diff --git a/src/shrpx-unittest.cc b/src/shrpx-unittest.cc index f65b8002..37a8e0ec 100644 --- a/src/shrpx-unittest.cc +++ b/src/shrpx-unittest.cc @@ -69,6 +69,8 @@ int main(int argc, char *argv[]) { shrpx::test_shrpx_ssl_create_lookup_tree) || !CU_add_test(pSuite, "ssl_cert_lookup_tree_add_cert_from_file", shrpx::test_shrpx_ssl_cert_lookup_tree_add_cert_from_file) || + !CU_add_test(pSuite, "ssl_tls_hostname_match", + shrpx::test_shrpx_ssl_tls_hostname_match) || !CU_add_test(pSuite, "http2_add_header", shrpx::test_http2_add_header) || !CU_add_test(pSuite, "http2_get_header", shrpx::test_http2_get_header) || !CU_add_test(pSuite, "http2_copy_headers_to_nva", diff --git a/src/shrpx_ssl.cc b/src/shrpx_ssl.cc index 61e61cbf..890484f2 100644 --- a/src/shrpx_ssl.cc +++ b/src/shrpx_ssl.cc @@ -780,27 +780,35 @@ ClientHandler *accept_connection(Worker *worker, int fd, sockaddr *addr, return new ClientHandler(worker, fd, ssl, host, service); } -namespace { -bool tls_hostname_match(const char *pattern, const char *hostname) { - const char *ptWildcard = strchr(pattern, '*'); - if (ptWildcard == nullptr) { - return util::strieq(pattern, hostname); +bool tls_hostname_match(const char *pattern, size_t plen, const char *hostname, + size_t hlen) { + auto pend = pattern + plen; + auto ptWildcard = std::find(pattern, pend, '*'); + if (ptWildcard == pend) { + return util::strieq(pattern, plen, hostname, hlen); } - const char *ptLeftLabelEnd = strchr(pattern, '.'); - bool wildcardEnabled = true; + + auto ptLeftLabelEnd = std::find(pattern, pend, '.'); + auto wildcardEnabled = true; // Do case-insensitive match. At least 2 dots are required to enable // wildcard match. Also wildcard must be in the left-most label. // Don't attempt to match a presented identifier where the wildcard // character is embedded within an A-label. - if (ptLeftLabelEnd == 0 || strchr(ptLeftLabelEnd + 1, '.') == 0 || - ptLeftLabelEnd < ptWildcard || util::istarts_with(pattern, "xn--")) { + if (ptLeftLabelEnd == pend || + std::find(ptLeftLabelEnd + 1, pend, '.') == pend || + ptLeftLabelEnd < ptWildcard || + util::istarts_with(pattern, plen, "xn--")) { wildcardEnabled = false; } + if (!wildcardEnabled) { - return util::strieq(pattern, hostname); + return util::strieq(pattern, plen, hostname, hlen); } - const char *hnLeftLabelEnd = strchr(hostname, '.'); - if (hnLeftLabelEnd == 0 || !util::strieq(ptLeftLabelEnd, hnLeftLabelEnd)) { + + auto hend = hostname + hlen; + auto hnLeftLabelEnd = std::find(hostname, hend, '.'); + if (hnLeftLabelEnd == hend || + !util::strieq(ptLeftLabelEnd, pend, hnLeftLabelEnd, hend)) { return false; } // Perform wildcard match. Here '*' must match at least one @@ -812,107 +820,143 @@ bool tls_hostname_match(const char *pattern, const char *hostname) { util::iends_with(hostname, hnLeftLabelEnd, ptWildcard + 1, ptLeftLabelEnd); } -} // namespace namespace { -int verify_hostname(const char *hostname, const Address *addr, - const std::vector &dns_names, - const std::vector &ip_addrs, - const std::string &common_name) { - if (util::numeric_host(hostname)) { - if (ip_addrs.empty()) { - return util::strieq(common_name.c_str(), hostname) ? 0 : -1; - } - const void *saddr; - switch (addr->su.storage.ss_family) { - case AF_INET: - saddr = &addr->su.in.sin_addr; +ssize_t get_common_name(unsigned char **out_ptr, X509 *cert) { + auto subjectname = X509_get_subject_name(cert); + if (!subjectname) { + LOG(WARN) << "Could not get X509 name object from the certificate."; + return -1; + } + int lastpos = -1; + for (;;) { + lastpos = X509_NAME_get_index_by_NID(subjectname, NID_commonName, lastpos); + if (lastpos == -1) { break; - case AF_INET6: - saddr = &addr->su.in6.sin6_addr; - break; - default: - return -1; } - for (size_t i = 0; i < ip_addrs.size(); ++i) { - if (addr->len == ip_addrs[i].size() && - memcmp(saddr, ip_addrs[i].c_str(), addr->len) == 0) { - return 0; - } + auto entry = X509_NAME_get_entry(subjectname, lastpos); + + auto outlen = ASN1_STRING_to_UTF8(out_ptr, X509_NAME_ENTRY_get_data(entry)); + if (outlen < 0) { + continue; } - } else { - if (dns_names.empty()) { - return tls_hostname_match(common_name.c_str(), hostname) ? 0 : -1; - } - for (size_t i = 0; i < dns_names.size(); ++i) { - if (tls_hostname_match(dns_names[i].c_str(), hostname)) { - return 0; - } + if (std::find(*out_ptr, *out_ptr + outlen, '\0') != *out_ptr + outlen) { + // Embedded NULL is not permitted. + continue; } + return outlen; } return -1; } } // namespace -void get_altnames(X509 *cert, std::vector &dns_names, - std::vector &ip_addrs, - std::string &common_name) { - GENERAL_NAMES *altnames = static_cast( +namespace { +int verify_numeric_hostname(X509 *cert, const char *hostname, size_t hlen, + const Address *addr) { + const void *saddr; + switch (addr->su.storage.ss_family) { + case AF_INET: + saddr = &addr->su.in.sin_addr; + break; + case AF_INET6: + saddr = &addr->su.in6.sin6_addr; + break; + default: + return -1; + } + + auto altnames = static_cast( X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr)); if (altnames) { auto altnames_deleter = defer(GENERAL_NAMES_free, altnames); - size_t n = sk_GENERAL_NAME_num(altnames); + auto n = sk_GENERAL_NAME_num(altnames); for (size_t i = 0; i < n; ++i) { - const GENERAL_NAME *altname = sk_GENERAL_NAME_value(altnames, i); - if (altname->type == GEN_DNS) { - const char *name; - name = reinterpret_cast(ASN1_STRING_data(altname->d.ia5)); - if (!name) { - continue; - } - size_t len = ASN1_STRING_length(altname->d.ia5); - if (std::find(name, name + len, '\0') != name + len) { - // Embedded NULL is not permitted. - continue; - } - dns_names.push_back(std::string(name, len)); - } else if (altname->type == GEN_IPADD) { - const unsigned char *ip_addr = altname->d.iPAddress->data; - if (!ip_addr) { - continue; - } - size_t len = altname->d.iPAddress->length; - ip_addrs.push_back( - std::string(reinterpret_cast(ip_addr), len)); + auto altname = sk_GENERAL_NAME_value(altnames, i); + if (altname->type != GEN_IPADD) { + continue; + } + + auto ip_addr = altname->d.iPAddress->data; + if (!ip_addr) { + continue; + } + auto ip_addrlen = altname->d.iPAddress->length; + + if (addr->len == ip_addrlen && memcmp(saddr, ip_addr, ip_addrlen) == 0) { + return 0; } } } - X509_NAME *subjectname = X509_get_subject_name(cert); - if (!subjectname) { - LOG(WARN) << "Could not get X509 name object from the certificate."; - return; + + unsigned char *cn; + auto cnlen = get_common_name(&cn, cert); + if (cnlen == -1) { + return -1; } - int lastpos = -1; - while (1) { - lastpos = X509_NAME_get_index_by_NID(subjectname, NID_commonName, lastpos); - if (lastpos == -1) { - break; - } - X509_NAME_ENTRY *entry = X509_NAME_get_entry(subjectname, lastpos); - unsigned char *out; - int outlen = ASN1_STRING_to_UTF8(&out, X509_NAME_ENTRY_get_data(entry)); - if (outlen < 0) { - continue; - } - if (std::find(out, out + outlen, '\0') != out + outlen) { - // Embedded NULL is not permitted. - continue; - } - common_name.assign(&out[0], &out[outlen]); - OPENSSL_free(out); - break; + + // cn is not NULL terminated + auto rv = util::streq(hostname, hlen, cn, cnlen); + OPENSSL_free(cn); + + if (rv) { + return 0; } + + return -1; } +} // namespace + +namespace { +int verify_hostname(X509 *cert, const char *hostname, size_t hlen, + const Address *addr) { + if (util::numeric_host(hostname)) { + return verify_numeric_hostname(cert, hostname, hlen, addr); + } + + auto altnames = static_cast( + X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr)); + if (altnames) { + auto altnames_deleter = defer(GENERAL_NAMES_free, altnames); + auto n = sk_GENERAL_NAME_num(altnames); + for (size_t i = 0; i < n; ++i) { + auto altname = sk_GENERAL_NAME_value(altnames, i); + if (altname->type != GEN_DNS) { + continue; + } + + auto name = reinterpret_cast(ASN1_STRING_data(altname->d.ia5)); + if (!name) { + continue; + } + + auto len = ASN1_STRING_length(altname->d.ia5); + if (std::find(name, name + len, '\0') != name + len) { + // Embedded NULL is not permitted. + continue; + } + + if (tls_hostname_match(name, len, hostname, hlen)) { + return 0; + } + } + } + + unsigned char *cn; + auto cnlen = get_common_name(&cn, cert); + if (cnlen == -1) { + return -1; + } + + auto rv = util::strieq(hostname, hlen, cn, cnlen); + OPENSSL_free(cn); + + if (rv) { + return 0; + } + + return -1; +} +} // namespace int check_cert(SSL *ssl, const DownstreamAddr *addr) { auto cert = SSL_get_peer_certificate(ssl); @@ -921,21 +965,16 @@ int check_cert(SSL *ssl, const DownstreamAddr *addr) { return -1; } auto cert_deleter = defer(X509_free, cert); - long verify_res = SSL_get_verify_result(ssl); + auto verify_res = SSL_get_verify_result(ssl); if (verify_res != X509_V_OK) { LOG(ERROR) << "Certificate verification failed: " << X509_verify_cert_error_string(verify_res); return -1; } - std::string common_name; - std::vector dns_names; - std::vector ip_addrs; - get_altnames(cert, dns_names, ip_addrs, common_name); auto hostname = get_config()->backend_tls_sni_name ? get_config()->backend_tls_sni_name.get() : addr->host.get(); - if (verify_hostname(hostname, &addr->addr, dns_names, ip_addrs, - common_name) != 0) { + if (verify_hostname(cert, hostname, strlen(hostname), &addr->addr) != 0) { LOG(ERROR) << "Certificate verification failed: hostname does not match"; return -1; } @@ -969,7 +1008,7 @@ void cert_lookup_tree_add_cert(CertNode *node, SSL_CTX *ssl_ctx, char *hostname, // some restrictions for wildcard hostname. We just ignore // these rules here but do the proper check when we do the // match. - node->wildcard_certs.emplace_back(hostname, ssl_ctx); + node->wildcard_certs.push_back({ssl_ctx, hostname, len}); return; } @@ -986,7 +1025,7 @@ void cert_lookup_tree_add_cert(CertNode *node, SSL_CTX *ssl_ctx, char *hostname, new_node->ssl_ctx = ssl_ctx; } else { new_node->ssl_ctx = nullptr; - new_node->wildcard_certs.emplace_back(hostname, ssl_ctx); + new_node->wildcard_certs.push_back({ssl_ctx, hostname, len}); } node->next.push_back(std::move(new_node)); return; @@ -1073,9 +1112,11 @@ SSL_CTX *cert_lookup_tree_lookup(CertNode *node, const char *hostname, // one character. return nullptr; } + for (const auto &wildcert : node->wildcard_certs) { - if (tls_hostname_match(wildcert.first, hostname)) { - return wildcert.second; + if (tls_hostname_match(wildcert.hostname, wildcert.hostnamelen, hostname, + len)) { + return wildcert.ssl_ctx; } } auto c = util::lowcase(hostname[j]); @@ -1111,14 +1152,43 @@ int cert_lookup_tree_add_cert_from_file(CertLookupTree *lt, SSL_CTX *ssl_ctx, return -1; } auto cert_deleter = defer(X509_free, cert); - std::string common_name; - std::vector dns_names; - std::vector ip_addrs; - get_altnames(cert, dns_names, ip_addrs, common_name); - for (auto &dns_name : dns_names) { - lt->add_cert(ssl_ctx, dns_name.c_str(), dns_name.size()); + + auto altnames = static_cast( + X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr)); + if (altnames) { + auto altnames_deleter = defer(GENERAL_NAMES_free, altnames); + auto n = sk_GENERAL_NAME_num(altnames); + for (size_t i = 0; i < n; ++i) { + auto altname = sk_GENERAL_NAME_value(altnames, i); + if (altname->type != GEN_DNS) { + continue; + } + + auto name = reinterpret_cast(ASN1_STRING_data(altname->d.ia5)); + if (!name) { + continue; + } + + auto len = ASN1_STRING_length(altname->d.ia5); + if (std::find(name, name + len, '\0') != name + len) { + // Embedded NULL is not permitted. + continue; + } + + lt->add_cert(ssl_ctx, name, len); + } } - lt->add_cert(ssl_ctx, common_name.c_str(), common_name.size()); + + unsigned char *cn; + auto cnlen = get_common_name(&cn, cert); + if (cnlen == -1) { + return 0; + } + + lt->add_cert(ssl_ctx, reinterpret_cast(cn), cnlen); + + OPENSSL_free(cn); + return 0; } diff --git a/src/shrpx_ssl.h b/src/shrpx_ssl.h index ffc7bab6..5d893b08 100644 --- a/src/shrpx_ssl.h +++ b/src/shrpx_ssl.h @@ -108,10 +108,16 @@ void get_altnames(X509 *cert, std::vector &dns_names, // them. If there is a match, its SSL_CTX is returned. If none // matches, query is continued to the next character. +struct WildcardCert { + SSL_CTX *ssl_ctx; + char *hostname; + size_t hostnamelen; +}; + struct CertNode { // list of wildcard domain name and its SSL_CTX pair, the wildcard // '*' appears in this position. - std::vector> wildcard_certs; + std::vector wildcard_certs; // Next CertNode index of CertLookupTree::nodes std::vector> next; // SSL_CTX for exact match @@ -198,6 +204,13 @@ SSL *create_ssl(SSL_CTX *ssl_ctx); // Returns true if SSL/TLS is enabled on downstream bool downstream_tls_enabled(); +// Performs TLS hostname match. |pattern| of length |plen| can +// contain wildcard character '*', which matches prefix of target +// hostname. There are several restrictions to make wildcard work. +// The matching algorithm is based on RFC 6125. +bool tls_hostname_match(const char *pattern, size_t plen, const char *hostname, + size_t hlen); + } // namespace ssl } // namespace shrpx diff --git a/src/shrpx_ssl_test.cc b/src/shrpx_ssl_test.cc index ff646b85..35167324 100644 --- a/src/shrpx_ssl_test.cc +++ b/src/shrpx_ssl_test.cc @@ -115,4 +115,37 @@ void test_shrpx_ssl_cert_lookup_tree_add_cert_from_file(void) { SSL_CTX_free(ssl_ctx); } +template +bool tls_hostname_match_wrapper(const char(&pattern)[N], + const char(&hostname)[M]) { + return ssl::tls_hostname_match(pattern, N, hostname, M); +} + +void test_shrpx_ssl_tls_hostname_match(void) { + CU_ASSERT(tls_hostname_match_wrapper("example.com", "example.com")); + CU_ASSERT(tls_hostname_match_wrapper("example.com", "EXAMPLE.com")); + + // check wildcard + CU_ASSERT(tls_hostname_match_wrapper("*.example.com", "www.example.com")); + CU_ASSERT(tls_hostname_match_wrapper("*w.example.com", "www.example.com")); + CU_ASSERT(tls_hostname_match_wrapper("www*.example.com", "www1.example.com")); + CU_ASSERT( + tls_hostname_match_wrapper("www*.example.com", "WWW12.EXAMPLE.com")); + // at least 2 dots are required after '*' + CU_ASSERT(!tls_hostname_match_wrapper("*.com", "example.com")); + CU_ASSERT(!tls_hostname_match_wrapper("*", "example.com")); + // '*' must be in left most label + CU_ASSERT( + !tls_hostname_match_wrapper("blog.*.example.com", "blog.my.example.com")); + // prefix is wrong + CU_ASSERT( + !tls_hostname_match_wrapper("client*.example.com", "server.example.com")); + // '*' must match at least one character + CU_ASSERT(!tls_hostname_match_wrapper("www*.example.com", "www.example.com")); + + CU_ASSERT(!tls_hostname_match_wrapper("example.com", "nghttp2.org")); + CU_ASSERT(!tls_hostname_match_wrapper("www.example.com", "example.com")); + CU_ASSERT(!tls_hostname_match_wrapper("example.com", "www.example.com")); +} + } // namespace shrpx diff --git a/src/shrpx_ssl_test.h b/src/shrpx_ssl_test.h index 89b00449..0b3619e2 100644 --- a/src/shrpx_ssl_test.h +++ b/src/shrpx_ssl_test.h @@ -33,6 +33,7 @@ namespace shrpx { void test_shrpx_ssl_create_lookup_tree(void); void test_shrpx_ssl_cert_lookup_tree_add_cert_from_file(void); +void test_shrpx_ssl_tls_hostname_match(void); } // namespace shrpx diff --git a/src/util.h b/src/util.h index eb879c15..67f9c27e 100644 --- a/src/util.h +++ b/src/util.h @@ -258,6 +258,15 @@ bool strieq(InputIt1 a, size_t alen, InputIt2 b, size_t blen) { return std::equal(a, a + alen, b, CaseCmp()); } +template +bool strieq(InputIt1 first1, InputIt1 last1, InputIt2 first2, InputIt2 last2) { + if (std::distance(first1, last1) != std::distance(first2, last2)) { + return false; + } + + return std::equal(first1, last1, first2, CaseCmp()); +} + inline bool strieq(const std::string &a, const std::string &b) { return strieq(std::begin(a), a.size(), std::begin(b), b.size()); }