diff --git a/common/cache/lrucache.go b/common/cache/lrucache.go new file mode 100644 index 00000000..e0f5d20c --- /dev/null +++ b/common/cache/lrucache.go @@ -0,0 +1,223 @@ +package cache + +// Modified by https://github.com/die-net/lrucache + +import ( + "container/list" + "sync" + "time" +) + +// Option is part of Functional Options Pattern +type Option func(*LruCache) + +// EvictCallback is used to get a callback when a cache entry is evicted +type EvictCallback = func(key any, value any) + +// WithEvict set the evict callback +func WithEvict(cb EvictCallback) Option { + return func(l *LruCache) { + l.onEvict = cb + } +} + +// WithUpdateAgeOnGet update expires when Get element +func WithUpdateAgeOnGet() Option { + return func(l *LruCache) { + l.updateAgeOnGet = true + } +} + +// WithAge defined element max age (second) +func WithAge(maxAge int64) Option { + return func(l *LruCache) { + l.maxAge = maxAge + } +} + +// WithSize defined max length of LruCache +func WithSize(maxSize int) Option { + return func(l *LruCache) { + l.maxSize = maxSize + } +} + +// WithStale decide whether Stale return is enabled. +// If this feature is enabled, element will not get Evicted according to `WithAge`. +func WithStale(stale bool) Option { + return func(l *LruCache) { + l.staleReturn = stale + } +} + +// LruCache is a thread-safe, in-memory lru-cache that evicts the +// least recently used entries from memory when (if set) the entries are +// older than maxAge (in seconds). Use the New constructor to create one. +type LruCache struct { + maxAge int64 + maxSize int + mu sync.Mutex + cache map[any]*list.Element + lru *list.List // Front is least-recent + updateAgeOnGet bool + staleReturn bool + onEvict EvictCallback +} + +// New creates an LruCache +func New(options ...Option) *LruCache { + lc := &LruCache{ + lru: list.New(), + cache: make(map[any]*list.Element), + } + + for _, option := range options { + option(lc) + } + + return lc +} + +// Get returns the any representation of a cached response and a bool +// set to true if the key was found. +func (c *LruCache) Get(key any) (any, bool) { + entry := c.get(key) + if entry == nil { + return nil, false + } + value := entry.value + + return value, true +} + +// GetWithExpire returns the any representation of a cached response, +// a time.Time Give expected expires, +// and a bool set to true if the key was found. +// This method will NOT check the maxAge of element and will NOT update the expires. +func (c *LruCache) GetWithExpire(key any) (any, time.Time, bool) { + entry := c.get(key) + if entry == nil { + return nil, time.Time{}, false + } + + return entry.value, time.Unix(entry.expires, 0), true +} + +// Exist returns if key exist in cache but not put item to the head of linked list +func (c *LruCache) Exist(key any) bool { + c.mu.Lock() + defer c.mu.Unlock() + + _, ok := c.cache[key] + return ok +} + +// Set stores the any representation of a response for a given key. +func (c *LruCache) Set(key any, value any) { + expires := int64(0) + if c.maxAge > 0 { + expires = time.Now().Unix() + c.maxAge + } + c.SetWithExpire(key, value, time.Unix(expires, 0)) +} + +// SetWithExpire stores the any representation of a response for a given key and given expires. +// The expires time will round to second. +func (c *LruCache) SetWithExpire(key any, value any, expires time.Time) { + c.mu.Lock() + defer c.mu.Unlock() + + if le, ok := c.cache[key]; ok { + c.lru.MoveToBack(le) + e := le.Value.(*entry) + e.value = value + e.expires = expires.Unix() + } else { + e := &entry{key: key, value: value, expires: expires.Unix()} + c.cache[key] = c.lru.PushBack(e) + + if c.maxSize > 0 { + if len := c.lru.Len(); len > c.maxSize { + c.deleteElement(c.lru.Front()) + } + } + } + + c.maybeDeleteOldest() +} + +// CloneTo clone and overwrite elements to another LruCache +func (c *LruCache) CloneTo(n *LruCache) { + c.mu.Lock() + defer c.mu.Unlock() + + n.mu.Lock() + defer n.mu.Unlock() + + n.lru = list.New() + n.cache = make(map[any]*list.Element) + + for e := c.lru.Front(); e != nil; e = e.Next() { + elm := e.Value.(*entry) + n.cache[elm.key] = n.lru.PushBack(elm) + } +} + +func (c *LruCache) get(key any) *entry { + c.mu.Lock() + defer c.mu.Unlock() + + le, ok := c.cache[key] + if !ok { + return nil + } + + if !c.staleReturn && c.maxAge > 0 && le.Value.(*entry).expires <= time.Now().Unix() { + c.deleteElement(le) + c.maybeDeleteOldest() + + return nil + } + + c.lru.MoveToBack(le) + entry := le.Value.(*entry) + if c.maxAge > 0 && c.updateAgeOnGet { + entry.expires = time.Now().Unix() + c.maxAge + } + return entry +} + +// Delete removes the value associated with a key. +func (c *LruCache) Delete(key any) { + c.mu.Lock() + + if le, ok := c.cache[key]; ok { + c.deleteElement(le) + } + + c.mu.Unlock() +} + +func (c *LruCache) maybeDeleteOldest() { + if !c.staleReturn && c.maxAge > 0 { + now := time.Now().Unix() + for le := c.lru.Front(); le != nil && le.Value.(*entry).expires <= now; le = c.lru.Front() { + c.deleteElement(le) + } + } +} + +func (c *LruCache) deleteElement(le *list.Element) { + c.lru.Remove(le) + e := le.Value.(*entry) + delete(c.cache, e.key) + if c.onEvict != nil { + c.onEvict(e.key, e.value) + } +} + +type entry struct { + key any + value any + expires int64 +} diff --git a/common/cache/lrucache_test.go b/common/cache/lrucache_test.go new file mode 100644 index 00000000..1a09975d --- /dev/null +++ b/common/cache/lrucache_test.go @@ -0,0 +1,183 @@ +package cache + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +var entries = []struct { + key string + value string +}{ + {"1", "one"}, + {"2", "two"}, + {"3", "three"}, + {"4", "four"}, + {"5", "five"}, +} + +func TestLRUCache(t *testing.T) { + c := New() + + for _, e := range entries { + c.Set(e.key, e.value) + } + + c.Delete("missing") + _, ok := c.Get("missing") + assert.False(t, ok) + + for _, e := range entries { + value, ok := c.Get(e.key) + if assert.True(t, ok) { + assert.Equal(t, e.value, value.(string)) + } + } + + for _, e := range entries { + c.Delete(e.key) + + _, ok := c.Get(e.key) + assert.False(t, ok) + } +} + +func TestLRUMaxAge(t *testing.T) { + c := New(WithAge(86400)) + + now := time.Now().Unix() + expected := now + 86400 + + // Add one expired entry + c.Set("foo", "bar") + c.lru.Back().Value.(*entry).expires = now + + // Reset + c.Set("foo", "bar") + e := c.lru.Back().Value.(*entry) + assert.True(t, e.expires >= now) + c.lru.Back().Value.(*entry).expires = now + + // Set a few and verify expiration times + for _, s := range entries { + c.Set(s.key, s.value) + e := c.lru.Back().Value.(*entry) + assert.True(t, e.expires >= expected && e.expires <= expected+10) + } + + // Make sure we can get them all + for _, s := range entries { + _, ok := c.Get(s.key) + assert.True(t, ok) + } + + // Expire all entries + for _, s := range entries { + le, ok := c.cache[s.key] + if assert.True(t, ok) { + le.Value.(*entry).expires = now + } + } + + // Get one expired entry, which should clear all expired entries + _, ok := c.Get("3") + assert.False(t, ok) + assert.Equal(t, c.lru.Len(), 0) +} + +func TestLRUpdateOnGet(t *testing.T) { + c := New(WithAge(86400), WithUpdateAgeOnGet()) + + now := time.Now().Unix() + expires := now + 86400/2 + + // Add one expired entry + c.Set("foo", "bar") + c.lru.Back().Value.(*entry).expires = expires + + _, ok := c.Get("foo") + assert.True(t, ok) + assert.True(t, c.lru.Back().Value.(*entry).expires > expires) +} + +func TestMaxSize(t *testing.T) { + c := New(WithSize(2)) + // Add one expired entry + c.Set("foo", "bar") + _, ok := c.Get("foo") + assert.True(t, ok) + + c.Set("bar", "foo") + c.Set("baz", "foo") + + _, ok = c.Get("foo") + assert.False(t, ok) +} + +func TestExist(t *testing.T) { + c := New(WithSize(1)) + c.Set(1, 2) + assert.True(t, c.Exist(1)) + c.Set(2, 3) + assert.False(t, c.Exist(1)) +} + +func TestEvict(t *testing.T) { + temp := 0 + evict := func(key any, value any) { + temp = key.(int) + value.(int) + } + + c := New(WithEvict(evict), WithSize(1)) + c.Set(1, 2) + c.Set(2, 3) + + assert.Equal(t, temp, 3) +} + +func TestSetWithExpire(t *testing.T) { + c := New(WithAge(1)) + now := time.Now().Unix() + + tenSecBefore := time.Unix(now-10, 0) + c.SetWithExpire(1, 2, tenSecBefore) + + // res is expected not to exist, and expires should be empty time.Time + res, expires, exist := c.GetWithExpire(1) + assert.Equal(t, nil, res) + assert.Equal(t, time.Time{}, expires) + assert.Equal(t, false, exist) +} + +func TestStale(t *testing.T) { + c := New(WithAge(1), WithStale(true)) + now := time.Now().Unix() + + tenSecBefore := time.Unix(now-10, 0) + c.SetWithExpire(1, 2, tenSecBefore) + + res, expires, exist := c.GetWithExpire(1) + assert.Equal(t, 2, res) + assert.Equal(t, tenSecBefore, expires) + assert.Equal(t, true, exist) +} + +func TestCloneTo(t *testing.T) { + o := New(WithSize(10)) + o.Set("1", 1) + o.Set("2", 2) + + n := New(WithSize(2)) + n.Set("3", 3) + n.Set("4", 4) + + o.CloneTo(n) + + assert.False(t, n.Exist("3")) + assert.True(t, n.Exist("1")) + + n.Set("5", 5) + assert.False(t, n.Exist("1")) +} diff --git a/common/sockopt/reuseaddr_linux.go b/common/sockopt/reuseaddr_linux.go new file mode 100644 index 00000000..a1d19bfd --- /dev/null +++ b/common/sockopt/reuseaddr_linux.go @@ -0,0 +1,19 @@ +package sockopt + +import ( + "net" + "syscall" +) + +func UDPReuseaddr(c *net.UDPConn) (err error) { + rc, err := c.SyscallConn() + if err != nil { + return + } + + rc.Control(func(fd uintptr) { + err = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) + }) + + return +} diff --git a/common/sockopt/reuseaddr_other.go b/common/sockopt/reuseaddr_other.go new file mode 100644 index 00000000..04fc8ed7 --- /dev/null +++ b/common/sockopt/reuseaddr_other.go @@ -0,0 +1,11 @@ +//go:build !linux + +package sockopt + +import ( + "net" +) + +func UDPReuseaddr(c *net.UDPConn) (err error) { + return +} diff --git a/component/fakeip/memory.go b/component/fakeip/memory.go new file mode 100644 index 00000000..69ba70e4 --- /dev/null +++ b/component/fakeip/memory.go @@ -0,0 +1,69 @@ +package fakeip + +import ( + "net" + + "github.com/xjasonlyu/tun2socks/v2/common/cache" +) + +type memoryStore struct { + cache *cache.LruCache +} + +// GetByHost implements store.GetByHost +func (m *memoryStore) GetByHost(host string) (net.IP, bool) { + if elm, exist := m.cache.Get(host); exist { + ip := elm.(net.IP) + + // ensure ip --> host on head of linked list + m.cache.Get(ipToUint(ip.To4())) + return ip, true + } + + return nil, false +} + +// PutByHost implements store.PutByHost +func (m *memoryStore) PutByHost(host string, ip net.IP) { + m.cache.Set(host, ip) +} + +// GetByIP implements store.GetByIP +func (m *memoryStore) GetByIP(ip net.IP) (string, bool) { + if elm, exist := m.cache.Get(ipToUint(ip.To4())); exist { + host := elm.(string) + + // ensure host --> ip on head of linked list + m.cache.Get(host) + return host, true + } + + return "", false +} + +// PutByIP implements store.PutByIP +func (m *memoryStore) PutByIP(ip net.IP, host string) { + m.cache.Set(ipToUint(ip.To4()), host) +} + +// DelByIP implements store.DelByIP +func (m *memoryStore) DelByIP(ip net.IP) { + ipNum := ipToUint(ip.To4()) + if elm, exist := m.cache.Get(ipNum); exist { + m.cache.Delete(elm.(string)) + } + m.cache.Delete(ipNum) +} + +// Exist implements store.Exist +func (m *memoryStore) Exist(ip net.IP) bool { + return m.cache.Exist(ipToUint(ip.To4())) +} + +// CloneTo implements store.CloneTo +// only for memoryStore to memoryStore +func (m *memoryStore) CloneTo(store store) { + if ms, ok := store.(*memoryStore); ok { + m.cache.CloneTo(ms.cache) + } +} diff --git a/component/fakeip/pool.go b/component/fakeip/pool.go new file mode 100644 index 00000000..134f6eb1 --- /dev/null +++ b/component/fakeip/pool.go @@ -0,0 +1,169 @@ +package fakeip + +import ( + "errors" + "net" + "strings" + "sync" + + "github.com/xjasonlyu/tun2socks/v2/common/cache" + "github.com/xjasonlyu/tun2socks/v2/component/trie" +) + +type store interface { + GetByHost(host string) (net.IP, bool) + PutByHost(host string, ip net.IP) + GetByIP(ip net.IP) (string, bool) + PutByIP(ip net.IP, host string) + DelByIP(ip net.IP) + Exist(ip net.IP) bool + CloneTo(store) +} + +// Pool is an implementation about fake ip generator without storage +type Pool struct { + max uint32 + min uint32 + gateway uint32 + offset uint32 + mux sync.Mutex + host *trie.DomainTrie + ipnet *net.IPNet + store store +} + +// Lookup return a fake ip with host +func (p *Pool) Lookup(host string) net.IP { + p.mux.Lock() + defer p.mux.Unlock() + + // RFC4343: DNS Case Insensitive, we SHOULD return result with all cases. + host = strings.ToLower(host) + if ip, exist := p.store.GetByHost(host); exist { + return ip + } + + ip := p.get(host) + p.store.PutByHost(host, ip) + return ip +} + +// LookBack return host with the fake ip +func (p *Pool) LookBack(ip net.IP) (string, bool) { + p.mux.Lock() + defer p.mux.Unlock() + + if ip = ip.To4(); ip == nil { + return "", false + } + + return p.store.GetByIP(ip) +} + +// ShouldSkipped return if domain should be skipped +func (p *Pool) ShouldSkipped(domain string) bool { + if p.host == nil { + return false + } + return p.host.Search(domain) != nil +} + +// Exist returns if given ip exists in fake-ip pool +func (p *Pool) Exist(ip net.IP) bool { + p.mux.Lock() + defer p.mux.Unlock() + + if ip = ip.To4(); ip == nil { + return false + } + + return p.store.Exist(ip) +} + +// Gateway return gateway ip +func (p *Pool) Gateway() net.IP { + return uintToIP(p.gateway) +} + +// IPNet return raw ipnet +func (p *Pool) IPNet() *net.IPNet { + return p.ipnet +} + +// CloneFrom clone cache from old pool +func (p *Pool) CloneFrom(o *Pool) { + o.store.CloneTo(p.store) +} + +func (p *Pool) get(host string) net.IP { + current := p.offset + for { + ip := uintToIP(p.min + p.offset) + if !p.store.Exist(ip) { + break + } + + p.offset = (p.offset + 1) % (p.max - p.min) + // Avoid infinite loops + if p.offset == current { + p.offset = (p.offset + 1) % (p.max - p.min) + ip := uintToIP(p.min + p.offset) + p.store.DelByIP(ip) + break + } + } + ip := uintToIP(p.min + p.offset) + p.store.PutByIP(ip, host) + return ip +} + +func ipToUint(ip net.IP) uint32 { + v := uint32(ip[0]) << 24 + v += uint32(ip[1]) << 16 + v += uint32(ip[2]) << 8 + v += uint32(ip[3]) + return v +} + +func uintToIP(v uint32) net.IP { + return net.IP{byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)} +} + +type Options struct { + IPNet *net.IPNet + Host *trie.DomainTrie + + // Size sets the maximum number of entries in memory + // and does not work if Persistence is true + Size int + + // Persistence will save the data to disk. + // Size will not work and record will be fully stored. + Persistence bool +} + +// New return Pool instance +func New(options Options) (*Pool, error) { + minIP := ipToUint(options.IPNet.IP) + 2 + + ones, bits := options.IPNet.Mask.Size() + total := 1<= 0; i-- { + part := parts[i] + if !node.hasChild(part) { + node.addChild(part, newNode(nil)) + } + + node = node.getChild(part) + } + + node.Data = data +} + +// Search is the most important part of the Trie. +// Priority as: +// 1. static part +// 2. wildcard domain +// 2. dot wildcard domain +func (t *DomainTrie) Search(domain string) *Node { + parts, valid := ValidAndSplitDomain(domain) + if !valid || parts[0] == "" { + return nil + } + + n := t.search(t.root, parts) + + if n == nil || n.Data == nil { + return nil + } + + return n +} + +func (t *DomainTrie) search(node *Node, parts []string) *Node { + if len(parts) == 0 { + return node + } + + if c := node.getChild(parts[len(parts)-1]); c != nil { + if n := t.search(c, parts[:len(parts)-1]); n != nil && n.Data != nil { + return n + } + } + + if c := node.getChild(wildcard); c != nil { + if n := t.search(c, parts[:len(parts)-1]); n != nil && n.Data != nil { + return n + } + } + + return node.getChild(dotWildcard) +} + +// New returns a new, empty Trie. +func New() *DomainTrie { + return &DomainTrie{root: newNode(nil)} +} diff --git a/component/trie/domain_test.go b/component/trie/domain_test.go new file mode 100644 index 00000000..4322699a --- /dev/null +++ b/component/trie/domain_test.go @@ -0,0 +1,107 @@ +package trie + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" +) + +var localIP = net.IP{127, 0, 0, 1} + +func TestTrie_Basic(t *testing.T) { + tree := New() + domains := []string{ + "example.com", + "google.com", + "localhost", + } + + for _, domain := range domains { + tree.Insert(domain, localIP) + } + + node := tree.Search("example.com") + assert.NotNil(t, node) + assert.True(t, node.Data.(net.IP).Equal(localIP)) + assert.NotNil(t, tree.Insert("", localIP)) + assert.Nil(t, tree.Search("")) + assert.NotNil(t, tree.Search("localhost")) + assert.Nil(t, tree.Search("www.google.com")) +} + +func TestTrie_Wildcard(t *testing.T) { + tree := New() + domains := []string{ + "*.example.com", + "sub.*.example.com", + "*.dev", + ".org", + ".example.net", + ".apple.*", + "+.foo.com", + "+.stun.*.*", + "+.stun.*.*.*", + "+.stun.*.*.*.*", + "stun.l.google.com", + } + + for _, domain := range domains { + tree.Insert(domain, localIP) + } + + assert.NotNil(t, tree.Search("sub.example.com")) + assert.NotNil(t, tree.Search("sub.foo.example.com")) + assert.NotNil(t, tree.Search("test.org")) + assert.NotNil(t, tree.Search("test.example.net")) + assert.NotNil(t, tree.Search("test.apple.com")) + assert.NotNil(t, tree.Search("test.foo.com")) + assert.NotNil(t, tree.Search("foo.com")) + assert.NotNil(t, tree.Search("global.stun.website.com")) + assert.Nil(t, tree.Search("foo.sub.example.com")) + assert.Nil(t, tree.Search("foo.example.dev")) + assert.Nil(t, tree.Search("example.com")) +} + +func TestTrie_Priority(t *testing.T) { + tree := New() + domains := []string{ + ".dev", + "example.dev", + "*.example.dev", + "test.example.dev", + } + + assertFn := func(domain string, data int) { + node := tree.Search(domain) + assert.NotNil(t, node) + assert.Equal(t, data, node.Data) + } + + for idx, domain := range domains { + tree.Insert(domain, idx) + } + + assertFn("test.dev", 0) + assertFn("foo.bar.dev", 0) + assertFn("example.dev", 1) + assertFn("foo.example.dev", 2) + assertFn("test.example.dev", 3) +} + +func TestTrie_Boundary(t *testing.T) { + tree := New() + tree.Insert("*.dev", localIP) + + assert.NotNil(t, tree.Insert(".", localIP)) + assert.NotNil(t, tree.Insert("..dev", localIP)) + assert.Nil(t, tree.Search("dev")) +} + +func TestTrie_WildcardBoundary(t *testing.T) { + tree := New() + tree.Insert("+.*", localIP) + tree.Insert("stun.*.*.*", localIP) + + assert.NotNil(t, tree.Search("example.com")) +} diff --git a/component/trie/node.go b/component/trie/node.go new file mode 100644 index 00000000..67ef64a4 --- /dev/null +++ b/component/trie/node.go @@ -0,0 +1,26 @@ +package trie + +// Node is the trie's node +type Node struct { + children map[string]*Node + Data any +} + +func (n *Node) getChild(s string) *Node { + return n.children[s] +} + +func (n *Node) hasChild(s string) bool { + return n.getChild(s) != nil +} + +func (n *Node) addChild(s string, child *Node) { + n.children[s] = child +} + +func newNode(data any) *Node { + return &Node{ + Data: data, + children: map[string]*Node{}, + } +} diff --git a/dns/fakedns.go b/dns/fakedns.go new file mode 100644 index 00000000..d0df2635 --- /dev/null +++ b/dns/fakedns.go @@ -0,0 +1,77 @@ +package dns + +import "C" + +import ( + "errors" + "strings" + + D "github.com/miekg/dns" + + "github.com/xjasonlyu/tun2socks/v2/component/fakeip" + M "github.com/xjasonlyu/tun2socks/v2/metadata" +) + +var ( + fakePool *fakeip.Pool + fakeDNSenabled = false +) + +func setMsgTTL(msg *D.Msg, ttl uint32) { + for _, answer := range msg.Answer { + answer.Header().Ttl = ttl + } + + for _, ns := range msg.Ns { + ns.Header().Ttl = ttl + } + + for _, extra := range msg.Extra { + extra.Header().Ttl = ttl + } +} + +func EnableFakeDNS() { + fakeDNSenabled = true +} + +func ProcessMetadata(metadata *M.Metadata) bool { + if !fakeDNSenabled { + return false + } + dstName, found := fakePool.LookBack(metadata.DstIP) + if !found { + return false + } + metadata.DstIP = nil + metadata.DstName = dstName + return true +} + +func fakeipHandler(fakePool *fakeip.Pool) handler { + return func(r *D.Msg) (*D.Msg, error) { + if len(r.Question) == 0 { + return nil, errors.New("at least one question is required") + } + + q := r.Question[0] + + host := strings.TrimRight(q.Name, ".") + msg := r.Copy() + + if q.Qtype == D.TypeA { + rr := &D.A{} + rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: dnsDefaultTTL} + ip := fakePool.Lookup(host) + rr.A = ip + msg.Answer = []D.RR{rr} + } + + setMsgTTL(msg, 1) + msg.SetRcode(r, D.RcodeSuccess) + msg.RecursionAvailable = true + msg.Response = true + + return msg, nil + } +} diff --git a/dns/server.go b/dns/server.go new file mode 100644 index 00000000..f996101a --- /dev/null +++ b/dns/server.go @@ -0,0 +1,95 @@ +package dns + +import ( + "errors" + "net" + + D "github.com/miekg/dns" + + "github.com/xjasonlyu/tun2socks/v2/common/sockopt" + "github.com/xjasonlyu/tun2socks/v2/component/fakeip" + "github.com/xjasonlyu/tun2socks/v2/log" +) + +var ( + server = &Server{} + dnsDefaultTTL uint32 = 600 +) + +type ( + handler func(r *D.Msg) (*D.Msg, error) +) + +type Server struct { + *D.Server + handler handler +} + +// ServeDNS implement D.Handler ServeDNS +func (s *Server) ServeDNS(w D.ResponseWriter, r *D.Msg) { + msg, err := handlerWithContext(s.handler, r) + if err != nil { + D.HandleFailed(w, r) + return + } + msg.Compress = true + w.WriteMsg(msg) +} + +func handlerWithContext(handler handler, msg *D.Msg) (*D.Msg, error) { + if len(msg.Question) == 0 { + return nil, errors.New("at least one question is required") + } + + return handler(msg) +} + +func ReCreateServer(addr string, pool *fakeip.Pool) { + fakePool = pool + if server.Server != nil { + server.Shutdown() + server = &Server{} + } + + if addr == "" { + return + } + + var err error + defer func() { + if err != nil { + log.Errorf("Start DNS server error: %s", err.Error()) + } + }() + + _, port, err := net.SplitHostPort(addr) + if port == "0" || port == "" || err != nil { + return + } + + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return + } + + udpConn, err := net.ListenUDP("udp", udpAddr) + if err != nil { + return + } + + err = sockopt.UDPReuseaddr(udpConn) + if err != nil { + log.Warnf("Failed to Reuse UDP Address: %s", err) + + err = nil + } + + server = &Server{handler: fakeipHandler(fakePool)} + server.Server = &D.Server{Addr: addr, PacketConn: udpConn, Handler: server} + + go func() { + server.ActivateAndServe() + }() + + log.Infof("DNS server listening at: %s", udpConn.LocalAddr().String()) +} diff --git a/engine/engine.go b/engine/engine.go index 52c2a938..13c54c0b 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -12,13 +12,17 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" + "github.com/xjasonlyu/tun2socks/v2/component/fakeip" + "github.com/xjasonlyu/tun2socks/v2/component/trie" "github.com/xjasonlyu/tun2socks/v2/core" "github.com/xjasonlyu/tun2socks/v2/core/device" "github.com/xjasonlyu/tun2socks/v2/core/option" "github.com/xjasonlyu/tun2socks/v2/dialer" + "github.com/xjasonlyu/tun2socks/v2/dns" "github.com/xjasonlyu/tun2socks/v2/engine/mirror" "github.com/xjasonlyu/tun2socks/v2/log" "github.com/xjasonlyu/tun2socks/v2/proxy" + "github.com/xjasonlyu/tun2socks/v2/proxy/proto" "github.com/xjasonlyu/tun2socks/v2/restapi" "github.com/xjasonlyu/tun2socks/v2/tunnel" ) @@ -164,6 +168,35 @@ func restAPI(k *Key) error { return nil } +func remoteDNS(k *Key, proxy proxy.Proxy) (err error) { + if !k.RemoteDNS { + return + } + + if proxy.Proto() != proto.Socks5 && proxy.Proto() != proto.HTTP && proxy.Proto() != proto.Shadowsocks && + proxy.Proto() != proto.Socks4 { + return errors.New("remote DNS not supported with this proxy protocol") + } + + _, ipnet, err := net.ParseCIDR(k.RemoteDNSNetIPv4) + if err != nil { + return err + } + + pool, err := fakeip.New(fakeip.Options{ + IPNet: ipnet, + Size: 1000, + Host: trie.New(), + }) + + dns.EnableFakeDNS() + + dns.ReCreateServer(k.RemoteDNSListenAddress, pool) + + log.Infof("[DNS] Remote DNS enabled") + return +} + func netstack(k *Key) (err error) { if k.Proxy == "" { return errors.New("empty proxy") @@ -238,5 +271,11 @@ func netstack(k *Key) (err error) { _defaultDevice.Type(), _defaultDevice.Name(), _defaultProxy.Proto(), _defaultProxy.Addr(), ) + + err = remoteDNS(k, _defaultProxy) + if err != nil { + return err + } + return nil } diff --git a/engine/key.go b/engine/key.go index 5a86d53a..b56a46da 100644 --- a/engine/key.go +++ b/engine/key.go @@ -17,4 +17,7 @@ type Key struct { TUNPreUp string `yaml:"tun-pre-up"` TUNPostUp string `yaml:"tun-post-up"` UDPTimeout time.Duration `yaml:"udp-timeout"` + RemoteDNS bool `yaml:"remote-dns"` + RemoteDNSNetIPv4 string `yaml:"remote-dns-net-ipv4"` + RemoteDNSListenAddress string `yaml:"remote-dns-listen-addr"` } diff --git a/go.mod b/go.mod index bc18e77c..dfa67efd 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,8 @@ require ( github.com/google/uuid v1.6.0 github.com/gorilla/schema v1.4.1 github.com/gorilla/websocket v1.5.1 + github.com/jellydator/ttlcache/v3 v3.2.0 + github.com/miekg/dns v1.1.52 github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.8.4 go.uber.org/atomic v1.11.0 @@ -28,8 +30,10 @@ require ( github.com/ajg/form v1.5.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/google/btree v1.1.2 // indirect - github.com/kr/text v0.2.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/mod v0.14.0 // indirect golang.org/x/net v0.24.0 // indirect + golang.org/x/sync v0.6.0 // indirect + golang.org/x/tools v0.16.1 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect ) diff --git a/go.sum b/go.sum index 95e27fdb..8aff59a1 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,5 @@ github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU= github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -24,10 +23,14 @@ github.com/gorilla/schema v1.4.1 h1:jUg5hUjCSDZpNGLuXQOgIWGdlgrIdYvgQ0wZtdK1M3E= github.com/gorilla/schema v1.4.1/go.mod h1:Dg5SSm5PV60mhF2NFaTV1xuYYj8tV8NOPRo4FggUMnM= github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= +github.com/jellydator/ttlcache/v3 v3.2.0 h1:6lqVJ8X3ZaUwvzENqPAobDsXNExfUJd61u++uW8a3LE= +github.com/jellydator/ttlcache/v3 v3.2.0/go.mod h1:hi7MGFdMAwZna5n2tuvh63DvFLzVKySzCVW6+0gA2n4= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/miekg/dns v1.1.52 h1:Bmlc/qsNNULOe6bpXcUTsuOajd0DzRHwup6D9k1An0c= +github.com/miekg/dns v1.1.52/go.mod h1:uInx36IzPl7FYnDcMeVWxj9byh7DutNykX4G9Sj60FY= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g= @@ -42,15 +45,23 @@ go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/automaxprocs v1.5.3 h1:kWazyxZUrS3Gs4qUpbwo5kEIMGe/DAvi5Z4tl2NW4j8= go.uber.org/automaxprocs v1.5.3/go.mod h1:eRbA25aqJrxAbsLO0xy5jVwPt7FQnRgjW+efnwa1WM0= +go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A= +go.uber.org/goleak v1.2.1/go.mod h1:qlT2yGI9QafXHhZZLxlSuNsMw3FFLxBr+tBRlmO1xH4= golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= +golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0= +golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w= golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= +golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= +golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/tools v0.16.1 h1:TLyB3WofjdOEepBHAU20JdNC1Zbg87elYofWYAY5oZA= +golang.org/x/tools v0.16.1/go.mod h1:kYVVN6I1mBNoB1OX+noeBjbRk4IUEPa7JJ+TJMEooJ0= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4= diff --git a/main.go b/main.go index a7c08e4c..f5a2c3ec 100644 --- a/main.go +++ b/main.go @@ -40,6 +40,9 @@ func init() { flag.StringVar(&key.TUNPreUp, "tun-pre-up", "", "Execute a command before TUN device setup") flag.StringVar(&key.TUNPostUp, "tun-post-up", "", "Execute a command after TUN device setup") flag.BoolVar(&versionFlag, "version", false, "Show version and then quit") + flag.BoolVar(&key.RemoteDNS, "remote-dns", false, "Enable remote DNS (HTTP, Shadowsocks, SOCKS)") + flag.StringVar(&key.RemoteDNSNetIPv4, "remote-dns-net-ipv4", "198.18.0.0/15", "IPv4 network for remote DNS A records") + flag.StringVar(&key.RemoteDNSListenAddress, "remote-dns-listen-addr", "198.18.0.1", "IP to listen on for DNS requests") flag.Parse() } diff --git a/metadata/metadata.go b/metadata/metadata.go index 63755dbe..3908a6d5 100644 --- a/metadata/metadata.go +++ b/metadata/metadata.go @@ -11,13 +11,20 @@ type Metadata struct { SrcIP net.IP `json:"sourceIP"` MidIP net.IP `json:"dialerIP"` DstIP net.IP `json:"destinationIP"` + DstName string `json:"destinationName"` SrcPort uint16 `json:"sourcePort"` MidPort uint16 `json:"dialerPort"` DstPort uint16 `json:"destinationPort"` } func (m *Metadata) DestinationAddress() string { - return net.JoinHostPort(m.DstIP.String(), strconv.FormatUint(uint64(m.DstPort), 10)) + var remote string + if m.DstIP == nil { + remote = m.DstName + } else { + remote = m.DstIP.String() + } + return net.JoinHostPort(remote, strconv.FormatUint(uint64(m.DstPort), 10)) } func (m *Metadata) SourceAddress() string { diff --git a/proxy/socks5.go b/proxy/socks5.go index bdc9b04c..de91466f 100644 --- a/proxy/socks5.go +++ b/proxy/socks5.go @@ -186,5 +186,5 @@ func (pc *socksPacketConn) Close() error { } func serializeSocksAddr(m *M.Metadata) socks5.Addr { - return socks5.SerializeAddr("", m.DstIP, m.DstPort) + return socks5.SerializeAddr(m.DstName, m.DstIP, m.DstPort) } diff --git a/tunnel/tcp.go b/tunnel/tcp.go index 03cebab5..cd1d7aba 100644 --- a/tunnel/tcp.go +++ b/tunnel/tcp.go @@ -8,6 +8,7 @@ import ( "github.com/xjasonlyu/tun2socks/v2/common/pool" "github.com/xjasonlyu/tun2socks/v2/core/adapter" + "github.com/xjasonlyu/tun2socks/v2/dns" "github.com/xjasonlyu/tun2socks/v2/log" M "github.com/xjasonlyu/tun2socks/v2/metadata" "github.com/xjasonlyu/tun2socks/v2/proxy" @@ -31,6 +32,8 @@ func handleTCPConn(originConn adapter.TCPConn) { DstPort: id.LocalPort, } + dns.ProcessMetadata(metadata) + remoteConn, err := proxy.Dial(metadata) if err != nil { log.Warnf("[TCP] dial %s: %v", metadata.DestinationAddress(), err) diff --git a/tunnel/udp.go b/tunnel/udp.go index a10e2d47..04710e1c 100644 --- a/tunnel/udp.go +++ b/tunnel/udp.go @@ -8,6 +8,7 @@ import ( "github.com/xjasonlyu/tun2socks/v2/common/pool" "github.com/xjasonlyu/tun2socks/v2/core/adapter" + "github.com/xjasonlyu/tun2socks/v2/dns" "github.com/xjasonlyu/tun2socks/v2/log" M "github.com/xjasonlyu/tun2socks/v2/metadata" "github.com/xjasonlyu/tun2socks/v2/proxy" @@ -34,6 +35,8 @@ func handleUDPConn(uc adapter.UDPConn) { DstPort: id.LocalPort, } + dns.ProcessMetadata(metadata) + pc, err := proxy.DialUDP(metadata) if err != nil { log.Warnf("[UDP] dial %s: %v", metadata.DestinationAddress(), err) @@ -113,7 +116,10 @@ func (pc *symmetricNATPacketConn) ReadFrom(p []byte) (int, net.Addr, error) { for { n, from, err := pc.PacketConn.ReadFrom(p) - if from != nil && from.String() != pc.dst { + // If pc.dst is not an IP address, it is a hostname. In that case, we + // do not know the source IP which packets should originate from and + // cannot drop them accordingly. + if from != nil && from.String() != pc.dst && net.ParseIP(pc.dst) != nil { log.Warnf("[UDP] symmetric NAT %s->%s: drop packet from %s", pc.src, pc.dst, from) continue }