diff --git a/src/shrpx_config.cc b/src/shrpx_config.cc index 99f1983d..3fae2d29 100644 --- a/src/shrpx_config.cc +++ b/src/shrpx_config.cc @@ -283,13 +283,12 @@ std::vector> split_config_str_list(const char *s, return list; } -std::vector> parse_config_str_list(const char *s, - char delim) { +std::vector parse_config_str_list(const char *s, char delim) { auto ranges = split_config_str_list(s, delim); - auto res = std::vector>(); + auto res = std::vector(); res.reserve(ranges.size()); for (const auto &range : ranges) { - res.push_back(strcopy(range.first, range.second)); + res.emplace_back(range.first, range.second); } return res; } @@ -1687,30 +1686,28 @@ int parse_config(const char *opt, const char *optarg, int port; - if (parse_uint(&port, opt, tokens[1].get()) != 0) { + if (parse_uint(&port, opt, tokens[1].c_str()) != 0) { return -1; } if (port < 1 || port > static_cast(std::numeric_limits::max())) { - LOG(ERROR) << opt << ": port is invalid: " << tokens[1].get(); + LOG(ERROR) << opt << ": port is invalid: " << tokens[1]; return -1; } AltSvc altsvc; - altsvc.port = port; - altsvc.protocol_id = std::move(tokens[0]); - altsvc.protocol_id_len = strlen(altsvc.protocol_id.get()); + + altsvc.port = port; + altsvc.service = std::move(tokens[1]); if (tokens.size() > 2) { altsvc.host = std::move(tokens[2]); - altsvc.host_len = strlen(altsvc.host.get()); if (tokens.size() > 3) { altsvc.origin = std::move(tokens[3]); - altsvc.origin_len = strlen(altsvc.origin.get()); } } diff --git a/src/shrpx_config.h b/src/shrpx_config.h index 68840c1d..2ad0297e 100644 --- a/src/shrpx_config.h +++ b/src/shrpx_config.h @@ -184,15 +184,9 @@ union sockaddr_union { enum shrpx_proto { PROTO_HTTP2, PROTO_HTTP }; struct AltSvc { - AltSvc() : protocol_id_len(0), host_len(0), origin_len(0), port(0) {} + AltSvc() : port(0) {} - std::unique_ptr protocol_id; - std::unique_ptr host; - std::unique_ptr origin; - - size_t protocol_id_len; - size_t host_len; - size_t origin_len; + std::string protocol_id, host, origin, service; uint16_t port; }; @@ -251,6 +245,11 @@ struct Config { std::vector accesslog_format; std::vector downstream_addr_groups; std::vector tls_ticket_key_files; + // list of supported NPN/ALPN protocol strings in the order of + // preference. + std::vector npn_list; + // list of supported SSL/TLS protocol strings. + std::vector tls_proto_list; // binary form of http proxy host and port sockaddr_union downstream_http_proxy_addr; ev_tstamp http2_upstream_read_timeout; @@ -286,13 +285,6 @@ struct Config { // ev_token_bucket_cfg *rate_limit_cfg; // // Rate limit configuration per worker (thread) // ev_token_bucket_cfg *worker_rate_limit_cfg; - // list of supported NPN/ALPN protocol strings in the order of - // preference. The each element of this list is a NULL-terminated - // string. - std::vector> npn_list; - // list of supported SSL/TLS protocol strings. The each element of - // this list is a NULL-terminated string. - std::vector> tls_proto_list; // Path to file containing CA certificate solely used for client // certificate validation std::unique_ptr verify_client_cacert; @@ -413,8 +405,7 @@ template using Range = std::pair; // Parses delimited strings in |s| and returns the array of substring, // delimited by |delim|. The any white spaces around substring are // treated as a part of substring. -std::vector> parse_config_str_list(const char *s, - char delim = ','); +std::vector parse_config_str_list(const char *s, char delim = ','); // Parses delimited strings in |s| and returns the array of pointers, // each element points to the beginning and one beyond last of diff --git a/src/shrpx_config_test.cc b/src/shrpx_config_test.cc index decbb50c..8ccd458c 100644 --- a/src/shrpx_config_test.cc +++ b/src/shrpx_config_test.cc @@ -39,29 +39,29 @@ namespace shrpx { void test_shrpx_config_parse_config_str_list(void) { auto res = parse_config_str_list("a"); CU_ASSERT(1 == res.size()); - CU_ASSERT(0 == strcmp("a", res[0].get())); + CU_ASSERT("a" == res[0]); res = parse_config_str_list("a,"); CU_ASSERT(2 == res.size()); - CU_ASSERT(0 == strcmp("a", res[0].get())); - CU_ASSERT(0 == strcmp("", res[1].get())); + CU_ASSERT("a" == res[0]); + CU_ASSERT("" == res[1]); res = parse_config_str_list(":a::", ':'); CU_ASSERT(4 == res.size()); - CU_ASSERT(0 == strcmp("", res[0].get())); - CU_ASSERT(0 == strcmp("a", res[1].get())); - CU_ASSERT(0 == strcmp("", res[2].get())); - CU_ASSERT(0 == strcmp("", res[3].get())); + CU_ASSERT("" == res[0]); + CU_ASSERT("a" == res[1]); + CU_ASSERT("" == res[2]); + CU_ASSERT("" == res[3]); res = parse_config_str_list(""); CU_ASSERT(1 == res.size()); - CU_ASSERT(0 == strcmp("", res[0].get())); + CU_ASSERT("" == res[0]); res = parse_config_str_list("alpha,bravo,charlie"); CU_ASSERT(3 == res.size()); - CU_ASSERT(0 == strcmp("alpha", res[0].get())); - CU_ASSERT(0 == strcmp("bravo", res[1].get())); - CU_ASSERT(0 == strcmp("charlie", res[2].get())); + CU_ASSERT("alpha" == res[0]); + CU_ASSERT("bravo" == res[1]); + CU_ASSERT("charlie" == res[2]); } void test_shrpx_config_parse_header(void) { diff --git a/src/shrpx_https_upstream.cc b/src/shrpx_https_upstream.cc index ef2659f8..4893c9fd 100644 --- a/src/shrpx_https_upstream.cc +++ b/src/shrpx_https_upstream.cc @@ -844,13 +844,12 @@ int HttpsUpstream::on_downstream_header_complete(Downstream *downstream) { if (!get_config()->altsvcs.empty()) { hdrs += "Alt-Svc: "; - for (auto &altsvc : get_config()->altsvcs) { - hdrs += util::percent_encode_token(altsvc.protocol_id.get()); + for (const auto &altsvc : get_config()->altsvcs) { + hdrs += util::percent_encode_token(altsvc.protocol_id); hdrs += "=\""; - hdrs += - util::quote_string(std::string(altsvc.host.get(), altsvc.host_len)); + hdrs += util::quote_string(altsvc.host); hdrs += ":"; - hdrs += util::utos(altsvc.port); + hdrs += altsvc.service; hdrs += "\", "; } diff --git a/src/shrpx_ssl.cc b/src/shrpx_ssl.cc index 10981803..ea5a1cc1 100644 --- a/src/shrpx_ssl.cc +++ b/src/shrpx_ssl.cc @@ -87,18 +87,16 @@ int verify_callback(int preverify_ok, X509_STORE_CTX *ctx) { } // namespace std::vector -set_alpn_prefs(const std::vector> &protos) { +set_alpn_prefs(const std::vector &protos) { size_t len = 0; - for (auto &proto : protos) { - auto n = strlen(proto.get()); - - if (n > 255) { - LOG(FATAL) << "Too long ALPN identifier: " << n; + for (const auto &proto : protos) { + if (proto.size() > 255) { + LOG(FATAL) << "Too long ALPN identifier: " << proto.size(); DIE(); } - len += 1 + n; + len += 1 + proto.size(); } if (len > (1 << 16) - 1) { @@ -109,12 +107,10 @@ set_alpn_prefs(const std::vector> &protos) { auto out = std::vector(len); auto ptr = out.data(); - for (auto &proto : protos) { - auto proto_len = strlen(proto.get()); - - *ptr++ = proto_len; - memcpy(ptr, proto.get(), proto_len); - ptr += proto_len; + for (const auto &proto : protos) { + *ptr++ = proto.size(); + memcpy(ptr, proto.c_str(), proto.size()); + ptr += proto.size(); } return out; @@ -282,15 +278,14 @@ int alpn_select_proto_cb(SSL *ssl, const unsigned char **out, // We assume that get_config()->npn_list contains ALPN protocol // identifier sorted by preference order. So we just break when we // found the first overlap. - for (auto &target_proto_id : get_config()->npn_list) { - auto target_proto_len = strlen(target_proto_id.get()); - + for (const auto &target_proto_id : get_config()->npn_list) { for (auto p = in, end = in + inlen; p < end;) { auto proto_id = p + 1; auto proto_len = *p; - if (proto_id + proto_len <= end && target_proto_len == proto_len && - memcmp(target_proto_id.get(), proto_id, proto_len) == 0) { + if (proto_id + proto_len <= end && + util::streq(target_proto_id.c_str(), target_proto_id.size(), proto_id, + proto_len)) { *out = reinterpret_cast(proto_id); *outlen = proto_len; @@ -314,14 +309,13 @@ constexpr long int tls_masks[] = {SSL_OP_NO_TLSv1_2, SSL_OP_NO_TLSv1_1, SSL_OP_NO_TLSv1}; } // namespace -long int create_tls_proto_mask( - const std::vector> &tls_proto_list) { +long int create_tls_proto_mask(const std::vector &tls_proto_list) { long int res = 0; for (size_t i = 0; i < tls_namelen; ++i) { size_t j; for (j = 0; j < tls_proto_list.size(); ++j) { - if (util::strieq(tls_names[i], tls_proto_list[j].get())) { + if (util::strieq(tls_names[i], tls_proto_list[j])) { break; } } @@ -950,10 +944,10 @@ int cert_lookup_tree_add_cert_from_file(CertLookupTree *lt, SSL_CTX *ssl_ctx, return 0; } -bool in_proto_list(const std::vector> &protos, +bool in_proto_list(const std::vector &protos, const unsigned char *needle, size_t len) { for (auto &proto : protos) { - if (strlen(proto.get()) == len && memcmp(proto.get(), needle, len) == 0) { + if (util::streq(proto.c_str(), proto.size(), needle, len)) { return true; } } diff --git a/src/shrpx_ssl.h b/src/shrpx_ssl.h index 85b3460d..de2509a8 100644 --- a/src/shrpx_ssl.h +++ b/src/shrpx_ssl.h @@ -140,7 +140,7 @@ int cert_lookup_tree_add_cert_from_file(CertLookupTree *lt, SSL_CTX *ssl_ctx, // Returns true if |needle| which has |len| bytes is included in the // protocol list |protos|. -bool in_proto_list(const std::vector> &protos, +bool in_proto_list(const std::vector &protos, const unsigned char *needle, size_t len); // Returns true if security requirement for HTTP/2 is fulfilled. @@ -149,11 +149,10 @@ bool check_http2_requirement(SSL *ssl); // Returns SSL/TLS option mask to disable SSL/TLS protocol version not // included in |tls_proto_list|. The returned mask can be directly // passed to SSL_CTX_set_options(). -long int create_tls_proto_mask( - const std::vector> &tls_proto_list); +long int create_tls_proto_mask(const std::vector &tls_proto_list); std::vector -set_alpn_prefs(const std::vector> &protos); +set_alpn_prefs(const std::vector &protos); // Setups server side SSL_CTX. This function inspects get_config() // and if upstream_no_tls is true, returns nullptr. Otherwise diff --git a/src/util.h b/src/util.h index 6c04fef0..1f84c37a 100644 --- a/src/util.h +++ b/src/util.h @@ -349,6 +349,10 @@ inline bool strieq(const std::string &a, const std::string &b) { bool strieq(const char *a, const char *b); +inline bool strieq(const char *a, const std::string &b) { + return strieq(a, b.c_str(), b.size()); +} + template bool strieq_l(const char (&a)[N], InputIt b, size_t blen) { return strieq(a, N - 1, b, blen);