@@ -187,6 +187,9 @@ type wsStreamCreator struct {
187
187
// map of stream id to stream; multiple streams read/write the connection
188
188
streams map [byte ]* stream
189
189
streamsMu sync.Mutex
190
+ // setStreamErr holds the error to return to anyone calling setStreams.
191
+ // this is populated in closeAllStreamReaders
192
+ setStreamErr error
190
193
}
191
194
192
195
func newWSStreamCreator (conn * gwebsocket.Conn ) * wsStreamCreator {
@@ -202,10 +205,14 @@ func (c *wsStreamCreator) getStream(id byte) *stream {
202
205
return c .streams [id ]
203
206
}
204
207
205
- func (c * wsStreamCreator ) setStream (id byte , s * stream ) {
208
+ func (c * wsStreamCreator ) setStream (id byte , s * stream ) error {
206
209
c .streamsMu .Lock ()
207
210
defer c .streamsMu .Unlock ()
211
+ if c .setStreamErr != nil {
212
+ return c .setStreamErr
213
+ }
208
214
c .streams [id ] = s
215
+ return nil
209
216
}
210
217
211
218
// CreateStream uses id from passed headers to create a stream over "c.conn" connection.
@@ -228,7 +235,11 @@ func (c *wsStreamCreator) CreateStream(headers http.Header) (httpstream.Stream,
228
235
connWriteLock : & c .connWriteLock ,
229
236
id : id ,
230
237
}
231
- c .setStream (id , s )
238
+ if err := c .setStream (id , s ); err != nil {
239
+ _ = s .writePipe .Close ()
240
+ _ = s .readPipe .Close ()
241
+ return nil , err
242
+ }
232
243
return s , nil
233
244
}
234
245
@@ -312,14 +323,20 @@ func (c *wsStreamCreator) readDemuxLoop(bufferSize int, period time.Duration, de
312
323
}
313
324
314
325
// closeAllStreamReaders closes readers in all streams.
315
- // This unblocks all stream.Read() calls.
326
+ // This unblocks all stream.Read() calls, and keeps any future streams from being created .
316
327
func (c * wsStreamCreator ) closeAllStreamReaders (err error ) {
317
328
c .streamsMu .Lock ()
318
329
defer c .streamsMu .Unlock ()
319
330
for _ , s := range c .streams {
320
331
// Closing writePipe unblocks all readPipe.Read() callers and prevents any future writes.
321
332
_ = s .writePipe .CloseWithError (err )
322
333
}
334
+ // ensure callers to setStreams receive an error after this point
335
+ if err != nil {
336
+ c .setStreamErr = err
337
+ } else {
338
+ c .setStreamErr = fmt .Errorf ("closed all streams" )
339
+ }
323
340
}
324
341
325
342
type stream struct {
0 commit comments