Skip to content

Commit

Permalink
feat: support assist connectivity check (udp-dns) by dns module (#543)
Browse files Browse the repository at this point in the history
Co-authored-by: Sumire (菫) <[email protected]>
  • Loading branch information
mzz2017 and sumire88 authored Jun 16, 2024
1 parent 93e47ff commit 76a4fe9
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 46 deletions.
75 changes: 45 additions & 30 deletions component/outbound/dialer/connectivity_check.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,6 @@ func (d *Dialer) ActivateCheck() {
}

func (d *Dialer) aliveBackground() {
timeout := Timeout
cycle := d.CheckInterval
var tcpSomark uint32
if network, err := netproxy.ParseMagicNetwork(d.TcpCheckOptionRaw.ResolverNetwork); err == nil {
Expand Down Expand Up @@ -461,7 +460,7 @@ func (d *Dialer) aliveBackground() {

wg.Add(1)
go func(opt *CheckOption) {
_, _ = d.Check(timeout, opt)
_, _ = d.Check(opt)
wg.Done()
}(opt)
}
Expand Down Expand Up @@ -513,10 +512,48 @@ func (d *Dialer) UnregisterAliveDialerSet(a *AliveDialerSet) {
}
}

func (d *Dialer) Check(timeout time.Duration,
opts *CheckOption,
) (ok bool, err error) {
ctx, cancel := context.WithTimeout(context.TODO(), timeout)
func (d *Dialer) logUnavailable(
collection *collection,
network *NetworkType,
err error,
) {
// Append timeout if there is any error or unexpected status code.
if err != nil {
if strings.HasSuffix(err.Error(), "network is unreachable") {
err = fmt.Errorf("network is unreachable")
} else if strings.HasSuffix(err.Error(), "no suitable address found") ||
strings.HasSuffix(err.Error(), "non-IPv4 address") {
err = fmt.Errorf("IPv%v is not supported", network.IpVersion)
}
d.Log.WithFields(logrus.Fields{
"network": network.String(),
"node": d.property.Name,
"err": err.Error(),
}).Debugln("Connectivity Check Failed")
}
collection.Latencies10.AppendLatency(Timeout)
collection.MovingAverage = (collection.MovingAverage + Timeout) / 2
collection.Alive = false
}

func (d *Dialer) informDialerGroupUpdate(collection *collection) {
// Inform DialerGroups to update state.
// We use lock because AliveDialerSetSet is a reference of that in collection.
d.collectionFineMu.Lock()
for a := range collection.AliveDialerSetSet {
a.NotifyLatencyChange(d, collection.Alive)
}
d.collectionFineMu.Unlock()
}

func (d *Dialer) ReportUnavailable(typ *NetworkType, err error) {
collection := d.mustGetCollection(typ)
d.logUnavailable(collection, typ, err)
d.informDialerGroupUpdate(collection)
}

func (d *Dialer) Check(opts *CheckOption) (ok bool, err error) {
ctx, cancel := context.WithTimeout(context.TODO(), Timeout)
defer cancel()
start := time.Now()
// Calc latency.
Expand All @@ -537,31 +574,9 @@ func (d *Dialer) Check(timeout time.Duration,
"mov_avg": collection.MovingAverage.Truncate(time.Millisecond),
}).Debugln("Connectivity Check")
} else {
// Append timeout if there is any error or unexpected status code.
if err != nil {
if strings.HasSuffix(err.Error(), "network is unreachable") {
err = fmt.Errorf("network is unreachable")
} else if strings.HasSuffix(err.Error(), "no suitable address found") ||
strings.HasSuffix(err.Error(), "non-IPv4 address") {
err = fmt.Errorf("IPv%v is not supported", opts.networkType.IpVersion)
}
d.Log.WithFields(logrus.Fields{
"network": opts.networkType.String(),
"node": d.property.Name,
"err": err.Error(),
}).Debugln("Connectivity Check Failed")
}
collection.Latencies10.AppendLatency(timeout)
collection.MovingAverage = (collection.MovingAverage + timeout) / 2
collection.Alive = false
d.logUnavailable(collection, opts.networkType, err)
}
// Inform DialerGroups to update state.
// We use lock because AliveDialerSetSet is a reference of that in collection.
d.collectionFineMu.Lock()
for a := range collection.AliveDialerSetSet {
a.NotifyLatencyChange(d, collection.Alive)
}
d.collectionFineMu.Unlock()
d.informDialerGroupUpdate(collection)
return ok, err
}

Expand Down
11 changes: 9 additions & 2 deletions control/control_plane.go
Original file line number Diff line number Diff line change
Expand Up @@ -444,8 +444,15 @@ func NewControlPlane(
}, nil
},
BestDialerChooser: plane.chooseBestDnsDialer,
IpVersionPrefer: dnsConfig.IpVersionPrefer,
FixedDomainTtl: fixedDomainTtl,
TimeoutExceedCallback: func(dialArgument *dialArgument, err error) {
dialArgument.bestDialer.ReportUnavailable(&dialer.NetworkType{
L4Proto: dialArgument.l4proto,
IpVersion: dialArgument.ipversion,
IsDns: true,
}, err)
},
IpVersionPrefer: dnsConfig.IpVersionPrefer,
FixedDomainTtl: fixedDomainTtl,
}); err != nil {
return nil, err
}
Expand Down
36 changes: 22 additions & 14 deletions control/dns_control.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,14 @@ var (
)

type DnsControllerOption struct {
Log *logrus.Logger
CacheAccessCallback func(cache *DnsCache) (err error)
CacheRemoveCallback func(cache *DnsCache) (err error)
NewCache func(fqdn string, answers []dnsmessage.RR, deadline time.Time, originalDeadline time.Time) (cache *DnsCache, err error)
BestDialerChooser func(req *udpRequest, upstream *dns.Upstream) (*dialArgument, error)
IpVersionPrefer int
FixedDomainTtl map[string]int
Log *logrus.Logger
CacheAccessCallback func(cache *DnsCache) (err error)
CacheRemoveCallback func(cache *DnsCache) (err error)
NewCache func(fqdn string, answers []dnsmessage.RR, deadline time.Time, originalDeadline time.Time) (cache *DnsCache, err error)
BestDialerChooser func(req *udpRequest, upstream *dns.Upstream) (*dialArgument, error)
TimeoutExceedCallback func(dialArgument *dialArgument, err error)
IpVersionPrefer int
FixedDomainTtl map[string]int
}

type DnsController struct {
Expand All @@ -76,6 +77,8 @@ type DnsController struct {
cacheRemoveCallback func(cache *DnsCache) (err error)
newCache func(fqdn string, answers []dnsmessage.RR, deadline time.Time, originalDeadline time.Time) (cache *DnsCache, err error)
bestDialerChooser func(req *udpRequest, upstream *dns.Upstream) (*dialArgument, error)
// timeoutExceedCallback is used to report this dialer is broken for the NetworkType
timeoutExceedCallback func(dialArgument *dialArgument, err error)

fixedDomainTtl map[string]int
// mutex protects the dnsCache.
Expand Down Expand Up @@ -107,11 +110,12 @@ func NewDnsController(routing *dns.Dns, option *DnsControllerOption) (c *DnsCont
routing: routing,
qtypePrefer: prefer,

log: option.Log,
cacheAccessCallback: option.CacheAccessCallback,
cacheRemoveCallback: option.CacheRemoveCallback,
newCache: option.NewCache,
bestDialerChooser: option.BestDialerChooser,
log: option.Log,
cacheAccessCallback: option.CacheAccessCallback,
cacheRemoveCallback: option.CacheRemoveCallback,
newCache: option.NewCache,
bestDialerChooser: option.BestDialerChooser,
timeoutExceedCallback: option.TimeoutExceedCallback,

fixedDomainTtl: option.FixedDomainTtl,
dnsCacheMu: sync.Mutex{},
Expand Down Expand Up @@ -578,8 +582,9 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte
}
}()

_ = conn.SetDeadline(time.Now().Add(5 * time.Second))
dnsReqCtx, cancelDnsReqCtx := context.WithTimeout(context.TODO(), 5*time.Second)
timeout := 5 * time.Second
_ = conn.SetDeadline(time.Now().Add(timeout))
dnsReqCtx, cancelDnsReqCtx := context.WithTimeout(context.TODO(), timeout)
defer cancelDnsReqCtx()
go func() {
// Send DNS request every seconds.
Expand Down Expand Up @@ -613,6 +618,9 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte
// Wait for response.
n, err := conn.Read(respBuf)
if err != nil {
if c.timeoutExceedCallback != nil {
c.timeoutExceedCallback(dialArgument, err)
}
return fmt.Errorf("failed to read from: %v (dialer: %v): %w", dialArgument.bestTarget, dialArgument.bestDialer.Property().Name, err)
}
var msg dnsmessage.Msg
Expand Down

0 comments on commit 76a4fe9

Please sign in to comment.