Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize(routing): fix slow domain++ ip routing #133

Merged
merged 4 commits into from
Jun 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 2 additions & 35 deletions component/dns/response_routing.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,10 @@ import (
"github.com/daeuniverse/dae/config"
"github.com/daeuniverse/dae/pkg/config_parser"
"github.com/daeuniverse/dae/pkg/trie"
"github.com/mzz2017/softwind/pkg/zeroalloc/buffer"
"github.com/sirupsen/logrus"
"golang.org/x/net/dns/dnsmessage"
)

var ValidCidrChars = trie.NewValidChars([]byte{'0', '1'})

type ResponseMatcherBuilder struct {
log *logrus.Logger
upstreamName2Id map[string]uint8
Expand Down Expand Up @@ -71,31 +68,6 @@ func (b *ResponseMatcherBuilder) upstreamToId(upstream string) (upstreamId const
return upstreamId, nil
}

func prefix2bin128(prefix netip.Prefix) (bin128 string) {
bits := prefix.Bits()
if prefix.Addr().Is4() {
bits += 96
}
ip := prefix.Addr().As16()
buf := buffer.NewBuffer(128)
defer buf.Put()
loop:
for i := 0; i < len(ip); i++ {
for j := 0; j < 8; j++ {
if (ip[i]>>j)&1 == 1 {
buf.WriteByte('1')
} else {
buf.WriteByte('0')
}
bits--
if bits == 0 {
break loop
}
}
}
return buf.String()
}

func (b *ResponseMatcherBuilder) addIp(f *config_parser.Function, cidrs []netip.Prefix, upstream *routing.Outbound) (err error) {
upstreamId, err := b.upstreamToId(upstream.Name)
if err != nil {
Expand All @@ -107,12 +79,7 @@ func (b *ResponseMatcherBuilder) addIp(f *config_parser.Function, cidrs []netip.
Not: f.Not,
Upstream: uint8(upstreamId),
}
var keys []string
// Convert netip.Prefix -> '0' '1' string
for _, prefix := range cidrs {
keys = append(keys, prefix2bin128(prefix))
}
t, err := trie.NewTrie(keys, ValidCidrChars)
t, err := trie.NewTrieFromPrefixes(cidrs)
if err != nil {
return err
}
Expand Down Expand Up @@ -263,7 +230,7 @@ func (m *ResponseMatcher) Match(
domainMatchBitmap := m.domainMatcher.MatchDomainBitmap(qName)
bin128 := make([]string, 0, len(ips))
for _, ip := range ips {
bin128 = append(bin128, prefix2bin128(netip.PrefixFrom(netip.AddrFrom16(ip.As16()), 128)))
bin128 = append(bin128, trie.Prefix2bin128(netip.PrefixFrom(netip.AddrFrom16(ip.As16()), 128)))
}

goodSubrule := false
Expand Down
6 changes: 6 additions & 0 deletions component/outbound/dialer/connectivity_check.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
package dialer

import (
"bytes"
"context"
"errors"
"fmt"
"github.com/daeuniverse/dae/common"
"io"
"net"
"net/http"
"net/netip"
Expand Down Expand Up @@ -580,6 +582,10 @@ func (d *Dialer) HttpCheck(ctx context.Context, u *netutils.URL, ip netip.Addr,
// Judge the status code.
if page := path.Base(req.URL.Path); strings.HasPrefix(page, "generate_") {
if strconv.Itoa(resp.StatusCode) != strings.TrimPrefix(page, "generate_") {
b, _ := io.ReadAll(resp.Body)
buf := bytes.NewBuffer(nil)
_ = resp.Request.Write(buf)
d.Log.Debugln(buf.String(), "Resp: ", string(b))
return false, fmt.Errorf("unexpected status code: %v", resp.StatusCode)
}
return true, nil
Expand Down
2 changes: 1 addition & 1 deletion control/control_plane.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ func NewControlPlane(
if err = builder.BuildKernspace(log); err != nil {
return nil, fmt.Errorf("RoutingMatcherBuilder.BuildKernspace: %w", err)
}
routingMatcher, err := builder.BuildUserspace(core.bpf.LpmArrayMap)
routingMatcher, err := builder.BuildUserspace()
if err != nil {
return nil, fmt.Errorf("RoutingMatcherBuilder.BuildUserspace: %w", err)
}
Expand Down
8 changes: 4 additions & 4 deletions control/kern/tproxy.c
Original file line number Diff line number Diff line change
Expand Up @@ -1016,8 +1016,8 @@ route(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4],
#ifdef __DEBUG_ROUTING
key = match_set->type;
bpf_printk("key(match_set->type): %llu", key);
bpf_printk("Skip to judge. bad_rule: %d, good_subrule: %d", bad_rule,
good_subrule);
bpf_printk("Skip to judge. bad_rule: %d, good_subrule: %d", isdns_must_goodsubrule_badrule&0b10,
isdns_must_goodsubrule_badrule&0b1);
#endif
goto before_next_loop;
}
Expand Down Expand Up @@ -1103,7 +1103,7 @@ route(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4],

before_next_loop:
#ifdef __DEBUG_ROUTING
bpf_printk("good_subrule: %d, bad_rule: %d", good_subrule, bad_rule);
bpf_printk("good_subrule: %d, bad_rule: %d", isdns_must_goodsubrule_badrule&0b10, isdns_must_goodsubrule_badrule&0b1);
#endif
if (match_set->outbound != OUTBOUND_LOGICAL_OR) {
// This match_set reaches the end of subrule.
Expand All @@ -1119,7 +1119,7 @@ route(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4],
isdns_must_goodsubrule_badrule &= ~0b10;
}
#ifdef __DEBUG_ROUTING
bpf_printk("_bad_rule: %d", bad_rule);
bpf_printk("_bad_rule: %d", isdns_must_goodsubrule_badrule&0b1);
#endif
if ((match_set->outbound & OUTBOUND_LOGICAL_MASK) !=
OUTBOUND_LOGICAL_MASK) {
Expand Down
14 changes: 12 additions & 2 deletions control/routing_matcher_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package control
import (
"encoding/binary"
"fmt"
"github.com/daeuniverse/dae/pkg/trie"
"net/netip"
"strconv"

Expand Down Expand Up @@ -328,12 +329,21 @@ func (b *RoutingMatcherBuilder) BuildKernspace(log *logrus.Logger) (err error) {
return nil
}

func (b *RoutingMatcherBuilder) BuildUserspace(lpmArrayMap *ebpf.Map) (matcher *RoutingMatcher, err error) {
func (b *RoutingMatcherBuilder) BuildUserspace() (matcher *RoutingMatcher, err error) {
// Build domainMatcher
domainMatcher := domain_matcher.NewAhocorasickSlimtrie(b.log, consts.MaxMatchSetLen)
for _, domains := range b.simulatedDomainSet {
domainMatcher.AddSet(domains.RuleIndex, domains.Domains, domains.Key)
}
// Build Ip matcher.
var lpmMatcher []*trie.Trie
for _, prefixes := range b.simulatedLpmTries {
t, err := trie.NewTrieFromPrefixes(prefixes)
if err != nil {
return nil, err
}
lpmMatcher = append(lpmMatcher, t)
}
if err = domainMatcher.Build(); err != nil {
return nil, err
}
Expand All @@ -345,7 +355,7 @@ func (b *RoutingMatcherBuilder) BuildUserspace(lpmArrayMap *ebpf.Map) (matcher *
}

return &RoutingMatcher{
lpmArrayMap: lpmArrayMap,
lpmMatcher: lpmMatcher,
domainMatcher: domainMatcher,
matches: b.rules,
}, nil
Expand Down
40 changes: 12 additions & 28 deletions control/routing_matcher_userspace.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@ package control
import (
"encoding/binary"
"fmt"
"github.com/daeuniverse/dae/pkg/trie"
"net"
"net/netip"

"github.com/cilium/ebpf"
"github.com/daeuniverse/dae/common"
"github.com/daeuniverse/dae/common/consts"
"github.com/daeuniverse/dae/component/routing"
)

type RoutingMatcher struct {
lpmArrayMap *ebpf.Map
lpmMatcher []*trie.Trie
domainMatcher routing.DomainMatcher // All domain matchSets use one DomainMatcher.

matches []bpfMatchSet
Expand All @@ -38,19 +38,12 @@ func (m *RoutingMatcher) Match(
if len(sourceAddr) != net.IPv6len || len(destAddr) != net.IPv6len || len(mac) != net.IPv6len {
return 0, 0, false, fmt.Errorf("bad address length")
}
lpmKeys := make([]*_bpfLpmKey, consts.MatchType_Mac+1)
lpmKeys[consts.MatchType_IpSet] = &_bpfLpmKey{
PrefixLen: 128,
Data: common.Ipv6ByteSliceToUint32Array(destAddr),
}
lpmKeys[consts.MatchType_SourceIpSet] = &_bpfLpmKey{
PrefixLen: 128,
Data: common.Ipv6ByteSliceToUint32Array(sourceAddr),
}
lpmKeys[consts.MatchType_Mac] = &_bpfLpmKey{
PrefixLen: 128,
Data: common.Ipv6ByteSliceToUint32Array(mac),
}

bin128s := make([]string, consts.MatchType_Mac+1)
bin128s[consts.MatchType_IpSet] = trie.Prefix2bin128(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(destAddr)), 128))
bin128s[consts.MatchType_SourceIpSet] = trie.Prefix2bin128(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(sourceAddr)), 128))
bin128s[consts.MatchType_Mac] = trie.Prefix2bin128(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(mac)), 128))

var domainMatchBitmap []uint32
if domain != "" {
domainMatchBitmap = m.domainMatcher.MatchDomainBitmap(domain)
Expand All @@ -65,19 +58,10 @@ func (m *RoutingMatcher) Match(
switch consts.MatchType(match.Type) {
case consts.MatchType_IpSet, consts.MatchType_SourceIpSet, consts.MatchType_Mac:
lpmIndex := uint32(binary.LittleEndian.Uint16(match.Value[:]))
var lpm *ebpf.Map
if err = m.lpmArrayMap.Lookup(lpmIndex, &lpm); err != nil {
//logrus.Debugln("m.lpmArrayMap.Lookup:", err)
break
}
var v uint32
if err = lpm.Lookup(*lpmKeys[int(match.Type)], &v); err != nil {
_ = lpm.Close()
//logrus.Debugln("lpm.Lookup:", err, lpmKeys[int(match.Type)], match.Type, destAddr)
break
m := m.lpmMatcher[lpmIndex]
if m.HasPrefix(bin128s[match.Type]) {
goodSubrule = true
}
_ = lpm.Close()
goodSubrule = true
case consts.MatchType_DomainSet:
if domainMatchBitmap != nil && (domainMatchBitmap[i/32]>>(i%32))&1 > 0 {
goodSubrule = true
Expand Down
45 changes: 45 additions & 0 deletions pkg/trie/trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@ package trie

import (
"fmt"
"github.com/mzz2017/softwind/pkg/zeroalloc/buffer"
"math/bits"
"net/netip"
"sort"

"github.com/daeuniverse/dae/common"
"github.com/daeuniverse/dae/common/bitlist"
)

var ValidCidrChars = NewValidChars([]byte{'0', '1'})

type ValidChars struct {
table [256]byte
n uint16
Expand Down Expand Up @@ -87,6 +91,47 @@ type Trie struct {
chars *ValidChars
}

func Prefix2bin128(prefix netip.Prefix) (bin128 string) {
n := prefix.Bits()
if n == -1 {
panic("! BadPrefix: " + prefix.String())
}
if prefix.Addr().Is4() {
n += 96
}
ip := prefix.Addr().As16()
buf := buffer.NewBuffer(128)
defer buf.Put()
loop:
for i := 0; i < len(ip); i++ {
for j := 7; j >= 0; j-- {
if (ip[i]>>j)&1 == 1 {
_ = buf.WriteByte('1')
} else {
_ = buf.WriteByte('0')
}
n--
if n == 0 {
break loop
}
}
}
return buf.String()
}

func NewTrieFromPrefixes(cidrs []netip.Prefix) (*Trie, error) {
var keys []string
// Convert netip.Prefix -> '0' '1' string
for _, prefix := range cidrs {
keys = append(keys, Prefix2bin128(prefix))
}
t, err := NewTrie(keys, ValidCidrChars)
if err != nil {
return nil, err
}
return t, nil
}

// NewTrie creates a new *Trie struct, from a slice of sorted strings.
func NewTrie(keys []string, chars *ValidChars) (*Trie, error) {
// Check chars.
Expand Down