blob: 55697ca40243ca8e69041aed8999eff9cef0e3c0 [file] [log] [blame]
Tim Windelschmidt6d33a432025-02-04 14:34:25 +01001// Copyright The Monogon Project Authors.
2// SPDX-License-Identifier: Apache-2.0
3
Jan Schär75ea9f42024-07-29 17:01:41 +02004package forward
5
6import (
7 "hash/maphash"
8 "math/rand/v2"
9 "slices"
10 "sync"
11 "time"
12
13 "github.com/miekg/dns"
14
15 netDNS "source.monogon.dev/osbase/net/dns"
16)
17
18// The cache uses at most cacheMaxItemSize * cacheCapacity = 20 MB of memory.
19// Actual memory usage may be slightly higher due to the overhead of in-memory
20// data structures compared to the serialized, uncompressed length.
21const (
22 cacheMaxItemSize = 2048
23 cacheCapacity = 10000
24 cacheMinSeconds = 1
25 cacheMaxSeconds = 5
26)
27
28// cacheKey is the key used for cache lookups. Both the DNSSEC ok and the
29// Checking Disabled flag influence the reply. While it would be possible to
30// always make upstream queries with DNSSEC, and then strip the authenticating
31// records if the client did not request it, this would mostly just waste
32// bandwidth. In theory, it would be possible to cache NXDOMAINs independently
33// of the QTYPE (RFC 2308, Section 5). However, the additional complexity and
34// second lookup for each query does not seem worth it.
35type cacheKey struct {
36 Name string
37 Qtype uint16
38 DNSSEC bool
39 CheckingDisabled bool
40}
41
42type cacheItem struct {
43 key cacheKey
44
45 // lock protects all fields except key. It also doubles as a way to wait for
46 // the reply. A write lock is held for as long as a query is pending.
47 lock sync.RWMutex
48
49 reply proxyReply
50 stored time.Time
51 // ttl is the number of seconds during which the cached reply can be used.
52 ttl uint32
53 // seenTruncated is true if we ever saw a truncated response for this key.
54 // We will then always use TCP when refetching after the item expires.
55 seenTruncated bool
56}
57
58func (k cacheKey) hash(seed maphash.Seed) uint64 {
59 var h maphash.Hash
60 h.SetSeed(seed)
61 h.WriteByte(byte(k.Qtype))
62 h.WriteByte(byte(k.Qtype >> 8))
63 var flags byte
64 if k.DNSSEC {
65 flags += 1
66 }
67 if k.CheckingDisabled {
68 flags += 2
69 }
70 h.WriteByte(flags)
71 h.WriteString(k.Name)
72 return h.Sum64()
73}
74
75// valid returns true if the cache item can be used for this query.
76func (i *cacheItem) valid(now time.Time, tcp bool) bool {
77 expired := now.After(i.stored.Add(time.Duration(i.ttl) * time.Second))
78 return !expired && (!tcp || !i.reply.Truncated)
79}
80
81func (f *Forward) HandleDNS(r *netDNS.Request) {
82 if !r.Reply.RecursionDesired {
83 // Only forward queries if the RD flag is set. If the question has been
84 // redirected by CNAME, return the reply as is without following the CNAME,
85 // else set a REFUSED rcode.
86 if r.Qname == r.Reply.Question[0].Name {
87 r.Reply.Rcode = dns.RcodeRefused
88 rejectsCount.WithLabelValues("no_recursion_desired").Inc()
89 }
90 } else {
91 f.lookupOrForward(r)
92 }
93 r.SendReply()
94}
95
96func (f *Forward) lookupOrForward(r *netDNS.Request) {
97 key := cacheKey{
98 Name: r.QnameCanonical,
99 Qtype: r.Qtype,
100 DNSSEC: r.Ropt != nil && r.Ropt.Do(),
101 CheckingDisabled: r.Reply.CheckingDisabled,
102 }
103 hash := key.hash(f.seed)
104 tcp := r.Writer.RemoteAddr().Network() == "tcp"
105
106 item, exists := f.cache.Get(hash)
107 if !exists {
108 // The lookup failed; allocate a new item and try to insert it.
109 // Lock the new item before inserting it, such that concurrent queries
110 // are blocked until we receive the reply and store it in the item.
111 newItem := &cacheItem{key: key}
112 newItem.lock.Lock()
113 item, exists = f.cache.GetOrPut(hash, newItem)
114 if !exists {
115 cacheLookupsCount.WithLabelValues("miss").Inc()
116 f.forward(r, newItem, hash, tcp)
117 newItem.lock.Unlock()
118 return
119 }
120 }
121 if item.key != key {
122 // We have a hash collision. Replace the existing item.
123 cacheLookupsCount.WithLabelValues("miss").Inc()
124 newItem := &cacheItem{key: key}
125 newItem.lock.Lock()
126 f.cache.Put(hash, newItem)
127 f.forward(r, newItem, hash, tcp)
128 newItem.lock.Unlock()
129 return
130 }
131
132 // Take a read lock and check if the reply is valid for this query.
133 // This blocks if a query for this item is pending.
134 item.lock.RLock()
135 now := f.now()
136 if item.valid(now, tcp) {
137 replyFromCache(r, item, now)
138 item.lock.RUnlock()
139 return
140 }
141 item.lock.RUnlock()
142
143 item.lock.Lock()
144 now = f.now()
145 if item.valid(now, tcp) {
146 replyFromCache(r, item, now)
147 item.lock.Unlock()
148 return
149 }
150 cacheLookupsCount.WithLabelValues("refresh").Inc()
151 f.forward(r, item, hash, tcp || item.seenTruncated)
152 item.lock.Unlock()
153}
154
155func (f *Forward) forward(r *netDNS.Request, item *cacheItem, hash uint64, tcp bool) {
156 // Query proxies.
157 var queryOptions []dns.EDNS0
158 if r.Qopt != nil {
159 // Forward DNSSEC algorithm understood options. These are only for
160 // statistics and must not influence the reply, so we do not need to include
161 // them in the cache key.
162 for _, option := range r.Qopt.Option {
163 switch option.(type) {
164 case *dns.EDNS0_DAU, *dns.EDNS0_DHU, *dns.EDNS0_N3U:
165 queryOptions = append(queryOptions, option)
166 }
167 }
168 }
169
170 question := dns.Question{
171 Name: item.key.Name,
172 Qtype: item.key.Qtype,
173 Qclass: dns.ClassINET,
174 }
175 reply := f.queryProxies(question, item.key.DNSSEC, item.key.CheckingDisabled, queryOptions, tcp)
176
177 r.Reply.Truncated = reply.Truncated
178 r.Reply.Rcode = reply.Rcode
179 r.Reply.Answer = appendOrClip(r.Reply.Answer, reply.Answer)
180 r.Reply.Ns = appendOrClip(r.Reply.Ns, reply.Ns)
181 r.Reply.Extra = appendOrClip(r.Reply.Extra, reply.Extra)
182 if r.Ropt != nil {
183 r.Ropt.Option = appendOrClip(r.Ropt.Option, reply.Options)
184 }
185
186 item.reply = reply
187 if reply.Truncated {
188 item.seenTruncated = true
189 }
190 item.stored = f.now()
191
192 // Compute how long to cache the item.
193 ttl := uint32(cacheMaxSeconds)
194 // If the reply is an error, or contains no ttls, use the minimum cache time.
195 if (reply.Rcode != dns.RcodeSuccess && reply.Rcode != dns.RcodeNameError) ||
196 len(reply.Answer)+len(reply.Ns)+len(reply.Extra) == 0 {
197 ttl = cacheMinSeconds
198 }
199 for _, rr := range reply.Answer {
200 ttl = min(ttl, rr.Header().Ttl)
201 }
202 for _, rr := range reply.Ns {
203 ttl = min(ttl, rr.Header().Ttl)
204 }
205 for _, rr := range reply.Extra {
206 ttl = min(ttl, rr.Header().Ttl)
207 }
208 item.ttl = max(ttl, cacheMinSeconds)
209
210 if reply.NoStore {
211 f.cache.Remove(hash)
212 }
213}
214
215func replyFromCache(r *netDNS.Request, item *cacheItem, now time.Time) {
216 cacheLookupsCount.WithLabelValues("hit").Inc()
217 decrementTtl := uint32(max(0, now.Sub(item.stored)/time.Second))
218
219 r.Reply.Truncated = item.reply.Truncated
220 r.Reply.Rcode = item.reply.Rcode
221
222 existing_len := len(r.Reply.Answer)
223 r.Reply.Answer = appendCached(r.Reply.Answer, item.reply.Answer, decrementTtl)
224 shuffleAnswer(r.Reply.Answer[existing_len:])
225 r.Reply.Ns = appendCached(r.Reply.Ns, item.reply.Ns, decrementTtl)
226 r.Reply.Extra = appendCached(r.Reply.Extra, item.reply.Extra, decrementTtl)
227 if r.Ropt != nil {
228 r.Ropt.Option = appendOrClip(r.Ropt.Option, item.reply.Options)
229 }
230}
231
232func appendCached(existing, add []dns.RR, decrementTtl uint32) []dns.RR {
233 existing = slices.Grow(existing, len(add))
234 for _, rr := range add {
235 decRR := dns.Copy(rr)
236 hdr := decRR.Header()
237 if hdr.Ttl == 0 {
238 } else if decrementTtl >= hdr.Ttl {
239 // Don't decrement the TTL to 0, as that could cause problems.
240 // https://00f.net/2011/11/17/how-long-does-a-dns-ttl-last/
241 hdr.Ttl = 1
242 } else {
243 hdr.Ttl = hdr.Ttl - decrementTtl
244 }
245 existing = append(existing, decRR)
246 }
247 return existing
248}
249
250// shuffleAnswer randomizes the order of consecutive RRs which are part of the
251// same RRset. This provides some load balancing.
252func shuffleAnswer(rrs []dns.RR) {
253 if len(rrs) < 2 {
254 return
255 }
256 startIndex := 0
257 startHdr := rrs[0].Header()
258 for i := 1; i < len(rrs); i++ {
259 hdr := rrs[i].Header()
260 sameRRset := startHdr.Rrtype == hdr.Rrtype &&
261 startHdr.Class == hdr.Class &&
262 startHdr.Name == hdr.Name
263 if sameRRset {
264 swap := startIndex + rand.IntN(i+1-startIndex)
265 rrs[i], rrs[swap] = rrs[swap], rrs[i]
266 } else {
267 startIndex = i
268 startHdr = hdr
269 }
270 }
271}
272
273// appendOrClip is similar to append(a, b...) except that it avoids allocation
274// if a is empty, in which case it returns b with any free capacity removed.
275// The resulting slice can still be appended to without affecting b.
276func appendOrClip[S ~[]E, E any](a, b S) S {
277 if len(a) == 0 {
278 return slices.Clip(b)
279 }
280 return append(a, b...)
281}