Skip to content

Commit 4004e4d

Browse files
authored
fix: persist JWT session in storage so user identity survives token exchange (#146)
* fix: persist JWT session in storage so user identity survives token exchange The session (sub, userinfo) was lost during authorization code → access token exchange because models.Request did not serialize the session. The storage layer then overwrote the stored session with an empty one passed from handleToken. Add SessionData field to models.Request to serialize/deserialize the session. On restore, only populate the session when stored data exists (no fallback to empty session for old data). * fix: propagate session unmarshal errors instead of silently discarding
1 parent 3d6e35c commit 4004e4d

6 files changed

Lines changed: 179 additions & 8 deletions

File tree

pkg/idp/idp_test.go

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,12 @@ func setupTestServer(t *testing.T) (*httptest.Server, repository.Repository, str
5252
store := cookie.NewStore(secret[:])
5353
router.Use(sessions.Sessions("test_session", store))
5454

55-
// Mock auth middleware that always passes
55+
// Mock auth middleware that always passes with user identity
5656
router.Use(func(c *gin.Context) {
5757
session := sessions.Default(c)
5858
session.Set(auth.SessionKeyAuthorized, true)
59+
session.Set(auth.SessionKeyUserID, "[email protected]")
60+
session.Set(auth.SessionKeyUserInfo, `{"email":"[email protected]","name":"Test User"}`)
5961
err := session.Save()
6062
if err != nil {
6163
c.JSON(500, gin.H{"error": "Failed to save session"})
@@ -466,3 +468,62 @@ func TestAccessTokenAudienceClaim(t *testing.T) {
466468
require.True(t, ok, "aud claim should be present as an array")
467469
require.Contains(t, aud, "http://localhost:8080", "aud should contain the external URL")
468470
}
471+
472+
func TestAccessTokenPreservesUserIdentity(t *testing.T) {
473+
server, _, _ := setupTestServer(t)
474+
regResp := registerTestClient(t, server.URL)
475+
476+
config := &oauth2.Config{
477+
ClientID: regResp.ClientID,
478+
ClientSecret: regResp.ClientSecret,
479+
RedirectURL: "http://localhost:8080/callback",
480+
Scopes: []string{},
481+
Endpoint: oauth2.Endpoint{
482+
AuthURL: server.URL + AuthorizationEndpoint,
483+
TokenURL: server.URL + TokenEndpoint,
484+
},
485+
}
486+
487+
callbackURL := testAuthFlowWithURL(t, server.URL, config.AuthCodeURL("test-state"))
488+
code := callbackURL.Query().Get("code")
489+
490+
tokenReq := url.Values{}
491+
tokenReq.Set("grant_type", "authorization_code")
492+
tokenReq.Set("code", code)
493+
tokenReq.Set("redirect_uri", "http://localhost:8080/callback")
494+
tokenReq.Set("client_id", regResp.ClientID)
495+
tokenReq.Set("client_secret", regResp.ClientSecret)
496+
497+
tokenResp, err := http.PostForm(server.URL+TokenEndpoint, tokenReq)
498+
require.NoError(t, err)
499+
defer tokenResp.Body.Close()
500+
require.Equal(t, http.StatusOK, tokenResp.StatusCode)
501+
502+
var tokenResult map[string]any
503+
err = json.NewDecoder(tokenResp.Body).Decode(&tokenResult)
504+
require.NoError(t, err)
505+
506+
accessToken := tokenResult["access_token"].(string)
507+
508+
// Decode JWT payload
509+
parts := strings.Split(accessToken, ".")
510+
require.Len(t, parts, 3)
511+
512+
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
513+
require.NoError(t, err)
514+
515+
var claims map[string]any
516+
err = json.Unmarshal(payload, &claims)
517+
require.NoError(t, err)
518+
519+
// Verify sub claim is preserved
520+
sub, ok := claims["sub"].(string)
521+
require.True(t, ok, "sub claim should be present")
522+
require.Equal(t, "[email protected]", sub)
523+
524+
// Verify userinfo claim is preserved
525+
userinfo, ok := claims["userinfo"].(map[string]any)
526+
require.True(t, ok, "userinfo claim should be present")
527+
require.Equal(t, "[email protected]", userinfo["email"])
528+
require.Equal(t, "Test User", userinfo["name"])
529+
}

pkg/models/models.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package models
22

33
import (
4+
"encoding/json"
45
"net/url"
56
"time"
67

@@ -17,6 +18,7 @@ type Request struct {
1718
RequestedAudience []string
1819
GrantedAudience []string
1920
RotatedAt time.Time
21+
SessionData json.RawMessage `json:",omitempty"`
2022
}
2123

2224
type Client struct {
@@ -43,7 +45,7 @@ type AuthorizeRequest struct {
4345

4446
func FromFositeReq(reqester fosite.Requester) *Request {
4547
req := reqester.(*fosite.Request)
46-
return &Request{
48+
r := &Request{
4749
ID: req.ID,
4850
RequestedAt: req.RequestedAt,
4951
Client: FromFositeClient(req.Client),
@@ -53,6 +55,12 @@ func FromFositeReq(reqester fosite.Requester) *Request {
5355
RequestedAudience: req.RequestedAudience,
5456
GrantedAudience: req.GrantedAudience,
5557
}
58+
if sess := req.GetSession(); sess != nil {
59+
if data, err := json.Marshal(sess); err == nil {
60+
r.SessionData = data
61+
}
62+
}
63+
return r
5664
}
5765

5866
func (r *Request) ToFositeReq() *fosite.Request {

pkg/repository/interface.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package repository
22

33
import (
44
"context"
5+
"encoding/json"
6+
"fmt"
57

68
"github.com/ory/fosite"
79
"github.com/ory/fosite/handler/oauth2"
@@ -27,3 +29,13 @@ type AuthorizeRequestStorage interface {
2729
GetAuthorizeRequest(ctx context.Context, requestID string) (fosite.AuthorizeRequester, error)
2830
DeleteAuthorizeRequest(ctx context.Context, requestID string) error
2931
}
32+
33+
func restoreSession(req *fosite.Request, sessionData json.RawMessage, sess fosite.Session) error {
34+
if len(sessionData) > 0 && sess != nil {
35+
if err := json.Unmarshal(sessionData, sess); err != nil {
36+
return fmt.Errorf("failed to unmarshal session data: %w", err)
37+
}
38+
req.SetSession(sess)
39+
}
40+
return nil
41+
}

pkg/repository/kvs.go

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,9 @@ func (r *kvsRepository) GetAuthorizeCodeSession(ctx context.Context, code string
8989
return nil, err
9090
}
9191
fositeReq := req.ToFositeReq()
92-
fositeReq.SetSession(sess)
92+
if err := restoreSession(fositeReq, req.SessionData, sess); err != nil {
93+
return nil, err
94+
}
9395
return fositeReq, nil
9496
}
9597

@@ -107,7 +109,9 @@ func (r *kvsRepository) GetAccessTokenSession(ctx context.Context, signature str
107109
return nil, err
108110
}
109111
fositeReq := req.ToFositeReq()
110-
fositeReq.SetSession(sess)
112+
if err := restoreSession(fositeReq, req.SessionData, sess); err != nil {
113+
return nil, err
114+
}
111115
return fositeReq, nil
112116
}
113117

@@ -125,7 +129,9 @@ func (r *kvsRepository) GetRefreshTokenSession(ctx context.Context, signature st
125129
return nil, err
126130
}
127131
fositeReq := req.ToFositeReq()
128-
fositeReq.SetSession(sess)
132+
if err := restoreSession(fositeReq, req.SessionData, sess); err != nil {
133+
return nil, err
134+
}
129135
return fositeReq, nil
130136
}
131137

@@ -200,7 +206,9 @@ func (r *kvsRepository) GetPKCERequestSession(ctx context.Context, signature str
200206
return nil, err
201207
}
202208
fositeReq := req.ToFositeReq()
203-
fositeReq.SetSession(sess)
209+
if err := restoreSession(fositeReq, req.SessionData, sess); err != nil {
210+
return nil, err
211+
}
204212
return fositeReq, nil
205213
}
206214

pkg/repository/sql.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,8 +356,8 @@ func unmarshalRequest(data []byte, sess fosite.Session) (fosite.Requester, error
356356
return nil, fmt.Errorf("failed to unmarshal request: %w", err)
357357
}
358358
fositeReq := req.ToFositeReq()
359-
if sess != nil {
360-
fositeReq.SetSession(sess)
359+
if err := restoreSession(fositeReq, req.SessionData, sess); err != nil {
360+
return nil, err
361361
}
362362
return fositeReq, nil
363363
}

pkg/repository/sql_test.go

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,88 @@ func TestSQLRepositoryAccessTokenSession(t *testing.T) {
5252
}
5353
}
5454

55+
func TestSQLRepositorySessionPersistence(t *testing.T) {
56+
repo, err := NewSQLRepository("sqlite", "file::memory:?cache=shared&_busy_timeout=5000")
57+
if err != nil {
58+
t.Fatalf("failed to create sql repository: %v", err)
59+
}
60+
defer repo.Close()
61+
62+
sess := &fosite.DefaultSession{
63+
Username: "test-user",
64+
Subject: "test-user",
65+
}
66+
67+
ctx := context.Background()
68+
client := &fosite.DefaultClient{
69+
ID: "client-sess",
70+
Secret: []byte("secret"),
71+
RedirectURIs: []string{"https://example.com/callback"},
72+
}
73+
74+
req := &fosite.Request{
75+
ID: "req-sess",
76+
RequestedAt: time.Now().UTC().Round(time.Second),
77+
Client: client,
78+
RequestedScope: []string{"openid"},
79+
Form: url.Values{"code": {"value"}},
80+
Session: sess,
81+
}
82+
83+
if err := repo.CreateAuthorizeCodeSession(ctx, "code-sess", req); err != nil {
84+
t.Fatalf("CreateAuthorizeCodeSession failed: %v", err)
85+
}
86+
87+
result, err := repo.GetAuthorizeCodeSession(ctx, "code-sess", &fosite.DefaultSession{})
88+
if err != nil {
89+
t.Fatalf("GetAuthorizeCodeSession failed: %v", err)
90+
}
91+
92+
restored := result.GetSession().(*fosite.DefaultSession)
93+
if restored.GetSubject() != "test-user" {
94+
t.Fatalf("expected subject 'test-user', got '%s'", restored.GetSubject())
95+
}
96+
if restored.GetUsername() != "test-user" {
97+
t.Fatalf("expected username 'test-user', got '%s'", restored.GetUsername())
98+
}
99+
}
100+
101+
func TestSQLRepositorySessionPersistence_NilSessionStored(t *testing.T) {
102+
repo, err := NewSQLRepository("sqlite", "file::memory:?cache=shared&_busy_timeout=5000")
103+
if err != nil {
104+
t.Fatalf("failed to create sql repository: %v", err)
105+
}
106+
defer repo.Close()
107+
108+
ctx := context.Background()
109+
client := &fosite.DefaultClient{
110+
ID: "client-old",
111+
Secret: []byte("secret"),
112+
RedirectURIs: []string{"https://example.com/callback"},
113+
}
114+
115+
req := &fosite.Request{
116+
ID: "req-old",
117+
RequestedAt: time.Now().UTC().Round(time.Second),
118+
Client: client,
119+
RequestedScope: []string{"openid"},
120+
Form: url.Values{"code": {"value"}},
121+
}
122+
123+
if err := repo.CreateAuthorizeCodeSession(ctx, "code-old", req); err != nil {
124+
t.Fatalf("CreateAuthorizeCodeSession failed: %v", err)
125+
}
126+
127+
result, err := repo.GetAuthorizeCodeSession(ctx, "code-old", &fosite.DefaultSession{})
128+
if err != nil {
129+
t.Fatalf("GetAuthorizeCodeSession failed: %v", err)
130+
}
131+
132+
if result.GetSession() != nil {
133+
t.Fatalf("expected nil session when no session was stored, got %v", result.GetSession())
134+
}
135+
}
136+
55137
func TestSQLRepositoryUnsupportedDriver(t *testing.T) {
56138
if _, err := NewSQLRepository("unsupported", "dsn"); err == nil {
57139
t.Fatalf("expected error for unsupported driver but got nil")

0 commit comments

Comments
 (0)