Skip to content

Commit 3ba032f

Browse files
committed
fix: fall back to top-level JWT claims when userinfo is absent in HEADER_MAPPING
1 parent 0fc873e commit 3ba032f

2 files changed

Lines changed: 128 additions & 17 deletions

File tree

pkg/proxy/proxy.go

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -93,26 +93,30 @@ func (p *ProxyRouter) handleProxy(c *gin.Context) {
9393

9494
if len(p.headerMapping) > 0 {
9595
if claims, ok := token.Claims.(jwt.MapClaims); ok {
96+
source := map[string]any(claims)
9697
if userinfo, exists := claims["userinfo"]; exists {
97-
for pointer, headerName := range p.headerMapping {
98-
val, err := jsonpointer.Get(userinfo, pointer)
99-
if err != nil {
100-
continue
101-
}
102-
switch v := val.(type) {
103-
case string:
104-
c.Request.Header.Set(headerName, v)
105-
case []any:
106-
var parts []string
107-
for _, item := range v {
108-
if s, ok := item.(string); ok {
109-
parts = append(parts, s)
110-
}
98+
if ui, ok := userinfo.(map[string]any); ok {
99+
source = ui
100+
}
101+
}
102+
for pointer, headerName := range p.headerMapping {
103+
val, err := jsonpointer.Get(source, pointer)
104+
if err != nil {
105+
continue
106+
}
107+
switch v := val.(type) {
108+
case string:
109+
c.Request.Header.Set(headerName, v)
110+
case []any:
111+
var parts []string
112+
for _, item := range v {
113+
if s, ok := item.(string); ok {
114+
parts = append(parts, s)
111115
}
112-
c.Request.Header.Set(headerName, strings.Join(parts, ","))
113-
default:
114-
c.Request.Header.Set(headerName, fmt.Sprintf("%v", v))
115116
}
117+
c.Request.Header.Set(headerName, strings.Join(parts, ","))
118+
default:
119+
c.Request.Header.Set(headerName, fmt.Sprintf("%v", v))
116120
}
117121
}
118122
}

pkg/proxy/proxy_test.go

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,113 @@ func TestProxyRouter_HeaderMapping(t *testing.T) {
197197
}
198198
}
199199

200+
func TestProxyRouter_HeaderMapping_FallbackToTopLevelClaims(t *testing.T) {
201+
privateKey, publicKey, err := generateRSAKeyPair()
202+
require.NoError(t, err)
203+
204+
cases := []struct {
205+
name string
206+
headerMapping map[string]string
207+
claims jwt.MapClaims
208+
expectedHeaders map[string]string
209+
missingHeaders []string
210+
}{
211+
{
212+
name: "top-level claim when userinfo is absent",
213+
headerMapping: map[string]string{"/email": "X-Forwarded-Email"},
214+
claims: jwt.MapClaims{
215+
"sub": "test-user",
216+
"email": "[email protected]",
217+
"exp": time.Now().Add(time.Hour).Unix(),
218+
"iat": time.Now().Unix(),
219+
},
220+
expectedHeaders: map[string]string{
221+
"X-Forwarded-Email": "[email protected]",
222+
},
223+
},
224+
{
225+
name: "userinfo takes precedence over top-level claim",
226+
headerMapping: map[string]string{"/email": "X-Forwarded-Email"},
227+
claims: jwt.MapClaims{
228+
"sub": "test-user",
229+
"email": "[email protected]",
230+
"userinfo": map[string]any{"email": "[email protected]"},
231+
"exp": time.Now().Add(time.Hour).Unix(),
232+
"iat": time.Now().Unix(),
233+
},
234+
expectedHeaders: map[string]string{
235+
"X-Forwarded-Email": "[email protected]",
236+
},
237+
},
238+
{
239+
name: "multiple top-level claims without userinfo",
240+
headerMapping: map[string]string{"/email": "X-Forwarded-Email", "/name": "X-Forwarded-Name"},
241+
claims: jwt.MapClaims{
242+
"sub": "test-user",
243+
"email": "[email protected]",
244+
"name": "John Doe",
245+
"exp": time.Now().Add(time.Hour).Unix(),
246+
"iat": time.Now().Unix(),
247+
},
248+
expectedHeaders: map[string]string{
249+
"X-Forwarded-Email": "[email protected]",
250+
"X-Forwarded-Name": "John Doe",
251+
},
252+
},
253+
{
254+
name: "missing claim in both userinfo and top-level",
255+
headerMapping: map[string]string{"/email": "X-Forwarded-Email", "/missing": "X-Missing"},
256+
claims: jwt.MapClaims{
257+
"sub": "test-user",
258+
"email": "[email protected]",
259+
"exp": time.Now().Add(time.Hour).Unix(),
260+
"iat": time.Now().Unix(),
261+
},
262+
expectedHeaders: map[string]string{
263+
"X-Forwarded-Email": "[email protected]",
264+
},
265+
missingHeaders: []string{"X-Missing"},
266+
},
267+
}
268+
269+
for _, tt := range cases {
270+
t.Run(tt.name, func(t *testing.T) {
271+
receivedHeaders := http.Header{}
272+
proxyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
273+
for k, v := range r.Header {
274+
receivedHeaders[k] = v
275+
}
276+
w.WriteHeader(http.StatusOK)
277+
})
278+
279+
proxyRouter, err := NewProxyRouter("https://example.com", proxyHandler, publicKey, http.Header{}, false, tt.headerMapping)
280+
require.NoError(t, err)
281+
282+
gin.SetMode(gin.TestMode)
283+
router := gin.New()
284+
proxyRouter.SetupRoutes(router)
285+
286+
token, err := createJWT(privateKey, tt.claims)
287+
require.NoError(t, err)
288+
289+
req, err := http.NewRequest("GET", "/test", nil)
290+
require.NoError(t, err)
291+
req.Header.Set("Authorization", "Bearer "+token)
292+
293+
w := httptest.NewRecorder()
294+
router.ServeHTTP(w, req)
295+
296+
assert.Equal(t, http.StatusOK, w.Code)
297+
for header, expected := range tt.expectedHeaders {
298+
assert.Equal(t, expected, receivedHeaders.Get(header), "header %s mismatch", header)
299+
}
300+
for _, header := range tt.missingHeaders {
301+
assert.Empty(t, receivedHeaders.Get(header), "header %s should not be set", header)
302+
}
303+
})
304+
}
305+
}
306+
200307
func TestProxyRouter_ProtectedResourceTrailingSlash(t *testing.T) {
201308
_, publicKey, err := generateRSAKeyPair()
202309
require.NoError(t, err)

0 commit comments

Comments
 (0)