Skip to content

Commit

Permalink
refactor: add streamDNS func for DoT, tcp, DoQ
Browse files Browse the repository at this point in the history
  • Loading branch information
EkkoG committed Sep 26, 2024
1 parent 0054e6c commit 543f6cb
Showing 1 changed file with 45 additions and 58 deletions.
103 changes: 45 additions & 58 deletions control/dns_control.go
Original file line number Diff line number Diff line change
Expand Up @@ -698,38 +698,11 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte
_ = stream.Close()
}()

// We should write two byte length in the front of QUIC DNS request.
bReq := pool.Get(2 + len(data))
defer pool.Put(bReq)
binary.BigEndian.PutUint16(bReq, uint16(len(data)))
copy(bReq[2:], data)
_, err = stream.Write(bReq)
msg, err := streamDNS(stream, data)
if err != nil {
return fmt.Errorf("failed to write DNS req: %w", err)
}

// Read two byte length.
if _, err = io.ReadFull(stream, bReq[:2]); err != nil {
return fmt.Errorf("failed to read DNS resp payload length: %w", err)
}
respLen := int(binary.BigEndian.Uint16(bReq))
// Try to reuse the buf.
var buf []byte
if len(bReq) < respLen {
buf = pool.Get(respLen)
defer pool.Put(buf)
} else {
buf = bReq
}
var n int
if n, err = io.ReadFull(stream, buf[:respLen]); err != nil {
return fmt.Errorf("failed to read DNS resp payload: %w", err)
}
var msg dnsmessage.Msg
if err = msg.Unpack(buf[:n]); err != nil {
return err
}
respMsg = &msg
respMsg = msg
}

case consts.L4ProtoStr_TCP:
Expand All @@ -755,38 +728,11 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte
_ = conn.SetDeadline(time.Now().Add(4900 * time.Millisecond))
switch upstream.Scheme {
case dns.UpstreamScheme_TCP, dns.UpstreamScheme_TLS, dns.UpstreamScheme_TCP_UDP:
// We should write two byte length in the front of TCP DNS request.
bReq := pool.Get(2 + len(data))
defer pool.Put(bReq)
binary.BigEndian.PutUint16(bReq, uint16(len(data)))
copy(bReq[2:], data)
_, err = conn.Write(bReq)
msg, err := streamDNS(conn, data)
if err != nil {
return fmt.Errorf("failed to write DNS req: %w", err)
}

// Read two byte length.
if _, err = io.ReadFull(conn, bReq[:2]); err != nil {
return fmt.Errorf("failed to read DNS resp payload length: %w", err)
}
respLen := int(binary.BigEndian.Uint16(bReq))
// Try to reuse the buf.
var buf []byte
if len(bReq) < respLen {
buf = pool.Get(respLen)
defer pool.Put(buf)
} else {
buf = bReq
}
var n int
if n, err = io.ReadFull(conn, buf[:respLen]); err != nil {
return fmt.Errorf("failed to read DNS resp payload: %w", err)
}
var msg dnsmessage.Msg
if err = msg.Unpack(buf[:n]); err != nil {
return err
}
respMsg = &msg
respMsg = msg
case dns.UpstreamScheme_HTTPS:

httpTransport := http.Transport{
Expand Down Expand Up @@ -923,3 +869,44 @@ func httpDNS(client *http.Client, target string, data []byte) (respMsg *dnsmessa
respMsg = &msg
return respMsg, nil
}

type stream interface {
io.Reader
io.Writer
}

func streamDNS(stream stream, data []byte) (respMsg *dnsmessage.Msg, err error) {
// We should write two byte length in the front of QUIC DNS request.
bReq := pool.Get(2 + len(data))
defer pool.Put(bReq)
binary.BigEndian.PutUint16(bReq, uint16(len(data)))
copy(bReq[2:], data)
_, err = stream.Write(bReq)
if err != nil {
return nil, fmt.Errorf("failed to write DNS req: %w", err)
}

// Read two byte length.
if _, err = io.ReadFull(stream, bReq[:2]); err != nil {
return nil, fmt.Errorf("failed to read DNS resp payload length: %w", err)
}
respLen := int(binary.BigEndian.Uint16(bReq))
// Try to reuse the buf.
var buf []byte
if len(bReq) < respLen {
buf = pool.Get(respLen)
defer pool.Put(buf)
} else {
buf = bReq
}
var n int
if n, err = io.ReadFull(stream, buf[:respLen]); err != nil {
return nil, fmt.Errorf("failed to read DNS resp payload: %w", err)
}
var msg dnsmessage.Msg
if err = msg.Unpack(buf[:n]); err != nil {
return nil, err
}
return &msg, nil
}

0 comments on commit 543f6cb

Please sign in to comment.