cloud/bmaas/server: do not start BMDB session per RPC

Fixes https://github.com/monogon-dev/monogon/issues/198

Change-Id: Ia99b9a47bfc9ae6de0b9e12e13add891dea289a5
Reviewed-on: https://review.monogon.dev/c/monogon/+/1567
Reviewed-by: Leopold Schabel <leo@monogon.tech>
Tested-by: Jenkins CI
diff --git a/cloud/bmaas/bmdb/sessions.go b/cloud/bmaas/bmdb/sessions.go
index 3f70393..6058b92 100644
--- a/cloud/bmaas/bmdb/sessions.go
+++ b/cloud/bmaas/bmdb/sessions.go
@@ -69,6 +69,19 @@
 	ctxC context.CancelFunc
 }
 
+// Expired returns true if this session is expired and will fail all subsequent
+// transactions/work.
+func (s *Session) Expired() bool {
+	return s.ctx.Err() != nil
+}
+
+// expire is a helper which marks this session as expired and returns
+// ErrSessionExpired.
+func (s *Session) expire() error {
+	s.ctxC()
+	return ErrSessionExpired
+}
+
 var (
 	// ErrSessionExpired is returned when attempting to Transact or Work on a
 	// Session that has expired or been canceled. Once a Session starts returning
@@ -102,7 +115,7 @@
 				return fmt.Errorf("when retrieving session: %w", err)
 			}
 			if len(sessions) < 1 {
-				return ErrSessionExpired
+				return s.expire()
 			}
 			err = q.SessionPoke(ctx, s.UUID)
 			if err != nil {
@@ -147,7 +160,7 @@
 			return fmt.Errorf("when retrieving session: %w", err)
 		}
 		if len(sessions) < 1 {
-			return ErrSessionExpired
+			return s.expire()
 		}
 
 		if err := fn(qtx); err != nil {
diff --git a/cloud/bmaas/server/BUILD.bazel b/cloud/bmaas/server/BUILD.bazel
index aa4d898..4303d62 100644
--- a/cloud/bmaas/server/BUILD.bazel
+++ b/cloud/bmaas/server/BUILD.bazel
@@ -15,6 +15,7 @@
         "//cloud/bmaas/server/api",
         "//cloud/lib/component",
         "//metropolis/node/core/rpc",
+        "@com_github_cenkalti_backoff_v4//:backoff",
         "@com_github_google_uuid//:uuid",
         "@io_k8s_klog//:klog",
         "@io_k8s_klog_v2//:klog",
diff --git a/cloud/bmaas/server/agent_callback_service.go b/cloud/bmaas/server/agent_callback_service.go
index b6e0e71..155c5de 100644
--- a/cloud/bmaas/server/agent_callback_service.go
+++ b/cloud/bmaas/server/agent_callback_service.go
@@ -39,8 +39,7 @@
 		return nil, status.Error(codes.InvalidArgument, "machine_id invalid")
 	}
 
-	// TODO(q3k): don't start a session for every RPC.
-	session, err := a.s.bmdb.StartSession(ctx)
+	session, err := a.s.session(ctx)
 	if err != nil {
 		klog.Errorf("Could not start session: %v", err)
 		return nil, status.Error(codes.Unavailable, "could not start session")
diff --git a/cloud/bmaas/server/server.go b/cloud/bmaas/server/server.go
index 57496d5..0aeefad 100644
--- a/cloud/bmaas/server/server.go
+++ b/cloud/bmaas/server/server.go
@@ -7,6 +7,7 @@
 	"net"
 	"os"
 
+	"github.com/cenkalti/backoff/v4"
 	"google.golang.org/grpc"
 	"google.golang.org/grpc/reflection"
 	"k8s.io/klog/v2"
@@ -59,6 +60,49 @@
 
 	bmdb  *bmdb.Connection
 	acsvc *agentCallbackService
+
+	sessionC chan *bmdb.Session
+}
+
+// sessionWorker emits a valid BMDB session to sessionC as long as ctx is active.
+func (s *Server) sessionWorker(ctx context.Context) {
+	var session *bmdb.Session
+	for {
+		if session == nil || session.Expired() {
+			klog.Infof("Starting new session...")
+			bo := backoff.NewExponentialBackOff()
+			err := backoff.Retry(func() error {
+				var err error
+				session, err = s.bmdb.StartSession(ctx)
+				if err != nil {
+					klog.Errorf("Failed to start session: %v", err)
+					return err
+				} else {
+					return nil
+				}
+			}, backoff.WithContext(bo, ctx))
+			if err != nil {
+				// If something's really wrong just crash.
+				klog.Exitf("Gave up on starting session: %v", err)
+			}
+			klog.Infof("New session: %s", session.UUID)
+		}
+
+		select {
+		case <-ctx.Done():
+			return
+		case s.sessionC <- session:
+		}
+	}
+}
+
+func (s *Server) session(ctx context.Context) (*bmdb.Session, error) {
+	select {
+	case sess := <-s.sessionC:
+		return sess, nil
+	case <-ctx.Done():
+		return nil, ctx.Err()
+	}
 }
 
 func (s *Server) startPublic(ctx context.Context) {
@@ -109,6 +153,8 @@
 		s: s,
 	}
 	s.bmdb = conn
+	s.sessionC = make(chan *bmdb.Session)
+	go s.sessionWorker(ctx)
 	s.startInternalGRPC(ctx)
 	s.startPublic(ctx)
 	go func() {