cloud: move takeover to agent/takeover
The takeover package is tightly coupled with the agent, so lets move it
there.
Change-Id: I38ae69d4f4e7a4f6a04b0fefb5f127ebc71f5961
Reviewed-on: https://review.monogon.dev/c/monogon/+/2790
Reviewed-by: Lorenz Brun <lorenz@monogon.tech>
Tested-by: Jenkins CI
diff --git a/cloud/agent/takeover/BUILD.bazel b/cloud/agent/takeover/BUILD.bazel
new file mode 100644
index 0000000..855621a
--- /dev/null
+++ b/cloud/agent/takeover/BUILD.bazel
@@ -0,0 +1,71 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_library")
+load("//metropolis/node/build/mkucode:def.bzl", "cpio_ucode")
+load("//build/static_binary_tarball:def.bzl", "static_binary_tarball")
+load("//metropolis/node/build:def.bzl", "node_initramfs")
+load("//metropolis/node/build/fwprune:def.bzl", "fsspec_linux_firmware")
+
+go_library(
+ name = "takeover_lib",
+ srcs = ["takeover.go"],
+ embedsrcs = [
+ "//third_party/linux", #keep
+ ":ucode", #keep
+ ":initramfs", #keep
+ ],
+ importpath = "source.monogon.dev/cloud/agent/takeover",
+ visibility = ["//visibility:private"],
+ deps = [
+ "//cloud/agent/api",
+ "//metropolis/pkg/bootparam",
+ "//metropolis/pkg/kexec",
+ "//net/dump",
+ "//net/proto",
+ "@com_github_cavaliergopher_cpio//:cpio",
+ "@com_github_klauspost_compress//zstd",
+ "@org_golang_google_protobuf//proto",
+ "@org_golang_x_sys//unix",
+ ],
+)
+
+node_initramfs(
+ name = "initramfs",
+ files = {
+ "//cloud/agent:agent": "/init",
+ "@com_github_coredns_coredns//:coredns": "/kubernetes/bin/coredns",
+ "//metropolis/node/core/network/dns:resolv.conf": "/etc/resolv.conf",
+ "@cacerts//file": "/etc/ssl/cert.pem",
+ },
+ fsspecs = [
+ "//metropolis/node/build:earlydev.fsspec",
+ ":firmware",
+ ],
+ visibility = ["//cloud/agent:__subpackages__"],
+)
+
+go_binary(
+ name = "takeover",
+ embed = [":takeover_lib"],
+ visibility = ["//visibility:public"],
+)
+
+cpio_ucode(
+ name = "ucode",
+ ucode = {
+ "@linux-firmware//:amd_ucode": "AuthenticAMD",
+ "@intel_ucode//:fam6h": "GenuineIntel",
+ },
+)
+
+fsspec_linux_firmware(
+ name = "firmware",
+ firmware_files = ["@linux-firmware//:all_files"],
+ kernel = "//third_party/linux",
+ metadata = "@linux-firmware//:metadata",
+)
+
+# Used by container_images, forces a static build of the test_agent.
+static_binary_tarball(
+ name = "takeover_layer",
+ executable = ":takeover",
+ visibility = ["//visibility:public"],
+)
diff --git a/cloud/agent/takeover/e2e/BUILD.bazel b/cloud/agent/takeover/e2e/BUILD.bazel
new file mode 100644
index 0000000..1cdd840
--- /dev/null
+++ b/cloud/agent/takeover/e2e/BUILD.bazel
@@ -0,0 +1,20 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
+go_test(
+ name = "e2e_test",
+ srcs = ["main_test.go"],
+ data = [
+ "//cloud/agent/takeover",
+ "//third_party/edk2:firmware",
+ "@debian_11_cloudimage//file",
+ ],
+ deps = [
+ "//cloud/agent/api",
+ "//metropolis/pkg/fat32",
+ "//metropolis/pkg/freeport",
+ "@com_github_pkg_sftp//:sftp",
+ "@io_bazel_rules_go//go/runfiles:go_default_library",
+ "@org_golang_google_protobuf//proto",
+ "@org_golang_x_crypto//ssh",
+ ],
+)
diff --git a/cloud/agent/takeover/e2e/main_test.go b/cloud/agent/takeover/e2e/main_test.go
new file mode 100644
index 0000000..6d489eb
--- /dev/null
+++ b/cloud/agent/takeover/e2e/main_test.go
@@ -0,0 +1,219 @@
+package e2e
+
+import (
+ "bufio"
+ "bytes"
+ "crypto/ed25519"
+ "crypto/rand"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net"
+ "os"
+ "os/exec"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/bazelbuild/rules_go/go/runfiles"
+ "github.com/pkg/sftp"
+ "golang.org/x/crypto/ssh"
+ "google.golang.org/protobuf/proto"
+
+ "source.monogon.dev/cloud/agent/api"
+
+ "source.monogon.dev/metropolis/pkg/fat32"
+ "source.monogon.dev/metropolis/pkg/freeport"
+)
+
+func TestE2E(t *testing.T) {
+ pubKey, privKey, err := ed25519.GenerateKey(rand.Reader)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ sshPubKey, err := ssh.NewPublicKey(pubKey)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ sshPrivkey, err := ssh.NewSignerFromKey(privKey)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // CloudConfig doesn't really have a rigid spec, so just put things into it
+ cloudConfig := make(map[string]any)
+ cloudConfig["ssh_authorized_keys"] = []string{
+ strings.TrimSuffix(string(ssh.MarshalAuthorizedKey(sshPubKey)), "\n"),
+ }
+
+ userData, err := json.Marshal(cloudConfig)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ rootInode := fat32.Inode{
+ Attrs: fat32.AttrDirectory,
+ Children: []*fat32.Inode{
+ {
+ Name: "user-data",
+ Content: strings.NewReader("#cloud-config\n" + string(userData)),
+ },
+ {
+ Name: "meta-data",
+ Content: strings.NewReader(""),
+ },
+ },
+ }
+ cloudInitDataFile, err := os.CreateTemp("", "cidata*.img")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.Remove(cloudInitDataFile.Name())
+ if err := fat32.WriteFS(cloudInitDataFile, rootInode, fat32.Options{Label: "cidata"}); err != nil {
+ t.Fatal(err)
+ }
+ cloudImagePath, err := runfiles.Rlocation("debian_11_cloudimage/file/downloaded")
+ if err != nil {
+ t.Fatal(err)
+ }
+ ovmfVarsPath, err := runfiles.Rlocation("edk2/OVMF_VARS.fd")
+ if err != nil {
+ t.Fatal(err)
+ }
+ ovmfCodePath, err := runfiles.Rlocation("edk2/OVMF_CODE.fd")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ sshPort, sshPortCloser, err := freeport.AllocateTCPPort()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ qemuArgs := []string{
+ "-machine", "q35", "-accel", "kvm", "-nographic", "-nodefaults", "-m", "1024",
+ "-cpu", "host", "-smp", "sockets=1,cpus=1,cores=2,threads=2,maxcpus=4",
+ "-drive", "if=pflash,format=raw,readonly=on,file=" + ovmfCodePath,
+ "-drive", "if=pflash,format=raw,snapshot=on,file=" + ovmfVarsPath,
+ "-drive", "if=virtio,format=qcow2,snapshot=on,cache=unsafe,file=" + cloudImagePath,
+ "-drive", "if=virtio,format=raw,snapshot=on,file=" + cloudInitDataFile.Name(),
+ "-netdev", fmt.Sprintf("user,id=net0,net=10.42.0.0/24,dhcpstart=10.42.0.10,hostfwd=tcp::%d-:22", sshPort),
+ "-device", "virtio-net-pci,netdev=net0,mac=22:d5:8e:76:1d:07",
+ "-device", "virtio-rng-pci",
+ "-serial", "stdio",
+ "-no-reboot",
+ }
+ qemuCmd := exec.Command("qemu-system-x86_64", qemuArgs...)
+ stdoutPipe, err := qemuCmd.StdoutPipe()
+ if err != nil {
+ t.Fatal(err)
+ }
+ agentStarted := make(chan struct{})
+ go func() {
+ s := bufio.NewScanner(stdoutPipe)
+ for s.Scan() {
+ t.Log("kernel: " + s.Text())
+ if strings.Contains(s.Text(), "Monogon BMaaS Agent started") {
+ agentStarted <- struct{}{}
+ break
+ }
+ }
+ qemuCmd.Wait()
+ }()
+ qemuCmd.Stderr = os.Stderr
+ sshPortCloser.Close()
+ if err := qemuCmd.Start(); err != nil {
+ t.Fatal(err)
+ }
+ defer qemuCmd.Process.Kill()
+
+ var c *ssh.Client
+ for {
+ c, err = ssh.Dial("tcp", net.JoinHostPort("localhost", fmt.Sprintf("%d", sshPort)), &ssh.ClientConfig{
+ User: "debian",
+ Auth: []ssh.AuthMethod{ssh.PublicKeys(sshPrivkey)},
+ HostKeyCallback: ssh.InsecureIgnoreHostKey(),
+ Timeout: 5 * time.Second,
+ })
+ if err != nil {
+ t.Logf("error connecting via SSH, retrying: %v", err)
+ time.Sleep(1 * time.Second)
+ continue
+ }
+ break
+ }
+ defer c.Close()
+ sc, err := sftp.NewClient(c)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer sc.Close()
+ takeoverFile, err := sc.Create("takeover")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer takeoverFile.Close()
+ if err := takeoverFile.Chmod(0o755); err != nil {
+ t.Fatal(err)
+ }
+ takeoverPath, err := runfiles.Rlocation("_main/cloud/agent/takeover/takeover_/takeover")
+ if err != nil {
+ t.Fatal(err)
+ }
+ takeoverSrcFile, err := os.Open(takeoverPath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer takeoverSrcFile.Close()
+ if _, err := io.Copy(takeoverFile, takeoverSrcFile); err != nil {
+ t.Fatal(err)
+ }
+ if err := takeoverFile.Close(); err != nil {
+ t.Fatal(err)
+ }
+ sc.Close()
+
+ sess, err := c.NewSession()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer sess.Close()
+
+ init := api.TakeoverInit{
+ MachineId: "test",
+ BmaasEndpoint: "localhost:1234",
+ }
+ initRaw, err := proto.Marshal(&init)
+ if err != nil {
+ t.Fatal(err)
+ }
+ sess.Stdin = bytes.NewReader(initRaw)
+ var stdoutBuf bytes.Buffer
+ var stderrBuf bytes.Buffer
+ sess.Stdout = &stdoutBuf
+ sess.Stderr = &stderrBuf
+ if err := sess.Run("sudo ./takeover"); err != nil {
+ t.Errorf("stderr:\n%s\n\n", stderrBuf.String())
+ t.Fatal(err)
+ }
+ var resp api.TakeoverResponse
+ if err := proto.Unmarshal(stdoutBuf.Bytes(), &resp); err != nil {
+ t.Fatal(err)
+ }
+ switch res := resp.Result.(type) {
+ case *api.TakeoverResponse_Success:
+ if res.Success.InitMessage.BmaasEndpoint != init.BmaasEndpoint {
+ t.Error("InitMessage not passed through properly")
+ }
+ case *api.TakeoverResponse_Error:
+ t.Fatalf("takeover returned error: %v", res.Error.Message)
+ }
+ select {
+ case <-agentStarted:
+ // Done, test passed
+ case <-time.After(30 * time.Second):
+ t.Fatal("Waiting for BMaaS agent startup timed out")
+ }
+}
diff --git a/cloud/agent/takeover/takeover.go b/cloud/agent/takeover/takeover.go
new file mode 100644
index 0000000..d313174
--- /dev/null
+++ b/cloud/agent/takeover/takeover.go
@@ -0,0 +1,241 @@
+// takeover is a self-contained executable which when executed loads the BMaaS
+// agent via kexec. It is intended to be called over SSH, given a binary
+// TakeoverInit message over standard input and (if all preparation work
+// completed successfully) will respond with a TakeoverResponse on standard
+// output. At that point the new kernel and agent initramfs are fully staged
+// by the current kernel.
+// The second stage which is also part of this binary, selected by an
+// environment variable, is then executed in detached mode and the main
+// takeover binary called over SSH terminates.
+// The second stage waits for 5 seconds for the main binary to exit, the SSH
+// session to be torn down and various other things before issuing the final
+// non-returning syscall which jumps into the new kernel.
+
+package main
+
+import (
+ "bytes"
+ "crypto/ed25519"
+ "crypto/rand"
+ _ "embed"
+ "errors"
+ "fmt"
+ "io"
+ "log"
+ "os"
+ "os/exec"
+ "time"
+
+ "github.com/cavaliergopher/cpio"
+ "github.com/klauspost/compress/zstd"
+ "golang.org/x/sys/unix"
+ "google.golang.org/protobuf/proto"
+
+ "source.monogon.dev/cloud/agent/api"
+ "source.monogon.dev/metropolis/pkg/bootparam"
+ "source.monogon.dev/metropolis/pkg/kexec"
+ netdump "source.monogon.dev/net/dump"
+ netapi "source.monogon.dev/net/proto"
+)
+
+//go:embed third_party/linux/bzImage
+var kernel []byte
+
+//go:embed ucode.cpio
+var ucode []byte
+
+//go:embed initramfs.cpio.zst
+var initramfs []byte
+
+// newMemfile creates a new file which is not located on a specific filesystem,
+// but is instead backed by anonymous memory.
+func newMemfile(name string, flags int) (*os.File, error) {
+ fd, err := unix.MemfdCreate(name, flags)
+ if err != nil {
+ return nil, fmt.Errorf("memfd_create failed: %w", err)
+ }
+ return os.NewFile(uintptr(fd), name), nil
+}
+
+func setupTakeover() (*api.TakeoverSuccess, error) {
+ // Read init specification from stdin.
+ initRaw, err := io.ReadAll(os.Stdin)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read TakeoverInit message from stdin: %w", err)
+ }
+ var takeoverInit api.TakeoverInit
+ if err := proto.Unmarshal(initRaw, &takeoverInit); err != nil {
+ return nil, fmt.Errorf("failed to parse TakeoverInit messag from stdin: %w", err)
+ }
+
+ // Sanity check for empty TakeoverInit messages
+ if takeoverInit.BmaasEndpoint == "" {
+ return nil, errors.New("BMaaS endpoint is empty, check that a proper TakeoverInit message has been provided")
+ }
+
+ // Load data from embedded files into memfiles as the kexec load syscall
+ // requires file descriptors.
+ kernelFile, err := newMemfile("kernel", 0)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create kernel memfile: %w", err)
+ }
+ initramfsFile, err := newMemfile("initramfs", 0)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create initramfs memfile: %w", err)
+ }
+ if _, err := kernelFile.ReadFrom(bytes.NewReader(kernel)); err != nil {
+ return nil, fmt.Errorf("failed to read kernel into memory-backed file: %w", err)
+ }
+ if _, err := initramfsFile.ReadFrom(bytes.NewReader(ucode)); err != nil {
+ return nil, fmt.Errorf("failed to read ucode into memory-backed file: %w", err)
+ }
+ if _, err := initramfsFile.ReadFrom(bytes.NewReader(initramfs)); err != nil {
+ return nil, fmt.Errorf("failed to read initramfs into memory-backed file: %w", err)
+ }
+
+ // Dump the current network configuration
+ netconf, warnings, err := netdump.Dump()
+ if err != nil {
+ return nil, fmt.Errorf("failed to dump network configuration: %w", err)
+ }
+
+ if len(netconf.Nameserver) == 0 {
+ netconf.Nameserver = []*netapi.Nameserver{{
+ Ip: "8.8.8.8",
+ }, {
+ Ip: "1.1.1.1",
+ }}
+ }
+
+ // Generate agent private key
+ pubKey, privKey, err := ed25519.GenerateKey(rand.Reader)
+ if err != nil {
+ return nil, fmt.Errorf("unable to generate Ed25519 key: %w", err)
+ }
+
+ agentInit := api.AgentInit{
+ TakeoverInit: &takeoverInit,
+ PrivateKey: privKey,
+ NetworkConfig: netconf,
+ }
+ agentInitRaw, err := proto.Marshal(&agentInit)
+ if err != nil {
+ return nil, fmt.Errorf("unable to marshal AgentInit message: %v", err)
+ }
+
+ // Append AgentInit spec to initramfs
+ compressedW, err := zstd.NewWriter(initramfsFile, zstd.WithEncoderLevel(1))
+ if err != nil {
+ return nil, fmt.Errorf("while creating zstd writer: %w", err)
+ }
+ cpioW := cpio.NewWriter(compressedW)
+ cpioW.WriteHeader(&cpio.Header{
+ Name: "/init.pb",
+ Size: int64(len(agentInitRaw)),
+ Mode: cpio.TypeReg | 0o644,
+ })
+ cpioW.Write(agentInitRaw)
+ cpioW.Close()
+ compressedW.Close()
+
+ agentParams := bootparam.Params{
+ bootparam.Param{Param: "quiet"},
+ bootparam.Param{Param: "init", Value: "/init"},
+ }
+
+ var customConsoles bool
+ cmdline, err := os.ReadFile("/proc/cmdline")
+ if err != nil {
+ warnings = append(warnings, fmt.Errorf("unable to read current kernel command line: %w", err))
+ } else {
+ params, _, err := bootparam.Unmarshal(string(cmdline))
+ // If the existing command line is well-formed, add all existing console
+ // parameters to the console for the agent
+ if err == nil {
+ for _, p := range params {
+ if p.Param == "console" {
+ agentParams = append(agentParams, p)
+ customConsoles = true
+ }
+ }
+ }
+ }
+ if !customConsoles {
+ // Add the "default" console on x86
+ agentParams = append(agentParams, bootparam.Param{Param: "console", Value: "ttyS0,115200"})
+ }
+ agentCmdline, err := bootparam.Marshal(agentParams, "")
+ // Stage agent payload into kernel memory
+ if err := kexec.FileLoad(kernelFile, initramfsFile, agentCmdline); err != nil {
+ return nil, fmt.Errorf("failed to load kexec payload: %w", err)
+ }
+ var warningsStrs []string
+ for _, w := range warnings {
+ warningsStrs = append(warningsStrs, w.Error())
+ }
+ return &api.TakeoverSuccess{
+ InitMessage: &takeoverInit,
+ Key: pubKey,
+ Warning: warningsStrs,
+ }, nil
+}
+
+// Environment variable which tells the takeover binary to run the second stage
+const detachedLaunchEnv = "TAKEOVER_DETACHED_LAUNCH"
+
+func main() {
+ // Check if the second stage should be executed
+ if os.Getenv(detachedLaunchEnv) == "1" {
+ // Wait 5 seconds for data to be sent, connections to be closed and
+ // syncs to be executed
+ time.Sleep(5 * time.Second)
+ // Perform kexec, this will not return unless it fails
+ err := unix.Reboot(unix.LINUX_REBOOT_CMD_KEXEC)
+ var msg string = "takeover: reboot succeeded, but we're still runing??"
+ if err != nil {
+ msg = err.Error()
+ }
+ // We have no standard output/error anymore, if this fails it's
+ // just borked. Attempt to dump the error into kmesg for manual
+ // debugging.
+ kmsg, err := os.OpenFile("/dev/kmsg", os.O_WRONLY, 0)
+ if err != nil {
+ os.Exit(2)
+ }
+ kmsg.WriteString(msg)
+ kmsg.Close()
+ os.Exit(1)
+ }
+
+ var takeoverResp api.TakeoverResponse
+ res, err := setupTakeover()
+ if err != nil {
+ takeoverResp.Result = &api.TakeoverResponse_Error{Error: &api.TakeoverError{
+ Message: err.Error(),
+ }}
+ } else {
+ takeoverResp.Result = &api.TakeoverResponse_Success{Success: res}
+ }
+ // Respond to stdout
+ takeoverRespRaw, err := proto.Marshal(&takeoverResp)
+ if err != nil {
+ log.Fatalf("failed to marshal response: %v", err)
+ }
+ if _, err := os.Stdout.Write(takeoverRespRaw); err != nil {
+ log.Fatalf("failed to write response to stdout: %v", err)
+ }
+ // Close stdout, we're done responding
+ os.Stdout.Close()
+
+ // Start second stage which waits for 5 seconds while performing
+ // final cleanup.
+ detachedCmd := exec.Command("/proc/self/exe")
+ detachedCmd.Env = []string{detachedLaunchEnv + "=1"}
+ if err := detachedCmd.Start(); err != nil {
+ log.Fatalf("failed to launch final stage: %v", err)
+ }
+ // Release the second stage so that the first stage can cleanly terminate.
+ if err := detachedCmd.Process.Release(); err != nil {
+ log.Fatalf("error releasing final stage process: %v", err)
+ }
+}