From 0db70d91c839d4c8b09771426010d05c9a9fa02b Mon Sep 17 00:00:00 2001 From: Takanori Hirano Date: Mon, 18 Aug 2025 12:36:53 +0000 Subject: [PATCH 1/2] fix: implement OAuth CSRF protection with state validation - Generate cryptographically secure random state for OAuth requests - Store OAuth state in session during authorization - Validate state parameter in callback handlers - Add GenerateState utility function for secure random generation Fixes #27 --- pkg/auth/github.go | 10 +++++++--- pkg/auth/google.go | 10 +++++++--- pkg/auth/main.go | 50 ++++++++++++++++++++++++++++++++-------------- pkg/utils/main.go | 8 ++++++++ 4 files changed, 57 insertions(+), 21 deletions(-) diff --git a/pkg/auth/github.go b/pkg/auth/github.go index c697ebb..df83505 100644 --- a/pkg/auth/github.go +++ b/pkg/auth/github.go @@ -3,6 +3,7 @@ package auth import ( "context" "encoding/json" + "errors" "net/url" "github.com/gin-gonic/gin" @@ -44,12 +45,15 @@ func (p *githubProvider) AuthURL() string { return GitHubAuthEndpoint } -func (p *githubProvider) AuthCodeURL(c *gin.Context) (string, error) { - authURL := p.oauth2.AuthCodeURL("state", oauth2.AccessTypeOffline) +func (p *githubProvider) AuthCodeURL(c *gin.Context, state string) (string, error) { + authURL := p.oauth2.AuthCodeURL(state, oauth2.AccessTypeOffline) return authURL, nil } -func (p *githubProvider) Exchange(c *gin.Context) (*oauth2.Token, error) { +func (p *githubProvider) Exchange(c *gin.Context, state string) (*oauth2.Token, error) { + if c.Query("state") != state { + return nil, errors.New("invalid OAuth state") + } code := c.Query("code") token, err := p.oauth2.Exchange(c, code) if err != nil { diff --git a/pkg/auth/google.go b/pkg/auth/google.go index 2335484..b19d74e 100644 --- a/pkg/auth/google.go +++ b/pkg/auth/google.go @@ -3,6 +3,7 @@ package auth import ( "context" "encoding/json" + "errors" "net/url" "github.com/gin-gonic/gin" @@ -40,8 +41,8 @@ func (p *googleProvider) RedirectURL() string { return GoogleCallbackEndpoint } -func (p *googleProvider) AuthCodeURL(c *gin.Context) (string, error) { - authURL := p.oauth2.AuthCodeURL("state", oauth2.AccessTypeOffline) +func (p *googleProvider) AuthCodeURL(c *gin.Context, state string) (string, error) { + authURL := p.oauth2.AuthCodeURL(state, oauth2.AccessTypeOffline) return authURL, nil } @@ -49,7 +50,10 @@ func (p *googleProvider) AuthURL() string { return GoogleAuthEndpoint } -func (p *googleProvider) Exchange(c *gin.Context) (*oauth2.Token, error) { +func (p *googleProvider) Exchange(c *gin.Context, state string) (*oauth2.Token, error) { + if c.Query("state") != state { + return nil, errors.New("invalid OAuth state") + } code := c.Query("code") token, err := p.oauth2.Exchange(c, code) if err != nil { diff --git a/pkg/auth/main.go b/pkg/auth/main.go index 98eb90e..be28a36 100644 --- a/pkg/auth/main.go +++ b/pkg/auth/main.go @@ -3,11 +3,13 @@ package auth import ( "context" "embed" + "errors" "html/template" "net/http" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" + "github.com/sigbit/mcp-auth-proxy/pkg/utils" "golang.org/x/crypto/bcrypt" "golang.org/x/oauth2" ) @@ -19,8 +21,8 @@ type Provider interface { Name() string RedirectURL() string AuthURL() string - AuthCodeURL(c *gin.Context) (string, error) - Exchange(c *gin.Context) (*oauth2.Token, error) + AuthCodeURL(c *gin.Context, state string) (string, error) + Exchange(c *gin.Context, state string) (*oauth2.Token, error) GetUserID(ctx context.Context, token *oauth2.Token) (string, error) Authorization(userid string) (bool, error) } @@ -63,9 +65,14 @@ const ( GoogleCallbackEndpoint = "/.auth/google/callback" GitHubAuthEndpoint = "/.auth/github" GitHubCallbackEndpoint = "/.auth/github/callback" - + PasswordProvider = "password" PasswordUserID = "password_user" + + SessionKeyProvider = "provider" + SessionKeyUserID = "user_id" + SessionKeyRedirectURL = "redirect_url" + SessionKeyOAuthState = "oauth_state" ) func (a *AuthRouter) SetupRoutes(router gin.IRouter) { @@ -75,8 +82,12 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) { for providerName, provider := range a.providers { router.GET(provider.RedirectURL(), func(c *gin.Context) { session := sessions.Default(c) - - token, err := provider.Exchange(c) + state := session.Get(SessionKeyOAuthState) + if state == nil { + c.Error(errors.New("OAuth state is missing")) + return + } + token, err := provider.Exchange(c, state.(string)) if err != nil { c.Error(err) return @@ -86,19 +97,28 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) { c.Error(err) return } - session.Set("provider", providerName) - session.Set("user_id", userID) + session.Set(SessionKeyProvider, providerName) + session.Set(SessionKeyUserID, userID) session.Save() - redirectURL := session.Get("redirect_url") + redirectURL := session.Get(SessionKeyRedirectURL) c.Redirect(http.StatusFound, redirectURL.(string)) }) router.GET(provider.AuthURL(), func(c *gin.Context) { - url, err := provider.AuthCodeURL(c) + session := sessions.Default(c) + + state, err := utils.GenerateState() + if err != nil { + c.Error(err) + return + } + url, err := provider.AuthCodeURL(c, state) if err != nil { c.Error(err) return } + session.Set(SessionKeyOAuthState, state) + session.Save() c.Redirect(http.StatusFound, url) }) } @@ -190,11 +210,11 @@ func (a *AuthRouter) handleLoginPost(c *gin.Context) { } session := sessions.Default(c) - session.Set("provider", PasswordProvider) - session.Set("user_id", PasswordUserID) + session.Set(SessionKeyProvider, PasswordProvider) + session.Set(SessionKeyUserID, PasswordUserID) session.Save() - redirectURL := session.Get("redirect_url") + redirectURL := session.Get(SessionKeyRedirectURL) if redirectURL == nil { c.Redirect(http.StatusFound, "/") return @@ -212,10 +232,10 @@ func (a *AuthRouter) handleLogout(c *gin.Context) { func (a *AuthRouter) RequireAuth() gin.HandlerFunc { return func(c *gin.Context) { session := sessions.Default(c) - providerName := session.Get("provider") - userID := session.Get("user_id") + providerName := session.Get(SessionKeyProvider) + userID := session.Get(SessionKeyUserID) if providerName == nil || userID == nil { - session.Set("redirect_url", c.Request.URL.String()) + session.Set(SessionKeyRedirectURL, c.Request.URL.String()) session.Save() c.Redirect(http.StatusFound, LoginEndpoint) return diff --git a/pkg/utils/main.go b/pkg/utils/main.go index 15b152a..ade15f9 100644 --- a/pkg/utils/main.go +++ b/pkg/utils/main.go @@ -32,3 +32,11 @@ func GenerateClientSecret() (string, error) { } return hex.EncodeToString(bytes), nil } + +func GenerateState() (string, error) { + bytes := make([]byte, 8) + if _, err := rand.Read(bytes); err != nil { + return "", err + } + return hex.EncodeToString(bytes), nil +} From 42013fb4dea4a5e98865036eebb755f7515e8630 Mon Sep 17 00:00:00 2001 From: Takanori Hirano Date: Mon, 18 Aug 2025 21:41:51 +0900 Subject: [PATCH 2/2] Update pkg/utils/main.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- pkg/utils/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/utils/main.go b/pkg/utils/main.go index ade15f9..fff790c 100644 --- a/pkg/utils/main.go +++ b/pkg/utils/main.go @@ -34,7 +34,7 @@ func GenerateClientSecret() (string, error) { } func GenerateState() (string, error) { - bytes := make([]byte, 8) + bytes := make([]byte, 16) if _, err := rand.Read(bytes); err != nil { return "", err }