blob: 7a815bad2fa413e1fe98be697d1a743db42b0814 [file] [log] [blame] [edit]
package kmod
import (
"fmt"
"regexp"
"strings"
"testing"
"unicode"
"github.com/google/go-cmp/cmp"
"google.golang.org/protobuf/testing/protocmp"
kmodpb "source.monogon.dev/metropolis/pkg/kmod/spec"
)
func TestParsePattern(t *testing.T) {
cases := []struct {
name string
pattern string
expectedNodes []*kmodpb.RadixNode
}{
{"Empty", "", nil},
{"SingleLiteral", "asdf", []*kmodpb.RadixNode{{Type: kmodpb.RadixNode_LITERAL, Literal: "asdf"}}},
{"SingleWildcard", "as*df", []*kmodpb.RadixNode{
{Type: kmodpb.RadixNode_LITERAL, Literal: "as"},
{Type: kmodpb.RadixNode_WILDCARD},
{Type: kmodpb.RadixNode_LITERAL, Literal: "df"},
}},
{"EscapedWildcard", "a\\*", []*kmodpb.RadixNode{{Type: kmodpb.RadixNode_LITERAL, Literal: "a*"}}},
{"SingleRange", "[y-z]", []*kmodpb.RadixNode{{Type: kmodpb.RadixNode_BYTE_RANGE, StartByte: 121, EndByte: 122}}},
{"SingleWildcardChar", "a?c", []*kmodpb.RadixNode{
{Type: kmodpb.RadixNode_LITERAL, Literal: "a"},
{Type: kmodpb.RadixNode_SINGLE_WILDCARD},
{Type: kmodpb.RadixNode_LITERAL, Literal: "c"},
}},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
out, err := parsePattern(c.pattern)
if err != nil {
t.Fatal(err)
}
diff := cmp.Diff(c.expectedNodes, out, protocmp.Transform())
if diff != "" {
t.Error(diff)
}
})
}
}
func TestLookupComplex(t *testing.T) {
root := &kmodpb.RadixNode{
Type: kmodpb.RadixNode_LITERAL,
}
if err := AddPattern(root, "usb:v0B95p1790d*dc*dsc*dp*icFFiscFFip00in*", 2); err != nil {
t.Error(err)
}
if err := AddPattern(root, "usb:v0B95p178Ad*dc*dsc*dp*icFFiscFFip00in*", 3); err != nil {
t.Error(err)
}
if err := AddPattern(root, "acpi*:PNP0C14:*", 10); err != nil {
t.Error(err)
}
matches := make(map[uint32]bool)
lookupModulesRec(root, "acpi:PNP0C14:asdf", matches)
if !matches[10] {
t.Error("value should match pattern 10")
}
}
func isASCII(s string) bool {
for i := 0; i < len(s); i++ {
if s[i] > unicode.MaxASCII {
return false
}
}
return true
}
func FuzzRadixImpl(f *testing.F) {
f.Add("acpi*:PNP0C14:*\x00usb:v0B95p1790d*dc*dsc*dp*icFFiscFFip00in*", "acpi:PNP0C14:asdf\x00usb:v0B95p1790d0dc0dsc0dp0icFFiscFFip00in")
f.Fuzz(func(t *testing.T, a string, b string) {
patternsRaw := strings.Split(a, "\x00")
values := strings.Split(b, "\x00")
var patternsRegexp []regexp.Regexp
root := &kmodpb.RadixNode{
Type: kmodpb.RadixNode_LITERAL,
}
for i, p := range patternsRaw {
if !isASCII(p) {
// Ignore non-ASCII patterns, there are tons of edge cases with them
return
}
pp, err := parsePattern(p)
if err != nil {
// Bad pattern
return
}
if err := AddPattern(root, p, uint32(i)); err != nil {
t.Fatal(err)
}
var regexb strings.Builder
regexb.WriteString("(?s)^")
for _, part := range pp {
switch part.Type {
case kmodpb.RadixNode_LITERAL:
regexb.WriteString(regexp.QuoteMeta(part.Literal))
case kmodpb.RadixNode_SINGLE_WILDCARD:
regexb.WriteString(".")
case kmodpb.RadixNode_WILDCARD:
regexb.WriteString(".*")
case kmodpb.RadixNode_BYTE_RANGE:
regexb.WriteString(fmt.Sprintf("[%s-%s]", regexp.QuoteMeta(string([]rune{rune(part.StartByte)})), regexp.QuoteMeta(string([]rune{rune(part.EndByte)}))))
default:
t.Errorf("Unknown node type %v", part.Type)
}
}
regexb.WriteString("$")
patternsRegexp = append(patternsRegexp, *regexp.MustCompile(regexb.String()))
}
for _, v := range values {
if !isASCII(v) {
// Ignore non-ASCII values
return
}
if len(v) > 64 {
// Ignore big values as they are not realistic and cause the
// wildcard matches to be very expensive.
return
}
radixMatchesSet := make(map[uint32]bool)
lookupModulesRec(root, v, radixMatchesSet)
for i, re := range patternsRegexp {
if re.MatchString(v) {
if !radixMatchesSet[uint32(i)] {
t.Errorf("Pattern %q is expected to match %q but didn't", patternsRaw[i], v)
}
} else {
if radixMatchesSet[uint32(i)] {
t.Errorf("Pattern %q is not expected to match %q but did", patternsRaw[i], v)
}
}
}
}
})
}