diff --git a/tools/remotecommand/websocket.go b/tools/remotecommand/websocket.go index a60986decc..49ef4717cd 100644 --- a/tools/remotecommand/websocket.go +++ b/tools/remotecommand/websocket.go @@ -187,6 +187,9 @@ type wsStreamCreator struct { // map of stream id to stream; multiple streams read/write the connection streams map[byte]*stream streamsMu sync.Mutex + // setStreamErr holds the error to return to anyone calling setStreams. + // this is populated in closeAllStreamReaders + setStreamErr error } func newWSStreamCreator(conn *gwebsocket.Conn) *wsStreamCreator { @@ -202,10 +205,14 @@ func (c *wsStreamCreator) getStream(id byte) *stream { return c.streams[id] } -func (c *wsStreamCreator) setStream(id byte, s *stream) { +func (c *wsStreamCreator) setStream(id byte, s *stream) error { c.streamsMu.Lock() defer c.streamsMu.Unlock() + if c.setStreamErr != nil { + return c.setStreamErr + } c.streams[id] = s + return nil } // CreateStream uses id from passed headers to create a stream over "c.conn" connection. @@ -228,7 +235,11 @@ func (c *wsStreamCreator) CreateStream(headers http.Header) (httpstream.Stream, connWriteLock: &c.connWriteLock, id: id, } - c.setStream(id, s) + if err := c.setStream(id, s); err != nil { + _ = s.writePipe.Close() + _ = s.readPipe.Close() + return nil, err + } return s, nil } @@ -312,7 +323,7 @@ func (c *wsStreamCreator) readDemuxLoop(bufferSize int, period time.Duration, de } // closeAllStreamReaders closes readers in all streams. -// This unblocks all stream.Read() calls. +// This unblocks all stream.Read() calls, and keeps any future streams from being created. func (c *wsStreamCreator) closeAllStreamReaders(err error) { c.streamsMu.Lock() defer c.streamsMu.Unlock() @@ -320,6 +331,12 @@ func (c *wsStreamCreator) closeAllStreamReaders(err error) { // Closing writePipe unblocks all readPipe.Read() callers and prevents any future writes. _ = s.writePipe.CloseWithError(err) } + // ensure callers to setStreams receive an error after this point + if err != nil { + c.setStreamErr = err + } else { + c.setStreamErr = fmt.Errorf("closed all streams") + } } type stream struct { diff --git a/tools/remotecommand/websocket_test.go b/tools/remotecommand/websocket_test.go index a6a1baf567..4a333b0b24 100644 --- a/tools/remotecommand/websocket_test.go +++ b/tools/remotecommand/websocket_test.go @@ -1116,6 +1116,14 @@ func TestWebSocketClient_HeartbeatSucceeds(t *testing.T) { wg.Wait() } +func TestLateStreamCreation(t *testing.T) { + c := newWSStreamCreator(nil) + c.closeAllStreamReaders(nil) + if err := c.setStream(0, nil); err == nil { + t.Fatal("expected error adding stream after closeAllStreamReaders") + } +} + func TestWebSocketClient_StreamsAndExpectedErrors(t *testing.T) { // Validate Stream functions. c := newWSStreamCreator(nil)