blob: 1f384f63e71daafb456ffa2e4c78dce2bcc42962 [file] [log] [blame]
Serge Bazanskife7134b2022-04-01 15:46:29 +02001package socksproxy
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "io"
8 "net"
9 "net/http"
10 "os"
11 "sync/atomic"
12 "testing"
13
14 "golang.org/x/net/proxy"
15)
16
17// TestE2E implements a happy path test by chaining together an HTTP server, a
18// proxy server, a proxy client (from golang.org/x/net) and an HTTP client into
19// an end-to-end test. It uses HostHandler and the actual host network stack for
20// the test HTTP server and test proxy server.
21func TestE2E(t *testing.T) {
22 ctx, ctxC := context.WithCancel(context.Background())
23 defer ctxC()
24
25 // Start test HTTP server.
26 lisSrv, err := net.Listen("tcp", "127.0.0.1:0")
27 if err != nil {
28 t.Fatalf("could not bind http listener: %v", err)
29 }
30
31 mux := http.NewServeMux()
32 mux.HandleFunc("/", func(rw http.ResponseWriter, req *http.Request) {
33 fmt.Fprintf(rw, "foo")
34 })
35 go func() {
36 err := http.Serve(lisSrv, mux)
37 if err != nil {
38 t.Fatalf("http.Serve: %v", err)
39 }
40 }()
41
42 // Start proxy server.
43 lisPrx, err := net.Listen("tcp", ":")
44 if err != nil {
45 t.Fatalf("could not bind proxy listener: %v", err)
46 }
47 go func() {
48 err := Serve(ctx, HostHandler, lisPrx)
49 if err != nil && !errors.Is(err, ctx.Err()) {
50 t.Fatalf("proxy.Serve: %v", err)
51 }
52 }()
53
54 // Start proxy client.
55 dialer, err := proxy.SOCKS5("tcp", lisPrx.Addr().String(), nil, proxy.Direct)
56 if err != nil {
57 t.Fatalf("creating SOCKS dialer failed: %v", err)
58 }
59
60 // Create http client.
61 tr := &http.Transport{
62 Dial: dialer.Dial,
63 }
64 cl := &http.Client{
65 Transport: tr,
66 }
67
68 // Perform request and expect 'foo' in response.
69 url := fmt.Sprintf("http://%s/", lisSrv.Addr().String())
70 req, err := http.NewRequest("GET", url, nil)
71 if err != nil {
72 t.Fatalf("creating test request failed: %v", err)
73 }
74 res, err := cl.Do(req)
75 if err != nil {
76 t.Fatalf("test http request failed: %v", err)
77 }
78 defer res.Body.Close()
79 body, _ := io.ReadAll(res.Body)
80 if want, got := "foo", string(body); want != got {
81 t.Errorf("wrong response from HTTP, wanted %q, got %q", want, got)
82 }
83}
84
85// testHandler is a handler which serves /dev/zero and keeps count of the
86// current number of live connections. It's used in TestCancellation to ensure
87// contexts are canceled appropriately.
88type testHandler struct {
89 live int64
90}
91
92func (t *testHandler) Connect(ctx context.Context, req *ConnectRequest) *ConnectResponse {
93 f, _ := os.Open("/dev/zero")
94
95 atomic.AddInt64(&t.live, 1)
96 go func() {
97 <-ctx.Done()
98 atomic.AddInt64(&t.live, -1)
99 f.Close()
100 }()
101
102 return &ConnectResponse{
103 Backend: f,
104 LocalAddress: net.ParseIP("127.0.0.1"),
105 LocalPort: 42123,
106 }
107}
108
109// TestCancellation ensures request contexts are canceled correctly - when an
110// incoming connection is closed and when the entire server is stopped.
111func TestCancellation(t *testing.T) {
112 handler := &testHandler{}
113
114 ctx, ctxC := context.WithCancel(context.Background())
115 defer ctxC()
116
117 // Start proxy server.
118 lisPrx, err := net.Listen("tcp", ":")
119 if err != nil {
120 t.Fatalf("could not bind proxy listener: %v", err)
121 }
122 go func() {
123 err := Serve(ctx, handler, lisPrx)
124 if err != nil && !errors.Is(err, ctx.Err()) {
125 t.Fatalf("proxy.Serve: %v", err)
126 }
127 }()
128
129 // Start proxy client.
130 dialer, err := proxy.SOCKS5("tcp", lisPrx.Addr().String(), nil, proxy.Direct)
131 if err != nil {
132 t.Fatalf("creating SOCKS dialer failed: %v", err)
133 }
134
135 // Open two connections.
136 con1, err := dialer.Dial("tcp", "192.2.0.10:1234")
137 if err != nil {
138 t.Fatalf("Dialing first client failed: %v", err)
139 }
140 con2, err := dialer.Dial("tcp", "192.2.0.10:1234")
141 if err != nil {
142 t.Fatalf("Dialing first client failed: %v", err)
143 }
144
145 // Read some data. This makes sure we're ready to check for the liveness of
146 // currently running connections.
147 io.ReadFull(con1, make([]byte, 3))
148 io.ReadFull(con2, make([]byte, 3))
149
150 // Ensure we have two connections.
151 if want, got := int64(2), atomic.LoadInt64(&handler.live); want != got {
152 t.Errorf("wanted %d connections at first, got %d", want, got)
153 }
154
155 // Close one connection. Wait for its context to be canceled.
156 con2.Close()
157 for {
158 if atomic.LoadInt64(&handler.live) == 1 {
159 break
160 }
161 }
162
163 // Cancel the entire server context. Wait for the other connection's context to
164 // be canceled as well.
165 ctxC()
166 for {
167 if atomic.LoadInt64(&handler.live) == 0 {
168 break
169 }
170 }
171}