From a5778ade4622940eb600d3d505f5c35bbef8908f Mon Sep 17 00:00:00 2001 From: nange Date: Fri, 1 Nov 2024 18:55:15 +0800 Subject: [PATCH] fix: memory leak in the HTTP tunnel module under extreme scenarios --- httptunnel/server.go | 99 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 84 insertions(+), 15 deletions(-) diff --git a/httptunnel/server.go b/httptunnel/server.go index 58fe76c5..748a6787 100644 --- a/httptunnel/server.go +++ b/httptunnel/server.go @@ -21,13 +21,19 @@ import ( ) const RelayBufferSize = cipherstream.MaxCipherRelaySize +const DefaultConnCount = 256 type Server struct { addr string timeout time.Duration - sync.Mutex - connMap map[string][]net.Conn + sync.RWMutex + connMap map[string]*struct { + conns []net.Conn + ch chan struct{} + timer *time.Timer + isPushCloseRunning bool + } connCh chan net.Conn closing chan struct{} pullWaiting map[string]chan struct{} @@ -44,12 +50,17 @@ func NewServer(addr string, timeout time.Duration, tlsConfig *tls.Config) *Serve } return &Server{ - addr: addr, - timeout: timeout, - connMap: make(map[string][]net.Conn, 256), + addr: addr, + timeout: timeout, + connMap: make(map[string]*struct { + conns []net.Conn + ch chan struct{} + timer *time.Timer + isPushCloseRunning bool + }, DefaultConnCount), connCh: make(chan net.Conn, 1), closing: make(chan struct{}, 1), - pullWaiting: make(map[string]chan struct{}, 256), + pullWaiting: make(map[string]chan struct{}, DefaultConnCount), tlsConfig: tlsConfig, server: server, } @@ -131,9 +142,9 @@ func (s *Server) pull(w http.ResponseWriter, r *http.Request) { return } - s.Lock() - conns := s.connMap[reqID] - s.Unlock() + s.RLock() + conns := s.connMap[reqID].conns + s.RUnlock() log.Debug("[HTTP_TUNNEL_SERVER] pull", "uuid", reqID) w.Header().Set("Content-Type", "application/json") @@ -172,7 +183,7 @@ func (s *Server) pull(w http.ResponseWriter, r *http.Request) { log.Warn("[HTTP_TUNNEL_SERVER] read from conn", "err", err) } - s.pullClose(reqID) + s.pullCloseConn(reqID) log.Info("[HTTP_TUNNEL_SERVER] Pull completed...", "uuid", reqID) } @@ -222,16 +233,43 @@ func (s *Server) push(w http.ResponseWriter, r *http.Request) { conns, ok := s.connMap[reqID] if !ok { conn1, conn2 := netpipe.Pipe(2*cipherstream.MaxPayloadSize, addr) - conns = []net.Conn{conn1, conn2} + conns = &struct { + conns []net.Conn + ch chan struct{} + timer *time.Timer + isPushCloseRunning bool + }{ + conns: []net.Conn{conn1, conn2}, + ch: make(chan struct{}, 1), + timer: time.NewTimer(s.timeout), + } s.connMap[reqID] = conns s.connCh <- conn2 } s.notifyPull(reqID) s.Unlock() + defer func() { + s.Lock() + defer s.Unlock() + + conns, ok := s.connMap[reqID] + if !ok { + return + } + timer := conns.timer + timer.Reset(s.timeout) + if conns.isPushCloseRunning { + return + } + conns.isPushCloseRunning = true + + go s.pushCloseConn(reqID) + }() + if p.Payload == "" { // client end push - _ = conns[0].(interface{ CloseWrite() error }).CloseWrite() + _ = conns.conns[0].(interface{ CloseWrite() error }).CloseWrite() return } cipher, err := base64.StdEncoding.DecodeString(p.Payload) @@ -241,7 +279,7 @@ func (s *Server) push(w http.ResponseWriter, r *http.Request) { return } - if _, err = conns[0].Write(cipher); err != nil { + if _, err = conns.conns[0].Write(cipher); err != nil { log.Warn("[HTTP_TUNNEL_SERVER] write local", "err", err) writeServiceUnavailableError(w, "write local:"+err.Error()) return @@ -250,12 +288,43 @@ func (s *Server) push(w http.ResponseWriter, r *http.Request) { writeSuccess(w) } -func (s *Server) pullClose(reqID string) { +func (s *Server) pushCloseConn(reqID string) { + s.RLock() + conns, ok := s.connMap[reqID] + if !ok { + s.RUnlock() + return + } + + timer := conns.timer + ch := conns.ch + s.RUnlock() + + select { + case <-s.closing: + case <-timer.C: + case <-ch: + } + + s.Lock() + defer s.Unlock() + s.closeConn(reqID) +} + +func (s *Server) pullCloseConn(reqID string) { s.Lock() defer s.Unlock() if conns, ok := s.connMap[reqID]; ok { - _ = conns[0].Close() + close(conns.ch) + } + + s.closeConn(reqID) +} + +func (s *Server) closeConn(reqID string) { + if conns, ok := s.connMap[reqID]; ok { + _ = conns.conns[0].Close() } s.connMap[reqID] = nil