osbase/net/dns/forward: add DNS forward handler

This adds a DNS server handler for forwarding queries to upstream DNS
resolvers, with a built-in cache. The implementation is partially based
on CoreDNS. The proxy, cache and up packages are only lightly modified.
The forward package itself however is mostly new code. Unlike CoreDNS,
it supports changing upstreams at runtime, and has integrated caching
and answer order randomization.

Some improvements over CoreDNS:
- Concurrent identical queries only result in one upstream query.
- In case of errors, Extended DNS Errors are added to replies.
- Very large replies are not stored in the cache to avoid using too much
memory.

Change-Id: I42294ae4997d621a6e55c98e46a04874eab75c99
Reviewed-on: https://review.monogon.dev/c/monogon/+/3258
Reviewed-by: Lorenz Brun <lorenz@monogon.tech>
Tested-by: Jenkins CI
diff --git a/osbase/net/dns/forward/forward.go b/osbase/net/dns/forward/forward.go
new file mode 100644
index 0000000..00271ec
--- /dev/null
+++ b/osbase/net/dns/forward/forward.go
@@ -0,0 +1,390 @@
+// Package forward implements a forwarding proxy.
+//
+// A cache is used to reduce load on the upstream servers. Cached items are only
+// used for a short time, because otherwise, we would need to provide a feature
+// for flushing the cache. The cache is most useful for taking the load from
+// applications making very frequent repeated queries. The cache also doubles as
+// a way to merge concurrent identical queries, since items are inserted into
+// the cache before sending the query upstream (see also RFC 5452, Section 5).
+package forward
+
+// Taken and modified from the Forward plugin of CoreDNS, under Apache 2.0.
+
+import (
+	"context"
+	"errors"
+	"hash/maphash"
+	"math/rand/v2"
+	"os"
+	"slices"
+	"strconv"
+	"sync/atomic"
+	"time"
+
+	"github.com/miekg/dns"
+
+	"source.monogon.dev/osbase/event/memory"
+	"source.monogon.dev/osbase/net/dns/forward/cache"
+	"source.monogon.dev/osbase/net/dns/forward/proxy"
+	"source.monogon.dev/osbase/supervisor"
+)
+
+const (
+	connectionExpire    = 10 * time.Second
+	healthcheckInterval = 500 * time.Millisecond
+	forwardTimeout      = 5 * time.Second
+	maxFails            = 2
+	maxConcurrent       = 5000
+	maxUpstreams        = 15
+)
+
+// Forward represents a plugin instance that can proxy requests to another (DNS)
+// server. It has a list of proxies each representing one upstream proxy.
+type Forward struct {
+	DNSServers memory.Value[[]string]
+	upstreams  atomic.Pointer[[]*proxy.Proxy]
+
+	concurrent atomic.Int64
+
+	seed  maphash.Seed
+	cache *cache.Cache[*cacheItem]
+
+	// now can be used to override time for testing.
+	now func() time.Time
+}
+
+// New returns a new Forward.
+func New() *Forward {
+	return &Forward{
+		seed:  maphash.MakeSeed(),
+		cache: cache.New[*cacheItem](cacheCapacity),
+		now:   time.Now,
+	}
+}
+
+func (f *Forward) Run(ctx context.Context) error {
+	supervisor.Signal(ctx, supervisor.SignalHealthy)
+
+	var lastAddrs []string
+	var upstreams []*proxy.Proxy
+
+	w := f.DNSServers.Watch()
+	defer w.Close()
+	for {
+		addrs, err := w.Get(ctx)
+		if err != nil {
+			for _, p := range upstreams {
+				p.Stop()
+			}
+			f.upstreams.Store(nil)
+			return err
+		}
+
+		if len(addrs) > maxUpstreams {
+			addrs = addrs[:maxUpstreams]
+		}
+
+		if slices.Equal(addrs, lastAddrs) {
+			continue
+		}
+		lastAddrs = addrs
+		supervisor.Logger(ctx).Infof("New upstream DNS servers: %s", addrs)
+
+		newAddrs := make(map[string]bool)
+		for _, addr := range addrs {
+			newAddrs[addr] = true
+		}
+		var newUpstreams []*proxy.Proxy
+		for _, p := range upstreams {
+			if newAddrs[p.Addr()] {
+				delete(newAddrs, p.Addr())
+				newUpstreams = append(newUpstreams, p)
+			} else {
+				p.Stop()
+			}
+		}
+		for newAddr := range newAddrs {
+			p := proxy.NewProxy(newAddr)
+			p.SetExpire(connectionExpire)
+			p.GetHealthchecker().SetRecursionDesired(true)
+			p.GetHealthchecker().SetDomain(".")
+			p.Start(healthcheckInterval)
+			newUpstreams = append(newUpstreams, p)
+		}
+		upstreams = newUpstreams
+		f.upstreams.Store(&newUpstreams)
+	}
+}
+
+type proxyReply struct {
+	// NoStore indicates that the reply should not be stored in the cache.
+	// This could be because it is cheap to obtain or expensive to store.
+	NoStore bool
+
+	Truncated bool
+	Rcode     int
+	Answer    []dns.RR
+	Ns        []dns.RR
+	Extra     []dns.RR
+	Options   []dns.EDNS0
+}
+
+var (
+	replyConcurrencyLimit = proxyReply{
+		NoStore: true,
+		Rcode:   dns.RcodeServerFailure,
+		Options: []dns.EDNS0{&dns.EDNS0_EDE{
+			InfoCode:  dns.ExtendedErrorCodeOther,
+			ExtraText: "too many concurrent queries",
+		}},
+	}
+	replyNoUpstreams = proxyReply{
+		NoStore: true,
+		Rcode:   dns.RcodeRefused,
+		Options: []dns.EDNS0{&dns.EDNS0_EDE{
+			InfoCode:  dns.ExtendedErrorCodeOther,
+			ExtraText: "no upstream DNS servers configured",
+		}},
+	}
+	replyProtocolError = proxyReply{
+		Rcode: dns.RcodeServerFailure,
+		Options: []dns.EDNS0{&dns.EDNS0_EDE{
+			InfoCode:  dns.ExtendedErrorCodeNetworkError,
+			ExtraText: "DNS protocol error when querying upstream DNS server",
+		}},
+	}
+	replyTimeout = proxyReply{
+		Rcode: dns.RcodeServerFailure,
+		Options: []dns.EDNS0{&dns.EDNS0_EDE{
+			InfoCode:  dns.ExtendedErrorCodeNetworkError,
+			ExtraText: "timeout when querying upstream DNS server",
+		}},
+	}
+	replyNetworkError = proxyReply{
+		Rcode: dns.RcodeServerFailure,
+		Options: []dns.EDNS0{&dns.EDNS0_EDE{
+			InfoCode:  dns.ExtendedErrorCodeNetworkError,
+			ExtraText: "network error when querying upstream DNS server",
+		}},
+	}
+)
+
+func (f *Forward) queryProxies(
+	question dns.Question,
+	dnssec bool,
+	checkingDisabled bool,
+	queryOptions []dns.EDNS0,
+	useTCP bool,
+) proxyReply {
+	count := f.concurrent.Add(1)
+	defer f.concurrent.Add(-1)
+	if count > maxConcurrent {
+		rejectsCount.WithLabelValues("concurrency_limit").Inc()
+		return replyConcurrencyLimit
+	}
+
+	// Construct outgoing query.
+	qopt := new(dns.OPT)
+	qopt.Hdr.Name = "."
+	qopt.Hdr.Rrtype = dns.TypeOPT
+	qopt.SetUDPSize(proxy.AdvertiseUDPSize)
+	if dnssec {
+		qopt.SetDo()
+	}
+	qopt.Option = queryOptions
+	m := &dns.Msg{
+		MsgHdr: dns.MsgHdr{
+			Opcode:           dns.OpcodeQuery,
+			RecursionDesired: true,
+			CheckingDisabled: checkingDisabled,
+		},
+		Question: []dns.Question{question},
+		Extra:    []dns.RR{qopt},
+	}
+
+	var list []*proxy.Proxy
+	if upstreams := f.upstreams.Load(); upstreams != nil {
+		list = randomList(*upstreams)
+	}
+
+	if len(list) == 0 {
+		rejectsCount.WithLabelValues("no_upstreams").Inc()
+		return replyNoUpstreams
+	}
+
+	proto := "udp"
+	if useTCP {
+		proto = "tcp"
+	}
+
+	var (
+		curUpstream *proxy.Proxy
+		curStart    time.Time
+		ret         *dns.Msg
+		err         error
+	)
+	recordDuration := func(rcode string) {
+		upstreamDuration.WithLabelValues(curUpstream.Addr(), proto, rcode).Observe(time.Since(curStart).Seconds())
+	}
+
+	fails := 0
+	i := 0
+	listStart := time.Now()
+	deadline := listStart.Add(forwardTimeout)
+	for {
+		if i >= len(list) {
+			// reached the end of list, reset to begin
+			i = 0
+			fails = 0
+
+			// Sleep for a bit if the last time we started going through the list was
+			// very recent.
+			time.Sleep(time.Until(listStart.Add(time.Second)))
+			listStart = time.Now()
+		}
+
+		curUpstream = list[i]
+		i++
+		if curUpstream.Down(maxFails) {
+			fails++
+			if fails < len(list) {
+				continue
+			}
+			// All upstream proxies are dead, assume healthcheck is completely broken
+			// and connect to a random upstream.
+			healthcheckBrokenCount.Inc()
+		}
+
+		curStart = time.Now()
+
+		for {
+			ret, err = curUpstream.Connect(m, useTCP)
+
+			if errors.Is(err, proxy.ErrCachedClosed) {
+				// Remote side closed conn, can only happen with TCP.
+				continue
+			}
+			break
+		}
+
+		if err != nil {
+			// Kick off health check to see if *our* upstream is broken.
+			curUpstream.Healthcheck()
+
+			retry := fails < len(list) && time.Now().Before(deadline)
+			var dnsError *dns.Error
+			switch {
+			case errors.Is(err, os.ErrDeadlineExceeded):
+				recordDuration("timeout")
+				if !retry {
+					return replyTimeout
+				}
+			case errors.As(err, &dnsError):
+				recordDuration("protocol_error")
+				if !retry {
+					return replyProtocolError
+				}
+			default:
+				recordDuration("network_error")
+				if !retry {
+					return replyNetworkError
+				}
+			}
+			continue
+		}
+
+		break
+	}
+
+	if !ret.Response || ret.Opcode != dns.OpcodeQuery || len(ret.Question) != 1 {
+		recordDuration("protocol_error")
+		return replyProtocolError
+	}
+
+	if ret.Truncated && useTCP {
+		recordDuration("protocol_error")
+		return replyProtocolError
+	}
+	if ret.Truncated {
+		proto = "udp_truncated"
+	}
+
+	// Check that the reply matches the question.
+	retq := ret.Question[0]
+	if retq.Qtype != question.Qtype || retq.Qclass != question.Qclass ||
+		(retq.Name != question.Name && dns.CanonicalName(retq.Name) != question.Name) {
+		recordDuration("protocol_error")
+		return replyProtocolError
+	}
+
+	// Extract OPT from reply.
+	var ropt *dns.OPT
+	var options []dns.EDNS0
+	for i := len(ret.Extra) - 1; i >= 0; i-- {
+		if rr, ok := ret.Extra[i].(*dns.OPT); ok {
+			if ropt != nil {
+				// Found more than one OPT.
+				recordDuration("protocol_error")
+				return replyProtocolError
+			}
+			ropt = rr
+			ret.Extra = append(ret.Extra[:i], ret.Extra[i+1:]...)
+		}
+	}
+	if ropt != nil {
+		if ropt.Version() != 0 || ropt.Hdr.Name != "." {
+			recordDuration("protocol_error")
+			return replyProtocolError
+		}
+		// Forward Extended DNS Error options.
+		for _, option := range ropt.Option {
+			switch option.(type) {
+			case *dns.EDNS0_EDE:
+				options = append(options, option)
+			}
+		}
+	}
+
+	rcode, ok := dns.RcodeToString[ret.Rcode]
+	if !ok {
+		// There are 4096 possible Rcodes, so it's probably still fine
+		// to put it in a metric label.
+		rcode = strconv.Itoa(ret.Rcode)
+	}
+	recordDuration(rcode)
+
+	// AuthenticatedData is intentionally not copied from the proxy reply because
+	// we don't have a secure channel to the proxy.
+	return proxyReply{
+		// Don't store large messages in the cache. Such large messages are very
+		// rare, and this protects against the cache using huge amounts of memory.
+		// DNS messages over TCP can be up to 64 KB in size, and after decompression
+		// this could go over 1 MB of memory usage.
+		NoStore: ret.Len() > cacheMaxItemSize,
+
+		Truncated: ret.Truncated,
+		Rcode:     ret.Rcode,
+		Answer:    ret.Answer,
+		Ns:        ret.Ns,
+		Extra:     ret.Extra,
+		Options:   options,
+	}
+}
+
+func randomList(p []*proxy.Proxy) []*proxy.Proxy {
+	switch len(p) {
+	case 1:
+		return p
+	case 2:
+		if rand.Int()%2 == 0 {
+			return []*proxy.Proxy{p[1], p[0]} // swap
+		}
+		return p
+	}
+
+	shuffled := slices.Clone(p)
+	rand.Shuffle(len(shuffled), func(i, j int) {
+		shuffled[i], shuffled[j] = shuffled[j], shuffled[i]
+	})
+	return shuffled
+}