nghttpx: Rewrite TLS async handshake using memchunk buffers

This commit is contained in:
Tatsuhiro Tsujikawa 2015-08-13 00:04:41 +09:00
parent 72c661f1dd
commit e91a576179
12 changed files with 421 additions and 130 deletions

View File

@ -117,6 +117,31 @@ template <typename T> struct Pool {
template <typename Memchunk> struct Memchunks { template <typename Memchunk> struct Memchunks {
Memchunks(Pool<Memchunk> *pool) Memchunks(Pool<Memchunk> *pool)
: pool(pool), head(nullptr), tail(nullptr), len(0) {} : pool(pool), head(nullptr), tail(nullptr), len(0) {}
Memchunks(const Memchunks &) = delete;
Memchunks(Memchunks &&other)
: pool(other.pool), head(other.head), tail(other.head), len(other.len) {
// keep other.pool
other.head = other.tail = nullptr;
other.len = 0;
}
Memchunks &operator=(const Memchunks &) = delete;
Memchunks &operator=(Memchunks &&other) {
if (this == &other) {
return *this;
}
reset();
pool = other.pool;
head = other.head;
tail = other.tail;
len = other.len;
other.head = other.tail = nullptr;
other.len = 0;
return *this;
}
~Memchunks() { ~Memchunks() {
if (!pool) { if (!pool) {
return; return;
@ -223,15 +248,142 @@ template <typename Memchunk> struct Memchunks {
return i; return i;
} }
size_t rleft() const { return len; } size_t rleft() const { return len; }
void reset() {
for (auto m = head; m;) {
auto next = m->next;
pool->recycle(m);
m = next;
}
len = 0;
head = tail = nullptr;
}
Pool<Memchunk> *pool; Pool<Memchunk> *pool;
Memchunk *head, *tail; Memchunk *head, *tail;
size_t len; size_t len;
}; };
// Wrapper around Memchunks to offer "peeking" functionality.
template <typename Memchunk> struct PeekMemchunks {
PeekMemchunks(Pool<Memchunk> *pool)
: memchunks(pool), cur(nullptr), cur_pos(nullptr), cur_last(nullptr),
len(0), peeking(true) {}
PeekMemchunks(const PeekMemchunks &) = delete;
PeekMemchunks(PeekMemchunks &&other)
: memchunks(std::move(other.memchunks)), cur(other.cur),
cur_pos(other.cur_pos), cur_last(other.cur_last), len(other.len),
peeking(other.peeking) {
other.reset();
}
PeekMemchunks &operator=(const PeekMemchunks &) = delete;
PeekMemchunks &operator=(PeekMemchunks &&other) {
if (this == &other) {
return *this;
}
memchunks = std::move(other.memchunks);
cur = other.cur;
cur_pos = other.cur_pos;
cur_last = other.cur_last;
len = other.len;
peeking = other.peeking;
other.reset();
return *this;
}
size_t append(const void *src, size_t count) {
count = memchunks.append(src, count);
len += count;
return count;
}
size_t remove(void *dest, size_t count) {
if (!peeking) {
count = memchunks.remove(dest, count);
len -= count;
return count;
}
if (count == 0 || len == 0) {
return 0;
}
if (!cur) {
cur = memchunks.head;
cur_pos = cur->pos;
}
// cur_last could be updated in append
cur_last = cur->last;
if (cur_pos == cur_last) {
assert(cur->next);
cur = cur->next;
}
auto first = static_cast<uint8_t *>(dest);
auto last = first + count;
for (;;) {
auto n = std::min(last - first, cur_last - cur_pos);
first = std::copy_n(cur_pos, n, first);
cur_pos += n;
len -= n;
if (first == last) {
break;
}
assert(cur_pos == cur_last);
if (!cur->next) {
break;
}
cur = cur->next;
cur_pos = cur->pos;
cur_last = cur->last;
}
return first - static_cast<uint8_t *>(dest);
}
size_t rleft() const { return len; }
size_t rleft_buffered() const { return memchunks.rleft(); }
void disable_peek(bool drain) {
if (!peeking) {
return;
}
if (drain) {
auto n = rleft_buffered() - rleft();
memchunks.drain(n);
assert(len == memchunks.rleft());
} else {
len = memchunks.rleft();
}
cur = nullptr;
cur_pos = cur_last = nullptr;
peeking = false;
}
void reset() {
memchunks.reset();
cur = nullptr;
cur_pos = cur_last = nullptr;
len = 0;
peeking = true;
}
Memchunks<Memchunk> memchunks;
// Pointer to the Memchunk currently we are reading/writing.
Memchunk *cur;
// Region inside cur, we have processed to cur_pos.
uint8_t *cur_pos, *cur_last;
// This is the length we have left unprocessed. len <=
// memchunk.rleft() must hold.
size_t len;
// true if peeking is enabled. Initially it is true.
bool peeking;
};
using Memchunk16K = Memchunk<16_k>; using Memchunk16K = Memchunk<16_k>;
using MemchunkPool = Pool<Memchunk16K>; using MemchunkPool = Pool<Memchunk16K>;
using DefaultMemchunks = Memchunks<Memchunk16K>; using DefaultMemchunks = Memchunks<Memchunk16K>;
using DefaultPeekMemchunks = PeekMemchunks<Memchunk16K>;
#define DEFAULT_WR_IOVCNT 16 #define DEFAULT_WR_IOVCNT 16

View File

@ -84,6 +84,7 @@ void test_pool_recycle(void) {
using Memchunk16 = Memchunk<16>; using Memchunk16 = Memchunk<16>;
using MemchunkPool16 = Pool<Memchunk16>; using MemchunkPool16 = Pool<Memchunk16>;
using Memchunks16 = Memchunks<Memchunk16>; using Memchunks16 = Memchunks<Memchunk16>;
using PeekMemchunks16 = PeekMemchunks<Memchunk16>;
void test_memchunks_append(void) { void test_memchunks_append(void) {
MemchunkPool16 pool; MemchunkPool16 pool;
@ -196,4 +197,144 @@ void test_memchunks_recycle(void) {
CU_ASSERT(nullptr == m->next); CU_ASSERT(nullptr == m->next);
} }
void test_memchunks_reset(void) {
MemchunkPool16 pool;
Memchunks16 chunks(&pool);
std::array<uint8_t, 32> b{};
chunks.append(b.data(), b.size());
CU_ASSERT(32 == chunks.rleft());
chunks.reset();
CU_ASSERT(0 == chunks.rleft());
CU_ASSERT(nullptr == chunks.head);
CU_ASSERT(nullptr == chunks.tail);
auto m = pool.freelist;
CU_ASSERT(nullptr != m);
CU_ASSERT(nullptr != m->next);
CU_ASSERT(nullptr == m->next->next);
}
void test_peek_memchunks_append(void) {
MemchunkPool16 pool;
PeekMemchunks16 pchunks(&pool);
std::array<uint8_t, 32> b{{
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', '1', '2', '3', '4',
'5', '6', '7', '8', '9', '0', '1',
}},
d;
pchunks.append(b.data(), b.size());
CU_ASSERT(32 == pchunks.rleft());
CU_ASSERT(32 == pchunks.rleft_buffered());
CU_ASSERT(0 == pchunks.remove(nullptr, 0));
CU_ASSERT(32 == pchunks.rleft());
CU_ASSERT(32 == pchunks.rleft_buffered());
CU_ASSERT(12 == pchunks.remove(d.data(), 12));
CU_ASSERT(std::equal(std::begin(b), std::begin(b) + 12, std::begin(d)));
CU_ASSERT(20 == pchunks.rleft());
CU_ASSERT(32 == pchunks.rleft_buffered());
CU_ASSERT(20 == pchunks.remove(d.data(), d.size()));
CU_ASSERT(std::equal(std::begin(b) + 12, std::end(b), std::begin(d)));
CU_ASSERT(0 == pchunks.rleft());
CU_ASSERT(32 == pchunks.rleft_buffered());
}
void test_peek_memchunks_disable_peek_drain(void) {
MemchunkPool16 pool;
PeekMemchunks16 pchunks(&pool);
std::array<uint8_t, 32> b{{
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', '1', '2', '3', '4',
'5', '6', '7', '8', '9', '0', '1',
}},
d;
pchunks.append(b.data(), b.size());
CU_ASSERT(12 == pchunks.remove(d.data(), 12));
pchunks.disable_peek(true);
CU_ASSERT(!pchunks.peeking);
CU_ASSERT(20 == pchunks.rleft());
CU_ASSERT(20 == pchunks.rleft_buffered());
CU_ASSERT(20 == pchunks.remove(d.data(), d.size()));
CU_ASSERT(std::equal(std::begin(b) + 12, std::end(b), std::begin(d)));
CU_ASSERT(0 == pchunks.rleft());
CU_ASSERT(0 == pchunks.rleft_buffered());
}
void test_peek_memchunks_disable_peek_no_drain(void) {
MemchunkPool16 pool;
PeekMemchunks16 pchunks(&pool);
std::array<uint8_t, 32> b{{
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', '1', '2', '3', '4',
'5', '6', '7', '8', '9', '0', '1',
}},
d;
pchunks.append(b.data(), b.size());
CU_ASSERT(12 == pchunks.remove(d.data(), 12));
pchunks.disable_peek(false);
CU_ASSERT(!pchunks.peeking);
CU_ASSERT(32 == pchunks.rleft());
CU_ASSERT(32 == pchunks.rleft_buffered());
CU_ASSERT(32 == pchunks.remove(d.data(), d.size()));
CU_ASSERT(std::equal(std::begin(b), std::end(b), std::begin(d)));
CU_ASSERT(0 == pchunks.rleft());
CU_ASSERT(0 == pchunks.rleft_buffered());
}
void test_peek_memchunks_reset(void) {
MemchunkPool16 pool;
PeekMemchunks16 pchunks(&pool);
std::array<uint8_t, 32> b{{
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', '1', '2', '3', '4',
'5', '6', '7', '8', '9', '0', '1',
}},
d;
pchunks.append(b.data(), b.size());
CU_ASSERT(12 == pchunks.remove(d.data(), 12));
pchunks.disable_peek(true);
pchunks.reset();
CU_ASSERT(0 == pchunks.rleft());
CU_ASSERT(0 == pchunks.rleft_buffered());
CU_ASSERT(nullptr == pchunks.cur);
CU_ASSERT(nullptr == pchunks.cur_pos);
CU_ASSERT(nullptr == pchunks.cur_last);
CU_ASSERT(pchunks.peeking);
}
} // namespace nghttp2 } // namespace nghttp2

View File

@ -36,6 +36,11 @@ void test_memchunks_append(void);
void test_memchunks_drain(void); void test_memchunks_drain(void);
void test_memchunks_riovec(void); void test_memchunks_riovec(void);
void test_memchunks_recycle(void); void test_memchunks_recycle(void);
void test_memchunks_reset(void);
void test_peek_memchunks_append(void);
void test_peek_memchunks_disable_peek_drain(void);
void test_peek_memchunks_disable_peek_no_drain(void);
void test_peek_memchunks_reset(void);
} // namespace nghttp2 } // namespace nghttp2

View File

@ -172,7 +172,16 @@ int main(int argc, char *argv[]) {
!CU_add_test(pSuite, "memchunk_drain", nghttp2::test_memchunks_drain) || !CU_add_test(pSuite, "memchunk_drain", nghttp2::test_memchunks_drain) ||
!CU_add_test(pSuite, "memchunk_riovec", nghttp2::test_memchunks_riovec) || !CU_add_test(pSuite, "memchunk_riovec", nghttp2::test_memchunks_riovec) ||
!CU_add_test(pSuite, "memchunk_recycle", !CU_add_test(pSuite, "memchunk_recycle",
nghttp2::test_memchunks_recycle)) { nghttp2::test_memchunks_recycle) ||
!CU_add_test(pSuite, "memchunk_reset", nghttp2::test_memchunks_reset) ||
!CU_add_test(pSuite, "peek_memchunk_append",
nghttp2::test_peek_memchunks_append) ||
!CU_add_test(pSuite, "peek_memchunk_disable_peek_drain",
nghttp2::test_peek_memchunks_disable_peek_drain) ||
!CU_add_test(pSuite, "peek_memchunk_disable_peek_no_drain",
nghttp2::test_peek_memchunks_disable_peek_no_drain) ||
!CU_add_test(pSuite, "peek_memchunk_reset",
nghttp2::test_peek_memchunks_reset)) {
CU_cleanup_registry(); CU_cleanup_registry();
return CU_get_error(); return CU_get_error();
} }

View File

@ -358,7 +358,8 @@ int ClientHandler::upstream_http1_connhd_read() {
ClientHandler::ClientHandler(Worker *worker, int fd, SSL *ssl, ClientHandler::ClientHandler(Worker *worker, int fd, SSL *ssl,
const char *ipaddr, const char *port) const char *ipaddr, const char *port)
: conn_(worker->get_loop(), fd, ssl, get_config()->upstream_write_timeout, : conn_(worker->get_loop(), fd, ssl, worker->get_mcpool(),
get_config()->upstream_write_timeout,
get_config()->upstream_read_timeout, get_config()->write_rate, get_config()->upstream_read_timeout, get_config()->write_rate,
get_config()->write_burst, get_config()->read_rate, get_config()->write_burst, get_config()->read_rate,
get_config()->read_burst, writecb, readcb, timeoutcb, this), get_config()->read_burst, writecb, readcb, timeoutcb, this),

View File

@ -40,12 +40,13 @@ using namespace nghttp2;
namespace shrpx { namespace shrpx {
Connection::Connection(struct ev_loop *loop, int fd, SSL *ssl, Connection::Connection(struct ev_loop *loop, int fd, SSL *ssl,
ev_tstamp write_timeout, ev_tstamp read_timeout, MemchunkPool *mcpool, ev_tstamp write_timeout,
size_t write_rate, size_t write_burst, size_t read_rate, ev_tstamp read_timeout, size_t write_rate,
size_t read_burst, IOCb writecb, IOCb readcb, size_t write_burst, size_t read_rate, size_t read_burst,
TimerCb timeoutcb, void *data) IOCb writecb, IOCb readcb, TimerCb timeoutcb, void *data)
: tls{}, wlimit(loop, &wev, write_rate, write_burst), : tls{DefaultMemchunks(mcpool), DefaultPeekMemchunks(mcpool)},
rlimit(loop, &rev, read_rate, read_burst, ssl), writecb(writecb), wlimit(loop, &wev, write_rate, write_burst),
rlimit(loop, &rev, read_rate, read_burst, this), writecb(writecb),
readcb(readcb), timeoutcb(timeoutcb), loop(loop), data(data), fd(fd) { readcb(readcb), timeoutcb(timeoutcb), loop(loop), data(data), fd(fd) {
ev_io_init(&wev, writecb, fd, EV_WRITE); ev_io_init(&wev, writecb, fd, EV_WRITE);
@ -83,10 +84,12 @@ void Connection::disconnect() {
if (tls.cached_session) { if (tls.cached_session) {
SSL_SESSION_free(tls.cached_session); SSL_SESSION_free(tls.cached_session);
tls.cached_session = nullptr;
} }
if (tls.cached_session_lookup_req) { if (tls.cached_session_lookup_req) {
tls.cached_session_lookup_req->canceled = true; tls.cached_session_lookup_req->canceled = true;
tls.cached_session_lookup_req = nullptr;
} }
// To reuse SSL/TLS session, we have to shutdown, and don't free // To reuse SSL/TLS session, we have to shutdown, and don't free
@ -96,7 +99,15 @@ void Connection::disconnect() {
tls.ssl = nullptr; tls.ssl = nullptr;
} }
tls = {tls.ssl}; tls.wbuf.reset();
tls.rbuf.reset();
tls.last_write_idle = 0.;
tls.warmup_writelen = 0;
tls.last_writelen = 0;
tls.last_readlen = 0;
tls.handshake_state = 0;
tls.initial_handshake_done = false;
tls.reneg_started = false;
} }
if (fd != -1) { if (fd != -1) {
@ -114,22 +125,9 @@ void Connection::disconnect() {
wlimit.stopw(); wlimit.stopw();
} }
namespace { void Connection::prepare_client_handshake() { SSL_set_connect_state(tls.ssl); }
void allocate_buffer(Connection *conn) {
conn->tls.rb = make_unique<Buffer<16_k>>();
conn->tls.wb = make_unique<Buffer<16_k>>();
}
} // namespace
void Connection::prepare_client_handshake() { void Connection::prepare_server_handshake() { SSL_set_accept_state(tls.ssl); }
SSL_set_connect_state(tls.ssl);
allocate_buffer(this);
}
void Connection::prepare_server_handshake() {
SSL_set_accept_state(tls.ssl);
allocate_buffer(this);
}
// BIO implementation is inspired by openldap implementation: // BIO implementation is inspired by openldap implementation:
// http://www.openldap.org/devel/cvsweb.cgi/~checkout~/libraries/libldap/tls_o.c // http://www.openldap.org/devel/cvsweb.cgi/~checkout~/libraries/libldap/tls_o.c
@ -140,27 +138,26 @@ int shrpx_bio_write(BIO *b, const char *buf, int len) {
} }
auto conn = static_cast<Connection *>(b->ptr); auto conn = static_cast<Connection *>(b->ptr);
auto &wb = conn->tls.wb; auto &wbuf = conn->tls.wbuf;
BIO_clear_retry_flags(b); BIO_clear_retry_flags(b);
if (conn->tls.initial_handshake_done) { if (conn->tls.initial_handshake_done) {
// After handshake finished, send |buf| of length |len| to the // After handshake finished, send |buf| of length |len| to the
// socket directly. // socket directly.
if (wb && wb->rleft()) { if (wbuf.rleft()) {
auto nwrite = conn->write_clear(wb->pos, wb->rleft()); std::array<struct iovec, 4> iov;
auto iovcnt = wbuf.riovec(iov.data(), iov.size());
auto nwrite = conn->writev_clear(iov.data(), iovcnt);
if (nwrite < 0) { if (nwrite < 0) {
return -1; return -1;
} }
wb->drain(nwrite); wbuf.drain(nwrite);
if (wb->rleft()) { if (wbuf.rleft()) {
BIO_set_retry_write(b); BIO_set_retry_write(b);
return -1; return -1;
} }
// Here delete TLS write buffer
wb.reset();
} }
auto nwrite = conn->write_clear(buf, len); auto nwrite = conn->write_clear(buf, len);
if (nwrite < 0) { if (nwrite < 0) {
@ -175,16 +172,9 @@ int shrpx_bio_write(BIO *b, const char *buf, int len) {
return nwrite; return nwrite;
} }
auto nwrite = std::min(static_cast<size_t>(len), wb->wleft()); wbuf.append(buf, len);
if (nwrite == 0) { return len;
BIO_set_retry_write(b);
return -1;
}
wb->write(buf, nwrite);
return nwrite;
} }
} // namespace } // namespace
@ -195,11 +185,11 @@ int shrpx_bio_read(BIO *b, char *buf, int len) {
} }
auto conn = static_cast<Connection *>(b->ptr); auto conn = static_cast<Connection *>(b->ptr);
auto &rb = conn->tls.rb; auto &rbuf = conn->tls.rbuf;
BIO_clear_retry_flags(b); BIO_clear_retry_flags(b);
if (conn->tls.initial_handshake_done && !rb) { if (conn->tls.initial_handshake_done && rbuf.rleft() == 0) {
auto nread = conn->read_clear(buf, len); auto nread = conn->read_clear(buf, len);
if (nread < 0) { if (nread < 0) {
return -1; return -1;
@ -211,22 +201,12 @@ int shrpx_bio_read(BIO *b, char *buf, int len) {
return nread; return nread;
} }
auto nread = std::min(static_cast<size_t>(len), rb->rleft()); if (rbuf.rleft() == 0) {
if (nread == 0) {
if (conn->tls.initial_handshake_done) {
rb.reset();
}
BIO_set_retry_read(b); BIO_set_retry_read(b);
return -1; return -1;
} }
std::copy_n(rb->pos, nread, buf); return rbuf.remove(buf, len);
rb->drain(nread);
return nread;
} }
} // namespace } // namespace
@ -289,51 +269,47 @@ void Connection::set_ssl(SSL *ssl) {
bio->ptr = this; bio->ptr = this;
SSL_set_bio(tls.ssl, bio, bio); SSL_set_bio(tls.ssl, bio, bio);
SSL_set_app_data(tls.ssl, this); SSL_set_app_data(tls.ssl, this);
rlimit.set_ssl(tls.ssl);
} }
namespace {
// We should buffer at least full encrypted TLS record here.
// Theoretically, peer can send client hello in several TLS records,
// which could exeed this limit, but it is not portable, and we don't
// have to handle such exotic behaviour.
bool read_buffer_full(DefaultPeekMemchunks &rbuf) {
return rbuf.rleft_buffered() >= 20_k;
}
} // namespace
int Connection::tls_handshake() { int Connection::tls_handshake() {
wlimit.stopw(); wlimit.stopw();
ev_timer_stop(loop, &wt); ev_timer_stop(loop, &wt);
auto nread = read_clear(tls.rb->last, tls.rb->wleft()); if (ev_is_active(&rev)) {
if (nread < 0) { std::array<uint8_t, 8_k> buf;
if (LOG_ENABLED(INFO)) { auto nread = read_clear(buf.data(), buf.size());
LOG(INFO) << "tls: handshake read error"; if (nread < 0) {
}
return -1;
}
tls.rb->write(nread);
// We have limited space for read buffer, so stop reading if it
// filled up.
if (tls.rb->wleft() == 0) {
if (tls.handshake_state != TLS_CONN_WRITE_STARTED) {
// Reading 16KiB before writing server hello is unlikely for
// ordinary client.
if (LOG_ENABLED(INFO)) { if (LOG_ENABLED(INFO)) {
LOG(INFO) << "tls: client hello is too large"; LOG(INFO) << "tls: handshake read error";
} }
return -1; return -1;
} }
tls.rbuf.append(buf.data(), nread);
rlimit.stopw(); if (read_buffer_full(tls.rbuf)) {
ev_timer_stop(loop, &rt); rlimit.stopw();
}
} }
switch (tls.handshake_state) { switch (tls.handshake_state) {
case TLS_CONN_WAIT_FOR_SESSION_CACHE: case TLS_CONN_WAIT_FOR_SESSION_CACHE:
if (nread > 0) {
if (LOG_ENABLED(INFO)) {
LOG(INFO) << "tls: client sent addtional data after client hello";
}
return -1;
}
return SHRPX_ERR_INPROGRESS; return SHRPX_ERR_INPROGRESS;
case TLS_CONN_GOT_SESSION_CACHE: { case TLS_CONN_GOT_SESSION_CACHE: {
// Use the same trick invented by @kazuho in h2o project // Use the same trick invented by @kazuho in h2o project.
tls.wb->reset();
tls.rb->pos = tls.rb->begin(); // Discard all outgoing data.
tls.wbuf.reset();
// Rewind buffered incoming data to replay client hello.
tls.rbuf.disable_peek(false);
auto ssl_ctx = SSL_get_SSL_CTX(tls.ssl); auto ssl_ctx = SSL_get_SSL_CTX(tls.ssl);
auto ssl_opts = SSL_get_options(tls.ssl); auto ssl_opts = SSL_get_options(tls.ssl);
@ -382,32 +358,33 @@ int Connection::tls_handshake() {
return SHRPX_ERR_INPROGRESS; return SHRPX_ERR_INPROGRESS;
} }
if (tls.wb->rleft()) { if (tls.wbuf.rleft()) {
// First write indicates that resumption stuff has done. // First write indicates that resumption stuff has done.
tls.handshake_state = TLS_CONN_WRITE_STARTED; if (tls.handshake_state != TLS_CONN_WRITE_STARTED) {
auto nwrite = write_clear(tls.wb->pos, tls.wb->rleft()); tls.handshake_state = TLS_CONN_WRITE_STARTED;
// If peek has already disabled, this is noop.
tls.rbuf.disable_peek(true);
}
std::array<struct iovec, 4> iov;
auto iovcnt = tls.wbuf.riovec(iov.data(), iov.size());
auto nwrite = writev_clear(iov.data(), iovcnt);
if (nwrite < 0) { if (nwrite < 0) {
if (LOG_ENABLED(INFO)) { if (LOG_ENABLED(INFO)) {
LOG(INFO) << "tls: handshake write error"; LOG(INFO) << "tls: handshake write error";
} }
return -1; return -1;
} }
tls.wb->drain(nwrite); tls.wbuf.drain(nwrite);
if (tls.wbuf.rleft()) {
wlimit.startw();
ev_timer_again(loop, &wt);
}
} }
if (tls.wb->rleft()) { if (!read_buffer_full(tls.rbuf)) {
wlimit.startw();
ev_timer_again(loop, &wt);
} else {
tls.wb->reset();
}
if (tls.handshake_state == TLS_CONN_WRITE_STARTED && tls.rb->rleft() == 0) {
tls.rb->reset();
// We may have stopped reading // We may have stopped reading
rlimit.startw(); rlimit.startw();
ev_timer_again(loop, &rt);
} }
if (rv != 1) { if (rv != 1) {
@ -419,13 +396,14 @@ int Connection::tls_handshake() {
tls.initial_handshake_done = true; tls.initial_handshake_done = true;
if (tls.rb->rleft()) { // We have to start read watcher, since later stage of code expects
ev_feed_event(loop, &rev, EV_READ); // this.
}
// We may have stopped reading
rlimit.startw(); rlimit.startw();
ev_timer_again(loop, &rt);
// We may have whole request in tls.rbuf. This means that we don't
// get notified further read event. This is especially true for
// HTTP/1.1.
handle_tls_pending_read();
if (LOG_ENABLED(INFO)) { if (LOG_ENABLED(INFO)) {
LOG(INFO) << "SSL/TLS handshake completed"; LOG(INFO) << "SSL/TLS handshake completed";
@ -506,8 +484,8 @@ ssize_t Connection::write_tls(const void *data, size_t len) {
return SHRPX_ERR_NETWORK; return SHRPX_ERR_NETWORK;
case SSL_ERROR_WANT_WRITE: case SSL_ERROR_WANT_WRITE:
tls.last_writelen = len; tls.last_writelen = len;
wlimit.startw(); // starting write watcher and timer is done in write_clear via
ev_timer_again(loop, &wt); // bio.
return 0; return 0;
default: default:
if (LOG_ENABLED(INFO)) { if (LOG_ENABLED(INFO)) {

View File

@ -35,7 +35,7 @@
#include "shrpx_rate_limit.h" #include "shrpx_rate_limit.h"
#include "shrpx_error.h" #include "shrpx_error.h"
#include "buffer.h" #include "memchunk.h"
namespace shrpx { namespace shrpx {
@ -50,6 +50,8 @@ enum {
}; };
struct TLSConnection { struct TLSConnection {
DefaultMemchunks wbuf;
DefaultPeekMemchunks rbuf;
SSL *ssl; SSL *ssl;
SSL_SESSION *cached_session; SSL_SESSION *cached_session;
MemcachedRequest *cached_session_lookup_req; MemcachedRequest *cached_session_lookup_req;
@ -62,8 +64,6 @@ struct TLSConnection {
int handshake_state; int handshake_state;
bool initial_handshake_done; bool initial_handshake_done;
bool reneg_started; bool reneg_started;
std::unique_ptr<Buffer<16_k>> rb;
std::unique_ptr<Buffer<16_k>> wb;
}; };
template <typename T> using EVCb = void (*)(struct ev_loop *, T *, int); template <typename T> using EVCb = void (*)(struct ev_loop *, T *, int);
@ -72,10 +72,10 @@ using IOCb = EVCb<ev_io>;
using TimerCb = EVCb<ev_timer>; using TimerCb = EVCb<ev_timer>;
struct Connection { struct Connection {
Connection(struct ev_loop *loop, int fd, SSL *ssl, ev_tstamp write_timeout, Connection(struct ev_loop *loop, int fd, SSL *ssl, MemchunkPool *mcpool,
ev_tstamp read_timeout, size_t write_rate, size_t write_burst, ev_tstamp write_timeout, ev_tstamp read_timeout, size_t write_rate,
size_t read_rate, size_t read_burst, IOCb writecb, IOCb readcb, size_t write_burst, size_t read_rate, size_t read_burst,
TimerCb timeoutcb, void *data); IOCb writecb, IOCb readcb, TimerCb timeoutcb, void *data);
~Connection(); ~Connection();
void disconnect(); void disconnect();

View File

@ -144,7 +144,8 @@ void writecb(struct ev_loop *loop, ev_io *w, int revents) {
Http2Session::Http2Session(struct ev_loop *loop, SSL_CTX *ssl_ctx, Http2Session::Http2Session(struct ev_loop *loop, SSL_CTX *ssl_ctx,
ConnectBlocker *connect_blocker, Worker *worker, ConnectBlocker *connect_blocker, Worker *worker,
size_t group, size_t idx) size_t group, size_t idx)
: conn_(loop, -1, nullptr, get_config()->downstream_write_timeout, : conn_(loop, -1, nullptr, worker->get_mcpool(),
get_config()->downstream_write_timeout,
get_config()->downstream_read_timeout, 0, 0, 0, 0, writecb, readcb, get_config()->downstream_read_timeout, 0, 0, 0, 0, writecb, readcb,
timeoutcb, this), timeoutcb, this),
worker_(worker), connect_blocker_(connect_blocker), ssl_ctx_(ssl_ctx), worker_(worker), connect_blocker_(connect_blocker), ssl_ctx_(ssl_ctx),

View File

@ -112,7 +112,7 @@ void connectcb(struct ev_loop *loop, ev_io *w, int revents) {
HttpDownstreamConnection::HttpDownstreamConnection( HttpDownstreamConnection::HttpDownstreamConnection(
DownstreamConnectionPool *dconn_pool, size_t group, struct ev_loop *loop) DownstreamConnectionPool *dconn_pool, size_t group, struct ev_loop *loop)
: DownstreamConnection(dconn_pool), : DownstreamConnection(dconn_pool),
conn_(loop, -1, nullptr, get_config()->downstream_write_timeout, conn_(loop, -1, nullptr, nullptr, get_config()->downstream_write_timeout,
get_config()->downstream_read_timeout, 0, 0, 0, 0, connectcb, get_config()->downstream_read_timeout, 0, 0, 0, 0, connectcb,
readcb, timeoutcb, this), readcb, timeoutcb, this),
ioctrl_(&conn_.rlimit), response_htp_{0}, group_(group), addr_idx_(0), ioctrl_(&conn_.rlimit), response_htp_{0}, group_(group), addr_idx_(0),

View File

@ -92,7 +92,7 @@ constexpr ev_tstamp read_timeout = 10.;
MemcachedConnection::MemcachedConnection(const Address *addr, MemcachedConnection::MemcachedConnection(const Address *addr,
struct ev_loop *loop) struct ev_loop *loop)
: conn_(loop, -1, nullptr, write_timeout, read_timeout, 0, 0, 0, 0, : conn_(loop, -1, nullptr, nullptr, write_timeout, read_timeout, 0, 0, 0, 0,
connectcb, readcb, timeoutcb, this), connectcb, readcb, timeoutcb, this),
parse_state_{}, addr_(addr), sendsum_(0), connected_(false) {} parse_state_{}, addr_(addr), sendsum_(0), connected_(false) {}
@ -403,6 +403,7 @@ int MemcachedConnection::parse_packet() {
return 0; return 0;
} }
#undef DEFAULT_WR_IOVCNT
#define DEFAULT_WR_IOVCNT 128 #define DEFAULT_WR_IOVCNT 128
#if defined(IOV_MAX) && IOV_MAX < DEFAULT_WR_IOVCNT #if defined(IOV_MAX) && IOV_MAX < DEFAULT_WR_IOVCNT

View File

@ -26,6 +26,8 @@
#include <limits> #include <limits>
#include "shrpx_connection.h"
namespace shrpx { namespace shrpx {
namespace { namespace {
@ -36,9 +38,9 @@ void regencb(struct ev_loop *loop, ev_timer *w, int revents) {
} // namespace } // namespace
RateLimit::RateLimit(struct ev_loop *loop, ev_io *w, size_t rate, size_t burst, RateLimit::RateLimit(struct ev_loop *loop, ev_io *w, size_t rate, size_t burst,
SSL *ssl) Connection *conn)
: w_(w), loop_(loop), ssl_(ssl), rate_(rate), burst_(burst), avail_(burst), : w_(w), loop_(loop), conn_(conn), rate_(rate), burst_(burst),
startw_req_(false) { avail_(burst), startw_req_(false) {
ev_timer_init(&t_, regencb, 0., 1.); ev_timer_init(&t_, regencb, 0., 1.);
t_.data = this; t_.data = this;
if (rate_ > 0) { if (rate_ > 0) {
@ -97,7 +99,8 @@ void RateLimit::stopw() {
} }
void RateLimit::handle_tls_pending_read() { void RateLimit::handle_tls_pending_read() {
if (!ssl_ || SSL_pending(ssl_) == 0) { if (!conn_ || !conn_->tls.ssl ||
(SSL_pending(conn_->tls.ssl) == 0 && conn_->tls.rbuf.rleft() == 0)) {
return; return;
} }
@ -106,6 +109,4 @@ void RateLimit::handle_tls_pending_read() {
ev_feed_event(loop_, w_, EV_READ); ev_feed_event(loop_, w_, EV_READ);
} }
void RateLimit::set_ssl(SSL *ssl) { ssl_ = ssl; }
} // namespace shrpx } // namespace shrpx

View File

@ -33,28 +33,30 @@
namespace shrpx { namespace shrpx {
struct Connection;
class RateLimit { class RateLimit {
public: public:
// We need |ssl| object to check that it has unread decrypted bytes. // We need |conn| object to check that it has unread bytes for TLS
// connection.
RateLimit(struct ev_loop *loop, ev_io *w, size_t rate, size_t burst, RateLimit(struct ev_loop *loop, ev_io *w, size_t rate, size_t burst,
SSL *ssl = nullptr); Connection *conn = nullptr);
~RateLimit(); ~RateLimit();
size_t avail() const; size_t avail() const;
void drain(size_t n); void drain(size_t n);
void regen(); void regen();
void startw(); void startw();
void stopw(); void stopw();
// Feeds event if ssl_ object has unread decrypted bytes. This is // Feeds event if conn_->tls object has unread bytes. This is
// required since it is buffered in ssl_ object, io event is not // required since it is buffered in conn_->tls object, io event is
// generated unless new incoming data is received. // not generated unless new incoming data is received.
void handle_tls_pending_read(); void handle_tls_pending_read();
void set_ssl(SSL *ssl);
private: private:
ev_timer t_; ev_timer t_;
ev_io *w_; ev_io *w_;
struct ev_loop *loop_; struct ev_loop *loop_;
SSL *ssl_; Connection *conn_;
size_t rate_; size_t rate_;
size_t burst_; size_t burst_;
size_t avail_; size_t avail_;