diff --git a/pkg/idp/idp_test.go b/pkg/idp/idp_test.go index 5bf8194..8c7a5d6 100644 --- a/pkg/idp/idp_test.go +++ b/pkg/idp/idp_test.go @@ -52,10 +52,12 @@ func setupTestServer(t *testing.T) (*httptest.Server, repository.Repository, str store := cookie.NewStore(secret[:]) router.Use(sessions.Sessions("test_session", store)) - // Mock auth middleware that always passes + // Mock auth middleware that always passes with user identity router.Use(func(c *gin.Context) { session := sessions.Default(c) session.Set(auth.SessionKeyAuthorized, true) + session.Set(auth.SessionKeyUserID, "test-user@example.com") + session.Set(auth.SessionKeyUserInfo, `{"email":"test-user@example.com","name":"Test User"}`) err := session.Save() if err != nil { c.JSON(500, gin.H{"error": "Failed to save session"}) @@ -466,3 +468,62 @@ func TestAccessTokenAudienceClaim(t *testing.T) { require.True(t, ok, "aud claim should be present as an array") require.Contains(t, aud, "http://localhost:8080", "aud should contain the external URL") } + +func TestAccessTokenPreservesUserIdentity(t *testing.T) { + server, _, _ := setupTestServer(t) + regResp := registerTestClient(t, server.URL) + + config := &oauth2.Config{ + ClientID: regResp.ClientID, + ClientSecret: regResp.ClientSecret, + RedirectURL: "http://localhost:8080/callback", + Scopes: []string{}, + Endpoint: oauth2.Endpoint{ + AuthURL: server.URL + AuthorizationEndpoint, + TokenURL: server.URL + TokenEndpoint, + }, + } + + callbackURL := testAuthFlowWithURL(t, server.URL, config.AuthCodeURL("test-state")) + code := callbackURL.Query().Get("code") + + tokenReq := url.Values{} + tokenReq.Set("grant_type", "authorization_code") + tokenReq.Set("code", code) + tokenReq.Set("redirect_uri", "http://localhost:8080/callback") + tokenReq.Set("client_id", regResp.ClientID) + tokenReq.Set("client_secret", regResp.ClientSecret) + + tokenResp, err := http.PostForm(server.URL+TokenEndpoint, tokenReq) + require.NoError(t, err) + defer tokenResp.Body.Close() + require.Equal(t, http.StatusOK, tokenResp.StatusCode) + + var tokenResult map[string]any + err = json.NewDecoder(tokenResp.Body).Decode(&tokenResult) + require.NoError(t, err) + + accessToken := tokenResult["access_token"].(string) + + // Decode JWT payload + parts := strings.Split(accessToken, ".") + require.Len(t, parts, 3) + + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + require.NoError(t, err) + + var claims map[string]any + err = json.Unmarshal(payload, &claims) + require.NoError(t, err) + + // Verify sub claim is preserved + sub, ok := claims["sub"].(string) + require.True(t, ok, "sub claim should be present") + require.Equal(t, "test-user@example.com", sub) + + // Verify userinfo claim is preserved + userinfo, ok := claims["userinfo"].(map[string]any) + require.True(t, ok, "userinfo claim should be present") + require.Equal(t, "test-user@example.com", userinfo["email"]) + require.Equal(t, "Test User", userinfo["name"]) +} diff --git a/pkg/models/models.go b/pkg/models/models.go index 2fef36b..7c4507b 100644 --- a/pkg/models/models.go +++ b/pkg/models/models.go @@ -1,6 +1,7 @@ package models import ( + "encoding/json" "net/url" "time" @@ -17,6 +18,7 @@ type Request struct { RequestedAudience []string GrantedAudience []string RotatedAt time.Time + SessionData json.RawMessage `json:",omitempty"` } type Client struct { @@ -43,7 +45,7 @@ type AuthorizeRequest struct { func FromFositeReq(reqester fosite.Requester) *Request { req := reqester.(*fosite.Request) - return &Request{ + r := &Request{ ID: req.ID, RequestedAt: req.RequestedAt, Client: FromFositeClient(req.Client), @@ -53,6 +55,12 @@ func FromFositeReq(reqester fosite.Requester) *Request { RequestedAudience: req.RequestedAudience, GrantedAudience: req.GrantedAudience, } + if sess := req.GetSession(); sess != nil { + if data, err := json.Marshal(sess); err == nil { + r.SessionData = data + } + } + return r } func (r *Request) ToFositeReq() *fosite.Request { diff --git a/pkg/repository/interface.go b/pkg/repository/interface.go index 3bc71bd..ea47cc1 100644 --- a/pkg/repository/interface.go +++ b/pkg/repository/interface.go @@ -2,6 +2,8 @@ package repository import ( "context" + "encoding/json" + "fmt" "github.com/ory/fosite" "github.com/ory/fosite/handler/oauth2" @@ -27,3 +29,13 @@ type AuthorizeRequestStorage interface { GetAuthorizeRequest(ctx context.Context, requestID string) (fosite.AuthorizeRequester, error) DeleteAuthorizeRequest(ctx context.Context, requestID string) error } + +func restoreSession(req *fosite.Request, sessionData json.RawMessage, sess fosite.Session) error { + if len(sessionData) > 0 && sess != nil { + if err := json.Unmarshal(sessionData, sess); err != nil { + return fmt.Errorf("failed to unmarshal session data: %w", err) + } + req.SetSession(sess) + } + return nil +} diff --git a/pkg/repository/kvs.go b/pkg/repository/kvs.go index 61ec413..90dae76 100644 --- a/pkg/repository/kvs.go +++ b/pkg/repository/kvs.go @@ -89,7 +89,9 @@ func (r *kvsRepository) GetAuthorizeCodeSession(ctx context.Context, code string return nil, err } fositeReq := req.ToFositeReq() - fositeReq.SetSession(sess) + if err := restoreSession(fositeReq, req.SessionData, sess); err != nil { + return nil, err + } return fositeReq, nil } @@ -107,7 +109,9 @@ func (r *kvsRepository) GetAccessTokenSession(ctx context.Context, signature str return nil, err } fositeReq := req.ToFositeReq() - fositeReq.SetSession(sess) + if err := restoreSession(fositeReq, req.SessionData, sess); err != nil { + return nil, err + } return fositeReq, nil } @@ -125,7 +129,9 @@ func (r *kvsRepository) GetRefreshTokenSession(ctx context.Context, signature st return nil, err } fositeReq := req.ToFositeReq() - fositeReq.SetSession(sess) + if err := restoreSession(fositeReq, req.SessionData, sess); err != nil { + return nil, err + } return fositeReq, nil } @@ -200,7 +206,9 @@ func (r *kvsRepository) GetPKCERequestSession(ctx context.Context, signature str return nil, err } fositeReq := req.ToFositeReq() - fositeReq.SetSession(sess) + if err := restoreSession(fositeReq, req.SessionData, sess); err != nil { + return nil, err + } return fositeReq, nil } diff --git a/pkg/repository/sql.go b/pkg/repository/sql.go index 830093f..3ce041b 100644 --- a/pkg/repository/sql.go +++ b/pkg/repository/sql.go @@ -356,8 +356,8 @@ func unmarshalRequest(data []byte, sess fosite.Session) (fosite.Requester, error return nil, fmt.Errorf("failed to unmarshal request: %w", err) } fositeReq := req.ToFositeReq() - if sess != nil { - fositeReq.SetSession(sess) + if err := restoreSession(fositeReq, req.SessionData, sess); err != nil { + return nil, err } return fositeReq, nil } diff --git a/pkg/repository/sql_test.go b/pkg/repository/sql_test.go index 1c1fd41..59ded95 100644 --- a/pkg/repository/sql_test.go +++ b/pkg/repository/sql_test.go @@ -52,6 +52,88 @@ func TestSQLRepositoryAccessTokenSession(t *testing.T) { } } +func TestSQLRepositorySessionPersistence(t *testing.T) { + repo, err := NewSQLRepository("sqlite", "file::memory:?cache=shared&_busy_timeout=5000") + if err != nil { + t.Fatalf("failed to create sql repository: %v", err) + } + defer repo.Close() + + sess := &fosite.DefaultSession{ + Username: "test-user", + Subject: "test-user", + } + + ctx := context.Background() + client := &fosite.DefaultClient{ + ID: "client-sess", + Secret: []byte("secret"), + RedirectURIs: []string{"https://example.com/callback"}, + } + + req := &fosite.Request{ + ID: "req-sess", + RequestedAt: time.Now().UTC().Round(time.Second), + Client: client, + RequestedScope: []string{"openid"}, + Form: url.Values{"code": {"value"}}, + Session: sess, + } + + if err := repo.CreateAuthorizeCodeSession(ctx, "code-sess", req); err != nil { + t.Fatalf("CreateAuthorizeCodeSession failed: %v", err) + } + + result, err := repo.GetAuthorizeCodeSession(ctx, "code-sess", &fosite.DefaultSession{}) + if err != nil { + t.Fatalf("GetAuthorizeCodeSession failed: %v", err) + } + + restored := result.GetSession().(*fosite.DefaultSession) + if restored.GetSubject() != "test-user" { + t.Fatalf("expected subject 'test-user', got '%s'", restored.GetSubject()) + } + if restored.GetUsername() != "test-user" { + t.Fatalf("expected username 'test-user', got '%s'", restored.GetUsername()) + } +} + +func TestSQLRepositorySessionPersistence_NilSessionStored(t *testing.T) { + repo, err := NewSQLRepository("sqlite", "file::memory:?cache=shared&_busy_timeout=5000") + if err != nil { + t.Fatalf("failed to create sql repository: %v", err) + } + defer repo.Close() + + ctx := context.Background() + client := &fosite.DefaultClient{ + ID: "client-old", + Secret: []byte("secret"), + RedirectURIs: []string{"https://example.com/callback"}, + } + + req := &fosite.Request{ + ID: "req-old", + RequestedAt: time.Now().UTC().Round(time.Second), + Client: client, + RequestedScope: []string{"openid"}, + Form: url.Values{"code": {"value"}}, + } + + if err := repo.CreateAuthorizeCodeSession(ctx, "code-old", req); err != nil { + t.Fatalf("CreateAuthorizeCodeSession failed: %v", err) + } + + result, err := repo.GetAuthorizeCodeSession(ctx, "code-old", &fosite.DefaultSession{}) + if err != nil { + t.Fatalf("GetAuthorizeCodeSession failed: %v", err) + } + + if result.GetSession() != nil { + t.Fatalf("expected nil session when no session was stored, got %v", result.GetSession()) + } +} + func TestSQLRepositoryUnsupportedDriver(t *testing.T) { if _, err := NewSQLRepository("unsupported", "dsn"); err == nil { t.Fatalf("expected error for unsupported driver but got nil")