Skip to content

Commit 940e91e

Browse files
authored
fix: generate server-side OAuth state when client omits it (#126)
1 parent 22b2591 commit 940e91e

2 files changed

Lines changed: 158 additions & 0 deletions

File tree

pkg/idp/idp.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,22 @@ func (a *IDPRouter) SetupRoutes(router gin.IRouter) {
114114
func (a *IDPRouter) handleAuth(c *gin.Context) {
115115
ctx := c.Request.Context()
116116

117+
// RFC 6749 makes state RECOMMENDED, not REQUIRED, but fosite enforces
118+
// minimum entropy (8 chars). Generate a server-side state for clients
119+
// that omit it (e.g., MCP Inspector, Cursor CLI) so they can complete
120+
// the OAuth flow. The generated state is echoed back in the redirect;
121+
// clients that didn't send state will simply ignore it.
122+
if c.Request.URL.Query().Get("state") == "" {
123+
state, err := utils.GenerateState()
124+
if err != nil {
125+
a.provider.WriteAuthorizeError(ctx, c.Writer, nil, fosite.ErrServerError.WithWrap(err))
126+
return
127+
}
128+
q := c.Request.URL.Query()
129+
q.Set("state", state)
130+
c.Request.URL.RawQuery = q.Encode()
131+
}
132+
117133
ar, err := a.provider.NewAuthorizeRequest(ctx, c.Request)
118134
if err != nil {
119135
a.provider.WriteAuthorizeError(ctx, c.Writer, ar, err)

pkg/idp/idp_test.go

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"crypto/rsa"
77
"crypto/sha256"
88
"encoding/json"
9+
"fmt"
910
"net/http"
1011
"net/http/cookiejar"
1112
"net/http/httptest"
@@ -271,3 +272,144 @@ func TestPrivateClient(t *testing.T) {
271272
require.NotEmpty(t, newAccessToken)
272273
require.NotEqual(t, originalAccessToken, newAccessToken, "Access token should be different after refresh")
273274
}
275+
276+
// registerTestClient is a helper that registers a private OAuth client and returns the registration response.
277+
func registerTestClient(t *testing.T, serverURL string) registrationResponse {
278+
t.Helper()
279+
280+
regReq := registrationRequest{
281+
ClientName: "Test OAuth Client",
282+
GrantTypes: []string{"authorization_code", "refresh_token"},
283+
ResponseTypes: []string{"code"},
284+
TokenEndpointAuthMethod: "client_secret_basic",
285+
Scope: "test",
286+
RedirectURIs: []string{"http://localhost:8080/callback"},
287+
}
288+
289+
reqBody, err := json.Marshal(regReq)
290+
require.NoError(t, err)
291+
292+
resp, err := http.Post(serverURL+RegistrationEndpoint, "application/json", bytes.NewReader(reqBody))
293+
require.NoError(t, err)
294+
defer resp.Body.Close()
295+
296+
require.Equal(t, http.StatusCreated, resp.StatusCode)
297+
298+
var regResp registrationResponse
299+
err = json.NewDecoder(resp.Body).Decode(&regResp)
300+
require.NoError(t, err)
301+
302+
return regResp
303+
}
304+
305+
// testAuthFlowWithURL performs the OAuth authorization flow given a raw authorization URL
306+
// and returns the callback URL after authorization completes.
307+
func testAuthFlowWithURL(t *testing.T, serverURL, authURL string) *url.URL {
308+
t.Helper()
309+
310+
jar, err := cookiejar.New(nil)
311+
require.NoError(t, err)
312+
client := &http.Client{
313+
Jar: jar,
314+
CheckRedirect: func(req *http.Request, via []*http.Request) error {
315+
return http.ErrUseLastResponse
316+
},
317+
}
318+
319+
// Step 1: Make initial authorization request
320+
authResp, err := client.Get(authURL)
321+
require.NoError(t, err)
322+
defer authResp.Body.Close()
323+
324+
require.Contains(t, []int{http.StatusFound, http.StatusSeeOther}, authResp.StatusCode,
325+
"expected redirect, got %d", authResp.StatusCode)
326+
location := authResp.Header.Get("Location")
327+
require.NotEmpty(t, location)
328+
require.Contains(t, location, strings.ReplaceAll(AuthorizationReturnEndpoint, ":ar_id", ""))
329+
330+
// Step 2: Follow the redirect to complete authorization
331+
authReturnResp, err := client.Get(serverURL + location)
332+
require.NoError(t, err)
333+
defer authReturnResp.Body.Close()
334+
335+
require.Contains(t, []int{http.StatusFound, http.StatusSeeOther}, authReturnResp.StatusCode,
336+
"expected redirect with authorization code, got %d", authReturnResp.StatusCode)
337+
callbackLocation := authReturnResp.Header.Get("Location")
338+
require.NotEmpty(t, callbackLocation)
339+
340+
callbackURL, err := url.Parse(callbackLocation)
341+
require.NoError(t, err)
342+
require.NotEmpty(t, callbackURL.Query().Get("code"), "callback URL should contain an authorization code")
343+
344+
return callbackURL
345+
}
346+
347+
func TestAuthWithoutState(t *testing.T) {
348+
server, _, _ := setupTestServer(t)
349+
regResp := registerTestClient(t, server.URL)
350+
351+
// Build authorization URL manually WITHOUT a state parameter
352+
authURL := fmt.Sprintf("%s%s?response_type=code&client_id=%s&redirect_uri=%s",
353+
server.URL, AuthorizationEndpoint, regResp.ClientID,
354+
url.QueryEscape("http://localhost:8080/callback"))
355+
356+
callbackURL := testAuthFlowWithURL(t, server.URL, authURL)
357+
358+
// Server should have generated a state and echoed it back
359+
require.NotEmpty(t, callbackURL.Query().Get("state"), "server should generate a state when client omits it")
360+
361+
// Exchange authorization code for tokens
362+
code := callbackURL.Query().Get("code")
363+
tokenReq := url.Values{}
364+
tokenReq.Set("grant_type", "authorization_code")
365+
tokenReq.Set("code", code)
366+
tokenReq.Set("redirect_uri", "http://localhost:8080/callback")
367+
tokenReq.Set("client_id", regResp.ClientID)
368+
tokenReq.Set("client_secret", regResp.ClientSecret)
369+
370+
tokenResp, err := http.PostForm(server.URL+TokenEndpoint, tokenReq)
371+
require.NoError(t, err)
372+
defer tokenResp.Body.Close()
373+
374+
require.Equal(t, http.StatusOK, tokenResp.StatusCode)
375+
376+
var tokenResult map[string]any
377+
err = json.NewDecoder(tokenResp.Body).Decode(&tokenResult)
378+
require.NoError(t, err)
379+
require.NotEmpty(t, tokenResult["access_token"])
380+
}
381+
382+
func TestAuthWithEmptyState(t *testing.T) {
383+
server, _, _ := setupTestServer(t)
384+
regResp := registerTestClient(t, server.URL)
385+
386+
// Build authorization URL with an empty state parameter
387+
authURL := fmt.Sprintf("%s%s?response_type=code&client_id=%s&redirect_uri=%s&state=",
388+
server.URL, AuthorizationEndpoint, regResp.ClientID,
389+
url.QueryEscape("http://localhost:8080/callback"))
390+
391+
callbackURL := testAuthFlowWithURL(t, server.URL, authURL)
392+
393+
// Server should have generated a state and echoed it back
394+
require.NotEmpty(t, callbackURL.Query().Get("state"), "server should generate a state when client sends empty state")
395+
396+
// Exchange authorization code for tokens
397+
code := callbackURL.Query().Get("code")
398+
tokenReq := url.Values{}
399+
tokenReq.Set("grant_type", "authorization_code")
400+
tokenReq.Set("code", code)
401+
tokenReq.Set("redirect_uri", "http://localhost:8080/callback")
402+
tokenReq.Set("client_id", regResp.ClientID)
403+
tokenReq.Set("client_secret", regResp.ClientSecret)
404+
405+
tokenResp, err := http.PostForm(server.URL+TokenEndpoint, tokenReq)
406+
require.NoError(t, err)
407+
defer tokenResp.Body.Close()
408+
409+
require.Equal(t, http.StatusOK, tokenResp.StatusCode)
410+
411+
var tokenResult map[string]any
412+
err = json.NewDecoder(tokenResp.Body).Decode(&tokenResult)
413+
require.NoError(t, err)
414+
require.NotEmpty(t, tokenResult["access_token"])
415+
}

0 commit comments

Comments
 (0)