osbase/blockdev: implement copy_file_range optimization

This change enables the use of the copy_file_range syscall on Linux when
copying from an os.File to a blockdev.File. This speeds up building of
system images, especially with a file system which supports reflinks.

The implementation is partially based on the implementation in the Go
standard library for copy_file_range between two os.File in
src/os/zero_copy_linux.go and src/internal/poll/copy_file_range_unix.go.
We can't use that implementation, because it only supports using the
file offset for both source and destination, but we want to provide the
destination offset as an argument. To support this, the ReaderFromAt
interface is introduced.

With these changes, copy_file_range is now used when building system
images, for both the rootfs and files on the FAT32 boot partition. If
the file system supports it (e.g. XFS), reflinks will be used for the
rootfs, which means no data is copied. For files on the FAT32 partition,
reflinks probably can't be used, because these are only aligned to 512
bytes but would need to be aligned to 4096 bytes on my system for
reflinking.

Change-Id: Ie42b5834e6d3f63a5cc1f347d2681d8a6bb5c006
Reviewed-on: https://review.monogon.dev/c/monogon/+/4293
Tested-by: Jenkins CI
Reviewed-by: Lorenz Brun <lorenz@monogon.tech>
diff --git a/osbase/blockdev/blockdev.go b/osbase/blockdev/blockdev.go
index 5eb7fe8..9877186 100644
--- a/osbase/blockdev/blockdev.go
+++ b/osbase/blockdev/blockdev.go
@@ -97,6 +97,25 @@
 	Sync() error
 }
 
+// ReaderFromAt is similar to [io.ReaderFrom], except that the write starts at
+// offset off instead of using the file offset.
+type ReaderFromAt interface {
+	ReadFromAt(r io.Reader, off int64) (n int64, err error)
+}
+
+// writerOnly wraps an [io.Writer] and hides all methods other than Write
+// (such as ReadFrom).
+type writerOnly struct {
+	io.Writer
+}
+
+// genericReadFromAt is a generic implementation which does not use b.ReadFromAt
+// to prevent recursive calls.
+func genericReadFromAt(b BlockDev, r io.Reader, off int64) (int64, error) {
+	w := &writerOnly{Writer: &ReadWriteSeeker{b: b, currPos: off}}
+	return io.Copy(w, r)
+}
+
 func NewRWS(b BlockDev) *ReadWriteSeeker {
 	return &ReadWriteSeeker{b: b}
 }
@@ -120,6 +139,17 @@
 	return
 }
 
+func (s *ReadWriteSeeker) ReadFrom(r io.Reader) (n int64, err error) {
+	rfa, rfaOK := s.b.(ReaderFromAt)
+	if !rfaOK {
+		w := &writerOnly{Writer: s}
+		return io.Copy(w, r)
+	}
+	n, err = rfa.ReadFromAt(r, s.currPos)
+	s.currPos += n
+	return
+}
+
 func (s *ReadWriteSeeker) Seek(offset int64, whence int) (int64, error) {
 	switch whence {
 	default:
@@ -203,6 +233,40 @@
 	return s.b.WriteAt(p, bOff)
 }
 
+func (s *Section) ReadFromAt(r io.Reader, off int64) (n int64, err error) {
+	rfa, rfaOK := s.b.(ReaderFromAt)
+	if !rfaOK {
+		return genericReadFromAt(s, r, off)
+	}
+	bOff := off + (s.startBlock * s.b.BlockSize())
+	bytesToEnd := (s.endBlock * s.b.BlockSize()) - bOff
+	if off < 0 || bytesToEnd < 0 {
+		return 0, ErrOutOfBounds
+	}
+	ur := r
+	lr, lrOK := r.(*io.LimitedReader)
+	if lrOK {
+		if bytesToEnd >= lr.N {
+			return rfa.ReadFromAt(r, bOff)
+		}
+		ur = lr.R
+	}
+	n, err = rfa.ReadFromAt(io.LimitReader(ur, bytesToEnd), bOff)
+	if lrOK {
+		lr.N -= n
+	}
+	if err == nil && n == bytesToEnd {
+		// Return an error if we have not reached EOF.
+		moreN, moreErr := io.CopyN(io.Discard, r, 1)
+		if moreN != 0 {
+			err = ErrOutOfBounds
+		} else if moreErr != io.EOF {
+			err = moreErr
+		}
+	}
+	return
+}
+
 func (s *Section) BlockCount() int64 {
 	return s.endBlock - s.startBlock
 }
diff --git a/osbase/blockdev/blockdev_linux.go b/osbase/blockdev/blockdev_linux.go
index fbcbf5b..ab9c96b 100644
--- a/osbase/blockdev/blockdev_linux.go
+++ b/osbase/blockdev/blockdev_linux.go
@@ -299,6 +299,105 @@
 	return d.backend.WriteAt(p, off)
 }
 
+func (d *File) ReadFromAt(r io.Reader, off int64) (n int64, err error) {
+	size := d.blockSize * d.blockCount
+	if off > size || off < 0 {
+		return 0, ErrOutOfBounds
+	}
+	limit := size - off
+	ur := r
+	lr, lrOK := r.(*io.LimitedReader)
+	if lrOK {
+		ur = lr.R
+		limit = min(limit, lr.N)
+	}
+	n, handled, err := d.doCopyFileRange(ur, off, limit)
+	if lrOK {
+		lr.N -= n
+	}
+	off += n
+	if !handled {
+		var fallbackN int64
+		fallbackN, err = genericReadFromAt(d, r, off)
+		n += fallbackN
+		return
+	}
+	if err == nil && off == size {
+		// Return an error if we have not reached EOF.
+		moreN, moreErr := io.CopyN(io.Discard, r, 1)
+		if moreN != 0 {
+			err = ErrOutOfBounds
+		} else if moreErr != io.EOF {
+			err = moreErr
+		}
+	}
+	return
+}
+
+// Copied from Go src/internal/poll/copy_file_range_linux.go
+const maxCopyFileRangeRound = 0x7ffff000
+
+// doCopyFileRange attempts to copy using the copy_file_range syscall.
+//
+// This is only implemented for [File] because Linux does not support this
+// syscall on block devices.
+func (d *File) doCopyFileRange(r io.Reader, off int64, remain int64) (written int64, handled bool, err error) {
+	if remain <= 0 {
+		handled = true
+		return
+	}
+	// Note: We should also check for os.fileWithoutWriteTo, but that type isn't
+	// exported. This means that this optimization won't work if the top-level
+	// copy is io.Copy, but it does work with io.CopyN and w.ReadFrom(r).
+	src, srcOK := r.(*os.File)
+	if !srcOK {
+		return
+	}
+	srcConn, err := src.SyscallConn()
+	if err != nil {
+		return
+	}
+	// We need a read lock of src, because its file offset is used and updated.
+	// We don't need a lock of dest, because its file offset is not used.
+	readErr := srcConn.Read(func(srcFD uintptr) bool {
+		controlErr := d.rawConn.Control(func(destFD uintptr) {
+			handled = true
+			for remain > 0 {
+				n := int(min(remain, maxCopyFileRangeRound))
+				n, err = unix.CopyFileRange(int(srcFD), nil, int(destFD), &off, n, 0)
+				if n > 0 {
+					remain -= int64(n)
+					written += int64(n)
+					// The kernel adds n to off.
+				}
+				// See handleCopyFileRangeErr in
+				// src/internal/poll/copy_file_range_linux.go
+				if err != nil {
+					if errors.Is(err, unix.ENOSYS) || errors.Is(err, unix.EXDEV) ||
+						errors.Is(err, unix.EINVAL) || errors.Is(err, unix.EIO) ||
+						errors.Is(err, unix.EOPNOTSUPP) || errors.Is(err, unix.EPERM) {
+						handled = false
+					}
+					break
+				} else if n == 0 {
+					if written == 0 {
+						handled = false
+					}
+					break
+				}
+			}
+		})
+		if err == nil {
+			err = controlErr
+		}
+		return true
+	})
+	if err == nil {
+		err = readErr
+	}
+	return
+}
+
 func (d *File) Close() error {
 	return d.backend.Close()
 }
diff --git a/osbase/blockdev/blockdev_linux_test.go b/osbase/blockdev/blockdev_linux_test.go
index 9cbf027..b61c1ee 100644
--- a/osbase/blockdev/blockdev_linux_test.go
+++ b/osbase/blockdev/blockdev_linux_test.go
@@ -6,6 +6,7 @@
 package blockdev
 
 import (
+	"io"
 	"os"
 	"testing"
 
@@ -66,6 +67,98 @@
 		t.Fatalf("Failed to create file: %v", err)
 	}
 	defer os.Remove("/tmp/testfile")
+	defer blk.Close()
 
 	ValidateBlockDev(t, blk, fileBlockCount, fileBlockSize, fileBlockSize)
+
+	// ReadFromAt
+	srcFile, err := os.Create("/tmp/copysrc")
+	if err != nil {
+		t.Fatalf("Failed to create source file: %v", err)
+	}
+	defer os.Remove("/tmp/copysrc")
+	defer srcFile.Close()
+	var size int64 = fileBlockSize * fileBlockCount
+	readFromAtTests := []struct {
+		name   string
+		offset int64
+		data   string
+		limit  int64
+		ok     bool
+	}{
+		{"empty start", 0, "", -1, true},
+		{"empty end", size, "", -1, true},
+		{"normal", 3, "abcdef", -1, true},
+		{"limited", 3, "abcdef", 4, true},
+		{"large limit", 3, "abcdef", size, true},
+		{"ends at the end", size - 4, "abcd", -1, true},
+		{"ends past the end", size - 4, "abcde", -1, false},
+		{"ends past the end with limit", size - 4, "abcde", 10, false},
+		{"offset negative", -1, "abc", -1, false},
+		{"starts at the end", size, "abc", -1, false},
+		{"starts past the end", size + 4, "abc", -1, false},
+	}
+	for _, tt := range readFromAtTests {
+		t.Run("readFromAt "+tt.name, func(t *testing.T) {
+			checkBlockDevOp(t, blk, func(content []byte) {
+				// Prepare source file
+				err = srcFile.Truncate(0)
+				if err != nil {
+					t.Fatalf("Failed to truncate source file: %v", err)
+				}
+				_, err = srcFile.WriteAt([]byte("123"+tt.data), 0)
+				if err != nil {
+					t.Fatalf("Failed to write source file: %v", err)
+				}
+				_, err = srcFile.Seek(3, io.SeekStart)
+				if err != nil {
+					t.Fatalf("Failed to seek source file: %v", err)
+				}
+
+				// Do ReadFromAt
+				r := io.Reader(srcFile)
+				lr := &io.LimitedReader{R: srcFile, N: tt.limit}
+				if tt.limit != -1 {
+					r = lr
+				}
+				n, err := blk.ReadFromAt(r, tt.offset)
+				if (err == nil) != tt.ok {
+					t.Errorf("expected error %v, got %v", tt.ok, err)
+				}
+				expectedN := 0
+				if tt.offset >= 0 && tt.offset < size {
+					c := content[tt.offset:]
+					if tt.limit != -1 && tt.limit < int64(len(c)) {
+						c = c[:tt.limit]
+					}
+					expectedN = copy(c, tt.data)
+				}
+				if n != int64(expectedN) {
+					t.Errorf("got n = %d, expected %d; err: %v", n, expectedN, err)
+				}
+
+				// Check new offset
+				newOffset, err := srcFile.Seek(0, io.SeekCurrent)
+				if err != nil {
+					t.Fatalf("Failed to get source file position: %v", err)
+				}
+				newOffset -= 3
+				minOffset := n
+				maxOffset := n
+				if !tt.ok {
+					maxOffset = int64(len(tt.data))
+					if tt.limit != -1 {
+						maxOffset = min(maxOffset, tt.limit)
+					}
+				}
+				if minOffset > newOffset || newOffset > maxOffset {
+					t.Errorf("Got newOffset = %d, expected between %d and %d", newOffset, minOffset, maxOffset)
+				}
+				remaining := tt.limit - newOffset
+				if tt.limit != -1 && lr.N != remaining {
+					t.Errorf("Got lr.N = %d, expected %d", lr.N, remaining)
+				}
+			})
+		})
+	}
 }