Skip to content

Commit 4546f40

Browse files
pvliesdonkclaude
andauthored
fix: follow backend 307/308 redirects in transparent proxy (#116)
Co-authored-by: Claude Sonnet 4.6 <[email protected]>
1 parent 68437b1 commit 4546f40

2 files changed

Lines changed: 137 additions & 0 deletions

File tree

pkg/backend/transparent.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
package backend
22

33
import (
4+
"bytes"
45
"context"
56
"fmt"
7+
"io"
68
"net"
79
"net/http"
810
"net/http/httputil"
@@ -38,6 +40,79 @@ func NewTransparentBackend(logger *zap.Logger, u *url.URL, trusted []string) (Ba
3840
}, nil
3941
}
4042

43+
const maxBackendRedirects = 10
44+
45+
// redirectFollowingTransport wraps an http.RoundTripper to transparently
46+
// follow 307/308 redirects from backend servers. This is needed because
47+
// httputil.ReverseProxy uses Transport.RoundTrip() directly, which does
48+
// not follow redirects. Many MCP backends (Starlette/FastAPI) redirect
49+
// /mcp → /mcp/ via 307, which POST-based MCP clients won't follow.
50+
type redirectFollowingTransport struct {
51+
base http.RoundTripper
52+
targetHost string // only follow redirects to this host
53+
}
54+
55+
func (t *redirectFollowingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
56+
// Buffer body upfront so we can replay it on redirect.
57+
// MCP JSON-RPC payloads are small, so this is fine.
58+
var bodyBytes []byte
59+
if req.Body != nil {
60+
var err error
61+
bodyBytes, err = io.ReadAll(req.Body)
62+
if err != nil {
63+
return nil, fmt.Errorf("failed to read request body: %w", err)
64+
}
65+
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
66+
}
67+
68+
for i := 0; i <= maxBackendRedirects; i++ {
69+
resp, err := t.base.RoundTrip(req)
70+
if err != nil {
71+
return nil, err
72+
}
73+
74+
// Only follow 307 (Temporary) and 308 (Permanent) redirects.
75+
// These preserve the original method and body per HTTP spec.
76+
if resp.StatusCode != http.StatusTemporaryRedirect &&
77+
resp.StatusCode != http.StatusPermanentRedirect {
78+
return resp, nil
79+
}
80+
81+
location := resp.Header.Get("Location")
82+
if location == "" {
83+
return resp, nil
84+
}
85+
86+
// Resolve relative Location against the request URL
87+
newURL, err := req.URL.Parse(location)
88+
if err != nil {
89+
return resp, nil
90+
}
91+
92+
// Security: only follow redirects to the same backend host.
93+
// Don't leak Authorization headers or body to arbitrary hosts.
94+
if newURL.Host != "" && newURL.Host != t.targetHost {
95+
return resp, nil
96+
}
97+
98+
// Drain and close the redirect response body
99+
io.Copy(io.Discard, resp.Body)
100+
resp.Body.Close()
101+
102+
// Clone the request for the next hop, replaying the body
103+
newReq := req.Clone(req.Context())
104+
newReq.URL = newURL
105+
newReq.Host = newURL.Host
106+
if bodyBytes != nil {
107+
newReq.Body = io.NopCloser(bytes.NewReader(bodyBytes))
108+
newReq.ContentLength = int64(len(bodyBytes))
109+
}
110+
req = newReq
111+
}
112+
113+
return nil, fmt.Errorf("backend exceeded maximum redirects (%d)", maxBackendRedirects)
114+
}
115+
41116
func (p *TransparentBackend) Run(ctx context.Context) (http.Handler, error) {
42117
p.ctxLock.Lock()
43118
defer p.ctxLock.Unlock()
@@ -46,6 +121,10 @@ func (p *TransparentBackend) Run(ctx context.Context) (http.Handler, error) {
46121
}
47122
p.ctx = ctx
48123
rp := httputil.ReverseProxy{
124+
Transport: &redirectFollowingTransport{
125+
base: http.DefaultTransport,
126+
targetHost: p.url.Host,
127+
},
49128
Rewrite: func(pr *httputil.ProxyRequest) {
50129
pr.SetURL(p.url)
51130
if p.isTrusted(pr.In.RemoteAddr) {

pkg/backend/transparent_test.go

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@ package backend
33
import (
44
"context"
55
"encoding/json"
6+
"io"
67
"net/http"
78
"net/http/httptest"
89
"net/url"
10+
"strings"
911
"testing"
1012
"time"
1113

@@ -98,6 +100,62 @@ func TestTransparentBackendWithInvalidProxy(t *testing.T) {
98100
require.Equal(t, "http", header.Get(("X-Forwarded-Proto")))
99101
}
100102

103+
func TestTransparentBackendFollows307Redirect(t *testing.T) {
104+
r := gin.New()
105+
// Simulate Starlette's redirect_slashes: /mcp → 307 → /mcp/
106+
r.POST("/mcp", func(c *gin.Context) {
107+
c.Redirect(http.StatusTemporaryRedirect, "/mcp/")
108+
})
109+
r.POST("/mcp/", func(c *gin.Context) {
110+
body, _ := io.ReadAll(c.Request.Body)
111+
c.JSON(http.StatusOK, gin.H{
112+
"received": string(body),
113+
"method": c.Request.Method,
114+
})
115+
})
116+
ts := httptest.NewServer(r)
117+
defer ts.Close()
118+
u, _ := url.Parse(ts.URL)
119+
120+
be, err := NewTransparentBackend(zap.NewNop(), u, []string{})
121+
require.NoError(t, err)
122+
handler, err := be.Run(context.Background())
123+
require.NoError(t, err)
124+
125+
body := `{"jsonrpc":"2.0","method":"initialize"}`
126+
req := httptest.NewRequest(http.MethodPost, "/mcp", strings.NewReader(body))
127+
req.Header.Set("Content-Type", "application/json")
128+
rr := httptest.NewRecorder()
129+
handler.ServeHTTP(rr, req)
130+
131+
require.Equal(t, http.StatusOK, rr.Code, "should follow 307 internally")
132+
var resp map[string]string
133+
require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &resp))
134+
require.Equal(t, body, resp["received"], "body must be preserved across redirect")
135+
require.Equal(t, "POST", resp["method"], "method must be preserved")
136+
}
137+
138+
func TestTransparentBackendRedirectLoopProtection(t *testing.T) {
139+
r := gin.New()
140+
r.POST("/loop", func(c *gin.Context) {
141+
c.Redirect(http.StatusTemporaryRedirect, "/loop")
142+
})
143+
ts := httptest.NewServer(r)
144+
defer ts.Close()
145+
u, _ := url.Parse(ts.URL)
146+
147+
be, err := NewTransparentBackend(zap.NewNop(), u, []string{})
148+
require.NoError(t, err)
149+
handler, err := be.Run(context.Background())
150+
require.NoError(t, err)
151+
152+
req := httptest.NewRequest(http.MethodPost, "/loop", strings.NewReader("{}"))
153+
rr := httptest.NewRecorder()
154+
handler.ServeHTTP(rr, req)
155+
156+
require.Equal(t, http.StatusBadGateway, rr.Code, "should fail on redirect loop")
157+
}
158+
101159
func TestTransparentBackendRun(t *testing.T) {
102160
r := gin.New()
103161
r.GET("/", func(c *gin.Context) {

0 commit comments

Comments
 (0)