blob: fdb63f330ae030760fd94b20b4176c40ef30ed6b [file] [log] [blame]
Jan Schär0175d7a2025-03-26 12:57:23 +00001// Copyright The Monogon Project Authors.
2// SPDX-License-Identifier: Apache-2.0
3
4// Package sshtakeover provides an [ssh.Client] wrapper which provides utilities
5// for taking over a machine over ssh, by uploading an executable and other
6// payloads, and then executing the executable.
7package sshtakeover
8
9import (
10 "bytes"
11 "context"
12 "fmt"
13 "io"
14 "net"
15
16 "github.com/pkg/sftp"
17 "golang.org/x/crypto/ssh"
18)
19
20type Client struct {
Jan Schär7ea9ef12025-03-27 15:57:41 +000021 cl *ssh.Client
22 sc *sftp.Client
23 progress func(int64)
Jan Schär0175d7a2025-03-26 12:57:23 +000024}
25
26// Dial starts an ssh client connection.
27func Dial(ctx context.Context, address string, config *ssh.ClientConfig) (*Client, error) {
28 d := net.Dialer{
29 Timeout: config.Timeout,
30 }
31 conn, err := d.DialContext(ctx, "tcp", address)
32 if err != nil {
33 return nil, err
34 }
35 conn2, chanC, reqC, err := ssh.NewClientConn(conn, address, config)
36 if err != nil {
37 return nil, err
38 }
39 cl := ssh.NewClient(conn2, chanC, reqC)
40
41 sc, err := sftp.NewClient(cl, sftp.UseConcurrentWrites(true), sftp.MaxConcurrentRequestsPerFile(1024))
42 if err != nil {
43 cl.Close()
44 return nil, fmt.Errorf("while building sftp client: %w", err)
45 }
46 return &Client{
47 cl: cl,
48 sc: sc,
49 }, nil
50}
51
52// Execute a given command on a remote host synchronously, passing in stdin as
53// input, and returning a captured stdout/stderr. The returned data might be
54// valid even when err != nil, which might happen if the remote side returned a
55// non-zero exit code.
56func (p *Client) Execute(ctx context.Context, command string, stdin []byte) (stdout []byte, stderr []byte, err error) {
57 sess, err := p.cl.NewSession()
58 if err != nil {
59 return nil, nil, fmt.Errorf("while creating SSH session: %w", err)
60 }
61 stdoutBuf := bytes.NewBuffer(nil)
62 stderrBuf := bytes.NewBuffer(nil)
63 sess.Stdin = bytes.NewBuffer(stdin)
64 sess.Stdout = stdoutBuf
65 sess.Stderr = stderrBuf
66 defer sess.Close()
67
68 if err := sess.Start(command); err != nil {
69 return nil, nil, err
70 }
71 doneC := make(chan error, 1)
72 go func() {
73 doneC <- sess.Wait()
74 }()
75 select {
76 case <-ctx.Done():
77 return nil, nil, ctx.Err()
78 case err := <-doneC:
79 return stdoutBuf.Bytes(), stderrBuf.Bytes(), err
80 }
81}
82
Jan Schär7ea9ef12025-03-27 15:57:41 +000083type wrappedReader struct {
84 r io.Reader
85 ctx context.Context
86 progress func(int64)
Jan Schär51f81e52025-03-27 13:13:46 +000087}
88
Jan Schär7ea9ef12025-03-27 15:57:41 +000089func (r *wrappedReader) Read(p []byte) (n int, err error) {
Jan Schär51f81e52025-03-27 13:13:46 +000090 if r.ctx.Err() != nil {
91 return 0, r.ctx.Err()
92 }
Jan Schär7ea9ef12025-03-27 15:57:41 +000093 n, err = r.r.Read(p)
94 if r.progress != nil {
95 r.progress(int64(n))
96 }
97 return
Jan Schär51f81e52025-03-27 13:13:46 +000098}
99
Jan Schär0175d7a2025-03-26 12:57:23 +0000100// Upload a given blob to a targetPath on the system.
101func (p *Client) Upload(ctx context.Context, targetPath string, src io.Reader) error {
Jan Schär7ea9ef12025-03-27 15:57:41 +0000102 src = &wrappedReader{r: src, ctx: ctx, progress: p.progress}
Jan Schär51f81e52025-03-27 13:13:46 +0000103
Jan Schär0175d7a2025-03-26 12:57:23 +0000104 df, err := p.sc.Create(targetPath)
105 if err != nil {
106 return fmt.Errorf("while creating file on the host: %w", err)
107 }
Jan Schär51f81e52025-03-27 13:13:46 +0000108 _, err = df.ReadFromWithConcurrency(src, 0)
109 closeErr := df.Close()
110 if err != nil {
111 return err
Jan Schär0175d7a2025-03-26 12:57:23 +0000112 }
Jan Schär51f81e52025-03-27 13:13:46 +0000113 return closeErr
Jan Schär0175d7a2025-03-26 12:57:23 +0000114}
115
116// UploadExecutable uploads a given blob to a targetPath on the system
117// and makes it executable.
118func (p *Client) UploadExecutable(ctx context.Context, targetPath string, src io.Reader) error {
119 if err := p.Upload(ctx, targetPath, src); err != nil {
120 return err
121 }
122 if err := p.sc.Chmod(targetPath, 0755); err != nil {
123 return fmt.Errorf("while setting file permissions: %w", err)
124 }
125 return nil
126}
127
Jan Schär7ea9ef12025-03-27 15:57:41 +0000128// SetProgress sets a callback which will be called repeatedly during uploads
129// with a number of bytes that have been read.
130func (p *Client) SetProgress(callback func(int64)) {
131 p.progress = callback
132}
133
Jan Schär0175d7a2025-03-26 12:57:23 +0000134func (p *Client) Close() error {
135 scErr := p.sc.Close()
136 clErr := p.cl.Close()
137 if clErr != nil {
138 return clErr
139 }
140 return scErr
141}