blob: 29fe848967e906970324e7ee2dd72dd29745c7c4 [file] [log] [blame]
Tim Windelschmidt6d33a432025-02-04 14:34:25 +01001// Copyright The Monogon Project Authors.
2// SPDX-License-Identifier: Apache-2.0
3
Jan Schär75ea9f42024-07-29 17:01:41 +02004package proxy
5
6// Taken and modified from CoreDNS, under Apache 2.0.
7
8import (
9 "crypto/tls"
10 "sync/atomic"
11 "time"
12
13 "github.com/miekg/dns"
14)
15
16// HealthChecker checks the upstream health.
17type HealthChecker interface {
18 Check(*Proxy) error
19 SetTLSConfig(*tls.Config)
20 GetTLSConfig() *tls.Config
21 SetRecursionDesired(bool)
22 GetRecursionDesired() bool
23 SetDomain(domain string)
24 GetDomain() string
25 SetTCPTransport()
26 GetReadTimeout() time.Duration
27 SetReadTimeout(time.Duration)
28 GetWriteTimeout() time.Duration
29 SetWriteTimeout(time.Duration)
30}
31
32// dnsHc is a health checker for a DNS endpoint (DNS, and DoT).
33type dnsHc struct {
34 c *dns.Client
35 recursionDesired bool
36 domain string
37}
38
39// NewHealthChecker returns a new HealthChecker.
40func NewHealthChecker(recursionDesired bool, domain string) HealthChecker {
41 c := new(dns.Client)
42 c.Net = "udp"
43 c.ReadTimeout = 1 * time.Second
44 c.WriteTimeout = 1 * time.Second
45
46 return &dnsHc{
47 c: c,
48 recursionDesired: recursionDesired,
49 domain: domain,
50 }
51}
52
53func (h *dnsHc) SetTLSConfig(cfg *tls.Config) {
54 h.c.Net = "tcp-tls"
55 h.c.TLSConfig = cfg
56}
57
58func (h *dnsHc) GetTLSConfig() *tls.Config {
59 return h.c.TLSConfig
60}
61
62func (h *dnsHc) SetRecursionDesired(recursionDesired bool) {
63 h.recursionDesired = recursionDesired
64}
65func (h *dnsHc) GetRecursionDesired() bool {
66 return h.recursionDesired
67}
68
69func (h *dnsHc) SetDomain(domain string) {
70 h.domain = domain
71}
72func (h *dnsHc) GetDomain() string {
73 return h.domain
74}
75
76func (h *dnsHc) SetTCPTransport() {
77 h.c.Net = "tcp"
78}
79
80func (h *dnsHc) GetReadTimeout() time.Duration {
81 return h.c.ReadTimeout
82}
83
84func (h *dnsHc) SetReadTimeout(t time.Duration) {
85 h.c.ReadTimeout = t
86}
87
88func (h *dnsHc) GetWriteTimeout() time.Duration {
89 return h.c.WriteTimeout
90}
91
92func (h *dnsHc) SetWriteTimeout(t time.Duration) {
93 h.c.WriteTimeout = t
94}
95
96// For HC, we send to . IN NS +[no]rec message to the upstream.
97// Dial timeouts and empty replies are considered fails,
98// basically anything else constitutes a healthy upstream.
99
100// Check is used as the up.Func in the up.Probe.
101func (h *dnsHc) Check(p *Proxy) error {
102 err := h.send(p.addr)
103 if err != nil {
104 healthcheckFailureCount.WithLabelValues(p.addr).Inc()
105 p.incrementFails()
106 return err
107 }
108
109 atomic.StoreUint32(&p.fails, 0)
110 return nil
111}
112
113func (h *dnsHc) send(addr string) error {
114 ping := new(dns.Msg)
115 ping.SetQuestion(h.domain, dns.TypeNS)
116 ping.MsgHdr.RecursionDesired = h.recursionDesired
117 ping.SetEdns0(AdvertiseUDPSize, false)
118
119 m, _, err := h.c.Exchange(ping, addr)
120 // If we got a header, we're alright,
121 // basically only care about I/O errors 'n stuff.
122 if err != nil && m != nil {
123 // Silly check, something sane came back.
124 if m.Response || m.Opcode == dns.OpcodeQuery {
125 err = nil
126 }
127 }
128
129 return err
130}