diff --git a/backend/chainlit/auth/cookie.py b/backend/chainlit/auth/cookie.py index 1cb4cc5ace..81c6c8979d 100644 --- a/backend/chainlit/auth/cookie.py +++ b/backend/chainlit/auth/cookie.py @@ -1,8 +1,9 @@ import os -from typing import Literal, Optional, cast +from typing import Any, Literal, Optional, cast from fastapi import Request, Response from fastapi.exceptions import HTTPException +from fastapi.openapi.models import OAuth2 as OAuth2Model, OAuthFlows as OAuthFlowsModel from fastapi.security.base import SecurityBase from fastapi.security.utils import get_authorization_scheme_param from starlette.status import HTTP_401_UNAUTHORIZED @@ -48,6 +49,11 @@ def __init__( self.tokenUrl = tokenUrl self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error + self.model = OAuth2Model( + flows=OAuthFlowsModel( + password=cast(Any, {"tokenUrl": tokenUrl, "scopes": {}}) + ) + ) async def __call__(self, request: Request) -> Optional[str]: # First try to get the token from the cookie diff --git a/backend/tests/auth/test_cookie.py b/backend/tests/auth/test_cookie.py index 5f5c3848a5..af73f088b9 100644 --- a/backend/tests/auth/test_cookie.py +++ b/backend/tests/auth/test_cookie.py @@ -1,13 +1,14 @@ import importlib import pytest -from fastapi import FastAPI, Form +from fastapi import Depends, FastAPI, Form from fastapi.testclient import TestClient from starlette.requests import Request from starlette.responses import Response import chainlit.auth.cookie as cookie_module from chainlit.auth import ( + OAuth2PasswordBearerWithCookie, clear_auth_cookie, get_token_from_cookies, set_auth_cookie, @@ -163,3 +164,22 @@ def test_clear_auth_cookie(client): assert len(clear_response.cookies) == 0 final_response = client.get("/get-token") assert final_response.json()["token"] is None + + +def test_cookie_oauth_generates_openapi_security_scheme(): + auth_scheme = OAuth2PasswordBearerWithCookie(tokenUrl="/login", auto_error=False) + app = FastAPI() + + @app.get("/protected") + async def protected(token: str = Depends(auth_scheme)): + return {"token": token} + + schema = app.openapi() + security_scheme = schema["components"]["securitySchemes"][ + "OAuth2PasswordBearerWithCookie" + ] + + assert security_scheme == { + "type": "oauth2", + "flows": {"password": {"scopes": {}, "tokenUrl": "/login"}}, + }