blob: a11e8e6ca7b4757240325265c94246186eb1188e [file] [log] [blame]
Tim Windelschmidt6d33a432025-02-04 14:34:25 +01001// Copyright The Monogon Project Authors.
2// SPDX-License-Identifier: Apache-2.0
3
Jan Schär75ea9f42024-07-29 17:01:41 +02004// Package proxy implements a forwarding proxy. It caches an upstream net.Conn
5// for some time, so if the same client returns the upstream's Conn will be
6// precached. Depending on how you benchmark this looks to be 50% faster than
7// just opening a new connection for every client.
8// It works with UDP and TCP and uses inband healthchecking.
9package proxy
10
11// Taken and modified from CoreDNS, under Apache 2.0.
12
13import (
14 "errors"
15 "io"
16 "strings"
17 "sync/atomic"
18 "time"
19
20 "github.com/miekg/dns"
21)
22
23// AdvertiseUDPSize is the maximum message size that we advertise in the OPT RR
24// of UDP messages. This is calculated as the minimum IPv6 MTU (1280) minus
25// size of IPv6 (40) and UDP (8) headers.
26const AdvertiseUDPSize = 1232
27
28// ErrCachedClosed means cached connection was closed by peer.
29var ErrCachedClosed = errors.New("cached connection was closed by peer")
30
31// limitTimeout is a utility function to auto-tune timeout values.
32// Average observed time is moved towards the last observed delay moderated by
33// a weight next timeout to use will be the double of the computed average,
34// limited by min and max frame.
35func limitTimeout(currentAvg *int64, minValue time.Duration, maxValue time.Duration) time.Duration {
36 rt := time.Duration(atomic.LoadInt64(currentAvg))
37 if rt < minValue {
38 return minValue
39 }
40 if rt < maxValue/2 {
41 return 2 * rt
42 }
43 return maxValue
44}
45
46func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight int64) {
47 dt := time.Duration(atomic.LoadInt64(currentAvg))
48 atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight)
49}
50
51func (t *Transport) dialTimeout() time.Duration {
52 return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout)
53}
54
55func (t *Transport) updateDialTimeout(newDialTime time.Duration) {
56 averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight)
57}
58
59// Dial dials the address configured in transport,
60// potentially reusing a connection or creating a new one.
61func (t *Transport) Dial(proto string) (*persistConn, bool, error) {
62 // If tls has been configured; use it.
63 if t.tlsConfig != nil {
64 proto = "tcp-tls"
65 }
66
67 t.dial <- proto
68 pc := <-t.ret
69
70 if pc != nil {
71 return pc, true, nil
72 }
73
74 reqTime := time.Now()
75 timeout := t.dialTimeout()
76 if proto == "tcp-tls" {
77 conn, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, timeout)
78 t.updateDialTimeout(time.Since(reqTime))
79 return &persistConn{c: conn}, false, err
80 }
81 conn, err := dns.DialTimeout(proto, t.addr, timeout)
82 t.updateDialTimeout(time.Since(reqTime))
83 return &persistConn{c: conn}, false, err
84}
85
86// Connect selects an upstream, sends the request and waits for a response.
87func (p *Proxy) Connect(m *dns.Msg, useTCP bool) (*dns.Msg, error) {
88 proto := "udp"
89 if useTCP {
90 proto = "tcp"
91 }
92
93 pc, cached, err := p.transport.Dial(proto)
94 if err != nil {
95 return nil, err
96 }
97
98 pc.c.UDPSize = AdvertiseUDPSize
99
100 pc.c.SetWriteDeadline(time.Now().Add(p.writeTimeout))
101 m.Id = dns.Id()
102
103 if err := pc.c.WriteMsg(m); err != nil {
104 pc.c.Close() // not giving it back
105 if err == io.EOF && cached {
106 return nil, ErrCachedClosed
107 }
108 return nil, err
109 }
110
111 var ret *dns.Msg
112 pc.c.SetReadDeadline(time.Now().Add(p.readTimeout))
113 for {
114 ret, err = pc.c.ReadMsg()
115 if err != nil {
116 if ret != nil && (m.Id == ret.Id) && p.transport.transportTypeFromConn(pc) == typeUDP && shouldTruncateResponse(err) {
117 // For UDP, if the error is an overflow, we probably have an upstream
118 // misbehaving in some way.
119 // (e.g. sending >512 byte responses without an eDNS0 OPT RR).
120 // Instead of returning an error, return an empty response
121 // with TC bit set. This will make the client retry over TCP
122 // (if that's supported) or at least receive a clean error.
123 // The connection is still good so we break before the close.
124
125 // Truncate the response.
126 ret = truncateResponse(ret)
127 break
128 }
129
130 pc.c.Close() // not giving it back
131 if err == io.EOF && cached {
132 return nil, ErrCachedClosed
133 }
134 return ret, err
135 }
136 // drop out-of-order responses
137 if m.Id == ret.Id {
138 break
139 }
140 }
141
142 p.transport.Yield(pc)
143
144 return ret, nil
145}
146
147const cumulativeAvgWeight = 4
148
149// Function to determine if a response should be truncated.
150func shouldTruncateResponse(err error) bool {
151 // This is to handle a scenario in which upstream sets the TC bit,
152 // but doesn't truncate the response and we get ErrBuf instead of overflow.
153 if errors.Is(err, dns.ErrBuf) {
154 return true
155 } else if strings.Contains(err.Error(), "overflow") {
156 return true
157 }
158 return false
159}
160
161// Function to return an empty response with TC (truncated) bit set.
162func truncateResponse(response *dns.Msg) *dns.Msg {
163 // Clear out Answer, Extra, and Ns sections
164 response.Answer = nil
165 response.Extra = nil
166 response.Ns = nil
167
168 // Set TC bit to indicate truncation.
169 response.Truncated = true
170 return response
171}