Protect "delete this" with a stack refcounter

(to fix use-after-free, too, but "delete this" was a time bomb anyway)
sync-io-test
Vitaliy Filippov 2020-06-01 00:47:49 +03:00
parent 3a5d488f19
commit 3469bead67
1 changed files with 41 additions and 16 deletions

View File

@ -50,7 +50,13 @@ struct http_co_t
websocket_t ws;
int onstack = 0;
bool ended = false;
~http_co_t();
inline void stackin() { onstack++; }
inline void stackout() { onstack--; if (!onstack && ended) end(); }
inline void end() { ended = true; if (!onstack) { delete this; } }
void start_connection();
void handle_events();
void handle_connect_result();
@ -138,12 +144,11 @@ void websocket_t::post_message(int type, const std::string & msg)
void websocket_t::close()
{
delete co;
co->end();
}
http_co_t::~http_co_t()
{
epoll_events = 0;
if (timeout_id >= 0)
{
tfd->clear_timer(timeout_id);
@ -175,14 +180,15 @@ http_co_t::~http_co_t()
void http_co_t::start_connection()
{
stackin();
int port = extract_port(host);
struct sockaddr_in addr;
int r;
if ((r = inet_pton(AF_INET, host.c_str(), &addr.sin_addr)) != 1)
{
parsed.error_code = ENXIO;
// FIXME 'delete this' is ugly...
delete this;
stackout();
end();
return;
}
addr.sin_family = AF_INET;
@ -191,7 +197,8 @@ void http_co_t::start_connection()
if (peer_fd < 0)
{
parsed.error_code = errno;
delete this;
stackout();
end();
return;
}
fcntl(peer_fd, F_SETFL, fcntl(peer_fd, F_GETFL, 0) | O_NONBLOCK);
@ -203,7 +210,7 @@ void http_co_t::start_connection()
{
parsed.error_code = ETIME;
}
delete this;
end();
});
}
epoll_events = 0;
@ -212,7 +219,8 @@ void http_co_t::start_connection()
if (r < 0 && errno != EINPROGRESS)
{
parsed.error_code = errno;
delete this;
stackout();
end();
return;
}
tfd->set_fd_handler(peer_fd, [this](int peer_fd, int epoll_events)
@ -221,10 +229,12 @@ void http_co_t::start_connection()
handle_events();
});
state = HTTP_CO_CONNECTING;
stackout();
}
void http_co_t::handle_events()
{
stackin();
while (epoll_events)
{
if (state == HTTP_CO_CONNECTING)
@ -240,15 +250,17 @@ void http_co_t::handle_events()
}
else if (epoll_events & (EPOLLRDHUP|EPOLLERR))
{
delete this;
return;
end();
break;
}
}
}
stackout();
}
void http_co_t::handle_connect_result()
{
stackin();
int result = 0;
socklen_t result_len = sizeof(result);
if (getsockopt(peer_fd, SOL_SOCKET, SO_ERROR, &result, &result_len) < 0)
@ -258,17 +270,20 @@ void http_co_t::handle_connect_result()
if (result != 0)
{
parsed.error_code = result;
delete this;
stackout();
end();
return;
}
int one = 1;
setsockopt(peer_fd, SOL_TCP, TCP_NODELAY, &one, sizeof(one));
state = HTTP_CO_SENDING_REQUEST;
submit_send();
stackout();
}
void http_co_t::submit_read()
{
stackin();
int res;
if (rbuf.size() != READ_BUFFER_SIZE)
{
@ -288,18 +303,19 @@ void http_co_t::submit_read()
}
else if (res < 0)
{
delete this;
return;
end();
}
else if (res > 0)
{
response += std::string(rbuf.data(), res);
handle_read();
}
stackout();
}
void http_co_t::submit_send()
{
stackin();
int res;
again:
if (sent < request.size())
@ -318,7 +334,8 @@ again:
}
else if (res < 0)
{
delete this;
stackout();
end();
return;
}
sent += res;
@ -338,10 +355,12 @@ again:
goto again;
}
}
stackout();
}
bool http_co_t::handle_read()
{
stackin();
if (state == HTTP_CO_REQUEST_SENT)
{
int pos = response.find("\r\n\r\n");
@ -376,7 +395,8 @@ bool http_co_t::handle_read()
if (!target_response_size)
{
// Sorry, unsupported response
delete this;
stackout();
end();
return false;
}
}
@ -384,7 +404,8 @@ bool http_co_t::handle_read()
}
if (state == HTTP_CO_HEADERS_RECEIVED && target_response_size > 0 && response.size() >= target_response_size)
{
delete this;
stackout();
end();
return false;
}
if (state == HTTP_CO_CHUNKED && response.size() > 0)
@ -412,7 +433,8 @@ bool http_co_t::handle_read()
}
if (parsed.eof)
{
delete this;
stackout();
end();
return false;
}
if (want_streaming && parsed.body.size() > 0)
@ -429,11 +451,13 @@ bool http_co_t::handle_read()
parsed.body = "";
}
}
stackout();
return true;
}
void http_co_t::post_message(int type, const std::string & msg)
{
stackin();
if (state == HTTP_CO_WEBSOCKET)
{
request += ws_format_frame(type, msg.size());
@ -445,6 +469,7 @@ void http_co_t::post_message(int type, const std::string & msg)
ws_outbox += ws_format_frame(type, msg.size());
ws_outbox += msg;
}
stackout();
}
uint64_t stoull_full(const std::string & str, int base)