Skip to content

Commit 239f2b2

Browse files
authored
feat: enhance OAuth providers with organization and workspace support (#69)
* feat: enhance OAuth providers with organization and workspace support - Add GitHub organization and team-based authorization - Add Google Workspace domain-based authorization - Consolidate authentication flow by combining user retrieval and authorization - Add comprehensive test coverage for OAuth providers - Add utilities for better error handling - Improve session management with proper cookie settings BREAKING CHANGE: Authorization interface changed from separate GetUserID/Authorization calls to combined Authorization method * refactor: remove unused getProvider method from AuthRouter Remove dead code that was not being used anywhere in the codebase. * refactor: rename response variables in GitHub OAuth for clarity Renamed resp variables to resp1, resp2, resp3 to avoid variable shadowing and improve code readability in the GitHub OAuth authorization flow.
1 parent 37bbe8c commit 239f2b2

15 files changed

Lines changed: 841 additions & 142 deletions

File tree

main.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,11 @@ func main() {
3737
var googleClientID string
3838
var googleClientSecret string
3939
var googleAllowedUsers string
40+
var googleAllowedWorkspaces string
4041
var githubClientID string
4142
var githubClientSecret string
4243
var githubAllowedUsers string
44+
var githubAllowedOrgs string
4345
var oidcConfigurationURL string
4446
var oidcClientID string
4547
var oidcClientSecret string
@@ -63,6 +65,14 @@ func main() {
6365
}
6466
}
6567

68+
var googleAllowedWorkspacesList []string
69+
if googleAllowedWorkspaces != "" {
70+
googleAllowedWorkspacesList = strings.Split(googleAllowedWorkspaces, ",")
71+
for i := range googleAllowedWorkspacesList {
72+
googleAllowedWorkspacesList[i] = strings.TrimSpace(googleAllowedWorkspacesList[i])
73+
}
74+
}
75+
6676
var githubAllowedUsersList []string
6777
if githubAllowedUsers != "" {
6878
githubAllowedUsersList = strings.Split(githubAllowedUsers, ",")
@@ -71,6 +81,14 @@ func main() {
7181
}
7282
}
7383

84+
var githubAllowedOrgsList []string
85+
if githubAllowedOrgs != "" {
86+
githubAllowedOrgsList = strings.Split(githubAllowedOrgs, ",")
87+
for i := range githubAllowedOrgsList {
88+
githubAllowedOrgsList[i] = strings.TrimSpace(githubAllowedOrgsList[i])
89+
}
90+
}
91+
7492
var oidcAllowedUsersList []string
7593
if oidcAllowedUsers != "" {
7694
oidcAllowedUsersList = strings.Split(oidcAllowedUsers, ",")
@@ -110,9 +128,11 @@ func main() {
110128
googleClientID,
111129
googleClientSecret,
112130
googleAllowedUsersList,
131+
googleAllowedWorkspacesList,
113132
githubClientID,
114133
githubClientSecret,
115134
githubAllowedUsersList,
135+
githubAllowedOrgsList,
116136
oidcConfigurationURL,
117137
oidcClientID,
118138
oidcClientSecret,
@@ -144,11 +164,13 @@ func main() {
144164
rootCmd.Flags().StringVar(&googleClientID, "google-client-id", getEnvWithDefault("GOOGLE_CLIENT_ID", ""), "Google OAuth client ID")
145165
rootCmd.Flags().StringVar(&googleClientSecret, "google-client-secret", getEnvWithDefault("GOOGLE_CLIENT_SECRET", ""), "Google OAuth client secret")
146166
rootCmd.Flags().StringVar(&googleAllowedUsers, "google-allowed-users", getEnvWithDefault("GOOGLE_ALLOWED_USERS", ""), "Comma-separated list of allowed Google users (emails)")
167+
rootCmd.Flags().StringVar(&googleAllowedWorkspaces, "google-allowed-workspaces", getEnvWithDefault("GOOGLE_ALLOWED_WORKSPACES", ""), "Comma-separated list of allowed Google workspaces")
147168

148169
// GitHub OAuth configuration
149170
rootCmd.Flags().StringVar(&githubClientID, "github-client-id", getEnvWithDefault("GITHUB_CLIENT_ID", ""), "GitHub OAuth client ID")
150171
rootCmd.Flags().StringVar(&githubClientSecret, "github-client-secret", getEnvWithDefault("GITHUB_CLIENT_SECRET", ""), "GitHub OAuth client secret")
151172
rootCmd.Flags().StringVar(&githubAllowedUsers, "github-allowed-users", getEnvWithDefault("GITHUB_ALLOWED_USERS", ""), "Comma-separated list of allowed GitHub users (usernames)")
173+
rootCmd.Flags().StringVar(&githubAllowedOrgs, "github-allowed-orgs", getEnvWithDefault("GITHUB_ALLOWED_ORGS", ""), "Comma-separated list of allowed GitHub organizations. You can also restrict access to specific teams using the format `Org:Team`")
152174

153175
// OIDC configuration
154176
rootCmd.Flags().StringVar(&oidcConfigurationURL, "oidc-configuration-url", getEnvWithDefault("OIDC_CONFIGURATION_URL", ""), "OIDC configuration URL")

pkg/auth/auth.go

Lines changed: 12 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,7 @@ const (
6161
PasswordProvider = "password"
6262
PasswordUserID = "password_user"
6363

64-
SessionKeyProvider = "provider"
65-
SessionKeyUserID = "user_id"
64+
SessionKeyAuthorized = "authorized"
6665
SessionKeyRedirectURL = "redirect_url"
6766
SessionKeyOAuthState = "oauth_state"
6867
)
@@ -84,22 +83,16 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) {
8483
a.renderError(c, err)
8584
return
8685
}
87-
userID, err := provider.GetUserID(c, token)
88-
if err != nil {
89-
a.renderError(c, err)
90-
return
91-
}
92-
ok, err := provider.Authorization(userID)
86+
ok, user, err := provider.Authorization(c, token)
9387
if err != nil {
9488
a.renderError(c, err)
9589
return
9690
}
9791
if !ok {
98-
a.renderUnauthorized(c, userID, provider.Name())
92+
a.renderUnauthorized(c, user, provider.Name())
9993
return
10094
}
101-
session.Set(SessionKeyProvider, provider.Name())
102-
session.Set(SessionKeyUserID, userID)
95+
session.Set(SessionKeyAuthorized, true)
10396
redirectURL := session.Get(SessionKeyRedirectURL)
10497
if redirectURL != nil {
10598
session.Delete(SessionKeyRedirectURL)
@@ -124,7 +117,7 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) {
124117
a.renderError(c, err)
125118
return
126119
}
127-
url, err := provider.AuthCodeURL(c, state)
120+
url, err := provider.AuthCodeURL(state)
128121
if err != nil {
129122
a.renderError(c, err)
130123
return
@@ -139,15 +132,6 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) {
139132
}
140133
}
141134

142-
func (a *AuthRouter) getProvider(name string) Provider {
143-
for _, provider := range a.providers {
144-
if provider.Name() == name {
145-
return provider
146-
}
147-
}
148-
return nil
149-
}
150-
151135
func (a *AuthRouter) handleLogin(c *gin.Context) {
152136
if c.Request.Method == "POST" {
153137
a.handleLoginPost(c)
@@ -183,8 +167,7 @@ func (a *AuthRouter) handleLoginPost(c *gin.Context) {
183167
}
184168

185169
session := sessions.Default(c)
186-
session.Set(SessionKeyProvider, PasswordProvider)
187-
session.Set(SessionKeyUserID, PasswordUserID)
170+
session.Set(SessionKeyAuthorized, true)
188171
redirectURL := session.Get(SessionKeyRedirectURL)
189172
if redirectURL != nil {
190173
session.Delete(SessionKeyRedirectURL)
@@ -203,8 +186,7 @@ func (a *AuthRouter) handleLoginPost(c *gin.Context) {
203186

204187
func (a *AuthRouter) handleLogout(c *gin.Context) {
205188
session := sessions.Default(c)
206-
session.Delete(SessionKeyProvider)
207-
session.Delete(SessionKeyUserID)
189+
session.Delete(SessionKeyAuthorized)
208190
if err := session.Save(); err != nil {
209191
a.renderError(c, err)
210192
return
@@ -215,9 +197,8 @@ func (a *AuthRouter) handleLogout(c *gin.Context) {
215197
func (a *AuthRouter) RequireAuth() gin.HandlerFunc {
216198
return func(c *gin.Context) {
217199
session := sessions.Default(c)
218-
providerName := session.Get(SessionKeyProvider)
219-
userID := session.Get(SessionKeyUserID)
220-
if providerName == nil || userID == nil {
200+
authorized := session.Get(SessionKeyAuthorized)
201+
if authorized == nil {
221202
session.Set(SessionKeyRedirectURL, c.Request.URL.String())
222203
if err := session.Save(); err != nil {
223204
a.renderError(c, err)
@@ -227,25 +208,9 @@ func (a *AuthRouter) RequireAuth() gin.HandlerFunc {
227208
return
228209
}
229210

230-
// Allow password authentication
231-
if providerName.(string) == PasswordProvider {
232-
c.Next()
233-
return
234-
}
235-
236-
p := a.getProvider(providerName.(string))
237-
if p == nil {
238-
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Unknown provider"})
239-
return
240-
}
241-
ok, err := p.Authorization(userID.(string))
242-
if err != nil {
243-
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "Authorization failed"})
244-
return
245-
}
246-
if !ok {
247-
a.renderUnauthorized(c, userID.(string), providerName.(string))
248-
c.Abort()
211+
if !authorized.(bool) {
212+
// not expected
213+
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
249214
return
250215
}
251216
c.Next()

pkg/auth/auth_test.go

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ import (
1515
)
1616

1717
func setupTestRouter(authRouter *AuthRouter) *gin.Engine {
18-
gin.SetMode(gin.TestMode)
1918
router := gin.New()
2019

2120
// Setup session middleware
@@ -84,10 +83,9 @@ func TestAuthenticationFlow(t *testing.T) {
8483
mockProvider.EXPECT().Name().Return("test").AnyTimes()
8584
mockProvider.EXPECT().AuthURL().Return("/.auth/test").AnyTimes()
8685
mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes()
87-
mockProvider.EXPECT().AuthCodeURL(gomock.Any(), gomock.Any()).Return("https://example.com/oauth", nil)
86+
mockProvider.EXPECT().AuthCodeURL(gomock.Any()).Return("https://example.com/oauth", nil)
8887
mockProvider.EXPECT().Exchange(gomock.Any(), gomock.Any()).Return(mockToken, nil)
89-
mockProvider.EXPECT().GetUserID(gomock.Any(), mockToken).Return("test-user", nil)
90-
mockProvider.EXPECT().Authorization("test-user").Return(true, nil).AnyTimes()
88+
mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(true, "authorized_user", nil)
9189

9290
// Create AuthRouter
9391
authRouter, err := NewAuthRouter(nil, mockProvider)
@@ -146,10 +144,9 @@ func TestAuthenticationFlow(t *testing.T) {
146144
mockProvider.EXPECT().Name().Return("test").AnyTimes()
147145
mockProvider.EXPECT().AuthURL().Return("/.auth/test").AnyTimes()
148146
mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes()
149-
mockProvider.EXPECT().AuthCodeURL(gomock.Any(), gomock.Any()).Return("https://example.com/oauth", nil)
147+
mockProvider.EXPECT().AuthCodeURL(gomock.Any()).Return("https://example.com/oauth", nil)
150148
mockProvider.EXPECT().Exchange(gomock.Any(), gomock.Any()).Return(mockToken, nil)
151-
mockProvider.EXPECT().GetUserID(gomock.Any(), mockToken).Return("unauthorized-user", nil)
152-
mockProvider.EXPECT().Authorization("unauthorized-user").Return(false, nil).AnyTimes()
149+
mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(false, "unauthorized_user", nil)
153150

154151
// Create AuthRouter
155152
authRouter, err := NewAuthRouter(nil, mockProvider)

pkg/auth/github.go

Lines changed: 90 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,34 +5,53 @@ import (
55
"encoding/json"
66
"errors"
77
"net/url"
8+
"slices"
9+
"strings"
810

911
"github.com/gin-gonic/gin"
12+
"github.com/sigbit/mcp-auth-proxy/pkg/utils"
1013
"golang.org/x/oauth2"
1114
"golang.org/x/oauth2/github"
1215
)
1316

1417
type githubProvider struct {
18+
endpoint string
1519
oauth2 oauth2.Config
1620
allowedUsers []string
21+
allowedOrgs []string
1722
}
1823

19-
func NewGithubProvider(clientID, clientSecret, externalURL string, allowedUsers []string) (Provider, error) {
24+
func NewGithubProvider(clientID, clientSecret, externalURL string, allowedUsers []string, allowedOrgs []string) (Provider, error) {
2025
r, err := url.JoinPath(externalURL, GitHubCallbackEndpoint)
2126
if err != nil {
2227
return nil, err
2328
}
29+
scopes := []string{}
30+
if len(allowedOrgs) > 0 {
31+
scopes = append(scopes, "read:org")
32+
}
2433
return &githubProvider{
34+
endpoint: "https://api.github.com",
2535
oauth2: oauth2.Config{
2636
ClientID: clientID,
2737
ClientSecret: clientSecret,
2838
RedirectURL: r,
29-
Scopes: []string{"user:email"},
39+
Scopes: scopes,
3040
Endpoint: github.Endpoint,
3141
},
3242
allowedUsers: allowedUsers,
43+
allowedOrgs: allowedOrgs,
3344
}, nil
3445
}
3546

47+
func (p *githubProvider) SetApiEndpoint(u string) {
48+
p.endpoint = u
49+
}
50+
51+
func (p *githubProvider) SetOAuth2Endpoint(cfg oauth2.Endpoint) {
52+
p.oauth2.Endpoint = cfg
53+
}
54+
3655
func (p *githubProvider) Name() string {
3756
return "GitHub"
3857
}
@@ -49,7 +68,7 @@ func (p *githubProvider) AuthURL() string {
4968
return GitHubAuthEndpoint
5069
}
5170

52-
func (p *githubProvider) AuthCodeURL(c *gin.Context, state string) (string, error) {
71+
func (p *githubProvider) AuthCodeURL(state string) (string, error) {
5372
authURL := p.oauth2.AuthCodeURL(state)
5473
return authURL, nil
5574
}
@@ -66,37 +85,87 @@ func (p *githubProvider) Exchange(c *gin.Context, state string) (*oauth2.Token,
6685
return token, nil
6786
}
6887

69-
func (p *githubProvider) GetUserID(ctx context.Context, token *oauth2.Token) (string, error) {
88+
func (p *githubProvider) Authorization(ctx context.Context, token *oauth2.Token) (bool, string, error) {
7089
client := p.oauth2.Client(ctx, token)
71-
resp, err := client.Get("https://api.github.com/user")
90+
resp1, err := client.Get(utils.Must(url.JoinPath(p.endpoint, "/user")))
7291
if err != nil {
73-
return "", err
92+
return false, "", err
93+
}
94+
if resp1.StatusCode < 200 || resp1.StatusCode >= 300 {
95+
return false, "", errors.New("failed to get user info from GitHub API: " + resp1.Status)
7496
}
75-
defer resp.Body.Close()
97+
defer resp1.Body.Close()
7698

7799
var userInfo struct {
78-
ID uint64 `json:"id"`
79100
Login string `json:"login"`
80-
Name string `json:"name"`
81-
Email string `json:"email"`
82101
}
83-
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
84-
return "", err
102+
if err := json.NewDecoder(resp1.Body).Decode(&userInfo); err != nil {
103+
return false, "", err
85104
}
86105

87-
return userInfo.Login, nil
88-
}
106+
if len(p.allowedUsers) == 0 && len(p.allowedOrgs) == 0 {
107+
return true, userInfo.Login, nil
108+
}
89109

90-
func (p *githubProvider) Authorization(userid string) (bool, error) {
91-
if len(p.allowedUsers) == 0 {
92-
return true, nil
110+
if slices.Contains(p.allowedUsers, userInfo.Login) {
111+
return true, userInfo.Login, nil
112+
}
113+
114+
allowedOrgTeams := []string{}
115+
allowedOrgs := []string{}
116+
for _, allowedOrg := range p.allowedOrgs {
117+
if strings.Contains(allowedOrg, ":") {
118+
allowedOrgTeams = append(allowedOrgTeams, allowedOrg)
119+
} else {
120+
allowedOrgs = append(allowedOrgs, allowedOrg)
121+
}
93122
}
94123

95-
for _, allowedUser := range p.allowedUsers {
96-
if allowedUser == userid {
97-
return true, nil
124+
if len(allowedOrgs) > 0 {
125+
resp2, err := client.Get(utils.Must(url.JoinPath(p.endpoint, "/user/orgs")))
126+
if err != nil {
127+
return false, "", err
128+
}
129+
if resp2.StatusCode < 200 || resp2.StatusCode >= 300 {
130+
return false, "", errors.New("failed to get user info from GitHub API: " + resp2.Status)
131+
}
132+
defer resp2.Body.Close()
133+
var orgInfo []struct {
134+
Login string `json:"login"`
135+
}
136+
if err := json.NewDecoder(resp2.Body).Decode(&orgInfo); err != nil {
137+
return false, "", err
138+
}
139+
for _, o := range orgInfo {
140+
if slices.Contains(allowedOrgs, o.Login) {
141+
return true, userInfo.Login, nil
142+
}
143+
}
144+
}
145+
if len(allowedOrgTeams) > 0 {
146+
resp3, err := client.Get(utils.Must(url.JoinPath(p.endpoint, "/user/teams")))
147+
if err != nil {
148+
return false, "", err
149+
}
150+
if resp3.StatusCode < 200 || resp3.StatusCode >= 300 {
151+
return false, "", errors.New("failed to get user info from GitHub API: " + resp3.Status)
152+
}
153+
defer resp3.Body.Close()
154+
var teamInfo []struct {
155+
Organization struct {
156+
Login string `json:"login"`
157+
} `json:"organization"`
158+
Slug string `json:"slug"`
159+
}
160+
if err := json.NewDecoder(resp3.Body).Decode(&teamInfo); err != nil {
161+
return false, "", err
162+
}
163+
for _, team := range teamInfo {
164+
if slices.Contains(allowedOrgTeams, team.Organization.Login+":"+team.Slug) {
165+
return true, userInfo.Login, nil
166+
}
98167
}
99168
}
100169

101-
return false, nil
170+
return false, userInfo.Login, nil
102171
}

0 commit comments

Comments
 (0)