osbase/net/dns: add new DNS server
This adds a new DNS server service, which will replace CoreDNS. The
service has built-in handlers for certain names, but all other names
will be handled by runtime configurable handlers.
Change-Id: I4184d11422496e899794ef658ca1450e7bb01471
Reviewed-on: https://review.monogon.dev/c/monogon/+/3126
Tested-by: Jenkins CI
Reviewed-by: Lorenz Brun <lorenz@monogon.tech>
diff --git a/osbase/net/dns/BUILD.bazel b/osbase/net/dns/BUILD.bazel
new file mode 100644
index 0000000..ac0e20e
--- /dev/null
+++ b/osbase/net/dns/BUILD.bazel
@@ -0,0 +1,37 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "dns",
+ srcs = [
+ "dns.go",
+ "metrics.go",
+ "name.go",
+ "testhelpers.go",
+ ],
+ importpath = "source.monogon.dev/osbase/net/dns",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//osbase/supervisor",
+ "@com_github_miekg_dns//:dns",
+ "@com_github_prometheus_client_golang//prometheus",
+ "@com_github_prometheus_client_golang//prometheus/promauto",
+ ],
+)
+
+exports_files([
+ "resolv.conf",
+ "hosts",
+])
+
+go_test(
+ name = "dns_test",
+ srcs = [
+ "dns_test.go",
+ "name_test.go",
+ ],
+ embed = [":dns"],
+ deps = [
+ "//osbase/net/dns/test",
+ "@com_github_miekg_dns//:dns",
+ ],
+)
diff --git a/osbase/net/dns/dns.go b/osbase/net/dns/dns.go
new file mode 100644
index 0000000..3207622
--- /dev/null
+++ b/osbase/net/dns/dns.go
@@ -0,0 +1,468 @@
+// Package dns provides a DNS server for resolving services against.
+package dns
+
+import (
+ "context"
+ "fmt"
+ "net/netip"
+ "runtime/debug"
+ "strconv"
+ "sync/atomic"
+ "time"
+
+ "github.com/miekg/dns"
+
+ "source.monogon.dev/osbase/supervisor"
+)
+
+// Service is a DNS server service with configurable handlers.
+//
+// The number and names of handlers is fixed when New is called. For each name
+// in handlerNames there is a corresponding pointer to a handler in the handlers
+// slice at the same index, which can be atomically updated at runtime through
+// its atomic.Pointer via the SetHandler function.
+type Service struct {
+ handlerNames []string
+ handlers []atomic.Pointer[Handler]
+}
+
+type serviceCtx struct {
+ service *Service
+ ctx context.Context
+}
+
+// New creates a Service instance. DNS handlers with the names given in
+// handlerNames must be set with SetHandler. When serving DNS queries, they will
+// be tried in the order they appear here. Doing it this way instead of directly
+// passing a []Handler avoids circular dependencies.
+func New(handlerNames []string) *Service {
+ return &Service{
+ handlerNames: handlerNames,
+ handlers: make([]atomic.Pointer[Handler], len(handlerNames)),
+ }
+}
+
+// Run runs the DNS service.
+func (s *Service) Run(ctx context.Context) error {
+ addr4 := "127.0.0.1:53"
+ addr6 := "[::1]:53"
+ supervisor.Run(ctx, "udp4", func(ctx context.Context) error {
+ return s.runListener(ctx, addr4, "udp")
+ })
+ supervisor.Run(ctx, "tcp4", func(ctx context.Context) error {
+ return s.runListener(ctx, addr4, "tcp")
+ })
+ supervisor.Run(ctx, "udp6", func(ctx context.Context) error {
+ return s.runListener(ctx, addr6, "udp")
+ })
+ supervisor.Run(ctx, "tcp6", func(ctx context.Context) error {
+ return s.runListener(ctx, addr6, "tcp")
+ })
+ supervisor.Signal(ctx, supervisor.SignalHealthy)
+ supervisor.Signal(ctx, supervisor.SignalDone)
+ return nil
+}
+
+// RunListenerAddr runs a DNS listener on a specific address.
+func (s *Service) RunListenerAddr(ctx context.Context, addr string) error {
+ supervisor.Run(ctx, "udp", func(ctx context.Context) error {
+ return s.runListener(ctx, addr, "udp")
+ })
+ supervisor.Run(ctx, "tcp", func(ctx context.Context) error {
+ return s.runListener(ctx, addr, "tcp")
+ })
+ supervisor.Signal(ctx, supervisor.SignalHealthy)
+ supervisor.Signal(ctx, supervisor.SignalDone)
+ return nil
+}
+
+func (s *Service) runListener(ctx context.Context, addr string, network string) error {
+ handler := &serviceCtx{service: s, ctx: ctx}
+ server := &dns.Server{Addr: addr, Net: network, ReusePort: true, Handler: handler}
+ server.NotifyStartedFunc = func() {
+ supervisor.Signal(ctx, supervisor.SignalHealthy)
+ supervisor.Logger(ctx).Infof("DNS server listening on %s %s", addr, network)
+ go func() {
+ <-ctx.Done()
+ server.Shutdown()
+ }()
+ }
+ return server.ListenAndServe()
+}
+
+// Requests
+
+// Request represents an incoming DNS query that is being handled.
+type Request struct {
+ // Reply is the reply that will be sent, and should be filled in by the
+ // handler. It is guaranteed to contain exactly one question.
+ Reply *dns.Msg
+ // Writer will be used to send the reply, and contains network information.
+ Writer dns.ResponseWriter
+
+ // Qopt is the OPT record from the query, or nil if not present.
+ Qopt *dns.OPT
+ // Ropt, if non-nil, is the OPT record that will be added to the reply. The
+ // handler can modify this as needed. Ropt is nil when Qopt is nil.
+ Ropt *dns.OPT
+
+ // Qname contains the current question name. This is different from the
+ // original question in Reply.Question[0].Name if a CNAME has been followed
+ // already.
+ Qname string
+
+ // QnameCanonical contains the canonicalized name of the question. This means
+ // that ASCII letters are lowercased.
+ QnameCanonical string
+
+ // Qtype contains the question type for convenient access.
+ Qtype uint16
+
+ // Handled is set to true when the current question name has been handled and
+ // no other handlers should be attempted.
+ Handled bool
+
+ // done is set to true when a reply has been sent. When a CNAME is
+ // encountered, Handled is set to true, but done is false.
+ done bool
+}
+
+// SetAuthoritative marks the reply as authoritative.
+func (r *Request) SetAuthoritative() {
+ // Only set the AA bit if the question has not yet been redirected by CNAME.
+ // See RFC 1034 6.2.7
+ if r.Qname == r.Reply.Question[0].Name {
+ r.Reply.Authoritative = true
+ }
+}
+
+// SendReply sends the reply. It may only be called once.
+func (r *Request) SendReply() {
+ if r.Handled {
+ panic("SendReply called twice for the same DNS request")
+ }
+ r.Handled = true
+ r.done = true
+
+ if r.Ropt != nil {
+ r.Reply.Extra = append(r.Reply.Extra, r.Ropt)
+ } else {
+ // Cannot use extended RCODEs without an OPT, so replace with SERVFAIL.
+ if r.Reply.Rcode > 0xF {
+ r.Reply.Rcode = dns.RcodeServerFailure
+ }
+ }
+
+ size := uint16(0)
+ if r.Writer.RemoteAddr().Network() == "tcp" {
+ size = dns.MaxMsgSize
+ } else if r.Qopt != nil {
+ size = r.Qopt.UDPSize()
+ }
+ if size < dns.MinMsgSize {
+ size = dns.MinMsgSize
+ }
+ r.Reply.Truncate(int(size))
+ if !r.Reply.Compress && r.Reply.Len() >= 1024 {
+ r.Reply.Compress = true
+ }
+
+ r.Writer.WriteMsg(r.Reply)
+}
+
+// SendRcode sets the reply RCODE and sends the reply.
+func (r *Request) SendRcode(rcode int) {
+ r.Reply.Rcode = rcode
+ r.SendReply()
+}
+
+// AddExtendedError adds an Extended DNS Error Option if the reply has an OPT.
+// See RFC 8914.
+func (r *Request) AddExtendedError(infoCode uint16, extraText string) {
+ if r.Ropt != nil {
+ r.Ropt.Option = append(r.Ropt.Option, &dns.EDNS0_EDE{InfoCode: infoCode, ExtraText: extraText})
+ }
+}
+
+// AddCNAME adds a CNAME record to the answer section, and either sends the
+// reply if the query is for the CNAME itself, or else marks the lookup to be
+// restarted at the new name. target must be fully qualified.
+func (r *Request) AddCNAME(target string, ttl uint32) {
+ rr := new(dns.CNAME)
+ rr.Hdr = dns.RR_Header{Name: r.Qname, Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: ttl}
+ rr.Target = target
+ r.Reply.Answer = append(r.Reply.Answer, rr)
+
+ if r.Qtype == dns.TypeCNAME || r.Qtype == dns.TypeANY {
+ r.SendReply()
+ } else {
+ r.Handled = true
+ r.Qname = target
+ r.QnameCanonical = dns.CanonicalName(r.Qname)
+ }
+}
+
+// Handlers
+
+// Handler can handle DNS requests. The handler should first inspect the query
+// and decide if it wants to handle it. If not, it should return immediately.
+// The next handler will then be tried. Otherwise, it should fill in the Reply,
+// and then call SendReply. The Answer section may already contain CNAMEs that
+// have been followed.
+type Handler interface {
+ HandleDNS(r *Request)
+}
+
+// SetHandler sets the handler of the given name. This name must have been
+// registered when creating the Service. As long as SetHandler has not been
+// called for a registered name, any queries that are not already handled by an
+// earlier handler in the sequence return SERVFAIL. SetHandler may be called
+// multiple times, each call replaces the previous handler of the same name.
+func (s *Service) SetHandler(name string, h Handler) {
+ for i, iname := range s.handlerNames {
+ if iname == name {
+ s.handlers[i].Store(&h)
+ return
+ }
+ }
+ panic(fmt.Sprintf("Attempted to set undeclared DNS handler: %q", name))
+}
+
+// EmptyDNSHandler is a handler that does not handle any queries. It can be used
+// as a placeholder with SetHandler when a handler is inactive.
+type EmptyDNSHandler struct{}
+
+func (EmptyDNSHandler) HandleDNS(*Request) {}
+
+// Serving requests
+
+// advertiseUDPSize is the maximum message size that we advertise in the OPT RR
+// of UDP messages. This is calculated as the minimum IPv6 MTU (1280) minus size
+// of IPv6 (40) and UDP (8) headers.
+const advertiseUDPSize = 1232
+
+// ServeDNS implements dns.Handler.
+func (s *serviceCtx) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
+ defer func() {
+ if rec := recover(); rec != nil {
+ supervisor.Logger(s.ctx).Errorf("panic in DNS handler: %v, stacktrace: %s", rec, string(debug.Stack()))
+ }
+ }()
+
+ // Only QUERY opcode is implemented.
+ if r.Opcode != dns.OpcodeQuery {
+ m := new(dns.Msg)
+ m.SetRcode(r, dns.RcodeNotImplemented)
+ w.WriteMsg(m)
+ return
+ }
+
+ if r.Truncated {
+ m := new(dns.Msg)
+ m.RecursionAvailable = true
+ m.SetRcode(r, dns.RcodeFormatError)
+ w.WriteMsg(m)
+ return
+ }
+
+ // Look for an OPT RR.
+ var opt *dns.OPT
+ for _, rr := range r.Extra {
+ if rr, ok := rr.(*dns.OPT); ok {
+ if opt != nil {
+ // RFC 6891 6.1.1
+ // If a query has more than one OPT RR, FORMERR MUST be returned.
+ m := new(dns.Msg)
+ m.RecursionAvailable = true
+ m.SetRcode(r, dns.RcodeFormatError)
+ w.WriteMsg(m)
+ return
+ }
+ opt = rr
+ }
+ }
+
+ // RFC 6891 6.1.3
+ // If the VERSION of the query is not implemented, BADVERS MUST be returned.
+ if opt != nil && opt.Version() != 0 {
+ m := new(dns.Msg)
+ m.RecursionAvailable = true
+ m.SetRcode(r, dns.RcodeBadVers)
+ m.SetEdns0(advertiseUDPSize, false)
+ w.WriteMsg(m)
+ return
+ }
+
+ // If the OPT name is not the root, the OPT is invalid.
+ if opt != nil && opt.Hdr.Name != "." {
+ m := new(dns.Msg)
+ m.RecursionAvailable = true
+ m.SetRcode(r, dns.RcodeFormatError)
+ w.WriteMsg(m)
+ return
+ }
+
+ // Reuse the query message as the reply message.
+ r.Response = true
+ r.Authoritative = false
+ r.RecursionAvailable = true
+ r.Zero = false
+ r.AuthenticatedData = false
+ r.Rcode = dns.RcodeSuccess
+ r.Extra = nil
+
+ req := &Request{
+ Reply: r,
+ Writer: w,
+ Qopt: opt,
+ }
+ if opt != nil {
+ req.Ropt = new(dns.OPT)
+ req.Ropt.Hdr.Name = "."
+ req.Ropt.Hdr.Rrtype = dns.TypeOPT
+ req.Ropt.SetUDPSize(advertiseUDPSize)
+ if opt.Do() {
+ req.Ropt.SetDo()
+ }
+ }
+
+ // Refuse queries that don't have exactly one question of class INET, or that
+ // have non-empty answer or authority sections.
+ if len(r.Question) != 1 || r.Question[0].Qclass != dns.ClassINET || len(r.Answer) != 0 || len(r.Ns) != 0 {
+ r.Answer = nil
+ r.Ns = nil
+ req.SendRcode(dns.RcodeRefused)
+ return
+ }
+ req.Qtype = r.Question[0].Qtype
+ req.Qname = r.Question[0].Name
+ req.QnameCanonical = dns.CanonicalName(req.Qname)
+
+ switch req.Qtype {
+ case dns.TypeOPT:
+ // OPT is a pseudo-RR and may only appear in the additional section.
+ req.SendRcode(dns.RcodeFormatError)
+ return
+ case dns.TypeAXFR, dns.TypeIXFR:
+ // Zone transfer is not supported.
+ req.AddExtendedError(dns.ExtendedErrorCodeNotSupported, "")
+ req.SendRcode(dns.RcodeRefused)
+ return
+ }
+
+ // If we encounter a CNAME, DNS resolution must be restarted with the new
+ // name. That's what this loop is for.
+ i := 0
+ seen := make(map[string]bool)
+ for {
+ prevName := req.QnameCanonical
+ s.service.HandleDNS(req)
+ if req.done {
+ break
+ }
+ req.Handled = false
+ i++
+ seen[prevName] = true
+ if seen[req.QnameCanonical] || i > 7 {
+ if seen[req.QnameCanonical] {
+ req.AddExtendedError(dns.ExtendedErrorCodeOther, "CNAME loop")
+ } else {
+ req.AddExtendedError(dns.ExtendedErrorCodeOther, "too many CNAME redirects")
+ }
+ req.SendRcode(dns.RcodeServerFailure)
+ break
+ }
+ }
+}
+
+func (s *Service) HandleDNS(r *Request) {
+ start := time.Now()
+
+ // Handle localhost.
+ if IsSubDomain("localhost.", r.QnameCanonical) {
+ handleLocalhost(r)
+ return
+ }
+ if IsSubDomain("127.in-addr.arpa.", r.QnameCanonical) ||
+ IsSubDomain("1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.", r.QnameCanonical) {
+ handleLocalhostPtr(r)
+ return
+ }
+
+ // Serve NXDOMAIN for the "invalid." domain, see RFC 6761.
+ if IsSubDomain("invalid.", r.QnameCanonical) {
+ r.SetAuthoritative()
+ r.SendRcode(dns.RcodeNameError)
+ return
+ }
+
+ for i := range s.handlers {
+ handler := s.handlers[i].Load()
+ if handler == nil {
+ // The application is still starting up. Fail queries instead of leaking
+ // local queries to the Internet and sending the wrong reply.
+ r.AddExtendedError(dns.ExtendedErrorCodeNotReady, fmt.Sprintf("%s handler not ready", s.handlerNames[i]))
+ r.SendRcode(dns.RcodeServerFailure)
+ handlerDuration.WithLabelValues(s.handlerNames[i], "not_ready").Observe(time.Since(start).Seconds())
+ return
+ }
+ (*handler).HandleDNS(r)
+ if r.Handled {
+ rcode, ok := dns.RcodeToString[r.Reply.Rcode]
+ if !ok {
+ // There are 4096 possible Rcodes, so it's probably still fine to put it
+ // in a metric label.
+ rcode = strconv.Itoa(r.Reply.Rcode)
+ }
+ if !r.done {
+ rcode = "redirected"
+ }
+ handlerDuration.WithLabelValues(s.handlerNames[i], rcode).Observe(time.Since(start).Seconds())
+ return
+ }
+ }
+
+ // No handler can handle this request.
+ r.SendRcode(dns.RcodeRefused)
+}
+
+var (
+ localhostA = netip.MustParseAddr("127.0.0.1").AsSlice()
+ localhostAAAA = netip.MustParseAddr("::1").AsSlice()
+)
+
+const localhostTtl = 60 * 5
+
+func handleLocalhost(r *Request) {
+ r.SetAuthoritative()
+ if r.Qtype == dns.TypeA || r.Qtype == dns.TypeANY {
+ rr := new(dns.A)
+ rr.Hdr = dns.RR_Header{Name: r.Qname, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: localhostTtl}
+ rr.A = localhostA
+ r.Reply.Answer = append(r.Reply.Answer, rr)
+ }
+ if r.Qtype == dns.TypeAAAA || r.Qtype == dns.TypeANY {
+ rr := new(dns.AAAA)
+ rr.Hdr = dns.RR_Header{Name: r.Qname, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: localhostTtl}
+ rr.AAAA = localhostAAAA
+ r.Reply.Answer = append(r.Reply.Answer, rr)
+ }
+ r.SendReply()
+}
+
+func handleLocalhostPtr(r *Request) {
+ r.SetAuthoritative()
+ ip, bits, extra := ParseReverse(r.QnameCanonical)
+ if extra {
+ // Name with extra labels does not exist (e.g. foo.1.0.0.127.in-addr.arpa.)
+ r.Reply.Rcode = dns.RcodeNameError
+ } else if bits != ip.BitLen() {
+ // Partial reverse name (e.g. 127.in-addr.arpa.) exists but has no records.
+ } else if r.Qtype == dns.TypePTR || r.Qtype == dns.TypeANY {
+ rr := new(dns.PTR)
+ rr.Hdr = dns.RR_Header{Name: r.Qname, Rrtype: dns.TypePTR, Class: dns.ClassINET, Ttl: localhostTtl}
+ rr.Ptr = "localhost."
+ r.Reply.Answer = append(r.Reply.Answer, rr)
+ }
+ r.SendReply()
+}
diff --git a/osbase/net/dns/dns_test.go b/osbase/net/dns/dns_test.go
new file mode 100644
index 0000000..7dd670c
--- /dev/null
+++ b/osbase/net/dns/dns_test.go
@@ -0,0 +1,445 @@
+package dns
+
+import (
+ "net"
+ "testing"
+
+ "github.com/miekg/dns"
+
+ "source.monogon.dev/osbase/net/dns/test"
+)
+
+func testQuery(t *testing.T, service *Service, query, wantReply *dns.Msg) {
+ t.Helper()
+ wantReply.RecursionAvailable = true
+ testMsg(t, service, query, wantReply)
+}
+
+func testMsg(t *testing.T, service *Service, query, wantReply *dns.Msg) {
+ sCtx := &serviceCtx{service: service}
+ t.Helper()
+ wantReply.Response = true
+ writer := &testWriter{addr: &net.UDPAddr{}}
+ sCtx.ServeDNS(writer, query)
+ if got, want := writer.msg.String(), wantReply.String(); got != want {
+ t.Errorf("Want reply:\n%s\nGot:\n%s", want, got)
+ }
+}
+
+func TestBuiltinHandlers(t *testing.T) {
+ service := New(nil)
+
+ cases := []struct {
+ name string
+ qtype uint16
+ rcode int
+ answer []dns.RR
+ }{
+ {
+ name: "localhost.",
+ qtype: dns.TypeA,
+ answer: []dns.RR{test.RR("localhost. 300 IN A 127.0.0.1")},
+ },
+ {
+ name: "foo.bar.localhost.",
+ qtype: dns.TypeA,
+ answer: []dns.RR{test.RR("foo.bar.localhost. 300 IN A 127.0.0.1")},
+ },
+ {
+ name: "localhost.",
+ qtype: dns.TypeAAAA,
+ answer: []dns.RR{test.RR("localhost. 300 IN AAAA ::1")},
+ },
+ {
+ name: "localhost.",
+ qtype: dns.TypeANY,
+ answer: []dns.RR{
+ test.RR("localhost. 300 IN A 127.0.0.1"),
+ test.RR("localhost. 300 IN AAAA ::1"),
+ },
+ },
+ {
+ name: "localhost.",
+ qtype: dns.TypeMX,
+ },
+ {
+ name: "1.0.0.127.in-addr.arpa.",
+ qtype: dns.TypePTR,
+ answer: []dns.RR{test.RR("1.0.0.127.in-addr.arpa. 300 IN PTR localhost.")},
+ },
+ {
+ name: "1.0.0.127.in-addr.arpa.",
+ qtype: dns.TypeNS,
+ },
+ {
+ name: "2.127.in-addr.arpa.",
+ qtype: dns.TypePTR,
+ },
+ {
+ name: "foo.127.in-addr.arpa.",
+ qtype: dns.TypePTR,
+ rcode: dns.RcodeNameError,
+ },
+ {
+ name: "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.",
+ qtype: dns.TypePTR,
+ answer: []dns.RR{test.RR("1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa. 300 IN PTR localhost.")},
+ },
+ {
+ name: "foo.1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.",
+ qtype: dns.TypePTR,
+ rcode: dns.RcodeNameError,
+ },
+ {
+ name: "invalid.",
+ qtype: dns.TypeA,
+ rcode: dns.RcodeNameError,
+ },
+ {
+ name: "foo.bar.invalid.",
+ qtype: dns.TypeA,
+ rcode: dns.RcodeNameError,
+ },
+ }
+
+ for _, c := range cases {
+ query := new(dns.Msg)
+ query.SetQuestion(c.name, c.qtype)
+ query.RecursionDesired = false
+ wantReply := query.Copy()
+ wantReply.Authoritative = true
+ wantReply.Rcode = c.rcode
+ wantReply.Answer = c.answer
+ testQuery(t, service, query, wantReply)
+ }
+}
+
+type handlerFunc func(*Request)
+
+func (f handlerFunc) HandleDNS(r *Request) {
+ f(r)
+}
+
+func TestCustomHandlers(t *testing.T) {
+ service := New([]string{"handler1", "handler2"})
+ service.SetHandler("handler2", handlerFunc(func(r *Request) {
+ if IsSubDomain("example.com.", r.QnameCanonical) {
+ r.SetAuthoritative()
+ if r.Qtype == dns.TypeA || r.Qtype == dns.TypeANY {
+ rr := new(dns.A)
+ rr.Hdr = dns.RR_Header{Name: r.Qname, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 300}
+ rr.A = net.IP{1, 2, 3, 4}
+ r.Reply.Answer = append(r.Reply.Answer, rr)
+ }
+ r.SendReply()
+ }
+ }))
+
+ // Because handler1 is not yet set, this query should fail.
+ query := new(dns.Msg)
+ query.SetQuestion("example.com.", dns.TypeA)
+ wantReply := query.Copy()
+ wantReply.Rcode = dns.RcodeServerFailure
+ testQuery(t, service, query, wantReply)
+
+ service.SetHandler("handler1", EmptyDNSHandler{})
+
+ // Now, we should get the result from handler2.
+ query = new(dns.Msg)
+ query.SetQuestion("example.com.", dns.TypeA)
+ wantReply = query.Copy()
+ wantReply.Authoritative = true
+ wantReply.Answer = []dns.RR{test.RR("example.com. 300 IN A 1.2.3.4")}
+ testQuery(t, service, query, wantReply)
+
+ service.SetHandler("handler1", handlerFunc(func(r *Request) {
+ if IsSubDomain("example.com.", r.QnameCanonical) {
+ r.SetAuthoritative()
+ if r.Qtype == dns.TypeA || r.Qtype == dns.TypeANY {
+ rr := new(dns.A)
+ rr.Hdr = dns.RR_Header{Name: r.Qname, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 300}
+ rr.A = net.IP{5, 6, 7, 8}
+ r.Reply.Answer = append(r.Reply.Answer, rr)
+ }
+ r.SendReply()
+ }
+ }))
+
+ // Handlers can be updated, and are tried in the order in which they were
+ // declared when creating the Service.
+ query = new(dns.Msg)
+ query.SetQuestion("example.com.", dns.TypeA)
+ wantReply = query.Copy()
+ wantReply.Authoritative = true
+ wantReply.Answer = []dns.RR{test.RR("example.com. 300 IN A 5.6.7.8")}
+ testQuery(t, service, query, wantReply)
+
+ // Names which are not handled by any handler get refused.
+ query = new(dns.Msg)
+ query.SetQuestion("example.net.", dns.TypeA)
+ wantReply = query.Copy()
+ wantReply.Rcode = dns.RcodeRefused
+ testQuery(t, service, query, wantReply)
+}
+
+func TestRedirect(t *testing.T) {
+ service := New([]string{"handler1", "handler2"})
+ service.SetHandler("handler1", handlerFunc(func(r *Request) {
+ if IsSubDomain("example.net.", r.QnameCanonical) {
+ r.SetAuthoritative()
+ if r.Qtype == dns.TypeA || r.Qtype == dns.TypeANY {
+ rr := new(dns.A)
+ rr.Hdr = dns.RR_Header{Name: r.Qname, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 300}
+ rr.A = net.IP{1, 2, 3, 4}
+ r.Reply.Answer = append(r.Reply.Answer, rr)
+ }
+ r.SendReply()
+ }
+ }))
+ service.SetHandler("handler2", handlerFunc(func(r *Request) {
+ if IsSubDomain("example.com.", r.QnameCanonical) {
+ switch r.QnameCanonical {
+ case "1.example.com.":
+ r.AddCNAME("2.example.com.", 30)
+ case "2.example.com.":
+ r.AddCNAME("example.net.", 30)
+
+ case "loop.example.com.":
+ r.AddCNAME("loop.example.com.", 30)
+
+ case "loop1.example.com.":
+ r.AddCNAME("loop2.example.com.", 30)
+ case "loop2.example.com.":
+ r.AddCNAME("loop3.example.com.", 30)
+ case "loop3.example.com.":
+ r.AddCNAME("loop1.example.com.", 30)
+
+ case "chain1.example.com.":
+ r.AddCNAME("chain2.example.com.", 30)
+ case "chain2.example.com.":
+ r.AddCNAME("chain3.example.com.", 30)
+ case "chain3.example.com.":
+ r.AddCNAME("chain4.example.com.", 30)
+ case "chain4.example.com.":
+ r.AddCNAME("chain5.example.com.", 30)
+ case "chain5.example.com.":
+ r.AddCNAME("chain6.example.com.", 30)
+ case "chain6.example.com.":
+ r.AddCNAME("chain7.example.com.", 30)
+ case "chain7.example.com.":
+ r.AddCNAME("chain8.example.com.", 30)
+ case "chain8.example.com.":
+ r.AddCNAME("chain9.example.com.", 30)
+ case "chain9.example.com.":
+ r.AddCNAME("chain10.example.com.", 30)
+
+ default:
+ r.SendRcode(dns.RcodeNameError)
+ }
+ }
+ }))
+
+ // CNAME redirects are followed.
+ query := new(dns.Msg)
+ query.SetQuestion("1.example.com.", dns.TypeA)
+ wantReply := query.Copy()
+ wantReply.Answer = []dns.RR{
+ test.RR("1.example.com. 30 IN CNAME 2.example.com."),
+ test.RR("2.example.com. 30 IN CNAME example.net."),
+ test.RR("example.net. 300 IN A 1.2.3.4"),
+ }
+ testQuery(t, service, query, wantReply)
+
+ // Queries of type CNAME or ANY do not follow the redirect.
+ query = new(dns.Msg)
+ query.SetQuestion("1.example.com.", dns.TypeCNAME)
+ wantReply = query.Copy()
+ wantReply.Answer = []dns.RR{test.RR("1.example.com. 30 IN CNAME 2.example.com.")}
+ testQuery(t, service, query, wantReply)
+
+ query = new(dns.Msg)
+ query.SetQuestion("1.example.com.", dns.TypeANY)
+ wantReply = query.Copy()
+ wantReply.Answer = []dns.RR{test.RR("1.example.com. 30 IN CNAME 2.example.com.")}
+ testQuery(t, service, query, wantReply)
+
+ // Loops are detected.
+ query = new(dns.Msg)
+ query.SetQuestion("loop.example.com.", dns.TypeA)
+ wantReply = query.Copy()
+ wantReply.Answer = []dns.RR{test.RR("loop.example.com. 30 IN CNAME loop.example.com.")}
+ wantReply.Rcode = dns.RcodeServerFailure
+ testQuery(t, service, query, wantReply)
+
+ // Loops are detected.
+ query = new(dns.Msg)
+ query.SetQuestion("loop1.example.com.", dns.TypeA)
+ wantReply = query.Copy()
+ wantReply.Answer = []dns.RR{
+ test.RR("loop1.example.com. 30 IN CNAME loop2.example.com."),
+ test.RR("loop2.example.com. 30 IN CNAME loop3.example.com."),
+ test.RR("loop3.example.com. 30 IN CNAME loop1.example.com."),
+ }
+ wantReply.Rcode = dns.RcodeServerFailure
+ testQuery(t, service, query, wantReply)
+
+ // Number of redirects is limited.
+ query = new(dns.Msg)
+ query.SetQuestion("chain1.example.com.", dns.TypeA)
+ wantReply = query.Copy()
+ wantReply.Answer = []dns.RR{
+ test.RR("chain1.example.com. 30 IN CNAME chain2.example.com."),
+ test.RR("chain2.example.com. 30 IN CNAME chain3.example.com."),
+ test.RR("chain3.example.com. 30 IN CNAME chain4.example.com."),
+ test.RR("chain4.example.com. 30 IN CNAME chain5.example.com."),
+ test.RR("chain5.example.com. 30 IN CNAME chain6.example.com."),
+ test.RR("chain6.example.com. 30 IN CNAME chain7.example.com."),
+ test.RR("chain7.example.com. 30 IN CNAME chain8.example.com."),
+ test.RR("chain8.example.com. 30 IN CNAME chain9.example.com."),
+ }
+ wantReply.Rcode = dns.RcodeServerFailure
+ testQuery(t, service, query, wantReply)
+}
+
+func TestFlags(t *testing.T) {
+ service := New(nil)
+
+ query := new(dns.Msg)
+ query.SetQuestion("localhost.", dns.TypeA)
+
+ // Set flags which should be copied to the reply.
+ query.RecursionDesired = true
+ query.CheckingDisabled = true
+
+ wantReply := query.Copy()
+ wantReply.Authoritative = true
+ wantReply.Answer = []dns.RR{test.RR("localhost. 300 IN A 127.0.0.1")}
+
+ // Set flags which should be ignored.
+ query.Authoritative = true
+ query.RecursionAvailable = true
+ query.Zero = true
+ query.AuthenticatedData = true
+ query.Rcode = dns.RcodeRefused
+
+ testQuery(t, service, query, wantReply)
+}
+
+func TestOPT(t *testing.T) {
+ service := New(nil)
+
+ query := new(dns.Msg)
+ query.SetQuestion("localhost.", dns.TypeA)
+ wantReply := query.Copy()
+ wantReply.Authoritative = true
+ wantReply.Answer = []dns.RR{test.RR("localhost. 300 IN A 127.0.0.1")}
+ wantReply.SetEdns0(advertiseUDPSize, false)
+ query.SetEdns0(512, false)
+ testQuery(t, service, query, wantReply)
+
+ // DNSSEC ok flag.
+ query = new(dns.Msg)
+ query.SetQuestion("localhost.", dns.TypeA)
+ wantReply = query.Copy()
+ wantReply.Authoritative = true
+ wantReply.Answer = []dns.RR{test.RR("localhost. 300 IN A 127.0.0.1")}
+ wantReply.SetEdns0(advertiseUDPSize, true)
+ query.SetEdns0(512, true)
+ testQuery(t, service, query, wantReply)
+}
+
+func TestInvalidQuery(t *testing.T) {
+ service := New(nil)
+
+ // Valid query.
+ query := new(dns.Msg)
+ query.SetQuestion("localhost.", dns.TypeA)
+ wantReply := query.Copy()
+ wantReply.Authoritative = true
+ wantReply.Answer = []dns.RR{test.RR("localhost. 300 IN A 127.0.0.1")}
+ testQuery(t, service, query, wantReply)
+
+ // Not query opcode.
+ query = new(dns.Msg)
+ query.SetQuestion("localhost.", dns.TypeA)
+ query.Opcode = dns.OpcodeNotify
+ wantReply = query.Copy()
+ wantReply.RecursionDesired = false
+ wantReply.Rcode = dns.RcodeNotImplemented
+ testMsg(t, service, query, wantReply)
+
+ // Truncated.
+ query = new(dns.Msg)
+ query.SetQuestion("localhost.", dns.TypeA)
+ wantReply = query.Copy()
+ wantReply.Rcode = dns.RcodeFormatError
+ query.Truncated = true
+ testQuery(t, service, query, wantReply)
+
+ // Multiple OPTs.
+ query = new(dns.Msg)
+ query.SetQuestion("localhost.", dns.TypeA)
+ wantReply = query.Copy()
+ wantReply.Rcode = dns.RcodeFormatError
+ query.SetEdns0(512, false)
+ query.SetEdns0(512, false)
+ testQuery(t, service, query, wantReply)
+
+ // Unknown OPT version.
+ query = new(dns.Msg)
+ query.SetQuestion("localhost.", dns.TypeA)
+ wantReply = query.Copy()
+ wantReply.Rcode = dns.RcodeBadVers
+ wantReply.SetEdns0(advertiseUDPSize, false)
+ query.SetEdns0(512, false)
+ query.Extra[0].(*dns.OPT).SetVersion(1)
+ testQuery(t, service, query, wantReply)
+
+ // Invalid OPT name.
+ query = new(dns.Msg)
+ query.SetQuestion("localhost.", dns.TypeA)
+ wantReply = query.Copy()
+ wantReply.Rcode = dns.RcodeFormatError
+ query.SetEdns0(512, false)
+ query.Extra[0].(*dns.OPT).Hdr.Name = "localhost."
+ testQuery(t, service, query, wantReply)
+
+ // No question.
+ query = new(dns.Msg)
+ query.Id = dns.Id()
+ wantReply = query.Copy()
+ wantReply.Rcode = dns.RcodeRefused
+ testQuery(t, service, query, wantReply)
+
+ // Multiple questions.
+ query = new(dns.Msg)
+ query.Id = dns.Id()
+ query.Question = []dns.Question{
+ {Name: "localhost.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
+ {Name: "localhost.", Qtype: dns.TypeAAAA, Qclass: dns.ClassINET},
+ }
+ wantReply = query.Copy()
+ wantReply.Rcode = dns.RcodeRefused
+ testQuery(t, service, query, wantReply)
+
+ // OPT qtype.
+ query = new(dns.Msg)
+ query.SetQuestion("localhost.", dns.TypeOPT)
+ wantReply = query.Copy()
+ wantReply.Rcode = dns.RcodeFormatError
+ testQuery(t, service, query, wantReply)
+
+ // Zone transfer.
+ query = new(dns.Msg)
+ query.SetQuestion("localhost.", dns.TypeAXFR)
+ wantReply = query.Copy()
+ wantReply.Rcode = dns.RcodeRefused
+ testQuery(t, service, query, wantReply)
+
+ // Zone transfer.
+ query = new(dns.Msg)
+ query.SetQuestion("localhost.", dns.TypeIXFR)
+ wantReply = query.Copy()
+ wantReply.Rcode = dns.RcodeRefused
+ testQuery(t, service, query, wantReply)
+}
diff --git a/osbase/net/dns/hosts b/osbase/net/dns/hosts
new file mode 100644
index 0000000..77c24a1
--- /dev/null
+++ b/osbase/net/dns/hosts
@@ -0,0 +1,2 @@
+127.0.0.1 localhost
+::1 localhost
diff --git a/osbase/net/dns/metrics.go b/osbase/net/dns/metrics.go
new file mode 100644
index 0000000..13d046a
--- /dev/null
+++ b/osbase/net/dns/metrics.go
@@ -0,0 +1,26 @@
+package dns
+
+import (
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/prometheus/client_golang/prometheus/promauto"
+)
+
+// MetricsRegistry is the metrics registry in which all DNS metrics are
+// registered.
+var MetricsRegistry = prometheus.NewRegistry()
+var MetricsFactory = promauto.With(MetricsRegistry)
+
+var (
+ // rcode can be an uppercase rcode name, a numeric rcode if the rcode is not
+ // known, or one of:
+ // * redirected: The query was redirected by CNAME, so the final rcode
+ // is not yet known.
+ // * not_ready: The handler is not yet ready, SERVFAIL is replied.
+ handlerDuration = MetricsFactory.NewHistogramVec(prometheus.HistogramOpts{
+ Namespace: "dnsserver",
+ Subsystem: "server",
+ Name: "handler_duration_seconds",
+ Buckets: prometheus.ExponentialBuckets(0.00025, 2, 16), // from 0.25ms to 8 seconds
+ Help: "Histogram of the time each handler took.",
+ }, []string{"handler", "rcode"})
+)
diff --git a/osbase/net/dns/name.go b/osbase/net/dns/name.go
new file mode 100644
index 0000000..9769d7e
--- /dev/null
+++ b/osbase/net/dns/name.go
@@ -0,0 +1,157 @@
+package dns
+
+import (
+ "net/netip"
+ "strings"
+)
+
+// IsSubDomain returns true if child is the same as or a subdomain of parent.
+// Both names should be in canonical form.
+func IsSubDomain(parent, child string) bool {
+ offset := len(child) - len(parent)
+ if offset < 0 || child[offset:] != parent {
+ return false
+ }
+ if offset == 0 || parent == "." {
+ return true
+ }
+ if child[offset-1] != '.' {
+ return false
+ }
+ j := offset - 2
+ for j >= 0 && child[j] == '\\' {
+ j--
+ }
+ return (offset-j)%2 == 0
+}
+
+// SplitLastLabel splits off the last label of a domain name. For example,
+// "www.example.com." is split into "www.example." and "com".
+func SplitLastLabel(name string) (rest string, label string) {
+ labelEnd := len(name)
+ if labelEnd != 0 && name[labelEnd-1] == '.' {
+ labelEnd--
+ }
+ labelStart := labelEnd
+ for ; labelStart > 0; labelStart-- {
+ if name[labelStart-1] != '.' {
+ continue
+ }
+ j := labelStart - 2
+ for j >= 0 && name[j] == '\\' {
+ j--
+ }
+ if (labelStart-j)%2 != 0 {
+ continue
+ }
+ break
+ }
+ return name[:labelStart], name[labelStart:labelEnd]
+}
+
+// ParseReverse parses name as a reverse lookup name. If name is not a reverse
+// name, the returned IP is invalid. The second return value indicates how many
+// bits of the address are present. The third return value is true if there are
+// extra labels before the reverse name.
+func ParseReverse(name string) (ip netip.Addr, bits int, extra bool) {
+ if strings.HasSuffix(name, "in-addr.arpa.") {
+ var ip [4]uint8
+ field := 0
+ pos := len(name) - len("in-addr.arpa.") - 1
+ for pos >= 0 && field < 4 {
+ if name[pos] != '.' {
+ break
+ }
+ nextPos := pos - 1
+ for nextPos >= 0 && name[nextPos] >= '0' && name[nextPos] <= '9' {
+ nextPos--
+ }
+ val := 0
+ for valPos := nextPos + 1; valPos < pos; valPos++ {
+ val = val*10 + int(name[valPos]) - '0'
+ }
+ valLen := pos - nextPos - 1
+ if valLen == 0 || valLen > 3 || (valLen != 1 && name[nextPos+1] == '0') || val > 255 {
+ // Number is empty, or too long, or has leading zero, or is too large.
+ break
+ }
+ ip[field] = uint8(val)
+ field++
+ pos = nextPos
+ }
+ if pos >= 0 {
+ // We did not parse the entire name.
+ j := pos - 1
+ for j >= 0 && name[j] == '\\' {
+ j--
+ }
+ if name[pos] != '.' || (pos-j)%2 == 0 {
+ // The last label we parsed was not terminated by a non-escaped dot.
+ field--
+ if field < 0 {
+ return netip.Addr{}, 0, false
+ }
+ ip[field] = 0
+ }
+ }
+ return netip.AddrFrom4(ip), field * 8, pos >= 0
+ }
+
+ if strings.HasSuffix(name, "ip6.arpa.") {
+ var ip [16]uint8
+ field := 0
+ half := false
+
+ pos := len(name) - len("ip6.arpa.") - 1
+ for pos > 0 && field < 16 {
+ if name[pos] != '.' {
+ break
+ }
+ var nibble uint8
+ if name[pos-1] >= '0' && name[pos-1] <= '9' {
+ nibble = name[pos-1] - '0'
+ } else if name[pos-1] >= 'a' && name[pos-1] <= 'f' {
+ nibble = name[pos-1] - 'a' + 10
+ } else {
+ break
+ }
+ if half {
+ ip[field] |= nibble
+ field++
+ half = false
+ } else {
+ ip[field] = nibble << 4
+ half = true
+ }
+ pos -= 2
+ }
+ if pos >= 0 {
+ // We did not parse the entire name.
+ j := pos - 1
+ for j >= 0 && name[j] == '\\' {
+ j--
+ }
+ if name[pos] != '.' || (pos-j)%2 == 0 {
+ // The last label we parsed was not terminated by a non-escaped dot.
+ if half {
+ half = false
+ ip[field] = 0
+ } else {
+ half = true
+ field--
+ if field < 0 {
+ return netip.Addr{}, 0, false
+ }
+ ip[field] &= 0xf0
+ }
+ }
+ }
+ bits := field * 8
+ if half {
+ bits += 4
+ }
+ return netip.AddrFrom16(ip), bits, pos >= 0
+ }
+
+ return netip.Addr{}, 0, false
+}
diff --git a/osbase/net/dns/name_test.go b/osbase/net/dns/name_test.go
new file mode 100644
index 0000000..fab7bba
--- /dev/null
+++ b/osbase/net/dns/name_test.go
@@ -0,0 +1,155 @@
+package dns
+
+import (
+ "testing"
+)
+
+func TestIsSubDomain(t *testing.T) {
+ cases := []struct {
+ parent, child string
+ expected bool
+ }{
+ {".", ".", true},
+ {".", "test.", true},
+ {"example.com.", "example.com.", true},
+ {"example.com.", "www.example.com.", true},
+ {"example.com.", "xample.com.", false},
+ {"example.com.", "www.axample.com.", false},
+ {"example.com.", "wwwexample.com.", false},
+ {"example.com.", `www\.example.com.`, false},
+ {"example.com.", `www\\.example.com.`, true},
+ }
+ for _, c := range cases {
+ if IsSubDomain(c.parent, c.child) != c.expected {
+ t.Errorf("IsSubDomain(%q, %q): expected %v", c.parent, c.child, c.expected)
+ }
+ }
+}
+
+func TestSplitLastLabel(t *testing.T) {
+ cases := []struct {
+ name, rest, label string
+ }{
+ {"", "", ""},
+ {".", "", ""},
+ {"com.", "", "com"},
+ {"www.example.com", "www.example.", "com"},
+ {"www.example.com.", "www.example.", "com"},
+ {`www.example\.com.`, "www.", `example\.com`},
+ {`www.example\\.com.`, `www.example\\.`, "com"},
+ }
+ for _, c := range cases {
+ rest, label := SplitLastLabel(c.name)
+ if rest != c.rest || label != c.label {
+ t.Errorf("SplitLastLabel(%q) = (%q, %q), expected (%q, %q)", c.name, rest, label, c.rest, c.label)
+ }
+ }
+}
+
+func TestParseReverse(t *testing.T) {
+ cases := []struct {
+ name string
+ ip string
+ bits int
+ extra bool
+ }{
+ {"example.", "invalid IP", 0, false},
+ {"0.10.200.255.in-addr.arpa.", "255.200.10.0", 32, false},
+ {"7.6.45.123.in-addr.arpa.", "123.45.6.7", 32, false},
+ {"6.45.123.in-addr.arpa.", "123.45.6.0", 24, false},
+ {"45.123.in-addr.arpa.", "123.45.0.0", 16, false},
+ {"123.in-addr.arpa.", "123.0.0.0", 8, false},
+ {"in-addr.arpa.", "0.0.0.0", 0, false},
+ {"8.7.6.45.123.in-addr.arpa.", "123.45.6.7", 32, true}, // too many fields
+ {".6.45.123.in-addr.arpa.", "123.45.6.0", 24, true}, // empty field
+ {"7.06.45.123.in-addr.arpa.", "123.45.0.0", 16, true}, // leading 0
+ {"7.256.45.123.in-addr.arpa.", "123.45.0.0", 16, true}, // number too large
+ {"a6.45.123.in-addr.arpa.", "123.45.0.0", 16, true}, // invalid character
+ {`7\.6.45.123.in-addr.arpa.`, "123.45.0.0", 16, true}, // escaped .
+ {"0.6.45.123in-addr.arpa.", "invalid IP", 0, false}, // missing .
+ {
+ "0.1.2.3.4.5.6.7.8.9.a.b.c.d.e.f.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.",
+ "::fedc:ba98:7654:3210",
+ 128,
+ false,
+ },
+ {
+ "1.2.3.4.5.6.7.8.9.a.b.c.d.e.f.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.",
+ "::fedc:ba98:7654:3210",
+ 124,
+ false,
+ },
+ {
+ "2.3.4.5.6.7.8.9.a.b.c.d.e.f.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.",
+ "::fedc:ba98:7654:3200",
+ 120,
+ false,
+ },
+ {
+ "3.4.5.6.7.8.9.a.b.c.d.e.f.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.",
+ "::fedc:ba98:7654:3000",
+ 116,
+ false,
+ },
+ {
+ "2.ip6.arpa.",
+ "2000::",
+ 4,
+ false,
+ },
+ {
+ "ip6.arpa.",
+ "::",
+ 0,
+ false,
+ },
+ {
+ "0.0.1.2.3.4.5.6.7.8.9.a.b.c.d.e.f.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.", // too long
+ "::fedc:ba98:7654:3210",
+ 128,
+ true,
+ },
+ {
+ "01.2.3.4.5.6.7.8.9.a.b.c.d.e.f.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.", // missing dot
+ "::fedc:ba98:7654:3200",
+ 120,
+ true,
+ },
+ {
+ "001.2.3.4.5.6.7.8.9.a.b.c.d.e.f.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.", // missing dot
+ "::fedc:ba98:7654:3200",
+ 120,
+ true,
+ },
+ {
+ `0.1\.2.3.4.5.6.7.8.9.a.b.c.d.e.f.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.`, // escaped dot
+ "::fedc:ba98:7654:3000",
+ 116,
+ true,
+ },
+ {
+ "g.1.2.3.4.5.6.7.8.9.a.b.c.d.e.f.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.", // invalid character
+ "::fedc:ba98:7654:3210",
+ 124,
+ true,
+ },
+ }
+ for _, c := range cases {
+ ip, bits, extra := ParseReverse(c.name)
+ if ip.String() != c.ip || bits != c.bits || extra != c.extra {
+ t.Errorf("ParseReverse(%q) = (%s, %v, %v), expected (%s, %v, %v)", c.name, ip, bits, extra, c.ip, c.bits, c.extra)
+ }
+ }
+}
+
+func BenchmarkParseReverseIPv4(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ ParseReverse("7.6.45.123.in-addr.arpa.")
+ }
+}
+
+func BenchmarkParseReverseIPv6(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ ParseReverse("0.1.2.3.4.5.6.7.8.9.a.b.c.d.e.f.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.")
+ }
+}
diff --git a/osbase/net/dns/resolv.conf b/osbase/net/dns/resolv.conf
new file mode 100644
index 0000000..d148e5f
--- /dev/null
+++ b/osbase/net/dns/resolv.conf
@@ -0,0 +1,3 @@
+# NOTE: This is baked into the rootfs. All DNS-related settings are in the
+# DNS service at //osbase/net/dns.
+nameserver 127.0.0.1
diff --git a/osbase/net/dns/test/BUILD.bazel b/osbase/net/dns/test/BUILD.bazel
new file mode 100644
index 0000000..3cbd61d
--- /dev/null
+++ b/osbase/net/dns/test/BUILD.bazel
@@ -0,0 +1,16 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "test",
+ srcs = ["server.go"],
+ importpath = "source.monogon.dev/osbase/net/dns/test",
+ visibility = ["//osbase/net/dns:__subpackages__"],
+ deps = ["@com_github_miekg_dns//:dns"],
+)
+
+go_test(
+ name = "test_test",
+ srcs = ["server_test.go"],
+ embed = [":test"],
+ deps = ["@com_github_miekg_dns//:dns"],
+)
diff --git a/osbase/net/dns/test/LICENSE-3RD-PARTY.txt b/osbase/net/dns/test/LICENSE-3RD-PARTY.txt
new file mode 100644
index 0000000..98f9935
--- /dev/null
+++ b/osbase/net/dns/test/LICENSE-3RD-PARTY.txt
@@ -0,0 +1,13 @@
+Copyright 2016-2024 The CoreDNS authors and contributors
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
diff --git a/osbase/net/dns/test/server.go b/osbase/net/dns/test/server.go
new file mode 100644
index 0000000..a877820
--- /dev/null
+++ b/osbase/net/dns/test/server.go
@@ -0,0 +1,73 @@
+package test
+
+// Taken and modified from CoreDNS, under Apache 2.0.
+
+import (
+ "fmt"
+ "net"
+
+ "github.com/miekg/dns"
+)
+
+// A Server is an DNS server listening on a system-chosen port on the local
+// loopback interface, for use in end-to-end DNS tests.
+type Server struct {
+ Addr string // Address where the server listening.
+
+ s1 *dns.Server // udp
+ s2 *dns.Server // tcp
+}
+
+// NewServer starts and returns a new Server. The caller should call Close when
+// finished, to shut it down.
+func NewServer(f dns.HandlerFunc) *Server {
+ ch1 := make(chan bool)
+ ch2 := make(chan bool)
+
+ s1 := &dns.Server{Handler: f} // udp
+ s2 := &dns.Server{Handler: f} // tcp
+
+ for i := 0; i < 5; i++ { // 5 attempts
+ s2.Listener, _ = net.Listen("tcp", "[::1]:0")
+ if s2.Listener == nil {
+ continue
+ }
+
+ s1.PacketConn, _ = net.ListenPacket("udp", s2.Listener.Addr().String())
+ if s1.PacketConn != nil {
+ break
+ }
+
+ // perhaps UDP port is in use, try again
+ s2.Listener.Close()
+ s2.Listener = nil
+ }
+ if s2.Listener == nil {
+ panic("dnstest.NewServer(): failed to create new server")
+ }
+
+ s1.NotifyStartedFunc = func() { close(ch1) }
+ s2.NotifyStartedFunc = func() { close(ch2) }
+ go s1.ActivateAndServe()
+ go s2.ActivateAndServe()
+
+ <-ch1
+ <-ch2
+
+ return &Server{s1: s1, s2: s2, Addr: s2.Listener.Addr().String()}
+}
+
+// Close shuts down the server.
+func (s *Server) Close() {
+ s.s1.Shutdown()
+ s.s2.Shutdown()
+}
+
+// RR parses s as a DNS resource record.
+func RR(s string) dns.RR {
+ rr, err := dns.NewRR(s)
+ if err != nil {
+ panic(fmt.Sprintf("Failed to parse rr %q: %v", s, err))
+ }
+ return rr
+}
diff --git a/osbase/net/dns/test/server_test.go b/osbase/net/dns/test/server_test.go
new file mode 100644
index 0000000..ea60845
--- /dev/null
+++ b/osbase/net/dns/test/server_test.go
@@ -0,0 +1,39 @@
+package test
+
+// Taken and modified from CoreDNS, under Apache 2.0.
+
+import (
+ "testing"
+
+ "github.com/miekg/dns"
+)
+
+func TestNewServer(t *testing.T) {
+ s := NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ w.WriteMsg(ret)
+ })
+ defer s.Close()
+
+ c := new(dns.Client)
+ c.Net = "tcp"
+ m := new(dns.Msg)
+ m.SetQuestion("example.org.", dns.TypeSOA)
+ ret, _, err := c.Exchange(m, s.Addr)
+ if err != nil {
+ t.Fatalf("Could not send message to dnstest.Server: %s", err)
+ }
+ if ret.Id != m.Id {
+ t.Fatalf("Msg ID's should match, expected %d, got %d", m.Id, ret.Id)
+ }
+
+ c.Net = "udp"
+ ret, _, err = c.Exchange(m, s.Addr)
+ if err != nil {
+ t.Fatalf("Could not send message to dnstest.Server: %s", err)
+ }
+ if ret.Id != m.Id {
+ t.Fatalf("Msg ID's should match, expected %d, got %d", m.Id, ret.Id)
+ }
+}
diff --git a/osbase/net/dns/testhelpers.go b/osbase/net/dns/testhelpers.go
new file mode 100644
index 0000000..b8943fd
--- /dev/null
+++ b/osbase/net/dns/testhelpers.go
@@ -0,0 +1,68 @@
+package dns
+
+import (
+ "errors"
+ "fmt"
+ "net"
+
+ "github.com/miekg/dns"
+)
+
+// CreateTestRequest creates a Request for use in tests.
+func CreateTestRequest(qname string, qtype uint16, proto string) *Request {
+ var addr net.Addr
+ switch proto {
+ case "udp":
+ addr = &net.UDPAddr{}
+ case "tcp":
+ addr = &net.TCPAddr{}
+ default:
+ panic(fmt.Sprintf("Unknown protocol %q", proto))
+ }
+ req := &Request{
+ Reply: new(dns.Msg),
+ Writer: &testWriter{addr: addr},
+ Qopt: &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}},
+ Ropt: &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}},
+ Qname: qname,
+ QnameCanonical: dns.CanonicalName(qname),
+ Qtype: qtype,
+ }
+ req.Reply.Response = true
+ req.Reply.Question = []dns.Question{{
+ Name: qname,
+ Qtype: qtype,
+ Qclass: dns.ClassINET,
+ }}
+ req.Reply.RecursionAvailable = true
+ req.Reply.RecursionDesired = true
+ req.Qopt.SetUDPSize(advertiseUDPSize)
+ req.Ropt.SetUDPSize(advertiseUDPSize)
+ return req
+}
+
+type testWriter struct {
+ addr net.Addr
+ msg *dns.Msg
+}
+
+func (t *testWriter) LocalAddr() net.Addr { return t.addr }
+func (t *testWriter) RemoteAddr() net.Addr { return t.addr }
+func (*testWriter) Write([]byte) (int, error) {
+ return 0, errors.New("testWriter only supports WriteMsg")
+}
+func (*testWriter) Close() error { return nil }
+func (*testWriter) TsigStatus() error { return nil }
+func (*testWriter) TsigTimersOnly(bool) {}
+func (*testWriter) Hijack() {}
+
+func (t *testWriter) WriteMsg(msg *dns.Msg) error {
+ if msg == nil {
+ panic("WriteMsg(nil)")
+ }
+ if t.msg != nil {
+ panic("duplicate WriteMsg()")
+ }
+ t.msg = msg
+ return nil
+}