diff --git a/const.go b/const.go index 4f52938..7cfb2d4 100644 --- a/const.go +++ b/const.go @@ -32,6 +32,9 @@ var ( // ErrTimeout is used when we reach an IO deadline ErrTimeout = fmt.Errorf("i/o deadline reached") + // ErrTimeout is used when we reach an IO deadline + ErrCallCanceled = fmt.Errorf("i/o was canceled") + // ErrStreamClosed is returned when using a closed stream ErrStreamClosed = fmt.Errorf("stream closed") @@ -121,34 +124,34 @@ const ( sizeOfStreamID + sizeOfLength ) -type header []byte +type header [headerSize]byte -func (h header) Version() uint8 { +func (h *header) Version() uint8 { return h[0] } -func (h header) MsgType() uint8 { +func (h *header) MsgType() uint8 { return h[1] } -func (h header) Flags() uint16 { +func (h *header) Flags() uint16 { return binary.BigEndian.Uint16(h[2:4]) } -func (h header) StreamID() uint32 { +func (h *header) StreamID() uint32 { return binary.BigEndian.Uint32(h[4:8]) } -func (h header) Length() uint32 { +func (h *header) Length() uint32 { return binary.BigEndian.Uint32(h[8:12]) } -func (h header) String() string { +func (h *header) String() string { return fmt.Sprintf("Vsn:%d Type:%d Flags:%d StreamID:%d Length:%d", h.Version(), h.MsgType(), h.Flags(), h.StreamID(), h.Length()) } -func (h header) encode(msgType uint8, flags uint16, streamID uint32, length uint32) { +func (h *header) encode(msgType uint8, flags uint16, streamID uint32, length uint32) { h[0] = protoVersion h[1] = msgType binary.BigEndian.PutUint16(h[2:4], flags) diff --git a/const_test.go b/const_test.go index 153da18..feb34ef 100644 --- a/const_test.go +++ b/const_test.go @@ -51,7 +51,7 @@ func TestConst(t *testing.T) { } func TestEncodeDecode(t *testing.T) { - hdr := header(make([]byte, headerSize)) + var hdr header hdr.encode(typeWindowUpdate, flagACK|flagRST, 1234, 4321) if hdr.Version() != protoVersion { diff --git a/go.mod b/go.mod index 672a0e5..35e9595 100644 --- a/go.mod +++ b/go.mod @@ -1 +1,3 @@ module github.com/hashicorp/yamux + +go 1.12 diff --git a/session.go b/session.go index a80ddec..e342fc7 100644 --- a/session.go +++ b/session.go @@ -70,7 +70,6 @@ type Session struct { recvDoneCh chan struct{} // shutdown is used to safely close a session - shutdown bool shutdownErr error shutdownCh chan struct{} shutdownLock sync.Mutex @@ -79,7 +78,7 @@ type Session struct { // sendReady is used to either mark a stream as ready // or to directly send a header type sendReady struct { - Hdr []byte + Hdr header Body io.Reader Err chan error } @@ -151,8 +150,35 @@ func (s *Session) Open() (net.Conn, error) { return conn, nil } +// OpenTimeout is used to create a new stream as a net.Conn with TimeOut +func (s *Session) OpenTimeout(timeout time.Duration) (net.Conn, error) { + conn, err := s.OpenStreamTimeout(timeout) + if err != nil { + return nil, err + } + return conn, nil +} + // OpenStream is used to create a new stream func (s *Session) OpenStream() (*Stream, error) { + return s.OpenStreamTimeout(s.config.ConnectionWriteTimeout) +} + +// OpenStreamTimeout is used to create a new stream with TimeOut +func (s *Session) OpenStreamTimeout(timeout time.Duration) (*Stream, error) { + t := timerPool.Get() + timer := t.(*time.Timer) + timer.Reset(timeout) + defer func() { + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + timerPool.Put(t) + }() + if s.IsClosed() { return nil, ErrSessionShutdown } @@ -165,6 +191,9 @@ func (s *Session) OpenStream() (*Stream, error) { case s.synCh <- struct{}{}: case <-s.shutdownCh: return nil, ErrSessionShutdown + case <-timer.C: + s.logger.Printf("[ERR] yamux: Failed to openstream due %v", ErrTimeout) + return nil, ErrTimeout } GET_ID: @@ -185,7 +214,12 @@ GET_ID: s.streamLock.Unlock() // Send the window update to create - if err := stream.sendWindowUpdate(); err != nil { + if err := stream.sendWindowUpdateTimeout(nil, timer.C); err != nil { + s.logger.Printf("[ERR] yamux: Failed to openstream due %v", err) + s.streamLock.Lock() + delete(s.streams, id) + delete(s.inflight, id) + s.streamLock.Unlock() select { case <-s.synCh: default: @@ -223,51 +257,63 @@ func (s *Session) AcceptStream() (*Stream, error) { // Close is used to close the session and all streams. // Attempts to send a GoAway before closing the connection. func (s *Session) Close() error { + return s.close(nil) +} + +func (s *Session) close(err error) error { s.shutdownLock.Lock() defer s.shutdownLock.Unlock() - if s.shutdown { + if s.IsClosed() { return nil } - s.shutdown = true - if s.shutdownErr == nil { + if err != nil { + s.shutdownErr = err + } else if s.shutdownErr == nil { s.shutdownErr = ErrSessionShutdown } close(s.shutdownCh) s.conn.Close() - <-s.recvDoneCh - s.streamLock.Lock() - defer s.streamLock.Unlock() - for _, stream := range s.streams { - stream.forceClose() - } - return nil -} + go func() { + <-s.recvDoneCh -// exitErr is used to handle an error that is causing the -// session to terminate. -func (s *Session) exitErr(err error) { - s.shutdownLock.Lock() - if s.shutdownErr == nil { - s.shutdownErr = err - } - s.shutdownLock.Unlock() - s.Close() + s.streamLock.Lock() + defer s.streamLock.Unlock() + for _, stream := range s.streams { + stream.forceClose() + } + }() + + return nil } // GoAway can be used to prevent accepting further // connections. It does not close the underlying conn. func (s *Session) GoAway() error { - return s.waitForSend(s.goAway(goAwayNormal), nil) + t := timerPool.Get() + timer := t.(*time.Timer) + timer.Reset(s.config.ConnectionWriteTimeout) + defer func() { + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + timerPool.Put(t) + }() + + ready := sendReady{Body: nil, Err: make(chan error, 1)} + s.goAway(goAwayNormal, &ready.Hdr) + + return s.waitForSendErrTimeout(nil, timer.C, &ready) } // goAway is used to send a goAway message -func (s *Session) goAway(reason uint32) header { +func (s *Session) goAway(reason uint32, hdr *header) { atomic.SwapInt32(&s.localGoAway, 1) - hdr := header(make([]byte, headerSize)) hdr.encode(typeGoAway, 0, 0, reason) - return hdr } // Ping is used to measure the RTT response time @@ -282,10 +328,23 @@ func (s *Session) Ping() (time.Duration, error) { s.pings[id] = ch s.pingLock.Unlock() + t := timerPool.Get() + timer := t.(*time.Timer) + timer.Reset(s.config.ConnectionWriteTimeout) + defer func() { + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + timerPool.Put(t) + }() + // Send the ping request - hdr := header(make([]byte, headerSize)) - hdr.encode(typePing, flagSYN, 0, id) - if err := s.waitForSend(hdr, nil); err != nil { + ready := sendReady{Body: nil, Err: make(chan error, 1)} + ready.Hdr.encode(typePing, flagSYN, 0, id) + if err := s.waitForSendErrTimeout(nil, timer.C, &ready); err != nil { return 0, err } @@ -293,7 +352,7 @@ func (s *Session) Ping() (time.Duration, error) { start := time.Now() select { case <-ch: - case <-time.After(s.config.ConnectionWriteTimeout): + case <-timer.C: s.pingLock.Lock() delete(s.pings, id) // Ignore it if a response comes later. s.pingLock.Unlock() @@ -309,60 +368,71 @@ func (s *Session) Ping() (time.Duration, error) { // keepalive is a long running goroutine that periodically does // a ping to keep the connection alive. func (s *Session) keepalive() { + ticker := time.NewTicker(s.config.KeepAliveInterval) + defer func() { + ticker.Stop() + }() + + lcheckcount := 0 + for { select { - case <-time.After(s.config.KeepAliveInterval): - _, err := s.Ping() + case <-ticker.C: + rtt, err := s.Ping() if err != nil { if err != ErrSessionShutdown { s.logger.Printf("[ERR] yamux: keepalive failed: %v", err) - s.exitErr(ErrKeepAliveTimeout) + s.close(ErrKeepAliveTimeout) } return } + + lsynChLen := len(s.synCh) + lacceptChLen := len(s.acceptCh) + + if (lsynChLen >= s.config.AcceptBacklog) || (lacceptChLen >= s.config.AcceptBacklog) { + lcheckcount++ + } else { + lcheckcount = 0 + } + + if lcheckcount >= 5 { + s.logger.Printf("[WARN] yamux: too long synCh(%d)/acceptCh(%d) for last %d keepalive intervals, something wentwrong, so close backend session", lsynChLen, lacceptChLen, lcheckcount) + s.close(ErrKeepAliveTimeout) + return + } + + if rtt >= s.config.KeepAliveInterval { + s.logger.Printf("[WARN] yamux: keepalive ping too long: %.01f seconds", rtt.Seconds()) + } case <-s.shutdownCh: return } } } -// waitForSendErr waits to send a header, checking for a potential shutdown -func (s *Session) waitForSend(hdr header, body io.Reader) error { - errCh := make(chan error, 1) - return s.waitForSendErr(hdr, body, errCh) -} - -// waitForSendErr waits to send a header with optional data, checking for a +// waitForSendErrTimeout waits to send a header with optional data, checking for a // potential shutdown. Since there's the expectation that sends can happen // in a timely manner, we enforce the connection write timeout here. -func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) error { - t := timerPool.Get() - timer := t.(*time.Timer) - timer.Reset(s.config.ConnectionWriteTimeout) - defer func() { - timer.Stop() - select { - case <-timer.C: - default: - } - timerPool.Put(t) - }() - - ready := sendReady{Hdr: hdr, Body: body, Err: errCh} +func (s *Session) waitForSendErrTimeout(cancel <-chan struct{}, timeout <-chan time.Time, ready *sendReady) error { select { - case s.sendCh <- ready: + case s.sendCh <- *ready: case <-s.shutdownCh: return ErrSessionShutdown - case <-timer.C: + case <-cancel: + return ErrCallCanceled + case <-timeout: return ErrConnectionWriteTimeout } select { - case err := <-errCh: + case err := <-ready.Err: return err case <-s.shutdownCh: return ErrSessionShutdown - case <-timer.C: + case <-cancel: + return ErrCallCanceled + case <-timeout: return ErrConnectionWriteTimeout } } @@ -370,7 +440,7 @@ func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) e // sendNoWait does a send without waiting. Since there's the expectation that // the send happens right here, we enforce the connection write timeout if we // can't queue the header to be sent. -func (s *Session) sendNoWait(hdr header) error { +func (s *Session) sendNoWait(ready *sendReady) error { t := timerPool.Get() timer := t.(*time.Timer) timer.Reset(s.config.ConnectionWriteTimeout) @@ -384,7 +454,7 @@ func (s *Session) sendNoWait(hdr header) error { }() select { - case s.sendCh <- sendReady{Hdr: hdr}: + case s.sendCh <- *ready: return nil case <-s.shutdownCh: return ErrSessionShutdown @@ -395,22 +465,26 @@ func (s *Session) sendNoWait(hdr header) error { // send is a long running goroutine that sends data func (s *Session) send() { + if err := s.sendLoop(); err != nil { + s.close(err) + } +} + +func (s *Session) sendLoop() error { + var ready sendReady for { select { - case ready := <-s.sendCh: + case ready = <-s.sendCh: // Send a header if ready - if ready.Hdr != nil { - sent := 0 - for sent < len(ready.Hdr) { - n, err := s.conn.Write(ready.Hdr[sent:]) - if err != nil { - s.logger.Printf("[ERR] yamux: Failed to write header: %v", err) - asyncSendErr(ready.Err, err) - s.exitErr(err) - return - } - sent += n + sent := 0 + for sent < headerSize { + n, err := s.conn.Write(ready.Hdr[sent:]) + if err != nil { + s.logger.Printf("[ERR] yamux: Failed to write header: %v", err) + asyncSendErr(ready.Err, err) + return err } + sent += n } // Send data from a body if given @@ -419,15 +493,14 @@ func (s *Session) send() { if err != nil { s.logger.Printf("[ERR] yamux: Failed to write body: %v", err) asyncSendErr(ready.Err, err) - s.exitErr(err) - return + return err } } // No error, successful send asyncSendErr(ready.Err, nil) case <-s.shutdownCh: - return + return nil } } } @@ -435,7 +508,7 @@ func (s *Session) send() { // recv is a long running goroutine that accepts new data func (s *Session) recv() { if err := s.recvLoop(); err != nil { - s.exitErr(err) + s.close(err) } } @@ -452,10 +525,10 @@ var ( // recvLoop continues to receive data until a fatal error is encountered func (s *Session) recvLoop() error { defer close(s.recvDoneCh) - hdr := header(make([]byte, headerSize)) + var hdr header for { // Read the header - if _, err := io.ReadFull(s.bufRead, hdr); err != nil { + if _, err := io.ReadFull(s.bufRead, hdr[:]); err != nil { if err != io.EOF && !strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "reset by peer") { s.logger.Printf("[ERR] yamux: Failed to read header: %v", err) } @@ -502,7 +575,7 @@ func (s *Session) handleStreamMessage(hdr header) error { s.logger.Printf("[WARN] yamux: Discarding data for stream: %d", id) if _, err := io.CopyN(ioutil.Discard, s.bufRead, int64(hdr.Length())); err != nil { s.logger.Printf("[ERR] yamux: Failed to discard data: %v", err) - return nil + return err } } else { s.logger.Printf("[WARN] yamux: frame for missing stream: %v", hdr) @@ -510,20 +583,29 @@ func (s *Session) handleStreamMessage(hdr header) error { return nil } + if err := stream.processFlags(flags); err != nil { + var ready sendReady + s.goAway(goAwayProtoErr, &ready.Hdr) + + if sendErr := s.sendNoWait(&ready); sendErr != nil { + s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr) + } + + return err + } + // Check if this is a window update if hdr.MsgType() == typeWindowUpdate { - if err := stream.incrSendWindow(hdr, flags); err != nil { - if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil { - s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr) - } - return err - } + stream.incrSetWindow(hdr) return nil } // Read the new data - if err := stream.readData(hdr, flags, s.bufRead); err != nil { - if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil { + if err := stream.readData(hdr.Length(), flags, s.bufRead); err != nil { + var ready sendReady + s.goAway(goAwayProtoErr, &ready.Hdr) + + if sendErr := s.sendNoWait(&ready); sendErr != nil { s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr) } return err @@ -540,9 +622,10 @@ func (s *Session) handlePing(hdr header) error { // don't interfere with the receiving thread blocking for the write. if flags&flagSYN == flagSYN { go func() { - hdr := header(make([]byte, headerSize)) - hdr.encode(typePing, flagACK, 0, pingID) - if err := s.sendNoWait(hdr); err != nil { + var ready sendReady + ready.Hdr.encode(typePing, flagACK, 0, pingID) + + if err := s.sendNoWait(&ready); err != nil { s.logger.Printf("[WARN] yamux: failed to send ping reply: %v", err) } }() @@ -583,13 +666,11 @@ func (s *Session) handleGoAway(hdr header) error { func (s *Session) incomingStream(id uint32) error { // Reject immediately if we are doing a go away if atomic.LoadInt32(&s.localGoAway) == 1 { - hdr := header(make([]byte, headerSize)) - hdr.encode(typeWindowUpdate, flagRST, id, 0) - return s.sendNoWait(hdr) - } + var ready sendReady + ready.Hdr.encode(typeWindowUpdate, flagRST, id, 0) - // Allocate a new stream - stream := newStream(s, id, streamSYNReceived) + return s.sendNoWait(&ready) + } s.streamLock.Lock() defer s.streamLock.Unlock() @@ -597,25 +678,33 @@ func (s *Session) incomingStream(id uint32) error { // Check if stream already exists if _, ok := s.streams[id]; ok { s.logger.Printf("[ERR] yamux: duplicate stream declared") - if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil { + var ready sendReady + s.goAway(goAwayProtoErr, &ready.Hdr) + + if sendErr := s.sendNoWait(&ready); sendErr != nil { s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr) } return ErrDuplicateStream } - // Register the stream - s.streams[id] = stream + // Allocate a new stream + stream := newStream(s, id, streamSYNReceived) // Check if we've exceeded the backlog select { case s.acceptCh <- stream: + // Register the stream + s.streams[id] = stream return nil default: // Backlog exceeded! RST the stream s.logger.Printf("[WARN] yamux: backlog exceeded, forcing connection reset") delete(s.streams, id) - stream.sendHdr.encode(typeWindowUpdate, flagRST, id, 0) - return s.sendNoWait(stream.sendHdr) + + var ready sendReady + ready.Hdr.encode(typeWindowUpdate, flagRST, id, 0) + + return s.sendNoWait(&ready) } } @@ -624,7 +713,9 @@ func (s *Session) incomingStream(id uint32) error { // was not yet established, then this will give the credit back. func (s *Session) closeStream(id uint32) { s.streamLock.Lock() + defer s.streamLock.Unlock() if _, ok := s.inflight[id]; ok { + delete(s.inflight, id) select { case <-s.synCh: default: @@ -632,13 +723,13 @@ func (s *Session) closeStream(id uint32) { } } delete(s.streams, id) - s.streamLock.Unlock() } // establishStream is used to mark a stream that was in the // SYN Sent state as established. func (s *Session) establishStream(id uint32) { s.streamLock.Lock() + defer s.streamLock.Unlock() if _, ok := s.inflight[id]; ok { delete(s.inflight, id) } else { @@ -649,5 +740,4 @@ func (s *Session) establishStream(id uint32) { default: s.logger.Printf("[ERR] yamux: established stream without inflight SYN (didn't have semaphore)") } - s.streamLock.Unlock() } diff --git a/session_test.go b/session_test.go index 1645e2b..42de8c8 100644 --- a/session_test.go +++ b/session_test.go @@ -1120,10 +1120,10 @@ func TestSession_sendNoWait_Timeout(t *testing.T) { conn := client.conn.(*pipeConn) conn.writeBlocker.Lock() - hdr := header(make([]byte, headerSize)) - hdr.encode(typePing, flagACK, 0, 0) + var ready sendReady + ready.Hdr.encode(typePing, flagACK, 0, 0) for { - err = client.sendNoWait(hdr) + err = client.sendNoWait(&ready) if err == nil { continue } else if err == ErrConnectionWriteTimeout { @@ -1164,9 +1164,9 @@ func TestSession_PingOfDeath(t *testing.T) { conn.writeBlocker.Lock() for { - hdr := header(make([]byte, headerSize)) - hdr.encode(typePing, 0, 0, 0) - err = server.sendNoWait(hdr) + var ready sendReady + ready.Hdr.encode(typePing, 0, 0, 0) + err = server.sendNoWait(&ready) if err == nil { continue } else if err == ErrConnectionWriteTimeout { diff --git a/stream.go b/stream.go index aa23919..ec3fd96 100644 --- a/stream.go +++ b/stream.go @@ -2,6 +2,7 @@ package yamux import ( "bytes" + "context" "io" "sync" "sync/atomic" @@ -36,19 +37,16 @@ type Stream struct { recvBuf *bytes.Buffer recvLock sync.Mutex - controlHdr header - controlErr chan error - controlHdrLock sync.Mutex - - sendHdr header - sendErr chan error - sendLock sync.Mutex + controlErr chan error + sendErr chan error + sendLock sync.Mutex recvNotifyCh chan struct{} sendNotifyCh chan struct{} - readDeadline atomic.Value // time.Time - writeDeadline atomic.Value // time.Time + readDeadline atomic.Value // time.Time + writeDeadline atomic.Value // time.Time + cancelDeadline atomic.Value } // newStream is used to construct a new stream within @@ -58,9 +56,7 @@ func newStream(session *Session, id uint32, state streamState) *Stream { id: id, session: session, state: state, - controlHdr: header(make([]byte, headerSize)), controlErr: make(chan error, 1), - sendHdr: header(make([]byte, headerSize)), sendErr: make(chan error, 1), recvWindow: initialStreamWindow, sendWindow: initialStreamWindow, @@ -84,7 +80,15 @@ func (s *Stream) StreamID() uint32 { // Read is used to read from the stream func (s *Stream) Read(b []byte) (n int, err error) { - defer asyncNotify(s.recvNotifyCh) + var cancel <-chan struct{} + cancelval := s.cancelDeadline.Load() + if cancelval != nil { + cancel = cancelval.(<-chan struct{}) + } + return s.read(cancel, b) +} + +func (s *Stream) read(cancel <-chan struct{}, b []byte) (n int, err error) { START: s.stateLock.Lock() switch s.state { @@ -99,7 +103,11 @@ START: s.stateLock.Unlock() return 0, io.EOF } + n, _ = s.recvBuf.Read(b) s.recvLock.Unlock() + + s.stateLock.Unlock() + return n, nil case streamReset: s.stateLock.Unlock() return 0, ErrConnectionReset @@ -136,6 +144,8 @@ WAIT: timer.Stop() } goto START + case <-cancel: + return 0, ErrCallCanceled case <-timeout: return 0, ErrTimeout } @@ -146,8 +156,15 @@ func (s *Stream) Write(b []byte) (n int, err error) { s.sendLock.Lock() defer s.sendLock.Unlock() total := 0 + + var cancel <-chan struct{} + cancelval := s.cancelDeadline.Load() + if cancelval != nil { + cancel = cancelval.(<-chan struct{}) + } + for total < len(b) { - n, err := s.write(b[total:]) + n, err := s.write(cancel, b[total:]) total += n if err != nil { return total, err @@ -158,10 +175,24 @@ func (s *Stream) Write(b []byte) (n int, err error) { // write is used to write to the stream, may return on // a short write. -func (s *Stream) write(b []byte) (n int, err error) { +func (s *Stream) write(cancel <-chan struct{}, b []byte) (n int, err error) { var flags uint16 var max uint32 var body io.Reader + var ready sendReady + t := timerPool.Get() + timer := t.(*time.Timer) + timer.Reset(s.session.config.ConnectionWriteTimeout) + defer func() { + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + timerPool.Put(t) + }() + START: s.stateLock.Lock() switch s.state { @@ -190,8 +221,10 @@ START: body = bytes.NewReader(b[:max]) // Send the header - s.sendHdr.encode(typeData, flags, s.id, max) - if err = s.session.waitForSendErr(s.sendHdr, body, s.sendErr); err != nil { + ready.Body = body + ready.Err = s.sendErr + ready.Hdr.encode(typeData, flags, s.id, max) + if err = s.session.waitForSendErrTimeout(cancel, timer.C, &ready); err != nil { return 0, err } @@ -211,6 +244,8 @@ WAIT: select { case <-s.sendNotifyCh: goto START + case <-cancel: + return 0, ErrCallCanceled case <-timeout: return 0, ErrTimeout } @@ -237,9 +272,23 @@ func (s *Stream) sendFlags() uint16 { // sendWindowUpdate potentially sends a window update enabling // further writes to take place. Must be invoked with the lock. func (s *Stream) sendWindowUpdate() error { - s.controlHdrLock.Lock() - defer s.controlHdrLock.Unlock() + t := timerPool.Get() + timer := t.(*time.Timer) + timer.Reset(s.session.config.ConnectionWriteTimeout) + defer func() { + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + timerPool.Put(t) + }() + + return s.sendWindowUpdateTimeout(nil, timer.C) +} +func (s *Stream) sendWindowUpdateTimeout(cancel <-chan struct{}, timeout <-chan time.Time) error { // Determine the delta update max := s.session.config.MaxStreamWindowSize var bufLen uint32 @@ -263,8 +312,9 @@ func (s *Stream) sendWindowUpdate() error { s.recvLock.Unlock() // Send the header - s.controlHdr.encode(typeWindowUpdate, flags, s.id, delta) - if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil { + ready := sendReady{Body: nil, Err: s.controlErr} + ready.Hdr.encode(typeWindowUpdate, flags, s.id, delta) + if err := s.session.waitForSendErrTimeout(cancel, timeout, &ready); err != nil { return err } return nil @@ -272,13 +322,25 @@ func (s *Stream) sendWindowUpdate() error { // sendClose is used to send a FIN func (s *Stream) sendClose() error { - s.controlHdrLock.Lock() - defer s.controlHdrLock.Unlock() + t := timerPool.Get() + timer := t.(*time.Timer) + timer.Reset(s.session.config.ConnectionWriteTimeout) + defer func() { + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + timerPool.Put(t) + }() flags := s.sendFlags() flags |= flagFIN - s.controlHdr.encode(typeWindowUpdate, flags, s.id, 0) - if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil { + + ready := sendReady{Body: nil, Err: s.controlErr} + ready.Hdr.encode(typeWindowUpdate, flags, s.id, 0) + if err := s.session.waitForSendErrTimeout(nil, timer.C, &ready); err != nil { return err } return nil @@ -337,6 +399,7 @@ func (s *Stream) processFlags(flags uint16) error { defer func() { if closeStream { s.session.closeStream(s.id) + s.notifyWaiting() } }() @@ -360,7 +423,6 @@ func (s *Stream) processFlags(flags uint16) error { case streamLocalClose: s.state = streamClosed closeStream = true - s.notifyWaiting() default: s.session.logger.Printf("[ERR] yamux: unexpected FIN flag in state %d", s.state) return ErrUnexpectedFlag @@ -369,7 +431,6 @@ func (s *Stream) processFlags(flags uint16) error { if flags&flagRST == flagRST { s.state = streamReset closeStream = true - s.notifyWaiting() } return nil } @@ -381,37 +442,25 @@ func (s *Stream) notifyWaiting() { } // incrSendWindow updates the size of our send window -func (s *Stream) incrSendWindow(hdr header, flags uint16) error { - if err := s.processFlags(flags); err != nil { - return err - } - +func (s *Stream) incrSetWindow(hdr header) { // Increase window, unblock a sender atomic.AddUint32(&s.sendWindow, hdr.Length()) asyncNotify(s.sendNotifyCh) - return nil } // readData is used to handle a data frame -func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error { - if err := s.processFlags(flags); err != nil { - return err - } - +func (s *Stream) readData(length uint32, flags uint16, conn io.Reader) error { // Check that our recv window is not exceeded - length := hdr.Length() if length == 0 { return nil } - // Wrap in a limited reader - conn = &io.LimitedReader{R: conn, N: int64(length)} - // Copy into buffer s.recvLock.Lock() if length > s.recvWindow { s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, s.recvWindow, length) + s.recvLock.Unlock() return ErrRecvWindowExceeded } @@ -420,7 +469,7 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error { // This way we can read in the whole packet without further allocations. s.recvBuf = bytes.NewBuffer(make([]byte, 0, length)) } - if _, err := io.Copy(s.recvBuf, conn); err != nil { + if _, err := io.CopyN(s.recvBuf, conn, int64(length)); err != nil { s.session.logger.Printf("[ERR] yamux: Failed to read stream data: %v", err) s.recvLock.Unlock() return err @@ -458,6 +507,15 @@ func (s *Stream) SetWriteDeadline(t time.Time) error { return nil } +// SetContext set stream context +func (s *Stream) SetContext(ctx context.Context) { + if ctx != nil { + s.cancelDeadline.Store(ctx.Done()) + } else { + s.cancelDeadline.Store(nil) + } +} + // Shrink is used to compact the amount of buffers utilized // This is useful when using Yamux in a connection pool to reduce // the idle memory utilization. diff --git a/util.go b/util.go index 8a73e92..ff1a31b 100644 --- a/util.go +++ b/util.go @@ -9,7 +9,12 @@ var ( timerPool = &sync.Pool{ New: func() interface{} { timer := time.NewTimer(time.Hour * 1e6) - timer.Stop() + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } return timer }, }