Skip to content

Commit

Permalink
feat: port, network rule match condition support
Browse files Browse the repository at this point in the history
  • Loading branch information
wintbiit committed Dec 11, 2023
1 parent c673864 commit 163c33e
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 20 deletions.
39 changes: 38 additions & 1 deletion model/config.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
package model

import (
"strconv"
"strings"
)

type Config struct {
Addr string `json:"addr"`
Debug bool `json:"debug"`
Expand All @@ -21,5 +26,37 @@ type Domain struct {
}

type Rule struct {
CIDRs []string `json:"cidrs"`
CIDRs []string `json:"cidrs"`
Ports []PortRule `json:"ports"`
Types []string `json:"types"`
}

type PortRule string

func (p *PortRule) Contains(port int) bool {
if strings.Count(string(*p), "-") == 1 {
ports := strings.Split(string(*p), "-")
if len(ports) != 2 {
return false
}

start, err := strconv.Atoi(ports[0])
if err != nil {
return false
}

end, err := strconv.Atoi(ports[1])
if err != nil {
return false
}

return port >= start && port <= end
}

por, err := strconv.Atoi(string(*p))
if err == nil {
return por == port
}

return false
}
42 changes: 40 additions & 2 deletions server/ruleset.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,60 @@ func (s *Server) newRuleSet(name string, rule model.Rule) *RuleSet {
Server: s,
Name: name,
l: s.l.Named(name),
cidrs: make([]*net.IPNet, len(rule.CIDRs)),
}

// Init CIDR rules
for _, cidr := range ruleSet.CIDRs {
for i, cidr := range ruleSet.CIDRs {
_, ipNet, err := net.ParseCIDR(cidr)
if err != nil {
s.l.Errorf("Failed to parse CIDR %s: %s", cidr, err)
continue
}

ruleSet.cidrs = append(ruleSet.cidrs, ipNet)
ruleSet.cidrs[i] = ipNet
}

return ruleSet
}

func (s *RuleSet) ShouldHandle(ip net.IP, port int, zone, network string) bool {
matchers := 0
matched := 0

if len(s.cidrs) > 0 {
matchers++
for _, cidr := range s.cidrs {
if cidr.Contains(ip) {
matched++
break
}
}
}

if len(s.Ports) > 0 {
matchers++
for _, portRule := range s.Ports {
if portRule.Contains(port) {
matched++
break
}
}
}

if len(s.Types) > 0 {
matchers++
for _, typ := range s.Types {
if typ == network {
matched++
break
}
}
}

return matchers == matched
}

func (s *RuleSet) findRecords(name string, quesType uint16) []model.Record {
name = strings.TrimSuffix(name, s.DomainName)
name = strings.TrimSuffix(name, ".")
Expand Down
44 changes: 27 additions & 17 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,22 +108,13 @@ func (s *Server) checkConfig() {
func (s *Server) handle(w dns.ResponseWriter, r *dns.Msg) {
remoteAddr := w.RemoteAddr()
s.l.Debugf("Receive DNS request {%+v} from %s: %s", r, remoteAddr.Network(), remoteAddr.String())
var remoteIp net.IP
if remoteAddr.Network() == "udp" {
remoteIp = remoteAddr.(*net.UDPAddr).IP
} else if remoteAddr.Network() == "tcp" {
remoteIp = remoteAddr.(*net.TCPAddr).IP
} else {
s.l.Warnf("Unsupported network %s", remoteAddr.Network())
return
}

m := new(dns.Msg)
m.SetReply(r)
m.Authoritative = s.Authoritative
m.RecursionAvailable = s.Recursion

handler := s.MatchHandler(remoteIp)
handler := s.MatchHandler(w)

if handler == nil {
s.l.Warnf("No rule found for %s", remoteAddr)
Expand Down Expand Up @@ -178,14 +169,34 @@ func (s *Server) Header(r *model.Record) dns.RR_Header {
}
}

func (s *Server) MatchHandler(ip net.IP) *RuleSet {
func (s *Server) MatchHandler(w dns.ResponseWriter) *RuleSet {
addr := w.RemoteAddr()
var ip net.IP
var port int
var zone, network string
if addr.Network() == "tcp" {
addr := addr.(*net.TCPAddr)
ip = addr.IP
port = addr.Port
zone = addr.Zone
network = addr.Network()
} else if addr.Network() == "udp" {
addr := addr.(*net.UDPAddr)
ip = addr.IP
port = addr.Port
zone = addr.Zone
} else {
s.l.Warnf("Unknown network type %s", addr.Network())
return nil
}

var handlerName string
var err error

handlerName, err = s.cacheClient.GetRuntimeCache("handler:" + ip.String())
if err != nil {
if err == redis.Nil {
handlerName = s.matchHandler(ip)
handlerName = s.matchHandler(ip, port, zone, network)
if handlerName == "" {
return nil
}
Expand All @@ -209,13 +220,12 @@ func (s *Server) MatchHandler(ip net.IP) *RuleSet {
return handler
}

func (s *Server) matchHandler(ip net.IP) string {
func (s *Server) matchHandler(ip net.IP, port int, zone, network string) string {
ruleName := ""
for _, rule := range s.rules {
for _, cidr := range rule.cidrs {
if cidr.Contains(ip) {
ruleName = rule.Name
}
if rule.ShouldHandle(ip, port, zone, network) {
ruleName = rule.Name
break
}
}

Expand Down

0 comments on commit 163c33e

Please sign in to comment.