diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 04fda55..e81f102 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -25,9 +25,14 @@ type AuthRouter struct { // When true, do not auto-redirect to the sole provider even if // there is only one provider and no password is set. noProviderAutoSelect bool + // userInfoFields is a list of top-level keys to retain from the + // provider's userinfo response. When non-empty, all other keys are + // stripped before the data is stored in the session cookie. This + // prevents oversized cookies when the provider returns many claims. + userInfoFields []string } -func NewAuthRouter(passwordHash []string, noProviderAutoSelect bool, providers ...Provider) (*AuthRouter, error) { +func NewAuthRouter(passwordHash []string, noProviderAutoSelect bool, userInfoFields []string, providers ...Provider) (*AuthRouter, error) { tmpl, err := template.ParseFS(templateFS, "templates/login.html") if err != nil { return nil, err @@ -50,6 +55,7 @@ func NewAuthRouter(passwordHash []string, noProviderAutoSelect bool, providers . unauthorizedTemplate: unauthorizedTmpl, errorTemplate: errorTmpl, noProviderAutoSelect: noProviderAutoSelect, + userInfoFields: userInfoFields, }, nil } @@ -102,6 +108,9 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) { session.Set(SessionKeyAuthorized, true) session.Set(SessionKeyUserID, user) if userInfo != nil { + if len(a.userInfoFields) > 0 { + userInfo = filterUserInfo(userInfo, a.userInfoFields) + } if userInfoJSON, err := json.Marshal(userInfo); err == nil { session.Set(SessionKeyUserInfo, string(userInfoJSON)) } @@ -282,6 +291,17 @@ func (a *AuthRouter) renderUnauthorized(c *gin.Context, userID, providerName str } } +// filterUserInfo returns a copy of m containing only the listed keys. +func filterUserInfo(m map[string]any, keys []string) map[string]any { + filtered := make(map[string]any, len(keys)) + for _, k := range keys { + if v, ok := m[k]; ok { + filtered[k] = v + } + } + return filtered +} + func (a *AuthRouter) renderError(c *gin.Context, err error) { data := errorTemplateData{ ErrorMessage: err.Error(), diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go index 14dbb9d..fa75f3d 100644 --- a/pkg/auth/auth_test.go +++ b/pkg/auth/auth_test.go @@ -1,6 +1,7 @@ package auth import ( + "encoding/json" "net/http" "net/http/cookiejar" "net/http/httptest" @@ -42,6 +43,164 @@ func setupClient() *http.Client { } } +func TestFilterUserInfo(t *testing.T) { + t.Run("filters to specified keys", func(t *testing.T) { + info := map[string]any{ + "email": "user@example.com", + "preferred_username": "user", + "groups": []any{"admin", "dev"}, + "realm_access": map[string]any{"roles": []any{"offline_access"}}, + } + filtered := filterUserInfo(info, []string{"email", "preferred_username"}) + require.Equal(t, map[string]any{ + "email": "user@example.com", + "preferred_username": "user", + }, filtered) + }) + + t.Run("missing keys are skipped", func(t *testing.T) { + info := map[string]any{"email": "user@example.com"} + filtered := filterUserInfo(info, []string{"email", "name"}) + require.Equal(t, map[string]any{"email": "user@example.com"}, filtered) + }) + + t.Run("empty keys returns empty map", func(t *testing.T) { + info := map[string]any{"email": "user@example.com"} + filtered := filterUserInfo(info, []string{}) + require.Empty(t, filtered) + }) + + t.Run("nil input returns empty map", func(t *testing.T) { + filtered := filterUserInfo(nil, []string{"email"}) + require.Empty(t, filtered) + }) +} + +func TestUserInfoFilteringInOAuthFlow(t *testing.T) { + t.Run("session stores only filtered userinfo fields", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + fullUserInfo := map[string]any{ + "email": "user@example.com", + "preferred_username": "user", + "groups": []any{"admin", "developers", "platform-team"}, + "realm_access": map[string]any{"roles": []any{"offline_access", "uma_authorization"}}, + "resource_access": map[string]any{"account": map[string]any{"roles": []any{"view-profile"}}}, + } + + mockToken := &oauth2.Token{AccessToken: "test-token"} + mockProvider := NewMockProvider(ctrl) + mockProvider.EXPECT().Name().Return("test").AnyTimes() + mockProvider.EXPECT().AuthURL().Return("/.auth/test").AnyTimes() + mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes() + mockProvider.EXPECT().AuthCodeURL(gomock.Any()).Return("https://example.com/oauth", nil) + mockProvider.EXPECT().Exchange(gomock.Any(), gomock.Any()).Return(mockToken, nil) + mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(true, "user@example.com", fullUserInfo, nil) + + authRouter, err := NewAuthRouter(nil, false, []string{"email", "preferred_username"}, mockProvider) + require.NoError(t, err) + + // Add a route that reads back the session to verify stored userinfo + var storedUserInfo string + router := gin.New() + store := memstore.NewStore([]byte("test-secret")) + router.Use(sessions.Sessions("session", store)) + authRouter.SetupRoutes(router) + router.GET("/check-session", func(c *gin.Context) { + session := sessions.Default(c) + if v, ok := session.Get(SessionKeyUserInfo).(string); ok { + storedUserInfo = v + } + c.String(http.StatusOK, "ok") + }) + + server := httptest.NewServer(router) + defer server.Close() + client := setupClient() + + // Start auth flow to set oauth state + resp, err := client.Get(server.URL + "/.auth/test") + require.NoError(t, err) + resp.Body.Close() + + // Complete callback + resp, err = client.Get(server.URL + "/.auth/test/callback") + require.NoError(t, err) + resp.Body.Close() + require.Equal(t, http.StatusFound, resp.StatusCode) + + // Read back session + resp, err = client.Get(server.URL + "/check-session") + require.NoError(t, err) + resp.Body.Close() + + var parsed map[string]any + require.NoError(t, json.Unmarshal([]byte(storedUserInfo), &parsed)) + require.Equal(t, "user@example.com", parsed["email"]) + require.Equal(t, "user", parsed["preferred_username"]) + require.NotContains(t, parsed, "groups") + require.NotContains(t, parsed, "realm_access") + require.NotContains(t, parsed, "resource_access") + }) + + t.Run("nil filter stores full userinfo", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + fullUserInfo := map[string]any{ + "email": "user@example.com", + "groups": []any{"admin"}, + } + + mockToken := &oauth2.Token{AccessToken: "test-token"} + mockProvider := NewMockProvider(ctrl) + mockProvider.EXPECT().Name().Return("test").AnyTimes() + mockProvider.EXPECT().AuthURL().Return("/.auth/test").AnyTimes() + mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes() + mockProvider.EXPECT().AuthCodeURL(gomock.Any()).Return("https://example.com/oauth", nil) + mockProvider.EXPECT().Exchange(gomock.Any(), gomock.Any()).Return(mockToken, nil) + mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(true, "user@example.com", fullUserInfo, nil) + + authRouter, err := NewAuthRouter(nil, false, nil, mockProvider) + require.NoError(t, err) + + var storedUserInfo string + router := gin.New() + store := memstore.NewStore([]byte("test-secret")) + router.Use(sessions.Sessions("session", store)) + authRouter.SetupRoutes(router) + router.GET("/check-session", func(c *gin.Context) { + session := sessions.Default(c) + if v, ok := session.Get(SessionKeyUserInfo).(string); ok { + storedUserInfo = v + } + c.String(http.StatusOK, "ok") + }) + + server := httptest.NewServer(router) + defer server.Close() + client := setupClient() + + resp, err := client.Get(server.URL + "/.auth/test") + require.NoError(t, err) + resp.Body.Close() + + resp, err = client.Get(server.URL + "/.auth/test/callback") + require.NoError(t, err) + resp.Body.Close() + + resp, err = client.Get(server.URL + "/check-session") + require.NoError(t, err) + resp.Body.Close() + + var parsed map[string]any + require.NoError(t, json.Unmarshal([]byte(storedUserInfo), &parsed)) + require.Contains(t, parsed, "email") + require.Contains(t, parsed, "groups") + }) +} + func TestAuthenticationFlow(t *testing.T) { t.Run("Unauthenticated access should redirect to login", func(t *testing.T) { ctrl := gomock.NewController(t) @@ -54,7 +213,7 @@ func TestAuthenticationFlow(t *testing.T) { mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes() // Create AuthRouter (auto-select enabled by default) - authRouter, err := NewAuthRouter(nil, false, mockProvider) + authRouter, err := NewAuthRouter(nil, false, nil, mockProvider) require.NoError(t, err) router := setupTestRouter(authRouter) @@ -88,7 +247,7 @@ func TestAuthenticationFlow(t *testing.T) { mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(true, "authorized_user", map[string]any{"email": "authorized_user@example.com"}, nil) // Create AuthRouter - authRouter, err := NewAuthRouter(nil, false, mockProvider) + authRouter, err := NewAuthRouter(nil, false, nil, mockProvider) require.NoError(t, err) router := setupTestRouter(authRouter) @@ -149,7 +308,7 @@ func TestAuthenticationFlow(t *testing.T) { mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(false, "unauthorized_user", map[string]any{"email": "unauthorized_user@example.com"}, nil) // Create AuthRouter - authRouter, err := NewAuthRouter(nil, false, mockProvider) + authRouter, err := NewAuthRouter(nil, false, nil, mockProvider) require.NoError(t, err) router := setupTestRouter(authRouter) @@ -199,7 +358,7 @@ func TestLoginAutoRedirect(t *testing.T) { mockProvider.EXPECT().AuthURL().Return("/.auth/test").AnyTimes() mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes() - authRouter, err := NewAuthRouter(nil, false, mockProvider) + authRouter, err := NewAuthRouter(nil, false, nil, mockProvider) require.NoError(t, err) router := gin.New() @@ -230,7 +389,7 @@ func TestLoginAutoRedirect(t *testing.T) { mockProvider.EXPECT().AuthURL().Return("/.auth/test").AnyTimes() mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes() - authRouter, err := NewAuthRouter(nil, true, mockProvider) + authRouter, err := NewAuthRouter(nil, true, nil, mockProvider) require.NoError(t, err) router := gin.New() @@ -260,7 +419,7 @@ func TestLoginAutoRedirect(t *testing.T) { mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes() // Non-empty passwordHash slice disables auto-select - authRouter, err := NewAuthRouter([]string{"dummy"}, false, mockProvider) + authRouter, err := NewAuthRouter([]string{"dummy"}, false, nil, mockProvider) require.NoError(t, err) router := gin.New() diff --git a/pkg/idp/idp_test.go b/pkg/idp/idp_test.go index f53b287..b7631c0 100644 --- a/pkg/idp/idp_test.go +++ b/pkg/idp/idp_test.go @@ -66,7 +66,7 @@ func setupTestServer(t *testing.T) (*httptest.Server, repository.Repository, str }) // Create auth router and IDP router - authRouter, err := auth.NewAuthRouter([]string{}, false) + authRouter, err := auth.NewAuthRouter([]string{}, false, nil) require.NoError(t, err) logger, _ := zap.NewDevelopment() diff --git a/pkg/mcp-proxy/main.go b/pkg/mcp-proxy/main.go index 1ce8ad8..14e01c0 100644 --- a/pkg/mcp-proxy/main.go +++ b/pkg/mcp-proxy/main.go @@ -285,7 +285,11 @@ func Run( passwordHashes = append(passwordHashes, passwordHash) } - authRouter, err := auth.NewAuthRouter(passwordHashes, noProviderAutoSelect, providers...) + // Collect the top-level userinfo keys that are actually needed so the + // session cookie doesn't store the entire provider response. + userInfoFields := userInfoFieldsFromConfig(oidcUserIDField, headerMapping) + + authRouter, err := auth.NewAuthRouter(passwordHashes, noProviderAutoSelect, userInfoFields, providers...) if err != nil { return fmt.Errorf("failed to create auth router: %w", err) } @@ -529,3 +533,28 @@ func Run( wg.Wait() return errors.Join(errs...) } + +// userInfoFieldsFromConfig extracts the top-level userinfo keys referenced +// by the OIDC user-ID field and the header mapping. JSON pointers like +// "/email" or "/preferred_username" yield "email" or "preferred_username". +func userInfoFieldsFromConfig(oidcUserIDField string, headerMapping map[string]string) []string { + seen := map[string]struct{}{} + add := func(pointer string) { + pointer = strings.TrimPrefix(pointer, "/") + if i := strings.IndexByte(pointer, '/'); i != -1 { + pointer = pointer[:i] + } + if pointer != "" { + seen[pointer] = struct{}{} + } + } + add(oidcUserIDField) + for pointer := range headerMapping { + add(pointer) + } + fields := make([]string, 0, len(seen)) + for k := range seen { + fields = append(fields, k) + } + return fields +} diff --git a/pkg/mcp-proxy/main_test.go b/pkg/mcp-proxy/main_test.go index b18a03d..e4a6591 100644 --- a/pkg/mcp-proxy/main_test.go +++ b/pkg/mcp-proxy/main_test.go @@ -122,6 +122,40 @@ func TestRun_PassesHTTPStreamingOnlyToProxyRouter(t *testing.T) { require.True(t, streamingOnlyReceived, "httpStreamingOnly should be forwarded to proxy router") } +func TestUserInfoFieldsFromConfig(t *testing.T) { + t.Run("extracts fields from header mapping and user ID field", func(t *testing.T) { + fields := userInfoFieldsFromConfig("/email", map[string]string{ + "/email": "X-Forwarded-Email", + "/preferred_username": "X-Forwarded-User", + }) + require.ElementsMatch(t, []string{"email", "preferred_username"}, fields) + }) + + t.Run("handles nested JSON pointers by taking top-level key", func(t *testing.T) { + fields := userInfoFieldsFromConfig("/email", map[string]string{ + "/address/street": "X-Street", + }) + require.ElementsMatch(t, []string{"email", "address"}, fields) + }) + + t.Run("deduplicates overlapping fields", func(t *testing.T) { + fields := userInfoFieldsFromConfig("/email", map[string]string{ + "/email": "X-Forwarded-Email", + }) + require.Equal(t, []string{"email"}, fields) + }) + + t.Run("empty config returns empty slice", func(t *testing.T) { + fields := userInfoFieldsFromConfig("", nil) + require.Empty(t, fields) + }) + + t.Run("handles user ID field without leading slash", func(t *testing.T) { + fields := userInfoFieldsFromConfig("email", nil) + require.Equal(t, []string{"email"}, fields) + }) +} + func TestHealthzEndpoint(t *testing.T) { gin.SetMode(gin.TestMode) router := gin.New()