Skip to content

Commit

Permalink
Merge pull request #4 from Skyxim/meta
Browse files Browse the repository at this point in the history
Feature:Supported Rule-Set
  • Loading branch information
Clash-Mini authored Dec 2, 2021
2 parents 53eb3f1 + c6f9230 commit 6369921
Show file tree
Hide file tree
Showing 13 changed files with 1,012 additions and 16 deletions.
44 changes: 44 additions & 0 deletions component/trie/ipcidr_node.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package trie

import "errors"

var (
ErrorOverMaxValue = errors.New("the value don't over max value")
)

type IpCidrNode struct {
Mark bool
child map[uint32]*IpCidrNode
maxValue uint32
}

func NewIpCidrNode(mark bool, maxValue uint32) *IpCidrNode {
ipCidrNode := &IpCidrNode{
Mark: mark,
child: map[uint32]*IpCidrNode{},
maxValue: maxValue,
}

return ipCidrNode
}

func (n *IpCidrNode) addChild(value uint32) error {
if value > n.maxValue {
return ErrorOverMaxValue
}

n.child[value] = NewIpCidrNode(false, n.maxValue)
return nil
}

func (n *IpCidrNode) hasChild(value uint32) bool {
return n.getChild(value) != nil
}

func (n *IpCidrNode) getChild(value uint32) *IpCidrNode {
if value <= n.maxValue {
return n.child[value]
}

return nil
}
255 changes: 255 additions & 0 deletions component/trie/ipcidr_trie.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
package trie

import (
"github.com/Dreamacro/clash/log"
"net"
)

type IPV6 bool

const (
ipv4GroupMaxValue = 0xFF
ipv6GroupMaxValue = 0xFFFF
)

type IpCidrTrie struct {
ipv4Trie *IpCidrNode
ipv6Trie *IpCidrNode
}

func NewIpCidrTrie() *IpCidrTrie {
return &IpCidrTrie{
ipv4Trie: NewIpCidrNode(false, ipv4GroupMaxValue),
ipv6Trie: NewIpCidrNode(false, ipv6GroupMaxValue),
}
}

func (trie *IpCidrTrie) AddIpCidr(ipCidr *net.IPNet) error {
subIpCidr, subCidr, isIpv4, err := ipCidrToSubIpCidr(ipCidr)
if err != nil {
return err
}

for _, sub := range subIpCidr {
addIpCidr(trie, isIpv4, sub, subCidr/8)
}

return nil
}

func (trie *IpCidrTrie) AddIpCidrForString(ipCidr string) error {
_, ipNet, err := net.ParseCIDR(ipCidr)
if err != nil {
return err
}

return trie.AddIpCidr(ipNet)
}

func (trie *IpCidrTrie) IsContain(ip net.IP) bool {
ip, isIpv4 := checkAndConverterIp(ip)
if ip == nil {
return false
}

var groupValues []uint32
var ipCidrNode *IpCidrNode

if isIpv4 {
ipCidrNode = trie.ipv4Trie
for _, group := range ip {
groupValues = append(groupValues, uint32(group))
}
} else {
ipCidrNode = trie.ipv6Trie
for i := 0; i < len(ip); i += 2 {
groupValues = append(groupValues, getIpv6GroupValue(ip[i], ip[i+1]))
}
}

return search(ipCidrNode, groupValues) != nil
}

func (trie *IpCidrTrie) IsContainForString(ipString string) bool {
return trie.IsContain(net.ParseIP(ipString))
}

func ipCidrToSubIpCidr(ipNet *net.IPNet) ([]net.IP, int, bool, error) {
maskSize, _ := ipNet.Mask.Size()
var (
ipList []net.IP
newMaskSize int
isIpv4 bool
err error
)

ip, isIpv4 := checkAndConverterIp(ipNet.IP)
ipList, newMaskSize, err = subIpCidr(ip, maskSize, isIpv4)

return ipList, newMaskSize, isIpv4, err
}

func subIpCidr(ip net.IP, maskSize int, isIpv4 bool) ([]net.IP, int, error) {
var subIpCidrList []net.IP
groupSize := 8
if !isIpv4 {
groupSize = 16
}

if maskSize%groupSize == 0 {
return append(subIpCidrList, ip), maskSize, nil
}

lastByteMaskSize := maskSize % 8
lastByteMaskIndex := maskSize / 8
subIpCidrNum := 0xFF >> lastByteMaskSize
for i := 0; i < subIpCidrNum; i++ {
subIpCidr := make([]byte, len(ip))
copy(subIpCidr, ip)
subIpCidr[lastByteMaskIndex] += byte(i)
subIpCidrList = append(subIpCidrList, subIpCidr)
}

newMaskSize := (lastByteMaskIndex + 1) * 8
if !isIpv4 {
newMaskSize = (lastByteMaskIndex/2 + 1) * 16
}

return subIpCidrList, newMaskSize, nil
}

func addIpCidr(trie *IpCidrTrie, isIpv4 bool, ip net.IP, groupSize int) {
if isIpv4 {
addIpv4Cidr(trie, ip, groupSize)
} else {
addIpv6Cidr(trie, ip, groupSize)
}
}

func addIpv4Cidr(trie *IpCidrTrie, ip net.IP, groupSize int) {
preNode := trie.ipv4Trie
node := preNode.getChild(uint32(ip[0]))
if node == nil {
err := preNode.addChild(uint32(ip[0]))
if err != nil {
return
}

node = preNode.getChild(uint32(ip[0]))
}

for i := 1; i < groupSize; i++ {
if node.Mark {
return
}

groupValue := uint32(ip[i])
if !node.hasChild(groupValue) {
err := node.addChild(groupValue)
if err != nil {
log.Errorln(err.Error())
}
}

preNode = node
node = node.getChild(groupValue)
if node == nil {
err := preNode.addChild(uint32(ip[i-1]))
if err != nil {
return
}

node = preNode.getChild(uint32(ip[i-1]))
}
}

node.Mark = true
cleanChild(node)
}

func addIpv6Cidr(trie *IpCidrTrie, ip net.IP, groupSize int) {
preNode := trie.ipv6Trie
node := preNode.getChild(getIpv6GroupValue(ip[0], ip[1]))
if node == nil {
err := preNode.addChild(getIpv6GroupValue(ip[0], ip[1]))
if err != nil {
return
}

node = preNode.getChild(getIpv6GroupValue(ip[0], ip[1]))
}

for i := 2; i < groupSize; i += 2 {
if node.Mark {
return
}

groupValue := getIpv6GroupValue(ip[i], ip[i+1])
if !node.hasChild(groupValue) {
err := node.addChild(groupValue)
if err != nil {
log.Errorln(err.Error())
}
}

preNode = node
node = node.getChild(groupValue)
if node == nil {
err := preNode.addChild(getIpv6GroupValue(ip[i-2], ip[i-1]))
if err != nil {
return
}

node = preNode.getChild(getIpv6GroupValue(ip[i-2], ip[i-1]))
}
}

node.Mark = true
cleanChild(node)
}

func getIpv6GroupValue(high, low byte) uint32 {
return (uint32(high) << 8) | uint32(low)
}

func cleanChild(node *IpCidrNode) {
for i := uint32(0); i < uint32(len(node.child)); i++ {
delete(node.child, i)
}
}

func search(root *IpCidrNode, groupValues []uint32) *IpCidrNode {
node := root.getChild(groupValues[0])
if node == nil || node.Mark {
return node
}

for _, value := range groupValues[1:] {
if !node.hasChild(value) {
return nil
}

node = node.getChild(value)

if node == nil || node.Mark {
return node
}
}

return nil
}

// return net.IP To4 or To16 and is ipv4
func checkAndConverterIp(ip net.IP) (net.IP, bool) {
ipResult := ip.To4()
if ipResult == nil {
ipResult = ip.To16()
if ipResult == nil {
return nil, false
}

return ipResult, false
}

return ipResult, true
}
82 changes: 82 additions & 0 deletions component/trie/trie_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package trie

import (
"net"
"testing"
)
import "github.com/stretchr/testify/assert"

func TestIpv4AddSuccess(t *testing.T) {
trie := NewIpCidrTrie()
err := trie.AddIpCidrForString("10.0.0.2/16")
assert.Equal(t, nil, err)
}

func TestIpv4AddFail(t *testing.T) {
trie := NewIpCidrTrie()
err := trie.AddIpCidrForString("333.00.23.2/23")
assert.IsType(t, new(net.ParseError), err)

err = trie.AddIpCidrForString("22.3.34.2/222")
assert.IsType(t, new(net.ParseError), err)

err = trie.AddIpCidrForString("2.2.2.2")
assert.IsType(t, new(net.ParseError), err)
}

func TestIpv4Search(t *testing.T) {
trie := NewIpCidrTrie()
assert.NoError(t, trie.AddIpCidrForString("129.2.36.0/16"))
assert.NoError(t, trie.AddIpCidrForString("10.2.36.0/18"))
assert.NoError(t, trie.AddIpCidrForString("16.2.23.0/24"))
assert.NoError(t, trie.AddIpCidrForString("11.2.13.2/26"))
assert.NoError(t, trie.AddIpCidrForString("55.5.6.3/8"))
assert.NoError(t, trie.AddIpCidrForString("66.23.25.4/6"))
assert.Equal(t, true, trie.IsContainForString("129.2.3.65"))
assert.Equal(t, false, trie.IsContainForString("15.2.3.1"))
assert.Equal(t, true, trie.IsContainForString("11.2.13.1"))
assert.Equal(t, true, trie.IsContainForString("55.0.0.0"))
assert.Equal(t, true, trie.IsContainForString("64.0.0.0"))
assert.Equal(t, false, trie.IsContainForString("128.0.0.0"))

assert.Equal(t, false, trie.IsContain(net.ParseIP("22")))
assert.Equal(t, false, trie.IsContain(net.ParseIP("")))
}

func TestIpv6AddSuccess(t *testing.T) {
trie := NewIpCidrTrie()
err := trie.AddIpCidrForString("2001:0db8:02de:0000:0000:0000:0000:0e13/32")
assert.Equal(t, nil, err)

err = trie.AddIpCidrForString("2001:1db8:f2de::0e13/18")
assert.Equal(t, nil, err)
}

func TestIpv6AddFail(t *testing.T) {
trie := NewIpCidrTrie()
err := trie.AddIpCidrForString("2001::25de::cade/23")
assert.IsType(t, new(net.ParseError), err)

err = trie.AddIpCidrForString("2001:0fa3:25de::cade/222")
assert.IsType(t, new(net.ParseError), err)

err = trie.AddIpCidrForString("2001:0fa3:25de::cade")
assert.IsType(t, new(net.ParseError), err)
}

func TestIpv6Search(t *testing.T) {
trie := NewIpCidrTrie()
assert.NoError(t, trie.AddIpCidrForString("2001:b28:f23d:f001::e/128"))
assert.NoError(t, trie.AddIpCidrForString("2001:67c:4e8:f002::e/12"))
assert.NoError(t, trie.AddIpCidrForString("2001:b28:f23d:f003::e/96"))
assert.NoError(t, trie.AddIpCidrForString("2001:67c:4e8:f002::a/32"))
assert.NoError(t, trie.AddIpCidrForString("2001:67c:4e8:f004::a/60"))
assert.NoError(t, trie.AddIpCidrForString("2001:b28:f23f:f005::a/64"))
assert.Equal(t, true, trie.IsContainForString("2001:b28:f23d:f001::e"))
assert.Equal(t, false, trie.IsContainForString("2222::fff2"))
assert.Equal(t, true, trie.IsContainForString("2000::ffa0"))
assert.Equal(t, true, trie.IsContainForString("2001:b28:f23f:f005:5662::"))
assert.Equal(t, true, trie.IsContainForString("2001:67c:4e8:9666::1213"))

assert.Equal(t, false, trie.IsContain(net.ParseIP("22233:22")))
}
Loading

0 comments on commit 6369921

Please sign in to comment.