From cc3f4c755d1d3ceb6f9f08809b9e3c3b82132cef Mon Sep 17 00:00:00 2001 From: enfein <83481737+enfein@users.noreply.github.com> Date: Wed, 13 Nov 2024 01:16:32 +0000 Subject: [PATCH] v3.8.0 release 1. Allow applications to read data after network connection is closed (issue #168). 2. Update dependency versions. Breaking change: the MTU value in client and server configuration now represent the maximum transmission unit in UDP layer or nested network connection, rather than in data-link layer. We don't recommend setting MTU value bigger than 1440. TCP protocol is not affected. --- Makefile | 2 +- apis/client/interface.go | 4 +- apis/client/mieru.go | 58 ++- apis/common/dns.go | 9 + apis/model/addr.go | 11 + .../package/mieru/amd64/debian/DEBIAN/control | 2 +- build/package/mieru/amd64/rpm/mieru.spec | 2 +- .../package/mieru/arm64/debian/DEBIAN/control | 2 +- build/package/mieru/arm64/rpm/mieru.spec | 2 +- .../package/mita/amd64/debian/DEBIAN/control | 2 +- build/package/mita/amd64/rpm/mita.spec | 2 +- .../package/mita/arm64/debian/DEBIAN/control | 2 +- build/package/mita/arm64/rpm/mita.spec | 2 +- docs/client-install.md | 2 +- docs/client-install.zh_CN.md | 2 +- docs/server-install.md | 18 +- docs/server-install.zh_CN.md | 18 +- pkg/cli/client.go | 37 +- pkg/common/dns.go | 76 --- pkg/common/dns_test.go | 58 --- pkg/common/sockopts/control.go | 32 +- pkg/congestion/bbr_sender.go | 4 + pkg/protocol/mux.go | 20 +- pkg/protocol/session.go | 481 ++++++++++-------- pkg/protocol/underlay_packet.go | 13 +- pkg/protocol/underlay_stream.go | 7 +- pkg/socks5/request.go | 17 +- pkg/socks5/socks5.go | 5 +- pkg/stderror/template.go | 1 + pkg/version/current.go | 2 +- 30 files changed, 440 insertions(+), 453 deletions(-) delete mode 100644 pkg/common/dns.go delete mode 100644 pkg/common/dns_test.go diff --git a/Makefile b/Makefile index 264c2fcd..3ff84e13 100644 --- a/Makefile +++ b/Makefile @@ -32,7 +32,7 @@ PROJECT_NAME=$(shell basename "${ROOT}") # - pkg/version/current.go # # Use `tools/bump_version.sh` script to change all those files at one shot. -VERSION="3.7.0" +VERSION="3.8.0" # Build binaries and installation packages. .PHONY: build diff --git a/apis/client/interface.go b/apis/client/interface.go index cfcaf5f3..ceb9aa9b 100644 --- a/apis/client/interface.go +++ b/apis/client/interface.go @@ -20,6 +20,7 @@ import ( "errors" "net" + apicommon "github.com/enfein/mieru/v3/apis/common" "github.com/enfein/mieru/v3/pkg/appctl/appctlpb" ) @@ -79,7 +80,8 @@ type ClientNetworkService interface { // ClientConfig stores proxy client configuration. type ClientConfig struct { - Profile *appctlpb.ClientProfile + Profile *appctlpb.ClientProfile + Resolver apicommon.DNSResolver } // NewClient creates a blank mieru client with no client config. diff --git a/apis/client/mieru.go b/apis/client/mieru.go index cb191b95..09a755ac 100644 --- a/apis/client/mieru.go +++ b/apis/client/mieru.go @@ -26,6 +26,7 @@ import ( "sync" "time" + apicommon "github.com/enfein/mieru/v3/apis/common" "github.com/enfein/mieru/v3/apis/constant" "github.com/enfein/mieru/v3/apis/model" "github.com/enfein/mieru/v3/pkg/appctl" @@ -97,6 +98,15 @@ func (mc *mieruClient) Start() error { mc.mux = protocol.NewMux(true) activeProfile := mc.config.Profile + // Set DNS resolver. + var resolver apicommon.DNSResolver + if mc.config.Resolver != nil { + resolver = mc.config.Resolver + } else { + resolver = &net.Resolver{} // Default DNS resolver. + } + mc.mux.SetResolver(resolver) + // Set user name and password. user := activeProfile.GetUser() var hashedPassword []byte @@ -130,16 +140,19 @@ func (mc *mieruClient) Start() error { mtu = int(activeProfile.GetMtu()) } endpoints := make([]protocol.UnderlayProperties, 0) - resolver := &common.DNSResolver{} for _, serverInfo := range activeProfile.GetServers() { var proxyHost string var proxyIP net.IP if serverInfo.GetDomainName() != "" { proxyHost = serverInfo.GetDomainName() - proxyIP, err = resolver.LookupIP(context.Background(), proxyHost) + proxyIPs, err := resolver.LookupIP(context.Background(), "ip", proxyHost) if err != nil { return fmt.Errorf(stderror.LookupIPFailedErr, err) } + if len(proxyIPs) == 0 { + return fmt.Errorf(stderror.IPAddressNotFound, proxyHost) + } + proxyIP = proxyIPs[0] } else { proxyHost = serverInfo.GetIpAddress() proxyIP = net.ParseIP(proxyHost) @@ -197,14 +210,8 @@ func (mc *mieruClient) DialContext(ctx context.Context, addr net.Addr) (net.Conn // Check destination address. var netAddrSpec model.NetAddrSpec - if nas, ok := addr.(model.NetAddrSpec); ok { - netAddrSpec = nas - } else if nas, ok := addr.(*model.NetAddrSpec); ok { - netAddrSpec = *nas - } else { - if err := netAddrSpec.From(addr); err != nil { - return nil, fmt.Errorf("invalid destination address: %w", err) - } + if err := netAddrSpec.From(addr); err != nil { + return nil, fmt.Errorf("invalid destination address: %w", err) } if !strings.HasPrefix(netAddrSpec.Network(), "tcp") { return nil, fmt.Errorf("only tcp network is supported") @@ -214,6 +221,33 @@ func (mc *mieruClient) DialContext(ctx context.Context, addr net.Addr) (net.Conn if err != nil { return nil, err } + return mc.dialPostHandshake(conn, netAddrSpec) +} + +func (mc *mieruClient) DialContextWithConn(ctx context.Context, conn net.Conn, addr net.Addr) (net.Conn, error) { + mc.mu.RLock() + defer mc.mu.RUnlock() + if !mc.running { + return nil, ErrClientIsNotRunning + } + + // Check destination address. + var netAddrSpec model.NetAddrSpec + if err := netAddrSpec.From(addr); err != nil { + return nil, fmt.Errorf("invalid destination address: %w", err) + } + if !strings.HasPrefix(netAddrSpec.Network(), "tcp") { + return nil, fmt.Errorf("only tcp network is supported") + } + + subConn, err := mc.mux.DialContextWithConn(ctx, conn) + if err != nil { + return nil, err + } + return mc.dialPostHandshake(subConn, netAddrSpec) +} + +func (mc *mieruClient) dialPostHandshake(conn net.Conn, netAddrSpec model.NetAddrSpec) (net.Conn, error) { var req bytes.Buffer req.Write([]byte{constant.Socks5Version, constant.Socks5ConnectCmd, 0}) if err := netAddrSpec.WriteToSocks5(&req); err != nil { @@ -241,7 +275,3 @@ func (mc *mieruClient) DialContext(ctx context.Context, addr net.Addr) (net.Conn } return conn, nil } - -func (mc *mieruClient) DialContextWithConn(ctx context.Context, conn net.Conn, addr net.Addr) (net.Conn, error) { - return nil, fmt.Errorf("not implemented") -} diff --git a/apis/common/dns.go b/apis/common/dns.go index 0983139b..67a99438 100644 --- a/apis/common/dns.go +++ b/apis/common/dns.go @@ -104,3 +104,12 @@ func ResolveUDPAddr(r DNSResolver, network, address string) (*net.UDPAddr, error return &net.UDPAddr{IP: ips[0], Port: port}, nil } + +// ForbidDefaultResolver causes the process to panic if +// net.DefaultResolver object is used. +func ForbidDefaultResolver() { + net.DefaultResolver.PreferGo = true + net.DefaultResolver.Dial = func(ctx context.Context, network, address string) (net.Conn, error) { + panic("Using net.DefaultResolver is forbidden") + } +} diff --git a/apis/model/addr.go b/apis/model/addr.go index a9f39256..2392086c 100644 --- a/apis/model/addr.go +++ b/apis/model/addr.go @@ -137,6 +137,17 @@ func (n NetAddrSpec) Network() string { // From modifies the NetAddrSpec object with the given network address. func (n *NetAddrSpec) From(addr net.Addr) error { + if nas, ok := addr.(NetAddrSpec); ok { + n.AddrSpec = nas.AddrSpec + n.Net = nas.Net + return nil + } + if nas, ok := addr.(*NetAddrSpec); ok { + n.AddrSpec = nas.AddrSpec + n.Net = nas.Net + return nil + } + n.Net = addr.Network() host, portStr, err := net.SplitHostPort(addr.String()) diff --git a/build/package/mieru/amd64/debian/DEBIAN/control b/build/package/mieru/amd64/debian/DEBIAN/control index 1416d8a6..3f8a86ef 100755 --- a/build/package/mieru/amd64/debian/DEBIAN/control +++ b/build/package/mieru/amd64/debian/DEBIAN/control @@ -1,5 +1,5 @@ Package: mieru -Version: 3.7.0 +Version: 3.8.0 Section: net Priority: optional Architecture: amd64 diff --git a/build/package/mieru/amd64/rpm/mieru.spec b/build/package/mieru/amd64/rpm/mieru.spec index f588bc97..86768378 100644 --- a/build/package/mieru/amd64/rpm/mieru.spec +++ b/build/package/mieru/amd64/rpm/mieru.spec @@ -1,5 +1,5 @@ Name: mieru -Version: 3.7.0 +Version: 3.8.0 Release: 1%{?dist} Summary: Mieru proxy client License: GPLv3+ diff --git a/build/package/mieru/arm64/debian/DEBIAN/control b/build/package/mieru/arm64/debian/DEBIAN/control index 7c50e53b..9de9c81b 100755 --- a/build/package/mieru/arm64/debian/DEBIAN/control +++ b/build/package/mieru/arm64/debian/DEBIAN/control @@ -1,5 +1,5 @@ Package: mieru -Version: 3.7.0 +Version: 3.8.0 Section: net Priority: optional Architecture: arm64 diff --git a/build/package/mieru/arm64/rpm/mieru.spec b/build/package/mieru/arm64/rpm/mieru.spec index f588bc97..86768378 100644 --- a/build/package/mieru/arm64/rpm/mieru.spec +++ b/build/package/mieru/arm64/rpm/mieru.spec @@ -1,5 +1,5 @@ Name: mieru -Version: 3.7.0 +Version: 3.8.0 Release: 1%{?dist} Summary: Mieru proxy client License: GPLv3+ diff --git a/build/package/mita/amd64/debian/DEBIAN/control b/build/package/mita/amd64/debian/DEBIAN/control index 9787f04e..b57783b2 100755 --- a/build/package/mita/amd64/debian/DEBIAN/control +++ b/build/package/mita/amd64/debian/DEBIAN/control @@ -1,5 +1,5 @@ Package: mita -Version: 3.7.0 +Version: 3.8.0 Section: net Priority: optional Architecture: amd64 diff --git a/build/package/mita/amd64/rpm/mita.spec b/build/package/mita/amd64/rpm/mita.spec index 782ea619..4c6accbe 100644 --- a/build/package/mita/amd64/rpm/mita.spec +++ b/build/package/mita/amd64/rpm/mita.spec @@ -1,5 +1,5 @@ Name: mita -Version: 3.7.0 +Version: 3.8.0 Release: 1%{?dist} Summary: Mieru proxy server License: GPLv3+ diff --git a/build/package/mita/arm64/debian/DEBIAN/control b/build/package/mita/arm64/debian/DEBIAN/control index 27af556a..8932173a 100755 --- a/build/package/mita/arm64/debian/DEBIAN/control +++ b/build/package/mita/arm64/debian/DEBIAN/control @@ -1,5 +1,5 @@ Package: mita -Version: 3.7.0 +Version: 3.8.0 Section: net Priority: optional Architecture: arm64 diff --git a/build/package/mita/arm64/rpm/mita.spec b/build/package/mita/arm64/rpm/mita.spec index c0e18f58..03f068c0 100644 --- a/build/package/mita/arm64/rpm/mita.spec +++ b/build/package/mita/arm64/rpm/mita.spec @@ -1,5 +1,5 @@ Name: mita -Version: 3.7.0 +Version: 3.8.0 Release: 1%{?dist} Summary: Mieru proxy server License: GPLv3+ diff --git a/docs/client-install.md b/docs/client-install.md index d95eb617..a1137640 100644 --- a/docs/client-install.md +++ b/docs/client-install.md @@ -64,7 +64,7 @@ Please use a text editor to modify the following fields. 3. In the `profiles` -> `servers` -> `ipAddress` property, fill in the public address of the proxy server. Both IPv4 and IPv6 addresses are supported. 4. If you have registered a domain name for the proxy server, please fill in the domain name in `profiles` -> `servers` -> `domainName`. Otherwise, do not modify this property. 5. Fill in `profiles` -> `servers` -> `portBindings` -> `port` with the TCP or UDP port number that mita is listening to. The port number must be the same as the one set in the proxy server. If you want to listen to a range of consecutive port numbers, you can also use the `portRange` property instead. -6. Specify a value between 1280 and 1500 for the `profiles` -> `mtu` property. The default value is 1400. This value can be different from the setting in the proxy server. +6. Specify a value between 1280 and 1400 for the `profiles` -> `mtu` property. The default value is 1400. This value can be different from the setting in the proxy server. 7. If you want to adjust the frequency of multiplexing, you can set a value for the `profiles` -> `multiplexing` -> `level` property. The values you can use here include `MULTIPLEXING_OFF`, `MULTIPLEXING_LOW`, `MULTIPLEXING_MIDDLE`, and `MULTIPLEXING_HIGH`. `MULTIPLEXING_OFF` will disable multiplexing, and the default value is `MULTIPLEXING_LOW`. 8. Please specify a value between 1025 and 65535 for the `rpcPort` property. 9. Please specify a value between 1025 and 65535 for the `socks5Port` property. This port cannot be the same as `rpcPort`. diff --git a/docs/client-install.zh_CN.md b/docs/client-install.zh_CN.md index 4a1cd820..1897a036 100644 --- a/docs/client-install.zh_CN.md +++ b/docs/client-install.zh_CN.md @@ -64,7 +64,7 @@ mieru apply config 3. 在 `profiles` -> `servers` -> `ipAddress` 属性中,填写代理服务器的公网地址。支持 IPv4 和 IPv6 地址。 4. 如果你为代理服务器注册了域名,请在 `profiles` -> `servers` -> `domainName` 中填写域名。否则,请勿修改这个属性。 5. 在 `profiles` -> `servers` -> `portBindings` -> `port` 中填写 mita 监听的 TCP 或 UDP 端口号。这个端口号必须与代理服务器中的设置相同。如果想要监听连续的端口号,也可以改为使用 `portRange` 属性。 -6. 请为 `profiles` -> `mtu` 属性中指定一个从 1280 到 1500 之间的值。默认值为 1400。这个值可以与代理服务器中的设置不同。 +6. 请为 `profiles` -> `mtu` 属性中指定一个从 1280 到 1400 之间的值。默认值为 1400。这个值可以与代理服务器中的设置不同。 7. 如果想要调整多路复用的频率,是更多地创建新连接,还是更多地重用旧连接,可以为 `profiles` -> `multiplexing` -> `level` 属性设定一个值。这里可以使用的值包括 `MULTIPLEXING_OFF`, `MULTIPLEXING_LOW`, `MULTIPLEXING_MIDDLE`, `MULTIPLEXING_HIGH`。其中 `MULTIPLEXING_OFF` 会关闭多路复用功能。默认值为 `MULTIPLEXING_LOW`。 8. 请为 `rpcPort` 属性指定一个从 1025 到 65535 之间的数值。 9. 请为 `socks5Port` 属性指定一个从 1025 到 65535 之间的数值。该端口不能与 `rpcPort` 相同。 diff --git a/docs/server-install.md b/docs/server-install.md index 048eb268..971fdaf9 100644 --- a/docs/server-install.md +++ b/docs/server-install.md @@ -8,32 +8,32 @@ Before installation and configuration, connect to the server via SSH and then ex ```sh # Debian / Ubuntu - X86_64 -curl -LSO https://github.com/enfein/mieru/releases/download/v3.7.0/mita_3.7.0_amd64.deb +curl -LSO https://github.com/enfein/mieru/releases/download/v3.8.0/mita_3.8.0_amd64.deb # Debian / Ubuntu - ARM 64 -curl -LSO https://github.com/enfein/mieru/releases/download/v3.7.0/mita_3.7.0_arm64.deb +curl -LSO https://github.com/enfein/mieru/releases/download/v3.8.0/mita_3.8.0_arm64.deb # RedHat / CentOS / Rocky Linux - X86_64 -curl -LSO https://github.com/enfein/mieru/releases/download/v3.7.0/mita-3.7.0-1.x86_64.rpm +curl -LSO https://github.com/enfein/mieru/releases/download/v3.8.0/mita-3.8.0-1.x86_64.rpm # RedHat / CentOS / Rocky Linux - ARM 64 -curl -LSO https://github.com/enfein/mieru/releases/download/v3.7.0/mita-3.7.0-1.aarch64.rpm +curl -LSO https://github.com/enfein/mieru/releases/download/v3.8.0/mita-3.8.0-1.aarch64.rpm ``` ## Install mita package ```sh # Debian / Ubuntu - X86_64 -sudo dpkg -i mita_3.7.0_amd64.deb +sudo dpkg -i mita_3.8.0_amd64.deb # Debian / Ubuntu - ARM 64 -sudo dpkg -i mita_3.7.0_arm64.deb +sudo dpkg -i mita_3.8.0_arm64.deb # RedHat / CentOS / Rocky Linux - X86_64 -sudo rpm -Uvh --force mita-3.7.0-1.x86_64.rpm +sudo rpm -Uvh --force mita-3.8.0-1.x86_64.rpm # RedHat / CentOS / Rocky Linux - ARM 64 -sudo rpm -Uvh --force mita-3.7.0-1.aarch64.rpm +sudo rpm -Uvh --force mita-3.8.0-1.aarch64.rpm ``` Those instructions can also be used to upgrade the version of mita software package. @@ -106,7 +106,7 @@ to modify the proxy server settings. `` is a JSON formatted configuration 2. The `portBindings` -> `protocol` property can be set to `TCP` or `UDP`. 3. Fill in the `users` -> `name` property with the user name. 4. Fill in the `users` -> `password` property with the user's password. -5. The `mtu` property is the maximum data link layer payload size when using the UDP proxy protocol. The default value is 1400. You can choose a value between 1280 and 1500. +5. The `mtu` property is the maximum transport layer payload size when using the UDP proxy protocol. The default value is 1400. The minimum value is 1280. In addition to this, mita can listen to several different ports. We recommend using multiple ports in both server and client configurations. diff --git a/docs/server-install.zh_CN.md b/docs/server-install.zh_CN.md index bcd40259..57314689 100644 --- a/docs/server-install.zh_CN.md +++ b/docs/server-install.zh_CN.md @@ -8,32 +8,32 @@ ```sh # Debian / Ubuntu - X86_64 -curl -LSO https://github.com/enfein/mieru/releases/download/v3.7.0/mita_3.7.0_amd64.deb +curl -LSO https://github.com/enfein/mieru/releases/download/v3.8.0/mita_3.8.0_amd64.deb # Debian / Ubuntu - ARM 64 -curl -LSO https://github.com/enfein/mieru/releases/download/v3.7.0/mita_3.7.0_arm64.deb +curl -LSO https://github.com/enfein/mieru/releases/download/v3.8.0/mita_3.8.0_arm64.deb # RedHat / CentOS / Rocky Linux - X86_64 -curl -LSO https://github.com/enfein/mieru/releases/download/v3.7.0/mita-3.7.0-1.x86_64.rpm +curl -LSO https://github.com/enfein/mieru/releases/download/v3.8.0/mita-3.8.0-1.x86_64.rpm # RedHat / CentOS / Rocky Linux - ARM 64 -curl -LSO https://github.com/enfein/mieru/releases/download/v3.7.0/mita-3.7.0-1.aarch64.rpm +curl -LSO https://github.com/enfein/mieru/releases/download/v3.8.0/mita-3.8.0-1.aarch64.rpm ``` ## 安装 mita 软件包 ```sh # Debian / Ubuntu - X86_64 -sudo dpkg -i mita_3.7.0_amd64.deb +sudo dpkg -i mita_3.8.0_amd64.deb # Debian / Ubuntu - ARM 64 -sudo dpkg -i mita_3.7.0_arm64.deb +sudo dpkg -i mita_3.8.0_arm64.deb # RedHat / CentOS / Rocky Linux - X86_64 -sudo rpm -Uvh --force mita-3.7.0-1.x86_64.rpm +sudo rpm -Uvh --force mita-3.8.0-1.x86_64.rpm # RedHat / CentOS / Rocky Linux - ARM 64 -sudo rpm -Uvh --force mita-3.7.0-1.aarch64.rpm +sudo rpm -Uvh --force mita-3.8.0-1.aarch64.rpm ``` 上述指令也可以用来升级 mita 软件包的版本。 @@ -106,7 +106,7 @@ mita apply config 2. `portBindings` -> `protocol` 属性可以使用 `TCP` 或者 `UDP`。 3. 在 `users` -> `name` 属性中填写用户名。 4. 在 `users` -> `password` 属性中填写该用户的密码。 -5. `mtu` 属性是使用 UDP 代理协议时,数据链路层最大的载荷大小。默认值是 1400,可以选择 1280 到 1500 之间的值。 +5. `mtu` 属性是使用 UDP 代理协议时,传输层最大的载荷大小。默认值是 1400,最小值是 1280。 除此之外,mita 可以监听多个不同的端口。我们建议在服务器和客户端配置中使用多个端口。 diff --git a/pkg/cli/client.go b/pkg/cli/client.go index c7a364b2..ee6de0c4 100644 --- a/pkg/cli/client.go +++ b/pkg/cli/client.go @@ -31,6 +31,7 @@ import ( "sync" "time" + apicommon "github.com/enfein/mieru/v3/apis/common" "github.com/enfein/mieru/v3/apis/constant" "github.com/enfein/mieru/v3/pkg/appctl" "github.com/enfein/mieru/v3/pkg/appctl/appctlgrpc" @@ -439,6 +440,8 @@ var clientRunFunc = func(s []string) error { serverDecryptionMetricGroup.DisableLogging() } + resolver := &net.Resolver{} + var wg sync.WaitGroup // RPC port is allowed to set to 0. In that case, don't run RPC server. @@ -448,10 +451,16 @@ var clientRunFunc = func(s []string) error { wg.Add(1) go func() { rpcAddr := "localhost:" + strconv.Itoa(int(config.GetRpcPort())) - listenConfig := sockopts.ListenConfigWithControls() - rpcListener, err := listenConfig.Listen(context.Background(), "tcp", rpcAddr) + rpcTCPAddr, err := apicommon.ResolveTCPAddr(resolver, "tcp", rpcAddr) + if err != nil { + log.Fatalf("Resolve RPC address %q failed: %v", rpcAddr, err) + } + rpcListener, err := net.ListenTCP("tcp", rpcTCPAddr) if err != nil { - log.Fatalf("listen on RPC address tcp %q failed: %v", rpcAddr, err) + log.Fatalf("Listen on RPC address %q failed: %v", rpcAddr, err) + } + if err := sockopts.ApplyTCPControls(rpcListener); err != nil { + log.Fatalf("ApplyTCPControls() failed: %v", err) } grpcServer := grpc.NewServer(grpc.MaxRecvMsgSize(appctl.MaxRecvMsgSize)) appctl.SetClientRPCServerRef(grpcServer) @@ -504,16 +513,19 @@ var clientRunFunc = func(s []string) error { mtu = int(activeProfile.GetMtu()) } endpoints := make([]protocol.UnderlayProperties, 0) - resolver := &common.DNSResolver{} for _, serverInfo := range activeProfile.GetServers() { var proxyHost string var proxyIP net.IP if serverInfo.GetDomainName() != "" { proxyHost = serverInfo.GetDomainName() - proxyIP, err = resolver.LookupIP(context.Background(), proxyHost) + proxyIPs, err := resolver.LookupIP(context.Background(), "ip", proxyHost) if err != nil { return fmt.Errorf(stderror.LookupIPFailedErr, err) } + if len(proxyIPs) == 0 { + return fmt.Errorf(stderror.IPAddressNotFound, proxyHost) + } + proxyIP = proxyIPs[0] } else { proxyHost = serverInfo.GetIpAddress() proxyIP = net.ParseIP(proxyHost) @@ -556,6 +568,7 @@ var clientRunFunc = func(s []string) error { IngressCredentials: socks5IngressCredentials, }, ProxyMux: mux, + Resolver: resolver, HandshakeTimeout: 10 * time.Second, } socks5Server, err := socks5.New(socks5Config) @@ -573,14 +586,20 @@ var clientRunFunc = func(s []string) error { } wg.Add(1) go func(socks5Addr string) { - listenConfig := sockopts.ListenConfigWithControls() - l, err := listenConfig.Listen(context.Background(), "tcp", socks5Addr) + socks5TCPAddr, err := apicommon.ResolveTCPAddr(resolver, "tcp", socks5Addr) + if err != nil { + log.Fatalf("Resolve socks5 address %q failed: %v", socks5Addr, err) + } + socks5Listener, err := net.ListenTCP("tcp", socks5TCPAddr) if err != nil { - log.Fatalf("listen on socks5 address tcp %q failed: %v", socks5Addr, err) + log.Fatalf("Listen on socks5 address %q failed: %v", socks5Addr, err) + } + if err := sockopts.ApplyTCPControls(socks5Listener); err != nil { + log.Fatalf("ApplyTCPControls() failed: %v", err) } close(appctl.ClientSocks5ServerStarted) log.Infof("mieru client socks5 server is running") - if err = socks5Server.Serve(l); err != nil { + if err = socks5Server.Serve(socks5Listener); err != nil { log.Fatalf("run socks5 server failed: %v", err) } log.Infof("mieru client socks5 server is stopped") diff --git a/pkg/common/dns.go b/pkg/common/dns.go deleted file mode 100644 index 5b85fa0f..00000000 --- a/pkg/common/dns.go +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (C) 2023 mieru authors -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU General Public License for more details. -// -// You should have received a copy of the GNU General Public License -// along with this program. If not, see . - -package common - -import ( - "context" - "fmt" - "net" -) - -type DNSPolicy uint8 - -const ( - DNSPolicyDefault DNSPolicy = iota - DNSPolicyIPv4Only - DNSPolicyIPv6Only -) - -func (p DNSPolicy) String() string { - switch p { - case DNSPolicyDefault: - return "DEFAULT" - case DNSPolicyIPv4Only: - return "IPv4_ONLY" - case DNSPolicyIPv6Only: - return "IPv6_ONLY" - default: - return "UNSPECIFIED" - } -} - -// DNSResolver uses Golang's default DNS implementation to resolve host names. -type DNSResolver struct { - DNSPolicy DNSPolicy -} - -// LookupIP looks up host for the given network using the DNS resolver. -func (d *DNSResolver) LookupIP(ctx context.Context, host string) (net.IP, error) { - network := "ip" - switch d.DNSPolicy { - case DNSPolicyIPv4Only: - network = "ip4" - case DNSPolicyIPv6Only: - network = "ip6" - } - ips, err := net.DefaultResolver.LookupIP(ctx, network, host) - if err != nil { - return nil, err - } - if len(ips) == 0 { - return nil, fmt.Errorf("lookup IP from %s returned no result", host) - } - return ips[0], nil -} - -// ForbidDefaultResolver causes the process to panic if -// net.DefaultResolver is used. -func ForbidDefaultResolver() { - net.DefaultResolver.PreferGo = true - net.DefaultResolver.Dial = func(ctx context.Context, network, address string) (net.Conn, error) { - panic("Using net.DefaultResolver is forbidden") - } -} diff --git a/pkg/common/dns_test.go b/pkg/common/dns_test.go deleted file mode 100644 index b7c4a758..00000000 --- a/pkg/common/dns_test.go +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (C) 2023 mieru authors -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU General Public License for more details. -// -// You should have received a copy of the GNU General Public License -// along with this program. If not, see . - -package common - -import ( - "context" - "testing" -) - -func TestDNSResolver(t *testing.T) { - ctx := context.Background() - - // Default policy. - d := DNSResolver{} - addr, err := d.LookupIP(ctx, "localhost") - if err != nil { - t.Fatalf("LookupIP() failed: %v", err) - } - if !addr.IsLoopback() { - t.Errorf("Returned IP address is not loopback address") - } - - // IPv4 only. - d.DNSPolicy = DNSPolicyIPv4Only - addr, err = d.LookupIP(ctx, "localhost") - if err != nil { - t.Fatalf("LookupIP() failed: %v", err) - } - if !addr.IsLoopback() { - t.Errorf("Returned IP address is not loopback address") - } - - if IsIPDualStack() { - // IPv6 only. - d.DNSPolicy = DNSPolicyIPv6Only - addr, err := d.LookupIP(ctx, "localhost") - addr2, err2 := d.LookupIP(ctx, "ip6-localhost") - if err != nil && err2 != nil { - t.Fatalf("LookupIP() failed: %v; %v", err, err2) - } - if (addr != nil && !addr.IsLoopback()) || (addr2 != nil && !addr2.IsLoopback()) { - t.Errorf("Returned IP address is not loopback address") - } - } -} diff --git a/pkg/common/sockopts/control.go b/pkg/common/sockopts/control.go index 0063dd9c..c9eda673 100644 --- a/pkg/common/sockopts/control.go +++ b/pkg/common/sockopts/control.go @@ -22,7 +22,7 @@ import ( "syscall" ) -// Control is the Control function used by net.Dialer and net.ListenConfig. +// Control is the Control function used by net.Dialer. type Control = func(network, address string, c syscall.RawConn) error // RawControl is the Control function used by syscall.RawConn. @@ -31,31 +31,19 @@ type RawControl = func(fd uintptr) // RawControlErr returns an error with RawControl. type RawControlErr = func(fd uintptr) error -// Append returns a Control function that chains next after prev. -func Append(prev, next Control) Control { - if prev == nil { - return next - } else if next == nil { - return prev +// ApplyTCPControls applies all the recommended controls to the TCP listener. +func ApplyTCPControls(listener *net.TCPListener) error { + rawConn, err := listener.SyscallConn() + if err != nil { + return fmt.Errorf("SyscallConn() failed: %w", err) } - return func(network, address string, c syscall.RawConn) error { - if err := prev(network, address, c); err != nil { - return err - } - return next(network, address, c) + if err := rawConn.Control(ReuseAddrPortRaw()); err != nil { + return err } -} - -// ListenConfigWithControls returns a net.ListenConfig with -// all the recommended controls applied. -func ListenConfigWithControls() net.ListenConfig { - var protectPathControl Control if path, found := os.LookupEnv("MIERU_PROTECT_PATH"); found { - protectPathControl = ProtectPath(path) - } - return net.ListenConfig{ - Control: Append(ReuseAddrPort(), protectPathControl), + return rawConn.Control(ProtectPathRaw(path)) } + return nil } // ApplyUDPControls applies all the recommended controls to the UDP connection. diff --git a/pkg/congestion/bbr_sender.go b/pkg/congestion/bbr_sender.go index 9e7bf474..b27f0159 100644 --- a/pkg/congestion/bbr_sender.go +++ b/pkg/congestion/bbr_sender.go @@ -413,6 +413,9 @@ func (b *BBRSender) OnCongestionEvent(priorInFlight int64, eventTime time.Time, // OnApplicationLimited updates BBR sender state when there is no application // data to send. func (b *BBRSender) OnApplicationLimited(bytesInFlight int64) { + b.mu.Lock() + defer b.mu.Unlock() + if bytesInFlight >= b.getCongestionWindow() { return } @@ -425,6 +428,7 @@ func (b *BBRSender) OnApplicationLimited(bytesInFlight int64) { func (b *BBRSender) CanSend(bytesInFlight, bytes int64) bool { b.mu.Lock() defer b.mu.Unlock() + pacerCanSend := b.pacer.CanSend(time.Now(), bytes, b.getPacingRate()) return bytesInFlight < b.getCongestionWindow() && pacerCanSend } diff --git a/pkg/protocol/mux.go b/pkg/protocol/mux.go index 003af7dd..ad6af9fc 100644 --- a/pkg/protocol/mux.go +++ b/pkg/protocol/mux.go @@ -78,7 +78,7 @@ func NewMux(isClinet bool) *Mux { underlays: make([]Underlay, 0), chAccept: make(chan net.Conn, sessionChanCapacity), chAcceptErr: make(chan error, 1), // non-blocking - resolver: &net.Resolver{PreferGo: true}, + resolver: &net.Resolver{}, done: make(chan struct{}), cleaner: time.NewTicker(idleUnderlayTickerInterval), } @@ -426,10 +426,18 @@ func (m *Mux) acceptUnderlayLoop(ctx context.Context, properties UnderlayPropert network := properties.LocalAddr().Network() switch network { case "tcp", "tcp4", "tcp6": - listenConfig := sockopts.ListenConfigWithControls() - rawListener, err := listenConfig.Listen(ctx, network, laddr) + tcpAddr, err := apicommon.ResolveTCPAddr(m.resolver, "tcp", laddr) if err != nil { - m.chAcceptErr <- fmt.Errorf("Listen() failed: %w", err) + m.chAcceptErr <- fmt.Errorf("ResolveTCPAddr() failed: %w", err) + return + } + rawListener, err := net.ListenTCP("tcp", tcpAddr) + if err != nil { + m.chAcceptErr <- fmt.Errorf("ListenTCP() failed: %w", err) + return + } + if err := sockopts.ApplyTCPControls(rawListener); err != nil { + m.chAcceptErr <- fmt.Errorf("ApplyTCPControls() failed: %w", err) return } log.Infof("Mux is listening to endpoint %s %s", network, laddr) @@ -597,7 +605,7 @@ func (m *Mux) newUnderlay(ctx context.Context) (Underlay, error) { block.SetBlockContext(cipher.BlockContext{ UserName: m.username, }) - underlay, err = NewStreamUnderlay(ctx, p.RemoteAddr().Network(), "", p.RemoteAddr().String(), p.MTU(), block) + underlay, err = NewStreamUnderlay(ctx, p.RemoteAddr().Network(), "", p.RemoteAddr().String(), p.MTU(), block, m.resolver) if err != nil { return nil, fmt.Errorf("NewTCPUnderlay() failed: %v", err) } @@ -609,7 +617,7 @@ func (m *Mux) newUnderlay(ctx context.Context) (Underlay, error) { block.SetBlockContext(cipher.BlockContext{ UserName: m.username, }) - underlay, err = NewPacketUnderlay(ctx, p.RemoteAddr().Network(), "", p.RemoteAddr().String(), p.MTU(), block) + underlay, err = NewPacketUnderlay(ctx, p.RemoteAddr().Network(), "", p.RemoteAddr().String(), p.MTU(), block, m.resolver) if err != nil { return nil, fmt.Errorf("NewUDPUnderlay() failed: %v", err) } diff --git a/pkg/protocol/session.go b/pkg/protocol/session.go index 2ea3f2bf..ef80279f 100644 --- a/pkg/protocol/session.go +++ b/pkg/protocol/session.go @@ -87,12 +87,13 @@ type Session struct { status statusCode // session status users map[string]*appctlpb.User // all registered users - ready chan struct{} // indicate the session is ready to use - done chan struct{} // indicate the session is complete - readDeadline time.Time // read deadline - writeDeadline time.Time // write deadline - inputErr chan error // input error - outputErr chan error // output error + ready chan struct{} // indicate the session is ready to use + closeRequested atomic.Bool // the session is being closed or has been closed + closedChan chan struct{} // indicate the session is closed + readDeadline time.Time // read deadline + writeDeadline time.Time // write deadline + inputErr chan error // input error + outputErr chan error // output error sendQueue *segmentTree // segments waiting to send sendBuf *segmentTree // segments sent but not acknowledged @@ -102,6 +103,7 @@ type Session struct { nextSend uint32 // next sequence number to send a segment nextRecv uint32 // next sequence number to receive + lastSend uint32 // last segment sequence number sent lastRXTime time.Time // last timestamp when a segment is received lastTXTime time.Time // last timestamp when a segment is sent ackOnDataRecv atomic.Bool // whether ack should be sent due to receive of new data @@ -116,10 +118,9 @@ type Session struct { remoteWindowSize uint16 wg sync.WaitGroup - rLock sync.Mutex - wLock sync.Mutex - cLock sync.Mutex - sLock sync.Mutex + rLock sync.Mutex // serialize read from application + oLock sync.Mutex // serialize the output sequence + sLock sync.Mutex // serialize the state transition } // Session must implement net.Conn interface. @@ -140,7 +141,7 @@ func NewSession(id uint32, isClient bool, mtu int, users map[string]*appctlpb.Us status: statusOK, users: users, ready: make(chan struct{}), - done: make(chan struct{}), + closedChan: make(chan struct{}), readDeadline: time.Time{}, writeDeadline: time.Time{}, inputErr: make(chan error, 2), // allow nested @@ -167,15 +168,10 @@ func (s *Session) String() string { } // Read lets a user to read data from receive queue. +// Read is allowed even after the session has been closed. func (s *Session) Read(b []byte) (n int, err error) { s.rLock.Lock() defer s.rLock.Unlock() - if s.isStateBefore(sessionAttached, false) { - return 0, fmt.Errorf("%v is not ready for Read()", s) - } - if s.isStateAfter(sessionClosed, true) { - return 0, io.ErrClosedPipe - } defer func() { s.readDeadline = time.Time{} }() @@ -228,7 +224,7 @@ func (s *Session) Read(b []byte) (n int, err error) { } else { // Wait for incoming segments. select { - case <-s.done: + case <-s.closedChan: return 0, io.EOF case <-s.inputErr: return 0, io.ErrUnexpectedEOF @@ -257,8 +253,9 @@ func (s *Session) Read(b []byte) (n int, err error) { // Write stores the data to send queue. func (s *Session) Write(b []byte) (n int, err error) { - s.wLock.Lock() - defer s.wLock.Unlock() + if s.closeRequested.Load() { + return 0, io.ErrClosedPipe + } if s.isStateBefore(sessionAttached, false) { return 0, fmt.Errorf("%v is not ready for Write()", s) @@ -272,6 +269,7 @@ func (s *Session) Write(b []byte) (n int, err error) { if s.isClient && s.isState(sessionAttached) { // Before the first write, client needs to send open session request. + s.oLock.Lock() seg := &segment{ metadata: &sessionStruct{ baseStruct: baseStruct{ @@ -292,6 +290,7 @@ func (s *Session) Write(b []byte) (n int, err error) { log.Tracef("%v writing %d bytes with open session request", s, len(seg.payload)) } s.sendQueue.InsertBlocking(seg) + s.oLock.Unlock() if len(seg.payload) > 0 { return len(seg.payload), nil } @@ -319,60 +318,7 @@ func (s *Session) Write(b []byte) (n int, err error) { // Close terminates the session. func (s *Session) Close() error { - s.cLock.Lock() - defer s.cLock.Unlock() - select { - case <-s.done: - s.forwardStateTo(sessionClosed) - return nil - default: - } - - log.Debugf("Closing %v", s) - s.sendQueue.DeleteAll() - s.sendBuf.DeleteAll() - s.recvBuf.DeleteAll() - s.recvQueue.DeleteAll() - if s.isState(sessionAttached) || s.isState(sessionEstablished) { - // Send closeSessionRequest, but don't wait for closeSessionResponse, - // because the underlay connection may be already broken. - // The closeSessionRequest won't be sent again. - seg := &segment{ - metadata: &sessionStruct{ - baseStruct: baseStruct{ - protocol: uint8(closeSessionRequest), - }, - sessionID: s.id, - seq: s.nextSend, - statusCode: uint8(s.status), - }, - transport: s.conn.TransportProtocol(), - } - s.nextSend++ - switch s.conn.TransportProtocol() { - case common.StreamTransport: - s.sendQueue.InsertBlocking(seg) - case common.PacketTransport: - if err := s.output(seg, s.RemoteAddr()); err != nil { - log.Debugf("output() failed: %v", err) - } - default: - log.Debugf("Unsupported transport protocol %v", s.conn.TransportProtocol()) - } - } - - // Wait for sendQueue to flush. - timeC := time.After(time.Second) - select { - case <-timeC: - case <-s.sendQueue.chanEmptyEvent: - } - - s.forwardStateTo(sessionClosed) - close(s.done) - log.Debugf("Closed %v", s) - metrics.CurrEstablished.Add(-1) - return nil + return s.closeWithError(nil) } func (s *Session) LocalAddr() net.Addr { @@ -486,14 +432,18 @@ func (s *Session) writeChunk(b []byte) (n int, err error) { nFragment = (len(b)-1)/fragmentSize + 1 } + s.oLock.Lock() ptr := b for i := nFragment - 1; i >= 0; i-- { select { - case <-s.done: + case <-s.closedChan: + s.oLock.Unlock() return 0, io.EOF case <-s.outputErr: + s.oLock.Unlock() return 0, io.ErrClosedPipe case <-timeC: + s.oLock.Unlock() return 0, stderror.ErrTimeout default: } @@ -525,6 +475,7 @@ func (s *Session) writeChunk(b []byte) (n int, err error) { s.sendQueue.InsertBlocking(seg) ptr = ptr[partLen:] } + s.oLock.Unlock() // To create back pressure, wait until sendQueue is moving. for { @@ -554,14 +505,14 @@ func (s *Session) runInputLoop(ctx context.Context) error { select { case <-ctx.Done(): return nil - case <-s.done: + case <-s.closedChan: return nil case seg := <-s.recvChan: if err := s.input(seg); err != nil { err = fmt.Errorf("input() failed: %w", err) log.Debugf("%v %v", s, err) s.inputErr <- err - s.Close() + s.closeWithError(err) return err } } @@ -575,7 +526,7 @@ func (s *Session) runOutputLoop(ctx context.Context) error { select { case <-ctx.Done(): return nil - case <-s.done: + case <-s.closedChan: return nil case <-ticker.C: case <-s.sendQueue.chanNotEmptyEvent: @@ -583,154 +534,178 @@ func (s *Session) runOutputLoop(ctx context.Context) error { switch s.conn.TransportProtocol() { case common.StreamTransport: - for { - seg, ok := s.sendQueue.DeleteMin() - if !ok { - break - } - if err := s.output(seg, nil); err != nil { - err = fmt.Errorf("output() failed: %w", err) - log.Debugf("%v %v", s, err) - s.outputErr <- err - s.Close() - break - } - } + s.runOutputOnceStream() case common.PacketTransport: - closeSession := false - hasLoss := false - hasTimeout := false - - // Resend segments in sendBuf. - // To avoid deadlock, session can't be closed inside Ascend(). - var bytesInFlight int64 - s.sendBuf.Ascend(func(iter *segment) bool { - bytesInFlight += int64(packetOverhead + len(iter.payload)) - if iter.txCount >= txCountLimit { - err := fmt.Errorf("too many retransmission of %v", iter) - log.Debugf("%v is unhealthy: %v", s, err) - s.outputErr <- err - closeSession = true - return false - } - if (iter.ackCount >= earlyRetransmission && iter.txCount <= earlyRetransmissionLimit) || time.Since(iter.txTime) > iter.txTimeout { - if iter.ackCount >= earlyRetransmission { - hasLoss = true - } else { - hasTimeout = true - } - iter.ackCount = 0 - iter.txCount++ - iter.txTime = time.Now() - iter.txTimeout = s.rttStat.RTO() * time.Duration(mathext.Min(math.Pow(txTimeoutBackOff, float64(iter.txCount)), maxBackOffMultiplier)) - if isDataAckProtocol(iter.metadata.Protocol()) { - das, _ := toDataAckStruct(iter.metadata) - das.unAckSeq = s.nextRecv - } - if err := s.output(iter, s.RemoteAddr()); err != nil { - err = fmt.Errorf("output() failed: %w", err) - log.Debugf("%v %v", s, err) - s.outputErr <- err - closeSession = true - return false - } - bytesInFlight += int64(packetOverhead + len(iter.payload)) - return true - } - return true - }) - if closeSession { - s.Close() + s.runOutputOncePacket() + default: + err := fmt.Errorf("unsupported transport protocol %v", s.conn.TransportProtocol()) + log.Debugf("%v %v", s, err) + s.outputErr <- err + s.closeWithError(err) + } + } +} + +func (s *Session) runOutputOnceStream() { + s.oLock.Lock() + defer s.oLock.Unlock() + + for { + seg, ok := s.sendQueue.DeleteMin() + if !ok { + break + } + if err := s.output(seg, nil); err != nil { + err = fmt.Errorf("output() failed: %w", err) + log.Debugf("%v %v", s, err) + s.outputErr <- err + s.closeWithError(err) + break + } + } +} + +func (s *Session) runOutputOncePacket() { + var closeSessionReason error + hasLoss := false + hasTimeout := false + var bytesInFlight int64 + + // Resend segments in sendBuf. + // To avoid deadlock, session can't be closed inside Ascend(). + s.oLock.Lock() + s.sendBuf.Ascend(func(iter *segment) bool { + bytesInFlight += int64(packetOverhead + len(iter.payload)) + if iter.txCount >= txCountLimit { + err := fmt.Errorf("too many retransmission of %v", iter) + log.Debugf("%v is unhealthy: %v", s, err) + s.outputErr <- err + closeSessionReason = err + return false + } + if (iter.ackCount >= earlyRetransmission && iter.txCount <= earlyRetransmissionLimit) || time.Since(iter.txTime) > iter.txTimeout { + if iter.ackCount >= earlyRetransmission { + hasLoss = true + } else { + hasTimeout = true + } + iter.ackCount = 0 + iter.txCount++ + iter.txTime = time.Now() + iter.txTimeout = s.rttStat.RTO() * time.Duration(mathext.Min(math.Pow(txTimeoutBackOff, float64(iter.txCount)), maxBackOffMultiplier)) + if isDataAckProtocol(iter.metadata.Protocol()) { + das, _ := toDataAckStruct(iter.metadata) + das.unAckSeq = s.nextRecv + } + if err := s.output(iter, s.RemoteAddr()); err != nil { + err = fmt.Errorf("output() failed: %w", err) + log.Debugf("%v %v", s, err) + s.outputErr <- err + closeSessionReason = err + return false } - if hasLoss || hasTimeout { - s.legacysendAlgorithm.OnLoss() // OnTimeout() is too aggressive. + bytesInFlight += int64(packetOverhead + len(iter.payload)) + return true + } + return true + }) + s.oLock.Unlock() + if closeSessionReason != nil { + s.closeWithError(closeSessionReason) + } + if hasLoss || hasTimeout { + s.legacysendAlgorithm.OnLoss() // OnTimeout() is too aggressive. + } + + // Send new segments in sendQueue. + if s.sendQueue.Len() > 0 { + s.oLock.Lock() + for { + seg, deleted := s.sendQueue.DeleteMinIf(func(iter *segment) bool { + return s.sendAlgorithm.CanSend(bytesInFlight, int64(packetOverhead+len(iter.payload))) + }) + if !deleted { + s.oLock.Unlock() + break } - // Send new segments in sendQueue. - if s.sendQueue.Len() > 0 { - for { - seg, deleted := s.sendQueue.DeleteMinIf(func(iter *segment) bool { - return s.sendAlgorithm.CanSend(bytesInFlight, int64(packetOverhead+len(iter.payload))) - }) - if !deleted { - break - } - seg.txCount++ - seg.txTime = time.Now() - seg.txTimeout = s.rttStat.RTO() * time.Duration(mathext.Min(math.Pow(txTimeoutBackOff, float64(seg.txCount)), maxBackOffMultiplier)) - if isDataAckProtocol(seg.metadata.Protocol()) { - das, _ := toDataAckStruct(seg.metadata) - das.unAckSeq = s.nextRecv - } - s.sendBuf.InsertBlocking(seg) - if err := s.output(seg, s.RemoteAddr()); err != nil { - err = fmt.Errorf("output() failed: %w", err) - log.Debugf("%v %v", s, err) - s.outputErr <- err - s.Close() - break - } else { - seq, err := seg.Seq() - if err != nil { - err = fmt.Errorf("failed to get sequence number from %v: %w", seg, err) - log.Debugf("%v %v", s, err) - s.outputErr <- err - s.Close() - break - } - newBytesInFlight := int64(packetOverhead + len(seg.payload)) - s.sendAlgorithm.OnPacketSent(time.Now(), bytesInFlight, int64(seq), newBytesInFlight, true) - bytesInFlight += newBytesInFlight - } - } - } else { - s.sendAlgorithm.OnApplicationLimited(bytesInFlight) + seg.txCount++ + seg.txTime = time.Now() + seg.txTimeout = s.rttStat.RTO() * time.Duration(mathext.Min(math.Pow(txTimeoutBackOff, float64(seg.txCount)), maxBackOffMultiplier)) + if isDataAckProtocol(seg.metadata.Protocol()) { + das, _ := toDataAckStruct(seg.metadata) + das.unAckSeq = s.nextRecv } + s.sendBuf.InsertBlocking(seg) - // Send ACK or heartbeat if needed. - exceedHeartbeatInterval := time.Since(s.lastTXTime) > sessionHeartbeatInterval - if s.ackOnDataRecv.Load() || exceedHeartbeatInterval { - baseStruct := baseStruct{} - if s.isClient { - baseStruct.protocol = uint8(ackClientToServer) - } else { - baseStruct.protocol = uint8(ackServerToClient) - } - ackSeg := &segment{ - metadata: &dataAckStruct{ - baseStruct: baseStruct, - sessionID: s.id, - seq: uint32(mathext.Max(0, int(s.nextSend)-1)), - unAckSeq: s.nextRecv, - windowSize: uint16(mathext.Max(0, int(s.legacysendAlgorithm.CongestionWindowSize())-s.recvBuf.Len())), - }, - transport: s.conn.TransportProtocol(), - } - if err := s.output(ackSeg, s.RemoteAddr()); err != nil { - err = fmt.Errorf("output() failed: %w", err) + if err := s.output(seg, s.RemoteAddr()); err != nil { + s.oLock.Unlock() + err = fmt.Errorf("output() failed: %w", err) + log.Debugf("%v %v", s, err) + s.outputErr <- err + s.closeWithError(err) + break + } else { + seq, err := seg.Seq() + if err != nil { + s.oLock.Unlock() + err = fmt.Errorf("failed to get sequence number from %v: %w", seg, err) log.Debugf("%v %v", s, err) s.outputErr <- err - s.Close() - } else { - seq, err := ackSeg.Seq() - if err != nil { - err = fmt.Errorf("failed to get sequence number from %v: %w", ackSeg, err) - log.Debugf("%v %v", s, err) - s.outputErr <- err - s.Close() - } - newBytesInFlight := int64(packetOverhead + len(ackSeg.payload)) - s.sendAlgorithm.OnPacketSent(time.Now(), bytesInFlight, int64(seq), newBytesInFlight, true) - bytesInFlight += newBytesInFlight + s.closeWithError(err) + break } - s.ackOnDataRecv.Store(false) + newBytesInFlight := int64(packetOverhead + len(seg.payload)) + s.sendAlgorithm.OnPacketSent(time.Now(), bytesInFlight, int64(seq), newBytesInFlight, true) + bytesInFlight += newBytesInFlight } - default: - err := fmt.Errorf("unsupported transport protocol %v", s.conn.TransportProtocol()) + } + } else { + s.sendAlgorithm.OnApplicationLimited(bytesInFlight) + } + + // Send ACK or heartbeat if needed. + exceedHeartbeatInterval := time.Since(s.lastTXTime) > sessionHeartbeatInterval + if s.ackOnDataRecv.Load() || exceedHeartbeatInterval { + baseStruct := baseStruct{} + if s.isClient { + baseStruct.protocol = uint8(ackClientToServer) + } else { + baseStruct.protocol = uint8(ackServerToClient) + } + s.oLock.Lock() + ackSeg := &segment{ + metadata: &dataAckStruct{ + baseStruct: baseStruct, + sessionID: s.id, + seq: uint32(mathext.Max(0, int(s.nextSend)-1)), + unAckSeq: s.nextRecv, + windowSize: uint16(mathext.Max(0, int(s.legacysendAlgorithm.CongestionWindowSize())-s.recvBuf.Len())), + }, + transport: s.conn.TransportProtocol(), + } + if err := s.output(ackSeg, s.RemoteAddr()); err != nil { + s.oLock.Unlock() + err = fmt.Errorf("output() failed: %w", err) log.Debugf("%v %v", s, err) s.outputErr <- err - s.Close() + s.closeWithError(err) + } else { + seq, err := ackSeg.Seq() + if err != nil { + s.oLock.Unlock() + err = fmt.Errorf("failed to get sequence number from %v: %w", ackSeg, err) + log.Debugf("%v %v", s, err) + s.outputErr <- err + s.closeWithError(err) + } else { + s.oLock.Unlock() + newBytesInFlight := int64(packetOverhead + len(ackSeg.payload)) + s.sendAlgorithm.OnPacketSent(time.Now(), bytesInFlight, int64(seq), newBytesInFlight, true) + bytesInFlight += newBytesInFlight + } } + s.ackOnDataRecv.Store(false) } } @@ -855,10 +830,10 @@ func (s *Session) inputData(seg *segment) error { } if !s.isClient && seg.metadata.Protocol() == openSessionRequest { - s.wLock.Lock() if s.isState(sessionAttached) { // Server needs to send open session response. // Check user quota if we can identify the user. + s.oLock.Lock() var userName string if seg.block != nil && seg.block.BlockContext().UserName != "" { userName = seg.block.BlockContext().UserName @@ -873,7 +848,7 @@ func (s *Session) inputData(seg *segment) error { if !quotaOK { s.status = statusQuotaExhausted log.Debugf("Closing %v because user %s used all the quota", s, userName) - s.wLock.Unlock() + s.oLock.Unlock() s.Close() return nil } @@ -894,8 +869,8 @@ func (s *Session) inputData(seg *segment) error { } s.sendQueue.InsertBlocking(seg4) s.forwardStateTo(sessionEstablished) + s.oLock.Unlock() } - s.wLock.Unlock() } return nil } @@ -955,7 +930,7 @@ func (s *Session) inputAck(seg *segment) error { } func (s *Session) inputClose(seg *segment) error { - s.wLock.Lock() + s.oLock.Lock() if seg.metadata.Protocol() == closeSessionRequest { // Send close session response. seg2 := &segment{ @@ -973,7 +948,7 @@ func (s *Session) inputClose(seg *segment) error { s.nextSend++ // The response will not retry if it is not delivered. if err := s.output(seg2, s.RemoteAddr()); err != nil { - s.wLock.Unlock() + s.oLock.Unlock() return fmt.Errorf("output() failed: %v", err) } // Immediately shutdown event loop. @@ -982,12 +957,12 @@ func (s *Session) inputClose(seg *segment) error { } else { log.Debugf("Remote requested to shut down %v", s) } - s.wLock.Unlock() + s.oLock.Unlock() s.Close() } else if seg.metadata.Protocol() == closeSessionResponse { // Immediately shutdown event loop. log.Debugf("Remote received the request from %v to shut down", s) - s.wLock.Unlock() + s.oLock.Unlock() s.Close() } return nil @@ -1013,10 +988,76 @@ func (s *Session) output(seg *segment, remoteAddr net.Addr) error { default: return fmt.Errorf("unsupported transport protocol %v", s.conn.TransportProtocol()) } + s.lastSend, _ = seg.Seq() s.lastTXTime = time.Now() return nil } +func (s *Session) closeWithError(err error) error { + if !s.closeRequested.CompareAndSwap(false, true) { + // This function has been called before. + return nil + } + + var gracefulClose bool + if err == nil { + log.Debugf("Closing %v", s) + gracefulClose = true + } else { + log.Debugf("Closing %v with error %v", s, err) + } + if s.isState(sessionAttached) || s.isState(sessionEstablished) { + // Send closeSessionRequest, but don't wait for closeSessionResponse, + // because the underlay connection may be already broken. + s.oLock.Lock() + closeRequestSeq := s.nextSend + seg := &segment{ + metadata: &sessionStruct{ + baseStruct: baseStruct{ + protocol: uint8(closeSessionRequest), + }, + sessionID: s.id, + seq: closeRequestSeq, + statusCode: uint8(s.status), + }, + transport: s.conn.TransportProtocol(), + } + s.nextSend++ + + var gracefulCloseSuccess bool + if gracefulClose { + s.sendQueue.InsertBlocking(seg) + s.oLock.Unlock() + for i := 0; i < 1000; i++ { + time.Sleep(time.Millisecond) + if s.lastSend >= closeRequestSeq { + gracefulCloseSuccess = true + break + } + } + } else { + s.oLock.Unlock() + } + if !gracefulCloseSuccess { + s.oLock.Lock() + if err := s.output(seg, s.RemoteAddr()); err != nil { + log.Debugf("output() failed: %v", err) + } + s.oLock.Unlock() + } + } + + // Don't clear receive buf and queue, because read is allowed after + // the session is closed. + s.sendQueue.DeleteAll() + s.sendBuf.DeleteAll() + s.forwardStateTo(sessionClosed) + close(s.closedChan) + log.Debugf("Closed %v", s) + metrics.CurrEstablished.Add(-1) + return nil +} + func (s *Session) checkQuota(userName string) (ok bool, err error) { if len(s.users) == 0 { return true, fmt.Errorf("no registered user") diff --git a/pkg/protocol/underlay_packet.go b/pkg/protocol/underlay_packet.go index c6d81187..b244748b 100644 --- a/pkg/protocol/underlay_packet.go +++ b/pkg/protocol/underlay_packet.go @@ -23,6 +23,7 @@ import ( "net" "time" + apicommon "github.com/enfein/mieru/v3/apis/common" "github.com/enfein/mieru/v3/pkg/appctl/appctlpb" "github.com/enfein/mieru/v3/pkg/cipher" "github.com/enfein/mieru/v3/pkg/common" @@ -67,7 +68,7 @@ var _ Underlay = &PacketUnderlay{} // "block" is the block encryption algorithm to encrypt packets. // // This function is only used by proxy client. -func NewPacketUnderlay(ctx context.Context, network, laddr, raddr string, mtu int, block cipher.BlockCipher) (*PacketUnderlay, error) { +func NewPacketUnderlay(ctx context.Context, network, laddr, raddr string, mtu int, block cipher.BlockCipher, resolver apicommon.DNSResolver) (*PacketUnderlay, error) { switch network { case "udp", "udp4", "udp6": default: @@ -79,14 +80,14 @@ func NewPacketUnderlay(ctx context.Context, network, laddr, raddr string, mtu in var localAddr *net.UDPAddr var err error if laddr != "" { - localAddr, err = net.ResolveUDPAddr("udp", laddr) + localAddr, err = apicommon.ResolveUDPAddr(resolver, "udp", laddr) if err != nil { - return nil, fmt.Errorf("net.ResolveUDPAddr() failed: %w", err) + return nil, fmt.Errorf("ResolveUDPAddr() failed: %w", err) } } - remoteAddr, err := net.ResolveUDPAddr("udp", raddr) + remoteAddr, err := apicommon.ResolveUDPAddr(resolver, "udp", raddr) if err != nil { - return nil, fmt.Errorf("net.ResolveUDPAddr() failed: %w", err) + return nil, fmt.Errorf("ResolveUDPAddr() failed: %w", err) } conn, err := net.ListenUDP(network, localAddr) @@ -189,7 +190,7 @@ func (u *PacketUnderlay) RunEventLoop(ctx context.Context) error { u.sessionMap.Range(func(k, v any) bool { session := v.(*Session) select { - case <-session.done: + case <-session.closedChan: log.Debugf("Found closed %v", session) if err := u.RemoveSession(session); err != nil { log.Debugf("%v RemoveSession() failed: %v", u, err) diff --git a/pkg/protocol/underlay_stream.go b/pkg/protocol/underlay_stream.go index 4e553ece..ba463d73 100644 --- a/pkg/protocol/underlay_stream.go +++ b/pkg/protocol/underlay_stream.go @@ -22,6 +22,7 @@ import ( "net" "time" + apicommon "github.com/enfein/mieru/v3/apis/common" "github.com/enfein/mieru/v3/pkg/appctl/appctlpb" "github.com/enfein/mieru/v3/pkg/cipher" "github.com/enfein/mieru/v3/pkg/common" @@ -61,7 +62,7 @@ var _ Underlay = &StreamUnderlay{} // "block" is the block encryption algorithm to encrypt packets. // // This function is only used by proxy client. -func NewStreamUnderlay(ctx context.Context, network, laddr, raddr string, mtu int, block cipher.BlockCipher) (*StreamUnderlay, error) { +func NewStreamUnderlay(ctx context.Context, network, laddr, raddr string, mtu int, block cipher.BlockCipher, resolver apicommon.DNSResolver) (*StreamUnderlay, error) { switch network { case "tcp", "tcp4", "tcp6": default: @@ -74,9 +75,9 @@ func NewStreamUnderlay(ctx context.Context, network, laddr, raddr string, mtu in Control: sockopts.ReuseAddrPort(), } if laddr != "" { - tcpLocalAddr, err := net.ResolveTCPAddr(network, laddr) + tcpLocalAddr, err := apicommon.ResolveTCPAddr(resolver, network, laddr) if err != nil { - return nil, fmt.Errorf("net.ResolveTCPAddr() failed: %w", err) + return nil, fmt.Errorf("ResolveTCPAddr() failed: %w", err) } dialer.LocalAddr = tcpLocalAddr } diff --git a/pkg/socks5/request.go b/pkg/socks5/request.go index 0d65fa7f..d70e0a07 100644 --- a/pkg/socks5/request.go +++ b/pkg/socks5/request.go @@ -11,6 +11,7 @@ import ( "sync" "sync/atomic" + apicommon "github.com/enfein/mieru/v3/apis/common" "github.com/enfein/mieru/v3/apis/constant" "github.com/enfein/mieru/v3/apis/model" "github.com/enfein/mieru/v3/pkg/common" @@ -83,15 +84,19 @@ func (s *Server) handleRequest(ctx context.Context, req *Request, conn io.ReadWr // Resolve the address if we have a FQDN. dst := req.DstAddr if dst.FQDN != "" { - addr, err := s.config.Resolver.LookupIP(ctx, dst.FQDN) - if err != nil { + ips, err := s.config.Resolver.LookupIP(ctx, "ip", dst.FQDN) + if err != nil || len(ips) == 0 { DNSResolveErrors.Add(1) if err := sendReply(conn, hostUnreachable, nil); err != nil { return fmt.Errorf("failed to send reply: %w", err) } - return fmt.Errorf("failed to resolve destination %q: %w", dst.FQDN, err) + if err != nil { + return fmt.Errorf("failed to resolve destination %q: %w", dst.FQDN, err) + } else { + return fmt.Errorf(stderror.IPAddressNotFound, dst.FQDN) + } } - dst.IP = addr + dst.IP = ips[0] } // Return error if access local destination is not allowed. @@ -165,7 +170,7 @@ func (s *Server) handleBind(_ context.Context, _ *Request, conn io.ReadWriteClos func (s *Server) handleAssociate(_ context.Context, _ *Request, conn io.ReadWriteCloser) error { // Create a UDP listener on a random port. // All the requests associated to this connection will go through this port. - udpListenerAddr, err := net.ResolveUDPAddr("udp", common.MaybeDecorateIPv6(common.AllIPAddr())+":0") + udpListenerAddr, err := apicommon.ResolveUDPAddr(s.config.Resolver, "udp", common.MaybeDecorateIPv6(common.AllIPAddr())+":0") if err != nil { UDPAssociateErrors.Add(1) return fmt.Errorf("failed to resolve UDP address: %w", err) @@ -267,7 +272,7 @@ func (s *Server) handleAssociate(_ context.Context, _ *Request, conn io.ReadWrit case 0x03: fqdnLen := buf[4] fqdn := string(buf[5 : 5+fqdnLen]) - dstAddr, err := net.ResolveUDPAddr("udp", fqdn+":"+strconv.Itoa(int(buf[5+fqdnLen])<<8+int(buf[6+fqdnLen]))) + dstAddr, err := apicommon.ResolveUDPAddr(s.config.Resolver, "udp", fqdn+":"+strconv.Itoa(int(buf[5+fqdnLen])<<8+int(buf[6+fqdnLen]))) if err != nil { log.Debugf("UDP associate %v ResolveUDPAddr() failed: %v", udpConn.LocalAddr(), err) UDPAssociateErrors.Add(1) diff --git a/pkg/socks5/socks5.go b/pkg/socks5/socks5.go index d140f731..1661c5b0 100644 --- a/pkg/socks5/socks5.go +++ b/pkg/socks5/socks5.go @@ -8,6 +8,7 @@ import ( "strconv" "time" + apicommon "github.com/enfein/mieru/v3/apis/common" "github.com/enfein/mieru/v3/apis/model" "github.com/enfein/mieru/v3/pkg/appctl/appctlpb" "github.com/enfein/mieru/v3/pkg/common" @@ -42,7 +43,7 @@ type Config struct { EgressController egress.Controller // Resolver can be provided to do custom name resolution. - Resolver *common.DNSResolver + Resolver apicommon.DNSResolver // BindIP is used for bind or udp associate BindIP net.IP @@ -84,7 +85,7 @@ func New(conf *Config) (*Server, error) { // Ensure we have a DNS resolver. if conf.Resolver == nil { - conf.Resolver = &common.DNSResolver{} + conf.Resolver = &net.Resolver{} } // Provide a default bind IP. diff --git a/pkg/stderror/template.go b/pkg/stderror/template.go index 9eb22c2f..c1d2b5f2 100644 --- a/pkg/stderror/template.go +++ b/pkg/stderror/template.go @@ -38,6 +38,7 @@ const ( GetThreadDumpFailedErr = "get thread dump failed: %w" InvalidPortBindingsErr = "invalid port bindings: %w" InvalidTransportProtocol = "invalid transport protocol" + IPAddressNotFound = "IP address not found from domain name %q" LookupIPFailedErr = "look up IP address failed: %w" ParseIPFailed = "parse IP address failed" ReloadServerFailedErr = "reload mita server failed: %w" diff --git a/pkg/version/current.go b/pkg/version/current.go index 23fca8a0..89b43cee 100644 --- a/pkg/version/current.go +++ b/pkg/version/current.go @@ -16,5 +16,5 @@ package version const ( - AppVersion = "3.7.0" + AppVersion = "3.8.0" )