Skip to content

Commit 5dfd063

Browse files
authored
Only write Content-Length if the runtime.WithWriteContentLength() option is specified (#5151)
1 parent e1364b5 commit 5dfd063

File tree

3 files changed

+90
-3
lines changed

3 files changed

+90
-3
lines changed

runtime/handler.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ func ForwardResponseMessage(ctx context.Context, mux *ServeMux, marshaler Marsha
196196
return
197197
}
198198

199-
if !doForwardTrailers {
199+
if !doForwardTrailers && mux.writeContentLength {
200200
w.Header().Set("Content-Length", strconv.Itoa(len(buf)))
201201
}
202202

runtime/handler_test.go

+81-2
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,77 @@ func TestForwardResponseMessage(t *testing.T) {
377377
}
378378

379379
func TestOutgoingHeaderMatcher(t *testing.T) {
380+
t.Parallel()
381+
msg := &pb.SimpleMessage{Id: "foo"}
382+
for _, tc := range []struct {
383+
name string
384+
md runtime.ServerMetadata
385+
headers http.Header
386+
matcher runtime.HeaderMatcherFunc
387+
}{
388+
{
389+
name: "default matcher",
390+
md: runtime.ServerMetadata{
391+
HeaderMD: metadata.Pairs(
392+
"foo", "bar",
393+
"baz", "qux",
394+
),
395+
},
396+
headers: http.Header{
397+
"Content-Type": []string{"application/json"},
398+
"Grpc-Metadata-Foo": []string{"bar"},
399+
"Grpc-Metadata-Baz": []string{"qux"},
400+
},
401+
},
402+
{
403+
name: "custom matcher",
404+
md: runtime.ServerMetadata{
405+
HeaderMD: metadata.Pairs(
406+
"foo", "bar",
407+
"baz", "qux",
408+
),
409+
},
410+
headers: http.Header{
411+
"Content-Type": []string{"application/json"},
412+
"Custom-Foo": []string{"bar"},
413+
},
414+
matcher: func(key string) (string, bool) {
415+
switch key {
416+
case "foo":
417+
return "custom-foo", true
418+
default:
419+
return "", false
420+
}
421+
},
422+
},
423+
} {
424+
tc := tc
425+
t.Run(tc.name, func(t *testing.T) {
426+
t.Parallel()
427+
ctx := runtime.NewServerMetadataContext(context.Background(), tc.md)
428+
429+
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
430+
resp := httptest.NewRecorder()
431+
432+
mux := runtime.NewServeMux(
433+
runtime.WithOutgoingHeaderMatcher(tc.matcher),
434+
)
435+
runtime.ForwardResponseMessage(ctx, mux, &runtime.JSONPb{}, resp, req, msg)
436+
437+
w := resp.Result()
438+
defer w.Body.Close()
439+
if w.StatusCode != http.StatusOK {
440+
t.Fatalf("StatusCode %d want %d", w.StatusCode, http.StatusOK)
441+
}
442+
443+
if !reflect.DeepEqual(w.Header, tc.headers) {
444+
t.Fatalf("Header %v want %v", w.Header, tc.headers)
445+
}
446+
})
447+
}
448+
}
449+
450+
func TestOutgoingHeaderMatcherWithContentLength(t *testing.T) {
380451
t.Parallel()
381452
msg := &pb.SimpleMessage{Id: "foo"}
382453
for _, tc := range []struct {
@@ -431,7 +502,11 @@ func TestOutgoingHeaderMatcher(t *testing.T) {
431502
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
432503
resp := httptest.NewRecorder()
433504

434-
runtime.ForwardResponseMessage(ctx, runtime.NewServeMux(runtime.WithOutgoingHeaderMatcher(tc.matcher)), &runtime.JSONPb{}, resp, req, msg)
505+
mux := runtime.NewServeMux(
506+
runtime.WithOutgoingHeaderMatcher(tc.matcher),
507+
runtime.WithWriteContentLength(),
508+
)
509+
runtime.ForwardResponseMessage(ctx, mux, &runtime.JSONPb{}, resp, req, msg)
435510

436511
w := resp.Result()
437512
defer w.Body.Close()
@@ -529,7 +604,11 @@ func TestOutgoingTrailerMatcher(t *testing.T) {
529604
req.Header = tc.caller
530605
resp := httptest.NewRecorder()
531606

532-
runtime.ForwardResponseMessage(ctx, runtime.NewServeMux(runtime.WithOutgoingTrailerMatcher(tc.matcher)), &runtime.JSONPb{}, resp, req, msg)
607+
mux := runtime.NewServeMux(
608+
runtime.WithOutgoingTrailerMatcher(tc.matcher),
609+
runtime.WithWriteContentLength(),
610+
)
611+
runtime.ForwardResponseMessage(ctx, mux, &runtime.JSONPb{}, resp, req, msg)
533612

534613
w := resp.Result()
535614
_, _ = io.Copy(io.Discard, w.Body)

runtime/mux.go

+8
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ type ServeMux struct {
7171
routingErrorHandler RoutingErrorHandlerFunc
7272
disablePathLengthFallback bool
7373
unescapingMode UnescapingMode
74+
writeContentLength bool
7475
}
7576

7677
// ServeMuxOption is an option that can be given to a ServeMux on construction.
@@ -258,6 +259,13 @@ func WithDisablePathLengthFallback() ServeMuxOption {
258259
}
259260
}
260261

262+
// WithWriteContentLength returns a ServeMuxOption to enable writing content length on non-streaming responses
263+
func WithWriteContentLength() ServeMuxOption {
264+
return func(serveMux *ServeMux) {
265+
serveMux.writeContentLength = true
266+
}
267+
}
268+
261269
// WithHealthEndpointAt returns a ServeMuxOption that will add an endpoint to the created ServeMux at the path specified by endpointPath.
262270
// When called the handler will forward the request to the upstream grpc service health check (defined in the
263271
// gRPC Health Checking Protocol).

0 commit comments

Comments
 (0)