Skip to content

Commit 0db70d9

Browse files
committed
fix: implement OAuth CSRF protection with state validation
- Generate cryptographically secure random state for OAuth requests - Store OAuth state in session during authorization - Validate state parameter in callback handlers - Add GenerateState utility function for secure random generation Fixes #27
1 parent d4deea9 commit 0db70d9

4 files changed

Lines changed: 57 additions & 21 deletions

File tree

pkg/auth/github.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package auth
33
import (
44
"context"
55
"encoding/json"
6+
"errors"
67
"net/url"
78

89
"github.com/gin-gonic/gin"
@@ -44,12 +45,15 @@ func (p *githubProvider) AuthURL() string {
4445
return GitHubAuthEndpoint
4546
}
4647

47-
func (p *githubProvider) AuthCodeURL(c *gin.Context) (string, error) {
48-
authURL := p.oauth2.AuthCodeURL("state", oauth2.AccessTypeOffline)
48+
func (p *githubProvider) AuthCodeURL(c *gin.Context, state string) (string, error) {
49+
authURL := p.oauth2.AuthCodeURL(state, oauth2.AccessTypeOffline)
4950
return authURL, nil
5051
}
5152

52-
func (p *githubProvider) Exchange(c *gin.Context) (*oauth2.Token, error) {
53+
func (p *githubProvider) Exchange(c *gin.Context, state string) (*oauth2.Token, error) {
54+
if c.Query("state") != state {
55+
return nil, errors.New("invalid OAuth state")
56+
}
5357
code := c.Query("code")
5458
token, err := p.oauth2.Exchange(c, code)
5559
if err != nil {

pkg/auth/google.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package auth
33
import (
44
"context"
55
"encoding/json"
6+
"errors"
67
"net/url"
78

89
"github.com/gin-gonic/gin"
@@ -40,16 +41,19 @@ func (p *googleProvider) RedirectURL() string {
4041
return GoogleCallbackEndpoint
4142
}
4243

43-
func (p *googleProvider) AuthCodeURL(c *gin.Context) (string, error) {
44-
authURL := p.oauth2.AuthCodeURL("state", oauth2.AccessTypeOffline)
44+
func (p *googleProvider) AuthCodeURL(c *gin.Context, state string) (string, error) {
45+
authURL := p.oauth2.AuthCodeURL(state, oauth2.AccessTypeOffline)
4546
return authURL, nil
4647
}
4748

4849
func (p *googleProvider) AuthURL() string {
4950
return GoogleAuthEndpoint
5051
}
5152

52-
func (p *googleProvider) Exchange(c *gin.Context) (*oauth2.Token, error) {
53+
func (p *googleProvider) Exchange(c *gin.Context, state string) (*oauth2.Token, error) {
54+
if c.Query("state") != state {
55+
return nil, errors.New("invalid OAuth state")
56+
}
5357
code := c.Query("code")
5458
token, err := p.oauth2.Exchange(c, code)
5559
if err != nil {

pkg/auth/main.go

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@ package auth
33
import (
44
"context"
55
"embed"
6+
"errors"
67
"html/template"
78
"net/http"
89

910
"github.com/gin-contrib/sessions"
1011
"github.com/gin-gonic/gin"
12+
"github.com/sigbit/mcp-auth-proxy/pkg/utils"
1113
"golang.org/x/crypto/bcrypt"
1214
"golang.org/x/oauth2"
1315
)
@@ -19,8 +21,8 @@ type Provider interface {
1921
Name() string
2022
RedirectURL() string
2123
AuthURL() string
22-
AuthCodeURL(c *gin.Context) (string, error)
23-
Exchange(c *gin.Context) (*oauth2.Token, error)
24+
AuthCodeURL(c *gin.Context, state string) (string, error)
25+
Exchange(c *gin.Context, state string) (*oauth2.Token, error)
2426
GetUserID(ctx context.Context, token *oauth2.Token) (string, error)
2527
Authorization(userid string) (bool, error)
2628
}
@@ -63,9 +65,14 @@ const (
6365
GoogleCallbackEndpoint = "/.auth/google/callback"
6466
GitHubAuthEndpoint = "/.auth/github"
6567
GitHubCallbackEndpoint = "/.auth/github/callback"
66-
68+
6769
PasswordProvider = "password"
6870
PasswordUserID = "password_user"
71+
72+
SessionKeyProvider = "provider"
73+
SessionKeyUserID = "user_id"
74+
SessionKeyRedirectURL = "redirect_url"
75+
SessionKeyOAuthState = "oauth_state"
6976
)
7077

7178
func (a *AuthRouter) SetupRoutes(router gin.IRouter) {
@@ -75,8 +82,12 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) {
7582
for providerName, provider := range a.providers {
7683
router.GET(provider.RedirectURL(), func(c *gin.Context) {
7784
session := sessions.Default(c)
78-
79-
token, err := provider.Exchange(c)
85+
state := session.Get(SessionKeyOAuthState)
86+
if state == nil {
87+
c.Error(errors.New("OAuth state is missing"))
88+
return
89+
}
90+
token, err := provider.Exchange(c, state.(string))
8091
if err != nil {
8192
c.Error(err)
8293
return
@@ -86,19 +97,28 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) {
8697
c.Error(err)
8798
return
8899
}
89-
session.Set("provider", providerName)
90-
session.Set("user_id", userID)
100+
session.Set(SessionKeyProvider, providerName)
101+
session.Set(SessionKeyUserID, userID)
91102
session.Save()
92-
redirectURL := session.Get("redirect_url")
103+
redirectURL := session.Get(SessionKeyRedirectURL)
93104
c.Redirect(http.StatusFound, redirectURL.(string))
94105
})
95106

96107
router.GET(provider.AuthURL(), func(c *gin.Context) {
97-
url, err := provider.AuthCodeURL(c)
108+
session := sessions.Default(c)
109+
110+
state, err := utils.GenerateState()
111+
if err != nil {
112+
c.Error(err)
113+
return
114+
}
115+
url, err := provider.AuthCodeURL(c, state)
98116
if err != nil {
99117
c.Error(err)
100118
return
101119
}
120+
session.Set(SessionKeyOAuthState, state)
121+
session.Save()
102122
c.Redirect(http.StatusFound, url)
103123
})
104124
}
@@ -190,11 +210,11 @@ func (a *AuthRouter) handleLoginPost(c *gin.Context) {
190210
}
191211

192212
session := sessions.Default(c)
193-
session.Set("provider", PasswordProvider)
194-
session.Set("user_id", PasswordUserID)
213+
session.Set(SessionKeyProvider, PasswordProvider)
214+
session.Set(SessionKeyUserID, PasswordUserID)
195215
session.Save()
196216

197-
redirectURL := session.Get("redirect_url")
217+
redirectURL := session.Get(SessionKeyRedirectURL)
198218
if redirectURL == nil {
199219
c.Redirect(http.StatusFound, "/")
200220
return
@@ -212,10 +232,10 @@ func (a *AuthRouter) handleLogout(c *gin.Context) {
212232
func (a *AuthRouter) RequireAuth() gin.HandlerFunc {
213233
return func(c *gin.Context) {
214234
session := sessions.Default(c)
215-
providerName := session.Get("provider")
216-
userID := session.Get("user_id")
235+
providerName := session.Get(SessionKeyProvider)
236+
userID := session.Get(SessionKeyUserID)
217237
if providerName == nil || userID == nil {
218-
session.Set("redirect_url", c.Request.URL.String())
238+
session.Set(SessionKeyRedirectURL, c.Request.URL.String())
219239
session.Save()
220240
c.Redirect(http.StatusFound, LoginEndpoint)
221241
return

pkg/utils/main.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,11 @@ func GenerateClientSecret() (string, error) {
3232
}
3333
return hex.EncodeToString(bytes), nil
3434
}
35+
36+
func GenerateState() (string, error) {
37+
bytes := make([]byte, 8)
38+
if _, err := rand.Read(bytes); err != nil {
39+
return "", err
40+
}
41+
return hex.EncodeToString(bytes), nil
42+
}

0 commit comments

Comments
 (0)