Skip to content

Commit cd02a3f

Browse files
committed
feat(auth): add no-provider-auto-select flag to disable auto-redirect
- Add noProviderAutoSelect to AuthRouter and mcp-proxy Run - Skip auto-redirect to the sole provider when no password is set - Update docs and tests to cover behavior Notes: This adds a new parameter to exported constructors; call sites pass the flag.
1 parent 3a4baff commit cd02a3f

6 files changed

Lines changed: 183 additions & 77 deletions

File tree

docs/docs/configuration.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,11 @@ Complete reference for all MCP Auth Proxy configuration options.
2727

2828
#### Password Authentication
2929

30-
| Option | Environment Variable | Default | Description |
31-
| ----------------- | -------------------- | ------- | ------------------------------------------------------------------- |
32-
| `--password` | `PASSWORD` | - | Plain text password for authentication (will be hashed with bcrypt) |
33-
| `--password-hash` | `PASSWORD_HASH` | - | Bcrypt hash of password for authentication |
30+
| Option | Environment Variable | Default | Description |
31+
| --------------------------- | ------------------------- | ------- | -------------------------------------------------------------------------------------------- |
32+
| `--no-provider-auto-select` | `NO_PROVIDER_AUTO_SELECT` | `false` | Disable auto-redirect when only one OAuth/OIDC provider is configured and no password is set |
33+
| `--password` | `PASSWORD` | - | Plain text password for authentication (will be hashed with bcrypt) |
34+
| `--password-hash` | `PASSWORD_HASH` | - | Bcrypt hash of password for authentication |
3435

3536
#### Google OAuth
3637

main.go

Lines changed: 40 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,10 @@ func main() {
8686
var oidcClientSecret string
8787
var oidcScopes string
8888
var oidcUserIDField string
89-
var oidcProviderName string
90-
var oidcAllowedUsers string
91-
var oidcAllowedUsersGlob string
89+
var oidcProviderName string
90+
var oidcAllowedUsers string
91+
var oidcAllowedUsersGlob string
92+
var noProviderAutoSelect bool
9293
var password string
9394
var passwordHash string
9495
var proxyBearerToken string
@@ -170,40 +171,41 @@ func main() {
170171
}
171172
}
172173

173-
if err := mcpproxy.Run(
174-
listen,
175-
tlsListen,
176-
!noAutoTLS,
177-
tlsHost,
178-
tlsDirectoryURL,
179-
tlsAcceptTOS,
180-
dataPath,
181-
externalURL,
182-
googleClientID,
183-
googleClientSecret,
184-
googleAllowedUsersList,
185-
googleAllowedWorkspacesList,
186-
githubClientID,
187-
githubClientSecret,
188-
githubAllowedUsersList,
189-
githubAllowedOrgsList,
190-
oidcConfigurationURL,
191-
oidcClientID,
192-
oidcClientSecret,
193-
oidcScopesList,
194-
oidcUserIDField,
195-
oidcProviderName,
196-
oidcAllowedUsersList,
197-
oidcAllowedUsersGlobList,
198-
password,
199-
passwordHash,
200-
trustedProxiesList,
201-
proxyHeadersList,
202-
proxyBearerToken,
203-
args,
204-
); err != nil {
205-
panic(err)
206-
}
174+
if err := mcpproxy.Run(
175+
listen,
176+
tlsListen,
177+
!noAutoTLS,
178+
tlsHost,
179+
tlsDirectoryURL,
180+
tlsAcceptTOS,
181+
dataPath,
182+
externalURL,
183+
googleClientID,
184+
googleClientSecret,
185+
googleAllowedUsersList,
186+
googleAllowedWorkspacesList,
187+
githubClientID,
188+
githubClientSecret,
189+
githubAllowedUsersList,
190+
githubAllowedOrgsList,
191+
oidcConfigurationURL,
192+
oidcClientID,
193+
oidcClientSecret,
194+
oidcScopesList,
195+
oidcUserIDField,
196+
oidcProviderName,
197+
oidcAllowedUsersList,
198+
oidcAllowedUsersGlobList,
199+
noProviderAutoSelect,
200+
password,
201+
passwordHash,
202+
trustedProxiesList,
203+
proxyHeadersList,
204+
proxyBearerToken,
205+
args,
206+
); err != nil {
207+
panic(err)
208+
}
207209
},
208210
}
209211

@@ -239,6 +241,7 @@ func main() {
239241
rootCmd.Flags().StringVar(&oidcAllowedUsersGlob, "oidc-allowed-users-glob", getEnvWithDefault("OIDC_ALLOWED_USERS_GLOB", ""), "Comma-separated list of glob patterns for allowed OIDC users")
240242

241243
// Password authentication
244+
rootCmd.Flags().BoolVar(&noProviderAutoSelect, "no-provider-auto-select", getEnvBoolWithDefault("NO_PROVIDER_AUTO_SELECT", false), "Disable auto-redirect when only one OAuth/OIDC provider is configured and no password is set")
242245
rootCmd.Flags().StringVar(&password, "password", getEnvWithDefault("PASSWORD", ""), "Plain text password for authentication (will be hashed with bcrypt)")
243246
rootCmd.Flags().StringVar(&passwordHash, "password-hash", getEnvWithDefault("PASSWORD_HASH", ""), "Bcrypt hash of password for authentication")
244247

pkg/auth/auth.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,12 @@ type AuthRouter struct {
2121
loginTemplate *template.Template
2222
unauthorizedTemplate *template.Template
2323
errorTemplate *template.Template
24+
// When true, do not auto-redirect to the sole provider even if
25+
// there is only one provider and no password is set.
26+
noProviderAutoSelect bool
2427
}
2528

26-
func NewAuthRouter(passwordHash []string, providers ...Provider) (*AuthRouter, error) {
29+
func NewAuthRouter(passwordHash []string, noProviderAutoSelect bool, providers ...Provider) (*AuthRouter, error) {
2730
tmpl, err := template.ParseFS(templateFS, "templates/login.html")
2831
if err != nil {
2932
return nil, err
@@ -45,6 +48,7 @@ func NewAuthRouter(passwordHash []string, providers ...Provider) (*AuthRouter, e
4548
loginTemplate: tmpl,
4649
unauthorizedTemplate: unauthorizedTmpl,
4750
errorTemplate: errorTmpl,
51+
noProviderAutoSelect: noProviderAutoSelect,
4852
}, nil
4953
}
5054

@@ -137,6 +141,11 @@ func (a *AuthRouter) handleLogin(c *gin.Context) {
137141
a.handleLoginPost(c)
138142
return
139143
}
144+
// Auto-redirect to the sole provider if enabled and no password is set
145+
if !a.noProviderAutoSelect && len(a.passwordHash) == 0 && len(a.providers) == 1 {
146+
c.Redirect(http.StatusFound, a.providers[0].AuthURL())
147+
return
148+
}
140149
a.renderLogin(c, "")
141150
}
142151

pkg/auth/auth_test.go

Lines changed: 96 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ func TestAuthenticationFlow(t *testing.T) {
5353
mockProvider.EXPECT().AuthURL().Return("/.auth/test").AnyTimes()
5454
mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes()
5555

56-
// Create AuthRouter
57-
authRouter, err := NewAuthRouter(nil, mockProvider)
56+
// Create AuthRouter (auto-select enabled by default)
57+
authRouter, err := NewAuthRouter(nil, false, mockProvider)
5858
require.NoError(t, err)
5959

6060
router := setupTestRouter(authRouter)
@@ -88,7 +88,7 @@ func TestAuthenticationFlow(t *testing.T) {
8888
mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(true, "authorized_user", nil)
8989

9090
// Create AuthRouter
91-
authRouter, err := NewAuthRouter(nil, mockProvider)
91+
authRouter, err := NewAuthRouter(nil, false, mockProvider)
9292
require.NoError(t, err)
9393

9494
router := setupTestRouter(authRouter)
@@ -149,7 +149,7 @@ func TestAuthenticationFlow(t *testing.T) {
149149
mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(false, "unauthorized_user", nil)
150150

151151
// Create AuthRouter
152-
authRouter, err := NewAuthRouter(nil, mockProvider)
152+
authRouter, err := NewAuthRouter(nil, false, mockProvider)
153153
require.NoError(t, err)
154154

155155
router := setupTestRouter(authRouter)
@@ -187,3 +187,95 @@ func TestAuthenticationFlow(t *testing.T) {
187187
require.Equal(t, "/.auth/login", location)
188188
})
189189
}
190+
191+
func TestLoginAutoRedirect(t *testing.T) {
192+
t.Run("Auto-redirects when single provider and no password", func(t *testing.T) {
193+
ctrl := gomock.NewController(t)
194+
defer ctrl.Finish()
195+
196+
mockProvider := NewMockProvider(ctrl)
197+
mockProvider.EXPECT().Name().Return("test").AnyTimes()
198+
mockProvider.EXPECT().Type().Return("test").AnyTimes()
199+
mockProvider.EXPECT().AuthURL().Return("/.auth/test").AnyTimes()
200+
mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes()
201+
202+
authRouter, err := NewAuthRouter(nil, false, mockProvider)
203+
require.NoError(t, err)
204+
205+
router := gin.New()
206+
store := memstore.NewStore([]byte("test-secret"))
207+
router.Use(sessions.Sessions("session", store))
208+
authRouter.SetupRoutes(router)
209+
210+
server := httptest.NewServer(router)
211+
defer server.Close()
212+
213+
client := setupClient()
214+
resp, err := client.Get(server.URL + LoginEndpoint)
215+
require.NoError(t, err)
216+
defer resp.Body.Close()
217+
218+
require.Equal(t, http.StatusFound, resp.StatusCode)
219+
location := resp.Header.Get("Location")
220+
require.Equal(t, "/.auth/test", location)
221+
})
222+
223+
t.Run("Does not redirect when disabled", func(t *testing.T) {
224+
ctrl := gomock.NewController(t)
225+
defer ctrl.Finish()
226+
227+
mockProvider := NewMockProvider(ctrl)
228+
mockProvider.EXPECT().Name().Return("test").AnyTimes()
229+
mockProvider.EXPECT().Type().Return("test").AnyTimes()
230+
mockProvider.EXPECT().AuthURL().Return("/.auth/test").AnyTimes()
231+
mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes()
232+
233+
authRouter, err := NewAuthRouter(nil, true, mockProvider)
234+
require.NoError(t, err)
235+
236+
router := gin.New()
237+
store := memstore.NewStore([]byte("test-secret"))
238+
router.Use(sessions.Sessions("session", store))
239+
authRouter.SetupRoutes(router)
240+
241+
server := httptest.NewServer(router)
242+
defer server.Close()
243+
244+
client := setupClient()
245+
resp, err := client.Get(server.URL + LoginEndpoint)
246+
require.NoError(t, err)
247+
defer resp.Body.Close()
248+
249+
require.Equal(t, http.StatusOK, resp.StatusCode)
250+
})
251+
252+
t.Run("Does not redirect when password configured", func(t *testing.T) {
253+
ctrl := gomock.NewController(t)
254+
defer ctrl.Finish()
255+
256+
mockProvider := NewMockProvider(ctrl)
257+
mockProvider.EXPECT().Name().Return("test").AnyTimes()
258+
mockProvider.EXPECT().Type().Return("test").AnyTimes()
259+
mockProvider.EXPECT().AuthURL().Return("/.auth/test").AnyTimes()
260+
mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes()
261+
262+
// Non-empty passwordHash slice disables auto-select
263+
authRouter, err := NewAuthRouter([]string{"dummy"}, false, mockProvider)
264+
require.NoError(t, err)
265+
266+
router := gin.New()
267+
store := memstore.NewStore([]byte("test-secret"))
268+
router.Use(sessions.Sessions("session", store))
269+
authRouter.SetupRoutes(router)
270+
271+
server := httptest.NewServer(router)
272+
defer server.Close()
273+
274+
client := setupClient()
275+
resp, err := client.Get(server.URL + LoginEndpoint)
276+
require.NoError(t, err)
277+
defer resp.Body.Close()
278+
279+
require.Equal(t, http.StatusOK, resp.StatusCode)
280+
})
281+
}

pkg/idp/idp_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ func setupTestServer(t *testing.T) (*httptest.Server, repository.Repository, str
6464
})
6565

6666
// Create auth router and IDP router
67-
authRouter, err := auth.NewAuthRouter([]string{})
67+
authRouter, err := auth.NewAuthRouter([]string{}, false)
6868
require.NoError(t, err)
6969

7070
logger, _ := zap.NewDevelopment()

pkg/mcp-proxy/main.go

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -33,35 +33,36 @@ import (
3333
var ServerShutdownTimeout = 5 * time.Second
3434

3535
func Run(
36-
listen string,
37-
tlsListen string,
38-
autoTLS bool,
39-
tlsHost string,
40-
tlsDirectoryURL string,
41-
tlsAcceptTOS bool,
42-
dataPath string,
43-
externalURL string,
44-
googleClientID string,
45-
googleClientSecret string,
46-
googleAllowedUsers []string,
47-
googleAllowedWorkspaces []string,
48-
githubClientID string,
49-
githubClientSecret string,
50-
githubAllowedUsers []string,
51-
githubAllowedOrgs []string,
52-
oidcConfigurationURL string,
53-
oidcClientID string,
54-
oidcClientSecret string,
55-
oidcScopes []string,
56-
oidcUserIDField string,
57-
oidcProviderName string,
58-
oidcAllowedUsers []string,
59-
oidcAllowedUsersGlob []string,
60-
password string,
61-
passwordHash string,
62-
trustedProxy []string,
63-
proxyHeaders []string,
64-
proxyBearerToken string,
36+
listen string,
37+
tlsListen string,
38+
autoTLS bool,
39+
tlsHost string,
40+
tlsDirectoryURL string,
41+
tlsAcceptTOS bool,
42+
dataPath string,
43+
externalURL string,
44+
googleClientID string,
45+
googleClientSecret string,
46+
googleAllowedUsers []string,
47+
googleAllowedWorkspaces []string,
48+
githubClientID string,
49+
githubClientSecret string,
50+
githubAllowedUsers []string,
51+
githubAllowedOrgs []string,
52+
oidcConfigurationURL string,
53+
oidcClientID string,
54+
oidcClientSecret string,
55+
oidcScopes []string,
56+
oidcUserIDField string,
57+
oidcProviderName string,
58+
oidcAllowedUsers []string,
59+
oidcAllowedUsersGlob []string,
60+
noProviderAutoSelect bool,
61+
password string,
62+
passwordHash string,
63+
trustedProxy []string,
64+
proxyHeaders []string,
65+
proxyBearerToken string,
6566
proxyTarget []string,
6667
) error {
6768
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt)
@@ -201,7 +202,7 @@ func Run(
201202
passwordHashes = append(passwordHashes, passwordHash)
202203
}
203204

204-
authRouter, err := auth.NewAuthRouter(passwordHashes, providers...)
205+
authRouter, err := auth.NewAuthRouter(passwordHashes, noProviderAutoSelect, providers...)
205206
if err != nil {
206207
return fmt.Errorf("failed to create auth router: %w", err)
207208
}

0 commit comments

Comments
 (0)