diff --git a/src/nghttp.cc b/src/nghttp.cc index fa763d40..3c2284ea 100644 --- a/src/nghttp.cc +++ b/src/nghttp.cc @@ -105,6 +105,7 @@ Config::Config() window_bits(-1), connection_window_bits(-1), verbose(0), + port_override(0), null_out(false), remote_name(false), get_assets(false), @@ -178,7 +179,57 @@ void Request::init_inflater() { assert(rv == 0); } -void Request::init_html_parser() { html_parser = new HtmlParser(uri); } +StringRef Request::get_real_scheme() const { + return config.scheme_override.empty() + ? util::get_uri_field(uri.c_str(), u, UF_SCHEMA) + : StringRef{config.scheme_override}; +} + +StringRef Request::get_real_host() const { + return config.host_override.empty() + ? util::get_uri_field(uri.c_str(), u, UF_HOST) + : StringRef{config.host_override}; +} + +uint16_t Request::get_real_port() const { + auto scheme = get_real_scheme(); + return config.host_override.empty() + ? util::has_uri_field(u, UF_PORT) ? u.port + : scheme == "https" ? 443 : 80 + : config.port_override == 0 ? scheme == "https" ? 443 : 80 + : config.port_override; +} + +void Request::init_html_parser() { + // We crawl HTML using overridden scheme, host, and port. + auto scheme = get_real_scheme(); + auto host = get_real_host(); + auto port = get_real_port(); + auto ipv6_lit = + std::find(std::begin(host), std::end(host), ':') != std::end(host); + + auto base_uri = scheme.str(); + base_uri += "://"; + if (ipv6_lit) { + base_uri += '['; + } + base_uri += host; + if (ipv6_lit) { + base_uri += ']'; + } + if (!((scheme == "https" && port == 443) || + (scheme == "http" && port == 80))) { + base_uri += ':'; + base_uri += util::utos(port); + } + base_uri += util::get_uri_field(uri.c_str(), u, UF_PATH); + if (util::has_uri_field(u, UF_QUERY)) { + base_uri += '?'; + base_uri += util::get_uri_field(uri.c_str(), u, UF_QUERY); + } + + html_parser = new HtmlParser(base_uri); +} int Request::update_html_parser(const uint8_t *data, size_t len, int fin) { if (!html_parser) { @@ -625,20 +676,11 @@ int HttpClient::initiate_connection() { // If the user overrode the :authority or host header, use that // value for the SNI extension - const char *host_string = nullptr; - auto i = - std::find_if(std::begin(config.headers), std::end(config.headers), - [](const Header &nv) { - return ":authority" == nv.name || "host" == nv.name; - }); - if (i != std::end(config.headers)) { - host_string = (*i).value.c_str(); - } else { - host_string = host.c_str(); - } + const auto &host_string = + config.host_override.empty() ? host : config.host_override; - if (!util::numeric_host(host_string)) { - SSL_set_tlsext_host_name(ssl, host_string); + if (!util::numeric_host(host_string.c_str())) { + SSL_set_tlsext_host_name(ssl, host_string.c_str()); } } @@ -1591,35 +1633,36 @@ void update_html_parser(HttpClient *client, Request *req, const uint8_t *data, } req->update_html_parser(data, len, fin); + auto scheme = req->get_real_scheme(); + auto host = req->get_real_host(); + auto port = req->get_real_port(); + for (auto &p : req->html_parser->get_links()) { auto uri = strip_fragment(p.first.c_str()); auto res_type = p.second; http_parser_url u{}; - if (http_parser_parse_url(uri.c_str(), uri.size(), 0, &u) == 0) { + if (http_parser_parse_url(uri.c_str(), uri.size(), 0, &u) != 0) { + continue; + } - const char *host_string = nullptr; - auto found = - std::find_if(std::begin(config.headers), std::end(config.headers), - [](const Header &nv) { - return ":authority" == nv.name || "host" == nv.name; - }); - if (found != std::end(config.headers)) { - host_string = (*found).value.c_str(); - } + if (!util::fieldeq(uri.c_str(), u, UF_SCHEMA, scheme) || + !util::fieldeq(uri.c_str(), u, UF_HOST, host)) { + continue; + } - if (util::fieldeq(uri.c_str(), u, req->uri.c_str(), req->u, UF_SCHEMA) && - (util::fieldeq(uri.c_str(), u, req->uri.c_str(), req->u, UF_HOST) || - (host_string != nullptr && - util::fieldeq(uri.c_str(), u, UF_HOST, host_string))) && - util::porteq(uri.c_str(), u, req->uri.c_str(), req->u)) { - // No POST data for assets - auto pri_spec = resolve_dep(res_type); + auto link_port = + util::has_uri_field(u, UF_PORT) ? u.port : scheme == "https" ? 443 : 80; - if (client->add_request(uri, nullptr, 0, pri_spec, req->level + 1)) { - submit_request(client, config.headers, client->reqvec.back().get()); - } - } + if (port != link_port) { + continue; + } + + // No POST data for assets + auto pri_spec = resolve_dep(res_type); + + if (client->add_request(uri, nullptr, 0, pri_spec, req->level + 1)) { + submit_request(client, config.headers, client->reqvec.back().get()); } } req->html_parser->clear_links(); @@ -2945,6 +2988,41 @@ int main(int argc, char **argv) { } config.weight.insert(std::end(config.weight), argc - optind, weight_to_fill); + // Find scheme overridden by extra header fields. + auto scheme_it = + std::find_if(std::begin(config.headers), std::end(config.headers), + [](const Header &nv) { return nv.name == ":scheme"; }); + if (scheme_it != std::end(config.headers)) { + config.scheme_override = (*scheme_it).value; + } + + // Find host and port overridden by extra header fields. + auto authority_it = + std::find_if(std::begin(config.headers), std::end(config.headers), + [](const Header &nv) { return nv.name == ":authority"; }); + if (authority_it == std::end(config.headers)) { + authority_it = + std::find_if(std::begin(config.headers), std::end(config.headers), + [](const Header &nv) { return nv.name == "host"; }); + } + + if (authority_it != std::end(config.headers)) { + // authority_it may looks like "host:port". + auto uri = "https://" + (*authority_it).value; + http_parser_url u{}; + if (http_parser_parse_url(uri.c_str(), uri.size(), 0, &u) != 0) { + std::cerr << "[ERROR] Could not parse authority in " + << (*authority_it).name << ": " << (*authority_it).value + << std::endl; + exit(EXIT_FAILURE); + } + + config.host_override = util::get_uri_field(uri.c_str(), u, UF_HOST).str(); + if (util::has_uri_field(u, UF_PORT)) { + config.port_override = u.port; + } + } + set_color_output(color || isatty(fileno(stdout))); nghttp2_option_set_peer_max_concurrent_streams( diff --git a/src/nghttp.h b/src/nghttp.h index eb36a4ce..ccc2230c 100644 --- a/src/nghttp.h +++ b/src/nghttp.h @@ -69,6 +69,8 @@ struct Config { std::string keyfile; std::string datafile; std::string harfile; + std::string scheme_override; + std::string host_override; nghttp2_option *http2_option; int64_t header_table_size; int64_t min_header_table_size; @@ -82,6 +84,7 @@ struct Config { int window_bits; int connection_window_bits; int verbose; + uint16_t port_override; bool null_out; bool remote_name; bool get_assets; @@ -151,6 +154,15 @@ struct Request { void record_response_start_time(); void record_response_end_time(); + // Returns scheme taking into account overridden scheme. + StringRef get_real_scheme() const; + // Returns request host, without port, taking into account + // overridden host. + StringRef get_real_host() const; + // Returns request port, taking into account overridden host, port, + // and scheme. + uint16_t get_real_port() const; + Headers res_nva; Headers req_nva; std::string method; diff --git a/src/util.cc b/src/util.cc index 587fee51..8dec1d97 100644 --- a/src/util.cc +++ b/src/util.cc @@ -555,20 +555,16 @@ bool fieldeq(const char *uri1, const http_parser_url &u1, const char *uri2, bool fieldeq(const char *uri, const http_parser_url &u, http_parser_url_fields field, const char *t) { + return fieldeq(uri, u, field, StringRef{t}); +} + +bool fieldeq(const char *uri, const http_parser_url &u, + http_parser_url_fields field, const StringRef &t) { if (!has_uri_field(u, field)) { - if (!t[0]) { - return true; - } else { - return false; - } - } else if (!t[0]) { - return false; + return t.empty(); } - int i, len = u.field_data[field].len; - const char *p = uri + u.field_data[field].off; - for (i = 0; i < len && t[i] && p[i] == t[i]; ++i) - ; - return i == len && !t[i]; + auto &f = u.field_data[field]; + return StringRef{uri + f.off, f.len} == t; } StringRef get_uri_field(const char *uri, const http_parser_url &u, diff --git a/src/util.h b/src/util.h index 3bf5ff92..39280b4d 100644 --- a/src/util.h +++ b/src/util.h @@ -461,6 +461,9 @@ bool fieldeq(const char *uri1, const http_parser_url &u1, const char *uri2, bool fieldeq(const char *uri, const http_parser_url &u, http_parser_url_fields field, const char *t); +bool fieldeq(const char *uri, const http_parser_url &u, + http_parser_url_fields field, const StringRef &t); + StringRef get_uri_field(const char *uri, const http_parser_url &u, http_parser_url_fields field);