From b4013df62a8a2991f7e960a6a509e2d65b3bb54c Mon Sep 17 00:00:00 2001 From: Takanori Hirano Date: Fri, 3 Apr 2026 21:37:36 +0900 Subject: [PATCH] feat: forward authenticated user identity to upstream via headers Add --header-mapping flag to inject OIDC/Google/GitHub userinfo attributes into upstream request headers. Userinfo is embedded as JWT custom claims and extracted in the proxy using JSON pointers. Also fixes JWT subject claim from hardcoded "user" to the actual authenticated user's identity. Closes #130 --- main.go | 30 ++++++++++++ main_test.go | 60 +++++++++++++++++++++++ pkg/auth/auth.go | 12 ++++- pkg/auth/auth_test.go | 4 +- pkg/auth/github.go | 40 ++++++++-------- pkg/auth/github_test.go | 2 +- pkg/auth/google.go | 32 ++++++------- pkg/auth/google_test.go | 6 +-- pkg/auth/interface.go | 2 +- pkg/auth/mock.go | 7 +-- pkg/auth/oidc.go | 26 +++++----- pkg/auth/oidc_test.go | 12 ++--- pkg/idp/idp.go | 27 +++++++++-- pkg/mcp-proxy/main.go | 3 +- pkg/mcp-proxy/main_test.go | 3 +- pkg/proxy/proxy.go | 31 ++++++++++++ pkg/proxy/proxy_test.go | 97 +++++++++++++++++++++++++++++++++++++- 17 files changed, 321 insertions(+), 73 deletions(-) diff --git a/main.go b/main.go index 4b1dd4d..1f8f0a9 100644 --- a/main.go +++ b/main.go @@ -91,6 +91,30 @@ func parseAttributeMap(s string) map[string][]string { return result } +func parseHeaderMapping(s string) map[string]string { + result := make(map[string]string) + if s == "" { + return result + } + parts := splitWithEscapes(s, ",") + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + colonIdx := strings.LastIndex(part, ":") + if colonIdx == -1 { + continue + } + pointer := strings.TrimSpace(part[:colonIdx]) + header := strings.TrimSpace(part[colonIdx+1:]) + if pointer != "" && header != "" { + result[pointer] = header + } + } + return result +} + type proxyRunnerFunc func( listen string, tlsListen string, @@ -130,6 +154,7 @@ type proxyRunnerFunc func( proxyBearerToken string, proxyTarget []string, httpStreamingOnly bool, + headerMapping map[string]string, ) error func main() { @@ -174,6 +199,7 @@ func newRootCommand(run proxyRunnerFunc) *cobra.Command { var passwordHash string var proxyBearerToken string var proxyHeaders string + var headerMapping string var httpStreamingOnly bool var trustedProxies string @@ -255,6 +281,8 @@ func newRootCommand(run proxyRunnerFunc) *cobra.Command { } } + headerMappingMap := parseHeaderMapping(headerMapping) + if err := run( listen, tlsListen, @@ -294,6 +322,7 @@ func newRootCommand(run proxyRunnerFunc) *cobra.Command { proxyBearerToken, args, httpStreamingOnly, + headerMappingMap, ); err != nil { panic(err) } @@ -347,6 +376,7 @@ func newRootCommand(run proxyRunnerFunc) *cobra.Command { rootCmd.Flags().StringVar(&trustedProxies, "trusted-proxies", getEnvWithDefault("TRUSTED_PROXIES", ""), "Comma-separated list of trusted proxies (IP addresses or CIDR ranges)") rootCmd.Flags().StringVar(&proxyHeaders, "proxy-headers", getEnvWithDefault("PROXY_HEADERS", ""), "Comma-separated list of headers to add when proxying requests (format: Header1:Value1,Header2:Value2)") rootCmd.Flags().BoolVar(&httpStreamingOnly, "http-streaming-only", getEnvBoolWithDefault("HTTP_STREAMING_ONLY", false), "Reject SSE (GET) requests and keep the backend in HTTP streaming-only mode") + rootCmd.Flags().StringVar(&headerMapping, "header-mapping", getEnvWithDefault("HEADER_MAPPING", ""), "Comma-separated mapping of userinfo JSON pointer paths to header names (e.g., /email:X-Forwarded-Email,/preferred_username:X-Forwarded-User)") return rootCmd } diff --git a/main_test.go b/main_test.go index e3caff1..a78f3d3 100644 --- a/main_test.go +++ b/main_test.go @@ -168,6 +168,64 @@ func TestParseAttributeMap(t *testing.T) { } } +func TestParseHeaderMapping(t *testing.T) { + testCases := []struct { + name string + input string + expected map[string]string + }{ + { + name: "empty string", + input: "", + expected: map[string]string{}, + }, + { + name: "single mapping", + input: "/email:X-Forwarded-Email", + expected: map[string]string{ + "/email": "X-Forwarded-Email", + }, + }, + { + name: "multiple mappings", + input: "/email:X-Forwarded-Email,/preferred_username:X-Forwarded-User", + expected: map[string]string{ + "/email": "X-Forwarded-Email", + "/preferred_username": "X-Forwarded-User", + }, + }, + { + name: "nested JSON pointer", + input: "/org/team:X-Forwarded-Team", + expected: map[string]string{ + "/org/team": "X-Forwarded-Team", + }, + }, + { + name: "whitespace trimming", + input: " /email : X-Forwarded-Email , /sub : X-Forwarded-Sub ", + expected: map[string]string{ + "/email": "X-Forwarded-Email", + "/sub": "X-Forwarded-Sub", + }, + }, + { + name: "no colon - skipped", + input: "invalid", + expected: map[string]string{}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := parseHeaderMapping(tc.input) + if !reflect.DeepEqual(result, tc.expected) { + t.Errorf("Expected %v, got %v", tc.expected, result) + } + }) + } +} + func TestGetEnvWithDefault(t *testing.T) { testCases := []struct { name string @@ -344,6 +402,7 @@ func TestNewRootCommand_HTTPStreamingOnlyFlag(t *testing.T) { proxyBearerToken string, proxyTarget []string, httpStreamingOnly bool, + headerMapping map[string]string, ) error { streamingOnly = httpStreamingOnly receivedTargets = proxyTarget @@ -407,6 +466,7 @@ func TestNewRootCommand_HTTPStreamingOnlyFromEnv(t *testing.T) { proxyBearerToken string, proxyTarget []string, httpStreamingOnly bool, + headerMapping map[string]string, ) error { streamingOnly = httpStreamingOnly return nil diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 339ff7e..04fda55 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -2,6 +2,7 @@ package auth import ( "embed" + "encoding/json" "errors" "html/template" "net/http" @@ -68,6 +69,8 @@ const ( SessionKeyAuthorized = "authorized" SessionKeyRedirectURL = "redirect_url" SessionKeyOAuthState = "oauth_state" + SessionKeyUserID = "user_id" + SessionKeyUserInfo = "user_info" ) func (a *AuthRouter) SetupRoutes(router gin.IRouter) { @@ -87,7 +90,7 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) { a.renderError(c, err) return } - ok, user, err := provider.Authorization(c, token) + ok, user, userInfo, err := provider.Authorization(c, token) if err != nil { a.renderError(c, err) return @@ -97,6 +100,12 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) { return } session.Set(SessionKeyAuthorized, true) + session.Set(SessionKeyUserID, user) + if userInfo != nil { + if userInfoJSON, err := json.Marshal(userInfo); err == nil { + session.Set(SessionKeyUserInfo, string(userInfoJSON)) + } + } redirectURL := session.Get(SessionKeyRedirectURL) if redirectURL != nil { session.Delete(SessionKeyRedirectURL) @@ -177,6 +186,7 @@ func (a *AuthRouter) handleLoginPost(c *gin.Context) { session := sessions.Default(c) session.Set(SessionKeyAuthorized, true) + session.Set(SessionKeyUserID, PasswordUserID) redirectURL := session.Get(SessionKeyRedirectURL) if redirectURL != nil { session.Delete(SessionKeyRedirectURL) diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go index e1dc8e2..14dbb9d 100644 --- a/pkg/auth/auth_test.go +++ b/pkg/auth/auth_test.go @@ -85,7 +85,7 @@ func TestAuthenticationFlow(t *testing.T) { mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes() mockProvider.EXPECT().AuthCodeURL(gomock.Any()).Return("https://example.com/oauth", nil) mockProvider.EXPECT().Exchange(gomock.Any(), gomock.Any()).Return(mockToken, nil) - mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(true, "authorized_user", nil) + mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(true, "authorized_user", map[string]any{"email": "authorized_user@example.com"}, nil) // Create AuthRouter authRouter, err := NewAuthRouter(nil, false, mockProvider) @@ -146,7 +146,7 @@ func TestAuthenticationFlow(t *testing.T) { mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes() mockProvider.EXPECT().AuthCodeURL(gomock.Any()).Return("https://example.com/oauth", nil) mockProvider.EXPECT().Exchange(gomock.Any(), gomock.Any()).Return(mockToken, nil) - mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(false, "unauthorized_user", nil) + mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(false, "unauthorized_user", map[string]any{"email": "unauthorized_user@example.com"}, nil) // Create AuthRouter authRouter, err := NewAuthRouter(nil, false, mockProvider) diff --git a/pkg/auth/github.go b/pkg/auth/github.go index d931917..e1afcfb 100644 --- a/pkg/auth/github.go +++ b/pkg/auth/github.go @@ -85,30 +85,30 @@ func (p *githubProvider) Exchange(c *gin.Context, state string) (*oauth2.Token, return token, nil } -func (p *githubProvider) Authorization(ctx context.Context, token *oauth2.Token) (bool, string, error) { +func (p *githubProvider) Authorization(ctx context.Context, token *oauth2.Token) (bool, string, map[string]any, error) { client := p.oauth2.Client(ctx, token) resp1, err := client.Get(utils.Must(url.JoinPath(p.endpoint, "/user"))) if err != nil { - return false, "", err + return false, "", nil, err } if resp1.StatusCode < 200 || resp1.StatusCode >= 300 { - return false, "", errors.New("failed to get user info from GitHub API: " + resp1.Status) + return false, "", nil, errors.New("failed to get user info from GitHub API: " + resp1.Status) } defer resp1.Body.Close() - var userInfo struct { - Login string `json:"login"` - } - if err := json.NewDecoder(resp1.Body).Decode(&userInfo); err != nil { - return false, "", err + var userInfoMap map[string]any + if err := json.NewDecoder(resp1.Body).Decode(&userInfoMap); err != nil { + return false, "", nil, err } + login, _ := userInfoMap["login"].(string) + if len(p.allowedUsers) == 0 && len(p.allowedOrgs) == 0 { - return true, userInfo.Login, nil + return true, login, userInfoMap, nil } - if slices.Contains(p.allowedUsers, userInfo.Login) { - return true, userInfo.Login, nil + if slices.Contains(p.allowedUsers, login) { + return true, login, userInfoMap, nil } allowedOrgTeams := []string{} @@ -124,31 +124,31 @@ func (p *githubProvider) Authorization(ctx context.Context, token *oauth2.Token) if len(allowedOrgs) > 0 { resp2, err := client.Get(utils.Must(url.JoinPath(p.endpoint, "/user/orgs"))) if err != nil { - return false, "", err + return false, "", nil, err } if resp2.StatusCode < 200 || resp2.StatusCode >= 300 { - return false, "", errors.New("failed to get user info from GitHub API: " + resp2.Status) + return false, "", nil, errors.New("failed to get user info from GitHub API: " + resp2.Status) } defer resp2.Body.Close() var orgInfo []struct { Login string `json:"login"` } if err := json.NewDecoder(resp2.Body).Decode(&orgInfo); err != nil { - return false, "", err + return false, "", nil, err } for _, o := range orgInfo { if slices.Contains(allowedOrgs, o.Login) { - return true, userInfo.Login, nil + return true, login, userInfoMap, nil } } } if len(allowedOrgTeams) > 0 { resp3, err := client.Get(utils.Must(url.JoinPath(p.endpoint, "/user/teams"))) if err != nil { - return false, "", err + return false, "", nil, err } if resp3.StatusCode < 200 || resp3.StatusCode >= 300 { - return false, "", errors.New("failed to get user info from GitHub API: " + resp3.Status) + return false, "", nil, errors.New("failed to get user info from GitHub API: " + resp3.Status) } defer resp3.Body.Close() var teamInfo []struct { @@ -158,14 +158,14 @@ func (p *githubProvider) Authorization(ctx context.Context, token *oauth2.Token) Slug string `json:"slug"` } if err := json.NewDecoder(resp3.Body).Decode(&teamInfo); err != nil { - return false, "", err + return false, "", nil, err } for _, team := range teamInfo { if slices.Contains(allowedOrgTeams, team.Organization.Login+":"+team.Slug) { - return true, userInfo.Login, nil + return true, login, userInfoMap, nil } } } - return false, userInfo.Login, nil + return false, login, userInfoMap, nil } diff --git a/pkg/auth/github_test.go b/pkg/auth/github_test.go index f3e0eeb..93b1b08 100644 --- a/pkg/auth/github_test.go +++ b/pkg/auth/github_test.go @@ -168,7 +168,7 @@ func TestGitHubProviderAuthorization(t *testing.T) { }) // Call the Authorization method - ok, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) + ok, _, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) require.NoError(t, err) require.Equal(t, expect, ok) }) diff --git a/pkg/auth/google.go b/pkg/auth/google.go index ccc46cf..69d7763 100644 --- a/pkg/auth/google.go +++ b/pkg/auth/google.go @@ -83,38 +83,36 @@ func (p *googleProvider) Exchange(c *gin.Context, state string) (*oauth2.Token, return token, nil } -func (p *googleProvider) Authorization(ctx context.Context, token *oauth2.Token) (bool, string, error) { +func (p *googleProvider) Authorization(ctx context.Context, token *oauth2.Token) (bool, string, map[string]any, error) { client := p.oauth2.Client(ctx, token) resp, err := client.Get(p.userinfoEndpoint) if err != nil { - return false, "", err + return false, "", nil, err } if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return false, "", errors.New("failed to get user info from Google API: " + resp.Status) + return false, "", nil, errors.New("failed to get user info from Google API: " + resp.Status) } defer resp.Body.Close() - var userInfo struct { - Sub string `json:"sub"` - Name string `json:"name"` - Email string `json:"email"` - HD string `json:"hd"` - } - if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil { - return false, "", err + var userInfoMap map[string]any + if err := json.NewDecoder(resp.Body).Decode(&userInfoMap); err != nil { + return false, "", nil, err } + email, _ := userInfoMap["email"].(string) + hd, _ := userInfoMap["hd"].(string) + if len(p.allowedUsers) == 0 && len(p.allowedWorkspaces) == 0 { - return true, userInfo.Email, nil + return true, email, userInfoMap, nil } - if slices.Contains(p.allowedUsers, userInfo.Email) { - return true, userInfo.Email, nil + if slices.Contains(p.allowedUsers, email) { + return true, email, userInfoMap, nil } - if slices.Contains(p.allowedWorkspaces, userInfo.HD) { - return true, userInfo.Email, nil + if slices.Contains(p.allowedWorkspaces, hd) { + return true, email, userInfoMap, nil } - return false, userInfo.Email, nil + return false, email, userInfoMap, nil } diff --git a/pkg/auth/google_test.go b/pkg/auth/google_test.go index ae792ae..c4aaad6 100644 --- a/pkg/auth/google_test.go +++ b/pkg/auth/google_test.go @@ -156,7 +156,7 @@ func TestGoogleProviderAuthorization(t *testing.T) { c.Data(http.StatusOK, "application/json", []byte(tt.userResp)) }) - ok, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) + ok, _, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) require.NoError(t, err) require.Equal(t, tt.expect, ok) }) @@ -170,7 +170,7 @@ func TestGoogleProviderAuthorizationAPIError(t *testing.T) { c.JSON(http.StatusInternalServerError, gin.H{"error": "internal error"}) }) - ok, user, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) + ok, user, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) require.Error(t, err) require.False(t, ok) require.Empty(t, user) @@ -184,7 +184,7 @@ func TestGoogleProviderAuthorizationInvalidJSON(t *testing.T) { c.Data(http.StatusOK, "application/json", []byte(`invalid json`)) }) - ok, user, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) + ok, user, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) require.Error(t, err) require.False(t, ok) require.Empty(t, user) diff --git a/pkg/auth/interface.go b/pkg/auth/interface.go index 5aebc70..f3f6f43 100644 --- a/pkg/auth/interface.go +++ b/pkg/auth/interface.go @@ -15,5 +15,5 @@ type Provider interface { AuthURL() string AuthCodeURL(state string) (string, error) Exchange(c *gin.Context, state string) (*oauth2.Token, error) - Authorization(ctx context.Context, token *oauth2.Token) (bool, string, error) + Authorization(ctx context.Context, token *oauth2.Token) (bool, string, map[string]any, error) } diff --git a/pkg/auth/mock.go b/pkg/auth/mock.go index 0a031eb..4814a9e 100644 --- a/pkg/auth/mock.go +++ b/pkg/auth/mock.go @@ -72,13 +72,14 @@ func (mr *MockProviderMockRecorder) AuthURL() *gomock.Call { } // Authorization mocks base method. -func (m *MockProvider) Authorization(ctx context.Context, token *oauth2.Token) (bool, string, error) { +func (m *MockProvider) Authorization(ctx context.Context, token *oauth2.Token) (bool, string, map[string]any, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Authorization", ctx, token) ret0, _ := ret[0].(bool) ret1, _ := ret[1].(string) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 + ret2, _ := ret[2].(map[string]any) + ret3, _ := ret[3].(error) + return ret0, ret1, ret2, ret3 } // Authorization indicates an expected call of Authorization. diff --git a/pkg/auth/oidc.go b/pkg/auth/oidc.go index 9d0aa86..1be1c02 100644 --- a/pkg/auth/oidc.go +++ b/pkg/auth/oidc.go @@ -128,40 +128,44 @@ func (p *oidcProvider) Exchange(c *gin.Context, state string) (*oauth2.Token, er return token, nil } -func (p *oidcProvider) Authorization(ctx context.Context, token *oauth2.Token) (bool, string, error) { +func (p *oidcProvider) Authorization(ctx context.Context, token *oauth2.Token) (bool, string, map[string]any, error) { client := p.oauth2.Client(ctx, token) resp, err := client.Get(p.userInfoURL) if err != nil { - return false, "", err + return false, "", nil, err } defer resp.Body.Close() var obj any if err := json.NewDecoder(resp.Body).Decode(&obj); err != nil { - return false, "", err + return false, "", nil, err + } + userInfoMap, ok := obj.(map[string]any) + if !ok { + return false, "", nil, errors.New("userinfo response is not a JSON object") } v, err := jsonpointer.Get(obj, p.userIDField) if err != nil { - return false, "", err + return false, "", nil, err } userID, ok := v.(string) if !ok { - return false, "", errors.New("user ID field is not a string") + return false, "", nil, errors.New("user ID field is not a string") } // If no restrictions are set, allow all users if len(p.allowedUsers) == 0 && len(p.allowedUsersGlob) == 0 && len(p.allowedAttributes) == 0 && len(p.allowedAttributesGlob) == 0 { - return true, userID, nil + return true, userID, userInfoMap, nil } // Check exact user matches first if slices.Contains(p.allowedUsers, userID) { - return true, userID, nil + return true, userID, userInfoMap, nil } // Check user glob patterns for _, g := range p.allowedUsersGlob { if g.Match(userID) { - return true, userID, nil + return true, userID, userInfoMap, nil } } @@ -172,7 +176,7 @@ func (p *oidcProvider) Authorization(ctx context.Context, token *oauth2.Token) ( continue // Attribute not found, skip } if matchAttributeValue(attrValue, allowedValues) { - return true, userID, nil + return true, userID, userInfoMap, nil } } @@ -183,11 +187,11 @@ func (p *oidcProvider) Authorization(ctx context.Context, token *oauth2.Token) ( continue // Attribute not found, skip } if matchAttributeGlob(attrValue, globs) { - return true, userID, nil + return true, userID, userInfoMap, nil } } - return false, userID, nil + return false, userID, userInfoMap, nil } // matchAttributeValue checks if an attribute value matches any of the allowed values. diff --git a/pkg/auth/oidc_test.go b/pkg/auth/oidc_test.go index 604064c..82da4e6 100644 --- a/pkg/auth/oidc_test.go +++ b/pkg/auth/oidc_test.go @@ -167,7 +167,7 @@ func TestOIDCProviderAuthorization(t *testing.T) { }) // Call the Authorization method - ok, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) + ok, _, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) require.NoError(t, err) require.Equal(t, tt.expect, ok) }) @@ -224,7 +224,7 @@ func TestOIDCProviderErrors(t *testing.T) { c.JSON(http.StatusOK, gin.H{"sub": "user1"}) }) - ok, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) + ok, _, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) require.Error(t, err) require.False(t, ok) }) @@ -237,7 +237,7 @@ func TestOIDCProviderErrors(t *testing.T) { c.JSON(http.StatusOK, gin.H{"sub": 12345}) }) - ok, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) + ok, _, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) require.Error(t, err) require.False(t, ok) }) @@ -250,7 +250,7 @@ func TestOIDCProviderErrors(t *testing.T) { c.JSON(http.StatusInternalServerError, gin.H{"error": "server error"}) }) - ok, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) + ok, _, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) require.Error(t, err) require.False(t, ok) }) @@ -316,7 +316,7 @@ func TestOIDCProviderGlobPatterns(t *testing.T) { provider.userInfoURL = tsUser.URL + "/userinfo" // Test authorization - authorized, userID, err := provider.Authorization(context.Background(), &oauth2.Token{AccessToken: "test"}) + authorized, userID, _, err := provider.Authorization(context.Background(), &oauth2.Token{AccessToken: "test"}) require.NoError(t, err) require.Equal(t, tc.email, userID) require.Equal(t, tc.expected, authorized, "Expected %v for email %s", tc.expected, tc.email) @@ -449,7 +449,7 @@ func TestOIDCProviderAttributeMatching(t *testing.T) { provider.userInfoURL = tsUser.URL + "/userinfo" // Test authorization - authorized, _, err := provider.Authorization(context.Background(), &oauth2.Token{AccessToken: "test"}) + authorized, _, _, err := provider.Authorization(context.Background(), &oauth2.Token{AccessToken: "test"}) require.NoError(t, err) require.Equal(t, tc.expected, authorized, "Expected %v for test case %s", tc.expected, tc.name) }) diff --git a/pkg/idp/idp.go b/pkg/idp/idp.go index 69048ad..f5bcfdf 100644 --- a/pkg/idp/idp.go +++ b/pkg/idp/idp.go @@ -4,11 +4,13 @@ import ( "context" "crypto/rsa" "encoding/base64" + "encoding/json" "math/big" "net/url" "strings" "time" + "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" "github.com/ory/fosite" "github.com/ory/fosite/compose" @@ -142,7 +144,18 @@ func (a *IDPRouter) handleAuthorizationReturn(c *gin.Context) { for _, scope := range ar.GetRequestedScopes() { ar.GrantScope(scope) } - jwtSession, err := NewJWTSessionWithKey(a.externalURL, "user", a.privKey) + + session := sessions.Default(c) + subject := "user" + if userID, ok := session.Get(auth.SessionKeyUserID).(string); ok && userID != "" { + subject = userID + } + var userInfo map[string]any + if userInfoJSON, ok := session.Get(auth.SessionKeyUserInfo).(string); ok && userInfoJSON != "" { + json.Unmarshal([]byte(userInfoJSON), &userInfo) + } + + jwtSession, err := NewJWTSessionWithKey(a.externalURL, subject, a.privKey, userInfo) if err != nil { a.logger.With(utils.Err(err)...).Error("Failed to create JWT session", zap.Error(err)) a.provider.WriteAuthorizeError(ctx, c.Writer, ar, err) @@ -162,7 +175,7 @@ func (a *IDPRouter) handleAuthorizationReturn(c *gin.Context) { func (a *IDPRouter) handleToken(c *gin.Context) { ctx := c.Request.Context() - session, err := NewJWTSessionWithKey("", "", a.privKey) + session, err := NewJWTSessionWithKey("", "", a.privKey, nil) if err != nil { a.logger.With(utils.Err(err)...).Error("Failed to create JWT session for token", zap.Error(err)) a.provider.WriteAccessError(ctx, c.Writer, nil, fosite.ErrServerError.WithWrap(err)) @@ -188,7 +201,7 @@ func (a *IDPRouter) handleToken(c *gin.Context) { func (a *IDPRouter) handleIntrospect(c *gin.Context) { ctx := c.Request.Context() - session, err := NewJWTSessionWithKey("", "", a.privKey) + session, err := NewJWTSessionWithKey("", "", a.privKey, nil) if err != nil { a.provider.WriteIntrospectionError(ctx, c.Writer, fosite.ErrServerError.WithWrap(err)) return @@ -392,11 +405,15 @@ func (a *IDPRouter) handleJWKS(c *gin.Context) { c.JSON(200, ks) } -func NewJWTSessionWithKey(iss string, subject string, privateKey *rsa.PrivateKey) (*Session, error) { +func NewJWTSessionWithKey(iss string, subject string, privateKey *rsa.PrivateKey, userInfo map[string]any) (*Session, error) { keyID, err := utils.GenerateKeyID(&privateKey.PublicKey) if err != nil { return nil, err } + var extra map[string]any + if userInfo != nil { + extra = map[string]any{"userinfo": userInfo} + } return &Session{ DefaultSession: &fosite.DefaultSession{ Username: subject, @@ -409,6 +426,7 @@ func NewJWTSessionWithKey(iss string, subject string, privateKey *rsa.PrivateKey ExpiresAt: time.Now().Add(time.Hour), IssuedAt: time.Now(), NotBefore: time.Now(), + Extra: extra, }, JWTHeader: &jwt.Headers{ Extra: map[string]any{ @@ -450,6 +468,7 @@ func (s *Session) Clone() fosite.Session { ExpiresAt: s.JWTClaims.ExpiresAt, IssuedAt: s.JWTClaims.IssuedAt, NotBefore: s.JWTClaims.NotBefore, + Extra: s.JWTClaims.Extra, }, JWTHeader: &jwt.Headers{ Extra: make(map[string]any), diff --git a/pkg/mcp-proxy/main.go b/pkg/mcp-proxy/main.go index 0b59c73..e867f7d 100644 --- a/pkg/mcp-proxy/main.go +++ b/pkg/mcp-proxy/main.go @@ -76,6 +76,7 @@ func Run( proxyBearerToken string, proxyTarget []string, httpStreamingOnly bool, + headerMapping map[string]string, ) error { ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt) defer stop() @@ -290,7 +291,7 @@ func Run( if err != nil { return fmt.Errorf("failed to create IDP router: %w", err) } - proxyRouter, err := newProxyRouter(externalURL, beHandler, &privKey.PublicKey, proxyHeadersMap, httpStreamingOnly) + proxyRouter, err := newProxyRouter(externalURL, beHandler, &privKey.PublicKey, proxyHeadersMap, httpStreamingOnly, headerMapping) if err != nil { return fmt.Errorf("failed to create proxy router: %w", err) } diff --git a/pkg/mcp-proxy/main_test.go b/pkg/mcp-proxy/main_test.go index db220d9..30b6c38 100644 --- a/pkg/mcp-proxy/main_test.go +++ b/pkg/mcp-proxy/main_test.go @@ -17,7 +17,7 @@ func TestRun_PassesHTTPStreamingOnlyToProxyRouter(t *testing.T) { }) var streamingOnlyReceived bool - newProxyRouter = func(externalURL string, proxyHandler http.Handler, publicKey *rsa.PublicKey, proxyHeaders http.Header, httpStreamingOnly bool) (*proxy.ProxyRouter, error) { + newProxyRouter = func(externalURL string, proxyHandler http.Handler, publicKey *rsa.PublicKey, proxyHeaders http.Header, httpStreamingOnly bool, headerMapping map[string]string) (*proxy.ProxyRouter, error) { streamingOnlyReceived = httpStreamingOnly return nil, errors.New("proxy router init failed") } @@ -61,6 +61,7 @@ func TestRun_PassesHTTPStreamingOnlyToProxyRouter(t *testing.T) { "", []string{"http://example.com"}, true, + nil, ) require.Error(t, err) diff --git a/pkg/proxy/proxy.go b/pkg/proxy/proxy.go index 3938365..6b86521 100644 --- a/pkg/proxy/proxy.go +++ b/pkg/proxy/proxy.go @@ -8,6 +8,7 @@ import ( "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" + "github.com/mattn/go-jsonpointer" ) type ProxyRouter struct { @@ -16,6 +17,7 @@ type ProxyRouter struct { publicKey *rsa.PublicKey proxyHeaders http.Header httpStreamingOnly bool + headerMapping map[string]string } func NewProxyRouter( @@ -24,6 +26,7 @@ func NewProxyRouter( publicKey *rsa.PublicKey, proxyHeaders http.Header, httpStreamingOnly bool, + headerMapping map[string]string, ) (*ProxyRouter, error) { return &ProxyRouter{ externalURL: externalURL, @@ -31,6 +34,7 @@ func NewProxyRouter( publicKey: publicKey, proxyHeaders: proxyHeaders, httpStreamingOnly: httpStreamingOnly, + headerMapping: headerMapping, }, nil } @@ -87,6 +91,33 @@ func (p *ProxyRouter) handleProxy(c *gin.Context) { } } + if len(p.headerMapping) > 0 { + if claims, ok := token.Claims.(jwt.MapClaims); ok { + if userinfo, exists := claims["userinfo"]; exists { + for pointer, headerName := range p.headerMapping { + val, err := jsonpointer.Get(userinfo, pointer) + if err != nil { + continue + } + switch v := val.(type) { + case string: + c.Request.Header.Set(headerName, v) + case []any: + var parts []string + for _, item := range v { + if s, ok := item.(string); ok { + parts = append(parts, s) + } + } + c.Request.Header.Set(headerName, strings.Join(parts, ",")) + default: + c.Request.Header.Set(headerName, fmt.Sprintf("%v", v)) + } + } + } + } + } + p.proxy.ServeHTTP(c.Writer, c.Request) } diff --git a/pkg/proxy/proxy_test.go b/pkg/proxy/proxy_test.go index 58252fb..289c174 100644 --- a/pkg/proxy/proxy_test.go +++ b/pkg/proxy/proxy_test.go @@ -69,7 +69,7 @@ func TestProxyRouter_HandleProxy_ValidToken(t *testing.T) { proxyHeaders := make(http.Header) proxyHeaders.Set("X-Forwarded-By", "mcp-auth-proxy") - proxyRouter, err := NewProxyRouter("https://example.com", proxyHandler, publicKey, proxyHeaders, false) + proxyRouter, err := NewProxyRouter("https://example.com", proxyHandler, publicKey, proxyHeaders, false, nil) require.NoError(t, err) gin.SetMode(gin.TestMode) @@ -103,6 +103,99 @@ func TestProxyRouter_HandleProxy_ValidToken(t *testing.T) { assert.Equal(t, http.StatusUnauthorized, w.Code) } +func TestProxyRouter_HeaderMapping(t *testing.T) { + privateKey, publicKey, err := generateRSAKeyPair() + require.NoError(t, err) + + cases := []struct { + name string + headerMapping map[string]string + userinfo map[string]any + expectedHeaders map[string]string + }{ + { + name: "string field", + headerMapping: map[string]string{"/email": "X-Forwarded-Email"}, + userinfo: map[string]any{"email": "user@example.com"}, + expectedHeaders: map[string]string{ + "X-Forwarded-Email": "user@example.com", + }, + }, + { + name: "array field joined with comma", + headerMapping: map[string]string{"/groups": "X-Forwarded-Groups"}, + userinfo: map[string]any{"groups": []any{"admin", "users"}}, + expectedHeaders: map[string]string{ + "X-Forwarded-Groups": "admin,users", + }, + }, + { + name: "multiple mappings", + headerMapping: map[string]string{"/email": "X-Forwarded-Email", "/preferred_username": "X-Forwarded-User"}, + userinfo: map[string]any{"email": "user@example.com", "preferred_username": "john"}, + expectedHeaders: map[string]string{ + "X-Forwarded-Email": "user@example.com", + "X-Forwarded-User": "john", + }, + }, + { + name: "missing field is skipped", + headerMapping: map[string]string{"/email": "X-Forwarded-Email", "/missing": "X-Missing"}, + userinfo: map[string]any{"email": "user@example.com"}, + expectedHeaders: map[string]string{"X-Forwarded-Email": "user@example.com"}, + }, + { + name: "nil headerMapping", + headerMapping: nil, + userinfo: map[string]any{"email": "user@example.com"}, + expectedHeaders: map[string]string{}, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + receivedHeaders := http.Header{} + proxyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for k, v := range r.Header { + receivedHeaders[k] = v + } + w.WriteHeader(http.StatusOK) + }) + + proxyRouter, err := NewProxyRouter("https://example.com", proxyHandler, publicKey, http.Header{}, false, tt.headerMapping) + require.NoError(t, err) + + gin.SetMode(gin.TestMode) + router := gin.New() + proxyRouter.SetupRoutes(router) + + claims := jwt.MapClaims{ + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + } + if tt.userinfo != nil { + claims["userinfo"] = tt.userinfo + } + + token, err := createJWT(privateKey, claims) + require.NoError(t, err) + + req, err := http.NewRequest("GET", "/test", nil) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer "+token) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + for header, expected := range tt.expectedHeaders { + assert.Equal(t, expected, receivedHeaders.Get(header), "header %s mismatch", header) + } + }) + } +} + func TestProxyRouter_HTTPStreamingOnlyRejectsSSE(t *testing.T) { privateKey, publicKey, err := generateRSAKeyPair() require.NoError(t, err) @@ -185,7 +278,7 @@ func TestProxyRouter_HTTPStreamingOnlyRejectsSSE(t *testing.T) { w.WriteHeader(http.StatusOK) }) - proxyRouter, err := NewProxyRouter("https://example.com", proxyHandler, publicKey, http.Header{}, tt.streamingOnly) + proxyRouter, err := NewProxyRouter("https://example.com", proxyHandler, publicKey, http.Header{}, tt.streamingOnly, nil) require.NoError(t, err) gin.SetMode(gin.TestMode)