osbase/logtree: add WithStartPosition option

To allow users to not always request all messages,
we introduce another option to the logtree.LogReader
which allows for starting at a specific global log id.
This, for example, makes implementing scrollback easier.

Change-Id: I1773288f670f476706d94baf3f052fe1e5da9eb0
Reviewed-on: https://review.monogon.dev/c/monogon/+/4452
Tested-by: Jenkins CI
Reviewed-by: Lorenz Brun <lorenz@monogon.tech>
diff --git a/osbase/logtree/BUILD.bazel b/osbase/logtree/BUILD.bazel
index 584d22c..f0f7c9a 100644
--- a/osbase/logtree/BUILD.bazel
+++ b/osbase/logtree/BUILD.bazel
@@ -49,6 +49,7 @@
         "journal_test.go",
         "klog_test.go",
         "kmsg_test.go",
+        "logtree_access_test.go",
         "logtree_test.go",
         "zap_test.go",
     ],
diff --git a/osbase/logtree/journal.go b/osbase/logtree/journal.go
index c6cc503..8697b21 100644
--- a/osbase/logtree/journal.go
+++ b/osbase/logtree/journal.go
@@ -5,7 +5,7 @@
 
 import (
 	"errors"
-	"sort"
+	"slices"
 	"strings"
 	"sync"
 
@@ -96,6 +96,10 @@
 	// provided filters (eg. to limit events to subtrees that interest that particular
 	// subscriber).
 	subscribers []*subscriber
+
+	// seq is a counter tracking the total amount of log entries appended since
+	// creation.
+	seq uint64
 }
 
 // newJournal creates a new empty journal. All journals are independent from
@@ -169,6 +173,19 @@
 	return e.leveled != nil
 }
 
+func filterStartPosition(count int, pos int, direction ReadDirection) filter {
+	return func(e *entry) bool {
+		switch direction {
+		case ReadDirectionAfter:
+			return e.seqGlobal >= uint64(pos) && (count == BacklogAllAvailable || e.seqGlobal < uint64(pos+count))
+		case ReadDirectionBefore:
+			return e.seqGlobal < uint64(pos) && (count == BacklogAllAvailable || e.seqGlobal >= uint64(max(0, pos-count)))
+		default:
+			panic("unreachable")
+		}
+	}
+}
+
 // scanEntries does a linear scan through the global entry list and returns all
 // entries that match the given filters. If retrieving entries for an exact event,
 // getEntries should be used instead, as it will leverage DN-local linked lists to
@@ -198,9 +215,8 @@
 	}
 
 	// Reverse entries back into chronological order.
-	sort.SliceStable(res, func(i, j int) bool {
-		return i > j
-	})
+	slices.Reverse(res)
+
 	return
 }
 
@@ -234,9 +250,8 @@
 	}
 
 	// Reverse entries back into chronological order.
-	sort.SliceStable(res, func(i, j int) bool {
-		return i > j
-	})
+	slices.Reverse(res)
+
 	return
 }
 
diff --git a/osbase/logtree/journal_entry.go b/osbase/logtree/journal_entry.go
index c553ac8..d492d16 100644
--- a/osbase/logtree/journal_entry.go
+++ b/osbase/logtree/journal_entry.go
@@ -43,6 +43,10 @@
 	// length calculation for local linked lists as long as entries are only unlinked
 	// from the head or tail (which is the case in the current implementation).
 	seqLocal uint64
+
+	// seqGlobal is a counter within the global log structure that increases by
+	// one each time a new log entry is added.
+	seqGlobal uint64
 }
 
 // defaultDNQuota defines how many messages should be stored per DN.
@@ -53,9 +57,10 @@
 // sequences, etc. These objects are visible to library consumers.
 func (e *entry) external() *LogEntry {
 	return &LogEntry{
-		DN:      e.origin,
-		Leveled: e.leveled,
-		Raw:     e.raw,
+		DN:       e.origin,
+		Leveled:  e.leveled,
+		Raw:      e.raw,
+		Position: int(e.seqGlobal),
 	}
 }
 
@@ -108,6 +113,8 @@
 	defer j.mu.Unlock()
 
 	e.journal = j
+	e.seqGlobal = j.seq
+	j.seq++
 
 	// Insert at head in global linked list, set pointers.
 	e.nextGlobal = nil
diff --git a/osbase/logtree/logtree_access.go b/osbase/logtree/logtree_access.go
index 1582a8f..1b4f90d 100644
--- a/osbase/logtree/logtree_access.go
+++ b/osbase/logtree/logtree_access.go
@@ -10,6 +10,13 @@
 	"source.monogon.dev/go/logging"
 )
 
+type ReadDirection int
+
+const (
+	ReadDirectionAfter ReadDirection = iota
+	ReadDirectionBefore
+)
+
 // LogReadOption describes options for the LogTree.Read call.
 type LogReadOption func(*logReaderOptions)
 
@@ -21,6 +28,8 @@
 	onlyRaw                    bool
 	leveledWithMinimumSeverity logging.Severity
 	withStreamBufferSize       int
+	withStartPosition          int
+	startPositionReadDirection ReadDirection
 }
 
 // WithChildren makes Read return/stream data for both a given DN and all its
@@ -54,6 +63,20 @@
 	return func(lro *logReaderOptions) { lro.withBacklog = count }
 }
 
+// WithStartPosition makes Read return log entries from the given position.
+// It requires WithBacklog to be provided.
+//
+// The Journal keeps a global counter for all logs, starting at 0 for the
+// first message. Based on this the user can read entries
+// (based on the ReadDirection option) either after or before the given
+// position.
+func WithStartPosition(pos int, direction ReadDirection) LogReadOption {
+	return func(lro *logReaderOptions) {
+		lro.withStartPosition = pos
+		lro.startPositionReadDirection = direction
+	}
+}
+
 // BacklogAllAvailable makes WithBacklog return all backlogged log data that
 // logtree possesses.
 const BacklogAllAvailable int = -1
@@ -107,7 +130,8 @@
 }
 
 var (
-	ErrRawAndLeveled = errors.New("cannot return logs that are simultaneously OnlyRaw and OnlyLeveled")
+	ErrRawAndLeveled               = errors.New("cannot return logs that are simultaneously OnlyRaw and OnlyLeveled")
+	ErrStartPositionWithoutBacklog = errors.New("cannot return logs that are WithStartingPosition and missing WithBacklog")
 )
 
 // Read and/or stream entries from a LogTree. The returned LogReader is influenced
@@ -121,6 +145,7 @@
 
 	lro := logReaderOptions{
 		withStreamBufferSize: 128,
+		withStartPosition:    -1,
 	}
 
 	for _, opt := range opts {
@@ -131,7 +156,15 @@
 		return nil, ErrRawAndLeveled
 	}
 
+	isWithBacklog := lro.withBacklog > 0 || lro.withBacklog == BacklogAllAvailable
+	if lro.withStartPosition != -1 && !isWithBacklog {
+		return nil, ErrStartPositionWithoutBacklog
+	}
+
 	var filters []filter
+	if lro.withStartPosition != -1 {
+		filters = append(filters, filterStartPosition(lro.withBacklog, lro.withStartPosition, lro.startPositionReadDirection))
+	}
 	if lro.onlyLeveled {
 		filters = append(filters, filterOnlyLeveled)
 	}
@@ -148,7 +181,7 @@
 	}
 
 	var entries []*entry
-	if lro.withBacklog > 0 || lro.withBacklog == BacklogAllAvailable {
+	if isWithBacklog {
 		if lro.withChildren {
 			entries = l.journal.scanEntries(lro.withBacklog, filters...)
 		} else {
diff --git a/osbase/logtree/logtree_access_test.go b/osbase/logtree/logtree_access_test.go
new file mode 100644
index 0000000..8379fd3
--- /dev/null
+++ b/osbase/logtree/logtree_access_test.go
@@ -0,0 +1,166 @@
+// Copyright The Monogon Project Authors.
+// SPDX-License-Identifier: Apache-2.0
+
+package logtree
+
+import (
+	"fmt"
+	"strings"
+	"testing"
+)
+
+func TestJournalStartPosition(t *testing.T) {
+	lt := New()
+
+	for i := 0; i < 100; i += 1 {
+		e := &entry{
+			origin:  "main",
+			leveled: testPayload(fmt.Sprintf("test %d", i)),
+		}
+		lt.journal.append(e)
+	}
+
+	type tCase struct {
+		name string
+
+		count     int
+		direction ReadDirection
+		pos       int
+
+		expectedCount int
+		expectedFirst string
+		expectedLast  string
+	}
+
+	for _, tc := range []tCase{
+		{
+			name:      "fetch all before id 0",
+			count:     BacklogAllAvailable,
+			direction: ReadDirectionBefore,
+			pos:       0,
+
+			expectedCount: 0,
+			expectedFirst: "UNREACHABLE",
+			expectedLast:  "UNREACHABLE",
+		},
+		{
+			name:      "fetch all after id 0",
+			count:     BacklogAllAvailable,
+			direction: ReadDirectionAfter,
+			pos:       0,
+
+			expectedCount: 100,
+			expectedFirst: "test 0",
+			expectedLast:  "test 99",
+		},
+
+		{
+			name:      "fetch all before id 10",
+			count:     BacklogAllAvailable,
+			direction: ReadDirectionBefore,
+			pos:       10,
+
+			expectedCount: 10,
+			expectedFirst: "test 0",
+			expectedLast:  "test 9",
+		},
+		{
+			name:      "fetch all after id 10",
+			count:     BacklogAllAvailable,
+			direction: ReadDirectionAfter,
+			pos:       10,
+
+			expectedCount: 90,
+			expectedFirst: "test 10",
+			expectedLast:  "test 99",
+		},
+
+		{
+			name:      "fetch 10 before id 0",
+			count:     10,
+			direction: ReadDirectionBefore,
+			pos:       0,
+
+			expectedCount: 0,
+			expectedFirst: "UNREACHABLE",
+			expectedLast:  "UNREACHABLE",
+		},
+		{
+			name:      "fetch 10 after id 0",
+			count:     10,
+			direction: ReadDirectionAfter,
+			pos:       0,
+
+			expectedCount: 10,
+			expectedFirst: "test 0",
+			expectedLast:  "test 9",
+		},
+
+		{
+			name:      "fetch 10 before id 3",
+			count:     10,
+			direction: ReadDirectionBefore,
+			pos:       3,
+
+			expectedCount: 3,
+			expectedFirst: "test 0",
+			expectedLast:  "test 2",
+		},
+		{
+			name:      "fetch 10 after id 3",
+			count:     10,
+			direction: ReadDirectionAfter,
+			pos:       3,
+
+			expectedCount: 10,
+			expectedFirst: "test 3",
+			expectedLast:  "test 12",
+		},
+		{
+			name:      "fetch 43 before id 47",
+			count:     43,
+			direction: ReadDirectionBefore,
+			pos:       47,
+
+			expectedCount: 43,
+			expectedFirst: "test 4",
+			expectedLast:  "test 46",
+		},
+	} {
+		t.Run(tc.name, func(t *testing.T) {
+			lr, _ := lt.Read("main",
+				WithBacklog(tc.count),
+				WithStartPosition(tc.pos, tc.direction),
+			)
+			if l := len(lr.Backlog); l != tc.expectedCount {
+				t.Fatalf("expected %d entries, got %d", tc.expectedCount, l)
+			}
+			if len(lr.Backlog) == 0 {
+				// If there is nothing to test against, skip to next test.
+				return
+			}
+			if first := strings.Join(lr.Backlog[0].Leveled.messages, "\n"); first != tc.expectedFirst {
+				t.Errorf("wanted first entry %q, got %q", tc.expectedFirst, first)
+			}
+			if last := strings.Join(lr.Backlog[len(lr.Backlog)-1].Leveled.messages, "\n"); last != tc.expectedLast {
+				t.Errorf("wanted last entry %q, got %q", tc.expectedLast, last)
+			}
+			for i, entry := range lr.Backlog {
+				// If we skip messages and are reading oldest first, adapt the
+				// id to the expected position
+				if tc.pos != 0 && tc.direction == ReadDirectionAfter {
+					i = tc.pos + i
+				}
+				if tc.count != BacklogAllAvailable && tc.pos != 0 && tc.direction == ReadDirectionBefore {
+					// Limit the negative offset to 0
+					i = max(0, tc.pos-tc.count) + i
+				}
+				want := fmt.Sprintf("test %d", i)
+				got := strings.Join(entry.Leveled.messages, "\n")
+				if want != got {
+					t.Errorf("wanted entry %q, got %q", want, got)
+				}
+			}
+		})
+	}
+}
diff --git a/osbase/logtree/logtree_entry.go b/osbase/logtree/logtree_entry.go
index 833081f..a434d08 100644
--- a/osbase/logtree/logtree_entry.go
+++ b/osbase/logtree/logtree_entry.go
@@ -23,6 +23,9 @@
 	Raw *logbuffer.Line
 	// DN from which this entry was logged.
 	DN DN
+	// Position of this entry in the global journal. This is only available
+	// locally and is not set if the entry was obtained via protobuf.
+	Position int
 }
 
 // String returns a canonical representation of this payload as a single string