blob: 9fc301a79b615e02ebe2c794ba37501e15716a1e [file] [log] [blame]
Tim Windelschmidt6d33a432025-02-04 14:34:25 +01001// Copyright The Monogon Project Authors.
2// SPDX-License-Identifier: Apache-2.0
3
Jan Schära6da1712024-08-21 15:12:11 +02004package blockdev
5
6import (
7 "bytes"
8 "io"
9 "os"
10 "slices"
11 "testing"
12)
13
14func errIsNil(err error) bool {
15 return err == nil
16}
17func errIsEOF(err error) bool {
18 return err == io.EOF
19}
20func errIsReadFailed(err error) bool {
21 return err != nil && err != io.EOF
22}
23
24// ValidateBlockDev tests all methods of the BlockDev interface. This way, all
25// implementations can be tested without duplicating test code. This also
26// ensures that all implementations behave consistently.
27func ValidateBlockDev(t *testing.T, blk BlockDev, blockCount, blockSize, optimalBlockSize int64) {
28 if blk.BlockCount() != blockCount {
29 t.Errorf("Expected block count %d, got %d", blockCount, blk.BlockCount())
30 }
31 if blk.BlockSize() != blockSize {
32 t.Errorf("Expected block size %d, got %d", blockSize, blk.BlockSize())
33 }
34 if blk.OptimalBlockSize() != optimalBlockSize {
35 t.Errorf("Expected optimal block size %d, got %d", optimalBlockSize, blk.OptimalBlockSize())
36 }
37 size := blockCount * blockSize
38
39 // ReadAt
40 checkBlockDevOp(t, blk, func(content []byte) {
41 normalErr := errIsNil
42 if size < 3+8 {
43 normalErr = errIsEOF
44 }
45 readAtTests := []struct {
46 name string
47 offset, len int64
48 expectedErr func(error) bool
49 }{
50 {"empty start", 0, 0, errIsNil},
51 {"empty end", size, 0, errIsNil},
52 {"normal", 3, 8, normalErr},
53 {"ends past the end", 1, size, errIsEOF},
54 {"offset negative", -1, 2, errIsReadFailed},
55 {"starts at the end", size, 8, errIsEOF},
56 {"starts past the end", size + 4, 8, errIsEOF},
57 }
58 for _, tt := range readAtTests {
59 t.Run("read "+tt.name, func(t *testing.T) {
60 buf := make([]byte, tt.len)
61 n, err := blk.ReadAt(buf, tt.offset)
62 if !tt.expectedErr(err) {
63 t.Errorf("unexpected error %v", err)
64 }
65 expected := []byte{}
66 if tt.offset >= 0 && tt.offset <= size {
67 expected = content[tt.offset:min(tt.offset+tt.len, size)]
68 }
69 if n != len(expected) {
70 t.Errorf("got n = %d, expected %d", n, len(expected))
71 }
72 if !slices.Equal(buf[:n], expected) {
73 t.Errorf("read unexpected data")
74 }
75 })
76 }
77 })
78
79 // WriteAt
80 writeAtTests := []struct {
81 name string
82 offset int64
83 data string
84 ok bool
85 }{
86 {"empty start", 0, "", true},
87 {"empty end", size, "", true},
88 {"normal", 3, "abcdef", size >= 9},
89 {"ends at the end", size - 4, "abcd", size >= 4},
90 {"ends past the end", size - 4, "abcde", false},
91 {"offset negative", -1, "abc", false},
92 {"starts at the end", size, "abc", false},
93 {"starts past the end", size + 4, "abc", false},
94 }
95 for _, tt := range writeAtTests {
96 t.Run("write "+tt.name, func(t *testing.T) {
97 checkBlockDevOp(t, blk, func(content []byte) {
98 n, err := blk.WriteAt([]byte(tt.data), tt.offset)
99 if (err == nil) != tt.ok {
100 t.Errorf("expected error %v, got %v", tt.ok, err)
101 }
102 expectedN := 0
103 if tt.offset >= 0 && tt.offset < size {
104 expectedN = copy(content[tt.offset:], tt.data)
105 }
106 if n != expectedN {
107 t.Errorf("got n = %d, expected %d; err: %v", n, expectedN, err)
108 }
109 })
110 })
111 }
112
113 // Zero
114 zeroTests := []struct {
115 name string
116 start, end int64
117 ok bool
118 }{
119 {"empty range start", 0, 0, true},
120 {"empty range end", size, size, true},
121 {"full", 0, size, true},
122 {"partial", blockSize, 3 * blockSize, blockCount >= 3},
123 {"negative start", -blockSize, blockSize, false},
124 {"start after end", 2 * blockSize, blockSize, false},
125 {"unaligned start", 1, blockSize, false},
126 {"unaligned end", 0, 1, false},
127 }
128 for _, tt := range zeroTests {
129 t.Run("zero "+tt.name, func(t *testing.T) {
130 checkBlockDevOp(t, blk, func(content []byte) {
131 err := blk.Zero(tt.start, tt.end)
132 if (err == nil) != tt.ok {
133 t.Errorf("expected error %v, got %v", tt.ok, err)
134 }
135 if tt.ok {
136 for i := tt.start; i < tt.end; i++ {
137 content[i] = 0
138 }
139 }
140 })
141 })
142 }
143
144 // Discard
145 for _, tt := range zeroTests {
146 t.Run("discard "+tt.name, func(t *testing.T) {
147 checkBlockDevOp(t, blk, func(content []byte) {
148 err := blk.Discard(tt.start, tt.end)
149 if (err == nil) != tt.ok {
150 t.Errorf("expected error %v, got %v", tt.ok, err)
151 }
152 if tt.ok {
153 n, err := blk.ReadAt(content[tt.start:tt.end], tt.start)
154 if n != int(tt.end-tt.start) {
155 t.Errorf("read returned %d, %v", n, err)
156 }
157 }
158 })
159 })
160 }
161
162 // Sync
163 checkBlockDevOp(t, blk, func(content []byte) {
164 err := blk.Sync()
165 if err != nil {
166 t.Errorf("Sync failed: %v", err)
167 }
168 })
169}
170
171// checkBlockDevOp is a helper for testing operations on a blockdev. It fills
172// the blockdev with a pattern, then calls f with a slice containing the
173// pattern, and afterwards reads the blockdev to compare against the expected
174// content. f should modify the slice to the content expected afterwards.
175func checkBlockDevOp(t *testing.T, blk BlockDev, f func(content []byte)) {
176 t.Helper()
177
178 testContent := make([]byte, blk.BlockCount()*blk.BlockSize())
179 for i := range testContent {
180 testContent[i] = '1' + byte(i%9)
181 }
182 n, err := blk.WriteAt(testContent, 0)
183 if n != len(testContent) || err != nil {
184 t.Fatalf("WriteAt = %d, %v; expected %d, nil", n, err, len(testContent))
185 }
186 f(testContent)
187 afterContent := make([]byte, len(testContent))
188 n, err = blk.ReadAt(afterContent, 0)
189 if n != len(afterContent) || (err != nil && err != io.EOF) {
190 t.Fatalf("ReadAt = %d, %v; expected %d, (nil or EOF)", n, err, len(afterContent))
191 }
192 if !slices.Equal(afterContent, testContent) {
193 t.Errorf("Unexpected content after operation")
194 }
195}
196
197func TestReadWriteSeeker_Seek(t *testing.T) {
198 // Verifies that NewRWS's Seeker behaves like bytes.NewReader
199 br := bytes.NewReader([]byte("foobar"))
200 m := MustNewMemory(2, 3)
201 rws := NewRWS(m)
202 n, err := rws.Write([]byte("foobar"))
203 if n != 6 || err != nil {
204 t.Errorf("Write = %v, %v; want 6, nil", n, err)
205 }
206
207 for _, whence := range []int{io.SeekStart, io.SeekCurrent, io.SeekEnd} {
208 for offset := int64(-7); offset <= 7; offset++ {
209 brOff, brErr := br.Seek(offset, whence)
210 rwsOff, rwsErr := rws.Seek(offset, whence)
211 if (brErr != nil) != (rwsErr != nil) || brOff != rwsOff {
212 t.Errorf("For whence %d, offset %d: bytes.Reader.Seek = (%v, %v) != ReadWriteSeeker.Seek = (%v, %v)",
213 whence, offset, brOff, brErr, rwsOff, rwsErr)
214 }
215 }
216 }
217
218 // And verify we can just seek past the end and get an EOF
219 got, err := rws.Seek(100, io.SeekStart)
220 if err != nil || got != 100 {
221 t.Errorf("Seek = %v, %v; want 100, nil", got, err)
222 }
223
224 n, err = rws.Read(make([]byte, 10))
225 if n != 0 || err != io.EOF {
226 t.Errorf("Read = %v, %v; want 0, EOF", n, err)
227 }
228}
229
230func TestNewSection(t *testing.T) {
231 tests := []struct {
232 name string
233 blockSize int64
234 count int64
235 startBlock int64
236 endBlock int64
237 ok bool
238 sectionCount int64
239 }{
240 {name: "empty underlying", blockSize: 8, count: 0, startBlock: 0, endBlock: 0, ok: true, sectionCount: 0},
241 {name: "empty section", blockSize: 8, count: 5, startBlock: 2, endBlock: 2, ok: true, sectionCount: 0},
242 {name: "partial section", blockSize: 8, count: 15, startBlock: 1, endBlock: 11, ok: true, sectionCount: 10},
243 {name: "full section", blockSize: 8, count: 15, startBlock: 0, endBlock: 15, ok: true, sectionCount: 15},
244 {name: "negative start", blockSize: 8, count: 15, startBlock: -1, endBlock: 11, ok: false},
245 {name: "start after end", blockSize: 8, count: 15, startBlock: 6, endBlock: 5, ok: false},
246 {name: "end out of bounds", blockSize: 8, count: 15, startBlock: 6, endBlock: 16, ok: false},
247 }
248 for _, tt := range tests {
249 t.Run(tt.name, func(t *testing.T) {
250 m := MustNewMemory(tt.blockSize, tt.count)
251 s, err := NewSection(m, tt.startBlock, tt.endBlock)
252 if (err == nil) != tt.ok {
253 t.Errorf("NewSection: expected %v, got %v", tt.ok, err)
254 }
255 if err == nil {
256 checkBlockDevOp(t, m, func(content []byte) {
257 ValidateBlockDev(t, s, tt.sectionCount, tt.blockSize, tt.blockSize)
258
259 // Check that content outside the section has not changed.
260 start := tt.startBlock * tt.blockSize
261 end := tt.endBlock * tt.blockSize
262 n, err := m.ReadAt(content[start:end], start)
263 if n != int(end-start) {
264 t.Errorf("read returned %d, %v", n, err)
265 }
266 })
267 }
268 })
269 }
270}
271
272type MemoryWithGenericZero struct {
273 *Memory
274}
275
276func (m *MemoryWithGenericZero) Zero(startByte, endByte int64) error {
277 return GenericZero(m, startByte, endByte)
278}
279
280func TestGenericZero(t *testing.T) {
281 if os.Getenv("IN_KTEST") == "true" {
282 t.Skip("In ktest")
283 }
284 // Use size larger than the 16 MiB buffer size in GenericZero.
285 blockSize := int64(512)
286 blockCount := int64(35 * 1024)
287 m, err := NewMemory(blockSize, blockCount)
288 if err != nil {
289 t.Errorf("NewMemory: %v", err)
290 }
291 b := &MemoryWithGenericZero{m}
292 if err == nil {
293 ValidateBlockDev(t, b, blockCount, blockSize, blockSize)
294 }
295}
296
297func TestNewMemory(t *testing.T) {
298 tests := []struct {
299 name string
300 blockSize int64
301 count int64
302 ok bool
303 }{
304 {name: "normal", blockSize: 64, count: 9, ok: true},
305 {name: "count 0", blockSize: 8, count: 0, ok: true},
306 {name: "count negative", blockSize: 8, count: -1, ok: false},
307 {name: "blockSize not a power of 2", blockSize: 9, count: 5, ok: false},
308 {name: "blockSize 0", blockSize: 0, count: 5, ok: false},
309 {name: "blockSize negative", blockSize: -1, count: 5, ok: false},
310 }
311 for _, tt := range tests {
312 t.Run(tt.name, func(t *testing.T) {
313 m, err := NewMemory(tt.blockSize, tt.count)
314 if (err == nil) != tt.ok {
315 t.Errorf("NewMemory: expected %v, got %v", tt.ok, err)
316 }
317 if err == nil {
318 ValidateBlockDev(t, m, tt.count, tt.blockSize, tt.blockSize)
319 }
320 })
321 }
322}