Skip to content

Commit c7d6d43

Browse files
authored
proxy middleware: reuse echo request context (#2537)
1 parent 69a0de8 commit c7d6d43

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed

middleware/proxy.go

+4
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,10 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
359359
c.Set("_error", nil)
360360
}
361361

362+
// This is needed for ProxyConfig.ModifyResponse and/or ProxyConfig.Transport to be able to process the Request
363+
// that Balancer may have replaced with c.SetRequest.
364+
req = c.Request()
365+
362366
// Proxy
363367
switch {
364368
case c.IsWebSocket():

middleware/proxy_test.go

+60
Original file line numberDiff line numberDiff line change
@@ -747,3 +747,63 @@ func TestProxyBalancerWithNoTargets(t *testing.T) {
747747
rrb := NewRoundRobinBalancer([]*ProxyTarget{})
748748
assert.Nil(t, rrb.Next(nil))
749749
}
750+
751+
type testContextKey string
752+
753+
type customBalancer struct {
754+
target *ProxyTarget
755+
}
756+
757+
func (b *customBalancer) AddTarget(target *ProxyTarget) bool {
758+
return false
759+
}
760+
761+
func (b *customBalancer) RemoveTarget(name string) bool {
762+
return false
763+
}
764+
765+
func (b *customBalancer) Next(c echo.Context) *ProxyTarget {
766+
ctx := context.WithValue(c.Request().Context(), testContextKey("FROM_BALANCER"), "CUSTOM_BALANCER")
767+
c.SetRequest(c.Request().WithContext(ctx))
768+
return b.target
769+
}
770+
771+
func TestModifyResponseUseContext(t *testing.T) {
772+
server := httptest.NewServer(
773+
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
774+
w.WriteHeader(http.StatusOK)
775+
w.Write([]byte("OK"))
776+
}),
777+
)
778+
defer server.Close()
779+
780+
targetURL, _ := url.Parse(server.URL)
781+
e := echo.New()
782+
e.Use(ProxyWithConfig(
783+
ProxyConfig{
784+
Balancer: &customBalancer{
785+
target: &ProxyTarget{
786+
Name: "tst",
787+
URL: targetURL,
788+
},
789+
},
790+
RetryCount: 1,
791+
ModifyResponse: func(res *http.Response) error {
792+
val := res.Request.Context().Value(testContextKey("FROM_BALANCER"))
793+
if valStr, ok := val.(string); ok {
794+
res.Header.Set("FROM_BALANCER", valStr)
795+
}
796+
return nil
797+
},
798+
},
799+
))
800+
801+
req := httptest.NewRequest(http.MethodGet, "/", nil)
802+
rec := httptest.NewRecorder()
803+
804+
e.ServeHTTP(rec, req)
805+
806+
assert.Equal(t, http.StatusOK, rec.Code)
807+
assert.Equal(t, "OK", rec.Body.String())
808+
assert.Equal(t, "CUSTOM_BALANCER", rec.Header().Get("FROM_BALANCER"))
809+
}

0 commit comments

Comments
 (0)