@@ -133,9 +133,10 @@ const (
133
133
type clientStream struct {
134
134
grpc.ClientStream
135
135
136
- desc * grpc.StreamDesc
137
- events chan streamEvent
138
- finished chan error
136
+ desc * grpc.StreamDesc
137
+ events chan streamEvent
138
+ eventsDone chan struct {}
139
+ finished chan error
139
140
140
141
receivedMessageID int
141
142
sentMessageID int
@@ -147,11 +148,11 @@ func (w *clientStream) RecvMsg(m interface{}) error {
147
148
err := w .ClientStream .RecvMsg (m )
148
149
149
150
if err == nil && ! w .desc .ServerStreams {
150
- w .events <- streamEvent { receiveEndEvent , nil }
151
+ w .sendStreamEvent ( receiveEndEvent , nil )
151
152
} else if err == io .EOF {
152
- w .events <- streamEvent { receiveEndEvent , nil }
153
+ w .sendStreamEvent ( receiveEndEvent , nil )
153
154
} else if err != nil {
154
- w .events <- streamEvent { errorEvent , err }
155
+ w .sendStreamEvent ( errorEvent , err )
155
156
} else {
156
157
w .receivedMessageID ++
157
158
messageReceived .Event (w .Context (), w .receivedMessageID , m )
@@ -167,7 +168,7 @@ func (w *clientStream) SendMsg(m interface{}) error {
167
168
messageSent .Event (w .Context (), w .sentMessageID , m )
168
169
169
170
if err != nil {
170
- w .events <- streamEvent { errorEvent , err }
171
+ w .sendStreamEvent ( errorEvent , err )
171
172
}
172
173
173
174
return err
@@ -177,7 +178,7 @@ func (w *clientStream) Header() (metadata.MD, error) {
177
178
md , err := w .ClientStream .Header ()
178
179
179
180
if err != nil {
180
- w .events <- streamEvent { errorEvent , err }
181
+ w .sendStreamEvent ( errorEvent , err )
181
182
}
182
183
183
184
return md , err
@@ -187,9 +188,9 @@ func (w *clientStream) CloseSend() error {
187
188
err := w .ClientStream .CloseSend ()
188
189
189
190
if err != nil {
190
- w .events <- streamEvent { errorEvent , err }
191
+ w .sendStreamEvent ( errorEvent , err )
191
192
} else {
192
- w .events <- streamEvent { closeEvent , nil }
193
+ w .sendStreamEvent ( closeEvent , nil )
193
194
}
194
195
195
196
return err
@@ -201,10 +202,13 @@ const (
201
202
)
202
203
203
204
func wrapClientStream (s grpc.ClientStream , desc * grpc.StreamDesc ) * clientStream {
204
- events := make (chan streamEvent , 1 )
205
+ events := make (chan streamEvent )
206
+ eventsDone := make (chan struct {})
205
207
finished := make (chan error )
206
208
207
209
go func () {
210
+ defer close (eventsDone )
211
+
208
212
// Both streams have to be closed
209
213
state := byte (0 )
210
214
@@ -216,12 +220,12 @@ func wrapClientStream(s grpc.ClientStream, desc *grpc.StreamDesc) *clientStream
216
220
state |= receiveEndedState
217
221
case errorEvent :
218
222
finished <- event .Err
219
- close ( events )
223
+ return
220
224
}
221
225
222
226
if state == clientClosedState | receiveEndedState {
223
227
finished <- nil
224
- close ( events )
228
+ return
225
229
}
226
230
}
227
231
}()
@@ -230,10 +234,18 @@ func wrapClientStream(s grpc.ClientStream, desc *grpc.StreamDesc) *clientStream
230
234
ClientStream : s ,
231
235
desc : desc ,
232
236
events : events ,
237
+ eventsDone : eventsDone ,
233
238
finished : finished ,
234
239
}
235
240
}
236
241
242
+ func (w * clientStream ) sendStreamEvent (eventType streamEventType , err error ) {
243
+ select {
244
+ case <- w .eventsDone :
245
+ case w .events <- streamEvent {Type : eventType , Err : err }:
246
+ }
247
+ }
248
+
237
249
// StreamClientInterceptor returns a grpc.StreamClientInterceptor suitable
238
250
// for use in a grpc.Dial call.
239
251
//
0 commit comments