Skip to content

Commit 6706199

Browse files
committed
Rework timeout middleware to use http.TimeoutHandler implementation. Replace ErrorHandler with ErrorMessage on config struct (fix labstack#1761)
1 parent 6f9b71c commit 6706199

File tree

2 files changed

+132
-104
lines changed

2 files changed

+132
-104
lines changed

middleware/timeout.go

Lines changed: 42 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ package middleware
44

55
import (
66
"context"
7-
"fmt"
87
"github.com/labstack/echo/v4"
8+
"net/http"
99
"time"
1010
)
1111

@@ -14,24 +14,25 @@ type (
1414
TimeoutConfig struct {
1515
// Skipper defines a function to skip middleware.
1616
Skipper Skipper
17-
// ErrorHandler defines a function which is executed for a timeout
18-
// It can be used to define a custom timeout error
19-
ErrorHandler TimeoutErrorHandlerWithContext
17+
18+
// ErrorMessage is written to response on timeout in addition to http.StatusServiceUnavailable (503) status code
19+
// It can be used to define a custom timeout error message
20+
ErrorMessage string
21+
2022
// Timeout configures a timeout for the middleware, defaults to 0 for no timeout
23+
// NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds)
24+
// the result of timeout does not seem to be reliable - could respond timeout, could respond handler output
25+
// difference over 500microseconds (0.5millisecond) response seems to be reliable
2126
Timeout time.Duration
2227
}
23-
24-
// TimeoutErrorHandlerWithContext is an error handler that is used with the timeout middleware so we can
25-
// handle the error as we see fit
26-
TimeoutErrorHandlerWithContext func(error, echo.Context) error
2728
)
2829

2930
var (
3031
// DefaultTimeoutConfig is the default Timeout middleware config.
3132
DefaultTimeoutConfig = TimeoutConfig{
3233
Skipper: DefaultSkipper,
3334
Timeout: 0,
34-
ErrorHandler: nil,
35+
ErrorMessage: "",
3536
}
3637
)
3738

@@ -55,39 +56,42 @@ func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc {
5556
return next(c)
5657
}
5758

58-
ctx, cancel := context.WithTimeout(c.Request().Context(), config.Timeout)
59-
defer cancel()
60-
61-
// this does a deep clone of the context, wondering if there is a better way to do this?
62-
c.SetRequest(c.Request().Clone(ctx))
63-
64-
done := make(chan error, 1)
65-
go func() {
66-
defer func() {
67-
if r := recover(); r != nil {
68-
err, ok := r.(error)
69-
if !ok {
70-
err = fmt.Errorf("panic recovered in timeout middleware: %v", r)
71-
}
72-
c.Logger().Error(err)
73-
done <- err
74-
}
75-
}()
76-
77-
// This goroutine will keep running even if this middleware times out and
78-
// will be stopped when ctx.Done() is called down the next(c) call chain
79-
done <- next(c)
80-
}()
59+
handlerWrapper := echoHandlerFuncWrapper{
60+
ctx: c,
61+
handler: next,
62+
errChan: make(chan error, 1),
63+
}
64+
handler := http.TimeoutHandler(handlerWrapper, config.Timeout, config.ErrorMessage)
65+
handler.ServeHTTP(c.Response().Writer, c.Request())
8166

8267
select {
83-
case <-ctx.Done():
84-
if config.ErrorHandler != nil {
85-
return config.ErrorHandler(ctx.Err(), c)
86-
}
87-
return ctx.Err()
88-
case err := <-done:
68+
case err := <-handlerWrapper.errChan:
8969
return err
70+
default:
71+
return nil
9072
}
9173
}
9274
}
9375
}
76+
77+
type echoHandlerFuncWrapper struct {
78+
ctx echo.Context
79+
handler echo.HandlerFunc
80+
errChan chan error
81+
}
82+
83+
func (t echoHandlerFuncWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
84+
// replace writer with TimeoutHandler custom one. This will guarantee that
85+
// `writes by h to its ResponseWriter will return ErrHandlerTimeout.`
86+
originalWriter := t.ctx.Response().Writer
87+
t.ctx.Response().Writer = rw
88+
89+
err := t.handler(t.ctx)
90+
t.ctx.Response().Writer = originalWriter
91+
if ctxErr := r.Context().Err(); ctxErr == context.DeadlineExceeded {
92+
return // on timeout we can not send handler error to client because `http.TimeoutHandler` has already sent headers
93+
}
94+
if err != nil {
95+
t.errChan <- err
96+
}
97+
}

middleware/timeout_test.go

Lines changed: 90 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
package middleware
44

55
import (
6-
"context"
76
"errors"
87
"github.com/labstack/echo/v4"
98
"github.com/stretchr/testify/assert"
@@ -22,6 +21,7 @@ func TestTimeoutSkipper(t *testing.T) {
2221
Skipper: func(context echo.Context) bool {
2322
return true
2423
},
24+
Timeout: 1 * time.Nanosecond,
2525
})
2626

2727
req := httptest.NewRequest(http.MethodGet, "/", nil)
@@ -31,18 +31,17 @@ func TestTimeoutSkipper(t *testing.T) {
3131
c := e.NewContext(req, rec)
3232

3333
err := m(func(c echo.Context) error {
34-
assert.NotEqual(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String())
35-
return nil
34+
time.Sleep(25 * time.Microsecond)
35+
return errors.New("response from handler")
3636
})(c)
3737

38-
assert.NoError(t, err)
38+
// if not skipped we would have not returned error due context timeout logic
39+
assert.EqualError(t, err, "response from handler")
3940
}
4041

4142
func TestTimeoutWithTimeout0(t *testing.T) {
4243
t.Parallel()
43-
m := TimeoutWithConfig(TimeoutConfig{
44-
Timeout: 0,
45-
})
44+
m := Timeout()
4645

4746
req := httptest.NewRequest(http.MethodGet, "/", nil)
4847
rec := httptest.NewRecorder()
@@ -58,10 +57,11 @@ func TestTimeoutWithTimeout0(t *testing.T) {
5857
assert.NoError(t, err)
5958
}
6059

61-
func TestTimeoutIsCancelable(t *testing.T) {
60+
func TestTimeoutErrorOutInHandler(t *testing.T) {
6261
t.Parallel()
6362
m := TimeoutWithConfig(TimeoutConfig{
64-
Timeout: time.Minute,
63+
// Timeout has to be defined or the whole flow for timeout middleware will be skipped
64+
Timeout: 50 * time.Millisecond,
6565
})
6666

6767
req := httptest.NewRequest(http.MethodGet, "/", nil)
@@ -71,58 +71,76 @@ func TestTimeoutIsCancelable(t *testing.T) {
7171
c := e.NewContext(req, rec)
7272

7373
err := m(func(c echo.Context) error {
74-
assert.EqualValues(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String())
75-
return nil
74+
return errors.New("err")
7675
})(c)
7776

78-
assert.NoError(t, err)
77+
assert.Error(t, err)
7978
}
8079

81-
func TestTimeoutErrorOutInHandler(t *testing.T) {
80+
func TestTimeoutTestRequestClone(t *testing.T) {
8281
t.Parallel()
83-
m := Timeout()
84-
85-
req := httptest.NewRequest(http.MethodGet, "/", nil)
82+
req := httptest.NewRequest(http.MethodPost, "/uri?query=value", strings.NewReader(url.Values{"form": {"value"}}.Encode()))
83+
req.AddCookie(&http.Cookie{Name: "cookie", Value: "value"})
84+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
8685
rec := httptest.NewRecorder()
8786

87+
m := TimeoutWithConfig(TimeoutConfig{
88+
// Timeout has to be defined or the whole flow for timeout middleware will be skipped
89+
Timeout: 1 * time.Second,
90+
})
91+
8892
e := echo.New()
8993
c := e.NewContext(req, rec)
9094

9195
err := m(func(c echo.Context) error {
92-
return errors.New("err")
96+
// Cookie test
97+
cookie, err := c.Request().Cookie("cookie")
98+
if assert.NoError(t, err) {
99+
assert.EqualValues(t, "cookie", cookie.Name)
100+
assert.EqualValues(t, "value", cookie.Value)
101+
}
102+
103+
// Form values
104+
if assert.NoError(t, c.Request().ParseForm()) {
105+
assert.EqualValues(t, "value", c.Request().FormValue("form"))
106+
}
107+
108+
// Query string
109+
assert.EqualValues(t, "value", c.Request().URL.Query()["query"][0])
110+
return nil
93111
})(c)
94112

95-
assert.Error(t, err)
113+
assert.NoError(t, err)
114+
96115
}
97116

98-
func TestTimeoutTimesOutAfterPredefinedTimeoutWithErrorHandler(t *testing.T) {
117+
func TestTimeoutRecoversPanic(t *testing.T) {
99118
t.Parallel()
100-
m := TimeoutWithConfig(TimeoutConfig{
101-
Timeout: time.Second,
102-
ErrorHandler: func(err error, e echo.Context) error {
103-
assert.EqualError(t, err, context.DeadlineExceeded.Error())
104-
return errors.New("err")
105-
},
119+
e := echo.New()
120+
e.Use(Recover()) // recover middleware will handler our panic
121+
e.Use(TimeoutWithConfig(TimeoutConfig{
122+
Timeout: 50 * time.Millisecond,
123+
}))
124+
125+
e.GET("/", func(c echo.Context) error {
126+
panic("panic!!!")
106127
})
107128

108129
req := httptest.NewRequest(http.MethodGet, "/", nil)
109130
rec := httptest.NewRecorder()
110131

111-
e := echo.New()
112-
c := e.NewContext(req, rec)
113-
114-
err := m(func(c echo.Context) error {
115-
time.Sleep(time.Minute)
116-
return nil
117-
})(c)
118-
119-
assert.EqualError(t, err, errors.New("err").Error())
132+
assert.NotPanics(t, func() {
133+
e.ServeHTTP(rec, req)
134+
})
120135
}
121136

122-
func TestTimeoutTimesOutAfterPredefinedTimeout(t *testing.T) {
137+
func TestTimeoutDataRace(t *testing.T) {
123138
t.Parallel()
139+
140+
timeout := 1 * time.Millisecond
124141
m := TimeoutWithConfig(TimeoutConfig{
125-
Timeout: time.Second,
142+
Timeout: timeout,
143+
ErrorMessage: "Timeout! change me",
126144
})
127145

128146
req := httptest.NewRequest(http.MethodGet, "/", nil)
@@ -132,54 +150,57 @@ func TestTimeoutTimesOutAfterPredefinedTimeout(t *testing.T) {
132150
c := e.NewContext(req, rec)
133151

134152
err := m(func(c echo.Context) error {
135-
time.Sleep(time.Minute)
136-
return nil
153+
// NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds)
154+
// the result of timeout does not seem to be reliable - could respond timeout, could respond handler output
155+
// difference over 500microseconds (0.5millisecond) response seems to be reliable
156+
time.Sleep(timeout) // timeout and handler execution time difference is close to zero
157+
return c.String(http.StatusOK, "Hello, World!")
137158
})(c)
138159

139-
assert.EqualError(t, err, context.DeadlineExceeded.Error())
160+
assert.NoError(t, err)
161+
162+
if rec.Code == http.StatusServiceUnavailable {
163+
assert.Equal(t, "Timeout! change me", rec.Body.String())
164+
} else {
165+
assert.Equal(t, "Hello, World!", rec.Body.String())
166+
}
140167
}
141168

142-
func TestTimeoutTestRequestClone(t *testing.T) {
169+
func TestTimeoutWithErrorMessage(t *testing.T) {
143170
t.Parallel()
144-
req := httptest.NewRequest(http.MethodPost, "/uri?query=value", strings.NewReader(url.Values{"form": {"value"}}.Encode()))
145-
req.AddCookie(&http.Cookie{Name: "cookie", Value: "value"})
146-
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
147-
rec := httptest.NewRecorder()
148171

172+
timeout := 1 * time.Millisecond
149173
m := TimeoutWithConfig(TimeoutConfig{
150-
// Timeout has to be defined or the whole flow for timeout middleware will be skipped
151-
Timeout: time.Second,
174+
Timeout: timeout,
175+
ErrorMessage: "Timeout! change me",
152176
})
153177

178+
req := httptest.NewRequest(http.MethodGet, "/", nil)
179+
rec := httptest.NewRecorder()
180+
154181
e := echo.New()
155182
c := e.NewContext(req, rec)
156183

157184
err := m(func(c echo.Context) error {
158-
// Cookie test
159-
cookie, err := c.Request().Cookie("cookie")
160-
if assert.NoError(t, err) {
161-
assert.EqualValues(t, "cookie", cookie.Name)
162-
assert.EqualValues(t, "value", cookie.Value)
163-
}
164-
165-
// Form values
166-
if assert.NoError(t, c.Request().ParseForm()) {
167-
assert.EqualValues(t, "value", c.Request().FormValue("form"))
168-
}
169-
170-
// Query string
171-
assert.EqualValues(t, "value", c.Request().URL.Query()["query"][0])
172-
return nil
185+
// NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds)
186+
// the result of timeout does not seem to be reliable - could respond timeout, could respond handler output
187+
// difference over 500microseconds (0.5millisecond) response seems to be reliable
188+
time.Sleep(timeout + 1*time.Millisecond) // minimal difference
189+
return c.String(http.StatusOK, "Hello, World!")
173190
})(c)
174191

175192
assert.NoError(t, err)
176-
193+
assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
194+
assert.Equal(t, "Timeout! change me", rec.Body.String())
177195
}
178196

179-
func TestTimeoutRecoversPanic(t *testing.T) {
197+
func TestTimeoutWithDefaultErrorMessage(t *testing.T) {
180198
t.Parallel()
199+
200+
timeout := 1 * time.Millisecond
181201
m := TimeoutWithConfig(TimeoutConfig{
182-
Timeout: 25 * time.Millisecond,
202+
Timeout: timeout,
203+
ErrorMessage: "",
183204
})
184205

185206
req := httptest.NewRequest(http.MethodGet, "/", nil)
@@ -189,8 +210,11 @@ func TestTimeoutRecoversPanic(t *testing.T) {
189210
c := e.NewContext(req, rec)
190211

191212
err := m(func(c echo.Context) error {
192-
panic("panic in handler")
213+
time.Sleep(timeout + 25*time.Millisecond)
214+
return c.String(http.StatusOK, "Hello, World!")
193215
})(c)
194216

195-
assert.Error(t, err, "panic recovered in timeout middleware: panic in handler")
217+
assert.NoError(t, err)
218+
assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
219+
assert.Equal(t, `<html><head><title>Timeout</title></head><body><h1>Timeout</h1></body></html>`, rec.Body.String())
196220
}

0 commit comments

Comments
 (0)