diff --git a/src/asio_server.cc b/src/asio_server.cc index a332d091..002c947d 100644 --- a/src/asio_server.cc +++ b/src/asio_server.cc @@ -44,8 +44,12 @@ namespace nghttp2 { namespace asio_http2 { namespace server { -server::server(std::size_t io_service_pool_size) - : io_service_pool_(io_service_pool_size) {} +server::server(std::size_t io_service_pool_size, + const boost::posix_time::time_duration &tls_handshake_timeout, + const boost::posix_time::time_duration &read_timeout) + : io_service_pool_(io_service_pool_size), + tls_handshake_timeout_(tls_handshake_timeout), + read_timeout_(read_timeout) {} boost::system::error_code server::listen_and_serve(boost::system::error_code &ec, @@ -121,7 +125,8 @@ boost::system::error_code server::bind_and_listen(boost::system::error_code &ec, void server::start_accept(boost::asio::ssl::context &tls_context, tcp::acceptor &acceptor, serve_mux &mux) { auto new_connection = std::make_shared>( - mux, io_service_pool_.get_io_service(), tls_context); + mux, tls_handshake_timeout_, read_timeout_, + io_service_pool_.get_io_service(), tls_context); acceptor.async_accept( new_connection->socket().lowest_layer(), @@ -130,14 +135,17 @@ void server::start_accept(boost::asio::ssl::context &tls_context, if (!e) { new_connection->socket().lowest_layer().set_option( tcp::no_delay(true)); + new_connection->start_tls_handshake_deadline(); new_connection->socket().async_handshake( boost::asio::ssl::stream_base::server, [new_connection](const boost::system::error_code &e) { if (e) { + new_connection->stop(); return; } if (!tls_h2_negotiated(new_connection->socket())) { + new_connection->stop(); return; } @@ -151,13 +159,15 @@ void server::start_accept(boost::asio::ssl::context &tls_context, void server::start_accept(tcp::acceptor &acceptor, serve_mux &mux) { auto new_connection = std::make_shared>( - mux, io_service_pool_.get_io_service()); + mux, tls_handshake_timeout_, read_timeout_, + io_service_pool_.get_io_service()); acceptor.async_accept( new_connection->socket(), [this, &acceptor, &mux, new_connection]( const boost::system::error_code &e) { if (!e) { new_connection->socket().set_option(tcp::no_delay(true)); + new_connection->start_read_deadline(); new_connection->start(); } diff --git a/src/asio_server.h b/src/asio_server.h index 49f99446..e1c52792 100644 --- a/src/asio_server.h +++ b/src/asio_server.h @@ -63,7 +63,9 @@ using ssl_socket = boost::asio::ssl::stream; class server : private boost::noncopyable { public: - explicit server(std::size_t io_service_pool_size); + explicit server(std::size_t io_service_pool_size, + const boost::posix_time::time_duration &tls_handshake_timeout, + const boost::posix_time::time_duration &read_timeout); boost::system::error_code listen_and_serve(boost::system::error_code &ec, @@ -98,6 +100,9 @@ private: std::vector acceptors_; std::unique_ptr ssl_ctx_; + + boost::posix_time::time_duration tls_handshake_timeout_; + boost::posix_time::time_duration read_timeout_; }; } // namespace server diff --git a/src/asio_server_connection.h b/src/asio_server_connection.h index beac04f8..4cab44b4 100644 --- a/src/asio_server_connection.h +++ b/src/asio_server_connection.h @@ -64,9 +64,15 @@ class connection : public std::enable_shared_from_this>, public: /// Construct a connection with the given io_service. template - explicit connection(serve_mux &mux, SocketArgs &&... args) - : socket_(std::forward(args)...), mux_(mux), writing_(false) { - } + explicit connection( + serve_mux &mux, + const boost::posix_time::time_duration &tls_handshake_timeout, + const boost::posix_time::time_duration &read_timeout, + SocketArgs &&... args) + : socket_(std::forward(args)...), mux_(mux), + deadline_(socket_.get_io_service()), + tls_handshake_timeout_(tls_handshake_timeout), + read_timeout_(read_timeout), writing_(false), stopped_(false) {} /// Start the first asynchronous operation for the connection. void start() { @@ -74,6 +80,7 @@ public: socket_.get_io_service(), socket_.lowest_layer().remote_endpoint(), [this]() { do_write(); }, mux_); if (handler_->start() != 0) { + stop(); return; } do_read(); @@ -81,27 +88,62 @@ public: socket_type &socket() { return socket_; } + void start_tls_handshake_deadline() { + deadline_.expires_from_now(tls_handshake_timeout_); + deadline_.async_wait( + std::bind(&connection::handle_deadline, this->shared_from_this())); + } + + void start_read_deadline() { + deadline_.expires_from_now(read_timeout_); + deadline_.async_wait( + std::bind(&connection::handle_deadline, this->shared_from_this())); + } + + void handle_deadline() { + if (stopped_) { + return; + } + + if (deadline_.expires_at() <= + boost::asio::deadline_timer::traits_type::now()) { + stop(); + deadline_.expires_at(boost::posix_time::pos_infin); + return; + } + + deadline_.async_wait( + std::bind(&connection::handle_deadline, this->shared_from_this())); + } + void do_read() { auto self = this->shared_from_this(); + deadline_.expires_from_now(read_timeout_); + socket_.async_read_some( boost::asio::buffer(buffer_), [this, self](const boost::system::error_code &e, std::size_t bytes_transferred) { - if (!e) { - if (handler_->on_read(buffer_, bytes_transferred) != 0) { - return; - } - - do_write(); - - if (!writing_ && handler_->should_stop()) { - return; - } - - do_read(); + if (e) { + stop(); + return; } + if (handler_->on_read(buffer_, bytes_transferred) != 0) { + stop(); + return; + } + + do_write(); + + if (!writing_ && handler_->should_stop()) { + stop(); + return; + } + + do_read(); + // If an error occurs then no new asynchronous operations are // started. This means that all shared_ptr references to the // connection object will disappear and the object will be @@ -123,23 +165,34 @@ public: rv = handler_->on_write(outbuf_, nwrite); if (rv != 0) { + stop(); return; } if (nwrite == 0) { + if (handler_->should_stop()) { + stop(); + } return; } writing_ = true; + // Reset read deadline here, because normally client is sending + // something, it does not expect timeout while doing it. + deadline_.expires_from_now(read_timeout_); + boost::asio::async_write( socket_, boost::asio::buffer(outbuf_, nwrite), [this, self](const boost::system::error_code &e, std::size_t) { - if (!e) { - writing_ = false; - - do_write(); + if (e) { + stop(); + return; } + + writing_ = false; + + do_write(); }); // No new asynchronous operations are started. This means that all @@ -148,6 +201,17 @@ public: // returns. The connection class's destructor closes the socket. } + void stop() { + if (stopped_) { + return; + } + + stopped_ = true; + boost::system::error_code ignored_ec; + socket_.lowest_layer().close(ignored_ec); + deadline_.cancel(); + } + private: socket_type socket_; @@ -160,7 +224,12 @@ private: boost::array outbuf_; + boost::asio::deadline_timer deadline_; + boost::posix_time::time_duration tls_handshake_timeout_; + boost::posix_time::time_duration read_timeout_; + bool writing_; + bool stopped_; }; } // namespace server diff --git a/src/asio_server_http2.cc b/src/asio_server_http2.cc index 2cc3495f..5606d833 100644 --- a/src/asio_server_http2.cc +++ b/src/asio_server_http2.cc @@ -69,6 +69,14 @@ void http2::num_threads(size_t num_threads) { impl_->num_threads(num_threads); } void http2::backlog(int backlog) { impl_->backlog(backlog); } +void http2::tls_handshake_timeout(const boost::posix_time::time_duration &t) { + impl_->tls_handshake_timeout(t); +} + +void http2::read_timeout(const boost::posix_time::time_duration &t) { + impl_->read_timeout(t); +} + bool http2::handle(std::string pattern, request_cb cb) { return impl_->handle(std::move(pattern), std::move(cb)); } diff --git a/src/asio_server_http2_impl.cc b/src/asio_server_http2_impl.cc index db1e0105..5e420219 100644 --- a/src/asio_server_http2_impl.cc +++ b/src/asio_server_http2_impl.cc @@ -37,12 +37,16 @@ namespace asio_http2 { namespace server { -http2_impl::http2_impl() : num_threads_(1), backlog_(-1) {} +http2_impl::http2_impl() + : num_threads_(1), backlog_(-1), + tls_handshake_timeout_(boost::posix_time::seconds(60)), + read_timeout_(boost::posix_time::seconds(60)) {} boost::system::error_code http2_impl::listen_and_serve( boost::system::error_code &ec, boost::asio::ssl::context *tls_context, const std::string &address, const std::string &port, bool asynchronous) { - server_.reset(new server(num_threads_)); + server_.reset( + new server(num_threads_, tls_handshake_timeout_, read_timeout_)); return server_->listen_and_serve(ec, tls_context, address, port, backlog_, mux_, asynchronous); } @@ -51,6 +55,15 @@ void http2_impl::num_threads(size_t num_threads) { num_threads_ = num_threads; } void http2_impl::backlog(int backlog) { backlog_ = backlog; } +void http2_impl::tls_handshake_timeout( + const boost::posix_time::time_duration &t) { + tls_handshake_timeout_ = t; +} + +void http2_impl::read_timeout(const boost::posix_time::time_duration &t) { + read_timeout_ = t; +} + bool http2_impl::handle(std::string pattern, request_cb cb) { return mux_.handle(std::move(pattern), std::move(cb)); } diff --git a/src/asio_server_http2_impl.h b/src/asio_server_http2_impl.h index 21c2e924..a3b98552 100644 --- a/src/asio_server_http2_impl.h +++ b/src/asio_server_http2_impl.h @@ -47,6 +47,8 @@ public: const std::string &address, const std::string &port, bool asynchronous); void num_threads(size_t num_threads); void backlog(int backlog); + void tls_handshake_timeout(const boost::posix_time::time_duration &t); + void read_timeout(const boost::posix_time::time_duration &t); bool handle(std::string pattern, request_cb cb); void stop(); void join(); @@ -58,6 +60,8 @@ private: std::size_t num_threads_; int backlog_; serve_mux mux_; + boost::posix_time::time_duration tls_handshake_timeout_; + boost::posix_time::time_duration read_timeout_; }; } // namespace server diff --git a/src/includes/nghttp2/asio_http2_server.h b/src/includes/nghttp2/asio_http2_server.h index 7a580081..5b969593 100644 --- a/src/includes/nghttp2/asio_http2_server.h +++ b/src/includes/nghttp2/asio_http2_server.h @@ -198,6 +198,12 @@ public: // connections. void backlog(int backlog); + // Sets TLS handshake timeout, which defaults to 60 seconds. + void tls_handshake_timeout(const boost::posix_time::time_duration &t); + + // Sets read timeout, which defaults to 60 seconds. + void read_timeout(const boost::posix_time::time_duration &t); + // Gracefully stop http2 server void stop();