blob: 60e75f0f26fd3d8425de368cf648dbdafa7f28bb [file] [log] [blame]
Lorenz Brunc7b036b2023-06-01 12:23:57 +02001package kmod
2
3import (
4 "errors"
5 "fmt"
6 "log"
7 "sort"
8 "strings"
9
10 kmodpb "source.monogon.dev/metropolis/pkg/kmod/spec"
11)
12
13// LookupModules looks up all matching modules for a given modalias device
14// identifier.
15func LookupModules(meta *kmodpb.Meta, modalias string) (mods []*kmodpb.Module) {
16 matches := make(map[uint32]bool)
17 lookupModulesRec(meta.ModuleDeviceMatches, modalias, matches)
18 for idx := range matches {
19 mods = append(mods, meta.Modules[idx])
20 }
21 sort.Slice(mods, func(i, j int) bool { return mods[i].Name < mods[j].Name })
22 return
23}
24
25func lookupModulesRec(n *kmodpb.RadixNode, needle string, matches map[uint32]bool) {
26 for _, c := range n.Children {
27 switch c.Type {
28 case kmodpb.RadixNode_LITERAL:
29 if len(needle) < len(c.Literal) {
30 continue
31 }
32 if c.Literal == needle[:len(c.Literal)] {
33 lookupModulesRec(c, needle[len(c.Literal):], matches)
34 }
35 case kmodpb.RadixNode_WILDCARD:
36 for i := 0; i <= len(needle); i++ {
37 lookupModulesRec(c, needle[i:], matches)
38 }
39 case kmodpb.RadixNode_SINGLE_WILDCARD:
40 if len(needle) < 1 {
41 continue
42 }
43 lookupModulesRec(c, needle[1:], matches)
44 case kmodpb.RadixNode_BYTE_RANGE:
45 if len(needle) < 1 {
46 continue
47 }
48 if needle[0] >= byte(c.StartByte) && needle[0] <= byte(c.EndByte) {
49 lookupModulesRec(c, needle[1:], matches)
50 }
51 }
52 }
53 if len(needle) == 0 {
54 for _, mi := range n.ModuleIndex {
55 matches[mi] = true
56 }
57 }
58 return
59}
60
61// AddPattern adds a new pattern associated with a moduleIndex to the radix tree
62// rooted at root.
63func AddPattern(root *kmodpb.RadixNode, pattern string, moduleIndex uint32) error {
64 pp, err := parsePattern(pattern)
65 if err != nil {
66 return fmt.Errorf("error parsing pattern %q: %w", pattern, err)
67 }
68 if len(pp) > 0 {
69 pp[len(pp)-1].ModuleIndex = []uint32{moduleIndex}
70 } else {
71 // This exists to handle empty patterns, which have little use in
72 // practice (but their behavior is well-defined). It exists primarily
73 // to not crash in that case as well as to appease the Fuzzer.
74 root.ModuleIndex = append(root.ModuleIndex, moduleIndex)
75 }
76 return addPatternRec(root, pp, nil)
77}
78
79// addPatternRec recursively adds a new pattern to the radix tree.
80// If currPartOverride is non-nil it is used instead of the first part in the
81// parts array.
82func addPatternRec(n *kmodpb.RadixNode, parts []*kmodpb.RadixNode, currPartOverride *kmodpb.RadixNode) error {
83 if len(parts) == 0 {
84 return nil
85 }
86 var currPart *kmodpb.RadixNode
87 if currPartOverride != nil {
88 currPart = currPartOverride
89 } else {
90 currPart = parts[0]
91 }
92 for _, c := range n.Children {
93 if c.Type != currPart.Type {
94 continue
95 }
96 switch c.Type {
97 case kmodpb.RadixNode_LITERAL:
98 if c.Literal[0] == currPart.Literal[0] {
99 var i int
100 for i < len(c.Literal) && i < len(currPart.Literal) && c.Literal[i] == currPart.Literal[i] {
101 i++
102 }
103 if i == len(c.Literal) && i == len(currPart.Literal) {
104 if len(parts) == 1 {
105 c.ModuleIndex = append(c.ModuleIndex, parts[0].ModuleIndex...)
106 return nil
107 }
108 return addPatternRec(c, parts[1:], nil)
109 }
110 if i == len(c.Literal) {
111 return addPatternRec(c, parts, &kmodpb.RadixNode{Type: kmodpb.RadixNode_LITERAL, Literal: currPart.Literal[i:], ModuleIndex: currPart.ModuleIndex})
112 }
113 // Split current node
114 splitOldPart := &kmodpb.RadixNode{
115 Type: kmodpb.RadixNode_LITERAL,
116 Literal: c.Literal[i:],
117 Children: c.Children,
118 ModuleIndex: c.ModuleIndex,
119 }
120 var splitNewPart *kmodpb.RadixNode
121 // Current part is a strict subset of the node being traversed
122 if i == len(currPart.Literal) {
123 if len(parts) < 2 {
124 c.Children = []*kmodpb.RadixNode{splitOldPart}
125 c.Literal = currPart.Literal
126 c.ModuleIndex = currPart.ModuleIndex
127 return nil
128 }
129 splitNewPart = parts[1]
130 parts = parts[1:]
131 } else {
132 splitNewPart = &kmodpb.RadixNode{
133 Type: kmodpb.RadixNode_LITERAL,
134 Literal: currPart.Literal[i:],
135 ModuleIndex: currPart.ModuleIndex,
136 }
137 }
138 c.Children = []*kmodpb.RadixNode{
139 splitOldPart,
140 splitNewPart,
141 }
142 c.Literal = currPart.Literal[:i]
143 c.ModuleIndex = nil
144 return addPatternRec(splitNewPart, parts[1:], nil)
145 }
146
147 case kmodpb.RadixNode_BYTE_RANGE:
148 if c.StartByte == currPart.StartByte && c.EndByte == currPart.EndByte {
149 if len(parts) == 1 {
150 c.ModuleIndex = append(c.ModuleIndex, parts[0].ModuleIndex...)
151 }
152 return addPatternRec(c, parts[1:], nil)
153 }
154 case kmodpb.RadixNode_SINGLE_WILDCARD, kmodpb.RadixNode_WILDCARD:
155 if len(parts) == 1 {
156 c.ModuleIndex = append(c.ModuleIndex, parts[0].ModuleIndex...)
157 }
158 return addPatternRec(c, parts[1:], nil)
159 }
160 }
161 // No child or common prefix found, append node
162 n.Children = append(n.Children, currPart)
163 return addPatternRec(currPart, parts[1:], nil)
164}
165
166// PrintTree prints the tree from the given root node to standard out.
167// The output is not stable and should only be used for debugging/diagnostics.
168// It will log and exit the process if it encounters invalid nodes.
169func PrintTree(r *kmodpb.RadixNode) {
170 printTree(r, 0, false)
171}
172
173func printTree(r *kmodpb.RadixNode, indent int, noIndent bool) {
174 if !noIndent {
175 for i := 0; i < indent; i++ {
176 fmt.Print(" ")
177 }
178 }
179 if len(r.ModuleIndex) > 0 {
180 fmt.Printf("%v ", r.ModuleIndex)
181 }
182 switch r.Type {
183 case kmodpb.RadixNode_LITERAL:
184 fmt.Printf("%q: ", r.Literal)
185 case kmodpb.RadixNode_SINGLE_WILDCARD:
186 fmt.Printf("?: ")
187 case kmodpb.RadixNode_WILDCARD:
188 fmt.Printf("*: ")
189 case kmodpb.RadixNode_BYTE_RANGE:
190 fmt.Printf("[%c-%c]: ", rune(r.StartByte), rune(r.EndByte))
191 default:
192 log.Fatalf("Unknown tree type %T\n", r)
193 }
194 if len(r.Children) == 1 {
195 printTree(r.Children[0], indent, true)
196 return
197 }
198 fmt.Println("")
199 for _, c := range r.Children {
200 printTree(c, indent+1, false)
201 }
202}
203
204// parsePattern parses a string pattern into a non-hierarchical list of
205// RadixNodes. These nodes can then be futher modified and integrated into
206// a Radix tree.
207func parsePattern(pattern string) ([]*kmodpb.RadixNode, error) {
208 var out []*kmodpb.RadixNode
209 var i int
210 var currentLiteral strings.Builder
211 storeCurrentLiteral := func() {
212 if currentLiteral.Len() > 0 {
213 out = append(out, &kmodpb.RadixNode{
214 Type: kmodpb.RadixNode_LITERAL,
215 Literal: currentLiteral.String(),
216 })
217 currentLiteral.Reset()
218 }
219 }
220 for i < len(pattern) {
221 switch pattern[i] {
222 case '*':
223 storeCurrentLiteral()
224 i += 1
225 if len(out) > 0 && out[len(out)-1].Type == kmodpb.RadixNode_WILDCARD {
226 continue
227 }
228 out = append(out, &kmodpb.RadixNode{
229 Type: kmodpb.RadixNode_WILDCARD,
230 })
231 case '?':
232 storeCurrentLiteral()
233 out = append(out, &kmodpb.RadixNode{
234 Type: kmodpb.RadixNode_SINGLE_WILDCARD,
235 })
236 i += 1
237 case '[':
238 storeCurrentLiteral()
239 if len(pattern) <= i+4 {
240 return nil, errors.New("illegal byte range notation, not enough characters")
241 }
242 if pattern[i+2] != '-' || pattern[i+4] != ']' {
243 return nil, errors.New("illegal byte range notation, incorrect dash or closing character")
244 }
245 nn := &kmodpb.RadixNode{
246 Type: kmodpb.RadixNode_BYTE_RANGE,
247 StartByte: uint32(pattern[i+1]),
248 EndByte: uint32(pattern[i+3]),
249 }
250 if nn.StartByte > nn.EndByte {
251 return nil, errors.New("byte range start byte larger than end byte")
252 }
253 out = append(out, nn)
254 i += 5
255 case '\\':
256 if len(pattern) <= i+1 {
257 return nil, errors.New("illegal escape character at the end of the string")
258 }
259 currentLiteral.WriteByte(pattern[i+1])
260 i += 2
261 default:
262 currentLiteral.WriteByte(pattern[i])
263 i += 1
264 }
265 }
266 storeCurrentLiteral()
267 return out, nil
268}