osbase/blockdev: add tests, fix minor issues
Add a lot of bounds checks which should make BlockDev safer to use. Fix
a bug in the ReadWriteSeeker.Seek function with io.SeekEnd; the offset
should be added to, not subtracted from the size. Add the Sync()
function to the BlockDev interface.
Change-Id: I247095b3dbc6410064844b4ac7c6208d88a7abcd
Reviewed-on: https://review.monogon.dev/c/monogon/+/3338
Tested-by: Jenkins CI
Reviewed-by: Lorenz Brun <lorenz@monogon.tech>
diff --git a/osbase/blockdev/blockdev.go b/osbase/blockdev/blockdev.go
index 0e3c6e1..9cf4ee2 100644
--- a/osbase/blockdev/blockdev.go
+++ b/osbase/blockdev/blockdev.go
@@ -16,14 +16,15 @@
type BlockDev interface {
io.ReaderAt
io.WriterAt
- // BlockSize returns the block size of the block device in bytes. This must
- // be a power of two and is commonly (but not always) either 512 or 4096.
- BlockSize() int64
// BlockCount returns the number of blocks on the block device or -1 if it
// is an image with an undefined size.
BlockCount() int64
+ // BlockSize returns the block size of the block device in bytes. This must
+ // be a power of two and is commonly (but not always) either 512 or 4096.
+ BlockSize() int64
+
// OptimalBlockSize returns the optimal block size in bytes for aligning
// to as well as issuing I/O. IO operations with block sizes below this
// one might incur read-write overhead. This is the larger of the physical
@@ -41,6 +42,9 @@
// Zero zeroes a continouous set of blocks. On certain implementations this
// can be significantly faster than just calling Write with zeroes.
Zero(startByte, endByte int64) error
+
+ // Sync commits the current contents to stable storage.
+ Sync() error
}
func NewRWS(b BlockDev) *ReadWriteSeeker {
@@ -68,13 +72,18 @@
func (s *ReadWriteSeeker) Seek(offset int64, whence int) (int64, error) {
switch whence {
- case io.SeekCurrent:
- s.currPos += offset
+ default:
+ return 0, errors.New("Seek: invalid whence")
case io.SeekStart:
- s.currPos = offset
+ case io.SeekCurrent:
+ offset += s.currPos
case io.SeekEnd:
- s.currPos = (s.b.BlockCount() * s.b.BlockSize()) - offset
+ offset += s.b.BlockCount() * s.b.BlockSize()
}
+ if offset < 0 {
+ return 0, errors.New("Seek: invalid offset")
+ }
+ s.currPos = offset
return s.currPos, nil
}
@@ -82,12 +91,21 @@
// NewSection returns a new Section, implementing BlockDev over that subset
// of blocks. The interval is inclusive-exclusive.
-func NewSection(b BlockDev, startBlock, endBlock int64) *Section {
+func NewSection(b BlockDev, startBlock, endBlock int64) (*Section, error) {
+ if startBlock < 0 {
+ return nil, fmt.Errorf("invalid range: startBlock (%d) negative", startBlock)
+ }
+ if startBlock > endBlock {
+ return nil, fmt.Errorf("invalid range: startBlock (%d) bigger than endBlock (%d)", startBlock, endBlock)
+ }
+ if endBlock > b.BlockCount() {
+ return nil, fmt.Errorf("endBlock (%d) out of range (%d)", endBlock, b.BlockCount())
+ }
return &Section{
b: b,
startBlock: startBlock,
endBlock: endBlock,
- }
+ }, nil
}
// Section implements BlockDev on a slice of another BlockDev given a startBlock
@@ -98,13 +116,20 @@
}
func (s *Section) ReadAt(p []byte, off int64) (n int, err error) {
+ if off < 0 {
+ return 0, errors.New("blockdev.Section.ReadAt: negative offset")
+ }
bOff := off + (s.startBlock * s.b.BlockSize())
bytesToEnd := (s.endBlock * s.b.BlockSize()) - bOff
- if bytesToEnd <= 0 {
+ if bytesToEnd < 0 {
return 0, io.EOF
}
if bytesToEnd < int64(len(p)) {
- return s.b.ReadAt(p[:bytesToEnd], bOff)
+ n, err := s.b.ReadAt(p[:bytesToEnd], bOff)
+ if err == nil {
+ err = io.EOF
+ }
+ return n, err
}
return s.b.ReadAt(p, bOff)
}
@@ -112,11 +137,11 @@
func (s *Section) WriteAt(p []byte, off int64) (n int, err error) {
bOff := off + (s.startBlock * s.b.BlockSize())
bytesToEnd := (s.endBlock * s.b.BlockSize()) - bOff
- if bytesToEnd <= 0 {
+ if off < 0 || bytesToEnd < 0 {
return 0, ErrOutOfBounds
}
if bytesToEnd < int64(len(p)) {
- n, err := s.b.WriteAt(p[:bytesToEnd], off+(s.startBlock*s.b.BlockSize()))
+ n, err := s.b.WriteAt(p[:bytesToEnd], bOff)
if err != nil {
// If an error happened, prioritize that error
return n, err
@@ -125,7 +150,7 @@
// error.
return n, ErrOutOfBounds
}
- return s.b.WriteAt(p, off+(s.startBlock*s.b.BlockSize()))
+ return s.b.WriteAt(p, bOff)
}
func (s *Section) BlockCount() int64 {
@@ -136,49 +161,56 @@
return s.b.BlockSize()
}
-func (s *Section) inRange(startByte, endByte int64) error {
- if startByte > endByte {
- return fmt.Errorf("invalid range: startByte (%d) bigger than endByte (%d)", startByte, endByte)
- }
- sectionLen := s.BlockCount() * s.BlockSize()
- if startByte >= sectionLen {
- return fmt.Errorf("startByte (%d) out of range (%d)", startByte, sectionLen)
- }
- if endByte > sectionLen {
- return fmt.Errorf("endBlock (%d) out of range (%d)", endByte, sectionLen)
- }
- return nil
+func (s *Section) OptimalBlockSize() int64 {
+ return s.b.OptimalBlockSize()
}
func (s *Section) Discard(startByte, endByte int64) error {
- if err := s.inRange(startByte, endByte); err != nil {
+ if err := validAlignedRange(s, startByte, endByte); err != nil {
return err
}
offset := s.startBlock * s.b.BlockSize()
return s.b.Discard(offset+startByte, offset+endByte)
}
-func (s *Section) OptimalBlockSize() int64 {
- return s.b.OptimalBlockSize()
-}
-
func (s *Section) Zero(startByte, endByte int64) error {
- if err := s.inRange(startByte, endByte); err != nil {
+ if err := validAlignedRange(s, startByte, endByte); err != nil {
return err
}
offset := s.startBlock * s.b.BlockSize()
return s.b.Zero(offset+startByte, offset+endByte)
}
-// GenericZero implements software-based zeroing. This can be used to implement
-// Zero when no acceleration is available or desired.
-func GenericZero(b BlockDev, startByte, endByte int64) error {
+func (s *Section) Sync() error {
+ return s.b.Sync()
+}
+
+func validAlignedRange(b BlockDev, startByte, endByte int64) error {
+ if startByte < 0 {
+ return fmt.Errorf("invalid range: startByte (%d) negative", startByte)
+ }
+ if startByte > endByte {
+ return fmt.Errorf("invalid range: startByte (%d) bigger than endByte (%d)", startByte, endByte)
+ }
+ devLen := b.BlockCount() * b.BlockSize()
+ if endByte > devLen {
+ return fmt.Errorf("endByte (%d) out of range (%d)", endByte, devLen)
+ }
if startByte%b.BlockSize() != 0 {
return fmt.Errorf("startByte (%d) needs to be aligned to block size (%d)", startByte, b.BlockSize())
}
if endByte%b.BlockSize() != 0 {
return fmt.Errorf("endByte (%d) needs to be aligned to block size (%d)", endByte, b.BlockSize())
}
+ return nil
+}
+
+// GenericZero implements software-based zeroing. This can be used to implement
+// Zero when no acceleration is available or desired.
+func GenericZero(b BlockDev, startByte, endByte int64) error {
+ if err := validAlignedRange(b, startByte, endByte); err != nil {
+ return err
+ }
// Choose buffer size close to 16MiB or the range to be zeroed, whatever
// is smaller.
bufSizeTarget := int64(16 * 1024 * 1024)