Skip to content

Commit

Permalink
chore: cleanup dns policy match code
Browse files Browse the repository at this point in the history
  • Loading branch information
wwqgtxx committed Aug 15, 2024
1 parent 4c10d42 commit d48db2c
Show file tree
Hide file tree
Showing 9 changed files with 223 additions and 400 deletions.
270 changes: 125 additions & 145 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ import (
N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/common/utils"
"github.com/metacubex/mihomo/component/auth"
"github.com/metacubex/mihomo/component/cidr"
"github.com/metacubex/mihomo/component/fakeip"
"github.com/metacubex/mihomo/component/geodata"
"github.com/metacubex/mihomo/component/geodata/router"
P "github.com/metacubex/mihomo/component/process"
"github.com/metacubex/mihomo/component/resolver"
SNIFF "github.com/metacubex/mihomo/component/sniffer"
Expand Down Expand Up @@ -114,33 +114,25 @@ type NTP struct {

// DNS config
type DNS struct {
Enable bool `yaml:"enable"`
PreferH3 bool `yaml:"prefer-h3"`
IPv6 bool `yaml:"ipv6"`
IPv6Timeout uint `yaml:"ipv6-timeout"`
UseSystemHosts bool `yaml:"use-system-hosts"`
NameServer []dns.NameServer `yaml:"nameserver"`
Fallback []dns.NameServer `yaml:"fallback"`
FallbackFilter FallbackFilter `yaml:"fallback-filter"`
Listen string `yaml:"listen"`
EnhancedMode C.DNSMode `yaml:"enhanced-mode"`
DefaultNameserver []dns.NameServer `yaml:"default-nameserver"`
CacheAlgorithm string `yaml:"cache-algorithm"`
Enable bool
PreferH3 bool
IPv6 bool
IPv6Timeout uint
UseSystemHosts bool
NameServer []dns.NameServer
Fallback []dns.NameServer
FallbackIPFilter []C.Rule
FallbackDomainFilter []C.Rule
Listen string
EnhancedMode C.DNSMode
DefaultNameserver []dns.NameServer
CacheAlgorithm string
FakeIPRange *fakeip.Pool
Hosts *trie.DomainTrie[resolver.HostValue]
NameServerPolicy *orderedmap.OrderedMap[string, []dns.NameServer]
NameServerPolicy []dns.Policy
ProxyServerNameserver []dns.NameServer
}

// FallbackFilter config
type FallbackFilter struct {
GeoIP bool `yaml:"geoip"`
GeoIPCode string `yaml:"geoip-code"`
IPCIDR []netip.Prefix `yaml:"ipcidr"`
Domain []string `yaml:"domain"`
GeoSite []router.DomainMatcher `yaml:"geosite"`
}

// Profile config
type Profile struct {
StoreSelected bool `yaml:"store-selected"`
Expand Down Expand Up @@ -1205,125 +1197,81 @@ func parsePureDNSServer(server string) string {
}
}

func parseNameServerPolicy(nsPolicy *orderedmap.OrderedMap[string, any], ruleProviders map[string]providerTypes.RuleProvider, respectRules bool, preferH3 bool) (*orderedmap.OrderedMap[string, []dns.NameServer], error) {
policy := orderedmap.New[string, []dns.NameServer]()
updatedPolicy := orderedmap.New[string, any]()
func parseNameServerPolicy(nsPolicy *orderedmap.OrderedMap[string, any], rules []C.Rule, ruleProviders map[string]providerTypes.RuleProvider, respectRules bool, preferH3 bool) ([]dns.Policy, error) {
var tmpPolicy []dns.Policy
re := regexp.MustCompile(`[a-zA-Z0-9\-]+\.[a-zA-Z]{2,}(\.[a-zA-Z]{2,})?`)

for pair := nsPolicy.Oldest(); pair != nil; pair = pair.Next() {
k, v := pair.Key, pair.Value
servers, err := utils.ToStringSlice(v)
if err != nil {
return nil, err
}
nameservers, err := parseNameServer(servers, respectRules, preferH3)
if err != nil {
return nil, err
}
if strings.Contains(strings.ToLower(k), ",") {
if strings.Contains(k, "geosite:") {
subkeys := strings.Split(k, ":")
subkeys = subkeys[1:]
subkeys = strings.Split(subkeys[0], ",")
for _, subkey := range subkeys {
newKey := "geosite:" + subkey
updatedPolicy.Store(newKey, v)
tmpPolicy = append(tmpPolicy, dns.Policy{Domain: newKey, NameServers: nameservers})
}
} else if strings.Contains(strings.ToLower(k), "rule-set:") {
subkeys := strings.Split(k, ":")
subkeys = subkeys[1:]
subkeys = strings.Split(subkeys[0], ",")
for _, subkey := range subkeys {
newKey := "rule-set:" + subkey
updatedPolicy.Store(newKey, v)
tmpPolicy = append(tmpPolicy, dns.Policy{Domain: newKey, NameServers: nameservers})
}
} else if re.MatchString(k) {
subkeys := strings.Split(k, ",")
for _, subkey := range subkeys {
updatedPolicy.Store(subkey, v)
tmpPolicy = append(tmpPolicy, dns.Policy{Domain: subkey, NameServers: nameservers})
}
}
} else {
if strings.Contains(strings.ToLower(k), "geosite:") {
updatedPolicy.Store("geosite:"+k[8:], v)
tmpPolicy = append(tmpPolicy, dns.Policy{Domain: "geosite:" + k[8:], NameServers: nameservers})
} else if strings.Contains(strings.ToLower(k), "rule-set:") {
updatedPolicy.Store("rule-set:"+k[9:], v)
}
updatedPolicy.Store(k, v)
}
}

for pair := updatedPolicy.Oldest(); pair != nil; pair = pair.Next() {
domain, server := pair.Key, pair.Value
servers, err := utils.ToStringSlice(server)
if err != nil {
return nil, err
}
nameservers, err := parseNameServer(servers, respectRules, preferH3)
if err != nil {
return nil, err
}
if _, valid := trie.ValidAndSplitDomain(domain); !valid {
return nil, fmt.Errorf("DNS ResoverRule invalid domain: %s", domain)
}
if strings.HasPrefix(domain, "rule-set:") {
domainSetName := domain[9:]
if provider, ok := ruleProviders[domainSetName]; !ok {
return nil, fmt.Errorf("not found rule-set: %s", domainSetName)
tmpPolicy = append(tmpPolicy, dns.Policy{Domain: "rule-set:" + k[9:], NameServers: nameservers})
} else {
switch provider.Behavior() {
case providerTypes.IPCIDR:
return nil, fmt.Errorf("rule provider type error, except domain,actual %s", provider.Behavior())
case providerTypes.Classical:
log.Warnln("%s provider is %s, only matching it contain domain rule", provider.Name(), provider.Behavior())
}
tmpPolicy = append(tmpPolicy, dns.Policy{Domain: k, NameServers: nameservers})
}
}
policy.Store(domain, nameservers)
}

return policy, nil
}

func parseFallbackIPCIDR(ips []string) ([]netip.Prefix, error) {
var ipNets []netip.Prefix

for idx, ip := range ips {
ipnet, err := netip.ParsePrefix(ip)
if err != nil {
return nil, fmt.Errorf("DNS FallbackIP[%d] format error: %s", idx, err.Error())
}
ipNets = append(ipNets, ipnet)
}

return ipNets, nil
}

func parseFallbackGeoSite(countries []string, rules []C.Rule) ([]router.DomainMatcher, error) {
var sites []router.DomainMatcher
if len(countries) > 0 {
if err := geodata.InitGeoSite(); err != nil {
return nil, fmt.Errorf("can't initial GeoSite: %s", err)
}
log.Warnln("replace fallback-filter.geosite with nameserver-policy, it will be removed in the future")
}
var policy []dns.Policy
for _, p := range tmpPolicy {
domain, nameservers := p.Domain, p.NameServers

for _, country := range countries {
found := false
for _, rule := range rules {
if rule.RuleType() == C.GEOSITE {
if strings.EqualFold(country, rule.Payload()) {
found = true
sites = append(sites, rule.(C.RuleGeoSite).GetDomainMatcher())
log.Infoln("Start initial GeoSite dns fallback filter from rule `%s`", country)
}
if strings.HasPrefix(domain, "rule-set:") {
domainSetName := domain[9:]
rule, err := parseDomainRuleSet(domainSetName, ruleProviders)
if err != nil {
return nil, err
}
}

if !found {
matcher, recordsCount, err := geodata.LoadGeoSiteMatcher(country)
policy = append(policy, dns.Policy{Rule: rule, NameServers: nameservers})
} else if strings.HasPrefix(domain, "geosite:") {
country := domain[8:]
rule, err := parseGEOSITE(country, rules)
if err != nil {
return nil, err
}

sites = append(sites, matcher)

log.Infoln("Start initial GeoSite dns fallback filter `%s`, records: %d", country, recordsCount)
policy = append(policy, dns.Policy{Rule: rule, NameServers: nameservers})
} else {
if _, valid := trie.ValidAndSplitDomain(domain); !valid {
return nil, fmt.Errorf("DNS ResoverRule invalid domain: %s", domain)
}
policy = append(policy, dns.Policy{Domain: domain, NameServers: nameservers})
}
}
return sites, nil

return policy, nil
}

func paresNTP(rawCfg *RawConfig) *NTP {
Expand Down Expand Up @@ -1357,10 +1305,6 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[resolver.HostValue], rul
IPv6: cfg.IPv6,
UseSystemHosts: cfg.UseSystemHosts,
EnhancedMode: cfg.EnhancedMode,
FallbackFilter: FallbackFilter{
IPCIDR: []netip.Prefix{},
GeoSite: []router.DomainMatcher{},
},
}
var err error
if dnsCfg.NameServer, err = parseNameServer(cfg.NameServer, cfg.RespectRules, cfg.PreferH3); err != nil {
Expand All @@ -1371,7 +1315,7 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[resolver.HostValue], rul
return nil, err
}

if dnsCfg.NameServerPolicy, err = parseNameServerPolicy(cfg.NameServerPolicy, ruleProviders, cfg.RespectRules, cfg.PreferH3); err != nil {
if dnsCfg.NameServerPolicy, err = parseNameServerPolicy(cfg.NameServerPolicy, rules, ruleProviders, cfg.RespectRules, cfg.PreferH3); err != nil {
return nil, err
}

Expand Down Expand Up @@ -1438,18 +1382,51 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[resolver.HostValue], rul
dnsCfg.FakeIPRange = pool
}

var rule C.Rule
if len(cfg.Fallback) != 0 {
dnsCfg.FallbackFilter.GeoIP = cfg.FallbackFilter.GeoIP
dnsCfg.FallbackFilter.GeoIPCode = cfg.FallbackFilter.GeoIPCode
if fallbackip, err := parseFallbackIPCIDR(cfg.FallbackFilter.IPCIDR); err == nil {
dnsCfg.FallbackFilter.IPCIDR = fallbackip
if cfg.FallbackFilter.GeoIP {
rule, err = RC.NewGEOIP(cfg.FallbackFilter.GeoIPCode, "", false, true)
if err != nil {
return nil, fmt.Errorf("load GeoIP dns fallback filter error, %w", err)
}
dnsCfg.FallbackIPFilter = append(dnsCfg.FallbackIPFilter, rule)
}
dnsCfg.FallbackFilter.Domain = cfg.FallbackFilter.Domain
fallbackGeoSite, err := parseFallbackGeoSite(cfg.FallbackFilter.GeoSite, rules)
if err != nil {
return nil, fmt.Errorf("load GeoSite dns fallback filter error, %w", err)
if len(cfg.FallbackFilter.IPCIDR) > 0 {
cidrSet := cidr.NewIpCidrSet()
for idx, ipcidr := range cfg.FallbackFilter.IPCIDR {
err = cidrSet.AddIpCidrForString(ipcidr)
if err != nil {
return nil, fmt.Errorf("DNS FallbackIP[%d] format error: %w", idx, err)
}
}
err = cidrSet.Merge()
if err != nil {
return nil, err
}
rule = RP.NewIpCidrSet(cidrSet, "")
dnsCfg.FallbackIPFilter = append(dnsCfg.FallbackIPFilter, rule)
}
if len(cfg.FallbackFilter.Domain) > 0 {
domainTrie := trie.New[struct{}]()
for idx, domain := range cfg.FallbackFilter.Domain {
err = domainTrie.Insert(domain, struct{}{})
if err != nil {
return nil, fmt.Errorf("DNS FallbackDomain[%d] format error: %w", idx, err)
}
}
rule = RP.NewDomainSet(domainTrie.NewDomainSet(), "")
dnsCfg.FallbackIPFilter = append(dnsCfg.FallbackIPFilter, rule)
}
if len(cfg.FallbackFilter.GeoSite) > 0 {
log.Warnln("replace fallback-filter.geosite with nameserver-policy, it will be removed in the future")
for idx, geoSite := range cfg.FallbackFilter.GeoSite {
rule, err = parseGEOSITE(geoSite, rules)
if err != nil {
return nil, fmt.Errorf("DNS FallbackGeosite[%d] format error: %w", idx, err)
}
dnsCfg.FallbackIPFilter = append(dnsCfg.FallbackIPFilter, rule)
}
}
dnsCfg.FallbackFilter.GeoSite = fallbackGeoSite
}

if cfg.UseHosts {
Expand Down Expand Up @@ -1636,44 +1613,21 @@ func parseDomain(domains []string, domainTrie *trie.DomainTrie[struct{}], rules
subkeys = subkeys[1:]
subkeys = strings.Split(subkeys[0], ",")
for _, country := range subkeys {
found := false
for _, rule = range rules {
if rule.RuleType() == C.GEOSITE {
if strings.EqualFold(country, rule.Payload()) {
found = true
domainRules = append(domainRules, rule)
}
}
}
if !found {
rule, err = RC.NewGEOSITE(country, "")
if err != nil {
return nil, err
}
domainRules = append(domainRules, rule)
rule, err = parseGEOSITE(country, rules)
if err != nil {
return nil, err
}
domainRules = append(domainRules, rule)
}
} else if strings.Contains(domainLower, "rule-set:") {
subkeys := strings.Split(domain, ":")
subkeys = subkeys[1:]
subkeys = strings.Split(subkeys[0], ",")
for _, domainSetName := range subkeys {
if rp, ok := ruleProviders[domainSetName]; !ok {
return nil, fmt.Errorf("not found rule-set: %s", domainSetName)
} else {
switch rp.Behavior() {
case providerTypes.IPCIDR:
return nil, fmt.Errorf("rule provider type error, except domain,actual %s", rp.Behavior())
case providerTypes.Classical:
log.Warnln("%s provider is %s, only matching it contain domain rule", rp.Name(), rp.Behavior())
default:
}
}
rule, err = RP.NewRuleSet(domainSetName, "", true)
rule, err = parseDomainRuleSet(domainSetName, ruleProviders)
if err != nil {
return nil, err
}

domainRules = append(domainRules, rule)
}
} else {
Expand All @@ -1692,3 +1646,29 @@ func parseDomain(domains []string, domainTrie *trie.DomainTrie[struct{}], rules
}
return
}

func parseDomainRuleSet(domainSetName string, ruleProviders map[string]providerTypes.RuleProvider) (C.Rule, error) {
if rp, ok := ruleProviders[domainSetName]; !ok {
return nil, fmt.Errorf("not found rule-set: %s", domainSetName)
} else {
switch rp.Behavior() {
case providerTypes.IPCIDR:
return nil, fmt.Errorf("rule provider type error, except domain,actual %s", rp.Behavior())
case providerTypes.Classical:
log.Warnln("%s provider is %s, only matching it contain domain rule", rp.Name(), rp.Behavior())
default:
}
}
return RP.NewRuleSet(domainSetName, "", true)
}

func parseGEOSITE(country string, rules []C.Rule) (C.Rule, error) {
for _, rule := range rules {
if rule.RuleType() == C.GEOSITE {
if strings.EqualFold(country, rule.Payload()) {
return rule, nil
}
}
}
return RC.NewGEOSITE(country, "")
}
Loading

0 comments on commit d48db2c

Please sign in to comment.