Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 0 additions & 26 deletions pgmq_sqlalchemy/queue.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
from typing import List, Optional

from sqlalchemy import create_engine
Expand All @@ -23,14 +22,12 @@ class PGMQueue:

is_async: bool = False
is_pg_partman_ext_checked: bool = False
loop: asyncio.AbstractEventLoop = None

def __init__(
self,
dsn: Optional[str] = None,
engine: Optional[ENGINE_TYPE] = None,
session_maker: Optional[sessionmaker] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> None:
"""

Expand Down Expand Up @@ -79,8 +76,6 @@ def __init__(
dsn (Optional[str]): Database connection string.
engine (Optional[ENGINE_TYPE]): SQLAlchemy engine (sync or async).
session_maker (Optional[sessionmaker]): SQLAlchemy session maker.
loop (Optional[asyncio.AbstractEventLoop]): Event loop for async operations.
If not provided, a new event loop will be created for async engines.

.. note::
| ``PGMQueue`` will **auto create** the ``pgmq`` extension ( and ``pg_partman`` extension if the method is related with **partitioned_queue** ) if it does not exist in the Postgres.
Expand All @@ -107,27 +102,6 @@ def __init__(
bind=self.engine, class_=get_session_type(self.engine)
)

if self.is_async:
if loop is not None:
# Use the provided event loop
self.loop = loop
else:
# Create a new event loop
self.loop = asyncio.new_event_loop()

# create pgmq extension if not exists
self._check_pgmq_ext()

async def _check_pgmq_ext_async(self) -> None:
"""Check if the pgmq extension exists."""
async with self.session_maker() as session:
await PGMQOperation.check_pgmq_ext_async(session=session, commit=True)

def _check_pgmq_ext_sync(self) -> None:
"""Check if the pgmq extension exists."""
with self.session_maker() as session:
PGMQOperation.check_pgmq_ext(session=session, commit=True)

def _check_pgmq_ext(self) -> None:
"""Check if the pgmq extension exists."""
if self.is_async:
Expand Down
52 changes: 26 additions & 26 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@
from tests.constant import ASYNC_DRIVERS, SYNC_DRIVERS

# Async fixture names for test filtering
ASYNC_FIXTURE_NAMES = ['pgmq_by_async_dsn', 'pgmq_by_async_engine', 'pgmq_by_async_session_maker']
ASYNC_FIXTURE_NAMES = [
"pgmq_by_async_dsn",
"pgmq_by_async_engine",
"pgmq_by_async_session_maker",
]


def pytest_addoption(parser):
def pytest_addoption(parser: pytest.Parser):
"""Add custom command-line options for pytest."""
parser.addoption(
"--driver",
Expand All @@ -29,30 +33,30 @@ def pytest_addoption(parser):
)


def pytest_generate_tests(metafunc):
def pytest_generate_tests(metafunc: pytest.Metafunc):
"""
Dynamically generate test parametrization based on CLI options.

This allows us to parametrize fixtures based on the --driver option.
"""
if "pgmq_all_variants" in metafunc.fixturenames:
driver_from_cli = metafunc.config.getoption("--driver")

# Define sync and async fixture variants
sync_fixtures = [
'pgmq_by_dsn',
'pgmq_by_engine',
'pgmq_by_session_maker',
'pgmq_by_dsn_and_engine',
'pgmq_by_dsn_and_session_maker',
"pgmq_by_dsn",
"pgmq_by_engine",
"pgmq_by_session_maker",
"pgmq_by_dsn_and_engine",
"pgmq_by_dsn_and_session_maker",
]

async_fixtures = [
'pgmq_by_async_dsn',
'pgmq_by_async_engine',
'pgmq_by_async_session_maker',
"pgmq_by_async_dsn",
"pgmq_by_async_engine",
"pgmq_by_async_session_maker",
]

# Determine which fixtures to use
if not driver_from_cli:
# No driver specified, use all fixtures
Expand All @@ -63,13 +67,9 @@ def pytest_generate_tests(metafunc):
else:
# Sync driver specified
fixture_params = sync_fixtures

# Parametrize the test
metafunc.parametrize(
"pgmq_all_variants",
fixture_params,
indirect=True
)
metafunc.parametrize("pgmq_all_variants", fixture_params, indirect=True)


@pytest.fixture(scope="module")
Expand All @@ -93,7 +93,7 @@ def get_sa_password():


@pytest.fixture(scope="module")
def get_sa_db(request):
def get_sa_db(request: pytest.FixtureRequest):
"""Get database name from CLI argument or environment variable."""
db_name_from_cli = request.config.getoption("--db-name")
if db_name_from_cli:
Expand All @@ -112,14 +112,14 @@ def get_dsn(
):
"""Get DSN for sync drivers based on CLI option."""
driver_from_cli = request.config.getoption("--driver")

# Use CLI driver if specified and it's a sync driver
if driver_from_cli and driver_from_cli in SYNC_DRIVERS:
driver = driver_from_cli
else:
# Default to first sync driver if no CLI option or invalid
driver = SYNC_DRIVERS[0]

return f"postgresql+{driver}://{get_sa_user}:{get_sa_password}@{get_sa_host}:{get_sa_port}/{get_sa_db}"


Expand All @@ -134,14 +134,14 @@ def get_async_dsn(
):
"""Get DSN for async drivers based on CLI option."""
driver_from_cli = request.config.getoption("--driver")

# Use CLI driver if specified and it's an async driver
if driver_from_cli and driver_from_cli in ASYNC_DRIVERS:
driver = driver_from_cli
else:
# Default to first async driver if no CLI option or invalid
driver = ASYNC_DRIVERS[0]

return f"postgresql+{driver}://{get_sa_user}:{get_sa_password}@{get_sa_host}:{get_sa_port}/{get_sa_db}"


Expand Down
14 changes: 11 additions & 3 deletions tests/fixture_deps.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import uuid
from typing import Tuple
from inspect import iscoroutinefunction

import pytest

from pgmq_sqlalchemy import PGMQueue
from tests.constant import ASYNC_DRIVERS
from tests._utils import check_queue_exists

PGMQ_WITH_QUEUE = Tuple[PGMQueue, str]
Expand All @@ -13,18 +15,24 @@
def pgmq_all_variants(request: pytest.FixtureRequest) -> PGMQueue:
"""
Fixture that parametrizes tests across all appropriate PGMQueue initialization methods.

When --driver is specified, only fixtures matching that driver type (sync/async) are used.
Without --driver, all fixtures are used.

The parametrization is handled by pytest_generate_tests in conftest.py.

Usage:
def test_something(pgmq_all_variants):
pgmq: PGMQueue = pgmq_all_variants
# test code here
"""
# The param is set by pytest_generate_tests via indirect parametrization
is_async_test = iscoroutinefunction(request.function)
driver_from_cli = request.config.getoption("--driver")
if driver_from_cli and (driver_from_cli in ASYNC_DRIVERS and not is_async_test):
pytest.skip(
reason=f"Skip sync test: {request.function.__name__}, as driver: {driver_from_cli} is async"
)
return request.getfixturevalue(request.param)


Expand Down
2 changes: 2 additions & 0 deletions tests/test_construct_pgmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from tests.fixture_deps import pgmq_all_variants

use_fixtures = [pgmq_all_variants]


def test_construct_pgmq(pgmq_all_variants):
pgmq: PGMQueue = pgmq_all_variants
Expand Down
58 changes: 0 additions & 58 deletions tests/test_event_loop.py

This file was deleted.

Loading