osbase/oci/registry: add package
This adds the registry package, which contains a client and server
implementation of the OCI Distribution spec.
Change-Id: I080bb1dbc511f8e6466ca370b090d459d2b730e8
Reviewed-on: https://review.monogon.dev/c/monogon/+/4086
Tested-by: Jenkins CI
Reviewed-by: Tim Windelschmidt <tim@monogon.tech>
diff --git a/osbase/oci/registry/client.go b/osbase/oci/registry/client.go
new file mode 100644
index 0000000..c414108
--- /dev/null
+++ b/osbase/oci/registry/client.go
@@ -0,0 +1,482 @@
+// Copyright The Monogon Project Authors.
+// SPDX-License-Identifier: Apache-2.0
+
+// Package registry contains a client and server implementation of the OCI
+// Distribution spec. Both client and server only support pulling. The server is
+// intended for use in tests.
+package registry
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "regexp"
+ "strconv"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/cenkalti/backoff/v4"
+ ocispecv1 "github.com/opencontainers/image-spec/specs-go/v1"
+
+ "source.monogon.dev/osbase/oci"
+)
+
+// Sources for these expressions:
+//
+// - https://github.com/opencontainers/distribution-spec/blob/main/spec.md#pulling-manifests
+// - https://github.com/opencontainers/image-spec/blob/main/descriptor.md#digests
+const (
+ repositoryExpr = `[a-z0-9]+(?:(?:\.|_|__|-+)[a-z0-9]+)*(?:\/[a-z0-9]+(?:(?:\.|_|__|-+)[a-z0-9]+)*)*`
+ tagExpr = `[a-zA-Z0-9_][a-zA-Z0-9._-]{0,127}`
+ digestExpr = `[a-z0-9]+(?:[+._-][a-z0-9]+)*:[a-zA-Z0-9=_-]+`
+)
+
+var (
+ repositoryRegexp = regexp.MustCompile(`^` + repositoryExpr + `$`)
+ tagRegexp = regexp.MustCompile(`^` + tagExpr + `$`)
+ digestRegexp = regexp.MustCompile(`^` + digestExpr + `$`)
+)
+
+// Client is an OCI registry client.
+type Client struct {
+ // Transport will be used to make requests. For example, this allows
+ // configuring TLS client and CA certificates.
+ // If nil, [http.DefaultTransport] is used.
+ Transport http.RoundTripper
+ // GetBackOff can be set to to make the Client retry HTTP requests.
+ GetBackOff func() backoff.BackOff
+ // RetryNotify receives errors that trigger a retry, e.g. for logging.
+ RetryNotify backoff.Notify
+ // UserAgent is used as the User-Agent HTTP header.
+ UserAgent string
+
+ // Scheme must be either http or https.
+ Scheme string
+ // Host is the host with optional port.
+ Host string
+ // Repository is the name of the repository. It is part of the client because
+ // bearer tokens are usually scoped to a repository.
+ Repository string
+
+ authMu sync.RWMutex
+ // bearerToken is a cached token obtained from an authorization service.
+ bearerToken string
+}
+
+// Read fetches an image manifest from the registry and returns an [oci.Image].
+//
+// The context is used for the manifest request and all blob requests made
+// through the Image.
+//
+// At least one of tag and digest must be set. If only tag is set, then you are
+// trusting the registry to return the right content. Otherwise, the digest is
+// used to verify the manifest. If both tag and digest are set, then the tag is
+// used in the request, and the digest is used to verify the response. The
+// advantage of fetching by tag is that it allows a pull through cache to
+// display tags to a user inspecting the cache contents.
+func (c *Client) Read(ctx context.Context, tag, digest string) (*oci.Image, error) {
+ if !repositoryRegexp.MatchString(c.Repository) {
+ return nil, fmt.Errorf("invalid repository %q", c.Repository)
+ }
+ if tag != "" && !tagRegexp.MatchString(tag) {
+ return nil, fmt.Errorf("invalid tag %q", tag)
+ }
+ if digest != "" {
+ if _, _, err := oci.ParseDigest(digest); err != nil {
+ return nil, err
+ }
+ }
+ var reference string
+ if tag != "" {
+ reference = tag
+ } else if digest != "" {
+ reference = digest
+ } else {
+ return nil, fmt.Errorf("tag and digest cannot both be empty")
+ }
+
+ manifestPath := fmt.Sprintf("/v2/%s/manifests/%s", c.Repository, reference)
+ var imageManifestBytes []byte
+ err := c.retry(ctx, func() error {
+ req, err := c.newGet(manifestPath)
+ if err != nil {
+ return err
+ }
+ req.Header.Set("Accept", ocispecv1.MediaTypeImageManifest)
+ resp, err := c.doGet(ctx, req)
+ if err != nil {
+ return err
+ }
+ if resp.StatusCode != http.StatusOK {
+ return readClientError(resp, req)
+ }
+ defer resp.Body.Close()
+ imageManifestBytes, err = readFullBody(resp, 50*1024*1024)
+ return err
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ blobs := &clientBlobs{
+ ctx: ctx,
+ client: c,
+ }
+ return oci.NewImage(imageManifestBytes, digest, blobs)
+}
+
+type clientBlobs struct {
+ ctx context.Context
+ client *Client
+}
+
+func (r *clientBlobs) Blob(descriptor *ocispecv1.Descriptor) (io.ReadCloser, error) {
+ if !digestRegexp.MatchString(string(descriptor.Digest)) {
+ return nil, fmt.Errorf("invalid blob digest %q", descriptor.Digest)
+ }
+ blobPath := fmt.Sprintf("/v2/%s/blobs/%s", r.client.Repository, descriptor.Digest)
+ var resp *http.Response
+ err := r.client.retry(r.ctx, func() error {
+ req, err := r.client.newGet(blobPath)
+ if err != nil {
+ return err
+ }
+ resp, err = r.client.doGet(r.ctx, req)
+ if err != nil {
+ return err
+ }
+ if resp.StatusCode != http.StatusOK {
+ return readClientError(resp, req)
+ }
+ return nil
+ })
+ if err != nil {
+ return nil, err
+ }
+ if r.client.GetBackOff == nil {
+ return resp.Body, nil
+ }
+ ctx, cancel := context.WithCancelCause(r.ctx)
+ reader := &retryReader{
+ ctx: ctx,
+ cancel: cancel,
+ client: r.client,
+ path: blobPath,
+ pos: 0,
+ size: descriptor.Size,
+ }
+ reader.resp.Store(resp)
+ return reader, nil
+}
+
+type retryReader struct {
+ ctx context.Context
+ cancel context.CancelCauseFunc
+ client *Client
+ path string
+ pos int64
+ size int64
+ // resp is an atomic pointer because it may be concurrently written by Read()
+ // and read by Close().
+ resp atomic.Pointer[http.Response]
+}
+
+func (r *retryReader) Read(p []byte) (n int, err error) {
+ if r.pos >= r.size {
+ return 0, io.EOF
+ }
+ if len(p) == 0 {
+ return 0, nil
+ }
+ if int64(len(p)) > r.size-r.pos {
+ p = p[:r.size-r.pos]
+ }
+ closed := false
+ err = r.client.retry(r.ctx, func() error {
+ if closed {
+ req, err := r.client.newGet(r.path)
+ if err != nil {
+ return err
+ }
+ if r.pos != 0 {
+ req.Header.Set("Range", fmt.Sprintf("bytes=%d-", r.pos))
+ }
+ resp, err := r.client.doGet(r.ctx, req)
+ if err != nil {
+ return err
+ }
+ r.resp.Store(resp)
+ if err := context.Cause(r.ctx); err != nil {
+ resp.Body.Close()
+ return err
+ }
+ switch resp.StatusCode {
+ case http.StatusOK:
+ _, err := io.CopyN(io.Discard, resp.Body, r.pos)
+ if err != nil {
+ return err
+ }
+ case http.StatusPartialContent:
+ if !strings.HasPrefix(resp.Header.Get("Content-Range"), fmt.Sprintf("bytes %d-", r.pos)) {
+ return backoff.Permanent(errors.New("invalid content range"))
+ }
+ default:
+ return readClientError(resp, req)
+ }
+ }
+ var err error
+ n, err = r.resp.Load().Body.Read(p)
+ if n != 0 {
+ r.pos += int64(n)
+ return nil
+ }
+ if err == nil {
+ err = errors.New("read 0 bytes")
+ }
+ closed = true
+ r.resp.Load().Body.Close()
+ return err
+ })
+ if r.pos >= r.size {
+ err = io.EOF
+ } else if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ return
+}
+
+func (r *retryReader) Close() error {
+ r.cancel(errors.New("reader closed"))
+ return r.resp.Load().Body.Close()
+}
+
+func (c *Client) retry(ctx context.Context, o func() error) error {
+ if err := ctx.Err(); err != nil {
+ return err
+ }
+ var b backoff.BackOff
+ for {
+ err := o()
+ if err == nil {
+ return nil
+ }
+ var permanent *backoff.PermanentError
+ if errors.As(err, &permanent) {
+ return err
+ }
+ if ctx.Err() != nil {
+ return err
+ }
+ if b == nil {
+ if c.GetBackOff == nil {
+ return err
+ }
+ b = c.GetBackOff()
+ }
+ next := b.NextBackOff()
+ if next == backoff.Stop {
+ return err
+ }
+ var clientErr *ClientError
+ if errors.As(err, &clientErr) && !clientErr.RetryAfter.IsZero() {
+ next = max(next, time.Until(clientErr.RetryAfter))
+ }
+ deadline, hasDeadline := ctx.Deadline()
+ if hasDeadline && time.Until(deadline) < next {
+ return err
+ }
+
+ if c.RetryNotify != nil {
+ c.RetryNotify(err, next)
+ }
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-time.After(next):
+ }
+ }
+}
+
+func (c *Client) newGet(path string) (*http.Request, error) {
+ u := url.URL{
+ Scheme: c.Scheme,
+ Host: c.Host,
+ Path: path,
+ }
+ req, err := http.NewRequest("GET", u.String(), nil)
+ if err != nil {
+ return nil, err
+ }
+ if c.UserAgent != "" {
+ req.Header.Set("User-Agent", c.UserAgent)
+ }
+ return req, nil
+}
+
+func (c *Client) doGet(ctx context.Context, req *http.Request) (*http.Response, error) {
+ req = req.WithContext(ctx)
+ c.addAuthorization(req)
+ client := http.Client{Transport: c.Transport}
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, redactURLError(err)
+ }
+
+ if resp.StatusCode == http.StatusUnauthorized {
+ unauthorizedErr := readClientError(resp, req)
+ retry, err := c.handleUnauthorized(ctx, resp)
+ if err != nil {
+ return nil, err
+ }
+ if !retry {
+ return nil, unauthorizedErr
+ }
+ c.addAuthorization(req)
+ resp, err = client.Do(req)
+ if err != nil {
+ return nil, redactURLError(err)
+ }
+ }
+
+ return resp, nil
+}
+
+func readClientError(resp *http.Response, req *http.Request) error {
+ defer resp.Body.Close()
+ clientErr := &ClientError{
+ StatusCode: resp.StatusCode,
+ }
+ retryAfter := resp.Header.Get("Retry-After")
+ if retryAfter != "" {
+ seconds, err := strconv.ParseInt(retryAfter, 10, 64)
+ if err == nil {
+ clientErr.RetryAfter = time.Now().Add(time.Duration(seconds) * time.Second)
+ } else {
+ clientErr.RetryAfter, _ = http.ParseTime(retryAfter)
+ }
+ }
+ content, err := readFullBody(resp, 2048)
+ if err == nil {
+ clientErr.RawBody = content
+ _ = json.Unmarshal(content, &clientErr.ErrorBody)
+ }
+
+ errReq := resp.Request
+ if errReq == nil {
+ errReq = req
+ }
+ urlErr := &url.Error{
+ Op: errReq.Method,
+ URL: errReq.URL.Redacted(),
+ Err: clientErr,
+ }
+ err = redactURLError(urlErr)
+
+ // Client errors are usually permanent, and server errors are usually
+ // temporary, but there are some exceptions.
+ isTemporary := 500 <= clientErr.StatusCode && clientErr.StatusCode <= 599
+ switch clientErr.StatusCode {
+ case http.StatusRequestTimeout, http.StatusTooEarly,
+ http.StatusTooManyRequests,
+ 499: // nginx-specific, client closed request
+ isTemporary = true
+ case http.StatusNotImplemented, http.StatusHTTPVersionNotSupported,
+ http.StatusNetworkAuthenticationRequired:
+ isTemporary = false
+ }
+ if !isTemporary {
+ return backoff.Permanent(err)
+ }
+ return err
+}
+
+// ClientError is an HTTP error received from a registry or authorization
+// service.
+type ClientError struct {
+ ErrorBody
+ StatusCode int
+ RetryAfter time.Time
+ RawBody []byte
+}
+
+type ErrorBody struct {
+ Errors []ErrorInfo `json:"errors,omitempty"`
+}
+
+type ErrorInfo struct {
+ Code string `json:"code"`
+ Message string `json:"message,omitempty"`
+}
+
+func (e *ClientError) Error() string {
+ if len(e.Errors) == 0 {
+ text := fmt.Sprintf("HTTP %d %s", e.StatusCode, http.StatusText(e.StatusCode))
+ if len(e.RawBody) != 0 {
+ text = fmt.Sprintf("%s: %q", text, e.RawBody)
+ }
+ return text
+ }
+ var errorStrs []string
+ for _, ei := range e.Errors {
+ errorStrs = append(errorStrs, fmt.Sprintf("%s: %s", ei.Code, ei.Message))
+ }
+ return fmt.Sprintf("HTTP %d %s", e.StatusCode, strings.Join(errorStrs, "; "))
+}
+
+// redactURLError redacts the URL in an [url.Error]. After redirects, the URL
+// may contain secrets in query parameter values.
+//
+// Logic adapted from:
+// https://github.com/google/go-containerregistry/blob/v0.20.3/internal/redact/redact.go
+func redactURLError(err error) error {
+ var urlErr *url.Error
+ if !errors.As(err, &urlErr) {
+ return err
+ }
+ u, perr := url.Parse(urlErr.URL)
+ if perr != nil {
+ return err
+ }
+ query := u.Query()
+ for name, vals := range query {
+ if name == "scope" || name == "service" {
+ continue
+ }
+ for i := range vals {
+ vals[i] = "REDACTED"
+ }
+ }
+ u.RawQuery = query.Encode()
+ urlErr.URL = u.Redacted()
+ return err
+}
+
+func readFullBody(resp *http.Response, limit int) ([]byte, error) {
+ switch {
+ case resp.ContentLength < 0:
+ lr := io.LimitReader(resp.Body, int64(limit)+1)
+ content, err := io.ReadAll(lr)
+ if err != nil {
+ return nil, err
+ }
+ if len(content) > limit {
+ return nil, backoff.Permanent(fmt.Errorf("HTTP response exceeds limit of %d bytes", limit))
+ }
+ return content, nil
+ case resp.ContentLength <= int64(limit):
+ content := make([]byte, resp.ContentLength)
+ _, err := io.ReadFull(resp.Body, content)
+ if err != nil {
+ return nil, err
+ }
+ return content, nil
+ default:
+ return nil, backoff.Permanent(fmt.Errorf("HTTP response of size %d exceeds limit of %d bytes", resp.ContentLength, limit))
+ }
+}