c/bmaas/bmdb: rework Work API into working form

The Work API within the BMDB library wasn't quite working for purely
BMDB-directed control loops. We fix that by turning Work into a
three-phase process of retrieval, side-effect causing work and
committing.

We also add a test that exercises this functionality alongside the Agent
tags and installation retrieval queries.

Change-Id: I43af0c02af034e402dadb3e3a1fd10e5a8fe680a
Reviewed-on: https://review.monogon.dev/c/monogon/+/998
Tested-by: Jenkins CI
Reviewed-by: Mateusz Zalega <mateusz@monogon.tech>
diff --git a/cloud/bmaas/bmdb/sessions.go b/cloud/bmaas/bmdb/sessions.go
index 6b19171..ccb1510 100644
--- a/cloud/bmaas/bmdb/sessions.go
+++ b/cloud/bmaas/bmdb/sessions.go
@@ -198,7 +198,12 @@
 		}
 		// Success. Keep going.
 		deadline = time.Now().Add(s.interval)
-		time.Sleep(s.interval / 2)
+		select {
+		case <-ctx.Done():
+			// Do nothing, next loop iteration will exit.
+		case <-time.After(s.interval / 2):
+			// Do nothing, next loop iteration will heartbeat.
+		}
 	}
 }
 
@@ -226,40 +231,137 @@
 	})
 }
 
-// Work runs a given function as a work item with a given process name against
-// some identified machine. Not more than one process of a given name can run
-// against a machine concurrently.
-//
-// Most impure (meaning with side effects outside the database itself) BMDB
-// transactions should be run this way.
-func (s *Session) Work(ctx context.Context, machine uuid.UUID, process model.Process, fn func() error) error {
-	err := model.New(s.connection.db).StartWork(ctx, model.StartWorkParams{
-		MachineID: machine,
-		SessionID: s.UUID,
-		Process:   process,
-	})
-	if err != nil {
-		var perr *pq.Error
-		if errors.As(err, &perr) && perr.Code == "23505" {
-			return ErrWorkConflict
-		}
-		return fmt.Errorf("could not start work: %w", err)
-	}
-	klog.Infof("Started work: %q on machine %s, session %s", process, machine, s.UUID)
+var (
+	ErrNothingToDo = errors.New("nothing to do")
+	// PostgresUniqueViolation is returned by the lib/pq driver when a mutation
+	// cannot be performed due to a UNIQUE constraint being violated as a result of
+	// the query.
+	postgresUniqueViolation = pq.ErrorCode("23505")
+)
 
-	defer func() {
-		err := model.New(s.connection.db).FinishWork(s.ctx, model.FinishWorkParams{
-			MachineID: machine,
+// Work starts work on a machine. Full work execution is performed in three
+// phases:
+//
+//  1. Retrieval phase. This is performed by 'fn' given to this function.
+//     The retrieval function must return zero or more machines that some work
+//     should be performed on per the BMDB. The first returned machine will be
+//     locked for work under the given process and made available in the Work
+//     structure returned by this call. The function may be called multiple times,
+//     as it's run within a CockroachDB transaction which may be retried an
+//     arbitrary number of times. Thus, it should be side-effect free, ideally only
+//     performing read queries to the database.
+//  2. Work phase. This is performed by user code while holding on to the Work
+//     structure instance.
+//  3. Commit phase. This is performed by the function passed to Work.Finish. See
+//     that method's documentation for more details.
+//
+// Important: after retrieving Work successfully, either Finish or Cancel must be
+// called, otherwise the machine will be locked until the parent session expires
+// or is closed! It's safe and recommended to `defer work.Close()` after calling
+// Work().
+//
+// If no machine is eligible for work, ErrNothingToDo should be returned by the
+// retrieval function, and the same error (wrapped) will be returned by Work. In
+// case the retrieval function returns no machines and no error, that error will
+// also be returned.
+//
+// The returned Work object is _not_ goroutine safe.
+func (s *Session) Work(ctx context.Context, process model.Process, fn func(q *model.Queries) ([]uuid.UUID, error)) (*Work, error) {
+	var mid *uuid.UUID
+	err := s.Transact(ctx, func(q *model.Queries) error {
+		mids, err := fn(q)
+		if err != nil {
+			return fmt.Errorf("could not retrieve machines for work: %w", err)
+		}
+		if len(mids) < 1 {
+			return ErrNothingToDo
+		}
+		mid = &mids[0]
+		err = q.StartWork(ctx, model.StartWorkParams{
+			MachineID: mids[0],
 			SessionID: s.UUID,
 			Process:   process,
 		})
-		klog.Errorf("Finished work: %q on machine %s, session %s", process, machine, s.UUID)
-		if err != nil && !errors.Is(err, s.ctx.Err()) {
-			klog.Errorf("Failed to finish work: %v", err)
-			klog.Errorf("Closing session out of an abundance of caution")
-			s.ctxC()
+		if err != nil {
+			var perr *pq.Error
+			if errors.As(err, &perr) && perr.Code == postgresUniqueViolation {
+				return ErrWorkConflict
+			}
+			return fmt.Errorf("could not start work on %q: %w", mids[0], err)
 		}
-	}()
+		return nil
+	})
+	if err != nil {
+		return nil, err
+	}
+	klog.Infof("Started work %q on machine %q (sess %q)", process, *mid, s.UUID)
+	return &Work{
+		Machine: *mid,
+		s:       s,
+		process: process,
+	}, nil
+}
 
-	return fn()
+// Work being performed on a machine.
+type Work struct {
+	// Machine that this work is being performed on, as retrieved by the retrieval
+	// function passed to the Work method.
+	Machine uuid.UUID
+	// s is the parent session.
+	s *Session
+	// done marks that this work has already been canceled or finished.
+	done bool
+	// process that this work performs.
+	process model.Process
+}
+
+// Cancel the Work started on a machine. If the work has already been finished
+// or canceled, this is a no-op. In case of error, a log line will be emitted.
+func (w *Work) Cancel(ctx context.Context) {
+	if w.done {
+		return
+	}
+	w.done = true
+
+	klog.Infof("Canceling work %q on machine %q (sess %q)", w.process, w.Machine, w.s.UUID)
+	// Eat error and log. There's nothing we can do if this fails, and if it does, it's
+	// probably because our connectivity to the BMDB has failed. If so, our session
+	// will be invalidated soon and so will the work being performed on this
+	// machine.
+	err := w.s.Transact(ctx, func(q *model.Queries) error {
+		return q.FinishWork(ctx, model.FinishWorkParams{
+			MachineID: w.Machine,
+			SessionID: w.s.UUID,
+			Process:   w.process,
+		})
+	})
+	if err != nil {
+		klog.Errorf("Failed to cancel work %q on %q (sess %q): %v", w.process, w.Machine, w.s.UUID, err)
+	}
+}
+
+// Finish work by executing a commit function 'fn' and releasing the machine
+// from the work performed. The function given should apply tags to the
+// processed machine in a way that causes it to not be eligible for retrieval
+// again. As with the retriever function, the commit function might be called an
+// arbitrary number of times as part of cockroachdb transaction retries.
+//
+// This may be called only once.
+func (w *Work) Finish(ctx context.Context, fn func(q *model.Queries) error) error {
+	if w.done {
+		return fmt.Errorf("already finished")
+	}
+	w.done = true
+	klog.Infof("Finishing work %q on machine %q (sess %q)", w.process, w.Machine, w.s.UUID)
+	return w.s.Transact(ctx, func(q *model.Queries) error {
+		err := q.FinishWork(ctx, model.FinishWorkParams{
+			MachineID: w.Machine,
+			SessionID: w.s.UUID,
+			Process:   w.process,
+		})
+		if err != nil {
+			return err
+		}
+		return fn(q)
+	})
 }