-
-
Notifications
You must be signed in to change notification settings - Fork 472
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This commit implements remote DNS. It introduces two new dependencies: ttlcache and dns. Remote DNS intercepts UDP DNS queries for A records on port 53. It replies with an unused IP address from an address pool, 198.18.0.0/15 by default. When obtaining a new address from the pool, tun2socks needs to memorize which name the address belongs to, so that when a client connects to the address, it can instruct the proxy to connect to the FQDN. To implement this IP to name mapping, ttlcache is used. To prevent using multiple addresses for the same name, ttlcache is also used to implement a name to IP mapping. If an IP address is already cached for a name, that address is returned instread. When building a connection, the connection metadata is inspected and if the destination address is associated with a DNS name, the proxy is instructed to use this name instead of the IP address.
- Loading branch information
1 parent
63f71e0
commit 6a0dc7e
Showing
14 changed files
with
379 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
package remotedns | ||
|
||
import ( | ||
"net" | ||
|
||
"github.com/miekg/dns" | ||
"gvisor.dev/gvisor/pkg/tcpip" | ||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" | ||
"gvisor.dev/gvisor/pkg/tcpip/stack" | ||
"gvisor.dev/gvisor/pkg/waiter" | ||
|
||
"github.com/xjasonlyu/tun2socks/v2/log" | ||
M "github.com/xjasonlyu/tun2socks/v2/metadata" | ||
) | ||
|
||
func RewriteMetadata(metadata *M.Metadata) bool { | ||
if !IsEnabled() { | ||
return false | ||
} | ||
dstName, found := getCachedName(metadata.DstIP) | ||
if !found { | ||
return false | ||
} | ||
metadata.DstIP = nil | ||
metadata.DstName = dstName | ||
return true | ||
} | ||
|
||
func HandleDNSQuery(s *stack.Stack, id stack.TransportEndpointID, ptr *stack.PacketBuffer) bool { | ||
if !IsEnabled() { | ||
return false | ||
} | ||
|
||
msg := dns.Msg{} | ||
err := msg.Unpack(ptr.Data().AsRange().ToSlice()) | ||
|
||
isCorrectEndpoint := id.LocalPort == 53 && (listenAddress.Equal(id.LocalAddress.AsSlice()) || listenAddress.IsUnspecified()) | ||
|
||
// Ignore UDP packets that are not matching the listen address and are not recursive queries | ||
if !isCorrectEndpoint || err != nil || len(msg.Question) != 1 || msg.Question[0].Qtype != dns.TypeA && | ||
msg.Question[0].Qtype != dns.TypeAAAA || msg.Question[0].Qclass != dns.ClassINET || !msg.RecursionDesired || | ||
msg.Response { | ||
return false | ||
} | ||
|
||
qname := msg.Question[0].Name | ||
qtype := msg.Question[0].Qtype | ||
|
||
log.Debugf("[DNS] query %s %s", dns.TypeToString[qtype], qname) | ||
|
||
var ip net.IP | ||
if qtype == dns.TypeA { | ||
rr := dns.A{} | ||
ip = findOrInsertNameAndReturnIP(4, qname) | ||
if ip == nil { | ||
log.Warnf("[DNS] IP space exhausted") | ||
return true | ||
} | ||
rr.A = ip | ||
rr.Hdr.Name = qname | ||
rr.Hdr.Ttl = dnsTTL | ||
rr.Hdr.Class = dns.ClassINET | ||
rr.Hdr.Rrtype = qtype | ||
msg.Answer = append(msg.Answer, &rr) | ||
} | ||
|
||
msg.Response = true | ||
msg.RecursionDesired = false | ||
msg.RecursionAvailable = true | ||
|
||
var wq waiter.Queue | ||
|
||
ep, err2 := s.NewEndpoint(ptr.TransportProtocolNumber, ptr.NetworkProtocolNumber, &wq) | ||
if err2 != nil { | ||
return true | ||
} | ||
defer ep.Close() | ||
|
||
ep.Bind(tcpip.FullAddress{NIC: ptr.NICID, Addr: id.LocalAddress, Port: id.LocalPort}) | ||
conn := gonet.NewUDPConn(&wq, ep) | ||
defer conn.Close() | ||
packed, err := msg.Pack() | ||
if err != nil { | ||
return true | ||
} | ||
_, _ = conn.WriteTo(packed, &net.UDPAddr{IP: id.RemoteAddress.AsSlice(), Port: int(id.RemotePort)}) | ||
return true | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
package remotedns | ||
|
||
import "net" | ||
|
||
func copyIP(ip net.IP) net.IP { | ||
dup := make(net.IP, len(ip)) | ||
copy(dup, ip) | ||
return dup | ||
} | ||
|
||
func incrementIP(ip net.IP) net.IP { | ||
result := copyIP(ip) | ||
for i := len(result) - 1; i >= 0; i-- { | ||
result[i]++ | ||
if result[i] != 0 { | ||
break | ||
} | ||
} | ||
return result | ||
} | ||
|
||
func getBroadcastAddress(ipnet *net.IPNet) net.IP { | ||
result := copyIP(ipnet.IP) | ||
for i := 0; i < len(ipnet.IP); i++ { | ||
result[i] |= ^ipnet.Mask[i] | ||
} | ||
return result | ||
} | ||
|
||
func getNetworkAddress(ipnet *net.IPNet) net.IP { | ||
result := copyIP(ipnet.IP) | ||
for i := 0; i < len(ipnet.IP); i++ { | ||
result[i] &= ipnet.Mask[i] | ||
} | ||
return result | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
package remotedns | ||
|
||
import ( | ||
"net" | ||
"sync" | ||
"time" | ||
|
||
"github.com/jellydator/ttlcache/v3" | ||
) | ||
|
||
var ( | ||
ipToName = ttlcache.New[string, string]() | ||
nameToIP = ttlcache.New[string, net.IP]() | ||
mutex = sync.Mutex{} | ||
|
||
ip4NextAddress net.IP | ||
ip4BroadcastAddress net.IP | ||
) | ||
|
||
func findOrInsertNameAndReturnIP(ipVersion int, name string) net.IP { | ||
if ipVersion != 4 { | ||
panic("Method not implemented for IPv6") | ||
} | ||
mutex.Lock() | ||
defer mutex.Unlock() | ||
var result net.IP = nil | ||
var ipnet *net.IPNet | ||
var nextAddress *net.IP | ||
var broadcastAddress net.IP | ||
if ipVersion == 4 { | ||
ipnet = ip4net | ||
nextAddress = &ip4NextAddress | ||
broadcastAddress = ip4BroadcastAddress | ||
} | ||
|
||
nameToIP.DeleteExpired() | ||
ipToName.DeleteExpired() | ||
|
||
entry := nameToIP.Get(name) | ||
if entry != nil { | ||
ip := entry.Value() | ||
ipToName.Touch(ip.String()) | ||
return ip | ||
} | ||
|
||
// Beginning from the pointer to the next most likely free IP, loop through the IP address space | ||
// until either a free IP is found or the space is exhausted | ||
passedBroadcastAddress := false | ||
for result == nil { | ||
if nextAddress.Equal(broadcastAddress) { | ||
*nextAddress = getNetworkAddress(ipnet) | ||
*nextAddress = incrementIP(ipnet.IP) | ||
|
||
// We have seen the broadcast address twice during looping | ||
// This means that our IP address space is exhausted | ||
if passedBroadcastAddress { | ||
return nil | ||
} | ||
passedBroadcastAddress = true | ||
} | ||
|
||
// Skip the listen address if that is inside our pool range | ||
if nextAddress.Equal(listenAddress) { | ||
*nextAddress = incrementIP(*nextAddress) | ||
continue | ||
} | ||
|
||
// Do not touch entries that exist in the cache already. | ||
hasKey := ipToName.Has((*nextAddress).String()) | ||
if !hasKey { | ||
_ = ipToName.Set((*nextAddress).String(), name, time.Duration(dnsTTL)*time.Second+cacheGraceTime) | ||
_ = nameToIP.Set(name, *nextAddress, time.Duration(dnsTTL)*time.Second+cacheGraceTime) | ||
result = *nextAddress | ||
} | ||
|
||
*nextAddress = incrementIP(*nextAddress) | ||
} | ||
|
||
return result | ||
} | ||
|
||
func getCachedName(address net.IP) (string, bool) { | ||
mutex.Lock() | ||
defer mutex.Unlock() | ||
entry := ipToName.Get(address.String()) | ||
if entry == nil { | ||
return "", false | ||
} | ||
nameToIP.Touch(entry.Value()) | ||
return entry.Value(), true | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
package remotedns | ||
|
||
import ( | ||
"errors" | ||
"net" | ||
"time" | ||
) | ||
|
||
// Timeouts are somewhat arbitrary. For example, netcat will resolve the DNS | ||
// names upon startup and then stick to the resolved IP address. A timeout of 1 | ||
// second may therefore be too low in cases where the first UDP packet is not | ||
// sent immediately. | ||
// cacheGraceTime defines how long an entry should still be retained in the cache | ||
// after being resolved by DNS. | ||
const ( | ||
cacheGraceTime = 30 * time.Second | ||
) | ||
|
||
var ( | ||
enabled = false | ||
dnsTTL uint32 = 0 | ||
ip4net *net.IPNet | ||
listenAddress net.IP | ||
) | ||
|
||
func IsEnabled() bool { | ||
return enabled | ||
} | ||
|
||
func SetDNSTTL(timeout time.Duration) { | ||
dnsTTL = uint32(timeout.Seconds()) | ||
} | ||
|
||
func SetListenAddress(ip net.IP) { | ||
listenAddress = ip | ||
} | ||
|
||
func SetNetwork(ipnet *net.IPNet) error { | ||
leadingOnes, _ := ipnet.Mask.Size() | ||
if len(ipnet.IP) == 4 { | ||
if leadingOnes > 30 { | ||
return errors.New("IPv4 remote DNS subnet too small") | ||
} | ||
ip4net = ipnet | ||
} else { | ||
return errors.New("unsupported protocol") | ||
} | ||
return nil | ||
} | ||
|
||
func Enable() { | ||
ip4NextAddress = incrementIP(getNetworkAddress(ip4net)) | ||
ip4BroadcastAddress = getBroadcastAddress(ip4net) | ||
enabled = true | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.