core -> metropolis
Smalltown is now called Metropolis!
This is the first commit in a series of cleanup commits that prepare us
for an open source release. This one just some Bazel packages around to
follow a stricter directory layout.
All of Metropolis now lives in `//metropolis`.
All of Metropolis Node code now lives in `//metropolis/node`.
All of the main /init now lives in `//m/n/core`.
All of the Kubernetes functionality/glue now lives in `//m/n/kubernetes`.
Next steps:
- hunt down all references to Smalltown and replace them appropriately
- narrow down visibility rules
- document new code organization
- move `//build/toolchain` to `//monogon/build/toolchain`
- do another cleanup pass between `//golibs` and
`//monogon/node/{core,common}`.
- remove `//delta` and `//anubis`
Fixes T799.
Test Plan: Just a very large refactor. CI should help us out here.
Bug: T799
X-Origin-Diff: phab/D667
GitOrigin-RevId: 6029b8d4edc42325d50042596b639e8b122d0ded
diff --git a/metropolis/node/BUILD.bazel b/metropolis/node/BUILD.bazel
new file mode 100644
index 0000000..48c9177
--- /dev/null
+++ b/metropolis/node/BUILD.bazel
@@ -0,0 +1,141 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+load("//metropolis/node/build:def.bzl", "smalltown_initramfs")
+
+go_library(
+ name = "go_default_library",
+ srcs = ["ports.go"],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/node",
+ visibility = ["//visibility:public"],
+)
+
+# debug_build checks if we're building in debug mode and enables various debug features for the image. Currently this
+# is only used for attaching a Delve debugger to init when it's enabled.
+config_setting(
+ name = "debug_build",
+ values = {
+ "compilation_mode": "dbg",
+ },
+)
+
+smalltown_initramfs(
+ name = "initramfs",
+ extra_dirs = [
+ "/kubernetes/conf/flexvolume-plugins",
+ "/containerd/run",
+ ],
+ files = {
+ "//metropolis/node/core": "/init",
+ "//third_party/xfsprogs:mkfs.xfs": "/bin/mkfs.xfs",
+
+ # CA Certificate bundle & os-release
+ "@cacerts//file": "/etc/ssl/cert.pem",
+ ":os-release-info": "/etc/os-release",
+
+ # Hyperkube
+ "//metropolis/node/kubernetes/hyperkube": "/kubernetes/bin/kube",
+
+ # CoreDNS
+ "@com_github_coredns_coredns//:coredns": "/kubernetes/bin/coredns",
+
+ # runsc/gVisor
+ "@com_github_google_gvisor//runsc": "/containerd/bin/runsc",
+ "@com_github_google_gvisor_containerd_shim//cmd/containerd-shim-runsc-v1": "/containerd/bin/containerd-shim-runsc-v1",
+
+ # runc (runtime in files_cc because of cgo)
+ "@com_github_containerd_containerd//cmd/containerd-shim-runc-v2": "/containerd/bin/containerd-shim-runc-v2",
+
+ # Containerd
+ "@com_github_containerd_containerd//cmd/containerd": "/containerd/bin/containerd",
+
+ # Containerd config files
+ "//metropolis/node/kubernetes/containerd:runsc.toml": "/containerd/conf/runsc.toml",
+ "//metropolis/node/kubernetes/containerd:config.toml": "/containerd/conf/config.toml",
+ "//metropolis/node/kubernetes/containerd:cnispec.gojson": "/containerd/conf/cnispec.gojson",
+
+ # Containerd preseed bundles
+ "//metropolis/test/e2e/preseedtest:preseedtest.tar": "/containerd/preseed/k8s.io/preseedtest.tar",
+ "//metropolis/test/e2e/k8s_cts:k8s_cts_image.tar": "/containerd/preseed/k8s.io/k8s_cts.tar",
+
+ # CNI Plugins
+ "@com_github_containernetworking_plugins//plugins/main/loopback": "/containerd/bin/cni/loopback",
+ "@com_github_containernetworking_plugins//plugins/main/ptp": "/containerd/bin/cni/ptp",
+ "@com_github_containernetworking_plugins//plugins/ipam/host-local": "/containerd/bin/cni/host-local",
+
+ # Delve
+ "@com_github_go_delve_delve//cmd/dlv:dlv": "/dlv",
+ },
+ files_cc = {
+ # runc runtime, with cgo
+ "@com_github_opencontainers_runc//:runc": "/containerd/bin/runc",
+ },
+)
+
+genrule(
+ name = "image",
+ srcs = [
+ "//third_party/linux:bzImage",
+ ":initramfs",
+ ],
+ outs = [
+ "smalltown.img",
+ ],
+ cmd = """
+ $(location //metropolis/node/build/mkimage) \
+ -efi $(location //third_party/linux:bzImage) \
+ -initramfs $(location :initramfs) \
+ -out $@
+ """,
+ tools = [
+ "//metropolis/node/build/mkimage",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+genrule(
+ name = "swtpm_data",
+ outs = [
+ "tpm/tpm2-00.permall",
+ "tpm/signkey.pem",
+ "tpm/issuercert.pem",
+ ],
+ cmd = """
+ mkdir -p tpm/ca
+
+ cat <<EOF > tpm/swtpm.conf
+create_certs_tool= /usr/share/swtpm/swtpm-localca
+create_certs_tool_config = tpm/swtpm-localca.conf
+create_certs_tool_options = /etc/swtpm-localca.options
+EOF
+
+ cat <<EOF > tpm/swtpm-localca.conf
+statedir = tpm/ca
+signingkey = tpm/ca/signkey.pem
+issuercert = tpm/ca/issuercert.pem
+certserial = tpm/ca/certserial
+EOF
+
+ swtpm_setup \
+ --tpmstate tpm \
+ --create-ek-cert \
+ --create-platform-cert \
+ --allow-signing \
+ --tpm2 \
+ --display \
+ --pcr-banks sha1,sha256,sha384,sha512 \
+ --config tpm/swtpm.conf
+
+ cp tpm/tpm2-00.permall $(location tpm/tpm2-00.permall)
+ cp tpm/ca/issuercert.pem $(location tpm/issuercert.pem)
+ cp tpm/ca/signkey.pem $(location tpm/signkey.pem)
+ """,
+ visibility = ["//visibility:public"],
+)
+
+load("//metropolis/node/build/genosrelease:defs.bzl", "os_release")
+
+os_release(
+ name = "os-release-info",
+ os_id = "smalltown",
+ os_name = "Smalltown",
+ stamp_var = "STABLE_SIGNOS_version",
+)
diff --git a/metropolis/node/build/BUILD b/metropolis/node/build/BUILD
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/metropolis/node/build/BUILD
diff --git a/metropolis/node/build/def.bzl b/metropolis/node/build/def.bzl
new file mode 100644
index 0000000..e2885e5
--- /dev/null
+++ b/metropolis/node/build/def.bzl
@@ -0,0 +1,257 @@
+# Copyright 2020 The Monogon Project Authors.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+def _build_pure_transition_impl(settings, attr):
+ """
+ Transition that enables pure, static build of Go binaries.
+ """
+ return {
+ "@io_bazel_rules_go//go/config:pure": True,
+ "@io_bazel_rules_go//go/config:static": True,
+ }
+
+build_pure_transition = transition(
+ implementation = _build_pure_transition_impl,
+ inputs = [],
+ outputs = [
+ "@io_bazel_rules_go//go/config:pure",
+ "@io_bazel_rules_go//go/config:static",
+ ],
+)
+
+def _build_static_transition_impl(settings, attr):
+ """
+ Transition that enables static builds with CGo and musl for Go binaries.
+ """
+ return {
+ "@io_bazel_rules_go//go/config:static": True,
+ "//command_line_option:crosstool_top": "//build/toolchain/musl-host-gcc:musl_host_cc_suite",
+ }
+
+build_static_transition = transition(
+ implementation = _build_static_transition_impl,
+ inputs = [],
+ outputs = [
+ "@io_bazel_rules_go//go/config:static",
+ "//command_line_option:crosstool_top",
+ ],
+)
+
+def _smalltown_initramfs_impl(ctx):
+ """
+ Generate an lz4-compressed initramfs based on a label/file list.
+ """
+
+ # Generate config file for gen_init_cpio that describes the initramfs to build.
+ cpio_list_name = ctx.label.name + ".cpio_list"
+ cpio_list = ctx.actions.declare_file(cpio_list_name)
+
+ # Start out with some standard initramfs device files.
+ cpio_list_content = [
+ "dir /dev 0755 0 0",
+ "nod /dev/console 0600 0 0 c 5 1",
+ "nod /dev/null 0644 0 0 c 1 3",
+ "nod /dev/kmsg 0644 0 0 c 1 11",
+ "nod /dev/ptmx 0644 0 0 c 5 2",
+ ]
+
+ # Find all directories that need to be created.
+ directories_needed = []
+ for _, p in ctx.attr.files.items():
+ if not p.startswith("/"):
+ fail("file {} invalid: must begin with /".format(p))
+
+ # Get all intermediate directories on path to file
+ parts = p.split("/")[1:-1]
+ directories_needed.append(parts)
+
+ for _, p in ctx.attr.files_cc.items():
+ if not p.startswith("/"):
+ fail("file {} invalid: must begin with /".format(p))
+
+ # Get all intermediate directories on path to file
+ parts = p.split("/")[1:-1]
+ directories_needed.append(parts)
+
+ # Extend with extra directories defined by user.
+ for p in ctx.attr.extra_dirs:
+ if not p.startswith("/"):
+ fail("directory {} invalid: must begin with /".format(p))
+
+ parts = p.split("/")[1:]
+ directories_needed.append(parts)
+
+ directories = []
+ for parts in directories_needed:
+ # Turn directory parts [usr, local, bin] into successive subpaths [/usr, /usr/local, /usr/local/bin].
+ last = ""
+ for part in parts:
+ last += "/" + part
+
+ # TODO(q3k): this is slow - this should be a set instead, but starlark doesn't implement them.
+ # For the amount of files we're dealing with this doesn't matter, but all stars are pointing towards this
+ # becoming accidentally quadratic at some point in the future.
+ if last not in directories:
+ directories.append(last)
+
+ # Append instructions to create directories.
+ # Serendipitously, the directories should already be in the right order due to us not using a set to create the
+ # list. They might not be in an elegant order (ie, if files [/foo/one/one, /bar, /foo/two/two] are request, the
+ # order will be [/foo, /foo/one, /bar, /foo/two]), but that's fine.
+ for d in directories:
+ cpio_list_content.append("dir {} 0755 0 0".format(d))
+
+ # Append instructions to add files.
+ inputs = []
+ for label, p in ctx.attr.files.items():
+ # Figure out if this is an executable.
+ is_executable = True
+
+ di = label[DefaultInfo]
+ if di.files_to_run.executable == None:
+ # Generated non-executable files will have DefaultInfo.files_to_run.executable == None
+ is_executable = False
+ elif di.files_to_run.executable.is_source:
+ # Source files will have executable.is_source == True
+ is_executable = False
+
+ # Ensure only single output is declared.
+ # If you hit this error, figure out a better logic to find what file you need, maybe looking at providers other
+ # than DefaultInfo.
+ files = di.files.to_list()
+ if len(files) > 1:
+ fail("file {} has more than one output: {}", p, files)
+ src = files[0]
+ inputs.append(src)
+
+ mode = "0755" if is_executable else "0444"
+
+ cpio_list_content.append("file {} {} {} 0 0".format(p, src.path, mode))
+
+ for label, p in ctx.attr.files_cc.items():
+ # Figure out if this is an executable.
+ is_executable = True
+
+ di = label[DefaultInfo]
+ if di.files_to_run.executable == None:
+ # Generated non-executable files will have DefaultInfo.files_to_run.executable == None
+ is_executable = False
+ elif di.files_to_run.executable.is_source:
+ # Source files will have executable.is_source == True
+ is_executable = False
+
+ # Ensure only single output is declared.
+ # If you hit this error, figure out a better logic to find what file you need, maybe looking at providers other
+ # than DefaultInfo.
+ files = di.files.to_list()
+ if len(files) > 1:
+ fail("file {} has more than one output: {}", p, files)
+ src = files[0]
+ inputs.append(src)
+
+ mode = "0755" if is_executable else "0444"
+
+ cpio_list_content.append("file {} {} {} 0 0".format(p, src.path, mode))
+
+ # Write cpio_list.
+ ctx.actions.write(cpio_list, "\n".join(cpio_list_content))
+
+ gen_init_cpio = ctx.executable._gen_init_cpio
+ savestdout = ctx.executable._savestdout
+ lz4 = ctx.executable._lz4
+
+ # Generate 'raw' (uncompressed) initramfs
+ initramfs_raw_name = ctx.label.name
+ initramfs_raw = ctx.actions.declare_file(initramfs_raw_name)
+ ctx.actions.run(
+ outputs = [initramfs_raw],
+ inputs = [cpio_list] + inputs,
+ tools = [savestdout, gen_init_cpio],
+ executable = savestdout,
+ arguments = [initramfs_raw.path, gen_init_cpio.path, cpio_list.path],
+ )
+
+ # Compress raw initramfs using lz4c.
+ initramfs_name = ctx.label.name + ".lz4"
+ initramfs = ctx.actions.declare_file(initramfs_name)
+ ctx.actions.run(
+ outputs = [initramfs],
+ inputs = [initramfs_raw],
+ tools = [savestdout, lz4],
+ executable = lz4.path,
+ arguments = ["-l", initramfs_raw.path, initramfs.path],
+ )
+
+ return [DefaultInfo(files = depset([initramfs]))]
+
+smalltown_initramfs = rule(
+ implementation = _smalltown_initramfs_impl,
+ doc = """
+ Build a Smalltown initramfs. The initramfs will contain a basic /dev directory and all the files specified by the
+ `files` attribute. Executable files will have their permissions set to 0755, non-executable files will have
+ their permissions set to 0444. All parent directories will be created with 0755 permissions.
+ """,
+ attrs = {
+ "files": attr.label_keyed_string_dict(
+ mandatory = True,
+ allow_files = True,
+ doc = """
+ Dictionary of Labels to String, placing a given Label's output file in the initramfs at the location
+ specified by the String value. The specified labels must only have a single output.
+ """,
+ # Attach pure transition to ensure all binaries added to the initramfs are pure/static binaries.
+ cfg = build_pure_transition,
+ ),
+ "files_cc": attr.label_keyed_string_dict(
+ allow_files = True,
+ doc = """
+ Special case of 'files' for compilation targets that need to be built with the musl toolchain like
+ go_binary targets which need cgo or cc_binary targets.
+ """,
+ # Attach static transition to all files_cc inputs to ensure they are built with musl and static.
+ cfg = build_static_transition,
+ ),
+ "extra_dirs": attr.string_list(
+ default = [],
+ doc = """
+ Extra directories to create. These will be created in addition to all the directories required to
+ contain the files specified in the `files` attribute.
+ """,
+ ),
+
+ # Tools, implicit dependencies.
+ "_gen_init_cpio": attr.label(
+ default = Label("@linux//:gen_init_cpio"),
+ executable = True,
+ cfg = "host",
+ ),
+ "_lz4": attr.label(
+ default = Label("@com_github_lz4_lz4//programs:lz4"),
+ executable = True,
+ cfg = "host",
+ ),
+ "_savestdout": attr.label(
+ default = Label("//build/savestdout"),
+ executable = True,
+ cfg = "host",
+ ),
+
+ # Allow for transitions to be attached to this rule.
+ "_whitelist_function_transition": attr.label(
+ default = "@bazel_tools//tools/whitelists/function_transition_whitelist",
+ ),
+ },
+)
diff --git a/metropolis/node/build/genosrelease/BUILD.bazel b/metropolis/node/build/genosrelease/BUILD.bazel
new file mode 100644
index 0000000..9403d72
--- /dev/null
+++ b/metropolis/node/build/genosrelease/BUILD.bazel
@@ -0,0 +1,15 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_library")
+
+go_library(
+ name = "go_default_library",
+ srcs = ["main.go"],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/node/build/genosrelease",
+ visibility = ["//visibility:private"],
+ deps = ["@com_github_joho_godotenv//:go_default_library"],
+)
+
+go_binary(
+ name = "genosrelease",
+ embed = [":go_default_library"],
+ visibility = ["//visibility:public"],
+)
diff --git a/metropolis/node/build/genosrelease/defs.bzl b/metropolis/node/build/genosrelease/defs.bzl
new file mode 100644
index 0000000..61ce9e4
--- /dev/null
+++ b/metropolis/node/build/genosrelease/defs.bzl
@@ -0,0 +1,54 @@
+# Copyright 2020 The Monogon Project Authors.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+def _os_release_impl(ctx):
+ ctx.actions.run(
+ mnemonic = "GenOSRelease",
+ progress_message = "Generating os-release",
+ inputs = [ctx.info_file],
+ outputs = [ctx.outputs.out],
+ executable = ctx.executable._genosrelease,
+ arguments = [
+ "-status_file",
+ ctx.info_file.path,
+ "-out_file",
+ ctx.outputs.out.path,
+ "-stamp_var",
+ ctx.attr.stamp_var,
+ "-name",
+ ctx.attr.os_name,
+ "-id",
+ ctx.attr.os_id,
+ ],
+ )
+
+os_release = rule(
+ implementation = _os_release_impl,
+ attrs = {
+ "os_name": attr.string(mandatory = True),
+ "os_id": attr.string(mandatory = True),
+ "stamp_var": attr.string(mandatory = True),
+ "_genosrelease": attr.label(
+ default = Label("//metropolis/node/build/genosrelease"),
+ cfg = "host",
+ executable = True,
+ allow_files = True,
+ ),
+ },
+ outputs = {
+ "out": "os-release",
+ },
+)
diff --git a/metropolis/node/build/genosrelease/main.go b/metropolis/node/build/genosrelease/main.go
new file mode 100644
index 0000000..2344f19
--- /dev/null
+++ b/metropolis/node/build/genosrelease/main.go
@@ -0,0 +1,78 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// genosrelease provides rudimentary support to generate os-release files following the freedesktop spec
+// (https://www.freedesktop.org/software/systemd/man/os-release.html) from arguments and stamping
+package main
+
+import (
+ "flag"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "strings"
+
+ "github.com/joho/godotenv"
+)
+
+var (
+ flagStatusFile = flag.String("status_file", "", "path to bazel workspace status file")
+ flagOutFile = flag.String("out_file", "os-release", "path to os-release output file")
+ flagStampVar = flag.String("stamp_var", "", "variable to use as version from the workspace status file")
+ flagName = flag.String("name", "", "name parameter (see freedesktop spec)")
+ flagID = flag.String("id", "", "id parameter (see freedesktop spec)")
+)
+
+func main() {
+ flag.Parse()
+ statusFileContent, err := ioutil.ReadFile(*flagStatusFile)
+ if err != nil {
+ fmt.Printf("Failed to open bazel workspace status file: %v\n", err)
+ os.Exit(1)
+ }
+ statusVars := make(map[string]string)
+ for _, line := range strings.Split(string(statusFileContent), "\n") {
+ line = strings.TrimSpace(line)
+ parts := strings.Fields(line)
+ if len(parts) != 2 {
+ continue
+ }
+ statusVars[parts[0]] = parts[1]
+ }
+
+ smalltownVersion, ok := statusVars[*flagStampVar]
+ if !ok {
+ fmt.Printf("%v key not set in bazel workspace status file\n", *flagStampVar)
+ os.Exit(1)
+ }
+ // As specified by https://www.freedesktop.org/software/systemd/man/os-release.html
+ osReleaseVars := map[string]string{
+ "NAME": *flagName,
+ "ID": *flagID,
+ "VERSION": smalltownVersion,
+ "VERSION_ID": smalltownVersion,
+ "PRETTY_NAME": *flagName + " " + smalltownVersion,
+ }
+ osReleaseContent, err := godotenv.Marshal(osReleaseVars)
+ if err != nil {
+ fmt.Printf("Failed to encode os-release file: %v\n", err)
+ os.Exit(1)
+ }
+ if err := ioutil.WriteFile(*flagOutFile, []byte(osReleaseContent), 0644); err != nil {
+ fmt.Printf("Failed to write os-release file: %v\n", err)
+ os.Exit(1)
+ }
+}
diff --git a/metropolis/node/build/kconfig-patcher/BUILD.bazel b/metropolis/node/build/kconfig-patcher/BUILD.bazel
new file mode 100644
index 0000000..55b2b52
--- /dev/null
+++ b/metropolis/node/build/kconfig-patcher/BUILD.bazel
@@ -0,0 +1,20 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_library", "go_test")
+
+go_library(
+ name = "go_default_library",
+ srcs = ["main.go"],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/node/build/kconfig-patcher",
+ visibility = ["//visibility:private"],
+)
+
+go_binary(
+ name = "kconfig-patcher",
+ embed = [":go_default_library"],
+ visibility = ["//visibility:public"],
+)
+
+go_test(
+ name = "go_default_test",
+ srcs = ["main_test.go"],
+ embed = [":go_default_library"],
+)
diff --git a/metropolis/node/build/kconfig-patcher/kconfig-patcher.bzl b/metropolis/node/build/kconfig-patcher/kconfig-patcher.bzl
new file mode 100644
index 0000000..337642e
--- /dev/null
+++ b/metropolis/node/build/kconfig-patcher/kconfig-patcher.bzl
@@ -0,0 +1,33 @@
+# Copyright 2020 The Monogon Project Authors.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Override configs in a Linux kernel Kconfig
+"""
+
+def kconfig_patch(name, src, out, override_configs, **kwargs):
+ native.genrule(
+ name = name,
+ srcs = [src],
+ outs = [out],
+ tools = [
+ "//metropolis/node/build/kconfig-patcher",
+ ],
+ cmd = """
+ $(location //metropolis/node/build/kconfig-patcher) \
+ -in $< -out $@ '%s'
+ """ % struct(overrides = override_configs).to_json(),
+ **kwargs
+ )
diff --git a/metropolis/node/build/kconfig-patcher/main.go b/metropolis/node/build/kconfig-patcher/main.go
new file mode 100644
index 0000000..27c33e9
--- /dev/null
+++ b/metropolis/node/build/kconfig-patcher/main.go
@@ -0,0 +1,95 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "bufio"
+ "encoding/json"
+ "flag"
+ "fmt"
+ "io"
+ "os"
+ "strings"
+)
+
+var (
+ inPath = flag.String("in", "", "Path to input Kconfig")
+ outPath = flag.String("out", "", "Path to output Kconfig")
+)
+
+func main() {
+ flag.Parse()
+ if *inPath == "" || *outPath == "" {
+ flag.PrintDefaults()
+ os.Exit(2)
+ }
+ inFile, err := os.Open(*inPath)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Failed to open input Kconfig: %v\n", err)
+ os.Exit(1)
+ }
+ outFile, err := os.Create(*outPath)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Failed to create output Kconfig: %v\n", err)
+ os.Exit(1)
+ }
+ var config struct {
+ Overrides map[string]string `json:"overrides"`
+ }
+ if err := json.Unmarshal([]byte(flag.Arg(0)), &config); err != nil {
+ fmt.Fprintf(os.Stderr, "Failed to parse overrides: %v\n", err)
+ os.Exit(1)
+ }
+ err = patchKconfig(inFile, outFile, config.Overrides)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Failed to patch: %v\n", err)
+ os.Exit(1)
+ }
+}
+
+func patchKconfig(inFile io.Reader, outFile io.Writer, overrides map[string]string) error {
+ scanner := bufio.NewScanner(inFile)
+ for scanner.Scan() {
+ line := scanner.Text()
+ cleanLine := strings.TrimSpace(line)
+ if strings.HasPrefix(cleanLine, "#") || cleanLine == "" {
+ // Pass through comments and empty lines
+ fmt.Fprintln(outFile, line)
+ } else {
+ // Line contains a configuration option
+ parts := strings.SplitN(line, "=", 2)
+ keyName := parts[0]
+ if overrideVal, ok := overrides[strings.TrimSpace(keyName)]; ok {
+ // Override it
+ if overrideVal == "" {
+ fmt.Fprintf(outFile, "# %v is not set\n", keyName)
+ } else {
+ fmt.Fprintf(outFile, "%v=%v\n", keyName, overrideVal)
+ }
+ delete(overrides, keyName)
+ } else {
+ // Pass through unchanged
+ fmt.Fprintln(outFile, line)
+ }
+ }
+ }
+ // Process left over overrides
+ for key, val := range overrides {
+ fmt.Fprintf(outFile, "%v=%v\n", key, val)
+ }
+ return nil
+}
diff --git a/metropolis/node/build/kconfig-patcher/main_test.go b/metropolis/node/build/kconfig-patcher/main_test.go
new file mode 100644
index 0000000..11c7d84
--- /dev/null
+++ b/metropolis/node/build/kconfig-patcher/main_test.go
@@ -0,0 +1,61 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "bytes"
+ "strings"
+ "testing"
+)
+
+func Test_patchKconfig(t *testing.T) {
+ type args struct {
+ inFile string
+ overrides map[string]string
+ }
+ tests := []struct {
+ name string
+ args args
+ wantOutFile string
+ wantErr bool
+ }{
+ {
+ "passthroughExtend",
+ args{inFile: "# TEST=y\n\n", overrides: map[string]string{"TEST": "n"}},
+ "# TEST=y\n\nTEST=n\n",
+ false,
+ },
+ {
+ "patch",
+ args{inFile: "TEST=y\nTEST_NO=n\n", overrides: map[string]string{"TEST": "n"}},
+ "TEST=n\nTEST_NO=n\n",
+ false,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ outFile := &bytes.Buffer{}
+ if err := patchKconfig(strings.NewReader(tt.args.inFile), outFile, tt.args.overrides); (err != nil) != tt.wantErr {
+ t.Errorf("patchKconfig() error = %v, wantErr %v", err, tt.wantErr)
+ return
+ }
+ if gotOutFile := outFile.String(); gotOutFile != tt.wantOutFile {
+ t.Errorf("patchKconfig() = %v, want %v", gotOutFile, tt.wantOutFile)
+ }
+ })
+ }
+}
diff --git a/metropolis/node/build/mkimage/BUILD.bazel b/metropolis/node/build/mkimage/BUILD.bazel
new file mode 100644
index 0000000..b489002
--- /dev/null
+++ b/metropolis/node/build/mkimage/BUILD.bazel
@@ -0,0 +1,20 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_library")
+
+go_library(
+ name = "go_default_library",
+ srcs = ["main.go"],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/node/build/mkimage",
+ visibility = ["//visibility:private"],
+ deps = [
+ "@com_github_diskfs_go_diskfs//:go_default_library",
+ "@com_github_diskfs_go_diskfs//disk:go_default_library",
+ "@com_github_diskfs_go_diskfs//filesystem:go_default_library",
+ "@com_github_diskfs_go_diskfs//partition/gpt:go_default_library",
+ ],
+)
+
+go_binary(
+ name = "mkimage",
+ embed = [":go_default_library"],
+ visibility = ["//visibility:public"],
+)
diff --git a/metropolis/node/build/mkimage/main.go b/metropolis/node/build/mkimage/main.go
new file mode 100644
index 0000000..9f49f0a
--- /dev/null
+++ b/metropolis/node/build/mkimage/main.go
@@ -0,0 +1,141 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+// mkimage is a tool to generate a Smalltown disk image containing the given EFI payload, and optionally, a given external
+// initramfs image and enrolment credentials.
+
+import (
+ "flag"
+ "fmt"
+ "io/ioutil"
+ "log"
+ "os"
+
+ diskfs "github.com/diskfs/go-diskfs"
+ "github.com/diskfs/go-diskfs/disk"
+ "github.com/diskfs/go-diskfs/filesystem"
+ "github.com/diskfs/go-diskfs/partition/gpt"
+)
+
+var SmalltownDataPartition gpt.Type = gpt.Type("9eeec464-6885-414a-b278-4305c51f7966")
+
+var (
+ flagEFI string
+ flagOut string
+ flagInitramfs string
+ flagEnrolmentCredentials string
+ flagDataPartitionSize uint64
+ flagESPPartitionSize uint64
+)
+
+func mibToSectors(size uint64) uint64 {
+ return (size * 1024 * 1024) / 512
+}
+
+func main() {
+ flag.StringVar(&flagEFI, "efi", "", "UEFI payload")
+ flag.StringVar(&flagOut, "out", "", "Output disk image")
+ flag.StringVar(&flagInitramfs, "initramfs", "", "External initramfs [optional]")
+ flag.StringVar(&flagEnrolmentCredentials, "enrolment_credentials", "", "Enrolment credentials [optional]")
+ flag.Uint64Var(&flagDataPartitionSize, "data_partition_size", 2048, "Override the data partition size (default 2048 MiB)")
+ flag.Uint64Var(&flagESPPartitionSize, "esp_partition_size", 512, "Override the ESP partition size (default: 512MiB)")
+ flag.Parse()
+
+ if flagEFI == "" || flagOut == "" {
+ log.Fatalf("efi and initramfs must be set")
+ }
+
+ _ = os.Remove(flagOut)
+ diskImg, err := diskfs.Create(flagOut, 3*1024*1024*1024, diskfs.Raw)
+ if err != nil {
+ log.Fatalf("diskfs.Create(%q): %v", flagOut, err)
+ }
+
+ table := &gpt.Table{
+ // This is appropriate at least for virtio disks. Might need to be adjusted for real ones.
+ LogicalSectorSize: 512,
+ PhysicalSectorSize: 512,
+ ProtectiveMBR: true,
+ Partitions: []*gpt.Partition{
+ {
+ Type: gpt.EFISystemPartition,
+ Name: "ESP",
+ Start: mibToSectors(1),
+ End: mibToSectors(flagESPPartitionSize) - 1,
+ },
+ {
+ Type: SmalltownDataPartition,
+ Name: "SIGNOS-DATA",
+ Start: mibToSectors(flagESPPartitionSize),
+ End: mibToSectors(flagESPPartitionSize+flagDataPartitionSize) - 1,
+ },
+ },
+ }
+ if err := diskImg.Partition(table); err != nil {
+ log.Fatalf("Failed to apply partition table: %v", err)
+ }
+
+ fs, err := diskImg.CreateFilesystem(disk.FilesystemSpec{Partition: 1, FSType: filesystem.TypeFat32, VolumeLabel: "ESP"})
+ if err != nil {
+ log.Fatalf("Failed to create filesystem: %v", err)
+ }
+
+ // Create EFI partition structure.
+ for _, dir := range []string{"/EFI", "/EFI/BOOT", "/EFI/smalltown"} {
+ if err := fs.Mkdir(dir); err != nil {
+ log.Fatalf("Mkdir(%q): %v", dir, err)
+ }
+ }
+
+ put(fs, flagEFI, "/EFI/BOOT/BOOTX64.EFI")
+
+ if flagInitramfs != "" {
+ put(fs, flagInitramfs, "/EFI/smalltown/initramfs.cpio.lz4")
+ }
+
+ if flagEnrolmentCredentials != "" {
+ put(fs, flagEnrolmentCredentials, "/EFI/smalltown/enrolment.pb")
+ }
+
+ if err := diskImg.File.Close(); err != nil {
+ log.Fatalf("Failed to finalize image: %v", err)
+ }
+ log.Printf("Success! You can now boot %v", flagOut)
+}
+
+// put copies a file from the host filesystem into the target image.
+func put(fs filesystem.FileSystem, src, dst string) {
+ target, err := fs.OpenFile(dst, os.O_CREATE|os.O_RDWR)
+ if err != nil {
+ log.Fatalf("fs.OpenFile(%q): %v", dst, err)
+ }
+ source, err := os.Open(src)
+ if err != nil {
+ log.Fatalf("os.Open(%q): %v", src, err)
+ }
+ defer source.Close()
+ // If this is streamed (e.g. using io.Copy) it exposes a bug in diskfs, so do it in one go.
+ data, err := ioutil.ReadAll(source)
+ if err != nil {
+ log.Fatalf("Reading %q: %v", src, err)
+ }
+ if _, err := target.Write(data); err != nil {
+ fmt.Printf("writing file %q: %v", dst, err)
+ os.Exit(1)
+ }
+}
diff --git a/metropolis/node/common/devicemapper/BUILD.bazel b/metropolis/node/common/devicemapper/BUILD.bazel
new file mode 100644
index 0000000..12ca0b3
--- /dev/null
+++ b/metropolis/node/common/devicemapper/BUILD.bazel
@@ -0,0 +1,13 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "go_default_library",
+ srcs = ["devicemapper.go"],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/node/common/devicemapper",
+ visibility = ["//visibility:public"],
+ deps = [
+ "@com_github_pkg_errors//:go_default_library",
+ "@com_github_yalue_native_endian//:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/metropolis/node/common/devicemapper/devicemapper.go b/metropolis/node/common/devicemapper/devicemapper.go
new file mode 100644
index 0000000..2687e3a
--- /dev/null
+++ b/metropolis/node/common/devicemapper/devicemapper.go
@@ -0,0 +1,298 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// package devicemapper is a thin wrapper for the devicemapper ioctl API.
+// See: https://github.com/torvalds/linux/blob/master/include/uapi/linux/dm-ioctl.h
+package devicemapper
+
+import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "os"
+ "runtime"
+ "unsafe"
+
+ "github.com/pkg/errors"
+ "github.com/yalue/native_endian"
+ "golang.org/x/sys/unix"
+)
+
+type DMIoctl struct {
+ Version Version
+ DataSize uint32
+ DataStart uint32
+ TargetCount uint32
+ OpenCount int32
+ Flags uint32
+ EventNumber uint32
+ _padding1 uint32
+ Dev uint64
+ Name [128]byte
+ UUID [129]byte
+ _padding2 [7]byte
+ Data [16384]byte
+}
+
+type DMTargetSpec struct {
+ SectorStart uint64
+ Length uint64
+ Status int32
+ Next uint32
+ TargetType [16]byte
+}
+
+type DMTargetDeps struct {
+ Count uint32
+ Padding uint32
+ Dev []uint64
+}
+
+type DMNameList struct {
+ Dev uint64
+ Next uint32
+ Name []byte
+}
+
+type DMTargetVersions struct {
+ Next uint32
+ Version [3]uint32
+}
+
+type DMTargetMessage struct {
+ Sector uint64
+ Message []byte
+}
+
+type Version [3]uint32
+
+const (
+ /* Top level cmds */
+ DM_VERSION_CMD uintptr = (0xc138fd << 8) + iota
+ DM_REMOVE_ALL_CMD
+ DM_LIST_DEVICES_CMD
+
+ /* device level cmds */
+ DM_DEV_CREATE_CMD
+ DM_DEV_REMOVE_CMD
+ DM_DEV_RENAME_CMD
+ DM_DEV_SUSPEND_CMD
+ DM_DEV_STATUS_CMD
+ DM_DEV_WAIT_CMD
+
+ /* Table level cmds */
+ DM_TABLE_LOAD_CMD
+ DM_TABLE_CLEAR_CMD
+ DM_TABLE_DEPS_CMD
+ DM_TABLE_STATUS_CMD
+
+ /* Added later */
+ DM_LIST_VERSIONS_CMD
+ DM_TARGET_MSG_CMD
+ DM_DEV_SET_GEOMETRY_CMD
+ DM_DEV_ARM_POLL_CMD
+)
+
+const (
+ DM_READONLY_FLAG = 1 << 0 /* In/Out */
+ DM_SUSPEND_FLAG = 1 << 1 /* In/Out */
+ DM_PERSISTENT_DEV_FLAG = 1 << 3 /* In */
+)
+
+const baseDataSize = uint32(unsafe.Sizeof(DMIoctl{})) - 16384
+
+func newReq() DMIoctl {
+ return DMIoctl{
+ Version: Version{4, 0, 0},
+ DataSize: baseDataSize,
+ DataStart: baseDataSize,
+ }
+}
+
+// stringToDelimitedBuf copies src to dst and returns an error if len(src) > len(dst),
+// or when the string contains a null byte.
+func stringToDelimitedBuf(dst []byte, src string) error {
+ if len(src) > len(dst)-1 {
+ return fmt.Errorf("string longer than target buffer (%v > %v)", len(src), len(dst)-1)
+ }
+ for i := 0; i < len(src); i++ {
+ if src[i] == 0x00 {
+ return errors.New("string contains null byte, this is unsupported by DM")
+ }
+ dst[i] = src[i]
+ }
+ return nil
+}
+
+var fd uintptr
+
+func getFd() (uintptr, error) {
+ if fd == 0 {
+ f, err := os.Open("/dev/mapper/control")
+ if os.IsNotExist(err) {
+ _ = os.MkdirAll("/dev/mapper", 0755)
+ if err := unix.Mknod("/dev/mapper/control", unix.S_IFCHR|0600, int(unix.Mkdev(10, 236))); err != nil {
+ return 0, err
+ }
+ f, err = os.Open("/dev/mapper/control")
+ if err != nil {
+ return 0, err
+ }
+ } else if err != nil {
+ return 0, err
+ }
+ fd = f.Fd()
+ return f.Fd(), nil
+ }
+ return fd, nil
+}
+
+func GetVersion() (Version, error) {
+ req := newReq()
+ fd, err := getFd()
+ if err != nil {
+ return Version{}, err
+ }
+ if _, _, err := unix.Syscall(unix.SYS_IOCTL, fd, DM_VERSION_CMD, uintptr(unsafe.Pointer(&req))); err != 0 {
+ return Version{}, err
+ }
+ return req.Version, nil
+}
+
+func CreateDevice(name string) (uint64, error) {
+ req := newReq()
+ if err := stringToDelimitedBuf(req.Name[:], name); err != nil {
+ return 0, err
+ }
+ fd, err := getFd()
+ if err != nil {
+ return 0, err
+ }
+ if _, _, err := unix.Syscall(unix.SYS_IOCTL, fd, DM_DEV_CREATE_CMD, uintptr(unsafe.Pointer(&req))); err != 0 {
+ return 0, err
+ }
+ return req.Dev, nil
+}
+
+func RemoveDevice(name string) error {
+ req := newReq()
+ if err := stringToDelimitedBuf(req.Name[:], name); err != nil {
+ return err
+ }
+ fd, err := getFd()
+ if err != nil {
+ return err
+ }
+ if _, _, err := unix.Syscall(unix.SYS_IOCTL, fd, DM_DEV_REMOVE_CMD, uintptr(unsafe.Pointer(&req))); err != 0 {
+ return err
+ }
+ runtime.KeepAlive(req)
+ return nil
+}
+
+type Target struct {
+ StartSector uint64
+ Length uint64
+ Type string
+ Parameters string
+}
+
+func LoadTable(name string, targets []Target) error {
+ req := newReq()
+ if err := stringToDelimitedBuf(req.Name[:], name); err != nil {
+ return err
+ }
+ var data bytes.Buffer
+ for _, target := range targets {
+ // Gives the size of the spec and the null-terminated params aligned to 8 bytes
+ padding := len(target.Parameters) % 8
+ targetSize := uint32(int(unsafe.Sizeof(DMTargetSpec{})) + (len(target.Parameters) + 1) + padding)
+
+ targetSpec := DMTargetSpec{
+ SectorStart: target.StartSector,
+ Length: target.Length,
+ Next: targetSize,
+ }
+ if err := stringToDelimitedBuf(targetSpec.TargetType[:], target.Type); err != nil {
+ return err
+ }
+ if err := binary.Write(&data, native_endian.NativeEndian(), &targetSpec); err != nil {
+ panic(err)
+ }
+ data.WriteString(target.Parameters)
+ data.WriteByte(0x00)
+ for i := 0; i < padding; i++ {
+ data.WriteByte(0x00)
+ }
+ }
+ req.TargetCount = uint32(len(targets))
+ if data.Len() >= 16384 {
+ return errors.New("table too large for allocated memory")
+ }
+ req.DataSize = baseDataSize + uint32(data.Len())
+ copy(req.Data[:], data.Bytes())
+ fd, err := getFd()
+ if err != nil {
+ return err
+ }
+ if _, _, err := unix.Syscall(unix.SYS_IOCTL, fd, DM_TABLE_LOAD_CMD, uintptr(unsafe.Pointer(&req))); err != 0 {
+ return err
+ }
+ runtime.KeepAlive(req)
+ return nil
+}
+
+func suspendResume(name string, suspend bool) error {
+ req := newReq()
+ if err := stringToDelimitedBuf(req.Name[:], name); err != nil {
+ return err
+ }
+ if suspend {
+ req.Flags = DM_SUSPEND_FLAG
+ }
+ fd, err := getFd()
+ if err != nil {
+ return err
+ }
+ if _, _, err := unix.Syscall(unix.SYS_IOCTL, fd, DM_DEV_SUSPEND_CMD, uintptr(unsafe.Pointer(&req))); err != 0 {
+ return err
+ }
+ runtime.KeepAlive(req)
+ return nil
+}
+
+func Suspend(name string) error {
+ return suspendResume(name, true)
+}
+func Resume(name string) error {
+ return suspendResume(name, false)
+}
+
+func CreateActiveDevice(name string, targets []Target) (uint64, error) {
+ dev, err := CreateDevice(name)
+ if err != nil {
+ return 0, fmt.Errorf("DM_DEV_CREATE failed: %w", err)
+ }
+ if err := LoadTable(name, targets); err != nil {
+ _ = RemoveDevice(name)
+ return 0, fmt.Errorf("DM_TABLE_LOAD failed: %w", err)
+ }
+ if err := Resume(name); err != nil {
+ _ = RemoveDevice(name)
+ return 0, fmt.Errorf("DM_DEV_SUSPEND failed: %w", err)
+ }
+ return dev, nil
+}
diff --git a/metropolis/node/common/fileargs/BUILD.bazel b/metropolis/node/common/fileargs/BUILD.bazel
new file mode 100644
index 0000000..c4fffc2
--- /dev/null
+++ b/metropolis/node/common/fileargs/BUILD.bazel
@@ -0,0 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "go_default_library",
+ srcs = ["fileargs.go"],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/node/common/fileargs",
+ visibility = ["//visibility:public"],
+ deps = ["@org_golang_x_sys//unix:go_default_library"],
+)
diff --git a/metropolis/node/common/fileargs/fileargs.go b/metropolis/node/common/fileargs/fileargs.go
new file mode 100644
index 0000000..26c054b
--- /dev/null
+++ b/metropolis/node/common/fileargs/fileargs.go
@@ -0,0 +1,101 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fileargs
+
+import (
+ "crypto/rand"
+ "encoding/hex"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+
+ "golang.org/x/sys/unix"
+)
+
+// DefaultSize is the default size limit for FileArgs
+const DefaultSize = 4 * 1024 * 1024
+
+// TempDirectory is the directory where FileArgs will mount the actual files to. Defaults to
+// os.TempDir() but can be globally overridden by the application before any FileArgs are used.
+var TempDirectory = os.TempDir()
+
+type FileArgs struct {
+ path string
+ lastError error
+}
+
+// New initializes a new set of file-based arguments. Remember to call Close() if you're done
+// using it, otherwise this leaks memory and mounts.
+func New() (*FileArgs, error) {
+ return NewWithSize(DefaultSize)
+}
+
+// NewWthSize is the same as new, but with a custom size limit. Please be aware that this data
+// cannot be swapped out and using a size limit that's too high can deadlock your kernel.
+func NewWithSize(size uint64) (*FileArgs, error) {
+ randomNameRaw := make([]byte, 128/8)
+ if _, err := io.ReadFull(rand.Reader, randomNameRaw); err != nil {
+ return nil, err
+ }
+ tmpPath := filepath.Join(TempDirectory, hex.EncodeToString(randomNameRaw))
+ if err := os.MkdirAll(tmpPath, 0700); err != nil {
+ return nil, err
+ }
+ // This uses ramfs instead of tmpfs because we never want to swap this for security reasons
+ if err := unix.Mount("none", tmpPath, "ramfs", unix.MS_NOEXEC|unix.MS_NOSUID|unix.MS_NODEV, fmt.Sprintf("size=%v", size)); err != nil {
+ return nil, err
+ }
+ return &FileArgs{
+ path: tmpPath,
+ }, nil
+}
+
+// ArgPath returns the path of the temporary file for this argument. It names the temporary
+// file according to name.
+func (f *FileArgs) ArgPath(name string, content []byte) string {
+ if f.lastError != nil {
+ return ""
+ }
+
+ path := filepath.Join(f.path, name)
+
+ if err := ioutil.WriteFile(path, content, 0600); err != nil {
+ f.lastError = err
+ return ""
+ }
+
+ return path
+}
+
+// FileOpt returns a full option with the temporary file name already filled in.
+// Example: `FileOpt("--testopt", "test.txt", []byte("hello")) == "--testopt=/tmp/daf8ed.../test.txt"`
+func (f *FileArgs) FileOpt(optName, fileName string, content []byte) string {
+ return fmt.Sprintf("%v=%v", optName, f.ArgPath(fileName, content))
+}
+
+func (f *FileArgs) Error() error {
+ return f.lastError
+}
+
+func (f *FileArgs) Close() error {
+ if err := unix.Unmount(f.path, 0); err != nil {
+ return err
+ }
+ return os.Remove(f.path)
+}
diff --git a/metropolis/node/common/fsquota/BUILD.bazel b/metropolis/node/common/fsquota/BUILD.bazel
new file mode 100644
index 0000000..b16d39e
--- /dev/null
+++ b/metropolis/node/common/fsquota/BUILD.bazel
@@ -0,0 +1,39 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+load("//metropolis/test/ktest:ktest.bzl", "ktest")
+
+go_library(
+ name = "go_default_library",
+ srcs = [
+ "fsinfo.go",
+ "fsquota.go",
+ ],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/node/common/fsquota",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//metropolis/node/common/fsquota/fsxattrs:go_default_library",
+ "//metropolis/node/common/fsquota/quotactl:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+go_test(
+ name = "go_default_test",
+ srcs = ["fsquota_test.go"],
+ embed = [":go_default_library"],
+ pure = "on",
+ deps = [
+ "@com_github_stretchr_testify//require:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+ktest(
+ tester = ":go_default_test",
+ deps = [
+ "//third_party/xfsprogs:mkfs.xfs",
+ ],
+ initramfs_extra = """
+file /mkfs.xfs $(location //third_party/xfsprogs:mkfs.xfs) 0755 0 0
+ """,
+ cmdline = "ramdisk_size=51200",
+)
diff --git a/metropolis/node/common/fsquota/fsinfo.go b/metropolis/node/common/fsquota/fsinfo.go
new file mode 100644
index 0000000..e40a533
--- /dev/null
+++ b/metropolis/node/common/fsquota/fsinfo.go
@@ -0,0 +1,59 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsquota
+
+import (
+ "fmt"
+ "os"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+)
+
+// This requires fsinfo() support, which is not yet in any stable kernel.
+// Our kernel has that syscall backported. This would otherwise be an extremely expensive
+// operation and also involve lots of logic from our side.
+
+// From syscall_64.tbl
+const sys_fsinfo = 441
+
+// From uapi/linux/fsinfo.h
+const fsinfo_attr_source = 0x09
+const fsinfo_flags_query_path = 0x0000
+const fsinfo_flags_query_fd = 0x0001
+
+type fsinfoParams struct {
+ resolveFlags uint64
+ atFlags uint32
+ flags uint32
+ request uint32
+ nth uint32
+ mth uint32
+}
+
+func fsinfoGetSource(dir *os.File) (string, error) {
+ buf := make([]byte, 256)
+ params := fsinfoParams{
+ flags: fsinfo_flags_query_fd,
+ request: fsinfo_attr_source,
+ }
+ n, _, err := unix.Syscall6(sys_fsinfo, dir.Fd(), 0, uintptr(unsafe.Pointer(¶ms)), unsafe.Sizeof(params), uintptr(unsafe.Pointer(&buf[0])), 128)
+ if err != unix.Errno(0) {
+ return "", fmt.Errorf("failed to call fsinfo: %w", err)
+ }
+ return string(buf[:n]), nil
+}
diff --git a/metropolis/node/common/fsquota/fsquota.go b/metropolis/node/common/fsquota/fsquota.go
new file mode 100644
index 0000000..f702d23
--- /dev/null
+++ b/metropolis/node/common/fsquota/fsquota.go
@@ -0,0 +1,146 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package fsquota provides a simplified interface to interact with Linux's filesystem qouta API.
+// It only supports setting quotas on directories, not groups or users.
+// Quotas need to be already enabled on the filesystem to be able to use them using this package.
+// See the quotactl package if you intend to use this on a filesystem where quotas need to be
+// enabled manually.
+package fsquota
+
+import (
+ "fmt"
+ "math"
+ "os"
+
+ "golang.org/x/sys/unix"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/node/common/fsquota/fsxattrs"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/common/fsquota/quotactl"
+)
+
+// SetQuota sets the quota of bytes and/or inodes in a given path. To not set a limit, set the
+// corresponding argument to zero. Setting both arguments to zero removes the quota entirely.
+// This function can only be called on an empty directory. It can't be used to create a quota
+// below a directory which already has a quota since Linux doesn't offer hierarchical quotas.
+func SetQuota(path string, maxBytes uint64, maxInodes uint64) error {
+ dir, err := os.Open(path)
+ if err != nil {
+ return err
+ }
+ defer dir.Close()
+ source, err := fsinfoGetSource(dir)
+ if err != nil {
+ return err
+ }
+ var valid uint32
+ if maxBytes > 0 {
+ valid |= quotactl.FlagBLimitsValid
+ }
+ if maxInodes > 0 {
+ valid |= quotactl.FlagILimitsValid
+ }
+
+ attrs, err := fsxattrs.Get(dir)
+ if err != nil {
+ return err
+ }
+
+ var lastID uint32 = attrs.ProjectID
+ if lastID == 0 {
+ // No project/quota exists for this directory, assign a new project quota
+ // TODO(lorenz): This is racy, but the kernel does not support atomically assigning
+ // quotas. So this needs to be added to the kernels setquota interface. Due to the short
+ // time window and infrequent calls this should not be an immediate issue.
+ for {
+ quota, err := quotactl.GetNextQuota(source, quotactl.QuotaTypeProject, lastID)
+ if err == unix.ENOENT || err == unix.ESRCH {
+ // We have enumerated all quotas, nothing exists here
+ break
+ } else if err != nil {
+ return fmt.Errorf("failed to call GetNextQuota: %w", err)
+ }
+ if quota.ID > lastID+1 {
+ // Take the first ID in the quota ID gap
+ lastID++
+ break
+ }
+ lastID++
+ }
+ }
+
+ // If both limits are zero, this is a delete operation, process it as such
+ if maxBytes == 0 && maxInodes == 0 {
+ valid = quotactl.FlagBLimitsValid | quotactl.FlagILimitsValid
+ attrs.ProjectID = 0
+ attrs.Flags &= ^fsxattrs.FlagProjectInherit
+ } else {
+ attrs.ProjectID = lastID
+ attrs.Flags |= fsxattrs.FlagProjectInherit
+ }
+
+ if err := fsxattrs.Set(dir, attrs); err != nil {
+ return err
+ }
+
+ // Always round up to the nearest block size
+ bytesLimitBlocks := uint64(math.Ceil(float64(maxBytes) / float64(1024)))
+
+ return quotactl.SetQuota(source, quotactl.QuotaTypeProject, lastID, "actl.Quota{
+ BHardLimit: bytesLimitBlocks,
+ BSoftLimit: bytesLimitBlocks,
+ IHardLimit: maxInodes,
+ ISoftLimit: maxInodes,
+ Valid: valid,
+ })
+}
+
+type Quota struct {
+ Bytes uint64
+ BytesUsed uint64
+ Inodes uint64
+ InodesUsed uint64
+}
+
+// GetQuota returns the current active quota and its utilization at the given path
+func GetQuota(path string) (*Quota, error) {
+ dir, err := os.Open(path)
+ if err != nil {
+ return nil, err
+ }
+ defer dir.Close()
+ source, err := fsinfoGetSource(dir)
+ if err != nil {
+ return nil, err
+ }
+ attrs, err := fsxattrs.Get(dir)
+ if err != nil {
+ return nil, err
+ }
+ if attrs.ProjectID == 0 {
+ return nil, os.ErrNotExist
+ }
+ quota, err := quotactl.GetQuota(source, quotactl.QuotaTypeProject, attrs.ProjectID)
+ if err != nil {
+ return nil, err
+ }
+ return &Quota{
+ Bytes: quota.BHardLimit * 1024,
+ BytesUsed: quota.CurSpace,
+ Inodes: quota.IHardLimit,
+ InodesUsed: quota.CurInodes,
+ }, nil
+}
diff --git a/metropolis/node/common/fsquota/fsquota_test.go b/metropolis/node/common/fsquota/fsquota_test.go
new file mode 100644
index 0000000..4729dac
--- /dev/null
+++ b/metropolis/node/common/fsquota/fsquota_test.go
@@ -0,0 +1,152 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsquota
+
+import (
+ "fmt"
+ "io/ioutil"
+ "math"
+ "os"
+ "os/exec"
+ "syscall"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "golang.org/x/sys/unix"
+)
+
+// withinTolerance is a helper for asserting that a value is within a certain percentage of the
+// expected value. The tolerance is specified as a float between 0 (exact match)
+// and 1 (between 0 and twice the expected value).
+func withinTolerance(t *testing.T, expected uint64, actual uint64, tolerance float64, name string) {
+ t.Helper()
+ delta := uint64(math.Round(float64(expected) * tolerance))
+ lowerBound := expected - delta
+ upperBound := expected + delta
+ if actual < lowerBound {
+ t.Errorf("Value %v (%v) is too low, expected between %v and %v", name, actual, lowerBound, upperBound)
+ }
+ if actual > upperBound {
+ t.Errorf("Value %v (%v) is too high, expected between %v and %v", name, actual, lowerBound, upperBound)
+ }
+}
+
+func TestBasic(t *testing.T) {
+ if os.Getenv("IN_KTEST") != "true" {
+ t.Skip("Not in ktest")
+ }
+ mkfsCmd := exec.Command("/mkfs.xfs", "-qf", "/dev/ram0")
+ if _, err := mkfsCmd.Output(); err != nil {
+ t.Fatal(err)
+ }
+ if err := os.Mkdir("/test", 0755); err != nil {
+ t.Error(err)
+ }
+
+ if err := unix.Mount("/dev/ram0", "/test", "xfs", unix.MS_NOEXEC|unix.MS_NODEV, "prjquota"); err != nil {
+ t.Fatal(err)
+ }
+ defer unix.Unmount("/test", 0)
+ defer os.RemoveAll("/test")
+ t.Run("SetQuota", func(t *testing.T) {
+ defer func() {
+ os.RemoveAll("/test/set")
+ }()
+ if err := os.Mkdir("/test/set", 0755); err != nil {
+ t.Fatal(err)
+ }
+ if err := SetQuota("/test/set", 1024*1024, 100); err != nil {
+ t.Fatal(err)
+ }
+ })
+ t.Run("SetQuotaAndExhaust", func(t *testing.T) {
+ defer func() {
+ os.RemoveAll("/test/sizequota")
+ }()
+ if err := os.Mkdir("/test/sizequota", 0755); err != nil {
+ t.Fatal(err)
+ }
+ const bytesQuota = 1024 * 1024 // 1MiB
+ if err := SetQuota("/test/sizequota", bytesQuota, 0); err != nil {
+ t.Fatal(err)
+ }
+ testfile, err := os.Create("/test/sizequota/testfile")
+ if err != nil {
+ t.Fatal(err)
+ }
+ testdata := make([]byte, 1024)
+ var bytesWritten int
+ for {
+ n, err := testfile.Write([]byte(testdata))
+ if err != nil {
+ if pathErr, ok := err.(*os.PathError); ok {
+ if pathErr.Err == syscall.ENOSPC {
+ // Running out of space is the only acceptable error to continue execution
+ break
+ }
+ }
+ t.Fatal(err)
+ }
+ bytesWritten += n
+ }
+ if bytesWritten > bytesQuota {
+ t.Errorf("Wrote %v bytes, quota is only %v bytes", bytesWritten, bytesQuota)
+ }
+ })
+ t.Run("GetQuotaReadbackAndUtilization", func(t *testing.T) {
+ defer func() {
+ os.RemoveAll("/test/readback")
+ }()
+ if err := os.Mkdir("/test/readback", 0755); err != nil {
+ t.Fatal(err)
+ }
+ const bytesQuota = 1024 * 1024 // 1MiB
+ const inodesQuota = 100
+ if err := SetQuota("/test/readback", bytesQuota, inodesQuota); err != nil {
+ t.Fatal(err)
+ }
+ sizeFileData := make([]byte, 512*1024)
+ if err := ioutil.WriteFile("/test/readback/512kfile", sizeFileData, 0644); err != nil {
+ t.Fatal(err)
+ }
+
+ quotaUtil, err := GetQuota("/test/readback")
+ if err != nil {
+ t.Fatal(err)
+ }
+ require.Equal(t, uint64(bytesQuota), quotaUtil.Bytes, "bytes quota readback incorrect")
+ require.Equal(t, uint64(inodesQuota), quotaUtil.Inodes, "inodes quota readback incorrect")
+
+ // Give 10% tolerance for quota used values to account for metadata overhead and internal
+ // structures that are also in there. If it's out by more than that it's an issue anyways.
+ withinTolerance(t, uint64(len(sizeFileData)), quotaUtil.BytesUsed, 0.1, "BytesUsed")
+
+ // Write 50 inodes for a total of 51 (with the 512K file)
+ for i := 0; i < 50; i++ {
+ if err := ioutil.WriteFile(fmt.Sprintf("/test/readback/ifile%v", i), []byte("test"), 0644); err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ quotaUtil, err = GetQuota("/test/readback")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ withinTolerance(t, 51, quotaUtil.InodesUsed, 0.1, "InodesUsed")
+ })
+}
diff --git a/metropolis/node/common/fsquota/fsxattrs/BUILD.bazel b/metropolis/node/common/fsquota/fsxattrs/BUILD.bazel
new file mode 100644
index 0000000..066200b
--- /dev/null
+++ b/metropolis/node/common/fsquota/fsxattrs/BUILD.bazel
@@ -0,0 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "go_default_library",
+ srcs = ["fsxattrs.go"],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/node/common/fsquota/fsxattrs",
+ visibility = ["//visibility:public"],
+ deps = ["@org_golang_x_sys//unix:go_default_library"],
+)
diff --git a/metropolis/node/common/fsquota/fsxattrs/fsxattrs.go b/metropolis/node/common/fsquota/fsxattrs/fsxattrs.go
new file mode 100644
index 0000000..1d455eb
--- /dev/null
+++ b/metropolis/node/common/fsquota/fsxattrs/fsxattrs.go
@@ -0,0 +1,77 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsxattrs
+
+import (
+ "fmt"
+ "os"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+)
+
+type FSXAttrFlag uint32
+
+// Defined in uapi/linux/fs.h
+const (
+ FlagRealtime FSXAttrFlag = 0x00000001
+ FlagPreallocated FSXAttrFlag = 0x00000002
+ FlagImmutable FSXAttrFlag = 0x00000008
+ FlagAppend FSXAttrFlag = 0x00000010
+ FlagSync FSXAttrFlag = 0x00000020
+ FlagNoATime FSXAttrFlag = 0x00000040
+ FlagNoDump FSXAttrFlag = 0x00000080
+ FlagRealtimeInherit FSXAttrFlag = 0x00000100
+ FlagProjectInherit FSXAttrFlag = 0x00000200
+ FlagNoSymlinks FSXAttrFlag = 0x00000400
+ FlagExtentSize FSXAttrFlag = 0x00000800
+ FlagNoDefragment FSXAttrFlag = 0x00002000
+ FlagFilestream FSXAttrFlag = 0x00004000
+ FlagDAX FSXAttrFlag = 0x00008000
+ FlagCOWExtentSize FSXAttrFlag = 0x00010000
+ FlagHasAttribute FSXAttrFlag = 0x80000000
+)
+
+// FS_IOC_FSGETXATTR/FS_IOC_FSSETXATTR are defined in uapi/linux/fs.h
+const FS_IOC_FSGETXATTR = 0x801c581f
+const FS_IOC_FSSETXATTR = 0x401c5820
+
+type FSXAttrs struct {
+ Flags FSXAttrFlag
+ ExtentSize uint32
+ ExtentCount uint32
+ ProjectID uint32
+ CoWExtentSize uint32
+ _pad [8]byte
+}
+
+func Get(file *os.File) (*FSXAttrs, error) {
+ var attrs FSXAttrs
+ _, _, errno := unix.Syscall(unix.SYS_IOCTL, file.Fd(), FS_IOC_FSGETXATTR, uintptr(unsafe.Pointer(&attrs)))
+ if errno != 0 {
+ return nil, fmt.Errorf("failed to execute getFSXAttrs: %v", errno)
+ }
+ return &attrs, nil
+}
+
+func Set(file *os.File, attrs *FSXAttrs) error {
+ _, _, errno := unix.Syscall(unix.SYS_IOCTL, file.Fd(), FS_IOC_FSSETXATTR, uintptr(unsafe.Pointer(attrs)))
+ if errno != 0 {
+ return fmt.Errorf("failed to execute setFSXAttrs: %v", errno)
+ }
+ return nil
+}
diff --git a/metropolis/node/common/fsquota/quotactl/BUILD.bazel b/metropolis/node/common/fsquota/quotactl/BUILD.bazel
new file mode 100644
index 0000000..c1582ad
--- /dev/null
+++ b/metropolis/node/common/fsquota/quotactl/BUILD.bazel
@@ -0,0 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "go_default_library",
+ srcs = ["quotactl.go"],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/node/common/fsquota/quotactl",
+ visibility = ["//visibility:public"],
+ deps = ["@org_golang_x_sys//unix:go_default_library"],
+)
diff --git a/metropolis/node/common/fsquota/quotactl/quotactl.go b/metropolis/node/common/fsquota/quotactl/quotactl.go
new file mode 100644
index 0000000..5ed77d7
--- /dev/null
+++ b/metropolis/node/common/fsquota/quotactl/quotactl.go
@@ -0,0 +1,233 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package quotactl implements a low-level wrapper around the modern portion of Linux's
+// quotactl() syscall. See the fsquota package for a nicer interface to the most common part
+// of this API.
+package quotactl
+
+import (
+ "fmt"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+)
+
+type QuotaType uint
+
+const (
+ QuotaTypeUser QuotaType = iota
+ QuotaTypeGroup
+ QuotaTypeProject
+)
+
+const (
+ Q_SYNC uint = ((0x800001 + iota) << 8)
+ Q_QUOTAON
+ Q_QUOTAOFF
+ Q_GETFMT
+ Q_GETINFO
+ Q_SETINFO
+ Q_GETQUOTA
+ Q_SETQUOTA
+ Q_GETNEXTQUOTA
+)
+
+const (
+ FlagBLimitsValid = 1 << iota
+ FlagSpaceValid
+ FlagILimitsValid
+ FlagInodesValid
+ FlagBTimeValid
+ FlagITimeValid
+)
+
+type DQInfo struct {
+ Bgrace uint64
+ Igrace uint64
+ Flags uint32
+ Valid uint32
+}
+
+type Quota struct {
+ BHardLimit uint64 // Both Byte limits are prescaled by 1024 (so are in KiB), but CurSpace is in B
+ BSoftLimit uint64
+ CurSpace uint64
+ IHardLimit uint64
+ ISoftLimit uint64
+ CurInodes uint64
+ BTime uint64
+ ITime uint64
+ Valid uint32
+}
+
+type NextDQBlk struct {
+ HardLimitBytes uint64
+ SoftLimitBytes uint64
+ CurrentBytes uint64
+ HardLimitInodes uint64
+ SoftLimitInodes uint64
+ CurrentInodes uint64
+ BTime uint64
+ ITime uint64
+ Valid uint32
+ ID uint32
+}
+
+type QuotaFormat uint32
+
+// Collected from quota_format_type structs
+const (
+ // QuotaFormatNone is a special case where all quota information is
+ // stored inside filesystem metadata and thus requires no quotaFilePath.
+ QuotaFormatNone QuotaFormat = 0
+ QuotaFormatVFSOld QuotaFormat = 1
+ QuotaFormatVFSV0 QuotaFormat = 2
+ QuotaFormatOCFS2 QuotaFormat = 3
+ QuotaFormatVFSV1 QuotaFormat = 4
+)
+
+// QuotaOn turns quota accounting and enforcement on
+func QuotaOn(device string, qtype QuotaType, quotaFormat QuotaFormat, quotaFilePath string) error {
+ devArg, err := unix.BytePtrFromString(device)
+ if err != nil {
+ return err
+ }
+ pathArg, err := unix.BytePtrFromString(quotaFilePath)
+ if err != nil {
+ return err
+ }
+ _, _, err = unix.Syscall6(unix.SYS_QUOTACTL, uintptr(Q_QUOTAON|uint(qtype)), uintptr(unsafe.Pointer(devArg)), uintptr(quotaFormat), uintptr(unsafe.Pointer(pathArg)), 0, 0)
+ if err != unix.Errno(0) {
+ return err
+ }
+ return nil
+}
+
+// QuotaOff turns quotas off
+func QuotaOff(device string, qtype QuotaType) error {
+ devArg, err := unix.BytePtrFromString(device)
+ if err != nil {
+ return err
+ }
+ _, _, err = unix.Syscall6(unix.SYS_QUOTACTL, uintptr(Q_QUOTAOFF|uint(qtype)), uintptr(unsafe.Pointer(devArg)), 0, 0, 0, 0)
+ if err != unix.Errno(0) {
+ return err
+ }
+ return nil
+}
+
+// GetFmt gets the quota format used on given filesystem
+func GetFmt(device string, qtype QuotaType) (uint32, error) {
+ var fmt uint32
+ devArg, err := unix.BytePtrFromString(device)
+ if err != nil {
+ return 0, err
+ }
+ _, _, err = unix.Syscall6(unix.SYS_QUOTACTL, uintptr(Q_GETFMT|uint(qtype)), uintptr(unsafe.Pointer(devArg)), 0, uintptr(unsafe.Pointer(&fmt)), 0, 0)
+ if err != unix.Errno(0) {
+ return 0, err
+ }
+ return fmt, nil
+}
+
+// GetInfo gets information about quota files
+func GetInfo(device string, qtype QuotaType) (*DQInfo, error) {
+ var info DQInfo
+ devArg, err := unix.BytePtrFromString(device)
+ if err != nil {
+ return nil, err
+ }
+ _, _, err = unix.Syscall6(unix.SYS_QUOTACTL, uintptr(Q_GETINFO|uint(qtype)), uintptr(unsafe.Pointer(devArg)), 0, uintptr(unsafe.Pointer(&info)), 0, 0)
+ if err != unix.Errno(0) {
+ return nil, err
+ }
+ return &info, nil
+}
+
+// SetInfo sets information about quota files
+func SetInfo(device string, qtype QuotaType, info *DQInfo) error {
+ devArg, err := unix.BytePtrFromString(device)
+ if err != nil {
+ return err
+ }
+ _, _, err = unix.Syscall6(unix.SYS_QUOTACTL, uintptr(Q_SETINFO|uint(qtype)), uintptr(unsafe.Pointer(devArg)), 0, uintptr(unsafe.Pointer(info)), 0, 0)
+ if err != unix.Errno(0) {
+ return err
+ }
+ return nil
+}
+
+// GetQuota gets user quota structure
+func GetQuota(device string, qtype QuotaType, id uint32) (*Quota, error) {
+ var info Quota
+ devArg, err := unix.BytePtrFromString(device)
+ if err != nil {
+ return nil, err
+ }
+ _, _, err = unix.Syscall6(unix.SYS_QUOTACTL, uintptr(Q_GETQUOTA|uint(qtype)), uintptr(unsafe.Pointer(devArg)), uintptr(id), uintptr(unsafe.Pointer(&info)), 0, 0)
+ if err != unix.Errno(0) {
+ return nil, err
+ }
+ return &info, nil
+}
+
+// GetNextQuota gets disk limits and usage > ID
+func GetNextQuota(device string, qtype QuotaType, id uint32) (*NextDQBlk, error) {
+ var info NextDQBlk
+ devArg, err := unix.BytePtrFromString(device)
+ if err != nil {
+ return nil, err
+ }
+ _, _, err = unix.Syscall6(unix.SYS_QUOTACTL, uintptr(Q_GETNEXTQUOTA|uint(qtype)), uintptr(unsafe.Pointer(devArg)), uintptr(id), uintptr(unsafe.Pointer(&info)), 0, 0)
+ if err != unix.Errno(0) {
+ return nil, err
+ }
+ return &info, nil
+}
+
+// SetQuota sets the given quota
+func SetQuota(device string, qtype QuotaType, id uint32, quota *Quota) error {
+ devArg, err := unix.BytePtrFromString(device)
+ if err != nil {
+ return err
+ }
+ _, _, err = unix.Syscall6(unix.SYS_QUOTACTL, uintptr(Q_SETQUOTA|uint(qtype)), uintptr(unsafe.Pointer(devArg)), uintptr(id), uintptr(unsafe.Pointer(quota)), 0, 0)
+ if err != unix.Errno(0) {
+ return fmt.Errorf("failed to set quota: %w", err)
+ }
+ return nil
+}
+
+// Sync syncs disk copy of filesystems quotas. If device is empty it syncs all filesystems.
+func Sync(device string) error {
+ if device != "" {
+ devArg, err := unix.BytePtrFromString(device)
+ if err != nil {
+ return err
+ }
+ _, _, err = unix.Syscall6(unix.SYS_QUOTACTL, uintptr(Q_SYNC), uintptr(unsafe.Pointer(devArg)), 0, 0, 0, 0)
+ if err != unix.Errno(0) {
+ return err
+ }
+ } else {
+ _, _, err := unix.Syscall6(unix.SYS_QUOTACTL, uintptr(Q_SYNC), 0, 0, 0, 0, 0)
+ if err != unix.Errno(0) {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/metropolis/node/common/jsonpatch/BUILD.bazel b/metropolis/node/common/jsonpatch/BUILD.bazel
new file mode 100644
index 0000000..bd77e0a
--- /dev/null
+++ b/metropolis/node/common/jsonpatch/BUILD.bazel
@@ -0,0 +1,14 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "go_default_library",
+ srcs = ["jsonpatch.go.go"],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/node/common/jsonpatch",
+ visibility = ["//visibility:public"],
+)
+
+go_test(
+ name = "go_default_test",
+ srcs = ["jsonpatch_test.go"],
+ embed = [":go_default_library"],
+)
diff --git a/metropolis/node/common/jsonpatch/jsonpatch.go.go b/metropolis/node/common/jsonpatch/jsonpatch.go.go
new file mode 100644
index 0000000..9682980
--- /dev/null
+++ b/metropolis/node/common/jsonpatch/jsonpatch.go.go
@@ -0,0 +1,44 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package jsonpatch contains data structures and encoders for JSON Patch (RFC 6902) and JSON Pointers (RFC 6901)
+package jsonpatch
+
+import "strings"
+
+// JSON Patch operation (RFC 6902 Section 4)
+type JsonPatchOp struct {
+ Operation string `json:"op"`
+ Path string `json:"path"` // Technically a JSON Pointer, but called Path in the RFC
+ From string `json:"from,omitempty"`
+ Value interface{} `json:"value,omitempty"`
+}
+
+// EncodeJSONRefToken encodes a JSON reference token as part of a JSON Pointer (RFC 6901 Section 2)
+func EncodeJSONRefToken(token string) string {
+ x := strings.ReplaceAll(token, "~", "~0")
+ return strings.ReplaceAll(x, "/", "~1")
+}
+
+// PointerFromParts returns an encoded JSON Pointer from parts
+func PointerFromParts(pathParts []string) string {
+ var encodedParts []string
+ encodedParts = append(encodedParts, "")
+ for _, part := range pathParts {
+ encodedParts = append(encodedParts, EncodeJSONRefToken(part))
+ }
+ return strings.Join(encodedParts, "/")
+}
diff --git a/metropolis/node/common/jsonpatch/jsonpatch_test.go b/metropolis/node/common/jsonpatch/jsonpatch_test.go
new file mode 100644
index 0000000..33a56ba
--- /dev/null
+++ b/metropolis/node/common/jsonpatch/jsonpatch_test.go
@@ -0,0 +1,66 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package jsonpatch
+
+import (
+ "testing"
+)
+
+func TestEncodeJSONRefToken(t *testing.T) {
+ tests := []struct {
+ name string
+ token string
+ want string
+ }{
+ {"Passes through normal characters", "asdf123", "asdf123"},
+ {"Encodes simple slashes", "a/b", "a~1b"},
+ {"Encodes tildes", "m~n", "m~0n"},
+ {"Encodes bot tildes and slashes", "a/m~n", "a~1m~0n"},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := EncodeJSONRefToken(tt.token); got != tt.want {
+ t.Errorf("EncodeJSONRefToken() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestPointerFromParts(t *testing.T) {
+ type args struct {
+ pathParts []string
+ }
+ tests := []struct {
+ name string
+ args args
+ want string
+ }{
+ {"Empty path", args{[]string{}}, ""},
+ {"Single level path", args{[]string{"foo"}}, "/foo"},
+ {"Multi-level path", args{[]string{"foo", "0"}}, "/foo/0"},
+ {"Path starting with empty key", args{[]string{""}}, "/"},
+ {"Path with part containing /", args{[]string{"a/b"}}, "/a~1b"},
+ {"Path with part containing spaces", args{[]string{" "}}, "/ "},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := PointerFromParts(tt.args.pathParts); got != tt.want {
+ t.Errorf("PointerFromParts() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
diff --git a/metropolis/node/common/logbuffer/BUILD.bazel b/metropolis/node/common/logbuffer/BUILD.bazel
new file mode 100644
index 0000000..2d4650d
--- /dev/null
+++ b/metropolis/node/common/logbuffer/BUILD.bazel
@@ -0,0 +1,22 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "go_default_library",
+ srcs = [
+ "linebuffer.go",
+ "logbuffer.go",
+ ],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/node/common/logbuffer",
+ visibility = ["//visibility:public"],
+ deps = ["//metropolis/proto/api:go_default_library"],
+)
+
+go_test(
+ name = "go_default_test",
+ srcs = [
+ "linebuffer_test.go",
+ "logbuffer_test.go",
+ ],
+ embed = [":go_default_library"],
+ deps = ["@com_github_stretchr_testify//require:go_default_library"],
+)
diff --git a/metropolis/node/common/logbuffer/linebuffer.go b/metropolis/node/common/logbuffer/linebuffer.go
new file mode 100644
index 0000000..246a91b
--- /dev/null
+++ b/metropolis/node/common/logbuffer/linebuffer.go
@@ -0,0 +1,160 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logbuffer
+
+import (
+ "bytes"
+ "fmt"
+ "strings"
+ "sync"
+
+ apb "git.monogon.dev/source/nexantic.git/metropolis/proto/api"
+)
+
+// Line is a line stored in the log buffer - a string, that has been perhaps truncated (due to exceeded limits).
+type Line struct {
+ Data string
+ OriginalLength int
+}
+
+// Truncated returns whether this line has been truncated to fit limits.
+func (l *Line) Truncated() bool {
+ return l.OriginalLength > len(l.Data)
+}
+
+// String returns the line with an ellipsis at the end (...) if the line has been truncated, or the original line
+// otherwise.
+func (l *Line) String() string {
+ if l.Truncated() {
+ return l.Data + "..."
+ }
+ return l.Data
+}
+
+// ProtoLog returns a Logging-specific protobuf structure.
+func (l *Line) ProtoLog() *apb.LogEntry_Raw {
+ return &apb.LogEntry_Raw{
+ Data: l.Data,
+ OriginalLength: int64(l.OriginalLength),
+ }
+}
+
+// LineFromLogProto converts a Logging-specific protobuf message back into a Line.
+func LineFromLogProto(raw *apb.LogEntry_Raw) (*Line, error) {
+ if raw.OriginalLength < int64(len(raw.Data)) {
+ return nil, fmt.Errorf("original_length smaller than length of data")
+ }
+ originalLength := int(raw.OriginalLength)
+ if int64(originalLength) < raw.OriginalLength {
+ return nil, fmt.Errorf("original_length larger than native int size")
+ }
+ return &Line{
+ Data: raw.Data,
+ OriginalLength: originalLength,
+ }, nil
+}
+
+// LineBuffer is a io.WriteCloser that will call a given callback every time a line is completed.
+type LineBuffer struct {
+ maxLineLength int
+ cb LineBufferCallback
+
+ mu sync.Mutex
+ cur strings.Builder
+ // length is the length of the line currently being written - this will continue to increase, even if the string
+ // exceeds maxLineLength.
+ length int
+ closed bool
+}
+
+// LineBufferCallback is a callback that will get called any time the line is completed. The function must not cause another
+// write to the LineBuffer, or the program will deadlock.
+type LineBufferCallback func(*Line)
+
+// NewLineBuffer creates a new LineBuffer with a given line length limit and callback.
+func NewLineBuffer(maxLineLength int, cb LineBufferCallback) *LineBuffer {
+ return &LineBuffer{
+ maxLineLength: maxLineLength,
+ cb: cb,
+ }
+}
+
+// writeLimited writes to the internal buffer, making sure that its size does not exceed the maxLineLength.
+func (l *LineBuffer) writeLimited(data []byte) {
+ l.length += len(data)
+ if l.cur.Len()+len(data) > l.maxLineLength {
+ data = data[:l.maxLineLength-l.cur.Len()]
+ }
+ l.cur.Write(data)
+}
+
+// comitLine calls the callback and resets the builder.
+func (l *LineBuffer) commitLine() {
+ l.cb(&Line{
+ Data: l.cur.String(),
+ OriginalLength: l.length,
+ })
+ l.cur.Reset()
+ l.length = 0
+}
+
+func (l *LineBuffer) Write(data []byte) (int, error) {
+ var pos = 0
+
+ l.mu.Lock()
+ defer l.mu.Unlock()
+
+ if l.closed {
+ return 0, fmt.Errorf("closed")
+ }
+
+ for {
+ nextNewline := bytes.IndexRune(data[pos:], '\n')
+
+ // No newline in the data, write everything to the current line
+ if nextNewline == -1 {
+ l.writeLimited(data[pos:])
+ break
+ }
+
+ // Write this line and update position
+ l.writeLimited(data[pos : pos+nextNewline])
+ l.commitLine()
+ pos += nextNewline + 1
+
+ // Data ends with a newline, stop now without writing an empty line
+ if nextNewline == len(data)-1 {
+ break
+ }
+ }
+ return len(data), nil
+}
+
+// Close will emit any leftover data in the buffer to the callback. Subsequent calls to Write will fail. Subsequent calls to Close
+// will also fail.
+func (l *LineBuffer) Close() error {
+ if l.closed {
+ return fmt.Errorf("already closed")
+ }
+ l.mu.Lock()
+ defer l.mu.Unlock()
+ l.closed = true
+ if l.length > 0 {
+ l.commitLine()
+ }
+ return nil
+}
diff --git a/metropolis/node/common/logbuffer/linebuffer_test.go b/metropolis/node/common/logbuffer/linebuffer_test.go
new file mode 100644
index 0000000..c821a4b
--- /dev/null
+++ b/metropolis/node/common/logbuffer/linebuffer_test.go
@@ -0,0 +1,75 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logbuffer
+
+import (
+ "fmt"
+ "testing"
+)
+
+func TestLineBuffer(t *testing.T) {
+ var lines []*Line
+ lb := NewLineBuffer(1024, func(l *Line) {
+ lines = append(lines, l)
+ })
+
+ compare := func(a []*Line, b ...string) string {
+ msg := fmt.Sprintf("want %v, got %v", a, b)
+ if len(a) != len(b) {
+ return msg
+ }
+ for i, _ := range a {
+ if a[i].String() != b[i] {
+ return msg
+ }
+ }
+ return ""
+ }
+
+ // Write some data.
+ fmt.Fprintf(lb, "foo ")
+ if diff := compare(lines); diff != "" {
+ t.Fatal(diff)
+ }
+ fmt.Fprintf(lb, "bar\n")
+ if diff := compare(lines, "foo bar"); diff != "" {
+ t.Fatal(diff)
+ }
+ fmt.Fprintf(lb, "baz")
+ if diff := compare(lines, "foo bar"); diff != "" {
+ t.Fatal(diff)
+ }
+ fmt.Fprintf(lb, " baz")
+ if diff := compare(lines, "foo bar"); diff != "" {
+ t.Fatal(diff)
+ }
+ // Close and expect flush.
+ if err := lb.Close(); err != nil {
+ t.Fatalf("Close: %v", err)
+ }
+ if diff := compare(lines, "foo bar", "baz baz"); diff != "" {
+ t.Fatal(diff)
+ }
+
+ // Check behaviour after close
+ if _, err := fmt.Fprintf(lb, "nope"); err == nil {
+ t.Fatalf("Write after Close: wanted error, got nil")
+ }
+ if err := lb.Close(); err == nil {
+ t.Fatalf("second Close: wanted error, got nil")
+ }
+}
diff --git a/metropolis/node/common/logbuffer/logbuffer.go b/metropolis/node/common/logbuffer/logbuffer.go
new file mode 100644
index 0000000..ce47816
--- /dev/null
+++ b/metropolis/node/common/logbuffer/logbuffer.go
@@ -0,0 +1,97 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package logbuffer implements a fixed-size in-memory ring buffer for line-separated logs.
+// It implements io.Writer and splits the data into lines. The lines are kept in a ring where the
+// oldest are overwritten once it's full. It allows retrieval of the last n lines. There is a built-in
+// line length limit to bound the memory usage at maxLineLength * size.
+package logbuffer
+
+import (
+ "sync"
+)
+
+// LogBuffer implements a fixed-size in-memory ring buffer for line-separated logs
+type LogBuffer struct {
+ mu sync.RWMutex
+ content []Line
+ length int
+ *LineBuffer
+}
+
+// New creates a new LogBuffer with a given ringbuffer size and maximum line length.
+func New(size, maxLineLength int) *LogBuffer {
+ lb := &LogBuffer{
+ content: make([]Line, size),
+ }
+ lb.LineBuffer = NewLineBuffer(maxLineLength, lb.lineCallback)
+ return lb
+}
+
+func (b *LogBuffer) lineCallback(line *Line) {
+ b.mu.Lock()
+ defer b.mu.Unlock()
+
+ b.content[b.length%len(b.content)] = *line
+ b.length++
+}
+
+// capToContentLength caps the number of requested lines to what is actually available
+func (b *LogBuffer) capToContentLength(n int) int {
+ // If there aren't enough lines to read, reduce the request size
+ if n > b.length {
+ n = b.length
+ }
+ // If there isn't enough ringbuffer space, reduce the request size
+ if n > len(b.content) {
+ n = len(b.content)
+ }
+ return n
+}
+
+// ReadLines reads the last n lines from the buffer in chronological order. If n is bigger than the
+// ring buffer or the number of available lines only the number of stored lines are returned.
+func (b *LogBuffer) ReadLines(n int) []Line {
+ b.mu.RLock()
+ defer b.mu.RUnlock()
+
+ n = b.capToContentLength(n)
+
+ // Copy references out to keep them around
+ outArray := make([]Line, n)
+ for i := 1; i <= n; i++ {
+ outArray[n-i] = b.content[(b.length-i)%len(b.content)]
+ }
+ return outArray
+}
+
+// ReadLinesTruncated works exactly the same as ReadLines, but adds an ellipsis at the end of every
+// line that was truncated because it was over MaxLineLength
+func (b *LogBuffer) ReadLinesTruncated(n int, ellipsis string) []string {
+ b.mu.RLock()
+ defer b.mu.RUnlock()
+ // This does not use ReadLines() to prevent excessive reference copying and associated GC pressure
+ // since it could process a lot of lines.
+
+ n = b.capToContentLength(n)
+
+ outArray := make([]string, n)
+ for i := 1; i <= n; i++ {
+ line := b.content[(b.length-i)%len(b.content)]
+ outArray[n-i] = line.String()
+ }
+ return outArray
+}
diff --git a/metropolis/node/common/logbuffer/logbuffer_test.go b/metropolis/node/common/logbuffer/logbuffer_test.go
new file mode 100644
index 0000000..c38d7a6
--- /dev/null
+++ b/metropolis/node/common/logbuffer/logbuffer_test.go
@@ -0,0 +1,94 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logbuffer
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestSingleLine(t *testing.T) {
+ buf := New(1, 16000)
+ buf.Write([]byte("Hello World\n"))
+ out := buf.ReadLines(1)
+ require.Len(t, out, 1, "Invalid number of lines read")
+ require.Equal(t, "Hello World", out[0].Data, "Read bad log line")
+ require.Equal(t, 11, out[0].OriginalLength, "Invalid line length")
+}
+
+func TestPartialWritesAndReads(t *testing.T) {
+ buf := New(2, 16000)
+ buf.Write([]byte("Hello "))
+ buf.Write([]byte("World\nTest "))
+ buf.Write([]byte("2\n"))
+
+ out := buf.ReadLines(1)
+ require.Len(t, out, 1, "Invalid number of lines for partial read")
+ require.Equal(t, "Test 2", out[0].Data, "Read bad log line")
+
+ out2 := buf.ReadLines(2)
+ require.Len(t, out2, 2, "Invalid number of lines read")
+ require.Equal(t, "Hello World", out2[0].Data, "Read bad log line")
+ require.Equal(t, "Test 2", out2[1].Data, "Read bad log line")
+}
+
+func TestBufferOverwrite(t *testing.T) {
+ buf := New(3, 16000)
+ buf.Write([]byte("Test1\nTest2\nTest3\nTest4\n"))
+
+ out := buf.ReadLines(3)
+ require.Equal(t, out[0].Data, "Test2", "Read bad log line")
+ require.Equal(t, out[1].Data, "Test3", "Read bad log line")
+ require.Equal(t, out[2].Data, "Test4", "Overwritten data is invalid")
+}
+
+func TestTooLargeRequests(t *testing.T) {
+ buf := New(1, 16000)
+ outEmpty := buf.ReadLines(1)
+ require.Len(t, outEmpty, 0, "Returned more data than there is")
+
+ buf.Write([]byte("1\n2\n"))
+ out := buf.ReadLines(2)
+ require.Len(t, out, 1, "Returned more data than the ring buffer can hold")
+}
+
+func TestSpecialCases(t *testing.T) {
+ buf := New(2, 16000)
+ buf.Write([]byte("Test1"))
+ buf.Write([]byte("\nTest2\n"))
+ out := buf.ReadLines(2)
+ require.Len(t, out, 2, "Too many lines written")
+ require.Equal(t, out[0].Data, "Test1", "Read bad log line")
+ require.Equal(t, out[1].Data, "Test2", "Read bad log line")
+}
+
+func TestLineLengthLimit(t *testing.T) {
+ buf := New(2, 6)
+
+ testStr := "Just Testing"
+
+ buf.Write([]byte(testStr + "\nShort\n"))
+
+ out := buf.ReadLines(2)
+ require.Equal(t, len(testStr), out[0].OriginalLength, "Line is over length limit")
+ require.Equal(t, "Just T", out[0].Data, "Log line not properly truncated")
+
+ out2 := buf.ReadLinesTruncated(2, "...")
+ require.Equal(t, out2[0], "Just T...", "Line is over length limit")
+ require.Equal(t, out2[1], "Short", "Truncated small enough line")
+}
diff --git a/metropolis/node/common/supervisor/BUILD.bazel b/metropolis/node/common/supervisor/BUILD.bazel
new file mode 100644
index 0000000..ae95892
--- /dev/null
+++ b/metropolis/node/common/supervisor/BUILD.bazel
@@ -0,0 +1,28 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "go_default_library",
+ srcs = [
+ "supervisor.go",
+ "supervisor_node.go",
+ "supervisor_processor.go",
+ "supervisor_support.go",
+ "supervisor_testhelpers.go",
+ ],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/node/common/supervisor",
+ visibility = [
+ "//metropolis/node:__subpackages__",
+ "//metropolis/test:__subpackages__",
+ ],
+ deps = [
+ "//metropolis/node/core/logtree:go_default_library",
+ "@com_github_cenkalti_backoff_v4//:go_default_library",
+ "@org_golang_google_grpc//:go_default_library",
+ ],
+)
+
+go_test(
+ name = "go_default_test",
+ srcs = ["supervisor_test.go"],
+ embed = [":go_default_library"],
+)
diff --git a/metropolis/node/common/supervisor/supervisor.go b/metropolis/node/common/supervisor/supervisor.go
new file mode 100644
index 0000000..df7492c
--- /dev/null
+++ b/metropolis/node/common/supervisor/supervisor.go
@@ -0,0 +1,145 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package supervisor
+
+// The service supervision library allows for writing of reliable, service-style software within Smalltown.
+// It builds upon the Erlang/OTP supervision tree system, adapted to be more Go-ish.
+// For detailed design see go/supervision.
+
+import (
+ "context"
+ "io"
+ "sync"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/logtree"
+)
+
+// A Runnable is a function that will be run in a goroutine, and supervised throughout its lifetime. It can in turn
+// start more runnables as its children, and those will form part of a supervision tree.
+// The context passed to a runnable is very important and needs to be handled properly. It will be live (non-errored) as
+// long as the runnable should be running, and canceled (ctx.Err() will be non-nil) when the supervisor wants it to
+// exit. This means this context is also perfectly usable for performing any blocking operations.
+type Runnable func(ctx context.Context) error
+
+// RunGroup starts a set of runnables as a group. These runnables will run together, and if any one of them quits
+// unexpectedly, the result will be canceled and restarted.
+// The context here must be an existing Runnable context, and the spawned runnables will run under the node that this
+// context represents.
+func RunGroup(ctx context.Context, runnables map[string]Runnable) error {
+ node, unlock := fromContext(ctx)
+ defer unlock()
+ return node.runGroup(runnables)
+}
+
+// Run starts a single runnable in its own group.
+func Run(ctx context.Context, name string, runnable Runnable) error {
+ return RunGroup(ctx, map[string]Runnable{
+ name: runnable,
+ })
+}
+
+// Signal tells the supervisor that the calling runnable has reached a certain state of its lifecycle. All runnables
+// should SignalHealthy when they are ready with set up, running other child runnables and are now 'serving'.
+func Signal(ctx context.Context, signal SignalType) {
+ node, unlock := fromContext(ctx)
+ defer unlock()
+ node.signal(signal)
+}
+
+type SignalType int
+
+const (
+ // The runnable is healthy, done with setup, done with spawning more Runnables, and ready to serve in a loop.
+ // The runnable needs to check the parent context and ensure that if that context is done, the runnable exits.
+ SignalHealthy SignalType = iota
+ // The runnable is done - it does not need to run any loop. This is useful for Runnables that only set up other
+ // child runnables. This runnable will be restarted if a related failure happens somewhere in the supervision tree.
+ SignalDone
+)
+
+// supervisor represents and instance of the supervision system. It keeps track of a supervision tree and a request
+// channel to its internal processor goroutine.
+type supervisor struct {
+ // mu guards the entire state of the supervisor.
+ mu sync.RWMutex
+ // root is the root node of the supervision tree, named 'root'. It represents the Runnable started with the
+ // supervisor.New call.
+ root *node
+ // logtree is the main logtree exposed to runnables and used internally.
+ logtree *logtree.LogTree
+ // ilogger is the internal logger logging to "supervisor" in the logtree.
+ ilogger logtree.LeveledLogger
+
+ // pReq is an interface channel to the lifecycle processor of the supervisor.
+ pReq chan *processorRequest
+
+ // propagate panics, ie. don't catch them.
+ propagatePanic bool
+}
+
+// SupervisorOpt are runtime configurable options for the supervisor.
+type SupervisorOpt func(s *supervisor)
+
+var (
+ // WithPropagatePanic prevents the Supervisor from catching panics in runnables and treating them as failures.
+ // This is useful to enable for testing and local debugging.
+ WithPropagatePanic = func(s *supervisor) {
+ s.propagatePanic = true
+ }
+)
+
+func WithExistingLogtree(lt *logtree.LogTree) SupervisorOpt {
+ return func(s *supervisor) {
+ s.logtree = lt
+ }
+}
+
+// New creates a new supervisor with its root running the given root runnable.
+// The given context can be used to cancel the entire supervision tree.
+func New(ctx context.Context, rootRunnable Runnable, opts ...SupervisorOpt) *supervisor {
+ sup := &supervisor{
+ logtree: logtree.New(),
+ pReq: make(chan *processorRequest),
+ }
+
+ for _, o := range opts {
+ o(sup)
+ }
+
+ sup.ilogger = sup.logtree.MustLeveledFor("supervisor")
+ sup.root = newNode("root", rootRunnable, sup, nil)
+
+ go sup.processor(ctx)
+
+ sup.pReq <- &processorRequest{
+ schedule: &processorRequestSchedule{dn: "root"},
+ }
+
+ return sup
+}
+
+func Logger(ctx context.Context) logtree.LeveledLogger {
+ node, unlock := fromContext(ctx)
+ defer unlock()
+ return node.sup.logtree.MustLeveledFor(logtree.DN(node.dn()))
+}
+
+func RawLogger(ctx context.Context) io.Writer {
+ node, unlock := fromContext(ctx)
+ defer unlock()
+ return node.sup.logtree.MustRawFor(logtree.DN(node.dn()))
+}
diff --git a/metropolis/node/common/supervisor/supervisor_node.go b/metropolis/node/common/supervisor/supervisor_node.go
new file mode 100644
index 0000000..a7caf82
--- /dev/null
+++ b/metropolis/node/common/supervisor/supervisor_node.go
@@ -0,0 +1,282 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package supervisor
+
+import (
+ "context"
+ "fmt"
+ "regexp"
+ "strings"
+
+ "github.com/cenkalti/backoff/v4"
+)
+
+// node is a supervision tree node. It represents the state of a Runnable within this tree, its relation to other tree
+// elements, and contains supporting data needed to actually supervise it.
+type node struct {
+ // The name of this node. Opaque string. It's used to make up the 'dn' (distinguished name) of a node within
+ // the tree. When starting a runnable inside a tree, this is where that name gets used.
+ name string
+ runnable Runnable
+
+ // The supervisor managing this tree.
+ sup *supervisor
+ // The parent, within the tree, of this node. If this is the root node of the tree, this is nil.
+ parent *node
+ // Children of this tree. This is represented by a map keyed from child node names, for easy access.
+ children map[string]*node
+ // Supervision groups. Each group is a set of names of children. Sets, and as such groups, don't overlap between
+ // each other. A supervision group indicates that if any child within that group fails, all others should be
+ // canceled and restarted together.
+ groups []map[string]bool
+
+ // The current state of the runnable in this node.
+ state nodeState
+
+ // Backoff used to keep runnables from being restarted too fast.
+ bo *backoff.ExponentialBackOff
+
+ // Context passed to the runnable, and its cancel function.
+ ctx context.Context
+ ctxC context.CancelFunc
+}
+
+// nodeState is the state of a runnable within a node, and in a way the node itself.
+// This follows the state diagram from go/supervision.
+type nodeState int
+
+const (
+ // A node that has just been created, and whose runnable has been started already but hasn't signaled anything yet.
+ nodeStateNew nodeState = iota
+ // A node whose runnable has signaled being healthy - this means it's ready to serve/act.
+ nodeStateHealthy
+ // A node that has unexpectedly returned or panicked.
+ nodeStateDead
+ // A node that has declared that its done with its work and should not be restarted, unless a supervision tree
+ // failure requires that.
+ nodeStateDone
+ // A node that has returned after being requested to cancel.
+ nodeStateCanceled
+)
+
+func (s nodeState) String() string {
+ switch s {
+ case nodeStateNew:
+ return "NODE_STATE_NEW"
+ case nodeStateHealthy:
+ return "NODE_STATE_HEALTHY"
+ case nodeStateDead:
+ return "NODE_STATE_DEAD"
+ case nodeStateDone:
+ return "NODE_STATE_DONE"
+ case nodeStateCanceled:
+ return "NODE_STATE_CANCELED"
+ }
+ return "UNKNOWN"
+}
+
+func (n *node) String() string {
+ return fmt.Sprintf("%s (%s)", n.dn(), n.state.String())
+}
+
+// contextKey is a type used to keep data within context values.
+type contextKey string
+
+var (
+ supervisorKey = contextKey("supervisor")
+ dnKey = contextKey("dn")
+)
+
+// fromContext retrieves a tree node from a runnable context. It takes a lock on the tree and returns an unlock
+// function. This unlock function needs to be called once mutations on the tree/supervisor/node are done.
+func fromContext(ctx context.Context) (*node, func()) {
+ sup, ok := ctx.Value(supervisorKey).(*supervisor)
+ if !ok {
+ panic("supervisor function called from non-runnable context")
+ }
+
+ sup.mu.Lock()
+
+ dnParent, ok := ctx.Value(dnKey).(string)
+ if !ok {
+ sup.mu.Unlock()
+ panic("supervisor function called from non-runnable context")
+ }
+
+ return sup.nodeByDN(dnParent), sup.mu.Unlock
+}
+
+// All the following 'internal' supervisor functions must only be called with the supervisor lock taken. Getting a lock
+// via fromContext is enough.
+
+// dn returns the distinguished name of a node. This distinguished name is a period-separated, inverse-DNS-like name.
+// For instance, the runnable 'foo' within the runnable 'bar' will be called 'root.bar.foo'. The root of the tree is
+// always named, and has the dn, 'root'.
+func (n *node) dn() string {
+ if n.parent != nil {
+ return fmt.Sprintf("%s.%s", n.parent.dn(), n.name)
+ }
+ return n.name
+}
+
+// groupSiblings is a helper function to get all runnable group siblings of a given runnable name within this node.
+// All children are always in a group, even if that group is unary.
+func (n *node) groupSiblings(name string) map[string]bool {
+ for _, m := range n.groups {
+ if _, ok := m[name]; ok {
+ return m
+ }
+ }
+ return nil
+}
+
+// newNode creates a new node with a given parent. It does not register it with the parent (as that depends on group
+// placement).
+func newNode(name string, runnable Runnable, sup *supervisor, parent *node) *node {
+ // We use exponential backoff for failed runnables, but at some point we cap at a given backoff time.
+ // To achieve this, we set MaxElapsedTime to 0, which will cap the backoff at MaxInterval.
+ bo := backoff.NewExponentialBackOff()
+ bo.MaxElapsedTime = 0
+
+ n := &node{
+ name: name,
+ runnable: runnable,
+
+ bo: bo,
+
+ sup: sup,
+ parent: parent,
+ }
+ n.reset()
+ return n
+}
+
+// resetNode sets up all the dynamic fields of the node, in preparation of starting a runnable. It clears the node's
+// children, groups and resets its context.
+func (n *node) reset() {
+ // Make new context. First, acquire parent context. For the root node that's Background, otherwise it's the
+ // parent's context.
+ var pCtx context.Context
+ if n.parent == nil {
+ pCtx = context.Background()
+ } else {
+ pCtx = n.parent.ctx
+ }
+ // Mark DN and supervisor in context.
+ ctx := context.WithValue(pCtx, dnKey, n.dn())
+ ctx = context.WithValue(ctx, supervisorKey, n.sup)
+ ctx, ctxC := context.WithCancel(ctx)
+ // Set context
+ n.ctx = ctx
+ n.ctxC = ctxC
+
+ // Clear children and state
+ n.state = nodeStateNew
+ n.children = make(map[string]*node)
+ n.groups = nil
+
+ // The node is now ready to be scheduled.
+}
+
+// nodeByDN returns a node by given DN from the supervisor.
+func (s *supervisor) nodeByDN(dn string) *node {
+ parts := strings.Split(dn, ".")
+ if parts[0] != "root" {
+ panic("DN does not start with root.")
+ }
+ parts = parts[1:]
+ cur := s.root
+ for {
+ if len(parts) == 0 {
+ return cur
+ }
+
+ next, ok := cur.children[parts[0]]
+ if !ok {
+ panic(fmt.Errorf("could not find %v (%s) in %s", parts, dn, cur))
+ }
+ cur = next
+ parts = parts[1:]
+ }
+}
+
+// reNodeName validates a node name against constraints.
+var reNodeName = regexp.MustCompile(`[a-z90-9_]{1,64}`)
+
+// runGroup schedules a new group of runnables to run on a node.
+func (n *node) runGroup(runnables map[string]Runnable) error {
+ // Check that the parent node is in the right state.
+ if n.state != nodeStateNew {
+ return fmt.Errorf("cannot run new runnable on non-NEW node")
+ }
+
+ // Check the requested runnable names.
+ for name, _ := range runnables {
+ if !reNodeName.MatchString(name) {
+ return fmt.Errorf("runnable name %q is invalid", name)
+ }
+ if _, ok := n.children[name]; ok {
+ return fmt.Errorf("runnable %q already exists", name)
+ }
+ }
+
+ // Create child nodes.
+ dns := make(map[string]string)
+ group := make(map[string]bool)
+ for name, runnable := range runnables {
+ if g := n.groupSiblings(name); g != nil {
+ return fmt.Errorf("duplicate child name %q", name)
+ }
+ node := newNode(name, runnable, n.sup, n)
+ n.children[name] = node
+
+ dns[name] = node.dn()
+ group[name] = true
+ }
+ // Add group.
+ n.groups = append(n.groups, group)
+
+ // Schedule execution of group members.
+ go func() {
+ for name, _ := range runnables {
+ n.sup.pReq <- &processorRequest{
+ schedule: &processorRequestSchedule{
+ dn: dns[name],
+ },
+ }
+ }
+ }()
+ return nil
+}
+
+// signal sequences state changes by signals received from runnables and updates a node's status accordingly.
+func (n *node) signal(signal SignalType) {
+ switch signal {
+ case SignalHealthy:
+ if n.state != nodeStateNew {
+ panic(fmt.Errorf("node %s signaled healthy", n))
+ }
+ n.state = nodeStateHealthy
+ n.bo.Reset()
+ case SignalDone:
+ if n.state != nodeStateHealthy {
+ panic(fmt.Errorf("node %s signaled done", n))
+ }
+ n.state = nodeStateDone
+ n.bo.Reset()
+ }
+}
diff --git a/metropolis/node/common/supervisor/supervisor_processor.go b/metropolis/node/common/supervisor/supervisor_processor.go
new file mode 100644
index 0000000..965a667
--- /dev/null
+++ b/metropolis/node/common/supervisor/supervisor_processor.go
@@ -0,0 +1,404 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package supervisor
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "runtime/debug"
+ "time"
+)
+
+// The processor maintains runnable goroutines - ie., when requested will start one, and then once it exists it will
+// record the result and act accordingly. It is also responsible for detecting and acting upon supervision subtrees that
+// need to be restarted after death (via a 'GC' process)
+
+// processorRequest is a request for the processor. Only one of the fields can be set.
+type processorRequest struct {
+ schedule *processorRequestSchedule
+ died *processorRequestDied
+ waitSettled *processorRequestWaitSettled
+}
+
+// processorRequestSchedule requests that a given node's runnable be started.
+type processorRequestSchedule struct {
+ dn string
+}
+
+// processorRequestDied is a signal from a runnable goroutine that the runnable has died.
+type processorRequestDied struct {
+ dn string
+ err error
+}
+
+type processorRequestWaitSettled struct {
+ waiter chan struct{}
+}
+
+// processor is the main processing loop.
+func (s *supervisor) processor(ctx context.Context) {
+ s.ilogger.Info("supervisor processor started")
+
+ // Waiters waiting for the GC to be settled.
+ var waiters []chan struct{}
+
+ // The GC will run every millisecond if needed. Any time the processor requests a change in the supervision tree
+ // (ie a death or a new runnable) it will mark the state as dirty and run the GC on the next millisecond cycle.
+ gc := time.NewTicker(1 * time.Millisecond)
+ defer gc.Stop()
+ clean := true
+
+ // How long has the GC been clean. This is used to notify 'settled' waiters.
+ cleanCycles := 0
+
+ markDirty := func() {
+ clean = false
+ cleanCycles = 0
+ }
+
+ for {
+ select {
+ case <-ctx.Done():
+ s.ilogger.Infof("supervisor processor exiting: %v", ctx.Err())
+ s.processKill()
+ s.ilogger.Info("supervisor exited")
+ return
+ case <-gc.C:
+ if !clean {
+ s.processGC()
+ }
+ clean = true
+ cleanCycles += 1
+
+ // This threshold is somewhat arbitrary. It's a balance between test speed and test reliability.
+ if cleanCycles > 50 {
+ for _, w := range waiters {
+ close(w)
+ }
+ waiters = nil
+ }
+ case r := <-s.pReq:
+ switch {
+ case r.schedule != nil:
+ s.processSchedule(r.schedule)
+ markDirty()
+ case r.died != nil:
+ s.processDied(r.died)
+ markDirty()
+ case r.waitSettled != nil:
+ waiters = append(waiters, r.waitSettled.waiter)
+ default:
+ panic(fmt.Errorf("unhandled request %+v", r))
+ }
+ }
+ }
+}
+
+// processKill cancels all nodes in the supervision tree. This is only called right before exiting the processor, so
+// they do not get automatically restarted.
+func (s *supervisor) processKill() {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // Gather all context cancel functions.
+ var cancels []func()
+ queue := []*node{s.root}
+ for {
+ if len(queue) == 0 {
+ break
+ }
+
+ cur := queue[0]
+ queue = queue[1:]
+
+ cancels = append(cancels, cur.ctxC)
+ for _, c := range cur.children {
+ queue = append(queue, c)
+ }
+ }
+
+ // Call all context cancels.
+ for _, c := range cancels {
+ c()
+ }
+}
+
+// processSchedule starts a node's runnable in a goroutine and records its output once it's done.
+func (s *supervisor) processSchedule(r *processorRequestSchedule) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ n := s.nodeByDN(r.dn)
+ go func() {
+ if !s.propagatePanic {
+ defer func() {
+ if rec := recover(); rec != nil {
+ s.pReq <- &processorRequest{
+ died: &processorRequestDied{
+ dn: r.dn,
+ err: fmt.Errorf("panic: %v, stacktrace: %s", rec, string(debug.Stack())),
+ },
+ }
+ }
+ }()
+ }
+
+ res := n.runnable(n.ctx)
+
+ s.pReq <- &processorRequest{
+ died: &processorRequestDied{
+ dn: r.dn,
+ err: res,
+ },
+ }
+ }()
+}
+
+// processDied records the result from a runnable goroutine, and updates its node state accordingly. If the result
+// is a death and not an expected exit, related nodes (ie. children and group siblings) are canceled accordingly.
+func (s *supervisor) processDied(r *processorRequestDied) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // Okay, so a Runnable has quit. What now?
+ n := s.nodeByDN(r.dn)
+ ctx := n.ctx
+
+ // Simple case: it was marked as Done and quit with no error.
+ if n.state == nodeStateDone && r.err == nil {
+ // Do nothing. This was supposed to happen. Keep the process as DONE.
+ return
+ }
+
+ // Find innermost error to check if it's a context canceled error.
+ perr := r.err
+ for {
+ if inner := errors.Unwrap(perr); inner != nil {
+ perr = inner
+ continue
+ }
+ break
+ }
+
+ // Simple case: the context was canceled and the returned error is the context error.
+ if err := ctx.Err(); err != nil && perr == err {
+ // Mark the node as canceled successfully.
+ n.state = nodeStateCanceled
+ return
+ }
+
+ // Otherwise, the Runnable should not have died or quit. Handle accordingly.
+ err := r.err
+ // A lack of returned error is also an error.
+ if err == nil {
+ err = fmt.Errorf("returned when %s", n.state)
+ } else {
+ err = fmt.Errorf("returned error when %s: %w", n.state, err)
+ }
+
+ s.ilogger.Errorf("Runnable %s died: %v", n.dn(), err)
+ // Mark as dead.
+ n.state = nodeStateDead
+
+ // Cancel that node's context, just in case something still depends on it.
+ n.ctxC()
+
+ // Cancel all siblings.
+ if n.parent != nil {
+ for name, _ := range n.parent.groupSiblings(n.name) {
+ if name == n.name {
+ continue
+ }
+ sibling := n.parent.children[name]
+ // TODO(q3k): does this need to run in a goroutine, ie. can a context cancel block?
+ sibling.ctxC()
+ }
+ }
+}
+
+// processGC runs the GC process. It's not really Garbage Collection, as in, it doesn't remove unnecessary tree nodes -
+// but it does find nodes that need to be restarted, find the subset that can and then schedules them for running.
+// As such, it's less of a Garbage Collector and more of a Necromancer. However, GC is a friendlier name.
+func (s *supervisor) processGC() {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // The 'GC' serves is the main business logic of the supervision tree. It traverses a locked tree and tries to
+ // find subtrees that must be restarted (because of a DEAD/CANCELED runnable). It then finds which of these
+ // subtrees that should be restarted can be restarted, ie. which ones are fully recursively DEAD/CANCELED. It
+ // also finds the smallest set of largest subtrees that can be restarted, ie. if there's multiple DEAD runnables
+ // that can be restarted at once, it will do so.
+
+ // Phase one: Find all leaves.
+ // This is a simple DFS that finds all the leaves of the tree, ie all nodes that do not have children nodes.
+ leaves := make(map[string]bool)
+ queue := []*node{s.root}
+ for {
+ if len(queue) == 0 {
+ break
+ }
+ cur := queue[0]
+ queue = queue[1:]
+
+ for _, c := range cur.children {
+ queue = append([]*node{c}, queue...)
+ }
+
+ if len(cur.children) == 0 {
+ leaves[cur.dn()] = true
+ }
+ }
+
+ // Phase two: traverse tree from node to root and make note of all subtrees that can be restarted.
+ // A subtree is restartable/ready iff every node in that subtree is either CANCELED, DEAD or DONE.
+ // Such a 'ready' subtree can be restarted by the supervisor if needed.
+
+ // DNs that we already visited.
+ visited := make(map[string]bool)
+ // DNs whose subtrees are ready to be restarted.
+ // These are all subtrees recursively - ie., root.a.a and root.a will both be marked here.
+ ready := make(map[string]bool)
+
+ // We build a queue of nodes to visit, starting from the leaves.
+ queue = []*node{}
+ for l, _ := range leaves {
+ queue = append(queue, s.nodeByDN(l))
+ }
+
+ for {
+ if len(queue) == 0 {
+ break
+ }
+
+ cur := queue[0]
+ curDn := cur.dn()
+
+ queue = queue[1:]
+
+ // Do we have a decision about our children?
+ allVisited := true
+ for _, c := range cur.children {
+ if !visited[c.dn()] {
+ allVisited = false
+ break
+ }
+ }
+
+ // If no decision about children is available, it means we ended up in this subtree through some shorter path
+ // of a shorter/lower-order leaf. There is a path to a leaf that's longer than the one that caused this node
+ // to be enqueued. Easy solution: just push back the current element and retry later.
+ if !allVisited {
+ // Push back to queue and wait for a decision later.
+ queue = append(queue, cur)
+ continue
+ }
+
+ // All children have been visited and we have an idea about whether they're ready/restartable. All of the node's
+ // children must be restartable in order for this node to be restartable.
+ childrenReady := true
+ for _, c := range cur.children {
+ if !ready[c.dn()] {
+ childrenReady = false
+ break
+ }
+ }
+
+ // In addition to children, the node itself must be restartable (ie. DONE, DEAD or CANCELED).
+ curReady := false
+ switch cur.state {
+ case nodeStateDone:
+ curReady = true
+ case nodeStateCanceled:
+ curReady = true
+ case nodeStateDead:
+ curReady = true
+ }
+
+ // Note down that we have an opinion on this node, and note that opinion down.
+ visited[curDn] = true
+ ready[curDn] = childrenReady && curReady
+
+ // Now we can also enqueue the parent of this node for processing.
+ if cur.parent != nil && !visited[cur.parent.dn()] {
+ queue = append(queue, cur.parent)
+ }
+ }
+
+ // Phase 3: traverse tree from root to find largest subtrees that need to be restarted and are ready to be
+ // restarted.
+
+ // All DNs that need to be restarted by the GC process.
+ want := make(map[string]bool)
+ // All DNs that need to be restarted and can be restarted by the GC process - a subset of 'want' DNs.
+ can := make(map[string]bool)
+ // The set difference between 'want' and 'can' are all nodes that should be restarted but can't yet (ie. because
+ // a child is still in the process of being canceled).
+
+ // DFS from root.
+ queue = []*node{s.root}
+ for {
+ if len(queue) == 0 {
+ break
+ }
+
+ cur := queue[0]
+ queue = queue[1:]
+
+ // If this node is DEAD or CANCELED it should be restarted.
+ if cur.state == nodeStateDead || cur.state == nodeStateCanceled {
+ want[cur.dn()] = true
+ }
+
+ // If it should be restarted and is ready to be restarted...
+ if want[cur.dn()] && ready[cur.dn()] {
+ // And its parent context is valid (ie hasn't been canceled), mark it as restartable.
+ if cur.parent == nil || cur.parent.ctx.Err() == nil {
+ can[cur.dn()] = true
+ continue
+ }
+ }
+
+ // Otherwise, traverse further down the tree to see if something else needs to be done.
+ for _, c := range cur.children {
+ queue = append(queue, c)
+ }
+ }
+
+ // Reinitialize and reschedule all subtrees
+ for dn, _ := range can {
+ n := s.nodeByDN(dn)
+
+ // Only back off when the node unexpectedly died - not when it got canceled.
+ bo := time.Duration(0)
+ if n.state == nodeStateDead {
+ bo = n.bo.NextBackOff()
+ }
+
+ // Prepare node for rescheduling - remove its children, reset its state to new.
+ n.reset()
+ s.ilogger.Infof("rescheduling supervised node %s with backoff %s", dn, bo.String())
+
+ // Reschedule node runnable to run after backoff.
+ go func(n *node, bo time.Duration) {
+ time.Sleep(bo)
+ s.pReq <- &processorRequest{
+ schedule: &processorRequestSchedule{dn: n.dn()},
+ }
+ }(n, bo)
+ }
+}
diff --git a/metropolis/node/common/supervisor/supervisor_support.go b/metropolis/node/common/supervisor/supervisor_support.go
new file mode 100644
index 0000000..d54b35c
--- /dev/null
+++ b/metropolis/node/common/supervisor/supervisor_support.go
@@ -0,0 +1,62 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package supervisor
+
+// Supporting infrastructure to allow running some non-Go payloads under supervision.
+
+import (
+ "context"
+ "net"
+ "os/exec"
+
+ "google.golang.org/grpc"
+)
+
+// GRPCServer creates a Runnable that serves gRPC requests as longs as it's not canceled.
+// If graceful is set to true, the server will be gracefully stopped instead of plain stopped. This means all pending
+// RPCs will finish, but also requires streaming gRPC handlers to check their context liveliness and exit accordingly.
+// If the server code does not support this, `graceful` should be false and the server will be killed violently instead.
+func GRPCServer(srv *grpc.Server, lis net.Listener, graceful bool) Runnable {
+ return func(ctx context.Context) error {
+ Signal(ctx, SignalHealthy)
+ errC := make(chan error)
+ go func() {
+ errC <- srv.Serve(lis)
+ }()
+ select {
+ case <-ctx.Done():
+ if graceful {
+ srv.GracefulStop()
+ } else {
+ srv.Stop()
+ }
+ return ctx.Err()
+ case err := <-errC:
+ return err
+ }
+ }
+}
+
+// RunCommand will create a Runnable that starts a long-running command, whose exit is determined to be a failure.
+func RunCommand(ctx context.Context, cmd *exec.Cmd) error {
+ Signal(ctx, SignalHealthy)
+ cmd.Stdout = RawLogger(ctx)
+ cmd.Stderr = RawLogger(ctx)
+ err := cmd.Run()
+ Logger(ctx).Infof("Command returned: %v", err)
+ return err
+}
diff --git a/metropolis/node/common/supervisor/supervisor_test.go b/metropolis/node/common/supervisor/supervisor_test.go
new file mode 100644
index 0000000..9c7bdb7
--- /dev/null
+++ b/metropolis/node/common/supervisor/supervisor_test.go
@@ -0,0 +1,557 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package supervisor
+
+import (
+ "context"
+ "fmt"
+ "testing"
+ "time"
+)
+
+func runnableBecomesHealthy(healthy, done chan struct{}) Runnable {
+ return func(ctx context.Context) error {
+ Signal(ctx, SignalHealthy)
+
+ go func() {
+ if healthy != nil {
+ healthy <- struct{}{}
+ }
+ }()
+
+ <-ctx.Done()
+
+ go func() {
+ if done != nil {
+ done <- struct{}{}
+ }
+ }()
+
+ return ctx.Err()
+ }
+}
+
+func runnableSpawnsMore(healthy, done chan struct{}, levels int) Runnable {
+ return func(ctx context.Context) error {
+ if levels > 0 {
+ err := RunGroup(ctx, map[string]Runnable{
+ "a": runnableSpawnsMore(nil, nil, levels-1),
+ "b": runnableSpawnsMore(nil, nil, levels-1),
+ })
+ if err != nil {
+ return err
+ }
+ }
+
+ Signal(ctx, SignalHealthy)
+
+ go func() {
+ if healthy != nil {
+ healthy <- struct{}{}
+ }
+ }()
+
+ <-ctx.Done()
+
+ go func() {
+ if done != nil {
+ done <- struct{}{}
+ }
+ }()
+ return ctx.Err()
+ }
+}
+
+// rc is a Remote Controlled runnable. It is a generic runnable used for testing the supervisor.
+type rc struct {
+ req chan rcRunnableRequest
+}
+
+type rcRunnableRequest struct {
+ cmd rcRunnableCommand
+ stateC chan rcRunnableState
+}
+
+type rcRunnableCommand int
+
+const (
+ rcRunnableCommandBecomeHealthy rcRunnableCommand = iota
+ rcRunnableCommandBecomeDone
+ rcRunnableCommandDie
+ rcRunnableCommandPanic
+ rcRunnableCommandState
+)
+
+type rcRunnableState int
+
+const (
+ rcRunnableStateNew rcRunnableState = iota
+ rcRunnableStateHealthy
+ rcRunnableStateDone
+)
+
+func (r *rc) becomeHealthy() {
+ r.req <- rcRunnableRequest{cmd: rcRunnableCommandBecomeHealthy}
+}
+
+func (r *rc) becomeDone() {
+ r.req <- rcRunnableRequest{cmd: rcRunnableCommandBecomeDone}
+}
+func (r *rc) die() {
+ r.req <- rcRunnableRequest{cmd: rcRunnableCommandDie}
+}
+
+func (r *rc) panic() {
+ r.req <- rcRunnableRequest{cmd: rcRunnableCommandPanic}
+}
+
+func (r *rc) state() rcRunnableState {
+ c := make(chan rcRunnableState)
+ r.req <- rcRunnableRequest{
+ cmd: rcRunnableCommandState,
+ stateC: c,
+ }
+ return <-c
+}
+
+func (r *rc) waitState(s rcRunnableState) {
+ // This is poll based. Making it non-poll based would make the RC runnable logic a bit more complex for little gain.
+ for {
+ got := r.state()
+ if got == s {
+ return
+ }
+ time.Sleep(10 * time.Millisecond)
+ }
+}
+
+func newRC() *rc {
+ return &rc{
+ req: make(chan rcRunnableRequest),
+ }
+}
+
+// Remote Controlled Runnable
+func (r *rc) runnable() Runnable {
+ return func(ctx context.Context) error {
+ state := rcRunnableStateNew
+
+ for {
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case r := <-r.req:
+ switch r.cmd {
+ case rcRunnableCommandBecomeHealthy:
+ Signal(ctx, SignalHealthy)
+ state = rcRunnableStateHealthy
+ case rcRunnableCommandBecomeDone:
+ Signal(ctx, SignalDone)
+ state = rcRunnableStateDone
+ case rcRunnableCommandDie:
+ return fmt.Errorf("died on request")
+ case rcRunnableCommandPanic:
+ panic("at the disco")
+ case rcRunnableCommandState:
+ r.stateC <- state
+ }
+ }
+ }
+ }
+}
+
+func TestSimple(t *testing.T) {
+ h1 := make(chan struct{})
+ d1 := make(chan struct{})
+ h2 := make(chan struct{})
+ d2 := make(chan struct{})
+
+ ctx, ctxC := context.WithCancel(context.Background())
+ defer ctxC()
+ s := New(ctx, func(ctx context.Context) error {
+ err := RunGroup(ctx, map[string]Runnable{
+ "one": runnableBecomesHealthy(h1, d1),
+ "two": runnableBecomesHealthy(h2, d2),
+ })
+ if err != nil {
+ return err
+ }
+ Signal(ctx, SignalHealthy)
+ Signal(ctx, SignalDone)
+ return nil
+ }, WithPropagatePanic)
+
+ // Expect both to start running.
+ s.waitSettleError(ctx, t)
+ select {
+ case <-h1:
+ default:
+ t.Fatalf("runnable 'one' didn't start")
+ }
+ select {
+ case <-h2:
+ default:
+ t.Fatalf("runnable 'one' didn't start")
+ }
+}
+
+func TestSimpleFailure(t *testing.T) {
+ h1 := make(chan struct{})
+ d1 := make(chan struct{})
+ two := newRC()
+
+ ctx, ctxC := context.WithTimeout(context.Background(), 10*time.Second)
+ defer ctxC()
+ s := New(ctx, func(ctx context.Context) error {
+ err := RunGroup(ctx, map[string]Runnable{
+ "one": runnableBecomesHealthy(h1, d1),
+ "two": two.runnable(),
+ })
+ if err != nil {
+ return err
+ }
+ Signal(ctx, SignalHealthy)
+ Signal(ctx, SignalDone)
+ return nil
+ }, WithPropagatePanic)
+ s.waitSettleError(ctx, t)
+
+ two.becomeHealthy()
+ s.waitSettleError(ctx, t)
+ // Expect one to start running.
+ select {
+ case <-h1:
+ default:
+ t.Fatalf("runnable 'one' didn't start")
+ }
+
+ // Kill off two, one should restart.
+ two.die()
+ s.waitSettleError(ctx, t)
+ select {
+ case <-d1:
+ default:
+ t.Fatalf("runnable 'one' didn't acknowledge cancel")
+ }
+
+ // And one should start running again.
+ s.waitSettleError(ctx, t)
+ select {
+ case <-h1:
+ default:
+ t.Fatalf("runnable 'one' didn't restart")
+ }
+}
+
+func TestDeepFailure(t *testing.T) {
+ h1 := make(chan struct{})
+ d1 := make(chan struct{})
+ two := newRC()
+
+ ctx, ctxC := context.WithTimeout(context.Background(), 10*time.Second)
+ defer ctxC()
+ s := New(ctx, func(ctx context.Context) error {
+ err := RunGroup(ctx, map[string]Runnable{
+ "one": runnableSpawnsMore(h1, d1, 5),
+ "two": two.runnable(),
+ })
+ if err != nil {
+ return err
+ }
+ Signal(ctx, SignalHealthy)
+ Signal(ctx, SignalDone)
+ return nil
+ }, WithPropagatePanic)
+
+ two.becomeHealthy()
+ s.waitSettleError(ctx, t)
+ // Expect one to start running.
+ select {
+ case <-h1:
+ default:
+ t.Fatalf("runnable 'one' didn't start")
+ }
+
+ // Kill off two, one should restart.
+ two.die()
+ s.waitSettleError(ctx, t)
+ select {
+ case <-d1:
+ default:
+ t.Fatalf("runnable 'one' didn't acknowledge cancel")
+ }
+
+ // And one should start running again.
+ s.waitSettleError(ctx, t)
+ select {
+ case <-h1:
+ default:
+ t.Fatalf("runnable 'one' didn't restart")
+ }
+}
+
+func TestPanic(t *testing.T) {
+ h1 := make(chan struct{})
+ d1 := make(chan struct{})
+ two := newRC()
+
+ ctx, ctxC := context.WithCancel(context.Background())
+ defer ctxC()
+ s := New(ctx, func(ctx context.Context) error {
+ err := RunGroup(ctx, map[string]Runnable{
+ "one": runnableBecomesHealthy(h1, d1),
+ "two": two.runnable(),
+ })
+ if err != nil {
+ return err
+ }
+ Signal(ctx, SignalHealthy)
+ Signal(ctx, SignalDone)
+ return nil
+ })
+
+ two.becomeHealthy()
+ s.waitSettleError(ctx, t)
+ // Expect one to start running.
+ select {
+ case <-h1:
+ default:
+ t.Fatalf("runnable 'one' didn't start")
+ }
+
+ // Kill off two, one should restart.
+ two.panic()
+ s.waitSettleError(ctx, t)
+ select {
+ case <-d1:
+ default:
+ t.Fatalf("runnable 'one' didn't acknowledge cancel")
+ }
+
+ // And one should start running again.
+ s.waitSettleError(ctx, t)
+ select {
+ case <-h1:
+ default:
+ t.Fatalf("runnable 'one' didn't restart")
+ }
+}
+
+func TestMultipleLevelFailure(t *testing.T) {
+ ctx, ctxC := context.WithCancel(context.Background())
+ defer ctxC()
+ New(ctx, func(ctx context.Context) error {
+ err := RunGroup(ctx, map[string]Runnable{
+ "one": runnableSpawnsMore(nil, nil, 4),
+ "two": runnableSpawnsMore(nil, nil, 4),
+ })
+ if err != nil {
+ return err
+ }
+ Signal(ctx, SignalHealthy)
+ Signal(ctx, SignalDone)
+ return nil
+ }, WithPropagatePanic)
+}
+
+func TestBackoff(t *testing.T) {
+ one := newRC()
+
+ ctx, ctxC := context.WithTimeout(context.Background(), 20*time.Second)
+ defer ctxC()
+
+ s := New(ctx, func(ctx context.Context) error {
+ if err := Run(ctx, "one", one.runnable()); err != nil {
+ return err
+ }
+ Signal(ctx, SignalHealthy)
+ Signal(ctx, SignalDone)
+ return nil
+ }, WithPropagatePanic)
+
+ one.becomeHealthy()
+ // Die a bunch of times in a row, this brings up the next exponential backoff to over a second.
+ for i := 0; i < 4; i += 1 {
+ one.die()
+ one.waitState(rcRunnableStateNew)
+ }
+ // Measure how long it takes for the runnable to respawn after a number of failures
+ start := time.Now()
+ one.die()
+ one.becomeHealthy()
+ one.waitState(rcRunnableStateHealthy)
+ taken := time.Since(start)
+ if taken < 1*time.Second {
+ t.Errorf("Runnable took %v to restart, wanted at least a second from backoff", taken)
+ }
+
+ s.waitSettleError(ctx, t)
+ // Now that we've become healthy, die again. Becoming healthy resets the backoff.
+ start = time.Now()
+ one.die()
+ one.becomeHealthy()
+ one.waitState(rcRunnableStateHealthy)
+ taken = time.Since(start)
+ if taken > 1*time.Second || taken < 100*time.Millisecond {
+ t.Errorf("Runnable took %v to restart, wanted at least 100ms from backoff and at most 1s from backoff reset", taken)
+ }
+}
+
+// TestResilience throws some curveballs at the supervisor - either programming errors or high load. It then ensures
+// that another runnable is running, and that it restarts on its sibling failure.
+func TestResilience(t *testing.T) {
+ // request/response channel for testing liveness of the 'one' runnable
+ req := make(chan chan struct{})
+
+ // A runnable that responds on the 'req' channel.
+ one := func(ctx context.Context) error {
+ Signal(ctx, SignalHealthy)
+ for {
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case r := <-req:
+ r <- struct{}{}
+ }
+ }
+ }
+ oneSibling := newRC()
+
+ oneTest := func() {
+ timeout := time.NewTicker(1000 * time.Millisecond)
+ ping := make(chan struct{})
+ req <- ping
+ select {
+ case <-ping:
+ case <-timeout.C:
+ t.Fatalf("one ping response timeout")
+ }
+ timeout.Stop()
+ }
+
+ // A nasty runnable that calls Signal with the wrong context (this is a programming error)
+ two := func(ctx context.Context) error {
+ Signal(context.TODO(), SignalHealthy)
+ return nil
+ }
+
+ // A nasty runnable that calls Signal wrong (this is a programming error).
+ three := func(ctx context.Context) error {
+ Signal(ctx, SignalDone)
+ return nil
+ }
+
+ // A nasty runnable that runs in a busy loop (this is a programming error).
+ four := func(ctx context.Context) error {
+ for {
+ time.Sleep(0)
+ }
+ }
+
+ // A nasty runnable that keeps creating more runnables.
+ five := func(ctx context.Context) error {
+ i := 1
+ for {
+ err := Run(ctx, fmt.Sprintf("r%d", i), runnableSpawnsMore(nil, nil, 2))
+ if err != nil {
+ return err
+ }
+
+ time.Sleep(100 * time.Millisecond)
+ i += 1
+ }
+ }
+
+ ctx, ctxC := context.WithCancel(context.Background())
+ defer ctxC()
+ New(ctx, func(ctx context.Context) error {
+ RunGroup(ctx, map[string]Runnable{
+ "one": one,
+ "oneSibling": oneSibling.runnable(),
+ })
+ rs := map[string]Runnable{
+ "two": two, "three": three, "four": four, "five": five,
+ }
+ for k, v := range rs {
+ if err := Run(ctx, k, v); err != nil {
+ return err
+ }
+ }
+ Signal(ctx, SignalHealthy)
+ Signal(ctx, SignalDone)
+ return nil
+ })
+
+ // Five rounds of letting one run, then restarting it.
+ for i := 0; i < 5; i += 1 {
+ oneSibling.becomeHealthy()
+ oneSibling.waitState(rcRunnableStateHealthy)
+
+ // 'one' should work for at least a second.
+ deadline := time.Now().Add(1 * time.Second)
+ for {
+ if time.Now().After(deadline) {
+ break
+ }
+
+ oneTest()
+ }
+
+ // Killing 'oneSibling' should restart one.
+ oneSibling.panic()
+ }
+ // Make sure 'one' is still okay.
+ oneTest()
+}
+
+func ExampleNew() {
+ // Minimal runnable that is immediately done.
+ childC := make(chan struct{})
+ child := func(ctx context.Context) error {
+ Signal(ctx, SignalHealthy)
+ close(childC)
+ Signal(ctx, SignalDone)
+ return nil
+ }
+
+ // Start a supervision tree with a root runnable.
+ ctx, ctxC := context.WithCancel(context.Background())
+ defer ctxC()
+ New(ctx, func(ctx context.Context) error {
+ err := Run(ctx, "child", child)
+ if err != nil {
+ return fmt.Errorf("could not run 'child': %w", err)
+ }
+ Signal(ctx, SignalHealthy)
+
+ t := time.NewTicker(time.Second)
+ defer t.Stop()
+
+ // Do something in the background, and exit on context cancel.
+ for {
+ select {
+ case <-t.C:
+ fmt.Printf("tick!")
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+ }
+ })
+
+ // root.child will close this channel.
+ <-childC
+}
diff --git a/metropolis/node/common/supervisor/supervisor_testhelpers.go b/metropolis/node/common/supervisor/supervisor_testhelpers.go
new file mode 100644
index 0000000..771e02f
--- /dev/null
+++ b/metropolis/node/common/supervisor/supervisor_testhelpers.go
@@ -0,0 +1,50 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package supervisor
+
+import (
+ "context"
+ "testing"
+)
+
+// waitSettle waits until the supervisor reaches a 'settled' state - ie., one
+// where no actions have been performed for a number of GC cycles.
+// This is used in tests only.
+func (s *supervisor) waitSettle(ctx context.Context) error {
+ waiter := make(chan struct{})
+ s.pReq <- &processorRequest{
+ waitSettled: &processorRequestWaitSettled{
+ waiter: waiter,
+ },
+ }
+
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-waiter:
+ return nil
+ }
+}
+
+// waitSettleError wraps waitSettle to fail a test if an error occurs, eg. the
+// context is canceled.
+func (s *supervisor) waitSettleError(ctx context.Context, t *testing.T) {
+ err := s.waitSettle(ctx)
+ if err != nil {
+ t.Fatalf("waitSettle: %v", err)
+ }
+}
diff --git a/metropolis/node/common/sysfs/BUILD.bazel b/metropolis/node/common/sysfs/BUILD.bazel
new file mode 100644
index 0000000..a4c7f18
--- /dev/null
+++ b/metropolis/node/common/sysfs/BUILD.bazel
@@ -0,0 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "go_default_library",
+ srcs = ["uevents.go"],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/node/common/sysfs",
+ visibility = ["//visibility:public"],
+)
diff --git a/metropolis/node/common/sysfs/uevents.go b/metropolis/node/common/sysfs/uevents.go
new file mode 100644
index 0000000..fed4319
--- /dev/null
+++ b/metropolis/node/common/sysfs/uevents.go
@@ -0,0 +1,50 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sysfs
+
+import (
+ "bufio"
+ "io"
+ "os"
+ "strings"
+)
+
+func ReadUevents(filename string) (map[string]string, error) {
+ f, err := os.Open(filename)
+ if err != nil {
+ return nil, err
+ }
+ defer f.Close()
+ ueventMap := make(map[string]string)
+ reader := bufio.NewReader(f)
+ for {
+ name, err := reader.ReadString(byte('='))
+ if err == io.EOF {
+ break
+ } else if err != nil {
+ return nil, err
+ }
+ value, err := reader.ReadString(byte('\n'))
+ if err == io.EOF {
+ continue
+ } else if err != nil {
+ return nil, err
+ }
+ ueventMap[strings.Trim(name, "=")] = strings.TrimSpace(value)
+ }
+ return ueventMap, nil
+}
diff --git a/metropolis/node/core/BUILD.bazel b/metropolis/node/core/BUILD.bazel
new file mode 100644
index 0000000..2398205
--- /dev/null
+++ b/metropolis/node/core/BUILD.bazel
@@ -0,0 +1,43 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_library")
+
+go_library(
+ name = "go_default_library",
+ # keep
+ srcs = [
+ "debug_service.go",
+ "main.go",
+ "switchroot.go",
+ ] + select({
+ "//metropolis/node:debug_build": ["delve_enabled.go"],
+ "//conditions:default": ["delve_disabled.go"],
+ }),
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/node/core",
+ visibility = ["//visibility:private"],
+ deps = [
+ "//metropolis/node:go_default_library",
+ "//metropolis/node/common/supervisor:go_default_library",
+ "//metropolis/node/core/cluster:go_default_library",
+ "//metropolis/node/core/consensus/ca:go_default_library",
+ "//metropolis/node/core/localstorage:go_default_library",
+ "//metropolis/node/core/localstorage/declarative:go_default_library",
+ "//metropolis/node/core/logtree:go_default_library",
+ "//metropolis/node/core/network:go_default_library",
+ "//metropolis/node/core/network/dns:go_default_library",
+ "//metropolis/node/core/tpm:go_default_library",
+ "//metropolis/node/kubernetes:go_default_library",
+ "//metropolis/node/kubernetes/containerd:go_default_library",
+ "//metropolis/node/kubernetes/pki:go_default_library",
+ "//metropolis/proto/api:go_default_library",
+ "@org_golang_google_grpc//:go_default_library",
+ "@org_golang_google_grpc//codes:go_default_library",
+ "@org_golang_google_grpc//status:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+go_binary(
+ name = "core",
+ embed = [":go_default_library"],
+ pure = "on", # keep
+ visibility = ["//visibility:public"],
+)
diff --git a/metropolis/node/core/cluster/BUILD.bazel b/metropolis/node/core/cluster/BUILD.bazel
new file mode 100644
index 0000000..70daba2
--- /dev/null
+++ b/metropolis/node/core/cluster/BUILD.bazel
@@ -0,0 +1,25 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "go_default_library",
+ srcs = [
+ "manager.go",
+ "node.go",
+ ],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/node/core/cluster",
+ visibility = ["//metropolis/node/core:__subpackages__"],
+ deps = [
+ "//metropolis/node:go_default_library",
+ "//metropolis/node/common/supervisor:go_default_library",
+ "//metropolis/node/core/consensus:go_default_library",
+ "//metropolis/node/core/localstorage:go_default_library",
+ "//metropolis/node/core/localstorage/declarative:go_default_library",
+ "//metropolis/node/core/network:go_default_library",
+ "//metropolis/proto/api:go_default_library",
+ "//metropolis/proto/internal:go_default_library",
+ "@com_github_cenkalti_backoff_v4//:go_default_library",
+ "@com_github_golang_protobuf//proto:go_default_library",
+ "@io_etcd_go_etcd//clientv3:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/metropolis/node/core/cluster/manager.go b/metropolis/node/core/cluster/manager.go
new file mode 100644
index 0000000..6bb87f4
--- /dev/null
+++ b/metropolis/node/core/cluster/manager.go
@@ -0,0 +1,555 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cluster
+
+import (
+ "context"
+ "crypto/x509"
+ "encoding/pem"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/cenkalti/backoff/v4"
+ "github.com/golang/protobuf/proto"
+ "go.etcd.io/etcd/clientv3"
+
+ common "git.monogon.dev/source/nexantic.git/metropolis/node"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/common/supervisor"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/consensus"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/localstorage"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/localstorage/declarative"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/network"
+ apb "git.monogon.dev/source/nexantic.git/metropolis/proto/api"
+)
+
+// Manager is a finite state machine that joins this node (ie., Smalltown instance running on a virtual/physical machine)
+// into a Smalltown cluster (ie. group of nodes that act as a single control plane for Smalltown services). It does that
+// by bringing up all required operating-system level components, including mounting the local filesystem, bringing up
+// a consensus (etcd) server/client, ...
+//
+// The Manager runs as a single-shot Runnable. It will attempt to progress its state from the initial state (New) to
+// either Running (meaning that the node is now part of a cluster), or Failed (meaning that the node couldn't become
+// part of a cluster). It is not restartable, as it mutates quite a bit of implicit operating-system level state (like
+// filesystem mounts). As such, it's difficult to recover reliably from failures, and since these failures indicate
+// some high issues with the cluster configuration/state, a failure requires a full kernel reboot to retry (or fix/
+// reconfigure the node).
+//
+// Currently, the Manager only supports one flow for bringing up a Node: by creating a new cluster. As such, it's
+// missing the following flows:
+// - joining a new node into an already running cluster
+// - restarting a node into an already existing cluster
+// - restarting a node into an already running cluster (ie. full reboot of whole cluster)
+//
+type Manager struct {
+ storageRoot *localstorage.Root
+ networkService *network.Service
+
+ // stateLock locks all state* variables.
+ stateLock sync.RWMutex
+ // state is the FSM state of the Manager.
+ state State
+ // stateRunningNode is the Node that this Manager got from joining a cluster. It's only valid if the Manager is
+ // Running.
+ stateRunningNode *Node
+ // stateWaiters is a list of channels that wish to be notified (by sending true or false) for when the Manager
+ // reaches a final state (Running or Failed respectively).
+ stateWaiters []chan bool
+
+ // goldenTicket is the Golden Ticket present in the enrolment config, if any.
+ goldenTicket *apb.GoldenTicket
+
+ // consensus is the spawned etcd/consensus service, if the Manager brought up a Node that should run one.
+ consensus *consensus.Service
+}
+
+// NewManager creates a new cluster Manager. The given localstorage Root must be places, but not yet started (and will
+// be started as the Manager makes progress). The given network Service must already be running.
+func NewManager(storageRoot *localstorage.Root, networkService *network.Service) *Manager {
+ return &Manager{
+ storageRoot: storageRoot,
+ networkService: networkService,
+ }
+}
+
+// State is the state of the Manager finite state machine.
+type State int
+
+const (
+ // StateNew is the initial state of the Manager. It decides how to go about joining or creating a cluster.
+ StateNew State = iota
+ // StateCreatingCluster is when the Manager attempts to create a new cluster - this happens when a node is started
+ // with no EnrolmentConfig.
+ StateCreatingCluster
+ // StateCharlie is when the Manager uses the Golden Ticket debug/stopgap system to join an already
+ // existing cluster. This mechanism will be removed before the first Smalltown release.
+ StateCharlie
+ // StateRunning is when the Manager successfully got the node to be part of a cluster. stateRunningNode is valid.
+ StateRunning
+ // StateFailed is when the Manager failed to ge the node to be part of a cluster.
+ StateFailed
+)
+
+func (s State) String() string {
+ switch s {
+ case StateNew:
+ return "New"
+ case StateCreatingCluster:
+ return "CreatingCluster"
+ case StateCharlie:
+ return "Charlie"
+ case StateRunning:
+ return "Running"
+ case StateFailed:
+ return "Failed"
+ default:
+ return "UNKNOWN"
+ }
+}
+
+// allowedTransition describes all allowed state transitions (map[From][]To).
+var allowedTransitions = map[State][]State{
+ StateNew: {StateCreatingCluster, StateCharlie},
+ StateCreatingCluster: {StateRunning, StateFailed},
+ StateCharlie: {StateRunning, StateFailed},
+}
+
+// allowed returns whether a transition from a state to another state is allowed (ie. is defined in allowedTransitions).
+func (m *Manager) allowed(from, to State) bool {
+ for _, allowed := range allowedTransitions[from] {
+ if to == allowed {
+ return true
+ }
+ }
+ return false
+}
+
+// next moves the Manager finite state machine from its current state to `n`, or to Failed if the transition is not
+// allowed.
+func (m *Manager) next(ctx context.Context, n State) {
+ m.stateLock.Lock()
+ defer m.stateLock.Unlock()
+
+ if !m.allowed(m.state, n) {
+ supervisor.Logger(ctx).Errorf("Attempted invalid enrolment state transition, failing enrolment; from: %s, to: %s",
+ m.state.String(), n.String())
+ m.state = StateFailed
+ return
+ }
+
+ supervisor.Logger(ctx).Infof("Enrolment state change; from: %s, to: %s", m.state.String(), n.String())
+
+ m.state = n
+}
+
+// State returns the state of the Manager. It's safe to call this from any goroutine.
+func (m *Manager) State() State {
+ m.stateLock.RLock()
+ defer m.stateLock.RUnlock()
+ return m.state
+}
+
+// WaitFinished waits until the Manager FSM reaches Running or Failed, and returns true if the FSM is Running. It's
+// safe to call this from any goroutine.
+func (m *Manager) WaitFinished() (success bool) {
+ m.stateLock.Lock()
+ switch m.state {
+ case StateFailed:
+ m.stateLock.Unlock()
+ return false
+ case StateRunning:
+ m.stateLock.Unlock()
+ return true
+ }
+
+ C := make(chan bool)
+ m.stateWaiters = append(m.stateWaiters, C)
+ m.stateLock.Unlock()
+ return <-C
+}
+
+// wakeWaiters wakes any WaitFinished waiters and lets them know about the current state of the Manager.
+// The stateLock must already been taken, and the state must have been set in the same critical section (otherwise
+// this can cause a race condition).
+func (m *Manager) wakeWaiters() {
+ state := m.state
+ waiters := m.stateWaiters
+ m.stateWaiters = nil
+
+ for _, waiter := range waiters {
+ go func(w chan bool) {
+ w <- state == StateRunning
+ }(waiter)
+ }
+}
+
+// Run is the runnable of the Manager, to be started using the Supervisor. It is one-shot, and should not be restarted.
+func (m *Manager) Run(ctx context.Context) error {
+ if state := m.State(); state != StateNew {
+ supervisor.Logger(ctx).Errorf("Manager started with non-New state %s, failing", state.String())
+ m.stateLock.Lock()
+ m.state = StateFailed
+ m.wakeWaiters()
+ m.stateLock.Unlock()
+ return nil
+ }
+
+ var err error
+ bo := backoff.NewExponentialBackOff()
+ for {
+ done := false
+ state := m.State()
+ switch state {
+ case StateNew:
+ err = m.stateNew(ctx)
+ case StateCreatingCluster:
+ err = m.stateCreatingCluster(ctx)
+ case StateCharlie:
+ err = m.stateCharlie(ctx)
+ default:
+ done = true
+ break
+ }
+
+ if err != nil || done {
+ break
+ }
+
+ if state == m.State() && !m.allowed(state, m.State()) {
+ supervisor.Logger(ctx).Errorf("Enrolment got stuck at %s, failing", m.state.String())
+ m.stateLock.Lock()
+ m.state = StateFailed
+ m.stateLock.Unlock()
+ } else {
+ bo.Reset()
+ }
+ }
+
+ m.stateLock.Lock()
+ state := m.state
+ if state != StateRunning {
+ supervisor.Logger(ctx).Errorf("Enrolment failed at %s: %v", m.state.String(), err)
+ } else {
+ supervisor.Logger(ctx).Info("Enrolment successful!")
+ }
+ m.wakeWaiters()
+ m.stateLock.Unlock()
+
+ supervisor.Signal(ctx, supervisor.SignalHealthy)
+ supervisor.Signal(ctx, supervisor.SignalDone)
+ return nil
+}
+
+// stateNew is called when a Manager is New. It makes the decision on how to join this node into a cluster.
+func (m *Manager) stateNew(ctx context.Context) error {
+ supervisor.Logger(ctx).Info("Starting enrolment process...")
+
+ // Check for presence of EnrolmentConfig on ESP or in qemu firmware variables.
+ var configRaw []byte
+ configRaw, err := m.storageRoot.ESP.Enrolment.Read()
+ if err != nil && !os.IsNotExist(err) {
+ return fmt.Errorf("could not read local enrolment file: %w", err)
+ } else if err != nil {
+ configRaw, err = ioutil.ReadFile("/sys/firmware/qemu_fw_cfg/by_name/com.nexantic.smalltown/enrolment.pb/raw")
+ if err != nil && !os.IsNotExist(err) {
+ return fmt.Errorf("could not read firmware enrolment file: %w", err)
+ }
+ }
+
+ // If no enrolment file exists, we create a new cluster.
+ if configRaw == nil {
+ m.next(ctx, StateCreatingCluster)
+ return nil
+ }
+
+ // Enrolment file exists, parse it.
+
+ enrolmentConfig := apb.EnrolmentConfig{}
+ if err := proto.Unmarshal(configRaw, &enrolmentConfig); err != nil {
+ return fmt.Errorf("could not unmarshal local enrolment file: %w", err)
+ }
+
+ // If no join ticket exists, we can't do anything yet.
+ if enrolmentConfig.GoldenTicket == nil {
+ return fmt.Errorf("joining a cluster without a golden ticket not yet implemented")
+ }
+
+ m.goldenTicket = enrolmentConfig.GoldenTicket
+
+ // Otherwise, we begin enrolling with the Golden Ticket.
+ m.next(ctx, StateCharlie)
+ return nil
+}
+
+// stateCreatingCluster is called when the Manager has decided to create a new cluster.
+//
+// The process to create a new cluster is as follows:
+// - wait for IP address
+// - initialize new data partition, by generating local and cluster unlock keys (the local unlock key is saved to
+// the ESP, while the cluster unlock key is returned)
+// - create a new node certificate and Node (with new given cluster unlock key)
+// - start up a new etcd cluster, with this node being the only member
+// - save the new Node to the new etcd cluster (thereby saving the node's cluster unlock key to etcd)
+func (m *Manager) stateCreatingCluster(ctx context.Context) error {
+ logger := supervisor.Logger(ctx)
+ logger.Info("Creating new cluster: waiting for IP address...")
+ ip, err := m.networkService.GetIP(ctx, true)
+ if err != nil {
+ return fmt.Errorf("when getting IP address: %w", err)
+ }
+ logger.Infof("Creating new cluster: got IP address %s", ip.String())
+
+ logger.Info("Creating new cluster: initializing storage...")
+ cuk, err := m.storageRoot.Data.MountNew(&m.storageRoot.ESP.LocalUnlock)
+ if err != nil {
+ return fmt.Errorf("when making new data partition: %w", err)
+ }
+ logger.Info("Creating new cluster: storage initialized")
+
+ // Create certificate for node.
+ cert, err := m.storageRoot.Data.Node.EnsureSelfSigned(localstorage.CertificateForNode)
+ if err != nil {
+ return fmt.Errorf("failed to create new node certificate: %w", err)
+ }
+
+ node := NewNode(cuk, *ip, *cert.Leaf)
+
+ m.consensus = consensus.New(consensus.Config{
+ Data: &m.storageRoot.Data.Etcd,
+ Ephemeral: &m.storageRoot.Ephemeral.Consensus,
+ NewCluster: true,
+ Name: node.ID(),
+ InitialCluster: ip.String(),
+ ExternalHost: ip.String(),
+ ListenHost: ip.String(),
+ })
+ if err := supervisor.Run(ctx, "consensus", m.consensus.Run); err != nil {
+ return fmt.Errorf("when starting consensus: %w", err)
+ }
+
+ // TODO(q3k): make timeout configurable?
+ ctxT, ctxC := context.WithTimeout(ctx, 5*time.Second)
+ defer ctxC()
+
+ supervisor.Logger(ctx).Info("Creating new cluster: waiting for consensus...")
+ if err := m.consensus.WaitReady(ctxT); err != nil {
+ return fmt.Errorf("consensus service failed to become ready: %w", err)
+ }
+
+ // Configure node to be a consensus member and kubernetes worker. In the future, different nodes will have
+ // different roles, but for now they're all symmetrical.
+ _, consensusName, err := m.consensus.MemberInfo(ctx)
+ if err != nil {
+ return fmt.Errorf("could not get consensus MemberInfo: %w", err)
+ }
+ if err := node.MakeConsensusMember(consensusName); err != nil {
+ return fmt.Errorf("could not make new node into consensus member: %w", err)
+ }
+ if err := node.MakeKubernetesWorker(node.ID()); err != nil {
+ return fmt.Errorf("could not make new node into kubernetes worker: %w", err)
+ }
+
+ // Save node into etcd.
+ supervisor.Logger(ctx).Info("Creating new cluster: storing first node...")
+ if err := node.Store(ctx, m.consensus.KV("cluster", "enrolment")); err != nil {
+ return fmt.Errorf("could not save new node: %w", err)
+ }
+
+ m.stateLock.Lock()
+ m.stateRunningNode = node
+ m.stateLock.Unlock()
+
+ m.next(ctx, StateRunning)
+ return nil
+}
+
+// stateCharlie is used to join an existing cluster via the GoldenTicket mechanism. This mechanism is temporarily
+// implemented in Smalltown in order to allow for testing multi-node clusters without a TPM attestation flow implemented.
+// The Golden Ticket contains a pregenerated node certificate, etcd certificate, and other data that any node can
+// use to join the cluster.
+// Since this flow is temporary, it has a slight impedance mismatch with methods exposed by localstorage, node, etc.,
+// and the resulting sequencing is a bit odd:
+// - the {node,etcd} certificates/keys are loaded (this already dictates the new node name, as the node name is based
+// off of the node public key)
+// - local storage is initialized, a local/cluster unlock keypair is generated
+// - etcd keys are manually saved to localstorage (vs. being generated locally by CA)
+// - an etcd/consensus member is started, knowing that the remote member was already generated when the golden ticket
+// was generated (vs. being created now by an RPC call, via an promote-node-to-etcd-member flow)
+// - the node is then promoted to a consensus member and kubernetes worker, its clusterunlock key is set, and then it
+// is saved to etcd.
+// As such, in this flow, we first create an etcd member (on goldenticket generation), and then only create a new Smalltown
+// node (when the goldenticket is used).
+func (m *Manager) stateCharlie(ctx context.Context) error {
+ t := m.goldenTicket
+ nodeCert, err := x509.ParseCertificate(t.NodeCert)
+ if err != nil {
+ return fmt.Errorf("parsing node certificate from ticket: %w", err)
+ }
+
+ supervisor.Logger(ctx).Info("Joining cluster: waiting for IP address...")
+ ip, err := m.networkService.GetIP(ctx, true)
+ if err != nil {
+ return fmt.Errorf("when getting IP address: %w", err)
+ }
+ supervisor.Logger(ctx).Info("Joining cluster: got IP address %s", ip.String())
+
+ supervisor.Logger(ctx).Info("Joining cluster: initializing storage...")
+ cuk, err := m.storageRoot.Data.MountNew(&m.storageRoot.ESP.LocalUnlock)
+ if err != nil {
+ return fmt.Errorf("when making new data partition: %w", err)
+ }
+ supervisor.Logger(ctx).Info("Joining cluster: storage initialized")
+ node := NewNode(cuk, *ip, *nodeCert)
+
+ // Save etcd PKI to disk.
+ for _, f := range []struct {
+ target declarative.FilePlacement
+ data []byte
+ blockType string
+ }{
+ {m.storageRoot.Data.Etcd.PeerPKI.Key, t.EtcdClientKey, "PRIVATE KEY"},
+ {m.storageRoot.Data.Etcd.PeerPKI.Certificate, t.EtcdClientCert, "CERTIFICATE"},
+ {m.storageRoot.Data.Etcd.PeerPKI.CACertificate, t.EtcdCaCert, "CERTIFICATE"},
+ } {
+ if err := f.target.Write(pem.EncodeToMemory(&pem.Block{Type: f.blockType, Bytes: f.data}), 0600); err != nil {
+ return fmt.Errorf("when writing etcd PKI data: %w", err)
+ }
+ }
+ if err := m.storageRoot.Data.Etcd.PeerCRL.Write(t.EtcdCrl, 0600); err != nil {
+ return fmt.Errorf("when writing etcd CRL: %w", err)
+ }
+
+ https := func(p *apb.GoldenTicket_EtcdPeer) string {
+ return fmt.Sprintf("%s=https://%s:%d", p.Name, p.Address, common.ConsensusPort)
+ }
+ var initialCluster []string
+ for _, p := range t.Peers {
+ initialCluster = append(initialCluster, https(p))
+ }
+ initialCluster = append(initialCluster, https(t.This))
+
+ supervisor.Logger(ctx).Infof("Joining cluster: starting etcd join, name: %s, initial_cluster: %s", node.ID(), strings.Join(initialCluster, ","))
+ m.consensus = consensus.New(consensus.Config{
+ Data: &m.storageRoot.Data.Etcd,
+ Ephemeral: &m.storageRoot.Ephemeral.Consensus,
+ Name: node.ID(),
+ InitialCluster: strings.Join(initialCluster, ","),
+ ExternalHost: ip.String(),
+ ListenHost: ip.String(),
+ })
+
+ if err := supervisor.Run(ctx, "consensus", m.consensus.Run); err != nil {
+ return fmt.Errorf("when starting consensus: %w", err)
+ }
+
+ // TODO(q3k): make timeout configurable?
+ ctxT, ctxC := context.WithTimeout(ctx, 5*time.Second)
+ defer ctxC()
+
+ supervisor.Logger(ctx).Info("Joining cluster: waiting for consensus...")
+ if err := m.consensus.WaitReady(ctxT); err != nil {
+ return fmt.Errorf("consensus service failed to become ready: %w", err)
+ }
+
+ // Configure node to be a consensus member and kubernetes worker. In the future, different nodes will have
+ // different roles, but for now they're all symmetrical.
+ _, consensusName, err := m.consensus.MemberInfo(ctx)
+ if err != nil {
+ return fmt.Errorf("could not get consensus MemberInfo: %w", err)
+ }
+ if err := node.MakeConsensusMember(consensusName); err != nil {
+ return fmt.Errorf("could not make new node into consensus member: %w", err)
+ }
+ if err := node.MakeKubernetesWorker(node.ID()); err != nil {
+ return fmt.Errorf("could not make new node into kubernetes worker: %w", err)
+ }
+
+ // Save node into etcd.
+ supervisor.Logger(ctx).Info("Creating new cluster: storing first node...")
+ if err := node.Store(ctx, m.consensus.KV("cluster", "enrolment")); err != nil {
+ return fmt.Errorf("could not save new node: %w", err)
+ }
+
+ m.stateLock.Lock()
+ m.stateRunningNode = node
+ m.stateLock.Unlock()
+
+ m.next(ctx, StateRunning)
+ return nil
+}
+
+// Node returns the Node that the Manager brought into a cluster, or nil if the Manager is not Running.
+// This is safe to call from any goroutine.
+func (m *Manager) Node() *Node {
+ m.stateLock.Lock()
+ defer m.stateLock.Unlock()
+ if m.state != StateRunning {
+ return nil
+ }
+ return m.stateRunningNode
+}
+
+// ConsensusKV returns a namespaced etcd KV client, or nil if the Manager is not Running.
+// This is safe to call from any goroutine.
+func (m *Manager) ConsensusKV(module, space string) clientv3.KV {
+ m.stateLock.Lock()
+ defer m.stateLock.Unlock()
+ if m.state != StateRunning {
+ return nil
+ }
+ if m.stateRunningNode.ConsensusMember() == nil {
+ // TODO(q3k): in this case, we should return a client to etcd even though this
+ // node is not a member of consensus. For now, all nodes are consensus members.
+ return nil
+ }
+ return m.consensus.KV(module, space)
+}
+
+// ConsensusKVRoot returns a non-namespaced etcd KV client, or nil if the Manager is not Running.
+// This is safe to call from any goroutine.
+func (m *Manager) ConsensusKVRoot() clientv3.KV {
+ m.stateLock.Lock()
+ defer m.stateLock.Unlock()
+ if m.state != StateRunning {
+ return nil
+ }
+ if m.stateRunningNode.ConsensusMember() == nil {
+ // TODO(q3k): in this case, we should return a client to etcd even though this
+ // node is not a member of consensus. For now, all nodes are consensus members.
+ return nil
+ }
+ return m.consensus.KVRoot()
+}
+
+// ConsensusCluster returns an etcd Cluster client, or nil if the Manager is not Running.
+// This is safe to call from any goroutine.
+func (m *Manager) ConsensusCluster() clientv3.Cluster {
+ m.stateLock.Lock()
+ defer m.stateLock.Unlock()
+ if m.state != StateRunning {
+ return nil
+ }
+ if m.stateRunningNode.ConsensusMember() == nil {
+ // TODO(q3k): in this case, we should return a client to etcd even though this
+ // node is not a member of consensus. For now, all nodes are consensus members.
+ return nil
+ }
+ return m.consensus.Cluster()
+}
diff --git a/metropolis/node/core/cluster/node.go b/metropolis/node/core/cluster/node.go
new file mode 100644
index 0000000..449c2ff
--- /dev/null
+++ b/metropolis/node/core/cluster/node.go
@@ -0,0 +1,219 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cluster
+
+import (
+ "context"
+ "crypto/ed25519"
+ "crypto/x509"
+ "encoding/hex"
+ "fmt"
+ "net"
+
+ "github.com/golang/protobuf/proto"
+ "go.etcd.io/etcd/clientv3"
+ "golang.org/x/sys/unix"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/localstorage"
+ ipb "git.monogon.dev/source/nexantic.git/metropolis/proto/internal"
+)
+
+// Node is a Smalltown cluster member. A node is a virtual or physical machine running Smalltown. This object represents a
+// node only as part of a Cluster - ie., this object will never be available outside of //metropolis/node/core/cluster
+// if the Node is not part of a Cluster.
+// Nodes are inherently tied to their long term storage, which is etcd. As such, methods on this object relate heavily
+// to the Node's expected lifecycle on etcd.
+type Node struct {
+ // clusterUnlockKey is half of the unlock key required to mount the node's data partition. It's stored in etcd, and
+ // will only be provided to the Node if it can prove its identity via an integrity mechanism (ie. via TPM), or when
+ // the Node was just created (as the key is generated locally by localstorage on first format/mount).
+ // The other part of the unlock key is the LocalUnlockKey that's present on the node's ESP partition.
+ clusterUnlockKey []byte
+ // certificate is the node's TLS certificate, used to authenticate Smalltown gRPC calls/services (but not
+ // consensus/etcd). The certificate for a node is permanent (and never expires). It's self-signed by the node on
+ // startup, and contains the node's IP address in its SAN. Callers/services should check directly against the
+ // expected certificate, and not against a CA.
+ certificate x509.Certificate
+ // address is the management IP address of the node. The management IP address of a node is permanent.
+ address net.IP
+
+ // A Node can have multiple Roles. Each Role is represented by the presence of NodeRole* structures in this
+ // structure, with a nil pointer representing the lack of a role.
+
+ consensusMember *NodeRoleConsensusMember
+ kubernetesWorker *NodeRoleKubernetesWorker
+}
+
+// NewNode creates a new Node. This is only called when a New node is supposed to be created as part of a cluster,
+// otherwise it should be loaded from Etcd.
+func NewNode(cuk []byte, address net.IP, certificate x509.Certificate) *Node {
+ if certificate.Raw == nil {
+ panic("new node must contain raw certificate")
+ }
+ return &Node{
+ clusterUnlockKey: cuk,
+ certificate: certificate,
+ address: address,
+ }
+}
+
+// NodeRoleConsensusMember defines that the Node is a consensus (etcd) cluster member.
+type NodeRoleConsensusMember struct {
+ // etcdMember is the name of the node in Kubernetes. This is for now usually the same as the ID() of the Node.
+ etcdMemberName string
+}
+
+// NodeRoleKubernetesWorker defines that the Node should be running the Kubernetes control and data plane.
+type NodeRoleKubernetesWorker struct {
+ // nodeName is the name of the node in Kubernetes. This is for now usually the same as the ID() of the Node.
+ nodeName string
+}
+
+// ID returns the name of this node, which is `smalltown-{pubkeyHash}`. This name should be the primary way to refer to
+// Smalltown nodes within a cluster, and is guaranteed to be unique by relying on cryptographic randomness.
+func (n *Node) ID() string {
+ return fmt.Sprintf("smalltown-%s", n.IDBare())
+}
+
+// IDBare returns the `{pubkeyHash}` part of the node ID.
+func (n Node) IDBare() string {
+ pubKey, ok := n.certificate.PublicKey.(ed25519.PublicKey)
+ if !ok {
+ panic("node has non-ed25519 public key")
+ }
+ return hex.EncodeToString(pubKey[:16])
+}
+
+func (n *Node) String() string {
+ return n.ID()
+}
+
+// ConsensusMember returns a copy of the NodeRoleConsensusMember struct if the Node is a consensus member, otherwise
+// nil.
+func (n *Node) ConsensusMember() *NodeRoleConsensusMember {
+ if n.consensusMember == nil {
+ return nil
+ }
+ cm := *n.consensusMember
+ return &cm
+}
+
+// KubernetesWorker returns a copy of the NodeRoleKubernetesWorker struct if the Node is a kubernetes worker, otherwise
+// nil.
+func (n *Node) KubernetesWorker() *NodeRoleKubernetesWorker {
+ if n.kubernetesWorker == nil {
+ return nil
+ }
+ kw := *n.kubernetesWorker
+ return &kw
+}
+
+// etcdPath builds the etcd path in which this node's protobuf-serialized state is stored in etcd.
+func (n *Node) etcdPath() string {
+ return fmt.Sprintf("/nodes/%s", n.ID())
+}
+
+// proto serializes the Node object into protobuf, to be used for saving to etcd.
+func (n *Node) proto() *ipb.Node {
+ msg := &ipb.Node{
+ Certificate: n.certificate.Raw,
+ ClusterUnlockKey: n.clusterUnlockKey,
+ Address: n.address.String(),
+ Roles: &ipb.Node_Roles{},
+ }
+ if n.consensusMember != nil {
+ msg.Roles.ConsensusMember = &ipb.Node_Roles_ConsensusMember{
+ EtcdMemberName: n.consensusMember.etcdMemberName,
+ }
+ }
+ if n.kubernetesWorker != nil {
+ msg.Roles.KubernetesWorker = &ipb.Node_Roles_KubernetesWorker{
+ NodeName: n.kubernetesWorker.nodeName,
+ }
+ }
+ return msg
+}
+
+// Store saves the Node into etcd. This should be called only once per Node (ie. when the Node has been created).
+func (n *Node) Store(ctx context.Context, kv clientv3.KV) error {
+ // Currently the only flow to store a node to etcd is a write-once flow: once a node is created, it cannot be
+ // deleted or updated. In the future, flows to change cluster node roles might be introduced (ie. to promote nodes
+ // to consensus members, etc).
+ key := n.etcdPath()
+ msg := n.proto()
+ nodeRaw, err := proto.Marshal(msg)
+ if err != nil {
+ return fmt.Errorf("failed to marshal node: %w", err)
+ }
+
+ res, err := kv.Txn(ctx).If(
+ clientv3.Compare(clientv3.CreateRevision(key), "=", 0),
+ ).Then(
+ clientv3.OpPut(key, string(nodeRaw)),
+ ).Commit()
+ if err != nil {
+ return fmt.Errorf("failed to store node: %w", err)
+ }
+
+ if !res.Succeeded {
+ return fmt.Errorf("attempted to re-register node (unsupported flow)")
+ }
+ return nil
+}
+
+// MakeConsensusMember turns the node into a consensus member with a given name. This only configures internal fields,
+// and does not actually start any services.
+func (n *Node) MakeConsensusMember(etcdMemberName string) error {
+ if n.consensusMember != nil {
+ return fmt.Errorf("node already is consensus member")
+ }
+ n.consensusMember = &NodeRoleConsensusMember{
+ etcdMemberName: etcdMemberName,
+ }
+ return nil
+}
+
+// MakeKubernetesWorker turns the node into a kubernetes worker with a given name. This only configures internal fields,
+// and does not actually start any services.
+func (n *Node) MakeKubernetesWorker(name string) error {
+ if n.kubernetesWorker != nil {
+ return fmt.Errorf("node is already kubernetes worker")
+ }
+ n.kubernetesWorker = &NodeRoleKubernetesWorker{
+ nodeName: name,
+ }
+ return nil
+}
+
+func (n *Node) Address() net.IP {
+ return n.address
+}
+
+// ConfigureLocalHostname uses the node's ID as a hostname, and sets the current hostname, and local files like hosts
+// and machine-id accordingly.
+func (n *Node) ConfigureLocalHostname(etc *localstorage.EtcDirectory) error {
+ if err := unix.Sethostname([]byte(n.ID())); err != nil {
+ return fmt.Errorf("failed to set runtime hostname: %w", err)
+ }
+ if err := etc.Hosts.Write([]byte(fmt.Sprintf("%s %s", "127.0.0.1", n.ID())), 0644); err != nil {
+ return fmt.Errorf("failed to write /etc/hosts: %w", err)
+ }
+ if err := etc.MachineID.Write([]byte(n.IDBare()), 0644); err != nil {
+ return fmt.Errorf("failed to write /etc/machine-id: %w", err)
+ }
+ return nil
+}
diff --git a/metropolis/node/core/consensus/BUILD.bazel b/metropolis/node/core/consensus/BUILD.bazel
new file mode 100644
index 0000000..cab2c0a
--- /dev/null
+++ b/metropolis/node/core/consensus/BUILD.bazel
@@ -0,0 +1,30 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "go_default_library",
+ srcs = ["consensus.go"],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/node/core/consensus",
+ visibility = ["//:__subpackages__"],
+ deps = [
+ "//metropolis/node:go_default_library",
+ "//metropolis/node/common/supervisor:go_default_library",
+ "//metropolis/node/core/consensus/ca:go_default_library",
+ "//metropolis/node/core/localstorage:go_default_library",
+ "@io_etcd_go_etcd//clientv3:go_default_library",
+ "@io_etcd_go_etcd//clientv3/namespace:go_default_library",
+ "@io_etcd_go_etcd//embed:go_default_library",
+ "@org_uber_go_atomic//:go_default_library",
+ ],
+)
+
+go_test(
+ name = "go_default_test",
+ srcs = ["consensus_test.go"],
+ embed = [":go_default_library"],
+ deps = [
+ "//golibs/common:go_default_library",
+ "//metropolis/node/common/supervisor:go_default_library",
+ "//metropolis/node/core/localstorage:go_default_library",
+ "//metropolis/node/core/localstorage/declarative:go_default_library",
+ ],
+)
diff --git a/metropolis/node/core/consensus/ca/BUILD.bazel b/metropolis/node/core/consensus/ca/BUILD.bazel
new file mode 100644
index 0000000..fecffa0
--- /dev/null
+++ b/metropolis/node/core/consensus/ca/BUILD.bazel
@@ -0,0 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "go_default_library",
+ srcs = ["ca.go"],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/node/core/consensus/ca",
+ visibility = ["//:__subpackages__"],
+ deps = ["@io_etcd_go_etcd//clientv3:go_default_library"],
+)
diff --git a/metropolis/node/core/consensus/ca/ca.go b/metropolis/node/core/consensus/ca/ca.go
new file mode 100644
index 0000000..9a1b634
--- /dev/null
+++ b/metropolis/node/core/consensus/ca/ca.go
@@ -0,0 +1,440 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// package ca implements a simple standards-compliant certificate authority.
+// It only supports ed25519 keys, and does not maintain any persistent state.
+//
+// The CA is backed by etcd storage, and can also bootstrap itself without a yet running etcd storage (and commit
+// in-memory secrets to etcd at a later date).
+//
+// CA and certificates successfully pass https://github.com/zmap/zlint
+// (minus the CA/B rules that a public CA would adhere to, which requires
+// things like OCSP servers, Certificate Policies and ECDSA/RSA-only keys).
+package ca
+
+// TODO(leo): add zlint test
+
+import (
+ "context"
+ "crypto"
+ "crypto/ed25519"
+ "crypto/rand"
+ "crypto/sha1"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "encoding/asn1"
+ "encoding/hex"
+ "errors"
+ "fmt"
+ "math/big"
+ "net"
+ "time"
+
+ "go.etcd.io/etcd/clientv3"
+)
+
+const (
+ // TODO(q3k): move this to a declarative storage layer
+ pathCACertificate = "/etcd-ca/ca.der"
+ pathCAKey = "/etcd-ca/ca-key.der"
+ pathCACRL = "/etcd-ca/crl.der"
+ pathIssuedCertificates = "/etcd-ca/certs/"
+)
+
+func pathIssuedCertificate(serial *big.Int) string {
+ return pathIssuedCertificates + hex.EncodeToString(serial.Bytes())
+}
+
+var (
+ // From RFC 5280 Section 4.1.2.5
+ unknownNotAfter = time.Unix(253402300799, 0)
+)
+
+type CA struct {
+ // TODO: Potentially protect the key with memguard
+ privateKey *ed25519.PrivateKey
+ CACert *x509.Certificate
+ CACertRaw []byte
+
+ // bootstrapIssued are certificates that have been issued by the CA before it has been successfully Saved to etcd.
+ bootstrapIssued [][]byte
+ // canBootstrapIssue is set on CAs that have been created by New and not yet stored to etcd. If not set,
+ // certificates cannot be issued in-memory.
+ canBootstrapIssue bool
+}
+
+// Workaround for https://github.com/golang/go/issues/26676 in Go's crypto/x509. Specifically Go
+// violates Section 4.2.1.2 of RFC 5280 without this.
+// Fixed for 1.15 in https://go-review.googlesource.com/c/go/+/227098/.
+//
+// Taken from https://github.com/FiloSottile/mkcert/blob/master/cert.go#L295 written by one of Go's
+// crypto engineers (BSD 3-clause).
+func calculateSKID(pubKey crypto.PublicKey) ([]byte, error) {
+ spkiASN1, err := x509.MarshalPKIXPublicKey(pubKey)
+ if err != nil {
+ return nil, err
+ }
+
+ var spki struct {
+ Algorithm pkix.AlgorithmIdentifier
+ SubjectPublicKey asn1.BitString
+ }
+ _, err = asn1.Unmarshal(spkiASN1, &spki)
+ if err != nil {
+ return nil, err
+ }
+ skid := sha1.Sum(spki.SubjectPublicKey.Bytes)
+ return skid[:], nil
+}
+
+// New creates a new certificate authority with the given common name. The newly created CA will be stored in memory
+// until committed to etcd by calling .Save.
+func New(name string) (*CA, error) {
+ pubKey, privKey, err := ed25519.GenerateKey(rand.Reader)
+ if err != nil {
+ panic(err)
+ }
+
+ serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 127)
+ serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
+ if err != nil {
+ return nil, fmt.Errorf("failed to generate serial number: %w", err)
+ }
+
+ skid, err := calculateSKID(pubKey)
+ if err != nil {
+ return nil, err
+ }
+
+ caCert := &x509.Certificate{
+ SerialNumber: serialNumber,
+ Subject: pkix.Name{
+ CommonName: name,
+ },
+ IsCA: true,
+ BasicConstraintsValid: true,
+ NotBefore: time.Now(),
+ NotAfter: unknownNotAfter,
+ KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign | x509.KeyUsageDigitalSignature,
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageOCSPSigning},
+ AuthorityKeyId: skid,
+ SubjectKeyId: skid,
+ }
+
+ caCertRaw, err := x509.CreateCertificate(rand.Reader, caCert, caCert, pubKey, privKey)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create root certificate: %w", err)
+ }
+
+ ca := &CA{
+ privateKey: &privKey,
+ CACertRaw: caCertRaw,
+ CACert: caCert,
+
+ canBootstrapIssue: true,
+ }
+
+ return ca, nil
+}
+
+// Load restores CA state from etcd.
+func Load(ctx context.Context, kv clientv3.KV) (*CA, error) {
+ resp, err := kv.Txn(ctx).Then(
+ clientv3.OpGet(pathCACertificate),
+ clientv3.OpGet(pathCAKey),
+ // We only read the CRL to ensure it exists on etcd (and early fail on inconsistency)
+ clientv3.OpGet(pathCACRL)).Commit()
+ if err != nil {
+ return nil, fmt.Errorf("failed to retrieve CA from etcd: %w", err)
+ }
+
+ var caCert, caKey, caCRL []byte
+ for _, el := range resp.Responses {
+ for _, kv := range el.GetResponseRange().GetKvs() {
+ switch string(kv.Key) {
+ case pathCACertificate:
+ caCert = kv.Value
+ case pathCAKey:
+ caKey = kv.Value
+ case pathCACRL:
+ caCRL = kv.Value
+ }
+ }
+ }
+ if caCert == nil || caKey == nil || caCRL == nil {
+ return nil, fmt.Errorf("failed to retrieve CA from etcd, missing at least one of {ca key, ca crt, ca crl}")
+ }
+
+ if len(caKey) != ed25519.PrivateKeySize {
+ return nil, errors.New("invalid CA private key size")
+ }
+ privateKey := ed25519.PrivateKey(caKey)
+
+ caCertVal, err := x509.ParseCertificate(caCert)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse CA certificate: %w", err)
+ }
+ return &CA{
+ privateKey: &privateKey,
+ CACertRaw: caCert,
+ CACert: caCertVal,
+ }, nil
+}
+
+// Save stores a newly created CA into etcd, committing both the CA data and any certificates issued until then.
+func (c *CA) Save(ctx context.Context, kv clientv3.KV) error {
+ crl, err := c.makeCRL(nil)
+ if err != nil {
+ return fmt.Errorf("failed to generate initial CRL: %w", err)
+ }
+
+ ops := []clientv3.Op{
+ clientv3.OpPut(pathCACertificate, string(c.CACertRaw)),
+ clientv3.OpPut(pathCAKey, string([]byte(*c.privateKey))),
+ clientv3.OpPut(pathCACRL, string(crl)),
+ }
+ for i, certRaw := range c.bootstrapIssued {
+ cert, err := x509.ParseCertificate(certRaw)
+ if err != nil {
+ return fmt.Errorf("failed to parse in-memory certificate %d", i)
+ }
+ ops = append(ops, clientv3.OpPut(pathIssuedCertificate(cert.SerialNumber), string(certRaw)))
+ }
+
+ res, err := kv.Txn(ctx).If(
+ clientv3.Compare(clientv3.CreateRevision(pathCAKey), "=", 0),
+ ).Then(ops...).Commit()
+ if err != nil {
+ return fmt.Errorf("failed to store CA to etcd: %w", err)
+ }
+ if !res.Succeeded {
+ // This should pretty much never happen, but we want to catch it just in case.
+ return fmt.Errorf("failed to store CA to etcd: CA already present - cluster-level data inconsistency")
+ }
+ c.bootstrapIssued = nil
+ c.canBootstrapIssue = false
+ return nil
+}
+
+// Issue issues a certificate. If kv is non-nil, the newly issued certificate will be immediately stored to etcd,
+// otherwise it will be kept in memory (until .Save is called). Certificates can only be issued to memory on
+// newly-created CAs that have not been saved to etcd yet.
+func (c *CA) Issue(ctx context.Context, kv clientv3.KV, commonName string, externalAddress net.IP) (cert []byte, privkey []byte, err error) {
+ serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 127)
+ serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
+ if err != nil {
+ err = fmt.Errorf("failed to generate serial number: %w", err)
+ return
+ }
+
+ pubKey, privKeyRaw, err := ed25519.GenerateKey(rand.Reader)
+ if err != nil {
+ return
+ }
+ privkey, err = x509.MarshalPKCS8PrivateKey(privKeyRaw)
+ if err != nil {
+ return
+ }
+
+ etcdCert := &x509.Certificate{
+ SerialNumber: serialNumber,
+ Subject: pkix.Name{
+ CommonName: commonName,
+ OrganizationalUnit: []string{"etcd"},
+ },
+ IsCA: false,
+ BasicConstraintsValid: true,
+ NotBefore: time.Now(),
+ NotAfter: unknownNotAfter,
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
+ DNSNames: []string{commonName},
+ IPAddresses: []net.IP{externalAddress},
+ }
+ cert, err = x509.CreateCertificate(rand.Reader, etcdCert, c.CACert, pubKey, c.privateKey)
+ if err != nil {
+ err = fmt.Errorf("failed to sign new certificate: %w", err)
+ return
+ }
+
+ if kv != nil {
+ path := pathIssuedCertificate(serialNumber)
+ _, err = kv.Put(ctx, path, string(cert))
+ if err != nil {
+ err = fmt.Errorf("failed to commit new certificate to etcd: %w", err)
+ return
+ }
+ } else {
+ if !c.canBootstrapIssue {
+ err = fmt.Errorf("cannot issue new certificate to memory on existing, etcd-backed CA")
+ return
+ }
+ c.bootstrapIssued = append(c.bootstrapIssued, cert)
+ }
+ return
+}
+
+func (c *CA) makeCRL(revoked []pkix.RevokedCertificate) ([]byte, error) {
+ crl, err := c.CACert.CreateCRL(rand.Reader, c.privateKey, revoked, time.Now(), unknownNotAfter)
+ if err != nil {
+ return nil, fmt.Errorf("failed to generate CRL: %w", err)
+ }
+ return crl, nil
+}
+
+// Revoke revokes a certificate by hostname. The selected hostname will be added to the CRL stored in etcd. This call
+// might fail (safely) if a simultaneous revoke happened that caused the CRL to be bumped. The call can be then retried
+// safely.
+func (c *CA) Revoke(ctx context.Context, kv clientv3.KV, hostname string) error {
+ res, err := kv.Txn(ctx).Then(
+ clientv3.OpGet(pathCACRL),
+ clientv3.OpGet(pathIssuedCertificates, clientv3.WithPrefix())).Commit()
+ if err != nil {
+ return fmt.Errorf("failed to retrieve certificates and CRL from etcd: %w", err)
+ }
+
+ var certs []*x509.Certificate
+ var crlRevision int64
+ var crl *pkix.CertificateList
+ for _, el := range res.Responses {
+ for _, kv := range el.GetResponseRange().GetKvs() {
+ if string(kv.Key) == pathCACRL {
+ crl, err = x509.ParseCRL(kv.Value)
+ if err != nil {
+ return fmt.Errorf("could not parse CRL from etcd: %w", err)
+ }
+ crlRevision = kv.CreateRevision
+ } else {
+ cert, err := x509.ParseCertificate(kv.Value)
+ if err != nil {
+ return fmt.Errorf("could not parse certificate %q from etcd: %w", string(kv.Key), err)
+ }
+ certs = append(certs, cert)
+ }
+ }
+ }
+
+ if crl == nil {
+ return fmt.Errorf("could not find CRL in etcd")
+ }
+ revoked := crl.TBSCertList.RevokedCertificates
+
+ // Find requested hostname in issued certificates.
+ var serial *big.Int
+ for _, cert := range certs {
+ for _, dnsName := range cert.DNSNames {
+ if dnsName == hostname {
+ serial = cert.SerialNumber
+ break
+ }
+ }
+ if serial != nil {
+ break
+ }
+ }
+ if serial == nil {
+ return fmt.Errorf("could not find requested hostname")
+ }
+
+ // Check if certificate has already been revoked.
+ for _, revokedCert := range revoked {
+ if revokedCert.SerialNumber.Cmp(serial) == 0 {
+ return nil // Already revoked
+ }
+ }
+
+ revoked = append(revoked, pkix.RevokedCertificate{
+ SerialNumber: serial,
+ RevocationTime: time.Now(),
+ })
+
+ crlRaw, err := c.makeCRL(revoked)
+ if err != nil {
+ return fmt.Errorf("when generating new CRL for revocation: %w", err)
+ }
+
+ res, err = kv.Txn(ctx).If(
+ clientv3.Compare(clientv3.CreateRevision(pathCACRL), "=", crlRevision),
+ ).Then(
+ clientv3.OpPut(pathCACRL, string(crlRaw)),
+ ).Commit()
+ if err != nil {
+ return fmt.Errorf("when saving new CRL: %w", err)
+ }
+ if !res.Succeeded {
+ return fmt.Errorf("CRL save transaction failed, retry possibly")
+ }
+
+ return nil
+}
+
+// WaitCRLChange returns a channel that will receive a CRLUpdate any time the remote CRL changed. Immediately after
+// calling this method, the current CRL is retrieved from the cluster and put into the channel.
+func (c *CA) WaitCRLChange(ctx context.Context, kv clientv3.KV, w clientv3.Watcher) <-chan CRLUpdate {
+ C := make(chan CRLUpdate)
+
+ go func(ctx context.Context) {
+ ctxC, cancel := context.WithCancel(ctx)
+ defer cancel()
+
+ fail := func(f string, args ...interface{}) {
+ C <- CRLUpdate{Err: fmt.Errorf(f, args...)}
+ close(C)
+ }
+
+ initial, err := kv.Get(ctx, pathCACRL)
+ if err != nil {
+ fail("failed to retrieve initial CRL: %w", err)
+ return
+ }
+
+ C <- CRLUpdate{CRL: initial.Kvs[0].Value}
+
+ for wr := range w.Watch(ctxC, pathCACRL, clientv3.WithRev(initial.Kvs[0].CreateRevision)) {
+ if wr.Err() != nil {
+ fail("failed watching CRL: %w", wr.Err())
+ return
+ }
+
+ for _, e := range wr.Events {
+ if string(e.Kv.Key) != pathCACRL {
+ continue
+ }
+
+ C <- CRLUpdate{CRL: e.Kv.Value}
+ }
+ }
+ }(ctx)
+
+ return C
+}
+
+// CRLUpdate is emitted for every remote CRL change, and spuriously on ever new WaitCRLChange.
+type CRLUpdate struct {
+ // The new (or existing, in the case of the first call) CRL. If nil, Err will be set.
+ CRL []byte
+ // If set, an error occurred and the WaitCRLChange call must be restarted. If set, CRL will be nil.
+ Err error
+}
+
+// GetCurrentCRL returns the current CRL for the CA. This should only be used for one-shot operations like
+// bootstrapping a new node that doesn't yet have access to etcd - otherwise, WaitCRLChange shoulde be used.
+func (c *CA) GetCurrentCRL(ctx context.Context, kv clientv3.KV) ([]byte, error) {
+ initial, err := kv.Get(ctx, pathCACRL)
+ if err != nil {
+ return nil, fmt.Errorf("failed to retrieve initial CRL: %w", err)
+ }
+ return initial.Kvs[0].Value, nil
+}
diff --git a/metropolis/node/core/consensus/consensus.go b/metropolis/node/core/consensus/consensus.go
new file mode 100644
index 0000000..8916164
--- /dev/null
+++ b/metropolis/node/core/consensus/consensus.go
@@ -0,0 +1,429 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package consensus implements a managed etcd cluster member service, with a self-hosted CA system for issuing peer
+// certificates. Currently each Smalltown node runs an etcd member, and connects to the etcd member locally over a unix
+// domain socket.
+//
+// The service supports two modes of startup:
+// - initializing a new cluster, by bootstrapping the CA in memory, starting a cluster, committing the CA to etcd
+// afterwards, and saving the new node's certificate to local storage
+// - joining an existing cluster, using certificates from local storage and loading the CA from etcd. This flow is also
+// used when the node joins a cluster for the first time (then the certificates required must be provisioned
+// externally before starting the consensus service).
+//
+// Regardless of how the etcd member service was started, the resulting running service is further managed and used
+// in the same way.
+//
+package consensus
+
+import (
+ "context"
+ "encoding/pem"
+ "fmt"
+ "net"
+ "net/url"
+ "sync"
+ "time"
+
+ "go.etcd.io/etcd/clientv3"
+ "go.etcd.io/etcd/clientv3/namespace"
+ "go.etcd.io/etcd/embed"
+ "go.uber.org/atomic"
+
+ common "git.monogon.dev/source/nexantic.git/metropolis/node"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/common/supervisor"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/consensus/ca"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/localstorage"
+)
+
+const (
+ DefaultClusterToken = "SIGNOS"
+ DefaultLogger = "zap"
+)
+
+// Service is the etcd cluster member service.
+type Service struct {
+ // The configuration with which the service was started. This is immutable.
+ config *Config
+
+ // stateMu guards state. This is locked internally on public methods of Service that require access to state. The
+ // state might be recreated on service restart.
+ stateMu sync.Mutex
+ state *state
+}
+
+// state is the runtime state of a running etcd member.
+type state struct {
+ etcd *embed.Etcd
+ ready atomic.Bool
+
+ ca *ca.CA
+ // cl is an etcd client that loops back to the localy running etcd server. This runs over the Client unix domain
+ // socket that etcd starts.
+ cl *clientv3.Client
+}
+
+type Config struct {
+ // Data directory (persistent, encrypted storage) for etcd.
+ Data *localstorage.DataEtcdDirectory
+ // Ephemeral directory for etcd.
+ Ephemeral *localstorage.EphemeralConsensusDirectory
+
+ // Name is the cluster name. This must be the same amongst all etcd members within one cluster.
+ Name string
+ // NewCluster selects whether the etcd member will start a new cluster and bootstrap a CA and the first member
+ // certificate, or load existing PKI certificates from disk.
+ NewCluster bool
+ // InitialCluster sets the initial cluster peer URLs when NewCluster is set, and is ignored otherwise. Usually this
+ // will be just the new, single server, and more members will be added later.
+ InitialCluster string
+ // ExternalHost is the IP address or hostname at which this cluster member is reachable to other cluster members.
+ ExternalHost string
+ // ListenHost is the IP address or hostname at which this cluster member will listen.
+ ListenHost string
+ // Port is the port at which this cluster member will listen for other members. If zero, defaults to the global
+ // Smalltown setting.
+ Port int
+}
+
+func New(config Config) *Service {
+ return &Service{
+ config: &config,
+ }
+}
+
+// configure transforms the service configuration into an embedded etcd configuration. This is pure and side effect
+// free.
+func (s *Service) configure(ctx context.Context) (*embed.Config, error) {
+ if err := s.config.Ephemeral.MkdirAll(0700); err != nil {
+ return nil, fmt.Errorf("failed to create ephemeral directory: %w", err)
+ }
+ if err := s.config.Data.MkdirAll(0700); err != nil {
+ return nil, fmt.Errorf("failed to create data directory: %w", err)
+ }
+
+ port := s.config.Port
+ if port == 0 {
+ port = common.ConsensusPort
+ }
+
+ cfg := embed.NewConfig()
+
+ cfg.Name = s.config.Name
+ cfg.Dir = s.config.Data.Data.FullPath()
+ cfg.InitialClusterToken = DefaultClusterToken
+
+ cfg.PeerTLSInfo.CertFile = s.config.Data.PeerPKI.Certificate.FullPath()
+ cfg.PeerTLSInfo.KeyFile = s.config.Data.PeerPKI.Key.FullPath()
+ cfg.PeerTLSInfo.TrustedCAFile = s.config.Data.PeerPKI.CACertificate.FullPath()
+ cfg.PeerTLSInfo.ClientCertAuth = true
+ cfg.PeerTLSInfo.CRLFile = s.config.Data.PeerCRL.FullPath()
+
+ cfg.LCUrls = []url.URL{{
+ Scheme: "unix",
+ Path: s.config.Ephemeral.ClientSocket.FullPath() + ":0",
+ }}
+ cfg.ACUrls = []url.URL{}
+ cfg.LPUrls = []url.URL{{
+ Scheme: "https",
+ Host: fmt.Sprintf("%s:%d", s.config.ListenHost, port),
+ }}
+ cfg.APUrls = []url.URL{{
+ Scheme: "https",
+ Host: fmt.Sprintf("%s:%d", s.config.ExternalHost, port),
+ }}
+
+ if s.config.NewCluster {
+ cfg.ClusterState = "new"
+ cfg.InitialCluster = cfg.InitialClusterFromName(cfg.Name)
+ } else if s.config.InitialCluster != "" {
+ cfg.ClusterState = "existing"
+ cfg.InitialCluster = s.config.InitialCluster
+ }
+
+ // TODO(q3k): pipe logs from etcd to supervisor.RawLogger via a file.
+ cfg.Logger = DefaultLogger
+ cfg.LogOutputs = []string{"stderr"}
+
+ return cfg, nil
+}
+
+// Run is a Supervisor runnable that starts the etcd member service. It will become healthy once the member joins the
+// cluster successfully.
+func (s *Service) Run(ctx context.Context) error {
+ st := &state{
+ ready: *atomic.NewBool(false),
+ }
+ s.stateMu.Lock()
+ s.state = st
+ s.stateMu.Unlock()
+
+ if s.config.NewCluster {
+ // Expect certificate to be absent from disk.
+ absent, err := s.config.Data.PeerPKI.AllAbsent()
+ if err != nil {
+ return fmt.Errorf("checking certificate existence: %w", err)
+ }
+ if !absent {
+ return fmt.Errorf("want new cluster, but certificates already exist on disk")
+ }
+
+ // Generate CA, keep in memory, write it down in etcd later.
+ st.ca, err = ca.New("Smalltown etcd peer Root CA")
+ if err != nil {
+ return fmt.Errorf("when creating new cluster's peer CA: %w", err)
+ }
+
+ ip := net.ParseIP(s.config.ExternalHost)
+ if ip == nil {
+ return fmt.Errorf("configued external host is not an IP address (got %q)", s.config.ExternalHost)
+ }
+
+ cert, key, err := st.ca.Issue(ctx, nil, s.config.Name, ip)
+ if err != nil {
+ return fmt.Errorf("when issuing new cluster's first certificate: %w", err)
+ }
+
+ if err := s.config.Data.PeerPKI.MkdirAll(0600); err != nil {
+ return fmt.Errorf("when creating PKI directory: %w", err)
+ }
+ if err := s.config.Data.PeerPKI.CACertificate.Write(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: st.ca.CACertRaw}), 0600); err != nil {
+ return fmt.Errorf("when writing CA certificate to disk: %w", err)
+ }
+ if err := s.config.Data.PeerPKI.Certificate.Write(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert}), 0600); err != nil {
+ return fmt.Errorf("when writing certificate to disk: %w", err)
+ }
+ if err := s.config.Data.PeerPKI.Key.Write(pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: key}), 0600); err != nil {
+ return fmt.Errorf("when writing certificate to disk: %w", err)
+ }
+ } else {
+ // Expect certificate to be present on disk.
+ present, err := s.config.Data.PeerPKI.AllExist()
+ if err != nil {
+ return fmt.Errorf("checking certificate existence: %w", err)
+ }
+ if !present {
+ return fmt.Errorf("want existing cluster, but certificate is missing from disk")
+ }
+ }
+
+ if err := s.config.Data.MkdirAll(0600); err != nil {
+ return fmt.Errorf("failed to create data directory; %w", err)
+ }
+
+ cfg, err := s.configure(ctx)
+ if err != nil {
+ return fmt.Errorf("when configuring etcd: %w", err)
+ }
+
+ server, err := embed.StartEtcd(cfg)
+ keep := false
+ defer func() {
+ if !keep && server != nil {
+ server.Close()
+ }
+ }()
+ if err != nil {
+ return fmt.Errorf("failed to start etcd: %w", err)
+ }
+ st.etcd = server
+
+ supervisor.Logger(ctx).Info("waiting for etcd...")
+
+ okay := true
+ select {
+ case <-st.etcd.Server.ReadyNotify():
+ case <-ctx.Done():
+ okay = false
+ }
+
+ if !okay {
+ supervisor.Logger(ctx).Info("context done, aborting wait")
+ return ctx.Err()
+ }
+
+ socket := s.config.Ephemeral.ClientSocket.FullPath()
+ cl, err := clientv3.New(clientv3.Config{
+ Endpoints: []string{fmt.Sprintf("unix://%s:0", socket)},
+ DialTimeout: time.Second,
+ })
+ if err != nil {
+ return fmt.Errorf("failed to connect to new etcd instance: %w", err)
+ }
+ st.cl = cl
+
+ if s.config.NewCluster {
+ if st.ca == nil {
+ panic("peerCA has not been generated")
+ }
+
+ // Save new CA into etcd.
+ err = st.ca.Save(ctx, cl.KV)
+ if err != nil {
+ return fmt.Errorf("failed to save new CA to etcd: %w", err)
+ }
+ } else {
+ // Load existing CA from etcd.
+ st.ca, err = ca.Load(ctx, cl.KV)
+ if err != nil {
+ return fmt.Errorf("failed to load CA from etcd: %w", err)
+ }
+ }
+
+ // Start CRL watcher.
+ if err := supervisor.Run(ctx, "crl", s.watchCRL); err != nil {
+ return fmt.Errorf("failed to start CRL watcher: %w", err)
+ }
+ // Start autopromoter.
+ if err := supervisor.Run(ctx, "autopromoter", s.autopromoter); err != nil {
+ return fmt.Errorf("failed to start autopromoter: %w", err)
+ }
+
+ supervisor.Logger(ctx).Info("etcd is now ready")
+ keep = true
+ st.ready.Store(true)
+ supervisor.Signal(ctx, supervisor.SignalHealthy)
+
+ <-ctx.Done()
+ st.etcd.Close()
+ return ctx.Err()
+}
+
+// watchCRL is a sub-runnable of the etcd cluster member service that updates the on-local-storage CRL to match the
+// newest available version in etcd.
+func (s *Service) watchCRL(ctx context.Context) error {
+ s.stateMu.Lock()
+ cl := s.state.cl
+ ca := s.state.ca
+ s.stateMu.Unlock()
+
+ supervisor.Signal(ctx, supervisor.SignalHealthy)
+ for e := range ca.WaitCRLChange(ctx, cl.KV, cl.Watcher) {
+ if e.Err != nil {
+ return fmt.Errorf("watching CRL: %w", e.Err)
+ }
+
+ if err := s.config.Data.PeerCRL.Write(e.CRL, 0600); err != nil {
+ return fmt.Errorf("saving CRL: %w", err)
+ }
+ }
+
+ // unreachable
+ return nil
+}
+
+func (s *Service) autopromoter(ctx context.Context) error {
+ t := time.NewTicker(5 * time.Second)
+ defer t.Stop()
+
+ autopromote := func() {
+ s.stateMu.Lock()
+ st := s.state
+ s.stateMu.Unlock()
+
+ if st.etcd.Server.Leader() != st.etcd.Server.ID() {
+ return
+ }
+
+ for _, member := range st.etcd.Server.Cluster().Members() {
+ if !member.IsLearner {
+ continue
+ }
+
+ // We always call PromoteMember since the metadata necessary to decide if we should is private.
+ // Luckily etcd already does sanity checks internally and will refuse to promote nodes that aren't
+ // connected or are still behind on transactions.
+ if _, err := st.etcd.Server.PromoteMember(ctx, uint64(member.ID)); err != nil {
+ supervisor.Logger(ctx).Infof("Failed to promote consensus node %s: %v", member.Name, err)
+ } else {
+ supervisor.Logger(ctx).Infof("Promoted new consensus node %s", member.Name)
+ }
+ }
+ }
+
+ for {
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-t.C:
+ autopromote()
+ }
+ }
+}
+
+// IsReady returns whether etcd is ready and synced
+func (s *Service) IsReady() bool {
+ s.stateMu.Lock()
+ defer s.stateMu.Unlock()
+ if s.state == nil {
+ return false
+ }
+ return s.state.ready.Load()
+}
+
+func (s *Service) WaitReady(ctx context.Context) error {
+ // TODO(q3k): reimplement the atomic ready flag as an event synchronization mechanism
+ if s.IsReady() {
+ return nil
+ }
+ t := time.NewTicker(100 * time.Millisecond)
+ defer t.Stop()
+ for {
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-t.C:
+ if s.IsReady() {
+ return nil
+ }
+ }
+ }
+}
+
+// KV returns and etcd KV client interface to the etcd member/cluster.
+func (s *Service) KV(module, space string) clientv3.KV {
+ s.stateMu.Lock()
+ defer s.stateMu.Unlock()
+ return namespace.NewKV(s.state.cl.KV, fmt.Sprintf("%s:%s", module, space))
+}
+
+func (s *Service) KVRoot() clientv3.KV {
+ s.stateMu.Lock()
+ defer s.stateMu.Unlock()
+ return s.state.cl.KV
+}
+
+func (s *Service) Cluster() clientv3.Cluster {
+ s.stateMu.Lock()
+ defer s.stateMu.Unlock()
+ return s.state.cl.Cluster
+}
+
+// MemberInfo returns information about this etcd cluster member: its ID and name. This will block until this
+// information is available (ie. the cluster status is Ready).
+func (s *Service) MemberInfo(ctx context.Context) (id uint64, name string, err error) {
+ if err = s.WaitReady(ctx); err != nil {
+ err = fmt.Errorf("when waiting for cluster readiness: %w", err)
+ return
+ }
+
+ s.stateMu.Lock()
+ defer s.stateMu.Unlock()
+ id = uint64(s.state.etcd.Server.ID())
+ name = s.config.Name
+ return
+}
diff --git a/metropolis/node/core/consensus/consensus_test.go b/metropolis/node/core/consensus/consensus_test.go
new file mode 100644
index 0000000..e08bd29
--- /dev/null
+++ b/metropolis/node/core/consensus/consensus_test.go
@@ -0,0 +1,261 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package consensus
+
+import (
+ "bytes"
+ "context"
+ "crypto/x509"
+ "io/ioutil"
+ "net"
+ "os"
+ "testing"
+ "time"
+
+ "git.monogon.dev/source/nexantic.git/golibs/common"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/common/supervisor"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/localstorage"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/localstorage/declarative"
+)
+
+type boilerplate struct {
+ ctx context.Context
+ ctxC context.CancelFunc
+ root *localstorage.Root
+ tmpdir string
+}
+
+func prep(t *testing.T) *boilerplate {
+ ctx, ctxC := context.WithCancel(context.Background())
+ root := &localstorage.Root{}
+ tmp, err := ioutil.TempDir("", "smalltown-test")
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = declarative.PlaceFS(root, tmp)
+ if err != nil {
+ t.Fatal(err)
+ }
+ os.MkdirAll(root.Data.Etcd.FullPath(), 0700)
+ os.MkdirAll(root.Ephemeral.Consensus.FullPath(), 0700)
+
+ return &boilerplate{
+ ctx: ctx,
+ ctxC: ctxC,
+ root: root,
+ tmpdir: tmp,
+ }
+}
+
+func (b *boilerplate) close() {
+ b.ctxC()
+ os.RemoveAll(b.tmpdir)
+}
+
+func waitEtcd(t *testing.T, s *Service) {
+ deadline := time.Now().Add(5 * time.Second)
+ for {
+ if time.Now().After(deadline) {
+ t.Fatalf("etcd did not start up on time")
+ }
+ if s.IsReady() {
+ break
+ }
+ time.Sleep(100 * time.Millisecond)
+ }
+}
+
+func TestBootstrap(t *testing.T) {
+ b := prep(t)
+ defer b.close()
+ etcd := New(Config{
+ Data: &b.root.Data.Etcd,
+ Ephemeral: &b.root.Ephemeral.Consensus,
+ Name: "test",
+ NewCluster: true,
+ InitialCluster: "127.0.0.1",
+ ExternalHost: "127.0.0.1",
+ ListenHost: "127.0.0.1",
+ Port: common.MustConsume(common.AllocateTCPPort()),
+ })
+
+ supervisor.New(b.ctx, etcd.Run)
+ waitEtcd(t, etcd)
+
+ kv := etcd.KV("foo", "bar")
+ if _, err := kv.Put(b.ctx, "/foo", "bar"); err != nil {
+ t.Fatalf("test key creation failed: %v", err)
+ }
+ if _, err := kv.Get(b.ctx, "/foo"); err != nil {
+ t.Fatalf("test key retrieval failed: %v", err)
+ }
+}
+
+func TestMemberInfo(t *testing.T) {
+ b := prep(t)
+ defer b.close()
+ etcd := New(Config{
+ Data: &b.root.Data.Etcd,
+ Ephemeral: &b.root.Ephemeral.Consensus,
+ Name: "test",
+ NewCluster: true,
+ InitialCluster: "127.0.0.1",
+ ExternalHost: "127.0.0.1",
+ ListenHost: "127.0.0.1",
+ Port: common.MustConsume(common.AllocateTCPPort()),
+ })
+ supervisor.New(b.ctx, etcd.Run)
+ waitEtcd(t, etcd)
+
+ id, name, err := etcd.MemberInfo(b.ctx)
+ if err != nil {
+ t.Fatalf("MemberInfo: %v", err)
+ }
+
+ // Compare name with configured name.
+ if want, got := "test", name; want != got {
+ t.Errorf("MemberInfo returned name %q, wanted %q (per config)", got, want)
+ }
+
+ // Compare name with cluster information.
+ members, err := etcd.Cluster().MemberList(b.ctx)
+ if err != nil {
+ t.Errorf("MemberList: %v", err)
+ }
+
+ if want, got := 1, len(members.Members); want != got {
+ t.Fatalf("expected one cluster member, got %d", got)
+ }
+ if want, got := id, members.Members[0].ID; want != got {
+ t.Errorf("MemberInfo returned ID %d, Cluster endpoint says %d", want, got)
+ }
+ if want, got := name, members.Members[0].Name; want != got {
+ t.Errorf("MemberInfo returned name %q, Cluster endpoint says %q", want, got)
+ }
+}
+
+func TestRestartFromDisk(t *testing.T) {
+ b := prep(t)
+ defer b.close()
+
+ startEtcd := func(new bool) (*Service, context.CancelFunc) {
+ etcd := New(Config{
+ Data: &b.root.Data.Etcd,
+ Ephemeral: &b.root.Ephemeral.Consensus,
+ Name: "test",
+ NewCluster: new,
+ InitialCluster: "127.0.0.1",
+ ExternalHost: "127.0.0.1",
+ ListenHost: "127.0.0.1",
+ Port: common.MustConsume(common.AllocateTCPPort()),
+ })
+ ctx, ctxC := context.WithCancel(b.ctx)
+ supervisor.New(ctx, etcd.Run)
+ waitEtcd(t, etcd)
+ kv := etcd.KV("foo", "bar")
+ if new {
+ if _, err := kv.Put(b.ctx, "/foo", "bar"); err != nil {
+ t.Fatalf("test key creation failed: %v", err)
+ }
+ }
+ if _, err := kv.Get(b.ctx, "/foo"); err != nil {
+ t.Fatalf("test key retrieval failed: %v", err)
+ }
+
+ return etcd, ctxC
+ }
+
+ etcd, ctxC := startEtcd(true)
+ etcd.stateMu.Lock()
+ firstCA := etcd.state.ca.CACertRaw
+ etcd.stateMu.Unlock()
+ ctxC()
+
+ etcd, ctxC = startEtcd(false)
+ etcd.stateMu.Lock()
+ secondCA := etcd.state.ca.CACertRaw
+ etcd.stateMu.Unlock()
+ ctxC()
+
+ if bytes.Compare(firstCA, secondCA) != 0 {
+ t.Fatalf("wanted same, got different CAs accross runs")
+ }
+}
+
+func TestCRL(t *testing.T) {
+ b := prep(t)
+ defer b.close()
+ etcd := New(Config{
+ Data: &b.root.Data.Etcd,
+ Ephemeral: &b.root.Ephemeral.Consensus,
+ Name: "test",
+ NewCluster: true,
+ InitialCluster: "127.0.0.1",
+ ExternalHost: "127.0.0.1",
+ ListenHost: "127.0.0.1",
+ Port: common.MustConsume(common.AllocateTCPPort()),
+ })
+ supervisor.New(b.ctx, etcd.Run)
+ waitEtcd(t, etcd)
+
+ etcd.stateMu.Lock()
+ ca := etcd.state.ca
+ kv := etcd.state.cl.KV
+ etcd.stateMu.Unlock()
+
+ certRaw, _, err := ca.Issue(b.ctx, kv, "revoketest", net.ParseIP("1.2.3.4"))
+ if err != nil {
+ t.Fatalf("cert issue failed: %v", err)
+ }
+ cert, err := x509.ParseCertificate(certRaw)
+ if err != nil {
+ t.Fatalf("cert parse failed: %v", err)
+ }
+
+ if err := ca.Revoke(b.ctx, kv, "revoketest"); err != nil {
+ t.Fatalf("cert revoke failed: %v", err)
+ }
+
+ deadline := time.Now().Add(5 * time.Second)
+ for {
+ if time.Now().After(deadline) {
+ t.Fatalf("CRL did not get updated in time")
+ }
+ time.Sleep(100 * time.Millisecond)
+
+ crlRaw, err := b.root.Data.Etcd.PeerCRL.Read()
+ if err != nil {
+ // That's fine. Maybe it hasn't been written yet.
+ continue
+ }
+ crl, err := x509.ParseCRL(crlRaw)
+ if err != nil {
+ // That's fine. Maybe it hasn't been written yet.
+ continue
+ }
+
+ found := false
+ for _, revoked := range crl.TBSCertList.RevokedCertificates {
+ if revoked.SerialNumber.Cmp(cert.SerialNumber) == 0 {
+ found = true
+ }
+ }
+ if found {
+ break
+ }
+ }
+}
diff --git a/metropolis/node/core/debug_service.go b/metropolis/node/core/debug_service.go
new file mode 100644
index 0000000..0155cc6
--- /dev/null
+++ b/metropolis/node/core/debug_service.go
@@ -0,0 +1,243 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "context"
+ "crypto/x509"
+ "fmt"
+ "net"
+
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+
+ common "git.monogon.dev/source/nexantic.git/metropolis/node"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/cluster"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/consensus/ca"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/logtree"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/kubernetes"
+ apb "git.monogon.dev/source/nexantic.git/metropolis/proto/api"
+)
+
+const (
+ logFilterMax = 1000
+)
+
+// debugService implements the Smalltown node debug API.
+type debugService struct {
+ cluster *cluster.Manager
+ kubernetes *kubernetes.Service
+ logtree *logtree.LogTree
+}
+
+func (s *debugService) GetGoldenTicket(ctx context.Context, req *apb.GetGoldenTicketRequest) (*apb.GetGoldenTicketResponse, error) {
+ ip := net.ParseIP(req.ExternalIp)
+ if ip == nil {
+ return nil, status.Errorf(codes.InvalidArgument, "could not parse IP %q", req.ExternalIp)
+ }
+ this := s.cluster.Node()
+
+ certRaw, key, err := s.nodeCertificate()
+ if err != nil {
+ return nil, status.Errorf(codes.Unavailable, "failed to generate node certificate: %v", err)
+ }
+ cert, err := x509.ParseCertificate(certRaw)
+ if err != nil {
+ panic(err)
+ }
+ kv := s.cluster.ConsensusKVRoot()
+ ca, err := ca.Load(ctx, kv)
+ if err != nil {
+ return nil, status.Errorf(codes.Unavailable, "could not load CA: %v", err)
+ }
+ etcdCert, etcdKey, err := ca.Issue(ctx, kv, cert.Subject.CommonName, ip)
+ if err != nil {
+ return nil, status.Errorf(codes.Unavailable, "could not generate etcd peer certificate: %v", err)
+ }
+ etcdCRL, err := ca.GetCurrentCRL(ctx, kv)
+ if err != nil {
+ return nil, status.Errorf(codes.Unavailable, "could not get etcd CRL: %v", err)
+ }
+
+ // Add new etcd member to etcd cluster.
+ etcd := s.cluster.ConsensusCluster()
+ etcdAddr := fmt.Sprintf("https://%s:%d", ip.String(), common.ConsensusPort)
+ _, err = etcd.MemberAddAsLearner(ctx, []string{etcdAddr})
+ if err != nil {
+ return nil, status.Errorf(codes.Unavailable, "could not add as new etcd consensus member: %v", err)
+ }
+
+ return &apb.GetGoldenTicketResponse{
+ Ticket: &apb.GoldenTicket{
+ EtcdCaCert: ca.CACertRaw,
+ EtcdClientCert: etcdCert,
+ EtcdClientKey: etcdKey,
+ EtcdCrl: etcdCRL,
+ Peers: []*apb.GoldenTicket_EtcdPeer{
+ {Name: this.ID(), Address: this.Address().String()},
+ },
+ This: &apb.GoldenTicket_EtcdPeer{Name: cert.Subject.CommonName, Address: ip.String()},
+
+ NodeId: cert.Subject.CommonName,
+ NodeCert: certRaw,
+ NodeKey: key,
+ },
+ }, nil
+}
+
+func (s *debugService) GetDebugKubeconfig(ctx context.Context, req *apb.GetDebugKubeconfigRequest) (*apb.GetDebugKubeconfigResponse, error) {
+ return s.kubernetes.GetDebugKubeconfig(ctx, req)
+}
+
+func (s *debugService) GetLogs(req *apb.GetLogsRequest, srv apb.NodeDebugService_GetLogsServer) error {
+ if len(req.Filters) > logFilterMax {
+ return status.Errorf(codes.InvalidArgument, "requested %d filters, maximum permitted is %d", len(req.Filters), logFilterMax)
+ }
+ dn := logtree.DN(req.Dn)
+ _, err := dn.Path()
+ switch err {
+ case nil:
+ case logtree.ErrInvalidDN:
+ return status.Errorf(codes.InvalidArgument, "invalid DN")
+ default:
+ return status.Errorf(codes.Unavailable, "could not parse DN: %v", err)
+ }
+
+ var options []logtree.LogReadOption
+
+ // Turn backlog mode into logtree option(s).
+ switch req.BacklogMode {
+ case apb.GetLogsRequest_BACKLOG_DISABLE:
+ case apb.GetLogsRequest_BACKLOG_ALL:
+ options = append(options, logtree.WithBacklog(logtree.BacklogAllAvailable))
+ case apb.GetLogsRequest_BACKLOG_COUNT:
+ count := int(req.BacklogCount)
+ if count <= 0 {
+ return status.Errorf(codes.InvalidArgument, "backlog_count must be > 0 if backlog_mode is BACKLOG_COUNT")
+ }
+ options = append(options, logtree.WithBacklog(count))
+ default:
+ return status.Errorf(codes.InvalidArgument, "unknown backlog_mode %d", req.BacklogMode)
+ }
+
+ // Turn stream mode into logtree option(s).
+ streamEnable := false
+ switch req.StreamMode {
+ case apb.GetLogsRequest_STREAM_DISABLE:
+ case apb.GetLogsRequest_STREAM_UNBUFFERED:
+ streamEnable = true
+ options = append(options, logtree.WithStream())
+ }
+
+ // Parse proto filters into logtree options.
+ for i, filter := range req.Filters {
+ switch inner := filter.Filter.(type) {
+ case *apb.LogFilter_WithChildren_:
+ options = append(options, logtree.WithChildren())
+ case *apb.LogFilter_OnlyRaw_:
+ options = append(options, logtree.OnlyRaw())
+ case *apb.LogFilter_OnlyLeveled_:
+ options = append(options, logtree.OnlyLeveled())
+ case *apb.LogFilter_LeveledWithMinimumSeverity_:
+ severity, err := logtree.SeverityFromProto(inner.LeveledWithMinimumSeverity.Minimum)
+ if err != nil {
+ return status.Errorf(codes.InvalidArgument, "filter %d has invalid severity: %v", i, err)
+ }
+ options = append(options, logtree.LeveledWithMinimumSeverity(severity))
+ }
+ }
+
+ reader, err := s.logtree.Read(logtree.DN(req.Dn), options...)
+ switch err {
+ case nil:
+ case logtree.ErrRawAndLeveled:
+ return status.Errorf(codes.InvalidArgument, "requested only raw and only leveled logs simultaneously")
+ default:
+ return status.Errorf(codes.Unavailable, "could not retrieve logs: %v", err)
+ }
+ defer reader.Close()
+
+ // Default protobuf message size limit is 64MB. We want to limit ourselves
+ // to 10MB.
+ // Currently each raw log line can be at most 1024 unicode codepoints (or
+ // 4096 bytes). To cover extra metadata and proto overhead, let's round
+ // this up to 4500 bytes. This in turn means we can store a maximum of
+ // (10e6/4500) == 2222 entries.
+ // Currently each leveled log line can also be at most 1024 unicode
+ // codepoints (or 4096 bytes). To cover extra metadata and proto overhead
+ // let's round this up to 2000 bytes. This in turn means we can store a
+ // maximum of (10e6/5000) == 2000 entries.
+ // The lowever of these numbers, ie the worst case scenario, is 2000
+ // maximum entries.
+ maxChunkSize := 2000
+
+ // Serve all backlog entries in chunks.
+ chunk := make([]*apb.LogEntry, 0, maxChunkSize)
+ for _, entry := range reader.Backlog {
+ p := entry.Proto()
+ if p == nil {
+ // TODO(q3k): log this once we have logtree/gRPC compatibility.
+ continue
+ }
+ chunk = append(chunk, p)
+
+ if len(chunk) >= maxChunkSize {
+ err := srv.Send(&apb.GetLogsResponse{
+ BacklogEntries: chunk,
+ })
+ if err != nil {
+ return err
+ }
+ chunk = make([]*apb.LogEntry, 0, maxChunkSize)
+ }
+ }
+
+ // Send last chunk of backlog, if present..
+ if len(chunk) > 0 {
+ err := srv.Send(&apb.GetLogsResponse{
+ BacklogEntries: chunk,
+ })
+ if err != nil {
+ return err
+ }
+ chunk = make([]*apb.LogEntry, 0, maxChunkSize)
+ }
+
+ // Start serving streaming data, if streaming has been requested.
+ if !streamEnable {
+ return nil
+ }
+
+ for {
+ entry, ok := <-reader.Stream
+ if !ok {
+ // Streaming has been ended by logtree - tell the client and return.
+ return status.Error(codes.Unavailable, "log streaming aborted by system")
+ }
+ p := entry.Proto()
+ if p == nil {
+ // TODO(q3k): log this once we have logtree/gRPC compatibility.
+ continue
+ }
+ err := srv.Send(&apb.GetLogsResponse{
+ StreamEntries: []*apb.LogEntry{p},
+ })
+ if err != nil {
+ return err
+ }
+ }
+}
diff --git a/metropolis/node/core/delve_disabled.go b/metropolis/node/core/delve_disabled.go
new file mode 100644
index 0000000..cb0a59b
--- /dev/null
+++ b/metropolis/node/core/delve_disabled.go
@@ -0,0 +1,23 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import "git.monogon.dev/source/nexantic.git/metropolis/node/core/network"
+
+// initializeDebugger does nothing in a non-debug build
+func initializeDebugger(*network.Service) {
+}
diff --git a/metropolis/node/core/delve_enabled.go b/metropolis/node/core/delve_enabled.go
new file mode 100644
index 0000000..b3b859c
--- /dev/null
+++ b/metropolis/node/core/delve_enabled.go
@@ -0,0 +1,41 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "context"
+ "fmt"
+ "os/exec"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/network"
+)
+
+// initializeDebugger attaches Delve to ourselves and exposes it on common.DebuggerPort
+// This is coupled to compilation_mode=dbg because otherwise Delve doesn't have the necessary DWARF debug info
+func initializeDebugger(networkSvc *network.Service) {
+ go func() {
+ // This is intentionally delayed until network becomes available since Delve for some reason connects to itself
+ // and in early-boot no network interface is available to do that through. Also external access isn't possible
+ // early on anyways.
+ networkSvc.GetIP(context.Background(), true)
+ dlvCmd := exec.Command("/dlv", "--headless=true", fmt.Sprintf("--listen=:%v", common.DebuggerPort),
+ "--accept-multiclient", "--only-same-user=false", "attach", "--continue", "1", "/init")
+ if err := dlvCmd.Start(); err != nil {
+ panic(err)
+ }
+ }()
+}
diff --git a/metropolis/node/core/localstorage/BUILD.bazel b/metropolis/node/core/localstorage/BUILD.bazel
new file mode 100644
index 0000000..099a380
--- /dev/null
+++ b/metropolis/node/core/localstorage/BUILD.bazel
@@ -0,0 +1,26 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "go_default_library",
+ srcs = [
+ "directory_data.go",
+ "directory_pki.go",
+ "directory_root.go",
+ "storage.go",
+ ],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/node/core/localstorage",
+ visibility = ["//metropolis/node:__subpackages__"],
+ deps = [
+ "//metropolis/node/core/localstorage/crypt:go_default_library",
+ "//metropolis/node/core/localstorage/declarative:go_default_library",
+ "//metropolis/node/core/tpm:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+go_test(
+ name = "go_default_test",
+ srcs = ["storage_test.go"],
+ embed = [":go_default_library"],
+ deps = ["//metropolis/node/core/localstorage/declarative:go_default_library"],
+)
diff --git a/metropolis/node/core/localstorage/crypt/BUILD.bazel b/metropolis/node/core/localstorage/crypt/BUILD.bazel
new file mode 100644
index 0000000..38e27d6
--- /dev/null
+++ b/metropolis/node/core/localstorage/crypt/BUILD.bazel
@@ -0,0 +1,17 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "go_default_library",
+ srcs = [
+ "blockdev.go",
+ "crypt.go",
+ ],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/node/core/localstorage/crypt",
+ visibility = ["//metropolis/node/core/localstorage:__subpackages__"],
+ deps = [
+ "//metropolis/node/common/devicemapper:go_default_library",
+ "//metropolis/node/common/sysfs:go_default_library",
+ "@com_github_rekby_gpt//:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/metropolis/node/core/localstorage/crypt/blockdev.go b/metropolis/node/core/localstorage/crypt/blockdev.go
new file mode 100644
index 0000000..df5f590
--- /dev/null
+++ b/metropolis/node/core/localstorage/crypt/blockdev.go
@@ -0,0 +1,95 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package crypt
+
+import (
+ "context"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "strconv"
+
+ "github.com/rekby/gpt"
+ "golang.org/x/sys/unix"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/node/common/sysfs"
+)
+
+var (
+ // EFIPartitionType is the standardized partition type value for the EFI ESP partition. The human readable GUID is C12A7328-F81F-11D2-BA4B-00A0C93EC93B.
+ EFIPartitionType = gpt.PartType{0x28, 0x73, 0x2a, 0xc1, 0x1f, 0xf8, 0xd2, 0x11, 0xba, 0x4b, 0x00, 0xa0, 0xc9, 0x3e, 0xc9, 0x3b}
+
+ // SmalltownDataPartitionType is the partition type value for a Smalltown data partition. The human-readable GUID is 9eeec464-6885-414a-b278-4305c51f7966.
+ SmalltownDataPartitionType = gpt.PartType{0x64, 0xc4, 0xee, 0x9e, 0x85, 0x68, 0x4a, 0x41, 0xb2, 0x78, 0x43, 0x05, 0xc5, 0x1f, 0x79, 0x66}
+)
+
+const (
+ ESPDevicePath = "/dev/esp"
+ SmalltownDataCryptPath = "/dev/data-crypt"
+)
+
+// MakeBlockDevices looks for the ESP and the Smalltown data partition and maps them to ESPDevicePath and
+// SmalltownDataCryptPath respectively. This doesn't fail if it doesn't find the partitions, only if
+// something goes catastrophically wrong.
+func MakeBlockDevices(ctx context.Context) error {
+ blockdevNames, err := ioutil.ReadDir("/sys/class/block")
+ if err != nil {
+ return fmt.Errorf("failed to read sysfs block class: %w", err)
+ }
+ for _, blockdevName := range blockdevNames {
+ ueventData, err := sysfs.ReadUevents(filepath.Join("/sys/class/block", blockdevName.Name(), "uevent"))
+ if err != nil {
+ return fmt.Errorf("failed to read uevent for block device %v: %w", blockdevName.Name(), err)
+ }
+ if ueventData["DEVTYPE"] == "disk" {
+ majorDev, err := strconv.Atoi(ueventData["MAJOR"])
+ if err != nil {
+ return fmt.Errorf("failed to convert uevent: %w", err)
+ }
+ devNodeName := fmt.Sprintf("/dev/%v", ueventData["DEVNAME"])
+ blkdev, err := os.Open(devNodeName)
+ if err != nil {
+ return fmt.Errorf("failed to open block device %v: %w", devNodeName, err)
+ }
+ defer blkdev.Close()
+ blockSize, err := unix.IoctlGetUint32(int(blkdev.Fd()), unix.BLKSSZGET)
+ if err != nil {
+ continue // This is not a regular block device
+ }
+ blkdev.Seek(int64(blockSize), 0)
+ table, err := gpt.ReadTable(blkdev, uint64(blockSize))
+ if err != nil {
+ // Probably just not a GPT-partitioned disk
+ continue
+ }
+ for partNumber, part := range table.Partitions {
+ if part.Type == EFIPartitionType {
+ if err := unix.Mknod(ESPDevicePath, 0600|unix.S_IFBLK, int(unix.Mkdev(uint32(majorDev), uint32(partNumber+1)))); err != nil {
+ return fmt.Errorf("failed to create device node for ESP partition: %w", err)
+ }
+ }
+ if part.Type == SmalltownDataPartitionType {
+ if err := unix.Mknod(SmalltownDataCryptPath, 0600|unix.S_IFBLK, int(unix.Mkdev(uint32(majorDev), uint32(partNumber+1)))); err != nil {
+ return fmt.Errorf("failed to create device node for Smalltown encrypted data partition: %w", err)
+ }
+ }
+ }
+ }
+ }
+ return nil
+}
diff --git a/metropolis/node/core/localstorage/crypt/crypt.go b/metropolis/node/core/localstorage/crypt/crypt.go
new file mode 100644
index 0000000..e0a8321
--- /dev/null
+++ b/metropolis/node/core/localstorage/crypt/crypt.go
@@ -0,0 +1,149 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package crypt
+
+import (
+ "encoding/binary"
+ "encoding/hex"
+ "fmt"
+ "os"
+ "syscall"
+
+ "golang.org/x/sys/unix"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/node/common/devicemapper"
+)
+
+func readDataSectors(path string) (uint64, error) {
+ integrityPartition, err := os.Open(path)
+ if err != nil {
+ return 0, err
+ }
+ defer integrityPartition.Close()
+ // Based on structure defined in
+ // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/drivers/md/dm-integrity.c#n59
+ if _, err := integrityPartition.Seek(16, 0); err != nil {
+ return 0, err
+ }
+ var providedDataSectors uint64
+ if err := binary.Read(integrityPartition, binary.LittleEndian, &providedDataSectors); err != nil {
+ return 0, err
+ }
+ return providedDataSectors, nil
+}
+
+// cryptMap maps an encrypted device (node) at baseName to a
+// decrypted device at /dev/$name using the given encryptionKey
+func CryptMap(name string, baseName string, encryptionKey []byte) error {
+ integritySectors, err := readDataSectors(baseName)
+ if err != nil {
+ return fmt.Errorf("failed to read the number of usable sectors on the integrity device: %w", err)
+ }
+
+ integrityDevName := fmt.Sprintf("/dev/%v-integrity", name)
+ integrityDMName := fmt.Sprintf("%v-integrity", name)
+ integrityDev, err := devicemapper.CreateActiveDevice(integrityDMName, []devicemapper.Target{
+ devicemapper.Target{
+ Length: integritySectors,
+ Type: "integrity",
+ Parameters: fmt.Sprintf("%v 0 28 J 1 journal_sectors:1024", baseName),
+ },
+ })
+ if err != nil {
+ return fmt.Errorf("failed to create Integrity device: %w", err)
+ }
+ if err := unix.Mknod(integrityDevName, 0600|unix.S_IFBLK, int(integrityDev)); err != nil {
+ unix.Unlink(integrityDevName)
+ devicemapper.RemoveDevice(integrityDMName)
+ return fmt.Errorf("failed to create integrity device node: %w", err)
+ }
+
+ cryptDevName := fmt.Sprintf("/dev/%v", name)
+ cryptDev, err := devicemapper.CreateActiveDevice(name, []devicemapper.Target{
+ devicemapper.Target{
+ Length: integritySectors,
+ Type: "crypt",
+ Parameters: fmt.Sprintf("capi:gcm(aes)-random %v 0 %v 0 1 integrity:28:aead", hex.EncodeToString(encryptionKey), integrityDevName),
+ },
+ })
+ if err != nil {
+ unix.Unlink(integrityDevName)
+ devicemapper.RemoveDevice(integrityDMName)
+ return fmt.Errorf("failed to create crypt device: %w", err)
+ }
+ if err := unix.Mknod(cryptDevName, 0600|unix.S_IFBLK, int(cryptDev)); err != nil {
+ unix.Unlink(cryptDevName)
+ devicemapper.RemoveDevice(name)
+
+ unix.Unlink(integrityDevName)
+ devicemapper.RemoveDevice(integrityDMName)
+ return fmt.Errorf("failed to create crypt device node: %w", err)
+ }
+ return nil
+}
+
+// cryptInit initializes a new encrypted block device. This can take a long
+// time since all bytes on the mapped block device need to be zeroed.
+func CryptInit(name, baseName string, encryptionKey []byte) error {
+ integrityPartition, err := os.OpenFile(baseName, os.O_WRONLY, 0)
+ if err != nil {
+ return err
+ }
+ defer integrityPartition.Close()
+ zeroed512BBuf := make([]byte, 4096)
+ if _, err := integrityPartition.Write(zeroed512BBuf); err != nil {
+ return fmt.Errorf("failed to wipe header: %w", err)
+ }
+ integrityPartition.Close()
+
+ integrityDMName := fmt.Sprintf("%v-integrity", name)
+ _, err = devicemapper.CreateActiveDevice(integrityDMName, []devicemapper.Target{
+ {
+ Length: 1,
+ Type: "integrity",
+ Parameters: fmt.Sprintf("%v 0 28 J 1 journal_sectors:1024", baseName),
+ },
+ })
+ if err != nil {
+ return fmt.Errorf("failed to create discovery integrity device: %w", err)
+ }
+ if err := devicemapper.RemoveDevice(integrityDMName); err != nil {
+ return fmt.Errorf("failed to remove discovery integrity device: %w", err)
+ }
+
+ if err := CryptMap(name, baseName, encryptionKey); err != nil {
+ return err
+ }
+
+ blkdev, err := os.OpenFile(fmt.Sprintf("/dev/%v", name), unix.O_DIRECT|os.O_WRONLY, 0000)
+ if err != nil {
+ return fmt.Errorf("failed to open new encrypted device for zeroing: %w", err)
+ }
+ defer blkdev.Close()
+ blockSize, err := unix.IoctlGetUint32(int(blkdev.Fd()), unix.BLKSSZGET)
+ zeroedBuf := make([]byte, blockSize*100) // Make it faster
+ for {
+ _, err := blkdev.Write(zeroedBuf)
+ if e, ok := err.(*os.PathError); ok && e.Err == syscall.ENOSPC {
+ break
+ }
+ if err != nil {
+ return fmt.Errorf("failed to zero-initalize new encrypted device: %w", err)
+ }
+ }
+ return nil
+}
diff --git a/metropolis/node/core/localstorage/declarative/BUILD.bazel b/metropolis/node/core/localstorage/declarative/BUILD.bazel
new file mode 100644
index 0000000..5b51aa2
--- /dev/null
+++ b/metropolis/node/core/localstorage/declarative/BUILD.bazel
@@ -0,0 +1,13 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "go_default_library",
+ srcs = [
+ "declarative.go",
+ "placement.go",
+ "placement_local.go",
+ ],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/node/core/localstorage/declarative",
+ visibility = ["//metropolis/node:__subpackages__"],
+ deps = ["@org_golang_x_sys//unix:go_default_library"],
+)
diff --git a/metropolis/node/core/localstorage/declarative/declarative.go b/metropolis/node/core/localstorage/declarative/declarative.go
new file mode 100644
index 0000000..ce82c42
--- /dev/null
+++ b/metropolis/node/core/localstorage/declarative/declarative.go
@@ -0,0 +1,199 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package declarative
+
+import (
+ "fmt"
+ "reflect"
+ "strings"
+)
+
+// Directory represents the intent of existence of a directory in a hierarchical filesystem (simplified to a tree).
+// This structure can be embedded and still be interpreted as a Directory for purposes of use within this library. Any
+// inner fields of such an embedding structure that are in turn (embedded) Directories or files will be treated as
+// children in the intent expressed by this Directory. All contained directory fields must have a `dir:"name"` struct
+// tag that names them, and all contained file fields must have a `file:"name"` struct tag.
+//
+// Creation and management of the directory at runtime is left to the implementing code. However, the DirectoryPlacement
+// implementation (set as the directory is placed onto a backing store) facilitates this management (by exposing methods
+// that mutate the backing store).
+type Directory struct {
+ DirectoryPlacement
+}
+
+// File represents the intent of existence of a file. files are usually child structures in types that embed Directory.
+// File can also be embedded in another structure, and this embedding type will still be interpreted as a File for
+// purposes of use within this library.
+//
+// As with Directory, the runtime management of a File in a backing store is left to the implementing code, and the
+// embedded FilePlacement interface facilitates access to the backing store.
+type File struct {
+ FilePlacement
+}
+
+// unpackDirectory takes a pointer to Directory or a pointer to a structure embedding Directory, and returns a
+// reflection Value that refers to the passed structure itself (not its pointer) and a plain Go pointer to the
+// (embedded) Directory.
+func unpackDirectory(d interface{}) (*reflect.Value, *Directory, error) {
+ td := reflect.TypeOf(d)
+ if td.Kind() != reflect.Ptr {
+ return nil, nil, fmt.Errorf("wanted a pointer, got %v", td.Kind())
+ }
+
+ var dir *Directory
+ id := reflect.ValueOf(d).Elem()
+ tid := id.Type()
+ switch {
+ case tid.Name() == reflect.TypeOf(Directory{}).Name():
+ dir = id.Addr().Interface().(*Directory)
+ case id.FieldByName("Directory").IsValid():
+ dir = id.FieldByName("Directory").Addr().Interface().(*Directory)
+ default:
+ return nil, nil, fmt.Errorf("not a Directory or embedding Directory (%v)", id.Type().String())
+ }
+ return &id, dir, nil
+}
+
+// unpackFile takes a pointer to a File or a pointer to a structure embedding File, and returns a reflection Value that
+// refers to the passed structure itself (not its pointer) and a plain Go pointer to the (embedded) File.
+func unpackFile(f interface{}) (*reflect.Value, *File, error) {
+ tf := reflect.TypeOf(f)
+ if tf.Kind() != reflect.Ptr {
+ return nil, nil, fmt.Errorf("wanted a pointer, got %v", tf.Kind())
+ }
+
+ var fil *File
+ id := reflect.ValueOf(f).Elem()
+ tid := id.Type()
+ switch {
+ case tid.Name() == reflect.TypeOf(File{}).Name():
+ fil = id.Addr().Interface().(*File)
+ case id.FieldByName("File").IsValid():
+ fil = id.FieldByName("File").Addr().Interface().(*File)
+ default:
+ return nil, nil, fmt.Errorf("not a File or embedding File (%v)", tid.String())
+ }
+ return &id, fil, nil
+
+}
+
+// subdirs takes a pointer to a Directory or pointer to a structure embedding Directory, and returns a pair of pointers
+// to Directory-like structures contained within that directory with corresponding names (based on struct tags).
+func subdirs(d interface{}) ([]namedDirectory, error) {
+ s, _, err := unpackDirectory(d)
+ if err != nil {
+ return nil, fmt.Errorf("argument could not be parsed as *Directory: %w", err)
+ }
+
+ var res []namedDirectory
+ for i := 0; i < s.NumField(); i++ {
+ tf := s.Type().Field(i)
+ dirTag := tf.Tag.Get("dir")
+ if dirTag == "" {
+ continue
+ }
+ sf := s.Field(i)
+ res = append(res, namedDirectory{dirTag, sf.Addr().Interface()})
+ }
+ return res, nil
+}
+
+type namedDirectory struct {
+ name string
+ directory interface{}
+}
+
+// files takes a pointer to a File or pointer to a structure embedding File, and returns a pair of pointers
+// to Directory-like structures contained within that directory with corresponding names (based on struct tags).
+func files(d interface{}) ([]namedFile, error) {
+ s, _, err := unpackDirectory(d)
+ if err != nil {
+ return nil, fmt.Errorf("argument could not be parsed as *Directory: %w", err)
+ }
+
+ var res []namedFile
+ for i := 0; i < s.NumField(); i++ {
+ tf := s.Type().Field(i)
+ fileTag := tf.Tag.Get("file")
+ if fileTag == "" {
+ continue
+ }
+ _, f, err := unpackFile(s.Field(i).Addr().Interface())
+ if err != nil {
+ return nil, fmt.Errorf("file %q could not be parsed as *File: %w", tf.Name, err)
+ }
+ res = append(res, namedFile{fileTag, f})
+ }
+ return res, nil
+}
+
+type namedFile struct {
+ name string
+ file *File
+}
+
+// Validate checks that a given pointer to a Directory or pointer to a structure containing Directory does not contain
+// any programmer errors in its definition:
+// - all subdirectories/files must be named
+// - all subdirectory/file names within a directory must be unique
+// - all subdirectory/file names within a directory must not contain the '/' character (as it is a common path
+// delimiter)
+func Validate(d interface{}) error {
+ names := make(map[string]bool)
+
+ subs, err := subdirs(d)
+ if err != nil {
+ return fmt.Errorf("could not get subdirectories: %w", err)
+ }
+
+ for _, nd := range subs {
+ if nd.name == "" {
+ return fmt.Errorf("subdirectory with empty name")
+ }
+ if strings.Contains(nd.name, "/") {
+ return fmt.Errorf("subdirectory with invalid path: %q", nd.name)
+ }
+ if names[nd.name] {
+ return fmt.Errorf("subdirectory with duplicate name: %q", nd.name)
+ }
+ names[nd.name] = true
+
+ err := Validate(nd.directory)
+ if err != nil {
+ return fmt.Errorf("%s: %w", nd.name, err)
+ }
+ }
+
+ filelist, err := files(d)
+ if err != nil {
+ return fmt.Errorf("could not get files: %w", err)
+ }
+
+ for _, nf := range filelist {
+ if nf.name == "" {
+ return fmt.Errorf("file with empty name")
+ }
+ if strings.Contains(nf.name, "/") {
+ return fmt.Errorf("file with invalid path: %q", nf.name)
+ }
+ if names[nf.name] {
+ return fmt.Errorf("file with duplicate name: %q", nf.name)
+ }
+ names[nf.name] = true
+ }
+ return nil
+}
diff --git a/metropolis/node/core/localstorage/declarative/placement.go b/metropolis/node/core/localstorage/declarative/placement.go
new file mode 100644
index 0000000..c2ff53d
--- /dev/null
+++ b/metropolis/node/core/localstorage/declarative/placement.go
@@ -0,0 +1,95 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package declarative
+
+import (
+ "fmt"
+ "os"
+)
+
+// A declarative Directory/File tree is an abstract definition until it's 'placed' on a backing file system.
+// By convention, all abstract definitions of hierarchies are stored as copiable structs, and only turned to pointers
+// when placed (ie., implementations like PlaceFS takes a *Directory, but Root as a declarative definition is defined as
+// non-pointer).
+
+// Placement is an interface available on Placed files and Directories. All *Placement interfaces on files/Directories
+// are only available on placed trees - eg., after a PlaceFS call. This is unfortunately not typesafe, callers need to
+// either be sure about placement, or check the interface for null.
+type Placement interface {
+ FullPath() string
+ RootRef() interface{}
+}
+
+// FilePlacement is an interface available on Placed files. It is implemented by different placement backends, and
+// set on all files during placement by a given backend.
+type FilePlacement interface {
+ Placement
+ Exists() (bool, error)
+ Read() ([]byte, error)
+ Write([]byte, os.FileMode) error
+}
+
+// DirectoryPlacement is an interface available on Placed Directories. It is implemented by different placement
+// backends, and set on all directories during placement by a given backend.
+type DirectoryPlacement interface {
+ Placement
+ // MkdirAll creates this directory and all its parents on backing stores that have a physical directory
+ // structure.
+ MkdirAll(file os.FileMode) error
+}
+
+// DirectoryPlacer is a placement backend-defined function that, given the path returned by the parent of a directory,
+// and the path to a directory, returns a DirectoryPlacement implementation for this directory. The new placement's
+// path (via .FullPath()) will be used for placement of directories/files within the new directory.
+type DirectoryPlacer func(parent, this string) DirectoryPlacement
+
+// FilePlacer is analogous to DirectoryPlacer, but for files.
+type FilePlacer func(parent, this string) FilePlacement
+
+// place recursively places a pointer to a Directory or pointer to a structure embedding Directory into a given backend,
+// by calling DirectoryPlacer and FilePlacer where appropriate. This is done recursively across a declarative tree until
+// all children are placed.
+func place(d interface{}, parent, this string, dpl DirectoryPlacer, fpl FilePlacer) error {
+ _, dir, err := unpackDirectory(d)
+ if err != nil {
+ return err
+ }
+
+ if dir.DirectoryPlacement != nil {
+ return fmt.Errorf("already placed")
+ }
+ dir.DirectoryPlacement = dpl(parent, this)
+
+ dirlist, err := subdirs(d)
+ if err != nil {
+ return fmt.Errorf("could not list subdirectories: %w", err)
+ }
+ for _, nd := range dirlist {
+ err := place(nd.directory, dir.FullPath(), nd.name, dpl, fpl)
+ if err != nil {
+ return fmt.Errorf("%v: %w", nd.name, err)
+ }
+ }
+ filelist, err := files(d)
+ if err != nil {
+ return fmt.Errorf("could not list files: %w", err)
+ }
+ for _, nf := range filelist {
+ nf.file.FilePlacement = fpl(dir.FullPath(), nf.name)
+ }
+ return nil
+}
diff --git a/metropolis/node/core/localstorage/declarative/placement_local.go b/metropolis/node/core/localstorage/declarative/placement_local.go
new file mode 100644
index 0000000..82b6a71
--- /dev/null
+++ b/metropolis/node/core/localstorage/declarative/placement_local.go
@@ -0,0 +1,115 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package declarative
+
+import (
+ "fmt"
+ "io/ioutil"
+ "os"
+ "sync"
+
+ "golang.org/x/sys/unix"
+)
+
+// FSRoot is a root of a storage backend that resides on the local filesystem.
+type FSRoot struct {
+ // The local path at which the declarative directory structure is located (eg. "/").
+ root string
+}
+
+type FSPlacement struct {
+ root *FSRoot
+ path string
+ writeLock sync.Mutex
+}
+
+func (f *FSPlacement) FullPath() string {
+ return f.path
+}
+
+func (f *FSPlacement) RootRef() interface{} {
+ return f.root
+}
+
+func (f *FSPlacement) Exists() (bool, error) {
+ _, err := os.Stat(f.FullPath())
+ if err == nil {
+ return true, nil
+ }
+ if os.IsNotExist(err) {
+ return false, nil
+ }
+ return false, err
+}
+
+func (f *FSPlacement) Read() ([]byte, error) {
+ return ioutil.ReadFile(f.FullPath())
+}
+
+// Write performs an atomic file write, via a temporary file.
+func (f *FSPlacement) Write(d []byte, mode os.FileMode) error {
+ f.writeLock.Lock()
+ defer f.writeLock.Unlock()
+
+ // TODO(q3k): ensure that these do not collide with an existing sibling file, or generate this suffix randomly.
+ tmp := f.FullPath() + ".__smalltown_tmp"
+ defer os.Remove(tmp)
+ if err := ioutil.WriteFile(tmp, d, mode); err != nil {
+ return fmt.Errorf("temporary file write failed: %w", err)
+ }
+
+ if err := unix.Rename(tmp, f.FullPath()); err != nil {
+ return fmt.Errorf("renaming target file failed: %w", err)
+ }
+
+ return nil
+}
+
+func (f *FSPlacement) MkdirAll(perm os.FileMode) error {
+ return os.MkdirAll(f.FullPath(), perm)
+}
+
+// PlaceFS takes a pointer to a Directory or a pointer to a structure embedding Directory and places it at a given
+// filesystem root. From this point on the given structure pointer has valid Placement interfaces.
+func PlaceFS(dd interface{}, root string) error {
+ r := &FSRoot{root}
+ pathFor := func(parent, this string) string {
+ var np string
+ switch {
+ case parent == "" && this == "":
+ np = "/"
+ case parent == "/":
+ np = "/" + this
+ default:
+ np = fmt.Sprintf("%s/%s", parent, this)
+ }
+ return np
+ }
+ dp := func(parent, this string) DirectoryPlacement {
+ np := pathFor(parent, this)
+ return &FSPlacement{path: np, root: r}
+ }
+ fp := func(parent, this string) FilePlacement {
+ np := pathFor(parent, this)
+ return &FSPlacement{path: np, root: r}
+ }
+ err := place(dd, r.root, "", dp, fp)
+ if err != nil {
+ return fmt.Errorf("could not place: %w", err)
+ }
+ return nil
+}
diff --git a/metropolis/node/core/localstorage/directory_data.go b/metropolis/node/core/localstorage/directory_data.go
new file mode 100644
index 0000000..e90dc48
--- /dev/null
+++ b/metropolis/node/core/localstorage/directory_data.go
@@ -0,0 +1,150 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package localstorage
+
+import (
+ "fmt"
+ "os"
+ "os/exec"
+
+ "golang.org/x/sys/unix"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/localstorage/crypt"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/localstorage/declarative"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/tpm"
+)
+
+var keySize uint16 = 256 / 8
+
+// MountData mounts the Smalltown data partition with the given global unlock key. It automatically
+// unseals the local unlock key from the TPM.
+func (d *DataDirectory) MountExisting(unlock *ESPLocalUnlockFile, globalUnlockKey []byte) error {
+ d.flagLock.Lock()
+ defer d.flagLock.Unlock()
+
+ if !d.canMount {
+ return fmt.Errorf("cannot mount yet (root not ready?)")
+ }
+ if d.mounted {
+ return fmt.Errorf("already mounted")
+ }
+ d.mounted = true
+
+ localUnlockBlob, err := unlock.Read()
+ if err != nil {
+ return fmt.Errorf("reading local unlock file from ESP: %w", err)
+ }
+ localUnlockKey, err := tpm.Unseal(localUnlockBlob)
+ if err != nil {
+ return fmt.Errorf("unsealing local unlock key: %w", err)
+ }
+
+ key := make([]byte, keySize)
+ for i := uint16(0); i < keySize; i++ {
+ key[i] = localUnlockKey[i] ^ globalUnlockKey[i]
+ }
+
+ if err := crypt.CryptMap("data", crypt.SmalltownDataCryptPath, key); err != nil {
+ return err
+ }
+ if err := d.mount(); err != nil {
+ return err
+ }
+ return nil
+}
+
+// InitializeData initializes the Smalltown data partition and returns the global unlock key. It seals
+// the local portion into the TPM and stores the blob on the ESP. This is a potentially slow
+// operation since it touches the whole partition.
+func (d *DataDirectory) MountNew(unlock *ESPLocalUnlockFile) ([]byte, error) {
+ d.flagLock.Lock()
+ defer d.flagLock.Unlock()
+ if !d.canMount {
+ return nil, fmt.Errorf("cannot mount yet (root not ready?)")
+ }
+ if d.mounted {
+ return nil, fmt.Errorf("already mounted")
+ }
+ d.mounted = true
+
+ localUnlockKey, err := tpm.GenerateSafeKey(keySize)
+ if err != nil {
+ return nil, fmt.Errorf("generating local unlock key: %w", err)
+ }
+ globalUnlockKey, err := tpm.GenerateSafeKey(keySize)
+ if err != nil {
+ return nil, fmt.Errorf("generating global unlock key: %w", err)
+ }
+
+ localUnlockBlob, err := tpm.Seal(localUnlockKey, tpm.SecureBootPCRs)
+ if err != nil {
+ return nil, fmt.Errorf("sealing lock unlock key: %w", err)
+ }
+
+ // The actual key is generated by XORing together the localUnlockKey and the globalUnlockKey
+ // This provides us with a mathematical guarantee that the resulting key cannot be recovered
+ // whithout knowledge of both parts.
+ key := make([]byte, keySize)
+ for i := uint16(0); i < keySize; i++ {
+ key[i] = localUnlockKey[i] ^ globalUnlockKey[i]
+ }
+
+ if err := crypt.CryptInit("data", crypt.SmalltownDataCryptPath, key); err != nil {
+ return nil, fmt.Errorf("initializing encrypted block device: %w", err)
+ }
+ mkfsCmd := exec.Command("/bin/mkfs.xfs", "-qf", "/dev/data")
+ if _, err := mkfsCmd.Output(); err != nil {
+ return nil, fmt.Errorf("formatting encrypted block device: %w", err)
+ }
+
+ if err := d.mount(); err != nil {
+ return nil, fmt.Errorf("mounting: %w", err)
+ }
+
+ // TODO(q3k): do this automatically?
+ for _, d := range []declarative.DirectoryPlacement{
+ d.Etcd, d.Etcd.Data, d.Etcd.PeerPKI,
+ d.Containerd,
+ d.Kubernetes,
+ d.Kubernetes.Kubelet, d.Kubernetes.Kubelet.PKI, d.Kubernetes.Kubelet.Plugins, d.Kubernetes.Kubelet.PluginsRegistry,
+ d.Kubernetes.ClusterNetworking,
+ d.Node,
+ d.Volumes,
+ } {
+ err := d.MkdirAll(0700)
+ if err != nil {
+ return nil, fmt.Errorf("creating directory failed: %w", err)
+ }
+ }
+
+ if err := unlock.Write(localUnlockBlob, 0600); err != nil {
+ return nil, fmt.Errorf("writing unlock blob: %w", err)
+ }
+
+ return globalUnlockKey, nil
+}
+
+func (d *DataDirectory) mount() error {
+ if err := os.Mkdir(d.FullPath(), 0755); err != nil {
+ return fmt.Errorf("making data directory: %w", err)
+ }
+
+ if err := unix.Mount("/dev/data", d.FullPath(), "xfs", unix.MS_NOEXEC|unix.MS_NODEV, "pquota"); err != nil {
+ return fmt.Errorf("mounting data directory: %w", err)
+ }
+ return nil
+}
diff --git a/metropolis/node/core/localstorage/directory_pki.go b/metropolis/node/core/localstorage/directory_pki.go
new file mode 100644
index 0000000..6bdebff
--- /dev/null
+++ b/metropolis/node/core/localstorage/directory_pki.go
@@ -0,0 +1,167 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package localstorage
+
+import (
+ "crypto/ed25519"
+ "crypto/rand"
+ "crypto/tls"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "encoding/hex"
+ "fmt"
+ "math/big"
+ "time"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/localstorage/declarative"
+)
+
+var (
+ // From RFC 5280 Section 4.1.2.5
+ unknownNotAfter = time.Unix(253402300799, 0)
+)
+
+type CertificateTemplateNamer func(pubkey []byte) x509.Certificate
+
+func CertificateForNode(pubkey []byte) x509.Certificate {
+ name := "smalltown-" + hex.EncodeToString([]byte(pubkey[:16]))
+
+ // This has no SANs because it authenticates by public key, not by name
+ return x509.Certificate{
+ Subject: pkix.Name{
+ // We identify nodes by their ID public keys (not hashed since a strong hash is longer and serves no benefit)
+ CommonName: name,
+ },
+ IsCA: false,
+ BasicConstraintsValid: true,
+ NotBefore: time.Now(),
+ NotAfter: unknownNotAfter,
+ // Certificate is used both as server & client
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
+ }
+}
+
+func (p *PKIDirectory) EnsureSelfSigned(namer CertificateTemplateNamer) (*tls.Certificate, error) {
+ create := false
+ for _, f := range []*declarative.File{&p.Certificate, &p.Key} {
+ exists, err := f.Exists()
+ if err != nil {
+ return nil, fmt.Errorf("could not check existence of file %q: %w", f.FullPath(), err)
+ }
+ if !exists {
+ create = true
+ break
+ }
+ }
+
+ if !create {
+ certRaw, err := p.Certificate.Read()
+ if err != nil {
+ return nil, fmt.Errorf("could not read certificate: %w", err)
+ }
+ privKeyRaw, err := p.Key.Read()
+ if err != nil {
+ return nil, fmt.Errorf("could not read key: %w", err)
+ }
+ cert, err := x509.ParseCertificate(certRaw)
+ if err != nil {
+ return nil, fmt.Errorf("could not parse certificate: %w", err)
+ }
+ privKey, err := x509.ParsePKCS8PrivateKey(privKeyRaw)
+ if err != nil {
+ return nil, fmt.Errorf("could not parse key: %w", err)
+ }
+ return &tls.Certificate{
+ Certificate: [][]byte{certRaw},
+ PrivateKey: privKey,
+ Leaf: cert,
+ }, nil
+ }
+
+ pubKey, privKeyRaw, err := ed25519.GenerateKey(rand.Reader)
+ if err != nil {
+ return nil, fmt.Errorf("failed to generate key: %w", err)
+ }
+
+ privKey, err := x509.MarshalPKCS8PrivateKey(privKeyRaw)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal key: %w", err)
+ }
+
+ if err := p.Key.Write(privKey, 0600); err != nil {
+ return nil, fmt.Errorf("failed to write new private key: %w", err)
+ }
+
+ serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 127)
+ serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
+ if err != nil {
+ return nil, fmt.Errorf("failed to generate serial number: %w", err)
+ }
+
+ template := namer(pubKey)
+ template.SerialNumber = serialNumber
+
+ certRaw, err := x509.CreateCertificate(rand.Reader, &template, &template, pubKey, privKeyRaw)
+ if err != nil {
+ return nil, fmt.Errorf("could not sign certificate: %w", err)
+ }
+
+ cert, err := x509.ParseCertificate(certRaw)
+ if err != nil {
+ return nil, fmt.Errorf("could not parse newly created certificate: %w", err)
+ }
+
+ if err := p.Certificate.Write(certRaw, 0600); err != nil {
+ return nil, fmt.Errorf("failed to write new certificate: %w", err)
+ }
+
+ return &tls.Certificate{
+ Certificate: [][]byte{certRaw},
+ PrivateKey: privKey,
+ Leaf: cert,
+ }, nil
+}
+
+// AllExist returns true if all PKI files (cert, key, CA cert) are present on the backing
+// store.
+func (p *PKIDirectory) AllExist() (bool, error) {
+ for _, d := range []*declarative.File{&p.CACertificate, &p.Certificate, &p.Key} {
+ exists, err := d.Exists()
+ if err != nil {
+ return false, fmt.Errorf("failed to check %q: %v", d.FullPath(), err)
+ }
+ if !exists {
+ return false, nil
+ }
+ }
+ return true, nil
+}
+
+// AllAbsent returns true if all PKI files (cert, key, CA cert) are missing from the backing
+// store.
+func (p *PKIDirectory) AllAbsent() (bool, error) {
+ for _, d := range []*declarative.File{&p.CACertificate, &p.Certificate, &p.Key} {
+ exists, err := d.Exists()
+ if err != nil {
+ return false, fmt.Errorf("failed to check %q: %v", d.FullPath(), err)
+ }
+ if exists {
+ return false, nil
+ }
+ }
+ return true, nil
+}
diff --git a/metropolis/node/core/localstorage/directory_root.go b/metropolis/node/core/localstorage/directory_root.go
new file mode 100644
index 0000000..883d1e2
--- /dev/null
+++ b/metropolis/node/core/localstorage/directory_root.go
@@ -0,0 +1,83 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package localstorage
+
+import (
+ "context"
+ "fmt"
+ "os"
+
+ "golang.org/x/sys/unix"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/localstorage/crypt"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/localstorage/declarative"
+)
+
+func (r *Root) Start(ctx context.Context) error {
+ r.Data.flagLock.Lock()
+ defer r.Data.flagLock.Unlock()
+ if r.Data.canMount {
+ return fmt.Errorf("cannot re-start root storage")
+ }
+ // TODO(q3k): turn this into an Ensure call
+ err := crypt.MakeBlockDevices(ctx)
+ if err != nil {
+ return fmt.Errorf("MakeBlockDevices: %w", err)
+ }
+
+ if err := os.Mkdir(r.ESP.FullPath(), 0755); err != nil {
+ return fmt.Errorf("making ESP directory: %w", err)
+ }
+
+ if err := unix.Mount(crypt.ESPDevicePath, r.ESP.FullPath(), "vfat", unix.MS_NOEXEC|unix.MS_NODEV|unix.MS_SYNC, ""); err != nil {
+ return fmt.Errorf("mounting ESP partition: %w", err)
+ }
+
+ r.Data.canMount = true
+
+ if err := os.Mkdir(r.Tmp.FullPath(), 0777); err != nil {
+ return fmt.Errorf("making /tmp directory: %w", err)
+ }
+
+ if err := unix.Mount("tmpfs", r.Tmp.FullPath(), "tmpfs", unix.MS_NOEXEC|unix.MS_NODEV, ""); err != nil {
+ return fmt.Errorf("mounting /tmp: %w", err)
+ }
+
+ // TODO(q3k): do this automatically?
+ for _, d := range []declarative.DirectoryPlacement{
+ r.Etc,
+ r.Ephemeral,
+ r.Ephemeral.Consensus,
+ r.Ephemeral.Containerd, r.Ephemeral.Containerd.Tmp, r.Ephemeral.Containerd.RunSC, r.Ephemeral.Containerd.IPAM,
+ r.Ephemeral.FlexvolumePlugins,
+ } {
+ err := d.MkdirAll(0700)
+ if err != nil {
+ return fmt.Errorf("creating directory failed: %w", err)
+ }
+ }
+
+ for _, d := range []declarative.DirectoryPlacement{
+ r.Ephemeral, r.Ephemeral.Containerd, r.Ephemeral.Containerd.Tmp,
+ } {
+ if err := os.Chmod(d.FullPath(), 0755); err != nil {
+ return fmt.Errorf("failed to chmod containerd tmp path: %w", err)
+ }
+ }
+
+ return nil
+}
diff --git a/metropolis/node/core/localstorage/storage.go b/metropolis/node/core/localstorage/storage.go
new file mode 100644
index 0000000..8cc291f
--- /dev/null
+++ b/metropolis/node/core/localstorage/storage.go
@@ -0,0 +1,161 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package localstorage
+
+// Localstorage is a replacement for the old 'storage' internal library. It is currently unused, but will become
+// so as the node code gets rewritten.
+
+// The library is centered around the idea of a declarative filesystem tree defined as mutually recursive Go structs.
+// This structure is then Placed onto an abstract real filesystem (eg. a local POSIX filesystem at /), and a handle
+// to that placed filesystem is then used by the consumers of this library to refer to subsets of the tree (that now
+// correspond to locations on a filesystem).
+//
+// Every member of the storage hierarchy must either be, or inherit from Directory or File. In order to be placed
+// correctly, Directory embedding structures must use `dir:` or `file:` tags for child Directories and files
+// respectively. The content of the tag specifies the path part that this element will be placed at.
+//
+// Full placement path(available via FullPath()) format is placement implementation-specific. However, they're always
+// strings.
+
+import (
+ "sync"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/localstorage/declarative"
+)
+
+type Root struct {
+ declarative.Directory
+ // UEFI ESP partition, mounted from plaintext storage.
+ ESP ESPDirectory `dir:"esp"`
+ // Persistent Data partition, mounted from encrypted and authenticated storage.
+ Data DataDirectory `dir:"data"`
+ // FHS-standard /etc directory, containes /etc/hosts, /etc/machine-id, and other compatibility files.
+ Etc EtcDirectory `dir:"etc"`
+ // Ephemeral data, used by runtime, stored in tmpfs. Things like sockets, temporary config files, etc.
+ Ephemeral EphemeralDirectory `dir:"ephemeral"`
+ // FHS-standard /tmp directory, used by ioutil.TempFile.
+ Tmp TmpDirectory `dir:"tmp"`
+}
+
+type PKIDirectory struct {
+ declarative.Directory
+ CACertificate declarative.File `file:"ca.pem"`
+ Certificate declarative.File `file:"cert.pem"`
+ Key declarative.File `file:"cert-key.pem"`
+}
+
+// ESPDirectory is the EFI System Partition.
+type ESPDirectory struct {
+ declarative.Directory
+ LocalUnlock ESPLocalUnlockFile `file:"local_unlock.bin"`
+ // Enrolment is the configuration/provisioning file for this node, containing information required to begin
+ // joining the cluster.
+ Enrolment declarative.File `file:"enrolment.pb"`
+}
+
+// ESPLocalUnlockFile is the localUnlock file, encrypted by the TPM of this node. After decrypting by the TPM it is used
+// in conjunction with the globalUnlock key (retrieved from the existing cluster) to decrypt the local data partition.
+type ESPLocalUnlockFile struct {
+ declarative.File
+}
+
+// DataDirectory is an xfs partition mounted via cryptsetup/LUKS, with a key derived from {global,local}Unlock keys.
+type DataDirectory struct {
+ declarative.Directory
+
+ // flagLock locks canMount and mounted.
+ flagLock sync.Mutex
+ // canMount is set by Root when it is initialized. It is required to be set for mounting the data directory.
+ canMount bool
+ // mounted is set by DataDirectory when it is mounted. It ensures it's only mounted once.
+ mounted bool
+
+ Containerd declarative.Directory `dir:"containerd"`
+ Etcd DataEtcdDirectory `dir:"etcd"`
+ Kubernetes DataKubernetesDirectory `dir:"kubernetes"`
+ Node PKIDirectory `dir:"node_pki"`
+ Volumes DataVolumesDirectory `dir:"volumes"`
+}
+
+type DataEtcdDirectory struct {
+ declarative.Directory
+ PeerPKI PKIDirectory `dir:"peer_pki"`
+ PeerCRL declarative.File `file:"peer_crl"`
+ Data declarative.Directory `dir:"data"`
+}
+
+type DataKubernetesDirectory struct {
+ declarative.Directory
+ ClusterNetworking DataKubernetesClusterNetworkingDirectory `dir:"clusternet"`
+ Kubelet DataKubernetesKubeletDirectory `dir:"kubelet"`
+}
+
+type DataKubernetesClusterNetworkingDirectory struct {
+ declarative.Directory
+ Key declarative.File `file:"private.key"`
+}
+
+type DataKubernetesKubeletDirectory struct {
+ declarative.Directory
+ Kubeconfig declarative.File `file:"kubeconfig"`
+ PKI PKIDirectory `dir:"pki"`
+
+ Plugins struct {
+ declarative.Directory
+ VFS declarative.File `file:"com.smalltown.vfs.sock"`
+ } `dir:"plugins"`
+
+ PluginsRegistry struct {
+ declarative.Directory
+ VFSReg declarative.File `file:"com.smalltown.vfs-reg.sock"`
+ } `dir:"plugins_registry"`
+}
+
+type DataVolumesDirectory struct {
+ declarative.Directory
+}
+
+type EtcDirectory struct {
+ declarative.Directory
+ Hosts declarative.File `file:"hosts"`
+ MachineID declarative.File `file:"machine-id"`
+}
+
+type EphemeralDirectory struct {
+ declarative.Directory
+ Consensus EphemeralConsensusDirectory `dir:"consensus"`
+ Containerd EphemeralContainerdDirectory `dir:"containerd"`
+ FlexvolumePlugins declarative.Directory `dir:"flexvolume_plugins"`
+}
+
+type EphemeralConsensusDirectory struct {
+ declarative.Directory
+ ClientSocket declarative.File `file:"client.sock"`
+}
+
+type EphemeralContainerdDirectory struct {
+ declarative.Directory
+ ClientSocket declarative.File `file:"client.sock"`
+ RunSCLogsFIFO declarative.File `file:"runsc-logs.fifo"`
+ Tmp declarative.Directory `dir:"tmp"`
+ RunSC declarative.Directory `dir:"runsc"`
+ IPAM declarative.Directory `dir:"ipam"`
+}
+
+type TmpDirectory struct {
+ declarative.Directory
+}
diff --git a/metropolis/node/core/localstorage/storage_test.go b/metropolis/node/core/localstorage/storage_test.go
new file mode 100644
index 0000000..8029d02
--- /dev/null
+++ b/metropolis/node/core/localstorage/storage_test.go
@@ -0,0 +1,58 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package localstorage
+
+import (
+ "testing"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/localstorage/declarative"
+)
+
+func TestValidateAll(t *testing.T) {
+ r := Root{}
+ if err := declarative.Validate(&r); err != nil {
+ t.Errorf("Validation failed: %v", err)
+ }
+}
+
+func TestPlaceFS(t *testing.T) {
+ rr := Root{}
+ err := declarative.PlaceFS(&rr, "")
+ if err != nil {
+ t.Errorf("Placement failed: %v", err)
+ }
+
+ // Re-placing should fail.
+ err = declarative.PlaceFS(&rr, "/foo")
+ if err == nil {
+ t.Errorf("Re-placement didn't fail")
+ }
+
+ // Check some absolute paths.
+ for i, te := range []struct {
+ pl declarative.Placement
+ want string
+ }{
+ {rr.ESP, "/esp"},
+ {rr.Data.Etcd, "/data/etcd"},
+ {rr.Data.Node.Certificate, "/data/node_pki/cert.pem"},
+ } {
+ if got, want := te.pl.FullPath(), te.want; got != want {
+ t.Errorf("test %d: wanted path %q, got %q", i, want, got)
+ }
+ }
+}
diff --git a/metropolis/node/core/logtree/BUILD.bazel b/metropolis/node/core/logtree/BUILD.bazel
new file mode 100644
index 0000000..120bf9f
--- /dev/null
+++ b/metropolis/node/core/logtree/BUILD.bazel
@@ -0,0 +1,32 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "go_default_library",
+ srcs = [
+ "doc.go",
+ "journal.go",
+ "journal_entry.go",
+ "journal_subscriber.go",
+ "leveled.go",
+ "leveled_payload.go",
+ "logtree.go",
+ "logtree_access.go",
+ "logtree_entry.go",
+ "logtree_publisher.go",
+ ],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/node/core/logtree",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//metropolis/node/common/logbuffer:go_default_library",
+ "//metropolis/proto/api:go_default_library",
+ ],
+)
+
+go_test(
+ name = "go_default_test",
+ srcs = [
+ "journal_test.go",
+ "logtree_test.go",
+ ],
+ embed = [":go_default_library"],
+)
diff --git a/metropolis/node/core/logtree/doc.go b/metropolis/node/core/logtree/doc.go
new file mode 100644
index 0000000..ab3c537
--- /dev/null
+++ b/metropolis/node/core/logtree/doc.go
@@ -0,0 +1,116 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+/*
+Package logtree implements a tree-shaped logger for debug events. It provides log publishers (ie. Go code) with a
+glog-like API and io.Writer API, with loggers placed in a hierarchical structure defined by a dot-delimited path
+(called a DN, short for Distinguished Name).
+
+ tree.MustLeveledFor("foo.bar.baz").Warningf("Houston, we have a problem: %v", err)
+ fmt.Fprintf(tree.MustRawFor("foo.bar.baz"), "some\nunstructured\ndata\n")
+
+Logs in this context are unstructured, operational and developer-centric human readable text messages presented as lines
+of text to consumers, with some attached metadata. Logtree does not deal with 'structured' logs as some parts of the
+industry do, and instead defers any machine-readable logs to either be handled by metrics systems like Prometheus or
+event sourcing systems like Kafka.
+
+Tree Structure
+
+As an example, consider an application that produces logs with the following DNs:
+
+ listener.http
+ listener.grpc
+ svc
+ svc.cache
+ svc.cache.gc
+
+This would correspond to a tree as follows:
+
+ .------.
+ | "" |
+ | (root) |
+ '------'
+ .----------------' '------.
+ .--------------. .---------------.
+ | svc | | listener |
+ '--------------' '---------------'
+ | .----' '----.
+ .--------------. .---------------. .---------------.
+ | svc.cache | | listener.http | | listener.grpc |
+ '--------------' '---------------' '---------------'
+ |
+ .--------------.
+ | svc.cache.gc |
+ '--------------'
+
+In this setup, every DN acts as a separate logging target, each with its own retention policy and quota. Logging to a DN
+under foo.bar does NOT automatically log to foo - all tree mechanisms are applied on log access by consumers. Loggers
+are automatically created on first use, and importantly, can be created at any time, and will automatically be created
+if a sub-DN is created that requires a parent DN to exist first. Note, for instance, that a `listener` logging node was
+created even though the example application only logged to `listener.http` and `listener.grpc`.
+
+An implicit root node is always present in the tree, accessed by DN "" (an empty string). All other logger nodes are
+children (or transitive children) of the root node.
+
+Log consumers (application code that reads the log and passes them on to operators, or ships them off for aggregation in
+other systems) to select subtrees of logs for readout. In the example tree, a consumer could select to either read all
+logs of the entire tree, just a single DN (like svc), or a subtree (like everything under listener, ie. messages emitted
+to listener.http and listener.grpc).
+
+Leveled Log Producer API
+
+As part of the glog-like logging API available to producers, the following metadata is attached to emitted logs in
+addition to the DN of the logger to which the log entry was emitted:
+
+ - timestamp at which the entry was emitted
+ - a severity level (one of FATAL, ERROR, WARN or INFO)
+ - a source of the message (file name and line number)
+
+In addition, the logger mechanism supports a variable verbosity level (so-called 'V-logging') that can be set at every
+node of the tree. For more information about the producer-facing logging API, see the documentation of the LeveledLogger
+interface, which is the main interface exposed to log producers.
+
+If the submitted message contains newlines, it will be split accordingly into a single log entry that contains multiple
+string lines. This allows for log producers to submit long, multi-line messages that are guaranteed to be non-interleaved
+with other entries, and allows for access API consumers to maintain semantic linking between multiple lines being emitted
+as a single atomic entry.
+
+Raw Log Producer API
+
+In addition to leveled, glog-like logging, LogTree supports 'raw logging'. This is implemented as an io.Writer that will
+split incoming bytes into newline-delimited lines, and log them into that logtree's DN. This mechanism is primarily
+intended to support storage of unstructured log data from external processes - for example binaries running with redirected
+stdout/stderr.
+
+Log Access API
+
+The Log Access API is mostly exposed via a single function on the LogTree struct: Read. It allows access to log entries
+that have been already buffered inside LogTree and to subscribe to receive future entries over a channel. As outlined
+earlier, any access can specify whether it is just interested in a single logger (addressed by DN), or a subtree of
+loggers.
+
+Due to the current implementation of the logtree, subtree accesses of backlogged data is significantly slower than
+accessing data of just one DN, or the whole tree (as every subtree backlog access performs a scan on all logged data).
+Thus, log consumers should be aware that it is much better to stream and buffer logs specific to some long-standing
+logging request on their own, rather than repeatedly perform reads of a subtree backlog.
+
+The data returned from the log access API is a LogEntry, which itself can contain either a raw logging entry, or a leveled
+logging entry. Helper functions are available on LogEntry that allow canonical string representations to be returned, for
+easy use in consuming tools/interfaces. Alternatively, the consumer can itself access the internal raw/leveled entries and
+print them according to their own preferred format.
+
+*/
+package logtree
diff --git a/metropolis/node/core/logtree/journal.go b/metropolis/node/core/logtree/journal.go
new file mode 100644
index 0000000..78c55a1
--- /dev/null
+++ b/metropolis/node/core/logtree/journal.go
@@ -0,0 +1,218 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logtree
+
+import (
+ "errors"
+ "strings"
+ "sync"
+)
+
+// DN is the Distinguished Name, a dot-delimited path used to address loggers within a LogTree. For example, "foo.bar"
+// designates the 'bar' logger node under the 'foo' logger node under the root node of the logger. An empty string is
+// the root node of the tree.
+type DN string
+
+var (
+ ErrInvalidDN = errors.New("invalid DN")
+)
+
+// Path return the parts of a DN, ie. all the elements of the dot-delimited DN path. For the root node, an empty list
+// will be returned. An error will be returned if the DN is invalid (contains empty parts, eg. `foo..bar`, `.foo` or
+// `foo.`.
+func (d DN) Path() ([]string, error) {
+ if d == "" {
+ return nil, nil
+ }
+ parts := strings.Split(string(d), ".")
+ for _, p := range parts {
+ if p == "" {
+ return nil, ErrInvalidDN
+ }
+ }
+ return parts, nil
+}
+
+// journal is the main log recording structure of logtree. It manages linked lists containing the actual log entries,
+// and implements scans across them. It does not understand the hierarchical nature of logtree, and instead sees all
+// entries as part of a global linked list and a local linked list for a given DN.
+//
+// The global linked list is represented by the head/tail pointers in journal and nextGlobal/prevGlobal pointers in
+// entries. The local linked lists are represented by heads[DN]/tails[DN] pointers in journal and nextLocal/prevLocal
+// pointers in entries:
+//
+// .------------. .------------. .------------.
+// | dn: A.B | | dn: Z | | dn: A.B |
+// | time: 1 | | time: 2 | | time: 3 |
+// |------------| |------------| |------------|
+// | nextGlobal :------->| nextGlobal :------->| nextGlobal :--> nil
+// nil <-: prevGlobal |<-------: prevGlobal |<-------| prevGlobal |
+// |------------| |------------| n |------------|
+// | nextLocal :---. n | nextLocal :->i .-->| nextLocal :--> nil
+// nil <-: prevLocal |<--: i<-: prevLocal | l :---| prevLocal |
+// '------------' | l '------------' | '------------'
+// ^ '----------------------' ^
+// | ^ |
+// | | |
+// ( head ) ( tails[Z] ) ( tail )
+// ( heads[A.B] ) ( heads[Z] ) ( tails[A.B] )
+//
+type journal struct {
+ // mu locks the rest of the structure. It must be taken during any operation on the journal.
+ mu sync.RWMutex
+
+ // tail is the side of the global linked list that contains the newest log entry, ie. the one that has been pushed
+ // the most recently. It can be nil when no log entry has yet been pushed. The global linked list contains all log
+ // entries pushed to the journal.
+ tail *entry
+ // head is the side of the global linked list that contains the oldest log entry. It can be nil when no log entry
+ // has yet been pushed.
+ head *entry
+
+ // tails are the tail sides of a local linked list for a given DN, ie. the sides that contain the newest entry. They
+ // are nil if there are no log entries for that DN.
+ tails map[DN]*entry
+ // heads are the head sides of a local linked list for a given DN, ie. the sides that contain the oldest entry. They
+ // are nil if there are no log entries for that DN.
+ heads map[DN]*entry
+
+ // quota is a map from DN to quota structure, representing the quota policy of a particular DN-designated logger.
+ quota map[DN]*quota
+
+ // subscribers are observer to logs. New log entries get emitted to channels present in the subscriber structure,
+ // after filtering them through subscriber-provided filters (eg. to limit events to subtrees that interest that
+ // particular subscriber).
+ subscribers []*subscriber
+}
+
+// newJournal creates a new empty journal. All journals are independent from eachother, and as such, all LogTrees are
+// also independent.
+func newJournal() *journal {
+ return &journal{
+ tails: make(map[DN]*entry),
+ heads: make(map[DN]*entry),
+
+ quota: make(map[DN]*quota),
+ }
+}
+
+// filter is a predicate that returns true if a log subscriber or reader is interested in a given log entry.
+type filter func(*entry) bool
+
+// filterAll returns a filter that accepts all log entries.
+func filterAll() filter {
+ return func(*entry) bool { return true }
+}
+
+// filterExact returns a filter that accepts only log entries at a given exact DN. This filter should not be used in
+// conjunction with journal.scanEntries - instead, journal.getEntries should be used, as it is much faster.
+func filterExact(dn DN) filter {
+ return func(e *entry) bool {
+ return e.origin == dn
+ }
+}
+
+// filterSubtree returns a filter that accepts all log entries at a given DN and sub-DNs. For example, filterSubtree at
+// "foo.bar" would allow entries at "foo.bar", "foo.bar.baz", but not "foo" or "foo.barr".
+func filterSubtree(root DN) filter {
+ if root == "" {
+ return filterAll()
+ }
+
+ rootParts := strings.Split(string(root), ".")
+ return func(e *entry) bool {
+ parts := strings.Split(string(e.origin), ".")
+ if len(parts) < len(rootParts) {
+ return false
+ }
+
+ for i, p := range rootParts {
+ if parts[i] != p {
+ return false
+ }
+ }
+
+ return true
+ }
+}
+
+// filterSeverity returns a filter that accepts log entries at a given severity level or above. See the Severity type
+// for more information about severity levels.
+func filterSeverity(atLeast Severity) filter {
+ return func(e *entry) bool {
+ return e.leveled != nil && e.leveled.severity.AtLeast(atLeast)
+ }
+}
+
+func filterOnlyRaw(e *entry) bool {
+ return e.raw != nil
+}
+
+func filterOnlyLeveled(e *entry) bool {
+ return e.leveled != nil
+}
+
+// scanEntries does a linear scan through the global entry list and returns all entries that match the given filters. If
+// retrieving entries for an exact event, getEntries should be used instead, as it will leverage DN-local linked lists
+// to retrieve them faster.
+// journal.mu must be taken at R or RW level when calling this function.
+func (j *journal) scanEntries(filters ...filter) (res []*entry) {
+ cur := j.tail
+ for {
+ if cur == nil {
+ return
+ }
+
+ passed := true
+ for _, filter := range filters {
+ if !filter(cur) {
+ passed = false
+ break
+ }
+ }
+ if passed {
+ res = append(res, cur)
+ }
+ cur = cur.nextGlobal
+ }
+}
+
+// getEntries returns all entries at a given DN. This is faster than a scanEntries(filterExact), as it uses the special
+// local linked list pointers to traverse the journal. Additional filters can be passed to further limit the entries
+// returned, but a scan through this DN's local linked list will be performed regardless.
+// journal.mu must be taken at R or RW level when calling this function.
+func (j *journal) getEntries(exact DN, filters ...filter) (res []*entry) {
+ cur := j.tails[exact]
+ for {
+ if cur == nil {
+ return
+ }
+
+ passed := true
+ for _, filter := range filters {
+ if !filter(cur) {
+ passed = false
+ break
+ }
+ }
+ if passed {
+ res = append(res, cur)
+ }
+ cur = cur.nextLocal
+ }
+
+}
diff --git a/metropolis/node/core/logtree/journal_entry.go b/metropolis/node/core/logtree/journal_entry.go
new file mode 100644
index 0000000..61619b3
--- /dev/null
+++ b/metropolis/node/core/logtree/journal_entry.go
@@ -0,0 +1,169 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logtree
+
+import "git.monogon.dev/source/nexantic.git/metropolis/node/common/logbuffer"
+
+// entry is a journal entry, representing a single log event (encompassed in a Payload) at a given DN.
+// See the journal struct for more information about the global/local linked lists.
+type entry struct {
+ // origin is the DN at which the log entry was recorded, or conversely, in which DN it will be available at.
+ origin DN
+ // journal is the parent journal of this entry. An entry can belong only to a single journal. This pointer is used
+ // to mutate the journal's head/tail pointers when unlinking an entry.
+ journal *journal
+ // leveled is the leveled log entry for this entry, if this log entry was emitted by leveled logging. Otherwise it
+ // is nil.
+ leveled *LeveledPayload
+ // raw is the raw log entry for this entry, if this log entry was emitted by raw logging. Otherwise it is nil.
+ raw *logbuffer.Line
+
+ // prevGlobal is the previous entry in the global linked list, or nil if this entry is the oldest entry in the
+ // global linked list.
+ prevGlobal *entry
+ // nextGlobal is the next entry in the global linked list, or nil if this entry is the newest entry in the global
+ // linked list.
+ nextGlobal *entry
+
+ // prevLocal is the previous entry in this entry DN's local linked list, or nil if this entry is the oldest entry in
+ // this local linked list.
+ prevLocal *entry
+ // prevLocal is the next entry in this entry DN's local linked list, or nil if this entry is the newest entry in
+ // this local linked list.
+ nextLocal *entry
+
+ // seqLocal is a counter within a local linked list that increases by one each time a new log entry is added. It is
+ // used to quickly establish local linked list sizes (by subtracting seqLocal from both ends). This setup allows for
+ // O(1) length calculation for local linked lists as long as entries are only unlinked from the head or tail (which
+ // is the case in the current implementation).
+ seqLocal uint64
+}
+
+// external returns a LogEntry object for this entry, ie. the public version of this object, without fields relating to
+// the parent journal, linked lists, sequences, etc. These objects are visible to library consumers.
+func (e *entry) external() *LogEntry {
+ return &LogEntry{
+ DN: e.origin,
+ Leveled: e.leveled,
+ Raw: e.raw,
+ }
+}
+
+// unlink removes this entry from both global and local linked lists, updating the journal's head/tail pointers if
+// needed.
+// journal.mu must be taken as RW
+func (e *entry) unlink() {
+ // Unlink from the global linked list.
+ if e.prevGlobal != nil {
+ e.prevGlobal.nextGlobal = e.nextGlobal
+ }
+ if e.nextGlobal != nil {
+ e.nextGlobal.prevGlobal = e.prevGlobal
+ }
+ // Update journal head/tail pointers.
+ if e.journal.head == e {
+ e.journal.head = e.prevGlobal
+ }
+ if e.journal.tail == e {
+ e.journal.tail = e.nextGlobal
+ }
+
+ // Unlink from the local linked list.
+ if e.prevLocal != nil {
+ e.prevLocal.nextLocal = e.nextLocal
+ }
+ if e.nextLocal != nil {
+ e.nextLocal.prevLocal = e.prevLocal
+ }
+ // Update journal head/tail pointers.
+ if e.journal.heads[e.origin] == e {
+ e.journal.heads[e.origin] = e.prevLocal
+ }
+ if e.journal.tails[e.origin] == e {
+ e.journal.tails[e.origin] = e.nextLocal
+ }
+}
+
+// quota describes the quota policy for logging at a given DN.
+type quota struct {
+ // origin is the exact DN that this quota applies to.
+ origin DN
+ // max is the maximum count of log entries permitted for this DN - ie, the maximum size of the local linked list.
+ max uint64
+}
+
+// append adds an entry at the head of the global and local linked lists.
+func (j *journal) append(e *entry) {
+ j.mu.Lock()
+ defer j.mu.Unlock()
+
+ e.journal = j
+
+ // Insert at head in global linked list, set pointers.
+ e.nextGlobal = nil
+ e.prevGlobal = j.head
+ if j.head != nil {
+ j.head.nextGlobal = e
+ }
+ j.head = e
+ if j.tail == nil {
+ j.tail = e
+ }
+
+ // Create quota if necessary.
+ if _, ok := j.quota[e.origin]; !ok {
+ j.quota[e.origin] = "a{origin: e.origin, max: 8192}
+ }
+
+ // Insert at head in local linked list, calculate seqLocal, set pointers.
+ e.nextLocal = nil
+ e.prevLocal = j.heads[e.origin]
+ if j.heads[e.origin] != nil {
+ j.heads[e.origin].nextLocal = e
+ e.seqLocal = e.prevLocal.seqLocal + 1
+ } else {
+ e.seqLocal = 0
+ }
+ j.heads[e.origin] = e
+ if j.tails[e.origin] == nil {
+ j.tails[e.origin] = e
+ }
+
+ // Apply quota to the local linked list that this entry got inserted to, ie. remove elements in excess of the
+ // quota.max count.
+ quota := j.quota[e.origin]
+ count := (j.heads[e.origin].seqLocal - j.tails[e.origin].seqLocal) + 1
+ if count > quota.max {
+ // Keep popping elements off the tail of the local linked list until quota is not violated.
+ left := count - quota.max
+ cur := j.tails[e.origin]
+ for {
+ // This shouldn't happen if quota.max >= 1.
+ if cur == nil {
+ break
+ }
+ if left == 0 {
+ break
+ }
+ el := cur
+ cur = el.nextLocal
+ // Unlinking the entry unlinks it from both the global and local linked lists.
+ el.unlink()
+ left -= 1
+ }
+ }
+}
diff --git a/metropolis/node/core/logtree/journal_subscriber.go b/metropolis/node/core/logtree/journal_subscriber.go
new file mode 100644
index 0000000..e6c7c62
--- /dev/null
+++ b/metropolis/node/core/logtree/journal_subscriber.go
@@ -0,0 +1,69 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logtree
+
+import (
+ "sync/atomic"
+)
+
+// subscriber is an observer for new entries that are appended to the journal.
+type subscriber struct {
+ // filters that entries need to pass through in order to be sent to the subscriber.
+ filters []filter
+ // dataC is the channel to which entries that pass filters will be sent. The channel must be drained regularly in
+ // order to prevent accumulation of goroutines and possible reordering of messages.
+ dataC chan *LogEntry
+ // doneC is a channel that is closed once the subscriber wishes to stop receiving notifications.
+ doneC chan struct{}
+ // missed is the amount of messages missed by the subscriber by not receiving from dataC fast enough
+ missed uint64
+}
+
+// subscribe attaches a subscriber to the journal.
+// mu must be taken in W mode
+func (j *journal) subscribe(sub *subscriber) {
+ j.subscribers = append(j.subscribers, sub)
+}
+
+// notify sends an entry to all subscribers that wish to receive it.
+func (j *journal) notify(e *entry) {
+ j.mu.Lock()
+ defer j.mu.Unlock()
+
+ newSub := make([]*subscriber, 0, len(j.subscribers))
+ for _, sub := range j.subscribers {
+ select {
+ case <-sub.doneC:
+ close(sub.dataC)
+ continue
+ default:
+ newSub = append(newSub, sub)
+ }
+
+ for _, filter := range sub.filters {
+ if !filter(e) {
+ continue
+ }
+ }
+ select {
+ case sub.dataC <- e.external():
+ default:
+ atomic.AddUint64(&sub.missed, 1)
+ }
+ }
+ j.subscribers = newSub
+}
diff --git a/metropolis/node/core/logtree/journal_test.go b/metropolis/node/core/logtree/journal_test.go
new file mode 100644
index 0000000..474748a
--- /dev/null
+++ b/metropolis/node/core/logtree/journal_test.go
@@ -0,0 +1,148 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logtree
+
+import (
+ "fmt"
+ "strings"
+ "testing"
+ "time"
+)
+
+func testPayload(msg string) *LeveledPayload {
+ return &LeveledPayload{
+ messages: []string{msg},
+ timestamp: time.Now(),
+ severity: INFO,
+ file: "main.go",
+ line: 1337,
+ }
+}
+
+func TestJournalRetention(t *testing.T) {
+ j := newJournal()
+
+ for i := 0; i < 9000; i += 1 {
+ e := &entry{
+ origin: "main",
+ leveled: testPayload(fmt.Sprintf("test %d", i)),
+ }
+ j.append(e)
+ }
+
+ entries := j.getEntries("main")
+ if want, got := 8192, len(entries); want != got {
+ t.Fatalf("wanted %d entries, got %d", want, got)
+ }
+ for i, entry := range entries {
+ want := fmt.Sprintf("test %d", (9000-8192)+i)
+ got := strings.Join(entry.leveled.messages, "\n")
+ if want != got {
+ t.Fatalf("wanted entry %q, got %q", want, got)
+ }
+ }
+}
+
+func TestJournalQuota(t *testing.T) {
+ j := newJournal()
+
+ for i := 0; i < 9000; i += 1 {
+ j.append(&entry{
+ origin: "chatty",
+ leveled: testPayload(fmt.Sprintf("chatty %d", i)),
+ })
+ if i%10 == 0 {
+ j.append(&entry{
+ origin: "solemn",
+ leveled: testPayload(fmt.Sprintf("solemn %d", i)),
+ })
+ }
+ }
+
+ entries := j.getEntries("chatty")
+ if want, got := 8192, len(entries); want != got {
+ t.Fatalf("wanted %d chatty entries, got %d", want, got)
+ }
+ entries = j.getEntries("solemn")
+ if want, got := 900, len(entries); want != got {
+ t.Fatalf("wanted %d solemn entries, got %d", want, got)
+ }
+ entries = j.getEntries("absent")
+ if want, got := 0, len(entries); want != got {
+ t.Fatalf("wanted %d absent entries, got %d", want, got)
+ }
+
+ entries = j.scanEntries(filterAll())
+ if want, got := 8192+900, len(entries); want != got {
+ t.Fatalf("wanted %d total entries, got %d", want, got)
+ }
+ setMessages := make(map[string]bool)
+ for _, entry := range entries {
+ setMessages[strings.Join(entry.leveled.messages, "\n")] = true
+ }
+
+ for i := 0; i < 900; i += 1 {
+ want := fmt.Sprintf("solemn %d", i*10)
+ if !setMessages[want] {
+ t.Fatalf("could not find entry %q in journal", want)
+ }
+ }
+ for i := 0; i < 8192; i += 1 {
+ want := fmt.Sprintf("chatty %d", i+(9000-8192))
+ if !setMessages[want] {
+ t.Fatalf("could not find entry %q in journal", want)
+ }
+ }
+}
+
+func TestJournalSubtree(t *testing.T) {
+ j := newJournal()
+ j.append(&entry{origin: "a", leveled: testPayload("a")})
+ j.append(&entry{origin: "a.b", leveled: testPayload("a.b")})
+ j.append(&entry{origin: "a.b.c", leveled: testPayload("a.b.c")})
+ j.append(&entry{origin: "a.b.d", leveled: testPayload("a.b.d")})
+ j.append(&entry{origin: "e.f", leveled: testPayload("e.f")})
+ j.append(&entry{origin: "e.g", leveled: testPayload("e.g")})
+
+ expect := func(f filter, msgs ...string) string {
+ res := j.scanEntries(f)
+ set := make(map[string]bool)
+ for _, entry := range res {
+ set[strings.Join(entry.leveled.messages, "\n")] = true
+ }
+
+ for _, want := range msgs {
+ if !set[want] {
+ return fmt.Sprintf("missing entry %q", want)
+ }
+ }
+ return ""
+ }
+
+ if res := expect(filterAll(), "a", "a.b", "a.b.c", "a.b.d", "e.f", "e.g"); res != "" {
+ t.Fatalf("All: %s", res)
+ }
+ if res := expect(filterSubtree("a"), "a", "a.b", "a.b.c", "a.b.d"); res != "" {
+ t.Fatalf("Subtree(a): %s", res)
+ }
+ if res := expect(filterSubtree("a.b"), "a.b", "a.b.c", "a.b.d"); res != "" {
+ t.Fatalf("Subtree(a.b): %s", res)
+ }
+ if res := expect(filterSubtree("e"), "e.f", "e.g"); res != "" {
+ t.Fatalf("Subtree(a.b): %s", res)
+ }
+}
diff --git a/metropolis/node/core/logtree/leveled.go b/metropolis/node/core/logtree/leveled.go
new file mode 100644
index 0000000..c24357e
--- /dev/null
+++ b/metropolis/node/core/logtree/leveled.go
@@ -0,0 +1,144 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logtree
+
+import (
+ "fmt"
+
+ apb "git.monogon.dev/source/nexantic.git/metropolis/proto/api"
+)
+
+// LeveledLogger is a generic interface for glog-style logging. There are four hardcoded log severities, in increasing
+// order: INFO, WARNING, ERROR, FATAL. Logging at a certain severity level logs not only to consumers expecting data
+// at that severity level, but also all lower severity levels. For example, an ERROR log will also be passed to
+// consumers looking at INFO or WARNING logs.
+type LeveledLogger interface {
+ // Info logs at the INFO severity. Arguments are handled in the manner of fmt.Print, a terminating newline is added
+ // if missing.
+ Info(args ...interface{})
+ // Infof logs at the INFO severity. Arguments are handled in the manner of fmt.Printf, a terminating newline is
+ // added if missing.
+ Infof(format string, args ...interface{})
+
+ // Warning logs at the WARNING severity. Arguments are handled in the manner of fmt.Print, a terminating newline is
+ // added if missing.
+ Warning(args ...interface{})
+ // Warningf logs at the WARNING severity. Arguments are handled in the manner of fmt.Printf, a terminating newline
+ // is added if missing.
+ Warningf(format string, args ...interface{})
+
+ // Error logs at the ERROR severity. Arguments are handled in the manner of fmt.Print, a terminating newline is
+ // added if missing.
+ Error(args ...interface{})
+ // Errorf logs at the ERROR severity. Arguments are handled in the manner of fmt.Printf, a terminating newline is
+ // added if missing.
+ Errorf(format string, args ...interface{})
+
+ // Fatal logs at the FATAL severity and aborts the current program. Arguments are handled in the manner of
+ // fmt.Print, a terminating newline is added if missing.
+ Fatal(args ...interface{})
+ // Fatalf logs at the FATAL severity and aborts the current program. Arguments are handled in the manner of
+ // fmt.Printf, a terminating newline is added if missing.
+ Fatalf(format string, args ...interface{})
+
+ // V returns a VerboseLeveledLogger at a given verbosity level. These verbosity levels can be dynamically set and
+ // unset on a package-granular level by consumers of the LeveledLogger logs. The returned value represents whether
+ // logging at the given verbosity level was active at that time, and as such should not be a long-lived object
+ // in programs.
+ // This construct is further refered to as 'V-logs'.
+ V(level VerbosityLevel) VerboseLeveledLogger
+}
+
+// VerbosityLevel is a verbosity level defined for V-logs. This can be changed programmatically per Go package. When
+// logging at a given VerbosityLevel V, the current level must be equal or higher to V for the logs to be recorded.
+// Conversely, enabling a V-logging at a VerbosityLevel V also enables all logging at lower levels [Int32Min .. (V-1)].
+type VerbosityLevel int32
+
+type VerboseLeveledLogger interface {
+ // Enabled returns if this level was enabled. If not enabled, all logging into this logger will be discarded
+ // immediately.
+ // Thus, Enabled() can be used to check the verbosity level before performing any logging:
+ // if l.V(3).Enabled() { l.Info("V3 is enabled") }
+ // or, in simple cases, the convenience function .Info can be used:
+ // l.V(3).Info("V3 is enabled")
+ // The second form is shorter and more convenient, but more expensive, as its arguments are always evaluated.
+ Enabled() bool
+ // Info is the equivalent of a LeveledLogger's Info call, guarded by whether this VerboseLeveledLogger is enabled.
+ Info(args ...interface{})
+ // Infof is the equivalent of a LeveledLogger's Infof call, guarded by whether this VerboseLeveledLogger is enabled.
+ Infof(format string, args ...interface{})
+}
+
+// Severity is one of the severities as described in LeveledLogger.
+type Severity string
+
+const (
+ INFO Severity = "I"
+ WARNING Severity = "W"
+ ERROR Severity = "E"
+ FATAL Severity = "F"
+)
+
+var (
+ // SeverityAtLeast maps a given severity to a list of severities that at that severity or higher. In other words,
+ // SeverityAtLeast[X] returns a list of severities that might be seen in a log at severity X.
+ SeverityAtLeast = map[Severity][]Severity{
+ INFO: {INFO, WARNING, ERROR, FATAL},
+ WARNING: {WARNING, ERROR, FATAL},
+ ERROR: {ERROR, FATAL},
+ FATAL: {FATAL},
+ }
+)
+
+func (s Severity) AtLeast(other Severity) bool {
+ for _, el := range SeverityAtLeast[other] {
+ if el == s {
+ return true
+ }
+ }
+ return false
+}
+
+func SeverityFromProto(s apb.LeveledLogSeverity) (Severity, error) {
+ switch s {
+ case apb.LeveledLogSeverity_INFO:
+ return INFO, nil
+ case apb.LeveledLogSeverity_WARNING:
+ return WARNING, nil
+ case apb.LeveledLogSeverity_ERROR:
+ return ERROR, nil
+ case apb.LeveledLogSeverity_FATAL:
+ return FATAL, nil
+ default:
+ return "", fmt.Errorf("unknown severity value %d", s)
+ }
+}
+
+func (s Severity) ToProto() apb.LeveledLogSeverity {
+ switch s {
+ case INFO:
+ return apb.LeveledLogSeverity_INFO
+ case WARNING:
+ return apb.LeveledLogSeverity_WARNING
+ case ERROR:
+ return apb.LeveledLogSeverity_ERROR
+ case FATAL:
+ return apb.LeveledLogSeverity_FATAL
+ default:
+ return apb.LeveledLogSeverity_INVALID
+ }
+}
diff --git a/metropolis/node/core/logtree/leveled_payload.go b/metropolis/node/core/logtree/leveled_payload.go
new file mode 100644
index 0000000..fad42e3
--- /dev/null
+++ b/metropolis/node/core/logtree/leveled_payload.go
@@ -0,0 +1,142 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logtree
+
+import (
+ "fmt"
+ "strconv"
+ "strings"
+ "time"
+
+ apb "git.monogon.dev/source/nexantic.git/metropolis/proto/api"
+)
+
+// LeveledPayload is a log entry for leveled logs (as per leveled.go). It contains the input to these calls (severity and
+// message split into newline-delimited messages) and additional metadata that would be usually seen in a text
+// representation of a leveled log entry.
+type LeveledPayload struct {
+ // messages is the list of messages contained in this payload. This list is built from splitting up the given message
+ // from the user by newline.
+ messages []string
+ // timestamp is the time at which this message was emitted.
+ timestamp time.Time
+ // severity is the leveled Severity at which this message was emitted.
+ severity Severity
+ // file is the filename of the caller that emitted this message.
+ file string
+ // line is the line number within the file of the caller that emitted this message.
+ line int
+}
+
+// String returns a canonical representation of this payload as a single string prefixed with metadata. If the original
+// message was logged with newlines, this representation will also contain newlines, with each original message part
+// prefixed by the metadata.
+// For an alternative call that will instead return a canonical prefix and a list of lines in the message, see Strings().
+func (p *LeveledPayload) String() string {
+ prefix, lines := p.Strings()
+ res := make([]string, len(p.messages))
+ for i, line := range lines {
+ res[i] = fmt.Sprintf("%s%s", prefix, line)
+ }
+ return strings.Join(res, "\n")
+}
+
+// Strings returns the canonical representation of this payload split into a prefix and all lines that were contained in
+// the original message. This is meant to be displayed to the user by showing the prefix before each line, concatenated
+// together - possibly in a table form with the prefixes all unified with a rowspan-like mechanism.
+//
+// For example, this function can return:
+// prefix = "I1102 17:20:06.921395 foo.go:42] "
+// lines = []string{"current tags:", " - one", " - two"}
+//
+// With this data, the result should be presented to users this way in text form:
+// I1102 17:20:06.921395 foo.go:42] current tags:
+// I1102 17:20:06.921395 foo.go:42] - one
+// I1102 17:20:06.921395 foo.go:42] - two
+//
+// Or, in a table layout:
+// .-----------------------------------------------------------.
+// | I1102 17:20:06.921395 0 foo.go:42] : current tags: |
+// | :------------------|
+// | : - one |
+// | :------------------|
+// | : - two |
+// '-----------------------------------------------------------'
+//
+func (p *LeveledPayload) Strings() (prefix string, lines []string) {
+ _, month, day := p.timestamp.Date()
+ hour, minute, second := p.timestamp.Clock()
+ nsec := p.timestamp.Nanosecond() / 1000
+
+ // Same format as in glog, but without treadid.
+ // Lmmdd hh:mm:ss.uuuuuu file:line]
+ // TODO(q3k): rewrite this to printf-less code.
+ prefix = fmt.Sprintf("%s%02d%02d %02d:%02d:%02d.%06d %s:%d] ", p.severity, month, day, hour, minute, second, nsec, p.file, p.line)
+
+ lines = p.messages
+ return
+}
+
+// Message returns the inner message lines of this entry, ie. what was passed to the actual logging method, but split by
+// newlines.
+func (p *LeveledPayload) Messages() []string { return p.messages }
+
+func (p *LeveledPayload) MessagesJoined() string { return strings.Join(p.messages, "\n") }
+
+// Timestamp returns the time at which this entry was logged.
+func (p *LeveledPayload) Timestamp() time.Time { return p.timestamp }
+
+// Location returns a string in the form of file_name:line_number that shows the origin of the log entry in the
+// program source.
+func (p *LeveledPayload) Location() string { return fmt.Sprintf("%s:%d", p.file, p.line) }
+
+// Severity returns the Severity with which this entry was logged.
+func (p *LeveledPayload) Severity() Severity { return p.severity }
+
+// Proto converts a LeveledPayload to protobuf format.
+func (p *LeveledPayload) Proto() *apb.LogEntry_Leveled {
+ return &apb.LogEntry_Leveled{
+ Lines: p.Messages(),
+ Timestamp: p.Timestamp().UnixNano(),
+ Severity: p.Severity().ToProto(),
+ Location: p.Location(),
+ }
+}
+
+// LeveledPayloadFromProto parses a protobuf message into the internal format.
+func LeveledPayloadFromProto(p *apb.LogEntry_Leveled) (*LeveledPayload, error) {
+ severity, err := SeverityFromProto(p.Severity)
+ if err != nil {
+ return nil, fmt.Errorf("could not convert severity: %w", err)
+ }
+ parts := strings.Split(p.Location, ":")
+ if len(parts) != 2 {
+ return nil, fmt.Errorf("invalid location, must be two :-delimited parts, is %d parts", len(parts))
+ }
+ file := parts[0]
+ line, err := strconv.Atoi(parts[1])
+ if err != nil {
+ return nil, fmt.Errorf("invalid location line number: %w", err)
+ }
+ return &LeveledPayload{
+ messages: p.Lines,
+ timestamp: time.Unix(0, p.Timestamp),
+ severity: severity,
+ file: file,
+ line: line,
+ }, nil
+}
diff --git a/metropolis/node/core/logtree/logtree.go b/metropolis/node/core/logtree/logtree.go
new file mode 100644
index 0000000..fab72ba
--- /dev/null
+++ b/metropolis/node/core/logtree/logtree.go
@@ -0,0 +1,147 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logtree
+
+import (
+ "fmt"
+ "strings"
+ "sync"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/node/common/logbuffer"
+)
+
+// LogTree is a tree-shaped logging system. For more information, see the package-level documentation.
+type LogTree struct {
+ // journal is the tree's journal, storing all log data and managing subscribers.
+ journal *journal
+ // root is the root node of the actual tree of the log tree. The nodes contain per-DN configuration options, notably
+ // the current verbosity level of that DN.
+ root *node
+}
+
+func New() *LogTree {
+ lt := &LogTree{
+ journal: newJournal(),
+ }
+ lt.root = newNode(lt, "")
+ return lt
+}
+
+// node represents a given DN as a discrete 'logger'. It implements the LeveledLogger interface for log publishing,
+// entries from which it passes over to the logtree's journal.
+type node struct {
+ // dn is the DN which this node represents (or "" if this is the root node).
+ dn DN
+ // tree is the LogTree to which this node belongs.
+ tree *LogTree
+ // verbosity is the current verbosity level of this DN/node, affecting .V(n) LeveledLogger calls
+ verbosity VerbosityLevel
+ rawLineBuffer *logbuffer.LineBuffer
+
+ // mu guards children.
+ mu sync.Mutex
+ // children is a map of DN-part to a children node in the logtree. A DN-part is a string representing a part of the
+ // DN between the deliming dots, as returned by DN.Path.
+ children map[string]*node
+}
+
+// newNode returns a node at a given DN in the LogTree - but doesn't set up the LogTree to insert it accordingly.
+func newNode(tree *LogTree, dn DN) *node {
+ n := &node{
+ dn: dn,
+ tree: tree,
+ children: make(map[string]*node),
+ }
+ // TODO(q3k): make this limit configurable. If this happens, or the default (1024) gets changes, max chunk size
+ // calculations when serving the logs (eg. in NodeDebugService) must reflect this.
+ n.rawLineBuffer = logbuffer.NewLineBuffer(1024, n.logRaw)
+ return n
+}
+
+// nodeByDN returns the LogTree node corresponding to a given DN. If either the node or some of its parents do not
+// exist they will be created as needed.
+func (l *LogTree) nodeByDN(dn DN) (*node, error) {
+ traversal, err := newTraversal(dn)
+ if err != nil {
+ return nil, fmt.Errorf("traversal failed: %w", err)
+ }
+ return traversal.execute(l.root), nil
+}
+
+// nodeTraversal represents a request to traverse the LogTree in search of a given node by DN.
+type nodeTraversal struct {
+ // want is the DN of the node's that requested to be found.
+ want DN
+ // current is the path already taken to find the node, in the form of DN parts. It starts out as want.Parts() and
+ // progresses to become empty as the traversal continues.
+ current []string
+ // left is the path that's still needed to be taken in order to find the node, in the form of DN parts. It starts
+ // out empty and progresses to become wants.Parts() as the traversal continues.
+ left []string
+}
+
+// next adjusts the traversal's current/left slices to the next element of the traversal, returns the part that's now
+// being looked for (or "" if the traveral is done) and the full DN of the element that's being looked for.
+//
+// For example, a traversal of foo.bar.baz will cause .next() to return the following on each invocation:
+// - part: foo, full: foo
+// - part: bar, full: foo.bar
+// - part: baz, full: foo.bar.baz
+// - part: "", full: foo.bar.baz
+func (t *nodeTraversal) next() (part string, full DN) {
+ if len(t.left) == 0 {
+ return "", t.want
+ }
+ part = t.left[0]
+ t.current = append(t.current, part)
+ t.left = t.left[1:]
+ full = DN(strings.Join(t.current, "."))
+ return
+}
+
+// newTraversal returns a nodeTraversal fora a given wanted DN.
+func newTraversal(dn DN) (*nodeTraversal, error) {
+ parts, err := dn.Path()
+ if err != nil {
+ return nil, err
+ }
+ return &nodeTraversal{
+ want: dn,
+ left: parts,
+ }, nil
+}
+
+// execute the traversal in order to find the node. This can only be called once per traversal.
+// Nodes will be created within the tree until the target node is reached. Existing nodes will be reused.
+// This is effectively an idempotent way of accessing a node in the tree based on a traversal.
+func (t *nodeTraversal) execute(n *node) *node {
+ cur := n
+ for {
+ part, full := t.next()
+ if part == "" {
+ return cur
+ }
+
+ mu := &cur.mu
+ mu.Lock()
+ if _, ok := cur.children[part]; !ok {
+ cur.children[part] = newNode(n.tree, DN(full))
+ }
+ cur = cur.children[part]
+ mu.Unlock()
+ }
+}
diff --git a/metropolis/node/core/logtree/logtree_access.go b/metropolis/node/core/logtree/logtree_access.go
new file mode 100644
index 0000000..fed202e
--- /dev/null
+++ b/metropolis/node/core/logtree/logtree_access.go
@@ -0,0 +1,183 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logtree
+
+import (
+ "errors"
+ "sync/atomic"
+)
+
+// LogReadOption describes options for the LogTree.Read call.
+type LogReadOption struct {
+ withChildren bool
+ withStream bool
+ withBacklog int
+ onlyLeveled bool
+ onlyRaw bool
+ leveledWithMinimumSeverity Severity
+}
+
+// WithChildren makes Read return/stream data for both a given DN and all its children.
+func WithChildren() LogReadOption { return LogReadOption{withChildren: true} }
+
+// WithStream makes Read return a stream of data. This works alongside WithBacklog to create a read-and-stream
+// construct.
+func WithStream() LogReadOption { return LogReadOption{withStream: true} }
+
+// WithBacklog makes Read return already recorded log entries, up to count elements.
+func WithBacklog(count int) LogReadOption { return LogReadOption{withBacklog: count} }
+
+// BacklogAllAvailable makes WithBacklog return all backlogged log data that logtree possesses.
+const BacklogAllAvailable int = -1
+
+func OnlyRaw() LogReadOption { return LogReadOption{onlyRaw: true} }
+
+func OnlyLeveled() LogReadOption { return LogReadOption{onlyLeveled: true} }
+
+// LeveledWithMinimumSeverity makes Read return only log entries that are at least at a given Severity. If only leveled
+// entries are needed, OnlyLeveled must be used. This is a no-op when OnlyRaw is used.
+func LeveledWithMinimumSeverity(s Severity) LogReadOption {
+ return LogReadOption{leveledWithMinimumSeverity: s}
+}
+
+// LogReader permits reading an already existing backlog of log entries and to stream further ones.
+type LogReader struct {
+ // Backlog are the log entries already logged by LogTree. This will only be set if WithBacklog has been passed to
+ // Read.
+ Backlog []*LogEntry
+ // Stream is a channel of new entries as received live by LogTree. This will only be set if WithStream has been
+ // passed to Read. In this case, entries from this channel must be read as fast as possible by the consumer in order
+ // to prevent missing entries.
+ Stream <-chan *LogEntry
+ // done is channel used to signal (by closing) that the log consumer is not interested in more Stream data.
+ done chan<- struct{}
+ // missed is an atomic integer pointer that tells the subscriber how many messages in Stream they missed. This
+ // pointer is nil if no streaming has been requested.
+ missed *uint64
+}
+
+// Missed returns the amount of entries that were missed from Stream (as the channel was not drained fast enough).
+func (l *LogReader) Missed() uint64 {
+ // No Stream.
+ if l.missed == nil {
+ return 0
+ }
+ return atomic.LoadUint64(l.missed)
+}
+
+// Close closes the LogReader's Stream. This must be called once the Reader does not wish to receive streaming messages
+// anymore.
+func (l *LogReader) Close() {
+ if l.done != nil {
+ close(l.done)
+ }
+}
+
+var (
+ ErrRawAndLeveled = errors.New("cannot return logs that are simultaneously OnlyRaw and OnlyLeveled")
+)
+
+// Read and/or stream entries from a LogTree. The returned LogReader is influenced by the LogReadOptions passed, which
+// influence whether the Read will return existing entries, a stream, or both. In addition the options also dictate
+// whether only entries for that particular DN are returned, or for all sub-DNs as well.
+func (l *LogTree) Read(dn DN, opts ...LogReadOption) (*LogReader, error) {
+ l.journal.mu.RLock()
+ defer l.journal.mu.RUnlock()
+
+ var backlog int
+ var stream bool
+ var recursive bool
+ var leveledSeverity Severity
+ var onlyRaw, onlyLeveled bool
+
+ for _, opt := range opts {
+ if opt.withBacklog > 0 || opt.withBacklog == BacklogAllAvailable {
+ backlog = opt.withBacklog
+ }
+ if opt.withStream {
+ stream = true
+ }
+ if opt.withChildren {
+ recursive = true
+ }
+ if opt.leveledWithMinimumSeverity != "" {
+ leveledSeverity = opt.leveledWithMinimumSeverity
+ }
+ if opt.onlyLeveled {
+ onlyLeveled = true
+ }
+ if opt.onlyRaw {
+ onlyRaw = true
+ }
+ }
+
+ if onlyLeveled && onlyRaw {
+ return nil, ErrRawAndLeveled
+ }
+
+ var filters []filter
+ if onlyLeveled {
+ filters = append(filters, filterOnlyLeveled)
+ }
+ if onlyRaw {
+ filters = append(filters, filterOnlyRaw)
+ }
+ if recursive {
+ filters = append(filters, filterSubtree(dn))
+ } else {
+ filters = append(filters, filterExact(dn))
+ }
+ if leveledSeverity != "" {
+ filters = append(filters, filterSeverity(leveledSeverity))
+ }
+
+ var entries []*entry
+ if backlog > 0 || backlog == BacklogAllAvailable {
+ // TODO(q3k): pass over the backlog count to scanEntries/getEntries, instead of discarding them here.
+ if recursive {
+ entries = l.journal.scanEntries(filters...)
+ } else {
+ entries = l.journal.getEntries(dn, filters...)
+ }
+ if backlog != BacklogAllAvailable && len(entries) > backlog {
+ entries = entries[:backlog]
+ }
+ }
+
+ var sub *subscriber
+ if stream {
+ sub = &subscriber{
+ // TODO(q3k): make buffer size configurable
+ dataC: make(chan *LogEntry, 128),
+ doneC: make(chan struct{}),
+ filters: filters,
+ }
+ l.journal.subscribe(sub)
+ }
+
+ lr := &LogReader{}
+ lr.Backlog = make([]*LogEntry, len(entries))
+ for i, entry := range entries {
+ lr.Backlog[i] = entry.external()
+ }
+ if stream {
+ lr.Stream = sub.dataC
+ lr.done = sub.doneC
+ lr.missed = &sub.missed
+ }
+ return lr, nil
+}
diff --git a/metropolis/node/core/logtree/logtree_entry.go b/metropolis/node/core/logtree/logtree_entry.go
new file mode 100644
index 0000000..635e5a8
--- /dev/null
+++ b/metropolis/node/core/logtree/logtree_entry.go
@@ -0,0 +1,141 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logtree
+
+import (
+ "fmt"
+ "strings"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/node/common/logbuffer"
+ apb "git.monogon.dev/source/nexantic.git/metropolis/proto/api"
+)
+
+// LogEntry contains a log entry, combining both leveled and raw logging into a single stream of events. A LogEntry
+// will contain exactly one of either LeveledPayload or RawPayload.
+type LogEntry struct {
+ // If non-nil, this is a leveled logging entry.
+ Leveled *LeveledPayload
+ // If non-nil, this is a raw logging entry line.
+ Raw *logbuffer.Line
+ // DN from which this entry was logged.
+ DN DN
+}
+
+// String returns a canonical representation of this payload as a single string prefixed with metadata. If the entry is
+// a leveled log entry that originally was logged with newlines this representation will also contain newlines, with
+// each original message part prefixed by the metadata.
+// For an alternative call that will instead return a canonical prefix and a list of lines in the message, see Strings().
+func (l *LogEntry) String() string {
+ if l.Leveled != nil {
+ prefix, messages := l.Leveled.Strings()
+ res := make([]string, len(messages))
+ for i, m := range messages {
+ res[i] = fmt.Sprintf("%-32s %s%s", l.DN, prefix, m)
+ }
+ return strings.Join(res, "\n")
+ }
+ if l.Raw != nil {
+ return fmt.Sprintf("%-32s R %s", l.DN, l.Raw)
+ }
+ return "INVALID"
+}
+
+// Strings returns the canonical representation of this payload split into a prefix and all lines that were contained in
+// the original message. This is meant to be displayed to the user by showing the prefix before each line, concatenated
+// together - possibly in a table form with the prefixes all unified with a rowspan-like mechanism.
+//
+// For example, this function can return:
+// prefix = "root.foo.bar I1102 17:20:06.921395 0 foo.go:42] "
+// lines = []string{"current tags:", " - one", " - two"}
+//
+// With this data, the result should be presented to users this way in text form:
+// root.foo.bar I1102 17:20:06.921395 foo.go:42] current tags:
+// root.foo.bar I1102 17:20:06.921395 foo.go:42] - one
+// root.foo.bar I1102 17:20:06.921395 foo.go:42] - two
+//
+// Or, in a table layout:
+// .-------------------------------------------------------------------------------------.
+// | root.foo.bar I1102 17:20:06.921395 foo.go:42] : current tags: |
+// | :------------------|
+// | : - one |
+// | :------------------|
+// | : - two |
+// '-------------------------------------------------------------------------------------'
+//
+func (l *LogEntry) Strings() (prefix string, lines []string) {
+ if l.Leveled != nil {
+ prefix, messages := l.Leveled.Strings()
+ prefix = fmt.Sprintf("%-32s %s", l.DN, prefix)
+ return prefix, messages
+ }
+ if l.Raw != nil {
+ return fmt.Sprintf("%-32s R ", l.DN), []string{l.Raw.Data}
+ }
+ return "INVALID ", []string{"INVALID"}
+}
+
+// Convert this LogEntry to proto. Returned value may be nil if given LogEntry is invalid, eg. contains neither a Raw
+// nor Leveled entry.
+func (l *LogEntry) Proto() *apb.LogEntry {
+ p := &apb.LogEntry{
+ Dn: string(l.DN),
+ }
+ switch {
+ case l.Leveled != nil:
+ leveled := l.Leveled
+ p.Kind = &apb.LogEntry_Leveled_{
+ Leveled: leveled.Proto(),
+ }
+ case l.Raw != nil:
+ raw := l.Raw
+ p.Kind = &apb.LogEntry_Raw_{
+ Raw: raw.ProtoLog(),
+ }
+ default:
+ return nil
+ }
+ return p
+}
+
+// Parse a proto LogEntry back into internal structure. This can be used in log proto API consumers to easily print
+// received log entries.
+func LogEntryFromProto(l *apb.LogEntry) (*LogEntry, error) {
+ dn := DN(l.Dn)
+ if _, err := dn.Path(); err != nil {
+ return nil, fmt.Errorf("could not convert DN: %w", err)
+ }
+ res := &LogEntry{
+ DN: dn,
+ }
+ switch inner := l.Kind.(type) {
+ case *apb.LogEntry_Leveled_:
+ leveled, err := LeveledPayloadFromProto(inner.Leveled)
+ if err != nil {
+ return nil, fmt.Errorf("could not convert leveled entry: %w", err)
+ }
+ res.Leveled = leveled
+ case *apb.LogEntry_Raw_:
+ line, err := logbuffer.LineFromLogProto(inner.Raw)
+ if err != nil {
+ return nil, fmt.Errorf("could not convert raw entry: %w", err)
+ }
+ res.Raw = line
+ default:
+ return nil, fmt.Errorf("proto has neither Leveled nor Raw set")
+ }
+ return res, nil
+}
diff --git a/metropolis/node/core/logtree/logtree_publisher.go b/metropolis/node/core/logtree/logtree_publisher.go
new file mode 100644
index 0000000..c4880bc
--- /dev/null
+++ b/metropolis/node/core/logtree/logtree_publisher.go
@@ -0,0 +1,185 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logtree
+
+import (
+ "fmt"
+ "io"
+ "runtime"
+ "strings"
+ "time"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/node/common/logbuffer"
+)
+
+// LeveledFor returns a LeveledLogger publishing interface for a given DN. An error may be returned if the DN is
+// malformed.
+func (l *LogTree) LeveledFor(dn DN) (LeveledLogger, error) {
+ return l.nodeByDN(dn)
+}
+
+func (l *LogTree) RawFor(dn DN) (io.Writer, error) {
+ node, err := l.nodeByDN(dn)
+ if err != nil {
+ return nil, fmt.Errorf("could not retrieve raw logger: %w", err)
+ }
+ return node.rawLineBuffer, nil
+}
+
+// MustLeveledFor returns a LeveledLogger publishing interface for a given DN, or panics if the given DN is invalid.
+func (l *LogTree) MustLeveledFor(dn DN) LeveledLogger {
+ leveled, err := l.LeveledFor(dn)
+ if err != nil {
+ panic(fmt.Errorf("LeveledFor returned: %w", err))
+ }
+ return leveled
+}
+
+func (l *LogTree) MustRawFor(dn DN) io.Writer {
+ raw, err := l.RawFor(dn)
+ if err != nil {
+ panic(fmt.Errorf("RawFor returned: %w", err))
+ }
+ return raw
+}
+
+// SetVerbosity sets the verbosity for a given DN (non-recursively, ie. for that DN only, not its children).
+func (l *LogTree) SetVerbosity(dn DN, level VerbosityLevel) error {
+ node, err := l.nodeByDN(dn)
+ if err != nil {
+ return err
+ }
+ node.verbosity = level
+ return nil
+}
+
+// logRaw is called by this node's LineBuffer any time a raw log line is completed. It will create a new entry, append
+// it to the journal, and notify all pertinent subscribers.
+func (n *node) logRaw(line *logbuffer.Line) {
+ e := &entry{
+ origin: n.dn,
+ raw: line,
+ }
+ n.tree.journal.append(e)
+ n.tree.journal.notify(e)
+}
+
+// log builds a LeveledPayload and entry for a given message, including all related metadata. It will create a new
+// entry append it to the journal, and notify all pertinent subscribers.
+func (n *node) logLeveled(depth int, severity Severity, msg string) {
+ _, file, line, ok := runtime.Caller(2 + depth)
+ if !ok {
+ file = "???"
+ line = 1
+ } else {
+ slash := strings.LastIndex(file, "/")
+ if slash >= 0 {
+ file = file[slash+1:]
+ }
+ }
+
+ // Remove leading/trailing newlines and split.
+ messages := strings.Split(strings.Trim(msg, "\n"), "\n")
+
+ p := &LeveledPayload{
+ timestamp: time.Now(),
+ severity: severity,
+ messages: messages,
+ file: file,
+ line: line,
+ }
+ e := &entry{
+ origin: n.dn,
+ leveled: p,
+ }
+ n.tree.journal.append(e)
+ n.tree.journal.notify(e)
+}
+
+// Info implements the LeveledLogger interface.
+func (n *node) Info(args ...interface{}) {
+ n.logLeveled(0, INFO, fmt.Sprint(args...))
+}
+
+// Infof implements the LeveledLogger interface.
+func (n *node) Infof(format string, args ...interface{}) {
+ n.logLeveled(0, INFO, fmt.Sprintf(format, args...))
+}
+
+// Warning implements the LeveledLogger interface.
+func (n *node) Warning(args ...interface{}) {
+ n.logLeveled(0, WARNING, fmt.Sprint(args...))
+}
+
+// Warningf implements the LeveledLogger interface.
+func (n *node) Warningf(format string, args ...interface{}) {
+ n.logLeveled(0, WARNING, fmt.Sprintf(format, args...))
+}
+
+// Error implements the LeveledLogger interface.
+func (n *node) Error(args ...interface{}) {
+ n.logLeveled(0, ERROR, fmt.Sprint(args...))
+}
+
+// Errorf implements the LeveledLogger interface.
+func (n *node) Errorf(format string, args ...interface{}) {
+ n.logLeveled(0, ERROR, fmt.Sprintf(format, args...))
+}
+
+// Fatal implements the LeveledLogger interface.
+func (n *node) Fatal(args ...interface{}) {
+ n.logLeveled(0, FATAL, fmt.Sprint(args...))
+}
+
+// Fatalf implements the LeveledLogger interface.
+func (n *node) Fatalf(format string, args ...interface{}) {
+ n.logLeveled(0, FATAL, fmt.Sprintf(format, args...))
+}
+
+// V implements the LeveledLogger interface.
+func (n *node) V(v VerbosityLevel) VerboseLeveledLogger {
+ return &verbose{
+ node: n,
+ enabled: n.verbosity >= v,
+ }
+}
+
+// verbose implements the VerboseLeveledLogger interface. It is a thin wrapper around node, with an 'enabled' bool. This
+// means that V(n)-returned VerboseLeveledLoggers must be short lived, as a changed in verbosity will not affect all
+// already existing VerboseLeveledLoggers.
+type verbose struct {
+ node *node
+ enabled bool
+}
+
+func (v *verbose) Enabled() bool {
+ return v.enabled
+}
+
+func (v *verbose) Info(args ...interface{}) {
+ if !v.enabled {
+ return
+ }
+ v.node.logLeveled(0, INFO, fmt.Sprint(args...))
+}
+
+func (v *verbose) Infof(format string, args ...interface{}) {
+ if !v.enabled {
+ return
+ }
+ v.node.logLeveled(0, INFO, fmt.Sprintf(format, args...))
+}
diff --git a/metropolis/node/core/logtree/logtree_test.go b/metropolis/node/core/logtree/logtree_test.go
new file mode 100644
index 0000000..b900201
--- /dev/null
+++ b/metropolis/node/core/logtree/logtree_test.go
@@ -0,0 +1,211 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logtree
+
+import (
+ "fmt"
+ "strings"
+ "testing"
+ "time"
+)
+
+func expect(tree *LogTree, t *testing.T, dn DN, entries ...string) string {
+ t.Helper()
+ res, err := tree.Read(dn, WithChildren(), WithBacklog(BacklogAllAvailable))
+ if err != nil {
+ t.Fatalf("Read: %v", err)
+ }
+ if want, got := len(entries), len(res.Backlog); want != got {
+ t.Fatalf("wanted %v backlog entries, got %v", want, got)
+ }
+ got := make(map[string]bool)
+ for _, entry := range res.Backlog {
+ if entry.Leveled != nil {
+ got[entry.Leveled.MessagesJoined()] = true
+ }
+ if entry.Raw != nil {
+ got[entry.Raw.Data] = true
+ }
+ }
+ for _, entry := range entries {
+ if !got[entry] {
+ return fmt.Sprintf("missing entry %q", entry)
+ }
+ }
+ return ""
+}
+
+func TestMultiline(t *testing.T) {
+ tree := New()
+ // Two lines in a single message.
+ tree.MustLeveledFor("main").Info("foo\nbar")
+ // Two lines in a single message with a hanging newline that should get stripped.
+ tree.MustLeveledFor("main").Info("one\ntwo\n")
+
+ if res := expect(tree, t, "main", "foo\nbar", "one\ntwo"); res != "" {
+ t.Errorf("retrieval at main failed: %s", res)
+ }
+}
+
+func TestBacklog(t *testing.T) {
+ tree := New()
+ tree.MustLeveledFor("main").Info("hello, main!")
+ tree.MustLeveledFor("main.foo").Info("hello, main.foo!")
+ tree.MustLeveledFor("main.bar").Info("hello, main.bar!")
+ tree.MustLeveledFor("aux").Info("hello, aux!")
+ // No newline at the last entry - shouldn't get propagated to the backlog.
+ fmt.Fprintf(tree.MustRawFor("aux.process"), "processing foo\nprocessing bar\nbaz")
+
+ if res := expect(tree, t, "main", "hello, main!", "hello, main.foo!", "hello, main.bar!"); res != "" {
+ t.Errorf("retrieval at main failed: %s", res)
+ }
+ if res := expect(tree, t, "", "hello, main!", "hello, main.foo!", "hello, main.bar!", "hello, aux!", "processing foo", "processing bar"); res != "" {
+ t.Errorf("retrieval at root failed: %s", res)
+ }
+ if res := expect(tree, t, "aux", "hello, aux!", "processing foo", "processing bar"); res != "" {
+ t.Errorf("retrieval at aux failed: %s", res)
+ }
+}
+
+func TestStream(t *testing.T) {
+ tree := New()
+ tree.MustLeveledFor("main").Info("hello, backlog")
+ fmt.Fprintf(tree.MustRawFor("main.process"), "hello, raw backlog\n")
+
+ res, err := tree.Read("", WithBacklog(BacklogAllAvailable), WithChildren(), WithStream())
+ if err != nil {
+ t.Fatalf("Read: %v", err)
+ }
+ defer res.Close()
+ if want, got := 2, len(res.Backlog); want != got {
+ t.Errorf("wanted %d backlog item, got %d", want, got)
+ }
+
+ tree.MustLeveledFor("main").Info("hello, stream")
+ fmt.Fprintf(tree.MustRawFor("main.raw"), "hello, raw stream\n")
+
+ entries := make(map[string]bool)
+ timeout := time.After(time.Second * 1)
+ for {
+ done := false
+ select {
+ case <-timeout:
+ done = true
+ case p := <-res.Stream:
+ if p.Leveled != nil {
+ entries[p.Leveled.MessagesJoined()] = true
+ }
+ if p.Raw != nil {
+ entries[p.Raw.Data] = true
+ }
+ }
+ if done {
+ break
+ }
+ }
+ if entry := "hello, stream"; !entries[entry] {
+ t.Errorf("Missing entry %q", entry)
+ }
+ if entry := "hello, raw stream"; !entries[entry] {
+ t.Errorf("Missing entry %q", entry)
+ }
+}
+
+func TestVerbose(t *testing.T) {
+ tree := New()
+
+ tree.MustLeveledFor("main").V(10).Info("this shouldn't get logged")
+
+ reader, err := tree.Read("", WithBacklog(BacklogAllAvailable), WithChildren())
+ if err != nil {
+ t.Fatalf("Read: %v", err)
+ }
+ if want, got := 0, len(reader.Backlog); want != got {
+ t.Fatalf("expected nothing to be logged, got %+v", reader.Backlog)
+ }
+
+ tree.SetVerbosity("main", 10)
+ tree.MustLeveledFor("main").V(10).Info("this should get logged")
+
+ reader, err = tree.Read("", WithBacklog(BacklogAllAvailable), WithChildren())
+ if err != nil {
+ t.Fatalf("Read: %v", err)
+ }
+ if want, got := 1, len(reader.Backlog); want != got {
+ t.Fatalf("expected %d entries to get logged, got %d", want, got)
+ }
+}
+
+func TestMetadata(t *testing.T) {
+ tree := New()
+ tree.MustLeveledFor("main").Error("i am an error")
+ tree.MustLeveledFor("main").Warning("i am a warning")
+ tree.MustLeveledFor("main").Info("i am informative")
+ tree.MustLeveledFor("main").V(0).Info("i am a zero-level debug")
+
+ reader, err := tree.Read("", WithChildren(), WithBacklog(BacklogAllAvailable))
+ if err != nil {
+ t.Fatalf("Read: %v", err)
+ }
+ if want, got := 4, len(reader.Backlog); want != got {
+ t.Fatalf("expected %d entries, got %d", want, got)
+ }
+
+ for _, te := range []struct {
+ ix int
+ severity Severity
+ message string
+ }{
+ {0, ERROR, "i am an error"},
+ {1, WARNING, "i am a warning"},
+ {2, INFO, "i am informative"},
+ {3, INFO, "i am a zero-level debug"},
+ } {
+ p := reader.Backlog[te.ix]
+ if want, got := te.severity, p.Leveled.Severity(); want != got {
+ t.Errorf("wanted element %d to have severity %s, got %s", te.ix, want, got)
+ }
+ if want, got := te.message, p.Leveled.MessagesJoined(); want != got {
+ t.Errorf("wanted element %d to have message %q, got %q", te.ix, want, got)
+ }
+ if want, got := "logtree_test.go", strings.Split(p.Leveled.Location(), ":")[0]; want != got {
+ t.Errorf("wanted element %d to have file %q, got %q", te.ix, want, got)
+ }
+ }
+}
+
+func TestSeverity(t *testing.T) {
+ tree := New()
+ tree.MustLeveledFor("main").Error("i am an error")
+ tree.MustLeveledFor("main").Warning("i am a warning")
+ tree.MustLeveledFor("main").Info("i am informative")
+ tree.MustLeveledFor("main").V(0).Info("i am a zero-level debug")
+
+ reader, err := tree.Read("main", WithBacklog(BacklogAllAvailable), LeveledWithMinimumSeverity(WARNING))
+ if err != nil {
+ t.Fatalf("Read: %v", err)
+ }
+ if want, got := 2, len(reader.Backlog); want != got {
+ t.Fatalf("wanted %d entries, got %d", want, got)
+ }
+ if want, got := "i am an error", reader.Backlog[0].Leveled.MessagesJoined(); want != got {
+ t.Fatalf("wanted entry %q, got %q", want, got)
+ }
+ if want, got := "i am a warning", reader.Backlog[1].Leveled.MessagesJoined(); want != got {
+ t.Fatalf("wanted entry %q, got %q", want, got)
+ }
+}
diff --git a/metropolis/node/core/main.go b/metropolis/node/core/main.go
new file mode 100644
index 0000000..54d09a4
--- /dev/null
+++ b/metropolis/node/core/main.go
@@ -0,0 +1,321 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "context"
+ "crypto/ed25519"
+ "crypto/rand"
+ "crypto/x509"
+ "fmt"
+ "log"
+ "math/big"
+ "net"
+ "os"
+ "os/signal"
+ "runtime/debug"
+
+ "golang.org/x/sys/unix"
+ "google.golang.org/grpc"
+
+ common "git.monogon.dev/source/nexantic.git/metropolis/node"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/common/supervisor"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/cluster"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/localstorage"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/localstorage/declarative"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/logtree"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/network"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/network/dns"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/tpm"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/kubernetes"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/kubernetes/containerd"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/kubernetes/pki"
+ apb "git.monogon.dev/source/nexantic.git/metropolis/proto/api"
+)
+
+var (
+ // kubernetesConfig is the static/global part of the Kubernetes service configuration. In the future, this might
+ // be configurable by loading it from the EnrolmentConfig. Fow now, it's static and same across all clusters.
+ kubernetesConfig = kubernetes.Config{
+ ServiceIPRange: net.IPNet{ // TODO(q3k): Decide if configurable / final value
+ IP: net.IP{10, 0, 255, 1},
+ Mask: net.IPMask{0xff, 0xff, 0xff, 0x00}, // /24, but Go stores as a literal mask
+ },
+ ClusterNet: net.IPNet{
+ IP: net.IP{10, 0, 0, 0},
+ Mask: net.IPMask{0xff, 0xff, 0x00, 0x00}, // /16
+ },
+ }
+)
+
+func main() {
+ defer func() {
+ if r := recover(); r != nil {
+ fmt.Println("Init panicked:", r)
+ debug.PrintStack()
+ }
+ unix.Sync()
+ // TODO(lorenz): Switch this to Reboot when init panics are less likely
+ // Best effort, nothing we can do if this fails except printing the error to the console.
+ if err := unix.Reboot(unix.LINUX_REBOOT_CMD_POWER_OFF); err != nil {
+ panic(fmt.Sprintf("failed to halt node: %v\n", err))
+ }
+ }()
+
+ // Set up logger for Smalltown. Currently logs everything to stderr.
+ lt := logtree.New()
+ reader, err := lt.Read("", logtree.WithChildren(), logtree.WithStream())
+ if err != nil {
+ panic(fmt.Errorf("could not set up root log reader: %v", err))
+ }
+ go func() {
+ for {
+ p := <-reader.Stream
+ fmt.Fprintf(os.Stderr, "%s\n", p.String())
+ }
+ }()
+
+ // Initial logger. Used until we get to a supervisor.
+ logger := lt.MustLeveledFor("init")
+
+ // Remount onto a tmpfs and re-exec if needed. Otherwise, keep running.
+ err = switchRoot(logger)
+ if err != nil {
+ panic(fmt.Errorf("could not remount root: %w", err))
+ }
+
+ // Linux kernel default is 4096 which is far too low. Raise it to 1M which is what gVisor suggests.
+ if err := unix.Setrlimit(unix.RLIMIT_NOFILE, &unix.Rlimit{Cur: 1048576, Max: 1048576}); err != nil {
+ logger.Fatalf("Failed to raise rlimits: %v", err)
+ }
+
+ logger.Info("Starting Smalltown Init")
+
+ signalChannel := make(chan os.Signal, 2)
+ signal.Notify(signalChannel)
+
+ if err := tpm.Initialize(logger); err != nil {
+ logger.Fatalf("Failed to initialize TPM 2.0: %v", err)
+ }
+
+ corednsRegistrationChan := make(chan *dns.ExtraDirective)
+
+ networkSvc := network.New(network.Config{CorednsRegistrationChan: corednsRegistrationChan})
+
+ // This function initializes a headless Delve if this is a debug build or does nothing if it's not
+ initializeDebugger(networkSvc)
+
+ // Prepare local storage.
+ root := &localstorage.Root{}
+ if err := declarative.PlaceFS(root, "/"); err != nil {
+ panic(fmt.Errorf("when placing root FS: %w", err))
+ }
+
+ // trapdoor is a channel used to signal to the init service that a very low-level, unrecoverable failure
+ // occured. This causes a GURU MEDITATION ERROR visible to the end user.
+ trapdoor := make(chan struct{})
+
+ // Make context for supervisor. We cancel it when we reach the trapdoor.
+ ctxS, ctxC := context.WithCancel(context.Background())
+
+ // Start root initialization code as a supervisor one-shot runnable. This means waiting for the network, starting
+ // the cluster manager, and then starting all services related to the node's roles.
+ // TODO(q3k): move this to a separate 'init' service.
+ supervisor.New(ctxS, func(ctx context.Context) error {
+ logger := supervisor.Logger(ctx)
+
+ // Start storage and network - we need this to get anything else done.
+ if err := root.Start(ctx); err != nil {
+ return fmt.Errorf("cannot start root FS: %w", err)
+ }
+ if err := supervisor.Run(ctx, "network", networkSvc.Run); err != nil {
+ return fmt.Errorf("when starting network: %w", err)
+ }
+
+ // Wait for IP address from network.
+ ip, err := networkSvc.GetIP(ctx, true)
+ if err != nil {
+ return fmt.Errorf("when waiting for IP address: %w", err)
+ }
+
+ // Start cluster manager. This kicks off cluster membership machinery, which will either start
+ // a new cluster, enroll into one or join one.
+ m := cluster.NewManager(root, networkSvc)
+ if err := supervisor.Run(ctx, "enrolment", m.Run); err != nil {
+ return fmt.Errorf("when starting enrolment: %w", err)
+ }
+
+ // Wait until the cluster manager settles.
+ success := m.WaitFinished()
+ if !success {
+ close(trapdoor)
+ return fmt.Errorf("enrolment failed, aborting")
+ }
+
+ // We are now in a cluster. We can thus access our 'node' object and start all services that
+ // we should be running.
+
+ node := m.Node()
+ if err := node.ConfigureLocalHostname(&root.Etc); err != nil {
+ close(trapdoor)
+ return fmt.Errorf("failed to set local hostname: %w", err)
+ }
+
+ logger.Info("Enrolment success, continuing startup.")
+ logger.Info(fmt.Sprintf("This node (%s) has roles:", node.String()))
+ if cm := node.ConsensusMember(); cm != nil {
+ // There's no need to start anything for when we are a consensus member - the cluster
+ // manager does this for us if necessary (as creating/enrolling/joining a cluster is
+ // pretty tied into cluster lifecycle management).
+ logger.Info(fmt.Sprintf(" - etcd consensus member"))
+ }
+ if kw := node.KubernetesWorker(); kw != nil {
+ logger.Info(fmt.Sprintf(" - kubernetes worker"))
+ }
+
+ // If we're supposed to be a kubernetes worker, start kubernetes services and containerd.
+ // In the future, this might be split further into kubernetes control plane and data plane
+ // roles.
+ var containerdSvc *containerd.Service
+ var kubeSvc *kubernetes.Service
+ if kw := node.KubernetesWorker(); kw != nil {
+ logger.Info("Starting Kubernetes worker services...")
+
+ // Ensure Kubernetes PKI objects exist in etcd.
+ kpkiKV := m.ConsensusKV("cluster", "kpki")
+ kpki := pki.NewKubernetes(lt.MustLeveledFor("pki.kubernetes"), kpkiKV)
+ if err := kpki.EnsureAll(ctx); err != nil {
+ return fmt.Errorf("failed to ensure kubernetes PKI present: %w", err)
+ }
+
+ containerdSvc = &containerd.Service{
+ EphemeralVolume: &root.Ephemeral.Containerd,
+ }
+ if err := supervisor.Run(ctx, "containerd", containerdSvc.Run); err != nil {
+ return fmt.Errorf("failed to start containerd service: %w", err)
+ }
+
+ kubernetesConfig.KPKI = kpki
+ kubernetesConfig.Root = root
+ kubernetesConfig.AdvertiseAddress = *ip
+ kubernetesConfig.CorednsRegistrationChan = corednsRegistrationChan
+ kubeSvc = kubernetes.New(kubernetesConfig)
+ if err := supervisor.Run(ctx, "kubernetes", kubeSvc.Run); err != nil {
+ return fmt.Errorf("failed to start kubernetes service: %w", err)
+ }
+
+ }
+
+ // Start the node debug service.
+ dbg := &debugService{
+ cluster: m,
+ logtree: lt,
+ kubernetes: kubeSvc,
+ }
+ dbgSrv := grpc.NewServer()
+ apb.RegisterNodeDebugServiceServer(dbgSrv, dbg)
+ dbgLis, err := net.Listen("tcp", fmt.Sprintf(":%d", common.DebugServicePort))
+ if err != nil {
+ return fmt.Errorf("failed to listen on debug service: %w", err)
+ }
+ if err := supervisor.Run(ctx, "debug", supervisor.GRPCServer(dbgSrv, dbgLis, false)); err != nil {
+ return fmt.Errorf("failed to start debug service: %w", err)
+ }
+
+ supervisor.Signal(ctx, supervisor.SignalHealthy)
+ supervisor.Signal(ctx, supervisor.SignalDone)
+ return nil
+ }, supervisor.WithExistingLogtree(lt))
+
+ // We're PID1, so orphaned processes get reparented to us to clean up
+ for {
+ select {
+ case <-trapdoor:
+ // If the trapdoor got closed, we got stuck early enough in the boot process that we can't do anything about
+ // it. Display a generic error message until we handle error conditions better.
+ ctxC()
+ log.Printf(" ########################")
+ log.Printf(" # GURU MEDIATION ERROR #")
+ log.Printf(" ########################")
+ log.Printf("")
+ log.Printf("Smalltown encountered an uncorrectable error and must be restarted.")
+ log.Printf("(Error condition: init trapdoor closed)")
+ log.Printf("")
+ select {}
+
+ case sig := <-signalChannel:
+ switch sig {
+ case unix.SIGCHLD:
+ var status unix.WaitStatus
+ var rusage unix.Rusage
+ for {
+ res, err := unix.Wait4(-1, &status, unix.WNOHANG, &rusage)
+ if err != nil && err != unix.ECHILD {
+ logger.Errorf("Failed to wait on orphaned child: %v", err)
+ break
+ }
+ if res <= 0 {
+ break
+ }
+ }
+ case unix.SIGURG:
+ // Go 1.14 introduced asynchronous preemption, which uses SIGURG.
+ // In order not to break backwards compatibility in the unlikely case
+ // of an application actually using SIGURG on its own, they're not filtering them.
+ // (https://github.com/golang/go/issues/37942)
+ logger.V(5).Info("Ignoring SIGURG")
+ // TODO(lorenz): We can probably get more than just SIGCHLD as init, but I can't think
+ // of any others right now, just log them in case we hit any of them.
+ default:
+ logger.Warningf("Got unexpected signal %s", sig.String())
+ }
+ }
+ }
+}
+
+// nodeCertificate creates a node key/certificate for a foreign node. This is duplicated code with localstorage's
+// PKIDirectory EnsureSelfSigned, but is temporary (and specific to 'golden tickets').
+func (s *debugService) nodeCertificate() (cert, key []byte, err error) {
+ pubKey, privKey, err := ed25519.GenerateKey(rand.Reader)
+ if err != nil {
+ err = fmt.Errorf("failed to generate key: %w", err)
+ return
+ }
+
+ key, err = x509.MarshalPKCS8PrivateKey(privKey)
+ if err != nil {
+ err = fmt.Errorf("failed to marshal key: %w", err)
+ return
+ }
+
+ serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 127)
+ serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
+ if err != nil {
+ err = fmt.Errorf("failed to generate serial number: %w", err)
+ return
+ }
+
+ template := localstorage.CertificateForNode(pubKey)
+ template.SerialNumber = serialNumber
+
+ cert, err = x509.CreateCertificate(rand.Reader, &template, &template, pubKey, privKey)
+ if err != nil {
+ err = fmt.Errorf("could not sign certificate: %w", err)
+ return
+ }
+ return
+}
diff --git a/metropolis/node/core/network/BUILD.bazel b/metropolis/node/core/network/BUILD.bazel
new file mode 100644
index 0000000..9ba56a9
--- /dev/null
+++ b/metropolis/node/core/network/BUILD.bazel
@@ -0,0 +1,20 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "go_default_library",
+ srcs = ["main.go"],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/node/core/network",
+ visibility = ["//:__subpackages__"],
+ deps = [
+ "//metropolis/node/common/supervisor:go_default_library",
+ "//metropolis/node/core/logtree:go_default_library",
+ "//metropolis/node/core/network/dhcp4c:go_default_library",
+ "//metropolis/node/core/network/dhcp4c/callback:go_default_library",
+ "//metropolis/node/core/network/dns:go_default_library",
+ "@com_github_google_nftables//:go_default_library",
+ "@com_github_google_nftables//expr:go_default_library",
+ "@com_github_insomniacslk_dhcp//dhcpv4:go_default_library",
+ "@com_github_vishvananda_netlink//:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/metropolis/node/core/network/dhcp4c/BUILD.bazel b/metropolis/node/core/network/dhcp4c/BUILD.bazel
new file mode 100644
index 0000000..19b4c70
--- /dev/null
+++ b/metropolis/node/core/network/dhcp4c/BUILD.bazel
@@ -0,0 +1,35 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "go_default_library",
+ srcs = [
+ "dhcpc.go",
+ "doc.go",
+ "lease.go",
+ ],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/node/core/network/dhcp4c",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//metropolis/node/common/supervisor:go_default_library",
+ "//metropolis/node/core/network/dhcp4c/transport:go_default_library",
+ "@com_github_cenkalti_backoff_v4//:go_default_library",
+ "@com_github_insomniacslk_dhcp//dhcpv4:go_default_library",
+ "@com_github_insomniacslk_dhcp//iana:go_default_library",
+ ],
+)
+
+go_test(
+ name = "go_default_test",
+ srcs = [
+ "dhcpc_test.go",
+ "lease_test.go",
+ ],
+ embed = [":go_default_library"],
+ pure = "on",
+ deps = [
+ "//metropolis/node/core/network/dhcp4c/transport:go_default_library",
+ "@com_github_cenkalti_backoff_v4//:go_default_library",
+ "@com_github_insomniacslk_dhcp//dhcpv4:go_default_library",
+ "@com_github_stretchr_testify//assert:go_default_library",
+ ],
+)
diff --git a/metropolis/node/core/network/dhcp4c/callback/BUILD.bazel b/metropolis/node/core/network/dhcp4c/callback/BUILD.bazel
new file mode 100644
index 0000000..ed6f330
--- /dev/null
+++ b/metropolis/node/core/network/dhcp4c/callback/BUILD.bazel
@@ -0,0 +1,36 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+load("//metropolis/test/ktest:ktest.bzl", "ktest")
+
+go_library(
+ name = "go_default_library",
+ srcs = ["callback.go"],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/node/core/network/dhcp4c/callback",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//metropolis/node/core/network/dhcp4c:go_default_library",
+ "@com_github_insomniacslk_dhcp//dhcpv4:go_default_library",
+ "@com_github_vishvananda_netlink//:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+go_test(
+ name = "go_default_test",
+ srcs = ["callback_test.go"],
+ embed = [":go_default_library"],
+ pure = "on",
+ deps = [
+ "//metropolis/node/core/network/dhcp4c:go_default_library",
+ "@com_github_insomniacslk_dhcp//dhcpv4:go_default_library",
+ "@com_github_stretchr_testify//require:go_default_library",
+ "@com_github_vishvananda_netlink//:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+ktest(
+ cmdline = "",
+ initramfs_extra = "",
+ tester = ":go_default_test",
+ deps = [],
+)
diff --git a/metropolis/node/core/network/dhcp4c/callback/callback.go b/metropolis/node/core/network/dhcp4c/callback/callback.go
new file mode 100644
index 0000000..10eb6ba
--- /dev/null
+++ b/metropolis/node/core/network/dhcp4c/callback/callback.go
@@ -0,0 +1,149 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package callback contains minimal callbacks for configuring the kernel with options received over DHCP.
+//
+// These directly configure the relevant kernel subsytems and need to own certain parts of them as documented on a per-
+// callback basis to make sure that they can recover from restarts and crashes of the DHCP client.
+// The callbacks in here are not suitable for use in advanced network scenarios like running multiple DHCP clients
+// per interface via ClientIdentifier or when running an external FIB manager. In these cases it's advised to extract
+// the necessary information from the lease in your own callback and communicate it directly to the responsible entity.
+package callback
+
+import (
+ "fmt"
+ "math"
+ "net"
+ "os"
+ "time"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/network/dhcp4c"
+
+ "github.com/insomniacslk/dhcp/dhcpv4"
+ "github.com/vishvananda/netlink"
+ "golang.org/x/sys/unix"
+)
+
+// Compose can be used to chain multiple callbacks
+func Compose(callbacks ...dhcp4c.LeaseCallback) dhcp4c.LeaseCallback {
+ return func(old, new *dhcp4c.Lease) error {
+ for _, cb := range callbacks {
+ if err := cb(old, new); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+}
+
+func isIPNetEqual(a, b *net.IPNet) bool {
+ if a == b {
+ return true
+ }
+ if a == nil || b == nil {
+ return false
+ }
+ aOnes, aBits := a.Mask.Size()
+ bOnes, bBits := b.Mask.Size()
+ return a.IP.Equal(b.IP) && aOnes == bOnes && aBits == bBits
+}
+
+// ManageIP sets up and tears down the assigned IP address. It takes exclusive ownership of all IPv4 addresses
+// on the given interface which do not have IFA_F_PERMANENT set, so it's not possible to run multiple dynamic addressing
+// clients on a single interface.
+func ManageIP(iface netlink.Link) dhcp4c.LeaseCallback {
+ return func(old, new *dhcp4c.Lease) error {
+ newNet := new.IPNet()
+
+ addrs, err := netlink.AddrList(iface, netlink.FAMILY_V4)
+ if err != nil {
+ return fmt.Errorf("netlink failed to list addresses: %w", err)
+ }
+
+ for _, addr := range addrs {
+ if addr.Flags&unix.IFA_F_PERMANENT == 0 {
+ // Linux identifies addreses by IP, mask and peer (see net/ipv4/devinet.find_matching_ifa in Linux 5.10)
+ // So don't touch addresses which match on these properties as AddrReplace will atomically reconfigure
+ // them anyways without interrupting things.
+ if isIPNetEqual(addr.IPNet, newNet) && addr.Peer == nil && new != nil {
+ continue
+ }
+
+ if err := netlink.AddrDel(iface, &addr); !os.IsNotExist(err) && err != nil {
+ return fmt.Errorf("failed to delete address: %w", err)
+ }
+ }
+ }
+
+ if new != nil {
+ remainingLifetimeSecs := int(math.Ceil(new.ExpiresAt.Sub(time.Now()).Seconds()))
+ newBroadcastIP := dhcpv4.GetIP(dhcpv4.OptionBroadcastAddress, new.Options)
+ if err := netlink.AddrReplace(iface, &netlink.Addr{
+ IPNet: newNet,
+ ValidLft: remainingLifetimeSecs,
+ PreferedLft: remainingLifetimeSecs,
+ Broadcast: newBroadcastIP,
+ }); err != nil {
+ return fmt.Errorf("failed to update address: %w", err)
+ }
+ }
+ return nil
+ }
+}
+
+// ManageDefaultRoute manages a default route through the first router offered by DHCP. It does nothing if DHCP
+// doesn't provide any routers. It takes ownership of all RTPROTO_DHCP routes on the given interface, so it's not
+// possible to run multiple DHCP clients on the given interface.
+func ManageDefaultRoute(iface netlink.Link) dhcp4c.LeaseCallback {
+ return func(old, new *dhcp4c.Lease) error {
+ newRouter := new.Router()
+
+ dhcpRoutes, err := netlink.RouteListFiltered(netlink.FAMILY_V4, &netlink.Route{
+ Protocol: unix.RTPROT_DHCP,
+ LinkIndex: iface.Attrs().Index,
+ }, netlink.RT_FILTER_OIF|netlink.RT_FILTER_PROTOCOL)
+ if err != nil {
+ return fmt.Errorf("netlink failed to list routes: %w", err)
+ }
+ ipv4DefaultRoute := net.IPNet{IP: net.IPv4zero, Mask: net.CIDRMask(0, 32)}
+ for _, route := range dhcpRoutes {
+ // Don't remove routes which can be atomically replaced by RouteReplace to prevent potential traffic
+ // disruptions.
+ if !isIPNetEqual(&ipv4DefaultRoute, route.Dst) && newRouter != nil {
+ continue
+ }
+ err := netlink.RouteDel(&route)
+ if !os.IsNotExist(err) && err != nil {
+ return fmt.Errorf("failed to delete DHCP route: %w", err)
+ }
+ }
+
+ if newRouter != nil {
+ err := netlink.RouteReplace(&netlink.Route{
+ Protocol: unix.RTPROT_DHCP,
+ Dst: &ipv4DefaultRoute,
+ Gw: newRouter,
+ Src: new.AssignedIP,
+ LinkIndex: iface.Attrs().Index,
+ Scope: netlink.SCOPE_UNIVERSE,
+ })
+ if err != nil {
+ return fmt.Errorf("failed to add default route via %s: %w", newRouter, err)
+ }
+ }
+ return nil
+ }
+}
diff --git a/metropolis/node/core/network/dhcp4c/callback/callback_test.go b/metropolis/node/core/network/dhcp4c/callback/callback_test.go
new file mode 100644
index 0000000..d533044
--- /dev/null
+++ b/metropolis/node/core/network/dhcp4c/callback/callback_test.go
@@ -0,0 +1,313 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package callback
+
+import (
+ "fmt"
+ "math"
+ "net"
+ "os"
+ "testing"
+ "time"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/network/dhcp4c"
+
+ "github.com/insomniacslk/dhcp/dhcpv4"
+ "github.com/stretchr/testify/require"
+ "github.com/vishvananda/netlink"
+ "golang.org/x/sys/unix"
+)
+
+func trivialLeaseFromNet(ipnet net.IPNet) *dhcp4c.Lease {
+ opts := make(dhcpv4.Options)
+ opts.Update(dhcpv4.OptSubnetMask(ipnet.Mask))
+ return &dhcp4c.Lease{
+ AssignedIP: ipnet.IP,
+ ExpiresAt: time.Now().Add(1 * time.Second),
+ Options: opts,
+ }
+}
+
+var (
+ testNet1 = net.IPNet{IP: net.IP{10, 0, 1, 2}, Mask: net.CIDRMask(24, 32)}
+ testNet1Broadcast = net.IP{10, 0, 1, 255}
+ testNet1Router = net.IP{10, 0, 1, 1}
+ testNet2 = net.IPNet{IP: net.IP{10, 0, 2, 2}, Mask: net.CIDRMask(24, 32)}
+ testNet2Broadcast = net.IP{10, 0, 2, 255}
+ testNet2Router = net.IP{10, 0, 2, 1}
+ mainRoutingTable = 254 // Linux automatically puts all routes into this table unless specified
+)
+
+func TestAssignedIPCallback(t *testing.T) {
+ if os.Getenv("IN_KTEST") != "true" {
+ t.Skip("Not in ktest")
+ }
+
+ var tests = []struct {
+ name string
+ initialAddrs []netlink.Addr
+ oldLease, newLease *dhcp4c.Lease
+ expectedAddrs []netlink.Addr
+ }{
+ { // Lifetimes are necessary, otherwise the Kernel sets the IFA_F_PERMANENT flag behind our back
+ name: "RemoveOldIPs",
+ initialAddrs: []netlink.Addr{{IPNet: &testNet1, ValidLft: 60}, {IPNet: &testNet2, ValidLft: 60}},
+ oldLease: nil,
+ newLease: nil,
+ expectedAddrs: nil,
+ },
+ {
+ name: "IgnoresPermanentIPs",
+ initialAddrs: []netlink.Addr{{IPNet: &testNet1, Flags: unix.IFA_F_PERMANENT}, {IPNet: &testNet2, ValidLft: 60}},
+ oldLease: nil,
+ newLease: trivialLeaseFromNet(testNet2),
+ expectedAddrs: []netlink.Addr{
+ {IPNet: &testNet1, Flags: unix.IFA_F_PERMANENT, ValidLft: math.MaxUint32, PreferedLft: math.MaxUint32, Broadcast: testNet1Broadcast},
+ {IPNet: &testNet2, ValidLft: 1, PreferedLft: 1, Broadcast: testNet2Broadcast},
+ },
+ },
+ {
+ name: "AssignsNewIP",
+ initialAddrs: []netlink.Addr{},
+ oldLease: nil,
+ newLease: trivialLeaseFromNet(testNet2),
+ expectedAddrs: []netlink.Addr{
+ {IPNet: &testNet2, ValidLft: 1, PreferedLft: 1, Broadcast: testNet2Broadcast},
+ },
+ },
+ {
+ name: "UpdatesIP",
+ initialAddrs: []netlink.Addr{},
+ oldLease: trivialLeaseFromNet(testNet2),
+ newLease: trivialLeaseFromNet(testNet1),
+ expectedAddrs: []netlink.Addr{
+ {IPNet: &testNet1, ValidLft: 1, PreferedLft: 1, Broadcast: testNet1Broadcast},
+ },
+ },
+ {
+ name: "RemovesIPOnRelease",
+ initialAddrs: []netlink.Addr{{IPNet: &testNet1, ValidLft: 60, PreferedLft: 60}},
+ oldLease: trivialLeaseFromNet(testNet1),
+ newLease: nil,
+ expectedAddrs: nil,
+ },
+ }
+ for i, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ testLink := &netlink.Dummy{
+ LinkAttrs: netlink.LinkAttrs{
+ Name: fmt.Sprintf("aipcb-test-%d", i),
+ Flags: unix.IFF_UP,
+ },
+ }
+ if err := netlink.LinkAdd(testLink); err != nil {
+ t.Fatalf("test cannot set up network interface: %v", err)
+ }
+ defer netlink.LinkDel(testLink)
+ for _, addr := range test.initialAddrs {
+ if err := netlink.AddrAdd(testLink, &addr); err != nil {
+ t.Fatalf("test cannot set up initial addrs: %v", err)
+ }
+ }
+ // Associate dynamically-generated interface name for later comparison
+ for i := range test.expectedAddrs {
+ test.expectedAddrs[i].Label = testLink.Name
+ }
+ cb := ManageIP(testLink)
+ if err := cb(test.oldLease, test.newLease); err != nil {
+ t.Fatalf("callback returned an error: %v", err)
+ }
+ addrs, err := netlink.AddrList(testLink, netlink.FAMILY_V4)
+ if err != nil {
+ t.Fatalf("test cannot read back addrs from interface: %v", err)
+ }
+ require.Equal(t, test.expectedAddrs, addrs, "Wrong IPs on interface")
+ })
+ }
+}
+
+func leaseAddRouter(lease *dhcp4c.Lease, router net.IP) *dhcp4c.Lease {
+ lease.Options.Update(dhcpv4.OptRouter(router))
+ return lease
+}
+
+func TestDefaultRouteCallback(t *testing.T) {
+ if os.Getenv("IN_KTEST") != "true" {
+ t.Skip("Not in ktest")
+ }
+ // testRoute is only used as a route destination and not configured on any interface.
+ testRoute := net.IPNet{IP: net.IP{10, 0, 3, 0}, Mask: net.CIDRMask(24, 32)}
+
+ // A test interface is set up for each test and assigned testNet1 and testNet2 so that testNet1Router and
+ // testNet2Router are valid gateways for routes in this environment. A LinkIndex of -1 is replaced by the correct
+ // link index for this test interface at runtime for both initialRoutes and expectedRoutes.
+ var tests = []struct {
+ name string
+ initialRoutes []netlink.Route
+ oldLease, newLease *dhcp4c.Lease
+ expectedRoutes []netlink.Route
+ }{
+ {
+ name: "AddsDefaultRoute",
+ initialRoutes: []netlink.Route{},
+ oldLease: nil,
+ newLease: leaseAddRouter(trivialLeaseFromNet(testNet1), testNet1Router),
+ expectedRoutes: []netlink.Route{{
+ Protocol: unix.RTPROT_DHCP,
+ Dst: nil, // Linux weirdly retuns no RTA_DST for default routes, but one for everything else
+ Gw: testNet1Router,
+ Src: testNet1.IP,
+ Table: mainRoutingTable,
+ LinkIndex: -1, // Filled in dynamically with test interface
+ Type: unix.RTN_UNICAST,
+ }},
+ },
+ {
+ name: "IgnoresLeasesWithoutRouter",
+ initialRoutes: []netlink.Route{},
+ oldLease: nil,
+ newLease: trivialLeaseFromNet(testNet1),
+ expectedRoutes: nil,
+ },
+ {
+ name: "RemovesUnrelatedOldRoutes",
+ initialRoutes: []netlink.Route{{
+ Dst: &testRoute,
+ LinkIndex: -1, // Filled in dynamically with test interface
+ Protocol: unix.RTPROT_DHCP,
+ Gw: testNet2Router,
+ Scope: netlink.SCOPE_UNIVERSE,
+ }},
+ oldLease: nil,
+ newLease: nil,
+ expectedRoutes: nil,
+ },
+ {
+ name: "IgnoresNonDHCPRoutes",
+ initialRoutes: []netlink.Route{{
+ Dst: &testRoute,
+ LinkIndex: -1, // Filled in dynamically with test interface
+ Protocol: unix.RTPROT_BIRD,
+ Gw: testNet2Router,
+ }},
+ oldLease: trivialLeaseFromNet(testNet1),
+ newLease: nil,
+ expectedRoutes: []netlink.Route{{
+ Protocol: unix.RTPROT_BIRD,
+ Dst: &testRoute,
+ Gw: testNet2Router,
+ Table: mainRoutingTable,
+ LinkIndex: -1, // Filled in dynamically with test interface
+ Type: unix.RTN_UNICAST,
+ }},
+ },
+ {
+ name: "RemovesRoute",
+ initialRoutes: []netlink.Route{{
+ Dst: nil,
+ LinkIndex: -1, // Filled in dynamically with test interface
+ Protocol: unix.RTPROT_DHCP,
+ Gw: testNet2Router,
+ }},
+ oldLease: leaseAddRouter(trivialLeaseFromNet(testNet2), testNet2Router),
+ newLease: nil,
+ expectedRoutes: nil,
+ },
+ {
+ name: "UpdatesRoute",
+ initialRoutes: []netlink.Route{{
+ Dst: nil,
+ LinkIndex: -1, // Filled in dynamically with test interface
+ Protocol: unix.RTPROT_DHCP,
+ Src: testNet1.IP,
+ Gw: testNet1Router,
+ }},
+ oldLease: leaseAddRouter(trivialLeaseFromNet(testNet1), testNet1Router),
+ newLease: leaseAddRouter(trivialLeaseFromNet(testNet2), testNet2Router),
+ expectedRoutes: []netlink.Route{{
+ Protocol: unix.RTPROT_DHCP,
+ Dst: nil,
+ Gw: testNet2Router,
+ Src: testNet2.IP,
+ Table: mainRoutingTable,
+ LinkIndex: -1, // Filled in dynamically with test interface
+ Type: unix.RTN_UNICAST,
+ }},
+ },
+ }
+ for i, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ testLink := &netlink.Dummy{
+ LinkAttrs: netlink.LinkAttrs{
+ Name: fmt.Sprintf("drcb-test-%d", i),
+ Flags: unix.IFF_UP,
+ },
+ }
+ if err := netlink.LinkAdd(testLink); err != nil {
+ t.Fatalf("test cannot set up network interface: %v", err)
+ }
+ defer func() { // Clean up after each test
+ routes, err := netlink.RouteListFiltered(netlink.FAMILY_V4, &netlink.Route{}, 0)
+ if err == nil {
+ for _, route := range routes {
+ netlink.RouteDel(&route)
+ }
+ }
+ }()
+ defer netlink.LinkDel(testLink)
+ if err := netlink.AddrAdd(testLink, &netlink.Addr{
+ IPNet: &testNet1,
+ }); err != nil {
+ t.Fatalf("test cannot set up test addrs: %v", err)
+ }
+ if err := netlink.AddrAdd(testLink, &netlink.Addr{
+ IPNet: &testNet2,
+ }); err != nil {
+ t.Fatalf("test cannot set up test addrs: %v", err)
+ }
+ for _, route := range test.initialRoutes {
+ if route.LinkIndex == -1 {
+ route.LinkIndex = testLink.Index
+ }
+ if err := netlink.RouteAdd(&route); err != nil {
+ t.Fatalf("test cannot set up initial routes: %v", err)
+ }
+ }
+ for i := range test.expectedRoutes {
+ if test.expectedRoutes[i].LinkIndex == -1 {
+ test.expectedRoutes[i].LinkIndex = testLink.Index
+ }
+ }
+
+ cb := ManageDefaultRoute(testLink)
+ if err := cb(test.oldLease, test.newLease); err != nil {
+ t.Fatalf("callback returned an error: %v", err)
+ }
+ routes, err := netlink.RouteListFiltered(netlink.FAMILY_V4, &netlink.Route{}, 0)
+ if err != nil {
+ t.Fatalf("test cannot read back routes: %v", err)
+ }
+ var notKernelRoutes []netlink.Route
+ for _, route := range routes {
+ if route.Protocol != unix.RTPROT_KERNEL { // Filter kernel-managed routes
+ notKernelRoutes = append(notKernelRoutes, route)
+ }
+ }
+ require.Equal(t, test.expectedRoutes, notKernelRoutes, "Wrong Routes")
+ })
+ }
+}
diff --git a/metropolis/node/core/network/dhcp4c/dhcpc.go b/metropolis/node/core/network/dhcp4c/dhcpc.go
new file mode 100644
index 0000000..4352506
--- /dev/null
+++ b/metropolis/node/core/network/dhcp4c/dhcpc.go
@@ -0,0 +1,677 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package dhcp4c implements a DHCPv4 Client as specified in RFC2131 (with some notable deviations).
+// It implements only the DHCP state machine itself, any configuration other than the interface IP
+// address (which is always assigned in DHCP and necessary for the protocol to work) is exposed
+// as [informers/observables/watchable variables/???] to consumers who then deal with it.
+package dhcp4c
+
+import (
+ "context"
+ "crypto/rand"
+ "errors"
+ "fmt"
+ "io"
+ "math"
+ "net"
+ "time"
+
+ "github.com/cenkalti/backoff/v4"
+ "github.com/insomniacslk/dhcp/dhcpv4"
+ "github.com/insomniacslk/dhcp/iana"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/node/common/supervisor"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/network/dhcp4c/transport"
+)
+
+type state int
+
+const (
+ // stateDiscovering sends broadcast DHCPDISCOVER messages to the network and waits for either a DHCPOFFER or
+ // (in case of Rapid Commit) DHCPACK.
+ stateDiscovering state = iota
+ // stateRequesting sends broadcast DHCPREQUEST messages containing the server identifier for the selected lease and
+ // waits for a DHCPACK or a DHCPNAK. If it doesn't get either it transitions back into discovering.
+ stateRequesting
+ // stateBound just waits until RenewDeadline (derived from RenewTimeValue, half the lifetime by default) expires.
+ stateBound
+ // stateRenewing sends unicast DHCPREQUEST messages to the currently-selected server and waits for either a DHCPACK
+ // or DHCPNAK message. On DHCPACK it transitions to bound, otherwise to discovering.
+ stateRenewing
+ // stateRebinding sends broadcast DHCPREQUEST messages to the network and waits for either a DHCPACK or DHCPNAK from
+ // any server. Response processing is identical to stateRenewing.
+ stateRebinding
+)
+
+func (s state) String() string {
+ switch s {
+ case stateDiscovering:
+ return "DISCOVERING"
+ case stateRequesting:
+ return "REQUESTING"
+ case stateBound:
+ return "BOUND"
+ case stateRenewing:
+ return "RENEWING"
+ case stateRebinding:
+ return "REBINDING"
+ default:
+ return "INVALID"
+ }
+}
+
+// This only requests SubnetMask and IPAddressLeaseTime as renewal and rebinding times are fine if
+// they are just defaulted. They are respected (if valid, otherwise they are clamped to the nearest
+// valid value) if sent by the server.
+var internalOptions = dhcpv4.OptionCodeList{dhcpv4.OptionSubnetMask, dhcpv4.OptionIPAddressLeaseTime}
+
+// Transport represents a mechanism over which DHCP messages can be exchanged with a server.
+type Transport interface {
+ // Send attempts to send the given DHCP payload message to the transport target once. An empty return value
+ // does not indicate that the message was successfully received.
+ Send(payload *dhcpv4.DHCPv4) error
+ // SetReceiveDeadline sets a deadline for Receive() calls after which they return with DeadlineExceededErr
+ SetReceiveDeadline(time.Time) error
+ // Receive waits for a DHCP message to arrive and returns it. If the deadline expires without a message arriving
+ // it will return DeadlineExceededErr. If the message is completely malformed it will an instance of
+ // InvalidMessageError.
+ Receive() (*dhcpv4.DHCPv4, error)
+ // Close closes the given transport. Calls to any of the above methods will fail if the transport is closed.
+ // Specific transports can be reopened after being closed.
+ Close() error
+}
+
+// UnicastTransport represents a mechanism over which DHCP messages can be exchanged with a single server over an
+// arbitrary IPv4-based network. Implementers need to support servers running outside the local network via a router.
+type UnicastTransport interface {
+ Transport
+ // Open connects the transport to a new unicast target. Can only be called after calling Close() or after creating
+ // a new transport.
+ Open(serverIP, bindIP net.IP) error
+}
+
+// BroadcastTransport represents a mechanism over which DHCP messages can be exchanged with all servers on a Layer 2
+// broadcast domain. Implementers need to support sending and receiving messages without any IP being configured on
+// the interface.
+type BroadcastTransport interface {
+ Transport
+ // Open connects the transport. Can only be called after calling Close() or after creating a new transport.
+ Open() error
+}
+
+type LeaseCallback func(old, new *Lease) error
+
+// Client implements a DHCPv4 client.
+//
+// Note that the size of all data sent to the server (RequestedOptions, ClientIdentifier,
+// VendorClassIdentifier and ExtraRequestOptions) should be kept reasonably small (<500 bytes) in
+// order to maximize the chance that requests can be properly transmitted.
+type Client struct {
+ // RequestedOptions contains a list of extra options this client is interested in
+ RequestedOptions dhcpv4.OptionCodeList
+
+ // ClientIdentifier is used by the DHCP server to identify this client.
+ // If empty, on Ethernet the MAC address is used instead.
+ ClientIdentifier []byte
+
+ // VendorClassIdentifier is used by the DHCP server to identify options specific to this type of
+ // clients and to populate the vendor-specific option (43).
+ VendorClassIdentifier string
+
+ // ExtraRequestOptions are extra options sent to the server.
+ ExtraRequestOptions dhcpv4.Options
+
+ // Backoff strategies for each state. These all have sane defaults, override them only if
+ // necessary.
+ DiscoverBackoff backoff.BackOff
+ AcceptOfferBackoff backoff.BackOff
+ RenewBackoff backoff.BackOff
+ RebindBackoff backoff.BackOff
+
+ state state
+
+ lastBoundTransition time.Time
+
+ iface *net.Interface
+
+ // now can be used to override time for testing
+ now func() time.Time
+
+ // LeaseCallback is called every time a lease is aquired, renewed or lost
+ LeaseCallback LeaseCallback
+
+ // Valid in states Discovering, Requesting, Rebinding
+ broadcastConn BroadcastTransport
+
+ // Valid in states Requesting
+ offer *dhcpv4.DHCPv4
+
+ // Valid in states Bound, Renewing
+ unicastConn UnicastTransport
+
+ // Valid in states Bound, Renewing, Rebinding
+ lease *dhcpv4.DHCPv4
+ leaseDeadline time.Time
+ leaseBoundDeadline time.Time
+ leaseRenewDeadline time.Time
+}
+
+// newDefaultBackoff returns an infinitely-retrying randomized exponential backoff with a
+// DHCP-appropriate InitialInterval
+func newDefaultBackoff() *backoff.ExponentialBackOff {
+ b := backoff.NewExponentialBackOff()
+ b.MaxElapsedTime = 0 // No Timeout
+ // Lots of servers wait 1s for existing users of an IP. Wait at least for that and keep some
+ // slack for randomization, communication and processing overhead.
+ b.InitialInterval = 1400 * time.Millisecond
+ b.MaxInterval = 30 * time.Second
+ b.RandomizationFactor = 0.2
+ return b
+}
+
+// NewClient instantiates (but doesn't start) a new DHCPv4 client.
+// To have a working client it's required to set LeaseCallback to something that is capable of configuring the IP
+// address on the given interface. Unless managed through external means like a routing protocol, setting the default
+// route is also required. A simple example with the callback package thus looks like this:
+// c := dhcp4c.NewClient(yourInterface)
+// c.LeaseCallback = callback.Compose(callback.ManageIP(yourInterface), callback.ManageDefaultRoute(yourInterface))
+// c.Run(ctx)
+func NewClient(iface *net.Interface) (*Client, error) {
+ broadcastConn := transport.NewBroadcastTransport(iface)
+
+ // broadcastConn needs to be open in stateDiscovering
+ if err := broadcastConn.Open(); err != nil {
+ return nil, fmt.Errorf("failed to create DHCP broadcast transport: %w", err)
+ }
+
+ discoverBackoff := newDefaultBackoff()
+
+ acceptOfferBackoff := newDefaultBackoff()
+ // Abort after 30s and go back to discovering
+ acceptOfferBackoff.MaxElapsedTime = 30 * time.Second
+
+ renewBackoff := newDefaultBackoff()
+ // Increase maximum interval to reduce chatter when the server is down
+ renewBackoff.MaxInterval = 5 * time.Minute
+
+ rebindBackoff := newDefaultBackoff()
+ // Increase maximum interval to reduce chatter when the server is down
+ renewBackoff.MaxInterval = 5 * time.Minute
+
+ return &Client{
+ state: stateDiscovering,
+ broadcastConn: broadcastConn,
+ unicastConn: transport.NewUnicastTransport(iface),
+ iface: iface,
+ RequestedOptions: dhcpv4.OptionCodeList{},
+ lastBoundTransition: time.Now(),
+ now: time.Now,
+ DiscoverBackoff: discoverBackoff,
+ AcceptOfferBackoff: acceptOfferBackoff,
+ RenewBackoff: renewBackoff,
+ RebindBackoff: rebindBackoff,
+ }, nil
+}
+
+// acceptableLease checks if the given lease is valid enough to even be processed. This is
+// intentionally not exposed to users because under certain cirumstances it can end up acquiring all
+// available IP addresses from a server.
+func (c *Client) acceptableLease(offer *dhcpv4.DHCPv4) bool {
+ // RFC2131 Section 4.3.1 Table 3
+ if offer.ServerIdentifier() == nil || offer.ServerIdentifier().To4() == nil {
+ return false
+ }
+ // RFC2131 Section 4.3.1 Table 3
+ // Minimum representable lease time is 1s (Section 1.1)
+ if offer.IPAddressLeaseTime(0) < 1*time.Second {
+ return false
+ }
+
+ // Ignore IPs that are in no way valid for an interface (multicast, loopback, ...)
+ if offer.YourIPAddr.To4() == nil || (!offer.YourIPAddr.IsGlobalUnicast() && !offer.YourIPAddr.IsLinkLocalUnicast()) {
+ return false
+ }
+
+ // Technically the options Requested IP address, Parameter request list, Client identifier
+ // and Maximum message size should be refused (MUST NOT), but in the interest of interopatibilty
+ // let's simply remove them if they are present.
+ delete(offer.Options, dhcpv4.OptionRequestedIPAddress.Code())
+ delete(offer.Options, dhcpv4.OptionParameterRequestList.Code())
+ delete(offer.Options, dhcpv4.OptionClientIdentifier.Code())
+ delete(offer.Options, dhcpv4.OptionMaximumDHCPMessageSize.Code())
+
+ // Clamp rebindinding times longer than the lease time. Otherwise the state machine might misbehave.
+ if offer.IPAddressRebindingTime(0) > offer.IPAddressLeaseTime(0) {
+ offer.UpdateOption(dhcpv4.OptGeneric(dhcpv4.OptionRebindingTimeValue, dhcpv4.Duration(offer.IPAddressLeaseTime(0)).ToBytes()))
+ }
+ // Clamp renewal times longer than the rebinding time. Otherwise the state machine might misbehave.
+ if offer.IPAddressRenewalTime(0) > offer.IPAddressRebindingTime(0) {
+ offer.UpdateOption(dhcpv4.OptGeneric(dhcpv4.OptionRenewTimeValue, dhcpv4.Duration(offer.IPAddressRebindingTime(0)).ToBytes()))
+ }
+
+ // Normalize two options that can be represented either inline or as options.
+ if len(offer.ServerHostName) > 0 {
+ offer.Options[uint8(dhcpv4.OptionTFTPServerName)] = []byte(offer.ServerHostName)
+ }
+ if len(offer.BootFileName) > 0 {
+ offer.Options[uint8(dhcpv4.OptionBootfileName)] = []byte(offer.BootFileName)
+ }
+
+ // Normalize siaddr to option 150 (see RFC5859)
+ if len(offer.GetOneOption(dhcpv4.OptionTFTPServerAddress)) == 0 {
+ if offer.ServerIPAddr.To4() != nil && (offer.ServerIPAddr.IsGlobalUnicast() || offer.ServerIPAddr.IsLinkLocalUnicast()) {
+ offer.Options[uint8(dhcpv4.OptionTFTPServerAddress)] = offer.ServerIPAddr.To4()
+ }
+ }
+
+ return true
+}
+
+func earliestDeadline(dl1, dl2 time.Time) time.Time {
+ if dl1.Before(dl2) {
+ return dl1
+ } else {
+ return dl2
+ }
+}
+
+// newXID generates a new transaction ID
+func (c *Client) newXID() (dhcpv4.TransactionID, error) {
+ var xid dhcpv4.TransactionID
+ if _, err := io.ReadFull(rand.Reader, xid[:]); err != nil {
+ return xid, fmt.Errorf("cannot read randomness for transaction ID: %w", err)
+ }
+ return xid, nil
+}
+
+// As most servers out there cannot do reassembly, let's just hope for the best and
+// provide the local interface MTU. If the packet is too big it won't work anyways.
+// Also clamp to the biggest representable MTU in DHCPv4 (2 bytes unsigned int).
+func (c *Client) maxMsgSize() uint16 {
+ if c.iface.MTU < math.MaxUint16 {
+ return uint16(c.iface.MTU)
+ } else {
+ return math.MaxUint16
+ }
+}
+
+// newMsg creates a new DHCP message of a given type and adds common options.
+func (c *Client) newMsg(t dhcpv4.MessageType) (*dhcpv4.DHCPv4, error) {
+ xid, err := c.newXID()
+ if err != nil {
+ return nil, err
+ }
+ opts := make(dhcpv4.Options)
+ opts.Update(dhcpv4.OptMessageType(t))
+ if len(c.ClientIdentifier) > 0 {
+ opts.Update(dhcpv4.OptClientIdentifier(c.ClientIdentifier))
+ }
+ if t == dhcpv4.MessageTypeDiscover || t == dhcpv4.MessageTypeRequest || t == dhcpv4.MessageTypeInform {
+ opts.Update(dhcpv4.OptParameterRequestList(append(c.RequestedOptions, internalOptions...)...))
+ opts.Update(dhcpv4.OptMaxMessageSize(c.maxMsgSize()))
+ if c.VendorClassIdentifier != "" {
+ opts.Update(dhcpv4.OptClassIdentifier(c.VendorClassIdentifier))
+ }
+ for opt, val := range c.ExtraRequestOptions {
+ opts[opt] = val
+ }
+ }
+ return &dhcpv4.DHCPv4{
+ OpCode: dhcpv4.OpcodeBootRequest,
+ HWType: iana.HWTypeEthernet,
+ ClientHWAddr: c.iface.HardwareAddr,
+ HopCount: 0,
+ TransactionID: xid,
+ NumSeconds: 0,
+ Flags: 0,
+ ClientIPAddr: net.IPv4zero,
+ YourIPAddr: net.IPv4zero,
+ ServerIPAddr: net.IPv4zero,
+ GatewayIPAddr: net.IPv4zero,
+ Options: opts,
+ }, nil
+}
+
+// transactionStateSpec describes a state which is driven by a DHCP message transaction (sending a
+// specific message and then transitioning into a different state depending on the received messages)
+type transactionStateSpec struct {
+ // ctx is a context for canceling the process
+ ctx context.Context
+
+ // transport is used to send and receive messages in this state
+ transport Transport
+
+ // stateDeadline is a fixed external deadline for how long the FSM can remain in this state.
+ // If it's exceeded the stateDeadlineExceeded callback is called and responsible for
+ // transitioning out of this state. It can be left empty to signal that there's no external
+ // deadline for the state.
+ stateDeadline time.Time
+
+ // backoff controls how long to wait for answers until handing control back to the FSM.
+ // Since the FSM hasn't advanced until then this means we just get called again and retransmit.
+ backoff backoff.BackOff
+
+ // requestType is the type of DHCP request sent out in this state. This is used to populate
+ // the default options for the message.
+ requestType dhcpv4.MessageType
+
+ // setExtraOptions can modify the request and set extra options before transmitting. Returning
+ // an error here aborts the FSM an can be used to terminate when no valid request can be
+ // constructed.
+ setExtraOptions func(msg *dhcpv4.DHCPv4) error
+
+ // handleMessage gets called for every parseable (not necessarily valid) DHCP message received
+ // by the transport. It should return an error for every message that doesn't advance the
+ // state machine and no error for every one that does. It is responsible for advancing the FSM
+ // if the required information is present.
+ handleMessage func(msg *dhcpv4.DHCPv4, sentTime time.Time) error
+
+ // stateDeadlineExceeded gets called if either the backoff returns backoff.Stop or the
+ // stateDeadline runs out. It is responsible for advancing the FSM into the next state.
+ stateDeadlineExceeded func() error
+}
+
+func (c *Client) runTransactionState(s transactionStateSpec) error {
+ sentTime := c.now()
+ msg, err := c.newMsg(s.requestType)
+ if err != nil {
+ return fmt.Errorf("failed to get new DHCP message: %w", err)
+ }
+ if err := s.setExtraOptions(msg); err != nil {
+ return fmt.Errorf("failed to create DHCP message: %w", err)
+ }
+
+ wait := s.backoff.NextBackOff()
+ if wait == backoff.Stop {
+ return s.stateDeadlineExceeded()
+ }
+
+ receiveDeadline := sentTime.Add(wait)
+ if !s.stateDeadline.IsZero() {
+ receiveDeadline = earliestDeadline(s.stateDeadline, receiveDeadline)
+ }
+
+ // Jump out if deadline expires in less than 10ms. Minimum lease time is 1s and if we have less
+ // than 10ms to wait for an answer before switching state it makes no sense to send out another
+ // request. This nearly eliminates the problem of sending two different requests back-to-back.
+ if receiveDeadline.Add(-10 * time.Millisecond).Before(sentTime) {
+ return s.stateDeadlineExceeded()
+ }
+
+ if err := s.transport.Send(msg); err != nil {
+ return fmt.Errorf("failed to send message: %w", err)
+ }
+
+ if err := s.transport.SetReceiveDeadline(receiveDeadline); err != nil {
+ return fmt.Errorf("failed to set deadline: %w", err)
+ }
+
+ for {
+ offer, err := s.transport.Receive()
+ select {
+ case <-s.ctx.Done():
+ c.cleanup()
+ return s.ctx.Err()
+ default:
+ }
+ if errors.Is(err, transport.DeadlineExceededErr) {
+ return nil
+ }
+ var e transport.InvalidMessageError
+ if errors.As(err, &e) {
+ // Packet couldn't be read. Maybe log at some point in the future.
+ continue
+ }
+ if err != nil {
+ return fmt.Errorf("failed to receive packet: %w", err)
+ }
+ if offer.TransactionID != msg.TransactionID { // Not our transaction
+ continue
+ }
+ err = s.handleMessage(offer, sentTime)
+ if err == nil {
+ return nil
+ } else if !errors.Is(err, InvalidMsgErr) {
+ return err
+ }
+ }
+}
+
+var InvalidMsgErr = errors.New("invalid message")
+
+func (c *Client) runState(ctx context.Context) error {
+ switch c.state {
+ case stateDiscovering:
+ return c.runTransactionState(transactionStateSpec{
+ ctx: ctx,
+ transport: c.broadcastConn,
+ backoff: c.DiscoverBackoff,
+ requestType: dhcpv4.MessageTypeDiscover,
+ setExtraOptions: func(msg *dhcpv4.DHCPv4) error {
+ msg.UpdateOption(dhcpv4.OptGeneric(dhcpv4.OptionRapidCommit, []byte{}))
+ return nil
+ },
+ handleMessage: func(offer *dhcpv4.DHCPv4, sentTime time.Time) error {
+ switch offer.MessageType() {
+ case dhcpv4.MessageTypeOffer:
+ if c.acceptableLease(offer) {
+ c.offer = offer
+ c.AcceptOfferBackoff.Reset()
+ c.state = stateRequesting
+ return nil
+ }
+ case dhcpv4.MessageTypeAck:
+ if c.acceptableLease(offer) {
+ return c.transitionToBound(offer, sentTime)
+ }
+ }
+ return InvalidMsgErr
+ },
+ })
+ case stateRequesting:
+ return c.runTransactionState(transactionStateSpec{
+ ctx: ctx,
+ transport: c.broadcastConn,
+ backoff: c.AcceptOfferBackoff,
+ requestType: dhcpv4.MessageTypeRequest,
+ setExtraOptions: func(msg *dhcpv4.DHCPv4) error {
+ msg.UpdateOption(dhcpv4.OptServerIdentifier(c.offer.ServerIdentifier()))
+ msg.TransactionID = c.offer.TransactionID
+ msg.UpdateOption(dhcpv4.OptRequestedIPAddress(c.offer.YourIPAddr))
+ return nil
+ },
+ handleMessage: func(msg *dhcpv4.DHCPv4, sentTime time.Time) error {
+ switch msg.MessageType() {
+ case dhcpv4.MessageTypeAck:
+ if c.acceptableLease(msg) {
+ return c.transitionToBound(msg, sentTime)
+ }
+ case dhcpv4.MessageTypeNak:
+ c.requestingToDiscovering()
+ return nil
+ }
+ return InvalidMsgErr
+ },
+ stateDeadlineExceeded: func() error {
+ c.requestingToDiscovering()
+ return nil
+ },
+ })
+ case stateBound:
+ select {
+ case <-time.After(c.leaseBoundDeadline.Sub(c.now())):
+ c.state = stateRenewing
+ c.RenewBackoff.Reset()
+ return nil
+ case <-ctx.Done():
+ c.cleanup()
+ return ctx.Err()
+ }
+ case stateRenewing:
+ return c.runTransactionState(transactionStateSpec{
+ ctx: ctx,
+ transport: c.unicastConn,
+ backoff: c.RenewBackoff,
+ requestType: dhcpv4.MessageTypeRequest,
+ stateDeadline: c.leaseRenewDeadline,
+ setExtraOptions: func(msg *dhcpv4.DHCPv4) error {
+ msg.ClientIPAddr = c.lease.YourIPAddr
+ return nil
+ },
+ handleMessage: func(ack *dhcpv4.DHCPv4, sentTime time.Time) error {
+ switch ack.MessageType() {
+ case dhcpv4.MessageTypeAck:
+ if c.acceptableLease(ack) {
+ return c.transitionToBound(ack, sentTime)
+ }
+ case dhcpv4.MessageTypeNak:
+ return c.leaseToDiscovering()
+ }
+ return InvalidMsgErr
+ },
+ stateDeadlineExceeded: func() error {
+ c.state = stateRebinding
+ if err := c.switchToBroadcast(); err != nil {
+ return fmt.Errorf("failed to switch to broadcast: %w", err)
+ }
+ c.RebindBackoff.Reset()
+ return nil
+ },
+ })
+ case stateRebinding:
+ return c.runTransactionState(transactionStateSpec{
+ ctx: ctx,
+ transport: c.broadcastConn,
+ backoff: c.RebindBackoff,
+ stateDeadline: c.leaseDeadline,
+ requestType: dhcpv4.MessageTypeRequest,
+ setExtraOptions: func(msg *dhcpv4.DHCPv4) error {
+ msg.ClientIPAddr = c.lease.YourIPAddr
+ return nil
+ },
+ handleMessage: func(ack *dhcpv4.DHCPv4, sentTime time.Time) error {
+ switch ack.MessageType() {
+ case dhcpv4.MessageTypeAck:
+ if c.acceptableLease(ack) {
+ return c.transitionToBound(ack, sentTime)
+ }
+ case dhcpv4.MessageTypeNak:
+ return c.leaseToDiscovering()
+ }
+ return InvalidMsgErr
+ },
+ stateDeadlineExceeded: func() error {
+ return c.leaseToDiscovering()
+ },
+ })
+ }
+ return errors.New("state machine in invalid state")
+}
+
+func (c *Client) Run(ctx context.Context) error {
+ if c.LeaseCallback == nil {
+ panic("LeaseCallback must be set before calling Run")
+ }
+ logger := supervisor.Logger(ctx)
+ for {
+ oldState := c.state
+ if err := c.runState(ctx); err != nil {
+ return err
+ }
+ if c.state != oldState {
+ logger.Infof("%s => %s", oldState, c.state)
+ }
+ }
+}
+
+func (c *Client) cleanup() {
+ c.unicastConn.Close()
+ if c.lease != nil {
+ c.LeaseCallback(leaseFromAck(c.lease, c.leaseDeadline), nil)
+ }
+ c.broadcastConn.Close()
+}
+
+func (c *Client) requestingToDiscovering() {
+ c.offer = nil
+ c.DiscoverBackoff.Reset()
+ c.state = stateDiscovering
+}
+
+func (c *Client) leaseToDiscovering() error {
+ if c.state == stateRenewing {
+ if err := c.switchToBroadcast(); err != nil {
+ return err
+ }
+ }
+ c.state = stateDiscovering
+ c.DiscoverBackoff.Reset()
+ if err := c.LeaseCallback(leaseFromAck(c.lease, c.leaseDeadline), nil); err != nil {
+ return fmt.Errorf("lease callback failed: %w", err)
+ }
+ c.lease = nil
+ return nil
+}
+
+func leaseFromAck(ack *dhcpv4.DHCPv4, expiresAt time.Time) *Lease {
+ if ack == nil {
+ return nil
+ }
+ return &Lease{Options: ack.Options, AssignedIP: ack.YourIPAddr, ExpiresAt: expiresAt}
+}
+
+func (c *Client) transitionToBound(ack *dhcpv4.DHCPv4, sentTime time.Time) error {
+ // Guaranteed to exist, leases without a lease time are filtered
+ leaseTime := ack.IPAddressLeaseTime(0)
+ origLeaseDeadline := c.leaseDeadline
+ c.leaseDeadline = sentTime.Add(leaseTime)
+ c.leaseBoundDeadline = sentTime.Add(ack.IPAddressRenewalTime(time.Duration(float64(leaseTime) * 0.5)))
+ c.leaseRenewDeadline = sentTime.Add(ack.IPAddressRebindingTime(time.Duration(float64(leaseTime) * 0.85)))
+
+ if err := c.LeaseCallback(leaseFromAck(c.lease, origLeaseDeadline), leaseFromAck(ack, c.leaseDeadline)); err != nil {
+ return fmt.Errorf("lease callback failed: %w", err)
+ }
+
+ if c.state != stateRenewing {
+ if err := c.switchToUnicast(ack.ServerIdentifier(), ack.YourIPAddr); err != nil {
+ return fmt.Errorf("failed to switch transports: %w", err)
+ }
+ }
+ c.state = stateBound
+ c.lease = ack
+ return nil
+}
+
+func (c *Client) switchToUnicast(serverIP, bindIP net.IP) error {
+ if err := c.broadcastConn.Close(); err != nil {
+ return fmt.Errorf("failed to close broadcast transport: %w", err)
+ }
+ if err := c.unicastConn.Open(serverIP, bindIP); err != nil {
+ return fmt.Errorf("failed to open unicast transport: %w", err)
+ }
+ return nil
+}
+
+func (c *Client) switchToBroadcast() error {
+ if err := c.unicastConn.Close(); err != nil {
+ return fmt.Errorf("failed to close unicast transport: %w", err)
+ }
+ if err := c.broadcastConn.Open(); err != nil {
+ return fmt.Errorf("failed to open broadcast transport: %w", err)
+ }
+ return nil
+}
diff --git a/metropolis/node/core/network/dhcp4c/dhcpc_test.go b/metropolis/node/core/network/dhcp4c/dhcpc_test.go
new file mode 100644
index 0000000..6988da2
--- /dev/null
+++ b/metropolis/node/core/network/dhcp4c/dhcpc_test.go
@@ -0,0 +1,514 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dhcp4c
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "testing"
+ "time"
+
+ "github.com/cenkalti/backoff/v4"
+ "github.com/insomniacslk/dhcp/dhcpv4"
+ "github.com/stretchr/testify/assert"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/network/dhcp4c/transport"
+)
+
+type fakeTime struct {
+ time time.Time
+}
+
+func newFakeTime(t time.Time) *fakeTime {
+ return &fakeTime{
+ time: t,
+ }
+}
+
+func (ft *fakeTime) Now() time.Time {
+ return ft.time
+}
+
+func (ft *fakeTime) Advance(d time.Duration) {
+ ft.time = ft.time.Add(d)
+}
+
+type mockTransport struct {
+ sentPacket *dhcpv4.DHCPv4
+ sendError error
+ setDeadline time.Time
+ receivePackets []*dhcpv4.DHCPv4
+ receiveError error
+ receiveIdx int
+ closed bool
+}
+
+func (mt *mockTransport) sendPackets(pkts ...*dhcpv4.DHCPv4) {
+ mt.receiveIdx = 0
+ mt.receivePackets = pkts
+}
+
+func (mt *mockTransport) Open() error {
+ mt.closed = false
+ return nil
+}
+
+func (mt *mockTransport) Send(payload *dhcpv4.DHCPv4) error {
+ mt.sentPacket = payload
+ return mt.sendError
+}
+
+func (mt *mockTransport) Receive() (*dhcpv4.DHCPv4, error) {
+ if mt.receiveError != nil {
+ return nil, mt.receiveError
+ }
+ if len(mt.receivePackets) > mt.receiveIdx {
+ packet := mt.receivePackets[mt.receiveIdx]
+ packet, err := dhcpv4.FromBytes(packet.ToBytes()) // Clone packet
+ if err != nil {
+ panic("ToBytes => FromBytes failed")
+ }
+ packet.TransactionID = mt.sentPacket.TransactionID
+ mt.receiveIdx++
+ return packet, nil
+ }
+ return nil, transport.DeadlineExceededErr
+}
+
+func (mt *mockTransport) SetReceiveDeadline(t time.Time) error {
+ mt.setDeadline = t
+ return nil
+}
+
+func (mt *mockTransport) Close() error {
+ mt.closed = true
+ return nil
+}
+
+type unicastMockTransport struct {
+ mockTransport
+ serverIP net.IP
+ bindIP net.IP
+}
+
+func (umt *unicastMockTransport) Open(serverIP, bindIP net.IP) error {
+ if umt.serverIP != nil {
+ panic("double-open of unicast transport")
+ }
+ umt.serverIP = serverIP
+ umt.bindIP = bindIP
+ return nil
+}
+
+func (umt *unicastMockTransport) Close() error {
+ umt.serverIP = nil
+ umt.bindIP = nil
+ return umt.mockTransport.Close()
+}
+
+type mockBackoff struct {
+ indefinite bool
+ values []time.Duration
+ idx int
+}
+
+func newMockBackoff(vals []time.Duration, indefinite bool) *mockBackoff {
+ return &mockBackoff{values: vals, indefinite: indefinite}
+}
+
+func (mb *mockBackoff) NextBackOff() time.Duration {
+ if mb.idx < len(mb.values) || mb.indefinite {
+ val := mb.values[mb.idx%len(mb.values)]
+ mb.idx++
+ return val
+ }
+ return backoff.Stop
+}
+
+func (mb *mockBackoff) Reset() {
+ mb.idx = 0
+}
+
+func TestClient_runTransactionState(t *testing.T) {
+ ft := newFakeTime(time.Date(2020, 10, 28, 15, 02, 32, 352, time.UTC))
+ c := Client{
+ now: ft.Now,
+ iface: &net.Interface{MTU: 9324, HardwareAddr: net.HardwareAddr{0x12, 0x23, 0x34, 0x45, 0x56, 0x67}},
+ }
+ mt := &mockTransport{}
+ err := c.runTransactionState(transactionStateSpec{
+ ctx: context.Background(),
+ transport: mt,
+ backoff: newMockBackoff([]time.Duration{1 * time.Second}, true),
+ requestType: dhcpv4.MessageTypeDiscover,
+ setExtraOptions: func(msg *dhcpv4.DHCPv4) error {
+ msg.UpdateOption(dhcpv4.OptDomainName("just.testing.invalid"))
+ return nil
+ },
+ handleMessage: func(msg *dhcpv4.DHCPv4, sentTime time.Time) error {
+ return nil
+ },
+ stateDeadlineExceeded: func() error {
+ panic("shouldn't be called")
+ },
+ })
+ assert.NoError(t, err)
+ assert.Equal(t, "just.testing.invalid", mt.sentPacket.DomainName())
+ assert.Equal(t, dhcpv4.MessageTypeDiscover, mt.sentPacket.MessageType())
+}
+
+// TestAcceptableLease tests if a minimal valid lease is accepted by acceptableLease
+func TestAcceptableLease(t *testing.T) {
+ c := Client{}
+ offer := &dhcpv4.DHCPv4{
+ OpCode: dhcpv4.OpcodeBootReply,
+ }
+ offer.UpdateOption(dhcpv4.OptMessageType(dhcpv4.MessageTypeOffer))
+ offer.UpdateOption(dhcpv4.OptServerIdentifier(net.IP{192, 0, 2, 1}))
+ offer.UpdateOption(dhcpv4.OptIPAddressLeaseTime(10 * time.Second))
+ offer.YourIPAddr = net.IP{192, 0, 2, 2}
+ assert.True(t, c.acceptableLease(offer), "Valid lease is not acceptable")
+}
+
+type dhcpClientPuppet struct {
+ ft *fakeTime
+ bmt *mockTransport
+ umt *unicastMockTransport
+ c *Client
+}
+
+func newPuppetClient(initState state) *dhcpClientPuppet {
+ ft := newFakeTime(time.Date(2020, 10, 28, 15, 02, 32, 352, time.UTC))
+ bmt := &mockTransport{}
+ umt := &unicastMockTransport{}
+ c := &Client{
+ state: initState,
+ now: ft.Now,
+ iface: &net.Interface{MTU: 9324, HardwareAddr: net.HardwareAddr{0x12, 0x23, 0x34, 0x45, 0x56, 0x67}},
+ broadcastConn: bmt,
+ unicastConn: umt,
+ DiscoverBackoff: newMockBackoff([]time.Duration{1 * time.Second}, true),
+ AcceptOfferBackoff: newMockBackoff([]time.Duration{1 * time.Second, 2 * time.Second}, false),
+ RenewBackoff: newMockBackoff([]time.Duration{1 * time.Second}, true),
+ RebindBackoff: newMockBackoff([]time.Duration{1 * time.Second}, true),
+ }
+ return &dhcpClientPuppet{
+ ft: ft,
+ bmt: bmt,
+ umt: umt,
+ c: c,
+ }
+}
+
+func newResponse(m dhcpv4.MessageType) *dhcpv4.DHCPv4 {
+ o := &dhcpv4.DHCPv4{
+ OpCode: dhcpv4.OpcodeBootReply,
+ }
+ o.UpdateOption(dhcpv4.OptMessageType(m))
+ return o
+}
+
+// TestDiscoverOffer tests if the DHCP state machine in discovering state properly selects the first valid lease
+// and transitions to requesting state
+func TestDiscoverRequesting(t *testing.T) {
+ p := newPuppetClient(stateDiscovering)
+
+ // A minimal valid lease
+ offer := newResponse(dhcpv4.MessageTypeOffer)
+ testIP := net.IP{192, 0, 2, 2}
+ offer.UpdateOption(dhcpv4.OptServerIdentifier(net.IP{192, 0, 2, 1}))
+ offer.UpdateOption(dhcpv4.OptIPAddressLeaseTime(10 * time.Second))
+ offer.YourIPAddr = testIP
+
+ // Intentionally bad offer with no lease time.
+ terribleOffer := newResponse(dhcpv4.MessageTypeOffer)
+ terribleOffer.UpdateOption(dhcpv4.OptServerIdentifier(net.IP{192, 0, 2, 2}))
+ terribleOffer.YourIPAddr = net.IPv4(192, 0, 2, 2)
+
+ // Send the bad offer first, then the valid offer
+ p.bmt.sendPackets(terribleOffer, offer)
+
+ if err := p.c.runState(context.Background()); err != nil {
+ t.Error(err)
+ }
+ assert.Equal(t, stateRequesting, p.c.state, "DHCP client didn't process offer")
+ assert.Equal(t, testIP, p.c.offer.YourIPAddr, "DHCP client requested invalid offer")
+}
+
+// TestOfferBound tests if the DHCP state machine in requesting state processes a valid DHCPACK and transitions to
+// bound state.
+func TestRequestingBound(t *testing.T) {
+ p := newPuppetClient(stateRequesting)
+
+ offer := newResponse(dhcpv4.MessageTypeAck)
+ testIP := net.IP{192, 0, 2, 2}
+ offer.UpdateOption(dhcpv4.OptServerIdentifier(net.IP{192, 0, 2, 1}))
+ offer.UpdateOption(dhcpv4.OptIPAddressLeaseTime(10 * time.Second))
+ offer.YourIPAddr = testIP
+
+ p.bmt.sendPackets(offer)
+ p.c.offer = offer
+ p.c.LeaseCallback = func(old, new *Lease) error {
+ assert.Nil(t, old, "old lease is not nil for new lease")
+ assert.Equal(t, testIP, new.AssignedIP, "new lease has wrong IP")
+ return nil
+ }
+
+ if err := p.c.runState(context.Background()); err != nil {
+ t.Error(err)
+ }
+ assert.Equal(t, stateBound, p.c.state, "DHCP client didn't process offer")
+ assert.Equal(t, testIP, p.c.lease.YourIPAddr, "DHCP client requested invalid offer")
+}
+
+// TestRequestingDiscover tests if the DHCP state machine in requesting state transitions back to discovering if it
+// takes too long to get a valid DHCPACK.
+func TestRequestingDiscover(t *testing.T) {
+ p := newPuppetClient(stateRequesting)
+
+ offer := newResponse(dhcpv4.MessageTypeOffer)
+ testIP := net.IP{192, 0, 2, 2}
+ offer.UpdateOption(dhcpv4.OptServerIdentifier(net.IP{192, 0, 2, 1}))
+ offer.UpdateOption(dhcpv4.OptIPAddressLeaseTime(10 * time.Second))
+ offer.YourIPAddr = testIP
+ p.c.offer = offer
+
+ for i := 0; i < 10; i++ {
+ p.bmt.sendPackets()
+ if err := p.c.runState(context.Background()); err != nil {
+ t.Error(err)
+ }
+ assert.Equal(t, dhcpv4.MessageTypeRequest, p.bmt.sentPacket.MessageType(), "Invalid message type for requesting")
+ if p.c.state == stateDiscovering {
+ break
+ }
+ p.ft.time = p.bmt.setDeadline
+
+ if i == 9 {
+ t.Fatal("Too many tries while requesting, backoff likely wrong")
+ }
+ }
+ assert.Equal(t, stateDiscovering, p.c.state, "DHCP client didn't switch back to offer after requesting expired")
+}
+
+// TestDiscoverRapidCommit tests if the DHCP state machine in discovering state transitions directly to bound if a
+// rapid commit response (DHCPACK) is received.
+func TestDiscoverRapidCommit(t *testing.T) {
+ testIP := net.IP{192, 0, 2, 2}
+ offer := newResponse(dhcpv4.MessageTypeAck)
+ offer.UpdateOption(dhcpv4.OptServerIdentifier(net.IP{192, 0, 2, 1}))
+ leaseTime := 10 * time.Second
+ offer.UpdateOption(dhcpv4.OptIPAddressLeaseTime(leaseTime))
+ offer.YourIPAddr = testIP
+
+ p := newPuppetClient(stateDiscovering)
+ p.c.LeaseCallback = func(old, new *Lease) error {
+ assert.Nil(t, old, "old is not nil")
+ assert.Equal(t, testIP, new.AssignedIP, "callback called with wrong IP")
+ assert.Equal(t, p.ft.Now().Add(leaseTime), new.ExpiresAt, "invalid ExpiresAt")
+ return nil
+ }
+ p.bmt.sendPackets(offer)
+ if err := p.c.runState(context.Background()); err != nil {
+ t.Error(err)
+ }
+ assert.Equal(t, stateBound, p.c.state, "DHCP client didn't process offer")
+ assert.Equal(t, testIP, p.c.lease.YourIPAddr, "DHCP client requested invalid offer")
+ assert.Equal(t, 5*time.Second, p.c.leaseBoundDeadline.Sub(p.ft.Now()), "Renewal time was incorrectly defaulted")
+}
+
+type TestOption uint8
+
+func (o TestOption) Code() uint8 {
+ return uint8(o) + 224 // Private options
+}
+func (o TestOption) String() string {
+ return fmt.Sprintf("Test Option %d", uint8(o))
+}
+
+// TestBoundRenewingBound tests if the DHCP state machine in bound correctly transitions to renewing after
+// leaseBoundDeadline expires, sends a DHCPREQUEST and after it gets a DHCPACK response calls LeaseCallback and
+// transitions back to bound with correct new deadlines.
+func TestBoundRenewingBound(t *testing.T) {
+ offer := newResponse(dhcpv4.MessageTypeAck)
+ testIP := net.IP{192, 0, 2, 2}
+ serverIP := net.IP{192, 0, 2, 1}
+ offer.UpdateOption(dhcpv4.OptServerIdentifier(serverIP))
+ leaseTime := 10 * time.Second
+ offer.UpdateOption(dhcpv4.OptIPAddressLeaseTime(leaseTime))
+ offer.YourIPAddr = testIP
+
+ p := newPuppetClient(stateBound)
+ p.umt.Open(serverIP, testIP)
+ p.c.lease, _ = dhcpv4.FromBytes(offer.ToBytes())
+ // Other deadlines are intentionally empty to make sure they aren't used
+ p.c.leaseRenewDeadline = p.ft.Now().Add(8500 * time.Millisecond)
+ p.c.leaseBoundDeadline = p.ft.Now().Add(5000 * time.Millisecond)
+
+ p.ft.Advance(5*time.Second - 5*time.Millisecond)
+ if err := p.c.runState(context.Background()); err != nil {
+ t.Error(err)
+ }
+ p.ft.Advance(5 * time.Millisecond) // We cannot intercept time.After so we just advance the clock by the time slept
+ assert.Equal(t, stateRenewing, p.c.state, "DHCP client not renewing")
+ offer.UpdateOption(dhcpv4.OptGeneric(TestOption(1), []byte{0x12}))
+ p.umt.sendPackets(offer)
+ p.c.LeaseCallback = func(old, new *Lease) error {
+ assert.Equal(t, testIP, old.AssignedIP, "callback called with wrong old IP")
+ assert.Equal(t, testIP, new.AssignedIP, "callback called with wrong IP")
+ assert.Equal(t, p.ft.Now().Add(leaseTime), new.ExpiresAt, "invalid ExpiresAt")
+ assert.Empty(t, old.Options.Get(TestOption(1)), "old contains options from new")
+ assert.Equal(t, []byte{0x12}, new.Options.Get(TestOption(1)), "renewal didn't add new option")
+ return nil
+ }
+ if err := p.c.runState(context.Background()); err != nil {
+ t.Error(err)
+ }
+ assert.Equal(t, stateBound, p.c.state, "DHCP client didn't renew")
+ assert.Equal(t, p.ft.Now().Add(leaseTime), p.c.leaseDeadline, "lease deadline not updated")
+ assert.Equal(t, dhcpv4.MessageTypeRequest, p.umt.sentPacket.MessageType(), "Invalid message type for renewal")
+}
+
+// TestRenewingRebinding tests if the DHCP state machine in renewing state correctly sends DHCPREQUESTs and transitions
+// to the rebinding state when it hasn't received a valid response until the deadline expires.
+func TestRenewingRebinding(t *testing.T) {
+ offer := newResponse(dhcpv4.MessageTypeAck)
+ testIP := net.IP{192, 0, 2, 2}
+ serverIP := net.IP{192, 0, 2, 1}
+ offer.UpdateOption(dhcpv4.OptServerIdentifier(serverIP))
+ leaseTime := 10 * time.Second
+ offer.UpdateOption(dhcpv4.OptIPAddressLeaseTime(leaseTime))
+ offer.YourIPAddr = testIP
+
+ p := newPuppetClient(stateRenewing)
+ p.umt.Open(serverIP, testIP)
+ p.c.lease, _ = dhcpv4.FromBytes(offer.ToBytes())
+ // Other deadlines are intentionally empty to make sure they aren't used
+ p.c.leaseRenewDeadline = p.ft.Now().Add(8500 * time.Millisecond)
+ p.c.leaseDeadline = p.ft.Now().Add(10000 * time.Millisecond)
+
+ startTime := p.ft.Now()
+ p.ft.Advance(5 * time.Second)
+
+ p.c.LeaseCallback = func(old, new *Lease) error {
+ t.Fatal("Lease callback called without valid offer")
+ return nil
+ }
+
+ for i := 0; i < 10; i++ {
+ p.umt.sendPackets()
+ if err := p.c.runState(context.Background()); err != nil {
+ t.Error(err)
+ }
+ assert.Equal(t, dhcpv4.MessageTypeRequest, p.umt.sentPacket.MessageType(), "Invalid message type for renewal")
+ p.ft.time = p.umt.setDeadline
+
+ if p.c.state == stateRebinding {
+ break
+ }
+ if i == 9 {
+ t.Fatal("Too many tries while renewing, backoff likely wrong")
+ }
+ }
+ assert.Equal(t, startTime.Add(8500*time.Millisecond), p.umt.setDeadline, "wrong listen deadline when renewing")
+ assert.Equal(t, stateRebinding, p.c.state, "DHCP client not renewing")
+ assert.False(t, p.bmt.closed)
+ assert.True(t, p.umt.closed)
+}
+
+// TestRebindingBound tests if the DHCP state machine in rebinding state sends DHCPREQUESTs to the network and if
+// it receives a valid DHCPACK correctly transitions back to bound state.
+func TestRebindingBound(t *testing.T) {
+ offer := newResponse(dhcpv4.MessageTypeAck)
+ testIP := net.IP{192, 0, 2, 2}
+ serverIP := net.IP{192, 0, 2, 1}
+ offer.UpdateOption(dhcpv4.OptServerIdentifier(serverIP))
+ leaseTime := 10 * time.Second
+ offer.UpdateOption(dhcpv4.OptIPAddressLeaseTime(leaseTime))
+ offer.YourIPAddr = testIP
+
+ p := newPuppetClient(stateRebinding)
+ p.c.lease, _ = dhcpv4.FromBytes(offer.ToBytes())
+ // Other deadlines are intentionally empty to make sure they aren't used
+ p.c.leaseDeadline = p.ft.Now().Add(10000 * time.Millisecond)
+
+ p.ft.Advance(9 * time.Second)
+ if err := p.c.runState(context.Background()); err != nil {
+ t.Error(err)
+ }
+ assert.Equal(t, dhcpv4.MessageTypeRequest, p.bmt.sentPacket.MessageType(), "DHCP rebind sent invalid message type")
+ assert.Equal(t, stateRebinding, p.c.state, "DHCP client transferred out of rebinding state without trigger")
+ offer.UpdateOption(dhcpv4.OptGeneric(TestOption(1), []byte{0x12})) // Mark answer
+ p.bmt.sendPackets(offer)
+ p.bmt.sentPacket = nil
+ p.c.LeaseCallback = func(old, new *Lease) error {
+ assert.Equal(t, testIP, old.AssignedIP, "callback called with wrong old IP")
+ assert.Equal(t, testIP, new.AssignedIP, "callback called with wrong IP")
+ assert.Equal(t, p.ft.Now().Add(leaseTime), new.ExpiresAt, "invalid ExpiresAt")
+ assert.Empty(t, old.Options.Get(TestOption(1)), "old contains options from new")
+ assert.Equal(t, []byte{0x12}, new.Options.Get(TestOption(1)), "renewal didn't add new option")
+ return nil
+ }
+ if err := p.c.runState(context.Background()); err != nil {
+ t.Error(err)
+ }
+ assert.Equal(t, dhcpv4.MessageTypeRequest, p.bmt.sentPacket.MessageType())
+ assert.Equal(t, stateBound, p.c.state, "DHCP client didn't go back to bound")
+}
+
+// TestRebindingBound tests if the DHCP state machine in rebinding state transitions to discovering state if
+// leaseDeadline expires and calls LeaseCallback with an empty new lease.
+func TestRebindingDiscovering(t *testing.T) {
+ offer := newResponse(dhcpv4.MessageTypeAck)
+ testIP := net.IP{192, 0, 2, 2}
+ serverIP := net.IP{192, 0, 2, 1}
+ offer.UpdateOption(dhcpv4.OptServerIdentifier(serverIP))
+ leaseTime := 10 * time.Second
+ offer.UpdateOption(dhcpv4.OptIPAddressLeaseTime(leaseTime))
+ offer.YourIPAddr = testIP
+
+ p := newPuppetClient(stateRebinding)
+ p.c.lease, _ = dhcpv4.FromBytes(offer.ToBytes())
+ // Other deadlines are intentionally empty to make sure they aren't used
+ p.c.leaseDeadline = p.ft.Now().Add(10000 * time.Millisecond)
+
+ p.ft.Advance(9 * time.Second)
+ p.c.LeaseCallback = func(old, new *Lease) error {
+ assert.Equal(t, testIP, old.AssignedIP, "callback called with wrong old IP")
+ assert.Nil(t, new, "transition to discovering didn't clear new lease on callback")
+ return nil
+ }
+ for i := 0; i < 10; i++ {
+ p.bmt.sendPackets()
+ p.bmt.sentPacket = nil
+ if err := p.c.runState(context.Background()); err != nil {
+ t.Error(err)
+ }
+ if p.c.state == stateDiscovering {
+ assert.Nil(t, p.bmt.sentPacket)
+ break
+ }
+ assert.Equal(t, dhcpv4.MessageTypeRequest, p.bmt.sentPacket.MessageType(), "Invalid message type for rebind")
+ p.ft.time = p.bmt.setDeadline
+ if i == 9 {
+ t.Fatal("Too many tries while rebinding, backoff likely wrong")
+ }
+ }
+ assert.Nil(t, p.c.lease, "Lease not zeroed on transition to discovering")
+ assert.Equal(t, stateDiscovering, p.c.state, "DHCP client didn't transition to discovering after loosing lease")
+}
diff --git a/metropolis/node/core/network/dhcp4c/doc.go b/metropolis/node/core/network/dhcp4c/doc.go
new file mode 100644
index 0000000..b270c7b
--- /dev/null
+++ b/metropolis/node/core/network/dhcp4c/doc.go
@@ -0,0 +1,53 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package dhcp4c provides a client implementation of the DHCPv4 protocol (RFC2131) and a few extensions for Linux-based
+// systems.
+// The code is split into three main parts:
+// - The core DHCP state machine, which lives in dhcpc.go
+// - Mechanisms to send and receive DHCP messages, which live in transport/
+// - Standard callbacks which implement necessary kernel configuration steps in a simple and standalone way living in
+// callback/
+//
+// Since the DHCP protocol is ugly and underspecified (see https://tools.ietf.org/html/draft-ietf-dhc-implementation-02
+// for a subset of known issues), this client slightly bends the specification in the following cases:
+// - IP fragmentation for DHCP messages is not supported for both sending and receiving messages
+// This is because the major servers (ISC, dnsmasq, ...) do not implement it and just drop fragmented packets, so it
+// would be counterproductive to try to send them. The client just attempts to send the full message and hopes it
+// passes through to the server.
+// - The suggested timeouts and wait periods have been tightened significantly. When the standard was written 10Mbps
+// Ethernet with hubs was a common interconnect. Using these would make the client extremely slow on today's
+// 1Gbps+ networks.
+// - Wrong data in DHCP responses is fixed up if possible. This fixing includes dropping prohibited options, clamping
+// semantically invalid data and defaulting not set options as far as it's possible. Non-recoverable responses
+// (for example because a non-Unicast IP is handed out or lease time is not set or zero) are still ignored.
+// All data which can be stored in both DHCP fields and options is also normalized to the corresponding option.
+// - Duplicate Address Detection is not implemented by default. It's slow, hard to implement correctly and generally
+// not necessary on modern networks as the servers already waste time checking for duplicate addresses. It's possible
+// to hook it in via a LeaseCallback if necessary in a given application.
+//
+// Operationally, there's one known caveat to using this client: If the lease offered during the select phase (in a
+// DHCPOFFER) is not the same as the one sent in the following DHCPACK the first one might be acceptable, but the second
+// one might not be. This can cause pathological behavior where the client constantly switches between discovering and
+// requesting states. Depending on the reuse policies on the DHCP server this can cause the client to consume all
+// available IP addresses. Sadly there's no good way of fixing this within the boundaries of the protocol. A DHCPRELEASE
+// for the adresse would need to be unicasted so the unaccepable address would need to be configured which can be either
+// impossible if it's not valid or not acceptable from a security standpoint (for example because it overlaps with a
+// prefix used internally) and a DHCPDECLINE would cause the server to blacklist the IP thus also depleting the IP pool.
+// This could be potentially avoided by originating DHCPRELEASE packages from a userspace transport, but said transport
+// would need to be routing- and PMTU-aware which would make it even more complicated than the existing
+// BroadcastTransport.
+package dhcp4c
diff --git a/metropolis/node/core/network/dhcp4c/lease.go b/metropolis/node/core/network/dhcp4c/lease.go
new file mode 100644
index 0000000..c56270c
--- /dev/null
+++ b/metropolis/node/core/network/dhcp4c/lease.go
@@ -0,0 +1,124 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dhcp4c
+
+import (
+ "encoding/binary"
+ "net"
+ "time"
+
+ "github.com/insomniacslk/dhcp/dhcpv4"
+)
+
+// Lease represents a DHCPv4 lease. It only consists of an IP, an expiration timestamp and options as all other
+// relevant parts of the message have been normalized into their respective options. It also contains some smart
+// getters for commonly-used options which extract only valid information from options.
+type Lease struct {
+ AssignedIP net.IP
+ ExpiresAt time.Time
+ Options dhcpv4.Options
+}
+
+// SubnetMask returns the SubnetMask option or the default mask if not set or invalid.
+// It returns nil if the lease is nil.
+func (l *Lease) SubnetMask() net.IPMask {
+ if l == nil {
+ return nil
+ }
+ mask := net.IPMask(dhcpv4.GetIP(dhcpv4.OptionSubnetMask, l.Options))
+ if _, bits := mask.Size(); bits != 32 { // If given mask is not valid, use the default mask
+ mask = l.AssignedIP.DefaultMask()
+ }
+ return mask
+}
+
+// IPNet returns an IPNet from the assigned IP and subnet mask.
+// It returns nil if the lease is nil.
+func (l *Lease) IPNet() *net.IPNet {
+ if l == nil {
+ return nil
+ }
+ return &net.IPNet{
+ IP: l.AssignedIP,
+ Mask: l.SubnetMask(),
+ }
+}
+
+// Router returns the first valid router from the DHCP router option or nil if none such exists.
+// It returns nil if the lease is nil.
+func (l *Lease) Router() net.IP {
+ if l == nil {
+ return nil
+ }
+ routers := dhcpv4.GetIPs(dhcpv4.OptionRouter, l.Options)
+ for _, r := range routers {
+ if r.IsGlobalUnicast() || r.IsLinkLocalUnicast() {
+ return r
+ }
+ }
+ // No (valid) router found
+ return nil
+}
+
+// DNSServers represents an ordered collection of DNS servers
+type DNSServers []net.IP
+
+func (a DNSServers) Equal(b DNSServers) bool {
+ if len(a) == len(b) {
+ if len(a) == 0 {
+ return true // both are empty or nil
+ }
+ for i, aVal := range a {
+ if !aVal.Equal(b[i]) {
+ return false
+ }
+ }
+ return true
+ }
+ return false
+
+}
+
+func ip4toInt(ip net.IP) uint32 {
+ ip4 := ip.To4()
+ if ip4 == nil {
+ return 0
+ }
+ return binary.BigEndian.Uint32(ip4)
+}
+
+// DNSServers returns all unique valid DNS servers from the DHCP DomainNameServers options.
+// It returns nil if the lease is nil.
+func (l *Lease) DNSServers() DNSServers {
+ if l == nil {
+ return nil
+ }
+ rawServers := dhcpv4.GetIPs(dhcpv4.OptionDomainNameServer, l.Options)
+ var servers DNSServers
+ serversSeenMap := make(map[uint32]struct{})
+ for _, s := range rawServers {
+ ip4Num := ip4toInt(s)
+ if s.IsGlobalUnicast() || s.IsLinkLocalUnicast() || ip4Num != 0 {
+ if _, ok := serversSeenMap[ip4Num]; ok {
+ continue
+ }
+ serversSeenMap[ip4Num] = struct{}{}
+ servers = append(servers, s)
+ }
+ }
+ return servers
+}
diff --git a/metropolis/node/core/network/dhcp4c/lease_test.go b/metropolis/node/core/network/dhcp4c/lease_test.go
new file mode 100644
index 0000000..823656f
--- /dev/null
+++ b/metropolis/node/core/network/dhcp4c/lease_test.go
@@ -0,0 +1,55 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dhcp4c
+
+import (
+ "net"
+ "testing"
+
+ "github.com/insomniacslk/dhcp/dhcpv4"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestLeaseDHCPServers(t *testing.T) {
+ var tests = []struct {
+ name string
+ lease *Lease
+ expected DNSServers
+ }{{
+ name: "ReturnsNilWithNoLease",
+ lease: nil,
+ expected: nil,
+ }, {
+ name: "DiscardsInvalidIPs",
+ lease: &Lease{
+ Options: dhcpv4.OptionsFromList(dhcpv4.OptDNS(net.IP{0, 0, 0, 0})),
+ },
+ expected: nil,
+ }, {
+ name: "DeduplicatesIPs",
+ lease: &Lease{
+ Options: dhcpv4.OptionsFromList(dhcpv4.OptDNS(net.IP{192, 0, 2, 1}, net.IP{192, 0, 2, 2}, net.IP{192, 0, 2, 1})),
+ },
+ expected: DNSServers{net.IP{192, 0, 2, 1}, net.IP{192, 0, 2, 2}},
+ }}
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ res := test.lease.DNSServers()
+ assert.Equal(t, test.expected, res)
+ })
+ }
+}
diff --git a/metropolis/node/core/network/dhcp4c/transport/BUILD.bazel b/metropolis/node/core/network/dhcp4c/transport/BUILD.bazel
new file mode 100644
index 0000000..edd47a1
--- /dev/null
+++ b/metropolis/node/core/network/dhcp4c/transport/BUILD.bazel
@@ -0,0 +1,20 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "go_default_library",
+ srcs = [
+ "transport.go",
+ "transport_broadcast.go",
+ "transport_unicast.go",
+ ],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/node/core/network/dhcp4c/transport",
+ visibility = ["//visibility:public"],
+ deps = [
+ "@com_github_google_gopacket//:go_default_library",
+ "@com_github_google_gopacket//layers:go_default_library",
+ "@com_github_insomniacslk_dhcp//dhcpv4:go_default_library",
+ "@com_github_mdlayher_raw//:go_default_library",
+ "@org_golang_x_net//bpf:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/metropolis/node/core/network/dhcp4c/transport/transport.go b/metropolis/node/core/network/dhcp4c/transport/transport.go
new file mode 100644
index 0000000..8f5f791
--- /dev/null
+++ b/metropolis/node/core/network/dhcp4c/transport/transport.go
@@ -0,0 +1,50 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package transport contains Linux-based transports for the DHCP broadcast and unicast
+// specifications.
+package transport
+
+import (
+ "errors"
+ "fmt"
+ "net"
+)
+
+var DeadlineExceededErr = errors.New("deadline exceeded")
+
+func NewInvalidMessageError(internalErr error) error {
+ return &InvalidMessageError{internalErr: internalErr}
+}
+
+type InvalidMessageError struct {
+ internalErr error
+}
+
+func (i InvalidMessageError) Error() string {
+ return fmt.Sprintf("received invalid packet: %v", i.internalErr.Error())
+}
+
+func (i InvalidMessageError) Unwrap() error {
+ return i.internalErr
+}
+
+func deadlineFromTimeout(err error) error {
+ if timeoutErr, ok := err.(net.Error); ok && timeoutErr.Timeout() {
+ return DeadlineExceededErr
+ }
+ return err
+}
diff --git a/metropolis/node/core/network/dhcp4c/transport/transport_broadcast.go b/metropolis/node/core/network/dhcp4c/transport/transport_broadcast.go
new file mode 100644
index 0000000..79fad7d
--- /dev/null
+++ b/metropolis/node/core/network/dhcp4c/transport/transport_broadcast.go
@@ -0,0 +1,207 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package transport
+
+import (
+ "errors"
+ "fmt"
+ "math"
+ "net"
+ "time"
+
+ "github.com/google/gopacket"
+ "github.com/google/gopacket/layers"
+ "github.com/insomniacslk/dhcp/dhcpv4"
+ "github.com/mdlayher/raw"
+ "golang.org/x/net/bpf"
+)
+
+const (
+ // RFC2474 Section 4.2.2.1 with reference to RFC791 Section 3.1 (Network Control Precedence)
+ dscpCS7 = 0x7 << 3
+
+ // IPv4 MTU
+ maxIPv4MTU = math.MaxUint16 // IPv4 "Total Length" field is an unsigned 16 bit integer
+)
+
+// mustAssemble calls bpf.Assemble and panics if it retuns an error.
+func mustAssemble(insns []bpf.Instruction) []bpf.RawInstruction {
+ rawInsns, err := bpf.Assemble(insns)
+ if err != nil {
+ panic("mustAssemble failed to assemble BPF: " + err.Error())
+ }
+ return rawInsns
+}
+
+// BPF filter for UDP in IPv4 with destination port 68 (DHCP Client)
+//
+// This is used to make the kernel drop non-DHCP traffic for us so that we don't have to handle
+// excessive unrelated traffic flowing on a given link which might overwhelm the single-threaded
+// receiver.
+var bpfFilterInstructions = []bpf.Instruction{
+ // Check IP protocol version equals 4 (first 4 bits of the first byte)
+ // With Ethernet II framing, this is more of a sanity check. We already request the kernel
+ // to only return EtherType 0x0800 (IPv4) frames.
+ bpf.LoadAbsolute{Off: 0, Size: 1},
+ bpf.ALUOpConstant{Op: bpf.ALUOpAnd, Val: 0xf0}, // SubnetMask second 4 bits
+ bpf.JumpIf{Cond: bpf.JumpEqual, Val: 4 << 4, SkipTrue: 1},
+ bpf.RetConstant{Val: 0}, // Discard
+
+ // Check IPv4 Protocol byte (offset 9) equals UDP
+ bpf.LoadAbsolute{Off: 9, Size: 1},
+ bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(layers.IPProtocolUDP), SkipTrue: 1},
+ bpf.RetConstant{Val: 0}, // Discard
+
+ // Check if IPv4 fragment offset is all-zero (this is not a fragment)
+ bpf.LoadAbsolute{Off: 6, Size: 2},
+ bpf.JumpIf{Cond: bpf.JumpBitsSet, Val: 0x1fff, SkipFalse: 1},
+ bpf.RetConstant{Val: 0}, // Discard
+
+ // Load IPv4 header size from offset zero and store it into X
+ bpf.LoadMemShift{Off: 0},
+
+ // Check if UDP header destination port equals 68
+ bpf.LoadIndirect{Off: 2, Size: 2}, // Offset relative to header size in register X
+ bpf.JumpIf{Cond: bpf.JumpEqual, Val: 68, SkipTrue: 1},
+ bpf.RetConstant{Val: 0}, // Discard
+
+ // Accept packet and pass through up maximum IP packet size
+ bpf.RetConstant{Val: maxIPv4MTU},
+}
+
+var bpfFilter = mustAssemble(bpfFilterInstructions)
+
+// BroadcastTransport implements a DHCP transport based on a custom IP/UDP stack fulfilling the
+// specific requirements for broadcasting DHCP packets (like all-zero source address, no ARP, ...)
+type BroadcastTransport struct {
+ rawConn *raw.Conn
+ iface *net.Interface
+}
+
+func NewBroadcastTransport(iface *net.Interface) *BroadcastTransport {
+ return &BroadcastTransport{iface: iface}
+}
+
+func (t *BroadcastTransport) Open() error {
+ if t.rawConn != nil {
+ return errors.New("broadcast transport already open")
+ }
+ rawConn, err := raw.ListenPacket(t.iface, uint16(layers.EthernetTypeIPv4), &raw.Config{
+ LinuxSockDGRAM: true,
+ Filter: bpfFilter,
+ })
+ if err != nil {
+ return fmt.Errorf("failed to create raw listener: %w", err)
+ }
+ t.rawConn = rawConn
+ return nil
+}
+
+func (t *BroadcastTransport) Send(payload *dhcpv4.DHCPv4) error {
+ if t.rawConn == nil {
+ return errors.New("broadcast transport closed")
+ }
+ packet := gopacket.NewSerializeBuffer()
+ opts := gopacket.SerializeOptions{
+ ComputeChecksums: true,
+ FixLengths: true,
+ }
+
+ ipLayer := &layers.IPv4{
+ Version: 4,
+ TOS: dscpCS7 << 2, // Shift left of ECN field
+ TTL: 1, // These packets should never be routed (their IP headers contain garbage)
+ Protocol: layers.IPProtocolUDP,
+ Flags: layers.IPv4DontFragment, // Most DHCP servers don't support fragmented packets
+ DstIP: net.IPv4bcast,
+ SrcIP: net.IPv4zero,
+ }
+ udpLayer := &layers.UDP{
+ DstPort: 67,
+ SrcPort: 68,
+ }
+ if err := udpLayer.SetNetworkLayerForChecksum(ipLayer); err != nil {
+ panic("Invalid layer stackup encountered")
+ }
+
+ err := gopacket.SerializeLayers(packet, opts,
+ ipLayer,
+ udpLayer,
+ gopacket.Payload(payload.ToBytes()))
+
+ if err != nil {
+ return fmt.Errorf("failed to assemble packet: %w", err)
+ }
+
+ _, err = t.rawConn.WriteTo(packet.Bytes(), &raw.Addr{HardwareAddr: layers.EthernetBroadcast})
+ if err != nil {
+ return fmt.Errorf("failed to transmit broadcast packet: %w", err)
+ }
+ return nil
+}
+
+func (t *BroadcastTransport) Receive() (*dhcpv4.DHCPv4, error) {
+ if t.rawConn == nil {
+ return nil, errors.New("broadcast transport closed")
+ }
+ buf := make([]byte, math.MaxUint16) // Maximum IP packet size
+ n, _, err := t.rawConn.ReadFrom(buf)
+ if err != nil {
+ return nil, deadlineFromTimeout(err)
+ }
+ respPacket := gopacket.NewPacket(buf[:n], layers.LayerTypeIPv4, gopacket.Default)
+ ipLayer := respPacket.Layer(layers.LayerTypeIPv4)
+ if ipLayer == nil {
+ return nil, NewInvalidMessageError(errors.New("got invalid IP packet"))
+ }
+ ip := ipLayer.(*layers.IPv4)
+ if ip.Flags&layers.IPv4MoreFragments != 0 {
+ return nil, NewInvalidMessageError(errors.New("got fragmented message"))
+ }
+
+ udpLayer := respPacket.Layer(layers.LayerTypeUDP)
+ if udpLayer == nil {
+ return nil, NewInvalidMessageError(errors.New("got non-UDP packet"))
+ }
+ udp := udpLayer.(*layers.UDP)
+ if udp.DstPort != 68 {
+ return nil, NewInvalidMessageError(errors.New("message not for DHCP client port"))
+ }
+ msg, err := dhcpv4.FromBytes(udp.Payload)
+ if err != nil {
+ return nil, NewInvalidMessageError(fmt.Errorf("failed to decode DHCPv4 message: %w", err))
+ }
+ return msg, nil
+}
+
+func (t *BroadcastTransport) Close() error {
+ if t.rawConn == nil {
+ return nil
+ }
+ if err := t.rawConn.Close(); err != nil {
+ return err
+ }
+ t.rawConn = nil
+ return nil
+}
+
+func (t *BroadcastTransport) SetReceiveDeadline(deadline time.Time) error {
+ if t.rawConn == nil {
+ return errors.New("broadcast transport closed")
+ }
+ return t.rawConn.SetReadDeadline(deadline)
+}
diff --git a/metropolis/node/core/network/dhcp4c/transport/transport_unicast.go b/metropolis/node/core/network/dhcp4c/transport/transport_unicast.go
new file mode 100644
index 0000000..bf2b3a4
--- /dev/null
+++ b/metropolis/node/core/network/dhcp4c/transport/transport_unicast.go
@@ -0,0 +1,122 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package transport
+
+import (
+ "errors"
+ "fmt"
+ "math"
+ "net"
+ "os"
+ "strings"
+ "time"
+
+ "github.com/insomniacslk/dhcp/dhcpv4"
+ "golang.org/x/sys/unix"
+)
+
+// UnicastTransport implements a DHCP transport based on a normal Linux UDP socket with some custom
+// socket options to influence DSCP and routing.
+type UnicastTransport struct {
+ udpConn *net.UDPConn
+ targetIP net.IP
+ iface *net.Interface
+}
+
+func NewUnicastTransport(iface *net.Interface) *UnicastTransport {
+ return &UnicastTransport{
+ iface: iface,
+ }
+}
+
+func (t *UnicastTransport) Open(serverIP, bindIP net.IP) error {
+ if t.udpConn != nil {
+ return errors.New("unicast transport already open")
+ }
+ rawFd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, 0)
+ if err != nil {
+ return fmt.Errorf("failed to get socket: %w", err)
+ }
+ if err := unix.BindToDevice(rawFd, t.iface.Name); err != nil {
+ return fmt.Errorf("failed to bind UDP interface to device: %w", err)
+ }
+ if err := unix.SetsockoptByte(rawFd, unix.SOL_IP, unix.IP_TOS, dscpCS7<<2); err != nil {
+ return fmt.Errorf("failed to set DSCP CS7: %w", err)
+ }
+ var addr [4]byte
+ copy(addr[:], bindIP.To4())
+ if err := unix.Bind(rawFd, &unix.SockaddrInet4{Addr: addr, Port: 68}); err != nil {
+ return fmt.Errorf("failed to bind UDP unicast interface: %w", err)
+ }
+ filePtr := os.NewFile(uintptr(rawFd), "dhcp-udp")
+ defer filePtr.Close()
+ conn, err := net.FileConn(filePtr)
+ if err != nil {
+ return fmt.Errorf("failed to initialize runtime-supported UDP connection: %w", err)
+ }
+ realConn, ok := conn.(*net.UDPConn)
+ if !ok {
+ panic("UDP socket imported into Go runtime is no longer a UDP socket")
+ }
+ t.udpConn = realConn
+ t.targetIP = serverIP
+ return nil
+}
+
+func (t *UnicastTransport) Send(payload *dhcpv4.DHCPv4) error {
+ if t.udpConn == nil {
+ return errors.New("unicast transport closed")
+ }
+ _, _, err := t.udpConn.WriteMsgUDP(payload.ToBytes(), []byte{}, &net.UDPAddr{
+ IP: t.targetIP,
+ Port: 67,
+ })
+ return err
+}
+
+func (t *UnicastTransport) SetReceiveDeadline(deadline time.Time) error {
+ return t.udpConn.SetReadDeadline(deadline)
+}
+
+func (t *UnicastTransport) Receive() (*dhcpv4.DHCPv4, error) {
+ if t.udpConn == nil {
+ return nil, errors.New("unicast transport closed")
+ }
+ receiveBuf := make([]byte, math.MaxUint16)
+ _, _, err := t.udpConn.ReadFromUDP(receiveBuf)
+ if err != nil {
+ return nil, deadlineFromTimeout(err)
+ }
+ msg, err := dhcpv4.FromBytes(receiveBuf)
+ if err != nil {
+ return nil, NewInvalidMessageError(err)
+ }
+ return msg, nil
+}
+
+func (t *UnicastTransport) Close() error {
+ if t.udpConn == nil {
+ return nil
+ }
+ err := t.udpConn.Close()
+ t.udpConn = nil
+ // TODO(lorenz): Move to net.ErrClosed once Go 1.16 lands
+ if err != nil && strings.Contains(err.Error(), "use of closed network connection") {
+ return nil
+ }
+ return err
+}
diff --git a/metropolis/node/core/network/dns/BUILD.bazel b/metropolis/node/core/network/dns/BUILD.bazel
new file mode 100644
index 0000000..862d4cf
--- /dev/null
+++ b/metropolis/node/core/network/dns/BUILD.bazel
@@ -0,0 +1,15 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "go_default_library",
+ srcs = [
+ "coredns.go",
+ "directives.go",
+ ],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/node/core/network/dns",
+ visibility = ["//metropolis/node:__subpackages__"],
+ deps = [
+ "//metropolis/node/common/fileargs:go_default_library",
+ "//metropolis/node/common/supervisor:go_default_library",
+ ],
+)
diff --git a/metropolis/node/core/network/dns/coredns.go b/metropolis/node/core/network/dns/coredns.go
new file mode 100644
index 0000000..b6400f7
--- /dev/null
+++ b/metropolis/node/core/network/dns/coredns.go
@@ -0,0 +1,148 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package DNS provides a DNS server using CoreDNS.
+package dns
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "os/exec"
+ "strings"
+ "sync"
+ "syscall"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/node/common/fileargs"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/common/supervisor"
+)
+
+const corefileBase = `
+.:53 {
+ errors
+ hosts {
+ fallthrough
+ }
+
+ cache 30
+ loadbalance
+`
+
+type Service struct {
+ directiveRegistration chan *ExtraDirective
+ directives map[string]ExtraDirective
+ cmd *exec.Cmd
+ args *fileargs.FileArgs
+ // stateMu guards access to the directives, cmd and args fields
+ stateMu sync.Mutex
+}
+
+// New creates a new CoreDNS service.
+// The given channel can then be used to dynamically register and unregister directives in the configuaration.
+// To register a new directive, send an ExtraDirective on the channel. To remove it again, use CancelDirective()
+// to create a removal message.
+func New(directiveRegistration chan *ExtraDirective) *Service {
+ return &Service{
+ directives: map[string]ExtraDirective{},
+ directiveRegistration: directiveRegistration,
+ }
+}
+
+func (s *Service) makeCorefile(fargs *fileargs.FileArgs) []byte {
+ corefile := bytes.Buffer{}
+ corefile.WriteString(corefileBase)
+ for _, dir := range s.directives {
+ resolvedDir := dir.directive
+ for fname, fcontent := range dir.files {
+ resolvedDir = strings.ReplaceAll(resolvedDir, fmt.Sprintf("$FILE(%v)", fname), fargs.ArgPath(fname, fcontent))
+ }
+ corefile.WriteString(resolvedDir)
+ corefile.WriteString("\n")
+ }
+ corefile.WriteString("\n}")
+ return corefile.Bytes()
+}
+
+// CancelDirective creates a message to cancel the given directive.
+func CancelDirective(d *ExtraDirective) *ExtraDirective {
+ return &ExtraDirective{
+ ID: d.ID,
+ }
+}
+
+// Run runs the DNS service consisting of the CoreDNS process and the directive registration process
+func (s *Service) Run(ctx context.Context) error {
+ supervisor.Run(ctx, "coredns", s.runCoreDNS)
+ supervisor.Run(ctx, "registration", s.runRegistration)
+ supervisor.Signal(ctx, supervisor.SignalHealthy)
+ supervisor.Signal(ctx, supervisor.SignalDone)
+ return nil
+}
+
+// runCoreDNS runs the CoreDNS proceess
+func (s *Service) runCoreDNS(ctx context.Context) error {
+ s.stateMu.Lock()
+ args, err := fileargs.New()
+ if err != nil {
+ s.stateMu.Unlock()
+ return fmt.Errorf("failed to create fileargs: %w", err)
+ }
+ defer args.Close()
+ s.args = args
+
+ s.cmd = exec.CommandContext(ctx, "/kubernetes/bin/coredns",
+ args.FileOpt("-conf", "Corefile", s.makeCorefile(args)),
+ )
+
+ if args.Error() != nil {
+ s.stateMu.Unlock()
+ return fmt.Errorf("failed to use fileargs: %w", err)
+ }
+
+ s.stateMu.Unlock()
+ return supervisor.RunCommand(ctx, s.cmd)
+}
+
+// runRegistration runs the background registration runnable which has a different lifecycle from the CoreDNS
+// runnable. It is responsible for managing dynamic directives.
+func (s *Service) runRegistration(ctx context.Context) error {
+ supervisor.Signal(ctx, supervisor.SignalHealthy)
+ for {
+ select {
+ case <-ctx.Done():
+ return nil
+ case d := <-s.directiveRegistration:
+ s.processRegistration(ctx, d)
+ }
+ }
+}
+
+func (s *Service) processRegistration(ctx context.Context, d *ExtraDirective) {
+ s.stateMu.Lock()
+ defer s.stateMu.Unlock()
+ if d.directive == "" {
+ delete(s.directives, d.ID)
+ } else {
+ s.directives[d.ID] = *d
+ }
+ // If the process is not currenty running we're relying on corefile regeneration on startup
+ if s.cmd != nil && s.cmd.Process != nil && s.cmd.ProcessState == nil {
+ s.args.ArgPath("Corefile", s.makeCorefile(s.args))
+ if err := s.cmd.Process.Signal(syscall.SIGUSR1); err != nil {
+ supervisor.Logger(ctx).Warningf("Failed to send SIGUSR1 to CoreDNS for reload: %v", err)
+ }
+ }
+}
diff --git a/metropolis/node/core/network/dns/directives.go b/metropolis/node/core/network/dns/directives.go
new file mode 100644
index 0000000..72c4f29
--- /dev/null
+++ b/metropolis/node/core/network/dns/directives.go
@@ -0,0 +1,73 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dns
+
+import (
+ "fmt"
+ "net"
+ "strings"
+)
+
+// Type ExtraDirective contains additional config directives for CoreDNS.
+type ExtraDirective struct {
+ // ID is the identifier of this directive. There can only be one directive with a given ID active at once.
+ // The ID is also used to identify which directive to purge.
+ ID string
+ // directive contains a full CoreDNS directive as a string. It can also use the $FILE(<filename>) macro,
+ // which will be expanded to the path of a file from the files field.
+ directive string
+ // files contains additional files used in the configuration. The map key is used as the filename.
+ files map[string][]byte
+}
+
+// NewUpstreamDirective creates a forward with no fallthrough that forwards all requests not yet matched to the given
+// upstream DNS servers.
+func NewUpstreamDirective(dnsServers []net.IP) *ExtraDirective {
+ strb := strings.Builder{}
+ if len(dnsServers) > 0 {
+ strb.WriteString("forward .")
+ for _, ip := range dnsServers {
+ strb.WriteString(" ")
+ strb.WriteString(ip.String())
+ }
+ }
+ return &ExtraDirective{
+ directive: strb.String(),
+ }
+}
+
+var kubernetesDirective = `
+kubernetes %v in-addr.arpa ip6.arpa {
+ kubeconfig $FILE(kubeconfig) default
+ pods insecure
+ fallthrough in-addr.arpa ip6.arpa
+ ttl 30
+}
+`
+
+// NewKubernetesDirective creates a directive running a "Kubernetes DNS-Based Service Discovery" [1] compliant service
+// under clusterDomain. The given kubeconfig needs at least read access to services, endpoints and endpointslices.
+// [1] https://github.com/kubernetes/dns/blob/master/docs/specification.md
+func NewKubernetesDirective(clusterDomain string, kubeconfig []byte) *ExtraDirective {
+ return &ExtraDirective{
+ ID: "k8s-clusterdns",
+ directive: fmt.Sprintf(kubernetesDirective, clusterDomain),
+ files: map[string][]byte{
+ "kubeconfig": kubeconfig,
+ },
+ }
+}
diff --git a/metropolis/node/core/network/main.go b/metropolis/node/core/network/main.go
new file mode 100644
index 0000000..29e757d
--- /dev/null
+++ b/metropolis/node/core/network/main.go
@@ -0,0 +1,260 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package network
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io/ioutil"
+ "net"
+ "os"
+ "sync"
+ "time"
+
+ "github.com/google/nftables"
+ "github.com/google/nftables/expr"
+ "github.com/insomniacslk/dhcp/dhcpv4"
+ "github.com/vishvananda/netlink"
+ "golang.org/x/sys/unix"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/node/common/supervisor"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/logtree"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/network/dhcp4c"
+ dhcpcb "git.monogon.dev/source/nexantic.git/metropolis/node/core/network/dhcp4c/callback"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/network/dns"
+)
+
+const (
+ resolvConfPath = "/etc/resolv.conf"
+ resolvConfSwapPath = "/etc/resolv.conf.new"
+)
+
+type Service struct {
+ config Config
+ dhcp *dhcp4c.Client
+
+ // nftConn is a shared file descriptor handle to nftables, automatically initialized on first use.
+ nftConn nftables.Conn
+ natTable *nftables.Table
+ natPostroutingChain *nftables.Chain
+
+ // These are a temporary hack pending the removal of the GetIP interface
+ ipLock sync.Mutex
+ currentIPTmp net.IP
+
+ logger logtree.LeveledLogger
+}
+
+type Config struct {
+ CorednsRegistrationChan chan *dns.ExtraDirective
+}
+
+func New(config Config) *Service {
+ return &Service{
+ config: config,
+ }
+}
+
+func setResolvconf(nameservers []net.IP, searchDomains []string) error {
+ _ = os.Mkdir("/etc", 0755)
+ newResolvConf, err := os.Create(resolvConfSwapPath)
+ if err != nil {
+ return err
+ }
+ defer newResolvConf.Close()
+ defer os.Remove(resolvConfSwapPath)
+ for _, ns := range nameservers {
+ if _, err := newResolvConf.WriteString(fmt.Sprintf("nameserver %v\n", ns)); err != nil {
+ return err
+ }
+ }
+ for _, searchDomain := range searchDomains {
+ if _, err := newResolvConf.WriteString(fmt.Sprintf("search %v", searchDomain)); err != nil {
+ return err
+ }
+ }
+ newResolvConf.Close()
+ // Atomically swap in new config
+ return unix.Rename(resolvConfSwapPath, resolvConfPath)
+}
+
+// nfifname converts an interface name into 16 bytes padded with zeroes (for nftables)
+func nfifname(n string) []byte {
+ b := make([]byte, 16)
+ copy(b, []byte(n+"\x00"))
+ return b
+}
+
+func (s *Service) dhcpDNSCallback(old, new *dhcp4c.Lease) error {
+ oldServers := old.DNSServers()
+ newServers := new.DNSServers()
+ if newServers.Equal(oldServers) {
+ return nil // nothing to do
+ }
+ s.logger.Infof("Setting upstream DNS servers to %v", newServers)
+ s.config.CorednsRegistrationChan <- dns.NewUpstreamDirective(newServers)
+ return nil
+}
+
+// TODO(lorenz): Get rid of this once we have robust node resolution
+func (s *Service) getIPCallbackHack(old, new *dhcp4c.Lease) error {
+ if old == nil && new != nil {
+ s.ipLock.Lock()
+ s.currentIPTmp = new.AssignedIP
+ s.ipLock.Unlock()
+ }
+ return nil
+}
+
+func (s *Service) useInterface(ctx context.Context, iface netlink.Link) error {
+ netIface, err := net.InterfaceByIndex(iface.Attrs().Index)
+ if err != nil {
+ return fmt.Errorf("cannot create Go net.Interface from netlink.Link: %w", err)
+ }
+ s.dhcp, err = dhcp4c.NewClient(netIface)
+ if err != nil {
+ return fmt.Errorf("failed to create DHCP client on interface %v: %w", iface.Attrs().Name, err)
+ }
+ s.dhcp.VendorClassIdentifier = "com.nexantic.smalltown.v1"
+ s.dhcp.RequestedOptions = []dhcpv4.OptionCode{dhcpv4.OptionRouter, dhcpv4.OptionNameServer}
+ s.dhcp.LeaseCallback = dhcpcb.Compose(dhcpcb.ManageIP(iface), dhcpcb.ManageDefaultRoute(iface), s.dhcpDNSCallback, s.getIPCallbackHack)
+ err = supervisor.Run(ctx, "dhcp", s.dhcp.Run)
+ if err != nil {
+ return err
+ }
+
+ // Masquerade/SNAT all traffic going out of the external interface
+ s.nftConn.AddRule(&nftables.Rule{
+ Table: s.natTable,
+ Chain: s.natPostroutingChain,
+ Exprs: []expr.Any{
+ &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
+ &expr.Cmp{
+ Op: expr.CmpOpEq,
+ Register: 1,
+ Data: nfifname(iface.Attrs().Name),
+ },
+ &expr.Masq{},
+ },
+ })
+
+ if err := s.nftConn.Flush(); err != nil {
+ panic(err)
+ }
+
+ return nil
+}
+
+// GetIP returns the current IP (and optionally waits for one to be assigned)
+func (s *Service) GetIP(ctx context.Context, wait bool) (*net.IP, error) {
+ for {
+ var currentIP net.IP
+ s.ipLock.Lock()
+ currentIP = s.currentIPTmp
+ s.ipLock.Unlock()
+ if currentIP == nil {
+ if !wait {
+ return nil, errors.New("no IP available")
+ }
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case <-time.After(1 * time.Second):
+ continue
+ }
+ }
+ return ¤tIP, nil
+ }
+}
+
+func (s *Service) Run(ctx context.Context) error {
+ logger := supervisor.Logger(ctx)
+ dnsSvc := dns.New(s.config.CorednsRegistrationChan)
+ supervisor.Run(ctx, "dns", dnsSvc.Run)
+ supervisor.Run(ctx, "interfaces", s.runInterfaces)
+
+ s.natTable = s.nftConn.AddTable(&nftables.Table{
+ Family: nftables.TableFamilyIPv4,
+ Name: "nat",
+ })
+
+ s.natPostroutingChain = s.nftConn.AddChain(&nftables.Chain{
+ Name: "postrouting",
+ Hooknum: nftables.ChainHookPostrouting,
+ Priority: nftables.ChainPriorityNATSource,
+ Table: s.natTable,
+ Type: nftables.ChainTypeNAT,
+ })
+ if err := s.nftConn.Flush(); err != nil {
+ logger.Fatalf("Failed to set up nftables base chains: %v", err)
+ }
+
+ if err := ioutil.WriteFile("/proc/sys/net/ipv4/ip_forward", []byte("1\n"), 0644); err != nil {
+ logger.Fatalf("Failed to enable IPv4 forwarding: %v", err)
+ }
+
+ // We're handling all DNS requests with CoreDNS, including local ones
+ if err := setResolvconf([]net.IP{{127, 0, 0, 1}}, []string{}); err != nil {
+ logger.Fatalf("Failed to set resolv.conf: %v", err)
+ }
+
+ supervisor.Signal(ctx, supervisor.SignalHealthy)
+ supervisor.Signal(ctx, supervisor.SignalDone)
+ return nil
+}
+
+func (s *Service) runInterfaces(ctx context.Context) error {
+ s.logger = supervisor.Logger(ctx)
+ s.logger.Info("Starting network interface management")
+
+ links, err := netlink.LinkList()
+ if err != nil {
+ s.logger.Fatalf("Failed to list network links: %s", err)
+ }
+
+ var ethernetLinks []netlink.Link
+ for _, link := range links {
+ attrs := link.Attrs()
+ if link.Type() == "device" && len(attrs.HardwareAddr) > 0 {
+ if len(attrs.HardwareAddr) == 6 { // Ethernet
+ if attrs.Flags&net.FlagUp != net.FlagUp {
+ netlink.LinkSetUp(link) // Attempt to take up all ethernet links
+ }
+ ethernetLinks = append(ethernetLinks, link)
+ } else {
+ s.logger.Infof("Ignoring non-Ethernet interface %s", attrs.Name)
+ }
+ } else if link.Attrs().Name == "lo" {
+ if err := netlink.LinkSetUp(link); err != nil {
+ s.logger.Errorf("Failed to bring up loopback interface: %v", err)
+ }
+ }
+ }
+ if len(ethernetLinks) != 1 {
+ s.logger.Warningf("Network service needs exactly one link, bailing")
+ } else {
+ link := ethernetLinks[0]
+ if err := s.useInterface(ctx, link); err != nil {
+ return fmt.Errorf("failed to bring up link %s: %w", link.Attrs().Name, err)
+ }
+ }
+
+ supervisor.Signal(ctx, supervisor.SignalHealthy)
+ supervisor.Signal(ctx, supervisor.SignalDone)
+ return nil
+}
diff --git a/metropolis/node/core/switchroot.go b/metropolis/node/core/switchroot.go
new file mode 100644
index 0000000..5865225
--- /dev/null
+++ b/metropolis/node/core/switchroot.go
@@ -0,0 +1,213 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "fmt"
+ "io"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "strings"
+ "syscall"
+
+ "golang.org/x/sys/unix"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/logtree"
+)
+
+// switchRoot moves the root from initramfs into a tmpfs
+// This is necessary because you cannot pivot_root from a initramfs (and runsc wants to do that).
+// In the future, we should instead use something like squashfs instead of an initramfs and just nuke this.
+func switchRoot(log logtree.LeveledLogger) error {
+ // We detect the need to remount to tmpfs over env vars.
+ // The first run of /init (from initramfs) will not have this var, and will be re-exec'd from a new tmpfs root with
+ // that variable set.
+ witness := "SIGNOS_REMOUNTED"
+
+ // If the witness env var is found in the environment, it means we are ready to go.
+ environ := os.Environ()
+ for _, env := range environ {
+ if strings.HasPrefix(env, witness+"=") {
+ log.Info("Smalltown running in tmpfs root")
+ return nil
+ }
+ }
+
+ // Otherwise, we need to remount to a tmpfs.
+ environ = append(environ, witness+"=yes")
+ log.Info("Smalltown running in initramfs, remounting to tmpfs...")
+
+ // Make note of all directories we have to make and files that we have to copy.
+ paths := []string{}
+ dirs := []string{}
+ err := filepath.Walk("/", func(path string, info os.FileInfo, err error) error {
+ if err != nil {
+ return err
+ }
+ if path == "/" {
+ return nil
+ }
+ // /dev is prepopulated by the initramfs, skip that. The target root uses devtmpfs.
+ if path == "/dev" || strings.HasPrefix(path, "/dev/") {
+ return nil
+ }
+
+ if info.IsDir() {
+ dirs = append(dirs, path)
+ } else {
+ paths = append(paths, path)
+ }
+
+ return nil
+ })
+ if err != nil {
+ return fmt.Errorf("could not list root files: %w", err)
+ }
+
+ log.Info("Copying paths to tmpfs:")
+ for _, p := range paths {
+ log.Infof(" - %s", p)
+ }
+
+ // Make new root at /mnt
+ if err := os.Mkdir("/mnt", 0755); err != nil {
+ return fmt.Errorf("could not make /mnt: %w", err)
+ }
+ // And mount a tmpfs on it
+ if err := unix.Mount("tmpfs", "/mnt", "tmpfs", 0, ""); err != nil {
+ return fmt.Errorf("could not mount tmpfs on /mnt: %w", err)
+ }
+
+ // Make all directories. Since filepath.Walk is lexicographically ordered, we don't need to ensure that the parent
+ // exists.
+ for _, src := range dirs {
+ stat, err := os.Stat(src)
+ if err != nil {
+ return fmt.Errorf("Stat(%q): %w", src, err)
+ }
+ dst := "/mnt" + src
+ err = os.Mkdir(dst, stat.Mode())
+ if err != nil {
+ return fmt.Errorf("Mkdir(%q): %w", dst, err)
+ }
+ }
+
+ // Move all files over. Parent directories will exist by now.
+ for _, src := range paths {
+ stat, err := os.Stat(src)
+ if err != nil {
+ return fmt.Errorf("Stat(%q): %w", src, err)
+ }
+ dst := "/mnt" + src
+
+ // Copy file.
+ sfd, err := os.Open(src)
+ if err != nil {
+ return fmt.Errorf("Open(%q): %w", src, err)
+ }
+ dfd, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE, stat.Mode())
+ if err != nil {
+ sfd.Close()
+ return fmt.Errorf("OpenFile(%q): %w", dst, err)
+ }
+ _, err = io.Copy(dfd, sfd)
+
+ sfd.Close()
+ dfd.Close()
+ if err != nil {
+ return fmt.Errorf("Copying %q failed: %w", src, err)
+ }
+
+ // Remove the old file.
+ err = unix.Unlink(src)
+ if err != nil {
+ return fmt.Errorf("Unlink(%q): %w", src, err)
+ }
+ }
+
+ // Set up target filesystems.
+ for _, el := range []struct {
+ dir string
+ fs string
+ flags uintptr
+ }{
+ {"/sys", "sysfs", unix.MS_NOEXEC | unix.MS_NOSUID | unix.MS_NODEV},
+ {"/proc", "proc", unix.MS_NOEXEC | unix.MS_NOSUID | unix.MS_NODEV},
+ {"/dev", "devtmpfs", unix.MS_NOEXEC | unix.MS_NOSUID},
+ {"/dev/pts", "devpts", unix.MS_NOEXEC | unix.MS_NOSUID},
+ } {
+ if err := os.Mkdir("/mnt"+el.dir, 0755); err != nil {
+ return fmt.Errorf("could not make /mnt%s: %w", el.dir, err)
+ }
+ if err := unix.Mount(el.fs, "/mnt"+el.dir, el.fs, el.flags, ""); err != nil {
+ return fmt.Errorf("could not mount %s on /mnt%s: %w", el.fs, el.dir, err)
+ }
+ }
+
+ // Mount all available CGroups for v1 (v2 uses a single unified hierarchy and is not supported by our runtimes yet)
+ if unix.Mount("tmpfs", "/mnt/sys/fs/cgroup", "tmpfs", unix.MS_NOEXEC|unix.MS_NOSUID|unix.MS_NODEV, ""); err != nil {
+ panic(err)
+ }
+ cgroupsRaw, err := ioutil.ReadFile("/mnt/proc/cgroups")
+ if err != nil {
+ panic(err)
+ }
+
+ cgroupLines := strings.Split(string(cgroupsRaw), "\n")
+ for _, cgroupLine := range cgroupLines {
+ if cgroupLine == "" || strings.HasPrefix(cgroupLine, "#") {
+ continue
+ }
+ cgroupParts := strings.Split(cgroupLine, "\t")
+ cgroupName := cgroupParts[0]
+ if err := os.Mkdir("/mnt/sys/fs/cgroup/"+cgroupName, 0755); err != nil {
+ panic(err)
+ }
+ if err := unix.Mount("cgroup", "/mnt/sys/fs/cgroup/"+cgroupName, "cgroup", unix.MS_NOEXEC|unix.MS_NOSUID|unix.MS_NODEV, cgroupName); err != nil {
+ panic(err)
+ }
+ }
+
+ // Enable hierarchical memory accounting
+ useMemoryHierarchy, err := os.OpenFile("/mnt/sys/fs/cgroup/memory/memory.use_hierarchy", os.O_RDWR, 0)
+ if err != nil {
+ panic(err)
+ }
+ if _, err := useMemoryHierarchy.WriteString("1"); err != nil {
+ panic(err)
+ }
+ useMemoryHierarchy.Close()
+
+ // Chroot to new root.
+ // This is adapted from util-linux's switch_root.
+ err = os.Chdir("/mnt")
+ if err != nil {
+ return fmt.Errorf("could not chdir to /mnt: %w", err)
+ }
+ err = syscall.Mount("/mnt", "/", "", syscall.MS_MOVE, "")
+ if err != nil {
+ return fmt.Errorf("could not remount /mnt to /: %w", err)
+ }
+ err = syscall.Chroot(".")
+ if err != nil {
+ return fmt.Errorf("could not chroot to new root: %w", err)
+ }
+
+ // Re-exec into new init with new environment
+ return unix.Exec("/init", os.Args, environ)
+}
diff --git a/metropolis/node/core/tpm/BUILD.bazel b/metropolis/node/core/tpm/BUILD.bazel
new file mode 100644
index 0000000..fd42681
--- /dev/null
+++ b/metropolis/node/core/tpm/BUILD.bazel
@@ -0,0 +1,22 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "go_default_library",
+ srcs = [
+ "credactivation_compat.go",
+ "tpm.go",
+ ],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/node/core/tpm",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//metropolis/node/common/sysfs:go_default_library",
+ "//metropolis/node/core/logtree:go_default_library",
+ "@com_github_gogo_protobuf//proto:go_default_library",
+ "@com_github_google_go_tpm//tpm2:go_default_library",
+ "@com_github_google_go_tpm//tpmutil:go_default_library",
+ "@com_github_google_go_tpm_tools//proto:go_default_library",
+ "@com_github_google_go_tpm_tools//tpm2tools:go_default_library",
+ "@com_github_pkg_errors//:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/metropolis/node/core/tpm/credactivation_compat.go b/metropolis/node/core/tpm/credactivation_compat.go
new file mode 100644
index 0000000..039f8d5
--- /dev/null
+++ b/metropolis/node/core/tpm/credactivation_compat.go
@@ -0,0 +1,123 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tpm
+
+// This file is adapted from github.com/google/go-tpm/tpm2/credactivation which outputs broken
+// challenges for unknown reasons. They use u16 length-delimited outputs for the challenge blobs
+// which is incorrect. Rather than rewriting the routine, we only applied minimal fixes to it
+// and skip the ECC part of the issue (because we would rather trust the proprietary RSA implementation).
+//
+// TODO(lorenz): I'll eventually deal with this upstream, but for now just fix it here (it's not that)
+// much code after all (https://github.com/google/go-tpm/issues/121)
+
+import (
+ "crypto/aes"
+ "crypto/cipher"
+ "crypto/hmac"
+ "crypto/rsa"
+ "fmt"
+ "io"
+
+ "github.com/google/go-tpm/tpm2"
+ "github.com/google/go-tpm/tpmutil"
+)
+
+const (
+ labelIdentity = "IDENTITY"
+ labelStorage = "STORAGE"
+ labelIntegrity = "INTEGRITY"
+)
+
+func generateRSA(aik *tpm2.HashValue, pub *rsa.PublicKey, symBlockSize int, secret []byte, rnd io.Reader) ([]byte, []byte, error) {
+ newAIKHash, err := aik.Alg.HashConstructor()
+ if err != nil {
+ return nil, nil, err
+ }
+
+ // The seed length should match the keysize used by the EKs symmetric cipher.
+ // For typical RSA EKs, this will be 128 bits (16 bytes).
+ // Spec: TCG 2.0 EK Credential Profile revision 14, section 2.1.5.1.
+ seed := make([]byte, symBlockSize)
+ if _, err := io.ReadFull(rnd, seed); err != nil {
+ return nil, nil, fmt.Errorf("generating seed: %v", err)
+ }
+
+ // Encrypt the seed value using the provided public key.
+ // See annex B, section 10.4 of the TPM specification revision 2 part 1.
+ label := append([]byte(labelIdentity), 0)
+ encSecret, err := rsa.EncryptOAEP(newAIKHash(), rnd, pub, seed, label)
+ if err != nil {
+ return nil, nil, fmt.Errorf("generating encrypted seed: %v", err)
+ }
+
+ // Generate the encrypted credential by convolving the seed with the digest of
+ // the AIK, and using the result as the key to encrypt the secret.
+ // See section 24.4 of TPM 2.0 specification, part 1.
+ aikNameEncoded, err := aik.Encode()
+ if err != nil {
+ return nil, nil, fmt.Errorf("encoding aikName: %v", err)
+ }
+ symmetricKey, err := tpm2.KDFa(aik.Alg, seed, labelStorage, aikNameEncoded, nil, len(seed)*8)
+ if err != nil {
+ return nil, nil, fmt.Errorf("generating symmetric key: %v", err)
+ }
+ c, err := aes.NewCipher(symmetricKey)
+ if err != nil {
+ return nil, nil, fmt.Errorf("symmetric cipher setup: %v", err)
+ }
+ cv, err := tpmutil.Pack(tpmutil.U16Bytes(secret))
+ if err != nil {
+ return nil, nil, fmt.Errorf("generating cv (TPM2B_Digest): %v", err)
+ }
+
+ // IV is all null bytes. encIdentity represents the encrypted credential.
+ encIdentity := make([]byte, len(cv))
+ cipher.NewCFBEncrypter(c, make([]byte, len(symmetricKey))).XORKeyStream(encIdentity, cv)
+
+ // Generate the integrity HMAC, which is used to protect the integrity of the
+ // encrypted structure.
+ // See section 24.5 of the TPM specification revision 2 part 1.
+ macKey, err := tpm2.KDFa(aik.Alg, seed, labelIntegrity, nil, nil, newAIKHash().Size()*8)
+ if err != nil {
+ return nil, nil, fmt.Errorf("generating HMAC key: %v", err)
+ }
+
+ mac := hmac.New(newAIKHash, macKey)
+ mac.Write(encIdentity)
+ mac.Write(aikNameEncoded)
+ integrityHMAC := mac.Sum(nil)
+
+ idObject := &tpm2.IDObject{
+ IntegrityHMAC: integrityHMAC,
+ EncIdentity: encIdentity,
+ }
+ id, err := tpmutil.Pack(idObject)
+ if err != nil {
+ return nil, nil, fmt.Errorf("encoding IDObject: %v", err)
+ }
+
+ packedID, err := tpmutil.Pack(id)
+ if err != nil {
+ return nil, nil, fmt.Errorf("packing id: %v", err)
+ }
+ packedEncSecret, err := tpmutil.Pack(encSecret)
+ if err != nil {
+ return nil, nil, fmt.Errorf("packing encSecret: %v", err)
+ }
+
+ return packedID, packedEncSecret, nil
+}
diff --git a/metropolis/node/core/tpm/eventlog/BUILD.bazel b/metropolis/node/core/tpm/eventlog/BUILD.bazel
new file mode 100644
index 0000000..64fa1ff
--- /dev/null
+++ b/metropolis/node/core/tpm/eventlog/BUILD.bazel
@@ -0,0 +1,17 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "go_default_library",
+ srcs = [
+ "compat.go",
+ "eventlog.go",
+ "secureboot.go",
+ ],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/node/core/tpm/eventlog",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//metropolis/node/core/tpm/eventlog/internal:go_default_library",
+ "@com_github_google_certificate_transparency_go//x509:go_default_library",
+ "@com_github_google_go_tpm//tpm2:go_default_library",
+ ],
+)
diff --git a/metropolis/node/core/tpm/eventlog/LICENSE-3RD-PARTY.txt b/metropolis/node/core/tpm/eventlog/LICENSE-3RD-PARTY.txt
new file mode 100644
index 0000000..2d3298c
--- /dev/null
+++ b/metropolis/node/core/tpm/eventlog/LICENSE-3RD-PARTY.txt
@@ -0,0 +1,12 @@
+Copyright 2020 Google Inc.
+Licensed under the Apache License, Version 2.0 (the "License"); you may not
+use this file except in compliance with the License. You may obtain a copy of
+the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+License for the specific language governing permissions and limitations under
+the License.
\ No newline at end of file
diff --git a/metropolis/node/core/tpm/eventlog/compat.go b/metropolis/node/core/tpm/eventlog/compat.go
new file mode 100644
index 0000000..f83972b
--- /dev/null
+++ b/metropolis/node/core/tpm/eventlog/compat.go
@@ -0,0 +1,32 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package eventlog
+
+// This file contains compatibility functions for our TPM library
+
+import (
+ "crypto"
+)
+
+// ConvertRawPCRs converts from raw PCRs to eventlog PCR structures
+func ConvertRawPCRs(pcrs [][]byte) []PCR {
+ var evPCRs []PCR
+ for i, digest := range pcrs {
+ evPCRs = append(evPCRs, PCR{DigestAlg: crypto.SHA256, Index: i, Digest: digest})
+ }
+ return evPCRs
+}
diff --git a/metropolis/node/core/tpm/eventlog/eventlog.go b/metropolis/node/core/tpm/eventlog/eventlog.go
new file mode 100644
index 0000000..49a8a26
--- /dev/null
+++ b/metropolis/node/core/tpm/eventlog/eventlog.go
@@ -0,0 +1,646 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Taken and pruned from go-attestation revision 2453c8f39a4ff46009f6a9db6fb7c6cca789d9a1 under Apache 2.0
+
+package eventlog
+
+import (
+ "bytes"
+ "crypto"
+ "crypto/sha1"
+ "crypto/sha256"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "sort"
+
+ // Ensure hashes are available.
+ _ "crypto/sha256"
+
+ "github.com/google/go-tpm/tpm2"
+)
+
+// HashAlg identifies a hashing Algorithm.
+type HashAlg uint8
+
+// Valid hash algorithms.
+var (
+ HashSHA1 = HashAlg(tpm2.AlgSHA1)
+ HashSHA256 = HashAlg(tpm2.AlgSHA256)
+)
+
+func (a HashAlg) cryptoHash() crypto.Hash {
+ switch a {
+ case HashSHA1:
+ return crypto.SHA1
+ case HashSHA256:
+ return crypto.SHA256
+ }
+ return 0
+}
+
+func (a HashAlg) goTPMAlg() tpm2.Algorithm {
+ switch a {
+ case HashSHA1:
+ return tpm2.AlgSHA1
+ case HashSHA256:
+ return tpm2.AlgSHA256
+ }
+ return 0
+}
+
+// String returns a human-friendly representation of the hash algorithm.
+func (a HashAlg) String() string {
+ switch a {
+ case HashSHA1:
+ return "SHA1"
+ case HashSHA256:
+ return "SHA256"
+ }
+ return fmt.Sprintf("HashAlg<%d>", int(a))
+}
+
+// ReplayError describes the parsed events that failed to verify against
+// a particular PCR.
+type ReplayError struct {
+ Events []Event
+ invalidPCRs []int
+}
+
+func (e ReplayError) affected(pcr int) bool {
+ for _, p := range e.invalidPCRs {
+ if p == pcr {
+ return true
+ }
+ }
+ return false
+}
+
+// Error returns a human-friendly description of replay failures.
+func (e ReplayError) Error() string {
+ return fmt.Sprintf("event log failed to verify: the following registers failed to replay: %v", e.invalidPCRs)
+}
+
+// TPM algorithms. See the TPM 2.0 specification section 6.3.
+//
+// https://trustedcomputinggroup.org/wp-content/uploads/TPM-Rev-2.0-Part-2-Structures-01.38.pdf#page=42
+const (
+ algSHA1 uint16 = 0x0004
+ algSHA256 uint16 = 0x000B
+)
+
+// EventType indicates what kind of data an event is reporting.
+type EventType uint32
+
+// Event is a single event from a TCG event log. This reports descrete items such
+// as BIOs measurements or EFI states.
+type Event struct {
+ // order of the event in the event log.
+ sequence int
+
+ // PCR index of the event.
+ Index int
+ // Type of the event.
+ Type EventType
+
+ // Data of the event. For certain kinds of events, this must match the event
+ // digest to be valid.
+ Data []byte
+ // Digest is the verified digest of the event data. While an event can have
+ // multiple for different hash values, this is the one that was matched to the
+ // PCR value.
+ Digest []byte
+
+ // TODO(ericchiang): Provide examples or links for which event types must
+ // match their data to their digest.
+}
+
+func (e *Event) digestEquals(b []byte) error {
+ if len(e.Digest) == 0 {
+ return errors.New("no digests present")
+ }
+
+ switch len(e.Digest) {
+ case crypto.SHA256.Size():
+ s := sha256.Sum256(b)
+ if bytes.Equal(s[:], e.Digest) {
+ return nil
+ }
+ case crypto.SHA1.Size():
+ s := sha1.Sum(b)
+ if bytes.Equal(s[:], e.Digest) {
+ return nil
+ }
+ default:
+ return fmt.Errorf("cannot compare hash of length %d", len(e.Digest))
+ }
+
+ return fmt.Errorf("digest (len %d) does not match", len(e.Digest))
+}
+
+// EventLog is a parsed measurement log. This contains unverified data representing
+// boot events that must be replayed against PCR values to determine authenticity.
+type EventLog struct {
+ // Algs holds the set of algorithms that the event log uses.
+ Algs []HashAlg
+
+ rawEvents []rawEvent
+}
+
+func (e *EventLog) clone() *EventLog {
+ out := EventLog{
+ Algs: make([]HashAlg, len(e.Algs)),
+ rawEvents: make([]rawEvent, len(e.rawEvents)),
+ }
+ copy(out.Algs, e.Algs)
+ copy(out.rawEvents, e.rawEvents)
+ return &out
+}
+
+type elWorkaround struct {
+ id string
+ affectedPCR int
+ apply func(e *EventLog) error
+}
+
+// inject3 appends two new events into the event log.
+func inject3(e *EventLog, pcr int, data1, data2, data3 string) error {
+ if err := inject(e, pcr, data1); err != nil {
+ return err
+ }
+ if err := inject(e, pcr, data2); err != nil {
+ return err
+ }
+ return inject(e, pcr, data3)
+}
+
+// inject2 appends two new events into the event log.
+func inject2(e *EventLog, pcr int, data1, data2 string) error {
+ if err := inject(e, pcr, data1); err != nil {
+ return err
+ }
+ return inject(e, pcr, data2)
+}
+
+// inject appends a new event into the event log.
+func inject(e *EventLog, pcr int, data string) error {
+ evt := rawEvent{
+ data: []byte(data),
+ index: pcr,
+ sequence: e.rawEvents[len(e.rawEvents)-1].sequence + 1,
+ }
+ for _, alg := range e.Algs {
+ h := alg.cryptoHash().New()
+ h.Write([]byte(data))
+ evt.digests = append(evt.digests, digest{hash: alg.cryptoHash(), data: h.Sum(nil)})
+ }
+ e.rawEvents = append(e.rawEvents, evt)
+ return nil
+}
+
+const (
+ ebsInvocation = "Exit Boot Services Invocation"
+ ebsSuccess = "Exit Boot Services Returned with Success"
+ ebsFailure = "Exit Boot Services Returned with Failure"
+)
+
+var eventlogWorkarounds = []elWorkaround{
+ {
+ id: "EBS Invocation + Success",
+ affectedPCR: 5,
+ apply: func(e *EventLog) error {
+ return inject2(e, 5, ebsInvocation, ebsSuccess)
+ },
+ },
+ {
+ id: "EBS Invocation + Failure",
+ affectedPCR: 5,
+ apply: func(e *EventLog) error {
+ return inject2(e, 5, ebsInvocation, ebsFailure)
+ },
+ },
+ {
+ id: "EBS Invocation + Failure + Success",
+ affectedPCR: 5,
+ apply: func(e *EventLog) error {
+ return inject3(e, 5, ebsInvocation, ebsFailure, ebsSuccess)
+ },
+ },
+}
+
+// Verify replays the event log against a TPM's PCR values, returning the
+// events which could be matched to a provided PCR value.
+// An error is returned if the replayed digest for events with a given PCR
+// index do not match any provided value for that PCR index.
+func (e *EventLog) Verify(pcrs []PCR) ([]Event, error) {
+ events, err := e.verify(pcrs)
+ // If there were any issues replaying the PCRs, try each of the workarounds
+ // in turn.
+ // TODO(jsonp): Allow workarounds to be combined.
+ if rErr, isReplayErr := err.(ReplayError); isReplayErr {
+ for _, wkrd := range eventlogWorkarounds {
+ if !rErr.affected(wkrd.affectedPCR) {
+ continue
+ }
+ el := e.clone()
+ if err := wkrd.apply(el); err != nil {
+ return nil, fmt.Errorf("failed applying workaround %q: %v", wkrd.id, err)
+ }
+ if events, err := el.verify(pcrs); err == nil {
+ return events, nil
+ }
+ }
+ }
+
+ return events, err
+}
+
+// PCR encapsulates the value of a PCR at a point in time.
+type PCR struct {
+ Index int
+ Digest []byte
+ DigestAlg crypto.Hash
+}
+
+func (e *EventLog) verify(pcrs []PCR) ([]Event, error) {
+ events, err := replayEvents(e.rawEvents, pcrs)
+ if err != nil {
+ if _, isReplayErr := err.(ReplayError); isReplayErr {
+ return nil, err
+ }
+ return nil, fmt.Errorf("pcrs failed to replay: %v", err)
+ }
+ return events, nil
+}
+
+func extend(pcr PCR, replay []byte, e rawEvent) (pcrDigest []byte, eventDigest []byte, err error) {
+ h := pcr.DigestAlg
+
+ for _, digest := range e.digests {
+ if digest.hash != pcr.DigestAlg {
+ continue
+ }
+ if len(digest.data) != len(pcr.Digest) {
+ return nil, nil, fmt.Errorf("digest data length (%d) doesn't match PCR digest length (%d)", len(digest.data), len(pcr.Digest))
+ }
+ hash := h.New()
+ if len(replay) != 0 {
+ hash.Write(replay)
+ } else {
+ b := make([]byte, h.Size())
+ hash.Write(b)
+ }
+ hash.Write(digest.data)
+ return hash.Sum(nil), digest.data, nil
+ }
+ return nil, nil, fmt.Errorf("no event digest matches pcr algorithm: %v", pcr.DigestAlg)
+}
+
+// replayPCR replays the event log for a specific PCR, using pcr and
+// event digests with the algorithm in pcr. An error is returned if the
+// replayed values do not match the final PCR digest, or any event tagged
+// with that PCR does not posess an event digest with the specified algorithm.
+func replayPCR(rawEvents []rawEvent, pcr PCR) ([]Event, bool) {
+ var (
+ replay []byte
+ outEvents []Event
+ )
+
+ for _, e := range rawEvents {
+ if e.index != pcr.Index {
+ continue
+ }
+
+ replayValue, digest, err := extend(pcr, replay, e)
+ if err != nil {
+ return nil, false
+ }
+ replay = replayValue
+ outEvents = append(outEvents, Event{sequence: e.sequence, Data: e.data, Digest: digest, Index: pcr.Index, Type: e.typ})
+ }
+
+ if len(outEvents) > 0 && !bytes.Equal(replay, pcr.Digest) {
+ return nil, false
+ }
+ return outEvents, true
+}
+
+type pcrReplayResult struct {
+ events []Event
+ successful bool
+}
+
+func replayEvents(rawEvents []rawEvent, pcrs []PCR) ([]Event, error) {
+ var (
+ invalidReplays []int
+ verifiedEvents []Event
+ allPCRReplays = map[int][]pcrReplayResult{}
+ )
+
+ // Replay the event log for every PCR and digest algorithm combination.
+ for _, pcr := range pcrs {
+ events, ok := replayPCR(rawEvents, pcr)
+ allPCRReplays[pcr.Index] = append(allPCRReplays[pcr.Index], pcrReplayResult{events, ok})
+ }
+
+ // Record PCR indices which do not have any successful replay. Record the
+ // events for a successful replay.
+pcrLoop:
+ for i, replaysForPCR := range allPCRReplays {
+ for _, replay := range replaysForPCR {
+ if replay.successful {
+ // We consider the PCR verified at this stage: The replay of values with
+ // one digest algorithm matched a provided value.
+ // As such, we save the PCR's events, and proceed to the next PCR.
+ verifiedEvents = append(verifiedEvents, replay.events...)
+ continue pcrLoop
+ }
+ }
+ invalidReplays = append(invalidReplays, i)
+ }
+
+ if len(invalidReplays) > 0 {
+ events := make([]Event, 0, len(rawEvents))
+ for _, e := range rawEvents {
+ events = append(events, Event{e.sequence, e.index, e.typ, e.data, nil})
+ }
+ return nil, ReplayError{
+ Events: events,
+ invalidPCRs: invalidReplays,
+ }
+ }
+
+ sort.Slice(verifiedEvents, func(i int, j int) bool {
+ return verifiedEvents[i].sequence < verifiedEvents[j].sequence
+ })
+ return verifiedEvents, nil
+}
+
+// EV_NO_ACTION is a special event type that indicates information to the parser
+// instead of holding a measurement. For TPM 2.0, this event type is used to signal
+// switching from SHA1 format to a variable length digest.
+//
+// https://trustedcomputinggroup.org/wp-content/uploads/TCG_PCClientSpecPlat_TPM_2p0_1p04_pub.pdf#page=110
+const eventTypeNoAction = 0x03
+
+// ParseEventLog parses an unverified measurement log.
+func ParseEventLog(measurementLog []byte) (*EventLog, error) {
+ var specID *specIDEvent
+ r := bytes.NewBuffer(measurementLog)
+ parseFn := parseRawEvent
+ var el EventLog
+ e, err := parseFn(r, specID)
+ if err != nil {
+ return nil, fmt.Errorf("parse first event: %v", err)
+ }
+ if e.typ == eventTypeNoAction {
+ specID, err = parseSpecIDEvent(e.data)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse spec ID event: %v", err)
+ }
+ for _, alg := range specID.algs {
+ switch tpm2.Algorithm(alg.ID) {
+ case tpm2.AlgSHA1:
+ el.Algs = append(el.Algs, HashSHA1)
+ case tpm2.AlgSHA256:
+ el.Algs = append(el.Algs, HashSHA256)
+ }
+ }
+ if len(el.Algs) == 0 {
+ return nil, fmt.Errorf("measurement log didn't use sha1 or sha256 digests")
+ }
+ // Switch to parsing crypto agile events. Don't include this in the
+ // replayed events since it intentionally doesn't extend the PCRs.
+ //
+ // Note that this doesn't actually guarentee that events have SHA256
+ // digests.
+ parseFn = parseRawEvent2
+ } else {
+ el.Algs = []HashAlg{HashSHA1}
+ el.rawEvents = append(el.rawEvents, e)
+ }
+ sequence := 1
+ for r.Len() != 0 {
+ e, err := parseFn(r, specID)
+ if err != nil {
+ return nil, err
+ }
+ e.sequence = sequence
+ sequence++
+ el.rawEvents = append(el.rawEvents, e)
+ }
+ return &el, nil
+}
+
+type specIDEvent struct {
+ algs []specAlgSize
+}
+
+type specAlgSize struct {
+ ID uint16
+ Size uint16
+}
+
+// Expected values for various Spec ID Event fields.
+// https://trustedcomputinggroup.org/wp-content/uploads/EFI-Protocol-Specification-rev13-160330final.pdf#page=19
+var wantSignature = [16]byte{0x53, 0x70,
+ 0x65, 0x63, 0x20, 0x49,
+ 0x44, 0x20, 0x45, 0x76,
+ 0x65, 0x6e, 0x74, 0x30,
+ 0x33, 0x00} // "Spec ID Event03\0"
+
+const (
+ wantMajor = 2
+ wantMinor = 0
+ wantErrata = 0
+)
+
+// parseSpecIDEvent parses a TCG_EfiSpecIDEventStruct structure from the reader.
+//
+// https://trustedcomputinggroup.org/wp-content/uploads/EFI-Protocol-Specification-rev13-160330final.pdf#page=18
+func parseSpecIDEvent(b []byte) (*specIDEvent, error) {
+ r := bytes.NewReader(b)
+ var header struct {
+ Signature [16]byte
+ PlatformClass uint32
+ VersionMinor uint8
+ VersionMajor uint8
+ Errata uint8
+ UintnSize uint8
+ NumAlgs uint32
+ }
+ if err := binary.Read(r, binary.LittleEndian, &header); err != nil {
+ return nil, fmt.Errorf("reading event header: %v", err)
+ }
+ if header.Signature != wantSignature {
+ return nil, fmt.Errorf("invalid spec id signature: %x", header.Signature)
+ }
+ if header.VersionMajor != wantMajor {
+ return nil, fmt.Errorf("invalid spec major version, got %02x, wanted %02x",
+ header.VersionMajor, wantMajor)
+ }
+ if header.VersionMinor != wantMinor {
+ return nil, fmt.Errorf("invalid spec minor version, got %02x, wanted %02x",
+ header.VersionMajor, wantMinor)
+ }
+
+ // TODO(ericchiang): Check errata? Or do we expect that to change in ways
+ // we're okay with?
+
+ specAlg := specAlgSize{}
+ e := specIDEvent{}
+ for i := 0; i < int(header.NumAlgs); i++ {
+ if err := binary.Read(r, binary.LittleEndian, &specAlg); err != nil {
+ return nil, fmt.Errorf("reading algorithm: %v", err)
+ }
+ e.algs = append(e.algs, specAlg)
+ }
+
+ var vendorInfoSize uint8
+ if err := binary.Read(r, binary.LittleEndian, &vendorInfoSize); err != nil {
+ return nil, fmt.Errorf("reading vender info size: %v", err)
+ }
+ if r.Len() != int(vendorInfoSize) {
+ return nil, fmt.Errorf("reading vendor info, expected %d remaining bytes, got %d", vendorInfoSize, r.Len())
+ }
+ return &e, nil
+}
+
+type digest struct {
+ hash crypto.Hash
+ data []byte
+}
+
+type rawEvent struct {
+ sequence int
+ index int
+ typ EventType
+ data []byte
+ digests []digest
+}
+
+// TPM 1.2 event log format. See "5.1 SHA1 Event Log Entry Format"
+// https://trustedcomputinggroup.org/wp-content/uploads/EFI-Protocol-Specification-rev13-160330final.pdf#page=15
+type rawEventHeader struct {
+ PCRIndex uint32
+ Type uint32
+ Digest [20]byte
+ EventSize uint32
+}
+
+type eventSizeErr struct {
+ eventSize uint32
+ logSize int
+}
+
+func (e *eventSizeErr) Error() string {
+ return fmt.Sprintf("event data size (%d bytes) is greater than remaining measurement log (%d bytes)", e.eventSize, e.logSize)
+}
+
+func parseRawEvent(r *bytes.Buffer, specID *specIDEvent) (event rawEvent, err error) {
+ var h rawEventHeader
+ if err = binary.Read(r, binary.LittleEndian, &h); err != nil {
+ return event, err
+ }
+ if h.EventSize == 0 {
+ return event, errors.New("event data size is 0")
+ }
+ if h.EventSize > uint32(r.Len()) {
+ return event, &eventSizeErr{h.EventSize, r.Len()}
+ }
+
+ data := make([]byte, int(h.EventSize))
+ if _, err := io.ReadFull(r, data); err != nil {
+ return event, err
+ }
+
+ digests := []digest{{hash: crypto.SHA1, data: h.Digest[:]}}
+
+ return rawEvent{
+ typ: EventType(h.Type),
+ data: data,
+ index: int(h.PCRIndex),
+ digests: digests,
+ }, nil
+}
+
+// TPM 2.0 event log format. See "5.2 Crypto Agile Log Entry Format"
+// https://trustedcomputinggroup.org/wp-content/uploads/EFI-Protocol-Specification-rev13-160330final.pdf#page=15
+type rawEvent2Header struct {
+ PCRIndex uint32
+ Type uint32
+}
+
+func parseRawEvent2(r *bytes.Buffer, specID *specIDEvent) (event rawEvent, err error) {
+ var h rawEvent2Header
+
+ if err = binary.Read(r, binary.LittleEndian, &h); err != nil {
+ return event, err
+ }
+ event.typ = EventType(h.Type)
+ event.index = int(h.PCRIndex)
+
+ // parse the event digests
+ var numDigests uint32
+ if err := binary.Read(r, binary.LittleEndian, &numDigests); err != nil {
+ return event, err
+ }
+
+ for i := 0; i < int(numDigests); i++ {
+ var algID uint16
+ if err := binary.Read(r, binary.LittleEndian, &algID); err != nil {
+ return event, err
+ }
+ var digest digest
+
+ for _, alg := range specID.algs {
+ if alg.ID != algID {
+ continue
+ }
+ if uint16(r.Len()) < alg.Size {
+ return event, fmt.Errorf("reading digest: %v", io.ErrUnexpectedEOF)
+ }
+ digest.data = make([]byte, alg.Size)
+ digest.hash = HashAlg(alg.ID).cryptoHash()
+ }
+ if len(digest.data) == 0 {
+ return event, fmt.Errorf("unknown algorithm ID %x", algID)
+ }
+ if _, err := io.ReadFull(r, digest.data); err != nil {
+ return event, err
+ }
+ event.digests = append(event.digests, digest)
+ }
+
+ // parse event data
+ var eventSize uint32
+ if err = binary.Read(r, binary.LittleEndian, &eventSize); err != nil {
+ return event, err
+ }
+ if eventSize == 0 {
+ return event, errors.New("event data size is 0")
+ }
+ if eventSize > uint32(r.Len()) {
+ return event, &eventSizeErr{eventSize, r.Len()}
+ }
+ event.data = make([]byte, int(eventSize))
+ if _, err := io.ReadFull(r, event.data); err != nil {
+ return event, err
+ }
+ return event, err
+}
diff --git a/metropolis/node/core/tpm/eventlog/internal/BUILD.bazel b/metropolis/node/core/tpm/eventlog/internal/BUILD.bazel
new file mode 100644
index 0000000..48e1e81
--- /dev/null
+++ b/metropolis/node/core/tpm/eventlog/internal/BUILD.bazel
@@ -0,0 +1,12 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "go_default_library",
+ srcs = ["events.go"],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/node/core/tpm/eventlog/internal",
+ visibility = ["//metropolis/node/core/tpm/eventlog:__subpackages__"],
+ deps = [
+ "@com_github_google_certificate_transparency_go//asn1:go_default_library",
+ "@com_github_google_certificate_transparency_go//x509:go_default_library",
+ ],
+)
diff --git a/metropolis/node/core/tpm/eventlog/internal/events.go b/metropolis/node/core/tpm/eventlog/internal/events.go
new file mode 100644
index 0000000..d9b933b
--- /dev/null
+++ b/metropolis/node/core/tpm/eventlog/internal/events.go
@@ -0,0 +1,403 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Taken from go-attestation under Apache 2.0
+package internal
+
+import (
+ "bytes"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "unicode/utf16"
+
+ "github.com/google/certificate-transparency-go/asn1"
+ "github.com/google/certificate-transparency-go/x509"
+)
+
+const (
+ // maxNameLen is the maximum accepted byte length for a name field.
+ // This value should be larger than any reasonable value.
+ maxNameLen = 2048
+ // maxDataLen is the maximum size in bytes of a variable data field.
+ // This value should be larger than any reasonable value.
+ maxDataLen = 1024 * 1024 // 1 Megabyte.
+)
+
+// GUIDs representing the contents of an UEFI_SIGNATURE_LIST.
+var (
+ hashSHA256SigGUID = efiGUID{0xc1c41626, 0x504c, 0x4092, [8]byte{0xac, 0xa9, 0x41, 0xf9, 0x36, 0x93, 0x43, 0x28}}
+ hashSHA1SigGUID = efiGUID{0x826ca512, 0xcf10, 0x4ac9, [8]byte{0xb1, 0x87, 0xbe, 0x01, 0x49, 0x66, 0x31, 0xbd}}
+ hashSHA224SigGUID = efiGUID{0x0b6e5233, 0xa65c, 0x44c9, [8]byte{0x94, 0x07, 0xd9, 0xab, 0x83, 0xbf, 0xc8, 0xbd}}
+ hashSHA384SigGUID = efiGUID{0xff3e5307, 0x9fd0, 0x48c9, [8]byte{0x85, 0xf1, 0x8a, 0xd5, 0x6c, 0x70, 0x1e, 0x01}}
+ hashSHA512SigGUID = efiGUID{0x093e0fae, 0xa6c4, 0x4f50, [8]byte{0x9f, 0x1b, 0xd4, 0x1e, 0x2b, 0x89, 0xc1, 0x9a}}
+ keyRSA2048SigGUID = efiGUID{0x3c5766e8, 0x269c, 0x4e34, [8]byte{0xaa, 0x14, 0xed, 0x77, 0x6e, 0x85, 0xb3, 0xb6}}
+ certRSA2048SHA256SigGUID = efiGUID{0xe2b36190, 0x879b, 0x4a3d, [8]byte{0xad, 0x8d, 0xf2, 0xe7, 0xbb, 0xa3, 0x27, 0x84}}
+ certRSA2048SHA1SigGUID = efiGUID{0x67f8444f, 0x8743, 0x48f1, [8]byte{0xa3, 0x28, 0x1e, 0xaa, 0xb8, 0x73, 0x60, 0x80}}
+ certX509SigGUID = efiGUID{0xa5c059a1, 0x94e4, 0x4aa7, [8]byte{0x87, 0xb5, 0xab, 0x15, 0x5c, 0x2b, 0xf0, 0x72}}
+ certHashSHA256SigGUID = efiGUID{0x3bd2a492, 0x96c0, 0x4079, [8]byte{0xb4, 0x20, 0xfc, 0xf9, 0x8e, 0xf1, 0x03, 0xed}}
+ certHashSHA384SigGUID = efiGUID{0x7076876e, 0x80c2, 0x4ee6, [8]byte{0xaa, 0xd2, 0x28, 0xb3, 0x49, 0xa6, 0x86, 0x5b}}
+ certHashSHA512SigGUID = efiGUID{0x446dbf63, 0x2502, 0x4cda, [8]byte{0xbc, 0xfa, 0x24, 0x65, 0xd2, 0xb0, 0xfe, 0x9d}}
+)
+
+// EventType describes the type of event signalled in the event log.
+type EventType uint32
+
+// BIOS Events (TCG PC Client Specific Implementation Specification for Conventional BIOS 1.21)
+const (
+ PrebootCert EventType = 0x00000000
+ PostCode EventType = 0x00000001
+ unused EventType = 0x00000002
+ NoAction EventType = 0x00000003
+ Separator EventType = 0x00000004
+ Action EventType = 0x00000005
+ EventTag EventType = 0x00000006
+ SCRTMContents EventType = 0x00000007
+ SCRTMVersion EventType = 0x00000008
+ CpuMicrocode EventType = 0x00000009
+ PlatformConfigFlags EventType = 0x0000000A
+ TableOfDevices EventType = 0x0000000B
+ CompactHash EventType = 0x0000000C
+ Ipl EventType = 0x0000000D
+ IplPartitionData EventType = 0x0000000E
+ NonhostCode EventType = 0x0000000F
+ NonhostConfig EventType = 0x00000010
+ NonhostInfo EventType = 0x00000011
+ OmitBootDeviceEvents EventType = 0x00000012
+)
+
+// EFI Events (TCG EFI Platform Specification Version 1.22)
+const (
+ EFIEventBase EventType = 0x80000000
+ EFIVariableDriverConfig EventType = 0x80000001
+ EFIVariableBoot EventType = 0x80000002
+ EFIBootServicesApplication EventType = 0x80000003
+ EFIBootServicesDriver EventType = 0x80000004
+ EFIRuntimeServicesDriver EventType = 0x80000005
+ EFIGPTEvent EventType = 0x80000006
+ EFIAction EventType = 0x80000007
+ EFIPlatformFirmwareBlob EventType = 0x80000008
+ EFIHandoffTables EventType = 0x80000009
+ EFIHCRTMEvent EventType = 0x80000010
+ EFIVariableAuthority EventType = 0x800000e0
+)
+
+// ErrSigMissingGUID is returned if an EFI_SIGNATURE_DATA structure was parsed
+// successfully, however was missing the SignatureOwner GUID. This case is
+// handled specially as a workaround for a bug relating to authority events.
+var ErrSigMissingGUID = errors.New("signature data was missing owner GUID")
+
+var eventTypeNames = map[EventType]string{
+ PrebootCert: "Preboot Cert",
+ PostCode: "POST Code",
+ unused: "Unused",
+ NoAction: "No Action",
+ Separator: "Separator",
+ Action: "Action",
+ EventTag: "Event Tag",
+ SCRTMContents: "S-CRTM Contents",
+ SCRTMVersion: "S-CRTM Version",
+ CpuMicrocode: "CPU Microcode",
+ PlatformConfigFlags: "Platform Config Flags",
+ TableOfDevices: "Table of Devices",
+ CompactHash: "Compact Hash",
+ Ipl: "IPL",
+ IplPartitionData: "IPL Partition Data",
+ NonhostCode: "Non-Host Code",
+ NonhostConfig: "Non-HostConfig",
+ NonhostInfo: "Non-Host Info",
+ OmitBootDeviceEvents: "Omit Boot Device Events",
+
+ EFIEventBase: "EFI Event Base",
+ EFIVariableDriverConfig: "EFI Variable Driver Config",
+ EFIVariableBoot: "EFI Variable Boot",
+ EFIBootServicesApplication: "EFI Boot Services Application",
+ EFIBootServicesDriver: "EFI Boot Services Driver",
+ EFIRuntimeServicesDriver: "EFI Runtime Services Driver",
+ EFIGPTEvent: "EFI GPT Event",
+ EFIAction: "EFI Action",
+ EFIPlatformFirmwareBlob: "EFI Platform Firmware Blob",
+ EFIVariableAuthority: "EFI Variable Authority",
+ EFIHandoffTables: "EFI Handoff Tables",
+ EFIHCRTMEvent: "EFI H-CRTM Event",
+}
+
+func (e EventType) String() string {
+ if s, ok := eventTypeNames[e]; ok {
+ return s
+ }
+ return fmt.Sprintf("EventType(0x%x)", uint32(e))
+}
+
+// UntrustedParseEventType returns the event type indicated by
+// the provided value.
+func UntrustedParseEventType(et uint32) (EventType, error) {
+ // "The value associated with a UEFI specific platform event type MUST be in
+ // the range between 0x80000000 and 0x800000FF, inclusive."
+ if (et < 0x80000000 && et > 0x800000FF) || (et < 0x0 && et > 0x12) {
+ return EventType(0), fmt.Errorf("event type not between [0x0, 0x12] or [0x80000000, 0x800000FF]: got %#x", et)
+ }
+ if _, ok := eventTypeNames[EventType(et)]; !ok {
+ return EventType(0), fmt.Errorf("unknown event type %#x", et)
+ }
+ return EventType(et), nil
+}
+
+// efiGUID represents the EFI_GUID type.
+// See section "2.3.1 Data Types" in the specification for more information.
+// type efiGUID [16]byte
+type efiGUID struct {
+ Data1 uint32
+ Data2 uint16
+ Data3 uint16
+ Data4 [8]byte
+}
+
+func (d efiGUID) String() string {
+ var u [8]byte
+ binary.BigEndian.PutUint32(u[:4], d.Data1)
+ binary.BigEndian.PutUint16(u[4:6], d.Data2)
+ binary.BigEndian.PutUint16(u[6:8], d.Data3)
+ return fmt.Sprintf("%x-%x-%x-%x-%x", u[:4], u[4:6], u[6:8], d.Data4[:2], d.Data4[2:])
+}
+
+// UEFIVariableDataHeader represents the leading fixed-size fields
+// within UEFI_VARIABLE_DATA.
+type UEFIVariableDataHeader struct {
+ VariableName efiGUID
+ UnicodeNameLength uint64 // uintN
+ VariableDataLength uint64 // uintN
+}
+
+// UEFIVariableData represents the UEFI_VARIABLE_DATA structure.
+type UEFIVariableData struct {
+ Header UEFIVariableDataHeader
+ UnicodeName []uint16
+ VariableData []byte // []int8
+}
+
+// ParseUEFIVariableData parses the data section of an event structured as
+// a UEFI variable.
+//
+// https://trustedcomputinggroup.org/wp-content/uploads/TCG_PCClient_Specific_Platform_Profile_for_TPM_2p0_1p04_PUBLIC.pdf#page=100
+func ParseUEFIVariableData(r io.Reader) (ret UEFIVariableData, err error) {
+ err = binary.Read(r, binary.LittleEndian, &ret.Header)
+ if err != nil {
+ return
+ }
+ if ret.Header.UnicodeNameLength > maxNameLen {
+ return UEFIVariableData{}, fmt.Errorf("unicode name too long: %d > %d", ret.Header.UnicodeNameLength, maxNameLen)
+ }
+ ret.UnicodeName = make([]uint16, ret.Header.UnicodeNameLength)
+ for i := 0; uint64(i) < ret.Header.UnicodeNameLength; i++ {
+ err = binary.Read(r, binary.LittleEndian, &ret.UnicodeName[i])
+ if err != nil {
+ return
+ }
+ }
+ if ret.Header.VariableDataLength > maxDataLen {
+ return UEFIVariableData{}, fmt.Errorf("variable data too long: %d > %d", ret.Header.VariableDataLength, maxDataLen)
+ }
+ ret.VariableData = make([]byte, ret.Header.VariableDataLength)
+ _, err = io.ReadFull(r, ret.VariableData)
+ return
+}
+
+func (v *UEFIVariableData) VarName() string {
+ return string(utf16.Decode(v.UnicodeName))
+}
+
+func (v *UEFIVariableData) SignatureData() (certs []x509.Certificate, hashes [][]byte, err error) {
+ return parseEfiSignatureList(v.VariableData)
+}
+
+// UEFIVariableAuthority describes the contents of a UEFI variable authority
+// event.
+type UEFIVariableAuthority struct {
+ Certs []x509.Certificate
+}
+
+// ParseUEFIVariableAuthority parses the data section of an event structured as
+// a UEFI variable authority.
+//
+// https://uefi.org/sites/default/files/resources/UEFI_Spec_2_8_final.pdf#page=1789
+func ParseUEFIVariableAuthority(r io.Reader) (UEFIVariableAuthority, error) {
+ v, err := ParseUEFIVariableData(r)
+ if err != nil {
+ return UEFIVariableAuthority{}, err
+ }
+ certs, err := parseEfiSignature(v.VariableData)
+ return UEFIVariableAuthority{Certs: certs}, err
+}
+
+// efiSignatureData represents the EFI_SIGNATURE_DATA type.
+// See section "31.4.1 Signature Database" in the specification for more information.
+type efiSignatureData struct {
+ SignatureOwner efiGUID
+ SignatureData []byte // []int8
+}
+
+// efiSignatureList represents the EFI_SIGNATURE_LIST type.
+// See section "31.4.1 Signature Database" in the specification for more information.
+type efiSignatureListHeader struct {
+ SignatureType efiGUID
+ SignatureListSize uint32
+ SignatureHeaderSize uint32
+ SignatureSize uint32
+}
+
+type efiSignatureList struct {
+ Header efiSignatureListHeader
+ SignatureData []byte
+ Signatures []byte
+}
+
+// parseEfiSignatureList parses a EFI_SIGNATURE_LIST structure.
+// The structure and related GUIDs are defined at:
+// https://uefi.org/sites/default/files/resources/UEFI_Spec_2_8_final.pdf#page=1790
+func parseEfiSignatureList(b []byte) ([]x509.Certificate, [][]byte, error) {
+ if len(b) < 28 {
+ // Being passed an empty signature list here appears to be valid
+ return nil, nil, nil
+ }
+ signatures := efiSignatureList{}
+ buf := bytes.NewReader(b)
+ certificates := []x509.Certificate{}
+ hashes := [][]byte{}
+
+ for buf.Len() > 0 {
+ err := binary.Read(buf, binary.LittleEndian, &signatures.Header)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ if signatures.Header.SignatureHeaderSize > maxDataLen {
+ return nil, nil, fmt.Errorf("signature header too large: %d > %d", signatures.Header.SignatureHeaderSize, maxDataLen)
+ }
+ if signatures.Header.SignatureListSize > maxDataLen {
+ return nil, nil, fmt.Errorf("signature list too large: %d > %d", signatures.Header.SignatureListSize, maxDataLen)
+ }
+
+ signatureType := signatures.Header.SignatureType
+ switch signatureType {
+ case certX509SigGUID: // X509 certificate
+ for sigOffset := 0; uint32(sigOffset) < signatures.Header.SignatureListSize-28; {
+ signature := efiSignatureData{}
+ signature.SignatureData = make([]byte, signatures.Header.SignatureSize-16)
+ err := binary.Read(buf, binary.LittleEndian, &signature.SignatureOwner)
+ if err != nil {
+ return nil, nil, err
+ }
+ err = binary.Read(buf, binary.LittleEndian, &signature.SignatureData)
+ if err != nil {
+ return nil, nil, err
+ }
+ cert, err := x509.ParseCertificate(signature.SignatureData)
+ if err != nil {
+ return nil, nil, err
+ }
+ sigOffset += int(signatures.Header.SignatureSize)
+ certificates = append(certificates, *cert)
+ }
+ case hashSHA256SigGUID: // SHA256
+ for sigOffset := 0; uint32(sigOffset) < signatures.Header.SignatureListSize-28; {
+ signature := efiSignatureData{}
+ signature.SignatureData = make([]byte, signatures.Header.SignatureSize-16)
+ err := binary.Read(buf, binary.LittleEndian, &signature.SignatureOwner)
+ if err != nil {
+ return nil, nil, err
+ }
+ err = binary.Read(buf, binary.LittleEndian, &signature.SignatureData)
+ if err != nil {
+ return nil, nil, err
+ }
+ hashes = append(hashes, signature.SignatureData)
+ sigOffset += int(signatures.Header.SignatureSize)
+ }
+ case keyRSA2048SigGUID:
+ err = errors.New("unhandled RSA2048 key")
+ case certRSA2048SHA256SigGUID:
+ err = errors.New("unhandled RSA2048-SHA256 key")
+ case hashSHA1SigGUID:
+ err = errors.New("unhandled SHA1 hash")
+ case certRSA2048SHA1SigGUID:
+ err = errors.New("unhandled RSA2048-SHA1 key")
+ case hashSHA224SigGUID:
+ err = errors.New("unhandled SHA224 hash")
+ case hashSHA384SigGUID:
+ err = errors.New("unhandled SHA384 hash")
+ case hashSHA512SigGUID:
+ err = errors.New("unhandled SHA512 hash")
+ case certHashSHA256SigGUID:
+ err = errors.New("unhandled X509-SHA256 hash metadata")
+ case certHashSHA384SigGUID:
+ err = errors.New("unhandled X509-SHA384 hash metadata")
+ case certHashSHA512SigGUID:
+ err = errors.New("unhandled X509-SHA512 hash metadata")
+ default:
+ err = fmt.Errorf("unhandled signature type %s", signatureType)
+ }
+ if err != nil {
+ return nil, nil, err
+ }
+ }
+ return certificates, hashes, nil
+}
+
+// EFISignatureData represents the EFI_SIGNATURE_DATA type.
+// See section "31.4.1 Signature Database" in the specification
+// for more information.
+type EFISignatureData struct {
+ SignatureOwner efiGUID
+ SignatureData []byte // []int8
+}
+
+func parseEfiSignature(b []byte) ([]x509.Certificate, error) {
+ certificates := []x509.Certificate{}
+
+ if len(b) < 16 {
+ return nil, fmt.Errorf("invalid signature: buffer smaller than header (%d < %d)", len(b), 16)
+ }
+
+ buf := bytes.NewReader(b)
+ signature := EFISignatureData{}
+ signature.SignatureData = make([]byte, len(b)-16)
+
+ if err := binary.Read(buf, binary.LittleEndian, &signature.SignatureOwner); err != nil {
+ return certificates, err
+ }
+ if err := binary.Read(buf, binary.LittleEndian, &signature.SignatureData); err != nil {
+ return certificates, err
+ }
+
+ cert, err := x509.ParseCertificate(signature.SignatureData)
+ if err == nil {
+ certificates = append(certificates, *cert)
+ } else {
+ // A bug in shim may cause an event to be missing the SignatureOwner GUID.
+ // We handle this, but signal back to the caller using ErrSigMissingGUID.
+ if _, isStructuralErr := err.(asn1.StructuralError); isStructuralErr {
+ var err2 error
+ cert, err2 = x509.ParseCertificate(b)
+ if err2 == nil {
+ certificates = append(certificates, *cert)
+ err = ErrSigMissingGUID
+ }
+ }
+ }
+ return certificates, err
+}
diff --git a/metropolis/node/core/tpm/eventlog/secureboot.go b/metropolis/node/core/tpm/eventlog/secureboot.go
new file mode 100644
index 0000000..f117d30
--- /dev/null
+++ b/metropolis/node/core/tpm/eventlog/secureboot.go
@@ -0,0 +1,210 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Taken and pruned from go-attestation under Apache 2.0
+package eventlog
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+
+ "github.com/google/certificate-transparency-go/x509"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/tpm/eventlog/internal"
+)
+
+// SecurebootState describes the secure boot status of a machine, as determined
+// by processing its event log.
+type SecurebootState struct {
+ Enabled bool
+
+ // PlatformKeys enumerates keys which can sign a key exchange key.
+ PlatformKeys []x509.Certificate
+ // PlatformKeys enumerates key hashes which can sign a key exchange key.
+ PlatformKeyHashes [][]byte
+
+ // ExchangeKeys enumerates keys which can sign a database of permitted or
+ // forbidden keys.
+ ExchangeKeys []x509.Certificate
+ // ExchangeKeyHashes enumerates key hashes which can sign a database or
+ // permitted or forbidden keys.
+ ExchangeKeyHashes [][]byte
+
+ // PermittedKeys enumerates keys which may sign binaries to run.
+ PermittedKeys []x509.Certificate
+ // PermittedHashes enumerates hashes which permit binaries to run.
+ PermittedHashes [][]byte
+
+ // ForbiddenKeys enumerates keys which must not permit a binary to run.
+ ForbiddenKeys []x509.Certificate
+ // ForbiddenKeys enumerates hashes which must not permit a binary to run.
+ ForbiddenHashes [][]byte
+
+ // PreSeparatorAuthority describes the use of a secure-boot key to authorize
+ // the execution of a binary before the separator.
+ PreSeparatorAuthority []x509.Certificate
+ // PostSeparatorAuthority describes the use of a secure-boot key to authorize
+ // the execution of a binary after the separator.
+ PostSeparatorAuthority []x509.Certificate
+}
+
+// ParseSecurebootState parses a series of events to determine the
+// configuration of secure boot on a device. An error is returned if
+// the state cannot be determined, or if the event log is structured
+// in such a way that it may have been tampered post-execution of
+// platform firmware.
+func ParseSecurebootState(events []Event) (*SecurebootState, error) {
+ // This algorithm verifies the following:
+ // - All events in PCR 7 have event types which are expected in PCR 7.
+ // - All events are parsable according to their event type.
+ // - All events have digests values corresponding to their data/event type.
+ // - No unverifiable events were present.
+ // - All variables are specified before the separator and never duplicated.
+ // - The SecureBoot variable has a value of 0 or 1.
+ // - If SecureBoot was 1 (enabled), authority events were present indicating
+ // keys were used to perform verification.
+ // - If SecureBoot was 1 (enabled), platform + exchange + database keys
+ // were specified.
+ // - No UEFI debugger was attached.
+
+ var (
+ out SecurebootState
+ seenSeparator bool
+ seenAuthority bool
+ seenVars = map[string]bool{}
+ )
+
+ for _, e := range events {
+ if e.Index != 7 {
+ continue
+ }
+
+ et, err := internal.UntrustedParseEventType(uint32(e.Type))
+ if err != nil {
+ return nil, fmt.Errorf("unrecognised event type: %v", err)
+ }
+
+ digestVerify := e.digestEquals(e.Data)
+ switch et {
+ case internal.Separator:
+ if seenSeparator {
+ return nil, fmt.Errorf("duplicate separator at event %d", e.sequence)
+ }
+ seenSeparator = true
+ if !bytes.Equal(e.Data, []byte{0, 0, 0, 0}) {
+ return nil, fmt.Errorf("invalid separator data at event %d: %v", e.sequence, e.Data)
+ }
+ if digestVerify != nil {
+ return nil, fmt.Errorf("invalid separator digest at event %d: %v", e.sequence, digestVerify)
+ }
+
+ case internal.EFIAction:
+ if string(e.Data) == "UEFI Debug Mode" {
+ return nil, errors.New("a UEFI debugger was present during boot")
+ }
+ return nil, fmt.Errorf("event %d: unexpected EFI action event", e.sequence)
+
+ case internal.EFIVariableDriverConfig:
+ v, err := internal.ParseUEFIVariableData(bytes.NewReader(e.Data))
+ if err != nil {
+ return nil, fmt.Errorf("failed parsing EFI variable at event %d: %v", e.sequence, err)
+ }
+ if _, seenBefore := seenVars[v.VarName()]; seenBefore {
+ return nil, fmt.Errorf("duplicate EFI variable %q at event %d", v.VarName(), e.sequence)
+ }
+ seenVars[v.VarName()] = true
+ if seenSeparator {
+ return nil, fmt.Errorf("event %d: variable %q specified after separator", e.sequence, v.VarName())
+ }
+
+ if digestVerify != nil {
+ return nil, fmt.Errorf("invalid digest for variable %q on event %d: %v", v.VarName(), e.sequence, digestVerify)
+ }
+
+ switch v.VarName() {
+ case "SecureBoot":
+ if len(v.VariableData) != 1 {
+ return nil, fmt.Errorf("event %d: SecureBoot data len is %d, expected 1", e.sequence, len(v.VariableData))
+ }
+ out.Enabled = v.VariableData[0] == 1
+ case "PK":
+ if out.PlatformKeys, out.PlatformKeyHashes, err = v.SignatureData(); err != nil {
+ return nil, fmt.Errorf("event %d: failed parsing platform keys: %v", e.sequence, err)
+ }
+ case "KEK":
+ if out.ExchangeKeys, out.ExchangeKeyHashes, err = v.SignatureData(); err != nil {
+ return nil, fmt.Errorf("event %d: failed parsing key exchange keys: %v", e.sequence, err)
+ }
+ case "db":
+ if out.PermittedKeys, out.PermittedHashes, err = v.SignatureData(); err != nil {
+ return nil, fmt.Errorf("event %d: failed parsing signature database: %v", e.sequence, err)
+ }
+ case "dbx":
+ if out.ForbiddenKeys, out.ForbiddenHashes, err = v.SignatureData(); err != nil {
+ return nil, fmt.Errorf("event %d: failed parsing forbidden signature database: %v", e.sequence, err)
+ }
+ }
+
+ case internal.EFIVariableAuthority:
+ a, err := internal.ParseUEFIVariableAuthority(bytes.NewReader(e.Data))
+ if err != nil {
+ // Workaround for: https://github.com/google/go-attestation/issues/157
+ if err == internal.ErrSigMissingGUID {
+ // Versions of shim which do not carry
+ // https://github.com/rhboot/shim/commit/8a27a4809a6a2b40fb6a4049071bf96d6ad71b50
+ // have an erroneous additional byte in the event, which breaks digest
+ // verification. If verification failed, we try removing the last byte.
+ if digestVerify != nil {
+ digestVerify = e.digestEquals(e.Data[:len(e.Data)-1])
+ }
+ } else {
+ return nil, fmt.Errorf("failed parsing EFI variable authority at event %d: %v", e.sequence, err)
+ }
+ }
+ seenAuthority = true
+ if digestVerify != nil {
+ return nil, fmt.Errorf("invalid digest for authority on event %d: %v", e.sequence, digestVerify)
+ }
+ if !seenSeparator {
+ out.PreSeparatorAuthority = append(out.PreSeparatorAuthority, a.Certs...)
+ } else {
+ out.PostSeparatorAuthority = append(out.PostSeparatorAuthority, a.Certs...)
+ }
+
+ default:
+ return nil, fmt.Errorf("unexpected event type: %v", et)
+ }
+ }
+
+ if !out.Enabled {
+ return &out, nil
+ }
+
+ if !seenAuthority {
+ return nil, errors.New("secure boot was enabled but no key was used")
+ }
+ if len(out.PlatformKeys) == 0 && len(out.PlatformKeyHashes) == 0 {
+ return nil, errors.New("secure boot was enabled but no platform keys were known")
+ }
+ if len(out.ExchangeKeys) == 0 && len(out.ExchangeKeyHashes) == 0 {
+ return nil, errors.New("secure boot was enabled but no key exchange keys were known")
+ }
+ if len(out.PermittedKeys) == 0 && len(out.PermittedHashes) == 0 {
+ return nil, errors.New("secure boot was enabled but no keys or hashes were permitted")
+ }
+ return &out, nil
+}
diff --git a/metropolis/node/core/tpm/tpm.go b/metropolis/node/core/tpm/tpm.go
new file mode 100644
index 0000000..76f4f92
--- /dev/null
+++ b/metropolis/node/core/tpm/tpm.go
@@ -0,0 +1,561 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tpm
+
+import (
+ "bytes"
+ "crypto"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/x509"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/gogo/protobuf/proto"
+ tpmpb "github.com/google/go-tpm-tools/proto"
+ "github.com/google/go-tpm-tools/tpm2tools"
+ "github.com/google/go-tpm/tpm2"
+ "github.com/google/go-tpm/tpmutil"
+ "github.com/pkg/errors"
+ "golang.org/x/sys/unix"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/node/common/sysfs"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/logtree"
+)
+
+var (
+ // SecureBootPCRs are all PCRs that measure the current Secure Boot configuration.
+ // This is what we want if we rely on secure boot to verify boot integrity. The firmware
+ // hashes the secure boot policy and custom keys into the PCR.
+ //
+ // This requires an extra step that provisions the custom keys.
+ //
+ // Some background: https://mjg59.dreamwidth.org/48897.html?thread=1847297
+ // (the initramfs issue mentioned in the article has been solved by integrating
+ // it into the kernel binary, and we don't have a shim bootloader)
+ //
+ // PCR7 alone is not sufficient - it needs to be combined with firmware measurements.
+ SecureBootPCRs = []int{7}
+
+ // FirmwarePCRs are alle PCRs that contain the firmware measurements
+ // See https://trustedcomputinggroup.org/wp-content/uploads/TCG_EFI_Platform_1_22_Final_-v15.pdf
+ FirmwarePCRs = []int{
+ 0, // platform firmware
+ 2, // option ROM code
+ 3, // option ROM configuration and data
+ }
+
+ // FullSystemPCRs are all PCRs that contain any measurements up to the currently running EFI payload.
+ FullSystemPCRs = []int{
+ 0, // platform firmware
+ 1, // host platform configuration
+ 2, // option ROM code
+ 3, // option ROM configuration and data
+ 4, // EFI payload
+ }
+
+ // Using FullSystemPCRs is the most secure, but also the most brittle option since updating the EFI
+ // binary, updating the platform firmware, changing platform settings or updating the binary
+ // would invalidate the sealed data. It's annoying (but possible) to predict values for PCR4,
+ // and even more annoying for the firmware PCR (comparison to known values on similar hardware
+ // is the only thing that comes to mind).
+ //
+ // See also: https://github.com/mxre/sealkey (generates PCR4 from EFI image, BSD license)
+ //
+ // Using only SecureBootPCRs is the easiest and still reasonably secure, if we assume that the
+ // platform knows how to take care of itself (i.e. Intel Boot Guard), and that secure boot
+ // is implemented properly. It is, however, a much larger amount of code we need to trust.
+ //
+ // We do not care about PCR 5 (GPT partition table) since modifying it is harmless. All of
+ // the boot options and cmdline are hardcoded in the kernel image, and we use no bootloader,
+ // so there's no PCR for bootloader configuration or kernel cmdline.
+)
+
+var (
+ numSRTMPCRs = 16
+ srtmPCRs = tpm2.PCRSelection{Hash: tpm2.AlgSHA256, PCRs: []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}
+ // TCG Trusted Platform Module Library Level 00 Revision 0.99 Table 6
+ tpmGeneratedValue = uint32(0xff544347)
+)
+
+var (
+ // ErrNotExists is returned when no TPMs are available in the system
+ ErrNotExists = errors.New("no TPMs found")
+ // ErrNotInitialized is returned when this package was not initialized successfully
+ ErrNotInitialized = errors.New("no TPM was initialized")
+)
+
+// Singleton since the TPM is too
+var tpm *TPM
+
+// We're serializing all TPM operations since it has a limited number of handles and recovering
+// if it runs out is difficult to implement correctly. Might also be marginally more secure.
+var lock sync.Mutex
+
+// TPM represents a high-level interface to a connected TPM 2.0
+type TPM struct {
+ logger logtree.LeveledLogger
+ device io.ReadWriteCloser
+
+ // We keep the AK loaded since it's used fairly often and deriving it is expensive
+ akHandleCache tpmutil.Handle
+ akPublicKey crypto.PublicKey
+}
+
+// Initialize finds and opens the TPM (if any). If there is no TPM available it returns
+// ErrNotExists
+func Initialize(logger logtree.LeveledLogger) error {
+ lock.Lock()
+ defer lock.Unlock()
+ tpmDir, err := os.Open("/sys/class/tpm")
+ if err != nil {
+ return errors.Wrap(err, "failed to open sysfs TPM class")
+ }
+ defer tpmDir.Close()
+
+ tpms, err := tpmDir.Readdirnames(2)
+ if err != nil {
+ return errors.Wrap(err, "failed to read TPM device class")
+ }
+
+ if len(tpms) == 0 {
+ return ErrNotExists
+ }
+ if len(tpms) > 1 {
+ // If this is changed GetMeasurementLog() needs to be updated too
+ logger.Warningf("Found more than one TPM, using the first one")
+ }
+ tpmName := tpms[0]
+ ueventData, err := sysfs.ReadUevents(filepath.Join("/sys/class/tpm", tpmName, "uevent"))
+ majorDev, err := strconv.Atoi(ueventData["MAJOR"])
+ if err != nil {
+ return fmt.Errorf("failed to convert uevent: %w", err)
+ }
+ minorDev, err := strconv.Atoi(ueventData["MINOR"])
+ if err != nil {
+ return fmt.Errorf("failed to convert uevent: %w", err)
+ }
+ if err := unix.Mknod("/dev/tpm", 0600|unix.S_IFCHR, int(unix.Mkdev(uint32(majorDev), uint32(minorDev)))); err != nil {
+ return errors.Wrap(err, "failed to create TPM device node")
+ }
+ device, err := tpm2.OpenTPM("/dev/tpm")
+ if err != nil {
+ return errors.Wrap(err, "failed to open TPM")
+ }
+ tpm = &TPM{
+ device: device,
+ logger: logger,
+ }
+ return nil
+}
+
+// GenerateSafeKey uses two sources of randomness (Kernel & TPM) to generate the key
+func GenerateSafeKey(size uint16) ([]byte, error) {
+ lock.Lock()
+ defer lock.Unlock()
+ if tpm == nil {
+ return []byte{}, ErrNotInitialized
+ }
+ encryptionKeyHost := make([]byte, size)
+ if _, err := io.ReadFull(rand.Reader, encryptionKeyHost); err != nil {
+ return []byte{}, errors.Wrap(err, "failed to generate host portion of new key")
+ }
+ var encryptionKeyTPM []byte
+ for i := 48; i > 0; i-- {
+ tpmKeyPart, err := tpm2.GetRandom(tpm.device, size-uint16(len(encryptionKeyTPM)))
+ if err != nil {
+ return []byte{}, errors.Wrap(err, "failed to generate TPM portion of new key")
+ }
+ encryptionKeyTPM = append(encryptionKeyTPM, tpmKeyPart...)
+ if len(encryptionKeyTPM) >= int(size) {
+ break
+ }
+ }
+
+ if len(encryptionKeyTPM) != int(size) {
+ return []byte{}, fmt.Errorf("got incorrect amount of TPM randomess: %v, requested %v", len(encryptionKeyTPM), size)
+ }
+
+ encryptionKey := make([]byte, size)
+ for i := uint16(0); i < size; i++ {
+ encryptionKey[i] = encryptionKeyHost[i] ^ encryptionKeyTPM[i]
+ }
+ return encryptionKey, nil
+}
+
+// Seal seals sensitive data and only allows access if the current platform configuration in
+// matches the one the data was sealed on.
+func Seal(data []byte, pcrs []int) ([]byte, error) {
+ lock.Lock()
+ defer lock.Unlock()
+ if tpm == nil {
+ return []byte{}, ErrNotInitialized
+ }
+ srk, err := tpm2tools.StorageRootKeyRSA(tpm.device)
+ if err != nil {
+ return []byte{}, errors.Wrap(err, "failed to load TPM SRK")
+ }
+ defer srk.Close()
+ sealedKey, err := srk.Seal(pcrs, data)
+ sealedKeyRaw, err := proto.Marshal(sealedKey)
+ if err != nil {
+ return []byte{}, errors.Wrapf(err, "failed to marshal sealed data")
+ }
+ return sealedKeyRaw, nil
+}
+
+// Unseal unseals sensitive data if the current platform configuration allows and sealing constraints
+// allow it.
+func Unseal(data []byte) ([]byte, error) {
+ lock.Lock()
+ defer lock.Unlock()
+ if tpm == nil {
+ return []byte{}, ErrNotInitialized
+ }
+ srk, err := tpm2tools.StorageRootKeyRSA(tpm.device)
+ if err != nil {
+ return []byte{}, errors.Wrap(err, "failed to load TPM SRK")
+ }
+ defer srk.Close()
+
+ var sealedKey tpmpb.SealedBytes
+ if err := proto.Unmarshal(data, &sealedKey); err != nil {
+ return []byte{}, errors.Wrap(err, "failed to decode sealed data")
+ }
+ // Logging this for auditing purposes
+ pcrList := []string{}
+ for _, pcr := range sealedKey.Pcrs {
+ pcrList = append(pcrList, string(pcr))
+ }
+ tpm.logger.Infof("Attempting to unseal data protected with PCRs %s", strings.Join(pcrList, ","))
+ unsealedData, err := srk.Unseal(&sealedKey)
+ if err != nil {
+ return []byte{}, errors.Wrap(err, "failed to unseal data")
+ }
+ return unsealedData, nil
+}
+
+// Standard AK template for RSA2048 non-duplicatable restricted signing for attestation
+var akTemplate = tpm2.Public{
+ Type: tpm2.AlgRSA,
+ NameAlg: tpm2.AlgSHA256,
+ Attributes: tpm2.FlagSignerDefault,
+ RSAParameters: &tpm2.RSAParams{
+ Sign: &tpm2.SigScheme{
+ Alg: tpm2.AlgRSASSA,
+ Hash: tpm2.AlgSHA256,
+ },
+ KeyBits: 2048,
+ },
+}
+
+func loadAK() error {
+ var err error
+ // Rationale: The AK is an EK-equivalent key and used only for attestation. Using a non-primary
+ // key here would require us to store the wrapped version somewhere, which is inconvenient.
+ // This being a primary key in the Endorsement hierarchy means that it can always be recreated
+ // and can never be "destroyed". Under our security model this is of no concern since we identify
+ // a node by its IK (Identity Key) which we can destroy.
+ tpm.akHandleCache, tpm.akPublicKey, err = tpm2.CreatePrimary(tpm.device, tpm2.HandleEndorsement,
+ tpm2.PCRSelection{}, "", "", akTemplate)
+ return err
+}
+
+// Process documented in TCG EK Credential Profile 2.2.1
+func loadEK() (tpmutil.Handle, crypto.PublicKey, error) {
+ // The EK is a primary key which is supposed to be certified by the manufacturer of the TPM.
+ // Its public attributes are standardized in TCG EK Credential Profile 2.0 Table 1. These need
+ // to match exactly or we aren't getting the key the manufacturere signed. tpm2tools contains
+ // such a template already, so we're using that instead of redoing it ourselves.
+ // This ignores the more complicated ways EKs can be specified, the additional stuff you can do
+ // is just absolutely crazy (see 2.2.1.2 onward)
+ return tpm2.CreatePrimary(tpm.device, tpm2.HandleEndorsement,
+ tpm2.PCRSelection{}, "", "", tpm2tools.DefaultEKTemplateRSA())
+}
+
+// GetAKPublic gets the TPM2T_PUBLIC of the AK key
+func GetAKPublic() ([]byte, error) {
+ lock.Lock()
+ defer lock.Unlock()
+ if tpm == nil {
+ return []byte{}, ErrNotInitialized
+ }
+ if tpm.akHandleCache == tpmutil.Handle(0) {
+ if err := loadAK(); err != nil {
+ return []byte{}, fmt.Errorf("failed to load AK primary key: %w", err)
+ }
+ }
+ public, _, _, err := tpm2.ReadPublic(tpm.device, tpm.akHandleCache)
+ if err != nil {
+ return []byte{}, err
+ }
+ return public.Encode()
+}
+
+// TCG TPM v2.0 Provisioning Guidance v1.0 7.8 Table 2 and
+// TCG EK Credential Profile v2.1 2.2.1.4 de-facto Standard for Windows
+// These are both non-normative and reference Windows 10 documentation that's no longer available :(
+// But in practice this is what people are using, so if it's normative or not doesn't really matter
+const ekCertHandle = 0x01c00002
+
+// GetEKPublic gets the public key and (if available) Certificate of the EK
+func GetEKPublic() ([]byte, []byte, error) {
+ lock.Lock()
+ defer lock.Unlock()
+ if tpm == nil {
+ return []byte{}, []byte{}, ErrNotInitialized
+ }
+ ekHandle, publicRaw, err := loadEK()
+ if err != nil {
+ return []byte{}, []byte{}, fmt.Errorf("failed to load EK primary key: %w", err)
+ }
+ defer tpm2.FlushContext(tpm.device, ekHandle)
+ // Don't question the use of HandleOwner, that's the Standard™
+ ekCertRaw, err := tpm2.NVReadEx(tpm.device, ekCertHandle, tpm2.HandleOwner, "", 0)
+ if err != nil {
+ return []byte{}, []byte{}, err
+ }
+
+ publicKey, err := x509.MarshalPKIXPublicKey(publicRaw)
+ if err != nil {
+ return []byte{}, []byte{}, err
+ }
+
+ return publicKey, ekCertRaw, nil
+}
+
+// MakeAKChallenge generates a challenge for TPM residency and attributes of the AK
+func MakeAKChallenge(ekPubKey, akPub []byte, nonce []byte) ([]byte, []byte, error) {
+ ekPubKeyData, err := x509.ParsePKIXPublicKey(ekPubKey)
+ if err != nil {
+ return []byte{}, []byte{}, fmt.Errorf("failed to decode EK pubkey: %w", err)
+ }
+ akPubData, err := tpm2.DecodePublic(akPub)
+ if err != nil {
+ return []byte{}, []byte{}, fmt.Errorf("failed to decode AK public part: %w", err)
+ }
+ // Make sure we're attesting the right attributes (in particular Restricted)
+ if !akPubData.MatchesTemplate(akTemplate) {
+ return []byte{}, []byte{}, errors.New("the key being challenged is not a valid AK")
+ }
+ akName, err := akPubData.Name()
+ if err != nil {
+ return []byte{}, []byte{}, fmt.Errorf("failed to derive AK name: %w", err)
+ }
+ return generateRSA(akName.Digest, ekPubKeyData.(*rsa.PublicKey), 16, nonce, rand.Reader)
+}
+
+// SolveAKChallenge solves a challenge for TPM residency of the AK
+func SolveAKChallenge(credBlob, secretChallenge []byte) ([]byte, error) {
+ lock.Lock()
+ defer lock.Unlock()
+ if tpm == nil {
+ return []byte{}, ErrNotInitialized
+ }
+ if tpm.akHandleCache == tpmutil.Handle(0) {
+ if err := loadAK(); err != nil {
+ return []byte{}, fmt.Errorf("failed to load AK primary key: %w", err)
+ }
+ }
+
+ ekHandle, _, err := loadEK()
+ if err != nil {
+ return []byte{}, fmt.Errorf("failed to load EK: %w", err)
+ }
+ defer tpm2.FlushContext(tpm.device, ekHandle)
+
+ // This is necessary since the EK requires an endorsement handle policy in its session
+ // For us this is stupid because we keep all hierarchies open anyways since a) we cannot safely
+ // store secrets on the OS side pre-global unlock and b) it makes no sense in this security model
+ // since an uncompromised host OS will not let an untrusted entity attest as itself and a
+ // compromised OS can either not pass PCR policy checks or the game's already over (you
+ // successfully runtime-exploited a production Smalltown Core)
+ endorsementSession, _, err := tpm2.StartAuthSession(
+ tpm.device,
+ tpm2.HandleNull,
+ tpm2.HandleNull,
+ make([]byte, 16),
+ nil,
+ tpm2.SessionPolicy,
+ tpm2.AlgNull,
+ tpm2.AlgSHA256)
+ if err != nil {
+ panic(err)
+ }
+ defer tpm2.FlushContext(tpm.device, endorsementSession)
+
+ _, err = tpm2.PolicySecret(tpm.device, tpm2.HandleEndorsement, tpm2.AuthCommand{Session: tpm2.HandlePasswordSession, Attributes: tpm2.AttrContinueSession}, endorsementSession, nil, nil, nil, 0)
+ if err != nil {
+ return []byte{}, fmt.Errorf("failed to make a policy secret session: %w", err)
+ }
+
+ for {
+ solution, err := tpm2.ActivateCredentialUsingAuth(tpm.device, []tpm2.AuthCommand{
+ {Session: tpm2.HandlePasswordSession, Attributes: tpm2.AttrContinueSession}, // Use standard no-password authentication
+ {Session: endorsementSession, Attributes: tpm2.AttrContinueSession}, // Use a full policy session for the EK
+ }, tpm.akHandleCache, ekHandle, credBlob, secretChallenge)
+ if warn, ok := err.(tpm2.Warning); ok && warn.Code == tpm2.RCRetry {
+ time.Sleep(100 * time.Millisecond)
+ continue
+ }
+ return solution, err
+ }
+}
+
+// FlushTransientHandles flushes all sessions and non-persistent handles
+func FlushTransientHandles() error {
+ lock.Lock()
+ defer lock.Unlock()
+ if tpm == nil {
+ return ErrNotInitialized
+ }
+ flushHandleTypes := []tpm2.HandleType{tpm2.HandleTypeTransient, tpm2.HandleTypeLoadedSession, tpm2.HandleTypeSavedSession}
+ for _, handleType := range flushHandleTypes {
+ handles, err := tpm2tools.Handles(tpm.device, handleType)
+ if err != nil {
+ return err
+ }
+ for _, handle := range handles {
+ if err := tpm2.FlushContext(tpm.device, handle); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+// AttestPlatform performs a PCR quote using the AK and returns the quote and its signature
+func AttestPlatform(nonce []byte) ([]byte, []byte, error) {
+ lock.Lock()
+ defer lock.Unlock()
+ if tpm == nil {
+ return []byte{}, []byte{}, ErrNotInitialized
+ }
+ if tpm.akHandleCache == tpmutil.Handle(0) {
+ if err := loadAK(); err != nil {
+ return []byte{}, []byte{}, fmt.Errorf("failed to load AK primary key: %w", err)
+ }
+ }
+ // We only care about SHA256 since SHA1 is weak. This is supported on at least GCE and
+ // Intel / AMD fTPM, which is good enough for now. Alg is null because that would just hash the
+ // nonce, which is dumb.
+ quote, signature, err := tpm2.Quote(tpm.device, tpm.akHandleCache, "", "", nonce, srtmPCRs,
+ tpm2.AlgNull)
+ if err != nil {
+ return []byte{}, []byte{}, fmt.Errorf("failed to quote PCRs: %w", err)
+ }
+ return quote, signature.RSA.Signature, err
+}
+
+// VerifyAttestPlatform verifies a given attestation. You can rely on all data coming back as being
+// from the TPM on which the AK is bound to.
+func VerifyAttestPlatform(nonce, akPub, quote, signature []byte) (*tpm2.AttestationData, error) {
+ hash := crypto.SHA256.New()
+ hash.Write(quote)
+
+ akPubData, err := tpm2.DecodePublic(akPub)
+ if err != nil {
+ return nil, fmt.Errorf("invalid AK: %w", err)
+ }
+ akPublicKey, err := akPubData.Key()
+ if err != nil {
+ return nil, fmt.Errorf("invalid AK: %w", err)
+ }
+ akRSAKey, ok := akPublicKey.(*rsa.PublicKey)
+ if !ok {
+ return nil, errors.New("invalid AK: invalid key type")
+ }
+
+ if err := rsa.VerifyPKCS1v15(akRSAKey, crypto.SHA256, hash.Sum(nil), signature); err != nil {
+ return nil, err
+ }
+
+ quoteData, err := tpm2.DecodeAttestationData(quote)
+ if err != nil {
+ return nil, err
+ }
+ // quoteData.Magic works together with the TPM's Restricted key attribute. If this attribute is set
+ // (which it needs to be for the AK to be considered valid) the TPM will not sign external data
+ // having this prefix with such a key. Only data that originates inside the TPM like quotes and
+ // key certifications can have this prefix and sill be signed by a restricted key. This check
+ // is thus vital, otherwise somebody can just feed the TPM an arbitrary attestation to sign with
+ // its AK and this function will happily accept the forged attestation.
+ if quoteData.Magic != tpmGeneratedValue {
+ return nil, errors.New("invalid TPM quote: data marker for internal data not set - forged attestation")
+ }
+ if quoteData.Type != tpm2.TagAttestQuote {
+ return nil, errors.New("invalid TPM qoute: not a TPM quote")
+ }
+ if !bytes.Equal(quoteData.ExtraData, nonce) {
+ return nil, errors.New("invalid TPM quote: wrong nonce")
+ }
+
+ return quoteData, nil
+}
+
+// GetPCRs returns all SRTM PCRs in-order
+func GetPCRs() ([][]byte, error) {
+ lock.Lock()
+ defer lock.Unlock()
+ if tpm == nil {
+ return [][]byte{}, ErrNotInitialized
+ }
+ pcrs := make([][]byte, numSRTMPCRs)
+
+ // The TPM can (and most do) return partial results. Let's just retry as many times as we have
+ // PCRs since each read should return at least one PCR.
+readLoop:
+ for i := 0; i < numSRTMPCRs; i++ {
+ sel := tpm2.PCRSelection{Hash: tpm2.AlgSHA256}
+ for pcrN := 0; pcrN < numSRTMPCRs; pcrN++ {
+ if len(pcrs[pcrN]) == 0 {
+ sel.PCRs = append(sel.PCRs, pcrN)
+ }
+ }
+
+ readPCRs, err := tpm2.ReadPCRs(tpm.device, sel)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read PCRs: %w", err)
+ }
+
+ for pcrN, pcr := range readPCRs {
+ pcrs[pcrN] = pcr
+ }
+ for _, pcr := range pcrs {
+ // If at least one PCR is still not read, continue
+ if len(pcr) == 0 {
+ continue readLoop
+ }
+ }
+ break
+ }
+
+ return pcrs, nil
+}
+
+// GetMeasurmentLog returns the binary log of all data hashed into PCRs. The result can be parsed by eventlog.
+// As this library currently doesn't support extending PCRs it just returns the log as supplied by the EFI interface.
+func GetMeasurementLog() ([]byte, error) {
+ return ioutil.ReadFile("/sys/kernel/security/tpm0/binary_bios_measurements")
+}
diff --git a/metropolis/node/kubernetes/BUILD.bazel b/metropolis/node/kubernetes/BUILD.bazel
new file mode 100644
index 0000000..f1fa849
--- /dev/null
+++ b/metropolis/node/kubernetes/BUILD.bazel
@@ -0,0 +1,54 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "go_default_library",
+ srcs = [
+ "apiserver.go",
+ "controller-manager.go",
+ "csi.go",
+ "kubelet.go",
+ "provisioner.go",
+ "scheduler.go",
+ "service.go",
+ ],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/node/kubernetes",
+ visibility = ["//metropolis/node:__subpackages__"],
+ deps = [
+ "//metropolis/node:go_default_library",
+ "//metropolis/node/common/fileargs:go_default_library",
+ "//metropolis/node/common/fsquota:go_default_library",
+ "//metropolis/node/common/supervisor:go_default_library",
+ "//metropolis/node/core/localstorage:go_default_library",
+ "//metropolis/node/core/localstorage/declarative:go_default_library",
+ "//metropolis/node/core/logtree:go_default_library",
+ "//metropolis/node/core/network/dns:go_default_library",
+ "//metropolis/node/kubernetes/clusternet:go_default_library",
+ "//metropolis/node/kubernetes/nfproxy:go_default_library",
+ "//metropolis/node/kubernetes/pki:go_default_library",
+ "//metropolis/node/kubernetes/reconciler:go_default_library",
+ "//metropolis/proto/api:go_default_library",
+ "@com_github_container_storage_interface_spec//lib/go/csi:go_default_library",
+ "@io_bazel_rules_go//proto/wkt:wrappers_go_proto",
+ "@io_k8s_api//core/v1:go_default_library",
+ "@io_k8s_api//storage/v1:go_default_library",
+ "@io_k8s_apimachinery//pkg/api/errors:go_default_library",
+ "@io_k8s_apimachinery//pkg/apis/meta/v1:go_default_library",
+ "@io_k8s_client_go//informers:go_default_library",
+ "@io_k8s_client_go//informers/core/v1:go_default_library",
+ "@io_k8s_client_go//informers/storage/v1:go_default_library",
+ "@io_k8s_client_go//kubernetes:go_default_library",
+ "@io_k8s_client_go//kubernetes/scheme:go_default_library",
+ "@io_k8s_client_go//kubernetes/typed/core/v1:go_default_library",
+ "@io_k8s_client_go//tools/cache:go_default_library",
+ "@io_k8s_client_go//tools/clientcmd:go_default_library",
+ "@io_k8s_client_go//tools/record:go_default_library",
+ "@io_k8s_client_go//tools/reference:go_default_library",
+ "@io_k8s_client_go//util/workqueue:go_default_library",
+ "@io_k8s_kubelet//config/v1beta1:go_default_library",
+ "@io_k8s_kubelet//pkg/apis/pluginregistration/v1:go_default_library",
+ "@org_golang_google_grpc//:go_default_library",
+ "@org_golang_google_grpc//codes:go_default_library",
+ "@org_golang_google_grpc//status:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/metropolis/node/kubernetes/apiserver.go b/metropolis/node/kubernetes/apiserver.go
new file mode 100644
index 0000000..583e268
--- /dev/null
+++ b/metropolis/node/kubernetes/apiserver.go
@@ -0,0 +1,137 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kubernetes
+
+import (
+ "context"
+ "encoding/pem"
+ "fmt"
+ "io"
+ "net"
+ "os/exec"
+
+ common "git.monogon.dev/source/nexantic.git/metropolis/node"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/common/fileargs"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/common/supervisor"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/localstorage"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/kubernetes/pki"
+)
+
+type apiserverService struct {
+ KPKI *pki.KubernetesPKI
+ AdvertiseAddress net.IP
+ ServiceIPRange net.IPNet
+ Output io.Writer
+ EphemeralConsensusDirectory *localstorage.EphemeralConsensusDirectory
+
+ // All PKI-related things are in DER
+ idCA []byte
+ kubeletClientCert []byte
+ kubeletClientKey []byte
+ aggregationCA []byte
+ aggregationClientCert []byte
+ aggregationClientKey []byte
+ serviceAccountPrivKey []byte // In PKIX form
+ serverCert []byte
+ serverKey []byte
+}
+
+func (s *apiserverService) loadPKI(ctx context.Context) error {
+ for _, el := range []struct {
+ targetCert *[]byte
+ targetKey *[]byte
+ name pki.KubeCertificateName
+ }{
+ {&s.idCA, nil, pki.IdCA},
+ {&s.kubeletClientCert, &s.kubeletClientKey, pki.KubeletClient},
+ {&s.aggregationCA, nil, pki.AggregationCA},
+ {&s.aggregationClientCert, &s.aggregationClientKey, pki.FrontProxyClient},
+ {&s.serverCert, &s.serverKey, pki.APIServer},
+ } {
+ cert, key, err := s.KPKI.Certificate(ctx, el.name)
+ if err != nil {
+ return fmt.Errorf("could not load certificate %q from PKI: %w", el.name, err)
+ }
+ if el.targetCert != nil {
+ *el.targetCert = cert
+ }
+ if el.targetKey != nil {
+ *el.targetKey = key
+ }
+ }
+
+ var err error
+ s.serviceAccountPrivKey, err = s.KPKI.ServiceAccountKey(ctx)
+ if err != nil {
+ return fmt.Errorf("could not load serviceaccount privkey: %w", err)
+ }
+ return nil
+}
+
+func (s *apiserverService) Run(ctx context.Context) error {
+ if err := s.loadPKI(ctx); err != nil {
+ return fmt.Errorf("loading PKI data failed: %w", err)
+ }
+ args, err := fileargs.New()
+ if err != nil {
+ panic(err) // If this fails, something is very wrong. Just crash.
+ }
+ defer args.Close()
+
+ cmd := exec.CommandContext(ctx, "/kubernetes/bin/kube", "kube-apiserver",
+ fmt.Sprintf("--advertise-address=%v", s.AdvertiseAddress.String()),
+ "--authorization-mode=Node,RBAC",
+ args.FileOpt("--client-ca-file", "client-ca.pem",
+ pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: s.idCA})),
+ "--enable-admission-plugins=NodeRestriction,PodSecurityPolicy",
+ "--enable-aggregator-routing=true",
+ "--insecure-port=0",
+ fmt.Sprintf("--secure-port=%v", common.KubernetesAPIPort),
+ fmt.Sprintf("--etcd-servers=unix:///%s:0", s.EphemeralConsensusDirectory.ClientSocket.FullPath()),
+ args.FileOpt("--kubelet-client-certificate", "kubelet-client-cert.pem",
+ pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: s.kubeletClientCert})),
+ args.FileOpt("--kubelet-client-key", "kubelet-client-key.pem",
+ pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: s.kubeletClientKey})),
+ "--kubelet-preferred-address-types=InternalIP",
+ args.FileOpt("--proxy-client-cert-file", "aggregation-client-cert.pem",
+ pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: s.aggregationClientCert})),
+ args.FileOpt("--proxy-client-key-file", "aggregation-client-key.pem",
+ pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: s.aggregationClientKey})),
+ "--requestheader-allowed-names=front-proxy-client",
+ args.FileOpt("--requestheader-client-ca-file", "aggregation-ca.pem",
+ pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: s.aggregationCA})),
+ "--requestheader-extra-headers-prefix=X-Remote-Extra-",
+ "--requestheader-group-headers=X-Remote-Group",
+ "--requestheader-username-headers=X-Remote-User",
+ args.FileOpt("--service-account-key-file", "service-account-pubkey.pem",
+ pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: s.serviceAccountPrivKey})),
+ fmt.Sprintf("--service-cluster-ip-range=%v", s.ServiceIPRange.String()),
+ args.FileOpt("--tls-cert-file", "server-cert.pem",
+ pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: s.serverCert})),
+ args.FileOpt("--tls-private-key-file", "server-key.pem",
+ pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: s.serverKey})),
+ )
+ if args.Error() != nil {
+ return err
+ }
+ cmd.Stdout = s.Output
+ cmd.Stderr = s.Output
+ supervisor.Signal(ctx, supervisor.SignalHealthy)
+ err = cmd.Run()
+ fmt.Fprintf(s.Output, "apiserver stopped: %v\n", err)
+ return err
+}
diff --git a/metropolis/node/kubernetes/clusternet/BUILD.bazel b/metropolis/node/kubernetes/clusternet/BUILD.bazel
new file mode 100644
index 0000000..9e9cc01
--- /dev/null
+++ b/metropolis/node/kubernetes/clusternet/BUILD.bazel
@@ -0,0 +1,27 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "go_default_library",
+ srcs = [
+ "clusternet.go",
+ "netlink_compat.go",
+ ],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/node/kubernetes/clusternet",
+ visibility = ["//metropolis/node/kubernetes:__subpackages__"],
+ deps = [
+ "//metropolis/node:go_default_library",
+ "//metropolis/node/common/jsonpatch:go_default_library",
+ "//metropolis/node/common/supervisor:go_default_library",
+ "//metropolis/node/core/localstorage:go_default_library",
+ "//metropolis/node/core/logtree:go_default_library",
+ "@com_github_vishvananda_netlink//:go_default_library",
+ "@com_zx2c4_golang_wireguard_wgctrl//:go_default_library",
+ "@com_zx2c4_golang_wireguard_wgctrl//wgtypes:go_default_library",
+ "@io_k8s_api//core/v1:go_default_library",
+ "@io_k8s_apimachinery//pkg/apis/meta/v1:go_default_library",
+ "@io_k8s_apimachinery//pkg/types:go_default_library",
+ "@io_k8s_client_go//informers:go_default_library",
+ "@io_k8s_client_go//kubernetes:go_default_library",
+ "@io_k8s_client_go//tools/cache:go_default_library",
+ ],
+)
diff --git a/metropolis/node/kubernetes/clusternet/clusternet.go b/metropolis/node/kubernetes/clusternet/clusternet.go
new file mode 100644
index 0000000..d8dc7ad
--- /dev/null
+++ b/metropolis/node/kubernetes/clusternet/clusternet.go
@@ -0,0 +1,276 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package clusternet implements a WireGuard-based overlay network for Kubernetes. It relies on controller-manager's
+// IPAM to assign IP ranges to nodes and on Kubernetes' Node objects to distribute the Node IPs and public keys.
+//
+// It sets up a single WireGuard network interface and routes the entire ClusterCIDR into that network interface,
+// relying on WireGuard's AllowedIPs mechanism to look up the correct peer node to send the traffic to. This means
+// that the routing table doesn't change and doesn't have to be separately managed. When clusternet is started
+// it annotates its WireGuard public key onto its node object.
+// For each node object that's created or updated on the K8s apiserver it checks if a public key annotation is set and
+// if yes a peer with that public key, its InternalIP as endpoint and the CIDR for that node as AllowedIPs is created.
+package clusternet
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "net"
+ "os"
+
+ "github.com/vishvananda/netlink"
+ "golang.zx2c4.com/wireguard/wgctrl"
+ "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
+ corev1 "k8s.io/api/core/v1"
+ metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
+ "k8s.io/apimachinery/pkg/types"
+ "k8s.io/client-go/informers"
+ "k8s.io/client-go/kubernetes"
+ "k8s.io/client-go/tools/cache"
+
+ common "git.monogon.dev/source/nexantic.git/metropolis/node"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/common/jsonpatch"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/common/supervisor"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/localstorage"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/logtree"
+)
+
+const (
+ clusterNetDeviceName = "clusternet"
+ publicKeyAnnotation = "node.smalltown.nexantic.com/wg-pubkey"
+)
+
+type Service struct {
+ NodeName string
+ Kubernetes kubernetes.Interface
+ ClusterNet net.IPNet
+ InformerFactory informers.SharedInformerFactory
+ DataDirectory *localstorage.DataKubernetesClusterNetworkingDirectory
+
+ wgClient *wgctrl.Client
+ privKey wgtypes.Key
+ logger logtree.LeveledLogger
+}
+
+// ensureNode creates/updates the corresponding WireGuard peer entry for the given node objet
+func (s *Service) ensureNode(newNode *corev1.Node) error {
+ if newNode.Name == s.NodeName {
+ // Node doesn't need to connect to itself
+ return nil
+ }
+ pubKeyRaw := newNode.Annotations[publicKeyAnnotation]
+ if pubKeyRaw == "" {
+ return nil
+ }
+ pubKey, err := wgtypes.ParseKey(pubKeyRaw)
+ if err != nil {
+ return fmt.Errorf("failed to parse public-key annotation: %w", err)
+ }
+ var internalIP net.IP
+ for _, addr := range newNode.Status.Addresses {
+ if addr.Type == corev1.NodeInternalIP {
+ if internalIP != nil {
+ s.logger.Warningf("More than one NodeInternalIP specified, using the first one")
+ break
+ }
+ internalIP = net.ParseIP(addr.Address)
+ if internalIP == nil {
+ s.logger.Warningf("Failed to parse Internal IP %s", addr.Address)
+ }
+ }
+ }
+ if internalIP == nil {
+ return errors.New("node has no Internal IP")
+ }
+ var allowedIPs []net.IPNet
+ for _, podNetStr := range newNode.Spec.PodCIDRs {
+ _, podNet, err := net.ParseCIDR(podNetStr)
+ if err != nil {
+ s.logger.Warningf("Node %s PodCIDR failed to parse, ignored: %v", newNode.Name, err)
+ continue
+ }
+ allowedIPs = append(allowedIPs, *podNet)
+ }
+ allowedIPs = append(allowedIPs, net.IPNet{IP: internalIP, Mask: net.CIDRMask(32, 32)})
+ s.logger.V(1).Infof("Adding/Updating WireGuard peer node %s, endpoint %s, allowedIPs %+v", newNode.Name, internalIP.String(), allowedIPs)
+ // WireGuard's kernel side has create/update semantics on peers by default. So we can just add the peer multiple
+ // times to update it.
+ err = s.wgClient.ConfigureDevice(clusterNetDeviceName, wgtypes.Config{
+ Peers: []wgtypes.PeerConfig{{
+ PublicKey: pubKey,
+ Endpoint: &net.UDPAddr{Port: common.WireGuardPort, IP: internalIP},
+ ReplaceAllowedIPs: true,
+ AllowedIPs: allowedIPs,
+ }},
+ })
+ if err != nil {
+ return fmt.Errorf("failed to add WireGuard peer node: %w", err)
+ }
+ return nil
+}
+
+// removeNode removes the corresponding WireGuard peer entry for the given node object
+func (s *Service) removeNode(oldNode *corev1.Node) error {
+ if oldNode.Name == s.NodeName {
+ // Node doesn't need to connect to itself
+ return nil
+ }
+ pubKeyRaw := oldNode.Annotations[publicKeyAnnotation]
+ if pubKeyRaw == "" {
+ return nil
+ }
+ pubKey, err := wgtypes.ParseKey(pubKeyRaw)
+ if err != nil {
+ return fmt.Errorf("node public-key annotation not decodable: %w", err)
+ }
+ err = s.wgClient.ConfigureDevice(clusterNetDeviceName, wgtypes.Config{
+ Peers: []wgtypes.PeerConfig{{
+ PublicKey: pubKey,
+ Remove: true,
+ }},
+ })
+ if err != nil {
+ return fmt.Errorf("failed to remove WireGuard peer node: %w", err)
+ }
+ return nil
+}
+
+// ensureOnDiskKey loads the private key from disk or (if none exists) generates one and persists it.
+func (s *Service) ensureOnDiskKey() error {
+ keyRaw, err := s.DataDirectory.Key.Read()
+ if os.IsNotExist(err) {
+ key, err := wgtypes.GeneratePrivateKey()
+ if err != nil {
+ return fmt.Errorf("failed to generate private key: %w", err)
+ }
+ if err := s.DataDirectory.Key.Write([]byte(key.String()), 0600); err != nil {
+ return fmt.Errorf("failed to store newly generated key: %w", err)
+ }
+
+ s.privKey = key
+ return nil
+ } else if err != nil {
+ return fmt.Errorf("failed to load on-disk key: %w", err)
+ }
+
+ key, err := wgtypes.ParseKey(string(keyRaw))
+ if err != nil {
+ return fmt.Errorf("invalid private key in file: %w", err)
+ }
+ s.privKey = key
+ return nil
+}
+
+// annotateThisNode annotates the node (as defined by NodeName) with the wireguard public key of this node.
+func (s *Service) annotateThisNode(ctx context.Context) error {
+ patch := []jsonpatch.JsonPatchOp{{
+ Operation: "add",
+ Path: "/metadata/annotations/" + jsonpatch.EncodeJSONRefToken(publicKeyAnnotation),
+ Value: s.privKey.PublicKey().String(),
+ }}
+
+ patchRaw, err := json.Marshal(patch)
+ if err != nil {
+ return fmt.Errorf("failed to encode JSONPatch: %w", err)
+ }
+
+ if _, err := s.Kubernetes.CoreV1().Nodes().Patch(ctx, s.NodeName, types.JSONPatchType, patchRaw, metav1.PatchOptions{}); err != nil {
+ return fmt.Errorf("failed to patch resource: %w", err)
+ }
+
+ return nil
+}
+
+// Run runs the ClusterNet service. See package description for what it does.
+func (s *Service) Run(ctx context.Context) error {
+ logger := supervisor.Logger(ctx)
+ s.logger = logger
+
+ wgClient, err := wgctrl.New()
+ if err != nil {
+ return fmt.Errorf("failed to connect to netlink's WireGuard config endpoint: %w", err)
+ }
+ s.wgClient = wgClient
+
+ if err := s.ensureOnDiskKey(); err != nil {
+ return fmt.Errorf("failed to ensure on-disk key: %w", err)
+ }
+
+ wgInterface := &Wireguard{LinkAttrs: netlink.LinkAttrs{Name: clusterNetDeviceName, Flags: net.FlagUp}}
+ if err := netlink.LinkAdd(wgInterface); err != nil {
+ return fmt.Errorf("failed to add WireGuard network interfacee: %w", err)
+ }
+ defer netlink.LinkDel(wgInterface)
+
+ listenPort := common.WireGuardPort
+ if err := wgClient.ConfigureDevice(clusterNetDeviceName, wgtypes.Config{
+ PrivateKey: &s.privKey,
+ ListenPort: &listenPort,
+ }); err != nil {
+ return fmt.Errorf("failed to set up WireGuard interface: %w", err)
+ }
+
+ if err := netlink.RouteAdd(&netlink.Route{
+ Dst: &s.ClusterNet,
+ LinkIndex: wgInterface.Index,
+ }); err != nil && !os.IsExist(err) {
+ return fmt.Errorf("failed to add cluster net route to Wireguard interface: %w", err)
+ }
+
+ if err := s.annotateThisNode(ctx); err != nil {
+ return fmt.Errorf("when annotating this node with public key: %w", err)
+ }
+
+ nodeInformer := s.InformerFactory.Core().V1().Nodes()
+ nodeInformer.Informer().AddEventHandler(cache.ResourceEventHandlerFuncs{
+ AddFunc: func(new interface{}) {
+ newNode, ok := new.(*corev1.Node)
+ if !ok {
+ logger.Errorf("Received non-node item %+v in node event handler", new)
+ return
+ }
+ if err := s.ensureNode(newNode); err != nil {
+ logger.Warningf("Failed to sync node: %v", err)
+ }
+ },
+ UpdateFunc: func(old, new interface{}) {
+ newNode, ok := new.(*corev1.Node)
+ if !ok {
+ logger.Errorf("Received non-node item %+v in node event handler", new)
+ return
+ }
+ if err := s.ensureNode(newNode); err != nil {
+ logger.Warningf("Failed to sync node: %v", err)
+ }
+ },
+ DeleteFunc: func(old interface{}) {
+ oldNode, ok := old.(*corev1.Node)
+ if !ok {
+ logger.Errorf("Received non-node item %+v in node event handler", oldNode)
+ return
+ }
+ if err := s.removeNode(oldNode); err != nil {
+ logger.Warningf("Failed to sync node: %v", err)
+ }
+ },
+ })
+
+ supervisor.Signal(ctx, supervisor.SignalHealthy)
+ nodeInformer.Informer().Run(ctx.Done())
+ return ctx.Err()
+}
diff --git a/metropolis/node/kubernetes/clusternet/netlink_compat.go b/metropolis/node/kubernetes/clusternet/netlink_compat.go
new file mode 100644
index 0000000..a90cc47
--- /dev/null
+++ b/metropolis/node/kubernetes/clusternet/netlink_compat.go
@@ -0,0 +1,33 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Can be removed once https://github.com/vishvananda/netlink/pull/464 lands
+package clusternet
+
+import "github.com/vishvananda/netlink"
+
+// Wireguard represent links of type "wireguard", see https://www.wireguard.com/
+type Wireguard struct {
+ netlink.LinkAttrs
+}
+
+func (wg *Wireguard) Attrs() *netlink.LinkAttrs {
+ return &wg.LinkAttrs
+}
+
+func (wg *Wireguard) Type() string {
+ return "wireguard"
+}
diff --git a/metropolis/node/kubernetes/containerd/BUILD.bazel b/metropolis/node/kubernetes/containerd/BUILD.bazel
new file mode 100644
index 0000000..9e42595
--- /dev/null
+++ b/metropolis/node/kubernetes/containerd/BUILD.bazel
@@ -0,0 +1,20 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "go_default_library",
+ srcs = ["main.go"],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/node/kubernetes/containerd",
+ visibility = ["//metropolis/node/core:__subpackages__"],
+ deps = [
+ "//metropolis/node/common/supervisor:go_default_library",
+ "//metropolis/node/core/localstorage:go_default_library",
+ "@com_github_containerd_containerd//:go_default_library",
+ "@com_github_containerd_containerd//namespaces:go_default_library",
+ ],
+)
+
+exports_files([
+ "config.toml",
+ "runsc.toml",
+ "cnispec.gojson",
+])
diff --git a/metropolis/node/kubernetes/containerd/cnispec.gojson b/metropolis/node/kubernetes/containerd/cnispec.gojson
new file mode 100644
index 0000000..0057036
--- /dev/null
+++ b/metropolis/node/kubernetes/containerd/cnispec.gojson
@@ -0,0 +1,29 @@
+{{- /*gotype: github.com/containerd/cri/pkg/server.cniConfigTemplate*/ -}}
+{
+ "name": "k8s-pod-network",
+ "cniVersion": "0.3.1",
+ "plugins": [
+ {
+ "type": "ptp",
+ "mtu": 1420,
+ "ipam": {
+ "type": "host-local",
+ "dataDir": "/containerd/run/ipam",
+ "ranges": [
+ {{range $i, $range := .PodCIDRRanges}}{{if $i}},
+ {{end}}[
+ {
+ "subnet": "{{$range}}"
+ }
+ ]
+ {{end}}
+ ],
+ "routes": [
+ {{range $i, $route := .Routes}}{{if $i}},
+ {{end}}{
+ "dst": "{{$route}}"
+}{{end}}]
+}
+}
+]
+}
\ No newline at end of file
diff --git a/metropolis/node/kubernetes/containerd/config.toml b/metropolis/node/kubernetes/containerd/config.toml
new file mode 100644
index 0000000..f8c7fb1
--- /dev/null
+++ b/metropolis/node/kubernetes/containerd/config.toml
@@ -0,0 +1,125 @@
+version = 2
+root = "/data/containerd"
+state = "/ephemeral/containerd"
+plugin_dir = ""
+disabled_plugins = []
+required_plugins = []
+oom_score = 0
+
+[grpc]
+ address = "/ephemeral/containerd/client.sock"
+ tcp_address = ""
+ tcp_tls_cert = ""
+ tcp_tls_key = ""
+ uid = 0
+ gid = 0
+ max_recv_message_size = 16777216
+ max_send_message_size = 16777216
+
+[ttrpc]
+ address = ""
+ uid = 0
+ gid = 0
+
+[debug]
+ address = ""
+ uid = 0
+ gid = 0
+ level = ""
+
+[metrics]
+ address = ""
+ grpc_histogram = false
+
+[cgroup]
+ path = ""
+
+[timeouts]
+ "io.containerd.timeout.shim.cleanup" = "5s"
+ "io.containerd.timeout.shim.load" = "5s"
+ "io.containerd.timeout.shim.shutdown" = "3s"
+ "io.containerd.timeout.task.state" = "2s"
+
+[plugins]
+ [plugins."io.containerd.gc.v1.scheduler"]
+ pause_threshold = 0.02
+ deletion_threshold = 0
+ mutation_threshold = 100
+ schedule_delay = "0s"
+ startup_delay = "100ms"
+ [plugins."io.containerd.grpc.v1.cri"]
+ disable_tcp_service = true
+ stream_server_address = "127.0.0.1"
+ stream_server_port = "0"
+ stream_idle_timeout = "4h0m0s"
+ enable_selinux = false
+ sandbox_image = "k8s.gcr.io/pause:3.1"
+ stats_collect_period = 10
+ systemd_cgroup = false
+ enable_tls_streaming = false
+ ignore_image_defined_volumes = true
+ max_container_log_line_size = 16384
+ disable_cgroup = false
+ disable_apparmor = true
+ restrict_oom_score_adj = false
+ max_concurrent_downloads = 3
+ disable_proc_mount = false
+ [plugins."io.containerd.grpc.v1.cri".containerd]
+ snapshotter = "overlayfs"
+ default_runtime_name = "runsc"
+ no_pivot = false
+ [plugins."io.containerd.grpc.v1.cri".containerd.default_runtime]
+ runtime_type = ""
+ runtime_engine = ""
+ runtime_root = ""
+ privileged_without_host_devices = false
+ [plugins."io.containerd.grpc.v1.cri".containerd.untrusted_workload_runtime]
+ runtime_type = ""
+ runtime_engine = ""
+ runtime_root = ""
+ privileged_without_host_devices = false
+ [plugins."io.containerd.grpc.v1.cri".containerd.runtimes]
+ [plugins."io.containerd.grpc.v1.cri".containerd.runtimes.runsc]
+ runtime_type = "io.containerd.runsc.v1"
+ runtime_engine = ""
+ runtime_root = ""
+ privileged_without_host_devices = false
+ [plugins."io.containerd.grpc.v1.cri".containerd.runtimes.runsc.options]
+ TypeUrl = "io.containerd.runsc.v1.options"
+ ConfigPath = "/containerd/conf/runsc.toml"
+ [plugins."io.containerd.grpc.v1.cri".containerd.runtimes.runc]
+ runtime_type = "io.containerd.runc.v2"
+ runtime_engine = ""
+ runtime_root = ""
+ privileged_without_host_devices = false
+ base_runtime_spec = ""
+ [plugins."io.containerd.grpc.v1.cri".cni]
+ bin_dir = "/containerd/bin/cni"
+ conf_dir = "/containerd/conf/cni"
+ max_conf_num = 0
+ conf_template = "/containerd/conf/cnispec.gojson"
+ [plugins."io.containerd.grpc.v1.cri".registry]
+ [plugins."io.containerd.grpc.v1.cri".registry.mirrors]
+ [plugins."io.containerd.grpc.v1.cri".registry.mirrors."docker.io"]
+ endpoint = ["https://registry-1.docker.io"]
+ [plugins."io.containerd.grpc.v1.cri".x509_key_pair_streaming]
+ tls_cert_file = ""
+ tls_key_file = ""
+ [plugins."io.containerd.internal.v1.opt"]
+ path = "/containerd/bin"
+ [plugins."io.containerd.internal.v1.restart"]
+ interval = "10s"
+ [plugins."io.containerd.metadata.v1.bolt"]
+ content_sharing_policy = "shared"
+ [plugins."io.containerd.monitor.v1.cgroups"]
+ no_prometheus = false
+ [plugins."io.containerd.runtime.v1.linux"]
+ shim = "containerd-shim"
+ runtime = "noop"
+ runtime_root = ""
+ no_shim = false
+ shim_debug = false
+ [plugins."io.containerd.runtime.v2.task"]
+ platforms = ["linux/amd64"]
+ [plugins."io.containerd.service.v1.diff-service"]
+ default = ["walking"]
\ No newline at end of file
diff --git a/metropolis/node/kubernetes/containerd/main.go b/metropolis/node/kubernetes/containerd/main.go
new file mode 100644
index 0000000..366f902
--- /dev/null
+++ b/metropolis/node/kubernetes/containerd/main.go
@@ -0,0 +1,146 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package containerd
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "strings"
+ "time"
+
+ ctr "github.com/containerd/containerd"
+ "github.com/containerd/containerd/namespaces"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/node/common/supervisor"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/localstorage"
+)
+
+const (
+ preseedNamespacesDir = "/containerd/preseed/"
+)
+
+type Service struct {
+ EphemeralVolume *localstorage.EphemeralContainerdDirectory
+}
+
+func (s *Service) Run(ctx context.Context) error {
+ cmd := exec.CommandContext(ctx, "/containerd/bin/containerd", "--config", "/containerd/conf/config.toml")
+ cmd.Env = []string{"PATH=/containerd/bin", "TMPDIR=" + s.EphemeralVolume.Tmp.FullPath()}
+
+ runscFifo, err := os.OpenFile(s.EphemeralVolume.RunSCLogsFIFO.FullPath(), os.O_CREATE|os.O_RDONLY, os.ModeNamedPipe|0777)
+ if err != nil {
+ return err
+ }
+
+ if err := supervisor.Run(ctx, "runsc", s.logPump(runscFifo)); err != nil {
+ return fmt.Errorf("failed to start runsc log pump: %w", err)
+ }
+
+ if err := supervisor.Run(ctx, "preseed", s.runPreseed); err != nil {
+ return fmt.Errorf("failed to start preseed runnable: %w", err)
+ }
+ return supervisor.RunCommand(ctx, cmd)
+}
+
+// logPump returns a runnable that pipes data from a file/FIFO into its raw logger.
+// TODO(q3k): refactor this out to a generic function in supervisor or logtree.
+func (s *Service) logPump(fifo *os.File) supervisor.Runnable {
+ return func(ctx context.Context) error {
+ supervisor.Signal(ctx, supervisor.SignalHealthy)
+ for {
+ // Quit if requested.
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ default:
+ }
+
+ n, err := io.Copy(supervisor.RawLogger(ctx), fifo)
+ if n == 0 && err == nil {
+ // Hack because pipes/FIFOs can return zero reads when nobody is writing. To avoid busy-looping,
+ // sleep a bit before retrying. This does not loose data since the FIFO internal buffer will
+ // stall writes when it becomes full. 10ms maximum stall in a non-latency critical process (reading
+ // debug logs) is not an issue for us.
+ time.Sleep(10 * time.Millisecond)
+ } else if err != nil {
+ return fmt.Errorf("log pump failed: %v", err)
+ }
+ }
+ }
+}
+
+// runPreseed loads OCI bundles in tar form from preseedNamespacesDir into containerd at startup.
+// This can be run multiple times, containerd will automatically dedup the layers.
+// containerd uses namespaces to keep images (and everything else) separate so to define where the images will be loaded
+// to they need to be in a folder named after the namespace they should be loaded into.
+// containerd's CRI plugin (which is built as part of containerd) uses a hardcoded namespace ("k8s.io") for everything
+// accessed through CRI, so if an image should be available on K8s it needs to be in that namespace.
+// As an example if image helloworld should be loaded for use with Kubernetes, the OCI bundle needs to be at
+// <preseedNamespacesDir>/k8s.io/helloworld.tar. No tagging beyond what's in the bundle is performed.
+func (s *Service) runPreseed(ctx context.Context) error {
+ client, err := ctr.New(s.EphemeralVolume.ClientSocket.FullPath())
+ if err != nil {
+ return fmt.Errorf("failed to connect to containerd: %w", err)
+ }
+ logger := supervisor.Logger(ctx)
+ preseedNamespaceDirs, err := ioutil.ReadDir(preseedNamespacesDir)
+ if err != nil {
+ return fmt.Errorf("failed to open preseed dir: %w", err)
+ }
+ for _, dir := range preseedNamespaceDirs {
+ if !dir.IsDir() {
+ logger.Warningf("Non-Directory %q found in preseed folder, ignoring", dir.Name())
+ continue
+ }
+ namespace := dir.Name()
+ images, err := ioutil.ReadDir(filepath.Join(preseedNamespacesDir, namespace))
+ if err != nil {
+ return fmt.Errorf("failed to list namespace preseed directory for ns \"%v\": %w", namespace, err)
+ }
+ ctxWithNS := namespaces.WithNamespace(ctx, namespace)
+ for _, image := range images {
+ if image.IsDir() {
+ logger.Warningf("Directory %q found in preseed namespaced folder, ignoring", image.Name())
+ continue
+ }
+ imageFile, err := os.Open(filepath.Join(preseedNamespacesDir, namespace, image.Name()))
+ if err != nil {
+ return fmt.Errorf("failed to open preseed image \"%v\": %w", image.Name(), err)
+ }
+ // defer in this loop is fine since we're never going to preseed more than ~1M images which is where our
+ // file descriptor limit is.
+ defer imageFile.Close()
+ importedImages, err := client.Import(ctxWithNS, imageFile)
+ if err != nil {
+ return fmt.Errorf("failed to import preseed image: %w", err)
+ }
+ var importedImageNames []string
+ for _, img := range importedImages {
+ importedImageNames = append(importedImageNames, img.Name)
+ }
+ logger.Infof("Successfully imported preseeded bundle %s/%s into containerd", namespace, strings.Join(importedImageNames, ","))
+ }
+ }
+ supervisor.Signal(ctx, supervisor.SignalHealthy)
+ supervisor.Signal(ctx, supervisor.SignalDone)
+ return nil
+}
diff --git a/metropolis/node/kubernetes/containerd/runsc.toml b/metropolis/node/kubernetes/containerd/runsc.toml
new file mode 100644
index 0000000..4fe0751
--- /dev/null
+++ b/metropolis/node/kubernetes/containerd/runsc.toml
@@ -0,0 +1,6 @@
+root = "/ephemeral/containerd/runsc"
+[runsc_config]
+debug = "true"
+debug-log = "/ephemeral/containerd/runsc-logs.fifo"
+panic-log = "/ephemeral/containerd/runsc-logs.fifo"
+log = "/ephemeral/containerd/runsc-logs.fifo"
diff --git a/metropolis/node/kubernetes/controller-manager.go b/metropolis/node/kubernetes/controller-manager.go
new file mode 100644
index 0000000..487511f
--- /dev/null
+++ b/metropolis/node/kubernetes/controller-manager.go
@@ -0,0 +1,93 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kubernetes
+
+import (
+ "context"
+ "encoding/pem"
+ "fmt"
+ "net"
+ "os/exec"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/node/common/fileargs"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/common/supervisor"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/kubernetes/pki"
+)
+
+type controllerManagerConfig struct {
+ clusterNet net.IPNet
+ // All PKI-related things are in DER
+ kubeConfig []byte
+ rootCA []byte
+ serviceAccountPrivKey []byte // In PKCS#8 form
+ serverCert []byte
+ serverKey []byte
+}
+
+func getPKIControllerManagerConfig(ctx context.Context, kpki *pki.KubernetesPKI) (*controllerManagerConfig, error) {
+ var config controllerManagerConfig
+ var err error
+ config.rootCA, _, err = kpki.Certificate(ctx, pki.IdCA)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get ID root CA: %w", err)
+ }
+ config.serverCert, config.serverKey, err = kpki.Certificate(ctx, pki.ControllerManager)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get controller-manager serving certificate: %w", err)
+ }
+ config.serviceAccountPrivKey, err = kpki.ServiceAccountKey(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get serviceaccount privkey: %w", err)
+ }
+ config.kubeConfig, err = kpki.Kubeconfig(ctx, pki.ControllerManagerClient)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get controller-manager kubeconfig: %w", err)
+ }
+ return &config, nil
+}
+
+func runControllerManager(config controllerManagerConfig) supervisor.Runnable {
+ return func(ctx context.Context) error {
+ args, err := fileargs.New()
+ if err != nil {
+ panic(err) // If this fails, something is very wrong. Just crash.
+ }
+ defer args.Close()
+
+ cmd := exec.CommandContext(ctx, "/kubernetes/bin/kube", "kube-controller-manager",
+ args.FileOpt("--kubeconfig", "kubeconfig", config.kubeConfig),
+ args.FileOpt("--service-account-private-key-file", "service-account-privkey.pem",
+ pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: config.serviceAccountPrivKey})),
+ args.FileOpt("--root-ca-file", "root-ca.pem",
+ pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: config.rootCA})),
+ "--port=0", // Kill insecure serving
+ "--use-service-account-credentials=true", // Enables things like PSP enforcement
+ fmt.Sprintf("--cluster-cidr=%v", config.clusterNet.String()),
+ args.FileOpt("--tls-cert-file", "server-cert.pem",
+ pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: config.serverCert})),
+ args.FileOpt("--tls-private-key-file", "server-key.pem",
+ pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: config.serverKey})),
+ "--allocate-node-cidrs",
+ "--cluster-cidr="+config.clusterNet.String(),
+ )
+
+ if args.Error() != nil {
+ return fmt.Errorf("failed to use fileargs: %w", err)
+ }
+ return supervisor.RunCommand(ctx, cmd)
+ }
+}
diff --git a/metropolis/node/kubernetes/csi.go b/metropolis/node/kubernetes/csi.go
new file mode 100644
index 0000000..4b44a1a
--- /dev/null
+++ b/metropolis/node/kubernetes/csi.go
@@ -0,0 +1,246 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kubernetes
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "os"
+ "path/filepath"
+ "regexp"
+
+ "github.com/container-storage-interface/spec/lib/go/csi"
+ "github.com/golang/protobuf/ptypes/wrappers"
+ "golang.org/x/sys/unix"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+ pluginregistration "k8s.io/kubelet/pkg/apis/pluginregistration/v1"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/node/common/fsquota"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/common/supervisor"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/localstorage"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/logtree"
+)
+
+// Derived from K8s spec for acceptable names, but shortened to 130 characters to avoid issues with
+// maximum path length. We don't provision longer names so this applies only if you manually create
+// a volume with a name of more than 130 characters.
+var acceptableNames = regexp.MustCompile("^[a-z][a-bz0-9-.]{0,128}[a-z0-9]$")
+
+const volumeDir = "volumes"
+
+type csiPluginServer struct {
+ KubeletDirectory *localstorage.DataKubernetesKubeletDirectory
+ VolumesDirectory *localstorage.DataVolumesDirectory
+
+ logger logtree.LeveledLogger
+}
+
+func (s *csiPluginServer) Run(ctx context.Context) error {
+ s.logger = supervisor.Logger(ctx)
+
+ pluginListener, err := net.ListenUnix("unix", &net.UnixAddr{Name: s.KubeletDirectory.Plugins.VFS.FullPath(), Net: "unix"})
+ if err != nil {
+ return fmt.Errorf("failed to listen on CSI socket: %w", err)
+ }
+ pluginListener.SetUnlinkOnClose(true)
+
+ pluginServer := grpc.NewServer()
+ csi.RegisterIdentityServer(pluginServer, s)
+ csi.RegisterNodeServer(pluginServer, s)
+ // Enable graceful shutdown since we don't have long-running RPCs and most of them shouldn't and can't be
+ // cancelled anyways.
+ if err := supervisor.Run(ctx, "csi-node", supervisor.GRPCServer(pluginServer, pluginListener, true)); err != nil {
+ return err
+ }
+
+ registrationListener, err := net.ListenUnix("unix", &net.UnixAddr{Name: s.KubeletDirectory.PluginsRegistry.VFSReg.FullPath(), Net: "unix"})
+ if err != nil {
+ return fmt.Errorf("failed to listen on CSI registration socket: %w", err)
+ }
+ registrationListener.SetUnlinkOnClose(true)
+
+ registrationServer := grpc.NewServer()
+ pluginregistration.RegisterRegistrationServer(registrationServer, s)
+ if err := supervisor.Run(ctx, "registration", supervisor.GRPCServer(registrationServer, registrationListener, true)); err != nil {
+ return err
+ }
+ supervisor.Signal(ctx, supervisor.SignalHealthy)
+ supervisor.Signal(ctx, supervisor.SignalDone)
+ return nil
+}
+
+func (*csiPluginServer) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRequest) (*csi.NodeStageVolumeResponse, error) {
+ return nil, status.Errorf(codes.Unimplemented, "method NodeStageVolume not supported")
+}
+
+func (*csiPluginServer) NodeUnstageVolume(ctx context.Context, req *csi.NodeUnstageVolumeRequest) (*csi.NodeUnstageVolumeResponse, error) {
+ return nil, status.Errorf(codes.Unimplemented, "method NodeUnstageVolume not supported")
+}
+
+func (s *csiPluginServer) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolumeRequest) (*csi.NodePublishVolumeResponse, error) {
+ if !acceptableNames.MatchString(req.VolumeId) {
+ return nil, status.Error(codes.InvalidArgument, "invalid characters in volume id")
+ }
+
+ // TODO(q3k): move this logic to localstorage?
+ volumePath := filepath.Join(s.VolumesDirectory.FullPath(), req.VolumeId)
+
+ switch req.VolumeCapability.AccessMode.Mode {
+ case csi.VolumeCapability_AccessMode_SINGLE_NODE_WRITER:
+ case csi.VolumeCapability_AccessMode_SINGLE_NODE_READER_ONLY:
+ default:
+ return nil, status.Error(codes.InvalidArgument, "unsupported access mode")
+ }
+ switch req.VolumeCapability.AccessType.(type) {
+ case *csi.VolumeCapability_Mount:
+ default:
+ return nil, status.Error(codes.InvalidArgument, "unsupported access type")
+ }
+
+ err := unix.Mount(volumePath, req.TargetPath, "", unix.MS_BIND, "")
+ switch {
+ case err == unix.ENOENT:
+ return nil, status.Error(codes.NotFound, "volume not found")
+ case err != nil:
+ return nil, status.Errorf(codes.Unavailable, "failed to bind-mount volume: %v", err)
+ }
+
+ if req.Readonly {
+ err := unix.Mount(volumePath, req.TargetPath, "", unix.MS_BIND|unix.MS_REMOUNT|unix.MS_RDONLY, "")
+ if err != nil {
+ _ = unix.Unmount(req.TargetPath, 0) // Best-effort
+ return nil, status.Errorf(codes.Unavailable, "failed to remount volume: %v", err)
+ }
+ }
+ return &csi.NodePublishVolumeResponse{}, nil
+}
+
+func (*csiPluginServer) NodeUnpublishVolume(ctx context.Context, req *csi.NodeUnpublishVolumeRequest) (*csi.NodeUnpublishVolumeResponse, error) {
+ if err := unix.Unmount(req.TargetPath, 0); err != nil {
+ return nil, status.Errorf(codes.Unavailable, "failed to unmount volume: %v", err)
+ }
+ return &csi.NodeUnpublishVolumeResponse{}, nil
+}
+
+func (*csiPluginServer) NodeGetVolumeStats(ctx context.Context, req *csi.NodeGetVolumeStatsRequest) (*csi.NodeGetVolumeStatsResponse, error) {
+ quota, err := fsquota.GetQuota(req.VolumePath)
+ if os.IsNotExist(err) {
+ return nil, status.Error(codes.NotFound, "volume does not exist at this path")
+ } else if err != nil {
+ return nil, status.Errorf(codes.Unavailable, "failed to get quota: %v", err)
+ }
+
+ return &csi.NodeGetVolumeStatsResponse{
+ Usage: []*csi.VolumeUsage{
+ {
+ Total: int64(quota.Bytes),
+ Unit: csi.VolumeUsage_BYTES,
+ Used: int64(quota.BytesUsed),
+ Available: int64(quota.Bytes - quota.BytesUsed),
+ },
+ {
+ Total: int64(quota.Inodes),
+ Unit: csi.VolumeUsage_INODES,
+ Used: int64(quota.InodesUsed),
+ Available: int64(quota.Inodes - quota.InodesUsed),
+ },
+ },
+ }, nil
+}
+
+func (*csiPluginServer) NodeExpandVolume(ctx context.Context, req *csi.NodeExpandVolumeRequest) (*csi.NodeExpandVolumeResponse, error) {
+ if req.CapacityRange.LimitBytes <= 0 {
+ return nil, status.Error(codes.InvalidArgument, "invalid expanded volume size: at or below zero bytes")
+ }
+ if err := fsquota.SetQuota(req.VolumePath, uint64(req.CapacityRange.LimitBytes), 0); err != nil {
+ return nil, status.Errorf(codes.Unavailable, "failed to update quota: %v", err)
+ }
+ return &csi.NodeExpandVolumeResponse{CapacityBytes: req.CapacityRange.LimitBytes}, nil
+}
+
+func rpcCapability(cap csi.NodeServiceCapability_RPC_Type) *csi.NodeServiceCapability {
+ return &csi.NodeServiceCapability{
+ Type: &csi.NodeServiceCapability_Rpc{
+ Rpc: &csi.NodeServiceCapability_RPC{Type: cap},
+ },
+ }
+}
+
+func (*csiPluginServer) NodeGetCapabilities(ctx context.Context, req *csi.NodeGetCapabilitiesRequest) (*csi.NodeGetCapabilitiesResponse, error) {
+ return &csi.NodeGetCapabilitiesResponse{
+ Capabilities: []*csi.NodeServiceCapability{
+ rpcCapability(csi.NodeServiceCapability_RPC_EXPAND_VOLUME),
+ rpcCapability(csi.NodeServiceCapability_RPC_GET_VOLUME_STATS),
+ },
+ }, nil
+}
+
+func (*csiPluginServer) NodeGetInfo(ctx context.Context, req *csi.NodeGetInfoRequest) (*csi.NodeGetInfoResponse, error) {
+ hostname, err := os.Hostname()
+ if err != nil {
+ return nil, status.Errorf(codes.Unavailable, "failed to get node identity: %v", err)
+ }
+ return &csi.NodeGetInfoResponse{
+ NodeId: hostname,
+ }, nil
+}
+
+// CSI Identity endpoints
+func (*csiPluginServer) GetPluginInfo(ctx context.Context, req *csi.GetPluginInfoRequest) (*csi.GetPluginInfoResponse, error) {
+ return &csi.GetPluginInfoResponse{
+ Name: "com.smalltown.vfs",
+ VendorVersion: "0.0.1", // TODO(lorenz): Maybe stamp?
+ }, nil
+}
+
+func (*csiPluginServer) GetPluginCapabilities(ctx context.Context, req *csi.GetPluginCapabilitiesRequest) (*csi.GetPluginCapabilitiesResponse, error) {
+ return &csi.GetPluginCapabilitiesResponse{
+ Capabilities: []*csi.PluginCapability{
+ {
+ Type: &csi.PluginCapability_VolumeExpansion_{
+ VolumeExpansion: &csi.PluginCapability_VolumeExpansion{
+ Type: csi.PluginCapability_VolumeExpansion_ONLINE,
+ },
+ },
+ },
+ },
+ }, nil
+}
+
+func (s *csiPluginServer) Probe(ctx context.Context, req *csi.ProbeRequest) (*csi.ProbeResponse, error) {
+ return &csi.ProbeResponse{Ready: &wrappers.BoolValue{Value: true}}, nil
+}
+
+// Registration endpoints
+func (s *csiPluginServer) GetInfo(ctx context.Context, req *pluginregistration.InfoRequest) (*pluginregistration.PluginInfo, error) {
+ return &pluginregistration.PluginInfo{
+ Type: "CSIPlugin",
+ Name: "com.smalltown.vfs",
+ Endpoint: s.KubeletDirectory.Plugins.VFS.FullPath(),
+ SupportedVersions: []string{"1.2"}, // Keep in sync with container-storage-interface/spec package version
+ }, nil
+}
+
+func (s *csiPluginServer) NotifyRegistrationStatus(ctx context.Context, req *pluginregistration.RegistrationStatus) (*pluginregistration.RegistrationStatusResponse, error) {
+ if req.Error != "" {
+ s.logger.Warningf("Kubelet failed registering CSI plugin: %v", req.Error)
+ }
+ return &pluginregistration.RegistrationStatusResponse{}, nil
+}
diff --git a/metropolis/node/kubernetes/hyperkube/BUILD b/metropolis/node/kubernetes/hyperkube/BUILD
new file mode 100644
index 0000000..dced1c7
--- /dev/null
+++ b/metropolis/node/kubernetes/hyperkube/BUILD
@@ -0,0 +1,29 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_library")
+load("@//third_party/go:kubernetes_version_def.bzl", "version_x_defs")
+
+go_library(
+ name = "go_default_library",
+ srcs = ["main.go"],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/node/kubernetes/hyperkube",
+ visibility = ["//visibility:private"],
+ deps = [
+ "@com_github_spf13_cobra//:go_default_library",
+ "@com_github_spf13_pflag//:go_default_library",
+ "@io_k8s_component_base//cli/flag:go_default_library",
+ "@io_k8s_component_base//logs:go_default_library",
+ "@io_k8s_component_base//metrics/prometheus/restclient:go_default_library",
+ "@io_k8s_component_base//metrics/prometheus/version:go_default_library",
+ "@io_k8s_kubernetes//cmd/kube-apiserver/app:go_default_library",
+ "@io_k8s_kubernetes//cmd/kube-controller-manager/app:go_default_library",
+ "@io_k8s_kubernetes//cmd/kube-scheduler/app:go_default_library",
+ "@io_k8s_kubernetes//cmd/kubelet/app:go_default_library",
+ ],
+)
+
+go_binary(
+ name = "hyperkube",
+ embed = [":go_default_library"],
+ pure = "on",
+ visibility = ["//visibility:public"],
+ x_defs = version_x_defs(),
+)
diff --git a/metropolis/node/kubernetes/hyperkube/main.go b/metropolis/node/kubernetes/hyperkube/main.go
new file mode 100644
index 0000000..3b4ac08
--- /dev/null
+++ b/metropolis/node/kubernetes/hyperkube/main.go
@@ -0,0 +1,122 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+/*
+Copyright 2014 The Kubernetes Authors.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+// Adapted from https://github.com/dims/hyperkube
+
+package main
+
+import (
+ goflag "flag"
+ "math/rand"
+ "os"
+ "path/filepath"
+ "time"
+
+ "github.com/spf13/cobra"
+ "github.com/spf13/pflag"
+
+ cliflag "k8s.io/component-base/cli/flag"
+ "k8s.io/component-base/logs"
+ _ "k8s.io/component-base/metrics/prometheus/restclient" // for client metric registration
+ _ "k8s.io/component-base/metrics/prometheus/version" // for version metric registration
+ kubeapiserver "k8s.io/kubernetes/cmd/kube-apiserver/app"
+ kubecontrollermanager "k8s.io/kubernetes/cmd/kube-controller-manager/app"
+ kubescheduler "k8s.io/kubernetes/cmd/kube-scheduler/app"
+ kubelet "k8s.io/kubernetes/cmd/kubelet/app"
+)
+
+func main() {
+ rand.Seed(time.Now().UnixNano())
+
+ hyperkubeCommand, allCommandFns := NewHyperKubeCommand()
+
+ // TODO: once we switch everything over to Cobra commands, we can go back to calling
+ // cliflag.InitFlags() (by removing its pflag.Parse() call). For now, we have to set the
+ // normalize func and add the go flag set by hand.
+ pflag.CommandLine.SetNormalizeFunc(cliflag.WordSepNormalizeFunc)
+ pflag.CommandLine.AddGoFlagSet(goflag.CommandLine)
+ // cliflag.InitFlags()
+ logs.InitLogs()
+ defer logs.FlushLogs()
+
+ basename := filepath.Base(os.Args[0])
+ if err := commandFor(basename, hyperkubeCommand, allCommandFns).Execute(); err != nil {
+ os.Exit(1)
+ }
+}
+
+func commandFor(basename string, defaultCommand *cobra.Command, commands []func() *cobra.Command) *cobra.Command {
+ for _, commandFn := range commands {
+ command := commandFn()
+ if command.Name() == basename {
+ return command
+ }
+ for _, alias := range command.Aliases {
+ if alias == basename {
+ return command
+ }
+ }
+ }
+
+ return defaultCommand
+}
+
+// NewHyperKubeCommand is the entry point for hyperkube
+func NewHyperKubeCommand() (*cobra.Command, []func() *cobra.Command) {
+ // these have to be functions since the command is polymorphic. Cobra wants you to be top level
+ // command to get executed
+ apiserver := func() *cobra.Command { return kubeapiserver.NewAPIServerCommand() }
+ controller := func() *cobra.Command { return kubecontrollermanager.NewControllerManagerCommand() }
+ scheduler := func() *cobra.Command { return kubescheduler.NewSchedulerCommand() }
+ kubelet := func() *cobra.Command { return kubelet.NewKubeletCommand() }
+
+ commandFns := []func() *cobra.Command{
+ apiserver,
+ controller,
+ scheduler,
+ kubelet,
+ }
+
+ cmd := &cobra.Command{
+ Use: "kube",
+ Short: "Combines all Kubernetes components in a single binary",
+ Run: func(cmd *cobra.Command, args []string) {
+ if len(args) != 0 {
+ cmd.Help()
+ os.Exit(1)
+ }
+ },
+ }
+
+ for i := range commandFns {
+ cmd.AddCommand(commandFns[i]())
+ }
+
+ return cmd, commandFns
+}
diff --git a/metropolis/node/kubernetes/kubelet.go b/metropolis/node/kubernetes/kubelet.go
new file mode 100644
index 0000000..e9c6ce5
--- /dev/null
+++ b/metropolis/node/kubernetes/kubelet.go
@@ -0,0 +1,145 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kubernetes
+
+import (
+ "context"
+ "encoding/json"
+ "encoding/pem"
+ "fmt"
+ "io"
+ "net"
+ "os/exec"
+
+ v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
+ kubeletconfig "k8s.io/kubelet/config/v1beta1"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/node/common/fileargs"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/common/supervisor"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/localstorage"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/localstorage/declarative"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/kubernetes/pki"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/kubernetes/reconciler"
+)
+
+type kubeletService struct {
+ NodeName string
+ ClusterDNS []net.IP
+ KubeletDirectory *localstorage.DataKubernetesKubeletDirectory
+ EphemeralDirectory *localstorage.EphemeralDirectory
+ Output io.Writer
+ KPKI *pki.KubernetesPKI
+}
+
+func (s *kubeletService) createCertificates(ctx context.Context) error {
+ identity := fmt.Sprintf("system:node:%s", s.NodeName)
+
+ ca := s.KPKI.Certificates[pki.IdCA]
+ cacert, _, err := ca.Ensure(ctx, s.KPKI.KV)
+ if err != nil {
+ return fmt.Errorf("could not ensure ca certificate: %w", err)
+ }
+
+ kubeconfig, err := pki.New(ca, "", pki.Client(identity, []string{"system:nodes"})).Kubeconfig(ctx, s.KPKI.KV)
+ if err != nil {
+ return fmt.Errorf("could not create volatile kubelet client cert: %w", err)
+ }
+
+ cert, key, err := pki.New(ca, "", pki.Server([]string{s.NodeName}, nil)).Ensure(ctx, s.KPKI.KV)
+ if err != nil {
+ return fmt.Errorf("could not create volatile kubelet server cert: %w", err)
+ }
+
+ // TODO(q3k): this should probably become its own function //metropolis/node/kubernetes/pki.
+ for _, el := range []struct {
+ target declarative.FilePlacement
+ data []byte
+ }{
+ {s.KubeletDirectory.Kubeconfig, kubeconfig},
+ {s.KubeletDirectory.PKI.CACertificate, pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cacert})},
+ {s.KubeletDirectory.PKI.Certificate, pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert})},
+ {s.KubeletDirectory.PKI.Key, pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: key})},
+ } {
+ if err := el.target.Write(el.data, 0400); err != nil {
+ return fmt.Errorf("could not write %v: %w", el.target, err)
+ }
+ }
+
+ return nil
+}
+
+func (s *kubeletService) configure() *kubeletconfig.KubeletConfiguration {
+ var clusterDNS []string
+ for _, dnsIP := range s.ClusterDNS {
+ clusterDNS = append(clusterDNS, dnsIP.String())
+ }
+
+ return &kubeletconfig.KubeletConfiguration{
+ TypeMeta: v1.TypeMeta{
+ Kind: "KubeletConfiguration",
+ APIVersion: kubeletconfig.GroupName + "/v1beta1",
+ },
+ TLSCertFile: s.KubeletDirectory.PKI.Certificate.FullPath(),
+ TLSPrivateKeyFile: s.KubeletDirectory.PKI.Key.FullPath(),
+ TLSMinVersion: "VersionTLS13",
+ ClusterDNS: clusterDNS,
+ Authentication: kubeletconfig.KubeletAuthentication{
+ X509: kubeletconfig.KubeletX509Authentication{
+ ClientCAFile: s.KubeletDirectory.PKI.CACertificate.FullPath(),
+ },
+ },
+ // TODO(q3k): move reconciler.False to a generic package, fix the following references.
+ ClusterDomain: "cluster.local", // cluster.local is hardcoded in the certificate too currently
+ EnableControllerAttachDetach: reconciler.False(),
+ HairpinMode: "none",
+ MakeIPTablesUtilChains: reconciler.False(), // We don't have iptables
+ FailSwapOn: reconciler.False(), // Our kernel doesn't have swap enabled which breaks Kubelet's detection
+ KubeReserved: map[string]string{
+ "cpu": "200m",
+ "memory": "300Mi",
+ },
+
+ // We're not going to use this, but let's make it point to a known-empty directory in case anybody manages to
+ // trigger it.
+ VolumePluginDir: s.EphemeralDirectory.FlexvolumePlugins.FullPath(),
+ }
+}
+
+func (s *kubeletService) Run(ctx context.Context) error {
+ if err := s.createCertificates(ctx); err != nil {
+ return fmt.Errorf("when creating certificates: %w", err)
+ }
+
+ configRaw, err := json.Marshal(s.configure())
+ if err != nil {
+ return fmt.Errorf("when marshaling kubelet configuration: %w", err)
+ }
+
+ fargs, err := fileargs.New()
+ if err != nil {
+ return err
+ }
+ cmd := exec.CommandContext(ctx, "/kubernetes/bin/kube", "kubelet",
+ fargs.FileOpt("--config", "config.json", configRaw),
+ "--container-runtime=remote",
+ fmt.Sprintf("--container-runtime-endpoint=unix://%s", s.EphemeralDirectory.Containerd.ClientSocket.FullPath()),
+ fmt.Sprintf("--kubeconfig=%s", s.KubeletDirectory.Kubeconfig.FullPath()),
+ fmt.Sprintf("--root-dir=%s", s.KubeletDirectory.FullPath()),
+ )
+ cmd.Env = []string{"PATH=/kubernetes/bin"}
+ return supervisor.RunCommand(ctx, cmd)
+}
diff --git a/metropolis/node/kubernetes/nfproxy/BUILD.bazel b/metropolis/node/kubernetes/nfproxy/BUILD.bazel
new file mode 100644
index 0000000..29124a6
--- /dev/null
+++ b/metropolis/node/kubernetes/nfproxy/BUILD.bazel
@@ -0,0 +1,22 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "go_default_library",
+ srcs = ["nfproxy.go"],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/node/kubernetes/nfproxy",
+ visibility = ["//metropolis/node/kubernetes:__subpackages__"],
+ deps = [
+ "//metropolis/node/common/supervisor:go_default_library",
+ "@com_github_sbezverk_nfproxy//pkg/controller:go_default_library",
+ "@com_github_sbezverk_nfproxy//pkg/nftables:go_default_library",
+ "@com_github_sbezverk_nfproxy//pkg/proxy:go_default_library",
+ "@io_k8s_api//core/v1:go_default_library",
+ "@io_k8s_apimachinery//pkg/apis/meta/v1:go_default_library",
+ "@io_k8s_apimachinery//pkg/labels:go_default_library",
+ "@io_k8s_apimachinery//pkg/selection:go_default_library",
+ "@io_k8s_client_go//informers:go_default_library",
+ "@io_k8s_client_go//kubernetes:go_default_library",
+ "@io_k8s_client_go//kubernetes/scheme:go_default_library",
+ "@io_k8s_client_go//tools/record:go_default_library",
+ ],
+)
diff --git a/metropolis/node/kubernetes/nfproxy/nfproxy.go b/metropolis/node/kubernetes/nfproxy/nfproxy.go
new file mode 100644
index 0000000..5fc9a11
--- /dev/null
+++ b/metropolis/node/kubernetes/nfproxy/nfproxy.go
@@ -0,0 +1,104 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package nfproxy is a Kubernetes Service IP proxy based exclusively on the Linux nftables interface.
+// It uses netfilter's NAT capabilities to accept traffic on service IPs and DNAT it to the respective endpoint.
+package nfproxy
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net"
+ "os"
+ "time"
+
+ "github.com/sbezverk/nfproxy/pkg/controller"
+ "github.com/sbezverk/nfproxy/pkg/nftables"
+ "github.com/sbezverk/nfproxy/pkg/proxy"
+ v1 "k8s.io/api/core/v1"
+ metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
+ "k8s.io/apimachinery/pkg/labels"
+ "k8s.io/apimachinery/pkg/selection"
+ kubeinformers "k8s.io/client-go/informers"
+ "k8s.io/client-go/kubernetes"
+ "k8s.io/client-go/kubernetes/scheme"
+ "k8s.io/client-go/tools/record"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/node/common/supervisor"
+)
+
+type Service struct {
+ // Traffic in ClusterCIDR is assumed to be originated inside the cluster and will not be SNATed
+ ClusterCIDR net.IPNet
+ // A Kubernetes ClientSet with read access to endpoints and services
+ ClientSet kubernetes.Interface
+}
+
+func (s *Service) Run(ctx context.Context) error {
+ var ipv4ClusterCIDR string
+ var ipv6ClusterCIDR string
+ if s.ClusterCIDR.IP.To4() == nil && s.ClusterCIDR.IP.To16() != nil {
+ ipv6ClusterCIDR = s.ClusterCIDR.String()
+ } else if s.ClusterCIDR.IP.To4() != nil {
+ ipv4ClusterCIDR = s.ClusterCIDR.String()
+ } else {
+ return errors.New("invalid ClusterCIDR")
+ }
+ nfti, err := nftables.InitNFTables(ipv4ClusterCIDR, ipv6ClusterCIDR)
+ if err != nil {
+ return fmt.Errorf("failed to initialize nftables with error: %w", err)
+ }
+
+ // Create event recorder to report events into K8s
+ hostname, err := os.Hostname()
+ if err != nil {
+ return fmt.Errorf("failed to get local host name with error: %w", err)
+ }
+ eventBroadcaster := record.NewBroadcaster()
+ recorder := eventBroadcaster.NewRecorder(scheme.Scheme, v1.EventSource{Component: "nfproxy", Host: hostname})
+
+ // Create new proxy controller with endpoint slices enabled
+ // https://kubernetes.io/docs/concepts/services-networking/endpoint-slices/
+ nfproxy := proxy.NewProxy(nfti, hostname, recorder, true)
+
+ // Create special informer which doesn't track headless services
+ noHeadlessEndpoints, err := labels.NewRequirement(v1.IsHeadlessService, selection.DoesNotExist, nil)
+ if err != nil {
+ return fmt.Errorf("failed to create Requirement for noHeadlessEndpoints: %w", err)
+ }
+ labelSelector := labels.NewSelector()
+ labelSelector = labelSelector.Add(*noHeadlessEndpoints)
+
+ kubeInformerFactory := kubeinformers.NewSharedInformerFactoryWithOptions(s.ClientSet, time.Minute*5,
+ kubeinformers.WithTweakListOptions(func(options *metav1.ListOptions) {
+ options.LabelSelector = labelSelector.String()
+ }))
+
+ svcController := controller.NewServiceController(nfproxy, s.ClientSet, kubeInformerFactory.Core().V1().Services())
+ ep := controller.NewEndpointSliceController(nfproxy, s.ClientSet, kubeInformerFactory.Discovery().V1beta1().EndpointSlices())
+ kubeInformerFactory.Start(ctx.Done())
+
+ if err = svcController.Start(ctx.Done()); err != nil {
+ return fmt.Errorf("error running Service controller: %w", err)
+ }
+ if err = ep.Start(ctx.Done()); err != nil {
+ return fmt.Errorf("error running endpoint controller: %w", err)
+ }
+ supervisor.Signal(ctx, supervisor.SignalHealthy)
+ supervisor.Signal(ctx, supervisor.SignalDone)
+ return nil
+}
diff --git a/metropolis/node/kubernetes/pki/BUILD.bazel b/metropolis/node/kubernetes/pki/BUILD.bazel
new file mode 100644
index 0000000..f82603d
--- /dev/null
+++ b/metropolis/node/kubernetes/pki/BUILD.bazel
@@ -0,0 +1,19 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "go_default_library",
+ srcs = [
+ "ca.go",
+ "certificate.go",
+ "kubernetes.go",
+ ],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/node/kubernetes/pki",
+ visibility = ["//metropolis/node:__subpackages__"],
+ deps = [
+ "//metropolis/node:go_default_library",
+ "//metropolis/node/core/logtree:go_default_library",
+ "@io_etcd_go_etcd//clientv3:go_default_library",
+ "@io_k8s_client_go//tools/clientcmd:go_default_library",
+ "@io_k8s_client_go//tools/clientcmd/api:go_default_library",
+ ],
+)
diff --git a/metropolis/node/kubernetes/pki/ca.go b/metropolis/node/kubernetes/pki/ca.go
new file mode 100644
index 0000000..64453cd
--- /dev/null
+++ b/metropolis/node/kubernetes/pki/ca.go
@@ -0,0 +1,151 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pki
+
+import (
+ "context"
+ "crypto"
+ "crypto/ed25519"
+ "crypto/rand"
+ "crypto/sha1"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "encoding/asn1"
+ "fmt"
+ "math/big"
+ "time"
+
+ "go.etcd.io/etcd/clientv3"
+)
+
+// Issuer is a CA that can issue certificates. Two issuers are currently implemented:
+// - SelfSigned, which will generated a certificate signed by its corresponding private key.
+// - Certificate, which will use another existing Certificate as a CA.
+type Issuer interface {
+ // CACertificate returns the DER-encoded x509 certificate of the CA that will sign certificates when Issue is
+ // called, or nil if this is self-signing issuer.
+ CACertificate(ctx context.Context, kv clientv3.KV) ([]byte, error)
+ // Issue will generate a key and certificate signed by the Issuer. The returned certificate is x509 DER-encoded,
+ // while the key is a bare ed25519 key.
+ Issue(ctx context.Context, template x509.Certificate, kv clientv3.KV) (cert, key []byte, err error)
+}
+
+var (
+ // From RFC 5280 Section 4.1.2.5
+ unknownNotAfter = time.Unix(253402300799, 0)
+)
+
+// Workaround for https://github.com/golang/go/issues/26676 in Go's crypto/x509. Specifically Go
+// violates Section 4.2.1.2 of RFC 5280 without this.
+// Fixed for 1.15 in https://go-review.googlesource.com/c/go/+/227098/.
+//
+// Taken from https://github.com/FiloSottile/mkcert/blob/master/cert.go#L295 written by one of Go's
+// crypto engineers
+func calculateSKID(pubKey crypto.PublicKey) ([]byte, error) {
+ spkiASN1, err := x509.MarshalPKIXPublicKey(pubKey)
+ if err != nil {
+ return nil, err
+ }
+
+ var spki struct {
+ Algorithm pkix.AlgorithmIdentifier
+ SubjectPublicKey asn1.BitString
+ }
+ _, err = asn1.Unmarshal(spkiASN1, &spki)
+ if err != nil {
+ return nil, err
+ }
+ skid := sha1.Sum(spki.SubjectPublicKey.Bytes)
+ return skid[:], nil
+}
+
+// issueCertificate is a generic low level certificate-and-key issuance function. If ca or cakey is null, the
+// certificate will be self-signed. The returned certificate is DER-encoded, while the returned key is internal.
+func issueCertificate(template x509.Certificate, ca *x509.Certificate, caKey interface{}) (cert, key []byte, err error) {
+ pubKey, privKey, err := ed25519.GenerateKey(rand.Reader)
+ if err != nil {
+ panic(err)
+ }
+
+ serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 127)
+ serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
+ if err != nil {
+ err = fmt.Errorf("failed to generate serial number: %w", err)
+ return
+ }
+
+ skid, err := calculateSKID(pubKey)
+ if err != nil {
+ return []byte{}, privKey, err
+ }
+
+ template.SerialNumber = serialNumber
+ template.NotBefore = time.Now()
+ template.NotAfter = unknownNotAfter
+ template.BasicConstraintsValid = true
+ template.SubjectKeyId = skid
+
+ // Set the AuthorityKeyID to the SKID of the signing certificate (or self, if self-signing).
+ if ca != nil && caKey != nil {
+ template.AuthorityKeyId = ca.AuthorityKeyId
+ } else {
+ template.AuthorityKeyId = template.SubjectKeyId
+ }
+
+ if ca == nil || caKey == nil {
+ ca = &template
+ caKey = privKey
+ }
+
+ caCertRaw, err := x509.CreateCertificate(rand.Reader, &template, ca, pubKey, caKey)
+ return caCertRaw, privKey, err
+}
+
+type selfSigned struct{}
+
+func (s *selfSigned) Issue(ctx context.Context, template x509.Certificate, kv clientv3.KV) (cert, key []byte, err error) {
+ return issueCertificate(template, nil, nil)
+}
+
+func (s *selfSigned) CACertificate(ctx context.Context, kv clientv3.KV) ([]byte, error) {
+ return nil, nil
+}
+
+var (
+ // SelfSigned is an Issuer that generates self-signed certificates.
+ SelfSigned = &selfSigned{}
+)
+
+func (c *Certificate) Issue(ctx context.Context, template x509.Certificate, kv clientv3.KV) (cert, key []byte, err error) {
+ caCert, caKey, err := c.ensure(ctx, kv)
+ if err != nil {
+ return nil, nil, fmt.Errorf("could not ensure CA certificate %q exists: %w", c.name, err)
+ }
+
+ ca, err := x509.ParseCertificate(caCert)
+ if err != nil {
+ return nil, nil, fmt.Errorf("could not parse CA certificate: %w", err)
+ }
+ // Ensure only one level of CAs exist, and that they are created explicitly.
+ template.IsCA = false
+ return issueCertificate(template, ca, ed25519.PrivateKey(caKey))
+}
+
+func (c *Certificate) CACertificate(ctx context.Context, kv clientv3.KV) ([]byte, error) {
+ cert, _, err := c.ensure(ctx, kv)
+ return cert, err
+}
diff --git a/metropolis/node/kubernetes/pki/certificate.go b/metropolis/node/kubernetes/pki/certificate.go
new file mode 100644
index 0000000..6bd50f9
--- /dev/null
+++ b/metropolis/node/kubernetes/pki/certificate.go
@@ -0,0 +1,192 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pki
+
+import (
+ "context"
+ "crypto/ed25519"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "fmt"
+ "net"
+
+ "go.etcd.io/etcd/clientv3"
+)
+
+// Certificate is the promise of a Certificate being available to the caller. In this case, Certificate refers to a
+// pair of x509 certificate and corresponding private key.
+// Certificates can be stored in etcd, and their issuers might also be store on etcd. As such, this type's methods
+// contain references to an etcd KV client.
+// This Certificate type is agnostic to usage, but mostly geared towards Kubernetes certificates.
+type Certificate struct {
+ // issuer is the Issuer that will generate this certificate if one doesn't yet exist or etcd, or the requested
+ // certificate is volatile (not to be stored on etcd).
+ issuer Issuer
+ // name is a unique key for storing the certificate in etcd. If empty, certificate is 'volatile', will not be stored
+ // on etcd, and every .Ensure() call will generate a new pair.
+ name string
+ // template is an x509 certificate definition that will be used to generate the certificate when issuing it.
+ template x509.Certificate
+}
+
+const (
+ // etcdPrefix is where all the PKI data is stored in etcd.
+ etcdPrefix = "/kube-pki/"
+)
+
+func etcdPath(f string, args ...interface{}) string {
+ return etcdPrefix + fmt.Sprintf(f, args...)
+}
+
+// New creates a new Certificate, or to be more precise, a promise that a certificate will exist once Ensure is called.
+// Issuer must be a valid certificate issuer (SelfSigned or another Certificate). Name must be unique among all
+// certificates, or empty (which will cause the certificate to be volatile, ie. not stored in etcd).
+func New(issuer Issuer, name string, template x509.Certificate) *Certificate {
+ return &Certificate{
+ issuer: issuer,
+ name: name,
+ template: template,
+ }
+}
+
+// Client makes a Kubernetes PKI-compatible client certificate template.
+// Directly derived from Kubernetes PKI requirements documented at
+// https://kubernetes.io/docs/setup/best-practices/certificates/#configure-certificates-manually
+func Client(identity string, groups []string) x509.Certificate {
+ return x509.Certificate{
+ Subject: pkix.Name{
+ CommonName: identity,
+ Organization: groups,
+ },
+ KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
+ }
+}
+
+// Server makes a Kubernetes PKI-compatible server certificate template.
+func Server(dnsNames []string, ips []net.IP) x509.Certificate {
+ return x509.Certificate{
+ Subject: pkix.Name{},
+ KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
+ DNSNames: dnsNames,
+ IPAddresses: ips,
+ }
+}
+
+// CA makes a Certificate that can sign other certificates.
+func CA(cn string) x509.Certificate {
+ return x509.Certificate{
+ Subject: pkix.Name{
+ CommonName: cn,
+ },
+ IsCA: true,
+ KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign | x509.KeyUsageDigitalSignature,
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageOCSPSigning},
+ }
+}
+
+func (c *Certificate) etcdPaths() (cert, key string) {
+ return etcdPath("%s-cert.der", c.name), etcdPath("%s-key.der", c.name)
+}
+
+// ensure returns a DER-encoded x509 certificate and internally encoded bare ed25519 key for a given Certificate,
+// in memory (if volatile), loading it from etcd, or creating and saving it on etcd if needed.
+// This function is safe to call in parallel from multiple etcd clients (including across machines), but it will error
+// in case a concurrent certificate generation happens. These errors are, however, safe to retry - as long as all the
+// certificate creators (ie., Smalltown nodes) run the same version of this code.
+// TODO(q3k): in the future, this should be handled better - especially as we introduce new certificates, or worse,
+// change the issuance chain. As a stopgap measure, an explicit per-certificate or even global lock can be implemented.
+// And, even before that, we can handle concurrency errors in a smarter way.
+func (c *Certificate) ensure(ctx context.Context, kv clientv3.KV) (cert, key []byte, err error) {
+ if c.name == "" {
+ // Volatile certificate - generate.
+ // TODO(q3k): cache internally?
+ cert, key, err = c.issuer.Issue(ctx, c.template, kv)
+ if err != nil {
+ err = fmt.Errorf("failed to issue: %w", err)
+ return
+ }
+ return
+ }
+
+ certPath, keyPath := c.etcdPaths()
+
+ // Try loading certificate and key from etcd.
+ certRes, err := kv.Get(ctx, certPath)
+ if err != nil {
+ err = fmt.Errorf("failed to get certificate from etcd: %w", err)
+ return
+ }
+ keyRes, err := kv.Get(ctx, keyPath)
+ if err != nil {
+ err = fmt.Errorf("failed to get key from etcd: %w", err)
+ return
+ }
+
+ if len(certRes.Kvs) == 1 && len(keyRes.Kvs) == 1 {
+ // Certificate and key exists in etcd, return that.
+ cert = certRes.Kvs[0].Value
+ key = keyRes.Kvs[0].Value
+
+ err = nil
+ // TODO(q3k): check for expiration
+ return
+ }
+
+ // No certificate found - issue one.
+ cert, key, err = c.issuer.Issue(ctx, c.template, kv)
+ if err != nil {
+ err = fmt.Errorf("failed to issue: %w", err)
+ return
+ }
+
+ // Save to etcd in transaction. This ensures that no partial writes happen, and that we haven't been raced to the
+ // save.
+ res, err := kv.Txn(ctx).
+ If(
+ clientv3.Compare(clientv3.CreateRevision(certPath), "=", 0),
+ clientv3.Compare(clientv3.CreateRevision(keyPath), "=", 0),
+ ).
+ Then(
+ clientv3.OpPut(certPath, string(cert)),
+ clientv3.OpPut(keyPath, string(key)),
+ ).Commit()
+ if err != nil {
+ err = fmt.Errorf("failed to write newly issued certificate: %w", err)
+ } else if !res.Succeeded {
+ err = fmt.Errorf("certificate issuance transaction failed: concurrent write")
+ }
+
+ return
+}
+
+// Ensure returns an x509 DER-encoded (but not PEM-encoded) certificate and key for a given Certificate.
+// If the certificate is volatile, each call to Ensure will cause a new certificate to be generated.
+// Otherwise, it will be retrieved from etcd, or generated and stored there if needed.
+func (c *Certificate) Ensure(ctx context.Context, kv clientv3.KV) (cert, key []byte, err error) {
+ cert, key, err = c.ensure(ctx, kv)
+ if err != nil {
+ return nil, nil, err
+ }
+ key, err = x509.MarshalPKCS8PrivateKey(ed25519.PrivateKey(key))
+ if err != nil {
+ err = fmt.Errorf("could not marshal private key (data corruption?): %w", err)
+ return
+ }
+ return cert, key, err
+}
diff --git a/metropolis/node/kubernetes/pki/kubernetes.go b/metropolis/node/kubernetes/pki/kubernetes.go
new file mode 100644
index 0000000..c4827a9
--- /dev/null
+++ b/metropolis/node/kubernetes/pki/kubernetes.go
@@ -0,0 +1,228 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pki
+
+import (
+ "context"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/x509"
+ "encoding/pem"
+ "fmt"
+ "net"
+
+ "go.etcd.io/etcd/clientv3"
+ "k8s.io/client-go/tools/clientcmd"
+ configapi "k8s.io/client-go/tools/clientcmd/api"
+
+ common "git.monogon.dev/source/nexantic.git/metropolis/node"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/logtree"
+)
+
+// KubeCertificateName is an enum-like unique name of a static Kubernetes certificate. The value of the name is used
+// as the unique part of an etcd path where the certificate and key are stored.
+type KubeCertificateName string
+
+const (
+ // The main Kubernetes CA, used to authenticate API consumers, and servers.
+ IdCA KubeCertificateName = "id-ca"
+
+ // Kubernetes apiserver server certificate.
+ APIServer KubeCertificateName = "apiserver"
+
+ // Kubelet client certificate, used to authenticate to the apiserver.
+ KubeletClient KubeCertificateName = "kubelet-client"
+
+ // Kubernetes Controller manager client certificate, used to authenticate to the apiserver.
+ ControllerManagerClient KubeCertificateName = "controller-manager-client"
+ // Kubernetes Controller manager server certificate, used to run its HTTP server.
+ ControllerManager KubeCertificateName = "controller-manager"
+
+ // Kubernetes Scheduler client certificate, used to authenticate to the apiserver.
+ SchedulerClient KubeCertificateName = "scheduler-client"
+ // Kubernetes scheduler server certificate, used to run its HTTP server.
+ Scheduler KubeCertificateName = "scheduler"
+
+ // Root-on-kube (system:masters) client certificate. Used to control the apiserver (and resources) by Smalltown
+ // internally.
+ Master KubeCertificateName = "master"
+
+ // OpenAPI Kubernetes Aggregation CA.
+ // See: https://kubernetes.io/docs/tasks/extend-kubernetes/configure-aggregation-layer/#ca-reusage-and-conflicts
+ AggregationCA KubeCertificateName = "aggregation-ca"
+ FrontProxyClient KubeCertificateName = "front-proxy-client"
+)
+
+const (
+ // serviceAccountKeyName is the etcd path part that is used to store the ServiceAccount authentication secret.
+ // This is not a certificate, just an RSA key.
+ serviceAccountKeyName = "service-account-privkey"
+)
+
+// KubernetesPKI manages all PKI resources required to run Kubernetes on Smalltown. It contains all static certificates,
+// which can be retrieved, or be used to generate Kubeconfigs from.
+type KubernetesPKI struct {
+ logger logtree.LeveledLogger
+ KV clientv3.KV
+ Certificates map[KubeCertificateName]*Certificate
+}
+
+func NewKubernetes(l logtree.LeveledLogger, kv clientv3.KV) *KubernetesPKI {
+ pki := KubernetesPKI{
+ logger: l,
+ KV: kv,
+ Certificates: make(map[KubeCertificateName]*Certificate),
+ }
+
+ make := func(i, name KubeCertificateName, template x509.Certificate) {
+ pki.Certificates[name] = New(pki.Certificates[i], string(name), template)
+ }
+
+ pki.Certificates[IdCA] = New(SelfSigned, string(IdCA), CA("Smalltown Kubernetes ID CA"))
+ make(IdCA, APIServer, Server(
+ []string{
+ "kubernetes",
+ "kubernetes.default",
+ "kubernetes.default.svc",
+ "kubernetes.default.svc.cluster",
+ "kubernetes.default.svc.cluster.local",
+ "localhost",
+ },
+ []net.IP{{10, 0, 255, 1}, {127, 0, 0, 1}}, // TODO(q3k): add service network internal apiserver address
+ ))
+ make(IdCA, KubeletClient, Client("smalltown:apiserver-kubelet-client", nil))
+ make(IdCA, ControllerManagerClient, Client("system:kube-controller-manager", nil))
+ make(IdCA, ControllerManager, Server([]string{"kube-controller-manager.local"}, nil))
+ make(IdCA, SchedulerClient, Client("system:kube-scheduler", nil))
+ make(IdCA, Scheduler, Server([]string{"kube-scheduler.local"}, nil))
+ make(IdCA, Master, Client("smalltown:master", []string{"system:masters"}))
+
+ pki.Certificates[AggregationCA] = New(SelfSigned, string(AggregationCA), CA("Smalltown OpenAPI Aggregation CA"))
+ make(AggregationCA, FrontProxyClient, Client("front-proxy-client", nil))
+
+ return &pki
+}
+
+// EnsureAll ensures that all static certificates (and the serviceaccount key) are present on etcd.
+func (k *KubernetesPKI) EnsureAll(ctx context.Context) error {
+ for n, v := range k.Certificates {
+ k.logger.Infof("Ensuring %s exists", string(n))
+ _, _, err := v.Ensure(ctx, k.KV)
+ if err != nil {
+ return fmt.Errorf("could not ensure certificate %q exists: %w", n, err)
+ }
+ }
+ _, err := k.ServiceAccountKey(ctx)
+ if err != nil {
+ return fmt.Errorf("could not ensure service account key exists: %w", err)
+ }
+ return nil
+}
+
+// Kubeconfig generates a kubeconfig blob for a given certificate name. The same lifetime semantics as in .Certificate
+// apply.
+func (k *KubernetesPKI) Kubeconfig(ctx context.Context, name KubeCertificateName) ([]byte, error) {
+ c, ok := k.Certificates[name]
+ if !ok {
+ return nil, fmt.Errorf("no certificate %q", name)
+ }
+ return c.Kubeconfig(ctx, k.KV)
+}
+
+// Certificate retrieves an x509 DER-encoded (but not PEM-wrapped) key and certificate for a given certificate name.
+// If the requested certificate is volatile, it will be created on demand. Otherwise it will be created on etcd (if not
+// present), and retrieved from there.
+func (k *KubernetesPKI) Certificate(ctx context.Context, name KubeCertificateName) (cert, key []byte, err error) {
+ c, ok := k.Certificates[name]
+ if !ok {
+ return nil, nil, fmt.Errorf("no certificate %q", name)
+ }
+ return c.Ensure(ctx, k.KV)
+}
+
+// Kubeconfig generates a kubeconfig blob for this certificate. The same lifetime semantics as in .Ensure apply.
+func (c *Certificate) Kubeconfig(ctx context.Context, kv clientv3.KV) ([]byte, error) {
+
+ cert, key, err := c.Ensure(ctx, kv)
+ if err != nil {
+ return nil, fmt.Errorf("could not ensure certificate exists: %w", err)
+ }
+
+ kubeconfig := configapi.NewConfig()
+
+ cluster := configapi.NewCluster()
+ cluster.Server = fmt.Sprintf("https://127.0.0.1:%v", common.KubernetesAPIPort)
+
+ ca, err := c.issuer.CACertificate(ctx, kv)
+ if err != nil {
+ return nil, fmt.Errorf("could not get CA certificate: %w", err)
+ }
+ if ca != nil {
+ cluster.CertificateAuthorityData = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: ca})
+ }
+ kubeconfig.Clusters["default"] = cluster
+
+ authInfo := configapi.NewAuthInfo()
+ authInfo.ClientCertificateData = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert})
+ authInfo.ClientKeyData = pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: key})
+ kubeconfig.AuthInfos["default"] = authInfo
+
+ ct := configapi.NewContext()
+ ct.Cluster = "default"
+ ct.AuthInfo = "default"
+ kubeconfig.Contexts["default"] = ct
+
+ kubeconfig.CurrentContext = "default"
+ return clientcmd.Write(*kubeconfig)
+}
+
+// ServiceAccountKey retrieves (and possible generates and stores on etcd) the Kubernetes service account key. The
+// returned data is ready to be used by Kubernetes components (in PKIX form).
+func (k *KubernetesPKI) ServiceAccountKey(ctx context.Context) ([]byte, error) {
+ // TODO(q3k): this should be abstracted away once we abstract away etcd access into a library with try-or-create
+ // semantics.
+
+ path := etcdPath("%s.der", serviceAccountKeyName)
+
+ // Try loading key from etcd.
+ keyRes, err := k.KV.Get(ctx, path)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get key from etcd: %w", err)
+ }
+
+ if len(keyRes.Kvs) == 1 {
+ // Certificate and key exists in etcd, return that.
+ return keyRes.Kvs[0].Value, nil
+ }
+
+ // No key found - generate one.
+ keyRaw, err := rsa.GenerateKey(rand.Reader, 2048)
+ if err != nil {
+ panic(err)
+ }
+ key, err := x509.MarshalPKCS8PrivateKey(keyRaw)
+ if err != nil {
+ panic(err) // Always a programmer error
+ }
+
+ // Save to etcd.
+ _, err = k.KV.Put(ctx, path, string(key))
+ if err != nil {
+ err = fmt.Errorf("failed to write newly generated key: %w", err)
+ }
+ return key, nil
+}
diff --git a/metropolis/node/kubernetes/provisioner.go b/metropolis/node/kubernetes/provisioner.go
new file mode 100644
index 0000000..b671125
--- /dev/null
+++ b/metropolis/node/kubernetes/provisioner.go
@@ -0,0 +1,368 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kubernetes
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+
+ v1 "k8s.io/api/core/v1"
+ storagev1 "k8s.io/api/storage/v1"
+ apierrs "k8s.io/apimachinery/pkg/api/errors"
+ metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
+ "k8s.io/client-go/informers"
+ coreinformers "k8s.io/client-go/informers/core/v1"
+ storageinformers "k8s.io/client-go/informers/storage/v1"
+ "k8s.io/client-go/kubernetes"
+ "k8s.io/client-go/kubernetes/scheme"
+ typedcorev1 "k8s.io/client-go/kubernetes/typed/core/v1"
+ "k8s.io/client-go/tools/cache"
+ "k8s.io/client-go/tools/record"
+ ref "k8s.io/client-go/tools/reference"
+ "k8s.io/client-go/util/workqueue"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/node/common/fsquota"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/common/supervisor"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/localstorage"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/logtree"
+)
+
+// ONCHANGE(//metropolis/node/kubernetes/reconciler:resources_csi.go): needs to match csiProvisionerServerName declared.
+const csiProvisionerServerName = "com.nexantic.smalltown.vfs"
+
+// csiProvisionerServer is responsible for the provisioning and deprovisioning of CSI-based container volumes. It runs on all
+// nodes and watches PVCs for ones assigned to the node it's running on and fulfills the provisioning request by
+// creating a directory, applying a quota and creating the corresponding PV. When the PV is released and its retention
+// policy is Delete, the directory and the PV resource are deleted.
+type csiProvisionerServer struct {
+ NodeName string
+ Kubernetes kubernetes.Interface
+ InformerFactory informers.SharedInformerFactory
+ VolumesDirectory *localstorage.DataVolumesDirectory
+
+ claimQueue workqueue.RateLimitingInterface
+ pvQueue workqueue.RateLimitingInterface
+ recorder record.EventRecorder
+ pvcInformer coreinformers.PersistentVolumeClaimInformer
+ pvInformer coreinformers.PersistentVolumeInformer
+ storageClassInformer storageinformers.StorageClassInformer
+ logger logtree.LeveledLogger
+}
+
+// runCSIProvisioner runs the main provisioning machinery. It consists of a bunch of informers which keep track of
+// the events happening on the Kubernetes control plane and informs us when something happens. If anything happens to
+// PVCs or PVs, we enqueue the identifier of that resource in a work queue. Queues are being worked on by only one
+// worker to limit load and avoid complicated locking infrastructure. Failed items are requeued.
+func (p *csiProvisionerServer) Run(ctx context.Context) error {
+ // The recorder is used to log Kubernetes events for successful or failed volume provisions. These events then
+ // show up in `kubectl describe pvc` and can be used by admins to debug issues with this provisioner.
+ eventBroadcaster := record.NewBroadcaster()
+ eventBroadcaster.StartRecordingToSink(&typedcorev1.EventSinkImpl{Interface: p.Kubernetes.CoreV1().Events("")})
+ p.recorder = eventBroadcaster.NewRecorder(scheme.Scheme, v1.EventSource{Component: csiProvisionerServerName, Host: p.NodeName})
+
+ p.pvInformer = p.InformerFactory.Core().V1().PersistentVolumes()
+ p.pvcInformer = p.InformerFactory.Core().V1().PersistentVolumeClaims()
+ p.storageClassInformer = p.InformerFactory.Storage().V1().StorageClasses()
+
+ p.claimQueue = workqueue.NewRateLimitingQueue(workqueue.DefaultControllerRateLimiter())
+ p.pvQueue = workqueue.NewRateLimitingQueue(workqueue.DefaultControllerRateLimiter())
+
+ p.pvcInformer.Informer().AddEventHandler(cache.ResourceEventHandlerFuncs{
+ AddFunc: p.enqueueClaim,
+ UpdateFunc: func(old, new interface{}) {
+ p.enqueueClaim(new)
+ },
+ })
+ p.pvInformer.Informer().AddEventHandler(cache.ResourceEventHandlerFuncs{
+ AddFunc: p.enqueuePV,
+ UpdateFunc: func(old, new interface{}) {
+ p.enqueuePV(new)
+ },
+ })
+ p.logger = supervisor.Logger(ctx)
+
+ go p.pvcInformer.Informer().Run(ctx.Done())
+ go p.pvInformer.Informer().Run(ctx.Done())
+ go p.storageClassInformer.Informer().Run(ctx.Done())
+
+ // These will self-terminate once the queues are shut down
+ go p.processQueueItems(p.claimQueue, func(key string) error {
+ return p.processPVC(key)
+ })
+ go p.processQueueItems(p.pvQueue, func(key string) error {
+ return p.processPV(key)
+ })
+
+ supervisor.Signal(ctx, supervisor.SignalHealthy)
+ <-ctx.Done()
+ p.claimQueue.ShutDown()
+ p.pvQueue.ShutDown()
+ return nil
+}
+
+// isOurPVC checks if the given PVC is is to be provisioned by this provisioner and has been scheduled onto this node
+func (p *csiProvisionerServer) isOurPVC(pvc *v1.PersistentVolumeClaim) bool {
+ if pvc.ObjectMeta.Annotations["volume.beta.kubernetes.io/storage-provisioner"] != csiProvisionerServerName {
+ return false
+ }
+ if pvc.ObjectMeta.Annotations["volume.kubernetes.io/selected-node"] != p.NodeName {
+ return false
+ }
+ return true
+}
+
+// isOurPV checks if the given PV has been provisioned by this provisioner and has been scheduled onto this node
+func (p *csiProvisionerServer) isOurPV(pv *v1.PersistentVolume) bool {
+ if pv.ObjectMeta.Annotations["pv.kubernetes.io/provisioned-by"] != csiProvisionerServerName {
+ return false
+ }
+ if pv.Spec.NodeAffinity.Required.NodeSelectorTerms[0].MatchExpressions[0].Values[0] != p.NodeName {
+ return false
+ }
+ return true
+}
+
+// enqueueClaim adds an added/changed PVC to the work queue
+func (p *csiProvisionerServer) enqueueClaim(obj interface{}) {
+ key, err := cache.MetaNamespaceKeyFunc(obj)
+ if err != nil {
+ p.logger.Errorf("Not queuing PVC because key could not be derived: %v", err)
+ return
+ }
+ p.claimQueue.Add(key)
+}
+
+// enqueuePV adds an added/changed PV to the work queue
+func (p *csiProvisionerServer) enqueuePV(obj interface{}) {
+ key, err := cache.MetaNamespaceKeyFunc(obj)
+ if err != nil {
+ p.logger.Errorf("Not queuing PV because key could not be derived: %v", err)
+ return
+ }
+ p.pvQueue.Add(key)
+}
+
+// processQueueItems gets items from the given work queue and calls the process function for each of them. It self-
+// terminates once the queue is shut down.
+func (p *csiProvisionerServer) processQueueItems(queue workqueue.RateLimitingInterface, process func(key string) error) {
+ for {
+ obj, shutdown := queue.Get()
+ if shutdown {
+ return
+ }
+
+ func(obj interface{}) {
+ defer queue.Done(obj)
+ key, ok := obj.(string)
+ if !ok {
+ queue.Forget(obj)
+ p.logger.Errorf("Expected string in workqueue, got %+v", obj)
+ return
+ }
+
+ if err := process(key); err != nil {
+ p.logger.Warningf("Failed processing item %q, requeueing (numrequeues: %d): %v", key, queue.NumRequeues(obj), err)
+ queue.AddRateLimited(obj)
+ }
+
+ queue.Forget(obj)
+ }(obj)
+ }
+}
+
+// volumePath gets the path where the volume is stored.
+func (p *csiProvisionerServer) volumePath(volumeID string) string {
+ return filepath.Join(p.VolumesDirectory.FullPath(), volumeID)
+}
+
+// processPVC looks at a single PVC item from the queue, determines if it needs to be provisioned and logs the
+// provisioning result to the recorder
+func (p *csiProvisionerServer) processPVC(key string) error {
+ namespace, name, err := cache.SplitMetaNamespaceKey(key)
+ if err != nil {
+ return fmt.Errorf("invalid resource key: %s", key)
+ }
+ pvc, err := p.pvcInformer.Lister().PersistentVolumeClaims(namespace).Get(name)
+ if apierrs.IsNotFound(err) {
+ return nil // nothing to do, no error
+ } else if err != nil {
+ return fmt.Errorf("failed to get PVC for processing: %w", err)
+ }
+
+ if !p.isOurPVC(pvc) {
+ return nil
+ }
+
+ if pvc.Status.Phase != "Pending" {
+ // If the PVC is not pending, we don't need to provision anything
+ return nil
+ }
+
+ storageClass, err := p.storageClassInformer.Lister().Get(*pvc.Spec.StorageClassName)
+ if err != nil {
+ return fmt.Errorf("")
+ }
+
+ if storageClass.Provisioner != csiProvisionerServerName {
+ // We're not responsible for this PVC. Can only happen if controller-manager makes a mistake
+ // setting the annotations, but we're bailing here anyways for safety.
+ return nil
+ }
+
+ err = p.provisionPVC(pvc, storageClass)
+
+ if err != nil {
+ p.recorder.Eventf(pvc, v1.EventTypeWarning, "ProvisioningFailed", "Failed to provision PV: %v", err)
+ return err
+ }
+ p.recorder.Eventf(pvc, v1.EventTypeNormal, "Provisioned", "Successfully provisioned PV")
+
+ return nil
+}
+
+// provisionPVC creates the directory where the volume lives, sets a quota for the requested amount of storage and
+// creates the PV object representing this new volume
+func (p *csiProvisionerServer) provisionPVC(pvc *v1.PersistentVolumeClaim, storageClass *storagev1.StorageClass) error {
+ claimRef, err := ref.GetReference(scheme.Scheme, pvc)
+ if err != nil {
+ return fmt.Errorf("failed to get reference to PVC: %w", err)
+ }
+
+ storageReq := pvc.Spec.Resources.Requests[v1.ResourceStorage]
+ if storageReq.IsZero() {
+ return fmt.Errorf("PVC is not requesting any storage, this is not supported")
+ }
+ capacity, ok := storageReq.AsInt64()
+ if !ok {
+ return fmt.Errorf("PVC requesting more than 2^63 bytes of storage, this is not supported")
+ }
+
+ if *pvc.Spec.VolumeMode == v1.PersistentVolumeBlock {
+ return fmt.Errorf("Block PVCs are not supported by Smalltown")
+ }
+
+ volumeID := "pvc-" + string(pvc.ObjectMeta.UID)
+ volumePath := p.volumePath(volumeID)
+
+ p.logger.Infof("Creating local PV %s", volumeID)
+ if err := os.Mkdir(volumePath, 0644); err != nil && !os.IsExist(err) {
+ return fmt.Errorf("failed to create volume directory: %w", err)
+ }
+ files, err := ioutil.ReadDir(volumePath)
+ if err != nil {
+ return fmt.Errorf("failed to list files in newly-created volume: %w", err)
+ }
+ if len(files) > 0 {
+ return errors.New("newly-created volume already contains data, bailing")
+ }
+ if err := fsquota.SetQuota(volumePath, uint64(capacity), 100000); err != nil {
+ return fmt.Errorf("failed to update quota: %v", err)
+ }
+
+ vol := &v1.PersistentVolume{
+ ObjectMeta: metav1.ObjectMeta{
+ Name: volumeID,
+ Annotations: map[string]string{
+ "pv.kubernetes.io/provisioned-by": csiProvisionerServerName},
+ },
+ Spec: v1.PersistentVolumeSpec{
+ AccessModes: []v1.PersistentVolumeAccessMode{v1.ReadWriteOnce},
+ Capacity: v1.ResourceList{
+ v1.ResourceStorage: storageReq, // We're always giving the exact amount
+ },
+ PersistentVolumeSource: v1.PersistentVolumeSource{
+ CSI: &v1.CSIPersistentVolumeSource{
+ Driver: csiProvisionerServerName,
+ VolumeHandle: volumeID,
+ },
+ },
+ ClaimRef: claimRef,
+ NodeAffinity: &v1.VolumeNodeAffinity{
+ Required: &v1.NodeSelector{
+ NodeSelectorTerms: []v1.NodeSelectorTerm{
+ {
+ MatchExpressions: []v1.NodeSelectorRequirement{
+ {
+ Key: "kubernetes.io/hostname",
+ Operator: v1.NodeSelectorOpIn,
+ Values: []string{p.NodeName},
+ },
+ },
+ },
+ },
+ },
+ },
+ StorageClassName: *pvc.Spec.StorageClassName,
+ PersistentVolumeReclaimPolicy: *storageClass.ReclaimPolicy,
+ },
+ }
+
+ _, err = p.Kubernetes.CoreV1().PersistentVolumes().Create(context.Background(), vol, metav1.CreateOptions{})
+ if err != nil && !apierrs.IsAlreadyExists(err) {
+ return fmt.Errorf("failed to create PV object: %w", err)
+ }
+ return nil
+}
+
+// processPV looks at a single PV item from the queue and checks if it has been released and needs to be deleted. If yes
+// it deletes the associated quota, directory and the PV object and logs the result to the recorder.
+func (p *csiProvisionerServer) processPV(key string) error {
+ _, name, err := cache.SplitMetaNamespaceKey(key)
+ if err != nil {
+ return fmt.Errorf("invalid resource key: %s", key)
+ }
+ pv, err := p.pvInformer.Lister().Get(name)
+ if apierrs.IsNotFound(err) {
+ return nil // nothing to do, no error
+ } else if err != nil {
+ return fmt.Errorf("failed to get PV for processing: %w", err)
+ }
+
+ if !p.isOurPV(pv) {
+ return nil
+ }
+ if pv.Spec.PersistentVolumeReclaimPolicy != v1.PersistentVolumeReclaimDelete || pv.Status.Phase != "Released" {
+ return nil
+ }
+ volumePath := p.volumePath(pv.Spec.CSI.VolumeHandle)
+
+ // Log deletes for auditing purposes
+ p.logger.Infof("Deleting persistent volume %s", pv.Spec.CSI.VolumeHandle)
+ if err := fsquota.SetQuota(volumePath, 0, 0); err != nil {
+ // We record these here manually since a successful deletion removes the PV we'd be attaching them to
+ p.recorder.Eventf(pv, v1.EventTypeWarning, "DeprovisioningFailed", "Failed to remove quota: %v", err)
+ return fmt.Errorf("failed to remove quota: %w", err)
+ }
+ err = os.RemoveAll(volumePath)
+ if os.IsNotExist(err) {
+ return nil
+ } else if err != nil {
+ p.recorder.Eventf(pv, v1.EventTypeWarning, "DeprovisioningFailed", "Failed to delete volume: %v", err)
+ return fmt.Errorf("failed to delete volume: %w", err)
+ }
+
+ err = p.Kubernetes.CoreV1().PersistentVolumes().Delete(context.Background(), pv.Name, metav1.DeleteOptions{})
+ if err != nil && !apierrs.IsNotFound(err) {
+ p.recorder.Eventf(pv, v1.EventTypeWarning, "DeprovisioningFailed", "Failed to delete PV object from K8s API: %v", err)
+ return fmt.Errorf("failed to delete PV object: %w", err)
+ }
+ return nil
+}
diff --git a/metropolis/node/kubernetes/reconciler/BUILD.bazel b/metropolis/node/kubernetes/reconciler/BUILD.bazel
new file mode 100644
index 0000000..d8f2db6
--- /dev/null
+++ b/metropolis/node/kubernetes/reconciler/BUILD.bazel
@@ -0,0 +1,38 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "go_default_library",
+ srcs = [
+ "reconciler.go",
+ "resources_csi.go",
+ "resources_podsecuritypolicy.go",
+ "resources_rbac.go",
+ "resources_runtimeclass.go",
+ "resources_storageclass.go",
+ ],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/node/kubernetes/reconciler",
+ visibility = ["//metropolis/node:__subpackages__"],
+ deps = [
+ "//metropolis/node/common/supervisor:go_default_library",
+ "@io_k8s_api//core/v1:go_default_library",
+ "@io_k8s_api//node/v1beta1:go_default_library",
+ "@io_k8s_api//policy/v1beta1:go_default_library",
+ "@io_k8s_api//rbac/v1:go_default_library",
+ "@io_k8s_api//storage/v1:go_default_library",
+ "@io_k8s_apimachinery//pkg/apis/meta/v1:go_default_library",
+ "@io_k8s_client_go//kubernetes:go_default_library",
+ ],
+)
+
+go_test(
+ name = "go_default_test",
+ srcs = ["reconciler_test.go"],
+ embed = [":go_default_library"],
+ deps = [
+ "@io_k8s_api//node/v1beta1:go_default_library",
+ "@io_k8s_api//policy/v1beta1:go_default_library",
+ "@io_k8s_api//rbac/v1:go_default_library",
+ "@io_k8s_api//storage/v1:go_default_library",
+ "@io_k8s_apimachinery//pkg/apis/meta/v1:go_default_library",
+ ],
+)
diff --git a/metropolis/node/kubernetes/reconciler/reconciler.go b/metropolis/node/kubernetes/reconciler/reconciler.go
new file mode 100644
index 0000000..9c5ba4e
--- /dev/null
+++ b/metropolis/node/kubernetes/reconciler/reconciler.go
@@ -0,0 +1,163 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// The reconciler ensures that a base set of K8s resources is always available in the cluster. These are necessary to
+// ensure correct out-of-the-box functionality. All resources containing the smalltown.com/builtin=true label are assumed
+// to be managed by the reconciler.
+// It currently does not revert modifications made by admins, it is planned to create an admission plugin prohibiting
+// such modifications to resources with the smalltown.com/builtin label to deal with that problem. This would also solve a
+// potential issue where you could delete resources just by adding the smalltown.com/builtin=true label.
+package reconciler
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ meta "k8s.io/apimachinery/pkg/apis/meta/v1"
+ "k8s.io/client-go/kubernetes"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/node/common/supervisor"
+)
+
+// Sad workaround for all the pointer booleans in K8s specs
+func True() *bool {
+ val := true
+ return &val
+}
+func False() *bool {
+ val := false
+ return &val
+}
+
+const (
+ // BuiltinLabelKey is used as a k8s label to mark built-in objects (ie., managed by the reconciler)
+ BuiltinLabelKey = "smalltown.com/builtin"
+ // BuiltinLabelValue is used as a k8s label value, under the BuiltinLabelKey key.
+ BuiltinLabelValue = "true"
+ // BuiltinRBACPrefix is used to prefix all built-in objects that are part of the rbac/v1 API (eg.
+ // {Cluster,}Role{Binding,} objects). This corresponds to the colon-separated 'namespaces' notation used by
+ // Kubernetes system (system:) objects.
+ BuiltinRBACPrefix = "smalltown:"
+)
+
+// builtinLabels makes a kubernetes-compatible label dictionary (key->value) that is used to mark objects that are
+// built-in into Smalltown (ie., managed by the reconciler). These are then subsequently retrieved by listBuiltins.
+// The extra argument specifies what other labels are to be merged into the the labels dictionary, for convenience. If
+// nil or empty, no extra labels will be applied.
+func builtinLabels(extra map[string]string) map[string]string {
+ l := map[string]string{
+ BuiltinLabelKey: BuiltinLabelValue,
+ }
+ if extra != nil {
+ for k, v := range extra {
+ l[k] = v
+ }
+ }
+ return l
+}
+
+// listBuiltins returns a k8s client ListOptions structure that allows to retrieve all objects that are built-in into
+// Smalltown currently present in the API server (ie., ones that are to be managed by the reconciler). These are created
+// by applying builtinLabels to their metadata labels.
+var listBuiltins = meta.ListOptions{
+ LabelSelector: fmt.Sprintf("%s=%s", BuiltinLabelKey, BuiltinLabelValue),
+}
+
+// builtinRBACName returns a name that is compatible with colon-delimited 'namespaced' objects, a la system:*.
+// These names are to be used by all builtins created as part of the rbac/v1 Kubernetes API.
+func builtinRBACName(name string) string {
+ return BuiltinRBACPrefix + name
+}
+
+// resource is a type of resource to be managed by the reconciler. All builti-ins/reconciled objects must implement
+// this interface to be managed correctly by the reconciler.
+type resource interface {
+ // List returns a list of names of objects current present on the target (ie. k8s API server).
+ List(ctx context.Context) ([]string, error)
+ // Create creates an object on the target. The el interface{} argument is the black box object returned by the
+ // Expected() call.
+ Create(ctx context.Context, el interface{}) error
+ // Delete delete an object, by name, from the target.
+ Delete(ctx context.Context, name string) error
+ // Expected returns a map of all objects expected to be present on the target. The keys are names (which must
+ // correspond to the names returned by List() and used by Delete(), and the values are blackboxes that will then
+ // be passed to the Create() call if their corresponding key (name) does not exist on the target.
+ Expected() map[string]interface{}
+}
+
+func allResources(clientSet kubernetes.Interface) map[string]resource {
+ return map[string]resource{
+ "psps": resourcePodSecurityPolicies{clientSet},
+ "clusterroles": resourceClusterRoles{clientSet},
+ "clusterrolebindings": resourceClusterRoleBindings{clientSet},
+ "storageclasses": resourceStorageClasses{clientSet},
+ "csidrivers": resourceCSIDrivers{clientSet},
+ "runtimeclasses": resourceRuntimeClasses{clientSet},
+ }
+}
+
+func Run(clientSet kubernetes.Interface) supervisor.Runnable {
+ return func(ctx context.Context) error {
+ log := supervisor.Logger(ctx)
+ resources := allResources(clientSet)
+ t := time.NewTicker(10 * time.Second)
+ reconcileAll := func() {
+ for name, resource := range resources {
+ if err := reconcile(ctx, resource); err != nil {
+ log.Warningf("Failed to reconcile built-in resources %s: %v", name, err)
+ }
+ }
+ }
+ supervisor.Signal(ctx, supervisor.SignalHealthy)
+ reconcileAll()
+ for {
+ select {
+ case <-t.C:
+ reconcileAll()
+ case <-ctx.Done():
+ return nil
+ }
+ }
+ }
+}
+
+func reconcile(ctx context.Context, r resource) error {
+ present, err := r.List(ctx)
+ if err != nil {
+ return err
+ }
+ presentSet := make(map[string]bool)
+ for _, el := range present {
+ presentSet[el] = true
+ }
+ expectedMap := r.Expected()
+ for name, el := range expectedMap {
+ if !presentSet[name] {
+ if err := r.Create(ctx, el); err != nil {
+ return err
+ }
+ }
+ }
+ for name, _ := range presentSet {
+ if _, ok := expectedMap[name]; !ok {
+ if err := r.Delete(ctx, name); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
diff --git a/metropolis/node/kubernetes/reconciler/reconciler_test.go b/metropolis/node/kubernetes/reconciler/reconciler_test.go
new file mode 100644
index 0000000..b58d4af
--- /dev/null
+++ b/metropolis/node/kubernetes/reconciler/reconciler_test.go
@@ -0,0 +1,184 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package reconciler
+
+import (
+ "context"
+ "fmt"
+ "testing"
+
+ node "k8s.io/api/node/v1beta1"
+ policy "k8s.io/api/policy/v1beta1"
+ rbac "k8s.io/api/rbac/v1"
+ storage "k8s.io/api/storage/v1"
+ meta "k8s.io/apimachinery/pkg/apis/meta/v1"
+)
+
+// kubernetesMeta unwraps an interface{} that might contain a Kubernetes resource of type that is managed by the
+// reconciler. Any time a new Kubernetes type is managed by the reconciler, the following switch should be extended
+// to cover that type.
+func kubernetesMeta(v interface{}) *meta.ObjectMeta {
+ switch v2 := v.(type) {
+ case *rbac.ClusterRole:
+ return &v2.ObjectMeta
+ case *rbac.ClusterRoleBinding:
+ return &v2.ObjectMeta
+ case *storage.CSIDriver:
+ return &v2.ObjectMeta
+ case *storage.StorageClass:
+ return &v2.ObjectMeta
+ case *policy.PodSecurityPolicy:
+ return &v2.ObjectMeta
+ case *node.RuntimeClass:
+ return &v2.ObjectMeta
+ }
+ return nil
+}
+
+// TestExpectedNamedCorrectly ensures that all the Expected objects of all resource types have a correspondence between
+// their returned key and inner name. This contract must be met in order for the reconciler to not create runaway
+// resources. This assumes all managed resources are Kubernetes resources.
+func TestExpectedNamedCorrectly(t *testing.T) {
+ for reconciler, r := range allResources(nil) {
+ for outer, v := range r.Expected() {
+ meta := kubernetesMeta(v)
+ if meta == nil {
+ t.Errorf("reconciler %q, object %q: could not decode kubernetes metadata", reconciler, outer)
+ continue
+ }
+ if inner := meta.Name; outer != inner {
+ t.Errorf("reconciler %q, object %q: inner name mismatch (%q)", reconciler, outer, inner)
+ continue
+ }
+ }
+ }
+}
+
+// TestExpectedLabeledCorrectly ensures that all the Expected objects of all resource types have a Kubernetes metadata
+// label that signifies it's a builtin object, to be retrieved afterwards. This contract must be met in order for the
+// reconciler to not keep overwriting objects (and possibly failing), when a newly created object is not then
+// retrievable using a selector corresponding to this label. This assumes all managed resources are Kubernetes objects.
+func TestExpectedLabeledCorrectly(t *testing.T) {
+ for reconciler, r := range allResources(nil) {
+ for outer, v := range r.Expected() {
+ meta := kubernetesMeta(v)
+ if meta == nil {
+ t.Errorf("reconciler %q, object %q: could not decode kubernetes metadata", reconciler, outer)
+ continue
+ }
+ if data := meta.Labels[BuiltinLabelKey]; data != BuiltinLabelValue {
+ t.Errorf("reconciler %q, object %q: %q=%q, wanted =%q", reconciler, outer, BuiltinLabelKey, data, BuiltinLabelValue)
+ continue
+ }
+ }
+ }
+}
+
+// testResource is a resource type used for testing. The inner type is a string that is equal to its name (key).
+// It simulates a target (ie. k8s apiserver mock) that always acts nominally (all resources are created, deleted as
+// requested, and the state is consistent with requests).
+type testResource struct {
+ // current is the simulated state of resources in the target.
+ current map[string]string
+ // expected is what this type will report as the Expected() resources.
+ expected map[string]string
+}
+
+func (r *testResource) List(ctx context.Context) ([]string, error) {
+ var keys []string
+ for k, _ := range r.current {
+ keys = append(keys, k)
+ }
+ return keys, nil
+}
+
+func (r *testResource) Create(ctx context.Context, el interface{}) error {
+ r.current[el.(string)] = el.(string)
+ return nil
+}
+
+func (r *testResource) Delete(ctx context.Context, name string) error {
+ delete(r.current, name)
+ return nil
+}
+
+func (r *testResource) Expected() map[string]interface{} {
+ exp := make(map[string]interface{})
+ for k, v := range r.expected {
+ exp[k] = v
+ }
+ return exp
+}
+
+// newTestResource creates a test resource with a list of expected resource strings.
+func newTestResource(want ...string) *testResource {
+ expected := make(map[string]string)
+ for _, w := range want {
+ expected[w] = w
+ }
+ return &testResource{
+ current: make(map[string]string),
+ expected: expected,
+ }
+}
+
+// currentDiff returns a human-readable string showing the different between the current state and the given resource
+// strings. If no difference is present, the returned string is empty.
+func (r *testResource) currentDiff(want ...string) string {
+ expected := make(map[string]string)
+ for _, w := range want {
+ if _, ok := r.current[w]; !ok {
+ return fmt.Sprintf("%q missing in current", w)
+ }
+ expected[w] = w
+ }
+ for _, g := range r.current {
+ if _, ok := expected[g]; !ok {
+ return fmt.Sprintf("%q spurious in current", g)
+ }
+ }
+ return ""
+}
+
+// TestBasicReconciliation ensures that the reconcile function does manipulate a target state based on a set of
+// expected resources.
+func TestBasicReconciliation(t *testing.T) {
+ ctx := context.Background()
+ r := newTestResource("foo", "bar", "baz")
+
+ // nothing should have happened yet (testing the test)
+ if diff := r.currentDiff(); diff != "" {
+ t.Fatalf("wrong state after creation: %s", diff)
+ }
+
+ if err := reconcile(ctx, r); err != nil {
+ t.Fatalf("reconcile: %v", err)
+ }
+ // everything requested should have been created
+ if diff := r.currentDiff("foo", "bar", "baz"); diff != "" {
+ t.Fatalf("wrong state after reconciliation: %s", diff)
+ }
+
+ delete(r.expected, "foo")
+ if err := reconcile(ctx, r); err != nil {
+ t.Fatalf("reconcile: %v", err)
+ }
+ // foo should not be missing
+ if diff := r.currentDiff("bar", "baz"); diff != "" {
+ t.Fatalf("wrong state after deleting foo: %s", diff)
+ }
+}
diff --git a/metropolis/node/kubernetes/reconciler/resources_csi.go b/metropolis/node/kubernetes/reconciler/resources_csi.go
new file mode 100644
index 0000000..ecbcb4b
--- /dev/null
+++ b/metropolis/node/kubernetes/reconciler/resources_csi.go
@@ -0,0 +1,71 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package reconciler
+
+import (
+ "context"
+
+ storage "k8s.io/api/storage/v1"
+ meta "k8s.io/apimachinery/pkg/apis/meta/v1"
+ "k8s.io/client-go/kubernetes"
+)
+
+// TODO(q3k): this is duplicated with //metropolis/node/kubernetes:provisioner.go; integrate this once provisioner.go
+// gets moved into a subpackage.
+// ONCHANGE(//metropolis/node/kubernetes:provisioner.go): needs to match csiProvisionerName declared.
+const csiProvisionerName = "com.nexantic.smalltown.vfs"
+
+type resourceCSIDrivers struct {
+ kubernetes.Interface
+}
+
+func (r resourceCSIDrivers) List(ctx context.Context) ([]string, error) {
+ res, err := r.StorageV1().CSIDrivers().List(ctx, listBuiltins)
+ if err != nil {
+ return nil, err
+ }
+ objs := make([]string, len(res.Items))
+ for i, el := range res.Items {
+ objs[i] = el.ObjectMeta.Name
+ }
+ return objs, nil
+}
+
+func (r resourceCSIDrivers) Create(ctx context.Context, el interface{}) error {
+ _, err := r.StorageV1().CSIDrivers().Create(ctx, el.(*storage.CSIDriver), meta.CreateOptions{})
+ return err
+}
+
+func (r resourceCSIDrivers) Delete(ctx context.Context, name string) error {
+ return r.StorageV1().CSIDrivers().Delete(ctx, name, meta.DeleteOptions{})
+}
+
+func (r resourceCSIDrivers) Expected() map[string]interface{} {
+ return map[string]interface{}{
+ csiProvisionerName: &storage.CSIDriver{
+ ObjectMeta: meta.ObjectMeta{
+ Name: csiProvisionerName,
+ Labels: builtinLabels(nil),
+ },
+ Spec: storage.CSIDriverSpec{
+ AttachRequired: False(),
+ PodInfoOnMount: False(),
+ VolumeLifecycleModes: []storage.VolumeLifecycleMode{storage.VolumeLifecyclePersistent},
+ },
+ },
+ }
+}
diff --git a/metropolis/node/kubernetes/reconciler/resources_podsecuritypolicy.go b/metropolis/node/kubernetes/reconciler/resources_podsecuritypolicy.go
new file mode 100644
index 0000000..507089f
--- /dev/null
+++ b/metropolis/node/kubernetes/reconciler/resources_podsecuritypolicy.go
@@ -0,0 +1,108 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package reconciler
+
+import (
+ "context"
+
+ core "k8s.io/api/core/v1"
+ policy "k8s.io/api/policy/v1beta1"
+ meta "k8s.io/apimachinery/pkg/apis/meta/v1"
+ "k8s.io/client-go/kubernetes"
+)
+
+type resourcePodSecurityPolicies struct {
+ kubernetes.Interface
+}
+
+func (r resourcePodSecurityPolicies) List(ctx context.Context) ([]string, error) {
+ res, err := r.PolicyV1beta1().PodSecurityPolicies().List(ctx, listBuiltins)
+ if err != nil {
+ return nil, err
+ }
+ objs := make([]string, len(res.Items))
+ for i, el := range res.Items {
+ objs[i] = el.ObjectMeta.Name
+ }
+ return objs, nil
+}
+
+func (r resourcePodSecurityPolicies) Create(ctx context.Context, el interface{}) error {
+ _, err := r.PolicyV1beta1().PodSecurityPolicies().Create(ctx, el.(*policy.PodSecurityPolicy), meta.CreateOptions{})
+ return err
+}
+
+func (r resourcePodSecurityPolicies) Delete(ctx context.Context, name string) error {
+ return r.PolicyV1beta1().PodSecurityPolicies().Delete(ctx, name, meta.DeleteOptions{})
+}
+
+func (r resourcePodSecurityPolicies) Expected() map[string]interface{} {
+ return map[string]interface{}{
+ "default": &policy.PodSecurityPolicy{
+ ObjectMeta: meta.ObjectMeta{
+ Name: "default",
+ Labels: builtinLabels(nil),
+ Annotations: map[string]string{
+ "kubernetes.io/description": "This default PSP allows the creation of pods using features that are" +
+ " generally considered safe against any sort of escape.",
+ },
+ },
+ Spec: policy.PodSecurityPolicySpec{
+ AllowPrivilegeEscalation: True(),
+ AllowedCapabilities: []core.Capability{ // runc's default list of allowed capabilities
+ "SETPCAP",
+ "MKNOD",
+ "AUDIT_WRITE",
+ "CHOWN",
+ "NET_RAW",
+ "DAC_OVERRIDE",
+ "FOWNER",
+ "FSETID",
+ "KILL",
+ "SETGID",
+ "SETUID",
+ "NET_BIND_SERVICE",
+ "SYS_CHROOT",
+ "SETFCAP",
+ },
+ HostNetwork: false,
+ HostIPC: false,
+ HostPID: false,
+ FSGroup: policy.FSGroupStrategyOptions{
+ Rule: policy.FSGroupStrategyRunAsAny,
+ },
+ RunAsUser: policy.RunAsUserStrategyOptions{
+ Rule: policy.RunAsUserStrategyRunAsAny,
+ },
+ SELinux: policy.SELinuxStrategyOptions{
+ Rule: policy.SELinuxStrategyRunAsAny,
+ },
+ SupplementalGroups: policy.SupplementalGroupsStrategyOptions{
+ Rule: policy.SupplementalGroupsStrategyRunAsAny,
+ },
+ Volumes: []policy.FSType{ // Volumes considered safe to use
+ policy.ConfigMap,
+ policy.EmptyDir,
+ policy.Projected,
+ policy.Secret,
+ policy.DownwardAPI,
+ policy.PersistentVolumeClaim,
+ },
+ },
+ },
+ }
+}
diff --git a/metropolis/node/kubernetes/reconciler/resources_rbac.go b/metropolis/node/kubernetes/reconciler/resources_rbac.go
new file mode 100644
index 0000000..40ca879
--- /dev/null
+++ b/metropolis/node/kubernetes/reconciler/resources_rbac.go
@@ -0,0 +1,154 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package reconciler
+
+import (
+ "context"
+
+ rbac "k8s.io/api/rbac/v1"
+ meta "k8s.io/apimachinery/pkg/apis/meta/v1"
+ "k8s.io/client-go/kubernetes"
+)
+
+var (
+ clusterRolePSPDefault = builtinRBACName("psp-default")
+ clusterRoleBindingDefaultPSP = builtinRBACName("default-psp-for-sa")
+ clusterRoleBindingAPIServerKubeletClient = builtinRBACName("apiserver-kubelet-client")
+)
+
+type resourceClusterRoles struct {
+ kubernetes.Interface
+}
+
+func (r resourceClusterRoles) List(ctx context.Context) ([]string, error) {
+ res, err := r.RbacV1().ClusterRoles().List(ctx, listBuiltins)
+ if err != nil {
+ return nil, err
+ }
+ objs := make([]string, len(res.Items))
+ for i, el := range res.Items {
+ objs[i] = el.ObjectMeta.Name
+ }
+ return objs, nil
+}
+
+func (r resourceClusterRoles) Create(ctx context.Context, el interface{}) error {
+ _, err := r.RbacV1().ClusterRoles().Create(ctx, el.(*rbac.ClusterRole), meta.CreateOptions{})
+ return err
+}
+
+func (r resourceClusterRoles) Delete(ctx context.Context, name string) error {
+ return r.RbacV1().ClusterRoles().Delete(ctx, name, meta.DeleteOptions{})
+}
+
+func (r resourceClusterRoles) Expected() map[string]interface{} {
+ return map[string]interface{}{
+ clusterRolePSPDefault: &rbac.ClusterRole{
+ ObjectMeta: meta.ObjectMeta{
+ Name: clusterRolePSPDefault,
+ Labels: builtinLabels(nil),
+ Annotations: map[string]string{
+ "kubernetes.io/description": "This role grants access to the \"default\" PSP.",
+ },
+ },
+ Rules: []rbac.PolicyRule{
+ {
+ APIGroups: []string{"policy"},
+ Resources: []string{"podsecuritypolicies"},
+ ResourceNames: []string{"default"},
+ Verbs: []string{"use"},
+ },
+ },
+ },
+ }
+}
+
+type resourceClusterRoleBindings struct {
+ kubernetes.Interface
+}
+
+func (r resourceClusterRoleBindings) List(ctx context.Context) ([]string, error) {
+ res, err := r.RbacV1().ClusterRoleBindings().List(ctx, listBuiltins)
+ if err != nil {
+ return nil, err
+ }
+ objs := make([]string, len(res.Items))
+ for i, el := range res.Items {
+ objs[i] = el.ObjectMeta.Name
+ }
+ return objs, nil
+}
+
+func (r resourceClusterRoleBindings) Create(ctx context.Context, el interface{}) error {
+ _, err := r.RbacV1().ClusterRoleBindings().Create(ctx, el.(*rbac.ClusterRoleBinding), meta.CreateOptions{})
+ return err
+}
+
+func (r resourceClusterRoleBindings) Delete(ctx context.Context, name string) error {
+ return r.RbacV1().ClusterRoleBindings().Delete(ctx, name, meta.DeleteOptions{})
+}
+
+func (r resourceClusterRoleBindings) Expected() map[string]interface{} {
+ return map[string]interface{}{
+ clusterRoleBindingDefaultPSP: &rbac.ClusterRoleBinding{
+ ObjectMeta: meta.ObjectMeta{
+ Name: clusterRoleBindingDefaultPSP,
+ Labels: builtinLabels(nil),
+ Annotations: map[string]string{
+ "kubernetes.io/description": "This binding grants every service account access to the \"default\" PSP. " +
+ "Creation of Pods is still restricted by other RBAC roles. Otherwise no pods (unprivileged or not) " +
+ "can be created.",
+ },
+ },
+ RoleRef: rbac.RoleRef{
+ APIGroup: rbac.GroupName,
+ Kind: "ClusterRole",
+ Name: clusterRolePSPDefault,
+ },
+ Subjects: []rbac.Subject{
+ {
+ APIGroup: rbac.GroupName,
+ Kind: "Group",
+ Name: "system:serviceaccounts",
+ },
+ },
+ },
+ clusterRoleBindingAPIServerKubeletClient: &rbac.ClusterRoleBinding{
+ ObjectMeta: meta.ObjectMeta{
+ Name: clusterRoleBindingAPIServerKubeletClient,
+ Labels: builtinLabels(nil),
+ Annotations: map[string]string{
+ "kubernetes.io/description": "This binding grants the apiserver access to the kubelets. This enables " +
+ "lots of built-in functionality like reading logs or forwarding ports via the API.",
+ },
+ },
+ RoleRef: rbac.RoleRef{
+ APIGroup: rbac.GroupName,
+ Kind: "ClusterRole",
+ Name: "system:kubelet-api-admin",
+ },
+ Subjects: []rbac.Subject{
+ {
+ APIGroup: rbac.GroupName,
+ Kind: "User",
+ // TODO(q3k): describe this name's contract, or unify with whatever creates this.
+ Name: "smalltown:apiserver-kubelet-client",
+ },
+ },
+ },
+ }
+}
diff --git a/metropolis/node/kubernetes/reconciler/resources_runtimeclass.go b/metropolis/node/kubernetes/reconciler/resources_runtimeclass.go
new file mode 100644
index 0000000..c202c0e
--- /dev/null
+++ b/metropolis/node/kubernetes/reconciler/resources_runtimeclass.go
@@ -0,0 +1,69 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package reconciler
+
+import (
+ "context"
+
+ node "k8s.io/api/node/v1beta1"
+ meta "k8s.io/apimachinery/pkg/apis/meta/v1"
+ "k8s.io/client-go/kubernetes"
+)
+
+type resourceRuntimeClasses struct {
+ kubernetes.Interface
+}
+
+func (r resourceRuntimeClasses) List(ctx context.Context) ([]string, error) {
+ res, err := r.NodeV1beta1().RuntimeClasses().List(ctx, listBuiltins)
+ if err != nil {
+ return nil, err
+ }
+ objs := make([]string, len(res.Items))
+ for i, el := range res.Items {
+ objs[i] = el.ObjectMeta.Name
+ }
+ return objs, nil
+}
+
+func (r resourceRuntimeClasses) Create(ctx context.Context, el interface{}) error {
+ _, err := r.NodeV1beta1().RuntimeClasses().Create(ctx, el.(*node.RuntimeClass), meta.CreateOptions{})
+ return err
+}
+
+func (r resourceRuntimeClasses) Delete(ctx context.Context, name string) error {
+ return r.NodeV1beta1().RuntimeClasses().Delete(ctx, name, meta.DeleteOptions{})
+}
+
+func (r resourceRuntimeClasses) Expected() map[string]interface{} {
+ return map[string]interface{}{
+ "gvisor": &node.RuntimeClass{
+ ObjectMeta: meta.ObjectMeta{
+ Name: "gvisor",
+ Labels: builtinLabels(nil),
+ },
+ Handler: "runsc",
+ },
+ "runc": &node.RuntimeClass{
+ ObjectMeta: meta.ObjectMeta{
+ Name: "runc",
+ Labels: builtinLabels(nil),
+ },
+ Handler: "runc",
+ },
+ }
+}
diff --git a/metropolis/node/kubernetes/reconciler/resources_storageclass.go b/metropolis/node/kubernetes/reconciler/resources_storageclass.go
new file mode 100644
index 0000000..72476ec
--- /dev/null
+++ b/metropolis/node/kubernetes/reconciler/resources_storageclass.go
@@ -0,0 +1,72 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package reconciler
+
+import (
+ "context"
+
+ core "k8s.io/api/core/v1"
+ storage "k8s.io/api/storage/v1"
+ meta "k8s.io/apimachinery/pkg/apis/meta/v1"
+ "k8s.io/client-go/kubernetes"
+)
+
+var reclaimPolicyDelete = core.PersistentVolumeReclaimDelete
+var waitForConsumerBinding = storage.VolumeBindingWaitForFirstConsumer
+
+type resourceStorageClasses struct {
+ kubernetes.Interface
+}
+
+func (r resourceStorageClasses) List(ctx context.Context) ([]string, error) {
+ res, err := r.StorageV1().StorageClasses().List(ctx, listBuiltins)
+ if err != nil {
+ return nil, err
+ }
+ objs := make([]string, len(res.Items))
+ for i, el := range res.Items {
+ objs[i] = el.ObjectMeta.Name
+ }
+ return objs, nil
+}
+
+func (r resourceStorageClasses) Create(ctx context.Context, el interface{}) error {
+ _, err := r.StorageV1().StorageClasses().Create(ctx, el.(*storage.StorageClass), meta.CreateOptions{})
+ return err
+}
+
+func (r resourceStorageClasses) Delete(ctx context.Context, name string) error {
+ return r.StorageV1().StorageClasses().Delete(ctx, name, meta.DeleteOptions{})
+}
+
+func (r resourceStorageClasses) Expected() map[string]interface{} {
+ return map[string]interface{}{
+ "local": &storage.StorageClass{
+ ObjectMeta: meta.ObjectMeta{
+ Name: "local",
+ Labels: builtinLabels(nil),
+ Annotations: map[string]string{
+ "storageclass.kubernetes.io/is-default-class": "true",
+ },
+ },
+ AllowVolumeExpansion: True(),
+ Provisioner: csiProvisionerName,
+ ReclaimPolicy: &reclaimPolicyDelete,
+ VolumeBindingMode: &waitForConsumerBinding,
+ },
+ }
+}
diff --git a/metropolis/node/kubernetes/scheduler.go b/metropolis/node/kubernetes/scheduler.go
new file mode 100644
index 0000000..21e6663
--- /dev/null
+++ b/metropolis/node/kubernetes/scheduler.go
@@ -0,0 +1,70 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kubernetes
+
+import (
+ "context"
+ "encoding/pem"
+ "fmt"
+ "os/exec"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/node/common/fileargs"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/common/supervisor"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/kubernetes/pki"
+)
+
+type schedulerConfig struct {
+ kubeConfig []byte
+ serverCert []byte
+ serverKey []byte
+}
+
+func getPKISchedulerConfig(ctx context.Context, kpki *pki.KubernetesPKI) (*schedulerConfig, error) {
+ var config schedulerConfig
+ var err error
+ config.serverCert, config.serverKey, err = kpki.Certificate(ctx, pki.Scheduler)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get scheduler serving certificate: %w", err)
+ }
+ config.kubeConfig, err = kpki.Kubeconfig(ctx, pki.SchedulerClient)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get scheduler kubeconfig: %w", err)
+ }
+ return &config, nil
+}
+
+func runScheduler(config schedulerConfig) supervisor.Runnable {
+ return func(ctx context.Context) error {
+ args, err := fileargs.New()
+ if err != nil {
+ panic(err) // If this fails, something is very wrong. Just crash.
+ }
+ defer args.Close()
+ cmd := exec.CommandContext(ctx, "/kubernetes/bin/kube", "kube-scheduler",
+ args.FileOpt("--kubeconfig", "kubeconfig", config.kubeConfig),
+ "--port=0", // Kill insecure serving
+ args.FileOpt("--tls-cert-file", "server-cert.pem",
+ pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: config.serverCert})),
+ args.FileOpt("--tls-private-key-file", "server-key.pem",
+ pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: config.serverKey})),
+ )
+ if args.Error() != nil {
+ return fmt.Errorf("failed to use fileargs: %w", err)
+ }
+ return supervisor.RunCommand(ctx, cmd)
+ }
+}
diff --git a/metropolis/node/kubernetes/service.go b/metropolis/node/kubernetes/service.go
new file mode 100644
index 0000000..2917bfc
--- /dev/null
+++ b/metropolis/node/kubernetes/service.go
@@ -0,0 +1,177 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kubernetes
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "os"
+ "time"
+
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+ "k8s.io/client-go/informers"
+ "k8s.io/client-go/kubernetes"
+ "k8s.io/client-go/tools/clientcmd"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/node/common/supervisor"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/localstorage"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/core/network/dns"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/kubernetes/clusternet"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/kubernetes/nfproxy"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/kubernetes/pki"
+ "git.monogon.dev/source/nexantic.git/metropolis/node/kubernetes/reconciler"
+ apb "git.monogon.dev/source/nexantic.git/metropolis/proto/api"
+)
+
+type Config struct {
+ AdvertiseAddress net.IP
+ ServiceIPRange net.IPNet
+ ClusterNet net.IPNet
+
+ KPKI *pki.KubernetesPKI
+ Root *localstorage.Root
+ CorednsRegistrationChan chan *dns.ExtraDirective
+}
+
+type Service struct {
+ c Config
+}
+
+func New(c Config) *Service {
+ s := &Service{
+ c: c,
+ }
+ return s
+}
+
+func (s *Service) Run(ctx context.Context) error {
+ controllerManagerConfig, err := getPKIControllerManagerConfig(ctx, s.c.KPKI)
+ if err != nil {
+ return fmt.Errorf("could not generate controller manager pki config: %w", err)
+ }
+ controllerManagerConfig.clusterNet = s.c.ClusterNet
+ schedulerConfig, err := getPKISchedulerConfig(ctx, s.c.KPKI)
+ if err != nil {
+ return fmt.Errorf("could not generate scheduler pki config: %w", err)
+ }
+
+ masterKubeconfig, err := s.c.KPKI.Kubeconfig(ctx, pki.Master)
+ if err != nil {
+ return fmt.Errorf("could not generate master kubeconfig: %w", err)
+ }
+
+ rawClientConfig, err := clientcmd.NewClientConfigFromBytes(masterKubeconfig)
+ if err != nil {
+ return fmt.Errorf("could not generate kubernetes client config: %w", err)
+ }
+
+ clientConfig, err := rawClientConfig.ClientConfig()
+ clientSet, err := kubernetes.NewForConfig(clientConfig)
+ if err != nil {
+ return fmt.Errorf("could not generate kubernetes client: %w", err)
+ }
+
+ informerFactory := informers.NewSharedInformerFactory(clientSet, 5*time.Minute)
+
+ hostname, err := os.Hostname()
+ if err != nil {
+ return fmt.Errorf("failed to get hostname: %w", err)
+ }
+
+ dnsHostIP := s.c.AdvertiseAddress // TODO: Which IP to use
+
+ apiserver := &apiserverService{
+ KPKI: s.c.KPKI,
+ AdvertiseAddress: s.c.AdvertiseAddress,
+ ServiceIPRange: s.c.ServiceIPRange,
+ EphemeralConsensusDirectory: &s.c.Root.Ephemeral.Consensus,
+ }
+
+ kubelet := kubeletService{
+ NodeName: hostname,
+ ClusterDNS: []net.IP{dnsHostIP},
+ KubeletDirectory: &s.c.Root.Data.Kubernetes.Kubelet,
+ EphemeralDirectory: &s.c.Root.Ephemeral,
+ KPKI: s.c.KPKI,
+ }
+
+ csiPlugin := csiPluginServer{
+ KubeletDirectory: &s.c.Root.Data.Kubernetes.Kubelet,
+ VolumesDirectory: &s.c.Root.Data.Volumes,
+ }
+
+ csiProvisioner := csiProvisionerServer{
+ NodeName: hostname,
+ Kubernetes: clientSet,
+ InformerFactory: informerFactory,
+ VolumesDirectory: &s.c.Root.Data.Volumes,
+ }
+
+ clusternet := clusternet.Service{
+ NodeName: hostname,
+ Kubernetes: clientSet,
+ ClusterNet: s.c.ClusterNet,
+ InformerFactory: informerFactory,
+ DataDirectory: &s.c.Root.Data.Kubernetes.ClusterNetworking,
+ }
+
+ nfproxy := nfproxy.Service{
+ ClusterCIDR: s.c.ClusterNet,
+ ClientSet: clientSet,
+ }
+
+ for _, sub := range []struct {
+ name string
+ runnable supervisor.Runnable
+ }{
+ {"apiserver", apiserver.Run},
+ {"controller-manager", runControllerManager(*controllerManagerConfig)},
+ {"scheduler", runScheduler(*schedulerConfig)},
+ {"kubelet", kubelet.Run},
+ {"reconciler", reconciler.Run(clientSet)},
+ {"csi-plugin", csiPlugin.Run},
+ {"csi-provisioner", csiProvisioner.Run},
+ {"clusternet", clusternet.Run},
+ {"nfproxy", nfproxy.Run},
+ } {
+ err := supervisor.Run(ctx, sub.name, sub.runnable)
+ if err != nil {
+ return fmt.Errorf("could not run sub-service %q: %w", sub.name, err)
+ }
+ }
+
+ supervisor.Logger(ctx).Info("Registering K8s CoreDNS")
+ clusterDNSDirective := dns.NewKubernetesDirective("cluster.local", masterKubeconfig)
+ s.c.CorednsRegistrationChan <- clusterDNSDirective
+
+ supervisor.Signal(ctx, supervisor.SignalHealthy)
+ <-ctx.Done()
+ s.c.CorednsRegistrationChan <- dns.CancelDirective(clusterDNSDirective)
+ return nil
+}
+
+// GetDebugKubeconfig issues a kubeconfig for an arbitrary given identity. Useful for debugging and testing.
+func (s *Service) GetDebugKubeconfig(ctx context.Context, request *apb.GetDebugKubeconfigRequest) (*apb.GetDebugKubeconfigResponse, error) {
+ ca := s.c.KPKI.Certificates[pki.IdCA]
+ debugKubeconfig, err := pki.New(ca, "", pki.Client(request.Id, request.Groups)).Kubeconfig(ctx, s.c.KPKI.KV)
+ if err != nil {
+ return nil, status.Errorf(codes.Unavailable, "Failed to generate kubeconfig: %v", err)
+ }
+ return &apb.GetDebugKubeconfigResponse{DebugKubeconfig: string(debugKubeconfig)}, nil
+}
diff --git a/metropolis/node/ports.go b/metropolis/node/ports.go
new file mode 100644
index 0000000..c63ec38
--- /dev/null
+++ b/metropolis/node/ports.go
@@ -0,0 +1,28 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package node
+
+const (
+ NodeServicePort = 7835
+ ConsensusPort = 7834
+ MasterServicePort = 7833
+ ExternalServicePort = 7836
+ DebugServicePort = 7837
+ WireGuardPort = 7838
+ KubernetesAPIPort = 6443
+ DebuggerPort = 2345
+)