go/net/tinylb: init

This implements tinylb, a tiny round-robin load balancer for
net.Conn/net.Listener protocols.

This will be used to loadbalance connections to Kubernetes apiservers
before cluster networking is available.

Change-Id: I48892e1fe03e0648df60c674e7394ca69b32932d
Reviewed-on: https://review.monogon.dev/c/monogon/+/1369
Tested-by: Jenkins CI
Reviewed-by: Lorenz Brun <lorenz@monogon.tech>
diff --git a/go/net/tinylb/BUILD.bazel b/go/net/tinylb/BUILD.bazel
new file mode 100644
index 0000000..a238bd3
--- /dev/null
+++ b/go/net/tinylb/BUILD.bazel
@@ -0,0 +1,26 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+    name = "tinylb",
+    srcs = [
+        "connection_pool.go",
+        "tinylb.go",
+    ],
+    importpath = "source.monogon.dev/go/net/tinylb",
+    visibility = ["//visibility:public"],
+    deps = [
+        "//go/types/mapsets",
+        "//metropolis/pkg/event",
+        "//metropolis/pkg/supervisor",
+    ],
+)
+
+go_test(
+    name = "tinylb_test",
+    srcs = ["tinylb_test.go"],
+    embed = [":tinylb"],
+    deps = [
+        "//metropolis/pkg/event/memory",
+        "//metropolis/pkg/supervisor",
+    ],
+)
diff --git a/go/net/tinylb/connection_pool.go b/go/net/tinylb/connection_pool.go
new file mode 100644
index 0000000..956bf75
--- /dev/null
+++ b/go/net/tinylb/connection_pool.go
@@ -0,0 +1,121 @@
+package tinylb
+
+import (
+	"net"
+	"sort"
+	"sync"
+)
+
+// connectionPool maintains information about open connections to backends, and
+// allows for closing either arbitrary connections (by ID) or all connections to
+// a given backend.
+//
+// This structure exists to allow tinylb to kill all connections of a backend
+// that has just been removed from the BackendSet.
+//
+// Any time a connection is inserted into the pool, a unique ID for that
+// connection is returned.
+//
+// Backends are identified by 'target name' which is an opaque string.
+//
+// This structure is likely the performance bottleneck of the implementation, as
+// it takes a non-RW lock for every incoming connection.
+type connectionPool struct {
+	mu sync.Mutex
+	// detailsById maps connection ids to details about that connection.
+	detailsById map[int64]*connectionDetails
+	// idsByTarget maps a target name to all connection IDs that opened to it.
+	idsByTarget map[string][]int64
+
+	// cid is the connection id counter, increased any time a connection ID is
+	// allocated.
+	cid int64
+}
+
+// connectionDetails for each open connection. These are held in
+// connectionPool.details
+type connectionDetails struct {
+	// conn is the active net.Conn backing this connection.
+	conn net.Conn
+	// target is the target name to which this connection was initiated.
+	target string
+}
+
+func newConnectionPool() *connectionPool {
+	return &connectionPool{
+		detailsById: make(map[int64]*connectionDetails),
+		idsByTarget: make(map[string][]int64),
+	}
+}
+
+// Insert a connection that's handled by the given target name, and return the
+// connection ID used to remove this connection later.
+func (c *connectionPool) Insert(target string, conn net.Conn) int64 {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+
+	id := c.cid
+	c.cid++
+
+	c.detailsById[id] = &connectionDetails{
+		conn:   conn,
+		target: target,
+	}
+	c.idsByTarget[target] = append(c.idsByTarget[target], id)
+	return id
+}
+
+// CloseConn closes the underlying connection for the given connection ID, and
+// removes that connection ID from internal tracking.
+func (c *connectionPool) CloseConn(id int64) {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+
+	cd, ok := c.detailsById[id]
+	if !ok {
+		return
+	}
+
+	ids := c.idsByTarget[cd.target]
+	// ids is technically sorted because 'id' is always monotonically increasing, so
+	// we could be smarter and do a binary search here.
+	ix := -1
+	for i, id2 := range ids {
+		if id2 == id {
+			ix = i
+			break
+		}
+	}
+	if ix == -1 {
+		panic("Programming error: connection present in detailsById but not in idsByTarget")
+	}
+	c.idsByTarget[cd.target] = append(ids[:ix], ids[ix+1:]...)
+	cd.conn.Close()
+	delete(c.detailsById, id)
+}
+
+// CloseTarget closes all connections to a given backend target name, and removes
+// them from internal tracking.
+func (c *connectionPool) CloseTarget(target string) {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+
+	for _, id := range c.idsByTarget[target] {
+		c.detailsById[id].conn.Close()
+		delete(c.detailsById, id)
+	}
+	delete(c.idsByTarget, target)
+}
+
+// Targets removes all currently active backend target names.
+func (c *connectionPool) Targets() []string {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+
+	res := make([]string, 0, len(c.idsByTarget))
+	for target, _ := range c.idsByTarget {
+		res = append(res, target)
+	}
+	sort.Strings(res)
+	return res
+}
diff --git a/go/net/tinylb/tinylb.go b/go/net/tinylb/tinylb.go
new file mode 100644
index 0000000..67a639e
--- /dev/null
+++ b/go/net/tinylb/tinylb.go
@@ -0,0 +1,189 @@
+// Package tinylb implements a small and simple userland round-robin load
+// balancer, mostly for TCP connections. However, it is entirely
+// protocol-agnostic, and only expects net.Listener and net.Conn objects.
+//
+// Apart from the simple act of load-balancing across a set of backends, tinylb
+// also automatically and immediately closes all open connections to backend
+// targets that have been removed from the set. This is perhaps not the ideal
+// behaviour for user-facing services, but it's the sort of behaviour that works
+// very well for machine-to-machine communication.
+package tinylb
+
+import (
+	"context"
+	"io"
+	"net"
+	"sync"
+
+	"source.monogon.dev/go/types/mapsets"
+	"source.monogon.dev/metropolis/pkg/event"
+	"source.monogon.dev/metropolis/pkg/supervisor"
+)
+
+// Backend is to be implemented by different kinds of loadbalancing backends, eg.
+// one per network protocol.
+type Backend interface {
+	// TargetName returns the 'target name' of the backend, which is _not_ the same
+	// as its key in the BackendSet. Instead, the TargetName should uniquely identify
+	// some backend address, and will be used to figure out that while a backend
+	// might still exist, its target address has changed - and thus, all existing
+	// connections to the previous target address should be closed.
+	//
+	// For simple load balancing backends, this could be the connection string used
+	// to connect to the backend.
+	TargetName() string
+	// Dial returns a new connection to this backend.
+	Dial() (net.Conn, error)
+}
+
+// BackendSet is the main structure used to provide the current set of backends
+// that should be targeted by tinylb. The key is some unique backend identifier.
+type BackendSet = mapsets.OrderedMap[string, Backend]
+
+// SimpleTCPBackend implements Backend for trivial TCP-based backends.
+type SimpleTCPBackend struct {
+	Remote string
+}
+
+func (t *SimpleTCPBackend) TargetName() string {
+	return t.Remote
+}
+
+func (t *SimpleTCPBackend) Dial() (net.Conn, error) {
+	return net.Dial("tcp", t.Remote)
+}
+
+// Server is a tiny round-robin loadbalancer for net.Listener/net.Conn compatible
+// protocols.
+//
+// All fields must be set before the loadbalancer can be run.
+type Server struct {
+	// Provider is some event Value which provides the current BackendSet for the
+	// loadbalancer to use. As the BackendSet is updated, the internal loadbalancing
+	// algorithm will adjust to the updated set, and any connections to backend
+	// TargetNames that are not present in the set anymore will be closed.
+	Provider event.Value[BackendSet]
+	// Listener is where the loadbalancer will listen on. After the loadbalancer
+	// exits, this listener will be closed.
+	Listener net.Listener
+}
+
+// Run the loadbalancer in a superervisor.Runnable and block until canceled.
+// Because the Server's Listener will be closed after exit, it should be opened
+// in the same runnable as this function is then started.
+func (s *Server) Run(ctx context.Context) error {
+	// Connection pool used to track connections/backends.
+	pool := newConnectionPool()
+
+	// Current backend set and its lock.
+	var curSetMu sync.RWMutex
+	var curSet BackendSet
+
+	// Close listener on exit.
+	go func() {
+		<-ctx.Done()
+		s.Listener.Close()
+	}()
+
+	// The acceptor is runs the main Accept() loop on the given Listener.
+	err := supervisor.Run(ctx, "acceptor", func(ctx context.Context) error {
+		// This doesn't need a lock, as it doesn't read any fields of curSet.
+		it := curSet.Cycle()
+
+		supervisor.Signal(ctx, supervisor.SignalHealthy)
+
+		for {
+			if ctx.Err() != nil {
+				return ctx.Err()
+			}
+			conn, err := s.Listener.Accept()
+			if err != nil {
+				return err
+			}
+
+			// Get next backend.
+			curSetMu.RLock()
+			id, backend, ok := it.Next()
+			curSetMu.RUnlock()
+
+			if !ok {
+				supervisor.Logger(ctx).Warningf("Balancing %s: failed due to empty backend set", conn.RemoteAddr().String())
+				conn.Close()
+				continue
+			}
+			// Dial backend.
+			r, err := backend.Dial()
+			if err != nil {
+				supervisor.Logger(ctx).Warningf("Balancing %s: failed due to backend %s error: %v", conn.RemoteAddr(), id, err)
+				conn.Close()
+				continue
+			}
+			// Insert connection/target name into connectionPool.
+			target := backend.TargetName()
+			cid := pool.Insert(target, r)
+
+			// Pipe data. Close both connections on any side failing.
+			go func() {
+				defer conn.Close()
+				defer pool.CloseConn(cid)
+				io.Copy(r, conn)
+			}()
+			go func() {
+				defer conn.Close()
+				defer pool.CloseConn(cid)
+				io.Copy(conn, r)
+			}()
+		}
+	})
+	if err != nil {
+		return err
+	}
+
+	supervisor.Signal(ctx, supervisor.SignalHealthy)
+
+	// Update curSet from Provider.
+	w := s.Provider.Watch()
+	defer w.Close()
+	for {
+		set, err := w.Get(ctx)
+		if err != nil {
+			return err
+		}
+
+		// Did we lose a backend? If so, kill all connections going through it.
+
+		// First, gather a map from TargetName to backend ID for the current set.
+		curTargets := make(map[string]string)
+		curSetMu.Lock()
+		for _, kv := range curSet.Values() {
+			curTargets[kv.Value.TargetName()] = kv.Key
+		}
+		curSetMu.Unlock()
+
+		// Then, gather it for the new set.
+		targets := make(map[string]string)
+		for _, kv := range set.Values() {
+			targets[kv.Value.TargetName()] = kv.Key
+		}
+
+		// Then, if we have any target name in the connection pool that's not in the new
+		// set, close all of its connections.
+		for _, target := range pool.Targets() {
+			if _, ok := targets[target]; ok {
+				continue
+			}
+			// Use curTargets just for displaying the name of the backend that has changed.
+			supervisor.Logger(ctx).Infof("Backend %s / target %s removed, killing connections", curTargets[target], target)
+			pool.CloseTarget(target)
+		}
+
+		// Tell about new backend set and actually replace it.
+		supervisor.Logger(ctx).Infof("New backend set (%d backends):", len(set.Keys()))
+		for _, kv := range set.Values() {
+			supervisor.Logger(ctx).Infof(" - %s, target %s", kv.Key, kv.Value.TargetName())
+		}
+		curSetMu.Lock()
+		curSet.Replace(&set)
+		curSetMu.Unlock()
+	}
+}
diff --git a/go/net/tinylb/tinylb_test.go b/go/net/tinylb/tinylb_test.go
new file mode 100644
index 0000000..acf2dda
--- /dev/null
+++ b/go/net/tinylb/tinylb_test.go
@@ -0,0 +1,230 @@
+package tinylb
+
+import (
+	"bufio"
+	"fmt"
+	"io"
+	"net"
+	"strings"
+	"testing"
+	"time"
+
+	"source.monogon.dev/metropolis/pkg/event/memory"
+	"source.monogon.dev/metropolis/pkg/supervisor"
+)
+
+func TestLoadbalancer(t *testing.T) {
+	v := memory.Value[BackendSet]{}
+	set := BackendSet{}
+	v.Set(set.Clone())
+
+	ln, err := net.Listen("tcp", ":0")
+	if err != nil {
+		t.Fatalf("Listen failed: %v", err)
+	}
+	s := Server{
+		Provider: &v,
+		Listener: ln,
+	}
+	supervisor.TestHarness(t, s.Run)
+
+	connect := func() net.Conn {
+		conn, err := net.Dial("tcp", ln.Addr().String())
+		if err != nil {
+			t.Fatalf("Connection failed: %v", err)
+		}
+		return conn
+	}
+
+	c := connect()
+	buf := make([]byte, 128)
+	if _, err := c.Read(buf); err == nil {
+		t.Fatalf("Expected error on read (no backends yet)")
+	}
+
+	// Now add a backend and expect it to be served.
+	makeBackend := func(hello string) net.Listener {
+		aln, err := net.Listen("tcp", ":0")
+		if err != nil {
+			t.Fatalf("Failed to make backend listener: %v", err)
+		}
+		// Start backend.
+		go func() {
+			for {
+				c, err := aln.Accept()
+				if err != nil {
+					return
+				}
+				// For each connection, keep writing 'hello' over and over, newline-separated.
+				go func() {
+					defer c.Close()
+					for {
+						if _, err := fmt.Fprintf(c, "%s\n", hello); err != nil {
+							return
+						}
+						time.Sleep(100 * time.Millisecond)
+					}
+				}()
+			}
+		}()
+		addr := aln.Addr().(*net.TCPAddr)
+		set.Insert(hello, &SimpleTCPBackend{Remote: addr.AddrPort().String()})
+		v.Set(set.Clone())
+		return aln
+	}
+
+	as1 := makeBackend("a")
+	defer as1.Close()
+
+	for {
+		c = connect()
+		_, err := c.Read(buf)
+		c.Close()
+		if err == nil {
+			break
+		}
+	}
+
+	measure := func() map[string]int {
+		res := make(map[string]int)
+		for {
+			count := 0
+			for _, v := range res {
+				count += v
+			}
+			if count >= 20 {
+				return res
+			}
+
+			c := connect()
+			b := bufio.NewScanner(c)
+			if !b.Scan() {
+				err := b.Err()
+				if err == nil {
+					err = io.EOF
+				}
+				t.Fatalf("Scan failed: %v", err)
+			}
+			v := b.Text()
+			res[v]++
+			c.Close()
+		}
+	}
+
+	m := measure()
+	if m["a"] < 20 {
+		t.Errorf("Expected only one backend, got: %v", m)
+	}
+
+	as2 := makeBackend("b")
+	defer as2.Close()
+
+	as3 := makeBackend("c")
+	defer as3.Close()
+
+	as4 := makeBackend("d")
+	defer as4.Close()
+
+	m = measure()
+	for _, id := range []string{"a", "b", "c", "d"} {
+		if want, got := 4, m[id]; got < want {
+			t.Errorf("Expected at least %d responses from %s, got %d", want, id, got)
+		}
+	}
+
+	// Test killing backend connections on backend removal.
+	// Open a bunch of connections to 'a'.
+	var conns []*bufio.Scanner
+	for len(conns) < 5 {
+		c := connect()
+		b := bufio.NewScanner(c)
+		b.Scan()
+		if b.Text() != "a" {
+			c.Close()
+		} else {
+			conns = append(conns, b)
+		}
+	}
+
+	// Now remove the 'a' backend.
+	set.Delete("a")
+	v.Set(set.Clone())
+	// All open connections should now get killed.
+	for _, b := range conns {
+		start := time.Now().Add(time.Second)
+		for b.Scan() {
+			if time.Now().After(start) {
+				t.Errorf("Connection still alive")
+				break
+			}
+		}
+	}
+}
+
+func BenchmarkLB(b *testing.B) {
+	v := memory.Value[BackendSet]{}
+	set := BackendSet{}
+	v.Set(set.Clone())
+
+	ln, err := net.Listen("tcp", ":0")
+	if err != nil {
+		b.Fatalf("Listen failed: %v", err)
+	}
+	s := Server{
+		Provider: &v,
+		Listener: ln,
+	}
+	supervisor.TestHarness(b, s.Run)
+
+	makeBackend := func(hello string) net.Listener {
+		aln, err := net.Listen("tcp", ":0")
+		if err != nil {
+			b.Fatalf("Failed to make backend listener: %v", err)
+		}
+		// Start backend.
+		go func() {
+			for {
+				c, err := aln.Accept()
+				if err != nil {
+					return
+				}
+				go func() {
+					fmt.Fprintf(c, "%s\n", hello)
+					c.Close()
+				}()
+			}
+		}()
+		addr := aln.Addr().(*net.TCPAddr)
+		set.Insert(hello, &SimpleTCPBackend{Remote: addr.AddrPort().String()})
+		v.Set(set.Clone())
+		return aln
+	}
+	var backends []net.Listener
+	for i := 0; i < 10; i++ {
+		b := makeBackend(fmt.Sprintf("backend%d", i))
+		backends = append(backends, b)
+	}
+
+	defer func() {
+		for _, b := range backends {
+			b.Close()
+		}
+	}()
+
+	b.ResetTimer()
+	b.RunParallel(func(pb *testing.PB) {
+		for pb.Next() {
+			conn, err := net.Dial("tcp", ln.Addr().String())
+			if err != nil {
+				b.Fatalf("Connection failed: %v", err)
+			}
+			buf := bufio.NewScanner(conn)
+			buf.Scan()
+			if !strings.HasPrefix(buf.Text(), "backend") {
+				b.Fatalf("Invalid backend response: %q", buf.Text())
+			}
+			conn.Close()
+		}
+	})
+	b.StopTimer()
+}