blob: 014e6b1c207f4d8c0c3a3265df4efd693b7b7ad6 [file] [log] [blame]
Tim Windelschmidt6d33a432025-02-04 14:34:25 +01001// Copyright The Monogon Project Authors.
2// SPDX-License-Identifier: Apache-2.0
3
Tim Windelschmidt7a1b27d2024-02-22 23:54:58 +01004package main
5
6import (
7 "context"
8 _ "embed"
9 "fmt"
10 "log"
11 "net"
12 "net/netip"
13 "os"
14 "os/signal"
15 "strings"
16 "syscall"
17 "time"
18
19 "github.com/schollz/progressbar/v3"
20 "github.com/spf13/cobra"
Jan Schär0175d7a2025-03-26 12:57:23 +000021 "golang.org/x/crypto/ssh"
Tim Windelschmidt7a1b27d2024-02-22 23:54:58 +010022 "golang.org/x/crypto/ssh/agent"
Tim Windelschmidt7a1b27d2024-02-22 23:54:58 +010023 "golang.org/x/term"
24 "google.golang.org/protobuf/proto"
25
Jan Schär0175d7a2025-03-26 12:57:23 +000026 "source.monogon.dev/osbase/net/sshtakeover"
Jan Schär5fdca562025-04-14 11:33:29 +000027 "source.monogon.dev/osbase/oci"
Tim Windelschmidt7a1b27d2024-02-22 23:54:58 +010028)
29
Jan Schär7ea9ef12025-03-27 15:57:41 +000030// progressbarUpdater wraps a [progressbar.ProgressBar] with an improved
31// interface for updating progress. It updates the progress bar in a separate
32// goroutine and at most 60 times per second. The stop function stops the
33// updates and can be safely called multiple times.
34type progressbarUpdater struct {
35 bar *progressbar.ProgressBar
36 update chan int64
37 close chan struct{}
38}
39
40func startProgressbarUpdater(bar *progressbar.ProgressBar) *progressbarUpdater {
41 updater := &progressbarUpdater{
42 bar: bar,
43 update: make(chan int64, 1),
44 close: make(chan struct{}),
45 }
46 go updater.run()
47 return updater
48}
49
50func (p *progressbarUpdater) add(num int64) {
51 for {
52 select {
53 case p.update <- num:
54 return
55 case oldNum := <-p.update:
56 num += oldNum
57 }
58 }
59}
60
61func (p *progressbarUpdater) run() {
62 for {
63 select {
64 case num := <-p.update:
65 p.bar.Add64(num)
66 case <-p.close:
67 return
68 }
69 select {
70 case <-time.After(time.Second / 60):
71 case <-p.close:
72 return
73 }
74 }
75}
76
77func (p *progressbarUpdater) stop() {
78 if p.close == nil {
79 return
80 }
81 p.close <- struct{}{}
82 p.close = nil
83 select {
84 case num := <-p.update:
85 // Do one last update to make the bar reach 100%.
86 p.bar.Add64(num)
87 default:
88 }
89 if !p.bar.IsFinished() {
90 p.bar.Exit()
91 }
92}
93
Tim Windelschmidt7a1b27d2024-02-22 23:54:58 +010094var sshCmd = &cobra.Command{
95 Use: "ssh --disk=<disk> <target>",
96 Short: "Installs Metropolis on a Linux system accessible via SSH.",
Jan Schär5fdca562025-04-14 11:33:29 +000097 Example: "metroctl install --image=metropolis-v0.1 --takeover=takeover ssh --disk=nvme0n1 root@ssh-enabled-server.example",
Tim Windelschmidt7a1b27d2024-02-22 23:54:58 +010098 Args: cobra.ExactArgs(1), // One positional argument: the target
Tim Windelschmidt0b4fb8c2024-09-18 17:34:23 +020099 RunE: func(cmd *cobra.Command, args []string) error {
100 user, address, err := parseSSHAddr(args[0])
101 if err != nil {
102 return err
103 }
104
105 diskName, err := cmd.Flags().GetString("disk")
106 if err != nil {
107 return err
108 }
109
110 if len(diskName) == 0 {
111 return fmt.Errorf("flag disk is required")
112 }
113
Jan Schär0175d7a2025-03-26 12:57:23 +0000114 var authMethods []ssh.AuthMethod
Tim Windelschmidt0b4fb8c2024-09-18 17:34:23 +0200115 if aconn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil {
116 defer aconn.Close()
117 a := agent.NewClient(aconn)
Jan Schär0175d7a2025-03-26 12:57:23 +0000118 authMethods = append(authMethods, ssh.PublicKeysCallback(a.Signers))
Tim Windelschmidt0b4fb8c2024-09-18 17:34:23 +0200119 } else {
120 log.Printf("error while establishing ssh agent connection: %v", err)
121 log.Println("ssh agent authentication will not be available.")
122 }
123
Timon Stampflid7c8bbb2024-12-15 17:26:35 +0100124 // On Windows syscall.Stdin is a handle and needs to be cast to an
125 // int for term.
126 stdin := int(syscall.Stdin) // nolint:unconvert
127 if term.IsTerminal(stdin) {
Tim Windelschmidt0b4fb8c2024-09-18 17:34:23 +0200128 authMethods = append(authMethods,
Jan Schär0175d7a2025-03-26 12:57:23 +0000129 ssh.PasswordCallback(func() (string, error) {
Tim Windelschmidt0b4fb8c2024-09-18 17:34:23 +0200130 fmt.Printf("%s@%s's password: ", user, address)
Timon Stampflid7c8bbb2024-12-15 17:26:35 +0100131 b, err := term.ReadPassword(stdin)
Tim Windelschmidt0b4fb8c2024-09-18 17:34:23 +0200132 if err != nil {
133 return "", err
134 }
135 fmt.Println()
136 return string(b), nil
137 }),
Jan Schär0175d7a2025-03-26 12:57:23 +0000138 ssh.KeyboardInteractive(func(name, instruction string, questions []string, echos []bool) ([]string, error) {
Tim Windelschmidt0b4fb8c2024-09-18 17:34:23 +0200139 answers := make([]string, 0, len(questions))
140 for i, q := range questions {
141 fmt.Print(q)
142 if echos[i] {
143 if _, err := fmt.Scan(&questions[i]); err != nil {
144 return nil, err
145 }
146 } else {
Timon Stampflid7c8bbb2024-12-15 17:26:35 +0100147 b, err := term.ReadPassword(stdin)
Tim Windelschmidt0b4fb8c2024-09-18 17:34:23 +0200148 if err != nil {
149 return nil, err
150 }
151 fmt.Println()
152 answers = append(answers, string(b))
153 }
154 }
155 return answers, nil
156 }),
157 )
158 } else {
159 log.Println("stdin is not interactive. password authentication will not be available.")
160 }
161
Jan Schär0175d7a2025-03-26 12:57:23 +0000162 conf := &ssh.ClientConfig{
163 User: user,
164 Auth: authMethods,
165 // Ignore the host key, since it's likely the first time anything logs into
166 // this device, and also because there's no way of knowing its fingerprint.
167 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
168 // Timeout sets a bound on the time it takes to set up the connection, but
169 // not on total session time.
170 Timeout: 5 * time.Second,
Tim Windelschmidt0b4fb8c2024-09-18 17:34:23 +0200171 }
172
173 ctx, _ := signal.NotifyContext(context.Background(), os.Interrupt)
Jan Schär0175d7a2025-03-26 12:57:23 +0000174 conn, err := sshtakeover.Dial(ctx, address, conf)
Tim Windelschmidt0b4fb8c2024-09-18 17:34:23 +0200175 if err != nil {
176 return fmt.Errorf("error while establishing ssh connection: %w", err)
177 }
178
179 params, err := makeNodeParams()
180 if err != nil {
181 return err
182 }
183 rawParams, err := proto.Marshal(params)
184 if err != nil {
185 return fmt.Errorf("error while marshaling node params: %w", err)
186 }
187
188 const takeoverTargetPath = "/root/takeover"
Jan Schär5fdca562025-04-14 11:33:29 +0000189 const imageTargetPath = "/root/osimage"
190
191 imagePathResolved, err := external("image", "_main/metropolis/node/oci_image", imagePath)
Tim Windelschmidt0b4fb8c2024-09-18 17:34:23 +0200192 if err != nil {
193 return err
194 }
Jan Schär5fdca562025-04-14 11:33:29 +0000195 image, err := oci.ReadLayout(imagePathResolved)
196 if err != nil {
197 return fmt.Errorf("failed to read OS image: %w", err)
198 }
199 imageLayout, err := oci.CreateLayout(image)
200 if err != nil {
201 return fmt.Errorf("failed to read OS image: %w", err)
202 }
Jan Schärf07d1b32025-03-24 18:36:06 +0000203 takeoverPath, err := cmd.Flags().GetString("takeover")
204 if err != nil {
205 return err
206 }
Jan Schär2b9a0a02025-07-09 07:54:12 +0000207 takeover, err := externalFile("takeover", "_main/metropolis/cli/takeover/takeover_/takeover", &takeoverPath)
Tim Windelschmidt0b4fb8c2024-09-18 17:34:23 +0200208 if err != nil {
209 return err
210 }
211
Jan Schär7ea9ef12025-03-27 15:57:41 +0000212 log.Println("Uploading files to target host.")
Jan Schär5fdca562025-04-14 11:33:29 +0000213 totalSize := takeover.Size()
214 for _, entry := range imageLayout.Walk() {
215 if entry.Mode.IsRegular() {
216 totalSize += entry.Content.Size()
217 }
218 }
Jan Schär7ea9ef12025-03-27 15:57:41 +0000219 barUpdater := startProgressbarUpdater(progressbar.DefaultBytes(totalSize))
220 defer barUpdater.stop()
221 conn.SetProgress(barUpdater.add)
Jan Schärc1b6df42025-03-20 08:52:18 +0000222
Jan Schär7ea9ef12025-03-27 15:57:41 +0000223 takeoverContent, err := takeover.Open()
224 if err != nil {
225 return err
226 }
227 err = conn.UploadExecutable(ctx, takeoverTargetPath, takeoverContent)
228 takeoverContent.Close()
229 if err != nil {
230 return fmt.Errorf("error while uploading %q: %w", takeoverTargetPath, err)
Tim Windelschmidt0b4fb8c2024-09-18 17:34:23 +0200231 }
232
Jan Schär5fdca562025-04-14 11:33:29 +0000233 err = conn.UploadTree(ctx, imageTargetPath, imageLayout)
Jan Schär7ea9ef12025-03-27 15:57:41 +0000234 if err != nil {
Jan Schär5fdca562025-04-14 11:33:29 +0000235 return fmt.Errorf("error while uploading OS image: %w", err)
Jan Schär7ea9ef12025-03-27 15:57:41 +0000236 }
237
238 barUpdater.stop()
Tim Windelschmidt0b4fb8c2024-09-18 17:34:23 +0200239
240 // Start the agent and wait for the agent's output to arrive.
241 log.Printf("Starting the takeover executable at path %q.", takeoverTargetPath)
242 _, stderr, err := conn.Execute(ctx, fmt.Sprintf("%s -disk %s", takeoverTargetPath, diskName), rawParams)
243 stderrStr := strings.TrimSpace(string(stderr))
244 if stderrStr != "" {
245 log.Printf("Agent stderr: %q", stderrStr)
246 }
247 if err != nil {
248 return fmt.Errorf("while starting the takeover executable: %w", err)
249 }
250
251 return nil
252 },
Tim Windelschmidt7a1b27d2024-02-22 23:54:58 +0100253}
254
255func parseAddrOptionalPort(addr string) (string, string, error) {
256 if addr == "" {
257 return "", "", fmt.Errorf("address is empty")
258 }
259
260 idx := strings.LastIndex(addr, ":")
261 // IPv4, DNS without Port.
262 if idx == -1 {
263 return addr, "", nil
264 }
265
266 // IPv4, DNS with Port.
267 if strings.Count(addr, ":") == 1 {
268 return addr[:idx], addr[idx+1:], nil
269 }
270
271 // IPv6 with Port.
272 if addrPort, err := netip.ParseAddrPort(addr); err == nil {
273 return addrPort.Addr().String(), fmt.Sprintf("%d", addrPort.Port()), nil
274 }
275
276 // IPv6 without Port.
277 if addr, err := netip.ParseAddr(addr); err == nil {
278 return addr.String(), "", nil
279 }
280
281 return "", "", fmt.Errorf("failed to parse address: %q", addr)
282}
283
284func parseSSHAddr(s string) (string, string, error) {
285 user, rawAddr, ok := strings.Cut(s, "@")
286 if !ok {
287 return "", "", fmt.Errorf("SSH user is mandatory")
288 }
289
290 addr, port, err := parseAddrOptionalPort(rawAddr)
291 if err != nil {
292 return "", "", err
293 }
294 if port == "" {
295 port = "22"
296 }
297
298 return user, net.JoinHostPort(addr, port), nil
299}
300
Tim Windelschmidt7a1b27d2024-02-22 23:54:58 +0100301func init() {
302 sshCmd.Flags().String("disk", "", "Which disk Metropolis should be installed to")
303 sshCmd.Flags().String("takeover", "", "Path to the Metropolis takeover binary")
304
305 installCmd.AddCommand(sshCmd)
306}