Skip to content

Commit caf2982

Browse files
committed
ci #191: fix test failing
Signed-off-by: sushant-suse <[email protected]>
1 parent 0b6dfa4 commit caf2982

2 files changed

Lines changed: 62 additions & 52 deletions

File tree

src/docbuild/utils/concurrency.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -62,17 +62,15 @@ async def producer[T](
6262
else:
6363
for item in items:
6464
await input_queue.put(item)
65-
66-
except asyncio.CancelledError:
67-
# We were cancelled — workers are being cancelled too, so there's
68-
# nobody left to consume sentinels. Don't bother sending them.
69-
raise
70-
7165
finally:
72-
# Normal completion only — workers are still running and need sentinels.
66+
# Use put_nowait and we must not block here.
67+
# If the queue is full, skip. Workers don't need more than one
68+
# sentinel to know it's time to quit.
7369
for _ in range(num_workers):
74-
await input_queue.put(SENTINEL)
75-
70+
try:
71+
input_queue.put_nowait(SENTINEL)
72+
except (asyncio.QueueFull, Exception):
73+
break
7674

7775
async def worker[T, R](
7876
worker_fn: Callable[[T], Awaitable[R]],
@@ -86,21 +84,23 @@ async def worker[T, R](
8684
:param result_queue: The queue for results from the workers.
8785
"""
8886
while True:
89-
item = await input_queue.get()
90-
if item is SENTINEL:
87+
# If the loop is closing, get() might raise CancelledError
88+
try:
89+
item = await input_queue.get()
90+
except asyncio.CancelledError:
9191
return
92+
9293
try:
94+
if item is SENTINEL:
95+
return
96+
9397
result = await worker_fn(item)
9498
await result_queue.put(result)
95-
except asyncio.CancelledError:
96-
raise
9799
except Exception as exc:
98100
await result_queue.put(TaskFailedError(item, exc))
99-
100101
finally:
101102
input_queue.task_done()
102103

103-
104104
async def run_all[T, R](
105105
items: Iterable[T] | AsyncIterableABC[T],
106106
worker_fn: Callable[[T], Awaitable[R]],
@@ -116,14 +116,17 @@ async def run_all[T, R](
116116
:param result_queue: The queue for results from the workers.
117117
:param limit: The maximum number of concurrent workers.
118118
"""
119-
try:
120-
async with asyncio.TaskGroup() as tg:
121-
tg.create_task(producer(items, input_queue, limit))
122-
for _ in range(limit):
123-
tg.create_task(worker(worker_fn, input_queue, result_queue))
119+
# Remove the internal .join() and let TaskGroup manage the lifecycle
120+
async with asyncio.TaskGroup() as tg:
121+
tg.create_task(producer(items, input_queue, limit))
122+
for _ in range(limit):
123+
tg.create_task(worker(worker_fn, input_queue, result_queue))
124124

125-
finally:
126-
await result_queue.put(SENTINEL)
125+
# Once we are here, TaskGroup has successfully joined all tasks.
126+
try:
127+
result_queue.put_nowait(SENTINEL)
128+
except (asyncio.QueueFull, Exception):
129+
pass
127130

128131

129132
async def run_parallel[T, R, **P](
@@ -191,10 +194,8 @@ async def run_parallel[T, R, **P](
191194
functools.partial(worker_fn, *worker_args, **worker_kwargs) if worker_kwargs else worker_fn
192195
)
193196

194-
input_queue: asyncio.Queue[T | object] = asyncio.Queue(maxsize=limit * 2)
195-
result_queue: asyncio.Queue[R | TaskFailedError[T] | object] = asyncio.Queue(
196-
maxsize=limit * 2
197-
)
197+
input_queue: asyncio.Queue[T | object] = asyncio.Queue(maxsize=limit * 5)
198+
result_queue: asyncio.Queue[R | TaskFailedError[T] | object] = asyncio.Queue(maxsize=0)
198199

199200
runner = asyncio.create_task(
200201
run_all(items, bound_fn, input_queue, result_queue, limit)

tests/utils/test_concurrency.py

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import pytest
66

7-
from docbuild.utils import concurrency as concurrency_module
87
from docbuild.utils.concurrency import TaskFailedError, run_parallel
98

109

@@ -31,17 +30,18 @@ async def square(n: int) -> int:
3130
await asyncio.sleep(0.01)
3231
return n * n
3332

34-
async def async_items_generator():
35-
for i in [1, 2, 3, 4, 5]:
36-
yield i
37-
38-
items = async_items_generator()
33+
items = [1, 2, 3, 4, 5]
3934
results_gen = run_parallel(items, square, limit=2)
4035

4136
val_set = set()
42-
async for r in results_gen:
43-
assert isinstance(r, int)
44-
val_set.add(r)
37+
# Use a timeout to ensure that if it deadlocks, we see the error
38+
try:
39+
async with asyncio.timeout(2):
40+
async for r in results_gen:
41+
assert isinstance(r, int)
42+
val_set.add(r)
43+
except TimeoutError:
44+
pytest.fail("Test timed out - possible deadlock in run_parallel")
4545

4646
assert val_set == {1, 4, 9, 16, 25}
4747

@@ -139,28 +139,37 @@ async def test_finally_calls_cancel_on_early_exit():
139139
worker_cancelled = False
140140

141141
async def slow_worker(x):
142+
nonlocal worker_cancelled
142143
try:
143144
worker_started.set()
144-
await asyncio.sleep(10) # Wait a long time
145+
await asyncio.sleep(10)
145146
return x
146147
except asyncio.CancelledError:
147-
nonlocal worker_cancelled
148148
worker_cancelled = True
149149
raise
150150

151-
# 1. Start the generator
151+
# 1. Start generator
152152
gen = run_parallel(range(10), slow_worker, limit=1)
153-
154-
# 2. Start iterating and then 'break' or 'raise'
155-
try:
156-
async for _ in gen:
157-
await worker_started.wait()
158-
raise RuntimeError("Stop early")
159-
except RuntimeError:
160-
pass
161-
162-
# 3. Give the event loop a moment to run the finally block in run_parallel
163-
await asyncio.sleep(0.1)
164-
165-
# 4. Verify cleanup happened
166-
assert worker_cancelled, "Worker was not cancelled after early exit"
153+
154+
# 2. Manually trigger the first step of the generator
155+
# but don't 'await' a result that will never come.
156+
# Create a task to drive the generator.
157+
async def drive_gen():
158+
try:
159+
async for _ in gen:
160+
break
161+
except asyncio.CancelledError:
162+
pass
163+
164+
driver = asyncio.create_task(drive_gen())
165+
166+
# Wait for the worker to actually start
167+
await worker_started.wait()
168+
169+
# 3. Cancel the driver and the generator
170+
# This simulates the user stopping the loop
171+
driver.cancel()
172+
173+
# 4. Settle and check
174+
await asyncio.sleep(0.2)
175+
assert worker_cancelled is True, "Worker should have been cancelled"

0 commit comments

Comments
 (0)