diff --git a/main.go b/main.go index 4b1dd4d..1f8f0a9 100644 --- a/main.go +++ b/main.go @@ -91,6 +91,30 @@ func parseAttributeMap(s string) map[string][]string { return result } +func parseHeaderMapping(s string) map[string]string { + result := make(map[string]string) + if s == "" { + return result + } + parts := splitWithEscapes(s, ",") + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + colonIdx := strings.LastIndex(part, ":") + if colonIdx == -1 { + continue + } + pointer := strings.TrimSpace(part[:colonIdx]) + header := strings.TrimSpace(part[colonIdx+1:]) + if pointer != "" && header != "" { + result[pointer] = header + } + } + return result +} + type proxyRunnerFunc func( listen string, tlsListen string, @@ -130,6 +154,7 @@ type proxyRunnerFunc func( proxyBearerToken string, proxyTarget []string, httpStreamingOnly bool, + headerMapping map[string]string, ) error func main() { @@ -174,6 +199,7 @@ func newRootCommand(run proxyRunnerFunc) *cobra.Command { var passwordHash string var proxyBearerToken string var proxyHeaders string + var headerMapping string var httpStreamingOnly bool var trustedProxies string @@ -255,6 +281,8 @@ func newRootCommand(run proxyRunnerFunc) *cobra.Command { } } + headerMappingMap := parseHeaderMapping(headerMapping) + if err := run( listen, tlsListen, @@ -294,6 +322,7 @@ func newRootCommand(run proxyRunnerFunc) *cobra.Command { proxyBearerToken, args, httpStreamingOnly, + headerMappingMap, ); err != nil { panic(err) } @@ -347,6 +376,7 @@ func newRootCommand(run proxyRunnerFunc) *cobra.Command { rootCmd.Flags().StringVar(&trustedProxies, "trusted-proxies", getEnvWithDefault("TRUSTED_PROXIES", ""), "Comma-separated list of trusted proxies (IP addresses or CIDR ranges)") rootCmd.Flags().StringVar(&proxyHeaders, "proxy-headers", getEnvWithDefault("PROXY_HEADERS", ""), "Comma-separated list of headers to add when proxying requests (format: Header1:Value1,Header2:Value2)") rootCmd.Flags().BoolVar(&httpStreamingOnly, "http-streaming-only", getEnvBoolWithDefault("HTTP_STREAMING_ONLY", false), "Reject SSE (GET) requests and keep the backend in HTTP streaming-only mode") + rootCmd.Flags().StringVar(&headerMapping, "header-mapping", getEnvWithDefault("HEADER_MAPPING", ""), "Comma-separated mapping of userinfo JSON pointer paths to header names (e.g., /email:X-Forwarded-Email,/preferred_username:X-Forwarded-User)") return rootCmd } diff --git a/main_test.go b/main_test.go index e3caff1..a78f3d3 100644 --- a/main_test.go +++ b/main_test.go @@ -168,6 +168,64 @@ func TestParseAttributeMap(t *testing.T) { } } +func TestParseHeaderMapping(t *testing.T) { + testCases := []struct { + name string + input string + expected map[string]string + }{ + { + name: "empty string", + input: "", + expected: map[string]string{}, + }, + { + name: "single mapping", + input: "/email:X-Forwarded-Email", + expected: map[string]string{ + "/email": "X-Forwarded-Email", + }, + }, + { + name: "multiple mappings", + input: "/email:X-Forwarded-Email,/preferred_username:X-Forwarded-User", + expected: map[string]string{ + "/email": "X-Forwarded-Email", + "/preferred_username": "X-Forwarded-User", + }, + }, + { + name: "nested JSON pointer", + input: "/org/team:X-Forwarded-Team", + expected: map[string]string{ + "/org/team": "X-Forwarded-Team", + }, + }, + { + name: "whitespace trimming", + input: " /email : X-Forwarded-Email , /sub : X-Forwarded-Sub ", + expected: map[string]string{ + "/email": "X-Forwarded-Email", + "/sub": "X-Forwarded-Sub", + }, + }, + { + name: "no colon - skipped", + input: "invalid", + expected: map[string]string{}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := parseHeaderMapping(tc.input) + if !reflect.DeepEqual(result, tc.expected) { + t.Errorf("Expected %v, got %v", tc.expected, result) + } + }) + } +} + func TestGetEnvWithDefault(t *testing.T) { testCases := []struct { name string @@ -344,6 +402,7 @@ func TestNewRootCommand_HTTPStreamingOnlyFlag(t *testing.T) { proxyBearerToken string, proxyTarget []string, httpStreamingOnly bool, + headerMapping map[string]string, ) error { streamingOnly = httpStreamingOnly receivedTargets = proxyTarget @@ -407,6 +466,7 @@ func TestNewRootCommand_HTTPStreamingOnlyFromEnv(t *testing.T) { proxyBearerToken string, proxyTarget []string, httpStreamingOnly bool, + headerMapping map[string]string, ) error { streamingOnly = httpStreamingOnly return nil diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 339ff7e..04fda55 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -2,6 +2,7 @@ package auth import ( "embed" + "encoding/json" "errors" "html/template" "net/http" @@ -68,6 +69,8 @@ const ( SessionKeyAuthorized = "authorized" SessionKeyRedirectURL = "redirect_url" SessionKeyOAuthState = "oauth_state" + SessionKeyUserID = "user_id" + SessionKeyUserInfo = "user_info" ) func (a *AuthRouter) SetupRoutes(router gin.IRouter) { @@ -87,7 +90,7 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) { a.renderError(c, err) return } - ok, user, err := provider.Authorization(c, token) + ok, user, userInfo, err := provider.Authorization(c, token) if err != nil { a.renderError(c, err) return @@ -97,6 +100,12 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) { return } session.Set(SessionKeyAuthorized, true) + session.Set(SessionKeyUserID, user) + if userInfo != nil { + if userInfoJSON, err := json.Marshal(userInfo); err == nil { + session.Set(SessionKeyUserInfo, string(userInfoJSON)) + } + } redirectURL := session.Get(SessionKeyRedirectURL) if redirectURL != nil { session.Delete(SessionKeyRedirectURL) @@ -177,6 +186,7 @@ func (a *AuthRouter) handleLoginPost(c *gin.Context) { session := sessions.Default(c) session.Set(SessionKeyAuthorized, true) + session.Set(SessionKeyUserID, PasswordUserID) redirectURL := session.Get(SessionKeyRedirectURL) if redirectURL != nil { session.Delete(SessionKeyRedirectURL) diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go index e1dc8e2..14dbb9d 100644 --- a/pkg/auth/auth_test.go +++ b/pkg/auth/auth_test.go @@ -85,7 +85,7 @@ func TestAuthenticationFlow(t *testing.T) { mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes() mockProvider.EXPECT().AuthCodeURL(gomock.Any()).Return("https://example.com/oauth", nil) mockProvider.EXPECT().Exchange(gomock.Any(), gomock.Any()).Return(mockToken, nil) - mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(true, "authorized_user", nil) + mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(true, "authorized_user", map[string]any{"email": "authorized_user@example.com"}, nil) // Create AuthRouter authRouter, err := NewAuthRouter(nil, false, mockProvider) @@ -146,7 +146,7 @@ func TestAuthenticationFlow(t *testing.T) { mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes() mockProvider.EXPECT().AuthCodeURL(gomock.Any()).Return("https://example.com/oauth", nil) mockProvider.EXPECT().Exchange(gomock.Any(), gomock.Any()).Return(mockToken, nil) - mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(false, "unauthorized_user", nil) + mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(false, "unauthorized_user", map[string]any{"email": "unauthorized_user@example.com"}, nil) // Create AuthRouter authRouter, err := NewAuthRouter(nil, false, mockProvider) diff --git a/pkg/auth/github.go b/pkg/auth/github.go index d931917..e1afcfb 100644 --- a/pkg/auth/github.go +++ b/pkg/auth/github.go @@ -85,30 +85,30 @@ func (p *githubProvider) Exchange(c *gin.Context, state string) (*oauth2.Token, return token, nil } -func (p *githubProvider) Authorization(ctx context.Context, token *oauth2.Token) (bool, string, error) { +func (p *githubProvider) Authorization(ctx context.Context, token *oauth2.Token) (bool, string, map[string]any, error) { client := p.oauth2.Client(ctx, token) resp1, err := client.Get(utils.Must(url.JoinPath(p.endpoint, "/user"))) if err != nil { - return false, "", err + return false, "", nil, err } if resp1.StatusCode < 200 || resp1.StatusCode >= 300 { - return false, "", errors.New("failed to get user info from GitHub API: " + resp1.Status) + return false, "", nil, errors.New("failed to get user info from GitHub API: " + resp1.Status) } defer resp1.Body.Close() - var userInfo struct { - Login string `json:"login"` - } - if err := json.NewDecoder(resp1.Body).Decode(&userInfo); err != nil { - return false, "", err + var userInfoMap map[string]any + if err := json.NewDecoder(resp1.Body).Decode(&userInfoMap); err != nil { + return false, "", nil, err } + login, _ := userInfoMap["login"].(string) + if len(p.allowedUsers) == 0 && len(p.allowedOrgs) == 0 { - return true, userInfo.Login, nil + return true, login, userInfoMap, nil } - if slices.Contains(p.allowedUsers, userInfo.Login) { - return true, userInfo.Login, nil + if slices.Contains(p.allowedUsers, login) { + return true, login, userInfoMap, nil } allowedOrgTeams := []string{} @@ -124,31 +124,31 @@ func (p *githubProvider) Authorization(ctx context.Context, token *oauth2.Token) if len(allowedOrgs) > 0 { resp2, err := client.Get(utils.Must(url.JoinPath(p.endpoint, "/user/orgs"))) if err != nil { - return false, "", err + return false, "", nil, err } if resp2.StatusCode < 200 || resp2.StatusCode >= 300 { - return false, "", errors.New("failed to get user info from GitHub API: " + resp2.Status) + return false, "", nil, errors.New("failed to get user info from GitHub API: " + resp2.Status) } defer resp2.Body.Close() var orgInfo []struct { Login string `json:"login"` } if err := json.NewDecoder(resp2.Body).Decode(&orgInfo); err != nil { - return false, "", err + return false, "", nil, err } for _, o := range orgInfo { if slices.Contains(allowedOrgs, o.Login) { - return true, userInfo.Login, nil + return true, login, userInfoMap, nil } } } if len(allowedOrgTeams) > 0 { resp3, err := client.Get(utils.Must(url.JoinPath(p.endpoint, "/user/teams"))) if err != nil { - return false, "", err + return false, "", nil, err } if resp3.StatusCode < 200 || resp3.StatusCode >= 300 { - return false, "", errors.New("failed to get user info from GitHub API: " + resp3.Status) + return false, "", nil, errors.New("failed to get user info from GitHub API: " + resp3.Status) } defer resp3.Body.Close() var teamInfo []struct { @@ -158,14 +158,14 @@ func (p *githubProvider) Authorization(ctx context.Context, token *oauth2.Token) Slug string `json:"slug"` } if err := json.NewDecoder(resp3.Body).Decode(&teamInfo); err != nil { - return false, "", err + return false, "", nil, err } for _, team := range teamInfo { if slices.Contains(allowedOrgTeams, team.Organization.Login+":"+team.Slug) { - return true, userInfo.Login, nil + return true, login, userInfoMap, nil } } } - return false, userInfo.Login, nil + return false, login, userInfoMap, nil } diff --git a/pkg/auth/github_test.go b/pkg/auth/github_test.go index f3e0eeb..93b1b08 100644 --- a/pkg/auth/github_test.go +++ b/pkg/auth/github_test.go @@ -168,7 +168,7 @@ func TestGitHubProviderAuthorization(t *testing.T) { }) // Call the Authorization method - ok, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) + ok, _, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) require.NoError(t, err) require.Equal(t, expect, ok) }) diff --git a/pkg/auth/google.go b/pkg/auth/google.go index ccc46cf..69d7763 100644 --- a/pkg/auth/google.go +++ b/pkg/auth/google.go @@ -83,38 +83,36 @@ func (p *googleProvider) Exchange(c *gin.Context, state string) (*oauth2.Token, return token, nil } -func (p *googleProvider) Authorization(ctx context.Context, token *oauth2.Token) (bool, string, error) { +func (p *googleProvider) Authorization(ctx context.Context, token *oauth2.Token) (bool, string, map[string]any, error) { client := p.oauth2.Client(ctx, token) resp, err := client.Get(p.userinfoEndpoint) if err != nil { - return false, "", err + return false, "", nil, err } if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return false, "", errors.New("failed to get user info from Google API: " + resp.Status) + return false, "", nil, errors.New("failed to get user info from Google API: " + resp.Status) } defer resp.Body.Close() - var userInfo struct { - Sub string `json:"sub"` - Name string `json:"name"` - Email string `json:"email"` - HD string `json:"hd"` - } - if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil { - return false, "", err + var userInfoMap map[string]any + if err := json.NewDecoder(resp.Body).Decode(&userInfoMap); err != nil { + return false, "", nil, err } + email, _ := userInfoMap["email"].(string) + hd, _ := userInfoMap["hd"].(string) + if len(p.allowedUsers) == 0 && len(p.allowedWorkspaces) == 0 { - return true, userInfo.Email, nil + return true, email, userInfoMap, nil } - if slices.Contains(p.allowedUsers, userInfo.Email) { - return true, userInfo.Email, nil + if slices.Contains(p.allowedUsers, email) { + return true, email, userInfoMap, nil } - if slices.Contains(p.allowedWorkspaces, userInfo.HD) { - return true, userInfo.Email, nil + if slices.Contains(p.allowedWorkspaces, hd) { + return true, email, userInfoMap, nil } - return false, userInfo.Email, nil + return false, email, userInfoMap, nil } diff --git a/pkg/auth/google_test.go b/pkg/auth/google_test.go index ae792ae..c4aaad6 100644 --- a/pkg/auth/google_test.go +++ b/pkg/auth/google_test.go @@ -156,7 +156,7 @@ func TestGoogleProviderAuthorization(t *testing.T) { c.Data(http.StatusOK, "application/json", []byte(tt.userResp)) }) - ok, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) + ok, _, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) require.NoError(t, err) require.Equal(t, tt.expect, ok) }) @@ -170,7 +170,7 @@ func TestGoogleProviderAuthorizationAPIError(t *testing.T) { c.JSON(http.StatusInternalServerError, gin.H{"error": "internal error"}) }) - ok, user, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) + ok, user, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) require.Error(t, err) require.False(t, ok) require.Empty(t, user) @@ -184,7 +184,7 @@ func TestGoogleProviderAuthorizationInvalidJSON(t *testing.T) { c.Data(http.StatusOK, "application/json", []byte(`invalid json`)) }) - ok, user, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) + ok, user, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) require.Error(t, err) require.False(t, ok) require.Empty(t, user) diff --git a/pkg/auth/interface.go b/pkg/auth/interface.go index 5aebc70..f3f6f43 100644 --- a/pkg/auth/interface.go +++ b/pkg/auth/interface.go @@ -15,5 +15,5 @@ type Provider interface { AuthURL() string AuthCodeURL(state string) (string, error) Exchange(c *gin.Context, state string) (*oauth2.Token, error) - Authorization(ctx context.Context, token *oauth2.Token) (bool, string, error) + Authorization(ctx context.Context, token *oauth2.Token) (bool, string, map[string]any, error) } diff --git a/pkg/auth/mock.go b/pkg/auth/mock.go index 0a031eb..4814a9e 100644 --- a/pkg/auth/mock.go +++ b/pkg/auth/mock.go @@ -72,13 +72,14 @@ func (mr *MockProviderMockRecorder) AuthURL() *gomock.Call { } // Authorization mocks base method. -func (m *MockProvider) Authorization(ctx context.Context, token *oauth2.Token) (bool, string, error) { +func (m *MockProvider) Authorization(ctx context.Context, token *oauth2.Token) (bool, string, map[string]any, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Authorization", ctx, token) ret0, _ := ret[0].(bool) ret1, _ := ret[1].(string) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 + ret2, _ := ret[2].(map[string]any) + ret3, _ := ret[3].(error) + return ret0, ret1, ret2, ret3 } // Authorization indicates an expected call of Authorization. diff --git a/pkg/auth/oidc.go b/pkg/auth/oidc.go index 9d0aa86..1be1c02 100644 --- a/pkg/auth/oidc.go +++ b/pkg/auth/oidc.go @@ -128,40 +128,44 @@ func (p *oidcProvider) Exchange(c *gin.Context, state string) (*oauth2.Token, er return token, nil } -func (p *oidcProvider) Authorization(ctx context.Context, token *oauth2.Token) (bool, string, error) { +func (p *oidcProvider) Authorization(ctx context.Context, token *oauth2.Token) (bool, string, map[string]any, error) { client := p.oauth2.Client(ctx, token) resp, err := client.Get(p.userInfoURL) if err != nil { - return false, "", err + return false, "", nil, err } defer resp.Body.Close() var obj any if err := json.NewDecoder(resp.Body).Decode(&obj); err != nil { - return false, "", err + return false, "", nil, err + } + userInfoMap, ok := obj.(map[string]any) + if !ok { + return false, "", nil, errors.New("userinfo response is not a JSON object") } v, err := jsonpointer.Get(obj, p.userIDField) if err != nil { - return false, "", err + return false, "", nil, err } userID, ok := v.(string) if !ok { - return false, "", errors.New("user ID field is not a string") + return false, "", nil, errors.New("user ID field is not a string") } // If no restrictions are set, allow all users if len(p.allowedUsers) == 0 && len(p.allowedUsersGlob) == 0 && len(p.allowedAttributes) == 0 && len(p.allowedAttributesGlob) == 0 { - return true, userID, nil + return true, userID, userInfoMap, nil } // Check exact user matches first if slices.Contains(p.allowedUsers, userID) { - return true, userID, nil + return true, userID, userInfoMap, nil } // Check user glob patterns for _, g := range p.allowedUsersGlob { if g.Match(userID) { - return true, userID, nil + return true, userID, userInfoMap, nil } } @@ -172,7 +176,7 @@ func (p *oidcProvider) Authorization(ctx context.Context, token *oauth2.Token) ( continue // Attribute not found, skip } if matchAttributeValue(attrValue, allowedValues) { - return true, userID, nil + return true, userID, userInfoMap, nil } } @@ -183,11 +187,11 @@ func (p *oidcProvider) Authorization(ctx context.Context, token *oauth2.Token) ( continue // Attribute not found, skip } if matchAttributeGlob(attrValue, globs) { - return true, userID, nil + return true, userID, userInfoMap, nil } } - return false, userID, nil + return false, userID, userInfoMap, nil } // matchAttributeValue checks if an attribute value matches any of the allowed values. diff --git a/pkg/auth/oidc_test.go b/pkg/auth/oidc_test.go index 604064c..82da4e6 100644 --- a/pkg/auth/oidc_test.go +++ b/pkg/auth/oidc_test.go @@ -167,7 +167,7 @@ func TestOIDCProviderAuthorization(t *testing.T) { }) // Call the Authorization method - ok, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) + ok, _, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) require.NoError(t, err) require.Equal(t, tt.expect, ok) }) @@ -224,7 +224,7 @@ func TestOIDCProviderErrors(t *testing.T) { c.JSON(http.StatusOK, gin.H{"sub": "user1"}) }) - ok, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) + ok, _, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) require.Error(t, err) require.False(t, ok) }) @@ -237,7 +237,7 @@ func TestOIDCProviderErrors(t *testing.T) { c.JSON(http.StatusOK, gin.H{"sub": 12345}) }) - ok, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) + ok, _, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) require.Error(t, err) require.False(t, ok) }) @@ -250,7 +250,7 @@ func TestOIDCProviderErrors(t *testing.T) { c.JSON(http.StatusInternalServerError, gin.H{"error": "server error"}) }) - ok, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) + ok, _, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) require.Error(t, err) require.False(t, ok) }) @@ -316,7 +316,7 @@ func TestOIDCProviderGlobPatterns(t *testing.T) { provider.userInfoURL = tsUser.URL + "/userinfo" // Test authorization - authorized, userID, err := provider.Authorization(context.Background(), &oauth2.Token{AccessToken: "test"}) + authorized, userID, _, err := provider.Authorization(context.Background(), &oauth2.Token{AccessToken: "test"}) require.NoError(t, err) require.Equal(t, tc.email, userID) require.Equal(t, tc.expected, authorized, "Expected %v for email %s", tc.expected, tc.email) @@ -449,7 +449,7 @@ func TestOIDCProviderAttributeMatching(t *testing.T) { provider.userInfoURL = tsUser.URL + "/userinfo" // Test authorization - authorized, _, err := provider.Authorization(context.Background(), &oauth2.Token{AccessToken: "test"}) + authorized, _, _, err := provider.Authorization(context.Background(), &oauth2.Token{AccessToken: "test"}) require.NoError(t, err) require.Equal(t, tc.expected, authorized, "Expected %v for test case %s", tc.expected, tc.name) }) diff --git a/pkg/idp/idp.go b/pkg/idp/idp.go index 9b8264b..de868c6 100644 --- a/pkg/idp/idp.go +++ b/pkg/idp/idp.go @@ -4,11 +4,13 @@ import ( "context" "crypto/rsa" "encoding/base64" + "encoding/json" "math/big" "net/url" "strings" "time" + "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" "github.com/ory/fosite" "github.com/ory/fosite/compose" @@ -159,7 +161,18 @@ func (a *IDPRouter) handleAuthorizationReturn(c *gin.Context) { ar.GrantScope(scope) } ar.GrantAudience(a.externalURL) - jwtSession, err := NewJWTSessionWithKey(a.externalURL, "user", a.privKey) + + session := sessions.Default(c) + subject := "user" + if userID, ok := session.Get(auth.SessionKeyUserID).(string); ok && userID != "" { + subject = userID + } + var userInfo map[string]any + if userInfoJSON, ok := session.Get(auth.SessionKeyUserInfo).(string); ok && userInfoJSON != "" { + json.Unmarshal([]byte(userInfoJSON), &userInfo) + } + + jwtSession, err := NewJWTSessionWithKey(a.externalURL, subject, a.privKey, userInfo) if err != nil { a.logger.With(utils.Err(err)...).Error("Failed to create JWT session", zap.Error(err)) a.provider.WriteAuthorizeError(ctx, c.Writer, ar, err) @@ -179,7 +192,7 @@ func (a *IDPRouter) handleAuthorizationReturn(c *gin.Context) { func (a *IDPRouter) handleToken(c *gin.Context) { ctx := c.Request.Context() - session, err := NewJWTSessionWithKey("", "", a.privKey) + session, err := NewJWTSessionWithKey("", "", a.privKey, nil) if err != nil { a.logger.With(utils.Err(err)...).Error("Failed to create JWT session for token", zap.Error(err)) a.provider.WriteAccessError(ctx, c.Writer, nil, fosite.ErrServerError.WithWrap(err)) @@ -205,7 +218,7 @@ func (a *IDPRouter) handleToken(c *gin.Context) { func (a *IDPRouter) handleIntrospect(c *gin.Context) { ctx := c.Request.Context() - session, err := NewJWTSessionWithKey("", "", a.privKey) + session, err := NewJWTSessionWithKey("", "", a.privKey, nil) if err != nil { a.provider.WriteIntrospectionError(ctx, c.Writer, fosite.ErrServerError.WithWrap(err)) return @@ -410,11 +423,15 @@ func (a *IDPRouter) handleJWKS(c *gin.Context) { c.JSON(200, ks) } -func NewJWTSessionWithKey(iss string, subject string, privateKey *rsa.PrivateKey) (*Session, error) { +func NewJWTSessionWithKey(iss string, subject string, privateKey *rsa.PrivateKey, userInfo map[string]any) (*Session, error) { keyID, err := utils.GenerateKeyID(&privateKey.PublicKey) if err != nil { return nil, err } + var extra map[string]any + if userInfo != nil { + extra = map[string]any{"userinfo": userInfo} + } return &Session{ DefaultSession: &fosite.DefaultSession{ Username: subject, @@ -427,6 +444,7 @@ func NewJWTSessionWithKey(iss string, subject string, privateKey *rsa.PrivateKey ExpiresAt: time.Now().Add(time.Hour), IssuedAt: time.Now(), NotBefore: time.Now(), + Extra: extra, }, JWTHeader: &jwt.Headers{ Extra: map[string]any{ @@ -468,6 +486,7 @@ func (s *Session) Clone() fosite.Session { ExpiresAt: s.JWTClaims.ExpiresAt, IssuedAt: s.JWTClaims.IssuedAt, NotBefore: s.JWTClaims.NotBefore, + Extra: s.JWTClaims.Extra, }, JWTHeader: &jwt.Headers{ Extra: make(map[string]any), diff --git a/pkg/mcp-proxy/main.go b/pkg/mcp-proxy/main.go index cc7b8b1..1ce8ad8 100644 --- a/pkg/mcp-proxy/main.go +++ b/pkg/mcp-proxy/main.go @@ -76,6 +76,7 @@ func Run( proxyBearerToken string, proxyTarget []string, httpStreamingOnly bool, + headerMapping map[string]string, ) error { ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt) defer stop() @@ -292,7 +293,7 @@ func Run( if err != nil { return fmt.Errorf("failed to create IDP router: %w", err) } - proxyRouter, err := newProxyRouter(externalURL, beHandler, &privKey.PublicKey, proxyHeadersMap, httpStreamingOnly) + proxyRouter, err := newProxyRouter(externalURL, beHandler, &privKey.PublicKey, proxyHeadersMap, httpStreamingOnly, headerMapping) if err != nil { return fmt.Errorf("failed to create proxy router: %w", err) } diff --git a/pkg/mcp-proxy/main_test.go b/pkg/mcp-proxy/main_test.go index 6873e1d..b18a03d 100644 --- a/pkg/mcp-proxy/main_test.go +++ b/pkg/mcp-proxy/main_test.go @@ -34,7 +34,7 @@ func TestRun_NormalizesExternalURLTrailingSlash(t *testing.T) { for _, tt := range cases { t.Run(tt.name, func(t *testing.T) { var receivedURL string - newProxyRouter = func(externalURL string, proxyHandler http.Handler, publicKey *rsa.PublicKey, proxyHeaders http.Header, httpStreamingOnly bool) (*proxy.ProxyRouter, error) { + newProxyRouter = func(externalURL string, proxyHandler http.Handler, publicKey *rsa.PublicKey, proxyHeaders http.Header, httpStreamingOnly bool, headerMapping map[string]string) (*proxy.ProxyRouter, error) { receivedURL = externalURL return nil, errors.New("stop early") } @@ -47,7 +47,7 @@ func TestRun_NormalizesExternalURLTrailingSlash(t *testing.T) { "", "", nil, nil, "", "", "", nil, "", "", nil, nil, nil, nil, false, "", "", nil, nil, "", - []string{"http://example.com"}, false, + []string{"http://example.com"}, false, nil, ) if tt.wantErr { @@ -70,7 +70,7 @@ func TestRun_PassesHTTPStreamingOnlyToProxyRouter(t *testing.T) { }) var streamingOnlyReceived bool - newProxyRouter = func(externalURL string, proxyHandler http.Handler, publicKey *rsa.PublicKey, proxyHeaders http.Header, httpStreamingOnly bool) (*proxy.ProxyRouter, error) { + newProxyRouter = func(externalURL string, proxyHandler http.Handler, publicKey *rsa.PublicKey, proxyHeaders http.Header, httpStreamingOnly bool, headerMapping map[string]string) (*proxy.ProxyRouter, error) { streamingOnlyReceived = httpStreamingOnly return nil, errors.New("proxy router init failed") } @@ -114,6 +114,7 @@ func TestRun_PassesHTTPStreamingOnlyToProxyRouter(t *testing.T) { "", []string{"http://example.com"}, true, + nil, ) require.Error(t, err) diff --git a/pkg/proxy/proxy.go b/pkg/proxy/proxy.go index 3938365..6b86521 100644 --- a/pkg/proxy/proxy.go +++ b/pkg/proxy/proxy.go @@ -8,6 +8,7 @@ import ( "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" + "github.com/mattn/go-jsonpointer" ) type ProxyRouter struct { @@ -16,6 +17,7 @@ type ProxyRouter struct { publicKey *rsa.PublicKey proxyHeaders http.Header httpStreamingOnly bool + headerMapping map[string]string } func NewProxyRouter( @@ -24,6 +26,7 @@ func NewProxyRouter( publicKey *rsa.PublicKey, proxyHeaders http.Header, httpStreamingOnly bool, + headerMapping map[string]string, ) (*ProxyRouter, error) { return &ProxyRouter{ externalURL: externalURL, @@ -31,6 +34,7 @@ func NewProxyRouter( publicKey: publicKey, proxyHeaders: proxyHeaders, httpStreamingOnly: httpStreamingOnly, + headerMapping: headerMapping, }, nil } @@ -87,6 +91,33 @@ func (p *ProxyRouter) handleProxy(c *gin.Context) { } } + if len(p.headerMapping) > 0 { + if claims, ok := token.Claims.(jwt.MapClaims); ok { + if userinfo, exists := claims["userinfo"]; exists { + for pointer, headerName := range p.headerMapping { + val, err := jsonpointer.Get(userinfo, pointer) + if err != nil { + continue + } + switch v := val.(type) { + case string: + c.Request.Header.Set(headerName, v) + case []any: + var parts []string + for _, item := range v { + if s, ok := item.(string); ok { + parts = append(parts, s) + } + } + c.Request.Header.Set(headerName, strings.Join(parts, ",")) + default: + c.Request.Header.Set(headerName, fmt.Sprintf("%v", v)) + } + } + } + } + } + p.proxy.ServeHTTP(c.Writer, c.Request) } diff --git a/pkg/proxy/proxy_test.go b/pkg/proxy/proxy_test.go index 995ae67..8c9ee47 100644 --- a/pkg/proxy/proxy_test.go +++ b/pkg/proxy/proxy_test.go @@ -70,7 +70,7 @@ func TestProxyRouter_HandleProxy_ValidToken(t *testing.T) { proxyHeaders := make(http.Header) proxyHeaders.Set("X-Forwarded-By", "mcp-auth-proxy") - proxyRouter, err := NewProxyRouter("https://example.com", proxyHandler, publicKey, proxyHeaders, false) + proxyRouter, err := NewProxyRouter("https://example.com", proxyHandler, publicKey, proxyHeaders, false, nil) require.NoError(t, err) gin.SetMode(gin.TestMode) @@ -104,11 +104,104 @@ func TestProxyRouter_HandleProxy_ValidToken(t *testing.T) { assert.Equal(t, http.StatusUnauthorized, w.Code) } +func TestProxyRouter_HeaderMapping(t *testing.T) { + privateKey, publicKey, err := generateRSAKeyPair() + require.NoError(t, err) + + cases := []struct { + name string + headerMapping map[string]string + userinfo map[string]any + expectedHeaders map[string]string + }{ + { + name: "string field", + headerMapping: map[string]string{"/email": "X-Forwarded-Email"}, + userinfo: map[string]any{"email": "user@example.com"}, + expectedHeaders: map[string]string{ + "X-Forwarded-Email": "user@example.com", + }, + }, + { + name: "array field joined with comma", + headerMapping: map[string]string{"/groups": "X-Forwarded-Groups"}, + userinfo: map[string]any{"groups": []any{"admin", "users"}}, + expectedHeaders: map[string]string{ + "X-Forwarded-Groups": "admin,users", + }, + }, + { + name: "multiple mappings", + headerMapping: map[string]string{"/email": "X-Forwarded-Email", "/preferred_username": "X-Forwarded-User"}, + userinfo: map[string]any{"email": "user@example.com", "preferred_username": "john"}, + expectedHeaders: map[string]string{ + "X-Forwarded-Email": "user@example.com", + "X-Forwarded-User": "john", + }, + }, + { + name: "missing field is skipped", + headerMapping: map[string]string{"/email": "X-Forwarded-Email", "/missing": "X-Missing"}, + userinfo: map[string]any{"email": "user@example.com"}, + expectedHeaders: map[string]string{"X-Forwarded-Email": "user@example.com"}, + }, + { + name: "nil headerMapping", + headerMapping: nil, + userinfo: map[string]any{"email": "user@example.com"}, + expectedHeaders: map[string]string{}, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + receivedHeaders := http.Header{} + proxyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for k, v := range r.Header { + receivedHeaders[k] = v + } + w.WriteHeader(http.StatusOK) + }) + + proxyRouter, err := NewProxyRouter("https://example.com", proxyHandler, publicKey, http.Header{}, false, tt.headerMapping) + require.NoError(t, err) + + gin.SetMode(gin.TestMode) + router := gin.New() + proxyRouter.SetupRoutes(router) + + claims := jwt.MapClaims{ + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + } + if tt.userinfo != nil { + claims["userinfo"] = tt.userinfo + } + + token, err := createJWT(privateKey, claims) + require.NoError(t, err) + + req, err := http.NewRequest("GET", "/test", nil) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer "+token) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + for header, expected := range tt.expectedHeaders { + assert.Equal(t, expected, receivedHeaders.Get(header), "header %s mismatch", header) + } + }) + } +} + func TestProxyRouter_ProtectedResourceTrailingSlash(t *testing.T) { _, publicKey, err := generateRSAKeyPair() require.NoError(t, err) - proxyRouter, err := NewProxyRouter("https://example.com/", http.NotFoundHandler(), publicKey, http.Header{}, false) + proxyRouter, err := NewProxyRouter("https://example.com/", http.NotFoundHandler(), publicKey, http.Header{}, false, nil) require.NoError(t, err) gin.SetMode(gin.TestMode) @@ -212,7 +305,7 @@ func TestProxyRouter_HTTPStreamingOnlyRejectsSSE(t *testing.T) { w.WriteHeader(http.StatusOK) }) - proxyRouter, err := NewProxyRouter("https://example.com", proxyHandler, publicKey, http.Header{}, tt.streamingOnly) + proxyRouter, err := NewProxyRouter("https://example.com", proxyHandler, publicKey, http.Header{}, tt.streamingOnly, nil) require.NoError(t, err) gin.SetMode(gin.TestMode)