From d792a6f64104a204c852188e2e7437222e07f04b Mon Sep 17 00:00:00 2001 From: qdongxu Date: Sun, 29 Jun 2025 19:45:39 +0800 Subject: [PATCH 1/3] 1. record replaer "http.request.size" 2. make metrics records accurate reqeust size by read from req.Body. ( metrics read from Content-Length field, which not applicable to web socket or chunked Body) 3. cache RequestRecorder to handle records in multiple moudles. ( logRequest and metrics creates its own ResponseRecorder.) --- modules/caddyhttp/intercept/intercept.go | 2 +- modules/caddyhttp/metrics.go | 20 ++++--- modules/caddyhttp/responsewriter.go | 73 ++++++++++++++++++------ modules/caddyhttp/responsewriter_test.go | 36 +++++++++++- modules/caddyhttp/server.go | 52 ++++------------- modules/caddyhttp/server_test.go | 25 ++++---- modules/caddyhttp/templates/templates.go | 2 +- 7 files changed, 127 insertions(+), 83 deletions(-) diff --git a/modules/caddyhttp/intercept/intercept.go b/modules/caddyhttp/intercept/intercept.go index cb23adf0a2b..ff6c2c9dda4 100644 --- a/modules/caddyhttp/intercept/intercept.go +++ b/modules/caddyhttp/intercept/intercept.go @@ -131,7 +131,7 @@ func (ir Intercept) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddy repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer) rec := interceptedResponseHandler{replacer: repl} - rec.ResponseRecorder = caddyhttp.NewResponseRecorder(w, buf, func(status int, header http.Header) bool { + rec.ResponseRecorder, _ = caddyhttp.NewResponseRecorder(w, &r, buf, func(status int, header http.Header) bool { // see if any response handler is configured for this original response for i, rh := range ir.HandleResponse { if rh.Match != nil && !rh.Match.Match(status, header) { diff --git a/modules/caddyhttp/metrics.go b/modules/caddyhttp/metrics.go index 9bb97e0b47b..e0d935bba47 100644 --- a/modules/caddyhttp/metrics.go +++ b/modules/caddyhttp/metrics.go @@ -153,8 +153,14 @@ func (h *metricsInstrumentedHandler) ServeHTTP(w http.ResponseWriter, r *http.Re h.metrics.httpMetrics.responseDuration.With(statusLabels).Observe(ttfb) return false }) - wrec := NewResponseRecorder(w, nil, writeHeaderRecorder) - err := h.mh.ServeHTTP(wrec, r, next) + + wrec, cached := NewResponseRecorder(w, &r, nil, writeHeaderRecorder) + if !cached { + w = wrec + } + + err := h.mh.ServeHTTP(w, r, next) + dur := time.Since(start).Seconds() h.metrics.httpMetrics.requestCount.With(labels).Inc() @@ -168,7 +174,7 @@ func (h *metricsInstrumentedHandler) ServeHTTP(w http.ResponseWriter, r *http.Re } h.metrics.httpMetrics.requestDuration.With(statusLabels).Observe(dur) - h.metrics.httpMetrics.requestSize.With(statusLabels).Observe(float64(computeApproximateRequestSize(r))) + h.metrics.httpMetrics.requestSize.With(statusLabels).Observe(float64(computeApproximateRequestSize(wrec, r))) h.metrics.httpMetrics.responseSize.With(statusLabels).Observe(float64(wrec.Size())) } @@ -189,7 +195,7 @@ func (h *metricsInstrumentedHandler) ServeHTTP(w http.ResponseWriter, r *http.Re } // taken from https://github.com/prometheus/client_golang/blob/6007b2b5cae01203111de55f753e76d8dac1f529/prometheus/promhttp/instrument_server.go#L298 -func computeApproximateRequestSize(r *http.Request) int { +func computeApproximateRequestSize(wrec ResponseRecorder, r *http.Request) int { s := 0 if r.URL != nil { s += len(r.URL.String()) @@ -205,10 +211,6 @@ func computeApproximateRequestSize(r *http.Request) int { } s += len(r.Host) - // N.B. r.Form and r.MultipartForm are assumed to be included in r.URL. - - if r.ContentLength != -1 { - s += int(r.ContentLength) - } + s += wrec.RequestSize() return s } diff --git a/modules/caddyhttp/responsewriter.go b/modules/caddyhttp/responsewriter.go index 904c30c0352..4256280a161 100644 --- a/modules/caddyhttp/responsewriter.go +++ b/modules/caddyhttp/responsewriter.go @@ -17,6 +17,7 @@ package caddyhttp import ( "bufio" "bytes" + `context` "fmt" "io" "net" @@ -71,12 +72,14 @@ type responseRecorder struct { wroteHeader bool stream bool - readSize *int + reqBodyLengthReader lengthReader } // NewResponseRecorder returns a new ResponseRecorder that can be -// used instead of a standard http.ResponseWriter. The recorder is -// useful for middlewares which need to buffer a response and +// used instead of a standard http.ResponseWriter. +// The ResponseRecorder will be cached in the http.Request.Context to avoid +// multiple instances created by different Modules in a pipeline of http handlers. +// The recorder is useful for middlewares which need to buffer a response and // potentially process its entire body before actually writing the // response to the underlying writer. Of course, buffering the entire // body has a memory overhead, but sometimes there is no way to avoid @@ -101,8 +104,12 @@ type responseRecorder struct { // // Proper usage of a recorder looks like this: // -// rec := caddyhttp.NewResponseRecorder(w, buf, shouldBuffer) -// err := next.ServeHTTP(rec, req) +// rec, cached := caddyhttp.NewResponseRecorder(w, &req, buf, shouldBuffer) +// if !cached { +// w = rec +// } +// +// err := next.ServeHTTP(w, req) // do not replace rec if got from a cached one // if err != nil { // return err // } @@ -134,12 +141,27 @@ type responseRecorder struct { // As a special case, 1xx responses are not buffered nor recorded // because they are not the final response; they are passed through // directly to the underlying ResponseWriter. -func NewResponseRecorder(w http.ResponseWriter, buf *bytes.Buffer, shouldBuffer ShouldBufferFunc) ResponseRecorder { - return &responseRecorder{ +func NewResponseRecorder(w http.ResponseWriter, req **http.Request, buf *bytes.Buffer, + shouldBuffer ShouldBufferFunc) (wrec ResponseRecorder, cached bool) { + r := *req + if wrec, ok := r.Context().Value(ResponseRecorderVarKey).(ResponseRecorder); ok { + return wrec, true + } + + rr := &responseRecorder{ ResponseWriterWrapper: &ResponseWriterWrapper{ResponseWriter: w}, buf: buf, shouldBuffer: shouldBuffer, + reqBodyLengthReader: lengthReader{}, + } + if r.Body != nil { + rr.reqBodyLengthReader.Source = r.Body + r.Body = &rr.reqBodyLengthReader } + + c := context.WithValue(r.Context(), ResponseRecorderVarKey, rr) + *req = r.WithContext(c) + return rr, false } // WriteHeader writes the headers with statusCode to the wrapped @@ -211,6 +233,12 @@ func (rr *responseRecorder) Size() int { return rr.size } +// RequestSize returns the number of bytes read from the Request, +// not including the request headers. +func (rr *responseRecorder) RequestSize() int { + return rr.reqBodyLengthReader.Length +} + // Buffer returns the body buffer that rr was created with. // You should still have your original pointer, though. func (rr *responseRecorder) Buffer() *bytes.Buffer { @@ -246,12 +274,6 @@ func (rr *responseRecorder) FlushError() error { return nil } -// Private interface so it can only be used in this package -// #TODO: maybe export it later -func (rr *responseRecorder) setReadSize(size *int) { - rr.readSize = size -} - func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { //nolint:bodyclose conn, brw, err := http.NewResponseController(rr.ResponseWriterWrapper).Hijack() @@ -282,9 +304,7 @@ type hijackedConn struct { } func (hc *hijackedConn) updateReadSize(n int) { - if hc.rr.readSize != nil { - *hc.rr.readSize += n - } + hc.rr.reqBodyLengthReader.Length += n } func (hc *hijackedConn) Read(p []byte) (int, error) { @@ -320,6 +340,7 @@ type ResponseRecorder interface { Buffer() *bytes.Buffer Buffered() bool Size() int + RequestSize() int WriteResponse() error } @@ -342,3 +363,23 @@ var ( _ io.WriterTo = (*hijackedConn)(nil) ) + +// lengthReader is an io.ReadCloser that keeps track of the +// number of bytes read from the request body. +// This wrapper is for http request process only. If the underlying +// conn hijacked by a websocket session. ResponseRecorder will +// update the Length field. +type lengthReader struct { + Source io.ReadCloser + Length int +} + +func (r *lengthReader) Read(b []byte) (int, error) { + n, err := r.Source.Read(b) + r.Length += n + return n, err +} + +func (r *lengthReader) Close() error { + return r.Source.Close() +} diff --git a/modules/caddyhttp/responsewriter_test.go b/modules/caddyhttp/responsewriter_test.go index c08ad26a472..1aa0a141447 100644 --- a/modules/caddyhttp/responsewriter_test.go +++ b/modules/caddyhttp/responsewriter_test.go @@ -4,6 +4,7 @@ import ( "bytes" "io" "net/http" + `net/http/httptest` "strings" "testing" ) @@ -104,26 +105,31 @@ func TestResponseWriterWrapperUnwrap(t *testing.T) { func TestResponseRecorderReadFrom(t *testing.T) { tests := map[string]struct { responseWriter responseWriterSpy + req *http.Request shouldBuffer bool wantReadFrom bool }{ "buffered plain": { responseWriter: &baseRespWriter{}, + req: &http.Request{}, shouldBuffer: true, wantReadFrom: false, }, "streamed plain": { responseWriter: &baseRespWriter{}, + req: &http.Request{}, shouldBuffer: false, wantReadFrom: false, }, "buffered ReadFrom": { responseWriter: &readFromRespWriter{}, + req: &http.Request{}, shouldBuffer: true, wantReadFrom: false, }, "streamed ReadFrom": { responseWriter: &readFromRespWriter{}, + req: &http.Request{}, shouldBuffer: false, wantReadFrom: true, }, @@ -132,7 +138,7 @@ func TestResponseRecorderReadFrom(t *testing.T) { t.Run(name, func(t *testing.T) { var buf bytes.Buffer - rr := NewResponseRecorder(tt.responseWriter, &buf, func(status int, header http.Header) bool { + rr, _ := NewResponseRecorder(tt.responseWriter, &tt.req, &buf, func(status int, header http.Header) bool { return tt.shouldBuffer }) @@ -169,3 +175,31 @@ func TestResponseRecorderReadFrom(t *testing.T) { }) } } + +func TestCachedResponseRecorder(t *testing.T) { + r := httptest.NewRequest("GET", "http://example.com/foo", nil) + rOld := r + w := &ResponseWriterWrapper{&baseRespWriter{}} + wrec1, cached := NewResponseRecorder(w, &r, nil, nil) + if cached { + t.Errorf("NewResponseRecorder() should not have been called and cached ") + } + + if rOld == r { + t.Errorf("r should be different from rOld") + } + + rOld = r + wrec2, cached := NewResponseRecorder(w, &r, nil, nil) + if !cached { + t.Errorf("NewResponseRecorder() has been caleed and should be cached ") + } + + if rOld != r { + t.Errorf("r should be identical as from rOld") + } + + if wrec1 != wrec2 { + t.Errorf("NewResponseRecorder() should be identical since it was cached") + } +} diff --git a/modules/caddyhttp/server.go b/modules/caddyhttp/server.go index a2b29d65831..c28600a099f 100644 --- a/modules/caddyhttp/server.go +++ b/modules/caddyhttp/server.go @@ -19,7 +19,6 @@ import ( "crypto/tls" "encoding/json" "fmt" - "io" "net" "net/http" "net/netip" @@ -335,26 +334,15 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { var duration time.Duration if s.shouldLogRequest(r) { - wrec := NewResponseRecorder(w, nil, nil) - w = wrec - - // wrap the request body in a LengthReader - // so we can track the number of bytes read from it - var bodyReader *lengthReader - if r.Body != nil { - bodyReader = &lengthReader{Source: r.Body} - r.Body = bodyReader - - // should always be true, private interface can only be referenced in the same package - if setReadSizer, ok := wrec.(interface{ setReadSize(*int) }); ok { - setReadSizer.setReadSize(&bodyReader.Length) - } + wrec, cached := NewResponseRecorder(w, &r, nil, nil) + if !cached { + w = wrec } // capture the original version of the request accLog := s.accessLogger.With(loggableReq) - defer s.logRequest(accLog, r, wrec, &duration, repl, bodyReader, shouldLogCredentials) + defer s.logRequest(accLog, r, wrec, &duration, repl, shouldLogCredentials) } start := time.Now() @@ -771,7 +759,7 @@ func (s *Server) logTrace(mh MiddlewareHandler) { // logRequest logs the request to access logs, unless skipped. func (s *Server) logRequest( accLog *zap.Logger, r *http.Request, wrec ResponseRecorder, duration *time.Duration, - repl *caddy.Replacer, bodyReader *lengthReader, shouldLogCredentials bool, + repl *caddy.Replacer, shouldLogCredentials bool, ) { // this request may be flagged as omitted from the logs if skip, ok := GetVar(r.Context(), LogSkipVar).(bool); ok && skip { @@ -780,9 +768,11 @@ func (s *Server) logRequest( status := wrec.Status() size := wrec.Size() + reqSize := wrec.RequestSize() repl.Set("http.response.status", status) // will be 0 if no response is written by us (Go will write 200 to client) repl.Set("http.response.size", size) + repl.Set("http.request.size", reqSize) repl.Set("http.response.duration", duration) repl.Set("http.response.duration_ms", duration.Seconds()*1e3) // multiply seconds to preserve decimal (see #4666) @@ -811,17 +801,12 @@ func (s *Server) logRequest( if fields == nil { userID, _ := repl.GetString("http.auth.user.id") - reqBodyLength := 0 - if bodyReader != nil { - reqBodyLength = bodyReader.Length - } - extra := r.Context().Value(ExtraLogFieldsCtxKey).(*ExtraLogFields) fieldCount := 6 fields = make([]zapcore.Field, 0, fieldCount+len(extra.fields)) fields = append(fields, - zap.Int("bytes_read", reqBodyLength), + zap.Int("bytes_read", wrec.RequestSize()), zap.String("user_id", userID), zap.Duration("duration", *duration), zap.Int("size", size), @@ -1050,23 +1035,6 @@ func cloneURL(from, to *url.URL) { } } -// lengthReader is an io.ReadCloser that keeps track of the -// number of bytes read from the request body. -type lengthReader struct { - Source io.ReadCloser - Length int -} - -func (r *lengthReader) Read(b []byte) (int, error) { - n, err := r.Source.Read(b) - r.Length += n - return n, err -} - -func (r *lengthReader) Close() error { - return r.Source.Close() -} - // Context keys for HTTP request context values. const ( // For referencing the server instance @@ -1087,6 +1055,10 @@ const ( // For tracking the real client IP (affected by trusted_proxy) ClientIPVarKey string = "client_ip" + + // For referencing underlying wrec to avoid create multiple ResponseRecorder + // in diffrenct Modules. + ResponseRecorderVarKey string = "response_recorder" ) var networkTypesHTTP3 = map[string]string{ diff --git a/modules/caddyhttp/server_test.go b/modules/caddyhttp/server_test.go index 6ce09974be5..8a2e1a0a696 100644 --- a/modules/caddyhttp/server_test.go +++ b/modules/caddyhttp/server_test.go @@ -51,16 +51,15 @@ func TestServer_LogRequest(t *testing.T) { ctx = context.WithValue(ctx, ExtraLogFieldsCtxKey, new(ExtraLogFields)) req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx) rec := httptest.NewRecorder() - wrec := NewResponseRecorder(rec, nil, nil) + wrec, _ := NewResponseRecorder(rec, &req, nil, nil) duration := 50 * time.Millisecond repl := NewTestReplacer(req) - bodyReader := &lengthReader{Source: req.Body} shouldLogCredentials := false buf := bytes.Buffer{} accLog := testLogger(buf.Write) - s.logRequest(accLog, req, wrec, &duration, repl, bodyReader, shouldLogCredentials) + s.logRequest(accLog, req, wrec, &duration, repl, shouldLogCredentials) assert.JSONEq(t, `{ "msg":"handled request", "level":"info", "bytes_read":0, @@ -79,16 +78,15 @@ func TestServer_LogRequest_WithTrace(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx) rec := httptest.NewRecorder() - wrec := NewResponseRecorder(rec, nil, nil) + wrec, _ := NewResponseRecorder(rec, &req, nil, nil) duration := 50 * time.Millisecond repl := NewTestReplacer(req) - bodyReader := &lengthReader{Source: req.Body} shouldLogCredentials := false buf := bytes.Buffer{} accLog := testLogger(buf.Write) - s.logRequest(accLog, req, wrec, &duration, repl, bodyReader, shouldLogCredentials) + s.logRequest(accLog, req, wrec, &duration, repl, shouldLogCredentials) assert.JSONEq(t, `{ "msg":"handled request", "level":"info", "bytes_read":0, @@ -107,11 +105,10 @@ func BenchmarkServer_LogRequest(b *testing.B) { req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx) rec := httptest.NewRecorder() - wrec := NewResponseRecorder(rec, nil, nil) + wrec, _ := NewResponseRecorder(rec, &req, nil, nil) duration := 50 * time.Millisecond repl := NewTestReplacer(req) - bodyReader := &lengthReader{Source: req.Body} buf := io.Discard accLog := testLogger(buf.Write) @@ -119,7 +116,7 @@ func BenchmarkServer_LogRequest(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - s.logRequest(accLog, req, wrec, &duration, repl, bodyReader, false) + s.logRequest(accLog, req, wrec, &duration, repl, false) } } @@ -131,18 +128,17 @@ func BenchmarkServer_LogRequest_NopLogger(b *testing.B) { req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx) rec := httptest.NewRecorder() - wrec := NewResponseRecorder(rec, nil, nil) + wrec, _ := NewResponseRecorder(rec, &req, nil, nil) duration := 50 * time.Millisecond repl := NewTestReplacer(req) - bodyReader := &lengthReader{Source: req.Body} accLog := zap.NewNop() b.ResetTimer() for i := 0; i < b.N; i++ { - s.logRequest(accLog, req, wrec, &duration, repl, bodyReader, false) + s.logRequest(accLog, req, wrec, &duration, repl, false) } } @@ -156,11 +152,10 @@ func BenchmarkServer_LogRequest_WithTrace(b *testing.B) { req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx) rec := httptest.NewRecorder() - wrec := NewResponseRecorder(rec, nil, nil) + wrec, _ := NewResponseRecorder(rec, &req, nil, nil) duration := 50 * time.Millisecond repl := NewTestReplacer(req) - bodyReader := &lengthReader{Source: req.Body} buf := io.Discard accLog := testLogger(buf.Write) @@ -168,7 +163,7 @@ func BenchmarkServer_LogRequest_WithTrace(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - s.logRequest(accLog, req, wrec, &duration, repl, bodyReader, false) + s.logRequest(accLog, req, wrec, &duration, repl, false) } } diff --git a/modules/caddyhttp/templates/templates.go b/modules/caddyhttp/templates/templates.go index eb648865983..0b2b7ded066 100644 --- a/modules/caddyhttp/templates/templates.go +++ b/modules/caddyhttp/templates/templates.go @@ -406,7 +406,7 @@ func (t *Templates) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddy return false } - rec := caddyhttp.NewResponseRecorder(w, buf, shouldBuf) + rec, _ := caddyhttp.NewResponseRecorder(w, &r, buf, shouldBuf) err := next.ServeHTTP(rec, r) if err != nil { From 705b1ad538a8f3aa166a27053a3d90e97d88d086 Mon Sep 17 00:00:00 2001 From: qdongxu Date: Tue, 1 Jul 2025 04:39:01 +0800 Subject: [PATCH 2/3] 1. refine formatting 2. do not export internal fields of lengthReader --- modules/caddyhttp/responsewriter.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/modules/caddyhttp/responsewriter.go b/modules/caddyhttp/responsewriter.go index 4256280a161..96d628d9e0f 100644 --- a/modules/caddyhttp/responsewriter.go +++ b/modules/caddyhttp/responsewriter.go @@ -17,7 +17,7 @@ package caddyhttp import ( "bufio" "bytes" - `context` + "context" "fmt" "io" "net" @@ -155,7 +155,7 @@ func NewResponseRecorder(w http.ResponseWriter, req **http.Request, buf *bytes.B reqBodyLengthReader: lengthReader{}, } if r.Body != nil { - rr.reqBodyLengthReader.Source = r.Body + rr.reqBodyLengthReader.source = r.Body r.Body = &rr.reqBodyLengthReader } @@ -236,7 +236,7 @@ func (rr *responseRecorder) Size() int { // RequestSize returns the number of bytes read from the Request, // not including the request headers. func (rr *responseRecorder) RequestSize() int { - return rr.reqBodyLengthReader.Length + return rr.reqBodyLengthReader.length } // Buffer returns the body buffer that rr was created with. @@ -304,7 +304,7 @@ type hijackedConn struct { } func (hc *hijackedConn) updateReadSize(n int) { - hc.rr.reqBodyLengthReader.Length += n + hc.rr.reqBodyLengthReader.length += n } func (hc *hijackedConn) Read(p []byte) (int, error) { @@ -370,16 +370,16 @@ var ( // conn hijacked by a websocket session. ResponseRecorder will // update the Length field. type lengthReader struct { - Source io.ReadCloser - Length int + source io.ReadCloser + length int } func (r *lengthReader) Read(b []byte) (int, error) { - n, err := r.Source.Read(b) - r.Length += n + n, err := r.source.Read(b) + r.length += n return n, err } func (r *lengthReader) Close() error { - return r.Source.Close() + return r.source.Close() } From c8fc9381e40d53a719ad837c97a7e9d3c918f8a9 Mon Sep 17 00:00:00 2001 From: qdongxu Date: Wed, 2 Jul 2025 22:23:48 +0800 Subject: [PATCH 3/3] fix and add unit test --- modules/caddyhttp/intercept/intercept.go | 11 ++- modules/caddyhttp/metrics.go | 3 +- modules/caddyhttp/responsewriter.go | 26 +++--- modules/caddyhttp/responsewriter_test.go | 104 +++++++++++++++++++---- modules/caddyhttp/server.go | 3 +- modules/caddyhttp/server_test.go | 10 +-- modules/caddyhttp/templates/templates.go | 8 +- 7 files changed, 127 insertions(+), 38 deletions(-) diff --git a/modules/caddyhttp/intercept/intercept.go b/modules/caddyhttp/intercept/intercept.go index ff6c2c9dda4..d00d068666a 100644 --- a/modules/caddyhttp/intercept/intercept.go +++ b/modules/caddyhttp/intercept/intercept.go @@ -131,7 +131,11 @@ func (ir Intercept) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddy repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer) rec := interceptedResponseHandler{replacer: repl} - rec.ResponseRecorder, _ = caddyhttp.NewResponseRecorder(w, &r, buf, func(status int, header http.Header) bool { + var ( + newReq *http.Request + cached bool + ) + rec.ResponseRecorder, newReq, cached = caddyhttp.NewResponseRecorder(w, r, buf, func(status int, header http.Header) bool { // see if any response handler is configured for this original response for i, rh := range ir.HandleResponse { if rh.Match != nil && !rh.Match.Match(status, header) { @@ -157,6 +161,11 @@ func (ir Intercept) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddy return false }) + if !cached { + w = rec.Unwrap() + r = newReq + } + if err := next.ServeHTTP(rec, r); err != nil { return err } diff --git a/modules/caddyhttp/metrics.go b/modules/caddyhttp/metrics.go index e0d935bba47..8d4424ff18a 100644 --- a/modules/caddyhttp/metrics.go +++ b/modules/caddyhttp/metrics.go @@ -154,9 +154,10 @@ func (h *metricsInstrumentedHandler) ServeHTTP(w http.ResponseWriter, r *http.Re return false }) - wrec, cached := NewResponseRecorder(w, &r, nil, writeHeaderRecorder) + wrec, newReq, cached := NewResponseRecorder(w, r, nil, writeHeaderRecorder) if !cached { w = wrec + r = newReq } err := h.mh.ServeHTTP(w, r, next) diff --git a/modules/caddyhttp/responsewriter.go b/modules/caddyhttp/responsewriter.go index 96d628d9e0f..cfa9d070bc5 100644 --- a/modules/caddyhttp/responsewriter.go +++ b/modules/caddyhttp/responsewriter.go @@ -77,6 +77,7 @@ type responseRecorder struct { // NewResponseRecorder returns a new ResponseRecorder that can be // used instead of a standard http.ResponseWriter. +// The returned wrec and newReq must be forwarded to the next Handler if not cached // The ResponseRecorder will be cached in the http.Request.Context to avoid // multiple instances created by different Modules in a pipeline of http handlers. // The recorder is useful for middlewares which need to buffer a response and @@ -104,12 +105,13 @@ type responseRecorder struct { // // Proper usage of a recorder looks like this: // -// rec, cached := caddyhttp.NewResponseRecorder(w, &req, buf, shouldBuffer) +// rec, newReq, cached := caddyhttp.NewResponseRecorder(w, req, buf, shouldBuffer) // if !cached { // w = rec +// req = newReq // } // -// err := next.ServeHTTP(w, req) // do not replace rec if got from a cached one +// err := next.ServeHTTP(w, req) // do not replace rec and newReq if got from a cached one // if err != nil { // return err // } @@ -141,11 +143,12 @@ type responseRecorder struct { // As a special case, 1xx responses are not buffered nor recorded // because they are not the final response; they are passed through // directly to the underlying ResponseWriter. -func NewResponseRecorder(w http.ResponseWriter, req **http.Request, buf *bytes.Buffer, - shouldBuffer ShouldBufferFunc) (wrec ResponseRecorder, cached bool) { - r := *req - if wrec, ok := r.Context().Value(ResponseRecorderVarKey).(ResponseRecorder); ok { - return wrec, true +func NewResponseRecorder(w http.ResponseWriter, r *http.Request, buf *bytes.Buffer, + shouldBuffer ShouldBufferFunc) (wrec ResponseRecorder, nerReq *http.Request, cached bool) { + if buf == nil { + if wrec, ok := r.Context().Value(ResponseRecorderVarKey).(ResponseRecorder); ok { + return wrec, r, true + } } rr := &responseRecorder{ @@ -159,9 +162,12 @@ func NewResponseRecorder(w http.ResponseWriter, req **http.Request, buf *bytes.B r.Body = &rr.reqBodyLengthReader } - c := context.WithValue(r.Context(), ResponseRecorderVarKey, rr) - *req = r.WithContext(c) - return rr, false + if buf == nil { + c := context.WithValue(r.Context(), ResponseRecorderVarKey, rr) + r = r.WithContext(c) + } + + return rr, r, false } // WriteHeader writes the headers with statusCode to the wrapped diff --git a/modules/caddyhttp/responsewriter_test.go b/modules/caddyhttp/responsewriter_test.go index 1aa0a141447..dbfd46a1a96 100644 --- a/modules/caddyhttp/responsewriter_test.go +++ b/modules/caddyhttp/responsewriter_test.go @@ -4,7 +4,7 @@ import ( "bytes" "io" "net/http" - `net/http/httptest` + "net/http/httptest" "strings" "testing" ) @@ -138,7 +138,7 @@ func TestResponseRecorderReadFrom(t *testing.T) { t.Run(name, func(t *testing.T) { var buf bytes.Buffer - rr, _ := NewResponseRecorder(tt.responseWriter, &tt.req, &buf, func(status int, header http.Header) bool { + rr, _, _ := NewResponseRecorder(tt.responseWriter, tt.req, &buf, func(status int, header http.Header) bool { return tt.shouldBuffer }) @@ -177,29 +177,97 @@ func TestResponseRecorderReadFrom(t *testing.T) { } func TestCachedResponseRecorder(t *testing.T) { - r := httptest.NewRequest("GET", "http://example.com/foo", nil) - rOld := r - w := &ResponseWriterWrapper{&baseRespWriter{}} - wrec1, cached := NewResponseRecorder(w, &r, nil, nil) - if cached { - t.Errorf("NewResponseRecorder() should not have been called and cached ") + buf1 := &bytes.Buffer{} + shouldBuf1 := func(status int, header http.Header) bool { + return false + } + buf2 := &bytes.Buffer{} + shouldBuf2 := func(status int, header http.Header) bool { + return true } + var r *http.Request = httptest.NewRequest("GET", "http://example.com/foo", nil) + var w http.ResponseWriter = &ResponseWriterWrapper{&baseRespWriter{}} - if rOld == r { - t.Errorf("r should be different from rOld") + tests := []struct { + name string + buf *bytes.Buffer + ShouldBufferFunc func(status int, header http.Header) bool + wantCached bool + }{ + { + name: "init nil buffer", + buf: nil, + ShouldBufferFunc: nil, + wantCached: false, + }, + { + name: "reuse nil buffer", + buf: nil, + ShouldBufferFunc: nil, + wantCached: true, + }, + { + name: "init buffered", + buf: buf1, + ShouldBufferFunc: shouldBuf1, + wantCached: false, + }, + { + name: "init another buffered", + buf: buf2, + ShouldBufferFunc: shouldBuf2, + wantCached: false, + }, + { + name: "reuse nil buffer", + buf: nil, + ShouldBufferFunc: nil, + wantCached: true, + }, } - rOld = r - wrec2, cached := NewResponseRecorder(w, &r, nil, nil) - if !cached { - t.Errorf("NewResponseRecorder() has been caleed and should be cached ") + var ( + newW = w + newReq = r + cached bool + ) + for _, tt := range tests { + + newW, newReq, cached = NewResponseRecorder(newW, newReq, tt.buf, tt.ShouldBufferFunc) + if cached != tt.wantCached { + t.Errorf("NewResponseRecorder() name = %s, cache = %v, want %v", tt.name, cached, tt.wantCached) + break + } } +} + +func BenchmarkNewResponseRecorderNil(b *testing.B) { + var r *http.Request = httptest.NewRequest("GET", "http://example.com/foo", nil) + var w http.ResponseWriter = &ResponseWriterWrapper{&baseRespWriter{}} + wrec1, newReq, _ := NewResponseRecorder(w, r, nil, nil) + for i := 0; i < b.N; i++ { + wrec1, newReq, _ = NewResponseRecorder(wrec1, r, nil, nil) + } + + if newReq.Context().Value(ResponseRecorderVarKey) == nil { + b.Errorf("NewResponseRecorder() did not set ResponseRecorder context") + } +} - if rOld != r { - t.Errorf("r should be identical as from rOld") +func BenchmarkNewResponseRecorderBuffer(b *testing.B) { + buf := &bytes.Buffer{} + shouldBuf := func(status int, header http.Header) bool { + return false } - if wrec1 != wrec2 { - t.Errorf("NewResponseRecorder() should be identical since it was cached") + var r *http.Request = httptest.NewRequest("GET", "http://example.com/foo", nil) + var w http.ResponseWriter = &ResponseWriterWrapper{&baseRespWriter{}} + wrec1, newReq, _ := NewResponseRecorder(w, r, buf, shouldBuf) + + for i := 0; i < b.N; i++ { + wrec1, newReq, _ = NewResponseRecorder(wrec1, r, buf, shouldBuf) + } + if newReq.Context().Value(ResponseRecorderVarKey) == nil { + b.Errorf("NewResponseRecorder() did not set ResponseRecorder context") } } diff --git a/modules/caddyhttp/server.go b/modules/caddyhttp/server.go index c28600a099f..4a500f130b2 100644 --- a/modules/caddyhttp/server.go +++ b/modules/caddyhttp/server.go @@ -334,9 +334,10 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { var duration time.Duration if s.shouldLogRequest(r) { - wrec, cached := NewResponseRecorder(w, &r, nil, nil) + wrec, newReq, cached := NewResponseRecorder(w, r, nil, nil) if !cached { w = wrec + r = newReq } // capture the original version of the request diff --git a/modules/caddyhttp/server_test.go b/modules/caddyhttp/server_test.go index 8a2e1a0a696..5433dceec07 100644 --- a/modules/caddyhttp/server_test.go +++ b/modules/caddyhttp/server_test.go @@ -51,7 +51,7 @@ func TestServer_LogRequest(t *testing.T) { ctx = context.WithValue(ctx, ExtraLogFieldsCtxKey, new(ExtraLogFields)) req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx) rec := httptest.NewRecorder() - wrec, _ := NewResponseRecorder(rec, &req, nil, nil) + wrec, _, _ := NewResponseRecorder(rec, req, nil, nil) duration := 50 * time.Millisecond repl := NewTestReplacer(req) @@ -78,7 +78,7 @@ func TestServer_LogRequest_WithTrace(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx) rec := httptest.NewRecorder() - wrec, _ := NewResponseRecorder(rec, &req, nil, nil) + wrec, _, _ := NewResponseRecorder(rec, req, nil, nil) duration := 50 * time.Millisecond repl := NewTestReplacer(req) @@ -105,7 +105,7 @@ func BenchmarkServer_LogRequest(b *testing.B) { req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx) rec := httptest.NewRecorder() - wrec, _ := NewResponseRecorder(rec, &req, nil, nil) + wrec, _, _ := NewResponseRecorder(rec, req, nil, nil) duration := 50 * time.Millisecond repl := NewTestReplacer(req) @@ -128,7 +128,7 @@ func BenchmarkServer_LogRequest_NopLogger(b *testing.B) { req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx) rec := httptest.NewRecorder() - wrec, _ := NewResponseRecorder(rec, &req, nil, nil) + wrec, _, _ := NewResponseRecorder(rec, req, nil, nil) duration := 50 * time.Millisecond repl := NewTestReplacer(req) @@ -152,7 +152,7 @@ func BenchmarkServer_LogRequest_WithTrace(b *testing.B) { req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx) rec := httptest.NewRecorder() - wrec, _ := NewResponseRecorder(rec, &req, nil, nil) + wrec, _, _ := NewResponseRecorder(rec, req, nil, nil) duration := 50 * time.Millisecond repl := NewTestReplacer(req) diff --git a/modules/caddyhttp/templates/templates.go b/modules/caddyhttp/templates/templates.go index 0b2b7ded066..778d154465c 100644 --- a/modules/caddyhttp/templates/templates.go +++ b/modules/caddyhttp/templates/templates.go @@ -406,9 +406,13 @@ func (t *Templates) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddy return false } - rec, _ := caddyhttp.NewResponseRecorder(w, &r, buf, shouldBuf) + rec, newReq, cached := caddyhttp.NewResponseRecorder(w, r, buf, shouldBuf) + if !cached { + w = rec + r = newReq + } - err := next.ServeHTTP(rec, r) + err := next.ServeHTTP(w, r) if err != nil { return err }