blob: 20a70d247983609b07698c0279e317d00d744416 [file] [log] [blame]
// Copyright The Monogon Project Authors.
// SPDX-License-Identifier: Apache-2.0
// Package sshtakeover provides an [ssh.Client] wrapper which provides utilities
// for taking over a machine over ssh, by uploading an executable and other
// payloads, and then executing the executable.
package sshtakeover
import (
"bytes"
"context"
"fmt"
"io"
"net"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
)
type Client struct {
cl *ssh.Client
sc *sftp.Client
}
// Dial starts an ssh client connection.
func Dial(ctx context.Context, address string, config *ssh.ClientConfig) (*Client, error) {
d := net.Dialer{
Timeout: config.Timeout,
}
conn, err := d.DialContext(ctx, "tcp", address)
if err != nil {
return nil, err
}
conn2, chanC, reqC, err := ssh.NewClientConn(conn, address, config)
if err != nil {
return nil, err
}
cl := ssh.NewClient(conn2, chanC, reqC)
sc, err := sftp.NewClient(cl, sftp.UseConcurrentWrites(true), sftp.MaxConcurrentRequestsPerFile(1024))
if err != nil {
cl.Close()
return nil, fmt.Errorf("while building sftp client: %w", err)
}
return &Client{
cl: cl,
sc: sc,
}, nil
}
// Execute a given command on a remote host synchronously, passing in stdin as
// input, and returning a captured stdout/stderr. The returned data might be
// valid even when err != nil, which might happen if the remote side returned a
// non-zero exit code.
func (p *Client) Execute(ctx context.Context, command string, stdin []byte) (stdout []byte, stderr []byte, err error) {
sess, err := p.cl.NewSession()
if err != nil {
return nil, nil, fmt.Errorf("while creating SSH session: %w", err)
}
stdoutBuf := bytes.NewBuffer(nil)
stderrBuf := bytes.NewBuffer(nil)
sess.Stdin = bytes.NewBuffer(stdin)
sess.Stdout = stdoutBuf
sess.Stderr = stderrBuf
defer sess.Close()
if err := sess.Start(command); err != nil {
return nil, nil, err
}
doneC := make(chan error, 1)
go func() {
doneC <- sess.Wait()
}()
select {
case <-ctx.Done():
return nil, nil, ctx.Err()
case err := <-doneC:
return stdoutBuf.Bytes(), stderrBuf.Bytes(), err
}
}
// Upload a given blob to a targetPath on the system.
func (p *Client) Upload(ctx context.Context, targetPath string, src io.Reader) error {
df, err := p.sc.Create(targetPath)
if err != nil {
return fmt.Errorf("while creating file on the host: %w", err)
}
doneC := make(chan error, 1)
go func() {
_, err := df.ReadFromWithConcurrency(src, 0)
df.Close()
doneC <- err
}()
select {
case err := <-doneC:
if err != nil {
return fmt.Errorf("while copying file: %w", err)
}
case <-ctx.Done():
df.Close()
return ctx.Err()
}
return nil
}
// UploadExecutable uploads a given blob to a targetPath on the system
// and makes it executable.
func (p *Client) UploadExecutable(ctx context.Context, targetPath string, src io.Reader) error {
if err := p.Upload(ctx, targetPath, src); err != nil {
return err
}
if err := p.sc.Chmod(targetPath, 0755); err != nil {
return fmt.Errorf("while setting file permissions: %w", err)
}
return nil
}
func (p *Client) Close() error {
scErr := p.sc.Close()
clErr := p.cl.Close()
if clErr != nil {
return clErr
}
return scErr
}