From e91a5761797e01d8dba180300ea246e8fa7d51ae Mon Sep 17 00:00:00 2001 From: Tatsuhiro Tsujikawa Date: Thu, 13 Aug 2015 00:04:41 +0900 Subject: [PATCH] nghttpx: Rewrite TLS async handshake using memchunk buffers --- src/memchunk.h | 152 +++++++++++++++++++ src/memchunk_test.cc | 141 ++++++++++++++++++ src/memchunk_test.h | 5 + src/shrpx-unittest.cc | 11 +- src/shrpx_client_handler.cc | 3 +- src/shrpx_connection.cc | 188 +++++++++++------------- src/shrpx_connection.h | 14 +- src/shrpx_http2_session.cc | 3 +- src/shrpx_http_downstream_connection.cc | 2 +- src/shrpx_memcached_connection.cc | 3 +- src/shrpx_rate_limit.cc | 13 +- src/shrpx_rate_limit.h | 16 +- 12 files changed, 421 insertions(+), 130 deletions(-) diff --git a/src/memchunk.h b/src/memchunk.h index 0ec40f26..6638a79e 100644 --- a/src/memchunk.h +++ b/src/memchunk.h @@ -117,6 +117,31 @@ template struct Pool { template struct Memchunks { Memchunks(Pool *pool) : pool(pool), head(nullptr), tail(nullptr), len(0) {} + Memchunks(const Memchunks &) = delete; + Memchunks(Memchunks &&other) + : pool(other.pool), head(other.head), tail(other.head), len(other.len) { + // keep other.pool + other.head = other.tail = nullptr; + other.len = 0; + } + Memchunks &operator=(const Memchunks &) = delete; + Memchunks &operator=(Memchunks &&other) { + if (this == &other) { + return *this; + } + + reset(); + + pool = other.pool; + head = other.head; + tail = other.tail; + len = other.len; + + other.head = other.tail = nullptr; + other.len = 0; + + return *this; + } ~Memchunks() { if (!pool) { return; @@ -223,15 +248,142 @@ template struct Memchunks { return i; } size_t rleft() const { return len; } + void reset() { + for (auto m = head; m;) { + auto next = m->next; + pool->recycle(m); + m = next; + } + len = 0; + head = tail = nullptr; + } Pool *pool; Memchunk *head, *tail; size_t len; }; +// Wrapper around Memchunks to offer "peeking" functionality. +template struct PeekMemchunks { + PeekMemchunks(Pool *pool) + : memchunks(pool), cur(nullptr), cur_pos(nullptr), cur_last(nullptr), + len(0), peeking(true) {} + PeekMemchunks(const PeekMemchunks &) = delete; + PeekMemchunks(PeekMemchunks &&other) + : memchunks(std::move(other.memchunks)), cur(other.cur), + cur_pos(other.cur_pos), cur_last(other.cur_last), len(other.len), + peeking(other.peeking) { + other.reset(); + } + PeekMemchunks &operator=(const PeekMemchunks &) = delete; + PeekMemchunks &operator=(PeekMemchunks &&other) { + if (this == &other) { + return *this; + } + + memchunks = std::move(other.memchunks); + cur = other.cur; + cur_pos = other.cur_pos; + cur_last = other.cur_last; + len = other.len; + peeking = other.peeking; + + other.reset(); + + return *this; + } + size_t append(const void *src, size_t count) { + count = memchunks.append(src, count); + len += count; + return count; + } + size_t remove(void *dest, size_t count) { + if (!peeking) { + count = memchunks.remove(dest, count); + len -= count; + return count; + } + + if (count == 0 || len == 0) { + return 0; + } + + if (!cur) { + cur = memchunks.head; + cur_pos = cur->pos; + } + + // cur_last could be updated in append + cur_last = cur->last; + + if (cur_pos == cur_last) { + assert(cur->next); + cur = cur->next; + } + + auto first = static_cast(dest); + auto last = first + count; + + for (;;) { + auto n = std::min(last - first, cur_last - cur_pos); + + first = std::copy_n(cur_pos, n, first); + cur_pos += n; + len -= n; + + if (first == last) { + break; + } + assert(cur_pos == cur_last); + if (!cur->next) { + break; + } + cur = cur->next; + cur_pos = cur->pos; + cur_last = cur->last; + } + return first - static_cast(dest); + } + size_t rleft() const { return len; } + size_t rleft_buffered() const { return memchunks.rleft(); } + void disable_peek(bool drain) { + if (!peeking) { + return; + } + if (drain) { + auto n = rleft_buffered() - rleft(); + memchunks.drain(n); + assert(len == memchunks.rleft()); + } else { + len = memchunks.rleft(); + } + cur = nullptr; + cur_pos = cur_last = nullptr; + peeking = false; + } + void reset() { + memchunks.reset(); + cur = nullptr; + cur_pos = cur_last = nullptr; + len = 0; + peeking = true; + } + Memchunks memchunks; + // Pointer to the Memchunk currently we are reading/writing. + Memchunk *cur; + // Region inside cur, we have processed to cur_pos. + uint8_t *cur_pos, *cur_last; + // This is the length we have left unprocessed. len <= + // memchunk.rleft() must hold. + size_t len; + // true if peeking is enabled. Initially it is true. + bool peeking; +}; + using Memchunk16K = Memchunk<16_k>; using MemchunkPool = Pool; using DefaultMemchunks = Memchunks; +using DefaultPeekMemchunks = PeekMemchunks; #define DEFAULT_WR_IOVCNT 16 diff --git a/src/memchunk_test.cc b/src/memchunk_test.cc index f95f3032..c206d153 100644 --- a/src/memchunk_test.cc +++ b/src/memchunk_test.cc @@ -84,6 +84,7 @@ void test_pool_recycle(void) { using Memchunk16 = Memchunk<16>; using MemchunkPool16 = Pool; using Memchunks16 = Memchunks; +using PeekMemchunks16 = PeekMemchunks; void test_memchunks_append(void) { MemchunkPool16 pool; @@ -196,4 +197,144 @@ void test_memchunks_recycle(void) { CU_ASSERT(nullptr == m->next); } +void test_memchunks_reset(void) { + MemchunkPool16 pool; + Memchunks16 chunks(&pool); + + std::array b{}; + + chunks.append(b.data(), b.size()); + + CU_ASSERT(32 == chunks.rleft()); + + chunks.reset(); + + CU_ASSERT(0 == chunks.rleft()); + CU_ASSERT(nullptr == chunks.head); + CU_ASSERT(nullptr == chunks.tail); + + auto m = pool.freelist; + + CU_ASSERT(nullptr != m); + CU_ASSERT(nullptr != m->next); + CU_ASSERT(nullptr == m->next->next); +} + +void test_peek_memchunks_append(void) { + MemchunkPool16 pool; + PeekMemchunks16 pchunks(&pool); + + std::array b{{ + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', '1', '2', '3', '4', + '5', '6', '7', '8', '9', '0', '1', + }}, + d; + + pchunks.append(b.data(), b.size()); + + CU_ASSERT(32 == pchunks.rleft()); + CU_ASSERT(32 == pchunks.rleft_buffered()); + + CU_ASSERT(0 == pchunks.remove(nullptr, 0)); + + CU_ASSERT(32 == pchunks.rleft()); + CU_ASSERT(32 == pchunks.rleft_buffered()); + + CU_ASSERT(12 == pchunks.remove(d.data(), 12)); + + CU_ASSERT(std::equal(std::begin(b), std::begin(b) + 12, std::begin(d))); + + CU_ASSERT(20 == pchunks.rleft()); + CU_ASSERT(32 == pchunks.rleft_buffered()); + + CU_ASSERT(20 == pchunks.remove(d.data(), d.size())); + + CU_ASSERT(std::equal(std::begin(b) + 12, std::end(b), std::begin(d))); + + CU_ASSERT(0 == pchunks.rleft()); + CU_ASSERT(32 == pchunks.rleft_buffered()); +} + +void test_peek_memchunks_disable_peek_drain(void) { + MemchunkPool16 pool; + PeekMemchunks16 pchunks(&pool); + + std::array b{{ + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', '1', '2', '3', '4', + '5', '6', '7', '8', '9', '0', '1', + }}, + d; + + pchunks.append(b.data(), b.size()); + + CU_ASSERT(12 == pchunks.remove(d.data(), 12)); + + pchunks.disable_peek(true); + + CU_ASSERT(!pchunks.peeking); + CU_ASSERT(20 == pchunks.rleft()); + CU_ASSERT(20 == pchunks.rleft_buffered()); + + CU_ASSERT(20 == pchunks.remove(d.data(), d.size())); + + CU_ASSERT(std::equal(std::begin(b) + 12, std::end(b), std::begin(d))); + + CU_ASSERT(0 == pchunks.rleft()); + CU_ASSERT(0 == pchunks.rleft_buffered()); +} + +void test_peek_memchunks_disable_peek_no_drain(void) { + MemchunkPool16 pool; + PeekMemchunks16 pchunks(&pool); + + std::array b{{ + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', '1', '2', '3', '4', + '5', '6', '7', '8', '9', '0', '1', + }}, + d; + + pchunks.append(b.data(), b.size()); + + CU_ASSERT(12 == pchunks.remove(d.data(), 12)); + + pchunks.disable_peek(false); + + CU_ASSERT(!pchunks.peeking); + CU_ASSERT(32 == pchunks.rleft()); + CU_ASSERT(32 == pchunks.rleft_buffered()); + + CU_ASSERT(32 == pchunks.remove(d.data(), d.size())); + + CU_ASSERT(std::equal(std::begin(b), std::end(b), std::begin(d))); + + CU_ASSERT(0 == pchunks.rleft()); + CU_ASSERT(0 == pchunks.rleft_buffered()); +} + +void test_peek_memchunks_reset(void) { + MemchunkPool16 pool; + PeekMemchunks16 pchunks(&pool); + + std::array b{{ + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', '1', '2', '3', '4', + '5', '6', '7', '8', '9', '0', '1', + }}, + d; + + pchunks.append(b.data(), b.size()); + + CU_ASSERT(12 == pchunks.remove(d.data(), 12)); + + pchunks.disable_peek(true); + pchunks.reset(); + + CU_ASSERT(0 == pchunks.rleft()); + CU_ASSERT(0 == pchunks.rleft_buffered()); + + CU_ASSERT(nullptr == pchunks.cur); + CU_ASSERT(nullptr == pchunks.cur_pos); + CU_ASSERT(nullptr == pchunks.cur_last); + CU_ASSERT(pchunks.peeking); +} + } // namespace nghttp2 diff --git a/src/memchunk_test.h b/src/memchunk_test.h index 9d170a43..068f8255 100644 --- a/src/memchunk_test.h +++ b/src/memchunk_test.h @@ -36,6 +36,11 @@ void test_memchunks_append(void); void test_memchunks_drain(void); void test_memchunks_riovec(void); void test_memchunks_recycle(void); +void test_memchunks_reset(void); +void test_peek_memchunks_append(void); +void test_peek_memchunks_disable_peek_drain(void); +void test_peek_memchunks_disable_peek_no_drain(void); +void test_peek_memchunks_reset(void); } // namespace nghttp2 diff --git a/src/shrpx-unittest.cc b/src/shrpx-unittest.cc index a783336d..10eefb8b 100644 --- a/src/shrpx-unittest.cc +++ b/src/shrpx-unittest.cc @@ -172,7 +172,16 @@ int main(int argc, char *argv[]) { !CU_add_test(pSuite, "memchunk_drain", nghttp2::test_memchunks_drain) || !CU_add_test(pSuite, "memchunk_riovec", nghttp2::test_memchunks_riovec) || !CU_add_test(pSuite, "memchunk_recycle", - nghttp2::test_memchunks_recycle)) { + nghttp2::test_memchunks_recycle) || + !CU_add_test(pSuite, "memchunk_reset", nghttp2::test_memchunks_reset) || + !CU_add_test(pSuite, "peek_memchunk_append", + nghttp2::test_peek_memchunks_append) || + !CU_add_test(pSuite, "peek_memchunk_disable_peek_drain", + nghttp2::test_peek_memchunks_disable_peek_drain) || + !CU_add_test(pSuite, "peek_memchunk_disable_peek_no_drain", + nghttp2::test_peek_memchunks_disable_peek_no_drain) || + !CU_add_test(pSuite, "peek_memchunk_reset", + nghttp2::test_peek_memchunks_reset)) { CU_cleanup_registry(); return CU_get_error(); } diff --git a/src/shrpx_client_handler.cc b/src/shrpx_client_handler.cc index ba32323f..df829181 100644 --- a/src/shrpx_client_handler.cc +++ b/src/shrpx_client_handler.cc @@ -358,7 +358,8 @@ int ClientHandler::upstream_http1_connhd_read() { ClientHandler::ClientHandler(Worker *worker, int fd, SSL *ssl, const char *ipaddr, const char *port) - : conn_(worker->get_loop(), fd, ssl, get_config()->upstream_write_timeout, + : conn_(worker->get_loop(), fd, ssl, worker->get_mcpool(), + get_config()->upstream_write_timeout, get_config()->upstream_read_timeout, get_config()->write_rate, get_config()->write_burst, get_config()->read_rate, get_config()->read_burst, writecb, readcb, timeoutcb, this), diff --git a/src/shrpx_connection.cc b/src/shrpx_connection.cc index c7c52722..0a9a77bb 100644 --- a/src/shrpx_connection.cc +++ b/src/shrpx_connection.cc @@ -40,12 +40,13 @@ using namespace nghttp2; namespace shrpx { Connection::Connection(struct ev_loop *loop, int fd, SSL *ssl, - ev_tstamp write_timeout, ev_tstamp read_timeout, - size_t write_rate, size_t write_burst, size_t read_rate, - size_t read_burst, IOCb writecb, IOCb readcb, - TimerCb timeoutcb, void *data) - : tls{}, wlimit(loop, &wev, write_rate, write_burst), - rlimit(loop, &rev, read_rate, read_burst, ssl), writecb(writecb), + MemchunkPool *mcpool, ev_tstamp write_timeout, + ev_tstamp read_timeout, size_t write_rate, + size_t write_burst, size_t read_rate, size_t read_burst, + IOCb writecb, IOCb readcb, TimerCb timeoutcb, void *data) + : tls{DefaultMemchunks(mcpool), DefaultPeekMemchunks(mcpool)}, + wlimit(loop, &wev, write_rate, write_burst), + rlimit(loop, &rev, read_rate, read_burst, this), writecb(writecb), readcb(readcb), timeoutcb(timeoutcb), loop(loop), data(data), fd(fd) { ev_io_init(&wev, writecb, fd, EV_WRITE); @@ -83,10 +84,12 @@ void Connection::disconnect() { if (tls.cached_session) { SSL_SESSION_free(tls.cached_session); + tls.cached_session = nullptr; } if (tls.cached_session_lookup_req) { tls.cached_session_lookup_req->canceled = true; + tls.cached_session_lookup_req = nullptr; } // To reuse SSL/TLS session, we have to shutdown, and don't free @@ -96,7 +99,15 @@ void Connection::disconnect() { tls.ssl = nullptr; } - tls = {tls.ssl}; + tls.wbuf.reset(); + tls.rbuf.reset(); + tls.last_write_idle = 0.; + tls.warmup_writelen = 0; + tls.last_writelen = 0; + tls.last_readlen = 0; + tls.handshake_state = 0; + tls.initial_handshake_done = false; + tls.reneg_started = false; } if (fd != -1) { @@ -114,22 +125,9 @@ void Connection::disconnect() { wlimit.stopw(); } -namespace { -void allocate_buffer(Connection *conn) { - conn->tls.rb = make_unique>(); - conn->tls.wb = make_unique>(); -} -} // namespace +void Connection::prepare_client_handshake() { SSL_set_connect_state(tls.ssl); } -void Connection::prepare_client_handshake() { - SSL_set_connect_state(tls.ssl); - allocate_buffer(this); -} - -void Connection::prepare_server_handshake() { - SSL_set_accept_state(tls.ssl); - allocate_buffer(this); -} +void Connection::prepare_server_handshake() { SSL_set_accept_state(tls.ssl); } // BIO implementation is inspired by openldap implementation: // http://www.openldap.org/devel/cvsweb.cgi/~checkout~/libraries/libldap/tls_o.c @@ -140,27 +138,26 @@ int shrpx_bio_write(BIO *b, const char *buf, int len) { } auto conn = static_cast(b->ptr); - auto &wb = conn->tls.wb; + auto &wbuf = conn->tls.wbuf; BIO_clear_retry_flags(b); if (conn->tls.initial_handshake_done) { // After handshake finished, send |buf| of length |len| to the // socket directly. - if (wb && wb->rleft()) { - auto nwrite = conn->write_clear(wb->pos, wb->rleft()); + if (wbuf.rleft()) { + std::array iov; + auto iovcnt = wbuf.riovec(iov.data(), iov.size()); + auto nwrite = conn->writev_clear(iov.data(), iovcnt); if (nwrite < 0) { return -1; } - wb->drain(nwrite); - if (wb->rleft()) { + wbuf.drain(nwrite); + if (wbuf.rleft()) { BIO_set_retry_write(b); return -1; } - - // Here delete TLS write buffer - wb.reset(); } auto nwrite = conn->write_clear(buf, len); if (nwrite < 0) { @@ -175,16 +172,9 @@ int shrpx_bio_write(BIO *b, const char *buf, int len) { return nwrite; } - auto nwrite = std::min(static_cast(len), wb->wleft()); + wbuf.append(buf, len); - if (nwrite == 0) { - BIO_set_retry_write(b); - return -1; - } - - wb->write(buf, nwrite); - - return nwrite; + return len; } } // namespace @@ -195,11 +185,11 @@ int shrpx_bio_read(BIO *b, char *buf, int len) { } auto conn = static_cast(b->ptr); - auto &rb = conn->tls.rb; + auto &rbuf = conn->tls.rbuf; BIO_clear_retry_flags(b); - if (conn->tls.initial_handshake_done && !rb) { + if (conn->tls.initial_handshake_done && rbuf.rleft() == 0) { auto nread = conn->read_clear(buf, len); if (nread < 0) { return -1; @@ -211,22 +201,12 @@ int shrpx_bio_read(BIO *b, char *buf, int len) { return nread; } - auto nread = std::min(static_cast(len), rb->rleft()); - - if (nread == 0) { - if (conn->tls.initial_handshake_done) { - rb.reset(); - } - + if (rbuf.rleft() == 0) { BIO_set_retry_read(b); return -1; } - std::copy_n(rb->pos, nread, buf); - - rb->drain(nread); - - return nread; + return rbuf.remove(buf, len); } } // namespace @@ -289,51 +269,47 @@ void Connection::set_ssl(SSL *ssl) { bio->ptr = this; SSL_set_bio(tls.ssl, bio, bio); SSL_set_app_data(tls.ssl, this); - rlimit.set_ssl(tls.ssl); } +namespace { +// We should buffer at least full encrypted TLS record here. +// Theoretically, peer can send client hello in several TLS records, +// which could exeed this limit, but it is not portable, and we don't +// have to handle such exotic behaviour. +bool read_buffer_full(DefaultPeekMemchunks &rbuf) { + return rbuf.rleft_buffered() >= 20_k; +} +} // namespace + int Connection::tls_handshake() { wlimit.stopw(); ev_timer_stop(loop, &wt); - auto nread = read_clear(tls.rb->last, tls.rb->wleft()); - if (nread < 0) { - if (LOG_ENABLED(INFO)) { - LOG(INFO) << "tls: handshake read error"; - } - return -1; - } - tls.rb->write(nread); - - // We have limited space for read buffer, so stop reading if it - // filled up. - if (tls.rb->wleft() == 0) { - if (tls.handshake_state != TLS_CONN_WRITE_STARTED) { - // Reading 16KiB before writing server hello is unlikely for - // ordinary client. + if (ev_is_active(&rev)) { + std::array buf; + auto nread = read_clear(buf.data(), buf.size()); + if (nread < 0) { if (LOG_ENABLED(INFO)) { - LOG(INFO) << "tls: client hello is too large"; + LOG(INFO) << "tls: handshake read error"; } return -1; } - - rlimit.stopw(); - ev_timer_stop(loop, &rt); + tls.rbuf.append(buf.data(), nread); + if (read_buffer_full(tls.rbuf)) { + rlimit.stopw(); + } } switch (tls.handshake_state) { case TLS_CONN_WAIT_FOR_SESSION_CACHE: - if (nread > 0) { - if (LOG_ENABLED(INFO)) { - LOG(INFO) << "tls: client sent addtional data after client hello"; - } - return -1; - } return SHRPX_ERR_INPROGRESS; case TLS_CONN_GOT_SESSION_CACHE: { - // Use the same trick invented by @kazuho in h2o project - tls.wb->reset(); - tls.rb->pos = tls.rb->begin(); + // Use the same trick invented by @kazuho in h2o project. + + // Discard all outgoing data. + tls.wbuf.reset(); + // Rewind buffered incoming data to replay client hello. + tls.rbuf.disable_peek(false); auto ssl_ctx = SSL_get_SSL_CTX(tls.ssl); auto ssl_opts = SSL_get_options(tls.ssl); @@ -382,32 +358,33 @@ int Connection::tls_handshake() { return SHRPX_ERR_INPROGRESS; } - if (tls.wb->rleft()) { + if (tls.wbuf.rleft()) { // First write indicates that resumption stuff has done. - tls.handshake_state = TLS_CONN_WRITE_STARTED; - auto nwrite = write_clear(tls.wb->pos, tls.wb->rleft()); + if (tls.handshake_state != TLS_CONN_WRITE_STARTED) { + tls.handshake_state = TLS_CONN_WRITE_STARTED; + // If peek has already disabled, this is noop. + tls.rbuf.disable_peek(true); + } + std::array iov; + auto iovcnt = tls.wbuf.riovec(iov.data(), iov.size()); + auto nwrite = writev_clear(iov.data(), iovcnt); if (nwrite < 0) { if (LOG_ENABLED(INFO)) { LOG(INFO) << "tls: handshake write error"; } return -1; } - tls.wb->drain(nwrite); + tls.wbuf.drain(nwrite); + + if (tls.wbuf.rleft()) { + wlimit.startw(); + ev_timer_again(loop, &wt); + } } - if (tls.wb->rleft()) { - wlimit.startw(); - ev_timer_again(loop, &wt); - } else { - tls.wb->reset(); - } - - if (tls.handshake_state == TLS_CONN_WRITE_STARTED && tls.rb->rleft() == 0) { - tls.rb->reset(); - + if (!read_buffer_full(tls.rbuf)) { // We may have stopped reading rlimit.startw(); - ev_timer_again(loop, &rt); } if (rv != 1) { @@ -419,13 +396,14 @@ int Connection::tls_handshake() { tls.initial_handshake_done = true; - if (tls.rb->rleft()) { - ev_feed_event(loop, &rev, EV_READ); - } - - // We may have stopped reading + // We have to start read watcher, since later stage of code expects + // this. rlimit.startw(); - ev_timer_again(loop, &rt); + + // We may have whole request in tls.rbuf. This means that we don't + // get notified further read event. This is especially true for + // HTTP/1.1. + handle_tls_pending_read(); if (LOG_ENABLED(INFO)) { LOG(INFO) << "SSL/TLS handshake completed"; @@ -506,8 +484,8 @@ ssize_t Connection::write_tls(const void *data, size_t len) { return SHRPX_ERR_NETWORK; case SSL_ERROR_WANT_WRITE: tls.last_writelen = len; - wlimit.startw(); - ev_timer_again(loop, &wt); + // starting write watcher and timer is done in write_clear via + // bio. return 0; default: if (LOG_ENABLED(INFO)) { diff --git a/src/shrpx_connection.h b/src/shrpx_connection.h index 2f4105ab..1d6a01c5 100644 --- a/src/shrpx_connection.h +++ b/src/shrpx_connection.h @@ -35,7 +35,7 @@ #include "shrpx_rate_limit.h" #include "shrpx_error.h" -#include "buffer.h" +#include "memchunk.h" namespace shrpx { @@ -50,6 +50,8 @@ enum { }; struct TLSConnection { + DefaultMemchunks wbuf; + DefaultPeekMemchunks rbuf; SSL *ssl; SSL_SESSION *cached_session; MemcachedRequest *cached_session_lookup_req; @@ -62,8 +64,6 @@ struct TLSConnection { int handshake_state; bool initial_handshake_done; bool reneg_started; - std::unique_ptr> rb; - std::unique_ptr> wb; }; template using EVCb = void (*)(struct ev_loop *, T *, int); @@ -72,10 +72,10 @@ using IOCb = EVCb; using TimerCb = EVCb; struct Connection { - Connection(struct ev_loop *loop, int fd, SSL *ssl, ev_tstamp write_timeout, - ev_tstamp read_timeout, size_t write_rate, size_t write_burst, - size_t read_rate, size_t read_burst, IOCb writecb, IOCb readcb, - TimerCb timeoutcb, void *data); + Connection(struct ev_loop *loop, int fd, SSL *ssl, MemchunkPool *mcpool, + ev_tstamp write_timeout, ev_tstamp read_timeout, size_t write_rate, + size_t write_burst, size_t read_rate, size_t read_burst, + IOCb writecb, IOCb readcb, TimerCb timeoutcb, void *data); ~Connection(); void disconnect(); diff --git a/src/shrpx_http2_session.cc b/src/shrpx_http2_session.cc index 59ad9327..c88adf4f 100644 --- a/src/shrpx_http2_session.cc +++ b/src/shrpx_http2_session.cc @@ -144,7 +144,8 @@ void writecb(struct ev_loop *loop, ev_io *w, int revents) { Http2Session::Http2Session(struct ev_loop *loop, SSL_CTX *ssl_ctx, ConnectBlocker *connect_blocker, Worker *worker, size_t group, size_t idx) - : conn_(loop, -1, nullptr, get_config()->downstream_write_timeout, + : conn_(loop, -1, nullptr, worker->get_mcpool(), + get_config()->downstream_write_timeout, get_config()->downstream_read_timeout, 0, 0, 0, 0, writecb, readcb, timeoutcb, this), worker_(worker), connect_blocker_(connect_blocker), ssl_ctx_(ssl_ctx), diff --git a/src/shrpx_http_downstream_connection.cc b/src/shrpx_http_downstream_connection.cc index 3bcadc9e..43d1a300 100644 --- a/src/shrpx_http_downstream_connection.cc +++ b/src/shrpx_http_downstream_connection.cc @@ -112,7 +112,7 @@ void connectcb(struct ev_loop *loop, ev_io *w, int revents) { HttpDownstreamConnection::HttpDownstreamConnection( DownstreamConnectionPool *dconn_pool, size_t group, struct ev_loop *loop) : DownstreamConnection(dconn_pool), - conn_(loop, -1, nullptr, get_config()->downstream_write_timeout, + conn_(loop, -1, nullptr, nullptr, get_config()->downstream_write_timeout, get_config()->downstream_read_timeout, 0, 0, 0, 0, connectcb, readcb, timeoutcb, this), ioctrl_(&conn_.rlimit), response_htp_{0}, group_(group), addr_idx_(0), diff --git a/src/shrpx_memcached_connection.cc b/src/shrpx_memcached_connection.cc index b954b8d7..5f6b93f7 100644 --- a/src/shrpx_memcached_connection.cc +++ b/src/shrpx_memcached_connection.cc @@ -92,7 +92,7 @@ constexpr ev_tstamp read_timeout = 10.; MemcachedConnection::MemcachedConnection(const Address *addr, struct ev_loop *loop) - : conn_(loop, -1, nullptr, write_timeout, read_timeout, 0, 0, 0, 0, + : conn_(loop, -1, nullptr, nullptr, write_timeout, read_timeout, 0, 0, 0, 0, connectcb, readcb, timeoutcb, this), parse_state_{}, addr_(addr), sendsum_(0), connected_(false) {} @@ -403,6 +403,7 @@ int MemcachedConnection::parse_packet() { return 0; } +#undef DEFAULT_WR_IOVCNT #define DEFAULT_WR_IOVCNT 128 #if defined(IOV_MAX) && IOV_MAX < DEFAULT_WR_IOVCNT diff --git a/src/shrpx_rate_limit.cc b/src/shrpx_rate_limit.cc index 53451f4c..e184921b 100644 --- a/src/shrpx_rate_limit.cc +++ b/src/shrpx_rate_limit.cc @@ -26,6 +26,8 @@ #include +#include "shrpx_connection.h" + namespace shrpx { namespace { @@ -36,9 +38,9 @@ void regencb(struct ev_loop *loop, ev_timer *w, int revents) { } // namespace RateLimit::RateLimit(struct ev_loop *loop, ev_io *w, size_t rate, size_t burst, - SSL *ssl) - : w_(w), loop_(loop), ssl_(ssl), rate_(rate), burst_(burst), avail_(burst), - startw_req_(false) { + Connection *conn) + : w_(w), loop_(loop), conn_(conn), rate_(rate), burst_(burst), + avail_(burst), startw_req_(false) { ev_timer_init(&t_, regencb, 0., 1.); t_.data = this; if (rate_ > 0) { @@ -97,7 +99,8 @@ void RateLimit::stopw() { } void RateLimit::handle_tls_pending_read() { - if (!ssl_ || SSL_pending(ssl_) == 0) { + if (!conn_ || !conn_->tls.ssl || + (SSL_pending(conn_->tls.ssl) == 0 && conn_->tls.rbuf.rleft() == 0)) { return; } @@ -106,6 +109,4 @@ void RateLimit::handle_tls_pending_read() { ev_feed_event(loop_, w_, EV_READ); } -void RateLimit::set_ssl(SSL *ssl) { ssl_ = ssl; } - } // namespace shrpx diff --git a/src/shrpx_rate_limit.h b/src/shrpx_rate_limit.h index 24c96275..7502a27a 100644 --- a/src/shrpx_rate_limit.h +++ b/src/shrpx_rate_limit.h @@ -33,28 +33,30 @@ namespace shrpx { +struct Connection; + class RateLimit { public: - // We need |ssl| object to check that it has unread decrypted bytes. + // We need |conn| object to check that it has unread bytes for TLS + // connection. RateLimit(struct ev_loop *loop, ev_io *w, size_t rate, size_t burst, - SSL *ssl = nullptr); + Connection *conn = nullptr); ~RateLimit(); size_t avail() const; void drain(size_t n); void regen(); void startw(); void stopw(); - // Feeds event if ssl_ object has unread decrypted bytes. This is - // required since it is buffered in ssl_ object, io event is not - // generated unless new incoming data is received. + // Feeds event if conn_->tls object has unread bytes. This is + // required since it is buffered in conn_->tls object, io event is + // not generated unless new incoming data is received. void handle_tls_pending_read(); - void set_ssl(SSL *ssl); private: ev_timer t_; ev_io *w_; struct ev_loop *loop_; - SSL *ssl_; + Connection *conn_; size_t rate_; size_t burst_; size_t avail_;