cloud: split shepherd up
Change-Id: I8e386d9eaaf17543743e1e8a37a8d71426910d59
Reviewed-on: https://review.monogon.dev/c/monogon/+/2213
Reviewed-by: Serge Bazanski <serge@monogon.tech>
Tested-by: Jenkins CI
diff --git a/cloud/shepherd/manager/BUILD.bazel b/cloud/shepherd/manager/BUILD.bazel
new file mode 100644
index 0000000..4119ff7
--- /dev/null
+++ b/cloud/shepherd/manager/BUILD.bazel
@@ -0,0 +1,54 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "manager",
+ srcs = [
+ "control_loop.go",
+ "fake_ssh_client.go",
+ "initializer.go",
+ "manager.go",
+ "provisioner.go",
+ "recoverer.go",
+ "ssh_client.go",
+ "ssh_key_signer.go",
+ ],
+ importpath = "source.monogon.dev/cloud/shepherd/manager",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//cloud/agent/api",
+ "//cloud/bmaas/bmdb",
+ "//cloud/bmaas/bmdb/metrics",
+ "//cloud/bmaas/bmdb/model",
+ "//cloud/shepherd",
+ "//go/mflags",
+ "@com_github_google_uuid//:uuid",
+ "@com_github_pkg_sftp//:sftp",
+ "@io_k8s_klog_v2//:klog",
+ "@org_golang_google_protobuf//proto",
+ "@org_golang_x_crypto//ssh",
+ "@org_golang_x_sync//errgroup",
+ "@org_golang_x_time//rate",
+ ],
+)
+
+go_test(
+ name = "manager_test",
+ srcs = [
+ "initializer_test.go",
+ "provider_test.go",
+ "provisioner_test.go",
+ ],
+ data = [
+ "@cockroach",
+ ],
+ embed = [":manager"],
+ deps = [
+ "//cloud/bmaas/bmdb",
+ "//cloud/bmaas/bmdb/model",
+ "//cloud/lib/component",
+ "//cloud/shepherd",
+ "@com_github_google_uuid//:uuid",
+ "@io_k8s_klog_v2//:klog",
+ "@org_golang_x_time//rate",
+ ],
+)
diff --git a/cloud/shepherd/manager/README.md b/cloud/shepherd/manager/README.md
new file mode 100644
index 0000000..d5a17c3
--- /dev/null
+++ b/cloud/shepherd/manager/README.md
@@ -0,0 +1,54 @@
+Equinix Shepherd
+===
+
+Manages Equinix machines in sync with BMDB contents. Made up of two components:
+
+Provisioner
+---
+
+Brings up machines from hardware reservations and populates BMDB with new Provided machines.
+
+Initializer
+---
+
+Starts the Agent over SSH (wherever necessary per the BMDB) and reports success into the BMDB.
+
+
+Running
+===
+
+Unit Tests
+---
+
+The Shepherd has some basic smoke tests which run against a Fakequinix.
+
+Manual Testing
+---
+
+If you have Equinix credentials, you can run:
+
+```
+$ bazel build //cloud/shepherd/provider/equinix
+$ bazel build //cloud/shepherd/manager/test_agent
+$ bazel-bin/cloud/shepherd/provider/equinix/equinix_/equinix \
+ -bmdb_eat_my_data \
+ -equinix_project_id FIXME \
+ -equinix_api_username FIXME \
+ -equinix_api_key FIXME \
+ -agent_executable_path bazel-bin/cloud/shepherd/manager/test_agent/test_agent_/test_agent \
+ -agent_endpoint example.com \
+ -equinix_ssh_key_label $USER-FIXME \
+ -equinix_device_prefix $USER-FIXME- \
+ -provisioner_assimilate -provisioner_max_machines 10
+```
+
+Replace $USER-FIXME with `<your username>-test` or some other unique name/prefix.
+
+This will start a single instance of the provisioner accompanied by a single instance of the initializer.
+
+A persistent SSH key will be created in your current working directory.
+
+Prod Deployment
+---
+
+TODO(q3k): split server binary into separate provisioner/initializer for initializer scalability, as that's the main bottleneck.
\ No newline at end of file
diff --git a/cloud/shepherd/manager/control_loop.go b/cloud/shepherd/manager/control_loop.go
new file mode 100644
index 0000000..e1fdd1d
--- /dev/null
+++ b/cloud/shepherd/manager/control_loop.go
@@ -0,0 +1,221 @@
+package manager
+
+import (
+ "context"
+ "errors"
+ "flag"
+ "fmt"
+ "time"
+
+ "github.com/google/uuid"
+ "golang.org/x/sync/errgroup"
+ "golang.org/x/time/rate"
+ "k8s.io/klog/v2"
+
+ "source.monogon.dev/cloud/bmaas/bmdb"
+ "source.monogon.dev/cloud/bmaas/bmdb/metrics"
+ "source.monogon.dev/cloud/bmaas/bmdb/model"
+ "source.monogon.dev/go/mflags"
+)
+
+// task describes a single server currently being processed by a control loop.
+type task struct {
+ // machine is the machine data (including provider and provider ID) retrieved
+ // from the BMDB.
+ machine *model.MachineProvided
+ // work is a machine lock facilitated by BMDB that prevents machines from
+ // being processed by multiple workers at the same time.
+ work *bmdb.Work
+ // backoff is configured from processInfo.defaultBackoff but can be overridden by
+ // processMachine to set a different backoff policy for specific failure modes.
+ backoff bmdb.Backoff
+}
+
+// controlLoop is implemented by any component which should act as a BMDB-based
+// control loop. Implementing these methods allows the given component to be
+// started using RunControlLoop.
+type controlLoop interface {
+ getProcessInfo() processInfo
+
+ // getMachines must return the list of machines ready to be processed by the
+ // control loop for a given control loop implementation.
+ getMachines(ctx context.Context, q *model.Queries, limit int32) ([]model.MachineProvided, error)
+ // processMachine will be called within the scope of an active task/BMDB work by
+ // the control loop logic.
+ processMachine(ctx context.Context, t *task) error
+
+ // getControlLoopConfig is implemented by ControlLoopConfig which should be
+ // embedded by the control loop component. If not embedded, this method will have
+ // to be implemented, too.
+ getControlLoopConfig() *ControlLoopConfig
+}
+
+type processInfo struct {
+ process model.Process
+ processor metrics.Processor
+ defaultBackoff bmdb.Backoff
+}
+
+// ControlLoopConfig should be embedded the every component which acts as a
+// control loop. RegisterFlags should be called by the component whenever it is
+// registering its own flags. Check should be called whenever the component is
+// instantiated, after RegisterFlags has been called.
+type ControlLoopConfig struct {
+ // DBQueryLimiter limits the rate at which BMDB is queried for servers ready
+ // for BMaaS agent initialization. Must be set.
+ DBQueryLimiter *rate.Limiter
+
+ // Parallelism is how many instances of the Initializer will be allowed to run in
+ // parallel against the BMDB. This speeds up the process of starting/restarting
+ // agents significantly, as one initializer instance can handle at most one agent
+ // (re)starting process.
+ //
+ // If not set (ie. 0), default to 1. A good starting value for production
+ // deployments is 10 or so.
+ Parallelism int
+}
+
+func (c *ControlLoopConfig) getControlLoopConfig() *ControlLoopConfig {
+ return c
+}
+
+// RegisterFlags should be called on this configuration whenever the embeddeding
+// component/configuration is registering its own flags. The prefix should be the
+// name of the component.
+func (c *ControlLoopConfig) RegisterFlags(prefix string) {
+ mflags.Limiter(&c.DBQueryLimiter, prefix+"_db_query_rate", "250ms,8", "Rate limiting for BMDB queries")
+ flag.IntVar(&c.Parallelism, prefix+"_loop_parallelism", 1, "How many initializer instances to run in parallel, ie. how many agents to attempt to (re)start at once")
+}
+
+// Check should be called after RegisterFlags but before the control loop is ran.
+// If an error is returned, the control loop cannot start.
+func (c *ControlLoopConfig) Check() error {
+ if c.DBQueryLimiter == nil {
+ return fmt.Errorf("DBQueryLimiter must be configured")
+ }
+ if c.Parallelism == 0 {
+ c.Parallelism = 1
+ }
+ return nil
+}
+
+// RunControlLoop runs the given controlLoop implementation against the BMDB. The
+// loop will be run with the parallelism and rate configured by the
+// ControlLoopConfig embedded or otherwise returned by the controlLoop.
+func RunControlLoop(ctx context.Context, conn *bmdb.Connection, loop controlLoop) error {
+ clr := &controlLoopRunner{
+ loop: loop,
+ config: loop.getControlLoopConfig(),
+ }
+ return clr.run(ctx, conn)
+}
+
+// controlLoopRunner is a configured control loop with an underlying control loop
+// implementation.
+type controlLoopRunner struct {
+ config *ControlLoopConfig
+ loop controlLoop
+}
+
+// run the control loops(s) (depending on opts.Parallelism) blocking the current
+// goroutine until the given context expires and all provisioners quit.
+func (r *controlLoopRunner) run(ctx context.Context, conn *bmdb.Connection) error {
+ pinfo := r.loop.getProcessInfo()
+
+ eg := errgroup.Group{}
+ for j := 0; j < r.config.Parallelism; j += 1 {
+ eg.Go(func() error {
+ return r.runOne(ctx, conn, &pinfo)
+ })
+ }
+ return eg.Wait()
+}
+
+// run the control loop blocking the current goroutine until the given context
+// expires.
+func (r *controlLoopRunner) runOne(ctx context.Context, conn *bmdb.Connection, pinfo *processInfo) error {
+ var err error
+
+ // Maintain a BMDB session as long as possible.
+ var sess *bmdb.Session
+ for {
+ if sess == nil {
+ sess, err = conn.StartSession(ctx, bmdb.SessionOption{Processor: pinfo.processor})
+ if err != nil {
+ return fmt.Errorf("could not start BMDB session: %w", err)
+ }
+ }
+ // Inside that session, run the main logic.
+ err := r.runInSession(ctx, sess, pinfo)
+
+ switch {
+ case err == nil:
+ case errors.Is(err, ctx.Err()):
+ return err
+ case errors.Is(err, bmdb.ErrSessionExpired):
+ klog.Errorf("Session expired, restarting...")
+ sess = nil
+ time.Sleep(time.Second)
+ case err != nil:
+ klog.Errorf("Processing failed: %v", err)
+ // TODO(q3k): close session
+ time.Sleep(time.Second)
+ }
+ }
+}
+
+// runInSession executes one iteration of the control loop within a BMDB session.
+// This control loop attempts to start or re-start the agent on any machines that
+// need this per the BMDB.
+func (r *controlLoopRunner) runInSession(ctx context.Context, sess *bmdb.Session, pinfo *processInfo) error {
+ t, err := r.source(ctx, sess, pinfo)
+ if err != nil {
+ return fmt.Errorf("could not source machine: %w", err)
+ }
+ if t == nil {
+ return nil
+ }
+ defer t.work.Cancel(ctx)
+
+ if err := r.loop.processMachine(ctx, t); err != nil {
+ klog.Errorf("Failed to process machine %s: %v", t.machine.MachineID, err)
+ err = t.work.Fail(ctx, &t.backoff, fmt.Sprintf("failed to process: %v", err))
+ return err
+ }
+ return nil
+}
+
+// source supplies returns a BMDB-locked server ready for processing by the
+// control loop, locked by a work item. If both task and error are nil, then
+// there are no machines needed to be initialized. The returned work item in task
+// _must_ be canceled or finished by the caller.
+func (r *controlLoopRunner) source(ctx context.Context, sess *bmdb.Session, pinfo *processInfo) (*task, error) {
+ r.config.DBQueryLimiter.Wait(ctx)
+
+ var machine *model.MachineProvided
+ work, err := sess.Work(ctx, pinfo.process, func(q *model.Queries) ([]uuid.UUID, error) {
+ machines, err := r.loop.getMachines(ctx, q, 1)
+ if err != nil {
+ return nil, err
+ }
+ if len(machines) < 1 {
+ return nil, bmdb.ErrNothingToDo
+ }
+ machine = &machines[0]
+ return []uuid.UUID{machines[0].MachineID}, nil
+ })
+
+ if errors.Is(err, bmdb.ErrNothingToDo) {
+ return nil, nil
+ }
+
+ if err != nil {
+ return nil, fmt.Errorf("while querying BMDB agent candidates: %w", err)
+ }
+
+ return &task{
+ machine: machine,
+ work: work,
+ backoff: pinfo.defaultBackoff,
+ }, nil
+}
diff --git a/cloud/shepherd/manager/fake_ssh_client.go b/cloud/shepherd/manager/fake_ssh_client.go
new file mode 100644
index 0000000..1d9d371
--- /dev/null
+++ b/cloud/shepherd/manager/fake_ssh_client.go
@@ -0,0 +1,58 @@
+package manager
+
+import (
+ "context"
+ "crypto/ed25519"
+ "crypto/rand"
+ "fmt"
+ "time"
+
+ "google.golang.org/protobuf/proto"
+
+ apb "source.monogon.dev/cloud/agent/api"
+)
+
+// FakeSSHClient is an SSHClient that pretends to start an agent, but in reality
+// just responds with what an agent would respond on every execution attempt.
+type FakeSSHClient struct{}
+
+type fakeSSHConnection struct{}
+
+func (f *FakeSSHClient) Dial(ctx context.Context, address string, timeout time.Duration) (SSHConnection, error) {
+ return &fakeSSHConnection{}, nil
+}
+
+func (f *fakeSSHConnection) Execute(ctx context.Context, command string, stdin []byte) (stdout []byte, stderr []byte, err error) {
+ var aim apb.TakeoverInit
+ if err := proto.Unmarshal(stdin, &aim); err != nil {
+ return nil, nil, fmt.Errorf("while unmarshaling TakeoverInit message: %v", err)
+ }
+
+ // Agent should send back apb.TakeoverResponse on its standard output.
+ pub, _, err := ed25519.GenerateKey(rand.Reader)
+ if err != nil {
+ return nil, nil, fmt.Errorf("while generating agent public key: %v", err)
+ }
+ arsp := apb.TakeoverResponse{
+ Result: &apb.TakeoverResponse_Success{Success: &apb.TakeoverSuccess{
+ InitMessage: &aim,
+ Key: pub,
+ }},
+ }
+ arspb, err := proto.Marshal(&arsp)
+ if err != nil {
+ return nil, nil, fmt.Errorf("while marshaling TakeoverResponse message: %v", err)
+ }
+ return arspb, nil, nil
+}
+
+func (f *fakeSSHConnection) Upload(ctx context.Context, targetPath string, data []byte) error {
+ if targetPath != "/fake/path" {
+ return fmt.Errorf("unexpected target path in test")
+ }
+ return nil
+}
+
+func (f *fakeSSHConnection) Close() error {
+ return nil
+}
diff --git a/cloud/shepherd/manager/initializer.go b/cloud/shepherd/manager/initializer.go
new file mode 100644
index 0000000..5abbc68
--- /dev/null
+++ b/cloud/shepherd/manager/initializer.go
@@ -0,0 +1,271 @@
+package manager
+
+import (
+ "context"
+ "crypto/ed25519"
+ "crypto/x509"
+ "encoding/hex"
+ "encoding/pem"
+ "flag"
+ "fmt"
+ "net"
+ "os"
+ "strings"
+ "time"
+
+ "github.com/google/uuid"
+ "google.golang.org/protobuf/proto"
+ "k8s.io/klog/v2"
+
+ apb "source.monogon.dev/cloud/agent/api"
+
+ "source.monogon.dev/cloud/bmaas/bmdb"
+ "source.monogon.dev/cloud/bmaas/bmdb/metrics"
+ "source.monogon.dev/cloud/bmaas/bmdb/model"
+ "source.monogon.dev/cloud/shepherd"
+)
+
+// InitializerConfig configures how the Initializer will deploy Agents on
+// machines. In CLI scenarios, this should be populated from flags via
+// RegisterFlags.
+type InitializerConfig struct {
+ ControlLoopConfig
+
+ // Executable is the contents of the agent binary created and run
+ // at the provisioned servers. Must be set.
+ Executable []byte
+
+ // TargetPath is a filesystem destination path used while uploading the BMaaS
+ // agent executable to hosts as part of the initialization process. Must be set.
+ TargetPath string
+
+ // Endpoint is the address Agent will use to contact the BMaaS
+ // infrastructure. Must be set.
+ Endpoint string
+
+ // EndpointCACertificate is an optional DER-encoded (but not PEM-armored) X509
+ // certificate used to populate the trusted CA store of the agent. It should be
+ // set to the CA certificate of the endpoint if not using a system-trusted CA
+ // certificate.
+ EndpointCACertificate []byte
+
+ // SSHTimeout is the amount of time set aside for the initializing
+ // SSH session to run its course. Upon timeout, the iteration would be
+ // declared a failure. Must be set.
+ SSHConnectTimeout time.Duration
+ // SSHExecTimeout is the amount of time set aside for executing the agent and
+ // getting its output once the SSH connection has been established. Upon timeout,
+ // the iteration would be declared as failure. Must be set.
+ SSHExecTimeout time.Duration
+}
+
+func (ic *InitializerConfig) RegisterFlags() {
+ ic.ControlLoopConfig.RegisterFlags("initializer")
+
+ flag.Func("agent_executable_path", "Local filesystem path of agent binary to be uploaded", func(val string) error {
+ if val == "" {
+ return nil
+ }
+ data, err := os.ReadFile(val)
+ if err != nil {
+ return fmt.Errorf("could not read: %w", err)
+ }
+ ic.Executable = data
+ return nil
+ })
+ flag.StringVar(&ic.TargetPath, "agent_target_path", "/root/agent", "Filesystem path where the agent will be uploaded to and ran from")
+ flag.StringVar(&ic.Endpoint, "agent_endpoint", "", "Address of BMDB Server to which the agent will attempt to connect")
+ flag.Func("agent_endpoint_ca_certificate_path", "Path to PEM X509 CA certificate that the agent endpoint is serving with. If not set, the agent will attempt to use system CA certificates to authenticate the endpoint.", func(val string) error {
+ if val == "" {
+ return nil
+ }
+ data, err := os.ReadFile(val)
+ if err != nil {
+ return fmt.Errorf("could not read: %w", err)
+ }
+ block, _ := pem.Decode(data)
+ if block.Type != "CERTIFICATE" {
+ return fmt.Errorf("not a certificate")
+ }
+ _, err = x509.ParseCertificate(block.Bytes)
+ if err != nil {
+ return fmt.Errorf("invalid certificate: %w", err)
+ }
+ ic.EndpointCACertificate = block.Bytes
+ return nil
+ })
+ flag.DurationVar(&ic.SSHConnectTimeout, "agent_ssh_connect_timeout", 2*time.Second, "Timeout for connecting over SSH to a machine")
+ flag.DurationVar(&ic.SSHExecTimeout, "agent_ssh_exec_timeout", 60*time.Second, "Timeout for connecting over SSH to a machine")
+}
+
+func (ic *InitializerConfig) Check() error {
+ if err := ic.ControlLoopConfig.Check(); err != nil {
+ return err
+ }
+
+ if len(ic.Executable) == 0 {
+ return fmt.Errorf("agent executable not configured")
+ }
+ if ic.TargetPath == "" {
+ return fmt.Errorf("agent target path must be set")
+ }
+ if ic.Endpoint == "" {
+ return fmt.Errorf("agent endpoint must be set")
+ }
+ if ic.SSHConnectTimeout == 0 {
+ return fmt.Errorf("agent SSH connection timeout must be set")
+ }
+ if ic.SSHExecTimeout == 0 {
+ return fmt.Errorf("agent SSH execution timeout must be set")
+ }
+
+ return nil
+}
+
+// The Initializer starts the agent on machines that aren't yet running it.
+type Initializer struct {
+ InitializerConfig
+
+ sshClient SSHClient
+ p shepherd.Provider
+}
+
+// NewInitializer creates an Initializer instance, checking the
+// InitializerConfig, SharedConfig and AgentConfig for errors.
+func NewInitializer(p shepherd.Provider, sshClient SSHClient, ic InitializerConfig) (*Initializer, error) {
+ if err := ic.Check(); err != nil {
+ return nil, err
+ }
+
+ return &Initializer{
+ InitializerConfig: ic,
+
+ p: p,
+ sshClient: sshClient,
+ }, nil
+}
+
+func (i *Initializer) getProcessInfo() processInfo {
+ return processInfo{
+ process: model.ProcessShepherdAgentStart,
+ defaultBackoff: bmdb.Backoff{
+ Initial: 5 * time.Minute,
+ Maximum: 4 * time.Hour,
+ Exponent: 1.2,
+ },
+ processor: metrics.ProcessorShepherdInitializer,
+ }
+}
+
+func (i *Initializer) getMachines(ctx context.Context, q *model.Queries, limit int32) ([]model.MachineProvided, error) {
+ return q.GetMachinesForAgentStart(ctx, model.GetMachinesForAgentStartParams{
+ Limit: limit,
+ Provider: i.p.Type(),
+ })
+}
+
+func (i *Initializer) processMachine(ctx context.Context, t *task) error {
+ machine, err := i.p.GetMachine(ctx, shepherd.ProviderID(t.machine.ProviderID))
+ if err != nil {
+ return fmt.Errorf("while fetching machine %q: %v", t.machine.ProviderID, err)
+ }
+
+ // Start the agent.
+ klog.Infof("Starting agent on machine (ID: %s, PID %s)", t.machine.MachineID, t.machine.ProviderID)
+ apk, err := i.startAgent(ctx, machine, t.machine.MachineID)
+ if err != nil {
+ return fmt.Errorf("while starting the agent: %w", err)
+ }
+
+ // Agent startup succeeded. Set the appropriate BMDB tag, and release the
+ // lock.
+ klog.Infof("Setting AgentStarted (ID: %s, PID: %s, Agent public key: %s).", t.machine.MachineID, t.machine.ProviderID, hex.EncodeToString(apk))
+ err = t.work.Finish(ctx, func(q *model.Queries) error {
+ return q.MachineSetAgentStarted(ctx, model.MachineSetAgentStartedParams{
+ MachineID: t.machine.MachineID,
+ AgentStartedAt: time.Now(),
+ AgentPublicKey: apk,
+ })
+ })
+ if err != nil {
+ return fmt.Errorf("while setting AgentStarted tag: %w", err)
+ }
+ return nil
+}
+
+// startAgent runs the agent executable on the target machine m, returning the
+// agent's public key on success.
+func (i *Initializer) startAgent(ctx context.Context, m shepherd.Machine, mid uuid.UUID) ([]byte, error) {
+ // Provide a bound on execution time in case we get stuck after the SSH
+ // connection is established.
+ sctx, sctxC := context.WithTimeout(ctx, i.SSHExecTimeout)
+ defer sctxC()
+
+ // Use the machine's IP address
+ ni := m.Addr()
+ if !ni.IsValid() {
+ return nil, fmt.Errorf("machine (machine ID: %s) has no available addresses", mid)
+ }
+
+ addr := net.JoinHostPort(ni.String(), "22")
+ klog.V(1).Infof("Dialing machine (machine ID: %s, addr: %s).", mid, addr)
+
+ conn, err := i.sshClient.Dial(sctx, addr, i.SSHConnectTimeout)
+ if err != nil {
+ return nil, fmt.Errorf("while dialing the machine: %w", err)
+ }
+ defer conn.Close()
+
+ // Upload the agent executable.
+
+ klog.Infof("Uploading the agent executable (machine ID: %s, addr: %s).", mid, addr)
+ if err := conn.Upload(sctx, i.TargetPath, i.Executable); err != nil {
+ return nil, fmt.Errorf("while uploading agent executable: %w", err)
+ }
+ klog.V(1).Infof("Upload successful (machine ID: %s, addr: %s).", mid, addr)
+
+ // The initialization protobuf message will be sent to the agent on its
+ // standard input.
+ imsg := apb.TakeoverInit{
+ MachineId: mid.String(),
+ BmaasEndpoint: i.Endpoint,
+ CaCertificate: i.EndpointCACertificate,
+ }
+ imsgb, err := proto.Marshal(&imsg)
+ if err != nil {
+ return nil, fmt.Errorf("while marshaling agent message: %w", err)
+ }
+
+ // Start the agent and wait for the agent's output to arrive.
+ klog.V(1).Infof("Starting the agent executable at path %q (machine ID: %s).", i.TargetPath, mid)
+ stdout, stderr, err := conn.Execute(ctx, i.TargetPath, imsgb)
+ stderrStr := strings.TrimSpace(string(stderr))
+ if stderrStr != "" {
+ klog.Warningf("Agent stderr: %q", stderrStr)
+ }
+ if err != nil {
+ return nil, fmt.Errorf("while starting the agent executable: %w", err)
+ }
+
+ var arsp apb.TakeoverResponse
+ if err := proto.Unmarshal(stdout, &arsp); err != nil {
+ return nil, fmt.Errorf("agent reply couldn't be unmarshaled: %w", err)
+ }
+ var successResp *apb.TakeoverSuccess
+ switch r := arsp.Result.(type) {
+ case *apb.TakeoverResponse_Error:
+ return nil, fmt.Errorf("agent returned error: %v", r.Error.Message)
+ case *apb.TakeoverResponse_Success:
+ successResp = r.Success
+ default:
+ return nil, fmt.Errorf("agent returned unknown result of type %T", arsp.Result)
+ }
+ if !proto.Equal(&imsg, successResp.InitMessage) {
+ return nil, fmt.Errorf("agent did not send back the init message.")
+ }
+ if len(successResp.Key) != ed25519.PublicKeySize {
+ return nil, fmt.Errorf("agent key length mismatch.")
+ }
+ klog.Infof("Started the agent (machine ID: %s, key: %s).", mid, hex.EncodeToString(successResp.Key))
+ return successResp.Key, nil
+}
diff --git a/cloud/shepherd/manager/initializer_test.go b/cloud/shepherd/manager/initializer_test.go
new file mode 100644
index 0000000..5ba2253
--- /dev/null
+++ b/cloud/shepherd/manager/initializer_test.go
@@ -0,0 +1,91 @@
+package manager
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "golang.org/x/time/rate"
+
+ "source.monogon.dev/cloud/bmaas/bmdb"
+ "source.monogon.dev/cloud/bmaas/bmdb/model"
+ "source.monogon.dev/cloud/lib/component"
+)
+
+// TestInitializerSmokes makes sure the Initializer doesn't go up in flames on
+// the happy path.
+func TestInitializerSmokes(t *testing.T) {
+ provider := newDummyProvider(100)
+
+ ic := InitializerConfig{
+ ControlLoopConfig: ControlLoopConfig{
+ DBQueryLimiter: rate.NewLimiter(rate.Every(time.Second), 10),
+ },
+ Executable: []byte("beep boop i'm a real program"),
+ TargetPath: "/fake/path",
+ Endpoint: "example.com:1234",
+ SSHConnectTimeout: time.Second,
+ SSHExecTimeout: time.Second,
+ }
+
+ i, err := NewInitializer(provider, provider.sshClient(), ic)
+ if err != nil {
+ t.Fatalf("Could not create Initializer: %v", err)
+ }
+
+ b := bmdb.BMDB{
+ Config: bmdb.Config{
+ Database: component.CockroachConfig{
+ InMemory: true,
+ },
+ ComponentName: "test",
+ RuntimeInfo: "test",
+ },
+ }
+ conn, err := b.Open(true)
+ if err != nil {
+ t.Fatalf("Could not create in-memory BMDB: %v", err)
+ }
+
+ ctx, ctxC := context.WithCancel(context.Background())
+ t.Cleanup(ctxC)
+
+ go RunControlLoop(ctx, conn, i)
+
+ sess, err := conn.StartSession(ctx)
+ if err != nil {
+ t.Fatalf("Failed to create BMDB session for verifiaction: %v", err)
+ }
+
+ // Create 10 provided machines for testing.
+ if _, err := provider.createDummyMachines(ctx, sess, 10); err != nil {
+ t.Fatalf("Failed to create dummy machines: %v", err)
+ }
+
+ // Expect to find 0 machines needing start.
+ for {
+ time.Sleep(100 * time.Millisecond)
+
+ var machines []model.MachineProvided
+ err = sess.Transact(ctx, func(q *model.Queries) error {
+ var err error
+ machines, err = q.GetMachinesForAgentStart(ctx, model.GetMachinesForAgentStartParams{
+ Limit: 100,
+ Provider: provider.Type(),
+ })
+ return err
+ })
+ if err != nil {
+ t.Fatalf("Failed to run Transaction: %v", err)
+ }
+ if len(machines) == 0 {
+ break
+ }
+ }
+
+ for _, m := range provider.machines {
+ if !m.agentStarted {
+ t.Fatalf("Initializer didn't start agent on machine %q", m.id)
+ }
+ }
+}
diff --git a/cloud/shepherd/manager/manager.go b/cloud/shepherd/manager/manager.go
new file mode 100644
index 0000000..3ae7854
--- /dev/null
+++ b/cloud/shepherd/manager/manager.go
@@ -0,0 +1,18 @@
+// Package manager, itself a part of BMaaS project, provides implementation
+// governing Equinix bare metal server lifecycle according to conditions set by
+// Bare Metal Database (BMDB).
+//
+// The implementation will attempt to provide as many machines as possible and
+// register them with BMDB. This is limited by the count of Hardware
+// Reservations available in the Equinix Metal project used. The BMaaS agent
+// will then be started on these machines as soon as they become ready.
+//
+// The implementation is provided in the form of a library, to which interface is
+// exported through Provisioner and Initializer types, each taking servers
+// through a single stage of their lifecycle.
+//
+// See the included test code for usage examples.
+//
+// The terms "device" and "machine" are used interchangeably throughout this
+// package due to differences in Equinix Metal and BMDB nomenclature.
+package manager
diff --git a/cloud/shepherd/manager/provider_test.go b/cloud/shepherd/manager/provider_test.go
new file mode 100644
index 0000000..d1c6361
--- /dev/null
+++ b/cloud/shepherd/manager/provider_test.go
@@ -0,0 +1,182 @@
+package manager
+
+import (
+ "context"
+ "fmt"
+ "net/netip"
+ "time"
+
+ "github.com/google/uuid"
+ "k8s.io/klog/v2"
+
+ "source.monogon.dev/cloud/bmaas/bmdb"
+ "source.monogon.dev/cloud/bmaas/bmdb/model"
+ "source.monogon.dev/cloud/shepherd"
+)
+
+type dummyMachine struct {
+ id shepherd.ProviderID
+ addr netip.Addr
+ state shepherd.State
+ agentStarted bool
+}
+
+func (dm *dummyMachine) ID() shepherd.ProviderID {
+ return dm.id
+}
+
+func (dm *dummyMachine) Addr() netip.Addr {
+ return dm.addr
+}
+
+func (dm *dummyMachine) State() shepherd.State {
+ return dm.state
+}
+
+type dummySSHClient struct {
+ SSHClient
+ dp *dummyProvider
+}
+
+type dummySSHConnection struct {
+ SSHConnection
+ m *dummyMachine
+}
+
+func (dsc *dummySSHConnection) Execute(ctx context.Context, command string, stdin []byte) ([]byte, []byte, error) {
+ stdout, stderr, err := dsc.SSHConnection.Execute(ctx, command, stdin)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ dsc.m.agentStarted = true
+ return stdout, stderr, nil
+}
+
+func (dsc *dummySSHClient) Dial(ctx context.Context, address string, timeout time.Duration) (SSHConnection, error) {
+ conn, err := dsc.SSHClient.Dial(ctx, address, timeout)
+ if err != nil {
+ return nil, err
+ }
+
+ addrPort := netip.MustParseAddrPort(address)
+ uid, err := uuid.FromBytes(addrPort.Addr().AsSlice())
+ if err != nil {
+ return nil, err
+ }
+
+ m := dsc.dp.machines[shepherd.ProviderID(uid.String())]
+ if m == nil {
+ return nil, fmt.Errorf("failed finding machine in map")
+ }
+
+ return &dummySSHConnection{conn, m}, nil
+}
+
+func (dp *dummyProvider) sshClient() SSHClient {
+ return &dummySSHClient{
+ SSHClient: &FakeSSHClient{},
+ dp: dp,
+ }
+}
+
+func newDummyProvider(cap int) *dummyProvider {
+ return &dummyProvider{
+ capacity: cap,
+ machines: make(map[shepherd.ProviderID]*dummyMachine),
+ }
+}
+
+type dummyProvider struct {
+ capacity int
+ machines map[shepherd.ProviderID]*dummyMachine
+}
+
+func (dp *dummyProvider) createDummyMachines(ctx context.Context, session *bmdb.Session, count int) ([]shepherd.Machine, error) {
+ if len(dp.machines)+count > dp.capacity {
+ return nil, fmt.Errorf("no capacity left")
+ }
+
+ var machines []shepherd.Machine
+ for i := 0; i < count; i++ {
+ uid := uuid.Must(uuid.NewRandom())
+ m, err := dp.CreateMachine(ctx, session, shepherd.CreateMachineRequest{
+ UnusedMachine: &dummyMachine{
+ id: shepherd.ProviderID(uid.String()),
+ state: shepherd.StateKnownUsed,
+ addr: netip.AddrFrom16(uid),
+ },
+ })
+ if err != nil {
+ return nil, err
+ }
+ machines = append(machines, m)
+ }
+
+ return machines, nil
+}
+
+func (dp *dummyProvider) ListMachines(ctx context.Context) ([]shepherd.Machine, error) {
+ var machines []shepherd.Machine
+ for _, m := range dp.machines {
+ machines = append(machines, m)
+ }
+
+ unusedMachineCount := dp.capacity - len(machines)
+ for i := 0; i < unusedMachineCount; i++ {
+ uid := uuid.Must(uuid.NewRandom())
+ machines = append(machines, &dummyMachine{
+ id: shepherd.ProviderID(uid.String()),
+ state: shepherd.StateKnownUnused,
+ addr: netip.AddrFrom16(uid),
+ })
+ }
+
+ return machines, nil
+}
+
+func (dp *dummyProvider) GetMachine(ctx context.Context, id shepherd.ProviderID) (shepherd.Machine, error) {
+ for _, m := range dp.machines {
+ if m.ID() == id {
+ return m, nil
+ }
+ }
+
+ return nil, shepherd.ErrMachineNotFound
+}
+
+func (dp *dummyProvider) CreateMachine(ctx context.Context, session *bmdb.Session, request shepherd.CreateMachineRequest) (shepherd.Machine, error) {
+ dm := request.UnusedMachine.(*dummyMachine)
+
+ err := session.Transact(ctx, func(q *model.Queries) error {
+ // Create a new machine record within BMDB.
+ m, err := q.NewMachine(ctx)
+ if err != nil {
+ return fmt.Errorf("while creating a new machine record in BMDB: %w", err)
+ }
+
+ p := model.MachineAddProvidedParams{
+ MachineID: m.MachineID,
+ ProviderID: string(dm.id),
+ Provider: dp.Type(),
+ }
+ klog.Infof("Setting \"provided\" tag (ID: %s, PID: %s, Provider: %s).", p.MachineID, p.ProviderID, p.Provider)
+ if err := q.MachineAddProvided(ctx, p); err != nil {
+ return fmt.Errorf("while tagging machine active: %w", err)
+ }
+ return nil
+ })
+
+ if err != nil {
+ return nil, err
+ }
+
+ dm.state = shepherd.StateKnownUsed
+ dp.machines[dm.id] = dm
+
+ return dm, nil
+}
+
+func (dp *dummyProvider) Type() model.Provider {
+ return model.ProviderNone
+}
diff --git a/cloud/shepherd/manager/provisioner.go b/cloud/shepherd/manager/provisioner.go
new file mode 100644
index 0000000..a77f241
--- /dev/null
+++ b/cloud/shepherd/manager/provisioner.go
@@ -0,0 +1,392 @@
+package manager
+
+import (
+ "context"
+ "errors"
+ "flag"
+ "fmt"
+ "net/netip"
+ "sort"
+ "time"
+
+ "github.com/google/uuid"
+ "golang.org/x/time/rate"
+ "k8s.io/klog/v2"
+
+ "source.monogon.dev/cloud/bmaas/bmdb"
+ "source.monogon.dev/cloud/bmaas/bmdb/metrics"
+ "source.monogon.dev/cloud/bmaas/bmdb/model"
+ "source.monogon.dev/cloud/shepherd"
+ "source.monogon.dev/go/mflags"
+)
+
+// Provisioner implements the server provisioning logic. Provisioning entails
+// bringing all available machines (subject to limits) into BMDB.
+type Provisioner struct {
+ ProvisionerConfig
+ p shepherd.Provider
+}
+
+// ProvisionerConfig configures the provisioning process.
+type ProvisionerConfig struct {
+ // MaxCount is the maximum count of managed servers. No new devices will be
+ // created after reaching the limit. No attempt will be made to reduce the
+ // server count.
+ MaxCount uint
+
+ // ReconcileLoopLimiter limits the rate of the main reconciliation loop
+ // iterating.
+ ReconcileLoopLimiter *rate.Limiter
+
+ // DeviceCreation limits the rate at which devices are created.
+ DeviceCreationLimiter *rate.Limiter
+
+ // ChunkSize is how many machines will try to be spawned in a
+ // single reconciliation loop. Higher numbers allow for faster initial
+ // provisioning, but lower numbers decrease potential raciness with other systems
+ // and make sure that other parts of the reconciliation logic are ran regularly.
+ //
+ // 20 is decent starting point.
+ ChunkSize uint
+}
+
+func (pc *ProvisionerConfig) RegisterFlags() {
+ flag.UintVar(&pc.MaxCount, "provisioner_max_machines", 50, "Limit of machines that the provisioner will attempt to pull into the BMDB. Zero for no limit.")
+ mflags.Limiter(&pc.ReconcileLoopLimiter, "provisioner_reconciler_rate", "1m,1", "Rate limiting for main provisioner reconciliation loop")
+ mflags.Limiter(&pc.DeviceCreationLimiter, "provisioner_device_creation_rate", "5s,1", "Rate limiting for machine creation")
+ flag.UintVar(&pc.ChunkSize, "provisioner_reservation_chunk_size", 20, "How many machines will the provisioner attempt to create in a single reconciliation loop iteration")
+}
+
+func (pc *ProvisionerConfig) check() error {
+ // If these are unset, it's probably because someone is using us as a library.
+ // Provide error messages useful to code users instead of flag names.
+ if pc.ReconcileLoopLimiter == nil {
+ return fmt.Errorf("ReconcileLoopLimiter must be set")
+ }
+ if pc.DeviceCreationLimiter == nil {
+ return fmt.Errorf("DeviceCreationLimiter must be set")
+ }
+ if pc.ChunkSize == 0 {
+ return fmt.Errorf("ChunkSize must be set")
+ }
+ return nil
+}
+
+// NewProvisioner creates a Provisioner instance, checking ProvisionerConfig and
+// providerConfig for errors.
+func NewProvisioner(p shepherd.Provider, pc ProvisionerConfig) (*Provisioner, error) {
+ if err := pc.check(); err != nil {
+ return nil, err
+ }
+
+ return &Provisioner{
+ ProvisionerConfig: pc,
+ p: p,
+ }, nil
+}
+
+// Run the provisioner blocking the current goroutine until the given context
+// expires.
+func (p *Provisioner) Run(ctx context.Context, conn *bmdb.Connection) error {
+
+ var sess *bmdb.Session
+ var err error
+ for {
+ if sess == nil {
+ sess, err = conn.StartSession(ctx, bmdb.SessionOption{Processor: metrics.ProcessorShepherdProvisioner})
+ if err != nil {
+ return fmt.Errorf("could not start BMDB session: %w", err)
+ }
+ }
+ err = p.runInSession(ctx, sess)
+
+ switch {
+ case err == nil:
+ case errors.Is(err, ctx.Err()):
+ return err
+ case errors.Is(err, bmdb.ErrSessionExpired):
+ klog.Errorf("Session expired, restarting...")
+ sess = nil
+ time.Sleep(time.Second)
+ case err != nil:
+ klog.Errorf("Processing failed: %v", err)
+ // TODO(q3k): close session
+ time.Sleep(time.Second)
+ }
+ }
+}
+
+type machineListing struct {
+ machines []shepherd.Machine
+ err error
+}
+
+// runInSession executes one iteration of the provisioner's control loop within a
+// BMDB session. This control loop attempts to bring all capacity into machines in
+// the BMDB, subject to limits.
+func (p *Provisioner) runInSession(ctx context.Context, sess *bmdb.Session) error {
+ if err := p.ReconcileLoopLimiter.Wait(ctx); err != nil {
+ return err
+ }
+
+ providerC := make(chan *machineListing, 1)
+ bmdbC := make(chan *machineListing, 1)
+
+ klog.Infof("Getting provider and bmdb machines...")
+
+ // Make sub-context for two parallel operations, and so that we can cancel one
+ // immediately if the other fails.
+ subCtx, subCtxC := context.WithCancel(ctx)
+ defer subCtxC()
+
+ go func() {
+ machines, err := p.listInProvider(subCtx)
+ providerC <- &machineListing{
+ machines: machines,
+ err: err,
+ }
+ }()
+ go func() {
+ machines, err := p.listInBMDB(subCtx, sess)
+ bmdbC <- &machineListing{
+ machines: machines,
+ err: err,
+ }
+ }()
+ var inProvider, inBMDB *machineListing
+ for {
+ select {
+ case inProvider = <-providerC:
+ if err := inProvider.err; err != nil {
+ return fmt.Errorf("listing provider machines failed: %w", err)
+ }
+ klog.Infof("Got %d machines in provider.", len(inProvider.machines))
+ case inBMDB = <-bmdbC:
+ if err := inBMDB.err; err != nil {
+ return fmt.Errorf("listing BMDB machines failed: %w", err)
+ }
+ klog.Infof("Got %d machines in BMDB.", len(inBMDB.machines))
+ }
+ if inProvider != nil && inBMDB != nil {
+ break
+ }
+ }
+
+ subCtxC()
+ if err := p.reconcile(ctx, sess, inProvider.machines, inBMDB.machines); err != nil {
+ return fmt.Errorf("reconciliation failed: %w", err)
+ }
+ return nil
+}
+
+// listInProviders returns all machines that the provider thinks we should be
+// managing.
+func (p *Provisioner) listInProvider(ctx context.Context) ([]shepherd.Machine, error) {
+ machines, err := p.p.ListMachines(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("while fetching managed machines: %w", err)
+ }
+ sort.Slice(machines, func(i, j int) bool {
+ return machines[i].ID() < machines[j].ID()
+ })
+ return machines, nil
+}
+
+type providedMachine struct {
+ model.MachineProvided
+}
+
+func (p providedMachine) ID() shepherd.ProviderID {
+ return shepherd.ProviderID(p.ProviderID)
+}
+
+func (p providedMachine) Addr() netip.Addr {
+ if !p.ProviderIpAddress.Valid {
+ return netip.Addr{}
+ }
+
+ addr, err := netip.ParseAddr(p.ProviderIpAddress.String)
+ if err != nil {
+ return netip.Addr{}
+ }
+ return addr
+}
+
+func (p providedMachine) State() shepherd.State {
+ return shepherd.StateKnownUsed
+}
+
+// listInBMDB returns all the machines that the BMDB thinks we should be managing.
+func (p *Provisioner) listInBMDB(ctx context.Context, sess *bmdb.Session) ([]shepherd.Machine, error) {
+ var res []shepherd.Machine
+ err := sess.Transact(ctx, func(q *model.Queries) error {
+ machines, err := q.GetProvidedMachines(ctx, p.p.Type())
+ if err != nil {
+ return err
+ }
+ res = make([]shepherd.Machine, 0, len(machines))
+ for _, machine := range machines {
+ _, err := uuid.Parse(machine.ProviderID)
+ if err != nil {
+ klog.Errorf("BMDB machine %s has unparseable provider ID %q", machine.MachineID, machine.ProviderID)
+ continue
+ }
+
+ res = append(res, providedMachine{machine})
+ }
+ return nil
+ })
+ if err != nil {
+ return nil, err
+ }
+ sort.Slice(res, func(i, j int) bool {
+ return res[i].ID() < res[j].ID()
+ })
+ return res, nil
+}
+
+// resolvePossiblyUsed checks if the state is set to possibly used and finds out
+// which state is the correct one.
+func (p *Provisioner) resolvePossiblyUsed(machine shepherd.Machine, providedMachines map[shepherd.ProviderID]bool) shepherd.State {
+ state, id := machine.State(), machine.ID()
+
+ // Bail out if this isn't a possibly used state.
+ if state != shepherd.StatePossiblyUsed {
+ return state
+ }
+
+ // If a machine does not have a valid id, its always seen as unused.
+ if !id.IsValid() {
+ return shepherd.StateKnownUnused
+ }
+
+ // If the machine is not inside the bmdb, it's seen as unused.
+ if _, ok := providedMachines[id]; !ok {
+ return shepherd.StateKnownUnused
+ }
+
+ return shepherd.StateKnownUsed
+}
+
+// reconcile takes a list of machines that the provider thinks we should be
+// managing and that the BMDB thinks we should be managing, and tries to make
+// sense of that. First, some checks are performed across the two lists to make
+// sure we haven't dropped anything. Then, additional machines are deployed from
+// hardware reservations as needed.
+func (p *Provisioner) reconcile(ctx context.Context, sess *bmdb.Session, inProvider, bmdbMachines []shepherd.Machine) error {
+ klog.Infof("Reconciling...")
+
+ bmdb := make(map[shepherd.ProviderID]bool)
+ for _, machine := range bmdbMachines {
+ // Dont check the state here as its hardcoded to be known used.
+ bmdb[machine.ID()] = true
+ }
+
+ var availableMachines []shepherd.Machine
+ provider := make(map[shepherd.ProviderID]bool)
+ for _, machine := range inProvider {
+ state := p.resolvePossiblyUsed(machine, bmdb)
+
+ switch state {
+ case shepherd.StateKnownUnused:
+ availableMachines = append(availableMachines, machine)
+
+ case shepherd.StateKnownUsed:
+ provider[machine.ID()] = true
+
+ default:
+ return fmt.Errorf("machine has invalid state (ID: %s, Addr: %s): %s", machine.ID(), machine.Addr(), state)
+ }
+ }
+
+ managed := make(map[shepherd.ProviderID]bool)
+
+ // Some desynchronization between the BMDB and Provider point of view might be so
+ // bad we shouldn't attempt to do any work, at least not any time soon.
+ badbadnotgood := false
+
+ // Find any machines supposedly managed by us in the provider, but not in the
+ // BMDB.
+ for machine, _ := range provider {
+ if bmdb[machine] {
+ managed[machine] = true
+ continue
+ }
+ klog.Errorf("Provider machine %s has no corresponding machine in BMDB.", machine)
+ badbadnotgood = true
+ }
+
+ // Find any machines in the BMDB but not in the provider.
+ for machine, _ := range bmdb {
+ if !provider[machine] {
+ klog.Errorf("Provider device ID %s referred to in BMDB (from TODO) but missing in provider.", machine)
+ badbadnotgood = true
+ }
+ }
+
+ // Bail if things are weird.
+ if badbadnotgood {
+ klog.Errorf("Something's very wrong. Bailing early and refusing to do any work.")
+ return fmt.Errorf("fatal discrepency between BMDB and provider")
+ }
+
+ // Summarize all managed machines, which is the intersection of BMDB and
+ // Provisioner machines, usually both of these sets being equal.
+ nmanaged := len(managed)
+ klog.Infof("Total managed machines: %d", nmanaged)
+
+ if p.MaxCount != 0 && p.MaxCount <= uint(nmanaged) {
+ klog.Infof("Not bringing up more machines (at limit of %d machines)", p.MaxCount)
+ return nil
+ }
+
+ limitName := "no limit"
+ if p.MaxCount != 0 {
+ limitName = fmt.Sprintf("%d", p.MaxCount)
+ }
+ klog.Infof("Below managed machine limit (%s), bringing up more...", limitName)
+
+ if len(availableMachines) == 0 {
+ klog.Infof("No more capacity available.")
+ return nil
+ }
+
+ toProvision := availableMachines
+ // Limit them to MaxCount, if applicable.
+ if p.MaxCount != 0 {
+ needed := int(p.MaxCount) - nmanaged
+ if len(toProvision) < needed {
+ needed = len(toProvision)
+ }
+ toProvision = toProvision[:needed]
+ }
+
+ // Limit them to an arbitrary 'chunk' size so that we don't do too many things in
+ // a single reconciliation operation.
+ if uint(len(toProvision)) > p.ChunkSize {
+ toProvision = toProvision[:p.ChunkSize]
+ }
+
+ if len(toProvision) == 0 {
+ klog.Infof("No more unused machines available, or all filtered out.")
+ return nil
+ }
+
+ klog.Infof("Bringing up %d machines...", len(toProvision))
+ for _, machine := range toProvision {
+ if err := p.DeviceCreationLimiter.Wait(ctx); err != nil {
+ return err
+ }
+
+ nd, err := p.p.CreateMachine(ctx, sess, shepherd.CreateMachineRequest{
+ UnusedMachine: machine,
+ })
+ if err != nil {
+ klog.Errorf("while creating new device (ID: %s, Addr: %s, State: %s): %w", machine.ID(), machine.Addr(), machine.State(), err)
+ continue
+ }
+ klog.Infof("Created new machine with ID: %s", nd.ID())
+ }
+
+ return nil
+}
diff --git a/cloud/shepherd/manager/provisioner_test.go b/cloud/shepherd/manager/provisioner_test.go
new file mode 100644
index 0000000..5adc408
--- /dev/null
+++ b/cloud/shepherd/manager/provisioner_test.go
@@ -0,0 +1,138 @@
+package manager
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "golang.org/x/time/rate"
+
+ "source.monogon.dev/cloud/bmaas/bmdb"
+ "source.monogon.dev/cloud/bmaas/bmdb/model"
+ "source.monogon.dev/cloud/lib/component"
+ "source.monogon.dev/cloud/shepherd"
+)
+
+// TestProvisionerSmokes makes sure the Provisioner doesn't go up in flames on
+// the happy path.
+func TestProvisionerSmokes(t *testing.T) {
+ pc := ProvisionerConfig{
+ MaxCount: 10,
+ // We need 3 iterations to provide 10 machines with a chunk size of 4.
+ ReconcileLoopLimiter: rate.NewLimiter(rate.Every(10*time.Second), 3),
+ DeviceCreationLimiter: rate.NewLimiter(rate.Every(time.Second), 10),
+ ChunkSize: 4,
+ }
+
+ provider := newDummyProvider(100)
+
+ p, err := NewProvisioner(provider, pc)
+ if err != nil {
+ t.Fatalf("Could not create Provisioner: %v", err)
+ }
+
+ ctx, ctxC := context.WithCancel(context.Background())
+ defer ctxC()
+
+ b := bmdb.BMDB{
+ Config: bmdb.Config{
+ Database: component.CockroachConfig{
+ InMemory: true,
+ },
+ ComponentName: "test",
+ RuntimeInfo: "test",
+ },
+ }
+ conn, err := b.Open(true)
+ if err != nil {
+ t.Fatalf("Could not create in-memory BMDB: %v", err)
+ }
+
+ go p.Run(ctx, conn)
+
+ sess, err := conn.StartSession(ctx)
+ if err != nil {
+ t.Fatalf("Failed to create BMDB session for verification: %v", err)
+ }
+ for {
+ time.Sleep(100 * time.Millisecond)
+
+ var provided []model.MachineProvided
+ err = sess.Transact(ctx, func(q *model.Queries) error {
+ var err error
+ provided, err = q.GetProvidedMachines(ctx, provider.Type())
+ return err
+ })
+ if err != nil {
+ t.Errorf("Transact failed: %v", err)
+ }
+ if len(provided) < 10 {
+ continue
+ }
+ if len(provided) > 10 {
+ t.Errorf("%d machines provided (limit: 10)", len(provided))
+ }
+
+ for _, mp := range provided {
+ if provider.machines[shepherd.ProviderID(mp.ProviderID)] == nil {
+ t.Errorf("BMDB machine %q has unknown provider ID %q", mp.MachineID, mp.ProviderID)
+ }
+ }
+
+ return
+ }
+}
+
+// TestProvisioner_resolvePossiblyUsed makes sure the PossiblyUsed state is
+// resolved correctly.
+func TestProvisioner_resolvePossiblyUsed(t *testing.T) {
+ const providedMachineID = "provided-machine"
+
+ providedMachines := map[shepherd.ProviderID]bool{
+ providedMachineID: true,
+ }
+
+ tests := []struct {
+ name string
+ machineID shepherd.ProviderID
+ machineState shepherd.State
+ wantedState shepherd.State
+ }{
+ {
+ name: "skip KnownUsed",
+ machineState: shepherd.StateKnownUsed,
+ wantedState: shepherd.StateKnownUsed,
+ },
+ {
+ name: "skip KnownUnused",
+ machineState: shepherd.StateKnownUnused,
+ wantedState: shepherd.StateKnownUnused,
+ },
+ {
+ name: "invalid ID",
+ machineID: shepherd.InvalidProviderID,
+ machineState: shepherd.StatePossiblyUsed,
+ wantedState: shepherd.StateKnownUnused,
+ },
+ {
+ name: "valid ID, not in providedMachines",
+ machineID: "unused-machine",
+ machineState: shepherd.StatePossiblyUsed,
+ wantedState: shepherd.StateKnownUnused,
+ },
+ {
+ name: "valid ID, in providedMachines",
+ machineID: providedMachineID,
+ machineState: shepherd.StatePossiblyUsed,
+ wantedState: shepherd.StateKnownUsed,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ p := &Provisioner{}
+ if got := p.resolvePossiblyUsed(&dummyMachine{id: tt.machineID, state: tt.machineState}, providedMachines); got != tt.wantedState {
+ t.Errorf("resolvePossiblyUsed() = %v, want %v", got, tt.wantedState)
+ }
+ })
+ }
+}
diff --git a/cloud/shepherd/manager/recoverer.go b/cloud/shepherd/manager/recoverer.go
new file mode 100644
index 0000000..a94700a
--- /dev/null
+++ b/cloud/shepherd/manager/recoverer.go
@@ -0,0 +1,81 @@
+package manager
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "k8s.io/klog/v2"
+
+ "source.monogon.dev/cloud/bmaas/bmdb"
+ "source.monogon.dev/cloud/bmaas/bmdb/metrics"
+ "source.monogon.dev/cloud/bmaas/bmdb/model"
+ "source.monogon.dev/cloud/shepherd"
+)
+
+type RecovererConfig struct {
+ ControlLoopConfig
+}
+
+func (r *RecovererConfig) RegisterFlags() {
+ r.ControlLoopConfig.RegisterFlags("recoverer")
+}
+
+// The Recoverer reboots machines whose agent has stopped sending heartbeats or
+// has not sent any heartbeats at all.
+type Recoverer struct {
+ RecovererConfig
+ r shepherd.Recoverer
+}
+
+func NewRecoverer(r shepherd.Recoverer, rc RecovererConfig) (*Recoverer, error) {
+ if err := rc.ControlLoopConfig.Check(); err != nil {
+ return nil, err
+ }
+ return &Recoverer{
+ RecovererConfig: rc,
+ r: r,
+ }, nil
+}
+
+func (r *Recoverer) getProcessInfo() processInfo {
+ return processInfo{
+ process: model.ProcessShepherdRecovery,
+ defaultBackoff: bmdb.Backoff{
+ Initial: 1 * time.Minute,
+ Maximum: 1 * time.Hour,
+ Exponent: 1.2,
+ },
+ processor: metrics.ProcessorShepherdRecoverer,
+ }
+}
+
+func (r *Recoverer) getMachines(ctx context.Context, q *model.Queries, limit int32) ([]model.MachineProvided, error) {
+ return q.GetMachineForAgentRecovery(ctx, model.GetMachineForAgentRecoveryParams{
+ Limit: limit,
+ Provider: r.r.Type(),
+ })
+}
+
+func (r *Recoverer) processMachine(ctx context.Context, t *task) error {
+ klog.Infof("Starting recovery of machine (ID: %s, PID %s)", t.machine.MachineID, t.machine.ProviderID)
+
+ if err := r.r.RebootMachine(ctx, shepherd.ProviderID(t.machine.ProviderID)); err != nil {
+ return fmt.Errorf("failed to reboot machine: %w", err)
+ }
+
+ klog.Infof("Removing AgentStarted/AgentHeartbeat (ID: %s, PID: %s)...", t.machine.MachineID, t.machine.ProviderID)
+ err := t.work.Finish(ctx, func(q *model.Queries) error {
+ if err := q.MachineDeleteAgentStarted(ctx, t.machine.MachineID); err != nil {
+ return fmt.Errorf("while deleting AgentStarted: %w", err)
+ }
+ if err := q.MachineDeleteAgentHeartbeat(ctx, t.machine.MachineID); err != nil {
+ return fmt.Errorf("while deleting AgentHeartbeat: %w", err)
+ }
+ return nil
+ })
+ if err != nil {
+ return fmt.Errorf("while deleting AgentStarted/AgentHeartbeat tags: %w", err)
+ }
+ return nil
+}
diff --git a/cloud/shepherd/manager/ssh_client.go b/cloud/shepherd/manager/ssh_client.go
new file mode 100644
index 0000000..a1a305a
--- /dev/null
+++ b/cloud/shepherd/manager/ssh_client.go
@@ -0,0 +1,143 @@
+package manager
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "io"
+ "net"
+ "time"
+
+ "github.com/pkg/sftp"
+ "golang.org/x/crypto/ssh"
+)
+
+// SSHClient defines a simple interface to an abstract SSH client. Usually this
+// would be PlainSSHClient, but tests can use this interface to dependency-inject
+// fake SSH connections.
+type SSHClient interface {
+ // Dial returns an SSHConnection to a given address (host:port pair) with
+ // a timeout for connection.
+ Dial(ctx context.Context, address string, connectTimeout time.Duration) (SSHConnection, error)
+}
+
+type SSHConnection interface {
+ // Execute a given command on a remote host synchronously, passing in stdin as
+ // input, and returning a captured stdout/stderr. The returned data might be
+ // valid even when err != nil, which might happen if the remote side returned a
+ // non-zero exit code.
+ Execute(ctx context.Context, command string, stdin []byte) (stdout []byte, stderr []byte, err error)
+ // Upload a given blob to a targetPath on the system and make executable.
+ Upload(ctx context.Context, targetPath string, data []byte) error
+ // Close this connection.
+ Close() error
+}
+
+// PlainSSHClient implements SSHClient (and SSHConnection) using
+// golang.org/x/crypto/ssh.
+type PlainSSHClient struct {
+ AuthMethod ssh.AuthMethod
+ Username string
+}
+
+type plainSSHConn struct {
+ cl *ssh.Client
+}
+
+func (p *PlainSSHClient) Dial(ctx context.Context, address string, connectTimeout time.Duration) (SSHConnection, error) {
+ d := net.Dialer{
+ Timeout: connectTimeout,
+ }
+ conn, err := d.DialContext(ctx, "tcp", address)
+ if err != nil {
+ return nil, err
+ }
+ conf := &ssh.ClientConfig{
+ User: p.Username,
+ Auth: []ssh.AuthMethod{
+ p.AuthMethod,
+ },
+ // Ignore the host key, since it's likely the first time anything logs into
+ // this device, and also because there's no way of knowing its fingerprint.
+ HostKeyCallback: ssh.InsecureIgnoreHostKey(),
+ // Timeout sets a bound on the time it takes to set up the connection, but
+ // not on total session time.
+ Timeout: connectTimeout,
+ }
+ conn2, chanC, reqC, err := ssh.NewClientConn(conn, address, conf)
+ if err != nil {
+ return nil, err
+ }
+ cl := ssh.NewClient(conn2, chanC, reqC)
+ return &plainSSHConn{
+ cl: cl,
+ }, nil
+}
+
+func (p *plainSSHConn) Execute(ctx context.Context, command string, stdin []byte) (stdout []byte, stderr []byte, err error) {
+ sess, err := p.cl.NewSession()
+ if err != nil {
+ return nil, nil, fmt.Errorf("while creating SSH session: %w", err)
+ }
+ stdoutBuf := bytes.NewBuffer(nil)
+ stderrBuf := bytes.NewBuffer(nil)
+ sess.Stdin = bytes.NewBuffer(stdin)
+ sess.Stdout = stdoutBuf
+ sess.Stderr = stderrBuf
+ defer sess.Close()
+
+ if err := sess.Start(command); err != nil {
+ return nil, nil, err
+ }
+ doneC := make(chan error, 1)
+ go func() {
+ doneC <- sess.Wait()
+ }()
+ select {
+ case <-ctx.Done():
+ return nil, nil, ctx.Err()
+ case err := <-doneC:
+ return stdoutBuf.Bytes(), stderrBuf.Bytes(), err
+ }
+}
+
+func (p *plainSSHConn) Upload(ctx context.Context, targetPath string, data []byte) error {
+ sc, err := sftp.NewClient(p.cl)
+ if err != nil {
+ return fmt.Errorf("while building sftp client: %w", err)
+ }
+ defer sc.Close()
+
+ acrdr := bytes.NewReader(data)
+ df, err := sc.Create(targetPath)
+ if err != nil {
+ return fmt.Errorf("while creating file on the host: %w", err)
+ }
+
+ doneC := make(chan error, 1)
+
+ go func() {
+ _, err := io.Copy(df, acrdr)
+ df.Close()
+ doneC <- err
+ }()
+
+ select {
+ case err := <-doneC:
+ if err != nil {
+ return fmt.Errorf("while copying file: %w", err)
+ }
+ case <-ctx.Done():
+ df.Close()
+ return ctx.Err()
+ }
+
+ if err := sc.Chmod(targetPath, 0755); err != nil {
+ return fmt.Errorf("while setting file permissions: %w", err)
+ }
+ return nil
+}
+
+func (p *plainSSHConn) Close() error {
+ return p.cl.Close()
+}
diff --git a/cloud/shepherd/manager/ssh_key_signer.go b/cloud/shepherd/manager/ssh_key_signer.go
new file mode 100644
index 0000000..7a8d08a
--- /dev/null
+++ b/cloud/shepherd/manager/ssh_key_signer.go
@@ -0,0 +1,108 @@
+package manager
+
+import (
+ "crypto/ed25519"
+ "crypto/rand"
+ "flag"
+ "fmt"
+ "os"
+ "sync"
+
+ "golang.org/x/crypto/ssh"
+ "k8s.io/klog/v2"
+)
+
+type SSHKey struct {
+ // myKey guards Key.
+ muKey sync.Mutex
+
+ // SSH key to use when creating machines and then connecting to them. If not
+ // provided, it will be automatically loaded from KeyPersistPath, and if that
+ // doesn't exist either, it will be first generated and persisted there.
+ Key ed25519.PrivateKey
+
+ // Path at which the SSH key will be loaded from and persisted to, if Key is not
+ // explicitly set. Either KeyPersistPath or Key must be set.
+ KeyPersistPath string
+}
+
+func (c *SSHKey) RegisterFlags() {
+ flag.StringVar(&c.KeyPersistPath, "ssh_key_path", "", "Local filesystem path to read SSH key from, and save generated key to")
+}
+
+// sshKey returns the SSH key as defined by the Key and KeyPersistPath options,
+// loading/generating/persisting it as necessary.
+func (c *SSHKey) sshKey() (ed25519.PrivateKey, error) {
+ c.muKey.Lock()
+ defer c.muKey.Unlock()
+
+ if c.Key != nil {
+ return c.Key, nil
+ }
+ if c.KeyPersistPath == "" {
+ return nil, fmt.Errorf("-ssh_key_path must be set")
+ }
+
+ data, err := os.ReadFile(c.KeyPersistPath)
+ switch {
+ case err == nil:
+ if len(data) != ed25519.PrivateKeySize {
+ return nil, fmt.Errorf("%s is not a valid ed25519 private key", c.KeyPersistPath)
+ }
+ c.Key = data
+ klog.Infof("Loaded SSH key from %s", c.KeyPersistPath)
+ return c.Key, nil
+ case os.IsNotExist(err):
+ if err := c.sshGenerateUnlocked(); err != nil {
+ return nil, err
+ }
+ if err := os.WriteFile(c.KeyPersistPath, c.Key, 0400); err != nil {
+ return nil, fmt.Errorf("could not persist key: %w", err)
+ }
+ return c.Key, nil
+ default:
+ return nil, fmt.Errorf("could not load peristed key: %w", err)
+ }
+}
+
+// PublicKey returns the SSH public key marshaled for use, based on sshKey.
+func (c *SSHKey) PublicKey() (string, error) {
+ private, err := c.sshKey()
+ if err != nil {
+ return "", err
+ }
+ // Marshal the public key part in OpenSSH authorized_keys.
+ sshpub, err := ssh.NewPublicKey(private.Public())
+ if err != nil {
+ return "", fmt.Errorf("while building SSH public key: %w", err)
+ }
+ return string(ssh.MarshalAuthorizedKey(sshpub)), nil
+}
+
+// Signer builds an ssh.Signer (for use in SSH connections) based on sshKey.
+func (c *SSHKey) Signer() (ssh.Signer, error) {
+ private, err := c.sshKey()
+ if err != nil {
+ return nil, err
+ }
+ // Set up the internal ssh.Signer to be later used to initiate SSH
+ // connections with newly provided hosts.
+ signer, err := ssh.NewSignerFromKey(private)
+ if err != nil {
+ return nil, fmt.Errorf("while building SSH signer: %w", err)
+ }
+ return signer, nil
+}
+
+// sshGenerateUnlocked saves a new private key into SharedConfig.Key.
+func (c *SSHKey) sshGenerateUnlocked() error {
+ if c.Key != nil {
+ return nil
+ }
+ _, priv, err := ed25519.GenerateKey(rand.Reader)
+ if err != nil {
+ return fmt.Errorf("while generating SSH key: %w", err)
+ }
+ c.Key = priv
+ return nil
+}
diff --git a/cloud/shepherd/manager/test_agent/BUILD.bazel b/cloud/shepherd/manager/test_agent/BUILD.bazel
new file mode 100644
index 0000000..7636cdd
--- /dev/null
+++ b/cloud/shepherd/manager/test_agent/BUILD.bazel
@@ -0,0 +1,28 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_library")
+load("//build/static_binary_tarball:def.bzl", "static_binary_tarball")
+
+go_binary(
+ name = "test_agent",
+ embed = [":test_agent_lib"],
+ visibility = [
+ "//cloud/shepherd/manager:__pkg__",
+ ],
+)
+
+go_library(
+ name = "test_agent_lib",
+ srcs = ["main.go"],
+ importpath = "source.monogon.dev/cloud/shepherd/manager/test_agent",
+ visibility = ["//visibility:private"],
+ deps = [
+ "//cloud/agent/api",
+ "@org_golang_google_protobuf//proto",
+ ],
+)
+
+# Used by container_images, forces a static build of the test_agent.
+static_binary_tarball(
+ name = "test_agent_layer",
+ executable = ":test_agent",
+ visibility = ["//visibility:public"],
+)
diff --git a/cloud/shepherd/manager/test_agent/main.go b/cloud/shepherd/manager/test_agent/main.go
new file mode 100644
index 0000000..8f29c30
--- /dev/null
+++ b/cloud/shepherd/manager/test_agent/main.go
@@ -0,0 +1,54 @@
+// test_agent is used by the Equinix Metal Manager test code. Its only role
+// is to ensure successful delivery of the BMaaS agent executable to the test
+// hosts, together with its subsequent execution.
+package main
+
+import (
+ "crypto/ed25519"
+ "crypto/rand"
+ "fmt"
+ "io"
+ "os"
+
+ "google.golang.org/protobuf/proto"
+
+ apb "source.monogon.dev/cloud/agent/api"
+)
+
+func main() {
+ // The agent initialization message will arrive from Shepherd on Agent's
+ // standard input.
+ aimb, err := io.ReadAll(os.Stdin)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "while reading AgentInit message: %v\n", err)
+ return
+ }
+ var aim apb.TakeoverInit
+ if err := proto.Unmarshal(aimb, &aim); err != nil {
+ fmt.Fprintf(os.Stderr, "while unmarshaling TakeoverInit message: %v\n", err)
+ return
+ }
+
+ // Agent should send back apb.TakeoverResponse on its standard output.
+ pub, _, err := ed25519.GenerateKey(rand.Reader)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "while generating agent public key: %v\n", err)
+ return
+ }
+ arsp := apb.TakeoverResponse{
+ Result: &apb.TakeoverResponse_Success{Success: &apb.TakeoverSuccess{
+ InitMessage: &aim,
+ Key: pub,
+ }},
+ }
+ arspb, err := proto.Marshal(&arsp)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "while marshaling TakeoverResponse message: %v\n", err)
+ return
+ }
+ if _, err := os.Stdout.Write(arspb); err != nil {
+ fmt.Fprintf(os.Stderr, "while writing TakeoverResponse message: %v\n", err)
+ }
+ // The agent must detach and/or terminate after sending back the reply.
+ // Failure to do so will leave the session hanging.
+}