Merge branch 'memcached'

This commit is contained in:
Tatsuhiro Tsujikawa 2015-07-28 21:20:21 +09:00
commit e8c83798da
35 changed files with 2117 additions and 168 deletions

114
contrib/tlsticketupdate.go Normal file
View File

@ -0,0 +1,114 @@
//
// nghttp2 - HTTP/2 C Library
//
// Copyright (c) 2015 Tatsuhiro Tsujikawa
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the
// "Software"), to deal in the Software without restriction, including
// without limitation the rights to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to
// permit persons to whom the Software is furnished to do so, subject to
// the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
//
package main
import (
"bytes"
"crypto/rand"
"encoding/binary"
"flag"
"fmt"
"github.com/bradfitz/gomemcache/memcache"
"log"
"time"
)
func makeKey(len int) []byte {
b := make([]byte, len)
if _, err := rand.Read(b); err != nil {
log.Fatalf("rand.Read: %v", err)
}
return b
}
func main() {
var host = flag.String("host", "127.0.0.1", "memcached host")
var port = flag.Int("port", 11211, "memcached port")
var cipher = flag.String("cipher", "aes-128-cbc", "cipher for TLS ticket encryption")
var interval = flag.Int("interval", 3600, "interval to update TLS ticket keys")
flag.Parse()
var keylen int
switch *cipher {
case "aes-128-cbc":
keylen = 48
case "aes-256-cbc":
keylen = 80
default:
log.Fatalf("cipher: unknown cipher %v", cipher)
}
mc := memcache.New(fmt.Sprintf("%v:%v", *host, *port))
keys := [][]byte{
makeKey(keylen), // current encryption key
makeKey(keylen), // next encryption key; now decryption only
}
for {
buf := new(bytes.Buffer)
if err := binary.Write(buf, binary.BigEndian, uint32(1)); err != nil {
log.Fatalf("failed to write version: %v", err)
}
for _, key := range keys {
if err := binary.Write(buf, binary.BigEndian, uint16(keylen)); err != nil {
log.Fatalf("failed to write length: %v", err)
}
if _, err := buf.Write(key); err != nil {
log.Fatalf("buf.Write: %v", err)
}
}
mc.Set(&memcache.Item{
Key: "nghttpx:tls-ticket-key",
Value: buf.Bytes(),
})
select {
case <-time.After(time.Duration(*interval) * time.Second):
}
// rotate keys. the last key is now encryption key.
// generate new key and append it to the last, so that
// we can at least decrypt TLS ticket encrypted by new
// key on the host which does not get new key yet.
new_keys := [][]byte{}
new_keys = append(new_keys, keys[len(keys)-1])
for i, key := range keys {
// keep at most past 11 keys as decryption
// only key
if i == len(keys)-1 || i > 11 {
break
}
new_keys = append(new_keys, key)
}
new_keys = append(new_keys, makeKey(keylen))
keys = new_keys
}
}

View File

@ -93,6 +93,11 @@ OPTIONS = [
"include",
"tls-ticket-cipher",
"host-rewrite",
"tls-session-cache-memcached",
"tls-ticket-key-memcached",
"tls-ticket-key-memcached-interval",
"tls-ticket-key-memcached-max-retry",
"tls-ticket-key-memcached-max-fail",
"conf",
]

View File

@ -120,6 +120,10 @@ NGHTTPX_SRCS = \
shrpx_downstream_connection_pool.cc shrpx_downstream_connection_pool.h \
shrpx_rate_limit.cc shrpx_rate_limit.h \
shrpx_connection.cc shrpx_connection.h \
shrpx_memcached_dispatcher.cc shrpx_memcached_dispatcher.h \
shrpx_memcached_connection.cc shrpx_memcached_connection.h \
shrpx_memcached_request.h \
shrpx_memcached_result.h \
buffer.h memchunk.h template.h
if HAVE_SPDYLAY

View File

@ -58,7 +58,17 @@ template <size_t N> struct Buffer {
pos += count;
return count;
}
size_t drain_reset(size_t count) {
count = std::min(count, rleft());
std::copy(pos + count, last, std::begin(buf));
last = std::begin(buf) + (last - (pos + count));
pos = std::begin(buf);
return count;
}
void reset() { pos = last = std::begin(buf); }
uint8_t *begin() { return std::begin(buf); }
uint8_t &operator[](size_t n) { return buf[n]; }
const uint8_t &operator[](size_t n) const { return buf[n]; }
std::array<uint8_t, N> buf;
uint8_t *pos, *last;
};

View File

@ -27,6 +27,7 @@
#include "nghttp2_config.h"
#include <limits.h>
#include <sys/uio.h>
#include <cassert>

View File

@ -164,6 +164,7 @@ int main(int argc, char *argv[]) {
shrpx::test_util_parse_http_date) ||
!CU_add_test(pSuite, "util_localtime_date",
shrpx::test_util_localtime_date) ||
!CU_add_test(pSuite, "util_get_uint64", shrpx::test_util_get_uint64) ||
!CU_add_test(pSuite, "gzip_inflate", test_nghttp2_gzip_inflate) ||
!CU_add_test(pSuite, "buffer_write", nghttp2::test_buffer_write) ||
!CU_add_test(pSuite, "pool_recycle", nghttp2::test_pool_recycle) ||

View File

@ -83,6 +83,8 @@
#include "shrpx_accept_handler.h"
#include "shrpx_http2_upstream.h"
#include "shrpx_http2_session.h"
#include "shrpx_memcached_dispatcher.h"
#include "shrpx_memcached_request.h"
#include "util.h"
#include "app_helper.h"
#include "ssl.h"
@ -117,8 +119,8 @@ const int GRACEFUL_SHUTDOWN_SIGNAL = SIGQUIT;
#define ENV_UNIX_PATH "NGHTTP2_UNIX_PATH"
namespace {
int resolve_hostname(sockaddr_union *addr, size_t *addrlen,
const char *hostname, uint16_t port, int family) {
int resolve_hostname(Address *addr, const char *hostname, uint16_t port,
int family) {
int rv;
auto service = util::utos(port);
@ -155,8 +157,8 @@ int resolve_hostname(sockaddr_union *addr, size_t *addrlen,
<< " succeeded: " << host;
}
memcpy(addr, res->ai_addr, res->ai_addrlen);
*addrlen = res->ai_addrlen;
memcpy(&addr->su, res->ai_addr, res->ai_addrlen);
addr->len = res->ai_addrlen;
freeaddrinfo(res);
return 0;
}
@ -611,8 +613,8 @@ int generate_ticket_key(TicketKey &ticket_key) {
ticket_key.hmac_keylen = EVP_MD_size(ticket_key.hmac);
assert(static_cast<size_t>(EVP_CIPHER_key_length(ticket_key.cipher)) <=
sizeof(ticket_key.data.enc_key));
assert(ticket_key.hmac_keylen <= sizeof(ticket_key.data.hmac_key));
ticket_key.data.enc_key.size());
assert(ticket_key.hmac_keylen <= ticket_key.data.hmac_key.size());
if (LOG_ENABLED(INFO)) {
LOG(INFO) << "enc_keylen=" << EVP_CIPHER_key_length(ticket_key.cipher)
@ -689,6 +691,116 @@ void renew_ticket_key_cb(struct ev_loop *loop, ev_timer *w, int revents) {
}
} // namespace
namespace {
void memcached_get_ticket_key_cb(struct ev_loop *loop, ev_timer *w,
int revents) {
auto conn_handler = static_cast<ConnectionHandler *>(w->data);
auto dispatcher = conn_handler->get_tls_ticket_key_memcached_dispatcher();
auto req = make_unique<MemcachedRequest>();
req->key = "nghttpx:tls-ticket-key";
req->op = MEMCACHED_OP_GET;
req->cb = [conn_handler, dispatcher, w](MemcachedRequest *req,
MemcachedResult res) {
switch (res.status_code) {
case MEMCACHED_ERR_NO_ERROR:
break;
case MEMCACHED_ERR_EXT_NETWORK_ERROR:
conn_handler->on_tls_ticket_key_network_error(w);
return;
default:
conn_handler->on_tls_ticket_key_not_found(w);
return;
}
// |version (4bytes)|len (2bytes)|key (variable length)|...
// (len, key) pairs are repeated as necessary.
auto &value = res.value;
if (value.size() < 4) {
LOG(WARN) << "Memcached: tls ticket key value is too small: got "
<< value.size();
conn_handler->on_tls_ticket_key_not_found(w);
return;
}
auto p = value.data();
auto version = util::get_uint32(p);
// Currently supported version is 1.
if (version != 1) {
LOG(WARN) << "Memcached: tls ticket key version: want 1, got " << version;
conn_handler->on_tls_ticket_key_not_found(w);
return;
}
auto end = p + value.size();
p += 4;
size_t expectedlen;
size_t enc_keylen;
size_t hmac_keylen;
if (get_config()->tls_ticket_cipher == EVP_aes_128_cbc()) {
expectedlen = 48;
enc_keylen = 16;
hmac_keylen = 16;
} else if (get_config()->tls_ticket_cipher == EVP_aes_256_cbc()) {
expectedlen = 80;
enc_keylen = 32;
hmac_keylen = 32;
} else {
return;
}
auto ticket_keys = std::make_shared<TicketKeys>();
for (; p != end;) {
if (end - p < 2) {
LOG(WARN) << "Memcached: tls ticket key data is too small";
conn_handler->on_tls_ticket_key_not_found(w);
return;
}
auto len = util::get_uint16(p);
p += 2;
if (len != expectedlen) {
LOG(WARN) << "Memcached: wrong tls ticket key size: want "
<< expectedlen << ", got " << len;
conn_handler->on_tls_ticket_key_not_found(w);
return;
}
if (p + len > end) {
LOG(WARN) << "Memcached: too short tls ticket key payload: want " << len
<< ", got " << (end - p);
conn_handler->on_tls_ticket_key_not_found(w);
return;
}
auto key = TicketKey();
key.cipher = get_config()->tls_ticket_cipher;
key.hmac = EVP_sha256();
key.hmac_keylen = EVP_MD_size(key.hmac);
std::copy_n(p, key.data.name.size(), key.data.name.data());
p += key.data.name.size();
std::copy_n(p, enc_keylen, key.data.enc_key.data());
p += enc_keylen;
std::copy_n(p, hmac_keylen, key.data.hmac_key.data());
p += hmac_keylen;
ticket_keys->keys.push_back(std::move(key));
}
conn_handler->on_tls_ticket_key_get_success(ticket_keys, w);
};
if (LOG_ENABLED(INFO)) {
LOG(INFO) << "Memcached: tls ticket key get request sent";
}
dispatcher->add_request(std::move(req));
}
} // namespace
namespace {
int call_daemon() {
#ifdef __sgi
@ -749,34 +861,47 @@ int event_loop() {
ev_timer renew_ticket_key_timer;
if (!get_config()->upstream_no_tls) {
bool auto_tls_ticket_key = true;
if (!get_config()->tls_ticket_key_files.empty()) {
if (!get_config()->tls_ticket_cipher_given) {
LOG(WARN) << "It is strongly recommended to specify "
"--tls-ticket-cipher=aes-128-cbc (or "
"tls-ticket-cipher=aes-128-cbc in configuration file) "
"when --tls-ticket-key-file is used for the smooth "
"transition when the default value of --tls-ticket-cipher "
"becomes aes-256-cbc";
}
auto ticket_keys = read_tls_ticket_key_file(
get_config()->tls_ticket_key_files, get_config()->tls_ticket_cipher,
EVP_sha256());
if (!ticket_keys) {
LOG(WARN) << "Use internal session ticket key generator";
} else {
conn_handler->set_ticket_keys(std::move(ticket_keys));
auto_tls_ticket_key = false;
}
}
if (auto_tls_ticket_key) {
// Generate new ticket key every 1hr.
ev_timer_init(&renew_ticket_key_timer, renew_ticket_key_cb, 0., 1_h);
renew_ticket_key_timer.data = conn_handler.get();
ev_timer_again(loop, &renew_ticket_key_timer);
if (get_config()->tls_ticket_key_memcached_host) {
conn_handler->set_tls_ticket_key_memcached_dispatcher(
make_unique<MemcachedDispatcher>(
&get_config()->tls_ticket_key_memcached_addr, loop));
// Generate first session ticket key before running workers.
renew_ticket_key_cb(loop, &renew_ticket_key_timer, 0);
ev_timer_init(&renew_ticket_key_timer, memcached_get_ticket_key_cb, 0.,
0.);
renew_ticket_key_timer.data = conn_handler.get();
// Get first ticket keys.
memcached_get_ticket_key_cb(loop, &renew_ticket_key_timer, 0);
} else {
bool auto_tls_ticket_key = true;
if (!get_config()->tls_ticket_key_files.empty()) {
if (!get_config()->tls_ticket_cipher_given) {
LOG(WARN)
<< "It is strongly recommended to specify "
"--tls-ticket-cipher=aes-128-cbc (or "
"tls-ticket-cipher=aes-128-cbc in configuration file) "
"when --tls-ticket-key-file is used for the smooth "
"transition when the default value of --tls-ticket-cipher "
"becomes aes-256-cbc";
}
auto ticket_keys = read_tls_ticket_key_file(
get_config()->tls_ticket_key_files, get_config()->tls_ticket_cipher,
EVP_sha256());
if (!ticket_keys) {
LOG(WARN) << "Use internal session ticket key generator";
} else {
conn_handler->set_ticket_keys(std::move(ticket_keys));
auto_tls_ticket_key = false;
}
}
if (auto_tls_ticket_key) {
// Generate new ticket key every 1hr.
ev_timer_init(&renew_ticket_key_timer, renew_ticket_key_cb, 0., 1_h);
renew_ticket_key_timer.data = conn_handler.get();
ev_timer_again(loop, &renew_ticket_key_timer);
// Generate first session ticket key before running workers.
renew_ticket_key_cb(loop, &renew_ticket_key_timer, 0);
}
}
}
@ -963,7 +1088,6 @@ void fill_default_config() {
mod_config()->downstream_http_proxy_userinfo = nullptr;
mod_config()->downstream_http_proxy_host = nullptr;
mod_config()->downstream_http_proxy_port = 0;
mod_config()->downstream_http_proxy_addrlen = 0;
mod_config()->read_rate = 0;
mod_config()->read_burst = 0;
mod_config()->write_rate = 0;
@ -1021,6 +1145,9 @@ void fill_default_config() {
mod_config()->tls_ticket_cipher = EVP_aes_128_cbc();
mod_config()->tls_ticket_cipher_given = false;
mod_config()->tls_session_timeout = std::chrono::hours(12);
mod_config()->tls_ticket_key_memcached_max_retry = 3;
mod_config()->tls_ticket_key_memcached_max_fail = 2;
mod_config()->tls_ticket_key_memcached_interval = 10_min;
}
} // namespace
@ -1365,6 +1492,38 @@ SSL/TLS:
Default: )"
<< util::duration_str(get_config()->ocsp_update_interval) << R"(
--no-ocsp Disable OCSP stapling.
--tls-session-cache-memcached=<HOST>,<PORT>
Specify address of memcached server to store session
cache. This enables shared session cache between
multiple nghttpx instances.
--tls-ticket-key-memcached=<HOST>,<PORT>
Specify address of memcached server to store session
cache. This enables shared TLS ticket key between
multiple nghttpx instances. nghttpx does not set TLS
ticket key to memcached. The external ticket key
generator is required. nghttpx just gets TLS ticket
keys from memcached, and use them, possibly replacing
current set of keys. It is up to extern TLS ticket key
generator to rotate keys frequently.
--tls-ticket-key-memcached-interval=<DURATION>
Set interval to get TLS ticket keys from memcached.
Default: )"
<< util::duration_str(get_config()->tls_ticket_key_memcached_interval)
<< R"(
--tls-ticket-key-memcached-max-retry=<N>
Set maximum number of consecutive retries before
abandoning TLS ticket key retrieval. If this number is
reached, the attempt is considered as failure, and
"failure" count is incremented by 1, which contributed
to the value controlled
--tls-ticket-key-memcached-max-fail option.
Default: )" << get_config()->tls_ticket_key_memcached_max_retry
<< R"(
--tls-ticket-key-memcached-max-fail=<N>
Set maximum number of consecutive failure before
disabling TLS ticket until next scheduled key retrieval.
Default: )" << get_config()->tls_ticket_key_memcached_max_fail
<< R"(
HTTP/2 and SPDY:
-c, --http2-max-concurrent-streams=<N>
@ -1728,6 +1887,14 @@ int main(int argc, char **argv) {
{SHRPX_OPT_INCLUDE, required_argument, &flag, 83},
{SHRPX_OPT_TLS_TICKET_CIPHER, required_argument, &flag, 84},
{SHRPX_OPT_HOST_REWRITE, no_argument, &flag, 85},
{SHRPX_OPT_TLS_SESSION_CACHE_MEMCACHED, required_argument, &flag, 86},
{SHRPX_OPT_TLS_TICKET_KEY_MEMCACHED, required_argument, &flag, 87},
{SHRPX_OPT_TLS_TICKET_KEY_MEMCACHED_INTERVAL, required_argument, &flag,
88},
{SHRPX_OPT_TLS_TICKET_KEY_MEMCACHED_MAX_RETRY, required_argument, &flag,
89},
{SHRPX_OPT_TLS_TICKET_KEY_MEMCACHED_MAX_FAIL, required_argument, &flag,
90},
{nullptr, 0, nullptr, 0}};
int option_index = 0;
@ -2102,6 +2269,29 @@ int main(int argc, char **argv) {
// --host-rewrite
cmdcfgs.emplace_back(SHRPX_OPT_HOST_REWRITE, "yes");
break;
case 86:
// --tls-session-cache-memcached
cmdcfgs.emplace_back(SHRPX_OPT_TLS_SESSION_CACHE_MEMCACHED, optarg);
break;
case 87:
// --tls-ticket-key-memcached
cmdcfgs.emplace_back(SHRPX_OPT_TLS_TICKET_KEY_MEMCACHED, optarg);
break;
case 88:
// --tls-ticket-key-memcached-interval
cmdcfgs.emplace_back(SHRPX_OPT_TLS_TICKET_KEY_MEMCACHED_INTERVAL,
optarg);
break;
case 89:
// --tls-ticket-key-memcached-max-retry
cmdcfgs.emplace_back(SHRPX_OPT_TLS_TICKET_KEY_MEMCACHED_MAX_RETRY,
optarg);
break;
case 90:
// --tls-ticket-key-memcached-max-fail
cmdcfgs.emplace_back(SHRPX_OPT_TLS_TICKET_KEY_MEMCACHED_MAX_FAIL,
optarg);
break;
default:
break;
}
@ -2338,19 +2528,19 @@ int main(int argc, char **argv) {
auto path = addr.host.get();
auto pathlen = strlen(path);
if (pathlen + 1 > sizeof(addr.addr.un.sun_path)) {
if (pathlen + 1 > sizeof(addr.addr.su.un.sun_path)) {
LOG(FATAL) << "UNIX domain socket path " << path << " is too long > "
<< sizeof(addr.addr.un.sun_path);
<< sizeof(addr.addr.su.un.sun_path);
exit(EXIT_FAILURE);
}
LOG(INFO) << "Use UNIX domain socket path " << path
<< " for backend connection";
addr.addr.un.sun_family = AF_UNIX;
addr.addr.su.un.sun_family = AF_UNIX;
// copy path including terminal NULL
std::copy_n(path, pathlen + 1, addr.addr.un.sun_path);
addr.addrlen = sizeof(addr.addr.un);
std::copy_n(path, pathlen + 1, addr.addr.su.un.sun_path);
addr.addr.len = sizeof(addr.addr.su.un);
continue;
}
@ -2358,7 +2548,7 @@ int main(int argc, char **argv) {
addr.hostport = strcopy(util::make_hostport(addr.host.get(), addr.port));
if (resolve_hostname(
&addr.addr, &addr.addrlen, addr.host.get(), addr.port,
&addr.addr, addr.host.get(), addr.port,
get_config()->backend_ipv4 ? AF_INET : (get_config()->backend_ipv6
? AF_INET6
: AF_UNSPEC)) == -1) {
@ -2372,7 +2562,6 @@ int main(int argc, char **argv) {
LOG(INFO) << "Resolving backend http proxy address";
}
if (resolve_hostname(&mod_config()->downstream_http_proxy_addr,
&mod_config()->downstream_http_proxy_addrlen,
get_config()->downstream_http_proxy_host.get(),
get_config()->downstream_http_proxy_port,
AF_UNSPEC) == -1) {
@ -2380,6 +2569,24 @@ int main(int argc, char **argv) {
}
}
if (get_config()->session_cache_memcached_host) {
if (resolve_hostname(&mod_config()->session_cache_memcached_addr,
get_config()->session_cache_memcached_host.get(),
get_config()->session_cache_memcached_port,
AF_UNSPEC) == -1) {
exit(EXIT_FAILURE);
}
}
if (get_config()->tls_ticket_key_memcached_host) {
if (resolve_hostname(&mod_config()->tls_ticket_key_memcached_addr,
get_config()->tls_ticket_key_memcached_host.get(),
get_config()->tls_ticket_key_memcached_port,
AF_UNSPEC) == -1) {
exit(EXIT_FAILURE);
}
}
if (get_config()->rlimit_nofile) {
struct rlimit lim = {static_cast<rlim_t>(get_config()->rlimit_nofile),
static_cast<rlim_t>(get_config()->rlimit_nofile)};

View File

@ -380,7 +380,7 @@ ClientHandler::ClientHandler(Worker *worker, int fd, SSL *ssl,
ev_timer_again(conn_.loop, &conn_.rt);
if (conn_.tls.ssl) {
SSL_set_app_data(conn_.tls.ssl, &conn_);
conn_.prepare_server_handshake();
read_ = write_ = &ClientHandler::tls_handshake;
on_read_ = &ClientHandler::upstream_noop;
on_write_ = &ClientHandler::upstream_write;
@ -848,4 +848,6 @@ ev_io *ClientHandler::get_wev() { return &conn_.wev; }
Worker *ClientHandler::get_worker() const { return worker_; }
Connection *ClientHandler::get_connection() { return &conn_; }
} // namespace shrpx

View File

@ -130,6 +130,8 @@ public:
void signal_write();
ev_io *get_wev();
Connection *get_connection();
private:
Connection conn_;
ev_timer reneg_shutdown_timer_;

View File

@ -81,7 +81,7 @@ TicketKeys::~TicketKeys() {
DownstreamAddr::DownstreamAddr(const DownstreamAddr &other)
: addr(other.addr), host(other.host ? strcopy(other.host.get()) : nullptr),
hostport(other.hostport ? strcopy(other.hostport.get()) : nullptr),
addrlen(other.addrlen), port(other.port), host_unix(other.host_unix) {}
port(other.port), host_unix(other.host_unix) {}
DownstreamAddr &DownstreamAddr::operator=(const DownstreamAddr &other) {
if (this == &other) {
@ -91,7 +91,6 @@ DownstreamAddr &DownstreamAddr::operator=(const DownstreamAddr &other) {
addr = other.addr;
host = (other.host ? strcopy(other.host.get()) : nullptr);
hostport = (other.hostport ? strcopy(other.hostport.get()) : nullptr);
addrlen = other.addrlen;
port = other.port;
host_unix = other.host_unix;
@ -156,7 +155,7 @@ read_tls_ticket_key_file(const std::vector<std::string> &files,
// with nginx and apache.
hmac_keylen = 16;
}
auto expectedlen = sizeof(keys[0].data.name) + enc_keylen + hmac_keylen;
auto expectedlen = keys[0].data.name.size() + enc_keylen + hmac_keylen;
char buf[256];
assert(sizeof(buf) >= expectedlen);
@ -202,11 +201,11 @@ read_tls_ticket_key_file(const std::vector<std::string> &files,
}
auto p = buf;
memcpy(key.data.name, p, sizeof(key.data.name));
p += sizeof(key.data.name);
memcpy(key.data.enc_key, p, enc_keylen);
std::copy_n(p, key.data.name.size(), std::begin(key.data.name));
p += key.data.name.size();
std::copy_n(p, enc_keylen, std::begin(key.data.enc_key));
p += enc_keylen;
memcpy(key.data.hmac_key, p, hmac_keylen);
std::copy_n(p, hmac_keylen, std::begin(key.data.hmac_key));
if (LOG_ENABLED(INFO)) {
LOG(INFO) << "session ticket key: " << util::format_hex(key.data.name);
@ -704,8 +703,13 @@ enum {
SHRPX_OPTID_SUBCERT,
SHRPX_OPTID_SYSLOG_FACILITY,
SHRPX_OPTID_TLS_PROTO_LIST,
SHRPX_OPTID_TLS_SESSION_CACHE_MEMCACHED,
SHRPX_OPTID_TLS_TICKET_CIPHER,
SHRPX_OPTID_TLS_TICKET_KEY_FILE,
SHRPX_OPTID_TLS_TICKET_KEY_MEMCACHED,
SHRPX_OPTID_TLS_TICKET_KEY_MEMCACHED_INTERVAL,
SHRPX_OPTID_TLS_TICKET_KEY_MEMCACHED_MAX_FAIL,
SHRPX_OPTID_TLS_TICKET_KEY_MEMCACHED_MAX_RETRY,
SHRPX_OPTID_USER,
SHRPX_OPTID_VERIFY_CLIENT,
SHRPX_OPTID_VERIFY_CLIENT_CACERT,
@ -1138,6 +1142,11 @@ int option_lookup_token(const char *name, size_t namelen) {
break;
case 24:
switch (name[23]) {
case 'd':
if (util::strieq_l("tls-ticket-key-memcache", name, 23)) {
return SHRPX_OPTID_TLS_TICKET_KEY_MEMCACHED;
}
break;
case 'e':
if (util::strieq_l("fetch-ocsp-response-fil", name, 23)) {
return SHRPX_OPTID_FETCH_OCSP_RESPONSE_FILE;
@ -1180,6 +1189,11 @@ int option_lookup_token(const char *name, size_t namelen) {
break;
case 27:
switch (name[26]) {
case 'd':
if (util::strieq_l("tls-session-cache-memcache", name, 26)) {
return SHRPX_OPTID_TLS_SESSION_CACHE_MEMCACHED;
}
break;
case 's':
if (util::strieq_l("worker-frontend-connection", name, 26)) {
return SHRPX_OPTID_WORKER_FRONTEND_CONNECTIONS;
@ -1210,6 +1224,18 @@ int option_lookup_token(const char *name, size_t namelen) {
break;
}
break;
case 33:
switch (name[32]) {
case 'l':
if (util::strieq_l("tls-ticket-key-memcached-interva", name, 32)) {
return SHRPX_OPTID_TLS_TICKET_KEY_MEMCACHED_INTERVAL;
}
if (util::strieq_l("tls-ticket-key-memcached-max-fai", name, 32)) {
return SHRPX_OPTID_TLS_TICKET_KEY_MEMCACHED_MAX_FAIL;
}
break;
}
break;
case 34:
switch (name[33]) {
case 'r':
@ -1222,6 +1248,11 @@ int option_lookup_token(const char *name, size_t namelen) {
return SHRPX_OPTID_BACKEND_HTTP1_CONNECTIONS_PER_HOST;
}
break;
case 'y':
if (util::strieq_l("tls-ticket-key-memcached-max-retr", name, 33)) {
return SHRPX_OPTID_TLS_TICKET_KEY_MEMCACHED_MAX_RETRY;
}
break;
}
break;
case 35:
@ -1865,6 +1896,48 @@ int parse_config(const char *opt, const char *optarg,
mod_config()->no_host_rewrite = !util::strieq(optarg, "yes");
return 0;
case SHRPX_OPTID_TLS_SESSION_CACHE_MEMCACHED: {
if (split_host_port(host, sizeof(host), &port, optarg, strlen(optarg)) ==
-1) {
return -1;
}
mod_config()->session_cache_memcached_host = strcopy(host);
mod_config()->session_cache_memcached_port = port;
return 0;
}
case SHRPX_OPTID_TLS_TICKET_KEY_MEMCACHED: {
if (split_host_port(host, sizeof(host), &port, optarg, strlen(optarg)) ==
-1) {
return -1;
}
mod_config()->tls_ticket_key_memcached_host = strcopy(host);
mod_config()->tls_ticket_key_memcached_port = port;
return 0;
}
case SHRPX_OPTID_TLS_TICKET_KEY_MEMCACHED_INTERVAL:
return parse_duration(&mod_config()->tls_ticket_key_memcached_interval, opt,
optarg);
case SHRPX_OPTID_TLS_TICKET_KEY_MEMCACHED_MAX_RETRY: {
int n;
if (parse_uint(&n, opt, optarg) != 0) {
return -1;
}
if (n > 30) {
LOG(ERROR) << opt << ": must be smaller than or equal to 30";
return -1;
}
mod_config()->tls_ticket_key_memcached_max_retry = n;
return 0;
}
case SHRPX_OPTID_TLS_TICKET_KEY_MEMCACHED_MAX_FAIL:
return parse_uint(&mod_config()->tls_ticket_key_memcached_max_fail, opt,
optarg);
case SHRPX_OPTID_CONF:
LOG(WARN) << "conf: ignored";

View File

@ -173,6 +173,16 @@ constexpr char SHRPX_OPT_MAX_HEADER_FIELDS[] = "max-header-fields";
constexpr char SHRPX_OPT_INCLUDE[] = "include";
constexpr char SHRPX_OPT_TLS_TICKET_CIPHER[] = "tls-ticket-cipher";
constexpr char SHRPX_OPT_HOST_REWRITE[] = "host-rewrite";
constexpr char SHRPX_OPT_TLS_SESSION_CACHE_MEMCACHED[] =
"tls-session-cache-memcached";
constexpr char SHRPX_OPT_TLS_TICKET_KEY_MEMCACHED[] =
"tls-ticket-key-memcached";
constexpr char SHRPX_OPT_TLS_TICKET_KEY_MEMCACHED_INTERVAL[] =
"tls-ticket-key-memcached-interval";
constexpr char SHRPX_OPT_TLS_TICKET_KEY_MEMCACHED_MAX_RETRY[] =
"tls-ticket-key-memcached-max-retry";
constexpr char SHRPX_OPT_TLS_TICKET_KEY_MEMCACHED_MAX_FAIL[] =
"tls-ticket-key-memcached-max-fail";
union sockaddr_union {
sockaddr_storage storage;
@ -182,6 +192,11 @@ union sockaddr_union {
sockaddr_un un;
};
struct Address {
size_t len;
union sockaddr_union su;
};
enum shrpx_proto { PROTO_HTTP2, PROTO_HTTP };
struct AltSvc {
@ -193,18 +208,17 @@ struct AltSvc {
};
struct DownstreamAddr {
DownstreamAddr() : addr{{0}}, addrlen(0), port(0), host_unix(false) {}
DownstreamAddr() : addr{}, port(0), host_unix(false) {}
DownstreamAddr(const DownstreamAddr &other);
DownstreamAddr(DownstreamAddr &&) = default;
DownstreamAddr &operator=(const DownstreamAddr &other);
DownstreamAddr &operator=(DownstreamAddr &&other) = default;
sockaddr_union addr;
Address addr;
// backend address. If |host_unix| is true, this is UNIX domain
// socket path.
std::unique_ptr<char[]> host;
std::unique_ptr<char[]> hostport;
size_t addrlen;
// backend port. 0 if |host_unix| is true.
uint16_t port;
// true if |host| contains UNIX domain socket path.
@ -223,11 +237,11 @@ struct TicketKey {
size_t hmac_keylen;
struct {
// name of this ticket configuration
uint8_t name[16];
std::array<uint8_t, 16> name;
// encryption key for |cipher|
uint8_t enc_key[32];
std::array<uint8_t, 32> enc_key;
// hmac key for |hmac|
uint8_t hmac_key[32];
std::array<uint8_t, 32> hmac_key;
} data;
};
@ -252,7 +266,9 @@ struct Config {
// list of supported SSL/TLS protocol strings.
std::vector<std::string> tls_proto_list;
// binary form of http proxy host and port
sockaddr_union downstream_http_proxy_addr;
Address downstream_http_proxy_addr;
Address session_cache_memcached_addr;
Address tls_ticket_key_memcached_addr;
std::chrono::seconds tls_session_timeout;
ev_tstamp http2_upstream_read_timeout;
ev_tstamp upstream_read_timeout;
@ -264,6 +280,7 @@ struct Config {
ev_tstamp downstream_idle_read_timeout;
ev_tstamp listener_disable_timeout;
ev_tstamp ocsp_update_interval;
ev_tstamp tls_ticket_key_memcached_interval;
// address of frontend connection. This could be a path to UNIX
// domain socket. In this case, |host_unix| must be true.
std::unique_ptr<char[]> host;
@ -295,6 +312,8 @@ struct Config {
std::unique_ptr<char[]> errorlog_file;
std::unique_ptr<char[]> fetch_ocsp_response_file;
std::unique_ptr<char[]> user;
std::unique_ptr<char[]> session_cache_memcached_host;
std::unique_ptr<char[]> tls_ticket_key_memcached_host;
FILE *http2_upstream_dump_request_header;
FILE *http2_upstream_dump_response_header;
nghttp2_session_callbacks *http2_upstream_callbacks;
@ -314,8 +333,6 @@ struct Config {
size_t http2_downstream_connections_per_worker;
size_t downstream_connections_per_host;
size_t downstream_connections_per_frontend;
// actual size of downstream_http_proxy_addr
size_t downstream_http_proxy_addrlen;
size_t read_rate;
size_t read_burst;
size_t write_rate;
@ -333,6 +350,12 @@ struct Config {
size_t max_header_fields;
// The index of catch-all group in downstream_addr_groups.
size_t downstream_addr_group_catch_all;
// Maximum number of retries when getting TLS ticket key from
// mamcached, due to network error.
size_t tls_ticket_key_memcached_max_retry;
// Maximum number of consecutive error from memcached, when this
// limit reached, TLS ticket is disabled.
size_t tls_ticket_key_memcached_max_fail;
// Bit mask to disable SSL/TLS protocol versions. This will be
// passed to SSL_CTX_set_options().
long int tls_proto_mask;
@ -349,6 +372,8 @@ struct Config {
uint16_t port;
// port in http proxy URI
uint16_t downstream_http_proxy_port;
uint16_t session_cache_memcached_port;
uint16_t tls_ticket_key_memcached_port;
bool verbose;
bool daemon;
bool verify_client;

View File

@ -192,16 +192,24 @@ void test_shrpx_config_read_tls_ticket_key_file(void) {
CU_ASSERT(ticket_keys.get() != nullptr);
CU_ASSERT(2 == ticket_keys->keys.size());
auto key = &ticket_keys->keys[0];
CU_ASSERT(0 ==
memcmp("0..............1", key->data.name, sizeof(key->data.name)));
CU_ASSERT(0 == memcmp("2..............3", key->data.enc_key, 16));
CU_ASSERT(0 == memcmp("4..............5", key->data.hmac_key, 16));
CU_ASSERT(std::equal(std::begin(key->data.name), std::end(key->data.name),
"0..............1"));
CU_ASSERT(std::equal(std::begin(key->data.enc_key),
std::begin(key->data.enc_key) + 16, "2..............3"));
CU_ASSERT(std::equal(std::begin(key->data.hmac_key),
std::begin(key->data.hmac_key) + 16,
"4..............5"));
CU_ASSERT(16 == key->hmac_keylen);
key = &ticket_keys->keys[1];
CU_ASSERT(0 ==
memcmp("6..............7", key->data.name, sizeof(key->data.name)));
CU_ASSERT(0 == memcmp("8..............9", key->data.enc_key, 16));
CU_ASSERT(0 == memcmp("a..............b", key->data.hmac_key, 16));
CU_ASSERT(std::equal(std::begin(key->data.name), std::end(key->data.name),
"6..............7"));
CU_ASSERT(std::equal(std::begin(key->data.enc_key),
std::begin(key->data.enc_key) + 16, "8..............9"));
CU_ASSERT(std::equal(std::begin(key->data.hmac_key),
std::begin(key->data.hmac_key) + 16,
"a..............b"));
CU_ASSERT(16 == key->hmac_keylen);
}
void test_shrpx_config_read_tls_ticket_key_file_aes_256(void) {
@ -227,20 +235,24 @@ void test_shrpx_config_read_tls_ticket_key_file_aes_256(void) {
CU_ASSERT(ticket_keys.get() != nullptr);
CU_ASSERT(2 == ticket_keys->keys.size());
auto key = &ticket_keys->keys[0];
CU_ASSERT(0 ==
memcmp("0..............1", key->data.name, sizeof(key->data.name)));
CU_ASSERT(0 ==
memcmp("2..............................3", key->data.enc_key, 32));
CU_ASSERT(0 ==
memcmp("4..............................5", key->data.hmac_key, 32));
CU_ASSERT(std::equal(std::begin(key->data.name), std::end(key->data.name),
"0..............1"));
CU_ASSERT(std::equal(std::begin(key->data.enc_key),
std::end(key->data.enc_key),
"2..............................3"));
CU_ASSERT(std::equal(std::begin(key->data.hmac_key),
std::end(key->data.hmac_key),
"4..............................5"));
key = &ticket_keys->keys[1];
CU_ASSERT(0 ==
memcmp("6..............7", key->data.name, sizeof(key->data.name)));
CU_ASSERT(0 ==
memcmp("8..............................9", key->data.enc_key, 32));
CU_ASSERT(0 ==
memcmp("a..............................b", key->data.hmac_key, 32));
CU_ASSERT(std::equal(std::begin(key->data.name), std::end(key->data.name),
"6..............7"));
CU_ASSERT(std::equal(std::begin(key->data.enc_key),
std::end(key->data.enc_key),
"8..............................9"));
CU_ASSERT(std::equal(std::begin(key->data.hmac_key),
std::end(key->data.hmac_key),
"a..............................b"));
}
void test_shrpx_config_match_downstream_addr_group(void) {

View File

@ -32,6 +32,8 @@
#include <openssl/err.h>
#include "shrpx_ssl.h"
#include "shrpx_memcached_request.h"
#include "memchunk.h"
using namespace nghttp2;
@ -42,7 +44,7 @@ Connection::Connection(struct ev_loop *loop, int fd, SSL *ssl,
size_t write_rate, size_t write_burst, size_t read_rate,
size_t read_burst, IOCb writecb, IOCb readcb,
TimerCb timeoutcb, void *data)
: tls{ssl}, wlimit(loop, &wev, write_rate, write_burst),
: tls{}, wlimit(loop, &wev, write_rate, write_burst),
rlimit(loop, &rev, read_rate, read_burst, ssl), writecb(writecb),
readcb(readcb), timeoutcb(timeoutcb), loop(loop), data(data), fd(fd) {
@ -60,6 +62,10 @@ Connection::Connection(struct ev_loop *loop, int fd, SSL *ssl,
// set 0. to double field explicitly just in case
tls.last_write_time = 0.;
if (ssl) {
set_ssl(ssl);
}
}
Connection::~Connection() {
@ -78,15 +84,25 @@ void Connection::disconnect() {
wlimit.stopw();
if (tls.ssl) {
SSL_set_app_data(tls.ssl, nullptr);
SSL_set_shutdown(tls.ssl, SSL_RECEIVED_SHUTDOWN);
ERR_clear_error();
if (tls.cached_session) {
SSL_SESSION_free(tls.cached_session);
}
if (tls.cached_session_lookup_req) {
tls.cached_session_lookup_req->canceled = true;
}
// To reuse SSL/TLS session, we have to shutdown, and don't free
// tls.ssl.
if (SSL_shutdown(tls.ssl) != 1) {
SSL_free(tls.ssl);
tls.ssl = nullptr;
}
tls = {tls.ssl};
}
if (fd != -1) {
@ -96,31 +112,275 @@ void Connection::disconnect() {
}
}
int Connection::tls_handshake() {
auto rv = SSL_do_handshake(tls.ssl);
namespace {
void allocate_buffer(Connection *conn) {
conn->tls.rb = make_unique<Buffer<16_k>>();
conn->tls.wb = make_unique<Buffer<16_k>>();
}
} // namespace
if (rv == 0) {
return SHRPX_ERR_NETWORK;
void Connection::prepare_client_handshake() {
SSL_set_connect_state(tls.ssl);
allocate_buffer(this);
}
void Connection::prepare_server_handshake() {
SSL_set_accept_state(tls.ssl);
allocate_buffer(this);
}
// BIO implementation is inspired by openldap implementation:
// http://www.openldap.org/devel/cvsweb.cgi/~checkout~/libraries/libldap/tls_o.c
namespace {
int shrpx_bio_write(BIO *b, const char *buf, int len) {
if (buf == nullptr || len <= 0) {
return 0;
}
if (rv < 0) {
auto conn = static_cast<Connection *>(b->ptr);
auto &wb = conn->tls.wb;
BIO_clear_retry_flags(b);
if (conn->tls.initial_handshake_done) {
// After handshake finished, send |buf| of length |len| to the
// socket directly.
if (wb && wb->rleft()) {
auto nwrite = conn->write_clear(wb->pos, wb->rleft());
if (nwrite < 0) {
return -1;
}
wb->drain(nwrite);
if (wb->rleft()) {
BIO_set_retry_write(b);
return -1;
}
// Here delete TLS write buffer
wb.reset();
}
auto nwrite = conn->write_clear(buf, len);
if (nwrite < 0) {
return -1;
}
if (nwrite == 0) {
BIO_set_retry_write(b);
return -1;
}
return nwrite;
}
auto nwrite = std::min(static_cast<size_t>(len), wb->wleft());
if (nwrite == 0) {
BIO_set_retry_write(b);
return -1;
}
wb->write(buf, nwrite);
return nwrite;
}
} // namespace
namespace {
int shrpx_bio_read(BIO *b, char *buf, int len) {
if (buf == nullptr || len <= 0) {
return 0;
}
auto conn = static_cast<Connection *>(b->ptr);
auto &rb = conn->tls.rb;
BIO_clear_retry_flags(b);
if (conn->tls.initial_handshake_done && !rb) {
auto nread = conn->read_clear(buf, len);
if (nread < 0) {
return -1;
}
if (nread == 0) {
BIO_set_retry_read(b);
return -1;
}
return nread;
}
auto nread = std::min(static_cast<size_t>(len), rb->rleft());
if (nread == 0) {
if (conn->tls.initial_handshake_done) {
rb.reset();
}
BIO_set_retry_read(b);
return -1;
}
std::copy_n(rb->pos, nread, buf);
rb->drain(nread);
return nread;
}
} // namespace
namespace {
int shrpx_bio_puts(BIO *b, const char *str) {
return shrpx_bio_write(b, str, strlen(str));
}
} // namespace
namespace {
int shrpx_bio_gets(BIO *b, char *buf, int len) { return -1; }
} // namespace
namespace {
long shrpx_bio_ctrl(BIO *b, int cmd, long num, void *ptr) {
switch (cmd) {
case BIO_CTRL_FLUSH:
return 1;
}
return 0;
}
} // namespace
namespace {
int shrpx_bio_create(BIO *b) {
b->init = 1;
b->num = 0;
b->ptr = nullptr;
b->flags = 0;
return 1;
}
} // namespace
namespace {
int shrpx_bio_destroy(BIO *b) {
if (b == nullptr) {
return 0;
}
b->ptr = nullptr;
b->init = 0;
b->flags = 0;
return 1;
}
} // namespace
namespace {
BIO_METHOD shrpx_bio_method = {
BIO_TYPE_FD, "nghttpx-bio", shrpx_bio_write,
shrpx_bio_read, shrpx_bio_puts, shrpx_bio_gets,
shrpx_bio_ctrl, shrpx_bio_create, shrpx_bio_destroy,
};
} // namespace
void Connection::set_ssl(SSL *ssl) {
tls.ssl = ssl;
auto bio = BIO_new(&shrpx_bio_method);
bio->ptr = this;
SSL_set_bio(tls.ssl, bio, bio);
SSL_set_app_data(tls.ssl, this);
rlimit.set_ssl(tls.ssl);
}
int Connection::tls_handshake() {
wlimit.stopw();
ev_timer_stop(loop, &wt);
auto nread = read_clear(tls.rb->last, tls.rb->wleft());
if (nread < 0) {
if (LOG_ENABLED(INFO)) {
LOG(INFO) << "tls: handshake read error";
}
return -1;
}
tls.rb->write(nread);
switch (tls.handshake_state) {
case TLS_CONN_WAIT_FOR_SESSION_CACHE:
if (tls.rb->wleft() == 0) {
// Input buffer is full. Disable read until cache is returned
rlimit.stopw();
ev_timer_stop(loop, &rt);
}
return SHRPX_ERR_INPROGRESS;
case TLS_CONN_GOT_SESSION_CACHE: {
// Use the same trick invented by @kazuho in h2o project
tls.wb->reset();
tls.rb->pos = tls.rb->begin();
auto ssl_ctx = SSL_get_SSL_CTX(tls.ssl);
SSL_free(tls.ssl);
auto ssl = ssl::create_ssl(ssl_ctx);
if (!ssl) {
return -1;
}
set_ssl(ssl);
SSL_set_accept_state(tls.ssl);
tls.handshake_state = TLS_CONN_NORMAL;
break;
}
case TLS_CONN_CANCEL_SESSION_CACHE:
tls.handshake_state = TLS_CONN_NORMAL;
break;
}
auto rv = SSL_do_handshake(tls.ssl);
if (rv <= 0) {
auto err = SSL_get_error(tls.ssl, rv);
switch (err) {
case SSL_ERROR_WANT_READ:
wlimit.stopw();
ev_timer_stop(loop, &wt);
return SHRPX_ERR_INPROGRESS;
case SSL_ERROR_WANT_WRITE:
wlimit.startw();
ev_timer_again(loop, &wt);
return SHRPX_ERR_INPROGRESS;
break;
default:
if (LOG_ENABLED(INFO)) {
LOG(INFO) << "tls: handshake libssl error " << err;
}
return SHRPX_ERR_NETWORK;
}
}
wlimit.stopw();
ev_timer_stop(loop, &wt);
if (tls.handshake_state == TLS_CONN_WAIT_FOR_SESSION_CACHE) {
if (LOG_ENABLED(INFO)) {
LOG(INFO) << "tls: handshake is still in progress";
}
return SHRPX_ERR_INPROGRESS;
}
if (tls.wb->rleft()) {
auto nwrite = write_clear(tls.wb->pos, tls.wb->rleft());
if (nwrite < 0) {
if (LOG_ENABLED(INFO)) {
LOG(INFO) << "tls: handshake write error";
}
return -1;
}
tls.wb->drain(nwrite);
}
if (tls.wb->rleft()) {
wlimit.startw();
ev_timer_again(loop, &wt);
}
if (rv != 1) {
if (LOG_ENABLED(INFO)) {
LOG(INFO) << "tls: handshake is still in progress";
}
return SHRPX_ERR_INPROGRESS;
}
tls.initial_handshake_done = true;

View File

@ -35,19 +35,34 @@
#include "shrpx_rate_limit.h"
#include "shrpx_error.h"
#include "buffer.h"
namespace shrpx {
struct MemcachedRequest;
enum {
TLS_CONN_NORMAL,
TLS_CONN_WAIT_FOR_SESSION_CACHE,
TLS_CONN_GOT_SESSION_CACHE,
TLS_CONN_CANCEL_SESSION_CACHE,
};
struct TLSConnection {
SSL *ssl;
SSL_SESSION *cached_session;
MemcachedRequest *cached_session_lookup_req;
ev_tstamp last_write_time;
size_t warmup_writelen;
// length passed to SSL_write and SSL_read last time. This is
// required since these functions require the exact same parameters
// on non-blocking I/O.
size_t last_writelen, last_readlen;
int handshake_state;
bool initial_handshake_done;
bool reneg_started;
std::unique_ptr<Buffer<16_k>> rb;
std::unique_ptr<Buffer<16_k>> wb;
};
template <typename T> using EVCb = void (*)(struct ev_loop *, T *, int);
@ -64,6 +79,9 @@ struct Connection {
void disconnect();
void prepare_client_handshake();
void prepare_server_handshake();
int tls_handshake();
// All write_* and writev_clear functions return number of bytes
@ -89,6 +107,8 @@ struct Connection {
void handle_tls_pending_read();
void set_ssl(SSL *ssl);
TLSConnection tls;
ev_io wev;
ev_io rev;

View File

@ -32,6 +32,7 @@
#include <cerrno>
#include <thread>
#include <random>
#include "shrpx_client_handler.h"
#include "shrpx_ssl.h"
@ -41,6 +42,7 @@
#include "shrpx_connect_blocker.h"
#include "shrpx_downstream_connection.h"
#include "shrpx_accept_handler.h"
#include "shrpx_memcached_dispatcher.h"
#include "util.h"
#include "template.h"
@ -94,7 +96,9 @@ void ocsp_chld_cb(struct ev_loop *loop, ev_child *w, int revent) {
} // namespace
ConnectionHandler::ConnectionHandler(struct ev_loop *loop)
: single_worker_(nullptr), loop_(loop), worker_round_robin_cnt_(0),
: single_worker_(nullptr), loop_(loop),
tls_ticket_key_memcached_get_retry_count_(0),
tls_ticket_key_memcached_fail_count_(0), worker_round_robin_cnt_(0),
graceful_shutdown_(false) {
ev_timer_init(&disable_acceptor_timer_, acceptor_disable_cb, 0., 0.);
disable_acceptor_timer_.data = this;
@ -553,4 +557,93 @@ void ConnectionHandler::proceed_next_cert_ocsp() {
}
}
void ConnectionHandler::set_tls_ticket_key_memcached_dispatcher(
std::unique_ptr<MemcachedDispatcher> dispatcher) {
tls_ticket_key_memcached_dispatcher_ = std::move(dispatcher);
}
MemcachedDispatcher *
ConnectionHandler::get_tls_ticket_key_memcached_dispatcher() const {
return tls_ticket_key_memcached_dispatcher_.get();
}
namespace {
std::random_device rd;
} // namespace
void ConnectionHandler::on_tls_ticket_key_network_error(ev_timer *w) {
if (++tls_ticket_key_memcached_get_retry_count_ >=
get_config()->tls_ticket_key_memcached_max_retry) {
LOG(WARN) << "Memcached: tls ticket get retry all failed "
<< tls_ticket_key_memcached_get_retry_count_ << " times.";
on_tls_ticket_key_not_found(w);
return;
}
auto dist = std::uniform_int_distribution<int>(
1, std::min(60, 1 << tls_ticket_key_memcached_get_retry_count_));
auto t = dist(rd);
LOG(WARN)
<< "Memcached: tls ticket get failed due to network error, retrying in "
<< t << " seconds";
ev_timer_set(w, t, 0.);
ev_timer_start(loop_, w);
}
void ConnectionHandler::on_tls_ticket_key_not_found(ev_timer *w) {
tls_ticket_key_memcached_get_retry_count_ = 0;
if (++tls_ticket_key_memcached_fail_count_ >=
get_config()->tls_ticket_key_memcached_max_fail) {
LOG(WARN) << "Memcached: could not get tls ticket; disable tls ticket";
tls_ticket_key_memcached_fail_count_ = 0;
set_ticket_keys(nullptr);
set_ticket_keys_to_worker(nullptr);
}
LOG(WARN) << "Memcached: tls ticket get failed, schedule next";
schedule_next_tls_ticket_key_memcached_get(w);
}
void ConnectionHandler::on_tls_ticket_key_get_success(
const std::shared_ptr<TicketKeys> &ticket_keys, ev_timer *w) {
LOG(NOTICE) << "Memcached: tls ticket get success";
tls_ticket_key_memcached_get_retry_count_ = 0;
tls_ticket_key_memcached_fail_count_ = 0;
schedule_next_tls_ticket_key_memcached_get(w);
if (!ticket_keys || ticket_keys->keys.empty()) {
LOG(WARN) << "Memcached: tls ticket keys are empty; tls ticket disabled";
set_ticket_keys(nullptr);
set_ticket_keys_to_worker(nullptr);
return;
}
if (LOG_ENABLED(INFO)) {
LOG(INFO) << "ticket keys get done";
LOG(INFO) << 0 << " enc+dec: "
<< util::format_hex(ticket_keys->keys[0].data.name);
for (size_t i = 1; i < ticket_keys->keys.size(); ++i) {
auto &key = ticket_keys->keys[i];
LOG(INFO) << i << " dec: " << util::format_hex(key.data.name);
}
}
set_ticket_keys(ticket_keys);
set_ticket_keys_to_worker(ticket_keys);
}
void
ConnectionHandler::schedule_next_tls_ticket_key_memcached_get(ev_timer *w) {
ev_timer_set(w, get_config()->tls_ticket_key_memcached_interval, 0.);
ev_timer_start(loop_, w);
}
} // namespace shrpx

View File

@ -49,6 +49,7 @@ class AcceptHandler;
class Worker;
struct WorkerStat;
struct TicketKeys;
class MemcachedDispatcher;
struct OCSPUpdateContext {
// ocsp response buffer
@ -111,6 +112,17 @@ public:
// update.
void proceed_next_cert_ocsp();
void set_tls_ticket_key_memcached_dispatcher(
std::unique_ptr<MemcachedDispatcher> dispatcher);
MemcachedDispatcher *get_tls_ticket_key_memcached_dispatcher() const;
void on_tls_ticket_key_network_error(ev_timer *w);
void on_tls_ticket_key_not_found(ev_timer *w);
void
on_tls_ticket_key_get_success(const std::shared_ptr<TicketKeys> &ticket_keys,
ev_timer *w);
void schedule_next_tls_ticket_key_memcached_get(ev_timer *w);
private:
// Stores all SSL_CTX objects.
std::vector<SSL_CTX *> all_ssl_ctx_;
@ -120,6 +132,7 @@ private:
// Worker instance used when single threaded mode (-n1) is used.
// Otherwise, nullptr and workers_ has instances of Worker instead.
std::unique_ptr<Worker> single_worker_;
std::unique_ptr<MemcachedDispatcher> tls_ticket_key_memcached_dispatcher_;
// Current TLS session ticket keys. Note that TLS connection does
// not refer to this field directly. They use TicketKeys object in
// Worker object.
@ -131,6 +144,8 @@ private:
std::unique_ptr<AcceptHandler> acceptor6_;
ev_timer disable_acceptor_timer_;
ev_timer ocsp_timer_;
size_t tls_ticket_key_memcached_get_retry_count_;
size_t tls_ticket_key_memcached_fail_count_;
unsigned int worker_round_robin_cnt_;
bool graceful_shutdown_;
};

View File

@ -276,15 +276,15 @@ int Http2Session::initiate_connection() {
}
conn_.fd = util::create_nonblock_socket(
get_config()->downstream_http_proxy_addr.storage.ss_family);
get_config()->downstream_http_proxy_addr.su.storage.ss_family);
if (conn_.fd == -1) {
connect_blocker_->on_failure();
return -1;
}
rv = connect(conn_.fd, &get_config()->downstream_http_proxy_addr.sa,
get_config()->downstream_http_proxy_addrlen);
rv = connect(conn_.fd, &get_config()->downstream_http_proxy_addr.su.sa,
get_config()->downstream_http_proxy_addr.len);
if (rv != 0 && errno != EINPROGRESS) {
SSLOG(ERROR, this) << "Failed to connect to the proxy "
<< get_config()->downstream_http_proxy_host.get()
@ -323,12 +323,12 @@ int Http2Session::initiate_connection() {
// We are establishing TLS connection. If conn_.tls.ssl, we may
// reuse the previous session.
if (!conn_.tls.ssl) {
conn_.tls.ssl = SSL_new(ssl_ctx_);
if (!conn_.tls.ssl) {
SSLOG(ERROR, this) << "SSL_new() failed: "
<< ERR_error_string(ERR_get_error(), NULL);
auto ssl = ssl::create_ssl(ssl_ctx_);
if (!ssl) {
return -1;
}
conn_.set_ssl(ssl);
}
const char *sni_name = nullptr;
@ -350,7 +350,7 @@ int Http2Session::initiate_connection() {
assert(conn_.fd == -1);
conn_.fd = util::create_nonblock_socket(
downstream_addr.addr.storage.ss_family);
downstream_addr.addr.su.storage.ss_family);
if (conn_.fd == -1) {
connect_blocker_->on_failure();
return -1;
@ -358,8 +358,8 @@ int Http2Session::initiate_connection() {
rv = connect(conn_.fd,
// TODO maybe not thread-safe?
const_cast<sockaddr *>(&downstream_addr.addr.sa),
downstream_addr.addrlen);
const_cast<sockaddr *>(&downstream_addr.addr.su.sa),
downstream_addr.addr.len);
if (rv != 0 && errno != EINPROGRESS) {
connect_blocker_->on_failure();
return -1;
@ -369,26 +369,23 @@ int Http2Session::initiate_connection() {
ev_io_set(&conn_.wev, conn_.fd, EV_WRITE);
}
if (SSL_set_fd(conn_.tls.ssl, conn_.fd) == 0) {
return -1;
}
SSL_set_connect_state(conn_.tls.ssl);
conn_.prepare_client_handshake();
} else {
if (state_ == DISCONNECTED) {
// Without TLS and proxy.
assert(conn_.fd == -1);
conn_.fd = util::create_nonblock_socket(
downstream_addr.addr.storage.ss_family);
downstream_addr.addr.su.storage.ss_family);
if (conn_.fd == -1) {
connect_blocker_->on_failure();
return -1;
}
rv = connect(conn_.fd, const_cast<sockaddr *>(&downstream_addr.addr.sa),
downstream_addr.addrlen);
rv = connect(conn_.fd,
const_cast<sockaddr *>(&downstream_addr.addr.su.sa),
downstream_addr.addr.len);
if (rv != 0 && errno != EINPROGRESS) {
connect_blocker_->on_failure();
return -1;

View File

@ -147,7 +147,7 @@ int HttpDownstreamConnection::attach_downstream(Downstream *downstream) {
next_downstream = 0;
}
conn_.fd = util::create_nonblock_socket(addr.addr.storage.ss_family);
conn_.fd = util::create_nonblock_socket(addr.addr.su.storage.ss_family);
if (conn_.fd == -1) {
auto error = errno;
@ -159,7 +159,7 @@ int HttpDownstreamConnection::attach_downstream(Downstream *downstream) {
}
int rv;
rv = connect(conn_.fd, &addr.addr.sa, addr.addrlen);
rv = connect(conn_.fd, &addr.addr.su.sa, addr.addr.len);
if (rv != 0 && errno != EINPROGRESS) {
auto error = errno;
DCLOG(WARN, this) << "connect() failed; errno=" << error;
@ -189,6 +189,15 @@ int HttpDownstreamConnection::attach_downstream(Downstream *downstream) {
break;
}
// TODO we should have timeout for connection establishment
ev_timer_again(conn_.loop, &conn_.wt);
} else {
// we may set read timer cb to idle_timeoutcb. Reset again.
conn_.rt.repeat = get_config()->downstream_read_timeout;
ev_set_cb(&conn_.rt, timeoutcb);
ev_timer_again(conn_.loop, &conn_.rt);
ev_set_cb(&conn_.rev, readcb);
}
downstream_ = downstream;
@ -196,15 +205,6 @@ int HttpDownstreamConnection::attach_downstream(Downstream *downstream) {
http_parser_init(&response_htp_, HTTP_RESPONSE);
response_htp_.data = downstream_;
ev_set_cb(&conn_.rev, readcb);
conn_.rt.repeat = get_config()->downstream_read_timeout;
// we may set read timer cb to idle_timeoutcb. Reset again.
ev_set_cb(&conn_.rt, timeoutcb);
ev_timer_again(conn_.loop, &conn_.rt);
// TODO we should have timeout for connection establishment
ev_timer_again(conn_.loop, &conn_.wt);
return 0;
}
@ -472,18 +472,16 @@ void HttpDownstreamConnection::detach_downstream(Downstream *downstream) {
DCLOG(INFO, this) << "Detaching from DOWNSTREAM:" << downstream;
}
downstream_ = nullptr;
ioctrl_.force_resume_read();
conn_.rlimit.startw();
conn_.wlimit.stopw();
ev_set_cb(&conn_.rev, idle_readcb);
ev_timer_stop(conn_.loop, &conn_.wt);
ioctrl_.force_resume_read();
conn_.rt.repeat = get_config()->downstream_idle_read_timeout;
ev_set_cb(&conn_.rt, idle_timeoutcb);
ev_timer_again(conn_.loop, &conn_.rt);
conn_.wlimit.stopw();
ev_timer_stop(conn_.loop, &conn_.wt);
}
void HttpDownstreamConnection::pause_read(IOCtrlReason reason) {
@ -870,6 +868,8 @@ int HttpDownstreamConnection::on_connect() {
connect_blocker->on_success();
conn_.rlimit.startw();
ev_timer_again(conn_.loop, &conn_.rt);
ev_set_cb(&conn_.wev, writecb);
return 0;

View File

@ -76,6 +76,10 @@ class Downstream;
#define SSLOG(SEVERITY, HTTP2) \
(Log(SEVERITY, __FILE__, __LINE__) << "[DHTTP2:" << HTTP2 << "] ")
// Memcached connection log
#define MCLOG(SEVERITY, MCONN) \
(Log(SEVERITY, __FILE__, __LINE__) << "[MCONN:" << MCONN << "] ")
enum SeverityLevel { INFO, NOTICE, WARN, ERROR, FATAL };
class Log {

View File

@ -0,0 +1,546 @@
/*
* nghttp2 - HTTP/2 C Library
*
* Copyright (c) 2015 Tatsuhiro Tsujikawa
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files (the
* "Software"), to deal in the Software without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Software, and to
* permit persons to whom the Software is furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
* LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
* OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
* WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include "shrpx_memcached_connection.h"
#include <limits.h>
#include <sys/uio.h>
#include "shrpx_memcached_request.h"
#include "shrpx_memcached_result.h"
#include "shrpx_config.h"
#include "util.h"
namespace shrpx {
namespace {
void timeoutcb(struct ev_loop *loop, ev_timer *w, int revents) {
auto conn = static_cast<Connection *>(w->data);
auto mconn = static_cast<MemcachedConnection *>(conn->data);
if (LOG_ENABLED(INFO)) {
MCLOG(INFO, mconn) << "Time out";
}
mconn->disconnect();
}
} // namespace
namespace {
void readcb(struct ev_loop *loop, ev_io *w, int revents) {
auto conn = static_cast<Connection *>(w->data);
auto mconn = static_cast<MemcachedConnection *>(conn->data);
if (mconn->on_read() != 0) {
mconn->disconnect();
return;
}
}
} // namespace
namespace {
void writecb(struct ev_loop *loop, ev_io *w, int revents) {
auto conn = static_cast<Connection *>(w->data);
auto mconn = static_cast<MemcachedConnection *>(conn->data);
if (mconn->on_write() != 0) {
mconn->disconnect();
return;
}
}
} // namespace
namespace {
void connectcb(struct ev_loop *loop, ev_io *w, int revents) {
auto conn = static_cast<Connection *>(w->data);
auto mconn = static_cast<MemcachedConnection *>(conn->data);
if (mconn->on_connect() != 0) {
mconn->disconnect();
return;
}
writecb(loop, w, revents);
}
} // namespace
constexpr ev_tstamp write_timeout = 10.;
constexpr ev_tstamp read_timeout = 10.;
MemcachedConnection::MemcachedConnection(const Address *addr,
struct ev_loop *loop)
: conn_(loop, -1, nullptr, write_timeout, read_timeout, 0, 0, 0, 0,
connectcb, readcb, timeoutcb, this),
parse_state_{}, addr_(addr), sendsum_(0), connected_(false) {}
MemcachedConnection::~MemcachedConnection() { disconnect(); }
namespace {
void clear_request(std::deque<std::unique_ptr<MemcachedRequest>> &q) {
for (auto &req : q) {
if (req->cb) {
req->cb(req.get(), MemcachedResult(MEMCACHED_ERR_EXT_NETWORK_ERROR));
}
}
q.clear();
}
} // namespace
void MemcachedConnection::disconnect() {
clear_request(recvq_);
clear_request(sendq_);
sendbufv_.clear();
sendsum_ = 0;
parse_state_ = {};
connected_ = false;
conn_.disconnect();
assert(recvbuf_.rleft() == 0);
recvbuf_.reset();
}
int MemcachedConnection::initiate_connection() {
assert(conn_.fd == -1);
conn_.fd = util::create_nonblock_socket(addr_->su.storage.ss_family);
if (conn_.fd == -1) {
auto error = errno;
MCLOG(WARN, this) << "socket() failed; errno=" << error;
return -1;
}
int rv;
rv = connect(conn_.fd, &addr_->su.sa, addr_->len);
if (rv != 0 && errno != EINPROGRESS) {
auto error = errno;
MCLOG(WARN, this) << "connect() failed; errno=" << error;
close(conn_.fd);
conn_.fd = -1;
return -1;
}
if (LOG_ENABLED(INFO)) {
MCLOG(INFO, this) << "Connecting to memcached server";
}
ev_io_set(&conn_.wev, conn_.fd, EV_WRITE);
ev_io_set(&conn_.rev, conn_.fd, EV_READ);
ev_set_cb(&conn_.wev, connectcb);
conn_.wlimit.startw();
ev_timer_again(conn_.loop, &conn_.wt);
return 0;
}
int MemcachedConnection::on_connect() {
if (!util::check_socket_connected(conn_.fd)) {
conn_.wlimit.stopw();
if (LOG_ENABLED(INFO)) {
MCLOG(INFO, this) << "memcached connect failed";
}
return -1;
}
if (LOG_ENABLED(INFO)) {
MCLOG(INFO, this) << "connected to memcached server";
}
connected_ = true;
ev_set_cb(&conn_.wev, writecb);
conn_.rlimit.startw();
ev_timer_again(conn_.loop, &conn_.rt);
return 0;
}
int MemcachedConnection::on_write() {
if (!connected_) {
return 0;
}
ev_timer_again(conn_.loop, &conn_.rt);
if (sendq_.empty()) {
conn_.wlimit.stopw();
ev_timer_stop(conn_.loop, &conn_.wt);
return 0;
}
int rv;
for (; !sendq_.empty();) {
rv = send_request();
if (rv < 0) {
return -1;
}
if (rv == 1) {
// blocked
return 0;
}
}
conn_.wlimit.stopw();
ev_timer_stop(conn_.loop, &conn_.wt);
return 0;
}
int MemcachedConnection::on_read() {
if (!connected_) {
return 0;
}
ev_timer_again(conn_.loop, &conn_.rt);
for (;;) {
auto nread = conn_.read_clear(recvbuf_.last, recvbuf_.wleft());
if (nread == 0) {
return 0;
}
if (nread < 0) {
return -1;
}
recvbuf_.write(nread);
if (parse_packet() != 0) {
return -1;
}
}
return 0;
}
int MemcachedConnection::parse_packet() {
auto in = recvbuf_.pos;
for (;;) {
auto busy = false;
switch (parse_state_.state) {
case MEMCACHED_PARSE_HEADER24: {
if (recvbuf_.last - in < 24) {
recvbuf_.drain_reset(in - recvbuf_.pos);
return 0;
}
if (recvq_.empty()) {
MCLOG(WARN, this)
<< "Response received, but there is no in-flight request.";
return -1;
}
auto &req = recvq_.front();
if (*in != MEMCACHED_RES_MAGIC) {
MCLOG(WARN, this) << "Response has bad magic: "
<< static_cast<uint32_t>(*in);
return -1;
}
++in;
parse_state_.op = *in++;
parse_state_.keylen = util::get_uint16(in);
in += 2;
parse_state_.extralen = *in++;
// skip 1 byte reserved data type
++in;
parse_state_.status_code = util::get_uint16(in);
in += 2;
parse_state_.totalbody = util::get_uint32(in);
in += 4;
// skip 4 bytes opaque
in += 4;
parse_state_.cas = util::get_uint64(in);
in += 8;
if (req->op != parse_state_.op) {
MCLOG(WARN, this)
<< "opcode in response does not match to the request: want "
<< static_cast<uint32_t>(req->op) << ", got " << parse_state_.op;
return -1;
}
if (parse_state_.keylen != 0) {
MCLOG(WARN, this) << "zero length keylen expected: got "
<< parse_state_.keylen;
return -1;
}
if (parse_state_.totalbody > 16_k) {
MCLOG(WARN, this) << "totalbody is too large: got "
<< parse_state_.totalbody;
return -1;
}
if (parse_state_.op == MEMCACHED_OP_GET &&
parse_state_.status_code == 0 && parse_state_.extralen == 0) {
MCLOG(WARN, this) << "response for GET does not have extra";
return -1;
}
if (parse_state_.totalbody <
parse_state_.keylen + parse_state_.extralen) {
MCLOG(WARN, this) << "totalbody is too short: totalbody "
<< parse_state_.totalbody << ", want min "
<< parse_state_.keylen + parse_state_.extralen;
return -1;
}
if (parse_state_.extralen) {
parse_state_.state = MEMCACHED_PARSE_EXTRA;
parse_state_.read_left = parse_state_.extralen;
} else {
parse_state_.state = MEMCACHED_PARSE_VALUE;
parse_state_.read_left = parse_state_.totalbody - parse_state_.keylen -
parse_state_.extralen;
}
busy = true;
break;
}
case MEMCACHED_PARSE_EXTRA: {
// We don't use extra for now. Just read and forget.
auto n = std::min(static_cast<size_t>(recvbuf_.last - in),
parse_state_.read_left);
parse_state_.read_left -= n;
in += n;
if (parse_state_.read_left) {
recvbuf_.reset();
return 0;
}
parse_state_.state = MEMCACHED_PARSE_VALUE;
// since we require keylen == 0, totalbody - extralen ==
// valuelen
parse_state_.read_left =
parse_state_.totalbody - parse_state_.keylen - parse_state_.extralen;
busy = true;
break;
}
case MEMCACHED_PARSE_VALUE: {
auto n = std::min(static_cast<size_t>(recvbuf_.last - in),
parse_state_.read_left);
parse_state_.value.insert(std::end(parse_state_.value), in, in + n);
parse_state_.read_left -= n;
in += n;
if (parse_state_.read_left) {
recvbuf_.reset();
return 0;
}
if (LOG_ENABLED(INFO)) {
if (parse_state_.status_code) {
MCLOG(INFO, this)
<< "response returned error status: " << parse_state_.status_code;
}
}
auto req = std::move(recvq_.front());
recvq_.pop_front();
if (!req->canceled && req->cb) {
req->cb(req.get(), MemcachedResult(parse_state_.status_code,
std::move(parse_state_.value)));
}
parse_state_ = {};
break;
}
}
if (!busy && in == recvbuf_.last) {
break;
}
}
assert(in == recvbuf_.last);
recvbuf_.reset();
return 0;
}
#define DEFAULT_WR_IOVCNT 128
#if defined(IOV_MAX) && IOV_MAX < DEFAULT_WR_IOVCNT
#define MAX_WR_IOVCNT IOV_MAX
#else // !defined(IOV_MAX) || IOV_MAX >= DEFAULT_WR_IOVCNT
#define MAX_WR_IOVCNT DEFAULT_WR_IOVCNT
#endif // !defined(IOV_MAX) || IOV_MAX >= DEFAULT_WR_IOVCNT
int MemcachedConnection::send_request() {
ssize_t nwrite;
if (sendsum_ == 0) {
for (auto &req : sendq_) {
if (req->canceled) {
continue;
}
if (serialized_size(req.get()) + sendsum_ > 1300) {
break;
}
sendbufv_.emplace_back();
sendbufv_.back().req = req.get();
make_request(&sendbufv_.back(), req.get());
sendsum_ += sendbufv_.back().left();
}
if (sendsum_ == 0) {
sendq_.clear();
return 0;
}
}
std::array<struct iovec, DEFAULT_WR_IOVCNT> iov;
size_t iovlen = 0;
for (auto &buf : sendbufv_) {
if (iovlen + 2 > iov.size()) {
break;
}
auto req = buf.req;
if (buf.headbuf.rleft()) {
iov[iovlen++] = {buf.headbuf.pos, buf.headbuf.rleft()};
}
if (buf.send_value_left) {
iov[iovlen++] = {req->value.data() + req->value.size() -
buf.send_value_left,
buf.send_value_left};
}
}
nwrite = conn_.writev_clear(iov.data(), iovlen);
if (nwrite < 0) {
return -1;
}
if (nwrite == 0) {
return 1;
}
sendsum_ -= nwrite;
while (nwrite > 0) {
auto &buf = sendbufv_.front();
auto &req = sendq_.front();
if (req->canceled) {
sendq_.pop_front();
continue;
}
assert(buf.req == req.get());
auto n = std::min(static_cast<size_t>(nwrite), buf.headbuf.rleft());
buf.headbuf.drain(n);
nwrite -= n;
n = std::min(static_cast<size_t>(nwrite), buf.send_value_left);
buf.send_value_left -= n;
nwrite -= n;
if (buf.headbuf.rleft() || buf.send_value_left) {
break;
}
sendbufv_.pop_front();
recvq_.push_back(std::move(sendq_.front()));
sendq_.pop_front();
}
return 0;
}
size_t MemcachedConnection::serialized_size(MemcachedRequest *req) {
switch (req->op) {
case MEMCACHED_OP_GET:
return 24 + req->key.size();
case MEMCACHED_OP_ADD:
default:
return 24 + 8 + req->key.size() + req->value.size();
}
}
void MemcachedConnection::make_request(MemcachedSendbuf *sendbuf,
MemcachedRequest *req) {
auto &headbuf = sendbuf->headbuf;
std::fill(std::begin(headbuf.buf), std::end(headbuf.buf), 0);
headbuf[0] = MEMCACHED_REQ_MAGIC;
headbuf[1] = req->op;
switch (req->op) {
case MEMCACHED_OP_GET:
util::put_uint16be(&headbuf[2], req->key.size());
util::put_uint32be(&headbuf[8], req->key.size());
headbuf.write(24);
break;
case MEMCACHED_OP_ADD:
util::put_uint16be(&headbuf[2], req->key.size());
headbuf[4] = 8;
util::put_uint32be(&headbuf[8], 8 + req->key.size() + req->value.size());
util::put_uint32be(&headbuf[28], req->expiry);
headbuf.write(32);
break;
}
headbuf.write(req->key.c_str(), req->key.size());
sendbuf->send_value_left = req->value.size();
}
int MemcachedConnection::add_request(std::unique_ptr<MemcachedRequest> req) {
sendq_.push_back(std::move(req));
if (connected_) {
signal_write();
return 0;
}
if (conn_.fd == -1 && initiate_connection() != 0) {
disconnect();
return -1;
}
return 0;
}
// TODO should we start write timer too?
void MemcachedConnection::signal_write() { conn_.wlimit.startw(); }
} // namespace shrpx

View File

@ -0,0 +1,129 @@
/*
* nghttp2 - HTTP/2 C Library
*
* Copyright (c) 2015 Tatsuhiro Tsujikawa
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files (the
* "Software"), to deal in the Software without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Software, and to
* permit persons to whom the Software is furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
* LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
* OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
* WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#ifndef SHRPX_MEMCACHED_CONNECTION_H
#define SHRPX_MEMCACHED_CONNECTION_H
#include "shrpx.h"
#include <memory>
#include <deque>
#include <ev.h>
#include "shrpx_connection.h"
#include "buffer.h"
using namespace nghttp2;
namespace shrpx {
struct MemcachedRequest;
struct Address;
enum {
MEMCACHED_PARSE_HEADER24,
MEMCACHED_PARSE_EXTRA,
MEMCACHED_PARSE_VALUE,
};
// Stores state when parsing response from memcached server
struct MemcachedParseState {
// Buffer for value, dynamically allocated.
std::vector<uint8_t> value;
// cas in response
uint64_t cas;
// keylen in response
size_t keylen;
// extralen in response
size_t extralen;
// totalbody in response. The length of value is totalbody -
// extralen - keylen.
size_t totalbody;
// Number of bytes left to read variable length field.
size_t read_left;
// Parser state; see enum above
int state;
// status_code in response
int status_code;
// op in response
int op;
};
struct MemcachedSendbuf {
// Buffer for header + extra + key
Buffer<512> headbuf;
// MemcachedRequest associated to this object
MemcachedRequest *req;
// Number of bytes left when sending value
size_t send_value_left;
// Returns the number of bytes this object transmits.
size_t left() const { return headbuf.rleft() + send_value_left; }
};
constexpr uint8_t MEMCACHED_REQ_MAGIC = 0x80;
constexpr uint8_t MEMCACHED_RES_MAGIC = 0x81;
// MemcachedConnection implements part of memcached binary protocol.
// This is not full brown implementation. Just the part we need is
// implemented. We only use GET and ADD.
//
// https://github.com/memcached/memcached/blob/master/doc/protocol-binary.xml
// https://code.google.com/p/memcached/wiki/MemcacheBinaryProtocol
class MemcachedConnection {
public:
MemcachedConnection(const Address *addr, struct ev_loop *loop);
~MemcachedConnection();
void disconnect();
int add_request(std::unique_ptr<MemcachedRequest> req);
int initiate_connection();
int on_connect();
int on_write();
int on_read();
int send_request();
void make_request(MemcachedSendbuf *sendbuf, MemcachedRequest *req);
int parse_packet();
size_t serialized_size(MemcachedRequest *req);
void signal_write();
private:
Connection conn_;
std::deque<std::unique_ptr<MemcachedRequest>> recvq_;
std::deque<std::unique_ptr<MemcachedRequest>> sendq_;
std::deque<MemcachedSendbuf> sendbufv_;
MemcachedParseState parse_state_;
const Address *addr_;
// Sum of the bytes to be transmitted in sendbufv_.
size_t sendsum_;
bool connected_;
Buffer<8_k> recvbuf_;
};
} // namespace shrpx
#endif // SHRPX_MEMCACHED_CONNECTION_H

View File

@ -0,0 +1,47 @@
/*
* nghttp2 - HTTP/2 C Library
*
* Copyright (c) 2015 Tatsuhiro Tsujikawa
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files (the
* "Software"), to deal in the Software without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Software, and to
* permit persons to whom the Software is furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
* LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
* OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
* WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include "shrpx_memcached_dispatcher.h"
#include "shrpx_memcached_request.h"
#include "shrpx_memcached_connection.h"
#include "shrpx_config.h"
namespace shrpx {
MemcachedDispatcher::MemcachedDispatcher(const Address *addr,
struct ev_loop *loop)
: loop_(loop), mconn_(make_unique<MemcachedConnection>(addr, loop_)) {}
MemcachedDispatcher::~MemcachedDispatcher() {}
int MemcachedDispatcher::add_request(std::unique_ptr<MemcachedRequest> req) {
if (mconn_->add_request(std::move(req)) != 0) {
return -1;
}
return 0;
}
} // namespace shrpx

View File

@ -0,0 +1,54 @@
/*
* nghttp2 - HTTP/2 C Library
*
* Copyright (c) 2015 Tatsuhiro Tsujikawa
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files (the
* "Software"), to deal in the Software without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Software, and to
* permit persons to whom the Software is furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
* LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
* OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
* WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#ifndef SHRPX_MEMCACHED_DISPATCHER_H
#define SHRPX_MEMCACHED_DISPATCHER_H
#include "shrpx.h"
#include <memory>
#include <ev.h>
namespace shrpx {
struct MemcachedRequest;
class MemcachedConnection;
struct Address;
class MemcachedDispatcher {
public:
MemcachedDispatcher(const Address *addr, struct ev_loop *loop);
~MemcachedDispatcher();
int add_request(std::unique_ptr<MemcachedRequest> req);
private:
struct ev_loop *loop_;
std::unique_ptr<MemcachedConnection> mconn_;
};
} // namespace shrpx
#endif // SHRPX_MEMCACHED_DISPATCHER_H

View File

@ -0,0 +1,59 @@
/*
* nghttp2 - HTTP/2 C Library
*
* Copyright (c) 2015 Tatsuhiro Tsujikawa
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files (the
* "Software"), to deal in the Software without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Software, and to
* permit persons to whom the Software is furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
* LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
* OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
* WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#ifndef SHRPX_MEMCACHED_REQUEST_H
#define SHRPX_MEMCACHED_REQUEST_H
#include "shrpx.h"
#include <string>
#include <vector>
#include <memory>
#include "shrpx_memcached_result.h"
namespace shrpx {
enum {
MEMCACHED_OP_GET = 0x00,
MEMCACHED_OP_ADD = 0x02,
};
struct MemcachedRequest;
using MemcachedResultCallback =
std::function<void(MemcachedRequest *req, MemcachedResult res)>;
struct MemcachedRequest {
std::string key;
std::vector<uint8_t> value;
MemcachedResultCallback cb;
uint32_t expiry;
int op;
bool canceled;
};
} // namespace shrpx
#endif // SHRPX_MEMCACHED_REQUEST_H

View File

@ -0,0 +1,50 @@
/*
* nghttp2 - HTTP/2 C Library
*
* Copyright (c) 2015 Tatsuhiro Tsujikawa
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files (the
* "Software"), to deal in the Software without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Software, and to
* permit persons to whom the Software is furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
* LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
* OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
* WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#ifndef SHRPX_MEMCACHED_RESULT_H
#define SHRPX_MEMCACHED_RESULT_H
#include "shrpx.h"
#include <vector>
namespace shrpx {
enum MemcachedStatusCode {
MEMCACHED_ERR_NO_ERROR,
MEMCACHED_ERR_EXT_NETWORK_ERROR = 0x1001,
};
struct MemcachedResult {
MemcachedResult(int status_code) : status_code(status_code) {}
MemcachedResult(int status_code, std::vector<uint8_t> value)
: value(std::move(value)), status_code(status_code) {}
std::vector<uint8_t> value;
int status_code;
};
} // namespace shrpx
#endif // SHRPX_MEMCACHED_RESULT_H

View File

@ -106,4 +106,6 @@ void RateLimit::handle_tls_pending_read() {
ev_feed_event(loop_, w_, EV_READ);
}
void RateLimit::set_ssl(SSL *ssl) { ssl_ = ssl; }
} // namespace shrpx

View File

@ -48,6 +48,7 @@ public:
// required since it is buffered in ssl_ object, io event is not
// generated unless new incoming data is received.
void handle_tls_pending_read();
void set_ssl(SSL *ssl);
private:
ev_timer t_;

View File

@ -54,6 +54,8 @@
#include "shrpx_worker.h"
#include "shrpx_downstream_connection_pool.h"
#include "shrpx_http2_session.h"
#include "shrpx_memcached_request.h"
#include "shrpx_memcached_dispatcher.h"
#include "util.h"
#include "ssl.h"
#include "template.h"
@ -183,6 +185,127 @@ int ocsp_resp_cb(SSL *ssl, void *arg) {
}
} // namespace
constexpr char MEMCACHED_SESSION_CACHE_KEY_PREFIX[] =
"nghttpx:tls-session-cache:";
namespace {
int tls_session_new_cb(SSL *ssl, SSL_SESSION *session) {
auto handler = static_cast<ClientHandler *>(SSL_get_app_data(ssl));
auto worker = handler->get_worker();
auto dispatcher = worker->get_session_cache_memcached_dispatcher();
const unsigned char *id;
unsigned int idlen;
id = SSL_SESSION_get_id(session, &idlen);
if (LOG_ENABLED(INFO)) {
LOG(INFO) << "Memached: cache session, id=" << util::format_hex(id, idlen);
}
auto req = make_unique<MemcachedRequest>();
req->op = MEMCACHED_OP_ADD;
req->key = MEMCACHED_SESSION_CACHE_KEY_PREFIX;
req->key += util::format_hex(id, idlen);
auto sessionlen = i2d_SSL_SESSION(session, nullptr);
req->value.resize(sessionlen);
auto buf = &req->value[0];
i2d_SSL_SESSION(session, &buf);
req->expiry = 12_h;
req->cb = [](MemcachedRequest *req, MemcachedResult res) {
if (LOG_ENABLED(INFO)) {
LOG(INFO) << "Memcached: session cache done. key=" << req->key
<< ", status_code=" << res.status_code << ", value="
<< std::string(std::begin(res.value), std::end(res.value));
}
if (res.status_code != 0) {
LOG(WARN) << "Memcached: failed to cache session key=" << req->key
<< ", status_code=" << res.status_code << ", value="
<< std::string(std::begin(res.value), std::end(res.value));
}
};
assert(!req->canceled);
dispatcher->add_request(std::move(req));
return 0;
}
} // namespace
namespace {
SSL_SESSION *tls_session_get_cb(SSL *ssl, unsigned char *id, int idlen,
int *copy) {
auto handler = static_cast<ClientHandler *>(SSL_get_app_data(ssl));
auto worker = handler->get_worker();
auto dispatcher = worker->get_session_cache_memcached_dispatcher();
auto conn = handler->get_connection();
if (conn->tls.cached_session) {
if (LOG_ENABLED(INFO)) {
LOG(INFO) << "Memcached: found cached session, id="
<< util::format_hex(id, idlen);
}
// This is required, without this, memory leak occurs.
*copy = 0;
auto session = conn->tls.cached_session;
conn->tls.cached_session = nullptr;
return session;
}
if (LOG_ENABLED(INFO)) {
LOG(INFO) << "Memcached: get cached session, id="
<< util::format_hex(id, idlen);
}
auto req = make_unique<MemcachedRequest>();
req->op = MEMCACHED_OP_GET;
req->key = MEMCACHED_SESSION_CACHE_KEY_PREFIX;
req->key += util::format_hex(id, idlen);
req->cb = [conn](MemcachedRequest *, MemcachedResult res) {
if (LOG_ENABLED(INFO)) {
LOG(INFO) << "Memcached: returned status code " << res.status_code;
}
// We might stop reading, so start it again
conn->rlimit.startw();
ev_timer_again(conn->loop, &conn->rt);
conn->wlimit.startw();
ev_timer_again(conn->loop, &conn->wt);
conn->tls.cached_session_lookup_req = nullptr;
if (res.status_code != 0) {
conn->tls.handshake_state = TLS_CONN_CANCEL_SESSION_CACHE;
return;
}
const uint8_t *p = res.value.data();
auto session = d2i_SSL_SESSION(nullptr, &p, res.value.size());
if (!session) {
if (LOG_ENABLED(INFO)) {
LOG(INFO) << "cannot materialize session";
}
conn->tls.handshake_state = TLS_CONN_CANCEL_SESSION_CACHE;
return;
}
conn->tls.cached_session = session;
conn->tls.handshake_state = TLS_CONN_GOT_SESSION_CACHE;
};
conn->tls.handshake_state = TLS_CONN_WAIT_FOR_SESSION_CACHE;
conn->tls.cached_session_lookup_req = req.get();
dispatcher->add_request(std::move(req));
return nullptr;
}
} // namespace
namespace {
int ticket_key_cb(SSL *ssl, unsigned char *key_name, unsigned char *iv,
EVP_CIPHER_CTX *ctx, HMAC_CTX *hctx, int enc) {
@ -213,18 +336,20 @@ int ticket_key_cb(SSL *ssl, unsigned char *key_name, unsigned char *iv,
<< util::format_hex(key.data.name);
}
memcpy(key_name, key.data.name, sizeof(key.data.name));
std::copy(std::begin(key.data.name), std::end(key.data.name), key_name);
EVP_EncryptInit_ex(ctx, get_config()->tls_ticket_cipher, nullptr,
key.data.enc_key, iv);
HMAC_Init_ex(hctx, key.data.hmac_key, key.hmac_keylen, key.hmac, nullptr);
key.data.enc_key.data(), iv);
HMAC_Init_ex(hctx, key.data.hmac_key.data(), key.hmac_keylen, key.hmac,
nullptr);
return 1;
}
size_t i;
for (i = 0; i < keys.size(); ++i) {
auto &key = keys[i];
if (memcmp(key_name, key.data.name, sizeof(key.data.name)) == 0) {
if (std::equal(std::begin(key.data.name), std::end(key.data.name),
key_name)) {
break;
}
}
@ -243,8 +368,9 @@ int ticket_key_cb(SSL *ssl, unsigned char *key_name, unsigned char *iv,
}
auto &key = keys[i];
HMAC_Init_ex(hctx, key.data.hmac_key, key.hmac_keylen, key.hmac, nullptr);
EVP_DecryptInit_ex(ctx, key.cipher, nullptr, key.data.enc_key, iv);
HMAC_Init_ex(hctx, key.data.hmac_key.data(), key.hmac_keylen, key.hmac,
nullptr);
EVP_DecryptInit_ex(ctx, key.cipher, nullptr, key.data.enc_key.data(), iv);
return i == 0 ? 1 : 2;
}
@ -334,18 +460,23 @@ SSL_CTX *create_ssl_context(const char *private_key_file,
DIE();
}
auto ssl_opts = (SSL_OP_ALL & ~SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS) |
SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_COMPRESSION |
SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION |
SSL_OP_SINGLE_ECDH_USE | SSL_OP_SINGLE_DH_USE |
SSL_OP_CIPHER_SERVER_PREFERENCE |
get_config()->tls_proto_mask;
constexpr auto ssl_opts =
(SSL_OP_ALL & ~SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS) | SSL_OP_NO_SSLv2 |
SSL_OP_NO_SSLv3 | SSL_OP_NO_COMPRESSION |
SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION | SSL_OP_SINGLE_ECDH_USE |
SSL_OP_SINGLE_DH_USE | SSL_OP_CIPHER_SERVER_PREFERENCE;
SSL_CTX_set_options(ssl_ctx, ssl_opts);
SSL_CTX_set_options(ssl_ctx, ssl_opts | get_config()->tls_proto_mask);
const unsigned char sid_ctx[] = "shrpx";
SSL_CTX_set_session_id_context(ssl_ctx, sid_ctx, sizeof(sid_ctx) - 1);
SSL_CTX_set_session_cache_mode(ssl_ctx, SSL_SESS_CACHE_SERVER);
if (get_config()->session_cache_memcached_host) {
SSL_CTX_sess_set_new_cb(ssl_ctx, tls_session_new_cb);
SSL_CTX_sess_set_get_cb(ssl_ctx, tls_session_get_cb);
}
SSL_CTX_set_timeout(ssl_ctx, get_config()->tls_session_timeout.count());
const char *ciphers;
@ -493,12 +624,12 @@ SSL_CTX *create_ssl_client_context() {
DIE();
}
auto ssl_opts = (SSL_OP_ALL & ~SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS) |
SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_COMPRESSION |
SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION |
get_config()->tls_proto_mask;
constexpr auto ssl_opts = (SSL_OP_ALL & ~SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS) |
SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 |
SSL_OP_NO_COMPRESSION |
SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION;
SSL_CTX_set_options(ssl_ctx, ssl_opts);
SSL_CTX_set_options(ssl_ctx, ssl_opts | get_config()->tls_proto_mask);
const char *ciphers;
if (get_config()->ciphers) {
@ -564,6 +695,17 @@ SSL_CTX *create_ssl_client_context() {
return ssl_ctx;
}
SSL *create_ssl(SSL_CTX *ssl_ctx) {
auto ssl = SSL_new(ssl_ctx);
if (!ssl) {
LOG(ERROR) << "SSL_new() failed: " << ERR_error_string(ERR_get_error(),
nullptr);
return nullptr;
}
return ssl;
}
ClientHandler *accept_connection(Worker *worker, int fd, sockaddr *addr,
int addrlen) {
char host[NI_MAXHOST];
@ -586,21 +728,10 @@ ClientHandler *accept_connection(Worker *worker, int fd, sockaddr *addr,
SSL *ssl = nullptr;
auto ssl_ctx = worker->get_sv_ssl_ctx();
if (ssl_ctx) {
ssl = SSL_new(ssl_ctx);
ssl = create_ssl(ssl_ctx);
if (!ssl) {
LOG(ERROR) << "SSL_new() failed: " << ERR_error_string(ERR_get_error(),
nullptr);
return nullptr;
}
if (SSL_set_fd(ssl, fd) == 0) {
LOG(ERROR) << "SSL_set_fd() failed: " << ERR_error_string(ERR_get_error(),
nullptr);
SSL_free(ssl);
return nullptr;
}
SSL_set_accept_state(ssl);
}
return new ClientHandler(worker, fd, ssl, host, service);
@ -641,8 +772,8 @@ bool tls_hostname_match(const char *pattern, const char *hostname) {
} // namespace
namespace {
int verify_hostname(const char *hostname, const sockaddr_union *su,
size_t salen, const std::vector<std::string> &dns_names,
int verify_hostname(const char *hostname, const Address *addr,
const std::vector<std::string> &dns_names,
const std::vector<std::string> &ip_addrs,
const std::string &common_name) {
if (util::numeric_host(hostname)) {
@ -650,19 +781,19 @@ int verify_hostname(const char *hostname, const sockaddr_union *su,
return util::strieq(common_name.c_str(), hostname) ? 0 : -1;
}
const void *saddr;
switch (su->storage.ss_family) {
switch (addr->su.storage.ss_family) {
case AF_INET:
saddr = &su->in.sin_addr;
saddr = &addr->su.in.sin_addr;
break;
case AF_INET6:
saddr = &su->in6.sin6_addr;
saddr = &addr->su.in6.sin6_addr;
break;
default:
return -1;
}
for (size_t i = 0; i < ip_addrs.size(); ++i) {
if (salen == ip_addrs[i].size() &&
memcmp(saddr, ip_addrs[i].c_str(), salen) == 0) {
if (addr->len == ip_addrs[i].size() &&
memcmp(saddr, ip_addrs[i].c_str(), addr->len) == 0) {
return 0;
}
}
@ -757,8 +888,8 @@ int check_cert(SSL *ssl, const DownstreamAddr *addr) {
std::vector<std::string> dns_names;
std::vector<std::string> ip_addrs;
get_altnames(cert, dns_names, ip_addrs, common_name);
if (verify_hostname(addr->host.get(), &addr->addr, addr->addrlen, dns_names,
ip_addrs, common_name) != 0) {
if (verify_hostname(addr->host.get(), &addr->addr, dns_names, ip_addrs,
common_name) != 0) {
LOG(ERROR) << "Certificate verification failed: hostname does not match";
return -1;
}

View File

@ -172,6 +172,8 @@ SSL_CTX *setup_client_ssl_context();
// this function returns nullptr.
CertLookupTree *create_cert_lookup_tree();
SSL *create_ssl(SSL_CTX *ssl_ctx);
} // namespace ssl
} // namespace shrpx

View File

@ -36,6 +36,7 @@
#include "shrpx_http2_session.h"
#include "shrpx_log_config.h"
#include "shrpx_connect_blocker.h"
#include "shrpx_memcached_dispatcher.h"
#include "util.h"
#include "template.h"
@ -75,6 +76,11 @@ Worker::Worker(struct ev_loop *loop, SSL_CTX *sv_ssl_ctx, SSL_CTX *cl_ssl_ctx,
ev_timer_init(&mcpool_clear_timer_, mcpool_clear_cb, 0., 0.);
mcpool_clear_timer_.data = this;
if (get_config()->session_cache_memcached_host) {
session_cache_memcached_dispatcher_ = make_unique<MemcachedDispatcher>(
&get_config()->session_cache_memcached_addr, loop);
}
if (get_config()->downstream_proto == PROTO_HTTP2) {
auto n = get_config()->http2_downstream_connections_per_worker;
size_t group = 0;
@ -253,4 +259,8 @@ DownstreamGroup *Worker::get_dgrp(size_t group) {
return &dgrps_[group];
}
MemcachedDispatcher *Worker::get_session_cache_memcached_dispatcher() {
return session_cache_memcached_dispatcher_.get();
}
} // namespace shrpx

View File

@ -49,6 +49,7 @@ namespace shrpx {
class Http2Session;
class ConnectBlocker;
class MemcachedDispatcher;
namespace ssl {
class CertLookupTree;
@ -121,6 +122,8 @@ public:
DownstreamGroup *get_dgrp(size_t group);
MemcachedDispatcher *get_session_cache_memcached_dispatcher();
private:
#ifndef NOTHREADS
std::future<void> fut_;
@ -133,6 +136,7 @@ private:
DownstreamConnectionPool dconn_pool_;
WorkerStat worker_stat_;
std::vector<DownstreamGroup> dgrps_;
std::unique_ptr<MemcachedDispatcher> session_cache_memcached_dispatcher_;
struct ev_loop *loop_;
// Following fields are shared across threads if

View File

@ -1130,6 +1130,41 @@ void hexdump(FILE *out, const uint8_t *src, size_t len) {
}
}
void put_uint16be(uint8_t *buf, uint16_t n) {
uint16_t x = htons(n);
memcpy(buf, &x, sizeof(uint16_t));
}
void put_uint32be(uint8_t *buf, uint32_t n) {
uint32_t x = htonl(n);
memcpy(buf, &x, sizeof(uint32_t));
}
uint16_t get_uint16(const uint8_t *data) {
uint16_t n;
memcpy(&n, data, sizeof(uint16_t));
return ntohs(n);
}
uint32_t get_uint32(const uint8_t *data) {
uint32_t n;
memcpy(&n, data, sizeof(uint32_t));
return ntohl(n);
}
uint64_t get_uint64(const uint8_t *data) {
uint64_t n = 0;
n += static_cast<uint64_t>(data[0]) << 56;
n += static_cast<uint64_t>(data[1]) << 48;
n += static_cast<uint64_t>(data[2]) << 40;
n += static_cast<uint64_t>(data[3]) << 32;
n += data[4] << 24;
n += data[5] << 16;
n += data[6] << 8;
n += data[7];
return n;
}
} // namespace util
} // namespace nghttp2

View File

@ -216,6 +216,10 @@ template <size_t N> std::string format_hex(const unsigned char (&s)[N]) {
return format_hex(s, N);
}
template <size_t N> std::string format_hex(const std::array<uint8_t, N> &s) {
return format_hex(s.data(), s.size());
}
std::string http_date(time_t t);
// Returns given time |t| from epoch in Common Log format (e.g.,
@ -631,6 +635,26 @@ std::string make_hostport(const char *host, uint16_t port);
// Dumps |src| of length |len| in the format similar to `hexdump -C`.
void hexdump(FILE *out, const uint8_t *src, size_t len);
// Copies 2 byte unsigned integer |n| in host byte order to |buf| in
// network byte order.
void put_uint16be(uint8_t *buf, uint16_t n);
// Copies 4 byte unsigned integer |n| in host byte order to |buf| in
// network byte order.
void put_uint32be(uint8_t *buf, uint32_t n);
// Retrieves 2 byte unsigned integer stored in |data| in network byte
// order and returns it in host byte order.
uint16_t get_uint16(const uint8_t *data);
// Retrieves 4 byte unsigned integer stored in |data| in network byte
// order and returns it in host byte order.
uint32_t get_uint32(const uint8_t *data);
// Retrieves 8 byte unsigned integer stored in |data| in network byte
// order and returns it in host byte order.
uint64_t get_uint64(const uint8_t *data);
} // namespace util
} // namespace nghttp2

View File

@ -393,4 +393,13 @@ void test_util_localtime_date(void) {
tzset();
}
void test_util_get_uint64(void) {
auto v = std::array<unsigned char, 8>{
{0x01, 0x12, 0x34, 0x56, 0x78, 0x9a, 0xab, 0xbc}};
auto n = util::get_uint64(v.data());
CU_ASSERT(0x01123456789aabbcULL == n);
}
} // namespace shrpx

View File

@ -55,6 +55,7 @@ void test_util_starts_with(void);
void test_util_ends_with(void);
void test_util_parse_http_date(void);
void test_util_localtime_date(void);
void test_util_get_uint64(void);
} // namespace shrpx