Skip to content

Commit af8e895

Browse files
committed
review comments
1 parent ede18a5 commit af8e895

File tree

3 files changed

+21
-20
lines changed

3 files changed

+21
-20
lines changed

session.go

+8-2
Original file line numberDiff line numberDiff line change
@@ -535,8 +535,14 @@ func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}) err
535535
// send is a long running goroutine that sends data
536536
func (s *Session) send() {
537537
if err := s.sendLoop(); err != nil {
538-
// Prefer the recvLoop error over the sendLoop error. The receive loop might have the error code
539-
// received in a GoAway frame received just before the TCP RST that closed the sendLoop
538+
// If we are shutting down because remote closed the connection, prefer the recvLoop error
539+
// over the sendLoop error. The receive loop might have error code received in a GoAway frame,
540+
// which was received just before the TCP RST that closed the sendLoop.
541+
//
542+
// If we are closing because of an write error, we use the error from the sendLoop and not the recvLoop.
543+
// We hold the shutdownLock, close the connection, and wait for the receive loop to finish and
544+
// use the sendLoop error. Holding the shutdownLock ensures that the recvLoop doesn't trigger connection close
545+
// but the sendLoop does.
540546
s.shutdownLock.Lock()
541547
if s.shutdownErr == nil {
542548
s.conn.Close()

session_test.go

+10-15
Original file line numberDiff line numberDiff line change
@@ -1578,7 +1578,7 @@ func TestStreamResetWithError(t *testing.T) {
15781578
defer server.Close()
15791579

15801580
wc := new(sync.WaitGroup)
1581-
wc.Add(2)
1581+
wc.Add(1)
15821582
go func() {
15831583
defer wc.Done()
15841584
stream, err := server.AcceptStream()
@@ -1589,7 +1589,7 @@ func TestStreamResetWithError(t *testing.T) {
15891589
se := &StreamError{}
15901590
_, err = io.ReadAll(stream)
15911591
if !errors.As(err, &se) {
1592-
t.Errorf("exptected StreamError, got type:%T, err: %s", err, err)
1592+
t.Errorf("expected StreamError, got type:%T, err: %s", err, err)
15931593
return
15941594
}
15951595
expected := &StreamError{Remote: true, ErrorCode: 42}
@@ -1601,24 +1601,19 @@ func TestStreamResetWithError(t *testing.T) {
16011601
t.Error(err)
16021602
}
16031603

1604-
go func() {
1605-
defer wc.Done()
1606-
1607-
se := &StreamError{}
1608-
_, err := io.ReadAll(stream)
1609-
if !errors.As(err, &se) {
1610-
t.Errorf("exptected StreamError, got type:%T, err: %s", err, err)
1611-
return
1612-
}
1613-
expected := &StreamError{Remote: false, ErrorCode: 42}
1614-
assert.Equal(t, se, expected)
1615-
}()
1616-
16171604
time.Sleep(1 * time.Second)
16181605
err = stream.ResetWithError(42)
16191606
if err != nil {
16201607
t.Fatal(err)
16211608
}
1609+
se := &StreamError{}
1610+
_, err = io.ReadAll(stream)
1611+
if !errors.As(err, &se) {
1612+
t.Errorf("expected StreamError, got type:%T, err: %s", err, err)
1613+
return
1614+
}
1615+
expected := &StreamError{Remote: false, ErrorCode: 42}
1616+
assert.Equal(t, se, expected)
16221617
wc.Wait()
16231618
}
16241619

stream.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ func (s *Stream) cleanup() {
395395

396396
// processFlags is used to update the state of the stream
397397
// based on set flags, if any. Lock must be held
398-
func (s *Stream) processFlags(flags uint16, hdr header) {
398+
func (s *Stream) processFlags(hdr header, flags uint16) {
399399
// Close the stream without holding the state lock
400400
var closeStream bool
401401
defer func() {
@@ -459,15 +459,15 @@ func (s *Stream) notifyWaiting() {
459459

460460
// incrSendWindow updates the size of our send window
461461
func (s *Stream) incrSendWindow(hdr header, flags uint16) {
462-
s.processFlags(flags, hdr)
462+
s.processFlags(hdr, flags)
463463
// Increase window, unblock a sender
464464
atomic.AddUint32(&s.sendWindow, hdr.Length())
465465
asyncNotify(s.sendNotifyCh)
466466
}
467467

468468
// readData is used to handle a data frame
469469
func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
470-
s.processFlags(flags, hdr)
470+
s.processFlags(hdr, flags)
471471

472472
// Check that our recv window is not exceeded
473473
length := hdr.Length()

0 commit comments

Comments
 (0)