m/n/core/curator: prevent nodes from sharing wireguard keys

As WireGuard keys nodes by public key, we must do our best to never let
two nodes share the same public key.

Change-Id: Ib8bc9b839355c1ee94dcf3ba42368055b47c2c21
Reviewed-on: https://review.monogon.dev/c/monogon/+/1415
Reviewed-by: Lorenz Brun <lorenz@monogon.tech>
Tested-by: Jenkins CI
diff --git a/metropolis/node/core/curator/impl_leader.go b/metropolis/node/core/curator/impl_leader.go
index ffb927b..2a66e29 100644
--- a/metropolis/node/core/curator/impl_leader.go
+++ b/metropolis/node/core/curator/impl_leader.go
@@ -29,6 +29,11 @@
 	// startTs is a local monotonic clock timestamp associated with this node's
 	// assumption of Curator leadership.
 	startTs time.Time
+
+	// clusternetCache maps wireguard public keys (as strings) into node IDs. It is
+	// used to detect possibly re-used WireGuard public keys without having to get
+	// all nodes from etcd.
+	clusternetCache map[string]string
 }
 
 // leadership represents the curator leader's ability to perform actions as a
diff --git a/metropolis/node/core/curator/impl_leader_cluster_networking.go b/metropolis/node/core/curator/impl_leader_cluster_networking.go
index 51d813a..80da0d3 100644
--- a/metropolis/node/core/curator/impl_leader_cluster_networking.go
+++ b/metropolis/node/core/curator/impl_leader_cluster_networking.go
@@ -10,10 +10,58 @@
 
 	"source.monogon.dev/metropolis/node/core/identity"
 	"source.monogon.dev/metropolis/node/core/rpc"
+	"source.monogon.dev/metropolis/pkg/event"
+	"source.monogon.dev/metropolis/pkg/event/etcd"
 
 	ipb "source.monogon.dev/metropolis/node/core/curator/proto/api"
 )
 
+// preapreClusternetCacheUnlocked makes sure the leader's clusternetCache exists,
+// and loads it from etcd otherwise.
+func (l *leaderCurator) prepareClusternetCacheUnlocked(ctx context.Context) error {
+	if l.ls.clusternetCache != nil {
+		return nil
+	}
+
+	cache := make(map[string]string)
+
+	// Get all nodes.
+	start, end := nodeEtcdPrefix.KeyRange()
+	value := etcd.NewValue[*nodeAtID](l.etcd, start, nodeValueConverter, etcd.Range(end))
+	w := value.Watch()
+	defer w.Close()
+	for {
+		nodeKV, err := w.Get(ctx, event.BacklogOnly[*nodeAtID]())
+		if err == event.BacklogDone {
+			break
+		}
+		if err != nil {
+			rpc.Trace(ctx).Printf("etcd watch failed (initial fetch): %v", err)
+			return status.Error(codes.Unavailable, "internal error during clusternet cache load")
+		}
+		n := nodeKV.value
+		if n == nil {
+			continue
+		}
+
+		// Ignore nodes without cluster networking.
+		if n.wireguardKey == "" {
+			continue
+		}
+
+		// If we have an inconsistency in the database, just pretend it's not there.
+		//
+		// TODO(q3k): try to recover from this.
+		if id, ok := cache[n.wireguardKey]; ok && id != n.ID() {
+			continue
+		}
+		cache[n.wireguardKey] = n.ID()
+	}
+
+	l.ls.clusternetCache = cache
+	return nil
+}
+
 func (l *leaderCurator) UpdateNodeClusterNetworking(ctx context.Context, req *ipb.UpdateNodeClusterNetworkingRequest) (*ipb.UpdateNodeClusterNetworkingResponse, error) {
 	// Ensure that the given node_id matches the calling node. We currently
 	// only allow for direct self-reporting of status by nodes.
@@ -30,19 +78,26 @@
 	if cn.WireguardPubkey == "" {
 		return nil, status.Error(codes.InvalidArgument, "clusternet.wireguard_pubkey must be set")
 	}
-	_, err := wgtypes.ParseKey(cn.WireguardPubkey)
+	key, err := wgtypes.ParseKey(cn.WireguardPubkey)
 	if err != nil {
 		return nil, status.Error(codes.InvalidArgument, "clusternet.wireguard_pubkey must be a valid wireguard public key")
 	}
 
-	// TODO(q3k): unhardcode this and synchronize with Kubernetes code.
-	clusterNet := netip.MustParsePrefix("10.0.0.0/16")
-
-	// Update node with new clusternetworking data. We're doing a load/modify/store,
-	// so lock here.
+	// Lock everything, as we're doing a complex read/modify/store here.
 	l.muNodes.Lock()
 	defer l.muNodes.Unlock()
 
+	if err := l.prepareClusternetCacheUnlocked(ctx); err != nil {
+		return nil, err
+	}
+
+	if nid, ok := l.ls.clusternetCache[key.String()]; ok && nid != id {
+		return nil, status.Error(codes.InvalidArgument, "public key alread used by another node")
+	}
+
+	// TODO(q3k): unhardcode this and synchronize with Kubernetes code.
+	clusterNet := netip.MustParsePrefix("10.0.0.0/16")
+
 	// Retrieve node ...
 	node, err := nodeLoad(ctx, l.leadership, id)
 	if err != nil {
@@ -88,13 +143,15 @@
 
 	}
 
-	// ... update its' clusternetworking bits ...
-	node.wireguardKey = cn.WireguardPubkey
+	// Modify and save node.
+	node.wireguardKey = key.String()
 	node.networkPrefixes = prefixes
-	// ... and save it to etcd.
 	if err := nodeSave(ctx, l.leadership, node); err != nil {
 		return nil, err
 	}
 
+	// Now that etcd is saved, also modify our cache.
+	l.ls.clusternetCache[key.String()] = id
+
 	return &ipb.UpdateNodeClusterNetworkingResponse{}, nil
 }
diff --git a/metropolis/node/core/curator/impl_leader_test.go b/metropolis/node/core/curator/impl_leader_test.go
index be97b63..465de13 100644
--- a/metropolis/node/core/curator/impl_leader_test.go
+++ b/metropolis/node/core/curator/impl_leader_test.go
@@ -1306,6 +1306,14 @@
 	ctx, ctxC := context.WithCancel(context.Background())
 	defer ctxC()
 
+	// Make another fake node out-of band. We'll be using it at the end to make sure
+	// that we can't add another node with the same pubkey. We have to do it as early
+	// as possible to bypass caching by the leader.
+	//
+	// TODO(q3k): implement adding more nodes in harness, and just add another node
+	// normally. This will actually exercise the cache better.
+	putNode(t, ctx, cl.l, func(n *Node) { n.wireguardKey = "+nb5grgIKQEbHm5JrUZovPQ9Bv04jR2TtY6sgS0dGG4=" })
+
 	cur := ipb.NewCuratorClient(cl.localNodeConn)
 	// Update the node's external address as it's used in tests.
 	_, err := cur.UpdateNodeStatus(ctx, &ipb.UpdateNodeStatusRequest{
@@ -1425,4 +1433,13 @@
 			}
 		}
 	}
+
+	// Make sure adding another node with the same pubkey fails.
+	_, err = cur.UpdateNodeClusterNetworking(ctx, &ipb.UpdateNodeClusterNetworkingRequest{
+		Clusternet: &cpb.NodeClusterNetworking{
+			WireguardPubkey: "+nb5grgIKQEbHm5JrUZovPQ9Bv04jR2TtY6sgS0dGG4=",
+		}})
+	if err == nil || !strings.Contains(err.Error(), "public key alread used by another node") {
+		t.Errorf("Adding same pubkey to different node should have failed, got %v", err)
+	}
 }