osbase/net/dns/forward: add DNS forward handler
This adds a DNS server handler for forwarding queries to upstream DNS
resolvers, with a built-in cache. The implementation is partially based
on CoreDNS. The proxy, cache and up packages are only lightly modified.
The forward package itself however is mostly new code. Unlike CoreDNS,
it supports changing upstreams at runtime, and has integrated caching
and answer order randomization.
Some improvements over CoreDNS:
- Concurrent identical queries only result in one upstream query.
- In case of errors, Extended DNS Errors are added to replies.
- Very large replies are not stored in the cache to avoid using too much
memory.
Change-Id: I42294ae4997d621a6e55c98e46a04874eab75c99
Reviewed-on: https://review.monogon.dev/c/monogon/+/3258
Reviewed-by: Lorenz Brun <lorenz@monogon.tech>
Tested-by: Jenkins CI
diff --git a/osbase/net/dns/forward/cache/BUILD.bazel b/osbase/net/dns/forward/cache/BUILD.bazel
new file mode 100644
index 0000000..3d91fa2
--- /dev/null
+++ b/osbase/net/dns/forward/cache/BUILD.bazel
@@ -0,0 +1,18 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "cache",
+ srcs = ["cache.go"],
+ importpath = "source.monogon.dev/osbase/net/dns/forward/cache",
+ visibility = ["//osbase/net/dns/forward:__subpackages__"],
+ deps = ["@org_golang_x_sys//cpu"],
+)
+
+go_test(
+ name = "cache_test",
+ srcs = [
+ "cache_test.go",
+ "shard_test.go",
+ ],
+ embed = [":cache"],
+)
diff --git a/osbase/net/dns/forward/cache/cache.go b/osbase/net/dns/forward/cache/cache.go
new file mode 100644
index 0000000..8c151a6
--- /dev/null
+++ b/osbase/net/dns/forward/cache/cache.go
@@ -0,0 +1,115 @@
+// Package cache implements a cache. The cache hold 256 shards, each shard
+// holds a cache: a map with a mutex. There is no fancy expunge algorithm, it
+// just randomly evicts elements when it gets full.
+package cache
+
+// Taken and modified from CoreDNS, under Apache 2.0.
+
+import (
+ "sync"
+
+ "golang.org/x/sys/cpu"
+)
+
+const shardSize = 256
+
+// Cache is cache.
+type Cache[T any] struct {
+ shards [shardSize]shard[T]
+}
+
+// shard is a cache with random eviction.
+type shard[T any] struct {
+ items map[uint64]T
+ size int
+
+ sync.RWMutex
+
+ _ cpu.CacheLinePad
+}
+
+// New returns a new cache.
+func New[T any](size int) *Cache[T] {
+ ssize := size / shardSize
+ if ssize < 4 {
+ ssize = 4
+ }
+
+ c := &Cache[T]{}
+
+ // Initialize all the shards
+ for i := 0; i < shardSize; i++ {
+ c.shards[i] = shard[T]{items: make(map[uint64]T), size: ssize}
+ }
+ return c
+}
+
+// Get returns the element under key, and whether the element exists.
+func (c *Cache[T]) Get(key uint64) (el T, exists bool) {
+ shard := key % shardSize
+ return c.shards[shard].Get(key)
+}
+
+// Put adds a new element to the cache. If the element already exists,
+// it is overwritten.
+func (c *Cache[T]) Put(key uint64, el T) {
+ shard := key % shardSize
+ c.shards[shard].Put(key, el)
+}
+
+// GetOrPut returns the element under key if it exists,
+// or else stores newEl there. This operation is atomic.
+func (c *Cache[T]) GetOrPut(key uint64, newEl T) (el T, exists bool) {
+ shard := key % shardSize
+ return c.shards[shard].GetOrPut(key, newEl)
+}
+
+// Remove removes the element indexed with key.
+func (c *Cache[T]) Remove(key uint64) {
+ shard := key % shardSize
+ c.shards[shard].Remove(key)
+}
+
+func (s *shard[T]) Get(key uint64) (el T, exists bool) {
+ s.RLock()
+ el, exists = s.items[key]
+ s.RUnlock()
+ return
+}
+
+func (s *shard[T]) Put(key uint64, el T) {
+ s.Lock()
+ if len(s.items) >= s.size {
+ if _, ok := s.items[key]; !ok {
+ for k := range s.items {
+ delete(s.items, k)
+ break
+ }
+ }
+ }
+ s.items[key] = el
+ s.Unlock()
+}
+
+func (s *shard[T]) GetOrPut(key uint64, newEl T) (el T, exists bool) {
+ s.Lock()
+ el, exists = s.items[key]
+ if !exists {
+ if len(s.items) >= s.size {
+ for k := range s.items {
+ delete(s.items, k)
+ break
+ }
+ }
+ s.items[key] = newEl
+ el = newEl
+ }
+ s.Unlock()
+ return
+}
+
+func (s *shard[T]) Remove(key uint64) {
+ s.Lock()
+ delete(s.items, key)
+ s.Unlock()
+}
diff --git a/osbase/net/dns/forward/cache/cache_test.go b/osbase/net/dns/forward/cache/cache_test.go
new file mode 100644
index 0000000..8a2b80c
--- /dev/null
+++ b/osbase/net/dns/forward/cache/cache_test.go
@@ -0,0 +1,39 @@
+package cache
+
+// Taken and modified from CoreDNS, under Apache 2.0.
+
+import (
+ "testing"
+)
+
+func TestCacheAddAndGet(t *testing.T) {
+ const N = shardSize * 4
+ c := New[int](N)
+ c.Put(1, 1)
+
+ if _, found := c.Get(1); !found {
+ t.Fatal("Failed to find inserted record")
+ }
+}
+
+func TestCacheSharding(t *testing.T) {
+ c := New[int](shardSize)
+ for i := 0; i < shardSize*2; i++ {
+ c.Put(uint64(i), 1)
+ }
+ for i := range c.shards {
+ if len(c.shards[i].items) == 0 {
+ t.Errorf("Failed to populate shard: %d", i)
+ }
+ }
+}
+
+func BenchmarkCache(b *testing.B) {
+ b.ReportAllocs()
+
+ c := New[int](4)
+ for n := 0; n < b.N; n++ {
+ c.Put(1, 1)
+ c.Get(1)
+ }
+}
diff --git a/osbase/net/dns/forward/cache/shard_test.go b/osbase/net/dns/forward/cache/shard_test.go
new file mode 100644
index 0000000..452f8cc
--- /dev/null
+++ b/osbase/net/dns/forward/cache/shard_test.go
@@ -0,0 +1,155 @@
+package cache
+
+// Taken and modified from CoreDNS, under Apache 2.0.
+
+import (
+ "testing"
+)
+
+// newShard returns a new shard with size.
+func newShard[T any](size int) *shard[T] { return &shard[T]{items: make(map[uint64]T), size: size} }
+
+func TestShardAddAndGet(t *testing.T) {
+ s := newShard[int](1)
+ s.Put(1, 1)
+
+ if _, found := s.Get(1); !found {
+ t.Fatal("Failed to find inserted record")
+ }
+
+ s.Put(2, 1)
+ if _, found := s.Get(1); found {
+ t.Fatal("Failed to evict record")
+ }
+ if _, found := s.Get(2); !found {
+ t.Fatal("Failed to find inserted record")
+ }
+}
+
+func TestGetOrPut(t *testing.T) {
+ s := newShard[int](1)
+ el, exists := s.GetOrPut(1, 2)
+ if exists {
+ t.Fatalf("Element should not have existed")
+ }
+ if el != 2 {
+ t.Fatalf("Expected element %d, got %d ", 2, el)
+ }
+
+ el, exists = s.GetOrPut(1, 3)
+ if !exists {
+ t.Fatalf("Element should have existed")
+ }
+ if el != 2 {
+ t.Fatalf("Expected element %d, got %d ", 2, el)
+ }
+}
+
+func TestShardRemove(t *testing.T) {
+ s := newShard[int](2)
+ s.Put(1, 1)
+ s.Put(2, 2)
+
+ s.Remove(1)
+
+ if _, found := s.Get(1); found {
+ t.Fatal("Failed to remove record")
+ }
+ if _, found := s.Get(2); !found {
+ t.Fatal("Failed to find inserted record")
+ }
+}
+
+func TestAddEvict(t *testing.T) {
+ const size = 1024
+ s := newShard[int](size)
+
+ for i := uint64(0); i < size; i++ {
+ s.Put(i, 1)
+ }
+ for i := uint64(0); i < size; i++ {
+ s.Put(i, 1)
+ if len(s.items) != size {
+ t.Fatal("A item was unnecessarily evicted from the cache")
+ }
+ }
+}
+
+func TestShardLen(t *testing.T) {
+ s := newShard[int](4)
+
+ s.Put(1, 1)
+ if l := len(s.items); l != 1 {
+ t.Fatalf("Shard size should %d, got %d", 1, l)
+ }
+
+ s.Put(1, 1)
+ if l := len(s.items); l != 1 {
+ t.Fatalf("Shard size should %d, got %d", 1, l)
+ }
+
+ s.Put(2, 2)
+ if l := len(s.items); l != 2 {
+ t.Fatalf("Shard size should %d, got %d", 2, l)
+ }
+}
+
+func TestShardEvict(t *testing.T) {
+ s := newShard[int](1)
+ s.Put(1, 1)
+ s.Put(2, 2)
+ // 1 should be gone
+
+ if _, found := s.Get(1); found {
+ t.Fatal("Found item that should have been evicted")
+ }
+}
+
+func TestShardLenEvict(t *testing.T) {
+ s := newShard[int](4)
+ s.Put(1, 1)
+ s.Put(2, 1)
+ s.Put(3, 1)
+ s.Put(4, 1)
+
+ if l := len(s.items); l != 4 {
+ t.Fatalf("Shard size should %d, got %d", 4, l)
+ }
+
+ // This should evict one element
+ s.Put(5, 1)
+ if l := len(s.items); l != 4 {
+ t.Fatalf("Shard size should %d, got %d", 4, l)
+ }
+
+ // Make sure we don't accidentally evict an element when
+ // we the key is already stored.
+ for i := 0; i < 4; i++ {
+ s.Put(5, 1)
+ if l := len(s.items); l != 4 {
+ t.Fatalf("Shard size should %d, got %d", 4, l)
+ }
+ }
+}
+
+func BenchmarkShard(b *testing.B) {
+ s := newShard[int](shardSize)
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ k := uint64(i) % shardSize * 2
+ s.Put(k, 1)
+ s.Get(k)
+ }
+}
+
+func BenchmarkShardParallel(b *testing.B) {
+ s := newShard[int](shardSize)
+ b.ResetTimer()
+ b.RunParallel(func(pb *testing.PB) {
+ for i := uint64(0); pb.Next(); i++ {
+ k := i % shardSize * 2
+ s.Put(k, 1)
+ s.Get(k)
+ }
+ })
+}