| Tim Windelschmidt | 6d33a43 | 2025-02-04 14:34:25 +0100 | [diff] [blame^] | 1 | // Copyright The Monogon Project Authors. |
| 2 | // SPDX-License-Identifier: Apache-2.0 |
| 3 | |
| Jan Schär | 75ea9f4 | 2024-07-29 17:01:41 +0200 | [diff] [blame] | 4 | package proxy |
| 5 | |
| 6 | // Taken and modified from CoreDNS, under Apache 2.0. |
| 7 | |
| 8 | import ( |
| 9 | "crypto/tls" |
| 10 | "sync/atomic" |
| 11 | "time" |
| 12 | |
| 13 | "github.com/miekg/dns" |
| 14 | ) |
| 15 | |
| 16 | // HealthChecker checks the upstream health. |
| 17 | type HealthChecker interface { |
| 18 | Check(*Proxy) error |
| 19 | SetTLSConfig(*tls.Config) |
| 20 | GetTLSConfig() *tls.Config |
| 21 | SetRecursionDesired(bool) |
| 22 | GetRecursionDesired() bool |
| 23 | SetDomain(domain string) |
| 24 | GetDomain() string |
| 25 | SetTCPTransport() |
| 26 | GetReadTimeout() time.Duration |
| 27 | SetReadTimeout(time.Duration) |
| 28 | GetWriteTimeout() time.Duration |
| 29 | SetWriteTimeout(time.Duration) |
| 30 | } |
| 31 | |
| 32 | // dnsHc is a health checker for a DNS endpoint (DNS, and DoT). |
| 33 | type dnsHc struct { |
| 34 | c *dns.Client |
| 35 | recursionDesired bool |
| 36 | domain string |
| 37 | } |
| 38 | |
| 39 | // NewHealthChecker returns a new HealthChecker. |
| 40 | func NewHealthChecker(recursionDesired bool, domain string) HealthChecker { |
| 41 | c := new(dns.Client) |
| 42 | c.Net = "udp" |
| 43 | c.ReadTimeout = 1 * time.Second |
| 44 | c.WriteTimeout = 1 * time.Second |
| 45 | |
| 46 | return &dnsHc{ |
| 47 | c: c, |
| 48 | recursionDesired: recursionDesired, |
| 49 | domain: domain, |
| 50 | } |
| 51 | } |
| 52 | |
| 53 | func (h *dnsHc) SetTLSConfig(cfg *tls.Config) { |
| 54 | h.c.Net = "tcp-tls" |
| 55 | h.c.TLSConfig = cfg |
| 56 | } |
| 57 | |
| 58 | func (h *dnsHc) GetTLSConfig() *tls.Config { |
| 59 | return h.c.TLSConfig |
| 60 | } |
| 61 | |
| 62 | func (h *dnsHc) SetRecursionDesired(recursionDesired bool) { |
| 63 | h.recursionDesired = recursionDesired |
| 64 | } |
| 65 | func (h *dnsHc) GetRecursionDesired() bool { |
| 66 | return h.recursionDesired |
| 67 | } |
| 68 | |
| 69 | func (h *dnsHc) SetDomain(domain string) { |
| 70 | h.domain = domain |
| 71 | } |
| 72 | func (h *dnsHc) GetDomain() string { |
| 73 | return h.domain |
| 74 | } |
| 75 | |
| 76 | func (h *dnsHc) SetTCPTransport() { |
| 77 | h.c.Net = "tcp" |
| 78 | } |
| 79 | |
| 80 | func (h *dnsHc) GetReadTimeout() time.Duration { |
| 81 | return h.c.ReadTimeout |
| 82 | } |
| 83 | |
| 84 | func (h *dnsHc) SetReadTimeout(t time.Duration) { |
| 85 | h.c.ReadTimeout = t |
| 86 | } |
| 87 | |
| 88 | func (h *dnsHc) GetWriteTimeout() time.Duration { |
| 89 | return h.c.WriteTimeout |
| 90 | } |
| 91 | |
| 92 | func (h *dnsHc) SetWriteTimeout(t time.Duration) { |
| 93 | h.c.WriteTimeout = t |
| 94 | } |
| 95 | |
| 96 | // For HC, we send to . IN NS +[no]rec message to the upstream. |
| 97 | // Dial timeouts and empty replies are considered fails, |
| 98 | // basically anything else constitutes a healthy upstream. |
| 99 | |
| 100 | // Check is used as the up.Func in the up.Probe. |
| 101 | func (h *dnsHc) Check(p *Proxy) error { |
| 102 | err := h.send(p.addr) |
| 103 | if err != nil { |
| 104 | healthcheckFailureCount.WithLabelValues(p.addr).Inc() |
| 105 | p.incrementFails() |
| 106 | return err |
| 107 | } |
| 108 | |
| 109 | atomic.StoreUint32(&p.fails, 0) |
| 110 | return nil |
| 111 | } |
| 112 | |
| 113 | func (h *dnsHc) send(addr string) error { |
| 114 | ping := new(dns.Msg) |
| 115 | ping.SetQuestion(h.domain, dns.TypeNS) |
| 116 | ping.MsgHdr.RecursionDesired = h.recursionDesired |
| 117 | ping.SetEdns0(AdvertiseUDPSize, false) |
| 118 | |
| 119 | m, _, err := h.c.Exchange(ping, addr) |
| 120 | // If we got a header, we're alright, |
| 121 | // basically only care about I/O errors 'n stuff. |
| 122 | if err != nil && m != nil { |
| 123 | // Silly check, something sane came back. |
| 124 | if m.Response || m.Opcode == dns.OpcodeQuery { |
| 125 | err = nil |
| 126 | } |
| 127 | } |
| 128 | |
| 129 | return err |
| 130 | } |