m/n/c/curator: factor out node{Load,Save}

This removes some duplicated logic between RPC implementations.

Change-Id: I3683ba11635a53f792def4d8dabddc09776ab427
Reviewed-on: https://review.monogon.dev/c/monogon/+/447
Reviewed-by: Mateusz Zalega <mateusz@monogon.tech>
diff --git a/metropolis/node/core/curator/impl_leader_curator.go b/metropolis/node/core/curator/impl_leader_curator.go
index c54dc11..6241827 100644
--- a/metropolis/node/core/curator/impl_leader_curator.go
+++ b/metropolis/node/core/curator/impl_leader_curator.go
@@ -5,7 +5,6 @@
 	"crypto/subtle"
 	"fmt"
 
-	"go.etcd.io/etcd/clientv3"
 	"google.golang.org/grpc/codes"
 	"google.golang.org/grpc/status"
 	"google.golang.org/protobuf/proto"
@@ -234,38 +233,15 @@
 	defer l.muNodes.Unlock()
 
 	// Retrieve node ...
-	key, err := nodeEtcdPrefix.Key(id)
+	node, err := nodeLoad(ctx, l.leadership, id)
 	if err != nil {
-		return nil, status.Errorf(codes.InvalidArgument, "invalid node id")
-	}
-	res, err := l.txnAsLeader(ctx, clientv3.OpGet(key))
-	if err != nil {
-		if rpcErr, ok := rpcError(err); ok {
-			return nil, rpcErr
-		}
-		return nil, status.Errorf(codes.Unavailable, "could not retrieve node: %v", err)
-	}
-	kvs := res.Responses[0].GetResponseRange().Kvs
-	if len(kvs) < 1 {
-		return nil, status.Error(codes.NotFound, "no such node")
-	}
-	node, err := nodeUnmarshal(kvs[0].Value)
-	if err != nil {
-		return nil, status.Errorf(codes.Unavailable, "failed to unmarshal node: %v", err)
+		return nil, err
 	}
 	// ... update its' status ...
 	node.status = req.Status
 	// ... and save it to etcd.
-	bytes, err := proto.Marshal(node.proto())
-	if err != nil {
-		return nil, status.Errorf(codes.Unavailable, "failed to marshal node: %v", err)
-	}
-	_, err = l.txnAsLeader(ctx, clientv3.OpPut(key, string(bytes)))
-	if err != nil {
-		if rpcErr, ok := rpcError(err); ok {
-			return nil, rpcErr
-		}
-		return nil, status.Errorf(codes.Unavailable, "could not update node: %v", err)
+	if err := nodeSave(ctx, l.leadership, node); err != nil {
+		return nil, err
 	}
 
 	return &ipb.UpdateNodeStatusResponse{}, nil
@@ -302,26 +278,8 @@
 
 	// Check if there already is a node with this pubkey in the cluster.
 	id := identity.NodeID(pubkey)
-	key, err := nodeEtcdPrefix.Key(id)
-	if err != nil {
-		// TODO(issues/85): log err
-		return nil, status.Errorf(codes.InvalidArgument, "invalid node id")
-	}
-	res, err := l.txnAsLeader(ctx, clientv3.OpGet(key))
-	if err != nil {
-		if rpcErr, ok := rpcError(err); ok {
-			return nil, rpcErr
-		}
-		// TODO(issues/85): log this
-		return nil, status.Errorf(codes.Unavailable, "could not retrieve node %s: %v", id, err)
-	}
-	kvs := res.Responses[0].GetResponseRange().Kvs
-	if len(kvs) > 0 {
-		node, err := nodeUnmarshal(kvs[0].Value)
-		if err != nil {
-			// TODO(issues/85): log this
-			return nil, status.Errorf(codes.Unavailable, "could not unmarshal node")
-		}
+	node, err := nodeLoad(ctx, l.leadership, id)
+	if err == nil {
 		// If the existing node is in the NEW state already, there's nothing to do,
 		// return no error. This can happen in case of spurious retries from the calling
 		// node.
@@ -336,23 +294,18 @@
 		// TODO(issues/85): log this
 		return nil, status.Errorf(codes.FailedPrecondition, "node already exists in cluster, state %s", node.state.String())
 	}
+	if err != errNodeNotFound {
+		return nil, err
+	}
 
 	// No node exists, create one.
-	node := &Node{
+	node = &Node{
 		pubkey: pubkey,
 		state:  cpb.NodeState_NODE_STATE_NEW,
 	}
-	nodeBytes, err := proto.Marshal(node.proto())
-	if err != nil {
-		// TODO(issues/85): log this
-		return nil, status.Errorf(codes.Unavailable, "could not marshal new node")
+	if err := nodeSave(ctx, l.leadership, node); err != nil {
+		return nil, err
 	}
-	_, err = l.txnAsLeader(ctx, clientv3.OpPut(key, string(nodeBytes)))
-	if err != nil {
-		// TODO(issues/85): log this
-		return nil, status.Error(codes.Unavailable, "could not save new node")
-	}
-
 	return &ipb.RegisterNodeResponse{}, nil
 }
 
@@ -373,33 +326,14 @@
 	l.muNodes.Lock()
 	defer l.muNodes.Unlock()
 
-	// Check if there is a node with this pubkey in the cluster.
-	id := identity.NodeID(pubkey)
-	key, err := nodeEtcdPrefix.Key(id)
-	if err != nil {
-		// TODO(issues/85): log err
-		return nil, status.Errorf(codes.InvalidArgument, "invalid node id")
-	}
-
 	// Retrieve the node and act on its current state, either returning early or
 	// mutating it and continuing with the rest of the Commit logic.
-	res, err := l.txnAsLeader(ctx, clientv3.OpGet(key))
+	id := identity.NodeID(pubkey)
+	node, err := nodeLoad(ctx, l.leadership, id)
 	if err != nil {
-		if rpcErr, ok := rpcError(err); ok {
-			return nil, rpcErr
-		}
-		// TODO(issues/85): log this
-		return nil, status.Errorf(codes.Unavailable, "could not retrieve node %s: %v", id, err)
+		return nil, err
 	}
-	kvs := res.Responses[0].GetResponseRange().Kvs
-	if len(kvs) != 1 {
-		return nil, status.Errorf(codes.NotFound, "node %s not found", id)
-	}
-	node, err := nodeUnmarshal(kvs[0].Value)
-	if err != nil {
-		// TODO(issues/85): log this
-		return nil, status.Errorf(codes.Unavailable, "could not unmarshal node")
-	}
+
 	switch node.state {
 	case cpb.NodeState_NODE_STATE_NEW:
 		return nil, status.Error(codes.PermissionDenied, "node is NEW, wait for attestation/approval")
@@ -437,15 +371,8 @@
 	node.state = cpb.NodeState_NODE_STATE_UP
 	node.clusterUnlockKey = req.ClusterUnlockKey
 
-	nodeBytes, err := proto.Marshal(node.proto())
-	if err != nil {
-		// TODO(issues/85): log this
-		return nil, status.Errorf(codes.Unavailable, "could not marshal updated node")
-	}
-	_, err = l.txnAsLeader(ctx, clientv3.OpPut(key, string(nodeBytes)))
-	if err != nil {
-		// TODO(issues/85): log this
-		return nil, status.Error(codes.Unavailable, "could not save updated node")
+	if err := nodeSave(ctx, l.leadership, node); err != nil {
+		return nil, err
 	}
 
 	// From this point on, any failure (in the server, or in the network, ...) dooms
diff --git a/metropolis/node/core/curator/impl_leader_management.go b/metropolis/node/core/curator/impl_leader_management.go
index b92229f..710eb93 100644
--- a/metropolis/node/core/curator/impl_leader_management.go
+++ b/metropolis/node/core/curator/impl_leader_management.go
@@ -6,10 +6,8 @@
 	"crypto/ed25519"
 	"sort"
 
-	"go.etcd.io/etcd/clientv3"
 	"google.golang.org/grpc/codes"
 	"google.golang.org/grpc/status"
-	"google.golang.org/protobuf/proto"
 
 	"source.monogon.dev/metropolis/node/core/identity"
 	apb "source.monogon.dev/metropolis/proto/api"
@@ -173,21 +171,9 @@
 
 	// Find node for this pubkey.
 	id := identity.NodeID(req.Pubkey)
-	key, err := nodeEtcdPrefix.Key(id)
+	node, err := nodeLoad(ctx, l.leadership, id)
 	if err != nil {
-		return nil, status.Errorf(codes.InvalidArgument, "pubkey invalid: %v", err)
-	}
-	res, err := l.txnAsLeader(ctx, clientv3.OpGet(key))
-	if err != nil {
-		return nil, status.Errorf(codes.Unavailable, "could not retrieve node: %v", err)
-	}
-	kvs := res.Responses[0].GetResponseRange().Kvs
-	if len(kvs) != 1 {
-		return nil, status.Errorf(codes.NotFound, "node with given pubkey not found")
-	}
-	node, err := nodeUnmarshal(kvs[0].Value)
-	if err != nil {
-		return nil, status.Errorf(codes.Internal, "could not deserialize node: %v", err)
+		return nil, err
 	}
 
 	// Ensure node is either UP/STANDBY (no-op) or NEW (set to STANDBY).
@@ -203,15 +189,8 @@
 
 	// Set node to be STANDBY.
 	node.state = cpb.NodeState_NODE_STATE_STANDBY
-	nodeBytes, err := proto.Marshal(node.proto())
-	if err != nil {
-		// TODO(issues/85): log this
-		return nil, status.Errorf(codes.Unavailable, "could not marshal updated node")
-	}
-	_, err = l.txnAsLeader(ctx, clientv3.OpPut(key, string(nodeBytes)))
-	if err != nil {
-		// TODO(issues/85): log this
-		return nil, status.Error(codes.Unavailable, "could not save updated node")
+	if err := nodeSave(ctx, l.leadership, node); err != nil {
+		return nil, err
 	}
 
 	return &apb.ApproveNodeResponse{}, nil
diff --git a/metropolis/node/core/curator/state_node.go b/metropolis/node/core/curator/state_node.go
index 69db4f7..c391e73 100644
--- a/metropolis/node/core/curator/state_node.go
+++ b/metropolis/node/core/curator/state_node.go
@@ -17,8 +17,12 @@
 package curator
 
 import (
+	"context"
 	"fmt"
 
+	"go.etcd.io/etcd/clientv3"
+	"google.golang.org/grpc/codes"
+	"google.golang.org/grpc/status"
 	"google.golang.org/protobuf/proto"
 
 	ppb "source.monogon.dev/metropolis/node/core/curator/proto/private"
@@ -151,3 +155,62 @@
 	}
 	return n, nil
 }
+
+var (
+	errNodeNotFound = status.Error(codes.NotFound, "node not found")
+)
+
+// nodeLoad attempts to load a node by ID from etcd, within a given active
+// leadership. All returned errors are gRPC statuses that are safe to return to
+// untrusted callers. If the given node is not found, errNodeNotFound will be
+// returned.
+func nodeLoad(ctx context.Context, l *leadership, id string) (*Node, error) {
+	key, err := nodeEtcdPrefix.Key(id)
+	if err != nil {
+		// TODO(issues/85): log err
+		return nil, status.Errorf(codes.InvalidArgument, "invalid node id")
+	}
+	res, err := l.txnAsLeader(ctx, clientv3.OpGet(key))
+	if err != nil {
+		if rpcErr, ok := rpcError(err); ok {
+			return nil, rpcErr
+		}
+		// TODO(issues/85): log this
+		return nil, status.Errorf(codes.Unavailable, "could not retrieve node %s: %v", id, err)
+	}
+	kvs := res.Responses[0].GetResponseRange().Kvs
+	if len(kvs) != 1 {
+		return nil, errNodeNotFound
+	}
+	node, err := nodeUnmarshal(kvs[0].Value)
+	if err != nil {
+		// TODO(issues/85): log this
+		return nil, status.Errorf(codes.Unavailable, "could not unmarshal node")
+	}
+	return node, nil
+}
+
+// nodeSave attempts to save a node into etcd, within a given active leadership.
+// All returned errors are gRPC statuses that safe to return to untrusted callers.
+func nodeSave(ctx context.Context, l *leadership, n *Node) error {
+	id := n.ID()
+	key, err := nodeEtcdPrefix.Key(id)
+	if err != nil {
+		// TODO(issues/85): log err
+		return status.Errorf(codes.InvalidArgument, "invalid node id")
+	}
+	nodeBytes, err := proto.Marshal(n.proto())
+	if err != nil {
+		// TODO(issues/85): log this
+		return status.Errorf(codes.Unavailable, "could not marshal updated node")
+	}
+	_, err = l.txnAsLeader(ctx, clientv3.OpPut(key, string(nodeBytes)))
+	if err != nil {
+		if rpcErr, ok := rpcError(err); ok {
+			return rpcErr
+		}
+		// TODO(issues/85): log this
+		return status.Error(codes.Unavailable, "could not save updated node")
+	}
+	return nil
+}