m/pkg/socksproxy: init
This implements a simple SOCKS5 proxy server, which will be used within
nanoswitch to expose multiple nodes to test code and metroctl.
Some existing alternatives were considered, but none were in a healthy
enough state to be usable within Metropolis. And, in the end, we only
need a small subset of an already simple standard, so implementing this
ourselves isn't a massive waste of time.
Change-Id: Ifa4d4edf837b55b93cae9981028efef336ff2a3d
Reviewed-on: https://review.monogon.dev/c/monogon/+/646
Reviewed-by: Mateusz Zalega <mateusz@monogon.tech>
diff --git a/metropolis/pkg/socksproxy/BUILD.bazel b/metropolis/pkg/socksproxy/BUILD.bazel
new file mode 100644
index 0000000..c953dbf
--- /dev/null
+++ b/metropolis/pkg/socksproxy/BUILD.bazel
@@ -0,0 +1,18 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "socksproxy",
+ srcs = [
+ "protocol.go",
+ "socksproxy.go",
+ ],
+ importpath = "source.monogon.dev/metropolis/pkg/socksproxy",
+ visibility = ["//visibility:public"],
+)
+
+go_test(
+ name = "socksproxy_test",
+ srcs = ["socksproxy_test.go"],
+ embed = [":socksproxy"],
+ deps = ["@org_golang_x_net//proxy"],
+)
diff --git a/metropolis/pkg/socksproxy/protocol.go b/metropolis/pkg/socksproxy/protocol.go
new file mode 100644
index 0000000..cb9ae0a
--- /dev/null
+++ b/metropolis/pkg/socksproxy/protocol.go
@@ -0,0 +1,195 @@
+package socksproxy
+
+import (
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+)
+
+// readMethods implements RFC1928 3. “Procedure for TCP-based clients”,
+// paragraph 3. It receives a 'version identifier/method selection message' from
+// r and returns the methods supported by the client.
+func readMethods(r io.Reader) ([]method, error) {
+ var ver uint8
+ if err := binary.Read(r, binary.BigEndian, &ver); err != nil {
+ return nil, fmt.Errorf("when reading ver: %w", err)
+ }
+ if ver != 5 {
+ return nil, fmt.Errorf("unimplemented version %d", ver)
+ }
+ var nmethods uint8
+ if err := binary.Read(r, binary.BigEndian, &nmethods); err != nil {
+ return nil, fmt.Errorf("when reading nmethods: %w", err)
+ }
+ methodBytes := make([]byte, nmethods)
+ if _, err := io.ReadFull(r, methodBytes); err != nil {
+ return nil, fmt.Errorf("while reading methods: %w", err)
+ }
+ methods := make([]method, nmethods)
+ for i, m := range methodBytes {
+ methods[i] = method(m)
+ }
+ return methods, nil
+}
+
+// writeMethod implements RFC1928 3. “Procedure for TCP-based clients”,
+// paragraph 5. It sends a selected method to w.
+func writeMethod(w io.Writer, m method) error {
+ if err := binary.Write(w, binary.BigEndian, uint8(5)); err != nil {
+ return fmt.Errorf("while writing version: %w", err)
+ }
+ if err := binary.Write(w, binary.BigEndian, uint8(m)); err != nil {
+ return fmt.Errorf("while writing method: %w", err)
+ }
+ return nil
+}
+
+// method is an RFC1928 authentication method.
+type method uint8
+
+const (
+ methodNoAuthenticationRequired method = 0
+ methodNoAcceptableMethods method = 0xff
+)
+
+// negotiateMethod implements the entire flow RFC1928 3. “Procedure for
+// TCP-based clients” by negotiating for the 'NO AUTHENTICATION REQUIRED'
+// authentication method, and failing otherwise.
+func negotiateMethod(rw io.ReadWriter) error {
+ methods, err := readMethods(rw)
+ if err != nil {
+ return fmt.Errorf("could not read methods: %w", err)
+ }
+
+ found := false
+ for _, m := range methods {
+ if m == methodNoAuthenticationRequired {
+ found = true
+ break
+ }
+ }
+ if !found {
+ // Discard error, as this connection is failed anyway.
+ writeMethod(rw, methodNoAcceptableMethods)
+ return fmt.Errorf("no acceptable methods found")
+ }
+ if err := writeMethod(rw, methodNoAuthenticationRequired); err != nil {
+ return fmt.Errorf("could not respond with method: %w", err)
+ }
+ return nil
+}
+
+var (
+ // errNotConnect is returned by readRequest when the request contained some
+ // other request than CONNECT.
+ errNotConnect = errors.New("not CONNECT")
+ // errUnsupportedAddressType is returned by readRequest when the request
+ // contained some unsupported address type (not IPv4 or IPv6).
+ errUnsupportedAddressType = errors.New("unsupported address type")
+)
+
+// readRequest implements RFC1928 4. “Requests” by reading a SOCKS request from
+// r and ensuring it's an IPv4/IPv6 CONNECT request. The parsed address/port
+// pair is then returned.
+func readRequest(r io.Reader) (*connectRequest, error) {
+ header := struct {
+ Ver uint8
+ Cmd uint8
+ Rsv uint8
+ Atyp uint8
+ }{}
+ if err := binary.Read(r, binary.BigEndian, &header); err != nil {
+ return nil, fmt.Errorf("when reading request header: %w", err)
+ }
+
+ if header.Ver != 5 {
+ return nil, fmt.Errorf("invalid version %d", header.Ver)
+ }
+ if header.Cmd != 1 {
+ return nil, errNotConnect
+ }
+
+ var addrBytes []byte
+ switch header.Atyp {
+ case 1:
+ addrBytes = make([]byte, 4)
+ case 4:
+ addrBytes = make([]byte, 4)
+ default:
+ return nil, errUnsupportedAddressType
+ }
+ if _, err := io.ReadFull(r, addrBytes); err != nil {
+ return nil, fmt.Errorf("when reading address: %w", err)
+ }
+
+ var port uint16
+ if err := binary.Read(r, binary.BigEndian, &port); err != nil {
+ return nil, fmt.Errorf("when reading port: %w", err)
+ }
+
+ return &connectRequest{
+ address: addrBytes,
+ port: port,
+ }, nil
+}
+
+type connectRequest struct {
+ address net.IP
+ port uint16
+}
+
+// Reply is an RFC1928 6. “Replies” reply field value. It's returned to the
+// client by internal socksproxy code or a Handler to signal a success or error
+// condition within an RFC1928 reply.
+type Reply uint8
+
+const (
+ ReplySucceeded Reply = 0
+ ReplyGeneralFailure Reply = 1
+ ReplyConnectionNotAllowed Reply = 2
+ ReplyNetworkUnreachable Reply = 3
+ ReplyHostUnreachable Reply = 4
+ ReplyConnectionRefused Reply = 5
+ ReplyTTLExpired Reply = 6
+ ReplyCommandNotSupported Reply = 7
+ ReplyAddressTypeNotSupported Reply = 8
+)
+
+// writeReply implements RFC1928 6. “Replies” by sending a given Reply, bind
+// address and bind port to w. An error is returned if the given bind address is
+// invaild, or if a communication error occurred.
+func writeReply(w io.Writer, r Reply, bindAddr net.IP, bindPort uint16) error {
+ var atyp uint8
+ switch len(bindAddr) {
+ case 4:
+ atyp = 1
+ case 16:
+ atyp = 4
+ default:
+ return fmt.Errorf("unsupported bind address type")
+ }
+
+ header := struct {
+ Ver uint8
+ Reply uint8
+ Rsv uint8
+ Atyp uint8
+ }{
+ Ver: 5,
+ Reply: uint8(r),
+ Rsv: 0,
+ Atyp: atyp,
+ }
+ if err := binary.Write(w, binary.BigEndian, &header); err != nil {
+ return fmt.Errorf("when writing reply header: %w", err)
+ }
+ if _, err := w.Write(bindAddr); err != nil {
+ return fmt.Errorf("when writing reply bind address: %w", err)
+ }
+ if err := binary.Write(w, binary.BigEndian, bindPort); err != nil {
+ return fmt.Errorf("when writing reply bind port: %w", err)
+ }
+ return nil
+}
diff --git a/metropolis/pkg/socksproxy/socksproxy.go b/metropolis/pkg/socksproxy/socksproxy.go
new file mode 100644
index 0000000..ce35cec
--- /dev/null
+++ b/metropolis/pkg/socksproxy/socksproxy.go
@@ -0,0 +1,218 @@
+// package socksproxy implements a limited subset of the SOCKS 5 (RFC1928)
+// protocol in the form of a pluggable Proxy object. However, this
+// implementation is _not_ RFC1928 compliant, as it does not implement GSSAPI
+// (which is mandated by the spec). It currently only implements CONNECT
+// requests to IPv4/IPv6 addresses. It also doesn't implement any
+// timeout/keepalive system for killing inactive connections.
+//
+// The intended use of the library is internally within Metropolis development
+// environments for contacting test clusters. The code is simple and robust, but
+// not really productionized (as noted above - no timeouts and no authentication
+// make it a bad idea to ever expose this proxy server publicly).
+//
+// There are multiple other, existing Go SOCKS4/5 server implementations, but
+// many of them are either not context aware, part of a larger project (and thus
+// difficult to extract) or are brand new/untested/bleeding edge code.
+package socksproxy
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "log"
+ "net"
+ "strconv"
+)
+
+// Handler should be implemented by socksproxy users to allow SOCKS connections
+// to be proxied in any other way than via the HostHandler.
+type Handler interface {
+ // Connect is called by the server any time a SOCKS client sends a CONNECT
+ // request. The function should return a ConnectResponse describing some
+ // 'backend' connection, ie. the connection that will then be exposed to the
+ // SOCKS client.
+ //
+ // Connect should return with Error set to a non-default value to abort/deny the
+ // connection request.
+ //
+ // The underlying incoming socket is managed by the proxy server and is not
+ // visible to the client. However, any sockets/connections/files opened by the
+ // Handler should be cleaned up by tying them to the given context, which will
+ // be canceled whenever the connection is closed.
+ Connect(context.Context, *ConnectRequest) *ConnectResponse
+}
+
+// ConnectRequest represents a pending CONNECT request from a client.
+type ConnectRequest struct {
+ // Address is an IPv4 or IPv6 address that the client requested to connect to.
+ // This address might be invalid/malformed/internal, and the Connect method
+ // should sanitize it before using it.
+ Address net.IP
+ // Port is the TCP port number that the client requested to connect to.
+ Port uint16
+}
+
+// ConnectResponse indicates a 'backend' connection that the proxy should expose
+// to the client, or an error if the connection cannot be made.
+type ConnectResponse struct {
+ // Error will cause an error to be returned if it is anything else than the
+ // default value (ReplySucceeded).
+ Error Reply
+
+ // Backend is the ReadWriter that will be bridged over to the connecting client
+ // if no Error is set.
+ Backend io.ReadWriter
+ // LocalAddress is the IP address that is returned to the client as the local
+ // address of the newly established backend connection.
+ LocalAddress net.IP
+ // LocalPort is the local TCP port number that is returned to the client as the
+ // local port of the newly established backend connection.
+ LocalPort uint16
+}
+
+// ConnectResponseFromConn builds a ConnectResponse from a net.Conn. This can be
+// used by custom Handlers to easily return a ConnectResponse for a newly
+// established net.Conn, eg. from a Dial call.
+//
+// An error is returned if the given net.Conn does not carry a properly formed
+// LocalAddr.
+func ConnectResponseFromConn(c net.Conn) (*ConnectResponse, error) {
+ laddr := c.LocalAddr().String()
+ host, port, err := net.SplitHostPort(laddr)
+ if err != nil {
+ return nil, fmt.Errorf("could not parse LocalAddr %q: %w", laddr, err)
+ }
+ addr := net.ParseIP(host)
+ if addr == nil {
+ return nil, fmt.Errorf("could not parse LocalAddr host %q as IP", host)
+ }
+ portNum, err := strconv.ParseUint(port, 10, 16)
+ if err != nil {
+ return nil, fmt.Errorf("could not parse LocalAddr port %q", port)
+ }
+ return &ConnectResponse{
+ Backend: c,
+ LocalAddress: addr,
+ LocalPort: uint16(portNum),
+ }, nil
+}
+
+type hostHandler struct{}
+
+func (h *hostHandler) Connect(ctx context.Context, req *ConnectRequest) *ConnectResponse {
+ port := fmt.Sprintf("%d", req.Port)
+ addr := net.JoinHostPort(req.Address.String(), port)
+ s, err := net.Dial("tcp", addr)
+ if err != nil {
+ log.Printf("HostHandler could not dial %q: %v", addr, err)
+ return &ConnectResponse{Error: ReplyConnectionRefused}
+ }
+ go func() {
+ <-ctx.Done()
+ s.Close()
+ }()
+ res, err := ConnectResponseFromConn(s)
+ if err != nil {
+ log.Printf("HostHandler could not build response: %v", err)
+ return &ConnectResponse{Error: ReplyGeneralFailure}
+ }
+ return res
+}
+
+var (
+ // HostHandler is an unsafe SOCKS5 proxy Handler which passes all incoming
+ // connections into the local network stack. The incoming addresses/ports are
+ // not sanitized, and as the proxy does not perform authentication, this handler
+ // is an open proxy. This handler should never be used in cases where the proxy
+ // server is publicly available.
+ HostHandler = &hostHandler{}
+)
+
+// Serve runs a SOCKS5 proxy server for a given Handler at a given listener.
+//
+// When the given context is canceled, the server will stop and the listener
+// will be closed. All pending connections will also be canceled and their
+// sockets closed.
+func Serve(ctx context.Context, handler Handler, lis net.Listener) error {
+ go func() {
+ <-ctx.Done()
+ lis.Close()
+ }()
+
+ for {
+ con, err := lis.Accept()
+ if err != nil {
+ // Context cancellation will close listener socket with a generic 'use of closed
+ // network connection' error, translate that back to context error.
+ if ctx.Err() != nil {
+ return ctx.Err()
+ }
+ return err
+ }
+ go handle(ctx, handler, con)
+ }
+}
+
+// handle runs in a goroutine per incoming SOCKS connection. Its lifecycle
+// corresponds to the lifecycle of a running proxy connection.
+func handle(ctx context.Context, handler Handler, con net.Conn) {
+ // ctxR is a per-request context, and will be canceled whenever the handler
+ // exits or the server is stopped.
+ ctxR, ctxRC := context.WithCancel(ctx)
+ defer ctxRC()
+
+ go func() {
+ <-ctxR.Done()
+ con.Close()
+ }()
+
+ // Perform method negotiation with the client.
+ if err := negotiateMethod(con); err != nil {
+ return
+ }
+
+ // Read request from the client and translate problems into early error replies.
+ req, err := readRequest(con)
+ switch err {
+ case errNotConnect:
+ writeReply(con, ReplyCommandNotSupported, net.IPv4(0, 0, 0, 0), 0)
+ return
+ case errUnsupportedAddressType:
+ writeReply(con, ReplyAddressTypeNotSupported, net.IPv4(0, 0, 0, 0), 0)
+ return
+ case nil:
+ default:
+ writeReply(con, ReplyGeneralFailure, net.IPv4(0, 0, 0, 0), 0)
+ return
+ }
+
+ // Ask handler.Connect for a backend.
+ conRes := handler.Connect(ctxR, &ConnectRequest{
+ Address: req.address,
+ Port: req.port,
+ })
+ // Handle programming error when returned value is nil.
+ if conRes == nil {
+ writeReply(con, ReplyGeneralFailure, net.IPv4(0, 0, 0, 0), 0)
+ return
+ }
+ // Handle returned errors.
+ if conRes.Error != ReplySucceeded {
+ writeReply(con, conRes.Error, net.IPv4(0, 0, 0, 0), 0)
+ return
+ }
+
+ // Ensure Bound.* fields are set.
+ if conRes.Backend == nil || conRes.LocalAddress == nil || conRes.LocalPort == 0 {
+ writeReply(con, ReplyGeneralFailure, net.IPv4(0, 0, 0, 0), 0)
+ return
+ }
+ // Send reply.
+ if err := writeReply(con, ReplySucceeded, conRes.LocalAddress, conRes.LocalPort); err != nil {
+ return
+ }
+
+ // Pipe returned backend into connection.
+ go io.Copy(conRes.Backend, con)
+ io.Copy(con, conRes.Backend)
+}
diff --git a/metropolis/pkg/socksproxy/socksproxy_test.go b/metropolis/pkg/socksproxy/socksproxy_test.go
new file mode 100644
index 0000000..1f384f6
--- /dev/null
+++ b/metropolis/pkg/socksproxy/socksproxy_test.go
@@ -0,0 +1,171 @@
+package socksproxy
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "os"
+ "sync/atomic"
+ "testing"
+
+ "golang.org/x/net/proxy"
+)
+
+// TestE2E implements a happy path test by chaining together an HTTP server, a
+// proxy server, a proxy client (from golang.org/x/net) and an HTTP client into
+// an end-to-end test. It uses HostHandler and the actual host network stack for
+// the test HTTP server and test proxy server.
+func TestE2E(t *testing.T) {
+ ctx, ctxC := context.WithCancel(context.Background())
+ defer ctxC()
+
+ // Start test HTTP server.
+ lisSrv, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("could not bind http listener: %v", err)
+ }
+
+ mux := http.NewServeMux()
+ mux.HandleFunc("/", func(rw http.ResponseWriter, req *http.Request) {
+ fmt.Fprintf(rw, "foo")
+ })
+ go func() {
+ err := http.Serve(lisSrv, mux)
+ if err != nil {
+ t.Fatalf("http.Serve: %v", err)
+ }
+ }()
+
+ // Start proxy server.
+ lisPrx, err := net.Listen("tcp", ":")
+ if err != nil {
+ t.Fatalf("could not bind proxy listener: %v", err)
+ }
+ go func() {
+ err := Serve(ctx, HostHandler, lisPrx)
+ if err != nil && !errors.Is(err, ctx.Err()) {
+ t.Fatalf("proxy.Serve: %v", err)
+ }
+ }()
+
+ // Start proxy client.
+ dialer, err := proxy.SOCKS5("tcp", lisPrx.Addr().String(), nil, proxy.Direct)
+ if err != nil {
+ t.Fatalf("creating SOCKS dialer failed: %v", err)
+ }
+
+ // Create http client.
+ tr := &http.Transport{
+ Dial: dialer.Dial,
+ }
+ cl := &http.Client{
+ Transport: tr,
+ }
+
+ // Perform request and expect 'foo' in response.
+ url := fmt.Sprintf("http://%s/", lisSrv.Addr().String())
+ req, err := http.NewRequest("GET", url, nil)
+ if err != nil {
+ t.Fatalf("creating test request failed: %v", err)
+ }
+ res, err := cl.Do(req)
+ if err != nil {
+ t.Fatalf("test http request failed: %v", err)
+ }
+ defer res.Body.Close()
+ body, _ := io.ReadAll(res.Body)
+ if want, got := "foo", string(body); want != got {
+ t.Errorf("wrong response from HTTP, wanted %q, got %q", want, got)
+ }
+}
+
+// testHandler is a handler which serves /dev/zero and keeps count of the
+// current number of live connections. It's used in TestCancellation to ensure
+// contexts are canceled appropriately.
+type testHandler struct {
+ live int64
+}
+
+func (t *testHandler) Connect(ctx context.Context, req *ConnectRequest) *ConnectResponse {
+ f, _ := os.Open("/dev/zero")
+
+ atomic.AddInt64(&t.live, 1)
+ go func() {
+ <-ctx.Done()
+ atomic.AddInt64(&t.live, -1)
+ f.Close()
+ }()
+
+ return &ConnectResponse{
+ Backend: f,
+ LocalAddress: net.ParseIP("127.0.0.1"),
+ LocalPort: 42123,
+ }
+}
+
+// TestCancellation ensures request contexts are canceled correctly - when an
+// incoming connection is closed and when the entire server is stopped.
+func TestCancellation(t *testing.T) {
+ handler := &testHandler{}
+
+ ctx, ctxC := context.WithCancel(context.Background())
+ defer ctxC()
+
+ // Start proxy server.
+ lisPrx, err := net.Listen("tcp", ":")
+ if err != nil {
+ t.Fatalf("could not bind proxy listener: %v", err)
+ }
+ go func() {
+ err := Serve(ctx, handler, lisPrx)
+ if err != nil && !errors.Is(err, ctx.Err()) {
+ t.Fatalf("proxy.Serve: %v", err)
+ }
+ }()
+
+ // Start proxy client.
+ dialer, err := proxy.SOCKS5("tcp", lisPrx.Addr().String(), nil, proxy.Direct)
+ if err != nil {
+ t.Fatalf("creating SOCKS dialer failed: %v", err)
+ }
+
+ // Open two connections.
+ con1, err := dialer.Dial("tcp", "192.2.0.10:1234")
+ if err != nil {
+ t.Fatalf("Dialing first client failed: %v", err)
+ }
+ con2, err := dialer.Dial("tcp", "192.2.0.10:1234")
+ if err != nil {
+ t.Fatalf("Dialing first client failed: %v", err)
+ }
+
+ // Read some data. This makes sure we're ready to check for the liveness of
+ // currently running connections.
+ io.ReadFull(con1, make([]byte, 3))
+ io.ReadFull(con2, make([]byte, 3))
+
+ // Ensure we have two connections.
+ if want, got := int64(2), atomic.LoadInt64(&handler.live); want != got {
+ t.Errorf("wanted %d connections at first, got %d", want, got)
+ }
+
+ // Close one connection. Wait for its context to be canceled.
+ con2.Close()
+ for {
+ if atomic.LoadInt64(&handler.live) == 1 {
+ break
+ }
+ }
+
+ // Cancel the entire server context. Wait for the other connection's context to
+ // be canceled as well.
+ ctxC()
+ for {
+ if atomic.LoadInt64(&handler.live) == 0 {
+ break
+ }
+ }
+}