Skip to content

Commit 2b88763

Browse files
feature(streamwork): make cancellation more predictable
1 parent 629ff7f commit 2b88763

File tree

6 files changed

+63
-0
lines changed

6 files changed

+63
-0
lines changed

streamwork/circuitbreaker.go

+12
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@ func CircuitBreakerWorker[T any](cb CircuitBreaker) Worker[T, T] {
102102
go func() {
103103
defer close(chanOut)
104104
for {
105+
if ctx.Err() != nil {
106+
return
107+
}
105108
select {
106109
case <-ctx.Done():
107110
return
@@ -110,6 +113,9 @@ func CircuitBreakerWorker[T any](cb CircuitBreaker) Worker[T, T] {
110113
return
111114
}
112115
if cb.IsPaused() {
116+
if ctx.Err() != nil {
117+
return
118+
}
113119
select {
114120
case <-ctx.Done():
115121
return
@@ -118,12 +124,18 @@ func CircuitBreakerWorker[T any](cb CircuitBreaker) Worker[T, T] {
118124
}
119125
vLoop:
120126
for {
127+
if ctx.Err() != nil {
128+
return
129+
}
121130
select {
122131
case <-ctx.Done():
123132
return
124133
case chanOut <- v:
125134
break vLoop
126135
case <-cb.pauseChan():
136+
if ctx.Err() != nil {
137+
return
138+
}
127139
select {
128140
case <-ctx.Done():
129141
return

streamwork/filter.go

+6
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ func Filter[TIN any](f func(ctx context.Context, v TIN) bool) Worker[TIN, TIN] {
99
go func() {
1010
defer close(chanOut)
1111
for {
12+
if ctx.Err() != nil {
13+
return
14+
}
1215
select {
1316
case <-ctx.Done():
1417
return
@@ -19,6 +22,9 @@ func Filter[TIN any](f func(ctx context.Context, v TIN) bool) Worker[TIN, TIN] {
1922
if !f(ctx, v) {
2023
continue
2124
}
25+
if ctx.Err() != nil {
26+
return
27+
}
2228
select {
2329
case <-ctx.Done():
2430
return

streamwork/flatten.go

+6
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ func Flatten[T any]() Worker[[]T, T] {
1111
go func() {
1212
defer close(chanOut)
1313
for {
14+
if ctx.Err() != nil {
15+
return
16+
}
1417
select {
1518
case <-ctx.Done():
1619
return
@@ -19,6 +22,9 @@ func Flatten[T any]() Worker[[]T, T] {
1922
return
2023
}
2124
for _, v := range batch {
25+
if ctx.Err() != nil {
26+
return
27+
}
2228
select {
2329
case <-ctx.Done():
2430
return

streamwork/parallel.go

+6
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,19 @@ func Parallelize[TIN any, TOUT any](instancesCount int, worker Worker[TIN, TOUT]
3030

3131
func pipeChannels[T any](ctx context.Context, src <-chan T, dest chan T) {
3232
for {
33+
if ctx.Err() != nil {
34+
return
35+
}
3336
select {
3437
case <-ctx.Done():
3538
return
3639
case v, ok := <-src:
3740
if !ok {
3841
return
3942
}
43+
if ctx.Err() != nil {
44+
return
45+
}
4046
select {
4147
case <-ctx.Done():
4248
return

streamwork/sources.go

+9
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ func ReadSlice[T any](s []T) Source[T] {
1414
go func() {
1515
defer close(c)
1616
for _, v := range s {
17+
if ctx.Err() != nil {
18+
return
19+
}
1720
select {
1821
case <-ctx.Done():
1922
return
@@ -48,6 +51,9 @@ func ReadSeq[T any](s iter.Seq[T], options ...StreamOption) Source[T] {
4851
go func() {
4952
defer close(c)
5053
for v := range s {
54+
if ctx.Err() != nil {
55+
return
56+
}
5157
select {
5258
case <-ctx.Done():
5359
return
@@ -76,6 +82,9 @@ func ReadSeqErr[T any](s iter.Seq2[T, error], options ...StreamOption) Source[T]
7682
}
7783
continue
7884
}
85+
if ctx.Err() != nil {
86+
return
87+
}
7988
select {
8089
case <-ctx.Done():
8190
return

streamwork/worker.go

+24
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,19 @@ func WorkerFunc[TIN any, TOUT any](f func(ctx context.Context, v TIN) TOUT) Work
1515
go func() {
1616
defer close(chanOut)
1717
for {
18+
if ctx.Err() != nil {
19+
return
20+
}
1821
select {
1922
case <-ctx.Done():
2023
return
2124
case v, ok := <-source:
2225
if !ok {
2326
return
2427
}
28+
if ctx.Err() != nil {
29+
return
30+
}
2531
select {
2632
case <-ctx.Done():
2733
return
@@ -43,6 +49,9 @@ func WorkerSeq[TIN any, TOUT any](f func(ctx context.Context, v iter.Seq[TIN]) i
4349
seqOut := f(
4450
ctx, func(yield func(TIN) bool) {
4551
for {
52+
if ctx.Err() != nil {
53+
return
54+
}
4655
select {
4756
case <-ctx.Done():
4857
return
@@ -60,6 +69,9 @@ func WorkerSeq[TIN any, TOUT any](f func(ctx context.Context, v iter.Seq[TIN]) i
6069
go func() {
6170
defer close(chanOut)
6271
for vOut := range seqOut {
72+
if ctx.Err() != nil {
73+
return
74+
}
6375
select {
6476
case <-ctx.Done():
6577
return
@@ -80,6 +92,9 @@ func WorkerSeqErr[TIN any, TOUT any](f func(ctx context.Context, v iter.Seq[TIN]
8092
seqOut := f(
8193
ctx, func(yield func(TIN) bool) {
8294
for {
95+
if ctx.Err() != nil {
96+
return
97+
}
8398
select {
8499
case <-ctx.Done():
85100
return
@@ -103,6 +118,9 @@ func WorkerSeqErr[TIN any, TOUT any](f func(ctx context.Context, v iter.Seq[TIN]
103118
}
104119
continue
105120
}
121+
if ctx.Err() != nil {
122+
return
123+
}
106124
select {
107125
case <-ctx.Done():
108126
return
@@ -123,6 +141,9 @@ func WorkerFuncErr[TIN any, TOUT any](f func(ctx context.Context, v TIN) (TOUT,
123141
go func() {
124142
defer close(chanOut)
125143
for {
144+
if ctx.Err() != nil {
145+
return
146+
}
126147
select {
127148
case <-ctx.Done():
128149
return
@@ -137,6 +158,9 @@ func WorkerFuncErr[TIN any, TOUT any](f func(ctx context.Context, v TIN) (TOUT,
137158
}
138159
continue
139160
}
161+
if ctx.Err() != nil {
162+
return
163+
}
140164
select {
141165
case <-ctx.Done():
142166
return

0 commit comments

Comments
 (0)