Skip to content

Commit

Permalink
Keep streams from being set up after closeAllStreamReaders is called
Browse files Browse the repository at this point in the history
Kubernetes-commit: efd8578ac75459df19e7589b2767fbdbbc288383
  • Loading branch information
liggitt authored and k8s-publishing-bot committed Feb 29, 2024
1 parent 8636987 commit cc21122
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
23 changes: 20 additions & 3 deletions tools/remotecommand/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,9 @@ type wsStreamCreator struct {
// map of stream id to stream; multiple streams read/write the connection
streams map[byte]*stream
streamsMu sync.Mutex
// setStreamErr holds the error to return to anyone calling setStreams.
// this is populated in closeAllStreamReaders
setStreamErr error
}

func newWSStreamCreator(conn *gwebsocket.Conn) *wsStreamCreator {
Expand All @@ -202,10 +205,14 @@ func (c *wsStreamCreator) getStream(id byte) *stream {
return c.streams[id]
}

func (c *wsStreamCreator) setStream(id byte, s *stream) {
func (c *wsStreamCreator) setStream(id byte, s *stream) error {
c.streamsMu.Lock()
defer c.streamsMu.Unlock()
if c.setStreamErr != nil {
return c.setStreamErr
}
c.streams[id] = s
return nil
}

// CreateStream uses id from passed headers to create a stream over "c.conn" connection.
Expand All @@ -228,7 +235,11 @@ func (c *wsStreamCreator) CreateStream(headers http.Header) (httpstream.Stream,
connWriteLock: &c.connWriteLock,
id: id,
}
c.setStream(id, s)
if err := c.setStream(id, s); err != nil {
_ = s.writePipe.Close()
_ = s.readPipe.Close()
return nil, err
}
return s, nil
}

Expand Down Expand Up @@ -312,14 +323,20 @@ func (c *wsStreamCreator) readDemuxLoop(bufferSize int, period time.Duration, de
}

// closeAllStreamReaders closes readers in all streams.
// This unblocks all stream.Read() calls.
// This unblocks all stream.Read() calls, and keeps any future streams from being created.
func (c *wsStreamCreator) closeAllStreamReaders(err error) {
c.streamsMu.Lock()
defer c.streamsMu.Unlock()
for _, s := range c.streams {
// Closing writePipe unblocks all readPipe.Read() callers and prevents any future writes.
_ = s.writePipe.CloseWithError(err)
}
// ensure callers to setStreams receive an error after this point
if err != nil {
c.setStreamErr = err
} else {
c.setStreamErr = fmt.Errorf("closed all streams")
}
}

type stream struct {
Expand Down
8 changes: 8 additions & 0 deletions tools/remotecommand/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1116,6 +1116,14 @@ func TestWebSocketClient_HeartbeatSucceeds(t *testing.T) {
wg.Wait()
}

func TestLateStreamCreation(t *testing.T) {
c := newWSStreamCreator(nil)
c.closeAllStreamReaders(nil)
if err := c.setStream(0, nil); err == nil {
t.Fatal("expected error adding stream after closeAllStreamReaders")
}
}

func TestWebSocketClient_StreamsAndExpectedErrors(t *testing.T) {
// Validate Stream functions.
c := newWSStreamCreator(nil)
Expand Down

0 comments on commit cc21122

Please sign in to comment.