m/pkg/logtree: fix exact backlog fetch, head/tail confusion

This started off as 'hm, the backlog data returned seems wrong'. I
realized we had no test for that, so I added one. It was indeed broken.

This was because we had two simultaneous bugs: we confused head/tail
between docs and different parts of the code, and we forgot to do a
reverse operation when scanning/retrieving journal entries.

With those two fixed, we also implement backlog retrieval in a optimized
fashion, but not scanning/retrieving more entries that is necessary.

Finally, we drive-by fix a massacred ASCII graphic in a comment.

Change-Id: I2ec5dd9b5b58f66fbc015c142feb91bd92038e4f
Reviewed-on: https://review.monogon.dev/c/monogon/+/1430
Tested-by: Jenkins CI
Reviewed-by: Leopold Schabel <leo@monogon.tech>
diff --git a/metropolis/pkg/logtree/journal.go b/metropolis/pkg/logtree/journal.go
index 5df6e1a..412c042 100644
--- a/metropolis/pkg/logtree/journal.go
+++ b/metropolis/pkg/logtree/journal.go
@@ -18,6 +18,7 @@
 
 import (
 	"errors"
+	"sort"
 	"strings"
 	"sync"
 )
@@ -59,25 +60,21 @@
 // represented by heads[DN]/tails[DN] pointers in journal and nextLocal/prevLocal
 // pointers in entries:
 //
-//	.------------.        .------------.        .------------.
-//	| dn: A.B    |        | dn: Z      |        | dn: A.B    |
-//	| time: 1    |        | time: 2    |        | time: 3    |
-//	|------------|        |------------|        |------------|
-//	| nextGlobal :------->| nextGlobal :------->| nextGlobal :--> nil
-//
-// nil <-: prevGlobal |<-------: prevGlobal |<-------| prevGlobal |
-//
-//	|------------|        |------------|  n     |------------|
-//	| nextLocal  :---. n  | nextLocal  :->i .-->| nextLocal  :--> nil
-//
-// nil <-: prevLocal  |<--: i<-: prevLocal  |  l :---| prevLocal  |
-//
-//	 '------------'   | l  '------------'    |   '------------'
-//	      ^           '----------------------'         ^
-//	      |                      ^                     |
-//	      |                      |                     |
-//	   ( head )             ( tails[Z] )            ( tail )
-//	( heads[A.B] )          ( heads[Z] )         ( tails[A.B] )
+//	      .------------.        .------------.        .------------.
+//	      | dn: A.B    |        | dn: Z      |        | dn: A.B    |
+//	      | time: 1    |        | time: 2    |        | time: 3    |
+//	      |------------|        |------------|        |------------|
+//	      | nextGlobal :------->| nextGlobal :------->| nextGlobal :--> nil
+//	nil <-: prevGlobal |<-------: prevGlobal |<-------| prevGlobal |
+//	      |------------|        |------------|  n     |------------|
+//	      | nextLocal  :---. n  | nextLocal  :->i .-->| nextLocal  :--> nil
+//	nil <-: prevLocal  |<--: i<-: prevLocal  |  l :---| prevLocal  |
+//	      '------------'   | l  '------------'    |   '------------'
+//	           ^           '----------------------'         ^
+//	           |                      ^                     |
+//	           |                      |                     |
+//	        ( head )             ( tails[Z] )            ( tail )
+//	     ( heads[A.B] )          ( heads[Z] )         ( tails[A.B] )
 type journal struct {
 	// mu locks the rest of the structure. It must be taken during any operation on the
 	// journal.
@@ -188,11 +185,11 @@
 // getEntries should be used instead, as it will leverage DN-local linked lists to
 // retrieve them faster. journal.mu must be taken at R or RW level when calling
 // this function.
-func (j *journal) scanEntries(filters ...filter) (res []*entry) {
+func (j *journal) scanEntries(count int, filters ...filter) (res []*entry) {
 	cur := j.tail
 	for {
 		if cur == nil {
-			return
+			break
 		}
 
 		passed := true
@@ -205,8 +202,17 @@
 		if passed {
 			res = append(res, cur)
 		}
-		cur = cur.nextGlobal
+		if count != BacklogAllAvailable && len(res) >= count {
+			break
+		}
+		cur = cur.prevGlobal
 	}
+
+	// Reverse entries back into chronological order.
+	sort.SliceStable(res, func(i, j int) bool {
+		return i > j
+	})
+	return
 }
 
 // getEntries returns all entries at a given DN. This is faster than a
@@ -215,11 +221,11 @@
 // entries returned, but a scan through this DN's local linked list will be
 // performed regardless. journal.mu must be taken at R or RW level when calling
 // this function.
-func (j *journal) getEntries(exact DN, filters ...filter) (res []*entry) {
+func (j *journal) getEntries(count int, exact DN, filters ...filter) (res []*entry) {
 	cur := j.tails[exact]
 	for {
 		if cur == nil {
-			return
+			break
 		}
 
 		passed := true
@@ -232,9 +238,17 @@
 		if passed {
 			res = append(res, cur)
 		}
-		cur = cur.nextLocal
+		if count != BacklogAllAvailable && len(res) >= count {
+			break
+		}
+		cur = cur.prevLocal
 	}
 
+	// Reverse entries back into chronological order.
+	sort.SliceStable(res, func(i, j int) bool {
+		return i > j
+	})
+	return
 }
 
 // Shorten returns a shortened version of this DN for constrained logging
diff --git a/metropolis/pkg/logtree/journal_entry.go b/metropolis/pkg/logtree/journal_entry.go
index d51d406..1580f54 100644
--- a/metropolis/pkg/logtree/journal_entry.go
+++ b/metropolis/pkg/logtree/journal_entry.go
@@ -81,10 +81,10 @@
 	}
 	// Update journal head/tail pointers.
 	if e.journal.head == e {
-		e.journal.head = e.prevGlobal
+		e.journal.head = e.nextGlobal
 	}
 	if e.journal.tail == e {
-		e.journal.tail = e.nextGlobal
+		e.journal.tail = e.prevGlobal
 	}
 
 	// Unlink from the local linked list.
@@ -96,10 +96,10 @@
 	}
 	// Update journal head/tail pointers.
 	if e.journal.heads[e.origin] == e {
-		e.journal.heads[e.origin] = e.prevLocal
+		e.journal.heads[e.origin] = e.nextLocal
 	}
 	if e.journal.tails[e.origin] == e {
-		e.journal.tails[e.origin] = e.nextLocal
+		e.journal.tails[e.origin] = e.prevLocal
 	}
 }
 
@@ -121,13 +121,13 @@
 
 	// Insert at head in global linked list, set pointers.
 	e.nextGlobal = nil
-	e.prevGlobal = j.head
-	if j.head != nil {
-		j.head.nextGlobal = e
+	e.prevGlobal = j.tail
+	if j.tail != nil {
+		j.tail.nextGlobal = e
 	}
-	j.head = e
-	if j.tail == nil {
-		j.tail = e
+	j.tail = e
+	if j.head == nil {
+		j.head = e
 	}
 
 	// Create quota if necessary.
@@ -137,27 +137,27 @@
 
 	// Insert at head in local linked list, calculate seqLocal, set pointers.
 	e.nextLocal = nil
-	e.prevLocal = j.heads[e.origin]
-	if j.heads[e.origin] != nil {
-		j.heads[e.origin].nextLocal = e
+	e.prevLocal = j.tails[e.origin]
+	if j.tails[e.origin] != nil {
+		j.tails[e.origin].nextLocal = e
 		e.seqLocal = e.prevLocal.seqLocal + 1
 	} else {
 		e.seqLocal = 0
 	}
-	j.heads[e.origin] = e
-	if j.tails[e.origin] == nil {
-		j.tails[e.origin] = e
+	j.tails[e.origin] = e
+	if j.heads[e.origin] == nil {
+		j.heads[e.origin] = e
 	}
 
 	// Apply quota to the local linked list that this entry got inserted to, ie. remove
 	// elements in excess of the quota.max count.
 	quota := j.quota[e.origin]
-	count := (j.heads[e.origin].seqLocal - j.tails[e.origin].seqLocal) + 1
+	count := (j.tails[e.origin].seqLocal - j.heads[e.origin].seqLocal) + 1
 	if count > quota.max {
-		// Keep popping elements off the tail of the local linked list until quota is not
+		// Keep popping elements off the head of the local linked list until quota is not
 		// violated.
 		left := count - quota.max
-		cur := j.tails[e.origin]
+		cur := j.heads[e.origin]
 		for {
 			// This shouldn't happen if quota.max >= 1.
 			if cur == nil {
diff --git a/metropolis/pkg/logtree/journal_test.go b/metropolis/pkg/logtree/journal_test.go
index 1df3f12..e9fc3b4 100644
--- a/metropolis/pkg/logtree/journal_test.go
+++ b/metropolis/pkg/logtree/journal_test.go
@@ -44,7 +44,7 @@
 		j.append(e)
 	}
 
-	entries := j.getEntries("main")
+	entries := j.getEntries(BacklogAllAvailable, "main")
 	if want, got := 8192, len(entries); want != got {
 		t.Fatalf("wanted %d entries, got %d", want, got)
 	}
@@ -73,20 +73,20 @@
 		}
 	}
 
-	entries := j.getEntries("chatty")
+	entries := j.getEntries(BacklogAllAvailable, "chatty")
 	if want, got := 8192, len(entries); want != got {
 		t.Fatalf("wanted %d chatty entries, got %d", want, got)
 	}
-	entries = j.getEntries("solemn")
+	entries = j.getEntries(BacklogAllAvailable, "solemn")
 	if want, got := 900, len(entries); want != got {
 		t.Fatalf("wanted %d solemn entries, got %d", want, got)
 	}
-	entries = j.getEntries("absent")
+	entries = j.getEntries(BacklogAllAvailable, "absent")
 	if want, got := 0, len(entries); want != got {
 		t.Fatalf("wanted %d absent entries, got %d", want, got)
 	}
 
-	entries = j.scanEntries(filterAll())
+	entries = j.scanEntries(BacklogAllAvailable, filterAll())
 	if want, got := 8192+900, len(entries); want != got {
 		t.Fatalf("wanted %d total entries, got %d", want, got)
 	}
@@ -119,7 +119,7 @@
 	j.append(&entry{origin: "e.g", leveled: testPayload("e.g")})
 
 	expect := func(f filter, msgs ...string) string {
-		res := j.scanEntries(f)
+		res := j.scanEntries(BacklogAllAvailable, f)
 		set := make(map[string]bool)
 		for _, entry := range res {
 			set[strings.Join(entry.leveled.messages, "\n")] = true
diff --git a/metropolis/pkg/logtree/logtree_access.go b/metropolis/pkg/logtree/logtree_access.go
index 1babe1e..b601ea4 100644
--- a/metropolis/pkg/logtree/logtree_access.go
+++ b/metropolis/pkg/logtree/logtree_access.go
@@ -157,14 +157,10 @@
 
 	var entries []*entry
 	if backlog > 0 || backlog == BacklogAllAvailable {
-		// TODO(q3k): pass over the backlog count to scanEntries/getEntries, instead of discarding them here.
 		if recursive {
-			entries = l.journal.scanEntries(filters...)
+			entries = l.journal.scanEntries(backlog, filters...)
 		} else {
-			entries = l.journal.getEntries(dn, filters...)
-		}
-		if backlog != BacklogAllAvailable && len(entries) > backlog {
-			entries = entries[:backlog]
+			entries = l.journal.getEntries(backlog, dn, filters...)
 		}
 	}
 
diff --git a/metropolis/pkg/logtree/logtree_test.go b/metropolis/pkg/logtree/logtree_test.go
index a7614a4..e37893a 100644
--- a/metropolis/pkg/logtree/logtree_test.go
+++ b/metropolis/pkg/logtree/logtree_test.go
@@ -29,6 +29,7 @@
 	if err != nil {
 		t.Fatalf("Read: %v", err)
 	}
+	defer res.Close()
 	if want, got := len(entries), len(res.Backlog); want != got {
 		t.Fatalf("wanted %v backlog entries, got %v", want, got)
 	}
@@ -49,6 +50,29 @@
 	return ""
 }
 
+func readBacklog(tree *LogTree, t *testing.T, dn DN, backlog int, recursive bool) []string {
+	t.Helper()
+	opts := []LogReadOption{
+		WithBacklog(backlog),
+	}
+	if recursive {
+		opts = append(opts, WithChildren())
+	}
+	res, err := tree.Read(dn, opts...)
+	if err != nil {
+		t.Fatalf("Read: %v", err)
+	}
+	defer res.Close()
+
+	var lines []string
+	for _, e := range res.Backlog {
+		for _, msg := range e.Leveled.Messages() {
+			lines = append(lines, msg)
+		}
+	}
+	return lines
+}
+
 func TestMultiline(t *testing.T) {
 	tree := New()
 	// Two lines in a single message.
@@ -61,7 +85,7 @@
 	}
 }
 
-func TestBacklog(t *testing.T) {
+func TestBacklogAll(t *testing.T) {
 	tree := New()
 	tree.MustLeveledFor("main").Info("hello, main!")
 	tree.MustLeveledFor("main.foo").Info("hello, main.foo!")
@@ -81,6 +105,44 @@
 	}
 }
 
+func TestBacklogExact(t *testing.T) {
+	tree := New()
+	tree.MustLeveledFor("main").Info("hello, main!")
+	tree.MustLeveledFor("main.foo").Info("hello, main.foo!")
+	tree.MustLeveledFor("main.bar").Info("hello, main.bar!")
+	tree.MustLeveledFor("main.bar.chatty").Info("hey there how are you")
+	tree.MustLeveledFor("main.bar.quiet").Info("fine how are you")
+	tree.MustLeveledFor("main.bar.chatty").Info("i've been alright myself")
+	tree.MustLeveledFor("main.bar.chatty").Info("but to tell you honestly...")
+	tree.MustLeveledFor("main.bar.chatty").Info("i feel like i'm stuck?")
+	tree.MustLeveledFor("main.bar.quiet").Info("mhm")
+	tree.MustLeveledFor("main.bar.chatty").Info("like you know what i'm saying, stuck in like")
+	tree.MustLeveledFor("main.bar.chatty").Info("like a go test?")
+	tree.MustLeveledFor("main.bar.quiet").Info("yeah totally")
+	tree.MustLeveledFor("main.bar.chatty").Info("it's hard to put my finger on it")
+	tree.MustLeveledFor("main.bar.chatty").Info("anyway, how's the wife doing?")
+
+	check := func(a []string, b ...string) {
+		t.Helper()
+		if len(a) != len(b) {
+			t.Errorf("Legth mismatch: wanted %d, got %d", len(b), len(a))
+		}
+		count := len(a)
+		if len(b) < count {
+			count = len(b)
+		}
+		for i := 0; i < count; i++ {
+			if want, got := b[i], a[i]; want != got {
+				t.Errorf("Message %d: wanted %q, got %q", i, want, got)
+			}
+		}
+	}
+
+	check(readBacklog(tree, t, "main", 3, true), "yeah totally", "it's hard to put my finger on it", "anyway, how's the wife doing?")
+	check(readBacklog(tree, t, "main.foo", 3, false), "hello, main.foo!")
+	check(readBacklog(tree, t, "main.bar.quiet", 2, true), "mhm", "yeah totally")
+}
+
 func TestStream(t *testing.T) {
 	tree := New()
 	tree.MustLeveledFor("main").Info("hello, backlog")