blob: 67a639e0b43fddb6cfcdc4834f9567dc9441a2c6 [file] [log] [blame]
// Package tinylb implements a small and simple userland round-robin load
// balancer, mostly for TCP connections. However, it is entirely
// protocol-agnostic, and only expects net.Listener and net.Conn objects.
//
// Apart from the simple act of load-balancing across a set of backends, tinylb
// also automatically and immediately closes all open connections to backend
// targets that have been removed from the set. This is perhaps not the ideal
// behaviour for user-facing services, but it's the sort of behaviour that works
// very well for machine-to-machine communication.
package tinylb
import (
"context"
"io"
"net"
"sync"
"source.monogon.dev/go/types/mapsets"
"source.monogon.dev/metropolis/pkg/event"
"source.monogon.dev/metropolis/pkg/supervisor"
)
// Backend is to be implemented by different kinds of loadbalancing backends, eg.
// one per network protocol.
type Backend interface {
// TargetName returns the 'target name' of the backend, which is _not_ the same
// as its key in the BackendSet. Instead, the TargetName should uniquely identify
// some backend address, and will be used to figure out that while a backend
// might still exist, its target address has changed - and thus, all existing
// connections to the previous target address should be closed.
//
// For simple load balancing backends, this could be the connection string used
// to connect to the backend.
TargetName() string
// Dial returns a new connection to this backend.
Dial() (net.Conn, error)
}
// BackendSet is the main structure used to provide the current set of backends
// that should be targeted by tinylb. The key is some unique backend identifier.
type BackendSet = mapsets.OrderedMap[string, Backend]
// SimpleTCPBackend implements Backend for trivial TCP-based backends.
type SimpleTCPBackend struct {
Remote string
}
func (t *SimpleTCPBackend) TargetName() string {
return t.Remote
}
func (t *SimpleTCPBackend) Dial() (net.Conn, error) {
return net.Dial("tcp", t.Remote)
}
// Server is a tiny round-robin loadbalancer for net.Listener/net.Conn compatible
// protocols.
//
// All fields must be set before the loadbalancer can be run.
type Server struct {
// Provider is some event Value which provides the current BackendSet for the
// loadbalancer to use. As the BackendSet is updated, the internal loadbalancing
// algorithm will adjust to the updated set, and any connections to backend
// TargetNames that are not present in the set anymore will be closed.
Provider event.Value[BackendSet]
// Listener is where the loadbalancer will listen on. After the loadbalancer
// exits, this listener will be closed.
Listener net.Listener
}
// Run the loadbalancer in a superervisor.Runnable and block until canceled.
// Because the Server's Listener will be closed after exit, it should be opened
// in the same runnable as this function is then started.
func (s *Server) Run(ctx context.Context) error {
// Connection pool used to track connections/backends.
pool := newConnectionPool()
// Current backend set and its lock.
var curSetMu sync.RWMutex
var curSet BackendSet
// Close listener on exit.
go func() {
<-ctx.Done()
s.Listener.Close()
}()
// The acceptor is runs the main Accept() loop on the given Listener.
err := supervisor.Run(ctx, "acceptor", func(ctx context.Context) error {
// This doesn't need a lock, as it doesn't read any fields of curSet.
it := curSet.Cycle()
supervisor.Signal(ctx, supervisor.SignalHealthy)
for {
if ctx.Err() != nil {
return ctx.Err()
}
conn, err := s.Listener.Accept()
if err != nil {
return err
}
// Get next backend.
curSetMu.RLock()
id, backend, ok := it.Next()
curSetMu.RUnlock()
if !ok {
supervisor.Logger(ctx).Warningf("Balancing %s: failed due to empty backend set", conn.RemoteAddr().String())
conn.Close()
continue
}
// Dial backend.
r, err := backend.Dial()
if err != nil {
supervisor.Logger(ctx).Warningf("Balancing %s: failed due to backend %s error: %v", conn.RemoteAddr(), id, err)
conn.Close()
continue
}
// Insert connection/target name into connectionPool.
target := backend.TargetName()
cid := pool.Insert(target, r)
// Pipe data. Close both connections on any side failing.
go func() {
defer conn.Close()
defer pool.CloseConn(cid)
io.Copy(r, conn)
}()
go func() {
defer conn.Close()
defer pool.CloseConn(cid)
io.Copy(conn, r)
}()
}
})
if err != nil {
return err
}
supervisor.Signal(ctx, supervisor.SignalHealthy)
// Update curSet from Provider.
w := s.Provider.Watch()
defer w.Close()
for {
set, err := w.Get(ctx)
if err != nil {
return err
}
// Did we lose a backend? If so, kill all connections going through it.
// First, gather a map from TargetName to backend ID for the current set.
curTargets := make(map[string]string)
curSetMu.Lock()
for _, kv := range curSet.Values() {
curTargets[kv.Value.TargetName()] = kv.Key
}
curSetMu.Unlock()
// Then, gather it for the new set.
targets := make(map[string]string)
for _, kv := range set.Values() {
targets[kv.Value.TargetName()] = kv.Key
}
// Then, if we have any target name in the connection pool that's not in the new
// set, close all of its connections.
for _, target := range pool.Targets() {
if _, ok := targets[target]; ok {
continue
}
// Use curTargets just for displaying the name of the backend that has changed.
supervisor.Logger(ctx).Infof("Backend %s / target %s removed, killing connections", curTargets[target], target)
pool.CloseTarget(target)
}
// Tell about new backend set and actually replace it.
supervisor.Logger(ctx).Infof("New backend set (%d backends):", len(set.Keys()))
for _, kv := range set.Values() {
supervisor.Logger(ctx).Infof(" - %s, target %s", kv.Key, kv.Value.TargetName())
}
curSetMu.Lock()
curSet.Replace(&set)
curSetMu.Unlock()
}
}