osbase/net/dns/forward: add DNS forward handler
This adds a DNS server handler for forwarding queries to upstream DNS
resolvers, with a built-in cache. The implementation is partially based
on CoreDNS. The proxy, cache and up packages are only lightly modified.
The forward package itself however is mostly new code. Unlike CoreDNS,
it supports changing upstreams at runtime, and has integrated caching
and answer order randomization.
Some improvements over CoreDNS:
- Concurrent identical queries only result in one upstream query.
- In case of errors, Extended DNS Errors are added to replies.
- Very large replies are not stored in the cache to avoid using too much
memory.
Change-Id: I42294ae4997d621a6e55c98e46a04874eab75c99
Reviewed-on: https://review.monogon.dev/c/monogon/+/3258
Reviewed-by: Lorenz Brun <lorenz@monogon.tech>
Tested-by: Jenkins CI
diff --git a/osbase/net/dns/forward/BUILD.bazel b/osbase/net/dns/forward/BUILD.bazel
new file mode 100644
index 0000000..0728ff1
--- /dev/null
+++ b/osbase/net/dns/forward/BUILD.bazel
@@ -0,0 +1,33 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "forward",
+ srcs = [
+ "cache.go",
+ "forward.go",
+ "metrics.go",
+ ],
+ importpath = "source.monogon.dev/osbase/net/dns/forward",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//osbase/event/memory",
+ "//osbase/net/dns",
+ "//osbase/net/dns/forward/cache",
+ "//osbase/net/dns/forward/proxy",
+ "//osbase/supervisor",
+ "@com_github_miekg_dns//:dns",
+ "@com_github_prometheus_client_golang//prometheus",
+ ],
+)
+
+go_test(
+ name = "forward_test",
+ srcs = ["forward_test.go"],
+ embed = [":forward"],
+ deps = [
+ "//osbase/net/dns",
+ "//osbase/net/dns/test",
+ "//osbase/supervisor",
+ "@com_github_miekg_dns//:dns",
+ ],
+)
diff --git a/osbase/net/dns/forward/LICENSE-3RD-PARTY.txt b/osbase/net/dns/forward/LICENSE-3RD-PARTY.txt
new file mode 100644
index 0000000..98f9935
--- /dev/null
+++ b/osbase/net/dns/forward/LICENSE-3RD-PARTY.txt
@@ -0,0 +1,13 @@
+Copyright 2016-2024 The CoreDNS authors and contributors
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
diff --git a/osbase/net/dns/forward/cache.go b/osbase/net/dns/forward/cache.go
new file mode 100644
index 0000000..d2422a9
--- /dev/null
+++ b/osbase/net/dns/forward/cache.go
@@ -0,0 +1,278 @@
+package forward
+
+import (
+ "hash/maphash"
+ "math/rand/v2"
+ "slices"
+ "sync"
+ "time"
+
+ "github.com/miekg/dns"
+
+ netDNS "source.monogon.dev/osbase/net/dns"
+)
+
+// The cache uses at most cacheMaxItemSize * cacheCapacity = 20 MB of memory.
+// Actual memory usage may be slightly higher due to the overhead of in-memory
+// data structures compared to the serialized, uncompressed length.
+const (
+ cacheMaxItemSize = 2048
+ cacheCapacity = 10000
+ cacheMinSeconds = 1
+ cacheMaxSeconds = 5
+)
+
+// cacheKey is the key used for cache lookups. Both the DNSSEC ok and the
+// Checking Disabled flag influence the reply. While it would be possible to
+// always make upstream queries with DNSSEC, and then strip the authenticating
+// records if the client did not request it, this would mostly just waste
+// bandwidth. In theory, it would be possible to cache NXDOMAINs independently
+// of the QTYPE (RFC 2308, Section 5). However, the additional complexity and
+// second lookup for each query does not seem worth it.
+type cacheKey struct {
+ Name string
+ Qtype uint16
+ DNSSEC bool
+ CheckingDisabled bool
+}
+
+type cacheItem struct {
+ key cacheKey
+
+ // lock protects all fields except key. It also doubles as a way to wait for
+ // the reply. A write lock is held for as long as a query is pending.
+ lock sync.RWMutex
+
+ reply proxyReply
+ stored time.Time
+ // ttl is the number of seconds during which the cached reply can be used.
+ ttl uint32
+ // seenTruncated is true if we ever saw a truncated response for this key.
+ // We will then always use TCP when refetching after the item expires.
+ seenTruncated bool
+}
+
+func (k cacheKey) hash(seed maphash.Seed) uint64 {
+ var h maphash.Hash
+ h.SetSeed(seed)
+ h.WriteByte(byte(k.Qtype))
+ h.WriteByte(byte(k.Qtype >> 8))
+ var flags byte
+ if k.DNSSEC {
+ flags += 1
+ }
+ if k.CheckingDisabled {
+ flags += 2
+ }
+ h.WriteByte(flags)
+ h.WriteString(k.Name)
+ return h.Sum64()
+}
+
+// valid returns true if the cache item can be used for this query.
+func (i *cacheItem) valid(now time.Time, tcp bool) bool {
+ expired := now.After(i.stored.Add(time.Duration(i.ttl) * time.Second))
+ return !expired && (!tcp || !i.reply.Truncated)
+}
+
+func (f *Forward) HandleDNS(r *netDNS.Request) {
+ if !r.Reply.RecursionDesired {
+ // Only forward queries if the RD flag is set. If the question has been
+ // redirected by CNAME, return the reply as is without following the CNAME,
+ // else set a REFUSED rcode.
+ if r.Qname == r.Reply.Question[0].Name {
+ r.Reply.Rcode = dns.RcodeRefused
+ rejectsCount.WithLabelValues("no_recursion_desired").Inc()
+ }
+ } else {
+ f.lookupOrForward(r)
+ }
+ r.SendReply()
+}
+
+func (f *Forward) lookupOrForward(r *netDNS.Request) {
+ key := cacheKey{
+ Name: r.QnameCanonical,
+ Qtype: r.Qtype,
+ DNSSEC: r.Ropt != nil && r.Ropt.Do(),
+ CheckingDisabled: r.Reply.CheckingDisabled,
+ }
+ hash := key.hash(f.seed)
+ tcp := r.Writer.RemoteAddr().Network() == "tcp"
+
+ item, exists := f.cache.Get(hash)
+ if !exists {
+ // The lookup failed; allocate a new item and try to insert it.
+ // Lock the new item before inserting it, such that concurrent queries
+ // are blocked until we receive the reply and store it in the item.
+ newItem := &cacheItem{key: key}
+ newItem.lock.Lock()
+ item, exists = f.cache.GetOrPut(hash, newItem)
+ if !exists {
+ cacheLookupsCount.WithLabelValues("miss").Inc()
+ f.forward(r, newItem, hash, tcp)
+ newItem.lock.Unlock()
+ return
+ }
+ }
+ if item.key != key {
+ // We have a hash collision. Replace the existing item.
+ cacheLookupsCount.WithLabelValues("miss").Inc()
+ newItem := &cacheItem{key: key}
+ newItem.lock.Lock()
+ f.cache.Put(hash, newItem)
+ f.forward(r, newItem, hash, tcp)
+ newItem.lock.Unlock()
+ return
+ }
+
+ // Take a read lock and check if the reply is valid for this query.
+ // This blocks if a query for this item is pending.
+ item.lock.RLock()
+ now := f.now()
+ if item.valid(now, tcp) {
+ replyFromCache(r, item, now)
+ item.lock.RUnlock()
+ return
+ }
+ item.lock.RUnlock()
+
+ item.lock.Lock()
+ now = f.now()
+ if item.valid(now, tcp) {
+ replyFromCache(r, item, now)
+ item.lock.Unlock()
+ return
+ }
+ cacheLookupsCount.WithLabelValues("refresh").Inc()
+ f.forward(r, item, hash, tcp || item.seenTruncated)
+ item.lock.Unlock()
+}
+
+func (f *Forward) forward(r *netDNS.Request, item *cacheItem, hash uint64, tcp bool) {
+ // Query proxies.
+ var queryOptions []dns.EDNS0
+ if r.Qopt != nil {
+ // Forward DNSSEC algorithm understood options. These are only for
+ // statistics and must not influence the reply, so we do not need to include
+ // them in the cache key.
+ for _, option := range r.Qopt.Option {
+ switch option.(type) {
+ case *dns.EDNS0_DAU, *dns.EDNS0_DHU, *dns.EDNS0_N3U:
+ queryOptions = append(queryOptions, option)
+ }
+ }
+ }
+
+ question := dns.Question{
+ Name: item.key.Name,
+ Qtype: item.key.Qtype,
+ Qclass: dns.ClassINET,
+ }
+ reply := f.queryProxies(question, item.key.DNSSEC, item.key.CheckingDisabled, queryOptions, tcp)
+
+ r.Reply.Truncated = reply.Truncated
+ r.Reply.Rcode = reply.Rcode
+ r.Reply.Answer = appendOrClip(r.Reply.Answer, reply.Answer)
+ r.Reply.Ns = appendOrClip(r.Reply.Ns, reply.Ns)
+ r.Reply.Extra = appendOrClip(r.Reply.Extra, reply.Extra)
+ if r.Ropt != nil {
+ r.Ropt.Option = appendOrClip(r.Ropt.Option, reply.Options)
+ }
+
+ item.reply = reply
+ if reply.Truncated {
+ item.seenTruncated = true
+ }
+ item.stored = f.now()
+
+ // Compute how long to cache the item.
+ ttl := uint32(cacheMaxSeconds)
+ // If the reply is an error, or contains no ttls, use the minimum cache time.
+ if (reply.Rcode != dns.RcodeSuccess && reply.Rcode != dns.RcodeNameError) ||
+ len(reply.Answer)+len(reply.Ns)+len(reply.Extra) == 0 {
+ ttl = cacheMinSeconds
+ }
+ for _, rr := range reply.Answer {
+ ttl = min(ttl, rr.Header().Ttl)
+ }
+ for _, rr := range reply.Ns {
+ ttl = min(ttl, rr.Header().Ttl)
+ }
+ for _, rr := range reply.Extra {
+ ttl = min(ttl, rr.Header().Ttl)
+ }
+ item.ttl = max(ttl, cacheMinSeconds)
+
+ if reply.NoStore {
+ f.cache.Remove(hash)
+ }
+}
+
+func replyFromCache(r *netDNS.Request, item *cacheItem, now time.Time) {
+ cacheLookupsCount.WithLabelValues("hit").Inc()
+ decrementTtl := uint32(max(0, now.Sub(item.stored)/time.Second))
+
+ r.Reply.Truncated = item.reply.Truncated
+ r.Reply.Rcode = item.reply.Rcode
+
+ existing_len := len(r.Reply.Answer)
+ r.Reply.Answer = appendCached(r.Reply.Answer, item.reply.Answer, decrementTtl)
+ shuffleAnswer(r.Reply.Answer[existing_len:])
+ r.Reply.Ns = appendCached(r.Reply.Ns, item.reply.Ns, decrementTtl)
+ r.Reply.Extra = appendCached(r.Reply.Extra, item.reply.Extra, decrementTtl)
+ if r.Ropt != nil {
+ r.Ropt.Option = appendOrClip(r.Ropt.Option, item.reply.Options)
+ }
+}
+
+func appendCached(existing, add []dns.RR, decrementTtl uint32) []dns.RR {
+ existing = slices.Grow(existing, len(add))
+ for _, rr := range add {
+ decRR := dns.Copy(rr)
+ hdr := decRR.Header()
+ if hdr.Ttl == 0 {
+ } else if decrementTtl >= hdr.Ttl {
+ // Don't decrement the TTL to 0, as that could cause problems.
+ // https://00f.net/2011/11/17/how-long-does-a-dns-ttl-last/
+ hdr.Ttl = 1
+ } else {
+ hdr.Ttl = hdr.Ttl - decrementTtl
+ }
+ existing = append(existing, decRR)
+ }
+ return existing
+}
+
+// shuffleAnswer randomizes the order of consecutive RRs which are part of the
+// same RRset. This provides some load balancing.
+func shuffleAnswer(rrs []dns.RR) {
+ if len(rrs) < 2 {
+ return
+ }
+ startIndex := 0
+ startHdr := rrs[0].Header()
+ for i := 1; i < len(rrs); i++ {
+ hdr := rrs[i].Header()
+ sameRRset := startHdr.Rrtype == hdr.Rrtype &&
+ startHdr.Class == hdr.Class &&
+ startHdr.Name == hdr.Name
+ if sameRRset {
+ swap := startIndex + rand.IntN(i+1-startIndex)
+ rrs[i], rrs[swap] = rrs[swap], rrs[i]
+ } else {
+ startIndex = i
+ startHdr = hdr
+ }
+ }
+}
+
+// appendOrClip is similar to append(a, b...) except that it avoids allocation
+// if a is empty, in which case it returns b with any free capacity removed.
+// The resulting slice can still be appended to without affecting b.
+func appendOrClip[S ~[]E, E any](a, b S) S {
+ if len(a) == 0 {
+ return slices.Clip(b)
+ }
+ return append(a, b...)
+}
diff --git a/osbase/net/dns/forward/cache/BUILD.bazel b/osbase/net/dns/forward/cache/BUILD.bazel
new file mode 100644
index 0000000..3d91fa2
--- /dev/null
+++ b/osbase/net/dns/forward/cache/BUILD.bazel
@@ -0,0 +1,18 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "cache",
+ srcs = ["cache.go"],
+ importpath = "source.monogon.dev/osbase/net/dns/forward/cache",
+ visibility = ["//osbase/net/dns/forward:__subpackages__"],
+ deps = ["@org_golang_x_sys//cpu"],
+)
+
+go_test(
+ name = "cache_test",
+ srcs = [
+ "cache_test.go",
+ "shard_test.go",
+ ],
+ embed = [":cache"],
+)
diff --git a/osbase/net/dns/forward/cache/cache.go b/osbase/net/dns/forward/cache/cache.go
new file mode 100644
index 0000000..8c151a6
--- /dev/null
+++ b/osbase/net/dns/forward/cache/cache.go
@@ -0,0 +1,115 @@
+// Package cache implements a cache. The cache hold 256 shards, each shard
+// holds a cache: a map with a mutex. There is no fancy expunge algorithm, it
+// just randomly evicts elements when it gets full.
+package cache
+
+// Taken and modified from CoreDNS, under Apache 2.0.
+
+import (
+ "sync"
+
+ "golang.org/x/sys/cpu"
+)
+
+const shardSize = 256
+
+// Cache is cache.
+type Cache[T any] struct {
+ shards [shardSize]shard[T]
+}
+
+// shard is a cache with random eviction.
+type shard[T any] struct {
+ items map[uint64]T
+ size int
+
+ sync.RWMutex
+
+ _ cpu.CacheLinePad
+}
+
+// New returns a new cache.
+func New[T any](size int) *Cache[T] {
+ ssize := size / shardSize
+ if ssize < 4 {
+ ssize = 4
+ }
+
+ c := &Cache[T]{}
+
+ // Initialize all the shards
+ for i := 0; i < shardSize; i++ {
+ c.shards[i] = shard[T]{items: make(map[uint64]T), size: ssize}
+ }
+ return c
+}
+
+// Get returns the element under key, and whether the element exists.
+func (c *Cache[T]) Get(key uint64) (el T, exists bool) {
+ shard := key % shardSize
+ return c.shards[shard].Get(key)
+}
+
+// Put adds a new element to the cache. If the element already exists,
+// it is overwritten.
+func (c *Cache[T]) Put(key uint64, el T) {
+ shard := key % shardSize
+ c.shards[shard].Put(key, el)
+}
+
+// GetOrPut returns the element under key if it exists,
+// or else stores newEl there. This operation is atomic.
+func (c *Cache[T]) GetOrPut(key uint64, newEl T) (el T, exists bool) {
+ shard := key % shardSize
+ return c.shards[shard].GetOrPut(key, newEl)
+}
+
+// Remove removes the element indexed with key.
+func (c *Cache[T]) Remove(key uint64) {
+ shard := key % shardSize
+ c.shards[shard].Remove(key)
+}
+
+func (s *shard[T]) Get(key uint64) (el T, exists bool) {
+ s.RLock()
+ el, exists = s.items[key]
+ s.RUnlock()
+ return
+}
+
+func (s *shard[T]) Put(key uint64, el T) {
+ s.Lock()
+ if len(s.items) >= s.size {
+ if _, ok := s.items[key]; !ok {
+ for k := range s.items {
+ delete(s.items, k)
+ break
+ }
+ }
+ }
+ s.items[key] = el
+ s.Unlock()
+}
+
+func (s *shard[T]) GetOrPut(key uint64, newEl T) (el T, exists bool) {
+ s.Lock()
+ el, exists = s.items[key]
+ if !exists {
+ if len(s.items) >= s.size {
+ for k := range s.items {
+ delete(s.items, k)
+ break
+ }
+ }
+ s.items[key] = newEl
+ el = newEl
+ }
+ s.Unlock()
+ return
+}
+
+func (s *shard[T]) Remove(key uint64) {
+ s.Lock()
+ delete(s.items, key)
+ s.Unlock()
+}
diff --git a/osbase/net/dns/forward/cache/cache_test.go b/osbase/net/dns/forward/cache/cache_test.go
new file mode 100644
index 0000000..8a2b80c
--- /dev/null
+++ b/osbase/net/dns/forward/cache/cache_test.go
@@ -0,0 +1,39 @@
+package cache
+
+// Taken and modified from CoreDNS, under Apache 2.0.
+
+import (
+ "testing"
+)
+
+func TestCacheAddAndGet(t *testing.T) {
+ const N = shardSize * 4
+ c := New[int](N)
+ c.Put(1, 1)
+
+ if _, found := c.Get(1); !found {
+ t.Fatal("Failed to find inserted record")
+ }
+}
+
+func TestCacheSharding(t *testing.T) {
+ c := New[int](shardSize)
+ for i := 0; i < shardSize*2; i++ {
+ c.Put(uint64(i), 1)
+ }
+ for i := range c.shards {
+ if len(c.shards[i].items) == 0 {
+ t.Errorf("Failed to populate shard: %d", i)
+ }
+ }
+}
+
+func BenchmarkCache(b *testing.B) {
+ b.ReportAllocs()
+
+ c := New[int](4)
+ for n := 0; n < b.N; n++ {
+ c.Put(1, 1)
+ c.Get(1)
+ }
+}
diff --git a/osbase/net/dns/forward/cache/shard_test.go b/osbase/net/dns/forward/cache/shard_test.go
new file mode 100644
index 0000000..452f8cc
--- /dev/null
+++ b/osbase/net/dns/forward/cache/shard_test.go
@@ -0,0 +1,155 @@
+package cache
+
+// Taken and modified from CoreDNS, under Apache 2.0.
+
+import (
+ "testing"
+)
+
+// newShard returns a new shard with size.
+func newShard[T any](size int) *shard[T] { return &shard[T]{items: make(map[uint64]T), size: size} }
+
+func TestShardAddAndGet(t *testing.T) {
+ s := newShard[int](1)
+ s.Put(1, 1)
+
+ if _, found := s.Get(1); !found {
+ t.Fatal("Failed to find inserted record")
+ }
+
+ s.Put(2, 1)
+ if _, found := s.Get(1); found {
+ t.Fatal("Failed to evict record")
+ }
+ if _, found := s.Get(2); !found {
+ t.Fatal("Failed to find inserted record")
+ }
+}
+
+func TestGetOrPut(t *testing.T) {
+ s := newShard[int](1)
+ el, exists := s.GetOrPut(1, 2)
+ if exists {
+ t.Fatalf("Element should not have existed")
+ }
+ if el != 2 {
+ t.Fatalf("Expected element %d, got %d ", 2, el)
+ }
+
+ el, exists = s.GetOrPut(1, 3)
+ if !exists {
+ t.Fatalf("Element should have existed")
+ }
+ if el != 2 {
+ t.Fatalf("Expected element %d, got %d ", 2, el)
+ }
+}
+
+func TestShardRemove(t *testing.T) {
+ s := newShard[int](2)
+ s.Put(1, 1)
+ s.Put(2, 2)
+
+ s.Remove(1)
+
+ if _, found := s.Get(1); found {
+ t.Fatal("Failed to remove record")
+ }
+ if _, found := s.Get(2); !found {
+ t.Fatal("Failed to find inserted record")
+ }
+}
+
+func TestAddEvict(t *testing.T) {
+ const size = 1024
+ s := newShard[int](size)
+
+ for i := uint64(0); i < size; i++ {
+ s.Put(i, 1)
+ }
+ for i := uint64(0); i < size; i++ {
+ s.Put(i, 1)
+ if len(s.items) != size {
+ t.Fatal("A item was unnecessarily evicted from the cache")
+ }
+ }
+}
+
+func TestShardLen(t *testing.T) {
+ s := newShard[int](4)
+
+ s.Put(1, 1)
+ if l := len(s.items); l != 1 {
+ t.Fatalf("Shard size should %d, got %d", 1, l)
+ }
+
+ s.Put(1, 1)
+ if l := len(s.items); l != 1 {
+ t.Fatalf("Shard size should %d, got %d", 1, l)
+ }
+
+ s.Put(2, 2)
+ if l := len(s.items); l != 2 {
+ t.Fatalf("Shard size should %d, got %d", 2, l)
+ }
+}
+
+func TestShardEvict(t *testing.T) {
+ s := newShard[int](1)
+ s.Put(1, 1)
+ s.Put(2, 2)
+ // 1 should be gone
+
+ if _, found := s.Get(1); found {
+ t.Fatal("Found item that should have been evicted")
+ }
+}
+
+func TestShardLenEvict(t *testing.T) {
+ s := newShard[int](4)
+ s.Put(1, 1)
+ s.Put(2, 1)
+ s.Put(3, 1)
+ s.Put(4, 1)
+
+ if l := len(s.items); l != 4 {
+ t.Fatalf("Shard size should %d, got %d", 4, l)
+ }
+
+ // This should evict one element
+ s.Put(5, 1)
+ if l := len(s.items); l != 4 {
+ t.Fatalf("Shard size should %d, got %d", 4, l)
+ }
+
+ // Make sure we don't accidentally evict an element when
+ // we the key is already stored.
+ for i := 0; i < 4; i++ {
+ s.Put(5, 1)
+ if l := len(s.items); l != 4 {
+ t.Fatalf("Shard size should %d, got %d", 4, l)
+ }
+ }
+}
+
+func BenchmarkShard(b *testing.B) {
+ s := newShard[int](shardSize)
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ k := uint64(i) % shardSize * 2
+ s.Put(k, 1)
+ s.Get(k)
+ }
+}
+
+func BenchmarkShardParallel(b *testing.B) {
+ s := newShard[int](shardSize)
+ b.ResetTimer()
+ b.RunParallel(func(pb *testing.PB) {
+ for i := uint64(0); pb.Next(); i++ {
+ k := i % shardSize * 2
+ s.Put(k, 1)
+ s.Get(k)
+ }
+ })
+}
diff --git a/osbase/net/dns/forward/forward.go b/osbase/net/dns/forward/forward.go
new file mode 100644
index 0000000..00271ec
--- /dev/null
+++ b/osbase/net/dns/forward/forward.go
@@ -0,0 +1,390 @@
+// Package forward implements a forwarding proxy.
+//
+// A cache is used to reduce load on the upstream servers. Cached items are only
+// used for a short time, because otherwise, we would need to provide a feature
+// for flushing the cache. The cache is most useful for taking the load from
+// applications making very frequent repeated queries. The cache also doubles as
+// a way to merge concurrent identical queries, since items are inserted into
+// the cache before sending the query upstream (see also RFC 5452, Section 5).
+package forward
+
+// Taken and modified from the Forward plugin of CoreDNS, under Apache 2.0.
+
+import (
+ "context"
+ "errors"
+ "hash/maphash"
+ "math/rand/v2"
+ "os"
+ "slices"
+ "strconv"
+ "sync/atomic"
+ "time"
+
+ "github.com/miekg/dns"
+
+ "source.monogon.dev/osbase/event/memory"
+ "source.monogon.dev/osbase/net/dns/forward/cache"
+ "source.monogon.dev/osbase/net/dns/forward/proxy"
+ "source.monogon.dev/osbase/supervisor"
+)
+
+const (
+ connectionExpire = 10 * time.Second
+ healthcheckInterval = 500 * time.Millisecond
+ forwardTimeout = 5 * time.Second
+ maxFails = 2
+ maxConcurrent = 5000
+ maxUpstreams = 15
+)
+
+// Forward represents a plugin instance that can proxy requests to another (DNS)
+// server. It has a list of proxies each representing one upstream proxy.
+type Forward struct {
+ DNSServers memory.Value[[]string]
+ upstreams atomic.Pointer[[]*proxy.Proxy]
+
+ concurrent atomic.Int64
+
+ seed maphash.Seed
+ cache *cache.Cache[*cacheItem]
+
+ // now can be used to override time for testing.
+ now func() time.Time
+}
+
+// New returns a new Forward.
+func New() *Forward {
+ return &Forward{
+ seed: maphash.MakeSeed(),
+ cache: cache.New[*cacheItem](cacheCapacity),
+ now: time.Now,
+ }
+}
+
+func (f *Forward) Run(ctx context.Context) error {
+ supervisor.Signal(ctx, supervisor.SignalHealthy)
+
+ var lastAddrs []string
+ var upstreams []*proxy.Proxy
+
+ w := f.DNSServers.Watch()
+ defer w.Close()
+ for {
+ addrs, err := w.Get(ctx)
+ if err != nil {
+ for _, p := range upstreams {
+ p.Stop()
+ }
+ f.upstreams.Store(nil)
+ return err
+ }
+
+ if len(addrs) > maxUpstreams {
+ addrs = addrs[:maxUpstreams]
+ }
+
+ if slices.Equal(addrs, lastAddrs) {
+ continue
+ }
+ lastAddrs = addrs
+ supervisor.Logger(ctx).Infof("New upstream DNS servers: %s", addrs)
+
+ newAddrs := make(map[string]bool)
+ for _, addr := range addrs {
+ newAddrs[addr] = true
+ }
+ var newUpstreams []*proxy.Proxy
+ for _, p := range upstreams {
+ if newAddrs[p.Addr()] {
+ delete(newAddrs, p.Addr())
+ newUpstreams = append(newUpstreams, p)
+ } else {
+ p.Stop()
+ }
+ }
+ for newAddr := range newAddrs {
+ p := proxy.NewProxy(newAddr)
+ p.SetExpire(connectionExpire)
+ p.GetHealthchecker().SetRecursionDesired(true)
+ p.GetHealthchecker().SetDomain(".")
+ p.Start(healthcheckInterval)
+ newUpstreams = append(newUpstreams, p)
+ }
+ upstreams = newUpstreams
+ f.upstreams.Store(&newUpstreams)
+ }
+}
+
+type proxyReply struct {
+ // NoStore indicates that the reply should not be stored in the cache.
+ // This could be because it is cheap to obtain or expensive to store.
+ NoStore bool
+
+ Truncated bool
+ Rcode int
+ Answer []dns.RR
+ Ns []dns.RR
+ Extra []dns.RR
+ Options []dns.EDNS0
+}
+
+var (
+ replyConcurrencyLimit = proxyReply{
+ NoStore: true,
+ Rcode: dns.RcodeServerFailure,
+ Options: []dns.EDNS0{&dns.EDNS0_EDE{
+ InfoCode: dns.ExtendedErrorCodeOther,
+ ExtraText: "too many concurrent queries",
+ }},
+ }
+ replyNoUpstreams = proxyReply{
+ NoStore: true,
+ Rcode: dns.RcodeRefused,
+ Options: []dns.EDNS0{&dns.EDNS0_EDE{
+ InfoCode: dns.ExtendedErrorCodeOther,
+ ExtraText: "no upstream DNS servers configured",
+ }},
+ }
+ replyProtocolError = proxyReply{
+ Rcode: dns.RcodeServerFailure,
+ Options: []dns.EDNS0{&dns.EDNS0_EDE{
+ InfoCode: dns.ExtendedErrorCodeNetworkError,
+ ExtraText: "DNS protocol error when querying upstream DNS server",
+ }},
+ }
+ replyTimeout = proxyReply{
+ Rcode: dns.RcodeServerFailure,
+ Options: []dns.EDNS0{&dns.EDNS0_EDE{
+ InfoCode: dns.ExtendedErrorCodeNetworkError,
+ ExtraText: "timeout when querying upstream DNS server",
+ }},
+ }
+ replyNetworkError = proxyReply{
+ Rcode: dns.RcodeServerFailure,
+ Options: []dns.EDNS0{&dns.EDNS0_EDE{
+ InfoCode: dns.ExtendedErrorCodeNetworkError,
+ ExtraText: "network error when querying upstream DNS server",
+ }},
+ }
+)
+
+func (f *Forward) queryProxies(
+ question dns.Question,
+ dnssec bool,
+ checkingDisabled bool,
+ queryOptions []dns.EDNS0,
+ useTCP bool,
+) proxyReply {
+ count := f.concurrent.Add(1)
+ defer f.concurrent.Add(-1)
+ if count > maxConcurrent {
+ rejectsCount.WithLabelValues("concurrency_limit").Inc()
+ return replyConcurrencyLimit
+ }
+
+ // Construct outgoing query.
+ qopt := new(dns.OPT)
+ qopt.Hdr.Name = "."
+ qopt.Hdr.Rrtype = dns.TypeOPT
+ qopt.SetUDPSize(proxy.AdvertiseUDPSize)
+ if dnssec {
+ qopt.SetDo()
+ }
+ qopt.Option = queryOptions
+ m := &dns.Msg{
+ MsgHdr: dns.MsgHdr{
+ Opcode: dns.OpcodeQuery,
+ RecursionDesired: true,
+ CheckingDisabled: checkingDisabled,
+ },
+ Question: []dns.Question{question},
+ Extra: []dns.RR{qopt},
+ }
+
+ var list []*proxy.Proxy
+ if upstreams := f.upstreams.Load(); upstreams != nil {
+ list = randomList(*upstreams)
+ }
+
+ if len(list) == 0 {
+ rejectsCount.WithLabelValues("no_upstreams").Inc()
+ return replyNoUpstreams
+ }
+
+ proto := "udp"
+ if useTCP {
+ proto = "tcp"
+ }
+
+ var (
+ curUpstream *proxy.Proxy
+ curStart time.Time
+ ret *dns.Msg
+ err error
+ )
+ recordDuration := func(rcode string) {
+ upstreamDuration.WithLabelValues(curUpstream.Addr(), proto, rcode).Observe(time.Since(curStart).Seconds())
+ }
+
+ fails := 0
+ i := 0
+ listStart := time.Now()
+ deadline := listStart.Add(forwardTimeout)
+ for {
+ if i >= len(list) {
+ // reached the end of list, reset to begin
+ i = 0
+ fails = 0
+
+ // Sleep for a bit if the last time we started going through the list was
+ // very recent.
+ time.Sleep(time.Until(listStart.Add(time.Second)))
+ listStart = time.Now()
+ }
+
+ curUpstream = list[i]
+ i++
+ if curUpstream.Down(maxFails) {
+ fails++
+ if fails < len(list) {
+ continue
+ }
+ // All upstream proxies are dead, assume healthcheck is completely broken
+ // and connect to a random upstream.
+ healthcheckBrokenCount.Inc()
+ }
+
+ curStart = time.Now()
+
+ for {
+ ret, err = curUpstream.Connect(m, useTCP)
+
+ if errors.Is(err, proxy.ErrCachedClosed) {
+ // Remote side closed conn, can only happen with TCP.
+ continue
+ }
+ break
+ }
+
+ if err != nil {
+ // Kick off health check to see if *our* upstream is broken.
+ curUpstream.Healthcheck()
+
+ retry := fails < len(list) && time.Now().Before(deadline)
+ var dnsError *dns.Error
+ switch {
+ case errors.Is(err, os.ErrDeadlineExceeded):
+ recordDuration("timeout")
+ if !retry {
+ return replyTimeout
+ }
+ case errors.As(err, &dnsError):
+ recordDuration("protocol_error")
+ if !retry {
+ return replyProtocolError
+ }
+ default:
+ recordDuration("network_error")
+ if !retry {
+ return replyNetworkError
+ }
+ }
+ continue
+ }
+
+ break
+ }
+
+ if !ret.Response || ret.Opcode != dns.OpcodeQuery || len(ret.Question) != 1 {
+ recordDuration("protocol_error")
+ return replyProtocolError
+ }
+
+ if ret.Truncated && useTCP {
+ recordDuration("protocol_error")
+ return replyProtocolError
+ }
+ if ret.Truncated {
+ proto = "udp_truncated"
+ }
+
+ // Check that the reply matches the question.
+ retq := ret.Question[0]
+ if retq.Qtype != question.Qtype || retq.Qclass != question.Qclass ||
+ (retq.Name != question.Name && dns.CanonicalName(retq.Name) != question.Name) {
+ recordDuration("protocol_error")
+ return replyProtocolError
+ }
+
+ // Extract OPT from reply.
+ var ropt *dns.OPT
+ var options []dns.EDNS0
+ for i := len(ret.Extra) - 1; i >= 0; i-- {
+ if rr, ok := ret.Extra[i].(*dns.OPT); ok {
+ if ropt != nil {
+ // Found more than one OPT.
+ recordDuration("protocol_error")
+ return replyProtocolError
+ }
+ ropt = rr
+ ret.Extra = append(ret.Extra[:i], ret.Extra[i+1:]...)
+ }
+ }
+ if ropt != nil {
+ if ropt.Version() != 0 || ropt.Hdr.Name != "." {
+ recordDuration("protocol_error")
+ return replyProtocolError
+ }
+ // Forward Extended DNS Error options.
+ for _, option := range ropt.Option {
+ switch option.(type) {
+ case *dns.EDNS0_EDE:
+ options = append(options, option)
+ }
+ }
+ }
+
+ rcode, ok := dns.RcodeToString[ret.Rcode]
+ if !ok {
+ // There are 4096 possible Rcodes, so it's probably still fine
+ // to put it in a metric label.
+ rcode = strconv.Itoa(ret.Rcode)
+ }
+ recordDuration(rcode)
+
+ // AuthenticatedData is intentionally not copied from the proxy reply because
+ // we don't have a secure channel to the proxy.
+ return proxyReply{
+ // Don't store large messages in the cache. Such large messages are very
+ // rare, and this protects against the cache using huge amounts of memory.
+ // DNS messages over TCP can be up to 64 KB in size, and after decompression
+ // this could go over 1 MB of memory usage.
+ NoStore: ret.Len() > cacheMaxItemSize,
+
+ Truncated: ret.Truncated,
+ Rcode: ret.Rcode,
+ Answer: ret.Answer,
+ Ns: ret.Ns,
+ Extra: ret.Extra,
+ Options: options,
+ }
+}
+
+func randomList(p []*proxy.Proxy) []*proxy.Proxy {
+ switch len(p) {
+ case 1:
+ return p
+ case 2:
+ if rand.Int()%2 == 0 {
+ return []*proxy.Proxy{p[1], p[0]} // swap
+ }
+ return p
+ }
+
+ shuffled := slices.Clone(p)
+ rand.Shuffle(len(shuffled), func(i, j int) {
+ shuffled[i], shuffled[j] = shuffled[j], shuffled[i]
+ })
+ return shuffled
+}
diff --git a/osbase/net/dns/forward/forward_test.go b/osbase/net/dns/forward/forward_test.go
new file mode 100644
index 0000000..f850591
--- /dev/null
+++ b/osbase/net/dns/forward/forward_test.go
@@ -0,0 +1,618 @@
+package forward
+
+import (
+ "fmt"
+ "slices"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/miekg/dns"
+
+ netDNS "source.monogon.dev/osbase/net/dns"
+ "source.monogon.dev/osbase/net/dns/test"
+ "source.monogon.dev/osbase/supervisor"
+)
+
+func rrStrings(rrs []dns.RR) []string {
+ s := make([]string, len(rrs))
+ for i, rr := range rrs {
+ s[i] = rr.String()
+ }
+ return s
+}
+
+func expectReply(t *testing.T, req *netDNS.Request, wantReply proxyReply) {
+ t.Helper()
+ if !req.Handled {
+ t.Errorf("Request was not handled")
+ }
+ if got, want := req.Reply.Truncated, wantReply.Truncated; got != want {
+ t.Errorf("Want truncated %v, got %v", want, got)
+ }
+ if got, want := req.Reply.Rcode, wantReply.Rcode; got != want {
+ t.Errorf("Want rcode %v, got %v", want, got)
+ }
+
+ wantExtra := wantReply.Extra
+ if req.Ropt != nil {
+ wantOpt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
+ wantOpt.Option = wantReply.Options
+ wantOpt.SetUDPSize(req.Ropt.UDPSize())
+ wantOpt.SetDo(req.Qopt.Do())
+ wantExtra = slices.Concat(wantExtra, []dns.RR{wantOpt})
+ }
+ checkReplySection(t, "answer", req.Reply.Answer, wantReply.Answer)
+ checkReplySection(t, "ns", req.Reply.Ns, wantReply.Ns)
+ checkReplySection(t, "extra", req.Reply.Extra, wantExtra)
+}
+
+func checkReplySection(t *testing.T, sectionName string, got []dns.RR, want []dns.RR) {
+ t.Helper()
+ gotStr := rrStrings(got)
+ wantStr := rrStrings(want)
+ if !slices.Equal(gotStr, wantStr) {
+ t.Errorf("Want %s:\n%s\nGot:\n%v", sectionName,
+ strings.Join(wantStr, "\n"), strings.Join(gotStr, "\n"))
+ }
+}
+
+type fakeTime struct {
+ time atomic.Pointer[time.Time]
+}
+
+func (f *fakeTime) now() time.Time {
+ t := f.time.Load()
+ if t != nil {
+ return *t
+ }
+ return time.Time{}
+}
+
+func (f *fakeTime) set(t time.Time) {
+ f.time.Store(&t)
+}
+
+func (f *fakeTime) add(t time.Duration) {
+ f.set(f.now().Add(t))
+}
+
+func TestUpstreams(t *testing.T) {
+ answerRecord1 := test.RR("example.com. IN A 127.0.0.1")
+ s1 := test.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ ret.Answer = append(ret.Answer, answerRecord1)
+ err := w.WriteMsg(ret)
+ if err != nil {
+ t.Error(err)
+ }
+ })
+ defer s1.Close()
+ answerRecord2 := test.RR("2.example.com. IN A 127.0.0.1")
+ s2 := test.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ ret.Answer = append(ret.Answer, answerRecord2)
+ err := w.WriteMsg(ret)
+ if err != nil {
+ t.Error(err)
+ }
+ })
+ defer s2.Close()
+
+ forward := New()
+ supervisor.TestHarness(t, forward.Run)
+
+ // If no upstreams are set, return an error.
+ req := netDNS.CreateTestRequest("example.com.", dns.TypeA, "udp")
+ forward.HandleDNS(req)
+ expectReply(t, req, replyNoUpstreams)
+
+ forward.DNSServers.Set([]string{s1.Addr})
+ time.Sleep(10 * time.Millisecond)
+
+ req = netDNS.CreateTestRequest("example.com.", dns.TypeA, "udp")
+ forward.HandleDNS(req)
+ expectReply(t, req, proxyReply{Answer: []dns.RR{answerRecord1}})
+
+ forward.DNSServers.Set([]string{s2.Addr})
+ time.Sleep(10 * time.Millisecond)
+
+ // New DNS server should be used.
+ req = netDNS.CreateTestRequest("2.example.com.", dns.TypeA, "udp")
+ forward.HandleDNS(req)
+ expectReply(t, req, proxyReply{Answer: []dns.RR{answerRecord2}})
+}
+
+// TestHealthcheck tests that if one of multiple upstreams is broken,
+// this upstream receives health check queries, and that client queries
+// succeed since they are retried on the good upstream.
+func TestHealthcheck(t *testing.T) {
+ var healthcheckCount atomic.Int64
+
+ sGood := test.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ ret.Rcode = dns.RcodeNameError
+ err := w.WriteMsg(ret)
+ if err != nil {
+ t.Error(err)
+ }
+ })
+ defer sGood.Close()
+ sBad := test.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ if r.Question[0] == (dns.Question{Name: ".", Qtype: dns.TypeNS, Qclass: dns.ClassINET}) {
+ healthcheckCount.Add(1)
+ }
+ w.Write([]byte("this is not a dns message"))
+ })
+ defer sBad.Close()
+
+ forward := New()
+ forward.DNSServers.Set([]string{sGood.Addr, sBad.Addr})
+ supervisor.TestHarness(t, forward.Run)
+ time.Sleep(10 * time.Millisecond)
+
+ for i := range 100 {
+ req := netDNS.CreateTestRequest(fmt.Sprintf("%v.example.com.", i), dns.TypeA, "udp")
+ forward.HandleDNS(req)
+ expectReply(t, req, proxyReply{Rcode: dns.RcodeNameError})
+ }
+
+ if healthcheckCount.Load() == 0 {
+ t.Error("Expected to see at least one healthcheck query.")
+ }
+}
+
+func TestRecursionDesired(t *testing.T) {
+ forward := New()
+
+ // If RecursionDesired is not set, refuse query.
+ req := netDNS.CreateTestRequest("example.com.", dns.TypeA, "udp")
+ req.Reply.RecursionDesired = false
+ forward.HandleDNS(req)
+ expectReply(t, req, proxyReply{Rcode: dns.RcodeRefused})
+
+ // If RecursionDesired is not set and the query was redirected, send reply as is.
+ req = netDNS.CreateTestRequest("external.default.scv.cluster.local.", dns.TypeA, "udp")
+ req.Reply.RecursionDesired = false
+ req.AddCNAME("example.com.", 10)
+ req.Handled = false
+ forward.HandleDNS(req)
+ expectReply(t, req, proxyReply{
+ Answer: []dns.RR{test.RR("external.default.scv.cluster.local. 10 IN CNAME example.com.")},
+ })
+}
+
+// TestCache tests that both concurrent and sequential queries use the cache.
+func TestCache(t *testing.T) {
+ var queryCount atomic.Int64
+
+ answerRecord := test.RR("example.com. IN A 127.0.0.1")
+ s := test.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ queryCount.Add(1)
+ // Sleep a bit until all concurrent queries are blocked.
+ time.Sleep(10 * time.Millisecond)
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ ret.Answer = append(ret.Answer, answerRecord)
+ err := w.WriteMsg(ret)
+ if err != nil {
+ t.Error(err)
+ }
+ })
+ defer s.Close()
+
+ forward := New()
+ forward.DNSServers.Set([]string{s.Addr})
+ supervisor.TestHarness(t, forward.Run)
+ time.Sleep(10 * time.Millisecond)
+
+ wg := sync.WaitGroup{}
+ for range 3 {
+ wg.Add(1)
+ go func() {
+ req := netDNS.CreateTestRequest("example.com.", dns.TypeA, "udp")
+ forward.HandleDNS(req)
+ expectReply(t, req, proxyReply{Answer: []dns.RR{answerRecord}})
+ wg.Done()
+ }()
+ }
+ wg.Wait()
+
+ req := netDNS.CreateTestRequest("example.com.", dns.TypeA, "udp")
+ forward.HandleDNS(req)
+ expectReply(t, req, proxyReply{Answer: []dns.RR{answerRecord}})
+
+ // tcp query
+ req = netDNS.CreateTestRequest("example.com.", dns.TypeA, "tcp")
+ forward.HandleDNS(req)
+ expectReply(t, req, proxyReply{Answer: []dns.RR{answerRecord}})
+
+ // query without OPT
+ req = netDNS.CreateTestRequest("example.com.", dns.TypeA, "udp")
+ req.Qopt = nil
+ req.Ropt = nil
+ forward.HandleDNS(req)
+ expectReply(t, req, proxyReply{Answer: []dns.RR{answerRecord}})
+
+ if got, want := queryCount.Load(), int64(1); got != want {
+ t.Errorf("Want %v queries, got %v", want, got)
+ }
+}
+
+func TestTtl(t *testing.T) {
+ var queryCount atomic.Int64
+ answer := []dns.RR{
+ test.RR("1.example.com. 3 CNAME 2.example.com."),
+ test.RR("2.example.com. 3600 IN A 127.0.0.1"),
+ }
+ answerDecrement := []dns.RR{
+ test.RR("1.example.com. 2 CNAME 2.example.com."),
+ test.RR("2.example.com. 3599 IN A 127.0.0.1"),
+ }
+ s := test.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ queryCount.Add(1)
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ ret.Answer = answer
+ err := w.WriteMsg(ret)
+ if err != nil {
+ t.Error(err)
+ }
+ })
+ defer s.Close()
+
+ forward := New()
+ ft := fakeTime{}
+ ft.set(time.Now())
+ forward.now = ft.now
+ forward.DNSServers.Set([]string{s.Addr})
+ supervisor.TestHarness(t, forward.Run)
+ time.Sleep(10 * time.Millisecond)
+
+ req := netDNS.CreateTestRequest("1.example.com.", dns.TypeA, "udp")
+ forward.HandleDNS(req)
+ expectReply(t, req, proxyReply{Answer: answer})
+
+ ft.add(1500 * time.Millisecond)
+
+ // TTL of cached reply should be decremented.
+ req = netDNS.CreateTestRequest("1.example.com.", dns.TypeA, "udp")
+ forward.HandleDNS(req)
+ expectReply(t, req, proxyReply{Answer: answerDecrement})
+
+ ft.add(2000 * time.Millisecond)
+
+ // Cache expired.
+ req = netDNS.CreateTestRequest("1.example.com.", dns.TypeA, "udp")
+ forward.HandleDNS(req)
+ expectReply(t, req, proxyReply{Answer: answer})
+
+ if got, want := queryCount.Load(), int64(2); got != want {
+ t.Errorf("Want %v queries, got %v", want, got)
+ }
+}
+
+// TestShuffle tests that replies from cache have shuffled RRsets.
+// In this example, only the A records should be shuffled,
+// the CNAMEs and RRSIG should stay where they are.
+func TestShuffle(t *testing.T) {
+ testAnswer := []dns.RR{
+ test.RR("1.example.com. CNAME 2.example.com."),
+ test.RR("2.example.com. CNAME 3.example.com."),
+ }
+ // A random shuffle of 20 items is extremely unlikely (1/(20!))
+ // to end up in the same order it was originally.
+ for i := range 20 {
+ testAnswer = append(testAnswer, test.RR(fmt.Sprintf("3.example.com. IN A 127.0.0.%v", i)))
+ }
+ testAnswer = append(testAnswer, test.RR("3.example.com. RRSIG A 8 2 3600 1 1 1 example.com AAAA AAAA AAAA"))
+
+ s := test.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ ret.Answer = testAnswer
+ err := w.WriteMsg(ret)
+ if err != nil {
+ t.Error(err)
+ }
+ })
+ defer s.Close()
+
+ forward := New()
+ forward.DNSServers.Set([]string{s.Addr})
+ supervisor.TestHarness(t, forward.Run)
+ time.Sleep(10 * time.Millisecond)
+
+ req := netDNS.CreateTestRequest("example.com.", dns.TypeA, "tcp")
+ forward.HandleDNS(req)
+ expectReply(t, req, proxyReply{Answer: testAnswer})
+
+ req = netDNS.CreateTestRequest("example.com.", dns.TypeA, "tcp")
+ forward.HandleDNS(req)
+
+ if slices.Equal(rrStrings(req.Reply.Answer), rrStrings(testAnswer)) {
+ t.Error("Expected second reply to be shuffled.")
+ }
+ slices.SortFunc(req.Reply.Answer[2:len(testAnswer)-1], func(a, b dns.RR) int {
+ return int(a.(*dns.A).A[3]) - int(b.(*dns.A).A[3])
+ })
+ expectReply(t, req, proxyReply{Answer: testAnswer})
+}
+
+func TestTruncated(t *testing.T) {
+ var queryCount atomic.Int64
+ answerRecord1 := test.RR("example.com. IN A 127.0.0.1")
+ answerRecord2 := test.RR("example.com. IN A 127.0.0.2")
+ s := test.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ queryCount.Add(1)
+ tcp := w.RemoteAddr().Network() == "tcp"
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ if tcp {
+ ret.Answer = append(ret.Answer, answerRecord2)
+ } else {
+ ret.Answer = append(ret.Answer, answerRecord1)
+ ret.Truncated = true
+ }
+ err := w.WriteMsg(ret)
+ if err != nil {
+ t.Error(err)
+ }
+ })
+ defer s.Close()
+
+ forward := New()
+ ft := fakeTime{}
+ ft.set(time.Now())
+ forward.now = ft.now
+ forward.DNSServers.Set([]string{s.Addr})
+ supervisor.TestHarness(t, forward.Run)
+ time.Sleep(10 * time.Millisecond)
+
+ for range 2 {
+ // Truncated replies are cached and returned for udp queries.
+ req := netDNS.CreateTestRequest("example.com.", dns.TypeA, "udp")
+ forward.HandleDNS(req)
+ expectReply(t, req, proxyReply{Truncated: true, Answer: []dns.RR{answerRecord1}})
+ }
+
+ // Cached truncated replies are not used for tcp queries.
+ req := netDNS.CreateTestRequest("example.com.", dns.TypeA, "tcp")
+ forward.HandleDNS(req)
+ expectReply(t, req, proxyReply{Answer: []dns.RR{answerRecord2}})
+
+ // Subsequent udp queries get the tcp reply.
+ req = netDNS.CreateTestRequest("example.com.", dns.TypeA, "udp")
+ forward.HandleDNS(req)
+ expectReply(t, req, proxyReply{Answer: []dns.RR{answerRecord2}})
+
+ ft.add(10000 * time.Second)
+
+ // After the cache expires, tcp is used.
+ req = netDNS.CreateTestRequest("example.com.", dns.TypeA, "udp")
+ forward.HandleDNS(req)
+ expectReply(t, req, proxyReply{Answer: []dns.RR{answerRecord2}})
+
+ if got, want := queryCount.Load(), int64(3); got != want {
+ t.Errorf("Want %v queries, got %v", want, got)
+ }
+}
+
+type testQuery struct {
+ qtype uint16
+ dnssec bool
+ checkingDisabled bool
+}
+
+// TestQueries tests that queries which differ in relevant fields
+// result in separate upstream queries and are separately cached.
+func TestQueries(t *testing.T) {
+ var queryCount atomic.Int64
+
+ answerRecord := test.RR("example.com. IN A 127.0.0.1")
+ answerRecordAAAA := test.RR("example.com. IN AAAA ::1")
+ answerRecordRRSIG := test.RR("example.com. IN RRSIG A 8 2 3600 1 1 1 example.com AAAA AAAA AAAA")
+ answerRecordCD := test.RR("example.com. IN A 127.0.0.2")
+ s := test.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ queryCount.Add(1)
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ if r.Question[0].Name != "example.com." || r.Question[0].Qclass != dns.ClassINET {
+ t.Errorf("Unexpected Name or Qclass")
+ return
+ }
+ switch (testQuery{r.Question[0].Qtype, r.IsEdns0().Do(), r.CheckingDisabled}) {
+ case testQuery{dns.TypeA, false, false}:
+ ret.Answer = append(ret.Answer, answerRecord)
+ case testQuery{dns.TypeAAAA, false, false}:
+ ret.Answer = append(ret.Answer, answerRecordAAAA)
+ case testQuery{dns.TypeA, true, false}:
+ ret.Answer = append(ret.Answer, answerRecord)
+ ret.Answer = append(ret.Answer, answerRecordRRSIG)
+ case testQuery{dns.TypeA, false, true}:
+ ret.Answer = append(ret.Answer, answerRecordCD)
+ }
+ err := w.WriteMsg(ret)
+ if err != nil {
+ t.Error(err)
+ }
+ })
+ defer s.Close()
+
+ forward := New()
+ forward.DNSServers.Set([]string{s.Addr})
+ supervisor.TestHarness(t, forward.Run)
+ time.Sleep(10 * time.Millisecond)
+
+ for range 2 {
+ req := netDNS.CreateTestRequest("example.com.", dns.TypeA, "udp")
+ forward.HandleDNS(req)
+ expectReply(t, req, proxyReply{Answer: []dns.RR{answerRecord}})
+
+ // different qtype
+ req = netDNS.CreateTestRequest("example.com.", dns.TypeAAAA, "udp")
+ forward.HandleDNS(req)
+ expectReply(t, req, proxyReply{Answer: []dns.RR{answerRecordAAAA}})
+
+ // DNSSEC flag
+ req = netDNS.CreateTestRequest("example.com.", dns.TypeA, "udp")
+ req.Qopt.SetDo()
+ req.Ropt.SetDo()
+ forward.HandleDNS(req)
+ expectReply(t, req, proxyReply{Answer: []dns.RR{answerRecord, answerRecordRRSIG}})
+
+ // CheckingDisabled flag
+ req = netDNS.CreateTestRequest("example.com.", dns.TypeA, "udp")
+ req.Reply.CheckingDisabled = true
+ forward.HandleDNS(req)
+ expectReply(t, req, proxyReply{Answer: []dns.RR{answerRecordCD}})
+ }
+
+ if got, want := queryCount.Load(), int64(4); got != want {
+ t.Errorf("Want %v queries, got %v", want, got)
+ }
+}
+
+// TestOPT tests that only certains OPT options are forwarded
+// in query and reply.
+func TestOPT(t *testing.T) {
+ s := test.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ wantOpt := &dns.OPT{}
+ wantOpt.Hdr.Name = "."
+ wantOpt.Hdr.Rrtype = dns.TypeOPT
+ wantOpt.SetUDPSize(r.IsEdns0().UDPSize())
+ wantOpt.Option = []dns.EDNS0{
+ &dns.EDNS0_DAU{AlgCode: []uint8{1, 4}},
+ &dns.EDNS0_DHU{AlgCode: []uint8{5}},
+ &dns.EDNS0_N3U{AlgCode: []uint8{6}},
+ }
+ if got, want := r.IsEdns0().String(), wantOpt.String(); got != want {
+ t.Errorf("Wanted opt %q, got %q", want, got)
+ }
+
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ ret.Rcode = dns.RcodeBadAlg
+ ret.SetEdns0(512, false)
+ ropt := ret.Extra[0].(*dns.OPT)
+ ropt.Option = []dns.EDNS0{
+ &dns.EDNS0_NSID{Nsid: "c0ffee"},
+ &dns.EDNS0_EDE{InfoCode: dns.ExtendedErrorCodeCensored, ExtraText: "****"},
+ &dns.EDNS0_EDE{InfoCode: dns.ExtendedErrorCodeDNSKEYMissing, ExtraText: "second problem"},
+ &dns.EDNS0_PADDING{Padding: []byte{0, 0, 0}},
+ }
+ err := w.WriteMsg(ret)
+ if err != nil {
+ t.Error(err)
+ }
+ })
+ defer s.Close()
+
+ forward := New()
+ forward.DNSServers.Set([]string{s.Addr})
+ supervisor.TestHarness(t, forward.Run)
+ time.Sleep(10 * time.Millisecond)
+
+ req := netDNS.CreateTestRequest("example.com.", dns.TypeA, "udp")
+ req.Qopt.Option = []dns.EDNS0{
+ &dns.EDNS0_NSID{Nsid: ""},
+ &dns.EDNS0_EDE{InfoCode: dns.ExtendedErrorCodeDNSBogus, ExtraText: "huh?"},
+ &dns.EDNS0_DAU{AlgCode: []uint8{1, 4}},
+ &dns.EDNS0_DHU{AlgCode: []uint8{5}},
+ &dns.EDNS0_N3U{AlgCode: []uint8{6}},
+ &dns.EDNS0_PADDING{Padding: []byte{0, 0, 0}},
+ }
+ forward.HandleDNS(req)
+ expectReply(t, req, proxyReply{
+ Rcode: dns.RcodeBadAlg,
+ Options: []dns.EDNS0{
+ &dns.EDNS0_EDE{InfoCode: dns.ExtendedErrorCodeCensored, ExtraText: "****"},
+ &dns.EDNS0_EDE{InfoCode: dns.ExtendedErrorCodeDNSKEYMissing, ExtraText: "second problem"},
+ },
+ })
+}
+
+// TestBadReply tests that if the qname of the reply is not what was
+// sent in the query, the reply is rejected.
+func TestBadReply(t *testing.T) {
+ answerRecord := test.RR("1.example.com. IN A 127.0.0.1")
+ s := test.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ ret.Question[0].Name = "1.example.com."
+ ret.Answer = append(ret.Answer, answerRecord)
+ err := w.WriteMsg(ret)
+ if err != nil {
+ t.Error(err)
+ }
+ })
+ defer s.Close()
+
+ forward := New()
+ forward.DNSServers.Set([]string{s.Addr})
+ supervisor.TestHarness(t, forward.Run)
+ time.Sleep(10 * time.Millisecond)
+
+ req := netDNS.CreateTestRequest("example.com.", dns.TypeA, "udp")
+ forward.HandleDNS(req)
+ expectReply(t, req, replyProtocolError)
+}
+
+// TestLargeReply tests that large replies are not stored,
+// but still shared with concurrent queries.
+func TestLargeReply(t *testing.T) {
+ var queryCount atomic.Int64
+
+ var testAnswer []dns.RR
+ for i := range 100 {
+ testAnswer = append(testAnswer, test.RR(fmt.Sprintf("%v.example.com. IN A 127.0.0.1", i)))
+ }
+
+ s := test.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ queryCount.Add(1)
+ // Sleep a bit until all concurrent queries are blocked.
+ time.Sleep(10 * time.Millisecond)
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ ret.Answer = testAnswer
+ err := w.WriteMsg(ret)
+ if err != nil {
+ t.Error(err)
+ }
+ })
+ defer s.Close()
+
+ forward := New()
+ forward.DNSServers.Set([]string{s.Addr})
+ supervisor.TestHarness(t, forward.Run)
+ time.Sleep(10 * time.Millisecond)
+
+ wg := sync.WaitGroup{}
+ for range 2 {
+ wg.Add(1)
+ go func() {
+ req := netDNS.CreateTestRequest("example.com.", dns.TypeA, "tcp")
+ forward.HandleDNS(req)
+ expectReply(t, req, proxyReply{Answer: testAnswer})
+ wg.Done()
+ }()
+ }
+ wg.Wait()
+
+ if got, want := queryCount.Load(), int64(1); got != want {
+ t.Errorf("Want %v queries, got %v", want, got)
+ }
+
+ req := netDNS.CreateTestRequest("example.com.", dns.TypeA, "tcp")
+ forward.HandleDNS(req)
+ expectReply(t, req, proxyReply{Answer: testAnswer})
+
+ if got, want := queryCount.Load(), int64(2); got != want {
+ t.Errorf("Want %v queries, got %v", want, got)
+ }
+}
diff --git a/osbase/net/dns/forward/metrics.go b/osbase/net/dns/forward/metrics.go
new file mode 100644
index 0000000..4977b02
--- /dev/null
+++ b/osbase/net/dns/forward/metrics.go
@@ -0,0 +1,59 @@
+package forward
+
+// Taken and modified from CoreDNS, under Apache 2.0.
+
+import (
+ "github.com/prometheus/client_golang/prometheus"
+
+ "source.monogon.dev/osbase/net/dns"
+)
+
+// Variables declared for monitoring.
+var (
+ // Possible results:
+ // * hit: Item found and returned from cache.
+ // * miss: Item not found in cache.
+ // * refresh: Item found in cache, but is either expired, or
+ // truncated while the client used TCP.
+ cacheLookupsCount = dns.MetricsFactory.NewCounterVec(prometheus.CounterOpts{
+ Namespace: "dnsserver",
+ Subsystem: "forward",
+ Name: "cache_lookups_total",
+ Help: "Counter of the number of cache lookups.",
+ }, []string{"result"})
+
+ // protocol is one of:
+ // * udp
+ // * udp_truncated
+ // * tcp
+ // rcode can be an uppercase rcode name, a numeric rcode if the rcode is not
+ // known, or one of:
+ // * timeout
+ // * network_error
+ // * protocol_error
+ upstreamDuration = dns.MetricsFactory.NewHistogramVec(prometheus.HistogramOpts{
+ Namespace: "dnsserver",
+ Subsystem: "forward",
+ Name: "upstream_duration_seconds",
+ Buckets: prometheus.ExponentialBuckets(0.00025, 2, 16), // from 0.25ms to 8 seconds
+ Help: "Histogram of the time each upstream request took.",
+ }, []string{"to", "protocol", "rcode"})
+
+ // Possible reasons:
+ // * concurrency_limit: Too many concurrent upstream queries.
+ // * no_upstreams: There are no upstreams configured.
+ // * no_recursion_desired: Client did not set Recursion Desired flag.
+ rejectsCount = dns.MetricsFactory.NewCounterVec(prometheus.CounterOpts{
+ Namespace: "dnsserver",
+ Subsystem: "forward",
+ Name: "rejects_total",
+ Help: "Counter of the number of queries rejected and not forwarded to an upstream.",
+ }, []string{"reason"})
+
+ healthcheckBrokenCount = dns.MetricsFactory.NewCounter(prometheus.CounterOpts{
+ Namespace: "dnsserver",
+ Subsystem: "forward",
+ Name: "healthcheck_broken_total",
+ Help: "Counter of the number of complete failures of the healthchecks.",
+ })
+)
diff --git a/osbase/net/dns/forward/proxy/BUILD.bazel b/osbase/net/dns/forward/proxy/BUILD.bazel
new file mode 100644
index 0000000..5f3b177
--- /dev/null
+++ b/osbase/net/dns/forward/proxy/BUILD.bazel
@@ -0,0 +1,35 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "proxy",
+ srcs = [
+ "connect.go",
+ "health.go",
+ "metrics.go",
+ "persistent.go",
+ "proxy.go",
+ "type.go",
+ ],
+ importpath = "source.monogon.dev/osbase/net/dns/forward/proxy",
+ visibility = ["//osbase/net/dns/forward:__subpackages__"],
+ deps = [
+ "//osbase/net/dns",
+ "//osbase/net/dns/forward/up",
+ "@com_github_miekg_dns//:dns",
+ "@com_github_prometheus_client_golang//prometheus",
+ ],
+)
+
+go_test(
+ name = "proxy_test",
+ srcs = [
+ "health_test.go",
+ "persistent_test.go",
+ "proxy_test.go",
+ ],
+ embed = [":proxy"],
+ deps = [
+ "//osbase/net/dns/test",
+ "@com_github_miekg_dns//:dns",
+ ],
+)
diff --git a/osbase/net/dns/forward/proxy/connect.go b/osbase/net/dns/forward/proxy/connect.go
new file mode 100644
index 0000000..d1c5b8e
--- /dev/null
+++ b/osbase/net/dns/forward/proxy/connect.go
@@ -0,0 +1,168 @@
+// Package proxy implements a forwarding proxy. It caches an upstream net.Conn
+// for some time, so if the same client returns the upstream's Conn will be
+// precached. Depending on how you benchmark this looks to be 50% faster than
+// just opening a new connection for every client.
+// It works with UDP and TCP and uses inband healthchecking.
+package proxy
+
+// Taken and modified from CoreDNS, under Apache 2.0.
+
+import (
+ "errors"
+ "io"
+ "strings"
+ "sync/atomic"
+ "time"
+
+ "github.com/miekg/dns"
+)
+
+// AdvertiseUDPSize is the maximum message size that we advertise in the OPT RR
+// of UDP messages. This is calculated as the minimum IPv6 MTU (1280) minus
+// size of IPv6 (40) and UDP (8) headers.
+const AdvertiseUDPSize = 1232
+
+// ErrCachedClosed means cached connection was closed by peer.
+var ErrCachedClosed = errors.New("cached connection was closed by peer")
+
+// limitTimeout is a utility function to auto-tune timeout values.
+// Average observed time is moved towards the last observed delay moderated by
+// a weight next timeout to use will be the double of the computed average,
+// limited by min and max frame.
+func limitTimeout(currentAvg *int64, minValue time.Duration, maxValue time.Duration) time.Duration {
+ rt := time.Duration(atomic.LoadInt64(currentAvg))
+ if rt < minValue {
+ return minValue
+ }
+ if rt < maxValue/2 {
+ return 2 * rt
+ }
+ return maxValue
+}
+
+func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight int64) {
+ dt := time.Duration(atomic.LoadInt64(currentAvg))
+ atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight)
+}
+
+func (t *Transport) dialTimeout() time.Duration {
+ return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout)
+}
+
+func (t *Transport) updateDialTimeout(newDialTime time.Duration) {
+ averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight)
+}
+
+// Dial dials the address configured in transport,
+// potentially reusing a connection or creating a new one.
+func (t *Transport) Dial(proto string) (*persistConn, bool, error) {
+ // If tls has been configured; use it.
+ if t.tlsConfig != nil {
+ proto = "tcp-tls"
+ }
+
+ t.dial <- proto
+ pc := <-t.ret
+
+ if pc != nil {
+ return pc, true, nil
+ }
+
+ reqTime := time.Now()
+ timeout := t.dialTimeout()
+ if proto == "tcp-tls" {
+ conn, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, timeout)
+ t.updateDialTimeout(time.Since(reqTime))
+ return &persistConn{c: conn}, false, err
+ }
+ conn, err := dns.DialTimeout(proto, t.addr, timeout)
+ t.updateDialTimeout(time.Since(reqTime))
+ return &persistConn{c: conn}, false, err
+}
+
+// Connect selects an upstream, sends the request and waits for a response.
+func (p *Proxy) Connect(m *dns.Msg, useTCP bool) (*dns.Msg, error) {
+ proto := "udp"
+ if useTCP {
+ proto = "tcp"
+ }
+
+ pc, cached, err := p.transport.Dial(proto)
+ if err != nil {
+ return nil, err
+ }
+
+ pc.c.UDPSize = AdvertiseUDPSize
+
+ pc.c.SetWriteDeadline(time.Now().Add(p.writeTimeout))
+ m.Id = dns.Id()
+
+ if err := pc.c.WriteMsg(m); err != nil {
+ pc.c.Close() // not giving it back
+ if err == io.EOF && cached {
+ return nil, ErrCachedClosed
+ }
+ return nil, err
+ }
+
+ var ret *dns.Msg
+ pc.c.SetReadDeadline(time.Now().Add(p.readTimeout))
+ for {
+ ret, err = pc.c.ReadMsg()
+ if err != nil {
+ if ret != nil && (m.Id == ret.Id) && p.transport.transportTypeFromConn(pc) == typeUDP && shouldTruncateResponse(err) {
+ // For UDP, if the error is an overflow, we probably have an upstream
+ // misbehaving in some way.
+ // (e.g. sending >512 byte responses without an eDNS0 OPT RR).
+ // Instead of returning an error, return an empty response
+ // with TC bit set. This will make the client retry over TCP
+ // (if that's supported) or at least receive a clean error.
+ // The connection is still good so we break before the close.
+
+ // Truncate the response.
+ ret = truncateResponse(ret)
+ break
+ }
+
+ pc.c.Close() // not giving it back
+ if err == io.EOF && cached {
+ return nil, ErrCachedClosed
+ }
+ return ret, err
+ }
+ // drop out-of-order responses
+ if m.Id == ret.Id {
+ break
+ }
+ }
+
+ p.transport.Yield(pc)
+
+ return ret, nil
+}
+
+const cumulativeAvgWeight = 4
+
+// Function to determine if a response should be truncated.
+func shouldTruncateResponse(err error) bool {
+ // This is to handle a scenario in which upstream sets the TC bit,
+ // but doesn't truncate the response and we get ErrBuf instead of overflow.
+ if errors.Is(err, dns.ErrBuf) {
+ return true
+ } else if strings.Contains(err.Error(), "overflow") {
+ return true
+ }
+ return false
+}
+
+// Function to return an empty response with TC (truncated) bit set.
+func truncateResponse(response *dns.Msg) *dns.Msg {
+ // Clear out Answer, Extra, and Ns sections
+ response.Answer = nil
+ response.Extra = nil
+ response.Ns = nil
+
+ // Set TC bit to indicate truncation.
+ response.Truncated = true
+ return response
+}
diff --git a/osbase/net/dns/forward/proxy/health.go b/osbase/net/dns/forward/proxy/health.go
new file mode 100644
index 0000000..4630cdc
--- /dev/null
+++ b/osbase/net/dns/forward/proxy/health.go
@@ -0,0 +1,127 @@
+package proxy
+
+// Taken and modified from CoreDNS, under Apache 2.0.
+
+import (
+ "crypto/tls"
+ "sync/atomic"
+ "time"
+
+ "github.com/miekg/dns"
+)
+
+// HealthChecker checks the upstream health.
+type HealthChecker interface {
+ Check(*Proxy) error
+ SetTLSConfig(*tls.Config)
+ GetTLSConfig() *tls.Config
+ SetRecursionDesired(bool)
+ GetRecursionDesired() bool
+ SetDomain(domain string)
+ GetDomain() string
+ SetTCPTransport()
+ GetReadTimeout() time.Duration
+ SetReadTimeout(time.Duration)
+ GetWriteTimeout() time.Duration
+ SetWriteTimeout(time.Duration)
+}
+
+// dnsHc is a health checker for a DNS endpoint (DNS, and DoT).
+type dnsHc struct {
+ c *dns.Client
+ recursionDesired bool
+ domain string
+}
+
+// NewHealthChecker returns a new HealthChecker.
+func NewHealthChecker(recursionDesired bool, domain string) HealthChecker {
+ c := new(dns.Client)
+ c.Net = "udp"
+ c.ReadTimeout = 1 * time.Second
+ c.WriteTimeout = 1 * time.Second
+
+ return &dnsHc{
+ c: c,
+ recursionDesired: recursionDesired,
+ domain: domain,
+ }
+}
+
+func (h *dnsHc) SetTLSConfig(cfg *tls.Config) {
+ h.c.Net = "tcp-tls"
+ h.c.TLSConfig = cfg
+}
+
+func (h *dnsHc) GetTLSConfig() *tls.Config {
+ return h.c.TLSConfig
+}
+
+func (h *dnsHc) SetRecursionDesired(recursionDesired bool) {
+ h.recursionDesired = recursionDesired
+}
+func (h *dnsHc) GetRecursionDesired() bool {
+ return h.recursionDesired
+}
+
+func (h *dnsHc) SetDomain(domain string) {
+ h.domain = domain
+}
+func (h *dnsHc) GetDomain() string {
+ return h.domain
+}
+
+func (h *dnsHc) SetTCPTransport() {
+ h.c.Net = "tcp"
+}
+
+func (h *dnsHc) GetReadTimeout() time.Duration {
+ return h.c.ReadTimeout
+}
+
+func (h *dnsHc) SetReadTimeout(t time.Duration) {
+ h.c.ReadTimeout = t
+}
+
+func (h *dnsHc) GetWriteTimeout() time.Duration {
+ return h.c.WriteTimeout
+}
+
+func (h *dnsHc) SetWriteTimeout(t time.Duration) {
+ h.c.WriteTimeout = t
+}
+
+// For HC, we send to . IN NS +[no]rec message to the upstream.
+// Dial timeouts and empty replies are considered fails,
+// basically anything else constitutes a healthy upstream.
+
+// Check is used as the up.Func in the up.Probe.
+func (h *dnsHc) Check(p *Proxy) error {
+ err := h.send(p.addr)
+ if err != nil {
+ healthcheckFailureCount.WithLabelValues(p.addr).Inc()
+ p.incrementFails()
+ return err
+ }
+
+ atomic.StoreUint32(&p.fails, 0)
+ return nil
+}
+
+func (h *dnsHc) send(addr string) error {
+ ping := new(dns.Msg)
+ ping.SetQuestion(h.domain, dns.TypeNS)
+ ping.MsgHdr.RecursionDesired = h.recursionDesired
+ ping.SetEdns0(AdvertiseUDPSize, false)
+
+ m, _, err := h.c.Exchange(ping, addr)
+ // If we got a header, we're alright,
+ // basically only care about I/O errors 'n stuff.
+ if err != nil && m != nil {
+ // Silly check, something sane came back.
+ if m.Response || m.Opcode == dns.OpcodeQuery {
+ err = nil
+ }
+ }
+
+ return err
+}
diff --git a/osbase/net/dns/forward/proxy/health_test.go b/osbase/net/dns/forward/proxy/health_test.go
new file mode 100644
index 0000000..ef8b60a
--- /dev/null
+++ b/osbase/net/dns/forward/proxy/health_test.go
@@ -0,0 +1,154 @@
+package proxy
+
+// Taken and modified from CoreDNS, under Apache 2.0.
+
+import (
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/miekg/dns"
+
+ "source.monogon.dev/osbase/net/dns/test"
+)
+
+func TestHealth(t *testing.T) {
+ i := uint32(0)
+ s := test.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ if r.Question[0].Name == "." && r.RecursionDesired == true {
+ atomic.AddUint32(&i, 1)
+ }
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ w.WriteMsg(ret)
+ })
+ defer s.Close()
+
+ hc := NewHealthChecker(true, ".")
+ hc.SetReadTimeout(10 * time.Millisecond)
+ hc.SetWriteTimeout(10 * time.Millisecond)
+
+ p := NewProxy(s.Addr)
+ p.readTimeout = 10 * time.Millisecond
+ err := hc.Check(p)
+ if err != nil {
+ t.Errorf("check failed: %v", err)
+ }
+
+ time.Sleep(20 * time.Millisecond)
+ i1 := atomic.LoadUint32(&i)
+ if i1 != 1 {
+ t.Errorf("Expected number of health checks with RecursionDesired==true to be %d, got %d", 1, i1)
+ }
+}
+
+func TestHealthTCP(t *testing.T) {
+ i := uint32(0)
+ s := test.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ if r.Question[0].Name == "." && r.RecursionDesired == true {
+ atomic.AddUint32(&i, 1)
+ }
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ w.WriteMsg(ret)
+ })
+ defer s.Close()
+
+ hc := NewHealthChecker(true, ".")
+ hc.SetTCPTransport()
+ hc.SetReadTimeout(10 * time.Millisecond)
+ hc.SetWriteTimeout(10 * time.Millisecond)
+
+ p := NewProxy(s.Addr)
+ p.readTimeout = 10 * time.Millisecond
+ err := hc.Check(p)
+ if err != nil {
+ t.Errorf("check failed: %v", err)
+ }
+
+ time.Sleep(20 * time.Millisecond)
+ i1 := atomic.LoadUint32(&i)
+ if i1 != 1 {
+ t.Errorf("Expected number of health checks with RecursionDesired==true to be %d, got %d", 1, i1)
+ }
+}
+
+func TestHealthNoRecursion(t *testing.T) {
+ i := uint32(0)
+ s := test.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ if r.Question[0].Name == "." && r.RecursionDesired == false {
+ atomic.AddUint32(&i, 1)
+ }
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ w.WriteMsg(ret)
+ })
+ defer s.Close()
+
+ hc := NewHealthChecker(false, ".")
+ hc.SetReadTimeout(10 * time.Millisecond)
+ hc.SetWriteTimeout(10 * time.Millisecond)
+
+ p := NewProxy(s.Addr)
+ p.readTimeout = 10 * time.Millisecond
+ err := hc.Check(p)
+ if err != nil {
+ t.Errorf("check failed: %v", err)
+ }
+
+ time.Sleep(20 * time.Millisecond)
+ i1 := atomic.LoadUint32(&i)
+ if i1 != 1 {
+ t.Errorf("Expected number of health checks with RecursionDesired==false to be %d, got %d", 1, i1)
+ }
+}
+
+func TestHealthTimeout(t *testing.T) {
+ s := test.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ // timeout
+ })
+ defer s.Close()
+
+ hc := NewHealthChecker(false, ".")
+ hc.SetReadTimeout(10 * time.Millisecond)
+ hc.SetWriteTimeout(10 * time.Millisecond)
+
+ p := NewProxy(s.Addr)
+ p.readTimeout = 10 * time.Millisecond
+ err := hc.Check(p)
+ if err == nil {
+ t.Errorf("expected error")
+ }
+}
+
+func TestHealthDomain(t *testing.T) {
+ hcDomain := "example.org."
+
+ i := uint32(0)
+ s := test.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ if r.Question[0].Name == hcDomain && r.RecursionDesired == true {
+ atomic.AddUint32(&i, 1)
+ }
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ w.WriteMsg(ret)
+ })
+ defer s.Close()
+
+ hc := NewHealthChecker(true, hcDomain)
+ hc.SetReadTimeout(10 * time.Millisecond)
+ hc.SetWriteTimeout(10 * time.Millisecond)
+
+ p := NewProxy(s.Addr)
+ p.readTimeout = 10 * time.Millisecond
+ err := hc.Check(p)
+ if err != nil {
+ t.Errorf("check failed: %v", err)
+ }
+
+ time.Sleep(12 * time.Millisecond)
+ i1 := atomic.LoadUint32(&i)
+ if i1 != 1 {
+ t.Errorf("Expected number of health checks with Domain==%s to be %d, got %d", hcDomain, 1, i1)
+ }
+}
diff --git a/osbase/net/dns/forward/proxy/metrics.go b/osbase/net/dns/forward/proxy/metrics.go
new file mode 100644
index 0000000..c08f1d8
--- /dev/null
+++ b/osbase/net/dns/forward/proxy/metrics.go
@@ -0,0 +1,19 @@
+package proxy
+
+// Taken and modified from CoreDNS, under Apache 2.0.
+
+import (
+ "github.com/prometheus/client_golang/prometheus"
+
+ "source.monogon.dev/osbase/net/dns"
+)
+
+// Variables declared for monitoring.
+var (
+ healthcheckFailureCount = dns.MetricsFactory.NewCounterVec(prometheus.CounterOpts{
+ Namespace: "dnsserver",
+ Subsystem: "forward",
+ Name: "healthcheck_failures_total",
+ Help: "Counter of the number of failed healthchecks.",
+ }, []string{"to"})
+)
diff --git a/osbase/net/dns/forward/proxy/persistent.go b/osbase/net/dns/forward/proxy/persistent.go
new file mode 100644
index 0000000..cb4f618
--- /dev/null
+++ b/osbase/net/dns/forward/proxy/persistent.go
@@ -0,0 +1,160 @@
+package proxy
+
+// Taken and modified from CoreDNS, under Apache 2.0.
+
+import (
+ "crypto/tls"
+ "sort"
+ "time"
+
+ "github.com/miekg/dns"
+)
+
+// a persistConn hold the dns.Conn and the last used time.
+type persistConn struct {
+ c *dns.Conn
+ used time.Time
+}
+
+// Transport hold the persistent cache.
+type Transport struct {
+ avgDialTime int64 // kind of average time of dial time
+ conns [typeTotalCount][]*persistConn // Buckets for udp, tcp and tcp-tls.
+ expire time.Duration // After this duration a connection is expired.
+ addr string
+ tlsConfig *tls.Config
+
+ dial chan string
+ yield chan *persistConn
+ ret chan *persistConn
+ stop chan bool
+}
+
+func newTransport(addr string) *Transport {
+ t := &Transport{
+ avgDialTime: int64(maxDialTimeout / 2),
+ conns: [typeTotalCount][]*persistConn{},
+ expire: defaultExpire,
+ addr: addr,
+ dial: make(chan string),
+ yield: make(chan *persistConn),
+ ret: make(chan *persistConn),
+ stop: make(chan bool),
+ }
+ return t
+}
+
+// connManager manages the persistent connection cache for UDP and TCP.
+func (t *Transport) connManager() {
+ ticker := time.NewTicker(defaultExpire)
+ defer ticker.Stop()
+Wait:
+ for {
+ select {
+ case proto := <-t.dial:
+ transtype := stringToTransportType(proto)
+ // take the last used conn - complexity O(1)
+ if stack := t.conns[transtype]; len(stack) > 0 {
+ pc := stack[len(stack)-1]
+ if time.Since(pc.used) < t.expire {
+ // Found one, remove from pool and return this conn.
+ t.conns[transtype] = stack[:len(stack)-1]
+ t.ret <- pc
+ continue Wait
+ }
+ // clear entire cache if the last conn is expired
+ t.conns[transtype] = nil
+ // now, the connections being passed to closeConns() are not reachable from
+ // transport methods anymore. So, it's safe to close them in a separate goroutine
+ go closeConns(stack)
+ }
+ t.ret <- nil
+
+ case pc := <-t.yield:
+ transtype := t.transportTypeFromConn(pc)
+ t.conns[transtype] = append(t.conns[transtype], pc)
+
+ case <-ticker.C:
+ t.cleanup(false)
+
+ case <-t.stop:
+ t.cleanup(true)
+ close(t.ret)
+ return
+ }
+ }
+}
+
+// closeConns closes connections.
+func closeConns(conns []*persistConn) {
+ for _, pc := range conns {
+ pc.c.Close()
+ }
+}
+
+// cleanup removes connections from cache.
+func (t *Transport) cleanup(all bool) {
+ staleTime := time.Now().Add(-t.expire)
+ for transtype, stack := range t.conns {
+ if len(stack) == 0 {
+ continue
+ }
+ if all {
+ t.conns[transtype] = nil
+ // now, the connections being passed to closeConns() are not reachable from
+ // transport methods anymore. So, it's safe to close them in a separate goroutine
+ go closeConns(stack)
+ continue
+ }
+ if stack[0].used.After(staleTime) {
+ continue
+ }
+
+ // connections in stack are sorted by "used"
+ good := sort.Search(len(stack), func(i int) bool {
+ return stack[i].used.After(staleTime)
+ })
+ t.conns[transtype] = stack[good:]
+ // now, the connections being passed to closeConns() are not reachable from
+ // transport methods anymore. So, it's safe to close them in a separate goroutine
+ go closeConns(stack[:good])
+ }
+}
+
+// It is hard to pin a value to this, the import thing is to no block forever,
+// losing at cached connection is not terrible.
+const yieldTimeout = 25 * time.Millisecond
+
+// Yield returns the connection to transport for reuse.
+func (t *Transport) Yield(pc *persistConn) {
+ pc.used = time.Now() // update used time
+
+ // Make this non-blocking, because in the case of a very busy forwarder
+ // we will *block* on this yield. This blocks the outer go-routine and stuff
+ // will just pile up. We timeout when the send fails to as returning
+ // these connection is an optimization anyway.
+ select {
+ case t.yield <- pc:
+ return
+ case <-time.After(yieldTimeout):
+ return
+ }
+}
+
+// Start starts the transport's connection manager.
+func (t *Transport) Start() { go t.connManager() }
+
+// Stop stops the transport's connection manager.
+func (t *Transport) Stop() { close(t.stop) }
+
+// SetExpire sets the connection expire time in transport.
+func (t *Transport) SetExpire(expire time.Duration) { t.expire = expire }
+
+// SetTLSConfig sets the TLS config in transport.
+func (t *Transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg }
+
+const (
+ defaultExpire = 10 * time.Second
+ minDialTimeout = 1 * time.Second
+ maxDialTimeout = 30 * time.Second
+)
diff --git a/osbase/net/dns/forward/proxy/persistent_test.go b/osbase/net/dns/forward/proxy/persistent_test.go
new file mode 100644
index 0000000..2856c3f
--- /dev/null
+++ b/osbase/net/dns/forward/proxy/persistent_test.go
@@ -0,0 +1,111 @@
+package proxy
+
+// Taken and modified from CoreDNS, under Apache 2.0.
+
+import (
+ "testing"
+ "time"
+
+ "github.com/miekg/dns"
+
+ "source.monogon.dev/osbase/net/dns/test"
+)
+
+func TestCached(t *testing.T) {
+ s := test.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ w.WriteMsg(ret)
+ })
+ defer s.Close()
+
+ tr := newTransport(s.Addr)
+ tr.Start()
+ defer tr.Stop()
+
+ c1, cache1, _ := tr.Dial("udp")
+ c2, cache2, _ := tr.Dial("udp")
+
+ if cache1 || cache2 {
+ t.Errorf("Expected non-cached connection")
+ }
+
+ tr.Yield(c1)
+ tr.Yield(c2)
+ c3, cached3, _ := tr.Dial("udp")
+ if !cached3 {
+ t.Error("Expected cached connection (c3)")
+ }
+ if c2 != c3 {
+ t.Error("Expected c2 == c3")
+ }
+
+ tr.Yield(c3)
+
+ // dial another protocol
+ c4, cached4, _ := tr.Dial("tcp")
+ if cached4 {
+ t.Errorf("Expected non-cached connection (c4)")
+ }
+ tr.Yield(c4)
+}
+
+func TestCleanupByTimer(t *testing.T) {
+ s := test.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ w.WriteMsg(ret)
+ })
+ defer s.Close()
+
+ tr := newTransport(s.Addr)
+ tr.SetExpire(10 * time.Millisecond)
+ tr.Start()
+ defer tr.Stop()
+
+ c1, _, _ := tr.Dial("udp")
+ c2, _, _ := tr.Dial("udp")
+ tr.Yield(c1)
+ time.Sleep(2 * time.Millisecond)
+ tr.Yield(c2)
+
+ time.Sleep(15 * time.Millisecond)
+ c3, cached, _ := tr.Dial("udp")
+ if cached {
+ t.Error("Expected non-cached connection (c3)")
+ }
+ tr.Yield(c3)
+
+ time.Sleep(15 * time.Millisecond)
+ c4, cached, _ := tr.Dial("udp")
+ if cached {
+ t.Error("Expected non-cached connection (c4)")
+ }
+ tr.Yield(c4)
+}
+
+func TestCleanupAll(t *testing.T) {
+ s := test.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ w.WriteMsg(ret)
+ })
+ defer s.Close()
+
+ tr := newTransport(s.Addr)
+
+ c1, _ := dns.DialTimeout("udp", tr.addr, maxDialTimeout)
+ c2, _ := dns.DialTimeout("udp", tr.addr, maxDialTimeout)
+ c3, _ := dns.DialTimeout("udp", tr.addr, maxDialTimeout)
+
+ tr.conns[typeUDP] = []*persistConn{{c1, time.Now()}, {c2, time.Now()}, {c3, time.Now()}}
+
+ if len(tr.conns[typeUDP]) != 3 {
+ t.Error("Expected 3 connections")
+ }
+ tr.cleanup(true)
+
+ if len(tr.conns[typeUDP]) > 0 {
+ t.Error("Expected no cached connections")
+ }
+}
diff --git a/osbase/net/dns/forward/proxy/proxy.go b/osbase/net/dns/forward/proxy/proxy.go
new file mode 100644
index 0000000..dc25a31
--- /dev/null
+++ b/osbase/net/dns/forward/proxy/proxy.go
@@ -0,0 +1,108 @@
+package proxy
+
+// Taken and modified from CoreDNS, under Apache 2.0.
+
+import (
+ "crypto/tls"
+ "runtime"
+ "sync/atomic"
+ "time"
+
+ "source.monogon.dev/osbase/net/dns/forward/up"
+)
+
+// Proxy defines an upstream host.
+type Proxy struct {
+ fails uint32
+ addr string
+
+ transport *Transport
+
+ writeTimeout time.Duration
+ readTimeout time.Duration
+
+ // health checking
+ probe *up.Probe
+ health HealthChecker
+}
+
+// NewProxy returns a new proxy.
+func NewProxy(addr string) *Proxy {
+ p := &Proxy{
+ addr: addr,
+ fails: 0,
+ probe: up.New(),
+ writeTimeout: 2 * time.Second,
+ readTimeout: 2 * time.Second,
+ transport: newTransport(addr),
+ health: NewHealthChecker(true, "."),
+ }
+
+ runtime.SetFinalizer(p, (*Proxy).finalizer)
+ return p
+}
+
+func (p *Proxy) Addr() string { return p.addr }
+
+// SetTLSConfig sets the TLS config in the lower p.transport
+// and in the healthchecking client.
+func (p *Proxy) SetTLSConfig(cfg *tls.Config) {
+ p.transport.SetTLSConfig(cfg)
+ p.health.SetTLSConfig(cfg)
+}
+
+// SetExpire sets the expire duration in the lower p.transport.
+func (p *Proxy) SetExpire(expire time.Duration) { p.transport.SetExpire(expire) }
+
+func (p *Proxy) GetHealthchecker() HealthChecker {
+ return p.health
+}
+
+func (p *Proxy) Fails() uint32 {
+ return atomic.LoadUint32(&p.fails)
+}
+
+// Healthcheck kicks of a round of health checks for this proxy.
+func (p *Proxy) Healthcheck() {
+ if p.health == nil {
+ return
+ }
+
+ p.probe.Do(func() error {
+ return p.health.Check(p)
+ })
+}
+
+// Down returns true if this proxy is down, i.e. has *more* fails than maxfails.
+func (p *Proxy) Down(maxfails uint32) bool {
+ if maxfails == 0 {
+ return false
+ }
+
+ fails := atomic.LoadUint32(&p.fails)
+ return fails > maxfails
+}
+
+// Stop close stops the health checking goroutine.
+func (p *Proxy) Stop() { p.probe.Stop() }
+func (p *Proxy) finalizer() { p.transport.Stop() }
+
+// Start starts the proxy's healthchecking.
+func (p *Proxy) Start(duration time.Duration) {
+ p.probe.Start(duration)
+ p.transport.Start()
+}
+
+func (p *Proxy) SetReadTimeout(duration time.Duration) {
+ p.readTimeout = duration
+}
+
+// incrementFails increments the number of fails safely.
+func (p *Proxy) incrementFails() {
+ curVal := atomic.LoadUint32(&p.fails)
+ if curVal > curVal+1 {
+ // overflow occurred, do not update the counter again
+ return
+ }
+ atomic.AddUint32(&p.fails, 1)
+}
diff --git a/osbase/net/dns/forward/proxy/proxy_test.go b/osbase/net/dns/forward/proxy/proxy_test.go
new file mode 100644
index 0000000..a9297da
--- /dev/null
+++ b/osbase/net/dns/forward/proxy/proxy_test.go
@@ -0,0 +1,156 @@
+package proxy
+
+// Taken and modified from CoreDNS, under Apache 2.0.
+
+import (
+ "errors"
+ "fmt"
+ "math"
+ "testing"
+ "time"
+
+ "github.com/miekg/dns"
+
+ "source.monogon.dev/osbase/net/dns/test"
+)
+
+func TestProxy(t *testing.T) {
+ s := test.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ ret.Answer = append(ret.Answer, test.RR("example.org. IN A 127.0.0.1"))
+ w.WriteMsg(ret)
+ })
+ defer s.Close()
+
+ p := NewProxy(s.Addr)
+ p.readTimeout = 10 * time.Millisecond
+ p.Start(5 * time.Second)
+ m := new(dns.Msg)
+
+ m.SetQuestion("example.org.", dns.TypeA)
+
+ resp, err := p.Connect(m, false)
+ if err != nil {
+ t.Errorf("Failed to connect to testdnsserver: %s", err)
+ }
+
+ if x := resp.Answer[0].Header().Name; x != "example.org." {
+ t.Errorf("Expected %s, got %s", "example.org.", x)
+ }
+}
+
+func TestProtocolSelection(t *testing.T) {
+ p := NewProxy("bad_address")
+ p.readTimeout = 10 * time.Millisecond
+
+ go func() {
+ p.Connect(new(dns.Msg), false)
+ p.Connect(new(dns.Msg), true)
+ }()
+
+ for i, exp := range []string{"udp", "tcp"} {
+ proto := <-p.transport.dial
+ p.transport.ret <- nil
+ if proto != exp {
+ t.Errorf("Unexpected protocol in case %d, expected %q, actual %q", i, exp, proto)
+ }
+ }
+}
+
+func TestProxyIncrementFails(t *testing.T) {
+ var testCases = []struct {
+ name string
+ fails uint32
+ expectFails uint32
+ }{
+ {
+ name: "increment fails counter overflows",
+ fails: math.MaxUint32,
+ expectFails: math.MaxUint32,
+ },
+ {
+ name: "increment fails counter",
+ fails: 0,
+ expectFails: 1,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ p := NewProxy("bad_address")
+ p.fails = tc.fails
+ p.incrementFails()
+ if p.fails != tc.expectFails {
+ t.Errorf("Expected fails to be %d, got %d", tc.expectFails, p.fails)
+ }
+ })
+ }
+}
+
+func TestCoreDNSOverflow(t *testing.T) {
+ s := test.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+
+ var answers []dns.RR
+ for i := range 50 {
+ answers = append(answers, test.RR(fmt.Sprintf("example.org. IN A 127.0.0.%v", i)))
+ }
+ ret.Answer = answers
+ w.WriteMsg(ret)
+ })
+ defer s.Close()
+
+ p := NewProxy(s.Addr)
+ p.readTimeout = 10 * time.Millisecond
+ p.Start(5 * time.Second)
+ defer p.Stop()
+
+ // Test different connection modes
+ testConnection := func(proto string, useTCP bool, expectTruncated bool) {
+ t.Helper()
+
+ queryMsg := new(dns.Msg)
+ queryMsg.SetQuestion("example.org.", dns.TypeA)
+
+ response, err := p.Connect(queryMsg, useTCP)
+ if err != nil {
+ t.Errorf("Failed to connect to testdnsserver: %s", err)
+ }
+
+ if response.Truncated != expectTruncated {
+ t.Errorf("Expected truncated response for %s, but got TC flag %v", proto, response.Truncated)
+ }
+ }
+
+ // Test udp, expect truncated response
+ testConnection("udp", false, true)
+
+ // Test tcp, expect no truncated response
+ testConnection("tcp", true, false)
+}
+
+func TestShouldTruncateResponse(t *testing.T) {
+ testCases := []struct {
+ testname string
+ err error
+ expected bool
+ }{
+ {"BadAlgorithm", dns.ErrAlg, false},
+ {"BufferSizeTooSmall", dns.ErrBuf, true},
+ {"OverflowUnpackingA", errors.New("overflow unpacking a"), true},
+ {"OverflowingHeaderSize", errors.New("overflowing header size"), true},
+ {"OverflowpackingA", errors.New("overflow packing a"), true},
+ {"ErrSig", dns.ErrSig, false},
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.testname, func(t *testing.T) {
+ result := shouldTruncateResponse(tc.err)
+ if result != tc.expected {
+ t.Errorf("For testname '%v', expected %v but got %v", tc.testname, tc.expected, result)
+ }
+ })
+ }
+}
diff --git a/osbase/net/dns/forward/proxy/type.go b/osbase/net/dns/forward/proxy/type.go
new file mode 100644
index 0000000..6eb78ce
--- /dev/null
+++ b/osbase/net/dns/forward/proxy/type.go
@@ -0,0 +1,41 @@
+package proxy
+
+// Taken and modified from CoreDNS, under Apache 2.0.
+
+import (
+ "net"
+)
+
+type transportType int
+
+const (
+ typeUDP transportType = iota
+ typeTCP
+ typeTLS
+ typeTotalCount // keep this last
+)
+
+func stringToTransportType(s string) transportType {
+ switch s {
+ case "udp":
+ return typeUDP
+ case "tcp":
+ return typeTCP
+ case "tcp-tls":
+ return typeTLS
+ }
+
+ return typeUDP
+}
+
+func (t *Transport) transportTypeFromConn(pc *persistConn) transportType {
+ if _, ok := pc.c.Conn.(*net.UDPConn); ok {
+ return typeUDP
+ }
+
+ if t.tlsConfig == nil {
+ return typeTCP
+ }
+
+ return typeTLS
+}
diff --git a/osbase/net/dns/forward/up/BUILD.bazel b/osbase/net/dns/forward/up/BUILD.bazel
new file mode 100644
index 0000000..b9f2683
--- /dev/null
+++ b/osbase/net/dns/forward/up/BUILD.bazel
@@ -0,0 +1,14 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "up",
+ srcs = ["up.go"],
+ importpath = "source.monogon.dev/osbase/net/dns/forward/up",
+ visibility = ["//osbase/net/dns/forward:__subpackages__"],
+)
+
+go_test(
+ name = "up_test",
+ srcs = ["up_test.go"],
+ embed = [":up"],
+)
diff --git a/osbase/net/dns/forward/up/up.go b/osbase/net/dns/forward/up/up.go
new file mode 100644
index 0000000..9806319
--- /dev/null
+++ b/osbase/net/dns/forward/up/up.go
@@ -0,0 +1,93 @@
+// Package up is used to run a function for some duration. If a new function is
+// added while a previous run is still ongoing, nothing new will be executed.
+package up
+
+// Taken and modified from CoreDNS, under Apache 2.0.
+
+import (
+ "sync"
+ "time"
+)
+
+// Probe is used to run a single Func until it returns true
+// (indicating a target is healthy).
+// If an Func is already in progress no new one will be added,
+// i.e. there is always a maximum of 1 checks in flight.
+//
+// There is a tradeoff to be made in figuring out quickly that an upstream is
+// healthy and not doing much work (sending queries) to find that out.
+// Having some kind of exp. backoff here won't help much, because you don't
+// want to backoff too much. You then also need random queries to be performed
+// every so often to quickly detect a working upstream. In the end we just send
+// a query every 0.5 second to check the upstream. This hopefully strikes a
+// balance between getting information about the upstream state quickly and not
+// doing too much work. Note that 0.5s is still an eternity in DNS, so we may
+// actually want to shorten it.
+type Probe struct {
+ sync.Mutex
+ inprogress int
+ interval time.Duration
+}
+
+// Func is used to determine if a target is alive.
+// If so this function must return nil.
+type Func func() error
+
+// New returns a pointer to an initialized Probe.
+func New() *Probe { return &Probe{} }
+
+// Do will probe target, if a probe is already in progress this is a noop.
+func (p *Probe) Do(f Func) {
+ p.Lock()
+ if p.inprogress != idle {
+ p.Unlock()
+ return
+ }
+ p.inprogress = active
+ interval := p.interval
+ p.Unlock()
+ // Passed the lock. Now run f for as long it returns false.
+ // If a true is returned we return from the goroutine
+ // and we can accept another Func to run.
+ go func() {
+ i := 1
+ for {
+ if err := f(); err == nil {
+ break
+ }
+ time.Sleep(interval)
+ p.Lock()
+ if p.inprogress == stop {
+ p.Unlock()
+ return
+ }
+ p.Unlock()
+ i++
+ }
+
+ p.Lock()
+ p.inprogress = idle
+ p.Unlock()
+ }()
+}
+
+// Stop stops the probing.
+func (p *Probe) Stop() {
+ p.Lock()
+ p.inprogress = stop
+ p.Unlock()
+}
+
+// Start will initialize the probe manager,
+// after which probes can be initiated with Do.
+func (p *Probe) Start(interval time.Duration) {
+ p.Lock()
+ p.interval = interval
+ p.Unlock()
+}
+
+const (
+ idle = iota
+ active
+ stop
+)
diff --git a/osbase/net/dns/forward/up/up_test.go b/osbase/net/dns/forward/up/up_test.go
new file mode 100644
index 0000000..0d0f928
--- /dev/null
+++ b/osbase/net/dns/forward/up/up_test.go
@@ -0,0 +1,42 @@
+package up
+
+// Taken and modified from CoreDNS, under Apache 2.0.
+
+import (
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+func TestUp(t *testing.T) {
+ pr := New()
+ wg := sync.WaitGroup{}
+ hits := int32(0)
+
+ upfunc := func() error {
+ atomic.AddInt32(&hits, 1)
+ // Sleep tiny amount so that our other pr.Do() calls hit the lock.
+ time.Sleep(3 * time.Millisecond)
+ wg.Done()
+ return nil
+ }
+
+ pr.Start(5 * time.Millisecond)
+ defer pr.Stop()
+
+ // These functions AddInt32 to the same hits variable, but we only want to
+ // wait when upfunc finishes, as that only calls Done() on the waitgroup.
+ upfuncNoWg := func() error { atomic.AddInt32(&hits, 1); return nil }
+ wg.Add(1)
+ pr.Do(upfunc)
+ pr.Do(upfuncNoWg)
+ pr.Do(upfuncNoWg)
+
+ wg.Wait()
+
+ h := atomic.LoadInt32(&hits)
+ if h != 1 {
+ t.Errorf("Expected hits to be %d, got %d", 1, h)
+ }
+}
diff --git a/osbase/supervisor/supervisor_testhelpers.go b/osbase/supervisor/supervisor_testhelpers.go
index ba015a2..cca93ff 100644
--- a/osbase/supervisor/supervisor_testhelpers.go
+++ b/osbase/supervisor/supervisor_testhelpers.go
@@ -96,7 +96,7 @@
}
}
- time.Sleep(time.Second)
+ time.Sleep(10 * time.Millisecond)
}
})
return ctxC, lt