Skip to content

Commit fcf44e9

Browse files
authored
Fix a few streaming interceptor bugs (#3655)
* Fix a couple streaming interceptor bugs - The wrap function for server interceptors was working as if it were a client interceptor: expecting the stream to be the result of calling the endpoint or the next unary interceptor in the chain. It now properly wraps the stream in the raw payload (which is actually a special `EndpointStruct` with the stream and optionally the unary payload). - The unary `Payload` methods of server interceptor info structs for a streaming method assumed that the raw payload would be the payload type rather than the special `EndpointStruct` type. Now, the `Payload` methods properly expect the raw payload to either be the `EndpointStruct` or the payload type in the case when the server interceptor info struct is also used for a client interceptor. * Fix a generation bug when an interceptor does not have any accessors - With the change of `goa.InterceptorInfo` from a struct to an interface, the generation of implementation of the interface methods on the generated interceptor info structs became necessary. The code for those methods was included in a section of the template that was conditional on the interceptor having accessors. Now, the conditional has been moved to a more appropriate place so the interface methods are implemented.
1 parent fc0dd51 commit fcf44e9

15 files changed

+246
-49
lines changed

codegen/service/interceptors.go

+12
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ func interceptorFile(svc *Data, server bool) *codegen.File {
122122
Data: interceptors,
123123
FuncMap: map[string]any{
124124
"hasPrivateImplementationTypes": hasPrivateImplementationTypes,
125+
"hasEndpointStruct": hasEndpointStruct(server),
125126
},
126127
})
127128
}
@@ -226,6 +227,17 @@ func hasPrivateImplementationTypes(interceptors []*InterceptorData) bool {
226227
return false
227228
}
228229

230+
// hasEndpointStruct returns a function that returns true if the method has an endpoint struct
231+
// if server is true, otherwise it returns false.
232+
func hasEndpointStruct(server bool) func(*MethodInterceptorData) bool {
233+
if !server {
234+
return func(*MethodInterceptorData) bool { return false }
235+
}
236+
return func(m *MethodInterceptorData) bool {
237+
return m.ServerStream != nil && m.ServerStream.EndpointStruct != ""
238+
}
239+
}
240+
229241
// collectWrappedStreams returns a slice of streams to be wrapped by interceptor wrapper functions.
230242
func collectWrappedStreams(interceptors []*InterceptorData, server bool) []*StreamInterceptorData {
231243
var (

codegen/service/service_data.go

+5-2
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,7 @@ type (
227227
// function.
228228
MustClose bool
229229
// EndpointStruct is the name of the endpoint struct that holds a payload
230-
// reference (if any) and the endpoint server stream. It is set only if the
231-
// client sends a normal payload and server streams a result.
230+
// reference (if any) and the endpoint server stream.
232231
EndpointStruct string
233232
// Kind is the kind of the stream (payload, result or bidirectional).
234233
Kind expr.StreamKind
@@ -345,6 +344,9 @@ type (
345344
// MustClose indicates whether the stream should implement the Close()
346345
// function.
347346
MustClose bool
347+
// EndpointStruct is the name of the endpoint struct that holds a payload
348+
// reference (if any) and the endpoint server stream.
349+
EndpointStruct string
348350
}
349351

350352
// AttributeData describes a single attribute.
@@ -1314,6 +1316,7 @@ func buildInterceptorMethodData(i *expr.InterceptorExpr, md *MethodData) *Method
13141316
RecvWithContextName: md.ServerStream.RecvWithContextName,
13151317
RecvTypeRef: md.ServerStream.RecvTypeRef,
13161318
MustClose: md.ServerStream.MustClose,
1319+
EndpointStruct: md.ServerStream.EndpointStruct,
13171320
}
13181321
}
13191322
if md.ClientStream != nil {

codegen/service/templates/interceptors.go.tpl

+19-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
{{- if hasPrivateImplementationTypes . }}
21
// Public accessor methods for Info types
32
{{- range . }}
43

@@ -29,13 +28,31 @@ func (info *{{ .Name }}Info) Payload() {{ .Name }}Payload {
2928
switch info.Method() {
3029
{{- range .Methods }}
3130
case "{{ .MethodName }}":
31+
{{- if hasEndpointStruct . }}
32+
switch pay := info.RawPayload().(type) {
33+
case *{{ .ServerStream.EndpointStruct }}:
34+
return &{{ .PayloadAccess }}{payload: pay.Payload}
35+
default:
36+
return &{{ .PayloadAccess }}{payload: pay.({{ .PayloadRef }})}
37+
}
38+
{{- else }}
3239
return &{{ .PayloadAccess }}{payload: info.RawPayload().({{ .PayloadRef }})}
40+
{{- end }}
3341
{{- end }}
3442
default:
3543
return nil
3644
}
3745
{{- else }}
46+
{{- if hasEndpointStruct (index .Methods 0) }}
47+
switch pay := info.RawPayload().(type) {
48+
case *{{ (index .Methods 0).ServerStream.EndpointStruct }}:
49+
return &{{ (index .Methods 0).PayloadAccess }}{payload: pay.Payload}
50+
default:
51+
return &{{ (index .Methods 0).PayloadAccess }}{payload: pay.({{ (index .Methods 0).PayloadRef }})}
52+
}
53+
{{- else }}
3854
return &{{ (index .Methods 0).PayloadAccess }}{payload: info.RawPayload().({{ (index .Methods 0).PayloadRef }})}
55+
{{- end }}
3956
{{- end }}
4057
}
4158
{{- end }}
@@ -131,6 +148,7 @@ func (info *{{ .Name }}Info) ServerStreamingResult() {{ .Name }}StreamingResult
131148
{{- end }}
132149
{{- end }}
133150

151+
{{- if hasPrivateImplementationTypes . }}
134152
// Private implementation methods
135153
{{- range . }}
136154
{{ $interceptor := . }}

codegen/service/templates/server_interceptor_wrappers.go.tpl

+14-17
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,8 @@
66
func wrap{{ .MethodName }}{{ $interceptor.Name }}(endpoint goa.Endpoint, i ServerInterceptors) goa.Endpoint {
77
return func(ctx context.Context, req any) (any, error) {
88
{{- if or $interceptor.HasStreamingPayloadAccess $interceptor.HasStreamingResultAccess }}
9-
{{- if $interceptor.HasPayloadAccess }}
10-
info := &{{ $interceptor.Name }}Info{
11-
service: "{{ $.Service }}",
12-
method: "{{ .MethodName }}",
13-
callType: goa.InterceptorUnary,
14-
rawPayload: req,
15-
}
16-
res, err := i.{{ $interceptor.Name }}(ctx, info, endpoint)
17-
{{- else }}
18-
res, err := endpoint(ctx, req)
19-
{{- end }}
20-
if err != nil {
21-
return res, err
22-
}
23-
stream := res.({{ .ServerStream.Interface }})
24-
return &wrapped{{ .ServerStream.Interface }}{
9+
stream := req.(*{{ .ServerStream.EndpointStruct }}).Stream
10+
req.(*{{ .ServerStream.EndpointStruct }}).Stream = &wrapped{{ .ServerStream.Interface }}{
2511
ctx: ctx,
2612
{{- if $interceptor.HasStreamingResultAccess }}
2713
sendWithContext: func(ctx context.Context, req {{ .ServerStream.SendTypeRef }}) error {
@@ -53,7 +39,18 @@ func wrap{{ .MethodName }}{{ $interceptor.Name }}(endpoint goa.Endpoint, i Serve
5339
},
5440
{{- end }}
5541
stream: stream,
56-
}, nil
42+
}
43+
{{- if $interceptor.HasPayloadAccess }}
44+
info := &{{ $interceptor.Name }}Info{
45+
service: "{{ $.Service }}",
46+
method: "{{ .MethodName }}",
47+
callType: goa.InterceptorUnary,
48+
rawPayload: req,
49+
}
50+
return i.{{ $interceptor.Name }}(ctx, info, endpoint)
51+
{{- else }}
52+
return endpoint(ctx, req)
53+
{{- end }}
5754
{{- else }}
5855
info := &{{ $interceptor.Name }}Info{
5956
service: "{{ $.Service }}",

codegen/service/testdata/interceptors/multiple-interceptors_client_interceptors.go.golden

+41
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,44 @@ func WrapMethodClientEndpoint(endpoint goa.Endpoint, i ClientInterceptors) goa.E
3535
return endpoint
3636
}
3737

38+
// Public accessor methods for Info types
39+
40+
// Service returns the name of the service handling the request.
41+
func (info *Test2Info) Service() string {
42+
return info.service
43+
}
44+
45+
// Method returns the name of the method handling the request.
46+
func (info *Test2Info) Method() string {
47+
return info.method
48+
}
49+
50+
// CallType returns the type of call the interceptor is handling.
51+
func (info *Test2Info) CallType() goa.InterceptorCallType {
52+
return info.callType
53+
}
54+
55+
// RawPayload returns the raw payload of the request.
56+
func (info *Test2Info) RawPayload() any {
57+
return info.rawPayload
58+
}
59+
60+
// Service returns the name of the service handling the request.
61+
func (info *Test4Info) Service() string {
62+
return info.service
63+
}
64+
65+
// Method returns the name of the method handling the request.
66+
func (info *Test4Info) Method() string {
67+
return info.method
68+
}
69+
70+
// CallType returns the type of call the interceptor is handling.
71+
func (info *Test4Info) CallType() goa.InterceptorCallType {
72+
return info.callType
73+
}
74+
75+
// RawPayload returns the raw payload of the request.
76+
func (info *Test4Info) RawPayload() any {
77+
return info.rawPayload
78+
}

codegen/service/testdata/interceptors/multiple-interceptors_service_interceptors.go.golden

+41
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,44 @@ func WrapMethodEndpoint(endpoint goa.Endpoint, i ServerInterceptors) goa.Endpoin
3535
return endpoint
3636
}
3737

38+
// Public accessor methods for Info types
39+
40+
// Service returns the name of the service handling the request.
41+
func (info *TestInfo) Service() string {
42+
return info.service
43+
}
44+
45+
// Method returns the name of the method handling the request.
46+
func (info *TestInfo) Method() string {
47+
return info.method
48+
}
49+
50+
// CallType returns the type of call the interceptor is handling.
51+
func (info *TestInfo) CallType() goa.InterceptorCallType {
52+
return info.callType
53+
}
54+
55+
// RawPayload returns the raw payload of the request.
56+
func (info *TestInfo) RawPayload() any {
57+
return info.rawPayload
58+
}
59+
60+
// Service returns the name of the service handling the request.
61+
func (info *Test3Info) Service() string {
62+
return info.service
63+
}
64+
65+
// Method returns the name of the method handling the request.
66+
func (info *Test3Info) Method() string {
67+
return info.method
68+
}
69+
70+
// CallType returns the type of call the interceptor is handling.
71+
func (info *Test3Info) CallType() goa.InterceptorCallType {
72+
return info.callType
73+
}
74+
75+
// RawPayload returns the raw payload of the request.
76+
func (info *Test3Info) RawPayload() any {
77+
return info.rawPayload
78+
}

codegen/service/testdata/interceptors/single-api-server-interceptor_service_interceptors.go.golden

+21
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,24 @@ func WrapMethod2Endpoint(endpoint goa.Endpoint, i ServerInterceptors) goa.Endpoi
3232
return endpoint
3333
}
3434

35+
// Public accessor methods for Info types
36+
37+
// Service returns the name of the service handling the request.
38+
func (info *LoggingInfo) Service() string {
39+
return info.service
40+
}
41+
42+
// Method returns the name of the method handling the request.
43+
func (info *LoggingInfo) Method() string {
44+
return info.method
45+
}
46+
47+
// CallType returns the type of call the interceptor is handling.
48+
func (info *LoggingInfo) CallType() goa.InterceptorCallType {
49+
return info.callType
50+
}
51+
52+
// RawPayload returns the raw payload of the request.
53+
func (info *LoggingInfo) RawPayload() any {
54+
return info.rawPayload
55+
}

codegen/service/testdata/interceptors/single-client-interceptor_client_interceptors.go.golden

+21
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,24 @@ func WrapMethodClientEndpoint(endpoint goa.Endpoint, i ClientInterceptors) goa.E
2525
return endpoint
2626
}
2727

28+
// Public accessor methods for Info types
29+
30+
// Service returns the name of the service handling the request.
31+
func (info *TracingInfo) Service() string {
32+
return info.service
33+
}
34+
35+
// Method returns the name of the method handling the request.
36+
func (info *TracingInfo) Method() string {
37+
return info.method
38+
}
39+
40+
// CallType returns the type of call the interceptor is handling.
41+
func (info *TracingInfo) CallType() goa.InterceptorCallType {
42+
return info.callType
43+
}
44+
45+
// RawPayload returns the raw payload of the request.
46+
func (info *TracingInfo) RawPayload() any {
47+
return info.rawPayload
48+
}

codegen/service/testdata/interceptors/single-method-server-interceptor_service_interceptors.go.golden

+21
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,24 @@ func WrapMethodEndpoint(endpoint goa.Endpoint, i ServerInterceptors) goa.Endpoin
2525
return endpoint
2626
}
2727

28+
// Public accessor methods for Info types
29+
30+
// Service returns the name of the service handling the request.
31+
func (info *LoggingInfo) Service() string {
32+
return info.service
33+
}
34+
35+
// Method returns the name of the method handling the request.
36+
func (info *LoggingInfo) Method() string {
37+
return info.method
38+
}
39+
40+
// CallType returns the type of call the interceptor is handling.
41+
func (info *LoggingInfo) CallType() goa.InterceptorCallType {
42+
return info.callType
43+
}
44+
45+
// RawPayload returns the raw payload of the request.
46+
func (info *LoggingInfo) RawPayload() any {
47+
return info.rawPayload
48+
}

codegen/service/testdata/interceptors/single-service-server-interceptor_service_interceptors.go.golden

+21
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,24 @@ func WrapMethod2Endpoint(endpoint goa.Endpoint, i ServerInterceptors) goa.Endpoi
3232
return endpoint
3333
}
3434

35+
// Public accessor methods for Info types
36+
37+
// Service returns the name of the service handling the request.
38+
func (info *LoggingInfo) Service() string {
39+
return info.service
40+
}
41+
42+
// Method returns the name of the method handling the request.
43+
func (info *LoggingInfo) Method() string {
44+
return info.method
45+
}
46+
47+
// CallType returns the type of call the interceptor is handling.
48+
func (info *LoggingInfo) CallType() goa.InterceptorCallType {
49+
return info.callType
50+
}
51+
52+
// RawPayload returns the raw payload of the request.
53+
func (info *LoggingInfo) RawPayload() any {
54+
return info.rawPayload
55+
}

codegen/service/testdata/interceptors/streaming-interceptors-with-read-payload-and-read-streaming-payload_interceptor_wrappers.go.golden

+10-13
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,8 @@ type wrappedMethodClientStream struct {
1919
// wrapLoggingMethod applies the logging server interceptor to endpoints.
2020
func wrapMethodLogging(endpoint goa.Endpoint, i ServerInterceptors) goa.Endpoint {
2121
return func(ctx context.Context, req any) (any, error) {
22-
info := &LoggingInfo{
23-
service: "StreamingInterceptorsWithReadPayloadAndReadStreamingPayload",
24-
method: "Method",
25-
callType: goa.InterceptorUnary,
26-
rawPayload: req,
27-
}
28-
res, err := i.Logging(ctx, info, endpoint)
29-
if err != nil {
30-
return res, err
31-
}
32-
stream := res.(MethodServerStream)
33-
return &wrappedMethodServerStream{
22+
stream := req.(*MethodEndpointInput).Stream
23+
req.(*MethodEndpointInput).Stream = &wrappedMethodServerStream{
3424
ctx: ctx,
3525
recvWithContext: func(ctx context.Context) (*MethodStreamingPayload, error) {
3626
info := &LoggingInfo{
@@ -45,7 +35,14 @@ func wrapMethodLogging(endpoint goa.Endpoint, i ServerInterceptors) goa.Endpoint
4535
return castRes, err
4636
},
4737
stream: stream,
48-
}, nil
38+
}
39+
info := &LoggingInfo{
40+
service: "StreamingInterceptorsWithReadPayloadAndReadStreamingPayload",
41+
method: "Method",
42+
callType: goa.InterceptorUnary,
43+
rawPayload: req,
44+
}
45+
return i.Logging(ctx, info, endpoint)
4946
}
5047
}
5148

codegen/service/testdata/interceptors/streaming-interceptors-with-read-payload-and-read-streaming-payload_service_interceptors.go.golden

+6-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,12 @@ func (info *LoggingInfo) RawPayload() any {
7373

7474
// Payload returns a type-safe accessor for the method payload.
7575
func (info *LoggingInfo) Payload() LoggingPayload {
76-
return &loggingMethodPayload{payload: info.RawPayload().(*MethodPayload)}
76+
switch pay := info.RawPayload().(type) {
77+
case *MethodEndpointInput:
78+
return &loggingMethodPayload{payload: pay.Payload}
79+
default:
80+
return &loggingMethodPayload{payload: pay.(*MethodPayload)}
81+
}
7782
}
7883

7984
// ClientStreamingPayload returns a type-safe accessor for the method streaming payload for a client-side interceptor.

codegen/service/testdata/interceptors/streaming-interceptors-with-read-payload_service_interceptors.go.golden

+6-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,12 @@ func (info *LoggingInfo) RawPayload() any {
6363

6464
// Payload returns a type-safe accessor for the method payload.
6565
func (info *LoggingInfo) Payload() LoggingPayload {
66-
return &loggingMethodPayload{payload: info.RawPayload().(*MethodPayload)}
66+
switch pay := info.RawPayload().(type) {
67+
case *MethodEndpointInput:
68+
return &loggingMethodPayload{payload: pay.Payload}
69+
default:
70+
return &loggingMethodPayload{payload: pay.(*MethodPayload)}
71+
}
6772
}
6873

6974
// Private implementation methods

0 commit comments

Comments
 (0)