diff --git a/src/shrpx_connection_handler.cc b/src/shrpx_connection_handler.cc index 6f2b585d..c9066b69 100644 --- a/src/shrpx_connection_handler.cc +++ b/src/shrpx_connection_handler.cc @@ -183,7 +183,7 @@ int ConnectionHandler::create_single_worker() { nb_.get() #endif // HAVE_NEVERBLEED ); - auto cl_ssl_ctx = ssl::setup_client_ssl_context( + auto cl_ssl_ctx = ssl::setup_downstream_client_ssl_context( #ifdef HAVE_NEVERBLEED nb_.get() #endif // HAVE_NEVERBLEED @@ -215,7 +215,7 @@ int ConnectionHandler::create_worker_thread(size_t num) { nb_.get() #endif // HAVE_NEVERBLEED ); - auto cl_ssl_ctx = ssl::setup_client_ssl_context( + auto cl_ssl_ctx = ssl::setup_downstream_client_ssl_context( #ifdef HAVE_NEVERBLEED nb_.get() #endif // HAVE_NEVERBLEED diff --git a/src/shrpx_ssl.cc b/src/shrpx_ssl.cc index 88b3e95d..66c703ad 100644 --- a/src/shrpx_ssl.cc +++ b/src/shrpx_ssl.cc @@ -660,9 +660,13 @@ int select_h1_next_proto_cb(SSL *ssl, unsigned char **out, SSL_CTX *create_ssl_client_context( #ifdef HAVE_NEVERBLEED - neverbleed_t *nb + neverbleed_t *nb, #endif // HAVE_NEVERBLEED - ) { + const char *cacert, const char *cert_file, const char *private_key_file, + const StringRef &alpn, + int (*next_proto_select_cb)(SSL *s, unsigned char **out, + unsigned char *outlen, const unsigned char *in, + unsigned int inlen, void *arg)) { auto ssl_ctx = SSL_CTX_new(SSLv23_client_method()); if (!ssl_ctx) { LOG(FATAL) << ERR_error_string(ERR_get_error(), nullptr); @@ -698,71 +702,52 @@ SSL_CTX *create_ssl_client_context( << ERR_error_string(ERR_get_error(), nullptr); } - if (tlsconf.cacert) { - if (SSL_CTX_load_verify_locations(ssl_ctx, tlsconf.cacert.get(), nullptr) != - 1) { + if (cacert) { + if (SSL_CTX_load_verify_locations(ssl_ctx, cacert, nullptr) != 1) { - LOG(FATAL) << "Could not load trusted ca certificates from " - << tlsconf.cacert.get() << ": " - << ERR_error_string(ERR_get_error(), nullptr); + LOG(FATAL) << "Could not load trusted ca certificates from " << cacert + << ": " << ERR_error_string(ERR_get_error(), nullptr); DIE(); } } - if (tlsconf.client.private_key_file) { + if (cert_file) { + if (SSL_CTX_use_certificate_chain_file(ssl_ctx, cert_file) != 1) { + + LOG(FATAL) << "Could not load client certificate from " << cert_file + << ": " << ERR_error_string(ERR_get_error(), nullptr); + DIE(); + } + } + + if (private_key_file) { #ifndef HAVE_NEVERBLEED - if (SSL_CTX_use_PrivateKey_file(ssl_ctx, - tlsconf.client.private_key_file.get(), + if (SSL_CTX_use_PrivateKey_file(ssl_ctx, private_key_file, SSL_FILETYPE_PEM) != 1) { LOG(FATAL) << "Could not load client private key from " - << tlsconf.client.private_key_file.get() << ": " + << private_key_file << ": " << ERR_error_string(ERR_get_error(), nullptr); DIE(); } #else // HAVE_NEVERBLEED std::array errbuf; - if (neverbleed_load_private_key_file(nb, ssl_ctx, - tlsconf.client.private_key_file.get(), + if (neverbleed_load_private_key_file(nb, ssl_ctx, private_key_file, errbuf.data()) != 1) { - LOG(FATAL) << "neverbleed_load_private_key_file failed: " + LOG(FATAL) << "neverbleed_load_private_key_file: could not load client " + "private key from " << private_key_file << ": " << errbuf.data(); DIE(); } #endif // HAVE_NEVERBLEED } - if (tlsconf.client.cert_file) { - if (SSL_CTX_use_certificate_chain_file( - ssl_ctx, tlsconf.client.cert_file.get()) != 1) { - LOG(FATAL) << "Could not load client certificate from " - << tlsconf.client.cert_file.get() << ": " - << ERR_error_string(ERR_get_error(), nullptr); - DIE(); - } - } - - auto &downstreamconf = get_config()->conn.downstream; - - if (downstreamconf.proto == PROTO_HTTP2) { - // NPN selection callback - SSL_CTX_set_next_proto_select_cb(ssl_ctx, select_h2_next_proto_cb, nullptr); + // NPN selection callback + SSL_CTX_set_next_proto_select_cb(ssl_ctx, next_proto_select_cb, nullptr); #if OPENSSL_VERSION_NUMBER >= 0x10002000L - // ALPN advertisement; We only advertise HTTP/2 - auto proto_list = util::get_default_alpn(); - - SSL_CTX_set_alpn_protos(ssl_ctx, proto_list.data(), proto_list.size()); + // ALPN advertisement + SSL_CTX_set_alpn_protos(ssl_ctx, alpn.byte(), alpn.size()); #endif // OPENSSL_VERSION_NUMBER >= 0x10002000L - } else { - // NPN selection callback - SSL_CTX_set_next_proto_select_cb(ssl_ctx, select_h1_next_proto_cb, nullptr); - -#if OPENSSL_VERSION_NUMBER >= 0x10002000L - SSL_CTX_set_alpn_protos( - ssl_ctx, reinterpret_cast(NGHTTP2_H1_1_ALPN), - str_size(NGHTTP2_H1_1_ALPN)); -#endif // OPENSSL_VERSION_NUMBER >= 0x10002000L - } return ssl_ctx; } @@ -1304,7 +1289,7 @@ SSL_CTX *setup_server_ssl_context(std::vector &all_ssl_ctx, bool downstream_tls_enabled() { return !get_config()->conn.downstream.no_tls; } -SSL_CTX *setup_client_ssl_context( +SSL_CTX *setup_downstream_client_ssl_context( #ifdef HAVE_NEVERBLEED neverbleed_t *nb #endif // HAVE_NEVERBLEED @@ -1313,11 +1298,30 @@ SSL_CTX *setup_client_ssl_context( return nullptr; } + auto &tlsconf = get_config()->tls; + auto &downstreamconf = get_config()->conn.downstream; + + std::vector h2alpn; + StringRef alpn; + int (*next_proto_select_cb)(SSL *s, unsigned char **out, + unsigned char *outlen, const unsigned char *in, + unsigned int inlen, void *arg); + + if (downstreamconf.proto == PROTO_HTTP2) { + h2alpn = util::get_default_alpn(); + alpn = StringRef(h2alpn.data(), h2alpn.size()); + next_proto_select_cb = select_h2_next_proto_cb; + } else { + alpn = StringRef::from_lit(NGHTTP2_H1_1_ALPN); + next_proto_select_cb = select_h1_next_proto_cb; + } + return ssl::create_ssl_client_context( #ifdef HAVE_NEVERBLEED - nb + nb, #endif // HAVE_NEVERBLEED - ); + tlsconf.cacert.get(), tlsconf.client.cert_file.get(), + tlsconf.client.private_key_file.get(), alpn, next_proto_select_cb); } CertLookupTree *create_cert_lookup_tree() { diff --git a/src/shrpx_ssl.h b/src/shrpx_ssl.h index 41c41f00..da7c1667 100644 --- a/src/shrpx_ssl.h +++ b/src/shrpx_ssl.h @@ -72,9 +72,13 @@ SSL_CTX *create_ssl_context(const char *private_key_file, const char *cert_file // Create client side SSL_CTX SSL_CTX *create_ssl_client_context( #ifdef HAVE_NEVERBLEED - neverbleed_t *nb + neverbleed_t *nb, #endif // HAVE_NEVERBLEED - ); + const char *cacert, const char *cert_file, const char *private_key_file, + const StringRef &alpn, + int (*next_proto_select_cb)(SSL *s, unsigned char **out, + unsigned char *outlen, const unsigned char *in, + unsigned int inlen, void *arg)); ClientHandler *accept_connection(Worker *worker, int fd, sockaddr *addr, int addrlen, const UpstreamAddr *faddr); @@ -190,7 +194,7 @@ SSL_CTX *setup_server_ssl_context(std::vector &all_ssl_ctx, // Setups client side SSL_CTX. This function inspects get_config() // and if downstream_no_tls is true, returns nullptr. Otherwise, only // construct SSL_CTX if either client_mode or http2_bridge is true. -SSL_CTX *setup_client_ssl_context( +SSL_CTX *setup_downstream_client_ssl_context( #ifdef HAVE_NEVERBLEED neverbleed_t *nb #endif // HAVE_NEVERBLEED diff --git a/src/template.h b/src/template.h index 8657b213..0d4c4ba2 100644 --- a/src/template.h +++ b/src/template.h @@ -392,11 +392,14 @@ public: explicit StringRef(const ImmutableString &s) : base(s.c_str()), len(s.size()) {} StringRef(const char *s) : base(s), len(strlen(s)) {} - StringRef(const char *s, size_t n) : base(s), len(n) {} + template + StringRef(const CharT *s, size_t n) + : base(reinterpret_cast(s)), len(n) {} template StringRef(InputIt first, InputIt last) : base(first), len(std::distance(first, last)) {} - template static StringRef from_lit(const char(&s)[N]) { + template + static StringRef from_lit(const CharT(&s)[N]) { return StringRef(s, N - 1); } @@ -412,6 +415,9 @@ public: const_reference operator[](size_type pos) const { return *(base + pos); } std::string str() const { return std::string(base, len); } + const uint8_t *byte() const { + return reinterpret_cast(base); + } private: const char *base;