diff --git a/README.md b/README.md index 3bcc194..512865c 100644 --- a/README.md +++ b/README.md @@ -222,7 +222,7 @@ $ docker compose up -d + `tun-mode`: TUN 模式(实验性)。请阅读后文中的 TUN 模式注意事项 -+ `add_route`: 启用 TUN 模式时根据服务端下发配置添加路由 ++ `add-route`: 启用 TUN 模式时根据服务端下发配置添加路由 + `dns-ttl`: DNS 缓存时间,默认为 `3600` 秒 @@ -230,11 +230,11 @@ $ docker compose up -d + `zju-dns-server`: ZJU DNS 服务器地址,默认为 `10.10.0.21` -+ `secondary_dns_server`: 当使用 ZJU DNS 服务器无法解析时使用的备用 DNS 服务器,默认为 `114.114.114.114`。留空则使用系统默认 DNS,但在开启 `tun_dns_server` 时必须设置 ++ `secondary-dns-server`: 当使用 ZJU DNS 服务器无法解析时使用的备用 DNS 服务器,默认为 `114.114.114.114`。留空则使用系统默认 DNS,但在开启 `tun_dns_server` 时必须设置 -+ `dns_server_bind`: DNS 服务器监听地址,默认为空即禁用。例如,设置为 `127.0.0.1:53`,则可向 `127.0.0.1:53` 发起 DNS 请求 ++ `dns-server-bind`: DNS 服务器监听地址,默认为空即禁用。例如,设置为 `127.0.0.1:53`,则可向 `127.0.0.1:53` 发起 DNS 请求 -+ `tun_dns_server`: 启用 TUN 模式时使用的 DNS 服务器,不带端口。例如:`127.0.0.1`。可配合 `dns_server_bind` 实现 TUN 模式下正确的 DNS 解析。目前仅支持 Windows 系统 ++ `dns-hijack`: 启用 TUN 模式时劫持 DNS 请求,建议在启用 TUN 模式时添加此参数 + `debug-dump`: 是否开启调试,一般不需要加此参数 @@ -254,7 +254,7 @@ $ docker compose up -d 2. Windows 系统需要前往 [Wintun 官网](https://www.wintun.net)下载 `wintun.dll` 并放置于可执行文件同目录下 -3. Linux 和 macOS 暂不支持通过 `tun_dns_server` 自动配置系统 DNS。为保证 `*.zju.edu.cn` 解析正确,建议配置 `dns_server_bind` 并手动配置系统 DNS +3. 为保证 `*.zju.edu.cn` 解析正确,建议配置 `dns-hijack` 劫持系统 DNS 4. macOS 暂不支持通过 TUN 接口访问 `10.0.0.0/8` 外的地址 @@ -276,10 +276,10 @@ $ docker compose up -d - [x] 通过配置文件启动 - [x] 定时保活 - [x] TUN 模式 +- [x] 自动劫持 DNS #### To Do -- [ ] 自动劫持 DNS - [ ] Fake IP 模式 ### 贡献者 diff --git a/client/rvpn_conn.go b/client/rvpn_conn.go index 57346b1..0c3735f 100644 --- a/client/rvpn_conn.go +++ b/client/rvpn_conn.go @@ -3,40 +3,47 @@ package client import ( "github.com/mythologyli/zju-connect/log" "io" + "sync" ) type RvpnConn struct { easyConnectClient *EasyConnectClient sendConn io.WriteCloser + sendLock sync.Mutex sendErrCount int recvConn io.ReadCloser + recvLock sync.Mutex recvErrCount int } -// always success or panic +// try best to read, if return err!=nil, please panic func (r *RvpnConn) Read(p []byte) (n int, err error) { + r.recvLock.Lock() + defer r.recvLock.Unlock() for n, err = r.recvConn.Read(p); err != nil && r.recvErrCount < 5; { + log.Printf("Error occurred while receiving, retrying: %v", err) // Do handshake again and create a new recvConn _ = r.recvConn.Close() r.recvConn, err = r.easyConnectClient.RecvConn() if err != nil { - // TODO graceful shutdown - panic(err) + return 0, err } r.recvErrCount++ if r.recvErrCount >= 5 { - panic("recv retry limit exceeded.") + return 0, err } } return } -// always success or panic +// try best to write, if return err!=nil, please panic func (r *RvpnConn) Write(p []byte) (n int, err error) { + r.sendLock.Lock() + defer r.sendLock.Unlock() for n, err = r.sendConn.Write(p); err != nil && r.sendErrCount < 5; { log.Printf("Error occurred while sending, retrying: %v", err) @@ -44,12 +51,11 @@ func (r *RvpnConn) Write(p []byte) (n int, err error) { _ = r.sendConn.Close() r.sendConn, err = r.easyConnectClient.SendConn() if err != nil { - // TODO graceful shutdown - panic(err) + return 0, err } r.sendErrCount++ if r.sendErrCount >= 5 { - panic("send retry limit exceeded.") + return 0, err } } return @@ -76,13 +82,13 @@ func NewRvpnConn(ec *EasyConnectClient) (*RvpnConn, error) { c.sendConn, err = ec.SendConn() if err != nil { log.Printf("Error occurred while creating sendConn: %v", err) - panic(err) + return nil, err } c.recvConn, err = ec.RecvConn() if err != nil { log.Printf("Error occurred while creating recvConn: %v", err) - panic(err) + return nil, err } return c, nil } diff --git a/config.toml.example b/config.toml.example index 60551cd..0d09e8b 100644 --- a/config.toml.example +++ b/config.toml.example @@ -20,7 +20,7 @@ disable_keep_alive = false zju_dns_server = "10.10.0.21" secondary_dns_server = "114.114.114.114" dns_server_bind = "" -tun_dns_server = "" +dns_hijack = false debug_dump = false # Port forwarding diff --git a/go.mod b/go.mod index 0279ff9..f30d418 100644 --- a/go.mod +++ b/go.mod @@ -8,11 +8,10 @@ require ( github.com/BurntSushi/toml v1.2.1 github.com/beevik/etree v1.2.0 github.com/cloverstd/tcping v0.1.1 + github.com/cxz66666/sing-tun v0.0.0-20231028191617-2867d9374292 github.com/miekg/dns v1.1.56 github.com/patrickmn/go-cache v2.1.0+incompatible - github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 github.com/things-go/go-socks5 v0.0.4 - golang.org/x/net v0.17.0 golang.org/x/sys v0.13.0 golang.zx2c4.com/wireguard v0.0.0-20231022001213-2e0774f246fb golang.zx2c4.com/wireguard/windows v0.5.3 @@ -23,14 +22,23 @@ require ( require ( github.com/andybalholm/brotli v1.0.6 // indirect github.com/cloudflare/circl v1.3.5 // indirect + github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/gaukas/godicttls v0.0.4 // indirect + github.com/go-ole/go-ole v1.3.0 // indirect github.com/google/btree v1.1.2 // indirect github.com/klauspost/compress v1.17.1 // indirect + github.com/metacubex/gvisor v0.0.0-20231001104248-0f672c3fb8d8 // indirect github.com/quic-go/quic-go v0.39.1 // indirect + github.com/sagernet/go-tun2socks v1.16.12-0.20220818015926-16cb67876a61 // indirect + github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97 // indirect + github.com/sagernet/sing v0.2.14 // indirect + github.com/scjalliance/comshim v0.0.0-20230315213746-5e51f40bd3b9 // indirect + github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 // indirect go4.org/intern v0.0.0-20211027215823-ae77deb06f29 // indirect go4.org/unsafe/assume-no-moving-gc v0.0.0-20230525183740-e7c30c78aeb2 // indirect golang.org/x/crypto v0.14.0 // indirect golang.org/x/mod v0.13.0 // indirect + golang.org/x/net v0.17.0 // indirect golang.org/x/time v0.3.0 // indirect golang.org/x/tools v0.14.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect diff --git a/go.sum b/go.sum index eb73a98..0fe0950 100644 --- a/go.sum +++ b/go.sum @@ -8,13 +8,19 @@ github.com/cloudflare/circl v1.3.5 h1:g+wWynZqVALYAlpSQFAa7TscDnUK8mKYtrxMpw6AUK github.com/cloudflare/circl v1.3.5/go.mod h1:5XYMA4rFBvNIrhs50XuiBJ15vF2pZn4nnUKZrLbUZFA= github.com/cloverstd/tcping v0.1.1 h1:3Yp9nvSDI7Z63zoVQDJzVk1PUczrF9tJoOrKGV30iOk= github.com/cloverstd/tcping v0.1.1/go.mod h1:NYXTrTDwlwuOKQ0vwksUVUbIr0sxDDsf1J6aFpScCBo= +github.com/cxz66666/sing-tun v0.0.0-20231028191617-2867d9374292 h1:ualmTq9VM4rYWNYgr7/9Pg6SHz1uXWHKmI2Ufd8sEQY= +github.com/cxz66666/sing-tun v0.0.0-20231028191617-2867d9374292/go.mod h1:MBoiGEiPODP6YZko83BH5Nhnp2IqRIclCYNyFdLocWM= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dvyukov/go-fuzz v0.0.0-20210103155950-6a8e9d1f2415/go.mod h1:11Gm+ccJnvAhCNLlf5+cS9KjtbaD5I5zaZpFMsTHWTw= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/gaukas/godicttls v0.0.4 h1:NlRaXb3J6hAnTmWdsEKb9bcSBD6BvcIjdGdeb0zfXbk= github.com/gaukas/godicttls v0.0.4/go.mod h1:l6EenT4TLWgTdwslVb4sEMOCf7Bv0JAK67deKr9/NCI= github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= +github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= @@ -25,6 +31,8 @@ github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/klauspost/compress v1.17.1 h1:NE3C767s2ak2bweCZo3+rdP4U/HoyVXLv/X9f2gPS5g= github.com/klauspost/compress v1.17.1/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= +github.com/metacubex/gvisor v0.0.0-20231001104248-0f672c3fb8d8 h1:npBvaPAT145UY8682AzpUMWpdIxJti/WPLjy7gCiYYs= +github.com/metacubex/gvisor v0.0.0-20231001104248-0f672c3fb8d8/go.mod h1:ZR6Gas7P1GcADCVBc1uOrA0bLQqDDyp70+63fD/BE2c= github.com/miekg/dns v1.1.56 h1:5imZaSeoRNvpM9SzWNhEcP9QliKiz20/dA2QabIGVnE= github.com/miekg/dns v1.1.56/go.mod h1:cRm6Oo2C8TY9ZS/TqsSrseAcncm74lfK5G+ikN2SWWY= github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q= @@ -39,12 +47,21 @@ github.com/quic-go/quic-go v0.39.1 h1:d/m3oaN/SD2c+f7/yEjZxe2zEVotXprnrCCJ2y/ZZF github.com/quic-go/quic-go v0.39.1/go.mod h1:T09QsDQWjLiQ74ZmacDfqZmhY/NLnw5BC40MANNNZ1Q= github.com/refraction-networking/utls v1.5.4 h1:9k6EO2b8TaOGsQ7Pl7p9w6PUhx18/ZCeT0WNTZ7Uw4o= github.com/refraction-networking/utls v1.5.4/go.mod h1:SPuDbBmgLGp8s+HLNc83FuavwZCFoMmExj+ltUHiHUw= -github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8= -github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E= +github.com/sagernet/go-tun2socks v1.16.12-0.20220818015926-16cb67876a61 h1:5+m7c6AkmAylhauulqN/c5dnh8/KssrE9c93TQrXldA= +github.com/sagernet/go-tun2socks v1.16.12-0.20220818015926-16cb67876a61/go.mod h1:QUQ4RRHD6hGGHdFMEtR8T2P6GS6R3D/CXKdaYHKKXms= +github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97 h1:iL5gZI3uFp0X6EslacyapiRz7LLSJyr4RajF/BhMVyE= +github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM= +github.com/sagernet/sing v0.0.0-20220817130738-ce854cda8522/go.mod h1:QVsS5L/ZA2Q5UhQwLrn0Trw+msNd/NPGEhBKR/ioWiY= +github.com/sagernet/sing v0.2.14 h1:L3AXDh22nsOOYz2nTRU1JvpRsmzViWKI1B8TsQYG1eY= +github.com/sagernet/sing v0.2.14/go.mod h1:AhNEHu0GXrpqkuzvTwvC8+j2cQUU/dh+zLEmq4C99pg= +github.com/scjalliance/comshim v0.0.0-20230315213746-5e51f40bd3b9 h1:rc/CcqLH3lh8n+csdOuDfP+NuykE0U6AeYSJJHKDgSg= +github.com/scjalliance/comshim v0.0.0-20230315213746-5e51f40bd3b9/go.mod h1:a/83NAfUXvEuLpmxDssAXxgUgrEy12MId3Wd7OTs76s= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/things-go/go-socks5 v0.0.4 h1:jMQjIc+qhD4z9cITOMnBiwo9dDmpGuXmBlkRFrl/qD0= github.com/things-go/go-socks5 v0.0.4/go.mod h1:sh4K6WHrmHZpjxLTCHyYtXYH8OUuD+yZun41NomR1IQ= +github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 h1:gga7acRE695APm9hlsSMoOoE65U4/TcqNj90mc69Rlg= +github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= go4.org/intern v0.0.0-20211027215823-ae77deb06f29 h1:UXLjNohABv4S58tHmeuIZDO6e3mHpW2Dx33gaNt03LE= go4.org/intern v0.0.0-20211027215823-ae77deb06f29/go.mod h1:cS2ma+47FKrLPdXFpr7CuxiTW3eyJbWew4qx0qtQWDA= @@ -70,8 +87,11 @@ golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ= golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/init.go b/init.go index 76f09c6..8000b5a 100644 --- a/init.go +++ b/init.go @@ -31,7 +31,7 @@ type ( ZJUDNSServer string SecondaryDNSServer string DNSServerBind string - TUNDNSServer string + DNSHijack bool DebugDump bool PortForwardingList []SinglePortForwarding CustomDNSList []SingleCustomDNS @@ -72,7 +72,7 @@ type ( ZJUDNSServer *string `toml:"zju_dns_server"` SecondaryDNSServer *string `toml:"secondary_dns_server"` DNSServerBind *string `toml:"dns_server_bind"` - TUNDNSServer *string `toml:"tun_dns_server"` + DNSHijack *bool `toml:"dns_hijack"` DebugDump *bool `toml:"debug_dump"` PortForwarding []SinglePortForwardingTOML `toml:"port_forwarding"` CustomDNS []SingleCustomDNSTOML `toml:"custom_dns"` @@ -127,7 +127,7 @@ func parseTOMLConfig(configFile string, conf *Config) error { conf.ZJUDNSServer = getTOMLVal(confTOML.ZJUDNSServer, "10.10.0.21") conf.SecondaryDNSServer = getTOMLVal(confTOML.SecondaryDNSServer, "114.114.114.114") conf.DNSServerBind = getTOMLVal(confTOML.DNSServerBind, "") - conf.TUNDNSServer = getTOMLVal(confTOML.TUNDNSServer, "") + conf.DNSHijack = getTOMLVal(confTOML.DNSHijack, false) for _, singlePortForwarding := range confTOML.PortForwarding { if singlePortForwarding.NetworkType == nil { @@ -193,7 +193,7 @@ func init() { flag.StringVar(&conf.ZJUDNSServer, "zju-dns-server", "10.10.0.21", "ZJU DNS server address") flag.StringVar(&conf.SecondaryDNSServer, "secondary-dns-server", "114.114.114.114", "Secondary DNS server address. Leave empty to use system default DNS server") flag.StringVar(&conf.DNSServerBind, "dns-server-bind", "", "The address DNS server listens on (e.g. 127.0.0.1:53)") - flag.StringVar(&conf.TUNDNSServer, "tun-dns-server", "", "DNS Server address for TUN interface (e.g. 127.0.0.1). You should not specify the port") + flag.BoolVar(&conf.DNSHijack, "dns-hijack", false, "Hijack all dns query to ZJU Connect") flag.StringVar(&conf.TwfID, "twf-id", "", "Login using twfID captured (mostly for debug usage)") flag.StringVar(&tcpPortForwarding, "tcp-port-forwarding", "", "TCP port forwarding (e.g. 0.0.0.0:9898-10.10.98.98:80,127.0.0.1:9899-10.10.98.98:80)") flag.StringVar(&udpPortForwarding, "udp-port-forwarding", "", "UDP port forwarding (e.g. 127.0.0.1:53-10.10.0.21:53)") diff --git a/internal/terminal_func/terminal_func.go b/internal/terminal_func/terminal_func.go new file mode 100644 index 0000000..604ea38 --- /dev/null +++ b/internal/terminal_func/terminal_func.go @@ -0,0 +1,43 @@ +package terminal_func + +import ( + "context" + "github.com/mythologyli/zju-connect/log" +) + +type TerminalFunc func(ctx context.Context) error +type TerminalItem struct { + f TerminalFunc + name string +} + +var terminalFuncList []TerminalItem + +var terminalBegin = false + +func RegisterTerminalFunc(execName string, fun TerminalFunc) { + terminalFuncList = append(terminalFuncList, TerminalItem{ + f: fun, + name: execName, + }) + log.Println("Register func on terminal:", execName) +} + +func ExecTerminalFunc(ctx context.Context) []error { + var errList []error + terminalBegin = true + for _, item := range terminalFuncList { + log.Println("Exec func on terminal:", item.name) + if err := item.f(ctx); err != nil { + errList = append(errList, err) + log.Println("Exec func on terminal ", item.name, "failed:", err) + } else { + log.Println("Exec func on terminal ", item.name, "success") + } + } + return errList +} + +func IsTerminal() bool { + return terminalBegin +} diff --git a/internal/zcdns/local_server.go b/internal/zcdns/local_server.go new file mode 100644 index 0000000..5422818 --- /dev/null +++ b/internal/zcdns/local_server.go @@ -0,0 +1,13 @@ +package zcdns + +import ( + "context" + + "github.com/miekg/dns" + "net" +) + +type LocalServer interface { + HandleDnsMsg(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) + CheckDnsHijack(dstIP net.IP) bool +} diff --git a/internal/zctcpip/icmp.go b/internal/zctcpip/icmp.go new file mode 100644 index 0000000..be94794 --- /dev/null +++ b/internal/zctcpip/icmp.go @@ -0,0 +1,40 @@ +package zctcpip + +import ( + "encoding/binary" +) + +type ICMPType = byte + +const ( + ICMPTypePingRequest byte = 0x8 + ICMPTypePingResponse byte = 0x0 +) + +type ICMPPacket []byte + +func (p ICMPPacket) Type() ICMPType { + return p[0] +} + +func (p ICMPPacket) SetType(v ICMPType) { + p[0] = v +} + +func (p ICMPPacket) Code() byte { + return p[1] +} + +func (p ICMPPacket) Checksum() uint16 { + return binary.BigEndian.Uint16(p[2:]) +} + +func (p ICMPPacket) SetChecksum(sum [2]byte) { + p[2] = sum[0] + p[3] = sum[1] +} + +func (p ICMPPacket) ResetChecksum() { + p.SetChecksum(zeroChecksum) + p.SetChecksum(Checksum(0, p)) +} diff --git a/internal/zctcpip/ip.go b/internal/zctcpip/ip.go new file mode 100644 index 0000000..1d0fa39 --- /dev/null +++ b/internal/zctcpip/ip.go @@ -0,0 +1,209 @@ +package zctcpip + +import ( + "encoding/binary" + "errors" + "net" +) + +type IPProtocol = byte + +type IP interface { + Payload() []byte + SourceIP() net.IP + DestinationIP() net.IP + SetSourceIP(ip net.IP) + SetDestinationIP(ip net.IP) + Protocol() IPProtocol + DecTimeToLive() + ResetChecksum() + PseudoSum() uint32 +} + +// IPProtocol type +const ( + ICMP IPProtocol = 0x01 + TCP IPProtocol = 0x06 + UDP IPProtocol = 0x11 + ICMPv6 IPProtocol = 0x3a +) + +const ( + FlagDontFragment = 1 << 1 + FlagMoreFragment = 1 << 2 +) + +const ( + IPv4HeaderSize = 20 + + IPv4Version = 4 + + IPv4OptionsOffset = 20 + IPv4PacketMinLength = IPv4OptionsOffset +) + +var ( + ErrInvalidLength = errors.New("invalid packet length") + ErrInvalidIPVersion = errors.New("invalid ip version") + ErrInvalidChecksum = errors.New("invalid checksum") +) + +type IPv4Packet []byte + +func (p IPv4Packet) TotalLen() uint16 { + return binary.BigEndian.Uint16(p[2:]) +} + +func (p IPv4Packet) SetTotalLength(length uint16) { + binary.BigEndian.PutUint16(p[2:], length) +} + +func (p IPv4Packet) HeaderLen() uint16 { + return uint16(p[0]&0xf) * 4 +} + +func (p IPv4Packet) SetHeaderLen(length uint16) { + p[0] &= 0xF0 + p[0] |= byte(length / 4) +} + +func (p IPv4Packet) TypeOfService() byte { + return p[1] +} + +func (p IPv4Packet) SetTypeOfService(tos byte) { + p[1] = tos +} + +func (p IPv4Packet) Identification() uint16 { + return binary.BigEndian.Uint16(p[4:]) +} + +func (p IPv4Packet) SetIdentification(id uint16) { + binary.BigEndian.PutUint16(p[4:], id) +} + +func (p IPv4Packet) FragmentOffset() uint16 { + return binary.BigEndian.Uint16([]byte{p[6] & 0x7, p[7]}) * 8 +} + +func (p IPv4Packet) SetFragmentOffset(offset uint32) { + flags := p.Flags() + binary.BigEndian.PutUint16(p[6:], uint16(offset/8)) + p.SetFlags(flags) +} + +func (p IPv4Packet) DataLen() uint16 { + return p.TotalLen() - p.HeaderLen() +} + +func (p IPv4Packet) Payload() []byte { + return p[p.HeaderLen():p.TotalLen()] +} + +func (p IPv4Packet) Protocol() IPProtocol { + return p[9] +} + +func (p IPv4Packet) SetProtocol(protocol IPProtocol) { + p[9] = protocol +} + +func (p IPv4Packet) Flags() byte { + return p[6] >> 5 +} + +func (p IPv4Packet) SetFlags(flags byte) { + p[6] &= 0x1F + p[6] |= flags << 5 +} + +func (p IPv4Packet) SourceIP() net.IP { + return net.IPv4(p[12], p[13], p[14], p[15]) +} + +func (p IPv4Packet) SetSourceIP(ip net.IP) { + if newIP := ip.To4(); newIP != nil { + copy(p[12:16], newIP) + } +} + +func (p IPv4Packet) DestinationIP() net.IP { + return net.IPv4(p[16], p[17], p[18], p[19]) +} + +func (p IPv4Packet) SetDestinationIP(ip net.IP) { + if newIP := ip.To4(); newIP != nil { + copy(p[16:20], newIP) + } +} + +func (p IPv4Packet) Checksum() uint16 { + return binary.BigEndian.Uint16(p[10:]) +} + +func (p IPv4Packet) SetChecksum(sum [2]byte) { + p[10] = sum[0] + p[11] = sum[1] +} + +func (p IPv4Packet) TimeToLive() uint8 { + return p[8] +} + +func (p IPv4Packet) SetTimeToLive(ttl uint8) { + p[8] = ttl +} + +func (p IPv4Packet) DecTimeToLive() { + p[8] = p[8] - uint8(1) +} + +func (p IPv4Packet) ResetChecksum() { + p.SetChecksum(zeroChecksum) + p.SetChecksum(Checksum(0, p[:p.HeaderLen()])) +} + +// PseudoSum for tcp checksum +func (p IPv4Packet) PseudoSum() uint32 { + sum := Sum(p[12:20]) + sum += uint32(p.Protocol()) + sum += uint32(p.DataLen()) + return sum +} + +func (p IPv4Packet) Valid() bool { + return len(p) >= IPv4HeaderSize && p.TotalLen() >= p.HeaderLen() && uint16(len(p)) >= p.TotalLen() +} + +func (p IPv4Packet) Verify() error { + if len(p) < IPv4PacketMinLength { + return ErrInvalidLength + } + + checksum := []byte{p[10], p[11]} + headerLength := uint16(p[0]&0xF) * 4 + packetLength := binary.BigEndian.Uint16(p[2:]) + + if p[0]>>4 != 4 { + return ErrInvalidIPVersion + } + + if uint16(len(p)) < packetLength || packetLength < headerLength { + return ErrInvalidLength + } + + p[10] = 0 + p[11] = 0 + defer copy(p[10:12], checksum) + + answer := Checksum(0, p[:headerLength]) + + if answer[0] != checksum[0] || answer[1] != checksum[1] { + return ErrInvalidChecksum + } + + return nil +} + +var _ IP = (*IPv4Packet)(nil) diff --git a/internal/zctcpip/tcp.go b/internal/zctcpip/tcp.go new file mode 100644 index 0000000..79f19d7 --- /dev/null +++ b/internal/zctcpip/tcp.go @@ -0,0 +1,90 @@ +package zctcpip + +import ( + "encoding/binary" + "net" +) + +const ( + TCPFin uint16 = 1 << 0 + TCPSyn uint16 = 1 << 1 + TCPRst uint16 = 1 << 2 + TCPPuh uint16 = 1 << 3 + TCPAck uint16 = 1 << 4 + TCPUrg uint16 = 1 << 5 + TCPEce uint16 = 1 << 6 + TCPEwr uint16 = 1 << 7 + TCPNs uint16 = 1 << 8 +) + +const TCPHeaderSize = 20 + +type TCPPacket []byte + +func (p TCPPacket) SourcePort() uint16 { + return binary.BigEndian.Uint16(p) +} + +func (p TCPPacket) SetSourcePort(port uint16) { + binary.BigEndian.PutUint16(p, port) +} + +func (p TCPPacket) DestinationPort() uint16 { + return binary.BigEndian.Uint16(p[2:]) +} + +func (p TCPPacket) SetDestinationPort(port uint16) { + binary.BigEndian.PutUint16(p[2:], port) +} + +func (p TCPPacket) Flags() uint16 { + return uint16(p[13] | (p[12] & 0x1)) +} + +func (p TCPPacket) Checksum() uint16 { + return binary.BigEndian.Uint16(p[16:]) +} + +func (p TCPPacket) SetChecksum(sum [2]byte) { + p[16] = sum[0] + p[17] = sum[1] +} + +func (p TCPPacket) ResetChecksum(psum uint32) { + p.SetChecksum(zeroChecksum) + p.SetChecksum(Checksum(psum, p)) +} + +func (p TCPPacket) Valid() bool { + return len(p) >= TCPHeaderSize +} + +func (p TCPPacket) Verify(sourceAddress net.IP, targetAddress net.IP) error { + var checksum [2]byte + checksum[0] = p[16] + checksum[1] = p[17] + + // reset checksum + p[16] = 0 + p[17] = 0 + + // restore checksum + defer func() { + p[16] = checksum[0] + p[17] = checksum[1] + }() + + // check checksum + s := uint32(0) + s += Sum(sourceAddress) + s += Sum(targetAddress) + s += uint32(TCP) + s += uint32(len(p)) + + check := Checksum(s, p) + if checksum[0] != check[0] || checksum[1] != check[1] { + return ErrInvalidChecksum + } + + return nil +} diff --git a/internal/zctcpip/tcpip.go b/internal/zctcpip/tcpip.go new file mode 100644 index 0000000..2f108dc --- /dev/null +++ b/internal/zctcpip/tcpip.go @@ -0,0 +1,33 @@ +package zctcpip + +var zeroChecksum = [2]byte{0x00, 0x00} + +func Sum(b []byte) uint32 { + // TODO use neon on arm and AVX on amd64 + var sum uint32 + n := len(b) + if n&1 != 0 { + n-- + sum += uint32(b[n]) << 8 + } + + for i := 0; i < n; i += 2 { + sum += (uint32(b[i]) << 8) | uint32(b[i+1]) + } + return sum +} + +// Checksum for Internet Protocol family headers +func Checksum(sum uint32, b []byte) (answer [2]byte) { + sum += Sum(b) + sum = (sum >> 16) + (sum & 0xffff) + sum += sum >> 16 + sum = ^sum + answer[0] = byte(sum >> 8) + answer[1] = byte(sum) + return +} + +func SetIPv4(packet []byte) { + packet[0] = (packet[0] & 0x0f) | (4 << 4) +} diff --git a/internal/zctcpip/udp.go b/internal/zctcpip/udp.go new file mode 100644 index 0000000..08e7e98 --- /dev/null +++ b/internal/zctcpip/udp.go @@ -0,0 +1,55 @@ +package zctcpip + +import ( + "encoding/binary" +) + +const UDPHeaderSize = 8 + +type UDPPacket []byte + +func (p UDPPacket) Length() uint16 { + return binary.BigEndian.Uint16(p[4:]) +} + +func (p UDPPacket) SetLength(length uint16) { + binary.BigEndian.PutUint16(p[4:], length) +} + +func (p UDPPacket) SourcePort() uint16 { + return binary.BigEndian.Uint16(p) +} + +func (p UDPPacket) SetSourcePort(port uint16) { + binary.BigEndian.PutUint16(p, port) +} + +func (p UDPPacket) DestinationPort() uint16 { + return binary.BigEndian.Uint16(p[2:]) +} + +func (p UDPPacket) SetDestinationPort(port uint16) { + binary.BigEndian.PutUint16(p[2:], port) +} + +func (p UDPPacket) Payload() []byte { + return p[UDPHeaderSize:p.Length()] +} + +func (p UDPPacket) Checksum() uint16 { + return binary.BigEndian.Uint16(p[6:]) +} + +func (p UDPPacket) SetChecksum(sum [2]byte) { + p[6] = sum[0] + p[7] = sum[1] +} + +func (p UDPPacket) ResetChecksum(psum uint32) { + p.SetChecksum(zeroChecksum) + p.SetChecksum(Checksum(psum, p)) +} + +func (p UDPPacket) Valid() bool { + return len(p) >= UDPHeaderSize && uint16(len(p)) >= p.Length() +} diff --git a/main.go b/main.go index 8acf8a0..bf01862 100644 --- a/main.go +++ b/main.go @@ -1,9 +1,11 @@ package main import ( + "context" "fmt" "github.com/mythologyli/zju-connect/client" "github.com/mythologyli/zju-connect/dial" + "github.com/mythologyli/zju-connect/internal/terminal_func" "github.com/mythologyli/zju-connect/log" "github.com/mythologyli/zju-connect/resolve" "github.com/mythologyli/zju-connect/service" @@ -12,6 +14,9 @@ import ( "github.com/mythologyli/zju-connect/stack/tun" "inet.af/netaddr" "net" + "os" + "os/signal" + "syscall" ) var conf Config @@ -73,7 +78,7 @@ func main() { var vpnStack stack.Stack if conf.TUNMode { - vpnTUNStack, err := tun.NewStack(vpnClient, conf.TUNDNSServer) + vpnTUNStack, err := tun.NewStack(vpnClient, conf.DNSHijack) if err != nil { log.Fatalf("Tun stack setup error: %s", err) } @@ -93,8 +98,6 @@ func main() { } } - go vpnStack.Run() - vpnResolver := resolve.NewResolver( vpnStack, conf.ZJUDNSServer, @@ -113,11 +116,19 @@ func main() { vpnResolver.SetPermanentDNS(customDns.HostName, ipAddr) log.Printf("Add custom DNS: %s -> %s\n", customDns.HostName, customDns.IP) } + localResolver := service.NewDnsServer(vpnResolver, []string{conf.ZJUDNSServer, conf.SecondaryDNSServer}) + vpnStack.SetupResolve(localResolver) + + go vpnStack.Run() vpnDialer := dial.NewDialer(vpnStack, vpnResolver, ipResource, conf.ProxyAll) if conf.DNSServerBind != "" { - go service.ServeDNS(conf.DNSServerBind, vpnResolver) + go service.ServeDNS(conf.DNSServerBind, localResolver) + } + if conf.TUNMode { + clientIP, _ := vpnClient.IP() + go service.ServeDNS(clientIP.String()+":53", localResolver) } if conf.SocksBind != "" { @@ -139,8 +150,18 @@ func main() { } if !conf.DisableKeepAlive { - service.KeepAlive(vpnResolver) + go service.KeepAlive(vpnResolver) } - select {} + quit := make(chan os.Signal) + signal.Notify(quit, os.Interrupt, syscall.SIGTERM) + <-quit + log.Println("Shutdown ZJU-Connect ......") + if errs := terminal_func.ExecTerminalFunc(context.Background()); errs != nil { + for _, err := range errs { + log.Printf("Shutdown ZJU-Connect failed:", err) + } + } else { + log.Println("Shutdown ZJU-Connect success, Bye~") + } } diff --git a/service/dns.go b/service/dns.go index ed002d6..6810d1d 100644 --- a/service/dns.go +++ b/service/dns.go @@ -6,20 +6,46 @@ import ( "github.com/miekg/dns" "github.com/mythologyli/zju-connect/log" "github.com/mythologyli/zju-connect/resolve" + "net" ) type DNSServer struct { resolver *resolve.Resolver + localDNS []net.IP } -func (d DNSServer) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) { +func (d DNSServer) serveDNSRequest(w dns.ResponseWriter, r *dns.Msg) { m := new(dns.Msg) m.SetReply(r) m.Compress = false - switch r.Opcode { + _ = d.handleSingleDNSResolve(context.Background(), r, m) + + _ = w.WriteMsg(m) +} + +func (d DNSServer) HandleDnsMsg(ctx context.Context, requestMsg *dns.Msg) (*dns.Msg, error) { + resMsg := new(dns.Msg) + resMsg.SetReply(requestMsg) + resMsg.Compress = false + + err := d.handleSingleDNSResolve(ctx, requestMsg, resMsg) + return resMsg, err +} + +func (d DNSServer) CheckDnsHijack(dstIP net.IP) bool { + for _, ip := range d.localDNS { + if ip.Equal(dstIP) { + return false + } + } + return true +} + +func (d DNSServer) handleSingleDNSResolve(ctx context.Context, requestMsg *dns.Msg, resMsg *dns.Msg) error { + switch requestMsg.Opcode { case dns.OpcodeQuery: - for _, q := range r.Question { + for _, q := range requestMsg.Question { name := q.Name if len(name) > 1 && name[len(name)-1] == '.' { name = name[:len(name)-1] @@ -27,33 +53,41 @@ func (d DNSServer) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) { switch q.Qtype { case dns.TypeA: - if _, ip, err := d.resolver.Resolve(context.Background(), name); err == nil { + if _, ip, err := d.resolver.Resolve(ctx, name); err == nil { if ip.To4() != nil { rr, err := dns.NewRR(fmt.Sprintf("%s A %s", q.Name, ip)) if err == nil { - m.Answer = append(m.Answer, rr) + resMsg.Answer = append(resMsg.Answer, rr) } } } case dns.TypeAAAA: - if _, ip, err := d.resolver.Resolve(context.Background(), name); err == nil { + if _, ip, err := d.resolver.Resolve(ctx, name); err == nil { if ip.To4() == nil { rr, err := dns.NewRR(fmt.Sprintf("%s AAAA %s", q.Name, ip)) if err == nil { - m.Answer = append(m.Answer, rr) + resMsg.Answer = append(resMsg.Answer, rr) } } } } } } + return nil +} - _ = w.WriteMsg(m) +func NewDnsServer(resolver *resolve.Resolver, dnsServers []string) DNSServer { + netIPs := make([]net.IP, len(dnsServers)) + for _, dnsServer := range dnsServers { + if net.ParseIP(dnsServer) != nil { + netIPs = append(netIPs, net.ParseIP(dnsServer)) + } + } + return DNSServer{resolver: resolver, localDNS: netIPs} } -func ServeDNS(bindAddr string, resolver *resolve.Resolver) { - dnsServer := &DNSServer{resolver: resolver} - dns.HandleFunc(".", dnsServer.handleDNSRequest) +func ServeDNS(bindAddr string, dnsServer DNSServer) { + dns.HandleFunc(".", dnsServer.serveDNSRequest) server := &dns.Server{Addr: bindAddr, Net: "udp"} log.Printf("Starting DNS server at %s", server.Addr) diff --git a/stack/gvisor/stack.go b/stack/gvisor/stack.go index 077e553..83a1164 100644 --- a/stack/gvisor/stack.go +++ b/stack/gvisor/stack.go @@ -3,6 +3,8 @@ package gvisor import ( "errors" "github.com/mythologyli/zju-connect/client" + "github.com/mythologyli/zju-connect/internal/terminal_func" + "github.com/mythologyli/zju-connect/internal/zcdns" "github.com/mythologyli/zju-connect/log" "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/tcpip" @@ -16,6 +18,7 @@ import ( type Stack struct { gvisorStack *stack.Stack + resolve zcdns.LocalServer endpoint *Endpoint } @@ -76,8 +79,14 @@ func (ep *Endpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) } if ep.rvpnConn != nil { - n, _ := ep.rvpnConn.Write(buf) - + n, err := ep.rvpnConn.Write(buf) + if err != nil { + if terminal_func.IsTerminal() { + return list.Len(), nil + } else { + panic(err) + } + } log.DebugPrintf("Send: wrote %d bytes", n) log.DebugDumpHex(buf[:n]) } @@ -132,15 +141,27 @@ func NewStack(easyConnectClient *client.EasyConnectClient) (*Stack, error) { return s, nil } -func (s *Stack) Run() { - - s.endpoint.rvpnConn, _ = client.NewRvpnConn(s.endpoint.easyConnectClient) +func (s *Stack) SetupResolve(r zcdns.LocalServer) { + s.resolve = r +} +func (s *Stack) Run() { + var connErr error + s.endpoint.rvpnConn, connErr = client.NewRvpnConn(s.endpoint.easyConnectClient) + if connErr != nil { + panic(connErr) + } // Read from VPN server and send to gVisor stack for { - buf := make([]byte, 1500) - n, _ := s.endpoint.rvpnConn.Read(buf) - + buf := make([]byte, MTU) + n, err := s.endpoint.rvpnConn.Read(buf) + if err != nil { + if terminal_func.IsTerminal() { + return + } else { + panic(err) + } + } log.DebugPrintf("Recv: read %d bytes", n) log.DebugDumpHex(buf[:n]) diff --git a/stack/stack.go b/stack/stack.go index 361880f..2363ae8 100644 --- a/stack/stack.go +++ b/stack/stack.go @@ -1,9 +1,13 @@ package stack -import "net" +import ( + "github.com/mythologyli/zju-connect/internal/zcdns" + "net" +) type Stack interface { Run() + SetupResolve(r zcdns.LocalServer) DialTCP(addr *net.TCPAddr) (net.Conn, error) DialUDP(addr *net.UDPAddr) (net.Conn, error) } diff --git a/stack/tun/stack.go b/stack/tun/stack.go index 9f6ba8c..92e3c1f 100644 --- a/stack/tun/stack.go +++ b/stack/tun/stack.go @@ -1,65 +1,207 @@ package tun import ( + "context" + "fmt" + "github.com/mythologyli/zju-connect/internal/terminal_func" + "io" + + tun "github.com/cxz66666/sing-tun" + "github.com/miekg/dns" "github.com/mythologyli/zju-connect/client" + "github.com/mythologyli/zju-connect/internal/zcdns" + "github.com/mythologyli/zju-connect/internal/zctcpip" "github.com/mythologyli/zju-connect/log" - "golang.org/x/net/ipv4" - "io" - "syscall" ) +const MTU uint32 = 1400 + type Stack struct { endpoint *Endpoint rvpnConn io.ReadWriteCloser + resolve zcdns.LocalServer } -func (s *Stack) Run() { - s.rvpnConn, _ = client.NewRvpnConn(s.endpoint.easyConnectClient) +func (s *Stack) SetupResolve(r zcdns.LocalServer) { + s.resolve = r +} +func (s *Stack) Run() { + var connErr error + s.rvpnConn, connErr = client.NewRvpnConn(s.endpoint.easyConnectClient) + if connErr != nil { + panic(connErr) + } // Read from VPN server and send to TUN stack go func() { for { - buf := make([]byte, 1500) - n, _ := s.rvpnConn.Read(buf) - + buf := make([]byte, MTU+tun.PacketOffset) + n, err := s.rvpnConn.Read(buf) + if err != nil { + panic(err) + } log.DebugPrintf("Recv: read %d bytes", n) log.DebugDumpHex(buf[:n]) - err := s.endpoint.Write(buf[:n]) + err = s.endpoint.Write(buf[:n]) if err != nil { - log.Printf("Error occurred while writing to TUN stack: %v", err) - panic(err) + if terminal_func.IsTerminal() { + return + } else { + log.Printf("Error occurred while writing to TUN stack: %v", err) + panic(err) + } } } }() // Read from TUN stack and send to VPN server for { - buf := make([]byte, 1500) + buf := make([]byte, MTU+tun.PacketOffset) n, err := s.endpoint.Read(buf) if err != nil { - log.Printf("Error occurred while reading from TUN stack: %v", err) - // TODO graceful shutdown - panic(err) + if terminal_func.IsTerminal() { + return + } else { + log.Printf("Error occurred while reading from TUN stack: %v", err) + // TODO graceful shutdown + panic(err) + } } - if n < 20 { + if n < zctcpip.IPv4PacketMinLength { continue } - header, err := ipv4.ParseHeader(buf[:n]) + // whether this should be a blocking operation? + packet := buf[tun.PacketOffset:n] + switch ipVersion := packet[0] >> 4; ipVersion { + case zctcpip.IPv4Version: + err = s.processIPV4(packet) + default: + err = fmt.Errorf("unsupport IP version %d", ipVersion) + } if err != nil { + log.DebugPrintf("Error occurred while processing IP packet: %v", err) continue } - // Filter out non-TCP/UDP packets otherwise error may occur - if header.Protocol != syscall.IPPROTO_TCP && header.Protocol != syscall.IPPROTO_UDP { - continue - } + } +} + +func (s *Stack) processIPV4(packet zctcpip.IPv4Packet) error { + switch packet.Protocol() { + case zctcpip.TCP: + return s.processIPV4TCP(packet, packet.Payload()) + case zctcpip.UDP: + return s.processIPV4UDP(packet, packet.Payload()) + case zctcpip.ICMP: + return s.processIPV4ICMP(packet, packet.Payload()) + default: + return fmt.Errorf("unknown protocol %d", packet[9]) + } +} + +func (s *Stack) processIPV4TCP(packet zctcpip.IPv4Packet, tcpPacket zctcpip.TCPPacket) error { + log.DebugPrintf("receive tcp %s:%d -> %s:%d", packet.SourceIP(), tcpPacket.SourcePort(), packet.DestinationIP(), tcpPacket.DestinationPort()) + + if !packet.DestinationIP().IsGlobalUnicast() { + return s.endpoint.Write(packet) + } + n, err := s.rvpnConn.Write(packet) + if err != nil { + panic(err) + } + log.DebugPrintf("Send: wrote %d bytes", n) + log.DebugDumpHex(packet[:n]) + + return err +} + +func (s *Stack) processIPV4UDP(packet zctcpip.IPv4Packet, udpPacket zctcpip.UDPPacket) error { + log.DebugPrintf("receive udp %s:%d -> %s:%d", packet.SourceIP(), udpPacket.SourcePort(), packet.DestinationIP(), udpPacket.DestinationPort()) + + if !packet.DestinationIP().IsGlobalUnicast() { + return s.endpoint.Write(packet) + } + + if s.shouldHijackUDPDns(packet, udpPacket) { + newPacket := make(zctcpip.IPv4Packet, len(packet)) + copy(newPacket, packet) + newUdpPacket := zctcpip.UDPPacket(newPacket.Payload()) + // need to be non-blocking + go s.doHijackUDPDns(newPacket, newUdpPacket) + return nil + } + + n, err := s.rvpnConn.Write(packet) + if err != nil { + panic(err) + } + log.DebugPrintf("Send: wrote %d bytes", n) + log.DebugDumpHex(packet[:n]) + + return err +} + +func (s *Stack) processIPV4ICMP(packet zctcpip.IPv4Packet, icmpHeader zctcpip.ICMPPacket) error { + log.DebugPrintf("receive icmp %s -> %s", packet.SourceIP(), packet.DestinationIP()) + if icmpHeader.Type() != zctcpip.ICMPTypePingRequest || icmpHeader.Code() != 0 { + return nil + } + icmpHeader.SetType(zctcpip.ICMPTypePingResponse) + sourceIP := packet.SourceIP() + packet.SetSourceIP(packet.DestinationIP()) + packet.SetDestinationIP(sourceIP) + + icmpHeader.ResetChecksum() + packet.ResetChecksum() - _, _ = s.rvpnConn.Write(buf[:n]) + return s.endpoint.Write(packet) +} - log.DebugPrintf("Send: wrote %d bytes", n) - log.DebugDumpHex(buf[:n]) +// only can handle udp dns query! +func (s *Stack) shouldHijackUDPDns(ipHeader zctcpip.IPv4Packet, udpHeader zctcpip.UDPPacket) bool { + if udpHeader.DestinationPort() != 53 { + return false + } + return s.resolve.CheckDnsHijack(ipHeader.DestinationIP()) +} + +func (s *Stack) doHijackUDPDns(ipHeader zctcpip.IPv4Packet, udpHeader zctcpip.UDPPacket) { + log.Printf("hijack dns %s:%d -> %s:%d", ipHeader.SourceIP(), udpHeader.SourcePort(), ipHeader.DestinationIP(), udpHeader.DestinationPort()) + msg := dns.Msg{} + if err := msg.Unpack(udpHeader.Payload()); err != nil { + log.Printf("unpack dns msg error: %v", err) + return + } + resMsg, err := s.resolve.HandleDnsMsg(context.Background(), &msg) + if err != nil { + log.Printf("hijack dns error: %v", err) + return } + + resByte, err := resMsg.Pack() + if err != nil { + log.Printf("pack dns msg error: %v", err) + return + } + + totalLen := int(ipHeader.HeaderLen()) + zctcpip.UDPHeaderSize + len(resByte) + + newPacket := make(zctcpip.IPv4Packet, totalLen) + copy(newPacket, ipHeader[:ipHeader.HeaderLen()]) + newPacket.SetTotalLength(uint16(totalLen)) + newPacket.SetSourceIP(ipHeader.DestinationIP()) + newPacket.SetDestinationIP(ipHeader.SourceIP()) + + newUDPHeader := zctcpip.UDPPacket(newPacket.Payload()) + newUDPHeader.SetSourcePort(udpHeader.DestinationPort()) + newUDPHeader.SetDestinationPort(udpHeader.SourcePort()) + newUDPHeader.SetLength(zctcpip.UDPHeaderSize + uint16(len(resByte))) + copy(newUDPHeader.Payload(), resByte) + + newUDPHeader.ResetChecksum(newPacket.PseudoSum()) + newPacket.ResetChecksum() + _ = s.endpoint.Write(newPacket) } diff --git a/stack/tun/stack_darwin.go b/stack/tun/stack_darwin.go index 172f9b4..a748a01 100644 --- a/stack/tun/stack_darwin.go +++ b/stack/tun/stack_darwin.go @@ -1,36 +1,50 @@ package tun import ( + "context" + "fmt" + tun "github.com/cxz66666/sing-tun" "github.com/mythologyli/zju-connect/client" + "github.com/mythologyli/zju-connect/internal/terminal_func" "github.com/mythologyli/zju-connect/log" - "github.com/songgao/water" "golang.org/x/sys/unix" "net" + "net/netip" + "os" "os/exec" + "sync" "syscall" ) type Endpoint struct { easyConnectClient *client.EasyConnectClient - ifce *water.Interface - ip net.IP + ifce tun.Tun + ifceName string + ifceIndex int + readLock sync.Mutex + writeLock sync.Mutex + ip net.IP tcpDialer *net.Dialer udpDialer *net.Dialer } func (ep *Endpoint) Write(buf []byte) error { + ep.writeLock.Lock() + defer ep.writeLock.Unlock() _, err := ep.ifce.Write(buf) return err } func (ep *Endpoint) Read(buf []byte) (int, error) { + ep.readLock.Lock() + defer ep.readLock.Unlock() return ep.ifce.Read(buf) } func (s *Stack) AddRoute(target string) error { - command := exec.Command("route", "-n", "add", "-net", target, "-interface", s.endpoint.ifce.Name()) + command := exec.Command("route", "-n", "add", "-net", target, "-interface", s.endpoint.ifceName) err := command.Run() if err != nil { return err @@ -39,38 +53,74 @@ func (s *Stack) AddRoute(target string) error { return nil } -func NewStack(easyConnectClient *client.EasyConnectClient, dnsServer string) (*Stack, error) { - s := &Stack{} - - ifce, err := water.New(water.Config{ - DeviceType: water.TUN, - }) +func (s *Stack) AddDnsServer(dnsServer string, targetHost string) error { + fileName := fmt.Sprintf("/etc/resolver/%s", targetHost) + file, err := os.Create(fileName) if err != nil { - return nil, err + return err } + defer file.Close() + + file.WriteString(fmt.Sprintf("nameserver %s\n", dnsServer)) - log.Printf("Interface Name: %s\n", ifce.Name()) + terminal_func.RegisterTerminalFunc("DelDnsServer_"+targetHost, func(ctx context.Context) error { + delCommand := exec.Command("rm", fmt.Sprintf("/etc/resolver/%s", targetHost)) + delErr := delCommand.Run() + if delErr != nil { + return delErr + } + return nil + }) + return nil +} +func NewStack(easyConnectClient *client.EasyConnectClient, dnsHijack bool) (*Stack, error) { + var err error + s := &Stack{} s.endpoint = &Endpoint{ easyConnectClient: easyConnectClient, } - s.endpoint.ifce = ifce - - // Get index of TUN interface - netIfce, err := net.InterfaceByName(ifce.Name()) + s.endpoint.ip, err = easyConnectClient.IP() if err != nil { return nil, err } + ipPrefix, _ := netip.ParsePrefix(s.endpoint.ip.String() + "/8") + zjuPrefix, _ := netip.ParsePrefix("10.0.0.0/8") + tunName := "utun0" + tunName = tun.CalculateInterfaceName(tunName) + tunOptions := tun.Options{ + Name: tunName, + MTU: MTU, + Inet4Address: []netip.Prefix{ + ipPrefix, + }, + Inet4RouteAddress: []netip.Prefix{ + zjuPrefix, + }, + AutoRoute: true, + TableIndex: 1897, + } - ifceIndex := netIfce.Index - - s.endpoint.ip, err = easyConnectClient.IP() + ifce, err := tun.New(tunOptions) + if err != nil { + return nil, err + } + terminal_func.RegisterTerminalFunc("Close Tun Device", func(ctx context.Context) error { + return ifce.Close() + }) + s.endpoint.ifce = ifce + s.endpoint.ifceName = tunName + netIfce, err := net.InterfaceByName(tunName) if err != nil { return nil, err } + s.endpoint.ifceIndex = netIfce.Index + log.Printf("Interface Name: %s, index %d\n", tunName, netIfce.Index) + // We need this dialer to bind to device otherwise packets will not be sent via TUN + // Doesn't work on macos. See https://github.com/Mythologyli/zju-connect/pull/44#issuecomment-1784050022 s.endpoint.tcpDialer = &net.Dialer{ LocalAddr: &net.TCPAddr{ IP: s.endpoint.ip, @@ -78,13 +128,14 @@ func NewStack(easyConnectClient *client.EasyConnectClient, dnsServer string) (*S }, Control: func(network, address string, c syscall.RawConn) error { // By ChenXuzheng return c.Control(func(fd uintptr) { - if err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_RECVIF, ifceIndex); err != nil { - log.Println("Warning: failed to bind to interface", s.endpoint.ifce.Name()) + if err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_RECVIF, s.endpoint.ifceIndex); err != nil { + log.Println("Warning: failed to bind to interface", s.endpoint.ifceName) } }) }, } + // Doesn't work on macos. See https://github.com/Mythologyli/zju-connect/pull/44#issuecomment-1784050022 s.endpoint.udpDialer = &net.Dialer{ LocalAddr: &net.UDPAddr{ IP: s.endpoint.ip, @@ -92,29 +143,19 @@ func NewStack(easyConnectClient *client.EasyConnectClient, dnsServer string) (*S }, Control: func(network, address string, c syscall.RawConn) error { // By ChenXuzheng return c.Control(func(fd uintptr) { - if err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_RECVIF, ifceIndex); err != nil { - log.Println("Warning: failed to bind to interface", s.endpoint.ifce.Name()) + if err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_RECVIF, s.endpoint.ifceIndex); err != nil { + log.Println("Warning: failed to bind to interface", s.endpoint.ifceName) } }) }, } - - cmd := exec.Command("ifconfig", ifce.Name(), s.endpoint.ip.String(), "255.0.0.0", s.endpoint.ip.String()) - err = cmd.Run() - if err != nil { - log.Printf("Run %s failed: %v", cmd.String(), err) - } - - if err = s.AddRoute("10.0.0.0/8"); err != nil { - log.Printf("Run AddRoute 10.0.0.0/8 failed: %v", err) + if dnsHijack { + if err = s.AddDnsServer(s.endpoint.ip.String(), "zju.edu.cn"); err != nil { + log.Printf("AddDnsServer failed: %v", err) + } + if err = s.AddDnsServer(s.endpoint.ip.String(), "cc98.org"); err != nil { + log.Printf("AddDnsServer failed: %v", err) + } } - - // Set MTU to 1400 otherwise error may occur when packets are large - cmd = exec.Command("ifconfig", ifce.Name(), "mtu", "1400", "up") - err = cmd.Run() - if err != nil { - log.Printf("Run %s failed: %v", cmd.String(), err) - } - return s, nil } diff --git a/stack/tun/stack_linux.go b/stack/tun/stack_linux.go index b24d2b6..5c83f3d 100644 --- a/stack/tun/stack_linux.go +++ b/stack/tun/stack_linux.go @@ -1,35 +1,46 @@ package tun import ( + "context" + tun "github.com/cxz66666/sing-tun" "github.com/mythologyli/zju-connect/client" + "github.com/mythologyli/zju-connect/internal/terminal_func" "github.com/mythologyli/zju-connect/log" - "github.com/songgao/water" "net" + "net/netip" "os/exec" + "sync" "syscall" ) type Endpoint struct { easyConnectClient *client.EasyConnectClient - ifce *water.Interface - ip net.IP + ifce tun.Tun + ifceName string + readLock sync.Mutex + writeLock sync.Mutex + ip net.IP tcpDialer *net.Dialer udpDialer *net.Dialer } func (ep *Endpoint) Write(buf []byte) error { + ep.writeLock.Lock() + defer ep.writeLock.Unlock() _, err := ep.ifce.Write(buf) return err } func (ep *Endpoint) Read(buf []byte) (int, error) { + ep.readLock.Lock() + defer ep.readLock.Unlock() return ep.ifce.Read(buf) } func (s *Stack) AddRoute(target string) error { - command := exec.Command("ip", "route", "add", target, "dev", s.endpoint.ifce.Name()) + command := exec.Command("ip", "route", "add", target, "dev", s.endpoint.ifceName) err := command.Run() if err != nil { return err @@ -38,28 +49,42 @@ func (s *Stack) AddRoute(target string) error { return nil } -func NewStack(easyConnectClient *client.EasyConnectClient, dnsServer string) (*Stack, error) { +func NewStack(easyConnectClient *client.EasyConnectClient, dnsHijack bool) (*Stack, error) { + var err error s := &Stack{} - - ifce, err := water.New(water.Config{ - DeviceType: water.TUN, - }) - if err != nil { - return nil, err - } - - log.Printf("Interface Name: %s\n", ifce.Name()) - s.endpoint = &Endpoint{ easyConnectClient: easyConnectClient, } - s.endpoint.ifce = ifce - s.endpoint.ip, err = easyConnectClient.IP() if err != nil { return nil, err } + ipPrefix, _ := netip.ParsePrefix(s.endpoint.ip.String() + "/8") + tunName := "ZJU-Connect" + tunName = tun.CalculateInterfaceName(tunName) + + tunOptions := tun.Options{ + Name: tunName, + MTU: MTU, + Inet4Address: []netip.Prefix{ + ipPrefix, + }, + } + if dnsHijack { + tunOptions.AutoRoute = true + tunOptions.TableIndex = 1897 + } + ifce, err := tun.New(tunOptions) + if err != nil { + return nil, err + } + terminal_func.RegisterTerminalFunc("Close Tun Device", func(ctx context.Context) error { + return ifce.Close() + }) + s.endpoint.ifce = ifce + s.endpoint.ifceName = tunName + log.Printf("Interface Name: %s\n", tunName) // We need this dialer to bind to device otherwise packets will not be sent via TUN s.endpoint.tcpDialer = &net.Dialer{ @@ -69,8 +94,8 @@ func NewStack(easyConnectClient *client.EasyConnectClient, dnsServer string) (*S }, Control: func(network, address string, c syscall.RawConn) error { return c.Control(func(fd uintptr) { - if err := syscall.BindToDevice(int(fd), s.endpoint.ifce.Name()); err != nil { - log.Println("Warning: failed to bind to interface", s.endpoint.ifce.Name()) + if err := syscall.BindToDevice(int(fd), s.endpoint.ifceName); err != nil { + log.Println("Warning: failed to bind to interface", s.endpoint.ifceName) } }) }, @@ -83,31 +108,12 @@ func NewStack(easyConnectClient *client.EasyConnectClient, dnsServer string) (*S }, Control: func(network, address string, c syscall.RawConn) error { return c.Control(func(fd uintptr) { - if err := syscall.BindToDevice(int(fd), s.endpoint.ifce.Name()); err != nil { - log.Println("Warning: failed to bind to interface", s.endpoint.ifce.Name()) + if err := syscall.BindToDevice(int(fd), s.endpoint.ifceName); err != nil { + log.Println("Warning: failed to bind to interface", s.endpoint.ifceName) } }) }, } - cmd := exec.Command("ip", "link", "set", ifce.Name(), "up") - err = cmd.Run() - if err != nil { - log.Printf("Run %s failed: %v", cmd.String(), err) - } - - // Set MTU to 1400 otherwise error may occur when packets are large - cmd = exec.Command("ip", "link", "set", "dev", ifce.Name(), "mtu", "1400") - err = cmd.Run() - if err != nil { - log.Printf("Run %s failed: %v", cmd.String(), err) - } - - cmd = exec.Command("ip", "addr", "add", s.endpoint.ip.String()+"/8", "dev", ifce.Name()) - err = cmd.Run() - if err != nil { - log.Printf("Run %s failed: %v", cmd.String(), err) - } - return s, nil } diff --git a/stack/tun/stack_windows.go b/stack/tun/stack_windows.go index 99728d1..9ab1f55 100644 --- a/stack/tun/stack_windows.go +++ b/stack/tun/stack_windows.go @@ -1,8 +1,10 @@ package tun import ( + "context" "fmt" "github.com/mythologyli/zju-connect/client" + "github.com/mythologyli/zju-connect/internal/terminal_func" "github.com/mythologyli/zju-connect/log" "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/tun" @@ -10,6 +12,7 @@ import ( "net" "net/netip" "os/exec" + "sync" ) const guid = "{4F5CDE94-D2A3-4AA5-A4A3-0FE6CB909E83}" @@ -18,11 +21,15 @@ const interfaceName = "ZJU Connect" type Endpoint struct { easyConnectClient *client.EasyConnectClient - dev tun.Device - ip net.IP + dev tun.Device + readLock sync.Mutex + writeLock sync.Mutex + ip net.IP } func (ep *Endpoint) Write(buf []byte) error { + ep.writeLock.Lock() + defer ep.writeLock.Unlock() bufs := [][]byte{buf} _, err := ep.dev.Write(bufs, 0) @@ -34,6 +41,8 @@ func (ep *Endpoint) Write(buf []byte) error { } func (ep *Endpoint) Read(buf []byte) (int, error) { + ep.readLock.Lock() + defer ep.readLock.Unlock() bufs := [][]byte{buf} sizes := []int{1} @@ -65,7 +74,7 @@ func (s *Stack) AddRoute(target string) error { return nil } -func NewStack(easyConnectClient *client.EasyConnectClient, dnsServer string) (*Stack, error) { +func NewStack(easyConnectClient *client.EasyConnectClient, dnsHijack bool) (*Stack, error) { s := &Stack{} guid, err := windows.GUIDFromString(guid) @@ -73,7 +82,7 @@ func NewStack(easyConnectClient *client.EasyConnectClient, dnsServer string) (*S return nil, err } - dev, err := tun.CreateTUNWithRequestedGUID(interfaceName, &guid, 1400) + dev, err := tun.CreateTUNWithRequestedGUID(interfaceName, &guid, int(MTU)) if err != nil { return nil, err } @@ -104,7 +113,7 @@ func NewStack(easyConnectClient *client.EasyConnectClient, dnsServer string) (*S } // Set MTU to 1400 otherwise error may occur when packets are large - command := exec.Command("netsh", "interface", "ipv4", "set", "subinterface", interfaceName, "mtu=1400", "store=persistent") + command := exec.Command("netsh", "interface", "ipv4", "set", "subinterface", interfaceName, fmt.Sprintf("mtu=%d", MTU), "store=persistent") err = command.Run() if err != nil { log.Printf("Run %s failed: %v", command.String(), err) @@ -118,8 +127,8 @@ func NewStack(easyConnectClient *client.EasyConnectClient, dnsServer string) (*S log.Printf("Run %s failed: %v", command.String(), err) } - if dnsServer != "" { - command = exec.Command("netsh", "interface", "ipv4", "add", "dnsservers", "ZJU Connect", dnsServer) + if dnsHijack { + command = exec.Command("netsh", "interface", "ipv4", "add", "dnsservers", "ZJU Connect", s.endpoint.ip.String()) } else { command = exec.Command("netsh", "interface", "ipv4", "delete", "dnsservers", "ZJU Connect", "all") } @@ -128,5 +137,10 @@ func NewStack(easyConnectClient *client.EasyConnectClient, dnsServer string) (*S log.Printf("Run %s failed: %v", command.String(), err) } + terminal_func.RegisterTerminalFunc("Close Tun Device", func(ctx context.Context) error { + dev.Close() + closeCommand := exec.Command("netsh", "interface", "ipv4", "delete", "dnsservers", "ZJU Connect", "all") + return closeCommand.Run() + }) return s, nil }