Skip to content

Commit d51a392

Browse files
authored
fix: fix concurrency bugs causing panics and race conditions in internal/syncutil (#921)
1. Remove the use of `atomic.Value` 2. Cancel all go-routines before releasing semaphore permits 3. Wait for schedule go-routines to release the semaphore permits before returning 4. Add unit tests Fixes #916, fixes #908 Signed-off-by: Lixia (Sylvia) Lei <[email protected]>
1 parent fdf2c51 commit d51a392

File tree

4 files changed

+235
-14
lines changed

4 files changed

+235
-14
lines changed

internal/syncutil/limit.go

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ package syncutil
1717

1818
import (
1919
"context"
20-
"sync/atomic"
2120

2221
"golang.org/x/sync/errgroup"
2322
"golang.org/x/sync/semaphore"
@@ -68,27 +67,41 @@ type GoFunc[T any] func(ctx context.Context, region *LimitedRegion, t T) error
6867

6968
// Go concurrently invokes fn on items.
7069
func Go[T any](ctx context.Context, limiter *semaphore.Weighted, fn GoFunc[T], items ...T) error {
70+
ctx, cancel := context.WithCancelCause(ctx)
71+
defer cancel(nil)
72+
7173
eg, egCtx := errgroup.WithContext(ctx)
72-
var egErr atomic.Value
7374
for _, item := range items {
7475
region := LimitRegion(egCtx, limiter)
7576
if err := region.Start(); err != nil {
76-
if egErr, ok := egErr.Load().(error); ok && egErr != nil {
77-
return egErr
78-
}
79-
return err
77+
cancel(err)
78+
// break loop instead of returning to allow previously scheduled
79+
// goroutines to finish their deferred region.End() calls
80+
break
8081
}
81-
eg.Go(func(t T) func() error {
82+
83+
eg.Go(func(t T, lr *LimitedRegion) func() error {
8284
return func() error {
83-
defer region.End()
84-
err := fn(egCtx, region, t)
85-
if err != nil {
86-
egErr.CompareAndSwap(nil, err)
85+
defer lr.End()
86+
87+
select {
88+
case <-egCtx.Done():
89+
// skip the task if the context is already cancelled
90+
return nil
91+
default:
92+
}
93+
94+
if err := fn(egCtx, lr, t); err != nil {
95+
cancel(err)
8796
return err
8897
}
8998
return nil
9099
}
91-
}(item))
100+
}(item, region))
101+
}
102+
103+
if err := eg.Wait(); err != nil {
104+
cancel(err)
92105
}
93-
return eg.Wait()
106+
return context.Cause(ctx)
94107
}

internal/syncutil/limit_test.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/*
2+
Copyright The ORAS Authors.
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
*/
15+
16+
package syncutil
17+
18+
import (
19+
"context"
20+
"errors"
21+
"sync/atomic"
22+
"testing"
23+
"time"
24+
25+
"golang.org/x/sync/semaphore"
26+
)
27+
28+
func TestLimitedRegion_Success(t *testing.T) {
29+
limiter := semaphore.NewWeighted(2)
30+
ctx := context.Background()
31+
var counter int32
32+
33+
err := Go(ctx, limiter, func(ctx context.Context, region *LimitedRegion, i int) error {
34+
// just sleeps a little and increments counter to simulate task
35+
time.Sleep(10 * time.Millisecond)
36+
atomic.AddInt32(&counter, 1)
37+
return nil
38+
}, 1, 2, 3, 4, 5)
39+
40+
if err != nil {
41+
t.Fatalf("expected no error, got %v", err)
42+
}
43+
44+
// after everything finishes, we expect counter to be 5
45+
if want := 5; atomic.LoadInt32(&counter) != int32(want) {
46+
t.Errorf("expected counter == %v, got %v", want, counter)
47+
}
48+
49+
// when all work is done the semaphore should have all permits available
50+
if !limiter.TryAcquire(2) {
51+
t.Error("semaphore permits were not fully released at the end")
52+
}
53+
limiter.Release(2)
54+
}
55+
56+
func TestLimitedRegion_Cancellation(t *testing.T) {
57+
limiter := semaphore.NewWeighted(2)
58+
ctx := context.Background()
59+
var counter int32
60+
61+
errTest := errors.New("test error")
62+
err := Go(ctx, limiter, func(ctx context.Context, region *LimitedRegion, i int) error {
63+
if i < 0 {
64+
// trigger an error on negative values
65+
return errTest
66+
}
67+
// just sleeps a little and increments counter to simulate task
68+
time.Sleep(50 * time.Millisecond)
69+
atomic.AddInt32(&counter, 1)
70+
return nil
71+
}, 1, -1, 2, 0, -2)
72+
73+
// we expect the returned error to be errTest.
74+
if !errors.Is(err, errTest) {
75+
t.Fatalf("expected error %v; got %v", errTest, err)
76+
}
77+
78+
// after everything finishes, we expect counter to be smaller than 5.
79+
if max := 5; atomic.LoadInt32(&counter) >= int32(max) {
80+
t.Errorf("expected counter < %v, got %v", max, counter)
81+
}
82+
83+
// when all work is done the semaphore should have all permits available
84+
if !limiter.TryAcquire(2) {
85+
t.Error("semaphore permit was not released after error cancellation")
86+
}
87+
limiter.Release(2)
88+
}

internal/syncutil/limitgroup.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import (
2121
"golang.org/x/sync/errgroup"
2222
)
2323

24-
// A LimitedGroup is a collection of goroutines working on subtasks that are part of
24+
// LimitedGroup is a collection of goroutines working on subtasks that are part of
2525
// the same overall task.
2626
type LimitedGroup struct {
2727
grp *errgroup.Group

internal/syncutil/limitgroup_test.go

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
/*
2+
Copyright The ORAS Authors.
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
*/
15+
16+
package syncutil
17+
18+
import (
19+
"context"
20+
"errors"
21+
"sync/atomic"
22+
"testing"
23+
"time"
24+
)
25+
26+
func TestLimitedGroup_Success(t *testing.T) {
27+
ctx := context.Background()
28+
numTasks := 5
29+
30+
// Create a limited group with a concurrency limit of 2.
31+
lg, _ := LimitGroup(ctx, 2)
32+
var counter int32
33+
34+
for range numTasks {
35+
lg.Go(func() error {
36+
// simulate some work.
37+
time.Sleep(10 * time.Millisecond)
38+
atomic.AddInt32(&counter, 1)
39+
return nil
40+
})
41+
}
42+
43+
if err := lg.Wait(); err != nil {
44+
t.Fatalf("unexpected error: %v", err)
45+
}
46+
47+
if got := atomic.LoadInt32(&counter); got != int32(numTasks) {
48+
t.Errorf("expected counter %d, got %d", numTasks, got)
49+
}
50+
}
51+
52+
func TestLimitedGroup_Error(t *testing.T) {
53+
ctx := context.Background()
54+
lg, _ := LimitGroup(ctx, 2)
55+
errTest := errors.New("test error")
56+
var executed int32
57+
58+
lg.Go(func() error {
59+
// delay a bit so that other tasks are scheduled.
60+
time.Sleep(20 * time.Millisecond)
61+
atomic.AddInt32(&executed, 1)
62+
return errTest
63+
})
64+
65+
// simulates a slower, normal task.
66+
lg.Go(func() error {
67+
// wait until cancellation is (hopefully) in effect.
68+
time.Sleep(50 * time.Millisecond)
69+
atomic.AddInt32(&executed, 1)
70+
return nil
71+
})
72+
73+
err := lg.Wait()
74+
if !errors.Is(err, errTest) {
75+
t.Fatalf("expected error %v, got %v", errTest, err)
76+
}
77+
78+
if atomic.LoadInt32(&executed) < 1 {
79+
t.Errorf("expected at least one task executed, got %d", executed)
80+
}
81+
}
82+
83+
func TestLimitedGroup_Limit(t *testing.T) {
84+
ctx := context.Background()
85+
limit := 2
86+
lg, _ := LimitGroup(ctx, limit)
87+
var concurrent, maxConcurrent int32
88+
numTasks := 10
89+
90+
for range numTasks {
91+
lg.Go(func() error {
92+
// increment the count of concurrently active tasks.
93+
cur := atomic.AddInt32(&concurrent, 1)
94+
// update the max concurrent tasks if needed.
95+
for {
96+
prevMax := atomic.LoadInt32(&maxConcurrent)
97+
if cur > prevMax {
98+
if atomic.CompareAndSwapInt32(&maxConcurrent, prevMax, cur) {
99+
break
100+
}
101+
} else {
102+
break
103+
}
104+
}
105+
106+
// simulate a short task.
107+
time.Sleep(20 * time.Millisecond)
108+
atomic.AddInt32(&concurrent, -1)
109+
return nil
110+
})
111+
}
112+
113+
if err := lg.Wait(); err != nil {
114+
t.Fatalf("unexpected error: %v", err)
115+
}
116+
117+
if maxConcurrent > int32(limit) {
118+
t.Errorf("expected max concurrent tasks <= %d, got %d", limit, maxConcurrent)
119+
}
120+
}

0 commit comments

Comments
 (0)