Skip to content

Commit b952594

Browse files
committed
quic: fix data race in connection close
We were failing to hold streamsState.streamsMu when removing a closed stream from the conn's stream map. Rework this to remove the mutex entirely. The only access to the map that isn't on the conn's loop is during stream creation. Send a message to the loop to register the stream instead of using a mutex. Change-Id: I2e87089e87c61a6ade8219dfb8acec3809bf95de Reviewed-on: https://go-review.googlesource.com/c/net/+/545217 LUCI-TryBot-Result: Go LUCI <[email protected]> Reviewed-by: Jonathan Amsterdam <[email protected]>
1 parent 577e44a commit b952594

File tree

5 files changed

+93
-22
lines changed

5 files changed

+93
-22
lines changed

internal/quic/conn.go

+28-3
Original file line numberDiff line numberDiff line change
@@ -369,12 +369,37 @@ func (c *Conn) wake() {
369369
}
370370

371371
// runOnLoop executes a function within the conn's loop goroutine.
372-
func (c *Conn) runOnLoop(f func(now time.Time, c *Conn)) error {
372+
func (c *Conn) runOnLoop(ctx context.Context, f func(now time.Time, c *Conn)) error {
373373
donec := make(chan struct{})
374-
c.sendMsg(func(now time.Time, c *Conn) {
374+
msg := func(now time.Time, c *Conn) {
375375
defer close(donec)
376376
f(now, c)
377-
})
377+
}
378+
if c.testHooks != nil {
379+
// In tests, we can't rely on being able to send a message immediately:
380+
// c.msgc might be full, and testConnHooks.nextMessage might be waiting
381+
// for us to block before it processes the next message.
382+
// To avoid a deadlock, we send the message in waitUntil.
383+
// If msgc is empty, the message is buffered.
384+
// If msgc is full, we block and let nextMessage process the queue.
385+
msgc := c.msgc
386+
c.testHooks.waitUntil(ctx, func() bool {
387+
for {
388+
select {
389+
case msgc <- msg:
390+
msgc = nil // send msg only once
391+
case <-donec:
392+
return true
393+
case <-c.donec:
394+
return true
395+
default:
396+
return false
397+
}
398+
}
399+
})
400+
} else {
401+
c.sendMsg(msg)
402+
}
378403
select {
379404
case <-donec:
380405
case <-c.donec:

internal/quic/conn_async_test.go

+1
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ func runAsync[T any](tc *testConn, f func(context.Context) (T, error)) *asyncOp[
125125
})
126126
// Wait for the operation to either finish or block.
127127
<-as.notify
128+
tc.wait()
128129
return a
129130
}
130131

internal/quic/conn_streams.go

+8-12
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,8 @@ import (
1414
)
1515

1616
type streamsState struct {
17-
queue queue[*Stream] // new, peer-created streams
18-
19-
streamsMu sync.Mutex
20-
streams map[streamID]*Stream
17+
queue queue[*Stream] // new, peer-created streams
18+
streams map[streamID]*Stream
2119

2220
// Limits on the number of streams, indexed by streamType.
2321
localLimit [streamTypeCount]localStreamLimits
@@ -82,9 +80,6 @@ func (c *Conn) NewSendOnlyStream(ctx context.Context) (*Stream, error) {
8280
}
8381

8482
func (c *Conn) newLocalStream(ctx context.Context, styp streamType) (*Stream, error) {
85-
c.streams.streamsMu.Lock()
86-
defer c.streams.streamsMu.Unlock()
87-
8883
num, err := c.streams.localLimit[styp].open(ctx, c)
8984
if err != nil {
9085
return nil, err
@@ -100,7 +95,12 @@ func (c *Conn) newLocalStream(ctx context.Context, styp streamType) (*Stream, er
10095
s.inUnlock()
10196
s.outUnlock()
10297

103-
c.streams.streams[s.id] = s
98+
// Modify c.streams on the conn's loop.
99+
if err := c.runOnLoop(ctx, func(now time.Time, c *Conn) {
100+
c.streams.streams[s.id] = s
101+
}); err != nil {
102+
return nil, err
103+
}
104104
return s, nil
105105
}
106106

@@ -119,8 +119,6 @@ const (
119119
// streamForID returns the stream with the given id.
120120
// If the stream does not exist, it returns nil.
121121
func (c *Conn) streamForID(id streamID) *Stream {
122-
c.streams.streamsMu.Lock()
123-
defer c.streams.streamsMu.Unlock()
124122
return c.streams.streams[id]
125123
}
126124

@@ -146,8 +144,6 @@ func (c *Conn) streamForFrame(now time.Time, id streamID, ftype streamFrameType)
146144
}
147145
}
148146

149-
c.streams.streamsMu.Lock()
150-
defer c.streams.streamsMu.Unlock()
151147
s, isOpen := c.streams.streams[id]
152148
if s != nil {
153149
return s

internal/quic/conn_streams_test.go

+44
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"fmt"
1212
"io"
1313
"math"
14+
"sync"
1415
"testing"
1516
)
1617

@@ -478,3 +479,46 @@ func TestStreamsCreateAndCloseRemote(t *testing.T) {
478479
t.Fatalf("after test, stream send queue is not empty; should be")
479480
}
480481
}
482+
483+
func TestStreamsCreateConcurrency(t *testing.T) {
484+
cli, srv := newLocalConnPair(t, &Config{}, &Config{})
485+
486+
srvdone := make(chan int)
487+
go func() {
488+
defer close(srvdone)
489+
for streams := 0; ; streams++ {
490+
s, err := srv.AcceptStream(context.Background())
491+
if err != nil {
492+
srvdone <- streams
493+
return
494+
}
495+
s.Close()
496+
}
497+
}()
498+
499+
var wg sync.WaitGroup
500+
const concurrency = 10
501+
const streams = 10
502+
for i := 0; i < concurrency; i++ {
503+
wg.Add(1)
504+
go func() {
505+
defer wg.Done()
506+
for j := 0; j < streams; j++ {
507+
s, err := cli.NewStream(context.Background())
508+
if err != nil {
509+
t.Errorf("NewStream: %v", err)
510+
return
511+
}
512+
s.Flush()
513+
s.Close()
514+
}
515+
}()
516+
}
517+
wg.Wait()
518+
519+
cli.Abort(nil)
520+
srv.Abort(nil)
521+
if got, want := <-srvdone, concurrency*streams; got != want {
522+
t.Errorf("accepted %v streams, want %v", got, want)
523+
}
524+
}

internal/quic/conn_test.go

+12-7
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,25 @@ func TestConnTestConn(t *testing.T) {
3030
t.Errorf("new conn timeout=%v, want %v (max_idle_timeout)", got, want)
3131
}
3232

33-
var ranAt time.Time
34-
tc.conn.runOnLoop(func(now time.Time, c *Conn) {
35-
ranAt = now
36-
})
33+
ranAt, _ := runAsync(tc, func(ctx context.Context) (when time.Time, _ error) {
34+
tc.conn.runOnLoop(ctx, func(now time.Time, c *Conn) {
35+
when = now
36+
})
37+
return
38+
}).result()
3739
if !ranAt.Equal(tc.endpoint.now) {
3840
t.Errorf("func ran on loop at %v, want %v", ranAt, tc.endpoint.now)
3941
}
4042
tc.wait()
4143

4244
nextTime := tc.endpoint.now.Add(defaultMaxIdleTimeout / 2)
4345
tc.advanceTo(nextTime)
44-
tc.conn.runOnLoop(func(now time.Time, c *Conn) {
45-
ranAt = now
46-
})
46+
ranAt, _ = runAsync(tc, func(ctx context.Context) (when time.Time, _ error) {
47+
tc.conn.runOnLoop(ctx, func(now time.Time, c *Conn) {
48+
when = now
49+
})
50+
return
51+
}).result()
4752
if !ranAt.Equal(nextTime) {
4853
t.Errorf("func ran on loop at %v, want %v", ranAt, nextTime)
4954
}

0 commit comments

Comments
 (0)