blob: 70fd6b2f331b712920b7e04427457ed19a59415c [file] [log] [blame]
Tim Windelschmidt6d33a432025-02-04 14:34:25 +01001// Copyright The Monogon Project Authors.
2// SPDX-License-Identifier: Apache-2.0
3
Jan Schär75ea9f42024-07-29 17:01:41 +02004package proxy
5
6// Taken and modified from CoreDNS, under Apache 2.0.
7
8import (
9 "errors"
10 "fmt"
11 "math"
12 "testing"
13 "time"
14
15 "github.com/miekg/dns"
16
17 "source.monogon.dev/osbase/net/dns/test"
18)
19
20func TestProxy(t *testing.T) {
21 s := test.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
22 ret := new(dns.Msg)
23 ret.SetReply(r)
24 ret.Answer = append(ret.Answer, test.RR("example.org. IN A 127.0.0.1"))
25 w.WriteMsg(ret)
26 })
27 defer s.Close()
28
29 p := NewProxy(s.Addr)
Jan Schär75ea9f42024-07-29 17:01:41 +020030 p.Start(5 * time.Second)
31 m := new(dns.Msg)
32
33 m.SetQuestion("example.org.", dns.TypeA)
34
35 resp, err := p.Connect(m, false)
36 if err != nil {
37 t.Errorf("Failed to connect to testdnsserver: %s", err)
38 }
39
40 if x := resp.Answer[0].Header().Name; x != "example.org." {
41 t.Errorf("Expected %s, got %s", "example.org.", x)
42 }
43}
44
45func TestProtocolSelection(t *testing.T) {
46 p := NewProxy("bad_address")
Jan Schär75ea9f42024-07-29 17:01:41 +020047
48 go func() {
49 p.Connect(new(dns.Msg), false)
50 p.Connect(new(dns.Msg), true)
51 }()
52
53 for i, exp := range []string{"udp", "tcp"} {
54 proto := <-p.transport.dial
55 p.transport.ret <- nil
56 if proto != exp {
57 t.Errorf("Unexpected protocol in case %d, expected %q, actual %q", i, exp, proto)
58 }
59 }
60}
61
62func TestProxyIncrementFails(t *testing.T) {
63 var testCases = []struct {
64 name string
65 fails uint32
66 expectFails uint32
67 }{
68 {
69 name: "increment fails counter overflows",
70 fails: math.MaxUint32,
71 expectFails: math.MaxUint32,
72 },
73 {
74 name: "increment fails counter",
75 fails: 0,
76 expectFails: 1,
77 },
78 }
79
80 for _, tc := range testCases {
81 t.Run(tc.name, func(t *testing.T) {
82 p := NewProxy("bad_address")
83 p.fails = tc.fails
84 p.incrementFails()
85 if p.fails != tc.expectFails {
86 t.Errorf("Expected fails to be %d, got %d", tc.expectFails, p.fails)
87 }
88 })
89 }
90}
91
92func TestCoreDNSOverflow(t *testing.T) {
93 s := test.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
94 ret := new(dns.Msg)
95 ret.SetReply(r)
96
97 var answers []dns.RR
98 for i := range 50 {
99 answers = append(answers, test.RR(fmt.Sprintf("example.org. IN A 127.0.0.%v", i)))
100 }
101 ret.Answer = answers
102 w.WriteMsg(ret)
103 })
104 defer s.Close()
105
106 p := NewProxy(s.Addr)
Jan Schär75ea9f42024-07-29 17:01:41 +0200107 p.Start(5 * time.Second)
108 defer p.Stop()
109
110 // Test different connection modes
111 testConnection := func(proto string, useTCP bool, expectTruncated bool) {
112 t.Helper()
113
114 queryMsg := new(dns.Msg)
115 queryMsg.SetQuestion("example.org.", dns.TypeA)
116
117 response, err := p.Connect(queryMsg, useTCP)
118 if err != nil {
119 t.Errorf("Failed to connect to testdnsserver: %s", err)
Jan Schär363322e2025-03-26 15:41:56 +0000120 return
Jan Schär75ea9f42024-07-29 17:01:41 +0200121 }
122
123 if response.Truncated != expectTruncated {
124 t.Errorf("Expected truncated response for %s, but got TC flag %v", proto, response.Truncated)
125 }
126 }
127
128 // Test udp, expect truncated response
129 testConnection("udp", false, true)
130
131 // Test tcp, expect no truncated response
132 testConnection("tcp", true, false)
133}
134
135func TestShouldTruncateResponse(t *testing.T) {
136 testCases := []struct {
137 testname string
138 err error
139 expected bool
140 }{
141 {"BadAlgorithm", dns.ErrAlg, false},
142 {"BufferSizeTooSmall", dns.ErrBuf, true},
143 {"OverflowUnpackingA", errors.New("overflow unpacking a"), true},
144 {"OverflowingHeaderSize", errors.New("overflowing header size"), true},
145 {"OverflowpackingA", errors.New("overflow packing a"), true},
146 {"ErrSig", dns.ErrSig, false},
147 }
148
149 for _, tc := range testCases {
150 t.Run(tc.testname, func(t *testing.T) {
151 result := shouldTruncateResponse(tc.err)
152 if result != tc.expected {
153 t.Errorf("For testname '%v', expected %v but got %v", tc.testname, tc.expected, result)
154 }
155 })
156 }
157}