m/n/c/curator: implement Join Flow

This implements Join Flow in Curator, as described in Cluster Lifecycle
and Integrity design document.

Change-Id: Idabb471575e1d22a7eb7cce2ad29d18f1f94760a
Reviewed-on: https://review.monogon.dev/c/monogon/+/667
Reviewed-by: Sergiusz Bazanski <serge@monogon.tech>
diff --git a/metropolis/node/core/curator/state_node.go b/metropolis/node/core/curator/state_node.go
index ce87723..82f9859 100644
--- a/metropolis/node/core/curator/state_node.go
+++ b/metropolis/node/core/curator/state_node.go
@@ -19,6 +19,7 @@
 import (
 	"context"
 	"crypto/x509"
+	"encoding/hex"
 	"fmt"
 
 	clientv3 "go.etcd.io/etcd/client/v3"
@@ -62,6 +63,11 @@
 	// The public key is used to generate the Node's canonical ID.
 	pubkey []byte
 
+	// jkey is the node's ED25519 public Join Key. The private part of the key
+	// never leaves the node. The key is generated by the node and passed to
+	// Curator during the registration process.
+	jkey []byte
+
 	// state is the state of this node as seen from the point of view of the
 	// cluster. See //metropolis/proto:common.proto for more information.
 	state cpb.NodeState
@@ -167,21 +173,31 @@
 }
 
 var (
+	// nodeEtcdPrefix is an etcd key prefix preceding cluster member node IDs,
+	// mapping to ppb.Node values.
 	nodeEtcdPrefix = mustNewEtcdPrefix("/nodes/")
+	// joinCredPrefix is an etcd key prefix preceding hex-encoded cluster member
+	// node join keys, mapping to node IDs.
+	joinCredPrefix = mustNewEtcdPrefix("/join_keys/")
 )
 
-// etcdPath builds the etcd path in which this node's protobuf-serialized state
-// is stored in etcd.
-func (n *Node) etcdPath() (string, error) {
+// etcdNodePath builds the etcd path in which this node's protobuf-serialized
+// state is stored in etcd.
+func (n *Node) etcdNodePath() (string, error) {
 	return nodeEtcdPrefix.Key(n.ID())
 }
 
+func (n *Node) etcdJoinKeyPath() (string, error) {
+	return joinCredPrefix.Key(hex.EncodeToString(n.jkey))
+}
+
 // proto serializes the Node object into protobuf, to be used for saving to
 // etcd.
 func (n *Node) proto() *ppb.Node {
 	msg := &ppb.Node{
 		ClusterUnlockKey: n.clusterUnlockKey,
 		PublicKey:        n.pubkey,
+		JoinKey:          n.jkey,
 		FsmState:         n.state,
 		Roles:            &cpb.NodeRoles{},
 		Status:           n.status,
@@ -215,6 +231,7 @@
 	n := &Node{
 		clusterUnlockKey: msg.ClusterUnlockKey,
 		pubkey:           msg.PublicKey,
+		jkey:             msg.JoinKey,
 		state:            msg.FsmState,
 		status:           msg.Status,
 	}
@@ -266,7 +283,7 @@
 	rpc.Trace(ctx).Printf("loadNode(%s)...", id)
 	key, err := nodeEtcdPrefix.Key(id)
 	if err != nil {
-		// TODO(issues/85): log err
+		rpc.Trace(ctx).Printf("invalid node id: %v", err)
 		return nil, status.Errorf(codes.InvalidArgument, "invalid node id")
 	}
 	res, err := l.txnAsLeader(ctx, clientv3.OpGet(key))
@@ -274,7 +291,7 @@
 		if rpcErr, ok := rpcError(err); ok {
 			return nil, rpcErr
 		}
-		// TODO(issues/85): log this
+		rpc.Trace(ctx).Printf("could not retrieve node %s: %v", id, err)
 		return nil, status.Errorf(codes.Unavailable, "could not retrieve node %s: %v", id, err)
 	}
 	kvs := res.Responses[0].GetResponseRange().Kvs
@@ -284,7 +301,7 @@
 	}
 	node, err := nodeUnmarshal(kvs[0].Value)
 	if err != nil {
-		// TODO(issues/85): log this
+		rpc.Trace(ctx).Printf("could not unmarshal node: %v", err)
 		return nil, status.Errorf(codes.Unavailable, "could not unmarshal node")
 	}
 	rpc.Trace(ctx).Printf("loadNode(%s): unmarshal ok", id)
@@ -294,26 +311,80 @@
 // nodeSave attempts to save a node into etcd, within a given active leadership.
 // All returned errors are gRPC statuses that safe to return to untrusted callers.
 func nodeSave(ctx context.Context, l *leadership, n *Node) error {
+	// Build an etcd operation to save the node with a key based on its ID.
 	id := n.ID()
 	rpc.Trace(ctx).Printf("nodeSave(%s)...", id)
-	key, err := nodeEtcdPrefix.Key(id)
+	nkey, err := nodeEtcdPrefix.Key(id)
 	if err != nil {
-		// TODO(issues/85): log err
+		rpc.Trace(ctx).Printf("invalid node id: %v", err)
 		return status.Errorf(codes.InvalidArgument, "invalid node id")
 	}
 	nodeBytes, err := proto.Marshal(n.proto())
 	if err != nil {
-		// TODO(issues/85): log this
+		rpc.Trace(ctx).Printf("could not marshal updated node: %v", err)
 		return status.Errorf(codes.Unavailable, "could not marshal updated node")
 	}
-	_, err = l.txnAsLeader(ctx, clientv3.OpPut(key, string(nodeBytes)))
+	ons := clientv3.OpPut(nkey, string(nodeBytes))
+	ops := []clientv3.Op{ons}
+
+	// Build an etcd operation to map the node's Join Key into its ID for use in
+	// Join Flow, if jkey is set. Once Join Flow is implemented on the client
+	// side, this operation will become mandatory.
+	if n.jkey != nil {
+		jkey, err := n.etcdJoinKeyPath()
+		if err != nil {
+			// This should never happen.
+			rpc.Trace(ctx).Printf("invalid join key representation: %v", err)
+			return status.Errorf(codes.InvalidArgument, "invalid join key representation")
+		}
+		// TODO(mateusz@monogon.tech): ensure that if the join key index already
+		// exists, it points to the node we're saving. Refuse to save/update the
+		// node if it doesn't.
+		oks := clientv3.OpPut(jkey, id)
+		ops = append(ops, oks)
+	}
+
+	// Execute one or both operations atomically.
+	_, err = l.txnAsLeader(ctx, ops...)
 	if err != nil {
 		if rpcErr, ok := rpcError(err); ok {
 			return rpcErr
 		}
-		// TODO(issues/85): log this
+		rpc.Trace(ctx).Printf("could not save updated node: %v", err)
 		return status.Error(codes.Unavailable, "could not save updated node")
 	}
 	rpc.Trace(ctx).Printf("nodeSave(%s): write ok", id)
 	return nil
 }
+
+// nodeIdByJoinKey attempts to fetch a Node ID corresponding to the given Join
+// Key from etcd, within a given active leadership. All returned errors are
+// gRPC statuses that are safe to return to untrusted callers. If the given
+// Join Key is not found, errNodeNotFound will be returned along with an empty
+// string.
+func nodeIdByJoinKey(ctx context.Context, l *leadership, jkey []byte) (string, error) {
+	if len(jkey) == 0 {
+		return "", status.Errorf(codes.InvalidArgument, "join key is empty")
+	}
+
+	cred := hex.EncodeToString(jkey)
+	key, err := joinCredPrefix.Key(cred)
+	if err != nil {
+		// This should never happen.
+		rpc.Trace(ctx).Printf("invalid join key representation: %v", err)
+		return "", status.Errorf(codes.InvalidArgument, "invalid join key representation")
+	}
+	res, err := l.txnAsLeader(ctx, clientv3.OpGet(key))
+	if err != nil {
+		if rpcErr, ok := rpcError(err); ok {
+			return "", rpcErr
+		}
+		return "", status.Errorf(codes.Unavailable, "could not retrieve node id matching join key %s: %v", cred, err)
+	}
+	kvs := res.Responses[0].GetResponseRange().Kvs
+	if len(kvs) != 1 {
+		return "", errNodeNotFound
+	}
+	id := string(kvs[0].Value[:])
+	return id, nil
+}