From acfc9f8baacbdc0662a2a984a22d2e5512455481 Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Fri, 27 Sep 2024 20:09:35 +0800 Subject: [PATCH] chore: reset resolver's connection after default interface changed --- adapter/outbound/wireguard.go | 2 +- component/resolver/resolver.go | 10 +++++ dns/client.go | 2 + dns/dhcp.go | 6 +++ dns/doh.go | 27 ++++++++++++-- dns/doq.go | 4 ++ dns/patch_android.go | 7 ++++ dns/rcode.go | 2 + dns/resolver.go | 68 +++++++++++++++++++--------------- dns/system_common.go | 2 + hub/executor/executor.go | 5 ++- listener/sing_tun/server.go | 5 +++ 12 files changed, 103 insertions(+), 37 deletions(-) diff --git a/adapter/outbound/wireguard.go b/adapter/outbound/wireguard.go index 6f5a18f35d..3928ab1b7e 100644 --- a/adapter/outbound/wireguard.go +++ b/adapter/outbound/wireguard.go @@ -296,7 +296,7 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) { for i := range nss { nss[i].ProxyAdapter = refP } - outbound.resolver = dns.NewResolver(dns.Config{ + outbound.resolver, _ = dns.NewResolver(dns.Config{ Main: nss, IPv6: has6, }) diff --git a/component/resolver/resolver.go b/component/resolver/resolver.go index feb3f98fb5..bcdbb7e2c4 100644 --- a/component/resolver/resolver.go +++ b/component/resolver/resolver.go @@ -47,6 +47,7 @@ type Resolver interface { ExchangeContext(ctx context.Context, m *dns.Msg) (msg *dns.Msg, err error) Invalid() bool ClearCache() + ResetConnection() } // LookupIPv4WithResolver same as LookupIPv4, but with a resolver @@ -256,6 +257,15 @@ func LookupIPProxyServerHost(ctx context.Context, host string) ([]netip.Addr, er return LookupIP(ctx, host) } +func ResetConnection() { + if DefaultResolver != nil { + go DefaultResolver.ResetConnection() + } + if ProxyServerHostResolver != nil { + go ProxyServerHostResolver.ResetConnection() + } +} + func SortationAddr(ips []netip.Addr) (ipv4s, ipv6s []netip.Addr) { for _, v := range ips { if v.Unmap().Is4() { diff --git a/dns/client.go b/dns/client.go index 096b96a7f5..62fc12f9c3 100644 --- a/dns/client.go +++ b/dns/client.go @@ -103,3 +103,5 @@ func (c *client) ExchangeContext(ctx context.Context, m *D.Msg) (*D.Msg, error) return ret.msg, ret.err } } + +func (c *client) ResetConnection() {} diff --git a/dns/dhcp.go b/dns/dhcp.go index dc1344f500..e3829b7c2c 100644 --- a/dns/dhcp.go +++ b/dns/dhcp.go @@ -53,6 +53,12 @@ func (d *dhcpClient) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, return } +func (d *dhcpClient) ResetConnection() { + for _, client := range d.clients { + client.ResetConnection() + } +} + func (d *dhcpClient) resolve(ctx context.Context) ([]dnsClient, error) { d.lock.Lock() diff --git a/dns/doh.go b/dns/doh.go index ffb65fcef0..027afd58cc 100644 --- a/dns/doh.go +++ b/dns/doh.go @@ -203,11 +203,23 @@ func (doh *dnsOverHTTPS) Close() (err error) { return doh.closeClient(doh.client) } -// closeClient cleans up resources used by client if necessary. Note, that at -// this point it should only be done for HTTP/3 as it may leak due to keep-alive -// connections. +func (doh *dnsOverHTTPS) ResetConnection() { + doh.clientMu.Lock() + defer doh.clientMu.Unlock() + + if doh.client == nil { + return + } + + _ = doh.closeClient(doh.client) + doh.client = nil +} + +// closeClient cleans up resources used by client if necessary. func (doh *dnsOverHTTPS) closeClient(client *http.Client) (err error) { - if isHTTP3(client) { + client.CloseIdleConnections() + + if isHTTP3(client) { // HTTP/3 may leak due to keep-alive connections. return client.Transport.(io.Closer).Close() } @@ -508,6 +520,13 @@ func (h *http3Transport) Close() (err error) { return h.baseTransport.Close() } +func (h *http3Transport) CloseIdleConnections() { + h.mu.RLock() + defer h.mu.RUnlock() + + h.baseTransport.CloseIdleConnections() +} + // createTransportH3 tries to create an HTTP/3 transport for this upstream. // We should be able to fall back to H1/H2 in case if HTTP/3 is unavailable or // if it is too slow. In order to do that, this method will run two probes diff --git a/dns/doq.go b/dns/doq.go index ad936f9575..29fdd00660 100644 --- a/dns/doq.go +++ b/dns/doq.go @@ -144,6 +144,10 @@ func (doq *dnsOverQUIC) Close() (err error) { return err } +func (doq *dnsOverQUIC) ResetConnection() { + doq.closeConnWithError(nil) +} + // exchangeQUIC attempts to open a QUIC connection, send the DNS message // through it and return the response it got from the server. func (doq *dnsOverQUIC) exchangeQUIC(ctx context.Context, msg *D.Msg) (resp *D.Msg, err error) { diff --git a/dns/patch_android.go b/dns/patch_android.go index 6579ef071a..e3dcd2492f 100644 --- a/dns/patch_android.go +++ b/dns/patch_android.go @@ -12,6 +12,7 @@ func FlushCacheWithDefaultResolver() { if r := resolver.DefaultResolver; r != nil { r.ClearCache() } + resolver.ResetConnection() } func UpdateSystemDNS(addr []string) { @@ -30,3 +31,9 @@ func UpdateSystemDNS(addr []string) { func (c *systemClient) getDnsClients() ([]dnsClient, error) { return systemResolver, nil } + +func (c *systemClient) ResetConnection() { + for _, r := range systemResolver { + r.ResetConnection() + } +} diff --git a/dns/rcode.go b/dns/rcode.go index 9777d2e77b..901d1019d3 100644 --- a/dns/rcode.go +++ b/dns/rcode.go @@ -48,3 +48,5 @@ func (r rcodeClient) ExchangeContext(ctx context.Context, m *D.Msg) (*D.Msg, err func (r rcodeClient) Address() string { return r.addr } + +func (r rcodeClient) ResetConnection() {} diff --git a/dns/resolver.go b/dns/resolver.go index e03feef46f..ec59f42857 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -24,6 +24,7 @@ import ( type dnsClient interface { ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) Address() string + ResetConnection() } type dnsCache interface { @@ -48,7 +49,7 @@ type Resolver struct { group singleflight.Group[*D.Msg] cache dnsCache policy []dnsPolicy - proxyServer []dnsClient + defaultResolver *Resolver } func (r *Resolver) LookupIPPrimaryIPv4(ctx context.Context, host string) (ips []netip.Addr, err error) { @@ -376,6 +377,20 @@ func (r *Resolver) ClearCache() { } } +func (r *Resolver) ResetConnection() { + if r != nil { + for _, c := range r.main { + c.ResetConnection() + } + for _, c := range r.fallback { + c.ResetConnection() + } + if dr := r.defaultResolver; dr != nil { + dr.ResetConnection() + } + } +} + type NameServer struct { Net string Addr string @@ -425,16 +440,18 @@ type Config struct { CacheAlgorithm string } -func NewResolver(config Config) *Resolver { - var cache dnsCache - if config.CacheAlgorithm == "lru" { - cache = lru.New(lru.WithSize[string, *D.Msg](4096), lru.WithStale[string, *D.Msg](true)) +func (config Config) newCache() dnsCache { + if config.CacheAlgorithm == "" || config.CacheAlgorithm == "lru" { + return lru.New(lru.WithSize[string, *D.Msg](4096), lru.WithStale[string, *D.Msg](true)) } else { - cache = arc.New(arc.WithSize[string, *D.Msg](4096)) + return arc.New(arc.WithSize[string, *D.Msg](4096)) } +} + +func NewResolver(config Config) (r *Resolver, pr *Resolver) { defaultResolver := &Resolver{ main: transform(config.Default, nil), - cache: cache, + cache: config.newCache(), ipv6Timeout: time.Duration(config.IPv6Timeout) * time.Millisecond, } @@ -465,25 +482,27 @@ func NewResolver(config Config) *Resolver { return } - if config.CacheAlgorithm == "" || config.CacheAlgorithm == "lru" { - cache = lru.New(lru.WithSize[string, *D.Msg](4096), lru.WithStale[string, *D.Msg](true)) - } else { - cache = arc.New(arc.WithSize[string, *D.Msg](4096)) - } - r := &Resolver{ + r = &Resolver{ ipv6: config.IPv6, main: cacheTransform(config.Main), - cache: cache, + cache: config.newCache(), hosts: config.Hosts, ipv6Timeout: time.Duration(config.IPv6Timeout) * time.Millisecond, } + r.defaultResolver = defaultResolver - if len(config.Fallback) != 0 { - r.fallback = cacheTransform(config.Fallback) + if len(config.ProxyServer) != 0 { + pr = &Resolver{ + ipv6: config.IPv6, + main: cacheTransform(config.ProxyServer), + cache: config.newCache(), + hosts: config.Hosts, + ipv6Timeout: time.Duration(config.IPv6Timeout) * time.Millisecond, + } } - if len(config.ProxyServer) != 0 { - r.proxyServer = cacheTransform(config.ProxyServer) + if len(config.Fallback) != 0 { + r.fallback = cacheTransform(config.Fallback) } if len(config.Policy) != 0 { @@ -516,18 +535,7 @@ func NewResolver(config Config) *Resolver { r.fallbackIPFilters = config.FallbackIPFilter r.fallbackDomainFilters = config.FallbackDomainFilter - return r -} - -func NewProxyServerHostResolver(old *Resolver) *Resolver { - r := &Resolver{ - ipv6: old.ipv6, - main: old.proxyServer, - cache: old.cache, - hosts: old.hosts, - ipv6Timeout: old.ipv6Timeout, - } - return r + return } var ParseNameServer func(servers []string) ([]NameServer, error) // define in config/config.go diff --git a/dns/system_common.go b/dns/system_common.go index 06dc0b3020..e6dabdcfff 100644 --- a/dns/system_common.go +++ b/dns/system_common.go @@ -69,3 +69,5 @@ func (c *systemClient) getDnsClients() ([]dnsClient, error) { } return nil, err } + +func (c *systemClient) ResetConnection() {} diff --git a/hub/executor/executor.go b/hub/executor/executor.go index 39bf28d246..66514e39bf 100644 --- a/hub/executor/executor.go +++ b/hub/executor/executor.go @@ -118,6 +118,8 @@ func ApplyConfig(cfg *config.Config, force bool) { tunnel.OnRunning() hcCompatibleProvider(cfg.Providers) initExternalUI() + + resolver.ResetConnection() } func initInnerTcp() { @@ -253,8 +255,7 @@ func updateDNS(c *config.DNS, generalIPv6 bool) { CacheAlgorithm: c.CacheAlgorithm, } - r := dns.NewResolver(cfg) - pr := dns.NewProxyServerHostResolver(r) + r, pr := dns.NewResolver(cfg) m := dns.NewEnhancer(cfg) // reuse cache of old host mapper diff --git a/listener/sing_tun/server.go b/listener/sing_tun/server.go index c2c668b34e..79856c466c 100644 --- a/listener/sing_tun/server.go +++ b/listener/sing_tun/server.go @@ -440,6 +440,10 @@ func New(options LC.Tun, tunnel C.Tunnel, additions ...inbound.Addition) (l *Lis //l.openAndroidHotspot(tunOptions) + if !l.options.AutoDetectInterface { + resolver.ResetConnection() + } + if options.FileDescriptor != 0 { tunName = fmt.Sprintf("%s(fd=%d)", tunName, options.FileDescriptor) } @@ -507,6 +511,7 @@ func (l *Listener) FlushDefaultInterface() { if old := dialer.DefaultInterface.Swap(autoDetectInterfaceName); old != autoDetectInterfaceName { log.Warnln("[TUN] default interface changed by monitor, %s => %s", old, autoDetectInterfaceName) iface.FlushCache() + resolver.ResetConnection() // reset resolver's connection after default interface changed } return }