-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathconftest.py
More file actions
218 lines (162 loc) · 5.87 KB
/
conftest.py
File metadata and controls
218 lines (162 loc) · 5.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import os
import pytest
from pytest import FixtureRequest
from sqlalchemy import create_engine
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker, Session
from pgmq_sqlalchemy import PGMQueue
from tests.constant import ASYNC_DRIVERS, SYNC_DRIVERS
# Async fixture names for test filtering
ASYNC_FIXTURE_NAMES = [
"pgmq_by_async_dsn",
"pgmq_by_async_engine",
"pgmq_by_async_session_maker",
]
def pytest_addoption(parser: pytest.Parser):
"""Add custom command-line options for pytest."""
parser.addoption(
"--driver",
action="store",
default=None,
help="Specify the database driver to use for testing (e.g., psycopg2, asyncpg, pg8000, etc.)",
)
parser.addoption(
"--db-name",
action="store",
default=None,
help="Specify the database name to use for testing",
)
def pytest_generate_tests(metafunc: pytest.Metafunc):
"""
Dynamically generate test parametrization based on CLI options.
This allows us to parametrize fixtures based on the --driver option.
"""
if "pgmq_all_variants" in metafunc.fixturenames:
driver_from_cli = metafunc.config.getoption("--driver")
# Define sync and async fixture variants
sync_fixtures = [
"pgmq_by_dsn",
"pgmq_by_engine",
"pgmq_by_session_maker",
"pgmq_by_dsn_and_engine",
"pgmq_by_dsn_and_session_maker",
]
async_fixtures = [
"pgmq_by_async_dsn",
"pgmq_by_async_engine",
"pgmq_by_async_session_maker",
]
# Determine which fixtures to use
if not driver_from_cli:
# No driver specified, use all fixtures
fixture_params = sync_fixtures + async_fixtures
elif driver_from_cli in ASYNC_DRIVERS:
# Async driver specified
fixture_params = async_fixtures
else:
# Sync driver specified
fixture_params = sync_fixtures
# Parametrize the test
metafunc.parametrize("pgmq_all_variants", fixture_params, indirect=True)
@pytest.fixture(scope="module")
def get_sa_host():
return os.getenv("SQLALCHEMY_HOST", "localhost")
@pytest.fixture(scope="module")
def get_sa_port():
return os.getenv("SQLALCHEMY_PORT", "5432")
@pytest.fixture(scope="module")
def get_sa_user():
return os.getenv("SQLALCHEMY_USER", "postgres")
@pytest.fixture(scope="module")
def get_sa_password():
return os.getenv("SQLALCHEMY_PASSWORD", "postgres")
@pytest.fixture(scope="module")
def get_sa_db(request: pytest.FixtureRequest):
"""Get database name from CLI argument or environment variable."""
db_name_from_cli = request.config.getoption("--db-name")
if db_name_from_cli:
return db_name_from_cli
return os.getenv("SQLALCHEMY_DB", "postgres")
@pytest.fixture(scope="function")
def get_dsn(
request: FixtureRequest,
get_sa_host,
get_sa_port,
get_sa_user,
get_sa_password,
get_sa_db,
):
"""Get DSN for sync drivers based on CLI option."""
driver_from_cli = request.config.getoption("--driver")
# Use CLI driver if specified and it's a sync driver
if driver_from_cli and driver_from_cli in SYNC_DRIVERS:
driver = driver_from_cli
else:
# Default to first sync driver if no CLI option or invalid
driver = SYNC_DRIVERS[0]
return f"postgresql+{driver}://{get_sa_user}:{get_sa_password}@{get_sa_host}:{get_sa_port}/{get_sa_db}"
@pytest.fixture(scope="function")
def get_async_dsn(
request: FixtureRequest,
get_sa_host,
get_sa_port,
get_sa_user,
get_sa_password,
get_sa_db,
):
"""Get DSN for async drivers based on CLI option."""
driver_from_cli = request.config.getoption("--driver")
# Use CLI driver if specified and it's an async driver
if driver_from_cli and driver_from_cli in ASYNC_DRIVERS:
driver = driver_from_cli
else:
# Default to first async driver if no CLI option or invalid
driver = ASYNC_DRIVERS[0]
return f"postgresql+{driver}://{get_sa_user}:{get_sa_password}@{get_sa_host}:{get_sa_port}/{get_sa_db}"
@pytest.fixture(scope="function")
def get_engine(get_dsn):
return create_engine(get_dsn)
@pytest.fixture(scope="function")
def get_async_engine(get_async_dsn):
return create_async_engine(get_async_dsn)
@pytest.fixture(scope="function")
def get_session_maker(get_engine):
return sessionmaker(bind=get_engine, class_=Session)
@pytest.fixture(scope="function")
def get_async_session_maker(get_async_engine):
return sessionmaker(bind=get_async_engine, class_=AsyncSession)
@pytest.fixture(scope="function")
def pgmq_by_dsn(get_dsn):
pgmq = PGMQueue(dsn=get_dsn)
return pgmq
@pytest.fixture(scope="function")
def pgmq_by_async_dsn(get_async_dsn):
pgmq = PGMQueue(dsn=get_async_dsn)
return pgmq
@pytest.fixture(scope="function")
def pgmq_by_engine(get_engine):
pgmq = PGMQueue(engine=get_engine)
return pgmq
@pytest.fixture(scope="function")
def pgmq_by_async_engine(get_async_engine):
pgmq = PGMQueue(engine=get_async_engine)
return pgmq
@pytest.fixture(scope="function")
def pgmq_by_session_maker(get_session_maker):
pgmq = PGMQueue(session_maker=get_session_maker)
return pgmq
@pytest.fixture(scope="function")
def pgmq_by_async_session_maker(get_async_session_maker):
pgmq = PGMQueue(session_maker=get_async_session_maker)
return pgmq
@pytest.fixture(scope="function")
def pgmq_by_dsn_and_engine(get_dsn, get_engine):
pgmq = PGMQueue(dsn=get_dsn, engine=get_engine)
return pgmq
@pytest.fixture(scope="function")
def pgmq_by_dsn_and_session_maker(get_dsn, get_session_maker):
pgmq = PGMQueue(dsn=get_dsn, session_maker=get_session_maker)
return pgmq
@pytest.fixture(scope="function")
def db_session(get_session_maker) -> Session:
return get_session_maker()