diff --git a/src/shrpx_client_handler.cc b/src/shrpx_client_handler.cc index f7185f84..341566b6 100644 --- a/src/shrpx_client_handler.cc +++ b/src/shrpx_client_handler.cc @@ -296,6 +296,8 @@ int ClientHandler::read_quic(const UpstreamAddr *faddr, return upstream->on_read(faddr, remote_addr, local_addr, data, datalen); } +int ClientHandler::write_quic() { return upstream_->on_write(); } + int ClientHandler::upstream_noop() { return 0; } int ClientHandler::upstream_read() { @@ -429,7 +431,9 @@ ClientHandler::ClientHandler(Worker *worker, int fd, SSL *ssl, reneg_shutdown_timer_.data = this; - conn_.rlimit.startw(); + if (!faddr->quic) { + conn_.rlimit.startw(); + } ev_timer_again(conn_.loop, &conn_.rt); auto config = get_config(); @@ -509,6 +513,7 @@ void ClientHandler::setup_http3_upstream( std::unique_ptr &&upstream) { upstream_ = std::move(upstream); alpn_ = StringRef::from_lit("h3"); + write_ = &ClientHandler::write_quic; } ClientHandler::~ClientHandler() { diff --git a/src/shrpx_client_handler.h b/src/shrpx_client_handler.h index dbdff961..ef467810 100644 --- a/src/shrpx_client_handler.h +++ b/src/shrpx_client_handler.h @@ -148,6 +148,7 @@ public: void setup_upstream_io_callback(); void setup_http3_upstream(std::unique_ptr &&upstream); + int write_quic(); // Returns string suitable for use in "by" parameter of Forwarded // header field. diff --git a/src/shrpx_connection.cc b/src/shrpx_connection.cc index daec6296..d07314b6 100644 --- a/src/shrpx_connection.cc +++ b/src/shrpx_connection.cc @@ -74,7 +74,7 @@ Connection::Connection(struct ev_loop *loop, int fd, SSL *ssl, read_timeout(read_timeout) { ev_io_init(&wev, writecb, fd, EV_WRITE); - ev_io_init(&rev, readcb, fd, EV_READ); + ev_io_init(&rev, readcb, proto == Proto::HTTP3 ? 0 : fd, EV_READ); wev.data = this; rev.data = this; @@ -128,7 +128,7 @@ void Connection::disconnect() { tls.early_data_finish = false; } - if (fd != -1) { + if (proto != Proto::HTTP3 && fd != -1) { shutdown(fd, SHUT_WR); close(fd); fd = -1; diff --git a/src/shrpx_http3_upstream.cc b/src/shrpx_http3_upstream.cc index 83b957bd..040f3192 100644 --- a/src/shrpx_http3_upstream.cc +++ b/src/shrpx_http3_upstream.cc @@ -185,6 +185,7 @@ int Http3Upstream::init(const UpstreamAddr *faddr, const Address &remote_addr, ngtcp2_transport_params params; ngtcp2_transport_params_default(¶ms); + params.initial_max_streams_uni = 3; params.initial_max_data = 1_m; params.initial_max_stream_data_bidi_remote = 256_k; params.initial_max_stream_data_uni = 256_k; @@ -217,7 +218,119 @@ int Http3Upstream::init(const UpstreamAddr *faddr, const Address &remote_addr, int Http3Upstream::on_read() { return 0; } -int Http3Upstream::on_write() { return 0; } +int Http3Upstream::on_write() { + std::array buf; + size_t max_pktcnt = + std::min(static_cast(64_k), ngtcp2_conn_get_send_quantum(conn_)) / + SHRPX_MAX_UDP_PAYLOAD_SIZE; + ngtcp2_pkt_info pi; + uint8_t *bufpos = buf.data(); + ngtcp2_path_storage ps, prev_ps; + size_t pktcnt = 0; + auto ts = quic_timestamp(); + + ngtcp2_path_storage_zero(&ps); + ngtcp2_path_storage_zero(&prev_ps); + + for (;;) { + int64_t stream_id = -1; + int fin = 0; + + ngtcp2_ssize ndatalen; + + uint32_t flags = NGTCP2_WRITE_STREAM_FLAG_MORE; + if (fin) { + flags |= NGTCP2_WRITE_STREAM_FLAG_FIN; + } + + auto nwrite = ngtcp2_conn_writev_stream( + conn_, &ps.path, &pi, bufpos, SHRPX_MAX_UDP_PAYLOAD_SIZE, &ndatalen, + flags, stream_id, nullptr, 0, ts); + if (nwrite < 0) { + switch (nwrite) { + case NGTCP2_ERR_STREAM_DATA_BLOCKED: + assert(ndatalen == -1); + continue; + case NGTCP2_ERR_STREAM_SHUT_WR: + assert(ndatalen == -1); + continue; + case NGTCP2_ERR_WRITE_MORE: + assert(ndatalen >= 0); + continue; + } + + assert(ndatalen == -1); + + LOG(ERROR) << "ngtcp2_conn_writev_stream: " << ngtcp2_strerror(nwrite); + + last_error_ = quic::err_transport(nwrite); + + handler_->get_connection()->wlimit.stopw(); + + return handle_error(); + } else if (ndatalen >= 0) { + // TODO do something + } + + if (nwrite == 0) { + if (bufpos - buf.data()) { + quic_send_packet(static_cast(prev_ps.path.user_data), + prev_ps.path.remote.addr, prev_ps.path.remote.addrlen, + prev_ps.path.local.addr, prev_ps.path.local.addrlen, + buf.data(), bufpos - buf.data(), + SHRPX_MAX_UDP_PAYLOAD_SIZE); + + ngtcp2_conn_update_pkt_tx_time(conn_, ts); + // reset_idle_timer here + } + + handler_->get_connection()->wlimit.stopw(); + + return 0; + } + + bufpos += nwrite; + + if (pktcnt == 0) { + ngtcp2_path_copy(&prev_ps.path, &ps.path); + } else if (!ngtcp2_path_eq(&prev_ps.path, &ps.path)) { + quic_send_packet(static_cast(prev_ps.path.user_data), + prev_ps.path.remote.addr, prev_ps.path.remote.addrlen, + prev_ps.path.local.addr, prev_ps.path.local.addrlen, + buf.data(), bufpos - buf.data() - nwrite, + SHRPX_MAX_UDP_PAYLOAD_SIZE); + + quic_send_packet(static_cast(ps.path.user_data), + ps.path.remote.addr, ps.path.remote.addrlen, + ps.path.local.addr, ps.path.local.addrlen, + bufpos - nwrite, nwrite, SHRPX_MAX_UDP_PAYLOAD_SIZE); + + ngtcp2_conn_update_pkt_tx_time(conn_, ts); + // reset_idle_timer here + + handler_->signal_write(); + + return 0; + } + + if (++pktcnt == max_pktcnt || + static_cast(nwrite) < SHRPX_MAX_UDP_PAYLOAD_SIZE) { + quic_send_packet(static_cast(ps.path.user_data), + ps.path.remote.addr, ps.path.remote.addrlen, + ps.path.local.addr, ps.path.local.addrlen, buf.data(), + bufpos - buf.data(), SHRPX_MAX_UDP_PAYLOAD_SIZE); + + ngtcp2_conn_update_pkt_tx_time(conn_, ts); + // reset_idle_timer here + + handler_->signal_write(); + + return 0; + } + } + + return 0; +} int Http3Upstream::on_timeout(Downstream *downstream) { return 0; } @@ -382,6 +495,13 @@ int Http3Upstream::on_tx_secret(ngtcp2_crypto_level level, int Http3Upstream::add_crypto_data(ngtcp2_crypto_level level, const uint8_t *data, size_t datalen) { + int rv = ngtcp2_conn_submit_crypto_data(conn_, level, data, datalen); + + if (rv != 0) { + LOG(ERROR) << "ngtcp2_conn_submit_crypto_data: " << ngtcp2_strerror(rv); + return -1; + } + return 0; } diff --git a/src/shrpx_quic.cc b/src/shrpx_quic.cc index cd989f9f..09f0683f 100644 --- a/src/shrpx_quic.cc +++ b/src/shrpx_quic.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -198,10 +199,86 @@ int create_quic_server_socket(UpstreamAddr &faddr) { return 0; } -int quic_send_packet(const UpstreamAddr *addr, const sockaddr *remote_sa, +int quic_send_packet(const UpstreamAddr *faddr, const sockaddr *remote_sa, size_t remote_salen, const sockaddr *local_sa, size_t local_salen, const uint8_t *data, size_t datalen, size_t gso_size) { + iovec msg_iov = {const_cast(data), datalen}; + msghdr msg{}; + msg.msg_name = const_cast(remote_sa); + msg.msg_namelen = remote_salen; + msg.msg_iov = &msg_iov; + msg.msg_iovlen = 1; + + uint8_t + msg_ctrl[CMSG_SPACE(sizeof(uint16_t)) + CMSG_SPACE(sizeof(in6_pktinfo))]; + + memset(msg_ctrl, 0, sizeof(msg_ctrl)); + + msg.msg_control = msg_ctrl; + msg.msg_controllen = sizeof(msg_ctrl); + + size_t controllen = 0; + + auto cm = CMSG_FIRSTHDR(&msg); + + switch (local_sa->sa_family) { + case AF_INET: { + controllen += CMSG_SPACE(sizeof(in_pktinfo)); + cm->cmsg_level = IPPROTO_IP; + cm->cmsg_type = IP_PKTINFO; + cm->cmsg_len = CMSG_LEN(sizeof(in_pktinfo)); + auto pktinfo = reinterpret_cast(CMSG_DATA(cm)); + memset(pktinfo, 0, sizeof(in_pktinfo)); + auto addrin = + reinterpret_cast(const_cast(local_sa)); + pktinfo->ipi_spec_dst = addrin->sin_addr; + break; + } + case AF_INET6: { + controllen += CMSG_SPACE(sizeof(in6_pktinfo)); + cm->cmsg_level = IPPROTO_IPV6; + cm->cmsg_type = IPV6_PKTINFO; + cm->cmsg_len = CMSG_LEN(sizeof(in6_pktinfo)); + auto pktinfo = reinterpret_cast(CMSG_DATA(cm)); + memset(pktinfo, 0, sizeof(in6_pktinfo)); + auto addrin = + reinterpret_cast(const_cast(local_sa)); + pktinfo->ipi6_addr = addrin->sin6_addr; + break; + } + default: + assert(0); + } + + if (gso_size && datalen > gso_size) { + controllen += CMSG_SPACE(sizeof(uint16_t)); + cm = CMSG_NXTHDR(&msg, cm); + cm->cmsg_level = SOL_UDP; + cm->cmsg_type = UDP_SEGMENT; + cm->cmsg_len = CMSG_LEN(sizeof(uint16_t)); + *(reinterpret_cast(CMSG_DATA(cm))) = gso_size; + } + + msg.msg_controllen = controllen; + + ssize_t nwrite; + + do { + nwrite = sendmsg(faddr->fd, &msg, 0); + } while (nwrite == -1 && errno == EINTR); + + if (nwrite == -1) { + return -1; + } + + if (LOG_ENABLED(INFO)) { + LOG(INFO) << "QUIC sent packet: local=" + << util::to_numeric_addr(local_sa, local_salen) + << " remote=" << util::to_numeric_addr(remote_sa, remote_salen) + << " " << nwrite << " bytes"; + } + return 0; } diff --git a/src/shrpx_quic.h b/src/shrpx_quic.h index cf405e61..4fa329e8 100644 --- a/src/shrpx_quic.h +++ b/src/shrpx_quic.h @@ -42,7 +42,7 @@ ngtcp2_tstamp quic_timestamp(); int create_quic_server_socket(UpstreamAddr &addr); -int quic_send_packet(const UpstreamAddr *addr, const sockaddr *remote_sa, +int quic_send_packet(const UpstreamAddr *faddr, const sockaddr *remote_sa, size_t remote_salen, const sockaddr *local_sa, size_t local_salen, const uint8_t *data, size_t datalen, size_t gso_size); diff --git a/src/shrpx_quic_connection_handler.cc b/src/shrpx_quic_connection_handler.cc index 7ff08767..3d6d97ed 100644 --- a/src/shrpx_quic_connection_handler.cc +++ b/src/shrpx_quic_connection_handler.cc @@ -105,8 +105,11 @@ int QUICConnectionHandler::handle_packet(const UpstreamAddr *faddr, if (handler->read_quic(faddr, remote_addr, local_addr, data, datalen) != 0) { delete handler; + return 0; } + handler->signal_write(); + return 0; } @@ -135,6 +138,10 @@ ClientHandler *QUICConnectionHandler::handle_new_connection( return nullptr; } + assert(SSL_is_quic(ssl)); + + SSL_set_accept_state(ssl); + // Disable TLS session ticket if we don't have working ticket // keys. if (!worker_->get_ticket_keys()) { diff --git a/src/shrpx_tls.cc b/src/shrpx_tls.cc index 9ca883e8..e6d7fabf 100644 --- a/src/shrpx_tls.cc +++ b/src/shrpx_tls.cc @@ -1118,6 +1118,10 @@ int quic_send_alert(SSL *ssl, OSSL_ENCRYPTION_LEVEL ossl_level, uint8_t alert) { auto handler = static_cast(conn->data); auto upstream = static_cast(handler->get_upstream()); + if (!upstream) { + return 1; + } + upstream->set_tls_alert(alert); return 1;