diff --git a/core/stream.go b/core/stream.go index 03a19793..0dd75fe4 100644 --- a/core/stream.go +++ b/core/stream.go @@ -4,6 +4,9 @@ import ( "bytes" "io" "net" + "sync" + "sync/atomic" + "time" "github.com/shadowsocks/go-shadowsocks2/socks" ) @@ -31,12 +34,25 @@ func Dial(network, address string, ciph StreamConnCipher) (net.Conn, error) { // Connect sends the shadowsocks standard header to underlying ciphered connection // in the next Write/ReadFrom call. func Connect(c net.Conn, addr socks.Addr) net.Conn { - return &ssconn{Conn: c, addr: addr} + return newSSConn(c, addr, 5 * time.Millisecond) } type ssconn struct { net.Conn addr socks.Addr + done uint32 + m sync.Mutex + t *time.Timer +} + +func newSSConn(c net.Conn, addr socks.Addr, delay time.Duration) *ssconn { + sc := &ssconn{Conn: c, addr: addr} + if delay > 0 { + sc.t = time.AfterFunc(delay, func() { + sc.Write([]byte{}) + }) + } + return sc } func (c *ssconn) Write(b []byte) (int, error) { @@ -45,9 +61,22 @@ func (c *ssconn) Write(b []byte) (int, error) { } func (c *ssconn) ReadFrom(r io.Reader) (int64, error) { - if len(c.addr) > 0 { - r = &readerWithAddr{Reader: r, b: c.addr} - c.addr = nil + if atomic.LoadUint32(&c.done) == 0 { + c.m.Lock() + defer c.m.Unlock() + if c.done == 0 { + defer atomic.StoreUint32(&c.done, 1) + ra := readerWithAddr{Reader: r, b: c.addr} + if c.t != nil { + c.t.Stop() + c.t = nil + } + n, err := io.Copy(c.Conn, &ra) + n -= int64(len(c.addr)) + if n < 0 { n = 0 } + c.addr = nil + return n, err + } } return io.Copy(c.Conn, r) } @@ -63,11 +92,7 @@ type readerWithAddr struct { func (r *readerWithAddr) Read(b []byte) (n int, err error) { nc := copy(b, r.b) - if nc < len(r.b) { - r.b = r.b[:nc] - return nc, nil - } - r.b = nil + r.b = r.b[nc:] nr, err := r.Reader.Read(b[nc:]) return nc + nr, err }