bpf: Take into account entire DCID

This commit is contained in:
Tatsuhiro Tsujikawa 2021-09-04 18:18:17 +09:00
parent 47edc33b0d
commit 1c7a001489
1 changed files with 91 additions and 36 deletions

View File

@ -157,8 +157,8 @@ static inline int parse_quic(quic_hd *qhd, const __u8 *data,
if (*data & 0x80) {
p = data + 1 + 4;
// Do not check the actual DCID length because we might not buffer
// whole DCID here.
/* Do not check the actual DCID length because we might not buffer
entire DCID here. */
dcidlen = *p;
if (dcidlen > MAX_DCIDLEN || dcidlen < MIN_DCIDLEN) {
@ -181,51 +181,118 @@ static inline int parse_quic(quic_hd *qhd, const __u8 *data,
return 0;
}
static __u32 hash(const __u8 *data, __u32 datalen, __u32 initval) {
__u32 a, b;
a = (data[0] << 24) | (data[1] << 16) | (data[2] << 8) | data[3];
b = (data[4] << 24) | (data[5] << 16) | (data[6] << 8) | data[7];
return jhash_2words(a, b, initval);
}
static __u32 sk_index_from_dcid(const quic_hd *qhd,
const struct sk_reuseport_md *reuse_md,
__u32 num_socks) {
__u32 len = qhd->dcidlen;
__u32 h = reuse_md->hash;
__u8 hbuf[8];
if (len > 16) {
__builtin_memset(hbuf, 0, sizeof(hbuf));
switch (len) {
case 20:
__builtin_memcpy(hbuf, qhd->dcid + 16, 4);
break;
case 19:
__builtin_memcpy(hbuf, qhd->dcid + 16, 3);
break;
case 18:
__builtin_memcpy(hbuf, qhd->dcid + 16, 2);
break;
case 17:
__builtin_memcpy(hbuf, qhd->dcid + 16, 1);
break;
}
h = hash(hbuf, sizeof(hbuf), h);
len = 16;
}
if (len > 8) {
__builtin_memset(hbuf, 0, sizeof(hbuf));
switch (len) {
case 16:
__builtin_memcpy(hbuf, qhd->dcid + 8, 8);
break;
case 15:
__builtin_memcpy(hbuf, qhd->dcid + 8, 7);
break;
case 14:
__builtin_memcpy(hbuf, qhd->dcid + 8, 6);
break;
case 13:
__builtin_memcpy(hbuf, qhd->dcid + 8, 5);
break;
case 12:
__builtin_memcpy(hbuf, qhd->dcid + 8, 4);
break;
case 11:
__builtin_memcpy(hbuf, qhd->dcid + 8, 3);
break;
case 10:
__builtin_memcpy(hbuf, qhd->dcid + 8, 2);
break;
case 9:
__builtin_memcpy(hbuf, qhd->dcid + 8, 1);
break;
}
h = hash(hbuf, sizeof(hbuf), h);
len = 8;
}
return hash(qhd->dcid, len, h) % num_socks;
}
SEC("sk_reuseport")
int select_reuseport(struct sk_reuseport_md *reuse_md) {
__u32 sk_index, *psk_index;
__u8 sk_prefix[8];
__u32 *pnum_socks;
__u32 zero = 0;
int rv;
quic_hd qhd;
__u32 a, b;
__u8 pkt_databuf[6 + MAX_DCIDLEN];
__u8 qpktbuf[6 + MAX_DCIDLEN];
if (bpf_skb_load_bytes(reuse_md, sizeof(struct udphdr), pkt_databuf,
sizeof(pkt_databuf)) != 0) {
if (bpf_skb_load_bytes(reuse_md, sizeof(struct udphdr), qpktbuf,
sizeof(qpktbuf)) != 0) {
return SK_DROP;
}
rv = parse_quic(&qhd, pkt_databuf, pkt_databuf + sizeof(pkt_databuf));
if (rv != 0) {
return SK_DROP;
}
switch (qhd.type) {
case NGTCP2_PKT_INITIAL:
case NGTCP2_PKT_0RTT:
__builtin_memcpy(sk_prefix, pkt_databuf + qhd.dcid_offset, CID_PREFIXLEN);
if (qhd.dcidlen == SV_DCIDLEN) {
psk_index = bpf_map_lookup_elem(&cid_prefix_map, sk_prefix);
if (psk_index != NULL) {
sk_index = *psk_index;
break;
}
}
pnum_socks = bpf_map_lookup_elem(&sk_info, &zero);
if (pnum_socks == NULL) {
return SK_DROP;
}
a = (sk_prefix[0] << 24) | (sk_prefix[1] << 16) | (sk_prefix[2] << 8) |
sk_prefix[3];
b = (sk_prefix[4] << 24) | (sk_prefix[5] << 16) | (sk_prefix[6] << 8) |
sk_prefix[7];
rv = parse_quic(&qhd, qpktbuf, qpktbuf + sizeof(qpktbuf));
if (rv != 0) {
return SK_DROP;
}
sk_index = jhash_2words(a, b, reuse_md->hash) % *pnum_socks;
switch (qhd.type) {
case NGTCP2_PKT_INITIAL:
case NGTCP2_PKT_0RTT:
if (qhd.dcidlen == SV_DCIDLEN) {
psk_index = bpf_map_lookup_elem(&cid_prefix_map, qhd.dcid);
if (psk_index != NULL) {
sk_index = *psk_index;
break;
}
}
sk_index = sk_index_from_dcid(&qhd, reuse_md, *pnum_socks);
break;
case NGTCP2_PKT_HANDSHAKE:
@ -234,21 +301,9 @@ int select_reuseport(struct sk_reuseport_md *reuse_md) {
return SK_DROP;
}
__builtin_memcpy(sk_prefix, pkt_databuf + qhd.dcid_offset, CID_PREFIXLEN);
psk_index = bpf_map_lookup_elem(&cid_prefix_map, sk_prefix);
psk_index = bpf_map_lookup_elem(&cid_prefix_map, qhd.dcid);
if (psk_index == NULL) {
pnum_socks = bpf_map_lookup_elem(&sk_info, &zero);
if (pnum_socks == NULL) {
return SK_DROP;
}
a = (sk_prefix[0] << 24) | (sk_prefix[1] << 16) | (sk_prefix[2] << 8) |
sk_prefix[3];
b = (sk_prefix[4] << 24) | (sk_prefix[5] << 16) | (sk_prefix[6] << 8) |
sk_prefix[7];
sk_index = jhash_2words(a, b, reuse_md->hash) % *pnum_socks;
sk_index = sk_index_from_dcid(&qhd, reuse_md, *pnum_socks);
break;
}