@@ -3,11 +3,13 @@ package auth
33import (
44 "context"
55 "embed"
6+ "errors"
67 "html/template"
78 "net/http"
89
910 "github.com/gin-contrib/sessions"
1011 "github.com/gin-gonic/gin"
12+ "github.com/sigbit/mcp-auth-proxy/pkg/utils"
1113 "golang.org/x/crypto/bcrypt"
1214 "golang.org/x/oauth2"
1315)
@@ -19,8 +21,8 @@ type Provider interface {
1921 Name () string
2022 RedirectURL () string
2123 AuthURL () string
22- AuthCodeURL (c * gin.Context ) (string , error )
23- Exchange (c * gin.Context ) (* oauth2.Token , error )
24+ AuthCodeURL (c * gin.Context , state string ) (string , error )
25+ Exchange (c * gin.Context , state string ) (* oauth2.Token , error )
2426 GetUserID (ctx context.Context , token * oauth2.Token ) (string , error )
2527 Authorization (userid string ) (bool , error )
2628}
@@ -63,9 +65,14 @@ const (
6365 GoogleCallbackEndpoint = "/.auth/google/callback"
6466 GitHubAuthEndpoint = "/.auth/github"
6567 GitHubCallbackEndpoint = "/.auth/github/callback"
66-
68+
6769 PasswordProvider = "password"
6870 PasswordUserID = "password_user"
71+
72+ SessionKeyProvider = "provider"
73+ SessionKeyUserID = "user_id"
74+ SessionKeyRedirectURL = "redirect_url"
75+ SessionKeyOAuthState = "oauth_state"
6976)
7077
7178func (a * AuthRouter ) SetupRoutes (router gin.IRouter ) {
@@ -75,8 +82,12 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) {
7582 for providerName , provider := range a .providers {
7683 router .GET (provider .RedirectURL (), func (c * gin.Context ) {
7784 session := sessions .Default (c )
78-
79- token , err := provider .Exchange (c )
85+ state := session .Get (SessionKeyOAuthState )
86+ if state == nil {
87+ c .Error (errors .New ("OAuth state is missing" ))
88+ return
89+ }
90+ token , err := provider .Exchange (c , state .(string ))
8091 if err != nil {
8192 c .Error (err )
8293 return
@@ -86,19 +97,28 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) {
8697 c .Error (err )
8798 return
8899 }
89- session .Set ("provider" , providerName )
90- session .Set ("user_id" , userID )
100+ session .Set (SessionKeyProvider , providerName )
101+ session .Set (SessionKeyUserID , userID )
91102 session .Save ()
92- redirectURL := session .Get ("redirect_url" )
103+ redirectURL := session .Get (SessionKeyRedirectURL )
93104 c .Redirect (http .StatusFound , redirectURL .(string ))
94105 })
95106
96107 router .GET (provider .AuthURL (), func (c * gin.Context ) {
97- url , err := provider .AuthCodeURL (c )
108+ session := sessions .Default (c )
109+
110+ state , err := utils .GenerateState ()
111+ if err != nil {
112+ c .Error (err )
113+ return
114+ }
115+ url , err := provider .AuthCodeURL (c , state )
98116 if err != nil {
99117 c .Error (err )
100118 return
101119 }
120+ session .Set (SessionKeyOAuthState , state )
121+ session .Save ()
102122 c .Redirect (http .StatusFound , url )
103123 })
104124 }
@@ -190,11 +210,11 @@ func (a *AuthRouter) handleLoginPost(c *gin.Context) {
190210 }
191211
192212 session := sessions .Default (c )
193- session .Set ("provider" , PasswordProvider )
194- session .Set ("user_id" , PasswordUserID )
213+ session .Set (SessionKeyProvider , PasswordProvider )
214+ session .Set (SessionKeyUserID , PasswordUserID )
195215 session .Save ()
196216
197- redirectURL := session .Get ("redirect_url" )
217+ redirectURL := session .Get (SessionKeyRedirectURL )
198218 if redirectURL == nil {
199219 c .Redirect (http .StatusFound , "/" )
200220 return
@@ -212,10 +232,10 @@ func (a *AuthRouter) handleLogout(c *gin.Context) {
212232func (a * AuthRouter ) RequireAuth () gin.HandlerFunc {
213233 return func (c * gin.Context ) {
214234 session := sessions .Default (c )
215- providerName := session .Get ("provider" )
216- userID := session .Get ("user_id" )
235+ providerName := session .Get (SessionKeyProvider )
236+ userID := session .Get (SessionKeyUserID )
217237 if providerName == nil || userID == nil {
218- session .Set ("redirect_url" , c .Request .URL .String ())
238+ session .Set (SessionKeyRedirectURL , c .Request .URL .String ())
219239 session .Save ()
220240 c .Redirect (http .StatusFound , LoginEndpoint )
221241 return
0 commit comments