osbase/net/sshtakeover: refactor package
This is an extensive refactoring of the sshtakeover package.
The package is renamed from //go/net/ssh to //osbase/net/sshtakeover,
to make it more clear what the package is for and avoid clashes with
golang.org/x/crypto/ssh.
The previous Client type was removed, and Dial is now a top-level
function which takes an ssh.ClientConfig. The previous Connection type
was renamed to Client, which makes the naming match ssh.Client.
The Client interface type was moved to //cloud/shepherd/manager. This
allows us to add more functions to sshtakeover.Client without breaking
consumers of the package, which would need to add dummy implementations
for functions which they don't need.
The Upload function was renamed to UploadExecutable, and the new Upload
function can be used for files that don't need to be executable.
The sftp client is now created at the same time as the client, instead
of creating a new one for each uploaded file.
Change-Id: I3be9c346713cb4e5c2b33f9c8c9a6f11ca569a75
Reviewed-on: https://review.monogon.dev/c/monogon/+/4047
Tested-by: Jenkins CI
Reviewed-by: Tim Windelschmidt <tim@monogon.tech>
diff --git a/cloud/shepherd/manager/BUILD.bazel b/cloud/shepherd/manager/BUILD.bazel
index aa61caa..92a2ffe 100644
--- a/cloud/shepherd/manager/BUILD.bazel
+++ b/cloud/shepherd/manager/BUILD.bazel
@@ -20,7 +20,7 @@
"//cloud/bmaas/bmdb/model",
"//cloud/shepherd",
"//go/mflags",
- "//go/net/ssh",
+ "//osbase/net/sshtakeover",
"@com_github_google_uuid//:uuid",
"@io_k8s_klog_v2//:klog",
"@org_golang_google_protobuf//proto",
@@ -46,9 +46,9 @@
"//cloud/bmaas/bmdb/model",
"//cloud/lib/component",
"//cloud/shepherd",
- "//go/net/ssh",
"@com_github_google_uuid//:uuid",
"@io_k8s_klog_v2//:klog",
+ "@org_golang_x_crypto//ssh",
"@org_golang_x_time//rate",
],
)
diff --git a/cloud/shepherd/manager/fake_ssh_client.go b/cloud/shepherd/manager/fake_ssh_client.go
index d932210..97de575 100644
--- a/cloud/shepherd/manager/fake_ssh_client.go
+++ b/cloud/shepherd/manager/fake_ssh_client.go
@@ -9,26 +9,22 @@
"crypto/rand"
"fmt"
"io"
- "time"
+ "golang.org/x/crypto/ssh"
"google.golang.org/protobuf/proto"
apb "source.monogon.dev/cloud/agent/api"
-
- "source.monogon.dev/go/net/ssh"
)
-// FakeSSHClient is an Client that pretends to start an agent, but in reality
-// just responds with what an agent would respond on every execution attempt.
-type FakeSSHClient struct{}
+type fakeSSHClient struct{}
-type fakeSSHConnection struct{}
-
-func (f *FakeSSHClient) Dial(ctx context.Context, address string, timeout time.Duration) (ssh.Connection, error) {
- return &fakeSSHConnection{}, nil
+// FakeSSHDial pretends to start an agent, but in reality just responds with
+// what an agent would respond on every execution attempt.
+func FakeSSHDial(ctx context.Context, address string, config *ssh.ClientConfig) (SSHClient, error) {
+ return &fakeSSHClient{}, nil
}
-func (f *fakeSSHConnection) Execute(ctx context.Context, command string, stdin []byte) (stdout []byte, stderr []byte, err error) {
+func (f *fakeSSHClient) Execute(ctx context.Context, command string, stdin []byte) (stdout []byte, stderr []byte, err error) {
var aim apb.TakeoverInit
if err := proto.Unmarshal(stdin, &aim); err != nil {
return nil, nil, fmt.Errorf("while unmarshaling TakeoverInit message: %w", err)
@@ -52,13 +48,13 @@
return arspb, nil, nil
}
-func (f *fakeSSHConnection) Upload(ctx context.Context, targetPath string, _ io.Reader) error {
+func (f *fakeSSHClient) UploadExecutable(ctx context.Context, targetPath string, _ io.Reader) error {
if targetPath != "/fake/path" {
return fmt.Errorf("unexpected target path in test")
}
return nil
}
-func (f *fakeSSHConnection) Close() error {
+func (f *fakeSSHClient) Close() error {
return nil
}
diff --git a/cloud/shepherd/manager/initializer.go b/cloud/shepherd/manager/initializer.go
index f4495e8..d9e3579 100644
--- a/cloud/shepherd/manager/initializer.go
+++ b/cloud/shepherd/manager/initializer.go
@@ -12,12 +12,14 @@
"encoding/pem"
"flag"
"fmt"
+ "io"
"net"
"os"
"strings"
"time"
"github.com/google/uuid"
+ "golang.org/x/crypto/ssh"
"google.golang.org/protobuf/proto"
"k8s.io/klog/v2"
@@ -27,7 +29,7 @@
"source.monogon.dev/cloud/bmaas/bmdb/metrics"
"source.monogon.dev/cloud/bmaas/bmdb/model"
"source.monogon.dev/cloud/shepherd"
- "source.monogon.dev/go/net/ssh"
+ "source.monogon.dev/osbase/net/sshtakeover"
)
// InitializerConfig configures how the Initializer will deploy Agents on
@@ -54,14 +56,20 @@
// certificate.
EndpointCACertificate []byte
- // SSHTimeout is the amount of time set aside for the initializing
- // SSH session to run its course. Upon timeout, the iteration would be
- // declared a failure. Must be set.
- SSHConnectTimeout time.Duration
+ SSHConfig ssh.ClientConfig
// SSHExecTimeout is the amount of time set aside for executing the agent and
// getting its output once the SSH connection has been established. Upon timeout,
// the iteration would be declared as failure. Must be set.
SSHExecTimeout time.Duration
+
+ // DialSSH can be set in tests to override how ssh connections are started.
+ DialSSH func(ctx context.Context, address string, config *ssh.ClientConfig) (SSHClient, error)
+}
+
+type SSHClient interface {
+ Execute(ctx context.Context, command string, stdin []byte) (stdout []byte, stderr []byte, err error)
+ UploadExecutable(ctx context.Context, targetPath string, src io.Reader) error
+ Close() error
}
func (ic *InitializerConfig) RegisterFlags() {
@@ -99,7 +107,7 @@
ic.EndpointCACertificate = block.Bytes
return nil
})
- flag.DurationVar(&ic.SSHConnectTimeout, "agent_ssh_connect_timeout", 2*time.Second, "Timeout for connecting over SSH to a machine")
+ flag.DurationVar(&ic.SSHConfig.Timeout, "agent_ssh_connect_timeout", 2*time.Second, "Timeout for connecting over SSH to a machine")
flag.DurationVar(&ic.SSHExecTimeout, "agent_ssh_exec_timeout", 60*time.Second, "Timeout for connecting over SSH to a machine")
}
@@ -117,7 +125,7 @@
if ic.Endpoint == "" {
return fmt.Errorf("agent endpoint must be set")
}
- if ic.SSHConnectTimeout == 0 {
+ if ic.SSHConfig.Timeout == 0 {
return fmt.Errorf("agent SSH connection timeout must be set")
}
if ic.SSHExecTimeout == 0 {
@@ -131,13 +139,12 @@
type Initializer struct {
InitializerConfig
- sshClient ssh.Client
- p shepherd.Provider
+ p shepherd.Provider
}
// NewInitializer creates an Initializer instance, checking the
// InitializerConfig, SharedConfig and AgentConfig for errors.
-func NewInitializer(p shepherd.Provider, sshClient ssh.Client, ic InitializerConfig) (*Initializer, error) {
+func NewInitializer(p shepherd.Provider, ic InitializerConfig) (*Initializer, error) {
if err := ic.Check(); err != nil {
return nil, err
}
@@ -145,8 +152,7 @@
return &Initializer{
InitializerConfig: ic,
- p: p,
- sshClient: sshClient,
+ p: p,
}, nil
}
@@ -215,7 +221,13 @@
addr := net.JoinHostPort(ni.String(), "22")
klog.V(1).Infof("Dialing machine (machine ID: %s, addr: %s).", mid, addr)
- conn, err := i.sshClient.Dial(sctx, addr, i.SSHConnectTimeout)
+ var conn SSHClient
+ var err error
+ if i.DialSSH != nil {
+ conn, err = i.DialSSH(sctx, addr, &i.SSHConfig)
+ } else {
+ conn, err = sshtakeover.Dial(sctx, addr, &i.SSHConfig)
+ }
if err != nil {
return nil, fmt.Errorf("while dialing the machine: %w", err)
}
@@ -224,7 +236,7 @@
// Upload the agent executable.
klog.Infof("Uploading the agent executable (machine ID: %s, addr: %s).", mid, addr)
- if err := conn.Upload(sctx, i.TargetPath, bytes.NewReader(i.Executable)); err != nil {
+ if err := conn.UploadExecutable(sctx, i.TargetPath, bytes.NewReader(i.Executable)); err != nil {
return nil, fmt.Errorf("while uploading agent executable: %w", err)
}
klog.V(1).Infof("Upload successful (machine ID: %s, addr: %s).", mid, addr)
diff --git a/cloud/shepherd/manager/initializer_test.go b/cloud/shepherd/manager/initializer_test.go
index 3c95527..3b41044 100644
--- a/cloud/shepherd/manager/initializer_test.go
+++ b/cloud/shepherd/manager/initializer_test.go
@@ -8,6 +8,7 @@
"testing"
"time"
+ "golang.org/x/crypto/ssh"
"golang.org/x/time/rate"
"source.monogon.dev/cloud/bmaas/bmdb"
@@ -24,14 +25,17 @@
ControlLoopConfig: ControlLoopConfig{
DBQueryLimiter: rate.NewLimiter(rate.Every(time.Second), 10),
},
- Executable: []byte("beep boop i'm a real program"),
- TargetPath: "/fake/path",
- Endpoint: "example.com:1234",
- SSHConnectTimeout: time.Second,
- SSHExecTimeout: time.Second,
+ Executable: []byte("beep boop i'm a real program"),
+ TargetPath: "/fake/path",
+ Endpoint: "example.com:1234",
+ SSHConfig: ssh.ClientConfig{
+ Timeout: time.Second,
+ },
+ SSHExecTimeout: time.Second,
+ DialSSH: provider.FakeSSHDial,
}
- i, err := NewInitializer(provider, provider.sshClient(), ic)
+ i, err := NewInitializer(provider, ic)
if err != nil {
t.Fatalf("Could not create Initializer: %v", err)
}
diff --git a/cloud/shepherd/manager/provider_test.go b/cloud/shepherd/manager/provider_test.go
index 003fd5e..2dc3fa9 100644
--- a/cloud/shepherd/manager/provider_test.go
+++ b/cloud/shepherd/manager/provider_test.go
@@ -8,15 +8,14 @@
"fmt"
"net/netip"
"sync"
- "time"
"github.com/google/uuid"
+ "golang.org/x/crypto/ssh"
"k8s.io/klog/v2"
"source.monogon.dev/cloud/bmaas/bmdb"
"source.monogon.dev/cloud/bmaas/bmdb/model"
"source.monogon.dev/cloud/shepherd"
- "source.monogon.dev/go/net/ssh"
)
type dummyMachine struct {
@@ -43,17 +42,12 @@
}
type dummySSHClient struct {
- ssh.Client
- dp *dummyProvider
-}
-
-type dummySSHConnection struct {
- ssh.Connection
+ SSHClient
m *dummyMachine
}
-func (dsc *dummySSHConnection) Execute(ctx context.Context, command string, stdin []byte) ([]byte, []byte, error) {
- stdout, stderr, err := dsc.Connection.Execute(ctx, command, stdin)
+func (dsc *dummySSHClient) Execute(ctx context.Context, command string, stdin []byte) ([]byte, []byte, error) {
+ stdout, stderr, err := dsc.SSHClient.Execute(ctx, command, stdin)
if err != nil {
return nil, nil, err
}
@@ -62,8 +56,8 @@
return stdout, stderr, nil
}
-func (dsc *dummySSHClient) Dial(ctx context.Context, address string, timeout time.Duration) (ssh.Connection, error) {
- conn, err := dsc.Client.Dial(ctx, address, timeout)
+func (dp *dummyProvider) FakeSSHDial(ctx context.Context, address string, config *ssh.ClientConfig) (SSHClient, error) {
+ conn, err := FakeSSHDial(ctx, address, config)
if err != nil {
return nil, err
}
@@ -74,21 +68,14 @@
return nil, err
}
- dsc.dp.muMachines.RLock()
- m := dsc.dp.machines[shepherd.ProviderID(uid.String())]
- dsc.dp.muMachines.RUnlock()
+ dp.muMachines.RLock()
+ m := dp.machines[shepherd.ProviderID(uid.String())]
+ dp.muMachines.RUnlock()
if m == nil {
return nil, fmt.Errorf("failed finding machine in map")
}
- return &dummySSHConnection{conn, m}, nil
-}
-
-func (dp *dummyProvider) sshClient() ssh.Client {
- return &dummySSHClient{
- Client: &FakeSSHClient{},
- dp: dp,
- }
+ return &dummySSHClient{conn, m}, nil
}
func newDummyProvider(cap int) *dummyProvider {
diff --git a/cloud/shepherd/mini/BUILD.bazel b/cloud/shepherd/mini/BUILD.bazel
index 948eb40..7b1a7ad 100644
--- a/cloud/shepherd/mini/BUILD.bazel
+++ b/cloud/shepherd/mini/BUILD.bazel
@@ -18,7 +18,6 @@
"//cloud/lib/component",
"//cloud/shepherd",
"//cloud/shepherd/manager",
- "//go/net/ssh",
"@io_k8s_klog_v2//:klog",
"@org_golang_x_crypto//ssh",
],
diff --git a/cloud/shepherd/mini/main.go b/cloud/shepherd/mini/main.go
index 75a5005..b056d20 100644
--- a/cloud/shepherd/mini/main.go
+++ b/cloud/shepherd/mini/main.go
@@ -144,7 +144,7 @@
klog.Exitf("Failed to open BMDB connection: %v", err)
}
- sshClient, err := c.SSHConfig.NewClient()
+ err = c.SSHConfig.Configure(&c.InitializerConfig.SSHConfig)
if err != nil {
klog.Exitf("Failed to create SSH client: %v", err)
}
@@ -168,7 +168,7 @@
klog.Exitf("%v", err)
}
- initializer, err := manager.NewInitializer(mini, sshClient, c.InitializerConfig)
+ initializer, err := manager.NewInitializer(mini, c.InitializerConfig)
if err != nil {
klog.Exitf("%v", err)
}
diff --git a/cloud/shepherd/mini/ssh.go b/cloud/shepherd/mini/ssh.go
index dceafc4..59cefca 100644
--- a/cloud/shepherd/mini/ssh.go
+++ b/cloud/shepherd/mini/ssh.go
@@ -7,11 +7,10 @@
"flag"
"fmt"
- xssh "golang.org/x/crypto/ssh"
+ "golang.org/x/crypto/ssh"
"k8s.io/klog/v2"
"source.monogon.dev/cloud/shepherd/manager"
- "source.monogon.dev/go/net/ssh"
)
type sshConfig struct {
@@ -40,32 +39,35 @@
sc.SSHKey.RegisterFlags()
}
-func (sc *sshConfig) NewClient() (*ssh.DirectClient, error) {
+func (sc *sshConfig) Configure(config *ssh.ClientConfig) error {
if err := sc.check(); err != nil {
- return nil, err
+ return err
}
- c := ssh.DirectClient{
- Username: sc.User,
- }
+ config.User = sc.User
switch {
case sc.Pass != "":
- c.AuthMethods = []xssh.AuthMethod{xssh.Password(sc.Pass)}
+ config.Auth = []ssh.AuthMethod{ssh.Password(sc.Pass)}
case sc.SSHKey.KeyPersistPath != "":
signer, err := sc.SSHKey.Signer()
if err != nil {
- return nil, err
+ return err
}
pubKey, err := sc.SSHKey.PublicKey()
if err != nil {
- return nil, err
+ return err
}
klog.Infof("Using ssh key auth with public key: %s", pubKey)
- c.AuthMethods = []xssh.AuthMethod{xssh.PublicKeys(signer)}
+ config.Auth = []ssh.AuthMethod{ssh.PublicKeys(signer)}
}
- return &c, nil
+
+ // Ignore the host key, since it's likely the first time anything logs into
+ // this device, and also because there's no way of knowing its fingerprint.
+ config.HostKeyCallback = ssh.InsecureIgnoreHostKey()
+
+ return nil
}
diff --git a/cloud/shepherd/provider/equinix/BUILD.bazel b/cloud/shepherd/provider/equinix/BUILD.bazel
index 23502b3..e0333b1 100644
--- a/cloud/shepherd/provider/equinix/BUILD.bazel
+++ b/cloud/shepherd/provider/equinix/BUILD.bazel
@@ -22,7 +22,6 @@
"//cloud/lib/sinbin",
"//cloud/shepherd",
"//cloud/shepherd/manager",
- "//go/net/ssh",
"@com_github_packethost_packngo//:packngo",
"@io_k8s_klog_v2//:klog",
"@org_golang_x_crypto//ssh",
@@ -49,6 +48,7 @@
"//cloud/shepherd/manager",
"@com_github_google_uuid//:uuid",
"@com_github_packethost_packngo//:packngo",
+ "@org_golang_x_crypto//ssh",
"@org_golang_x_time//rate",
],
)
diff --git a/cloud/shepherd/provider/equinix/initializer_test.go b/cloud/shepherd/provider/equinix/initializer_test.go
index 587b70a..fc34a10 100644
--- a/cloud/shepherd/provider/equinix/initializer_test.go
+++ b/cloud/shepherd/provider/equinix/initializer_test.go
@@ -12,6 +12,7 @@
"time"
"github.com/packethost/packngo"
+ "golang.org/x/crypto/ssh"
"golang.org/x/time/rate"
"source.monogon.dev/cloud/bmaas/bmdb"
@@ -51,14 +52,17 @@
ControlLoopConfig: manager.ControlLoopConfig{
DBQueryLimiter: rate.NewLimiter(rate.Every(time.Second), 10),
},
- Executable: []byte("beep boop i'm a real program"),
- TargetPath: "/fake/path",
- Endpoint: "example.com:1234",
- SSHConnectTimeout: time.Second,
- SSHExecTimeout: time.Second,
+ Executable: []byte("beep boop i'm a real program"),
+ TargetPath: "/fake/path",
+ Endpoint: "example.com:1234",
+ SSHConfig: ssh.ClientConfig{
+ Timeout: time.Second,
+ },
+ SSHExecTimeout: time.Second,
+ DialSSH: manager.FakeSSHDial,
}
- i, err := manager.NewInitializer(provider, &manager.FakeSSHClient{}, ic)
+ i, err := manager.NewInitializer(provider, ic)
if err != nil {
t.Fatalf("Could not create Initializer: %v", err)
}
diff --git a/cloud/shepherd/provider/equinix/main.go b/cloud/shepherd/provider/equinix/main.go
index 1322039..9903b04 100644
--- a/cloud/shepherd/provider/equinix/main.go
+++ b/cloud/shepherd/provider/equinix/main.go
@@ -11,7 +11,7 @@
"os"
"os/signal"
- xssh "golang.org/x/crypto/ssh"
+ "golang.org/x/crypto/ssh"
"k8s.io/klog/v2"
"source.monogon.dev/cloud/bmaas/bmdb"
@@ -19,7 +19,6 @@
"source.monogon.dev/cloud/equinix/wrapngo"
"source.monogon.dev/cloud/lib/component"
"source.monogon.dev/cloud/shepherd/manager"
- "source.monogon.dev/go/net/ssh"
)
type Config struct {
@@ -94,18 +93,19 @@
klog.Exitf("%v", err)
}
- sshClient := &ssh.DirectClient{
- AuthMethods: []xssh.AuthMethod{xssh.PublicKeys(sshSigner)},
- // Equinix OS installations always use root.
- Username: "root",
- }
+ c.InitializerConfig.SSHConfig.Auth = []ssh.AuthMethod{ssh.PublicKeys(sshSigner)}
+ // Equinix OS installations always use root.
+ c.InitializerConfig.SSHConfig.User = "root"
+ // Ignore the host key, since it's likely the first time anything logs into
+ // this device, and also because there's no way of knowing its fingerprint.
+ c.InitializerConfig.SSHConfig.HostKeyCallback = ssh.InsecureIgnoreHostKey()
provisioner, err := manager.NewProvisioner(provider, c.ProvisionerConfig)
if err != nil {
klog.Exitf("%v", err)
}
- initializer, err := manager.NewInitializer(provider, sshClient, c.InitializerConfig)
+ initializer, err := manager.NewInitializer(provider, c.InitializerConfig)
if err != nil {
klog.Exitf("%v", err)
}
diff --git a/go/net/ssh/ssh_client.go b/go/net/ssh/ssh_client.go
deleted file mode 100644
index 73974d3..0000000
--- a/go/net/ssh/ssh_client.go
+++ /dev/null
@@ -1,143 +0,0 @@
-// Copyright The Monogon Project Authors.
-// SPDX-License-Identifier: Apache-2.0
-
-package ssh
-
-import (
- "bytes"
- "context"
- "fmt"
- "io"
- "net"
- "time"
-
- "github.com/pkg/sftp"
- "golang.org/x/crypto/ssh"
-)
-
-// Client defines a simple interface to an abstract SSH client. Usually this
-// would be DirectClient, but tests can use this interface to dependency-inject
-// fake SSH connections.
-type Client interface {
- // Dial returns an Connection to a given address (host:port pair) with
- // a timeout for connection.
- Dial(ctx context.Context, address string, connectTimeout time.Duration) (Connection, error)
-}
-
-type Connection interface {
- // Execute a given command on a remote host synchronously, passing in stdin as
- // input, and returning a captured stdout/stderr. The returned data might be
- // valid even when err != nil, which might happen if the remote side returned a
- // non-zero exit code.
- Execute(ctx context.Context, command string, stdin []byte) (stdout []byte, stderr []byte, err error)
- // Upload a given blob to a targetPath on the system and make executable.
- Upload(ctx context.Context, targetPath string, src io.Reader) error
- // Close this connection.
- Close() error
-}
-
-// DirectClient implements Client (and Connection) using
-// golang.org/x/crypto/ssh.
-type DirectClient struct {
- AuthMethods []ssh.AuthMethod
- Username string
-}
-
-type directConn struct {
- cl *ssh.Client
-}
-
-func (p *DirectClient) Dial(ctx context.Context, address string, connectTimeout time.Duration) (Connection, error) {
- d := net.Dialer{
- Timeout: connectTimeout,
- }
- conn, err := d.DialContext(ctx, "tcp", address)
- if err != nil {
- return nil, err
- }
- conf := &ssh.ClientConfig{
- User: p.Username,
- Auth: p.AuthMethods,
- // Ignore the host key, since it's likely the first time anything logs into
- // this device, and also because there's no way of knowing its fingerprint.
- HostKeyCallback: ssh.InsecureIgnoreHostKey(),
- // Timeout sets a bound on the time it takes to set up the connection, but
- // not on total session time.
- Timeout: connectTimeout,
- }
- conn2, chanC, reqC, err := ssh.NewClientConn(conn, address, conf)
- if err != nil {
- return nil, err
- }
- cl := ssh.NewClient(conn2, chanC, reqC)
- return &directConn{
- cl: cl,
- }, nil
-}
-
-func (p *directConn) Execute(ctx context.Context, command string, stdin []byte) (stdout []byte, stderr []byte, err error) {
- sess, err := p.cl.NewSession()
- if err != nil {
- return nil, nil, fmt.Errorf("while creating SSH session: %w", err)
- }
- stdoutBuf := bytes.NewBuffer(nil)
- stderrBuf := bytes.NewBuffer(nil)
- sess.Stdin = bytes.NewBuffer(stdin)
- sess.Stdout = stdoutBuf
- sess.Stderr = stderrBuf
- defer sess.Close()
-
- if err := sess.Start(command); err != nil {
- return nil, nil, err
- }
- doneC := make(chan error, 1)
- go func() {
- doneC <- sess.Wait()
- }()
- select {
- case <-ctx.Done():
- return nil, nil, ctx.Err()
- case err := <-doneC:
- return stdoutBuf.Bytes(), stderrBuf.Bytes(), err
- }
-}
-
-func (p *directConn) Upload(ctx context.Context, targetPath string, src io.Reader) error {
- sc, err := sftp.NewClient(p.cl, sftp.UseConcurrentWrites(true), sftp.MaxConcurrentRequestsPerFile(1024))
- if err != nil {
- return fmt.Errorf("while building sftp client: %w", err)
- }
- defer sc.Close()
-
- df, err := sc.Create(targetPath)
- if err != nil {
- return fmt.Errorf("while creating file on the host: %w", err)
- }
-
- doneC := make(chan error, 1)
-
- go func() {
- _, err := df.ReadFromWithConcurrency(src, 0)
- df.Close()
- doneC <- err
- }()
-
- select {
- case err := <-doneC:
- if err != nil {
- return fmt.Errorf("while copying file: %w", err)
- }
- case <-ctx.Done():
- df.Close()
- return ctx.Err()
- }
-
- if err := sc.Chmod(targetPath, 0755); err != nil {
- return fmt.Errorf("while setting file permissions: %w", err)
- }
- return nil
-}
-
-func (p *directConn) Close() error {
- return p.cl.Close()
-}
diff --git a/metropolis/cli/metroctl/BUILD.bazel b/metropolis/cli/metroctl/BUILD.bazel
index bd7ab70..8829be4 100644
--- a/metropolis/cli/metroctl/BUILD.bazel
+++ b/metropolis/cli/metroctl/BUILD.bazel
@@ -40,7 +40,6 @@
deps = [
"//go/clitable",
"//go/logging",
- "//go/net/ssh",
"//metropolis/cli/flagdefs",
"//metropolis/cli/metroctl/core",
"//metropolis/node",
@@ -50,6 +49,7 @@
"//metropolis/proto/common",
"//osbase/logtree",
"//osbase/logtree/proto",
+ "//osbase/net/sshtakeover",
"//osbase/structfs",
"//version",
"@com_github_adrg_xdg//:xdg",
diff --git a/metropolis/cli/metroctl/cmd_install_ssh.go b/metropolis/cli/metroctl/cmd_install_ssh.go
index f0a5379..410a88d 100644
--- a/metropolis/cli/metroctl/cmd_install_ssh.go
+++ b/metropolis/cli/metroctl/cmd_install_ssh.go
@@ -18,12 +18,12 @@
"github.com/schollz/progressbar/v3"
"github.com/spf13/cobra"
- xssh "golang.org/x/crypto/ssh"
+ "golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"golang.org/x/term"
"google.golang.org/protobuf/proto"
- "source.monogon.dev/go/net/ssh"
+ "source.monogon.dev/osbase/net/sshtakeover"
"source.monogon.dev/osbase/structfs"
)
@@ -47,11 +47,11 @@
return fmt.Errorf("flag disk is required")
}
- var authMethods []xssh.AuthMethod
+ var authMethods []ssh.AuthMethod
if aconn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil {
defer aconn.Close()
a := agent.NewClient(aconn)
- authMethods = append(authMethods, xssh.PublicKeysCallback(a.Signers))
+ authMethods = append(authMethods, ssh.PublicKeysCallback(a.Signers))
} else {
log.Printf("error while establishing ssh agent connection: %v", err)
log.Println("ssh agent authentication will not be available.")
@@ -62,7 +62,7 @@
stdin := int(syscall.Stdin) // nolint:unconvert
if term.IsTerminal(stdin) {
authMethods = append(authMethods,
- xssh.PasswordCallback(func() (string, error) {
+ ssh.PasswordCallback(func() (string, error) {
fmt.Printf("%s@%s's password: ", user, address)
b, err := term.ReadPassword(stdin)
if err != nil {
@@ -71,7 +71,7 @@
fmt.Println()
return string(b), nil
}),
- xssh.KeyboardInteractive(func(name, instruction string, questions []string, echos []bool) ([]string, error) {
+ ssh.KeyboardInteractive(func(name, instruction string, questions []string, echos []bool) ([]string, error) {
answers := make([]string, 0, len(questions))
for i, q := range questions {
fmt.Print(q)
@@ -95,13 +95,19 @@
log.Println("stdin is not interactive. password authentication will not be available.")
}
- cl := ssh.DirectClient{
- Username: user,
- AuthMethods: authMethods,
+ conf := &ssh.ClientConfig{
+ User: user,
+ Auth: authMethods,
+ // Ignore the host key, since it's likely the first time anything logs into
+ // this device, and also because there's no way of knowing its fingerprint.
+ HostKeyCallback: ssh.InsecureIgnoreHostKey(),
+ // Timeout sets a bound on the time it takes to set up the connection, but
+ // not on total session time.
+ Timeout: 5 * time.Second,
}
ctx, _ := signal.NotifyContext(context.Background(), os.Interrupt)
- conn, err := cl.Dial(ctx, address, 5*time.Second)
+ conn, err := sshtakeover.Dial(ctx, address, conf)
if err != nil {
return fmt.Errorf("error while establishing ssh connection: %w", err)
}
@@ -146,7 +152,7 @@
proxyReader := progressbar.NewReader(content, bar)
defer proxyReader.Close()
- if err := conn.Upload(ctx, targetPath, &proxyReader); err != nil {
+ if err := conn.UploadExecutable(ctx, targetPath, &proxyReader); err != nil {
log.Fatalf("error while uploading %q: %v", targetPath, err)
}
}
diff --git a/go/net/ssh/BUILD.bazel b/osbase/net/sshtakeover/BUILD.bazel
similarity index 63%
rename from go/net/ssh/BUILD.bazel
rename to osbase/net/sshtakeover/BUILD.bazel
index cb82262..d7a35b0 100644
--- a/go/net/ssh/BUILD.bazel
+++ b/osbase/net/sshtakeover/BUILD.bazel
@@ -1,9 +1,9 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library(
- name = "ssh",
- srcs = ["ssh_client.go"],
- importpath = "source.monogon.dev/go/net/ssh",
+ name = "sshtakeover",
+ srcs = ["sshtakeover.go"],
+ importpath = "source.monogon.dev/osbase/net/sshtakeover",
visibility = ["//visibility:public"],
deps = [
"@com_github_pkg_sftp//:sftp",
diff --git a/osbase/net/sshtakeover/sshtakeover.go b/osbase/net/sshtakeover/sshtakeover.go
new file mode 100644
index 0000000..20a70d2
--- /dev/null
+++ b/osbase/net/sshtakeover/sshtakeover.go
@@ -0,0 +1,128 @@
+// Copyright The Monogon Project Authors.
+// SPDX-License-Identifier: Apache-2.0
+
+// Package sshtakeover provides an [ssh.Client] wrapper which provides utilities
+// for taking over a machine over ssh, by uploading an executable and other
+// payloads, and then executing the executable.
+package sshtakeover
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "io"
+ "net"
+
+ "github.com/pkg/sftp"
+ "golang.org/x/crypto/ssh"
+)
+
+type Client struct {
+ cl *ssh.Client
+ sc *sftp.Client
+}
+
+// Dial starts an ssh client connection.
+func Dial(ctx context.Context, address string, config *ssh.ClientConfig) (*Client, error) {
+ d := net.Dialer{
+ Timeout: config.Timeout,
+ }
+ conn, err := d.DialContext(ctx, "tcp", address)
+ if err != nil {
+ return nil, err
+ }
+ conn2, chanC, reqC, err := ssh.NewClientConn(conn, address, config)
+ if err != nil {
+ return nil, err
+ }
+ cl := ssh.NewClient(conn2, chanC, reqC)
+
+ sc, err := sftp.NewClient(cl, sftp.UseConcurrentWrites(true), sftp.MaxConcurrentRequestsPerFile(1024))
+ if err != nil {
+ cl.Close()
+ return nil, fmt.Errorf("while building sftp client: %w", err)
+ }
+ return &Client{
+ cl: cl,
+ sc: sc,
+ }, nil
+}
+
+// Execute a given command on a remote host synchronously, passing in stdin as
+// input, and returning a captured stdout/stderr. The returned data might be
+// valid even when err != nil, which might happen if the remote side returned a
+// non-zero exit code.
+func (p *Client) Execute(ctx context.Context, command string, stdin []byte) (stdout []byte, stderr []byte, err error) {
+ sess, err := p.cl.NewSession()
+ if err != nil {
+ return nil, nil, fmt.Errorf("while creating SSH session: %w", err)
+ }
+ stdoutBuf := bytes.NewBuffer(nil)
+ stderrBuf := bytes.NewBuffer(nil)
+ sess.Stdin = bytes.NewBuffer(stdin)
+ sess.Stdout = stdoutBuf
+ sess.Stderr = stderrBuf
+ defer sess.Close()
+
+ if err := sess.Start(command); err != nil {
+ return nil, nil, err
+ }
+ doneC := make(chan error, 1)
+ go func() {
+ doneC <- sess.Wait()
+ }()
+ select {
+ case <-ctx.Done():
+ return nil, nil, ctx.Err()
+ case err := <-doneC:
+ return stdoutBuf.Bytes(), stderrBuf.Bytes(), err
+ }
+}
+
+// Upload a given blob to a targetPath on the system.
+func (p *Client) Upload(ctx context.Context, targetPath string, src io.Reader) error {
+ df, err := p.sc.Create(targetPath)
+ if err != nil {
+ return fmt.Errorf("while creating file on the host: %w", err)
+ }
+
+ doneC := make(chan error, 1)
+
+ go func() {
+ _, err := df.ReadFromWithConcurrency(src, 0)
+ df.Close()
+ doneC <- err
+ }()
+
+ select {
+ case err := <-doneC:
+ if err != nil {
+ return fmt.Errorf("while copying file: %w", err)
+ }
+ case <-ctx.Done():
+ df.Close()
+ return ctx.Err()
+ }
+ return nil
+}
+
+// UploadExecutable uploads a given blob to a targetPath on the system
+// and makes it executable.
+func (p *Client) UploadExecutable(ctx context.Context, targetPath string, src io.Reader) error {
+ if err := p.Upload(ctx, targetPath, src); err != nil {
+ return err
+ }
+ if err := p.sc.Chmod(targetPath, 0755); err != nil {
+ return fmt.Errorf("while setting file permissions: %w", err)
+ }
+ return nil
+}
+
+func (p *Client) Close() error {
+ scErr := p.sc.Close()
+ clErr := p.cl.Close()
+ if clErr != nil {
+ return clErr
+ }
+ return scErr
+}