asio: Add configurable tls handshake/read timeout to server

This commit is contained in:
Tatsuhiro Tsujikawa 2015-12-22 00:33:12 +09:00
parent 09bd9c94a3
commit 1ee1122d40
7 changed files with 141 additions and 26 deletions

View File

@ -44,8 +44,12 @@ namespace nghttp2 {
namespace asio_http2 { namespace asio_http2 {
namespace server { namespace server {
server::server(std::size_t io_service_pool_size) server::server(std::size_t io_service_pool_size,
: io_service_pool_(io_service_pool_size) {} const boost::posix_time::time_duration &tls_handshake_timeout,
const boost::posix_time::time_duration &read_timeout)
: io_service_pool_(io_service_pool_size),
tls_handshake_timeout_(tls_handshake_timeout),
read_timeout_(read_timeout) {}
boost::system::error_code boost::system::error_code
server::listen_and_serve(boost::system::error_code &ec, server::listen_and_serve(boost::system::error_code &ec,
@ -121,7 +125,8 @@ boost::system::error_code server::bind_and_listen(boost::system::error_code &ec,
void server::start_accept(boost::asio::ssl::context &tls_context, void server::start_accept(boost::asio::ssl::context &tls_context,
tcp::acceptor &acceptor, serve_mux &mux) { tcp::acceptor &acceptor, serve_mux &mux) {
auto new_connection = std::make_shared<connection<ssl_socket>>( auto new_connection = std::make_shared<connection<ssl_socket>>(
mux, io_service_pool_.get_io_service(), tls_context); mux, tls_handshake_timeout_, read_timeout_,
io_service_pool_.get_io_service(), tls_context);
acceptor.async_accept( acceptor.async_accept(
new_connection->socket().lowest_layer(), new_connection->socket().lowest_layer(),
@ -130,14 +135,17 @@ void server::start_accept(boost::asio::ssl::context &tls_context,
if (!e) { if (!e) {
new_connection->socket().lowest_layer().set_option( new_connection->socket().lowest_layer().set_option(
tcp::no_delay(true)); tcp::no_delay(true));
new_connection->start_tls_handshake_deadline();
new_connection->socket().async_handshake( new_connection->socket().async_handshake(
boost::asio::ssl::stream_base::server, boost::asio::ssl::stream_base::server,
[new_connection](const boost::system::error_code &e) { [new_connection](const boost::system::error_code &e) {
if (e) { if (e) {
new_connection->stop();
return; return;
} }
if (!tls_h2_negotiated(new_connection->socket())) { if (!tls_h2_negotiated(new_connection->socket())) {
new_connection->stop();
return; return;
} }
@ -151,13 +159,15 @@ void server::start_accept(boost::asio::ssl::context &tls_context,
void server::start_accept(tcp::acceptor &acceptor, serve_mux &mux) { void server::start_accept(tcp::acceptor &acceptor, serve_mux &mux) {
auto new_connection = std::make_shared<connection<tcp::socket>>( auto new_connection = std::make_shared<connection<tcp::socket>>(
mux, io_service_pool_.get_io_service()); mux, tls_handshake_timeout_, read_timeout_,
io_service_pool_.get_io_service());
acceptor.async_accept( acceptor.async_accept(
new_connection->socket(), [this, &acceptor, &mux, new_connection]( new_connection->socket(), [this, &acceptor, &mux, new_connection](
const boost::system::error_code &e) { const boost::system::error_code &e) {
if (!e) { if (!e) {
new_connection->socket().set_option(tcp::no_delay(true)); new_connection->socket().set_option(tcp::no_delay(true));
new_connection->start_read_deadline();
new_connection->start(); new_connection->start();
} }

View File

@ -63,7 +63,9 @@ using ssl_socket = boost::asio::ssl::stream<tcp::socket>;
class server : private boost::noncopyable { class server : private boost::noncopyable {
public: public:
explicit server(std::size_t io_service_pool_size); explicit server(std::size_t io_service_pool_size,
const boost::posix_time::time_duration &tls_handshake_timeout,
const boost::posix_time::time_duration &read_timeout);
boost::system::error_code boost::system::error_code
listen_and_serve(boost::system::error_code &ec, listen_and_serve(boost::system::error_code &ec,
@ -98,6 +100,9 @@ private:
std::vector<tcp::acceptor> acceptors_; std::vector<tcp::acceptor> acceptors_;
std::unique_ptr<boost::asio::ssl::context> ssl_ctx_; std::unique_ptr<boost::asio::ssl::context> ssl_ctx_;
boost::posix_time::time_duration tls_handshake_timeout_;
boost::posix_time::time_duration read_timeout_;
}; };
} // namespace server } // namespace server

View File

@ -64,9 +64,15 @@ class connection : public std::enable_shared_from_this<connection<socket_type>>,
public: public:
/// Construct a connection with the given io_service. /// Construct a connection with the given io_service.
template <typename... SocketArgs> template <typename... SocketArgs>
explicit connection(serve_mux &mux, SocketArgs &&... args) explicit connection(
: socket_(std::forward<SocketArgs>(args)...), mux_(mux), writing_(false) { serve_mux &mux,
} const boost::posix_time::time_duration &tls_handshake_timeout,
const boost::posix_time::time_duration &read_timeout,
SocketArgs &&... args)
: socket_(std::forward<SocketArgs>(args)...), mux_(mux),
deadline_(socket_.get_io_service()),
tls_handshake_timeout_(tls_handshake_timeout),
read_timeout_(read_timeout), writing_(false), stopped_(false) {}
/// Start the first asynchronous operation for the connection. /// Start the first asynchronous operation for the connection.
void start() { void start() {
@ -74,6 +80,7 @@ public:
socket_.get_io_service(), socket_.lowest_layer().remote_endpoint(), socket_.get_io_service(), socket_.lowest_layer().remote_endpoint(),
[this]() { do_write(); }, mux_); [this]() { do_write(); }, mux_);
if (handler_->start() != 0) { if (handler_->start() != 0) {
stop();
return; return;
} }
do_read(); do_read();
@ -81,27 +88,62 @@ public:
socket_type &socket() { return socket_; } socket_type &socket() { return socket_; }
void start_tls_handshake_deadline() {
deadline_.expires_from_now(tls_handshake_timeout_);
deadline_.async_wait(
std::bind(&connection::handle_deadline, this->shared_from_this()));
}
void start_read_deadline() {
deadline_.expires_from_now(read_timeout_);
deadline_.async_wait(
std::bind(&connection::handle_deadline, this->shared_from_this()));
}
void handle_deadline() {
if (stopped_) {
return;
}
if (deadline_.expires_at() <=
boost::asio::deadline_timer::traits_type::now()) {
stop();
deadline_.expires_at(boost::posix_time::pos_infin);
return;
}
deadline_.async_wait(
std::bind(&connection::handle_deadline, this->shared_from_this()));
}
void do_read() { void do_read() {
auto self = this->shared_from_this(); auto self = this->shared_from_this();
deadline_.expires_from_now(read_timeout_);
socket_.async_read_some( socket_.async_read_some(
boost::asio::buffer(buffer_), boost::asio::buffer(buffer_),
[this, self](const boost::system::error_code &e, [this, self](const boost::system::error_code &e,
std::size_t bytes_transferred) { std::size_t bytes_transferred) {
if (!e) { if (e) {
if (handler_->on_read(buffer_, bytes_transferred) != 0) { stop();
return; return;
}
do_write();
if (!writing_ && handler_->should_stop()) {
return;
}
do_read();
} }
if (handler_->on_read(buffer_, bytes_transferred) != 0) {
stop();
return;
}
do_write();
if (!writing_ && handler_->should_stop()) {
stop();
return;
}
do_read();
// If an error occurs then no new asynchronous operations are // If an error occurs then no new asynchronous operations are
// started. This means that all shared_ptr references to the // started. This means that all shared_ptr references to the
// connection object will disappear and the object will be // connection object will disappear and the object will be
@ -123,23 +165,34 @@ public:
rv = handler_->on_write(outbuf_, nwrite); rv = handler_->on_write(outbuf_, nwrite);
if (rv != 0) { if (rv != 0) {
stop();
return; return;
} }
if (nwrite == 0) { if (nwrite == 0) {
if (handler_->should_stop()) {
stop();
}
return; return;
} }
writing_ = true; writing_ = true;
// Reset read deadline here, because normally client is sending
// something, it does not expect timeout while doing it.
deadline_.expires_from_now(read_timeout_);
boost::asio::async_write( boost::asio::async_write(
socket_, boost::asio::buffer(outbuf_, nwrite), socket_, boost::asio::buffer(outbuf_, nwrite),
[this, self](const boost::system::error_code &e, std::size_t) { [this, self](const boost::system::error_code &e, std::size_t) {
if (!e) { if (e) {
writing_ = false; stop();
return;
do_write();
} }
writing_ = false;
do_write();
}); });
// No new asynchronous operations are started. This means that all // No new asynchronous operations are started. This means that all
@ -148,6 +201,17 @@ public:
// returns. The connection class's destructor closes the socket. // returns. The connection class's destructor closes the socket.
} }
void stop() {
if (stopped_) {
return;
}
stopped_ = true;
boost::system::error_code ignored_ec;
socket_.lowest_layer().close(ignored_ec);
deadline_.cancel();
}
private: private:
socket_type socket_; socket_type socket_;
@ -160,7 +224,12 @@ private:
boost::array<uint8_t, 64_k> outbuf_; boost::array<uint8_t, 64_k> outbuf_;
boost::asio::deadline_timer deadline_;
boost::posix_time::time_duration tls_handshake_timeout_;
boost::posix_time::time_duration read_timeout_;
bool writing_; bool writing_;
bool stopped_;
}; };
} // namespace server } // namespace server

View File

@ -69,6 +69,14 @@ void http2::num_threads(size_t num_threads) { impl_->num_threads(num_threads); }
void http2::backlog(int backlog) { impl_->backlog(backlog); } void http2::backlog(int backlog) { impl_->backlog(backlog); }
void http2::tls_handshake_timeout(const boost::posix_time::time_duration &t) {
impl_->tls_handshake_timeout(t);
}
void http2::read_timeout(const boost::posix_time::time_duration &t) {
impl_->read_timeout(t);
}
bool http2::handle(std::string pattern, request_cb cb) { bool http2::handle(std::string pattern, request_cb cb) {
return impl_->handle(std::move(pattern), std::move(cb)); return impl_->handle(std::move(pattern), std::move(cb));
} }

View File

@ -37,12 +37,16 @@ namespace asio_http2 {
namespace server { namespace server {
http2_impl::http2_impl() : num_threads_(1), backlog_(-1) {} http2_impl::http2_impl()
: num_threads_(1), backlog_(-1),
tls_handshake_timeout_(boost::posix_time::seconds(60)),
read_timeout_(boost::posix_time::seconds(60)) {}
boost::system::error_code http2_impl::listen_and_serve( boost::system::error_code http2_impl::listen_and_serve(
boost::system::error_code &ec, boost::asio::ssl::context *tls_context, boost::system::error_code &ec, boost::asio::ssl::context *tls_context,
const std::string &address, const std::string &port, bool asynchronous) { const std::string &address, const std::string &port, bool asynchronous) {
server_.reset(new server(num_threads_)); server_.reset(
new server(num_threads_, tls_handshake_timeout_, read_timeout_));
return server_->listen_and_serve(ec, tls_context, address, port, backlog_, return server_->listen_and_serve(ec, tls_context, address, port, backlog_,
mux_, asynchronous); mux_, asynchronous);
} }
@ -51,6 +55,15 @@ void http2_impl::num_threads(size_t num_threads) { num_threads_ = num_threads; }
void http2_impl::backlog(int backlog) { backlog_ = backlog; } void http2_impl::backlog(int backlog) { backlog_ = backlog; }
void http2_impl::tls_handshake_timeout(
const boost::posix_time::time_duration &t) {
tls_handshake_timeout_ = t;
}
void http2_impl::read_timeout(const boost::posix_time::time_duration &t) {
read_timeout_ = t;
}
bool http2_impl::handle(std::string pattern, request_cb cb) { bool http2_impl::handle(std::string pattern, request_cb cb) {
return mux_.handle(std::move(pattern), std::move(cb)); return mux_.handle(std::move(pattern), std::move(cb));
} }

View File

@ -47,6 +47,8 @@ public:
const std::string &address, const std::string &port, bool asynchronous); const std::string &address, const std::string &port, bool asynchronous);
void num_threads(size_t num_threads); void num_threads(size_t num_threads);
void backlog(int backlog); void backlog(int backlog);
void tls_handshake_timeout(const boost::posix_time::time_duration &t);
void read_timeout(const boost::posix_time::time_duration &t);
bool handle(std::string pattern, request_cb cb); bool handle(std::string pattern, request_cb cb);
void stop(); void stop();
void join(); void join();
@ -58,6 +60,8 @@ private:
std::size_t num_threads_; std::size_t num_threads_;
int backlog_; int backlog_;
serve_mux mux_; serve_mux mux_;
boost::posix_time::time_duration tls_handshake_timeout_;
boost::posix_time::time_duration read_timeout_;
}; };
} // namespace server } // namespace server

View File

@ -198,6 +198,12 @@ public:
// connections. // connections.
void backlog(int backlog); void backlog(int backlog);
// Sets TLS handshake timeout, which defaults to 60 seconds.
void tls_handshake_timeout(const boost::posix_time::time_duration &t);
// Sets read timeout, which defaults to 60 seconds.
void read_timeout(const boost::posix_time::time_duration &t);
// Gracefully stop http2 server // Gracefully stop http2 server
void stop(); void stop();