| // Package tinylb implements a small and simple userland round-robin load | 
 | // balancer, mostly for TCP connections. However, it is entirely | 
 | // protocol-agnostic, and only expects net.Listener and net.Conn objects. | 
 | // | 
 | // Apart from the simple act of load-balancing across a set of backends, tinylb | 
 | // also automatically and immediately closes all open connections to backend | 
 | // targets that have been removed from the set. This is perhaps not the ideal | 
 | // behaviour for user-facing services, but it's the sort of behaviour that works | 
 | // very well for machine-to-machine communication. | 
 | package tinylb | 
 |  | 
 | import ( | 
 | 	"context" | 
 | 	"io" | 
 | 	"net" | 
 | 	"sync" | 
 |  | 
 | 	"source.monogon.dev/go/types/mapsets" | 
 | 	"source.monogon.dev/metropolis/pkg/event" | 
 | 	"source.monogon.dev/metropolis/pkg/supervisor" | 
 | ) | 
 |  | 
 | // Backend is to be implemented by different kinds of loadbalancing backends, eg. | 
 | // one per network protocol. | 
 | type Backend interface { | 
 | 	// TargetName returns the 'target name' of the backend, which is _not_ the same | 
 | 	// as its key in the BackendSet. Instead, the TargetName should uniquely identify | 
 | 	// some backend address, and will be used to figure out that while a backend | 
 | 	// might still exist, its target address has changed - and thus, all existing | 
 | 	// connections to the previous target address should be closed. | 
 | 	// | 
 | 	// For simple load balancing backends, this could be the connection string used | 
 | 	// to connect to the backend. | 
 | 	TargetName() string | 
 | 	// Dial returns a new connection to this backend. | 
 | 	Dial() (net.Conn, error) | 
 | } | 
 |  | 
 | // BackendSet is the main structure used to provide the current set of backends | 
 | // that should be targeted by tinylb. The key is some unique backend identifier. | 
 | type BackendSet = mapsets.OrderedMap[string, Backend] | 
 |  | 
 | // SimpleTCPBackend implements Backend for trivial TCP-based backends. | 
 | type SimpleTCPBackend struct { | 
 | 	Remote string | 
 | } | 
 |  | 
 | func (t *SimpleTCPBackend) TargetName() string { | 
 | 	return t.Remote | 
 | } | 
 |  | 
 | func (t *SimpleTCPBackend) Dial() (net.Conn, error) { | 
 | 	return net.Dial("tcp", t.Remote) | 
 | } | 
 |  | 
 | // Server is a tiny round-robin loadbalancer for net.Listener/net.Conn compatible | 
 | // protocols. | 
 | // | 
 | // All fields must be set before the loadbalancer can be run. | 
 | type Server struct { | 
 | 	// Provider is some event Value which provides the current BackendSet for the | 
 | 	// loadbalancer to use. As the BackendSet is updated, the internal loadbalancing | 
 | 	// algorithm will adjust to the updated set, and any connections to backend | 
 | 	// TargetNames that are not present in the set anymore will be closed. | 
 | 	Provider event.Value[BackendSet] | 
 | 	// Listener is where the loadbalancer will listen on. After the loadbalancer | 
 | 	// exits, this listener will be closed. | 
 | 	Listener net.Listener | 
 | } | 
 |  | 
 | // Run the loadbalancer in a superervisor.Runnable and block until canceled. | 
 | // Because the Server's Listener will be closed after exit, it should be opened | 
 | // in the same runnable as this function is then started. | 
 | func (s *Server) Run(ctx context.Context) error { | 
 | 	// Connection pool used to track connections/backends. | 
 | 	pool := newConnectionPool() | 
 |  | 
 | 	// Current backend set and its lock. | 
 | 	var curSetMu sync.RWMutex | 
 | 	var curSet BackendSet | 
 |  | 
 | 	// Close listener on exit. | 
 | 	go func() { | 
 | 		<-ctx.Done() | 
 | 		s.Listener.Close() | 
 | 	}() | 
 |  | 
 | 	// The acceptor is runs the main Accept() loop on the given Listener. | 
 | 	err := supervisor.Run(ctx, "acceptor", func(ctx context.Context) error { | 
 | 		// This doesn't need a lock, as it doesn't read any fields of curSet. | 
 | 		it := curSet.Cycle() | 
 |  | 
 | 		supervisor.Signal(ctx, supervisor.SignalHealthy) | 
 |  | 
 | 		for { | 
 | 			if ctx.Err() != nil { | 
 | 				return ctx.Err() | 
 | 			} | 
 | 			conn, err := s.Listener.Accept() | 
 | 			if err != nil { | 
 | 				return err | 
 | 			} | 
 |  | 
 | 			// Get next backend. | 
 | 			curSetMu.RLock() | 
 | 			id, backend, ok := it.Next() | 
 | 			curSetMu.RUnlock() | 
 |  | 
 | 			if !ok { | 
 | 				supervisor.Logger(ctx).Warningf("Balancing %s: failed due to empty backend set", conn.RemoteAddr().String()) | 
 | 				conn.Close() | 
 | 				continue | 
 | 			} | 
 | 			// Dial backend. | 
 | 			r, err := backend.Dial() | 
 | 			if err != nil { | 
 | 				supervisor.Logger(ctx).Warningf("Balancing %s: failed due to backend %s error: %v", conn.RemoteAddr(), id, err) | 
 | 				conn.Close() | 
 | 				continue | 
 | 			} | 
 | 			// Insert connection/target name into connectionPool. | 
 | 			target := backend.TargetName() | 
 | 			cid := pool.Insert(target, r) | 
 |  | 
 | 			// Pipe data. Close both connections on any side failing. | 
 | 			go func() { | 
 | 				defer conn.Close() | 
 | 				defer pool.CloseConn(cid) | 
 | 				io.Copy(r, conn) | 
 | 			}() | 
 | 			go func() { | 
 | 				defer conn.Close() | 
 | 				defer pool.CloseConn(cid) | 
 | 				io.Copy(conn, r) | 
 | 			}() | 
 | 		} | 
 | 	}) | 
 | 	if err != nil { | 
 | 		return err | 
 | 	} | 
 |  | 
 | 	supervisor.Signal(ctx, supervisor.SignalHealthy) | 
 |  | 
 | 	// Update curSet from Provider. | 
 | 	w := s.Provider.Watch() | 
 | 	defer w.Close() | 
 | 	for { | 
 | 		set, err := w.Get(ctx) | 
 | 		if err != nil { | 
 | 			return err | 
 | 		} | 
 |  | 
 | 		// Did we lose a backend? If so, kill all connections going through it. | 
 |  | 
 | 		// First, gather a map from TargetName to backend ID for the current set. | 
 | 		curTargets := make(map[string]string) | 
 | 		curSetMu.Lock() | 
 | 		for _, kv := range curSet.Values() { | 
 | 			curTargets[kv.Value.TargetName()] = kv.Key | 
 | 		} | 
 | 		curSetMu.Unlock() | 
 |  | 
 | 		// Then, gather it for the new set. | 
 | 		targets := make(map[string]string) | 
 | 		for _, kv := range set.Values() { | 
 | 			targets[kv.Value.TargetName()] = kv.Key | 
 | 		} | 
 |  | 
 | 		// Then, if we have any target name in the connection pool that's not in the new | 
 | 		// set, close all of its connections. | 
 | 		for _, target := range pool.Targets() { | 
 | 			if _, ok := targets[target]; ok { | 
 | 				continue | 
 | 			} | 
 | 			// Use curTargets just for displaying the name of the backend that has changed. | 
 | 			supervisor.Logger(ctx).Infof("Backend %s / target %s removed, killing connections", curTargets[target], target) | 
 | 			pool.CloseTarget(target) | 
 | 		} | 
 |  | 
 | 		// Tell about new backend set and actually replace it. | 
 | 		supervisor.Logger(ctx).Infof("New backend set (%d backends):", len(set.Keys())) | 
 | 		for _, kv := range set.Values() { | 
 | 			supervisor.Logger(ctx).Infof(" - %s, target %s", kv.Key, kv.Value.TargetName()) | 
 | 		} | 
 | 		curSetMu.Lock() | 
 | 		curSet.Replace(&set) | 
 | 		curSetMu.Unlock() | 
 | 	} | 
 | } |