c/bmaas/bmdb: rework Work API into working form

The Work API within the BMDB library wasn't quite working for purely
BMDB-directed control loops. We fix that by turning Work into a
three-phase process of retrieval, side-effect causing work and
committing.

We also add a test that exercises this functionality alongside the Agent
tags and installation retrieval queries.

Change-Id: I43af0c02af034e402dadb3e3a1fd10e5a8fe680a
Reviewed-on: https://review.monogon.dev/c/monogon/+/998
Tested-by: Jenkins CI
Reviewed-by: Mateusz Zalega <mateusz@monogon.tech>
diff --git a/cloud/bmaas/bmdb/sessions_test.go b/cloud/bmaas/bmdb/sessions_test.go
index 0018109..9664c17 100644
--- a/cloud/bmaas/bmdb/sessions_test.go
+++ b/cloud/bmaas/bmdb/sessions_test.go
@@ -3,9 +3,12 @@
 import (
 	"context"
 	"errors"
+	"fmt"
 	"testing"
 	"time"
 
+	"github.com/google/uuid"
+
 	"source.monogon.dev/cloud/bmaas/bmdb/model"
 	"source.monogon.dev/cloud/lib/component"
 )
@@ -111,22 +114,13 @@
 	// part of the test.
 	ctxB, ctxBC := context.WithCancel(ctx)
 	defer ctxBC()
-	// Start work which will block forever. We have to go rendezvous through a
-	// channel to make sure the work actually starts.
-	started := make(chan error)
-	done := make(chan error, 1)
-	go func() {
-		err := session1.Work(ctxB, machine.MachineID, model.ProcessUnitTest1, func() error {
-			started <- nil
-			<-ctxB.Done()
-			return ctxB.Err()
-		})
-		done <- err
-		if err != nil {
-			started <- err
-		}
-	}()
-	err = <-started
+
+	constantRetriever := func(_ *model.Queries) ([]uuid.UUID, error) {
+		return []uuid.UUID{machine.MachineID}, nil
+	}
+
+	// Start work on machine which we're not gonna finish for a while.
+	work1, err := session1.Work(ctxB, model.ProcessUnitTest1, constantRetriever)
 	if err != nil {
 		t.Fatalf("Starting first work failed: %v", err)
 	}
@@ -134,39 +128,277 @@
 	// Starting more work on the same machine but a different process should still
 	// be allowed.
 	for _, session := range []*Session{session1, session2} {
-		err = session.Work(ctxB, machine.MachineID, model.ProcessUnitTest2, func() error {
-			return nil
-		})
+		work2, err := session.Work(ctxB, model.ProcessUnitTest2, constantRetriever)
 		if err != nil {
 			t.Errorf("Could not run concurrent process on machine: %v", err)
+		} else {
+			work2.Cancel(ctxB)
 		}
 	}
 
 	// However, starting work with the same process on the same machine should
 	// fail.
 	for _, session := range []*Session{session1, session2} {
-		err = session.Work(ctxB, machine.MachineID, model.ProcessUnitTest1, func() error {
-			return nil
-		})
+		work2, err := session.Work(ctxB, model.ProcessUnitTest1, constantRetriever)
 		if !errors.Is(err, ErrWorkConflict) {
 			t.Errorf("Concurrent work with same process should've been forbidden, got %v", err)
+			work2.Cancel(ctxB)
 		}
 	}
 
-	// Now, cancel the first long-running request and wait for it to return.
-	ctxBC()
-	err = <-done
-	if !errors.Is(err, ctxB.Err()) {
-		t.Fatalf("First work item should've failed with %v, got %v", ctxB.Err(), err)
-	}
+	// Now, finish the long-running work.
+	work1.Cancel(ctx)
 
 	// We should now be able to perform 'test1' work again against this machine.
 	for _, session := range []*Session{session1, session2} {
-		err = session.Work(ctx, machine.MachineID, model.ProcessUnitTest1, func() error {
-			return nil
-		})
+		work1, err := session.Work(ctxB, model.ProcessUnitTest1, constantRetriever)
 		if err != nil {
 			t.Errorf("Could not run work against machine: %v", err)
+		} else {
+			work1.Cancel(ctxB)
 		}
 	}
 }
+
+// TestInstallationWorkflow exercises the agent installation workflow within the
+// BMDB.
+func TestInstallationWorkflow(t *testing.T) {
+	b := dut()
+	conn, err := b.Open(true)
+	if err != nil {
+		t.Fatalf("Open failed: %v", err)
+	}
+
+	ctx, ctxC := context.WithCancel(context.Background())
+	defer ctxC()
+
+	session, err := conn.StartSession(ctx)
+	if err != nil {
+		t.Fatalf("Starting session failed: %v", err)
+	}
+
+	// Create machine. Drop its ID.
+	err = session.Transact(ctx, func(q *model.Queries) error {
+		machine, err := q.NewMachine(ctx)
+		if err != nil {
+			return err
+		}
+		return q.MachineAddProvided(ctx, model.MachineAddProvidedParams{
+			MachineID:  machine.MachineID,
+			Provider:   model.ProviderEquinix,
+			ProviderID: "123",
+		})
+	})
+	if err != nil {
+		t.Fatalf("Creating machine failed: %v", err)
+	}
+
+	// Start working on a machine.
+	startedC := make(chan struct{})
+	doneC := make(chan struct{})
+	errC := make(chan error)
+	go func() {
+		work, err := session.Work(ctx, model.ProcessShepherdInstall, func(q *model.Queries) ([]uuid.UUID, error) {
+			machines, err := q.GetMachinesForAgentStart(ctx, 1)
+			if err != nil {
+				return nil, err
+			}
+			if len(machines) < 1 {
+				return nil, ErrNothingToDo
+			}
+			return []uuid.UUID{machines[0].MachineID}, nil
+		})
+		defer work.Cancel(ctx)
+
+		if err != nil {
+			close(startedC)
+			errC <- err
+			return
+		}
+
+		// Simulate work by blocking on a channel.
+		close(startedC)
+		<-doneC
+
+		err = work.Finish(ctx, func(q *model.Queries) error {
+			return q.MachineSetAgentStarted(ctx, model.MachineSetAgentStartedParams{
+				MachineID:      work.Machine,
+				AgentStartedAt: time.Now(),
+				AgentPublicKey: []byte("fakefakefake"),
+			})
+		})
+		errC <- err
+	}()
+	<-startedC
+	// Work on the machine has started. Attempting to get more machines now should
+	// return no machines.
+	err = session.Transact(ctx, func(q *model.Queries) error {
+		machines, err := q.GetMachinesForAgentStart(ctx, 1)
+		if err != nil {
+			return err
+		}
+		if len(machines) > 0 {
+			t.Errorf("Expected no machines ready for installation.")
+		}
+		return nil
+	})
+	if err != nil {
+		t.Errorf("Failed to retrieve machines for installation in parallel: %v", err)
+	}
+	// Finish working on machine.
+	close(doneC)
+	err = <-errC
+	if err != nil {
+		t.Errorf("Failed to finish work on machine: %v", err)
+	}
+	// That machine is now installed, so we still expect no work to have to be done.
+	err = session.Transact(ctx, func(q *model.Queries) error {
+		machines, err := q.GetMachinesForAgentStart(ctx, 1)
+		if err != nil {
+			return err
+		}
+		if len(machines) > 0 {
+			t.Errorf("Expected still no machines ready for installation.")
+		}
+		return nil
+	})
+	if err != nil {
+		t.Errorf("Failed to retrieve machines for installation after work finished: %v", err)
+	}
+}
+
+// TestInstallationWorkflowParallel starts work on three machines by six workers
+// and makes sure that there are no scheduling conflicts between them.
+func TestInstallationWorkflowParallel(t *testing.T) {
+	b := dut()
+	conn, err := b.Open(true)
+	if err != nil {
+		t.Fatalf("Open failed: %v", err)
+	}
+
+	ctx, ctxC := context.WithCancel(context.Background())
+	defer ctxC()
+
+	makeMachine := func(providerID string) {
+		ctxS, ctxC := context.WithCancel(ctx)
+		defer ctxC()
+		session, err := conn.StartSession(ctxS)
+		if err != nil {
+			t.Fatalf("Starting session failed: %v", err)
+		}
+		err = session.Transact(ctx, func(q *model.Queries) error {
+			machine, err := q.NewMachine(ctx)
+			if err != nil {
+				return err
+			}
+			return q.MachineAddProvided(ctx, model.MachineAddProvidedParams{
+				MachineID:  machine.MachineID,
+				Provider:   model.ProviderEquinix,
+				ProviderID: providerID,
+			})
+		})
+		if err != nil {
+			t.Fatalf("Creating machine failed: %v", err)
+		}
+	}
+	// Make six machines for testing.
+	for i := 0; i < 6; i++ {
+		makeMachine(fmt.Sprintf("test%d", i))
+	}
+
+	workStarted := make(chan struct{})
+	workDone := make(chan struct {
+		machine  uuid.UUID
+		workerID int
+	})
+
+	workOnce := func(ctx context.Context, workerID int, session *Session) error {
+		work, err := session.Work(ctx, model.ProcessShepherdInstall, func(q *model.Queries) ([]uuid.UUID, error) {
+			machines, err := q.GetMachinesForAgentStart(ctx, 1)
+			if err != nil {
+				return nil, err
+			}
+			if len(machines) < 1 {
+				return nil, ErrNothingToDo
+			}
+			return []uuid.UUID{machines[0].MachineID}, nil
+		})
+
+		if err != nil {
+			return err
+		}
+		defer work.Cancel(ctx)
+
+		select {
+		case <-workStarted:
+		case <-ctx.Done():
+			return ctx.Err()
+		}
+
+		select {
+		case workDone <- struct {
+			machine  uuid.UUID
+			workerID int
+		}{
+			machine:  work.Machine,
+			workerID: workerID,
+		}:
+		case <-ctx.Done():
+			return ctx.Err()
+		}
+
+		return work.Finish(ctx, func(q *model.Queries) error {
+			return q.MachineSetAgentStarted(ctx, model.MachineSetAgentStartedParams{
+				MachineID:      work.Machine,
+				AgentStartedAt: time.Now(),
+				AgentPublicKey: []byte("fakefakefake"),
+			})
+		})
+	}
+
+	worker := func(workerID int) {
+		ctxS, ctxC := context.WithCancel(ctx)
+		defer ctxC()
+		session, err := conn.StartSession(ctxS)
+		if err != nil {
+			t.Fatalf("Starting session failed: %v", err)
+		}
+		for {
+			err := workOnce(ctxS, workerID, session)
+			if err != nil {
+				if errors.Is(err, ctxS.Err()) {
+					return
+				}
+				t.Fatalf("worker failed: %v", err)
+			}
+		}
+	}
+	// Start three workers.
+	for i := 0; i < 3; i++ {
+		go worker(i)
+	}
+
+	// Wait for at least three workers to be alive.
+	for i := 0; i < 3; i++ {
+		workStarted <- struct{}{}
+	}
+
+	// Allow all workers to continue running from now on.
+	close(workStarted)
+
+	// Expect six machines to have been handled in parallel by three workers.
+	seenWorkers := make(map[int]bool)
+	seenMachines := make(map[string]bool)
+	for i := 0; i < 6; i++ {
+		res := <-workDone
+		seenWorkers[res.workerID] = true
+		seenMachines[res.machine.String()] = true
+	}
+
+	if want, got := 3, len(seenWorkers); want != got {
+		t.Errorf("Expected %d workers, got %d", want, got)
+	}
+	if want, got := 6, len(seenMachines); want != got {
+		t.Errorf("Expected %d machines, got %d", want, got)
+	}
+}