From 1e470ba8f9523eed7009d9177b931aae6ea98916 Mon Sep 17 00:00:00 2001 From: Takanori Hirano Date: Fri, 17 Apr 2026 00:44:58 +0900 Subject: [PATCH 1/2] fix: persist JWT session in storage so user identity survives token exchange MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The session (sub, userinfo) was lost during authorization code → access token exchange because models.Request did not serialize the session. The storage layer then overwrote the stored session with an empty one passed from handleToken. Add SessionData field to models.Request to serialize/deserialize the session. On restore, only populate the session when stored data exists (no fallback to empty session for old data). --- pkg/idp/idp_test.go | 63 +++++++++++++++++++++++++++- pkg/models/models.go | 10 ++++- pkg/repository/interface.go | 9 ++++ pkg/repository/kvs.go | 8 ++-- pkg/repository/sql.go | 4 +- pkg/repository/sql_test.go | 84 +++++++++++++++++++++++++++++++++++++ 6 files changed, 169 insertions(+), 9 deletions(-) diff --git a/pkg/idp/idp_test.go b/pkg/idp/idp_test.go index 5bf8194..8c7a5d6 100644 --- a/pkg/idp/idp_test.go +++ b/pkg/idp/idp_test.go @@ -52,10 +52,12 @@ func setupTestServer(t *testing.T) (*httptest.Server, repository.Repository, str store := cookie.NewStore(secret[:]) router.Use(sessions.Sessions("test_session", store)) - // Mock auth middleware that always passes + // Mock auth middleware that always passes with user identity router.Use(func(c *gin.Context) { session := sessions.Default(c) session.Set(auth.SessionKeyAuthorized, true) + session.Set(auth.SessionKeyUserID, "test-user@example.com") + session.Set(auth.SessionKeyUserInfo, `{"email":"test-user@example.com","name":"Test User"}`) err := session.Save() if err != nil { c.JSON(500, gin.H{"error": "Failed to save session"}) @@ -466,3 +468,62 @@ func TestAccessTokenAudienceClaim(t *testing.T) { require.True(t, ok, "aud claim should be present as an array") require.Contains(t, aud, "http://localhost:8080", "aud should contain the external URL") } + +func TestAccessTokenPreservesUserIdentity(t *testing.T) { + server, _, _ := setupTestServer(t) + regResp := registerTestClient(t, server.URL) + + config := &oauth2.Config{ + ClientID: regResp.ClientID, + ClientSecret: regResp.ClientSecret, + RedirectURL: "http://localhost:8080/callback", + Scopes: []string{}, + Endpoint: oauth2.Endpoint{ + AuthURL: server.URL + AuthorizationEndpoint, + TokenURL: server.URL + TokenEndpoint, + }, + } + + callbackURL := testAuthFlowWithURL(t, server.URL, config.AuthCodeURL("test-state")) + code := callbackURL.Query().Get("code") + + tokenReq := url.Values{} + tokenReq.Set("grant_type", "authorization_code") + tokenReq.Set("code", code) + tokenReq.Set("redirect_uri", "http://localhost:8080/callback") + tokenReq.Set("client_id", regResp.ClientID) + tokenReq.Set("client_secret", regResp.ClientSecret) + + tokenResp, err := http.PostForm(server.URL+TokenEndpoint, tokenReq) + require.NoError(t, err) + defer tokenResp.Body.Close() + require.Equal(t, http.StatusOK, tokenResp.StatusCode) + + var tokenResult map[string]any + err = json.NewDecoder(tokenResp.Body).Decode(&tokenResult) + require.NoError(t, err) + + accessToken := tokenResult["access_token"].(string) + + // Decode JWT payload + parts := strings.Split(accessToken, ".") + require.Len(t, parts, 3) + + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + require.NoError(t, err) + + var claims map[string]any + err = json.Unmarshal(payload, &claims) + require.NoError(t, err) + + // Verify sub claim is preserved + sub, ok := claims["sub"].(string) + require.True(t, ok, "sub claim should be present") + require.Equal(t, "test-user@example.com", sub) + + // Verify userinfo claim is preserved + userinfo, ok := claims["userinfo"].(map[string]any) + require.True(t, ok, "userinfo claim should be present") + require.Equal(t, "test-user@example.com", userinfo["email"]) + require.Equal(t, "Test User", userinfo["name"]) +} diff --git a/pkg/models/models.go b/pkg/models/models.go index 2fef36b..7c4507b 100644 --- a/pkg/models/models.go +++ b/pkg/models/models.go @@ -1,6 +1,7 @@ package models import ( + "encoding/json" "net/url" "time" @@ -17,6 +18,7 @@ type Request struct { RequestedAudience []string GrantedAudience []string RotatedAt time.Time + SessionData json.RawMessage `json:",omitempty"` } type Client struct { @@ -43,7 +45,7 @@ type AuthorizeRequest struct { func FromFositeReq(reqester fosite.Requester) *Request { req := reqester.(*fosite.Request) - return &Request{ + r := &Request{ ID: req.ID, RequestedAt: req.RequestedAt, Client: FromFositeClient(req.Client), @@ -53,6 +55,12 @@ func FromFositeReq(reqester fosite.Requester) *Request { RequestedAudience: req.RequestedAudience, GrantedAudience: req.GrantedAudience, } + if sess := req.GetSession(); sess != nil { + if data, err := json.Marshal(sess); err == nil { + r.SessionData = data + } + } + return r } func (r *Request) ToFositeReq() *fosite.Request { diff --git a/pkg/repository/interface.go b/pkg/repository/interface.go index 3bc71bd..e576bf1 100644 --- a/pkg/repository/interface.go +++ b/pkg/repository/interface.go @@ -2,6 +2,7 @@ package repository import ( "context" + "encoding/json" "github.com/ory/fosite" "github.com/ory/fosite/handler/oauth2" @@ -27,3 +28,11 @@ type AuthorizeRequestStorage interface { GetAuthorizeRequest(ctx context.Context, requestID string) (fosite.AuthorizeRequester, error) DeleteAuthorizeRequest(ctx context.Context, requestID string) error } + +func restoreSession(req *fosite.Request, sessionData json.RawMessage, sess fosite.Session) { + if len(sessionData) > 0 && sess != nil { + if json.Unmarshal(sessionData, sess) == nil { + req.SetSession(sess) + } + } +} diff --git a/pkg/repository/kvs.go b/pkg/repository/kvs.go index 61ec413..409f617 100644 --- a/pkg/repository/kvs.go +++ b/pkg/repository/kvs.go @@ -89,7 +89,7 @@ func (r *kvsRepository) GetAuthorizeCodeSession(ctx context.Context, code string return nil, err } fositeReq := req.ToFositeReq() - fositeReq.SetSession(sess) + restoreSession(fositeReq, req.SessionData, sess) return fositeReq, nil } @@ -107,7 +107,7 @@ func (r *kvsRepository) GetAccessTokenSession(ctx context.Context, signature str return nil, err } fositeReq := req.ToFositeReq() - fositeReq.SetSession(sess) + restoreSession(fositeReq, req.SessionData, sess) return fositeReq, nil } @@ -125,7 +125,7 @@ func (r *kvsRepository) GetRefreshTokenSession(ctx context.Context, signature st return nil, err } fositeReq := req.ToFositeReq() - fositeReq.SetSession(sess) + restoreSession(fositeReq, req.SessionData, sess) return fositeReq, nil } @@ -200,7 +200,7 @@ func (r *kvsRepository) GetPKCERequestSession(ctx context.Context, signature str return nil, err } fositeReq := req.ToFositeReq() - fositeReq.SetSession(sess) + restoreSession(fositeReq, req.SessionData, sess) return fositeReq, nil } diff --git a/pkg/repository/sql.go b/pkg/repository/sql.go index 830093f..eb84938 100644 --- a/pkg/repository/sql.go +++ b/pkg/repository/sql.go @@ -356,9 +356,7 @@ func unmarshalRequest(data []byte, sess fosite.Session) (fosite.Requester, error return nil, fmt.Errorf("failed to unmarshal request: %w", err) } fositeReq := req.ToFositeReq() - if sess != nil { - fositeReq.SetSession(sess) - } + restoreSession(fositeReq, req.SessionData, sess) return fositeReq, nil } diff --git a/pkg/repository/sql_test.go b/pkg/repository/sql_test.go index 1c1fd41..c80e2e5 100644 --- a/pkg/repository/sql_test.go +++ b/pkg/repository/sql_test.go @@ -52,6 +52,90 @@ func TestSQLRepositoryAccessTokenSession(t *testing.T) { } } +func TestSQLRepositorySessionPersistence(t *testing.T) { + repo, err := NewSQLRepository("sqlite", "file::memory:?cache=shared&_busy_timeout=5000") + if err != nil { + t.Fatalf("failed to create sql repository: %v", err) + } + defer repo.Close() + + sess := &fosite.DefaultSession{ + Username: "test-user", + Subject: "test-user", + } + + ctx := context.Background() + client := &fosite.DefaultClient{ + ID: "client-sess", + Secret: []byte("secret"), + RedirectURIs: []string{"https://example.com/callback"}, + } + + req := &fosite.Request{ + ID: "req-sess", + RequestedAt: time.Now().UTC().Round(time.Second), + Client: client, + RequestedScope: []string{"openid"}, + Form: url.Values{"code": {"value"}}, + Session: sess, + } + + if err := repo.CreateAuthorizeCodeSession(ctx, "code-sess", req); err != nil { + t.Fatalf("CreateAuthorizeCodeSession failed: %v", err) + } + + result, err := repo.GetAuthorizeCodeSession(ctx, "code-sess", &fosite.DefaultSession{}) + if err != nil { + t.Fatalf("GetAuthorizeCodeSession failed: %v", err) + } + + restored := result.GetSession().(*fosite.DefaultSession) + if restored.GetSubject() != "test-user" { + t.Fatalf("expected subject 'test-user', got '%s'", restored.GetSubject()) + } + if restored.GetUsername() != "test-user" { + t.Fatalf("expected username 'test-user', got '%s'", restored.GetUsername()) + } +} + +func TestSQLRepositorySessionPersistence_BackwardsCompatible(t *testing.T) { + repo, err := NewSQLRepository("sqlite", "file::memory:?cache=shared&_busy_timeout=5000") + if err != nil { + t.Fatalf("failed to create sql repository: %v", err) + } + defer repo.Close() + + ctx := context.Background() + client := &fosite.DefaultClient{ + ID: "client-old", + Secret: []byte("secret"), + RedirectURIs: []string{"https://example.com/callback"}, + } + + // Simulate old data without session + req := &fosite.Request{ + ID: "req-old", + RequestedAt: time.Now().UTC().Round(time.Second), + Client: client, + RequestedScope: []string{"openid"}, + Form: url.Values{"code": {"value"}}, + } + + if err := repo.CreateAuthorizeCodeSession(ctx, "code-old", req); err != nil { + t.Fatalf("CreateAuthorizeCodeSession failed: %v", err) + } + + result, err := repo.GetAuthorizeCodeSession(ctx, "code-old", &fosite.DefaultSession{}) + if err != nil { + t.Fatalf("GetAuthorizeCodeSession failed: %v", err) + } + + // Session should be nil (no data to restore, no fallback) + if result.GetSession() != nil { + t.Fatalf("expected nil session for old data, got %v", result.GetSession()) + } +} + func TestSQLRepositoryUnsupportedDriver(t *testing.T) { if _, err := NewSQLRepository("unsupported", "dsn"); err == nil { t.Fatalf("expected error for unsupported driver but got nil") From eecbd18f113fac9abcbb8cefe31a744701655413 Mon Sep 17 00:00:00 2001 From: Takanori Hirano Date: Fri, 17 Apr 2026 00:47:49 +0900 Subject: [PATCH 2/2] fix: propagate session unmarshal errors instead of silently discarding --- pkg/repository/interface.go | 9 ++++++--- pkg/repository/kvs.go | 16 ++++++++++++---- pkg/repository/sql.go | 4 +++- pkg/repository/sql_test.go | 6 ++---- 4 files changed, 23 insertions(+), 12 deletions(-) diff --git a/pkg/repository/interface.go b/pkg/repository/interface.go index e576bf1..ea47cc1 100644 --- a/pkg/repository/interface.go +++ b/pkg/repository/interface.go @@ -3,6 +3,7 @@ package repository import ( "context" "encoding/json" + "fmt" "github.com/ory/fosite" "github.com/ory/fosite/handler/oauth2" @@ -29,10 +30,12 @@ type AuthorizeRequestStorage interface { DeleteAuthorizeRequest(ctx context.Context, requestID string) error } -func restoreSession(req *fosite.Request, sessionData json.RawMessage, sess fosite.Session) { +func restoreSession(req *fosite.Request, sessionData json.RawMessage, sess fosite.Session) error { if len(sessionData) > 0 && sess != nil { - if json.Unmarshal(sessionData, sess) == nil { - req.SetSession(sess) + if err := json.Unmarshal(sessionData, sess); err != nil { + return fmt.Errorf("failed to unmarshal session data: %w", err) } + req.SetSession(sess) } + return nil } diff --git a/pkg/repository/kvs.go b/pkg/repository/kvs.go index 409f617..90dae76 100644 --- a/pkg/repository/kvs.go +++ b/pkg/repository/kvs.go @@ -89,7 +89,9 @@ func (r *kvsRepository) GetAuthorizeCodeSession(ctx context.Context, code string return nil, err } fositeReq := req.ToFositeReq() - restoreSession(fositeReq, req.SessionData, sess) + if err := restoreSession(fositeReq, req.SessionData, sess); err != nil { + return nil, err + } return fositeReq, nil } @@ -107,7 +109,9 @@ func (r *kvsRepository) GetAccessTokenSession(ctx context.Context, signature str return nil, err } fositeReq := req.ToFositeReq() - restoreSession(fositeReq, req.SessionData, sess) + if err := restoreSession(fositeReq, req.SessionData, sess); err != nil { + return nil, err + } return fositeReq, nil } @@ -125,7 +129,9 @@ func (r *kvsRepository) GetRefreshTokenSession(ctx context.Context, signature st return nil, err } fositeReq := req.ToFositeReq() - restoreSession(fositeReq, req.SessionData, sess) + if err := restoreSession(fositeReq, req.SessionData, sess); err != nil { + return nil, err + } return fositeReq, nil } @@ -200,7 +206,9 @@ func (r *kvsRepository) GetPKCERequestSession(ctx context.Context, signature str return nil, err } fositeReq := req.ToFositeReq() - restoreSession(fositeReq, req.SessionData, sess) + if err := restoreSession(fositeReq, req.SessionData, sess); err != nil { + return nil, err + } return fositeReq, nil } diff --git a/pkg/repository/sql.go b/pkg/repository/sql.go index eb84938..3ce041b 100644 --- a/pkg/repository/sql.go +++ b/pkg/repository/sql.go @@ -356,7 +356,9 @@ func unmarshalRequest(data []byte, sess fosite.Session) (fosite.Requester, error return nil, fmt.Errorf("failed to unmarshal request: %w", err) } fositeReq := req.ToFositeReq() - restoreSession(fositeReq, req.SessionData, sess) + if err := restoreSession(fositeReq, req.SessionData, sess); err != nil { + return nil, err + } return fositeReq, nil } diff --git a/pkg/repository/sql_test.go b/pkg/repository/sql_test.go index c80e2e5..59ded95 100644 --- a/pkg/repository/sql_test.go +++ b/pkg/repository/sql_test.go @@ -98,7 +98,7 @@ func TestSQLRepositorySessionPersistence(t *testing.T) { } } -func TestSQLRepositorySessionPersistence_BackwardsCompatible(t *testing.T) { +func TestSQLRepositorySessionPersistence_NilSessionStored(t *testing.T) { repo, err := NewSQLRepository("sqlite", "file::memory:?cache=shared&_busy_timeout=5000") if err != nil { t.Fatalf("failed to create sql repository: %v", err) @@ -112,7 +112,6 @@ func TestSQLRepositorySessionPersistence_BackwardsCompatible(t *testing.T) { RedirectURIs: []string{"https://example.com/callback"}, } - // Simulate old data without session req := &fosite.Request{ ID: "req-old", RequestedAt: time.Now().UTC().Round(time.Second), @@ -130,9 +129,8 @@ func TestSQLRepositorySessionPersistence_BackwardsCompatible(t *testing.T) { t.Fatalf("GetAuthorizeCodeSession failed: %v", err) } - // Session should be nil (no data to restore, no fallback) if result.GetSession() != nil { - t.Fatalf("expected nil session for old data, got %v", result.GetSession()) + t.Fatalf("expected nil session when no session was stored, got %v", result.GetSession()) } }