blob: d1c5b8e5cc83b66b0d730da37c9bb853d7321114 [file] [log] [blame]
Jan Schär75ea9f42024-07-29 17:01:41 +02001// Package proxy implements a forwarding proxy. It caches an upstream net.Conn
2// for some time, so if the same client returns the upstream's Conn will be
3// precached. Depending on how you benchmark this looks to be 50% faster than
4// just opening a new connection for every client.
5// It works with UDP and TCP and uses inband healthchecking.
6package proxy
7
8// Taken and modified from CoreDNS, under Apache 2.0.
9
10import (
11 "errors"
12 "io"
13 "strings"
14 "sync/atomic"
15 "time"
16
17 "github.com/miekg/dns"
18)
19
20// AdvertiseUDPSize is the maximum message size that we advertise in the OPT RR
21// of UDP messages. This is calculated as the minimum IPv6 MTU (1280) minus
22// size of IPv6 (40) and UDP (8) headers.
23const AdvertiseUDPSize = 1232
24
25// ErrCachedClosed means cached connection was closed by peer.
26var ErrCachedClosed = errors.New("cached connection was closed by peer")
27
28// limitTimeout is a utility function to auto-tune timeout values.
29// Average observed time is moved towards the last observed delay moderated by
30// a weight next timeout to use will be the double of the computed average,
31// limited by min and max frame.
32func limitTimeout(currentAvg *int64, minValue time.Duration, maxValue time.Duration) time.Duration {
33 rt := time.Duration(atomic.LoadInt64(currentAvg))
34 if rt < minValue {
35 return minValue
36 }
37 if rt < maxValue/2 {
38 return 2 * rt
39 }
40 return maxValue
41}
42
43func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight int64) {
44 dt := time.Duration(atomic.LoadInt64(currentAvg))
45 atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight)
46}
47
48func (t *Transport) dialTimeout() time.Duration {
49 return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout)
50}
51
52func (t *Transport) updateDialTimeout(newDialTime time.Duration) {
53 averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight)
54}
55
56// Dial dials the address configured in transport,
57// potentially reusing a connection or creating a new one.
58func (t *Transport) Dial(proto string) (*persistConn, bool, error) {
59 // If tls has been configured; use it.
60 if t.tlsConfig != nil {
61 proto = "tcp-tls"
62 }
63
64 t.dial <- proto
65 pc := <-t.ret
66
67 if pc != nil {
68 return pc, true, nil
69 }
70
71 reqTime := time.Now()
72 timeout := t.dialTimeout()
73 if proto == "tcp-tls" {
74 conn, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, timeout)
75 t.updateDialTimeout(time.Since(reqTime))
76 return &persistConn{c: conn}, false, err
77 }
78 conn, err := dns.DialTimeout(proto, t.addr, timeout)
79 t.updateDialTimeout(time.Since(reqTime))
80 return &persistConn{c: conn}, false, err
81}
82
83// Connect selects an upstream, sends the request and waits for a response.
84func (p *Proxy) Connect(m *dns.Msg, useTCP bool) (*dns.Msg, error) {
85 proto := "udp"
86 if useTCP {
87 proto = "tcp"
88 }
89
90 pc, cached, err := p.transport.Dial(proto)
91 if err != nil {
92 return nil, err
93 }
94
95 pc.c.UDPSize = AdvertiseUDPSize
96
97 pc.c.SetWriteDeadline(time.Now().Add(p.writeTimeout))
98 m.Id = dns.Id()
99
100 if err := pc.c.WriteMsg(m); err != nil {
101 pc.c.Close() // not giving it back
102 if err == io.EOF && cached {
103 return nil, ErrCachedClosed
104 }
105 return nil, err
106 }
107
108 var ret *dns.Msg
109 pc.c.SetReadDeadline(time.Now().Add(p.readTimeout))
110 for {
111 ret, err = pc.c.ReadMsg()
112 if err != nil {
113 if ret != nil && (m.Id == ret.Id) && p.transport.transportTypeFromConn(pc) == typeUDP && shouldTruncateResponse(err) {
114 // For UDP, if the error is an overflow, we probably have an upstream
115 // misbehaving in some way.
116 // (e.g. sending >512 byte responses without an eDNS0 OPT RR).
117 // Instead of returning an error, return an empty response
118 // with TC bit set. This will make the client retry over TCP
119 // (if that's supported) or at least receive a clean error.
120 // The connection is still good so we break before the close.
121
122 // Truncate the response.
123 ret = truncateResponse(ret)
124 break
125 }
126
127 pc.c.Close() // not giving it back
128 if err == io.EOF && cached {
129 return nil, ErrCachedClosed
130 }
131 return ret, err
132 }
133 // drop out-of-order responses
134 if m.Id == ret.Id {
135 break
136 }
137 }
138
139 p.transport.Yield(pc)
140
141 return ret, nil
142}
143
144const cumulativeAvgWeight = 4
145
146// Function to determine if a response should be truncated.
147func shouldTruncateResponse(err error) bool {
148 // This is to handle a scenario in which upstream sets the TC bit,
149 // but doesn't truncate the response and we get ErrBuf instead of overflow.
150 if errors.Is(err, dns.ErrBuf) {
151 return true
152 } else if strings.Contains(err.Error(), "overflow") {
153 return true
154 }
155 return false
156}
157
158// Function to return an empty response with TC (truncated) bit set.
159func truncateResponse(response *dns.Msg) *dns.Msg {
160 // Clear out Answer, Extra, and Ns sections
161 response.Answer = nil
162 response.Extra = nil
163 response.Ns = nil
164
165 // Set TC bit to indicate truncation.
166 response.Truncated = true
167 return response
168}