nghttpx: Remove strieq(const char*, cosnt char*) overload, and fix unittests

This commit is contained in:
Tatsuhiro Tsujikawa 2016-03-24 23:16:20 +09:00
parent 13596bde90
commit 372123c178
7 changed files with 173 additions and 173 deletions

View File

@ -37,40 +37,40 @@
namespace shrpx { namespace shrpx {
void test_shrpx_config_parse_header(void) { void test_shrpx_config_parse_header(void) {
auto p = parse_header("a: b"); auto p = parse_header(StringRef::from_lit("a: b"));
CU_ASSERT("a" == p.name); CU_ASSERT("a" == p.name);
CU_ASSERT("b" == p.value); CU_ASSERT("b" == p.value);
p = parse_header("a: b"); p = parse_header(StringRef::from_lit("a: b"));
CU_ASSERT("a" == p.name); CU_ASSERT("a" == p.name);
CU_ASSERT("b" == p.value); CU_ASSERT("b" == p.value);
p = parse_header(":a: b"); p = parse_header(StringRef::from_lit(":a: b"));
CU_ASSERT(p.name.empty()); CU_ASSERT(p.name.empty());
p = parse_header("a: :b"); p = parse_header(StringRef::from_lit("a: :b"));
CU_ASSERT("a" == p.name); CU_ASSERT("a" == p.name);
CU_ASSERT(":b" == p.value); CU_ASSERT(":b" == p.value);
p = parse_header(": b"); p = parse_header(StringRef::from_lit(": b"));
CU_ASSERT(p.name.empty()); CU_ASSERT(p.name.empty());
p = parse_header("alpha: bravo charlie"); p = parse_header(StringRef::from_lit("alpha: bravo charlie"));
CU_ASSERT("alpha" == p.name); CU_ASSERT("alpha" == p.name);
CU_ASSERT("bravo charlie" == p.value); CU_ASSERT("bravo charlie" == p.value);
p = parse_header("a,: b"); p = parse_header(StringRef::from_lit("a,: b"));
CU_ASSERT(p.name.empty()); CU_ASSERT(p.name.empty());
p = parse_header("a: b\x0a"); p = parse_header(StringRef::from_lit("a: b\x0a"));
CU_ASSERT(p.name.empty()); CU_ASSERT(p.name.empty());
} }
void test_shrpx_config_parse_log_format(void) { void test_shrpx_config_parse_log_format(void) {
auto res = auto res = parse_log_format(StringRef::from_lit(
parse_log_format(R"($remote_addr - $remote_user [$time_local] )" R"($remote_addr - $remote_user [$time_local] )"
R"("$request" $status $body_bytes_sent )" R"("$request" $status $body_bytes_sent )"
R"("${http_referer}" $http_host "$http_user_agent")"); R"("${http_referer}" $http_host "$http_user_agent")"));
CU_ASSERT(16 == res.size()); CU_ASSERT(16 == res.size());
CU_ASSERT(SHRPX_LOGF_REMOTE_ADDR == res[0].type); CU_ASSERT(SHRPX_LOGF_REMOTE_ADDR == res[0].type);
@ -115,35 +115,35 @@ void test_shrpx_config_parse_log_format(void) {
CU_ASSERT(SHRPX_LOGF_LITERAL == res[15].type); CU_ASSERT(SHRPX_LOGF_LITERAL == res[15].type);
CU_ASSERT("\"" == res[15].value); CU_ASSERT("\"" == res[15].value);
res = parse_log_format("$"); res = parse_log_format(StringRef::from_lit("$"));
CU_ASSERT(1 == res.size()); CU_ASSERT(1 == res.size());
CU_ASSERT(SHRPX_LOGF_LITERAL == res[0].type); CU_ASSERT(SHRPX_LOGF_LITERAL == res[0].type);
CU_ASSERT("$" == res[0].value); CU_ASSERT("$" == res[0].value);
res = parse_log_format("${"); res = parse_log_format(StringRef::from_lit("${"));
CU_ASSERT(1 == res.size()); CU_ASSERT(1 == res.size());
CU_ASSERT(SHRPX_LOGF_LITERAL == res[0].type); CU_ASSERT(SHRPX_LOGF_LITERAL == res[0].type);
CU_ASSERT("${" == res[0].value); CU_ASSERT("${" == res[0].value);
res = parse_log_format("${a"); res = parse_log_format(StringRef::from_lit("${a"));
CU_ASSERT(1 == res.size()); CU_ASSERT(1 == res.size());
CU_ASSERT(SHRPX_LOGF_LITERAL == res[0].type); CU_ASSERT(SHRPX_LOGF_LITERAL == res[0].type);
CU_ASSERT("${a" == res[0].value); CU_ASSERT("${a" == res[0].value);
res = parse_log_format("${a "); res = parse_log_format(StringRef::from_lit("${a "));
CU_ASSERT(1 == res.size()); CU_ASSERT(1 == res.size());
CU_ASSERT(SHRPX_LOGF_LITERAL == res[0].type); CU_ASSERT(SHRPX_LOGF_LITERAL == res[0].type);
CU_ASSERT("${a " == res[0].value); CU_ASSERT("${a " == res[0].value);
res = parse_log_format("$$remote_addr"); res = parse_log_format(StringRef::from_lit("$$remote_addr"));
CU_ASSERT(2 == res.size()); CU_ASSERT(2 == res.size());

View File

@ -145,7 +145,8 @@ int servername_callback(SSL *ssl, int *al, void *arg) {
if (cert_tree) { if (cert_tree) {
const char *hostname = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name); const char *hostname = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
if (hostname) { if (hostname) {
auto ssl_ctx = cert_tree->lookup(hostname, strlen(hostname)); auto len = strlen(hostname);
auto ssl_ctx = cert_tree->lookup(StringRef{hostname, len});
if (ssl_ctx) { if (ssl_ctx) {
SSL_set_SSL_CTX(ssl, ssl_ctx); SSL_set_SSL_CTX(ssl, ssl_ctx);
} }
@ -820,53 +821,56 @@ ClientHandler *accept_connection(Worker *worker, int fd, sockaddr *addr,
faddr); faddr);
} }
bool tls_hostname_match(const char *pattern, size_t plen, const char *hostname, bool tls_hostname_match(const StringRef &pattern, const StringRef &hostname) {
size_t hlen) { auto ptWildcard = std::find(std::begin(pattern), std::end(pattern), '*');
auto pend = pattern + plen; if (ptWildcard == std::end(pattern)) {
auto ptWildcard = std::find(pattern, pend, '*'); return util::strieq(pattern, hostname);
if (ptWildcard == pend) {
return util::strieq(pattern, plen, hostname, hlen);
} }
auto ptLeftLabelEnd = std::find(pattern, pend, '.'); auto ptLeftLabelEnd = std::find(std::begin(pattern), std::end(pattern), '.');
auto wildcardEnabled = true; auto wildcardEnabled = true;
// Do case-insensitive match. At least 2 dots are required to enable // Do case-insensitive match. At least 2 dots are required to enable
// wildcard match. Also wildcard must be in the left-most label. // wildcard match. Also wildcard must be in the left-most label.
// Don't attempt to match a presented identifier where the wildcard // Don't attempt to match a presented identifier where the wildcard
// character is embedded within an A-label. // character is embedded within an A-label.
if (ptLeftLabelEnd == pend || if (ptLeftLabelEnd == std::end(pattern) ||
std::find(ptLeftLabelEnd + 1, pend, '.') == pend || std::find(ptLeftLabelEnd + 1, std::end(pattern), '.') ==
ptLeftLabelEnd < ptWildcard || std::end(pattern) ||
util::istarts_with(pattern, plen, "xn--")) { ptLeftLabelEnd < ptWildcard || util::istarts_with_l(pattern, "xn--")) {
wildcardEnabled = false; wildcardEnabled = false;
} }
if (!wildcardEnabled) { if (!wildcardEnabled) {
return util::strieq(pattern, plen, hostname, hlen); return util::strieq(pattern, hostname);
} }
auto hend = hostname + hlen; auto hnLeftLabelEnd =
auto hnLeftLabelEnd = std::find(hostname, hend, '.'); std::find(std::begin(hostname), std::end(hostname), '.');
if (hnLeftLabelEnd == hend || if (hnLeftLabelEnd == std::end(hostname) ||
!util::strieq(ptLeftLabelEnd, pend, hnLeftLabelEnd, hend)) { !util::strieq(StringRef{ptLeftLabelEnd, std::end(pattern)},
StringRef{hnLeftLabelEnd, std::end(hostname)})) {
return false; return false;
} }
// Perform wildcard match. Here '*' must match at least one // Perform wildcard match. Here '*' must match at least one
// character. // character.
if (hnLeftLabelEnd - hostname < ptLeftLabelEnd - pattern) { if (hnLeftLabelEnd - std::begin(hostname) <
ptLeftLabelEnd - std::begin(pattern)) {
return false; return false;
} }
return util::istarts_with(hostname, hnLeftLabelEnd, pattern, ptWildcard) && return util::istarts_with(StringRef{std::begin(hostname), hnLeftLabelEnd},
util::iends_with(hostname, hnLeftLabelEnd, ptWildcard + 1, StringRef{std::begin(pattern), ptWildcard}) &&
ptLeftLabelEnd); util::iends_with(StringRef{std::begin(hostname), hnLeftLabelEnd},
StringRef{ptWildcard + 1, ptLeftLabelEnd});
} }
namespace { namespace {
ssize_t get_common_name(unsigned char **out_ptr, X509 *cert) { // if return value is not empty, StringRef.c_str() must be freed using
// OPENSSL_free().
StringRef get_common_name(X509 *cert) {
auto subjectname = X509_get_subject_name(cert); auto subjectname = X509_get_subject_name(cert);
if (!subjectname) { if (!subjectname) {
LOG(WARN) << "Could not get X509 name object from the certificate."; LOG(WARN) << "Could not get X509 name object from the certificate.";
return -1; return StringRef{};
} }
int lastpos = -1; int lastpos = -1;
for (;;) { for (;;) {
@ -876,22 +880,29 @@ ssize_t get_common_name(unsigned char **out_ptr, X509 *cert) {
} }
auto entry = X509_NAME_get_entry(subjectname, lastpos); auto entry = X509_NAME_get_entry(subjectname, lastpos);
auto outlen = ASN1_STRING_to_UTF8(out_ptr, X509_NAME_ENTRY_get_data(entry)); unsigned char *p;
if (outlen < 0) { auto plen = ASN1_STRING_to_UTF8(&p, X509_NAME_ENTRY_get_data(entry));
if (plen < 0) {
continue; continue;
} }
if (std::find(*out_ptr, *out_ptr + outlen, '\0') != *out_ptr + outlen) { if (std::find(p, p + plen, '\0') != p + plen) {
// Embedded NULL is not permitted. // Embedded NULL is not permitted.
continue; continue;
} }
return outlen; if (plen == 0) {
LOG(WARN) << "X509 name is empty";
OPENSSL_free(p);
continue;
} }
return -1;
return StringRef{p, static_cast<size_t>(plen)};
}
return StringRef{};
} }
} // namespace } // namespace
namespace { namespace {
int verify_numeric_hostname(X509 *cert, const char *hostname, size_t hlen, int verify_numeric_hostname(X509 *cert, const StringRef &hostname,
const Address *addr) { const Address *addr) {
const void *saddr; const void *saddr;
switch (addr->su.storage.ss_family) { switch (addr->su.storage.ss_family) {
@ -928,15 +939,14 @@ int verify_numeric_hostname(X509 *cert, const char *hostname, size_t hlen,
} }
} }
unsigned char *cn; auto cn = get_common_name(cert);
auto cnlen = get_common_name(&cn, cert); if (cn.empty()) {
if (cnlen == -1) {
return -1; return -1;
} }
// cn is not NULL terminated // cn is not NULL terminated
auto rv = util::streq(hostname, hlen, cn, cnlen); auto rv = util::streq(hostname, cn);
OPENSSL_free(cn); OPENSSL_free(const_cast<char *>(cn.c_str()));
if (rv) { if (rv) {
return 0; return 0;
@ -947,10 +957,10 @@ int verify_numeric_hostname(X509 *cert, const char *hostname, size_t hlen,
} // namespace } // namespace
namespace { namespace {
int verify_hostname(X509 *cert, const char *hostname, size_t hlen, int verify_hostname(X509 *cert, const StringRef &hostname,
const Address *addr) { const Address *addr) {
if (util::numeric_host(hostname)) { if (util::numeric_host(hostname.c_str())) {
return verify_numeric_hostname(cert, hostname, hlen, addr); return verify_numeric_hostname(cert, hostname, addr);
} }
auto altnames = static_cast<GENERAL_NAMES *>( auto altnames = static_cast<GENERAL_NAMES *>(
@ -975,20 +985,20 @@ int verify_hostname(X509 *cert, const char *hostname, size_t hlen,
continue; continue;
} }
if (tls_hostname_match(name, len, hostname, hlen)) { if (tls_hostname_match(StringRef{name, static_cast<size_t>(len)},
hostname)) {
return 0; return 0;
} }
} }
} }
unsigned char *cn; auto cn = get_common_name(cert);
auto cnlen = get_common_name(&cn, cert); if (cn.empty()) {
if (cnlen == -1) {
return -1; return -1;
} }
auto rv = util::strieq(hostname, hlen, cn, cnlen); auto rv = util::strieq(hostname, cn);
OPENSSL_free(cn); OPENSSL_free(const_cast<char *>(cn.c_str()));
if (rv) { if (rv) {
return 0; return 0;
@ -1012,7 +1022,7 @@ int check_cert(SSL *ssl, const Address *addr, const StringRef &host) {
return -1; return -1;
} }
if (verify_hostname(cert, host.c_str(), host.size(), addr) != 0) { if (verify_hostname(cert, host, addr) != 0) {
LOG(ERROR) << "Certificate verification failed: hostname does not match"; LOG(ERROR) << "Certificate verification failed: hostname does not match";
return -1; return -1;
} }
@ -1138,8 +1148,8 @@ void CertLookupTree::add_cert(SSL_CTX *ssl_ctx, const char *hostname,
} }
namespace { namespace {
SSL_CTX *cert_lookup_tree_lookup(CertNode *node, const char *hostname, SSL_CTX *cert_lookup_tree_lookup(CertNode *node, const StringRef &hostname,
size_t len, int offset) { int offset) {
int i, j; int i, j;
for (i = node->first, j = offset; for (i = node->first, j = offset;
i > node->last && j >= 0 && node->str[i] == util::lowcase(hostname[j]); i > node->last && j >= 0 && node->str[i] == util::lowcase(hostname[j]);
@ -1160,23 +1170,26 @@ SSL_CTX *cert_lookup_tree_lookup(CertNode *node, const char *hostname,
} }
for (const auto &wildcert : node->wildcard_certs) { for (const auto &wildcert : node->wildcard_certs) {
if (tls_hostname_match(wildcert.hostname, wildcert.hostnamelen, hostname, if (tls_hostname_match(StringRef{wildcert.hostname, wildcert.hostnamelen},
len)) { hostname)) {
return wildcert.ssl_ctx; return wildcert.ssl_ctx;
} }
} }
auto c = util::lowcase(hostname[j]); auto c = util::lowcase(hostname[j]);
for (const auto &next_node : node->next) { for (const auto &next_node : node->next) {
if (next_node->str[next_node->first] == c) { if (next_node->str[next_node->first] == c) {
return cert_lookup_tree_lookup(next_node.get(), hostname, len, j); return cert_lookup_tree_lookup(next_node.get(), hostname, j);
} }
} }
return nullptr; return nullptr;
} }
} // namespace } // namespace
SSL_CTX *CertLookupTree::lookup(const char *hostname, size_t len) { SSL_CTX *CertLookupTree::lookup(const StringRef &hostname) {
return cert_lookup_tree_lookup(&root_, hostname, len, len - 1); if (hostname.empty()) {
return nullptr;
}
return cert_lookup_tree_lookup(&root_, hostname, hostname.size() - 1);
} }
int cert_lookup_tree_add_cert_from_file(CertLookupTree *lt, SSL_CTX *ssl_ctx, int cert_lookup_tree_add_cert_from_file(CertLookupTree *lt, SSL_CTX *ssl_ctx,
@ -1225,15 +1238,14 @@ int cert_lookup_tree_add_cert_from_file(CertLookupTree *lt, SSL_CTX *ssl_ctx,
} }
} }
unsigned char *cn; auto cn = get_common_name(cert);
auto cnlen = get_common_name(&cn, cert); if (cn.empty()) {
if (cnlen == -1) {
return 0; return 0;
} }
lt->add_cert(ssl_ctx, reinterpret_cast<char *>(cn), cnlen); lt->add_cert(ssl_ctx, cn.c_str(), cn.size());
OPENSSL_free(cn); OPENSSL_free(const_cast<char *>(cn.c_str()));
return 0; return 0;
} }

View File

@ -143,11 +143,11 @@ public:
// to the lookup tree. The |hostname| must be NULL-terminated. // to the lookup tree. The |hostname| must be NULL-terminated.
void add_cert(SSL_CTX *ssl_ctx, const char *hostname, size_t len); void add_cert(SSL_CTX *ssl_ctx, const char *hostname, size_t len);
// Looks up SSL_CTX using the given |hostname| with length |len|. // Looks up SSL_CTX using the given |hostname|. If more than one
// If more than one SSL_CTX which matches the query, it is undefined // SSL_CTX which matches the query, it is undefined which one is
// which one is returned. The |hostname| must be NULL-terminated. // returned. The |hostname| must be NULL-terminated. If no
// If no matching SSL_CTX found, returns NULL. // matching SSL_CTX found, returns NULL.
SSL_CTX *lookup(const char *hostname, size_t len); SSL_CTX *lookup(const StringRef &hostname);
private: private:
CertNode root_; CertNode root_;
@ -219,12 +219,11 @@ bool upstream_tls_enabled();
// Returns true if SSL/TLS is enabled on downstream // Returns true if SSL/TLS is enabled on downstream
bool downstream_tls_enabled(); bool downstream_tls_enabled();
// Performs TLS hostname match. |pattern| of length |plen| can // Performs TLS hostname match. |pattern| can contain wildcard
// contain wildcard character '*', which matches prefix of target // character '*', which matches prefix of target hostname. There are
// hostname. There are several restrictions to make wildcard work. // several restrictions to make wildcard work. The matching algorithm
// The matching algorithm is based on RFC 6125. // is based on RFC 6125.
bool tls_hostname_match(const char *pattern, size_t plen, const char *hostname, bool tls_hostname_match(const StringRef &pattern, const StringRef &hostname);
size_t hlen);
// Caches |session| which is associated to remote address |addr|. // Caches |session| which is associated to remote address |addr|.
// |session| is serialized into ASN1 representation, and stored. |t| // |session| is serialized into ASN1 representation, and stored. |t|

View File

@ -43,41 +43,40 @@ void test_shrpx_ssl_create_lookup_tree(void) {
SSL_CTX_new(SSLv23_method()), SSL_CTX_new(SSLv23_method()), SSL_CTX_new(SSLv23_method()), SSL_CTX_new(SSLv23_method()),
SSL_CTX_new(SSLv23_method()), SSL_CTX_new(SSLv23_method())}; SSL_CTX_new(SSLv23_method()), SSL_CTX_new(SSLv23_method())};
const char *hostnames[] = { constexpr StringRef hostnames[] = {
"example.com", "www.example.org", "*www.example.org", "x*.host.domain", StringRef::from_lit("example.com"),
"*yy.host.domain", "nghttp2.sourceforge.net", "sourceforge.net", StringRef::from_lit("www.example.org"),
"sourceforge.net", // duplicate StringRef::from_lit("*www.example.org"),
"*.foo.bar", // oo.bar is suffix of *.foo.bar StringRef::from_lit("x*.host.domain"),
"oo.bar"}; StringRef::from_lit("*yy.host.domain"),
int num = array_size(ctxs); StringRef::from_lit("nghttp2.sourceforge.net"),
for (int i = 0; i < num; ++i) { StringRef::from_lit("sourceforge.net"),
tree->add_cert(ctxs[i], hostnames[i], strlen(hostnames[i])); StringRef::from_lit("sourceforge.net"), // duplicate
StringRef::from_lit("*.foo.bar"), // oo.bar is suffix of *.foo.bar
StringRef::from_lit("oo.bar")};
auto num = array_size(ctxs);
for (size_t i = 0; i < num; ++i) {
tree->add_cert(ctxs[i], hostnames[i].c_str(), hostnames[i].size());
} }
CU_ASSERT(ctxs[0] == tree->lookup(hostnames[0], strlen(hostnames[0]))); CU_ASSERT(ctxs[0] == tree->lookup(hostnames[0]));
CU_ASSERT(ctxs[1] == tree->lookup(hostnames[1], strlen(hostnames[1]))); CU_ASSERT(ctxs[1] == tree->lookup(hostnames[1]));
const char h1[] = "2www.example.org"; CU_ASSERT(ctxs[2] == tree->lookup(StringRef::from_lit("2www.example.org")));
CU_ASSERT(ctxs[2] == tree->lookup(h1, strlen(h1))); CU_ASSERT(nullptr == tree->lookup(StringRef::from_lit("www2.example.org")));
const char h2[] = "www2.example.org"; CU_ASSERT(ctxs[3] == tree->lookup(StringRef::from_lit("x1.host.domain")));
CU_ASSERT(0 == tree->lookup(h2, strlen(h2)));
const char h3[] = "x1.host.domain";
CU_ASSERT(ctxs[3] == tree->lookup(h3, strlen(h3)));
// Does not match *yy.host.domain, because * must match at least 1 // Does not match *yy.host.domain, because * must match at least 1
// character. // character.
const char h4[] = "yy.Host.domain"; CU_ASSERT(nullptr == tree->lookup(StringRef::from_lit("yy.Host.domain")));
CU_ASSERT(0 == tree->lookup(h4, strlen(h4))); CU_ASSERT(ctxs[4] == tree->lookup(StringRef::from_lit("zyy.host.domain")));
const char h5[] = "zyy.host.domain"; CU_ASSERT(nullptr == tree->lookup(StringRef{}));
CU_ASSERT(ctxs[4] == tree->lookup(h5, strlen(h5))); CU_ASSERT(ctxs[5] == tree->lookup(hostnames[5]));
CU_ASSERT(0 == tree->lookup("", 0)); CU_ASSERT(ctxs[6] == tree->lookup(hostnames[6]));
CU_ASSERT(ctxs[5] == tree->lookup(hostnames[5], strlen(hostnames[5]))); constexpr char h6[] = "pdylay.sourceforge.net";
CU_ASSERT(ctxs[6] == tree->lookup(hostnames[6], strlen(hostnames[6])));
const char h6[] = "pdylay.sourceforge.net";
for (int i = 0; i < 7; ++i) { for (int i = 0; i < 7; ++i) {
CU_ASSERT(0 == tree->lookup(h6 + i, strlen(h6) - i)); CU_ASSERT(0 == tree->lookup(StringRef{h6 + i, str_size(h6) - i}));
} }
const char h7[] = "x.foo.bar"; CU_ASSERT(ctxs[8] == tree->lookup(StringRef::from_lit("x.foo.bar")));
CU_ASSERT(ctxs[8] == tree->lookup(h7, strlen(h7))); CU_ASSERT(ctxs[9] == tree->lookup(hostnames[9]));
CU_ASSERT(ctxs[9] == tree->lookup(hostnames[9], strlen(hostnames[9])));
for (int i = 0; i < num; ++i) { for (int i = 0; i < num; ++i) {
SSL_CTX_free(ctxs[i]); SSL_CTX_free(ctxs[i]);
@ -86,18 +85,20 @@ void test_shrpx_ssl_create_lookup_tree(void) {
SSL_CTX *ctxs2[] = { SSL_CTX *ctxs2[] = {
SSL_CTX_new(SSLv23_method()), SSL_CTX_new(SSLv23_method()), SSL_CTX_new(SSLv23_method()), SSL_CTX_new(SSLv23_method()),
SSL_CTX_new(SSLv23_method()), SSL_CTX_new(SSLv23_method())}; SSL_CTX_new(SSLv23_method()), SSL_CTX_new(SSLv23_method())};
const char *names[] = {"rab", "zab", "zzub", "ab"}; constexpr StringRef names[] = {
StringRef::from_lit("rab"), StringRef::from_lit("zab"),
StringRef::from_lit("zzub"), StringRef::from_lit("ab")};
num = array_size(ctxs2); num = array_size(ctxs2);
tree = make_unique<ssl::CertLookupTree>(); tree = make_unique<ssl::CertLookupTree>();
for (int i = 0; i < num; ++i) { for (size_t i = 0; i < num; ++i) {
tree->add_cert(ctxs2[i], names[i], strlen(names[i])); tree->add_cert(ctxs2[i], names[i].c_str(), names[i].size());
} }
for (int i = 0; i < num; ++i) { for (size_t i = 0; i < num; ++i) {
CU_ASSERT(ctxs2[i] == tree->lookup(names[i], strlen(names[i]))); CU_ASSERT(ctxs2[i] == tree->lookup(names[i]));
} }
for (int i = 0; i < num; ++i) { for (size_t i = 0; i < num; ++i) {
SSL_CTX_free(ctxs2[i]); SSL_CTX_free(ctxs2[i]);
} }
} }
@ -109,8 +110,7 @@ void test_shrpx_ssl_cert_lookup_tree_add_cert_from_file(void) {
const char certfile[] = NGHTTP2_TESTS_DIR "/testdata/cacert.pem"; const char certfile[] = NGHTTP2_TESTS_DIR "/testdata/cacert.pem";
rv = ssl::cert_lookup_tree_add_cert_from_file(&tree, ssl_ctx, certfile); rv = ssl::cert_lookup_tree_add_cert_from_file(&tree, ssl_ctx, certfile);
CU_ASSERT(0 == rv); CU_ASSERT(0 == rv);
const char localhost[] = "localhost"; CU_ASSERT(ssl_ctx == tree.lookup(StringRef::from_lit("localhost")));
CU_ASSERT(ssl_ctx == tree.lookup(localhost, sizeof(localhost) - 1));
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx);
} }
@ -118,7 +118,7 @@ void test_shrpx_ssl_cert_lookup_tree_add_cert_from_file(void) {
template <size_t N, size_t M> template <size_t N, size_t M>
bool tls_hostname_match_wrapper(const char(&pattern)[N], bool tls_hostname_match_wrapper(const char(&pattern)[N],
const char(&hostname)[M]) { const char(&hostname)[M]) {
return ssl::tls_hostname_match(pattern, N, hostname, M); return ssl::tls_hostname_match(StringRef{pattern, N}, StringRef{hostname, M});
} }
void test_shrpx_ssl_tls_hostname_match(void) { void test_shrpx_ssl_tls_hostname_match(void) {

View File

@ -363,15 +363,6 @@ bool istarts_with(const char *a, const char *b) {
return !*b; return !*b;
} }
bool strieq(const char *a, const char *b) {
if (!a || !b) {
return false;
}
for (; *a && *b && lowcase(*a) == lowcase(*b); ++a, ++b)
;
return !*a && !*b;
}
int strcompare(const char *a, const uint8_t *b, size_t bn) { int strcompare(const char *a, const uint8_t *b, size_t bn) {
assert(a && b); assert(a && b);
const uint8_t *blast = b + bn; const uint8_t *blast = b + bn;

View File

@ -216,6 +216,10 @@ inline bool istarts_with(const std::string &a, const std::string &b) {
return istarts_with(std::begin(a), std::end(a), std::begin(b), std::end(b)); return istarts_with(std::begin(a), std::end(a), std::begin(b), std::end(b));
} }
inline bool istarts_with(const StringRef &a, const StringRef &b) {
return istarts_with(std::begin(a), std::end(a), std::begin(b), std::end(b));
}
template <typename InputIt> template <typename InputIt>
bool istarts_with(InputIt a, size_t an, const char *b) { bool istarts_with(InputIt a, size_t an, const char *b) {
return istarts_with(a, a + an, b, b + strlen(b)); return istarts_with(a, a + an, b, b + strlen(b));
@ -264,6 +268,10 @@ inline bool iends_with(const std::string &a, const std::string &b) {
return iends_with(std::begin(a), std::end(a), std::begin(b), std::end(b)); return iends_with(std::begin(a), std::end(a), std::begin(b), std::end(b));
} }
inline bool iends_with(const StringRef &a, const StringRef &b) {
return iends_with(std::begin(a), std::end(a), std::begin(b), std::end(b));
}
template <typename CharT, size_t N> template <typename CharT, size_t N>
bool iends_with_l(const std::string &a, const CharT(&b)[N]) { bool iends_with_l(const std::string &a, const CharT(&b)[N]) {
return iends_with(std::begin(a), std::end(a), b, b + N - 1); return iends_with(std::begin(a), std::end(a), b, b + N - 1);
@ -276,24 +284,6 @@ bool iends_with_l(const StringRef &a, const CharT(&b)[N]) {
int strcompare(const char *a, const uint8_t *b, size_t n); int strcompare(const char *a, const uint8_t *b, size_t n);
template <typename InputIt> bool strieq(const char *a, InputIt b, size_t bn) {
if (!a) {
return false;
}
auto blast = b + bn;
for (; *a && b != blast && lowcase(*a) == lowcase(*b); ++a, ++b)
;
return !*a && b == blast;
}
template <typename InputIt1, typename InputIt2>
bool strieq(InputIt1 a, size_t alen, InputIt2 b, size_t blen) {
if (alen != blen) {
return false;
}
return std::equal(a, a + alen, b, CaseCmp());
}
template <typename InputIt1, typename InputIt2> template <typename InputIt1, typename InputIt2>
bool strieq(InputIt1 first1, InputIt1 last1, InputIt2 first2, InputIt2 last2) { bool strieq(InputIt1 first1, InputIt1 last1, InputIt2 first2, InputIt2 last2) {
if (std::distance(first1, last1) != std::distance(first2, last2)) { if (std::distance(first1, last1) != std::distance(first2, last2)) {
@ -304,28 +294,26 @@ bool strieq(InputIt1 first1, InputIt1 last1, InputIt2 first2, InputIt2 last2) {
} }
inline bool strieq(const std::string &a, const std::string &b) { inline bool strieq(const std::string &a, const std::string &b) {
return strieq(std::begin(a), a.size(), std::begin(b), b.size()); return strieq(std::begin(a), std::end(a), std::begin(b), std::end(b));
} }
bool strieq(const char *a, const char *b); inline bool strieq(const StringRef &a, const StringRef &b) {
return strieq(std::begin(a), std::end(a), std::begin(b), std::end(b));
inline bool strieq(const char *a, const std::string &b) {
return strieq(a, b.c_str(), b.size());
} }
template <typename CharT, typename InputIt, size_t N> template <typename CharT, typename InputIt, size_t N>
bool strieq_l(const CharT(&a)[N], InputIt b, size_t blen) { bool strieq_l(const CharT(&a)[N], InputIt b, size_t blen) {
return strieq(a, N - 1, b, blen); return strieq(a, a + (N - 1), b, b + blen);
} }
template <typename CharT, size_t N> template <typename CharT, size_t N>
bool strieq_l(const CharT(&a)[N], const std::string &b) { bool strieq_l(const CharT(&a)[N], const std::string &b) {
return strieq(a, N - 1, std::begin(b), b.size()); return strieq(a, a + (N - 1), std::begin(b), std::end(b));
} }
template <typename CharT, size_t N> template <typename CharT, size_t N>
bool strieq_l(const CharT(&a)[N], const StringRef &b) { bool strieq_l(const CharT(&a)[N], const StringRef &b) {
return strieq(a, N - 1, std::begin(b), b.size()); return strieq(a, a + (N - 1), std::begin(b), std::end(b));
} }
template <typename InputIt> bool streq(const char *a, InputIt b, size_t bn) { template <typename InputIt> bool streq(const char *a, InputIt b, size_t bn) {
@ -338,6 +326,14 @@ template <typename InputIt> bool streq(const char *a, InputIt b, size_t bn) {
return !*a && b == blast; return !*a && b == blast;
} }
template <typename InputIt1, typename InputIt2>
bool streq(InputIt1 first1, InputIt1 last1, InputIt2 first2, InputIt2 last2) {
if (std::distance(first1, last1) != std::distance(first2, last2)) {
return false;
}
return std::equal(first1, last1, first2);
}
template <typename InputIt1, typename InputIt2> template <typename InputIt1, typename InputIt2>
bool streq(InputIt1 a, size_t alen, InputIt2 b, size_t blen) { bool streq(InputIt1 a, size_t alen, InputIt2 b, size_t blen) {
if (alen != blen) { if (alen != blen) {
@ -353,6 +349,10 @@ inline bool streq(const char *a, const char *b) {
return streq(a, strlen(a), b, strlen(b)); return streq(a, strlen(a), b, strlen(b));
} }
inline bool streq(const StringRef &a, const StringRef &b) {
return streq(std::begin(a), std::end(a), std::begin(b), std::end(b));
}
template <typename CharT, typename InputIt, size_t N> template <typename CharT, typename InputIt, size_t N>
bool streq_l(const CharT(&a)[N], InputIt b, size_t blen) { bool streq_l(const CharT(&a)[N], InputIt b, size_t blen) {
return streq(a, N - 1, b, blen); return streq(a, N - 1, b, blen);

View File

@ -73,17 +73,15 @@ void test_util_strieq(void) {
CU_ASSERT(!util::strieq(std::string("alpha"), std::string("AlPhA "))); CU_ASSERT(!util::strieq(std::string("alpha"), std::string("AlPhA ")));
CU_ASSERT(!util::strieq(std::string(), std::string("AlPhA "))); CU_ASSERT(!util::strieq(std::string(), std::string("AlPhA ")));
CU_ASSERT(util::strieq("alpha", "alpha", 5)); CU_ASSERT(
CU_ASSERT(util::strieq("alpha", "AlPhA", 5)); util::strieq(StringRef::from_lit("alpha"), StringRef::from_lit("alpha")));
CU_ASSERT(util::strieq("", static_cast<const char *>(nullptr), 0)); CU_ASSERT(
CU_ASSERT(!util::strieq("alpha", "AlPhA ", 6)); util::strieq(StringRef::from_lit("alpha"), StringRef::from_lit("AlPhA")));
CU_ASSERT(!util::strieq("", "AlPhA ", 6)); CU_ASSERT(util::strieq(StringRef{}, StringRef{}));
CU_ASSERT(!util::strieq(StringRef::from_lit("alpha"),
CU_ASSERT(util::strieq("alpha", "alpha")); StringRef::from_lit("AlPhA ")));
CU_ASSERT(util::strieq("alpha", "AlPhA")); CU_ASSERT(
CU_ASSERT(util::strieq("", "")); !util::strieq(StringRef::from_lit(""), StringRef::from_lit("AlPhA ")));
CU_ASSERT(!util::strieq("alpha", "AlPhA "));
CU_ASSERT(!util::strieq("", "AlPhA "));
CU_ASSERT(util::strieq_l("alpha", "alpha", 5)); CU_ASSERT(util::strieq_l("alpha", "alpha", 5));
CU_ASSERT(util::strieq_l("alpha", "AlPhA", 5)); CU_ASSERT(util::strieq_l("alpha", "AlPhA", 5));
@ -455,27 +453,27 @@ void test_util_get_uint64(void) {
} }
void test_util_parse_config_str_list(void) { void test_util_parse_config_str_list(void) {
auto res = util::parse_config_str_list("a"); auto res = util::parse_config_str_list(StringRef::from_lit("a"));
CU_ASSERT(1 == res.size()); CU_ASSERT(1 == res.size());
CU_ASSERT("a" == res[0]); CU_ASSERT("a" == res[0]);
res = util::parse_config_str_list("a,"); res = util::parse_config_str_list(StringRef::from_lit("a,"));
CU_ASSERT(2 == res.size()); CU_ASSERT(2 == res.size());
CU_ASSERT("a" == res[0]); CU_ASSERT("a" == res[0]);
CU_ASSERT("" == res[1]); CU_ASSERT("" == res[1]);
res = util::parse_config_str_list(":a::", ':'); res = util::parse_config_str_list(StringRef::from_lit(":a::"), ':');
CU_ASSERT(4 == res.size()); CU_ASSERT(4 == res.size());
CU_ASSERT("" == res[0]); CU_ASSERT("" == res[0]);
CU_ASSERT("a" == res[1]); CU_ASSERT("a" == res[1]);
CU_ASSERT("" == res[2]); CU_ASSERT("" == res[2]);
CU_ASSERT("" == res[3]); CU_ASSERT("" == res[3]);
res = util::parse_config_str_list(""); res = util::parse_config_str_list(StringRef{});
CU_ASSERT(1 == res.size()); CU_ASSERT(1 == res.size());
CU_ASSERT("" == res[0]); CU_ASSERT("" == res[0]);
res = util::parse_config_str_list("alpha,bravo,charlie"); res = util::parse_config_str_list(StringRef::from_lit("alpha,bravo,charlie"));
CU_ASSERT(3 == res.size()); CU_ASSERT(3 == res.size());
CU_ASSERT("alpha" == res[0]); CU_ASSERT("alpha" == res[0]);
CU_ASSERT("bravo" == res[1]); CU_ASSERT("bravo" == res[1]);