osbase/net/dns/forward: add DNS forward handler
This adds a DNS server handler for forwarding queries to upstream DNS
resolvers, with a built-in cache. The implementation is partially based
on CoreDNS. The proxy, cache and up packages are only lightly modified.
The forward package itself however is mostly new code. Unlike CoreDNS,
it supports changing upstreams at runtime, and has integrated caching
and answer order randomization.
Some improvements over CoreDNS:
- Concurrent identical queries only result in one upstream query.
- In case of errors, Extended DNS Errors are added to replies.
- Very large replies are not stored in the cache to avoid using too much
memory.
Change-Id: I42294ae4997d621a6e55c98e46a04874eab75c99
Reviewed-on: https://review.monogon.dev/c/monogon/+/3258
Reviewed-by: Lorenz Brun <lorenz@monogon.tech>
Tested-by: Jenkins CI
diff --git a/osbase/net/dns/forward/cache.go b/osbase/net/dns/forward/cache.go
new file mode 100644
index 0000000..d2422a9
--- /dev/null
+++ b/osbase/net/dns/forward/cache.go
@@ -0,0 +1,278 @@
+package forward
+
+import (
+ "hash/maphash"
+ "math/rand/v2"
+ "slices"
+ "sync"
+ "time"
+
+ "github.com/miekg/dns"
+
+ netDNS "source.monogon.dev/osbase/net/dns"
+)
+
+// The cache uses at most cacheMaxItemSize * cacheCapacity = 20 MB of memory.
+// Actual memory usage may be slightly higher due to the overhead of in-memory
+// data structures compared to the serialized, uncompressed length.
+const (
+ cacheMaxItemSize = 2048
+ cacheCapacity = 10000
+ cacheMinSeconds = 1
+ cacheMaxSeconds = 5
+)
+
+// cacheKey is the key used for cache lookups. Both the DNSSEC ok and the
+// Checking Disabled flag influence the reply. While it would be possible to
+// always make upstream queries with DNSSEC, and then strip the authenticating
+// records if the client did not request it, this would mostly just waste
+// bandwidth. In theory, it would be possible to cache NXDOMAINs independently
+// of the QTYPE (RFC 2308, Section 5). However, the additional complexity and
+// second lookup for each query does not seem worth it.
+type cacheKey struct {
+ Name string
+ Qtype uint16
+ DNSSEC bool
+ CheckingDisabled bool
+}
+
+type cacheItem struct {
+ key cacheKey
+
+ // lock protects all fields except key. It also doubles as a way to wait for
+ // the reply. A write lock is held for as long as a query is pending.
+ lock sync.RWMutex
+
+ reply proxyReply
+ stored time.Time
+ // ttl is the number of seconds during which the cached reply can be used.
+ ttl uint32
+ // seenTruncated is true if we ever saw a truncated response for this key.
+ // We will then always use TCP when refetching after the item expires.
+ seenTruncated bool
+}
+
+func (k cacheKey) hash(seed maphash.Seed) uint64 {
+ var h maphash.Hash
+ h.SetSeed(seed)
+ h.WriteByte(byte(k.Qtype))
+ h.WriteByte(byte(k.Qtype >> 8))
+ var flags byte
+ if k.DNSSEC {
+ flags += 1
+ }
+ if k.CheckingDisabled {
+ flags += 2
+ }
+ h.WriteByte(flags)
+ h.WriteString(k.Name)
+ return h.Sum64()
+}
+
+// valid returns true if the cache item can be used for this query.
+func (i *cacheItem) valid(now time.Time, tcp bool) bool {
+ expired := now.After(i.stored.Add(time.Duration(i.ttl) * time.Second))
+ return !expired && (!tcp || !i.reply.Truncated)
+}
+
+func (f *Forward) HandleDNS(r *netDNS.Request) {
+ if !r.Reply.RecursionDesired {
+ // Only forward queries if the RD flag is set. If the question has been
+ // redirected by CNAME, return the reply as is without following the CNAME,
+ // else set a REFUSED rcode.
+ if r.Qname == r.Reply.Question[0].Name {
+ r.Reply.Rcode = dns.RcodeRefused
+ rejectsCount.WithLabelValues("no_recursion_desired").Inc()
+ }
+ } else {
+ f.lookupOrForward(r)
+ }
+ r.SendReply()
+}
+
+func (f *Forward) lookupOrForward(r *netDNS.Request) {
+ key := cacheKey{
+ Name: r.QnameCanonical,
+ Qtype: r.Qtype,
+ DNSSEC: r.Ropt != nil && r.Ropt.Do(),
+ CheckingDisabled: r.Reply.CheckingDisabled,
+ }
+ hash := key.hash(f.seed)
+ tcp := r.Writer.RemoteAddr().Network() == "tcp"
+
+ item, exists := f.cache.Get(hash)
+ if !exists {
+ // The lookup failed; allocate a new item and try to insert it.
+ // Lock the new item before inserting it, such that concurrent queries
+ // are blocked until we receive the reply and store it in the item.
+ newItem := &cacheItem{key: key}
+ newItem.lock.Lock()
+ item, exists = f.cache.GetOrPut(hash, newItem)
+ if !exists {
+ cacheLookupsCount.WithLabelValues("miss").Inc()
+ f.forward(r, newItem, hash, tcp)
+ newItem.lock.Unlock()
+ return
+ }
+ }
+ if item.key != key {
+ // We have a hash collision. Replace the existing item.
+ cacheLookupsCount.WithLabelValues("miss").Inc()
+ newItem := &cacheItem{key: key}
+ newItem.lock.Lock()
+ f.cache.Put(hash, newItem)
+ f.forward(r, newItem, hash, tcp)
+ newItem.lock.Unlock()
+ return
+ }
+
+ // Take a read lock and check if the reply is valid for this query.
+ // This blocks if a query for this item is pending.
+ item.lock.RLock()
+ now := f.now()
+ if item.valid(now, tcp) {
+ replyFromCache(r, item, now)
+ item.lock.RUnlock()
+ return
+ }
+ item.lock.RUnlock()
+
+ item.lock.Lock()
+ now = f.now()
+ if item.valid(now, tcp) {
+ replyFromCache(r, item, now)
+ item.lock.Unlock()
+ return
+ }
+ cacheLookupsCount.WithLabelValues("refresh").Inc()
+ f.forward(r, item, hash, tcp || item.seenTruncated)
+ item.lock.Unlock()
+}
+
+func (f *Forward) forward(r *netDNS.Request, item *cacheItem, hash uint64, tcp bool) {
+ // Query proxies.
+ var queryOptions []dns.EDNS0
+ if r.Qopt != nil {
+ // Forward DNSSEC algorithm understood options. These are only for
+ // statistics and must not influence the reply, so we do not need to include
+ // them in the cache key.
+ for _, option := range r.Qopt.Option {
+ switch option.(type) {
+ case *dns.EDNS0_DAU, *dns.EDNS0_DHU, *dns.EDNS0_N3U:
+ queryOptions = append(queryOptions, option)
+ }
+ }
+ }
+
+ question := dns.Question{
+ Name: item.key.Name,
+ Qtype: item.key.Qtype,
+ Qclass: dns.ClassINET,
+ }
+ reply := f.queryProxies(question, item.key.DNSSEC, item.key.CheckingDisabled, queryOptions, tcp)
+
+ r.Reply.Truncated = reply.Truncated
+ r.Reply.Rcode = reply.Rcode
+ r.Reply.Answer = appendOrClip(r.Reply.Answer, reply.Answer)
+ r.Reply.Ns = appendOrClip(r.Reply.Ns, reply.Ns)
+ r.Reply.Extra = appendOrClip(r.Reply.Extra, reply.Extra)
+ if r.Ropt != nil {
+ r.Ropt.Option = appendOrClip(r.Ropt.Option, reply.Options)
+ }
+
+ item.reply = reply
+ if reply.Truncated {
+ item.seenTruncated = true
+ }
+ item.stored = f.now()
+
+ // Compute how long to cache the item.
+ ttl := uint32(cacheMaxSeconds)
+ // If the reply is an error, or contains no ttls, use the minimum cache time.
+ if (reply.Rcode != dns.RcodeSuccess && reply.Rcode != dns.RcodeNameError) ||
+ len(reply.Answer)+len(reply.Ns)+len(reply.Extra) == 0 {
+ ttl = cacheMinSeconds
+ }
+ for _, rr := range reply.Answer {
+ ttl = min(ttl, rr.Header().Ttl)
+ }
+ for _, rr := range reply.Ns {
+ ttl = min(ttl, rr.Header().Ttl)
+ }
+ for _, rr := range reply.Extra {
+ ttl = min(ttl, rr.Header().Ttl)
+ }
+ item.ttl = max(ttl, cacheMinSeconds)
+
+ if reply.NoStore {
+ f.cache.Remove(hash)
+ }
+}
+
+func replyFromCache(r *netDNS.Request, item *cacheItem, now time.Time) {
+ cacheLookupsCount.WithLabelValues("hit").Inc()
+ decrementTtl := uint32(max(0, now.Sub(item.stored)/time.Second))
+
+ r.Reply.Truncated = item.reply.Truncated
+ r.Reply.Rcode = item.reply.Rcode
+
+ existing_len := len(r.Reply.Answer)
+ r.Reply.Answer = appendCached(r.Reply.Answer, item.reply.Answer, decrementTtl)
+ shuffleAnswer(r.Reply.Answer[existing_len:])
+ r.Reply.Ns = appendCached(r.Reply.Ns, item.reply.Ns, decrementTtl)
+ r.Reply.Extra = appendCached(r.Reply.Extra, item.reply.Extra, decrementTtl)
+ if r.Ropt != nil {
+ r.Ropt.Option = appendOrClip(r.Ropt.Option, item.reply.Options)
+ }
+}
+
+func appendCached(existing, add []dns.RR, decrementTtl uint32) []dns.RR {
+ existing = slices.Grow(existing, len(add))
+ for _, rr := range add {
+ decRR := dns.Copy(rr)
+ hdr := decRR.Header()
+ if hdr.Ttl == 0 {
+ } else if decrementTtl >= hdr.Ttl {
+ // Don't decrement the TTL to 0, as that could cause problems.
+ // https://00f.net/2011/11/17/how-long-does-a-dns-ttl-last/
+ hdr.Ttl = 1
+ } else {
+ hdr.Ttl = hdr.Ttl - decrementTtl
+ }
+ existing = append(existing, decRR)
+ }
+ return existing
+}
+
+// shuffleAnswer randomizes the order of consecutive RRs which are part of the
+// same RRset. This provides some load balancing.
+func shuffleAnswer(rrs []dns.RR) {
+ if len(rrs) < 2 {
+ return
+ }
+ startIndex := 0
+ startHdr := rrs[0].Header()
+ for i := 1; i < len(rrs); i++ {
+ hdr := rrs[i].Header()
+ sameRRset := startHdr.Rrtype == hdr.Rrtype &&
+ startHdr.Class == hdr.Class &&
+ startHdr.Name == hdr.Name
+ if sameRRset {
+ swap := startIndex + rand.IntN(i+1-startIndex)
+ rrs[i], rrs[swap] = rrs[swap], rrs[i]
+ } else {
+ startIndex = i
+ startHdr = hdr
+ }
+ }
+}
+
+// appendOrClip is similar to append(a, b...) except that it avoids allocation
+// if a is empty, in which case it returns b with any free capacity removed.
+// The resulting slice can still be appended to without affecting b.
+func appendOrClip[S ~[]E, E any](a, b S) S {
+ if len(a) == 0 {
+ return slices.Clip(b)
+ }
+ return append(a, b...)
+}