diff --git a/httptunnel/local_conn.go b/httptunnel/local_conn.go index ebbe2d51..dac438d3 100644 --- a/httptunnel/local_conn.go +++ b/httptunnel/local_conn.go @@ -162,6 +162,21 @@ func (l *LocalConn) push() error { buf := bytespool.Get(cipherstream.MaxCipherRelaySize) defer bytespool.MustPut(buf) + defer func() { + p := &pushPayload{RequestUID: l.uuid} + _ = faker.FakeData(p) + + payload, _ := json.Marshal(p) + resp, err := r.SetBody(payload).Post(l.serverAddr + "/push") + if err != nil { + log.Warn("[HTTP_TUNNEL_LOCAL] push end", "err", err, "uuid", l.uuid) + return + } + if _, err = resp.ToBytes(); err != nil { + log.Warn("[HTTP_TUNNEL_LOCAL] push end", "err", err, "uuid", l.uuid) + } + }() + for { var resp *req.Response n, err1 := l.Read(buf) diff --git a/httptunnel/server.go b/httptunnel/server.go index 267c1021..13bf76c7 100644 --- a/httptunnel/server.go +++ b/httptunnel/server.go @@ -169,7 +169,7 @@ func (s *Server) pull(w http.ResponseWriter, r *http.Request) { log.Warn("[HTTP_TUNNEL_SERVER] read from conn", "err", err) } - s.CloseConn(reqID) + s.pullClose(reqID) log.Info("[HTTP_TUNNEL_SERVER] Pull completed...", "uuid", reqID) } @@ -226,6 +226,11 @@ func (s *Server) push(w http.ResponseWriter, r *http.Request) { s.notifyPull(reqID) s.Unlock() + if p.Payload == "" { + // client end push + _ = conns[0].(interface{ CloseWrite() error }).CloseWrite() + return + } cipher, err := base64.StdEncoding.DecodeString(p.Payload) if err != nil { log.Warn("[HTTP_TUNNEL_SERVER] decode cipher", "err", err) @@ -242,10 +247,16 @@ func (s *Server) push(w http.ResponseWriter, r *http.Request) { writeSuccess(w) } -func (s *Server) CloseConn(reqID string) { +func (s *Server) pullClose(reqID string) { s.Lock() defer s.Unlock() + conns, ok := s.connMap[reqID] + if !ok { + return + } + _ = conns[1].(interface{ CloseWrite() error }).CloseWrite() + s.connMap[reqID] = nil delete(s.connMap, reqID) }