Skip to content

Commit efa52bc

Browse files
committed
add CloseWithError
1 parent d8cf4e7 commit efa52bc

File tree

2 files changed

+90
-22
lines changed

2 files changed

+90
-22
lines changed

session.go

+60-22
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ type Session struct {
102102
// recvDoneCh is closed when recv() exits to avoid a race
103103
// between stream registration and stream shutdown
104104
recvDoneCh chan struct{}
105+
// recvErr is the error the receive loop ended with
106+
recvErr error
105107

106108
// sendDoneCh is closed when send() exits to avoid a race
107109
// between returning from a Stream.Write and exiting from the send loop
@@ -288,10 +290,18 @@ func (s *Session) AcceptStream() (*Stream, error) {
288290
// semantics of the underlying net.Conn. For TCP connections, it may be dropped depending on LINGER value or
289291
// if there's unread data in the kernel receive buffer.
290292
func (s *Session) Close() error {
291-
return s.close(true, goAwayNormal)
293+
return s.closeWithGoAway(goAwayNormal)
292294
}
293295

294-
func (s *Session) close(sendGoAway bool, errCode uint32) error {
296+
// CloseWithError is used to close the session and all streams after sending a GoAway message with errCode.
297+
// The GoAway may not actually be sent depending on the semantics of the underlying net.Conn.
298+
// For TCP connections, it may be dropped depending on LINGER value or if there's unread data in the kernel
299+
// receive buffer.
300+
func (s *Session) CloseWithError(errCode uint32) error {
301+
return s.closeWithGoAway(errCode)
302+
}
303+
304+
func (s *Session) closeWithGoAway(errCode uint32) error {
295305
s.shutdownLock.Lock()
296306
defer s.shutdownLock.Unlock()
297307

@@ -308,14 +318,12 @@ func (s *Session) close(sendGoAway bool, errCode uint32) error {
308318
// wait for write loop to exit
309319
_ = s.conn.SetWriteDeadline(time.Now().Add(-1 * time.Hour)) // if SetWriteDeadline errored, any blocked writes will be unblocked
310320
<-s.sendDoneCh
311-
if sendGoAway {
312-
ga := s.goAway(errCode)
313-
if err := s.conn.SetWriteDeadline(time.Now().Add(goAwayWaitTime)); err == nil {
314-
_, _ = s.conn.Write(ga[:]) // there's nothing we can do on error here
315-
}
321+
ga := s.goAway(errCode)
322+
if err := s.conn.SetWriteDeadline(time.Now().Add(goAwayWaitTime)); err == nil {
323+
_, _ = s.conn.Write(ga[:]) // there's nothing we can do on error here
316324
}
317-
318325
s.conn.SetWriteDeadline(time.Time{})
326+
319327
s.conn.Close()
320328
<-s.recvDoneCh
321329

@@ -329,15 +337,37 @@ func (s *Session) close(sendGoAway bool, errCode uint32) error {
329337
return nil
330338
}
331339

332-
// exitErr is used to handle an error that is causing the
333-
// session to terminate.
334-
func (s *Session) exitErr(err error) {
340+
func (s *Session) closeWithoutGoAway(err error) error {
335341
s.shutdownLock.Lock()
342+
defer s.shutdownLock.Unlock()
343+
if s.shutdown {
344+
return nil
345+
}
346+
s.shutdown = true
336347
if s.shutdownErr == nil {
337348
s.shutdownErr = err
338349
}
339-
s.shutdownLock.Unlock()
340-
s.close(false, 0)
350+
351+
s.conn.Close()
352+
<-s.recvDoneCh
353+
// Prefer the recvLoop error over the sendLoop error. The receive loop might have the error code
354+
// received in a GoAway frame received just before the RST that closed the sendLoop
355+
if _, ok := s.recvErr.(*GoAwayError); ok {
356+
s.shutdownErr = s.recvErr
357+
}
358+
close(s.shutdownCh)
359+
360+
s.stopKeepalive()
361+
<-s.sendDoneCh
362+
363+
s.streamLock.Lock()
364+
defer s.streamLock.Unlock()
365+
for id, stream := range s.streams {
366+
stream.forceClose()
367+
delete(s.streams, id)
368+
stream.memorySpan.Done()
369+
}
370+
return nil
341371
}
342372

343373
// GoAway can be used to prevent accepting further
@@ -468,7 +498,12 @@ func (s *Session) startKeepalive() {
468498

469499
if err != nil {
470500
s.logger.Printf("[ERR] yamux: keepalive failed: %v", err)
471-
s.exitErr(ErrKeepAliveTimeout)
501+
s.shutdownLock.Lock()
502+
if s.shutdownErr == nil {
503+
s.shutdownErr = ErrKeepAliveTimeout
504+
}
505+
s.shutdownLock.Unlock()
506+
s.closeWithGoAway(goAwayNormal)
472507
}
473508
})
474509
}
@@ -533,7 +568,7 @@ func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}) err
533568
// send is a long running goroutine that sends data
534569
func (s *Session) send() {
535570
if err := s.sendLoop(); err != nil {
536-
s.exitErr(err)
571+
s.closeWithoutGoAway(err)
537572
}
538573
}
539574

@@ -661,7 +696,7 @@ func (s *Session) sendLoop() (err error) {
661696
// recv is a long running goroutine that accepts new data
662697
func (s *Session) recv() {
663698
if err := s.recvLoop(); err != nil {
664-
s.exitErr(err)
699+
s.closeWithoutGoAway(err)
665700
}
666701
}
667702

@@ -683,7 +718,10 @@ func (s *Session) recvLoop() (err error) {
683718
err = fmt.Errorf("panic in yamux receive loop: %s", rerr)
684719
}
685720
}()
686-
defer close(s.recvDoneCh)
721+
defer func() {
722+
s.recvErr = err
723+
close(s.recvDoneCh)
724+
}()
687725
var hdr header
688726
for {
689727
// fmt.Printf("ReadFull from %#v\n", s.reader)
@@ -799,17 +837,17 @@ func (s *Session) handleGoAway(hdr header) error {
799837
switch code {
800838
case goAwayNormal:
801839
atomic.SwapInt32(&s.remoteGoAway, 1)
840+
// Don't close connection on normal go away. Let the existing streams
841+
// complete gracefully.
842+
return nil
802843
case goAwayProtoErr:
803844
s.logger.Printf("[ERR] yamux: received protocol error go away")
804-
return fmt.Errorf("yamux protocol error")
805845
case goAwayInternalErr:
806846
s.logger.Printf("[ERR] yamux: received internal error go away")
807-
return fmt.Errorf("remote yamux internal error")
808847
default:
809-
s.logger.Printf("[ERR] yamux: received unexpected go away")
810-
return fmt.Errorf("unexpected go away received")
848+
s.logger.Printf("[ERR] yamux: received go away with error code: %d", code)
811849
}
812-
return nil
850+
return &GoAwayError{Remote: true, ErrorCode: code}
813851
}
814852

815853
// incomingStream is used to create a new incoming stream

session_test.go

+30
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package yamux
33
import (
44
"bytes"
55
"context"
6+
"errors"
67
"fmt"
78
"io"
89
"math/rand"
@@ -650,6 +651,35 @@ func TestGoAway(t *testing.T) {
650651
default:
651652
t.Fatalf("err: %v", err)
652653
}
654+
time.Sleep(50 * time.Millisecond)
655+
}
656+
t.Fatalf("expected GoAway error")
657+
}
658+
659+
func TestCloseWithError(t *testing.T) {
660+
// This test is noisy.
661+
conf := testConf()
662+
conf.LogOutput = io.Discard
663+
664+
client, server := testClientServerConfig(conf)
665+
defer client.Close()
666+
defer server.Close()
667+
668+
if err := server.CloseWithError(42); err != nil {
669+
t.Fatalf("err: %v", err)
670+
}
671+
672+
for i := 0; i < 100; i++ {
673+
s, err := client.Open(context.Background())
674+
if err == nil {
675+
s.Close()
676+
time.Sleep(50 * time.Millisecond)
677+
continue
678+
}
679+
if !errors.Is(err, &GoAwayError{ErrorCode: 42, Remote: true}) {
680+
t.Fatalf("err: %v", err)
681+
}
682+
return
653683
}
654684
t.Fatalf("expected GoAway error")
655685
}

0 commit comments

Comments
 (0)