Skip to content

Commit cc21122

Browse files
liggittk8s-publishing-bot
authored andcommitted
Keep streams from being set up after closeAllStreamReaders is called
Kubernetes-commit: efd8578ac75459df19e7589b2767fbdbbc288383
1 parent 8636987 commit cc21122

File tree

2 files changed

+28
-3
lines changed

2 files changed

+28
-3
lines changed

tools/remotecommand/websocket.go

+20-3
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,9 @@ type wsStreamCreator struct {
187187
// map of stream id to stream; multiple streams read/write the connection
188188
streams map[byte]*stream
189189
streamsMu sync.Mutex
190+
// setStreamErr holds the error to return to anyone calling setStreams.
191+
// this is populated in closeAllStreamReaders
192+
setStreamErr error
190193
}
191194

192195
func newWSStreamCreator(conn *gwebsocket.Conn) *wsStreamCreator {
@@ -202,10 +205,14 @@ func (c *wsStreamCreator) getStream(id byte) *stream {
202205
return c.streams[id]
203206
}
204207

205-
func (c *wsStreamCreator) setStream(id byte, s *stream) {
208+
func (c *wsStreamCreator) setStream(id byte, s *stream) error {
206209
c.streamsMu.Lock()
207210
defer c.streamsMu.Unlock()
211+
if c.setStreamErr != nil {
212+
return c.setStreamErr
213+
}
208214
c.streams[id] = s
215+
return nil
209216
}
210217

211218
// CreateStream uses id from passed headers to create a stream over "c.conn" connection.
@@ -228,7 +235,11 @@ func (c *wsStreamCreator) CreateStream(headers http.Header) (httpstream.Stream,
228235
connWriteLock: &c.connWriteLock,
229236
id: id,
230237
}
231-
c.setStream(id, s)
238+
if err := c.setStream(id, s); err != nil {
239+
_ = s.writePipe.Close()
240+
_ = s.readPipe.Close()
241+
return nil, err
242+
}
232243
return s, nil
233244
}
234245

@@ -312,14 +323,20 @@ func (c *wsStreamCreator) readDemuxLoop(bufferSize int, period time.Duration, de
312323
}
313324

314325
// closeAllStreamReaders closes readers in all streams.
315-
// This unblocks all stream.Read() calls.
326+
// This unblocks all stream.Read() calls, and keeps any future streams from being created.
316327
func (c *wsStreamCreator) closeAllStreamReaders(err error) {
317328
c.streamsMu.Lock()
318329
defer c.streamsMu.Unlock()
319330
for _, s := range c.streams {
320331
// Closing writePipe unblocks all readPipe.Read() callers and prevents any future writes.
321332
_ = s.writePipe.CloseWithError(err)
322333
}
334+
// ensure callers to setStreams receive an error after this point
335+
if err != nil {
336+
c.setStreamErr = err
337+
} else {
338+
c.setStreamErr = fmt.Errorf("closed all streams")
339+
}
323340
}
324341

325342
type stream struct {

tools/remotecommand/websocket_test.go

+8
Original file line numberDiff line numberDiff line change
@@ -1116,6 +1116,14 @@ func TestWebSocketClient_HeartbeatSucceeds(t *testing.T) {
11161116
wg.Wait()
11171117
}
11181118

1119+
func TestLateStreamCreation(t *testing.T) {
1120+
c := newWSStreamCreator(nil)
1121+
c.closeAllStreamReaders(nil)
1122+
if err := c.setStream(0, nil); err == nil {
1123+
t.Fatal("expected error adding stream after closeAllStreamReaders")
1124+
}
1125+
}
1126+
11191127
func TestWebSocketClient_StreamsAndExpectedErrors(t *testing.T) {
11201128
// Validate Stream functions.
11211129
c := newWSStreamCreator(nil)

0 commit comments

Comments
 (0)