From 97fed99fc3d90d9c0f2e6dac4db76e2e9dab20e0 Mon Sep 17 00:00:00 2001 From: tantra35 Date: Wed, 12 Feb 2020 15:56:04 +0300 Subject: [PATCH 1/4] add ability to make create streams with timeout like DialTimeout https://golang.org/pkg/net/#DialTimeout --- session.go | 105 +++++++++++++++++++++++++++++++++++++++++++++++------ stream.go | 12 +++++- util.go | 7 +++- 3 files changed, 110 insertions(+), 14 deletions(-) diff --git a/session.go b/session.go index a80ddec..eb94c34 100644 --- a/session.go +++ b/session.go @@ -151,8 +151,44 @@ func (s *Session) Open() (net.Conn, error) { return conn, nil } +// OpenTimeout is used to create a new stream as a net.Conn with TimeOut +func (s *Session) OpenTimeout(timeout time.Duration) (net.Conn, error) { + conn, err := s.OpenStreamTimeout(timeout) + if err != nil { + return nil, err + } + return conn, nil +} + // OpenStream is used to create a new stream func (s *Session) OpenStream() (*Stream, error) { + return s._openStream(nil, func(stream *Stream) error { + return stream.sendWindowUpdate() + }) +} + +// OpenStreamTimeout is used to create a new stream with TimeOut +func (s *Session) OpenStreamTimeout(timeout time.Duration) (*Stream, error) { + t := timerPool.Get() + timer := t.(*time.Timer) + timer.Reset(timeout) + defer func() { + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + timerPool.Put(t) + }() + + return s._openStream(timer.C, func(stream *Stream) error { + return stream.sendWindowUpdateTimeout(timer.C) + }) +} + +// OpenStream is used to create a new stream +func (s *Session) _openStream(timeout <-chan time.Time, fn func(*Stream) error) (*Stream, error) { if s.IsClosed() { return nil, ErrSessionShutdown } @@ -165,6 +201,9 @@ func (s *Session) OpenStream() (*Stream, error) { case s.synCh <- struct{}{}: case <-s.shutdownCh: return nil, ErrSessionShutdown + case <-timeout: + s.logger.Printf("[ERR] yamux: Failed to openstream due %v", ErrTimeout) + return nil, ErrTimeout } GET_ID: @@ -185,7 +224,12 @@ GET_ID: s.streamLock.Unlock() // Send the window update to create - if err := stream.sendWindowUpdate(); err != nil { + if err := fn(stream); err != nil { + s.logger.Printf("[ERR] yamux: Failed to openstream due %v", err) + s.streamLock.Lock() + delete(s.streams, id) + delete(s.inflight, id) + s.streamLock.Unlock() select { case <-s.synCh: default: @@ -282,10 +326,24 @@ func (s *Session) Ping() (time.Duration, error) { s.pings[id] = ch s.pingLock.Unlock() + t := timerPool.Get() + timer := t.(*time.Timer) + timer.Reset(s.config.ConnectionWriteTimeout) + defer func() { + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + timerPool.Put(t) + }() + // Send the ping request hdr := header(make([]byte, headerSize)) hdr.encode(typePing, flagSYN, 0, id) - if err := s.waitForSend(hdr, nil); err != nil { + errCh := make(chan error, 1) + if err := s.waitForSendErrTimeout(timer.C, hdr, nil, errCh); err != nil { return 0, err } @@ -293,7 +351,7 @@ func (s *Session) Ping() (time.Duration, error) { start := time.Now() select { case <-ch: - case <-time.After(s.config.ConnectionWriteTimeout): + case <-timer.C: s.pingLock.Lock() delete(s.pings, id) // Ignore it if a response comes later. s.pingLock.Unlock() @@ -309,10 +367,23 @@ func (s *Session) Ping() (time.Duration, error) { // keepalive is a long running goroutine that periodically does // a ping to keep the connection alive. func (s *Session) keepalive() { + t := timerPool.Get() + timer := t.(*time.Timer) + timer.Reset(s.config.KeepAliveInterval) + defer func() { + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + timerPool.Put(t) + }() + for { select { - case <-time.After(s.config.KeepAliveInterval): - _, err := s.Ping() + case <-timer.C: + rtt, err := s.Ping() if err != nil { if err != ErrSessionShutdown { s.logger.Printf("[ERR] yamux: keepalive failed: %v", err) @@ -320,6 +391,10 @@ func (s *Session) keepalive() { } return } + + if rtt >= s.config.KeepAliveInterval { + s.logger.Printf("[WARN] yamux: keepalive ping too long: %.01f seconds", rtt.Seconds()) + } case <-s.shutdownCh: return } @@ -340,20 +415,25 @@ func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) e timer := t.(*time.Timer) timer.Reset(s.config.ConnectionWriteTimeout) defer func() { - timer.Stop() - select { - case <-timer.C: - default: + if !timer.Stop() { + select { + case <-timer.C: + default: + } } timerPool.Put(t) }() + return s.waitForSendErrTimeout(timer.C, hdr, body, errCh) +} + +func (s *Session) waitForSendErrTimeout(timeout <-chan time.Time, hdr header, body io.Reader, errCh chan error) error { ready := sendReady{Hdr: hdr, Body: body, Err: errCh} select { case s.sendCh <- ready: case <-s.shutdownCh: return ErrSessionShutdown - case <-timer.C: + case <-timeout: return ErrConnectionWriteTimeout } @@ -362,7 +442,7 @@ func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) e return err case <-s.shutdownCh: return ErrSessionShutdown - case <-timer.C: + case <-timeout: return ErrConnectionWriteTimeout } } @@ -624,7 +704,9 @@ func (s *Session) incomingStream(id uint32) error { // was not yet established, then this will give the credit back. func (s *Session) closeStream(id uint32) { s.streamLock.Lock() + defer s.streamLock.Unlock() if _, ok := s.inflight[id]; ok { + delete(s.inflight, id) select { case <-s.synCh: default: @@ -632,7 +714,6 @@ func (s *Session) closeStream(id uint32) { } } delete(s.streams, id) - s.streamLock.Unlock() } // establishStream is used to mark a stream that was in the diff --git a/stream.go b/stream.go index aa23919..b2a5235 100644 --- a/stream.go +++ b/stream.go @@ -237,6 +237,16 @@ func (s *Stream) sendFlags() uint16 { // sendWindowUpdate potentially sends a window update enabling // further writes to take place. Must be invoked with the lock. func (s *Stream) sendWindowUpdate() error { + return s._sendWindowUpdate(s.session.waitForSendErr) +} + +func (s *Stream) sendWindowUpdateTimeout(timeout <-chan time.Time) error { + return s._sendWindowUpdate(func(hdr header, body io.Reader, errCh chan error) error { + return s.session.waitForSendErrTimeout(timeout, hdr, body, errCh) + }) +} + +func (s *Stream) _sendWindowUpdate(fn func(hdr header, body io.Reader, errCh chan error) error) error { s.controlHdrLock.Lock() defer s.controlHdrLock.Unlock() @@ -264,7 +274,7 @@ func (s *Stream) sendWindowUpdate() error { // Send the header s.controlHdr.encode(typeWindowUpdate, flags, s.id, delta) - if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil { + if err := fn(s.controlHdr, nil, s.controlErr); err != nil { return err } return nil diff --git a/util.go b/util.go index 8a73e92..ff1a31b 100644 --- a/util.go +++ b/util.go @@ -9,7 +9,12 @@ var ( timerPool = &sync.Pool{ New: func() interface{} { timer := time.NewTimer(time.Hour * 1e6) - timer.Stop() + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } return timer }, } From 7e64346fcb48fd331cd0760df1237ac84f64346f Mon Sep 17 00:00:00 2001 From: tantra35 Date: Wed, 19 Feb 2020 02:25:33 +0300 Subject: [PATCH 2/4] add context support --- const.go | 3 ++ go.mod | 2 ++ session.go | 82 +++++++++++++++++++++++++++--------------------------- stream.go | 63 ++++++++++++++++++++++++++++++----------- 4 files changed, 93 insertions(+), 57 deletions(-) diff --git a/const.go b/const.go index 4f52938..31d32e4 100644 --- a/const.go +++ b/const.go @@ -32,6 +32,9 @@ var ( // ErrTimeout is used when we reach an IO deadline ErrTimeout = fmt.Errorf("i/o deadline reached") + // ErrTimeout is used when we reach an IO deadline + ErrCallCanceled = fmt.Errorf("i/o was canceled") + // ErrStreamClosed is returned when using a closed stream ErrStreamClosed = fmt.Errorf("stream closed") diff --git a/go.mod b/go.mod index 672a0e5..35e9595 100644 --- a/go.mod +++ b/go.mod @@ -1 +1,3 @@ module github.com/hashicorp/yamux + +go 1.12 diff --git a/session.go b/session.go index eb94c34..ef29cc6 100644 --- a/session.go +++ b/session.go @@ -70,7 +70,6 @@ type Session struct { recvDoneCh chan struct{} // shutdown is used to safely close a session - shutdown bool shutdownErr error shutdownCh chan struct{} shutdownLock sync.Mutex @@ -267,43 +266,42 @@ func (s *Session) AcceptStream() (*Stream, error) { // Close is used to close the session and all streams. // Attempts to send a GoAway before closing the connection. func (s *Session) Close() error { + return s.close(nil) +} + +func (s *Session) close(err error) error { s.shutdownLock.Lock() defer s.shutdownLock.Unlock() - if s.shutdown { + if s.IsClosed() { return nil } - s.shutdown = true - if s.shutdownErr == nil { + if err != nil { + s.shutdownErr = err + } else if s.shutdownErr == nil { s.shutdownErr = ErrSessionShutdown } close(s.shutdownCh) s.conn.Close() - <-s.recvDoneCh - s.streamLock.Lock() - defer s.streamLock.Unlock() - for _, stream := range s.streams { - stream.forceClose() - } - return nil -} + go func() { + <-s.recvDoneCh -// exitErr is used to handle an error that is causing the -// session to terminate. -func (s *Session) exitErr(err error) { - s.shutdownLock.Lock() - if s.shutdownErr == nil { - s.shutdownErr = err - } - s.shutdownLock.Unlock() - s.Close() + s.streamLock.Lock() + defer s.streamLock.Unlock() + for _, stream := range s.streams { + stream.forceClose() + } + }() + + return nil } // GoAway can be used to prevent accepting further // connections. It does not close the underlying conn. func (s *Session) GoAway() error { - return s.waitForSend(s.goAway(goAwayNormal), nil) + errCh := make(chan error, 1) + return s.waitForSendErr(nil, s.goAway(goAwayNormal), nil, errCh) } // goAway is used to send a goAway message @@ -343,7 +341,7 @@ func (s *Session) Ping() (time.Duration, error) { hdr := header(make([]byte, headerSize)) hdr.encode(typePing, flagSYN, 0, id) errCh := make(chan error, 1) - if err := s.waitForSendErrTimeout(timer.C, hdr, nil, errCh); err != nil { + if err := s.waitForSendErrTimeout(nil, timer.C, hdr, nil, errCh); err != nil { return 0, err } @@ -387,7 +385,7 @@ func (s *Session) keepalive() { if err != nil { if err != ErrSessionShutdown { s.logger.Printf("[ERR] yamux: keepalive failed: %v", err) - s.exitErr(ErrKeepAliveTimeout) + s.close(ErrKeepAliveTimeout) } return } @@ -401,16 +399,10 @@ func (s *Session) keepalive() { } } -// waitForSendErr waits to send a header, checking for a potential shutdown -func (s *Session) waitForSend(hdr header, body io.Reader) error { - errCh := make(chan error, 1) - return s.waitForSendErr(hdr, body, errCh) -} - // waitForSendErr waits to send a header with optional data, checking for a // potential shutdown. Since there's the expectation that sends can happen // in a timely manner, we enforce the connection write timeout here. -func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) error { +func (s *Session) waitForSendErr(cancel <-chan struct{}, hdr header, body io.Reader, errCh chan error) error { t := timerPool.Get() timer := t.(*time.Timer) timer.Reset(s.config.ConnectionWriteTimeout) @@ -424,15 +416,17 @@ func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) e timerPool.Put(t) }() - return s.waitForSendErrTimeout(timer.C, hdr, body, errCh) + return s.waitForSendErrTimeout(cancel, timer.C, hdr, body, errCh) } -func (s *Session) waitForSendErrTimeout(timeout <-chan time.Time, hdr header, body io.Reader, errCh chan error) error { +func (s *Session) waitForSendErrTimeout(cancel <-chan struct{}, timeout <-chan time.Time, hdr header, body io.Reader, errCh chan error) error { ready := sendReady{Hdr: hdr, Body: body, Err: errCh} select { case s.sendCh <- ready: case <-s.shutdownCh: return ErrSessionShutdown + case <-cancel: + return ErrCallCanceled case <-timeout: return ErrConnectionWriteTimeout } @@ -442,6 +436,8 @@ func (s *Session) waitForSendErrTimeout(timeout <-chan time.Time, hdr header, bo return err case <-s.shutdownCh: return ErrSessionShutdown + case <-cancel: + return ErrCallCanceled case <-timeout: return ErrConnectionWriteTimeout } @@ -475,6 +471,12 @@ func (s *Session) sendNoWait(hdr header) error { // send is a long running goroutine that sends data func (s *Session) send() { + if err := s.sendLoop(); err != nil { + s.close(err) + } +} + +func (s *Session) sendLoop() error { for { select { case ready := <-s.sendCh: @@ -486,8 +488,7 @@ func (s *Session) send() { if err != nil { s.logger.Printf("[ERR] yamux: Failed to write header: %v", err) asyncSendErr(ready.Err, err) - s.exitErr(err) - return + return err } sent += n } @@ -499,15 +500,14 @@ func (s *Session) send() { if err != nil { s.logger.Printf("[ERR] yamux: Failed to write body: %v", err) asyncSendErr(ready.Err, err) - s.exitErr(err) - return + return err } } // No error, successful send asyncSendErr(ready.Err, nil) case <-s.shutdownCh: - return + return nil } } } @@ -515,7 +515,7 @@ func (s *Session) send() { // recv is a long running goroutine that accepts new data func (s *Session) recv() { if err := s.recvLoop(); err != nil { - s.exitErr(err) + s.close(err) } } @@ -582,7 +582,7 @@ func (s *Session) handleStreamMessage(hdr header) error { s.logger.Printf("[WARN] yamux: Discarding data for stream: %d", id) if _, err := io.CopyN(ioutil.Discard, s.bufRead, int64(hdr.Length())); err != nil { s.logger.Printf("[ERR] yamux: Failed to discard data: %v", err) - return nil + return err } } else { s.logger.Printf("[WARN] yamux: frame for missing stream: %v", hdr) @@ -720,6 +720,7 @@ func (s *Session) closeStream(id uint32) { // SYN Sent state as established. func (s *Session) establishStream(id uint32) { s.streamLock.Lock() + defer s.streamLock.Unlock() if _, ok := s.inflight[id]; ok { delete(s.inflight, id) } else { @@ -730,5 +731,4 @@ func (s *Session) establishStream(id uint32) { default: s.logger.Printf("[ERR] yamux: established stream without inflight SYN (didn't have semaphore)") } - s.streamLock.Unlock() } diff --git a/stream.go b/stream.go index b2a5235..bb87737 100644 --- a/stream.go +++ b/stream.go @@ -2,6 +2,7 @@ package yamux import ( "bytes" + "context" "io" "sync" "sync/atomic" @@ -47,8 +48,9 @@ type Stream struct { recvNotifyCh chan struct{} sendNotifyCh chan struct{} - readDeadline atomic.Value // time.Time - writeDeadline atomic.Value // time.Time + readDeadline atomic.Value // time.Time + writeDeadline atomic.Value // time.Time + cancelDeadline atomic.Value } // newStream is used to construct a new stream within @@ -84,7 +86,15 @@ func (s *Stream) StreamID() uint32 { // Read is used to read from the stream func (s *Stream) Read(b []byte) (n int, err error) { - defer asyncNotify(s.recvNotifyCh) + var cancel <-chan struct{} + cancelval := s.cancelDeadline.Load() + if cancelval != nil { + cancel = cancelval.(<-chan struct{}) + } + return s.read(cancel, b) +} + +func (s *Stream) read(cancel <-chan struct{}, b []byte) (n int, err error) { START: s.stateLock.Lock() switch s.state { @@ -99,7 +109,11 @@ START: s.stateLock.Unlock() return 0, io.EOF } + n, _ = s.recvBuf.Read(b) s.recvLock.Unlock() + + s.stateLock.Unlock() + return n, nil case streamReset: s.stateLock.Unlock() return 0, ErrConnectionReset @@ -136,6 +150,8 @@ WAIT: timer.Stop() } goto START + case <-cancel: + return 0, ErrCallCanceled case <-timeout: return 0, ErrTimeout } @@ -146,8 +162,13 @@ func (s *Stream) Write(b []byte) (n int, err error) { s.sendLock.Lock() defer s.sendLock.Unlock() total := 0 + var cancel <-chan struct{} + cancelval := s.cancelDeadline.Load() + if cancelval != nil { + cancel = cancelval.(<-chan struct{}) + } for total < len(b) { - n, err := s.write(b[total:]) + n, err := s.write(cancel, b[total:]) total += n if err != nil { return total, err @@ -158,7 +179,7 @@ func (s *Stream) Write(b []byte) (n int, err error) { // write is used to write to the stream, may return on // a short write. -func (s *Stream) write(b []byte) (n int, err error) { +func (s *Stream) write(cancel <-chan struct{}, b []byte) (n int, err error) { var flags uint16 var max uint32 var body io.Reader @@ -191,7 +212,7 @@ START: // Send the header s.sendHdr.encode(typeData, flags, s.id, max) - if err = s.session.waitForSendErr(s.sendHdr, body, s.sendErr); err != nil { + if err = s.session.waitForSendErr(cancel, s.sendHdr, body, s.sendErr); err != nil { return 0, err } @@ -211,6 +232,8 @@ WAIT: select { case <-s.sendNotifyCh: goto START + case <-cancel: + return 0, ErrCallCanceled case <-timeout: return 0, ErrTimeout } @@ -237,12 +260,14 @@ func (s *Stream) sendFlags() uint16 { // sendWindowUpdate potentially sends a window update enabling // further writes to take place. Must be invoked with the lock. func (s *Stream) sendWindowUpdate() error { - return s._sendWindowUpdate(s.session.waitForSendErr) + return s._sendWindowUpdate(func(hdr header, body io.Reader, errCh chan error) error { + return s.session.waitForSendErr(nil, hdr, body, errCh) + }) } func (s *Stream) sendWindowUpdateTimeout(timeout <-chan time.Time) error { return s._sendWindowUpdate(func(hdr header, body io.Reader, errCh chan error) error { - return s.session.waitForSendErrTimeout(timeout, hdr, body, errCh) + return s.session.waitForSendErrTimeout(nil, timeout, hdr, body, errCh) }) } @@ -252,12 +277,12 @@ func (s *Stream) _sendWindowUpdate(fn func(hdr header, body io.Reader, errCh cha // Determine the delta update max := s.session.config.MaxStreamWindowSize - var bufLen uint32 + var recvBufLen uint32 s.recvLock.Lock() if s.recvBuf != nil { - bufLen = uint32(s.recvBuf.Len()) + recvBufLen = uint32(s.recvBuf.Len()) } - delta := (max - bufLen) - s.recvWindow + delta := (max - recvBufLen) - s.recvWindow // Determine the flags if any flags := s.sendFlags() @@ -288,7 +313,7 @@ func (s *Stream) sendClose() error { flags := s.sendFlags() flags |= flagFIN s.controlHdr.encode(typeWindowUpdate, flags, s.id, 0) - if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil { + if err := s.session.waitForSendErr(nil, s.controlHdr, nil, s.controlErr); err != nil { return err } return nil @@ -414,9 +439,6 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error { return nil } - // Wrap in a limited reader - conn = &io.LimitedReader{R: conn, N: int64(length)} - // Copy into buffer s.recvLock.Lock() @@ -430,7 +452,7 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error { // This way we can read in the whole packet without further allocations. s.recvBuf = bytes.NewBuffer(make([]byte, 0, length)) } - if _, err := io.Copy(s.recvBuf, conn); err != nil { + if _, err := io.CopyN(s.recvBuf, conn, int64(length)); err != nil { s.session.logger.Printf("[ERR] yamux: Failed to read stream data: %v", err) s.recvLock.Unlock() return err @@ -468,6 +490,15 @@ func (s *Stream) SetWriteDeadline(t time.Time) error { return nil } +// SetContext set stream context +func (s *Stream) SetContext(ctx context.Context) { + if ctx != nil { + s.cancelDeadline.Store(ctx.Done()) + } else { + s.cancelDeadline.Store(nil) + } +} + // Shrink is used to compact the amount of buffers utilized // This is useful when using Yamux in a connection pool to reduce // the idle memory utilization. From 88b7b4d2d8ae0402730eda447fff1d115feb6760 Mon Sep 17 00:00:00 2001 From: tantra35 Date: Thu, 20 Feb 2020 19:00:44 +0300 Subject: [PATCH 3/4] reogranize code of yamux(but this not helps in https://github.com/hashicorp/nomad/issues/6620 due i think real problem in nomad code, which hung in imo Accept) + add support for context --- const.go | 16 ++--- const_test.go | 2 +- session.go | 179 ++++++++++++++++++++++++++---------------------- session_test.go | 12 ++-- stream.go | 109 ++++++++++++++++------------- 5 files changed, 176 insertions(+), 142 deletions(-) diff --git a/const.go b/const.go index 31d32e4..7cfb2d4 100644 --- a/const.go +++ b/const.go @@ -124,34 +124,34 @@ const ( sizeOfStreamID + sizeOfLength ) -type header []byte +type header [headerSize]byte -func (h header) Version() uint8 { +func (h *header) Version() uint8 { return h[0] } -func (h header) MsgType() uint8 { +func (h *header) MsgType() uint8 { return h[1] } -func (h header) Flags() uint16 { +func (h *header) Flags() uint16 { return binary.BigEndian.Uint16(h[2:4]) } -func (h header) StreamID() uint32 { +func (h *header) StreamID() uint32 { return binary.BigEndian.Uint32(h[4:8]) } -func (h header) Length() uint32 { +func (h *header) Length() uint32 { return binary.BigEndian.Uint32(h[8:12]) } -func (h header) String() string { +func (h *header) String() string { return fmt.Sprintf("Vsn:%d Type:%d Flags:%d StreamID:%d Length:%d", h.Version(), h.MsgType(), h.Flags(), h.StreamID(), h.Length()) } -func (h header) encode(msgType uint8, flags uint16, streamID uint32, length uint32) { +func (h *header) encode(msgType uint8, flags uint16, streamID uint32, length uint32) { h[0] = protoVersion h[1] = msgType binary.BigEndian.PutUint16(h[2:4], flags) diff --git a/const_test.go b/const_test.go index 153da18..feb34ef 100644 --- a/const_test.go +++ b/const_test.go @@ -51,7 +51,7 @@ func TestConst(t *testing.T) { } func TestEncodeDecode(t *testing.T) { - hdr := header(make([]byte, headerSize)) + var hdr header hdr.encode(typeWindowUpdate, flagACK|flagRST, 1234, 4321) if hdr.Version() != protoVersion { diff --git a/session.go b/session.go index ef29cc6..0824b9a 100644 --- a/session.go +++ b/session.go @@ -78,7 +78,7 @@ type Session struct { // sendReady is used to either mark a stream as ready // or to directly send a header type sendReady struct { - Hdr []byte + Hdr header Body io.Reader Err chan error } @@ -161,9 +161,7 @@ func (s *Session) OpenTimeout(timeout time.Duration) (net.Conn, error) { // OpenStream is used to create a new stream func (s *Session) OpenStream() (*Stream, error) { - return s._openStream(nil, func(stream *Stream) error { - return stream.sendWindowUpdate() - }) + return s.OpenStreamTimeout(s.config.ConnectionWriteTimeout) } // OpenStreamTimeout is used to create a new stream with TimeOut @@ -181,13 +179,6 @@ func (s *Session) OpenStreamTimeout(timeout time.Duration) (*Stream, error) { timerPool.Put(t) }() - return s._openStream(timer.C, func(stream *Stream) error { - return stream.sendWindowUpdateTimeout(timer.C) - }) -} - -// OpenStream is used to create a new stream -func (s *Session) _openStream(timeout <-chan time.Time, fn func(*Stream) error) (*Stream, error) { if s.IsClosed() { return nil, ErrSessionShutdown } @@ -200,7 +191,7 @@ func (s *Session) _openStream(timeout <-chan time.Time, fn func(*Stream) error) case s.synCh <- struct{}{}: case <-s.shutdownCh: return nil, ErrSessionShutdown - case <-timeout: + case <-timer.C: s.logger.Printf("[ERR] yamux: Failed to openstream due %v", ErrTimeout) return nil, ErrTimeout } @@ -223,7 +214,7 @@ GET_ID: s.streamLock.Unlock() // Send the window update to create - if err := fn(stream); err != nil { + if err := stream.sendWindowUpdateTimeout(nil, timer.C); err != nil { s.logger.Printf("[ERR] yamux: Failed to openstream due %v", err) s.streamLock.Lock() delete(s.streams, id) @@ -300,16 +291,29 @@ func (s *Session) close(err error) error { // GoAway can be used to prevent accepting further // connections. It does not close the underlying conn. func (s *Session) GoAway() error { - errCh := make(chan error, 1) - return s.waitForSendErr(nil, s.goAway(goAwayNormal), nil, errCh) + t := timerPool.Get() + timer := t.(*time.Timer) + timer.Reset(s.config.ConnectionWriteTimeout) + defer func() { + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + timerPool.Put(t) + }() + + ready := sendReady{Body: nil, Err: make(chan error, 1)} + s.goAway(goAwayNormal, &ready.Hdr) + + return s.waitForSendErrTimeout(nil, timer.C, &ready) } // goAway is used to send a goAway message -func (s *Session) goAway(reason uint32) header { +func (s *Session) goAway(reason uint32, hdr *header) { atomic.SwapInt32(&s.localGoAway, 1) - hdr := header(make([]byte, headerSize)) hdr.encode(typeGoAway, 0, 0, reason) - return hdr } // Ping is used to measure the RTT response time @@ -338,10 +342,9 @@ func (s *Session) Ping() (time.Duration, error) { }() // Send the ping request - hdr := header(make([]byte, headerSize)) - hdr.encode(typePing, flagSYN, 0, id) - errCh := make(chan error, 1) - if err := s.waitForSendErrTimeout(nil, timer.C, hdr, nil, errCh); err != nil { + ready := sendReady{Body: nil, Err: make(chan error, 1)} + ready.Hdr.encode(typePing, flagSYN, 0, id) + if err := s.waitForSendErrTimeout(nil, timer.C, &ready); err != nil { return 0, err } @@ -378,6 +381,8 @@ func (s *Session) keepalive() { timerPool.Put(t) }() + lcheckcount := 0 + for { select { case <-timer.C: @@ -390,6 +395,21 @@ func (s *Session) keepalive() { return } + lsynChLen := len(s.synCh) + lacceptChLen := len(s.acceptCh) + + if (lsynChLen >= s.config.AcceptBacklog/2) || (lacceptChLen >= s.config.AcceptBacklog/2) { + lcheckcount++ + } else { + lcheckcount = 0 + } + + if lcheckcount >= 5 { + s.logger.Printf("[WARN] yamux: too long synCh(%d)/acceptCh(%d) for last %d keepalive intervals, something wentwrong, so close backend session", lsynChLen, lacceptChLen, lcheckcount) + s.close(ErrKeepAliveTimeout) + return + } + if rtt >= s.config.KeepAliveInterval { s.logger.Printf("[WARN] yamux: keepalive ping too long: %.01f seconds", rtt.Seconds()) } @@ -399,30 +419,12 @@ func (s *Session) keepalive() { } } -// waitForSendErr waits to send a header with optional data, checking for a +// waitForSendErrTimeout waits to send a header with optional data, checking for a // potential shutdown. Since there's the expectation that sends can happen // in a timely manner, we enforce the connection write timeout here. -func (s *Session) waitForSendErr(cancel <-chan struct{}, hdr header, body io.Reader, errCh chan error) error { - t := timerPool.Get() - timer := t.(*time.Timer) - timer.Reset(s.config.ConnectionWriteTimeout) - defer func() { - if !timer.Stop() { - select { - case <-timer.C: - default: - } - } - timerPool.Put(t) - }() - - return s.waitForSendErrTimeout(cancel, timer.C, hdr, body, errCh) -} - -func (s *Session) waitForSendErrTimeout(cancel <-chan struct{}, timeout <-chan time.Time, hdr header, body io.Reader, errCh chan error) error { - ready := sendReady{Hdr: hdr, Body: body, Err: errCh} +func (s *Session) waitForSendErrTimeout(cancel <-chan struct{}, timeout <-chan time.Time, ready *sendReady) error { select { - case s.sendCh <- ready: + case s.sendCh <- *ready: case <-s.shutdownCh: return ErrSessionShutdown case <-cancel: @@ -432,7 +434,7 @@ func (s *Session) waitForSendErrTimeout(cancel <-chan struct{}, timeout <-chan t } select { - case err := <-errCh: + case err := <-ready.Err: return err case <-s.shutdownCh: return ErrSessionShutdown @@ -446,7 +448,7 @@ func (s *Session) waitForSendErrTimeout(cancel <-chan struct{}, timeout <-chan t // sendNoWait does a send without waiting. Since there's the expectation that // the send happens right here, we enforce the connection write timeout if we // can't queue the header to be sent. -func (s *Session) sendNoWait(hdr header) error { +func (s *Session) sendNoWait(ready *sendReady) error { t := timerPool.Get() timer := t.(*time.Timer) timer.Reset(s.config.ConnectionWriteTimeout) @@ -460,7 +462,7 @@ func (s *Session) sendNoWait(hdr header) error { }() select { - case s.sendCh <- sendReady{Hdr: hdr}: + case s.sendCh <- *ready: return nil case <-s.shutdownCh: return ErrSessionShutdown @@ -477,21 +479,20 @@ func (s *Session) send() { } func (s *Session) sendLoop() error { + var ready sendReady for { select { - case ready := <-s.sendCh: + case ready = <-s.sendCh: // Send a header if ready - if ready.Hdr != nil { - sent := 0 - for sent < len(ready.Hdr) { - n, err := s.conn.Write(ready.Hdr[sent:]) - if err != nil { - s.logger.Printf("[ERR] yamux: Failed to write header: %v", err) - asyncSendErr(ready.Err, err) - return err - } - sent += n + sent := 0 + for sent < headerSize { + n, err := s.conn.Write(ready.Hdr[sent:]) + if err != nil { + s.logger.Printf("[ERR] yamux: Failed to write header: %v", err) + asyncSendErr(ready.Err, err) + return err } + sent += n } // Send data from a body if given @@ -532,10 +533,10 @@ var ( // recvLoop continues to receive data until a fatal error is encountered func (s *Session) recvLoop() error { defer close(s.recvDoneCh) - hdr := header(make([]byte, headerSize)) + var hdr header for { // Read the header - if _, err := io.ReadFull(s.bufRead, hdr); err != nil { + if _, err := io.ReadFull(s.bufRead, hdr[:]); err != nil { if err != io.EOF && !strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "reset by peer") { s.logger.Printf("[ERR] yamux: Failed to read header: %v", err) } @@ -590,20 +591,29 @@ func (s *Session) handleStreamMessage(hdr header) error { return nil } + if err := stream.processFlags(flags); err != nil { + var ready sendReady + s.goAway(goAwayProtoErr, &ready.Hdr) + + if sendErr := s.sendNoWait(&ready); sendErr != nil { + s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr) + } + + return err + } + // Check if this is a window update if hdr.MsgType() == typeWindowUpdate { - if err := stream.incrSendWindow(hdr, flags); err != nil { - if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil { - s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr) - } - return err - } + stream.incrSetWindow(hdr) return nil } // Read the new data - if err := stream.readData(hdr, flags, s.bufRead); err != nil { - if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil { + if err := stream.readData(hdr.Length(), flags, s.bufRead); err != nil { + var ready sendReady + s.goAway(goAwayProtoErr, &ready.Hdr) + + if sendErr := s.sendNoWait(&ready); sendErr != nil { s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr) } return err @@ -620,9 +630,10 @@ func (s *Session) handlePing(hdr header) error { // don't interfere with the receiving thread blocking for the write. if flags&flagSYN == flagSYN { go func() { - hdr := header(make([]byte, headerSize)) - hdr.encode(typePing, flagACK, 0, pingID) - if err := s.sendNoWait(hdr); err != nil { + var ready sendReady + ready.Hdr.encode(typePing, flagACK, 0, pingID) + + if err := s.sendNoWait(&ready); err != nil { s.logger.Printf("[WARN] yamux: failed to send ping reply: %v", err) } }() @@ -663,13 +674,11 @@ func (s *Session) handleGoAway(hdr header) error { func (s *Session) incomingStream(id uint32) error { // Reject immediately if we are doing a go away if atomic.LoadInt32(&s.localGoAway) == 1 { - hdr := header(make([]byte, headerSize)) - hdr.encode(typeWindowUpdate, flagRST, id, 0) - return s.sendNoWait(hdr) - } + var ready sendReady + ready.Hdr.encode(typeWindowUpdate, flagRST, id, 0) - // Allocate a new stream - stream := newStream(s, id, streamSYNReceived) + return s.sendNoWait(&ready) + } s.streamLock.Lock() defer s.streamLock.Unlock() @@ -677,25 +686,33 @@ func (s *Session) incomingStream(id uint32) error { // Check if stream already exists if _, ok := s.streams[id]; ok { s.logger.Printf("[ERR] yamux: duplicate stream declared") - if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil { + var ready sendReady + s.goAway(goAwayProtoErr, &ready.Hdr) + + if sendErr := s.sendNoWait(&ready); sendErr != nil { s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr) } return ErrDuplicateStream } - // Register the stream - s.streams[id] = stream + // Allocate a new stream + stream := newStream(s, id, streamSYNReceived) // Check if we've exceeded the backlog select { case s.acceptCh <- stream: + // Register the stream + s.streams[id] = stream return nil default: // Backlog exceeded! RST the stream s.logger.Printf("[WARN] yamux: backlog exceeded, forcing connection reset") delete(s.streams, id) - stream.sendHdr.encode(typeWindowUpdate, flagRST, id, 0) - return s.sendNoWait(stream.sendHdr) + + var ready sendReady + ready.Hdr.encode(typeWindowUpdate, flagRST, id, 0) + + return s.sendNoWait(&ready) } } diff --git a/session_test.go b/session_test.go index 1645e2b..42de8c8 100644 --- a/session_test.go +++ b/session_test.go @@ -1120,10 +1120,10 @@ func TestSession_sendNoWait_Timeout(t *testing.T) { conn := client.conn.(*pipeConn) conn.writeBlocker.Lock() - hdr := header(make([]byte, headerSize)) - hdr.encode(typePing, flagACK, 0, 0) + var ready sendReady + ready.Hdr.encode(typePing, flagACK, 0, 0) for { - err = client.sendNoWait(hdr) + err = client.sendNoWait(&ready) if err == nil { continue } else if err == ErrConnectionWriteTimeout { @@ -1164,9 +1164,9 @@ func TestSession_PingOfDeath(t *testing.T) { conn.writeBlocker.Lock() for { - hdr := header(make([]byte, headerSize)) - hdr.encode(typePing, 0, 0, 0) - err = server.sendNoWait(hdr) + var ready sendReady + ready.Hdr.encode(typePing, 0, 0, 0) + err = server.sendNoWait(&ready) if err == nil { continue } else if err == ErrConnectionWriteTimeout { diff --git a/stream.go b/stream.go index bb87737..ec3fd96 100644 --- a/stream.go +++ b/stream.go @@ -37,13 +37,9 @@ type Stream struct { recvBuf *bytes.Buffer recvLock sync.Mutex - controlHdr header - controlErr chan error - controlHdrLock sync.Mutex - - sendHdr header - sendErr chan error - sendLock sync.Mutex + controlErr chan error + sendErr chan error + sendLock sync.Mutex recvNotifyCh chan struct{} sendNotifyCh chan struct{} @@ -60,9 +56,7 @@ func newStream(session *Session, id uint32, state streamState) *Stream { id: id, session: session, state: state, - controlHdr: header(make([]byte, headerSize)), controlErr: make(chan error, 1), - sendHdr: header(make([]byte, headerSize)), sendErr: make(chan error, 1), recvWindow: initialStreamWindow, sendWindow: initialStreamWindow, @@ -162,11 +156,13 @@ func (s *Stream) Write(b []byte) (n int, err error) { s.sendLock.Lock() defer s.sendLock.Unlock() total := 0 + var cancel <-chan struct{} cancelval := s.cancelDeadline.Load() if cancelval != nil { cancel = cancelval.(<-chan struct{}) } + for total < len(b) { n, err := s.write(cancel, b[total:]) total += n @@ -183,6 +179,20 @@ func (s *Stream) write(cancel <-chan struct{}, b []byte) (n int, err error) { var flags uint16 var max uint32 var body io.Reader + var ready sendReady + t := timerPool.Get() + timer := t.(*time.Timer) + timer.Reset(s.session.config.ConnectionWriteTimeout) + defer func() { + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + timerPool.Put(t) + }() + START: s.stateLock.Lock() switch s.state { @@ -211,8 +221,10 @@ START: body = bytes.NewReader(b[:max]) // Send the header - s.sendHdr.encode(typeData, flags, s.id, max) - if err = s.session.waitForSendErr(cancel, s.sendHdr, body, s.sendErr); err != nil { + ready.Body = body + ready.Err = s.sendErr + ready.Hdr.encode(typeData, flags, s.id, max) + if err = s.session.waitForSendErrTimeout(cancel, timer.C, &ready); err != nil { return 0, err } @@ -260,29 +272,31 @@ func (s *Stream) sendFlags() uint16 { // sendWindowUpdate potentially sends a window update enabling // further writes to take place. Must be invoked with the lock. func (s *Stream) sendWindowUpdate() error { - return s._sendWindowUpdate(func(hdr header, body io.Reader, errCh chan error) error { - return s.session.waitForSendErr(nil, hdr, body, errCh) - }) -} + t := timerPool.Get() + timer := t.(*time.Timer) + timer.Reset(s.session.config.ConnectionWriteTimeout) + defer func() { + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + timerPool.Put(t) + }() -func (s *Stream) sendWindowUpdateTimeout(timeout <-chan time.Time) error { - return s._sendWindowUpdate(func(hdr header, body io.Reader, errCh chan error) error { - return s.session.waitForSendErrTimeout(nil, timeout, hdr, body, errCh) - }) + return s.sendWindowUpdateTimeout(nil, timer.C) } -func (s *Stream) _sendWindowUpdate(fn func(hdr header, body io.Reader, errCh chan error) error) error { - s.controlHdrLock.Lock() - defer s.controlHdrLock.Unlock() - +func (s *Stream) sendWindowUpdateTimeout(cancel <-chan struct{}, timeout <-chan time.Time) error { // Determine the delta update max := s.session.config.MaxStreamWindowSize - var recvBufLen uint32 + var bufLen uint32 s.recvLock.Lock() if s.recvBuf != nil { - recvBufLen = uint32(s.recvBuf.Len()) + bufLen = uint32(s.recvBuf.Len()) } - delta := (max - recvBufLen) - s.recvWindow + delta := (max - bufLen) - s.recvWindow // Determine the flags if any flags := s.sendFlags() @@ -298,8 +312,9 @@ func (s *Stream) _sendWindowUpdate(fn func(hdr header, body io.Reader, errCh cha s.recvLock.Unlock() // Send the header - s.controlHdr.encode(typeWindowUpdate, flags, s.id, delta) - if err := fn(s.controlHdr, nil, s.controlErr); err != nil { + ready := sendReady{Body: nil, Err: s.controlErr} + ready.Hdr.encode(typeWindowUpdate, flags, s.id, delta) + if err := s.session.waitForSendErrTimeout(cancel, timeout, &ready); err != nil { return err } return nil @@ -307,13 +322,25 @@ func (s *Stream) _sendWindowUpdate(fn func(hdr header, body io.Reader, errCh cha // sendClose is used to send a FIN func (s *Stream) sendClose() error { - s.controlHdrLock.Lock() - defer s.controlHdrLock.Unlock() + t := timerPool.Get() + timer := t.(*time.Timer) + timer.Reset(s.session.config.ConnectionWriteTimeout) + defer func() { + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + timerPool.Put(t) + }() flags := s.sendFlags() flags |= flagFIN - s.controlHdr.encode(typeWindowUpdate, flags, s.id, 0) - if err := s.session.waitForSendErr(nil, s.controlHdr, nil, s.controlErr); err != nil { + + ready := sendReady{Body: nil, Err: s.controlErr} + ready.Hdr.encode(typeWindowUpdate, flags, s.id, 0) + if err := s.session.waitForSendErrTimeout(nil, timer.C, &ready); err != nil { return err } return nil @@ -372,6 +399,7 @@ func (s *Stream) processFlags(flags uint16) error { defer func() { if closeStream { s.session.closeStream(s.id) + s.notifyWaiting() } }() @@ -395,7 +423,6 @@ func (s *Stream) processFlags(flags uint16) error { case streamLocalClose: s.state = streamClosed closeStream = true - s.notifyWaiting() default: s.session.logger.Printf("[ERR] yamux: unexpected FIN flag in state %d", s.state) return ErrUnexpectedFlag @@ -404,7 +431,6 @@ func (s *Stream) processFlags(flags uint16) error { if flags&flagRST == flagRST { s.state = streamReset closeStream = true - s.notifyWaiting() } return nil } @@ -416,25 +442,15 @@ func (s *Stream) notifyWaiting() { } // incrSendWindow updates the size of our send window -func (s *Stream) incrSendWindow(hdr header, flags uint16) error { - if err := s.processFlags(flags); err != nil { - return err - } - +func (s *Stream) incrSetWindow(hdr header) { // Increase window, unblock a sender atomic.AddUint32(&s.sendWindow, hdr.Length()) asyncNotify(s.sendNotifyCh) - return nil } // readData is used to handle a data frame -func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error { - if err := s.processFlags(flags); err != nil { - return err - } - +func (s *Stream) readData(length uint32, flags uint16, conn io.Reader) error { // Check that our recv window is not exceeded - length := hdr.Length() if length == 0 { return nil } @@ -444,6 +460,7 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error { if length > s.recvWindow { s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, s.recvWindow, length) + s.recvLock.Unlock() return ErrRecvWindowExceeded } From 9665eb76f2281b81a89b9d70ab0a05e85231b3de Mon Sep 17 00:00:00 2001 From: tantra35 Date: Thu, 5 Mar 2020 03:35:42 +0300 Subject: [PATCH 4/4] fix error, which prevents send Ping in keepalive loop only once --- session.go | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/session.go b/session.go index 0824b9a..e342fc7 100644 --- a/session.go +++ b/session.go @@ -368,24 +368,16 @@ func (s *Session) Ping() (time.Duration, error) { // keepalive is a long running goroutine that periodically does // a ping to keep the connection alive. func (s *Session) keepalive() { - t := timerPool.Get() - timer := t.(*time.Timer) - timer.Reset(s.config.KeepAliveInterval) + ticker := time.NewTicker(s.config.KeepAliveInterval) defer func() { - if !timer.Stop() { - select { - case <-timer.C: - default: - } - } - timerPool.Put(t) + ticker.Stop() }() lcheckcount := 0 for { select { - case <-timer.C: + case <-ticker.C: rtt, err := s.Ping() if err != nil { if err != ErrSessionShutdown { @@ -398,7 +390,7 @@ func (s *Session) keepalive() { lsynChLen := len(s.synCh) lacceptChLen := len(s.acceptCh) - if (lsynChLen >= s.config.AcceptBacklog/2) || (lacceptChLen >= s.config.AcceptBacklog/2) { + if (lsynChLen >= s.config.AcceptBacklog) || (lacceptChLen >= s.config.AcceptBacklog) { lcheckcount++ } else { lcheckcount = 0