Skip to content

Commit

Permalink
Implement remote DNS
Browse files Browse the repository at this point in the history
This commit implements remote DNS. It introduces two new dependencies:
ttlcache and dns.

Remote DNS intercepts UDP DNS queries for A records on port 53. It
replies with an unused IP address from an address pool, 198.18.0.0/15 by
default. When obtaining a new address from the pool, tun2socks needs to
memorize which name the address belongs to, so that when a client
connects to the address, it can instruct the proxy to connect to the
FQDN. To implement this IP to name mapping, ttlcache is used.
To prevent using multiple addresses for the same name, ttlcache is also
used to implement a name to IP mapping. If an IP address is already
cached for a name, that address is returned instread.
When building a connection, the connection metadata is inspected and if
the destination address is associated with a DNS name, the proxy is
instructed to use this name instead of the IP address.
  • Loading branch information
blechschmidt committed Jul 16, 2024
1 parent 63f71e0 commit 6a0dc7e
Show file tree
Hide file tree
Showing 14 changed files with 379 additions and 21 deletions.
88 changes: 88 additions & 0 deletions component/remotedns/handle.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package remotedns

import (
"net"

"github.com/miekg/dns"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"

"github.com/xjasonlyu/tun2socks/v2/log"
M "github.com/xjasonlyu/tun2socks/v2/metadata"
)

func RewriteMetadata(metadata *M.Metadata) bool {
if !IsEnabled() {
return false
}
dstName, found := getCachedName(metadata.DstIP)
if !found {
return false
}
metadata.DstIP = nil
metadata.DstName = dstName
return true
}

func HandleDNSQuery(s *stack.Stack, id stack.TransportEndpointID, ptr *stack.PacketBuffer) bool {
if !IsEnabled() {
return false
}

msg := dns.Msg{}
err := msg.Unpack(ptr.Data().AsRange().ToSlice())

isCorrectEndpoint := id.LocalPort == 53 && (listenAddress.Equal(id.LocalAddress.AsSlice()) || listenAddress.IsUnspecified())

// Ignore UDP packets that are not matching the listen address and are not recursive queries
if !isCorrectEndpoint || err != nil || len(msg.Question) != 1 || msg.Question[0].Qtype != dns.TypeA &&
msg.Question[0].Qtype != dns.TypeAAAA || msg.Question[0].Qclass != dns.ClassINET || !msg.RecursionDesired ||
msg.Response {
return false
}

qname := msg.Question[0].Name
qtype := msg.Question[0].Qtype

log.Debugf("[DNS] query %s %s", dns.TypeToString[qtype], qname)

var ip net.IP
if qtype == dns.TypeA {
rr := dns.A{}
ip = findOrInsertNameAndReturnIP(4, qname)
if ip == nil {
log.Warnf("[DNS] IP space exhausted")
return true
}
rr.A = ip
rr.Hdr.Name = qname
rr.Hdr.Ttl = dnsTTL
rr.Hdr.Class = dns.ClassINET
rr.Hdr.Rrtype = qtype
msg.Answer = append(msg.Answer, &rr)
}

msg.Response = true
msg.RecursionDesired = false
msg.RecursionAvailable = true

var wq waiter.Queue

ep, err2 := s.NewEndpoint(ptr.TransportProtocolNumber, ptr.NetworkProtocolNumber, &wq)
if err2 != nil {
return true
}
defer ep.Close()

ep.Bind(tcpip.FullAddress{NIC: ptr.NICID, Addr: id.LocalAddress, Port: id.LocalPort})
conn := gonet.NewUDPConn(&wq, ep)
defer conn.Close()
packed, err := msg.Pack()
if err != nil {
return true
}
_, _ = conn.WriteTo(packed, &net.UDPAddr{IP: id.RemoteAddress.AsSlice(), Port: int(id.RemotePort)})
return true
}
36 changes: 36 additions & 0 deletions component/remotedns/iputil.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package remotedns

import "net"

func copyIP(ip net.IP) net.IP {
dup := make(net.IP, len(ip))
copy(dup, ip)
return dup
}

func incrementIP(ip net.IP) net.IP {
result := copyIP(ip)
for i := len(result) - 1; i >= 0; i-- {
result[i]++
if result[i] != 0 {
break
}
}
return result
}

func getBroadcastAddress(ipnet *net.IPNet) net.IP {
result := copyIP(ipnet.IP)
for i := 0; i < len(ipnet.IP); i++ {
result[i] |= ^ipnet.Mask[i]
}
return result
}

func getNetworkAddress(ipnet *net.IPNet) net.IP {
result := copyIP(ipnet.IP)
for i := 0; i < len(ipnet.IP); i++ {
result[i] &= ipnet.Mask[i]
}
return result
}
91 changes: 91 additions & 0 deletions component/remotedns/pool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package remotedns

import (
"net"
"sync"
"time"

"github.com/jellydator/ttlcache/v3"
)

var (
ipToName = ttlcache.New[string, string]()
nameToIP = ttlcache.New[string, net.IP]()
mutex = sync.Mutex{}

ip4NextAddress net.IP
ip4BroadcastAddress net.IP
)

func findOrInsertNameAndReturnIP(ipVersion int, name string) net.IP {
if ipVersion != 4 {
panic("Method not implemented for IPv6")
}
mutex.Lock()
defer mutex.Unlock()
var result net.IP = nil
var ipnet *net.IPNet
var nextAddress *net.IP
var broadcastAddress net.IP
if ipVersion == 4 {
ipnet = ip4net
nextAddress = &ip4NextAddress
broadcastAddress = ip4BroadcastAddress
}

nameToIP.DeleteExpired()
ipToName.DeleteExpired()

entry := nameToIP.Get(name)
if entry != nil {
ip := entry.Value()
ipToName.Touch(ip.String())
return ip
}

// Beginning from the pointer to the next most likely free IP, loop through the IP address space
// until either a free IP is found or the space is exhausted
passedBroadcastAddress := false
for result == nil {
if nextAddress.Equal(broadcastAddress) {
*nextAddress = getNetworkAddress(ipnet)
*nextAddress = incrementIP(ipnet.IP)

// We have seen the broadcast address twice during looping
// This means that our IP address space is exhausted
if passedBroadcastAddress {
return nil
}
passedBroadcastAddress = true
}

// Skip the listen address if that is inside our pool range
if nextAddress.Equal(listenAddress) {
*nextAddress = incrementIP(*nextAddress)
continue
}

// Do not touch entries that exist in the cache already.
hasKey := ipToName.Has((*nextAddress).String())
if !hasKey {
_ = ipToName.Set((*nextAddress).String(), name, time.Duration(dnsTTL)*time.Second+cacheGraceTime)
_ = nameToIP.Set(name, *nextAddress, time.Duration(dnsTTL)*time.Second+cacheGraceTime)
result = *nextAddress
}

*nextAddress = incrementIP(*nextAddress)
}

return result
}

func getCachedName(address net.IP) (string, bool) {
mutex.Lock()
defer mutex.Unlock()
entry := ipToName.Get(address.String())
if entry == nil {
return "", false
}
nameToIP.Touch(entry.Value())
return entry.Value(), true
}
55 changes: 55 additions & 0 deletions component/remotedns/settings.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package remotedns

import (
"errors"
"net"
"time"
)

// Timeouts are somewhat arbitrary. For example, netcat will resolve the DNS
// names upon startup and then stick to the resolved IP address. A timeout of 1
// second may therefore be too low in cases where the first UDP packet is not
// sent immediately.
// cacheGraceTime defines how long an entry should still be retained in the cache
// after being resolved by DNS.
const (
cacheGraceTime = 30 * time.Second
)

var (
enabled = false
dnsTTL uint32 = 0
ip4net *net.IPNet
listenAddress net.IP
)

func IsEnabled() bool {
return enabled
}

func SetDNSTTL(timeout time.Duration) {
dnsTTL = uint32(timeout.Seconds())
}

func SetListenAddress(ip net.IP) {
listenAddress = ip
}

func SetNetwork(ipnet *net.IPNet) error {
leadingOnes, _ := ipnet.Mask.Size()
if len(ipnet.IP) == 4 {
if leadingOnes > 30 {
return errors.New("IPv4 remote DNS subnet too small")
}
ip4net = ipnet
} else {
return errors.New("unsupported protocol")
}
return nil
}

func Enable() {
ip4NextAddress = incrementIP(getNetworkAddress(ip4net))
ip4BroadcastAddress = getBroadcastAddress(ip4net)
enabled = true
}
39 changes: 23 additions & 16 deletions core/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,38 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"

"github.com/xjasonlyu/tun2socks/v2/component/remotedns"
"github.com/xjasonlyu/tun2socks/v2/core/adapter"
"github.com/xjasonlyu/tun2socks/v2/core/option"
)

func withUDPHandler(handle func(adapter.UDPConn)) option.Option {
return func(s *stack.Stack) error {
udpForwarder := udp.NewForwarder(s, func(r *udp.ForwarderRequest) {
var (
wq waiter.Queue
id = r.ID()
)
ep, err := r.CreateEndpoint(&wq)
if err != nil {
glog.Debugf("forward udp request: %s:%d->%s:%d: %s",
id.RemoteAddress, id.RemotePort, id.LocalAddress, id.LocalPort, err)
return
s.SetTransportProtocolHandler(udp.ProtocolNumber, func(id stack.TransportEndpointID, ptr *stack.PacketBuffer) bool {
if remotedns.HandleDNSQuery(s, id, ptr) {
return true
}

conn := &udpConn{
UDPConn: gonet.NewUDPConn(&wq, ep),
id: id,
}
handle(conn)
udpForwarder := udp.NewForwarder(s, func(r *udp.ForwarderRequest) {
var (
wq waiter.Queue
id = r.ID()
)
ep, err := r.CreateEndpoint(&wq)
if err != nil {
glog.Debugf("forward udp request %s:%d->%s:%d: %s",
id.RemoteAddress, id.RemotePort, id.LocalAddress, id.LocalPort, err)
return
}

conn := &udpConn{
UDPConn: gonet.NewUDPConn(&wq, ep),
id: id,
}
handle(conn)
})
return udpForwarder.HandlePacket(id, ptr)
})
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
return nil
}
}
Expand Down
Loading

0 comments on commit 6a0dc7e

Please sign in to comment.