osbase/oci: add end-to-end tests

This adds some end-to-end tests of the OCI OS image generation and
consumption implementations.

Change-Id: Id9f4e3ab5b2c959807657e06990525810d4979ff
Reviewed-on: https://review.monogon.dev/c/monogon/+/4092
Reviewed-by: Tim Windelschmidt <tim@monogon.tech>
Tested-by: Jenkins CI
diff --git a/osbase/oci/osimage/BUILD.bazel b/osbase/oci/osimage/BUILD.bazel
index bfb8dd6..f54cdfb 100644
--- a/osbase/oci/osimage/BUILD.bazel
+++ b/osbase/oci/osimage/BUILD.bazel
@@ -1,4 +1,5 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+load("//osbase/build/mkoci:def.bzl", "oci_os_image")
 
 go_library(
     name = "osimage",
@@ -14,3 +15,45 @@
         "@com_github_klauspost_compress//zstd",
     ],
 )
+
+oci_os_image(
+    name = "test_image",
+    srcs = {
+        # We need some test payload, which should be multiple MB in size to test
+        # the chunking, but also not too large.
+        "test": "//third_party/linux",
+    },
+    visibility = ["//osbase/oci:__subpackages__"],
+)
+
+oci_os_image(
+    name = "test_image_uncompressed",
+    srcs = {
+        "test": "//third_party/linux",
+    },
+    compression_level = 0,
+    visibility = ["//osbase/oci:__subpackages__"],
+)
+
+go_test(
+    name = "osimage_test",
+    srcs = ["osimage_test.go"],
+    data = [
+        ":test_image",
+        ":test_image_uncompressed",
+        "//third_party/linux",
+    ],
+    embed = [":osimage"],
+    x_defs = {
+        "xImagePath": "$(rlocationpath :test_image )",
+        "xImageUncompressedPath": "$(rlocationpath :test_image_uncompressed )",
+        "xTestPayloadPath": "$(rlocationpath //third_party/linux )",
+    },
+    deps = [
+        "//osbase/oci",
+        "//osbase/oci/registry",
+        "@com_github_cenkalti_backoff_v4//:backoff",
+        "@com_github_opencontainers_image_spec//specs-go/v1:specs-go",
+        "@io_bazel_rules_go//go/runfiles",
+    ],
+)
diff --git a/osbase/oci/osimage/osimage_test.go b/osbase/oci/osimage/osimage_test.go
new file mode 100644
index 0000000..6ee8afa
--- /dev/null
+++ b/osbase/oci/osimage/osimage_test.go
@@ -0,0 +1,270 @@
+// Copyright The Monogon Project Authors.
+// SPDX-License-Identifier: Apache-2.0
+
+package osimage
+
+import (
+	"bytes"
+	"context"
+	"crypto/sha256"
+	"fmt"
+	"io"
+	"net"
+	"net/http"
+	"os"
+	"strings"
+	"testing"
+	"time"
+
+	"github.com/bazelbuild/rules_go/go/runfiles"
+	"github.com/cenkalti/backoff/v4"
+	ocispecv1 "github.com/opencontainers/image-spec/specs-go/v1"
+
+	"source.monogon.dev/osbase/oci"
+	"source.monogon.dev/osbase/oci/registry"
+)
+
+var (
+	// These are filled by bazel at linking time with the canonical path of
+	// their corresponding file. Inside the init function we resolve it
+	// with the rules_go runfiles package to the real path.
+	xImagePath             string
+	xImageUncompressedPath string
+	xTestPayloadPath       string
+)
+
+func init() {
+	var err error
+	for _, path := range []*string{
+		&xImagePath, &xImageUncompressedPath, &xTestPayloadPath,
+	} {
+		*path, err = runfiles.Rlocation(*path)
+		if err != nil {
+			panic(err)
+		}
+	}
+}
+
+var expectedPayloadHash [32]byte
+var expectedPayloadLen int64
+
+func init() {
+	expectedPayload, err := os.ReadFile(xTestPayloadPath)
+	if err != nil {
+		panic(err)
+	}
+	expectedPayloadHash = sha256.Sum256(expectedPayload)
+	expectedPayloadLen = int64(len(expectedPayload))
+}
+
+func TestRead(t *testing.T) {
+	testCases := []struct {
+		desc string
+		path string
+	}{
+		{"compressed", xImagePath},
+		{"uncompressed", xImageUncompressedPath},
+	}
+	for _, tC := range testCases {
+		t.Run(tC.desc, func(t *testing.T) {
+			image, err := oci.ReadLayout(tC.path)
+			if err != nil {
+				t.Fatal(err)
+			}
+			osImage, err := Read(image)
+			if err != nil {
+				t.Fatal(err)
+			}
+			payload, err := osImage.Payload("test")
+			if err != nil {
+				t.Fatal(err)
+			}
+			if got, want := payload.Size(), expectedPayloadLen; got != want {
+				t.Errorf("payload has size %d, expected %d", got, want)
+			}
+			reader, err := payload.Open()
+			if err != nil {
+				t.Fatal(err)
+			}
+			content, err := io.ReadAll(reader)
+			if err != nil {
+				t.Fatal(err)
+			}
+			contentHash := sha256.Sum256(content)
+			if contentHash != expectedPayloadHash {
+				t.Errorf("Payload read through Image does not match expected content, expected %x, got %x", expectedPayloadHash, contentHash)
+			}
+			if err := reader.Close(); err != nil {
+				t.Error(err)
+			}
+		})
+	}
+}
+
+func TestVerification(t *testing.T) {
+	server := registry.NewServer()
+	srcImage, err := oci.ReadLayout(xImageUncompressedPath)
+	if err != nil {
+		t.Fatal(err)
+	}
+	server.AddImage("test/repo", "test-tag", srcImage)
+	corrupter := &corruptingServer{handler: server}
+
+	listener, err := net.Listen("tcp", "127.0.0.1:0")
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer listener.Close()
+	go http.Serve(listener, corrupter)
+
+	client := &registry.Client{
+		GetBackOff: func() backoff.BackOff {
+			return backoff.NewExponentialBackOff()
+		},
+		RetryNotify: func(err error, _ time.Duration) {
+			t.Errorf("Unexpected retry; verification errors should not trigger retries: %v", err)
+		},
+		Scheme:     "http",
+		Host:       listener.Addr().String(),
+		Repository: "test/repo",
+	}
+
+	// Test manifest verification
+	corrupter.affectedPath = "/v2/test/repo/manifests/test-tag"
+	_, err = client.Read(context.Background(), "test-tag", "")
+	if err != nil {
+		t.Errorf("Expected reading manifest to succeed when digest not given: %v", err)
+	}
+	_, err = client.Read(context.Background(), "test-tag", srcImage.ManifestDigest)
+	if !strings.Contains(fmt.Sprintf("%v", err), "failed verification") {
+		t.Errorf("Expected failed verification, got %v", err)
+	}
+
+	// Test config verification
+	corrupter.affectedPath = fmt.Sprintf("/v2/test/repo/blobs/%s", srcImage.Manifest.Config.Digest)
+	image, err := client.Read(context.Background(), "test-tag", srcImage.ManifestDigest)
+	if err != nil {
+		t.Fatal(err)
+	}
+	_, err = Read(image)
+	if !strings.Contains(fmt.Sprintf("%v", err), "failed verification") {
+		t.Errorf("Expected failed verification, got %v", err)
+	}
+
+	// Test payload verification
+	corrupter.affectedPath = fmt.Sprintf("/v2/test/repo/blobs/%s", srcImage.Manifest.Layers[0].Digest)
+	image, err = client.Read(context.Background(), "test-tag", srcImage.ManifestDigest)
+	if err != nil {
+		t.Fatal(err)
+	}
+	osImage, err := Read(image)
+	if err != nil {
+		t.Fatal(err)
+	}
+	payload, err := osImage.Payload("test")
+	if err != nil {
+		t.Fatal(err)
+	}
+	reader, err := payload.Open()
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer reader.Close()
+	content, err := io.ReadAll(reader)
+	if !strings.Contains(fmt.Sprintf("%v", err), "payload failed verification") {
+		t.Errorf("Expected failed verification, got %v", err)
+	}
+	if len(content) != 0 {
+		t.Errorf("Did not expect to read any content, got %d bytes", len(content))
+	}
+}
+
+type corruptingServer struct {
+	affectedPath string
+	handler      http.Handler
+}
+
+func (s *corruptingServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+	if req.URL.Path == s.affectedPath {
+		w = &corruptingResponseWriter{ResponseWriter: w}
+	}
+	s.handler.ServeHTTP(w, req)
+}
+
+// corruptingResponseWriter replaces the first newline in the response with a
+// space. This means that JSON parsing will still succeed, but digest
+// verification should fail.
+type corruptingResponseWriter struct {
+	http.ResponseWriter
+	corrupted bool
+}
+
+func (w *corruptingResponseWriter) Write(b []byte) (n int, err error) {
+	index := bytes.IndexByte(b, '\n')
+	if w.corrupted || index == -1 {
+		return w.ResponseWriter.Write(b)
+	}
+	b = bytes.Clone(b)
+	b[index] = ' '
+	n, err = w.ResponseWriter.Write(b)
+	if n > index {
+		w.corrupted = true
+	}
+	return
+}
+
+func TestTruncation(t *testing.T) {
+	srcImage, err := oci.ReadLayout(xImageUncompressedPath)
+	if err != nil {
+		t.Fatal(err)
+	}
+	blobs := &truncatedBlobs{
+		image:  srcImage,
+		length: srcImage.Manifest.Config.Size,
+	}
+	truncatedImage, err := oci.NewImage(srcImage.RawManifest, "", blobs)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	osImage, err := Read(truncatedImage)
+	if err != nil {
+		t.Fatal(err)
+	}
+	blobs.length = osImage.Config.Payloads[0].HashChunkSize
+	payload, err := osImage.Payload("test")
+	if err != nil {
+		t.Fatal(err)
+	}
+	reader, err := payload.Open()
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer reader.Close()
+	_, err = io.ReadAll(reader)
+	if err == nil {
+		t.Error("Expected to get an error, got nil")
+	}
+}
+
+type truncatedBlobs struct {
+	image  *oci.Image
+	length int64
+}
+
+func (b *truncatedBlobs) Blob(d *ocispecv1.Descriptor) (io.ReadCloser, error) {
+	reader, err := b.image.Blob(d)
+	if err != nil {
+		return nil, err
+	}
+	reader = &readCloser{
+		Reader: io.LimitReader(reader, b.length),
+		Closer: reader,
+	}
+	return reader, nil
+}
+
+type readCloser struct {
+	io.Reader
+	io.Closer
+}
diff --git a/osbase/oci/registry/BUILD.bazel b/osbase/oci/registry/BUILD.bazel
index ca2d806..cbe81b0 100644
--- a/osbase/oci/registry/BUILD.bazel
+++ b/osbase/oci/registry/BUILD.bazel
@@ -20,6 +20,20 @@
 
 go_test(
     name = "registry_test",
-    srcs = ["headers_test.go"],
+    srcs = [
+        "client_test.go",
+        "headers_test.go",
+    ],
+    data = [
+        "//osbase/oci/osimage:test_image_uncompressed",
+    ],
     embed = [":registry"],
+    x_defs = {
+        "xImagePath": "$(rlocationpath //osbase/oci/osimage:test_image_uncompressed )",
+    },
+    deps = [
+        "//osbase/oci",
+        "@com_github_cenkalti_backoff_v4//:backoff",
+        "@io_bazel_rules_go//go/runfiles",
+    ],
 )
diff --git a/osbase/oci/registry/client_test.go b/osbase/oci/registry/client_test.go
new file mode 100644
index 0000000..328d2fa
--- /dev/null
+++ b/osbase/oci/registry/client_test.go
@@ -0,0 +1,158 @@
+// Copyright The Monogon Project Authors.
+// SPDX-License-Identifier: Apache-2.0
+
+package registry
+
+import (
+	"context"
+	"fmt"
+	"net"
+	"net/http"
+	"strings"
+	"sync"
+	"testing"
+	"time"
+
+	"github.com/bazelbuild/rules_go/go/runfiles"
+	"github.com/cenkalti/backoff/v4"
+
+	"source.monogon.dev/osbase/oci"
+)
+
+var (
+	// These are filled by bazel at linking time with the canonical path of
+	// their corresponding file. Inside the init function we resolve it
+	// with the rules_go runfiles package to the real path.
+	xImagePath string
+)
+
+func init() {
+	var err error
+	for _, path := range []*string{
+		&xImagePath,
+	} {
+		*path, err = runfiles.Rlocation(*path)
+		if err != nil {
+			panic(err)
+		}
+	}
+}
+
+func TestRetries(t *testing.T) {
+	srcImage, err := oci.ReadLayout(xImagePath)
+	if err != nil {
+		t.Fatal(err)
+	}
+	server := NewServer()
+	server.AddImage("test/repo", "test-tag", srcImage)
+	wrapper := &unreliableServer{
+		handler:   server,
+		blobLimit: srcImage.Manifest.Config.Size / 2,
+		seen:      make(map[string]bool),
+	}
+
+	listener, err := net.Listen("tcp", "127.0.0.1:0")
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer listener.Close()
+	go http.Serve(listener, wrapper)
+	wrapper.host = listener.Addr().String()
+
+	client := &Client{
+		GetBackOff: func() backoff.BackOff {
+			return backoff.NewExponentialBackOff(backoff.WithInitialInterval(time.Millisecond))
+		},
+		RetryNotify: func(err error, d time.Duration) {
+			fmt.Printf("Retrying in %v: %v\n", d, err)
+		},
+		Scheme:     "http",
+		Host:       listener.Addr().String(),
+		Repository: "test/repo",
+	}
+
+	image, err := client.Read(context.Background(), "test-tag", srcImage.ManifestDigest)
+	if err != nil {
+		t.Fatal(err)
+	}
+	_, err = image.ReadBlobVerified(&image.Manifest.Config)
+	if err != nil {
+		t.Error(err)
+	}
+}
+
+type unreliableServer struct {
+	handler   http.Handler
+	host      string
+	blobLimit int64
+	mu        sync.Mutex
+	seen      map[string]bool
+}
+
+func (s *unreliableServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+	fmt.Printf("%s %s %s\n", req.Method, req.URL.String(), req.Header.Get("Range"))
+
+	// Every path returns a temporary error the first time it is hit. This
+	// includes the redirected and token paths.
+	s.mu.Lock()
+	if !s.seen[req.URL.Path] {
+		s.seen[req.URL.Path] = true
+		s.mu.Unlock()
+		w.WriteHeader(http.StatusServiceUnavailable)
+		return
+	}
+	s.mu.Unlock()
+
+	// Every path is redirected.
+	var ok bool
+	req.URL.Path, ok = strings.CutPrefix(req.URL.Path, "/redirected")
+	if !ok {
+		req.URL.Path = "/redirected" + req.URL.Path
+		w.Header().Set("Location", req.URL.String())
+		w.WriteHeader(http.StatusTemporaryRedirect)
+		return
+	}
+
+	// Each request requires a token.
+	if req.URL.Path == "/token" {
+		query := req.URL.Query()
+		if query.Get("service") != "myregistry.test" || query.Get("scope") != "repository:test/repo:pull" {
+			w.WriteHeader(http.StatusBadRequest)
+			return
+		}
+		w.WriteHeader(http.StatusOK)
+		w.Write([]byte(`{"token": "the_token"}`))
+		return
+	} else if req.Header.Get("Authorization") != "Bearer the_token" {
+		w.Header().Set("Www-Authenticate", fmt.Sprintf(`Bearer realm="http://%s/token",service="myregistry.test",scope="repository:test/repo:pull"`, s.host))
+		w.WriteHeader(http.StatusUnauthorized)
+		return
+	}
+
+	// Blob requests fail after returning part of the response, requiring retries
+	// with Range header.
+	if strings.Contains(req.URL.Path, "/blobs/") {
+		w = &limitResponseWriter{ResponseWriter: w, remaining: s.blobLimit}
+	}
+
+	s.handler.ServeHTTP(w, req)
+}
+
+type limitResponseWriter struct {
+	http.ResponseWriter
+	remaining int64
+}
+
+func (w *limitResponseWriter) Write(b []byte) (n int, err error) {
+	if w.remaining <= 0 {
+		return 0, fmt.Errorf("limit reached")
+	}
+	if int64(len(b)) > w.remaining {
+		n, _ = w.ResponseWriter.Write(b[:w.remaining])
+		err = fmt.Errorf("limit reached")
+		w.remaining = 0
+		return
+	}
+	w.remaining -= int64(len(b))
+	return w.ResponseWriter.Write(b)
+}