Skip to content

Commit 43cd707

Browse files
committed
Merge branch 'sukun/stream-error-code' into sukun/conn-error-2
2 parents 5727def + 9190b78 commit 43cd707

File tree

4 files changed

+102
-11
lines changed

4 files changed

+102
-11
lines changed

const.go

+21
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,27 @@ func (e *GoAwayError) Is(target error) bool {
5757
return false
5858
}
5959

60+
// A StreamError is used for errors returned from Read and Write calls after the stream is Reset
61+
type StreamError struct {
62+
ErrorCode uint32
63+
Remote bool
64+
}
65+
66+
func (s *StreamError) Error() string {
67+
if s.Remote {
68+
return fmt.Sprintf("stream reset by remote, error code: %d", s.ErrorCode)
69+
}
70+
return fmt.Sprintf("stream reset, error code: %d", s.ErrorCode)
71+
}
72+
73+
func (s *StreamError) Is(target error) bool {
74+
if target == ErrStreamReset {
75+
return true
76+
}
77+
e, ok := target.(*StreamError)
78+
return ok && *e == *s
79+
}
80+
6081
var (
6182
// ErrInvalidVersion means we received a frame with an
6283
// invalid version

session.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ func (s *Session) close(shutdownErr error, sendGoAway bool, errCode uint32) erro
334334
s.streamLock.Lock()
335335
defer s.streamLock.Unlock()
336336
for id, stream := range s.streams {
337-
stream.forceClose()
337+
stream.forceClose(fmt.Errorf("%w: connection closed: %w", ErrStreamReset, s.shutdownErr))
338338
delete(s.streams, id)
339339
stream.memorySpan.Done()
340340
}

session_test.go

+52-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"testing"
1717
"time"
1818

19+
"github.com/stretchr/testify/assert"
1920
"github.com/stretchr/testify/require"
2021
)
2122

@@ -1571,6 +1572,56 @@ func TestStreamResetRead(t *testing.T) {
15711572
wc.Wait()
15721573
}
15731574

1575+
func TestStreamResetWithError(t *testing.T) {
1576+
client, server := testClientServer()
1577+
defer client.Close()
1578+
defer server.Close()
1579+
1580+
wc := new(sync.WaitGroup)
1581+
wc.Add(2)
1582+
go func() {
1583+
defer wc.Done()
1584+
stream, err := server.AcceptStream()
1585+
if err != nil {
1586+
t.Error(err)
1587+
}
1588+
1589+
se := &StreamError{}
1590+
_, err = io.ReadAll(stream)
1591+
if !errors.As(err, &se) {
1592+
t.Errorf("exptected StreamError, got type:%T, err: %s", err, err)
1593+
return
1594+
}
1595+
expected := &StreamError{Remote: true, ErrorCode: 42}
1596+
assert.Equal(t, se, expected)
1597+
}()
1598+
1599+
stream, err := client.OpenStream(context.Background())
1600+
if err != nil {
1601+
t.Error(err)
1602+
}
1603+
1604+
go func() {
1605+
defer wc.Done()
1606+
1607+
se := &StreamError{}
1608+
_, err := io.ReadAll(stream)
1609+
if !errors.As(err, &se) {
1610+
t.Errorf("exptected StreamError, got type:%T, err: %s", err, err)
1611+
return
1612+
}
1613+
expected := &StreamError{Remote: false, ErrorCode: 42}
1614+
assert.Equal(t, se, expected)
1615+
}()
1616+
1617+
time.Sleep(1 * time.Second)
1618+
err = stream.ResetWithError(42)
1619+
if err != nil {
1620+
t.Fatal(err)
1621+
}
1622+
wc.Wait()
1623+
}
1624+
15741625
func TestLotsOfWritesWithStreamDeadline(t *testing.T) {
15751626
config := testConf()
15761627
config.EnableKeepAlive = false
@@ -1809,7 +1860,7 @@ func TestMaxIncomingStreams(t *testing.T) {
18091860
require.NoError(t, err)
18101861
str.SetDeadline(time.Now().Add(time.Second))
18111862
_, err = str.Read([]byte{0})
1812-
require.EqualError(t, err, "stream reset")
1863+
require.ErrorIs(t, err, ErrStreamReset)
18131864

18141865
// Now close one of the streams.
18151866
// This should then allow the client to open a new stream.

stream.go

+28-9
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ type Stream struct {
4141

4242
state streamState
4343
writeState, readState halfStreamState
44+
writeErr, readErr error
4445
stateLock sync.Mutex
4546

4647
recvBuf segmentedBuffer
@@ -89,6 +90,7 @@ func (s *Stream) Read(b []byte) (n int, err error) {
8990
START:
9091
s.stateLock.Lock()
9192
state := s.readState
93+
resetErr := s.readErr
9294
s.stateLock.Unlock()
9395

9496
switch state {
@@ -101,7 +103,7 @@ START:
101103
}
102104
// Closed, but we have data pending -> read.
103105
case halfReset:
104-
return 0, ErrStreamReset
106+
return 0, resetErr
105107
default:
106108
panic("unknown state")
107109
}
@@ -147,6 +149,7 @@ func (s *Stream) write(b []byte) (n int, err error) {
147149
START:
148150
s.stateLock.Lock()
149151
state := s.writeState
152+
resetErr := s.writeErr
150153
s.stateLock.Unlock()
151154

152155
switch state {
@@ -155,7 +158,7 @@ START:
155158
case halfClosed:
156159
return 0, ErrStreamClosed
157160
case halfReset:
158-
return 0, ErrStreamReset
161+
return 0, resetErr
159162
default:
160163
panic("unknown state")
161164
}
@@ -250,13 +253,17 @@ func (s *Stream) sendClose() error {
250253
}
251254

252255
// sendReset is used to send a RST
253-
func (s *Stream) sendReset() error {
254-
hdr := encode(typeWindowUpdate, flagRST, s.id, 0)
256+
func (s *Stream) sendReset(errCode uint32) error {
257+
hdr := encode(typeWindowUpdate, flagRST, s.id, errCode)
255258
return s.session.sendMsg(hdr, nil, nil)
256259
}
257260

258261
// Reset resets the stream (forcibly closes the stream)
259262
func (s *Stream) Reset() error {
263+
return s.ResetWithError(0)
264+
}
265+
266+
func (s *Stream) ResetWithError(errCode uint32) error {
260267
sendReset := false
261268
s.stateLock.Lock()
262269
switch s.state {
@@ -276,15 +283,17 @@ func (s *Stream) Reset() error {
276283
// If we've already sent/received an EOF, no need to reset that side.
277284
if s.writeState == halfOpen {
278285
s.writeState = halfReset
286+
s.writeErr = &StreamError{Remote: false, ErrorCode: errCode}
279287
}
280288
if s.readState == halfOpen {
281289
s.readState = halfReset
290+
s.readErr = &StreamError{Remote: false, ErrorCode: errCode}
282291
}
283292
s.state = streamFinished
284293
s.notifyWaiting()
285294
s.stateLock.Unlock()
286295
if sendReset {
287-
_ = s.sendReset()
296+
_ = s.sendReset(errCode)
288297
}
289298
s.cleanup()
290299
return nil
@@ -336,6 +345,7 @@ func (s *Stream) CloseRead() error {
336345
panic("invalid state")
337346
}
338347
s.readState = halfReset
348+
s.readErr = ErrStreamReset
339349
cleanup = s.writeState != halfOpen
340350
if cleanup {
341351
s.state = streamFinished
@@ -357,13 +367,15 @@ func (s *Stream) Close() error {
357367
}
358368

359369
// forceClose is used for when the session is exiting
360-
func (s *Stream) forceClose() {
370+
func (s *Stream) forceClose(err error) {
361371
s.stateLock.Lock()
362372
if s.readState == halfOpen {
363373
s.readState = halfReset
374+
s.readErr = err
364375
}
365376
if s.writeState == halfOpen {
366377
s.writeState = halfReset
378+
s.writeErr = err
367379
}
368380
s.state = streamFinished
369381
s.notifyWaiting()
@@ -382,7 +394,7 @@ func (s *Stream) cleanup() {
382394

383395
// processFlags is used to update the state of the stream
384396
// based on set flags, if any. Lock must be held
385-
func (s *Stream) processFlags(flags uint16) {
397+
func (s *Stream) processFlags(flags uint16, hdr header) {
386398
// Close the stream without holding the state lock
387399
var closeStream bool
388400
defer func() {
@@ -418,11 +430,18 @@ func (s *Stream) processFlags(flags uint16) {
418430
}
419431
if flags&flagRST == flagRST {
420432
s.stateLock.Lock()
433+
var resetErr error = ErrStreamReset
434+
// Length in a window update frame with RST flag encodes an error code.
435+
if hdr.MsgType() == typeWindowUpdate {
436+
resetErr = &StreamError{Remote: true, ErrorCode: hdr.Length()}
437+
}
421438
if s.readState == halfOpen {
422439
s.readState = halfReset
440+
s.readErr = resetErr
423441
}
424442
if s.writeState == halfOpen {
425443
s.writeState = halfReset
444+
s.writeErr = resetErr
426445
}
427446
s.state = streamFinished
428447
s.stateLock.Unlock()
@@ -439,15 +458,15 @@ func (s *Stream) notifyWaiting() {
439458

440459
// incrSendWindow updates the size of our send window
441460
func (s *Stream) incrSendWindow(hdr header, flags uint16) {
442-
s.processFlags(flags)
461+
s.processFlags(flags, hdr)
443462
// Increase window, unblock a sender
444463
atomic.AddUint32(&s.sendWindow, hdr.Length())
445464
asyncNotify(s.sendNotifyCh)
446465
}
447466

448467
// readData is used to handle a data frame
449468
func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
450-
s.processFlags(flags)
469+
s.processFlags(flags, hdr)
451470

452471
// Check that our recv window is not exceeded
453472
length := hdr.Length()

0 commit comments

Comments
 (0)