Skip to content

Commit c76f5d7

Browse files
authored
add support for sending error codes on session close (#121)
1 parent c60349b commit c76f5d7

File tree

4 files changed

+254
-57
lines changed

4 files changed

+254
-57
lines changed

const.go

+64-3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package yamux
33
import (
44
"encoding/binary"
55
"fmt"
6+
"time"
67
)
78

89
type Error struct {
@@ -22,6 +23,64 @@ func (ye *Error) Temporary() bool {
2223
return ye.temporary
2324
}
2425

26+
type GoAwayError struct {
27+
ErrorCode uint32
28+
Remote bool
29+
}
30+
31+
func (e *GoAwayError) Error() string {
32+
if e.Remote {
33+
return fmt.Sprintf("remote sent go away, code: %d", e.ErrorCode)
34+
}
35+
return fmt.Sprintf("sent go away, code: %d", e.ErrorCode)
36+
}
37+
38+
func (e *GoAwayError) Timeout() bool {
39+
return false
40+
}
41+
42+
func (e *GoAwayError) Temporary() bool {
43+
return false
44+
}
45+
46+
func (e *GoAwayError) Is(target error) bool {
47+
// to maintain compatibility with errors returned by previous versions
48+
if e.Remote && target == ErrRemoteGoAway {
49+
return true
50+
} else if !e.Remote && target == ErrSessionShutdown {
51+
return true
52+
} else if target == ErrStreamReset {
53+
// A GoAway on a connection also resets all the streams.
54+
return true
55+
}
56+
57+
if err, ok := target.(*GoAwayError); ok {
58+
return *e == *err
59+
}
60+
return false
61+
}
62+
63+
// A StreamError is used for errors returned from Read and Write calls after the stream is Reset
64+
type StreamError struct {
65+
ErrorCode uint32
66+
Remote bool
67+
}
68+
69+
func (s *StreamError) Error() string {
70+
if s.Remote {
71+
return fmt.Sprintf("stream reset by remote, error code: %d", s.ErrorCode)
72+
}
73+
return fmt.Sprintf("stream reset, error code: %d", s.ErrorCode)
74+
}
75+
76+
func (s *StreamError) Is(target error) bool {
77+
if target == ErrStreamReset {
78+
return true
79+
}
80+
e, ok := target.(*StreamError)
81+
return ok && *e == *s
82+
}
83+
2584
var (
2685
// ErrInvalidVersion means we received a frame with an
2786
// invalid version
@@ -33,7 +92,7 @@ var (
3392

3493
// ErrSessionShutdown is used if there is a shutdown during
3594
// an operation
36-
ErrSessionShutdown = &Error{msg: "session shutdown"}
95+
ErrSessionShutdown = &GoAwayError{ErrorCode: goAwayNormal, Remote: false}
3796

3897
// ErrStreamsExhausted is returned if we have no more
3998
// stream ids to issue
@@ -55,8 +114,9 @@ var (
55114
// ErrUnexpectedFlag is set when we get an unexpected flag
56115
ErrUnexpectedFlag = &Error{msg: "unexpected flag"}
57116

58-
// ErrRemoteGoAway is used when we get a go away from the other side
59-
ErrRemoteGoAway = &Error{msg: "remote end is not accepting connections"}
117+
// ErrRemoteGoAway is used when we get a go away from the other side with error code
118+
// goAwayNormal(0).
119+
ErrRemoteGoAway = &GoAwayError{Remote: true, ErrorCode: goAwayNormal}
60120

61121
// ErrStreamReset is sent if a stream is reset. This can happen
62122
// if the backlog is exceeded, or if there was a remote GoAway.
@@ -117,6 +177,7 @@ const (
117177
// It's not an implementation choice, the value defined in the specification.
118178
initialStreamWindow = 256 * 1024
119179
maxStreamWindow = 16 * 1024 * 1024
180+
goAwayWaitTime = 100 * time.Millisecond
120181
)
121182

122183
const (

session.go

+68-34
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,6 @@ var nullMemoryManager = &nullMemoryManagerImpl{}
4646
type Session struct {
4747
rtt int64 // to be accessed atomically, in nanoseconds
4848

49-
// remoteGoAway indicates the remote side does
50-
// not want futher connections. Must be first for alignment.
51-
remoteGoAway int32
52-
5349
// localGoAway indicates that we should stop
5450
// accepting futher connections. Must be first for alignment.
5551
localGoAway int32
@@ -102,6 +98,8 @@ type Session struct {
10298
// recvDoneCh is closed when recv() exits to avoid a race
10399
// between stream registration and stream shutdown
104100
recvDoneCh chan struct{}
101+
// recvErr is the error the receive loop ended with
102+
recvErr error
105103

106104
// sendDoneCh is closed when send() exits to avoid a race
107105
// between returning from a Stream.Write and exiting from the send loop
@@ -203,9 +201,6 @@ func (s *Session) OpenStream(ctx context.Context) (*Stream, error) {
203201
if s.IsClosed() {
204202
return nil, s.shutdownErr
205203
}
206-
if atomic.LoadInt32(&s.remoteGoAway) == 1 {
207-
return nil, ErrRemoteGoAway
208-
}
209204

210205
// Block if we have too many inflight SYNs
211206
select {
@@ -283,9 +278,23 @@ func (s *Session) AcceptStream() (*Stream, error) {
283278
}
284279
}
285280

286-
// Close is used to close the session and all streams.
287-
// Attempts to send a GoAway before closing the connection.
281+
// Close is used to close the session and all streams. It doesn't send a GoAway before
282+
// closing the connection.
288283
func (s *Session) Close() error {
284+
return s.close(ErrSessionShutdown, false, goAwayNormal)
285+
}
286+
287+
// CloseWithError is used to close the session and all streams after sending a GoAway message with errCode.
288+
// Blocks for ConnectionWriteTimeout to write the GoAway message.
289+
//
290+
// The GoAway may not actually be sent depending on the semantics of the underlying net.Conn.
291+
// For TCP connections, it may be dropped depending on LINGER value or if there's unread data in the kernel
292+
// receive buffer.
293+
func (s *Session) CloseWithError(errCode uint32) error {
294+
return s.close(&GoAwayError{Remote: false, ErrorCode: errCode}, true, errCode)
295+
}
296+
297+
func (s *Session) close(shutdownErr error, sendGoAway bool, errCode uint32) error {
289298
s.shutdownLock.Lock()
290299
defer s.shutdownLock.Unlock()
291300

@@ -294,35 +303,42 @@ func (s *Session) Close() error {
294303
}
295304
s.shutdown = true
296305
if s.shutdownErr == nil {
297-
s.shutdownErr = ErrSessionShutdown
306+
s.shutdownErr = shutdownErr
298307
}
299308
close(s.shutdownCh)
300-
s.conn.Close()
301309
s.stopKeepalive()
302-
<-s.recvDoneCh
310+
311+
// Only send GoAway if we have an error code.
312+
if sendGoAway && errCode != goAwayNormal {
313+
// wait for write loop to exit
314+
// We need to write the current frame completely before sending a goaway.
315+
// This will wait for at most s.config.ConnectionWriteTimeout
316+
<-s.sendDoneCh
317+
ga := s.goAway(errCode)
318+
if err := s.conn.SetWriteDeadline(time.Now().Add(goAwayWaitTime)); err == nil {
319+
_, _ = s.conn.Write(ga[:]) // there's nothing we can do on error here
320+
}
321+
s.conn.SetWriteDeadline(time.Time{})
322+
}
323+
324+
s.conn.Close()
303325
<-s.sendDoneCh
326+
<-s.recvDoneCh
304327

328+
resetErr := shutdownErr
329+
if _, ok := resetErr.(*GoAwayError); !ok {
330+
resetErr = fmt.Errorf("%w: connection closed: %w", ErrStreamReset, shutdownErr)
331+
}
305332
s.streamLock.Lock()
306333
defer s.streamLock.Unlock()
307334
for id, stream := range s.streams {
308-
stream.forceClose()
335+
stream.forceClose(resetErr)
309336
delete(s.streams, id)
310337
stream.memorySpan.Done()
311338
}
312339
return nil
313340
}
314341

315-
// exitErr is used to handle an error that is causing the
316-
// session to terminate.
317-
func (s *Session) exitErr(err error) {
318-
s.shutdownLock.Lock()
319-
if s.shutdownErr == nil {
320-
s.shutdownErr = err
321-
}
322-
s.shutdownLock.Unlock()
323-
s.Close()
324-
}
325-
326342
// GoAway can be used to prevent accepting further
327343
// connections. It does not close the underlying conn.
328344
func (s *Session) GoAway() error {
@@ -451,7 +467,7 @@ func (s *Session) startKeepalive() {
451467

452468
if err != nil {
453469
s.logger.Printf("[ERR] yamux: keepalive failed: %v", err)
454-
s.exitErr(ErrKeepAliveTimeout)
470+
s.close(ErrKeepAliveTimeout, false, 0)
455471
}
456472
})
457473
}
@@ -516,7 +532,25 @@ func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}) err
516532
// send is a long running goroutine that sends data
517533
func (s *Session) send() {
518534
if err := s.sendLoop(); err != nil {
519-
s.exitErr(err)
535+
// If we are shutting down because remote closed the connection, prefer the recvLoop error
536+
// over the sendLoop error. The receive loop might have error code received in a GoAway frame,
537+
// which was received just before the TCP RST that closed the sendLoop.
538+
//
539+
// If we are closing because of an write error, we use the error from the sendLoop and not the recvLoop.
540+
// We hold the shutdownLock, close the connection, and wait for the receive loop to finish and
541+
// use the sendLoop error. Holding the shutdownLock ensures that the recvLoop doesn't trigger connection close
542+
// but the sendLoop does.
543+
s.shutdownLock.Lock()
544+
if s.shutdownErr == nil {
545+
s.conn.Close()
546+
<-s.recvDoneCh
547+
if _, ok := s.recvErr.(*GoAwayError); ok {
548+
err = s.recvErr
549+
}
550+
s.shutdownErr = err
551+
}
552+
s.shutdownLock.Unlock()
553+
s.close(err, false, 0)
520554
}
521555
}
522556

@@ -644,7 +678,7 @@ func (s *Session) sendLoop() (err error) {
644678
// recv is a long running goroutine that accepts new data
645679
func (s *Session) recv() {
646680
if err := s.recvLoop(); err != nil {
647-
s.exitErr(err)
681+
s.close(err, false, 0)
648682
}
649683
}
650684

@@ -666,7 +700,10 @@ func (s *Session) recvLoop() (err error) {
666700
err = fmt.Errorf("panic in yamux receive loop: %s", rerr)
667701
}
668702
}()
669-
defer close(s.recvDoneCh)
703+
defer func() {
704+
s.recvErr = err
705+
close(s.recvDoneCh)
706+
}()
670707
var hdr header
671708
for {
672709
// fmt.Printf("ReadFull from %#v\n", s.reader)
@@ -781,18 +818,15 @@ func (s *Session) handleGoAway(hdr header) error {
781818
code := hdr.Length()
782819
switch code {
783820
case goAwayNormal:
784-
atomic.SwapInt32(&s.remoteGoAway, 1)
821+
return ErrRemoteGoAway
785822
case goAwayProtoErr:
786823
s.logger.Printf("[ERR] yamux: received protocol error go away")
787-
return fmt.Errorf("yamux protocol error")
788824
case goAwayInternalErr:
789825
s.logger.Printf("[ERR] yamux: received internal error go away")
790-
return fmt.Errorf("remote yamux internal error")
791826
default:
792-
s.logger.Printf("[ERR] yamux: received unexpected go away")
793-
return fmt.Errorf("unexpected go away received")
827+
s.logger.Printf("[ERR] yamux: received go away with error code: %d", code)
794828
}
795-
return nil
829+
return &GoAwayError{Remote: true, ErrorCode: code}
796830
}
797831

798832
// incomingStream is used to create a new incoming stream

0 commit comments

Comments
 (0)