blob: 60e75f0f26fd3d8425de368cf648dbdafa7f28bb [file] [log] [blame] [edit]
package kmod
import (
"errors"
"fmt"
"log"
"sort"
"strings"
kmodpb "source.monogon.dev/metropolis/pkg/kmod/spec"
)
// LookupModules looks up all matching modules for a given modalias device
// identifier.
func LookupModules(meta *kmodpb.Meta, modalias string) (mods []*kmodpb.Module) {
matches := make(map[uint32]bool)
lookupModulesRec(meta.ModuleDeviceMatches, modalias, matches)
for idx := range matches {
mods = append(mods, meta.Modules[idx])
}
sort.Slice(mods, func(i, j int) bool { return mods[i].Name < mods[j].Name })
return
}
func lookupModulesRec(n *kmodpb.RadixNode, needle string, matches map[uint32]bool) {
for _, c := range n.Children {
switch c.Type {
case kmodpb.RadixNode_LITERAL:
if len(needle) < len(c.Literal) {
continue
}
if c.Literal == needle[:len(c.Literal)] {
lookupModulesRec(c, needle[len(c.Literal):], matches)
}
case kmodpb.RadixNode_WILDCARD:
for i := 0; i <= len(needle); i++ {
lookupModulesRec(c, needle[i:], matches)
}
case kmodpb.RadixNode_SINGLE_WILDCARD:
if len(needle) < 1 {
continue
}
lookupModulesRec(c, needle[1:], matches)
case kmodpb.RadixNode_BYTE_RANGE:
if len(needle) < 1 {
continue
}
if needle[0] >= byte(c.StartByte) && needle[0] <= byte(c.EndByte) {
lookupModulesRec(c, needle[1:], matches)
}
}
}
if len(needle) == 0 {
for _, mi := range n.ModuleIndex {
matches[mi] = true
}
}
return
}
// AddPattern adds a new pattern associated with a moduleIndex to the radix tree
// rooted at root.
func AddPattern(root *kmodpb.RadixNode, pattern string, moduleIndex uint32) error {
pp, err := parsePattern(pattern)
if err != nil {
return fmt.Errorf("error parsing pattern %q: %w", pattern, err)
}
if len(pp) > 0 {
pp[len(pp)-1].ModuleIndex = []uint32{moduleIndex}
} else {
// This exists to handle empty patterns, which have little use in
// practice (but their behavior is well-defined). It exists primarily
// to not crash in that case as well as to appease the Fuzzer.
root.ModuleIndex = append(root.ModuleIndex, moduleIndex)
}
return addPatternRec(root, pp, nil)
}
// addPatternRec recursively adds a new pattern to the radix tree.
// If currPartOverride is non-nil it is used instead of the first part in the
// parts array.
func addPatternRec(n *kmodpb.RadixNode, parts []*kmodpb.RadixNode, currPartOverride *kmodpb.RadixNode) error {
if len(parts) == 0 {
return nil
}
var currPart *kmodpb.RadixNode
if currPartOverride != nil {
currPart = currPartOverride
} else {
currPart = parts[0]
}
for _, c := range n.Children {
if c.Type != currPart.Type {
continue
}
switch c.Type {
case kmodpb.RadixNode_LITERAL:
if c.Literal[0] == currPart.Literal[0] {
var i int
for i < len(c.Literal) && i < len(currPart.Literal) && c.Literal[i] == currPart.Literal[i] {
i++
}
if i == len(c.Literal) && i == len(currPart.Literal) {
if len(parts) == 1 {
c.ModuleIndex = append(c.ModuleIndex, parts[0].ModuleIndex...)
return nil
}
return addPatternRec(c, parts[1:], nil)
}
if i == len(c.Literal) {
return addPatternRec(c, parts, &kmodpb.RadixNode{Type: kmodpb.RadixNode_LITERAL, Literal: currPart.Literal[i:], ModuleIndex: currPart.ModuleIndex})
}
// Split current node
splitOldPart := &kmodpb.RadixNode{
Type: kmodpb.RadixNode_LITERAL,
Literal: c.Literal[i:],
Children: c.Children,
ModuleIndex: c.ModuleIndex,
}
var splitNewPart *kmodpb.RadixNode
// Current part is a strict subset of the node being traversed
if i == len(currPart.Literal) {
if len(parts) < 2 {
c.Children = []*kmodpb.RadixNode{splitOldPart}
c.Literal = currPart.Literal
c.ModuleIndex = currPart.ModuleIndex
return nil
}
splitNewPart = parts[1]
parts = parts[1:]
} else {
splitNewPart = &kmodpb.RadixNode{
Type: kmodpb.RadixNode_LITERAL,
Literal: currPart.Literal[i:],
ModuleIndex: currPart.ModuleIndex,
}
}
c.Children = []*kmodpb.RadixNode{
splitOldPart,
splitNewPart,
}
c.Literal = currPart.Literal[:i]
c.ModuleIndex = nil
return addPatternRec(splitNewPart, parts[1:], nil)
}
case kmodpb.RadixNode_BYTE_RANGE:
if c.StartByte == currPart.StartByte && c.EndByte == currPart.EndByte {
if len(parts) == 1 {
c.ModuleIndex = append(c.ModuleIndex, parts[0].ModuleIndex...)
}
return addPatternRec(c, parts[1:], nil)
}
case kmodpb.RadixNode_SINGLE_WILDCARD, kmodpb.RadixNode_WILDCARD:
if len(parts) == 1 {
c.ModuleIndex = append(c.ModuleIndex, parts[0].ModuleIndex...)
}
return addPatternRec(c, parts[1:], nil)
}
}
// No child or common prefix found, append node
n.Children = append(n.Children, currPart)
return addPatternRec(currPart, parts[1:], nil)
}
// PrintTree prints the tree from the given root node to standard out.
// The output is not stable and should only be used for debugging/diagnostics.
// It will log and exit the process if it encounters invalid nodes.
func PrintTree(r *kmodpb.RadixNode) {
printTree(r, 0, false)
}
func printTree(r *kmodpb.RadixNode, indent int, noIndent bool) {
if !noIndent {
for i := 0; i < indent; i++ {
fmt.Print(" ")
}
}
if len(r.ModuleIndex) > 0 {
fmt.Printf("%v ", r.ModuleIndex)
}
switch r.Type {
case kmodpb.RadixNode_LITERAL:
fmt.Printf("%q: ", r.Literal)
case kmodpb.RadixNode_SINGLE_WILDCARD:
fmt.Printf("?: ")
case kmodpb.RadixNode_WILDCARD:
fmt.Printf("*: ")
case kmodpb.RadixNode_BYTE_RANGE:
fmt.Printf("[%c-%c]: ", rune(r.StartByte), rune(r.EndByte))
default:
log.Fatalf("Unknown tree type %T\n", r)
}
if len(r.Children) == 1 {
printTree(r.Children[0], indent, true)
return
}
fmt.Println("")
for _, c := range r.Children {
printTree(c, indent+1, false)
}
}
// parsePattern parses a string pattern into a non-hierarchical list of
// RadixNodes. These nodes can then be futher modified and integrated into
// a Radix tree.
func parsePattern(pattern string) ([]*kmodpb.RadixNode, error) {
var out []*kmodpb.RadixNode
var i int
var currentLiteral strings.Builder
storeCurrentLiteral := func() {
if currentLiteral.Len() > 0 {
out = append(out, &kmodpb.RadixNode{
Type: kmodpb.RadixNode_LITERAL,
Literal: currentLiteral.String(),
})
currentLiteral.Reset()
}
}
for i < len(pattern) {
switch pattern[i] {
case '*':
storeCurrentLiteral()
i += 1
if len(out) > 0 && out[len(out)-1].Type == kmodpb.RadixNode_WILDCARD {
continue
}
out = append(out, &kmodpb.RadixNode{
Type: kmodpb.RadixNode_WILDCARD,
})
case '?':
storeCurrentLiteral()
out = append(out, &kmodpb.RadixNode{
Type: kmodpb.RadixNode_SINGLE_WILDCARD,
})
i += 1
case '[':
storeCurrentLiteral()
if len(pattern) <= i+4 {
return nil, errors.New("illegal byte range notation, not enough characters")
}
if pattern[i+2] != '-' || pattern[i+4] != ']' {
return nil, errors.New("illegal byte range notation, incorrect dash or closing character")
}
nn := &kmodpb.RadixNode{
Type: kmodpb.RadixNode_BYTE_RANGE,
StartByte: uint32(pattern[i+1]),
EndByte: uint32(pattern[i+3]),
}
if nn.StartByte > nn.EndByte {
return nil, errors.New("byte range start byte larger than end byte")
}
out = append(out, nn)
i += 5
case '\\':
if len(pattern) <= i+1 {
return nil, errors.New("illegal escape character at the end of the string")
}
currentLiteral.WriteByte(pattern[i+1])
i += 2
default:
currentLiteral.WriteByte(pattern[i])
i += 1
}
}
storeCurrentLiteral()
return out, nil
}