Skip to content

Commit 683d79a

Browse files
authored
fix: trim userinfo to mapped fields before storing in session cookie (#141)
1 parent 0fc873e commit 683d79a

5 files changed

Lines changed: 251 additions & 9 deletions

File tree

pkg/auth/auth.go

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,14 @@ type AuthRouter struct {
2525
// When true, do not auto-redirect to the sole provider even if
2626
// there is only one provider and no password is set.
2727
noProviderAutoSelect bool
28+
// userInfoFields is a list of top-level keys to retain from the
29+
// provider's userinfo response. When non-empty, all other keys are
30+
// stripped before the data is stored in the session cookie. This
31+
// prevents oversized cookies when the provider returns many claims.
32+
userInfoFields []string
2833
}
2934

30-
func NewAuthRouter(passwordHash []string, noProviderAutoSelect bool, providers ...Provider) (*AuthRouter, error) {
35+
func NewAuthRouter(passwordHash []string, noProviderAutoSelect bool, userInfoFields []string, providers ...Provider) (*AuthRouter, error) {
3136
tmpl, err := template.ParseFS(templateFS, "templates/login.html")
3237
if err != nil {
3338
return nil, err
@@ -50,6 +55,7 @@ func NewAuthRouter(passwordHash []string, noProviderAutoSelect bool, providers .
5055
unauthorizedTemplate: unauthorizedTmpl,
5156
errorTemplate: errorTmpl,
5257
noProviderAutoSelect: noProviderAutoSelect,
58+
userInfoFields: userInfoFields,
5359
}, nil
5460
}
5561

@@ -102,6 +108,9 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) {
102108
session.Set(SessionKeyAuthorized, true)
103109
session.Set(SessionKeyUserID, user)
104110
if userInfo != nil {
111+
if len(a.userInfoFields) > 0 {
112+
userInfo = filterUserInfo(userInfo, a.userInfoFields)
113+
}
105114
if userInfoJSON, err := json.Marshal(userInfo); err == nil {
106115
session.Set(SessionKeyUserInfo, string(userInfoJSON))
107116
}
@@ -282,6 +291,17 @@ func (a *AuthRouter) renderUnauthorized(c *gin.Context, userID, providerName str
282291
}
283292
}
284293

294+
// filterUserInfo returns a copy of m containing only the listed keys.
295+
func filterUserInfo(m map[string]any, keys []string) map[string]any {
296+
filtered := make(map[string]any, len(keys))
297+
for _, k := range keys {
298+
if v, ok := m[k]; ok {
299+
filtered[k] = v
300+
}
301+
}
302+
return filtered
303+
}
304+
285305
func (a *AuthRouter) renderError(c *gin.Context, err error) {
286306
data := errorTemplateData{
287307
ErrorMessage: err.Error(),

pkg/auth/auth_test.go

Lines changed: 165 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package auth
22

33
import (
4+
"encoding/json"
45
"net/http"
56
"net/http/cookiejar"
67
"net/http/httptest"
@@ -42,6 +43,164 @@ func setupClient() *http.Client {
4243
}
4344
}
4445

46+
func TestFilterUserInfo(t *testing.T) {
47+
t.Run("filters to specified keys", func(t *testing.T) {
48+
info := map[string]any{
49+
"email": "[email protected]",
50+
"preferred_username": "user",
51+
"groups": []any{"admin", "dev"},
52+
"realm_access": map[string]any{"roles": []any{"offline_access"}},
53+
}
54+
filtered := filterUserInfo(info, []string{"email", "preferred_username"})
55+
require.Equal(t, map[string]any{
56+
"email": "[email protected]",
57+
"preferred_username": "user",
58+
}, filtered)
59+
})
60+
61+
t.Run("missing keys are skipped", func(t *testing.T) {
62+
info := map[string]any{"email": "[email protected]"}
63+
filtered := filterUserInfo(info, []string{"email", "name"})
64+
require.Equal(t, map[string]any{"email": "[email protected]"}, filtered)
65+
})
66+
67+
t.Run("empty keys returns empty map", func(t *testing.T) {
68+
info := map[string]any{"email": "[email protected]"}
69+
filtered := filterUserInfo(info, []string{})
70+
require.Empty(t, filtered)
71+
})
72+
73+
t.Run("nil input returns empty map", func(t *testing.T) {
74+
filtered := filterUserInfo(nil, []string{"email"})
75+
require.Empty(t, filtered)
76+
})
77+
}
78+
79+
func TestUserInfoFilteringInOAuthFlow(t *testing.T) {
80+
t.Run("session stores only filtered userinfo fields", func(t *testing.T) {
81+
ctrl := gomock.NewController(t)
82+
defer ctrl.Finish()
83+
84+
fullUserInfo := map[string]any{
85+
"email": "[email protected]",
86+
"preferred_username": "user",
87+
"groups": []any{"admin", "developers", "platform-team"},
88+
"realm_access": map[string]any{"roles": []any{"offline_access", "uma_authorization"}},
89+
"resource_access": map[string]any{"account": map[string]any{"roles": []any{"view-profile"}}},
90+
}
91+
92+
mockToken := &oauth2.Token{AccessToken: "test-token"}
93+
mockProvider := NewMockProvider(ctrl)
94+
mockProvider.EXPECT().Name().Return("test").AnyTimes()
95+
mockProvider.EXPECT().AuthURL().Return("/.auth/test").AnyTimes()
96+
mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes()
97+
mockProvider.EXPECT().AuthCodeURL(gomock.Any()).Return("https://example.com/oauth", nil)
98+
mockProvider.EXPECT().Exchange(gomock.Any(), gomock.Any()).Return(mockToken, nil)
99+
mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(true, "[email protected]", fullUserInfo, nil)
100+
101+
authRouter, err := NewAuthRouter(nil, false, []string{"email", "preferred_username"}, mockProvider)
102+
require.NoError(t, err)
103+
104+
// Add a route that reads back the session to verify stored userinfo
105+
var storedUserInfo string
106+
router := gin.New()
107+
store := memstore.NewStore([]byte("test-secret"))
108+
router.Use(sessions.Sessions("session", store))
109+
authRouter.SetupRoutes(router)
110+
router.GET("/check-session", func(c *gin.Context) {
111+
session := sessions.Default(c)
112+
if v, ok := session.Get(SessionKeyUserInfo).(string); ok {
113+
storedUserInfo = v
114+
}
115+
c.String(http.StatusOK, "ok")
116+
})
117+
118+
server := httptest.NewServer(router)
119+
defer server.Close()
120+
client := setupClient()
121+
122+
// Start auth flow to set oauth state
123+
resp, err := client.Get(server.URL + "/.auth/test")
124+
require.NoError(t, err)
125+
resp.Body.Close()
126+
127+
// Complete callback
128+
resp, err = client.Get(server.URL + "/.auth/test/callback")
129+
require.NoError(t, err)
130+
resp.Body.Close()
131+
require.Equal(t, http.StatusFound, resp.StatusCode)
132+
133+
// Read back session
134+
resp, err = client.Get(server.URL + "/check-session")
135+
require.NoError(t, err)
136+
resp.Body.Close()
137+
138+
var parsed map[string]any
139+
require.NoError(t, json.Unmarshal([]byte(storedUserInfo), &parsed))
140+
require.Equal(t, "[email protected]", parsed["email"])
141+
require.Equal(t, "user", parsed["preferred_username"])
142+
require.NotContains(t, parsed, "groups")
143+
require.NotContains(t, parsed, "realm_access")
144+
require.NotContains(t, parsed, "resource_access")
145+
})
146+
147+
t.Run("nil filter stores full userinfo", func(t *testing.T) {
148+
ctrl := gomock.NewController(t)
149+
defer ctrl.Finish()
150+
151+
fullUserInfo := map[string]any{
152+
"email": "[email protected]",
153+
"groups": []any{"admin"},
154+
}
155+
156+
mockToken := &oauth2.Token{AccessToken: "test-token"}
157+
mockProvider := NewMockProvider(ctrl)
158+
mockProvider.EXPECT().Name().Return("test").AnyTimes()
159+
mockProvider.EXPECT().AuthURL().Return("/.auth/test").AnyTimes()
160+
mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes()
161+
mockProvider.EXPECT().AuthCodeURL(gomock.Any()).Return("https://example.com/oauth", nil)
162+
mockProvider.EXPECT().Exchange(gomock.Any(), gomock.Any()).Return(mockToken, nil)
163+
mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(true, "[email protected]", fullUserInfo, nil)
164+
165+
authRouter, err := NewAuthRouter(nil, false, nil, mockProvider)
166+
require.NoError(t, err)
167+
168+
var storedUserInfo string
169+
router := gin.New()
170+
store := memstore.NewStore([]byte("test-secret"))
171+
router.Use(sessions.Sessions("session", store))
172+
authRouter.SetupRoutes(router)
173+
router.GET("/check-session", func(c *gin.Context) {
174+
session := sessions.Default(c)
175+
if v, ok := session.Get(SessionKeyUserInfo).(string); ok {
176+
storedUserInfo = v
177+
}
178+
c.String(http.StatusOK, "ok")
179+
})
180+
181+
server := httptest.NewServer(router)
182+
defer server.Close()
183+
client := setupClient()
184+
185+
resp, err := client.Get(server.URL + "/.auth/test")
186+
require.NoError(t, err)
187+
resp.Body.Close()
188+
189+
resp, err = client.Get(server.URL + "/.auth/test/callback")
190+
require.NoError(t, err)
191+
resp.Body.Close()
192+
193+
resp, err = client.Get(server.URL + "/check-session")
194+
require.NoError(t, err)
195+
resp.Body.Close()
196+
197+
var parsed map[string]any
198+
require.NoError(t, json.Unmarshal([]byte(storedUserInfo), &parsed))
199+
require.Contains(t, parsed, "email")
200+
require.Contains(t, parsed, "groups")
201+
})
202+
}
203+
45204
func TestAuthenticationFlow(t *testing.T) {
46205
t.Run("Unauthenticated access should redirect to login", func(t *testing.T) {
47206
ctrl := gomock.NewController(t)
@@ -54,7 +213,7 @@ func TestAuthenticationFlow(t *testing.T) {
54213
mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes()
55214

56215
// Create AuthRouter (auto-select enabled by default)
57-
authRouter, err := NewAuthRouter(nil, false, mockProvider)
216+
authRouter, err := NewAuthRouter(nil, false, nil, mockProvider)
58217
require.NoError(t, err)
59218

60219
router := setupTestRouter(authRouter)
@@ -88,7 +247,7 @@ func TestAuthenticationFlow(t *testing.T) {
88247
mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(true, "authorized_user", map[string]any{"email": "[email protected]"}, nil)
89248

90249
// Create AuthRouter
91-
authRouter, err := NewAuthRouter(nil, false, mockProvider)
250+
authRouter, err := NewAuthRouter(nil, false, nil, mockProvider)
92251
require.NoError(t, err)
93252

94253
router := setupTestRouter(authRouter)
@@ -149,7 +308,7 @@ func TestAuthenticationFlow(t *testing.T) {
149308
mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(false, "unauthorized_user", map[string]any{"email": "[email protected]"}, nil)
150309

151310
// Create AuthRouter
152-
authRouter, err := NewAuthRouter(nil, false, mockProvider)
311+
authRouter, err := NewAuthRouter(nil, false, nil, mockProvider)
153312
require.NoError(t, err)
154313

155314
router := setupTestRouter(authRouter)
@@ -199,7 +358,7 @@ func TestLoginAutoRedirect(t *testing.T) {
199358
mockProvider.EXPECT().AuthURL().Return("/.auth/test").AnyTimes()
200359
mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes()
201360

202-
authRouter, err := NewAuthRouter(nil, false, mockProvider)
361+
authRouter, err := NewAuthRouter(nil, false, nil, mockProvider)
203362
require.NoError(t, err)
204363

205364
router := gin.New()
@@ -230,7 +389,7 @@ func TestLoginAutoRedirect(t *testing.T) {
230389
mockProvider.EXPECT().AuthURL().Return("/.auth/test").AnyTimes()
231390
mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes()
232391

233-
authRouter, err := NewAuthRouter(nil, true, mockProvider)
392+
authRouter, err := NewAuthRouter(nil, true, nil, mockProvider)
234393
require.NoError(t, err)
235394

236395
router := gin.New()
@@ -260,7 +419,7 @@ func TestLoginAutoRedirect(t *testing.T) {
260419
mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes()
261420

262421
// Non-empty passwordHash slice disables auto-select
263-
authRouter, err := NewAuthRouter([]string{"dummy"}, false, mockProvider)
422+
authRouter, err := NewAuthRouter([]string{"dummy"}, false, nil, mockProvider)
264423
require.NoError(t, err)
265424

266425
router := gin.New()

pkg/idp/idp_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ func setupTestServer(t *testing.T) (*httptest.Server, repository.Repository, str
6666
})
6767

6868
// Create auth router and IDP router
69-
authRouter, err := auth.NewAuthRouter([]string{}, false)
69+
authRouter, err := auth.NewAuthRouter([]string{}, false, nil)
7070
require.NoError(t, err)
7171

7272
logger, _ := zap.NewDevelopment()

pkg/mcp-proxy/main.go

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,11 @@ func Run(
285285
passwordHashes = append(passwordHashes, passwordHash)
286286
}
287287

288-
authRouter, err := auth.NewAuthRouter(passwordHashes, noProviderAutoSelect, providers...)
288+
// Collect the top-level userinfo keys that are actually needed so the
289+
// session cookie doesn't store the entire provider response.
290+
userInfoFields := userInfoFieldsFromConfig(oidcUserIDField, headerMapping)
291+
292+
authRouter, err := auth.NewAuthRouter(passwordHashes, noProviderAutoSelect, userInfoFields, providers...)
289293
if err != nil {
290294
return fmt.Errorf("failed to create auth router: %w", err)
291295
}
@@ -529,3 +533,28 @@ func Run(
529533
wg.Wait()
530534
return errors.Join(errs...)
531535
}
536+
537+
// userInfoFieldsFromConfig extracts the top-level userinfo keys referenced
538+
// by the OIDC user-ID field and the header mapping. JSON pointers like
539+
// "/email" or "/preferred_username" yield "email" or "preferred_username".
540+
func userInfoFieldsFromConfig(oidcUserIDField string, headerMapping map[string]string) []string {
541+
seen := map[string]struct{}{}
542+
add := func(pointer string) {
543+
pointer = strings.TrimPrefix(pointer, "/")
544+
if i := strings.IndexByte(pointer, '/'); i != -1 {
545+
pointer = pointer[:i]
546+
}
547+
if pointer != "" {
548+
seen[pointer] = struct{}{}
549+
}
550+
}
551+
add(oidcUserIDField)
552+
for pointer := range headerMapping {
553+
add(pointer)
554+
}
555+
fields := make([]string, 0, len(seen))
556+
for k := range seen {
557+
fields = append(fields, k)
558+
}
559+
return fields
560+
}

pkg/mcp-proxy/main_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,40 @@ func TestRun_PassesHTTPStreamingOnlyToProxyRouter(t *testing.T) {
122122
require.True(t, streamingOnlyReceived, "httpStreamingOnly should be forwarded to proxy router")
123123
}
124124

125+
func TestUserInfoFieldsFromConfig(t *testing.T) {
126+
t.Run("extracts fields from header mapping and user ID field", func(t *testing.T) {
127+
fields := userInfoFieldsFromConfig("/email", map[string]string{
128+
"/email": "X-Forwarded-Email",
129+
"/preferred_username": "X-Forwarded-User",
130+
})
131+
require.ElementsMatch(t, []string{"email", "preferred_username"}, fields)
132+
})
133+
134+
t.Run("handles nested JSON pointers by taking top-level key", func(t *testing.T) {
135+
fields := userInfoFieldsFromConfig("/email", map[string]string{
136+
"/address/street": "X-Street",
137+
})
138+
require.ElementsMatch(t, []string{"email", "address"}, fields)
139+
})
140+
141+
t.Run("deduplicates overlapping fields", func(t *testing.T) {
142+
fields := userInfoFieldsFromConfig("/email", map[string]string{
143+
"/email": "X-Forwarded-Email",
144+
})
145+
require.Equal(t, []string{"email"}, fields)
146+
})
147+
148+
t.Run("empty config returns empty slice", func(t *testing.T) {
149+
fields := userInfoFieldsFromConfig("", nil)
150+
require.Empty(t, fields)
151+
})
152+
153+
t.Run("handles user ID field without leading slash", func(t *testing.T) {
154+
fields := userInfoFieldsFromConfig("email", nil)
155+
require.Equal(t, []string{"email"}, fields)
156+
})
157+
}
158+
125159
func TestHealthzEndpoint(t *testing.T) {
126160
gin.SetMode(gin.TestMode)
127161
router := gin.New()

0 commit comments

Comments
 (0)