osbase/net/dns/forward: add DNS forward handler
This adds a DNS server handler for forwarding queries to upstream DNS
resolvers, with a built-in cache. The implementation is partially based
on CoreDNS. The proxy, cache and up packages are only lightly modified.
The forward package itself however is mostly new code. Unlike CoreDNS,
it supports changing upstreams at runtime, and has integrated caching
and answer order randomization.
Some improvements over CoreDNS:
- Concurrent identical queries only result in one upstream query.
- In case of errors, Extended DNS Errors are added to replies.
- Very large replies are not stored in the cache to avoid using too much
memory.
Change-Id: I42294ae4997d621a6e55c98e46a04874eab75c99
Reviewed-on: https://review.monogon.dev/c/monogon/+/3258
Reviewed-by: Lorenz Brun <lorenz@monogon.tech>
Tested-by: Jenkins CI
diff --git a/osbase/net/dns/forward/proxy/BUILD.bazel b/osbase/net/dns/forward/proxy/BUILD.bazel
new file mode 100644
index 0000000..5f3b177
--- /dev/null
+++ b/osbase/net/dns/forward/proxy/BUILD.bazel
@@ -0,0 +1,35 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "proxy",
+ srcs = [
+ "connect.go",
+ "health.go",
+ "metrics.go",
+ "persistent.go",
+ "proxy.go",
+ "type.go",
+ ],
+ importpath = "source.monogon.dev/osbase/net/dns/forward/proxy",
+ visibility = ["//osbase/net/dns/forward:__subpackages__"],
+ deps = [
+ "//osbase/net/dns",
+ "//osbase/net/dns/forward/up",
+ "@com_github_miekg_dns//:dns",
+ "@com_github_prometheus_client_golang//prometheus",
+ ],
+)
+
+go_test(
+ name = "proxy_test",
+ srcs = [
+ "health_test.go",
+ "persistent_test.go",
+ "proxy_test.go",
+ ],
+ embed = [":proxy"],
+ deps = [
+ "//osbase/net/dns/test",
+ "@com_github_miekg_dns//:dns",
+ ],
+)
diff --git a/osbase/net/dns/forward/proxy/connect.go b/osbase/net/dns/forward/proxy/connect.go
new file mode 100644
index 0000000..d1c5b8e
--- /dev/null
+++ b/osbase/net/dns/forward/proxy/connect.go
@@ -0,0 +1,168 @@
+// Package proxy implements a forwarding proxy. It caches an upstream net.Conn
+// for some time, so if the same client returns the upstream's Conn will be
+// precached. Depending on how you benchmark this looks to be 50% faster than
+// just opening a new connection for every client.
+// It works with UDP and TCP and uses inband healthchecking.
+package proxy
+
+// Taken and modified from CoreDNS, under Apache 2.0.
+
+import (
+ "errors"
+ "io"
+ "strings"
+ "sync/atomic"
+ "time"
+
+ "github.com/miekg/dns"
+)
+
+// AdvertiseUDPSize is the maximum message size that we advertise in the OPT RR
+// of UDP messages. This is calculated as the minimum IPv6 MTU (1280) minus
+// size of IPv6 (40) and UDP (8) headers.
+const AdvertiseUDPSize = 1232
+
+// ErrCachedClosed means cached connection was closed by peer.
+var ErrCachedClosed = errors.New("cached connection was closed by peer")
+
+// limitTimeout is a utility function to auto-tune timeout values.
+// Average observed time is moved towards the last observed delay moderated by
+// a weight next timeout to use will be the double of the computed average,
+// limited by min and max frame.
+func limitTimeout(currentAvg *int64, minValue time.Duration, maxValue time.Duration) time.Duration {
+ rt := time.Duration(atomic.LoadInt64(currentAvg))
+ if rt < minValue {
+ return minValue
+ }
+ if rt < maxValue/2 {
+ return 2 * rt
+ }
+ return maxValue
+}
+
+func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight int64) {
+ dt := time.Duration(atomic.LoadInt64(currentAvg))
+ atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight)
+}
+
+func (t *Transport) dialTimeout() time.Duration {
+ return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout)
+}
+
+func (t *Transport) updateDialTimeout(newDialTime time.Duration) {
+ averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight)
+}
+
+// Dial dials the address configured in transport,
+// potentially reusing a connection or creating a new one.
+func (t *Transport) Dial(proto string) (*persistConn, bool, error) {
+ // If tls has been configured; use it.
+ if t.tlsConfig != nil {
+ proto = "tcp-tls"
+ }
+
+ t.dial <- proto
+ pc := <-t.ret
+
+ if pc != nil {
+ return pc, true, nil
+ }
+
+ reqTime := time.Now()
+ timeout := t.dialTimeout()
+ if proto == "tcp-tls" {
+ conn, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, timeout)
+ t.updateDialTimeout(time.Since(reqTime))
+ return &persistConn{c: conn}, false, err
+ }
+ conn, err := dns.DialTimeout(proto, t.addr, timeout)
+ t.updateDialTimeout(time.Since(reqTime))
+ return &persistConn{c: conn}, false, err
+}
+
+// Connect selects an upstream, sends the request and waits for a response.
+func (p *Proxy) Connect(m *dns.Msg, useTCP bool) (*dns.Msg, error) {
+ proto := "udp"
+ if useTCP {
+ proto = "tcp"
+ }
+
+ pc, cached, err := p.transport.Dial(proto)
+ if err != nil {
+ return nil, err
+ }
+
+ pc.c.UDPSize = AdvertiseUDPSize
+
+ pc.c.SetWriteDeadline(time.Now().Add(p.writeTimeout))
+ m.Id = dns.Id()
+
+ if err := pc.c.WriteMsg(m); err != nil {
+ pc.c.Close() // not giving it back
+ if err == io.EOF && cached {
+ return nil, ErrCachedClosed
+ }
+ return nil, err
+ }
+
+ var ret *dns.Msg
+ pc.c.SetReadDeadline(time.Now().Add(p.readTimeout))
+ for {
+ ret, err = pc.c.ReadMsg()
+ if err != nil {
+ if ret != nil && (m.Id == ret.Id) && p.transport.transportTypeFromConn(pc) == typeUDP && shouldTruncateResponse(err) {
+ // For UDP, if the error is an overflow, we probably have an upstream
+ // misbehaving in some way.
+ // (e.g. sending >512 byte responses without an eDNS0 OPT RR).
+ // Instead of returning an error, return an empty response
+ // with TC bit set. This will make the client retry over TCP
+ // (if that's supported) or at least receive a clean error.
+ // The connection is still good so we break before the close.
+
+ // Truncate the response.
+ ret = truncateResponse(ret)
+ break
+ }
+
+ pc.c.Close() // not giving it back
+ if err == io.EOF && cached {
+ return nil, ErrCachedClosed
+ }
+ return ret, err
+ }
+ // drop out-of-order responses
+ if m.Id == ret.Id {
+ break
+ }
+ }
+
+ p.transport.Yield(pc)
+
+ return ret, nil
+}
+
+const cumulativeAvgWeight = 4
+
+// Function to determine if a response should be truncated.
+func shouldTruncateResponse(err error) bool {
+ // This is to handle a scenario in which upstream sets the TC bit,
+ // but doesn't truncate the response and we get ErrBuf instead of overflow.
+ if errors.Is(err, dns.ErrBuf) {
+ return true
+ } else if strings.Contains(err.Error(), "overflow") {
+ return true
+ }
+ return false
+}
+
+// Function to return an empty response with TC (truncated) bit set.
+func truncateResponse(response *dns.Msg) *dns.Msg {
+ // Clear out Answer, Extra, and Ns sections
+ response.Answer = nil
+ response.Extra = nil
+ response.Ns = nil
+
+ // Set TC bit to indicate truncation.
+ response.Truncated = true
+ return response
+}
diff --git a/osbase/net/dns/forward/proxy/health.go b/osbase/net/dns/forward/proxy/health.go
new file mode 100644
index 0000000..4630cdc
--- /dev/null
+++ b/osbase/net/dns/forward/proxy/health.go
@@ -0,0 +1,127 @@
+package proxy
+
+// Taken and modified from CoreDNS, under Apache 2.0.
+
+import (
+ "crypto/tls"
+ "sync/atomic"
+ "time"
+
+ "github.com/miekg/dns"
+)
+
+// HealthChecker checks the upstream health.
+type HealthChecker interface {
+ Check(*Proxy) error
+ SetTLSConfig(*tls.Config)
+ GetTLSConfig() *tls.Config
+ SetRecursionDesired(bool)
+ GetRecursionDesired() bool
+ SetDomain(domain string)
+ GetDomain() string
+ SetTCPTransport()
+ GetReadTimeout() time.Duration
+ SetReadTimeout(time.Duration)
+ GetWriteTimeout() time.Duration
+ SetWriteTimeout(time.Duration)
+}
+
+// dnsHc is a health checker for a DNS endpoint (DNS, and DoT).
+type dnsHc struct {
+ c *dns.Client
+ recursionDesired bool
+ domain string
+}
+
+// NewHealthChecker returns a new HealthChecker.
+func NewHealthChecker(recursionDesired bool, domain string) HealthChecker {
+ c := new(dns.Client)
+ c.Net = "udp"
+ c.ReadTimeout = 1 * time.Second
+ c.WriteTimeout = 1 * time.Second
+
+ return &dnsHc{
+ c: c,
+ recursionDesired: recursionDesired,
+ domain: domain,
+ }
+}
+
+func (h *dnsHc) SetTLSConfig(cfg *tls.Config) {
+ h.c.Net = "tcp-tls"
+ h.c.TLSConfig = cfg
+}
+
+func (h *dnsHc) GetTLSConfig() *tls.Config {
+ return h.c.TLSConfig
+}
+
+func (h *dnsHc) SetRecursionDesired(recursionDesired bool) {
+ h.recursionDesired = recursionDesired
+}
+func (h *dnsHc) GetRecursionDesired() bool {
+ return h.recursionDesired
+}
+
+func (h *dnsHc) SetDomain(domain string) {
+ h.domain = domain
+}
+func (h *dnsHc) GetDomain() string {
+ return h.domain
+}
+
+func (h *dnsHc) SetTCPTransport() {
+ h.c.Net = "tcp"
+}
+
+func (h *dnsHc) GetReadTimeout() time.Duration {
+ return h.c.ReadTimeout
+}
+
+func (h *dnsHc) SetReadTimeout(t time.Duration) {
+ h.c.ReadTimeout = t
+}
+
+func (h *dnsHc) GetWriteTimeout() time.Duration {
+ return h.c.WriteTimeout
+}
+
+func (h *dnsHc) SetWriteTimeout(t time.Duration) {
+ h.c.WriteTimeout = t
+}
+
+// For HC, we send to . IN NS +[no]rec message to the upstream.
+// Dial timeouts and empty replies are considered fails,
+// basically anything else constitutes a healthy upstream.
+
+// Check is used as the up.Func in the up.Probe.
+func (h *dnsHc) Check(p *Proxy) error {
+ err := h.send(p.addr)
+ if err != nil {
+ healthcheckFailureCount.WithLabelValues(p.addr).Inc()
+ p.incrementFails()
+ return err
+ }
+
+ atomic.StoreUint32(&p.fails, 0)
+ return nil
+}
+
+func (h *dnsHc) send(addr string) error {
+ ping := new(dns.Msg)
+ ping.SetQuestion(h.domain, dns.TypeNS)
+ ping.MsgHdr.RecursionDesired = h.recursionDesired
+ ping.SetEdns0(AdvertiseUDPSize, false)
+
+ m, _, err := h.c.Exchange(ping, addr)
+ // If we got a header, we're alright,
+ // basically only care about I/O errors 'n stuff.
+ if err != nil && m != nil {
+ // Silly check, something sane came back.
+ if m.Response || m.Opcode == dns.OpcodeQuery {
+ err = nil
+ }
+ }
+
+ return err
+}
diff --git a/osbase/net/dns/forward/proxy/health_test.go b/osbase/net/dns/forward/proxy/health_test.go
new file mode 100644
index 0000000..ef8b60a
--- /dev/null
+++ b/osbase/net/dns/forward/proxy/health_test.go
@@ -0,0 +1,154 @@
+package proxy
+
+// Taken and modified from CoreDNS, under Apache 2.0.
+
+import (
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/miekg/dns"
+
+ "source.monogon.dev/osbase/net/dns/test"
+)
+
+func TestHealth(t *testing.T) {
+ i := uint32(0)
+ s := test.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ if r.Question[0].Name == "." && r.RecursionDesired == true {
+ atomic.AddUint32(&i, 1)
+ }
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ w.WriteMsg(ret)
+ })
+ defer s.Close()
+
+ hc := NewHealthChecker(true, ".")
+ hc.SetReadTimeout(10 * time.Millisecond)
+ hc.SetWriteTimeout(10 * time.Millisecond)
+
+ p := NewProxy(s.Addr)
+ p.readTimeout = 10 * time.Millisecond
+ err := hc.Check(p)
+ if err != nil {
+ t.Errorf("check failed: %v", err)
+ }
+
+ time.Sleep(20 * time.Millisecond)
+ i1 := atomic.LoadUint32(&i)
+ if i1 != 1 {
+ t.Errorf("Expected number of health checks with RecursionDesired==true to be %d, got %d", 1, i1)
+ }
+}
+
+func TestHealthTCP(t *testing.T) {
+ i := uint32(0)
+ s := test.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ if r.Question[0].Name == "." && r.RecursionDesired == true {
+ atomic.AddUint32(&i, 1)
+ }
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ w.WriteMsg(ret)
+ })
+ defer s.Close()
+
+ hc := NewHealthChecker(true, ".")
+ hc.SetTCPTransport()
+ hc.SetReadTimeout(10 * time.Millisecond)
+ hc.SetWriteTimeout(10 * time.Millisecond)
+
+ p := NewProxy(s.Addr)
+ p.readTimeout = 10 * time.Millisecond
+ err := hc.Check(p)
+ if err != nil {
+ t.Errorf("check failed: %v", err)
+ }
+
+ time.Sleep(20 * time.Millisecond)
+ i1 := atomic.LoadUint32(&i)
+ if i1 != 1 {
+ t.Errorf("Expected number of health checks with RecursionDesired==true to be %d, got %d", 1, i1)
+ }
+}
+
+func TestHealthNoRecursion(t *testing.T) {
+ i := uint32(0)
+ s := test.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ if r.Question[0].Name == "." && r.RecursionDesired == false {
+ atomic.AddUint32(&i, 1)
+ }
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ w.WriteMsg(ret)
+ })
+ defer s.Close()
+
+ hc := NewHealthChecker(false, ".")
+ hc.SetReadTimeout(10 * time.Millisecond)
+ hc.SetWriteTimeout(10 * time.Millisecond)
+
+ p := NewProxy(s.Addr)
+ p.readTimeout = 10 * time.Millisecond
+ err := hc.Check(p)
+ if err != nil {
+ t.Errorf("check failed: %v", err)
+ }
+
+ time.Sleep(20 * time.Millisecond)
+ i1 := atomic.LoadUint32(&i)
+ if i1 != 1 {
+ t.Errorf("Expected number of health checks with RecursionDesired==false to be %d, got %d", 1, i1)
+ }
+}
+
+func TestHealthTimeout(t *testing.T) {
+ s := test.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ // timeout
+ })
+ defer s.Close()
+
+ hc := NewHealthChecker(false, ".")
+ hc.SetReadTimeout(10 * time.Millisecond)
+ hc.SetWriteTimeout(10 * time.Millisecond)
+
+ p := NewProxy(s.Addr)
+ p.readTimeout = 10 * time.Millisecond
+ err := hc.Check(p)
+ if err == nil {
+ t.Errorf("expected error")
+ }
+}
+
+func TestHealthDomain(t *testing.T) {
+ hcDomain := "example.org."
+
+ i := uint32(0)
+ s := test.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ if r.Question[0].Name == hcDomain && r.RecursionDesired == true {
+ atomic.AddUint32(&i, 1)
+ }
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ w.WriteMsg(ret)
+ })
+ defer s.Close()
+
+ hc := NewHealthChecker(true, hcDomain)
+ hc.SetReadTimeout(10 * time.Millisecond)
+ hc.SetWriteTimeout(10 * time.Millisecond)
+
+ p := NewProxy(s.Addr)
+ p.readTimeout = 10 * time.Millisecond
+ err := hc.Check(p)
+ if err != nil {
+ t.Errorf("check failed: %v", err)
+ }
+
+ time.Sleep(12 * time.Millisecond)
+ i1 := atomic.LoadUint32(&i)
+ if i1 != 1 {
+ t.Errorf("Expected number of health checks with Domain==%s to be %d, got %d", hcDomain, 1, i1)
+ }
+}
diff --git a/osbase/net/dns/forward/proxy/metrics.go b/osbase/net/dns/forward/proxy/metrics.go
new file mode 100644
index 0000000..c08f1d8
--- /dev/null
+++ b/osbase/net/dns/forward/proxy/metrics.go
@@ -0,0 +1,19 @@
+package proxy
+
+// Taken and modified from CoreDNS, under Apache 2.0.
+
+import (
+ "github.com/prometheus/client_golang/prometheus"
+
+ "source.monogon.dev/osbase/net/dns"
+)
+
+// Variables declared for monitoring.
+var (
+ healthcheckFailureCount = dns.MetricsFactory.NewCounterVec(prometheus.CounterOpts{
+ Namespace: "dnsserver",
+ Subsystem: "forward",
+ Name: "healthcheck_failures_total",
+ Help: "Counter of the number of failed healthchecks.",
+ }, []string{"to"})
+)
diff --git a/osbase/net/dns/forward/proxy/persistent.go b/osbase/net/dns/forward/proxy/persistent.go
new file mode 100644
index 0000000..cb4f618
--- /dev/null
+++ b/osbase/net/dns/forward/proxy/persistent.go
@@ -0,0 +1,160 @@
+package proxy
+
+// Taken and modified from CoreDNS, under Apache 2.0.
+
+import (
+ "crypto/tls"
+ "sort"
+ "time"
+
+ "github.com/miekg/dns"
+)
+
+// a persistConn hold the dns.Conn and the last used time.
+type persistConn struct {
+ c *dns.Conn
+ used time.Time
+}
+
+// Transport hold the persistent cache.
+type Transport struct {
+ avgDialTime int64 // kind of average time of dial time
+ conns [typeTotalCount][]*persistConn // Buckets for udp, tcp and tcp-tls.
+ expire time.Duration // After this duration a connection is expired.
+ addr string
+ tlsConfig *tls.Config
+
+ dial chan string
+ yield chan *persistConn
+ ret chan *persistConn
+ stop chan bool
+}
+
+func newTransport(addr string) *Transport {
+ t := &Transport{
+ avgDialTime: int64(maxDialTimeout / 2),
+ conns: [typeTotalCount][]*persistConn{},
+ expire: defaultExpire,
+ addr: addr,
+ dial: make(chan string),
+ yield: make(chan *persistConn),
+ ret: make(chan *persistConn),
+ stop: make(chan bool),
+ }
+ return t
+}
+
+// connManager manages the persistent connection cache for UDP and TCP.
+func (t *Transport) connManager() {
+ ticker := time.NewTicker(defaultExpire)
+ defer ticker.Stop()
+Wait:
+ for {
+ select {
+ case proto := <-t.dial:
+ transtype := stringToTransportType(proto)
+ // take the last used conn - complexity O(1)
+ if stack := t.conns[transtype]; len(stack) > 0 {
+ pc := stack[len(stack)-1]
+ if time.Since(pc.used) < t.expire {
+ // Found one, remove from pool and return this conn.
+ t.conns[transtype] = stack[:len(stack)-1]
+ t.ret <- pc
+ continue Wait
+ }
+ // clear entire cache if the last conn is expired
+ t.conns[transtype] = nil
+ // now, the connections being passed to closeConns() are not reachable from
+ // transport methods anymore. So, it's safe to close them in a separate goroutine
+ go closeConns(stack)
+ }
+ t.ret <- nil
+
+ case pc := <-t.yield:
+ transtype := t.transportTypeFromConn(pc)
+ t.conns[transtype] = append(t.conns[transtype], pc)
+
+ case <-ticker.C:
+ t.cleanup(false)
+
+ case <-t.stop:
+ t.cleanup(true)
+ close(t.ret)
+ return
+ }
+ }
+}
+
+// closeConns closes connections.
+func closeConns(conns []*persistConn) {
+ for _, pc := range conns {
+ pc.c.Close()
+ }
+}
+
+// cleanup removes connections from cache.
+func (t *Transport) cleanup(all bool) {
+ staleTime := time.Now().Add(-t.expire)
+ for transtype, stack := range t.conns {
+ if len(stack) == 0 {
+ continue
+ }
+ if all {
+ t.conns[transtype] = nil
+ // now, the connections being passed to closeConns() are not reachable from
+ // transport methods anymore. So, it's safe to close them in a separate goroutine
+ go closeConns(stack)
+ continue
+ }
+ if stack[0].used.After(staleTime) {
+ continue
+ }
+
+ // connections in stack are sorted by "used"
+ good := sort.Search(len(stack), func(i int) bool {
+ return stack[i].used.After(staleTime)
+ })
+ t.conns[transtype] = stack[good:]
+ // now, the connections being passed to closeConns() are not reachable from
+ // transport methods anymore. So, it's safe to close them in a separate goroutine
+ go closeConns(stack[:good])
+ }
+}
+
+// It is hard to pin a value to this, the import thing is to no block forever,
+// losing at cached connection is not terrible.
+const yieldTimeout = 25 * time.Millisecond
+
+// Yield returns the connection to transport for reuse.
+func (t *Transport) Yield(pc *persistConn) {
+ pc.used = time.Now() // update used time
+
+ // Make this non-blocking, because in the case of a very busy forwarder
+ // we will *block* on this yield. This blocks the outer go-routine and stuff
+ // will just pile up. We timeout when the send fails to as returning
+ // these connection is an optimization anyway.
+ select {
+ case t.yield <- pc:
+ return
+ case <-time.After(yieldTimeout):
+ return
+ }
+}
+
+// Start starts the transport's connection manager.
+func (t *Transport) Start() { go t.connManager() }
+
+// Stop stops the transport's connection manager.
+func (t *Transport) Stop() { close(t.stop) }
+
+// SetExpire sets the connection expire time in transport.
+func (t *Transport) SetExpire(expire time.Duration) { t.expire = expire }
+
+// SetTLSConfig sets the TLS config in transport.
+func (t *Transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg }
+
+const (
+ defaultExpire = 10 * time.Second
+ minDialTimeout = 1 * time.Second
+ maxDialTimeout = 30 * time.Second
+)
diff --git a/osbase/net/dns/forward/proxy/persistent_test.go b/osbase/net/dns/forward/proxy/persistent_test.go
new file mode 100644
index 0000000..2856c3f
--- /dev/null
+++ b/osbase/net/dns/forward/proxy/persistent_test.go
@@ -0,0 +1,111 @@
+package proxy
+
+// Taken and modified from CoreDNS, under Apache 2.0.
+
+import (
+ "testing"
+ "time"
+
+ "github.com/miekg/dns"
+
+ "source.monogon.dev/osbase/net/dns/test"
+)
+
+func TestCached(t *testing.T) {
+ s := test.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ w.WriteMsg(ret)
+ })
+ defer s.Close()
+
+ tr := newTransport(s.Addr)
+ tr.Start()
+ defer tr.Stop()
+
+ c1, cache1, _ := tr.Dial("udp")
+ c2, cache2, _ := tr.Dial("udp")
+
+ if cache1 || cache2 {
+ t.Errorf("Expected non-cached connection")
+ }
+
+ tr.Yield(c1)
+ tr.Yield(c2)
+ c3, cached3, _ := tr.Dial("udp")
+ if !cached3 {
+ t.Error("Expected cached connection (c3)")
+ }
+ if c2 != c3 {
+ t.Error("Expected c2 == c3")
+ }
+
+ tr.Yield(c3)
+
+ // dial another protocol
+ c4, cached4, _ := tr.Dial("tcp")
+ if cached4 {
+ t.Errorf("Expected non-cached connection (c4)")
+ }
+ tr.Yield(c4)
+}
+
+func TestCleanupByTimer(t *testing.T) {
+ s := test.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ w.WriteMsg(ret)
+ })
+ defer s.Close()
+
+ tr := newTransport(s.Addr)
+ tr.SetExpire(10 * time.Millisecond)
+ tr.Start()
+ defer tr.Stop()
+
+ c1, _, _ := tr.Dial("udp")
+ c2, _, _ := tr.Dial("udp")
+ tr.Yield(c1)
+ time.Sleep(2 * time.Millisecond)
+ tr.Yield(c2)
+
+ time.Sleep(15 * time.Millisecond)
+ c3, cached, _ := tr.Dial("udp")
+ if cached {
+ t.Error("Expected non-cached connection (c3)")
+ }
+ tr.Yield(c3)
+
+ time.Sleep(15 * time.Millisecond)
+ c4, cached, _ := tr.Dial("udp")
+ if cached {
+ t.Error("Expected non-cached connection (c4)")
+ }
+ tr.Yield(c4)
+}
+
+func TestCleanupAll(t *testing.T) {
+ s := test.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ w.WriteMsg(ret)
+ })
+ defer s.Close()
+
+ tr := newTransport(s.Addr)
+
+ c1, _ := dns.DialTimeout("udp", tr.addr, maxDialTimeout)
+ c2, _ := dns.DialTimeout("udp", tr.addr, maxDialTimeout)
+ c3, _ := dns.DialTimeout("udp", tr.addr, maxDialTimeout)
+
+ tr.conns[typeUDP] = []*persistConn{{c1, time.Now()}, {c2, time.Now()}, {c3, time.Now()}}
+
+ if len(tr.conns[typeUDP]) != 3 {
+ t.Error("Expected 3 connections")
+ }
+ tr.cleanup(true)
+
+ if len(tr.conns[typeUDP]) > 0 {
+ t.Error("Expected no cached connections")
+ }
+}
diff --git a/osbase/net/dns/forward/proxy/proxy.go b/osbase/net/dns/forward/proxy/proxy.go
new file mode 100644
index 0000000..dc25a31
--- /dev/null
+++ b/osbase/net/dns/forward/proxy/proxy.go
@@ -0,0 +1,108 @@
+package proxy
+
+// Taken and modified from CoreDNS, under Apache 2.0.
+
+import (
+ "crypto/tls"
+ "runtime"
+ "sync/atomic"
+ "time"
+
+ "source.monogon.dev/osbase/net/dns/forward/up"
+)
+
+// Proxy defines an upstream host.
+type Proxy struct {
+ fails uint32
+ addr string
+
+ transport *Transport
+
+ writeTimeout time.Duration
+ readTimeout time.Duration
+
+ // health checking
+ probe *up.Probe
+ health HealthChecker
+}
+
+// NewProxy returns a new proxy.
+func NewProxy(addr string) *Proxy {
+ p := &Proxy{
+ addr: addr,
+ fails: 0,
+ probe: up.New(),
+ writeTimeout: 2 * time.Second,
+ readTimeout: 2 * time.Second,
+ transport: newTransport(addr),
+ health: NewHealthChecker(true, "."),
+ }
+
+ runtime.SetFinalizer(p, (*Proxy).finalizer)
+ return p
+}
+
+func (p *Proxy) Addr() string { return p.addr }
+
+// SetTLSConfig sets the TLS config in the lower p.transport
+// and in the healthchecking client.
+func (p *Proxy) SetTLSConfig(cfg *tls.Config) {
+ p.transport.SetTLSConfig(cfg)
+ p.health.SetTLSConfig(cfg)
+}
+
+// SetExpire sets the expire duration in the lower p.transport.
+func (p *Proxy) SetExpire(expire time.Duration) { p.transport.SetExpire(expire) }
+
+func (p *Proxy) GetHealthchecker() HealthChecker {
+ return p.health
+}
+
+func (p *Proxy) Fails() uint32 {
+ return atomic.LoadUint32(&p.fails)
+}
+
+// Healthcheck kicks of a round of health checks for this proxy.
+func (p *Proxy) Healthcheck() {
+ if p.health == nil {
+ return
+ }
+
+ p.probe.Do(func() error {
+ return p.health.Check(p)
+ })
+}
+
+// Down returns true if this proxy is down, i.e. has *more* fails than maxfails.
+func (p *Proxy) Down(maxfails uint32) bool {
+ if maxfails == 0 {
+ return false
+ }
+
+ fails := atomic.LoadUint32(&p.fails)
+ return fails > maxfails
+}
+
+// Stop close stops the health checking goroutine.
+func (p *Proxy) Stop() { p.probe.Stop() }
+func (p *Proxy) finalizer() { p.transport.Stop() }
+
+// Start starts the proxy's healthchecking.
+func (p *Proxy) Start(duration time.Duration) {
+ p.probe.Start(duration)
+ p.transport.Start()
+}
+
+func (p *Proxy) SetReadTimeout(duration time.Duration) {
+ p.readTimeout = duration
+}
+
+// incrementFails increments the number of fails safely.
+func (p *Proxy) incrementFails() {
+ curVal := atomic.LoadUint32(&p.fails)
+ if curVal > curVal+1 {
+ // overflow occurred, do not update the counter again
+ return
+ }
+ atomic.AddUint32(&p.fails, 1)
+}
diff --git a/osbase/net/dns/forward/proxy/proxy_test.go b/osbase/net/dns/forward/proxy/proxy_test.go
new file mode 100644
index 0000000..a9297da
--- /dev/null
+++ b/osbase/net/dns/forward/proxy/proxy_test.go
@@ -0,0 +1,156 @@
+package proxy
+
+// Taken and modified from CoreDNS, under Apache 2.0.
+
+import (
+ "errors"
+ "fmt"
+ "math"
+ "testing"
+ "time"
+
+ "github.com/miekg/dns"
+
+ "source.monogon.dev/osbase/net/dns/test"
+)
+
+func TestProxy(t *testing.T) {
+ s := test.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ ret.Answer = append(ret.Answer, test.RR("example.org. IN A 127.0.0.1"))
+ w.WriteMsg(ret)
+ })
+ defer s.Close()
+
+ p := NewProxy(s.Addr)
+ p.readTimeout = 10 * time.Millisecond
+ p.Start(5 * time.Second)
+ m := new(dns.Msg)
+
+ m.SetQuestion("example.org.", dns.TypeA)
+
+ resp, err := p.Connect(m, false)
+ if err != nil {
+ t.Errorf("Failed to connect to testdnsserver: %s", err)
+ }
+
+ if x := resp.Answer[0].Header().Name; x != "example.org." {
+ t.Errorf("Expected %s, got %s", "example.org.", x)
+ }
+}
+
+func TestProtocolSelection(t *testing.T) {
+ p := NewProxy("bad_address")
+ p.readTimeout = 10 * time.Millisecond
+
+ go func() {
+ p.Connect(new(dns.Msg), false)
+ p.Connect(new(dns.Msg), true)
+ }()
+
+ for i, exp := range []string{"udp", "tcp"} {
+ proto := <-p.transport.dial
+ p.transport.ret <- nil
+ if proto != exp {
+ t.Errorf("Unexpected protocol in case %d, expected %q, actual %q", i, exp, proto)
+ }
+ }
+}
+
+func TestProxyIncrementFails(t *testing.T) {
+ var testCases = []struct {
+ name string
+ fails uint32
+ expectFails uint32
+ }{
+ {
+ name: "increment fails counter overflows",
+ fails: math.MaxUint32,
+ expectFails: math.MaxUint32,
+ },
+ {
+ name: "increment fails counter",
+ fails: 0,
+ expectFails: 1,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ p := NewProxy("bad_address")
+ p.fails = tc.fails
+ p.incrementFails()
+ if p.fails != tc.expectFails {
+ t.Errorf("Expected fails to be %d, got %d", tc.expectFails, p.fails)
+ }
+ })
+ }
+}
+
+func TestCoreDNSOverflow(t *testing.T) {
+ s := test.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+
+ var answers []dns.RR
+ for i := range 50 {
+ answers = append(answers, test.RR(fmt.Sprintf("example.org. IN A 127.0.0.%v", i)))
+ }
+ ret.Answer = answers
+ w.WriteMsg(ret)
+ })
+ defer s.Close()
+
+ p := NewProxy(s.Addr)
+ p.readTimeout = 10 * time.Millisecond
+ p.Start(5 * time.Second)
+ defer p.Stop()
+
+ // Test different connection modes
+ testConnection := func(proto string, useTCP bool, expectTruncated bool) {
+ t.Helper()
+
+ queryMsg := new(dns.Msg)
+ queryMsg.SetQuestion("example.org.", dns.TypeA)
+
+ response, err := p.Connect(queryMsg, useTCP)
+ if err != nil {
+ t.Errorf("Failed to connect to testdnsserver: %s", err)
+ }
+
+ if response.Truncated != expectTruncated {
+ t.Errorf("Expected truncated response for %s, but got TC flag %v", proto, response.Truncated)
+ }
+ }
+
+ // Test udp, expect truncated response
+ testConnection("udp", false, true)
+
+ // Test tcp, expect no truncated response
+ testConnection("tcp", true, false)
+}
+
+func TestShouldTruncateResponse(t *testing.T) {
+ testCases := []struct {
+ testname string
+ err error
+ expected bool
+ }{
+ {"BadAlgorithm", dns.ErrAlg, false},
+ {"BufferSizeTooSmall", dns.ErrBuf, true},
+ {"OverflowUnpackingA", errors.New("overflow unpacking a"), true},
+ {"OverflowingHeaderSize", errors.New("overflowing header size"), true},
+ {"OverflowpackingA", errors.New("overflow packing a"), true},
+ {"ErrSig", dns.ErrSig, false},
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.testname, func(t *testing.T) {
+ result := shouldTruncateResponse(tc.err)
+ if result != tc.expected {
+ t.Errorf("For testname '%v', expected %v but got %v", tc.testname, tc.expected, result)
+ }
+ })
+ }
+}
diff --git a/osbase/net/dns/forward/proxy/type.go b/osbase/net/dns/forward/proxy/type.go
new file mode 100644
index 0000000..6eb78ce
--- /dev/null
+++ b/osbase/net/dns/forward/proxy/type.go
@@ -0,0 +1,41 @@
+package proxy
+
+// Taken and modified from CoreDNS, under Apache 2.0.
+
+import (
+ "net"
+)
+
+type transportType int
+
+const (
+ typeUDP transportType = iota
+ typeTCP
+ typeTLS
+ typeTotalCount // keep this last
+)
+
+func stringToTransportType(s string) transportType {
+ switch s {
+ case "udp":
+ return typeUDP
+ case "tcp":
+ return typeTCP
+ case "tcp-tls":
+ return typeTLS
+ }
+
+ return typeUDP
+}
+
+func (t *Transport) transportTypeFromConn(pc *persistConn) transportType {
+ if _, ok := pc.c.Conn.(*net.UDPConn); ok {
+ return typeUDP
+ }
+
+ if t.tlsConfig == nil {
+ return typeTCP
+ }
+
+ return typeTLS
+}