metropolis: unify utility packages
One last sweeping rename / reshuffle.
We get rid of //metropolis/node/common and //golibs, unifying them into
a single //metropolis/pkg meta-package.
This is to be documented somwhere properly, but here's the new logic
behind selecting where to place a new library package:
- if it's specific to k8s-on-metropolis, put it in
//metropolis/node/kubernetes/*. This is a self-contained tree that
other paths cannot import from.
- if it's a big new subsystem of the metropolis core, put it in
//metropolis/node/core. This can be imported by anything in
//m/n (eg the Kubernetes code at //m/n/kubernetes
- otherwise, treat it as generic library that's part of the metropolis
project, and put it in //metropolis/pkg. This can be imported by
anything within //metropolis.
This will be followed up by a diff that updates visibility rules.
Test Plan: Pure refactor, CI only.
X-Origin-Diff: phab/D683
GitOrigin-RevId: 883e7f09a7d22d64e966d07bbe839454ed081c79
diff --git a/metropolis/pkg/devicemapper/BUILD.bazel b/metropolis/pkg/devicemapper/BUILD.bazel
new file mode 100644
index 0000000..17c50cc
--- /dev/null
+++ b/metropolis/pkg/devicemapper/BUILD.bazel
@@ -0,0 +1,13 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "go_default_library",
+ srcs = ["devicemapper.go"],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/pkg/devicemapper",
+ visibility = ["//visibility:public"],
+ deps = [
+ "@com_github_pkg_errors//:go_default_library",
+ "@com_github_yalue_native_endian//:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/metropolis/pkg/devicemapper/devicemapper.go b/metropolis/pkg/devicemapper/devicemapper.go
new file mode 100644
index 0000000..2687e3a
--- /dev/null
+++ b/metropolis/pkg/devicemapper/devicemapper.go
@@ -0,0 +1,298 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// package devicemapper is a thin wrapper for the devicemapper ioctl API.
+// See: https://github.com/torvalds/linux/blob/master/include/uapi/linux/dm-ioctl.h
+package devicemapper
+
+import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "os"
+ "runtime"
+ "unsafe"
+
+ "github.com/pkg/errors"
+ "github.com/yalue/native_endian"
+ "golang.org/x/sys/unix"
+)
+
+type DMIoctl struct {
+ Version Version
+ DataSize uint32
+ DataStart uint32
+ TargetCount uint32
+ OpenCount int32
+ Flags uint32
+ EventNumber uint32
+ _padding1 uint32
+ Dev uint64
+ Name [128]byte
+ UUID [129]byte
+ _padding2 [7]byte
+ Data [16384]byte
+}
+
+type DMTargetSpec struct {
+ SectorStart uint64
+ Length uint64
+ Status int32
+ Next uint32
+ TargetType [16]byte
+}
+
+type DMTargetDeps struct {
+ Count uint32
+ Padding uint32
+ Dev []uint64
+}
+
+type DMNameList struct {
+ Dev uint64
+ Next uint32
+ Name []byte
+}
+
+type DMTargetVersions struct {
+ Next uint32
+ Version [3]uint32
+}
+
+type DMTargetMessage struct {
+ Sector uint64
+ Message []byte
+}
+
+type Version [3]uint32
+
+const (
+ /* Top level cmds */
+ DM_VERSION_CMD uintptr = (0xc138fd << 8) + iota
+ DM_REMOVE_ALL_CMD
+ DM_LIST_DEVICES_CMD
+
+ /* device level cmds */
+ DM_DEV_CREATE_CMD
+ DM_DEV_REMOVE_CMD
+ DM_DEV_RENAME_CMD
+ DM_DEV_SUSPEND_CMD
+ DM_DEV_STATUS_CMD
+ DM_DEV_WAIT_CMD
+
+ /* Table level cmds */
+ DM_TABLE_LOAD_CMD
+ DM_TABLE_CLEAR_CMD
+ DM_TABLE_DEPS_CMD
+ DM_TABLE_STATUS_CMD
+
+ /* Added later */
+ DM_LIST_VERSIONS_CMD
+ DM_TARGET_MSG_CMD
+ DM_DEV_SET_GEOMETRY_CMD
+ DM_DEV_ARM_POLL_CMD
+)
+
+const (
+ DM_READONLY_FLAG = 1 << 0 /* In/Out */
+ DM_SUSPEND_FLAG = 1 << 1 /* In/Out */
+ DM_PERSISTENT_DEV_FLAG = 1 << 3 /* In */
+)
+
+const baseDataSize = uint32(unsafe.Sizeof(DMIoctl{})) - 16384
+
+func newReq() DMIoctl {
+ return DMIoctl{
+ Version: Version{4, 0, 0},
+ DataSize: baseDataSize,
+ DataStart: baseDataSize,
+ }
+}
+
+// stringToDelimitedBuf copies src to dst and returns an error if len(src) > len(dst),
+// or when the string contains a null byte.
+func stringToDelimitedBuf(dst []byte, src string) error {
+ if len(src) > len(dst)-1 {
+ return fmt.Errorf("string longer than target buffer (%v > %v)", len(src), len(dst)-1)
+ }
+ for i := 0; i < len(src); i++ {
+ if src[i] == 0x00 {
+ return errors.New("string contains null byte, this is unsupported by DM")
+ }
+ dst[i] = src[i]
+ }
+ return nil
+}
+
+var fd uintptr
+
+func getFd() (uintptr, error) {
+ if fd == 0 {
+ f, err := os.Open("/dev/mapper/control")
+ if os.IsNotExist(err) {
+ _ = os.MkdirAll("/dev/mapper", 0755)
+ if err := unix.Mknod("/dev/mapper/control", unix.S_IFCHR|0600, int(unix.Mkdev(10, 236))); err != nil {
+ return 0, err
+ }
+ f, err = os.Open("/dev/mapper/control")
+ if err != nil {
+ return 0, err
+ }
+ } else if err != nil {
+ return 0, err
+ }
+ fd = f.Fd()
+ return f.Fd(), nil
+ }
+ return fd, nil
+}
+
+func GetVersion() (Version, error) {
+ req := newReq()
+ fd, err := getFd()
+ if err != nil {
+ return Version{}, err
+ }
+ if _, _, err := unix.Syscall(unix.SYS_IOCTL, fd, DM_VERSION_CMD, uintptr(unsafe.Pointer(&req))); err != 0 {
+ return Version{}, err
+ }
+ return req.Version, nil
+}
+
+func CreateDevice(name string) (uint64, error) {
+ req := newReq()
+ if err := stringToDelimitedBuf(req.Name[:], name); err != nil {
+ return 0, err
+ }
+ fd, err := getFd()
+ if err != nil {
+ return 0, err
+ }
+ if _, _, err := unix.Syscall(unix.SYS_IOCTL, fd, DM_DEV_CREATE_CMD, uintptr(unsafe.Pointer(&req))); err != 0 {
+ return 0, err
+ }
+ return req.Dev, nil
+}
+
+func RemoveDevice(name string) error {
+ req := newReq()
+ if err := stringToDelimitedBuf(req.Name[:], name); err != nil {
+ return err
+ }
+ fd, err := getFd()
+ if err != nil {
+ return err
+ }
+ if _, _, err := unix.Syscall(unix.SYS_IOCTL, fd, DM_DEV_REMOVE_CMD, uintptr(unsafe.Pointer(&req))); err != 0 {
+ return err
+ }
+ runtime.KeepAlive(req)
+ return nil
+}
+
+type Target struct {
+ StartSector uint64
+ Length uint64
+ Type string
+ Parameters string
+}
+
+func LoadTable(name string, targets []Target) error {
+ req := newReq()
+ if err := stringToDelimitedBuf(req.Name[:], name); err != nil {
+ return err
+ }
+ var data bytes.Buffer
+ for _, target := range targets {
+ // Gives the size of the spec and the null-terminated params aligned to 8 bytes
+ padding := len(target.Parameters) % 8
+ targetSize := uint32(int(unsafe.Sizeof(DMTargetSpec{})) + (len(target.Parameters) + 1) + padding)
+
+ targetSpec := DMTargetSpec{
+ SectorStart: target.StartSector,
+ Length: target.Length,
+ Next: targetSize,
+ }
+ if err := stringToDelimitedBuf(targetSpec.TargetType[:], target.Type); err != nil {
+ return err
+ }
+ if err := binary.Write(&data, native_endian.NativeEndian(), &targetSpec); err != nil {
+ panic(err)
+ }
+ data.WriteString(target.Parameters)
+ data.WriteByte(0x00)
+ for i := 0; i < padding; i++ {
+ data.WriteByte(0x00)
+ }
+ }
+ req.TargetCount = uint32(len(targets))
+ if data.Len() >= 16384 {
+ return errors.New("table too large for allocated memory")
+ }
+ req.DataSize = baseDataSize + uint32(data.Len())
+ copy(req.Data[:], data.Bytes())
+ fd, err := getFd()
+ if err != nil {
+ return err
+ }
+ if _, _, err := unix.Syscall(unix.SYS_IOCTL, fd, DM_TABLE_LOAD_CMD, uintptr(unsafe.Pointer(&req))); err != 0 {
+ return err
+ }
+ runtime.KeepAlive(req)
+ return nil
+}
+
+func suspendResume(name string, suspend bool) error {
+ req := newReq()
+ if err := stringToDelimitedBuf(req.Name[:], name); err != nil {
+ return err
+ }
+ if suspend {
+ req.Flags = DM_SUSPEND_FLAG
+ }
+ fd, err := getFd()
+ if err != nil {
+ return err
+ }
+ if _, _, err := unix.Syscall(unix.SYS_IOCTL, fd, DM_DEV_SUSPEND_CMD, uintptr(unsafe.Pointer(&req))); err != 0 {
+ return err
+ }
+ runtime.KeepAlive(req)
+ return nil
+}
+
+func Suspend(name string) error {
+ return suspendResume(name, true)
+}
+func Resume(name string) error {
+ return suspendResume(name, false)
+}
+
+func CreateActiveDevice(name string, targets []Target) (uint64, error) {
+ dev, err := CreateDevice(name)
+ if err != nil {
+ return 0, fmt.Errorf("DM_DEV_CREATE failed: %w", err)
+ }
+ if err := LoadTable(name, targets); err != nil {
+ _ = RemoveDevice(name)
+ return 0, fmt.Errorf("DM_TABLE_LOAD failed: %w", err)
+ }
+ if err := Resume(name); err != nil {
+ _ = RemoveDevice(name)
+ return 0, fmt.Errorf("DM_DEV_SUSPEND failed: %w", err)
+ }
+ return dev, nil
+}
diff --git a/metropolis/pkg/fileargs/BUILD.bazel b/metropolis/pkg/fileargs/BUILD.bazel
new file mode 100644
index 0000000..fab70d7
--- /dev/null
+++ b/metropolis/pkg/fileargs/BUILD.bazel
@@ -0,0 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "go_default_library",
+ srcs = ["fileargs.go"],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/pkg/fileargs",
+ visibility = ["//visibility:public"],
+ deps = ["@org_golang_x_sys//unix:go_default_library"],
+)
diff --git a/metropolis/pkg/fileargs/fileargs.go b/metropolis/pkg/fileargs/fileargs.go
new file mode 100644
index 0000000..26c054b
--- /dev/null
+++ b/metropolis/pkg/fileargs/fileargs.go
@@ -0,0 +1,101 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fileargs
+
+import (
+ "crypto/rand"
+ "encoding/hex"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+
+ "golang.org/x/sys/unix"
+)
+
+// DefaultSize is the default size limit for FileArgs
+const DefaultSize = 4 * 1024 * 1024
+
+// TempDirectory is the directory where FileArgs will mount the actual files to. Defaults to
+// os.TempDir() but can be globally overridden by the application before any FileArgs are used.
+var TempDirectory = os.TempDir()
+
+type FileArgs struct {
+ path string
+ lastError error
+}
+
+// New initializes a new set of file-based arguments. Remember to call Close() if you're done
+// using it, otherwise this leaks memory and mounts.
+func New() (*FileArgs, error) {
+ return NewWithSize(DefaultSize)
+}
+
+// NewWthSize is the same as new, but with a custom size limit. Please be aware that this data
+// cannot be swapped out and using a size limit that's too high can deadlock your kernel.
+func NewWithSize(size uint64) (*FileArgs, error) {
+ randomNameRaw := make([]byte, 128/8)
+ if _, err := io.ReadFull(rand.Reader, randomNameRaw); err != nil {
+ return nil, err
+ }
+ tmpPath := filepath.Join(TempDirectory, hex.EncodeToString(randomNameRaw))
+ if err := os.MkdirAll(tmpPath, 0700); err != nil {
+ return nil, err
+ }
+ // This uses ramfs instead of tmpfs because we never want to swap this for security reasons
+ if err := unix.Mount("none", tmpPath, "ramfs", unix.MS_NOEXEC|unix.MS_NOSUID|unix.MS_NODEV, fmt.Sprintf("size=%v", size)); err != nil {
+ return nil, err
+ }
+ return &FileArgs{
+ path: tmpPath,
+ }, nil
+}
+
+// ArgPath returns the path of the temporary file for this argument. It names the temporary
+// file according to name.
+func (f *FileArgs) ArgPath(name string, content []byte) string {
+ if f.lastError != nil {
+ return ""
+ }
+
+ path := filepath.Join(f.path, name)
+
+ if err := ioutil.WriteFile(path, content, 0600); err != nil {
+ f.lastError = err
+ return ""
+ }
+
+ return path
+}
+
+// FileOpt returns a full option with the temporary file name already filled in.
+// Example: `FileOpt("--testopt", "test.txt", []byte("hello")) == "--testopt=/tmp/daf8ed.../test.txt"`
+func (f *FileArgs) FileOpt(optName, fileName string, content []byte) string {
+ return fmt.Sprintf("%v=%v", optName, f.ArgPath(fileName, content))
+}
+
+func (f *FileArgs) Error() error {
+ return f.lastError
+}
+
+func (f *FileArgs) Close() error {
+ if err := unix.Unmount(f.path, 0); err != nil {
+ return err
+ }
+ return os.Remove(f.path)
+}
diff --git a/metropolis/pkg/freeport/BUILD.bazel b/metropolis/pkg/freeport/BUILD.bazel
new file mode 100644
index 0000000..8ac6daf
--- /dev/null
+++ b/metropolis/pkg/freeport/BUILD.bazel
@@ -0,0 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "go_default_library",
+ srcs = ["freeport.go"],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/pkg/freeport",
+ visibility = ["//visibility:public"],
+)
diff --git a/metropolis/pkg/freeport/freeport.go b/metropolis/pkg/freeport/freeport.go
new file mode 100644
index 0000000..bd047b5
--- /dev/null
+++ b/metropolis/pkg/freeport/freeport.go
@@ -0,0 +1,51 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package freeport
+
+import (
+ "io"
+ "net"
+)
+
+// AllocateTCPPort allocates a TCP port on the looopback address, and starts a temporary listener on it. That listener
+// is returned to the caller alongside with the allocated port number. The listener must be closed right before
+// the port is used by the caller. This naturally still leaves a race condition window where that port number
+// might be snatched up by some other process, but there doesn't seem to be a better way to do this.
+func AllocateTCPPort() (uint16, io.Closer, error) {
+ addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
+ if err != nil {
+ return 0, nil, err
+ }
+
+ l, err := net.ListenTCP("tcp", addr)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uint16(l.Addr().(*net.TCPAddr).Port), l, nil
+}
+
+// MustConsume takes the result of AllocateTCPPort, closes the listener and returns the allocated port.
+// If anything goes wrong (port could not be allocated or closed) it will panic.
+func MustConsume(port uint16, lis io.Closer, err error) int {
+ if err != nil {
+ panic(err)
+ }
+ if err := lis.Close(); err != nil {
+ panic(err)
+ }
+ return int(port)
+}
diff --git a/metropolis/pkg/fsquota/BUILD.bazel b/metropolis/pkg/fsquota/BUILD.bazel
new file mode 100644
index 0000000..5f875a9
--- /dev/null
+++ b/metropolis/pkg/fsquota/BUILD.bazel
@@ -0,0 +1,39 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+load("//metropolis/test/ktest:ktest.bzl", "ktest")
+
+go_library(
+ name = "go_default_library",
+ srcs = [
+ "fsinfo.go",
+ "fsquota.go",
+ ],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/pkg/fsquota",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//metropolis/pkg/fsquota/fsxattrs:go_default_library",
+ "//metropolis/pkg/fsquota/quotactl:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+go_test(
+ name = "go_default_test",
+ srcs = ["fsquota_test.go"],
+ embed = [":go_default_library"],
+ pure = "on",
+ deps = [
+ "@com_github_stretchr_testify//require:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+ktest(
+ tester = ":go_default_test",
+ deps = [
+ "//third_party/xfsprogs:mkfs.xfs",
+ ],
+ initramfs_extra = """
+file /mkfs.xfs $(location //third_party/xfsprogs:mkfs.xfs) 0755 0 0
+ """,
+ cmdline = "ramdisk_size=51200",
+)
diff --git a/metropolis/pkg/fsquota/fsinfo.go b/metropolis/pkg/fsquota/fsinfo.go
new file mode 100644
index 0000000..e40a533
--- /dev/null
+++ b/metropolis/pkg/fsquota/fsinfo.go
@@ -0,0 +1,59 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsquota
+
+import (
+ "fmt"
+ "os"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+)
+
+// This requires fsinfo() support, which is not yet in any stable kernel.
+// Our kernel has that syscall backported. This would otherwise be an extremely expensive
+// operation and also involve lots of logic from our side.
+
+// From syscall_64.tbl
+const sys_fsinfo = 441
+
+// From uapi/linux/fsinfo.h
+const fsinfo_attr_source = 0x09
+const fsinfo_flags_query_path = 0x0000
+const fsinfo_flags_query_fd = 0x0001
+
+type fsinfoParams struct {
+ resolveFlags uint64
+ atFlags uint32
+ flags uint32
+ request uint32
+ nth uint32
+ mth uint32
+}
+
+func fsinfoGetSource(dir *os.File) (string, error) {
+ buf := make([]byte, 256)
+ params := fsinfoParams{
+ flags: fsinfo_flags_query_fd,
+ request: fsinfo_attr_source,
+ }
+ n, _, err := unix.Syscall6(sys_fsinfo, dir.Fd(), 0, uintptr(unsafe.Pointer(¶ms)), unsafe.Sizeof(params), uintptr(unsafe.Pointer(&buf[0])), 128)
+ if err != unix.Errno(0) {
+ return "", fmt.Errorf("failed to call fsinfo: %w", err)
+ }
+ return string(buf[:n]), nil
+}
diff --git a/metropolis/pkg/fsquota/fsquota.go b/metropolis/pkg/fsquota/fsquota.go
new file mode 100644
index 0000000..b1305f8
--- /dev/null
+++ b/metropolis/pkg/fsquota/fsquota.go
@@ -0,0 +1,146 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package fsquota provides a simplified interface to interact with Linux's filesystem qouta API.
+// It only supports setting quotas on directories, not groups or users.
+// Quotas need to be already enabled on the filesystem to be able to use them using this package.
+// See the quotactl package if you intend to use this on a filesystem where quotas need to be
+// enabled manually.
+package fsquota
+
+import (
+ "fmt"
+ "math"
+ "os"
+
+ "golang.org/x/sys/unix"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/pkg/fsquota/fsxattrs"
+ "git.monogon.dev/source/nexantic.git/metropolis/pkg/fsquota/quotactl"
+)
+
+// SetQuota sets the quota of bytes and/or inodes in a given path. To not set a limit, set the
+// corresponding argument to zero. Setting both arguments to zero removes the quota entirely.
+// This function can only be called on an empty directory. It can't be used to create a quota
+// below a directory which already has a quota since Linux doesn't offer hierarchical quotas.
+func SetQuota(path string, maxBytes uint64, maxInodes uint64) error {
+ dir, err := os.Open(path)
+ if err != nil {
+ return err
+ }
+ defer dir.Close()
+ source, err := fsinfoGetSource(dir)
+ if err != nil {
+ return err
+ }
+ var valid uint32
+ if maxBytes > 0 {
+ valid |= quotactl.FlagBLimitsValid
+ }
+ if maxInodes > 0 {
+ valid |= quotactl.FlagILimitsValid
+ }
+
+ attrs, err := fsxattrs.Get(dir)
+ if err != nil {
+ return err
+ }
+
+ var lastID uint32 = attrs.ProjectID
+ if lastID == 0 {
+ // No project/quota exists for this directory, assign a new project quota
+ // TODO(lorenz): This is racy, but the kernel does not support atomically assigning
+ // quotas. So this needs to be added to the kernels setquota interface. Due to the short
+ // time window and infrequent calls this should not be an immediate issue.
+ for {
+ quota, err := quotactl.GetNextQuota(source, quotactl.QuotaTypeProject, lastID)
+ if err == unix.ENOENT || err == unix.ESRCH {
+ // We have enumerated all quotas, nothing exists here
+ break
+ } else if err != nil {
+ return fmt.Errorf("failed to call GetNextQuota: %w", err)
+ }
+ if quota.ID > lastID+1 {
+ // Take the first ID in the quota ID gap
+ lastID++
+ break
+ }
+ lastID++
+ }
+ }
+
+ // If both limits are zero, this is a delete operation, process it as such
+ if maxBytes == 0 && maxInodes == 0 {
+ valid = quotactl.FlagBLimitsValid | quotactl.FlagILimitsValid
+ attrs.ProjectID = 0
+ attrs.Flags &= ^fsxattrs.FlagProjectInherit
+ } else {
+ attrs.ProjectID = lastID
+ attrs.Flags |= fsxattrs.FlagProjectInherit
+ }
+
+ if err := fsxattrs.Set(dir, attrs); err != nil {
+ return err
+ }
+
+ // Always round up to the nearest block size
+ bytesLimitBlocks := uint64(math.Ceil(float64(maxBytes) / float64(1024)))
+
+ return quotactl.SetQuota(source, quotactl.QuotaTypeProject, lastID, "actl.Quota{
+ BHardLimit: bytesLimitBlocks,
+ BSoftLimit: bytesLimitBlocks,
+ IHardLimit: maxInodes,
+ ISoftLimit: maxInodes,
+ Valid: valid,
+ })
+}
+
+type Quota struct {
+ Bytes uint64
+ BytesUsed uint64
+ Inodes uint64
+ InodesUsed uint64
+}
+
+// GetQuota returns the current active quota and its utilization at the given path
+func GetQuota(path string) (*Quota, error) {
+ dir, err := os.Open(path)
+ if err != nil {
+ return nil, err
+ }
+ defer dir.Close()
+ source, err := fsinfoGetSource(dir)
+ if err != nil {
+ return nil, err
+ }
+ attrs, err := fsxattrs.Get(dir)
+ if err != nil {
+ return nil, err
+ }
+ if attrs.ProjectID == 0 {
+ return nil, os.ErrNotExist
+ }
+ quota, err := quotactl.GetQuota(source, quotactl.QuotaTypeProject, attrs.ProjectID)
+ if err != nil {
+ return nil, err
+ }
+ return &Quota{
+ Bytes: quota.BHardLimit * 1024,
+ BytesUsed: quota.CurSpace,
+ Inodes: quota.IHardLimit,
+ InodesUsed: quota.CurInodes,
+ }, nil
+}
diff --git a/metropolis/pkg/fsquota/fsquota_test.go b/metropolis/pkg/fsquota/fsquota_test.go
new file mode 100644
index 0000000..4729dac
--- /dev/null
+++ b/metropolis/pkg/fsquota/fsquota_test.go
@@ -0,0 +1,152 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsquota
+
+import (
+ "fmt"
+ "io/ioutil"
+ "math"
+ "os"
+ "os/exec"
+ "syscall"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "golang.org/x/sys/unix"
+)
+
+// withinTolerance is a helper for asserting that a value is within a certain percentage of the
+// expected value. The tolerance is specified as a float between 0 (exact match)
+// and 1 (between 0 and twice the expected value).
+func withinTolerance(t *testing.T, expected uint64, actual uint64, tolerance float64, name string) {
+ t.Helper()
+ delta := uint64(math.Round(float64(expected) * tolerance))
+ lowerBound := expected - delta
+ upperBound := expected + delta
+ if actual < lowerBound {
+ t.Errorf("Value %v (%v) is too low, expected between %v and %v", name, actual, lowerBound, upperBound)
+ }
+ if actual > upperBound {
+ t.Errorf("Value %v (%v) is too high, expected between %v and %v", name, actual, lowerBound, upperBound)
+ }
+}
+
+func TestBasic(t *testing.T) {
+ if os.Getenv("IN_KTEST") != "true" {
+ t.Skip("Not in ktest")
+ }
+ mkfsCmd := exec.Command("/mkfs.xfs", "-qf", "/dev/ram0")
+ if _, err := mkfsCmd.Output(); err != nil {
+ t.Fatal(err)
+ }
+ if err := os.Mkdir("/test", 0755); err != nil {
+ t.Error(err)
+ }
+
+ if err := unix.Mount("/dev/ram0", "/test", "xfs", unix.MS_NOEXEC|unix.MS_NODEV, "prjquota"); err != nil {
+ t.Fatal(err)
+ }
+ defer unix.Unmount("/test", 0)
+ defer os.RemoveAll("/test")
+ t.Run("SetQuota", func(t *testing.T) {
+ defer func() {
+ os.RemoveAll("/test/set")
+ }()
+ if err := os.Mkdir("/test/set", 0755); err != nil {
+ t.Fatal(err)
+ }
+ if err := SetQuota("/test/set", 1024*1024, 100); err != nil {
+ t.Fatal(err)
+ }
+ })
+ t.Run("SetQuotaAndExhaust", func(t *testing.T) {
+ defer func() {
+ os.RemoveAll("/test/sizequota")
+ }()
+ if err := os.Mkdir("/test/sizequota", 0755); err != nil {
+ t.Fatal(err)
+ }
+ const bytesQuota = 1024 * 1024 // 1MiB
+ if err := SetQuota("/test/sizequota", bytesQuota, 0); err != nil {
+ t.Fatal(err)
+ }
+ testfile, err := os.Create("/test/sizequota/testfile")
+ if err != nil {
+ t.Fatal(err)
+ }
+ testdata := make([]byte, 1024)
+ var bytesWritten int
+ for {
+ n, err := testfile.Write([]byte(testdata))
+ if err != nil {
+ if pathErr, ok := err.(*os.PathError); ok {
+ if pathErr.Err == syscall.ENOSPC {
+ // Running out of space is the only acceptable error to continue execution
+ break
+ }
+ }
+ t.Fatal(err)
+ }
+ bytesWritten += n
+ }
+ if bytesWritten > bytesQuota {
+ t.Errorf("Wrote %v bytes, quota is only %v bytes", bytesWritten, bytesQuota)
+ }
+ })
+ t.Run("GetQuotaReadbackAndUtilization", func(t *testing.T) {
+ defer func() {
+ os.RemoveAll("/test/readback")
+ }()
+ if err := os.Mkdir("/test/readback", 0755); err != nil {
+ t.Fatal(err)
+ }
+ const bytesQuota = 1024 * 1024 // 1MiB
+ const inodesQuota = 100
+ if err := SetQuota("/test/readback", bytesQuota, inodesQuota); err != nil {
+ t.Fatal(err)
+ }
+ sizeFileData := make([]byte, 512*1024)
+ if err := ioutil.WriteFile("/test/readback/512kfile", sizeFileData, 0644); err != nil {
+ t.Fatal(err)
+ }
+
+ quotaUtil, err := GetQuota("/test/readback")
+ if err != nil {
+ t.Fatal(err)
+ }
+ require.Equal(t, uint64(bytesQuota), quotaUtil.Bytes, "bytes quota readback incorrect")
+ require.Equal(t, uint64(inodesQuota), quotaUtil.Inodes, "inodes quota readback incorrect")
+
+ // Give 10% tolerance for quota used values to account for metadata overhead and internal
+ // structures that are also in there. If it's out by more than that it's an issue anyways.
+ withinTolerance(t, uint64(len(sizeFileData)), quotaUtil.BytesUsed, 0.1, "BytesUsed")
+
+ // Write 50 inodes for a total of 51 (with the 512K file)
+ for i := 0; i < 50; i++ {
+ if err := ioutil.WriteFile(fmt.Sprintf("/test/readback/ifile%v", i), []byte("test"), 0644); err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ quotaUtil, err = GetQuota("/test/readback")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ withinTolerance(t, 51, quotaUtil.InodesUsed, 0.1, "InodesUsed")
+ })
+}
diff --git a/metropolis/pkg/fsquota/fsxattrs/BUILD.bazel b/metropolis/pkg/fsquota/fsxattrs/BUILD.bazel
new file mode 100644
index 0000000..87f2617
--- /dev/null
+++ b/metropolis/pkg/fsquota/fsxattrs/BUILD.bazel
@@ -0,0 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "go_default_library",
+ srcs = ["fsxattrs.go"],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/pkg/fsquota/fsxattrs",
+ visibility = ["//visibility:public"],
+ deps = ["@org_golang_x_sys//unix:go_default_library"],
+)
diff --git a/metropolis/pkg/fsquota/fsxattrs/fsxattrs.go b/metropolis/pkg/fsquota/fsxattrs/fsxattrs.go
new file mode 100644
index 0000000..1d455eb
--- /dev/null
+++ b/metropolis/pkg/fsquota/fsxattrs/fsxattrs.go
@@ -0,0 +1,77 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsxattrs
+
+import (
+ "fmt"
+ "os"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+)
+
+type FSXAttrFlag uint32
+
+// Defined in uapi/linux/fs.h
+const (
+ FlagRealtime FSXAttrFlag = 0x00000001
+ FlagPreallocated FSXAttrFlag = 0x00000002
+ FlagImmutable FSXAttrFlag = 0x00000008
+ FlagAppend FSXAttrFlag = 0x00000010
+ FlagSync FSXAttrFlag = 0x00000020
+ FlagNoATime FSXAttrFlag = 0x00000040
+ FlagNoDump FSXAttrFlag = 0x00000080
+ FlagRealtimeInherit FSXAttrFlag = 0x00000100
+ FlagProjectInherit FSXAttrFlag = 0x00000200
+ FlagNoSymlinks FSXAttrFlag = 0x00000400
+ FlagExtentSize FSXAttrFlag = 0x00000800
+ FlagNoDefragment FSXAttrFlag = 0x00002000
+ FlagFilestream FSXAttrFlag = 0x00004000
+ FlagDAX FSXAttrFlag = 0x00008000
+ FlagCOWExtentSize FSXAttrFlag = 0x00010000
+ FlagHasAttribute FSXAttrFlag = 0x80000000
+)
+
+// FS_IOC_FSGETXATTR/FS_IOC_FSSETXATTR are defined in uapi/linux/fs.h
+const FS_IOC_FSGETXATTR = 0x801c581f
+const FS_IOC_FSSETXATTR = 0x401c5820
+
+type FSXAttrs struct {
+ Flags FSXAttrFlag
+ ExtentSize uint32
+ ExtentCount uint32
+ ProjectID uint32
+ CoWExtentSize uint32
+ _pad [8]byte
+}
+
+func Get(file *os.File) (*FSXAttrs, error) {
+ var attrs FSXAttrs
+ _, _, errno := unix.Syscall(unix.SYS_IOCTL, file.Fd(), FS_IOC_FSGETXATTR, uintptr(unsafe.Pointer(&attrs)))
+ if errno != 0 {
+ return nil, fmt.Errorf("failed to execute getFSXAttrs: %v", errno)
+ }
+ return &attrs, nil
+}
+
+func Set(file *os.File, attrs *FSXAttrs) error {
+ _, _, errno := unix.Syscall(unix.SYS_IOCTL, file.Fd(), FS_IOC_FSSETXATTR, uintptr(unsafe.Pointer(attrs)))
+ if errno != 0 {
+ return fmt.Errorf("failed to execute setFSXAttrs: %v", errno)
+ }
+ return nil
+}
diff --git a/metropolis/pkg/fsquota/quotactl/BUILD.bazel b/metropolis/pkg/fsquota/quotactl/BUILD.bazel
new file mode 100644
index 0000000..406c784
--- /dev/null
+++ b/metropolis/pkg/fsquota/quotactl/BUILD.bazel
@@ -0,0 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "go_default_library",
+ srcs = ["quotactl.go"],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/pkg/fsquota/quotactl",
+ visibility = ["//visibility:public"],
+ deps = ["@org_golang_x_sys//unix:go_default_library"],
+)
diff --git a/metropolis/pkg/fsquota/quotactl/quotactl.go b/metropolis/pkg/fsquota/quotactl/quotactl.go
new file mode 100644
index 0000000..5ed77d7
--- /dev/null
+++ b/metropolis/pkg/fsquota/quotactl/quotactl.go
@@ -0,0 +1,233 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package quotactl implements a low-level wrapper around the modern portion of Linux's
+// quotactl() syscall. See the fsquota package for a nicer interface to the most common part
+// of this API.
+package quotactl
+
+import (
+ "fmt"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+)
+
+type QuotaType uint
+
+const (
+ QuotaTypeUser QuotaType = iota
+ QuotaTypeGroup
+ QuotaTypeProject
+)
+
+const (
+ Q_SYNC uint = ((0x800001 + iota) << 8)
+ Q_QUOTAON
+ Q_QUOTAOFF
+ Q_GETFMT
+ Q_GETINFO
+ Q_SETINFO
+ Q_GETQUOTA
+ Q_SETQUOTA
+ Q_GETNEXTQUOTA
+)
+
+const (
+ FlagBLimitsValid = 1 << iota
+ FlagSpaceValid
+ FlagILimitsValid
+ FlagInodesValid
+ FlagBTimeValid
+ FlagITimeValid
+)
+
+type DQInfo struct {
+ Bgrace uint64
+ Igrace uint64
+ Flags uint32
+ Valid uint32
+}
+
+type Quota struct {
+ BHardLimit uint64 // Both Byte limits are prescaled by 1024 (so are in KiB), but CurSpace is in B
+ BSoftLimit uint64
+ CurSpace uint64
+ IHardLimit uint64
+ ISoftLimit uint64
+ CurInodes uint64
+ BTime uint64
+ ITime uint64
+ Valid uint32
+}
+
+type NextDQBlk struct {
+ HardLimitBytes uint64
+ SoftLimitBytes uint64
+ CurrentBytes uint64
+ HardLimitInodes uint64
+ SoftLimitInodes uint64
+ CurrentInodes uint64
+ BTime uint64
+ ITime uint64
+ Valid uint32
+ ID uint32
+}
+
+type QuotaFormat uint32
+
+// Collected from quota_format_type structs
+const (
+ // QuotaFormatNone is a special case where all quota information is
+ // stored inside filesystem metadata and thus requires no quotaFilePath.
+ QuotaFormatNone QuotaFormat = 0
+ QuotaFormatVFSOld QuotaFormat = 1
+ QuotaFormatVFSV0 QuotaFormat = 2
+ QuotaFormatOCFS2 QuotaFormat = 3
+ QuotaFormatVFSV1 QuotaFormat = 4
+)
+
+// QuotaOn turns quota accounting and enforcement on
+func QuotaOn(device string, qtype QuotaType, quotaFormat QuotaFormat, quotaFilePath string) error {
+ devArg, err := unix.BytePtrFromString(device)
+ if err != nil {
+ return err
+ }
+ pathArg, err := unix.BytePtrFromString(quotaFilePath)
+ if err != nil {
+ return err
+ }
+ _, _, err = unix.Syscall6(unix.SYS_QUOTACTL, uintptr(Q_QUOTAON|uint(qtype)), uintptr(unsafe.Pointer(devArg)), uintptr(quotaFormat), uintptr(unsafe.Pointer(pathArg)), 0, 0)
+ if err != unix.Errno(0) {
+ return err
+ }
+ return nil
+}
+
+// QuotaOff turns quotas off
+func QuotaOff(device string, qtype QuotaType) error {
+ devArg, err := unix.BytePtrFromString(device)
+ if err != nil {
+ return err
+ }
+ _, _, err = unix.Syscall6(unix.SYS_QUOTACTL, uintptr(Q_QUOTAOFF|uint(qtype)), uintptr(unsafe.Pointer(devArg)), 0, 0, 0, 0)
+ if err != unix.Errno(0) {
+ return err
+ }
+ return nil
+}
+
+// GetFmt gets the quota format used on given filesystem
+func GetFmt(device string, qtype QuotaType) (uint32, error) {
+ var fmt uint32
+ devArg, err := unix.BytePtrFromString(device)
+ if err != nil {
+ return 0, err
+ }
+ _, _, err = unix.Syscall6(unix.SYS_QUOTACTL, uintptr(Q_GETFMT|uint(qtype)), uintptr(unsafe.Pointer(devArg)), 0, uintptr(unsafe.Pointer(&fmt)), 0, 0)
+ if err != unix.Errno(0) {
+ return 0, err
+ }
+ return fmt, nil
+}
+
+// GetInfo gets information about quota files
+func GetInfo(device string, qtype QuotaType) (*DQInfo, error) {
+ var info DQInfo
+ devArg, err := unix.BytePtrFromString(device)
+ if err != nil {
+ return nil, err
+ }
+ _, _, err = unix.Syscall6(unix.SYS_QUOTACTL, uintptr(Q_GETINFO|uint(qtype)), uintptr(unsafe.Pointer(devArg)), 0, uintptr(unsafe.Pointer(&info)), 0, 0)
+ if err != unix.Errno(0) {
+ return nil, err
+ }
+ return &info, nil
+}
+
+// SetInfo sets information about quota files
+func SetInfo(device string, qtype QuotaType, info *DQInfo) error {
+ devArg, err := unix.BytePtrFromString(device)
+ if err != nil {
+ return err
+ }
+ _, _, err = unix.Syscall6(unix.SYS_QUOTACTL, uintptr(Q_SETINFO|uint(qtype)), uintptr(unsafe.Pointer(devArg)), 0, uintptr(unsafe.Pointer(info)), 0, 0)
+ if err != unix.Errno(0) {
+ return err
+ }
+ return nil
+}
+
+// GetQuota gets user quota structure
+func GetQuota(device string, qtype QuotaType, id uint32) (*Quota, error) {
+ var info Quota
+ devArg, err := unix.BytePtrFromString(device)
+ if err != nil {
+ return nil, err
+ }
+ _, _, err = unix.Syscall6(unix.SYS_QUOTACTL, uintptr(Q_GETQUOTA|uint(qtype)), uintptr(unsafe.Pointer(devArg)), uintptr(id), uintptr(unsafe.Pointer(&info)), 0, 0)
+ if err != unix.Errno(0) {
+ return nil, err
+ }
+ return &info, nil
+}
+
+// GetNextQuota gets disk limits and usage > ID
+func GetNextQuota(device string, qtype QuotaType, id uint32) (*NextDQBlk, error) {
+ var info NextDQBlk
+ devArg, err := unix.BytePtrFromString(device)
+ if err != nil {
+ return nil, err
+ }
+ _, _, err = unix.Syscall6(unix.SYS_QUOTACTL, uintptr(Q_GETNEXTQUOTA|uint(qtype)), uintptr(unsafe.Pointer(devArg)), uintptr(id), uintptr(unsafe.Pointer(&info)), 0, 0)
+ if err != unix.Errno(0) {
+ return nil, err
+ }
+ return &info, nil
+}
+
+// SetQuota sets the given quota
+func SetQuota(device string, qtype QuotaType, id uint32, quota *Quota) error {
+ devArg, err := unix.BytePtrFromString(device)
+ if err != nil {
+ return err
+ }
+ _, _, err = unix.Syscall6(unix.SYS_QUOTACTL, uintptr(Q_SETQUOTA|uint(qtype)), uintptr(unsafe.Pointer(devArg)), uintptr(id), uintptr(unsafe.Pointer(quota)), 0, 0)
+ if err != unix.Errno(0) {
+ return fmt.Errorf("failed to set quota: %w", err)
+ }
+ return nil
+}
+
+// Sync syncs disk copy of filesystems quotas. If device is empty it syncs all filesystems.
+func Sync(device string) error {
+ if device != "" {
+ devArg, err := unix.BytePtrFromString(device)
+ if err != nil {
+ return err
+ }
+ _, _, err = unix.Syscall6(unix.SYS_QUOTACTL, uintptr(Q_SYNC), uintptr(unsafe.Pointer(devArg)), 0, 0, 0, 0)
+ if err != unix.Errno(0) {
+ return err
+ }
+ } else {
+ _, _, err := unix.Syscall6(unix.SYS_QUOTACTL, uintptr(Q_SYNC), 0, 0, 0, 0, 0)
+ if err != unix.Errno(0) {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/metropolis/pkg/jsonpatch/BUILD.bazel b/metropolis/pkg/jsonpatch/BUILD.bazel
new file mode 100644
index 0000000..b733c57
--- /dev/null
+++ b/metropolis/pkg/jsonpatch/BUILD.bazel
@@ -0,0 +1,14 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "go_default_library",
+ srcs = ["jsonpatch.go.go"],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/pkg/jsonpatch",
+ visibility = ["//visibility:public"],
+)
+
+go_test(
+ name = "go_default_test",
+ srcs = ["jsonpatch_test.go"],
+ embed = [":go_default_library"],
+)
diff --git a/metropolis/pkg/jsonpatch/jsonpatch.go.go b/metropolis/pkg/jsonpatch/jsonpatch.go.go
new file mode 100644
index 0000000..9682980
--- /dev/null
+++ b/metropolis/pkg/jsonpatch/jsonpatch.go.go
@@ -0,0 +1,44 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package jsonpatch contains data structures and encoders for JSON Patch (RFC 6902) and JSON Pointers (RFC 6901)
+package jsonpatch
+
+import "strings"
+
+// JSON Patch operation (RFC 6902 Section 4)
+type JsonPatchOp struct {
+ Operation string `json:"op"`
+ Path string `json:"path"` // Technically a JSON Pointer, but called Path in the RFC
+ From string `json:"from,omitempty"`
+ Value interface{} `json:"value,omitempty"`
+}
+
+// EncodeJSONRefToken encodes a JSON reference token as part of a JSON Pointer (RFC 6901 Section 2)
+func EncodeJSONRefToken(token string) string {
+ x := strings.ReplaceAll(token, "~", "~0")
+ return strings.ReplaceAll(x, "/", "~1")
+}
+
+// PointerFromParts returns an encoded JSON Pointer from parts
+func PointerFromParts(pathParts []string) string {
+ var encodedParts []string
+ encodedParts = append(encodedParts, "")
+ for _, part := range pathParts {
+ encodedParts = append(encodedParts, EncodeJSONRefToken(part))
+ }
+ return strings.Join(encodedParts, "/")
+}
diff --git a/metropolis/pkg/jsonpatch/jsonpatch_test.go b/metropolis/pkg/jsonpatch/jsonpatch_test.go
new file mode 100644
index 0000000..33a56ba
--- /dev/null
+++ b/metropolis/pkg/jsonpatch/jsonpatch_test.go
@@ -0,0 +1,66 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package jsonpatch
+
+import (
+ "testing"
+)
+
+func TestEncodeJSONRefToken(t *testing.T) {
+ tests := []struct {
+ name string
+ token string
+ want string
+ }{
+ {"Passes through normal characters", "asdf123", "asdf123"},
+ {"Encodes simple slashes", "a/b", "a~1b"},
+ {"Encodes tildes", "m~n", "m~0n"},
+ {"Encodes bot tildes and slashes", "a/m~n", "a~1m~0n"},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := EncodeJSONRefToken(tt.token); got != tt.want {
+ t.Errorf("EncodeJSONRefToken() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestPointerFromParts(t *testing.T) {
+ type args struct {
+ pathParts []string
+ }
+ tests := []struct {
+ name string
+ args args
+ want string
+ }{
+ {"Empty path", args{[]string{}}, ""},
+ {"Single level path", args{[]string{"foo"}}, "/foo"},
+ {"Multi-level path", args{[]string{"foo", "0"}}, "/foo/0"},
+ {"Path starting with empty key", args{[]string{""}}, "/"},
+ {"Path with part containing /", args{[]string{"a/b"}}, "/a~1b"},
+ {"Path with part containing spaces", args{[]string{" "}}, "/ "},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := PointerFromParts(tt.args.pathParts); got != tt.want {
+ t.Errorf("PointerFromParts() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
diff --git a/metropolis/pkg/logbuffer/BUILD.bazel b/metropolis/pkg/logbuffer/BUILD.bazel
new file mode 100644
index 0000000..57a85d8
--- /dev/null
+++ b/metropolis/pkg/logbuffer/BUILD.bazel
@@ -0,0 +1,22 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "go_default_library",
+ srcs = [
+ "linebuffer.go",
+ "logbuffer.go",
+ ],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/pkg/logbuffer",
+ visibility = ["//visibility:public"],
+ deps = ["//metropolis/proto/api:go_default_library"],
+)
+
+go_test(
+ name = "go_default_test",
+ srcs = [
+ "linebuffer_test.go",
+ "logbuffer_test.go",
+ ],
+ embed = [":go_default_library"],
+ deps = ["@com_github_stretchr_testify//require:go_default_library"],
+)
diff --git a/metropolis/pkg/logbuffer/linebuffer.go b/metropolis/pkg/logbuffer/linebuffer.go
new file mode 100644
index 0000000..246a91b
--- /dev/null
+++ b/metropolis/pkg/logbuffer/linebuffer.go
@@ -0,0 +1,160 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logbuffer
+
+import (
+ "bytes"
+ "fmt"
+ "strings"
+ "sync"
+
+ apb "git.monogon.dev/source/nexantic.git/metropolis/proto/api"
+)
+
+// Line is a line stored in the log buffer - a string, that has been perhaps truncated (due to exceeded limits).
+type Line struct {
+ Data string
+ OriginalLength int
+}
+
+// Truncated returns whether this line has been truncated to fit limits.
+func (l *Line) Truncated() bool {
+ return l.OriginalLength > len(l.Data)
+}
+
+// String returns the line with an ellipsis at the end (...) if the line has been truncated, or the original line
+// otherwise.
+func (l *Line) String() string {
+ if l.Truncated() {
+ return l.Data + "..."
+ }
+ return l.Data
+}
+
+// ProtoLog returns a Logging-specific protobuf structure.
+func (l *Line) ProtoLog() *apb.LogEntry_Raw {
+ return &apb.LogEntry_Raw{
+ Data: l.Data,
+ OriginalLength: int64(l.OriginalLength),
+ }
+}
+
+// LineFromLogProto converts a Logging-specific protobuf message back into a Line.
+func LineFromLogProto(raw *apb.LogEntry_Raw) (*Line, error) {
+ if raw.OriginalLength < int64(len(raw.Data)) {
+ return nil, fmt.Errorf("original_length smaller than length of data")
+ }
+ originalLength := int(raw.OriginalLength)
+ if int64(originalLength) < raw.OriginalLength {
+ return nil, fmt.Errorf("original_length larger than native int size")
+ }
+ return &Line{
+ Data: raw.Data,
+ OriginalLength: originalLength,
+ }, nil
+}
+
+// LineBuffer is a io.WriteCloser that will call a given callback every time a line is completed.
+type LineBuffer struct {
+ maxLineLength int
+ cb LineBufferCallback
+
+ mu sync.Mutex
+ cur strings.Builder
+ // length is the length of the line currently being written - this will continue to increase, even if the string
+ // exceeds maxLineLength.
+ length int
+ closed bool
+}
+
+// LineBufferCallback is a callback that will get called any time the line is completed. The function must not cause another
+// write to the LineBuffer, or the program will deadlock.
+type LineBufferCallback func(*Line)
+
+// NewLineBuffer creates a new LineBuffer with a given line length limit and callback.
+func NewLineBuffer(maxLineLength int, cb LineBufferCallback) *LineBuffer {
+ return &LineBuffer{
+ maxLineLength: maxLineLength,
+ cb: cb,
+ }
+}
+
+// writeLimited writes to the internal buffer, making sure that its size does not exceed the maxLineLength.
+func (l *LineBuffer) writeLimited(data []byte) {
+ l.length += len(data)
+ if l.cur.Len()+len(data) > l.maxLineLength {
+ data = data[:l.maxLineLength-l.cur.Len()]
+ }
+ l.cur.Write(data)
+}
+
+// comitLine calls the callback and resets the builder.
+func (l *LineBuffer) commitLine() {
+ l.cb(&Line{
+ Data: l.cur.String(),
+ OriginalLength: l.length,
+ })
+ l.cur.Reset()
+ l.length = 0
+}
+
+func (l *LineBuffer) Write(data []byte) (int, error) {
+ var pos = 0
+
+ l.mu.Lock()
+ defer l.mu.Unlock()
+
+ if l.closed {
+ return 0, fmt.Errorf("closed")
+ }
+
+ for {
+ nextNewline := bytes.IndexRune(data[pos:], '\n')
+
+ // No newline in the data, write everything to the current line
+ if nextNewline == -1 {
+ l.writeLimited(data[pos:])
+ break
+ }
+
+ // Write this line and update position
+ l.writeLimited(data[pos : pos+nextNewline])
+ l.commitLine()
+ pos += nextNewline + 1
+
+ // Data ends with a newline, stop now without writing an empty line
+ if nextNewline == len(data)-1 {
+ break
+ }
+ }
+ return len(data), nil
+}
+
+// Close will emit any leftover data in the buffer to the callback. Subsequent calls to Write will fail. Subsequent calls to Close
+// will also fail.
+func (l *LineBuffer) Close() error {
+ if l.closed {
+ return fmt.Errorf("already closed")
+ }
+ l.mu.Lock()
+ defer l.mu.Unlock()
+ l.closed = true
+ if l.length > 0 {
+ l.commitLine()
+ }
+ return nil
+}
diff --git a/metropolis/pkg/logbuffer/linebuffer_test.go b/metropolis/pkg/logbuffer/linebuffer_test.go
new file mode 100644
index 0000000..c821a4b
--- /dev/null
+++ b/metropolis/pkg/logbuffer/linebuffer_test.go
@@ -0,0 +1,75 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logbuffer
+
+import (
+ "fmt"
+ "testing"
+)
+
+func TestLineBuffer(t *testing.T) {
+ var lines []*Line
+ lb := NewLineBuffer(1024, func(l *Line) {
+ lines = append(lines, l)
+ })
+
+ compare := func(a []*Line, b ...string) string {
+ msg := fmt.Sprintf("want %v, got %v", a, b)
+ if len(a) != len(b) {
+ return msg
+ }
+ for i, _ := range a {
+ if a[i].String() != b[i] {
+ return msg
+ }
+ }
+ return ""
+ }
+
+ // Write some data.
+ fmt.Fprintf(lb, "foo ")
+ if diff := compare(lines); diff != "" {
+ t.Fatal(diff)
+ }
+ fmt.Fprintf(lb, "bar\n")
+ if diff := compare(lines, "foo bar"); diff != "" {
+ t.Fatal(diff)
+ }
+ fmt.Fprintf(lb, "baz")
+ if diff := compare(lines, "foo bar"); diff != "" {
+ t.Fatal(diff)
+ }
+ fmt.Fprintf(lb, " baz")
+ if diff := compare(lines, "foo bar"); diff != "" {
+ t.Fatal(diff)
+ }
+ // Close and expect flush.
+ if err := lb.Close(); err != nil {
+ t.Fatalf("Close: %v", err)
+ }
+ if diff := compare(lines, "foo bar", "baz baz"); diff != "" {
+ t.Fatal(diff)
+ }
+
+ // Check behaviour after close
+ if _, err := fmt.Fprintf(lb, "nope"); err == nil {
+ t.Fatalf("Write after Close: wanted error, got nil")
+ }
+ if err := lb.Close(); err == nil {
+ t.Fatalf("second Close: wanted error, got nil")
+ }
+}
diff --git a/metropolis/pkg/logbuffer/logbuffer.go b/metropolis/pkg/logbuffer/logbuffer.go
new file mode 100644
index 0000000..ce47816
--- /dev/null
+++ b/metropolis/pkg/logbuffer/logbuffer.go
@@ -0,0 +1,97 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package logbuffer implements a fixed-size in-memory ring buffer for line-separated logs.
+// It implements io.Writer and splits the data into lines. The lines are kept in a ring where the
+// oldest are overwritten once it's full. It allows retrieval of the last n lines. There is a built-in
+// line length limit to bound the memory usage at maxLineLength * size.
+package logbuffer
+
+import (
+ "sync"
+)
+
+// LogBuffer implements a fixed-size in-memory ring buffer for line-separated logs
+type LogBuffer struct {
+ mu sync.RWMutex
+ content []Line
+ length int
+ *LineBuffer
+}
+
+// New creates a new LogBuffer with a given ringbuffer size and maximum line length.
+func New(size, maxLineLength int) *LogBuffer {
+ lb := &LogBuffer{
+ content: make([]Line, size),
+ }
+ lb.LineBuffer = NewLineBuffer(maxLineLength, lb.lineCallback)
+ return lb
+}
+
+func (b *LogBuffer) lineCallback(line *Line) {
+ b.mu.Lock()
+ defer b.mu.Unlock()
+
+ b.content[b.length%len(b.content)] = *line
+ b.length++
+}
+
+// capToContentLength caps the number of requested lines to what is actually available
+func (b *LogBuffer) capToContentLength(n int) int {
+ // If there aren't enough lines to read, reduce the request size
+ if n > b.length {
+ n = b.length
+ }
+ // If there isn't enough ringbuffer space, reduce the request size
+ if n > len(b.content) {
+ n = len(b.content)
+ }
+ return n
+}
+
+// ReadLines reads the last n lines from the buffer in chronological order. If n is bigger than the
+// ring buffer or the number of available lines only the number of stored lines are returned.
+func (b *LogBuffer) ReadLines(n int) []Line {
+ b.mu.RLock()
+ defer b.mu.RUnlock()
+
+ n = b.capToContentLength(n)
+
+ // Copy references out to keep them around
+ outArray := make([]Line, n)
+ for i := 1; i <= n; i++ {
+ outArray[n-i] = b.content[(b.length-i)%len(b.content)]
+ }
+ return outArray
+}
+
+// ReadLinesTruncated works exactly the same as ReadLines, but adds an ellipsis at the end of every
+// line that was truncated because it was over MaxLineLength
+func (b *LogBuffer) ReadLinesTruncated(n int, ellipsis string) []string {
+ b.mu.RLock()
+ defer b.mu.RUnlock()
+ // This does not use ReadLines() to prevent excessive reference copying and associated GC pressure
+ // since it could process a lot of lines.
+
+ n = b.capToContentLength(n)
+
+ outArray := make([]string, n)
+ for i := 1; i <= n; i++ {
+ line := b.content[(b.length-i)%len(b.content)]
+ outArray[n-i] = line.String()
+ }
+ return outArray
+}
diff --git a/metropolis/pkg/logbuffer/logbuffer_test.go b/metropolis/pkg/logbuffer/logbuffer_test.go
new file mode 100644
index 0000000..c38d7a6
--- /dev/null
+++ b/metropolis/pkg/logbuffer/logbuffer_test.go
@@ -0,0 +1,94 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logbuffer
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestSingleLine(t *testing.T) {
+ buf := New(1, 16000)
+ buf.Write([]byte("Hello World\n"))
+ out := buf.ReadLines(1)
+ require.Len(t, out, 1, "Invalid number of lines read")
+ require.Equal(t, "Hello World", out[0].Data, "Read bad log line")
+ require.Equal(t, 11, out[0].OriginalLength, "Invalid line length")
+}
+
+func TestPartialWritesAndReads(t *testing.T) {
+ buf := New(2, 16000)
+ buf.Write([]byte("Hello "))
+ buf.Write([]byte("World\nTest "))
+ buf.Write([]byte("2\n"))
+
+ out := buf.ReadLines(1)
+ require.Len(t, out, 1, "Invalid number of lines for partial read")
+ require.Equal(t, "Test 2", out[0].Data, "Read bad log line")
+
+ out2 := buf.ReadLines(2)
+ require.Len(t, out2, 2, "Invalid number of lines read")
+ require.Equal(t, "Hello World", out2[0].Data, "Read bad log line")
+ require.Equal(t, "Test 2", out2[1].Data, "Read bad log line")
+}
+
+func TestBufferOverwrite(t *testing.T) {
+ buf := New(3, 16000)
+ buf.Write([]byte("Test1\nTest2\nTest3\nTest4\n"))
+
+ out := buf.ReadLines(3)
+ require.Equal(t, out[0].Data, "Test2", "Read bad log line")
+ require.Equal(t, out[1].Data, "Test3", "Read bad log line")
+ require.Equal(t, out[2].Data, "Test4", "Overwritten data is invalid")
+}
+
+func TestTooLargeRequests(t *testing.T) {
+ buf := New(1, 16000)
+ outEmpty := buf.ReadLines(1)
+ require.Len(t, outEmpty, 0, "Returned more data than there is")
+
+ buf.Write([]byte("1\n2\n"))
+ out := buf.ReadLines(2)
+ require.Len(t, out, 1, "Returned more data than the ring buffer can hold")
+}
+
+func TestSpecialCases(t *testing.T) {
+ buf := New(2, 16000)
+ buf.Write([]byte("Test1"))
+ buf.Write([]byte("\nTest2\n"))
+ out := buf.ReadLines(2)
+ require.Len(t, out, 2, "Too many lines written")
+ require.Equal(t, out[0].Data, "Test1", "Read bad log line")
+ require.Equal(t, out[1].Data, "Test2", "Read bad log line")
+}
+
+func TestLineLengthLimit(t *testing.T) {
+ buf := New(2, 6)
+
+ testStr := "Just Testing"
+
+ buf.Write([]byte(testStr + "\nShort\n"))
+
+ out := buf.ReadLines(2)
+ require.Equal(t, len(testStr), out[0].OriginalLength, "Line is over length limit")
+ require.Equal(t, "Just T", out[0].Data, "Log line not properly truncated")
+
+ out2 := buf.ReadLinesTruncated(2, "...")
+ require.Equal(t, out2[0], "Just T...", "Line is over length limit")
+ require.Equal(t, out2[1], "Short", "Truncated small enough line")
+}
diff --git a/metropolis/pkg/logtree/BUILD.bazel b/metropolis/pkg/logtree/BUILD.bazel
new file mode 100644
index 0000000..bb07e99
--- /dev/null
+++ b/metropolis/pkg/logtree/BUILD.bazel
@@ -0,0 +1,32 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "go_default_library",
+ srcs = [
+ "doc.go",
+ "journal.go",
+ "journal_entry.go",
+ "journal_subscriber.go",
+ "leveled.go",
+ "leveled_payload.go",
+ "logtree.go",
+ "logtree_access.go",
+ "logtree_entry.go",
+ "logtree_publisher.go",
+ ],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/pkg/logtree",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//metropolis/pkg/logbuffer:go_default_library",
+ "//metropolis/proto/api:go_default_library",
+ ],
+)
+
+go_test(
+ name = "go_default_test",
+ srcs = [
+ "journal_test.go",
+ "logtree_test.go",
+ ],
+ embed = [":go_default_library"],
+)
diff --git a/metropolis/pkg/logtree/doc.go b/metropolis/pkg/logtree/doc.go
new file mode 100644
index 0000000..ab3c537
--- /dev/null
+++ b/metropolis/pkg/logtree/doc.go
@@ -0,0 +1,116 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+/*
+Package logtree implements a tree-shaped logger for debug events. It provides log publishers (ie. Go code) with a
+glog-like API and io.Writer API, with loggers placed in a hierarchical structure defined by a dot-delimited path
+(called a DN, short for Distinguished Name).
+
+ tree.MustLeveledFor("foo.bar.baz").Warningf("Houston, we have a problem: %v", err)
+ fmt.Fprintf(tree.MustRawFor("foo.bar.baz"), "some\nunstructured\ndata\n")
+
+Logs in this context are unstructured, operational and developer-centric human readable text messages presented as lines
+of text to consumers, with some attached metadata. Logtree does not deal with 'structured' logs as some parts of the
+industry do, and instead defers any machine-readable logs to either be handled by metrics systems like Prometheus or
+event sourcing systems like Kafka.
+
+Tree Structure
+
+As an example, consider an application that produces logs with the following DNs:
+
+ listener.http
+ listener.grpc
+ svc
+ svc.cache
+ svc.cache.gc
+
+This would correspond to a tree as follows:
+
+ .------.
+ | "" |
+ | (root) |
+ '------'
+ .----------------' '------.
+ .--------------. .---------------.
+ | svc | | listener |
+ '--------------' '---------------'
+ | .----' '----.
+ .--------------. .---------------. .---------------.
+ | svc.cache | | listener.http | | listener.grpc |
+ '--------------' '---------------' '---------------'
+ |
+ .--------------.
+ | svc.cache.gc |
+ '--------------'
+
+In this setup, every DN acts as a separate logging target, each with its own retention policy and quota. Logging to a DN
+under foo.bar does NOT automatically log to foo - all tree mechanisms are applied on log access by consumers. Loggers
+are automatically created on first use, and importantly, can be created at any time, and will automatically be created
+if a sub-DN is created that requires a parent DN to exist first. Note, for instance, that a `listener` logging node was
+created even though the example application only logged to `listener.http` and `listener.grpc`.
+
+An implicit root node is always present in the tree, accessed by DN "" (an empty string). All other logger nodes are
+children (or transitive children) of the root node.
+
+Log consumers (application code that reads the log and passes them on to operators, or ships them off for aggregation in
+other systems) to select subtrees of logs for readout. In the example tree, a consumer could select to either read all
+logs of the entire tree, just a single DN (like svc), or a subtree (like everything under listener, ie. messages emitted
+to listener.http and listener.grpc).
+
+Leveled Log Producer API
+
+As part of the glog-like logging API available to producers, the following metadata is attached to emitted logs in
+addition to the DN of the logger to which the log entry was emitted:
+
+ - timestamp at which the entry was emitted
+ - a severity level (one of FATAL, ERROR, WARN or INFO)
+ - a source of the message (file name and line number)
+
+In addition, the logger mechanism supports a variable verbosity level (so-called 'V-logging') that can be set at every
+node of the tree. For more information about the producer-facing logging API, see the documentation of the LeveledLogger
+interface, which is the main interface exposed to log producers.
+
+If the submitted message contains newlines, it will be split accordingly into a single log entry that contains multiple
+string lines. This allows for log producers to submit long, multi-line messages that are guaranteed to be non-interleaved
+with other entries, and allows for access API consumers to maintain semantic linking between multiple lines being emitted
+as a single atomic entry.
+
+Raw Log Producer API
+
+In addition to leveled, glog-like logging, LogTree supports 'raw logging'. This is implemented as an io.Writer that will
+split incoming bytes into newline-delimited lines, and log them into that logtree's DN. This mechanism is primarily
+intended to support storage of unstructured log data from external processes - for example binaries running with redirected
+stdout/stderr.
+
+Log Access API
+
+The Log Access API is mostly exposed via a single function on the LogTree struct: Read. It allows access to log entries
+that have been already buffered inside LogTree and to subscribe to receive future entries over a channel. As outlined
+earlier, any access can specify whether it is just interested in a single logger (addressed by DN), or a subtree of
+loggers.
+
+Due to the current implementation of the logtree, subtree accesses of backlogged data is significantly slower than
+accessing data of just one DN, or the whole tree (as every subtree backlog access performs a scan on all logged data).
+Thus, log consumers should be aware that it is much better to stream and buffer logs specific to some long-standing
+logging request on their own, rather than repeatedly perform reads of a subtree backlog.
+
+The data returned from the log access API is a LogEntry, which itself can contain either a raw logging entry, or a leveled
+logging entry. Helper functions are available on LogEntry that allow canonical string representations to be returned, for
+easy use in consuming tools/interfaces. Alternatively, the consumer can itself access the internal raw/leveled entries and
+print them according to their own preferred format.
+
+*/
+package logtree
diff --git a/metropolis/pkg/logtree/journal.go b/metropolis/pkg/logtree/journal.go
new file mode 100644
index 0000000..78c55a1
--- /dev/null
+++ b/metropolis/pkg/logtree/journal.go
@@ -0,0 +1,218 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logtree
+
+import (
+ "errors"
+ "strings"
+ "sync"
+)
+
+// DN is the Distinguished Name, a dot-delimited path used to address loggers within a LogTree. For example, "foo.bar"
+// designates the 'bar' logger node under the 'foo' logger node under the root node of the logger. An empty string is
+// the root node of the tree.
+type DN string
+
+var (
+ ErrInvalidDN = errors.New("invalid DN")
+)
+
+// Path return the parts of a DN, ie. all the elements of the dot-delimited DN path. For the root node, an empty list
+// will be returned. An error will be returned if the DN is invalid (contains empty parts, eg. `foo..bar`, `.foo` or
+// `foo.`.
+func (d DN) Path() ([]string, error) {
+ if d == "" {
+ return nil, nil
+ }
+ parts := strings.Split(string(d), ".")
+ for _, p := range parts {
+ if p == "" {
+ return nil, ErrInvalidDN
+ }
+ }
+ return parts, nil
+}
+
+// journal is the main log recording structure of logtree. It manages linked lists containing the actual log entries,
+// and implements scans across them. It does not understand the hierarchical nature of logtree, and instead sees all
+// entries as part of a global linked list and a local linked list for a given DN.
+//
+// The global linked list is represented by the head/tail pointers in journal and nextGlobal/prevGlobal pointers in
+// entries. The local linked lists are represented by heads[DN]/tails[DN] pointers in journal and nextLocal/prevLocal
+// pointers in entries:
+//
+// .------------. .------------. .------------.
+// | dn: A.B | | dn: Z | | dn: A.B |
+// | time: 1 | | time: 2 | | time: 3 |
+// |------------| |------------| |------------|
+// | nextGlobal :------->| nextGlobal :------->| nextGlobal :--> nil
+// nil <-: prevGlobal |<-------: prevGlobal |<-------| prevGlobal |
+// |------------| |------------| n |------------|
+// | nextLocal :---. n | nextLocal :->i .-->| nextLocal :--> nil
+// nil <-: prevLocal |<--: i<-: prevLocal | l :---| prevLocal |
+// '------------' | l '------------' | '------------'
+// ^ '----------------------' ^
+// | ^ |
+// | | |
+// ( head ) ( tails[Z] ) ( tail )
+// ( heads[A.B] ) ( heads[Z] ) ( tails[A.B] )
+//
+type journal struct {
+ // mu locks the rest of the structure. It must be taken during any operation on the journal.
+ mu sync.RWMutex
+
+ // tail is the side of the global linked list that contains the newest log entry, ie. the one that has been pushed
+ // the most recently. It can be nil when no log entry has yet been pushed. The global linked list contains all log
+ // entries pushed to the journal.
+ tail *entry
+ // head is the side of the global linked list that contains the oldest log entry. It can be nil when no log entry
+ // has yet been pushed.
+ head *entry
+
+ // tails are the tail sides of a local linked list for a given DN, ie. the sides that contain the newest entry. They
+ // are nil if there are no log entries for that DN.
+ tails map[DN]*entry
+ // heads are the head sides of a local linked list for a given DN, ie. the sides that contain the oldest entry. They
+ // are nil if there are no log entries for that DN.
+ heads map[DN]*entry
+
+ // quota is a map from DN to quota structure, representing the quota policy of a particular DN-designated logger.
+ quota map[DN]*quota
+
+ // subscribers are observer to logs. New log entries get emitted to channels present in the subscriber structure,
+ // after filtering them through subscriber-provided filters (eg. to limit events to subtrees that interest that
+ // particular subscriber).
+ subscribers []*subscriber
+}
+
+// newJournal creates a new empty journal. All journals are independent from eachother, and as such, all LogTrees are
+// also independent.
+func newJournal() *journal {
+ return &journal{
+ tails: make(map[DN]*entry),
+ heads: make(map[DN]*entry),
+
+ quota: make(map[DN]*quota),
+ }
+}
+
+// filter is a predicate that returns true if a log subscriber or reader is interested in a given log entry.
+type filter func(*entry) bool
+
+// filterAll returns a filter that accepts all log entries.
+func filterAll() filter {
+ return func(*entry) bool { return true }
+}
+
+// filterExact returns a filter that accepts only log entries at a given exact DN. This filter should not be used in
+// conjunction with journal.scanEntries - instead, journal.getEntries should be used, as it is much faster.
+func filterExact(dn DN) filter {
+ return func(e *entry) bool {
+ return e.origin == dn
+ }
+}
+
+// filterSubtree returns a filter that accepts all log entries at a given DN and sub-DNs. For example, filterSubtree at
+// "foo.bar" would allow entries at "foo.bar", "foo.bar.baz", but not "foo" or "foo.barr".
+func filterSubtree(root DN) filter {
+ if root == "" {
+ return filterAll()
+ }
+
+ rootParts := strings.Split(string(root), ".")
+ return func(e *entry) bool {
+ parts := strings.Split(string(e.origin), ".")
+ if len(parts) < len(rootParts) {
+ return false
+ }
+
+ for i, p := range rootParts {
+ if parts[i] != p {
+ return false
+ }
+ }
+
+ return true
+ }
+}
+
+// filterSeverity returns a filter that accepts log entries at a given severity level or above. See the Severity type
+// for more information about severity levels.
+func filterSeverity(atLeast Severity) filter {
+ return func(e *entry) bool {
+ return e.leveled != nil && e.leveled.severity.AtLeast(atLeast)
+ }
+}
+
+func filterOnlyRaw(e *entry) bool {
+ return e.raw != nil
+}
+
+func filterOnlyLeveled(e *entry) bool {
+ return e.leveled != nil
+}
+
+// scanEntries does a linear scan through the global entry list and returns all entries that match the given filters. If
+// retrieving entries for an exact event, getEntries should be used instead, as it will leverage DN-local linked lists
+// to retrieve them faster.
+// journal.mu must be taken at R or RW level when calling this function.
+func (j *journal) scanEntries(filters ...filter) (res []*entry) {
+ cur := j.tail
+ for {
+ if cur == nil {
+ return
+ }
+
+ passed := true
+ for _, filter := range filters {
+ if !filter(cur) {
+ passed = false
+ break
+ }
+ }
+ if passed {
+ res = append(res, cur)
+ }
+ cur = cur.nextGlobal
+ }
+}
+
+// getEntries returns all entries at a given DN. This is faster than a scanEntries(filterExact), as it uses the special
+// local linked list pointers to traverse the journal. Additional filters can be passed to further limit the entries
+// returned, but a scan through this DN's local linked list will be performed regardless.
+// journal.mu must be taken at R or RW level when calling this function.
+func (j *journal) getEntries(exact DN, filters ...filter) (res []*entry) {
+ cur := j.tails[exact]
+ for {
+ if cur == nil {
+ return
+ }
+
+ passed := true
+ for _, filter := range filters {
+ if !filter(cur) {
+ passed = false
+ break
+ }
+ }
+ if passed {
+ res = append(res, cur)
+ }
+ cur = cur.nextLocal
+ }
+
+}
diff --git a/metropolis/pkg/logtree/journal_entry.go b/metropolis/pkg/logtree/journal_entry.go
new file mode 100644
index 0000000..2a60aa1
--- /dev/null
+++ b/metropolis/pkg/logtree/journal_entry.go
@@ -0,0 +1,169 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logtree
+
+import "git.monogon.dev/source/nexantic.git/metropolis/pkg/logbuffer"
+
+// entry is a journal entry, representing a single log event (encompassed in a Payload) at a given DN.
+// See the journal struct for more information about the global/local linked lists.
+type entry struct {
+ // origin is the DN at which the log entry was recorded, or conversely, in which DN it will be available at.
+ origin DN
+ // journal is the parent journal of this entry. An entry can belong only to a single journal. This pointer is used
+ // to mutate the journal's head/tail pointers when unlinking an entry.
+ journal *journal
+ // leveled is the leveled log entry for this entry, if this log entry was emitted by leveled logging. Otherwise it
+ // is nil.
+ leveled *LeveledPayload
+ // raw is the raw log entry for this entry, if this log entry was emitted by raw logging. Otherwise it is nil.
+ raw *logbuffer.Line
+
+ // prevGlobal is the previous entry in the global linked list, or nil if this entry is the oldest entry in the
+ // global linked list.
+ prevGlobal *entry
+ // nextGlobal is the next entry in the global linked list, or nil if this entry is the newest entry in the global
+ // linked list.
+ nextGlobal *entry
+
+ // prevLocal is the previous entry in this entry DN's local linked list, or nil if this entry is the oldest entry in
+ // this local linked list.
+ prevLocal *entry
+ // prevLocal is the next entry in this entry DN's local linked list, or nil if this entry is the newest entry in
+ // this local linked list.
+ nextLocal *entry
+
+ // seqLocal is a counter within a local linked list that increases by one each time a new log entry is added. It is
+ // used to quickly establish local linked list sizes (by subtracting seqLocal from both ends). This setup allows for
+ // O(1) length calculation for local linked lists as long as entries are only unlinked from the head or tail (which
+ // is the case in the current implementation).
+ seqLocal uint64
+}
+
+// external returns a LogEntry object for this entry, ie. the public version of this object, without fields relating to
+// the parent journal, linked lists, sequences, etc. These objects are visible to library consumers.
+func (e *entry) external() *LogEntry {
+ return &LogEntry{
+ DN: e.origin,
+ Leveled: e.leveled,
+ Raw: e.raw,
+ }
+}
+
+// unlink removes this entry from both global and local linked lists, updating the journal's head/tail pointers if
+// needed.
+// journal.mu must be taken as RW
+func (e *entry) unlink() {
+ // Unlink from the global linked list.
+ if e.prevGlobal != nil {
+ e.prevGlobal.nextGlobal = e.nextGlobal
+ }
+ if e.nextGlobal != nil {
+ e.nextGlobal.prevGlobal = e.prevGlobal
+ }
+ // Update journal head/tail pointers.
+ if e.journal.head == e {
+ e.journal.head = e.prevGlobal
+ }
+ if e.journal.tail == e {
+ e.journal.tail = e.nextGlobal
+ }
+
+ // Unlink from the local linked list.
+ if e.prevLocal != nil {
+ e.prevLocal.nextLocal = e.nextLocal
+ }
+ if e.nextLocal != nil {
+ e.nextLocal.prevLocal = e.prevLocal
+ }
+ // Update journal head/tail pointers.
+ if e.journal.heads[e.origin] == e {
+ e.journal.heads[e.origin] = e.prevLocal
+ }
+ if e.journal.tails[e.origin] == e {
+ e.journal.tails[e.origin] = e.nextLocal
+ }
+}
+
+// quota describes the quota policy for logging at a given DN.
+type quota struct {
+ // origin is the exact DN that this quota applies to.
+ origin DN
+ // max is the maximum count of log entries permitted for this DN - ie, the maximum size of the local linked list.
+ max uint64
+}
+
+// append adds an entry at the head of the global and local linked lists.
+func (j *journal) append(e *entry) {
+ j.mu.Lock()
+ defer j.mu.Unlock()
+
+ e.journal = j
+
+ // Insert at head in global linked list, set pointers.
+ e.nextGlobal = nil
+ e.prevGlobal = j.head
+ if j.head != nil {
+ j.head.nextGlobal = e
+ }
+ j.head = e
+ if j.tail == nil {
+ j.tail = e
+ }
+
+ // Create quota if necessary.
+ if _, ok := j.quota[e.origin]; !ok {
+ j.quota[e.origin] = "a{origin: e.origin, max: 8192}
+ }
+
+ // Insert at head in local linked list, calculate seqLocal, set pointers.
+ e.nextLocal = nil
+ e.prevLocal = j.heads[e.origin]
+ if j.heads[e.origin] != nil {
+ j.heads[e.origin].nextLocal = e
+ e.seqLocal = e.prevLocal.seqLocal + 1
+ } else {
+ e.seqLocal = 0
+ }
+ j.heads[e.origin] = e
+ if j.tails[e.origin] == nil {
+ j.tails[e.origin] = e
+ }
+
+ // Apply quota to the local linked list that this entry got inserted to, ie. remove elements in excess of the
+ // quota.max count.
+ quota := j.quota[e.origin]
+ count := (j.heads[e.origin].seqLocal - j.tails[e.origin].seqLocal) + 1
+ if count > quota.max {
+ // Keep popping elements off the tail of the local linked list until quota is not violated.
+ left := count - quota.max
+ cur := j.tails[e.origin]
+ for {
+ // This shouldn't happen if quota.max >= 1.
+ if cur == nil {
+ break
+ }
+ if left == 0 {
+ break
+ }
+ el := cur
+ cur = el.nextLocal
+ // Unlinking the entry unlinks it from both the global and local linked lists.
+ el.unlink()
+ left -= 1
+ }
+ }
+}
diff --git a/metropolis/pkg/logtree/journal_subscriber.go b/metropolis/pkg/logtree/journal_subscriber.go
new file mode 100644
index 0000000..e6c7c62
--- /dev/null
+++ b/metropolis/pkg/logtree/journal_subscriber.go
@@ -0,0 +1,69 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logtree
+
+import (
+ "sync/atomic"
+)
+
+// subscriber is an observer for new entries that are appended to the journal.
+type subscriber struct {
+ // filters that entries need to pass through in order to be sent to the subscriber.
+ filters []filter
+ // dataC is the channel to which entries that pass filters will be sent. The channel must be drained regularly in
+ // order to prevent accumulation of goroutines and possible reordering of messages.
+ dataC chan *LogEntry
+ // doneC is a channel that is closed once the subscriber wishes to stop receiving notifications.
+ doneC chan struct{}
+ // missed is the amount of messages missed by the subscriber by not receiving from dataC fast enough
+ missed uint64
+}
+
+// subscribe attaches a subscriber to the journal.
+// mu must be taken in W mode
+func (j *journal) subscribe(sub *subscriber) {
+ j.subscribers = append(j.subscribers, sub)
+}
+
+// notify sends an entry to all subscribers that wish to receive it.
+func (j *journal) notify(e *entry) {
+ j.mu.Lock()
+ defer j.mu.Unlock()
+
+ newSub := make([]*subscriber, 0, len(j.subscribers))
+ for _, sub := range j.subscribers {
+ select {
+ case <-sub.doneC:
+ close(sub.dataC)
+ continue
+ default:
+ newSub = append(newSub, sub)
+ }
+
+ for _, filter := range sub.filters {
+ if !filter(e) {
+ continue
+ }
+ }
+ select {
+ case sub.dataC <- e.external():
+ default:
+ atomic.AddUint64(&sub.missed, 1)
+ }
+ }
+ j.subscribers = newSub
+}
diff --git a/metropolis/pkg/logtree/journal_test.go b/metropolis/pkg/logtree/journal_test.go
new file mode 100644
index 0000000..474748a
--- /dev/null
+++ b/metropolis/pkg/logtree/journal_test.go
@@ -0,0 +1,148 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logtree
+
+import (
+ "fmt"
+ "strings"
+ "testing"
+ "time"
+)
+
+func testPayload(msg string) *LeveledPayload {
+ return &LeveledPayload{
+ messages: []string{msg},
+ timestamp: time.Now(),
+ severity: INFO,
+ file: "main.go",
+ line: 1337,
+ }
+}
+
+func TestJournalRetention(t *testing.T) {
+ j := newJournal()
+
+ for i := 0; i < 9000; i += 1 {
+ e := &entry{
+ origin: "main",
+ leveled: testPayload(fmt.Sprintf("test %d", i)),
+ }
+ j.append(e)
+ }
+
+ entries := j.getEntries("main")
+ if want, got := 8192, len(entries); want != got {
+ t.Fatalf("wanted %d entries, got %d", want, got)
+ }
+ for i, entry := range entries {
+ want := fmt.Sprintf("test %d", (9000-8192)+i)
+ got := strings.Join(entry.leveled.messages, "\n")
+ if want != got {
+ t.Fatalf("wanted entry %q, got %q", want, got)
+ }
+ }
+}
+
+func TestJournalQuota(t *testing.T) {
+ j := newJournal()
+
+ for i := 0; i < 9000; i += 1 {
+ j.append(&entry{
+ origin: "chatty",
+ leveled: testPayload(fmt.Sprintf("chatty %d", i)),
+ })
+ if i%10 == 0 {
+ j.append(&entry{
+ origin: "solemn",
+ leveled: testPayload(fmt.Sprintf("solemn %d", i)),
+ })
+ }
+ }
+
+ entries := j.getEntries("chatty")
+ if want, got := 8192, len(entries); want != got {
+ t.Fatalf("wanted %d chatty entries, got %d", want, got)
+ }
+ entries = j.getEntries("solemn")
+ if want, got := 900, len(entries); want != got {
+ t.Fatalf("wanted %d solemn entries, got %d", want, got)
+ }
+ entries = j.getEntries("absent")
+ if want, got := 0, len(entries); want != got {
+ t.Fatalf("wanted %d absent entries, got %d", want, got)
+ }
+
+ entries = j.scanEntries(filterAll())
+ if want, got := 8192+900, len(entries); want != got {
+ t.Fatalf("wanted %d total entries, got %d", want, got)
+ }
+ setMessages := make(map[string]bool)
+ for _, entry := range entries {
+ setMessages[strings.Join(entry.leveled.messages, "\n")] = true
+ }
+
+ for i := 0; i < 900; i += 1 {
+ want := fmt.Sprintf("solemn %d", i*10)
+ if !setMessages[want] {
+ t.Fatalf("could not find entry %q in journal", want)
+ }
+ }
+ for i := 0; i < 8192; i += 1 {
+ want := fmt.Sprintf("chatty %d", i+(9000-8192))
+ if !setMessages[want] {
+ t.Fatalf("could not find entry %q in journal", want)
+ }
+ }
+}
+
+func TestJournalSubtree(t *testing.T) {
+ j := newJournal()
+ j.append(&entry{origin: "a", leveled: testPayload("a")})
+ j.append(&entry{origin: "a.b", leveled: testPayload("a.b")})
+ j.append(&entry{origin: "a.b.c", leveled: testPayload("a.b.c")})
+ j.append(&entry{origin: "a.b.d", leveled: testPayload("a.b.d")})
+ j.append(&entry{origin: "e.f", leveled: testPayload("e.f")})
+ j.append(&entry{origin: "e.g", leveled: testPayload("e.g")})
+
+ expect := func(f filter, msgs ...string) string {
+ res := j.scanEntries(f)
+ set := make(map[string]bool)
+ for _, entry := range res {
+ set[strings.Join(entry.leveled.messages, "\n")] = true
+ }
+
+ for _, want := range msgs {
+ if !set[want] {
+ return fmt.Sprintf("missing entry %q", want)
+ }
+ }
+ return ""
+ }
+
+ if res := expect(filterAll(), "a", "a.b", "a.b.c", "a.b.d", "e.f", "e.g"); res != "" {
+ t.Fatalf("All: %s", res)
+ }
+ if res := expect(filterSubtree("a"), "a", "a.b", "a.b.c", "a.b.d"); res != "" {
+ t.Fatalf("Subtree(a): %s", res)
+ }
+ if res := expect(filterSubtree("a.b"), "a.b", "a.b.c", "a.b.d"); res != "" {
+ t.Fatalf("Subtree(a.b): %s", res)
+ }
+ if res := expect(filterSubtree("e"), "e.f", "e.g"); res != "" {
+ t.Fatalf("Subtree(a.b): %s", res)
+ }
+}
diff --git a/metropolis/pkg/logtree/leveled.go b/metropolis/pkg/logtree/leveled.go
new file mode 100644
index 0000000..c24357e
--- /dev/null
+++ b/metropolis/pkg/logtree/leveled.go
@@ -0,0 +1,144 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logtree
+
+import (
+ "fmt"
+
+ apb "git.monogon.dev/source/nexantic.git/metropolis/proto/api"
+)
+
+// LeveledLogger is a generic interface for glog-style logging. There are four hardcoded log severities, in increasing
+// order: INFO, WARNING, ERROR, FATAL. Logging at a certain severity level logs not only to consumers expecting data
+// at that severity level, but also all lower severity levels. For example, an ERROR log will also be passed to
+// consumers looking at INFO or WARNING logs.
+type LeveledLogger interface {
+ // Info logs at the INFO severity. Arguments are handled in the manner of fmt.Print, a terminating newline is added
+ // if missing.
+ Info(args ...interface{})
+ // Infof logs at the INFO severity. Arguments are handled in the manner of fmt.Printf, a terminating newline is
+ // added if missing.
+ Infof(format string, args ...interface{})
+
+ // Warning logs at the WARNING severity. Arguments are handled in the manner of fmt.Print, a terminating newline is
+ // added if missing.
+ Warning(args ...interface{})
+ // Warningf logs at the WARNING severity. Arguments are handled in the manner of fmt.Printf, a terminating newline
+ // is added if missing.
+ Warningf(format string, args ...interface{})
+
+ // Error logs at the ERROR severity. Arguments are handled in the manner of fmt.Print, a terminating newline is
+ // added if missing.
+ Error(args ...interface{})
+ // Errorf logs at the ERROR severity. Arguments are handled in the manner of fmt.Printf, a terminating newline is
+ // added if missing.
+ Errorf(format string, args ...interface{})
+
+ // Fatal logs at the FATAL severity and aborts the current program. Arguments are handled in the manner of
+ // fmt.Print, a terminating newline is added if missing.
+ Fatal(args ...interface{})
+ // Fatalf logs at the FATAL severity and aborts the current program. Arguments are handled in the manner of
+ // fmt.Printf, a terminating newline is added if missing.
+ Fatalf(format string, args ...interface{})
+
+ // V returns a VerboseLeveledLogger at a given verbosity level. These verbosity levels can be dynamically set and
+ // unset on a package-granular level by consumers of the LeveledLogger logs. The returned value represents whether
+ // logging at the given verbosity level was active at that time, and as such should not be a long-lived object
+ // in programs.
+ // This construct is further refered to as 'V-logs'.
+ V(level VerbosityLevel) VerboseLeveledLogger
+}
+
+// VerbosityLevel is a verbosity level defined for V-logs. This can be changed programmatically per Go package. When
+// logging at a given VerbosityLevel V, the current level must be equal or higher to V for the logs to be recorded.
+// Conversely, enabling a V-logging at a VerbosityLevel V also enables all logging at lower levels [Int32Min .. (V-1)].
+type VerbosityLevel int32
+
+type VerboseLeveledLogger interface {
+ // Enabled returns if this level was enabled. If not enabled, all logging into this logger will be discarded
+ // immediately.
+ // Thus, Enabled() can be used to check the verbosity level before performing any logging:
+ // if l.V(3).Enabled() { l.Info("V3 is enabled") }
+ // or, in simple cases, the convenience function .Info can be used:
+ // l.V(3).Info("V3 is enabled")
+ // The second form is shorter and more convenient, but more expensive, as its arguments are always evaluated.
+ Enabled() bool
+ // Info is the equivalent of a LeveledLogger's Info call, guarded by whether this VerboseLeveledLogger is enabled.
+ Info(args ...interface{})
+ // Infof is the equivalent of a LeveledLogger's Infof call, guarded by whether this VerboseLeveledLogger is enabled.
+ Infof(format string, args ...interface{})
+}
+
+// Severity is one of the severities as described in LeveledLogger.
+type Severity string
+
+const (
+ INFO Severity = "I"
+ WARNING Severity = "W"
+ ERROR Severity = "E"
+ FATAL Severity = "F"
+)
+
+var (
+ // SeverityAtLeast maps a given severity to a list of severities that at that severity or higher. In other words,
+ // SeverityAtLeast[X] returns a list of severities that might be seen in a log at severity X.
+ SeverityAtLeast = map[Severity][]Severity{
+ INFO: {INFO, WARNING, ERROR, FATAL},
+ WARNING: {WARNING, ERROR, FATAL},
+ ERROR: {ERROR, FATAL},
+ FATAL: {FATAL},
+ }
+)
+
+func (s Severity) AtLeast(other Severity) bool {
+ for _, el := range SeverityAtLeast[other] {
+ if el == s {
+ return true
+ }
+ }
+ return false
+}
+
+func SeverityFromProto(s apb.LeveledLogSeverity) (Severity, error) {
+ switch s {
+ case apb.LeveledLogSeverity_INFO:
+ return INFO, nil
+ case apb.LeveledLogSeverity_WARNING:
+ return WARNING, nil
+ case apb.LeveledLogSeverity_ERROR:
+ return ERROR, nil
+ case apb.LeveledLogSeverity_FATAL:
+ return FATAL, nil
+ default:
+ return "", fmt.Errorf("unknown severity value %d", s)
+ }
+}
+
+func (s Severity) ToProto() apb.LeveledLogSeverity {
+ switch s {
+ case INFO:
+ return apb.LeveledLogSeverity_INFO
+ case WARNING:
+ return apb.LeveledLogSeverity_WARNING
+ case ERROR:
+ return apb.LeveledLogSeverity_ERROR
+ case FATAL:
+ return apb.LeveledLogSeverity_FATAL
+ default:
+ return apb.LeveledLogSeverity_INVALID
+ }
+}
diff --git a/metropolis/pkg/logtree/leveled_payload.go b/metropolis/pkg/logtree/leveled_payload.go
new file mode 100644
index 0000000..fad42e3
--- /dev/null
+++ b/metropolis/pkg/logtree/leveled_payload.go
@@ -0,0 +1,142 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logtree
+
+import (
+ "fmt"
+ "strconv"
+ "strings"
+ "time"
+
+ apb "git.monogon.dev/source/nexantic.git/metropolis/proto/api"
+)
+
+// LeveledPayload is a log entry for leveled logs (as per leveled.go). It contains the input to these calls (severity and
+// message split into newline-delimited messages) and additional metadata that would be usually seen in a text
+// representation of a leveled log entry.
+type LeveledPayload struct {
+ // messages is the list of messages contained in this payload. This list is built from splitting up the given message
+ // from the user by newline.
+ messages []string
+ // timestamp is the time at which this message was emitted.
+ timestamp time.Time
+ // severity is the leveled Severity at which this message was emitted.
+ severity Severity
+ // file is the filename of the caller that emitted this message.
+ file string
+ // line is the line number within the file of the caller that emitted this message.
+ line int
+}
+
+// String returns a canonical representation of this payload as a single string prefixed with metadata. If the original
+// message was logged with newlines, this representation will also contain newlines, with each original message part
+// prefixed by the metadata.
+// For an alternative call that will instead return a canonical prefix and a list of lines in the message, see Strings().
+func (p *LeveledPayload) String() string {
+ prefix, lines := p.Strings()
+ res := make([]string, len(p.messages))
+ for i, line := range lines {
+ res[i] = fmt.Sprintf("%s%s", prefix, line)
+ }
+ return strings.Join(res, "\n")
+}
+
+// Strings returns the canonical representation of this payload split into a prefix and all lines that were contained in
+// the original message. This is meant to be displayed to the user by showing the prefix before each line, concatenated
+// together - possibly in a table form with the prefixes all unified with a rowspan-like mechanism.
+//
+// For example, this function can return:
+// prefix = "I1102 17:20:06.921395 foo.go:42] "
+// lines = []string{"current tags:", " - one", " - two"}
+//
+// With this data, the result should be presented to users this way in text form:
+// I1102 17:20:06.921395 foo.go:42] current tags:
+// I1102 17:20:06.921395 foo.go:42] - one
+// I1102 17:20:06.921395 foo.go:42] - two
+//
+// Or, in a table layout:
+// .-----------------------------------------------------------.
+// | I1102 17:20:06.921395 0 foo.go:42] : current tags: |
+// | :------------------|
+// | : - one |
+// | :------------------|
+// | : - two |
+// '-----------------------------------------------------------'
+//
+func (p *LeveledPayload) Strings() (prefix string, lines []string) {
+ _, month, day := p.timestamp.Date()
+ hour, minute, second := p.timestamp.Clock()
+ nsec := p.timestamp.Nanosecond() / 1000
+
+ // Same format as in glog, but without treadid.
+ // Lmmdd hh:mm:ss.uuuuuu file:line]
+ // TODO(q3k): rewrite this to printf-less code.
+ prefix = fmt.Sprintf("%s%02d%02d %02d:%02d:%02d.%06d %s:%d] ", p.severity, month, day, hour, minute, second, nsec, p.file, p.line)
+
+ lines = p.messages
+ return
+}
+
+// Message returns the inner message lines of this entry, ie. what was passed to the actual logging method, but split by
+// newlines.
+func (p *LeveledPayload) Messages() []string { return p.messages }
+
+func (p *LeveledPayload) MessagesJoined() string { return strings.Join(p.messages, "\n") }
+
+// Timestamp returns the time at which this entry was logged.
+func (p *LeveledPayload) Timestamp() time.Time { return p.timestamp }
+
+// Location returns a string in the form of file_name:line_number that shows the origin of the log entry in the
+// program source.
+func (p *LeveledPayload) Location() string { return fmt.Sprintf("%s:%d", p.file, p.line) }
+
+// Severity returns the Severity with which this entry was logged.
+func (p *LeveledPayload) Severity() Severity { return p.severity }
+
+// Proto converts a LeveledPayload to protobuf format.
+func (p *LeveledPayload) Proto() *apb.LogEntry_Leveled {
+ return &apb.LogEntry_Leveled{
+ Lines: p.Messages(),
+ Timestamp: p.Timestamp().UnixNano(),
+ Severity: p.Severity().ToProto(),
+ Location: p.Location(),
+ }
+}
+
+// LeveledPayloadFromProto parses a protobuf message into the internal format.
+func LeveledPayloadFromProto(p *apb.LogEntry_Leveled) (*LeveledPayload, error) {
+ severity, err := SeverityFromProto(p.Severity)
+ if err != nil {
+ return nil, fmt.Errorf("could not convert severity: %w", err)
+ }
+ parts := strings.Split(p.Location, ":")
+ if len(parts) != 2 {
+ return nil, fmt.Errorf("invalid location, must be two :-delimited parts, is %d parts", len(parts))
+ }
+ file := parts[0]
+ line, err := strconv.Atoi(parts[1])
+ if err != nil {
+ return nil, fmt.Errorf("invalid location line number: %w", err)
+ }
+ return &LeveledPayload{
+ messages: p.Lines,
+ timestamp: time.Unix(0, p.Timestamp),
+ severity: severity,
+ file: file,
+ line: line,
+ }, nil
+}
diff --git a/metropolis/pkg/logtree/logtree.go b/metropolis/pkg/logtree/logtree.go
new file mode 100644
index 0000000..8523569
--- /dev/null
+++ b/metropolis/pkg/logtree/logtree.go
@@ -0,0 +1,147 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logtree
+
+import (
+ "fmt"
+ "strings"
+ "sync"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/pkg/logbuffer"
+)
+
+// LogTree is a tree-shaped logging system. For more information, see the package-level documentation.
+type LogTree struct {
+ // journal is the tree's journal, storing all log data and managing subscribers.
+ journal *journal
+ // root is the root node of the actual tree of the log tree. The nodes contain per-DN configuration options, notably
+ // the current verbosity level of that DN.
+ root *node
+}
+
+func New() *LogTree {
+ lt := &LogTree{
+ journal: newJournal(),
+ }
+ lt.root = newNode(lt, "")
+ return lt
+}
+
+// node represents a given DN as a discrete 'logger'. It implements the LeveledLogger interface for log publishing,
+// entries from which it passes over to the logtree's journal.
+type node struct {
+ // dn is the DN which this node represents (or "" if this is the root node).
+ dn DN
+ // tree is the LogTree to which this node belongs.
+ tree *LogTree
+ // verbosity is the current verbosity level of this DN/node, affecting .V(n) LeveledLogger calls
+ verbosity VerbosityLevel
+ rawLineBuffer *logbuffer.LineBuffer
+
+ // mu guards children.
+ mu sync.Mutex
+ // children is a map of DN-part to a children node in the logtree. A DN-part is a string representing a part of the
+ // DN between the deliming dots, as returned by DN.Path.
+ children map[string]*node
+}
+
+// newNode returns a node at a given DN in the LogTree - but doesn't set up the LogTree to insert it accordingly.
+func newNode(tree *LogTree, dn DN) *node {
+ n := &node{
+ dn: dn,
+ tree: tree,
+ children: make(map[string]*node),
+ }
+ // TODO(q3k): make this limit configurable. If this happens, or the default (1024) gets changes, max chunk size
+ // calculations when serving the logs (eg. in NodeDebugService) must reflect this.
+ n.rawLineBuffer = logbuffer.NewLineBuffer(1024, n.logRaw)
+ return n
+}
+
+// nodeByDN returns the LogTree node corresponding to a given DN. If either the node or some of its parents do not
+// exist they will be created as needed.
+func (l *LogTree) nodeByDN(dn DN) (*node, error) {
+ traversal, err := newTraversal(dn)
+ if err != nil {
+ return nil, fmt.Errorf("traversal failed: %w", err)
+ }
+ return traversal.execute(l.root), nil
+}
+
+// nodeTraversal represents a request to traverse the LogTree in search of a given node by DN.
+type nodeTraversal struct {
+ // want is the DN of the node's that requested to be found.
+ want DN
+ // current is the path already taken to find the node, in the form of DN parts. It starts out as want.Parts() and
+ // progresses to become empty as the traversal continues.
+ current []string
+ // left is the path that's still needed to be taken in order to find the node, in the form of DN parts. It starts
+ // out empty and progresses to become wants.Parts() as the traversal continues.
+ left []string
+}
+
+// next adjusts the traversal's current/left slices to the next element of the traversal, returns the part that's now
+// being looked for (or "" if the traveral is done) and the full DN of the element that's being looked for.
+//
+// For example, a traversal of foo.bar.baz will cause .next() to return the following on each invocation:
+// - part: foo, full: foo
+// - part: bar, full: foo.bar
+// - part: baz, full: foo.bar.baz
+// - part: "", full: foo.bar.baz
+func (t *nodeTraversal) next() (part string, full DN) {
+ if len(t.left) == 0 {
+ return "", t.want
+ }
+ part = t.left[0]
+ t.current = append(t.current, part)
+ t.left = t.left[1:]
+ full = DN(strings.Join(t.current, "."))
+ return
+}
+
+// newTraversal returns a nodeTraversal fora a given wanted DN.
+func newTraversal(dn DN) (*nodeTraversal, error) {
+ parts, err := dn.Path()
+ if err != nil {
+ return nil, err
+ }
+ return &nodeTraversal{
+ want: dn,
+ left: parts,
+ }, nil
+}
+
+// execute the traversal in order to find the node. This can only be called once per traversal.
+// Nodes will be created within the tree until the target node is reached. Existing nodes will be reused.
+// This is effectively an idempotent way of accessing a node in the tree based on a traversal.
+func (t *nodeTraversal) execute(n *node) *node {
+ cur := n
+ for {
+ part, full := t.next()
+ if part == "" {
+ return cur
+ }
+
+ mu := &cur.mu
+ mu.Lock()
+ if _, ok := cur.children[part]; !ok {
+ cur.children[part] = newNode(n.tree, DN(full))
+ }
+ cur = cur.children[part]
+ mu.Unlock()
+ }
+}
diff --git a/metropolis/pkg/logtree/logtree_access.go b/metropolis/pkg/logtree/logtree_access.go
new file mode 100644
index 0000000..fed202e
--- /dev/null
+++ b/metropolis/pkg/logtree/logtree_access.go
@@ -0,0 +1,183 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logtree
+
+import (
+ "errors"
+ "sync/atomic"
+)
+
+// LogReadOption describes options for the LogTree.Read call.
+type LogReadOption struct {
+ withChildren bool
+ withStream bool
+ withBacklog int
+ onlyLeveled bool
+ onlyRaw bool
+ leveledWithMinimumSeverity Severity
+}
+
+// WithChildren makes Read return/stream data for both a given DN and all its children.
+func WithChildren() LogReadOption { return LogReadOption{withChildren: true} }
+
+// WithStream makes Read return a stream of data. This works alongside WithBacklog to create a read-and-stream
+// construct.
+func WithStream() LogReadOption { return LogReadOption{withStream: true} }
+
+// WithBacklog makes Read return already recorded log entries, up to count elements.
+func WithBacklog(count int) LogReadOption { return LogReadOption{withBacklog: count} }
+
+// BacklogAllAvailable makes WithBacklog return all backlogged log data that logtree possesses.
+const BacklogAllAvailable int = -1
+
+func OnlyRaw() LogReadOption { return LogReadOption{onlyRaw: true} }
+
+func OnlyLeveled() LogReadOption { return LogReadOption{onlyLeveled: true} }
+
+// LeveledWithMinimumSeverity makes Read return only log entries that are at least at a given Severity. If only leveled
+// entries are needed, OnlyLeveled must be used. This is a no-op when OnlyRaw is used.
+func LeveledWithMinimumSeverity(s Severity) LogReadOption {
+ return LogReadOption{leveledWithMinimumSeverity: s}
+}
+
+// LogReader permits reading an already existing backlog of log entries and to stream further ones.
+type LogReader struct {
+ // Backlog are the log entries already logged by LogTree. This will only be set if WithBacklog has been passed to
+ // Read.
+ Backlog []*LogEntry
+ // Stream is a channel of new entries as received live by LogTree. This will only be set if WithStream has been
+ // passed to Read. In this case, entries from this channel must be read as fast as possible by the consumer in order
+ // to prevent missing entries.
+ Stream <-chan *LogEntry
+ // done is channel used to signal (by closing) that the log consumer is not interested in more Stream data.
+ done chan<- struct{}
+ // missed is an atomic integer pointer that tells the subscriber how many messages in Stream they missed. This
+ // pointer is nil if no streaming has been requested.
+ missed *uint64
+}
+
+// Missed returns the amount of entries that were missed from Stream (as the channel was not drained fast enough).
+func (l *LogReader) Missed() uint64 {
+ // No Stream.
+ if l.missed == nil {
+ return 0
+ }
+ return atomic.LoadUint64(l.missed)
+}
+
+// Close closes the LogReader's Stream. This must be called once the Reader does not wish to receive streaming messages
+// anymore.
+func (l *LogReader) Close() {
+ if l.done != nil {
+ close(l.done)
+ }
+}
+
+var (
+ ErrRawAndLeveled = errors.New("cannot return logs that are simultaneously OnlyRaw and OnlyLeveled")
+)
+
+// Read and/or stream entries from a LogTree. The returned LogReader is influenced by the LogReadOptions passed, which
+// influence whether the Read will return existing entries, a stream, or both. In addition the options also dictate
+// whether only entries for that particular DN are returned, or for all sub-DNs as well.
+func (l *LogTree) Read(dn DN, opts ...LogReadOption) (*LogReader, error) {
+ l.journal.mu.RLock()
+ defer l.journal.mu.RUnlock()
+
+ var backlog int
+ var stream bool
+ var recursive bool
+ var leveledSeverity Severity
+ var onlyRaw, onlyLeveled bool
+
+ for _, opt := range opts {
+ if opt.withBacklog > 0 || opt.withBacklog == BacklogAllAvailable {
+ backlog = opt.withBacklog
+ }
+ if opt.withStream {
+ stream = true
+ }
+ if opt.withChildren {
+ recursive = true
+ }
+ if opt.leveledWithMinimumSeverity != "" {
+ leveledSeverity = opt.leveledWithMinimumSeverity
+ }
+ if opt.onlyLeveled {
+ onlyLeveled = true
+ }
+ if opt.onlyRaw {
+ onlyRaw = true
+ }
+ }
+
+ if onlyLeveled && onlyRaw {
+ return nil, ErrRawAndLeveled
+ }
+
+ var filters []filter
+ if onlyLeveled {
+ filters = append(filters, filterOnlyLeveled)
+ }
+ if onlyRaw {
+ filters = append(filters, filterOnlyRaw)
+ }
+ if recursive {
+ filters = append(filters, filterSubtree(dn))
+ } else {
+ filters = append(filters, filterExact(dn))
+ }
+ if leveledSeverity != "" {
+ filters = append(filters, filterSeverity(leveledSeverity))
+ }
+
+ var entries []*entry
+ if backlog > 0 || backlog == BacklogAllAvailable {
+ // TODO(q3k): pass over the backlog count to scanEntries/getEntries, instead of discarding them here.
+ if recursive {
+ entries = l.journal.scanEntries(filters...)
+ } else {
+ entries = l.journal.getEntries(dn, filters...)
+ }
+ if backlog != BacklogAllAvailable && len(entries) > backlog {
+ entries = entries[:backlog]
+ }
+ }
+
+ var sub *subscriber
+ if stream {
+ sub = &subscriber{
+ // TODO(q3k): make buffer size configurable
+ dataC: make(chan *LogEntry, 128),
+ doneC: make(chan struct{}),
+ filters: filters,
+ }
+ l.journal.subscribe(sub)
+ }
+
+ lr := &LogReader{}
+ lr.Backlog = make([]*LogEntry, len(entries))
+ for i, entry := range entries {
+ lr.Backlog[i] = entry.external()
+ }
+ if stream {
+ lr.Stream = sub.dataC
+ lr.done = sub.doneC
+ lr.missed = &sub.missed
+ }
+ return lr, nil
+}
diff --git a/metropolis/pkg/logtree/logtree_entry.go b/metropolis/pkg/logtree/logtree_entry.go
new file mode 100644
index 0000000..321406d
--- /dev/null
+++ b/metropolis/pkg/logtree/logtree_entry.go
@@ -0,0 +1,141 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logtree
+
+import (
+ "fmt"
+ "strings"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/pkg/logbuffer"
+ apb "git.monogon.dev/source/nexantic.git/metropolis/proto/api"
+)
+
+// LogEntry contains a log entry, combining both leveled and raw logging into a single stream of events. A LogEntry
+// will contain exactly one of either LeveledPayload or RawPayload.
+type LogEntry struct {
+ // If non-nil, this is a leveled logging entry.
+ Leveled *LeveledPayload
+ // If non-nil, this is a raw logging entry line.
+ Raw *logbuffer.Line
+ // DN from which this entry was logged.
+ DN DN
+}
+
+// String returns a canonical representation of this payload as a single string prefixed with metadata. If the entry is
+// a leveled log entry that originally was logged with newlines this representation will also contain newlines, with
+// each original message part prefixed by the metadata.
+// For an alternative call that will instead return a canonical prefix and a list of lines in the message, see Strings().
+func (l *LogEntry) String() string {
+ if l.Leveled != nil {
+ prefix, messages := l.Leveled.Strings()
+ res := make([]string, len(messages))
+ for i, m := range messages {
+ res[i] = fmt.Sprintf("%-32s %s%s", l.DN, prefix, m)
+ }
+ return strings.Join(res, "\n")
+ }
+ if l.Raw != nil {
+ return fmt.Sprintf("%-32s R %s", l.DN, l.Raw)
+ }
+ return "INVALID"
+}
+
+// Strings returns the canonical representation of this payload split into a prefix and all lines that were contained in
+// the original message. This is meant to be displayed to the user by showing the prefix before each line, concatenated
+// together - possibly in a table form with the prefixes all unified with a rowspan-like mechanism.
+//
+// For example, this function can return:
+// prefix = "root.foo.bar I1102 17:20:06.921395 0 foo.go:42] "
+// lines = []string{"current tags:", " - one", " - two"}
+//
+// With this data, the result should be presented to users this way in text form:
+// root.foo.bar I1102 17:20:06.921395 foo.go:42] current tags:
+// root.foo.bar I1102 17:20:06.921395 foo.go:42] - one
+// root.foo.bar I1102 17:20:06.921395 foo.go:42] - two
+//
+// Or, in a table layout:
+// .-------------------------------------------------------------------------------------.
+// | root.foo.bar I1102 17:20:06.921395 foo.go:42] : current tags: |
+// | :------------------|
+// | : - one |
+// | :------------------|
+// | : - two |
+// '-------------------------------------------------------------------------------------'
+//
+func (l *LogEntry) Strings() (prefix string, lines []string) {
+ if l.Leveled != nil {
+ prefix, messages := l.Leveled.Strings()
+ prefix = fmt.Sprintf("%-32s %s", l.DN, prefix)
+ return prefix, messages
+ }
+ if l.Raw != nil {
+ return fmt.Sprintf("%-32s R ", l.DN), []string{l.Raw.Data}
+ }
+ return "INVALID ", []string{"INVALID"}
+}
+
+// Convert this LogEntry to proto. Returned value may be nil if given LogEntry is invalid, eg. contains neither a Raw
+// nor Leveled entry.
+func (l *LogEntry) Proto() *apb.LogEntry {
+ p := &apb.LogEntry{
+ Dn: string(l.DN),
+ }
+ switch {
+ case l.Leveled != nil:
+ leveled := l.Leveled
+ p.Kind = &apb.LogEntry_Leveled_{
+ Leveled: leveled.Proto(),
+ }
+ case l.Raw != nil:
+ raw := l.Raw
+ p.Kind = &apb.LogEntry_Raw_{
+ Raw: raw.ProtoLog(),
+ }
+ default:
+ return nil
+ }
+ return p
+}
+
+// Parse a proto LogEntry back into internal structure. This can be used in log proto API consumers to easily print
+// received log entries.
+func LogEntryFromProto(l *apb.LogEntry) (*LogEntry, error) {
+ dn := DN(l.Dn)
+ if _, err := dn.Path(); err != nil {
+ return nil, fmt.Errorf("could not convert DN: %w", err)
+ }
+ res := &LogEntry{
+ DN: dn,
+ }
+ switch inner := l.Kind.(type) {
+ case *apb.LogEntry_Leveled_:
+ leveled, err := LeveledPayloadFromProto(inner.Leveled)
+ if err != nil {
+ return nil, fmt.Errorf("could not convert leveled entry: %w", err)
+ }
+ res.Leveled = leveled
+ case *apb.LogEntry_Raw_:
+ line, err := logbuffer.LineFromLogProto(inner.Raw)
+ if err != nil {
+ return nil, fmt.Errorf("could not convert raw entry: %w", err)
+ }
+ res.Raw = line
+ default:
+ return nil, fmt.Errorf("proto has neither Leveled nor Raw set")
+ }
+ return res, nil
+}
diff --git a/metropolis/pkg/logtree/logtree_publisher.go b/metropolis/pkg/logtree/logtree_publisher.go
new file mode 100644
index 0000000..3e2711a
--- /dev/null
+++ b/metropolis/pkg/logtree/logtree_publisher.go
@@ -0,0 +1,185 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logtree
+
+import (
+ "fmt"
+ "io"
+ "runtime"
+ "strings"
+ "time"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/pkg/logbuffer"
+)
+
+// LeveledFor returns a LeveledLogger publishing interface for a given DN. An error may be returned if the DN is
+// malformed.
+func (l *LogTree) LeveledFor(dn DN) (LeveledLogger, error) {
+ return l.nodeByDN(dn)
+}
+
+func (l *LogTree) RawFor(dn DN) (io.Writer, error) {
+ node, err := l.nodeByDN(dn)
+ if err != nil {
+ return nil, fmt.Errorf("could not retrieve raw logger: %w", err)
+ }
+ return node.rawLineBuffer, nil
+}
+
+// MustLeveledFor returns a LeveledLogger publishing interface for a given DN, or panics if the given DN is invalid.
+func (l *LogTree) MustLeveledFor(dn DN) LeveledLogger {
+ leveled, err := l.LeveledFor(dn)
+ if err != nil {
+ panic(fmt.Errorf("LeveledFor returned: %w", err))
+ }
+ return leveled
+}
+
+func (l *LogTree) MustRawFor(dn DN) io.Writer {
+ raw, err := l.RawFor(dn)
+ if err != nil {
+ panic(fmt.Errorf("RawFor returned: %w", err))
+ }
+ return raw
+}
+
+// SetVerbosity sets the verbosity for a given DN (non-recursively, ie. for that DN only, not its children).
+func (l *LogTree) SetVerbosity(dn DN, level VerbosityLevel) error {
+ node, err := l.nodeByDN(dn)
+ if err != nil {
+ return err
+ }
+ node.verbosity = level
+ return nil
+}
+
+// logRaw is called by this node's LineBuffer any time a raw log line is completed. It will create a new entry, append
+// it to the journal, and notify all pertinent subscribers.
+func (n *node) logRaw(line *logbuffer.Line) {
+ e := &entry{
+ origin: n.dn,
+ raw: line,
+ }
+ n.tree.journal.append(e)
+ n.tree.journal.notify(e)
+}
+
+// log builds a LeveledPayload and entry for a given message, including all related metadata. It will create a new
+// entry append it to the journal, and notify all pertinent subscribers.
+func (n *node) logLeveled(depth int, severity Severity, msg string) {
+ _, file, line, ok := runtime.Caller(2 + depth)
+ if !ok {
+ file = "???"
+ line = 1
+ } else {
+ slash := strings.LastIndex(file, "/")
+ if slash >= 0 {
+ file = file[slash+1:]
+ }
+ }
+
+ // Remove leading/trailing newlines and split.
+ messages := strings.Split(strings.Trim(msg, "\n"), "\n")
+
+ p := &LeveledPayload{
+ timestamp: time.Now(),
+ severity: severity,
+ messages: messages,
+ file: file,
+ line: line,
+ }
+ e := &entry{
+ origin: n.dn,
+ leveled: p,
+ }
+ n.tree.journal.append(e)
+ n.tree.journal.notify(e)
+}
+
+// Info implements the LeveledLogger interface.
+func (n *node) Info(args ...interface{}) {
+ n.logLeveled(0, INFO, fmt.Sprint(args...))
+}
+
+// Infof implements the LeveledLogger interface.
+func (n *node) Infof(format string, args ...interface{}) {
+ n.logLeveled(0, INFO, fmt.Sprintf(format, args...))
+}
+
+// Warning implements the LeveledLogger interface.
+func (n *node) Warning(args ...interface{}) {
+ n.logLeveled(0, WARNING, fmt.Sprint(args...))
+}
+
+// Warningf implements the LeveledLogger interface.
+func (n *node) Warningf(format string, args ...interface{}) {
+ n.logLeveled(0, WARNING, fmt.Sprintf(format, args...))
+}
+
+// Error implements the LeveledLogger interface.
+func (n *node) Error(args ...interface{}) {
+ n.logLeveled(0, ERROR, fmt.Sprint(args...))
+}
+
+// Errorf implements the LeveledLogger interface.
+func (n *node) Errorf(format string, args ...interface{}) {
+ n.logLeveled(0, ERROR, fmt.Sprintf(format, args...))
+}
+
+// Fatal implements the LeveledLogger interface.
+func (n *node) Fatal(args ...interface{}) {
+ n.logLeveled(0, FATAL, fmt.Sprint(args...))
+}
+
+// Fatalf implements the LeveledLogger interface.
+func (n *node) Fatalf(format string, args ...interface{}) {
+ n.logLeveled(0, FATAL, fmt.Sprintf(format, args...))
+}
+
+// V implements the LeveledLogger interface.
+func (n *node) V(v VerbosityLevel) VerboseLeveledLogger {
+ return &verbose{
+ node: n,
+ enabled: n.verbosity >= v,
+ }
+}
+
+// verbose implements the VerboseLeveledLogger interface. It is a thin wrapper around node, with an 'enabled' bool. This
+// means that V(n)-returned VerboseLeveledLoggers must be short lived, as a changed in verbosity will not affect all
+// already existing VerboseLeveledLoggers.
+type verbose struct {
+ node *node
+ enabled bool
+}
+
+func (v *verbose) Enabled() bool {
+ return v.enabled
+}
+
+func (v *verbose) Info(args ...interface{}) {
+ if !v.enabled {
+ return
+ }
+ v.node.logLeveled(0, INFO, fmt.Sprint(args...))
+}
+
+func (v *verbose) Infof(format string, args ...interface{}) {
+ if !v.enabled {
+ return
+ }
+ v.node.logLeveled(0, INFO, fmt.Sprintf(format, args...))
+}
diff --git a/metropolis/pkg/logtree/logtree_test.go b/metropolis/pkg/logtree/logtree_test.go
new file mode 100644
index 0000000..b900201
--- /dev/null
+++ b/metropolis/pkg/logtree/logtree_test.go
@@ -0,0 +1,211 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logtree
+
+import (
+ "fmt"
+ "strings"
+ "testing"
+ "time"
+)
+
+func expect(tree *LogTree, t *testing.T, dn DN, entries ...string) string {
+ t.Helper()
+ res, err := tree.Read(dn, WithChildren(), WithBacklog(BacklogAllAvailable))
+ if err != nil {
+ t.Fatalf("Read: %v", err)
+ }
+ if want, got := len(entries), len(res.Backlog); want != got {
+ t.Fatalf("wanted %v backlog entries, got %v", want, got)
+ }
+ got := make(map[string]bool)
+ for _, entry := range res.Backlog {
+ if entry.Leveled != nil {
+ got[entry.Leveled.MessagesJoined()] = true
+ }
+ if entry.Raw != nil {
+ got[entry.Raw.Data] = true
+ }
+ }
+ for _, entry := range entries {
+ if !got[entry] {
+ return fmt.Sprintf("missing entry %q", entry)
+ }
+ }
+ return ""
+}
+
+func TestMultiline(t *testing.T) {
+ tree := New()
+ // Two lines in a single message.
+ tree.MustLeveledFor("main").Info("foo\nbar")
+ // Two lines in a single message with a hanging newline that should get stripped.
+ tree.MustLeveledFor("main").Info("one\ntwo\n")
+
+ if res := expect(tree, t, "main", "foo\nbar", "one\ntwo"); res != "" {
+ t.Errorf("retrieval at main failed: %s", res)
+ }
+}
+
+func TestBacklog(t *testing.T) {
+ tree := New()
+ tree.MustLeveledFor("main").Info("hello, main!")
+ tree.MustLeveledFor("main.foo").Info("hello, main.foo!")
+ tree.MustLeveledFor("main.bar").Info("hello, main.bar!")
+ tree.MustLeveledFor("aux").Info("hello, aux!")
+ // No newline at the last entry - shouldn't get propagated to the backlog.
+ fmt.Fprintf(tree.MustRawFor("aux.process"), "processing foo\nprocessing bar\nbaz")
+
+ if res := expect(tree, t, "main", "hello, main!", "hello, main.foo!", "hello, main.bar!"); res != "" {
+ t.Errorf("retrieval at main failed: %s", res)
+ }
+ if res := expect(tree, t, "", "hello, main!", "hello, main.foo!", "hello, main.bar!", "hello, aux!", "processing foo", "processing bar"); res != "" {
+ t.Errorf("retrieval at root failed: %s", res)
+ }
+ if res := expect(tree, t, "aux", "hello, aux!", "processing foo", "processing bar"); res != "" {
+ t.Errorf("retrieval at aux failed: %s", res)
+ }
+}
+
+func TestStream(t *testing.T) {
+ tree := New()
+ tree.MustLeveledFor("main").Info("hello, backlog")
+ fmt.Fprintf(tree.MustRawFor("main.process"), "hello, raw backlog\n")
+
+ res, err := tree.Read("", WithBacklog(BacklogAllAvailable), WithChildren(), WithStream())
+ if err != nil {
+ t.Fatalf("Read: %v", err)
+ }
+ defer res.Close()
+ if want, got := 2, len(res.Backlog); want != got {
+ t.Errorf("wanted %d backlog item, got %d", want, got)
+ }
+
+ tree.MustLeveledFor("main").Info("hello, stream")
+ fmt.Fprintf(tree.MustRawFor("main.raw"), "hello, raw stream\n")
+
+ entries := make(map[string]bool)
+ timeout := time.After(time.Second * 1)
+ for {
+ done := false
+ select {
+ case <-timeout:
+ done = true
+ case p := <-res.Stream:
+ if p.Leveled != nil {
+ entries[p.Leveled.MessagesJoined()] = true
+ }
+ if p.Raw != nil {
+ entries[p.Raw.Data] = true
+ }
+ }
+ if done {
+ break
+ }
+ }
+ if entry := "hello, stream"; !entries[entry] {
+ t.Errorf("Missing entry %q", entry)
+ }
+ if entry := "hello, raw stream"; !entries[entry] {
+ t.Errorf("Missing entry %q", entry)
+ }
+}
+
+func TestVerbose(t *testing.T) {
+ tree := New()
+
+ tree.MustLeveledFor("main").V(10).Info("this shouldn't get logged")
+
+ reader, err := tree.Read("", WithBacklog(BacklogAllAvailable), WithChildren())
+ if err != nil {
+ t.Fatalf("Read: %v", err)
+ }
+ if want, got := 0, len(reader.Backlog); want != got {
+ t.Fatalf("expected nothing to be logged, got %+v", reader.Backlog)
+ }
+
+ tree.SetVerbosity("main", 10)
+ tree.MustLeveledFor("main").V(10).Info("this should get logged")
+
+ reader, err = tree.Read("", WithBacklog(BacklogAllAvailable), WithChildren())
+ if err != nil {
+ t.Fatalf("Read: %v", err)
+ }
+ if want, got := 1, len(reader.Backlog); want != got {
+ t.Fatalf("expected %d entries to get logged, got %d", want, got)
+ }
+}
+
+func TestMetadata(t *testing.T) {
+ tree := New()
+ tree.MustLeveledFor("main").Error("i am an error")
+ tree.MustLeveledFor("main").Warning("i am a warning")
+ tree.MustLeveledFor("main").Info("i am informative")
+ tree.MustLeveledFor("main").V(0).Info("i am a zero-level debug")
+
+ reader, err := tree.Read("", WithChildren(), WithBacklog(BacklogAllAvailable))
+ if err != nil {
+ t.Fatalf("Read: %v", err)
+ }
+ if want, got := 4, len(reader.Backlog); want != got {
+ t.Fatalf("expected %d entries, got %d", want, got)
+ }
+
+ for _, te := range []struct {
+ ix int
+ severity Severity
+ message string
+ }{
+ {0, ERROR, "i am an error"},
+ {1, WARNING, "i am a warning"},
+ {2, INFO, "i am informative"},
+ {3, INFO, "i am a zero-level debug"},
+ } {
+ p := reader.Backlog[te.ix]
+ if want, got := te.severity, p.Leveled.Severity(); want != got {
+ t.Errorf("wanted element %d to have severity %s, got %s", te.ix, want, got)
+ }
+ if want, got := te.message, p.Leveled.MessagesJoined(); want != got {
+ t.Errorf("wanted element %d to have message %q, got %q", te.ix, want, got)
+ }
+ if want, got := "logtree_test.go", strings.Split(p.Leveled.Location(), ":")[0]; want != got {
+ t.Errorf("wanted element %d to have file %q, got %q", te.ix, want, got)
+ }
+ }
+}
+
+func TestSeverity(t *testing.T) {
+ tree := New()
+ tree.MustLeveledFor("main").Error("i am an error")
+ tree.MustLeveledFor("main").Warning("i am a warning")
+ tree.MustLeveledFor("main").Info("i am informative")
+ tree.MustLeveledFor("main").V(0).Info("i am a zero-level debug")
+
+ reader, err := tree.Read("main", WithBacklog(BacklogAllAvailable), LeveledWithMinimumSeverity(WARNING))
+ if err != nil {
+ t.Fatalf("Read: %v", err)
+ }
+ if want, got := 2, len(reader.Backlog); want != got {
+ t.Fatalf("wanted %d entries, got %d", want, got)
+ }
+ if want, got := "i am an error", reader.Backlog[0].Leveled.MessagesJoined(); want != got {
+ t.Fatalf("wanted entry %q, got %q", want, got)
+ }
+ if want, got := "i am a warning", reader.Backlog[1].Leveled.MessagesJoined(); want != got {
+ t.Fatalf("wanted entry %q, got %q", want, got)
+ }
+}
diff --git a/metropolis/pkg/supervisor/BUILD.bazel b/metropolis/pkg/supervisor/BUILD.bazel
new file mode 100644
index 0000000..40b0469
--- /dev/null
+++ b/metropolis/pkg/supervisor/BUILD.bazel
@@ -0,0 +1,28 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "go_default_library",
+ srcs = [
+ "supervisor.go",
+ "supervisor_node.go",
+ "supervisor_processor.go",
+ "supervisor_support.go",
+ "supervisor_testhelpers.go",
+ ],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/pkg/supervisor",
+ visibility = [
+ "//metropolis/node:__subpackages__",
+ "//metropolis/test:__subpackages__",
+ ],
+ deps = [
+ "//metropolis/pkg/logtree:go_default_library",
+ "@com_github_cenkalti_backoff_v4//:go_default_library",
+ "@org_golang_google_grpc//:go_default_library",
+ ],
+)
+
+go_test(
+ name = "go_default_test",
+ srcs = ["supervisor_test.go"],
+ embed = [":go_default_library"],
+)
diff --git a/metropolis/pkg/supervisor/supervisor.go b/metropolis/pkg/supervisor/supervisor.go
new file mode 100644
index 0000000..ed79c69
--- /dev/null
+++ b/metropolis/pkg/supervisor/supervisor.go
@@ -0,0 +1,145 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package supervisor
+
+// The service supervision library allows for writing of reliable, service-style software within a Metropolis node.
+// It builds upon the Erlang/OTP supervision tree system, adapted to be more Go-ish.
+// For detailed design see go/supervision.
+
+import (
+ "context"
+ "io"
+ "sync"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/pkg/logtree"
+)
+
+// A Runnable is a function that will be run in a goroutine, and supervised throughout its lifetime. It can in turn
+// start more runnables as its children, and those will form part of a supervision tree.
+// The context passed to a runnable is very important and needs to be handled properly. It will be live (non-errored) as
+// long as the runnable should be running, and canceled (ctx.Err() will be non-nil) when the supervisor wants it to
+// exit. This means this context is also perfectly usable for performing any blocking operations.
+type Runnable func(ctx context.Context) error
+
+// RunGroup starts a set of runnables as a group. These runnables will run together, and if any one of them quits
+// unexpectedly, the result will be canceled and restarted.
+// The context here must be an existing Runnable context, and the spawned runnables will run under the node that this
+// context represents.
+func RunGroup(ctx context.Context, runnables map[string]Runnable) error {
+ node, unlock := fromContext(ctx)
+ defer unlock()
+ return node.runGroup(runnables)
+}
+
+// Run starts a single runnable in its own group.
+func Run(ctx context.Context, name string, runnable Runnable) error {
+ return RunGroup(ctx, map[string]Runnable{
+ name: runnable,
+ })
+}
+
+// Signal tells the supervisor that the calling runnable has reached a certain state of its lifecycle. All runnables
+// should SignalHealthy when they are ready with set up, running other child runnables and are now 'serving'.
+func Signal(ctx context.Context, signal SignalType) {
+ node, unlock := fromContext(ctx)
+ defer unlock()
+ node.signal(signal)
+}
+
+type SignalType int
+
+const (
+ // The runnable is healthy, done with setup, done with spawning more Runnables, and ready to serve in a loop.
+ // The runnable needs to check the parent context and ensure that if that context is done, the runnable exits.
+ SignalHealthy SignalType = iota
+ // The runnable is done - it does not need to run any loop. This is useful for Runnables that only set up other
+ // child runnables. This runnable will be restarted if a related failure happens somewhere in the supervision tree.
+ SignalDone
+)
+
+// supervisor represents and instance of the supervision system. It keeps track of a supervision tree and a request
+// channel to its internal processor goroutine.
+type supervisor struct {
+ // mu guards the entire state of the supervisor.
+ mu sync.RWMutex
+ // root is the root node of the supervision tree, named 'root'. It represents the Runnable started with the
+ // supervisor.New call.
+ root *node
+ // logtree is the main logtree exposed to runnables and used internally.
+ logtree *logtree.LogTree
+ // ilogger is the internal logger logging to "supervisor" in the logtree.
+ ilogger logtree.LeveledLogger
+
+ // pReq is an interface channel to the lifecycle processor of the supervisor.
+ pReq chan *processorRequest
+
+ // propagate panics, ie. don't catch them.
+ propagatePanic bool
+}
+
+// SupervisorOpt are runtime configurable options for the supervisor.
+type SupervisorOpt func(s *supervisor)
+
+var (
+ // WithPropagatePanic prevents the Supervisor from catching panics in runnables and treating them as failures.
+ // This is useful to enable for testing and local debugging.
+ WithPropagatePanic = func(s *supervisor) {
+ s.propagatePanic = true
+ }
+)
+
+func WithExistingLogtree(lt *logtree.LogTree) SupervisorOpt {
+ return func(s *supervisor) {
+ s.logtree = lt
+ }
+}
+
+// New creates a new supervisor with its root running the given root runnable.
+// The given context can be used to cancel the entire supervision tree.
+func New(ctx context.Context, rootRunnable Runnable, opts ...SupervisorOpt) *supervisor {
+ sup := &supervisor{
+ logtree: logtree.New(),
+ pReq: make(chan *processorRequest),
+ }
+
+ for _, o := range opts {
+ o(sup)
+ }
+
+ sup.ilogger = sup.logtree.MustLeveledFor("supervisor")
+ sup.root = newNode("root", rootRunnable, sup, nil)
+
+ go sup.processor(ctx)
+
+ sup.pReq <- &processorRequest{
+ schedule: &processorRequestSchedule{dn: "root"},
+ }
+
+ return sup
+}
+
+func Logger(ctx context.Context) logtree.LeveledLogger {
+ node, unlock := fromContext(ctx)
+ defer unlock()
+ return node.sup.logtree.MustLeveledFor(logtree.DN(node.dn()))
+}
+
+func RawLogger(ctx context.Context) io.Writer {
+ node, unlock := fromContext(ctx)
+ defer unlock()
+ return node.sup.logtree.MustRawFor(logtree.DN(node.dn()))
+}
diff --git a/metropolis/pkg/supervisor/supervisor_node.go b/metropolis/pkg/supervisor/supervisor_node.go
new file mode 100644
index 0000000..a7caf82
--- /dev/null
+++ b/metropolis/pkg/supervisor/supervisor_node.go
@@ -0,0 +1,282 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package supervisor
+
+import (
+ "context"
+ "fmt"
+ "regexp"
+ "strings"
+
+ "github.com/cenkalti/backoff/v4"
+)
+
+// node is a supervision tree node. It represents the state of a Runnable within this tree, its relation to other tree
+// elements, and contains supporting data needed to actually supervise it.
+type node struct {
+ // The name of this node. Opaque string. It's used to make up the 'dn' (distinguished name) of a node within
+ // the tree. When starting a runnable inside a tree, this is where that name gets used.
+ name string
+ runnable Runnable
+
+ // The supervisor managing this tree.
+ sup *supervisor
+ // The parent, within the tree, of this node. If this is the root node of the tree, this is nil.
+ parent *node
+ // Children of this tree. This is represented by a map keyed from child node names, for easy access.
+ children map[string]*node
+ // Supervision groups. Each group is a set of names of children. Sets, and as such groups, don't overlap between
+ // each other. A supervision group indicates that if any child within that group fails, all others should be
+ // canceled and restarted together.
+ groups []map[string]bool
+
+ // The current state of the runnable in this node.
+ state nodeState
+
+ // Backoff used to keep runnables from being restarted too fast.
+ bo *backoff.ExponentialBackOff
+
+ // Context passed to the runnable, and its cancel function.
+ ctx context.Context
+ ctxC context.CancelFunc
+}
+
+// nodeState is the state of a runnable within a node, and in a way the node itself.
+// This follows the state diagram from go/supervision.
+type nodeState int
+
+const (
+ // A node that has just been created, and whose runnable has been started already but hasn't signaled anything yet.
+ nodeStateNew nodeState = iota
+ // A node whose runnable has signaled being healthy - this means it's ready to serve/act.
+ nodeStateHealthy
+ // A node that has unexpectedly returned or panicked.
+ nodeStateDead
+ // A node that has declared that its done with its work and should not be restarted, unless a supervision tree
+ // failure requires that.
+ nodeStateDone
+ // A node that has returned after being requested to cancel.
+ nodeStateCanceled
+)
+
+func (s nodeState) String() string {
+ switch s {
+ case nodeStateNew:
+ return "NODE_STATE_NEW"
+ case nodeStateHealthy:
+ return "NODE_STATE_HEALTHY"
+ case nodeStateDead:
+ return "NODE_STATE_DEAD"
+ case nodeStateDone:
+ return "NODE_STATE_DONE"
+ case nodeStateCanceled:
+ return "NODE_STATE_CANCELED"
+ }
+ return "UNKNOWN"
+}
+
+func (n *node) String() string {
+ return fmt.Sprintf("%s (%s)", n.dn(), n.state.String())
+}
+
+// contextKey is a type used to keep data within context values.
+type contextKey string
+
+var (
+ supervisorKey = contextKey("supervisor")
+ dnKey = contextKey("dn")
+)
+
+// fromContext retrieves a tree node from a runnable context. It takes a lock on the tree and returns an unlock
+// function. This unlock function needs to be called once mutations on the tree/supervisor/node are done.
+func fromContext(ctx context.Context) (*node, func()) {
+ sup, ok := ctx.Value(supervisorKey).(*supervisor)
+ if !ok {
+ panic("supervisor function called from non-runnable context")
+ }
+
+ sup.mu.Lock()
+
+ dnParent, ok := ctx.Value(dnKey).(string)
+ if !ok {
+ sup.mu.Unlock()
+ panic("supervisor function called from non-runnable context")
+ }
+
+ return sup.nodeByDN(dnParent), sup.mu.Unlock
+}
+
+// All the following 'internal' supervisor functions must only be called with the supervisor lock taken. Getting a lock
+// via fromContext is enough.
+
+// dn returns the distinguished name of a node. This distinguished name is a period-separated, inverse-DNS-like name.
+// For instance, the runnable 'foo' within the runnable 'bar' will be called 'root.bar.foo'. The root of the tree is
+// always named, and has the dn, 'root'.
+func (n *node) dn() string {
+ if n.parent != nil {
+ return fmt.Sprintf("%s.%s", n.parent.dn(), n.name)
+ }
+ return n.name
+}
+
+// groupSiblings is a helper function to get all runnable group siblings of a given runnable name within this node.
+// All children are always in a group, even if that group is unary.
+func (n *node) groupSiblings(name string) map[string]bool {
+ for _, m := range n.groups {
+ if _, ok := m[name]; ok {
+ return m
+ }
+ }
+ return nil
+}
+
+// newNode creates a new node with a given parent. It does not register it with the parent (as that depends on group
+// placement).
+func newNode(name string, runnable Runnable, sup *supervisor, parent *node) *node {
+ // We use exponential backoff for failed runnables, but at some point we cap at a given backoff time.
+ // To achieve this, we set MaxElapsedTime to 0, which will cap the backoff at MaxInterval.
+ bo := backoff.NewExponentialBackOff()
+ bo.MaxElapsedTime = 0
+
+ n := &node{
+ name: name,
+ runnable: runnable,
+
+ bo: bo,
+
+ sup: sup,
+ parent: parent,
+ }
+ n.reset()
+ return n
+}
+
+// resetNode sets up all the dynamic fields of the node, in preparation of starting a runnable. It clears the node's
+// children, groups and resets its context.
+func (n *node) reset() {
+ // Make new context. First, acquire parent context. For the root node that's Background, otherwise it's the
+ // parent's context.
+ var pCtx context.Context
+ if n.parent == nil {
+ pCtx = context.Background()
+ } else {
+ pCtx = n.parent.ctx
+ }
+ // Mark DN and supervisor in context.
+ ctx := context.WithValue(pCtx, dnKey, n.dn())
+ ctx = context.WithValue(ctx, supervisorKey, n.sup)
+ ctx, ctxC := context.WithCancel(ctx)
+ // Set context
+ n.ctx = ctx
+ n.ctxC = ctxC
+
+ // Clear children and state
+ n.state = nodeStateNew
+ n.children = make(map[string]*node)
+ n.groups = nil
+
+ // The node is now ready to be scheduled.
+}
+
+// nodeByDN returns a node by given DN from the supervisor.
+func (s *supervisor) nodeByDN(dn string) *node {
+ parts := strings.Split(dn, ".")
+ if parts[0] != "root" {
+ panic("DN does not start with root.")
+ }
+ parts = parts[1:]
+ cur := s.root
+ for {
+ if len(parts) == 0 {
+ return cur
+ }
+
+ next, ok := cur.children[parts[0]]
+ if !ok {
+ panic(fmt.Errorf("could not find %v (%s) in %s", parts, dn, cur))
+ }
+ cur = next
+ parts = parts[1:]
+ }
+}
+
+// reNodeName validates a node name against constraints.
+var reNodeName = regexp.MustCompile(`[a-z90-9_]{1,64}`)
+
+// runGroup schedules a new group of runnables to run on a node.
+func (n *node) runGroup(runnables map[string]Runnable) error {
+ // Check that the parent node is in the right state.
+ if n.state != nodeStateNew {
+ return fmt.Errorf("cannot run new runnable on non-NEW node")
+ }
+
+ // Check the requested runnable names.
+ for name, _ := range runnables {
+ if !reNodeName.MatchString(name) {
+ return fmt.Errorf("runnable name %q is invalid", name)
+ }
+ if _, ok := n.children[name]; ok {
+ return fmt.Errorf("runnable %q already exists", name)
+ }
+ }
+
+ // Create child nodes.
+ dns := make(map[string]string)
+ group := make(map[string]bool)
+ for name, runnable := range runnables {
+ if g := n.groupSiblings(name); g != nil {
+ return fmt.Errorf("duplicate child name %q", name)
+ }
+ node := newNode(name, runnable, n.sup, n)
+ n.children[name] = node
+
+ dns[name] = node.dn()
+ group[name] = true
+ }
+ // Add group.
+ n.groups = append(n.groups, group)
+
+ // Schedule execution of group members.
+ go func() {
+ for name, _ := range runnables {
+ n.sup.pReq <- &processorRequest{
+ schedule: &processorRequestSchedule{
+ dn: dns[name],
+ },
+ }
+ }
+ }()
+ return nil
+}
+
+// signal sequences state changes by signals received from runnables and updates a node's status accordingly.
+func (n *node) signal(signal SignalType) {
+ switch signal {
+ case SignalHealthy:
+ if n.state != nodeStateNew {
+ panic(fmt.Errorf("node %s signaled healthy", n))
+ }
+ n.state = nodeStateHealthy
+ n.bo.Reset()
+ case SignalDone:
+ if n.state != nodeStateHealthy {
+ panic(fmt.Errorf("node %s signaled done", n))
+ }
+ n.state = nodeStateDone
+ n.bo.Reset()
+ }
+}
diff --git a/metropolis/pkg/supervisor/supervisor_processor.go b/metropolis/pkg/supervisor/supervisor_processor.go
new file mode 100644
index 0000000..965a667
--- /dev/null
+++ b/metropolis/pkg/supervisor/supervisor_processor.go
@@ -0,0 +1,404 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package supervisor
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "runtime/debug"
+ "time"
+)
+
+// The processor maintains runnable goroutines - ie., when requested will start one, and then once it exists it will
+// record the result and act accordingly. It is also responsible for detecting and acting upon supervision subtrees that
+// need to be restarted after death (via a 'GC' process)
+
+// processorRequest is a request for the processor. Only one of the fields can be set.
+type processorRequest struct {
+ schedule *processorRequestSchedule
+ died *processorRequestDied
+ waitSettled *processorRequestWaitSettled
+}
+
+// processorRequestSchedule requests that a given node's runnable be started.
+type processorRequestSchedule struct {
+ dn string
+}
+
+// processorRequestDied is a signal from a runnable goroutine that the runnable has died.
+type processorRequestDied struct {
+ dn string
+ err error
+}
+
+type processorRequestWaitSettled struct {
+ waiter chan struct{}
+}
+
+// processor is the main processing loop.
+func (s *supervisor) processor(ctx context.Context) {
+ s.ilogger.Info("supervisor processor started")
+
+ // Waiters waiting for the GC to be settled.
+ var waiters []chan struct{}
+
+ // The GC will run every millisecond if needed. Any time the processor requests a change in the supervision tree
+ // (ie a death or a new runnable) it will mark the state as dirty and run the GC on the next millisecond cycle.
+ gc := time.NewTicker(1 * time.Millisecond)
+ defer gc.Stop()
+ clean := true
+
+ // How long has the GC been clean. This is used to notify 'settled' waiters.
+ cleanCycles := 0
+
+ markDirty := func() {
+ clean = false
+ cleanCycles = 0
+ }
+
+ for {
+ select {
+ case <-ctx.Done():
+ s.ilogger.Infof("supervisor processor exiting: %v", ctx.Err())
+ s.processKill()
+ s.ilogger.Info("supervisor exited")
+ return
+ case <-gc.C:
+ if !clean {
+ s.processGC()
+ }
+ clean = true
+ cleanCycles += 1
+
+ // This threshold is somewhat arbitrary. It's a balance between test speed and test reliability.
+ if cleanCycles > 50 {
+ for _, w := range waiters {
+ close(w)
+ }
+ waiters = nil
+ }
+ case r := <-s.pReq:
+ switch {
+ case r.schedule != nil:
+ s.processSchedule(r.schedule)
+ markDirty()
+ case r.died != nil:
+ s.processDied(r.died)
+ markDirty()
+ case r.waitSettled != nil:
+ waiters = append(waiters, r.waitSettled.waiter)
+ default:
+ panic(fmt.Errorf("unhandled request %+v", r))
+ }
+ }
+ }
+}
+
+// processKill cancels all nodes in the supervision tree. This is only called right before exiting the processor, so
+// they do not get automatically restarted.
+func (s *supervisor) processKill() {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // Gather all context cancel functions.
+ var cancels []func()
+ queue := []*node{s.root}
+ for {
+ if len(queue) == 0 {
+ break
+ }
+
+ cur := queue[0]
+ queue = queue[1:]
+
+ cancels = append(cancels, cur.ctxC)
+ for _, c := range cur.children {
+ queue = append(queue, c)
+ }
+ }
+
+ // Call all context cancels.
+ for _, c := range cancels {
+ c()
+ }
+}
+
+// processSchedule starts a node's runnable in a goroutine and records its output once it's done.
+func (s *supervisor) processSchedule(r *processorRequestSchedule) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ n := s.nodeByDN(r.dn)
+ go func() {
+ if !s.propagatePanic {
+ defer func() {
+ if rec := recover(); rec != nil {
+ s.pReq <- &processorRequest{
+ died: &processorRequestDied{
+ dn: r.dn,
+ err: fmt.Errorf("panic: %v, stacktrace: %s", rec, string(debug.Stack())),
+ },
+ }
+ }
+ }()
+ }
+
+ res := n.runnable(n.ctx)
+
+ s.pReq <- &processorRequest{
+ died: &processorRequestDied{
+ dn: r.dn,
+ err: res,
+ },
+ }
+ }()
+}
+
+// processDied records the result from a runnable goroutine, and updates its node state accordingly. If the result
+// is a death and not an expected exit, related nodes (ie. children and group siblings) are canceled accordingly.
+func (s *supervisor) processDied(r *processorRequestDied) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // Okay, so a Runnable has quit. What now?
+ n := s.nodeByDN(r.dn)
+ ctx := n.ctx
+
+ // Simple case: it was marked as Done and quit with no error.
+ if n.state == nodeStateDone && r.err == nil {
+ // Do nothing. This was supposed to happen. Keep the process as DONE.
+ return
+ }
+
+ // Find innermost error to check if it's a context canceled error.
+ perr := r.err
+ for {
+ if inner := errors.Unwrap(perr); inner != nil {
+ perr = inner
+ continue
+ }
+ break
+ }
+
+ // Simple case: the context was canceled and the returned error is the context error.
+ if err := ctx.Err(); err != nil && perr == err {
+ // Mark the node as canceled successfully.
+ n.state = nodeStateCanceled
+ return
+ }
+
+ // Otherwise, the Runnable should not have died or quit. Handle accordingly.
+ err := r.err
+ // A lack of returned error is also an error.
+ if err == nil {
+ err = fmt.Errorf("returned when %s", n.state)
+ } else {
+ err = fmt.Errorf("returned error when %s: %w", n.state, err)
+ }
+
+ s.ilogger.Errorf("Runnable %s died: %v", n.dn(), err)
+ // Mark as dead.
+ n.state = nodeStateDead
+
+ // Cancel that node's context, just in case something still depends on it.
+ n.ctxC()
+
+ // Cancel all siblings.
+ if n.parent != nil {
+ for name, _ := range n.parent.groupSiblings(n.name) {
+ if name == n.name {
+ continue
+ }
+ sibling := n.parent.children[name]
+ // TODO(q3k): does this need to run in a goroutine, ie. can a context cancel block?
+ sibling.ctxC()
+ }
+ }
+}
+
+// processGC runs the GC process. It's not really Garbage Collection, as in, it doesn't remove unnecessary tree nodes -
+// but it does find nodes that need to be restarted, find the subset that can and then schedules them for running.
+// As such, it's less of a Garbage Collector and more of a Necromancer. However, GC is a friendlier name.
+func (s *supervisor) processGC() {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // The 'GC' serves is the main business logic of the supervision tree. It traverses a locked tree and tries to
+ // find subtrees that must be restarted (because of a DEAD/CANCELED runnable). It then finds which of these
+ // subtrees that should be restarted can be restarted, ie. which ones are fully recursively DEAD/CANCELED. It
+ // also finds the smallest set of largest subtrees that can be restarted, ie. if there's multiple DEAD runnables
+ // that can be restarted at once, it will do so.
+
+ // Phase one: Find all leaves.
+ // This is a simple DFS that finds all the leaves of the tree, ie all nodes that do not have children nodes.
+ leaves := make(map[string]bool)
+ queue := []*node{s.root}
+ for {
+ if len(queue) == 0 {
+ break
+ }
+ cur := queue[0]
+ queue = queue[1:]
+
+ for _, c := range cur.children {
+ queue = append([]*node{c}, queue...)
+ }
+
+ if len(cur.children) == 0 {
+ leaves[cur.dn()] = true
+ }
+ }
+
+ // Phase two: traverse tree from node to root and make note of all subtrees that can be restarted.
+ // A subtree is restartable/ready iff every node in that subtree is either CANCELED, DEAD or DONE.
+ // Such a 'ready' subtree can be restarted by the supervisor if needed.
+
+ // DNs that we already visited.
+ visited := make(map[string]bool)
+ // DNs whose subtrees are ready to be restarted.
+ // These are all subtrees recursively - ie., root.a.a and root.a will both be marked here.
+ ready := make(map[string]bool)
+
+ // We build a queue of nodes to visit, starting from the leaves.
+ queue = []*node{}
+ for l, _ := range leaves {
+ queue = append(queue, s.nodeByDN(l))
+ }
+
+ for {
+ if len(queue) == 0 {
+ break
+ }
+
+ cur := queue[0]
+ curDn := cur.dn()
+
+ queue = queue[1:]
+
+ // Do we have a decision about our children?
+ allVisited := true
+ for _, c := range cur.children {
+ if !visited[c.dn()] {
+ allVisited = false
+ break
+ }
+ }
+
+ // If no decision about children is available, it means we ended up in this subtree through some shorter path
+ // of a shorter/lower-order leaf. There is a path to a leaf that's longer than the one that caused this node
+ // to be enqueued. Easy solution: just push back the current element and retry later.
+ if !allVisited {
+ // Push back to queue and wait for a decision later.
+ queue = append(queue, cur)
+ continue
+ }
+
+ // All children have been visited and we have an idea about whether they're ready/restartable. All of the node's
+ // children must be restartable in order for this node to be restartable.
+ childrenReady := true
+ for _, c := range cur.children {
+ if !ready[c.dn()] {
+ childrenReady = false
+ break
+ }
+ }
+
+ // In addition to children, the node itself must be restartable (ie. DONE, DEAD or CANCELED).
+ curReady := false
+ switch cur.state {
+ case nodeStateDone:
+ curReady = true
+ case nodeStateCanceled:
+ curReady = true
+ case nodeStateDead:
+ curReady = true
+ }
+
+ // Note down that we have an opinion on this node, and note that opinion down.
+ visited[curDn] = true
+ ready[curDn] = childrenReady && curReady
+
+ // Now we can also enqueue the parent of this node for processing.
+ if cur.parent != nil && !visited[cur.parent.dn()] {
+ queue = append(queue, cur.parent)
+ }
+ }
+
+ // Phase 3: traverse tree from root to find largest subtrees that need to be restarted and are ready to be
+ // restarted.
+
+ // All DNs that need to be restarted by the GC process.
+ want := make(map[string]bool)
+ // All DNs that need to be restarted and can be restarted by the GC process - a subset of 'want' DNs.
+ can := make(map[string]bool)
+ // The set difference between 'want' and 'can' are all nodes that should be restarted but can't yet (ie. because
+ // a child is still in the process of being canceled).
+
+ // DFS from root.
+ queue = []*node{s.root}
+ for {
+ if len(queue) == 0 {
+ break
+ }
+
+ cur := queue[0]
+ queue = queue[1:]
+
+ // If this node is DEAD or CANCELED it should be restarted.
+ if cur.state == nodeStateDead || cur.state == nodeStateCanceled {
+ want[cur.dn()] = true
+ }
+
+ // If it should be restarted and is ready to be restarted...
+ if want[cur.dn()] && ready[cur.dn()] {
+ // And its parent context is valid (ie hasn't been canceled), mark it as restartable.
+ if cur.parent == nil || cur.parent.ctx.Err() == nil {
+ can[cur.dn()] = true
+ continue
+ }
+ }
+
+ // Otherwise, traverse further down the tree to see if something else needs to be done.
+ for _, c := range cur.children {
+ queue = append(queue, c)
+ }
+ }
+
+ // Reinitialize and reschedule all subtrees
+ for dn, _ := range can {
+ n := s.nodeByDN(dn)
+
+ // Only back off when the node unexpectedly died - not when it got canceled.
+ bo := time.Duration(0)
+ if n.state == nodeStateDead {
+ bo = n.bo.NextBackOff()
+ }
+
+ // Prepare node for rescheduling - remove its children, reset its state to new.
+ n.reset()
+ s.ilogger.Infof("rescheduling supervised node %s with backoff %s", dn, bo.String())
+
+ // Reschedule node runnable to run after backoff.
+ go func(n *node, bo time.Duration) {
+ time.Sleep(bo)
+ s.pReq <- &processorRequest{
+ schedule: &processorRequestSchedule{dn: n.dn()},
+ }
+ }(n, bo)
+ }
+}
diff --git a/metropolis/pkg/supervisor/supervisor_support.go b/metropolis/pkg/supervisor/supervisor_support.go
new file mode 100644
index 0000000..d54b35c
--- /dev/null
+++ b/metropolis/pkg/supervisor/supervisor_support.go
@@ -0,0 +1,62 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package supervisor
+
+// Supporting infrastructure to allow running some non-Go payloads under supervision.
+
+import (
+ "context"
+ "net"
+ "os/exec"
+
+ "google.golang.org/grpc"
+)
+
+// GRPCServer creates a Runnable that serves gRPC requests as longs as it's not canceled.
+// If graceful is set to true, the server will be gracefully stopped instead of plain stopped. This means all pending
+// RPCs will finish, but also requires streaming gRPC handlers to check their context liveliness and exit accordingly.
+// If the server code does not support this, `graceful` should be false and the server will be killed violently instead.
+func GRPCServer(srv *grpc.Server, lis net.Listener, graceful bool) Runnable {
+ return func(ctx context.Context) error {
+ Signal(ctx, SignalHealthy)
+ errC := make(chan error)
+ go func() {
+ errC <- srv.Serve(lis)
+ }()
+ select {
+ case <-ctx.Done():
+ if graceful {
+ srv.GracefulStop()
+ } else {
+ srv.Stop()
+ }
+ return ctx.Err()
+ case err := <-errC:
+ return err
+ }
+ }
+}
+
+// RunCommand will create a Runnable that starts a long-running command, whose exit is determined to be a failure.
+func RunCommand(ctx context.Context, cmd *exec.Cmd) error {
+ Signal(ctx, SignalHealthy)
+ cmd.Stdout = RawLogger(ctx)
+ cmd.Stderr = RawLogger(ctx)
+ err := cmd.Run()
+ Logger(ctx).Infof("Command returned: %v", err)
+ return err
+}
diff --git a/metropolis/pkg/supervisor/supervisor_test.go b/metropolis/pkg/supervisor/supervisor_test.go
new file mode 100644
index 0000000..9c7bdb7
--- /dev/null
+++ b/metropolis/pkg/supervisor/supervisor_test.go
@@ -0,0 +1,557 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package supervisor
+
+import (
+ "context"
+ "fmt"
+ "testing"
+ "time"
+)
+
+func runnableBecomesHealthy(healthy, done chan struct{}) Runnable {
+ return func(ctx context.Context) error {
+ Signal(ctx, SignalHealthy)
+
+ go func() {
+ if healthy != nil {
+ healthy <- struct{}{}
+ }
+ }()
+
+ <-ctx.Done()
+
+ go func() {
+ if done != nil {
+ done <- struct{}{}
+ }
+ }()
+
+ return ctx.Err()
+ }
+}
+
+func runnableSpawnsMore(healthy, done chan struct{}, levels int) Runnable {
+ return func(ctx context.Context) error {
+ if levels > 0 {
+ err := RunGroup(ctx, map[string]Runnable{
+ "a": runnableSpawnsMore(nil, nil, levels-1),
+ "b": runnableSpawnsMore(nil, nil, levels-1),
+ })
+ if err != nil {
+ return err
+ }
+ }
+
+ Signal(ctx, SignalHealthy)
+
+ go func() {
+ if healthy != nil {
+ healthy <- struct{}{}
+ }
+ }()
+
+ <-ctx.Done()
+
+ go func() {
+ if done != nil {
+ done <- struct{}{}
+ }
+ }()
+ return ctx.Err()
+ }
+}
+
+// rc is a Remote Controlled runnable. It is a generic runnable used for testing the supervisor.
+type rc struct {
+ req chan rcRunnableRequest
+}
+
+type rcRunnableRequest struct {
+ cmd rcRunnableCommand
+ stateC chan rcRunnableState
+}
+
+type rcRunnableCommand int
+
+const (
+ rcRunnableCommandBecomeHealthy rcRunnableCommand = iota
+ rcRunnableCommandBecomeDone
+ rcRunnableCommandDie
+ rcRunnableCommandPanic
+ rcRunnableCommandState
+)
+
+type rcRunnableState int
+
+const (
+ rcRunnableStateNew rcRunnableState = iota
+ rcRunnableStateHealthy
+ rcRunnableStateDone
+)
+
+func (r *rc) becomeHealthy() {
+ r.req <- rcRunnableRequest{cmd: rcRunnableCommandBecomeHealthy}
+}
+
+func (r *rc) becomeDone() {
+ r.req <- rcRunnableRequest{cmd: rcRunnableCommandBecomeDone}
+}
+func (r *rc) die() {
+ r.req <- rcRunnableRequest{cmd: rcRunnableCommandDie}
+}
+
+func (r *rc) panic() {
+ r.req <- rcRunnableRequest{cmd: rcRunnableCommandPanic}
+}
+
+func (r *rc) state() rcRunnableState {
+ c := make(chan rcRunnableState)
+ r.req <- rcRunnableRequest{
+ cmd: rcRunnableCommandState,
+ stateC: c,
+ }
+ return <-c
+}
+
+func (r *rc) waitState(s rcRunnableState) {
+ // This is poll based. Making it non-poll based would make the RC runnable logic a bit more complex for little gain.
+ for {
+ got := r.state()
+ if got == s {
+ return
+ }
+ time.Sleep(10 * time.Millisecond)
+ }
+}
+
+func newRC() *rc {
+ return &rc{
+ req: make(chan rcRunnableRequest),
+ }
+}
+
+// Remote Controlled Runnable
+func (r *rc) runnable() Runnable {
+ return func(ctx context.Context) error {
+ state := rcRunnableStateNew
+
+ for {
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case r := <-r.req:
+ switch r.cmd {
+ case rcRunnableCommandBecomeHealthy:
+ Signal(ctx, SignalHealthy)
+ state = rcRunnableStateHealthy
+ case rcRunnableCommandBecomeDone:
+ Signal(ctx, SignalDone)
+ state = rcRunnableStateDone
+ case rcRunnableCommandDie:
+ return fmt.Errorf("died on request")
+ case rcRunnableCommandPanic:
+ panic("at the disco")
+ case rcRunnableCommandState:
+ r.stateC <- state
+ }
+ }
+ }
+ }
+}
+
+func TestSimple(t *testing.T) {
+ h1 := make(chan struct{})
+ d1 := make(chan struct{})
+ h2 := make(chan struct{})
+ d2 := make(chan struct{})
+
+ ctx, ctxC := context.WithCancel(context.Background())
+ defer ctxC()
+ s := New(ctx, func(ctx context.Context) error {
+ err := RunGroup(ctx, map[string]Runnable{
+ "one": runnableBecomesHealthy(h1, d1),
+ "two": runnableBecomesHealthy(h2, d2),
+ })
+ if err != nil {
+ return err
+ }
+ Signal(ctx, SignalHealthy)
+ Signal(ctx, SignalDone)
+ return nil
+ }, WithPropagatePanic)
+
+ // Expect both to start running.
+ s.waitSettleError(ctx, t)
+ select {
+ case <-h1:
+ default:
+ t.Fatalf("runnable 'one' didn't start")
+ }
+ select {
+ case <-h2:
+ default:
+ t.Fatalf("runnable 'one' didn't start")
+ }
+}
+
+func TestSimpleFailure(t *testing.T) {
+ h1 := make(chan struct{})
+ d1 := make(chan struct{})
+ two := newRC()
+
+ ctx, ctxC := context.WithTimeout(context.Background(), 10*time.Second)
+ defer ctxC()
+ s := New(ctx, func(ctx context.Context) error {
+ err := RunGroup(ctx, map[string]Runnable{
+ "one": runnableBecomesHealthy(h1, d1),
+ "two": two.runnable(),
+ })
+ if err != nil {
+ return err
+ }
+ Signal(ctx, SignalHealthy)
+ Signal(ctx, SignalDone)
+ return nil
+ }, WithPropagatePanic)
+ s.waitSettleError(ctx, t)
+
+ two.becomeHealthy()
+ s.waitSettleError(ctx, t)
+ // Expect one to start running.
+ select {
+ case <-h1:
+ default:
+ t.Fatalf("runnable 'one' didn't start")
+ }
+
+ // Kill off two, one should restart.
+ two.die()
+ s.waitSettleError(ctx, t)
+ select {
+ case <-d1:
+ default:
+ t.Fatalf("runnable 'one' didn't acknowledge cancel")
+ }
+
+ // And one should start running again.
+ s.waitSettleError(ctx, t)
+ select {
+ case <-h1:
+ default:
+ t.Fatalf("runnable 'one' didn't restart")
+ }
+}
+
+func TestDeepFailure(t *testing.T) {
+ h1 := make(chan struct{})
+ d1 := make(chan struct{})
+ two := newRC()
+
+ ctx, ctxC := context.WithTimeout(context.Background(), 10*time.Second)
+ defer ctxC()
+ s := New(ctx, func(ctx context.Context) error {
+ err := RunGroup(ctx, map[string]Runnable{
+ "one": runnableSpawnsMore(h1, d1, 5),
+ "two": two.runnable(),
+ })
+ if err != nil {
+ return err
+ }
+ Signal(ctx, SignalHealthy)
+ Signal(ctx, SignalDone)
+ return nil
+ }, WithPropagatePanic)
+
+ two.becomeHealthy()
+ s.waitSettleError(ctx, t)
+ // Expect one to start running.
+ select {
+ case <-h1:
+ default:
+ t.Fatalf("runnable 'one' didn't start")
+ }
+
+ // Kill off two, one should restart.
+ two.die()
+ s.waitSettleError(ctx, t)
+ select {
+ case <-d1:
+ default:
+ t.Fatalf("runnable 'one' didn't acknowledge cancel")
+ }
+
+ // And one should start running again.
+ s.waitSettleError(ctx, t)
+ select {
+ case <-h1:
+ default:
+ t.Fatalf("runnable 'one' didn't restart")
+ }
+}
+
+func TestPanic(t *testing.T) {
+ h1 := make(chan struct{})
+ d1 := make(chan struct{})
+ two := newRC()
+
+ ctx, ctxC := context.WithCancel(context.Background())
+ defer ctxC()
+ s := New(ctx, func(ctx context.Context) error {
+ err := RunGroup(ctx, map[string]Runnable{
+ "one": runnableBecomesHealthy(h1, d1),
+ "two": two.runnable(),
+ })
+ if err != nil {
+ return err
+ }
+ Signal(ctx, SignalHealthy)
+ Signal(ctx, SignalDone)
+ return nil
+ })
+
+ two.becomeHealthy()
+ s.waitSettleError(ctx, t)
+ // Expect one to start running.
+ select {
+ case <-h1:
+ default:
+ t.Fatalf("runnable 'one' didn't start")
+ }
+
+ // Kill off two, one should restart.
+ two.panic()
+ s.waitSettleError(ctx, t)
+ select {
+ case <-d1:
+ default:
+ t.Fatalf("runnable 'one' didn't acknowledge cancel")
+ }
+
+ // And one should start running again.
+ s.waitSettleError(ctx, t)
+ select {
+ case <-h1:
+ default:
+ t.Fatalf("runnable 'one' didn't restart")
+ }
+}
+
+func TestMultipleLevelFailure(t *testing.T) {
+ ctx, ctxC := context.WithCancel(context.Background())
+ defer ctxC()
+ New(ctx, func(ctx context.Context) error {
+ err := RunGroup(ctx, map[string]Runnable{
+ "one": runnableSpawnsMore(nil, nil, 4),
+ "two": runnableSpawnsMore(nil, nil, 4),
+ })
+ if err != nil {
+ return err
+ }
+ Signal(ctx, SignalHealthy)
+ Signal(ctx, SignalDone)
+ return nil
+ }, WithPropagatePanic)
+}
+
+func TestBackoff(t *testing.T) {
+ one := newRC()
+
+ ctx, ctxC := context.WithTimeout(context.Background(), 20*time.Second)
+ defer ctxC()
+
+ s := New(ctx, func(ctx context.Context) error {
+ if err := Run(ctx, "one", one.runnable()); err != nil {
+ return err
+ }
+ Signal(ctx, SignalHealthy)
+ Signal(ctx, SignalDone)
+ return nil
+ }, WithPropagatePanic)
+
+ one.becomeHealthy()
+ // Die a bunch of times in a row, this brings up the next exponential backoff to over a second.
+ for i := 0; i < 4; i += 1 {
+ one.die()
+ one.waitState(rcRunnableStateNew)
+ }
+ // Measure how long it takes for the runnable to respawn after a number of failures
+ start := time.Now()
+ one.die()
+ one.becomeHealthy()
+ one.waitState(rcRunnableStateHealthy)
+ taken := time.Since(start)
+ if taken < 1*time.Second {
+ t.Errorf("Runnable took %v to restart, wanted at least a second from backoff", taken)
+ }
+
+ s.waitSettleError(ctx, t)
+ // Now that we've become healthy, die again. Becoming healthy resets the backoff.
+ start = time.Now()
+ one.die()
+ one.becomeHealthy()
+ one.waitState(rcRunnableStateHealthy)
+ taken = time.Since(start)
+ if taken > 1*time.Second || taken < 100*time.Millisecond {
+ t.Errorf("Runnable took %v to restart, wanted at least 100ms from backoff and at most 1s from backoff reset", taken)
+ }
+}
+
+// TestResilience throws some curveballs at the supervisor - either programming errors or high load. It then ensures
+// that another runnable is running, and that it restarts on its sibling failure.
+func TestResilience(t *testing.T) {
+ // request/response channel for testing liveness of the 'one' runnable
+ req := make(chan chan struct{})
+
+ // A runnable that responds on the 'req' channel.
+ one := func(ctx context.Context) error {
+ Signal(ctx, SignalHealthy)
+ for {
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case r := <-req:
+ r <- struct{}{}
+ }
+ }
+ }
+ oneSibling := newRC()
+
+ oneTest := func() {
+ timeout := time.NewTicker(1000 * time.Millisecond)
+ ping := make(chan struct{})
+ req <- ping
+ select {
+ case <-ping:
+ case <-timeout.C:
+ t.Fatalf("one ping response timeout")
+ }
+ timeout.Stop()
+ }
+
+ // A nasty runnable that calls Signal with the wrong context (this is a programming error)
+ two := func(ctx context.Context) error {
+ Signal(context.TODO(), SignalHealthy)
+ return nil
+ }
+
+ // A nasty runnable that calls Signal wrong (this is a programming error).
+ three := func(ctx context.Context) error {
+ Signal(ctx, SignalDone)
+ return nil
+ }
+
+ // A nasty runnable that runs in a busy loop (this is a programming error).
+ four := func(ctx context.Context) error {
+ for {
+ time.Sleep(0)
+ }
+ }
+
+ // A nasty runnable that keeps creating more runnables.
+ five := func(ctx context.Context) error {
+ i := 1
+ for {
+ err := Run(ctx, fmt.Sprintf("r%d", i), runnableSpawnsMore(nil, nil, 2))
+ if err != nil {
+ return err
+ }
+
+ time.Sleep(100 * time.Millisecond)
+ i += 1
+ }
+ }
+
+ ctx, ctxC := context.WithCancel(context.Background())
+ defer ctxC()
+ New(ctx, func(ctx context.Context) error {
+ RunGroup(ctx, map[string]Runnable{
+ "one": one,
+ "oneSibling": oneSibling.runnable(),
+ })
+ rs := map[string]Runnable{
+ "two": two, "three": three, "four": four, "five": five,
+ }
+ for k, v := range rs {
+ if err := Run(ctx, k, v); err != nil {
+ return err
+ }
+ }
+ Signal(ctx, SignalHealthy)
+ Signal(ctx, SignalDone)
+ return nil
+ })
+
+ // Five rounds of letting one run, then restarting it.
+ for i := 0; i < 5; i += 1 {
+ oneSibling.becomeHealthy()
+ oneSibling.waitState(rcRunnableStateHealthy)
+
+ // 'one' should work for at least a second.
+ deadline := time.Now().Add(1 * time.Second)
+ for {
+ if time.Now().After(deadline) {
+ break
+ }
+
+ oneTest()
+ }
+
+ // Killing 'oneSibling' should restart one.
+ oneSibling.panic()
+ }
+ // Make sure 'one' is still okay.
+ oneTest()
+}
+
+func ExampleNew() {
+ // Minimal runnable that is immediately done.
+ childC := make(chan struct{})
+ child := func(ctx context.Context) error {
+ Signal(ctx, SignalHealthy)
+ close(childC)
+ Signal(ctx, SignalDone)
+ return nil
+ }
+
+ // Start a supervision tree with a root runnable.
+ ctx, ctxC := context.WithCancel(context.Background())
+ defer ctxC()
+ New(ctx, func(ctx context.Context) error {
+ err := Run(ctx, "child", child)
+ if err != nil {
+ return fmt.Errorf("could not run 'child': %w", err)
+ }
+ Signal(ctx, SignalHealthy)
+
+ t := time.NewTicker(time.Second)
+ defer t.Stop()
+
+ // Do something in the background, and exit on context cancel.
+ for {
+ select {
+ case <-t.C:
+ fmt.Printf("tick!")
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+ }
+ })
+
+ // root.child will close this channel.
+ <-childC
+}
diff --git a/metropolis/pkg/supervisor/supervisor_testhelpers.go b/metropolis/pkg/supervisor/supervisor_testhelpers.go
new file mode 100644
index 0000000..771e02f
--- /dev/null
+++ b/metropolis/pkg/supervisor/supervisor_testhelpers.go
@@ -0,0 +1,50 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package supervisor
+
+import (
+ "context"
+ "testing"
+)
+
+// waitSettle waits until the supervisor reaches a 'settled' state - ie., one
+// where no actions have been performed for a number of GC cycles.
+// This is used in tests only.
+func (s *supervisor) waitSettle(ctx context.Context) error {
+ waiter := make(chan struct{})
+ s.pReq <- &processorRequest{
+ waitSettled: &processorRequestWaitSettled{
+ waiter: waiter,
+ },
+ }
+
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-waiter:
+ return nil
+ }
+}
+
+// waitSettleError wraps waitSettle to fail a test if an error occurs, eg. the
+// context is canceled.
+func (s *supervisor) waitSettleError(ctx context.Context, t *testing.T) {
+ err := s.waitSettle(ctx)
+ if err != nil {
+ t.Fatalf("waitSettle: %v", err)
+ }
+}
diff --git a/metropolis/pkg/sysfs/BUILD.bazel b/metropolis/pkg/sysfs/BUILD.bazel
new file mode 100644
index 0000000..0cea1f8
--- /dev/null
+++ b/metropolis/pkg/sysfs/BUILD.bazel
@@ -0,0 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "go_default_library",
+ srcs = ["uevents.go"],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/pkg/sysfs",
+ visibility = ["//visibility:public"],
+)
diff --git a/metropolis/pkg/sysfs/uevents.go b/metropolis/pkg/sysfs/uevents.go
new file mode 100644
index 0000000..fed4319
--- /dev/null
+++ b/metropolis/pkg/sysfs/uevents.go
@@ -0,0 +1,50 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sysfs
+
+import (
+ "bufio"
+ "io"
+ "os"
+ "strings"
+)
+
+func ReadUevents(filename string) (map[string]string, error) {
+ f, err := os.Open(filename)
+ if err != nil {
+ return nil, err
+ }
+ defer f.Close()
+ ueventMap := make(map[string]string)
+ reader := bufio.NewReader(f)
+ for {
+ name, err := reader.ReadString(byte('='))
+ if err == io.EOF {
+ break
+ } else if err != nil {
+ return nil, err
+ }
+ value, err := reader.ReadString(byte('\n'))
+ if err == io.EOF {
+ continue
+ } else if err != nil {
+ return nil, err
+ }
+ ueventMap[strings.Trim(name, "=")] = strings.TrimSpace(value)
+ }
+ return ueventMap, nil
+}
diff --git a/metropolis/pkg/tpm/BUILD.bazel b/metropolis/pkg/tpm/BUILD.bazel
new file mode 100644
index 0000000..d06ff37
--- /dev/null
+++ b/metropolis/pkg/tpm/BUILD.bazel
@@ -0,0 +1,22 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "go_default_library",
+ srcs = [
+ "credactivation_compat.go",
+ "tpm.go",
+ ],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/pkg/tpm",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//metropolis/pkg/logtree:go_default_library",
+ "//metropolis/pkg/sysfs:go_default_library",
+ "@com_github_gogo_protobuf//proto:go_default_library",
+ "@com_github_google_go_tpm//tpm2:go_default_library",
+ "@com_github_google_go_tpm//tpmutil:go_default_library",
+ "@com_github_google_go_tpm_tools//proto:go_default_library",
+ "@com_github_google_go_tpm_tools//tpm2tools:go_default_library",
+ "@com_github_pkg_errors//:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/metropolis/pkg/tpm/credactivation_compat.go b/metropolis/pkg/tpm/credactivation_compat.go
new file mode 100644
index 0000000..039f8d5
--- /dev/null
+++ b/metropolis/pkg/tpm/credactivation_compat.go
@@ -0,0 +1,123 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tpm
+
+// This file is adapted from github.com/google/go-tpm/tpm2/credactivation which outputs broken
+// challenges for unknown reasons. They use u16 length-delimited outputs for the challenge blobs
+// which is incorrect. Rather than rewriting the routine, we only applied minimal fixes to it
+// and skip the ECC part of the issue (because we would rather trust the proprietary RSA implementation).
+//
+// TODO(lorenz): I'll eventually deal with this upstream, but for now just fix it here (it's not that)
+// much code after all (https://github.com/google/go-tpm/issues/121)
+
+import (
+ "crypto/aes"
+ "crypto/cipher"
+ "crypto/hmac"
+ "crypto/rsa"
+ "fmt"
+ "io"
+
+ "github.com/google/go-tpm/tpm2"
+ "github.com/google/go-tpm/tpmutil"
+)
+
+const (
+ labelIdentity = "IDENTITY"
+ labelStorage = "STORAGE"
+ labelIntegrity = "INTEGRITY"
+)
+
+func generateRSA(aik *tpm2.HashValue, pub *rsa.PublicKey, symBlockSize int, secret []byte, rnd io.Reader) ([]byte, []byte, error) {
+ newAIKHash, err := aik.Alg.HashConstructor()
+ if err != nil {
+ return nil, nil, err
+ }
+
+ // The seed length should match the keysize used by the EKs symmetric cipher.
+ // For typical RSA EKs, this will be 128 bits (16 bytes).
+ // Spec: TCG 2.0 EK Credential Profile revision 14, section 2.1.5.1.
+ seed := make([]byte, symBlockSize)
+ if _, err := io.ReadFull(rnd, seed); err != nil {
+ return nil, nil, fmt.Errorf("generating seed: %v", err)
+ }
+
+ // Encrypt the seed value using the provided public key.
+ // See annex B, section 10.4 of the TPM specification revision 2 part 1.
+ label := append([]byte(labelIdentity), 0)
+ encSecret, err := rsa.EncryptOAEP(newAIKHash(), rnd, pub, seed, label)
+ if err != nil {
+ return nil, nil, fmt.Errorf("generating encrypted seed: %v", err)
+ }
+
+ // Generate the encrypted credential by convolving the seed with the digest of
+ // the AIK, and using the result as the key to encrypt the secret.
+ // See section 24.4 of TPM 2.0 specification, part 1.
+ aikNameEncoded, err := aik.Encode()
+ if err != nil {
+ return nil, nil, fmt.Errorf("encoding aikName: %v", err)
+ }
+ symmetricKey, err := tpm2.KDFa(aik.Alg, seed, labelStorage, aikNameEncoded, nil, len(seed)*8)
+ if err != nil {
+ return nil, nil, fmt.Errorf("generating symmetric key: %v", err)
+ }
+ c, err := aes.NewCipher(symmetricKey)
+ if err != nil {
+ return nil, nil, fmt.Errorf("symmetric cipher setup: %v", err)
+ }
+ cv, err := tpmutil.Pack(tpmutil.U16Bytes(secret))
+ if err != nil {
+ return nil, nil, fmt.Errorf("generating cv (TPM2B_Digest): %v", err)
+ }
+
+ // IV is all null bytes. encIdentity represents the encrypted credential.
+ encIdentity := make([]byte, len(cv))
+ cipher.NewCFBEncrypter(c, make([]byte, len(symmetricKey))).XORKeyStream(encIdentity, cv)
+
+ // Generate the integrity HMAC, which is used to protect the integrity of the
+ // encrypted structure.
+ // See section 24.5 of the TPM specification revision 2 part 1.
+ macKey, err := tpm2.KDFa(aik.Alg, seed, labelIntegrity, nil, nil, newAIKHash().Size()*8)
+ if err != nil {
+ return nil, nil, fmt.Errorf("generating HMAC key: %v", err)
+ }
+
+ mac := hmac.New(newAIKHash, macKey)
+ mac.Write(encIdentity)
+ mac.Write(aikNameEncoded)
+ integrityHMAC := mac.Sum(nil)
+
+ idObject := &tpm2.IDObject{
+ IntegrityHMAC: integrityHMAC,
+ EncIdentity: encIdentity,
+ }
+ id, err := tpmutil.Pack(idObject)
+ if err != nil {
+ return nil, nil, fmt.Errorf("encoding IDObject: %v", err)
+ }
+
+ packedID, err := tpmutil.Pack(id)
+ if err != nil {
+ return nil, nil, fmt.Errorf("packing id: %v", err)
+ }
+ packedEncSecret, err := tpmutil.Pack(encSecret)
+ if err != nil {
+ return nil, nil, fmt.Errorf("packing encSecret: %v", err)
+ }
+
+ return packedID, packedEncSecret, nil
+}
diff --git a/metropolis/pkg/tpm/eventlog/BUILD.bazel b/metropolis/pkg/tpm/eventlog/BUILD.bazel
new file mode 100644
index 0000000..94a7ee9
--- /dev/null
+++ b/metropolis/pkg/tpm/eventlog/BUILD.bazel
@@ -0,0 +1,17 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "go_default_library",
+ srcs = [
+ "compat.go",
+ "eventlog.go",
+ "secureboot.go",
+ ],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/pkg/tpm/eventlog",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//metropolis/pkg/tpm/eventlog/internal:go_default_library",
+ "@com_github_google_certificate_transparency_go//x509:go_default_library",
+ "@com_github_google_go_tpm//tpm2:go_default_library",
+ ],
+)
diff --git a/metropolis/pkg/tpm/eventlog/LICENSE-3RD-PARTY.txt b/metropolis/pkg/tpm/eventlog/LICENSE-3RD-PARTY.txt
new file mode 100644
index 0000000..2d3298c
--- /dev/null
+++ b/metropolis/pkg/tpm/eventlog/LICENSE-3RD-PARTY.txt
@@ -0,0 +1,12 @@
+Copyright 2020 Google Inc.
+Licensed under the Apache License, Version 2.0 (the "License"); you may not
+use this file except in compliance with the License. You may obtain a copy of
+the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+License for the specific language governing permissions and limitations under
+the License.
\ No newline at end of file
diff --git a/metropolis/pkg/tpm/eventlog/compat.go b/metropolis/pkg/tpm/eventlog/compat.go
new file mode 100644
index 0000000..f83972b
--- /dev/null
+++ b/metropolis/pkg/tpm/eventlog/compat.go
@@ -0,0 +1,32 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package eventlog
+
+// This file contains compatibility functions for our TPM library
+
+import (
+ "crypto"
+)
+
+// ConvertRawPCRs converts from raw PCRs to eventlog PCR structures
+func ConvertRawPCRs(pcrs [][]byte) []PCR {
+ var evPCRs []PCR
+ for i, digest := range pcrs {
+ evPCRs = append(evPCRs, PCR{DigestAlg: crypto.SHA256, Index: i, Digest: digest})
+ }
+ return evPCRs
+}
diff --git a/metropolis/pkg/tpm/eventlog/eventlog.go b/metropolis/pkg/tpm/eventlog/eventlog.go
new file mode 100644
index 0000000..49a8a26
--- /dev/null
+++ b/metropolis/pkg/tpm/eventlog/eventlog.go
@@ -0,0 +1,646 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Taken and pruned from go-attestation revision 2453c8f39a4ff46009f6a9db6fb7c6cca789d9a1 under Apache 2.0
+
+package eventlog
+
+import (
+ "bytes"
+ "crypto"
+ "crypto/sha1"
+ "crypto/sha256"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "sort"
+
+ // Ensure hashes are available.
+ _ "crypto/sha256"
+
+ "github.com/google/go-tpm/tpm2"
+)
+
+// HashAlg identifies a hashing Algorithm.
+type HashAlg uint8
+
+// Valid hash algorithms.
+var (
+ HashSHA1 = HashAlg(tpm2.AlgSHA1)
+ HashSHA256 = HashAlg(tpm2.AlgSHA256)
+)
+
+func (a HashAlg) cryptoHash() crypto.Hash {
+ switch a {
+ case HashSHA1:
+ return crypto.SHA1
+ case HashSHA256:
+ return crypto.SHA256
+ }
+ return 0
+}
+
+func (a HashAlg) goTPMAlg() tpm2.Algorithm {
+ switch a {
+ case HashSHA1:
+ return tpm2.AlgSHA1
+ case HashSHA256:
+ return tpm2.AlgSHA256
+ }
+ return 0
+}
+
+// String returns a human-friendly representation of the hash algorithm.
+func (a HashAlg) String() string {
+ switch a {
+ case HashSHA1:
+ return "SHA1"
+ case HashSHA256:
+ return "SHA256"
+ }
+ return fmt.Sprintf("HashAlg<%d>", int(a))
+}
+
+// ReplayError describes the parsed events that failed to verify against
+// a particular PCR.
+type ReplayError struct {
+ Events []Event
+ invalidPCRs []int
+}
+
+func (e ReplayError) affected(pcr int) bool {
+ for _, p := range e.invalidPCRs {
+ if p == pcr {
+ return true
+ }
+ }
+ return false
+}
+
+// Error returns a human-friendly description of replay failures.
+func (e ReplayError) Error() string {
+ return fmt.Sprintf("event log failed to verify: the following registers failed to replay: %v", e.invalidPCRs)
+}
+
+// TPM algorithms. See the TPM 2.0 specification section 6.3.
+//
+// https://trustedcomputinggroup.org/wp-content/uploads/TPM-Rev-2.0-Part-2-Structures-01.38.pdf#page=42
+const (
+ algSHA1 uint16 = 0x0004
+ algSHA256 uint16 = 0x000B
+)
+
+// EventType indicates what kind of data an event is reporting.
+type EventType uint32
+
+// Event is a single event from a TCG event log. This reports descrete items such
+// as BIOs measurements or EFI states.
+type Event struct {
+ // order of the event in the event log.
+ sequence int
+
+ // PCR index of the event.
+ Index int
+ // Type of the event.
+ Type EventType
+
+ // Data of the event. For certain kinds of events, this must match the event
+ // digest to be valid.
+ Data []byte
+ // Digest is the verified digest of the event data. While an event can have
+ // multiple for different hash values, this is the one that was matched to the
+ // PCR value.
+ Digest []byte
+
+ // TODO(ericchiang): Provide examples or links for which event types must
+ // match their data to their digest.
+}
+
+func (e *Event) digestEquals(b []byte) error {
+ if len(e.Digest) == 0 {
+ return errors.New("no digests present")
+ }
+
+ switch len(e.Digest) {
+ case crypto.SHA256.Size():
+ s := sha256.Sum256(b)
+ if bytes.Equal(s[:], e.Digest) {
+ return nil
+ }
+ case crypto.SHA1.Size():
+ s := sha1.Sum(b)
+ if bytes.Equal(s[:], e.Digest) {
+ return nil
+ }
+ default:
+ return fmt.Errorf("cannot compare hash of length %d", len(e.Digest))
+ }
+
+ return fmt.Errorf("digest (len %d) does not match", len(e.Digest))
+}
+
+// EventLog is a parsed measurement log. This contains unverified data representing
+// boot events that must be replayed against PCR values to determine authenticity.
+type EventLog struct {
+ // Algs holds the set of algorithms that the event log uses.
+ Algs []HashAlg
+
+ rawEvents []rawEvent
+}
+
+func (e *EventLog) clone() *EventLog {
+ out := EventLog{
+ Algs: make([]HashAlg, len(e.Algs)),
+ rawEvents: make([]rawEvent, len(e.rawEvents)),
+ }
+ copy(out.Algs, e.Algs)
+ copy(out.rawEvents, e.rawEvents)
+ return &out
+}
+
+type elWorkaround struct {
+ id string
+ affectedPCR int
+ apply func(e *EventLog) error
+}
+
+// inject3 appends two new events into the event log.
+func inject3(e *EventLog, pcr int, data1, data2, data3 string) error {
+ if err := inject(e, pcr, data1); err != nil {
+ return err
+ }
+ if err := inject(e, pcr, data2); err != nil {
+ return err
+ }
+ return inject(e, pcr, data3)
+}
+
+// inject2 appends two new events into the event log.
+func inject2(e *EventLog, pcr int, data1, data2 string) error {
+ if err := inject(e, pcr, data1); err != nil {
+ return err
+ }
+ return inject(e, pcr, data2)
+}
+
+// inject appends a new event into the event log.
+func inject(e *EventLog, pcr int, data string) error {
+ evt := rawEvent{
+ data: []byte(data),
+ index: pcr,
+ sequence: e.rawEvents[len(e.rawEvents)-1].sequence + 1,
+ }
+ for _, alg := range e.Algs {
+ h := alg.cryptoHash().New()
+ h.Write([]byte(data))
+ evt.digests = append(evt.digests, digest{hash: alg.cryptoHash(), data: h.Sum(nil)})
+ }
+ e.rawEvents = append(e.rawEvents, evt)
+ return nil
+}
+
+const (
+ ebsInvocation = "Exit Boot Services Invocation"
+ ebsSuccess = "Exit Boot Services Returned with Success"
+ ebsFailure = "Exit Boot Services Returned with Failure"
+)
+
+var eventlogWorkarounds = []elWorkaround{
+ {
+ id: "EBS Invocation + Success",
+ affectedPCR: 5,
+ apply: func(e *EventLog) error {
+ return inject2(e, 5, ebsInvocation, ebsSuccess)
+ },
+ },
+ {
+ id: "EBS Invocation + Failure",
+ affectedPCR: 5,
+ apply: func(e *EventLog) error {
+ return inject2(e, 5, ebsInvocation, ebsFailure)
+ },
+ },
+ {
+ id: "EBS Invocation + Failure + Success",
+ affectedPCR: 5,
+ apply: func(e *EventLog) error {
+ return inject3(e, 5, ebsInvocation, ebsFailure, ebsSuccess)
+ },
+ },
+}
+
+// Verify replays the event log against a TPM's PCR values, returning the
+// events which could be matched to a provided PCR value.
+// An error is returned if the replayed digest for events with a given PCR
+// index do not match any provided value for that PCR index.
+func (e *EventLog) Verify(pcrs []PCR) ([]Event, error) {
+ events, err := e.verify(pcrs)
+ // If there were any issues replaying the PCRs, try each of the workarounds
+ // in turn.
+ // TODO(jsonp): Allow workarounds to be combined.
+ if rErr, isReplayErr := err.(ReplayError); isReplayErr {
+ for _, wkrd := range eventlogWorkarounds {
+ if !rErr.affected(wkrd.affectedPCR) {
+ continue
+ }
+ el := e.clone()
+ if err := wkrd.apply(el); err != nil {
+ return nil, fmt.Errorf("failed applying workaround %q: %v", wkrd.id, err)
+ }
+ if events, err := el.verify(pcrs); err == nil {
+ return events, nil
+ }
+ }
+ }
+
+ return events, err
+}
+
+// PCR encapsulates the value of a PCR at a point in time.
+type PCR struct {
+ Index int
+ Digest []byte
+ DigestAlg crypto.Hash
+}
+
+func (e *EventLog) verify(pcrs []PCR) ([]Event, error) {
+ events, err := replayEvents(e.rawEvents, pcrs)
+ if err != nil {
+ if _, isReplayErr := err.(ReplayError); isReplayErr {
+ return nil, err
+ }
+ return nil, fmt.Errorf("pcrs failed to replay: %v", err)
+ }
+ return events, nil
+}
+
+func extend(pcr PCR, replay []byte, e rawEvent) (pcrDigest []byte, eventDigest []byte, err error) {
+ h := pcr.DigestAlg
+
+ for _, digest := range e.digests {
+ if digest.hash != pcr.DigestAlg {
+ continue
+ }
+ if len(digest.data) != len(pcr.Digest) {
+ return nil, nil, fmt.Errorf("digest data length (%d) doesn't match PCR digest length (%d)", len(digest.data), len(pcr.Digest))
+ }
+ hash := h.New()
+ if len(replay) != 0 {
+ hash.Write(replay)
+ } else {
+ b := make([]byte, h.Size())
+ hash.Write(b)
+ }
+ hash.Write(digest.data)
+ return hash.Sum(nil), digest.data, nil
+ }
+ return nil, nil, fmt.Errorf("no event digest matches pcr algorithm: %v", pcr.DigestAlg)
+}
+
+// replayPCR replays the event log for a specific PCR, using pcr and
+// event digests with the algorithm in pcr. An error is returned if the
+// replayed values do not match the final PCR digest, or any event tagged
+// with that PCR does not posess an event digest with the specified algorithm.
+func replayPCR(rawEvents []rawEvent, pcr PCR) ([]Event, bool) {
+ var (
+ replay []byte
+ outEvents []Event
+ )
+
+ for _, e := range rawEvents {
+ if e.index != pcr.Index {
+ continue
+ }
+
+ replayValue, digest, err := extend(pcr, replay, e)
+ if err != nil {
+ return nil, false
+ }
+ replay = replayValue
+ outEvents = append(outEvents, Event{sequence: e.sequence, Data: e.data, Digest: digest, Index: pcr.Index, Type: e.typ})
+ }
+
+ if len(outEvents) > 0 && !bytes.Equal(replay, pcr.Digest) {
+ return nil, false
+ }
+ return outEvents, true
+}
+
+type pcrReplayResult struct {
+ events []Event
+ successful bool
+}
+
+func replayEvents(rawEvents []rawEvent, pcrs []PCR) ([]Event, error) {
+ var (
+ invalidReplays []int
+ verifiedEvents []Event
+ allPCRReplays = map[int][]pcrReplayResult{}
+ )
+
+ // Replay the event log for every PCR and digest algorithm combination.
+ for _, pcr := range pcrs {
+ events, ok := replayPCR(rawEvents, pcr)
+ allPCRReplays[pcr.Index] = append(allPCRReplays[pcr.Index], pcrReplayResult{events, ok})
+ }
+
+ // Record PCR indices which do not have any successful replay. Record the
+ // events for a successful replay.
+pcrLoop:
+ for i, replaysForPCR := range allPCRReplays {
+ for _, replay := range replaysForPCR {
+ if replay.successful {
+ // We consider the PCR verified at this stage: The replay of values with
+ // one digest algorithm matched a provided value.
+ // As such, we save the PCR's events, and proceed to the next PCR.
+ verifiedEvents = append(verifiedEvents, replay.events...)
+ continue pcrLoop
+ }
+ }
+ invalidReplays = append(invalidReplays, i)
+ }
+
+ if len(invalidReplays) > 0 {
+ events := make([]Event, 0, len(rawEvents))
+ for _, e := range rawEvents {
+ events = append(events, Event{e.sequence, e.index, e.typ, e.data, nil})
+ }
+ return nil, ReplayError{
+ Events: events,
+ invalidPCRs: invalidReplays,
+ }
+ }
+
+ sort.Slice(verifiedEvents, func(i int, j int) bool {
+ return verifiedEvents[i].sequence < verifiedEvents[j].sequence
+ })
+ return verifiedEvents, nil
+}
+
+// EV_NO_ACTION is a special event type that indicates information to the parser
+// instead of holding a measurement. For TPM 2.0, this event type is used to signal
+// switching from SHA1 format to a variable length digest.
+//
+// https://trustedcomputinggroup.org/wp-content/uploads/TCG_PCClientSpecPlat_TPM_2p0_1p04_pub.pdf#page=110
+const eventTypeNoAction = 0x03
+
+// ParseEventLog parses an unverified measurement log.
+func ParseEventLog(measurementLog []byte) (*EventLog, error) {
+ var specID *specIDEvent
+ r := bytes.NewBuffer(measurementLog)
+ parseFn := parseRawEvent
+ var el EventLog
+ e, err := parseFn(r, specID)
+ if err != nil {
+ return nil, fmt.Errorf("parse first event: %v", err)
+ }
+ if e.typ == eventTypeNoAction {
+ specID, err = parseSpecIDEvent(e.data)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse spec ID event: %v", err)
+ }
+ for _, alg := range specID.algs {
+ switch tpm2.Algorithm(alg.ID) {
+ case tpm2.AlgSHA1:
+ el.Algs = append(el.Algs, HashSHA1)
+ case tpm2.AlgSHA256:
+ el.Algs = append(el.Algs, HashSHA256)
+ }
+ }
+ if len(el.Algs) == 0 {
+ return nil, fmt.Errorf("measurement log didn't use sha1 or sha256 digests")
+ }
+ // Switch to parsing crypto agile events. Don't include this in the
+ // replayed events since it intentionally doesn't extend the PCRs.
+ //
+ // Note that this doesn't actually guarentee that events have SHA256
+ // digests.
+ parseFn = parseRawEvent2
+ } else {
+ el.Algs = []HashAlg{HashSHA1}
+ el.rawEvents = append(el.rawEvents, e)
+ }
+ sequence := 1
+ for r.Len() != 0 {
+ e, err := parseFn(r, specID)
+ if err != nil {
+ return nil, err
+ }
+ e.sequence = sequence
+ sequence++
+ el.rawEvents = append(el.rawEvents, e)
+ }
+ return &el, nil
+}
+
+type specIDEvent struct {
+ algs []specAlgSize
+}
+
+type specAlgSize struct {
+ ID uint16
+ Size uint16
+}
+
+// Expected values for various Spec ID Event fields.
+// https://trustedcomputinggroup.org/wp-content/uploads/EFI-Protocol-Specification-rev13-160330final.pdf#page=19
+var wantSignature = [16]byte{0x53, 0x70,
+ 0x65, 0x63, 0x20, 0x49,
+ 0x44, 0x20, 0x45, 0x76,
+ 0x65, 0x6e, 0x74, 0x30,
+ 0x33, 0x00} // "Spec ID Event03\0"
+
+const (
+ wantMajor = 2
+ wantMinor = 0
+ wantErrata = 0
+)
+
+// parseSpecIDEvent parses a TCG_EfiSpecIDEventStruct structure from the reader.
+//
+// https://trustedcomputinggroup.org/wp-content/uploads/EFI-Protocol-Specification-rev13-160330final.pdf#page=18
+func parseSpecIDEvent(b []byte) (*specIDEvent, error) {
+ r := bytes.NewReader(b)
+ var header struct {
+ Signature [16]byte
+ PlatformClass uint32
+ VersionMinor uint8
+ VersionMajor uint8
+ Errata uint8
+ UintnSize uint8
+ NumAlgs uint32
+ }
+ if err := binary.Read(r, binary.LittleEndian, &header); err != nil {
+ return nil, fmt.Errorf("reading event header: %v", err)
+ }
+ if header.Signature != wantSignature {
+ return nil, fmt.Errorf("invalid spec id signature: %x", header.Signature)
+ }
+ if header.VersionMajor != wantMajor {
+ return nil, fmt.Errorf("invalid spec major version, got %02x, wanted %02x",
+ header.VersionMajor, wantMajor)
+ }
+ if header.VersionMinor != wantMinor {
+ return nil, fmt.Errorf("invalid spec minor version, got %02x, wanted %02x",
+ header.VersionMajor, wantMinor)
+ }
+
+ // TODO(ericchiang): Check errata? Or do we expect that to change in ways
+ // we're okay with?
+
+ specAlg := specAlgSize{}
+ e := specIDEvent{}
+ for i := 0; i < int(header.NumAlgs); i++ {
+ if err := binary.Read(r, binary.LittleEndian, &specAlg); err != nil {
+ return nil, fmt.Errorf("reading algorithm: %v", err)
+ }
+ e.algs = append(e.algs, specAlg)
+ }
+
+ var vendorInfoSize uint8
+ if err := binary.Read(r, binary.LittleEndian, &vendorInfoSize); err != nil {
+ return nil, fmt.Errorf("reading vender info size: %v", err)
+ }
+ if r.Len() != int(vendorInfoSize) {
+ return nil, fmt.Errorf("reading vendor info, expected %d remaining bytes, got %d", vendorInfoSize, r.Len())
+ }
+ return &e, nil
+}
+
+type digest struct {
+ hash crypto.Hash
+ data []byte
+}
+
+type rawEvent struct {
+ sequence int
+ index int
+ typ EventType
+ data []byte
+ digests []digest
+}
+
+// TPM 1.2 event log format. See "5.1 SHA1 Event Log Entry Format"
+// https://trustedcomputinggroup.org/wp-content/uploads/EFI-Protocol-Specification-rev13-160330final.pdf#page=15
+type rawEventHeader struct {
+ PCRIndex uint32
+ Type uint32
+ Digest [20]byte
+ EventSize uint32
+}
+
+type eventSizeErr struct {
+ eventSize uint32
+ logSize int
+}
+
+func (e *eventSizeErr) Error() string {
+ return fmt.Sprintf("event data size (%d bytes) is greater than remaining measurement log (%d bytes)", e.eventSize, e.logSize)
+}
+
+func parseRawEvent(r *bytes.Buffer, specID *specIDEvent) (event rawEvent, err error) {
+ var h rawEventHeader
+ if err = binary.Read(r, binary.LittleEndian, &h); err != nil {
+ return event, err
+ }
+ if h.EventSize == 0 {
+ return event, errors.New("event data size is 0")
+ }
+ if h.EventSize > uint32(r.Len()) {
+ return event, &eventSizeErr{h.EventSize, r.Len()}
+ }
+
+ data := make([]byte, int(h.EventSize))
+ if _, err := io.ReadFull(r, data); err != nil {
+ return event, err
+ }
+
+ digests := []digest{{hash: crypto.SHA1, data: h.Digest[:]}}
+
+ return rawEvent{
+ typ: EventType(h.Type),
+ data: data,
+ index: int(h.PCRIndex),
+ digests: digests,
+ }, nil
+}
+
+// TPM 2.0 event log format. See "5.2 Crypto Agile Log Entry Format"
+// https://trustedcomputinggroup.org/wp-content/uploads/EFI-Protocol-Specification-rev13-160330final.pdf#page=15
+type rawEvent2Header struct {
+ PCRIndex uint32
+ Type uint32
+}
+
+func parseRawEvent2(r *bytes.Buffer, specID *specIDEvent) (event rawEvent, err error) {
+ var h rawEvent2Header
+
+ if err = binary.Read(r, binary.LittleEndian, &h); err != nil {
+ return event, err
+ }
+ event.typ = EventType(h.Type)
+ event.index = int(h.PCRIndex)
+
+ // parse the event digests
+ var numDigests uint32
+ if err := binary.Read(r, binary.LittleEndian, &numDigests); err != nil {
+ return event, err
+ }
+
+ for i := 0; i < int(numDigests); i++ {
+ var algID uint16
+ if err := binary.Read(r, binary.LittleEndian, &algID); err != nil {
+ return event, err
+ }
+ var digest digest
+
+ for _, alg := range specID.algs {
+ if alg.ID != algID {
+ continue
+ }
+ if uint16(r.Len()) < alg.Size {
+ return event, fmt.Errorf("reading digest: %v", io.ErrUnexpectedEOF)
+ }
+ digest.data = make([]byte, alg.Size)
+ digest.hash = HashAlg(alg.ID).cryptoHash()
+ }
+ if len(digest.data) == 0 {
+ return event, fmt.Errorf("unknown algorithm ID %x", algID)
+ }
+ if _, err := io.ReadFull(r, digest.data); err != nil {
+ return event, err
+ }
+ event.digests = append(event.digests, digest)
+ }
+
+ // parse event data
+ var eventSize uint32
+ if err = binary.Read(r, binary.LittleEndian, &eventSize); err != nil {
+ return event, err
+ }
+ if eventSize == 0 {
+ return event, errors.New("event data size is 0")
+ }
+ if eventSize > uint32(r.Len()) {
+ return event, &eventSizeErr{eventSize, r.Len()}
+ }
+ event.data = make([]byte, int(eventSize))
+ if _, err := io.ReadFull(r, event.data); err != nil {
+ return event, err
+ }
+ return event, err
+}
diff --git a/metropolis/pkg/tpm/eventlog/internal/BUILD.bazel b/metropolis/pkg/tpm/eventlog/internal/BUILD.bazel
new file mode 100644
index 0000000..a73bcba
--- /dev/null
+++ b/metropolis/pkg/tpm/eventlog/internal/BUILD.bazel
@@ -0,0 +1,12 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "go_default_library",
+ srcs = ["events.go"],
+ importpath = "git.monogon.dev/source/nexantic.git/metropolis/pkg/tpm/eventlog/internal",
+ visibility = ["//metropolis/pkg/tpm/eventlog:__subpackages__"],
+ deps = [
+ "@com_github_google_certificate_transparency_go//asn1:go_default_library",
+ "@com_github_google_certificate_transparency_go//x509:go_default_library",
+ ],
+)
diff --git a/metropolis/pkg/tpm/eventlog/internal/events.go b/metropolis/pkg/tpm/eventlog/internal/events.go
new file mode 100644
index 0000000..d9b933b
--- /dev/null
+++ b/metropolis/pkg/tpm/eventlog/internal/events.go
@@ -0,0 +1,403 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Taken from go-attestation under Apache 2.0
+package internal
+
+import (
+ "bytes"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "unicode/utf16"
+
+ "github.com/google/certificate-transparency-go/asn1"
+ "github.com/google/certificate-transparency-go/x509"
+)
+
+const (
+ // maxNameLen is the maximum accepted byte length for a name field.
+ // This value should be larger than any reasonable value.
+ maxNameLen = 2048
+ // maxDataLen is the maximum size in bytes of a variable data field.
+ // This value should be larger than any reasonable value.
+ maxDataLen = 1024 * 1024 // 1 Megabyte.
+)
+
+// GUIDs representing the contents of an UEFI_SIGNATURE_LIST.
+var (
+ hashSHA256SigGUID = efiGUID{0xc1c41626, 0x504c, 0x4092, [8]byte{0xac, 0xa9, 0x41, 0xf9, 0x36, 0x93, 0x43, 0x28}}
+ hashSHA1SigGUID = efiGUID{0x826ca512, 0xcf10, 0x4ac9, [8]byte{0xb1, 0x87, 0xbe, 0x01, 0x49, 0x66, 0x31, 0xbd}}
+ hashSHA224SigGUID = efiGUID{0x0b6e5233, 0xa65c, 0x44c9, [8]byte{0x94, 0x07, 0xd9, 0xab, 0x83, 0xbf, 0xc8, 0xbd}}
+ hashSHA384SigGUID = efiGUID{0xff3e5307, 0x9fd0, 0x48c9, [8]byte{0x85, 0xf1, 0x8a, 0xd5, 0x6c, 0x70, 0x1e, 0x01}}
+ hashSHA512SigGUID = efiGUID{0x093e0fae, 0xa6c4, 0x4f50, [8]byte{0x9f, 0x1b, 0xd4, 0x1e, 0x2b, 0x89, 0xc1, 0x9a}}
+ keyRSA2048SigGUID = efiGUID{0x3c5766e8, 0x269c, 0x4e34, [8]byte{0xaa, 0x14, 0xed, 0x77, 0x6e, 0x85, 0xb3, 0xb6}}
+ certRSA2048SHA256SigGUID = efiGUID{0xe2b36190, 0x879b, 0x4a3d, [8]byte{0xad, 0x8d, 0xf2, 0xe7, 0xbb, 0xa3, 0x27, 0x84}}
+ certRSA2048SHA1SigGUID = efiGUID{0x67f8444f, 0x8743, 0x48f1, [8]byte{0xa3, 0x28, 0x1e, 0xaa, 0xb8, 0x73, 0x60, 0x80}}
+ certX509SigGUID = efiGUID{0xa5c059a1, 0x94e4, 0x4aa7, [8]byte{0x87, 0xb5, 0xab, 0x15, 0x5c, 0x2b, 0xf0, 0x72}}
+ certHashSHA256SigGUID = efiGUID{0x3bd2a492, 0x96c0, 0x4079, [8]byte{0xb4, 0x20, 0xfc, 0xf9, 0x8e, 0xf1, 0x03, 0xed}}
+ certHashSHA384SigGUID = efiGUID{0x7076876e, 0x80c2, 0x4ee6, [8]byte{0xaa, 0xd2, 0x28, 0xb3, 0x49, 0xa6, 0x86, 0x5b}}
+ certHashSHA512SigGUID = efiGUID{0x446dbf63, 0x2502, 0x4cda, [8]byte{0xbc, 0xfa, 0x24, 0x65, 0xd2, 0xb0, 0xfe, 0x9d}}
+)
+
+// EventType describes the type of event signalled in the event log.
+type EventType uint32
+
+// BIOS Events (TCG PC Client Specific Implementation Specification for Conventional BIOS 1.21)
+const (
+ PrebootCert EventType = 0x00000000
+ PostCode EventType = 0x00000001
+ unused EventType = 0x00000002
+ NoAction EventType = 0x00000003
+ Separator EventType = 0x00000004
+ Action EventType = 0x00000005
+ EventTag EventType = 0x00000006
+ SCRTMContents EventType = 0x00000007
+ SCRTMVersion EventType = 0x00000008
+ CpuMicrocode EventType = 0x00000009
+ PlatformConfigFlags EventType = 0x0000000A
+ TableOfDevices EventType = 0x0000000B
+ CompactHash EventType = 0x0000000C
+ Ipl EventType = 0x0000000D
+ IplPartitionData EventType = 0x0000000E
+ NonhostCode EventType = 0x0000000F
+ NonhostConfig EventType = 0x00000010
+ NonhostInfo EventType = 0x00000011
+ OmitBootDeviceEvents EventType = 0x00000012
+)
+
+// EFI Events (TCG EFI Platform Specification Version 1.22)
+const (
+ EFIEventBase EventType = 0x80000000
+ EFIVariableDriverConfig EventType = 0x80000001
+ EFIVariableBoot EventType = 0x80000002
+ EFIBootServicesApplication EventType = 0x80000003
+ EFIBootServicesDriver EventType = 0x80000004
+ EFIRuntimeServicesDriver EventType = 0x80000005
+ EFIGPTEvent EventType = 0x80000006
+ EFIAction EventType = 0x80000007
+ EFIPlatformFirmwareBlob EventType = 0x80000008
+ EFIHandoffTables EventType = 0x80000009
+ EFIHCRTMEvent EventType = 0x80000010
+ EFIVariableAuthority EventType = 0x800000e0
+)
+
+// ErrSigMissingGUID is returned if an EFI_SIGNATURE_DATA structure was parsed
+// successfully, however was missing the SignatureOwner GUID. This case is
+// handled specially as a workaround for a bug relating to authority events.
+var ErrSigMissingGUID = errors.New("signature data was missing owner GUID")
+
+var eventTypeNames = map[EventType]string{
+ PrebootCert: "Preboot Cert",
+ PostCode: "POST Code",
+ unused: "Unused",
+ NoAction: "No Action",
+ Separator: "Separator",
+ Action: "Action",
+ EventTag: "Event Tag",
+ SCRTMContents: "S-CRTM Contents",
+ SCRTMVersion: "S-CRTM Version",
+ CpuMicrocode: "CPU Microcode",
+ PlatformConfigFlags: "Platform Config Flags",
+ TableOfDevices: "Table of Devices",
+ CompactHash: "Compact Hash",
+ Ipl: "IPL",
+ IplPartitionData: "IPL Partition Data",
+ NonhostCode: "Non-Host Code",
+ NonhostConfig: "Non-HostConfig",
+ NonhostInfo: "Non-Host Info",
+ OmitBootDeviceEvents: "Omit Boot Device Events",
+
+ EFIEventBase: "EFI Event Base",
+ EFIVariableDriverConfig: "EFI Variable Driver Config",
+ EFIVariableBoot: "EFI Variable Boot",
+ EFIBootServicesApplication: "EFI Boot Services Application",
+ EFIBootServicesDriver: "EFI Boot Services Driver",
+ EFIRuntimeServicesDriver: "EFI Runtime Services Driver",
+ EFIGPTEvent: "EFI GPT Event",
+ EFIAction: "EFI Action",
+ EFIPlatformFirmwareBlob: "EFI Platform Firmware Blob",
+ EFIVariableAuthority: "EFI Variable Authority",
+ EFIHandoffTables: "EFI Handoff Tables",
+ EFIHCRTMEvent: "EFI H-CRTM Event",
+}
+
+func (e EventType) String() string {
+ if s, ok := eventTypeNames[e]; ok {
+ return s
+ }
+ return fmt.Sprintf("EventType(0x%x)", uint32(e))
+}
+
+// UntrustedParseEventType returns the event type indicated by
+// the provided value.
+func UntrustedParseEventType(et uint32) (EventType, error) {
+ // "The value associated with a UEFI specific platform event type MUST be in
+ // the range between 0x80000000 and 0x800000FF, inclusive."
+ if (et < 0x80000000 && et > 0x800000FF) || (et < 0x0 && et > 0x12) {
+ return EventType(0), fmt.Errorf("event type not between [0x0, 0x12] or [0x80000000, 0x800000FF]: got %#x", et)
+ }
+ if _, ok := eventTypeNames[EventType(et)]; !ok {
+ return EventType(0), fmt.Errorf("unknown event type %#x", et)
+ }
+ return EventType(et), nil
+}
+
+// efiGUID represents the EFI_GUID type.
+// See section "2.3.1 Data Types" in the specification for more information.
+// type efiGUID [16]byte
+type efiGUID struct {
+ Data1 uint32
+ Data2 uint16
+ Data3 uint16
+ Data4 [8]byte
+}
+
+func (d efiGUID) String() string {
+ var u [8]byte
+ binary.BigEndian.PutUint32(u[:4], d.Data1)
+ binary.BigEndian.PutUint16(u[4:6], d.Data2)
+ binary.BigEndian.PutUint16(u[6:8], d.Data3)
+ return fmt.Sprintf("%x-%x-%x-%x-%x", u[:4], u[4:6], u[6:8], d.Data4[:2], d.Data4[2:])
+}
+
+// UEFIVariableDataHeader represents the leading fixed-size fields
+// within UEFI_VARIABLE_DATA.
+type UEFIVariableDataHeader struct {
+ VariableName efiGUID
+ UnicodeNameLength uint64 // uintN
+ VariableDataLength uint64 // uintN
+}
+
+// UEFIVariableData represents the UEFI_VARIABLE_DATA structure.
+type UEFIVariableData struct {
+ Header UEFIVariableDataHeader
+ UnicodeName []uint16
+ VariableData []byte // []int8
+}
+
+// ParseUEFIVariableData parses the data section of an event structured as
+// a UEFI variable.
+//
+// https://trustedcomputinggroup.org/wp-content/uploads/TCG_PCClient_Specific_Platform_Profile_for_TPM_2p0_1p04_PUBLIC.pdf#page=100
+func ParseUEFIVariableData(r io.Reader) (ret UEFIVariableData, err error) {
+ err = binary.Read(r, binary.LittleEndian, &ret.Header)
+ if err != nil {
+ return
+ }
+ if ret.Header.UnicodeNameLength > maxNameLen {
+ return UEFIVariableData{}, fmt.Errorf("unicode name too long: %d > %d", ret.Header.UnicodeNameLength, maxNameLen)
+ }
+ ret.UnicodeName = make([]uint16, ret.Header.UnicodeNameLength)
+ for i := 0; uint64(i) < ret.Header.UnicodeNameLength; i++ {
+ err = binary.Read(r, binary.LittleEndian, &ret.UnicodeName[i])
+ if err != nil {
+ return
+ }
+ }
+ if ret.Header.VariableDataLength > maxDataLen {
+ return UEFIVariableData{}, fmt.Errorf("variable data too long: %d > %d", ret.Header.VariableDataLength, maxDataLen)
+ }
+ ret.VariableData = make([]byte, ret.Header.VariableDataLength)
+ _, err = io.ReadFull(r, ret.VariableData)
+ return
+}
+
+func (v *UEFIVariableData) VarName() string {
+ return string(utf16.Decode(v.UnicodeName))
+}
+
+func (v *UEFIVariableData) SignatureData() (certs []x509.Certificate, hashes [][]byte, err error) {
+ return parseEfiSignatureList(v.VariableData)
+}
+
+// UEFIVariableAuthority describes the contents of a UEFI variable authority
+// event.
+type UEFIVariableAuthority struct {
+ Certs []x509.Certificate
+}
+
+// ParseUEFIVariableAuthority parses the data section of an event structured as
+// a UEFI variable authority.
+//
+// https://uefi.org/sites/default/files/resources/UEFI_Spec_2_8_final.pdf#page=1789
+func ParseUEFIVariableAuthority(r io.Reader) (UEFIVariableAuthority, error) {
+ v, err := ParseUEFIVariableData(r)
+ if err != nil {
+ return UEFIVariableAuthority{}, err
+ }
+ certs, err := parseEfiSignature(v.VariableData)
+ return UEFIVariableAuthority{Certs: certs}, err
+}
+
+// efiSignatureData represents the EFI_SIGNATURE_DATA type.
+// See section "31.4.1 Signature Database" in the specification for more information.
+type efiSignatureData struct {
+ SignatureOwner efiGUID
+ SignatureData []byte // []int8
+}
+
+// efiSignatureList represents the EFI_SIGNATURE_LIST type.
+// See section "31.4.1 Signature Database" in the specification for more information.
+type efiSignatureListHeader struct {
+ SignatureType efiGUID
+ SignatureListSize uint32
+ SignatureHeaderSize uint32
+ SignatureSize uint32
+}
+
+type efiSignatureList struct {
+ Header efiSignatureListHeader
+ SignatureData []byte
+ Signatures []byte
+}
+
+// parseEfiSignatureList parses a EFI_SIGNATURE_LIST structure.
+// The structure and related GUIDs are defined at:
+// https://uefi.org/sites/default/files/resources/UEFI_Spec_2_8_final.pdf#page=1790
+func parseEfiSignatureList(b []byte) ([]x509.Certificate, [][]byte, error) {
+ if len(b) < 28 {
+ // Being passed an empty signature list here appears to be valid
+ return nil, nil, nil
+ }
+ signatures := efiSignatureList{}
+ buf := bytes.NewReader(b)
+ certificates := []x509.Certificate{}
+ hashes := [][]byte{}
+
+ for buf.Len() > 0 {
+ err := binary.Read(buf, binary.LittleEndian, &signatures.Header)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ if signatures.Header.SignatureHeaderSize > maxDataLen {
+ return nil, nil, fmt.Errorf("signature header too large: %d > %d", signatures.Header.SignatureHeaderSize, maxDataLen)
+ }
+ if signatures.Header.SignatureListSize > maxDataLen {
+ return nil, nil, fmt.Errorf("signature list too large: %d > %d", signatures.Header.SignatureListSize, maxDataLen)
+ }
+
+ signatureType := signatures.Header.SignatureType
+ switch signatureType {
+ case certX509SigGUID: // X509 certificate
+ for sigOffset := 0; uint32(sigOffset) < signatures.Header.SignatureListSize-28; {
+ signature := efiSignatureData{}
+ signature.SignatureData = make([]byte, signatures.Header.SignatureSize-16)
+ err := binary.Read(buf, binary.LittleEndian, &signature.SignatureOwner)
+ if err != nil {
+ return nil, nil, err
+ }
+ err = binary.Read(buf, binary.LittleEndian, &signature.SignatureData)
+ if err != nil {
+ return nil, nil, err
+ }
+ cert, err := x509.ParseCertificate(signature.SignatureData)
+ if err != nil {
+ return nil, nil, err
+ }
+ sigOffset += int(signatures.Header.SignatureSize)
+ certificates = append(certificates, *cert)
+ }
+ case hashSHA256SigGUID: // SHA256
+ for sigOffset := 0; uint32(sigOffset) < signatures.Header.SignatureListSize-28; {
+ signature := efiSignatureData{}
+ signature.SignatureData = make([]byte, signatures.Header.SignatureSize-16)
+ err := binary.Read(buf, binary.LittleEndian, &signature.SignatureOwner)
+ if err != nil {
+ return nil, nil, err
+ }
+ err = binary.Read(buf, binary.LittleEndian, &signature.SignatureData)
+ if err != nil {
+ return nil, nil, err
+ }
+ hashes = append(hashes, signature.SignatureData)
+ sigOffset += int(signatures.Header.SignatureSize)
+ }
+ case keyRSA2048SigGUID:
+ err = errors.New("unhandled RSA2048 key")
+ case certRSA2048SHA256SigGUID:
+ err = errors.New("unhandled RSA2048-SHA256 key")
+ case hashSHA1SigGUID:
+ err = errors.New("unhandled SHA1 hash")
+ case certRSA2048SHA1SigGUID:
+ err = errors.New("unhandled RSA2048-SHA1 key")
+ case hashSHA224SigGUID:
+ err = errors.New("unhandled SHA224 hash")
+ case hashSHA384SigGUID:
+ err = errors.New("unhandled SHA384 hash")
+ case hashSHA512SigGUID:
+ err = errors.New("unhandled SHA512 hash")
+ case certHashSHA256SigGUID:
+ err = errors.New("unhandled X509-SHA256 hash metadata")
+ case certHashSHA384SigGUID:
+ err = errors.New("unhandled X509-SHA384 hash metadata")
+ case certHashSHA512SigGUID:
+ err = errors.New("unhandled X509-SHA512 hash metadata")
+ default:
+ err = fmt.Errorf("unhandled signature type %s", signatureType)
+ }
+ if err != nil {
+ return nil, nil, err
+ }
+ }
+ return certificates, hashes, nil
+}
+
+// EFISignatureData represents the EFI_SIGNATURE_DATA type.
+// See section "31.4.1 Signature Database" in the specification
+// for more information.
+type EFISignatureData struct {
+ SignatureOwner efiGUID
+ SignatureData []byte // []int8
+}
+
+func parseEfiSignature(b []byte) ([]x509.Certificate, error) {
+ certificates := []x509.Certificate{}
+
+ if len(b) < 16 {
+ return nil, fmt.Errorf("invalid signature: buffer smaller than header (%d < %d)", len(b), 16)
+ }
+
+ buf := bytes.NewReader(b)
+ signature := EFISignatureData{}
+ signature.SignatureData = make([]byte, len(b)-16)
+
+ if err := binary.Read(buf, binary.LittleEndian, &signature.SignatureOwner); err != nil {
+ return certificates, err
+ }
+ if err := binary.Read(buf, binary.LittleEndian, &signature.SignatureData); err != nil {
+ return certificates, err
+ }
+
+ cert, err := x509.ParseCertificate(signature.SignatureData)
+ if err == nil {
+ certificates = append(certificates, *cert)
+ } else {
+ // A bug in shim may cause an event to be missing the SignatureOwner GUID.
+ // We handle this, but signal back to the caller using ErrSigMissingGUID.
+ if _, isStructuralErr := err.(asn1.StructuralError); isStructuralErr {
+ var err2 error
+ cert, err2 = x509.ParseCertificate(b)
+ if err2 == nil {
+ certificates = append(certificates, *cert)
+ err = ErrSigMissingGUID
+ }
+ }
+ }
+ return certificates, err
+}
diff --git a/metropolis/pkg/tpm/eventlog/secureboot.go b/metropolis/pkg/tpm/eventlog/secureboot.go
new file mode 100644
index 0000000..46e1f95
--- /dev/null
+++ b/metropolis/pkg/tpm/eventlog/secureboot.go
@@ -0,0 +1,210 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Taken and pruned from go-attestation under Apache 2.0
+package eventlog
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+
+ "github.com/google/certificate-transparency-go/x509"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/pkg/tpm/eventlog/internal"
+)
+
+// SecurebootState describes the secure boot status of a machine, as determined
+// by processing its event log.
+type SecurebootState struct {
+ Enabled bool
+
+ // PlatformKeys enumerates keys which can sign a key exchange key.
+ PlatformKeys []x509.Certificate
+ // PlatformKeys enumerates key hashes which can sign a key exchange key.
+ PlatformKeyHashes [][]byte
+
+ // ExchangeKeys enumerates keys which can sign a database of permitted or
+ // forbidden keys.
+ ExchangeKeys []x509.Certificate
+ // ExchangeKeyHashes enumerates key hashes which can sign a database or
+ // permitted or forbidden keys.
+ ExchangeKeyHashes [][]byte
+
+ // PermittedKeys enumerates keys which may sign binaries to run.
+ PermittedKeys []x509.Certificate
+ // PermittedHashes enumerates hashes which permit binaries to run.
+ PermittedHashes [][]byte
+
+ // ForbiddenKeys enumerates keys which must not permit a binary to run.
+ ForbiddenKeys []x509.Certificate
+ // ForbiddenKeys enumerates hashes which must not permit a binary to run.
+ ForbiddenHashes [][]byte
+
+ // PreSeparatorAuthority describes the use of a secure-boot key to authorize
+ // the execution of a binary before the separator.
+ PreSeparatorAuthority []x509.Certificate
+ // PostSeparatorAuthority describes the use of a secure-boot key to authorize
+ // the execution of a binary after the separator.
+ PostSeparatorAuthority []x509.Certificate
+}
+
+// ParseSecurebootState parses a series of events to determine the
+// configuration of secure boot on a device. An error is returned if
+// the state cannot be determined, or if the event log is structured
+// in such a way that it may have been tampered post-execution of
+// platform firmware.
+func ParseSecurebootState(events []Event) (*SecurebootState, error) {
+ // This algorithm verifies the following:
+ // - All events in PCR 7 have event types which are expected in PCR 7.
+ // - All events are parsable according to their event type.
+ // - All events have digests values corresponding to their data/event type.
+ // - No unverifiable events were present.
+ // - All variables are specified before the separator and never duplicated.
+ // - The SecureBoot variable has a value of 0 or 1.
+ // - If SecureBoot was 1 (enabled), authority events were present indicating
+ // keys were used to perform verification.
+ // - If SecureBoot was 1 (enabled), platform + exchange + database keys
+ // were specified.
+ // - No UEFI debugger was attached.
+
+ var (
+ out SecurebootState
+ seenSeparator bool
+ seenAuthority bool
+ seenVars = map[string]bool{}
+ )
+
+ for _, e := range events {
+ if e.Index != 7 {
+ continue
+ }
+
+ et, err := internal.UntrustedParseEventType(uint32(e.Type))
+ if err != nil {
+ return nil, fmt.Errorf("unrecognised event type: %v", err)
+ }
+
+ digestVerify := e.digestEquals(e.Data)
+ switch et {
+ case internal.Separator:
+ if seenSeparator {
+ return nil, fmt.Errorf("duplicate separator at event %d", e.sequence)
+ }
+ seenSeparator = true
+ if !bytes.Equal(e.Data, []byte{0, 0, 0, 0}) {
+ return nil, fmt.Errorf("invalid separator data at event %d: %v", e.sequence, e.Data)
+ }
+ if digestVerify != nil {
+ return nil, fmt.Errorf("invalid separator digest at event %d: %v", e.sequence, digestVerify)
+ }
+
+ case internal.EFIAction:
+ if string(e.Data) == "UEFI Debug Mode" {
+ return nil, errors.New("a UEFI debugger was present during boot")
+ }
+ return nil, fmt.Errorf("event %d: unexpected EFI action event", e.sequence)
+
+ case internal.EFIVariableDriverConfig:
+ v, err := internal.ParseUEFIVariableData(bytes.NewReader(e.Data))
+ if err != nil {
+ return nil, fmt.Errorf("failed parsing EFI variable at event %d: %v", e.sequence, err)
+ }
+ if _, seenBefore := seenVars[v.VarName()]; seenBefore {
+ return nil, fmt.Errorf("duplicate EFI variable %q at event %d", v.VarName(), e.sequence)
+ }
+ seenVars[v.VarName()] = true
+ if seenSeparator {
+ return nil, fmt.Errorf("event %d: variable %q specified after separator", e.sequence, v.VarName())
+ }
+
+ if digestVerify != nil {
+ return nil, fmt.Errorf("invalid digest for variable %q on event %d: %v", v.VarName(), e.sequence, digestVerify)
+ }
+
+ switch v.VarName() {
+ case "SecureBoot":
+ if len(v.VariableData) != 1 {
+ return nil, fmt.Errorf("event %d: SecureBoot data len is %d, expected 1", e.sequence, len(v.VariableData))
+ }
+ out.Enabled = v.VariableData[0] == 1
+ case "PK":
+ if out.PlatformKeys, out.PlatformKeyHashes, err = v.SignatureData(); err != nil {
+ return nil, fmt.Errorf("event %d: failed parsing platform keys: %v", e.sequence, err)
+ }
+ case "KEK":
+ if out.ExchangeKeys, out.ExchangeKeyHashes, err = v.SignatureData(); err != nil {
+ return nil, fmt.Errorf("event %d: failed parsing key exchange keys: %v", e.sequence, err)
+ }
+ case "db":
+ if out.PermittedKeys, out.PermittedHashes, err = v.SignatureData(); err != nil {
+ return nil, fmt.Errorf("event %d: failed parsing signature database: %v", e.sequence, err)
+ }
+ case "dbx":
+ if out.ForbiddenKeys, out.ForbiddenHashes, err = v.SignatureData(); err != nil {
+ return nil, fmt.Errorf("event %d: failed parsing forbidden signature database: %v", e.sequence, err)
+ }
+ }
+
+ case internal.EFIVariableAuthority:
+ a, err := internal.ParseUEFIVariableAuthority(bytes.NewReader(e.Data))
+ if err != nil {
+ // Workaround for: https://github.com/google/go-attestation/issues/157
+ if err == internal.ErrSigMissingGUID {
+ // Versions of shim which do not carry
+ // https://github.com/rhboot/shim/commit/8a27a4809a6a2b40fb6a4049071bf96d6ad71b50
+ // have an erroneous additional byte in the event, which breaks digest
+ // verification. If verification failed, we try removing the last byte.
+ if digestVerify != nil {
+ digestVerify = e.digestEquals(e.Data[:len(e.Data)-1])
+ }
+ } else {
+ return nil, fmt.Errorf("failed parsing EFI variable authority at event %d: %v", e.sequence, err)
+ }
+ }
+ seenAuthority = true
+ if digestVerify != nil {
+ return nil, fmt.Errorf("invalid digest for authority on event %d: %v", e.sequence, digestVerify)
+ }
+ if !seenSeparator {
+ out.PreSeparatorAuthority = append(out.PreSeparatorAuthority, a.Certs...)
+ } else {
+ out.PostSeparatorAuthority = append(out.PostSeparatorAuthority, a.Certs...)
+ }
+
+ default:
+ return nil, fmt.Errorf("unexpected event type: %v", et)
+ }
+ }
+
+ if !out.Enabled {
+ return &out, nil
+ }
+
+ if !seenAuthority {
+ return nil, errors.New("secure boot was enabled but no key was used")
+ }
+ if len(out.PlatformKeys) == 0 && len(out.PlatformKeyHashes) == 0 {
+ return nil, errors.New("secure boot was enabled but no platform keys were known")
+ }
+ if len(out.ExchangeKeys) == 0 && len(out.ExchangeKeyHashes) == 0 {
+ return nil, errors.New("secure boot was enabled but no key exchange keys were known")
+ }
+ if len(out.PermittedKeys) == 0 && len(out.PermittedHashes) == 0 {
+ return nil, errors.New("secure boot was enabled but no keys or hashes were permitted")
+ }
+ return &out, nil
+}
diff --git a/metropolis/pkg/tpm/tpm.go b/metropolis/pkg/tpm/tpm.go
new file mode 100644
index 0000000..29bd208
--- /dev/null
+++ b/metropolis/pkg/tpm/tpm.go
@@ -0,0 +1,561 @@
+// Copyright 2020 The Monogon Project Authors.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tpm
+
+import (
+ "bytes"
+ "crypto"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/x509"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/gogo/protobuf/proto"
+ tpmpb "github.com/google/go-tpm-tools/proto"
+ "github.com/google/go-tpm-tools/tpm2tools"
+ "github.com/google/go-tpm/tpm2"
+ "github.com/google/go-tpm/tpmutil"
+ "github.com/pkg/errors"
+ "golang.org/x/sys/unix"
+
+ "git.monogon.dev/source/nexantic.git/metropolis/pkg/logtree"
+ "git.monogon.dev/source/nexantic.git/metropolis/pkg/sysfs"
+)
+
+var (
+ // SecureBootPCRs are all PCRs that measure the current Secure Boot configuration.
+ // This is what we want if we rely on secure boot to verify boot integrity. The firmware
+ // hashes the secure boot policy and custom keys into the PCR.
+ //
+ // This requires an extra step that provisions the custom keys.
+ //
+ // Some background: https://mjg59.dreamwidth.org/48897.html?thread=1847297
+ // (the initramfs issue mentioned in the article has been solved by integrating
+ // it into the kernel binary, and we don't have a shim bootloader)
+ //
+ // PCR7 alone is not sufficient - it needs to be combined with firmware measurements.
+ SecureBootPCRs = []int{7}
+
+ // FirmwarePCRs are alle PCRs that contain the firmware measurements
+ // See https://trustedcomputinggroup.org/wp-content/uploads/TCG_EFI_Platform_1_22_Final_-v15.pdf
+ FirmwarePCRs = []int{
+ 0, // platform firmware
+ 2, // option ROM code
+ 3, // option ROM configuration and data
+ }
+
+ // FullSystemPCRs are all PCRs that contain any measurements up to the currently running EFI payload.
+ FullSystemPCRs = []int{
+ 0, // platform firmware
+ 1, // host platform configuration
+ 2, // option ROM code
+ 3, // option ROM configuration and data
+ 4, // EFI payload
+ }
+
+ // Using FullSystemPCRs is the most secure, but also the most brittle option since updating the EFI
+ // binary, updating the platform firmware, changing platform settings or updating the binary
+ // would invalidate the sealed data. It's annoying (but possible) to predict values for PCR4,
+ // and even more annoying for the firmware PCR (comparison to known values on similar hardware
+ // is the only thing that comes to mind).
+ //
+ // See also: https://github.com/mxre/sealkey (generates PCR4 from EFI image, BSD license)
+ //
+ // Using only SecureBootPCRs is the easiest and still reasonably secure, if we assume that the
+ // platform knows how to take care of itself (i.e. Intel Boot Guard), and that secure boot
+ // is implemented properly. It is, however, a much larger amount of code we need to trust.
+ //
+ // We do not care about PCR 5 (GPT partition table) since modifying it is harmless. All of
+ // the boot options and cmdline are hardcoded in the kernel image, and we use no bootloader,
+ // so there's no PCR for bootloader configuration or kernel cmdline.
+)
+
+var (
+ numSRTMPCRs = 16
+ srtmPCRs = tpm2.PCRSelection{Hash: tpm2.AlgSHA256, PCRs: []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}
+ // TCG Trusted Platform Module Library Level 00 Revision 0.99 Table 6
+ tpmGeneratedValue = uint32(0xff544347)
+)
+
+var (
+ // ErrNotExists is returned when no TPMs are available in the system
+ ErrNotExists = errors.New("no TPMs found")
+ // ErrNotInitialized is returned when this package was not initialized successfully
+ ErrNotInitialized = errors.New("no TPM was initialized")
+)
+
+// Singleton since the TPM is too
+var tpm *TPM
+
+// We're serializing all TPM operations since it has a limited number of handles and recovering
+// if it runs out is difficult to implement correctly. Might also be marginally more secure.
+var lock sync.Mutex
+
+// TPM represents a high-level interface to a connected TPM 2.0
+type TPM struct {
+ logger logtree.LeveledLogger
+ device io.ReadWriteCloser
+
+ // We keep the AK loaded since it's used fairly often and deriving it is expensive
+ akHandleCache tpmutil.Handle
+ akPublicKey crypto.PublicKey
+}
+
+// Initialize finds and opens the TPM (if any). If there is no TPM available it returns
+// ErrNotExists
+func Initialize(logger logtree.LeveledLogger) error {
+ lock.Lock()
+ defer lock.Unlock()
+ tpmDir, err := os.Open("/sys/class/tpm")
+ if err != nil {
+ return errors.Wrap(err, "failed to open sysfs TPM class")
+ }
+ defer tpmDir.Close()
+
+ tpms, err := tpmDir.Readdirnames(2)
+ if err != nil {
+ return errors.Wrap(err, "failed to read TPM device class")
+ }
+
+ if len(tpms) == 0 {
+ return ErrNotExists
+ }
+ if len(tpms) > 1 {
+ // If this is changed GetMeasurementLog() needs to be updated too
+ logger.Warningf("Found more than one TPM, using the first one")
+ }
+ tpmName := tpms[0]
+ ueventData, err := sysfs.ReadUevents(filepath.Join("/sys/class/tpm", tpmName, "uevent"))
+ majorDev, err := strconv.Atoi(ueventData["MAJOR"])
+ if err != nil {
+ return fmt.Errorf("failed to convert uevent: %w", err)
+ }
+ minorDev, err := strconv.Atoi(ueventData["MINOR"])
+ if err != nil {
+ return fmt.Errorf("failed to convert uevent: %w", err)
+ }
+ if err := unix.Mknod("/dev/tpm", 0600|unix.S_IFCHR, int(unix.Mkdev(uint32(majorDev), uint32(minorDev)))); err != nil {
+ return errors.Wrap(err, "failed to create TPM device node")
+ }
+ device, err := tpm2.OpenTPM("/dev/tpm")
+ if err != nil {
+ return errors.Wrap(err, "failed to open TPM")
+ }
+ tpm = &TPM{
+ device: device,
+ logger: logger,
+ }
+ return nil
+}
+
+// GenerateSafeKey uses two sources of randomness (Kernel & TPM) to generate the key
+func GenerateSafeKey(size uint16) ([]byte, error) {
+ lock.Lock()
+ defer lock.Unlock()
+ if tpm == nil {
+ return []byte{}, ErrNotInitialized
+ }
+ encryptionKeyHost := make([]byte, size)
+ if _, err := io.ReadFull(rand.Reader, encryptionKeyHost); err != nil {
+ return []byte{}, errors.Wrap(err, "failed to generate host portion of new key")
+ }
+ var encryptionKeyTPM []byte
+ for i := 48; i > 0; i-- {
+ tpmKeyPart, err := tpm2.GetRandom(tpm.device, size-uint16(len(encryptionKeyTPM)))
+ if err != nil {
+ return []byte{}, errors.Wrap(err, "failed to generate TPM portion of new key")
+ }
+ encryptionKeyTPM = append(encryptionKeyTPM, tpmKeyPart...)
+ if len(encryptionKeyTPM) >= int(size) {
+ break
+ }
+ }
+
+ if len(encryptionKeyTPM) != int(size) {
+ return []byte{}, fmt.Errorf("got incorrect amount of TPM randomess: %v, requested %v", len(encryptionKeyTPM), size)
+ }
+
+ encryptionKey := make([]byte, size)
+ for i := uint16(0); i < size; i++ {
+ encryptionKey[i] = encryptionKeyHost[i] ^ encryptionKeyTPM[i]
+ }
+ return encryptionKey, nil
+}
+
+// Seal seals sensitive data and only allows access if the current platform configuration in
+// matches the one the data was sealed on.
+func Seal(data []byte, pcrs []int) ([]byte, error) {
+ lock.Lock()
+ defer lock.Unlock()
+ if tpm == nil {
+ return []byte{}, ErrNotInitialized
+ }
+ srk, err := tpm2tools.StorageRootKeyRSA(tpm.device)
+ if err != nil {
+ return []byte{}, errors.Wrap(err, "failed to load TPM SRK")
+ }
+ defer srk.Close()
+ sealedKey, err := srk.Seal(pcrs, data)
+ sealedKeyRaw, err := proto.Marshal(sealedKey)
+ if err != nil {
+ return []byte{}, errors.Wrapf(err, "failed to marshal sealed data")
+ }
+ return sealedKeyRaw, nil
+}
+
+// Unseal unseals sensitive data if the current platform configuration allows and sealing constraints
+// allow it.
+func Unseal(data []byte) ([]byte, error) {
+ lock.Lock()
+ defer lock.Unlock()
+ if tpm == nil {
+ return []byte{}, ErrNotInitialized
+ }
+ srk, err := tpm2tools.StorageRootKeyRSA(tpm.device)
+ if err != nil {
+ return []byte{}, errors.Wrap(err, "failed to load TPM SRK")
+ }
+ defer srk.Close()
+
+ var sealedKey tpmpb.SealedBytes
+ if err := proto.Unmarshal(data, &sealedKey); err != nil {
+ return []byte{}, errors.Wrap(err, "failed to decode sealed data")
+ }
+ // Logging this for auditing purposes
+ pcrList := []string{}
+ for _, pcr := range sealedKey.Pcrs {
+ pcrList = append(pcrList, string(pcr))
+ }
+ tpm.logger.Infof("Attempting to unseal data protected with PCRs %s", strings.Join(pcrList, ","))
+ unsealedData, err := srk.Unseal(&sealedKey)
+ if err != nil {
+ return []byte{}, errors.Wrap(err, "failed to unseal data")
+ }
+ return unsealedData, nil
+}
+
+// Standard AK template for RSA2048 non-duplicatable restricted signing for attestation
+var akTemplate = tpm2.Public{
+ Type: tpm2.AlgRSA,
+ NameAlg: tpm2.AlgSHA256,
+ Attributes: tpm2.FlagSignerDefault,
+ RSAParameters: &tpm2.RSAParams{
+ Sign: &tpm2.SigScheme{
+ Alg: tpm2.AlgRSASSA,
+ Hash: tpm2.AlgSHA256,
+ },
+ KeyBits: 2048,
+ },
+}
+
+func loadAK() error {
+ var err error
+ // Rationale: The AK is an EK-equivalent key and used only for attestation. Using a non-primary
+ // key here would require us to store the wrapped version somewhere, which is inconvenient.
+ // This being a primary key in the Endorsement hierarchy means that it can always be recreated
+ // and can never be "destroyed". Under our security model this is of no concern since we identify
+ // a node by its IK (Identity Key) which we can destroy.
+ tpm.akHandleCache, tpm.akPublicKey, err = tpm2.CreatePrimary(tpm.device, tpm2.HandleEndorsement,
+ tpm2.PCRSelection{}, "", "", akTemplate)
+ return err
+}
+
+// Process documented in TCG EK Credential Profile 2.2.1
+func loadEK() (tpmutil.Handle, crypto.PublicKey, error) {
+ // The EK is a primary key which is supposed to be certified by the manufacturer of the TPM.
+ // Its public attributes are standardized in TCG EK Credential Profile 2.0 Table 1. These need
+ // to match exactly or we aren't getting the key the manufacturere signed. tpm2tools contains
+ // such a template already, so we're using that instead of redoing it ourselves.
+ // This ignores the more complicated ways EKs can be specified, the additional stuff you can do
+ // is just absolutely crazy (see 2.2.1.2 onward)
+ return tpm2.CreatePrimary(tpm.device, tpm2.HandleEndorsement,
+ tpm2.PCRSelection{}, "", "", tpm2tools.DefaultEKTemplateRSA())
+}
+
+// GetAKPublic gets the TPM2T_PUBLIC of the AK key
+func GetAKPublic() ([]byte, error) {
+ lock.Lock()
+ defer lock.Unlock()
+ if tpm == nil {
+ return []byte{}, ErrNotInitialized
+ }
+ if tpm.akHandleCache == tpmutil.Handle(0) {
+ if err := loadAK(); err != nil {
+ return []byte{}, fmt.Errorf("failed to load AK primary key: %w", err)
+ }
+ }
+ public, _, _, err := tpm2.ReadPublic(tpm.device, tpm.akHandleCache)
+ if err != nil {
+ return []byte{}, err
+ }
+ return public.Encode()
+}
+
+// TCG TPM v2.0 Provisioning Guidance v1.0 7.8 Table 2 and
+// TCG EK Credential Profile v2.1 2.2.1.4 de-facto Standard for Windows
+// These are both non-normative and reference Windows 10 documentation that's no longer available :(
+// But in practice this is what people are using, so if it's normative or not doesn't really matter
+const ekCertHandle = 0x01c00002
+
+// GetEKPublic gets the public key and (if available) Certificate of the EK
+func GetEKPublic() ([]byte, []byte, error) {
+ lock.Lock()
+ defer lock.Unlock()
+ if tpm == nil {
+ return []byte{}, []byte{}, ErrNotInitialized
+ }
+ ekHandle, publicRaw, err := loadEK()
+ if err != nil {
+ return []byte{}, []byte{}, fmt.Errorf("failed to load EK primary key: %w", err)
+ }
+ defer tpm2.FlushContext(tpm.device, ekHandle)
+ // Don't question the use of HandleOwner, that's the Standard™
+ ekCertRaw, err := tpm2.NVReadEx(tpm.device, ekCertHandle, tpm2.HandleOwner, "", 0)
+ if err != nil {
+ return []byte{}, []byte{}, err
+ }
+
+ publicKey, err := x509.MarshalPKIXPublicKey(publicRaw)
+ if err != nil {
+ return []byte{}, []byte{}, err
+ }
+
+ return publicKey, ekCertRaw, nil
+}
+
+// MakeAKChallenge generates a challenge for TPM residency and attributes of the AK
+func MakeAKChallenge(ekPubKey, akPub []byte, nonce []byte) ([]byte, []byte, error) {
+ ekPubKeyData, err := x509.ParsePKIXPublicKey(ekPubKey)
+ if err != nil {
+ return []byte{}, []byte{}, fmt.Errorf("failed to decode EK pubkey: %w", err)
+ }
+ akPubData, err := tpm2.DecodePublic(akPub)
+ if err != nil {
+ return []byte{}, []byte{}, fmt.Errorf("failed to decode AK public part: %w", err)
+ }
+ // Make sure we're attesting the right attributes (in particular Restricted)
+ if !akPubData.MatchesTemplate(akTemplate) {
+ return []byte{}, []byte{}, errors.New("the key being challenged is not a valid AK")
+ }
+ akName, err := akPubData.Name()
+ if err != nil {
+ return []byte{}, []byte{}, fmt.Errorf("failed to derive AK name: %w", err)
+ }
+ return generateRSA(akName.Digest, ekPubKeyData.(*rsa.PublicKey), 16, nonce, rand.Reader)
+}
+
+// SolveAKChallenge solves a challenge for TPM residency of the AK
+func SolveAKChallenge(credBlob, secretChallenge []byte) ([]byte, error) {
+ lock.Lock()
+ defer lock.Unlock()
+ if tpm == nil {
+ return []byte{}, ErrNotInitialized
+ }
+ if tpm.akHandleCache == tpmutil.Handle(0) {
+ if err := loadAK(); err != nil {
+ return []byte{}, fmt.Errorf("failed to load AK primary key: %w", err)
+ }
+ }
+
+ ekHandle, _, err := loadEK()
+ if err != nil {
+ return []byte{}, fmt.Errorf("failed to load EK: %w", err)
+ }
+ defer tpm2.FlushContext(tpm.device, ekHandle)
+
+ // This is necessary since the EK requires an endorsement handle policy in its session
+ // For us this is stupid because we keep all hierarchies open anyways since a) we cannot safely
+ // store secrets on the OS side pre-global unlock and b) it makes no sense in this security model
+ // since an uncompromised host OS will not let an untrusted entity attest as itself and a
+ // compromised OS can either not pass PCR policy checks or the game's already over (you
+ // successfully runtime-exploited a production Metropolis node)
+ endorsementSession, _, err := tpm2.StartAuthSession(
+ tpm.device,
+ tpm2.HandleNull,
+ tpm2.HandleNull,
+ make([]byte, 16),
+ nil,
+ tpm2.SessionPolicy,
+ tpm2.AlgNull,
+ tpm2.AlgSHA256)
+ if err != nil {
+ panic(err)
+ }
+ defer tpm2.FlushContext(tpm.device, endorsementSession)
+
+ _, err = tpm2.PolicySecret(tpm.device, tpm2.HandleEndorsement, tpm2.AuthCommand{Session: tpm2.HandlePasswordSession, Attributes: tpm2.AttrContinueSession}, endorsementSession, nil, nil, nil, 0)
+ if err != nil {
+ return []byte{}, fmt.Errorf("failed to make a policy secret session: %w", err)
+ }
+
+ for {
+ solution, err := tpm2.ActivateCredentialUsingAuth(tpm.device, []tpm2.AuthCommand{
+ {Session: tpm2.HandlePasswordSession, Attributes: tpm2.AttrContinueSession}, // Use standard no-password authentication
+ {Session: endorsementSession, Attributes: tpm2.AttrContinueSession}, // Use a full policy session for the EK
+ }, tpm.akHandleCache, ekHandle, credBlob, secretChallenge)
+ if warn, ok := err.(tpm2.Warning); ok && warn.Code == tpm2.RCRetry {
+ time.Sleep(100 * time.Millisecond)
+ continue
+ }
+ return solution, err
+ }
+}
+
+// FlushTransientHandles flushes all sessions and non-persistent handles
+func FlushTransientHandles() error {
+ lock.Lock()
+ defer lock.Unlock()
+ if tpm == nil {
+ return ErrNotInitialized
+ }
+ flushHandleTypes := []tpm2.HandleType{tpm2.HandleTypeTransient, tpm2.HandleTypeLoadedSession, tpm2.HandleTypeSavedSession}
+ for _, handleType := range flushHandleTypes {
+ handles, err := tpm2tools.Handles(tpm.device, handleType)
+ if err != nil {
+ return err
+ }
+ for _, handle := range handles {
+ if err := tpm2.FlushContext(tpm.device, handle); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+// AttestPlatform performs a PCR quote using the AK and returns the quote and its signature
+func AttestPlatform(nonce []byte) ([]byte, []byte, error) {
+ lock.Lock()
+ defer lock.Unlock()
+ if tpm == nil {
+ return []byte{}, []byte{}, ErrNotInitialized
+ }
+ if tpm.akHandleCache == tpmutil.Handle(0) {
+ if err := loadAK(); err != nil {
+ return []byte{}, []byte{}, fmt.Errorf("failed to load AK primary key: %w", err)
+ }
+ }
+ // We only care about SHA256 since SHA1 is weak. This is supported on at least GCE and
+ // Intel / AMD fTPM, which is good enough for now. Alg is null because that would just hash the
+ // nonce, which is dumb.
+ quote, signature, err := tpm2.Quote(tpm.device, tpm.akHandleCache, "", "", nonce, srtmPCRs,
+ tpm2.AlgNull)
+ if err != nil {
+ return []byte{}, []byte{}, fmt.Errorf("failed to quote PCRs: %w", err)
+ }
+ return quote, signature.RSA.Signature, err
+}
+
+// VerifyAttestPlatform verifies a given attestation. You can rely on all data coming back as being
+// from the TPM on which the AK is bound to.
+func VerifyAttestPlatform(nonce, akPub, quote, signature []byte) (*tpm2.AttestationData, error) {
+ hash := crypto.SHA256.New()
+ hash.Write(quote)
+
+ akPubData, err := tpm2.DecodePublic(akPub)
+ if err != nil {
+ return nil, fmt.Errorf("invalid AK: %w", err)
+ }
+ akPublicKey, err := akPubData.Key()
+ if err != nil {
+ return nil, fmt.Errorf("invalid AK: %w", err)
+ }
+ akRSAKey, ok := akPublicKey.(*rsa.PublicKey)
+ if !ok {
+ return nil, errors.New("invalid AK: invalid key type")
+ }
+
+ if err := rsa.VerifyPKCS1v15(akRSAKey, crypto.SHA256, hash.Sum(nil), signature); err != nil {
+ return nil, err
+ }
+
+ quoteData, err := tpm2.DecodeAttestationData(quote)
+ if err != nil {
+ return nil, err
+ }
+ // quoteData.Magic works together with the TPM's Restricted key attribute. If this attribute is set
+ // (which it needs to be for the AK to be considered valid) the TPM will not sign external data
+ // having this prefix with such a key. Only data that originates inside the TPM like quotes and
+ // key certifications can have this prefix and sill be signed by a restricted key. This check
+ // is thus vital, otherwise somebody can just feed the TPM an arbitrary attestation to sign with
+ // its AK and this function will happily accept the forged attestation.
+ if quoteData.Magic != tpmGeneratedValue {
+ return nil, errors.New("invalid TPM quote: data marker for internal data not set - forged attestation")
+ }
+ if quoteData.Type != tpm2.TagAttestQuote {
+ return nil, errors.New("invalid TPM qoute: not a TPM quote")
+ }
+ if !bytes.Equal(quoteData.ExtraData, nonce) {
+ return nil, errors.New("invalid TPM quote: wrong nonce")
+ }
+
+ return quoteData, nil
+}
+
+// GetPCRs returns all SRTM PCRs in-order
+func GetPCRs() ([][]byte, error) {
+ lock.Lock()
+ defer lock.Unlock()
+ if tpm == nil {
+ return [][]byte{}, ErrNotInitialized
+ }
+ pcrs := make([][]byte, numSRTMPCRs)
+
+ // The TPM can (and most do) return partial results. Let's just retry as many times as we have
+ // PCRs since each read should return at least one PCR.
+readLoop:
+ for i := 0; i < numSRTMPCRs; i++ {
+ sel := tpm2.PCRSelection{Hash: tpm2.AlgSHA256}
+ for pcrN := 0; pcrN < numSRTMPCRs; pcrN++ {
+ if len(pcrs[pcrN]) == 0 {
+ sel.PCRs = append(sel.PCRs, pcrN)
+ }
+ }
+
+ readPCRs, err := tpm2.ReadPCRs(tpm.device, sel)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read PCRs: %w", err)
+ }
+
+ for pcrN, pcr := range readPCRs {
+ pcrs[pcrN] = pcr
+ }
+ for _, pcr := range pcrs {
+ // If at least one PCR is still not read, continue
+ if len(pcr) == 0 {
+ continue readLoop
+ }
+ }
+ break
+ }
+
+ return pcrs, nil
+}
+
+// GetMeasurmentLog returns the binary log of all data hashed into PCRs. The result can be parsed by eventlog.
+// As this library currently doesn't support extending PCRs it just returns the log as supplied by the EFI interface.
+func GetMeasurementLog() ([]byte, error) {
+ return ioutil.ReadFile("/sys/kernel/security/tpm0/binary_bios_measurements")
+}