blob: 00271ec16f935da629f65b02c3ecb81ddae1612b [file] [log] [blame]
Jan Schär75ea9f42024-07-29 17:01:41 +02001// Package forward implements a forwarding proxy.
2//
3// A cache is used to reduce load on the upstream servers. Cached items are only
4// used for a short time, because otherwise, we would need to provide a feature
5// for flushing the cache. The cache is most useful for taking the load from
6// applications making very frequent repeated queries. The cache also doubles as
7// a way to merge concurrent identical queries, since items are inserted into
8// the cache before sending the query upstream (see also RFC 5452, Section 5).
9package forward
10
11// Taken and modified from the Forward plugin of CoreDNS, under Apache 2.0.
12
13import (
14 "context"
15 "errors"
16 "hash/maphash"
17 "math/rand/v2"
18 "os"
19 "slices"
20 "strconv"
21 "sync/atomic"
22 "time"
23
24 "github.com/miekg/dns"
25
26 "source.monogon.dev/osbase/event/memory"
27 "source.monogon.dev/osbase/net/dns/forward/cache"
28 "source.monogon.dev/osbase/net/dns/forward/proxy"
29 "source.monogon.dev/osbase/supervisor"
30)
31
32const (
33 connectionExpire = 10 * time.Second
34 healthcheckInterval = 500 * time.Millisecond
35 forwardTimeout = 5 * time.Second
36 maxFails = 2
37 maxConcurrent = 5000
38 maxUpstreams = 15
39)
40
41// Forward represents a plugin instance that can proxy requests to another (DNS)
42// server. It has a list of proxies each representing one upstream proxy.
43type Forward struct {
44 DNSServers memory.Value[[]string]
45 upstreams atomic.Pointer[[]*proxy.Proxy]
46
47 concurrent atomic.Int64
48
49 seed maphash.Seed
50 cache *cache.Cache[*cacheItem]
51
52 // now can be used to override time for testing.
53 now func() time.Time
54}
55
56// New returns a new Forward.
57func New() *Forward {
58 return &Forward{
59 seed: maphash.MakeSeed(),
60 cache: cache.New[*cacheItem](cacheCapacity),
61 now: time.Now,
62 }
63}
64
65func (f *Forward) Run(ctx context.Context) error {
66 supervisor.Signal(ctx, supervisor.SignalHealthy)
67
68 var lastAddrs []string
69 var upstreams []*proxy.Proxy
70
71 w := f.DNSServers.Watch()
72 defer w.Close()
73 for {
74 addrs, err := w.Get(ctx)
75 if err != nil {
76 for _, p := range upstreams {
77 p.Stop()
78 }
79 f.upstreams.Store(nil)
80 return err
81 }
82
83 if len(addrs) > maxUpstreams {
84 addrs = addrs[:maxUpstreams]
85 }
86
87 if slices.Equal(addrs, lastAddrs) {
88 continue
89 }
90 lastAddrs = addrs
91 supervisor.Logger(ctx).Infof("New upstream DNS servers: %s", addrs)
92
93 newAddrs := make(map[string]bool)
94 for _, addr := range addrs {
95 newAddrs[addr] = true
96 }
97 var newUpstreams []*proxy.Proxy
98 for _, p := range upstreams {
99 if newAddrs[p.Addr()] {
100 delete(newAddrs, p.Addr())
101 newUpstreams = append(newUpstreams, p)
102 } else {
103 p.Stop()
104 }
105 }
106 for newAddr := range newAddrs {
107 p := proxy.NewProxy(newAddr)
108 p.SetExpire(connectionExpire)
109 p.GetHealthchecker().SetRecursionDesired(true)
110 p.GetHealthchecker().SetDomain(".")
111 p.Start(healthcheckInterval)
112 newUpstreams = append(newUpstreams, p)
113 }
114 upstreams = newUpstreams
115 f.upstreams.Store(&newUpstreams)
116 }
117}
118
119type proxyReply struct {
120 // NoStore indicates that the reply should not be stored in the cache.
121 // This could be because it is cheap to obtain or expensive to store.
122 NoStore bool
123
124 Truncated bool
125 Rcode int
126 Answer []dns.RR
127 Ns []dns.RR
128 Extra []dns.RR
129 Options []dns.EDNS0
130}
131
132var (
133 replyConcurrencyLimit = proxyReply{
134 NoStore: true,
135 Rcode: dns.RcodeServerFailure,
136 Options: []dns.EDNS0{&dns.EDNS0_EDE{
137 InfoCode: dns.ExtendedErrorCodeOther,
138 ExtraText: "too many concurrent queries",
139 }},
140 }
141 replyNoUpstreams = proxyReply{
142 NoStore: true,
143 Rcode: dns.RcodeRefused,
144 Options: []dns.EDNS0{&dns.EDNS0_EDE{
145 InfoCode: dns.ExtendedErrorCodeOther,
146 ExtraText: "no upstream DNS servers configured",
147 }},
148 }
149 replyProtocolError = proxyReply{
150 Rcode: dns.RcodeServerFailure,
151 Options: []dns.EDNS0{&dns.EDNS0_EDE{
152 InfoCode: dns.ExtendedErrorCodeNetworkError,
153 ExtraText: "DNS protocol error when querying upstream DNS server",
154 }},
155 }
156 replyTimeout = proxyReply{
157 Rcode: dns.RcodeServerFailure,
158 Options: []dns.EDNS0{&dns.EDNS0_EDE{
159 InfoCode: dns.ExtendedErrorCodeNetworkError,
160 ExtraText: "timeout when querying upstream DNS server",
161 }},
162 }
163 replyNetworkError = proxyReply{
164 Rcode: dns.RcodeServerFailure,
165 Options: []dns.EDNS0{&dns.EDNS0_EDE{
166 InfoCode: dns.ExtendedErrorCodeNetworkError,
167 ExtraText: "network error when querying upstream DNS server",
168 }},
169 }
170)
171
172func (f *Forward) queryProxies(
173 question dns.Question,
174 dnssec bool,
175 checkingDisabled bool,
176 queryOptions []dns.EDNS0,
177 useTCP bool,
178) proxyReply {
179 count := f.concurrent.Add(1)
180 defer f.concurrent.Add(-1)
181 if count > maxConcurrent {
182 rejectsCount.WithLabelValues("concurrency_limit").Inc()
183 return replyConcurrencyLimit
184 }
185
186 // Construct outgoing query.
187 qopt := new(dns.OPT)
188 qopt.Hdr.Name = "."
189 qopt.Hdr.Rrtype = dns.TypeOPT
190 qopt.SetUDPSize(proxy.AdvertiseUDPSize)
191 if dnssec {
192 qopt.SetDo()
193 }
194 qopt.Option = queryOptions
195 m := &dns.Msg{
196 MsgHdr: dns.MsgHdr{
197 Opcode: dns.OpcodeQuery,
198 RecursionDesired: true,
199 CheckingDisabled: checkingDisabled,
200 },
201 Question: []dns.Question{question},
202 Extra: []dns.RR{qopt},
203 }
204
205 var list []*proxy.Proxy
206 if upstreams := f.upstreams.Load(); upstreams != nil {
207 list = randomList(*upstreams)
208 }
209
210 if len(list) == 0 {
211 rejectsCount.WithLabelValues("no_upstreams").Inc()
212 return replyNoUpstreams
213 }
214
215 proto := "udp"
216 if useTCP {
217 proto = "tcp"
218 }
219
220 var (
221 curUpstream *proxy.Proxy
222 curStart time.Time
223 ret *dns.Msg
224 err error
225 )
226 recordDuration := func(rcode string) {
227 upstreamDuration.WithLabelValues(curUpstream.Addr(), proto, rcode).Observe(time.Since(curStart).Seconds())
228 }
229
230 fails := 0
231 i := 0
232 listStart := time.Now()
233 deadline := listStart.Add(forwardTimeout)
234 for {
235 if i >= len(list) {
236 // reached the end of list, reset to begin
237 i = 0
238 fails = 0
239
240 // Sleep for a bit if the last time we started going through the list was
241 // very recent.
242 time.Sleep(time.Until(listStart.Add(time.Second)))
243 listStart = time.Now()
244 }
245
246 curUpstream = list[i]
247 i++
248 if curUpstream.Down(maxFails) {
249 fails++
250 if fails < len(list) {
251 continue
252 }
253 // All upstream proxies are dead, assume healthcheck is completely broken
254 // and connect to a random upstream.
255 healthcheckBrokenCount.Inc()
256 }
257
258 curStart = time.Now()
259
260 for {
261 ret, err = curUpstream.Connect(m, useTCP)
262
263 if errors.Is(err, proxy.ErrCachedClosed) {
264 // Remote side closed conn, can only happen with TCP.
265 continue
266 }
267 break
268 }
269
270 if err != nil {
271 // Kick off health check to see if *our* upstream is broken.
272 curUpstream.Healthcheck()
273
274 retry := fails < len(list) && time.Now().Before(deadline)
275 var dnsError *dns.Error
276 switch {
277 case errors.Is(err, os.ErrDeadlineExceeded):
278 recordDuration("timeout")
279 if !retry {
280 return replyTimeout
281 }
282 case errors.As(err, &dnsError):
283 recordDuration("protocol_error")
284 if !retry {
285 return replyProtocolError
286 }
287 default:
288 recordDuration("network_error")
289 if !retry {
290 return replyNetworkError
291 }
292 }
293 continue
294 }
295
296 break
297 }
298
299 if !ret.Response || ret.Opcode != dns.OpcodeQuery || len(ret.Question) != 1 {
300 recordDuration("protocol_error")
301 return replyProtocolError
302 }
303
304 if ret.Truncated && useTCP {
305 recordDuration("protocol_error")
306 return replyProtocolError
307 }
308 if ret.Truncated {
309 proto = "udp_truncated"
310 }
311
312 // Check that the reply matches the question.
313 retq := ret.Question[0]
314 if retq.Qtype != question.Qtype || retq.Qclass != question.Qclass ||
315 (retq.Name != question.Name && dns.CanonicalName(retq.Name) != question.Name) {
316 recordDuration("protocol_error")
317 return replyProtocolError
318 }
319
320 // Extract OPT from reply.
321 var ropt *dns.OPT
322 var options []dns.EDNS0
323 for i := len(ret.Extra) - 1; i >= 0; i-- {
324 if rr, ok := ret.Extra[i].(*dns.OPT); ok {
325 if ropt != nil {
326 // Found more than one OPT.
327 recordDuration("protocol_error")
328 return replyProtocolError
329 }
330 ropt = rr
331 ret.Extra = append(ret.Extra[:i], ret.Extra[i+1:]...)
332 }
333 }
334 if ropt != nil {
335 if ropt.Version() != 0 || ropt.Hdr.Name != "." {
336 recordDuration("protocol_error")
337 return replyProtocolError
338 }
339 // Forward Extended DNS Error options.
340 for _, option := range ropt.Option {
341 switch option.(type) {
342 case *dns.EDNS0_EDE:
343 options = append(options, option)
344 }
345 }
346 }
347
348 rcode, ok := dns.RcodeToString[ret.Rcode]
349 if !ok {
350 // There are 4096 possible Rcodes, so it's probably still fine
351 // to put it in a metric label.
352 rcode = strconv.Itoa(ret.Rcode)
353 }
354 recordDuration(rcode)
355
356 // AuthenticatedData is intentionally not copied from the proxy reply because
357 // we don't have a secure channel to the proxy.
358 return proxyReply{
359 // Don't store large messages in the cache. Such large messages are very
360 // rare, and this protects against the cache using huge amounts of memory.
361 // DNS messages over TCP can be up to 64 KB in size, and after decompression
362 // this could go over 1 MB of memory usage.
363 NoStore: ret.Len() > cacheMaxItemSize,
364
365 Truncated: ret.Truncated,
366 Rcode: ret.Rcode,
367 Answer: ret.Answer,
368 Ns: ret.Ns,
369 Extra: ret.Extra,
370 Options: options,
371 }
372}
373
374func randomList(p []*proxy.Proxy) []*proxy.Proxy {
375 switch len(p) {
376 case 1:
377 return p
378 case 2:
379 if rand.Int()%2 == 0 {
380 return []*proxy.Proxy{p[1], p[0]} // swap
381 }
382 return p
383 }
384
385 shuffled := slices.Clone(p)
386 rand.Shuffle(len(shuffled), func(i, j int) {
387 shuffled[i], shuffled[j] = shuffled[j], shuffled[i]
388 })
389 return shuffled
390}