osbase/oci: implement support for OCI index

Previously, only OCI images were supported, now we can also handle
indexes. The new Ref type is either an Image or Index.

Change-Id: I1b282ed6078d53e9a69e7613f601fdbbe64e192b
Reviewed-on: https://review.monogon.dev/c/monogon/+/4475
Tested-by: Jenkins CI
Reviewed-by: Tim Windelschmidt <tim@monogon.tech>
diff --git a/osbase/oci/registry/client.go b/osbase/oci/registry/client.go
index 4e60b7b..4f2b6d2 100644
--- a/osbase/oci/registry/client.go
+++ b/osbase/oci/registry/client.go
@@ -43,6 +43,12 @@
 	DigestRegexp     = regexp.MustCompile(`^` + digestExpr + `$`)
 )
 
+// unknownManifest can be used to parse the media type from a manifest of
+// unknown type.
+type unknownManifest struct {
+	MediaType string `json:"mediaType,omitempty"`
+}
+
 // Client is an OCI registry client.
 type Client struct {
 	// Transport will be used to make requests. For example, this allows
@@ -69,10 +75,10 @@
 	bearerToken string
 }
 
-// Read fetches an image manifest from the registry and returns an [oci.Image].
+// Read fetches a manifest from the registry and returns an [oci.Ref].
 //
-// The context is used for the manifest request and all blob requests made
-// through the Image.
+// The context is used for the manifest request and for all blob and manifest
+// requests made through the Ref.
 //
 // At least one of tag and digest must be set. If only tag is set, then you are
 // trusting the registry to return the right content. Otherwise, the digest is
@@ -80,7 +86,7 @@
 // used in the request, and the digest is used to verify the response. The
 // advantage of fetching by tag is that it allows a pull through cache to
 // display tags to a user inspecting the cache contents.
-func (c *Client) Read(ctx context.Context, tag, digest string) (*oci.Image, error) {
+func (c *Client) Read(ctx context.Context, tag, digest string) (oci.Ref, error) {
 	if !RepositoryRegexp.MatchString(c.Repository) {
 		return nil, fmt.Errorf("invalid repository %q", c.Repository)
 	}
@@ -102,13 +108,14 @@
 	}
 
 	manifestPath := fmt.Sprintf("/v2/%s/manifests/%s", c.Repository, reference)
-	var imageManifestBytes []byte
+	var manifestBytes []byte
+	var manifestMediaType string
 	err := c.retry(ctx, func() error {
 		req, err := c.newGet(manifestPath)
 		if err != nil {
 			return err
 		}
-		req.Header.Set("Accept", ocispecv1.MediaTypeImageManifest)
+		req.Header.Set("Accept", ocispecv1.MediaTypeImageManifest+","+ocispecv1.MediaTypeImageIndex)
 		resp, err := c.doGet(ctx, req)
 		if err != nil {
 			return err
@@ -117,18 +124,34 @@
 			return readClientError(resp, req)
 		}
 		defer resp.Body.Close()
-		imageManifestBytes, err = readFullBody(resp, 50*1024*1024)
+		manifestMediaType = resp.Header.Get("Content-Type")
+		manifestBytes, err = readFullBody(resp, 50*1024*1024)
 		return err
 	})
 	if err != nil {
 		return nil, err
 	}
 
+	// Remove any parameters from the Content-Type header.
+	manifestMediaType, _, _ = strings.Cut(manifestMediaType, ";")
+	switch manifestMediaType {
+	case ocispecv1.MediaTypeImageManifest, ocispecv1.MediaTypeImageIndex:
+		// The Content-Type header is valid, use it.
+	default:
+		// We need to parse the manifest to extract the media type, then parse it
+		// again for that media type.
+		var manifest unknownManifest
+		if err := json.Unmarshal(manifestBytes, &manifest); err != nil {
+			return nil, fmt.Errorf("failed to parse manifest: %w", err)
+		}
+		manifestMediaType = manifest.MediaType
+	}
+
 	blobs := &clientBlobs{
 		ctx:    ctx,
 		client: c,
 	}
-	return oci.NewImage(imageManifestBytes, digest, blobs)
+	return oci.NewRef(manifestBytes, manifestMediaType, digest, blobs)
 }
 
 type clientBlobs struct {
@@ -136,6 +159,41 @@
 	client *Client
 }
 
+func (r *clientBlobs) Manifest(descriptor *ocispecv1.Descriptor) ([]byte, error) {
+	digest := string(descriptor.Digest)
+	if _, _, err := oci.ParseDigest(digest); err != nil {
+		return nil, err
+	}
+
+	manifestPath := fmt.Sprintf("/v2/%s/manifests/%s", r.client.Repository, digest)
+	var manifestBytes []byte
+	err := r.client.retry(r.ctx, func() error {
+		req, err := r.client.newGet(manifestPath)
+		if err != nil {
+			return err
+		}
+		req.Header.Set("Accept", ocispecv1.MediaTypeImageManifest+","+ocispecv1.MediaTypeImageIndex)
+		resp, err := r.client.doGet(r.ctx, req)
+		if err != nil {
+			return err
+		}
+		if resp.StatusCode != http.StatusOK {
+			return readClientError(resp, req)
+		}
+		defer resp.Body.Close()
+		manifestBytes, err = readKnownSizeBody(resp, int(descriptor.Size))
+		return err
+	})
+	if err != nil {
+		return nil, err
+	}
+	return manifestBytes, nil
+}
+
+func (r *clientBlobs) Blobs(_ *ocispecv1.Descriptor) (oci.Blobs, error) {
+	return r, nil
+}
+
 func (r *clientBlobs) Blob(descriptor *ocispecv1.Descriptor) (io.ReadCloser, error) {
 	if !DigestRegexp.MatchString(string(descriptor.Digest)) {
 		return nil, fmt.Errorf("invalid blob digest %q", descriptor.Digest)
@@ -480,3 +538,15 @@
 		return nil, backoff.Permanent(fmt.Errorf("HTTP response of size %d exceeds limit of %d bytes", resp.ContentLength, limit))
 	}
 }
+
+func readKnownSizeBody(resp *http.Response, size int) ([]byte, error) {
+	if resp.ContentLength >= 0 && resp.ContentLength != int64(size) {
+		return nil, backoff.Permanent(fmt.Errorf("HTTP response has size %d, expected %d bytes", resp.ContentLength, size))
+	}
+	content := make([]byte, size)
+	_, err := io.ReadFull(resp.Body, content)
+	if err != nil {
+		return nil, err
+	}
+	return content, nil
+}
diff --git a/osbase/oci/registry/client_test.go b/osbase/oci/registry/client_test.go
index 328d2fa..622a3ab 100644
--- a/osbase/oci/registry/client_test.go
+++ b/osbase/oci/registry/client_test.go
@@ -39,12 +39,12 @@
 }
 
 func TestRetries(t *testing.T) {
-	srcImage, err := oci.ReadLayout(xImagePath)
+	srcImage, err := oci.AsImage(oci.ReadLayout(xImagePath))
 	if err != nil {
 		t.Fatal(err)
 	}
 	server := NewServer()
-	server.AddImage("test/repo", "test-tag", srcImage)
+	server.AddRef("test/repo", "test-tag", srcImage)
 	wrapper := &unreliableServer{
 		handler:   server,
 		blobLimit: srcImage.Manifest.Config.Size / 2,
@@ -71,7 +71,7 @@
 		Repository: "test/repo",
 	}
 
-	image, err := client.Read(context.Background(), "test-tag", srcImage.ManifestDigest)
+	image, err := oci.AsImage(client.Read(context.Background(), "test-tag", srcImage.Digest()))
 	if err != nil {
 		t.Fatal(err)
 	}
diff --git a/osbase/oci/registry/server.go b/osbase/oci/registry/server.go
index 9c99c40..d7765b1 100644
--- a/osbase/oci/registry/server.go
+++ b/osbase/oci/registry/server.go
@@ -15,8 +15,6 @@
 	"strings"
 	"time"
 
-	ocispecv1 "github.com/opencontainers/image-spec/specs-go/v1"
-
 	"source.monogon.dev/osbase/oci"
 	"source.monogon.dev/osbase/structfs"
 )
@@ -49,16 +47,24 @@
 	}
 }
 
-// AddImage adds an image to the server in the specified repository.
+// AddRef adds a Ref to the server in the specified repository.
 //
 // If the tag is empty, the image can only be fetched by digest.
-func (s *Server) AddImage(repository string, tag string, image *oci.Image) error {
+func (s *Server) AddRef(repository string, tag string, ref oci.Ref) error {
 	if !RepositoryRegexp.MatchString(repository) {
 		return fmt.Errorf("invalid repository %q", repository)
 	}
 	if tag != "" && !TagRegexp.MatchString(tag) {
 		return fmt.Errorf("invalid tag %q", tag)
 	}
+	var refs []oci.Ref
+	err := oci.WalkRefs(ref.Digest(), ref, func(_ string, r oci.Ref) error {
+		refs = append(refs, r)
+		return nil
+	})
+	if err != nil {
+		return err
+	}
 
 	s.mu.Lock()
 	defer s.mu.Unlock()
@@ -71,17 +77,22 @@
 		}
 		s.repositories[repository] = repo
 	}
-	if _, ok := repo.manifests[image.ManifestDigest]; !ok {
-		for descriptor := range image.Descriptors() {
-			repo.blobs[string(descriptor.Digest)] = image.StructfsBlob(descriptor)
+	for _, ref := range refs {
+		if _, ok := repo.manifests[ref.Digest()]; ok {
+			continue
 		}
-		repo.manifests[image.ManifestDigest] = serverManifest{
-			contentType: ocispecv1.MediaTypeImageManifest,
-			content:     image.RawManifest,
+		if image, ok := ref.(*oci.Image); ok {
+			for descriptor := range image.Descriptors() {
+				repo.blobs[string(descriptor.Digest)] = image.StructfsBlob(descriptor)
+			}
+		}
+		repo.manifests[ref.Digest()] = serverManifest{
+			contentType: ref.MediaType(),
+			content:     ref.RawManifest(),
 		}
 	}
 	if tag != "" {
-		repo.tags[tag] = image.ManifestDigest
+		repo.tags[tag] = ref.Digest()
 	}
 	return nil
 }