blob: 15ced4abfaed319b9ff2c4366b5a620b1a71da98 [file] [log] [blame]
Tim Windelschmidt7a1b27d2024-02-22 23:54:58 +01001package main
2
3import (
4 "context"
5 _ "embed"
6 "fmt"
7 "log"
8 "net"
9 "net/netip"
10 "os"
11 "os/signal"
12 "strings"
13 "syscall"
14 "time"
15
16 "github.com/schollz/progressbar/v3"
17 "github.com/spf13/cobra"
18 xssh "golang.org/x/crypto/ssh"
19 "golang.org/x/crypto/ssh/agent"
20 "golang.org/x/crypto/ssh/terminal"
21 "golang.org/x/term"
22 "google.golang.org/protobuf/proto"
23
24 "source.monogon.dev/go/net/ssh"
25 "source.monogon.dev/osbase/fat32"
26)
27
28var sshCmd = &cobra.Command{
29 Use: "ssh --disk=<disk> <target>",
30 Short: "Installs Metropolis on a Linux system accessible via SSH.",
31 Example: "metroctl install --bundle=metropolis-v0.1.zip --takeover=takeover ssh --disk=nvme0n1 root@ssh-enabled-server.example",
32 Args: cobra.ExactArgs(1), // One positional argument: the target
33 RunE: doSSH,
34}
35
36func parseAddrOptionalPort(addr string) (string, string, error) {
37 if addr == "" {
38 return "", "", fmt.Errorf("address is empty")
39 }
40
41 idx := strings.LastIndex(addr, ":")
42 // IPv4, DNS without Port.
43 if idx == -1 {
44 return addr, "", nil
45 }
46
47 // IPv4, DNS with Port.
48 if strings.Count(addr, ":") == 1 {
49 return addr[:idx], addr[idx+1:], nil
50 }
51
52 // IPv6 with Port.
53 if addrPort, err := netip.ParseAddrPort(addr); err == nil {
54 return addrPort.Addr().String(), fmt.Sprintf("%d", addrPort.Port()), nil
55 }
56
57 // IPv6 without Port.
58 if addr, err := netip.ParseAddr(addr); err == nil {
59 return addr.String(), "", nil
60 }
61
62 return "", "", fmt.Errorf("failed to parse address: %q", addr)
63}
64
65func parseSSHAddr(s string) (string, string, error) {
66 user, rawAddr, ok := strings.Cut(s, "@")
67 if !ok {
68 return "", "", fmt.Errorf("SSH user is mandatory")
69 }
70
71 addr, port, err := parseAddrOptionalPort(rawAddr)
72 if err != nil {
73 return "", "", err
74 }
75 if port == "" {
76 port = "22"
77 }
78
79 return user, net.JoinHostPort(addr, port), nil
80}
81
82func doSSH(cmd *cobra.Command, args []string) error {
83 user, address, err := parseSSHAddr(args[0])
84 if err != nil {
85 return err
86 }
87
88 diskName, err := cmd.Flags().GetString("disk")
89 if err != nil {
90 return err
91 }
92
93 if len(diskName) == 0 {
94 return fmt.Errorf("flag disk is required")
95 }
96
97 var authMethods []xssh.AuthMethod
98 if aconn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil {
99 defer aconn.Close()
100 a := agent.NewClient(aconn)
101 authMethods = append(authMethods, xssh.PublicKeysCallback(a.Signers))
102 } else {
103 log.Printf("error while establishing ssh agent connection: %v", err)
104 log.Println("ssh agent authentication will not be available.")
105 }
106
107 if term.IsTerminal(int(os.Stdin.Fd())) {
108 authMethods = append(authMethods,
109 xssh.PasswordCallback(func() (string, error) {
110 fmt.Printf("%s@%s's password: ", user, address)
111 b, err := terminal.ReadPassword(syscall.Stdin)
112 if err != nil {
113 return "", err
114 }
115 fmt.Println()
116 return string(b), nil
117 }),
118 xssh.KeyboardInteractive(func(name, instruction string, questions []string, echos []bool) ([]string, error) {
119 answers := make([]string, 0, len(questions))
120 for i, q := range questions {
121 fmt.Print(q)
122 if echos[i] {
123 if _, err := fmt.Scan(&questions[i]); err != nil {
124 return nil, err
125 }
126 } else {
127 b, err := terminal.ReadPassword(syscall.Stdin)
128 if err != nil {
129 return nil, err
130 }
131 fmt.Println()
132 answers = append(answers, string(b))
133 }
134 }
135 return answers, nil
136 }),
137 )
138 } else {
139 log.Println("stdin is not interactive. password authentication will not be available.")
140 }
141
142 cl := ssh.DirectClient{
143 Username: user,
144 AuthMethods: authMethods,
145 }
146
147 ctx, _ := signal.NotifyContext(context.Background(), os.Interrupt)
148 conn, err := cl.Dial(ctx, address, 5*time.Second)
149 if err != nil {
150 return fmt.Errorf("error while establishing ssh connection: %v", err)
151 }
152
153 params := makeNodeParams()
154 rawParams, err := proto.Marshal(params)
155 if err != nil {
156 return fmt.Errorf("error while marshaling node params: %v", err)
157 }
158
159 const takeoverTargetPath = "/root/takeover"
160 const bundleTargetPath = "/root/bundle.zip"
161 bundle := external("bundle", "_main/metropolis/node/bundle.zip", bundlePath)
162 takeover := external("takeover", "_main/metropolis/cli/takeover/takeover_bin_/takeover_bin", bundlePath)
163
164 barUploader := func(r fat32.SizedReader, targetPath string) {
165 bar := progressbar.DefaultBytes(
166 r.Size(),
167 targetPath,
168 )
169 defer bar.Close()
170
171 proxyReader := progressbar.NewReader(r, bar)
172 defer proxyReader.Close()
173
174 if err := conn.Upload(ctx, targetPath, &proxyReader); err != nil {
175 log.Fatalf("error while uploading %q: %v", targetPath, err)
176 }
177 }
178
179 log.Println("Uploading required binaries to target host.")
180 barUploader(takeover, takeoverTargetPath)
181 barUploader(bundle, bundleTargetPath)
182
183 // Start the agent and wait for the agent's output to arrive.
184 log.Printf("Starting the takeover executable at path %q.", takeoverTargetPath)
185 _, stderr, err := conn.Execute(ctx, fmt.Sprintf("%s -disk %s", takeoverTargetPath, diskName), rawParams)
186 stderrStr := strings.TrimSpace(string(stderr))
187 if stderrStr != "" {
188 log.Printf("Agent stderr: %q", stderrStr)
189 }
190 if err != nil {
191 return fmt.Errorf("while starting the takeover executable: %v", err)
192 }
193
194 return nil
195}
196
197func init() {
198 sshCmd.Flags().String("disk", "", "Which disk Metropolis should be installed to")
199 sshCmd.Flags().String("takeover", "", "Path to the Metropolis takeover binary")
200
201 installCmd.AddCommand(sshCmd)
202}