m/n/core/rpc: implement node verification in authenticated connections

The current API of NewAuthenticatedCredentials is not easily extensible,
so switch over to such an API now.

This then adds a WantRemoteNode option which verifies that the remote
connection is established to a node with a given ID.

Change-Id: Ie9f6b33d8b032729181bae5591eba9856ea2f523
Reviewed-on: https://review.monogon.dev/c/monogon/+/1427
Tested-by: Jenkins CI
Reviewed-by: Lorenz Brun <lorenz@monogon.tech>
diff --git a/metropolis/node/core/curator/impl_leader_test.go b/metropolis/node/core/curator/impl_leader_test.go
index fcb3e1d..be97b63 100644
--- a/metropolis/node/core/curator/impl_leader_test.go
+++ b/metropolis/node/core/curator/impl_leader_test.go
@@ -175,14 +175,17 @@
 	ca := nodeCredentials.ClusterCA()
 
 	// Create an authenticated manager gRPC client.
-	mcl, err := grpc.Dial("local", withLocalDialer, grpc.WithTransportCredentials(rpc.NewAuthenticatedCredentials(ownerCreds, ca)))
+	creds := rpc.NewAuthenticatedCredentials(ownerCreds, rpc.WantRemoteCluster(ca))
+	gcreds := grpc.WithTransportCredentials(creds)
+	mcl, err := grpc.Dial("local", withLocalDialer, gcreds)
 	if err != nil {
 		t.Fatalf("Dialing external GRPC failed: %v", err)
 	}
 
 	// Create a node gRPC client for the local node.
-	lcl, err := grpc.Dial("local", withLocalDialer,
-		grpc.WithTransportCredentials(rpc.NewAuthenticatedCredentials(nodeCredentials.TLSCredentials(), ca)))
+	creds = rpc.NewAuthenticatedCredentials(nodeCredentials.TLSCredentials(), rpc.WantRemoteCluster(ca))
+	gcreds = grpc.WithTransportCredentials(creds)
+	lcl, err := grpc.Dial("local", withLocalDialer, gcreds)
 	if err != nil {
 		t.Fatalf("Dialing external GRPC failed: %v", err)
 	}
diff --git a/metropolis/node/core/roleserve/value_clustermembership.go b/metropolis/node/core/roleserve/value_clustermembership.go
index e956d10..3044c7c 100644
--- a/metropolis/node/core/roleserve/value_clustermembership.go
+++ b/metropolis/node/core/roleserve/value_clustermembership.go
@@ -104,7 +104,7 @@
 			m.resolver.AddEndpoint(resolver.NodeByHostPort(addr.Host, uint16(common.CuratorServicePort)))
 		}
 	}
-	creds := rpc.NewAuthenticatedCredentials(m.credentials.TLSCredentials(), m.credentials.ClusterCA())
+	creds := rpc.NewAuthenticatedCredentials(m.credentials.TLSCredentials(), rpc.WantRemoteCluster(m.credentials.ClusterCA()))
 	return grpc.Dial(resolver.MetropolisControlAddress, grpc.WithTransportCredentials(creds), grpc.WithResolvers(m.resolver))
 }
 
diff --git a/metropolis/node/core/rpc/client.go b/metropolis/node/core/rpc/client.go
index 70173d8..656fee5 100644
--- a/metropolis/node/core/rpc/client.go
+++ b/metropolis/node/core/rpc/client.go
@@ -20,7 +20,7 @@
 
 type verifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
 
-func verifyClusterCertificate(ca *x509.Certificate) verifyPeerCertificate {
+func verifyClusterCertificateAndNodeID(ca *x509.Certificate, nodeID string) verifyPeerCertificate {
 	return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
 		if len(rawCerts) != 1 {
 			return fmt.Errorf("server presented %d certificates, wanted exactly one", len(rawCerts))
@@ -29,14 +29,27 @@
 		if err != nil {
 			return fmt.Errorf("server presented unparseable certificate: %w", err)
 		}
-		if _, err := identity.VerifyNodeInCluster(serverCert, ca); err != nil {
+		pkey, err := identity.VerifyNodeInCluster(serverCert, ca)
+		if err != nil {
 			return fmt.Errorf("node certificate verification failed: %w", err)
 		}
+		if nodeID != "" {
+			id := identity.NodeID(pkey)
+			if id != nodeID {
+				return fmt.Errorf("wanted to reach node %q, got %q", nodeID, id)
+			}
+		}
 
 		return nil
 	}
 }
 
+func verifyFail(err error) verifyPeerCertificate {
+	return func(_ [][]byte, _ [][]*x509.Certificate) error {
+		return err
+	}
+}
+
 // NewEphemeralCredentials returns gRPC TransportCredentials that can be used to
 // dial a cluster without authenticating with a certificate, but instead
 // authenticating by proving the possession of a private key, via an ephemeral
@@ -44,11 +57,11 @@
 //
 // Currently these credentials are used in two flows:
 //
-//   1. Registration of nodes into a cluster, after which a node receives a proper
-//      node certificate
+//  1. Registration of nodes into a cluster, after which a node receives a proper
+//     node certificate
 //
-//   2. Escrow of initial owner credentials into a proper manager
-//      certificate
+//  2. Escrow of initial owner credentials into a proper manager
+//     certificate
 //
 // If 'ca' is given, the remote side will be cryptographically verified to be a
 // node that's part of the cluster represented by the ca. Otherwise, no
@@ -71,28 +84,106 @@
 		Certificate: [][]byte{certificateBytes},
 		PrivateKey:  private,
 	}
-	return NewAuthenticatedCredentials(certificate, ca), nil
+	var opts []AuthenticatedCredentialsOpt
+	if ca != nil {
+		opts = append(opts, WantRemoteCluster(ca))
+	} else {
+		opts = append(opts, WantInsecure())
+	}
+	return NewAuthenticatedCredentials(certificate, opts...), nil
+}
+
+// AuthenticatedCredentialsOpt are created using WantXXX functions and used in
+// NewAuthenticatedCredentials.
+type AuthenticatedCredentialsOpt struct {
+	wantCA       *x509.Certificate
+	wantNodeID   string
+	insecureOkay bool
+}
+
+func (a *AuthenticatedCredentialsOpt) merge(o *AuthenticatedCredentialsOpt) {
+	if a.wantNodeID == "" && o.wantNodeID != "" {
+		a.wantNodeID = o.wantNodeID
+	}
+	if a.wantCA == nil && o.wantCA != nil {
+		a.wantCA = o.wantCA
+	}
+	if !a.insecureOkay && o.insecureOkay {
+		a.insecureOkay = o.insecureOkay
+	}
+}
+
+// WantRemoteCluster enables the verification of the remote cluster identity when
+// using NewAuthanticatedCredentials. If the connection is not terminated at a
+// cluster with the given CA certificate, an error will be returned.
+//
+// This is the bare minimum option required to implement secure connections to
+// clusters.
+func WantRemoteCluster(ca *x509.Certificate) AuthenticatedCredentialsOpt {
+	return AuthenticatedCredentialsOpt{
+		wantCA: ca,
+	}
+}
+
+// WantRemoteNode enables the verification of the remote node identity when using
+// NewAuthenticatedCredentials. If the connection is not terminated at the node
+// ID 'id', an error will be returned. For this function to work,
+// WantRemoteCluster must also be set.
+func WantRemoteNode(id string) AuthenticatedCredentialsOpt {
+	return AuthenticatedCredentialsOpt{
+		wantNodeID: id,
+	}
+}
+
+// WantInsecure disables the verification of the remote side of the connection
+// via NewAuthenticatedCredentials. This is unsafe.
+func WantInsecure() AuthenticatedCredentialsOpt {
+	return AuthenticatedCredentialsOpt{
+		insecureOkay: true,
+	}
 }
 
 // NewAuthenticatedCredentials returns gRPC TransportCredentials that can be
 // used to dial a cluster with a given TLS certificate (from node or manager
 // credentials).
 //
-// If 'ca' is given, the remote side will be cryptographically verified to be a
-// node that's part of the cluster represented by the ca. Otherwise, no
-// verification is performed and this function is unsafe.
-func NewAuthenticatedCredentials(cert tls.Certificate, ca *x509.Certificate) credentials.TransportCredentials {
+// The provided AuthenticatedCredentialsOpt specify the verification of the
+// remote side of the connection. When connecting to a cluster (any node), use
+// WantRemoteCluster. If you also want to verify the connection to a particular
+// node, specify WantRemoteNode alongside it. If no verification should be
+// performed use WantInsecure.
+//
+// The given options are parsed on a first-wins basis.
+func NewAuthenticatedCredentials(cert tls.Certificate, opts ...AuthenticatedCredentialsOpt) credentials.TransportCredentials {
 	config := &tls.Config{
 		Certificates:       []tls.Certificate{cert},
 		InsecureSkipVerify: true,
 	}
-	if ca != nil {
-		config.VerifyPeerCertificate = verifyClusterCertificate(ca)
+
+	var merged AuthenticatedCredentialsOpt
+	for _, o := range opts {
+		merged.merge(&o)
 	}
+
+	if merged.insecureOkay {
+		if merged.wantNodeID != "" || merged.wantCA != nil {
+			config.VerifyPeerCertificate = verifyFail(fmt.Errorf("WantInsecure specified alongside WantRemoteNode/WantRemoteCluster"))
+		}
+	} else {
+		switch {
+		case merged.wantNodeID != "" && merged.wantCA == nil:
+			config.VerifyPeerCertificate = verifyFail(fmt.Errorf("WantRemoteNode also requires WantRemoteCluster"))
+		case merged.wantCA == nil:
+			config.VerifyPeerCertificate = verifyFail(fmt.Errorf("no AuthenticaedCreentialsOpts specified"))
+		default:
+			config.VerifyPeerCertificate = verifyClusterCertificateAndNodeID(merged.wantCA, merged.wantNodeID)
+		}
+	}
+
 	return credentials.NewTLS(config)
 }
 
-// RetrieveOwnerCertificates uses AAA.Escrow to retrieve a cluster manager
+// RetrieveOwnerCertificate uses AAA.Escrow to retrieve a cluster manager
 // certificate for the initial owner of the cluster, authenticated by the
 // public/private key set in the clusters NodeParameters.ClusterBoostrap.
 //
diff --git a/metropolis/node/core/rpc/server_authentication_test.go b/metropolis/node/core/rpc/server_authentication_test.go
index 8deaaea..09565ad 100644
--- a/metropolis/node/core/rpc/server_authentication_test.go
+++ b/metropolis/node/core/rpc/server_authentication_test.go
@@ -62,7 +62,7 @@
 
 	// Authenticate as manager externally, ensure that GetRegisterTicket runs.
 	cl, err := grpc.Dial("local",
-		grpc.WithTransportCredentials(NewAuthenticatedCredentials(eph.Manager, eph.CA)),
+		grpc.WithTransportCredentials(NewAuthenticatedCredentials(eph.Manager, WantRemoteCluster(eph.CA))),
 		withLocalDialer)
 	if err != nil {
 		t.Fatalf("Dial: %v", err)
@@ -77,7 +77,7 @@
 	// Authenticate as node externally, ensure that GetRegisterTicket is refused
 	// (this is because nodes miss the GET_REGISTER_TICKET permissions).
 	cl, err = grpc.Dial("local",
-		grpc.WithTransportCredentials(NewAuthenticatedCredentials(eph.Nodes[0].TLSCredentials(), eph.CA)),
+		grpc.WithTransportCredentials(NewAuthenticatedCredentials(eph.Nodes[0].TLSCredentials(), WantRemoteCluster(eph.CA))),
 		withLocalDialer)
 	if err != nil {
 		t.Fatalf("Dial: %v", err)