@@ -102,6 +102,8 @@ type Session struct {
102
102
// recvDoneCh is closed when recv() exits to avoid a race
103
103
// between stream registration and stream shutdown
104
104
recvDoneCh chan struct {}
105
+ // recvErr is the error the receive loop ended with
106
+ recvErr error
105
107
106
108
// sendDoneCh is closed when send() exits to avoid a race
107
109
// between returning from a Stream.Write and exiting from the send loop
@@ -288,10 +290,18 @@ func (s *Session) AcceptStream() (*Stream, error) {
288
290
// semantics of the underlying net.Conn. For TCP connections, it may be dropped depending on LINGER value or
289
291
// if there's unread data in the kernel receive buffer.
290
292
func (s * Session ) Close () error {
291
- return s .close ( true , goAwayNormal )
293
+ return s .closeWithGoAway ( goAwayNormal )
292
294
}
293
295
294
- func (s * Session ) close (sendGoAway bool , errCode uint32 ) error {
296
+ // CloseWithError is used to close the session and all streams after sending a GoAway message with errCode.
297
+ // The GoAway may not actually be sent depending on the semantics of the underlying net.Conn.
298
+ // For TCP connections, it may be dropped depending on LINGER value or if there's unread data in the kernel
299
+ // receive buffer.
300
+ func (s * Session ) CloseWithError (errCode uint32 ) error {
301
+ return s .closeWithGoAway (errCode )
302
+ }
303
+
304
+ func (s * Session ) closeWithGoAway (errCode uint32 ) error {
295
305
s .shutdownLock .Lock ()
296
306
defer s .shutdownLock .Unlock ()
297
307
@@ -308,14 +318,12 @@ func (s *Session) close(sendGoAway bool, errCode uint32) error {
308
318
// wait for write loop to exit
309
319
_ = s .conn .SetWriteDeadline (time .Now ().Add (- 1 * time .Hour )) // if SetWriteDeadline errored, any blocked writes will be unblocked
310
320
<- s .sendDoneCh
311
- if sendGoAway {
312
- ga := s .goAway (errCode )
313
- if err := s .conn .SetWriteDeadline (time .Now ().Add (goAwayWaitTime )); err == nil {
314
- _ , _ = s .conn .Write (ga [:]) // there's nothing we can do on error here
315
- }
321
+ ga := s .goAway (errCode )
322
+ if err := s .conn .SetWriteDeadline (time .Now ().Add (goAwayWaitTime )); err == nil {
323
+ _ , _ = s .conn .Write (ga [:]) // there's nothing we can do on error here
316
324
}
317
-
318
325
s .conn .SetWriteDeadline (time.Time {})
326
+
319
327
s .conn .Close ()
320
328
<- s .recvDoneCh
321
329
@@ -329,15 +337,37 @@ func (s *Session) close(sendGoAway bool, errCode uint32) error {
329
337
return nil
330
338
}
331
339
332
- // exitErr is used to handle an error that is causing the
333
- // session to terminate.
334
- func (s * Session ) exitErr (err error ) {
340
+ func (s * Session ) closeWithoutGoAway (err error ) error {
335
341
s .shutdownLock .Lock ()
342
+ defer s .shutdownLock .Unlock ()
343
+ if s .shutdown {
344
+ return nil
345
+ }
346
+ s .shutdown = true
336
347
if s .shutdownErr == nil {
337
348
s .shutdownErr = err
338
349
}
339
- s .shutdownLock .Unlock ()
340
- s .close (false , 0 )
350
+
351
+ s .conn .Close ()
352
+ <- s .recvDoneCh
353
+ // Prefer the recvLoop error over the sendLoop error. The receive loop might have the error code
354
+ // received in a GoAway frame received just before the RST that closed the sendLoop
355
+ if _ , ok := s .recvErr .(* GoAwayError ); ok {
356
+ s .shutdownErr = s .recvErr
357
+ }
358
+ close (s .shutdownCh )
359
+
360
+ s .stopKeepalive ()
361
+ <- s .sendDoneCh
362
+
363
+ s .streamLock .Lock ()
364
+ defer s .streamLock .Unlock ()
365
+ for id , stream := range s .streams {
366
+ stream .forceClose ()
367
+ delete (s .streams , id )
368
+ stream .memorySpan .Done ()
369
+ }
370
+ return nil
341
371
}
342
372
343
373
// GoAway can be used to prevent accepting further
@@ -468,7 +498,12 @@ func (s *Session) startKeepalive() {
468
498
469
499
if err != nil {
470
500
s .logger .Printf ("[ERR] yamux: keepalive failed: %v" , err )
471
- s .exitErr (ErrKeepAliveTimeout )
501
+ s .shutdownLock .Lock ()
502
+ if s .shutdownErr == nil {
503
+ s .shutdownErr = ErrKeepAliveTimeout
504
+ }
505
+ s .shutdownLock .Unlock ()
506
+ s .closeWithGoAway (goAwayNormal )
472
507
}
473
508
})
474
509
}
@@ -533,7 +568,7 @@ func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}) err
533
568
// send is a long running goroutine that sends data
534
569
func (s * Session ) send () {
535
570
if err := s .sendLoop (); err != nil {
536
- s .exitErr (err )
571
+ s .closeWithoutGoAway (err )
537
572
}
538
573
}
539
574
@@ -661,7 +696,7 @@ func (s *Session) sendLoop() (err error) {
661
696
// recv is a long running goroutine that accepts new data
662
697
func (s * Session ) recv () {
663
698
if err := s .recvLoop (); err != nil {
664
- s .exitErr (err )
699
+ s .closeWithoutGoAway (err )
665
700
}
666
701
}
667
702
@@ -683,7 +718,10 @@ func (s *Session) recvLoop() (err error) {
683
718
err = fmt .Errorf ("panic in yamux receive loop: %s" , rerr )
684
719
}
685
720
}()
686
- defer close (s .recvDoneCh )
721
+ defer func () {
722
+ s .recvErr = err
723
+ close (s .recvDoneCh )
724
+ }()
687
725
var hdr header
688
726
for {
689
727
// fmt.Printf("ReadFull from %#v\n", s.reader)
@@ -799,17 +837,17 @@ func (s *Session) handleGoAway(hdr header) error {
799
837
switch code {
800
838
case goAwayNormal :
801
839
atomic .SwapInt32 (& s .remoteGoAway , 1 )
840
+ // Don't close connection on normal go away. Let the existing streams
841
+ // complete gracefully.
842
+ return nil
802
843
case goAwayProtoErr :
803
844
s .logger .Printf ("[ERR] yamux: received protocol error go away" )
804
- return fmt .Errorf ("yamux protocol error" )
805
845
case goAwayInternalErr :
806
846
s .logger .Printf ("[ERR] yamux: received internal error go away" )
807
- return fmt .Errorf ("remote yamux internal error" )
808
847
default :
809
- s .logger .Printf ("[ERR] yamux: received unexpected go away" )
810
- return fmt .Errorf ("unexpected go away received" )
848
+ s .logger .Printf ("[ERR] yamux: received go away with error code: %d" , code )
811
849
}
812
- return nil
850
+ return & GoAwayError { Remote : true , ErrorCode : code }
813
851
}
814
852
815
853
// incomingStream is used to create a new incoming stream
0 commit comments