Skip to content

Commit 1e470ba

Browse files
committed
fix: persist JWT session in storage so user identity survives token exchange
The session (sub, userinfo) was lost during authorization code → access token exchange because models.Request did not serialize the session. The storage layer then overwrote the stored session with an empty one passed from handleToken. Add SessionData field to models.Request to serialize/deserialize the session. On restore, only populate the session when stored data exists (no fallback to empty session for old data).
1 parent 6f6d916 commit 1e470ba

6 files changed

Lines changed: 169 additions & 9 deletions

File tree

pkg/idp/idp_test.go

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,12 @@ func setupTestServer(t *testing.T) (*httptest.Server, repository.Repository, str
5252
store := cookie.NewStore(secret[:])
5353
router.Use(sessions.Sessions("test_session", store))
5454

55-
// Mock auth middleware that always passes
55+
// Mock auth middleware that always passes with user identity
5656
router.Use(func(c *gin.Context) {
5757
session := sessions.Default(c)
5858
session.Set(auth.SessionKeyAuthorized, true)
59+
session.Set(auth.SessionKeyUserID, "[email protected]")
60+
session.Set(auth.SessionKeyUserInfo, `{"email":"[email protected]","name":"Test User"}`)
5961
err := session.Save()
6062
if err != nil {
6163
c.JSON(500, gin.H{"error": "Failed to save session"})
@@ -466,3 +468,62 @@ func TestAccessTokenAudienceClaim(t *testing.T) {
466468
require.True(t, ok, "aud claim should be present as an array")
467469
require.Contains(t, aud, "http://localhost:8080", "aud should contain the external URL")
468470
}
471+
472+
func TestAccessTokenPreservesUserIdentity(t *testing.T) {
473+
server, _, _ := setupTestServer(t)
474+
regResp := registerTestClient(t, server.URL)
475+
476+
config := &oauth2.Config{
477+
ClientID: regResp.ClientID,
478+
ClientSecret: regResp.ClientSecret,
479+
RedirectURL: "http://localhost:8080/callback",
480+
Scopes: []string{},
481+
Endpoint: oauth2.Endpoint{
482+
AuthURL: server.URL + AuthorizationEndpoint,
483+
TokenURL: server.URL + TokenEndpoint,
484+
},
485+
}
486+
487+
callbackURL := testAuthFlowWithURL(t, server.URL, config.AuthCodeURL("test-state"))
488+
code := callbackURL.Query().Get("code")
489+
490+
tokenReq := url.Values{}
491+
tokenReq.Set("grant_type", "authorization_code")
492+
tokenReq.Set("code", code)
493+
tokenReq.Set("redirect_uri", "http://localhost:8080/callback")
494+
tokenReq.Set("client_id", regResp.ClientID)
495+
tokenReq.Set("client_secret", regResp.ClientSecret)
496+
497+
tokenResp, err := http.PostForm(server.URL+TokenEndpoint, tokenReq)
498+
require.NoError(t, err)
499+
defer tokenResp.Body.Close()
500+
require.Equal(t, http.StatusOK, tokenResp.StatusCode)
501+
502+
var tokenResult map[string]any
503+
err = json.NewDecoder(tokenResp.Body).Decode(&tokenResult)
504+
require.NoError(t, err)
505+
506+
accessToken := tokenResult["access_token"].(string)
507+
508+
// Decode JWT payload
509+
parts := strings.Split(accessToken, ".")
510+
require.Len(t, parts, 3)
511+
512+
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
513+
require.NoError(t, err)
514+
515+
var claims map[string]any
516+
err = json.Unmarshal(payload, &claims)
517+
require.NoError(t, err)
518+
519+
// Verify sub claim is preserved
520+
sub, ok := claims["sub"].(string)
521+
require.True(t, ok, "sub claim should be present")
522+
require.Equal(t, "[email protected]", sub)
523+
524+
// Verify userinfo claim is preserved
525+
userinfo, ok := claims["userinfo"].(map[string]any)
526+
require.True(t, ok, "userinfo claim should be present")
527+
require.Equal(t, "[email protected]", userinfo["email"])
528+
require.Equal(t, "Test User", userinfo["name"])
529+
}

pkg/models/models.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package models
22

33
import (
4+
"encoding/json"
45
"net/url"
56
"time"
67

@@ -17,6 +18,7 @@ type Request struct {
1718
RequestedAudience []string
1819
GrantedAudience []string
1920
RotatedAt time.Time
21+
SessionData json.RawMessage `json:",omitempty"`
2022
}
2123

2224
type Client struct {
@@ -43,7 +45,7 @@ type AuthorizeRequest struct {
4345

4446
func FromFositeReq(reqester fosite.Requester) *Request {
4547
req := reqester.(*fosite.Request)
46-
return &Request{
48+
r := &Request{
4749
ID: req.ID,
4850
RequestedAt: req.RequestedAt,
4951
Client: FromFositeClient(req.Client),
@@ -53,6 +55,12 @@ func FromFositeReq(reqester fosite.Requester) *Request {
5355
RequestedAudience: req.RequestedAudience,
5456
GrantedAudience: req.GrantedAudience,
5557
}
58+
if sess := req.GetSession(); sess != nil {
59+
if data, err := json.Marshal(sess); err == nil {
60+
r.SessionData = data
61+
}
62+
}
63+
return r
5664
}
5765

5866
func (r *Request) ToFositeReq() *fosite.Request {

pkg/repository/interface.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package repository
22

33
import (
44
"context"
5+
"encoding/json"
56

67
"github.com/ory/fosite"
78
"github.com/ory/fosite/handler/oauth2"
@@ -27,3 +28,11 @@ type AuthorizeRequestStorage interface {
2728
GetAuthorizeRequest(ctx context.Context, requestID string) (fosite.AuthorizeRequester, error)
2829
DeleteAuthorizeRequest(ctx context.Context, requestID string) error
2930
}
31+
32+
func restoreSession(req *fosite.Request, sessionData json.RawMessage, sess fosite.Session) {
33+
if len(sessionData) > 0 && sess != nil {
34+
if json.Unmarshal(sessionData, sess) == nil {
35+
req.SetSession(sess)
36+
}
37+
}
38+
}

pkg/repository/kvs.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ func (r *kvsRepository) GetAuthorizeCodeSession(ctx context.Context, code string
8989
return nil, err
9090
}
9191
fositeReq := req.ToFositeReq()
92-
fositeReq.SetSession(sess)
92+
restoreSession(fositeReq, req.SessionData, sess)
9393
return fositeReq, nil
9494
}
9595

@@ -107,7 +107,7 @@ func (r *kvsRepository) GetAccessTokenSession(ctx context.Context, signature str
107107
return nil, err
108108
}
109109
fositeReq := req.ToFositeReq()
110-
fositeReq.SetSession(sess)
110+
restoreSession(fositeReq, req.SessionData, sess)
111111
return fositeReq, nil
112112
}
113113

@@ -125,7 +125,7 @@ func (r *kvsRepository) GetRefreshTokenSession(ctx context.Context, signature st
125125
return nil, err
126126
}
127127
fositeReq := req.ToFositeReq()
128-
fositeReq.SetSession(sess)
128+
restoreSession(fositeReq, req.SessionData, sess)
129129
return fositeReq, nil
130130
}
131131

@@ -200,7 +200,7 @@ func (r *kvsRepository) GetPKCERequestSession(ctx context.Context, signature str
200200
return nil, err
201201
}
202202
fositeReq := req.ToFositeReq()
203-
fositeReq.SetSession(sess)
203+
restoreSession(fositeReq, req.SessionData, sess)
204204
return fositeReq, nil
205205
}
206206

pkg/repository/sql.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -356,9 +356,7 @@ func unmarshalRequest(data []byte, sess fosite.Session) (fosite.Requester, error
356356
return nil, fmt.Errorf("failed to unmarshal request: %w", err)
357357
}
358358
fositeReq := req.ToFositeReq()
359-
if sess != nil {
360-
fositeReq.SetSession(sess)
361-
}
359+
restoreSession(fositeReq, req.SessionData, sess)
362360
return fositeReq, nil
363361
}
364362

pkg/repository/sql_test.go

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,90 @@ func TestSQLRepositoryAccessTokenSession(t *testing.T) {
5252
}
5353
}
5454

55+
func TestSQLRepositorySessionPersistence(t *testing.T) {
56+
repo, err := NewSQLRepository("sqlite", "file::memory:?cache=shared&_busy_timeout=5000")
57+
if err != nil {
58+
t.Fatalf("failed to create sql repository: %v", err)
59+
}
60+
defer repo.Close()
61+
62+
sess := &fosite.DefaultSession{
63+
Username: "test-user",
64+
Subject: "test-user",
65+
}
66+
67+
ctx := context.Background()
68+
client := &fosite.DefaultClient{
69+
ID: "client-sess",
70+
Secret: []byte("secret"),
71+
RedirectURIs: []string{"https://example.com/callback"},
72+
}
73+
74+
req := &fosite.Request{
75+
ID: "req-sess",
76+
RequestedAt: time.Now().UTC().Round(time.Second),
77+
Client: client,
78+
RequestedScope: []string{"openid"},
79+
Form: url.Values{"code": {"value"}},
80+
Session: sess,
81+
}
82+
83+
if err := repo.CreateAuthorizeCodeSession(ctx, "code-sess", req); err != nil {
84+
t.Fatalf("CreateAuthorizeCodeSession failed: %v", err)
85+
}
86+
87+
result, err := repo.GetAuthorizeCodeSession(ctx, "code-sess", &fosite.DefaultSession{})
88+
if err != nil {
89+
t.Fatalf("GetAuthorizeCodeSession failed: %v", err)
90+
}
91+
92+
restored := result.GetSession().(*fosite.DefaultSession)
93+
if restored.GetSubject() != "test-user" {
94+
t.Fatalf("expected subject 'test-user', got '%s'", restored.GetSubject())
95+
}
96+
if restored.GetUsername() != "test-user" {
97+
t.Fatalf("expected username 'test-user', got '%s'", restored.GetUsername())
98+
}
99+
}
100+
101+
func TestSQLRepositorySessionPersistence_BackwardsCompatible(t *testing.T) {
102+
repo, err := NewSQLRepository("sqlite", "file::memory:?cache=shared&_busy_timeout=5000")
103+
if err != nil {
104+
t.Fatalf("failed to create sql repository: %v", err)
105+
}
106+
defer repo.Close()
107+
108+
ctx := context.Background()
109+
client := &fosite.DefaultClient{
110+
ID: "client-old",
111+
Secret: []byte("secret"),
112+
RedirectURIs: []string{"https://example.com/callback"},
113+
}
114+
115+
// Simulate old data without session
116+
req := &fosite.Request{
117+
ID: "req-old",
118+
RequestedAt: time.Now().UTC().Round(time.Second),
119+
Client: client,
120+
RequestedScope: []string{"openid"},
121+
Form: url.Values{"code": {"value"}},
122+
}
123+
124+
if err := repo.CreateAuthorizeCodeSession(ctx, "code-old", req); err != nil {
125+
t.Fatalf("CreateAuthorizeCodeSession failed: %v", err)
126+
}
127+
128+
result, err := repo.GetAuthorizeCodeSession(ctx, "code-old", &fosite.DefaultSession{})
129+
if err != nil {
130+
t.Fatalf("GetAuthorizeCodeSession failed: %v", err)
131+
}
132+
133+
// Session should be nil (no data to restore, no fallback)
134+
if result.GetSession() != nil {
135+
t.Fatalf("expected nil session for old data, got %v", result.GetSession())
136+
}
137+
}
138+
55139
func TestSQLRepositoryUnsupportedDriver(t *testing.T) {
56140
if _, err := NewSQLRepository("unsupported", "dsn"); err == nil {
57141
t.Fatalf("expected error for unsupported driver but got nil")

0 commit comments

Comments
 (0)