| Jan Schär | 75ea9f4 | 2024-07-29 17:01:41 +0200 | [diff] [blame^] | 1 | package forward |
| 2 | |
| 3 | import ( |
| 4 | "hash/maphash" |
| 5 | "math/rand/v2" |
| 6 | "slices" |
| 7 | "sync" |
| 8 | "time" |
| 9 | |
| 10 | "github.com/miekg/dns" |
| 11 | |
| 12 | netDNS "source.monogon.dev/osbase/net/dns" |
| 13 | ) |
| 14 | |
| 15 | // The cache uses at most cacheMaxItemSize * cacheCapacity = 20 MB of memory. |
| 16 | // Actual memory usage may be slightly higher due to the overhead of in-memory |
| 17 | // data structures compared to the serialized, uncompressed length. |
| 18 | const ( |
| 19 | cacheMaxItemSize = 2048 |
| 20 | cacheCapacity = 10000 |
| 21 | cacheMinSeconds = 1 |
| 22 | cacheMaxSeconds = 5 |
| 23 | ) |
| 24 | |
| 25 | // cacheKey is the key used for cache lookups. Both the DNSSEC ok and the |
| 26 | // Checking Disabled flag influence the reply. While it would be possible to |
| 27 | // always make upstream queries with DNSSEC, and then strip the authenticating |
| 28 | // records if the client did not request it, this would mostly just waste |
| 29 | // bandwidth. In theory, it would be possible to cache NXDOMAINs independently |
| 30 | // of the QTYPE (RFC 2308, Section 5). However, the additional complexity and |
| 31 | // second lookup for each query does not seem worth it. |
| 32 | type cacheKey struct { |
| 33 | Name string |
| 34 | Qtype uint16 |
| 35 | DNSSEC bool |
| 36 | CheckingDisabled bool |
| 37 | } |
| 38 | |
| 39 | type cacheItem struct { |
| 40 | key cacheKey |
| 41 | |
| 42 | // lock protects all fields except key. It also doubles as a way to wait for |
| 43 | // the reply. A write lock is held for as long as a query is pending. |
| 44 | lock sync.RWMutex |
| 45 | |
| 46 | reply proxyReply |
| 47 | stored time.Time |
| 48 | // ttl is the number of seconds during which the cached reply can be used. |
| 49 | ttl uint32 |
| 50 | // seenTruncated is true if we ever saw a truncated response for this key. |
| 51 | // We will then always use TCP when refetching after the item expires. |
| 52 | seenTruncated bool |
| 53 | } |
| 54 | |
| 55 | func (k cacheKey) hash(seed maphash.Seed) uint64 { |
| 56 | var h maphash.Hash |
| 57 | h.SetSeed(seed) |
| 58 | h.WriteByte(byte(k.Qtype)) |
| 59 | h.WriteByte(byte(k.Qtype >> 8)) |
| 60 | var flags byte |
| 61 | if k.DNSSEC { |
| 62 | flags += 1 |
| 63 | } |
| 64 | if k.CheckingDisabled { |
| 65 | flags += 2 |
| 66 | } |
| 67 | h.WriteByte(flags) |
| 68 | h.WriteString(k.Name) |
| 69 | return h.Sum64() |
| 70 | } |
| 71 | |
| 72 | // valid returns true if the cache item can be used for this query. |
| 73 | func (i *cacheItem) valid(now time.Time, tcp bool) bool { |
| 74 | expired := now.After(i.stored.Add(time.Duration(i.ttl) * time.Second)) |
| 75 | return !expired && (!tcp || !i.reply.Truncated) |
| 76 | } |
| 77 | |
| 78 | func (f *Forward) HandleDNS(r *netDNS.Request) { |
| 79 | if !r.Reply.RecursionDesired { |
| 80 | // Only forward queries if the RD flag is set. If the question has been |
| 81 | // redirected by CNAME, return the reply as is without following the CNAME, |
| 82 | // else set a REFUSED rcode. |
| 83 | if r.Qname == r.Reply.Question[0].Name { |
| 84 | r.Reply.Rcode = dns.RcodeRefused |
| 85 | rejectsCount.WithLabelValues("no_recursion_desired").Inc() |
| 86 | } |
| 87 | } else { |
| 88 | f.lookupOrForward(r) |
| 89 | } |
| 90 | r.SendReply() |
| 91 | } |
| 92 | |
| 93 | func (f *Forward) lookupOrForward(r *netDNS.Request) { |
| 94 | key := cacheKey{ |
| 95 | Name: r.QnameCanonical, |
| 96 | Qtype: r.Qtype, |
| 97 | DNSSEC: r.Ropt != nil && r.Ropt.Do(), |
| 98 | CheckingDisabled: r.Reply.CheckingDisabled, |
| 99 | } |
| 100 | hash := key.hash(f.seed) |
| 101 | tcp := r.Writer.RemoteAddr().Network() == "tcp" |
| 102 | |
| 103 | item, exists := f.cache.Get(hash) |
| 104 | if !exists { |
| 105 | // The lookup failed; allocate a new item and try to insert it. |
| 106 | // Lock the new item before inserting it, such that concurrent queries |
| 107 | // are blocked until we receive the reply and store it in the item. |
| 108 | newItem := &cacheItem{key: key} |
| 109 | newItem.lock.Lock() |
| 110 | item, exists = f.cache.GetOrPut(hash, newItem) |
| 111 | if !exists { |
| 112 | cacheLookupsCount.WithLabelValues("miss").Inc() |
| 113 | f.forward(r, newItem, hash, tcp) |
| 114 | newItem.lock.Unlock() |
| 115 | return |
| 116 | } |
| 117 | } |
| 118 | if item.key != key { |
| 119 | // We have a hash collision. Replace the existing item. |
| 120 | cacheLookupsCount.WithLabelValues("miss").Inc() |
| 121 | newItem := &cacheItem{key: key} |
| 122 | newItem.lock.Lock() |
| 123 | f.cache.Put(hash, newItem) |
| 124 | f.forward(r, newItem, hash, tcp) |
| 125 | newItem.lock.Unlock() |
| 126 | return |
| 127 | } |
| 128 | |
| 129 | // Take a read lock and check if the reply is valid for this query. |
| 130 | // This blocks if a query for this item is pending. |
| 131 | item.lock.RLock() |
| 132 | now := f.now() |
| 133 | if item.valid(now, tcp) { |
| 134 | replyFromCache(r, item, now) |
| 135 | item.lock.RUnlock() |
| 136 | return |
| 137 | } |
| 138 | item.lock.RUnlock() |
| 139 | |
| 140 | item.lock.Lock() |
| 141 | now = f.now() |
| 142 | if item.valid(now, tcp) { |
| 143 | replyFromCache(r, item, now) |
| 144 | item.lock.Unlock() |
| 145 | return |
| 146 | } |
| 147 | cacheLookupsCount.WithLabelValues("refresh").Inc() |
| 148 | f.forward(r, item, hash, tcp || item.seenTruncated) |
| 149 | item.lock.Unlock() |
| 150 | } |
| 151 | |
| 152 | func (f *Forward) forward(r *netDNS.Request, item *cacheItem, hash uint64, tcp bool) { |
| 153 | // Query proxies. |
| 154 | var queryOptions []dns.EDNS0 |
| 155 | if r.Qopt != nil { |
| 156 | // Forward DNSSEC algorithm understood options. These are only for |
| 157 | // statistics and must not influence the reply, so we do not need to include |
| 158 | // them in the cache key. |
| 159 | for _, option := range r.Qopt.Option { |
| 160 | switch option.(type) { |
| 161 | case *dns.EDNS0_DAU, *dns.EDNS0_DHU, *dns.EDNS0_N3U: |
| 162 | queryOptions = append(queryOptions, option) |
| 163 | } |
| 164 | } |
| 165 | } |
| 166 | |
| 167 | question := dns.Question{ |
| 168 | Name: item.key.Name, |
| 169 | Qtype: item.key.Qtype, |
| 170 | Qclass: dns.ClassINET, |
| 171 | } |
| 172 | reply := f.queryProxies(question, item.key.DNSSEC, item.key.CheckingDisabled, queryOptions, tcp) |
| 173 | |
| 174 | r.Reply.Truncated = reply.Truncated |
| 175 | r.Reply.Rcode = reply.Rcode |
| 176 | r.Reply.Answer = appendOrClip(r.Reply.Answer, reply.Answer) |
| 177 | r.Reply.Ns = appendOrClip(r.Reply.Ns, reply.Ns) |
| 178 | r.Reply.Extra = appendOrClip(r.Reply.Extra, reply.Extra) |
| 179 | if r.Ropt != nil { |
| 180 | r.Ropt.Option = appendOrClip(r.Ropt.Option, reply.Options) |
| 181 | } |
| 182 | |
| 183 | item.reply = reply |
| 184 | if reply.Truncated { |
| 185 | item.seenTruncated = true |
| 186 | } |
| 187 | item.stored = f.now() |
| 188 | |
| 189 | // Compute how long to cache the item. |
| 190 | ttl := uint32(cacheMaxSeconds) |
| 191 | // If the reply is an error, or contains no ttls, use the minimum cache time. |
| 192 | if (reply.Rcode != dns.RcodeSuccess && reply.Rcode != dns.RcodeNameError) || |
| 193 | len(reply.Answer)+len(reply.Ns)+len(reply.Extra) == 0 { |
| 194 | ttl = cacheMinSeconds |
| 195 | } |
| 196 | for _, rr := range reply.Answer { |
| 197 | ttl = min(ttl, rr.Header().Ttl) |
| 198 | } |
| 199 | for _, rr := range reply.Ns { |
| 200 | ttl = min(ttl, rr.Header().Ttl) |
| 201 | } |
| 202 | for _, rr := range reply.Extra { |
| 203 | ttl = min(ttl, rr.Header().Ttl) |
| 204 | } |
| 205 | item.ttl = max(ttl, cacheMinSeconds) |
| 206 | |
| 207 | if reply.NoStore { |
| 208 | f.cache.Remove(hash) |
| 209 | } |
| 210 | } |
| 211 | |
| 212 | func replyFromCache(r *netDNS.Request, item *cacheItem, now time.Time) { |
| 213 | cacheLookupsCount.WithLabelValues("hit").Inc() |
| 214 | decrementTtl := uint32(max(0, now.Sub(item.stored)/time.Second)) |
| 215 | |
| 216 | r.Reply.Truncated = item.reply.Truncated |
| 217 | r.Reply.Rcode = item.reply.Rcode |
| 218 | |
| 219 | existing_len := len(r.Reply.Answer) |
| 220 | r.Reply.Answer = appendCached(r.Reply.Answer, item.reply.Answer, decrementTtl) |
| 221 | shuffleAnswer(r.Reply.Answer[existing_len:]) |
| 222 | r.Reply.Ns = appendCached(r.Reply.Ns, item.reply.Ns, decrementTtl) |
| 223 | r.Reply.Extra = appendCached(r.Reply.Extra, item.reply.Extra, decrementTtl) |
| 224 | if r.Ropt != nil { |
| 225 | r.Ropt.Option = appendOrClip(r.Ropt.Option, item.reply.Options) |
| 226 | } |
| 227 | } |
| 228 | |
| 229 | func appendCached(existing, add []dns.RR, decrementTtl uint32) []dns.RR { |
| 230 | existing = slices.Grow(existing, len(add)) |
| 231 | for _, rr := range add { |
| 232 | decRR := dns.Copy(rr) |
| 233 | hdr := decRR.Header() |
| 234 | if hdr.Ttl == 0 { |
| 235 | } else if decrementTtl >= hdr.Ttl { |
| 236 | // Don't decrement the TTL to 0, as that could cause problems. |
| 237 | // https://00f.net/2011/11/17/how-long-does-a-dns-ttl-last/ |
| 238 | hdr.Ttl = 1 |
| 239 | } else { |
| 240 | hdr.Ttl = hdr.Ttl - decrementTtl |
| 241 | } |
| 242 | existing = append(existing, decRR) |
| 243 | } |
| 244 | return existing |
| 245 | } |
| 246 | |
| 247 | // shuffleAnswer randomizes the order of consecutive RRs which are part of the |
| 248 | // same RRset. This provides some load balancing. |
| 249 | func shuffleAnswer(rrs []dns.RR) { |
| 250 | if len(rrs) < 2 { |
| 251 | return |
| 252 | } |
| 253 | startIndex := 0 |
| 254 | startHdr := rrs[0].Header() |
| 255 | for i := 1; i < len(rrs); i++ { |
| 256 | hdr := rrs[i].Header() |
| 257 | sameRRset := startHdr.Rrtype == hdr.Rrtype && |
| 258 | startHdr.Class == hdr.Class && |
| 259 | startHdr.Name == hdr.Name |
| 260 | if sameRRset { |
| 261 | swap := startIndex + rand.IntN(i+1-startIndex) |
| 262 | rrs[i], rrs[swap] = rrs[swap], rrs[i] |
| 263 | } else { |
| 264 | startIndex = i |
| 265 | startHdr = hdr |
| 266 | } |
| 267 | } |
| 268 | } |
| 269 | |
| 270 | // appendOrClip is similar to append(a, b...) except that it avoids allocation |
| 271 | // if a is empty, in which case it returns b with any free capacity removed. |
| 272 | // The resulting slice can still be appended to without affecting b. |
| 273 | func appendOrClip[S ~[]E, E any](a, b S) S { |
| 274 | if len(a) == 0 { |
| 275 | return slices.Clip(b) |
| 276 | } |
| 277 | return append(a, b...) |
| 278 | } |