Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 62 additions & 1 deletion pkg/idp/idp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,12 @@ func setupTestServer(t *testing.T) (*httptest.Server, repository.Repository, str
store := cookie.NewStore(secret[:])
router.Use(sessions.Sessions("test_session", store))

// Mock auth middleware that always passes
// Mock auth middleware that always passes with user identity
router.Use(func(c *gin.Context) {
session := sessions.Default(c)
session.Set(auth.SessionKeyAuthorized, true)
session.Set(auth.SessionKeyUserID, "[email protected]")
session.Set(auth.SessionKeyUserInfo, `{"email":"[email protected]","name":"Test User"}`)
err := session.Save()
if err != nil {
c.JSON(500, gin.H{"error": "Failed to save session"})
Expand Down Expand Up @@ -466,3 +468,62 @@ func TestAccessTokenAudienceClaim(t *testing.T) {
require.True(t, ok, "aud claim should be present as an array")
require.Contains(t, aud, "http://localhost:8080", "aud should contain the external URL")
}

func TestAccessTokenPreservesUserIdentity(t *testing.T) {
server, _, _ := setupTestServer(t)
regResp := registerTestClient(t, server.URL)

config := &oauth2.Config{
ClientID: regResp.ClientID,
ClientSecret: regResp.ClientSecret,
RedirectURL: "http://localhost:8080/callback",
Scopes: []string{},
Endpoint: oauth2.Endpoint{
AuthURL: server.URL + AuthorizationEndpoint,
TokenURL: server.URL + TokenEndpoint,
},
}

callbackURL := testAuthFlowWithURL(t, server.URL, config.AuthCodeURL("test-state"))
code := callbackURL.Query().Get("code")

tokenReq := url.Values{}
tokenReq.Set("grant_type", "authorization_code")
tokenReq.Set("code", code)
tokenReq.Set("redirect_uri", "http://localhost:8080/callback")
tokenReq.Set("client_id", regResp.ClientID)
tokenReq.Set("client_secret", regResp.ClientSecret)

tokenResp, err := http.PostForm(server.URL+TokenEndpoint, tokenReq)
require.NoError(t, err)
defer tokenResp.Body.Close()
require.Equal(t, http.StatusOK, tokenResp.StatusCode)

var tokenResult map[string]any
err = json.NewDecoder(tokenResp.Body).Decode(&tokenResult)
require.NoError(t, err)

accessToken := tokenResult["access_token"].(string)

// Decode JWT payload
parts := strings.Split(accessToken, ".")
require.Len(t, parts, 3)

payload, err := base64.RawURLEncoding.DecodeString(parts[1])
require.NoError(t, err)

var claims map[string]any
err = json.Unmarshal(payload, &claims)
require.NoError(t, err)

// Verify sub claim is preserved
sub, ok := claims["sub"].(string)
require.True(t, ok, "sub claim should be present")
require.Equal(t, "[email protected]", sub)

// Verify userinfo claim is preserved
userinfo, ok := claims["userinfo"].(map[string]any)
require.True(t, ok, "userinfo claim should be present")
require.Equal(t, "[email protected]", userinfo["email"])
require.Equal(t, "Test User", userinfo["name"])
}
10 changes: 9 additions & 1 deletion pkg/models/models.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package models

import (
"encoding/json"
"net/url"
"time"

Expand All @@ -17,6 +18,7 @@ type Request struct {
RequestedAudience []string
GrantedAudience []string
RotatedAt time.Time
SessionData json.RawMessage `json:",omitempty"`
}

type Client struct {
Expand All @@ -43,7 +45,7 @@ type AuthorizeRequest struct {

func FromFositeReq(reqester fosite.Requester) *Request {
req := reqester.(*fosite.Request)
return &Request{
r := &Request{
ID: req.ID,
RequestedAt: req.RequestedAt,
Client: FromFositeClient(req.Client),
Expand All @@ -53,6 +55,12 @@ func FromFositeReq(reqester fosite.Requester) *Request {
RequestedAudience: req.RequestedAudience,
GrantedAudience: req.GrantedAudience,
}
if sess := req.GetSession(); sess != nil {
if data, err := json.Marshal(sess); err == nil {
r.SessionData = data
}
}
return r
}

func (r *Request) ToFositeReq() *fosite.Request {
Expand Down
12 changes: 12 additions & 0 deletions pkg/repository/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package repository

import (
"context"
"encoding/json"
"fmt"

"github.com/ory/fosite"
"github.com/ory/fosite/handler/oauth2"
Expand All @@ -27,3 +29,13 @@ type AuthorizeRequestStorage interface {
GetAuthorizeRequest(ctx context.Context, requestID string) (fosite.AuthorizeRequester, error)
DeleteAuthorizeRequest(ctx context.Context, requestID string) error
}

func restoreSession(req *fosite.Request, sessionData json.RawMessage, sess fosite.Session) error {
if len(sessionData) > 0 && sess != nil {
if err := json.Unmarshal(sessionData, sess); err != nil {
return fmt.Errorf("failed to unmarshal session data: %w", err)
}
req.SetSession(sess)
}
return nil
}
16 changes: 12 additions & 4 deletions pkg/repository/kvs.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ func (r *kvsRepository) GetAuthorizeCodeSession(ctx context.Context, code string
return nil, err
}
fositeReq := req.ToFositeReq()
fositeReq.SetSession(sess)
if err := restoreSession(fositeReq, req.SessionData, sess); err != nil {
return nil, err
}
return fositeReq, nil
}

Expand All @@ -107,7 +109,9 @@ func (r *kvsRepository) GetAccessTokenSession(ctx context.Context, signature str
return nil, err
}
fositeReq := req.ToFositeReq()
fositeReq.SetSession(sess)
if err := restoreSession(fositeReq, req.SessionData, sess); err != nil {
return nil, err
}
return fositeReq, nil
}

Expand All @@ -125,7 +129,9 @@ func (r *kvsRepository) GetRefreshTokenSession(ctx context.Context, signature st
return nil, err
}
fositeReq := req.ToFositeReq()
fositeReq.SetSession(sess)
if err := restoreSession(fositeReq, req.SessionData, sess); err != nil {
return nil, err
}
return fositeReq, nil
}

Expand Down Expand Up @@ -200,7 +206,9 @@ func (r *kvsRepository) GetPKCERequestSession(ctx context.Context, signature str
return nil, err
}
fositeReq := req.ToFositeReq()
fositeReq.SetSession(sess)
if err := restoreSession(fositeReq, req.SessionData, sess); err != nil {
return nil, err
}
return fositeReq, nil
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/repository/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,8 @@ func unmarshalRequest(data []byte, sess fosite.Session) (fosite.Requester, error
return nil, fmt.Errorf("failed to unmarshal request: %w", err)
}
fositeReq := req.ToFositeReq()
if sess != nil {
fositeReq.SetSession(sess)
if err := restoreSession(fositeReq, req.SessionData, sess); err != nil {
return nil, err
}
return fositeReq, nil
}
Expand Down
82 changes: 82 additions & 0 deletions pkg/repository/sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,88 @@ func TestSQLRepositoryAccessTokenSession(t *testing.T) {
}
}

func TestSQLRepositorySessionPersistence(t *testing.T) {
repo, err := NewSQLRepository("sqlite", "file::memory:?cache=shared&_busy_timeout=5000")
if err != nil {
t.Fatalf("failed to create sql repository: %v", err)
}
defer repo.Close()

sess := &fosite.DefaultSession{
Username: "test-user",
Subject: "test-user",
}

ctx := context.Background()
client := &fosite.DefaultClient{
ID: "client-sess",
Secret: []byte("secret"),
RedirectURIs: []string{"https://example.com/callback"},
}

req := &fosite.Request{
ID: "req-sess",
RequestedAt: time.Now().UTC().Round(time.Second),
Client: client,
RequestedScope: []string{"openid"},
Form: url.Values{"code": {"value"}},
Session: sess,
}

if err := repo.CreateAuthorizeCodeSession(ctx, "code-sess", req); err != nil {
t.Fatalf("CreateAuthorizeCodeSession failed: %v", err)
}

result, err := repo.GetAuthorizeCodeSession(ctx, "code-sess", &fosite.DefaultSession{})
if err != nil {
t.Fatalf("GetAuthorizeCodeSession failed: %v", err)
}

restored := result.GetSession().(*fosite.DefaultSession)
if restored.GetSubject() != "test-user" {
t.Fatalf("expected subject 'test-user', got '%s'", restored.GetSubject())
}
if restored.GetUsername() != "test-user" {
t.Fatalf("expected username 'test-user', got '%s'", restored.GetUsername())
}
}

func TestSQLRepositorySessionPersistence_NilSessionStored(t *testing.T) {
repo, err := NewSQLRepository("sqlite", "file::memory:?cache=shared&_busy_timeout=5000")
if err != nil {
t.Fatalf("failed to create sql repository: %v", err)
}
defer repo.Close()

ctx := context.Background()
client := &fosite.DefaultClient{
ID: "client-old",
Secret: []byte("secret"),
RedirectURIs: []string{"https://example.com/callback"},
}

req := &fosite.Request{
ID: "req-old",
RequestedAt: time.Now().UTC().Round(time.Second),
Client: client,
RequestedScope: []string{"openid"},
Form: url.Values{"code": {"value"}},
}

if err := repo.CreateAuthorizeCodeSession(ctx, "code-old", req); err != nil {
t.Fatalf("CreateAuthorizeCodeSession failed: %v", err)
}

result, err := repo.GetAuthorizeCodeSession(ctx, "code-old", &fosite.DefaultSession{})
if err != nil {
t.Fatalf("GetAuthorizeCodeSession failed: %v", err)
}

if result.GetSession() != nil {
t.Fatalf("expected nil session when no session was stored, got %v", result.GetSession())
}
}

func TestSQLRepositoryUnsupportedDriver(t *testing.T) {
if _, err := NewSQLRepository("unsupported", "dsn"); err == nil {
t.Fatalf("expected error for unsupported driver but got nil")
Expand Down
Loading