diff --git a/src/http2.cc b/src/http2.cc index 5289ef5c..d703006a 100644 --- a/src/http2.cc +++ b/src/http2.cc @@ -115,7 +115,7 @@ void sanitize_header_value(std::string& s, size_t offset) } } -void copy_url_component(std::string& dest, http_parser_url *u, int field, +void copy_url_component(std::string& dest, const http_parser_url *u, int field, const char* url) { if(u->field_set & (1 << field)) { @@ -439,6 +439,74 @@ void dump_nv(FILE *out, const nghttp2_nv *nva, size_t nvlen) fflush(out); } +std::string rewrite_location_uri(const std::string& uri, + const http_parser_url& u, + const std::string& request_host, + const std::string& upstream_scheme, + uint16_t upstream_port, + uint16_t downstream_port) +{ + // We just rewrite host and optionally port. We don't rewrite https + // link. Not sure it happens in practice. + if(u.field_set & (1 << UF_SCHEMA)) { + auto field = &u.field_data[UF_SCHEMA]; + if(!util::streq("http", &uri[field->off], field->len)) { + return ""; + } + } + if((u.field_set & (1 << UF_HOST)) == 0) { + return ""; + } + std::string host; + copy_url_component(host, &u, UF_HOST, uri.c_str()); + if(u.field_set & (1 << UF_PORT)) { + host += ":"; + host += util::utos(u.port); + if(host != request_host) { + // :authority or host have "host", but host in location header + // field may have "host:port". + auto field = &u.field_data[UF_HOST]; + if(!util::streq(request_host.c_str(), request_host.size(), + &uri[field->off], field->len) || + downstream_port != u.port) { + return ""; + } + } + } else if(host != request_host) { + return ""; + } + std::string res = upstream_scheme; + res += "://"; + auto field = &u.field_data[UF_HOST]; + res.append(&uri[field->off], field->len); + if(upstream_scheme == "http") { + if(upstream_port != 80) { + res += ":"; + res += util::utos(upstream_port); + } + } else if(upstream_scheme == "https") { + if(upstream_port != 443) { + res += ":"; + res += util::utos(upstream_port); + } + } + if(u.field_set & (1 << UF_PATH)) { + field = &u.field_data[UF_PATH]; + res.append(&uri[field->off], field->len); + } + if(u.field_set & (1 << UF_QUERY)) { + field = &u.field_data[UF_QUERY]; + res += "?"; + res.append(&uri[field->off], field->len); + } + if(u.field_set & (1 << UF_FRAGMENT)) { + field = &u.field_data[UF_FRAGMENT]; + res += "#"; + res.append(&uri[field->off], field->len); + } + return res; +} + } // namespace http2 } // namespace nghttp2 diff --git a/src/http2.h b/src/http2.h index e5325f4a..3635768c 100644 --- a/src/http2.h +++ b/src/http2.h @@ -55,7 +55,7 @@ void sanitize_header_value(std::string& s, size_t offset); // Copies the |field| component value from |u| and |url| to the // |dest|. If |u| does not have |field|, then this function does // nothing. -void copy_url_component(std::string& dest, http_parser_url *u, int field, +void copy_url_component(std::string& dest, const http_parser_url *u, int field, const char* url); // Returns true if the header field |name| with length |namelen| bytes @@ -170,6 +170,24 @@ void dump_nv(FILE *out, const char **nv); // Dumps name/value pairs in |nva| to |out|. void dump_nv(FILE *out, const nghttp2_nv *nva, size_t nvlen); +// Rewrites redirection URI which usually appears in location header +// field. The |uri| is the URI in the location header field. The |u| +// stores the result of parsed |uri|. The |request_host| is the host +// or :authority header field value in the request. The +// |upstream_scheme| is either "https" or "http" in the upstream +// interface. The |downstream_port| is the port in the downstream +// connection. +// +// This function returns the new rewritten URI on success. If the +// location URI is not subject to the rewrite, this function returns +// emtpy string. +std::string rewrite_location_uri(const std::string& uri, + const http_parser_url& u, + const std::string& request_host, + const std::string& upstream_scheme, + uint16_t upstream_port, + uint16_t downstream_port); + } // namespace http2 } // namespace nghttp2 diff --git a/src/http2_test.cc b/src/http2_test.cc index 9f92ef00..91ca067c 100644 --- a/src/http2_test.cc +++ b/src/http2_test.cc @@ -30,6 +30,8 @@ #include +#include "http-parser/http_parser.h" + #include "http2.h" #include "util.h" @@ -222,4 +224,52 @@ void test_http2_check_header_value(void) CU_ASSERT(!http2::check_header_value(&nv3)); } +namespace { +void check_rewrite_location_uri(const std::string& new_uri, + const std::string& uri, + const std::string& req_host, + const std::string& upstream_scheme, + uint16_t upstream_port, + uint16_t downstream_port) +{ + http_parser_url u; + CU_ASSERT(0 == http_parser_parse_url(uri.c_str(), uri.size(), 0, &u)); + CU_ASSERT(new_uri == + http2::rewrite_location_uri(uri, u, req_host, + upstream_scheme, upstream_port, + downstream_port)); +} +} // namespace + +void test_http2_rewrite_location_uri(void) +{ + check_rewrite_location_uri("https://localhost:3000/alpha?bravo#charlie", + "http://localhost:3001/alpha?bravo#charlie", + "localhost:3001", "https", 3000, 3001); + check_rewrite_location_uri("https://localhost/", + "http://localhost:3001/", + "localhost:3001", "https", 443, 3001); + check_rewrite_location_uri("http://localhost/", + "http://localhost:3001/", + "localhost:3001", "http", 80, 3001); + check_rewrite_location_uri("http://localhost:443/", + "http://localhost:3001/", + "localhost:3001", "http", 443, 3001); + check_rewrite_location_uri("https://localhost:80/", + "http://localhost:3001/", + "localhost:3001", "https", 80, 3001); + check_rewrite_location_uri("", + "http://localhost:3001/", + "127.0.0.1", "https", 3000, 3001); + check_rewrite_location_uri("https://localhost:3000/", + "http://localhost:3001/", + "localhost", "https", 3000, 3001); + check_rewrite_location_uri("", + "https://localhost:3001/", + "localhost", "https", 3000, 3001); + check_rewrite_location_uri("https://localhost:3000/", + "http://localhost/", + "localhost", "https", 3000, 80); +} + } // namespace shrpx diff --git a/src/http2_test.h b/src/http2_test.h index d1c87e27..5b7ecfdf 100644 --- a/src/http2_test.h +++ b/src/http2_test.h @@ -36,6 +36,7 @@ void test_http2_concat_norm_headers(void); void test_http2_copy_norm_headers_to_nva(void); void test_http2_build_http1_headers_from_norm_headers(void); void test_http2_check_header_value(void); +void test_http2_rewrite_location_uri(void); } // namespace shrpx diff --git a/src/shrpx-unittest.cc b/src/shrpx-unittest.cc index db1bb49c..e2d11350 100644 --- a/src/shrpx-unittest.cc +++ b/src/shrpx-unittest.cc @@ -86,6 +86,8 @@ int main(int argc, char* argv[]) shrpx::test_http2_build_http1_headers_from_norm_headers) || !CU_add_test(pSuite, "http2_check_header_value", shrpx::test_http2_check_header_value) || + !CU_add_test(pSuite, "http2_rewrite_location_uri", + shrpx::test_http2_rewrite_location_uri) || !CU_add_test(pSuite, "downstream_normalize_request_headers", shrpx::test_downstream_normalize_request_headers) || !CU_add_test(pSuite, "downstream_normalize_response_headers", @@ -98,6 +100,8 @@ int main(int argc, char* argv[]) shrpx::test_downstream_crumble_request_cookie) || !CU_add_test(pSuite, "downstream_assemble_request_cookie", shrpx::test_downstream_assemble_request_cookie) || + !CU_add_test(pSuite, "downstream_rewrite_norm_location_response_header", + shrpx::test_downstream_rewrite_norm_location_response_header) || !CU_add_test(pSuite, "util_streq", shrpx::test_util_streq) || !CU_add_test(pSuite, "util_inp_strlower", shrpx::test_util_inp_strlower) || diff --git a/src/shrpx_client_handler.cc b/src/shrpx_client_handler.cc index d6a1ba5b..9e101eb1 100644 --- a/src/shrpx_client_handler.cc +++ b/src/shrpx_client_handler.cc @@ -460,4 +460,13 @@ bool ClientHandler::get_http2_upgrade_allowed() const return !ssl_; } +std::string ClientHandler::get_upstream_scheme() const +{ + if(ssl_) { + return "https"; + } else { + return "http"; + } +} + } // namespace shrpx diff --git a/src/shrpx_client_handler.h b/src/shrpx_client_handler.h index 1ebaf457..e2039efc 100644 --- a/src/shrpx_client_handler.h +++ b/src/shrpx_client_handler.h @@ -75,6 +75,8 @@ public: // terminated. This function returns 0 if it succeeds, or -1. int perform_http2_upgrade(HttpsUpstream *http); bool get_http2_upgrade_allowed() const; + // Returns upstream scheme, either "http" or "https" + std::string get_upstream_scheme() const; private: std::set dconn_pool_; std::unique_ptr upstream_; diff --git a/src/shrpx_downstream.cc b/src/shrpx_downstream.cc index 1cb81ee0..0521d9d8 100644 --- a/src/shrpx_downstream.cc +++ b/src/shrpx_downstream.cc @@ -26,6 +26,8 @@ #include +#include "http-parser/http_parser.h" + #include "shrpx_upstream.h" #include "shrpx_client_handler.h" #include "shrpx_config.h" @@ -174,6 +176,19 @@ Headers::const_iterator get_norm_header(const Headers& headers, } } // namespace +namespace { +Headers::iterator get_norm_header(Headers& headers, + const std::string& name) +{ + auto i = std::lower_bound(std::begin(headers), std::end(headers), + std::make_pair(name, std::string()), name_less); + if(i != std::end(headers) && (*i).first == name) { + return i; + } + return std::end(headers); +} +} // namespace + const Headers& Downstream::get_request_headers() const { return request_headers_; @@ -253,6 +268,11 @@ Headers::const_iterator Downstream::get_norm_request_header return get_norm_header(request_headers_, name); } +void Downstream::concat_norm_request_headers() +{ + request_headers_ = http2::concat_norm_headers(std::move(request_headers_)); +} + void Downstream::add_request_header(std::string name, std::string value) { request_header_key_prev_ = true; @@ -467,6 +487,42 @@ Headers::const_iterator Downstream::get_norm_response_header return get_norm_header(response_headers_, name); } +void Downstream::rewrite_norm_location_response_header +(const std::string& upstream_scheme, + uint16_t upstream_port, + uint16_t downstream_port) +{ + auto hd = get_norm_header(response_headers_, "location"); + if(hd == std::end(response_headers_)) { + return; + } + http_parser_url u; + int rv = http_parser_parse_url((*hd).second.c_str(), (*hd).second.size(), + 0, &u); + if(rv != 0) { + return; + } + std::string new_uri; + if(!request_http2_authority_.empty()) { + new_uri = http2::rewrite_location_uri((*hd).second, u, + request_http2_authority_, + upstream_scheme, upstream_port, + downstream_port); + } + if(new_uri.empty()) { + auto host = get_norm_request_header("host"); + if(host == std::end(request_headers_)) { + return; + } + new_uri = http2::rewrite_location_uri((*hd).second, u, (*host).second, + upstream_scheme, upstream_port, + downstream_port); + } + if(!new_uri.empty()) { + (*hd).second = std::move(new_uri); + } +} + void Downstream::add_response_header(std::string name, std::string value) { response_header_key_prev_ = true; diff --git a/src/shrpx_downstream.h b/src/shrpx_downstream.h index 25eb6f68..c4923cb0 100644 --- a/src/shrpx_downstream.h +++ b/src/shrpx_downstream.h @@ -94,6 +94,10 @@ public: // called after calling normalize_request_headers(). Headers::const_iterator get_norm_request_header (const std::string& name) const; + // Concatenates request header fields with same name by NULL as + // delimiter. See http2::concat_norm_headers(). This function must + // be called after calling normalize_request_headers(). + void concat_norm_request_headers(); void add_request_header(std::string name, std::string value); void set_last_request_header_value(std::string value); @@ -151,6 +155,13 @@ public: // called after calling normalize_response_headers(). Headers::const_iterator get_norm_response_header (const std::string& name) const; + // Rewrites the location response header field. This function must + // be called after calling normalize_response_headers() and + // normalize_request_headers(). + void rewrite_norm_location_response_header + (const std::string& upstream_scheme, + uint16_t upstream_port, + uint16_t downstream_port); void add_response_header(std::string name, std::string value); void set_last_response_header_value(std::string value); diff --git a/src/shrpx_downstream_test.cc b/src/shrpx_downstream_test.cc index 2a0ab18b..6ad989ae 100644 --- a/src/shrpx_downstream_test.cc +++ b/src/shrpx_downstream_test.cc @@ -146,4 +146,24 @@ void test_downstream_assemble_request_cookie(void) } +void test_downstream_rewrite_norm_location_response_header(void) +{ + { + Downstream d(nullptr, 0, 0); + d.add_request_header("host", "localhost:3000"); + d.add_response_header("location", "http://localhost:3000/"); + d.rewrite_norm_location_response_header("https", 443, 3000); + auto location = d.get_norm_response_header("location"); + CU_ASSERT("https://localhost/" == (*location).second); + } + { + Downstream d(nullptr, 0, 0); + d.set_request_http2_authority("localhost"); + d.add_response_header("location", "http://localhost/"); + d.rewrite_norm_location_response_header("https", 443, 80); + auto location = d.get_norm_response_header("location"); + CU_ASSERT("https://localhost/" == (*location).second); + } +} + } // namespace shrpx diff --git a/src/shrpx_downstream_test.h b/src/shrpx_downstream_test.h index ef645578..81a916b5 100644 --- a/src/shrpx_downstream_test.h +++ b/src/shrpx_downstream_test.h @@ -33,6 +33,7 @@ void test_downstream_get_norm_request_header(void); void test_downstream_get_norm_response_header(void); void test_downstream_crumble_request_cookie(void); void test_downstream_assemble_request_cookie(void); +void test_downstream_rewrite_norm_location_response_header(void); } // namespace shrpx diff --git a/src/shrpx_http2_downstream_connection.cc b/src/shrpx_http2_downstream_connection.cc index 9dfc90fa..6cfe1ec4 100644 --- a/src/shrpx_http2_downstream_connection.cc +++ b/src/shrpx_http2_downstream_connection.cc @@ -236,7 +236,7 @@ int Http2DownstreamConnection::push_request_headers() downstream_->crumble_request_cookie(); } downstream_->normalize_request_headers(); - downstream_->concat_norm_response_headers(); + downstream_->concat_norm_request_headers(); auto end_headers = std::end(downstream_->get_request_headers()); // 6 means: diff --git a/src/shrpx_http2_upstream.cc b/src/shrpx_http2_upstream.cc index 28706c55..c7b07797 100644 --- a/src/shrpx_http2_upstream.cc +++ b/src/shrpx_http2_upstream.cc @@ -945,6 +945,9 @@ int Http2Upstream::on_downstream_header_complete(Downstream *downstream) DLOG(INFO, downstream) << "HTTP response header completed"; } downstream->normalize_response_headers(); + downstream->rewrite_norm_location_response_header + (get_client_handler()->get_upstream_scheme(), get_config()->port, + get_config()->downstream_port); downstream->concat_norm_response_headers(); auto end_headers = std::end(downstream->get_response_headers()); size_t nheader = downstream->get_response_headers().size(); diff --git a/src/shrpx_https_upstream.cc b/src/shrpx_https_upstream.cc index 19006b54..caf7bbbe 100644 --- a/src/shrpx_https_upstream.cc +++ b/src/shrpx_https_upstream.cc @@ -656,6 +656,9 @@ int HttpsUpstream::on_downstream_header_complete(Downstream *downstream) hdrs += http2::get_status_string(downstream->get_response_http_status()); hdrs += "\r\n"; downstream->normalize_response_headers(); + downstream->rewrite_norm_location_response_header + (get_client_handler()->get_upstream_scheme(), get_config()->port, + get_config()->downstream_port); auto end_headers = std::end(downstream->get_response_headers()); http2::build_http1_headers_from_norm_headers (hdrs, downstream->get_response_headers()); diff --git a/src/shrpx_spdy_upstream.cc b/src/shrpx_spdy_upstream.cc index dc081e57..376d27ac 100644 --- a/src/shrpx_spdy_upstream.cc +++ b/src/shrpx_spdy_upstream.cc @@ -839,6 +839,10 @@ int SpdyUpstream::on_downstream_header_complete(Downstream *downstream) if(LOG_ENABLED(INFO)) { DLOG(INFO, downstream) << "HTTP response header completed"; } + downstream->normalize_response_headers(); + downstream->rewrite_norm_location_response_header + (get_client_handler()->get_upstream_scheme(), get_config()->port, + get_config()->downstream_port); size_t nheader = downstream->get_response_headers().size(); // 6 means :status, :version and possible via header field. auto nv = util::make_unique(nheader * 2 + 6 + 1); diff --git a/src/util.cc b/src/util.cc index fc67ef0c..30a3370e 100644 --- a/src/util.cc +++ b/src/util.cc @@ -172,24 +172,6 @@ bool strieq(const char *a, const uint8_t *b, size_t bn) return !*a && b == blast; } -bool streq(const char *a, const uint8_t *b, size_t bn) -{ - if(!a || !b) { - return false; - } - const uint8_t *blast = b + bn; - for(; *a && b != blast && *a == *b; ++a, ++b); - return !*a && b == blast; -} - -bool streq(const uint8_t *a, size_t alen, const uint8_t *b, size_t blen) -{ - if(alen != blen) { - return false; - } - return memcmp(a, b, alen) == 0; -} - int strcompare(const char *a, const uint8_t *b, size_t bn) { assert(a && b); diff --git a/src/util.h b/src/util.h index 89734ae0..45b8db28 100644 --- a/src/util.h +++ b/src/util.h @@ -299,9 +299,25 @@ bool strieq(const char *a, const char *b); bool strieq(const char *a, const uint8_t *b, size_t n); -bool streq(const char *a, const uint8_t *b, size_t bn); +template +bool streq(const A *a, const B *b, size_t bn) +{ + if(!a || !b) { + return false; + } + auto blast = b + bn; + for(; *a && b != blast && *a == *b; ++a, ++b); + return !*a && b == blast; +} -bool streq(const uint8_t *a, size_t alen, const uint8_t *b, size_t blen); +template +bool streq(const A *a, size_t alen, const B *b, size_t blen) +{ + if(alen != blen) { + return false; + } + return memcmp(a, b, alen) == 0; +} bool strifind(const char *a, const char *b); diff --git a/src/util_test.cc b/src/util_test.cc index ba143bc8..848ffad9 100644 --- a/src/util_test.cc +++ b/src/util_test.cc @@ -53,7 +53,9 @@ void test_util_streq(void) (const uint8_t*)"alpha", 4)); CU_ASSERT(!util::streq((const uint8_t*)"alpha", 5, (const uint8_t*)"alphA", 5)); - CU_ASSERT(util::streq(nullptr, 0, nullptr, 0)); + char *a = nullptr; + char *b = nullptr; + CU_ASSERT(util::streq(a, 0, b, 0)); } void test_util_inp_strlower(void)