diff --git a/gennghttpxfun.py b/gennghttpxfun.py index f75add0b..500d47a6 100755 --- a/gennghttpxfun.py +++ b/gennghttpxfun.py @@ -92,6 +92,7 @@ OPTIONS = [ "tls-ticket-key-cipher", "host-rewrite", "tls-session-cache-memcached", + "tls-session-cache-memcached-tls", "tls-ticket-key-memcached", "tls-ticket-key-memcached-interval", "tls-ticket-key-memcached-max-retry", diff --git a/src/shrpx.cc b/src/shrpx.cc index e93c5d67..04e85664 100644 --- a/src/shrpx.cc +++ b/src/shrpx.cc @@ -1564,6 +1564,9 @@ SSL/TLS: Specify address of memcached server to store session cache. This enables shared session cache between multiple nghttpx instances. + --tls-session-cache-memcached-tls + Enable SSL/TLS on memcached connections to store session + cache. --tls-dyn-rec-warmup-threshold= Specify the threshold size for TLS dynamic record size behaviour. During a TLS session, after the threshold @@ -2400,6 +2403,7 @@ int main(int argc, char **argv) { {SHRPX_OPT_BACKEND_HTTP1_TLS, no_argument, &flag, 106}, {SHRPX_OPT_BACKEND_TLS_SESSION_CACHE_PER_WORKER, required_argument, &flag, 107}, + {SHRPX_OPT_TLS_SESSION_CACHE_MEMCACHED_TLS, no_argument, &flag, 108}, {nullptr, 0, nullptr, 0}}; int option_index = 0; @@ -2858,6 +2862,10 @@ int main(int argc, char **argv) { cmdcfgs.emplace_back(SHRPX_OPT_BACKEND_TLS_SESSION_CACHE_PER_WORKER, optarg); break; + case 108: + // --tls-session-cache-memcached-tls + cmdcfgs.emplace_back(SHRPX_OPT_TLS_SESSION_CACHE_MEMCACHED_TLS, "yes"); + break; default: break; } diff --git a/src/shrpx_config.cc b/src/shrpx_config.cc index 102319e5..68e4da5b 100644 --- a/src/shrpx_config.cc +++ b/src/shrpx_config.cc @@ -758,6 +758,7 @@ enum { SHRPX_OPTID_TLS_DYN_REC_WARMUP_THRESHOLD, SHRPX_OPTID_TLS_PROTO_LIST, SHRPX_OPTID_TLS_SESSION_CACHE_MEMCACHED, + SHRPX_OPTID_TLS_SESSION_CACHE_MEMCACHED_TLS, SHRPX_OPTID_TLS_TICKET_KEY_CIPHER, SHRPX_OPTID_TLS_TICKET_KEY_FILE, SHRPX_OPTID_TLS_TICKET_KEY_MEMCACHED, @@ -1337,6 +1338,15 @@ int option_lookup_token(const char *name, size_t namelen) { break; } break; + case 31: + switch (name[30]) { + case 's': + if (util::strieq_l("tls-session-cache-memcached-tl", name, 30)) { + return SHRPX_OPTID_TLS_SESSION_CACHE_MEMCACHED_TLS; + } + break; + } + break; case 33: switch (name[32]) { case 'l': @@ -2229,6 +2239,10 @@ int parse_config(const char *opt, const char *optarg, case SHRPX_OPTID_BACKEND_TLS_SESSION_CACHE_PER_WORKER: return parse_uint(&mod_config()->tls.downstream_session_cache_per_worker, opt, optarg); + case SHRPX_OPTID_TLS_SESSION_CACHE_MEMCACHED_TLS: + mod_config()->tls.session_cache.memcached.tls = util::strieq(optarg, "yes"); + + return 0; case SHRPX_OPTID_CONF: LOG(WARN) << "conf: ignored"; diff --git a/src/shrpx_config.h b/src/shrpx_config.h index 2d511fda..5a67d06e 100644 --- a/src/shrpx_config.h +++ b/src/shrpx_config.h @@ -209,6 +209,8 @@ constexpr char SHRPX_OPT_NO_HTTP2_CIPHER_BLACK_LIST[] = constexpr char SHRPX_OPT_BACKEND_HTTP1_TLS[] = "backend-http1-tls"; constexpr char SHRPX_OPT_BACKEND_TLS_SESSION_CACHE_PER_WORKER[] = "backend-tls-session-cache-per-worker"; +constexpr char SHRPX_OPT_TLS_SESSION_CACHE_MEMCACHED_TLS[] = + "tls-session-cache-memcached-tls"; constexpr size_t SHRPX_OBFUSCATED_NODE_LENGTH = 8; @@ -355,6 +357,7 @@ struct TLSConfig { Address addr; uint16_t port; std::unique_ptr host; + bool tls; } memcached; } session_cache; diff --git a/src/shrpx_connection_handler.cc b/src/shrpx_connection_handler.cc index c9066b69..1ff4c7ef 100644 --- a/src/shrpx_connection_handler.cc +++ b/src/shrpx_connection_handler.cc @@ -193,8 +193,21 @@ int ConnectionHandler::create_single_worker() { all_ssl_ctx_.push_back(cl_ssl_ctx); } - single_worker_ = make_unique(loop_, sv_ssl_ctx, cl_ssl_ctx, cert_tree, - ticket_keys_); + auto &session_cacheconf = get_config()->tls.session_cache; + + SSL_CTX *session_cache_ssl_ctx = nullptr; + if (session_cacheconf.memcached.tls) { + session_cache_ssl_ctx = ssl::create_ssl_client_context( +#ifdef HAVE_NEVERBLEED + nb_.get(), +#endif // HAVE_NEVERBLEED + nullptr, nullptr, nullptr, StringRef(), nullptr); + all_ssl_ctx_.push_back(session_cache_ssl_ctx); + } + + single_worker_ = + make_unique(loop_, sv_ssl_ctx, cl_ssl_ctx, session_cache_ssl_ctx, + cert_tree, ticket_keys_); #ifdef HAVE_MRUBY if (single_worker_->create_mruby_context() != 0) { return -1; @@ -225,11 +238,23 @@ int ConnectionHandler::create_worker_thread(size_t num) { all_ssl_ctx_.push_back(cl_ssl_ctx); } + auto &session_cacheconf = get_config()->tls.session_cache; + for (size_t i = 0; i < num; ++i) { auto loop = ev_loop_new(0); - auto worker = make_unique(loop, sv_ssl_ctx, cl_ssl_ctx, cert_tree, - ticket_keys_); + SSL_CTX *session_cache_ssl_ctx = nullptr; + if (session_cacheconf.memcached.tls) { + session_cache_ssl_ctx = ssl::create_ssl_client_context( +#ifdef HAVE_NEVERBLEED + nb_.get(), +#endif // HAVE_NEVERBLEED + nullptr, nullptr, nullptr, StringRef(), nullptr); + all_ssl_ctx_.push_back(session_cache_ssl_ctx); + } + auto worker = + make_unique(loop, sv_ssl_ctx, cl_ssl_ctx, session_cache_ssl_ctx, + cert_tree, ticket_keys_); #ifdef HAVE_MRUBY if (worker->create_mruby_context() != 0) { return -1; diff --git a/src/shrpx_memcached_connection.cc b/src/shrpx_memcached_connection.cc index 8c4d1cc2..4df031d8 100644 --- a/src/shrpx_memcached_connection.cc +++ b/src/shrpx_memcached_connection.cc @@ -32,6 +32,7 @@ #include "shrpx_memcached_request.h" #include "shrpx_memcached_result.h" #include "shrpx_config.h" +#include "shrpx_ssl.h" #include "util.h" namespace shrpx { @@ -78,7 +79,7 @@ void connectcb(struct ev_loop *loop, ev_io *w, int revents) { auto conn = static_cast(w->data); auto mconn = static_cast(conn->data); - if (mconn->on_connect() != 0) { + if (mconn->connected() != 0) { mconn->disconnect(); return; } @@ -91,11 +92,17 @@ constexpr ev_tstamp write_timeout = 10.; constexpr ev_tstamp read_timeout = 10.; MemcachedConnection::MemcachedConnection(const Address *addr, - struct ev_loop *loop) - : conn_(loop, -1, nullptr, nullptr, write_timeout, read_timeout, {}, {}, + struct ev_loop *loop, SSL_CTX *ssl_ctx, + const StringRef &sni_name, + MemchunkPool *mcpool) + : conn_(loop, -1, nullptr, mcpool, write_timeout, read_timeout, {}, {}, connectcb, readcb, timeoutcb, this, 0, 0.), + do_read_(&MemcachedConnection::noop), + do_write_(&MemcachedConnection::noop), + sni_name_(sni_name.str()), parse_state_{}, addr_(addr), + ssl_ctx_(ssl_ctx), sendsum_(0), connected_(false) {} @@ -127,11 +134,21 @@ void MemcachedConnection::disconnect() { assert(recvbuf_.rleft() == 0); recvbuf_.reset(); + + do_read_ = do_write_ = &MemcachedConnection::noop; } int MemcachedConnection::initiate_connection() { assert(conn_.fd == -1); + if (ssl_ctx_ && !conn_.tls.ssl) { + auto ssl = ssl::create_ssl(ssl_ctx_); + if (!ssl) { + return -1; + } + conn_.set_ssl(ssl); + } + conn_.fd = util::create_nonblock_socket(addr_->su.storage.ss_family); if (conn_.fd == -1) { @@ -153,6 +170,14 @@ int MemcachedConnection::initiate_connection() { return -1; } + if (ssl_ctx_) { + if (!util::numeric_host(sni_name_.c_str())) { + SSL_set_tlsext_host_name(conn_.tls.ssl, sni_name_.c_str()); + } + + conn_.prepare_client_handshake(); + } + if (LOG_ENABLED(INFO)) { MCLOG(INFO, this) << "Connecting to memcached server"; } @@ -168,7 +193,7 @@ int MemcachedConnection::initiate_connection() { return 0; } -int MemcachedConnection::on_connect() { +int MemcachedConnection::connected() { if (!util::check_socket_connected(conn_.fd)) { conn_.wlimit.stopw(); @@ -185,15 +210,59 @@ int MemcachedConnection::on_connect() { connected_ = true; - ev_set_cb(&conn_.wev, writecb); - conn_.rlimit.startw(); ev_timer_again(conn_.loop, &conn_.rt); + ev_set_cb(&conn_.wev, writecb); + + if (conn_.tls.ssl) { + do_read_ = &MemcachedConnection::tls_handshake; + do_write_ = &MemcachedConnection::tls_handshake; + + return 0; + } + + do_read_ = &MemcachedConnection::read_clear; + do_write_ = &MemcachedConnection::write_clear; + return 0; } -int MemcachedConnection::on_write() { +int MemcachedConnection::on_write() { return do_write_(*this); } +int MemcachedConnection::on_read() { return do_read_(*this); } + +int MemcachedConnection::tls_handshake() { + ERR_clear_error(); + + ev_timer_again(conn_.loop, &conn_.rt); + + auto rv = conn_.tls_handshake(); + if (rv == SHRPX_ERR_INPROGRESS) { + return 0; + } + + if (rv < 0) { + return rv; + } + + if (LOG_ENABLED(INFO)) { + LOG(INFO) << "SSL/TLS handshake completed"; + } + + auto &tlsconf = get_config()->tls; + + if (!tlsconf.insecure && + ssl::check_cert(conn_.tls.ssl, addr_, StringRef(sni_name_)) != 0) { + return -1; + } + + do_read_ = &MemcachedConnection::read_tls; + do_write_ = &MemcachedConnection::write_tls; + + return on_write(); +} + +int MemcachedConnection::write_tls() { if (!connected_) { return 0; } @@ -207,19 +276,30 @@ int MemcachedConnection::on_write() { return 0; } - int rv; + std::array iov; + std::array buf; for (; !sendq_.empty();) { - rv = send_request(); + auto iovcnt = fill_request_buffer(iov.data(), iov.size()); + auto p = std::begin(buf); + for (size_t i = 0; i < iovcnt; ++i) { + auto &v = iov[i]; + auto n = std::min(static_cast(std::end(buf) - p), v.iov_len); + p = std::copy_n(static_cast(v.iov_base), n, p); + if (p == std::end(buf)) { + break; + } + } - if (rv < 0) { + auto nwrite = conn_.write_tls(buf.data(), p - std::begin(buf)); + if (nwrite < 0) { return -1; } - - if (rv == 1) { - // blocked + if (nwrite == 0) { return 0; } + + drain_send_queue(nwrite); } conn_.wlimit.stopw(); @@ -228,7 +308,70 @@ int MemcachedConnection::on_write() { return 0; } -int MemcachedConnection::on_read() { +int MemcachedConnection::read_tls() { + if (!connected_) { + return 0; + } + + ev_timer_again(conn_.loop, &conn_.rt); + + for (;;) { + auto nread = conn_.read_tls(recvbuf_.last, recvbuf_.wleft()); + + if (nread == 0) { + return 0; + } + + if (nread < 0) { + return -1; + } + + recvbuf_.write(nread); + + if (parse_packet() != 0) { + return -1; + } + } + + return 0; +} + +int MemcachedConnection::write_clear() { + if (!connected_) { + return 0; + } + + ev_timer_again(conn_.loop, &conn_.rt); + + if (sendq_.empty()) { + conn_.wlimit.stopw(); + ev_timer_stop(conn_.loop, &conn_.wt); + + return 0; + } + + std::array iov; + + for (; !sendq_.empty();) { + auto iovcnt = fill_request_buffer(iov.data(), iov.size()); + auto nwrite = conn_.writev_clear(iov.data(), iovcnt); + if (nwrite < 0) { + return -1; + } + if (nwrite == 0) { + return 0; + } + + drain_send_queue(nwrite); + } + + conn_.wlimit.stopw(); + ev_timer_stop(conn_.loop, &conn_.wt); + + return 0; +} + +int MemcachedConnection::read_clear() { if (!connected_) { return 0; } @@ -415,9 +558,8 @@ int MemcachedConnection::parse_packet() { #define MAX_WR_IOVCNT DEFAULT_WR_IOVCNT #endif // !defined(IOV_MAX) || IOV_MAX >= DEFAULT_WR_IOVCNT -int MemcachedConnection::send_request() { - ssize_t nwrite; - +size_t MemcachedConnection::fill_request_buffer(struct iovec *iov, + size_t iovlen) { if (sendsum_ == 0) { for (auto &req : sendq_) { if (req->canceled) { @@ -438,32 +580,27 @@ int MemcachedConnection::send_request() { } } - std::array iov; - size_t iovlen = 0; + size_t iovcnt = 0; for (auto &buf : sendbufv_) { - if (iovlen + 2 > iov.size()) { + if (iovcnt + 2 > iovlen) { break; } auto req = buf.req; if (buf.headbuf.rleft()) { - iov[iovlen++] = {buf.headbuf.pos, buf.headbuf.rleft()}; + iov[iovcnt++] = {buf.headbuf.pos, buf.headbuf.rleft()}; } if (buf.send_value_left) { - iov[iovlen++] = {req->value.data() + req->value.size() - + iov[iovcnt++] = {req->value.data() + req->value.size() - buf.send_value_left, buf.send_value_left}; } } - nwrite = conn_.writev_clear(iov.data(), iovlen); - if (nwrite < 0) { - return -1; - } - if (nwrite == 0) { - return 1; - } + return iovcnt; +} +void MemcachedConnection::drain_send_queue(size_t nwrite) { sendsum_ -= nwrite; while (nwrite > 0) { @@ -488,8 +625,6 @@ int MemcachedConnection::send_request() { recvq_.push_back(std::move(sendq_.front())); sendq_.pop_front(); } - - return 0; } size_t MemcachedConnection::serialized_size(MemcachedRequest *req) { @@ -549,4 +684,6 @@ int MemcachedConnection::add_request(std::unique_ptr req) { // TODO should we start write timer too? void MemcachedConnection::signal_write() { conn_.wlimit.startw(); } +int MemcachedConnection::noop() { return 0; } + } // namespace shrpx diff --git a/src/shrpx_memcached_connection.h b/src/shrpx_memcached_connection.h index 43d27c46..c9552198 100644 --- a/src/shrpx_memcached_connection.h +++ b/src/shrpx_memcached_connection.h @@ -93,7 +93,9 @@ constexpr uint8_t MEMCACHED_RES_MAGIC = 0x81; // https://code.google.com/p/memcached/wiki/MemcacheBinaryProtocol class MemcachedConnection { public: - MemcachedConnection(const Address *addr, struct ev_loop *loop); + MemcachedConnection(const Address *addr, struct ev_loop *loop, + SSL_CTX *ssl_ctx, const StringRef &sni_name, + MemchunkPool *mcpool); ~MemcachedConnection(); void disconnect(); @@ -101,23 +103,38 @@ public: int add_request(std::unique_ptr req); int initiate_connection(); - int on_connect(); + int connected(); int on_write(); int on_read(); - int send_request(); + + int write_clear(); + int read_clear(); + + int tls_handshake(); + int write_tls(); + int read_tls(); + + size_t fill_request_buffer(struct iovec *iov, size_t iovlen); + void drain_send_queue(size_t nwrite); + void make_request(MemcachedSendbuf *sendbuf, MemcachedRequest *req); int parse_packet(); size_t serialized_size(MemcachedRequest *req); void signal_write(); + int noop(); + private: Connection conn_; std::deque> recvq_; std::deque> sendq_; std::deque sendbufv_; + std::function do_read_, do_write_; + std::string sni_name_; MemcachedParseState parse_state_; const Address *addr_; + SSL_CTX *ssl_ctx_; // Sum of the bytes to be transmitted in sendbufv_. size_t sendsum_; bool connected_; diff --git a/src/shrpx_memcached_dispatcher.cc b/src/shrpx_memcached_dispatcher.cc index 28c00f13..796fef8e 100644 --- a/src/shrpx_memcached_dispatcher.cc +++ b/src/shrpx_memcached_dispatcher.cc @@ -31,8 +31,12 @@ namespace shrpx { MemcachedDispatcher::MemcachedDispatcher(const Address *addr, - struct ev_loop *loop) - : loop_(loop), mconn_(make_unique(addr, loop_)) {} + struct ev_loop *loop, SSL_CTX *ssl_ctx, + const StringRef &sni_name, + MemchunkPool *mcpool) + : loop_(loop), + mconn_(make_unique(addr, loop_, ssl_ctx, sni_name, + mcpool)) {} MemcachedDispatcher::~MemcachedDispatcher() {} diff --git a/src/shrpx_memcached_dispatcher.h b/src/shrpx_memcached_dispatcher.h index b3ea4866..64021c68 100644 --- a/src/shrpx_memcached_dispatcher.h +++ b/src/shrpx_memcached_dispatcher.h @@ -31,6 +31,10 @@ #include +#include + +#include "memchunk.h" + namespace shrpx { struct MemcachedRequest; @@ -39,7 +43,9 @@ struct Address; class MemcachedDispatcher { public: - MemcachedDispatcher(const Address *addr, struct ev_loop *loop); + MemcachedDispatcher(const Address *addr, struct ev_loop *loop, + SSL_CTX *ssl_ctx, const StringRef &sni_name, + MemchunkPool *mcpool); ~MemcachedDispatcher(); int add_request(std::unique_ptr req); diff --git a/src/shrpx_ssl.cc b/src/shrpx_ssl.cc index 66c703ad..ee82022d 100644 --- a/src/shrpx_ssl.cc +++ b/src/shrpx_ssl.cc @@ -982,7 +982,7 @@ int verify_hostname(X509 *cert, const char *hostname, size_t hlen, } } // namespace -int check_cert(SSL *ssl, const DownstreamAddr *addr) { +int check_cert(SSL *ssl, const Address *addr, const StringRef &host) { auto cert = SSL_get_peer_certificate(ssl); if (!cert) { LOG(ERROR) << "No certificate found"; @@ -996,18 +996,21 @@ int check_cert(SSL *ssl, const DownstreamAddr *addr) { return -1; } - auto &backend_sni_name = get_config()->tls.backend_sni_name; - - auto hostname = !backend_sni_name.empty() ? StringRef(backend_sni_name) - : StringRef(addr->host); - if (verify_hostname(cert, hostname.c_str(), hostname.size(), &addr->addr) != - 0) { + if (verify_hostname(cert, host.c_str(), host.size(), addr) != 0) { LOG(ERROR) << "Certificate verification failed: hostname does not match"; return -1; } return 0; } +int check_cert(SSL *ssl, const DownstreamAddr *addr) { + auto &backend_sni_name = get_config()->tls.backend_sni_name; + + auto hostname = !backend_sni_name.empty() ? StringRef(backend_sni_name) + : StringRef(addr->host); + return check_cert(ssl, &addr->addr, hostname); +} + CertLookupTree::CertLookupTree() { root_.ssl_ctx = nullptr; root_.str = nullptr; diff --git a/src/shrpx_ssl.h b/src/shrpx_ssl.h index da7c1667..4152a8ed 100644 --- a/src/shrpx_ssl.h +++ b/src/shrpx_ssl.h @@ -46,6 +46,7 @@ class Worker; class DownstreamConnectionPool; struct DownstreamAddr; struct UpstreamAddr; +struct Address; namespace ssl { @@ -83,9 +84,8 @@ SSL_CTX *create_ssl_client_context( ClientHandler *accept_connection(Worker *worker, int fd, sockaddr *addr, int addrlen, const UpstreamAddr *faddr); -// Check peer's certificate against first downstream address in -// Config::downstream_addrs. We only consider first downstream since -// we use this function for HTTP/2 downstream link only. +// Check peer's certificate against given |address| and |host|. +int check_cert(SSL *ssl, const Address *addr, const StringRef &host); int check_cert(SSL *ssl, const DownstreamAddr *addr); // Retrieves DNS and IP address in subjectAltNames and commonName from diff --git a/src/shrpx_worker.cc b/src/shrpx_worker.cc index 5e46e398..e663eace 100644 --- a/src/shrpx_worker.cc +++ b/src/shrpx_worker.cc @@ -68,6 +68,7 @@ std::random_device rd; } // namespace Worker::Worker(struct ev_loop *loop, SSL_CTX *sv_ssl_ctx, SSL_CTX *cl_ssl_ctx, + SSL_CTX *tls_session_cache_memcached_ssl_ctx, ssl::CertLookupTree *cert_tree, const std::shared_ptr &ticket_keys) : randgen_(rd()), @@ -92,7 +93,9 @@ Worker::Worker(struct ev_loop *loop, SSL_CTX *sv_ssl_ctx, SSL_CTX *cl_ssl_ctx, if (session_cacheconf.memcached.host) { session_cache_memcached_dispatcher_ = make_unique( - &session_cacheconf.memcached.addr, loop); + &session_cacheconf.memcached.addr, loop, + tls_session_cache_memcached_ssl_ctx, + session_cacheconf.memcached.host.get(), &mcpool_); } auto &downstreamconf = get_config()->conn.downstream; diff --git a/src/shrpx_worker.h b/src/shrpx_worker.h index c578d9fb..8d352429 100644 --- a/src/shrpx_worker.h +++ b/src/shrpx_worker.h @@ -112,6 +112,7 @@ struct SessionCacheEntry { class Worker { public: Worker(struct ev_loop *loop, SSL_CTX *sv_ssl_ctx, SSL_CTX *cl_ssl_ctx, + SSL_CTX *tls_session_cache_memcached_ssl_ctx, ssl::CertLookupTree *cert_tree, const std::shared_ptr &ticket_keys); ~Worker(); diff --git a/src/shrpx_worker_process.cc b/src/shrpx_worker_process.cc index 0e05e19f..e7a30311 100644 --- a/src/shrpx_worker_process.cc +++ b/src/shrpx_worker_process.cc @@ -426,7 +426,8 @@ int worker_process_event_loop(WorkerProcessConfig *wpconf) { if (ticketconf.memcached.host) { conn_handler.set_tls_ticket_key_memcached_dispatcher( - make_unique(&ticketconf.memcached.addr, loop)); + make_unique(&ticketconf.memcached.addr, loop, + nullptr, "", nullptr)); ev_timer_init(&renew_ticket_key_timer, memcached_get_ticket_key_cb, 0., 0.);