Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions pkg/auth/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package auth
import (
"context"
"encoding/json"
"errors"
"net/url"

"github.com/gin-gonic/gin"
Expand Down Expand Up @@ -44,12 +45,15 @@ func (p *githubProvider) AuthURL() string {
return GitHubAuthEndpoint
}

func (p *githubProvider) AuthCodeURL(c *gin.Context) (string, error) {
authURL := p.oauth2.AuthCodeURL("state", oauth2.AccessTypeOffline)
func (p *githubProvider) AuthCodeURL(c *gin.Context, state string) (string, error) {
authURL := p.oauth2.AuthCodeURL(state, oauth2.AccessTypeOffline)
return authURL, nil
}

func (p *githubProvider) Exchange(c *gin.Context) (*oauth2.Token, error) {
func (p *githubProvider) Exchange(c *gin.Context, state string) (*oauth2.Token, error) {
if c.Query("state") != state {
return nil, errors.New("invalid OAuth state")
}
code := c.Query("code")
token, err := p.oauth2.Exchange(c, code)
if err != nil {
Expand Down
10 changes: 7 additions & 3 deletions pkg/auth/google.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package auth
import (
"context"
"encoding/json"
"errors"
"net/url"

"github.com/gin-gonic/gin"
Expand Down Expand Up @@ -40,16 +41,19 @@ func (p *googleProvider) RedirectURL() string {
return GoogleCallbackEndpoint
}

func (p *googleProvider) AuthCodeURL(c *gin.Context) (string, error) {
authURL := p.oauth2.AuthCodeURL("state", oauth2.AccessTypeOffline)
func (p *googleProvider) AuthCodeURL(c *gin.Context, state string) (string, error) {
authURL := p.oauth2.AuthCodeURL(state, oauth2.AccessTypeOffline)
return authURL, nil
}

func (p *googleProvider) AuthURL() string {
return GoogleAuthEndpoint
}

func (p *googleProvider) Exchange(c *gin.Context) (*oauth2.Token, error) {
func (p *googleProvider) Exchange(c *gin.Context, state string) (*oauth2.Token, error) {
if c.Query("state") != state {
return nil, errors.New("invalid OAuth state")
}
code := c.Query("code")
token, err := p.oauth2.Exchange(c, code)
if err != nil {
Expand Down
50 changes: 35 additions & 15 deletions pkg/auth/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ package auth
import (
"context"
"embed"
"errors"
"html/template"
"net/http"

"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/sigbit/mcp-auth-proxy/pkg/utils"
"golang.org/x/crypto/bcrypt"
"golang.org/x/oauth2"
)
Expand All @@ -19,8 +21,8 @@ type Provider interface {
Name() string
RedirectURL() string
AuthURL() string
AuthCodeURL(c *gin.Context) (string, error)
Exchange(c *gin.Context) (*oauth2.Token, error)
AuthCodeURL(c *gin.Context, state string) (string, error)
Exchange(c *gin.Context, state string) (*oauth2.Token, error)
GetUserID(ctx context.Context, token *oauth2.Token) (string, error)
Authorization(userid string) (bool, error)
}
Expand Down Expand Up @@ -63,9 +65,14 @@ const (
GoogleCallbackEndpoint = "/.auth/google/callback"
GitHubAuthEndpoint = "/.auth/github"
GitHubCallbackEndpoint = "/.auth/github/callback"

PasswordProvider = "password"
PasswordUserID = "password_user"

SessionKeyProvider = "provider"
SessionKeyUserID = "user_id"
SessionKeyRedirectURL = "redirect_url"
SessionKeyOAuthState = "oauth_state"
)

func (a *AuthRouter) SetupRoutes(router gin.IRouter) {
Expand All @@ -75,8 +82,12 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) {
for providerName, provider := range a.providers {
router.GET(provider.RedirectURL(), func(c *gin.Context) {
session := sessions.Default(c)

token, err := provider.Exchange(c)
state := session.Get(SessionKeyOAuthState)
if state == nil {
c.Error(errors.New("OAuth state is missing"))
return
}
token, err := provider.Exchange(c, state.(string))
if err != nil {
c.Error(err)
return
Expand All @@ -86,19 +97,28 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) {
c.Error(err)
return
}
session.Set("provider", providerName)
session.Set("user_id", userID)
session.Set(SessionKeyProvider, providerName)
session.Set(SessionKeyUserID, userID)
session.Save()
redirectURL := session.Get("redirect_url")
redirectURL := session.Get(SessionKeyRedirectURL)
c.Redirect(http.StatusFound, redirectURL.(string))
})

router.GET(provider.AuthURL(), func(c *gin.Context) {
url, err := provider.AuthCodeURL(c)
session := sessions.Default(c)

state, err := utils.GenerateState()
if err != nil {
c.Error(err)
return
}
url, err := provider.AuthCodeURL(c, state)
if err != nil {
c.Error(err)
return
}
session.Set(SessionKeyOAuthState, state)
session.Save()
c.Redirect(http.StatusFound, url)
})
}
Expand Down Expand Up @@ -190,11 +210,11 @@ func (a *AuthRouter) handleLoginPost(c *gin.Context) {
}

session := sessions.Default(c)
session.Set("provider", PasswordProvider)
session.Set("user_id", PasswordUserID)
session.Set(SessionKeyProvider, PasswordProvider)
session.Set(SessionKeyUserID, PasswordUserID)
session.Save()

redirectURL := session.Get("redirect_url")
redirectURL := session.Get(SessionKeyRedirectURL)
if redirectURL == nil {
c.Redirect(http.StatusFound, "/")
return
Expand All @@ -212,10 +232,10 @@ func (a *AuthRouter) handleLogout(c *gin.Context) {
func (a *AuthRouter) RequireAuth() gin.HandlerFunc {
return func(c *gin.Context) {
session := sessions.Default(c)
providerName := session.Get("provider")
userID := session.Get("user_id")
providerName := session.Get(SessionKeyProvider)
userID := session.Get(SessionKeyUserID)
if providerName == nil || userID == nil {
session.Set("redirect_url", c.Request.URL.String())
session.Set(SessionKeyRedirectURL, c.Request.URL.String())
session.Save()
c.Redirect(http.StatusFound, LoginEndpoint)
return
Expand Down
8 changes: 8 additions & 0 deletions pkg/utils/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,11 @@ func GenerateClientSecret() (string, error) {
}
return hex.EncodeToString(bytes), nil
}

func GenerateState() (string, error) {
bytes := make([]byte, 8)
Comment thread
hrntknr marked this conversation as resolved.
Outdated
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}