osbase/net/sshtakeover: refactor package

This is an extensive refactoring of the sshtakeover package.

The package is renamed from //go/net/ssh to //osbase/net/sshtakeover,
to make it more clear what the package is for and avoid clashes with
golang.org/x/crypto/ssh.

The previous Client type was removed, and Dial is now a top-level
function which takes an ssh.ClientConfig. The previous Connection type
was renamed to Client, which makes the naming match ssh.Client.

The Client interface type was moved to //cloud/shepherd/manager. This
allows us to add more functions to sshtakeover.Client without breaking
consumers of the package, which would need to add dummy implementations
for functions which they don't need.

The Upload function was renamed to UploadExecutable, and the new Upload
function can be used for files that don't need to be executable.

The sftp client is now created at the same time as the client, instead
of creating a new one for each uploaded file.

Change-Id: I3be9c346713cb4e5c2b33f9c8c9a6f11ca569a75
Reviewed-on: https://review.monogon.dev/c/monogon/+/4047
Tested-by: Jenkins CI
Reviewed-by: Tim Windelschmidt <tim@monogon.tech>
diff --git a/cloud/shepherd/manager/BUILD.bazel b/cloud/shepherd/manager/BUILD.bazel
index aa61caa..92a2ffe 100644
--- a/cloud/shepherd/manager/BUILD.bazel
+++ b/cloud/shepherd/manager/BUILD.bazel
@@ -20,7 +20,7 @@
         "//cloud/bmaas/bmdb/model",
         "//cloud/shepherd",
         "//go/mflags",
-        "//go/net/ssh",
+        "//osbase/net/sshtakeover",
         "@com_github_google_uuid//:uuid",
         "@io_k8s_klog_v2//:klog",
         "@org_golang_google_protobuf//proto",
@@ -46,9 +46,9 @@
         "//cloud/bmaas/bmdb/model",
         "//cloud/lib/component",
         "//cloud/shepherd",
-        "//go/net/ssh",
         "@com_github_google_uuid//:uuid",
         "@io_k8s_klog_v2//:klog",
+        "@org_golang_x_crypto//ssh",
         "@org_golang_x_time//rate",
     ],
 )
diff --git a/cloud/shepherd/manager/fake_ssh_client.go b/cloud/shepherd/manager/fake_ssh_client.go
index d932210..97de575 100644
--- a/cloud/shepherd/manager/fake_ssh_client.go
+++ b/cloud/shepherd/manager/fake_ssh_client.go
@@ -9,26 +9,22 @@
 	"crypto/rand"
 	"fmt"
 	"io"
-	"time"
 
+	"golang.org/x/crypto/ssh"
 	"google.golang.org/protobuf/proto"
 
 	apb "source.monogon.dev/cloud/agent/api"
-
-	"source.monogon.dev/go/net/ssh"
 )
 
-// FakeSSHClient is an Client that pretends to start an agent, but in reality
-// just responds with what an agent would respond on every execution attempt.
-type FakeSSHClient struct{}
+type fakeSSHClient struct{}
 
-type fakeSSHConnection struct{}
-
-func (f *FakeSSHClient) Dial(ctx context.Context, address string, timeout time.Duration) (ssh.Connection, error) {
-	return &fakeSSHConnection{}, nil
+// FakeSSHDial pretends to start an agent, but in reality just responds with
+// what an agent would respond on every execution attempt.
+func FakeSSHDial(ctx context.Context, address string, config *ssh.ClientConfig) (SSHClient, error) {
+	return &fakeSSHClient{}, nil
 }
 
-func (f *fakeSSHConnection) Execute(ctx context.Context, command string, stdin []byte) (stdout []byte, stderr []byte, err error) {
+func (f *fakeSSHClient) Execute(ctx context.Context, command string, stdin []byte) (stdout []byte, stderr []byte, err error) {
 	var aim apb.TakeoverInit
 	if err := proto.Unmarshal(stdin, &aim); err != nil {
 		return nil, nil, fmt.Errorf("while unmarshaling TakeoverInit message: %w", err)
@@ -52,13 +48,13 @@
 	return arspb, nil, nil
 }
 
-func (f *fakeSSHConnection) Upload(ctx context.Context, targetPath string, _ io.Reader) error {
+func (f *fakeSSHClient) UploadExecutable(ctx context.Context, targetPath string, _ io.Reader) error {
 	if targetPath != "/fake/path" {
 		return fmt.Errorf("unexpected target path in test")
 	}
 	return nil
 }
 
-func (f *fakeSSHConnection) Close() error {
+func (f *fakeSSHClient) Close() error {
 	return nil
 }
diff --git a/cloud/shepherd/manager/initializer.go b/cloud/shepherd/manager/initializer.go
index f4495e8..d9e3579 100644
--- a/cloud/shepherd/manager/initializer.go
+++ b/cloud/shepherd/manager/initializer.go
@@ -12,12 +12,14 @@
 	"encoding/pem"
 	"flag"
 	"fmt"
+	"io"
 	"net"
 	"os"
 	"strings"
 	"time"
 
 	"github.com/google/uuid"
+	"golang.org/x/crypto/ssh"
 	"google.golang.org/protobuf/proto"
 	"k8s.io/klog/v2"
 
@@ -27,7 +29,7 @@
 	"source.monogon.dev/cloud/bmaas/bmdb/metrics"
 	"source.monogon.dev/cloud/bmaas/bmdb/model"
 	"source.monogon.dev/cloud/shepherd"
-	"source.monogon.dev/go/net/ssh"
+	"source.monogon.dev/osbase/net/sshtakeover"
 )
 
 // InitializerConfig configures how the Initializer will deploy Agents on
@@ -54,14 +56,20 @@
 	// certificate.
 	EndpointCACertificate []byte
 
-	// SSHTimeout is the amount of time set aside for the initializing
-	// SSH session to run its course. Upon timeout, the iteration would be
-	// declared a failure. Must be set.
-	SSHConnectTimeout time.Duration
+	SSHConfig ssh.ClientConfig
 	// SSHExecTimeout is the amount of time set aside for executing the agent and
 	// getting its output once the SSH connection has been established. Upon timeout,
 	// the iteration would be declared as failure. Must be set.
 	SSHExecTimeout time.Duration
+
+	// DialSSH can be set in tests to override how ssh connections are started.
+	DialSSH func(ctx context.Context, address string, config *ssh.ClientConfig) (SSHClient, error)
+}
+
+type SSHClient interface {
+	Execute(ctx context.Context, command string, stdin []byte) (stdout []byte, stderr []byte, err error)
+	UploadExecutable(ctx context.Context, targetPath string, src io.Reader) error
+	Close() error
 }
 
 func (ic *InitializerConfig) RegisterFlags() {
@@ -99,7 +107,7 @@
 		ic.EndpointCACertificate = block.Bytes
 		return nil
 	})
-	flag.DurationVar(&ic.SSHConnectTimeout, "agent_ssh_connect_timeout", 2*time.Second, "Timeout for connecting over SSH to a machine")
+	flag.DurationVar(&ic.SSHConfig.Timeout, "agent_ssh_connect_timeout", 2*time.Second, "Timeout for connecting over SSH to a machine")
 	flag.DurationVar(&ic.SSHExecTimeout, "agent_ssh_exec_timeout", 60*time.Second, "Timeout for connecting over SSH to a machine")
 }
 
@@ -117,7 +125,7 @@
 	if ic.Endpoint == "" {
 		return fmt.Errorf("agent endpoint must be set")
 	}
-	if ic.SSHConnectTimeout == 0 {
+	if ic.SSHConfig.Timeout == 0 {
 		return fmt.Errorf("agent SSH connection timeout must be set")
 	}
 	if ic.SSHExecTimeout == 0 {
@@ -131,13 +139,12 @@
 type Initializer struct {
 	InitializerConfig
 
-	sshClient ssh.Client
-	p         shepherd.Provider
+	p shepherd.Provider
 }
 
 // NewInitializer creates an Initializer instance, checking the
 // InitializerConfig, SharedConfig and AgentConfig for errors.
-func NewInitializer(p shepherd.Provider, sshClient ssh.Client, ic InitializerConfig) (*Initializer, error) {
+func NewInitializer(p shepherd.Provider, ic InitializerConfig) (*Initializer, error) {
 	if err := ic.Check(); err != nil {
 		return nil, err
 	}
@@ -145,8 +152,7 @@
 	return &Initializer{
 		InitializerConfig: ic,
 
-		p:         p,
-		sshClient: sshClient,
+		p: p,
 	}, nil
 }
 
@@ -215,7 +221,13 @@
 	addr := net.JoinHostPort(ni.String(), "22")
 	klog.V(1).Infof("Dialing machine (machine ID: %s, addr: %s).", mid, addr)
 
-	conn, err := i.sshClient.Dial(sctx, addr, i.SSHConnectTimeout)
+	var conn SSHClient
+	var err error
+	if i.DialSSH != nil {
+		conn, err = i.DialSSH(sctx, addr, &i.SSHConfig)
+	} else {
+		conn, err = sshtakeover.Dial(sctx, addr, &i.SSHConfig)
+	}
 	if err != nil {
 		return nil, fmt.Errorf("while dialing the machine: %w", err)
 	}
@@ -224,7 +236,7 @@
 	// Upload the agent executable.
 
 	klog.Infof("Uploading the agent executable (machine ID: %s, addr: %s).", mid, addr)
-	if err := conn.Upload(sctx, i.TargetPath, bytes.NewReader(i.Executable)); err != nil {
+	if err := conn.UploadExecutable(sctx, i.TargetPath, bytes.NewReader(i.Executable)); err != nil {
 		return nil, fmt.Errorf("while uploading agent executable: %w", err)
 	}
 	klog.V(1).Infof("Upload successful (machine ID: %s, addr: %s).", mid, addr)
diff --git a/cloud/shepherd/manager/initializer_test.go b/cloud/shepherd/manager/initializer_test.go
index 3c95527..3b41044 100644
--- a/cloud/shepherd/manager/initializer_test.go
+++ b/cloud/shepherd/manager/initializer_test.go
@@ -8,6 +8,7 @@
 	"testing"
 	"time"
 
+	"golang.org/x/crypto/ssh"
 	"golang.org/x/time/rate"
 
 	"source.monogon.dev/cloud/bmaas/bmdb"
@@ -24,14 +25,17 @@
 		ControlLoopConfig: ControlLoopConfig{
 			DBQueryLimiter: rate.NewLimiter(rate.Every(time.Second), 10),
 		},
-		Executable:        []byte("beep boop i'm a real program"),
-		TargetPath:        "/fake/path",
-		Endpoint:          "example.com:1234",
-		SSHConnectTimeout: time.Second,
-		SSHExecTimeout:    time.Second,
+		Executable: []byte("beep boop i'm a real program"),
+		TargetPath: "/fake/path",
+		Endpoint:   "example.com:1234",
+		SSHConfig: ssh.ClientConfig{
+			Timeout: time.Second,
+		},
+		SSHExecTimeout: time.Second,
+		DialSSH:        provider.FakeSSHDial,
 	}
 
-	i, err := NewInitializer(provider, provider.sshClient(), ic)
+	i, err := NewInitializer(provider, ic)
 	if err != nil {
 		t.Fatalf("Could not create Initializer: %v", err)
 	}
diff --git a/cloud/shepherd/manager/provider_test.go b/cloud/shepherd/manager/provider_test.go
index 003fd5e..2dc3fa9 100644
--- a/cloud/shepherd/manager/provider_test.go
+++ b/cloud/shepherd/manager/provider_test.go
@@ -8,15 +8,14 @@
 	"fmt"
 	"net/netip"
 	"sync"
-	"time"
 
 	"github.com/google/uuid"
+	"golang.org/x/crypto/ssh"
 	"k8s.io/klog/v2"
 
 	"source.monogon.dev/cloud/bmaas/bmdb"
 	"source.monogon.dev/cloud/bmaas/bmdb/model"
 	"source.monogon.dev/cloud/shepherd"
-	"source.monogon.dev/go/net/ssh"
 )
 
 type dummyMachine struct {
@@ -43,17 +42,12 @@
 }
 
 type dummySSHClient struct {
-	ssh.Client
-	dp *dummyProvider
-}
-
-type dummySSHConnection struct {
-	ssh.Connection
+	SSHClient
 	m *dummyMachine
 }
 
-func (dsc *dummySSHConnection) Execute(ctx context.Context, command string, stdin []byte) ([]byte, []byte, error) {
-	stdout, stderr, err := dsc.Connection.Execute(ctx, command, stdin)
+func (dsc *dummySSHClient) Execute(ctx context.Context, command string, stdin []byte) ([]byte, []byte, error) {
+	stdout, stderr, err := dsc.SSHClient.Execute(ctx, command, stdin)
 	if err != nil {
 		return nil, nil, err
 	}
@@ -62,8 +56,8 @@
 	return stdout, stderr, nil
 }
 
-func (dsc *dummySSHClient) Dial(ctx context.Context, address string, timeout time.Duration) (ssh.Connection, error) {
-	conn, err := dsc.Client.Dial(ctx, address, timeout)
+func (dp *dummyProvider) FakeSSHDial(ctx context.Context, address string, config *ssh.ClientConfig) (SSHClient, error) {
+	conn, err := FakeSSHDial(ctx, address, config)
 	if err != nil {
 		return nil, err
 	}
@@ -74,21 +68,14 @@
 		return nil, err
 	}
 
-	dsc.dp.muMachines.RLock()
-	m := dsc.dp.machines[shepherd.ProviderID(uid.String())]
-	dsc.dp.muMachines.RUnlock()
+	dp.muMachines.RLock()
+	m := dp.machines[shepherd.ProviderID(uid.String())]
+	dp.muMachines.RUnlock()
 	if m == nil {
 		return nil, fmt.Errorf("failed finding machine in map")
 	}
 
-	return &dummySSHConnection{conn, m}, nil
-}
-
-func (dp *dummyProvider) sshClient() ssh.Client {
-	return &dummySSHClient{
-		Client: &FakeSSHClient{},
-		dp:     dp,
-	}
+	return &dummySSHClient{conn, m}, nil
 }
 
 func newDummyProvider(cap int) *dummyProvider {