diff --git a/src/shrpx_downstream.cc b/src/shrpx_downstream.cc index 7d4e69ee..8480d856 100644 --- a/src/shrpx_downstream.cc +++ b/src/shrpx_downstream.cc @@ -663,6 +663,18 @@ size_t Downstream::get_response_headers_sum() const { return response_headers_sum_; } +const Headers &Downstream::get_response_trailers() const { + return response_trailers_; +} + +void Downstream::add_response_trailer(const uint8_t *name, size_t namelen, + const uint8_t *value, size_t valuelen, + bool no_index, int16_t token) { + response_headers_sum_ += namelen + valuelen; + http2::add_header(response_trailers_, name, namelen, value, valuelen, + no_index, -1); +} + unsigned int Downstream::get_response_http_status() const { return response_http_status_; } diff --git a/src/shrpx_downstream.h b/src/shrpx_downstream.h index 2343dc44..f138811a 100644 --- a/src/shrpx_downstream.h +++ b/src/shrpx_downstream.h @@ -221,6 +221,11 @@ public: size_t get_response_headers_sum() const; + const Headers &get_response_trailers() const; + void add_response_trailer(const uint8_t *name, size_t namelen, + const uint8_t *value, size_t valuelen, + bool no_index, int16_t token); + unsigned int get_response_http_status() const; void set_response_http_status(unsigned int status); void set_response_major(int major); @@ -316,6 +321,7 @@ private: // trailer part. For HTTP/1.1, trailer part is only included with // chunked encoding. For HTTP/2, there is no such limit. Headers request_trailers_; + Headers response_trailers_; std::chrono::high_resolution_clock::time_point request_start_time_; diff --git a/src/shrpx_http2_downstream_connection.cc b/src/shrpx_http2_downstream_connection.cc index 3536a509..a359c3db 100644 --- a/src/shrpx_http2_downstream_connection.cc +++ b/src/shrpx_http2_downstream_connection.cc @@ -203,10 +203,11 @@ ssize_t http2_data_read_callback(nghttp2_session *session, int32_t stream_id, *data_flags |= NGHTTP2_DATA_FLAG_EOF; - if (!downstream->get_request_trailers().empty()) { + auto &trailers = downstream->get_request_trailers(); + if (!trailers.empty()) { std::vector nva; - nva.reserve(downstream->get_request_trailers().size()); - for (auto &kv : downstream->get_request_trailers()) { + nva.reserve(trailers.size()); + for (auto &kv : trailers) { nva.push_back(http2::make_nv(kv.name, kv.value, kv.no_index)); } rv = nghttp2_submit_trailer(session, stream_id, nva.data(), nva.size()); diff --git a/src/shrpx_http2_session.cc b/src/shrpx_http2_session.cc index 41301737..92e668dd 100644 --- a/src/shrpx_http2_session.cc +++ b/src/shrpx_http2_session.cc @@ -668,20 +668,35 @@ int on_header_callback(nghttp2_session *session, const nghttp2_frame *frame, return 0; } - if (frame->hd.type != NGHTTP2_HEADERS || - (frame->headers.cat != NGHTTP2_HCAT_RESPONSE && - !downstream->get_expect_final_response())) { + if (frame->hd.type != NGHTTP2_HEADERS) { return 0; } + auto trailer = frame->headers.cat != NGHTTP2_HCAT_RESPONSE && + !downstream->get_expect_final_response(); + if (downstream->get_response_headers_sum() > Downstream::MAX_HEADERS_SUM) { if (LOG_ENABLED(INFO)) { DLOG(INFO, downstream) << "Too large header block size=" << downstream->get_response_headers_sum(); } + + if (trailer) { + // we don't care trailer part exceeds header size limit; just + // discard it. + return 0; + } + return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE; } + if (trailer) { + // just store header fields for trailer part + downstream->add_response_trailer(name, namelen, value, valuelen, + flags & NGHTTP2_NV_FLAG_NO_INDEX, -1); + return 0; + } + auto token = http2::lookup_token(name, namelen); if (token == http2::HD_CONTENT_LENGTH) { @@ -891,10 +906,6 @@ int on_frame_recv_callback(nghttp2_session *session, const nghttp2_frame *frame, if (rv != 0) { return 0; } - } else if ((frame->hd.flags & NGHTTP2_FLAG_END_STREAM) == 0) { - http2session->submit_rst_stream(frame->hd.stream_id, - NGHTTP2_PROTOCOL_ERROR); - return 0; } } diff --git a/src/shrpx_http2_upstream.cc b/src/shrpx_http2_upstream.cc index bcaa8236..abe28be7 100644 --- a/src/shrpx_http2_upstream.cc +++ b/src/shrpx_http2_upstream.cc @@ -1068,6 +1068,7 @@ ssize_t downstream_data_read_callback(nghttp2_session *session, size_t length, uint32_t *data_flags, nghttp2_data_source *source, void *user_data) { + int rv; auto downstream = static_cast(source->ptr); auto upstream = static_cast(downstream->get_upstream()); auto body = downstream->get_response_buf(); @@ -1093,7 +1094,22 @@ ssize_t downstream_data_read_callback(nghttp2_session *session, *data_flags |= NGHTTP2_DATA_FLAG_EOF; if (!downstream->get_upgraded()) { - + auto &trailers = downstream->get_response_trailers(); + if (!trailers.empty()) { + std::vector nva; + nva.reserve(trailers.size()); + for (auto &kv : trailers) { + nva.push_back(http2::make_nv(kv.name, kv.value, kv.no_index)); + } + rv = nghttp2_submit_trailer(session, stream_id, nva.data(), nva.size()); + if (rv != 0) { + if (nghttp2_is_fatal(rv)) { + return NGHTTP2_ERR_CALLBACK_FAILURE; + } + } else { + *data_flags |= NGHTTP2_DATA_FLAG_NO_END_STREAM; + } + } if (nghttp2_session_get_stream_remote_close(session, stream_id) == 0) { upstream->rst_stream(downstream, NGHTTP2_NO_ERROR); }