Skip to content

Commit 0094132

Browse files
committed
Fix a couple streaming interceptor bugs
- The wrap function for server interceptors was working as if it were a client interceptor: expecting the stream to be the result of calling the endpoint or the next unary interceptor in the chain. It now properly wraps the stream in the raw payload (which is actually a special `EndpointStruct` with the stream and optionally the unary payload). - The unary `Payload` methods of server interceptor info structs for a streaming method assumed that the raw payload would be the payload type rather than the special `EndpointStruct` type. Now, the `Payload` methods properly expect the raw payload to either be the `EndpointStruct` or the payload type in the case when the server interceptor info struct is also used for a client interceptor.
1 parent fc0dd51 commit 0094132

9 files changed

+79
-48
lines changed

codegen/service/interceptors.go

+12
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ func interceptorFile(svc *Data, server bool) *codegen.File {
122122
Data: interceptors,
123123
FuncMap: map[string]any{
124124
"hasPrivateImplementationTypes": hasPrivateImplementationTypes,
125+
"hasEndpointStruct": hasEndpointStruct(server),
125126
},
126127
})
127128
}
@@ -226,6 +227,17 @@ func hasPrivateImplementationTypes(interceptors []*InterceptorData) bool {
226227
return false
227228
}
228229

230+
// hasEndpointStruct returns a function that returns true if the method has an endpoint struct
231+
// if server is true, otherwise it returns false.
232+
func hasEndpointStruct(server bool) func(*MethodInterceptorData) bool {
233+
if !server {
234+
return func(*MethodInterceptorData) bool { return false }
235+
}
236+
return func(m *MethodInterceptorData) bool {
237+
return m.ServerStream != nil && m.ServerStream.EndpointStruct != ""
238+
}
239+
}
240+
229241
// collectWrappedStreams returns a slice of streams to be wrapped by interceptor wrapper functions.
230242
func collectWrappedStreams(interceptors []*InterceptorData, server bool) []*StreamInterceptorData {
231243
var (

codegen/service/service_data.go

+5-2
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,7 @@ type (
227227
// function.
228228
MustClose bool
229229
// EndpointStruct is the name of the endpoint struct that holds a payload
230-
// reference (if any) and the endpoint server stream. It is set only if the
231-
// client sends a normal payload and server streams a result.
230+
// reference (if any) and the endpoint server stream.
232231
EndpointStruct string
233232
// Kind is the kind of the stream (payload, result or bidirectional).
234233
Kind expr.StreamKind
@@ -345,6 +344,9 @@ type (
345344
// MustClose indicates whether the stream should implement the Close()
346345
// function.
347346
MustClose bool
347+
// EndpointStruct is the name of the endpoint struct that holds a payload
348+
// reference (if any) and the endpoint server stream.
349+
EndpointStruct string
348350
}
349351

350352
// AttributeData describes a single attribute.
@@ -1314,6 +1316,7 @@ func buildInterceptorMethodData(i *expr.InterceptorExpr, md *MethodData) *Method
13141316
RecvWithContextName: md.ServerStream.RecvWithContextName,
13151317
RecvTypeRef: md.ServerStream.RecvTypeRef,
13161318
MustClose: md.ServerStream.MustClose,
1319+
EndpointStruct: md.ServerStream.EndpointStruct,
13171320
}
13181321
}
13191322
if md.ClientStream != nil {

codegen/service/templates/interceptors.go.tpl

+18
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,31 @@ func (info *{{ .Name }}Info) Payload() {{ .Name }}Payload {
2929
switch info.Method() {
3030
{{- range .Methods }}
3131
case "{{ .MethodName }}":
32+
{{- if hasEndpointStruct . }}
33+
switch pay := info.RawPayload().(type) {
34+
case *{{ .ServerStream.EndpointStruct }}:
35+
return &{{ .PayloadAccess }}{payload: pay.Payload}
36+
default:
37+
return &{{ .PayloadAccess }}{payload: pay.({{ .PayloadRef }})}
38+
}
39+
{{- else }}
3240
return &{{ .PayloadAccess }}{payload: info.RawPayload().({{ .PayloadRef }})}
41+
{{- end }}
3342
{{- end }}
3443
default:
3544
return nil
3645
}
3746
{{- else }}
47+
{{- if hasEndpointStruct (index .Methods 0) }}
48+
switch pay := info.RawPayload().(type) {
49+
case *{{ (index .Methods 0).ServerStream.EndpointStruct }}:
50+
return &{{ (index .Methods 0).PayloadAccess }}{payload: pay.Payload}
51+
default:
52+
return &{{ (index .Methods 0).PayloadAccess }}{payload: pay.({{ (index .Methods 0).PayloadRef }})}
53+
}
54+
{{- else }}
3855
return &{{ (index .Methods 0).PayloadAccess }}{payload: info.RawPayload().({{ (index .Methods 0).PayloadRef }})}
56+
{{- end }}
3957
{{- end }}
4058
}
4159
{{- end }}

codegen/service/templates/server_interceptor_wrappers.go.tpl

+14-17
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,8 @@
66
func wrap{{ .MethodName }}{{ $interceptor.Name }}(endpoint goa.Endpoint, i ServerInterceptors) goa.Endpoint {
77
return func(ctx context.Context, req any) (any, error) {
88
{{- if or $interceptor.HasStreamingPayloadAccess $interceptor.HasStreamingResultAccess }}
9-
{{- if $interceptor.HasPayloadAccess }}
10-
info := &{{ $interceptor.Name }}Info{
11-
service: "{{ $.Service }}",
12-
method: "{{ .MethodName }}",
13-
callType: goa.InterceptorUnary,
14-
rawPayload: req,
15-
}
16-
res, err := i.{{ $interceptor.Name }}(ctx, info, endpoint)
17-
{{- else }}
18-
res, err := endpoint(ctx, req)
19-
{{- end }}
20-
if err != nil {
21-
return res, err
22-
}
23-
stream := res.({{ .ServerStream.Interface }})
24-
return &wrapped{{ .ServerStream.Interface }}{
9+
stream := req.(*{{ .ServerStream.EndpointStruct }}).Stream
10+
req.(*{{ .ServerStream.EndpointStruct }}).Stream = &wrapped{{ .ServerStream.Interface }}{
2511
ctx: ctx,
2612
{{- if $interceptor.HasStreamingResultAccess }}
2713
sendWithContext: func(ctx context.Context, req {{ .ServerStream.SendTypeRef }}) error {
@@ -53,7 +39,18 @@ func wrap{{ .MethodName }}{{ $interceptor.Name }}(endpoint goa.Endpoint, i Serve
5339
},
5440
{{- end }}
5541
stream: stream,
56-
}, nil
42+
}
43+
{{- if $interceptor.HasPayloadAccess }}
44+
info := &{{ $interceptor.Name }}Info{
45+
service: "{{ $.Service }}",
46+
method: "{{ .MethodName }}",
47+
callType: goa.InterceptorUnary,
48+
rawPayload: req,
49+
}
50+
return i.{{ $interceptor.Name }}(ctx, info, endpoint)
51+
{{- else }}
52+
return endpoint(ctx, req)
53+
{{- end }}
5754
{{- else }}
5855
info := &{{ $interceptor.Name }}Info{
5956
service: "{{ $.Service }}",

codegen/service/testdata/interceptors/streaming-interceptors-with-read-payload-and-read-streaming-payload_interceptor_wrappers.go.golden

+10-13
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,8 @@ type wrappedMethodClientStream struct {
1919
// wrapLoggingMethod applies the logging server interceptor to endpoints.
2020
func wrapMethodLogging(endpoint goa.Endpoint, i ServerInterceptors) goa.Endpoint {
2121
return func(ctx context.Context, req any) (any, error) {
22-
info := &LoggingInfo{
23-
service: "StreamingInterceptorsWithReadPayloadAndReadStreamingPayload",
24-
method: "Method",
25-
callType: goa.InterceptorUnary,
26-
rawPayload: req,
27-
}
28-
res, err := i.Logging(ctx, info, endpoint)
29-
if err != nil {
30-
return res, err
31-
}
32-
stream := res.(MethodServerStream)
33-
return &wrappedMethodServerStream{
22+
stream := req.(*MethodEndpointInput).Stream
23+
req.(*MethodEndpointInput).Stream = &wrappedMethodServerStream{
3424
ctx: ctx,
3525
recvWithContext: func(ctx context.Context) (*MethodStreamingPayload, error) {
3626
info := &LoggingInfo{
@@ -45,7 +35,14 @@ func wrapMethodLogging(endpoint goa.Endpoint, i ServerInterceptors) goa.Endpoint
4535
return castRes, err
4636
},
4737
stream: stream,
48-
}, nil
38+
}
39+
info := &LoggingInfo{
40+
service: "StreamingInterceptorsWithReadPayloadAndReadStreamingPayload",
41+
method: "Method",
42+
callType: goa.InterceptorUnary,
43+
rawPayload: req,
44+
}
45+
return i.Logging(ctx, info, endpoint)
4946
}
5047
}
5148

codegen/service/testdata/interceptors/streaming-interceptors-with-read-payload-and-read-streaming-payload_service_interceptors.go.golden

+6-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,12 @@ func (info *LoggingInfo) RawPayload() any {
7373

7474
// Payload returns a type-safe accessor for the method payload.
7575
func (info *LoggingInfo) Payload() LoggingPayload {
76-
return &loggingMethodPayload{payload: info.RawPayload().(*MethodPayload)}
76+
switch pay := info.RawPayload().(type) {
77+
case *MethodEndpointInput:
78+
return &loggingMethodPayload{payload: pay.Payload}
79+
default:
80+
return &loggingMethodPayload{payload: pay.(*MethodPayload)}
81+
}
7782
}
7883

7984
// ClientStreamingPayload returns a type-safe accessor for the method streaming payload for a client-side interceptor.

codegen/service/testdata/interceptors/streaming-interceptors-with-read-payload_service_interceptors.go.golden

+6-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,12 @@ func (info *LoggingInfo) RawPayload() any {
6363

6464
// Payload returns a type-safe accessor for the method payload.
6565
func (info *LoggingInfo) Payload() LoggingPayload {
66-
return &loggingMethodPayload{payload: info.RawPayload().(*MethodPayload)}
66+
switch pay := info.RawPayload().(type) {
67+
case *MethodEndpointInput:
68+
return &loggingMethodPayload{payload: pay.Payload}
69+
default:
70+
return &loggingMethodPayload{payload: pay.(*MethodPayload)}
71+
}
6772
}
6873

6974
// Private implementation methods

codegen/service/testdata/interceptors/streaming-interceptors-with-read-streaming-result_interceptor_wrappers.go.golden

+4-7
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,8 @@ type wrappedMethodClientStream struct {
1919
// wrapLoggingMethod applies the logging server interceptor to endpoints.
2020
func wrapMethodLogging(endpoint goa.Endpoint, i ServerInterceptors) goa.Endpoint {
2121
return func(ctx context.Context, req any) (any, error) {
22-
res, err := endpoint(ctx, req)
23-
if err != nil {
24-
return res, err
25-
}
26-
stream := res.(MethodServerStream)
27-
return &wrappedMethodServerStream{
22+
stream := req.(*MethodEndpointInput).Stream
23+
req.(*MethodEndpointInput).Stream = &wrappedMethodServerStream{
2824
ctx: ctx,
2925
sendWithContext: func(ctx context.Context, req *MethodResult) error {
3026
info := &LoggingInfo{
@@ -40,7 +36,8 @@ func wrapMethodLogging(endpoint goa.Endpoint, i ServerInterceptors) goa.Endpoint
4036
return err
4137
},
4238
stream: stream,
43-
}, nil
39+
}
40+
return endpoint(ctx, req)
4441
}
4542
}
4643

codegen/service/testdata/interceptors/streaming-interceptors_interceptor_wrappers.go.golden

+4-7
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,8 @@ type wrappedMethodClientStream struct {
2121
// wrapLoggingMethod applies the logging server interceptor to endpoints.
2222
func wrapMethodLogging(endpoint goa.Endpoint, i ServerInterceptors) goa.Endpoint {
2323
return func(ctx context.Context, req any) (any, error) {
24-
res, err := endpoint(ctx, req)
25-
if err != nil {
26-
return res, err
27-
}
28-
stream := res.(MethodServerStream)
29-
return &wrappedMethodServerStream{
24+
stream := req.(*MethodEndpointInput).Stream
25+
req.(*MethodEndpointInput).Stream = &wrappedMethodServerStream{
3026
ctx: ctx,
3127
sendWithContext: func(ctx context.Context, req *MethodResult) error {
3228
info := &LoggingInfo{
@@ -54,7 +50,8 @@ func wrapMethodLogging(endpoint goa.Endpoint, i ServerInterceptors) goa.Endpoint
5450
return castRes, err
5551
},
5652
stream: stream,
57-
}, nil
53+
}
54+
return endpoint(ctx, req)
5855
}
5956
}
6057

0 commit comments

Comments
 (0)