@@ -377,6 +377,77 @@ func TestForwardResponseMessage(t *testing.T) {
377
377
}
378
378
379
379
func TestOutgoingHeaderMatcher (t * testing.T ) {
380
+ t .Parallel ()
381
+ msg := & pb.SimpleMessage {Id : "foo" }
382
+ for _ , tc := range []struct {
383
+ name string
384
+ md runtime.ServerMetadata
385
+ headers http.Header
386
+ matcher runtime.HeaderMatcherFunc
387
+ }{
388
+ {
389
+ name : "default matcher" ,
390
+ md : runtime.ServerMetadata {
391
+ HeaderMD : metadata .Pairs (
392
+ "foo" , "bar" ,
393
+ "baz" , "qux" ,
394
+ ),
395
+ },
396
+ headers : http.Header {
397
+ "Content-Type" : []string {"application/json" },
398
+ "Grpc-Metadata-Foo" : []string {"bar" },
399
+ "Grpc-Metadata-Baz" : []string {"qux" },
400
+ },
401
+ },
402
+ {
403
+ name : "custom matcher" ,
404
+ md : runtime.ServerMetadata {
405
+ HeaderMD : metadata .Pairs (
406
+ "foo" , "bar" ,
407
+ "baz" , "qux" ,
408
+ ),
409
+ },
410
+ headers : http.Header {
411
+ "Content-Type" : []string {"application/json" },
412
+ "Custom-Foo" : []string {"bar" },
413
+ },
414
+ matcher : func (key string ) (string , bool ) {
415
+ switch key {
416
+ case "foo" :
417
+ return "custom-foo" , true
418
+ default :
419
+ return "" , false
420
+ }
421
+ },
422
+ },
423
+ } {
424
+ tc := tc
425
+ t .Run (tc .name , func (t * testing.T ) {
426
+ t .Parallel ()
427
+ ctx := runtime .NewServerMetadataContext (context .Background (), tc .md )
428
+
429
+ req := httptest .NewRequest ("GET" , "http://example.com/foo" , nil )
430
+ resp := httptest .NewRecorder ()
431
+
432
+ mux := runtime .NewServeMux (
433
+ runtime .WithOutgoingHeaderMatcher (tc .matcher ),
434
+ )
435
+ runtime .ForwardResponseMessage (ctx , mux , & runtime.JSONPb {}, resp , req , msg )
436
+
437
+ w := resp .Result ()
438
+ defer w .Body .Close ()
439
+ if w .StatusCode != http .StatusOK {
440
+ t .Fatalf ("StatusCode %d want %d" , w .StatusCode , http .StatusOK )
441
+ }
442
+
443
+ if ! reflect .DeepEqual (w .Header , tc .headers ) {
444
+ t .Fatalf ("Header %v want %v" , w .Header , tc .headers )
445
+ }
446
+ })
447
+ }
448
+ }
449
+
450
+ func TestOutgoingHeaderMatcherWithContentLength (t * testing.T ) {
380
451
t .Parallel ()
381
452
msg := & pb.SimpleMessage {Id : "foo" }
382
453
for _ , tc := range []struct {
@@ -431,7 +502,11 @@ func TestOutgoingHeaderMatcher(t *testing.T) {
431
502
req := httptest .NewRequest ("GET" , "http://example.com/foo" , nil )
432
503
resp := httptest .NewRecorder ()
433
504
434
- runtime .ForwardResponseMessage (ctx , runtime .NewServeMux (runtime .WithOutgoingHeaderMatcher (tc .matcher )), & runtime.JSONPb {}, resp , req , msg )
505
+ mux := runtime .NewServeMux (
506
+ runtime .WithOutgoingHeaderMatcher (tc .matcher ),
507
+ runtime .WithWriteContentLength (),
508
+ )
509
+ runtime .ForwardResponseMessage (ctx , mux , & runtime.JSONPb {}, resp , req , msg )
435
510
436
511
w := resp .Result ()
437
512
defer w .Body .Close ()
@@ -529,7 +604,11 @@ func TestOutgoingTrailerMatcher(t *testing.T) {
529
604
req .Header = tc .caller
530
605
resp := httptest .NewRecorder ()
531
606
532
- runtime .ForwardResponseMessage (ctx , runtime .NewServeMux (runtime .WithOutgoingTrailerMatcher (tc .matcher )), & runtime.JSONPb {}, resp , req , msg )
607
+ mux := runtime .NewServeMux (
608
+ runtime .WithOutgoingTrailerMatcher (tc .matcher ),
609
+ runtime .WithWriteContentLength (),
610
+ )
611
+ runtime .ForwardResponseMessage (ctx , mux , & runtime.JSONPb {}, resp , req , msg )
533
612
534
613
w := resp .Result ()
535
614
_ , _ = io .Copy (io .Discard , w .Body )
0 commit comments