Skip to content

Commit 84a21fe

Browse files
authored
Merge pull request #755 from realdave/748-prevent-panic
Ensure gRPC ClientStream override methods do not panic
2 parents 029c0ee + 316b4fd commit 84a21fe

File tree

2 files changed

+28
-13
lines changed

2 files changed

+28
-13
lines changed

plugin/grpctrace/interceptor.go

+25-13
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,10 @@ const (
133133
type clientStream struct {
134134
grpc.ClientStream
135135

136-
desc *grpc.StreamDesc
137-
events chan streamEvent
138-
finished chan error
136+
desc *grpc.StreamDesc
137+
events chan streamEvent
138+
eventsDone chan struct{}
139+
finished chan error
139140

140141
receivedMessageID int
141142
sentMessageID int
@@ -147,11 +148,11 @@ func (w *clientStream) RecvMsg(m interface{}) error {
147148
err := w.ClientStream.RecvMsg(m)
148149

149150
if err == nil && !w.desc.ServerStreams {
150-
w.events <- streamEvent{receiveEndEvent, nil}
151+
w.sendStreamEvent(receiveEndEvent, nil)
151152
} else if err == io.EOF {
152-
w.events <- streamEvent{receiveEndEvent, nil}
153+
w.sendStreamEvent(receiveEndEvent, nil)
153154
} else if err != nil {
154-
w.events <- streamEvent{errorEvent, err}
155+
w.sendStreamEvent(errorEvent, err)
155156
} else {
156157
w.receivedMessageID++
157158
messageReceived.Event(w.Context(), w.receivedMessageID, m)
@@ -167,7 +168,7 @@ func (w *clientStream) SendMsg(m interface{}) error {
167168
messageSent.Event(w.Context(), w.sentMessageID, m)
168169

169170
if err != nil {
170-
w.events <- streamEvent{errorEvent, err}
171+
w.sendStreamEvent(errorEvent, err)
171172
}
172173

173174
return err
@@ -177,7 +178,7 @@ func (w *clientStream) Header() (metadata.MD, error) {
177178
md, err := w.ClientStream.Header()
178179

179180
if err != nil {
180-
w.events <- streamEvent{errorEvent, err}
181+
w.sendStreamEvent(errorEvent, err)
181182
}
182183

183184
return md, err
@@ -187,9 +188,9 @@ func (w *clientStream) CloseSend() error {
187188
err := w.ClientStream.CloseSend()
188189

189190
if err != nil {
190-
w.events <- streamEvent{errorEvent, err}
191+
w.sendStreamEvent(errorEvent, err)
191192
} else {
192-
w.events <- streamEvent{closeEvent, nil}
193+
w.sendStreamEvent(closeEvent, nil)
193194
}
194195

195196
return err
@@ -201,10 +202,13 @@ const (
201202
)
202203

203204
func wrapClientStream(s grpc.ClientStream, desc *grpc.StreamDesc) *clientStream {
204-
events := make(chan streamEvent, 1)
205+
events := make(chan streamEvent)
206+
eventsDone := make(chan struct{})
205207
finished := make(chan error)
206208

207209
go func() {
210+
defer close(eventsDone)
211+
208212
// Both streams have to be closed
209213
state := byte(0)
210214

@@ -216,12 +220,12 @@ func wrapClientStream(s grpc.ClientStream, desc *grpc.StreamDesc) *clientStream
216220
state |= receiveEndedState
217221
case errorEvent:
218222
finished <- event.Err
219-
close(events)
223+
return
220224
}
221225

222226
if state == clientClosedState|receiveEndedState {
223227
finished <- nil
224-
close(events)
228+
return
225229
}
226230
}
227231
}()
@@ -230,10 +234,18 @@ func wrapClientStream(s grpc.ClientStream, desc *grpc.StreamDesc) *clientStream
230234
ClientStream: s,
231235
desc: desc,
232236
events: events,
237+
eventsDone: eventsDone,
233238
finished: finished,
234239
}
235240
}
236241

242+
func (w *clientStream) sendStreamEvent(eventType streamEventType, err error) {
243+
select {
244+
case <-w.eventsDone:
245+
case w.events <- streamEvent{Type: eventType, Err: err}:
246+
}
247+
}
248+
237249
// StreamClientInterceptor returns a grpc.StreamClientInterceptor suitable
238250
// for use in a grpc.Dial call.
239251
//

plugin/grpctrace/interceptor_test.go

+3
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,9 @@ func TestStreamClientInterceptor(t *testing.T) {
376376
validate("SENT", events[i].Attributes)
377377
validate("RECEIVED", events[i+1].Attributes)
378378
}
379+
380+
// ensure CloseSend can be subsequently called
381+
_ = streamClient.CloseSend()
379382
}
380383

381384
func TestServerInterceptorError(t *testing.T) {

0 commit comments

Comments
 (0)