blob: 00271ec16f935da629f65b02c3ecb81ddae1612b [file] [log] [blame]
// Package forward implements a forwarding proxy.
//
// A cache is used to reduce load on the upstream servers. Cached items are only
// used for a short time, because otherwise, we would need to provide a feature
// for flushing the cache. The cache is most useful for taking the load from
// applications making very frequent repeated queries. The cache also doubles as
// a way to merge concurrent identical queries, since items are inserted into
// the cache before sending the query upstream (see also RFC 5452, Section 5).
package forward
// Taken and modified from the Forward plugin of CoreDNS, under Apache 2.0.
import (
"context"
"errors"
"hash/maphash"
"math/rand/v2"
"os"
"slices"
"strconv"
"sync/atomic"
"time"
"github.com/miekg/dns"
"source.monogon.dev/osbase/event/memory"
"source.monogon.dev/osbase/net/dns/forward/cache"
"source.monogon.dev/osbase/net/dns/forward/proxy"
"source.monogon.dev/osbase/supervisor"
)
const (
connectionExpire = 10 * time.Second
healthcheckInterval = 500 * time.Millisecond
forwardTimeout = 5 * time.Second
maxFails = 2
maxConcurrent = 5000
maxUpstreams = 15
)
// Forward represents a plugin instance that can proxy requests to another (DNS)
// server. It has a list of proxies each representing one upstream proxy.
type Forward struct {
DNSServers memory.Value[[]string]
upstreams atomic.Pointer[[]*proxy.Proxy]
concurrent atomic.Int64
seed maphash.Seed
cache *cache.Cache[*cacheItem]
// now can be used to override time for testing.
now func() time.Time
}
// New returns a new Forward.
func New() *Forward {
return &Forward{
seed: maphash.MakeSeed(),
cache: cache.New[*cacheItem](cacheCapacity),
now: time.Now,
}
}
func (f *Forward) Run(ctx context.Context) error {
supervisor.Signal(ctx, supervisor.SignalHealthy)
var lastAddrs []string
var upstreams []*proxy.Proxy
w := f.DNSServers.Watch()
defer w.Close()
for {
addrs, err := w.Get(ctx)
if err != nil {
for _, p := range upstreams {
p.Stop()
}
f.upstreams.Store(nil)
return err
}
if len(addrs) > maxUpstreams {
addrs = addrs[:maxUpstreams]
}
if slices.Equal(addrs, lastAddrs) {
continue
}
lastAddrs = addrs
supervisor.Logger(ctx).Infof("New upstream DNS servers: %s", addrs)
newAddrs := make(map[string]bool)
for _, addr := range addrs {
newAddrs[addr] = true
}
var newUpstreams []*proxy.Proxy
for _, p := range upstreams {
if newAddrs[p.Addr()] {
delete(newAddrs, p.Addr())
newUpstreams = append(newUpstreams, p)
} else {
p.Stop()
}
}
for newAddr := range newAddrs {
p := proxy.NewProxy(newAddr)
p.SetExpire(connectionExpire)
p.GetHealthchecker().SetRecursionDesired(true)
p.GetHealthchecker().SetDomain(".")
p.Start(healthcheckInterval)
newUpstreams = append(newUpstreams, p)
}
upstreams = newUpstreams
f.upstreams.Store(&newUpstreams)
}
}
type proxyReply struct {
// NoStore indicates that the reply should not be stored in the cache.
// This could be because it is cheap to obtain or expensive to store.
NoStore bool
Truncated bool
Rcode int
Answer []dns.RR
Ns []dns.RR
Extra []dns.RR
Options []dns.EDNS0
}
var (
replyConcurrencyLimit = proxyReply{
NoStore: true,
Rcode: dns.RcodeServerFailure,
Options: []dns.EDNS0{&dns.EDNS0_EDE{
InfoCode: dns.ExtendedErrorCodeOther,
ExtraText: "too many concurrent queries",
}},
}
replyNoUpstreams = proxyReply{
NoStore: true,
Rcode: dns.RcodeRefused,
Options: []dns.EDNS0{&dns.EDNS0_EDE{
InfoCode: dns.ExtendedErrorCodeOther,
ExtraText: "no upstream DNS servers configured",
}},
}
replyProtocolError = proxyReply{
Rcode: dns.RcodeServerFailure,
Options: []dns.EDNS0{&dns.EDNS0_EDE{
InfoCode: dns.ExtendedErrorCodeNetworkError,
ExtraText: "DNS protocol error when querying upstream DNS server",
}},
}
replyTimeout = proxyReply{
Rcode: dns.RcodeServerFailure,
Options: []dns.EDNS0{&dns.EDNS0_EDE{
InfoCode: dns.ExtendedErrorCodeNetworkError,
ExtraText: "timeout when querying upstream DNS server",
}},
}
replyNetworkError = proxyReply{
Rcode: dns.RcodeServerFailure,
Options: []dns.EDNS0{&dns.EDNS0_EDE{
InfoCode: dns.ExtendedErrorCodeNetworkError,
ExtraText: "network error when querying upstream DNS server",
}},
}
)
func (f *Forward) queryProxies(
question dns.Question,
dnssec bool,
checkingDisabled bool,
queryOptions []dns.EDNS0,
useTCP bool,
) proxyReply {
count := f.concurrent.Add(1)
defer f.concurrent.Add(-1)
if count > maxConcurrent {
rejectsCount.WithLabelValues("concurrency_limit").Inc()
return replyConcurrencyLimit
}
// Construct outgoing query.
qopt := new(dns.OPT)
qopt.Hdr.Name = "."
qopt.Hdr.Rrtype = dns.TypeOPT
qopt.SetUDPSize(proxy.AdvertiseUDPSize)
if dnssec {
qopt.SetDo()
}
qopt.Option = queryOptions
m := &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
RecursionDesired: true,
CheckingDisabled: checkingDisabled,
},
Question: []dns.Question{question},
Extra: []dns.RR{qopt},
}
var list []*proxy.Proxy
if upstreams := f.upstreams.Load(); upstreams != nil {
list = randomList(*upstreams)
}
if len(list) == 0 {
rejectsCount.WithLabelValues("no_upstreams").Inc()
return replyNoUpstreams
}
proto := "udp"
if useTCP {
proto = "tcp"
}
var (
curUpstream *proxy.Proxy
curStart time.Time
ret *dns.Msg
err error
)
recordDuration := func(rcode string) {
upstreamDuration.WithLabelValues(curUpstream.Addr(), proto, rcode).Observe(time.Since(curStart).Seconds())
}
fails := 0
i := 0
listStart := time.Now()
deadline := listStart.Add(forwardTimeout)
for {
if i >= len(list) {
// reached the end of list, reset to begin
i = 0
fails = 0
// Sleep for a bit if the last time we started going through the list was
// very recent.
time.Sleep(time.Until(listStart.Add(time.Second)))
listStart = time.Now()
}
curUpstream = list[i]
i++
if curUpstream.Down(maxFails) {
fails++
if fails < len(list) {
continue
}
// All upstream proxies are dead, assume healthcheck is completely broken
// and connect to a random upstream.
healthcheckBrokenCount.Inc()
}
curStart = time.Now()
for {
ret, err = curUpstream.Connect(m, useTCP)
if errors.Is(err, proxy.ErrCachedClosed) {
// Remote side closed conn, can only happen with TCP.
continue
}
break
}
if err != nil {
// Kick off health check to see if *our* upstream is broken.
curUpstream.Healthcheck()
retry := fails < len(list) && time.Now().Before(deadline)
var dnsError *dns.Error
switch {
case errors.Is(err, os.ErrDeadlineExceeded):
recordDuration("timeout")
if !retry {
return replyTimeout
}
case errors.As(err, &dnsError):
recordDuration("protocol_error")
if !retry {
return replyProtocolError
}
default:
recordDuration("network_error")
if !retry {
return replyNetworkError
}
}
continue
}
break
}
if !ret.Response || ret.Opcode != dns.OpcodeQuery || len(ret.Question) != 1 {
recordDuration("protocol_error")
return replyProtocolError
}
if ret.Truncated && useTCP {
recordDuration("protocol_error")
return replyProtocolError
}
if ret.Truncated {
proto = "udp_truncated"
}
// Check that the reply matches the question.
retq := ret.Question[0]
if retq.Qtype != question.Qtype || retq.Qclass != question.Qclass ||
(retq.Name != question.Name && dns.CanonicalName(retq.Name) != question.Name) {
recordDuration("protocol_error")
return replyProtocolError
}
// Extract OPT from reply.
var ropt *dns.OPT
var options []dns.EDNS0
for i := len(ret.Extra) - 1; i >= 0; i-- {
if rr, ok := ret.Extra[i].(*dns.OPT); ok {
if ropt != nil {
// Found more than one OPT.
recordDuration("protocol_error")
return replyProtocolError
}
ropt = rr
ret.Extra = append(ret.Extra[:i], ret.Extra[i+1:]...)
}
}
if ropt != nil {
if ropt.Version() != 0 || ropt.Hdr.Name != "." {
recordDuration("protocol_error")
return replyProtocolError
}
// Forward Extended DNS Error options.
for _, option := range ropt.Option {
switch option.(type) {
case *dns.EDNS0_EDE:
options = append(options, option)
}
}
}
rcode, ok := dns.RcodeToString[ret.Rcode]
if !ok {
// There are 4096 possible Rcodes, so it's probably still fine
// to put it in a metric label.
rcode = strconv.Itoa(ret.Rcode)
}
recordDuration(rcode)
// AuthenticatedData is intentionally not copied from the proxy reply because
// we don't have a secure channel to the proxy.
return proxyReply{
// Don't store large messages in the cache. Such large messages are very
// rare, and this protects against the cache using huge amounts of memory.
// DNS messages over TCP can be up to 64 KB in size, and after decompression
// this could go over 1 MB of memory usage.
NoStore: ret.Len() > cacheMaxItemSize,
Truncated: ret.Truncated,
Rcode: ret.Rcode,
Answer: ret.Answer,
Ns: ret.Ns,
Extra: ret.Extra,
Options: options,
}
}
func randomList(p []*proxy.Proxy) []*proxy.Proxy {
switch len(p) {
case 1:
return p
case 2:
if rand.Int()%2 == 0 {
return []*proxy.Proxy{p[1], p[0]} // swap
}
return p
}
shuffled := slices.Clone(p)
rand.Shuffle(len(shuffled), func(i, j int) {
shuffled[i], shuffled[j] = shuffled[j], shuffled[i]
})
return shuffled
}