blob: b61af8007e5bbf1d790dc2ce4a339d9851d969af [file] [log] [blame]
// Copyright The Monogon Project Authors.
// SPDX-License-Identifier: Apache-2.0
package transport
import (
"errors"
"fmt"
"math"
"net"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/insomniacslk/dhcp/dhcpv4"
"github.com/mdlayher/packet"
"golang.org/x/net/bpf"
)
const (
// RFC2474 Section 4.2.2.1 with reference to RFC791 Section 3.1 (Network
// Control Precedence)
dscpCS7 = 0x7 << 3
// IPv4 MTU
maxIPv4MTU = math.MaxUint16 // IPv4 "Total Length" field is an unsigned 16 bit integer
)
// mustAssemble calls bpf.Assemble and panics if it retuns an error.
func mustAssemble(insns []bpf.Instruction) []bpf.RawInstruction {
rawInsns, err := bpf.Assemble(insns)
if err != nil {
panic("mustAssemble failed to assemble BPF: " + err.Error())
}
return rawInsns
}
// BPF filter for UDP in IPv4 with destination port 68 (DHCP Client)
//
// This is used to make the kernel drop non-DHCP traffic for us so that we
// don't have to handle excessive unrelated traffic flowing on a given link
// which might overwhelm the single-threaded receiver.
var bpfFilterInstructions = []bpf.Instruction{
// Check IP protocol version equals 4 (first 4 bits of the first byte)
// With Ethernet II framing, this is more of a sanity check. We already
// request the kernel to only return EtherType 0x0800 (IPv4) frames.
bpf.LoadAbsolute{Off: 0, Size: 1},
bpf.ALUOpConstant{Op: bpf.ALUOpAnd, Val: 0xf0}, // SubnetMask second 4 bits
bpf.JumpIf{Cond: bpf.JumpEqual, Val: 4 << 4, SkipTrue: 1},
bpf.RetConstant{Val: 0}, // Discard
// Check IPv4 Protocol byte (offset 9) equals UDP
bpf.LoadAbsolute{Off: 9, Size: 1},
bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(layers.IPProtocolUDP), SkipTrue: 1},
bpf.RetConstant{Val: 0}, // Discard
// Check if IPv4 fragment offset is all-zero (this is not a fragment)
bpf.LoadAbsolute{Off: 6, Size: 2},
bpf.JumpIf{Cond: bpf.JumpBitsSet, Val: 0x1fff, SkipFalse: 1},
bpf.RetConstant{Val: 0}, // Discard
// Load IPv4 header size from offset zero and store it into X
bpf.LoadMemShift{Off: 0},
// Check if UDP header destination port equals 68
bpf.LoadIndirect{Off: 2, Size: 2}, // Offset relative to header size in register X
bpf.JumpIf{Cond: bpf.JumpEqual, Val: 68, SkipTrue: 1},
bpf.RetConstant{Val: 0}, // Discard
// Accept packet and pass through up maximum IP packet size
bpf.RetConstant{Val: maxIPv4MTU},
}
var bpfFilter = mustAssemble(bpfFilterInstructions)
// BroadcastTransport implements a DHCP transport based on a custom IP/UDP
// stack fulfilling the specific requirements for broadcasting DHCP packets
// (like all-zero source address, no ARP, ...)
type BroadcastTransport struct {
rawConn *packet.Conn
iface *net.Interface
}
func NewBroadcastTransport(iface *net.Interface) *BroadcastTransport {
return &BroadcastTransport{iface: iface}
}
func (t *BroadcastTransport) Open() error {
if t.rawConn != nil {
return errors.New("broadcast transport already open")
}
rawConn, err := packet.Listen(t.iface, packet.Datagram, int(layers.EthernetTypeIPv4), &packet.Config{
Filter: bpfFilter,
})
if err != nil {
return fmt.Errorf("failed to create raw listener: %w", err)
}
t.rawConn = rawConn
return nil
}
func (t *BroadcastTransport) Send(payload *dhcpv4.DHCPv4) error {
if t.rawConn == nil {
return errors.New("broadcast transport closed")
}
pkt := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
ipLayer := &layers.IPv4{
Version: 4,
// Shift left of ECN field
TOS: dscpCS7 << 2,
// These packets should never be routed (their IP headers contain
// garbage)
TTL: 1,
Protocol: layers.IPProtocolUDP,
// Most DHCP servers don't support fragmented packets.
Flags: layers.IPv4DontFragment,
DstIP: net.IPv4bcast,
SrcIP: net.IPv4zero,
}
udpLayer := &layers.UDP{
DstPort: 67,
SrcPort: 68,
}
if err := udpLayer.SetNetworkLayerForChecksum(ipLayer); err != nil {
panic("Invalid layer stackup encountered")
}
err := gopacket.SerializeLayers(pkt, opts,
ipLayer,
udpLayer,
gopacket.Payload(payload.ToBytes()))
if err != nil {
return fmt.Errorf("failed to assemble packet: %w", err)
}
_, err = t.rawConn.WriteTo(pkt.Bytes(), &packet.Addr{HardwareAddr: layers.EthernetBroadcast})
if err != nil {
return fmt.Errorf("failed to transmit broadcast packet: %w", err)
}
return nil
}
func (t *BroadcastTransport) Receive() (*dhcpv4.DHCPv4, error) {
if t.rawConn == nil {
return nil, errors.New("broadcast transport closed")
}
buf := make([]byte, math.MaxUint16) // Maximum IP packet size
n, _, err := t.rawConn.ReadFrom(buf)
if err != nil {
return nil, deadlineFromTimeout(err)
}
respPacket := gopacket.NewPacket(buf[:n], layers.LayerTypeIPv4, gopacket.Default)
ipLayer := respPacket.Layer(layers.LayerTypeIPv4)
if ipLayer == nil {
return nil, NewInvalidMessageError(errors.New("got invalid IP packet"))
}
ip := ipLayer.(*layers.IPv4)
if ip.Flags&layers.IPv4MoreFragments != 0 {
return nil, NewInvalidMessageError(errors.New("got fragmented message"))
}
udpLayer := respPacket.Layer(layers.LayerTypeUDP)
if udpLayer == nil {
return nil, NewInvalidMessageError(errors.New("got non-UDP packet"))
}
udp := udpLayer.(*layers.UDP)
if udp.DstPort != 68 {
return nil, NewInvalidMessageError(errors.New("message not for DHCP client port"))
}
msg, err := dhcpv4.FromBytes(udp.Payload)
if err != nil {
return nil, NewInvalidMessageError(fmt.Errorf("failed to decode DHCPv4 message: %w", err))
}
return msg, nil
}
func (t *BroadcastTransport) Close() error {
if t.rawConn == nil {
return nil
}
if err := t.rawConn.Close(); err != nil {
return err
}
t.rawConn = nil
return nil
}
func (t *BroadcastTransport) SetReceiveDeadline(deadline time.Time) error {
if t.rawConn == nil {
return errors.New("broadcast transport closed")
}
return t.rawConn.SetReadDeadline(deadline)
}