@@ -46,10 +46,6 @@ var nullMemoryManager = &nullMemoryManagerImpl{}
46
46
type Session struct {
47
47
rtt int64 // to be accessed atomically, in nanoseconds
48
48
49
- // remoteGoAway indicates the remote side does
50
- // not want futher connections. Must be first for alignment.
51
- remoteGoAway int32
52
-
53
49
// localGoAway indicates that we should stop
54
50
// accepting futher connections. Must be first for alignment.
55
51
localGoAway int32
@@ -102,6 +98,8 @@ type Session struct {
102
98
// recvDoneCh is closed when recv() exits to avoid a race
103
99
// between stream registration and stream shutdown
104
100
recvDoneCh chan struct {}
101
+ // recvErr is the error the receive loop ended with
102
+ recvErr error
105
103
106
104
// sendDoneCh is closed when send() exits to avoid a race
107
105
// between returning from a Stream.Write and exiting from the send loop
@@ -203,9 +201,6 @@ func (s *Session) OpenStream(ctx context.Context) (*Stream, error) {
203
201
if s .IsClosed () {
204
202
return nil , s .shutdownErr
205
203
}
206
- if atomic .LoadInt32 (& s .remoteGoAway ) == 1 {
207
- return nil , ErrRemoteGoAway
208
- }
209
204
210
205
// Block if we have too many inflight SYNs
211
206
select {
@@ -283,9 +278,23 @@ func (s *Session) AcceptStream() (*Stream, error) {
283
278
}
284
279
}
285
280
286
- // Close is used to close the session and all streams.
287
- // Attempts to send a GoAway before closing the connection.
281
+ // Close is used to close the session and all streams. It doesn't send a GoAway before
282
+ // closing the connection.
288
283
func (s * Session ) Close () error {
284
+ return s .close (ErrSessionShutdown , false , goAwayNormal )
285
+ }
286
+
287
+ // CloseWithError is used to close the session and all streams after sending a GoAway message with errCode.
288
+ // Blocks for ConnectionWriteTimeout to write the GoAway message.
289
+ //
290
+ // The GoAway may not actually be sent depending on the semantics of the underlying net.Conn.
291
+ // For TCP connections, it may be dropped depending on LINGER value or if there's unread data in the kernel
292
+ // receive buffer.
293
+ func (s * Session ) CloseWithError (errCode uint32 ) error {
294
+ return s .close (& GoAwayError {Remote : false , ErrorCode : errCode }, true , errCode )
295
+ }
296
+
297
+ func (s * Session ) close (shutdownErr error , sendGoAway bool , errCode uint32 ) error {
289
298
s .shutdownLock .Lock ()
290
299
defer s .shutdownLock .Unlock ()
291
300
@@ -294,35 +303,42 @@ func (s *Session) Close() error {
294
303
}
295
304
s .shutdown = true
296
305
if s .shutdownErr == nil {
297
- s .shutdownErr = ErrSessionShutdown
306
+ s .shutdownErr = shutdownErr
298
307
}
299
308
close (s .shutdownCh )
300
- s .conn .Close ()
301
309
s .stopKeepalive ()
302
- <- s .recvDoneCh
310
+
311
+ // Only send GoAway if we have an error code.
312
+ if sendGoAway && errCode != goAwayNormal {
313
+ // wait for write loop to exit
314
+ // We need to write the current frame completely before sending a goaway.
315
+ // This will wait for at most s.config.ConnectionWriteTimeout
316
+ <- s .sendDoneCh
317
+ ga := s .goAway (errCode )
318
+ if err := s .conn .SetWriteDeadline (time .Now ().Add (goAwayWaitTime )); err == nil {
319
+ _ , _ = s .conn .Write (ga [:]) // there's nothing we can do on error here
320
+ }
321
+ s .conn .SetWriteDeadline (time.Time {})
322
+ }
323
+
324
+ s .conn .Close ()
303
325
<- s .sendDoneCh
326
+ <- s .recvDoneCh
304
327
328
+ resetErr := shutdownErr
329
+ if _ , ok := resetErr .(* GoAwayError ); ! ok {
330
+ resetErr = fmt .Errorf ("%w: connection closed: %w" , ErrStreamReset , shutdownErr )
331
+ }
305
332
s .streamLock .Lock ()
306
333
defer s .streamLock .Unlock ()
307
334
for id , stream := range s .streams {
308
- stream .forceClose ()
335
+ stream .forceClose (resetErr )
309
336
delete (s .streams , id )
310
337
stream .memorySpan .Done ()
311
338
}
312
339
return nil
313
340
}
314
341
315
- // exitErr is used to handle an error that is causing the
316
- // session to terminate.
317
- func (s * Session ) exitErr (err error ) {
318
- s .shutdownLock .Lock ()
319
- if s .shutdownErr == nil {
320
- s .shutdownErr = err
321
- }
322
- s .shutdownLock .Unlock ()
323
- s .Close ()
324
- }
325
-
326
342
// GoAway can be used to prevent accepting further
327
343
// connections. It does not close the underlying conn.
328
344
func (s * Session ) GoAway () error {
@@ -451,7 +467,7 @@ func (s *Session) startKeepalive() {
451
467
452
468
if err != nil {
453
469
s .logger .Printf ("[ERR] yamux: keepalive failed: %v" , err )
454
- s .exitErr (ErrKeepAliveTimeout )
470
+ s .close (ErrKeepAliveTimeout , false , 0 )
455
471
}
456
472
})
457
473
}
@@ -516,7 +532,25 @@ func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}) err
516
532
// send is a long running goroutine that sends data
517
533
func (s * Session ) send () {
518
534
if err := s .sendLoop (); err != nil {
519
- s .exitErr (err )
535
+ // If we are shutting down because remote closed the connection, prefer the recvLoop error
536
+ // over the sendLoop error. The receive loop might have error code received in a GoAway frame,
537
+ // which was received just before the TCP RST that closed the sendLoop.
538
+ //
539
+ // If we are closing because of an write error, we use the error from the sendLoop and not the recvLoop.
540
+ // We hold the shutdownLock, close the connection, and wait for the receive loop to finish and
541
+ // use the sendLoop error. Holding the shutdownLock ensures that the recvLoop doesn't trigger connection close
542
+ // but the sendLoop does.
543
+ s .shutdownLock .Lock ()
544
+ if s .shutdownErr == nil {
545
+ s .conn .Close ()
546
+ <- s .recvDoneCh
547
+ if _ , ok := s .recvErr .(* GoAwayError ); ok {
548
+ err = s .recvErr
549
+ }
550
+ s .shutdownErr = err
551
+ }
552
+ s .shutdownLock .Unlock ()
553
+ s .close (err , false , 0 )
520
554
}
521
555
}
522
556
@@ -644,7 +678,7 @@ func (s *Session) sendLoop() (err error) {
644
678
// recv is a long running goroutine that accepts new data
645
679
func (s * Session ) recv () {
646
680
if err := s .recvLoop (); err != nil {
647
- s .exitErr (err )
681
+ s .close (err , false , 0 )
648
682
}
649
683
}
650
684
@@ -666,7 +700,10 @@ func (s *Session) recvLoop() (err error) {
666
700
err = fmt .Errorf ("panic in yamux receive loop: %s" , rerr )
667
701
}
668
702
}()
669
- defer close (s .recvDoneCh )
703
+ defer func () {
704
+ s .recvErr = err
705
+ close (s .recvDoneCh )
706
+ }()
670
707
var hdr header
671
708
for {
672
709
// fmt.Printf("ReadFull from %#v\n", s.reader)
@@ -781,18 +818,15 @@ func (s *Session) handleGoAway(hdr header) error {
781
818
code := hdr .Length ()
782
819
switch code {
783
820
case goAwayNormal :
784
- atomic . SwapInt32 ( & s . remoteGoAway , 1 )
821
+ return ErrRemoteGoAway
785
822
case goAwayProtoErr :
786
823
s .logger .Printf ("[ERR] yamux: received protocol error go away" )
787
- return fmt .Errorf ("yamux protocol error" )
788
824
case goAwayInternalErr :
789
825
s .logger .Printf ("[ERR] yamux: received internal error go away" )
790
- return fmt .Errorf ("remote yamux internal error" )
791
826
default :
792
- s .logger .Printf ("[ERR] yamux: received unexpected go away" )
793
- return fmt .Errorf ("unexpected go away received" )
827
+ s .logger .Printf ("[ERR] yamux: received go away with error code: %d" , code )
794
828
}
795
- return nil
829
+ return & GoAwayError { Remote : true , ErrorCode : code }
796
830
}
797
831
798
832
// incomingStream is used to create a new incoming stream
0 commit comments