Skip to content

Align the semantics of http.request.size with http.response.size #7100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion modules/caddyhttp/intercept/intercept.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,11 @@ func (ir Intercept) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddy

repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
rec := interceptedResponseHandler{replacer: repl}
rec.ResponseRecorder = caddyhttp.NewResponseRecorder(w, buf, func(status int, header http.Header) bool {
var (
newReq *http.Request
cached bool
)
rec.ResponseRecorder, newReq, cached = caddyhttp.NewResponseRecorder(w, r, buf, func(status int, header http.Header) bool {
// see if any response handler is configured for this original response
for i, rh := range ir.HandleResponse {
if rh.Match != nil && !rh.Match.Match(status, header) {
Expand All @@ -157,6 +161,11 @@ func (ir Intercept) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddy
return false
})

if !cached {
w = rec.Unwrap()
r = newReq
}

if err := next.ServeHTTP(rec, r); err != nil {
return err
}
Expand Down
21 changes: 12 additions & 9 deletions modules/caddyhttp/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,15 @@ func (h *metricsInstrumentedHandler) ServeHTTP(w http.ResponseWriter, r *http.Re
h.metrics.httpMetrics.responseDuration.With(statusLabels).Observe(ttfb)
return false
})
wrec := NewResponseRecorder(w, nil, writeHeaderRecorder)
err := h.mh.ServeHTTP(wrec, r, next)

wrec, newReq, cached := NewResponseRecorder(w, r, nil, writeHeaderRecorder)
if !cached {
w = wrec
r = newReq
}

err := h.mh.ServeHTTP(w, r, next)

dur := time.Since(start).Seconds()
h.metrics.httpMetrics.requestCount.With(labels).Inc()

Expand All @@ -168,7 +175,7 @@ func (h *metricsInstrumentedHandler) ServeHTTP(w http.ResponseWriter, r *http.Re
}

h.metrics.httpMetrics.requestDuration.With(statusLabels).Observe(dur)
h.metrics.httpMetrics.requestSize.With(statusLabels).Observe(float64(computeApproximateRequestSize(r)))
h.metrics.httpMetrics.requestSize.With(statusLabels).Observe(float64(computeApproximateRequestSize(wrec, r)))
h.metrics.httpMetrics.responseSize.With(statusLabels).Observe(float64(wrec.Size()))
}

Expand All @@ -189,7 +196,7 @@ func (h *metricsInstrumentedHandler) ServeHTTP(w http.ResponseWriter, r *http.Re
}

// taken from https://github.com/prometheus/client_golang/blob/6007b2b5cae01203111de55f753e76d8dac1f529/prometheus/promhttp/instrument_server.go#L298
func computeApproximateRequestSize(r *http.Request) int {
func computeApproximateRequestSize(wrec ResponseRecorder, r *http.Request) int {
s := 0
if r.URL != nil {
s += len(r.URL.String())
Expand All @@ -205,10 +212,6 @@ func computeApproximateRequestSize(r *http.Request) int {
}
s += len(r.Host)

// N.B. r.Form and r.MultipartForm are assumed to be included in r.URL.

if r.ContentLength != -1 {
s += int(r.ContentLength)
}
s += wrec.RequestSize()
return s
}
79 changes: 63 additions & 16 deletions modules/caddyhttp/responsewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -71,12 +72,15 @@
wroteHeader bool
stream bool

readSize *int
reqBodyLengthReader lengthReader
}

// NewResponseRecorder returns a new ResponseRecorder that can be
// used instead of a standard http.ResponseWriter. The recorder is
// useful for middlewares which need to buffer a response and
// used instead of a standard http.ResponseWriter.
// The returned wrec and newReq must be forwarded to the next Handler if not cached
// The ResponseRecorder will be cached in the http.Request.Context to avoid
// multiple instances created by different Modules in a pipeline of http handlers.
// The recorder is useful for middlewares which need to buffer a response and
// potentially process its entire body before actually writing the
// response to the underlying writer. Of course, buffering the entire
// body has a memory overhead, but sometimes there is no way to avoid
Expand All @@ -101,8 +105,13 @@
//
// Proper usage of a recorder looks like this:
//
// rec := caddyhttp.NewResponseRecorder(w, buf, shouldBuffer)
// err := next.ServeHTTP(rec, req)
// rec, newReq, cached := caddyhttp.NewResponseRecorder(w, req, buf, shouldBuffer)
// if !cached {
// w = rec
// req = newReq
// }
//
// err := next.ServeHTTP(w, req) // do not replace rec and newReq if got from a cached one
// if err != nil {
// return err
// }
Expand Down Expand Up @@ -134,12 +143,31 @@
// As a special case, 1xx responses are not buffered nor recorded
// because they are not the final response; they are passed through
// directly to the underlying ResponseWriter.
func NewResponseRecorder(w http.ResponseWriter, buf *bytes.Buffer, shouldBuffer ShouldBufferFunc) ResponseRecorder {
return &responseRecorder{
func NewResponseRecorder(w http.ResponseWriter, r *http.Request, buf *bytes.Buffer,
shouldBuffer ShouldBufferFunc) (wrec ResponseRecorder, nerReq *http.Request, cached bool) {

Check failure on line 147 in modules/caddyhttp/responsewriter.go

View workflow job for this annotation

GitHub Actions / lint (mac)

File is not properly formatted (gofumpt)
if buf == nil {
if wrec, ok := r.Context().Value(ResponseRecorderVarKey).(ResponseRecorder); ok {
return wrec, r, true
}
}

rr := &responseRecorder{
ResponseWriterWrapper: &ResponseWriterWrapper{ResponseWriter: w},
buf: buf,
shouldBuffer: shouldBuffer,
reqBodyLengthReader: lengthReader{},
}
if r.Body != nil {
rr.reqBodyLengthReader.source = r.Body
r.Body = &rr.reqBodyLengthReader
}

if buf == nil {
c := context.WithValue(r.Context(), ResponseRecorderVarKey, rr)

Check failure on line 166 in modules/caddyhttp/responsewriter.go

View workflow job for this annotation

GitHub Actions / lint (mac)

SA1029: should not use built-in type string as key for value; define your own type to avoid collisions (staticcheck)
r = r.WithContext(c)
}

return rr, r, false
}

// WriteHeader writes the headers with statusCode to the wrapped
Expand Down Expand Up @@ -211,6 +239,12 @@
return rr.size
}

// RequestSize returns the number of bytes read from the Request,
// not including the request headers.
func (rr *responseRecorder) RequestSize() int {
return rr.reqBodyLengthReader.length
}

// Buffer returns the body buffer that rr was created with.
// You should still have your original pointer, though.
func (rr *responseRecorder) Buffer() *bytes.Buffer {
Expand Down Expand Up @@ -246,12 +280,6 @@
return nil
}

// Private interface so it can only be used in this package
// #TODO: maybe export it later
func (rr *responseRecorder) setReadSize(size *int) {
rr.readSize = size
}

func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
//nolint:bodyclose
conn, brw, err := http.NewResponseController(rr.ResponseWriterWrapper).Hijack()
Expand Down Expand Up @@ -282,9 +310,7 @@
}

func (hc *hijackedConn) updateReadSize(n int) {
if hc.rr.readSize != nil {
*hc.rr.readSize += n
}
hc.rr.reqBodyLengthReader.length += n
}

func (hc *hijackedConn) Read(p []byte) (int, error) {
Expand Down Expand Up @@ -320,6 +346,7 @@
Buffer() *bytes.Buffer
Buffered() bool
Size() int
RequestSize() int
WriteResponse() error
}

Expand All @@ -342,3 +369,23 @@

_ io.WriterTo = (*hijackedConn)(nil)
)

// lengthReader is an io.ReadCloser that keeps track of the
// number of bytes read from the request body.
// This wrapper is for http request process only. If the underlying
// conn hijacked by a websocket session. ResponseRecorder will
// update the Length field.
type lengthReader struct {
source io.ReadCloser
length int
}

func (r *lengthReader) Read(b []byte) (int, error) {
n, err := r.source.Read(b)
r.length += n
return n, err
}

func (r *lengthReader) Close() error {
return r.source.Close()
}
104 changes: 103 additions & 1 deletion modules/caddyhttp/responsewriter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
Expand Down Expand Up @@ -104,26 +105,31 @@ func TestResponseWriterWrapperUnwrap(t *testing.T) {
func TestResponseRecorderReadFrom(t *testing.T) {
tests := map[string]struct {
responseWriter responseWriterSpy
req *http.Request
shouldBuffer bool
wantReadFrom bool
}{
"buffered plain": {
responseWriter: &baseRespWriter{},
req: &http.Request{},
shouldBuffer: true,
wantReadFrom: false,
},
"streamed plain": {
responseWriter: &baseRespWriter{},
req: &http.Request{},
shouldBuffer: false,
wantReadFrom: false,
},
"buffered ReadFrom": {
responseWriter: &readFromRespWriter{},
req: &http.Request{},
shouldBuffer: true,
wantReadFrom: false,
},
"streamed ReadFrom": {
responseWriter: &readFromRespWriter{},
req: &http.Request{},
shouldBuffer: false,
wantReadFrom: true,
},
Expand All @@ -132,7 +138,7 @@ func TestResponseRecorderReadFrom(t *testing.T) {
t.Run(name, func(t *testing.T) {
var buf bytes.Buffer

rr := NewResponseRecorder(tt.responseWriter, &buf, func(status int, header http.Header) bool {
rr, _, _ := NewResponseRecorder(tt.responseWriter, tt.req, &buf, func(status int, header http.Header) bool {
return tt.shouldBuffer
})

Expand Down Expand Up @@ -169,3 +175,99 @@ func TestResponseRecorderReadFrom(t *testing.T) {
})
}
}

func TestCachedResponseRecorder(t *testing.T) {
buf1 := &bytes.Buffer{}
shouldBuf1 := func(status int, header http.Header) bool {
return false
}
buf2 := &bytes.Buffer{}
shouldBuf2 := func(status int, header http.Header) bool {
return true
}
var r *http.Request = httptest.NewRequest("GET", "http://example.com/foo", nil)
var w http.ResponseWriter = &ResponseWriterWrapper{&baseRespWriter{}}

tests := []struct {
name string
buf *bytes.Buffer
ShouldBufferFunc func(status int, header http.Header) bool
wantCached bool
}{
{
name: "init nil buffer",
buf: nil,
ShouldBufferFunc: nil,
wantCached: false,
},
{
name: "reuse nil buffer",
buf: nil,
ShouldBufferFunc: nil,
wantCached: true,
},
{
name: "init buffered",
buf: buf1,
ShouldBufferFunc: shouldBuf1,
wantCached: false,
},
{
name: "init another buffered",
buf: buf2,
ShouldBufferFunc: shouldBuf2,
wantCached: false,
},
{
name: "reuse nil buffer",
buf: nil,
ShouldBufferFunc: nil,
wantCached: true,
},
}

var (
newW = w
newReq = r
cached bool
)
for _, tt := range tests {

newW, newReq, cached = NewResponseRecorder(newW, newReq, tt.buf, tt.ShouldBufferFunc)
if cached != tt.wantCached {
t.Errorf("NewResponseRecorder() name = %s, cache = %v, want %v", tt.name, cached, tt.wantCached)
break
}
}
}

func BenchmarkNewResponseRecorderNil(b *testing.B) {
var r *http.Request = httptest.NewRequest("GET", "http://example.com/foo", nil)
var w http.ResponseWriter = &ResponseWriterWrapper{&baseRespWriter{}}
wrec1, newReq, _ := NewResponseRecorder(w, r, nil, nil)
for i := 0; i < b.N; i++ {
wrec1, newReq, _ = NewResponseRecorder(wrec1, r, nil, nil)
}

if newReq.Context().Value(ResponseRecorderVarKey) == nil {
b.Errorf("NewResponseRecorder() did not set ResponseRecorder context")
}
}

func BenchmarkNewResponseRecorderBuffer(b *testing.B) {
buf := &bytes.Buffer{}
shouldBuf := func(status int, header http.Header) bool {
return false
}

var r *http.Request = httptest.NewRequest("GET", "http://example.com/foo", nil)
var w http.ResponseWriter = &ResponseWriterWrapper{&baseRespWriter{}}
wrec1, newReq, _ := NewResponseRecorder(w, r, buf, shouldBuf)

for i := 0; i < b.N; i++ {
wrec1, newReq, _ = NewResponseRecorder(wrec1, r, buf, shouldBuf)
}
if newReq.Context().Value(ResponseRecorderVarKey) == nil {
b.Errorf("NewResponseRecorder() did not set ResponseRecorder context")
}
}
Loading
Loading