cloud: move takeover to agent/takeover

The takeover package is tightly coupled with the agent, so lets move it
there.

Change-Id: I38ae69d4f4e7a4f6a04b0fefb5f127ebc71f5961
Reviewed-on: https://review.monogon.dev/c/monogon/+/2790
Reviewed-by: Lorenz Brun <lorenz@monogon.tech>
Tested-by: Jenkins CI
diff --git a/cloud/agent/BUILD.bazel b/cloud/agent/BUILD.bazel
index 96e3a38..6f5ce82 100644
--- a/cloud/agent/BUILD.bazel
+++ b/cloud/agent/BUILD.bazel
@@ -1,6 +1,4 @@
 load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_library", "go_test")
-load("//metropolis/node/build/fwprune:def.bzl", "fsspec_linux_firmware")
-load("//metropolis/node/build:def.bzl", "node_initramfs")
 
 go_library(
     name = "agent_lib",
@@ -57,25 +55,3 @@
         "@com_github_stretchr_testify//assert",
     ],
 )
-
-fsspec_linux_firmware(
-    name = "firmware",
-    firmware_files = ["@linux-firmware//:all_files"],
-    kernel = "//third_party/linux",
-    metadata = "@linux-firmware//:metadata",
-)
-
-node_initramfs(
-    name = "initramfs",
-    files = {
-        ":agent": "/init",
-        "@com_github_coredns_coredns//:coredns": "/kubernetes/bin/coredns",
-        "//metropolis/node/core/network/dns:resolv.conf": "/etc/resolv.conf",
-        "@cacerts//file": "/etc/ssl/cert.pem",
-    },
-    fsspecs = [
-        "//metropolis/node/build:earlydev.fsspec",
-        ":firmware",
-    ],
-    visibility = ["//cloud:__subpackages__"],
-)
diff --git a/cloud/agent/e2e/BUILD.bazel b/cloud/agent/e2e/BUILD.bazel
index 16e9731..d05031b 100644
--- a/cloud/agent/e2e/BUILD.bazel
+++ b/cloud/agent/e2e/BUILD.bazel
@@ -4,7 +4,7 @@
     name = "e2e_test",
     srcs = ["main_test.go"],
     data = [
-        "//cloud/agent:initramfs",
+        "//cloud/agent/takeover:initramfs",
         "//metropolis/installer/test/testos:testos_bundle",
         "//third_party/edk2:firmware",
         "//third_party/linux",
diff --git a/cloud/agent/e2e/main_test.go b/cloud/agent/e2e/main_test.go
index a422d0e..4fd06ec 100644
--- a/cloud/agent/e2e/main_test.go
+++ b/cloud/agent/e2e/main_test.go
@@ -184,7 +184,7 @@
 	if err != nil {
 		t.Fatal(err)
 	}
-	initramfsOrigPath, err := runfiles.Rlocation("_main/cloud/agent/initramfs.cpio.zst")
+	initramfsOrigPath, err := runfiles.Rlocation("_main/cloud/agent/takeover/initramfs.cpio.zst")
 	if err != nil {
 		t.Fatal(err)
 	}
diff --git a/cloud/agent/takeover/BUILD.bazel b/cloud/agent/takeover/BUILD.bazel
new file mode 100644
index 0000000..855621a
--- /dev/null
+++ b/cloud/agent/takeover/BUILD.bazel
@@ -0,0 +1,71 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_library")
+load("//metropolis/node/build/mkucode:def.bzl", "cpio_ucode")
+load("//build/static_binary_tarball:def.bzl", "static_binary_tarball")
+load("//metropolis/node/build:def.bzl", "node_initramfs")
+load("//metropolis/node/build/fwprune:def.bzl", "fsspec_linux_firmware")
+
+go_library(
+    name = "takeover_lib",
+    srcs = ["takeover.go"],
+    embedsrcs = [
+        "//third_party/linux",  #keep
+        ":ucode",  #keep
+        ":initramfs",  #keep
+    ],
+    importpath = "source.monogon.dev/cloud/agent/takeover",
+    visibility = ["//visibility:private"],
+    deps = [
+        "//cloud/agent/api",
+        "//metropolis/pkg/bootparam",
+        "//metropolis/pkg/kexec",
+        "//net/dump",
+        "//net/proto",
+        "@com_github_cavaliergopher_cpio//:cpio",
+        "@com_github_klauspost_compress//zstd",
+        "@org_golang_google_protobuf//proto",
+        "@org_golang_x_sys//unix",
+    ],
+)
+
+node_initramfs(
+    name = "initramfs",
+    files = {
+        "//cloud/agent:agent": "/init",
+        "@com_github_coredns_coredns//:coredns": "/kubernetes/bin/coredns",
+        "//metropolis/node/core/network/dns:resolv.conf": "/etc/resolv.conf",
+        "@cacerts//file": "/etc/ssl/cert.pem",
+    },
+    fsspecs = [
+        "//metropolis/node/build:earlydev.fsspec",
+        ":firmware",
+    ],
+    visibility = ["//cloud/agent:__subpackages__"],
+)
+
+go_binary(
+    name = "takeover",
+    embed = [":takeover_lib"],
+    visibility = ["//visibility:public"],
+)
+
+cpio_ucode(
+    name = "ucode",
+    ucode = {
+        "@linux-firmware//:amd_ucode": "AuthenticAMD",
+        "@intel_ucode//:fam6h": "GenuineIntel",
+    },
+)
+
+fsspec_linux_firmware(
+    name = "firmware",
+    firmware_files = ["@linux-firmware//:all_files"],
+    kernel = "//third_party/linux",
+    metadata = "@linux-firmware//:metadata",
+)
+
+# Used by container_images, forces a static build of the test_agent.
+static_binary_tarball(
+    name = "takeover_layer",
+    executable = ":takeover",
+    visibility = ["//visibility:public"],
+)
diff --git a/cloud/agent/takeover/e2e/BUILD.bazel b/cloud/agent/takeover/e2e/BUILD.bazel
new file mode 100644
index 0000000..1cdd840
--- /dev/null
+++ b/cloud/agent/takeover/e2e/BUILD.bazel
@@ -0,0 +1,20 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
+go_test(
+    name = "e2e_test",
+    srcs = ["main_test.go"],
+    data = [
+        "//cloud/agent/takeover",
+        "//third_party/edk2:firmware",
+        "@debian_11_cloudimage//file",
+    ],
+    deps = [
+        "//cloud/agent/api",
+        "//metropolis/pkg/fat32",
+        "//metropolis/pkg/freeport",
+        "@com_github_pkg_sftp//:sftp",
+        "@io_bazel_rules_go//go/runfiles:go_default_library",
+        "@org_golang_google_protobuf//proto",
+        "@org_golang_x_crypto//ssh",
+    ],
+)
diff --git a/cloud/agent/takeover/e2e/main_test.go b/cloud/agent/takeover/e2e/main_test.go
new file mode 100644
index 0000000..6d489eb
--- /dev/null
+++ b/cloud/agent/takeover/e2e/main_test.go
@@ -0,0 +1,219 @@
+package e2e
+
+import (
+	"bufio"
+	"bytes"
+	"crypto/ed25519"
+	"crypto/rand"
+	"encoding/json"
+	"fmt"
+	"io"
+	"net"
+	"os"
+	"os/exec"
+	"strings"
+	"testing"
+	"time"
+
+	"github.com/bazelbuild/rules_go/go/runfiles"
+	"github.com/pkg/sftp"
+	"golang.org/x/crypto/ssh"
+	"google.golang.org/protobuf/proto"
+
+	"source.monogon.dev/cloud/agent/api"
+
+	"source.monogon.dev/metropolis/pkg/fat32"
+	"source.monogon.dev/metropolis/pkg/freeport"
+)
+
+func TestE2E(t *testing.T) {
+	pubKey, privKey, err := ed25519.GenerateKey(rand.Reader)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	sshPubKey, err := ssh.NewPublicKey(pubKey)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	sshPrivkey, err := ssh.NewSignerFromKey(privKey)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	// CloudConfig doesn't really have a rigid spec, so just put things into it
+	cloudConfig := make(map[string]any)
+	cloudConfig["ssh_authorized_keys"] = []string{
+		strings.TrimSuffix(string(ssh.MarshalAuthorizedKey(sshPubKey)), "\n"),
+	}
+
+	userData, err := json.Marshal(cloudConfig)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	rootInode := fat32.Inode{
+		Attrs: fat32.AttrDirectory,
+		Children: []*fat32.Inode{
+			{
+				Name:    "user-data",
+				Content: strings.NewReader("#cloud-config\n" + string(userData)),
+			},
+			{
+				Name:    "meta-data",
+				Content: strings.NewReader(""),
+			},
+		},
+	}
+	cloudInitDataFile, err := os.CreateTemp("", "cidata*.img")
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer os.Remove(cloudInitDataFile.Name())
+	if err := fat32.WriteFS(cloudInitDataFile, rootInode, fat32.Options{Label: "cidata"}); err != nil {
+		t.Fatal(err)
+	}
+	cloudImagePath, err := runfiles.Rlocation("debian_11_cloudimage/file/downloaded")
+	if err != nil {
+		t.Fatal(err)
+	}
+	ovmfVarsPath, err := runfiles.Rlocation("edk2/OVMF_VARS.fd")
+	if err != nil {
+		t.Fatal(err)
+	}
+	ovmfCodePath, err := runfiles.Rlocation("edk2/OVMF_CODE.fd")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	sshPort, sshPortCloser, err := freeport.AllocateTCPPort()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	qemuArgs := []string{
+		"-machine", "q35", "-accel", "kvm", "-nographic", "-nodefaults", "-m", "1024",
+		"-cpu", "host", "-smp", "sockets=1,cpus=1,cores=2,threads=2,maxcpus=4",
+		"-drive", "if=pflash,format=raw,readonly=on,file=" + ovmfCodePath,
+		"-drive", "if=pflash,format=raw,snapshot=on,file=" + ovmfVarsPath,
+		"-drive", "if=virtio,format=qcow2,snapshot=on,cache=unsafe,file=" + cloudImagePath,
+		"-drive", "if=virtio,format=raw,snapshot=on,file=" + cloudInitDataFile.Name(),
+		"-netdev", fmt.Sprintf("user,id=net0,net=10.42.0.0/24,dhcpstart=10.42.0.10,hostfwd=tcp::%d-:22", sshPort),
+		"-device", "virtio-net-pci,netdev=net0,mac=22:d5:8e:76:1d:07",
+		"-device", "virtio-rng-pci",
+		"-serial", "stdio",
+		"-no-reboot",
+	}
+	qemuCmd := exec.Command("qemu-system-x86_64", qemuArgs...)
+	stdoutPipe, err := qemuCmd.StdoutPipe()
+	if err != nil {
+		t.Fatal(err)
+	}
+	agentStarted := make(chan struct{})
+	go func() {
+		s := bufio.NewScanner(stdoutPipe)
+		for s.Scan() {
+			t.Log("kernel: " + s.Text())
+			if strings.Contains(s.Text(), "Monogon BMaaS Agent started") {
+				agentStarted <- struct{}{}
+				break
+			}
+		}
+		qemuCmd.Wait()
+	}()
+	qemuCmd.Stderr = os.Stderr
+	sshPortCloser.Close()
+	if err := qemuCmd.Start(); err != nil {
+		t.Fatal(err)
+	}
+	defer qemuCmd.Process.Kill()
+
+	var c *ssh.Client
+	for {
+		c, err = ssh.Dial("tcp", net.JoinHostPort("localhost", fmt.Sprintf("%d", sshPort)), &ssh.ClientConfig{
+			User:            "debian",
+			Auth:            []ssh.AuthMethod{ssh.PublicKeys(sshPrivkey)},
+			HostKeyCallback: ssh.InsecureIgnoreHostKey(),
+			Timeout:         5 * time.Second,
+		})
+		if err != nil {
+			t.Logf("error connecting via SSH, retrying: %v", err)
+			time.Sleep(1 * time.Second)
+			continue
+		}
+		break
+	}
+	defer c.Close()
+	sc, err := sftp.NewClient(c)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer sc.Close()
+	takeoverFile, err := sc.Create("takeover")
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer takeoverFile.Close()
+	if err := takeoverFile.Chmod(0o755); err != nil {
+		t.Fatal(err)
+	}
+	takeoverPath, err := runfiles.Rlocation("_main/cloud/agent/takeover/takeover_/takeover")
+	if err != nil {
+		t.Fatal(err)
+	}
+	takeoverSrcFile, err := os.Open(takeoverPath)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer takeoverSrcFile.Close()
+	if _, err := io.Copy(takeoverFile, takeoverSrcFile); err != nil {
+		t.Fatal(err)
+	}
+	if err := takeoverFile.Close(); err != nil {
+		t.Fatal(err)
+	}
+	sc.Close()
+
+	sess, err := c.NewSession()
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer sess.Close()
+
+	init := api.TakeoverInit{
+		MachineId:     "test",
+		BmaasEndpoint: "localhost:1234",
+	}
+	initRaw, err := proto.Marshal(&init)
+	if err != nil {
+		t.Fatal(err)
+	}
+	sess.Stdin = bytes.NewReader(initRaw)
+	var stdoutBuf bytes.Buffer
+	var stderrBuf bytes.Buffer
+	sess.Stdout = &stdoutBuf
+	sess.Stderr = &stderrBuf
+	if err := sess.Run("sudo ./takeover"); err != nil {
+		t.Errorf("stderr:\n%s\n\n", stderrBuf.String())
+		t.Fatal(err)
+	}
+	var resp api.TakeoverResponse
+	if err := proto.Unmarshal(stdoutBuf.Bytes(), &resp); err != nil {
+		t.Fatal(err)
+	}
+	switch res := resp.Result.(type) {
+	case *api.TakeoverResponse_Success:
+		if res.Success.InitMessage.BmaasEndpoint != init.BmaasEndpoint {
+			t.Error("InitMessage not passed through properly")
+		}
+	case *api.TakeoverResponse_Error:
+		t.Fatalf("takeover returned error: %v", res.Error.Message)
+	}
+	select {
+	case <-agentStarted:
+		// Done, test passed
+	case <-time.After(30 * time.Second):
+		t.Fatal("Waiting for BMaaS agent startup timed out")
+	}
+}
diff --git a/cloud/agent/takeover/takeover.go b/cloud/agent/takeover/takeover.go
new file mode 100644
index 0000000..d313174
--- /dev/null
+++ b/cloud/agent/takeover/takeover.go
@@ -0,0 +1,241 @@
+// takeover is a self-contained executable which when executed loads the BMaaS
+// agent via kexec. It is intended to be called over SSH, given a binary
+// TakeoverInit message over standard input and (if all preparation work
+// completed successfully) will respond with a TakeoverResponse on standard
+// output. At that point the new kernel and agent initramfs are fully staged
+// by the current kernel.
+// The second stage which is also part of this binary, selected by an
+// environment variable, is then executed in detached mode and the main
+// takeover binary called over SSH terminates.
+// The second stage waits for 5 seconds for the main binary to exit, the SSH
+// session to be torn down and various other things before issuing the final
+// non-returning syscall which jumps into the new kernel.
+
+package main
+
+import (
+	"bytes"
+	"crypto/ed25519"
+	"crypto/rand"
+	_ "embed"
+	"errors"
+	"fmt"
+	"io"
+	"log"
+	"os"
+	"os/exec"
+	"time"
+
+	"github.com/cavaliergopher/cpio"
+	"github.com/klauspost/compress/zstd"
+	"golang.org/x/sys/unix"
+	"google.golang.org/protobuf/proto"
+
+	"source.monogon.dev/cloud/agent/api"
+	"source.monogon.dev/metropolis/pkg/bootparam"
+	"source.monogon.dev/metropolis/pkg/kexec"
+	netdump "source.monogon.dev/net/dump"
+	netapi "source.monogon.dev/net/proto"
+)
+
+//go:embed third_party/linux/bzImage
+var kernel []byte
+
+//go:embed ucode.cpio
+var ucode []byte
+
+//go:embed initramfs.cpio.zst
+var initramfs []byte
+
+// newMemfile creates a new file which is not located on a specific filesystem,
+// but is instead backed by anonymous memory.
+func newMemfile(name string, flags int) (*os.File, error) {
+	fd, err := unix.MemfdCreate(name, flags)
+	if err != nil {
+		return nil, fmt.Errorf("memfd_create failed: %w", err)
+	}
+	return os.NewFile(uintptr(fd), name), nil
+}
+
+func setupTakeover() (*api.TakeoverSuccess, error) {
+	// Read init specification from stdin.
+	initRaw, err := io.ReadAll(os.Stdin)
+	if err != nil {
+		return nil, fmt.Errorf("failed to read TakeoverInit message from stdin: %w", err)
+	}
+	var takeoverInit api.TakeoverInit
+	if err := proto.Unmarshal(initRaw, &takeoverInit); err != nil {
+		return nil, fmt.Errorf("failed to parse TakeoverInit messag from stdin: %w", err)
+	}
+
+	// Sanity check for empty TakeoverInit messages
+	if takeoverInit.BmaasEndpoint == "" {
+		return nil, errors.New("BMaaS endpoint is empty, check that a proper TakeoverInit message has been provided")
+	}
+
+	// Load data from embedded files into memfiles as the kexec load syscall
+	// requires file descriptors.
+	kernelFile, err := newMemfile("kernel", 0)
+	if err != nil {
+		return nil, fmt.Errorf("failed to create kernel memfile: %w", err)
+	}
+	initramfsFile, err := newMemfile("initramfs", 0)
+	if err != nil {
+		return nil, fmt.Errorf("failed to create initramfs memfile: %w", err)
+	}
+	if _, err := kernelFile.ReadFrom(bytes.NewReader(kernel)); err != nil {
+		return nil, fmt.Errorf("failed to read kernel into memory-backed file: %w", err)
+	}
+	if _, err := initramfsFile.ReadFrom(bytes.NewReader(ucode)); err != nil {
+		return nil, fmt.Errorf("failed to read ucode into memory-backed file: %w", err)
+	}
+	if _, err := initramfsFile.ReadFrom(bytes.NewReader(initramfs)); err != nil {
+		return nil, fmt.Errorf("failed to read initramfs into memory-backed file: %w", err)
+	}
+
+	// Dump the current network configuration
+	netconf, warnings, err := netdump.Dump()
+	if err != nil {
+		return nil, fmt.Errorf("failed to dump network configuration: %w", err)
+	}
+
+	if len(netconf.Nameserver) == 0 {
+		netconf.Nameserver = []*netapi.Nameserver{{
+			Ip: "8.8.8.8",
+		}, {
+			Ip: "1.1.1.1",
+		}}
+	}
+
+	// Generate agent private key
+	pubKey, privKey, err := ed25519.GenerateKey(rand.Reader)
+	if err != nil {
+		return nil, fmt.Errorf("unable to generate Ed25519 key: %w", err)
+	}
+
+	agentInit := api.AgentInit{
+		TakeoverInit:  &takeoverInit,
+		PrivateKey:    privKey,
+		NetworkConfig: netconf,
+	}
+	agentInitRaw, err := proto.Marshal(&agentInit)
+	if err != nil {
+		return nil, fmt.Errorf("unable to marshal AgentInit message: %v", err)
+	}
+
+	// Append AgentInit spec to initramfs
+	compressedW, err := zstd.NewWriter(initramfsFile, zstd.WithEncoderLevel(1))
+	if err != nil {
+		return nil, fmt.Errorf("while creating zstd writer: %w", err)
+	}
+	cpioW := cpio.NewWriter(compressedW)
+	cpioW.WriteHeader(&cpio.Header{
+		Name: "/init.pb",
+		Size: int64(len(agentInitRaw)),
+		Mode: cpio.TypeReg | 0o644,
+	})
+	cpioW.Write(agentInitRaw)
+	cpioW.Close()
+	compressedW.Close()
+
+	agentParams := bootparam.Params{
+		bootparam.Param{Param: "quiet"},
+		bootparam.Param{Param: "init", Value: "/init"},
+	}
+
+	var customConsoles bool
+	cmdline, err := os.ReadFile("/proc/cmdline")
+	if err != nil {
+		warnings = append(warnings, fmt.Errorf("unable to read current kernel command line: %w", err))
+	} else {
+		params, _, err := bootparam.Unmarshal(string(cmdline))
+		// If the existing command line is well-formed, add all existing console
+		// parameters to the console for the agent
+		if err == nil {
+			for _, p := range params {
+				if p.Param == "console" {
+					agentParams = append(agentParams, p)
+					customConsoles = true
+				}
+			}
+		}
+	}
+	if !customConsoles {
+		// Add the "default" console on x86
+		agentParams = append(agentParams, bootparam.Param{Param: "console", Value: "ttyS0,115200"})
+	}
+	agentCmdline, err := bootparam.Marshal(agentParams, "")
+	// Stage agent payload into kernel memory
+	if err := kexec.FileLoad(kernelFile, initramfsFile, agentCmdline); err != nil {
+		return nil, fmt.Errorf("failed to load kexec payload: %w", err)
+	}
+	var warningsStrs []string
+	for _, w := range warnings {
+		warningsStrs = append(warningsStrs, w.Error())
+	}
+	return &api.TakeoverSuccess{
+		InitMessage: &takeoverInit,
+		Key:         pubKey,
+		Warning:     warningsStrs,
+	}, nil
+}
+
+// Environment variable which tells the takeover binary to run the second stage
+const detachedLaunchEnv = "TAKEOVER_DETACHED_LAUNCH"
+
+func main() {
+	// Check if the second stage should be executed
+	if os.Getenv(detachedLaunchEnv) == "1" {
+		// Wait 5 seconds for data to be sent, connections to be closed and
+		// syncs to be executed
+		time.Sleep(5 * time.Second)
+		// Perform kexec, this will not return unless it fails
+		err := unix.Reboot(unix.LINUX_REBOOT_CMD_KEXEC)
+		var msg string = "takeover: reboot succeeded, but we're still runing??"
+		if err != nil {
+			msg = err.Error()
+		}
+		// We have no standard output/error anymore, if this fails it's
+		// just borked. Attempt to dump the error into kmesg for manual
+		// debugging.
+		kmsg, err := os.OpenFile("/dev/kmsg", os.O_WRONLY, 0)
+		if err != nil {
+			os.Exit(2)
+		}
+		kmsg.WriteString(msg)
+		kmsg.Close()
+		os.Exit(1)
+	}
+
+	var takeoverResp api.TakeoverResponse
+	res, err := setupTakeover()
+	if err != nil {
+		takeoverResp.Result = &api.TakeoverResponse_Error{Error: &api.TakeoverError{
+			Message: err.Error(),
+		}}
+	} else {
+		takeoverResp.Result = &api.TakeoverResponse_Success{Success: res}
+	}
+	// Respond to stdout
+	takeoverRespRaw, err := proto.Marshal(&takeoverResp)
+	if err != nil {
+		log.Fatalf("failed to marshal response: %v", err)
+	}
+	if _, err := os.Stdout.Write(takeoverRespRaw); err != nil {
+		log.Fatalf("failed to write response to stdout: %v", err)
+	}
+	// Close stdout, we're done responding
+	os.Stdout.Close()
+
+	// Start second stage which waits for 5 seconds while performing
+	// final cleanup.
+	detachedCmd := exec.Command("/proc/self/exe")
+	detachedCmd.Env = []string{detachedLaunchEnv + "=1"}
+	if err := detachedCmd.Start(); err != nil {
+		log.Fatalf("failed to launch final stage: %v", err)
+	}
+	// Release the second stage so that the first stage can cleanly terminate.
+	if err := detachedCmd.Process.Release(); err != nil {
+		log.Fatalf("error releasing final stage process: %v", err)
+	}
+}