diff --git a/src/shrpx_worker.cc b/src/shrpx_worker.cc index e128badb..9dbdd28f 100644 --- a/src/shrpx_worker.cc +++ b/src/shrpx_worker.cc @@ -1021,7 +1021,10 @@ const UpstreamAddr *Worker::find_quic_upstream_addr(const Address &local_addr) { assert(0); } - auto hostport = util::make_hostport(StringRef{host.data()}, port); + std::array hostport_buf; + + auto hostport = util::make_http_hostport(std::begin(hostport_buf), + StringRef{host.data()}, port); const UpstreamAddr *fallback_faddr = nullptr; for (auto &faddr : quic_upstream_addrs_) { @@ -1033,21 +1036,41 @@ const UpstreamAddr *Worker::find_quic_upstream_addr(const Address &local_addr) { continue; } - switch (faddr.family) { - case AF_INET: - if (util::starts_with(faddr.hostport, StringRef::from_lit("0.0.0.0:"))) { - fallback_faddr = &faddr; - } + if (faddr.port == 443 || faddr.port == 80) { + switch (faddr.family) { + case AF_INET: + if (util::streq(faddr.hostport, StringRef::from_lit("0.0.0.0"))) { + fallback_faddr = &faddr; + } - break; - case AF_INET6: - if (util::starts_with(faddr.hostport, StringRef::from_lit("[::]:"))) { - fallback_faddr = &faddr; - } + break; + case AF_INET6: + if (util::streq(faddr.hostport, StringRef::from_lit("[::]"))) { + fallback_faddr = &faddr; + } - break; - default: - assert(0); + break; + default: + assert(0); + } + } else { + switch (faddr.family) { + case AF_INET: + if (util::starts_with(faddr.hostport, + StringRef::from_lit("0.0.0.0:"))) { + fallback_faddr = &faddr; + } + + break; + case AF_INET6: + if (util::starts_with(faddr.hostport, StringRef::from_lit("[::]:"))) { + fallback_faddr = &faddr; + } + + break; + default: + assert(0); + } } } diff --git a/src/util.cc b/src/util.cc index 6efc0e9e..97af8d79 100644 --- a/src/util.cc +++ b/src/util.cc @@ -1325,81 +1325,26 @@ std::string dtos(double n) { StringRef make_http_hostport(BlockAllocator &balloc, const StringRef &host, uint16_t port) { - if (port != 80 && port != 443) { - return make_hostport(balloc, host, port); - } - - auto ipv6 = ipv6_numeric_addr(host.c_str()); - - auto iov = make_byte_ref(balloc, host.size() + (ipv6 ? 2 : 0) + 1); - auto p = iov.base; - - if (ipv6) { - *p++ = '['; - } - - p = std::copy(std::begin(host), std::end(host), p); - - if (ipv6) { - *p++ = ']'; - } - - *p = '\0'; - - return StringRef{iov.base, p}; + auto iov = make_byte_ref(balloc, host.size() + 2 + 1 + 5 + 1); + return make_http_hostport(iov.base, host, port); } std::string make_hostport(const StringRef &host, uint16_t port) { - auto ipv6 = ipv6_numeric_addr(host.c_str()); - auto serv = utos(port); - std::string hostport; - hostport.resize(host.size() + (ipv6 ? 2 : 0) + 1 + serv.size()); + // I'm not sure we can write \0 at the position std::string::size(), + // so allocate an extra byte. + hostport.resize(host.size() + 2 + 1 + 5 + 1); - auto p = &hostport[0]; - - if (ipv6) { - *p++ = '['; - } - - p = std::copy_n(host.c_str(), host.size(), p); - - if (ipv6) { - *p++ = ']'; - } - - *p++ = ':'; - std::copy_n(serv.c_str(), serv.size(), p); + auto s = make_hostport(std::begin(hostport), host, port); + hostport.resize(s.size()); return hostport; } StringRef make_hostport(BlockAllocator &balloc, const StringRef &host, uint16_t port) { - auto ipv6 = ipv6_numeric_addr(host.c_str()); - auto serv = utos(port); - - auto iov = - make_byte_ref(balloc, host.size() + (ipv6 ? 2 : 0) + 1 + serv.size()); - auto p = iov.base; - - if (ipv6) { - *p++ = '['; - } - - p = std::copy(std::begin(host), std::end(host), p); - - if (ipv6) { - *p++ = ']'; - } - - *p++ = ':'; - - p = std::copy(std::begin(serv), std::end(serv), p); - - *p = '\0'; - - return StringRef{iov.base, p}; + auto iov = make_byte_ref(balloc, host.size() + 2 + 1 + 5 + 1); + return make_hostport(iov.base, host, port); } namespace { diff --git a/src/util.h b/src/util.h index ca9b437d..d4990e03 100644 --- a/src/util.h +++ b/src/util.h @@ -762,12 +762,6 @@ std::string format_duration(const std::chrono::microseconds &u); // Just like above, but this takes |t| as seconds. std::string format_duration(double t); -// Creates "host:port" string using given |host| and |port|. If -// |host| is numeric IPv6 address (e.g., ::1), it is enclosed by "[" -// and "]". If |port| is 80 or 443, port part is omitted. -StringRef make_http_hostport(BlockAllocator &balloc, const StringRef &host, - uint16_t port); - // Just like make_http_hostport(), but doesn't treat 80 and 443 // specially. std::string make_hostport(const StringRef &host, uint16_t port); @@ -775,6 +769,65 @@ std::string make_hostport(const StringRef &host, uint16_t port); StringRef make_hostport(BlockAllocator &balloc, const StringRef &host, uint16_t port); +template +StringRef make_hostport(OutputIt first, const StringRef &host, uint16_t port) { + auto ipv6 = ipv6_numeric_addr(host.c_str()); + auto serv = utos(port); + auto p = first; + + if (ipv6) { + *p++ = '['; + } + + p = std::copy(std::begin(host), std::end(host), p); + + if (ipv6) { + *p++ = ']'; + } + + *p++ = ':'; + + p = std::copy(std::begin(serv), std::end(serv), p); + + *p = '\0'; + + return StringRef{first, p}; +} + +// Creates "host:port" string using given |host| and |port|. If +// |host| is numeric IPv6 address (e.g., ::1), it is enclosed by "[" +// and "]". If |port| is 80 or 443, port part is omitted. +StringRef make_http_hostport(BlockAllocator &balloc, const StringRef &host, + uint16_t port); + +constexpr size_t max_hostport = NI_MAXHOST + /* [] for IPv6 */ 2 + /* : */ 1 + + /* port */ 5 + /* terminal NUL */ 1; + +template +StringRef make_http_hostport(OutputIt first, const StringRef &host, + uint16_t port) { + if (port != 80 && port != 443) { + return make_hostport(first, host, port); + } + + auto ipv6 = ipv6_numeric_addr(host.c_str()); + auto p = first; + + if (ipv6) { + *p++ = '['; + } + + p = std::copy(std::begin(host), std::end(host), p); + + if (ipv6) { + *p++ = ']'; + } + + *p = '\0'; + + return StringRef{first, p}; +} + // Dumps |src| of length |len| in the format similar to `hexdump -C`. void hexdump(FILE *out, const uint8_t *src, size_t len); diff --git a/src/util_test.cc b/src/util_test.cc index dfe87e99..d5ff942f 100644 --- a/src/util_test.cc +++ b/src/util_test.cc @@ -550,6 +550,12 @@ void test_util_make_hostport(void) { util::make_hostport(balloc, StringRef::from_lit("localhost"), 80)); CU_ASSERT("[::1]:443" == util::make_hostport(balloc, StringRef::from_lit("::1"), 443)); + + // Check std::string version + CU_ASSERT( + "abcdefghijklmnopqrstuvwxyz0123456789:65535" == + util::make_hostport( + StringRef::from_lit("abcdefghijklmnopqrstuvwxyz0123456789"), 65535)); } void test_util_strifind(void) {