Skip to content

Commit a43f4f4

Browse files
authored
feat: add HEADER_MAPPING_BASE flag to control JWT claims source for HEADER_MAPPING (#144)
* fix: fall back to top-level JWT claims when userinfo is absent in HEADER_MAPPING * feat: add HEADER_MAPPING_BASE flag to control JWT claims source Instead of implicit fallback from userinfo to top-level claims, add an explicit --header-mapping-base / HEADER_MAPPING_BASE flag (default: /userinfo) that controls which JWT claims subtree HEADER_MAPPING reads from. Set to / for top-level claims. Closes #143
1 parent cdef9cf commit a43f4f4

7 files changed

Lines changed: 163 additions & 17 deletions

File tree

docs/docs/configuration.md

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -176,11 +176,13 @@ openssl genpkey -algorithm RSA -pkeyopt rsa_keygen_bits:2048
176176

177177
### Proxy Options
178178

179-
| Option | Environment Variable | Default | Description |
180-
| ----------------------- | --------------------- | ------- | ----------------------------------------------------------------------------------------------------- |
181-
| `--proxy-bearer-token` | `PROXY_BEARER_TOKEN` | - | Bearer token to add to Authorization header when proxying requests |
182-
| `--proxy-headers` | `PROXY_HEADERS` | - | Comma-separated list of headers to add when proxying requests (format: Header1:Value1,Header2:Value2) |
183-
| `--http-streaming-only` | `HTTP_STREAMING_ONLY` | `false` | Reject SSE (GET) requests and keep the backend operating in HTTP streaming-only mode |
184-
| `--trusted-proxies` | `TRUSTED_PROXIES` | - | Comma-separated list of trusted proxies (IP addresses or CIDR ranges) |
179+
| Option | Environment Variable | Default | Description |
180+
| ----------------------- | --------------------- | ----------- | ----------------------------------------------------------------------------------------------------- |
181+
| `--proxy-bearer-token` | `PROXY_BEARER_TOKEN` | - | Bearer token to add to Authorization header when proxying requests |
182+
| `--proxy-headers` | `PROXY_HEADERS` | - | Comma-separated list of headers to add when proxying requests (format: Header1:Value1,Header2:Value2) |
183+
| `--header-mapping` | `HEADER_MAPPING` | - | Comma-separated mapping of JSON pointer paths to header names (e.g., `/email:X-Forwarded-Email`) |
184+
| `--header-mapping-base` | `HEADER_MAPPING_BASE` | `/userinfo` | JSON pointer base path for header mapping claims lookup (e.g., `/userinfo` or `/`) |
185+
| `--http-streaming-only` | `HTTP_STREAMING_ONLY` | `false` | Reject SSE (GET) requests and keep the backend operating in HTTP streaming-only mode |
186+
| `--trusted-proxies` | `TRUSTED_PROXIES` | - | Comma-separated list of trusted proxies (IP addresses or CIDR ranges) |
185187

186188
For practical configuration examples including environment variables, Docker Compose, and Kubernetes deployments, see the [Configuration Examples](./examples.md) page.

main.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ type proxyRunnerFunc func(
155155
proxyTarget []string,
156156
httpStreamingOnly bool,
157157
headerMapping map[string]string,
158+
headerMappingBase string,
158159
) error
159160

160161
func main() {
@@ -200,6 +201,7 @@ func newRootCommand(run proxyRunnerFunc) *cobra.Command {
200201
var proxyBearerToken string
201202
var proxyHeaders string
202203
var headerMapping string
204+
var headerMappingBase string
203205
var httpStreamingOnly bool
204206
var trustedProxies string
205207

@@ -323,6 +325,7 @@ func newRootCommand(run proxyRunnerFunc) *cobra.Command {
323325
args,
324326
httpStreamingOnly,
325327
headerMappingMap,
328+
headerMappingBase,
326329
); err != nil {
327330
panic(err)
328331
}
@@ -376,7 +379,8 @@ func newRootCommand(run proxyRunnerFunc) *cobra.Command {
376379
rootCmd.Flags().StringVar(&trustedProxies, "trusted-proxies", getEnvWithDefault("TRUSTED_PROXIES", ""), "Comma-separated list of trusted proxies (IP addresses or CIDR ranges)")
377380
rootCmd.Flags().StringVar(&proxyHeaders, "proxy-headers", getEnvWithDefault("PROXY_HEADERS", ""), "Comma-separated list of headers to add when proxying requests (format: Header1:Value1,Header2:Value2)")
378381
rootCmd.Flags().BoolVar(&httpStreamingOnly, "http-streaming-only", getEnvBoolWithDefault("HTTP_STREAMING_ONLY", false), "Reject SSE (GET) requests and keep the backend in HTTP streaming-only mode")
379-
rootCmd.Flags().StringVar(&headerMapping, "header-mapping", getEnvWithDefault("HEADER_MAPPING", ""), "Comma-separated mapping of userinfo JSON pointer paths to header names (e.g., /email:X-Forwarded-Email,/preferred_username:X-Forwarded-User)")
382+
rootCmd.Flags().StringVar(&headerMapping, "header-mapping", getEnvWithDefault("HEADER_MAPPING", ""), "Comma-separated mapping of JSON pointer paths to header names (e.g., /email:X-Forwarded-Email,/preferred_username:X-Forwarded-User)")
383+
rootCmd.Flags().StringVar(&headerMappingBase, "header-mapping-base", getEnvWithDefault("HEADER_MAPPING_BASE", "/userinfo"), "JSON pointer base path for header mapping claims lookup (e.g., /userinfo or /)")
380384

381385
return rootCmd
382386
}

main_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,7 @@ func TestNewRootCommand_HTTPStreamingOnlyFlag(t *testing.T) {
403403
proxyTarget []string,
404404
httpStreamingOnly bool,
405405
headerMapping map[string]string,
406+
headerMappingBase string,
406407
) error {
407408
streamingOnly = httpStreamingOnly
408409
receivedTargets = proxyTarget
@@ -467,6 +468,7 @@ func TestNewRootCommand_HTTPStreamingOnlyFromEnv(t *testing.T) {
467468
proxyTarget []string,
468469
httpStreamingOnly bool,
469470
headerMapping map[string]string,
471+
headerMappingBase string,
470472
) error {
471473
streamingOnly = httpStreamingOnly
472474
return nil

pkg/mcp-proxy/main.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ func Run(
7777
proxyTarget []string,
7878
httpStreamingOnly bool,
7979
headerMapping map[string]string,
80+
headerMappingBase string,
8081
) error {
8182
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt)
8283
defer stop()
@@ -297,7 +298,7 @@ func Run(
297298
if err != nil {
298299
return fmt.Errorf("failed to create IDP router: %w", err)
299300
}
300-
proxyRouter, err := newProxyRouter(externalURL, beHandler, &privKey.PublicKey, proxyHeadersMap, httpStreamingOnly, headerMapping)
301+
proxyRouter, err := newProxyRouter(externalURL, beHandler, &privKey.PublicKey, proxyHeadersMap, httpStreamingOnly, headerMapping, headerMappingBase)
301302
if err != nil {
302303
return fmt.Errorf("failed to create proxy router: %w", err)
303304
}

pkg/mcp-proxy/main_test.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ func TestRun_NormalizesExternalURLTrailingSlash(t *testing.T) {
3434
for _, tt := range cases {
3535
t.Run(tt.name, func(t *testing.T) {
3636
var receivedURL string
37-
newProxyRouter = func(externalURL string, proxyHandler http.Handler, publicKey *rsa.PublicKey, proxyHeaders http.Header, httpStreamingOnly bool, headerMapping map[string]string) (*proxy.ProxyRouter, error) {
37+
newProxyRouter = func(externalURL string, proxyHandler http.Handler, publicKey *rsa.PublicKey, proxyHeaders http.Header, httpStreamingOnly bool, headerMapping map[string]string, headerMappingBase string) (*proxy.ProxyRouter, error) {
3838
receivedURL = externalURL
3939
return nil, errors.New("stop early")
4040
}
@@ -47,7 +47,7 @@ func TestRun_NormalizesExternalURLTrailingSlash(t *testing.T) {
4747
"", "", nil, nil,
4848
"", "", "", nil, "", "", nil, nil, nil, nil,
4949
false, "", "", nil, nil, "",
50-
[]string{"http://example.com"}, false, nil,
50+
[]string{"http://example.com"}, false, nil, "/userinfo",
5151
)
5252

5353
if tt.wantErr {
@@ -70,7 +70,7 @@ func TestRun_PassesHTTPStreamingOnlyToProxyRouter(t *testing.T) {
7070
})
7171

7272
var streamingOnlyReceived bool
73-
newProxyRouter = func(externalURL string, proxyHandler http.Handler, publicKey *rsa.PublicKey, proxyHeaders http.Header, httpStreamingOnly bool, headerMapping map[string]string) (*proxy.ProxyRouter, error) {
73+
newProxyRouter = func(externalURL string, proxyHandler http.Handler, publicKey *rsa.PublicKey, proxyHeaders http.Header, httpStreamingOnly bool, headerMapping map[string]string, headerMappingBase string) (*proxy.ProxyRouter, error) {
7474
streamingOnlyReceived = httpStreamingOnly
7575
return nil, errors.New("proxy router init failed")
7676
}
@@ -115,6 +115,7 @@ func TestRun_PassesHTTPStreamingOnlyToProxyRouter(t *testing.T) {
115115
[]string{"http://example.com"},
116116
true,
117117
nil,
118+
"/userinfo",
118119
)
119120

120121
require.Error(t, err)

pkg/proxy/proxy.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ type ProxyRouter struct {
1818
proxyHeaders http.Header
1919
httpStreamingOnly bool
2020
headerMapping map[string]string
21+
headerMappingBase string
2122
}
2223

2324
func NewProxyRouter(
@@ -27,6 +28,7 @@ func NewProxyRouter(
2728
proxyHeaders http.Header,
2829
httpStreamingOnly bool,
2930
headerMapping map[string]string,
31+
headerMappingBase string,
3032
) (*ProxyRouter, error) {
3133
return &ProxyRouter{
3234
externalURL: externalURL,
@@ -35,6 +37,7 @@ func NewProxyRouter(
3537
proxyHeaders: proxyHeaders,
3638
httpStreamingOnly: httpStreamingOnly,
3739
headerMapping: headerMapping,
40+
headerMappingBase: headerMappingBase,
3841
}, nil
3942
}
4043

@@ -93,9 +96,18 @@ func (p *ProxyRouter) handleProxy(c *gin.Context) {
9396

9497
if len(p.headerMapping) > 0 {
9598
if claims, ok := token.Claims.(jwt.MapClaims); ok {
96-
if userinfo, exists := claims["userinfo"]; exists {
99+
var source any = map[string]any(claims)
100+
if p.headerMappingBase != "/" {
101+
val, err := jsonpointer.Get(source, p.headerMappingBase)
102+
if err != nil {
103+
source = nil
104+
} else {
105+
source = val
106+
}
107+
}
108+
if source != nil {
97109
for pointer, headerName := range p.headerMapping {
98-
val, err := jsonpointer.Get(userinfo, pointer)
110+
val, err := jsonpointer.Get(source, pointer)
99111
if err != nil {
100112
continue
101113
}

pkg/proxy/proxy_test.go

Lines changed: 128 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ func TestProxyRouter_HandleProxy_ValidToken(t *testing.T) {
7070
proxyHeaders := make(http.Header)
7171
proxyHeaders.Set("X-Forwarded-By", "mcp-auth-proxy")
7272

73-
proxyRouter, err := NewProxyRouter("https://example.com", proxyHandler, publicKey, proxyHeaders, false, nil)
73+
proxyRouter, err := NewProxyRouter("https://example.com", proxyHandler, publicKey, proxyHeaders, false, nil, "/userinfo")
7474
require.NoError(t, err)
7575

7676
gin.SetMode(gin.TestMode)
@@ -163,7 +163,7 @@ func TestProxyRouter_HeaderMapping(t *testing.T) {
163163
w.WriteHeader(http.StatusOK)
164164
})
165165

166-
proxyRouter, err := NewProxyRouter("https://example.com", proxyHandler, publicKey, http.Header{}, false, tt.headerMapping)
166+
proxyRouter, err := NewProxyRouter("https://example.com", proxyHandler, publicKey, http.Header{}, false, tt.headerMapping, "/userinfo")
167167
require.NoError(t, err)
168168

169169
gin.SetMode(gin.TestMode)
@@ -197,11 +197,135 @@ func TestProxyRouter_HeaderMapping(t *testing.T) {
197197
}
198198
}
199199

200+
func TestProxyRouter_HeaderMappingBase(t *testing.T) {
201+
privateKey, publicKey, err := generateRSAKeyPair()
202+
require.NoError(t, err)
203+
204+
cases := []struct {
205+
name string
206+
headerMapping map[string]string
207+
headerMappingBase string
208+
claims jwt.MapClaims
209+
expectedHeaders map[string]string
210+
missingHeaders []string
211+
}{
212+
{
213+
name: "base=/ reads top-level claims",
214+
headerMapping: map[string]string{"/email": "X-Forwarded-Email"},
215+
headerMappingBase: "/",
216+
claims: jwt.MapClaims{
217+
"sub": "test-user",
218+
"email": "[email protected]",
219+
"exp": time.Now().Add(time.Hour).Unix(),
220+
"iat": time.Now().Unix(),
221+
},
222+
expectedHeaders: map[string]string{
223+
"X-Forwarded-Email": "[email protected]",
224+
},
225+
},
226+
{
227+
name: "base=/userinfo reads userinfo claims",
228+
headerMapping: map[string]string{"/email": "X-Forwarded-Email"},
229+
headerMappingBase: "/userinfo",
230+
claims: jwt.MapClaims{
231+
"sub": "test-user",
232+
"email": "[email protected]",
233+
"userinfo": map[string]any{"email": "[email protected]"},
234+
"exp": time.Now().Add(time.Hour).Unix(),
235+
"iat": time.Now().Unix(),
236+
},
237+
expectedHeaders: map[string]string{
238+
"X-Forwarded-Email": "[email protected]",
239+
},
240+
},
241+
{
242+
name: "base=/ with multiple claims",
243+
headerMapping: map[string]string{"/email": "X-Forwarded-Email", "/name": "X-Forwarded-Name"},
244+
headerMappingBase: "/",
245+
claims: jwt.MapClaims{
246+
"sub": "test-user",
247+
"email": "[email protected]",
248+
"name": "John Doe",
249+
"exp": time.Now().Add(time.Hour).Unix(),
250+
"iat": time.Now().Unix(),
251+
},
252+
expectedHeaders: map[string]string{
253+
"X-Forwarded-Email": "[email protected]",
254+
"X-Forwarded-Name": "John Doe",
255+
},
256+
},
257+
{
258+
name: "base=/userinfo skips when userinfo is absent",
259+
headerMapping: map[string]string{"/email": "X-Forwarded-Email"},
260+
headerMappingBase: "/userinfo",
261+
claims: jwt.MapClaims{
262+
"sub": "test-user",
263+
"email": "[email protected]",
264+
"exp": time.Now().Add(time.Hour).Unix(),
265+
"iat": time.Now().Unix(),
266+
},
267+
missingHeaders: []string{"X-Forwarded-Email"},
268+
},
269+
{
270+
name: "base=/ missing claim is skipped",
271+
headerMapping: map[string]string{"/email": "X-Forwarded-Email", "/missing": "X-Missing"},
272+
headerMappingBase: "/",
273+
claims: jwt.MapClaims{
274+
"sub": "test-user",
275+
"email": "[email protected]",
276+
"exp": time.Now().Add(time.Hour).Unix(),
277+
"iat": time.Now().Unix(),
278+
},
279+
expectedHeaders: map[string]string{
280+
"X-Forwarded-Email": "[email protected]",
281+
},
282+
missingHeaders: []string{"X-Missing"},
283+
},
284+
}
285+
286+
for _, tt := range cases {
287+
t.Run(tt.name, func(t *testing.T) {
288+
receivedHeaders := http.Header{}
289+
proxyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
290+
for k, v := range r.Header {
291+
receivedHeaders[k] = v
292+
}
293+
w.WriteHeader(http.StatusOK)
294+
})
295+
296+
proxyRouter, err := NewProxyRouter("https://example.com", proxyHandler, publicKey, http.Header{}, false, tt.headerMapping, tt.headerMappingBase)
297+
require.NoError(t, err)
298+
299+
gin.SetMode(gin.TestMode)
300+
router := gin.New()
301+
proxyRouter.SetupRoutes(router)
302+
303+
token, err := createJWT(privateKey, tt.claims)
304+
require.NoError(t, err)
305+
306+
req, err := http.NewRequest("GET", "/test", nil)
307+
require.NoError(t, err)
308+
req.Header.Set("Authorization", "Bearer "+token)
309+
310+
w := httptest.NewRecorder()
311+
router.ServeHTTP(w, req)
312+
313+
assert.Equal(t, http.StatusOK, w.Code)
314+
for header, expected := range tt.expectedHeaders {
315+
assert.Equal(t, expected, receivedHeaders.Get(header), "header %s mismatch", header)
316+
}
317+
for _, header := range tt.missingHeaders {
318+
assert.Empty(t, receivedHeaders.Get(header), "header %s should not be set", header)
319+
}
320+
})
321+
}
322+
}
323+
200324
func TestProxyRouter_ProtectedResourceTrailingSlash(t *testing.T) {
201325
_, publicKey, err := generateRSAKeyPair()
202326
require.NoError(t, err)
203327

204-
proxyRouter, err := NewProxyRouter("https://example.com/", http.NotFoundHandler(), publicKey, http.Header{}, false, nil)
328+
proxyRouter, err := NewProxyRouter("https://example.com/", http.NotFoundHandler(), publicKey, http.Header{}, false, nil, "/userinfo")
205329
require.NoError(t, err)
206330

207331
gin.SetMode(gin.TestMode)
@@ -305,7 +429,7 @@ func TestProxyRouter_HTTPStreamingOnlyRejectsSSE(t *testing.T) {
305429
w.WriteHeader(http.StatusOK)
306430
})
307431

308-
proxyRouter, err := NewProxyRouter("https://example.com", proxyHandler, publicKey, http.Header{}, tt.streamingOnly, nil)
432+
proxyRouter, err := NewProxyRouter("https://example.com", proxyHandler, publicKey, http.Header{}, tt.streamingOnly, nil, "/userinfo")
309433
require.NoError(t, err)
310434

311435
gin.SetMode(gin.TestMode)

0 commit comments

Comments
 (0)