| package kmod | 
 |  | 
 | import ( | 
 | 	"errors" | 
 | 	"fmt" | 
 | 	"log" | 
 | 	"sort" | 
 | 	"strings" | 
 |  | 
 | 	kmodpb "source.monogon.dev/metropolis/pkg/kmod/spec" | 
 | ) | 
 |  | 
 | // LookupModules looks up all matching modules for a given modalias device | 
 | // identifier. | 
 | func LookupModules(meta *kmodpb.Meta, modalias string) (mods []*kmodpb.Module) { | 
 | 	matches := make(map[uint32]bool) | 
 | 	lookupModulesRec(meta.ModuleDeviceMatches, modalias, matches) | 
 | 	for idx := range matches { | 
 | 		mods = append(mods, meta.Modules[idx]) | 
 | 	} | 
 | 	sort.Slice(mods, func(i, j int) bool { return mods[i].Name < mods[j].Name }) | 
 | 	return | 
 | } | 
 |  | 
 | func lookupModulesRec(n *kmodpb.RadixNode, needle string, matches map[uint32]bool) { | 
 | 	for _, c := range n.Children { | 
 | 		switch c.Type { | 
 | 		case kmodpb.RadixNode_LITERAL: | 
 | 			if len(needle) < len(c.Literal) { | 
 | 				continue | 
 | 			} | 
 | 			if c.Literal == needle[:len(c.Literal)] { | 
 | 				lookupModulesRec(c, needle[len(c.Literal):], matches) | 
 | 			} | 
 | 		case kmodpb.RadixNode_WILDCARD: | 
 | 			for i := 0; i <= len(needle); i++ { | 
 | 				lookupModulesRec(c, needle[i:], matches) | 
 | 			} | 
 | 		case kmodpb.RadixNode_SINGLE_WILDCARD: | 
 | 			if len(needle) < 1 { | 
 | 				continue | 
 | 			} | 
 | 			lookupModulesRec(c, needle[1:], matches) | 
 | 		case kmodpb.RadixNode_BYTE_RANGE: | 
 | 			if len(needle) < 1 { | 
 | 				continue | 
 | 			} | 
 | 			if needle[0] >= byte(c.StartByte) && needle[0] <= byte(c.EndByte) { | 
 | 				lookupModulesRec(c, needle[1:], matches) | 
 | 			} | 
 | 		} | 
 | 	} | 
 | 	if len(needle) == 0 { | 
 | 		for _, mi := range n.ModuleIndex { | 
 | 			matches[mi] = true | 
 | 		} | 
 | 	} | 
 | 	return | 
 | } | 
 |  | 
 | // AddPattern adds a new pattern associated with a moduleIndex to the radix tree | 
 | // rooted at root. | 
 | func AddPattern(root *kmodpb.RadixNode, pattern string, moduleIndex uint32) error { | 
 | 	pp, err := parsePattern(pattern) | 
 | 	if err != nil { | 
 | 		return fmt.Errorf("error parsing pattern %q: %w", pattern, err) | 
 | 	} | 
 | 	if len(pp) > 0 { | 
 | 		pp[len(pp)-1].ModuleIndex = []uint32{moduleIndex} | 
 | 	} else { | 
 | 		// This exists to handle empty patterns, which have little use in | 
 | 		// practice (but their behavior is well-defined). It exists primarily | 
 | 		// to not crash in that case as well as to appease the Fuzzer. | 
 | 		root.ModuleIndex = append(root.ModuleIndex, moduleIndex) | 
 | 	} | 
 | 	return addPatternRec(root, pp, nil) | 
 | } | 
 |  | 
 | // addPatternRec recursively adds a new pattern to the radix tree. | 
 | // If currPartOverride is non-nil it is used instead of the first part in the | 
 | // parts array. | 
 | func addPatternRec(n *kmodpb.RadixNode, parts []*kmodpb.RadixNode, currPartOverride *kmodpb.RadixNode) error { | 
 | 	if len(parts) == 0 { | 
 | 		return nil | 
 | 	} | 
 | 	var currPart *kmodpb.RadixNode | 
 | 	if currPartOverride != nil { | 
 | 		currPart = currPartOverride | 
 | 	} else { | 
 | 		currPart = parts[0] | 
 | 	} | 
 | 	for _, c := range n.Children { | 
 | 		if c.Type != currPart.Type { | 
 | 			continue | 
 | 		} | 
 | 		switch c.Type { | 
 | 		case kmodpb.RadixNode_LITERAL: | 
 | 			if c.Literal[0] == currPart.Literal[0] { | 
 | 				var i int | 
 | 				for i < len(c.Literal) && i < len(currPart.Literal) && c.Literal[i] == currPart.Literal[i] { | 
 | 					i++ | 
 | 				} | 
 | 				if i == len(c.Literal) && i == len(currPart.Literal) { | 
 | 					if len(parts) == 1 { | 
 | 						c.ModuleIndex = append(c.ModuleIndex, parts[0].ModuleIndex...) | 
 | 						return nil | 
 | 					} | 
 | 					return addPatternRec(c, parts[1:], nil) | 
 | 				} | 
 | 				if i == len(c.Literal) { | 
 | 					return addPatternRec(c, parts, &kmodpb.RadixNode{Type: kmodpb.RadixNode_LITERAL, Literal: currPart.Literal[i:], ModuleIndex: currPart.ModuleIndex}) | 
 | 				} | 
 | 				// Split current node | 
 | 				splitOldPart := &kmodpb.RadixNode{ | 
 | 					Type:        kmodpb.RadixNode_LITERAL, | 
 | 					Literal:     c.Literal[i:], | 
 | 					Children:    c.Children, | 
 | 					ModuleIndex: c.ModuleIndex, | 
 | 				} | 
 | 				var splitNewPart *kmodpb.RadixNode | 
 | 				// Current part is a strict subset of the node being traversed | 
 | 				if i == len(currPart.Literal) { | 
 | 					if len(parts) < 2 { | 
 | 						c.Children = []*kmodpb.RadixNode{splitOldPart} | 
 | 						c.Literal = currPart.Literal | 
 | 						c.ModuleIndex = currPart.ModuleIndex | 
 | 						return nil | 
 | 					} | 
 | 					splitNewPart = parts[1] | 
 | 					parts = parts[1:] | 
 | 				} else { | 
 | 					splitNewPart = &kmodpb.RadixNode{ | 
 | 						Type:        kmodpb.RadixNode_LITERAL, | 
 | 						Literal:     currPart.Literal[i:], | 
 | 						ModuleIndex: currPart.ModuleIndex, | 
 | 					} | 
 | 				} | 
 | 				c.Children = []*kmodpb.RadixNode{ | 
 | 					splitOldPart, | 
 | 					splitNewPart, | 
 | 				} | 
 | 				c.Literal = currPart.Literal[:i] | 
 | 				c.ModuleIndex = nil | 
 | 				return addPatternRec(splitNewPart, parts[1:], nil) | 
 | 			} | 
 |  | 
 | 		case kmodpb.RadixNode_BYTE_RANGE: | 
 | 			if c.StartByte == currPart.StartByte && c.EndByte == currPart.EndByte { | 
 | 				if len(parts) == 1 { | 
 | 					c.ModuleIndex = append(c.ModuleIndex, parts[0].ModuleIndex...) | 
 | 				} | 
 | 				return addPatternRec(c, parts[1:], nil) | 
 | 			} | 
 | 		case kmodpb.RadixNode_SINGLE_WILDCARD, kmodpb.RadixNode_WILDCARD: | 
 | 			if len(parts) == 1 { | 
 | 				c.ModuleIndex = append(c.ModuleIndex, parts[0].ModuleIndex...) | 
 | 			} | 
 | 			return addPatternRec(c, parts[1:], nil) | 
 | 		} | 
 | 	} | 
 | 	// No child or common prefix found, append node | 
 | 	n.Children = append(n.Children, currPart) | 
 | 	return addPatternRec(currPart, parts[1:], nil) | 
 | } | 
 |  | 
 | // PrintTree prints the tree from the given root node to standard out. | 
 | // The output is not stable and should only be used for debugging/diagnostics. | 
 | // It will log and exit the process if it encounters invalid nodes. | 
 | func PrintTree(r *kmodpb.RadixNode) { | 
 | 	printTree(r, 0, false) | 
 | } | 
 |  | 
 | func printTree(r *kmodpb.RadixNode, indent int, noIndent bool) { | 
 | 	if !noIndent { | 
 | 		for i := 0; i < indent; i++ { | 
 | 			fmt.Print("  ") | 
 | 		} | 
 | 	} | 
 | 	if len(r.ModuleIndex) > 0 { | 
 | 		fmt.Printf("%v ", r.ModuleIndex) | 
 | 	} | 
 | 	switch r.Type { | 
 | 	case kmodpb.RadixNode_LITERAL: | 
 | 		fmt.Printf("%q: ", r.Literal) | 
 | 	case kmodpb.RadixNode_SINGLE_WILDCARD: | 
 | 		fmt.Printf("?: ") | 
 | 	case kmodpb.RadixNode_WILDCARD: | 
 | 		fmt.Printf("*: ") | 
 | 	case kmodpb.RadixNode_BYTE_RANGE: | 
 | 		fmt.Printf("[%c-%c]: ", rune(r.StartByte), rune(r.EndByte)) | 
 | 	default: | 
 | 		log.Fatalf("Unknown tree type %T\n", r) | 
 | 	} | 
 | 	if len(r.Children) == 1 { | 
 | 		printTree(r.Children[0], indent, true) | 
 | 		return | 
 | 	} | 
 | 	fmt.Println("") | 
 | 	for _, c := range r.Children { | 
 | 		printTree(c, indent+1, false) | 
 | 	} | 
 | } | 
 |  | 
 | // parsePattern parses a string pattern into a non-hierarchical list of | 
 | // RadixNodes. These nodes can then be futher modified and integrated into | 
 | // a Radix tree. | 
 | func parsePattern(pattern string) ([]*kmodpb.RadixNode, error) { | 
 | 	var out []*kmodpb.RadixNode | 
 | 	var i int | 
 | 	var currentLiteral strings.Builder | 
 | 	storeCurrentLiteral := func() { | 
 | 		if currentLiteral.Len() > 0 { | 
 | 			out = append(out, &kmodpb.RadixNode{ | 
 | 				Type:    kmodpb.RadixNode_LITERAL, | 
 | 				Literal: currentLiteral.String(), | 
 | 			}) | 
 | 			currentLiteral.Reset() | 
 | 		} | 
 | 	} | 
 | 	for i < len(pattern) { | 
 | 		switch pattern[i] { | 
 | 		case '*': | 
 | 			storeCurrentLiteral() | 
 | 			i += 1 | 
 | 			if len(out) > 0 && out[len(out)-1].Type == kmodpb.RadixNode_WILDCARD { | 
 | 				continue | 
 | 			} | 
 | 			out = append(out, &kmodpb.RadixNode{ | 
 | 				Type: kmodpb.RadixNode_WILDCARD, | 
 | 			}) | 
 | 		case '?': | 
 | 			storeCurrentLiteral() | 
 | 			out = append(out, &kmodpb.RadixNode{ | 
 | 				Type: kmodpb.RadixNode_SINGLE_WILDCARD, | 
 | 			}) | 
 | 			i += 1 | 
 | 		case '[': | 
 | 			storeCurrentLiteral() | 
 | 			if len(pattern) <= i+4 { | 
 | 				return nil, errors.New("illegal byte range notation, not enough characters") | 
 | 			} | 
 | 			if pattern[i+2] != '-' || pattern[i+4] != ']' { | 
 | 				return nil, errors.New("illegal byte range notation, incorrect dash or closing character") | 
 | 			} | 
 | 			nn := &kmodpb.RadixNode{ | 
 | 				Type:      kmodpb.RadixNode_BYTE_RANGE, | 
 | 				StartByte: uint32(pattern[i+1]), | 
 | 				EndByte:   uint32(pattern[i+3]), | 
 | 			} | 
 | 			if nn.StartByte > nn.EndByte { | 
 | 				return nil, errors.New("byte range start byte larger than end byte") | 
 | 			} | 
 | 			out = append(out, nn) | 
 | 			i += 5 | 
 | 		case '\\': | 
 | 			if len(pattern) <= i+1 { | 
 | 				return nil, errors.New("illegal escape character at the end of the string") | 
 | 			} | 
 | 			currentLiteral.WriteByte(pattern[i+1]) | 
 | 			i += 2 | 
 | 		default: | 
 | 			currentLiteral.WriteByte(pattern[i]) | 
 | 			i += 1 | 
 | 		} | 
 | 	} | 
 | 	storeCurrentLiteral() | 
 | 	return out, nil | 
 | } |