Skip to content

Commit a1b0988

Browse files
authored
feat: Add WithForwardResponseRewriter to allow easier/more useful response control (#4622)
## Context/Background I am working with grpc-gateway to mimick an older REST API while implementing the core as a gRPC server. This REST API, by convention, emits a response envelope. After extensively researching grpc-gateway's code and options, I think that there is not a good enough way to address my use-case; for a few more nuanced reasons. The sum of my requirements are as follows: - 1) Allow a particular field of a response message to be used as the main response. ✅ This is handled with `response_body` annotation. - 2) Be able to run interface checks on the response to extract useful information, like `next_page_token` [0] and surface it in the final response envelope. (`npt, ok := resp.(interface { GetNextPageToken() string })`). - 3) Take the true result and place it into a response envelope along with other parts of the response by convention and let that be encoded and sent as the response instead. ### Implementing a response envelope with `Marshaler` My first attempt at getting my gRPC server's responses in an envelope led me to implement my own Marshaler, I have seen this approach discussed in #4483. This does satisfy requirements 1 and 3 just fine, since the HTTP annotations helpfully allow the code to only receive the true result, and the Marshal interface has enough capabilities to take that and wrap it in a response envelope. However, requirements 1 and 2 are not _both_ satisfiable with the current grpc-gateway code because of how the `XXX_ResponseBody()` is called _before_ passing to the `Marshal(v)` function. This strips out the other fields that I would normally be able to detect and place in the response envelope. I even tried creating my _own_ protobuf message extension that would let me define another way of defining the "true result" field. But the options for implementing that are either a _ton_ of protoreflect at runtime to detect and extract that, or I am writing another protobuf generator plugin (which I have done before [1]), but both of those options seem quite complex. ### Other non-starter options Just to get ahead of the discussion, `WithForwardResponseOption` clearly was not meant for this use-case. At best, it seems to only be a way to take information that might be in the response and add it as a header. [0]: https://google.aip.dev/158#:~:text=Response%20messages%20for%20collections%20should%20define%20a%20string%20next_page_token%20field [1]: https://github.com/nkcmr/protoc-gen-twirp_js ### In practice This change fulfills my requirements by allowing logic to be inserted right before the Marshal is called: ```go gatewayMux := runtime.NewServeMux( runtime.WithForwardResponseRewriter(func(ctx context.Context, response proto.Message) (interface{}, error) { if s, ok := response.(*statuspb.Status); ok { return rewriteStatusToErrorEnvelope(ctx, s) } return rewriteResultToEnvelope(ctx, response) }), ) ``` ## In this PR This PR introduces a new `ServeMuxOption` called `WithForwardResponseRewriter` that allows for a user-provided function to be supplied that can take a response `proto.Message` and return `any` during unary response forwarding, stream response forwarding, and error response forwarding. The code generation was also updated to make the `XXX_ResponseBody()` response wrappers embed the concrete type instead of just `proto.Message`. This allows any code in response rewriter functions to be able to have access to the original type, so that interface checks against it should pass as if it was the original message. Updated the "Customizing Your Gateway" documentation to use `WithForwardResponseRewriter` in the `Fully Overriding Custom HTTP Responses` sections. ## Testing Added some basic unit tests to ensure Unary/Stream and error handlers invoke `ForwardResponseRewriter` correctly.
1 parent 169370b commit a1b0988

File tree

8 files changed

+156
-74
lines changed

8 files changed

+156
-74
lines changed

docs/docs/mapping/customizing_your_gateway.md

+14-17
Original file line numberDiff line numberDiff line change
@@ -324,15 +324,15 @@ First, set up the gRPC-Gateway with the custom options:
324324

325325
```go
326326
mux := runtime.NewServeMux(
327-
runtime.WithMarshalerOption(runtime.MIMEWildcard, &ResponseWrapper{}),
328-
runtime.WithForwardResponseOption(forwardResponse),
327+
runtime.WithForwardResponseOption(setStatus),
328+
runtime.WithForwardResponseRewriter(responseEnvelope),
329329
)
330330
```
331331

332-
Define the `forwardResponse` function to handle specific response types:
332+
Define the `setStatus` function to handle specific response types:
333333

334334
```go
335-
func forwardResponse(ctx context.Context, w http.ResponseWriter, m protoreflect.ProtoMessage) error {
335+
func setStatus(ctx context.Context, w http.ResponseWriter, m protoreflect.ProtoMessage) error {
336336
switch v := m.(type) {
337337
case *pb.CreateUserResponse:
338338
w.WriteHeader(http.StatusCreated)
@@ -342,32 +342,29 @@ func forwardResponse(ctx context.Context, w http.ResponseWriter, m protoreflect.
342342
}
343343
```
344344

345-
Create a custom marshaler to format the response data which utilizes the `JSONPb` marshaler as a fallback:
345+
Define the `responseEnvelope` function to rewrite the response to a different type/shape:
346346

347347
```go
348-
type ResponseWrapper struct {
349-
runtime.JSONPb
350-
}
351-
352-
func (c *ResponseWrapper) Marshal(data any) ([]byte, error) {
353-
resp := data
348+
func responseEnvelope(_ context.Context, response proto.Message) (interface{}, error) {
354349
switch v := data.(type) {
355350
case *pb.CreateUserResponse:
356351
// wrap the response in a custom structure
357-
resp = map[string]any{
352+
return map[string]any{
358353
"success": true,
359354
"data": data,
360-
}
355+
}, nil
361356
}
362-
// otherwise, use the default JSON marshaller
363-
return c.JSONPb.Marshal(resp)
357+
return response, nil
364358
}
365359
```
366360

367361
In this setup:
368362

369-
- The `forwardResponse` function intercepts the response and formats it as needed.
370-
- The `CustomPB` marshaller ensures that specific types of responses are wrapped in a custom structure before being sent to the client.
363+
- The `setStatus` function intercepts the response and uses its type to send `201 Created` only when it sees `*pb.CreateUserResponse`.
364+
- The `responseEnvelope` function ensures that specific types of responses are wrapped in a custom structure before being sent to the client.
365+
366+
**NOTE:** Using `WithForwardResponseRewriter` is partially incompatible with OpenAPI annotations. Because response
367+
rewriting happens at runtime, it is not possible to represent that in `protoc-gen-openapiv2` output.
371368

372369
## Error handler
373370

examples/internal/proto/examplepb/response_body_service.pb.gw.go

+14-18
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

protoc-gen-grpc-gateway/internal/gengateway/template.go

+4-5
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,7 @@ func Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}Server(ctx context.Context,
654654
}
655655
656656
{{ if $b.ResponseBody }}
657-
forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, mux, outboundMarshaler, w, req, response_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}{resp}, mux.GetForwardResponseOptions()...)
657+
forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, mux, outboundMarshaler, w, req, response_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}{resp.(*{{$m.ResponseType.GoType $m.Service.File.GoPkg.Path}})}, mux.GetForwardResponseOptions()...)
658658
{{ else }}
659659
forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
660660
{{end}}
@@ -744,7 +744,7 @@ func Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}Client(ctx context.Context,
744744
{{end}}
745745
{{else}}
746746
{{ if $b.ResponseBody }}
747-
forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, mux, outboundMarshaler, w, req, response_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}{resp}, mux.GetForwardResponseOptions()...)
747+
forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, mux, outboundMarshaler, w, req, response_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}{resp.(*{{$m.ResponseType.GoType $m.Service.File.GoPkg.Path}})}, mux.GetForwardResponseOptions()...)
748748
{{ else }}
749749
forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
750750
{{end}}
@@ -759,12 +759,11 @@ func Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}Client(ctx context.Context,
759759
{{range $b := $m.Bindings}}
760760
{{if $b.ResponseBody}}
761761
type response_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}} struct {
762-
proto.Message
762+
*{{$m.ResponseType.GoType $m.Service.File.GoPkg.Path}}
763763
}
764764
765765
func (m response_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}) XXX_ResponseBody() interface{} {
766-
response := m.Message.(*{{$m.ResponseType.GoType $m.Service.File.GoPkg.Path}})
767-
return {{$b.ResponseBody.AssignableExpr "response" $m.Service.File.GoPkg.Path}}
766+
return {{$b.ResponseBody.AssignableExpr "m" $m.Service.File.GoPkg.Path}}
768767
}
769768
{{end}}
770769
{{end}}

runtime/errors.go

+13-3
Original file line numberDiff line numberDiff line change
@@ -93,26 +93,36 @@ func HTTPError(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.R
9393
func DefaultHTTPErrorHandler(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, r *http.Request, err error) {
9494
// return Internal when Marshal failed
9595
const fallback = `{"code": 13, "message": "failed to marshal error message"}`
96+
const fallbackRewriter = `{"code": 13, "message": "failed to rewrite error message"}`
9697

9798
var customStatus *HTTPStatusError
9899
if errors.As(err, &customStatus) {
99100
err = customStatus.Err
100101
}
101102

102103
s := status.Convert(err)
103-
pb := s.Proto()
104104

105105
w.Header().Del("Trailer")
106106
w.Header().Del("Transfer-Encoding")
107107

108-
contentType := marshaler.ContentType(pb)
108+
respRw, err := mux.forwardResponseRewriter(ctx, s.Proto())
109+
if err != nil {
110+
grpclog.Errorf("Failed to rewrite error message %q: %v", s, err)
111+
w.WriteHeader(http.StatusInternalServerError)
112+
if _, err := io.WriteString(w, fallbackRewriter); err != nil {
113+
grpclog.Errorf("Failed to write response: %v", err)
114+
}
115+
return
116+
}
117+
118+
contentType := marshaler.ContentType(respRw)
109119
w.Header().Set("Content-Type", contentType)
110120

111121
if s.Code() == codes.Unauthenticated {
112122
w.Header().Set("WWW-Authenticate", s.Message())
113123
}
114124

115-
buf, merr := marshaler.Marshal(pb)
125+
buf, merr := marshaler.Marshal(respRw)
116126
if merr != nil {
117127
grpclog.Errorf("Failed to marshal error message %q: %v", s, merr)
118128
w.WriteHeader(http.StatusInternalServerError)

runtime/errors_test.go

+36-11
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
statuspb "google.golang.org/genproto/googleapis/rpc/status"
1515
"google.golang.org/grpc/codes"
1616
"google.golang.org/grpc/status"
17+
"google.golang.org/protobuf/proto"
1718
)
1819

1920
func TestDefaultHTTPError(t *testing.T) {
@@ -24,12 +25,14 @@ func TestDefaultHTTPError(t *testing.T) {
2425
)
2526

2627
for i, spec := range []struct {
27-
err error
28-
status int
29-
msg string
30-
marshaler runtime.Marshaler
31-
contentType string
32-
details string
28+
err error
29+
status int
30+
msg string
31+
marshaler runtime.Marshaler
32+
contentType string
33+
details string
34+
fordwardRespRewriter runtime.ForwardResponseRewriter
35+
extractMessage func(*testing.T)
3336
}{
3437
{
3538
err: errors.New("example error"),
@@ -70,23 +73,45 @@ func TestDefaultHTTPError(t *testing.T) {
7073
contentType: "application/json",
7174
msg: "Method Not Allowed",
7275
},
76+
{
77+
err: status.Error(codes.InvalidArgument, "example error"),
78+
status: http.StatusBadRequest,
79+
marshaler: &runtime.JSONPb{},
80+
contentType: "application/json",
81+
msg: "bad request: example error",
82+
fordwardRespRewriter: func(ctx context.Context, response proto.Message) (any, error) {
83+
if s, ok := response.(*statuspb.Status); ok && strings.HasPrefix(s.Message, "example") {
84+
return &statuspb.Status{
85+
Code: s.Code,
86+
Message: "bad request: " + s.Message,
87+
Details: s.Details,
88+
}, nil
89+
}
90+
return response, nil
91+
},
92+
},
7393
} {
7494
t.Run(strconv.Itoa(i), func(t *testing.T) {
7595
w := httptest.NewRecorder()
7696
req, _ := http.NewRequestWithContext(ctx, "", "", nil) // Pass in an empty request to match the signature
77-
mux := runtime.NewServeMux()
78-
marshaler := &runtime.JSONPb{}
79-
runtime.HTTPError(ctx, mux, marshaler, w, req, spec.err)
8097

81-
if got, want := w.Header().Get("Content-Type"), "application/json"; got != want {
98+
opts := []runtime.ServeMuxOption{}
99+
if spec.fordwardRespRewriter != nil {
100+
opts = append(opts, runtime.WithForwardResponseRewriter(spec.fordwardRespRewriter))
101+
}
102+
mux := runtime.NewServeMux(opts...)
103+
104+
runtime.HTTPError(ctx, mux, spec.marshaler, w, req, spec.err)
105+
106+
if got, want := w.Header().Get("Content-Type"), spec.contentType; got != want {
82107
t.Errorf(`w.Header().Get("Content-Type") = %q; want %q; on spec.err=%v`, got, want, spec.err)
83108
}
84109
if got, want := w.Code, spec.status; got != want {
85110
t.Errorf("w.Code = %d; want %d", got, want)
86111
}
87112

88113
var st statuspb.Status
89-
if err := marshaler.Unmarshal(w.Body.Bytes(), &st); err != nil {
114+
if err := spec.marshaler.Unmarshal(w.Body.Bytes(), &st); err != nil {
90115
t.Errorf("marshaler.Unmarshal(%q, &body) failed with %v; want success", w.Body.Bytes(), err)
91116
return
92117
}

runtime/handler.go

+20-8
Original file line numberDiff line numberDiff line change
@@ -56,20 +56,27 @@ func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshal
5656
return
5757
}
5858

59+
respRw, err := mux.forwardResponseRewriter(ctx, resp)
60+
if err != nil {
61+
grpclog.Errorf("Rewrite error: %v", err)
62+
handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err, delimiter)
63+
return
64+
}
65+
5966
if !wroteHeader {
60-
w.Header().Set("Content-Type", marshaler.ContentType(resp))
67+
w.Header().Set("Content-Type", marshaler.ContentType(respRw))
6168
}
6269

6370
var buf []byte
64-
httpBody, isHTTPBody := resp.(*httpbody.HttpBody)
71+
httpBody, isHTTPBody := respRw.(*httpbody.HttpBody)
6572
switch {
66-
case resp == nil:
73+
case respRw == nil:
6774
buf, err = marshaler.Marshal(errorChunk(status.New(codes.Internal, "empty response")))
6875
case isHTTPBody:
6976
buf = httpBody.GetData()
7077
default:
71-
result := map[string]interface{}{"result": resp}
72-
if rb, ok := resp.(responseBody); ok {
78+
result := map[string]interface{}{"result": respRw}
79+
if rb, ok := respRw.(responseBody); ok {
7380
result["result"] = rb.XXX_ResponseBody()
7481
}
7582

@@ -165,12 +172,17 @@ func ForwardResponseMessage(ctx context.Context, mux *ServeMux, marshaler Marsha
165172
HTTPError(ctx, mux, marshaler, w, req, err)
166173
return
167174
}
175+
respRw, err := mux.forwardResponseRewriter(ctx, resp)
176+
if err != nil {
177+
grpclog.Errorf("Rewrite error: %v", err)
178+
HTTPError(ctx, mux, marshaler, w, req, err)
179+
return
180+
}
168181
var buf []byte
169-
var err error
170-
if rb, ok := resp.(responseBody); ok {
182+
if rb, ok := respRw.(responseBody); ok {
171183
buf, err = marshaler.Marshal(rb.XXX_ResponseBody())
172184
} else {
173-
buf, err = marshaler.Marshal(resp)
185+
buf, err = marshaler.Marshal(respRw)
174186
}
175187
if err != nil {
176188
grpclog.Errorf("Marshal error: %v", err)

0 commit comments

Comments
 (0)