diff --git a/src/HttpServer.cc b/src/HttpServer.cc index 99752138..e112edd5 100644 --- a/src/HttpServer.cc +++ b/src/HttpServer.cc @@ -196,6 +196,7 @@ Http2Handler::~Http2Handler() SSL_shutdown(ssl_); } if(bev_) { + bufferevent_disable(bev_, EV_READ | EV_WRITE); bufferevent_free(bev_); } if(ssl_) { @@ -398,19 +399,30 @@ int Http2Handler::verify_npn_result() { const unsigned char *next_proto = nullptr; unsigned int next_proto_len; + // Check the negotiated protocol in NPN or ALPN SSL_get0_next_proto_negotiated(ssl_, &next_proto, &next_proto_len); - if(next_proto) { - std::string proto(next_proto, next_proto+next_proto_len); - if(sessions_->get_config()->verbose) { - std::cout << "The negotiated next protocol: " << proto << std::endl; - } - if(proto == NGHTTP2_PROTO_VERSION_ID) { - return 0; + for(int i = 0; i < 2; ++i) { + if(next_proto) { + std::string proto(next_proto, next_proto+next_proto_len); + if(sessions_->get_config()->verbose) { + std::cout << "The negotiated protocol: " << proto << std::endl; + } + if(proto == NGHTTP2_PROTO_VERSION_ID) { + return 0; + } + break; + } else { +#if OPENSSL_VERSION_NUMBER >= 0x10002000L + SSL_get0_alpn_selected(ssl_, &next_proto, &next_proto_len); +#else // OPENSSL_VERSION_NUMBER < 0x10002000L + break; +#endif // OPENSSL_VERSION_NUMBER < 0x10002000L } } - std::cerr << "The negotiated next protocol is not supported." + std::cerr << "Client did not advertise HTTP/2.0 protocol." + << " (nghttp2 expects " << NGHTTP2_PROTO_VERSION_ID << ")" << std::endl; - return 0; + return -1; } int Http2Handler::sendcb(const uint8_t *data, size_t len) @@ -1094,6 +1106,33 @@ int start_listen(event_base *evbase, Sessions *sessions, } } // namespace +#if OPENSSL_VERSION_NUMBER >= 0x10002000L +namespace { +int alpn_select_proto_cb(SSL* ssl, + const unsigned char **out, unsigned char *outlen, + const unsigned char *in, unsigned int inlen, + void *arg) +{ + auto config = reinterpret_cast(arg)->get_config(); + if(config->verbose) { + std::cout << "[ALPN] client offers:" << std::endl; + } + if(config->verbose) { + for(unsigned int i = 0; i < inlen; i += in[i]+1) { + std::cout << " * "; + std::cout.write(reinterpret_cast(&in[i+1]), in[i]); + std::cout << std::endl; + } + } + if(nghttp2_select_next_protocol(const_cast(out), outlen, + in, inlen) <= 0) { + return SSL_TLSEXT_ERR_NOACK; + } + return SSL_TLSEXT_ERR_OK; +} +} // namespace +#endif // OPENSSL_VERSION_NUMBER >= 0x10002000L + int HttpServer::run() { SSL_CTX *ssl_ctx = nullptr; @@ -1138,6 +1177,10 @@ int HttpServer::run() next_proto.second = proto_list[0] + 1; SSL_CTX_set_next_protos_advertised_cb(ssl_ctx, next_proto_cb, &next_proto); +#if OPENSSL_VERSION_NUMBER >= 0x10002000L + // ALPN selection callback + SSL_CTX_set_alpn_select_cb(ssl_ctx, alpn_select_proto_cb, this); +#endif // OPENSSL_VERSION_NUMBER >= 0x10002000L } auto evbase = event_base_new(); @@ -1152,4 +1195,9 @@ int HttpServer::run() return 0; } +const Config* HttpServer::get_config() const +{ + return config_; +} + } // namespace nghttp2 diff --git a/src/HttpServer.h b/src/HttpServer.h index 6eb26164..9b0a2198 100644 --- a/src/HttpServer.h +++ b/src/HttpServer.h @@ -137,6 +137,7 @@ public: HttpServer(const Config* config); int listen(); int run(); + const Config* get_config() const; private: const Config *config_; }; diff --git a/src/nghttp.cc b/src/nghttp.cc index 8fb6047a..7c736bd7 100644 --- a/src/nghttp.cc +++ b/src/nghttp.cc @@ -1223,6 +1223,15 @@ void print_stats(const HttpClient& client) } } // namespace +namespace { +void print_protocol_nego_error() +{ + std::cerr << "Server did not select HTTP/2.0 protocol." + << " (nghttp2 expects " << NGHTTP2_PROTO_VERSION_ID << ")" + << std::endl; +} +} // namespace + namespace { int client_select_next_proto_cb(SSL* ssl, unsigned char **out, unsigned char *outlen, @@ -1231,8 +1240,7 @@ int client_select_next_proto_cb(SSL* ssl, { if(config.verbose) { print_timer(); - std::cout << " NPN select next protocol: the remote server offers:" - << std::endl; + std::cout << "[NPN] server offers:" << std::endl; } for(unsigned int i = 0; i < inlen; i += in[i]+1) { if(config.verbose) { @@ -1242,15 +1250,8 @@ int client_select_next_proto_cb(SSL* ssl, } } if(nghttp2_select_next_protocol(out, outlen, in, inlen) <= 0) { - std::cerr << "Server did not advertise HTTP/2.0 protocol." - << " (nghttp2 expects " << NGHTTP2_PROTO_VERSION_ID << ")" - << std::endl; - } else { - if(config.verbose) { - std::cout << " NPN selected the protocol: "; - std::cout.write(reinterpret_cast(*out), (size_t)*outlen); - std::cout << std::endl; - } + print_protocol_nego_error(); + return SSL_TLSEXT_ERR_NOACK; } return SSL_TLSEXT_ERR_OK; } @@ -1312,7 +1313,37 @@ void eventcb(bufferevent *bev, short events, void *ptr) if(client->need_upgrade()) { rv = client->on_upgrade_connect(); } else { - // TODO Check NPN result and fail fast? + // Check NPN or ALPN result + const unsigned char *next_proto = nullptr; + unsigned int next_proto_len; + SSL_get0_next_proto_negotiated(client->ssl, + &next_proto, &next_proto_len); + for(int i = 0; i < 2; ++i) { + if(next_proto) { + if(config.verbose) { + std::cout << "The negotiated protocol: "; + std::cout.write(reinterpret_cast(next_proto), + next_proto_len); + std::cout << std::endl; + } + if(NGHTTP2_PROTO_VERSION_ID_LEN != next_proto_len || + memcmp(NGHTTP2_PROTO_VERSION_ID, next_proto, + NGHTTP2_PROTO_VERSION_ID_LEN) != 0) { + next_proto = nullptr; + } + break; + } +#if OPENSSL_VERSION_NUMBER >= 0x10002000L + SSL_get0_alpn_selected(client->ssl, &next_proto, &next_proto_len); +#else // OPENSSL_VERSION_NUMBER < 0x10002000L + break; +#endif // OPENSSL_VERSION_NUMBER < 0x10002000L + } + if(!next_proto) { + print_protocol_nego_error(); + client->disconnect(); + return; + } rv = client->on_connect(); } if(rv != 0) { @@ -1404,6 +1435,14 @@ int communicate(const std::string& scheme, const std::string& host, } SSL_CTX_set_next_proto_select_cb(ssl_ctx, client_select_next_proto_cb, nullptr); + +#if OPENSSL_VERSION_NUMBER >= 0x10002000L + unsigned char proto_list[255]; + proto_list[0] = NGHTTP2_PROTO_VERSION_ID_LEN; + memcpy(&proto_list[1], NGHTTP2_PROTO_VERSION_ID, + NGHTTP2_PROTO_VERSION_ID_LEN); + SSL_CTX_set_alpn_protos(ssl_ctx, proto_list, proto_list[0] + 1); +#endif // OPENSSL_VERSION_NUMBER >= 0x10002000L } { HttpClient client{callbacks, evbase, ssl_ctx}; diff --git a/src/shrpx_client_handler.cc b/src/shrpx_client_handler.cc index 9e101eb1..a5667053 100644 --- a/src/shrpx_client_handler.cc +++ b/src/shrpx_client_handler.cc @@ -299,26 +299,37 @@ int ClientHandler::validate_next_proto() // First set callback for catch all cases set_bev_cb(upstream_readcb, upstream_writecb, upstream_eventcb); SSL_get0_next_proto_negotiated(ssl_, &next_proto, &next_proto_len); - if(next_proto) { - std::string proto(next_proto, next_proto+next_proto_len); - if(LOG_ENABLED(INFO)) { - CLOG(INFO, this) << "The negotiated next protocol: " << proto; - } - if(proto == NGHTTP2_PROTO_VERSION_ID) { - set_bev_cb(upstream_http2_connhd_readcb, upstream_writecb, - upstream_eventcb); - upstream_ = util::make_unique(this); - return 0; - } else { -#ifdef HAVE_SPDYLAY - uint16_t version = spdylay_npn_get_version(next_proto, next_proto_len); - if(version) { - upstream_ = util::make_unique(version, this); - return 0; + for(int i = 0; i < 2; ++i) { + if(next_proto) { + if(LOG_ENABLED(INFO)) { + std::string proto(next_proto, next_proto+next_proto_len); + CLOG(INFO, this) << "The negotiated next protocol: " << proto; } + if(next_proto_len == NGHTTP2_PROTO_VERSION_ID_LEN && + memcmp(NGHTTP2_PROTO_VERSION_ID, next_proto, + NGHTTP2_PROTO_VERSION_ID_LEN) == 0) { + set_bev_cb(upstream_http2_connhd_readcb, upstream_writecb, + upstream_eventcb); + upstream_ = util::make_unique(this); + return 0; + } else { +#ifdef HAVE_SPDYLAY + uint16_t version = spdylay_npn_get_version(next_proto, next_proto_len); + if(version) { + upstream_ = util::make_unique(version, this); + return 0; + } #endif // HAVE_SPDYLAY + } + break; } - } else { +#if OPENSSL_VERSION_NUMBER >= 0x10002000L + SSL_get0_alpn_selected(ssl_, &next_proto, &next_proto_len); +#else // OPENSSL_VERSION_NUMBER < 0x10002000L + break; +#endif // OPENSSL_VERSION_NUMBER < 0x10002000L + } + if(!next_proto) { if(LOG_ENABLED(INFO)) { CLOG(INFO, this) << "No proto negotiated."; } diff --git a/src/shrpx_http2_session.cc b/src/shrpx_http2_session.cc index c7ea40a3..fb2fdffd 100644 --- a/src/shrpx_http2_session.cc +++ b/src/shrpx_http2_session.cc @@ -1120,15 +1120,30 @@ int on_unknown_frame_recv_callback(nghttp2_session *session, int Http2Session::on_connect() { int rv; - const unsigned char *next_proto = nullptr; - unsigned int next_proto_len; if(ssl_ctx_) { + const unsigned char *next_proto = nullptr; + unsigned int next_proto_len; SSL_get0_next_proto_negotiated(ssl_, &next_proto, &next_proto_len); - std::string proto(next_proto, next_proto+next_proto_len); - if(LOG_ENABLED(INFO)) { - SSLOG(INFO, this) << "Negotiated next protocol: " << proto; + for(int i = 0; i < 2; ++i) { + if(next_proto) { + if(LOG_ENABLED(INFO)) { + std::string proto(next_proto, next_proto+next_proto_len); + SSLOG(INFO, this) << "Negotiated next protocol: " << proto; + } + if(next_proto_len != NGHTTP2_PROTO_VERSION_ID_LEN || + memcmp(NGHTTP2_PROTO_VERSION_ID, next_proto, + NGHTTP2_PROTO_VERSION_ID_LEN) != 0) { + return -1; + } + break; + } +#if OPENSSL_VERSION_NUMBER >= 0x10002000L + SSL_get0_alpn_selected(ssl_, &next_proto, &next_proto_len); +#else // OPENSSL_VERSION_NUMBER < 0x10002000L + break; +#endif // OPENSSL_VERSION_NUMBER < 0x10002000L } - if(proto != NGHTTP2_PROTO_VERSION_ID) { + if(!next_proto) { return -1; } } diff --git a/src/shrpx_ssl.cc b/src/shrpx_ssl.cc index 4699f564..9f513e7e 100644 --- a/src/shrpx_ssl.cc +++ b/src/shrpx_ssl.cc @@ -132,6 +132,23 @@ int servername_callback(SSL *ssl, int *al, void *arg) } } // namespace +#if OPENSSL_VERSION_NUMBER >= 0x10002000L +namespace { +int alpn_select_proto_cb(SSL* ssl, + const unsigned char **out, + unsigned char *outlen, + const unsigned char *in, unsigned int inlen, + void *arg) +{ + if(nghttp2_select_next_protocol + (const_cast(out), outlen, in, inlen) == -1) { + return SSL_TLSEXT_ERR_NOACK; + } + return SSL_TLSEXT_ERR_OK; +} +} // namespace +#endif // OPENSSL_VERSION_NUMBER >= 0x10002000L + SSL_CTX* create_ssl_context(const char *private_key_file, const char *cert_file) { @@ -241,11 +258,16 @@ SSL_CTX* create_ssl_context(const char *private_key_file, } SSL_CTX_set_tlsext_servername_callback(ssl_ctx, servername_callback); + // NPN advertisement auto proto_list_len = set_npn_prefs(proto_list, get_config()->npn_list, get_config()->npn_list_len); next_proto.first = proto_list; next_proto.second = proto_list_len; SSL_CTX_set_next_protos_advertised_cb(ssl_ctx, next_proto_cb, &next_proto); +#if OPENSSL_VERSION_NUMBER >= 0x10002000L + // ALPN selection callback + SSL_CTX_set_alpn_select_cb(ssl_ctx, alpn_select_proto_cb, nullptr); +#endif // OPENSSL_VERSION_NUMBER >= 0x10002000L return ssl_ctx; } @@ -322,8 +344,18 @@ SSL_CTX* create_ssl_client_context() DIE(); } } - + // NPN selection callback SSL_CTX_set_next_proto_select_cb(ssl_ctx, select_next_proto_cb, nullptr); + +#if OPENSSL_VERSION_NUMBER >= 0x10002000L + // ALPN advertisement + auto proto_list_len = set_npn_prefs(proto_list, get_config()->npn_list, + get_config()->npn_list_len); + next_proto.first = proto_list; + next_proto.second = proto_list_len; + SSL_CTX_set_alpn_protos(ssl_ctx, proto_list, proto_list[0] + 1); +#endif // OPENSSL_VERSION_NUMBER >= 0x10002000L + return ssl_ctx; }