diff --git a/src/http2.cc b/src/http2.cc index 6bb0ba4d..d5804816 100644 --- a/src/http2.cc +++ b/src/http2.cc @@ -327,39 +327,24 @@ void dump_nv(FILE *out, const Headers &nva) { std::string rewrite_location_uri(const std::string &uri, const http_parser_url &u, - const std::string &request_host, - const std::string &upstream_scheme, - uint16_t upstream_port) { - // We just rewrite host and optionally port. We don't rewrite https - // link. Not sure it happens in practice. - if (u.field_set & (1 << UF_SCHEMA)) { - auto field = &u.field_data[UF_SCHEMA]; - if (!util::streq("http", &uri[field->off], field->len)) { - return ""; - } - } + const std::string &match_host, + const std::string &request_authority, + const std::string &upstream_scheme) { + // We just rewrite scheme and authority. if ((u.field_set & (1 << UF_HOST)) == 0) { return ""; } auto field = &u.field_data[UF_HOST]; - if (!util::startsWith(std::begin(request_host), std::end(request_host), + if (!util::startsWith(std::begin(match_host), std::end(match_host), &uri[field->off], &uri[field->off] + field->len) || - (request_host.size() != field->len && request_host[field->len] != ':')) { + (match_host.size() != field->len && match_host[field->len] != ':')) { return ""; } - std::string res = upstream_scheme; - res += "://"; - res.append(&uri[field->off], field->len); - if (upstream_scheme == "http") { - if (upstream_port != 80) { - res += ":"; - res += util::utos(upstream_port); - } - } else if (upstream_scheme == "https") { - if (upstream_port != 443) { - res += ":"; - res += util::utos(upstream_port); - } + std::string res; + if (!request_authority.empty()) { + res += upstream_scheme; + res += "://"; + res += request_authority; } if (u.field_set & (1 << UF_PATH)) { field = &u.field_data[UF_PATH]; diff --git a/src/http2.h b/src/http2.h index fcd2c420..5e9b3565 100644 --- a/src/http2.h +++ b/src/http2.h @@ -163,19 +163,22 @@ void dump_nv(FILE *out, const Headers &nva); // Rewrites redirection URI which usually appears in location header // field. The |uri| is the URI in the location header field. The |u| -// stores the result of parsed |uri|. The |request_host| is the host -// or :authority header field value in the request. The +// stores the result of parsed |uri|. The |request_authority| is the +// host or :authority header field value in the request. The // |upstream_scheme| is either "https" or "http" in the upstream -// interface. +// interface. Rewrite is done only if location header field value +// contains |match_host| as host excluding port. The |match_host| and +// |request_authority| could be different. If |request_authority| is +// empty, strip authority. // // This function returns the new rewritten URI on success. If the // location URI is not subject to the rewrite, this function returns // emtpy string. std::string rewrite_location_uri(const std::string &uri, const http_parser_url &u, - const std::string &request_host, - const std::string &upstream_scheme, - uint16_t upstream_port); + const std::string &match_host, + const std::string &request_authority, + const std::string &upstream_scheme); // Checks the header name/value pair using nghttp2_check_header_name() // and nghttp2_check_header_value(). If both function returns nonzero, diff --git a/src/http2_test.cc b/src/http2_test.cc index 46bccdf5..d7df2383 100644 --- a/src/http2_test.cc +++ b/src/http2_test.cc @@ -187,40 +187,44 @@ void test_http2_lws(void) { } namespace { -void check_rewrite_location_uri(const std::string &new_uri, - const std::string &uri, - const std::string &req_host, - const std::string &upstream_scheme, - uint16_t upstream_port) { +void check_rewrite_location_uri(const std::string &want, const std::string &uri, + const std::string &match_host, + const std::string &req_authority, + const std::string &upstream_scheme) { http_parser_url u; memset(&u, 0, sizeof(u)); CU_ASSERT(0 == http_parser_parse_url(uri.c_str(), uri.size(), 0, &u)); - CU_ASSERT(new_uri == http2::rewrite_location_uri( - uri, u, req_host, upstream_scheme, upstream_port)); + auto got = http2::rewrite_location_uri(uri, u, match_host, req_authority, + upstream_scheme); + CU_ASSERT(want == got); } } // namespace void test_http2_rewrite_location_uri(void) { check_rewrite_location_uri("https://localhost:3000/alpha?bravo#charlie", "http://localhost:3001/alpha?bravo#charlie", - "localhost:3001", "https", 3000); + "localhost:3001", "localhost:3000", "https"); check_rewrite_location_uri("https://localhost/", "http://localhost:3001/", - "localhost:3001", "https", 443); + "localhost", "localhost", "https"); check_rewrite_location_uri("http://localhost/", "http://localhost:3001/", - "localhost:3001", "http", 80); + "localhost", "localhost", "http"); check_rewrite_location_uri("http://localhost:443/", "http://localhost:3001/", - "localhost:3001", "http", 443); + "localhost", "localhost:443", "http"); check_rewrite_location_uri("https://localhost:80/", "http://localhost:3001/", - "localhost:3001", "https", 80); - check_rewrite_location_uri("", "http://localhost:3001/", "127.0.0.1", "https", - 3000); + "localhost", "localhost:80", "https"); + check_rewrite_location_uri("", "http://localhost:3001/", "127.0.0.1", + "127.0.0.1", "https"); check_rewrite_location_uri("https://localhost:3000/", - "http://localhost:3001/", "localhost", "https", - 3000); - check_rewrite_location_uri("", "https://localhost:3001/", "localhost", - "https", 3000); + "http://localhost:3001/", "localhost", + "localhost:3000", "https"); check_rewrite_location_uri("https://localhost:3000/", "http://localhost/", - "localhost", "https", 3000); + "localhost", "localhost:3000", "https"); + + // match_host != req_authority + check_rewrite_location_uri("https://example.org", "http://127.0.0.1:8080", + "127.0.0.1", "example.org", "https"); + check_rewrite_location_uri("", "http://example.org", "127.0.0.1", + "example.org", "https"); } void test_http2_parse_http_status_code(void) { diff --git a/src/shrpx-unittest.cc b/src/shrpx-unittest.cc index 4a21892e..7e3e901a 100644 --- a/src/shrpx-unittest.cc +++ b/src/shrpx-unittest.cc @@ -55,7 +55,6 @@ int main(int argc, char *argv[]) { SSL_library_init(); shrpx::create_config(); - shrpx::mod_config()->no_host_rewrite = true; // initialize the CUnit test registry if (CUE_SUCCESS != CU_initialize_registry()) diff --git a/src/shrpx_downstream.cc b/src/shrpx_downstream.cc index 038aa896..7622e1b2 100644 --- a/src/shrpx_downstream.cc +++ b/src/shrpx_downstream.cc @@ -532,9 +532,8 @@ Downstream::get_response_header(int16_t token) const { return http2::get_header(response_hdidx_, token, response_headers_); } -void -Downstream::rewrite_location_response_header(const std::string &upstream_scheme, - uint16_t upstream_port) { +void Downstream::rewrite_location_response_header( + const std::string &upstream_scheme) { auto hd = http2::get_header(response_hdidx_, http2::HD_LOCATION, response_headers_); if (!hd) { @@ -550,24 +549,41 @@ Downstream::rewrite_location_response_header(const std::string &upstream_scheme, std::string new_uri; if (get_config()->no_host_rewrite) { if (!request_http2_authority_.empty()) { - new_uri = - http2::rewrite_location_uri((*hd).value, u, request_http2_authority_, - upstream_scheme, upstream_port); + new_uri = http2::rewrite_location_uri( + (*hd).value, u, request_http2_authority_, request_http2_authority_, + upstream_scheme); } if (new_uri.empty()) { auto host = get_request_header(http2::HD_HOST); - if (!host) { + if (host) { + new_uri = http2::rewrite_location_uri((*hd).value, u, (*host).value, + (*host).value, upstream_scheme); + } else if (!request_downstream_host_.empty()) { + new_uri = http2::rewrite_location_uri( + (*hd).value, u, request_downstream_host_, "", upstream_scheme); + } else { return; } - new_uri = http2::rewrite_location_uri((*hd).value, u, (*host).value, - upstream_scheme, upstream_port); } } else { - assert(dconn_); - auto request_host = - get_config()->downstream_addrs[dconn_->get_addr_idx()].host.get(); - new_uri = http2::rewrite_location_uri((*hd).value, u, request_host, - upstream_scheme, upstream_port); + if (request_downstream_host_.empty()) { + return; + } + if (!request_http2_authority_.empty()) { + new_uri = http2::rewrite_location_uri( + (*hd).value, u, request_downstream_host_, request_http2_authority_, + upstream_scheme); + } else { + auto host = get_request_header(http2::HD_HOST); + if (host) { + new_uri = http2::rewrite_location_uri((*hd).value, u, + request_downstream_host_, + (*host).value, upstream_scheme); + } else { + new_uri = http2::rewrite_location_uri( + (*hd).value, u, request_downstream_host_, "", upstream_scheme); + } + } } if (!new_uri.empty()) { auto idx = response_hdidx_[http2::HD_LOCATION]; @@ -1044,4 +1060,8 @@ void Downstream::add_retry() { ++num_retry_; } bool Downstream::no_more_retry() const { return num_retry_ > 5; } +void Downstream::set_request_downstream_host(std::string host) { + request_downstream_host_ = std::move(host); +} + } // namespace shrpx diff --git a/src/shrpx_downstream.h b/src/shrpx_downstream.h index f759041a..810155ec 100644 --- a/src/shrpx_downstream.h +++ b/src/shrpx_downstream.h @@ -166,6 +166,7 @@ public: int64_t get_request_content_length() const; void set_request_content_length(int64_t len); bool request_pseudo_header_allowed(int16_t token) const; + void set_request_downstream_host(std::string host); bool expect_response_body() const; enum { INITIAL, @@ -194,8 +195,7 @@ public: // This function must be called after response headers are indexed. const Headers::value_type *get_response_header(int16_t token) const; // Rewrites the location response header field. - void rewrite_location_response_header(const std::string &upstream_scheme, - uint16_t upstream_port); + void rewrite_location_response_header(const std::string &upstream_scheme); void add_response_header(std::string name, std::string value); void set_last_response_header_value(std::string value); @@ -310,6 +310,10 @@ private: std::string request_path_; std::string request_http2_scheme_; std::string request_http2_authority_; + // host we requested to downstream. This is used to rewrite + // location header field to decide the location should be rewritten + // or not. + std::string request_downstream_host_; std::string assembled_request_cookie_; DefaultMemchunks request_buf_; diff --git a/src/shrpx_downstream_connection.h b/src/shrpx_downstream_connection.h index 7851aff9..5594ccf7 100644 --- a/src/shrpx_downstream_connection.h +++ b/src/shrpx_downstream_connection.h @@ -58,8 +58,6 @@ public: virtual void on_upstream_change(Upstream *uptream) = 0; virtual int on_priority_change(int32_t pri) = 0; - virtual size_t get_addr_idx() const = 0; - void set_client_handler(ClientHandler *client_handler); ClientHandler *get_client_handler(); Downstream *get_downstream(); diff --git a/src/shrpx_downstream_test.cc b/src/shrpx_downstream_test.cc index aab7ba85..ad0abaa4 100644 --- a/src/shrpx_downstream_test.cc +++ b/src/shrpx_downstream_test.cc @@ -136,20 +136,22 @@ void test_downstream_assemble_request_cookie(void) { void test_downstream_rewrite_location_response_header(void) { { Downstream d(nullptr, 0, 0); - d.add_request_header("host", "localhost:3000"); + d.set_request_downstream_host("localhost:3000"); + d.add_request_header("host", "localhost"); d.add_response_header("location", "http://localhost:3000/"); d.index_request_headers(); d.index_response_headers(); - d.rewrite_location_response_header("https", 443); + d.rewrite_location_response_header("https"); auto location = d.get_response_header(http2::HD_LOCATION); CU_ASSERT("https://localhost/" == (*location).value); } { Downstream d(nullptr, 0, 0); + d.set_request_downstream_host("localhost"); d.set_request_http2_authority("localhost"); - d.add_response_header("location", "http://localhost/"); + d.add_response_header("location", "http://localhost:3000/"); d.index_response_headers(); - d.rewrite_location_response_header("https", 443); + d.rewrite_location_response_header("https"); auto location = d.get_response_header(http2::HD_LOCATION); CU_ASSERT("https://localhost/" == (*location).value); } diff --git a/src/shrpx_http2_downstream_connection.cc b/src/shrpx_http2_downstream_connection.cc index 2152ea3b..74bdcc8e 100644 --- a/src/shrpx_http2_downstream_connection.cc +++ b/src/shrpx_http2_downstream_connection.cc @@ -260,6 +260,12 @@ int Http2DownstreamConnection::push_request_headers() { host = get_config()->downstream_addrs[0].hostport.get(); } + if (authority) { + downstream_->set_request_downstream_host(authority); + } else { + downstream_->set_request_downstream_host(host); + } + size_t nheader = downstream_->get_request_headers().size(); Headers cookies; @@ -578,6 +584,4 @@ int Http2DownstreamConnection::on_timeout() { return submit_rst_stream(downstream_, NGHTTP2_NO_ERROR); } -size_t Http2DownstreamConnection::get_addr_idx() const { return 0; } - } // namespace shrpx diff --git a/src/shrpx_http2_downstream_connection.h b/src/shrpx_http2_downstream_connection.h index 55807f36..cee6265a 100644 --- a/src/shrpx_http2_downstream_connection.h +++ b/src/shrpx_http2_downstream_connection.h @@ -62,8 +62,6 @@ public: virtual void on_upstream_change(Upstream *upstream) {} virtual int on_priority_change(int32_t pri); - virtual size_t get_addr_idx() const; - int send(); void attach_stream_data(StreamData *sd); diff --git a/src/shrpx_http2_upstream.cc b/src/shrpx_http2_upstream.cc index 64d10697..2874a45a 100644 --- a/src/shrpx_http2_upstream.cc +++ b/src/shrpx_http2_upstream.cc @@ -1232,7 +1232,7 @@ int Http2Upstream::on_downstream_header_complete(Downstream *downstream) { if (!get_config()->http2_proxy && !get_config()->client_proxy && !get_config()->no_location_rewrite) { downstream->rewrite_location_response_header( - get_client_handler()->get_upstream_scheme(), get_config()->port); + downstream->get_request_http2_scheme()); } size_t nheader = downstream->get_response_headers().size(); diff --git a/src/shrpx_http_downstream_connection.cc b/src/shrpx_http_downstream_connection.cc index 74fb098c..5dd90151 100644 --- a/src/shrpx_http_downstream_connection.cc +++ b/src/shrpx_http_downstream_connection.cc @@ -235,6 +235,12 @@ int HttpDownstreamConnection::push_request_headers() { host = get_config()->downstream_addrs[addr_idx_].hostport.get(); } + if (authority) { + downstream_->set_request_downstream_host(authority); + } else { + downstream_->set_request_downstream_host(host); + } + downstream_->assemble_request_cookie(); // Assume that method and request path do not contain \r\n. @@ -767,6 +773,4 @@ void HttpDownstreamConnection::on_upstream_change(Upstream *upstream) {} void HttpDownstreamConnection::signal_write() { conn_.wlimit.startw(); } -size_t HttpDownstreamConnection::get_addr_idx() const { return addr_idx_; } - } // namespace shrpx diff --git a/src/shrpx_http_downstream_connection.h b/src/shrpx_http_downstream_connection.h index 280c6738..2eec5def 100644 --- a/src/shrpx_http_downstream_connection.h +++ b/src/shrpx_http_downstream_connection.h @@ -59,8 +59,6 @@ public: virtual void on_upstream_change(Upstream *upstream); virtual int on_priority_change(int32_t pri) { return 0; } - virtual size_t get_addr_idx() const; - int on_connect(); void signal_write(); diff --git a/src/shrpx_https_upstream.cc b/src/shrpx_https_upstream.cc index dda329a6..30439860 100644 --- a/src/shrpx_https_upstream.cc +++ b/src/shrpx_https_upstream.cc @@ -649,7 +649,7 @@ int HttpsUpstream::on_downstream_header_complete(Downstream *downstream) { if (!get_config()->http2_proxy && !get_config()->client_proxy && !get_config()->no_location_rewrite) { downstream->rewrite_location_response_header( - get_client_handler()->get_upstream_scheme(), get_config()->port); + get_client_handler()->get_upstream_scheme()); } http2::build_http1_headers_from_headers(hdrs, diff --git a/src/shrpx_spdy_upstream.cc b/src/shrpx_spdy_upstream.cc index 89fdcaf7..6ca21c11 100644 --- a/src/shrpx_spdy_upstream.cc +++ b/src/shrpx_spdy_upstream.cc @@ -841,7 +841,7 @@ int SpdyUpstream::on_downstream_header_complete(Downstream *downstream) { if (!get_config()->http2_proxy && !get_config()->client_proxy && !get_config()->no_location_rewrite) { downstream->rewrite_location_response_header( - get_client_handler()->get_upstream_scheme(), get_config()->port); + downstream->get_request_http2_scheme()); } size_t nheader = downstream->get_response_headers().size(); // 8 means server, :status, :version and possible via header field.