From 372123c178eb88f2fdb11a23f224d8b30a19838d Mon Sep 17 00:00:00 2001 From: Tatsuhiro Tsujikawa Date: Thu, 24 Mar 2016 23:16:20 +0900 Subject: [PATCH] nghttpx: Remove strieq(const char*, cosnt char*) overload, and fix unittests --- src/shrpx_config_test.cc | 34 +++++------ src/shrpx_ssl.cc | 124 +++++++++++++++++++++------------------ src/shrpx_ssl.h | 21 ++++--- src/shrpx_ssl_test.cc | 76 ++++++++++++------------ src/util.cc | 9 --- src/util.h | 52 ++++++++-------- src/util_test.cc | 30 +++++----- 7 files changed, 173 insertions(+), 173 deletions(-) diff --git a/src/shrpx_config_test.cc b/src/shrpx_config_test.cc index 635c23d9..5317fe00 100644 --- a/src/shrpx_config_test.cc +++ b/src/shrpx_config_test.cc @@ -37,40 +37,40 @@ namespace shrpx { void test_shrpx_config_parse_header(void) { - auto p = parse_header("a: b"); + auto p = parse_header(StringRef::from_lit("a: b")); CU_ASSERT("a" == p.name); CU_ASSERT("b" == p.value); - p = parse_header("a: b"); + p = parse_header(StringRef::from_lit("a: b")); CU_ASSERT("a" == p.name); CU_ASSERT("b" == p.value); - p = parse_header(":a: b"); + p = parse_header(StringRef::from_lit(":a: b")); CU_ASSERT(p.name.empty()); - p = parse_header("a: :b"); + p = parse_header(StringRef::from_lit("a: :b")); CU_ASSERT("a" == p.name); CU_ASSERT(":b" == p.value); - p = parse_header(": b"); + p = parse_header(StringRef::from_lit(": b")); CU_ASSERT(p.name.empty()); - p = parse_header("alpha: bravo charlie"); + p = parse_header(StringRef::from_lit("alpha: bravo charlie")); CU_ASSERT("alpha" == p.name); CU_ASSERT("bravo charlie" == p.value); - p = parse_header("a,: b"); + p = parse_header(StringRef::from_lit("a,: b")); CU_ASSERT(p.name.empty()); - p = parse_header("a: b\x0a"); + p = parse_header(StringRef::from_lit("a: b\x0a")); CU_ASSERT(p.name.empty()); } void test_shrpx_config_parse_log_format(void) { - auto res = - parse_log_format(R"($remote_addr - $remote_user [$time_local] )" - R"("$request" $status $body_bytes_sent )" - R"("${http_referer}" $http_host "$http_user_agent")"); + auto res = parse_log_format(StringRef::from_lit( + R"($remote_addr - $remote_user [$time_local] )" + R"("$request" $status $body_bytes_sent )" + R"("${http_referer}" $http_host "$http_user_agent")")); CU_ASSERT(16 == res.size()); CU_ASSERT(SHRPX_LOGF_REMOTE_ADDR == res[0].type); @@ -115,35 +115,35 @@ void test_shrpx_config_parse_log_format(void) { CU_ASSERT(SHRPX_LOGF_LITERAL == res[15].type); CU_ASSERT("\"" == res[15].value); - res = parse_log_format("$"); + res = parse_log_format(StringRef::from_lit("$")); CU_ASSERT(1 == res.size()); CU_ASSERT(SHRPX_LOGF_LITERAL == res[0].type); CU_ASSERT("$" == res[0].value); - res = parse_log_format("${"); + res = parse_log_format(StringRef::from_lit("${")); CU_ASSERT(1 == res.size()); CU_ASSERT(SHRPX_LOGF_LITERAL == res[0].type); CU_ASSERT("${" == res[0].value); - res = parse_log_format("${a"); + res = parse_log_format(StringRef::from_lit("${a")); CU_ASSERT(1 == res.size()); CU_ASSERT(SHRPX_LOGF_LITERAL == res[0].type); CU_ASSERT("${a" == res[0].value); - res = parse_log_format("${a "); + res = parse_log_format(StringRef::from_lit("${a ")); CU_ASSERT(1 == res.size()); CU_ASSERT(SHRPX_LOGF_LITERAL == res[0].type); CU_ASSERT("${a " == res[0].value); - res = parse_log_format("$$remote_addr"); + res = parse_log_format(StringRef::from_lit("$$remote_addr")); CU_ASSERT(2 == res.size()); diff --git a/src/shrpx_ssl.cc b/src/shrpx_ssl.cc index b6346d64..ecd5ad65 100644 --- a/src/shrpx_ssl.cc +++ b/src/shrpx_ssl.cc @@ -145,7 +145,8 @@ int servername_callback(SSL *ssl, int *al, void *arg) { if (cert_tree) { const char *hostname = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name); if (hostname) { - auto ssl_ctx = cert_tree->lookup(hostname, strlen(hostname)); + auto len = strlen(hostname); + auto ssl_ctx = cert_tree->lookup(StringRef{hostname, len}); if (ssl_ctx) { SSL_set_SSL_CTX(ssl, ssl_ctx); } @@ -820,53 +821,56 @@ ClientHandler *accept_connection(Worker *worker, int fd, sockaddr *addr, faddr); } -bool tls_hostname_match(const char *pattern, size_t plen, const char *hostname, - size_t hlen) { - auto pend = pattern + plen; - auto ptWildcard = std::find(pattern, pend, '*'); - if (ptWildcard == pend) { - return util::strieq(pattern, plen, hostname, hlen); +bool tls_hostname_match(const StringRef &pattern, const StringRef &hostname) { + auto ptWildcard = std::find(std::begin(pattern), std::end(pattern), '*'); + if (ptWildcard == std::end(pattern)) { + return util::strieq(pattern, hostname); } - auto ptLeftLabelEnd = std::find(pattern, pend, '.'); + auto ptLeftLabelEnd = std::find(std::begin(pattern), std::end(pattern), '.'); auto wildcardEnabled = true; // Do case-insensitive match. At least 2 dots are required to enable // wildcard match. Also wildcard must be in the left-most label. // Don't attempt to match a presented identifier where the wildcard // character is embedded within an A-label. - if (ptLeftLabelEnd == pend || - std::find(ptLeftLabelEnd + 1, pend, '.') == pend || - ptLeftLabelEnd < ptWildcard || - util::istarts_with(pattern, plen, "xn--")) { + if (ptLeftLabelEnd == std::end(pattern) || + std::find(ptLeftLabelEnd + 1, std::end(pattern), '.') == + std::end(pattern) || + ptLeftLabelEnd < ptWildcard || util::istarts_with_l(pattern, "xn--")) { wildcardEnabled = false; } if (!wildcardEnabled) { - return util::strieq(pattern, plen, hostname, hlen); + return util::strieq(pattern, hostname); } - auto hend = hostname + hlen; - auto hnLeftLabelEnd = std::find(hostname, hend, '.'); - if (hnLeftLabelEnd == hend || - !util::strieq(ptLeftLabelEnd, pend, hnLeftLabelEnd, hend)) { + auto hnLeftLabelEnd = + std::find(std::begin(hostname), std::end(hostname), '.'); + if (hnLeftLabelEnd == std::end(hostname) || + !util::strieq(StringRef{ptLeftLabelEnd, std::end(pattern)}, + StringRef{hnLeftLabelEnd, std::end(hostname)})) { return false; } // Perform wildcard match. Here '*' must match at least one // character. - if (hnLeftLabelEnd - hostname < ptLeftLabelEnd - pattern) { + if (hnLeftLabelEnd - std::begin(hostname) < + ptLeftLabelEnd - std::begin(pattern)) { return false; } - return util::istarts_with(hostname, hnLeftLabelEnd, pattern, ptWildcard) && - util::iends_with(hostname, hnLeftLabelEnd, ptWildcard + 1, - ptLeftLabelEnd); + return util::istarts_with(StringRef{std::begin(hostname), hnLeftLabelEnd}, + StringRef{std::begin(pattern), ptWildcard}) && + util::iends_with(StringRef{std::begin(hostname), hnLeftLabelEnd}, + StringRef{ptWildcard + 1, ptLeftLabelEnd}); } namespace { -ssize_t get_common_name(unsigned char **out_ptr, X509 *cert) { +// if return value is not empty, StringRef.c_str() must be freed using +// OPENSSL_free(). +StringRef get_common_name(X509 *cert) { auto subjectname = X509_get_subject_name(cert); if (!subjectname) { LOG(WARN) << "Could not get X509 name object from the certificate."; - return -1; + return StringRef{}; } int lastpos = -1; for (;;) { @@ -876,22 +880,29 @@ ssize_t get_common_name(unsigned char **out_ptr, X509 *cert) { } auto entry = X509_NAME_get_entry(subjectname, lastpos); - auto outlen = ASN1_STRING_to_UTF8(out_ptr, X509_NAME_ENTRY_get_data(entry)); - if (outlen < 0) { + unsigned char *p; + auto plen = ASN1_STRING_to_UTF8(&p, X509_NAME_ENTRY_get_data(entry)); + if (plen < 0) { continue; } - if (std::find(*out_ptr, *out_ptr + outlen, '\0') != *out_ptr + outlen) { + if (std::find(p, p + plen, '\0') != p + plen) { // Embedded NULL is not permitted. continue; } - return outlen; + if (plen == 0) { + LOG(WARN) << "X509 name is empty"; + OPENSSL_free(p); + continue; + } + + return StringRef{p, static_cast(plen)}; } - return -1; + return StringRef{}; } } // namespace namespace { -int verify_numeric_hostname(X509 *cert, const char *hostname, size_t hlen, +int verify_numeric_hostname(X509 *cert, const StringRef &hostname, const Address *addr) { const void *saddr; switch (addr->su.storage.ss_family) { @@ -928,15 +939,14 @@ int verify_numeric_hostname(X509 *cert, const char *hostname, size_t hlen, } } - unsigned char *cn; - auto cnlen = get_common_name(&cn, cert); - if (cnlen == -1) { + auto cn = get_common_name(cert); + if (cn.empty()) { return -1; } // cn is not NULL terminated - auto rv = util::streq(hostname, hlen, cn, cnlen); - OPENSSL_free(cn); + auto rv = util::streq(hostname, cn); + OPENSSL_free(const_cast(cn.c_str())); if (rv) { return 0; @@ -947,10 +957,10 @@ int verify_numeric_hostname(X509 *cert, const char *hostname, size_t hlen, } // namespace namespace { -int verify_hostname(X509 *cert, const char *hostname, size_t hlen, +int verify_hostname(X509 *cert, const StringRef &hostname, const Address *addr) { - if (util::numeric_host(hostname)) { - return verify_numeric_hostname(cert, hostname, hlen, addr); + if (util::numeric_host(hostname.c_str())) { + return verify_numeric_hostname(cert, hostname, addr); } auto altnames = static_cast( @@ -975,20 +985,20 @@ int verify_hostname(X509 *cert, const char *hostname, size_t hlen, continue; } - if (tls_hostname_match(name, len, hostname, hlen)) { + if (tls_hostname_match(StringRef{name, static_cast(len)}, + hostname)) { return 0; } } } - unsigned char *cn; - auto cnlen = get_common_name(&cn, cert); - if (cnlen == -1) { + auto cn = get_common_name(cert); + if (cn.empty()) { return -1; } - auto rv = util::strieq(hostname, hlen, cn, cnlen); - OPENSSL_free(cn); + auto rv = util::strieq(hostname, cn); + OPENSSL_free(const_cast(cn.c_str())); if (rv) { return 0; @@ -1012,7 +1022,7 @@ int check_cert(SSL *ssl, const Address *addr, const StringRef &host) { return -1; } - if (verify_hostname(cert, host.c_str(), host.size(), addr) != 0) { + if (verify_hostname(cert, host, addr) != 0) { LOG(ERROR) << "Certificate verification failed: hostname does not match"; return -1; } @@ -1138,8 +1148,8 @@ void CertLookupTree::add_cert(SSL_CTX *ssl_ctx, const char *hostname, } namespace { -SSL_CTX *cert_lookup_tree_lookup(CertNode *node, const char *hostname, - size_t len, int offset) { +SSL_CTX *cert_lookup_tree_lookup(CertNode *node, const StringRef &hostname, + int offset) { int i, j; for (i = node->first, j = offset; i > node->last && j >= 0 && node->str[i] == util::lowcase(hostname[j]); @@ -1160,23 +1170,26 @@ SSL_CTX *cert_lookup_tree_lookup(CertNode *node, const char *hostname, } for (const auto &wildcert : node->wildcard_certs) { - if (tls_hostname_match(wildcert.hostname, wildcert.hostnamelen, hostname, - len)) { + if (tls_hostname_match(StringRef{wildcert.hostname, wildcert.hostnamelen}, + hostname)) { return wildcert.ssl_ctx; } } auto c = util::lowcase(hostname[j]); for (const auto &next_node : node->next) { if (next_node->str[next_node->first] == c) { - return cert_lookup_tree_lookup(next_node.get(), hostname, len, j); + return cert_lookup_tree_lookup(next_node.get(), hostname, j); } } return nullptr; } } // namespace -SSL_CTX *CertLookupTree::lookup(const char *hostname, size_t len) { - return cert_lookup_tree_lookup(&root_, hostname, len, len - 1); +SSL_CTX *CertLookupTree::lookup(const StringRef &hostname) { + if (hostname.empty()) { + return nullptr; + } + return cert_lookup_tree_lookup(&root_, hostname, hostname.size() - 1); } int cert_lookup_tree_add_cert_from_file(CertLookupTree *lt, SSL_CTX *ssl_ctx, @@ -1225,15 +1238,14 @@ int cert_lookup_tree_add_cert_from_file(CertLookupTree *lt, SSL_CTX *ssl_ctx, } } - unsigned char *cn; - auto cnlen = get_common_name(&cn, cert); - if (cnlen == -1) { + auto cn = get_common_name(cert); + if (cn.empty()) { return 0; } - lt->add_cert(ssl_ctx, reinterpret_cast(cn), cnlen); + lt->add_cert(ssl_ctx, cn.c_str(), cn.size()); - OPENSSL_free(cn); + OPENSSL_free(const_cast(cn.c_str())); return 0; } diff --git a/src/shrpx_ssl.h b/src/shrpx_ssl.h index 748c034c..4be68d1a 100644 --- a/src/shrpx_ssl.h +++ b/src/shrpx_ssl.h @@ -143,11 +143,11 @@ public: // to the lookup tree. The |hostname| must be NULL-terminated. void add_cert(SSL_CTX *ssl_ctx, const char *hostname, size_t len); - // Looks up SSL_CTX using the given |hostname| with length |len|. - // If more than one SSL_CTX which matches the query, it is undefined - // which one is returned. The |hostname| must be NULL-terminated. - // If no matching SSL_CTX found, returns NULL. - SSL_CTX *lookup(const char *hostname, size_t len); + // Looks up SSL_CTX using the given |hostname|. If more than one + // SSL_CTX which matches the query, it is undefined which one is + // returned. The |hostname| must be NULL-terminated. If no + // matching SSL_CTX found, returns NULL. + SSL_CTX *lookup(const StringRef &hostname); private: CertNode root_; @@ -219,12 +219,11 @@ bool upstream_tls_enabled(); // Returns true if SSL/TLS is enabled on downstream bool downstream_tls_enabled(); -// Performs TLS hostname match. |pattern| of length |plen| can -// contain wildcard character '*', which matches prefix of target -// hostname. There are several restrictions to make wildcard work. -// The matching algorithm is based on RFC 6125. -bool tls_hostname_match(const char *pattern, size_t plen, const char *hostname, - size_t hlen); +// Performs TLS hostname match. |pattern| can contain wildcard +// character '*', which matches prefix of target hostname. There are +// several restrictions to make wildcard work. The matching algorithm +// is based on RFC 6125. +bool tls_hostname_match(const StringRef &pattern, const StringRef &hostname); // Caches |session| which is associated to remote address |addr|. // |session| is serialized into ASN1 representation, and stored. |t| diff --git a/src/shrpx_ssl_test.cc b/src/shrpx_ssl_test.cc index 35167324..79129a5c 100644 --- a/src/shrpx_ssl_test.cc +++ b/src/shrpx_ssl_test.cc @@ -43,41 +43,40 @@ void test_shrpx_ssl_create_lookup_tree(void) { SSL_CTX_new(SSLv23_method()), SSL_CTX_new(SSLv23_method()), SSL_CTX_new(SSLv23_method()), SSL_CTX_new(SSLv23_method())}; - const char *hostnames[] = { - "example.com", "www.example.org", "*www.example.org", "x*.host.domain", - "*yy.host.domain", "nghttp2.sourceforge.net", "sourceforge.net", - "sourceforge.net", // duplicate - "*.foo.bar", // oo.bar is suffix of *.foo.bar - "oo.bar"}; - int num = array_size(ctxs); - for (int i = 0; i < num; ++i) { - tree->add_cert(ctxs[i], hostnames[i], strlen(hostnames[i])); + constexpr StringRef hostnames[] = { + StringRef::from_lit("example.com"), + StringRef::from_lit("www.example.org"), + StringRef::from_lit("*www.example.org"), + StringRef::from_lit("x*.host.domain"), + StringRef::from_lit("*yy.host.domain"), + StringRef::from_lit("nghttp2.sourceforge.net"), + StringRef::from_lit("sourceforge.net"), + StringRef::from_lit("sourceforge.net"), // duplicate + StringRef::from_lit("*.foo.bar"), // oo.bar is suffix of *.foo.bar + StringRef::from_lit("oo.bar")}; + auto num = array_size(ctxs); + for (size_t i = 0; i < num; ++i) { + tree->add_cert(ctxs[i], hostnames[i].c_str(), hostnames[i].size()); } - CU_ASSERT(ctxs[0] == tree->lookup(hostnames[0], strlen(hostnames[0]))); - CU_ASSERT(ctxs[1] == tree->lookup(hostnames[1], strlen(hostnames[1]))); - const char h1[] = "2www.example.org"; - CU_ASSERT(ctxs[2] == tree->lookup(h1, strlen(h1))); - const char h2[] = "www2.example.org"; - CU_ASSERT(0 == tree->lookup(h2, strlen(h2))); - const char h3[] = "x1.host.domain"; - CU_ASSERT(ctxs[3] == tree->lookup(h3, strlen(h3))); + CU_ASSERT(ctxs[0] == tree->lookup(hostnames[0])); + CU_ASSERT(ctxs[1] == tree->lookup(hostnames[1])); + CU_ASSERT(ctxs[2] == tree->lookup(StringRef::from_lit("2www.example.org"))); + CU_ASSERT(nullptr == tree->lookup(StringRef::from_lit("www2.example.org"))); + CU_ASSERT(ctxs[3] == tree->lookup(StringRef::from_lit("x1.host.domain"))); // Does not match *yy.host.domain, because * must match at least 1 // character. - const char h4[] = "yy.Host.domain"; - CU_ASSERT(0 == tree->lookup(h4, strlen(h4))); - const char h5[] = "zyy.host.domain"; - CU_ASSERT(ctxs[4] == tree->lookup(h5, strlen(h5))); - CU_ASSERT(0 == tree->lookup("", 0)); - CU_ASSERT(ctxs[5] == tree->lookup(hostnames[5], strlen(hostnames[5]))); - CU_ASSERT(ctxs[6] == tree->lookup(hostnames[6], strlen(hostnames[6]))); - const char h6[] = "pdylay.sourceforge.net"; + CU_ASSERT(nullptr == tree->lookup(StringRef::from_lit("yy.Host.domain"))); + CU_ASSERT(ctxs[4] == tree->lookup(StringRef::from_lit("zyy.host.domain"))); + CU_ASSERT(nullptr == tree->lookup(StringRef{})); + CU_ASSERT(ctxs[5] == tree->lookup(hostnames[5])); + CU_ASSERT(ctxs[6] == tree->lookup(hostnames[6])); + constexpr char h6[] = "pdylay.sourceforge.net"; for (int i = 0; i < 7; ++i) { - CU_ASSERT(0 == tree->lookup(h6 + i, strlen(h6) - i)); + CU_ASSERT(0 == tree->lookup(StringRef{h6 + i, str_size(h6) - i})); } - const char h7[] = "x.foo.bar"; - CU_ASSERT(ctxs[8] == tree->lookup(h7, strlen(h7))); - CU_ASSERT(ctxs[9] == tree->lookup(hostnames[9], strlen(hostnames[9]))); + CU_ASSERT(ctxs[8] == tree->lookup(StringRef::from_lit("x.foo.bar"))); + CU_ASSERT(ctxs[9] == tree->lookup(hostnames[9])); for (int i = 0; i < num; ++i) { SSL_CTX_free(ctxs[i]); @@ -86,18 +85,20 @@ void test_shrpx_ssl_create_lookup_tree(void) { SSL_CTX *ctxs2[] = { SSL_CTX_new(SSLv23_method()), SSL_CTX_new(SSLv23_method()), SSL_CTX_new(SSLv23_method()), SSL_CTX_new(SSLv23_method())}; - const char *names[] = {"rab", "zab", "zzub", "ab"}; + constexpr StringRef names[] = { + StringRef::from_lit("rab"), StringRef::from_lit("zab"), + StringRef::from_lit("zzub"), StringRef::from_lit("ab")}; num = array_size(ctxs2); tree = make_unique(); - for (int i = 0; i < num; ++i) { - tree->add_cert(ctxs2[i], names[i], strlen(names[i])); + for (size_t i = 0; i < num; ++i) { + tree->add_cert(ctxs2[i], names[i].c_str(), names[i].size()); } - for (int i = 0; i < num; ++i) { - CU_ASSERT(ctxs2[i] == tree->lookup(names[i], strlen(names[i]))); + for (size_t i = 0; i < num; ++i) { + CU_ASSERT(ctxs2[i] == tree->lookup(names[i])); } - for (int i = 0; i < num; ++i) { + for (size_t i = 0; i < num; ++i) { SSL_CTX_free(ctxs2[i]); } } @@ -109,8 +110,7 @@ void test_shrpx_ssl_cert_lookup_tree_add_cert_from_file(void) { const char certfile[] = NGHTTP2_TESTS_DIR "/testdata/cacert.pem"; rv = ssl::cert_lookup_tree_add_cert_from_file(&tree, ssl_ctx, certfile); CU_ASSERT(0 == rv); - const char localhost[] = "localhost"; - CU_ASSERT(ssl_ctx == tree.lookup(localhost, sizeof(localhost) - 1)); + CU_ASSERT(ssl_ctx == tree.lookup(StringRef::from_lit("localhost"))); SSL_CTX_free(ssl_ctx); } @@ -118,7 +118,7 @@ void test_shrpx_ssl_cert_lookup_tree_add_cert_from_file(void) { template bool tls_hostname_match_wrapper(const char(&pattern)[N], const char(&hostname)[M]) { - return ssl::tls_hostname_match(pattern, N, hostname, M); + return ssl::tls_hostname_match(StringRef{pattern, N}, StringRef{hostname, M}); } void test_shrpx_ssl_tls_hostname_match(void) { diff --git a/src/util.cc b/src/util.cc index b1234574..e761c37c 100644 --- a/src/util.cc +++ b/src/util.cc @@ -363,15 +363,6 @@ bool istarts_with(const char *a, const char *b) { return !*b; } -bool strieq(const char *a, const char *b) { - if (!a || !b) { - return false; - } - for (; *a && *b && lowcase(*a) == lowcase(*b); ++a, ++b) - ; - return !*a && !*b; -} - int strcompare(const char *a, const uint8_t *b, size_t bn) { assert(a && b); const uint8_t *blast = b + bn; diff --git a/src/util.h b/src/util.h index 4c554e0d..49d72dc0 100644 --- a/src/util.h +++ b/src/util.h @@ -216,6 +216,10 @@ inline bool istarts_with(const std::string &a, const std::string &b) { return istarts_with(std::begin(a), std::end(a), std::begin(b), std::end(b)); } +inline bool istarts_with(const StringRef &a, const StringRef &b) { + return istarts_with(std::begin(a), std::end(a), std::begin(b), std::end(b)); +} + template bool istarts_with(InputIt a, size_t an, const char *b) { return istarts_with(a, a + an, b, b + strlen(b)); @@ -264,6 +268,10 @@ inline bool iends_with(const std::string &a, const std::string &b) { return iends_with(std::begin(a), std::end(a), std::begin(b), std::end(b)); } +inline bool iends_with(const StringRef &a, const StringRef &b) { + return iends_with(std::begin(a), std::end(a), std::begin(b), std::end(b)); +} + template bool iends_with_l(const std::string &a, const CharT(&b)[N]) { return iends_with(std::begin(a), std::end(a), b, b + N - 1); @@ -276,24 +284,6 @@ bool iends_with_l(const StringRef &a, const CharT(&b)[N]) { int strcompare(const char *a, const uint8_t *b, size_t n); -template bool strieq(const char *a, InputIt b, size_t bn) { - if (!a) { - return false; - } - auto blast = b + bn; - for (; *a && b != blast && lowcase(*a) == lowcase(*b); ++a, ++b) - ; - return !*a && b == blast; -} - -template -bool strieq(InputIt1 a, size_t alen, InputIt2 b, size_t blen) { - if (alen != blen) { - return false; - } - return std::equal(a, a + alen, b, CaseCmp()); -} - template bool strieq(InputIt1 first1, InputIt1 last1, InputIt2 first2, InputIt2 last2) { if (std::distance(first1, last1) != std::distance(first2, last2)) { @@ -304,28 +294,26 @@ bool strieq(InputIt1 first1, InputIt1 last1, InputIt2 first2, InputIt2 last2) { } inline bool strieq(const std::string &a, const std::string &b) { - return strieq(std::begin(a), a.size(), std::begin(b), b.size()); + return strieq(std::begin(a), std::end(a), std::begin(b), std::end(b)); } -bool strieq(const char *a, const char *b); - -inline bool strieq(const char *a, const std::string &b) { - return strieq(a, b.c_str(), b.size()); +inline bool strieq(const StringRef &a, const StringRef &b) { + return strieq(std::begin(a), std::end(a), std::begin(b), std::end(b)); } template bool strieq_l(const CharT(&a)[N], InputIt b, size_t blen) { - return strieq(a, N - 1, b, blen); + return strieq(a, a + (N - 1), b, b + blen); } template bool strieq_l(const CharT(&a)[N], const std::string &b) { - return strieq(a, N - 1, std::begin(b), b.size()); + return strieq(a, a + (N - 1), std::begin(b), std::end(b)); } template bool strieq_l(const CharT(&a)[N], const StringRef &b) { - return strieq(a, N - 1, std::begin(b), b.size()); + return strieq(a, a + (N - 1), std::begin(b), std::end(b)); } template bool streq(const char *a, InputIt b, size_t bn) { @@ -338,6 +326,14 @@ template bool streq(const char *a, InputIt b, size_t bn) { return !*a && b == blast; } +template +bool streq(InputIt1 first1, InputIt1 last1, InputIt2 first2, InputIt2 last2) { + if (std::distance(first1, last1) != std::distance(first2, last2)) { + return false; + } + return std::equal(first1, last1, first2); +} + template bool streq(InputIt1 a, size_t alen, InputIt2 b, size_t blen) { if (alen != blen) { @@ -353,6 +349,10 @@ inline bool streq(const char *a, const char *b) { return streq(a, strlen(a), b, strlen(b)); } +inline bool streq(const StringRef &a, const StringRef &b) { + return streq(std::begin(a), std::end(a), std::begin(b), std::end(b)); +} + template bool streq_l(const CharT(&a)[N], InputIt b, size_t blen) { return streq(a, N - 1, b, blen); diff --git a/src/util_test.cc b/src/util_test.cc index fac806be..f5c1c305 100644 --- a/src/util_test.cc +++ b/src/util_test.cc @@ -73,17 +73,15 @@ void test_util_strieq(void) { CU_ASSERT(!util::strieq(std::string("alpha"), std::string("AlPhA "))); CU_ASSERT(!util::strieq(std::string(), std::string("AlPhA "))); - CU_ASSERT(util::strieq("alpha", "alpha", 5)); - CU_ASSERT(util::strieq("alpha", "AlPhA", 5)); - CU_ASSERT(util::strieq("", static_cast(nullptr), 0)); - CU_ASSERT(!util::strieq("alpha", "AlPhA ", 6)); - CU_ASSERT(!util::strieq("", "AlPhA ", 6)); - - CU_ASSERT(util::strieq("alpha", "alpha")); - CU_ASSERT(util::strieq("alpha", "AlPhA")); - CU_ASSERT(util::strieq("", "")); - CU_ASSERT(!util::strieq("alpha", "AlPhA ")); - CU_ASSERT(!util::strieq("", "AlPhA ")); + CU_ASSERT( + util::strieq(StringRef::from_lit("alpha"), StringRef::from_lit("alpha"))); + CU_ASSERT( + util::strieq(StringRef::from_lit("alpha"), StringRef::from_lit("AlPhA"))); + CU_ASSERT(util::strieq(StringRef{}, StringRef{})); + CU_ASSERT(!util::strieq(StringRef::from_lit("alpha"), + StringRef::from_lit("AlPhA "))); + CU_ASSERT( + !util::strieq(StringRef::from_lit(""), StringRef::from_lit("AlPhA "))); CU_ASSERT(util::strieq_l("alpha", "alpha", 5)); CU_ASSERT(util::strieq_l("alpha", "AlPhA", 5)); @@ -455,27 +453,27 @@ void test_util_get_uint64(void) { } void test_util_parse_config_str_list(void) { - auto res = util::parse_config_str_list("a"); + auto res = util::parse_config_str_list(StringRef::from_lit("a")); CU_ASSERT(1 == res.size()); CU_ASSERT("a" == res[0]); - res = util::parse_config_str_list("a,"); + res = util::parse_config_str_list(StringRef::from_lit("a,")); CU_ASSERT(2 == res.size()); CU_ASSERT("a" == res[0]); CU_ASSERT("" == res[1]); - res = util::parse_config_str_list(":a::", ':'); + res = util::parse_config_str_list(StringRef::from_lit(":a::"), ':'); CU_ASSERT(4 == res.size()); CU_ASSERT("" == res[0]); CU_ASSERT("a" == res[1]); CU_ASSERT("" == res[2]); CU_ASSERT("" == res[3]); - res = util::parse_config_str_list(""); + res = util::parse_config_str_list(StringRef{}); CU_ASSERT(1 == res.size()); CU_ASSERT("" == res[0]); - res = util::parse_config_str_list("alpha,bravo,charlie"); + res = util::parse_config_str_list(StringRef::from_lit("alpha,bravo,charlie")); CU_ASSERT(3 == res.size()); CU_ASSERT("alpha" == res[0]); CU_ASSERT("bravo" == res[1]);