blob: 4e60b7b51705ed463eede1263522dba83af96767 [file] [log] [blame]
Jan Schärcc9e4d12025-04-14 10:28:40 +00001// Copyright The Monogon Project Authors.
2// SPDX-License-Identifier: Apache-2.0
3
4// Package registry contains a client and server implementation of the OCI
5// Distribution spec. Both client and server only support pulling. The server is
6// intended for use in tests.
7package registry
8
9import (
10 "context"
11 "encoding/json"
12 "errors"
13 "fmt"
14 "io"
15 "net/http"
16 "net/url"
17 "regexp"
18 "strconv"
19 "strings"
20 "sync"
21 "sync/atomic"
22 "time"
23
24 "github.com/cenkalti/backoff/v4"
25 ocispecv1 "github.com/opencontainers/image-spec/specs-go/v1"
26
27 "source.monogon.dev/osbase/oci"
28)
29
30// Sources for these expressions:
31//
32// - https://github.com/opencontainers/distribution-spec/blob/main/spec.md#pulling-manifests
33// - https://github.com/opencontainers/image-spec/blob/main/descriptor.md#digests
34const (
35 repositoryExpr = `[a-z0-9]+(?:(?:\.|_|__|-+)[a-z0-9]+)*(?:\/[a-z0-9]+(?:(?:\.|_|__|-+)[a-z0-9]+)*)*`
36 tagExpr = `[a-zA-Z0-9_][a-zA-Z0-9._-]{0,127}`
37 digestExpr = `[a-z0-9]+(?:[+._-][a-z0-9]+)*:[a-zA-Z0-9=_-]+`
38)
39
40var (
Jan Schär62cecde2025-04-16 15:24:04 +000041 RepositoryRegexp = regexp.MustCompile(`^` + repositoryExpr + `$`)
42 TagRegexp = regexp.MustCompile(`^` + tagExpr + `$`)
43 DigestRegexp = regexp.MustCompile(`^` + digestExpr + `$`)
Jan Schärcc9e4d12025-04-14 10:28:40 +000044)
45
46// Client is an OCI registry client.
47type Client struct {
48 // Transport will be used to make requests. For example, this allows
49 // configuring TLS client and CA certificates.
50 // If nil, [http.DefaultTransport] is used.
51 Transport http.RoundTripper
52 // GetBackOff can be set to to make the Client retry HTTP requests.
53 GetBackOff func() backoff.BackOff
54 // RetryNotify receives errors that trigger a retry, e.g. for logging.
55 RetryNotify backoff.Notify
56 // UserAgent is used as the User-Agent HTTP header.
57 UserAgent string
58
59 // Scheme must be either http or https.
60 Scheme string
61 // Host is the host with optional port.
62 Host string
63 // Repository is the name of the repository. It is part of the client because
64 // bearer tokens are usually scoped to a repository.
65 Repository string
66
67 authMu sync.RWMutex
68 // bearerToken is a cached token obtained from an authorization service.
69 bearerToken string
70}
71
72// Read fetches an image manifest from the registry and returns an [oci.Image].
73//
74// The context is used for the manifest request and all blob requests made
75// through the Image.
76//
77// At least one of tag and digest must be set. If only tag is set, then you are
78// trusting the registry to return the right content. Otherwise, the digest is
79// used to verify the manifest. If both tag and digest are set, then the tag is
80// used in the request, and the digest is used to verify the response. The
81// advantage of fetching by tag is that it allows a pull through cache to
82// display tags to a user inspecting the cache contents.
83func (c *Client) Read(ctx context.Context, tag, digest string) (*oci.Image, error) {
Jan Schär62cecde2025-04-16 15:24:04 +000084 if !RepositoryRegexp.MatchString(c.Repository) {
Jan Schärcc9e4d12025-04-14 10:28:40 +000085 return nil, fmt.Errorf("invalid repository %q", c.Repository)
86 }
Jan Schär62cecde2025-04-16 15:24:04 +000087 if tag != "" && !TagRegexp.MatchString(tag) {
Jan Schärcc9e4d12025-04-14 10:28:40 +000088 return nil, fmt.Errorf("invalid tag %q", tag)
89 }
90 if digest != "" {
91 if _, _, err := oci.ParseDigest(digest); err != nil {
92 return nil, err
93 }
94 }
95 var reference string
96 if tag != "" {
97 reference = tag
98 } else if digest != "" {
99 reference = digest
100 } else {
101 return nil, fmt.Errorf("tag and digest cannot both be empty")
102 }
103
104 manifestPath := fmt.Sprintf("/v2/%s/manifests/%s", c.Repository, reference)
105 var imageManifestBytes []byte
106 err := c.retry(ctx, func() error {
107 req, err := c.newGet(manifestPath)
108 if err != nil {
109 return err
110 }
111 req.Header.Set("Accept", ocispecv1.MediaTypeImageManifest)
112 resp, err := c.doGet(ctx, req)
113 if err != nil {
114 return err
115 }
116 if resp.StatusCode != http.StatusOK {
117 return readClientError(resp, req)
118 }
119 defer resp.Body.Close()
120 imageManifestBytes, err = readFullBody(resp, 50*1024*1024)
121 return err
122 })
123 if err != nil {
124 return nil, err
125 }
126
127 blobs := &clientBlobs{
128 ctx: ctx,
129 client: c,
130 }
131 return oci.NewImage(imageManifestBytes, digest, blobs)
132}
133
134type clientBlobs struct {
135 ctx context.Context
136 client *Client
137}
138
139func (r *clientBlobs) Blob(descriptor *ocispecv1.Descriptor) (io.ReadCloser, error) {
Jan Schär62cecde2025-04-16 15:24:04 +0000140 if !DigestRegexp.MatchString(string(descriptor.Digest)) {
Jan Schärcc9e4d12025-04-14 10:28:40 +0000141 return nil, fmt.Errorf("invalid blob digest %q", descriptor.Digest)
142 }
143 blobPath := fmt.Sprintf("/v2/%s/blobs/%s", r.client.Repository, descriptor.Digest)
144 var resp *http.Response
145 err := r.client.retry(r.ctx, func() error {
146 req, err := r.client.newGet(blobPath)
147 if err != nil {
148 return err
149 }
150 resp, err = r.client.doGet(r.ctx, req)
151 if err != nil {
152 return err
153 }
154 if resp.StatusCode != http.StatusOK {
155 return readClientError(resp, req)
156 }
157 return nil
158 })
159 if err != nil {
160 return nil, err
161 }
162 if r.client.GetBackOff == nil {
163 return resp.Body, nil
164 }
165 ctx, cancel := context.WithCancelCause(r.ctx)
166 reader := &retryReader{
167 ctx: ctx,
168 cancel: cancel,
169 client: r.client,
170 path: blobPath,
171 pos: 0,
172 size: descriptor.Size,
173 }
174 reader.resp.Store(resp)
175 return reader, nil
176}
177
178type retryReader struct {
179 ctx context.Context
180 cancel context.CancelCauseFunc
181 client *Client
182 path string
183 pos int64
184 size int64
185 // resp is an atomic pointer because it may be concurrently written by Read()
186 // and read by Close().
187 resp atomic.Pointer[http.Response]
188}
189
190func (r *retryReader) Read(p []byte) (n int, err error) {
191 if r.pos >= r.size {
192 return 0, io.EOF
193 }
194 if len(p) == 0 {
195 return 0, nil
196 }
197 if int64(len(p)) > r.size-r.pos {
198 p = p[:r.size-r.pos]
199 }
200 closed := false
201 err = r.client.retry(r.ctx, func() error {
202 if closed {
203 req, err := r.client.newGet(r.path)
204 if err != nil {
205 return err
206 }
207 if r.pos != 0 {
208 req.Header.Set("Range", fmt.Sprintf("bytes=%d-", r.pos))
209 }
210 resp, err := r.client.doGet(r.ctx, req)
211 if err != nil {
212 return err
213 }
214 r.resp.Store(resp)
215 if err := context.Cause(r.ctx); err != nil {
216 resp.Body.Close()
217 return err
218 }
219 switch resp.StatusCode {
220 case http.StatusOK:
221 _, err := io.CopyN(io.Discard, resp.Body, r.pos)
222 if err != nil {
223 return err
224 }
225 case http.StatusPartialContent:
226 if !strings.HasPrefix(resp.Header.Get("Content-Range"), fmt.Sprintf("bytes %d-", r.pos)) {
227 return backoff.Permanent(errors.New("invalid content range"))
228 }
229 default:
230 return readClientError(resp, req)
231 }
232 }
233 var err error
234 n, err = r.resp.Load().Body.Read(p)
235 if n != 0 {
236 r.pos += int64(n)
237 return nil
238 }
239 if err == nil {
240 err = errors.New("read 0 bytes")
241 }
242 closed = true
243 r.resp.Load().Body.Close()
244 return err
245 })
246 if r.pos >= r.size {
247 err = io.EOF
248 } else if err == io.EOF {
249 err = io.ErrUnexpectedEOF
250 }
251 return
252}
253
254func (r *retryReader) Close() error {
255 r.cancel(errors.New("reader closed"))
256 return r.resp.Load().Body.Close()
257}
258
259func (c *Client) retry(ctx context.Context, o func() error) error {
260 if err := ctx.Err(); err != nil {
261 return err
262 }
263 var b backoff.BackOff
264 for {
265 err := o()
266 if err == nil {
267 return nil
268 }
269 var permanent *backoff.PermanentError
270 if errors.As(err, &permanent) {
271 return err
272 }
273 if ctx.Err() != nil {
274 return err
275 }
276 if b == nil {
277 if c.GetBackOff == nil {
278 return err
279 }
280 b = c.GetBackOff()
281 }
282 next := b.NextBackOff()
283 if next == backoff.Stop {
284 return err
285 }
286 var clientErr *ClientError
287 if errors.As(err, &clientErr) && !clientErr.RetryAfter.IsZero() {
288 next = max(next, time.Until(clientErr.RetryAfter))
289 }
290 deadline, hasDeadline := ctx.Deadline()
291 if hasDeadline && time.Until(deadline) < next {
292 return err
293 }
294
295 if c.RetryNotify != nil {
296 c.RetryNotify(err, next)
297 }
298 select {
299 case <-ctx.Done():
300 return ctx.Err()
301 case <-time.After(next):
302 }
303 }
304}
305
306func (c *Client) newGet(path string) (*http.Request, error) {
307 u := url.URL{
308 Scheme: c.Scheme,
309 Host: c.Host,
310 Path: path,
311 }
312 req, err := http.NewRequest("GET", u.String(), nil)
313 if err != nil {
314 return nil, err
315 }
316 if c.UserAgent != "" {
317 req.Header.Set("User-Agent", c.UserAgent)
318 }
319 return req, nil
320}
321
322func (c *Client) doGet(ctx context.Context, req *http.Request) (*http.Response, error) {
323 req = req.WithContext(ctx)
324 c.addAuthorization(req)
325 client := http.Client{Transport: c.Transport}
326 resp, err := client.Do(req)
327 if err != nil {
328 return nil, redactURLError(err)
329 }
330
331 if resp.StatusCode == http.StatusUnauthorized {
332 unauthorizedErr := readClientError(resp, req)
333 retry, err := c.handleUnauthorized(ctx, resp)
334 if err != nil {
335 return nil, err
336 }
337 if !retry {
338 return nil, unauthorizedErr
339 }
340 c.addAuthorization(req)
341 resp, err = client.Do(req)
342 if err != nil {
343 return nil, redactURLError(err)
344 }
345 }
346
347 return resp, nil
348}
349
350func readClientError(resp *http.Response, req *http.Request) error {
351 defer resp.Body.Close()
352 clientErr := &ClientError{
353 StatusCode: resp.StatusCode,
354 }
355 retryAfter := resp.Header.Get("Retry-After")
356 if retryAfter != "" {
357 seconds, err := strconv.ParseInt(retryAfter, 10, 64)
358 if err == nil {
359 clientErr.RetryAfter = time.Now().Add(time.Duration(seconds) * time.Second)
360 } else {
361 clientErr.RetryAfter, _ = http.ParseTime(retryAfter)
362 }
363 }
364 content, err := readFullBody(resp, 2048)
365 if err == nil {
366 clientErr.RawBody = content
367 _ = json.Unmarshal(content, &clientErr.ErrorBody)
368 }
369
370 errReq := resp.Request
371 if errReq == nil {
372 errReq = req
373 }
374 urlErr := &url.Error{
375 Op: errReq.Method,
376 URL: errReq.URL.Redacted(),
377 Err: clientErr,
378 }
379 err = redactURLError(urlErr)
380
381 // Client errors are usually permanent, and server errors are usually
382 // temporary, but there are some exceptions.
383 isTemporary := 500 <= clientErr.StatusCode && clientErr.StatusCode <= 599
384 switch clientErr.StatusCode {
385 case http.StatusRequestTimeout, http.StatusTooEarly,
386 http.StatusTooManyRequests,
387 499: // nginx-specific, client closed request
388 isTemporary = true
389 case http.StatusNotImplemented, http.StatusHTTPVersionNotSupported,
390 http.StatusNetworkAuthenticationRequired:
391 isTemporary = false
392 }
393 if !isTemporary {
394 return backoff.Permanent(err)
395 }
396 return err
397}
398
399// ClientError is an HTTP error received from a registry or authorization
400// service.
401type ClientError struct {
402 ErrorBody
403 StatusCode int
404 RetryAfter time.Time
405 RawBody []byte
406}
407
408type ErrorBody struct {
409 Errors []ErrorInfo `json:"errors,omitempty"`
410}
411
412type ErrorInfo struct {
413 Code string `json:"code"`
414 Message string `json:"message,omitempty"`
415}
416
417func (e *ClientError) Error() string {
418 if len(e.Errors) == 0 {
419 text := fmt.Sprintf("HTTP %d %s", e.StatusCode, http.StatusText(e.StatusCode))
420 if len(e.RawBody) != 0 {
421 text = fmt.Sprintf("%s: %q", text, e.RawBody)
422 }
423 return text
424 }
425 var errorStrs []string
426 for _, ei := range e.Errors {
427 errorStrs = append(errorStrs, fmt.Sprintf("%s: %s", ei.Code, ei.Message))
428 }
429 return fmt.Sprintf("HTTP %d %s", e.StatusCode, strings.Join(errorStrs, "; "))
430}
431
432// redactURLError redacts the URL in an [url.Error]. After redirects, the URL
433// may contain secrets in query parameter values.
434//
435// Logic adapted from:
436// https://github.com/google/go-containerregistry/blob/v0.20.3/internal/redact/redact.go
437func redactURLError(err error) error {
438 var urlErr *url.Error
439 if !errors.As(err, &urlErr) {
440 return err
441 }
442 u, perr := url.Parse(urlErr.URL)
443 if perr != nil {
444 return err
445 }
446 query := u.Query()
447 for name, vals := range query {
448 if name == "scope" || name == "service" {
449 continue
450 }
451 for i := range vals {
452 vals[i] = "REDACTED"
453 }
454 }
455 u.RawQuery = query.Encode()
456 urlErr.URL = u.Redacted()
457 return err
458}
459
460func readFullBody(resp *http.Response, limit int) ([]byte, error) {
461 switch {
462 case resp.ContentLength < 0:
463 lr := io.LimitReader(resp.Body, int64(limit)+1)
464 content, err := io.ReadAll(lr)
465 if err != nil {
466 return nil, err
467 }
468 if len(content) > limit {
469 return nil, backoff.Permanent(fmt.Errorf("HTTP response exceeds limit of %d bytes", limit))
470 }
471 return content, nil
472 case resp.ContentLength <= int64(limit):
473 content := make([]byte, resp.ContentLength)
474 _, err := io.ReadFull(resp.Body, content)
475 if err != nil {
476 return nil, err
477 }
478 return content, nil
479 default:
480 return nil, backoff.Permanent(fmt.Errorf("HTTP response of size %d exceeds limit of %d bytes", resp.ContentLength, limit))
481 }
482}