diff --git a/pkg/idp/idp.go b/pkg/idp/idp.go index 11ec298..9b8264b 100644 --- a/pkg/idp/idp.go +++ b/pkg/idp/idp.go @@ -158,6 +158,7 @@ func (a *IDPRouter) handleAuthorizationReturn(c *gin.Context) { for _, scope := range ar.GetRequestedScopes() { ar.GrantScope(scope) } + ar.GrantAudience(a.externalURL) jwtSession, err := NewJWTSessionWithKey(a.externalURL, "user", a.privKey) if err != nil { a.logger.With(utils.Err(err)...).Error("Failed to create JWT session", zap.Error(err)) @@ -284,6 +285,7 @@ func (a *IDPRouter) handleRegister(c *gin.Context) { GrantTypes: req.GrantTypes, ResponseTypes: req.ResponseTypes, Scopes: strings.Fields(req.Scope), + Audience: []string{a.externalURL}, Public: isPublic, } if err := a.repo.RegisterClient(ctx, client); err != nil { diff --git a/pkg/idp/idp_test.go b/pkg/idp/idp_test.go index ac7c700..f53b287 100644 --- a/pkg/idp/idp_test.go +++ b/pkg/idp/idp_test.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "crypto/rsa" "crypto/sha256" + "encoding/base64" "encoding/json" "fmt" "net/http" @@ -413,3 +414,55 @@ func TestAuthWithEmptyState(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, tokenResult["access_token"]) } + +func TestAccessTokenAudienceClaim(t *testing.T) { + server, _, _ := setupTestServer(t) + regResp := registerTestClient(t, server.URL) + + config := &oauth2.Config{ + ClientID: regResp.ClientID, + ClientSecret: regResp.ClientSecret, + RedirectURL: "http://localhost:8080/callback", + Scopes: []string{}, + Endpoint: oauth2.Endpoint{ + AuthURL: server.URL + AuthorizationEndpoint, + TokenURL: server.URL + TokenEndpoint, + }, + } + + callbackURL := testAuthFlowWithURL(t, server.URL, config.AuthCodeURL("test-state")) + code := callbackURL.Query().Get("code") + + tokenReq := url.Values{} + tokenReq.Set("grant_type", "authorization_code") + tokenReq.Set("code", code) + tokenReq.Set("redirect_uri", "http://localhost:8080/callback") + tokenReq.Set("client_id", regResp.ClientID) + tokenReq.Set("client_secret", regResp.ClientSecret) + + tokenResp, err := http.PostForm(server.URL+TokenEndpoint, tokenReq) + require.NoError(t, err) + defer tokenResp.Body.Close() + require.Equal(t, http.StatusOK, tokenResp.StatusCode) + + var tokenResult map[string]any + err = json.NewDecoder(tokenResp.Body).Decode(&tokenResult) + require.NoError(t, err) + + accessToken := tokenResult["access_token"].(string) + + // Decode JWT payload and verify aud claim contains the external URL + parts := strings.Split(accessToken, ".") + require.Len(t, parts, 3, "access token should be a JWT with 3 parts") + + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + require.NoError(t, err) + + var claims map[string]any + err = json.Unmarshal(payload, &claims) + require.NoError(t, err) + + aud, ok := claims["aud"].([]any) + require.True(t, ok, "aud claim should be present as an array") + require.Contains(t, aud, "http://localhost:8080", "aud should contain the external URL") +}