Lorenz Brun | c7b036b | 2023-06-01 12:23:57 +0200 | [diff] [blame] | 1 | package kmod |
| 2 | |
| 3 | import ( |
| 4 | "errors" |
| 5 | "fmt" |
| 6 | "log" |
| 7 | "sort" |
| 8 | "strings" |
| 9 | |
| 10 | kmodpb "source.monogon.dev/metropolis/pkg/kmod/spec" |
| 11 | ) |
| 12 | |
| 13 | // LookupModules looks up all matching modules for a given modalias device |
| 14 | // identifier. |
| 15 | func LookupModules(meta *kmodpb.Meta, modalias string) (mods []*kmodpb.Module) { |
| 16 | matches := make(map[uint32]bool) |
| 17 | lookupModulesRec(meta.ModuleDeviceMatches, modalias, matches) |
| 18 | for idx := range matches { |
| 19 | mods = append(mods, meta.Modules[idx]) |
| 20 | } |
| 21 | sort.Slice(mods, func(i, j int) bool { return mods[i].Name < mods[j].Name }) |
| 22 | return |
| 23 | } |
| 24 | |
| 25 | func lookupModulesRec(n *kmodpb.RadixNode, needle string, matches map[uint32]bool) { |
| 26 | for _, c := range n.Children { |
| 27 | switch c.Type { |
| 28 | case kmodpb.RadixNode_LITERAL: |
| 29 | if len(needle) < len(c.Literal) { |
| 30 | continue |
| 31 | } |
| 32 | if c.Literal == needle[:len(c.Literal)] { |
| 33 | lookupModulesRec(c, needle[len(c.Literal):], matches) |
| 34 | } |
| 35 | case kmodpb.RadixNode_WILDCARD: |
| 36 | for i := 0; i <= len(needle); i++ { |
| 37 | lookupModulesRec(c, needle[i:], matches) |
| 38 | } |
| 39 | case kmodpb.RadixNode_SINGLE_WILDCARD: |
| 40 | if len(needle) < 1 { |
| 41 | continue |
| 42 | } |
| 43 | lookupModulesRec(c, needle[1:], matches) |
| 44 | case kmodpb.RadixNode_BYTE_RANGE: |
| 45 | if len(needle) < 1 { |
| 46 | continue |
| 47 | } |
| 48 | if needle[0] >= byte(c.StartByte) && needle[0] <= byte(c.EndByte) { |
| 49 | lookupModulesRec(c, needle[1:], matches) |
| 50 | } |
| 51 | } |
| 52 | } |
| 53 | if len(needle) == 0 { |
| 54 | for _, mi := range n.ModuleIndex { |
| 55 | matches[mi] = true |
| 56 | } |
| 57 | } |
| 58 | return |
| 59 | } |
| 60 | |
| 61 | // AddPattern adds a new pattern associated with a moduleIndex to the radix tree |
| 62 | // rooted at root. |
| 63 | func AddPattern(root *kmodpb.RadixNode, pattern string, moduleIndex uint32) error { |
| 64 | pp, err := parsePattern(pattern) |
| 65 | if err != nil { |
| 66 | return fmt.Errorf("error parsing pattern %q: %w", pattern, err) |
| 67 | } |
| 68 | if len(pp) > 0 { |
| 69 | pp[len(pp)-1].ModuleIndex = []uint32{moduleIndex} |
| 70 | } else { |
| 71 | // This exists to handle empty patterns, which have little use in |
| 72 | // practice (but their behavior is well-defined). It exists primarily |
| 73 | // to not crash in that case as well as to appease the Fuzzer. |
| 74 | root.ModuleIndex = append(root.ModuleIndex, moduleIndex) |
| 75 | } |
| 76 | return addPatternRec(root, pp, nil) |
| 77 | } |
| 78 | |
| 79 | // addPatternRec recursively adds a new pattern to the radix tree. |
| 80 | // If currPartOverride is non-nil it is used instead of the first part in the |
| 81 | // parts array. |
| 82 | func addPatternRec(n *kmodpb.RadixNode, parts []*kmodpb.RadixNode, currPartOverride *kmodpb.RadixNode) error { |
| 83 | if len(parts) == 0 { |
| 84 | return nil |
| 85 | } |
| 86 | var currPart *kmodpb.RadixNode |
| 87 | if currPartOverride != nil { |
| 88 | currPart = currPartOverride |
| 89 | } else { |
| 90 | currPart = parts[0] |
| 91 | } |
| 92 | for _, c := range n.Children { |
| 93 | if c.Type != currPart.Type { |
| 94 | continue |
| 95 | } |
| 96 | switch c.Type { |
| 97 | case kmodpb.RadixNode_LITERAL: |
| 98 | if c.Literal[0] == currPart.Literal[0] { |
| 99 | var i int |
| 100 | for i < len(c.Literal) && i < len(currPart.Literal) && c.Literal[i] == currPart.Literal[i] { |
| 101 | i++ |
| 102 | } |
| 103 | if i == len(c.Literal) && i == len(currPart.Literal) { |
| 104 | if len(parts) == 1 { |
| 105 | c.ModuleIndex = append(c.ModuleIndex, parts[0].ModuleIndex...) |
| 106 | return nil |
| 107 | } |
| 108 | return addPatternRec(c, parts[1:], nil) |
| 109 | } |
| 110 | if i == len(c.Literal) { |
| 111 | return addPatternRec(c, parts, &kmodpb.RadixNode{Type: kmodpb.RadixNode_LITERAL, Literal: currPart.Literal[i:], ModuleIndex: currPart.ModuleIndex}) |
| 112 | } |
| 113 | // Split current node |
| 114 | splitOldPart := &kmodpb.RadixNode{ |
| 115 | Type: kmodpb.RadixNode_LITERAL, |
| 116 | Literal: c.Literal[i:], |
| 117 | Children: c.Children, |
| 118 | ModuleIndex: c.ModuleIndex, |
| 119 | } |
| 120 | var splitNewPart *kmodpb.RadixNode |
| 121 | // Current part is a strict subset of the node being traversed |
| 122 | if i == len(currPart.Literal) { |
| 123 | if len(parts) < 2 { |
| 124 | c.Children = []*kmodpb.RadixNode{splitOldPart} |
| 125 | c.Literal = currPart.Literal |
| 126 | c.ModuleIndex = currPart.ModuleIndex |
| 127 | return nil |
| 128 | } |
| 129 | splitNewPart = parts[1] |
| 130 | parts = parts[1:] |
| 131 | } else { |
| 132 | splitNewPart = &kmodpb.RadixNode{ |
| 133 | Type: kmodpb.RadixNode_LITERAL, |
| 134 | Literal: currPart.Literal[i:], |
| 135 | ModuleIndex: currPart.ModuleIndex, |
| 136 | } |
| 137 | } |
| 138 | c.Children = []*kmodpb.RadixNode{ |
| 139 | splitOldPart, |
| 140 | splitNewPart, |
| 141 | } |
| 142 | c.Literal = currPart.Literal[:i] |
| 143 | c.ModuleIndex = nil |
| 144 | return addPatternRec(splitNewPart, parts[1:], nil) |
| 145 | } |
| 146 | |
| 147 | case kmodpb.RadixNode_BYTE_RANGE: |
| 148 | if c.StartByte == currPart.StartByte && c.EndByte == currPart.EndByte { |
| 149 | if len(parts) == 1 { |
| 150 | c.ModuleIndex = append(c.ModuleIndex, parts[0].ModuleIndex...) |
| 151 | } |
| 152 | return addPatternRec(c, parts[1:], nil) |
| 153 | } |
| 154 | case kmodpb.RadixNode_SINGLE_WILDCARD, kmodpb.RadixNode_WILDCARD: |
| 155 | if len(parts) == 1 { |
| 156 | c.ModuleIndex = append(c.ModuleIndex, parts[0].ModuleIndex...) |
| 157 | } |
| 158 | return addPatternRec(c, parts[1:], nil) |
| 159 | } |
| 160 | } |
| 161 | // No child or common prefix found, append node |
| 162 | n.Children = append(n.Children, currPart) |
| 163 | return addPatternRec(currPart, parts[1:], nil) |
| 164 | } |
| 165 | |
| 166 | // PrintTree prints the tree from the given root node to standard out. |
| 167 | // The output is not stable and should only be used for debugging/diagnostics. |
| 168 | // It will log and exit the process if it encounters invalid nodes. |
| 169 | func PrintTree(r *kmodpb.RadixNode) { |
| 170 | printTree(r, 0, false) |
| 171 | } |
| 172 | |
| 173 | func printTree(r *kmodpb.RadixNode, indent int, noIndent bool) { |
| 174 | if !noIndent { |
| 175 | for i := 0; i < indent; i++ { |
| 176 | fmt.Print(" ") |
| 177 | } |
| 178 | } |
| 179 | if len(r.ModuleIndex) > 0 { |
| 180 | fmt.Printf("%v ", r.ModuleIndex) |
| 181 | } |
| 182 | switch r.Type { |
| 183 | case kmodpb.RadixNode_LITERAL: |
| 184 | fmt.Printf("%q: ", r.Literal) |
| 185 | case kmodpb.RadixNode_SINGLE_WILDCARD: |
| 186 | fmt.Printf("?: ") |
| 187 | case kmodpb.RadixNode_WILDCARD: |
| 188 | fmt.Printf("*: ") |
| 189 | case kmodpb.RadixNode_BYTE_RANGE: |
| 190 | fmt.Printf("[%c-%c]: ", rune(r.StartByte), rune(r.EndByte)) |
| 191 | default: |
| 192 | log.Fatalf("Unknown tree type %T\n", r) |
| 193 | } |
| 194 | if len(r.Children) == 1 { |
| 195 | printTree(r.Children[0], indent, true) |
| 196 | return |
| 197 | } |
| 198 | fmt.Println("") |
| 199 | for _, c := range r.Children { |
| 200 | printTree(c, indent+1, false) |
| 201 | } |
| 202 | } |
| 203 | |
| 204 | // parsePattern parses a string pattern into a non-hierarchical list of |
| 205 | // RadixNodes. These nodes can then be futher modified and integrated into |
| 206 | // a Radix tree. |
| 207 | func parsePattern(pattern string) ([]*kmodpb.RadixNode, error) { |
| 208 | var out []*kmodpb.RadixNode |
| 209 | var i int |
| 210 | var currentLiteral strings.Builder |
| 211 | storeCurrentLiteral := func() { |
| 212 | if currentLiteral.Len() > 0 { |
| 213 | out = append(out, &kmodpb.RadixNode{ |
| 214 | Type: kmodpb.RadixNode_LITERAL, |
| 215 | Literal: currentLiteral.String(), |
| 216 | }) |
| 217 | currentLiteral.Reset() |
| 218 | } |
| 219 | } |
| 220 | for i < len(pattern) { |
| 221 | switch pattern[i] { |
| 222 | case '*': |
| 223 | storeCurrentLiteral() |
| 224 | i += 1 |
| 225 | if len(out) > 0 && out[len(out)-1].Type == kmodpb.RadixNode_WILDCARD { |
| 226 | continue |
| 227 | } |
| 228 | out = append(out, &kmodpb.RadixNode{ |
| 229 | Type: kmodpb.RadixNode_WILDCARD, |
| 230 | }) |
| 231 | case '?': |
| 232 | storeCurrentLiteral() |
| 233 | out = append(out, &kmodpb.RadixNode{ |
| 234 | Type: kmodpb.RadixNode_SINGLE_WILDCARD, |
| 235 | }) |
| 236 | i += 1 |
| 237 | case '[': |
| 238 | storeCurrentLiteral() |
| 239 | if len(pattern) <= i+4 { |
| 240 | return nil, errors.New("illegal byte range notation, not enough characters") |
| 241 | } |
| 242 | if pattern[i+2] != '-' || pattern[i+4] != ']' { |
| 243 | return nil, errors.New("illegal byte range notation, incorrect dash or closing character") |
| 244 | } |
| 245 | nn := &kmodpb.RadixNode{ |
| 246 | Type: kmodpb.RadixNode_BYTE_RANGE, |
| 247 | StartByte: uint32(pattern[i+1]), |
| 248 | EndByte: uint32(pattern[i+3]), |
| 249 | } |
| 250 | if nn.StartByte > nn.EndByte { |
| 251 | return nil, errors.New("byte range start byte larger than end byte") |
| 252 | } |
| 253 | out = append(out, nn) |
| 254 | i += 5 |
| 255 | case '\\': |
| 256 | if len(pattern) <= i+1 { |
| 257 | return nil, errors.New("illegal escape character at the end of the string") |
| 258 | } |
| 259 | currentLiteral.WriteByte(pattern[i+1]) |
| 260 | i += 2 |
| 261 | default: |
| 262 | currentLiteral.WriteByte(pattern[i]) |
| 263 | i += 1 |
| 264 | } |
| 265 | } |
| 266 | storeCurrentLiteral() |
| 267 | return out, nil |
| 268 | } |