diff --git a/bpf/reuseport_kern.c b/bpf/reuseport_kern.c index d0e473e8..710ff2ce 100644 --- a/bpf/reuseport_kern.c +++ b/bpf/reuseport_kern.c @@ -157,8 +157,8 @@ static inline int parse_quic(quic_hd *qhd, const __u8 *data, if (*data & 0x80) { p = data + 1 + 4; - // Do not check the actual DCID length because we might not buffer - // whole DCID here. + /* Do not check the actual DCID length because we might not buffer + entire DCID here. */ dcidlen = *p; if (dcidlen > MAX_DCIDLEN || dcidlen < MIN_DCIDLEN) { @@ -181,23 +181,101 @@ static inline int parse_quic(quic_hd *qhd, const __u8 *data, return 0; } +static __u32 hash(const __u8 *data, __u32 datalen, __u32 initval) { + __u32 a, b; + + a = (data[0] << 24) | (data[1] << 16) | (data[2] << 8) | data[3]; + b = (data[4] << 24) | (data[5] << 16) | (data[6] << 8) | data[7]; + + return jhash_2words(a, b, initval); +} + +static __u32 sk_index_from_dcid(const quic_hd *qhd, + const struct sk_reuseport_md *reuse_md, + __u32 num_socks) { + __u32 len = qhd->dcidlen; + __u32 h = reuse_md->hash; + __u8 hbuf[8]; + + if (len > 16) { + __builtin_memset(hbuf, 0, sizeof(hbuf)); + + switch (len) { + case 20: + __builtin_memcpy(hbuf, qhd->dcid + 16, 4); + break; + case 19: + __builtin_memcpy(hbuf, qhd->dcid + 16, 3); + break; + case 18: + __builtin_memcpy(hbuf, qhd->dcid + 16, 2); + break; + case 17: + __builtin_memcpy(hbuf, qhd->dcid + 16, 1); + break; + } + + h = hash(hbuf, sizeof(hbuf), h); + len = 16; + } + + if (len > 8) { + __builtin_memset(hbuf, 0, sizeof(hbuf)); + + switch (len) { + case 16: + __builtin_memcpy(hbuf, qhd->dcid + 8, 8); + break; + case 15: + __builtin_memcpy(hbuf, qhd->dcid + 8, 7); + break; + case 14: + __builtin_memcpy(hbuf, qhd->dcid + 8, 6); + break; + case 13: + __builtin_memcpy(hbuf, qhd->dcid + 8, 5); + break; + case 12: + __builtin_memcpy(hbuf, qhd->dcid + 8, 4); + break; + case 11: + __builtin_memcpy(hbuf, qhd->dcid + 8, 3); + break; + case 10: + __builtin_memcpy(hbuf, qhd->dcid + 8, 2); + break; + case 9: + __builtin_memcpy(hbuf, qhd->dcid + 8, 1); + break; + } + + h = hash(hbuf, sizeof(hbuf), h); + len = 8; + } + + return hash(qhd->dcid, len, h) % num_socks; +} + SEC("sk_reuseport") int select_reuseport(struct sk_reuseport_md *reuse_md) { __u32 sk_index, *psk_index; - __u8 sk_prefix[8]; __u32 *pnum_socks; __u32 zero = 0; int rv; quic_hd qhd; - __u32 a, b; - __u8 pkt_databuf[6 + MAX_DCIDLEN]; + __u8 qpktbuf[6 + MAX_DCIDLEN]; - if (bpf_skb_load_bytes(reuse_md, sizeof(struct udphdr), pkt_databuf, - sizeof(pkt_databuf)) != 0) { + if (bpf_skb_load_bytes(reuse_md, sizeof(struct udphdr), qpktbuf, + sizeof(qpktbuf)) != 0) { return SK_DROP; } - rv = parse_quic(&qhd, pkt_databuf, pkt_databuf + sizeof(pkt_databuf)); + pnum_socks = bpf_map_lookup_elem(&sk_info, &zero); + if (pnum_socks == NULL) { + return SK_DROP; + } + + rv = parse_quic(&qhd, qpktbuf, qpktbuf + sizeof(qpktbuf)); if (rv != 0) { return SK_DROP; } @@ -205,27 +283,16 @@ int select_reuseport(struct sk_reuseport_md *reuse_md) { switch (qhd.type) { case NGTCP2_PKT_INITIAL: case NGTCP2_PKT_0RTT: - __builtin_memcpy(sk_prefix, pkt_databuf + qhd.dcid_offset, CID_PREFIXLEN); - if (qhd.dcidlen == SV_DCIDLEN) { - psk_index = bpf_map_lookup_elem(&cid_prefix_map, sk_prefix); + psk_index = bpf_map_lookup_elem(&cid_prefix_map, qhd.dcid); if (psk_index != NULL) { sk_index = *psk_index; + break; } } - pnum_socks = bpf_map_lookup_elem(&sk_info, &zero); - if (pnum_socks == NULL) { - return SK_DROP; - } - - a = (sk_prefix[0] << 24) | (sk_prefix[1] << 16) | (sk_prefix[2] << 8) | - sk_prefix[3]; - b = (sk_prefix[4] << 24) | (sk_prefix[5] << 16) | (sk_prefix[6] << 8) | - sk_prefix[7]; - - sk_index = jhash_2words(a, b, reuse_md->hash) % *pnum_socks; + sk_index = sk_index_from_dcid(&qhd, reuse_md, *pnum_socks); break; case NGTCP2_PKT_HANDSHAKE: @@ -234,21 +301,9 @@ int select_reuseport(struct sk_reuseport_md *reuse_md) { return SK_DROP; } - __builtin_memcpy(sk_prefix, pkt_databuf + qhd.dcid_offset, CID_PREFIXLEN); - - psk_index = bpf_map_lookup_elem(&cid_prefix_map, sk_prefix); + psk_index = bpf_map_lookup_elem(&cid_prefix_map, qhd.dcid); if (psk_index == NULL) { - pnum_socks = bpf_map_lookup_elem(&sk_info, &zero); - if (pnum_socks == NULL) { - return SK_DROP; - } - - a = (sk_prefix[0] << 24) | (sk_prefix[1] << 16) | (sk_prefix[2] << 8) | - sk_prefix[3]; - b = (sk_prefix[4] << 24) | (sk_prefix[5] << 16) | (sk_prefix[6] << 8) | - sk_prefix[7]; - - sk_index = jhash_2words(a, b, reuse_md->hash) % *pnum_socks; + sk_index = sk_index_from_dcid(&qhd, reuse_md, *pnum_socks); break; }