[WIP] Experiment with eBPF
diff --git a/examples/Makefile.am b/examples/Makefile.am
index 13c2ccd..f5b78c5 100644
--- a/examples/Makefile.am
+++ b/examples/Makefile.am
@@ -33,14 +33,16 @@
 	@LIBNGHTTP3_CFLAGS@ \
 	@DEFS@
 AM_LDFLAGS = -no-install \
-	@LIBTOOL_LDFLAGS@
+	@LIBTOOL_LDFLAGS@ \
+	-pthread
 LDADD = $(top_builddir)/crypto/openssl/libngtcp2_crypto_openssl.la \
 	$(top_builddir)/lib/libngtcp2.la \
 	$(top_builddir)/third-party/libhttp-parser.la \
 	@JEMALLOC_LIBS@ \
 	@OPENSSL_LIBS@ \
 	@LIBEV_LIBS@ \
-	@LIBNGHTTP3_LIBS@
+	@LIBNGHTTP3_LIBS@ \
+	-lbpf
 
 noinst_PROGRAMS = client server h09client h09server
 
diff --git a/examples/server.cc b/examples/server.cc
index 6e8abd2..6767e2e 100644
--- a/examples/server.cc
+++ b/examples/server.cc
@@ -30,6 +30,8 @@
 #include <algorithm>
 #include <memory>
 #include <fstream>
+#include <future>
+#include <thread>
 
 #include <unistd.h>
 #include <getopt.h>
@@ -40,6 +42,14 @@
 #include <fcntl.h>
 #include <sys/mman.h>
 #include <netinet/udp.h>
+#include <signal.h>
+
+#include <linux/bpf.h>
+
+enum bpf_stats_type {};
+
+#include <bpf/libbpf.h>
+#include <bpf/bpf.h>
 
 #include <openssl/bio.h>
 #include <openssl/err.h>
@@ -101,10 +111,78 @@
 Config config{};
 } // namespace
 
+namespace {
+int prog_fd;
+int reuseport_array;
+} // namespace
+
 Buffer::Buffer(const uint8_t *data, size_t datalen)
     : buf{data, data + datalen}, begin(buf.data()), tail(begin + datalen) {}
 Buffer::Buffer(size_t datalen) : buf(datalen), begin(buf.data()), tail(begin) {}
 
+namespace {
+void sha256(uint8_t *dest, const uint8_t *kpad, const uint8_t *b, size_t blen) {
+  auto ctx = EVP_MD_CTX_new();
+  assert(ctx);
+
+  auto ctx_deleter = defer(EVP_MD_CTX_free, ctx);
+
+  unsigned int mdlen = 32;
+  if (!EVP_DigestInit_ex(ctx, EVP_sha256(), nullptr) ||
+      !EVP_DigestUpdate(ctx, kpad, 64) || !EVP_DigestUpdate(ctx, b, blen) ||
+      !EVP_DigestFinal_ex(ctx, dest, &mdlen)) {
+    assert(0);
+  }
+}
+} // namespace
+
+namespace {
+constexpr size_t HMACLEN = 4;
+std::array<uint8_t, 64> kopad, kipad;
+} // namespace
+
+namespace {
+void hmac256_32(uint32_t *dest, const uint8_t *src, size_t srclen) {
+  uint8_t h[32];
+
+  sha256(h, kipad.data(), src, srclen);
+  sha256(h, kopad.data(), h, sizeof(h));
+
+  memcpy(dest, h, HMACLEN);
+}
+} // namespace
+
+namespace {
+void generate_authenticated_cid(uint8_t *dest, size_t destlen,
+                                uint32_t svindex) {
+  // - Connection ID is NGTCP2_SV_SCIDLEN bytes in total.
+  //   Connection
+  //
+  // - ID is authenticated with HMAC-SHA256-32 to reliably embed
+  //   svindex in it.
+  //
+  // - The last 4 bytes are digest.  Thus 14 bytes are available.
+  //
+  // - For now, encode svindex in the first (most significant) byte.
+
+  assert(destlen == NGTCP2_SV_SCIDLEN);
+  assert(svindex < 256);
+
+  auto dis = std::uniform_int_distribution<uint8_t>(0, 255);
+  auto f = [&dis]() { return dis(randgen); };
+  auto len = NGTCP2_SV_SCIDLEN - HMACLEN;
+
+  std::generate_n(dest, len, f);
+  // TODO Encode svindex elsewhere other than first byte.
+  dest[0] = static_cast<uint8_t>(svindex);
+
+  uint32_t hmac;
+  hmac256_32(&hmac, dest, len);
+
+  memcpy(dest + len, &hmac, sizeof(hmac));
+}
+} // namespace
+
 int Handler::on_key(ngtcp2_crypto_level level, const uint8_t *rx_secret,
                     const uint8_t *tx_secret, size_t secretlen) {
   std::array<uint8_t, 64> rx_key, rx_iv, rx_hp_key, tx_key, tx_iv, tx_hp_key;
@@ -1075,11 +1153,12 @@
 namespace {
 int get_new_connection_id(ngtcp2_conn *conn, ngtcp2_cid *cid, uint8_t *token,
                           size_t cidlen, void *user_data) {
-  auto dis = std::uniform_int_distribution<uint8_t>(0, 255);
-  auto f = [&dis]() { return dis(randgen); };
+  auto h = static_cast<Handler *>(user_data);
+  auto server = h->server();
 
-  std::generate_n(cid->data, cidlen, f);
   cid->datalen = cidlen;
+  generate_authenticated_cid(cid->data, cidlen, server->get_svindex());
+
   auto md = ngtcp2_crypto_md{const_cast<EVP_MD *>(EVP_sha256())};
   if (ngtcp2_crypto_generate_stateless_reset_token(
           token, &md, config.static_secret.data(), config.static_secret.size(),
@@ -1087,8 +1166,7 @@
     return NGTCP2_ERR_CALLBACK_FAILURE;
   }
 
-  auto h = static_cast<Handler *>(user_data);
-  h->server()->associate_cid(cid, h);
+  server->associate_cid(cid, h);
 
   return 0;
 }
@@ -1591,8 +1669,8 @@
   auto dis = std::uniform_int_distribution<uint8_t>(0, 255);
 
   scid_.datalen = NGTCP2_SV_SCIDLEN;
-  std::generate(scid_.data, scid_.data + scid_.datalen,
-                [&dis]() { return dis(randgen); });
+  generate_authenticated_cid(scid_.data, NGTCP2_SV_SCIDLEN,
+                             server_->get_svindex());
 
   ngtcp2_settings settings;
   ngtcp2_settings_default(&settings);
@@ -2183,14 +2261,32 @@
 } // namespace
 
 namespace {
+std::vector<std::unique_ptr<Server>> servers;
+} // namespace
+
+namespace {
 void siginthandler(struct ev_loop *loop, ev_signal *watcher, int revents) {
+  std::cerr << "siginthandler" << std::endl;
+  for (auto &sv : servers) {
+    sv->request_stop();
+  }
+
+  std::cerr << "Stopping default loop" << std::endl;
   ev_break(loop, EVBREAK_ALL);
 }
 } // namespace
 
-Server::Server(struct ev_loop *loop, SSL_CTX *ssl_ctx)
-    : loop_(loop), ssl_ctx_(ssl_ctx) {
-  ev_signal_init(&sigintev_, siginthandler, SIGINT);
+namespace {
+void stopcb(struct ev_loop *loop, ev_async *w, int revents) {
+  auto sv = static_cast<Server *>(w->data);
+  sv->stop();
+}
+} // namespace
+
+Server::Server(uint32_t svindex, struct ev_loop *loop, SSL_CTX *ssl_ctx)
+    : svindex_(svindex), loop_(loop), ssl_ctx_(ssl_ctx) {
+  ev_async_init(&stopev_, stopcb);
+  stopev_.data = this;
 
   token_aead_.native_handle = const_cast<EVP_CIPHER *>(EVP_aes_128_gcm());
   token_md_.native_handle = const_cast<EVP_MD *>(EVP_sha256());
@@ -2208,7 +2304,7 @@
     ev_io_stop(loop_, &ep.rev);
   }
 
-  ev_signal_stop(loop_, &sigintev_);
+  ev_async_stop(loop_, &stopev_);
 
   while (!handlers_.empty()) {
     auto it = std::begin(handlers_);
@@ -2228,9 +2324,15 @@
   endpoints_.clear();
 }
 
+void Server::start() { ev_run(loop_); }
+
+void Server::stop() { ev_break(loop_, EVBREAK_ALL); }
+
+void Server::request_stop() { ev_async_send(loop_, &stopev_); }
+
 namespace {
 int create_sock(Address &local_addr, const char *addr, const char *port,
-                int family) {
+                int family, uint32_t svindex) {
   addrinfo hints{};
   addrinfo *res, *rp;
   int val = 1;
@@ -2272,9 +2374,31 @@
       continue;
     }
 
+    if (setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &val,
+                   static_cast<socklen_t>(sizeof(val))) == -1) {
+      close(fd);
+      continue;
+    }
+
+    if (svindex == 0 &&
+        setsockopt(fd, SOL_SOCKET, SO_ATTACH_REUSEPORT_EBPF, &prog_fd,
+                   static_cast<socklen_t>(sizeof(prog_fd))) == -1) {
+      std::cerr << "Unable to attach bpf prog: " << strerror(errno)
+                << std::endl;
+      close(fd);
+      continue;
+    }
+
     fd_set_recv_ecn(fd, rp->ai_family);
 
     if (bind(fd, rp->ai_addr, rp->ai_addrlen) != -1) {
+      if (bpf_map_update_elem(reuseport_array, &svindex, &fd, BPF_NOEXIST) !=
+          0) {
+        std::cerr << "bpf_map_update_elem: " << strerror(errno) << std::endl;
+        close(fd);
+        continue;
+      }
+
       break;
     }
 
@@ -2301,9 +2425,9 @@
 
 namespace {
 int add_endpoint(std::vector<Endpoint> &endpoints, const char *addr,
-                 const char *port, int af) {
+                 const char *port, int af, uint32_t svindex) {
   Address dest;
-  auto fd = create_sock(dest, addr, port, af);
+  auto fd = create_sock(dest, addr, port, af, svindex);
   if (fd == -1) {
     return -1;
   }
@@ -2364,11 +2488,11 @@
 
   auto ready = false;
   if (!util::numeric_host(addr, AF_INET6) &&
-      add_endpoint(endpoints_, addr, port, AF_INET) == 0) {
+      add_endpoint(endpoints_, addr, port, AF_INET, svindex_) == 0) {
     ready = true;
   }
   if (!util::numeric_host(addr, AF_INET) &&
-      add_endpoint(endpoints_, addr, port, AF_INET6) == 0) {
+      add_endpoint(endpoints_, addr, port, AF_INET6, svindex_) == 0) {
     ready = true;
   }
   if (!ready) {
@@ -2393,7 +2517,7 @@
     ev_io_start(loop_, &ep.rev);
   }
 
-  ev_signal_start(loop_, &sigintev_);
+  ev_async_start(loop_, &stopev_);
 
   return 0;
 }
@@ -3265,6 +3389,8 @@
   handlers_.erase(util::make_cid_key(h->scid()));
 }
 
+uint32_t Server::get_svindex() const { return svindex_; }
+
 namespace {
 int alpn_select_proto_cb(SSL *ssl, const unsigned char **out,
                          unsigned char *outlen, const unsigned char *in,
@@ -3986,15 +4112,80 @@
     exit(EXIT_FAILURE);
   }
 
-  Server s(EV_DEFAULT, ssl_ctx);
-  if (s.init(addr, port) != 0) {
+  bpf_object *obj;
+  if (bpf_prog_load("reuseport_kern.o", BPF_PROG_TYPE_SK_REUSEPORT, &obj,
+                    &prog_fd) != 0) {
+    std::cerr << "bpf_prog_load: " << strerror(errno) << std::endl;
     exit(EXIT_FAILURE);
   }
 
+  auto map = bpf_object__find_map_by_name(obj, "reuseport_array");
+  if (!map) {
+    std::cerr << "failed to find reuseport_array" << std::endl;
+    exit(EXIT_FAILURE);
+  }
+
+  reuseport_array = bpf_map__fd(map);
+
+  // TODO Infer number of threads to bpf prog.
+  // TODO Embed svindex in Connection ID.
+  // TODO Authenticate Connection ID with HMAC-SHA256-32
+
+  sigset_t set;
+  sigemptyset(&set);
+  sigaddset(&set, SIGINT);
+  if (auto err = pthread_sigmask(SIG_BLOCK, &set, nullptr); err) {
+    std::cerr << "pthread_sigmask: " << strerror(err) << std::endl;
+    exit(EXIT_FAILURE);
+  }
+
+  std::array<uint8_t, 32> secret;
+  util::generate_secret(secret.data(), secret.size());
+
+  for (size_t i = 0; i < secret.size(); ++i) {
+    kipad[i] = secret[i] ^ 0x36;
+    kopad[i] = secret[i] ^ 0x5c;
+  }
+  std::fill(std::begin(kipad) + secret.size(), std::end(kipad), 0x36);
+  std::fill(std::begin(kopad) + secret.size(), std::end(kopad), 0x5c);
+
+  constexpr size_t nthreads = 4;
+
+  for (size_t i = 0; i < nthreads; ++i) {
+    auto loop = ev_loop_new(0);
+    auto sv = std::make_unique<Server>(i, loop, ssl_ctx);
+    if (sv->init(addr, port) != 0) {
+      exit(EXIT_FAILURE);
+    }
+    servers.emplace_back(std::move(sv));
+  }
+
+  std::vector<std::future<void>> futures;
+
+  for (auto &sv : servers) {
+    futures.emplace_back(std::async(std::launch::async, [&sv]() {
+      sv->start();
+      sv->disconnect();
+      sv->close();
+    }));
+  }
+
+  if (auto err = pthread_sigmask(SIG_UNBLOCK, &set, nullptr); err) {
+    std::cerr << "pthread_sigmask: " << strerror(err) << std::endl;
+    exit(EXIT_FAILURE);
+  }
+
+  ev_signal sigintev;
+  ev_signal_init(&sigintev, siginthandler, SIGINT);
+  ev_signal_start(EV_DEFAULT, &sigintev);
+
+  std::cerr << "Running main thread" << std::endl;
+
   ev_run(EV_DEFAULT, 0);
 
-  s.disconnect();
-  s.close();
+  for (auto &f : futures) {
+    f.wait();
+  }
 
   return EXIT_SUCCESS;
 }
diff --git a/examples/server.h b/examples/server.h
index 89d87ea..410f4d2 100644
--- a/examples/server.h
+++ b/examples/server.h
@@ -341,12 +341,15 @@
 
 class Server {
 public:
-  Server(struct ev_loop *loop, SSL_CTX *ssl_ctx);
+  Server(uint32_t svindex, struct ev_loop *loop, SSL_CTX *ssl_ctx);
   ~Server();
 
   int init(const char *addr, const char *port);
   void disconnect();
   void close();
+  void start();
+  void stop();
+  void request_stop();
 
   int on_read(Endpoint &ep);
   int send_version_negotiation(uint32_t version, const uint8_t *dcid,
@@ -374,8 +377,10 @@
   void generate_rand_data(uint8_t *buf, size_t len);
   void associate_cid(const ngtcp2_cid *cid, Handler *h);
   void dissociate_cid(const ngtcp2_cid *cid);
+  uint32_t get_svindex() const;
 
 private:
+  uint32_t svindex_;
   std::unordered_map<std::string, std::unique_ptr<Handler>> handlers_;
   // ctos_ is a mapping between client's initial destination
   // connection ID, and server source connection ID.
@@ -385,7 +390,7 @@
   SSL_CTX *ssl_ctx_;
   ngtcp2_crypto_aead token_aead_;
   ngtcp2_crypto_md token_md_;
-  ev_signal sigintev_;
+  ev_async stopev_;
 };
 
 #endif // SERVER_H
diff --git a/reuseport_kern.c b/reuseport_kern.c
new file mode 100644
index 0000000..84036df
--- /dev/null
+++ b/reuseport_kern.c
@@ -0,0 +1,236 @@
+#include <stdlib.h>
+#include <linux/in.h>
+#include <linux/ip.h>
+#include <linux/ipv6.h>
+#include <linux/tcp.h>
+#include <linux/udp.h>
+#include <linux/bpf.h>
+#include <linux/types.h>
+#include <linux/if_ether.h>
+
+#include <bpf/bpf_endian.h>
+#include <bpf/bpf_helpers.h>
+
+/*
+ * How to compile:
+ *
+ * clang-10 -O2 -Wall -target bpf -g -c reuseport_kern.c -o reuseport_kern.o \
+ *   -I/path/to/kernel/include
+ *
+ * See
+ * https://www.kernel.org/doc/Documentation/kbuild/headers_install.txt
+ * how to install kernel header files.
+ */
+
+/* rol32: From linux kernel source code */
+
+/**
+ * rol32 - rotate a 32-bit value left
+ * @word: value to rotate
+ * @shift: bits to roll
+ */
+static inline __u32 rol32(__u32 word, unsigned int shift) {
+  return (word << shift) | (word >> ((-shift) & 31));
+}
+
+/* jhash.h: Jenkins hash support.
+ *
+ * Copyright (C) 2006. Bob Jenkins (bob_jenkins@burtleburtle.net)
+ *
+ * https://burtleburtle.net/bob/hash/
+ *
+ * These are the credits from Bob's sources:
+ *
+ * lookup3.c, by Bob Jenkins, May 2006, Public Domain.
+ *
+ * These are functions for producing 32-bit hashes for hash table lookup.
+ * hashword(), hashlittle(), hashlittle2(), hashbig(), mix(), and final()
+ * are externally useful functions.  Routines to test the hash are included
+ * if SELF_TEST is defined.  You can use this free for any purpose.  It's in
+ * the public domain.  It has no warranty.
+ *
+ * Copyright (C) 2009-2010 Jozsef Kadlecsik (kadlec@blackhole.kfki.hu)
+ *
+ * I've modified Bob's hash to be useful in the Linux kernel, and
+ * any bugs present are my fault.
+ * Jozsef
+ */
+
+/* __jhash_final - final mixing of 3 32-bit values (a,b,c) into c */
+#define __jhash_final(a, b, c)                                                 \
+  {                                                                            \
+    c ^= b;                                                                    \
+    c -= rol32(b, 14);                                                         \
+    a ^= c;                                                                    \
+    a -= rol32(c, 11);                                                         \
+    b ^= a;                                                                    \
+    b -= rol32(a, 25);                                                         \
+    c ^= b;                                                                    \
+    c -= rol32(b, 16);                                                         \
+    a ^= c;                                                                    \
+    a -= rol32(c, 4);                                                          \
+    b ^= a;                                                                    \
+    b -= rol32(a, 14);                                                         \
+    c ^= b;                                                                    \
+    c -= rol32(b, 24);                                                         \
+  }
+
+/* __jhash_nwords - hash exactly 3, 2 or 1 word(s) */
+static inline __u32 __jhash_nwords(__u32 a, __u32 b, __u32 c, __u32 initval) {
+  a += initval;
+  b += initval;
+  c += initval;
+
+  __jhash_final(a, b, c);
+
+  return c;
+}
+
+/* An arbitrary initial parameter */
+#define JHASH_INITVAL 0xdeadbeef
+
+static inline __u32 jhash_2words(__u32 a, __u32 b, __u32 initval) {
+  return __jhash_nwords(a, b, 0, initval + JHASH_INITVAL + (2 << 2));
+}
+
+struct {
+  __uint(type, BPF_MAP_TYPE_REUSEPORT_SOCKARRAY);
+  __uint(max_entries, 255);
+  __uint(key_size, sizeof(__u32));
+  __uint(value_size, sizeof(__u32));
+} reuseport_array SEC(".maps");
+
+typedef struct vec {
+  __u8 *data;
+  __u8 *data_end;
+} vec;
+
+typedef struct quic_hd {
+  __u8 *dcid;
+  __u32 dcid_offset;
+  __u32 dcid_len;
+  __u8 type;
+} quic_hd;
+
+#define SV_DCIDLEN 18
+#define MAX_DCIDLEN 20
+#define MIN_DCIDLEN 8
+
+static inline int parse_quic(quic_hd *qhd, struct sk_reuseport_md *reuse_md) {
+  __u64 len = sizeof(struct udphdr) + 1;
+  __u8 *p;
+  __u64 dcidlen;
+
+  if (reuse_md->data + len > reuse_md->data_end) {
+    return -1;
+  }
+
+  p = reuse_md->data + sizeof(struct udphdr);
+
+  if (*p & 0x80) {
+    len += 4 + 1;
+    if (reuse_md->data + len > reuse_md->data_end) {
+      return -1;
+    }
+
+    p += 1 + 4;
+
+    dcidlen = *p;
+
+    if (dcidlen > MAX_DCIDLEN || dcidlen < MIN_DCIDLEN) {
+      return -1;
+    }
+
+    len += 1 + dcidlen;
+
+    if (reuse_md->data + len > reuse_md->data_end) {
+      return -1;
+    }
+
+    ++p;
+
+    qhd->type =
+        (*((__u8 *)(reuse_md->data) + sizeof(struct udphdr)) & 0x30) >> 4;
+    qhd->dcid = p;
+    qhd->dcid_offset = sizeof(struct udphdr) + 6;
+    qhd->dcid_len = dcidlen;
+  } else {
+    len += SV_DCIDLEN;
+    if (reuse_md->data + len > reuse_md->data_end) {
+      return -1;
+    }
+
+    qhd->type = 0xff;
+    qhd->dcid = (__u8 *)reuse_md->data + sizeof(struct udphdr) + 1;
+    qhd->dcid_offset = sizeof(struct udphdr) + 1;
+    qhd->dcid_len = SV_DCIDLEN;
+  }
+
+  return 0;
+}
+
+#define NUM_SOCKETS 4
+
+SEC("sk_reuseport")
+int _select_by_skb_data(struct sk_reuseport_md *reuse_md) {
+  __u32 sk_index;
+  int rv;
+  quic_hd qhd;
+  __u32 a, b;
+  __u8 *p;
+
+  rv = parse_quic(&qhd, reuse_md);
+  if (rv != 0) {
+    return SK_DROP;
+  }
+
+  switch (qhd.type) {
+  case 0x0: /* Initial */
+  case 0x1: /* 0-RTT */
+    if (reuse_md->data + sizeof(struct udphdr) + 6 + 8 > reuse_md->data_end) {
+      return SK_DROP;
+    }
+
+    p = (__u8 *)reuse_md->data + sizeof(struct udphdr) + 6;
+    a = (p[0] << 24) | (p[1] << 16) | (p[2] << 8) | p[3];
+    b = (p[4] << 24) | (p[5] << 16) | (p[6] << 8) | p[7];
+
+    sk_index = jhash_2words(a, b, reuse_md->hash) % NUM_SOCKETS;
+
+    break;
+  case 0x2: /* Handshake */
+    if (qhd.dcid_len != SV_DCIDLEN) {
+      return SK_DROP;
+    }
+
+    if (reuse_md->data + sizeof(struct udphdr) + 6 + 1 > reuse_md->data_end) {
+      return SK_DROP;
+    }
+
+    sk_index =
+        *((__u8 *)reuse_md->data + sizeof(struct udphdr) + 6) % NUM_SOCKETS;
+
+    break;
+  case 0xff: /* Short */
+    if (qhd.dcid_len != SV_DCIDLEN) {
+      return SK_DROP;
+    }
+
+    if (reuse_md->data + sizeof(struct udphdr) + 1 + 1 > reuse_md->data_end) {
+      return SK_DROP;
+    }
+
+    sk_index =
+        *((__u8 *)reuse_md->data + sizeof(struct udphdr) + 1) % NUM_SOCKETS;
+    break;
+  default:
+    return SK_DROP;
+  }
+
+  rv = bpf_sk_select_reuseport(reuse_md, &reuseport_array, &sk_index, 0);
+  if (rv != 0) {
+    return SK_DROP;
+  }
+
+  return SK_PASS;
+}