osbase/net/dhcp4c: move package out of metropolis
Move the dhcp4c package from metropolis/node/core/network/dhcp4c to
osbase/net/dhcp4c. The package is not specific to metropolis, and is
also used by nanoswitch and cloud/agent.
Change-Id: I508261c93c623d5b7a33a2089da11625b7a3abd0
Reviewed-on: https://review.monogon.dev/c/monogon/+/4565
Tested-by: Jenkins CI
Reviewed-by: Tim Windelschmidt <tim@monogon.tech>
diff --git a/osbase/net/dhcp4c/BUILD.bazel b/osbase/net/dhcp4c/BUILD.bazel
new file mode 100644
index 0000000..af7081c
--- /dev/null
+++ b/osbase/net/dhcp4c/BUILD.bazel
@@ -0,0 +1,34 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "dhcp4c",
+ srcs = [
+ "dhcpc.go",
+ "doc.go",
+ "lease.go",
+ ],
+ importpath = "source.monogon.dev/osbase/net/dhcp4c",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//osbase/net/dhcp4c/transport",
+ "//osbase/supervisor",
+ "@com_github_cenkalti_backoff_v4//:backoff",
+ "@com_github_insomniacslk_dhcp//dhcpv4",
+ "@com_github_insomniacslk_dhcp//iana",
+ ],
+)
+
+go_test(
+ name = "dhcp4c_test",
+ srcs = [
+ "dhcpc_test.go",
+ "lease_test.go",
+ ],
+ embed = [":dhcp4c"],
+ deps = [
+ "//osbase/net/dhcp4c/transport",
+ "@com_github_cenkalti_backoff_v4//:backoff",
+ "@com_github_insomniacslk_dhcp//dhcpv4",
+ "@com_github_stretchr_testify//assert",
+ ],
+)
diff --git a/osbase/net/dhcp4c/callback/BUILD.bazel b/osbase/net/dhcp4c/callback/BUILD.bazel
new file mode 100644
index 0000000..a8780fc
--- /dev/null
+++ b/osbase/net/dhcp4c/callback/BUILD.bazel
@@ -0,0 +1,33 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+load("//osbase/test/ktest:ktest.bzl", "k_test")
+
+go_library(
+ name = "callback",
+ srcs = ["callback.go"],
+ importpath = "source.monogon.dev/osbase/net/dhcp4c/callback",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//osbase/net/dhcp4c",
+ "@com_github_insomniacslk_dhcp//dhcpv4",
+ "@com_github_vishvananda_netlink//:netlink",
+ "@org_golang_x_sys//unix",
+ ],
+)
+
+go_test(
+ name = "callback_test",
+ srcs = ["callback_test.go"],
+ embed = [":callback"],
+ deps = [
+ "//osbase/net/dhcp4c",
+ "@com_github_google_go_cmp//cmp",
+ "@com_github_insomniacslk_dhcp//dhcpv4",
+ "@com_github_vishvananda_netlink//:netlink",
+ "@org_golang_x_sys//unix",
+ ],
+)
+
+k_test(
+ name = "ktest",
+ tester = ":callback_test",
+)
diff --git a/osbase/net/dhcp4c/callback/callback.go b/osbase/net/dhcp4c/callback/callback.go
new file mode 100644
index 0000000..dc3088b
--- /dev/null
+++ b/osbase/net/dhcp4c/callback/callback.go
@@ -0,0 +1,159 @@
+// Copyright The Monogon Project Authors.
+// SPDX-License-Identifier: Apache-2.0
+
+// Package callback contains minimal callbacks for configuring the kernel with
+// options received over DHCP.
+//
+// These directly configure the relevant kernel subsytems and need to own
+// certain parts of them as documented on a per- callback basis to make sure
+// that they can recover from restarts and crashes of the DHCP client.
+// The callbacks in here are not suitable for use in advanced network scenarios
+// like running multiple DHCP clients per interface via ClientIdentifier or
+// when running an external FIB manager. In these cases it's advised to extract
+// the necessary information from the lease in your own callback and
+// communicate it directly to the responsible entity.
+package callback
+
+import (
+ "fmt"
+ "math"
+ "net"
+ "os"
+ "time"
+
+ "github.com/insomniacslk/dhcp/dhcpv4"
+ "github.com/vishvananda/netlink"
+ "golang.org/x/sys/unix"
+
+ "source.monogon.dev/osbase/net/dhcp4c"
+)
+
+// Compose can be used to chain multiple callbacks
+func Compose(callbacks ...dhcp4c.LeaseCallback) dhcp4c.LeaseCallback {
+ return func(lease *dhcp4c.Lease) error {
+ for _, cb := range callbacks {
+ if err := cb(lease); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+}
+
+func isIPNetEqual(a, b *net.IPNet) bool {
+ if a == b {
+ return true
+ }
+ if a == nil || b == nil {
+ return false
+ }
+ aOnes, aBits := a.Mask.Size()
+ bOnes, bBits := b.Mask.Size()
+ return a.IP.Equal(b.IP) && aOnes == bOnes && aBits == bBits
+}
+
+// ManageIP sets up and tears down the assigned IP address. It takes exclusive
+// ownership of all IPv4 addresses on the given interface which do not have
+// IFA_F_PERMANENT set, so it's not possible to run multiple dynamic addressing
+// clients on a single interface.
+func ManageIP(iface netlink.Link) dhcp4c.LeaseCallback {
+ return func(lease *dhcp4c.Lease) error {
+ newNet := lease.IPNet()
+
+ addrs, err := netlink.AddrList(iface, netlink.FAMILY_V4)
+ if err != nil {
+ return fmt.Errorf("netlink failed to list addresses: %w", err)
+ }
+
+ for _, addr := range addrs {
+ if addr.Flags&unix.IFA_F_PERMANENT == 0 {
+ // Linux identifies addreses by IP, mask and peer (see
+ // net/ipv4/devinet.find_matching_ifa in Linux 5.10).
+ // So don't touch addresses which match on these properties as
+ // AddrReplace will atomically reconfigure them anyways without
+ // interrupting things.
+ if isIPNetEqual(addr.IPNet, newNet) && addr.Peer == nil && lease != nil {
+ continue
+ }
+
+ if err := netlink.AddrDel(iface, &addr); !os.IsNotExist(err) && err != nil {
+ return fmt.Errorf("failed to delete address: %w", err)
+ }
+ }
+ }
+
+ if lease != nil {
+ remainingLifetimeSecs := int(math.Ceil(time.Until(lease.ExpiresAt).Seconds()))
+ newBroadcastIP := dhcpv4.GetIP(dhcpv4.OptionBroadcastAddress, lease.Options)
+ if err := netlink.AddrReplace(iface, &netlink.Addr{
+ IPNet: newNet,
+ ValidLft: remainingLifetimeSecs,
+ PreferedLft: remainingLifetimeSecs,
+ Broadcast: newBroadcastIP,
+ }); err != nil {
+ return fmt.Errorf("failed to update address: %w", err)
+ }
+ }
+ return nil
+ }
+}
+
+// ManageRoutes installs and removes routes according to the current DHCP lease,
+// including the default route (if any).
+// It takes ownership of all RTPROTO_DHCP routes on the given interface, so it's
+// not possible to run multiple DHCP clients on the given interface.
+func ManageRoutes(iface netlink.Link) dhcp4c.LeaseCallback {
+ return func(lease *dhcp4c.Lease) error {
+ newRoutes := lease.Routes()
+
+ dhcpRoutes, err := netlink.RouteListFiltered(netlink.FAMILY_V4, &netlink.Route{
+ Protocol: unix.RTPROT_DHCP,
+ LinkIndex: iface.Attrs().Index,
+ }, netlink.RT_FILTER_OIF|netlink.RT_FILTER_PROTOCOL)
+ if err != nil {
+ return fmt.Errorf("netlink failed to list routes: %w", err)
+ }
+ for _, route := range dhcpRoutes {
+ // Don't remove routes which can be atomically replaced by
+ // RouteReplace to prevent potential traffic disruptions.
+ //
+ // This is O(n^2) but the number of routes is bounded by the size
+ // of a DHCP packet (around 100 routes). Sorting both would be
+ // be marginally faster for large amounts of routes only and in 99%
+ // of cases it's going to be <5 routes.
+ var found bool
+ for _, newRoute := range newRoutes {
+ if isIPNetEqual(newRoute.Dest, route.Dst) {
+ found = true
+ break
+ }
+ }
+ if !found {
+ err := netlink.RouteDel(&route)
+ if !os.IsNotExist(err) && err != nil {
+ return fmt.Errorf("failed to delete DHCP route: %w", err)
+ }
+ }
+ }
+
+ for _, route := range newRoutes {
+ newRoute := netlink.Route{
+ Protocol: unix.RTPROT_DHCP,
+ Dst: route.Dest,
+ Gw: route.Router,
+ Src: lease.AssignedIP,
+ LinkIndex: iface.Attrs().Index,
+ Scope: netlink.SCOPE_UNIVERSE,
+ }
+ // Routes with a non-L3 gateway are link-scoped
+ if route.Router.IsUnspecified() {
+ newRoute.Scope = netlink.SCOPE_LINK
+ }
+ err := netlink.RouteReplace(&newRoute)
+ if err != nil {
+ return fmt.Errorf("failed to add %s: %w", route, err)
+ }
+ }
+ return nil
+ }
+}
diff --git a/osbase/net/dhcp4c/callback/callback_test.go b/osbase/net/dhcp4c/callback/callback_test.go
new file mode 100644
index 0000000..db616bb
--- /dev/null
+++ b/osbase/net/dhcp4c/callback/callback_test.go
@@ -0,0 +1,355 @@
+// Copyright The Monogon Project Authors.
+// SPDX-License-Identifier: Apache-2.0
+
+package callback
+
+import (
+ "fmt"
+ "math"
+ "net"
+ "os"
+ "testing"
+ "time"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/insomniacslk/dhcp/dhcpv4"
+ "github.com/vishvananda/netlink"
+ "golang.org/x/sys/unix"
+
+ "source.monogon.dev/osbase/net/dhcp4c"
+)
+
+func trivialLeaseFromNet(ipnet net.IPNet) *dhcp4c.Lease {
+ opts := make(dhcpv4.Options)
+ opts.Update(dhcpv4.OptSubnetMask(ipnet.Mask))
+ return &dhcp4c.Lease{
+ AssignedIP: ipnet.IP,
+ ExpiresAt: time.Now().Add(1 * time.Second),
+ Options: opts,
+ }
+}
+
+var (
+ testNet1 = net.IPNet{IP: net.IP{10, 0, 1, 2}, Mask: net.CIDRMask(24, 32)}
+ testNet1Broadcast = net.IP{10, 0, 1, 255}
+ testNet1Router = net.IP{10, 0, 1, 1}
+ testNet2 = net.IPNet{IP: net.IP{10, 0, 2, 2}, Mask: net.CIDRMask(24, 32)}
+ testNet2Broadcast = net.IP{10, 0, 2, 255}
+ testNet2Router = net.IP{10, 0, 2, 1}
+ mainRoutingTable = 254 // Linux automatically puts all routes into this table unless specified
+)
+
+func TestAssignedIPCallback(t *testing.T) {
+ if os.Getenv("IN_KTEST") != "true" {
+ t.Skip("Not in ktest")
+ }
+
+ var tests = []struct {
+ name string
+ initialAddrs []netlink.Addr
+ newLease *dhcp4c.Lease
+ expectedAddrs []netlink.Addr
+ }{
+ // Lifetimes are necessary, otherwise the Kernel sets the
+ // IFA_F_PERMANENT flag behind our back.
+ {
+ name: "RemoveOldIPs",
+ initialAddrs: []netlink.Addr{{IPNet: &testNet1, ValidLft: 60}, {IPNet: &testNet2, ValidLft: 60}},
+ newLease: nil,
+ expectedAddrs: nil,
+ },
+ {
+ name: "IgnoresPermanentIPs",
+ initialAddrs: []netlink.Addr{{IPNet: &testNet1, Flags: unix.IFA_F_PERMANENT}, {IPNet: &testNet2, ValidLft: 60}},
+ newLease: trivialLeaseFromNet(testNet2),
+ expectedAddrs: []netlink.Addr{
+ {IPNet: &testNet1, Flags: unix.IFA_F_PERMANENT, ValidLft: math.MaxUint32, PreferedLft: math.MaxUint32, Broadcast: testNet1Broadcast},
+ {IPNet: &testNet2, ValidLft: 1, PreferedLft: 1, Broadcast: testNet2Broadcast},
+ },
+ },
+ {
+ name: "AssignsNewIP",
+ initialAddrs: []netlink.Addr{},
+ newLease: trivialLeaseFromNet(testNet2),
+ expectedAddrs: []netlink.Addr{
+ {IPNet: &testNet2, ValidLft: 1, PreferedLft: 1, Broadcast: testNet2Broadcast},
+ },
+ },
+ {
+ name: "UpdatesIP",
+ initialAddrs: []netlink.Addr{},
+ newLease: trivialLeaseFromNet(testNet1),
+ expectedAddrs: []netlink.Addr{
+ {IPNet: &testNet1, ValidLft: 1, PreferedLft: 1, Broadcast: testNet1Broadcast},
+ },
+ },
+ {
+ name: "RemovesIPOnRelease",
+ initialAddrs: []netlink.Addr{{IPNet: &testNet1, ValidLft: 60, PreferedLft: 60}},
+ newLease: nil,
+ expectedAddrs: nil,
+ },
+ }
+ for i, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ testLink := &netlink.Dummy{
+ LinkAttrs: netlink.LinkAttrs{
+ Name: fmt.Sprintf("aipcb-test-%d", i),
+ Flags: unix.IFF_UP,
+ },
+ }
+ if err := netlink.LinkAdd(testLink); err != nil {
+ t.Fatalf("test cannot set up network interface: %v", err)
+ }
+ defer netlink.LinkDel(testLink)
+ for _, addr := range test.initialAddrs {
+ if err := netlink.AddrAdd(testLink, &addr); err != nil {
+ t.Fatalf("test cannot set up initial addrs: %v", err)
+ }
+ }
+ // Associate dynamically-generated interface name for later comparison
+ for i := range test.expectedAddrs {
+ test.expectedAddrs[i].Label = testLink.Name
+ test.expectedAddrs[i].LinkIndex = testLink.Index
+ }
+ cb := ManageIP(testLink)
+ if err := cb(test.newLease); err != nil {
+ t.Fatalf("callback returned an error: %v", err)
+ }
+ addrs, err := netlink.AddrList(testLink, netlink.FAMILY_V4)
+ if err != nil {
+ t.Fatalf("test cannot read back addrs from interface: %v", err)
+ }
+ if diff := cmp.Diff(test.expectedAddrs, addrs); diff != "" {
+ t.Errorf("Wrong IPs on interface (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func leaseAddRouter(lease *dhcp4c.Lease, router net.IP) *dhcp4c.Lease {
+ lease.Options.Update(dhcpv4.OptRouter(router))
+ return lease
+}
+
+func leaseAddClasslessRoutes(lease *dhcp4c.Lease, routes ...*dhcpv4.Route) *dhcp4c.Lease {
+ lease.Options.Update(dhcpv4.OptClasslessStaticRoute(routes...))
+ return lease
+}
+
+func mustParseCIDR(cidr string) *net.IPNet {
+ _, n, err := net.ParseCIDR(cidr)
+ if err != nil {
+ panic(err)
+ }
+ // Equality checks don't know about net.IP's canonicalization rules.
+ if n.IP.To4() != nil {
+ n.IP = n.IP.To4()
+ }
+ return n
+}
+
+func TestDefaultRouteCallback(t *testing.T) {
+ if os.Getenv("IN_KTEST") != "true" {
+ t.Skip("Not in ktest")
+ }
+ // testRoute is only used as a route destination and not configured on any
+ // interface.
+ testRoute := net.IPNet{IP: net.IP{10, 0, 3, 0}, Mask: net.CIDRMask(24, 32)}
+
+ // A test interface is set up for each test and assigned testNet1 and
+ // testNet2 so that testNet1Router and testNet2Router are valid gateways
+ // for routes in this environment. A LinkIndex of -1 is replaced by the
+ // correct link index for this test interface at runtime for both
+ // initialRoutes and expectedRoutes.
+ var tests = []struct {
+ name string
+ initialRoutes []netlink.Route
+ newLease *dhcp4c.Lease
+ expectedRoutes []netlink.Route
+ }{
+ {
+ name: "AddsDefaultRoute",
+ initialRoutes: []netlink.Route{},
+ newLease: leaseAddRouter(trivialLeaseFromNet(testNet1), testNet1Router),
+ expectedRoutes: []netlink.Route{{
+ Protocol: unix.RTPROT_DHCP,
+ Dst: mustParseCIDR("0.0.0.0/0"),
+ Family: unix.AF_INET,
+ Gw: testNet1Router,
+ Src: testNet1.IP,
+ Table: mainRoutingTable,
+ LinkIndex: -1, // Filled in dynamically with test interface
+ Type: unix.RTN_UNICAST,
+ }},
+ },
+ {
+ name: "IgnoresLeasesWithoutRouter",
+ initialRoutes: []netlink.Route{},
+ newLease: trivialLeaseFromNet(testNet1),
+ expectedRoutes: nil,
+ },
+ {
+ name: "RemovesUnrelatedOldRoutes",
+ initialRoutes: []netlink.Route{{
+ Dst: &testRoute,
+ Family: unix.AF_INET,
+ LinkIndex: -1, // Filled in dynamically with test interface
+ Protocol: unix.RTPROT_DHCP,
+ Gw: testNet2Router,
+ Scope: netlink.SCOPE_UNIVERSE,
+ }},
+ newLease: nil,
+ expectedRoutes: nil,
+ },
+ {
+ name: "IgnoresNonDHCPRoutes",
+ initialRoutes: []netlink.Route{{
+ Dst: &testRoute,
+ Family: unix.AF_INET,
+ LinkIndex: -1, // Filled in dynamically with test interface
+ Protocol: unix.RTPROT_BIRD,
+ Gw: testNet2Router,
+ }},
+ newLease: nil,
+ expectedRoutes: []netlink.Route{{
+ Protocol: unix.RTPROT_BIRD,
+ Dst: &testRoute,
+ Family: unix.AF_INET,
+ Gw: testNet2Router,
+ Table: mainRoutingTable,
+ LinkIndex: -1, // Filled in dynamically with test interface
+ Type: unix.RTN_UNICAST,
+ }},
+ },
+ {
+ name: "RemovesRoute",
+ initialRoutes: []netlink.Route{{
+ Dst: nil,
+ Family: unix.AF_INET,
+ LinkIndex: -1, // Filled in dynamically with test interface
+ Protocol: unix.RTPROT_DHCP,
+ Gw: testNet2Router,
+ }},
+ newLease: nil,
+ expectedRoutes: nil,
+ },
+ {
+ name: "UpdatesRoute",
+ initialRoutes: []netlink.Route{{
+ Dst: nil,
+ Family: unix.AF_INET,
+ LinkIndex: -1, // Filled in dynamically with test interface
+ Protocol: unix.RTPROT_DHCP,
+ Src: testNet1.IP,
+ Gw: testNet1Router,
+ }},
+ newLease: leaseAddRouter(trivialLeaseFromNet(testNet2), testNet2Router),
+ expectedRoutes: []netlink.Route{{
+ Protocol: unix.RTPROT_DHCP,
+ Dst: mustParseCIDR("0.0.0.0/0"),
+ Family: unix.AF_INET,
+ Gw: testNet2Router,
+ Src: testNet2.IP,
+ Table: mainRoutingTable,
+ LinkIndex: -1, // Filled in dynamically with test interface
+ Type: unix.RTN_UNICAST,
+ }},
+ },
+ {
+ name: "AddsClasslessStaticRoutes",
+ initialRoutes: []netlink.Route{},
+ newLease: leaseAddClasslessRoutes(
+ // Router should be ignored
+ leaseAddRouter(trivialLeaseFromNet(testNet1), testNet1Router),
+ // P2P/foreign gateway route
+ &dhcpv4.Route{Dest: mustParseCIDR("192.168.42.1/32"), Router: net.IPv4zero},
+ // Standard route over foreign gateway set up by previous route
+ &dhcpv4.Route{Dest: mustParseCIDR("0.0.0.0/0"), Router: net.IPv4(192, 168, 42, 1)},
+ ),
+ expectedRoutes: []netlink.Route{{
+ Protocol: unix.RTPROT_DHCP,
+ Dst: mustParseCIDR("0.0.0.0/0"),
+ Family: unix.AF_INET,
+ Gw: net.IPv4(192, 168, 42, 1).To4(), // Equal() doesn't know about canonicalization
+ Src: testNet1.IP,
+ Table: mainRoutingTable,
+ LinkIndex: -1, // Filled in dynamically with test interface
+ Type: unix.RTN_UNICAST,
+ }, {
+ Protocol: unix.RTPROT_DHCP,
+ Dst: mustParseCIDR("192.168.42.1/32"),
+ Family: unix.AF_INET,
+ Gw: nil,
+ Src: testNet1.IP,
+ Table: mainRoutingTable,
+ LinkIndex: -1, // Filled in dynamically with test interface
+ Type: unix.RTN_UNICAST,
+ Scope: unix.RT_SCOPE_LINK,
+ }},
+ },
+ }
+ for i, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ testLink := &netlink.Dummy{
+ LinkAttrs: netlink.LinkAttrs{
+ Name: fmt.Sprintf("drcb-test-%d", i),
+ Flags: unix.IFF_UP,
+ },
+ }
+ if err := netlink.LinkAdd(testLink); err != nil {
+ t.Fatalf("test cannot set up network interface: %v", err)
+ }
+ defer func() { // Clean up after each test
+ routes, err := netlink.RouteListFiltered(netlink.FAMILY_V4, &netlink.Route{}, 0)
+ if err == nil {
+ for _, route := range routes {
+ netlink.RouteDel(&route)
+ }
+ }
+ }()
+ defer netlink.LinkDel(testLink)
+ if err := netlink.AddrAdd(testLink, &netlink.Addr{
+ IPNet: &testNet1,
+ }); err != nil {
+ t.Fatalf("test cannot set up test addrs: %v", err)
+ }
+ if err := netlink.AddrAdd(testLink, &netlink.Addr{
+ IPNet: &testNet2,
+ }); err != nil {
+ t.Fatalf("test cannot set up test addrs: %v", err)
+ }
+ for _, route := range test.initialRoutes {
+ if route.LinkIndex == -1 {
+ route.LinkIndex = testLink.Index
+ }
+ if err := netlink.RouteAdd(&route); err != nil {
+ t.Fatalf("test cannot set up initial routes: %v", err)
+ }
+ }
+ for i := range test.expectedRoutes {
+ if test.expectedRoutes[i].LinkIndex == -1 {
+ test.expectedRoutes[i].LinkIndex = testLink.Index
+ }
+ }
+
+ cb := ManageRoutes(testLink)
+ if err := cb(test.newLease); err != nil {
+ t.Fatalf("callback returned an error: %v", err)
+ }
+ routes, err := netlink.RouteListFiltered(netlink.FAMILY_V4, &netlink.Route{}, 0)
+ if err != nil {
+ t.Fatalf("test cannot read back routes: %v", err)
+ }
+ var notKernelRoutes []netlink.Route
+ for _, route := range routes {
+ if route.Protocol != unix.RTPROT_KERNEL { // Filter kernel-managed routes
+ notKernelRoutes = append(notKernelRoutes, route)
+ }
+ }
+ if diff := cmp.Diff(test.expectedRoutes, notKernelRoutes); diff != "" {
+ t.Errorf("Expected route mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/osbase/net/dhcp4c/dhcpc.go b/osbase/net/dhcp4c/dhcpc.go
new file mode 100644
index 0000000..ad4c9dd
--- /dev/null
+++ b/osbase/net/dhcp4c/dhcpc.go
@@ -0,0 +1,709 @@
+// Copyright The Monogon Project Authors.
+// SPDX-License-Identifier: Apache-2.0
+
+// Package dhcp4c implements a DHCPv4 Client as specified in RFC2131 (with some
+// notable deviations). It implements only the DHCP state machine itself, any
+// configuration other than the interface IP address (which is always assigned
+// in DHCP and necessary for the protocol to work) is exposed as
+// [informers/observables/watchable variables/???] to consumers who then deal
+// with it.
+package dhcp4c
+
+import (
+ "context"
+ "crypto/rand"
+ "errors"
+ "fmt"
+ "io"
+ "math"
+ "net"
+ "time"
+
+ "github.com/cenkalti/backoff/v4"
+ "github.com/insomniacslk/dhcp/dhcpv4"
+ "github.com/insomniacslk/dhcp/iana"
+
+ "source.monogon.dev/osbase/net/dhcp4c/transport"
+ "source.monogon.dev/osbase/supervisor"
+)
+
+type state int
+
+const (
+ // stateDiscovering sends broadcast DHCPDISCOVER messages to the network
+ // and waits for either a DHCPOFFER or (in case of Rapid Commit) DHCPACK.
+ stateDiscovering state = iota
+ // stateRequesting sends broadcast DHCPREQUEST messages containing the
+ // server identifier for the selected lease and waits for a DHCPACK or a
+ // DHCPNAK. If it doesn't get either it transitions back into discovering.
+ stateRequesting
+ // stateBound just waits until RenewDeadline (derived from RenewTimeValue,
+ // half the lifetime by default) expires.
+ stateBound
+ // stateRenewing sends unicast DHCPREQUEST messages to the
+ // currently-selected server and waits for either a DHCPACK or DHCPNAK
+ // message. On DHCPACK it transitions to bound, otherwise to discovering.
+ stateRenewing
+ // stateRebinding sends broadcast DHCPREQUEST messages to the network and
+ // waits for either a DHCPACK or DHCPNAK from any server. Response
+ // processing is identical to stateRenewing.
+ stateRebinding
+)
+
+func (s state) String() string {
+ switch s {
+ case stateDiscovering:
+ return "DISCOVERING"
+ case stateRequesting:
+ return "REQUESTING"
+ case stateBound:
+ return "BOUND"
+ case stateRenewing:
+ return "RENEWING"
+ case stateRebinding:
+ return "REBINDING"
+ default:
+ return "INVALID"
+ }
+}
+
+// This only requests SubnetMask and IPAddressLeaseTime as renewal and
+// rebinding times are fine if they are just defaulted. They are respected (if
+// valid, otherwise they are clamped to the nearest valid value) if sent by the
+// server.
+var internalOptions = dhcpv4.OptionCodeList{dhcpv4.OptionSubnetMask, dhcpv4.OptionIPAddressLeaseTime}
+
+// Transport represents a mechanism over which DHCP messages can be exchanged
+// with a server.
+type Transport interface {
+ // Send attempts to send the given DHCP payload message to the transport
+ // target once. An empty return value does not indicate that the message
+ // was successfully received.
+ Send(payload *dhcpv4.DHCPv4) error
+ // SetReceiveDeadline sets a deadline for Receive() calls after which they
+ // return with ErrDeadlineExceeded
+ SetReceiveDeadline(time.Time) error
+ // Receive waits for a DHCP message to arrive and returns it. If the
+ // deadline expires without a message arriving it will return
+ // ErrDeadlineExceeded. If the message is completely malformed it will an
+ // instance of InvalidMessageError.
+ Receive() (*dhcpv4.DHCPv4, error)
+ // Close closes the given transport. Calls to any of the above methods will
+ // fail if the transport is closed. Specific transports can be reopened
+ // after being closed.
+ Close() error
+}
+
+// UnicastTransport represents a mechanism over which DHCP messages can be
+// exchanged with a single server over an arbitrary IPv4-based network.
+// Implementers need to support servers running outside the local network via a
+// router.
+type UnicastTransport interface {
+ Transport
+ // Open connects the transport to a new unicast target. Can only be called
+ // after calling Close() or after creating a new transport.
+ Open(serverIP, bindIP net.IP) error
+}
+
+// BroadcastTransport represents a mechanism over which DHCP messages can be
+// exchanged with all servers on a Layer 2 broadcast domain. Implementers need
+// to support sending and receiving messages without any IP being configured on
+// the interface.
+type BroadcastTransport interface {
+ Transport
+ // Open connects the transport. Can only be called after calling Close() or
+ // after creating a new transport.
+ Open() error
+}
+
+type LeaseCallback func(*Lease) error
+
+// Client implements a DHCPv4 client.
+//
+// Note that the size of all data sent to the server (RequestedOptions,
+// ClientIdentifier, VendorClassIdentifier and ExtraRequestOptions) should be
+// kept reasonably small (<500 bytes) in order to maximize the chance that
+// requests can be properly transmitted.
+type Client struct {
+ // RequestedOptions contains a list of extra options this client is
+ // interested in
+ RequestedOptions dhcpv4.OptionCodeList
+
+ // ClientIdentifier is used by the DHCP server to identify this client.
+ // If empty, on Ethernet the MAC address is used instead.
+ ClientIdentifier []byte
+
+ // VendorClassIdentifier is used by the DHCP server to identify options
+ // specific to this type of clients and to populate the vendor-specific
+ // option (43).
+ VendorClassIdentifier string
+
+ // ExtraRequestOptions are extra options sent to the server.
+ ExtraRequestOptions dhcpv4.Options
+
+ // Backoff strategies for each state. These all have sane defaults,
+ // override them only if necessary.
+ DiscoverBackoff backoff.BackOff
+ AcceptOfferBackoff backoff.BackOff
+ RenewBackoff backoff.BackOff
+ RebindBackoff backoff.BackOff
+
+ state state
+
+ iface *net.Interface
+
+ // now can be used to override time for testing
+ now func() time.Time
+
+ // LeaseCallback is called every time a lease is aquired, renewed or lost
+ LeaseCallback LeaseCallback
+
+ // Valid in states Discovering, Requesting, Rebinding
+ broadcastConn BroadcastTransport
+
+ // Valid in states Requesting
+ offer *dhcpv4.DHCPv4
+
+ // Valid in states Bound, Renewing
+ unicastConn UnicastTransport
+
+ // Valid in states Bound, Renewing, Rebinding
+ lease *dhcpv4.DHCPv4
+ leaseDeadline time.Time
+ leaseBoundDeadline time.Time
+ leaseRenewDeadline time.Time
+}
+
+// defaultBackoffOpts can be passed to NewExponentialBackOff and configures it
+// to retry infinitely and use a DHCP-appropriate InitialInterval.
+func defaultBackoffOpts(b *backoff.ExponentialBackOff) {
+ b.MaxElapsedTime = 0 // No Timeout
+ // Lots of servers wait 1s for existing users of an IP. Wait at least for
+ // that and keep some slack for randomization, communication and processing
+ // overhead.
+ b.InitialInterval = 1400 * time.Millisecond
+ b.MaxInterval = 30 * time.Second
+ b.RandomizationFactor = 0.2
+}
+
+// NewClient instantiates (but doesn't start) a new DHCPv4 client.
+// To have a working client it's required to set LeaseCallback to something
+// that is capable of configuring the IP address on the given interface. Unless
+// managed through external means like a routing protocol, setting the default
+// route is also required. A simple example with the callback package thus
+// looks like this:
+//
+// c := dhcp4c.NewClient(yourInterface)
+// c.LeaseCallback = callback.Compose(callback.ManageIP(yourInterface), callback.ManageDefaultRoute(yourInterface))
+// c.Run(ctx)
+func NewClient(iface *net.Interface) (*Client, error) {
+ broadcastConn := transport.NewBroadcastTransport(iface)
+
+ // broadcastConn needs to be open in stateDiscovering
+ if err := broadcastConn.Open(); err != nil {
+ return nil, fmt.Errorf("failed to create DHCP broadcast transport: %w", err)
+ }
+
+ discoverBackoff := backoff.NewExponentialBackOff(defaultBackoffOpts)
+
+ acceptOfferBackoff := backoff.NewExponentialBackOff(defaultBackoffOpts,
+ // Abort after 30s and go back to discovering
+ backoff.WithMaxElapsedTime(30*time.Second))
+
+ renewBackoff := backoff.NewExponentialBackOff(defaultBackoffOpts,
+ // Increase maximum interval to reduce chatter when the server is down
+ backoff.WithMaxInterval(5*time.Minute))
+
+ rebindBackoff := backoff.NewExponentialBackOff(defaultBackoffOpts,
+ // Increase maximum interval to reduce chatter when the server is down
+ backoff.WithMaxInterval(5*time.Minute))
+
+ // Check if the hardware address contains at least one non-zero value.
+ // This exists to catch undefined/non-supplied hardware address values,
+ // it does not check for L2 protocol-specific hardware address constraints.
+ hasValidHWAddr := false
+ for _, b := range iface.HardwareAddr {
+ if b != 0x00 {
+ hasValidHWAddr = true
+ break
+ }
+ }
+ if !hasValidHWAddr {
+ return nil, fmt.Errorf("iface HardwareAddr is invalid (only zeroes or invalid length): %x", iface.HardwareAddr)
+ }
+
+ return &Client{
+ state: stateDiscovering,
+ broadcastConn: broadcastConn,
+ unicastConn: transport.NewUnicastTransport(iface),
+ iface: iface,
+ RequestedOptions: dhcpv4.OptionCodeList{},
+ now: time.Now,
+ DiscoverBackoff: discoverBackoff,
+ AcceptOfferBackoff: acceptOfferBackoff,
+ RenewBackoff: renewBackoff,
+ RebindBackoff: rebindBackoff,
+ }, nil
+}
+
+// acceptableLease checks if the given lease is valid enough to even be
+// processed. This is intentionally not exposed to users because under certain
+// cirumstances it can end up acquiring all available IP addresses from a
+// server.
+func (c *Client) acceptableLease(offer *dhcpv4.DHCPv4) bool {
+ // RFC2131 Section 4.3.1 Table 3
+ if offer.ServerIdentifier() == nil || offer.ServerIdentifier().To4() == nil {
+ return false
+ }
+ // RFC2131 Section 4.3.1 Table 3
+ // Minimum representable lease time is 1s (Section 1.1)
+ if offer.IPAddressLeaseTime(0) < 1*time.Second {
+ return false
+ }
+
+ // Ignore IPs that are in no way valid for an interface (multicast,
+ // loopback, ...)
+ if offer.YourIPAddr.To4() == nil || (!offer.YourIPAddr.IsGlobalUnicast() && !offer.YourIPAddr.IsLinkLocalUnicast()) {
+ return false
+ }
+
+ // Technically the options Requested IP address, Parameter request list,
+ // Client identifier and Maximum message size should be refused (MUST NOT),
+ // but in the interest of interopatibilty let's simply remove them if they
+ // are present.
+ delete(offer.Options, dhcpv4.OptionRequestedIPAddress.Code())
+ delete(offer.Options, dhcpv4.OptionParameterRequestList.Code())
+ delete(offer.Options, dhcpv4.OptionClientIdentifier.Code())
+ delete(offer.Options, dhcpv4.OptionMaximumDHCPMessageSize.Code())
+
+ // Clamp rebindinding times longer than the lease time. Otherwise the state
+ // machine might misbehave.
+ if offer.IPAddressRebindingTime(0) > offer.IPAddressLeaseTime(0) {
+ offer.UpdateOption(dhcpv4.OptGeneric(dhcpv4.OptionRebindingTimeValue, dhcpv4.Duration(offer.IPAddressLeaseTime(0)).ToBytes()))
+ }
+ // Clamp renewal times longer than the rebinding time. Otherwise the state
+ // machine might misbehave.
+ if offer.IPAddressRenewalTime(0) > offer.IPAddressRebindingTime(0) {
+ offer.UpdateOption(dhcpv4.OptGeneric(dhcpv4.OptionRenewTimeValue, dhcpv4.Duration(offer.IPAddressRebindingTime(0)).ToBytes()))
+ }
+
+ // Normalize two options that can be represented either inline or as
+ // options.
+ if len(offer.ServerHostName) > 0 {
+ offer.Options[uint8(dhcpv4.OptionTFTPServerName)] = []byte(offer.ServerHostName)
+ }
+ if len(offer.BootFileName) > 0 {
+ offer.Options[uint8(dhcpv4.OptionBootfileName)] = []byte(offer.BootFileName)
+ }
+
+ // Normalize siaddr to option 150 (see RFC5859)
+ if len(offer.GetOneOption(dhcpv4.OptionTFTPServerAddress)) == 0 {
+ if offer.ServerIPAddr.To4() != nil && (offer.ServerIPAddr.IsGlobalUnicast() || offer.ServerIPAddr.IsLinkLocalUnicast()) {
+ offer.Options[uint8(dhcpv4.OptionTFTPServerAddress)] = offer.ServerIPAddr.To4()
+ }
+ }
+
+ return true
+}
+
+func earliestDeadline(dl1, dl2 time.Time) time.Time {
+ if dl1.Before(dl2) {
+ return dl1
+ } else {
+ return dl2
+ }
+}
+
+// newXID generates a new transaction ID
+func (c *Client) newXID() (dhcpv4.TransactionID, error) {
+ var xid dhcpv4.TransactionID
+ if _, err := io.ReadFull(rand.Reader, xid[:]); err != nil {
+ return xid, fmt.Errorf("cannot read randomness for transaction ID: %w", err)
+ }
+ return xid, nil
+}
+
+// As most servers out there cannot do reassembly, let's just hope for the best
+// and provide the local interface MTU. If the packet is too big it won't work
+// anyways. Also clamp to the biggest representable MTU in DHCPv4 (2 bytes
+// unsigned int).
+func (c *Client) maxMsgSize() uint16 {
+ if c.iface.MTU < math.MaxUint16 {
+ return uint16(c.iface.MTU)
+ } else {
+ return math.MaxUint16
+ }
+}
+
+// newMsg creates a new DHCP message of a given type and adds common options.
+func (c *Client) newMsg(t dhcpv4.MessageType) (*dhcpv4.DHCPv4, error) {
+ xid, err := c.newXID()
+ if err != nil {
+ return nil, err
+ }
+ opts := make(dhcpv4.Options)
+ opts.Update(dhcpv4.OptMessageType(t))
+ if len(c.ClientIdentifier) > 0 {
+ opts.Update(dhcpv4.OptClientIdentifier(c.ClientIdentifier))
+ }
+ if t == dhcpv4.MessageTypeDiscover || t == dhcpv4.MessageTypeRequest || t == dhcpv4.MessageTypeInform {
+ opts.Update(dhcpv4.OptParameterRequestList(append(append(dhcpv4.OptionCodeList(nil), c.RequestedOptions...), internalOptions...)...))
+ opts.Update(dhcpv4.OptMaxMessageSize(c.maxMsgSize()))
+ if c.VendorClassIdentifier != "" {
+ opts.Update(dhcpv4.OptClassIdentifier(c.VendorClassIdentifier))
+ }
+ for opt, val := range c.ExtraRequestOptions {
+ opts[opt] = val
+ }
+ }
+ return &dhcpv4.DHCPv4{
+ OpCode: dhcpv4.OpcodeBootRequest,
+ HWType: iana.HWTypeEthernet,
+ ClientHWAddr: c.iface.HardwareAddr,
+ HopCount: 0,
+ TransactionID: xid,
+ NumSeconds: 0,
+ Flags: 0,
+ ClientIPAddr: net.IPv4zero,
+ YourIPAddr: net.IPv4zero,
+ ServerIPAddr: net.IPv4zero,
+ GatewayIPAddr: net.IPv4zero,
+ Options: opts,
+ }, nil
+}
+
+// transactionStateSpec describes a state which is driven by a DHCP message
+// transaction (sending a specific message and then transitioning into a
+// different state depending on the received messages)
+type transactionStateSpec struct {
+ // ctx is a context for canceling the process
+ ctx context.Context
+
+ // transport is used to send and receive messages in this state
+ transport Transport
+
+ // stateDeadline is a fixed external deadline for how long the FSM can
+ // remain in this state.
+ // If it's exceeded the stateDeadlineExceeded callback is called and
+ // responsible for transitioning out of this state. It can be left empty to
+ // signal that there's no external deadline for the state.
+ stateDeadline time.Time
+
+ // backoff controls how long to wait for answers until handing control back
+ // to the FSM.
+ // Since the FSM hasn't advanced until then this means we just get called
+ // again and retransmit.
+ backoff backoff.BackOff
+
+ // requestType is the type of DHCP request sent out in this state. This is
+ // used to populate the default options for the message.
+ requestType dhcpv4.MessageType
+
+ // setExtraOptions can modify the request and set extra options before
+ // transmitting. Returning an error here aborts the FSM an can be used to
+ // terminate when no valid request can be constructed.
+ setExtraOptions func(msg *dhcpv4.DHCPv4) error
+
+ // handleMessage gets called for every parseable (not necessarily valid)
+ // DHCP message received by the transport. It should return an error for
+ // every message that doesn't advance the state machine and no error for
+ // every one that does. It is responsible for advancing the FSM if the
+ // required information is present.
+ handleMessage func(msg *dhcpv4.DHCPv4, sentTime time.Time) error
+
+ // stateDeadlineExceeded gets called if either the backoff returns
+ // backoff.Stop or the stateDeadline runs out. It is responsible for
+ // advancing the FSM into the next state.
+ stateDeadlineExceeded func() error
+}
+
+func (c *Client) runTransactionState(s transactionStateSpec) error {
+ sentTime := c.now()
+ msg, err := c.newMsg(s.requestType)
+ if err != nil {
+ return fmt.Errorf("failed to get new DHCP message: %w", err)
+ }
+ if err := s.setExtraOptions(msg); err != nil {
+ return fmt.Errorf("failed to create DHCP message: %w", err)
+ }
+
+ wait := s.backoff.NextBackOff()
+ if wait == backoff.Stop {
+ return s.stateDeadlineExceeded()
+ }
+
+ receiveDeadline := sentTime.Add(wait)
+ if !s.stateDeadline.IsZero() {
+ receiveDeadline = earliestDeadline(s.stateDeadline, receiveDeadline)
+
+ // Jump out if deadline expires in less than 10ms. Minimum lease time is 1s
+ // and if we have less than 10ms to wait for an answer before switching
+ // state it makes no sense to send out another request. This nearly
+ // eliminates the problem of sending two different requests back-to-back.
+ if s.stateDeadline.Add(-10 * time.Millisecond).Before(sentTime) {
+ return s.stateDeadlineExceeded()
+ }
+ }
+
+ if err := s.transport.Send(msg); err != nil {
+ return fmt.Errorf("failed to send message: %w", err)
+ }
+
+ if err := s.transport.SetReceiveDeadline(receiveDeadline); err != nil {
+ return fmt.Errorf("failed to set deadline: %w", err)
+ }
+
+ for {
+ offer, err := s.transport.Receive()
+ select {
+ case <-s.ctx.Done():
+ c.cleanup()
+ return s.ctx.Err()
+ default:
+ }
+ if errors.Is(err, transport.ErrDeadlineExceeded) {
+ return nil
+ }
+ var e transport.InvalidMessageError
+ if errors.As(err, &e) {
+ // Packet couldn't be read. Maybe log at some point in the future.
+ continue
+ }
+ if err != nil {
+ return fmt.Errorf("failed to receive packet: %w", err)
+ }
+ if offer.TransactionID != msg.TransactionID { // Not our transaction
+ continue
+ }
+ err = s.handleMessage(offer, sentTime)
+ if err == nil {
+ return nil
+ } else if !errors.Is(err, ErrInvalidMsg) {
+ return err
+ }
+ }
+}
+
+var ErrInvalidMsg = errors.New("invalid message")
+
+func (c *Client) runState(ctx context.Context) error {
+ switch c.state {
+ case stateDiscovering:
+ return c.runTransactionState(transactionStateSpec{
+ ctx: ctx,
+ transport: c.broadcastConn,
+ backoff: c.DiscoverBackoff,
+ requestType: dhcpv4.MessageTypeDiscover,
+ setExtraOptions: func(msg *dhcpv4.DHCPv4) error {
+ msg.UpdateOption(dhcpv4.OptGeneric(dhcpv4.OptionRapidCommit, []byte{}))
+ return nil
+ },
+ handleMessage: func(offer *dhcpv4.DHCPv4, sentTime time.Time) error {
+ switch offer.MessageType() {
+ case dhcpv4.MessageTypeOffer:
+ if c.acceptableLease(offer) {
+ c.offer = offer
+ c.AcceptOfferBackoff.Reset()
+ c.state = stateRequesting
+ return nil
+ }
+ case dhcpv4.MessageTypeAck:
+ if c.acceptableLease(offer) {
+ return c.transitionToBound(offer, sentTime)
+ }
+ }
+ return ErrInvalidMsg
+ },
+ })
+ case stateRequesting:
+ return c.runTransactionState(transactionStateSpec{
+ ctx: ctx,
+ transport: c.broadcastConn,
+ backoff: c.AcceptOfferBackoff,
+ requestType: dhcpv4.MessageTypeRequest,
+ setExtraOptions: func(msg *dhcpv4.DHCPv4) error {
+ msg.UpdateOption(dhcpv4.OptServerIdentifier(c.offer.ServerIdentifier()))
+ msg.TransactionID = c.offer.TransactionID
+ msg.UpdateOption(dhcpv4.OptRequestedIPAddress(c.offer.YourIPAddr))
+ return nil
+ },
+ handleMessage: func(msg *dhcpv4.DHCPv4, sentTime time.Time) error {
+ switch msg.MessageType() {
+ case dhcpv4.MessageTypeAck:
+ if c.acceptableLease(msg) {
+ return c.transitionToBound(msg, sentTime)
+ }
+ case dhcpv4.MessageTypeNak:
+ c.requestingToDiscovering()
+ return nil
+ }
+ return ErrInvalidMsg
+ },
+ stateDeadlineExceeded: func() error {
+ c.requestingToDiscovering()
+ return nil
+ },
+ })
+ case stateBound:
+ select {
+ case <-time.After(c.leaseBoundDeadline.Sub(c.now())):
+ c.state = stateRenewing
+ c.RenewBackoff.Reset()
+ return nil
+ case <-ctx.Done():
+ c.cleanup()
+ return ctx.Err()
+ }
+ case stateRenewing:
+ return c.runTransactionState(transactionStateSpec{
+ ctx: ctx,
+ transport: c.unicastConn,
+ backoff: c.RenewBackoff,
+ requestType: dhcpv4.MessageTypeRequest,
+ stateDeadline: c.leaseRenewDeadline,
+ setExtraOptions: func(msg *dhcpv4.DHCPv4) error {
+ msg.ClientIPAddr = c.lease.YourIPAddr
+ return nil
+ },
+ handleMessage: func(ack *dhcpv4.DHCPv4, sentTime time.Time) error {
+ switch ack.MessageType() {
+ case dhcpv4.MessageTypeAck:
+ if c.acceptableLease(ack) {
+ return c.transitionToBound(ack, sentTime)
+ }
+ case dhcpv4.MessageTypeNak:
+ return c.leaseToDiscovering()
+ }
+ return ErrInvalidMsg
+ },
+ stateDeadlineExceeded: func() error {
+ c.state = stateRebinding
+ if err := c.switchToBroadcast(); err != nil {
+ return fmt.Errorf("failed to switch to broadcast: %w", err)
+ }
+ c.RebindBackoff.Reset()
+ return nil
+ },
+ })
+ case stateRebinding:
+ return c.runTransactionState(transactionStateSpec{
+ ctx: ctx,
+ transport: c.broadcastConn,
+ backoff: c.RebindBackoff,
+ stateDeadline: c.leaseDeadline,
+ requestType: dhcpv4.MessageTypeRequest,
+ setExtraOptions: func(msg *dhcpv4.DHCPv4) error {
+ msg.ClientIPAddr = c.lease.YourIPAddr
+ return nil
+ },
+ handleMessage: func(ack *dhcpv4.DHCPv4, sentTime time.Time) error {
+ switch ack.MessageType() {
+ case dhcpv4.MessageTypeAck:
+ if c.acceptableLease(ack) {
+ return c.transitionToBound(ack, sentTime)
+ }
+ case dhcpv4.MessageTypeNak:
+ return c.leaseToDiscovering()
+ }
+ return ErrInvalidMsg
+ },
+ stateDeadlineExceeded: func() error {
+ return c.leaseToDiscovering()
+ },
+ })
+ }
+ return errors.New("state machine in invalid state")
+}
+
+func (c *Client) Run(ctx context.Context) error {
+ if c.LeaseCallback == nil {
+ panic("LeaseCallback must be set before calling Run")
+ }
+ logger := supervisor.Logger(ctx)
+ for {
+ oldState := c.state
+ if err := c.runState(ctx); err != nil {
+ return err
+ }
+ if c.state != oldState {
+ logger.Infof("%s => %s", oldState, c.state)
+ }
+ }
+}
+
+func (c *Client) cleanup() {
+ c.unicastConn.Close()
+ if c.lease != nil {
+ c.LeaseCallback(nil)
+ }
+ c.broadcastConn.Close()
+}
+
+func (c *Client) requestingToDiscovering() {
+ c.offer = nil
+ c.DiscoverBackoff.Reset()
+ c.state = stateDiscovering
+}
+
+func (c *Client) leaseToDiscovering() error {
+ if c.state == stateRenewing {
+ if err := c.switchToBroadcast(); err != nil {
+ return err
+ }
+ }
+ c.state = stateDiscovering
+ c.lease = nil
+ c.DiscoverBackoff.Reset()
+ if err := c.LeaseCallback(nil); err != nil {
+ return fmt.Errorf("lease callback failed: %w", err)
+ }
+ return nil
+}
+
+func leaseFromAck(ack *dhcpv4.DHCPv4, expiresAt time.Time) *Lease {
+ if ack == nil {
+ return nil
+ }
+ return &Lease{Options: ack.Options, AssignedIP: ack.YourIPAddr, ExpiresAt: expiresAt}
+}
+
+func (c *Client) transitionToBound(ack *dhcpv4.DHCPv4, sentTime time.Time) error {
+ // Guaranteed to exist, leases without a lease time are filtered
+ leaseTime := ack.IPAddressLeaseTime(0)
+ c.leaseDeadline = sentTime.Add(leaseTime)
+ c.leaseBoundDeadline = sentTime.Add(ack.IPAddressRenewalTime(time.Duration(float64(leaseTime) * 0.5)))
+ c.leaseRenewDeadline = sentTime.Add(ack.IPAddressRebindingTime(time.Duration(float64(leaseTime) * 0.85)))
+
+ if err := c.LeaseCallback(leaseFromAck(ack, c.leaseDeadline)); err != nil {
+ return fmt.Errorf("lease callback failed: %w", err)
+ }
+
+ if c.state != stateRenewing {
+ if err := c.switchToUnicast(ack.ServerIdentifier(), ack.YourIPAddr); err != nil {
+ return fmt.Errorf("failed to switch transports: %w", err)
+ }
+ }
+ c.state = stateBound
+ c.lease = ack
+ return nil
+}
+
+func (c *Client) switchToUnicast(serverIP, bindIP net.IP) error {
+ if err := c.broadcastConn.Close(); err != nil {
+ return fmt.Errorf("failed to close broadcast transport: %w", err)
+ }
+ if err := c.unicastConn.Open(serverIP, bindIP); err != nil {
+ return fmt.Errorf("failed to open unicast transport: %w", err)
+ }
+ return nil
+}
+
+func (c *Client) switchToBroadcast() error {
+ if err := c.unicastConn.Close(); err != nil {
+ return fmt.Errorf("failed to close unicast transport: %w", err)
+ }
+ if err := c.broadcastConn.Open(); err != nil {
+ return fmt.Errorf("failed to open broadcast transport: %w", err)
+ }
+ return nil
+}
diff --git a/osbase/net/dhcp4c/dhcpc_test.go b/osbase/net/dhcp4c/dhcpc_test.go
new file mode 100644
index 0000000..67651a7
--- /dev/null
+++ b/osbase/net/dhcp4c/dhcpc_test.go
@@ -0,0 +1,501 @@
+// Copyright The Monogon Project Authors.
+// SPDX-License-Identifier: Apache-2.0
+
+package dhcp4c
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "testing"
+ "time"
+
+ "github.com/cenkalti/backoff/v4"
+ "github.com/insomniacslk/dhcp/dhcpv4"
+ "github.com/stretchr/testify/assert"
+
+ "source.monogon.dev/osbase/net/dhcp4c/transport"
+)
+
+type fakeTime struct {
+ time time.Time
+}
+
+func newFakeTime(t time.Time) *fakeTime {
+ return &fakeTime{
+ time: t,
+ }
+}
+
+func (ft *fakeTime) Now() time.Time {
+ return ft.time
+}
+
+func (ft *fakeTime) Advance(d time.Duration) {
+ ft.time = ft.time.Add(d)
+}
+
+type mockTransport struct {
+ sentPacket *dhcpv4.DHCPv4
+ sendError error
+ setDeadline time.Time
+ receivePackets []*dhcpv4.DHCPv4
+ receiveError error
+ receiveIdx int
+ closed bool
+}
+
+func (mt *mockTransport) sendPackets(pkts ...*dhcpv4.DHCPv4) {
+ mt.receiveIdx = 0
+ mt.receivePackets = pkts
+}
+
+func (mt *mockTransport) Open() error {
+ mt.closed = false
+ return nil
+}
+
+func (mt *mockTransport) Send(payload *dhcpv4.DHCPv4) error {
+ mt.sentPacket = payload
+ return mt.sendError
+}
+
+func (mt *mockTransport) Receive() (*dhcpv4.DHCPv4, error) {
+ if mt.receiveError != nil {
+ return nil, mt.receiveError
+ }
+ if len(mt.receivePackets) > mt.receiveIdx {
+ packet := mt.receivePackets[mt.receiveIdx]
+ packet, err := dhcpv4.FromBytes(packet.ToBytes()) // Clone packet
+ if err != nil {
+ panic("ToBytes => FromBytes failed")
+ }
+ packet.TransactionID = mt.sentPacket.TransactionID
+ mt.receiveIdx++
+ return packet, nil
+ }
+ return nil, transport.ErrDeadlineExceeded
+}
+
+func (mt *mockTransport) SetReceiveDeadline(t time.Time) error {
+ mt.setDeadline = t
+ return nil
+}
+
+func (mt *mockTransport) Close() error {
+ mt.closed = true
+ return nil
+}
+
+type unicastMockTransport struct {
+ mockTransport
+ serverIP net.IP
+ bindIP net.IP
+}
+
+func (umt *unicastMockTransport) Open(serverIP, bindIP net.IP) error {
+ if umt.serverIP != nil {
+ panic("double-open of unicast transport")
+ }
+ umt.serverIP = serverIP
+ umt.bindIP = bindIP
+ return nil
+}
+
+func (umt *unicastMockTransport) Close() error {
+ umt.serverIP = nil
+ umt.bindIP = nil
+ return umt.mockTransport.Close()
+}
+
+type mockBackoff struct {
+ indefinite bool
+ values []time.Duration
+ idx int
+}
+
+func newMockBackoff(vals []time.Duration, indefinite bool) *mockBackoff {
+ return &mockBackoff{values: vals, indefinite: indefinite}
+}
+
+func (mb *mockBackoff) NextBackOff() time.Duration {
+ if mb.idx < len(mb.values) || mb.indefinite {
+ val := mb.values[mb.idx%len(mb.values)]
+ mb.idx++
+ return val
+ }
+ return backoff.Stop
+}
+
+func (mb *mockBackoff) Reset() {
+ mb.idx = 0
+}
+
+func TestClient_runTransactionState(t *testing.T) {
+ ft := newFakeTime(time.Date(2020, 10, 28, 15, 02, 32, 352, time.UTC))
+ c := Client{
+ now: ft.Now,
+ iface: &net.Interface{MTU: 9324, HardwareAddr: net.HardwareAddr{0x12, 0x23, 0x34, 0x45, 0x56, 0x67}},
+ }
+ mt := &mockTransport{}
+ err := c.runTransactionState(transactionStateSpec{
+ ctx: context.Background(),
+ transport: mt,
+ backoff: newMockBackoff([]time.Duration{1 * time.Second}, true),
+ requestType: dhcpv4.MessageTypeDiscover,
+ setExtraOptions: func(msg *dhcpv4.DHCPv4) error {
+ msg.UpdateOption(dhcpv4.OptDomainName("just.testing.invalid"))
+ return nil
+ },
+ handleMessage: func(msg *dhcpv4.DHCPv4, sentTime time.Time) error {
+ return nil
+ },
+ stateDeadlineExceeded: func() error {
+ panic("shouldn't be called")
+ },
+ })
+ assert.NoError(t, err)
+ assert.Equal(t, "just.testing.invalid", mt.sentPacket.DomainName())
+ assert.Equal(t, dhcpv4.MessageTypeDiscover, mt.sentPacket.MessageType())
+}
+
+// TestAcceptableLease tests if a minimal valid lease is accepted by
+// acceptableLease
+func TestAcceptableLease(t *testing.T) {
+ var c Client
+ offer := &dhcpv4.DHCPv4{
+ OpCode: dhcpv4.OpcodeBootReply,
+ }
+ offer.UpdateOption(dhcpv4.OptMessageType(dhcpv4.MessageTypeOffer))
+ offer.UpdateOption(dhcpv4.OptServerIdentifier(net.IP{192, 0, 2, 1}))
+ offer.UpdateOption(dhcpv4.OptIPAddressLeaseTime(10 * time.Second))
+ offer.YourIPAddr = net.IP{192, 0, 2, 2}
+ assert.True(t, c.acceptableLease(offer), "Valid lease is not acceptable")
+}
+
+type dhcpClientPuppet struct {
+ ft *fakeTime
+ bmt *mockTransport
+ umt *unicastMockTransport
+ c *Client
+}
+
+func newPuppetClient(initState state) *dhcpClientPuppet {
+ ft := newFakeTime(time.Date(2020, 10, 28, 15, 02, 32, 352, time.UTC))
+ bmt := &mockTransport{}
+ umt := &unicastMockTransport{}
+ c := &Client{
+ state: initState,
+ now: ft.Now,
+ iface: &net.Interface{MTU: 9324, HardwareAddr: net.HardwareAddr{0x12, 0x23, 0x34, 0x45, 0x56, 0x67}},
+ broadcastConn: bmt,
+ unicastConn: umt,
+ DiscoverBackoff: newMockBackoff([]time.Duration{1 * time.Second}, true),
+ AcceptOfferBackoff: newMockBackoff([]time.Duration{1 * time.Second, 2 * time.Second}, false),
+ RenewBackoff: newMockBackoff([]time.Duration{1 * time.Second}, true),
+ RebindBackoff: newMockBackoff([]time.Duration{1 * time.Second}, true),
+ }
+ return &dhcpClientPuppet{
+ ft: ft,
+ bmt: bmt,
+ umt: umt,
+ c: c,
+ }
+}
+
+func newResponse(m dhcpv4.MessageType) *dhcpv4.DHCPv4 {
+ o := &dhcpv4.DHCPv4{
+ OpCode: dhcpv4.OpcodeBootReply,
+ }
+ o.UpdateOption(dhcpv4.OptMessageType(m))
+ return o
+}
+
+// TestDiscoverOffer tests if the DHCP state machine in discovering state
+// properly selects the first valid lease and transitions to requesting state
+func TestDiscoverRequesting(t *testing.T) {
+ p := newPuppetClient(stateDiscovering)
+
+ // A minimal valid lease
+ offer := newResponse(dhcpv4.MessageTypeOffer)
+ testIP := net.IP{192, 0, 2, 2}
+ offer.UpdateOption(dhcpv4.OptServerIdentifier(net.IP{192, 0, 2, 1}))
+ offer.UpdateOption(dhcpv4.OptIPAddressLeaseTime(10 * time.Second))
+ offer.YourIPAddr = testIP
+
+ // Intentionally bad offer with no lease time.
+ terribleOffer := newResponse(dhcpv4.MessageTypeOffer)
+ terribleOffer.UpdateOption(dhcpv4.OptServerIdentifier(net.IP{192, 0, 2, 2}))
+ terribleOffer.YourIPAddr = net.IPv4(192, 0, 2, 2)
+
+ // Send the bad offer first, then the valid offer
+ p.bmt.sendPackets(terribleOffer, offer)
+
+ if err := p.c.runState(context.Background()); err != nil {
+ t.Fatal(err)
+ }
+ assert.Equal(t, stateRequesting, p.c.state, "DHCP client didn't process offer")
+ assert.Equal(t, testIP, p.c.offer.YourIPAddr, "DHCP client requested invalid offer")
+}
+
+// TestOfferBound tests if the DHCP state machine in requesting state processes
+// a valid DHCPACK and transitions to bound state.
+func TestRequestingBound(t *testing.T) {
+ p := newPuppetClient(stateRequesting)
+
+ offer := newResponse(dhcpv4.MessageTypeAck)
+ testIP := net.IP{192, 0, 2, 2}
+ offer.UpdateOption(dhcpv4.OptServerIdentifier(net.IP{192, 0, 2, 1}))
+ offer.UpdateOption(dhcpv4.OptIPAddressLeaseTime(10 * time.Second))
+ offer.YourIPAddr = testIP
+
+ p.bmt.sendPackets(offer)
+ p.c.offer = offer
+ p.c.LeaseCallback = func(lease *Lease) error {
+ assert.Equal(t, testIP, lease.AssignedIP, "new lease has wrong IP")
+ return nil
+ }
+
+ if err := p.c.runState(context.Background()); err != nil {
+ t.Fatal(err)
+ }
+ assert.Equal(t, stateBound, p.c.state, "DHCP client didn't process offer")
+ assert.Equal(t, testIP, p.c.lease.YourIPAddr, "DHCP client requested invalid offer")
+}
+
+// TestRequestingDiscover tests if the DHCP state machine in requesting state
+// transitions back to discovering if it takes too long to get a valid DHCPACK.
+func TestRequestingDiscover(t *testing.T) {
+ p := newPuppetClient(stateRequesting)
+
+ offer := newResponse(dhcpv4.MessageTypeOffer)
+ testIP := net.IP{192, 0, 2, 2}
+ offer.UpdateOption(dhcpv4.OptServerIdentifier(net.IP{192, 0, 2, 1}))
+ offer.UpdateOption(dhcpv4.OptIPAddressLeaseTime(10 * time.Second))
+ offer.YourIPAddr = testIP
+ p.c.offer = offer
+
+ for i := 0; i < 10; i++ {
+ p.bmt.sendPackets()
+ if err := p.c.runState(context.Background()); err != nil {
+ t.Fatal(err)
+ }
+ assert.Equal(t, dhcpv4.MessageTypeRequest, p.bmt.sentPacket.MessageType(), "Invalid message type for requesting")
+ if p.c.state == stateDiscovering {
+ break
+ }
+ p.ft.time = p.bmt.setDeadline
+
+ if i == 9 {
+ t.Fatal("Too many tries while requesting, backoff likely wrong")
+ }
+ }
+ assert.Equal(t, stateDiscovering, p.c.state, "DHCP client didn't switch back to offer after requesting expired")
+}
+
+// TestDiscoverRapidCommit tests if the DHCP state machine in discovering state
+// transitions directly to bound if a rapid commit response (DHCPACK) is
+// received.
+func TestDiscoverRapidCommit(t *testing.T) {
+ testIP := net.IP{192, 0, 2, 2}
+ offer := newResponse(dhcpv4.MessageTypeAck)
+ offer.UpdateOption(dhcpv4.OptServerIdentifier(net.IP{192, 0, 2, 1}))
+ leaseTime := 10 * time.Second
+ offer.UpdateOption(dhcpv4.OptIPAddressLeaseTime(leaseTime))
+ offer.YourIPAddr = testIP
+
+ p := newPuppetClient(stateDiscovering)
+ p.c.LeaseCallback = func(lease *Lease) error {
+ assert.Equal(t, testIP, lease.AssignedIP, "callback called with wrong IP")
+ assert.Equal(t, p.ft.Now().Add(leaseTime), lease.ExpiresAt, "invalid ExpiresAt")
+ return nil
+ }
+ p.bmt.sendPackets(offer)
+ if err := p.c.runState(context.Background()); err != nil {
+ t.Fatal(err)
+ }
+ assert.Equal(t, stateBound, p.c.state, "DHCP client didn't process offer")
+ assert.Equal(t, testIP, p.c.lease.YourIPAddr, "DHCP client requested invalid offer")
+ assert.Equal(t, 5*time.Second, p.c.leaseBoundDeadline.Sub(p.ft.Now()), "Renewal time was incorrectly defaulted")
+}
+
+type TestOption uint8
+
+func (o TestOption) Code() uint8 {
+ return uint8(o) + 224 // Private options
+}
+func (o TestOption) String() string {
+ return fmt.Sprintf("Test Option %d", uint8(o))
+}
+
+// TestBoundRenewingBound tests if the DHCP state machine in bound correctly
+// transitions to renewing after leaseBoundDeadline expires, sends a
+// DHCPREQUEST and after it gets a DHCPACK response calls LeaseCallback and
+// transitions back to bound with correct new deadlines.
+func TestBoundRenewingBound(t *testing.T) {
+ offer := newResponse(dhcpv4.MessageTypeAck)
+ testIP := net.IP{192, 0, 2, 2}
+ serverIP := net.IP{192, 0, 2, 1}
+ offer.UpdateOption(dhcpv4.OptServerIdentifier(serverIP))
+ leaseTime := 10 * time.Second
+ offer.UpdateOption(dhcpv4.OptIPAddressLeaseTime(leaseTime))
+ offer.YourIPAddr = testIP
+
+ p := newPuppetClient(stateBound)
+ p.umt.Open(serverIP, testIP)
+ p.c.lease, _ = dhcpv4.FromBytes(offer.ToBytes())
+ // Other deadlines are intentionally empty to make sure they aren't used
+ p.c.leaseRenewDeadline = p.ft.Now().Add(8500 * time.Millisecond)
+ p.c.leaseBoundDeadline = p.ft.Now().Add(5000 * time.Millisecond)
+
+ p.ft.Advance(5*time.Second - 5*time.Millisecond)
+ if err := p.c.runState(context.Background()); err != nil {
+ t.Fatal(err)
+ }
+ // We cannot intercept time.After so we just advance the clock by the time slept
+ p.ft.Advance(5 * time.Millisecond)
+ assert.Equal(t, stateRenewing, p.c.state, "DHCP client not renewing")
+ offer.UpdateOption(dhcpv4.OptGeneric(TestOption(1), []byte{0x12}))
+ p.umt.sendPackets(offer)
+ p.c.LeaseCallback = func(lease *Lease) error {
+ assert.Equal(t, testIP, lease.AssignedIP, "callback called with wrong IP")
+ assert.Equal(t, p.ft.Now().Add(leaseTime), lease.ExpiresAt, "invalid ExpiresAt")
+ assert.Equal(t, []byte{0x12}, lease.Options.Get(TestOption(1)), "renewal didn't add new option")
+ return nil
+ }
+ if err := p.c.runState(context.Background()); err != nil {
+ t.Fatal(err)
+ }
+ assert.Equal(t, stateBound, p.c.state, "DHCP client didn't renew")
+ assert.Equal(t, p.ft.Now().Add(leaseTime), p.c.leaseDeadline, "lease deadline not updated")
+ assert.Equal(t, dhcpv4.MessageTypeRequest, p.umt.sentPacket.MessageType(), "Invalid message type for renewal")
+}
+
+// TestRenewingRebinding tests if the DHCP state machine in renewing state
+// correctly sends DHCPREQUESTs and transitions to the rebinding state when it
+// hasn't received a valid response until the deadline expires.
+func TestRenewingRebinding(t *testing.T) {
+ offer := newResponse(dhcpv4.MessageTypeAck)
+ testIP := net.IP{192, 0, 2, 2}
+ serverIP := net.IP{192, 0, 2, 1}
+ offer.UpdateOption(dhcpv4.OptServerIdentifier(serverIP))
+ leaseTime := 10 * time.Second
+ offer.UpdateOption(dhcpv4.OptIPAddressLeaseTime(leaseTime))
+ offer.YourIPAddr = testIP
+
+ p := newPuppetClient(stateRenewing)
+ p.umt.Open(serverIP, testIP)
+ p.c.lease, _ = dhcpv4.FromBytes(offer.ToBytes())
+ // Other deadlines are intentionally empty to make sure they aren't used
+ p.c.leaseRenewDeadline = p.ft.Now().Add(8500 * time.Millisecond)
+ p.c.leaseDeadline = p.ft.Now().Add(10000 * time.Millisecond)
+
+ startTime := p.ft.Now()
+ p.ft.Advance(5 * time.Second)
+
+ p.c.LeaseCallback = func(*Lease) error {
+ t.Fatal("Lease callback called without valid offer")
+ return nil
+ }
+
+ for i := 0; i < 10; i++ {
+ p.umt.sendPackets()
+ if err := p.c.runState(context.Background()); err != nil {
+ t.Fatal(err)
+ }
+ assert.Equal(t, dhcpv4.MessageTypeRequest, p.umt.sentPacket.MessageType(), "Invalid message type for renewal")
+ p.ft.time = p.umt.setDeadline
+
+ if p.c.state == stateRebinding {
+ break
+ }
+ if i == 9 {
+ t.Fatal("Too many tries while renewing, backoff likely wrong")
+ }
+ }
+ assert.Equal(t, startTime.Add(8500*time.Millisecond), p.umt.setDeadline, "wrong listen deadline when renewing")
+ assert.Equal(t, stateRebinding, p.c.state, "DHCP client not renewing")
+ assert.False(t, p.bmt.closed)
+ assert.True(t, p.umt.closed)
+}
+
+// TestRebindingBound tests if the DHCP state machine in rebinding state sends
+// DHCPREQUESTs to the network and if it receives a valid DHCPACK correctly
+// transitions back to bound state.
+func TestRebindingBound(t *testing.T) {
+ offer := newResponse(dhcpv4.MessageTypeAck)
+ testIP := net.IP{192, 0, 2, 2}
+ serverIP := net.IP{192, 0, 2, 1}
+ offer.UpdateOption(dhcpv4.OptServerIdentifier(serverIP))
+ leaseTime := 10 * time.Second
+ offer.UpdateOption(dhcpv4.OptIPAddressLeaseTime(leaseTime))
+ offer.YourIPAddr = testIP
+
+ p := newPuppetClient(stateRebinding)
+ p.c.lease, _ = dhcpv4.FromBytes(offer.ToBytes())
+ // Other deadlines are intentionally empty to make sure they aren't used
+ p.c.leaseDeadline = p.ft.Now().Add(10000 * time.Millisecond)
+
+ p.ft.Advance(9 * time.Second)
+ if err := p.c.runState(context.Background()); err != nil {
+ t.Fatal(err)
+ }
+ assert.Equal(t, dhcpv4.MessageTypeRequest, p.bmt.sentPacket.MessageType(), "DHCP rebind sent invalid message type")
+ assert.Equal(t, stateRebinding, p.c.state, "DHCP client transferred out of rebinding state without trigger")
+ offer.UpdateOption(dhcpv4.OptGeneric(TestOption(1), []byte{0x12})) // Mark answer
+ p.bmt.sendPackets(offer)
+ p.bmt.sentPacket = nil
+ p.c.LeaseCallback = func(lease *Lease) error {
+ assert.Equal(t, testIP, lease.AssignedIP, "callback called with wrong IP")
+ assert.Equal(t, p.ft.Now().Add(leaseTime), lease.ExpiresAt, "invalid ExpiresAt")
+ assert.Equal(t, []byte{0x12}, lease.Options.Get(TestOption(1)), "renewal didn't add new option")
+ return nil
+ }
+ if err := p.c.runState(context.Background()); err != nil {
+ t.Fatal(err)
+ }
+ assert.Equal(t, dhcpv4.MessageTypeRequest, p.bmt.sentPacket.MessageType())
+ assert.Equal(t, stateBound, p.c.state, "DHCP client didn't go back to bound")
+}
+
+// TestRebindingBound tests if the DHCP state machine in rebinding state
+// transitions to discovering state if leaseDeadline expires and calls
+// LeaseCallback with an empty new lease.
+func TestRebindingDiscovering(t *testing.T) {
+ offer := newResponse(dhcpv4.MessageTypeAck)
+ testIP := net.IP{192, 0, 2, 2}
+ serverIP := net.IP{192, 0, 2, 1}
+ offer.UpdateOption(dhcpv4.OptServerIdentifier(serverIP))
+ leaseTime := 10 * time.Second
+ offer.UpdateOption(dhcpv4.OptIPAddressLeaseTime(leaseTime))
+ offer.YourIPAddr = testIP
+
+ p := newPuppetClient(stateRebinding)
+ p.c.lease, _ = dhcpv4.FromBytes(offer.ToBytes())
+ // Other deadlines are intentionally empty to make sure they aren't used
+ p.c.leaseDeadline = p.ft.Now().Add(10000 * time.Millisecond)
+
+ p.ft.Advance(9 * time.Second)
+ p.c.LeaseCallback = func(lease *Lease) error {
+ assert.Nil(t, lease, "transition to discovering didn't clear new lease on callback")
+ return nil
+ }
+ for i := 0; i < 10; i++ {
+ p.bmt.sendPackets()
+ p.bmt.sentPacket = nil
+ if err := p.c.runState(context.Background()); err != nil {
+ t.Fatal(err)
+ }
+ if p.c.state == stateDiscovering {
+ assert.Nil(t, p.bmt.sentPacket)
+ break
+ }
+ assert.Equal(t, dhcpv4.MessageTypeRequest, p.bmt.sentPacket.MessageType(), "Invalid message type for rebind")
+ p.ft.time = p.bmt.setDeadline
+ if i == 9 {
+ t.Fatal("Too many tries while rebinding, backoff likely wrong")
+ }
+ }
+ assert.Nil(t, p.c.lease, "Lease not zeroed on transition to discovering")
+ assert.Equal(t, stateDiscovering, p.c.state, "DHCP client didn't transition to discovering after loosing lease")
+}
diff --git a/osbase/net/dhcp4c/doc.go b/osbase/net/dhcp4c/doc.go
new file mode 100644
index 0000000..622452b
--- /dev/null
+++ b/osbase/net/dhcp4c/doc.go
@@ -0,0 +1,55 @@
+// Copyright The Monogon Project Authors.
+// SPDX-License-Identifier: Apache-2.0
+
+// Package dhcp4c provides a client implementation of the DHCPv4 protocol
+// (RFC2131) and a few extensions for Linux-based systems.
+// The code is split into three main parts:
+// - The core DHCP state machine, which lives in dhcpc.go
+// - Mechanisms to send and receive DHCP messages, which live in transport/
+// - Standard callbacks which implement necessary kernel configuration steps in
+// a simple and standalone way living in callback/
+//
+// Since the DHCP protocol is ugly and underspecified (see
+// https://tools.ietf.org/html/draft-ietf-dhc-implementation-02 for a subset of
+// known issues), this client slightly bends the specification in the following
+// cases:
+// - IP fragmentation for DHCP messages is not supported for both sending and
+// receiving messages This is because the major servers (ISC, dnsmasq, ...)
+// do not implement it and just drop fragmented packets, so it would be
+// counterproductive to try to send them. The client just attempts to send
+// the full message and hopes it passes through to the server.
+// - The suggested timeouts and wait periods have been tightened significantly.
+// When the standard was written 10Mbps Ethernet with hubs was a common
+// interconnect. Using these would make the client extremely slow on today's
+// 1Gbps+ networks.
+// - Wrong data in DHCP responses is fixed up if possible. This fixing includes
+// dropping prohibited options, clamping semantically invalid data and
+// defaulting not set options as far as it's possible. Non-recoverable
+// responses (for example because a non-Unicast IP is handed out or lease
+// time is not set or zero) are still ignored. All data which can be stored
+// in both DHCP fields and options is also normalized to the corresponding
+// option.
+// - Duplicate Address Detection is not implemented by default. It's slow, hard
+// to implement correctly and generally not necessary on modern networks as
+// the servers already waste time checking for duplicate addresses. It's
+// possible to hook it in via a LeaseCallback if necessary in a given
+// application.
+//
+// Operationally, there's one known caveat to using this client: If the lease
+// offered during the select phase (in a DHCPOFFER) is not the same as the one
+// sent in the following DHCPACK the first one might be acceptable, but the
+// second one might not be. This can cause pathological behavior where the
+// client constantly switches between discovering and requesting states.
+// Depending on the reuse policies on the DHCP server this can cause the client
+// to consume all available IP addresses. Sadly there's no good way of fixing
+// this within the boundaries of the protocol. A DHCPRELEASE for the adresse
+// would need to be unicasted so the unaccepable address would need to be
+// configured which can be either impossible if it's not valid or not
+// acceptable from a security standpoint (for example because it overlaps with
+// a prefix used internally) and a DHCPDECLINE would cause the server to
+// blacklist the IP thus also depleting the IP pool.
+// This could be potentially avoided by originating DHCPRELEASE packages from a
+// userspace transport, but said transport would need to be routing- and
+// PMTU-aware which would make it even more complicated than the existing
+// BroadcastTransport.
+package dhcp4c
diff --git a/osbase/net/dhcp4c/lease.go b/osbase/net/dhcp4c/lease.go
new file mode 100644
index 0000000..fefe092
--- /dev/null
+++ b/osbase/net/dhcp4c/lease.go
@@ -0,0 +1,229 @@
+// Copyright The Monogon Project Authors.
+// SPDX-License-Identifier: Apache-2.0
+
+package dhcp4c
+
+import (
+ "encoding/binary"
+ "net"
+ "time"
+
+ "github.com/insomniacslk/dhcp/dhcpv4"
+)
+
+// Lease represents a DHCPv4 lease. It only consists of an IP, an expiration
+// timestamp and options as all other relevant parts of the message have been
+// normalized into their respective options. It also contains some smart
+// getters for commonly-used options which extract only valid information from
+// options.
+type Lease struct {
+ AssignedIP net.IP
+ ExpiresAt time.Time
+ Options dhcpv4.Options
+}
+
+// SubnetMask returns the SubnetMask option or the default mask if not set or
+// invalid.
+// It returns nil if the lease is nil.
+func (l *Lease) SubnetMask() net.IPMask {
+ if l == nil {
+ return nil
+ }
+ mask := net.IPMask(dhcpv4.GetIP(dhcpv4.OptionSubnetMask, l.Options))
+ // If given mask is not valid, use the default mask.
+ if _, bits := mask.Size(); bits != 32 {
+ mask = l.AssignedIP.DefaultMask()
+ }
+ return mask
+}
+
+// IPNet returns an IPNet from the assigned IP and subnet mask.
+// It returns nil if the lease is nil.
+func (l *Lease) IPNet() *net.IPNet {
+ if l == nil {
+ return nil
+ }
+ return &net.IPNet{
+ IP: l.AssignedIP,
+ Mask: l.SubnetMask(),
+ }
+}
+
+// Routes returns all routes assigned by a DHCP answer. It combines and
+// normalizes data from the Router, StaticRoutingTable and ClasslessStaticRoute
+// options.
+func (l *Lease) Routes() []*dhcpv4.Route {
+ if l == nil {
+ return nil
+ }
+
+ // Note that this is different from l.IPNet() because we care about the
+ // network base address of the network instead of the assigned IP.
+ assignedNet := &net.IPNet{IP: l.AssignedIP.Mask(l.SubnetMask()), Mask: l.SubnetMask()}
+
+ // RFC 3442 Section DHCP Client Behavior:
+ // If the DHCP server returns both a Classless Static Routes option and
+ // a Router option, the DHCP client MUST ignore the Router option.
+ // Similarly, if the DHCP server returns both a Classless Static Routes
+ // option and a Static Routes option, the DHCP client MUST ignore the
+ // Static Routes option.
+ var routes dhcpv4.Routes
+ rawCIDRRoutes := l.Options.Get(dhcpv4.OptionClasslessStaticRoute)
+ if rawCIDRRoutes != nil {
+ // TODO(#96): This needs a logging story
+ // Ignore errors intentionally and just return what has been parsed
+ _ = routes.FromBytes(rawCIDRRoutes)
+ return sanitizeRoutes(routes, assignedNet)
+ }
+ // The Static Routes option contains legacy classful routes (i.e. routes
+ // whose mask is determined by the IP of the network).
+ // Each static route is expressed as a pair of IPs, the first one being
+ // the destination network and the second one being the router IP.
+ // See RFC 2132 Section 5.8 for further details.
+ legacyRouteIPs := dhcpv4.GetIPs(dhcpv4.OptionStaticRoutingTable, l.Options)
+ // Routes are only valid in pairs, cut the last one off if necessary
+ if len(legacyRouteIPs)%2 != 0 {
+ legacyRouteIPs = legacyRouteIPs[:len(legacyRouteIPs)-1]
+ }
+ for i := 0; i < len(legacyRouteIPs)/2; i++ {
+ dest := legacyRouteIPs[i*2]
+ if dest.IsUnspecified() {
+ // RFC 2132 Section 5.8:
+ // The default route (0.0.0.0) is an illegal destination for a
+ // static route.
+ continue
+ }
+ via := legacyRouteIPs[i*2+1]
+ destNet := net.IPNet{
+ // Apply the default mask just to make sure this is a valid route
+ IP: dest.Mask(dest.DefaultMask()),
+ Mask: dest.DefaultMask(),
+ }
+ routes = append(routes, &dhcpv4.Route{Dest: &destNet, Router: via})
+ }
+ for _, r := range dhcpv4.GetIPs(dhcpv4.OptionRouter, l.Options) {
+ if r.IsGlobalUnicast() || r.IsLinkLocalUnicast() {
+ routes = append(routes, &dhcpv4.Route{
+ Dest: &net.IPNet{IP: net.IPv4zero, Mask: net.IPv4Mask(0, 0, 0, 0)},
+ Router: r,
+ })
+ // Only one default router can exist, exit after the first one
+ break
+ }
+ }
+ return sanitizeRoutes(routes, assignedNet)
+}
+
+// sanitizeRoutes filters the list of routes by removing routes that are
+// obviously invalid. It filters out routes according to the following criteria:
+// 1. The route is not an interface route and its router is not a unicast or
+// link-local address.
+// 2. Each route's router must be reachable according to the routes listed
+// before it and the assigned network.
+// 3. The network mask must consist of all-ones followed by all-zeros. Non-
+// contiguous routes are not allowed.
+// 4. If multiple routes match the same destination, only the first one is kept.
+// 5. Routes covering the loopback IP space (127.0.0.0/8) will be ignored if
+// they are smaller than a /9 to prevent them from interfering with loopback
+// IPs.
+func sanitizeRoutes(routes []*dhcpv4.Route, assignedNet *net.IPNet) []*dhcpv4.Route {
+ var saneRoutes []*dhcpv4.Route
+ for _, route := range routes {
+ if route.Router != nil && !route.Router.IsUnspecified() {
+ if !route.Router.IsGlobalUnicast() && !route.Router.IsLinkLocalUnicast() {
+ // Ignore non-unicast routers
+ continue
+ }
+ reachable := false
+ for _, r := range saneRoutes {
+ if r.Dest.Contains(route.Router) {
+ reachable = true
+ break
+ }
+ }
+ if assignedNet.Contains(route.Router) {
+ reachable = true
+ }
+ if !reachable {
+ continue
+ }
+ }
+ ones, bits := route.Dest.Mask.Size()
+ if bits == 0 && len(route.Dest.Mask) > 0 {
+ // Bitmask is not ones followed by zeros, i.e. invalid
+ continue
+ }
+ // Ignore routes that would be able to redirect loopback IPs
+ if route.Dest.IP.IsLoopback() && ones >= 8 {
+ continue
+ }
+ // Ignore routes that would shadow the implicit interface route
+ assignedOnes, _ := assignedNet.Mask.Size()
+ if assignedNet.IP.Equal(route.Dest.IP) && assignedOnes == ones {
+ continue
+ }
+ validDest := true
+ for _, r := range saneRoutes {
+ rOnes, _ := r.Dest.Mask.Size()
+ if r.Dest.IP.Equal(route.Dest.IP) && ones == rOnes {
+ // Exact duplicate, ignore
+ validDest = false
+ break
+ }
+ }
+ if validDest {
+ saneRoutes = append(saneRoutes, route)
+ }
+ }
+ return saneRoutes
+}
+
+// DNSServers represents an ordered collection of DNS servers
+type DNSServers []net.IP
+
+func (a DNSServers) Equal(b DNSServers) bool {
+ if len(a) == len(b) {
+ if len(a) == 0 {
+ return true // both are empty or nil
+ }
+ for i, aVal := range a {
+ if !aVal.Equal(b[i]) {
+ return false
+ }
+ }
+ return true
+ }
+ return false
+
+}
+
+func ip4toInt(ip net.IP) uint32 {
+ ip4 := ip.To4()
+ if ip4 == nil {
+ return 0
+ }
+ return binary.BigEndian.Uint32(ip4)
+}
+
+// DNSServers returns all unique valid DNS servers from the DHCP
+// DomainNameServers options.
+// It returns nil if the lease is nil.
+func (l *Lease) DNSServers() DNSServers {
+ if l == nil {
+ return nil
+ }
+ rawServers := dhcpv4.GetIPs(dhcpv4.OptionDomainNameServer, l.Options)
+ var servers DNSServers
+ serversSeenMap := make(map[uint32]bool)
+ for _, s := range rawServers {
+ ip4Num := ip4toInt(s)
+ if s.IsGlobalUnicast() || s.IsLinkLocalUnicast() {
+ if serversSeenMap[ip4Num] {
+ continue
+ }
+ serversSeenMap[ip4Num] = true
+ servers = append(servers, s)
+ }
+ }
+ return servers
+}
diff --git a/osbase/net/dhcp4c/lease_test.go b/osbase/net/dhcp4c/lease_test.go
new file mode 100644
index 0000000..5d233df
--- /dev/null
+++ b/osbase/net/dhcp4c/lease_test.go
@@ -0,0 +1,156 @@
+// Copyright The Monogon Project Authors.
+// SPDX-License-Identifier: Apache-2.0
+
+package dhcp4c
+
+import (
+ "bytes"
+ "net"
+ "testing"
+
+ "github.com/insomniacslk/dhcp/dhcpv4"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestLeaseDHCPServers(t *testing.T) {
+ var tests = []struct {
+ name string
+ lease *Lease
+ expected DNSServers
+ }{{
+ name: "ReturnsNilWithNoLease",
+ lease: nil,
+ expected: nil,
+ }, {
+ name: "DiscardsInvalidIPs",
+ lease: &Lease{
+ Options: dhcpv4.OptionsFromList(dhcpv4.OptDNS(net.IP{0, 0, 0, 0})),
+ },
+ expected: nil,
+ }, {
+ name: "DeduplicatesIPs",
+ lease: &Lease{
+ Options: dhcpv4.OptionsFromList(dhcpv4.OptDNS(net.IP{192, 0, 2, 1}, net.IP{192, 0, 2, 2}, net.IP{192, 0, 2, 1})),
+ },
+ expected: DNSServers{net.IP{192, 0, 2, 1}, net.IP{192, 0, 2, 2}},
+ }}
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ res := test.lease.DNSServers()
+ assert.Equal(t, test.expected, res)
+ })
+ }
+}
+
+func makeIPNet(cidr string) *net.IPNet {
+ _, n, err := net.ParseCIDR(cidr)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+type testRoute struct {
+ dest string
+ via string
+ expected bool
+}
+
+func (t testRoute) toRealRoute() *dhcpv4.Route {
+ ip, n, err := net.ParseCIDR(t.dest)
+ if err != nil {
+ panic(err)
+ }
+ if !ip.Equal(n.IP) {
+ panic("testRoute is not aligned to route boundary")
+ }
+ routerIP := net.ParseIP(t.via)
+ if routerIP == nil && t.via != "" {
+ panic("routerIP is not valid")
+ }
+ return &dhcpv4.Route{
+ Dest: n,
+ Router: routerIP,
+ }
+}
+
+func TestSanitizeRoutes(t *testing.T) {
+ var tests = []struct {
+ name string
+ assignedNet *net.IPNet
+ routes []testRoute
+ }{{
+ name: "SimpleAdditionalRoute",
+ assignedNet: makeIPNet("10.0.5.0/24"),
+ routes: []testRoute{
+ {"10.0.7.0/24", "10.0.5.1", true},
+ },
+ }, {
+ name: "OutOfNetworkGateway",
+ assignedNet: makeIPNet("10.5.0.2/32"),
+ routes: []testRoute{
+ {"10.0.7.1/32", "", true},
+ {"0.0.0.0/0", "10.0.7.1", true},
+ },
+ }, {
+ name: "InvalidRouter",
+ assignedNet: makeIPNet("10.0.5.0/24"),
+ routes: []testRoute{
+ // Router is localhost
+ {"10.0.7.0/24", "127.0.0.1", false},
+ // Router is unreachable
+ {"10.0.8.0/24", "10.254.0.1", false},
+ },
+ }, {
+ name: "SameDestinationRoutes",
+ assignedNet: makeIPNet("10.0.5.0/24"),
+ routes: []testRoute{
+ {"0.0.0.0/0", "10.0.5.1", true},
+ {"10.0.7.0/24", "10.0.5.1", true},
+ {"0.0.0.0/0", "10.0.7.1", false},
+ },
+ }, {
+ name: "RoutesShadowingLoopback",
+ assignedNet: makeIPNet("10.0.5.0/24"),
+ routes: []testRoute{
+ // Default route, technically covers 127.0.0.0/8, but less-specific
+ {"0.0.0.0/0", "10.0.5.1", true},
+ // 127.0.0.0/8 is still more-specific
+ {"126.0.0.0/7", "10.0.5.1", true},
+ // Duplicate of 127.0.0.0/8, behavior undefined, disallowed
+ {"127.0.0.0/8", "10.0.5.1", false},
+ // Shadows localhost, disallowed
+ {"127.0.0.1/32", "10.0.5.1", false},
+ },
+ }}
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ var routes []*dhcpv4.Route
+ var expectedRoutes []*dhcpv4.Route
+ for _, r := range test.routes {
+ routes = append(routes, r.toRealRoute())
+ if r.expected {
+ expectedRoutes = append(expectedRoutes, r.toRealRoute())
+ }
+ }
+ out := sanitizeRoutes(routes, test.assignedNet)
+ if len(out) != len(expectedRoutes) {
+ t.Errorf("expected %d routes, got %d", len(expectedRoutes), len(out))
+ t.Error("Expected:")
+ for _, r := range expectedRoutes {
+ t.Errorf("\t%s via %s", r.Dest, r.Router)
+ }
+ t.Error("Actual:")
+ for _, r := range out {
+ t.Errorf("\t%s via %s", r.Dest, r.Router)
+ }
+ return
+ }
+ for i, r := range expectedRoutes {
+ if !out[i].Router.Equal(r.Router) || !out[i].Dest.IP.Equal(r.Dest.IP) || !bytes.Equal(out[i].Dest.Mask, r.Dest.Mask) {
+ t.Errorf("expected %s via %s, got %s via %s", r.Dest, r.Router, out[i].Dest, out[i].Router)
+ }
+ }
+ })
+ }
+}
diff --git a/osbase/net/dhcp4c/transport/BUILD.bazel b/osbase/net/dhcp4c/transport/BUILD.bazel
new file mode 100644
index 0000000..8ba0830
--- /dev/null
+++ b/osbase/net/dhcp4c/transport/BUILD.bazel
@@ -0,0 +1,20 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "transport",
+ srcs = [
+ "transport.go",
+ "transport_broadcast.go",
+ "transport_unicast.go",
+ ],
+ importpath = "source.monogon.dev/osbase/net/dhcp4c/transport",
+ visibility = ["//osbase/net/dhcp4c:__subpackages__"],
+ deps = [
+ "@com_github_google_gopacket//:gopacket",
+ "@com_github_google_gopacket//layers",
+ "@com_github_insomniacslk_dhcp//dhcpv4",
+ "@com_github_mdlayher_packet//:packet",
+ "@org_golang_x_net//bpf",
+ "@org_golang_x_sys//unix",
+ ],
+)
diff --git a/osbase/net/dhcp4c/transport/transport.go b/osbase/net/dhcp4c/transport/transport.go
new file mode 100644
index 0000000..9a5ff14
--- /dev/null
+++ b/osbase/net/dhcp4c/transport/transport.go
@@ -0,0 +1,38 @@
+// Copyright The Monogon Project Authors.
+// SPDX-License-Identifier: Apache-2.0
+
+// Package transport contains Linux-based transports for the DHCP broadcast and
+// unicast specifications.
+package transport
+
+import (
+ "errors"
+ "fmt"
+ "net"
+)
+
+var ErrDeadlineExceeded = errors.New("deadline exceeded")
+
+func NewInvalidMessageError(internalErr error) error {
+ return &InvalidMessageError{internalErr: internalErr}
+}
+
+type InvalidMessageError struct {
+ internalErr error
+}
+
+func (i InvalidMessageError) Error() string {
+ return fmt.Sprintf("received invalid packet: %v", i.internalErr.Error())
+}
+
+func (i InvalidMessageError) Unwrap() error {
+ return i.internalErr
+}
+
+func deadlineFromTimeout(err error) error {
+ var timeoutErr net.Error
+ if errors.As(err, &timeoutErr) && timeoutErr.Timeout() {
+ return ErrDeadlineExceeded
+ }
+ return err
+}
diff --git a/osbase/net/dhcp4c/transport/transport_broadcast.go b/osbase/net/dhcp4c/transport/transport_broadcast.go
new file mode 100644
index 0000000..b61af80
--- /dev/null
+++ b/osbase/net/dhcp4c/transport/transport_broadcast.go
@@ -0,0 +1,199 @@
+// Copyright The Monogon Project Authors.
+// SPDX-License-Identifier: Apache-2.0
+
+package transport
+
+import (
+ "errors"
+ "fmt"
+ "math"
+ "net"
+ "time"
+
+ "github.com/google/gopacket"
+ "github.com/google/gopacket/layers"
+ "github.com/insomniacslk/dhcp/dhcpv4"
+ "github.com/mdlayher/packet"
+ "golang.org/x/net/bpf"
+)
+
+const (
+ // RFC2474 Section 4.2.2.1 with reference to RFC791 Section 3.1 (Network
+ // Control Precedence)
+ dscpCS7 = 0x7 << 3
+
+ // IPv4 MTU
+ maxIPv4MTU = math.MaxUint16 // IPv4 "Total Length" field is an unsigned 16 bit integer
+)
+
+// mustAssemble calls bpf.Assemble and panics if it retuns an error.
+func mustAssemble(insns []bpf.Instruction) []bpf.RawInstruction {
+ rawInsns, err := bpf.Assemble(insns)
+ if err != nil {
+ panic("mustAssemble failed to assemble BPF: " + err.Error())
+ }
+ return rawInsns
+}
+
+// BPF filter for UDP in IPv4 with destination port 68 (DHCP Client)
+//
+// This is used to make the kernel drop non-DHCP traffic for us so that we
+// don't have to handle excessive unrelated traffic flowing on a given link
+// which might overwhelm the single-threaded receiver.
+var bpfFilterInstructions = []bpf.Instruction{
+ // Check IP protocol version equals 4 (first 4 bits of the first byte)
+ // With Ethernet II framing, this is more of a sanity check. We already
+ // request the kernel to only return EtherType 0x0800 (IPv4) frames.
+ bpf.LoadAbsolute{Off: 0, Size: 1},
+ bpf.ALUOpConstant{Op: bpf.ALUOpAnd, Val: 0xf0}, // SubnetMask second 4 bits
+ bpf.JumpIf{Cond: bpf.JumpEqual, Val: 4 << 4, SkipTrue: 1},
+ bpf.RetConstant{Val: 0}, // Discard
+
+ // Check IPv4 Protocol byte (offset 9) equals UDP
+ bpf.LoadAbsolute{Off: 9, Size: 1},
+ bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(layers.IPProtocolUDP), SkipTrue: 1},
+ bpf.RetConstant{Val: 0}, // Discard
+
+ // Check if IPv4 fragment offset is all-zero (this is not a fragment)
+ bpf.LoadAbsolute{Off: 6, Size: 2},
+ bpf.JumpIf{Cond: bpf.JumpBitsSet, Val: 0x1fff, SkipFalse: 1},
+ bpf.RetConstant{Val: 0}, // Discard
+
+ // Load IPv4 header size from offset zero and store it into X
+ bpf.LoadMemShift{Off: 0},
+
+ // Check if UDP header destination port equals 68
+ bpf.LoadIndirect{Off: 2, Size: 2}, // Offset relative to header size in register X
+ bpf.JumpIf{Cond: bpf.JumpEqual, Val: 68, SkipTrue: 1},
+ bpf.RetConstant{Val: 0}, // Discard
+
+ // Accept packet and pass through up maximum IP packet size
+ bpf.RetConstant{Val: maxIPv4MTU},
+}
+
+var bpfFilter = mustAssemble(bpfFilterInstructions)
+
+// BroadcastTransport implements a DHCP transport based on a custom IP/UDP
+// stack fulfilling the specific requirements for broadcasting DHCP packets
+// (like all-zero source address, no ARP, ...)
+type BroadcastTransport struct {
+ rawConn *packet.Conn
+ iface *net.Interface
+}
+
+func NewBroadcastTransport(iface *net.Interface) *BroadcastTransport {
+ return &BroadcastTransport{iface: iface}
+}
+
+func (t *BroadcastTransport) Open() error {
+ if t.rawConn != nil {
+ return errors.New("broadcast transport already open")
+ }
+ rawConn, err := packet.Listen(t.iface, packet.Datagram, int(layers.EthernetTypeIPv4), &packet.Config{
+ Filter: bpfFilter,
+ })
+ if err != nil {
+ return fmt.Errorf("failed to create raw listener: %w", err)
+ }
+ t.rawConn = rawConn
+ return nil
+}
+
+func (t *BroadcastTransport) Send(payload *dhcpv4.DHCPv4) error {
+ if t.rawConn == nil {
+ return errors.New("broadcast transport closed")
+ }
+ pkt := gopacket.NewSerializeBuffer()
+ opts := gopacket.SerializeOptions{
+ ComputeChecksums: true,
+ FixLengths: true,
+ }
+
+ ipLayer := &layers.IPv4{
+ Version: 4,
+ // Shift left of ECN field
+ TOS: dscpCS7 << 2,
+ // These packets should never be routed (their IP headers contain
+ // garbage)
+ TTL: 1,
+ Protocol: layers.IPProtocolUDP,
+ // Most DHCP servers don't support fragmented packets.
+ Flags: layers.IPv4DontFragment,
+ DstIP: net.IPv4bcast,
+ SrcIP: net.IPv4zero,
+ }
+ udpLayer := &layers.UDP{
+ DstPort: 67,
+ SrcPort: 68,
+ }
+ if err := udpLayer.SetNetworkLayerForChecksum(ipLayer); err != nil {
+ panic("Invalid layer stackup encountered")
+ }
+
+ err := gopacket.SerializeLayers(pkt, opts,
+ ipLayer,
+ udpLayer,
+ gopacket.Payload(payload.ToBytes()))
+
+ if err != nil {
+ return fmt.Errorf("failed to assemble packet: %w", err)
+ }
+
+ _, err = t.rawConn.WriteTo(pkt.Bytes(), &packet.Addr{HardwareAddr: layers.EthernetBroadcast})
+ if err != nil {
+ return fmt.Errorf("failed to transmit broadcast packet: %w", err)
+ }
+ return nil
+}
+
+func (t *BroadcastTransport) Receive() (*dhcpv4.DHCPv4, error) {
+ if t.rawConn == nil {
+ return nil, errors.New("broadcast transport closed")
+ }
+ buf := make([]byte, math.MaxUint16) // Maximum IP packet size
+ n, _, err := t.rawConn.ReadFrom(buf)
+ if err != nil {
+ return nil, deadlineFromTimeout(err)
+ }
+ respPacket := gopacket.NewPacket(buf[:n], layers.LayerTypeIPv4, gopacket.Default)
+ ipLayer := respPacket.Layer(layers.LayerTypeIPv4)
+ if ipLayer == nil {
+ return nil, NewInvalidMessageError(errors.New("got invalid IP packet"))
+ }
+ ip := ipLayer.(*layers.IPv4)
+ if ip.Flags&layers.IPv4MoreFragments != 0 {
+ return nil, NewInvalidMessageError(errors.New("got fragmented message"))
+ }
+
+ udpLayer := respPacket.Layer(layers.LayerTypeUDP)
+ if udpLayer == nil {
+ return nil, NewInvalidMessageError(errors.New("got non-UDP packet"))
+ }
+ udp := udpLayer.(*layers.UDP)
+ if udp.DstPort != 68 {
+ return nil, NewInvalidMessageError(errors.New("message not for DHCP client port"))
+ }
+ msg, err := dhcpv4.FromBytes(udp.Payload)
+ if err != nil {
+ return nil, NewInvalidMessageError(fmt.Errorf("failed to decode DHCPv4 message: %w", err))
+ }
+ return msg, nil
+}
+
+func (t *BroadcastTransport) Close() error {
+ if t.rawConn == nil {
+ return nil
+ }
+ if err := t.rawConn.Close(); err != nil {
+ return err
+ }
+ t.rawConn = nil
+ return nil
+}
+
+func (t *BroadcastTransport) SetReceiveDeadline(deadline time.Time) error {
+ if t.rawConn == nil {
+ return errors.New("broadcast transport closed")
+ }
+ return t.rawConn.SetReadDeadline(deadline)
+}
diff --git a/osbase/net/dhcp4c/transport/transport_unicast.go b/osbase/net/dhcp4c/transport/transport_unicast.go
new file mode 100644
index 0000000..b76e37c
--- /dev/null
+++ b/osbase/net/dhcp4c/transport/transport_unicast.go
@@ -0,0 +1,108 @@
+// Copyright The Monogon Project Authors.
+// SPDX-License-Identifier: Apache-2.0
+
+package transport
+
+import (
+ "errors"
+ "fmt"
+ "math"
+ "net"
+ "os"
+ "time"
+
+ "github.com/insomniacslk/dhcp/dhcpv4"
+ "golang.org/x/sys/unix"
+)
+
+// UnicastTransport implements a DHCP transport based on a normal Linux UDP
+// socket with some custom socket options to influence DSCP and routing.
+type UnicastTransport struct {
+ udpConn *net.UDPConn
+ targetIP net.IP
+ iface *net.Interface
+}
+
+func NewUnicastTransport(iface *net.Interface) *UnicastTransport {
+ return &UnicastTransport{
+ iface: iface,
+ }
+}
+
+func (t *UnicastTransport) Open(serverIP, bindIP net.IP) error {
+ if t.udpConn != nil {
+ return errors.New("unicast transport already open")
+ }
+ rawFd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, 0)
+ if err != nil {
+ return fmt.Errorf("failed to get socket: %w", err)
+ }
+ if err := unix.BindToDevice(rawFd, t.iface.Name); err != nil {
+ return fmt.Errorf("failed to bind UDP interface to device: %w", err)
+ }
+ if err := unix.SetsockoptByte(rawFd, unix.SOL_IP, unix.IP_TOS, dscpCS7<<2); err != nil {
+ return fmt.Errorf("failed to set DSCP CS7: %w", err)
+ }
+ var addr [4]byte
+ copy(addr[:], bindIP.To4())
+ if err := unix.Bind(rawFd, &unix.SockaddrInet4{Addr: addr, Port: 68}); err != nil {
+ return fmt.Errorf("failed to bind UDP unicast interface: %w", err)
+ }
+ filePtr := os.NewFile(uintptr(rawFd), "dhcp-udp")
+ defer filePtr.Close()
+ conn, err := net.FileConn(filePtr)
+ if err != nil {
+ return fmt.Errorf("failed to initialize runtime-supported UDP connection: %w", err)
+ }
+ realConn, ok := conn.(*net.UDPConn)
+ if !ok {
+ panic("UDP socket imported into Go runtime is no longer a UDP socket")
+ }
+ t.udpConn = realConn
+ t.targetIP = serverIP
+ return nil
+}
+
+func (t *UnicastTransport) Send(payload *dhcpv4.DHCPv4) error {
+ if t.udpConn == nil {
+ return errors.New("unicast transport closed")
+ }
+ _, _, err := t.udpConn.WriteMsgUDP(payload.ToBytes(), []byte{}, &net.UDPAddr{
+ IP: t.targetIP,
+ Port: 67,
+ })
+ return err
+}
+
+func (t *UnicastTransport) SetReceiveDeadline(deadline time.Time) error {
+ return t.udpConn.SetReadDeadline(deadline)
+}
+
+func (t *UnicastTransport) Receive() (*dhcpv4.DHCPv4, error) {
+ if t.udpConn == nil {
+ return nil, errors.New("unicast transport closed")
+ }
+ receiveBuf := make([]byte, math.MaxUint16)
+ _, _, err := t.udpConn.ReadFromUDP(receiveBuf)
+ if err != nil {
+ return nil, deadlineFromTimeout(err)
+ }
+ msg, err := dhcpv4.FromBytes(receiveBuf)
+ if err != nil {
+ return nil, NewInvalidMessageError(err)
+ }
+ return msg, nil
+}
+
+func (t *UnicastTransport) Close() error {
+ if t.udpConn == nil {
+ return nil
+ }
+ err := t.udpConn.Close()
+ t.udpConn = nil
+ if err != nil && errors.Is(err, net.ErrClosed) {
+ //nolint:returnerrcheck
+ return nil
+ }
+ return err
+}