Skip to content

Commit f10dc4a

Browse files
authored
fix(middleware): Close created writer in the compressor middleware (#919)
1 parent ef31c0b commit f10dc4a

File tree

2 files changed

+45
-3
lines changed

2 files changed

+45
-3
lines changed

middleware/compress.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ func (cw *compressResponseWriter) Push(target string, opts *http.PushOptions) er
371371
}
372372

373373
func (cw *compressResponseWriter) Close() error {
374-
if c, ok := cw.writer().(io.WriteCloser); ok {
374+
if c, ok := cw.w.(io.WriteCloser); ok {
375375
return c.Close()
376376
}
377377
return errors.New("chi/middleware: io.WriteCloser is unavailable on the writer")

middleware/compress_test.go

+44-2
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,13 @@ func TestCompressor(t *testing.T) {
2626
return w
2727
})
2828

29-
if len(compressor.encoders) != 1 {
30-
t.Errorf("nop encoder should be stored in the encoders map")
29+
var sideEffect int
30+
compressor.SetEncoder("test", func(w io.Writer, _ int) io.Writer {
31+
return newSideEffectWriter(w, &sideEffect)
32+
})
33+
34+
if len(compressor.encoders) != 2 {
35+
t.Errorf("nop and test encoders should be stored in the encoders map")
3136
}
3237

3338
r.Use(compressor.Handler)
@@ -47,6 +52,11 @@ func TestCompressor(t *testing.T) {
4752
w.Write([]byte("textstring"))
4853
})
4954

55+
r.Get("/getimage", func(w http.ResponseWriter, r *http.Request) {
56+
w.Header().Set("Content-Type", "image/png")
57+
w.Write([]byte("textstring"))
58+
})
59+
5060
ts := httptest.NewServer(r)
5161
defer ts.Close()
5262

@@ -93,6 +103,12 @@ func TestCompressor(t *testing.T) {
93103
acceptedEncodings: []string{"nop, gzip, deflate"},
94104
expectedEncoding: "nop",
95105
},
106+
{
107+
name: "test is used and side effect is cleared after close",
108+
path: "/getimage",
109+
acceptedEncodings: []string{"test"},
110+
expectedEncoding: "",
111+
},
96112
}
97113

98114
for _, tc := range tests {
@@ -107,7 +123,10 @@ func TestCompressor(t *testing.T) {
107123
}
108124

109125
})
126+
}
110127

128+
if sideEffect > 1 {
129+
t.Errorf("side effect should be cleared after close")
111130
}
112131
}
113132

@@ -217,3 +236,26 @@ func decodeResponseBody(t *testing.T, resp *http.Response) string {
217236

218237
return string(respBody)
219238
}
239+
240+
type (
241+
sideEffectWriter struct {
242+
w io.Writer
243+
s *int
244+
}
245+
)
246+
247+
func newSideEffectWriter(w io.Writer, sideEffect *int) io.Writer {
248+
*sideEffect = *sideEffect + 1
249+
250+
return &sideEffectWriter{w: w, s: sideEffect}
251+
}
252+
253+
func (w *sideEffectWriter) Write(p []byte) (n int, err error) {
254+
return w.w.Write(p)
255+
}
256+
257+
func (w *sideEffectWriter) Close() error {
258+
*w.s = *w.s - 1
259+
260+
return nil
261+
}

0 commit comments

Comments
 (0)