| package kmod |
| |
| import ( |
| "errors" |
| "fmt" |
| "log" |
| "sort" |
| "strings" |
| |
| kmodpb "source.monogon.dev/metropolis/pkg/kmod/spec" |
| ) |
| |
| // LookupModules looks up all matching modules for a given modalias device |
| // identifier. |
| func LookupModules(meta *kmodpb.Meta, modalias string) (mods []*kmodpb.Module) { |
| matches := make(map[uint32]bool) |
| lookupModulesRec(meta.ModuleDeviceMatches, modalias, matches) |
| for idx := range matches { |
| mods = append(mods, meta.Modules[idx]) |
| } |
| sort.Slice(mods, func(i, j int) bool { return mods[i].Name < mods[j].Name }) |
| return |
| } |
| |
| func lookupModulesRec(n *kmodpb.RadixNode, needle string, matches map[uint32]bool) { |
| for _, c := range n.Children { |
| switch c.Type { |
| case kmodpb.RadixNode_LITERAL: |
| if len(needle) < len(c.Literal) { |
| continue |
| } |
| if c.Literal == needle[:len(c.Literal)] { |
| lookupModulesRec(c, needle[len(c.Literal):], matches) |
| } |
| case kmodpb.RadixNode_WILDCARD: |
| for i := 0; i <= len(needle); i++ { |
| lookupModulesRec(c, needle[i:], matches) |
| } |
| case kmodpb.RadixNode_SINGLE_WILDCARD: |
| if len(needle) < 1 { |
| continue |
| } |
| lookupModulesRec(c, needle[1:], matches) |
| case kmodpb.RadixNode_BYTE_RANGE: |
| if len(needle) < 1 { |
| continue |
| } |
| if needle[0] >= byte(c.StartByte) && needle[0] <= byte(c.EndByte) { |
| lookupModulesRec(c, needle[1:], matches) |
| } |
| } |
| } |
| if len(needle) == 0 { |
| for _, mi := range n.ModuleIndex { |
| matches[mi] = true |
| } |
| } |
| return |
| } |
| |
| // AddPattern adds a new pattern associated with a moduleIndex to the radix tree |
| // rooted at root. |
| func AddPattern(root *kmodpb.RadixNode, pattern string, moduleIndex uint32) error { |
| pp, err := parsePattern(pattern) |
| if err != nil { |
| return fmt.Errorf("error parsing pattern %q: %w", pattern, err) |
| } |
| if len(pp) > 0 { |
| pp[len(pp)-1].ModuleIndex = []uint32{moduleIndex} |
| } else { |
| // This exists to handle empty patterns, which have little use in |
| // practice (but their behavior is well-defined). It exists primarily |
| // to not crash in that case as well as to appease the Fuzzer. |
| root.ModuleIndex = append(root.ModuleIndex, moduleIndex) |
| } |
| return addPatternRec(root, pp, nil) |
| } |
| |
| // addPatternRec recursively adds a new pattern to the radix tree. |
| // If currPartOverride is non-nil it is used instead of the first part in the |
| // parts array. |
| func addPatternRec(n *kmodpb.RadixNode, parts []*kmodpb.RadixNode, currPartOverride *kmodpb.RadixNode) error { |
| if len(parts) == 0 { |
| return nil |
| } |
| var currPart *kmodpb.RadixNode |
| if currPartOverride != nil { |
| currPart = currPartOverride |
| } else { |
| currPart = parts[0] |
| } |
| for _, c := range n.Children { |
| if c.Type != currPart.Type { |
| continue |
| } |
| switch c.Type { |
| case kmodpb.RadixNode_LITERAL: |
| if c.Literal[0] == currPart.Literal[0] { |
| var i int |
| for i < len(c.Literal) && i < len(currPart.Literal) && c.Literal[i] == currPart.Literal[i] { |
| i++ |
| } |
| if i == len(c.Literal) && i == len(currPart.Literal) { |
| if len(parts) == 1 { |
| c.ModuleIndex = append(c.ModuleIndex, parts[0].ModuleIndex...) |
| return nil |
| } |
| return addPatternRec(c, parts[1:], nil) |
| } |
| if i == len(c.Literal) { |
| return addPatternRec(c, parts, &kmodpb.RadixNode{Type: kmodpb.RadixNode_LITERAL, Literal: currPart.Literal[i:], ModuleIndex: currPart.ModuleIndex}) |
| } |
| // Split current node |
| splitOldPart := &kmodpb.RadixNode{ |
| Type: kmodpb.RadixNode_LITERAL, |
| Literal: c.Literal[i:], |
| Children: c.Children, |
| ModuleIndex: c.ModuleIndex, |
| } |
| var splitNewPart *kmodpb.RadixNode |
| // Current part is a strict subset of the node being traversed |
| if i == len(currPart.Literal) { |
| if len(parts) < 2 { |
| c.Children = []*kmodpb.RadixNode{splitOldPart} |
| c.Literal = currPart.Literal |
| c.ModuleIndex = currPart.ModuleIndex |
| return nil |
| } |
| splitNewPart = parts[1] |
| parts = parts[1:] |
| } else { |
| splitNewPart = &kmodpb.RadixNode{ |
| Type: kmodpb.RadixNode_LITERAL, |
| Literal: currPart.Literal[i:], |
| ModuleIndex: currPart.ModuleIndex, |
| } |
| } |
| c.Children = []*kmodpb.RadixNode{ |
| splitOldPart, |
| splitNewPart, |
| } |
| c.Literal = currPart.Literal[:i] |
| c.ModuleIndex = nil |
| return addPatternRec(splitNewPart, parts[1:], nil) |
| } |
| |
| case kmodpb.RadixNode_BYTE_RANGE: |
| if c.StartByte == currPart.StartByte && c.EndByte == currPart.EndByte { |
| if len(parts) == 1 { |
| c.ModuleIndex = append(c.ModuleIndex, parts[0].ModuleIndex...) |
| } |
| return addPatternRec(c, parts[1:], nil) |
| } |
| case kmodpb.RadixNode_SINGLE_WILDCARD, kmodpb.RadixNode_WILDCARD: |
| if len(parts) == 1 { |
| c.ModuleIndex = append(c.ModuleIndex, parts[0].ModuleIndex...) |
| } |
| return addPatternRec(c, parts[1:], nil) |
| } |
| } |
| // No child or common prefix found, append node |
| n.Children = append(n.Children, currPart) |
| return addPatternRec(currPart, parts[1:], nil) |
| } |
| |
| // PrintTree prints the tree from the given root node to standard out. |
| // The output is not stable and should only be used for debugging/diagnostics. |
| // It will log and exit the process if it encounters invalid nodes. |
| func PrintTree(r *kmodpb.RadixNode) { |
| printTree(r, 0, false) |
| } |
| |
| func printTree(r *kmodpb.RadixNode, indent int, noIndent bool) { |
| if !noIndent { |
| for i := 0; i < indent; i++ { |
| fmt.Print(" ") |
| } |
| } |
| if len(r.ModuleIndex) > 0 { |
| fmt.Printf("%v ", r.ModuleIndex) |
| } |
| switch r.Type { |
| case kmodpb.RadixNode_LITERAL: |
| fmt.Printf("%q: ", r.Literal) |
| case kmodpb.RadixNode_SINGLE_WILDCARD: |
| fmt.Printf("?: ") |
| case kmodpb.RadixNode_WILDCARD: |
| fmt.Printf("*: ") |
| case kmodpb.RadixNode_BYTE_RANGE: |
| fmt.Printf("[%c-%c]: ", rune(r.StartByte), rune(r.EndByte)) |
| default: |
| log.Fatalf("Unknown tree type %T\n", r) |
| } |
| if len(r.Children) == 1 { |
| printTree(r.Children[0], indent, true) |
| return |
| } |
| fmt.Println("") |
| for _, c := range r.Children { |
| printTree(c, indent+1, false) |
| } |
| } |
| |
| // parsePattern parses a string pattern into a non-hierarchical list of |
| // RadixNodes. These nodes can then be futher modified and integrated into |
| // a Radix tree. |
| func parsePattern(pattern string) ([]*kmodpb.RadixNode, error) { |
| var out []*kmodpb.RadixNode |
| var i int |
| var currentLiteral strings.Builder |
| storeCurrentLiteral := func() { |
| if currentLiteral.Len() > 0 { |
| out = append(out, &kmodpb.RadixNode{ |
| Type: kmodpb.RadixNode_LITERAL, |
| Literal: currentLiteral.String(), |
| }) |
| currentLiteral.Reset() |
| } |
| } |
| for i < len(pattern) { |
| switch pattern[i] { |
| case '*': |
| storeCurrentLiteral() |
| i += 1 |
| if len(out) > 0 && out[len(out)-1].Type == kmodpb.RadixNode_WILDCARD { |
| continue |
| } |
| out = append(out, &kmodpb.RadixNode{ |
| Type: kmodpb.RadixNode_WILDCARD, |
| }) |
| case '?': |
| storeCurrentLiteral() |
| out = append(out, &kmodpb.RadixNode{ |
| Type: kmodpb.RadixNode_SINGLE_WILDCARD, |
| }) |
| i += 1 |
| case '[': |
| storeCurrentLiteral() |
| if len(pattern) <= i+4 { |
| return nil, errors.New("illegal byte range notation, not enough characters") |
| } |
| if pattern[i+2] != '-' || pattern[i+4] != ']' { |
| return nil, errors.New("illegal byte range notation, incorrect dash or closing character") |
| } |
| nn := &kmodpb.RadixNode{ |
| Type: kmodpb.RadixNode_BYTE_RANGE, |
| StartByte: uint32(pattern[i+1]), |
| EndByte: uint32(pattern[i+3]), |
| } |
| if nn.StartByte > nn.EndByte { |
| return nil, errors.New("byte range start byte larger than end byte") |
| } |
| out = append(out, nn) |
| i += 5 |
| case '\\': |
| if len(pattern) <= i+1 { |
| return nil, errors.New("illegal escape character at the end of the string") |
| } |
| currentLiteral.WriteByte(pattern[i+1]) |
| i += 2 |
| default: |
| currentLiteral.WriteByte(pattern[i]) |
| i += 1 |
| } |
| } |
| storeCurrentLiteral() |
| return out, nil |
| } |