| Jan Schär | 75ea9f4 | 2024-07-29 17:01:41 +0200 | [diff] [blame] | 1 | // Package forward implements a forwarding proxy. |
| 2 | // |
| 3 | // A cache is used to reduce load on the upstream servers. Cached items are only |
| 4 | // used for a short time, because otherwise, we would need to provide a feature |
| 5 | // for flushing the cache. The cache is most useful for taking the load from |
| 6 | // applications making very frequent repeated queries. The cache also doubles as |
| 7 | // a way to merge concurrent identical queries, since items are inserted into |
| 8 | // the cache before sending the query upstream (see also RFC 5452, Section 5). |
| 9 | package forward |
| 10 | |
| 11 | // Taken and modified from the Forward plugin of CoreDNS, under Apache 2.0. |
| 12 | |
| 13 | import ( |
| 14 | "context" |
| 15 | "errors" |
| 16 | "hash/maphash" |
| 17 | "math/rand/v2" |
| 18 | "os" |
| 19 | "slices" |
| 20 | "strconv" |
| 21 | "sync/atomic" |
| 22 | "time" |
| 23 | |
| 24 | "github.com/miekg/dns" |
| 25 | |
| 26 | "source.monogon.dev/osbase/event/memory" |
| 27 | "source.monogon.dev/osbase/net/dns/forward/cache" |
| 28 | "source.monogon.dev/osbase/net/dns/forward/proxy" |
| 29 | "source.monogon.dev/osbase/supervisor" |
| 30 | ) |
| 31 | |
| 32 | const ( |
| 33 | connectionExpire = 10 * time.Second |
| 34 | healthcheckInterval = 500 * time.Millisecond |
| 35 | forwardTimeout = 5 * time.Second |
| 36 | maxFails = 2 |
| 37 | maxConcurrent = 5000 |
| 38 | maxUpstreams = 15 |
| 39 | ) |
| 40 | |
| 41 | // Forward represents a plugin instance that can proxy requests to another (DNS) |
| 42 | // server. It has a list of proxies each representing one upstream proxy. |
| 43 | type Forward struct { |
| 44 | DNSServers memory.Value[[]string] |
| 45 | upstreams atomic.Pointer[[]*proxy.Proxy] |
| 46 | |
| 47 | concurrent atomic.Int64 |
| 48 | |
| 49 | seed maphash.Seed |
| 50 | cache *cache.Cache[*cacheItem] |
| 51 | |
| 52 | // now can be used to override time for testing. |
| 53 | now func() time.Time |
| 54 | } |
| 55 | |
| 56 | // New returns a new Forward. |
| 57 | func New() *Forward { |
| 58 | return &Forward{ |
| 59 | seed: maphash.MakeSeed(), |
| 60 | cache: cache.New[*cacheItem](cacheCapacity), |
| 61 | now: time.Now, |
| 62 | } |
| 63 | } |
| 64 | |
| 65 | func (f *Forward) Run(ctx context.Context) error { |
| 66 | supervisor.Signal(ctx, supervisor.SignalHealthy) |
| 67 | |
| 68 | var lastAddrs []string |
| 69 | var upstreams []*proxy.Proxy |
| 70 | |
| 71 | w := f.DNSServers.Watch() |
| 72 | defer w.Close() |
| 73 | for { |
| 74 | addrs, err := w.Get(ctx) |
| 75 | if err != nil { |
| 76 | for _, p := range upstreams { |
| 77 | p.Stop() |
| 78 | } |
| 79 | f.upstreams.Store(nil) |
| 80 | return err |
| 81 | } |
| 82 | |
| 83 | if len(addrs) > maxUpstreams { |
| 84 | addrs = addrs[:maxUpstreams] |
| 85 | } |
| 86 | |
| 87 | if slices.Equal(addrs, lastAddrs) { |
| 88 | continue |
| 89 | } |
| 90 | lastAddrs = addrs |
| 91 | supervisor.Logger(ctx).Infof("New upstream DNS servers: %s", addrs) |
| 92 | |
| 93 | newAddrs := make(map[string]bool) |
| 94 | for _, addr := range addrs { |
| 95 | newAddrs[addr] = true |
| 96 | } |
| 97 | var newUpstreams []*proxy.Proxy |
| 98 | for _, p := range upstreams { |
| 99 | if newAddrs[p.Addr()] { |
| 100 | delete(newAddrs, p.Addr()) |
| 101 | newUpstreams = append(newUpstreams, p) |
| 102 | } else { |
| 103 | p.Stop() |
| 104 | } |
| 105 | } |
| 106 | for newAddr := range newAddrs { |
| 107 | p := proxy.NewProxy(newAddr) |
| 108 | p.SetExpire(connectionExpire) |
| 109 | p.GetHealthchecker().SetRecursionDesired(true) |
| 110 | p.GetHealthchecker().SetDomain(".") |
| 111 | p.Start(healthcheckInterval) |
| 112 | newUpstreams = append(newUpstreams, p) |
| 113 | } |
| 114 | upstreams = newUpstreams |
| 115 | f.upstreams.Store(&newUpstreams) |
| 116 | } |
| 117 | } |
| 118 | |
| 119 | type proxyReply struct { |
| 120 | // NoStore indicates that the reply should not be stored in the cache. |
| 121 | // This could be because it is cheap to obtain or expensive to store. |
| 122 | NoStore bool |
| 123 | |
| 124 | Truncated bool |
| 125 | Rcode int |
| 126 | Answer []dns.RR |
| 127 | Ns []dns.RR |
| 128 | Extra []dns.RR |
| 129 | Options []dns.EDNS0 |
| 130 | } |
| 131 | |
| 132 | var ( |
| 133 | replyConcurrencyLimit = proxyReply{ |
| 134 | NoStore: true, |
| 135 | Rcode: dns.RcodeServerFailure, |
| 136 | Options: []dns.EDNS0{&dns.EDNS0_EDE{ |
| 137 | InfoCode: dns.ExtendedErrorCodeOther, |
| 138 | ExtraText: "too many concurrent queries", |
| 139 | }}, |
| 140 | } |
| 141 | replyNoUpstreams = proxyReply{ |
| 142 | NoStore: true, |
| 143 | Rcode: dns.RcodeRefused, |
| 144 | Options: []dns.EDNS0{&dns.EDNS0_EDE{ |
| 145 | InfoCode: dns.ExtendedErrorCodeOther, |
| 146 | ExtraText: "no upstream DNS servers configured", |
| 147 | }}, |
| 148 | } |
| 149 | replyProtocolError = proxyReply{ |
| 150 | Rcode: dns.RcodeServerFailure, |
| 151 | Options: []dns.EDNS0{&dns.EDNS0_EDE{ |
| 152 | InfoCode: dns.ExtendedErrorCodeNetworkError, |
| 153 | ExtraText: "DNS protocol error when querying upstream DNS server", |
| 154 | }}, |
| 155 | } |
| 156 | replyTimeout = proxyReply{ |
| 157 | Rcode: dns.RcodeServerFailure, |
| 158 | Options: []dns.EDNS0{&dns.EDNS0_EDE{ |
| 159 | InfoCode: dns.ExtendedErrorCodeNetworkError, |
| 160 | ExtraText: "timeout when querying upstream DNS server", |
| 161 | }}, |
| 162 | } |
| 163 | replyNetworkError = proxyReply{ |
| 164 | Rcode: dns.RcodeServerFailure, |
| 165 | Options: []dns.EDNS0{&dns.EDNS0_EDE{ |
| 166 | InfoCode: dns.ExtendedErrorCodeNetworkError, |
| 167 | ExtraText: "network error when querying upstream DNS server", |
| 168 | }}, |
| 169 | } |
| 170 | ) |
| 171 | |
| 172 | func (f *Forward) queryProxies( |
| 173 | question dns.Question, |
| 174 | dnssec bool, |
| 175 | checkingDisabled bool, |
| 176 | queryOptions []dns.EDNS0, |
| 177 | useTCP bool, |
| 178 | ) proxyReply { |
| 179 | count := f.concurrent.Add(1) |
| 180 | defer f.concurrent.Add(-1) |
| 181 | if count > maxConcurrent { |
| 182 | rejectsCount.WithLabelValues("concurrency_limit").Inc() |
| 183 | return replyConcurrencyLimit |
| 184 | } |
| 185 | |
| 186 | // Construct outgoing query. |
| 187 | qopt := new(dns.OPT) |
| 188 | qopt.Hdr.Name = "." |
| 189 | qopt.Hdr.Rrtype = dns.TypeOPT |
| 190 | qopt.SetUDPSize(proxy.AdvertiseUDPSize) |
| 191 | if dnssec { |
| 192 | qopt.SetDo() |
| 193 | } |
| 194 | qopt.Option = queryOptions |
| 195 | m := &dns.Msg{ |
| 196 | MsgHdr: dns.MsgHdr{ |
| 197 | Opcode: dns.OpcodeQuery, |
| 198 | RecursionDesired: true, |
| 199 | CheckingDisabled: checkingDisabled, |
| 200 | }, |
| 201 | Question: []dns.Question{question}, |
| 202 | Extra: []dns.RR{qopt}, |
| 203 | } |
| 204 | |
| 205 | var list []*proxy.Proxy |
| 206 | if upstreams := f.upstreams.Load(); upstreams != nil { |
| 207 | list = randomList(*upstreams) |
| 208 | } |
| 209 | |
| 210 | if len(list) == 0 { |
| 211 | rejectsCount.WithLabelValues("no_upstreams").Inc() |
| 212 | return replyNoUpstreams |
| 213 | } |
| 214 | |
| 215 | proto := "udp" |
| 216 | if useTCP { |
| 217 | proto = "tcp" |
| 218 | } |
| 219 | |
| 220 | var ( |
| 221 | curUpstream *proxy.Proxy |
| 222 | curStart time.Time |
| 223 | ret *dns.Msg |
| 224 | err error |
| 225 | ) |
| 226 | recordDuration := func(rcode string) { |
| 227 | upstreamDuration.WithLabelValues(curUpstream.Addr(), proto, rcode).Observe(time.Since(curStart).Seconds()) |
| 228 | } |
| 229 | |
| 230 | fails := 0 |
| 231 | i := 0 |
| 232 | listStart := time.Now() |
| 233 | deadline := listStart.Add(forwardTimeout) |
| 234 | for { |
| 235 | if i >= len(list) { |
| 236 | // reached the end of list, reset to begin |
| 237 | i = 0 |
| 238 | fails = 0 |
| 239 | |
| 240 | // Sleep for a bit if the last time we started going through the list was |
| 241 | // very recent. |
| 242 | time.Sleep(time.Until(listStart.Add(time.Second))) |
| 243 | listStart = time.Now() |
| 244 | } |
| 245 | |
| 246 | curUpstream = list[i] |
| 247 | i++ |
| 248 | if curUpstream.Down(maxFails) { |
| 249 | fails++ |
| 250 | if fails < len(list) { |
| 251 | continue |
| 252 | } |
| 253 | // All upstream proxies are dead, assume healthcheck is completely broken |
| 254 | // and connect to a random upstream. |
| 255 | healthcheckBrokenCount.Inc() |
| 256 | } |
| 257 | |
| 258 | curStart = time.Now() |
| 259 | |
| 260 | for { |
| 261 | ret, err = curUpstream.Connect(m, useTCP) |
| 262 | |
| 263 | if errors.Is(err, proxy.ErrCachedClosed) { |
| 264 | // Remote side closed conn, can only happen with TCP. |
| 265 | continue |
| 266 | } |
| 267 | break |
| 268 | } |
| 269 | |
| 270 | if err != nil { |
| 271 | // Kick off health check to see if *our* upstream is broken. |
| 272 | curUpstream.Healthcheck() |
| 273 | |
| 274 | retry := fails < len(list) && time.Now().Before(deadline) |
| 275 | var dnsError *dns.Error |
| 276 | switch { |
| 277 | case errors.Is(err, os.ErrDeadlineExceeded): |
| 278 | recordDuration("timeout") |
| 279 | if !retry { |
| 280 | return replyTimeout |
| 281 | } |
| 282 | case errors.As(err, &dnsError): |
| 283 | recordDuration("protocol_error") |
| 284 | if !retry { |
| 285 | return replyProtocolError |
| 286 | } |
| 287 | default: |
| 288 | recordDuration("network_error") |
| 289 | if !retry { |
| 290 | return replyNetworkError |
| 291 | } |
| 292 | } |
| 293 | continue |
| 294 | } |
| 295 | |
| 296 | break |
| 297 | } |
| 298 | |
| 299 | if !ret.Response || ret.Opcode != dns.OpcodeQuery || len(ret.Question) != 1 { |
| 300 | recordDuration("protocol_error") |
| 301 | return replyProtocolError |
| 302 | } |
| 303 | |
| 304 | if ret.Truncated && useTCP { |
| 305 | recordDuration("protocol_error") |
| 306 | return replyProtocolError |
| 307 | } |
| 308 | if ret.Truncated { |
| 309 | proto = "udp_truncated" |
| 310 | } |
| 311 | |
| 312 | // Check that the reply matches the question. |
| 313 | retq := ret.Question[0] |
| 314 | if retq.Qtype != question.Qtype || retq.Qclass != question.Qclass || |
| 315 | (retq.Name != question.Name && dns.CanonicalName(retq.Name) != question.Name) { |
| 316 | recordDuration("protocol_error") |
| 317 | return replyProtocolError |
| 318 | } |
| 319 | |
| 320 | // Extract OPT from reply. |
| 321 | var ropt *dns.OPT |
| 322 | var options []dns.EDNS0 |
| 323 | for i := len(ret.Extra) - 1; i >= 0; i-- { |
| 324 | if rr, ok := ret.Extra[i].(*dns.OPT); ok { |
| 325 | if ropt != nil { |
| 326 | // Found more than one OPT. |
| 327 | recordDuration("protocol_error") |
| 328 | return replyProtocolError |
| 329 | } |
| 330 | ropt = rr |
| 331 | ret.Extra = append(ret.Extra[:i], ret.Extra[i+1:]...) |
| 332 | } |
| 333 | } |
| 334 | if ropt != nil { |
| 335 | if ropt.Version() != 0 || ropt.Hdr.Name != "." { |
| 336 | recordDuration("protocol_error") |
| 337 | return replyProtocolError |
| 338 | } |
| 339 | // Forward Extended DNS Error options. |
| 340 | for _, option := range ropt.Option { |
| 341 | switch option.(type) { |
| 342 | case *dns.EDNS0_EDE: |
| 343 | options = append(options, option) |
| 344 | } |
| 345 | } |
| 346 | } |
| 347 | |
| 348 | rcode, ok := dns.RcodeToString[ret.Rcode] |
| 349 | if !ok { |
| 350 | // There are 4096 possible Rcodes, so it's probably still fine |
| 351 | // to put it in a metric label. |
| 352 | rcode = strconv.Itoa(ret.Rcode) |
| 353 | } |
| 354 | recordDuration(rcode) |
| 355 | |
| 356 | // AuthenticatedData is intentionally not copied from the proxy reply because |
| 357 | // we don't have a secure channel to the proxy. |
| 358 | return proxyReply{ |
| 359 | // Don't store large messages in the cache. Such large messages are very |
| 360 | // rare, and this protects against the cache using huge amounts of memory. |
| 361 | // DNS messages over TCP can be up to 64 KB in size, and after decompression |
| 362 | // this could go over 1 MB of memory usage. |
| 363 | NoStore: ret.Len() > cacheMaxItemSize, |
| 364 | |
| 365 | Truncated: ret.Truncated, |
| 366 | Rcode: ret.Rcode, |
| 367 | Answer: ret.Answer, |
| 368 | Ns: ret.Ns, |
| 369 | Extra: ret.Extra, |
| 370 | Options: options, |
| 371 | } |
| 372 | } |
| 373 | |
| 374 | func randomList(p []*proxy.Proxy) []*proxy.Proxy { |
| 375 | switch len(p) { |
| 376 | case 1: |
| 377 | return p |
| 378 | case 2: |
| 379 | if rand.Int()%2 == 0 { |
| 380 | return []*proxy.Proxy{p[1], p[0]} // swap |
| 381 | } |
| 382 | return p |
| 383 | } |
| 384 | |
| 385 | shuffled := slices.Clone(p) |
| 386 | rand.Shuffle(len(shuffled), func(i, j int) { |
| 387 | shuffled[i], shuffled[j] = shuffled[j], shuffled[i] |
| 388 | }) |
| 389 | return shuffled |
| 390 | } |