diff --git a/src/shrpx_http3_upstream.cc b/src/shrpx_http3_upstream.cc index e21ec9bd..c0443700 100644 --- a/src/shrpx_http3_upstream.cc +++ b/src/shrpx_http3_upstream.cc @@ -118,6 +118,7 @@ size_t downstream_queue_size(Worker *worker) { Http3Upstream::Http3Upstream(ClientHandler *handler) : handler_{handler}, qlog_fd_{-1}, + hashed_scid_{}, conn_{nullptr}, tls_alert_{0}, httpconn_{nullptr}, @@ -636,7 +637,12 @@ int Http3Upstream::init(const UpstreamAddr *faddr, const Address &remote_addr, auto quic_connection_handler = worker->get_quic_connection_handler(); - quic_connection_handler->add_connection_id(&initial_hd.dcid, handler_); + if (generate_quic_hashed_connection_id(hashed_scid_, remote_addr, local_addr, + initial_hd.dcid) != 0) { + return -1; + } + + quic_connection_handler->add_connection_id(&hashed_scid_, handler_); quic_connection_handler->add_connection_id(&scid, handler_); return 0; @@ -1324,8 +1330,7 @@ void Http3Upstream::on_handler_delete() { auto worker = handler_->get_worker(); auto quic_conn_handler = worker->get_quic_connection_handler(); - quic_conn_handler->remove_connection_id( - ngtcp2_conn_get_client_initial_dcid(conn_)); + quic_conn_handler->remove_connection_id(&hashed_scid_); std::vector scids(ngtcp2_conn_get_num_scid(conn_)); ngtcp2_conn_get_scid(conn_, scids.data()); diff --git a/src/shrpx_http3_upstream.h b/src/shrpx_http3_upstream.h index 0a33dd9f..0778a3f5 100644 --- a/src/shrpx_http3_upstream.h +++ b/src/shrpx_http3_upstream.h @@ -159,6 +159,7 @@ private: ev_timer shutdown_timer_; ev_prepare prep_; int qlog_fd_; + ngtcp2_cid hashed_scid_; ngtcp2_conn *conn_; quic::Error last_error_; uint8_t tls_alert_; diff --git a/src/shrpx_quic.cc b/src/shrpx_quic.cc index 1c6bee52..fc710b48 100644 --- a/src/shrpx_quic.cc +++ b/src/shrpx_quic.cc @@ -43,8 +43,6 @@ #include "util.h" #include "xsi_strerror.h" -using namespace nghttp2; - bool operator==(const ngtcp2_cid &lhs, const ngtcp2_cid &rhs) { return ngtcp2_cid_eq(&lhs, &rhs); } @@ -213,6 +211,32 @@ int decrypt_quic_connection_id(uint8_t *dest, const uint8_t *src, return 0; } +int generate_quic_hashed_connection_id(ngtcp2_cid &dest, + const Address &remote_addr, + const Address &local_addr, + const ngtcp2_cid &cid) { + auto ctx = EVP_MD_CTX_new(); + auto d = defer(EVP_MD_CTX_free, ctx); + + std::array h; + unsigned int hlen = EVP_MD_size(EVP_sha256()); + + if (!EVP_DigestInit_ex(ctx, EVP_sha256(), nullptr) || + !EVP_DigestUpdate(ctx, &remote_addr.su.sa, remote_addr.len) || + !EVP_DigestUpdate(ctx, &local_addr.su.sa, local_addr.len) || + !EVP_DigestUpdate(ctx, cid.data, cid.datalen) || + !EVP_DigestFinal_ex(ctx, h.data(), &hlen)) { + return -1; + } + + assert(hlen == h.size()); + + std::copy_n(std::begin(h), sizeof(dest.data), std::begin(dest.data)); + dest.datalen = sizeof(dest.data); + + return 0; +} + int generate_quic_stateless_reset_token(uint8_t *token, const ngtcp2_cid *cid, const uint8_t *secret, size_t secretlen) { diff --git a/src/shrpx_quic.h b/src/shrpx_quic.h index 4938a15c..8bf6e4ab 100644 --- a/src/shrpx_quic.h +++ b/src/shrpx_quic.h @@ -33,6 +33,10 @@ #include +#include "network.h" + +using namespace nghttp2; + namespace std { template <> struct hash { std::size_t operator()(const ngtcp2_cid &cid) const noexcept { @@ -86,6 +90,11 @@ int encrypt_quic_connection_id(uint8_t *dest, const uint8_t *src, int decrypt_quic_connection_id(uint8_t *dest, const uint8_t *src, const uint8_t *key); +int generate_quic_hashed_connection_id(ngtcp2_cid &dest, + const Address &remote_addr, + const Address &local_addr, + const ngtcp2_cid &cid); + int generate_quic_stateless_reset_token(uint8_t *token, const ngtcp2_cid *cid, const uint8_t *secret, size_t secretlen); diff --git a/src/shrpx_quic_connection_handler.cc b/src/shrpx_quic_connection_handler.cc index 49991f03..38acdf31 100644 --- a/src/shrpx_quic_connection_handler.cc +++ b/src/shrpx_quic_connection_handler.cc @@ -93,6 +93,15 @@ int QUICConnectionHandler::handle_packet(const UpstreamAddr *faddr, auto &quicconf = config->quic; auto it = connections_.find(dcid_key); + if ((data[0] & 0x80) && it == std::end(connections_)) { + if (generate_quic_hashed_connection_id(dcid_key, remote_addr, local_addr, + dcid_key) != 0) { + return 0; + } + + it = connections_.find(dcid_key); + } + if (it == std::end(connections_)) { std::array decrypted_dcid;