osbase/socksproxy: implement hostname addresses

This is required by quite a few clients, like Chrome. Implement it for
better usability of our userspace network while debugging.

Change-Id: I5db16d3702800b79f88d11c132ce8f7469839ec4
Reviewed-on: https://review.monogon.dev/c/monogon/+/3842
Tested-by: Jenkins CI
Reviewed-by: Leopold Schabel <leo@monogon.tech>
diff --git a/metropolis/test/nanoswitch/socks.go b/metropolis/test/nanoswitch/socks.go
index 4ba1d0a..15e0f81 100644
--- a/metropolis/test/nanoswitch/socks.go
+++ b/metropolis/test/nanoswitch/socks.go
@@ -23,16 +23,29 @@
 
 func (s *socksHandler) Connect(ctx context.Context, req *socksproxy.ConnectRequest) *socksproxy.ConnectResponse {
 	logger := supervisor.Logger(ctx)
-	target := net.JoinHostPort(req.Address.String(), fmt.Sprintf("%d", req.Port))
-
-	if len(req.Address) != 4 {
-		logger.Warningf("Connect %s: wrong address type", target)
-		return &socksproxy.ConnectResponse{
-			Error: socksproxy.ReplyAddressTypeNotSupported,
+	var target string
+	var addr net.IP
+	if req.Hostname == "" {
+		target = net.JoinHostPort(req.Address.String(), fmt.Sprintf("%d", req.Port))
+		if req.Address.To4() == nil {
+			logger.Warningf("Connect %s: wrong address type", target)
+			return &socksproxy.ConnectResponse{
+				Error: socksproxy.ReplyAddressTypeNotSupported,
+			}
 		}
+		addr = req.Address
+	} else {
+		target = net.JoinHostPort(req.Hostname, fmt.Sprintf("%d", req.Port))
+		ip, err := net.ResolveIPAddr("ip", req.Hostname)
+		if err != nil {
+			logger.Warningf("Connect %s: while resolving hostname: %v", target, err)
+			return &socksproxy.ConnectResponse{
+				Error: socksproxy.ReplyAddressTypeNotSupported,
+			}
+		}
+		addr = ip.IP
 	}
 
-	addr := req.Address
 	switchCIDR := net.IPNet{
 		IP:   switchIP.Mask(switchSubnetMask),
 		Mask: switchSubnetMask,
diff --git a/osbase/socksproxy/protocol.go b/osbase/socksproxy/protocol.go
index 0d5d133..698d324 100644
--- a/osbase/socksproxy/protocol.go
+++ b/osbase/socksproxy/protocol.go
@@ -115,9 +115,13 @@
 	}
 
 	var addrBytes []byte
+	var hostnameBytes []byte
 	switch header.Atyp {
 	case 1:
 		addrBytes = make([]byte, 4)
+	case 3:
+		// Variable-length string to resolve
+		addrBytes = make([]byte, 1)
 	case 4:
 		addrBytes = make([]byte, 16)
 	default:
@@ -127,20 +131,30 @@
 		return nil, fmt.Errorf("when reading address: %w", err)
 	}
 
+	// Handle domain name addressing, required by for example Chrome
+	if header.Atyp == 3 {
+		hostnameBytes = make([]byte, addrBytes[0])
+		if _, err := io.ReadFull(r, hostnameBytes); err != nil {
+			return nil, fmt.Errorf("when reading address: %w", err)
+		}
+	}
+
 	var port uint16
 	if err := binary.Read(r, binary.BigEndian, &port); err != nil {
 		return nil, fmt.Errorf("when reading port: %w", err)
 	}
 
 	return &connectRequest{
-		address: addrBytes,
-		port:    port,
+		address:  addrBytes,
+		hostname: string(hostnameBytes),
+		port:     port,
 	}, nil
 }
 
 type connectRequest struct {
-	address net.IP
-	port    uint16
+	address  net.IP
+	hostname string
+	port     uint16
 }
 
 // Reply is an RFC1928 6. “Replies” reply field value. It's returned to the
diff --git a/osbase/socksproxy/socksproxy.go b/osbase/socksproxy/socksproxy.go
index 47a398c..6143f90 100644
--- a/osbase/socksproxy/socksproxy.go
+++ b/osbase/socksproxy/socksproxy.go
@@ -52,6 +52,10 @@
 	// This address might be invalid/malformed/internal, and the Connect method
 	// should sanitize it before using it.
 	Address net.IP
+	// Hostname is a string that the client requested to connect to. Only set if
+	// Address is empty. Format and resolution rules are up to the implementer,
+	// a lot of clients only allow valid DNS labels.
+	Hostname string
 	// Port is the TCP port number that the client requested to connect to.
 	Port uint16
 }
@@ -105,7 +109,13 @@
 
 func (h *hostHandler) Connect(ctx context.Context, req *ConnectRequest) *ConnectResponse {
 	port := fmt.Sprintf("%d", req.Port)
-	addr := net.JoinHostPort(req.Address.String(), port)
+	var host string
+	if req.Hostname != "" {
+		host = req.Hostname
+	} else {
+		host = req.Address.String()
+	}
+	addr := net.JoinHostPort(host, port)
 	s, err := net.Dial("tcp", addr)
 	if err != nil {
 		log.Printf("HostHandler could not dial %q: %v", addr, err)
@@ -192,8 +202,9 @@
 
 	// Ask handler.Connect for a backend.
 	conRes := handler.Connect(ctxR, &ConnectRequest{
-		Address: req.address,
-		Port:    req.port,
+		Address:  req.address,
+		Hostname: req.hostname,
+		Port:     req.port,
 	})
 	// Handle programming error when returned value is nil.
 	if conRes == nil {