Handle context canceled in ForwardAuth middleware

This commit is contained in:
Ben 2025-06-04 07:38:04 -06:00 committed by GitHub
parent bf72b9768c
commit 2949995abc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 78 additions and 1 deletions

View File

@ -17,6 +17,7 @@ import (
"github.com/traefik/traefik/v3/pkg/middlewares"
"github.com/traefik/traefik/v3/pkg/middlewares/accesslog"
"github.com/traefik/traefik/v3/pkg/middlewares/observability"
"github.com/traefik/traefik/v3/pkg/proxy/httputil"
"github.com/traefik/traefik/v3/pkg/tracing"
"github.com/traefik/traefik/v3/pkg/types"
"github.com/vulcand/oxy/v2/forward"
@ -195,7 +196,12 @@ func (fa *forwardAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
logger.Debug().Err(forwardErr).Msgf("Error calling %s", fa.address)
observability.SetStatusErrorf(req.Context(), "Error calling %s. Cause: %s", fa.address, forwardErr)
rw.WriteHeader(http.StatusInternalServerError)
statusCode := http.StatusInternalServerError
if errors.Is(forwardErr, context.Canceled) {
statusCode = httputil.StatusClientClosedRequest
}
rw.WriteHeader(statusCode)
return
}
defer forwardResponse.Body.Close()

View File

@ -11,10 +11,12 @@ import (
"net/url"
"strconv"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/traefik/traefik/v3/pkg/config/dynamic"
"github.com/traefik/traefik/v3/pkg/proxy/httputil"
"github.com/traefik/traefik/v3/pkg/testhelpers"
"github.com/traefik/traefik/v3/pkg/tracing"
"github.com/vulcand/oxy/v2/forward"
@ -408,6 +410,75 @@ func TestForwardAuthFailResponseHeaders(t *testing.T) {
assert.Equal(t, "Forbidden\n", string(body))
}
func TestForwardAuthClientClosedRequest(t *testing.T) {
requestStarted := make(chan struct{})
requestCancelled := make(chan struct{})
responseComplete := make(chan struct{})
authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(requestStarted)
<-requestCancelled
}))
t.Cleanup(authTs.Close)
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// next should not be called.
t.Fail()
})
auth := dynamic.ForwardAuth{
Address: authTs.URL,
}
authMiddleware, err := NewForward(t.Context(), next, auth, "authTest")
require.NoError(t, err)
ctx, cancel := context.WithCancel(t.Context())
req := httptest.NewRequestWithContext(ctx, "GET", "http://foo", http.NoBody)
recorder := httptest.NewRecorder()
go func() {
authMiddleware.ServeHTTP(recorder, req)
close(responseComplete)
}()
<-requestStarted
cancel()
close(requestCancelled)
<-responseComplete
assert.Equal(t, httputil.StatusClientClosedRequest, recorder.Result().StatusCode)
}
func TestForwardAuthForwardError(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// next should not be called.
t.Fail()
})
auth := dynamic.ForwardAuth{
Address: "http://non-existing-server",
}
authMiddleware, err := NewForward(t.Context(), next, auth, "authTest")
require.NoError(t, err)
ctx, cancel := context.WithTimeout(t.Context(), 1*time.Microsecond)
defer cancel()
req := httptest.NewRequestWithContext(ctx, http.MethodGet, "http://foo", nil)
recorder := httptest.NewRecorder()
responseComplete := make(chan struct{})
go func() {
authMiddleware.ServeHTTP(recorder, req)
close(responseComplete)
}()
<-responseComplete
assert.Equal(t, http.StatusInternalServerError, recorder.Result().StatusCode)
}
func Test_writeHeader(t *testing.T) {
testCases := []struct {
name string