From 6081e6a3ff488990aab3fa349a84521cd495a1b3 Mon Sep 17 00:00:00 2001 From: Nedyalko Andreev Date: Mon, 16 Nov 2020 13:51:34 +0200 Subject: [PATCH] Block overlapping sub-domains properly --- lib/types/hostnametrie.go | 39 ++++++++++++++++++++++------------ lib/types/hostnametrie_test.go | 9 ++++++-- 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/lib/types/hostnametrie.go b/lib/types/hostnametrie.go index 09240f11933..006bfad8fb1 100644 --- a/lib/types/hostnametrie.go +++ b/lib/types/hostnametrie.go @@ -83,9 +83,8 @@ func (d NullHostnameTrie) MarshalJSON() ([]byte, error) { // for wildcards exclusively at the start of the pattern. Items may only // be inserted and searched. Internationalized hostnames are valid. type HostnameTrie struct { + *trieNode source []string - - children map[rune]*HostnameTrie } // NewNullHostnameTrie returns a NullHostnameTrie encapsulating HostnameTrie or an error if the @@ -105,6 +104,10 @@ func NewNullHostnameTrie(source []string) (NullHostnameTrie, error) { func NewHostnameTrie(source []string) (*HostnameTrie, error) { h := &HostnameTrie{ source: source, + trieNode: &trieNode{ + isLeaf: false, + children: make(map[rune]*trieNode), + }, } for _, s := range h.source { if err := h.insert(s); err != nil { @@ -135,42 +138,52 @@ func (t *HostnameTrie) insert(s string) error { return err } - return t.childInsert(s) + return t.trieNode.insert(s) +} + +// Contains returns whether s matches a pattern in the HostnameTrie +// along with the matching pattern, if one was found. +func (t *HostnameTrie) Contains(s string) (matchedPattern string, matchFound bool) { + return t.trieNode.contains(s) } -func (t *HostnameTrie) childInsert(s string) error { +type trieNode struct { + isLeaf bool + children map[rune]*trieNode +} + +func (t *trieNode) insert(s string) error { if len(s) == 0 { + t.isLeaf = true return nil } // mask creation of the trie by initializing the root here if t.children == nil { - t.children = make(map[rune]*HostnameTrie) + t.children = make(map[rune]*trieNode) } rStr := []rune(s) // need to iterate by runes for intl' names last := len(rStr) - 1 if c, ok := t.children[rStr[last]]; ok { - return c.childInsert(string(rStr[:last])) + return c.insert(string(rStr[:last])) } - t.children[rStr[last]] = &HostnameTrie{children: make(map[rune]*HostnameTrie)} - return t.children[rStr[last]].childInsert(string(rStr[:last])) + t.children[rStr[last]] = &trieNode{children: make(map[rune]*trieNode)} + return t.children[rStr[last]].insert(string(rStr[:last])) } -// Contains returns whether s matches a pattern in the HostnameTrie -// along with the matching pattern, if one was found. -func (t *HostnameTrie) Contains(s string) (matchedPattern string, matchFound bool) { +func (t *trieNode) contains(s string) (matchedPattern string, matchFound bool) { s = strings.ToLower(s) if len(s) == 0 { - if len(t.children) == 0 { + if t.isLeaf { return "", true } } else { rStr := []rune(s) last := len(rStr) - 1 if c, ok := t.children[rStr[last]]; ok { - if match, matched := c.Contains(string(rStr[:last])); matched { + if match, matched := c.contains(string(rStr[:last])); matched { return match + string(rStr[last]), true } } diff --git a/lib/types/hostnametrie_test.go b/lib/types/hostnametrie_test.go index c9eecd69992..a6be5ded1a2 100644 --- a/lib/types/hostnametrie_test.go +++ b/lib/types/hostnametrie_test.go @@ -28,21 +28,26 @@ import ( ) func TestHostnameTrieInsert(t *testing.T) { - hostnames := HostnameTrie{} + hostnames, err := NewHostnameTrie([]string{"foo.bar"}) + assert.NoError(t, err) assert.NoError(t, hostnames.insert("test.k6.io")) assert.Error(t, hostnames.insert("inval*d.pattern")) assert.NoError(t, hostnames.insert("*valid.pattern")) } func TestHostnameTrieContains(t *testing.T) { - trie, err := NewHostnameTrie([]string{"test.k6.io", "*valid.pattern"}) + trie, err := NewHostnameTrie([]string{"sub.test.k6.io", "test.k6.io", "*valid.pattern", "sub.valid.pattern"}) require.NoError(t, err) cases := map[string]string{ "K6.Io": "", "tEsT.k6.Io": "test.k6.io", "TESt.K6.IO": "test.k6.io", + "sub.test.k6.io": "sub.test.k6.io", + "sub.sub.test.k6.io": "", "blocked.valId.paTtern": "*valid.pattern", "valId.paTtern": "*valid.pattern", + "sub.valid.pattern": "sub.valid.pattern", // use the most specific blocker + "www.sub.valid.pattern": "*valid.pattern", "example.test.k6.io": "", } for key, value := range cases {