From dec824102c6022cd67e3191a155363dfc2c139c4 Mon Sep 17 00:00:00 2001 From: zhiyi Date: Tue, 9 Apr 2024 16:10:06 +0800 Subject: [PATCH] perf: handle specified tcp relay directly --- bin/src/main.rs | 1 - libcs/client/conn.go | 56 ++++++++++++++++++++++-------------------- libcs/server/client.go | 2 +- libcs/server/conn.go | 26 ++++++++++++++++++++ 4 files changed, 57 insertions(+), 28 deletions(-) diff --git a/bin/src/main.rs b/bin/src/main.rs index 72b850ad..8f7f9fa6 100644 --- a/bin/src/main.rs +++ b/bin/src/main.rs @@ -14,7 +14,6 @@ * limitations under the License. */ -use std::env; use std::path::PathBuf; use clap::Parser; diff --git a/libcs/client/conn.go b/libcs/client/conn.go index 255b3890..41578b5e 100644 --- a/libcs/client/conn.go +++ b/libcs/client/conn.go @@ -315,7 +315,7 @@ func (c *conn) readLoop(connID uint) { if rErr != nil { err = wErr if !errors.Is(rErr, net.ErrClosed) { - c.Logger.Warn().Err(rErr).Msg("failed to read data in processData") + c.Logger.Warn().Err(rErr).Msg("failed to read data in processServiceData") } return } @@ -327,7 +327,7 @@ func (c *conn) readLoop(connID uint) { } if wErr != nil { if !errors.Is(wErr, net.ErrClosed) { - c.Logger.Warn().Err(wErr).Msg("failed to write data in processData") + c.Logger.Warn().Err(wErr).Msg("failed to write data in processServiceData") } continue } @@ -387,22 +387,24 @@ func (c *conn) dial(s *service) (task *httpTask, err error) { } func (c *conn) processServiceData(connID uint, taskID uint32, s *service, r *bufio.LimitedReader) (readErr, writeErr error) { - var peekBytes []byte - peekBytes, readErr = r.Peek(2) - if readErr != nil { - return - } - // first 2 bytes of p2p sdp request is "XP"(0x5850) - isP2P := (uint16(peekBytes[1]) | uint16(peekBytes[0])<<8) == 0x5850 - if isP2P { - if len(c.stuns) < 1 { - respAndClose(taskID, c, [][]byte{ - []byte("HTTP/1.1 403 Forbidden\r\nConnection: Closed\r\n\r\n"), - }) + if r.N > 0 { + var peekBytes []byte + peekBytes, readErr = r.Peek(2) + if readErr != nil { + return + } + // first 2 bytes of p2p sdp request is "XP"(0x5850) + isP2P := (uint16(peekBytes[1]) | uint16(peekBytes[0])<<8) == 0x5850 + if isP2P { + if len(c.stuns) < 1 { + respAndClose(taskID, c, [][]byte{ + []byte("HTTP/1.1 403 Forbidden\r\nConnection: Closed\r\n\r\n"), + }) + return + } + c.processP2P(taskID, r) return } - c.processP2P(taskID, r) - return } var task *httpTask @@ -429,18 +431,20 @@ func (c *conn) processServiceData(connID uint, taskID uint32, s *service, r *buf c.tasksRWMtx.Unlock() go task.process(connID, taskID, c) - _, err := r.WriteTo(task) - if err != nil { - switch e := err.(type) { - case *net.OpError: - switch e.Op { - case "write": + if r.N > 0 { + _, err := r.WriteTo(task) + if err != nil { + switch e := err.(type) { + case *net.OpError: + switch e.Op { + case "write": + writeErr = err + } + case *bufio.WriteErr: writeErr = err + default: + readErr = err } - case *bufio.WriteErr: - writeErr = err - default: - readErr = err } } if task.service.LocalTimeout.Duration > 0 { diff --git a/libcs/server/client.go b/libcs/server/client.go index 9d2dc0a5..424051d2 100644 --- a/libcs/server/client.go +++ b/libcs/server/client.go @@ -448,7 +448,7 @@ func (c *client) openSpecifiedTCPPort(serviceIndex uint16, l *tcpListener, tcpPo }() tunnel.Logger.Info().Uint16("serviceIndex", serviceIndex).Uint16("tcpPort", tcpPort).Msg("tcp forward start") conn.serviceIndex = serviceIndex - conn.handle(func() { + conn.handleTCP(func() { err = c.process(conn) if err != nil { conn.Logger.Error().Err(err).Msg("tcp handle") diff --git a/libcs/server/conn.go b/libcs/server/conn.go index d17198a9..eabefca6 100644 --- a/libcs/server/conn.go +++ b/libcs/server/conn.go @@ -172,6 +172,32 @@ func (c *conn) handle(handleFunc func()) { handleFunc() } +func (c *conn) handleTCP(handleFunc func()) { + startTime := time.Now() + reader := pool.GetReader(c.Conn) + c.Reader = reader + defer func() { + c.Close() + pool.PutReader(reader) + endTime := time.Now() + if !predef.Debug { + if e := recover(); e != nil { + c.Logger.Error().Msgf("recovered panic: %#v\n%s", e, debug.Stack()) + } + } + c.Logger.Info().Dur("cost", endTime.Sub(startTime)).Msg("closed") + }() + if c.server.config.Timeout.Duration > 0 { + dl := startTime.Add(c.server.config.Timeout.Duration) + err := c.SetReadDeadline(dl) + if err != nil { + c.Logger.Debug().Err(err).Msg("handle set deadline failed") + return + } + } + handleFunc() +} + func (c *conn) handleProbe(r *bufio.Reader) { op, err := r.ReadByte() if err != nil {