From bfc17c3e2dd04850c8f4b67e1f79c2e50d0f4293 Mon Sep 17 00:00:00 2001 From: Ekko Date: Sat, 2 Nov 2024 00:45:03 +0800 Subject: [PATCH] feat(dns): support DoH, DoT, DoH3, DoQ (#649) --- component/dns/dns.go | 2 +- component/dns/upstream.go | 41 +++- control/dns.go | 437 ++++++++++++++++++++++++++++++++++++++ control/dns_control.go | 142 +++---------- 4 files changed, 496 insertions(+), 126 deletions(-) create mode 100644 control/dns.go diff --git a/component/dns/dns.go b/component/dns/dns.go index b6917e3ce..96ca5ce34 100644 --- a/component/dns/dns.go +++ b/component/dns/dns.go @@ -128,7 +128,7 @@ func New(dns *config.Dns, opt *NewOption) (s *Dns, err error) { func (s *Dns) CheckUpstreamsFormat() error { for _, upstream := range s.upstream { - _, _, _, err := ParseRawUpstream(upstream.Raw) + _, _, _, _, err := ParseRawUpstream(upstream.Raw) if err != nil { return err } diff --git a/component/dns/upstream.go b/component/dns/upstream.go index bd8682f7d..0adee2fe1 100644 --- a/component/dns/upstream.go +++ b/component/dns/upstream.go @@ -30,6 +30,11 @@ const ( UpstreamScheme_UDP UpstreamScheme = "udp" UpstreamScheme_TCP_UDP UpstreamScheme = "tcp+udp" upstreamScheme_TCP_UDP_Alias UpstreamScheme = "udp+tcp" + UpstreamScheme_TLS UpstreamScheme = "tls" + UpstreamScheme_QUIC UpstreamScheme = "quic" + UpstreamScheme_HTTPS UpstreamScheme = "https" + upstreamScheme_H3_Alias UpstreamScheme = "http3" + UpstreamScheme_H3 UpstreamScheme = "h3" ) func (s UpstreamScheme) ContainsTcp() bool { @@ -42,8 +47,9 @@ func (s UpstreamScheme) ContainsTcp() bool { } } -func ParseRawUpstream(raw *url.URL) (scheme UpstreamScheme, hostname string, port uint16, err error) { +func ParseRawUpstream(raw *url.URL) (scheme UpstreamScheme, hostname string, port uint16, path string, err error) { var __port string + var __path string switch scheme = UpstreamScheme(raw.Scheme); scheme { case upstreamScheme_TCP_UDP_Alias: scheme = UpstreamScheme_TCP_UDP @@ -53,27 +59,45 @@ func ParseRawUpstream(raw *url.URL) (scheme UpstreamScheme, hostname string, por if __port == "" { __port = "53" } + case upstreamScheme_H3_Alias: + scheme = UpstreamScheme_H3 + fallthrough + case UpstreamScheme_HTTPS, UpstreamScheme_H3: + __port = raw.Port() + if __port == "" { + __port = "443" + } + __path = raw.Path + if __path == "" { + __path = "/dns-query" + } + case UpstreamScheme_QUIC, UpstreamScheme_TLS: + __port = raw.Port() + if __port == "" { + __port = "853" + } default: - return "", "", 0, fmt.Errorf("unexpected scheme: %v", raw.Scheme) + return "", "", 0, "", fmt.Errorf("unexpected scheme: %v", raw.Scheme) } _port, err := strconv.ParseUint(__port, 10, 16) if err != nil { - return "", "", 0, fmt.Errorf("failed to parse dns_upstream port: %v", err) + return "", "", 0, "", fmt.Errorf("failed to parse dns_upstream port: %v", err) } port = uint16(_port) hostname = raw.Hostname() - return scheme, hostname, port, nil + return scheme, hostname, port, __path, nil } type Upstream struct { Scheme UpstreamScheme Hostname string Port uint16 + Path string *netutils.Ip46 } func NewUpstream(ctx context.Context, upstream *url.URL, resolverNetwork string) (up *Upstream, err error) { - scheme, hostname, port, err := ParseRawUpstream(upstream) + scheme, hostname, port, path, err := ParseRawUpstream(upstream) if err != nil { return nil, fmt.Errorf("%w: %v", ErrFormat, err) } @@ -100,6 +124,7 @@ func NewUpstream(ctx context.Context, upstream *url.URL, resolverNetwork string) Scheme: scheme, Hostname: hostname, Port: port, + Path: path, Ip46: ip46, }, nil } @@ -115,9 +140,9 @@ func (u *Upstream) SupportedNetworks() (ipversions []consts.IpVersionStr, l4prot } } switch u.Scheme { - case UpstreamScheme_TCP: + case UpstreamScheme_TCP, UpstreamScheme_HTTPS, UpstreamScheme_TLS: l4protos = []consts.L4ProtoStr{consts.L4ProtoStr_TCP} - case UpstreamScheme_UDP: + case UpstreamScheme_UDP, UpstreamScheme_QUIC, UpstreamScheme_H3: l4protos = []consts.L4ProtoStr{consts.L4ProtoStr_UDP} case UpstreamScheme_TCP_UDP: // UDP first. @@ -127,7 +152,7 @@ func (u *Upstream) SupportedNetworks() (ipversions []consts.IpVersionStr, l4prot } func (u *Upstream) String() string { - return string(u.Scheme) + "://" + net.JoinHostPort(u.Hostname, strconv.Itoa(int(u.Port))) + return string(u.Scheme) + "://" + net.JoinHostPort(u.Hostname, strconv.Itoa(int(u.Port))) + u.Path } type UpstreamResolver struct { diff --git a/control/dns.go b/control/dns.go new file mode 100644 index 000000000..1a3878036 --- /dev/null +++ b/control/dns.go @@ -0,0 +1,437 @@ +package control + +import ( + "context" + "crypto/tls" + "encoding/base64" + "encoding/binary" + "fmt" + "io" + "net" + "net/http" + "net/url" + "time" + + "github.com/daeuniverse/dae/common" + "github.com/daeuniverse/dae/common/consts" + "github.com/daeuniverse/dae/component/dns" + "github.com/daeuniverse/outbound/netproxy" + "github.com/daeuniverse/outbound/pool" + tc "github.com/daeuniverse/outbound/protocol/tuic/common" + "github.com/daeuniverse/quic-go" + "github.com/daeuniverse/quic-go/http3" + dnsmessage "github.com/miekg/dns" +) + +type DnsForwarder interface { + ForwardDNS(ctx context.Context, data []byte) (*dnsmessage.Msg, error) + Close() error +} + +func newDnsForwarder(upstream *dns.Upstream, dialArgument dialArgument) (DnsForwarder, error) { + forwarder, err := func() (DnsForwarder, error) { + switch dialArgument.l4proto { + case consts.L4ProtoStr_TCP: + switch upstream.Scheme { + case dns.UpstreamScheme_TCP, dns.UpstreamScheme_TCP_UDP: + return &DoTCP{Upstream: *upstream, Dialer: dialArgument.bestDialer, dialArgument: dialArgument}, nil + case dns.UpstreamScheme_TLS: + return &DoTLS{Upstream: *upstream, Dialer: dialArgument.bestDialer, dialArgument: dialArgument}, nil + case dns.UpstreamScheme_HTTPS: + return &DoH{Upstream: *upstream, Dialer: dialArgument.bestDialer, dialArgument: dialArgument, http3: false}, nil + default: + return nil, fmt.Errorf("unexpected scheme: %v", upstream.Scheme) + } + case consts.L4ProtoStr_UDP: + switch upstream.Scheme { + case dns.UpstreamScheme_UDP, dns.UpstreamScheme_TCP_UDP: + return &DoUDP{Upstream: *upstream, Dialer: dialArgument.bestDialer, dialArgument: dialArgument}, nil + case dns.UpstreamScheme_QUIC: + return &DoQ{Upstream: *upstream, Dialer: dialArgument.bestDialer, dialArgument: dialArgument}, nil + case dns.UpstreamScheme_H3: + return &DoH{Upstream: *upstream, Dialer: dialArgument.bestDialer, dialArgument: dialArgument, http3: true}, nil + default: + return nil, fmt.Errorf("unexpected scheme: %v", upstream.Scheme) + } + default: + return nil, fmt.Errorf("unexpected l4proto: %v", dialArgument.l4proto) + } + }() + if err != nil { + return nil, err + } + return forwarder, nil +} + +type DoH struct { + dns.Upstream + netproxy.Dialer + dialArgument dialArgument + http3 bool + client *http.Client +} + +func (d *DoH) ForwardDNS(ctx context.Context, data []byte) (*dnsmessage.Msg, error) { + if d.client == nil { + d.client = d.getClient() + } + msg, err := sendHttpDNS(d.client, d.dialArgument.bestTarget.String(), &d.Upstream, data) + if err != nil { + // If failed to send DNS request, we should try to create a new client. + d.client = d.getClient() + msg, err = sendHttpDNS(d.client, d.dialArgument.bestTarget.String(), &d.Upstream, data) + if err != nil { + return nil, err + } + return msg, nil + } + return msg, nil +} + +func (d *DoH) getClient() *http.Client { + var roundTripper http.RoundTripper + if d.http3 { + roundTripper = d.getHttp3RoundTripper() + } else { + roundTripper = d.getHttpRoundTripper() + } + + return &http.Client{ + Transport: roundTripper, + } +} + +func (d *DoH) getHttpRoundTripper() *http.Transport { + httpTransport := http.Transport{ + TLSClientConfig: &tls.Config{ + ServerName: d.Upstream.Hostname, + InsecureSkipVerify: false, + }, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + conn, err := d.dialArgument.bestDialer.DialContext( + ctx, + common.MagicNetwork("tcp", d.dialArgument.mark, d.dialArgument.mptcp), + d.dialArgument.bestTarget.String(), + ) + if err != nil { + return nil, err + } + return &netproxy.FakeNetConn{Conn: conn}, nil + }, + } + + return &httpTransport +} + +func (d *DoH) getHttp3RoundTripper() *http3.RoundTripper { + roundTripper := &http3.RoundTripper{ + TLSClientConfig: &tls.Config{ + ServerName: d.Upstream.Hostname, + NextProtos: []string{"h3"}, + InsecureSkipVerify: false, + }, + QuicConfig: &quic.Config{}, + Dial: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { + udpAddr := net.UDPAddrFromAddrPort(d.dialArgument.bestTarget) + conn, err := d.dialArgument.bestDialer.DialContext( + ctx, + common.MagicNetwork("udp", d.dialArgument.mark, d.dialArgument.mptcp), + d.dialArgument.bestTarget.String(), + ) + if err != nil { + return nil, err + } + fakePkt := netproxy.NewFakeNetPacketConn(conn.(netproxy.PacketConn), net.UDPAddrFromAddrPort(tc.GetUniqueFakeAddrPort()), udpAddr) + c, e := quic.DialEarly(ctx, fakePkt, udpAddr, tlsCfg, cfg) + return c, e + }, + } + return roundTripper +} + +func (d *DoH) Close() error { + return nil +} + +type DoQ struct { + dns.Upstream + netproxy.Dialer + dialArgument dialArgument + connection quic.EarlyConnection +} + +func (d *DoQ) ForwardDNS(ctx context.Context, data []byte) (*dnsmessage.Msg, error) { + if d.connection == nil { + qc, err := d.createConnection(ctx) + if err != nil { + return nil, err + } + d.connection = qc + } + + stream, err := d.connection.OpenStreamSync(ctx) + if err != nil { + // If failed to open stream, we should try to create a new connection. + qc, err := d.createConnection(ctx) + if err != nil { + return nil, err + } + d.connection = qc + stream, err = d.connection.OpenStreamSync(ctx) + if err != nil { + return nil, err + } + } + defer func() { + _ = stream.Close() + }() + + // According https://datatracker.ietf.org/doc/html/rfc9250#section-4.2.1 + // msg id should set to 0 when transport over QUIC. + // thanks https://github.com/natesales/q/blob/1cb2639caf69bd0a9b46494a3c689130df8fb24a/transport/quic.go#L97 + binary.BigEndian.PutUint16(data[0:2], 0) + + msg, err := sendStreamDNS(stream, data) + if err != nil { + return nil, err + } + return msg, nil +} +func (d *DoQ) createConnection(ctx context.Context) (quic.EarlyConnection, error) { + + udpAddr := net.UDPAddrFromAddrPort(d.dialArgument.bestTarget) + conn, err := d.dialArgument.bestDialer.DialContext( + ctx, + common.MagicNetwork("udp", d.dialArgument.mark, d.dialArgument.mptcp), + d.dialArgument.bestTarget.String(), + ) + if err != nil { + return nil, err + } + + fakePkt := netproxy.NewFakeNetPacketConn(conn.(netproxy.PacketConn), net.UDPAddrFromAddrPort(tc.GetUniqueFakeAddrPort()), udpAddr) + tlsCfg := &tls.Config{ + NextProtos: []string{"doq"}, + InsecureSkipVerify: false, + ServerName: d.Upstream.Hostname, + } + addr := net.UDPAddrFromAddrPort(d.dialArgument.bestTarget) + qc, err := quic.DialEarly(ctx, fakePkt, addr, tlsCfg, nil) + if err != nil { + return nil, err + } + return qc, nil + +} + +func (d *DoQ) Close() error { + return nil +} + +type DoTLS struct { + dns.Upstream + netproxy.Dialer + dialArgument dialArgument + conn netproxy.Conn +} + +func (d *DoTLS) ForwardDNS(ctx context.Context, data []byte) (*dnsmessage.Msg, error) { + conn, err := d.dialArgument.bestDialer.DialContext( + ctx, + common.MagicNetwork("tcp", d.dialArgument.mark, d.dialArgument.mptcp), + d.dialArgument.bestTarget.String(), + ) + if err != nil { + return nil, err + } + + tlsConn := tls.Client(&netproxy.FakeNetConn{Conn: conn}, &tls.Config{ + InsecureSkipVerify: false, + ServerName: d.Upstream.Hostname, + }) + if err = tlsConn.Handshake(); err != nil { + return nil, err + } + d.conn = tlsConn + + return sendStreamDNS(tlsConn, data) +} + +func (d *DoTLS) Close() error { + if d.conn != nil { + return d.conn.Close() + } + return nil +} + +type DoTCP struct { + dns.Upstream + netproxy.Dialer + dialArgument dialArgument + conn netproxy.Conn +} + +func (d *DoTCP) ForwardDNS(ctx context.Context, data []byte) (*dnsmessage.Msg, error) { + conn, err := d.dialArgument.bestDialer.DialContext( + ctx, + common.MagicNetwork("tcp", d.dialArgument.mark, d.dialArgument.mptcp), + d.dialArgument.bestTarget.String(), + ) + if err != nil { + return nil, err + } + + d.conn = conn + return sendStreamDNS(conn, data) +} + +func (d *DoTCP) Close() error { + if d.conn != nil { + return d.conn.Close() + } + return nil +} + +type DoUDP struct { + dns.Upstream + netproxy.Dialer + dialArgument dialArgument + conn netproxy.Conn +} + +func (d *DoUDP) ForwardDNS(ctx context.Context, data []byte) (*dnsmessage.Msg, error) { + conn, err := d.dialArgument.bestDialer.DialContext( + ctx, + common.MagicNetwork("udp", d.dialArgument.mark, d.dialArgument.mptcp), + d.dialArgument.bestTarget.String(), + ) + if err != nil { + return nil, err + } + + timeout := 5 * time.Second + _ = conn.SetDeadline(time.Now().Add(timeout)) + dnsReqCtx, cancelDnsReqCtx := context.WithTimeout(context.TODO(), timeout) + defer cancelDnsReqCtx() + + go func() { + // Send DNS request every seconds. + for { + _, err = conn.Write(data) + // if err != nil { + // if c.log.IsLevelEnabled(logrus.DebugLevel) { + // c.log.WithFields(logrus.Fields{ + // "to": dialArgument.bestTarget.String(), + // "pid": req.routingResult.Pid, + // "pname": ProcessName2String(req.routingResult.Pname[:]), + // "mac": Mac2String(req.routingResult.Mac[:]), + // "from": req.realSrc.String(), + // "network": networkType.String(), + // "err": err.Error(), + // }).Debugln("Failed to write UDP(DNS) packet request.") + // } + // return + // } + select { + case <-dnsReqCtx.Done(): + return + case <-time.After(1 * time.Second): + } + } + }() + + // We can block here because we are in a coroutine. + respBuf := pool.GetFullCap(consts.EthernetMtu) + defer pool.Put(respBuf) + // Wait for response. + n, err := conn.Read(respBuf) + if err != nil { + return nil, err + } + var msg dnsmessage.Msg + if err = msg.Unpack(respBuf[:n]); err != nil { + return nil, err + } + return &msg, nil +} + +func (d *DoUDP) Close() error { + if d.conn != nil { + return d.conn.Close() + } + return nil +} + +func sendHttpDNS(client *http.Client, target string, upstream *dns.Upstream, data []byte) (respMsg *dnsmessage.Msg, err error) { + // disable redirect https://github.com/daeuniverse/dae/pull/649#issuecomment-2379577896 + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return fmt.Errorf("do not use a server that will redirect, upstream: %v", upstream.String()) + } + serverURL := url.URL{ + Scheme: "https", + Host: target, + Path: upstream.Path, + } + q := serverURL.Query() + // According https://datatracker.ietf.org/doc/html/rfc8484#section-4 + // msg id should set to 0 when transport over HTTPS for cache friendly. + binary.BigEndian.PutUint16(data[0:2], 0) + q.Set("dns", base64.RawURLEncoding.EncodeToString(data)) + serverURL.RawQuery = q.Encode() + + req, err := http.NewRequest(http.MethodGet, serverURL.String(), nil) + if err != nil { + return nil, err + } + req.Header.Set("Accept", "application/dns-message") + req.Host = upstream.Hostname + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + buf, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + var msg dnsmessage.Msg + if err = msg.Unpack(buf); err != nil { + return nil, err + } + return &msg, nil +} + +func sendStreamDNS(stream io.ReadWriter, data []byte) (respMsg *dnsmessage.Msg, err error) { + // We should write two byte length in the front of stream DNS request. + bReq := pool.Get(2 + len(data)) + defer pool.Put(bReq) + binary.BigEndian.PutUint16(bReq, uint16(len(data))) + copy(bReq[2:], data) + _, err = stream.Write(bReq) + if err != nil { + return nil, fmt.Errorf("failed to write DNS req: %w", err) + } + + // Read two byte length. + if _, err = io.ReadFull(stream, bReq[:2]); err != nil { + return nil, fmt.Errorf("failed to read DNS resp payload length: %w", err) + } + respLen := int(binary.BigEndian.Uint16(bReq)) + // Try to reuse the buf. + var buf []byte + if len(bReq) < respLen { + buf = pool.Get(respLen) + defer pool.Put(buf) + } else { + buf = bReq + } + var n int + if n, err = io.ReadFull(stream, buf[:respLen]); err != nil { + return nil, fmt.Errorf("failed to read DNS resp payload: %w", err) + } + var msg dnsmessage.Msg + if err = msg.Unpack(buf[:n]); err != nil { + return nil, err + } + return &msg, nil +} diff --git a/control/dns_control.go b/control/dns_control.go index ac653e885..543581473 100644 --- a/control/dns_control.go +++ b/control/dns_control.go @@ -7,9 +7,7 @@ package control import ( "context" - "encoding/binary" "fmt" - "io" "math" "net" "net/netip" @@ -18,16 +16,12 @@ import ( "sync" "time" - "github.com/daeuniverse/dae/common" - "github.com/daeuniverse/dae/common/consts" "github.com/daeuniverse/dae/common/netutils" "github.com/daeuniverse/dae/component/dns" "github.com/daeuniverse/dae/component/outbound" "github.com/daeuniverse/dae/component/outbound/dialer" - "github.com/daeuniverse/outbound/netproxy" "github.com/daeuniverse/outbound/pkg/fastrand" - "github.com/daeuniverse/outbound/pool" dnsmessage "github.com/miekg/dns" "github.com/mohae/deepcopy" "github.com/sirupsen/logrus" @@ -84,6 +78,8 @@ type DnsController struct { // mutex protects the dnsCache. dnsCacheMu sync.Mutex dnsCache map[string]*DnsCache + dnsForwarderCacheMu sync.Mutex + dnsForwarderCache map[string]DnsForwarder } func parseIpVersionPreference(prefer int) (uint16, error) { @@ -120,6 +116,8 @@ func NewDnsController(routing *dns.Dns, option *DnsControllerOption) (c *DnsCont fixedDomainTtl: option.FixedDomainTtl, dnsCacheMu: sync.Mutex{}, dnsCache: make(map[string]*DnsCache), + dnsForwarderCacheMu: sync.Mutex{}, + dnsForwarderCache: make(map[string]DnsForwarder), }, nil } @@ -558,130 +556,40 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte // the next recursive call. However, a connection cannot be closed twice. // We should set a connClosed flag to avoid it. var connClosed bool - var conn netproxy.Conn ctxDial, cancel := context.WithTimeout(context.TODO(), consts.DefaultDialTimeout) defer cancel() - switch dialArgument.l4proto { - case consts.L4ProtoStr_UDP: - // Get udp endpoint. - - // TODO: connection pool. - conn, err = dialArgument.bestDialer.DialContext( - ctxDial, - common.MagicNetwork("udp", dialArgument.mark, dialArgument.mptcp), - dialArgument.bestTarget.String(), - ) - if err != nil { - return fmt.Errorf("failed to dial '%v': %w", dialArgument.bestTarget, err) - } - defer func() { - if !connClosed { - conn.Close() - } - }() - - timeout := 5 * time.Second - _ = conn.SetDeadline(time.Now().Add(timeout)) - dnsReqCtx, cancelDnsReqCtx := context.WithTimeout(context.TODO(), timeout) - defer cancelDnsReqCtx() - go func() { - // Send DNS request every seconds. - for { - _, err = conn.Write(data) - if err != nil { - if c.log.IsLevelEnabled(logrus.DebugLevel) { - c.log.WithFields(logrus.Fields{ - "to": dialArgument.bestTarget.String(), - "pid": req.routingResult.Pid, - "pname": ProcessName2String(req.routingResult.Pname[:]), - "mac": Mac2String(req.routingResult.Mac[:]), - "from": req.realSrc.String(), - "network": networkType.String(), - "err": err.Error(), - }).Debugln("Failed to write UDP(DNS) packet request.") - } - return - } - select { - case <-dnsReqCtx.Done(): - return - case <-time.After(1 * time.Second): - } - } - }() - - // We can block here because we are in a coroutine. - respBuf := pool.GetFullCap(consts.EthernetMtu) - defer pool.Put(respBuf) - // Wait for response. - n, err := conn.Read(respBuf) + // get forwarder from cache + c.dnsForwarderCacheMu.Lock() + forwarder, ok := c.dnsForwarderCache[upstreamName] + if !ok { + forwarder, err = newDnsForwarder(upstream, *dialArgument) if err != nil { - if c.timeoutExceedCallback != nil { - c.timeoutExceedCallback(dialArgument, err) - } - return fmt.Errorf("failed to read from: %v (dialer: %v): %w", dialArgument.bestTarget, dialArgument.bestDialer.Property().Name, err) - } - var msg dnsmessage.Msg - if err = msg.Unpack(respBuf[:n]); err != nil { + c.dnsForwarderCacheMu.Unlock() return err } - respMsg = &msg - cancelDnsReqCtx() - - case consts.L4ProtoStr_TCP: - // We can block here because we are in a coroutine. + c.dnsForwarderCache[upstreamName] = forwarder + } + c.dnsForwarderCacheMu.Unlock() - conn, err = dialArgument.bestDialer.DialContext(ctxDial, common.MagicNetwork("tcp", dialArgument.mark, dialArgument.mptcp), dialArgument.bestTarget.String()) - if err != nil { - return fmt.Errorf("failed to dial proxy to tcp: %w", err) - } - defer func() { - if !connClosed { - conn.Close() - } - }() - - _ = conn.SetDeadline(time.Now().Add(4900 * time.Millisecond)) - // We should write two byte length in the front of TCP DNS request. - bReq := pool.Get(2 + len(data)) - defer pool.Put(bReq) - binary.BigEndian.PutUint16(bReq, uint16(len(data))) - copy(bReq[2:], data) - _, err = conn.Write(bReq) - if err != nil { - return fmt.Errorf("failed to write DNS req: %w", err) + defer func() { + if !connClosed { + forwarder.Close() } + }() - // Read two byte length. - if _, err = io.ReadFull(conn, bReq[:2]); err != nil { - return fmt.Errorf("failed to read DNS resp payload length: %w", err) - } - respLen := int(binary.BigEndian.Uint16(bReq)) - // Try to reuse the buf. - var buf []byte - if len(bReq) < respLen { - buf = pool.Get(respLen) - defer pool.Put(buf) - } else { - buf = bReq - } - var n int - if n, err = io.ReadFull(conn, buf[:respLen]); err != nil { - return fmt.Errorf("failed to read DNS resp payload: %w", err) - } - var msg dnsmessage.Msg - if err = msg.Unpack(buf[:n]); err != nil { - return err - } - respMsg = &msg - default: - return fmt.Errorf("unexpected l4proto: %v", dialArgument.l4proto) + if err != nil { + return err + } + + respMsg, err = forwarder.ForwardDNS(ctxDial, data) + if err != nil { + return err } // Close conn before the recursive call. - conn.Close() + forwarder.Close() connClosed = true // Route response.