blob: b52c35366d218df702aface9bfd97a7fe6d0fa87 [file] [log] [blame]
Tim Windelschmidt6d33a432025-02-04 14:34:25 +01001// Copyright The Monogon Project Authors.
2// SPDX-License-Identifier: Apache-2.0
3
Tim Windelschmidt7a1b27d2024-02-22 23:54:58 +01004package main
5
6import (
7 "context"
8 _ "embed"
9 "fmt"
10 "log"
11 "net"
12 "net/netip"
13 "os"
14 "os/signal"
15 "strings"
16 "syscall"
17 "time"
18
19 "github.com/schollz/progressbar/v3"
20 "github.com/spf13/cobra"
Jan Schär0175d7a2025-03-26 12:57:23 +000021 "golang.org/x/crypto/ssh"
Tim Windelschmidt7a1b27d2024-02-22 23:54:58 +010022 "golang.org/x/crypto/ssh/agent"
Tim Windelschmidt7a1b27d2024-02-22 23:54:58 +010023 "golang.org/x/term"
24 "google.golang.org/protobuf/proto"
25
Jan Schär0175d7a2025-03-26 12:57:23 +000026 "source.monogon.dev/osbase/net/sshtakeover"
Tim Windelschmidt7a1b27d2024-02-22 23:54:58 +010027)
28
Jan Schär7ea9ef12025-03-27 15:57:41 +000029// progressbarUpdater wraps a [progressbar.ProgressBar] with an improved
30// interface for updating progress. It updates the progress bar in a separate
31// goroutine and at most 60 times per second. The stop function stops the
32// updates and can be safely called multiple times.
33type progressbarUpdater struct {
34 bar *progressbar.ProgressBar
35 update chan int64
36 close chan struct{}
37}
38
39func startProgressbarUpdater(bar *progressbar.ProgressBar) *progressbarUpdater {
40 updater := &progressbarUpdater{
41 bar: bar,
42 update: make(chan int64, 1),
43 close: make(chan struct{}),
44 }
45 go updater.run()
46 return updater
47}
48
49func (p *progressbarUpdater) add(num int64) {
50 for {
51 select {
52 case p.update <- num:
53 return
54 case oldNum := <-p.update:
55 num += oldNum
56 }
57 }
58}
59
60func (p *progressbarUpdater) run() {
61 for {
62 select {
63 case num := <-p.update:
64 p.bar.Add64(num)
65 case <-p.close:
66 return
67 }
68 select {
69 case <-time.After(time.Second / 60):
70 case <-p.close:
71 return
72 }
73 }
74}
75
76func (p *progressbarUpdater) stop() {
77 if p.close == nil {
78 return
79 }
80 p.close <- struct{}{}
81 p.close = nil
82 select {
83 case num := <-p.update:
84 // Do one last update to make the bar reach 100%.
85 p.bar.Add64(num)
86 default:
87 }
88 if !p.bar.IsFinished() {
89 p.bar.Exit()
90 }
91}
92
Tim Windelschmidt7a1b27d2024-02-22 23:54:58 +010093var sshCmd = &cobra.Command{
94 Use: "ssh --disk=<disk> <target>",
95 Short: "Installs Metropolis on a Linux system accessible via SSH.",
96 Example: "metroctl install --bundle=metropolis-v0.1.zip --takeover=takeover ssh --disk=nvme0n1 root@ssh-enabled-server.example",
97 Args: cobra.ExactArgs(1), // One positional argument: the target
Tim Windelschmidt0b4fb8c2024-09-18 17:34:23 +020098 RunE: func(cmd *cobra.Command, args []string) error {
99 user, address, err := parseSSHAddr(args[0])
100 if err != nil {
101 return err
102 }
103
104 diskName, err := cmd.Flags().GetString("disk")
105 if err != nil {
106 return err
107 }
108
109 if len(diskName) == 0 {
110 return fmt.Errorf("flag disk is required")
111 }
112
Jan Schär0175d7a2025-03-26 12:57:23 +0000113 var authMethods []ssh.AuthMethod
Tim Windelschmidt0b4fb8c2024-09-18 17:34:23 +0200114 if aconn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil {
115 defer aconn.Close()
116 a := agent.NewClient(aconn)
Jan Schär0175d7a2025-03-26 12:57:23 +0000117 authMethods = append(authMethods, ssh.PublicKeysCallback(a.Signers))
Tim Windelschmidt0b4fb8c2024-09-18 17:34:23 +0200118 } else {
119 log.Printf("error while establishing ssh agent connection: %v", err)
120 log.Println("ssh agent authentication will not be available.")
121 }
122
Timon Stampflid7c8bbb2024-12-15 17:26:35 +0100123 // On Windows syscall.Stdin is a handle and needs to be cast to an
124 // int for term.
125 stdin := int(syscall.Stdin) // nolint:unconvert
126 if term.IsTerminal(stdin) {
Tim Windelschmidt0b4fb8c2024-09-18 17:34:23 +0200127 authMethods = append(authMethods,
Jan Schär0175d7a2025-03-26 12:57:23 +0000128 ssh.PasswordCallback(func() (string, error) {
Tim Windelschmidt0b4fb8c2024-09-18 17:34:23 +0200129 fmt.Printf("%s@%s's password: ", user, address)
Timon Stampflid7c8bbb2024-12-15 17:26:35 +0100130 b, err := term.ReadPassword(stdin)
Tim Windelschmidt0b4fb8c2024-09-18 17:34:23 +0200131 if err != nil {
132 return "", err
133 }
134 fmt.Println()
135 return string(b), nil
136 }),
Jan Schär0175d7a2025-03-26 12:57:23 +0000137 ssh.KeyboardInteractive(func(name, instruction string, questions []string, echos []bool) ([]string, error) {
Tim Windelschmidt0b4fb8c2024-09-18 17:34:23 +0200138 answers := make([]string, 0, len(questions))
139 for i, q := range questions {
140 fmt.Print(q)
141 if echos[i] {
142 if _, err := fmt.Scan(&questions[i]); err != nil {
143 return nil, err
144 }
145 } else {
Timon Stampflid7c8bbb2024-12-15 17:26:35 +0100146 b, err := term.ReadPassword(stdin)
Tim Windelschmidt0b4fb8c2024-09-18 17:34:23 +0200147 if err != nil {
148 return nil, err
149 }
150 fmt.Println()
151 answers = append(answers, string(b))
152 }
153 }
154 return answers, nil
155 }),
156 )
157 } else {
158 log.Println("stdin is not interactive. password authentication will not be available.")
159 }
160
Jan Schär0175d7a2025-03-26 12:57:23 +0000161 conf := &ssh.ClientConfig{
162 User: user,
163 Auth: authMethods,
164 // Ignore the host key, since it's likely the first time anything logs into
165 // this device, and also because there's no way of knowing its fingerprint.
166 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
167 // Timeout sets a bound on the time it takes to set up the connection, but
168 // not on total session time.
169 Timeout: 5 * time.Second,
Tim Windelschmidt0b4fb8c2024-09-18 17:34:23 +0200170 }
171
172 ctx, _ := signal.NotifyContext(context.Background(), os.Interrupt)
Jan Schär0175d7a2025-03-26 12:57:23 +0000173 conn, err := sshtakeover.Dial(ctx, address, conf)
Tim Windelschmidt0b4fb8c2024-09-18 17:34:23 +0200174 if err != nil {
175 return fmt.Errorf("error while establishing ssh connection: %w", err)
176 }
177
178 params, err := makeNodeParams()
179 if err != nil {
180 return err
181 }
182 rawParams, err := proto.Marshal(params)
183 if err != nil {
184 return fmt.Errorf("error while marshaling node params: %w", err)
185 }
186
187 const takeoverTargetPath = "/root/takeover"
188 const bundleTargetPath = "/root/bundle.zip"
189 bundle, err := external("bundle", "_main/metropolis/node/bundle.zip", bundlePath)
190 if err != nil {
191 return err
192 }
Jan Schärf07d1b32025-03-24 18:36:06 +0000193 takeoverPath, err := cmd.Flags().GetString("takeover")
194 if err != nil {
195 return err
196 }
197 takeover, err := external("takeover", "_main/metropolis/cli/takeover/takeover_bin_/takeover_bin", &takeoverPath)
Tim Windelschmidt0b4fb8c2024-09-18 17:34:23 +0200198 if err != nil {
199 return err
200 }
201
Jan Schär7ea9ef12025-03-27 15:57:41 +0000202 log.Println("Uploading files to target host.")
203 totalSize := takeover.Size() + bundle.Size()
204 barUpdater := startProgressbarUpdater(progressbar.DefaultBytes(totalSize))
205 defer barUpdater.stop()
206 conn.SetProgress(barUpdater.add)
Jan Schärc1b6df42025-03-20 08:52:18 +0000207
Jan Schär7ea9ef12025-03-27 15:57:41 +0000208 takeoverContent, err := takeover.Open()
209 if err != nil {
210 return err
211 }
212 err = conn.UploadExecutable(ctx, takeoverTargetPath, takeoverContent)
213 takeoverContent.Close()
214 if err != nil {
215 return fmt.Errorf("error while uploading %q: %w", takeoverTargetPath, err)
Tim Windelschmidt0b4fb8c2024-09-18 17:34:23 +0200216 }
217
Jan Schär7ea9ef12025-03-27 15:57:41 +0000218 bundleContent, err := bundle.Open()
219 if err != nil {
220 return err
221 }
222 err = conn.Upload(ctx, bundleTargetPath, bundleContent)
223 bundleContent.Close()
224 if err != nil {
225 return fmt.Errorf("error while uploading %q: %w", bundleTargetPath, err)
226 }
227
228 barUpdater.stop()
Tim Windelschmidt0b4fb8c2024-09-18 17:34:23 +0200229
230 // Start the agent and wait for the agent's output to arrive.
231 log.Printf("Starting the takeover executable at path %q.", takeoverTargetPath)
232 _, stderr, err := conn.Execute(ctx, fmt.Sprintf("%s -disk %s", takeoverTargetPath, diskName), rawParams)
233 stderrStr := strings.TrimSpace(string(stderr))
234 if stderrStr != "" {
235 log.Printf("Agent stderr: %q", stderrStr)
236 }
237 if err != nil {
238 return fmt.Errorf("while starting the takeover executable: %w", err)
239 }
240
241 return nil
242 },
Tim Windelschmidt7a1b27d2024-02-22 23:54:58 +0100243}
244
245func parseAddrOptionalPort(addr string) (string, string, error) {
246 if addr == "" {
247 return "", "", fmt.Errorf("address is empty")
248 }
249
250 idx := strings.LastIndex(addr, ":")
251 // IPv4, DNS without Port.
252 if idx == -1 {
253 return addr, "", nil
254 }
255
256 // IPv4, DNS with Port.
257 if strings.Count(addr, ":") == 1 {
258 return addr[:idx], addr[idx+1:], nil
259 }
260
261 // IPv6 with Port.
262 if addrPort, err := netip.ParseAddrPort(addr); err == nil {
263 return addrPort.Addr().String(), fmt.Sprintf("%d", addrPort.Port()), nil
264 }
265
266 // IPv6 without Port.
267 if addr, err := netip.ParseAddr(addr); err == nil {
268 return addr.String(), "", nil
269 }
270
271 return "", "", fmt.Errorf("failed to parse address: %q", addr)
272}
273
274func parseSSHAddr(s string) (string, string, error) {
275 user, rawAddr, ok := strings.Cut(s, "@")
276 if !ok {
277 return "", "", fmt.Errorf("SSH user is mandatory")
278 }
279
280 addr, port, err := parseAddrOptionalPort(rawAddr)
281 if err != nil {
282 return "", "", err
283 }
284 if port == "" {
285 port = "22"
286 }
287
288 return user, net.JoinHostPort(addr, port), nil
289}
290
Tim Windelschmidt7a1b27d2024-02-22 23:54:58 +0100291func init() {
292 sshCmd.Flags().String("disk", "", "Which disk Metropolis should be installed to")
293 sshCmd.Flags().String("takeover", "", "Path to the Metropolis takeover binary")
294
295 installCmd.AddCommand(sshCmd)
296}