diff --git a/src/net/http/http_test.go b/src/net/http/http_test.go index c12bbedac986db..a8aa20b5800813 100644 --- a/src/net/http/http_test.go +++ b/src/net/http/http_test.go @@ -9,6 +9,7 @@ package http import ( "bytes" "internal/testenv" + "io" "io/fs" "net/url" "os" @@ -189,6 +190,45 @@ func TestNoUnicodeStrings(t *testing.T) { } } +type requestTooLargerResponseWriter struct { + called bool +} + +func (rw *requestTooLargerResponseWriter) Header() Header { return Header{} } +func (rw *requestTooLargerResponseWriter) Write(b []byte) (int, error) { return len(b), nil } +func (rw *requestTooLargerResponseWriter) WriteHeader(statusCode int) {} +func (rw *requestTooLargerResponseWriter) requestTooLarge() { + rw.called = true +} + +type wrapper struct { + ResponseWriter +} + +func (w *wrapper) Unwrap() ResponseWriter { + return w.ResponseWriter +} + +func TestMaxBytesReaderUnwrapTriggersRequestTooLarge(t *testing.T) { + body := strings.NewReader("123456") + limit := int64(5) + + innerRw := &requestTooLargerResponseWriter{} + wrappedRw := &wrapper{ResponseWriter: innerRw} + + l := MaxBytesReader(wrappedRw, io.NopCloser(body), limit) + + buf := make([]byte, 10) + _, err := l.Read(buf) + + if _, ok := err.(*MaxBytesError); !ok { + t.Errorf("expected MaxBytesError, got %T", err) + } + if !innerRw.called { + t.Errorf("expected requestTooLarge to be called, but it wasn't") + } +} + func TestProtocols(t *testing.T) { var p Protocols if p.HTTP1() { diff --git a/src/net/http/request.go b/src/net/http/request.go index 167cff585af34a..96b8363b523d8d 100644 --- a/src/net/http/request.go +++ b/src/net/http/request.go @@ -1243,9 +1243,23 @@ func (l *maxBytesReader) Read(p []byte) (n int, err error) { type requestTooLarger interface { requestTooLarge() } - if res, ok := l.w.(requestTooLarger); ok { - res.requestTooLarge() + // Unwrap the ResponseWriter wrappers until we find one that implements + // the server-only requestTooLarger interface, then call requestTooLarge(). + // This ensures that even if the ResponseWriter is wrapped by a custom implementation, + // the underlying server writer can be notified when the request body is too large. + rw := l.w + for { + if res, ok := rw.(requestTooLarger); ok { + res.requestTooLarge() + break + } + unwrapper, ok := rw.(rwUnwrapper) + if !ok { + break + } + rw = unwrapper.Unwrap() } + l.err = &MaxBytesError{l.i} return n, l.err }