osbase/blockdev: add tests, fix minor issues
Add a lot of bounds checks which should make BlockDev safer to use. Fix
a bug in the ReadWriteSeeker.Seek function with io.SeekEnd; the offset
should be added to, not subtracted from the size. Add the Sync()
function to the BlockDev interface.
Change-Id: I247095b3dbc6410064844b4ac7c6208d88a7abcd
Reviewed-on: https://review.monogon.dev/c/monogon/+/3338
Tested-by: Jenkins CI
Reviewed-by: Lorenz Brun <lorenz@monogon.tech>
diff --git a/osbase/blockdev/blockdev_test.go b/osbase/blockdev/blockdev_test.go
new file mode 100644
index 0000000..df4e8ae
--- /dev/null
+++ b/osbase/blockdev/blockdev_test.go
@@ -0,0 +1,319 @@
+package blockdev
+
+import (
+ "bytes"
+ "io"
+ "os"
+ "slices"
+ "testing"
+)
+
+func errIsNil(err error) bool {
+ return err == nil
+}
+func errIsEOF(err error) bool {
+ return err == io.EOF
+}
+func errIsReadFailed(err error) bool {
+ return err != nil && err != io.EOF
+}
+
+// ValidateBlockDev tests all methods of the BlockDev interface. This way, all
+// implementations can be tested without duplicating test code. This also
+// ensures that all implementations behave consistently.
+func ValidateBlockDev(t *testing.T, blk BlockDev, blockCount, blockSize, optimalBlockSize int64) {
+ if blk.BlockCount() != blockCount {
+ t.Errorf("Expected block count %d, got %d", blockCount, blk.BlockCount())
+ }
+ if blk.BlockSize() != blockSize {
+ t.Errorf("Expected block size %d, got %d", blockSize, blk.BlockSize())
+ }
+ if blk.OptimalBlockSize() != optimalBlockSize {
+ t.Errorf("Expected optimal block size %d, got %d", optimalBlockSize, blk.OptimalBlockSize())
+ }
+ size := blockCount * blockSize
+
+ // ReadAt
+ checkBlockDevOp(t, blk, func(content []byte) {
+ normalErr := errIsNil
+ if size < 3+8 {
+ normalErr = errIsEOF
+ }
+ readAtTests := []struct {
+ name string
+ offset, len int64
+ expectedErr func(error) bool
+ }{
+ {"empty start", 0, 0, errIsNil},
+ {"empty end", size, 0, errIsNil},
+ {"normal", 3, 8, normalErr},
+ {"ends past the end", 1, size, errIsEOF},
+ {"offset negative", -1, 2, errIsReadFailed},
+ {"starts at the end", size, 8, errIsEOF},
+ {"starts past the end", size + 4, 8, errIsEOF},
+ }
+ for _, tt := range readAtTests {
+ t.Run("read "+tt.name, func(t *testing.T) {
+ buf := make([]byte, tt.len)
+ n, err := blk.ReadAt(buf, tt.offset)
+ if !tt.expectedErr(err) {
+ t.Errorf("unexpected error %v", err)
+ }
+ expected := []byte{}
+ if tt.offset >= 0 && tt.offset <= size {
+ expected = content[tt.offset:min(tt.offset+tt.len, size)]
+ }
+ if n != len(expected) {
+ t.Errorf("got n = %d, expected %d", n, len(expected))
+ }
+ if !slices.Equal(buf[:n], expected) {
+ t.Errorf("read unexpected data")
+ }
+ })
+ }
+ })
+
+ // WriteAt
+ writeAtTests := []struct {
+ name string
+ offset int64
+ data string
+ ok bool
+ }{
+ {"empty start", 0, "", true},
+ {"empty end", size, "", true},
+ {"normal", 3, "abcdef", size >= 9},
+ {"ends at the end", size - 4, "abcd", size >= 4},
+ {"ends past the end", size - 4, "abcde", false},
+ {"offset negative", -1, "abc", false},
+ {"starts at the end", size, "abc", false},
+ {"starts past the end", size + 4, "abc", false},
+ }
+ for _, tt := range writeAtTests {
+ t.Run("write "+tt.name, func(t *testing.T) {
+ checkBlockDevOp(t, blk, func(content []byte) {
+ n, err := blk.WriteAt([]byte(tt.data), tt.offset)
+ if (err == nil) != tt.ok {
+ t.Errorf("expected error %v, got %v", tt.ok, err)
+ }
+ expectedN := 0
+ if tt.offset >= 0 && tt.offset < size {
+ expectedN = copy(content[tt.offset:], tt.data)
+ }
+ if n != expectedN {
+ t.Errorf("got n = %d, expected %d; err: %v", n, expectedN, err)
+ }
+ })
+ })
+ }
+
+ // Zero
+ zeroTests := []struct {
+ name string
+ start, end int64
+ ok bool
+ }{
+ {"empty range start", 0, 0, true},
+ {"empty range end", size, size, true},
+ {"full", 0, size, true},
+ {"partial", blockSize, 3 * blockSize, blockCount >= 3},
+ {"negative start", -blockSize, blockSize, false},
+ {"start after end", 2 * blockSize, blockSize, false},
+ {"unaligned start", 1, blockSize, false},
+ {"unaligned end", 0, 1, false},
+ }
+ for _, tt := range zeroTests {
+ t.Run("zero "+tt.name, func(t *testing.T) {
+ checkBlockDevOp(t, blk, func(content []byte) {
+ err := blk.Zero(tt.start, tt.end)
+ if (err == nil) != tt.ok {
+ t.Errorf("expected error %v, got %v", tt.ok, err)
+ }
+ if tt.ok {
+ for i := tt.start; i < tt.end; i++ {
+ content[i] = 0
+ }
+ }
+ })
+ })
+ }
+
+ // Discard
+ for _, tt := range zeroTests {
+ t.Run("discard "+tt.name, func(t *testing.T) {
+ checkBlockDevOp(t, blk, func(content []byte) {
+ err := blk.Discard(tt.start, tt.end)
+ if (err == nil) != tt.ok {
+ t.Errorf("expected error %v, got %v", tt.ok, err)
+ }
+ if tt.ok {
+ n, err := blk.ReadAt(content[tt.start:tt.end], tt.start)
+ if n != int(tt.end-tt.start) {
+ t.Errorf("read returned %d, %v", n, err)
+ }
+ }
+ })
+ })
+ }
+
+ // Sync
+ checkBlockDevOp(t, blk, func(content []byte) {
+ err := blk.Sync()
+ if err != nil {
+ t.Errorf("Sync failed: %v", err)
+ }
+ })
+}
+
+// checkBlockDevOp is a helper for testing operations on a blockdev. It fills
+// the blockdev with a pattern, then calls f with a slice containing the
+// pattern, and afterwards reads the blockdev to compare against the expected
+// content. f should modify the slice to the content expected afterwards.
+func checkBlockDevOp(t *testing.T, blk BlockDev, f func(content []byte)) {
+ t.Helper()
+
+ testContent := make([]byte, blk.BlockCount()*blk.BlockSize())
+ for i := range testContent {
+ testContent[i] = '1' + byte(i%9)
+ }
+ n, err := blk.WriteAt(testContent, 0)
+ if n != len(testContent) || err != nil {
+ t.Fatalf("WriteAt = %d, %v; expected %d, nil", n, err, len(testContent))
+ }
+ f(testContent)
+ afterContent := make([]byte, len(testContent))
+ n, err = blk.ReadAt(afterContent, 0)
+ if n != len(afterContent) || (err != nil && err != io.EOF) {
+ t.Fatalf("ReadAt = %d, %v; expected %d, (nil or EOF)", n, err, len(afterContent))
+ }
+ if !slices.Equal(afterContent, testContent) {
+ t.Errorf("Unexpected content after operation")
+ }
+}
+
+func TestReadWriteSeeker_Seek(t *testing.T) {
+ // Verifies that NewRWS's Seeker behaves like bytes.NewReader
+ br := bytes.NewReader([]byte("foobar"))
+ m := MustNewMemory(2, 3)
+ rws := NewRWS(m)
+ n, err := rws.Write([]byte("foobar"))
+ if n != 6 || err != nil {
+ t.Errorf("Write = %v, %v; want 6, nil", n, err)
+ }
+
+ for _, whence := range []int{io.SeekStart, io.SeekCurrent, io.SeekEnd} {
+ for offset := int64(-7); offset <= 7; offset++ {
+ brOff, brErr := br.Seek(offset, whence)
+ rwsOff, rwsErr := rws.Seek(offset, whence)
+ if (brErr != nil) != (rwsErr != nil) || brOff != rwsOff {
+ t.Errorf("For whence %d, offset %d: bytes.Reader.Seek = (%v, %v) != ReadWriteSeeker.Seek = (%v, %v)",
+ whence, offset, brOff, brErr, rwsOff, rwsErr)
+ }
+ }
+ }
+
+ // And verify we can just seek past the end and get an EOF
+ got, err := rws.Seek(100, io.SeekStart)
+ if err != nil || got != 100 {
+ t.Errorf("Seek = %v, %v; want 100, nil", got, err)
+ }
+
+ n, err = rws.Read(make([]byte, 10))
+ if n != 0 || err != io.EOF {
+ t.Errorf("Read = %v, %v; want 0, EOF", n, err)
+ }
+}
+
+func TestNewSection(t *testing.T) {
+ tests := []struct {
+ name string
+ blockSize int64
+ count int64
+ startBlock int64
+ endBlock int64
+ ok bool
+ sectionCount int64
+ }{
+ {name: "empty underlying", blockSize: 8, count: 0, startBlock: 0, endBlock: 0, ok: true, sectionCount: 0},
+ {name: "empty section", blockSize: 8, count: 5, startBlock: 2, endBlock: 2, ok: true, sectionCount: 0},
+ {name: "partial section", blockSize: 8, count: 15, startBlock: 1, endBlock: 11, ok: true, sectionCount: 10},
+ {name: "full section", blockSize: 8, count: 15, startBlock: 0, endBlock: 15, ok: true, sectionCount: 15},
+ {name: "negative start", blockSize: 8, count: 15, startBlock: -1, endBlock: 11, ok: false},
+ {name: "start after end", blockSize: 8, count: 15, startBlock: 6, endBlock: 5, ok: false},
+ {name: "end out of bounds", blockSize: 8, count: 15, startBlock: 6, endBlock: 16, ok: false},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ m := MustNewMemory(tt.blockSize, tt.count)
+ s, err := NewSection(m, tt.startBlock, tt.endBlock)
+ if (err == nil) != tt.ok {
+ t.Errorf("NewSection: expected %v, got %v", tt.ok, err)
+ }
+ if err == nil {
+ checkBlockDevOp(t, m, func(content []byte) {
+ ValidateBlockDev(t, s, tt.sectionCount, tt.blockSize, tt.blockSize)
+
+ // Check that content outside the section has not changed.
+ start := tt.startBlock * tt.blockSize
+ end := tt.endBlock * tt.blockSize
+ n, err := m.ReadAt(content[start:end], start)
+ if n != int(end-start) {
+ t.Errorf("read returned %d, %v", n, err)
+ }
+ })
+ }
+ })
+ }
+}
+
+type MemoryWithGenericZero struct {
+ *Memory
+}
+
+func (m *MemoryWithGenericZero) Zero(startByte, endByte int64) error {
+ return GenericZero(m, startByte, endByte)
+}
+
+func TestGenericZero(t *testing.T) {
+ if os.Getenv("IN_KTEST") == "true" {
+ t.Skip("In ktest")
+ }
+ // Use size larger than the 16 MiB buffer size in GenericZero.
+ blockSize := int64(512)
+ blockCount := int64(35 * 1024)
+ m, err := NewMemory(blockSize, blockCount)
+ if err != nil {
+ t.Errorf("NewMemory: %v", err)
+ }
+ b := &MemoryWithGenericZero{m}
+ if err == nil {
+ ValidateBlockDev(t, b, blockCount, blockSize, blockSize)
+ }
+}
+
+func TestNewMemory(t *testing.T) {
+ tests := []struct {
+ name string
+ blockSize int64
+ count int64
+ ok bool
+ }{
+ {name: "normal", blockSize: 64, count: 9, ok: true},
+ {name: "count 0", blockSize: 8, count: 0, ok: true},
+ {name: "count negative", blockSize: 8, count: -1, ok: false},
+ {name: "blockSize not a power of 2", blockSize: 9, count: 5, ok: false},
+ {name: "blockSize 0", blockSize: 0, count: 5, ok: false},
+ {name: "blockSize negative", blockSize: -1, count: 5, ok: false},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ m, err := NewMemory(tt.blockSize, tt.count)
+ if (err == nil) != tt.ok {
+ t.Errorf("NewMemory: expected %v, got %v", tt.ok, err)
+ }
+ if err == nil {
+ ValidateBlockDev(t, m, tt.count, tt.blockSize, tt.blockSize)
+ }
+ })
+ }
+}