Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion pkg/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,14 @@ type AuthRouter struct {
// When true, do not auto-redirect to the sole provider even if
// there is only one provider and no password is set.
noProviderAutoSelect bool
// userInfoFields is a list of top-level keys to retain from the
// provider's userinfo response. When non-empty, all other keys are
// stripped before the data is stored in the session cookie. This
// prevents oversized cookies when the provider returns many claims.
userInfoFields []string
}

func NewAuthRouter(passwordHash []string, noProviderAutoSelect bool, providers ...Provider) (*AuthRouter, error) {
func NewAuthRouter(passwordHash []string, noProviderAutoSelect bool, userInfoFields []string, providers ...Provider) (*AuthRouter, error) {
tmpl, err := template.ParseFS(templateFS, "templates/login.html")
if err != nil {
return nil, err
Expand All @@ -50,6 +55,7 @@ func NewAuthRouter(passwordHash []string, noProviderAutoSelect bool, providers .
unauthorizedTemplate: unauthorizedTmpl,
errorTemplate: errorTmpl,
noProviderAutoSelect: noProviderAutoSelect,
userInfoFields: userInfoFields,
}, nil
}

Expand Down Expand Up @@ -102,6 +108,9 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) {
session.Set(SessionKeyAuthorized, true)
session.Set(SessionKeyUserID, user)
if userInfo != nil {
if len(a.userInfoFields) > 0 {
userInfo = filterUserInfo(userInfo, a.userInfoFields)
}
if userInfoJSON, err := json.Marshal(userInfo); err == nil {
session.Set(SessionKeyUserInfo, string(userInfoJSON))
}
Expand Down Expand Up @@ -282,6 +291,17 @@ func (a *AuthRouter) renderUnauthorized(c *gin.Context, userID, providerName str
}
}

// filterUserInfo returns a copy of m containing only the listed keys.
func filterUserInfo(m map[string]any, keys []string) map[string]any {
filtered := make(map[string]any, len(keys))
for _, k := range keys {
if v, ok := m[k]; ok {
filtered[k] = v
}
}
return filtered
}

func (a *AuthRouter) renderError(c *gin.Context, err error) {
data := errorTemplateData{
ErrorMessage: err.Error(),
Expand Down
171 changes: 165 additions & 6 deletions pkg/auth/auth_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package auth

import (
"encoding/json"
"net/http"
"net/http/cookiejar"
"net/http/httptest"
Expand Down Expand Up @@ -42,6 +43,164 @@ func setupClient() *http.Client {
}
}

func TestFilterUserInfo(t *testing.T) {
t.Run("filters to specified keys", func(t *testing.T) {
info := map[string]any{
"email": "[email protected]",
"preferred_username": "user",
"groups": []any{"admin", "dev"},
"realm_access": map[string]any{"roles": []any{"offline_access"}},
}
filtered := filterUserInfo(info, []string{"email", "preferred_username"})
require.Equal(t, map[string]any{
"email": "[email protected]",
"preferred_username": "user",
}, filtered)
})

t.Run("missing keys are skipped", func(t *testing.T) {
info := map[string]any{"email": "[email protected]"}
filtered := filterUserInfo(info, []string{"email", "name"})
require.Equal(t, map[string]any{"email": "[email protected]"}, filtered)
})

t.Run("empty keys returns empty map", func(t *testing.T) {
info := map[string]any{"email": "[email protected]"}
filtered := filterUserInfo(info, []string{})
require.Empty(t, filtered)
})

t.Run("nil input returns empty map", func(t *testing.T) {
filtered := filterUserInfo(nil, []string{"email"})
require.Empty(t, filtered)
})
}

func TestUserInfoFilteringInOAuthFlow(t *testing.T) {
t.Run("session stores only filtered userinfo fields", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

fullUserInfo := map[string]any{
"email": "[email protected]",
"preferred_username": "user",
"groups": []any{"admin", "developers", "platform-team"},
"realm_access": map[string]any{"roles": []any{"offline_access", "uma_authorization"}},
"resource_access": map[string]any{"account": map[string]any{"roles": []any{"view-profile"}}},
}

mockToken := &oauth2.Token{AccessToken: "test-token"}
mockProvider := NewMockProvider(ctrl)
mockProvider.EXPECT().Name().Return("test").AnyTimes()
mockProvider.EXPECT().AuthURL().Return("/.auth/test").AnyTimes()
mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes()
mockProvider.EXPECT().AuthCodeURL(gomock.Any()).Return("https://example.com/oauth", nil)
mockProvider.EXPECT().Exchange(gomock.Any(), gomock.Any()).Return(mockToken, nil)
mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(true, "[email protected]", fullUserInfo, nil)

authRouter, err := NewAuthRouter(nil, false, []string{"email", "preferred_username"}, mockProvider)
require.NoError(t, err)

// Add a route that reads back the session to verify stored userinfo
var storedUserInfo string
router := gin.New()
store := memstore.NewStore([]byte("test-secret"))
router.Use(sessions.Sessions("session", store))
authRouter.SetupRoutes(router)
router.GET("/check-session", func(c *gin.Context) {
session := sessions.Default(c)
if v, ok := session.Get(SessionKeyUserInfo).(string); ok {
storedUserInfo = v
}
c.String(http.StatusOK, "ok")
})

server := httptest.NewServer(router)
defer server.Close()
client := setupClient()

// Start auth flow to set oauth state
resp, err := client.Get(server.URL + "/.auth/test")
require.NoError(t, err)
resp.Body.Close()

// Complete callback
resp, err = client.Get(server.URL + "/.auth/test/callback")
require.NoError(t, err)
resp.Body.Close()
require.Equal(t, http.StatusFound, resp.StatusCode)

// Read back session
resp, err = client.Get(server.URL + "/check-session")
require.NoError(t, err)
resp.Body.Close()

var parsed map[string]any
require.NoError(t, json.Unmarshal([]byte(storedUserInfo), &parsed))
require.Equal(t, "[email protected]", parsed["email"])
require.Equal(t, "user", parsed["preferred_username"])
require.NotContains(t, parsed, "groups")
require.NotContains(t, parsed, "realm_access")
require.NotContains(t, parsed, "resource_access")
})

t.Run("nil filter stores full userinfo", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

fullUserInfo := map[string]any{
"email": "[email protected]",
"groups": []any{"admin"},
}

mockToken := &oauth2.Token{AccessToken: "test-token"}
mockProvider := NewMockProvider(ctrl)
mockProvider.EXPECT().Name().Return("test").AnyTimes()
mockProvider.EXPECT().AuthURL().Return("/.auth/test").AnyTimes()
mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes()
mockProvider.EXPECT().AuthCodeURL(gomock.Any()).Return("https://example.com/oauth", nil)
mockProvider.EXPECT().Exchange(gomock.Any(), gomock.Any()).Return(mockToken, nil)
mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(true, "[email protected]", fullUserInfo, nil)

authRouter, err := NewAuthRouter(nil, false, nil, mockProvider)
require.NoError(t, err)

var storedUserInfo string
router := gin.New()
store := memstore.NewStore([]byte("test-secret"))
router.Use(sessions.Sessions("session", store))
authRouter.SetupRoutes(router)
router.GET("/check-session", func(c *gin.Context) {
session := sessions.Default(c)
if v, ok := session.Get(SessionKeyUserInfo).(string); ok {
storedUserInfo = v
}
c.String(http.StatusOK, "ok")
})

server := httptest.NewServer(router)
defer server.Close()
client := setupClient()

resp, err := client.Get(server.URL + "/.auth/test")
require.NoError(t, err)
resp.Body.Close()

resp, err = client.Get(server.URL + "/.auth/test/callback")
require.NoError(t, err)
resp.Body.Close()

resp, err = client.Get(server.URL + "/check-session")
require.NoError(t, err)
resp.Body.Close()

var parsed map[string]any
require.NoError(t, json.Unmarshal([]byte(storedUserInfo), &parsed))
require.Contains(t, parsed, "email")
require.Contains(t, parsed, "groups")
})
}

func TestAuthenticationFlow(t *testing.T) {
t.Run("Unauthenticated access should redirect to login", func(t *testing.T) {
ctrl := gomock.NewController(t)
Expand All @@ -54,7 +213,7 @@ func TestAuthenticationFlow(t *testing.T) {
mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes()

// Create AuthRouter (auto-select enabled by default)
authRouter, err := NewAuthRouter(nil, false, mockProvider)
authRouter, err := NewAuthRouter(nil, false, nil, mockProvider)
require.NoError(t, err)

router := setupTestRouter(authRouter)
Expand Down Expand Up @@ -88,7 +247,7 @@ func TestAuthenticationFlow(t *testing.T) {
mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(true, "authorized_user", map[string]any{"email": "[email protected]"}, nil)

// Create AuthRouter
authRouter, err := NewAuthRouter(nil, false, mockProvider)
authRouter, err := NewAuthRouter(nil, false, nil, mockProvider)
require.NoError(t, err)

router := setupTestRouter(authRouter)
Expand Down Expand Up @@ -149,7 +308,7 @@ func TestAuthenticationFlow(t *testing.T) {
mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(false, "unauthorized_user", map[string]any{"email": "[email protected]"}, nil)

// Create AuthRouter
authRouter, err := NewAuthRouter(nil, false, mockProvider)
authRouter, err := NewAuthRouter(nil, false, nil, mockProvider)
require.NoError(t, err)

router := setupTestRouter(authRouter)
Expand Down Expand Up @@ -199,7 +358,7 @@ func TestLoginAutoRedirect(t *testing.T) {
mockProvider.EXPECT().AuthURL().Return("/.auth/test").AnyTimes()
mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes()

authRouter, err := NewAuthRouter(nil, false, mockProvider)
authRouter, err := NewAuthRouter(nil, false, nil, mockProvider)
require.NoError(t, err)

router := gin.New()
Expand Down Expand Up @@ -230,7 +389,7 @@ func TestLoginAutoRedirect(t *testing.T) {
mockProvider.EXPECT().AuthURL().Return("/.auth/test").AnyTimes()
mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes()

authRouter, err := NewAuthRouter(nil, true, mockProvider)
authRouter, err := NewAuthRouter(nil, true, nil, mockProvider)
require.NoError(t, err)

router := gin.New()
Expand Down Expand Up @@ -260,7 +419,7 @@ func TestLoginAutoRedirect(t *testing.T) {
mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes()

// Non-empty passwordHash slice disables auto-select
authRouter, err := NewAuthRouter([]string{"dummy"}, false, mockProvider)
authRouter, err := NewAuthRouter([]string{"dummy"}, false, nil, mockProvider)
require.NoError(t, err)

router := gin.New()
Expand Down
2 changes: 1 addition & 1 deletion pkg/idp/idp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func setupTestServer(t *testing.T) (*httptest.Server, repository.Repository, str
})

// Create auth router and IDP router
authRouter, err := auth.NewAuthRouter([]string{}, false)
authRouter, err := auth.NewAuthRouter([]string{}, false, nil)
require.NoError(t, err)

logger, _ := zap.NewDevelopment()
Expand Down
31 changes: 30 additions & 1 deletion pkg/mcp-proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,11 @@ func Run(
passwordHashes = append(passwordHashes, passwordHash)
}

authRouter, err := auth.NewAuthRouter(passwordHashes, noProviderAutoSelect, providers...)
// Collect the top-level userinfo keys that are actually needed so the
// session cookie doesn't store the entire provider response.
userInfoFields := userInfoFieldsFromConfig(oidcUserIDField, headerMapping)

authRouter, err := auth.NewAuthRouter(passwordHashes, noProviderAutoSelect, userInfoFields, providers...)
if err != nil {
return fmt.Errorf("failed to create auth router: %w", err)
}
Expand Down Expand Up @@ -529,3 +533,28 @@ func Run(
wg.Wait()
return errors.Join(errs...)
}

// userInfoFieldsFromConfig extracts the top-level userinfo keys referenced
// by the OIDC user-ID field and the header mapping. JSON pointers like
// "/email" or "/preferred_username" yield "email" or "preferred_username".
func userInfoFieldsFromConfig(oidcUserIDField string, headerMapping map[string]string) []string {
seen := map[string]struct{}{}
add := func(pointer string) {
pointer = strings.TrimPrefix(pointer, "/")
if i := strings.IndexByte(pointer, '/'); i != -1 {
pointer = pointer[:i]
}
if pointer != "" {
seen[pointer] = struct{}{}
}
}
add(oidcUserIDField)
for pointer := range headerMapping {
add(pointer)
}
fields := make([]string, 0, len(seen))
for k := range seen {
fields = append(fields, k)
}
return fields
}
34 changes: 34 additions & 0 deletions pkg/mcp-proxy/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,40 @@ func TestRun_PassesHTTPStreamingOnlyToProxyRouter(t *testing.T) {
require.True(t, streamingOnlyReceived, "httpStreamingOnly should be forwarded to proxy router")
}

func TestUserInfoFieldsFromConfig(t *testing.T) {
t.Run("extracts fields from header mapping and user ID field", func(t *testing.T) {
fields := userInfoFieldsFromConfig("/email", map[string]string{
"/email": "X-Forwarded-Email",
"/preferred_username": "X-Forwarded-User",
})
require.ElementsMatch(t, []string{"email", "preferred_username"}, fields)
})

t.Run("handles nested JSON pointers by taking top-level key", func(t *testing.T) {
fields := userInfoFieldsFromConfig("/email", map[string]string{
"/address/street": "X-Street",
})
require.ElementsMatch(t, []string{"email", "address"}, fields)
})

t.Run("deduplicates overlapping fields", func(t *testing.T) {
fields := userInfoFieldsFromConfig("/email", map[string]string{
"/email": "X-Forwarded-Email",
})
require.Equal(t, []string{"email"}, fields)
})

t.Run("empty config returns empty slice", func(t *testing.T) {
fields := userInfoFieldsFromConfig("", nil)
require.Empty(t, fields)
})

t.Run("handles user ID field without leading slash", func(t *testing.T) {
fields := userInfoFieldsFromConfig("email", nil)
require.Equal(t, []string{"email"}, fields)
})
}

func TestHealthzEndpoint(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
Expand Down
Loading