shrpx: Use patricia trie for cert lookup

This commit is contained in:
Tatsuhiro Tsujikawa 2013-02-14 00:28:55 +09:00
parent e322af8a6f
commit 291cbc639b
4 changed files with 177 additions and 82 deletions

View File

@ -513,7 +513,8 @@ CertLookupTree* cert_lookup_tree_new()
CertLookupTree *tree = new CertLookupTree(); CertLookupTree *tree = new CertLookupTree();
CertNode *root = new CertNode(); CertNode *root = new CertNode();
root->ssl_ctx = 0; root->ssl_ctx = 0;
root->c = 0; root->str = 0;
root->first = root->last = 0;
tree->root = root; tree->root = root;
return tree; return tree;
} }
@ -525,11 +526,6 @@ void cert_node_del(CertNode *node)
eoi = node->next.end(); i != eoi; ++i) { eoi = node->next.end(); i != eoi; ++i) {
cert_node_del(*i); cert_node_del(*i);
} }
for(std::vector<std::pair<char*, SSL_CTX*> >::iterator i =
node->wildcard_certs.begin(), eoi = node->wildcard_certs.end();
i != eoi; ++i) {
delete [] (*i).first;
}
delete node; delete node;
} }
} // namespace } // namespace
@ -537,58 +533,85 @@ void cert_node_del(CertNode *node)
void cert_lookup_tree_del(CertLookupTree *lt) void cert_lookup_tree_del(CertLookupTree *lt)
{ {
cert_node_del(lt->root); cert_node_del(lt->root);
for(std::vector<char*>::iterator i = lt->hosts.begin(),
eoi = lt->hosts.end(); i != eoi; ++i) {
delete [] *i;
}
delete lt; delete lt;
} }
namespace { namespace {
// The |offset| is the index in the hostname we are examining. // The |offset| is the index in the hostname we are examining. We are
// going to scan from |offset| in backwards.
void cert_lookup_tree_add_cert(CertLookupTree *lt, CertNode *node, void cert_lookup_tree_add_cert(CertLookupTree *lt, CertNode *node,
SSL_CTX *ssl_ctx, SSL_CTX *ssl_ctx,
const char *hostname, size_t len, int offset) char *hostname, size_t len, int offset)
{ {
if(offset == -1) {
if(!node->ssl_ctx) {
node->ssl_ctx = ssl_ctx;
}
return;
}
int i, next_len = node->next.size(); int i, next_len = node->next.size();
char c = util::lowcase(hostname[offset]); char c = hostname[offset];
CertNode *cn = 0;
for(i = 0; i < next_len; ++i) { for(i = 0; i < next_len; ++i) {
if(node->next[i]->c == c) { cn = node->next[i];
if(cn->str[cn->first] == c) {
break; break;
} }
} }
if(i == next_len) { if(i == next_len) {
CertNode *parent = node; if(c == '*') {
int j;
for(j = offset; j >= 0; --j) {
if(hostname[j] == '*') {
// We assume hostname as wildcard hostname when first '*' is // We assume hostname as wildcard hostname when first '*' is
// encountered. Note that as per RFC 6125 (6.4.3), there are // encountered. Note that as per RFC 6125 (6.4.3), there are
// some restrictions for wildcard hostname. We just ignore // some restrictions for wildcard hostname. We just ignore
// these rules here but do the proper check when we do the // these rules here but do the proper check when we do the
// match. // match.
char *hostcopy = strdup(hostname); node->wildcard_certs.push_back(std::make_pair(hostname, ssl_ctx));
for(int k = 0; hostcopy[k]; ++k) { } else {
hostcopy[k] = util::lowcase(hostcopy[k]); int j;
}
parent->wildcard_certs.push_back(std::make_pair(hostcopy, ssl_ctx));
break;
}
CertNode *new_node = new CertNode(); CertNode *new_node = new CertNode();
new_node->ssl_ctx = 0; new_node->str = hostname;
new_node->c = util::lowcase(hostname[j]); new_node->first = offset;
parent->next.push_back(new_node); // If wildcard is found, set the region before it because we
parent = new_node; // don't include it in [first, last).
} for(j = offset; j >= 0 && hostname[j] != '*'; --j);
new_node->last = j;
if(j == -1) { if(j == -1) {
// non-wildcard hostname, exact match case. new_node->ssl_ctx = ssl_ctx;
parent->ssl_ctx = ssl_ctx; } else {
new_node->ssl_ctx = 0;
new_node->wildcard_certs.push_back(std::make_pair(hostname, ssl_ctx));
}
node->next.push_back(new_node);
} }
} else { } else {
cert_lookup_tree_add_cert(lt, node->next[i], ssl_ctx, int j;
hostname, len, offset-1); for(i = cn->first, j = offset; i > cn->last && j >= 0 &&
cn->str[i] == hostname[j]; --i, --j);
if(i == cn->last) {
if(j == -1) {
// same hostname, we don't overwrite exiting ssl_ctx
} else {
// The existing hostname is a suffix of this hostname.
// Continue matching at potion j.
cert_lookup_tree_add_cert(lt, cn, ssl_ctx, hostname, len, j);
}
} else {
CertNode *new_node = new CertNode();
new_node->ssl_ctx = cn->ssl_ctx;
new_node->str = cn->str;
new_node->first = i;
new_node->last = cn->last;
new_node->wildcard_certs.swap(cn->wildcard_certs);
cn->next.push_back(new_node);
cn->last = i;
if(j == -1) {
// This hostname is a suffix of the existing hostname.
cn->ssl_ctx = ssl_ctx;
} else {
// This hostname and existing one share suffix.
cn->ssl_ctx = 0;
cert_lookup_tree_add_cert(lt, cn, ssl_ctx, hostname, len, j);
}
}
} }
} }
} // namespace } // namespace
@ -599,23 +622,34 @@ void cert_lookup_tree_add_cert(CertLookupTree *lt, SSL_CTX *ssl_ctx,
if(len == 0) { if(len == 0) {
return; return;
} }
cert_lookup_tree_add_cert(lt, lt->root, ssl_ctx, hostname, len, len-1); // Copy hostname including terminal NULL
char *host_copy = new char[len + 1];
for(size_t i = 0; i < len; ++i) {
host_copy[i] = util::lowcase(hostname[i]);
}
host_copy[len] = '\0';
lt->hosts.push_back(host_copy);
cert_lookup_tree_add_cert(lt, lt->root, ssl_ctx, host_copy, len, len-1);
} }
namespace { namespace {
SSL_CTX* cert_lookup_tree_lookup(CertLookupTree *lt, CertNode *node, SSL_CTX* cert_lookup_tree_lookup(CertLookupTree *lt, CertNode *node,
const char *hostname, size_t len, int offset) const char *hostname, size_t len, int offset)
{ {
if(offset == -1) { int i, j;
for(i = node->first, j = offset; i > node->last && j >= 0 &&
node->str[i] == util::lowcase(hostname[j]); --i, --j);
if(i == node->last) {
if(j == -1) {
if(node->ssl_ctx) { if(node->ssl_ctx) {
// exact match
return node->ssl_ctx; return node->ssl_ctx;
} else { } else {
// Do not perform wildcard-match because '*' must match at least // Do not perform wildcard-match because '*' must match at least
// one character. // one character.
return 0; return 0;
} }
} } else {
for(std::vector<std::pair<char*, SSL_CTX*> >::iterator i = for(std::vector<std::pair<char*, SSL_CTX*> >::iterator i =
node->wildcard_certs.begin(), eoi = node->wildcard_certs.end(); node->wildcard_certs.begin(), eoi = node->wildcard_certs.end();
i != eoi; ++i) { i != eoi; ++i) {
@ -623,14 +657,18 @@ SSL_CTX* cert_lookup_tree_lookup(CertLookupTree *lt, CertNode *node,
return (*i).second; return (*i).second;
} }
} }
char c = util::lowcase(hostname[offset]); char c = util::lowcase(hostname[j]);
for(std::vector<CertNode*>::iterator i = node->next.begin(), for(std::vector<CertNode*>::iterator i = node->next.begin(),
eoi = node->next.end(); i != eoi; ++i) { eoi = node->next.end(); i != eoi; ++i) {
if((*i)->c == c) { if((*i)->str[(*i)->first] == c) {
return cert_lookup_tree_lookup(lt, *i, hostname, len, offset-1); return cert_lookup_tree_lookup(lt, *i, hostname, len, j);
} }
} }
return 0; return 0;
}
} else {
return 0;
}
} }
} // namespace } // namespace

View File

@ -65,14 +65,15 @@ void get_altnames(X509 *cert,
std::string& common_name); std::string& common_name);
// CertLookupTree forms lookup tree to get SSL_CTX whose DNS or // CertLookupTree forms lookup tree to get SSL_CTX whose DNS or
// commonName matches hostname in query. The tree is trie data // commonName matches hostname in query. The tree is patricia trie
// structure form from the tail of the hostname pattern. Each CertNode // data structure formed from the tail of the hostname pattern. Each
// contains one ASCII character in the c member and the next member // CertNode contains part of hostname str member in range [first,
// contains the following CertNode pointers ('following' means // last) member and the next member contains the following CertNode
// character before the current one). The CertNode where a hostname // pointers ('following' means character before the current one). The
// pattern ends contains its SSL_CTX pointer in the ssl_ctx member. // CertNode where a hostname pattern ends contains its SSL_CTX pointer
// For wildcard hostname pattern, we store the its pattern and SSL_CTX // in the ssl_ctx member. For wildcard hostname pattern, we store the
// in CertNode one before first "*" found from the tail. // its pattern and SSL_CTX in CertNode one before first "*" found from
// the tail.
// //
// When querying SSL_CTX with particular hostname, we match from its // When querying SSL_CTX with particular hostname, we match from its
// tail in our lookup tree. If the query goes to the first character // tail in our lookup tree. If the query goes to the first character
@ -89,14 +90,17 @@ struct CertNode {
// list of wildcard domain name and its SSL_CTX pair, the wildcard // list of wildcard domain name and its SSL_CTX pair, the wildcard
// '*' appears in this position. // '*' appears in this position.
std::vector<std::pair<char*, SSL_CTX*> > wildcard_certs; std::vector<std::pair<char*, SSL_CTX*> > wildcard_certs;
// ASCII byte in this position
char c;
// Next CertNode index of CertLookupTree::nodes // Next CertNode index of CertLookupTree::nodes
std::vector<CertNode*> next; std::vector<CertNode*> next;
char *str;
// [first, last) in the reverse direction in str, first >=
// last. This indices only work for str member.
int first, last;
}; };
struct CertLookupTree { struct CertLookupTree {
std::vector<SSL_CTX*> certs; std::vector<SSL_CTX*> certs;
std::vector<char*> hosts;
CertNode *root; CertNode *root;
}; };

View File

@ -34,15 +34,27 @@ void test_shrpx_ssl_create_lookup_tree(void)
{ {
ssl::CertLookupTree* tree = ssl::cert_lookup_tree_new(); ssl::CertLookupTree* tree = ssl::cert_lookup_tree_new();
SSL_CTX *ctxs[] = {SSL_CTX_new(TLSv1_method()), SSL_CTX *ctxs[] = {SSL_CTX_new(TLSv1_method()),
SSL_CTX_new(TLSv1_method()),
SSL_CTX_new(TLSv1_method()),
SSL_CTX_new(TLSv1_method()),
SSL_CTX_new(TLSv1_method()),
SSL_CTX_new(TLSv1_method()),
SSL_CTX_new(TLSv1_method()), SSL_CTX_new(TLSv1_method()),
SSL_CTX_new(TLSv1_method()), SSL_CTX_new(TLSv1_method()),
SSL_CTX_new(TLSv1_method()), SSL_CTX_new(TLSv1_method()),
SSL_CTX_new(TLSv1_method())}; SSL_CTX_new(TLSv1_method())};
const char *hostnames[] = { "example.com", const char *hostnames[] = { "example.com",
"www.example.org", "www.example.org",
"*www.example.org", "*www.example.org",
"x*.host.domain", "x*.host.domain",
"*yy.host.domain"}; "*yy.host.domain",
"spdylay.sourceforge.net",
"sourceforge.net",
"sourceforge.net", // duplicate
"*.foo.bar", // oo.bar is suffix of *.foo.bar
"oo.bar"
};
int num = sizeof(ctxs)/sizeof(ctxs[0]); int num = sizeof(ctxs)/sizeof(ctxs[0]);
for(int i = 0; i < num; ++i) { for(int i = 0; i < num; ++i) {
ssl::cert_lookup_tree_add_cert(tree, ctxs[i], hostnames[i], ssl::cert_lookup_tree_add_cert(tree, ctxs[i], hostnames[i],
@ -61,12 +73,23 @@ void test_shrpx_ssl_create_lookup_tree(void)
CU_ASSERT(ctxs[3] == ssl::cert_lookup_tree_lookup(tree, h3, strlen(h3))); CU_ASSERT(ctxs[3] == ssl::cert_lookup_tree_lookup(tree, h3, strlen(h3)));
// Does not match *yy.host.domain, because * must match at least 1 // Does not match *yy.host.domain, because * must match at least 1
// character. // character.
const char h4[] = "yy.host.domain"; const char h4[] = "yy.Host.domain";
CU_ASSERT(0 == ssl::cert_lookup_tree_lookup(tree, h4, strlen(h4))); CU_ASSERT(0 == ssl::cert_lookup_tree_lookup(tree, h4, strlen(h4)));
const char h5[] = "zyy.host.domain"; const char h5[] = "zyy.host.domain";
CU_ASSERT(ctxs[4] == ssl::cert_lookup_tree_lookup(tree, h5, strlen(h5))); CU_ASSERT(ctxs[4] == ssl::cert_lookup_tree_lookup(tree, h5, strlen(h5)));
CU_ASSERT(0 == ssl::cert_lookup_tree_lookup(tree, "", 0)); CU_ASSERT(0 == ssl::cert_lookup_tree_lookup(tree, "", 0));
CU_ASSERT(ctxs[5] == ssl::cert_lookup_tree_lookup(tree, hostnames[5],
strlen(hostnames[5])));
CU_ASSERT(ctxs[6] == ssl::cert_lookup_tree_lookup(tree, hostnames[6],
strlen(hostnames[6])));
const char h6[] = "pdylay.sourceforge.net";
for(int i = 0; i < 7; ++i) {
CU_ASSERT(0 == ssl::cert_lookup_tree_lookup(tree, h6 + i, strlen(h6) - i));
}
const char h7[] = "x.foo.bar";
CU_ASSERT(ctxs[8] == ssl::cert_lookup_tree_lookup(tree, h7, strlen(h7)));
CU_ASSERT(ctxs[9] == ssl::cert_lookup_tree_lookup(tree, hostnames[9],
strlen(hostnames[9])));
ssl::cert_lookup_tree_del(tree); ssl::cert_lookup_tree_del(tree);
for(int i = 0; i < num; ++i) { for(int i = 0; i < num; ++i) {
SSL_CTX_free(ctxs[i]); SSL_CTX_free(ctxs[i]);

View File

@ -300,11 +300,41 @@ char lowcase(char c);
inline char lowcase(char c) inline char lowcase(char c)
{ {
if('A' <= c && c <= 'Z') { static unsigned char tbl[] = {
return c-'A'+'a'; 0, 1, 2, 3, 4, 5, 6, 7,
} else { 8, 9, 10, 11, 12, 13, 14, 15,
return c; 16, 17, 18, 19, 20, 21, 22, 23,
} 24, 25, 26, 27, 28, 29, 30, 31,
32, 33, 34, 35, 36, 37, 38, 39,
40, 41, 42, 43, 44, 45, 46, 47,
48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63,
64, 'a', 'b', 'c', 'd', 'e', 'f', 'g',
'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o',
'p', 'q', 'r', 's', 't', 'u', 'v', 'w',
'x', 'y', 'z', 91, 92, 93, 94, 95,
96, 97, 98, 99, 100, 101, 102, 103,
104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119,
120, 121, 122, 123, 124, 125, 126, 127,
128, 129, 130, 131, 132, 133, 134, 135,
136, 137, 138, 139, 140, 141, 142, 143,
144, 145, 146, 147, 148, 149, 150, 151,
152, 153, 154, 155, 156, 157, 158, 159,
160, 161, 162, 163, 164, 165, 166, 167,
168, 169, 170, 171, 172, 173, 174, 175,
176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191,
192, 193, 194, 195, 196, 197, 198, 199,
200, 201, 202, 203, 204, 205, 206, 207,
208, 209, 210, 211, 212, 213, 214, 215,
216, 217, 218, 219, 220, 221, 222, 223,
224, 225, 226, 227, 228, 229, 230, 231,
232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247,
248, 249, 250, 251, 252, 253, 254, 255,
};
return tbl[static_cast<unsigned char>(c)];
} }
template<typename T> template<typename T>