blob: c4222c9f8c0417ee7b72cb33b077e28255ca8102 [file] [log] [blame]
Tim Windelschmidt6d33a432025-02-04 14:34:25 +01001// Copyright The Monogon Project Authors.
2// SPDX-License-Identifier: Apache-2.0
3
Jan Schär75ea9f42024-07-29 17:01:41 +02004// Package forward implements a forwarding proxy.
5//
6// A cache is used to reduce load on the upstream servers. Cached items are only
7// used for a short time, because otherwise, we would need to provide a feature
8// for flushing the cache. The cache is most useful for taking the load from
9// applications making very frequent repeated queries. The cache also doubles as
10// a way to merge concurrent identical queries, since items are inserted into
11// the cache before sending the query upstream (see also RFC 5452, Section 5).
12package forward
13
14// Taken and modified from the Forward plugin of CoreDNS, under Apache 2.0.
15
16import (
17 "context"
18 "errors"
19 "hash/maphash"
20 "math/rand/v2"
21 "os"
22 "slices"
23 "strconv"
24 "sync/atomic"
25 "time"
26
27 "github.com/miekg/dns"
28
29 "source.monogon.dev/osbase/event/memory"
30 "source.monogon.dev/osbase/net/dns/forward/cache"
31 "source.monogon.dev/osbase/net/dns/forward/proxy"
32 "source.monogon.dev/osbase/supervisor"
33)
34
35const (
36 connectionExpire = 10 * time.Second
37 healthcheckInterval = 500 * time.Millisecond
38 forwardTimeout = 5 * time.Second
39 maxFails = 2
40 maxConcurrent = 5000
41 maxUpstreams = 15
42)
43
44// Forward represents a plugin instance that can proxy requests to another (DNS)
45// server. It has a list of proxies each representing one upstream proxy.
46type Forward struct {
47 DNSServers memory.Value[[]string]
48 upstreams atomic.Pointer[[]*proxy.Proxy]
49
50 concurrent atomic.Int64
51
52 seed maphash.Seed
53 cache *cache.Cache[*cacheItem]
54
55 // now can be used to override time for testing.
56 now func() time.Time
57}
58
59// New returns a new Forward.
60func New() *Forward {
61 return &Forward{
62 seed: maphash.MakeSeed(),
63 cache: cache.New[*cacheItem](cacheCapacity),
64 now: time.Now,
65 }
66}
67
68func (f *Forward) Run(ctx context.Context) error {
69 supervisor.Signal(ctx, supervisor.SignalHealthy)
70
71 var lastAddrs []string
72 var upstreams []*proxy.Proxy
73
74 w := f.DNSServers.Watch()
75 defer w.Close()
76 for {
77 addrs, err := w.Get(ctx)
78 if err != nil {
79 for _, p := range upstreams {
80 p.Stop()
81 }
82 f.upstreams.Store(nil)
83 return err
84 }
85
86 if len(addrs) > maxUpstreams {
87 addrs = addrs[:maxUpstreams]
88 }
89
90 if slices.Equal(addrs, lastAddrs) {
91 continue
92 }
93 lastAddrs = addrs
94 supervisor.Logger(ctx).Infof("New upstream DNS servers: %s", addrs)
95
96 newAddrs := make(map[string]bool)
97 for _, addr := range addrs {
98 newAddrs[addr] = true
99 }
100 var newUpstreams []*proxy.Proxy
101 for _, p := range upstreams {
102 if newAddrs[p.Addr()] {
103 delete(newAddrs, p.Addr())
104 newUpstreams = append(newUpstreams, p)
105 } else {
106 p.Stop()
107 }
108 }
109 for newAddr := range newAddrs {
110 p := proxy.NewProxy(newAddr)
111 p.SetExpire(connectionExpire)
112 p.GetHealthchecker().SetRecursionDesired(true)
113 p.GetHealthchecker().SetDomain(".")
114 p.Start(healthcheckInterval)
115 newUpstreams = append(newUpstreams, p)
116 }
117 upstreams = newUpstreams
118 f.upstreams.Store(&newUpstreams)
119 }
120}
121
122type proxyReply struct {
123 // NoStore indicates that the reply should not be stored in the cache.
124 // This could be because it is cheap to obtain or expensive to store.
125 NoStore bool
126
127 Truncated bool
128 Rcode int
129 Answer []dns.RR
130 Ns []dns.RR
131 Extra []dns.RR
132 Options []dns.EDNS0
133}
134
135var (
136 replyConcurrencyLimit = proxyReply{
137 NoStore: true,
138 Rcode: dns.RcodeServerFailure,
139 Options: []dns.EDNS0{&dns.EDNS0_EDE{
140 InfoCode: dns.ExtendedErrorCodeOther,
141 ExtraText: "too many concurrent queries",
142 }},
143 }
144 replyNoUpstreams = proxyReply{
145 NoStore: true,
146 Rcode: dns.RcodeRefused,
147 Options: []dns.EDNS0{&dns.EDNS0_EDE{
148 InfoCode: dns.ExtendedErrorCodeOther,
149 ExtraText: "no upstream DNS servers configured",
150 }},
151 }
152 replyProtocolError = proxyReply{
153 Rcode: dns.RcodeServerFailure,
154 Options: []dns.EDNS0{&dns.EDNS0_EDE{
155 InfoCode: dns.ExtendedErrorCodeNetworkError,
156 ExtraText: "DNS protocol error when querying upstream DNS server",
157 }},
158 }
159 replyTimeout = proxyReply{
160 Rcode: dns.RcodeServerFailure,
161 Options: []dns.EDNS0{&dns.EDNS0_EDE{
162 InfoCode: dns.ExtendedErrorCodeNetworkError,
163 ExtraText: "timeout when querying upstream DNS server",
164 }},
165 }
166 replyNetworkError = proxyReply{
167 Rcode: dns.RcodeServerFailure,
168 Options: []dns.EDNS0{&dns.EDNS0_EDE{
169 InfoCode: dns.ExtendedErrorCodeNetworkError,
170 ExtraText: "network error when querying upstream DNS server",
171 }},
172 }
173)
174
175func (f *Forward) queryProxies(
176 question dns.Question,
177 dnssec bool,
178 checkingDisabled bool,
179 queryOptions []dns.EDNS0,
180 useTCP bool,
181) proxyReply {
182 count := f.concurrent.Add(1)
183 defer f.concurrent.Add(-1)
184 if count > maxConcurrent {
185 rejectsCount.WithLabelValues("concurrency_limit").Inc()
186 return replyConcurrencyLimit
187 }
188
189 // Construct outgoing query.
190 qopt := new(dns.OPT)
191 qopt.Hdr.Name = "."
192 qopt.Hdr.Rrtype = dns.TypeOPT
193 qopt.SetUDPSize(proxy.AdvertiseUDPSize)
194 if dnssec {
195 qopt.SetDo()
196 }
197 qopt.Option = queryOptions
198 m := &dns.Msg{
199 MsgHdr: dns.MsgHdr{
200 Opcode: dns.OpcodeQuery,
201 RecursionDesired: true,
202 CheckingDisabled: checkingDisabled,
203 },
204 Question: []dns.Question{question},
205 Extra: []dns.RR{qopt},
206 }
207
208 var list []*proxy.Proxy
209 if upstreams := f.upstreams.Load(); upstreams != nil {
210 list = randomList(*upstreams)
211 }
212
213 if len(list) == 0 {
214 rejectsCount.WithLabelValues("no_upstreams").Inc()
215 return replyNoUpstreams
216 }
217
218 proto := "udp"
219 if useTCP {
220 proto = "tcp"
221 }
222
223 var (
224 curUpstream *proxy.Proxy
225 curStart time.Time
226 ret *dns.Msg
227 err error
228 )
229 recordDuration := func(rcode string) {
230 upstreamDuration.WithLabelValues(curUpstream.Addr(), proto, rcode).Observe(time.Since(curStart).Seconds())
231 }
232
233 fails := 0
234 i := 0
235 listStart := time.Now()
236 deadline := listStart.Add(forwardTimeout)
237 for {
238 if i >= len(list) {
239 // reached the end of list, reset to begin
240 i = 0
241 fails = 0
242
243 // Sleep for a bit if the last time we started going through the list was
244 // very recent.
245 time.Sleep(time.Until(listStart.Add(time.Second)))
246 listStart = time.Now()
247 }
248
249 curUpstream = list[i]
250 i++
251 if curUpstream.Down(maxFails) {
252 fails++
253 if fails < len(list) {
254 continue
255 }
256 // All upstream proxies are dead, assume healthcheck is completely broken
257 // and connect to a random upstream.
258 healthcheckBrokenCount.Inc()
259 }
260
261 curStart = time.Now()
262
263 for {
264 ret, err = curUpstream.Connect(m, useTCP)
265
266 if errors.Is(err, proxy.ErrCachedClosed) {
267 // Remote side closed conn, can only happen with TCP.
268 continue
269 }
270 break
271 }
272
273 if err != nil {
274 // Kick off health check to see if *our* upstream is broken.
275 curUpstream.Healthcheck()
276
277 retry := fails < len(list) && time.Now().Before(deadline)
278 var dnsError *dns.Error
279 switch {
280 case errors.Is(err, os.ErrDeadlineExceeded):
281 recordDuration("timeout")
282 if !retry {
283 return replyTimeout
284 }
285 case errors.As(err, &dnsError):
286 recordDuration("protocol_error")
287 if !retry {
288 return replyProtocolError
289 }
290 default:
291 recordDuration("network_error")
292 if !retry {
293 return replyNetworkError
294 }
295 }
296 continue
297 }
298
299 break
300 }
301
302 if !ret.Response || ret.Opcode != dns.OpcodeQuery || len(ret.Question) != 1 {
303 recordDuration("protocol_error")
304 return replyProtocolError
305 }
306
307 if ret.Truncated && useTCP {
308 recordDuration("protocol_error")
309 return replyProtocolError
310 }
311 if ret.Truncated {
312 proto = "udp_truncated"
313 }
314
315 // Check that the reply matches the question.
316 retq := ret.Question[0]
317 if retq.Qtype != question.Qtype || retq.Qclass != question.Qclass ||
318 (retq.Name != question.Name && dns.CanonicalName(retq.Name) != question.Name) {
319 recordDuration("protocol_error")
320 return replyProtocolError
321 }
322
323 // Extract OPT from reply.
324 var ropt *dns.OPT
325 var options []dns.EDNS0
326 for i := len(ret.Extra) - 1; i >= 0; i-- {
327 if rr, ok := ret.Extra[i].(*dns.OPT); ok {
328 if ropt != nil {
329 // Found more than one OPT.
330 recordDuration("protocol_error")
331 return replyProtocolError
332 }
333 ropt = rr
334 ret.Extra = append(ret.Extra[:i], ret.Extra[i+1:]...)
335 }
336 }
337 if ropt != nil {
338 if ropt.Version() != 0 || ropt.Hdr.Name != "." {
339 recordDuration("protocol_error")
340 return replyProtocolError
341 }
342 // Forward Extended DNS Error options.
343 for _, option := range ropt.Option {
344 switch option.(type) {
345 case *dns.EDNS0_EDE:
346 options = append(options, option)
347 }
348 }
349 }
350
351 rcode, ok := dns.RcodeToString[ret.Rcode]
352 if !ok {
353 // There are 4096 possible Rcodes, so it's probably still fine
354 // to put it in a metric label.
355 rcode = strconv.Itoa(ret.Rcode)
356 }
357 recordDuration(rcode)
358
359 // AuthenticatedData is intentionally not copied from the proxy reply because
360 // we don't have a secure channel to the proxy.
361 return proxyReply{
362 // Don't store large messages in the cache. Such large messages are very
363 // rare, and this protects against the cache using huge amounts of memory.
364 // DNS messages over TCP can be up to 64 KB in size, and after decompression
365 // this could go over 1 MB of memory usage.
366 NoStore: ret.Len() > cacheMaxItemSize,
367
368 Truncated: ret.Truncated,
369 Rcode: ret.Rcode,
370 Answer: ret.Answer,
371 Ns: ret.Ns,
372 Extra: ret.Extra,
373 Options: options,
374 }
375}
376
377func randomList(p []*proxy.Proxy) []*proxy.Proxy {
378 switch len(p) {
379 case 1:
380 return p
381 case 2:
382 if rand.Int()%2 == 0 {
383 return []*proxy.Proxy{p[1], p[0]} // swap
384 }
385 return p
386 }
387
388 shuffled := slices.Clone(p)
389 rand.Shuffle(len(shuffled), func(i, j int) {
390 shuffled[i], shuffled[j] = shuffled[j], shuffled[i]
391 })
392 return shuffled
393}