Skip to content

Commit

Permalink
Implement remote DNS
Browse files Browse the repository at this point in the history
This commit implements the ability to resolve IP addresses remotely via
SOCKS5 or HTTP. It does so by intercepting DNS A and AAAA queries and
replying with IP addresses from a virtual IP space which is reserved
for DNS resolves. When an application performs a DNS lookup, a mapping
from the virtual IP to a DNS name is created and cached for a short
period of time. If the application subsequently connects to the virtual
IP, the dialer will use the DNS name from the mapping instead of the
destination IP address.
  • Loading branch information
blechschmidt committed Mar 20, 2023
1 parent 6cfc253 commit 894f087
Show file tree
Hide file tree
Showing 14 changed files with 385 additions and 26 deletions.
63 changes: 63 additions & 0 deletions component/remotedns/cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package remotedns

import (
"net"
"sync"

"github.com/jellydator/ttlcache/v2"
)

var (
cache = ttlcache.NewCache()
mutex = sync.Mutex{}
ttl uint32 = 0
)

func insertNameIntoCache(ipVersion int, name string) net.IP {
mutex.Lock()
defer mutex.Unlock()
var result net.IP = nil
var ipnet *net.IPNet
var nextAddress *net.IP
var broadcastAddress net.IP
if ipVersion == 4 {
ipnet = ip4net
nextAddress = &ip4NextAddress
broadcastAddress = ip4BroadcastAddress
}

// Beginning from the pointer to the next most likely free IP, loop through the IP address space
// until either a free IP is found or the space is exhausted
passedBroadcastAddress := false
for result == nil {
if nextAddress.Equal(broadcastAddress) {
*nextAddress = getNetworkAddress(ipnet)
*nextAddress = incrementIp(ipnet.IP)

// We have seen the broadcast address twice during looping
// This means that our IP address space is exhausted
if passedBroadcastAddress {
return nil
}
passedBroadcastAddress = true
}

// This method is protected by a mutex, and we are only inserting elements into the cache here.
_, err := cache.Get((*nextAddress).String())
if err == ttlcache.ErrNotFound {
_ = cache.Set((*nextAddress).String(), name)
result = *nextAddress
} else if err != nil { // Should never happen
panic(nil)
}

*nextAddress = incrementIp(*nextAddress)
}

return result
}

func getCachedName(address net.IP) (interface{}, bool) {
name, err := cache.Get(address.String())
return name, err == nil
}
88 changes: 88 additions & 0 deletions component/remotedns/handle.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package remotedns

import (
"net"

"github.com/miekg/dns"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"

"github.com/xjasonlyu/tun2socks/v2/log"
M "github.com/xjasonlyu/tun2socks/v2/metadata"
)

func RewriteMetadata(metadata *M.Metadata) bool {
if !IsEnabled() {
return false
}
dstName, found := getCachedName(metadata.DstIP)
if !found {
return false
}
metadata.VirtualIP = metadata.DstIP
metadata.DstIP = nil
metadata.DstName = dstName.(string)
return true
}

func HandleDNSQuery(s *stack.Stack, id stack.TransportEndpointID, ptr stack.PacketBufferPtr) bool {
if !IsEnabled() {
return false
}

msg := dns.Msg{}
err := msg.Unpack(ptr.Data().AsRange().ToSlice())

// Ignore UDP packets that are not IP queries to a recursive resolver
if id.LocalPort != 53 || err != nil || len(msg.Question) != 1 || msg.Question[0].Qtype != dns.TypeA &&
msg.Question[0].Qtype != dns.TypeAAAA || msg.Question[0].Qclass != dns.ClassINET || !msg.RecursionDesired ||
msg.Response {
return false
}

qname := msg.Question[0].Name
qtype := msg.Question[0].Qtype

log.Infof("[DNS] query %s %s", dns.TypeToString[qtype], qname)

msg.RecursionDesired = false
msg.RecursionAvailable = true
var ip net.IP
if qtype == dns.TypeA {
rr := dns.A{}
ip = insertNameIntoCache(4, qname)
if ip == nil {
log.Warnf("[DNS] IP space exhausted")
return true
}
rr.A = ip
rr.Hdr.Name = qname
rr.Hdr.Ttl = ttl
rr.Hdr.Class = dns.ClassINET
rr.Hdr.Rrtype = qtype
msg.Answer = append(msg.Answer, &rr)
}

msg.Response = true
msg.RecursionAvailable = true

var wq waiter.Queue

ep, err2 := s.NewEndpoint(ptr.TransportProtocolNumber, ptr.NetworkProtocolNumber, &wq)
if err2 != nil {
return true
}
defer ep.Close()

ep.Bind(tcpip.FullAddress{NIC: ptr.NICID, Addr: id.LocalAddress, Port: id.LocalPort})
conn := gonet.NewUDPConn(s, &wq, ep)
defer conn.Close()
packed, err := msg.Pack()
if err != nil {
return true
}
_, _ = conn.WriteTo(packed, &net.UDPAddr{IP: net.IP(id.RemoteAddress), Port: int(id.RemotePort)})
return true
}
36 changes: 36 additions & 0 deletions component/remotedns/iputil.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package remotedns

import "net"

func copyIP(ip net.IP) net.IP {
dup := make(net.IP, len(ip))
copy(dup, ip)
return dup
}

func incrementIp(ip net.IP) net.IP {
result := copyIP(ip)
for i := len(result) - 1; i >= 0; i-- {
result[i]++
if result[i] != 0 {
break
}
}
return result
}

func getBroadcastAddress(ipnet *net.IPNet) net.IP {
result := copyIP(ipnet.IP)
for i := 0; i < len(ipnet.IP); i++ {
result[i] |= ^ipnet.Mask[i]
}
return result
}

func getNetworkAddress(ipnet *net.IPNet) net.IP {
result := copyIP(ipnet.IP)
for i := 0; i < len(ipnet.IP); i++ {
result[i] &= ipnet.Mask[i]
}
return result
}
55 changes: 55 additions & 0 deletions component/remotedns/settings.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package remotedns

import (
"errors"
"net"
"time"
)

// The timeout is somewhat arbitrary. For example, netcat will resolve the DNS
// names upon startup and then stick to the resolved IP address. A timeout of 1
// second may therefore be too low in cases where the first UDP packet is not
// sent immediately.
const (
minTimeout = 30 * time.Second
)

var (
enabled = false
ip4net *net.IPNet
ip4NextAddress net.IP
ip4BroadcastAddress net.IP
)

func IsEnabled() bool {
return enabled
}

func SetCacheTimeout(timeout time.Duration) error {
if timeout < minTimeout {
timeout = minTimeout
}
ttl = uint32(timeout.Seconds())

// Keep the value a little longer in cache than propagated via DNS
return cache.SetTTL(timeout + 10*time.Second)
}

func SetNetwork(ipnet *net.IPNet) error {
leadingOnes, _ := ipnet.Mask.Size()
if len(ipnet.IP) == 4 {
if leadingOnes > 30 {
return errors.New("IPv4 remote DNS subnet too small")
}
ip4net = ipnet
} else {
return errors.New("unsupported protocol")
}
return nil
}

func Enable() {
ip4NextAddress = incrementIp(getNetworkAddress(ip4net))
ip4BroadcastAddress = getBroadcastAddress(ip4net)
enabled = true
}
39 changes: 23 additions & 16 deletions core/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,38 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"

"github.com/xjasonlyu/tun2socks/v2/component/remotedns"
"github.com/xjasonlyu/tun2socks/v2/core/adapter"
"github.com/xjasonlyu/tun2socks/v2/core/option"
)

func withUDPHandler(handle func(adapter.UDPConn), printf func(string, ...any)) option.Option {
return func(s *stack.Stack) error {
udpForwarder := udp.NewForwarder(s, func(r *udp.ForwarderRequest) {
var (
wq waiter.Queue
id = r.ID()
)
ep, err := r.CreateEndpoint(&wq)
if err != nil {
printf("udp forwarder request %s:%d->%s:%d: %s",
id.RemoteAddress, id.RemotePort, id.LocalAddress, id.LocalPort, err)
return
s.SetTransportProtocolHandler(udp.ProtocolNumber, func(id stack.TransportEndpointID, ptr stack.PacketBufferPtr) bool {
if remotedns.HandleDNSQuery(s, id, ptr) {
return true
}

conn := &udpConn{
UDPConn: gonet.NewUDPConn(s, &wq, ep),
id: id,
}
handle(conn)
udpForwarder := udp.NewForwarder(s, func(r *udp.ForwarderRequest) {
var (
wq waiter.Queue
id = r.ID()
)
ep, err := r.CreateEndpoint(&wq)
if err != nil {
printf("udp forwarder request %s:%d->%s:%d: %s",
id.RemoteAddress, id.RemotePort, id.LocalAddress, id.LocalPort, err)
return
}

conn := &udpConn{
UDPConn: gonet.NewUDPConn(s, &wq, ep),
id: id,
}
handle(conn)
})
return udpForwarder.HandlePacket(id, ptr)
})
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
return nil
}
}
Expand Down
38 changes: 38 additions & 0 deletions engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"

"github.com/xjasonlyu/tun2socks/v2/component/dialer"
"github.com/xjasonlyu/tun2socks/v2/component/remotedns"
"github.com/xjasonlyu/tun2socks/v2/core"
"github.com/xjasonlyu/tun2socks/v2/core/device"
"github.com/xjasonlyu/tun2socks/v2/core/option"
"github.com/xjasonlyu/tun2socks/v2/engine/mirror"
"github.com/xjasonlyu/tun2socks/v2/log"
"github.com/xjasonlyu/tun2socks/v2/proxy"
"github.com/xjasonlyu/tun2socks/v2/proxy/proto"
"github.com/xjasonlyu/tun2socks/v2/restapi"
"github.com/xjasonlyu/tun2socks/v2/tunnel"
)
Expand Down Expand Up @@ -151,6 +153,36 @@ func restAPI(k *Key) error {
return nil
}

func remoteDNS(k *Key, proxy proxy.Proxy) (err error) {
if !k.RemoteDNS {
return
}
if proxy.Proto() != proto.Socks5 && proxy.Proto() != proto.HTTP && proxy.Proto() != proto.Shadowsocks &&
proxy.Proto() != proto.Socks4 {
return errors.New("remote DNS not supported with this proxy protocol")
}

_, ipnet, err := net.ParseCIDR(k.RemoteDNSNetIPv4)
if err != nil {
return err
}

err = remotedns.SetNetwork(ipnet)
if err != nil {
return err
}

// Use the UDP timeout as cache timeout, so a DNS value is present in the cache for the duration of a connection
err = remotedns.SetCacheTimeout(k.UDPTimeout)
if err != nil {
return err
}

remotedns.Enable()
log.Infof("[DNS] Remote DNS enabled")
return
}

func netstack(k *Key) (err error) {
if k.Proxy == "" {
return errors.New("empty proxy")
Expand Down Expand Up @@ -205,5 +237,11 @@ func netstack(k *Key) (err error) {
_defaultDevice.Type(), _defaultDevice.Name(),
_defaultProxy.Proto(), _defaultProxy.Addr(),
)

err = remoteDNS(k, _defaultProxy)
if err != nil {
return err
}

return nil
}
2 changes: 2 additions & 0 deletions engine/key.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,6 @@ type Key struct {
TCPSendBufferSize string `yaml:"tcp-send-buffer-size"`
TCPReceiveBufferSize string `yaml:"tcp-receive-buffer-size"`
UDPTimeout time.Duration `yaml:"udp-timeout"`
RemoteDNS bool `yaml:"remote-dns"`
RemoteDNSNetIPv4 string `yaml:"remote-dns-net-ipv4"`
}
5 changes: 5 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,14 @@ require (
github.com/ajg/form v1.5.1 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/btree v1.1.2 // indirect
github.com/jellydator/ttlcache/v2 v2.11.1 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/miekg/dns v1.1.52 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/crypto v0.6.0 // indirect
golang.org/x/mod v0.7.0 // indirect
golang.org/x/net v0.7.0 // indirect
golang.org/x/sync v0.1.0 // indirect
golang.org/x/tools v0.5.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
)
Loading

0 comments on commit 894f087

Please sign in to comment.