blob: abf82ea7da2d96bf41d3fae1942e6fb760d2c82b [file] [log] [blame]
Jan Schär0175d7a2025-03-26 12:57:23 +00001// Copyright The Monogon Project Authors.
2// SPDX-License-Identifier: Apache-2.0
3
4// Package sshtakeover provides an [ssh.Client] wrapper which provides utilities
5// for taking over a machine over ssh, by uploading an executable and other
6// payloads, and then executing the executable.
7package sshtakeover
8
9import (
10 "bytes"
11 "context"
Jan Schär5fdca562025-04-14 11:33:29 +000012 "errors"
Jan Schär0175d7a2025-03-26 12:57:23 +000013 "fmt"
14 "io"
15 "net"
Jan Schär5fdca562025-04-14 11:33:29 +000016 "os"
Jan Schär0175d7a2025-03-26 12:57:23 +000017
18 "github.com/pkg/sftp"
19 "golang.org/x/crypto/ssh"
Jan Schär5fdca562025-04-14 11:33:29 +000020
21 "source.monogon.dev/osbase/structfs"
Jan Schär0175d7a2025-03-26 12:57:23 +000022)
23
24type Client struct {
Jan Schär7ea9ef12025-03-27 15:57:41 +000025 cl *ssh.Client
26 sc *sftp.Client
27 progress func(int64)
Jan Schär0175d7a2025-03-26 12:57:23 +000028}
29
30// Dial starts an ssh client connection.
31func Dial(ctx context.Context, address string, config *ssh.ClientConfig) (*Client, error) {
32 d := net.Dialer{
33 Timeout: config.Timeout,
34 }
35 conn, err := d.DialContext(ctx, "tcp", address)
36 if err != nil {
37 return nil, err
38 }
39 conn2, chanC, reqC, err := ssh.NewClientConn(conn, address, config)
40 if err != nil {
41 return nil, err
42 }
43 cl := ssh.NewClient(conn2, chanC, reqC)
44
45 sc, err := sftp.NewClient(cl, sftp.UseConcurrentWrites(true), sftp.MaxConcurrentRequestsPerFile(1024))
46 if err != nil {
47 cl.Close()
48 return nil, fmt.Errorf("while building sftp client: %w", err)
49 }
50 return &Client{
51 cl: cl,
52 sc: sc,
53 }, nil
54}
55
56// Execute a given command on a remote host synchronously, passing in stdin as
57// input, and returning a captured stdout/stderr. The returned data might be
58// valid even when err != nil, which might happen if the remote side returned a
59// non-zero exit code.
60func (p *Client) Execute(ctx context.Context, command string, stdin []byte) (stdout []byte, stderr []byte, err error) {
61 sess, err := p.cl.NewSession()
62 if err != nil {
63 return nil, nil, fmt.Errorf("while creating SSH session: %w", err)
64 }
65 stdoutBuf := bytes.NewBuffer(nil)
66 stderrBuf := bytes.NewBuffer(nil)
67 sess.Stdin = bytes.NewBuffer(stdin)
68 sess.Stdout = stdoutBuf
69 sess.Stderr = stderrBuf
70 defer sess.Close()
71
72 if err := sess.Start(command); err != nil {
73 return nil, nil, err
74 }
75 doneC := make(chan error, 1)
76 go func() {
77 doneC <- sess.Wait()
78 }()
79 select {
80 case <-ctx.Done():
81 return nil, nil, ctx.Err()
82 case err := <-doneC:
83 return stdoutBuf.Bytes(), stderrBuf.Bytes(), err
84 }
85}
86
Jan Schär7ea9ef12025-03-27 15:57:41 +000087type wrappedReader struct {
88 r io.Reader
89 ctx context.Context
90 progress func(int64)
Jan Schär51f81e52025-03-27 13:13:46 +000091}
92
Jan Schär7ea9ef12025-03-27 15:57:41 +000093func (r *wrappedReader) Read(p []byte) (n int, err error) {
Jan Schär51f81e52025-03-27 13:13:46 +000094 if r.ctx.Err() != nil {
95 return 0, r.ctx.Err()
96 }
Jan Schär7ea9ef12025-03-27 15:57:41 +000097 n, err = r.r.Read(p)
98 if r.progress != nil {
99 r.progress(int64(n))
100 }
101 return
Jan Schär51f81e52025-03-27 13:13:46 +0000102}
103
Jan Schär0175d7a2025-03-26 12:57:23 +0000104// Upload a given blob to a targetPath on the system.
105func (p *Client) Upload(ctx context.Context, targetPath string, src io.Reader) error {
Jan Schär7ea9ef12025-03-27 15:57:41 +0000106 src = &wrappedReader{r: src, ctx: ctx, progress: p.progress}
Jan Schär51f81e52025-03-27 13:13:46 +0000107
Jan Schär0175d7a2025-03-26 12:57:23 +0000108 df, err := p.sc.Create(targetPath)
109 if err != nil {
110 return fmt.Errorf("while creating file on the host: %w", err)
111 }
Jan Schär51f81e52025-03-27 13:13:46 +0000112 _, err = df.ReadFromWithConcurrency(src, 0)
113 closeErr := df.Close()
114 if err != nil {
115 return err
Jan Schär0175d7a2025-03-26 12:57:23 +0000116 }
Jan Schär51f81e52025-03-27 13:13:46 +0000117 return closeErr
Jan Schär0175d7a2025-03-26 12:57:23 +0000118}
119
120// UploadExecutable uploads a given blob to a targetPath on the system
121// and makes it executable.
122func (p *Client) UploadExecutable(ctx context.Context, targetPath string, src io.Reader) error {
123 if err := p.Upload(ctx, targetPath, src); err != nil {
124 return err
125 }
126 if err := p.sc.Chmod(targetPath, 0755); err != nil {
127 return fmt.Errorf("while setting file permissions: %w", err)
128 }
129 return nil
130}
131
Jan Schär5fdca562025-04-14 11:33:29 +0000132func (p *Client) UploadTree(ctx context.Context, targetPath string, tree structfs.Tree) error {
133 if err := p.sc.RemoveAll(targetPath); err != nil && !errors.Is(err, os.ErrNotExist) {
134 return fmt.Errorf("RemoveAll: %w", err)
135 }
136 if err := p.sc.Mkdir(targetPath); err != nil {
137 return err
138 }
139 for nodePath, node := range tree.Walk() {
140 fullPath := targetPath + "/" + nodePath
141 switch {
142 case node.Mode.IsDir():
143 if err := p.sc.Mkdir(fullPath); err != nil {
144 return fmt.Errorf("sftp mkdir %q: %w", fullPath, err)
145 }
146 case node.Mode.IsRegular():
147 reader, err := node.Content.Open()
148 if err != nil {
149 return fmt.Errorf("upload %q: %w", nodePath, err)
150 }
151 if err := p.Upload(ctx, fullPath, reader); err != nil {
152 reader.Close()
153 return fmt.Errorf("upload %q: %w", fullPath, err)
154 }
155 reader.Close()
156 default:
157 return fmt.Errorf("upload %q: unsupported file type %s", nodePath, node.Mode.Type().String())
158 }
159 }
160 return nil
161}
162
Jan Schär7ea9ef12025-03-27 15:57:41 +0000163// SetProgress sets a callback which will be called repeatedly during uploads
164// with a number of bytes that have been read.
165func (p *Client) SetProgress(callback func(int64)) {
166 p.progress = callback
167}
168
Jan Schär0175d7a2025-03-26 12:57:23 +0000169func (p *Client) Close() error {
170 scErr := p.sc.Close()
171 clErr := p.cl.Close()
172 if clErr != nil {
173 return clErr
174 }
175 return scErr
176}