blob: 6ee8afabef0feee33be4f1fbaaff3eeb35306db9 [file] [log] [blame]
Jan Schär56d12992025-04-14 11:49:37 +00001// Copyright The Monogon Project Authors.
2// SPDX-License-Identifier: Apache-2.0
3
4package osimage
5
6import (
7 "bytes"
8 "context"
9 "crypto/sha256"
10 "fmt"
11 "io"
12 "net"
13 "net/http"
14 "os"
15 "strings"
16 "testing"
17 "time"
18
19 "github.com/bazelbuild/rules_go/go/runfiles"
20 "github.com/cenkalti/backoff/v4"
21 ocispecv1 "github.com/opencontainers/image-spec/specs-go/v1"
22
23 "source.monogon.dev/osbase/oci"
24 "source.monogon.dev/osbase/oci/registry"
25)
26
27var (
28 // These are filled by bazel at linking time with the canonical path of
29 // their corresponding file. Inside the init function we resolve it
30 // with the rules_go runfiles package to the real path.
31 xImagePath string
32 xImageUncompressedPath string
33 xTestPayloadPath string
34)
35
36func init() {
37 var err error
38 for _, path := range []*string{
39 &xImagePath, &xImageUncompressedPath, &xTestPayloadPath,
40 } {
41 *path, err = runfiles.Rlocation(*path)
42 if err != nil {
43 panic(err)
44 }
45 }
46}
47
48var expectedPayloadHash [32]byte
49var expectedPayloadLen int64
50
51func init() {
52 expectedPayload, err := os.ReadFile(xTestPayloadPath)
53 if err != nil {
54 panic(err)
55 }
56 expectedPayloadHash = sha256.Sum256(expectedPayload)
57 expectedPayloadLen = int64(len(expectedPayload))
58}
59
60func TestRead(t *testing.T) {
61 testCases := []struct {
62 desc string
63 path string
64 }{
65 {"compressed", xImagePath},
66 {"uncompressed", xImageUncompressedPath},
67 }
68 for _, tC := range testCases {
69 t.Run(tC.desc, func(t *testing.T) {
70 image, err := oci.ReadLayout(tC.path)
71 if err != nil {
72 t.Fatal(err)
73 }
74 osImage, err := Read(image)
75 if err != nil {
76 t.Fatal(err)
77 }
78 payload, err := osImage.Payload("test")
79 if err != nil {
80 t.Fatal(err)
81 }
82 if got, want := payload.Size(), expectedPayloadLen; got != want {
83 t.Errorf("payload has size %d, expected %d", got, want)
84 }
85 reader, err := payload.Open()
86 if err != nil {
87 t.Fatal(err)
88 }
89 content, err := io.ReadAll(reader)
90 if err != nil {
91 t.Fatal(err)
92 }
93 contentHash := sha256.Sum256(content)
94 if contentHash != expectedPayloadHash {
95 t.Errorf("Payload read through Image does not match expected content, expected %x, got %x", expectedPayloadHash, contentHash)
96 }
97 if err := reader.Close(); err != nil {
98 t.Error(err)
99 }
100 })
101 }
102}
103
104func TestVerification(t *testing.T) {
105 server := registry.NewServer()
106 srcImage, err := oci.ReadLayout(xImageUncompressedPath)
107 if err != nil {
108 t.Fatal(err)
109 }
110 server.AddImage("test/repo", "test-tag", srcImage)
111 corrupter := &corruptingServer{handler: server}
112
113 listener, err := net.Listen("tcp", "127.0.0.1:0")
114 if err != nil {
115 t.Fatal(err)
116 }
117 defer listener.Close()
118 go http.Serve(listener, corrupter)
119
120 client := &registry.Client{
121 GetBackOff: func() backoff.BackOff {
122 return backoff.NewExponentialBackOff()
123 },
124 RetryNotify: func(err error, _ time.Duration) {
125 t.Errorf("Unexpected retry; verification errors should not trigger retries: %v", err)
126 },
127 Scheme: "http",
128 Host: listener.Addr().String(),
129 Repository: "test/repo",
130 }
131
132 // Test manifest verification
133 corrupter.affectedPath = "/v2/test/repo/manifests/test-tag"
134 _, err = client.Read(context.Background(), "test-tag", "")
135 if err != nil {
136 t.Errorf("Expected reading manifest to succeed when digest not given: %v", err)
137 }
138 _, err = client.Read(context.Background(), "test-tag", srcImage.ManifestDigest)
139 if !strings.Contains(fmt.Sprintf("%v", err), "failed verification") {
140 t.Errorf("Expected failed verification, got %v", err)
141 }
142
143 // Test config verification
144 corrupter.affectedPath = fmt.Sprintf("/v2/test/repo/blobs/%s", srcImage.Manifest.Config.Digest)
145 image, err := client.Read(context.Background(), "test-tag", srcImage.ManifestDigest)
146 if err != nil {
147 t.Fatal(err)
148 }
149 _, err = Read(image)
150 if !strings.Contains(fmt.Sprintf("%v", err), "failed verification") {
151 t.Errorf("Expected failed verification, got %v", err)
152 }
153
154 // Test payload verification
155 corrupter.affectedPath = fmt.Sprintf("/v2/test/repo/blobs/%s", srcImage.Manifest.Layers[0].Digest)
156 image, err = client.Read(context.Background(), "test-tag", srcImage.ManifestDigest)
157 if err != nil {
158 t.Fatal(err)
159 }
160 osImage, err := Read(image)
161 if err != nil {
162 t.Fatal(err)
163 }
164 payload, err := osImage.Payload("test")
165 if err != nil {
166 t.Fatal(err)
167 }
168 reader, err := payload.Open()
169 if err != nil {
170 t.Fatal(err)
171 }
172 defer reader.Close()
173 content, err := io.ReadAll(reader)
174 if !strings.Contains(fmt.Sprintf("%v", err), "payload failed verification") {
175 t.Errorf("Expected failed verification, got %v", err)
176 }
177 if len(content) != 0 {
178 t.Errorf("Did not expect to read any content, got %d bytes", len(content))
179 }
180}
181
182type corruptingServer struct {
183 affectedPath string
184 handler http.Handler
185}
186
187func (s *corruptingServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
188 if req.URL.Path == s.affectedPath {
189 w = &corruptingResponseWriter{ResponseWriter: w}
190 }
191 s.handler.ServeHTTP(w, req)
192}
193
194// corruptingResponseWriter replaces the first newline in the response with a
195// space. This means that JSON parsing will still succeed, but digest
196// verification should fail.
197type corruptingResponseWriter struct {
198 http.ResponseWriter
199 corrupted bool
200}
201
202func (w *corruptingResponseWriter) Write(b []byte) (n int, err error) {
203 index := bytes.IndexByte(b, '\n')
204 if w.corrupted || index == -1 {
205 return w.ResponseWriter.Write(b)
206 }
207 b = bytes.Clone(b)
208 b[index] = ' '
209 n, err = w.ResponseWriter.Write(b)
210 if n > index {
211 w.corrupted = true
212 }
213 return
214}
215
216func TestTruncation(t *testing.T) {
217 srcImage, err := oci.ReadLayout(xImageUncompressedPath)
218 if err != nil {
219 t.Fatal(err)
220 }
221 blobs := &truncatedBlobs{
222 image: srcImage,
223 length: srcImage.Manifest.Config.Size,
224 }
225 truncatedImage, err := oci.NewImage(srcImage.RawManifest, "", blobs)
226 if err != nil {
227 t.Fatal(err)
228 }
229
230 osImage, err := Read(truncatedImage)
231 if err != nil {
232 t.Fatal(err)
233 }
234 blobs.length = osImage.Config.Payloads[0].HashChunkSize
235 payload, err := osImage.Payload("test")
236 if err != nil {
237 t.Fatal(err)
238 }
239 reader, err := payload.Open()
240 if err != nil {
241 t.Fatal(err)
242 }
243 defer reader.Close()
244 _, err = io.ReadAll(reader)
245 if err == nil {
246 t.Error("Expected to get an error, got nil")
247 }
248}
249
250type truncatedBlobs struct {
251 image *oci.Image
252 length int64
253}
254
255func (b *truncatedBlobs) Blob(d *ocispecv1.Descriptor) (io.ReadCloser, error) {
256 reader, err := b.image.Blob(d)
257 if err != nil {
258 return nil, err
259 }
260 reader = &readCloser{
261 Reader: io.LimitReader(reader, b.length),
262 Closer: reader,
263 }
264 return reader, nil
265}
266
267type readCloser struct {
268 io.Reader
269 io.Closer
270}