blob: 4630cdcc08af03ad3ecb63826e62c16f99833347 [file] [log] [blame]
Jan Schär75ea9f42024-07-29 17:01:41 +02001package proxy
2
3// Taken and modified from CoreDNS, under Apache 2.0.
4
5import (
6 "crypto/tls"
7 "sync/atomic"
8 "time"
9
10 "github.com/miekg/dns"
11)
12
13// HealthChecker checks the upstream health.
14type HealthChecker interface {
15 Check(*Proxy) error
16 SetTLSConfig(*tls.Config)
17 GetTLSConfig() *tls.Config
18 SetRecursionDesired(bool)
19 GetRecursionDesired() bool
20 SetDomain(domain string)
21 GetDomain() string
22 SetTCPTransport()
23 GetReadTimeout() time.Duration
24 SetReadTimeout(time.Duration)
25 GetWriteTimeout() time.Duration
26 SetWriteTimeout(time.Duration)
27}
28
29// dnsHc is a health checker for a DNS endpoint (DNS, and DoT).
30type dnsHc struct {
31 c *dns.Client
32 recursionDesired bool
33 domain string
34}
35
36// NewHealthChecker returns a new HealthChecker.
37func NewHealthChecker(recursionDesired bool, domain string) HealthChecker {
38 c := new(dns.Client)
39 c.Net = "udp"
40 c.ReadTimeout = 1 * time.Second
41 c.WriteTimeout = 1 * time.Second
42
43 return &dnsHc{
44 c: c,
45 recursionDesired: recursionDesired,
46 domain: domain,
47 }
48}
49
50func (h *dnsHc) SetTLSConfig(cfg *tls.Config) {
51 h.c.Net = "tcp-tls"
52 h.c.TLSConfig = cfg
53}
54
55func (h *dnsHc) GetTLSConfig() *tls.Config {
56 return h.c.TLSConfig
57}
58
59func (h *dnsHc) SetRecursionDesired(recursionDesired bool) {
60 h.recursionDesired = recursionDesired
61}
62func (h *dnsHc) GetRecursionDesired() bool {
63 return h.recursionDesired
64}
65
66func (h *dnsHc) SetDomain(domain string) {
67 h.domain = domain
68}
69func (h *dnsHc) GetDomain() string {
70 return h.domain
71}
72
73func (h *dnsHc) SetTCPTransport() {
74 h.c.Net = "tcp"
75}
76
77func (h *dnsHc) GetReadTimeout() time.Duration {
78 return h.c.ReadTimeout
79}
80
81func (h *dnsHc) SetReadTimeout(t time.Duration) {
82 h.c.ReadTimeout = t
83}
84
85func (h *dnsHc) GetWriteTimeout() time.Duration {
86 return h.c.WriteTimeout
87}
88
89func (h *dnsHc) SetWriteTimeout(t time.Duration) {
90 h.c.WriteTimeout = t
91}
92
93// For HC, we send to . IN NS +[no]rec message to the upstream.
94// Dial timeouts and empty replies are considered fails,
95// basically anything else constitutes a healthy upstream.
96
97// Check is used as the up.Func in the up.Probe.
98func (h *dnsHc) Check(p *Proxy) error {
99 err := h.send(p.addr)
100 if err != nil {
101 healthcheckFailureCount.WithLabelValues(p.addr).Inc()
102 p.incrementFails()
103 return err
104 }
105
106 atomic.StoreUint32(&p.fails, 0)
107 return nil
108}
109
110func (h *dnsHc) send(addr string) error {
111 ping := new(dns.Msg)
112 ping.SetQuestion(h.domain, dns.TypeNS)
113 ping.MsgHdr.RecursionDesired = h.recursionDesired
114 ping.SetEdns0(AdvertiseUDPSize, false)
115
116 m, _, err := h.c.Exchange(ping, addr)
117 // If we got a header, we're alright,
118 // basically only care about I/O errors 'n stuff.
119 if err != nil && m != nil {
120 // Silly check, something sane came back.
121 if m.Response || m.Opcode == dns.OpcodeQuery {
122 err = nil
123 }
124 }
125
126 return err
127}