osbase/blockdev: add tests, fix minor issues
Add a lot of bounds checks which should make BlockDev safer to use. Fix
a bug in the ReadWriteSeeker.Seek function with io.SeekEnd; the offset
should be added to, not subtracted from the size. Add the Sync()
function to the BlockDev interface.
Change-Id: I247095b3dbc6410064844b4ac7c6208d88a7abcd
Reviewed-on: https://review.monogon.dev/c/monogon/+/3338
Tested-by: Jenkins CI
Reviewed-by: Lorenz Brun <lorenz@monogon.tech>
diff --git a/osbase/blockdev/BUILD.bazel b/osbase/blockdev/BUILD.bazel
index 0805ef6..f476e8b 100644
--- a/osbase/blockdev/BUILD.bazel
+++ b/osbase/blockdev/BUILD.bazel
@@ -1,4 +1,5 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+load("//osbase/test/ktest:ktest.bzl", "ktest")
go_library(
name = "blockdev",
@@ -26,3 +27,25 @@
"//conditions:default": [],
}),
)
+
+go_test(
+ name = "blockdev_test",
+ srcs = [
+ "blockdev_linux_test.go",
+ "blockdev_test.go",
+ ],
+ embed = [":blockdev"],
+ deps = select({
+ "@io_bazel_rules_go//go/platform:android": [
+ "//osbase/loop",
+ ],
+ "@io_bazel_rules_go//go/platform:linux": [
+ "//osbase/loop",
+ ],
+ "//conditions:default": [],
+ }),
+)
+
+ktest(
+ tester = ":blockdev_test",
+)
diff --git a/osbase/blockdev/blockdev.go b/osbase/blockdev/blockdev.go
index 0e3c6e1..9cf4ee2 100644
--- a/osbase/blockdev/blockdev.go
+++ b/osbase/blockdev/blockdev.go
@@ -16,14 +16,15 @@
type BlockDev interface {
io.ReaderAt
io.WriterAt
- // BlockSize returns the block size of the block device in bytes. This must
- // be a power of two and is commonly (but not always) either 512 or 4096.
- BlockSize() int64
// BlockCount returns the number of blocks on the block device or -1 if it
// is an image with an undefined size.
BlockCount() int64
+ // BlockSize returns the block size of the block device in bytes. This must
+ // be a power of two and is commonly (but not always) either 512 or 4096.
+ BlockSize() int64
+
// OptimalBlockSize returns the optimal block size in bytes for aligning
// to as well as issuing I/O. IO operations with block sizes below this
// one might incur read-write overhead. This is the larger of the physical
@@ -41,6 +42,9 @@
// Zero zeroes a continouous set of blocks. On certain implementations this
// can be significantly faster than just calling Write with zeroes.
Zero(startByte, endByte int64) error
+
+ // Sync commits the current contents to stable storage.
+ Sync() error
}
func NewRWS(b BlockDev) *ReadWriteSeeker {
@@ -68,13 +72,18 @@
func (s *ReadWriteSeeker) Seek(offset int64, whence int) (int64, error) {
switch whence {
- case io.SeekCurrent:
- s.currPos += offset
+ default:
+ return 0, errors.New("Seek: invalid whence")
case io.SeekStart:
- s.currPos = offset
+ case io.SeekCurrent:
+ offset += s.currPos
case io.SeekEnd:
- s.currPos = (s.b.BlockCount() * s.b.BlockSize()) - offset
+ offset += s.b.BlockCount() * s.b.BlockSize()
}
+ if offset < 0 {
+ return 0, errors.New("Seek: invalid offset")
+ }
+ s.currPos = offset
return s.currPos, nil
}
@@ -82,12 +91,21 @@
// NewSection returns a new Section, implementing BlockDev over that subset
// of blocks. The interval is inclusive-exclusive.
-func NewSection(b BlockDev, startBlock, endBlock int64) *Section {
+func NewSection(b BlockDev, startBlock, endBlock int64) (*Section, error) {
+ if startBlock < 0 {
+ return nil, fmt.Errorf("invalid range: startBlock (%d) negative", startBlock)
+ }
+ if startBlock > endBlock {
+ return nil, fmt.Errorf("invalid range: startBlock (%d) bigger than endBlock (%d)", startBlock, endBlock)
+ }
+ if endBlock > b.BlockCount() {
+ return nil, fmt.Errorf("endBlock (%d) out of range (%d)", endBlock, b.BlockCount())
+ }
return &Section{
b: b,
startBlock: startBlock,
endBlock: endBlock,
- }
+ }, nil
}
// Section implements BlockDev on a slice of another BlockDev given a startBlock
@@ -98,13 +116,20 @@
}
func (s *Section) ReadAt(p []byte, off int64) (n int, err error) {
+ if off < 0 {
+ return 0, errors.New("blockdev.Section.ReadAt: negative offset")
+ }
bOff := off + (s.startBlock * s.b.BlockSize())
bytesToEnd := (s.endBlock * s.b.BlockSize()) - bOff
- if bytesToEnd <= 0 {
+ if bytesToEnd < 0 {
return 0, io.EOF
}
if bytesToEnd < int64(len(p)) {
- return s.b.ReadAt(p[:bytesToEnd], bOff)
+ n, err := s.b.ReadAt(p[:bytesToEnd], bOff)
+ if err == nil {
+ err = io.EOF
+ }
+ return n, err
}
return s.b.ReadAt(p, bOff)
}
@@ -112,11 +137,11 @@
func (s *Section) WriteAt(p []byte, off int64) (n int, err error) {
bOff := off + (s.startBlock * s.b.BlockSize())
bytesToEnd := (s.endBlock * s.b.BlockSize()) - bOff
- if bytesToEnd <= 0 {
+ if off < 0 || bytesToEnd < 0 {
return 0, ErrOutOfBounds
}
if bytesToEnd < int64(len(p)) {
- n, err := s.b.WriteAt(p[:bytesToEnd], off+(s.startBlock*s.b.BlockSize()))
+ n, err := s.b.WriteAt(p[:bytesToEnd], bOff)
if err != nil {
// If an error happened, prioritize that error
return n, err
@@ -125,7 +150,7 @@
// error.
return n, ErrOutOfBounds
}
- return s.b.WriteAt(p, off+(s.startBlock*s.b.BlockSize()))
+ return s.b.WriteAt(p, bOff)
}
func (s *Section) BlockCount() int64 {
@@ -136,49 +161,56 @@
return s.b.BlockSize()
}
-func (s *Section) inRange(startByte, endByte int64) error {
- if startByte > endByte {
- return fmt.Errorf("invalid range: startByte (%d) bigger than endByte (%d)", startByte, endByte)
- }
- sectionLen := s.BlockCount() * s.BlockSize()
- if startByte >= sectionLen {
- return fmt.Errorf("startByte (%d) out of range (%d)", startByte, sectionLen)
- }
- if endByte > sectionLen {
- return fmt.Errorf("endBlock (%d) out of range (%d)", endByte, sectionLen)
- }
- return nil
+func (s *Section) OptimalBlockSize() int64 {
+ return s.b.OptimalBlockSize()
}
func (s *Section) Discard(startByte, endByte int64) error {
- if err := s.inRange(startByte, endByte); err != nil {
+ if err := validAlignedRange(s, startByte, endByte); err != nil {
return err
}
offset := s.startBlock * s.b.BlockSize()
return s.b.Discard(offset+startByte, offset+endByte)
}
-func (s *Section) OptimalBlockSize() int64 {
- return s.b.OptimalBlockSize()
-}
-
func (s *Section) Zero(startByte, endByte int64) error {
- if err := s.inRange(startByte, endByte); err != nil {
+ if err := validAlignedRange(s, startByte, endByte); err != nil {
return err
}
offset := s.startBlock * s.b.BlockSize()
return s.b.Zero(offset+startByte, offset+endByte)
}
-// GenericZero implements software-based zeroing. This can be used to implement
-// Zero when no acceleration is available or desired.
-func GenericZero(b BlockDev, startByte, endByte int64) error {
+func (s *Section) Sync() error {
+ return s.b.Sync()
+}
+
+func validAlignedRange(b BlockDev, startByte, endByte int64) error {
+ if startByte < 0 {
+ return fmt.Errorf("invalid range: startByte (%d) negative", startByte)
+ }
+ if startByte > endByte {
+ return fmt.Errorf("invalid range: startByte (%d) bigger than endByte (%d)", startByte, endByte)
+ }
+ devLen := b.BlockCount() * b.BlockSize()
+ if endByte > devLen {
+ return fmt.Errorf("endByte (%d) out of range (%d)", endByte, devLen)
+ }
if startByte%b.BlockSize() != 0 {
return fmt.Errorf("startByte (%d) needs to be aligned to block size (%d)", startByte, b.BlockSize())
}
if endByte%b.BlockSize() != 0 {
return fmt.Errorf("endByte (%d) needs to be aligned to block size (%d)", endByte, b.BlockSize())
}
+ return nil
+}
+
+// GenericZero implements software-based zeroing. This can be used to implement
+// Zero when no acceleration is available or desired.
+func GenericZero(b BlockDev, startByte, endByte int64) error {
+ if err := validAlignedRange(b, startByte, endByte); err != nil {
+ return err
+ }
// Choose buffer size close to 16MiB or the range to be zeroed, whatever
// is smaller.
bufSizeTarget := int64(16 * 1024 * 1024)
diff --git a/osbase/blockdev/blockdev_darwin.go b/osbase/blockdev/blockdev_darwin.go
index 5422e55..725c3a5 100644
--- a/osbase/blockdev/blockdev_darwin.go
+++ b/osbase/blockdev/blockdev_darwin.go
@@ -45,22 +45,26 @@
return d.blockSize
}
+func (d *Device) OptimalBlockSize() int64 {
+ return d.blockSize
+}
+
func (d *Device) Discard(startByte int64, endByte int64) error {
// Can be implemented using DKIOCUNMAP, but needs x/sys/unix extension.
// Not mandatory, so this is fine for now.
return errors.ErrUnsupported
}
-func (d *Device) OptimalBlockSize() int64 {
- return d.blockSize
-}
-
func (d *Device) Zero(startByte int64, endByte int64) error {
// It doesn't look like MacOS even has any zeroing acceleration, so just
// use the generic one.
return GenericZero(d, startByte, endByte)
}
+func (d *Device) Sync() error {
+ return d.backend.Sync()
+}
+
// Open opens a block device given a path to its inode.
func Open(path string) (*Device, error) {
outFile, err := os.OpenFile(path, os.O_RDWR, 0640)
@@ -156,16 +160,20 @@
return d.blockSize
}
+func (d *File) OptimalBlockSize() int64 {
+ return d.blockSize
+}
+
func (d *File) Discard(startByte int64, endByte int64) error {
// Can be supported in the future via fnctl.
return errors.ErrUnsupported
}
-func (d *File) OptimalBlockSize() int64 {
- return d.blockSize
-}
-
func (d *File) Zero(startByte int64, endByte int64) error {
// Can possibly be accelerated in the future via fnctl.
return GenericZero(d, startByte, endByte)
}
+
+func (d *File) Sync() error {
+ return d.backend.Sync()
+}
diff --git a/osbase/blockdev/blockdev_linux.go b/osbase/blockdev/blockdev_linux.go
index c5fa784..f6d5b4c 100644
--- a/osbase/blockdev/blockdev_linux.go
+++ b/osbase/blockdev/blockdev_linux.go
@@ -5,6 +5,7 @@
import (
"errors"
"fmt"
+ "io"
"math/bits"
"os"
"syscall"
@@ -21,10 +22,32 @@
}
func (d *Device) ReadAt(p []byte, off int64) (n int, err error) {
+ size := d.blockSize * d.blockCount
+ if off > size {
+ return 0, io.EOF
+ }
+ if int64(len(p)) > size-off {
+ n, err = d.backend.ReadAt(p[:size-off], off)
+ if err == nil {
+ err = io.EOF
+ }
+ return
+ }
return d.backend.ReadAt(p, off)
}
func (d *Device) WriteAt(p []byte, off int64) (n int, err error) {
+ size := d.blockSize * d.blockCount
+ if off > size {
+ return 0, ErrOutOfBounds
+ }
+ if int64(len(p)) > size-off {
+ n, err = d.backend.WriteAt(p[:size-off], off)
+ if err == nil {
+ err = ErrOutOfBounds
+ }
+ return
+ }
return d.backend.WriteAt(p, off)
}
@@ -40,7 +63,17 @@
return d.blockSize
}
+func (d *Device) OptimalBlockSize() int64 {
+ return d.blockSize
+}
+
func (d *Device) Discard(startByte int64, endByte int64) error {
+ if err := validAlignedRange(d, startByte, endByte); err != nil {
+ return err
+ }
+ if startByte == endByte {
+ return nil
+ }
var args [2]uint64
var err unix.Errno
args[0] = uint64(startByte)
@@ -59,11 +92,13 @@
return nil
}
-func (d *Device) OptimalBlockSize() int64 {
- return d.blockSize
-}
-
func (d *Device) Zero(startByte int64, endByte int64) error {
+ if err := validAlignedRange(d, startByte, endByte); err != nil {
+ return err
+ }
+ if startByte == endByte {
+ return nil
+ }
var args [2]uint64
var err error
args[0] = uint64(startByte)
@@ -92,6 +127,10 @@
return nil
}
+func (d *Device) Sync() error {
+ return d.backend.Sync()
+}
+
// RefreshPartitionTable refreshes the kernel's view of the partition table
// after changes made from userspace.
func (d *Device) RefreshPartitionTable() error {
@@ -165,7 +204,7 @@
func CreateFile(name string, blockSize int64, blockCount int64) (*File, error) {
if blockSize < 512 {
- return nil, fmt.Errorf("blockSize must be bigger than 512 bytes")
+ return nil, fmt.Errorf("blockSize must be at least 512 bytes")
}
if bits.OnesCount64(uint64(blockSize)) != 1 {
return nil, fmt.Errorf("blockSize must be a power of two")
@@ -187,10 +226,32 @@
}
func (d *File) ReadAt(p []byte, off int64) (n int, err error) {
+ size := d.blockSize * d.blockCount
+ if off > size {
+ return 0, io.EOF
+ }
+ if int64(len(p)) > size-off {
+ n, err = d.backend.ReadAt(p[:size-off], off)
+ if err == nil {
+ err = io.EOF
+ }
+ return
+ }
return d.backend.ReadAt(p, off)
}
func (d *File) WriteAt(p []byte, off int64) (n int, err error) {
+ size := d.blockSize * d.blockCount
+ if off > size {
+ return 0, ErrOutOfBounds
+ }
+ if int64(len(p)) > size-off {
+ n, err = d.backend.WriteAt(p[:size-off], off)
+ if err == nil {
+ err = ErrOutOfBounds
+ }
+ return
+ }
return d.backend.WriteAt(p, off)
}
@@ -206,7 +267,17 @@
return d.blockSize
}
+func (d *File) OptimalBlockSize() int64 {
+ return d.blockSize
+}
+
func (d *File) Discard(startByte int64, endByte int64) error {
+ if err := validAlignedRange(d, startByte, endByte); err != nil {
+ return err
+ }
+ if startByte == endByte {
+ return nil
+ }
var err error
if ctrlErr := d.rawConn.Control(func(fd uintptr) {
// There is FALLOC_FL_NO_HIDE_STALE, but it's not implemented by
@@ -224,11 +295,13 @@
return nil
}
-func (d *File) OptimalBlockSize() int64 {
- return d.blockSize
-}
-
func (d *File) Zero(startByte int64, endByte int64) error {
+ if err := validAlignedRange(d, startByte, endByte); err != nil {
+ return err
+ }
+ if startByte == endByte {
+ return nil
+ }
var err error
if ctrlErr := d.rawConn.Control(func(fd uintptr) {
// Tell the filesystem to punch out the given blocks.
@@ -246,3 +319,7 @@
}
return nil
}
+
+func (d *File) Sync() error {
+ return d.backend.Sync()
+}
diff --git a/osbase/blockdev/blockdev_linux_test.go b/osbase/blockdev/blockdev_linux_test.go
new file mode 100644
index 0000000..31aa827
--- /dev/null
+++ b/osbase/blockdev/blockdev_linux_test.go
@@ -0,0 +1,68 @@
+//go:build linux
+
+package blockdev
+
+import (
+ "os"
+ "testing"
+
+ "source.monogon.dev/osbase/loop"
+)
+
+const loopBlockSize = 1024
+const loopBlockCount = 8
+
+func TestLoopDevice(t *testing.T) {
+ if os.Getenv("IN_KTEST") != "true" {
+ t.Skip("Not in ktest")
+ }
+ underlying, err := os.CreateTemp("/tmp", "")
+ if err != nil {
+ t.Fatalf("CreateTemp failed: %v", err)
+ }
+ defer os.Remove(underlying.Name())
+
+ _, err = underlying.Write(make([]byte, loopBlockSize*loopBlockCount))
+ if err != nil {
+ t.Fatalf("Write failed: %v", err)
+ }
+
+ loopDev, err := loop.Create(underlying, loop.Config{
+ BlockSize: loopBlockSize,
+ })
+ if err != nil {
+ t.Fatalf("loop.Create failed: %v", err)
+ }
+ defer loopDev.Remove()
+
+ devPath, err := loopDev.DevPath()
+ if err != nil {
+ t.Fatalf("loopDev.DevPath failed: %v", err)
+ }
+
+ loopDev.Close()
+ blk, err := Open(devPath)
+ if err != nil {
+ t.Fatalf("Failed to open loop device: %v", err)
+ }
+ defer blk.Close()
+
+ ValidateBlockDev(t, blk, loopBlockCount, loopBlockSize, loopBlockSize)
+}
+
+const fileBlockSize = 1024
+const fileBlockCount = 8
+
+func TestFile(t *testing.T) {
+ if os.Getenv("IN_KTEST") != "true" {
+ t.Skip("Not in ktest")
+ }
+
+ blk, err := CreateFile("/tmp/testfile", fileBlockSize, fileBlockCount)
+ if err != nil {
+ t.Fatalf("Failed to create file: %v", err)
+ }
+ defer os.Remove("/tmp/testfile")
+
+ ValidateBlockDev(t, blk, fileBlockCount, fileBlockSize, fileBlockSize)
+}
diff --git a/osbase/blockdev/blockdev_test.go b/osbase/blockdev/blockdev_test.go
new file mode 100644
index 0000000..df4e8ae
--- /dev/null
+++ b/osbase/blockdev/blockdev_test.go
@@ -0,0 +1,319 @@
+package blockdev
+
+import (
+ "bytes"
+ "io"
+ "os"
+ "slices"
+ "testing"
+)
+
+func errIsNil(err error) bool {
+ return err == nil
+}
+func errIsEOF(err error) bool {
+ return err == io.EOF
+}
+func errIsReadFailed(err error) bool {
+ return err != nil && err != io.EOF
+}
+
+// ValidateBlockDev tests all methods of the BlockDev interface. This way, all
+// implementations can be tested without duplicating test code. This also
+// ensures that all implementations behave consistently.
+func ValidateBlockDev(t *testing.T, blk BlockDev, blockCount, blockSize, optimalBlockSize int64) {
+ if blk.BlockCount() != blockCount {
+ t.Errorf("Expected block count %d, got %d", blockCount, blk.BlockCount())
+ }
+ if blk.BlockSize() != blockSize {
+ t.Errorf("Expected block size %d, got %d", blockSize, blk.BlockSize())
+ }
+ if blk.OptimalBlockSize() != optimalBlockSize {
+ t.Errorf("Expected optimal block size %d, got %d", optimalBlockSize, blk.OptimalBlockSize())
+ }
+ size := blockCount * blockSize
+
+ // ReadAt
+ checkBlockDevOp(t, blk, func(content []byte) {
+ normalErr := errIsNil
+ if size < 3+8 {
+ normalErr = errIsEOF
+ }
+ readAtTests := []struct {
+ name string
+ offset, len int64
+ expectedErr func(error) bool
+ }{
+ {"empty start", 0, 0, errIsNil},
+ {"empty end", size, 0, errIsNil},
+ {"normal", 3, 8, normalErr},
+ {"ends past the end", 1, size, errIsEOF},
+ {"offset negative", -1, 2, errIsReadFailed},
+ {"starts at the end", size, 8, errIsEOF},
+ {"starts past the end", size + 4, 8, errIsEOF},
+ }
+ for _, tt := range readAtTests {
+ t.Run("read "+tt.name, func(t *testing.T) {
+ buf := make([]byte, tt.len)
+ n, err := blk.ReadAt(buf, tt.offset)
+ if !tt.expectedErr(err) {
+ t.Errorf("unexpected error %v", err)
+ }
+ expected := []byte{}
+ if tt.offset >= 0 && tt.offset <= size {
+ expected = content[tt.offset:min(tt.offset+tt.len, size)]
+ }
+ if n != len(expected) {
+ t.Errorf("got n = %d, expected %d", n, len(expected))
+ }
+ if !slices.Equal(buf[:n], expected) {
+ t.Errorf("read unexpected data")
+ }
+ })
+ }
+ })
+
+ // WriteAt
+ writeAtTests := []struct {
+ name string
+ offset int64
+ data string
+ ok bool
+ }{
+ {"empty start", 0, "", true},
+ {"empty end", size, "", true},
+ {"normal", 3, "abcdef", size >= 9},
+ {"ends at the end", size - 4, "abcd", size >= 4},
+ {"ends past the end", size - 4, "abcde", false},
+ {"offset negative", -1, "abc", false},
+ {"starts at the end", size, "abc", false},
+ {"starts past the end", size + 4, "abc", false},
+ }
+ for _, tt := range writeAtTests {
+ t.Run("write "+tt.name, func(t *testing.T) {
+ checkBlockDevOp(t, blk, func(content []byte) {
+ n, err := blk.WriteAt([]byte(tt.data), tt.offset)
+ if (err == nil) != tt.ok {
+ t.Errorf("expected error %v, got %v", tt.ok, err)
+ }
+ expectedN := 0
+ if tt.offset >= 0 && tt.offset < size {
+ expectedN = copy(content[tt.offset:], tt.data)
+ }
+ if n != expectedN {
+ t.Errorf("got n = %d, expected %d; err: %v", n, expectedN, err)
+ }
+ })
+ })
+ }
+
+ // Zero
+ zeroTests := []struct {
+ name string
+ start, end int64
+ ok bool
+ }{
+ {"empty range start", 0, 0, true},
+ {"empty range end", size, size, true},
+ {"full", 0, size, true},
+ {"partial", blockSize, 3 * blockSize, blockCount >= 3},
+ {"negative start", -blockSize, blockSize, false},
+ {"start after end", 2 * blockSize, blockSize, false},
+ {"unaligned start", 1, blockSize, false},
+ {"unaligned end", 0, 1, false},
+ }
+ for _, tt := range zeroTests {
+ t.Run("zero "+tt.name, func(t *testing.T) {
+ checkBlockDevOp(t, blk, func(content []byte) {
+ err := blk.Zero(tt.start, tt.end)
+ if (err == nil) != tt.ok {
+ t.Errorf("expected error %v, got %v", tt.ok, err)
+ }
+ if tt.ok {
+ for i := tt.start; i < tt.end; i++ {
+ content[i] = 0
+ }
+ }
+ })
+ })
+ }
+
+ // Discard
+ for _, tt := range zeroTests {
+ t.Run("discard "+tt.name, func(t *testing.T) {
+ checkBlockDevOp(t, blk, func(content []byte) {
+ err := blk.Discard(tt.start, tt.end)
+ if (err == nil) != tt.ok {
+ t.Errorf("expected error %v, got %v", tt.ok, err)
+ }
+ if tt.ok {
+ n, err := blk.ReadAt(content[tt.start:tt.end], tt.start)
+ if n != int(tt.end-tt.start) {
+ t.Errorf("read returned %d, %v", n, err)
+ }
+ }
+ })
+ })
+ }
+
+ // Sync
+ checkBlockDevOp(t, blk, func(content []byte) {
+ err := blk.Sync()
+ if err != nil {
+ t.Errorf("Sync failed: %v", err)
+ }
+ })
+}
+
+// checkBlockDevOp is a helper for testing operations on a blockdev. It fills
+// the blockdev with a pattern, then calls f with a slice containing the
+// pattern, and afterwards reads the blockdev to compare against the expected
+// content. f should modify the slice to the content expected afterwards.
+func checkBlockDevOp(t *testing.T, blk BlockDev, f func(content []byte)) {
+ t.Helper()
+
+ testContent := make([]byte, blk.BlockCount()*blk.BlockSize())
+ for i := range testContent {
+ testContent[i] = '1' + byte(i%9)
+ }
+ n, err := blk.WriteAt(testContent, 0)
+ if n != len(testContent) || err != nil {
+ t.Fatalf("WriteAt = %d, %v; expected %d, nil", n, err, len(testContent))
+ }
+ f(testContent)
+ afterContent := make([]byte, len(testContent))
+ n, err = blk.ReadAt(afterContent, 0)
+ if n != len(afterContent) || (err != nil && err != io.EOF) {
+ t.Fatalf("ReadAt = %d, %v; expected %d, (nil or EOF)", n, err, len(afterContent))
+ }
+ if !slices.Equal(afterContent, testContent) {
+ t.Errorf("Unexpected content after operation")
+ }
+}
+
+func TestReadWriteSeeker_Seek(t *testing.T) {
+ // Verifies that NewRWS's Seeker behaves like bytes.NewReader
+ br := bytes.NewReader([]byte("foobar"))
+ m := MustNewMemory(2, 3)
+ rws := NewRWS(m)
+ n, err := rws.Write([]byte("foobar"))
+ if n != 6 || err != nil {
+ t.Errorf("Write = %v, %v; want 6, nil", n, err)
+ }
+
+ for _, whence := range []int{io.SeekStart, io.SeekCurrent, io.SeekEnd} {
+ for offset := int64(-7); offset <= 7; offset++ {
+ brOff, brErr := br.Seek(offset, whence)
+ rwsOff, rwsErr := rws.Seek(offset, whence)
+ if (brErr != nil) != (rwsErr != nil) || brOff != rwsOff {
+ t.Errorf("For whence %d, offset %d: bytes.Reader.Seek = (%v, %v) != ReadWriteSeeker.Seek = (%v, %v)",
+ whence, offset, brOff, brErr, rwsOff, rwsErr)
+ }
+ }
+ }
+
+ // And verify we can just seek past the end and get an EOF
+ got, err := rws.Seek(100, io.SeekStart)
+ if err != nil || got != 100 {
+ t.Errorf("Seek = %v, %v; want 100, nil", got, err)
+ }
+
+ n, err = rws.Read(make([]byte, 10))
+ if n != 0 || err != io.EOF {
+ t.Errorf("Read = %v, %v; want 0, EOF", n, err)
+ }
+}
+
+func TestNewSection(t *testing.T) {
+ tests := []struct {
+ name string
+ blockSize int64
+ count int64
+ startBlock int64
+ endBlock int64
+ ok bool
+ sectionCount int64
+ }{
+ {name: "empty underlying", blockSize: 8, count: 0, startBlock: 0, endBlock: 0, ok: true, sectionCount: 0},
+ {name: "empty section", blockSize: 8, count: 5, startBlock: 2, endBlock: 2, ok: true, sectionCount: 0},
+ {name: "partial section", blockSize: 8, count: 15, startBlock: 1, endBlock: 11, ok: true, sectionCount: 10},
+ {name: "full section", blockSize: 8, count: 15, startBlock: 0, endBlock: 15, ok: true, sectionCount: 15},
+ {name: "negative start", blockSize: 8, count: 15, startBlock: -1, endBlock: 11, ok: false},
+ {name: "start after end", blockSize: 8, count: 15, startBlock: 6, endBlock: 5, ok: false},
+ {name: "end out of bounds", blockSize: 8, count: 15, startBlock: 6, endBlock: 16, ok: false},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ m := MustNewMemory(tt.blockSize, tt.count)
+ s, err := NewSection(m, tt.startBlock, tt.endBlock)
+ if (err == nil) != tt.ok {
+ t.Errorf("NewSection: expected %v, got %v", tt.ok, err)
+ }
+ if err == nil {
+ checkBlockDevOp(t, m, func(content []byte) {
+ ValidateBlockDev(t, s, tt.sectionCount, tt.blockSize, tt.blockSize)
+
+ // Check that content outside the section has not changed.
+ start := tt.startBlock * tt.blockSize
+ end := tt.endBlock * tt.blockSize
+ n, err := m.ReadAt(content[start:end], start)
+ if n != int(end-start) {
+ t.Errorf("read returned %d, %v", n, err)
+ }
+ })
+ }
+ })
+ }
+}
+
+type MemoryWithGenericZero struct {
+ *Memory
+}
+
+func (m *MemoryWithGenericZero) Zero(startByte, endByte int64) error {
+ return GenericZero(m, startByte, endByte)
+}
+
+func TestGenericZero(t *testing.T) {
+ if os.Getenv("IN_KTEST") == "true" {
+ t.Skip("In ktest")
+ }
+ // Use size larger than the 16 MiB buffer size in GenericZero.
+ blockSize := int64(512)
+ blockCount := int64(35 * 1024)
+ m, err := NewMemory(blockSize, blockCount)
+ if err != nil {
+ t.Errorf("NewMemory: %v", err)
+ }
+ b := &MemoryWithGenericZero{m}
+ if err == nil {
+ ValidateBlockDev(t, b, blockCount, blockSize, blockSize)
+ }
+}
+
+func TestNewMemory(t *testing.T) {
+ tests := []struct {
+ name string
+ blockSize int64
+ count int64
+ ok bool
+ }{
+ {name: "normal", blockSize: 64, count: 9, ok: true},
+ {name: "count 0", blockSize: 8, count: 0, ok: true},
+ {name: "count negative", blockSize: 8, count: -1, ok: false},
+ {name: "blockSize not a power of 2", blockSize: 9, count: 5, ok: false},
+ {name: "blockSize 0", blockSize: 0, count: 5, ok: false},
+ {name: "blockSize negative", blockSize: -1, count: 5, ok: false},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ m, err := NewMemory(tt.blockSize, tt.count)
+ if (err == nil) != tt.ok {
+ t.Errorf("NewMemory: expected %v, got %v", tt.ok, err)
+ }
+ if err == nil {
+ ValidateBlockDev(t, m, tt.count, tt.blockSize, tt.blockSize)
+ }
+ })
+ }
+}
diff --git a/osbase/blockdev/memory.go b/osbase/blockdev/memory.go
index 193f93c..cf5e21d 100644
--- a/osbase/blockdev/memory.go
+++ b/osbase/blockdev/memory.go
@@ -43,70 +43,47 @@
return m
}
-func (m *Memory) ReadAt(p []byte, off int64) (int, error) {
- devSize := m.blockSize * m.blockCount
- if off > devSize {
+func (m *Memory) ReadAt(p []byte, off int64) (n int, err error) {
+ if off < 0 {
+ return 0, errors.New("blockdev.Memory.ReadAt: negative offset")
+ }
+ if off > int64(len(m.data)) {
return 0, io.EOF
}
// TODO: Alignment checks?
- copy(p, m.data[off:])
- n := len(m.data[off:])
+ n = copy(p, m.data[off:])
if n < len(p) {
- return n, io.EOF
+ err = io.EOF
}
- return len(p), nil
+ return
}
-func (m *Memory) WriteAt(p []byte, off int64) (int, error) {
- devSize := m.blockSize * m.blockCount
- if off > devSize {
- return 0, io.EOF
+func (m *Memory) WriteAt(p []byte, off int64) (n int, err error) {
+ if off < 0 || off > int64(len(m.data)) {
+ return 0, ErrOutOfBounds
}
// TODO: Alignment checks?
- copy(m.data[off:], p)
- n := len(m.data[off:])
+ n = copy(m.data[off:], p)
if n < len(p) {
- return n, io.EOF
+ err = ErrOutOfBounds
}
- return len(p), nil
-}
-
-func (m *Memory) BlockSize() int64 {
- return m.blockSize
+ return
}
func (m *Memory) BlockCount() int64 {
return m.blockCount
}
+func (m *Memory) BlockSize() int64 {
+ return m.blockSize
+}
+
func (m *Memory) OptimalBlockSize() int64 {
return m.blockSize
}
-func (m *Memory) validRange(startByte, endByte int64) error {
- if startByte > endByte {
- return fmt.Errorf("startByte (%d) larger than endByte (%d), invalid interval", startByte, endByte)
- }
- devSize := m.blockSize * m.blockCount
- if startByte >= devSize || startByte < 0 {
- return fmt.Errorf("startByte (%d) out of range (0-%d)", endByte, devSize)
- }
- if endByte > devSize || endByte < 0 {
- return fmt.Errorf("endByte (%d) out of range (0-%d)", endByte, devSize)
- }
- // Alignment check works for powers of two by looking at every bit below
- // the bit set in the block size.
- if startByte&(m.blockSize-1) != 0 {
- return fmt.Errorf("startByte (%d) is not aligned to blockSize (%d)", startByte, m.blockSize)
- }
- if endByte&(m.blockSize-1) != 0 {
- return fmt.Errorf("endByte (%d) is not aligned to blockSize (%d)", startByte, m.blockSize)
- }
- return nil
-}
-
func (m *Memory) Discard(startByte, endByte int64) error {
- if err := m.validRange(startByte, endByte); err != nil {
+ if err := validAlignedRange(m, startByte, endByte); err != nil {
return err
}
for i := startByte; i < endByte; i++ {
@@ -118,7 +95,7 @@
}
func (m *Memory) Zero(startByte, endByte int64) error {
- if err := m.validRange(startByte, endByte); err != nil {
+ if err := validAlignedRange(m, startByte, endByte); err != nil {
return err
}
for i := startByte; i < endByte; i++ {
@@ -126,3 +103,7 @@
}
return nil
}
+
+func (m *Memory) Sync() error {
+ return nil
+}
diff --git a/osbase/gpt/gpt.go b/osbase/gpt/gpt.go
index be6fb7f..ca8566a 100644
--- a/osbase/gpt/gpt.go
+++ b/osbase/gpt/gpt.go
@@ -281,7 +281,10 @@
} else {
gpt.Partitions[newPartPos] = p
}
- p.Section = blockdev.NewSection(gpt.b, int64(p.FirstBlock), int64(p.LastBlock)+1)
+ p.Section, err = blockdev.NewSection(gpt.b, int64(p.FirstBlock), int64(p.LastBlock)+1)
+ if err != nil {
+ return fmt.Errorf("failed to create blockdev Section for partition: %w", err)
+ }
return nil
}
}