Skip to content

Allow ResponseWriters to unwrap writers when flushing/hijacking #2595

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 9, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Allow ResponseWriters to unwrap writers when flushing/hijacking
  • Loading branch information
aldas committed Feb 20, 2024
commit a5999fcbbd7f5db0b060f671ddf1f5f3b463f392
12 changes: 10 additions & 2 deletions middleware/body_dump.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package middleware
import (
"bufio"
"bytes"
"errors"
"io"
"net"
"net/http"
Expand Down Expand Up @@ -98,9 +99,16 @@ func (w *bodyDumpResponseWriter) Write(b []byte) (int, error) {
}

func (w *bodyDumpResponseWriter) Flush() {
w.ResponseWriter.(http.Flusher).Flush()
err := http.NewResponseController(w.ResponseWriter).Flush()
if err != nil && errors.Is(err, http.ErrNotSupported) {
panic(errors.New("response writer flushing is not supported"))
}
}

func (w *bodyDumpResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return w.ResponseWriter.(http.Hijacker).Hijack()
return http.NewResponseController(w.ResponseWriter).Hijack()
}

func (w *bodyDumpResponseWriter) Unwrap() http.ResponseWriter {
return w.ResponseWriter
}
50 changes: 50 additions & 0 deletions middleware/body_dump_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,53 @@ func TestBodyDumpFails(t *testing.T) {
}
})
}

func TestBodyDumpResponseWriter_CanNotFlush(t *testing.T) {
bdrw := bodyDumpResponseWriter{
ResponseWriter: new(testResponseWriterNoFlushHijack), // this RW does not support flush
}

assert.PanicsWithError(t, "response writer flushing is not supported", func() {
bdrw.Flush()
})
}

func TestBodyDumpResponseWriter_CanFlush(t *testing.T) {
trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}}
bdrw := bodyDumpResponseWriter{
ResponseWriter: &trwu,
}

bdrw.Flush()
assert.Equal(t, 1, trwu.unwrapCalled)
}

func TestBodyDumpResponseWriter_CanUnwrap(t *testing.T) {
trwu := &testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
bdrw := bodyDumpResponseWriter{
ResponseWriter: trwu,
}

result := bdrw.Unwrap()
assert.Equal(t, trwu, result)
}

func TestBodyDumpResponseWriter_CanHijack(t *testing.T) {
trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}}
bdrw := bodyDumpResponseWriter{
ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
}

_, _, err := bdrw.Hijack()
assert.EqualError(t, err, "can hijack")
}

func TestBodyDumpResponseWriter_CanNotHijack(t *testing.T) {
trwu := testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
bdrw := bodyDumpResponseWriter{
ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
}

_, _, err := bdrw.Hijack()
assert.EqualError(t, err, "feature not supported")
}
10 changes: 6 additions & 4 deletions middleware/compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,13 +191,15 @@ func (w *gzipResponseWriter) Flush() {
}

w.Writer.(*gzip.Writer).Flush()
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
http.NewResponseController(w.ResponseWriter).Flush()
}

func (w *gzipResponseWriter) Unwrap() http.ResponseWriter {
return w.ResponseWriter
}

func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return w.ResponseWriter.(http.Hijacker).Hijack()
return http.NewResponseController(w.ResponseWriter).Hijack()
}

func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error {
Expand Down
30 changes: 30 additions & 0 deletions middleware/compress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,36 @@ func TestGzipWithStatic(t *testing.T) {
}
}

func TestGzipResponseWriter_CanUnwrap(t *testing.T) {
trwu := &testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
bdrw := gzipResponseWriter{
ResponseWriter: trwu,
}

result := bdrw.Unwrap()
assert.Equal(t, trwu, result)
}

func TestGzipResponseWriter_CanHijack(t *testing.T) {
trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}}
bdrw := gzipResponseWriter{
ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
}

_, _, err := bdrw.Hijack()
assert.EqualError(t, err, "can hijack")
}

func TestGzipResponseWriter_CanNotHijack(t *testing.T) {
trwu := testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
bdrw := gzipResponseWriter{
ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
}

_, _, err := bdrw.Hijack()
assert.EqualError(t, err, "feature not supported")
}

func BenchmarkGzip(b *testing.B) {
e := echo.New()

Expand Down
46 changes: 46 additions & 0 deletions middleware/middleware_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package middleware

import (
"bufio"
"errors"
"github.com/stretchr/testify/assert"
"net"
"net/http"
"net/http/httptest"
"regexp"
Expand Down Expand Up @@ -90,3 +93,46 @@ func TestRewriteURL(t *testing.T) {
})
}
}

type testResponseWriterNoFlushHijack struct {
}

func (w *testResponseWriterNoFlushHijack) WriteHeader(statusCode int) {
}

func (w *testResponseWriterNoFlushHijack) Write([]byte) (int, error) {
return 0, nil
}

func (w *testResponseWriterNoFlushHijack) Header() http.Header {
return nil
}

type testResponseWriterUnwrapper struct {
unwrapCalled int
rw http.ResponseWriter
}

func (w *testResponseWriterUnwrapper) WriteHeader(statusCode int) {
}

func (w *testResponseWriterUnwrapper) Write([]byte) (int, error) {
return 0, nil
}

func (w *testResponseWriterUnwrapper) Header() http.Header {
return nil
}

func (w *testResponseWriterUnwrapper) Unwrap() http.ResponseWriter {
w.unwrapCalled++
return w.rw
}

type testResponseWriterUnwrapperHijack struct {
testResponseWriterUnwrapper
}

func (w *testResponseWriterUnwrapperHijack) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, errors.New("can hijack")
}
8 changes: 6 additions & 2 deletions response.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package echo

import (
"bufio"
"errors"
"net"
"net/http"
)
Expand Down Expand Up @@ -84,14 +85,17 @@ func (r *Response) Write(b []byte) (n int, err error) {
// buffered data to the client.
// See [http.Flusher](https://golang.org/pkg/net/http/#Flusher)
func (r *Response) Flush() {
r.Writer.(http.Flusher).Flush()
err := http.NewResponseController(r.Writer).Flush()
if err != nil && errors.Is(err, http.ErrNotSupported) {
panic(errors.New("response writer flushing is not supported"))
}
}

// Hijack implements the http.Hijacker interface to allow an HTTP handler to
// take over the connection.
// See [http.Hijacker](https://golang.org/pkg/net/http/#Hijacker)
func (r *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return r.Writer.(http.Hijacker).Hijack()
return http.NewResponseController(r.Writer).Hijack()
}

// Unwrap returns the original http.ResponseWriter.
Expand Down
25 changes: 25 additions & 0 deletions response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,31 @@ func TestResponse_Flush(t *testing.T) {
assert.True(t, rec.Flushed)
}

type testResponseWriter struct {
}

func (w *testResponseWriter) WriteHeader(statusCode int) {
}

func (w *testResponseWriter) Write([]byte) (int, error) {
return 0, nil
}

func (w *testResponseWriter) Header() http.Header {
return nil
}

func TestResponse_FlushPanics(t *testing.T) {
e := New()
rw := new(testResponseWriter)
res := &Response{echo: e, Writer: rw}

// we test that we behave as before unwrapping flushers - flushing writer that does not support it causes panic
assert.PanicsWithError(t, "response writer flushing is not supported", func() {
res.Flush()
})
}

func TestResponse_ChangeStatusCodeBeforeWrite(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
Expand Down