blob: 20a70d247983609b07698c0279e317d00d744416 [file] [log] [blame]
Jan Schär0175d7a2025-03-26 12:57:23 +00001// Copyright The Monogon Project Authors.
2// SPDX-License-Identifier: Apache-2.0
3
4// Package sshtakeover provides an [ssh.Client] wrapper which provides utilities
5// for taking over a machine over ssh, by uploading an executable and other
6// payloads, and then executing the executable.
7package sshtakeover
8
9import (
10 "bytes"
11 "context"
12 "fmt"
13 "io"
14 "net"
15
16 "github.com/pkg/sftp"
17 "golang.org/x/crypto/ssh"
18)
19
20type Client struct {
21 cl *ssh.Client
22 sc *sftp.Client
23}
24
25// Dial starts an ssh client connection.
26func Dial(ctx context.Context, address string, config *ssh.ClientConfig) (*Client, error) {
27 d := net.Dialer{
28 Timeout: config.Timeout,
29 }
30 conn, err := d.DialContext(ctx, "tcp", address)
31 if err != nil {
32 return nil, err
33 }
34 conn2, chanC, reqC, err := ssh.NewClientConn(conn, address, config)
35 if err != nil {
36 return nil, err
37 }
38 cl := ssh.NewClient(conn2, chanC, reqC)
39
40 sc, err := sftp.NewClient(cl, sftp.UseConcurrentWrites(true), sftp.MaxConcurrentRequestsPerFile(1024))
41 if err != nil {
42 cl.Close()
43 return nil, fmt.Errorf("while building sftp client: %w", err)
44 }
45 return &Client{
46 cl: cl,
47 sc: sc,
48 }, nil
49}
50
51// Execute a given command on a remote host synchronously, passing in stdin as
52// input, and returning a captured stdout/stderr. The returned data might be
53// valid even when err != nil, which might happen if the remote side returned a
54// non-zero exit code.
55func (p *Client) Execute(ctx context.Context, command string, stdin []byte) (stdout []byte, stderr []byte, err error) {
56 sess, err := p.cl.NewSession()
57 if err != nil {
58 return nil, nil, fmt.Errorf("while creating SSH session: %w", err)
59 }
60 stdoutBuf := bytes.NewBuffer(nil)
61 stderrBuf := bytes.NewBuffer(nil)
62 sess.Stdin = bytes.NewBuffer(stdin)
63 sess.Stdout = stdoutBuf
64 sess.Stderr = stderrBuf
65 defer sess.Close()
66
67 if err := sess.Start(command); err != nil {
68 return nil, nil, err
69 }
70 doneC := make(chan error, 1)
71 go func() {
72 doneC <- sess.Wait()
73 }()
74 select {
75 case <-ctx.Done():
76 return nil, nil, ctx.Err()
77 case err := <-doneC:
78 return stdoutBuf.Bytes(), stderrBuf.Bytes(), err
79 }
80}
81
82// Upload a given blob to a targetPath on the system.
83func (p *Client) Upload(ctx context.Context, targetPath string, src io.Reader) error {
84 df, err := p.sc.Create(targetPath)
85 if err != nil {
86 return fmt.Errorf("while creating file on the host: %w", err)
87 }
88
89 doneC := make(chan error, 1)
90
91 go func() {
92 _, err := df.ReadFromWithConcurrency(src, 0)
93 df.Close()
94 doneC <- err
95 }()
96
97 select {
98 case err := <-doneC:
99 if err != nil {
100 return fmt.Errorf("while copying file: %w", err)
101 }
102 case <-ctx.Done():
103 df.Close()
104 return ctx.Err()
105 }
106 return nil
107}
108
109// UploadExecutable uploads a given blob to a targetPath on the system
110// and makes it executable.
111func (p *Client) UploadExecutable(ctx context.Context, targetPath string, src io.Reader) error {
112 if err := p.Upload(ctx, targetPath, src); err != nil {
113 return err
114 }
115 if err := p.sc.Chmod(targetPath, 0755); err != nil {
116 return fmt.Errorf("while setting file permissions: %w", err)
117 }
118 return nil
119}
120
121func (p *Client) Close() error {
122 scErr := p.sc.Close()
123 clErr := p.cl.Close()
124 if clErr != nil {
125 return clErr
126 }
127 return scErr
128}