Skip to content

Commit 67be7d9

Browse files
authored
middleware: add Discard method to WrapResponseWriter (#926)
* middleware: add Discard method to WrapResponseWriter * resolve review comments * use ioutil.Discard and deprecate the public interface * move the Discard method back to the public interface * discard calls to WriteHeader too
1 parent 7957c0d commit 67be7d9

File tree

2 files changed

+89
-8
lines changed

2 files changed

+89
-8
lines changed

middleware/wrap_writer.go

+27-8
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package middleware
66
import (
77
"bufio"
88
"io"
9+
"io/ioutil"
910
"net"
1011
"net/http"
1112
)
@@ -61,6 +62,11 @@ type WrapResponseWriter interface {
6162
Tee(io.Writer)
6263
// Unwrap returns the original proxied target.
6364
Unwrap() http.ResponseWriter
65+
// Discard causes all writes to the original ResponseWriter be discarded,
66+
// instead writing only to the tee'd writer if it's set.
67+
// The caller is responsible for calling WriteHeader and Write on the
68+
// original ResponseWriter once the processing is done.
69+
Discard()
6470
}
6571

6672
// basicWriter wraps a http.ResponseWriter that implements the minimal
@@ -71,25 +77,34 @@ type basicWriter struct {
7177
code int
7278
bytes int
7379
tee io.Writer
80+
discard bool
7481
}
7582

7683
func (b *basicWriter) WriteHeader(code int) {
7784
if !b.wroteHeader {
7885
b.code = code
7986
b.wroteHeader = true
80-
b.ResponseWriter.WriteHeader(code)
87+
if !b.discard {
88+
b.ResponseWriter.WriteHeader(code)
89+
}
8190
}
8291
}
8392

84-
func (b *basicWriter) Write(buf []byte) (int, error) {
93+
func (b *basicWriter) Write(buf []byte) (n int, err error) {
8594
b.maybeWriteHeader()
86-
n, err := b.ResponseWriter.Write(buf)
87-
if b.tee != nil {
88-
_, err2 := b.tee.Write(buf[:n])
89-
// Prefer errors generated by the proxied writer.
90-
if err == nil {
91-
err = err2
95+
if !b.discard {
96+
n, err = b.ResponseWriter.Write(buf)
97+
if b.tee != nil {
98+
_, err2 := b.tee.Write(buf[:n])
99+
// Prefer errors generated by the proxied writer.
100+
if err == nil {
101+
err = err2
102+
}
92103
}
104+
} else if b.tee != nil {
105+
n, err = b.tee.Write(buf)
106+
} else {
107+
n, err = ioutil.Discard.Write(buf)
93108
}
94109
b.bytes += n
95110
return n, err
@@ -117,6 +132,10 @@ func (b *basicWriter) Unwrap() http.ResponseWriter {
117132
return b.ResponseWriter
118133
}
119134

135+
func (b *basicWriter) Discard() {
136+
b.discard = true
137+
}
138+
120139
// flushWriter ...
121140
type flushWriter struct {
122141
basicWriter

middleware/wrap_writer_test.go

+62
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package middleware
22

33
import (
4+
"bytes"
5+
"net/http"
46
"net/http/httptest"
57
"testing"
68
)
@@ -22,3 +24,63 @@ func TestHttp2FancyWriterRemembersWroteHeaderWhenFlushed(t *testing.T) {
2224
t.Fatal("want Flush to have set wroteHeader=true")
2325
}
2426
}
27+
28+
func TestBasicWritesTeesWritesWithoutDiscard(t *testing.T) {
29+
// explicitly create the struct instead of NewRecorder to control the value of Code
30+
original := &httptest.ResponseRecorder{
31+
HeaderMap: make(http.Header),
32+
Body: new(bytes.Buffer),
33+
}
34+
wrap := &basicWriter{ResponseWriter: original}
35+
36+
var buf bytes.Buffer
37+
wrap.Tee(&buf)
38+
39+
_, err := wrap.Write([]byte("hello world"))
40+
assertNoError(t, err)
41+
42+
assertEqual(t, 200, original.Code)
43+
assertEqual(t, []byte("hello world"), original.Body.Bytes())
44+
assertEqual(t, []byte("hello world"), buf.Bytes())
45+
assertEqual(t, 11, wrap.BytesWritten())
46+
}
47+
48+
func TestBasicWriterDiscardsWritesToOriginalResponseWriter(t *testing.T) {
49+
t.Run("With Tee", func(t *testing.T) {
50+
// explicitly create the struct instead of NewRecorder to control the value of Code
51+
original := &httptest.ResponseRecorder{
52+
HeaderMap: make(http.Header),
53+
Body: new(bytes.Buffer),
54+
}
55+
wrap := &basicWriter{ResponseWriter: original}
56+
57+
var buf bytes.Buffer
58+
wrap.Tee(&buf)
59+
wrap.Discard()
60+
61+
_, err := wrap.Write([]byte("hello world"))
62+
assertNoError(t, err)
63+
64+
assertEqual(t, 0, original.Code) // wrapper shouldn't call WriteHeader implicitly
65+
assertEqual(t, 0, original.Body.Len())
66+
assertEqual(t, []byte("hello world"), buf.Bytes())
67+
assertEqual(t, 11, wrap.BytesWritten())
68+
})
69+
70+
t.Run("Without Tee", func(t *testing.T) {
71+
// explicitly create the struct instead of NewRecorder to control the value of Code
72+
original := &httptest.ResponseRecorder{
73+
HeaderMap: make(http.Header),
74+
Body: new(bytes.Buffer),
75+
}
76+
wrap := &basicWriter{ResponseWriter: original}
77+
wrap.Discard()
78+
79+
_, err := wrap.Write([]byte("hello world"))
80+
assertNoError(t, err)
81+
82+
assertEqual(t, 0, original.Code) // wrapper shouldn't call WriteHeader implicitly
83+
assertEqual(t, 0, original.Body.Len())
84+
assertEqual(t, 11, wrap.BytesWritten())
85+
})
86+
}

0 commit comments

Comments
 (0)