m/n/core/localstorage/crypt: support more enc/auth modes

This is in preparation for introducing configurable disk
encryption/authentication policies in Metropolis (eg. low integrity
mode).

We also use the opportunity to add some tests for the newly refactored
crypt library. All modes go through an end-to-end test making sure data
is preserved and repeatedly mapping/unmapping the device works.

This change also disables insecure mode in debug builds. The equivalent
functionality will be re-established at a higher level in the cluster
code in a subsequent change, alongside the encryption/authentication
policy code.

Change-Id: I85db001c7c37a918cb491b1fcc3a51ea1d715817
Reviewed-on: https://review.monogon.dev/c/monogon/+/1724
Tested-by: Jenkins CI
Reviewed-by: Lorenz Brun <lorenz@monogon.tech>
diff --git a/metropolis/node/core/localstorage/crypt/BUILD.bazel b/metropolis/node/core/localstorage/crypt/BUILD.bazel
index 84b289b..b330bf1 100644
--- a/metropolis/node/core/localstorage/crypt/BUILD.bazel
+++ b/metropolis/node/core/localstorage/crypt/BUILD.bazel
@@ -1,14 +1,15 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+load("//metropolis/test/ktest:ktest.bzl", "ktest")
 
 go_library(
     name = "crypt",
     # keep
     srcs = [
         "blockdev.go",
-    ] + select({
-        "//metropolis/node:debug_build": ["crypt_debug.go"],
-        "//conditions:default": ["crypt.go"],
-    }),
+        "crypt.go",
+        "crypt_encryption.go",
+        "crypt_integrity.go",
+    ],
     importpath = "source.monogon.dev/metropolis/node/core/localstorage/crypt",
     visibility = ["//metropolis/node/core/localstorage:__subpackages__"],
     deps = [
@@ -21,3 +22,14 @@
         "@org_golang_x_sys//unix",
     ],
 )
+
+go_test(
+    name = "crypt_test",
+    srcs = ["crypt_test.go"],
+    embed = [":crypt"],
+)
+
+ktest(
+    cmdline = "ramdisk_size=4096",
+    tester = ":crypt_test",
+)
diff --git a/metropolis/node/core/localstorage/crypt/blockdev.go b/metropolis/node/core/localstorage/crypt/blockdev.go
index 8b572e8..8379180 100644
--- a/metropolis/node/core/localstorage/crypt/blockdev.go
+++ b/metropolis/node/core/localstorage/crypt/blockdev.go
@@ -38,8 +38,8 @@
 var NodeDataPartitionType = uuid.MustParse("9eeec464-6885-414a-b278-4305c51f7966")
 
 const (
-	ESPDevicePath     = "/dev/esp"
-	NodeDataCryptPath = "/dev/data-crypt"
+	ESPDevicePath   = "/dev/esp"
+	NodeDataRawPath = "/dev/data-raw"
 )
 
 // MakeBlockDevices looks for the ESP and the node data partition and maps them
@@ -116,7 +116,7 @@
 				}
 				if part.Type == NodeDataPartitionType {
 					seenTypes[part.Type] = true
-					err := unix.Mknod(NodeDataCryptPath, 0600|unix.S_IFBLK, int(unix.Mkdev(uint32(majorDev), uint32(partNumber+1))))
+					err := unix.Mknod(NodeDataRawPath, 0600|unix.S_IFBLK, int(unix.Mkdev(uint32(majorDev), uint32(partNumber+1))))
 					if err != nil && !os.IsExist(err) {
 						return fmt.Errorf("failed to create device node for Metropolis node encrypted data partition: %w", err)
 					}
diff --git a/metropolis/node/core/localstorage/crypt/crypt.go b/metropolis/node/core/localstorage/crypt/crypt.go
index e84390e..af7451b 100644
--- a/metropolis/node/core/localstorage/crypt/crypt.go
+++ b/metropolis/node/core/localstorage/crypt/crypt.go
@@ -14,179 +14,265 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
+// Package crypt implements block device (eg. disk) encryption and authentication
+// using dm-crypt and dm-integrity.
+//
+// Encryption using dm-crypt is implemented using AES (either in GCM or XTS mode,
+// depending on whether authentication is enabled).
+//
+// Authentication using dm-integrity provides per-sector integrity protection which
+// guards against accidental and malicious bit flips in the underlying storage,
+// but does nor protect against individual sectors (or the entire disk) being
+// rolled back.
+//
+// The same key is used for both authentication and encryption. The key must be
+// exactly 256 bits long.
+//
+// When initializing or mapping a device, a name must be provided. This name will
+// be used as the device-mapper target name if the device will have a
+// device-mapper set up, and will also form the base of any intermediary target
+// names used. Thus, it must be unique per data store.
+
 package crypt
 
 import (
-	"encoding/binary"
-	"encoding/hex"
 	"errors"
 	"fmt"
 	"os"
 	"syscall"
+	"unsafe"
 
 	"golang.org/x/sys/unix"
-
-	"source.monogon.dev/metropolis/pkg/devicemapper"
 )
 
-func readDataSectors(path string) (uint64, error) {
-	integrityPartition, err := os.Open(path)
-	if err != nil {
-		return 0, err
+// Mode of block device encryption and/or authentication, if any. See the
+// package-level documentation for information about how encryption and
+// authentication is implemented and what guarantees they provide.
+type Mode string
+
+// ModeEncryptedAuthenticated means the block device will first be authenticated
+// using dm-integrity, then encrypted using dm-crypt.
+//
+// A key needs to be provided when initializing and mapping a block device.
+const ModeEncryptedAuthenticated Mode = "encrypted+authenticated"
+
+// ModeEncrypted means the device will be encrypted using dm-crypt, but will not
+// be authenticated.
+//
+// A key needs to be provided when initializing and mapping a block device.
+const ModeEncrypted Mode = "encrypted"
+
+// ModeAuthenticated means the device will be authenticated using dm-integrity,
+// but will not be encrypted.
+//
+// A key needs to be provided when initializing and mapping a block device.
+const ModeAuthenticated Mode = "authenticated"
+
+// ModeInsecure means the device will be neither authenticated nor encrypted.
+//
+// A key must not be provided, or must be exactly zero bytes long.
+const ModeInsecure Mode = "insecure"
+
+func (m Mode) encrypted() bool {
+	switch m {
+	case ModeEncryptedAuthenticated, ModeEncrypted:
+		return true
+	case ModeInsecure, ModeAuthenticated:
+		return false
 	}
-	defer integrityPartition.Close()
-	// Based on structure defined in
-	//   https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/drivers/md/dm-integrity.c#n59
-	if _, err := integrityPartition.Seek(16, 0); err != nil {
-		return 0, err
-	}
-	var providedDataSectors uint64
-	if err := binary.Read(integrityPartition, binary.LittleEndian, &providedDataSectors); err != nil {
-		return 0, err
-	}
-	return providedDataSectors, nil
+	panic("invalid mode " + m)
 }
 
-// cryptMap maps an encrypted device (node) at baseName to a
-// decrypted device at /dev/$name using the given encryptionKey
-func CryptMap(name string, baseName string, encryptionKey []byte) error {
-	return cryptMap(name, baseName, encryptionKey, true)
+func (m Mode) authenticated() bool {
+	switch m {
+	case ModeEncryptedAuthenticated, ModeAuthenticated:
+		return true
+	case ModeEncrypted, ModeInsecure:
+		return false
+	}
+	panic("invalid mode " + m)
 }
 
-func cryptMap(name string, baseName string, encryptionKey []byte, enableJournal bool) error {
-	integritySectors, err := readDataSectors(baseName)
+// getSizeBytes returns the size of a block device in bytes.
+func getSizeBytes(path string) (uint64, error) {
+	blkdev, err := os.Open(path)
 	if err != nil {
-		return fmt.Errorf("failed to read the number of usable sectors on the integrity device: %w", err)
-	}
-
-	integrityDevName := fmt.Sprintf("/dev/%v-integrity", name)
-	integrityDMName := fmt.Sprintf("%v-integrity", name)
-	mode := "D"
-	if enableJournal {
-		mode = "J"
-	}
-	integrityDev, err := devicemapper.CreateActiveDevice(integrityDMName, false, []devicemapper.Target{
-		devicemapper.Target{
-			Length:     integritySectors,
-			Type:       "integrity",
-			Parameters: []string{baseName, "0", "28", mode, "1", "journal_sectors:1024"},
-		},
-	})
-	if err != nil {
-		return fmt.Errorf("failed to create Integrity device: %w", err)
-	}
-	if err := unix.Mknod(integrityDevName, 0600|unix.S_IFBLK, int(integrityDev)); err != nil {
-		unix.Unlink(integrityDevName)
-		devicemapper.RemoveDevice(integrityDMName)
-		return fmt.Errorf("failed to create integrity device node: %w", err)
-	}
-
-	cryptDevName := fmt.Sprintf("/dev/%v", name)
-	cryptDev, err := devicemapper.CreateActiveDevice(name, false, []devicemapper.Target{
-		devicemapper.Target{
-			Length:     integritySectors,
-			Type:       "crypt",
-			Parameters: []string{"capi:gcm(aes)-random", hex.EncodeToString(encryptionKey), "0", integrityDevName, "0", "1", "integrity:28:aead", "no_read_workqueue", "no_write_workqueue"},
-		},
-	})
-	if err != nil {
-		unix.Unlink(integrityDevName)
-		devicemapper.RemoveDevice(integrityDMName)
-		return fmt.Errorf("failed to create crypt device: %w", err)
-	}
-	if err := unix.Mknod(cryptDevName, 0600|unix.S_IFBLK, int(cryptDev)); err != nil {
-		unix.Unlink(cryptDevName)
-		devicemapper.RemoveDevice(name)
-
-		unix.Unlink(integrityDevName)
-		devicemapper.RemoveDevice(integrityDMName)
-		return fmt.Errorf("failed to create crypt device node: %w", err)
-	}
-	return nil
-}
-
-func cryptUnmap(name string, baseName string) error {
-	integrityDevName := fmt.Sprintf("/dev/%v-integrity", name)
-	if err := unix.Unlink(integrityDevName); err != nil && !os.IsNotExist(err) {
-		return fmt.Errorf("failed to delete integrity device inode: %w", err)
-	}
-	cryptDevName := fmt.Sprintf("/dev/%v", name)
-	if err := unix.Unlink(cryptDevName); err != nil {
-		return fmt.Errorf("failed to delete crypt device inode: %w", err)
-	}
-	integrityDMName := fmt.Sprintf("%v-integrity", name)
-	if err := devicemapper.RemoveDevice(name); err != nil && !errors.Is(err, unix.ENOENT) {
-		return fmt.Errorf("failed to remove dm-crypt device: %w", err)
-	}
-	if err := devicemapper.RemoveDevice(integrityDMName); err != nil && !errors.Is(err, unix.ENOENT) {
-		return fmt.Errorf("failed to remove dm-integrity device: %w", err)
-	}
-	return nil
-}
-
-// cryptInit initializes a new encrypted block device. This can take a long
-// time since all bytes on the mapped block device need to be zeroed.
-func CryptInit(name, baseName string, encryptionKey []byte) error {
-	integrityPartition, err := os.OpenFile(baseName, os.O_WRONLY, 0)
-	if err != nil {
-		return err
-	}
-	defer integrityPartition.Close()
-	zeroed512BBuf := make([]byte, 4096)
-	if _, err := integrityPartition.Write(zeroed512BBuf); err != nil {
-		return fmt.Errorf("failed to wipe header: %w", err)
-	}
-	integrityPartition.Close()
-
-	integrityDMName := fmt.Sprintf("%v-integrity", name)
-	_, err = devicemapper.CreateActiveDevice(integrityDMName, false, []devicemapper.Target{
-		{
-			Length:     1,
-			Type:       "integrity",
-			Parameters: []string{baseName, "0", "28", "J", "1", "journal_sectors:1024"},
-		},
-	})
-	if err != nil {
-		return fmt.Errorf("failed to create discovery integrity device: %w", err)
-	}
-	if err := devicemapper.RemoveDevice(integrityDMName); err != nil {
-		return fmt.Errorf("failed to remove discovery integrity device: %w", err)
-	}
-
-	// First, map the device without journal. Zeroing with journal is extremely
-	// slow as it transforms sequential IO into random IO and also consumes
-	// twice the write operations. This is fine as if we abort here we'll
-	// reinitialize the whole device so the reliability is of no concern.
-	if err := cryptMap(name, baseName, encryptionKey, false); err != nil {
-		return err
-	}
-
-	blkdev, err := os.OpenFile(fmt.Sprintf("/dev/%v", name), unix.O_DIRECT|os.O_WRONLY, 0000)
-	if err != nil {
-		return fmt.Errorf("failed to open new encrypted device for zeroing: %w", err)
+		return 0, fmt.Errorf("failed to open block device: %w", err)
 	}
 	defer blkdev.Close()
+
+	var sizeBytes uint64
+	_, _, err = unix.Syscall(unix.SYS_IOCTL, blkdev.Fd(), unix.BLKGETSIZE64, uintptr(unsafe.Pointer(&sizeBytes)))
+	if err != unix.Errno(0) {
+		return 0, fmt.Errorf("failed to get device size: %w", err)
+	}
+	return sizeBytes, nil
+}
+
+// getBlockSize returns the size of a block device's sector in bytes.
+func getBlockSize(path string) (uint32, error) {
+	blkdev, err := os.Open(path)
+	if err != nil {
+		return 0, fmt.Errorf("failed to open block device: %w", err)
+	}
+	defer blkdev.Close()
+
 	blockSize, err := unix.IoctlGetUint32(int(blkdev.Fd()), unix.BLKSSZGET)
-	zeroedBuf := make([]byte, blockSize*256) // Make it faster
-	for {
-		_, err := blkdev.Write(zeroedBuf)
-		if e, ok := err.(*os.PathError); ok && e.Err == syscall.ENOSPC {
-			break
+	if err != nil {
+		return 0, fmt.Errorf("BLKSSZGET: %w", err)
+	}
+	return blockSize, nil
+}
+
+// Map sets up an underlying block device (at path 'underlying') for access.
+// Depending on the given mode, authentication/integrity device-mapper targets
+// will be set up, and the top-level new block device path will be returned.
+//
+// The given name will be used as a base for the device mapper targets created,
+// and is used to uniquely identify this particular mapping setup. The same name
+// must then be used to unmap the device.
+//
+// If an error occurs during Map, cleanup will be attempted and an error will be
+// returned.
+//
+// The encryption key must be exactly 32 bytes / 256 bits long when
+// authentication and/or encryption is enabled, and nil / 0 bytes long when
+// insecure mode is used.
+//
+// Note: a successful Map does not necessarily mean the underlying device is
+// ready to access. Integrity errors or data corruption might mean accesses to
+// the newly mapped device will fail. The caller is responsible for catching
+// these conditions.
+func Map(name string, underlying string, encryptionKey []byte, mode Mode) (string, error) {
+	return map_(name, underlying, encryptionKey, mode, true)
+}
+
+// map_ is the internal implementation of Map, which also allows
+// enabling/disabling the integrity journal.
+//
+// This would be called map, but map is a reserved keyword in Go.
+func map_(name string, underlying string, encryptionKey []byte, mode Mode, enableJournal bool) (string, error) {
+	// Verify key length.
+	switch mode {
+	case ModeInsecure:
+		if len(encryptionKey) != 0 {
+			return "", fmt.Errorf("can't use key in insecure mode")
 		}
+	default:
+		if len(encryptionKey) != 32 {
+			return "", fmt.Errorf("key must be exactly 32 bytes / 256 bits")
+		}
+	}
+
+	device := underlying
+	if mode.authenticated() {
+		var err error
+		device, err = mapIntegrity(name, device, enableJournal)
 		if err != nil {
-			return fmt.Errorf("failed to zero-initalize new encrypted device: %w", err)
+			return "", err
 		}
 	}
-	blkdev.Close()
 
-	// Now, unmap the non-journaled device and remap it with journaling for
-	// further use.
-	if err := cryptUnmap(name, baseName); err != nil {
-		return fmt.Errorf("failed to unmap temporary encrypted block device: %w", err)
-	}
-	if err := cryptMap(name, baseName, encryptionKey, true); err != nil {
-		return fmt.Errorf("failed to map initialized encrypted device: %w", err)
+	if mode.encrypted() {
+		var err error
+		device, err = mapEncryption(name, device, encryptionKey, mode.authenticated())
+		if err != nil {
+			unmapIntegrity(name)
+			return "", err
+		}
 	}
 
+	return device, nil
+}
+
+// Unmap tears down all block devices related to the named mapping. The given
+// name and mode must match the name and mode used when mapping and/or
+// initializing the disk.
+func Unmap(name string, mode Mode) error {
+	if mode.encrypted() {
+		if err := unmapEncryption(name); err != nil {
+			return err
+		}
+	}
+	if mode.authenticated() {
+		if err := unmapIntegrity(name); err != nil {
+			return err
+		}
+	}
 	return nil
 }
+
+// Init sets up encryption/authentication as defined by mode on an underlying
+// block device path. After initialization, the setup/mapping is preserved and
+// the path of the resulting top-level block device is returned.
+//
+// Any existing data present on the underlying storage will be ignored. If
+// authentication is enabled, the underlying storage will also be fully
+// overwritten.
+//
+// The given name will be used as a base for the device mapper targets created,
+// and is used to uniquely identify this particular mapping setup. The same name
+// must then be used to unmap the device.
+//
+// The encryption key must be exactly 32 bytes / 256 bits long when
+// authentication and/or encryption is enabled, and nil / 0 bytes long when
+// insecure mode is used.
+func Init(name, underlying string, encryptionKey []byte, mode Mode) (string, error) {
+	// If using an authenticated mode, we'll do an initial map with journaling
+	// enabled to speed up the initial zeroing, then remap it with journaling.
+	// Otherwise, we immediately map with journaling enabled and don't remap.
+	initWithJournal := true
+	if mode.authenticated() {
+		if err := initializeIntegrity(name, underlying); err != nil {
+			return "", err
+		}
+		initWithJournal = false
+	}
+
+	device, err := map_(name, underlying, encryptionKey, mode, initWithJournal)
+	if err != nil {
+		return "", fmt.Errorf("initial mount failed: %w", err)
+	}
+
+	// Zero out device if authentication is enabled.
+	if mode.authenticated() {
+		blockSize, err := getBlockSize(device)
+		if err != nil {
+			return "", err
+		}
+
+		blkdev, err := os.OpenFile(device, unix.O_DIRECT|os.O_WRONLY, 0000)
+		if err != nil {
+			return "", fmt.Errorf("failed to open new device for zeroing: %w", err)
+		}
+
+		// Use a multiple of the block size to make the initial zeroing faster.
+		zeroedBuf := make([]byte, blockSize*256)
+		for {
+			_, err := blkdev.Write(zeroedBuf)
+			if errors.Is(err, syscall.ENOSPC) {
+				break
+			}
+			if err != nil {
+				blkdev.Close()
+				return "", fmt.Errorf("failed to zero-initalize new device: %w", err)
+			}
+		}
+		if err := blkdev.Close(); err != nil {
+			return "", fmt.Errorf("failed to close initialized device: %w", err)
+		}
+	}
+
+	// Remap with journaling if needed.
+	if !initWithJournal {
+		if err := Unmap(name, mode); err != nil {
+			return "", fmt.Errorf("failed to unmap temporary encrypted block device: %w", err)
+		}
+
+		device, err = map_(name, underlying, encryptionKey, mode, true)
+		if err != nil {
+			return "", fmt.Errorf("failed to map initialized encrypted device: %w", err)
+		}
+	}
+	return device, nil
+}
diff --git a/metropolis/node/core/localstorage/crypt/crypt_debug.go b/metropolis/node/core/localstorage/crypt/crypt_debug.go
deleted file mode 100644
index 2fbdc50..0000000
--- a/metropolis/node/core/localstorage/crypt/crypt_debug.go
+++ /dev/null
@@ -1,44 +0,0 @@
-// Copyright 2020 The Monogon Project Authors.
-//
-// SPDX-License-Identifier: Apache-2.0
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package crypt
-
-import (
-	"fmt"
-
-	"golang.org/x/sys/unix"
-)
-
-// CryptMap implements a debug version of CryptMap from crypt.go. It aliases
-// the given baseName device into name without any encryption.
-func CryptMap(name string, baseName string, _ []byte) error {
-	var stat unix.Stat_t
-	if err := unix.Stat(baseName, &stat); err != nil {
-		return fmt.Errorf("cannot stat base device: %w", err)
-	}
-	cryptDevName := fmt.Sprintf("/dev/%v", name)
-	if err := unix.Mknod(cryptDevName, 0600|unix.S_IFBLK, int(stat.Rdev)); err != nil {
-		return fmt.Errorf("failed to create crypt device node: %w", err)
-	}
-	return nil
-}
-
-// CryptInit implements a debug version of CryptInit from crypt.go. It aliases
-// the given baseName device into name without any encryption. As an identity
-// mapping doesn't need any initialization it doesn't do anything else.
-func CryptInit(name, baseName string, encryptionKey []byte) error {
-	return CryptMap(name, baseName, encryptionKey)
-}
diff --git a/metropolis/node/core/localstorage/crypt/crypt_encryption.go b/metropolis/node/core/localstorage/crypt/crypt_encryption.go
new file mode 100644
index 0000000..4fd6061
--- /dev/null
+++ b/metropolis/node/core/localstorage/crypt/crypt_encryption.go
@@ -0,0 +1,81 @@
+package crypt
+
+import (
+	"encoding/hex"
+	"fmt"
+	"os"
+
+	"golang.org/x/sys/unix"
+
+	"source.monogon.dev/metropolis/pkg/devicemapper"
+)
+
+func encryptionDevPath(name string) string {
+	return fmt.Sprintf("/dev/%s-crypt", name)
+}
+
+func encryptionDMName(name string) string {
+	return fmt.Sprintf("%s-crypt", name)
+}
+
+func mapEncryption(name, underlying string, encryptionKey []byte, authenticated bool) (string, error) {
+	sizeBytes, err := getSizeBytes(underlying)
+	if err != nil {
+		return "", fmt.Errorf("getting size of block device failed: %w", err)
+	}
+	blockSize, err := getBlockSize(underlying)
+	if err != nil {
+		return "", fmt.Errorf("getting block size failed: %w", err)
+	}
+
+	optParams := []string{
+		"no_read_workqueue", "no_write_workqueue",
+	}
+	cipher := "capi:xts(aes)-essiv:sha256"
+	if authenticated {
+		optParams = append(optParams, "integrity:28:aead")
+		cipher = "capi:gcm(aes)-random"
+	} else {
+		// discard (TRIM/UNMAP) only works without integrity enabled.
+		optParams = append(optParams, "allow_discards")
+	}
+	params := []string{
+		// cipher, key, iv_offset, device_path, offset
+		cipher, hex.EncodeToString(encryptionKey), "0", underlying, "0",
+		// number of opt params
+		fmt.Sprintf("%d", len(optParams)),
+	}
+	params = append(params, optParams...)
+
+	cryptDev, err := devicemapper.CreateActiveDevice(encryptionDMName(name), false, []devicemapper.Target{
+		{
+			Length:     sizeBytes / uint64(blockSize),
+			Type:       "crypt",
+			Parameters: params,
+		},
+	})
+	if err != nil {
+		return "", fmt.Errorf("failed to create crypt device: %w", err)
+	}
+	if err := unix.Mknod(encryptionDevPath(name), 0600|unix.S_IFBLK, int(cryptDev)); err != nil {
+		// Best-effort cleanup, swallow errors.
+		unmapEncryption(name)
+		return "", fmt.Errorf("failed to create crypt device node: %w", err)
+	}
+	return encryptionDevPath(name), nil
+}
+
+func unmapEncryption(name string) error {
+	// Remove /dev node if present.
+	if _, err := os.Stat(encryptionDevPath(name)); err == nil {
+		if err := unix.Unlink(encryptionDevPath(name)); err != nil {
+			return fmt.Errorf("unlinking encryption device failed: %w", err)
+		}
+	}
+
+	// Remove dm target.
+	if err := devicemapper.RemoveDevice(encryptionDMName(name)); err != nil {
+		return fmt.Errorf("removing encryption device failed: %w", err)
+	}
+	return nil
+}
diff --git a/metropolis/node/core/localstorage/crypt/crypt_integrity.go b/metropolis/node/core/localstorage/crypt/crypt_integrity.go
new file mode 100644
index 0000000..4130aef
--- /dev/null
+++ b/metropolis/node/core/localstorage/crypt/crypt_integrity.go
@@ -0,0 +1,139 @@
+package crypt
+
+import (
+	"encoding/binary"
+	"fmt"
+	"os"
+
+	"golang.org/x/sys/unix"
+
+	"source.monogon.dev/metropolis/pkg/devicemapper"
+)
+
+func integrityDevPath(name string) string {
+	return fmt.Sprintf("/dev/%s-integrity", name)
+}
+
+func integrityDMName(name string) string {
+	return fmt.Sprintf("%s-integrity", name)
+}
+
+// readIntegrityDataSectors parses the number of available integrity data sectors
+// from a raw dm-integrity formatted device. This is needed to then map the
+// device.
+//
+// This is described in further detail in
+// https://docs.kernel.org/admin-guide/device-mapper/dm-integrity.html.
+func readIntegrityDataSectors(path string) (uint64, error) {
+	integrityPartition, err := os.Open(path)
+	if err != nil {
+		return 0, err
+	}
+	defer integrityPartition.Close()
+	// Based on structure defined in
+	//   https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/drivers/md/dm-integrity.c#n59
+	if _, err := integrityPartition.Seek(16, 0); err != nil {
+		return 0, err
+	}
+	var providedDataSectors uint64
+	if err := binary.Read(integrityPartition, binary.LittleEndian, &providedDataSectors); err != nil {
+		return 0, err
+	}
+
+	// Let's perform some simple checks on the read value to make sure the returned
+	// data isn't corrupted or has been tampered with.
+
+	if providedDataSectors == 0 {
+		return 0, fmt.Errorf("invalid data sector count of zero")
+	}
+
+	underlyingSizeBytes, err := getSizeBytes(path)
+	if err != nil {
+		return 0, fmt.Errorf("getting underlying block device size failed: %w", err)
+	}
+	underlyingBlockSize, err := getBlockSize(path)
+	if err != nil {
+		return 0, fmt.Errorf("getting underlying block device block size failed: %w", err)
+	}
+	underlyingSectors := underlyingSizeBytes / uint64(underlyingBlockSize)
+	if providedDataSectors > underlyingSectors {
+		return 0, fmt.Errorf("device claims %d data sectors but underlying device only has %d", providedDataSectors, underlyingSectors)
+	}
+	return providedDataSectors, nil
+}
+
+// initializeIntegrity performs the initialization steps outlined in
+// https://docs.kernel.org/admin-guide/device-mapper/dm-integrity.html.
+func initializeIntegrity(name, baseName string) error {
+	// Zero out superblock.
+	integrityPartition, err := os.OpenFile(baseName, os.O_WRONLY, 0)
+	if err != nil {
+		return err
+	}
+	zeroedBuf := make([]byte, 4096)
+	if _, err := integrityPartition.Write(zeroedBuf); err != nil {
+		integrityPartition.Close()
+		return fmt.Errorf("failed to wipe header: %w", err)
+	}
+	integrityPartition.Close()
+
+	// Load target with one-sector size. The kernel will format the device.
+	_, err = devicemapper.CreateActiveDevice(integrityDMName(name), false, []devicemapper.Target{
+		{
+			Length:     1,
+			Type:       "integrity",
+			Parameters: []string{baseName, "0", "28", "J", "1", "journal_sectors:1024"},
+		},
+	})
+	if err != nil {
+		return fmt.Errorf("failed to create initial integrity device: %w", err)
+	}
+	// Unload the target.
+	if err := devicemapper.RemoveDevice(integrityDMName(name)); err != nil {
+		return fmt.Errorf("failed to remove initial integrity device: %w", err)
+	}
+
+	return nil
+}
+
+func mapIntegrity(name, baseName string, enableJournal bool) (string, error) {
+	integritySectors, err := readIntegrityDataSectors(baseName)
+	if err != nil {
+		return "", fmt.Errorf("failed to read the number of usable sectors on the integrity device: %w", err)
+	}
+
+	mode := "D"
+	if enableJournal {
+		mode = "J"
+	}
+	integrityDev, err := devicemapper.CreateActiveDevice(integrityDMName(name), false, []devicemapper.Target{
+		{
+			Length:     integritySectors,
+			Type:       "integrity",
+			Parameters: []string{baseName, "0", "28", mode, "1", "journal_sectors:1024"},
+		},
+	})
+	if err != nil {
+		return "", fmt.Errorf("failed to create Integrity device: %w", err)
+	}
+	if err := unix.Mknod(integrityDevPath(name), 0600|unix.S_IFBLK, int(integrityDev)); err != nil {
+		unmapIntegrity(name)
+		return "", fmt.Errorf("failed to create integrity device node: %w", err)
+	}
+
+	return integrityDevPath(name), nil
+}
+
+func unmapIntegrity(name string) error {
+	// Remove /dev node if present.
+	if _, err := os.Stat(integrityDevPath(name)); err == nil {
+		if err := unix.Unlink(integrityDevPath(name)); err != nil {
+			return fmt.Errorf("unlinking integrity device failed: %w", err)
+		}
+	}
+
+	if err := devicemapper.RemoveDevice(integrityDMName(name)); err != nil {
+		return fmt.Errorf("removing integrity DM device failed: %w", err)
+	}
+	return nil
+}
diff --git a/metropolis/node/core/localstorage/crypt/crypt_test.go b/metropolis/node/core/localstorage/crypt/crypt_test.go
new file mode 100644
index 0000000..f021e56
--- /dev/null
+++ b/metropolis/node/core/localstorage/crypt/crypt_test.go
@@ -0,0 +1,123 @@
+package crypt
+
+import (
+	"bytes"
+	"fmt"
+	"os"
+	"testing"
+)
+
+// TestMapUnmap performs a round-trip test for all modes, making sure we can
+// intialize, map, unmap, map again and unmap again and that data isn't getting
+// corrupted.
+func TestMapUnmap(t *testing.T) {
+	if os.Getenv("IN_KTEST") != "true" {
+		t.Skip("Not in ktest")
+	}
+
+	init := func(name string, key []byte, mode Mode) string {
+		t.Helper()
+
+		target, err := Init(name, "/dev/ram0", key, mode)
+		if err != nil {
+			t.Fatalf("Init failed: %v", err)
+		}
+		return target
+	}
+
+	unmap := func(name string, mode Mode) {
+		t.Helper()
+		if err := Unmap(name, mode); err != nil {
+			t.Fatalf("Unmap failed: %v", err)
+		}
+
+	}
+
+	map_ := func(name string, key []byte, mode Mode) string {
+		t.Helper()
+		target, err := Map(name, "/dev/ram0", key, mode)
+		if err != nil {
+			t.Fatalf("Map fialed: %v", err)
+		}
+		return target
+	}
+
+	writeWitness := func(target string, i int) string {
+		t.Helper()
+
+		file, err := os.OpenFile(target, os.O_WRONLY, 0644)
+		if err != nil {
+			t.Fatalf("opening initialized crypt failed: %v", err)
+		}
+		defer file.Close()
+
+		witness := fmt.Sprintf("this is test %d", i)
+		_, err = fmt.Fprintf(file, "%s", witness)
+		if err != nil {
+			t.Fatalf("writing to initialized crypt failed; %v", err)
+		}
+		return witness
+	}
+
+	checkWitness := func(target, witness string) {
+		t.Helper()
+
+		file, err := os.OpenFile(target, 0, 644)
+		if err != nil {
+			t.Fatalf("opening mapped crypt failed: %v", err)
+		}
+		defer file.Close()
+
+		buf := make([]byte, len(witness))
+		_, err = file.Read(buf)
+		if err != nil {
+			t.Fatalf("reading mapped crypt failed: %v", err)
+		}
+		defer file.Close()
+
+		if want, got := witness, string(buf); want != got {
+			t.Fatalf("read data differs, wanted %q, got %q", want, got)
+		}
+		file.Close()
+	}
+
+	for i, mode := range []Mode{
+		ModeInsecure,
+		ModeEncrypted,
+		ModeAuthenticated,
+		ModeEncryptedAuthenticated,
+	} {
+		t.Run(string(mode), func(t *testing.T) {
+			name := fmt.Sprintf("test-%d", i)
+			key := bytes.Repeat([]byte("a"), 32)
+			if mode == ModeInsecure {
+				key = nil
+			}
+
+			target := init(name, key, mode)
+			witness := writeWitness(target, i)
+			unmap(name, mode)
+
+			if target != "/dev/ram0" {
+				if _, err := os.Stat(target); !os.IsNotExist(err) {
+					t.Fatalf("Unmount didn't remove %s", target)
+				}
+			}
+
+			target2 := map_(name, key, mode)
+			if target != target2 {
+				t.Fatalf("Init mounted at %s, first Map mounted at %s", target, target2)
+			}
+
+			checkWitness(target, witness)
+			unmap(name, mode)
+
+			target3 := map_(name, key, mode)
+			if target != target3 {
+				t.Fatalf("Init mounted at %s, second Map mounted at %s", target, target2)
+			}
+			checkWitness(target, witness)
+			unmap(name, mode)
+		})
+	}
+}
diff --git a/metropolis/node/core/localstorage/directory_data.go b/metropolis/node/core/localstorage/directory_data.go
index d88afb7..f824ac2 100644
--- a/metropolis/node/core/localstorage/directory_data.go
+++ b/metropolis/node/core/localstorage/directory_data.go
@@ -50,10 +50,11 @@
 		key[i] = config.NodeUnlockKey[i] ^ clusterUnlockKey[i]
 	}
 
-	if err := crypt.CryptMap("data", crypt.NodeDataCryptPath, key); err != nil {
+	target, err := crypt.Map("data", crypt.NodeDataRawPath, key, crypt.ModeEncryptedAuthenticated)
+	if err != nil {
 		return err
 	}
-	if err := d.mount(); err != nil {
+	if err := d.mount(target); err != nil {
 		return err
 	}
 	return nil
@@ -102,15 +103,16 @@
 		key[i] = nodeUnlockKey[i] ^ globalUnlockKey[i]
 	}
 
-	if err := crypt.CryptInit("data", crypt.NodeDataCryptPath, key); err != nil {
+	target, err := crypt.Init("data", crypt.NodeDataRawPath, key, crypt.ModeEncryptedAuthenticated)
+	if err != nil {
 		return nil, fmt.Errorf("initializing encrypted block device: %w", err)
 	}
-	mkfsCmd := exec.Command("/bin/mkfs.xfs", "-qKf", "/dev/data")
+	mkfsCmd := exec.Command("/bin/mkfs.xfs", "-qKf", target)
 	if _, err := mkfsCmd.Output(); err != nil {
 		return nil, fmt.Errorf("formatting encrypted block device: %w", err)
 	}
 
-	if err := d.mount(); err != nil {
+	if err := d.mount(target); err != nil {
 		return nil, fmt.Errorf("mounting: %w", err)
 	}
 
@@ -136,10 +138,10 @@
 	return globalUnlockKey, nil
 }
 
-func (d *DataDirectory) mount() error {
+func (d *DataDirectory) mount(path string) error {
 	// TODO(T965): MS_NODEV should definitely be set on the data partition, but as long as the kubelet root
 	// is on there, we can't do it.
-	if err := unix.Mount("/dev/data", d.FullPath(), "xfs", unix.MS_NOEXEC, "pquota"); err != nil {
+	if err := unix.Mount(path, d.FullPath(), "xfs", unix.MS_NOEXEC, "pquota"); err != nil {
 		return fmt.Errorf("mounting data directory: %w", err)
 	}
 	return nil