blob: d2422a902a7fbdfc08fafefd77b91d31d2005bf0 [file] [log] [blame]
Jan Schär75ea9f42024-07-29 17:01:41 +02001package forward
2
3import (
4 "hash/maphash"
5 "math/rand/v2"
6 "slices"
7 "sync"
8 "time"
9
10 "github.com/miekg/dns"
11
12 netDNS "source.monogon.dev/osbase/net/dns"
13)
14
15// The cache uses at most cacheMaxItemSize * cacheCapacity = 20 MB of memory.
16// Actual memory usage may be slightly higher due to the overhead of in-memory
17// data structures compared to the serialized, uncompressed length.
18const (
19 cacheMaxItemSize = 2048
20 cacheCapacity = 10000
21 cacheMinSeconds = 1
22 cacheMaxSeconds = 5
23)
24
25// cacheKey is the key used for cache lookups. Both the DNSSEC ok and the
26// Checking Disabled flag influence the reply. While it would be possible to
27// always make upstream queries with DNSSEC, and then strip the authenticating
28// records if the client did not request it, this would mostly just waste
29// bandwidth. In theory, it would be possible to cache NXDOMAINs independently
30// of the QTYPE (RFC 2308, Section 5). However, the additional complexity and
31// second lookup for each query does not seem worth it.
32type cacheKey struct {
33 Name string
34 Qtype uint16
35 DNSSEC bool
36 CheckingDisabled bool
37}
38
39type cacheItem struct {
40 key cacheKey
41
42 // lock protects all fields except key. It also doubles as a way to wait for
43 // the reply. A write lock is held for as long as a query is pending.
44 lock sync.RWMutex
45
46 reply proxyReply
47 stored time.Time
48 // ttl is the number of seconds during which the cached reply can be used.
49 ttl uint32
50 // seenTruncated is true if we ever saw a truncated response for this key.
51 // We will then always use TCP when refetching after the item expires.
52 seenTruncated bool
53}
54
55func (k cacheKey) hash(seed maphash.Seed) uint64 {
56 var h maphash.Hash
57 h.SetSeed(seed)
58 h.WriteByte(byte(k.Qtype))
59 h.WriteByte(byte(k.Qtype >> 8))
60 var flags byte
61 if k.DNSSEC {
62 flags += 1
63 }
64 if k.CheckingDisabled {
65 flags += 2
66 }
67 h.WriteByte(flags)
68 h.WriteString(k.Name)
69 return h.Sum64()
70}
71
72// valid returns true if the cache item can be used for this query.
73func (i *cacheItem) valid(now time.Time, tcp bool) bool {
74 expired := now.After(i.stored.Add(time.Duration(i.ttl) * time.Second))
75 return !expired && (!tcp || !i.reply.Truncated)
76}
77
78func (f *Forward) HandleDNS(r *netDNS.Request) {
79 if !r.Reply.RecursionDesired {
80 // Only forward queries if the RD flag is set. If the question has been
81 // redirected by CNAME, return the reply as is without following the CNAME,
82 // else set a REFUSED rcode.
83 if r.Qname == r.Reply.Question[0].Name {
84 r.Reply.Rcode = dns.RcodeRefused
85 rejectsCount.WithLabelValues("no_recursion_desired").Inc()
86 }
87 } else {
88 f.lookupOrForward(r)
89 }
90 r.SendReply()
91}
92
93func (f *Forward) lookupOrForward(r *netDNS.Request) {
94 key := cacheKey{
95 Name: r.QnameCanonical,
96 Qtype: r.Qtype,
97 DNSSEC: r.Ropt != nil && r.Ropt.Do(),
98 CheckingDisabled: r.Reply.CheckingDisabled,
99 }
100 hash := key.hash(f.seed)
101 tcp := r.Writer.RemoteAddr().Network() == "tcp"
102
103 item, exists := f.cache.Get(hash)
104 if !exists {
105 // The lookup failed; allocate a new item and try to insert it.
106 // Lock the new item before inserting it, such that concurrent queries
107 // are blocked until we receive the reply and store it in the item.
108 newItem := &cacheItem{key: key}
109 newItem.lock.Lock()
110 item, exists = f.cache.GetOrPut(hash, newItem)
111 if !exists {
112 cacheLookupsCount.WithLabelValues("miss").Inc()
113 f.forward(r, newItem, hash, tcp)
114 newItem.lock.Unlock()
115 return
116 }
117 }
118 if item.key != key {
119 // We have a hash collision. Replace the existing item.
120 cacheLookupsCount.WithLabelValues("miss").Inc()
121 newItem := &cacheItem{key: key}
122 newItem.lock.Lock()
123 f.cache.Put(hash, newItem)
124 f.forward(r, newItem, hash, tcp)
125 newItem.lock.Unlock()
126 return
127 }
128
129 // Take a read lock and check if the reply is valid for this query.
130 // This blocks if a query for this item is pending.
131 item.lock.RLock()
132 now := f.now()
133 if item.valid(now, tcp) {
134 replyFromCache(r, item, now)
135 item.lock.RUnlock()
136 return
137 }
138 item.lock.RUnlock()
139
140 item.lock.Lock()
141 now = f.now()
142 if item.valid(now, tcp) {
143 replyFromCache(r, item, now)
144 item.lock.Unlock()
145 return
146 }
147 cacheLookupsCount.WithLabelValues("refresh").Inc()
148 f.forward(r, item, hash, tcp || item.seenTruncated)
149 item.lock.Unlock()
150}
151
152func (f *Forward) forward(r *netDNS.Request, item *cacheItem, hash uint64, tcp bool) {
153 // Query proxies.
154 var queryOptions []dns.EDNS0
155 if r.Qopt != nil {
156 // Forward DNSSEC algorithm understood options. These are only for
157 // statistics and must not influence the reply, so we do not need to include
158 // them in the cache key.
159 for _, option := range r.Qopt.Option {
160 switch option.(type) {
161 case *dns.EDNS0_DAU, *dns.EDNS0_DHU, *dns.EDNS0_N3U:
162 queryOptions = append(queryOptions, option)
163 }
164 }
165 }
166
167 question := dns.Question{
168 Name: item.key.Name,
169 Qtype: item.key.Qtype,
170 Qclass: dns.ClassINET,
171 }
172 reply := f.queryProxies(question, item.key.DNSSEC, item.key.CheckingDisabled, queryOptions, tcp)
173
174 r.Reply.Truncated = reply.Truncated
175 r.Reply.Rcode = reply.Rcode
176 r.Reply.Answer = appendOrClip(r.Reply.Answer, reply.Answer)
177 r.Reply.Ns = appendOrClip(r.Reply.Ns, reply.Ns)
178 r.Reply.Extra = appendOrClip(r.Reply.Extra, reply.Extra)
179 if r.Ropt != nil {
180 r.Ropt.Option = appendOrClip(r.Ropt.Option, reply.Options)
181 }
182
183 item.reply = reply
184 if reply.Truncated {
185 item.seenTruncated = true
186 }
187 item.stored = f.now()
188
189 // Compute how long to cache the item.
190 ttl := uint32(cacheMaxSeconds)
191 // If the reply is an error, or contains no ttls, use the minimum cache time.
192 if (reply.Rcode != dns.RcodeSuccess && reply.Rcode != dns.RcodeNameError) ||
193 len(reply.Answer)+len(reply.Ns)+len(reply.Extra) == 0 {
194 ttl = cacheMinSeconds
195 }
196 for _, rr := range reply.Answer {
197 ttl = min(ttl, rr.Header().Ttl)
198 }
199 for _, rr := range reply.Ns {
200 ttl = min(ttl, rr.Header().Ttl)
201 }
202 for _, rr := range reply.Extra {
203 ttl = min(ttl, rr.Header().Ttl)
204 }
205 item.ttl = max(ttl, cacheMinSeconds)
206
207 if reply.NoStore {
208 f.cache.Remove(hash)
209 }
210}
211
212func replyFromCache(r *netDNS.Request, item *cacheItem, now time.Time) {
213 cacheLookupsCount.WithLabelValues("hit").Inc()
214 decrementTtl := uint32(max(0, now.Sub(item.stored)/time.Second))
215
216 r.Reply.Truncated = item.reply.Truncated
217 r.Reply.Rcode = item.reply.Rcode
218
219 existing_len := len(r.Reply.Answer)
220 r.Reply.Answer = appendCached(r.Reply.Answer, item.reply.Answer, decrementTtl)
221 shuffleAnswer(r.Reply.Answer[existing_len:])
222 r.Reply.Ns = appendCached(r.Reply.Ns, item.reply.Ns, decrementTtl)
223 r.Reply.Extra = appendCached(r.Reply.Extra, item.reply.Extra, decrementTtl)
224 if r.Ropt != nil {
225 r.Ropt.Option = appendOrClip(r.Ropt.Option, item.reply.Options)
226 }
227}
228
229func appendCached(existing, add []dns.RR, decrementTtl uint32) []dns.RR {
230 existing = slices.Grow(existing, len(add))
231 for _, rr := range add {
232 decRR := dns.Copy(rr)
233 hdr := decRR.Header()
234 if hdr.Ttl == 0 {
235 } else if decrementTtl >= hdr.Ttl {
236 // Don't decrement the TTL to 0, as that could cause problems.
237 // https://00f.net/2011/11/17/how-long-does-a-dns-ttl-last/
238 hdr.Ttl = 1
239 } else {
240 hdr.Ttl = hdr.Ttl - decrementTtl
241 }
242 existing = append(existing, decRR)
243 }
244 return existing
245}
246
247// shuffleAnswer randomizes the order of consecutive RRs which are part of the
248// same RRset. This provides some load balancing.
249func shuffleAnswer(rrs []dns.RR) {
250 if len(rrs) < 2 {
251 return
252 }
253 startIndex := 0
254 startHdr := rrs[0].Header()
255 for i := 1; i < len(rrs); i++ {
256 hdr := rrs[i].Header()
257 sameRRset := startHdr.Rrtype == hdr.Rrtype &&
258 startHdr.Class == hdr.Class &&
259 startHdr.Name == hdr.Name
260 if sameRRset {
261 swap := startIndex + rand.IntN(i+1-startIndex)
262 rrs[i], rrs[swap] = rrs[swap], rrs[i]
263 } else {
264 startIndex = i
265 startHdr = hdr
266 }
267 }
268}
269
270// appendOrClip is similar to append(a, b...) except that it avoids allocation
271// if a is empty, in which case it returns b with any free capacity removed.
272// The resulting slice can still be appended to without affecting b.
273func appendOrClip[S ~[]E, E any](a, b S) S {
274 if len(a) == 0 {
275 return slices.Clip(b)
276 }
277 return append(a, b...)
278}