diff --git a/remote.go b/remote.go index 039c16f4..443f4e90 100644 --- a/remote.go +++ b/remote.go @@ -118,7 +118,7 @@ func (es *EasyServer) handleConn(conn net.Conn, tryReuse bool) { if err != nil { if errors.Is(err, io.EOF) { log.Debug("[REMOTE] got EOF error when handshake with client-server, maybe the connection pool closed the idle conn") - } else if !errors.Is(err, netpipe.ErrDeadline) { + } else if !errors.Is(err, netpipe.ErrReadDeadline) { log.Warn("[REMOTE] handshake with client", "err", err) } return @@ -190,9 +190,9 @@ func (es *EasyServer) handShakeWithClient(conn net.Conn) (hsRes, error) { } cs := csStream.(*cipherstream.CipherStream) - _ = csStream.SetDeadline(time.Now().Add(es.MaxConnWaitTimeout())) + _ = csStream.SetReadDeadline(time.Now().Add(es.MaxConnWaitTimeout())) defer func() { - _ = csStream.SetDeadline(time.Time{}) + _ = csStream.SetReadDeadline(time.Time{}) cs.Release() }() @@ -203,7 +203,7 @@ func (es *EasyServer) handShakeWithClient(conn net.Conn) (hsRes, error) { return res, err } - _ = csStream.SetDeadline(time.Now().Add(es.MaxConnWaitTimeout())) + _ = csStream.SetReadDeadline(time.Now().Add(es.MaxConnWaitTimeout())) if frame.IsPingFrame() { log.Debug("[REMOTE] got ping message", diff --git a/remote_udp.go b/remote_udp.go index 6b6486c0..0f412fe4 100644 --- a/remote_udp.go +++ b/remote_udp.go @@ -48,7 +48,7 @@ func (es *EasyServer) remoteUDPHandle(conn net.Conn, addrStr, method string, isD if errors.Is(err, cipherstream.ErrFINRSTStream) { _tryReuse = true log.Debug("[REMOTE_UDP] received FIN when reading data from client, try to reuse the connection") - } else if !errors.Is(err, io.EOF) && !errors.Is(err, netpipe.ErrDeadline) { + } else if !errors.Is(err, io.EOF) && !errors.Is(err, netpipe.ErrReadDeadline) { log.Warn("[REMOTE_UDP] read data from client connection", "err", err) } diff --git a/util/net_test.go b/util/net_test.go index 87859cfc..37c61cf1 100644 --- a/util/net_test.go +++ b/util/net_test.go @@ -23,7 +23,7 @@ func TestIP(t *testing.T) { } func TestLookupIPV4From(t *testing.T) { - ips, err := LookupIPV4From("8.8.8.8:53", "dnspod.cn") + ips, err := LookupIPV4From("119.29.29.29:53", "dnspod.cn") assert.Nil(t, err) assert.Greater(t, len(ips), 0) } diff --git a/util/netpipe/dup_pipe.go b/util/netpipe/dup_pipe.go index 4100c8ab..24a8befe 100644 --- a/util/netpipe/dup_pipe.go +++ b/util/netpipe/dup_pipe.go @@ -70,6 +70,8 @@ func Pipe(maxSize int, addrs ...net.Addr) (net.Conn, net.Conn) { sp := &pipe{ buf: ringbuffer.NewBuffer(buf1), back: buf1, + rdChan: make(chan struct{}), + wdChan: make(chan struct{}), maxSize: maxSize, remoteAddr: remoteAddr, localAddr: localAddr, @@ -80,6 +82,8 @@ func Pipe(maxSize int, addrs ...net.Addr) (net.Conn, net.Conn) { rp := &pipe{ buf: ringbuffer.NewBuffer(buf2), back: buf2, + rdChan: make(chan struct{}), + wdChan: make(chan struct{}), maxSize: maxSize, remoteAddr: remoteAddr, localAddr: localAddr, diff --git a/util/netpipe/dup_pipe_test.go b/util/netpipe/dup_pipe_test.go index 151d5546..64d1a356 100644 --- a/util/netpipe/dup_pipe_test.go +++ b/util/netpipe/dup_pipe_test.go @@ -109,10 +109,10 @@ func TestDupPipe_Timeout(t *testing.T) { b := make([]byte, 10) n, err := p1.Read(b) assert.Equal(t, 0, n) - assert.Equal(t, ErrDeadline, err) + assert.Equal(t, ErrReadDeadline, err) n, err = p2.Read(b) assert.Equal(t, 0, n) - assert.Equal(t, ErrDeadline, err) + assert.Equal(t, ErrReadDeadline, err) err = p1.SetReadDeadline(time.Time{}) assert.Nil(t, err) @@ -134,7 +134,7 @@ func TestDupPipe_Timeout(t *testing.T) { assert.Nil(t, err) n, err = p1.Write([]byte("hello2")) assert.Equal(t, 0, n) - assert.Equal(t, ErrDeadline, err) + assert.Equal(t, ErrWriteDeadline, err) err = p1.SetWriteDeadline(time.Time{}) assert.Nil(t, err) @@ -149,3 +149,15 @@ func TestDupPipe_Timeout(t *testing.T) { assert.Equal(t, 6, n) assert.Nil(t, err) } + +func TestDupPipe_Read_Then_SetDeadline(t *testing.T) { + p1, _ := Pipe(10) + go func() { + time.Sleep(time.Second) + err := p1.SetDeadline(time.Now().Add(time.Second)) + assert.Nil(t, err) + }() + n, err := p1.Read(make([]byte, 10)) + assert.Equal(t, 0, n) + assert.Equal(t, ErrReadDeadline, err) +} diff --git a/util/netpipe/pipe.go b/util/netpipe/pipe.go index 459ad2b5..6e419951 100644 --- a/util/netpipe/pipe.go +++ b/util/netpipe/pipe.go @@ -1,6 +1,7 @@ package netpipe import ( + "errors" "fmt" "net" "sync" @@ -15,10 +16,12 @@ var _ net.Conn = (*pipe)(nil) // ensure to implements net.Conn // Pipe is buffered version of net.Pipe. Reads // will block until data is available. type pipe struct { - buf *ringbuffer.RingBuffer - back []byte - cond sync.Cond - mu sync.Mutex + buf *ringbuffer.RingBuffer + back []byte + rdChan chan struct{} + wdChan chan struct{} + cond sync.Cond + mu sync.Mutex maxSize int rLate bool @@ -30,7 +33,8 @@ type pipe struct { localAddr net.Addr } -var ErrDeadline = fmt.Errorf("pipe deadline exceeded") +var ErrReadDeadline = fmt.Errorf("pipe read deadline exceeded") +var ErrWriteDeadline = fmt.Errorf("pipe write deadline exceeded") var ErrPipeClosed = fmt.Errorf("pipe closed") var ErrExceedMaxSize = fmt.Errorf("exceed max size") @@ -41,30 +45,7 @@ func (p *pipe) Read(b []byte) (n int, err error) { defer p.cond.L.Unlock() if p.rLate { - return 0, ErrDeadline - } - - if !p.readDeadline.IsZero() { - now := time.Now() - dur := p.readDeadline.Sub(now) - if dur <= 0 { - p.rLate = true - return 0, ErrDeadline - } - nextReadDone := make(chan struct{}) - defer close(nextReadDone) - go func(dur time.Duration) { - timer := time.NewTimer(dur) - defer timer.Stop() - select { - case <-timer.C: - p.cond.L.Lock() - p.rLate = true - p.cond.L.Unlock() - p.cond.Broadcast() - case <-nextReadDone: - } - }(dur) + return 0, ErrReadDeadline } defer p.cond.Broadcast() @@ -75,7 +56,7 @@ func (p *pipe) Read(b []byte) (n int, err error) { } if p.rLate { - return 0, ErrDeadline + return 0, ErrReadDeadline } return p.buf.Read(b) @@ -92,31 +73,9 @@ func (p *pipe) Write(b []byte) (n int, err error) { defer p.cond.L.Unlock() if p.wLate { - return 0, ErrDeadline - } - - if !p.writeDeadline.IsZero() { - now := time.Now() - dur := p.writeDeadline.Sub(now) - if dur <= 0 { - p.wLate = true - return 0, ErrDeadline - } - nextWriteDone := make(chan struct{}) - defer close(nextWriteDone) - go func(dur time.Duration) { - timer := time.NewTimer(dur) - defer timer.Stop() - select { - case <-timer.C: - p.cond.L.Lock() - p.wLate = true - p.cond.L.Unlock() - p.cond.Broadcast() - case <-nextWriteDone: - } - }(dur) + return 0, ErrWriteDeadline } + defer p.cond.Broadcast() for p.buf.Free() < len(b) && !p.closed && !p.wLate { @@ -125,7 +84,7 @@ func (p *pipe) Write(b []byte) (n int, err error) { } if p.wLate { - return 0, ErrDeadline + return 0, ErrWriteDeadline } return p.buf.Write(b) @@ -179,26 +138,81 @@ func (a addr) Network() string { return "pipe" } func (p *pipe) SetDeadline(t time.Time) error { err := p.SetReadDeadline(t) err2 := p.SetWriteDeadline(t) - if err != nil { - return err - } - return err2 + return errors.Join(err, err2) } // SetWriteDeadline implements the net.Conn method func (p *pipe) SetWriteDeadline(t time.Time) error { + // Let the previous goroutine exit, if it exists. + select { + case p.wdChan <- struct{}{}: + default: + } + p.cond.L.Lock() - p.writeDeadline = t - p.wLate = false - p.cond.L.Unlock() + defer p.cond.L.Unlock() + + if t.IsZero() || t.After(time.Now()) { + p.wLate = false + } else { + p.wLate = true + p.cond.Broadcast() + return nil + } + + if !t.IsZero() { + go func() { + timer := time.NewTimer(t.Sub(time.Now())) + defer timer.Stop() + + select { + case <-timer.C: + p.cond.L.Lock() + p.wLate = true + p.cond.Broadcast() + p.cond.L.Unlock() + case <-p.wdChan: + } + }() + } + return nil } // SetReadDeadline implements the net.Conn method func (p *pipe) SetReadDeadline(t time.Time) error { + // Let the previous goroutine exit, if it exists. + select { + case p.rdChan <- struct{}{}: + default: + } + p.cond.L.Lock() - p.readDeadline = t - p.rLate = false - p.cond.L.Unlock() + defer p.cond.L.Unlock() + + if t.IsZero() || t.After(time.Now()) { + p.rLate = false + } else { + p.rLate = true + p.cond.Broadcast() + return nil + } + + if !t.IsZero() { + go func() { + timer := time.NewTimer(t.Sub(time.Now())) + defer timer.Stop() + + select { + case <-timer.C: + p.cond.L.Lock() + p.rLate = true + p.cond.Broadcast() + p.cond.L.Unlock() + case <-p.rdChan: + } + }() + } + return nil }