| Jan Schär | 75ea9f4 | 2024-07-29 17:01:41 +0200 | [diff] [blame^] | 1 | package proxy |
| 2 | |
| 3 | // Taken and modified from CoreDNS, under Apache 2.0. |
| 4 | |
| 5 | import ( |
| 6 | "crypto/tls" |
| 7 | "sync/atomic" |
| 8 | "time" |
| 9 | |
| 10 | "github.com/miekg/dns" |
| 11 | ) |
| 12 | |
| 13 | // HealthChecker checks the upstream health. |
| 14 | type HealthChecker interface { |
| 15 | Check(*Proxy) error |
| 16 | SetTLSConfig(*tls.Config) |
| 17 | GetTLSConfig() *tls.Config |
| 18 | SetRecursionDesired(bool) |
| 19 | GetRecursionDesired() bool |
| 20 | SetDomain(domain string) |
| 21 | GetDomain() string |
| 22 | SetTCPTransport() |
| 23 | GetReadTimeout() time.Duration |
| 24 | SetReadTimeout(time.Duration) |
| 25 | GetWriteTimeout() time.Duration |
| 26 | SetWriteTimeout(time.Duration) |
| 27 | } |
| 28 | |
| 29 | // dnsHc is a health checker for a DNS endpoint (DNS, and DoT). |
| 30 | type dnsHc struct { |
| 31 | c *dns.Client |
| 32 | recursionDesired bool |
| 33 | domain string |
| 34 | } |
| 35 | |
| 36 | // NewHealthChecker returns a new HealthChecker. |
| 37 | func NewHealthChecker(recursionDesired bool, domain string) HealthChecker { |
| 38 | c := new(dns.Client) |
| 39 | c.Net = "udp" |
| 40 | c.ReadTimeout = 1 * time.Second |
| 41 | c.WriteTimeout = 1 * time.Second |
| 42 | |
| 43 | return &dnsHc{ |
| 44 | c: c, |
| 45 | recursionDesired: recursionDesired, |
| 46 | domain: domain, |
| 47 | } |
| 48 | } |
| 49 | |
| 50 | func (h *dnsHc) SetTLSConfig(cfg *tls.Config) { |
| 51 | h.c.Net = "tcp-tls" |
| 52 | h.c.TLSConfig = cfg |
| 53 | } |
| 54 | |
| 55 | func (h *dnsHc) GetTLSConfig() *tls.Config { |
| 56 | return h.c.TLSConfig |
| 57 | } |
| 58 | |
| 59 | func (h *dnsHc) SetRecursionDesired(recursionDesired bool) { |
| 60 | h.recursionDesired = recursionDesired |
| 61 | } |
| 62 | func (h *dnsHc) GetRecursionDesired() bool { |
| 63 | return h.recursionDesired |
| 64 | } |
| 65 | |
| 66 | func (h *dnsHc) SetDomain(domain string) { |
| 67 | h.domain = domain |
| 68 | } |
| 69 | func (h *dnsHc) GetDomain() string { |
| 70 | return h.domain |
| 71 | } |
| 72 | |
| 73 | func (h *dnsHc) SetTCPTransport() { |
| 74 | h.c.Net = "tcp" |
| 75 | } |
| 76 | |
| 77 | func (h *dnsHc) GetReadTimeout() time.Duration { |
| 78 | return h.c.ReadTimeout |
| 79 | } |
| 80 | |
| 81 | func (h *dnsHc) SetReadTimeout(t time.Duration) { |
| 82 | h.c.ReadTimeout = t |
| 83 | } |
| 84 | |
| 85 | func (h *dnsHc) GetWriteTimeout() time.Duration { |
| 86 | return h.c.WriteTimeout |
| 87 | } |
| 88 | |
| 89 | func (h *dnsHc) SetWriteTimeout(t time.Duration) { |
| 90 | h.c.WriteTimeout = t |
| 91 | } |
| 92 | |
| 93 | // For HC, we send to . IN NS +[no]rec message to the upstream. |
| 94 | // Dial timeouts and empty replies are considered fails, |
| 95 | // basically anything else constitutes a healthy upstream. |
| 96 | |
| 97 | // Check is used as the up.Func in the up.Probe. |
| 98 | func (h *dnsHc) Check(p *Proxy) error { |
| 99 | err := h.send(p.addr) |
| 100 | if err != nil { |
| 101 | healthcheckFailureCount.WithLabelValues(p.addr).Inc() |
| 102 | p.incrementFails() |
| 103 | return err |
| 104 | } |
| 105 | |
| 106 | atomic.StoreUint32(&p.fails, 0) |
| 107 | return nil |
| 108 | } |
| 109 | |
| 110 | func (h *dnsHc) send(addr string) error { |
| 111 | ping := new(dns.Msg) |
| 112 | ping.SetQuestion(h.domain, dns.TypeNS) |
| 113 | ping.MsgHdr.RecursionDesired = h.recursionDesired |
| 114 | ping.SetEdns0(AdvertiseUDPSize, false) |
| 115 | |
| 116 | m, _, err := h.c.Exchange(ping, addr) |
| 117 | // If we got a header, we're alright, |
| 118 | // basically only care about I/O errors 'n stuff. |
| 119 | if err != nil && m != nil { |
| 120 | // Silly check, something sane came back. |
| 121 | if m.Response || m.Opcode == dns.OpcodeQuery { |
| 122 | err = nil |
| 123 | } |
| 124 | } |
| 125 | |
| 126 | return err |
| 127 | } |