From fd060eb9f1490909c9022f5eb0e4938e07cdfc1b Mon Sep 17 00:00:00 2001 From: Tatsuhiro Tsujikawa Date: Wed, 15 Sep 2021 20:07:33 +0900 Subject: [PATCH] nghttpx: Connection ID encryption --- bpf/reuseport_kern.c | 337 ++++++++++++++++++++++++++- gennghttpxfun.py | 1 + src/shrpx.cc | 24 ++ src/shrpx_config.cc | 16 ++ src/shrpx_config.h | 4 + src/shrpx_connection_handler.cc | 18 +- src/shrpx_http3_upstream.cc | 22 +- src/shrpx_quic.cc | 47 +++- src/shrpx_quic.h | 13 +- src/shrpx_quic_connection_handler.cc | 47 ++-- src/shrpx_worker.cc | 25 +- 11 files changed, 513 insertions(+), 41 deletions(-) diff --git a/bpf/reuseport_kern.c b/bpf/reuseport_kern.c index 710ff2ce..25e1102a 100644 --- a/bpf/reuseport_kern.c +++ b/bpf/reuseport_kern.c @@ -38,6 +38,304 @@ * how to install kernel header files. */ +/* AES_CBC_decrypt_buffer: https://github.com/kokke/tiny-AES-c + License is Public Domain. Commit hash: + 12e7744b4919e9d55de75b7ab566326a1c8e7a67 */ + +#define AES_BLOCKLEN \ + 16 /* Block length in bytes - AES is 128b block \ + only */ + +#define AES_KEYLEN 16 /* Key length in bytes */ +#define AES_keyExpSize 176 + +struct AES_ctx { + __u8 RoundKey[AES_keyExpSize]; +}; + +/* The number of columns comprising a state in AES. This is a constant + in AES. Value=4 */ +#define Nb 4 + +#define Nk 4 /* The number of 32 bit words in a key. */ +#define Nr 10 /* The number of rounds in AES Cipher. */ + +/* state - array holding the intermediate results during + decryption. */ +typedef __u8 state_t[4][4]; + +/* The lookup-tables are marked const so they can be placed in + read-only storage instead of RAM The numbers below can be computed + dynamically trading ROM for RAM - This can be useful in (embedded) + bootloader applications, where ROM is often limited. */ +static const __u8 sbox[256] = { + /* 0 1 2 3 4 5 6 7 8 9 A B C D E F */ + 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, + 0xfe, 0xd7, 0xab, 0x76, 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, + 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, 0xb7, 0xfd, 0x93, 0x26, + 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15, + 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, + 0xeb, 0x27, 0xb2, 0x75, 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, + 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84, 0x53, 0xd1, 0x00, 0xed, + 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf, + 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, + 0x50, 0x3c, 0x9f, 0xa8, 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, + 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, 0xcd, 0x0c, 0x13, 0xec, + 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73, + 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, + 0xde, 0x5e, 0x0b, 0xdb, 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, + 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, 0xe7, 0xc8, 0x37, 0x6d, + 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08, + 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, + 0x4b, 0xbd, 0x8b, 0x8a, 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, + 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, 0xe1, 0xf8, 0x98, 0x11, + 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf, + 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, + 0xb0, 0x54, 0xbb, 0x16}; + +static const __u8 rsbox[256] = { + 0x52, 0x09, 0x6a, 0xd5, 0x30, 0x36, 0xa5, 0x38, 0xbf, 0x40, 0xa3, 0x9e, + 0x81, 0xf3, 0xd7, 0xfb, 0x7c, 0xe3, 0x39, 0x82, 0x9b, 0x2f, 0xff, 0x87, + 0x34, 0x8e, 0x43, 0x44, 0xc4, 0xde, 0xe9, 0xcb, 0x54, 0x7b, 0x94, 0x32, + 0xa6, 0xc2, 0x23, 0x3d, 0xee, 0x4c, 0x95, 0x0b, 0x42, 0xfa, 0xc3, 0x4e, + 0x08, 0x2e, 0xa1, 0x66, 0x28, 0xd9, 0x24, 0xb2, 0x76, 0x5b, 0xa2, 0x49, + 0x6d, 0x8b, 0xd1, 0x25, 0x72, 0xf8, 0xf6, 0x64, 0x86, 0x68, 0x98, 0x16, + 0xd4, 0xa4, 0x5c, 0xcc, 0x5d, 0x65, 0xb6, 0x92, 0x6c, 0x70, 0x48, 0x50, + 0xfd, 0xed, 0xb9, 0xda, 0x5e, 0x15, 0x46, 0x57, 0xa7, 0x8d, 0x9d, 0x84, + 0x90, 0xd8, 0xab, 0x00, 0x8c, 0xbc, 0xd3, 0x0a, 0xf7, 0xe4, 0x58, 0x05, + 0xb8, 0xb3, 0x45, 0x06, 0xd0, 0x2c, 0x1e, 0x8f, 0xca, 0x3f, 0x0f, 0x02, + 0xc1, 0xaf, 0xbd, 0x03, 0x01, 0x13, 0x8a, 0x6b, 0x3a, 0x91, 0x11, 0x41, + 0x4f, 0x67, 0xdc, 0xea, 0x97, 0xf2, 0xcf, 0xce, 0xf0, 0xb4, 0xe6, 0x73, + 0x96, 0xac, 0x74, 0x22, 0xe7, 0xad, 0x35, 0x85, 0xe2, 0xf9, 0x37, 0xe8, + 0x1c, 0x75, 0xdf, 0x6e, 0x47, 0xf1, 0x1a, 0x71, 0x1d, 0x29, 0xc5, 0x89, + 0x6f, 0xb7, 0x62, 0x0e, 0xaa, 0x18, 0xbe, 0x1b, 0xfc, 0x56, 0x3e, 0x4b, + 0xc6, 0xd2, 0x79, 0x20, 0x9a, 0xdb, 0xc0, 0xfe, 0x78, 0xcd, 0x5a, 0xf4, + 0x1f, 0xdd, 0xa8, 0x33, 0x88, 0x07, 0xc7, 0x31, 0xb1, 0x12, 0x10, 0x59, + 0x27, 0x80, 0xec, 0x5f, 0x60, 0x51, 0x7f, 0xa9, 0x19, 0xb5, 0x4a, 0x0d, + 0x2d, 0xe5, 0x7a, 0x9f, 0x93, 0xc9, 0x9c, 0xef, 0xa0, 0xe0, 0x3b, 0x4d, + 0xae, 0x2a, 0xf5, 0xb0, 0xc8, 0xeb, 0xbb, 0x3c, 0x83, 0x53, 0x99, 0x61, + 0x17, 0x2b, 0x04, 0x7e, 0xba, 0x77, 0xd6, 0x26, 0xe1, 0x69, 0x14, 0x63, + 0x55, 0x21, 0x0c, 0x7d}; + +/* The round constant word array, Rcon[i], contains the values given + by x to the power (i-1) being powers of x (x is denoted as {02}) in + the field GF(2^8) */ +static const __u8 Rcon[11] = {0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, + 0x20, 0x40, 0x80, 0x1b, 0x36}; + +#define getSBoxValue(num) (sbox[(num)]) + +/* This function produces Nb(Nr+1) round keys. The round keys are used + in each round to decrypt the states. */ +static void KeyExpansion(__u8 *RoundKey, const __u8 *Key) { + unsigned i, j, k; + __u8 tempa[4]; /* Used for the column/row operations */ + + /* The first round key is the key itself. */ + for (i = 0; i < Nk; ++i) { + RoundKey[(i * 4) + 0] = Key[(i * 4) + 0]; + RoundKey[(i * 4) + 1] = Key[(i * 4) + 1]; + RoundKey[(i * 4) + 2] = Key[(i * 4) + 2]; + RoundKey[(i * 4) + 3] = Key[(i * 4) + 3]; + } + + /* All other round keys are found from the previous round keys. */ + for (i = Nk; i < Nb * (Nr + 1); ++i) { + { + k = (i - 1) * 4; + tempa[0] = RoundKey[k + 0]; + tempa[1] = RoundKey[k + 1]; + tempa[2] = RoundKey[k + 2]; + tempa[3] = RoundKey[k + 3]; + } + + if (i % Nk == 0) { + /* This function shifts the 4 bytes in a word to the left once. + [a0,a1,a2,a3] becomes [a1,a2,a3,a0] */ + + /* Function RotWord() */ + { + const __u8 u8tmp = tempa[0]; + tempa[0] = tempa[1]; + tempa[1] = tempa[2]; + tempa[2] = tempa[3]; + tempa[3] = u8tmp; + } + + /* SubWord() is a function that takes a four-byte input word and + applies the S-box to each of the four bytes to produce an + output word. */ + + /* Function Subword() */ + { + tempa[0] = getSBoxValue(tempa[0]); + tempa[1] = getSBoxValue(tempa[1]); + tempa[2] = getSBoxValue(tempa[2]); + tempa[3] = getSBoxValue(tempa[3]); + } + + tempa[0] = tempa[0] ^ Rcon[i / Nk]; + } + j = i * 4; + k = (i - Nk) * 4; + RoundKey[j + 0] = RoundKey[k + 0] ^ tempa[0]; + RoundKey[j + 1] = RoundKey[k + 1] ^ tempa[1]; + RoundKey[j + 2] = RoundKey[k + 2] ^ tempa[2]; + RoundKey[j + 3] = RoundKey[k + 3] ^ tempa[3]; + } +} + +static void AES_init_ctx(struct AES_ctx *ctx, const __u8 *key) { + KeyExpansion(ctx->RoundKey, key); +} + +/* This function adds the round key to state. The round key is added + to the state by an XOR function. */ +static void AddRoundKey(__u8 round, state_t *state, const __u8 *RoundKey) { + __u8 i, j; + for (i = 0; i < 4; ++i) { + for (j = 0; j < 4; ++j) { + (*state)[i][j] ^= RoundKey[(round * Nb * 4) + (i * Nb) + j]; + } + } +} + +static __u8 xtime(__u8 x) { return ((x << 1) ^ (((x >> 7) & 1) * 0x1b)); } + +#define Multiply(x, y) \ + (((y & 1) * x) ^ ((y >> 1 & 1) * xtime(x)) ^ \ + ((y >> 2 & 1) * xtime(xtime(x))) ^ \ + ((y >> 3 & 1) * xtime(xtime(xtime(x)))) ^ \ + ((y >> 4 & 1) * xtime(xtime(xtime(xtime(x)))))) + +#define getSBoxInvert(num) (rsbox[(num)]) + +/* MixColumns function mixes the columns of the state matrix. The + method used to multiply may be difficult to understand for the + inexperienced. Please use the references to gain more + information. */ +static void InvMixColumns(state_t *state) { + int i; + __u8 a, b, c, d; + for (i = 0; i < 4; ++i) { + a = (*state)[i][0]; + b = (*state)[i][1]; + c = (*state)[i][2]; + d = (*state)[i][3]; + + (*state)[i][0] = Multiply(a, 0x0e) ^ Multiply(b, 0x0b) ^ Multiply(c, 0x0d) ^ + Multiply(d, 0x09); + (*state)[i][1] = Multiply(a, 0x09) ^ Multiply(b, 0x0e) ^ Multiply(c, 0x0b) ^ + Multiply(d, 0x0d); + (*state)[i][2] = Multiply(a, 0x0d) ^ Multiply(b, 0x09) ^ Multiply(c, 0x0e) ^ + Multiply(d, 0x0b); + (*state)[i][3] = Multiply(a, 0x0b) ^ Multiply(b, 0x0d) ^ Multiply(c, 0x09) ^ + Multiply(d, 0x0e); + } +} + +/* The SubBytes Function Substitutes the values in the state matrix + with values in an S-box. */ +static void InvSubBytes(state_t *state) { + __u8 i, j; + for (i = 0; i < 4; ++i) { + for (j = 0; j < 4; ++j) { + (*state)[j][i] = getSBoxInvert((*state)[j][i]); + } + } +} + +static void InvShiftRows(state_t *state) { + __u8 temp; + + /* Rotate first row 1 columns to right */ + temp = (*state)[3][1]; + (*state)[3][1] = (*state)[2][1]; + (*state)[2][1] = (*state)[1][1]; + (*state)[1][1] = (*state)[0][1]; + (*state)[0][1] = temp; + + /* Rotate second row 2 columns to right */ + temp = (*state)[0][2]; + (*state)[0][2] = (*state)[2][2]; + (*state)[2][2] = temp; + + temp = (*state)[1][2]; + (*state)[1][2] = (*state)[3][2]; + (*state)[3][2] = temp; + + /* Rotate third row 3 columns to right */ + temp = (*state)[0][3]; + (*state)[0][3] = (*state)[1][3]; + (*state)[1][3] = (*state)[2][3]; + (*state)[2][3] = (*state)[3][3]; + (*state)[3][3] = temp; +} + +static void InvCipher(state_t *state, const __u8 *RoundKey) { + /* Add the First round key to the state before starting the + rounds. */ + AddRoundKey(Nr, state, RoundKey); + + /* There will be Nr rounds. The first Nr-1 rounds are identical. + These Nr rounds are executed in the loop below. Last one without + InvMixColumn() */ + InvShiftRows(state); + InvSubBytes(state); + AddRoundKey(Nr - 1, state, RoundKey); + InvMixColumns(state); + + InvShiftRows(state); + InvSubBytes(state); + AddRoundKey(Nr - 2, state, RoundKey); + InvMixColumns(state); + + InvShiftRows(state); + InvSubBytes(state); + AddRoundKey(Nr - 3, state, RoundKey); + InvMixColumns(state); + + InvShiftRows(state); + InvSubBytes(state); + AddRoundKey(Nr - 4, state, RoundKey); + InvMixColumns(state); + + InvShiftRows(state); + InvSubBytes(state); + AddRoundKey(Nr - 5, state, RoundKey); + InvMixColumns(state); + + InvShiftRows(state); + InvSubBytes(state); + AddRoundKey(Nr - 6, state, RoundKey); + InvMixColumns(state); + + InvShiftRows(state); + InvSubBytes(state); + AddRoundKey(Nr - 7, state, RoundKey); + InvMixColumns(state); + + InvShiftRows(state); + InvSubBytes(state); + AddRoundKey(Nr - 8, state, RoundKey); + InvMixColumns(state); + + InvShiftRows(state); + InvSubBytes(state); + AddRoundKey(Nr - 9, state, RoundKey); + InvMixColumns(state); + + InvShiftRows(state); + InvSubBytes(state); + AddRoundKey(Nr - 10, state, RoundKey); +} + +static void AES_ECB_decrypt(const struct AES_ctx *ctx, __u8 *buf) { + /* The next function call decrypts the PlainText with the Key using + AES algorithm. */ + InvCipher((state_t *)buf, ctx->RoundKey); +} + /* rol32: From linux kernel source code */ /** @@ -125,13 +423,13 @@ struct bpf_map_def SEC("maps") reuseport_array = { struct bpf_map_def SEC("maps") sk_info = { .type = BPF_MAP_TYPE_ARRAY, - .max_entries = 1, + .max_entries = 3, .key_size = sizeof(__u32), - .value_size = sizeof(__u32), + .value_size = sizeof(__u64), }; typedef struct quic_hd { - const __u8 *dcid; + __u8 *dcid; __u32 dcidlen; __u32 dcid_offset; __u8 type; @@ -149,9 +447,8 @@ enum { NGTCP2_PKT_SHORT = 0x40, }; -static inline int parse_quic(quic_hd *qhd, const __u8 *data, - const __u8 *data_end) { - const __u8 *p; +static inline int parse_quic(quic_hd *qhd, __u8 *data, __u8 *data_end) { + __u8 *p; __u64 dcidlen; if (*data & 0x80) { @@ -192,7 +489,7 @@ static __u32 hash(const __u8 *data, __u32 datalen, __u32 initval) { static __u32 sk_index_from_dcid(const quic_hd *qhd, const struct sk_reuseport_md *reuse_md, - __u32 num_socks) { + __u64 num_socks) { __u32 len = qhd->dcidlen; __u32 h = reuse_md->hash; __u8 hbuf[8]; @@ -259,11 +556,13 @@ static __u32 sk_index_from_dcid(const quic_hd *qhd, SEC("sk_reuseport") int select_reuseport(struct sk_reuseport_md *reuse_md) { __u32 sk_index, *psk_index; - __u32 *pnum_socks; - __u32 zero = 0; + __u64 *pnum_socks, *pkey; + __u32 zero = 0, key_high_idx = 1, key_low_idx = 2; int rv; quic_hd qhd; __u8 qpktbuf[6 + MAX_DCIDLEN]; + struct AES_ctx aes_ctx; + __u8 key[AES_KEYLEN]; if (bpf_skb_load_bytes(reuse_md, sizeof(struct udphdr), qpktbuf, sizeof(qpktbuf)) != 0) { @@ -275,15 +574,33 @@ int select_reuseport(struct sk_reuseport_md *reuse_md) { return SK_DROP; } + pkey = bpf_map_lookup_elem(&sk_info, &key_high_idx); + if (pkey == NULL) { + return SK_DROP; + } + + __builtin_memcpy(key, pkey, sizeof(*pkey)); + + pkey = bpf_map_lookup_elem(&sk_info, &key_low_idx); + if (pkey == NULL) { + return SK_DROP; + } + + __builtin_memcpy(key + sizeof(*pkey), pkey, sizeof(*pkey)); + rv = parse_quic(&qhd, qpktbuf, qpktbuf + sizeof(qpktbuf)); if (rv != 0) { return SK_DROP; } + AES_init_ctx(&aes_ctx, key); + switch (qhd.type) { case NGTCP2_PKT_INITIAL: case NGTCP2_PKT_0RTT: if (qhd.dcidlen == SV_DCIDLEN) { + AES_ECB_decrypt(&aes_ctx, qhd.dcid); + psk_index = bpf_map_lookup_elem(&cid_prefix_map, qhd.dcid); if (psk_index != NULL) { sk_index = *psk_index; @@ -301,6 +618,8 @@ int select_reuseport(struct sk_reuseport_md *reuse_md) { return SK_DROP; } + AES_ECB_decrypt(&aes_ctx, qhd.dcid); + psk_index = bpf_map_lookup_elem(&cid_prefix_map, qhd.dcid); if (psk_index == NULL) { sk_index = sk_index_from_dcid(&qhd, reuse_md, *pnum_socks); diff --git a/gennghttpxfun.py b/gennghttpxfun.py index 13fecf26..17eefd7b 100755 --- a/gennghttpxfun.py +++ b/gennghttpxfun.py @@ -192,6 +192,7 @@ OPTIONS = [ "frontend-quic-qlog-dir", "frontend-quic-require-token", "frontend-quic-congestion-controller", + "frontend-quic-connection-id-encryption-key", ] LOGVARS = [ diff --git a/src/shrpx.cc b/src/shrpx.cc index c60e4143..1aa47dc4 100644 --- a/src/shrpx.cc +++ b/src/shrpx.cc @@ -1856,6 +1856,14 @@ void fill_default_config(Config *config) { bpfconf.prog_file = StringRef::from_lit(PKGLIBDIR "/reuseport_kern.o"); upstreamconf.congestion_controller = NGTCP2_CC_ALGO_CUBIC; + + // TODO Not really nice to generate random key here, but fine for + // now. + if (RAND_bytes(upstreamconf.cid_encryption_key.data(), + upstreamconf.cid_encryption_key.size()) != 1) { + assert(0); + abort(); + } } auto &http3conf = config->http3; @@ -3237,6 +3245,14 @@ HTTP/3 and QUIC: ? "cubic" : "bbr") << R"( + --frontend-quic-connection-id-encryption-key= + Specify Connection ID encryption key. The encryption + key must be 16 bytes, and it must be encoded in hex + string (which is 32 bytes long). If this option is + omitted, new key is generated. In order to survive QUIC + connection in a configuration reload event, old and new + configuration must have this option and share the same + key. --no-quic-bpf Disable eBPF. --frontend-http3-window-size= @@ -4035,6 +4051,8 @@ int main(int argc, char **argv) { 182}, {SHRPX_OPT_FRONTEND_QUIC_CONGESTION_CONTROLLER.c_str(), required_argument, &flag, 183}, + {SHRPX_OPT_FRONTEND_QUIC_CONNECTION_ID_ENCRYPTION_KEY.c_str(), + required_argument, &flag, 184}, {nullptr, 0, nullptr, 0}}; int option_index = 0; @@ -4912,6 +4930,12 @@ int main(int argc, char **argv) { cmdcfgs.emplace_back(SHRPX_OPT_FRONTEND_QUIC_CONGESTION_CONTROLLER, StringRef{optarg}); break; + case 184: + // --frontend-quic-connection-id-encryption-key + cmdcfgs.emplace_back( + SHRPX_OPT_FRONTEND_QUIC_CONNECTION_ID_ENCRYPTION_KEY, + StringRef{optarg}); + break; default: break; } diff --git a/src/shrpx_config.cc b/src/shrpx_config.cc index be272f62..9a2f24a7 100644 --- a/src/shrpx_config.cc +++ b/src/shrpx_config.cc @@ -2684,6 +2684,10 @@ int option_lookup_token(const char *name, size_t namelen) { case 42: switch (name[41]) { case 'y': + if (util::strieq_l("frontend-quic-connection-id-encryption-ke", name, + 41)) { + return SHRPX_OPTID_FRONTEND_QUIC_CONNECTION_ID_ENCRYPTION_KEY; + } if (util::strieq_l("tls-session-cache-memcached-address-famil", name, 41)) { return SHRPX_OPTID_TLS_SESSION_CACHE_MEMCACHED_ADDRESS_FAMILY; @@ -4013,6 +4017,18 @@ int parse_config(Config *config, int optid, const StringRef &opt, } #endif // ENABLE_HTTP3 + return 0; + case SHRPX_OPTID_FRONTEND_QUIC_CONNECTION_ID_ENCRYPTION_KEY: +#ifdef ENABLE_HTTP3 + if (optarg.size() != config->quic.upstream.cid_encryption_key.size() * 2 || + !util::is_hex_string(optarg)) { + LOG(ERROR) << opt << ": must be a hex-string"; + return -1; + } + util::decode_hex(std::begin(config->quic.upstream.cid_encryption_key), + optarg); +#endif // ENABLE_HTTP3 + return 0; case SHRPX_OPTID_CONF: LOG(WARN) << "conf: ignored"; diff --git a/src/shrpx_config.h b/src/shrpx_config.h index c46ebe5b..6202b174 100644 --- a/src/shrpx_config.h +++ b/src/shrpx_config.h @@ -391,6 +391,8 @@ constexpr auto SHRPX_OPT_FRONTEND_QUIC_REQUIRE_TOKEN = StringRef::from_lit("frontend-quic-require-token"); constexpr auto SHRPX_OPT_FRONTEND_QUIC_CONGESTION_CONTROLLER = StringRef::from_lit("frontend-quic-congestion-controller"); +constexpr auto SHRPX_OPT_FRONTEND_QUIC_CONNECTION_ID_ENCRYPTION_KEY = + StringRef::from_lit("frontend-quic-connection-id-encryption-key"); constexpr size_t SHRPX_OBFUSCATED_NODE_LENGTH = 8; @@ -761,6 +763,7 @@ struct QUICConfig { ngtcp2_cc_algo congestion_controller; bool early_data; bool require_token; + std::array cid_encryption_key; } upstream; struct { StringRef prog_file; @@ -1214,6 +1217,7 @@ enum { SHRPX_OPTID_FRONTEND_MAX_REQUESTS, SHRPX_OPTID_FRONTEND_NO_TLS, SHRPX_OPTID_FRONTEND_QUIC_CONGESTION_CONTROLLER, + SHRPX_OPTID_FRONTEND_QUIC_CONNECTION_ID_ENCRYPTION_KEY, SHRPX_OPTID_FRONTEND_QUIC_DEBUG_LOG, SHRPX_OPTID_FRONTEND_QUIC_EARLY_DATA, SHRPX_OPTID_FRONTEND_QUIC_IDLE_TIMEOUT, diff --git a/src/shrpx_connection_handler.cc b/src/shrpx_connection_handler.cc index c32eb86f..c1c72a02 100644 --- a/src/shrpx_connection_handler.cc +++ b/src/shrpx_connection_handler.cc @@ -1265,8 +1265,8 @@ int ConnectionHandler::quic_ipc_read() { return -1; } - if (dcidlen < SHRPX_QUIC_CID_PREFIXLEN) { - LOG(ERROR) << "DCID is too short"; + if (dcidlen != SHRPX_QUIC_SCIDLEN) { + LOG(ERROR) << "DCID length is invalid"; return -1; } @@ -1287,8 +1287,20 @@ int ConnectionHandler::quic_ipc_read() { return 0; } + auto config = get_config(); + auto &quicconf = config->quic; + + std::array decrypted_dcid; + + if (decrypt_quic_connection_id(decrypted_dcid.data(), dcid, + quicconf.upstream.cid_encryption_key.data()) != + 0) { + return -1; + } + for (auto &worker : workers_) { - if (!std::equal(dcid, dcid + SHRPX_QUIC_CID_PREFIXLEN, + if (!std::equal(std::begin(decrypted_dcid), + std::begin(decrypted_dcid) + SHRPX_QUIC_CID_PREFIXLEN, worker->get_cid_prefix())) { continue; } diff --git a/src/shrpx_http3_upstream.cc b/src/shrpx_http3_upstream.cc index 492bea4b..e21ec9bd 100644 --- a/src/shrpx_http3_upstream.cc +++ b/src/shrpx_http3_upstream.cc @@ -216,7 +216,12 @@ int get_new_connection_id(ngtcp2_conn *conn, ngtcp2_cid *cid, uint8_t *token, auto handler = upstream->get_client_handler(); auto worker = handler->get_worker(); - if (generate_quic_connection_id(cid, cidlen, worker->get_cid_prefix()) != 0) { + auto config = get_config(); + auto &quicconf = config->quic; + + if (generate_encrypted_quic_connection_id( + cid, cidlen, worker->get_cid_prefix(), + quicconf.upstream.cid_encryption_key.data()) != 0) { return NGTCP2_ERR_CALLBACK_FAILURE; } @@ -546,17 +551,18 @@ int Http3Upstream::init(const UpstreamAddr *faddr, const Address &remote_addr, shrpx::stream_stop_sending, }; - ngtcp2_cid scid; - - if (generate_quic_connection_id(&scid, SHRPX_QUIC_SCIDLEN, - worker->get_cid_prefix()) != 0) { - return -1; - } - auto config = get_config(); auto &quicconf = config->quic; auto &http3conf = config->http3; + ngtcp2_cid scid; + + if (generate_encrypted_quic_connection_id( + &scid, SHRPX_QUIC_SCIDLEN, worker->get_cid_prefix(), + quicconf.upstream.cid_encryption_key.data()) != 0) { + return -1; + } + ngtcp2_settings settings; ngtcp2_settings_default(&settings); if (quicconf.upstream.debug.log) { diff --git a/src/shrpx_quic.cc b/src/shrpx_quic.cc index 0cde3618..1c6bee52 100644 --- a/src/shrpx_quic.cc +++ b/src/shrpx_quic.cc @@ -155,8 +155,9 @@ int generate_quic_connection_id(ngtcp2_cid *cid, size_t cidlen) { return 0; } -int generate_quic_connection_id(ngtcp2_cid *cid, size_t cidlen, - const uint8_t *cid_prefix) { +int generate_encrypted_quic_connection_id(ngtcp2_cid *cid, size_t cidlen, + const uint8_t *cid_prefix, + const uint8_t *key) { assert(cidlen > SHRPX_QUIC_CID_PREFIXLEN); auto p = std::copy_n(cid_prefix, SHRPX_QUIC_CID_PREFIXLEN, cid->data); @@ -167,6 +168,48 @@ int generate_quic_connection_id(ngtcp2_cid *cid, size_t cidlen, cid->datalen = cidlen; + return encrypt_quic_connection_id(cid->data, cid->data, key); +} + +int encrypt_quic_connection_id(uint8_t *dest, const uint8_t *src, + const uint8_t *key) { + auto ctx = EVP_CIPHER_CTX_new(); + auto d = defer(EVP_CIPHER_CTX_free, ctx); + + if (!EVP_EncryptInit_ex(ctx, EVP_aes_128_ecb(), nullptr, key, nullptr)) { + return -1; + } + + EVP_CIPHER_CTX_set_padding(ctx, 0); + + int len; + + if (!EVP_EncryptUpdate(ctx, dest, &len, src, SHRPX_QUIC_DECRYPTED_DCIDLEN) || + !EVP_EncryptFinal_ex(ctx, dest + len, &len)) { + return -1; + } + + return 0; +} + +int decrypt_quic_connection_id(uint8_t *dest, const uint8_t *src, + const uint8_t *key) { + auto ctx = EVP_CIPHER_CTX_new(); + auto d = defer(EVP_CIPHER_CTX_free, ctx); + + if (!EVP_DecryptInit_ex(ctx, EVP_aes_128_ecb(), nullptr, key, nullptr)) { + return -1; + } + + EVP_CIPHER_CTX_set_padding(ctx, 0); + + int len; + + if (!EVP_DecryptUpdate(ctx, dest, &len, src, SHRPX_QUIC_DECRYPTED_DCIDLEN) || + !EVP_DecryptFinal_ex(ctx, dest + len, &len)) { + return -1; + } + return 0; } diff --git a/src/shrpx_quic.h b/src/shrpx_quic.h index 2fa12e2d..4938a15c 100644 --- a/src/shrpx_quic.h +++ b/src/shrpx_quic.h @@ -59,6 +59,8 @@ struct UpstreamAddr; constexpr size_t SHRPX_QUIC_SCIDLEN = 20; constexpr size_t SHRPX_QUIC_CID_PREFIXLEN = 8; +constexpr size_t SHRPX_QUIC_DECRYPTED_DCIDLEN = 16; +constexpr size_t SHRPX_QUIC_CID_ENCRYPTION_KEYLEN = 16; constexpr size_t SHRPX_QUIC_MAX_UDP_PAYLOAD_SIZE = 1472; constexpr size_t SHRPX_QUIC_STATELESS_RESET_SECRETLEN = 32; constexpr size_t SHRPX_QUIC_TOKEN_SECRETLEN = 32; @@ -74,8 +76,15 @@ int quic_send_packet(const UpstreamAddr *faddr, const sockaddr *remote_sa, int generate_quic_connection_id(ngtcp2_cid *cid, size_t cidlen); -int generate_quic_connection_id(ngtcp2_cid *cid, size_t cidlen, - const uint8_t *cid_prefix); +int generate_encrypted_quic_connection_id(ngtcp2_cid *cid, size_t cidlen, + const uint8_t *cid_prefix, + const uint8_t *key); + +int encrypt_quic_connection_id(uint8_t *dest, const uint8_t *src, + const uint8_t *key); + +int decrypt_quic_connection_id(uint8_t *dest, const uint8_t *src, + const uint8_t *key); int generate_quic_stateless_reset_token(uint8_t *token, const ngtcp2_cid *cid, const uint8_t *secret, diff --git a/src/shrpx_quic_connection_handler.cc b/src/shrpx_quic_connection_handler.cc index c792f489..49991f03 100644 --- a/src/shrpx_quic_connection_handler.cc +++ b/src/shrpx_quic_connection_handler.cc @@ -90,20 +90,33 @@ int QUICConnectionHandler::handle_packet(const UpstreamAddr *faddr, ClientHandler *handler; + auto &quicconf = config->quic; + auto it = connections_.find(dcid_key); if (it == std::end(connections_)) { - if (!std::equal(dcid, dcid + SHRPX_QUIC_CID_PREFIXLEN, - worker_->get_cid_prefix())) { - auto quic_lwp = - conn_handler->match_quic_lingering_worker_process_cid_prefix(dcid, - dcidlen); - if (quic_lwp) { - if (conn_handler->forward_quic_packet_to_lingering_worker_process( - quic_lwp, remote_addr, local_addr, data, datalen) == 0) { + std::array decrypted_dcid; + + if (dcidlen == SHRPX_QUIC_SCIDLEN) { + if (decrypt_quic_connection_id( + decrypted_dcid.data(), dcid, + quicconf.upstream.cid_encryption_key.data()) != 0) { + return 0; + } + + if (!std::equal(std::begin(decrypted_dcid), + std::begin(decrypted_dcid) + SHRPX_QUIC_CID_PREFIXLEN, + worker_->get_cid_prefix())) { + auto quic_lwp = + conn_handler->match_quic_lingering_worker_process_cid_prefix( + decrypted_dcid.data(), decrypted_dcid.size()); + if (quic_lwp) { + if (conn_handler->forward_quic_packet_to_lingering_worker_process( + quic_lwp, remote_addr, local_addr, data, datalen) == 0) { + return 0; + } + return 0; } - - return 0; } } @@ -134,14 +147,14 @@ int QUICConnectionHandler::handle_packet(const UpstreamAddr *faddr, const uint8_t *token = nullptr; size_t tokenlen = 0; - auto &quicconf = config->quic; - switch (ngtcp2_accept(&hd, data, datalen)) { case 0: { // If we get Initial and it has the CID prefix of this worker, it // is likely that client is intentionally use the our prefix. // Just drop it. - if (std::equal(dcid, dcid + SHRPX_QUIC_CID_PREFIXLEN, + if (dcidlen == SHRPX_QUIC_SCIDLEN && + std::equal(std::begin(decrypted_dcid), + std::begin(decrypted_dcid) + SHRPX_QUIC_CID_PREFIXLEN, worker_->get_cid_prefix())) { return 0; } @@ -237,11 +250,13 @@ int QUICConnectionHandler::handle_packet(const UpstreamAddr *faddr, return 0; default: if (!config->single_thread && !(data[0] & 0x80) && - dcidlen > SHRPX_QUIC_CID_PREFIXLEN && - !std::equal(dcid, dcid + SHRPX_QUIC_CID_PREFIXLEN, + dcidlen == SHRPX_QUIC_SCIDLEN && + !std::equal(std::begin(decrypted_dcid), + std::begin(decrypted_dcid) + SHRPX_QUIC_CID_PREFIXLEN, worker_->get_cid_prefix())) { if (conn_handler->forward_quic_packet(faddr, remote_addr, local_addr, - dcid, data, datalen) == 0) { + decrypted_dcid.data(), data, + datalen) == 0) { return 0; } } diff --git a/src/shrpx_worker.cc b/src/shrpx_worker.cc index 9dbdd28f..75bd7325 100644 --- a/src/shrpx_worker.cc +++ b/src/shrpx_worker.cc @@ -923,7 +923,7 @@ int Worker::create_quic_server_socket(UpstreamAddr &faddr) { } constexpr uint32_t zero = 0; - uint32_t num_socks = config->num_worker; + uint64_t num_socks = config->num_worker; if (bpf_map_update_elem(bpf_map__fd(sk_info), &zero, &num_socks, BPF_ANY) != 0) { @@ -933,6 +933,29 @@ int Worker::create_quic_server_socket(UpstreamAddr &faddr) { return -1; } + auto &quicconf = config->quic; + + constexpr uint32_t key_high_idx = 1; + constexpr uint32_t key_low_idx = 2; + + if (bpf_map_update_elem(bpf_map__fd(sk_info), &key_high_idx, + quicconf.upstream.cid_encryption_key.data(), + BPF_ANY) != 0) { + LOG(FATAL) << "Failed to update key_high_idx sk_info: " + << xsi_strerror(errno, errbuf.data(), errbuf.size()); + close(fd); + return -1; + } + + if (bpf_map_update_elem(bpf_map__fd(sk_info), &key_low_idx, + quicconf.upstream.cid_encryption_key.data() + 8, + BPF_ANY) != 0) { + LOG(FATAL) << "Failed to update key_low_idx sk_info: " + << xsi_strerror(errno, errbuf.data(), errbuf.size()); + close(fd); + return -1; + } + auto prog_fd = bpf_program__fd(prog); if (setsockopt(fd, SOL_SOCKET, SO_ATTACH_REUSEPORT_EBPF, &prog_fd,