Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion backend/chainlit/auth/cookie.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
from typing import Literal, Optional, cast
from typing import Any, Literal, Optional, cast

from fastapi import Request, Response
from fastapi.exceptions import HTTPException
from fastapi.openapi.models import OAuth2 as OAuth2Model, OAuthFlows as OAuthFlowsModel
from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param
from starlette.status import HTTP_401_UNAUTHORIZED
Expand Down Expand Up @@ -48,6 +49,11 @@ def __init__(
self.tokenUrl = tokenUrl
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
self.model = OAuth2Model(
flows=OAuthFlowsModel(
password=cast(Any, {"tokenUrl": tokenUrl, "scopes": {}})
)
)

async def __call__(self, request: Request) -> Optional[str]:
# First try to get the token from the cookie
Expand Down
22 changes: 21 additions & 1 deletion backend/tests/auth/test_cookie.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import importlib

import pytest
from fastapi import FastAPI, Form
from fastapi import Depends, FastAPI, Form
from fastapi.testclient import TestClient
from starlette.requests import Request
from starlette.responses import Response

import chainlit.auth.cookie as cookie_module
from chainlit.auth import (
OAuth2PasswordBearerWithCookie,
clear_auth_cookie,
get_token_from_cookies,
set_auth_cookie,
Expand Down Expand Up @@ -163,3 +164,22 @@ def test_clear_auth_cookie(client):
assert len(clear_response.cookies) == 0
final_response = client.get("/get-token")
assert final_response.json()["token"] is None


def test_cookie_oauth_generates_openapi_security_scheme():
auth_scheme = OAuth2PasswordBearerWithCookie(tokenUrl="/login", auto_error=False)
app = FastAPI()

@app.get("/protected")
async def protected(token: str = Depends(auth_scheme)):
return {"token": token}

schema = app.openapi()
security_scheme = schema["components"]["securitySchemes"][
"OAuth2PasswordBearerWithCookie"
]

assert security_scheme == {
"type": "oauth2",
"flows": {"password": {"scopes": {}, "tokenUrl": "/login"}},
}