c/bmaas/bmdb: rework Work API into working form

The Work API within the BMDB library wasn't quite working for purely
BMDB-directed control loops. We fix that by turning Work into a
three-phase process of retrieval, side-effect causing work and
committing.

We also add a test that exercises this functionality alongside the Agent
tags and installation retrieval queries.

Change-Id: I43af0c02af034e402dadb3e3a1fd10e5a8fe680a
Reviewed-on: https://review.monogon.dev/c/monogon/+/998
Tested-by: Jenkins CI
Reviewed-by: Mateusz Zalega <mateusz@monogon.tech>
diff --git a/cloud/bmaas/bmdb/BUILD.bazel b/cloud/bmaas/bmdb/BUILD.bazel
index 882d73c..65a11c7 100644
--- a/cloud/bmaas/bmdb/BUILD.bazel
+++ b/cloud/bmaas/bmdb/BUILD.bazel
@@ -31,5 +31,6 @@
     deps = [
         "//cloud/bmaas/bmdb/model",
         "//cloud/lib/component",
+        "@com_github_google_uuid//:uuid",
     ],
 )
diff --git a/cloud/bmaas/bmdb/sessions.go b/cloud/bmaas/bmdb/sessions.go
index 6b19171..ccb1510 100644
--- a/cloud/bmaas/bmdb/sessions.go
+++ b/cloud/bmaas/bmdb/sessions.go
@@ -198,7 +198,12 @@
 		}
 		// Success. Keep going.
 		deadline = time.Now().Add(s.interval)
-		time.Sleep(s.interval / 2)
+		select {
+		case <-ctx.Done():
+			// Do nothing, next loop iteration will exit.
+		case <-time.After(s.interval / 2):
+			// Do nothing, next loop iteration will heartbeat.
+		}
 	}
 }
 
@@ -226,40 +231,137 @@
 	})
 }
 
-// Work runs a given function as a work item with a given process name against
-// some identified machine. Not more than one process of a given name can run
-// against a machine concurrently.
-//
-// Most impure (meaning with side effects outside the database itself) BMDB
-// transactions should be run this way.
-func (s *Session) Work(ctx context.Context, machine uuid.UUID, process model.Process, fn func() error) error {
-	err := model.New(s.connection.db).StartWork(ctx, model.StartWorkParams{
-		MachineID: machine,
-		SessionID: s.UUID,
-		Process:   process,
-	})
-	if err != nil {
-		var perr *pq.Error
-		if errors.As(err, &perr) && perr.Code == "23505" {
-			return ErrWorkConflict
-		}
-		return fmt.Errorf("could not start work: %w", err)
-	}
-	klog.Infof("Started work: %q on machine %s, session %s", process, machine, s.UUID)
+var (
+	ErrNothingToDo = errors.New("nothing to do")
+	// PostgresUniqueViolation is returned by the lib/pq driver when a mutation
+	// cannot be performed due to a UNIQUE constraint being violated as a result of
+	// the query.
+	postgresUniqueViolation = pq.ErrorCode("23505")
+)
 
-	defer func() {
-		err := model.New(s.connection.db).FinishWork(s.ctx, model.FinishWorkParams{
-			MachineID: machine,
+// Work starts work on a machine. Full work execution is performed in three
+// phases:
+//
+//  1. Retrieval phase. This is performed by 'fn' given to this function.
+//     The retrieval function must return zero or more machines that some work
+//     should be performed on per the BMDB. The first returned machine will be
+//     locked for work under the given process and made available in the Work
+//     structure returned by this call. The function may be called multiple times,
+//     as it's run within a CockroachDB transaction which may be retried an
+//     arbitrary number of times. Thus, it should be side-effect free, ideally only
+//     performing read queries to the database.
+//  2. Work phase. This is performed by user code while holding on to the Work
+//     structure instance.
+//  3. Commit phase. This is performed by the function passed to Work.Finish. See
+//     that method's documentation for more details.
+//
+// Important: after retrieving Work successfully, either Finish or Cancel must be
+// called, otherwise the machine will be locked until the parent session expires
+// or is closed! It's safe and recommended to `defer work.Close()` after calling
+// Work().
+//
+// If no machine is eligible for work, ErrNothingToDo should be returned by the
+// retrieval function, and the same error (wrapped) will be returned by Work. In
+// case the retrieval function returns no machines and no error, that error will
+// also be returned.
+//
+// The returned Work object is _not_ goroutine safe.
+func (s *Session) Work(ctx context.Context, process model.Process, fn func(q *model.Queries) ([]uuid.UUID, error)) (*Work, error) {
+	var mid *uuid.UUID
+	err := s.Transact(ctx, func(q *model.Queries) error {
+		mids, err := fn(q)
+		if err != nil {
+			return fmt.Errorf("could not retrieve machines for work: %w", err)
+		}
+		if len(mids) < 1 {
+			return ErrNothingToDo
+		}
+		mid = &mids[0]
+		err = q.StartWork(ctx, model.StartWorkParams{
+			MachineID: mids[0],
 			SessionID: s.UUID,
 			Process:   process,
 		})
-		klog.Errorf("Finished work: %q on machine %s, session %s", process, machine, s.UUID)
-		if err != nil && !errors.Is(err, s.ctx.Err()) {
-			klog.Errorf("Failed to finish work: %v", err)
-			klog.Errorf("Closing session out of an abundance of caution")
-			s.ctxC()
+		if err != nil {
+			var perr *pq.Error
+			if errors.As(err, &perr) && perr.Code == postgresUniqueViolation {
+				return ErrWorkConflict
+			}
+			return fmt.Errorf("could not start work on %q: %w", mids[0], err)
 		}
-	}()
+		return nil
+	})
+	if err != nil {
+		return nil, err
+	}
+	klog.Infof("Started work %q on machine %q (sess %q)", process, *mid, s.UUID)
+	return &Work{
+		Machine: *mid,
+		s:       s,
+		process: process,
+	}, nil
+}
 
-	return fn()
+// Work being performed on a machine.
+type Work struct {
+	// Machine that this work is being performed on, as retrieved by the retrieval
+	// function passed to the Work method.
+	Machine uuid.UUID
+	// s is the parent session.
+	s *Session
+	// done marks that this work has already been canceled or finished.
+	done bool
+	// process that this work performs.
+	process model.Process
+}
+
+// Cancel the Work started on a machine. If the work has already been finished
+// or canceled, this is a no-op. In case of error, a log line will be emitted.
+func (w *Work) Cancel(ctx context.Context) {
+	if w.done {
+		return
+	}
+	w.done = true
+
+	klog.Infof("Canceling work %q on machine %q (sess %q)", w.process, w.Machine, w.s.UUID)
+	// Eat error and log. There's nothing we can do if this fails, and if it does, it's
+	// probably because our connectivity to the BMDB has failed. If so, our session
+	// will be invalidated soon and so will the work being performed on this
+	// machine.
+	err := w.s.Transact(ctx, func(q *model.Queries) error {
+		return q.FinishWork(ctx, model.FinishWorkParams{
+			MachineID: w.Machine,
+			SessionID: w.s.UUID,
+			Process:   w.process,
+		})
+	})
+	if err != nil {
+		klog.Errorf("Failed to cancel work %q on %q (sess %q): %v", w.process, w.Machine, w.s.UUID, err)
+	}
+}
+
+// Finish work by executing a commit function 'fn' and releasing the machine
+// from the work performed. The function given should apply tags to the
+// processed machine in a way that causes it to not be eligible for retrieval
+// again. As with the retriever function, the commit function might be called an
+// arbitrary number of times as part of cockroachdb transaction retries.
+//
+// This may be called only once.
+func (w *Work) Finish(ctx context.Context, fn func(q *model.Queries) error) error {
+	if w.done {
+		return fmt.Errorf("already finished")
+	}
+	w.done = true
+	klog.Infof("Finishing work %q on machine %q (sess %q)", w.process, w.Machine, w.s.UUID)
+	return w.s.Transact(ctx, func(q *model.Queries) error {
+		err := q.FinishWork(ctx, model.FinishWorkParams{
+			MachineID: w.Machine,
+			SessionID: w.s.UUID,
+			Process:   w.process,
+		})
+		if err != nil {
+			return err
+		}
+		return fn(q)
+	})
 }
diff --git a/cloud/bmaas/bmdb/sessions_test.go b/cloud/bmaas/bmdb/sessions_test.go
index 0018109..9664c17 100644
--- a/cloud/bmaas/bmdb/sessions_test.go
+++ b/cloud/bmaas/bmdb/sessions_test.go
@@ -3,9 +3,12 @@
 import (
 	"context"
 	"errors"
+	"fmt"
 	"testing"
 	"time"
 
+	"github.com/google/uuid"
+
 	"source.monogon.dev/cloud/bmaas/bmdb/model"
 	"source.monogon.dev/cloud/lib/component"
 )
@@ -111,22 +114,13 @@
 	// part of the test.
 	ctxB, ctxBC := context.WithCancel(ctx)
 	defer ctxBC()
-	// Start work which will block forever. We have to go rendezvous through a
-	// channel to make sure the work actually starts.
-	started := make(chan error)
-	done := make(chan error, 1)
-	go func() {
-		err := session1.Work(ctxB, machine.MachineID, model.ProcessUnitTest1, func() error {
-			started <- nil
-			<-ctxB.Done()
-			return ctxB.Err()
-		})
-		done <- err
-		if err != nil {
-			started <- err
-		}
-	}()
-	err = <-started
+
+	constantRetriever := func(_ *model.Queries) ([]uuid.UUID, error) {
+		return []uuid.UUID{machine.MachineID}, nil
+	}
+
+	// Start work on machine which we're not gonna finish for a while.
+	work1, err := session1.Work(ctxB, model.ProcessUnitTest1, constantRetriever)
 	if err != nil {
 		t.Fatalf("Starting first work failed: %v", err)
 	}
@@ -134,39 +128,277 @@
 	// Starting more work on the same machine but a different process should still
 	// be allowed.
 	for _, session := range []*Session{session1, session2} {
-		err = session.Work(ctxB, machine.MachineID, model.ProcessUnitTest2, func() error {
-			return nil
-		})
+		work2, err := session.Work(ctxB, model.ProcessUnitTest2, constantRetriever)
 		if err != nil {
 			t.Errorf("Could not run concurrent process on machine: %v", err)
+		} else {
+			work2.Cancel(ctxB)
 		}
 	}
 
 	// However, starting work with the same process on the same machine should
 	// fail.
 	for _, session := range []*Session{session1, session2} {
-		err = session.Work(ctxB, machine.MachineID, model.ProcessUnitTest1, func() error {
-			return nil
-		})
+		work2, err := session.Work(ctxB, model.ProcessUnitTest1, constantRetriever)
 		if !errors.Is(err, ErrWorkConflict) {
 			t.Errorf("Concurrent work with same process should've been forbidden, got %v", err)
+			work2.Cancel(ctxB)
 		}
 	}
 
-	// Now, cancel the first long-running request and wait for it to return.
-	ctxBC()
-	err = <-done
-	if !errors.Is(err, ctxB.Err()) {
-		t.Fatalf("First work item should've failed with %v, got %v", ctxB.Err(), err)
-	}
+	// Now, finish the long-running work.
+	work1.Cancel(ctx)
 
 	// We should now be able to perform 'test1' work again against this machine.
 	for _, session := range []*Session{session1, session2} {
-		err = session.Work(ctx, machine.MachineID, model.ProcessUnitTest1, func() error {
-			return nil
-		})
+		work1, err := session.Work(ctxB, model.ProcessUnitTest1, constantRetriever)
 		if err != nil {
 			t.Errorf("Could not run work against machine: %v", err)
+		} else {
+			work1.Cancel(ctxB)
 		}
 	}
 }
+
+// TestInstallationWorkflow exercises the agent installation workflow within the
+// BMDB.
+func TestInstallationWorkflow(t *testing.T) {
+	b := dut()
+	conn, err := b.Open(true)
+	if err != nil {
+		t.Fatalf("Open failed: %v", err)
+	}
+
+	ctx, ctxC := context.WithCancel(context.Background())
+	defer ctxC()
+
+	session, err := conn.StartSession(ctx)
+	if err != nil {
+		t.Fatalf("Starting session failed: %v", err)
+	}
+
+	// Create machine. Drop its ID.
+	err = session.Transact(ctx, func(q *model.Queries) error {
+		machine, err := q.NewMachine(ctx)
+		if err != nil {
+			return err
+		}
+		return q.MachineAddProvided(ctx, model.MachineAddProvidedParams{
+			MachineID:  machine.MachineID,
+			Provider:   model.ProviderEquinix,
+			ProviderID: "123",
+		})
+	})
+	if err != nil {
+		t.Fatalf("Creating machine failed: %v", err)
+	}
+
+	// Start working on a machine.
+	startedC := make(chan struct{})
+	doneC := make(chan struct{})
+	errC := make(chan error)
+	go func() {
+		work, err := session.Work(ctx, model.ProcessShepherdInstall, func(q *model.Queries) ([]uuid.UUID, error) {
+			machines, err := q.GetMachinesForAgentStart(ctx, 1)
+			if err != nil {
+				return nil, err
+			}
+			if len(machines) < 1 {
+				return nil, ErrNothingToDo
+			}
+			return []uuid.UUID{machines[0].MachineID}, nil
+		})
+		defer work.Cancel(ctx)
+
+		if err != nil {
+			close(startedC)
+			errC <- err
+			return
+		}
+
+		// Simulate work by blocking on a channel.
+		close(startedC)
+		<-doneC
+
+		err = work.Finish(ctx, func(q *model.Queries) error {
+			return q.MachineSetAgentStarted(ctx, model.MachineSetAgentStartedParams{
+				MachineID:      work.Machine,
+				AgentStartedAt: time.Now(),
+				AgentPublicKey: []byte("fakefakefake"),
+			})
+		})
+		errC <- err
+	}()
+	<-startedC
+	// Work on the machine has started. Attempting to get more machines now should
+	// return no machines.
+	err = session.Transact(ctx, func(q *model.Queries) error {
+		machines, err := q.GetMachinesForAgentStart(ctx, 1)
+		if err != nil {
+			return err
+		}
+		if len(machines) > 0 {
+			t.Errorf("Expected no machines ready for installation.")
+		}
+		return nil
+	})
+	if err != nil {
+		t.Errorf("Failed to retrieve machines for installation in parallel: %v", err)
+	}
+	// Finish working on machine.
+	close(doneC)
+	err = <-errC
+	if err != nil {
+		t.Errorf("Failed to finish work on machine: %v", err)
+	}
+	// That machine is now installed, so we still expect no work to have to be done.
+	err = session.Transact(ctx, func(q *model.Queries) error {
+		machines, err := q.GetMachinesForAgentStart(ctx, 1)
+		if err != nil {
+			return err
+		}
+		if len(machines) > 0 {
+			t.Errorf("Expected still no machines ready for installation.")
+		}
+		return nil
+	})
+	if err != nil {
+		t.Errorf("Failed to retrieve machines for installation after work finished: %v", err)
+	}
+}
+
+// TestInstallationWorkflowParallel starts work on three machines by six workers
+// and makes sure that there are no scheduling conflicts between them.
+func TestInstallationWorkflowParallel(t *testing.T) {
+	b := dut()
+	conn, err := b.Open(true)
+	if err != nil {
+		t.Fatalf("Open failed: %v", err)
+	}
+
+	ctx, ctxC := context.WithCancel(context.Background())
+	defer ctxC()
+
+	makeMachine := func(providerID string) {
+		ctxS, ctxC := context.WithCancel(ctx)
+		defer ctxC()
+		session, err := conn.StartSession(ctxS)
+		if err != nil {
+			t.Fatalf("Starting session failed: %v", err)
+		}
+		err = session.Transact(ctx, func(q *model.Queries) error {
+			machine, err := q.NewMachine(ctx)
+			if err != nil {
+				return err
+			}
+			return q.MachineAddProvided(ctx, model.MachineAddProvidedParams{
+				MachineID:  machine.MachineID,
+				Provider:   model.ProviderEquinix,
+				ProviderID: providerID,
+			})
+		})
+		if err != nil {
+			t.Fatalf("Creating machine failed: %v", err)
+		}
+	}
+	// Make six machines for testing.
+	for i := 0; i < 6; i++ {
+		makeMachine(fmt.Sprintf("test%d", i))
+	}
+
+	workStarted := make(chan struct{})
+	workDone := make(chan struct {
+		machine  uuid.UUID
+		workerID int
+	})
+
+	workOnce := func(ctx context.Context, workerID int, session *Session) error {
+		work, err := session.Work(ctx, model.ProcessShepherdInstall, func(q *model.Queries) ([]uuid.UUID, error) {
+			machines, err := q.GetMachinesForAgentStart(ctx, 1)
+			if err != nil {
+				return nil, err
+			}
+			if len(machines) < 1 {
+				return nil, ErrNothingToDo
+			}
+			return []uuid.UUID{machines[0].MachineID}, nil
+		})
+
+		if err != nil {
+			return err
+		}
+		defer work.Cancel(ctx)
+
+		select {
+		case <-workStarted:
+		case <-ctx.Done():
+			return ctx.Err()
+		}
+
+		select {
+		case workDone <- struct {
+			machine  uuid.UUID
+			workerID int
+		}{
+			machine:  work.Machine,
+			workerID: workerID,
+		}:
+		case <-ctx.Done():
+			return ctx.Err()
+		}
+
+		return work.Finish(ctx, func(q *model.Queries) error {
+			return q.MachineSetAgentStarted(ctx, model.MachineSetAgentStartedParams{
+				MachineID:      work.Machine,
+				AgentStartedAt: time.Now(),
+				AgentPublicKey: []byte("fakefakefake"),
+			})
+		})
+	}
+
+	worker := func(workerID int) {
+		ctxS, ctxC := context.WithCancel(ctx)
+		defer ctxC()
+		session, err := conn.StartSession(ctxS)
+		if err != nil {
+			t.Fatalf("Starting session failed: %v", err)
+		}
+		for {
+			err := workOnce(ctxS, workerID, session)
+			if err != nil {
+				if errors.Is(err, ctxS.Err()) {
+					return
+				}
+				t.Fatalf("worker failed: %v", err)
+			}
+		}
+	}
+	// Start three workers.
+	for i := 0; i < 3; i++ {
+		go worker(i)
+	}
+
+	// Wait for at least three workers to be alive.
+	for i := 0; i < 3; i++ {
+		workStarted <- struct{}{}
+	}
+
+	// Allow all workers to continue running from now on.
+	close(workStarted)
+
+	// Expect six machines to have been handled in parallel by three workers.
+	seenWorkers := make(map[int]bool)
+	seenMachines := make(map[string]bool)
+	for i := 0; i < 6; i++ {
+		res := <-workDone
+		seenWorkers[res.workerID] = true
+		seenMachines[res.machine.String()] = true
+	}
+
+	if want, got := 3, len(seenWorkers); want != got {
+		t.Errorf("Expected %d workers, got %d", want, got)
+	}
+	if want, got := 6, len(seenMachines); want != got {
+		t.Errorf("Expected %d machines, got %d", want, got)
+	}
+}