blob: b8fd8b2d5f6268d9d8d728ed4e6540e484c5a47f [file] [log] [blame]
Jan Schär0175d7a2025-03-26 12:57:23 +00001// Copyright The Monogon Project Authors.
2// SPDX-License-Identifier: Apache-2.0
3
4// Package sshtakeover provides an [ssh.Client] wrapper which provides utilities
5// for taking over a machine over ssh, by uploading an executable and other
6// payloads, and then executing the executable.
7package sshtakeover
8
9import (
10 "bytes"
11 "context"
12 "fmt"
13 "io"
14 "net"
15
16 "github.com/pkg/sftp"
17 "golang.org/x/crypto/ssh"
18)
19
20type Client struct {
21 cl *ssh.Client
22 sc *sftp.Client
23}
24
25// Dial starts an ssh client connection.
26func Dial(ctx context.Context, address string, config *ssh.ClientConfig) (*Client, error) {
27 d := net.Dialer{
28 Timeout: config.Timeout,
29 }
30 conn, err := d.DialContext(ctx, "tcp", address)
31 if err != nil {
32 return nil, err
33 }
34 conn2, chanC, reqC, err := ssh.NewClientConn(conn, address, config)
35 if err != nil {
36 return nil, err
37 }
38 cl := ssh.NewClient(conn2, chanC, reqC)
39
40 sc, err := sftp.NewClient(cl, sftp.UseConcurrentWrites(true), sftp.MaxConcurrentRequestsPerFile(1024))
41 if err != nil {
42 cl.Close()
43 return nil, fmt.Errorf("while building sftp client: %w", err)
44 }
45 return &Client{
46 cl: cl,
47 sc: sc,
48 }, nil
49}
50
51// Execute a given command on a remote host synchronously, passing in stdin as
52// input, and returning a captured stdout/stderr. The returned data might be
53// valid even when err != nil, which might happen if the remote side returned a
54// non-zero exit code.
55func (p *Client) Execute(ctx context.Context, command string, stdin []byte) (stdout []byte, stderr []byte, err error) {
56 sess, err := p.cl.NewSession()
57 if err != nil {
58 return nil, nil, fmt.Errorf("while creating SSH session: %w", err)
59 }
60 stdoutBuf := bytes.NewBuffer(nil)
61 stderrBuf := bytes.NewBuffer(nil)
62 sess.Stdin = bytes.NewBuffer(stdin)
63 sess.Stdout = stdoutBuf
64 sess.Stderr = stderrBuf
65 defer sess.Close()
66
67 if err := sess.Start(command); err != nil {
68 return nil, nil, err
69 }
70 doneC := make(chan error, 1)
71 go func() {
72 doneC <- sess.Wait()
73 }()
74 select {
75 case <-ctx.Done():
76 return nil, nil, ctx.Err()
77 case err := <-doneC:
78 return stdoutBuf.Bytes(), stderrBuf.Bytes(), err
79 }
80}
81
Jan Schär51f81e52025-03-27 13:13:46 +000082type contextReader struct {
83 r io.Reader
84 ctx context.Context
85}
86
87func (r *contextReader) Read(p []byte) (n int, err error) {
88 if r.ctx.Err() != nil {
89 return 0, r.ctx.Err()
90 }
91 return r.r.Read(p)
92}
93
Jan Schär0175d7a2025-03-26 12:57:23 +000094// Upload a given blob to a targetPath on the system.
95func (p *Client) Upload(ctx context.Context, targetPath string, src io.Reader) error {
Jan Schär51f81e52025-03-27 13:13:46 +000096 src = &contextReader{r: src, ctx: ctx}
97
Jan Schär0175d7a2025-03-26 12:57:23 +000098 df, err := p.sc.Create(targetPath)
99 if err != nil {
100 return fmt.Errorf("while creating file on the host: %w", err)
101 }
Jan Schär51f81e52025-03-27 13:13:46 +0000102 _, err = df.ReadFromWithConcurrency(src, 0)
103 closeErr := df.Close()
104 if err != nil {
105 return err
Jan Schär0175d7a2025-03-26 12:57:23 +0000106 }
Jan Schär51f81e52025-03-27 13:13:46 +0000107 return closeErr
Jan Schär0175d7a2025-03-26 12:57:23 +0000108}
109
110// UploadExecutable uploads a given blob to a targetPath on the system
111// and makes it executable.
112func (p *Client) UploadExecutable(ctx context.Context, targetPath string, src io.Reader) error {
113 if err := p.Upload(ctx, targetPath, src); err != nil {
114 return err
115 }
116 if err := p.sc.Chmod(targetPath, 0755); err != nil {
117 return fmt.Errorf("while setting file permissions: %w", err)
118 }
119 return nil
120}
121
122func (p *Client) Close() error {
123 scErr := p.sc.Close()
124 clErr := p.cl.Close()
125 if clErr != nil {
126 return clErr
127 }
128 return scErr
129}