nghttpx: Rewrite TLS async handshake using memchunk buffers
This commit is contained in:
parent
72c661f1dd
commit
e91a576179
152
src/memchunk.h
152
src/memchunk.h
|
@ -117,6 +117,31 @@ template <typename T> struct Pool {
|
|||
template <typename Memchunk> struct Memchunks {
|
||||
Memchunks(Pool<Memchunk> *pool)
|
||||
: pool(pool), head(nullptr), tail(nullptr), len(0) {}
|
||||
Memchunks(const Memchunks &) = delete;
|
||||
Memchunks(Memchunks &&other)
|
||||
: pool(other.pool), head(other.head), tail(other.head), len(other.len) {
|
||||
// keep other.pool
|
||||
other.head = other.tail = nullptr;
|
||||
other.len = 0;
|
||||
}
|
||||
Memchunks &operator=(const Memchunks &) = delete;
|
||||
Memchunks &operator=(Memchunks &&other) {
|
||||
if (this == &other) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
reset();
|
||||
|
||||
pool = other.pool;
|
||||
head = other.head;
|
||||
tail = other.tail;
|
||||
len = other.len;
|
||||
|
||||
other.head = other.tail = nullptr;
|
||||
other.len = 0;
|
||||
|
||||
return *this;
|
||||
}
|
||||
~Memchunks() {
|
||||
if (!pool) {
|
||||
return;
|
||||
|
@ -223,15 +248,142 @@ template <typename Memchunk> struct Memchunks {
|
|||
return i;
|
||||
}
|
||||
size_t rleft() const { return len; }
|
||||
void reset() {
|
||||
for (auto m = head; m;) {
|
||||
auto next = m->next;
|
||||
pool->recycle(m);
|
||||
m = next;
|
||||
}
|
||||
len = 0;
|
||||
head = tail = nullptr;
|
||||
}
|
||||
|
||||
Pool<Memchunk> *pool;
|
||||
Memchunk *head, *tail;
|
||||
size_t len;
|
||||
};
|
||||
|
||||
// Wrapper around Memchunks to offer "peeking" functionality.
|
||||
template <typename Memchunk> struct PeekMemchunks {
|
||||
PeekMemchunks(Pool<Memchunk> *pool)
|
||||
: memchunks(pool), cur(nullptr), cur_pos(nullptr), cur_last(nullptr),
|
||||
len(0), peeking(true) {}
|
||||
PeekMemchunks(const PeekMemchunks &) = delete;
|
||||
PeekMemchunks(PeekMemchunks &&other)
|
||||
: memchunks(std::move(other.memchunks)), cur(other.cur),
|
||||
cur_pos(other.cur_pos), cur_last(other.cur_last), len(other.len),
|
||||
peeking(other.peeking) {
|
||||
other.reset();
|
||||
}
|
||||
PeekMemchunks &operator=(const PeekMemchunks &) = delete;
|
||||
PeekMemchunks &operator=(PeekMemchunks &&other) {
|
||||
if (this == &other) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
memchunks = std::move(other.memchunks);
|
||||
cur = other.cur;
|
||||
cur_pos = other.cur_pos;
|
||||
cur_last = other.cur_last;
|
||||
len = other.len;
|
||||
peeking = other.peeking;
|
||||
|
||||
other.reset();
|
||||
|
||||
return *this;
|
||||
}
|
||||
size_t append(const void *src, size_t count) {
|
||||
count = memchunks.append(src, count);
|
||||
len += count;
|
||||
return count;
|
||||
}
|
||||
size_t remove(void *dest, size_t count) {
|
||||
if (!peeking) {
|
||||
count = memchunks.remove(dest, count);
|
||||
len -= count;
|
||||
return count;
|
||||
}
|
||||
|
||||
if (count == 0 || len == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (!cur) {
|
||||
cur = memchunks.head;
|
||||
cur_pos = cur->pos;
|
||||
}
|
||||
|
||||
// cur_last could be updated in append
|
||||
cur_last = cur->last;
|
||||
|
||||
if (cur_pos == cur_last) {
|
||||
assert(cur->next);
|
||||
cur = cur->next;
|
||||
}
|
||||
|
||||
auto first = static_cast<uint8_t *>(dest);
|
||||
auto last = first + count;
|
||||
|
||||
for (;;) {
|
||||
auto n = std::min(last - first, cur_last - cur_pos);
|
||||
|
||||
first = std::copy_n(cur_pos, n, first);
|
||||
cur_pos += n;
|
||||
len -= n;
|
||||
|
||||
if (first == last) {
|
||||
break;
|
||||
}
|
||||
assert(cur_pos == cur_last);
|
||||
if (!cur->next) {
|
||||
break;
|
||||
}
|
||||
cur = cur->next;
|
||||
cur_pos = cur->pos;
|
||||
cur_last = cur->last;
|
||||
}
|
||||
return first - static_cast<uint8_t *>(dest);
|
||||
}
|
||||
size_t rleft() const { return len; }
|
||||
size_t rleft_buffered() const { return memchunks.rleft(); }
|
||||
void disable_peek(bool drain) {
|
||||
if (!peeking) {
|
||||
return;
|
||||
}
|
||||
if (drain) {
|
||||
auto n = rleft_buffered() - rleft();
|
||||
memchunks.drain(n);
|
||||
assert(len == memchunks.rleft());
|
||||
} else {
|
||||
len = memchunks.rleft();
|
||||
}
|
||||
cur = nullptr;
|
||||
cur_pos = cur_last = nullptr;
|
||||
peeking = false;
|
||||
}
|
||||
void reset() {
|
||||
memchunks.reset();
|
||||
cur = nullptr;
|
||||
cur_pos = cur_last = nullptr;
|
||||
len = 0;
|
||||
peeking = true;
|
||||
}
|
||||
Memchunks<Memchunk> memchunks;
|
||||
// Pointer to the Memchunk currently we are reading/writing.
|
||||
Memchunk *cur;
|
||||
// Region inside cur, we have processed to cur_pos.
|
||||
uint8_t *cur_pos, *cur_last;
|
||||
// This is the length we have left unprocessed. len <=
|
||||
// memchunk.rleft() must hold.
|
||||
size_t len;
|
||||
// true if peeking is enabled. Initially it is true.
|
||||
bool peeking;
|
||||
};
|
||||
|
||||
using Memchunk16K = Memchunk<16_k>;
|
||||
using MemchunkPool = Pool<Memchunk16K>;
|
||||
using DefaultMemchunks = Memchunks<Memchunk16K>;
|
||||
using DefaultPeekMemchunks = PeekMemchunks<Memchunk16K>;
|
||||
|
||||
#define DEFAULT_WR_IOVCNT 16
|
||||
|
||||
|
|
|
@ -84,6 +84,7 @@ void test_pool_recycle(void) {
|
|||
using Memchunk16 = Memchunk<16>;
|
||||
using MemchunkPool16 = Pool<Memchunk16>;
|
||||
using Memchunks16 = Memchunks<Memchunk16>;
|
||||
using PeekMemchunks16 = PeekMemchunks<Memchunk16>;
|
||||
|
||||
void test_memchunks_append(void) {
|
||||
MemchunkPool16 pool;
|
||||
|
@ -196,4 +197,144 @@ void test_memchunks_recycle(void) {
|
|||
CU_ASSERT(nullptr == m->next);
|
||||
}
|
||||
|
||||
void test_memchunks_reset(void) {
|
||||
MemchunkPool16 pool;
|
||||
Memchunks16 chunks(&pool);
|
||||
|
||||
std::array<uint8_t, 32> b{};
|
||||
|
||||
chunks.append(b.data(), b.size());
|
||||
|
||||
CU_ASSERT(32 == chunks.rleft());
|
||||
|
||||
chunks.reset();
|
||||
|
||||
CU_ASSERT(0 == chunks.rleft());
|
||||
CU_ASSERT(nullptr == chunks.head);
|
||||
CU_ASSERT(nullptr == chunks.tail);
|
||||
|
||||
auto m = pool.freelist;
|
||||
|
||||
CU_ASSERT(nullptr != m);
|
||||
CU_ASSERT(nullptr != m->next);
|
||||
CU_ASSERT(nullptr == m->next->next);
|
||||
}
|
||||
|
||||
void test_peek_memchunks_append(void) {
|
||||
MemchunkPool16 pool;
|
||||
PeekMemchunks16 pchunks(&pool);
|
||||
|
||||
std::array<uint8_t, 32> b{{
|
||||
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', '1', '2', '3', '4',
|
||||
'5', '6', '7', '8', '9', '0', '1',
|
||||
}},
|
||||
d;
|
||||
|
||||
pchunks.append(b.data(), b.size());
|
||||
|
||||
CU_ASSERT(32 == pchunks.rleft());
|
||||
CU_ASSERT(32 == pchunks.rleft_buffered());
|
||||
|
||||
CU_ASSERT(0 == pchunks.remove(nullptr, 0));
|
||||
|
||||
CU_ASSERT(32 == pchunks.rleft());
|
||||
CU_ASSERT(32 == pchunks.rleft_buffered());
|
||||
|
||||
CU_ASSERT(12 == pchunks.remove(d.data(), 12));
|
||||
|
||||
CU_ASSERT(std::equal(std::begin(b), std::begin(b) + 12, std::begin(d)));
|
||||
|
||||
CU_ASSERT(20 == pchunks.rleft());
|
||||
CU_ASSERT(32 == pchunks.rleft_buffered());
|
||||
|
||||
CU_ASSERT(20 == pchunks.remove(d.data(), d.size()));
|
||||
|
||||
CU_ASSERT(std::equal(std::begin(b) + 12, std::end(b), std::begin(d)));
|
||||
|
||||
CU_ASSERT(0 == pchunks.rleft());
|
||||
CU_ASSERT(32 == pchunks.rleft_buffered());
|
||||
}
|
||||
|
||||
void test_peek_memchunks_disable_peek_drain(void) {
|
||||
MemchunkPool16 pool;
|
||||
PeekMemchunks16 pchunks(&pool);
|
||||
|
||||
std::array<uint8_t, 32> b{{
|
||||
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', '1', '2', '3', '4',
|
||||
'5', '6', '7', '8', '9', '0', '1',
|
||||
}},
|
||||
d;
|
||||
|
||||
pchunks.append(b.data(), b.size());
|
||||
|
||||
CU_ASSERT(12 == pchunks.remove(d.data(), 12));
|
||||
|
||||
pchunks.disable_peek(true);
|
||||
|
||||
CU_ASSERT(!pchunks.peeking);
|
||||
CU_ASSERT(20 == pchunks.rleft());
|
||||
CU_ASSERT(20 == pchunks.rleft_buffered());
|
||||
|
||||
CU_ASSERT(20 == pchunks.remove(d.data(), d.size()));
|
||||
|
||||
CU_ASSERT(std::equal(std::begin(b) + 12, std::end(b), std::begin(d)));
|
||||
|
||||
CU_ASSERT(0 == pchunks.rleft());
|
||||
CU_ASSERT(0 == pchunks.rleft_buffered());
|
||||
}
|
||||
|
||||
void test_peek_memchunks_disable_peek_no_drain(void) {
|
||||
MemchunkPool16 pool;
|
||||
PeekMemchunks16 pchunks(&pool);
|
||||
|
||||
std::array<uint8_t, 32> b{{
|
||||
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', '1', '2', '3', '4',
|
||||
'5', '6', '7', '8', '9', '0', '1',
|
||||
}},
|
||||
d;
|
||||
|
||||
pchunks.append(b.data(), b.size());
|
||||
|
||||
CU_ASSERT(12 == pchunks.remove(d.data(), 12));
|
||||
|
||||
pchunks.disable_peek(false);
|
||||
|
||||
CU_ASSERT(!pchunks.peeking);
|
||||
CU_ASSERT(32 == pchunks.rleft());
|
||||
CU_ASSERT(32 == pchunks.rleft_buffered());
|
||||
|
||||
CU_ASSERT(32 == pchunks.remove(d.data(), d.size()));
|
||||
|
||||
CU_ASSERT(std::equal(std::begin(b), std::end(b), std::begin(d)));
|
||||
|
||||
CU_ASSERT(0 == pchunks.rleft());
|
||||
CU_ASSERT(0 == pchunks.rleft_buffered());
|
||||
}
|
||||
|
||||
void test_peek_memchunks_reset(void) {
|
||||
MemchunkPool16 pool;
|
||||
PeekMemchunks16 pchunks(&pool);
|
||||
|
||||
std::array<uint8_t, 32> b{{
|
||||
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', '1', '2', '3', '4',
|
||||
'5', '6', '7', '8', '9', '0', '1',
|
||||
}},
|
||||
d;
|
||||
|
||||
pchunks.append(b.data(), b.size());
|
||||
|
||||
CU_ASSERT(12 == pchunks.remove(d.data(), 12));
|
||||
|
||||
pchunks.disable_peek(true);
|
||||
pchunks.reset();
|
||||
|
||||
CU_ASSERT(0 == pchunks.rleft());
|
||||
CU_ASSERT(0 == pchunks.rleft_buffered());
|
||||
|
||||
CU_ASSERT(nullptr == pchunks.cur);
|
||||
CU_ASSERT(nullptr == pchunks.cur_pos);
|
||||
CU_ASSERT(nullptr == pchunks.cur_last);
|
||||
CU_ASSERT(pchunks.peeking);
|
||||
}
|
||||
|
||||
} // namespace nghttp2
|
||||
|
|
|
@ -36,6 +36,11 @@ void test_memchunks_append(void);
|
|||
void test_memchunks_drain(void);
|
||||
void test_memchunks_riovec(void);
|
||||
void test_memchunks_recycle(void);
|
||||
void test_memchunks_reset(void);
|
||||
void test_peek_memchunks_append(void);
|
||||
void test_peek_memchunks_disable_peek_drain(void);
|
||||
void test_peek_memchunks_disable_peek_no_drain(void);
|
||||
void test_peek_memchunks_reset(void);
|
||||
|
||||
} // namespace nghttp2
|
||||
|
||||
|
|
|
@ -172,7 +172,16 @@ int main(int argc, char *argv[]) {
|
|||
!CU_add_test(pSuite, "memchunk_drain", nghttp2::test_memchunks_drain) ||
|
||||
!CU_add_test(pSuite, "memchunk_riovec", nghttp2::test_memchunks_riovec) ||
|
||||
!CU_add_test(pSuite, "memchunk_recycle",
|
||||
nghttp2::test_memchunks_recycle)) {
|
||||
nghttp2::test_memchunks_recycle) ||
|
||||
!CU_add_test(pSuite, "memchunk_reset", nghttp2::test_memchunks_reset) ||
|
||||
!CU_add_test(pSuite, "peek_memchunk_append",
|
||||
nghttp2::test_peek_memchunks_append) ||
|
||||
!CU_add_test(pSuite, "peek_memchunk_disable_peek_drain",
|
||||
nghttp2::test_peek_memchunks_disable_peek_drain) ||
|
||||
!CU_add_test(pSuite, "peek_memchunk_disable_peek_no_drain",
|
||||
nghttp2::test_peek_memchunks_disable_peek_no_drain) ||
|
||||
!CU_add_test(pSuite, "peek_memchunk_reset",
|
||||
nghttp2::test_peek_memchunks_reset)) {
|
||||
CU_cleanup_registry();
|
||||
return CU_get_error();
|
||||
}
|
||||
|
|
|
@ -358,7 +358,8 @@ int ClientHandler::upstream_http1_connhd_read() {
|
|||
|
||||
ClientHandler::ClientHandler(Worker *worker, int fd, SSL *ssl,
|
||||
const char *ipaddr, const char *port)
|
||||
: conn_(worker->get_loop(), fd, ssl, get_config()->upstream_write_timeout,
|
||||
: conn_(worker->get_loop(), fd, ssl, worker->get_mcpool(),
|
||||
get_config()->upstream_write_timeout,
|
||||
get_config()->upstream_read_timeout, get_config()->write_rate,
|
||||
get_config()->write_burst, get_config()->read_rate,
|
||||
get_config()->read_burst, writecb, readcb, timeoutcb, this),
|
||||
|
|
|
@ -40,12 +40,13 @@ using namespace nghttp2;
|
|||
|
||||
namespace shrpx {
|
||||
Connection::Connection(struct ev_loop *loop, int fd, SSL *ssl,
|
||||
ev_tstamp write_timeout, ev_tstamp read_timeout,
|
||||
size_t write_rate, size_t write_burst, size_t read_rate,
|
||||
size_t read_burst, IOCb writecb, IOCb readcb,
|
||||
TimerCb timeoutcb, void *data)
|
||||
: tls{}, wlimit(loop, &wev, write_rate, write_burst),
|
||||
rlimit(loop, &rev, read_rate, read_burst, ssl), writecb(writecb),
|
||||
MemchunkPool *mcpool, ev_tstamp write_timeout,
|
||||
ev_tstamp read_timeout, size_t write_rate,
|
||||
size_t write_burst, size_t read_rate, size_t read_burst,
|
||||
IOCb writecb, IOCb readcb, TimerCb timeoutcb, void *data)
|
||||
: tls{DefaultMemchunks(mcpool), DefaultPeekMemchunks(mcpool)},
|
||||
wlimit(loop, &wev, write_rate, write_burst),
|
||||
rlimit(loop, &rev, read_rate, read_burst, this), writecb(writecb),
|
||||
readcb(readcb), timeoutcb(timeoutcb), loop(loop), data(data), fd(fd) {
|
||||
|
||||
ev_io_init(&wev, writecb, fd, EV_WRITE);
|
||||
|
@ -83,10 +84,12 @@ void Connection::disconnect() {
|
|||
|
||||
if (tls.cached_session) {
|
||||
SSL_SESSION_free(tls.cached_session);
|
||||
tls.cached_session = nullptr;
|
||||
}
|
||||
|
||||
if (tls.cached_session_lookup_req) {
|
||||
tls.cached_session_lookup_req->canceled = true;
|
||||
tls.cached_session_lookup_req = nullptr;
|
||||
}
|
||||
|
||||
// To reuse SSL/TLS session, we have to shutdown, and don't free
|
||||
|
@ -96,7 +99,15 @@ void Connection::disconnect() {
|
|||
tls.ssl = nullptr;
|
||||
}
|
||||
|
||||
tls = {tls.ssl};
|
||||
tls.wbuf.reset();
|
||||
tls.rbuf.reset();
|
||||
tls.last_write_idle = 0.;
|
||||
tls.warmup_writelen = 0;
|
||||
tls.last_writelen = 0;
|
||||
tls.last_readlen = 0;
|
||||
tls.handshake_state = 0;
|
||||
tls.initial_handshake_done = false;
|
||||
tls.reneg_started = false;
|
||||
}
|
||||
|
||||
if (fd != -1) {
|
||||
|
@ -114,22 +125,9 @@ void Connection::disconnect() {
|
|||
wlimit.stopw();
|
||||
}
|
||||
|
||||
namespace {
|
||||
void allocate_buffer(Connection *conn) {
|
||||
conn->tls.rb = make_unique<Buffer<16_k>>();
|
||||
conn->tls.wb = make_unique<Buffer<16_k>>();
|
||||
}
|
||||
} // namespace
|
||||
void Connection::prepare_client_handshake() { SSL_set_connect_state(tls.ssl); }
|
||||
|
||||
void Connection::prepare_client_handshake() {
|
||||
SSL_set_connect_state(tls.ssl);
|
||||
allocate_buffer(this);
|
||||
}
|
||||
|
||||
void Connection::prepare_server_handshake() {
|
||||
SSL_set_accept_state(tls.ssl);
|
||||
allocate_buffer(this);
|
||||
}
|
||||
void Connection::prepare_server_handshake() { SSL_set_accept_state(tls.ssl); }
|
||||
|
||||
// BIO implementation is inspired by openldap implementation:
|
||||
// http://www.openldap.org/devel/cvsweb.cgi/~checkout~/libraries/libldap/tls_o.c
|
||||
|
@ -140,27 +138,26 @@ int shrpx_bio_write(BIO *b, const char *buf, int len) {
|
|||
}
|
||||
|
||||
auto conn = static_cast<Connection *>(b->ptr);
|
||||
auto &wb = conn->tls.wb;
|
||||
auto &wbuf = conn->tls.wbuf;
|
||||
|
||||
BIO_clear_retry_flags(b);
|
||||
|
||||
if (conn->tls.initial_handshake_done) {
|
||||
// After handshake finished, send |buf| of length |len| to the
|
||||
// socket directly.
|
||||
if (wb && wb->rleft()) {
|
||||
auto nwrite = conn->write_clear(wb->pos, wb->rleft());
|
||||
if (wbuf.rleft()) {
|
||||
std::array<struct iovec, 4> iov;
|
||||
auto iovcnt = wbuf.riovec(iov.data(), iov.size());
|
||||
auto nwrite = conn->writev_clear(iov.data(), iovcnt);
|
||||
if (nwrite < 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
wb->drain(nwrite);
|
||||
if (wb->rleft()) {
|
||||
wbuf.drain(nwrite);
|
||||
if (wbuf.rleft()) {
|
||||
BIO_set_retry_write(b);
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Here delete TLS write buffer
|
||||
wb.reset();
|
||||
}
|
||||
auto nwrite = conn->write_clear(buf, len);
|
||||
if (nwrite < 0) {
|
||||
|
@ -175,16 +172,9 @@ int shrpx_bio_write(BIO *b, const char *buf, int len) {
|
|||
return nwrite;
|
||||
}
|
||||
|
||||
auto nwrite = std::min(static_cast<size_t>(len), wb->wleft());
|
||||
wbuf.append(buf, len);
|
||||
|
||||
if (nwrite == 0) {
|
||||
BIO_set_retry_write(b);
|
||||
return -1;
|
||||
}
|
||||
|
||||
wb->write(buf, nwrite);
|
||||
|
||||
return nwrite;
|
||||
return len;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -195,11 +185,11 @@ int shrpx_bio_read(BIO *b, char *buf, int len) {
|
|||
}
|
||||
|
||||
auto conn = static_cast<Connection *>(b->ptr);
|
||||
auto &rb = conn->tls.rb;
|
||||
auto &rbuf = conn->tls.rbuf;
|
||||
|
||||
BIO_clear_retry_flags(b);
|
||||
|
||||
if (conn->tls.initial_handshake_done && !rb) {
|
||||
if (conn->tls.initial_handshake_done && rbuf.rleft() == 0) {
|
||||
auto nread = conn->read_clear(buf, len);
|
||||
if (nread < 0) {
|
||||
return -1;
|
||||
|
@ -211,22 +201,12 @@ int shrpx_bio_read(BIO *b, char *buf, int len) {
|
|||
return nread;
|
||||
}
|
||||
|
||||
auto nread = std::min(static_cast<size_t>(len), rb->rleft());
|
||||
|
||||
if (nread == 0) {
|
||||
if (conn->tls.initial_handshake_done) {
|
||||
rb.reset();
|
||||
}
|
||||
|
||||
if (rbuf.rleft() == 0) {
|
||||
BIO_set_retry_read(b);
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::copy_n(rb->pos, nread, buf);
|
||||
|
||||
rb->drain(nread);
|
||||
|
||||
return nread;
|
||||
return rbuf.remove(buf, len);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -289,51 +269,47 @@ void Connection::set_ssl(SSL *ssl) {
|
|||
bio->ptr = this;
|
||||
SSL_set_bio(tls.ssl, bio, bio);
|
||||
SSL_set_app_data(tls.ssl, this);
|
||||
rlimit.set_ssl(tls.ssl);
|
||||
}
|
||||
|
||||
namespace {
|
||||
// We should buffer at least full encrypted TLS record here.
|
||||
// Theoretically, peer can send client hello in several TLS records,
|
||||
// which could exeed this limit, but it is not portable, and we don't
|
||||
// have to handle such exotic behaviour.
|
||||
bool read_buffer_full(DefaultPeekMemchunks &rbuf) {
|
||||
return rbuf.rleft_buffered() >= 20_k;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
int Connection::tls_handshake() {
|
||||
wlimit.stopw();
|
||||
ev_timer_stop(loop, &wt);
|
||||
|
||||
auto nread = read_clear(tls.rb->last, tls.rb->wleft());
|
||||
if (ev_is_active(&rev)) {
|
||||
std::array<uint8_t, 8_k> buf;
|
||||
auto nread = read_clear(buf.data(), buf.size());
|
||||
if (nread < 0) {
|
||||
if (LOG_ENABLED(INFO)) {
|
||||
LOG(INFO) << "tls: handshake read error";
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
tls.rb->write(nread);
|
||||
|
||||
// We have limited space for read buffer, so stop reading if it
|
||||
// filled up.
|
||||
if (tls.rb->wleft() == 0) {
|
||||
if (tls.handshake_state != TLS_CONN_WRITE_STARTED) {
|
||||
// Reading 16KiB before writing server hello is unlikely for
|
||||
// ordinary client.
|
||||
if (LOG_ENABLED(INFO)) {
|
||||
LOG(INFO) << "tls: client hello is too large";
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
tls.rbuf.append(buf.data(), nread);
|
||||
if (read_buffer_full(tls.rbuf)) {
|
||||
rlimit.stopw();
|
||||
ev_timer_stop(loop, &rt);
|
||||
}
|
||||
}
|
||||
|
||||
switch (tls.handshake_state) {
|
||||
case TLS_CONN_WAIT_FOR_SESSION_CACHE:
|
||||
if (nread > 0) {
|
||||
if (LOG_ENABLED(INFO)) {
|
||||
LOG(INFO) << "tls: client sent addtional data after client hello";
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
return SHRPX_ERR_INPROGRESS;
|
||||
case TLS_CONN_GOT_SESSION_CACHE: {
|
||||
// Use the same trick invented by @kazuho in h2o project
|
||||
tls.wb->reset();
|
||||
tls.rb->pos = tls.rb->begin();
|
||||
// Use the same trick invented by @kazuho in h2o project.
|
||||
|
||||
// Discard all outgoing data.
|
||||
tls.wbuf.reset();
|
||||
// Rewind buffered incoming data to replay client hello.
|
||||
tls.rbuf.disable_peek(false);
|
||||
|
||||
auto ssl_ctx = SSL_get_SSL_CTX(tls.ssl);
|
||||
auto ssl_opts = SSL_get_options(tls.ssl);
|
||||
|
@ -382,32 +358,33 @@ int Connection::tls_handshake() {
|
|||
return SHRPX_ERR_INPROGRESS;
|
||||
}
|
||||
|
||||
if (tls.wb->rleft()) {
|
||||
if (tls.wbuf.rleft()) {
|
||||
// First write indicates that resumption stuff has done.
|
||||
if (tls.handshake_state != TLS_CONN_WRITE_STARTED) {
|
||||
tls.handshake_state = TLS_CONN_WRITE_STARTED;
|
||||
auto nwrite = write_clear(tls.wb->pos, tls.wb->rleft());
|
||||
// If peek has already disabled, this is noop.
|
||||
tls.rbuf.disable_peek(true);
|
||||
}
|
||||
std::array<struct iovec, 4> iov;
|
||||
auto iovcnt = tls.wbuf.riovec(iov.data(), iov.size());
|
||||
auto nwrite = writev_clear(iov.data(), iovcnt);
|
||||
if (nwrite < 0) {
|
||||
if (LOG_ENABLED(INFO)) {
|
||||
LOG(INFO) << "tls: handshake write error";
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
tls.wb->drain(nwrite);
|
||||
}
|
||||
tls.wbuf.drain(nwrite);
|
||||
|
||||
if (tls.wb->rleft()) {
|
||||
if (tls.wbuf.rleft()) {
|
||||
wlimit.startw();
|
||||
ev_timer_again(loop, &wt);
|
||||
} else {
|
||||
tls.wb->reset();
|
||||
}
|
||||
}
|
||||
|
||||
if (tls.handshake_state == TLS_CONN_WRITE_STARTED && tls.rb->rleft() == 0) {
|
||||
tls.rb->reset();
|
||||
|
||||
if (!read_buffer_full(tls.rbuf)) {
|
||||
// We may have stopped reading
|
||||
rlimit.startw();
|
||||
ev_timer_again(loop, &rt);
|
||||
}
|
||||
|
||||
if (rv != 1) {
|
||||
|
@ -419,13 +396,14 @@ int Connection::tls_handshake() {
|
|||
|
||||
tls.initial_handshake_done = true;
|
||||
|
||||
if (tls.rb->rleft()) {
|
||||
ev_feed_event(loop, &rev, EV_READ);
|
||||
}
|
||||
|
||||
// We may have stopped reading
|
||||
// We have to start read watcher, since later stage of code expects
|
||||
// this.
|
||||
rlimit.startw();
|
||||
ev_timer_again(loop, &rt);
|
||||
|
||||
// We may have whole request in tls.rbuf. This means that we don't
|
||||
// get notified further read event. This is especially true for
|
||||
// HTTP/1.1.
|
||||
handle_tls_pending_read();
|
||||
|
||||
if (LOG_ENABLED(INFO)) {
|
||||
LOG(INFO) << "SSL/TLS handshake completed";
|
||||
|
@ -506,8 +484,8 @@ ssize_t Connection::write_tls(const void *data, size_t len) {
|
|||
return SHRPX_ERR_NETWORK;
|
||||
case SSL_ERROR_WANT_WRITE:
|
||||
tls.last_writelen = len;
|
||||
wlimit.startw();
|
||||
ev_timer_again(loop, &wt);
|
||||
// starting write watcher and timer is done in write_clear via
|
||||
// bio.
|
||||
return 0;
|
||||
default:
|
||||
if (LOG_ENABLED(INFO)) {
|
||||
|
|
|
@ -35,7 +35,7 @@
|
|||
|
||||
#include "shrpx_rate_limit.h"
|
||||
#include "shrpx_error.h"
|
||||
#include "buffer.h"
|
||||
#include "memchunk.h"
|
||||
|
||||
namespace shrpx {
|
||||
|
||||
|
@ -50,6 +50,8 @@ enum {
|
|||
};
|
||||
|
||||
struct TLSConnection {
|
||||
DefaultMemchunks wbuf;
|
||||
DefaultPeekMemchunks rbuf;
|
||||
SSL *ssl;
|
||||
SSL_SESSION *cached_session;
|
||||
MemcachedRequest *cached_session_lookup_req;
|
||||
|
@ -62,8 +64,6 @@ struct TLSConnection {
|
|||
int handshake_state;
|
||||
bool initial_handshake_done;
|
||||
bool reneg_started;
|
||||
std::unique_ptr<Buffer<16_k>> rb;
|
||||
std::unique_ptr<Buffer<16_k>> wb;
|
||||
};
|
||||
|
||||
template <typename T> using EVCb = void (*)(struct ev_loop *, T *, int);
|
||||
|
@ -72,10 +72,10 @@ using IOCb = EVCb<ev_io>;
|
|||
using TimerCb = EVCb<ev_timer>;
|
||||
|
||||
struct Connection {
|
||||
Connection(struct ev_loop *loop, int fd, SSL *ssl, ev_tstamp write_timeout,
|
||||
ev_tstamp read_timeout, size_t write_rate, size_t write_burst,
|
||||
size_t read_rate, size_t read_burst, IOCb writecb, IOCb readcb,
|
||||
TimerCb timeoutcb, void *data);
|
||||
Connection(struct ev_loop *loop, int fd, SSL *ssl, MemchunkPool *mcpool,
|
||||
ev_tstamp write_timeout, ev_tstamp read_timeout, size_t write_rate,
|
||||
size_t write_burst, size_t read_rate, size_t read_burst,
|
||||
IOCb writecb, IOCb readcb, TimerCb timeoutcb, void *data);
|
||||
~Connection();
|
||||
|
||||
void disconnect();
|
||||
|
|
|
@ -144,7 +144,8 @@ void writecb(struct ev_loop *loop, ev_io *w, int revents) {
|
|||
Http2Session::Http2Session(struct ev_loop *loop, SSL_CTX *ssl_ctx,
|
||||
ConnectBlocker *connect_blocker, Worker *worker,
|
||||
size_t group, size_t idx)
|
||||
: conn_(loop, -1, nullptr, get_config()->downstream_write_timeout,
|
||||
: conn_(loop, -1, nullptr, worker->get_mcpool(),
|
||||
get_config()->downstream_write_timeout,
|
||||
get_config()->downstream_read_timeout, 0, 0, 0, 0, writecb, readcb,
|
||||
timeoutcb, this),
|
||||
worker_(worker), connect_blocker_(connect_blocker), ssl_ctx_(ssl_ctx),
|
||||
|
|
|
@ -112,7 +112,7 @@ void connectcb(struct ev_loop *loop, ev_io *w, int revents) {
|
|||
HttpDownstreamConnection::HttpDownstreamConnection(
|
||||
DownstreamConnectionPool *dconn_pool, size_t group, struct ev_loop *loop)
|
||||
: DownstreamConnection(dconn_pool),
|
||||
conn_(loop, -1, nullptr, get_config()->downstream_write_timeout,
|
||||
conn_(loop, -1, nullptr, nullptr, get_config()->downstream_write_timeout,
|
||||
get_config()->downstream_read_timeout, 0, 0, 0, 0, connectcb,
|
||||
readcb, timeoutcb, this),
|
||||
ioctrl_(&conn_.rlimit), response_htp_{0}, group_(group), addr_idx_(0),
|
||||
|
|
|
@ -92,7 +92,7 @@ constexpr ev_tstamp read_timeout = 10.;
|
|||
|
||||
MemcachedConnection::MemcachedConnection(const Address *addr,
|
||||
struct ev_loop *loop)
|
||||
: conn_(loop, -1, nullptr, write_timeout, read_timeout, 0, 0, 0, 0,
|
||||
: conn_(loop, -1, nullptr, nullptr, write_timeout, read_timeout, 0, 0, 0, 0,
|
||||
connectcb, readcb, timeoutcb, this),
|
||||
parse_state_{}, addr_(addr), sendsum_(0), connected_(false) {}
|
||||
|
||||
|
@ -403,6 +403,7 @@ int MemcachedConnection::parse_packet() {
|
|||
return 0;
|
||||
}
|
||||
|
||||
#undef DEFAULT_WR_IOVCNT
|
||||
#define DEFAULT_WR_IOVCNT 128
|
||||
|
||||
#if defined(IOV_MAX) && IOV_MAX < DEFAULT_WR_IOVCNT
|
||||
|
|
|
@ -26,6 +26,8 @@
|
|||
|
||||
#include <limits>
|
||||
|
||||
#include "shrpx_connection.h"
|
||||
|
||||
namespace shrpx {
|
||||
|
||||
namespace {
|
||||
|
@ -36,9 +38,9 @@ void regencb(struct ev_loop *loop, ev_timer *w, int revents) {
|
|||
} // namespace
|
||||
|
||||
RateLimit::RateLimit(struct ev_loop *loop, ev_io *w, size_t rate, size_t burst,
|
||||
SSL *ssl)
|
||||
: w_(w), loop_(loop), ssl_(ssl), rate_(rate), burst_(burst), avail_(burst),
|
||||
startw_req_(false) {
|
||||
Connection *conn)
|
||||
: w_(w), loop_(loop), conn_(conn), rate_(rate), burst_(burst),
|
||||
avail_(burst), startw_req_(false) {
|
||||
ev_timer_init(&t_, regencb, 0., 1.);
|
||||
t_.data = this;
|
||||
if (rate_ > 0) {
|
||||
|
@ -97,7 +99,8 @@ void RateLimit::stopw() {
|
|||
}
|
||||
|
||||
void RateLimit::handle_tls_pending_read() {
|
||||
if (!ssl_ || SSL_pending(ssl_) == 0) {
|
||||
if (!conn_ || !conn_->tls.ssl ||
|
||||
(SSL_pending(conn_->tls.ssl) == 0 && conn_->tls.rbuf.rleft() == 0)) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -106,6 +109,4 @@ void RateLimit::handle_tls_pending_read() {
|
|||
ev_feed_event(loop_, w_, EV_READ);
|
||||
}
|
||||
|
||||
void RateLimit::set_ssl(SSL *ssl) { ssl_ = ssl; }
|
||||
|
||||
} // namespace shrpx
|
||||
|
|
|
@ -33,28 +33,30 @@
|
|||
|
||||
namespace shrpx {
|
||||
|
||||
struct Connection;
|
||||
|
||||
class RateLimit {
|
||||
public:
|
||||
// We need |ssl| object to check that it has unread decrypted bytes.
|
||||
// We need |conn| object to check that it has unread bytes for TLS
|
||||
// connection.
|
||||
RateLimit(struct ev_loop *loop, ev_io *w, size_t rate, size_t burst,
|
||||
SSL *ssl = nullptr);
|
||||
Connection *conn = nullptr);
|
||||
~RateLimit();
|
||||
size_t avail() const;
|
||||
void drain(size_t n);
|
||||
void regen();
|
||||
void startw();
|
||||
void stopw();
|
||||
// Feeds event if ssl_ object has unread decrypted bytes. This is
|
||||
// required since it is buffered in ssl_ object, io event is not
|
||||
// generated unless new incoming data is received.
|
||||
// Feeds event if conn_->tls object has unread bytes. This is
|
||||
// required since it is buffered in conn_->tls object, io event is
|
||||
// not generated unless new incoming data is received.
|
||||
void handle_tls_pending_read();
|
||||
void set_ssl(SSL *ssl);
|
||||
|
||||
private:
|
||||
ev_timer t_;
|
||||
ev_io *w_;
|
||||
struct ev_loop *loop_;
|
||||
SSL *ssl_;
|
||||
Connection *conn_;
|
||||
size_t rate_;
|
||||
size_t burst_;
|
||||
size_t avail_;
|
||||
|
|
Loading…
Reference in New Issue