From 275c05eea52cf829a7e3b14c1cb602147a546ecd Mon Sep 17 00:00:00 2001 From: Tom Schraitle Date: Tue, 24 Feb 2026 13:23:14 +0100 Subject: [PATCH 01/26] Add producer/consumer model with create_task() --- src/docbuild/utils/concurrency.py | 106 ++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 src/docbuild/utils/concurrency.py diff --git a/src/docbuild/utils/concurrency.py b/src/docbuild/utils/concurrency.py new file mode 100644 index 00000000..deb9dd1e --- /dev/null +++ b/src/docbuild/utils/concurrency.py @@ -0,0 +1,106 @@ + +# src/docbuild/utils/concurrency.py +from multiprocessing.sharedctypes import Value +import asyncio +from collections.abc import Awaitable, Callable, Iterable +import logging +from typing import TypeVar + +T = TypeVar("T") # Input type +R = TypeVar("R") # Result type + +log = logging.getLogger(__name__) + + +async def parallel_process( + items: Iterable[T], + worker_fn: Callable[[T], Awaitable[R]], + *, + limit: int, + return_exceptions: bool = False, + name: str | None = None, +) -> list[R | Exception]: + """Process a list of items in parallel using a fixed number of workers. + + :param items: An iterable of items to process. + :param worker_fn: An async function that processes a single item. + :param limit: The maximum number of concurrent workers. + :param return_exceptions: If True, exceptions are returned as results + instead of raised. + :param name: Optional name for the task. + :return: A list of results (unordered unless you track indices). + """ + queue: asyncio.Queue[T] = asyncio.Queue() + results: list[R | Exception] = [] + + # 1. Populate Queue + for item in items: + queue.put_nowait(item) + + # 2. Define Worker + async def worker() -> None: + while True: + try: + item = await queue.get() + except asyncio.CancelledError: + return + + try: + res = await worker_fn(item) + results.append(res) + + except Exception as e: + if return_exceptions: + results.append(e) + else: + log.error("Worker failed: %s", e) + # Optional: Cancel all other workers here if you want fail-fast + finally: + queue.task_done() + + # 3. Start Workers + workers = [asyncio.create_task(worker(), name=name) + for _ in range(limit)] + + # 4. Wait & Cleanup + await queue.join() + for w in workers: + w.cancel() + await asyncio.gather(*workers, return_exceptions=True) + + return results + + +if __name__ == "__main__": + import random + import time + + logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + + async def sample_worker(num: int) -> int: + """Create a simple worker that simulates some I/O-bound work.""" + log.info("Processing item %d", num) + x = random.randint(0, 10) + if x in(1, 5, 6): + raise ValueError("Oh no! Wrong value!") + await asyncio.sleep(0.1* x) # Simulate I/O delay + return num * 2 + + async def main() -> None: + """Run the example.""" + start_time = time.monotonic() + + log.info("Starting parallel processing with a limit of 3 workers...") + results = await parallel_process( + range(20), + sample_worker, + limit=5) + end_time = time.monotonic() + + log.info("Processing finished in %.2f seconds", end_time - start_time) + log.info("Results (unordered): %s", results) + + print("Starting example...") + asyncio.run(main()) + print("Example finished.") + From 7557ae1405caaaea441f7d778490da4684140506 Mon Sep 17 00:00:00 2001 From: Tom Schraitle Date: Tue, 24 Feb 2026 16:14:12 +0100 Subject: [PATCH 02/26] Use TaskGroup as producer/consumer --- src/docbuild/utils/concurrency.py | 153 ++++++++++++++++++------------ 1 file changed, 93 insertions(+), 60 deletions(-) diff --git a/src/docbuild/utils/concurrency.py b/src/docbuild/utils/concurrency.py index deb9dd1e..0fb88cc4 100644 --- a/src/docbuild/utils/concurrency.py +++ b/src/docbuild/utils/concurrency.py @@ -1,8 +1,7 @@ - -# src/docbuild/utils/concurrency.py -from multiprocessing.sharedctypes import Value +from abc import ABC, abstractmethod import asyncio from collections.abc import Awaitable, Callable, Iterable +from dataclasses import dataclass import logging from typing import TypeVar @@ -12,95 +11,129 @@ log = logging.getLogger(__name__) -async def parallel_process( +# HINT(toms): DESIGN Is this really needed? +# This looks like it's a bit overengineered. On the other side, it +# contains the original item of the problem. +# +# Alternative implementation +# Maybe it's enough to add additional args to the exception(s) +# to hold the original item? Something like: +# >>> ValueError("the error message", item) +class Result[R](ABC): + """Abstract base class for the result of a task.""" + + @abstractmethod + def __init__(self) -> None: + pass + + +@dataclass(frozen=True) +class Success[R](Result[R]): + """Represents a successful task result.""" + + result: R + + +@dataclass(frozen=True) +class Failure[R, T](Result[R]): + """Represents a failed task result.""" + + item: T | None = None + exception: Exception | None = None + + +async def map_concurrent( items: Iterable[T], worker_fn: Callable[[T], Awaitable[R]], - *, limit: int, - return_exceptions: bool = False, - name: str | None = None, -) -> list[R | Exception]: - """Process a list of items in parallel using a fixed number of workers. +) -> list[Result[R]]: + """Apply an async worker function to an iterable of items concurrently. + + This function uses a producer-consumer model with a bounded number of + concurrent workers managed by an asyncio.TaskGroup. It always waits for + all tasks to complete. :param items: An iterable of items to process. :param worker_fn: An async function that processes a single item. - :param limit: The maximum number of concurrent workers. - :param return_exceptions: If True, exceptions are returned as results - instead of raised. - :param name: Optional name for the task. - :return: A list of results (unordered unless you track indices). + :param limit: The maximum number of concurrent workers (consumers). + :return: A list of Success or Failure result objects. The order is not guaranteed. """ - queue: asyncio.Queue[T] = asyncio.Queue() - results: list[R | Exception] = [] + queue: asyncio.Queue[T | None] = asyncio.Queue() + results: list[Result[R]] = [] - # 1. Populate Queue - for item in items: - queue.put_nowait(item) + async def producer() -> None: + for item in items: + await queue.put(item) + # After producing all items, send "poison pills" to consumers + for _ in range(limit): + await queue.put(None) - # 2. Define Worker - async def worker() -> None: + async def consumer() -> None: while True: - try: - item = await queue.get() - except asyncio.CancelledError: - return + item = await queue.get() + if item is None: + # "Poison pill" received, exit the loop + break try: - res = await worker_fn(item) - results.append(res) + result_val = await worker_fn(item) + # TODO: What do add here? + results.append(Success(result=result_val)) except Exception as e: - if return_exceptions: - results.append(e) - else: - log.error("Worker failed: %s", e) - # Optional: Cancel all other workers here if you want fail-fast - finally: - queue.task_done() - - # 3. Start Workers - workers = [asyncio.create_task(worker(), name=name) - for _ in range(limit)] - - # 4. Wait & Cleanup - await queue.join() - for w in workers: - w.cancel() - await asyncio.gather(*workers, return_exceptions=True) + # TODO: What do add here? + results.append(Failure(item=item, exception=e)) + + # Note: asyncio.TaskGroup requires Python 3.11+ + async with asyncio.TaskGroup() as tg: + tg.create_task(producer()) + for _ in range(limit): + tg.create_task(consumer()) return results if __name__ == "__main__": - import random import time logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") async def sample_worker(num: int) -> int: """Create a simple worker that simulates some I/O-bound work.""" + if num in (5, 8): + log.warning("Simulating failure for item %d", num) + # HINT: This is the "alternative" implementation. + # Instead of having a Failure class, we just raise the exception + # and add the item into the exception as an additional metadata + raise ValueError("Item 5 is not allowed!", num) + # Alternative: + # raise ValueError("Item 5 is not allowed!", {"item": num}) + log.info("Processing item %d", num) - x = random.randint(0, 10) - if x in(1, 5, 6): - raise ValueError("Oh no! Wrong value!") - await asyncio.sleep(0.1* x) # Simulate I/O delay + await asyncio.sleep(0.1) # Simulate I/O delay return num * 2 async def main() -> None: """Run the example.""" - start_time = time.monotonic() + items_to_process = list(range(10)) - log.info("Starting parallel processing with a limit of 3 workers...") - results = await parallel_process( - range(20), - sample_worker, - limit=5) + log.info("--- Running map_concurrent ---") + start_time = time.monotonic() + task_results = await map_concurrent(items_to_process, sample_worker, limit=3) end_time = time.monotonic() + log.info("Finished in %.2f seconds\n", end_time - start_time) - log.info("Processing finished in %.2f seconds", end_time - start_time) - log.info("Results (unordered): %s", results) + successful_results = [] + failed_tasks = [] + for res in task_results: + match res: + case Success(item=i, result=r): + successful_results.append((i, res)) + case Failure(item=i, exception=e): + failed_tasks.append(res ) # , i, e)) + print(">>", e.args) - print("Starting example...") - asyncio.run(main()) - print("Example finished.") + log.info("Successful results (unordered): %s", (successful_results)) + log.info("Caught exceptions: %s", failed_tasks) + asyncio.run(main()) From fab651d976b2858b944b701215c1dcc67bb680f1 Mon Sep 17 00:00:00 2001 From: Tom Schraitle Date: Wed, 25 Feb 2026 13:05:18 +0100 Subject: [PATCH 03/26] Simplify the producer/consumer model * Rename function to `process_unordered` * Remove Result, Success, and Failure * For the result, just add the result from the worker or the exception. --- src/docbuild/utils/concurrency.py | 103 ++++++++++++------------------ 1 file changed, 41 insertions(+), 62 deletions(-) diff --git a/src/docbuild/utils/concurrency.py b/src/docbuild/utils/concurrency.py index 0fb88cc4..59d9e83b 100644 --- a/src/docbuild/utils/concurrency.py +++ b/src/docbuild/utils/concurrency.py @@ -1,7 +1,7 @@ -from abc import ABC, abstractmethod +"""Concurrency utilities.""" + import asyncio from collections.abc import Awaitable, Callable, Iterable -from dataclasses import dataclass import logging from typing import TypeVar @@ -11,85 +11,64 @@ log = logging.getLogger(__name__) -# HINT(toms): DESIGN Is this really needed? -# This looks like it's a bit overengineered. On the other side, it -# contains the original item of the problem. -# -# Alternative implementation -# Maybe it's enough to add additional args to the exception(s) -# to hold the original item? Something like: -# >>> ValueError("the error message", item) -class Result[R](ABC): - """Abstract base class for the result of a task.""" - - @abstractmethod - def __init__(self) -> None: - pass - - -@dataclass(frozen=True) -class Success[R](Result[R]): - """Represents a successful task result.""" - - result: R - - -@dataclass(frozen=True) -class Failure[R, T](Result[R]): - """Represents a failed task result.""" - - item: T | None = None - exception: Exception | None = None - - -async def map_concurrent( +async def process_unordered( items: Iterable[T], worker_fn: Callable[[T], Awaitable[R]], limit: int, -) -> list[Result[R]]: - """Apply an async worker function to an iterable of items concurrently. +) -> list[R | Exception]: + """Process items concurrently with a worker limit. - This function uses a producer-consumer model with a bounded number of - concurrent workers managed by an asyncio.TaskGroup. It always waits for - all tasks to complete. + Uses a producer-consumer model via asyncio.TaskGroup. + Order of results is NOT guaranteed. + If an exception occurs, the exception object is returned in the list. + The original item is attached to the exception as `e.item`. - :param items: An iterable of items to process. - :param worker_fn: An async function that processes a single item. - :param limit: The maximum number of concurrent workers (consumers). - :return: A list of Success or Failure result objects. The order is not guaranteed. + :param items: Iterable of items to process. + :param worker_fn: Async function processing a single item. + :param limit: Max concurrent workers. """ - queue: asyncio.Queue[T | None] = asyncio.Queue() - results: list[Result[R]] = [] + # Limit queue size to prevent memory explosion if producer is faster than consumers + queue: asyncio.Queue[T | None] = asyncio.Queue(maxsize=limit * 2) + results: list[R | Exception] = [] async def producer() -> None: for item in items: await queue.put(item) - # After producing all items, send "poison pills" to consumers - for _ in range(limit): - await queue.put(None) async def consumer() -> None: while True: item = await queue.get() if item is None: - # "Poison pill" received, exit the loop + queue.task_done() break try: result_val = await worker_fn(item) - # TODO: What do add here? - results.append(Success(result=result_val)) + results.append(result_val) except Exception as e: - # TODO: What do add here? - results.append(Failure(item=item, exception=e)) + # Attach the item to the exception for tracking + e.item = item # type: ignore[attr-defined] + results.append(e) + + finally: + queue.task_done() - # Note: asyncio.TaskGroup requires Python 3.11+ async with asyncio.TaskGroup() as tg: - tg.create_task(producer()) + # Start consumers for _ in range(limit): tg.create_task(consumer()) + # Push items (blocks if queue is full, providing backpressure) + await producer() + + # Signal shutdown + for _ in range(limit): + await queue.put(None) + + # Wait for all items to be processed + await queue.join() + return results @@ -117,9 +96,9 @@ async def main() -> None: """Run the example.""" items_to_process = list(range(10)) - log.info("--- Running map_concurrent ---") + log.info("--- Running process_unordered ---") start_time = time.monotonic() - task_results = await map_concurrent(items_to_process, sample_worker, limit=3) + task_results = await process_unordered(items_to_process, sample_worker, limit=3) end_time = time.monotonic() log.info("Finished in %.2f seconds\n", end_time - start_time) @@ -127,11 +106,11 @@ async def main() -> None: failed_tasks = [] for res in task_results: match res: - case Success(item=i, result=r): - successful_results.append((i, res)) - case Failure(item=i, exception=e): - failed_tasks.append(res ) # , i, e)) - print(">>", e.args) + case Exception(item=i): + failed_tasks.append((i, res)) + case _: + # Order lost, but we have results + successful_results.append(res) log.info("Successful results (unordered): %s", (successful_results)) log.info("Caught exceptions: %s", failed_tasks) From 7735a6cb14e4cb78d0c676fa8452e254c01b102d Mon Sep 17 00:00:00 2001 From: Tom Schraitle Date: Wed, 25 Feb 2026 14:31:42 +0100 Subject: [PATCH 04/26] Replace Exception -> TaskFailedError --- src/docbuild/utils/concurrency.py | 77 +++++++++++++++++++++++-------- 1 file changed, 59 insertions(+), 18 deletions(-) diff --git a/src/docbuild/utils/concurrency.py b/src/docbuild/utils/concurrency.py index 59d9e83b..6d5f6f77 100644 --- a/src/docbuild/utils/concurrency.py +++ b/src/docbuild/utils/concurrency.py @@ -3,25 +3,32 @@ import asyncio from collections.abc import Awaitable, Callable, Iterable import logging -from typing import TypeVar - -T = TypeVar("T") # Input type -R = TypeVar("R") # Result type log = logging.getLogger(__name__) -async def process_unordered( +class TaskFailedError[T](Exception): + """Exception raised when a task fails during processing. + + :param item: The item that was being processed. + :param original_exception: The exception that caused the failure. + """ + + def __init__(self, item: T, original_exception: Exception) -> None: + super().__init__(f"Task failed for item {item}: {original_exception}") + self.item = item + self.original_exception = original_exception + + items: Iterable[T], worker_fn: Callable[[T], Awaitable[R]], limit: int, -) -> list[R | Exception]: +) -> list[R | TaskFailedError[T]]: """Process items concurrently with a worker limit. Uses a producer-consumer model via asyncio.TaskGroup. Order of results is NOT guaranteed. - If an exception occurs, the exception object is returned in the list. - The original item is attached to the exception as `e.item`. + If an exception occurs, it is wrapped in `TaskFailedError`. :param items: Iterable of items to process. :param worker_fn: Async function processing a single item. @@ -29,7 +36,7 @@ async def process_unordered( """ # Limit queue size to prevent memory explosion if producer is faster than consumers queue: asyncio.Queue[T | None] = asyncio.Queue(maxsize=limit * 2) - results: list[R | Exception] = [] + results: list[R | TaskFailedError[T]] = [] async def producer() -> None: for item in items: @@ -47,9 +54,8 @@ async def consumer() -> None: results.append(result_val) except Exception as e: - # Attach the item to the exception for tracking - e.item = item # type: ignore[attr-defined] - results.append(e) + # Wrap the exception in TaskFailedError + results.append(TaskFailedError(item, e)) finally: queue.task_done() @@ -92,6 +98,12 @@ async def sample_worker(num: int) -> int: await asyncio.sleep(0.1) # Simulate I/O delay return num * 2 + # Make process intensive tasks in a executor + # 1. Define the heavy lifting function (must be at module level for pickle) + def heavy_cpu_math(item: int) -> int: + """Simulate a CPU-bound task.""" + return item * item + async def main() -> None: """Run the example.""" items_to_process = list(range(10)) @@ -105,12 +117,41 @@ async def main() -> None: successful_results = [] failed_tasks = [] for res in task_results: - match res: - case Exception(item=i): - failed_tasks.append((i, res)) - case _: - # Order lost, but we have results - successful_results.append(res) + if isinstance(res, TaskFailedError): + failed_tasks.append((res.item, res.original_exception)) + else: + successful_results.append(res) + + log.info("Successful results (unordered): %s", (successful_results)) + log.info("Caught exceptions: %s", failed_tasks) + + ## ------------------- + log.info("--- Running process executor ---") + from concurrent.futures import ProcessPoolExecutor + + # 2. Create the wrapper + async def cpu_worker_wrapper(item, executor: None|ProcessPoolExecutor=None) -> int: + loop = asyncio.get_running_loop() + # Use the passed executor + return await loop.run_in_executor(executor, heavy_cpu_math, item) + + # 3. Use your existing utility with the executor passed as a kwarg + items = range(10) + with ProcessPoolExecutor() as process_pool: + results = await process_unordered( + items, + cpu_worker_wrapper, + limit=4, + executor=process_pool + ) + + successful_results = [] + failed_tasks = [] + for res in results: + if isinstance(res, TaskFailedError): + failed_tasks.append((res.item, res.original_exception)) + else: + successful_results.append(res) log.info("Successful results (unordered): %s", (successful_results)) log.info("Caught exceptions: %s", failed_tasks) From 67c8a176686b2926137347b698a06d2e63aaa072 Mon Sep 17 00:00:00 2001 From: Tom Schraitle Date: Wed, 25 Feb 2026 14:36:24 +0100 Subject: [PATCH 05/26] Use ParamSpec, extend worker_fn --- src/docbuild/utils/concurrency.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/docbuild/utils/concurrency.py b/src/docbuild/utils/concurrency.py index 6d5f6f77..3e673582 100644 --- a/src/docbuild/utils/concurrency.py +++ b/src/docbuild/utils/concurrency.py @@ -3,6 +3,7 @@ import asyncio from collections.abc import Awaitable, Callable, Iterable import logging +from typing import Concatenate, ParamSpec log = logging.getLogger(__name__) @@ -20,9 +21,12 @@ def __init__(self, item: T, original_exception: Exception) -> None: self.original_exception = original_exception +async def process_unordered[T, R, **P]( items: Iterable[T], - worker_fn: Callable[[T], Awaitable[R]], + worker_fn: Callable[Concatenate[T, P], Awaitable[R]], limit: int, + *worker_args: P.args, + **worker_kwargs: P.kwargs, ) -> list[R | TaskFailedError[T]]: """Process items concurrently with a worker limit. @@ -32,7 +36,10 @@ def __init__(self, item: T, original_exception: Exception) -> None: :param items: Iterable of items to process. :param worker_fn: Async function processing a single item. + Result signature: `worker_fn(item, *worker_args, **worker_kwargs)`. :param limit: Max concurrent workers. + :param worker_args: Additional positional arguments passed to `worker_fn`. + :param worker_kwargs: Additional keyword arguments passed to `worker_fn`. """ # Limit queue size to prevent memory explosion if producer is faster than consumers queue: asyncio.Queue[T | None] = asyncio.Queue(maxsize=limit * 2) @@ -50,7 +57,7 @@ async def consumer() -> None: break try: - result_val = await worker_fn(item) + result_val = await worker_fn(item, *worker_args, **worker_kwargs) results.append(result_val) except Exception as e: From 7ebaedfd9d4519c22a3ae4a6a8e675463861d9bf Mon Sep 17 00:00:00 2001 From: Tom Schraitle Date: Wed, 25 Feb 2026 14:39:28 +0100 Subject: [PATCH 06/26] Add tests --- tests/utils/test_concurrency.py | 98 +++++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 tests/utils/test_concurrency.py diff --git a/tests/utils/test_concurrency.py b/tests/utils/test_concurrency.py new file mode 100644 index 00000000..4375dab2 --- /dev/null +++ b/tests/utils/test_concurrency.py @@ -0,0 +1,98 @@ +"""Tests for concurrency utilities.""" + +import asyncio +import logging + +import pytest + +from docbuild.utils.concurrency import TaskFailedError, process_unordered + + +async def test_process_unordered_basic(): + """Test basic parallel processing of a list of numbers.""" + async def square(n: int) -> int: + await asyncio.sleep(0.01) + return n * n + + items = [1, 2, 3, 4, 5] + result = await process_unordered(items, square, limit=2) + + val_set = set() + for r in result: + assert isinstance(r, int) + val_set.add(r) + + assert val_set == {1, 4, 9, 16, 25} + + +async def test_process_unordered_concurrency_limit(): + """Verify that concurrency limit is respected.""" + active_workers = 0 + max_active = 0 + lock = asyncio.Lock() + + async def track_concurrency(n: int) -> int: + nonlocal active_workers, max_active + async with lock: + active_workers += 1 + max_active = max(max_active, active_workers) + + await asyncio.sleep(0.05) + + async with lock: + active_workers -= 1 + return n + + items = range(10) + limit = 3 + # Use higher limit in worker fn but restrict at call site + await process_unordered(items, track_concurrency, limit=limit) + + assert max_active <= limit + + +async def test_process_unordered_exceptions(): + """Test exception handling returning TaskFailedError.""" + async def fail_on_even(n: int) -> int: + if n % 2 == 0: + raise ValueError(f"Even number: n={n}") + return n + + items = [1, 2, 3] + results = await process_unordered(items, fail_on_even, limit=2) + + assert len(results) == 3 + + success_vals = [] + failed_items = [] + + for r in results: + match r: + case TaskFailedError(item=item, original_exception=exc): + failed_items.append(item) + assert isinstance(exc, ValueError) + case _: + success_vals.append(r) + + assert set(success_vals) == {1, 3} + assert failed_items == [2] + + +async def test_process_unordered_empty(): + """Test processing an empty list.""" + async def identity(n): return n + results = await process_unordered([], identity, limit=5) + assert results == [] + + +async def test_process_unordered_kwargs(): + """Test passing kwargs to worker function.""" + async def multiply(n: int, factor: int = 1) -> int: + return n * factor + + items = [1, 2, 3] + results = await process_unordered(items, multiply, limit=2, factor=3) + + # We might get exceptions if anything failed, but expecting ints + int_results = [r for r in results if isinstance(r, int)] + assert set(int_results) == {3, 6, 9} From 7064c2c1450382588c168b50154aa457700118c9 Mon Sep 17 00:00:00 2001 From: Tom Schraitle Date: Wed, 25 Feb 2026 14:50:29 +0100 Subject: [PATCH 07/26] Fix Ruff errors --- src/docbuild/utils/concurrency.py | 8 +++++--- tests/utils/test_concurrency.py | 3 --- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/docbuild/utils/concurrency.py b/src/docbuild/utils/concurrency.py index 3e673582..2adfd94a 100644 --- a/src/docbuild/utils/concurrency.py +++ b/src/docbuild/utils/concurrency.py @@ -3,7 +3,7 @@ import asyncio from collections.abc import Awaitable, Callable, Iterable import logging -from typing import Concatenate, ParamSpec +from typing import Concatenate log = logging.getLogger(__name__) @@ -134,10 +134,12 @@ async def main() -> None: ## ------------------- log.info("--- Running process executor ---") - from concurrent.futures import ProcessPoolExecutor + from concurrent.futures import Executor, ProcessPoolExecutor # 2. Create the wrapper - async def cpu_worker_wrapper(item, executor: None|ProcessPoolExecutor=None) -> int: + async def cpu_worker_wrapper( + item: int, executor: Executor | None = None + ) -> int: loop = asyncio.get_running_loop() # Use the passed executor return await loop.run_in_executor(executor, heavy_cpu_math, item) diff --git a/tests/utils/test_concurrency.py b/tests/utils/test_concurrency.py index 4375dab2..68a97d94 100644 --- a/tests/utils/test_concurrency.py +++ b/tests/utils/test_concurrency.py @@ -1,9 +1,6 @@ """Tests for concurrency utilities.""" import asyncio -import logging - -import pytest from docbuild.utils.concurrency import TaskFailedError, process_unordered From 3f1414750a6af99b8799d497b96b23fc32179b84 Mon Sep 17 00:00:00 2001 From: Tom Schraitle Date: Wed, 25 Feb 2026 15:22:09 +0100 Subject: [PATCH 08/26] Improve docstring --- src/docbuild/utils/concurrency.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/src/docbuild/utils/concurrency.py b/src/docbuild/utils/concurrency.py index 2adfd94a..ebfe6633 100644 --- a/src/docbuild/utils/concurrency.py +++ b/src/docbuild/utils/concurrency.py @@ -1,4 +1,11 @@ -"""Concurrency utilities.""" +"""Concurrency utilities using producer-consumer patterns. + +This module provides helpers for managing concurrent asyncio tasks with +strict concurrency limits, backpressure handling, and robust exception tracking. + +It is designed to handle both I/O-bound tasks (via native asyncio coroutines) and +CPU-bound tasks (via `loop.run_in_executor`) while keeping resource usage deterministic. +""" import asyncio from collections.abc import Awaitable, Callable, Iterable @@ -11,6 +18,11 @@ class TaskFailedError[T](Exception): """Exception raised when a task fails during processing. + This wrapper preserves the context of a failure in concurrent processing pipelines. + Since results may be returned out of order or aggregated later, wrapping the + exception allows the caller to link a failure back to the specific input item + that caused it. + :param item: The item that was being processed. :param original_exception: The exception that caused the failure. """ @@ -32,14 +44,14 @@ async def process_unordered[T, R, **P]( Uses a producer-consumer model via asyncio.TaskGroup. Order of results is NOT guaranteed. - If an exception occurs, it is wrapped in `TaskFailedError`. + If an exception occurs, it is wrapped in :class:`~docbuild.utils.concurrency.TaskFailedError`. :param items: Iterable of items to process. :param worker_fn: Async function processing a single item. - Result signature: `worker_fn(item, *worker_args, **worker_kwargs)`. + Result signature: ``worker_fn(item, *worker_args, **worker_kwargs)``. :param limit: Max concurrent workers. - :param worker_args: Additional positional arguments passed to `worker_fn`. - :param worker_kwargs: Additional keyword arguments passed to `worker_fn`. + :param worker_args: Additional positional arguments passed to ``worker_fn``. + :param worker_kwargs: Additional keyword arguments passed to ``worker_fn``. """ # Limit queue size to prevent memory explosion if producer is faster than consumers queue: asyncio.Queue[T | None] = asyncio.Queue(maxsize=limit * 2) From 432bd3b248747999b1de15bfd44621e13edbaee5 Mon Sep 17 00:00:00 2001 From: Tom Schraitle Date: Wed, 25 Feb 2026 15:23:27 +0100 Subject: [PATCH 09/26] Add newsfragment --- changelog.d/191.feature.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/191.feature.rst diff --git a/changelog.d/191.feature.rst b/changelog.d/191.feature.rst new file mode 100644 index 00000000..fa198f44 --- /dev/null +++ b/changelog.d/191.feature.rst @@ -0,0 +1 @@ +Add a producer/consumer model implementation. From cd792ffa81427ec8d8ee26d447f7c05ef1ea1702 Mon Sep 17 00:00:00 2001 From: Tom Schraitle Date: Fri, 27 Feb 2026 13:49:14 +0100 Subject: [PATCH 10/26] Rename function to run_parallel --- src/docbuild/utils/concurrency.py | 6 +++--- tests/utils/test_concurrency.py | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/docbuild/utils/concurrency.py b/src/docbuild/utils/concurrency.py index ebfe6633..2fedf36a 100644 --- a/src/docbuild/utils/concurrency.py +++ b/src/docbuild/utils/concurrency.py @@ -33,7 +33,7 @@ def __init__(self, item: T, original_exception: Exception) -> None: self.original_exception = original_exception -async def process_unordered[T, R, **P]( +async def run_parallel[T, R, **P]( items: Iterable[T], worker_fn: Callable[Concatenate[T, P], Awaitable[R]], limit: int, @@ -129,7 +129,7 @@ async def main() -> None: log.info("--- Running process_unordered ---") start_time = time.monotonic() - task_results = await process_unordered(items_to_process, sample_worker, limit=3) + task_results = await run_parallel(items_to_process, sample_worker, limit=3) end_time = time.monotonic() log.info("Finished in %.2f seconds\n", end_time - start_time) @@ -159,7 +159,7 @@ async def cpu_worker_wrapper( # 3. Use your existing utility with the executor passed as a kwarg items = range(10) with ProcessPoolExecutor() as process_pool: - results = await process_unordered( + results = await run_parallel( items, cpu_worker_wrapper, limit=4, diff --git a/tests/utils/test_concurrency.py b/tests/utils/test_concurrency.py index 68a97d94..dd6b185d 100644 --- a/tests/utils/test_concurrency.py +++ b/tests/utils/test_concurrency.py @@ -2,7 +2,7 @@ import asyncio -from docbuild.utils.concurrency import TaskFailedError, process_unordered +from docbuild.utils.concurrency import TaskFailedError, run_parallel async def test_process_unordered_basic(): @@ -12,7 +12,7 @@ async def square(n: int) -> int: return n * n items = [1, 2, 3, 4, 5] - result = await process_unordered(items, square, limit=2) + result = await run_parallel(items, square, limit=2) val_set = set() for r in result: @@ -43,7 +43,7 @@ async def track_concurrency(n: int) -> int: items = range(10) limit = 3 # Use higher limit in worker fn but restrict at call site - await process_unordered(items, track_concurrency, limit=limit) + await run_parallel(items, track_concurrency, limit=limit) assert max_active <= limit @@ -56,7 +56,7 @@ async def fail_on_even(n: int) -> int: return n items = [1, 2, 3] - results = await process_unordered(items, fail_on_even, limit=2) + results = await run_parallel(items, fail_on_even, limit=2) assert len(results) == 3 @@ -78,7 +78,7 @@ async def fail_on_even(n: int) -> int: async def test_process_unordered_empty(): """Test processing an empty list.""" async def identity(n): return n - results = await process_unordered([], identity, limit=5) + results = await run_parallel([], identity, limit=5) assert results == [] @@ -88,7 +88,7 @@ async def multiply(n: int, factor: int = 1) -> int: return n * factor items = [1, 2, 3] - results = await process_unordered(items, multiply, limit=2, factor=3) + results = await run_parallel(items, multiply, limit=2, factor=3) # We might get exceptions if anything failed, but expecting ints int_results = [r for r in results if isinstance(r, int)] From 6e1b093300350e022c16ffd5d0c6888eea180dfd Mon Sep 17 00:00:00 2001 From: Tom Schraitle Date: Fri, 27 Feb 2026 15:08:55 +0100 Subject: [PATCH 11/26] Make it more scalable --- src/docbuild/utils/concurrency.py | 114 +++++++++++++++++------------- 1 file changed, 65 insertions(+), 49 deletions(-) diff --git a/src/docbuild/utils/concurrency.py b/src/docbuild/utils/concurrency.py index 2fedf36a..19f2d65b 100644 --- a/src/docbuild/utils/concurrency.py +++ b/src/docbuild/utils/concurrency.py @@ -8,7 +8,13 @@ """ import asyncio -from collections.abc import Awaitable, Callable, Iterable +from collections.abc import ( + AsyncIterable as AsyncIterableABC, + AsyncIterator, + Awaitable, + Callable, + Iterable, +) import logging from typing import Concatenate @@ -34,12 +40,12 @@ def __init__(self, item: T, original_exception: Exception) -> None: async def run_parallel[T, R, **P]( - items: Iterable[T], + items: Iterable[T] | AsyncIterableABC[T], worker_fn: Callable[Concatenate[T, P], Awaitable[R]], limit: int, *worker_args: P.args, **worker_kwargs: P.kwargs, -) -> list[R | TaskFailedError[T]]: +) -> AsyncIterator[R | TaskFailedError[T]]: """Process items concurrently with a worker limit. Uses a producer-consumer model via asyncio.TaskGroup. @@ -53,48 +59,57 @@ async def run_parallel[T, R, **P]( :param worker_args: Additional positional arguments passed to ``worker_fn``. :param worker_kwargs: Additional keyword arguments passed to ``worker_fn``. """ - # Limit queue size to prevent memory explosion if producer is faster than consumers - queue: asyncio.Queue[T | None] = asyncio.Queue(maxsize=limit * 2) - results: list[R | TaskFailedError[T]] = [] + if limit <= 0: + raise ValueError("limit must be >= 1") - async def producer() -> None: - for item in items: - await queue.put(item) + result_queue: asyncio.Queue[R | TaskFailedError[T]] = asyncio.Queue() + iterator_lock = asyncio.Lock() - async def consumer() -> None: + # Normalize iterable + if isinstance(items, AsyncIterableABC): + async_iter = items + + else: + async def async_wrapper(): + for item in items: + yield item + + async_iter = async_wrapper() + + async def worker(): while True: - item = await queue.get() - if item is None: - queue.task_done() - break + async with iterator_lock: + try: + item = await anext(async_iter) + except StopAsyncIteration: + return try: - result_val = await worker_fn(item, *worker_args, **worker_kwargs) - results.append(result_val) + result = await worker_fn(item, *worker_args, **worker_kwargs) + await result_queue.put(result) - except Exception as e: - # Wrap the exception in TaskFailedError - results.append(TaskFailedError(item, e)) + except asyncio.CancelledError: + raise - finally: - queue.task_done() + except Exception as e: + await result_queue.put(TaskFailedError(item, e)) async with asyncio.TaskGroup() as tg: - # Start consumers - for _ in range(limit): - tg.create_task(consumer()) - - # Push items (blocks if queue is full, providing backpressure) - await producer() - - # Signal shutdown for _ in range(limit): - await queue.put(None) + tg.create_task(worker()) - # Wait for all items to be processed - await queue.join() + active_workers = limit + while active_workers > 0: + try: + result = await asyncio.wait_for(result_queue.get(), timeout=0.1) + yield result - return results + except asyncio.TimeoutError: + # Check if workers done + active_workers = sum( + not t.done() + for t in tg._tasks # internal but safe in practice + ) if __name__ == "__main__": @@ -125,17 +140,22 @@ def heavy_cpu_math(item: int) -> int: async def main() -> None: """Run the example.""" - items_to_process = list(range(10)) + async def generate_items() -> AsyncIterableABC[int]: + for i in range(10): + yield i + # yield from range(10) log.info("--- Running process_unordered ---") start_time = time.monotonic() - task_results = await run_parallel(items_to_process, sample_worker, limit=3) + task_results = ( + res async for res in run_parallel(generate_items(), sample_worker, limit=3) + ) end_time = time.monotonic() log.info("Finished in %.2f seconds\n", end_time - start_time) successful_results = [] failed_tasks = [] - for res in task_results: + async for res in task_results: if isinstance(res, TaskFailedError): failed_tasks.append((res.item, res.original_exception)) else: @@ -159,20 +179,16 @@ async def cpu_worker_wrapper( # 3. Use your existing utility with the executor passed as a kwarg items = range(10) with ProcessPoolExecutor() as process_pool: - results = await run_parallel( - items, - cpu_worker_wrapper, - limit=4, - executor=process_pool - ) - successful_results = [] - failed_tasks = [] - for res in results: - if isinstance(res, TaskFailedError): - failed_tasks.append((res.item, res.original_exception)) - else: - successful_results.append(res) + successful_results = [] + failed_tasks = [] + async for res in run_parallel( + items, cpu_worker_wrapper, limit=4, executor=process_pool + ): + if isinstance(res, TaskFailedError): + failed_tasks.append((res.item, res.original_exception)) + else: + successful_results.append(res) log.info("Successful results (unordered): %s", (successful_results)) log.info("Caught exceptions: %s", failed_tasks) From 02ead56888a6804d5c189589310aa5c7c44a5d1b Mon Sep 17 00:00:00 2001 From: Tom Schraitle Date: Fri, 27 Feb 2026 15:42:02 +0100 Subject: [PATCH 12/26] Split it in different subfunction --- src/docbuild/utils/concurrency.py | 187 +++++++++++++++++++++--------- 1 file changed, 129 insertions(+), 58 deletions(-) diff --git a/src/docbuild/utils/concurrency.py b/src/docbuild/utils/concurrency.py index 19f2d65b..105fdea3 100644 --- a/src/docbuild/utils/concurrency.py +++ b/src/docbuild/utils/concurrency.py @@ -15,11 +15,15 @@ Callable, Iterable, ) +import functools import logging -from typing import Concatenate +from typing import Any, Concatenate log = logging.getLogger(__name__) +#: Sentinel value for internal use when needed (e.g., to signal completion). +SENTINEL = object() + class TaskFailedError[T](Exception): """Exception raised when a task fails during processing. @@ -39,77 +43,144 @@ def __init__(self, item: T, original_exception: Exception) -> None: self.original_exception = original_exception -async def run_parallel[T, R, **P]( +async def producer[T]( + items: Iterable[T] | AsyncIterableABC[T], + input_queue: asyncio.Queue, + num_workers: int, +) -> None: + """Feed items into the input queue, then send one sentinel per worker.""" + try: + if isinstance(items, AsyncIterableABC): + async for item in items: + await input_queue.put(item) + else: + for item in items: + await input_queue.put(item) + + finally: + for _ in range(num_workers): + await input_queue.put(SENTINEL) + + +async def worker[T, R]( + worker_fn: Callable[[T], Awaitable[R]], + input_queue: asyncio.Queue, + result_queue: asyncio.Queue, +) -> None: + """Pull items from the input queue, process them, push results out.""" + while True: + item = await input_queue.get() + if item is SENTINEL: + return + try: + result = await worker_fn(item) + await result_queue.put(result) + except asyncio.CancelledError: + raise + except Exception as exc: + await result_queue.put(TaskFailedError(item, exc)) + + finally: + input_queue.task_done() + + +async def run_all[T, R]( items: Iterable[T] | AsyncIterableABC[T], - worker_fn: Callable[Concatenate[T, P], Awaitable[R]], + worker_fn: Callable[[T], Awaitable[R]], + input_queue: asyncio.Queue, + result_queue: asyncio.Queue, limit: int, - *worker_args: P.args, - **worker_kwargs: P.kwargs, -) -> AsyncIterator[R | TaskFailedError[T]]: - """Process items concurrently with a worker limit. +) -> None: + """Orchestrate producer + workers, then signal the consumer when done.""" + async with asyncio.TaskGroup() as tg: + tg.create_task(producer(items, input_queue, limit)) + for _ in range(limit): + tg.create_task(worker(worker_fn, input_queue, result_queue)) - Uses a producer-consumer model via asyncio.TaskGroup. - Order of results is NOT guaranteed. - If an exception occurs, it is wrapped in :class:`~docbuild.utils.concurrency.TaskFailedError`. - - :param items: Iterable of items to process. - :param worker_fn: Async function processing a single item. - Result signature: ``worker_fn(item, *worker_args, **worker_kwargs)``. - :param limit: Max concurrent workers. - :param worker_args: Additional positional arguments passed to ``worker_fn``. - :param worker_kwargs: Additional keyword arguments passed to ``worker_fn``. - """ - if limit <= 0: - raise ValueError("limit must be >= 1") + await result_queue.put(SENTINEL) - result_queue: asyncio.Queue[R | TaskFailedError[T]] = asyncio.Queue() - iterator_lock = asyncio.Lock() - # Normalize iterable - if isinstance(items, AsyncIterableABC): - async_iter = items +async def run_parallel[T, R]( + items: Iterable[T] | AsyncIterableABC[T], + worker_fn: Callable[[T], Awaitable[R]], + limit: int, + **worker_kwargs: Any, # noqa: ANN401 +) -> AsyncIterator[R | TaskFailedError[T]]: + """Process items concurrently with bounded parallelism. - else: - async def async_wrapper(): - for item in items: - yield item + Uses a producer/worker/consumer pipeline: - async_iter = async_wrapper() + - A single **producer** task feeds items into a bounded input queue. + - ``limit`` **worker** tasks pull from the input queue, call ``worker_fn``, + and push results into a bounded result queue. + - The **caller** consumes results by iterating this async generator. - async def worker(): - while True: - async with iterator_lock: - try: - item = await anext(async_iter) - except StopAsyncIteration: - return + All three stages run concurrently. Backpressure propagates naturally: + a slow consumer stalls workers; stalled workers stall the producer. + Order of results is NOT guaranteed. - try: - result = await worker_fn(item, *worker_args, **worker_kwargs) - await result_queue.put(result) + If ``worker_fn`` raises, the exception is wrapped in + :class:`TaskFailedError` and yielded rather than re-raised, so one + failing item does not abort the pipeline. + + Performance characteristics + --------------------------- + - **Throughput:** approaches ``limit × per-worker throughput`` for + I/O-bound workloads where workers spend most time awaiting external + resources. CPU-bound work gains little due to the GIL; use + ``ProcessPoolExecutor`` wrapped in ``asyncio.run_in_executor`` instead. + - **Startup cost:** O(limit) — one asyncio task per worker, each cheap + to create (~microseconds). + - **Memory:** O(limit). Both the input queue (``maxsize=limit * 2``) + and the result queue (``maxsize=limit * 2``) are bounded. At most + ``limit`` items are in-flight inside workers at any time, giving a + total live-item count of roughly ``5 × limit``. + Note: each item itself may be arbitrarily large; the O(limit) bound + refers to the *number* of items held in memory, not their byte size. + - **Latency:** time-to-first-result equals one worker's latency. + Remaining results stream out as workers complete, with no polling + delay (sentinel-based signalling, zero busy-wait). + - **Cancellation:** if the caller abandons the generator (e.g. ``break`` + in an ``async for`` loop), the internal runner task is cancelled and + all worker tasks are cleaned up promptly via ``TaskGroup``. + + :param items: Iterable or async iterable of items to process. + :param worker_fn: Async callable invoked as ``worker_fn(item)`` for + each item. Must be safe to call concurrently from ``limit`` tasks. + :param limit: Maximum number of concurrent workers. Must be >= 1. + Higher values increase throughput up to the point where the event + loop, network, or downstream service becomes the bottleneck. + :raises ValueError: If ``limit`` is less than 1. + :yields: Results in completion order (not input order). Failed items + are yielded as :class:`TaskFailedError` instances rather than + raising, so the caller can handle partial failures inline. + """ + if limit <= 0: + raise ValueError("limit must be >= 1") - except asyncio.CancelledError: - raise + bound_fn = ( + functools.partial(worker_fn, **worker_kwargs) if worker_kwargs else worker_fn + ) - except Exception as e: - await result_queue.put(TaskFailedError(item, e)) + input_queue: asyncio.Queue[T | object] = asyncio.Queue(maxsize=limit * 2) + result_queue: asyncio.Queue[R | TaskFailedError[T] | object] = asyncio.Queue() - async with asyncio.TaskGroup() as tg: - for _ in range(limit): - tg.create_task(worker()) + runner = asyncio.create_task(run_all(items, bound_fn, input_queue, result_queue, limit)) - active_workers = limit - while active_workers > 0: + try: + while True: + result = await result_queue.get() + if result is SENTINEL: + break + yield result # type: ignore[misc] + + finally: + if not runner.done(): + runner.cancel() try: - result = await asyncio.wait_for(result_queue.get(), timeout=0.1) - yield result - - except asyncio.TimeoutError: - # Check if workers done - active_workers = sum( - not t.done() - for t in tg._tasks # internal but safe in practice - ) + await runner + except (asyncio.CancelledError, Exception): + pass if __name__ == "__main__": From b27237939bb3375e8c7184af03e55218213ac3fd Mon Sep 17 00:00:00 2001 From: Tom Schraitle Date: Mon, 2 Mar 2026 08:42:21 +0100 Subject: [PATCH 13/26] Add docstrings --- src/docbuild/utils/concurrency.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/src/docbuild/utils/concurrency.py b/src/docbuild/utils/concurrency.py index 105fdea3..f2a5958e 100644 --- a/src/docbuild/utils/concurrency.py +++ b/src/docbuild/utils/concurrency.py @@ -17,7 +17,7 @@ ) import functools import logging -from typing import Any, Concatenate +from typing import Any log = logging.getLogger(__name__) @@ -48,7 +48,12 @@ async def producer[T]( input_queue: asyncio.Queue, num_workers: int, ) -> None: - """Feed items into the input queue, then send one sentinel per worker.""" + """Feed items into the input queue, then send one sentinel per worker. + + :param items: An iterable or async iterable of items to be processed. + :param input_queue: The queue for items to be processed by workers. + :param num_workers: The number of workers, used to send the correct number of sentinels. + """ try: if isinstance(items, AsyncIterableABC): async for item in items: @@ -67,7 +72,12 @@ async def worker[T, R]( input_queue: asyncio.Queue, result_queue: asyncio.Queue, ) -> None: - """Pull items from the input queue, process them, push results out.""" + """Pull items from the input queue, process them, push results out. + + :param worker_fn: The asynchronous function that processes a single item. + :param input_queue: The queue for items to be processed by workers. + :param result_queue: The queue for results from the workers. + """ while True: item = await input_queue.get() if item is SENTINEL: @@ -91,7 +101,14 @@ async def run_all[T, R]( result_queue: asyncio.Queue, limit: int, ) -> None: - """Orchestrate producer + workers, then signal the consumer when done.""" + """Orchestrate producer + workers, then signal the consumer when done. + + :param items: An iterable or async iterable of items to be processed. + :param worker_fn: The asynchronous function that processes a single item. + :param input_queue: The queue for items to be processed by workers. + :param result_queue: The queue for results from the workers. + :param limit: The maximum number of concurrent workers. + """ async with asyncio.TaskGroup() as tg: tg.create_task(producer(items, input_queue, limit)) for _ in range(limit): @@ -125,7 +142,7 @@ async def run_parallel[T, R]( Performance characteristics --------------------------- - - **Throughput:** approaches ``limit × per-worker throughput`` for + - **Throughput:** approaches ``limit * per-worker-throughput`` for I/O-bound workloads where workers spend most time awaiting external resources. CPU-bound work gains little due to the GIL; use ``ProcessPoolExecutor`` wrapped in ``asyncio.run_in_executor`` instead. @@ -134,7 +151,7 @@ async def run_parallel[T, R]( - **Memory:** O(limit). Both the input queue (``maxsize=limit * 2``) and the result queue (``maxsize=limit * 2``) are bounded. At most ``limit`` items are in-flight inside workers at any time, giving a - total live-item count of roughly ``5 × limit``. + total live-item count of roughly ``5 * limit``. Note: each item itself may be arbitrarily large; the O(limit) bound refers to the *number* of items held in memory, not their byte size. - **Latency:** time-to-first-result equals one worker's latency. From 95ac498687f46bba95efd17566193143c8f55eeb Mon Sep 17 00:00:00 2001 From: Tom Schraitle Date: Mon, 2 Mar 2026 08:47:39 +0100 Subject: [PATCH 14/26] Fix failed tests --- tests/utils/test_concurrency.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/tests/utils/test_concurrency.py b/tests/utils/test_concurrency.py index dd6b185d..ed1a1014 100644 --- a/tests/utils/test_concurrency.py +++ b/tests/utils/test_concurrency.py @@ -12,10 +12,10 @@ async def square(n: int) -> int: return n * n items = [1, 2, 3, 4, 5] - result = await run_parallel(items, square, limit=2) + results_gen = run_parallel(items, square, limit=2) val_set = set() - for r in result: + async for r in results_gen: assert isinstance(r, int) val_set.add(r) @@ -42,8 +42,8 @@ async def track_concurrency(n: int) -> int: items = range(10) limit = 3 - # Use higher limit in worker fn but restrict at call site - await run_parallel(items, track_concurrency, limit=limit) + # Consume the async generator to ensure all workers run and concurrency is tracked. + _ = [r async for r in run_parallel(items, track_concurrency, limit=limit)] assert max_active <= limit @@ -56,8 +56,8 @@ async def fail_on_even(n: int) -> int: return n items = [1, 2, 3] - results = await run_parallel(items, fail_on_even, limit=2) - + results_gen = run_parallel(items, fail_on_even, limit=2) + results = [r async for r in results_gen] assert len(results) == 3 success_vals = [] @@ -78,8 +78,9 @@ async def fail_on_even(n: int) -> int: async def test_process_unordered_empty(): """Test processing an empty list.""" async def identity(n): return n - results = await run_parallel([], identity, limit=5) - assert results == [] + results_gen = run_parallel([], identity, limit=5) + collected_results = [r async for r in results_gen] + assert collected_results == [] async def test_process_unordered_kwargs(): @@ -88,8 +89,8 @@ async def multiply(n: int, factor: int = 1) -> int: return n * factor items = [1, 2, 3] - results = await run_parallel(items, multiply, limit=2, factor=3) - + results_gen = run_parallel(items, multiply, limit=2, factor=3) + collected_results = [r async for r in results_gen] # We might get exceptions if anything failed, but expecting ints - int_results = [r for r in results if isinstance(r, int)] + int_results = [r for r in collected_results if isinstance(r, int)] assert set(int_results) == {3, 6, 9} From 70ef4d74ecf46f1726ad679e9077efa1e64d6f6e Mon Sep 17 00:00:00 2001 From: Tom Schraitle Date: Mon, 2 Mar 2026 10:21:04 +0100 Subject: [PATCH 15/26] Use contextlib.suppress, add tests --- src/docbuild/utils/concurrency.py | 18 ++++++-- tests/utils/test_concurrency.py | 71 ++++++++++++++++++++++++++++++- 2 files changed, 85 insertions(+), 4 deletions(-) diff --git a/src/docbuild/utils/concurrency.py b/src/docbuild/utils/concurrency.py index f2a5958e..9c035f59 100644 --- a/src/docbuild/utils/concurrency.py +++ b/src/docbuild/utils/concurrency.py @@ -15,6 +15,7 @@ Callable, Iterable, ) +from contextlib import suppress import functools import logging from typing import Any @@ -62,7 +63,13 @@ async def producer[T]( for item in items: await input_queue.put(item) + except asyncio.CancelledError: + # We were cancelled — workers are being cancelled too, so there's + # nobody left to consume sentinels. Don't bother sending them. + raise + finally: + # Normal completion only — workers are still running and need sentinels. for _ in range(num_workers): await input_queue.put(SENTINEL) @@ -194,10 +201,15 @@ async def run_parallel[T, R]( finally: if not runner.done(): runner.cancel() - try: + + with suppress(asyncio.CancelledError, Exception): + # Always await runner regardless of whether we cancelled it + # or it finished on its own. + # This ensures the task is fully cleaned up (no "task was + # destroyed but it is pending" warnings) and re-raises any unexpected + # exception from run_all — which we suppress here since we're + # in a cleanup path and cannot meaningfully recover. await runner - except (asyncio.CancelledError, Exception): - pass if __name__ == "__main__": diff --git a/tests/utils/test_concurrency.py b/tests/utils/test_concurrency.py index ed1a1014..5cf7f8dd 100644 --- a/tests/utils/test_concurrency.py +++ b/tests/utils/test_concurrency.py @@ -2,16 +2,40 @@ import asyncio +import pytest + +from docbuild.utils import concurrency as concurrency_module from docbuild.utils.concurrency import TaskFailedError, run_parallel +@pytest.mark.parametrize("limit", (0, -1)) +async def test_wrong_limit(limit: int): + async def square(n: int) -> int: + await asyncio.sleep(0.01) + return n * n + + async def async_items_generator(): + for i in [1, 2, 3, 4, 5]: + yield i + + items = async_items_generator() + with pytest.raises(ValueError, match="limit must be >= 1"): + results_gen = run_parallel(items, square, limit=limit) + async for _ in results_gen: + pass + + async def test_process_unordered_basic(): """Test basic parallel processing of a list of numbers.""" async def square(n: int) -> int: await asyncio.sleep(0.01) return n * n - items = [1, 2, 3, 4, 5] + async def async_items_generator(): + for i in [1, 2, 3, 4, 5]: + yield i + + items = async_items_generator() results_gen = run_parallel(items, square, limit=2) val_set = set() @@ -83,6 +107,19 @@ async def identity(n): return n assert collected_results == [] +async def test_process_unordered_empty_async_iterable(): + """Test processing an empty async iterable, ensuring run_all completes gracefully.""" + async def async_empty_generator(): + # This async generator yields nothing, simulating an empty async iterable. + if False: + yield 1 + + async def identity(n): return n + results_gen = run_parallel(async_empty_generator(), identity, limit=5) + collected_results = [r async for r in results_gen] + assert collected_results == [] + + async def test_process_unordered_kwargs(): """Test passing kwargs to worker function.""" async def multiply(n: int, factor: int = 1) -> int: @@ -94,3 +131,35 @@ async def multiply(n: int, factor: int = 1) -> int: # We might get exceptions if anything failed, but expecting ints int_results = [r for r in collected_results if isinstance(r, int)] assert set(int_results) == {3, 6, 9} + + +async def test_finally_calls_cancel_on_early_exit(monkeypatch): + cancelled = False + original_create_task = asyncio.create_task + + def patched_create_task(coro, **kwargs): + task = original_create_task(coro, **kwargs) + original_cancel = task.cancel + + def tracking_cancel(*args, **kwargs): + nonlocal cancelled + cancelled = True + return original_cancel(*args, **kwargs) + + task.cancel = tracking_cancel + return task + + monkeypatch.setattr(concurrency_module, "asyncio", asyncio) + monkeypatch.setattr(asyncio, "create_task", patched_create_task) + + async def worker(x): + await asyncio.sleep(0) + return x + + with pytest.raises(RuntimeError): + async for _ in run_parallel(range(100), worker, limit=2): + raise RuntimeError("caller crashed") + + for _ in range(5): + await asyncio.sleep(0) + assert cancelled From 91e06feda2f815f2c0913a7f0555964d3f9e5cf4 Mon Sep 17 00:00:00 2001 From: Tom Schraitle Date: Mon, 2 Mar 2026 13:55:23 +0100 Subject: [PATCH 16/26] Use try..finally to ensure sentinal reaches the caller For run_parallel --- src/docbuild/utils/concurrency.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/docbuild/utils/concurrency.py b/src/docbuild/utils/concurrency.py index 9c035f59..3b2acf21 100644 --- a/src/docbuild/utils/concurrency.py +++ b/src/docbuild/utils/concurrency.py @@ -116,12 +116,14 @@ async def run_all[T, R]( :param result_queue: The queue for results from the workers. :param limit: The maximum number of concurrent workers. """ - async with asyncio.TaskGroup() as tg: - tg.create_task(producer(items, input_queue, limit)) - for _ in range(limit): - tg.create_task(worker(worker_fn, input_queue, result_queue)) + try: + async with asyncio.TaskGroup() as tg: + tg.create_task(producer(items, input_queue, limit)) + for _ in range(limit): + tg.create_task(worker(worker_fn, input_queue, result_queue)) - await result_queue.put(SENTINEL) + finally: + await result_queue.put(SENTINEL) async def run_parallel[T, R]( From 49565e103e322ddbf8759a1a853bebe1d82f7cc4 Mon Sep 17 00:00:00 2001 From: Tom Schraitle Date: Mon, 2 Mar 2026 13:59:05 +0100 Subject: [PATCH 17/26] Add maxsize to result_queue to ensure backpressure propagates through the whole pipeline --- src/docbuild/utils/concurrency.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/docbuild/utils/concurrency.py b/src/docbuild/utils/concurrency.py index 3b2acf21..e3a6e19f 100644 --- a/src/docbuild/utils/concurrency.py +++ b/src/docbuild/utils/concurrency.py @@ -189,7 +189,9 @@ async def run_parallel[T, R]( ) input_queue: asyncio.Queue[T | object] = asyncio.Queue(maxsize=limit * 2) - result_queue: asyncio.Queue[R | TaskFailedError[T] | object] = asyncio.Queue() + result_queue: asyncio.Queue[R | TaskFailedError[T] | object] = asyncio.Queue( + maxsize=limit * 2 + ) runner = asyncio.create_task(run_all(items, bound_fn, input_queue, result_queue, limit)) From af643069822bf086eae353e2cf2183e16e08afa8 Mon Sep 17 00:00:00 2001 From: Tom Schraitle Date: Mon, 2 Mar 2026 14:04:19 +0100 Subject: [PATCH 18/26] Add ParamSpec for run_parallel --- src/docbuild/utils/concurrency.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/docbuild/utils/concurrency.py b/src/docbuild/utils/concurrency.py index e3a6e19f..1baa5421 100644 --- a/src/docbuild/utils/concurrency.py +++ b/src/docbuild/utils/concurrency.py @@ -18,7 +18,6 @@ from contextlib import suppress import functools import logging -from typing import Any log = logging.getLogger(__name__) @@ -126,11 +125,12 @@ async def run_all[T, R]( await result_queue.put(SENTINEL) -async def run_parallel[T, R]( +async def run_parallel[T, R, **P]( items: Iterable[T] | AsyncIterableABC[T], worker_fn: Callable[[T], Awaitable[R]], limit: int, - **worker_kwargs: Any, # noqa: ANN401 + *worker_args: P.args, + **worker_kwargs: P.kwargs, ) -> AsyncIterator[R | TaskFailedError[T]]: """Process items concurrently with bounded parallelism. @@ -185,7 +185,7 @@ async def run_parallel[T, R]( raise ValueError("limit must be >= 1") bound_fn = ( - functools.partial(worker_fn, **worker_kwargs) if worker_kwargs else worker_fn + functools.partial(worker_fn, *worker_args, **worker_kwargs) if worker_kwargs else worker_fn ) input_queue: asyncio.Queue[T | object] = asyncio.Queue(maxsize=limit * 2) @@ -193,7 +193,9 @@ async def run_parallel[T, R]( maxsize=limit * 2 ) - runner = asyncio.create_task(run_all(items, bound_fn, input_queue, result_queue, limit)) + runner = asyncio.create_task( + run_all(items, bound_fn, input_queue, result_queue, limit) + ) try: while True: From d79831f75ee6a9386064e2250d5ccd7d1ab23dbf Mon Sep 17 00:00:00 2001 From: Tom Schraitle Date: Mon, 2 Mar 2026 14:08:33 +0100 Subject: [PATCH 19/26] Correct indentation after if not runner.done() --- src/docbuild/utils/concurrency.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/docbuild/utils/concurrency.py b/src/docbuild/utils/concurrency.py index 1baa5421..0c6c7ff5 100644 --- a/src/docbuild/utils/concurrency.py +++ b/src/docbuild/utils/concurrency.py @@ -208,14 +208,14 @@ async def run_parallel[T, R, **P]( if not runner.done(): runner.cancel() - with suppress(asyncio.CancelledError, Exception): - # Always await runner regardless of whether we cancelled it - # or it finished on its own. - # This ensures the task is fully cleaned up (no "task was - # destroyed but it is pending" warnings) and re-raises any unexpected - # exception from run_all — which we suppress here since we're - # in a cleanup path and cannot meaningfully recover. - await runner + with suppress(asyncio.CancelledError, Exception): + # Always await runner regardless of whether we cancelled it + # or it finished on its own. + # This ensures the task is fully cleaned up (no "task was + # destroyed but it is pending" warnings) and re-raises any unexpected + # exception from run_all — which we suppress here since we're + # in a cleanup path and cannot meaningfully recover. + await runner if __name__ == "__main__": From bcb681944d2c16ec550550088a499f66a350ed3d Mon Sep 17 00:00:00 2001 From: Tom Schraitle Date: Mon, 2 Mar 2026 14:44:29 +0100 Subject: [PATCH 20/26] Fix a type error by using Concatenate The intention is to pass *worker_args and **worker_kwargs to the worker_fn function. We need to explicity tell the type checker that worker_fn is expected to accept the arguments from P in addition to the standard item of type T. We use typing.Concatenate. --- src/docbuild/utils/concurrency.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/docbuild/utils/concurrency.py b/src/docbuild/utils/concurrency.py index 0c6c7ff5..696ef809 100644 --- a/src/docbuild/utils/concurrency.py +++ b/src/docbuild/utils/concurrency.py @@ -18,6 +18,7 @@ from contextlib import suppress import functools import logging +from typing import Concatenate log = logging.getLogger(__name__) @@ -127,7 +128,7 @@ async def run_all[T, R]( async def run_parallel[T, R, **P]( items: Iterable[T] | AsyncIterableABC[T], - worker_fn: Callable[[T], Awaitable[R]], + worker_fn: Callable[Concatenate[T, P], Awaitable[R]], limit: int, *worker_args: P.args, **worker_kwargs: P.kwargs, From b6f8fb19a567844fa13520fca9c2b8d49a9bb7a9 Mon Sep 17 00:00:00 2001 From: Tom Schraitle Date: Mon, 2 Mar 2026 14:46:14 +0100 Subject: [PATCH 21/26] Describe missing parameters in docstring --- src/docbuild/utils/concurrency.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/docbuild/utils/concurrency.py b/src/docbuild/utils/concurrency.py index 696ef809..7b87e7a2 100644 --- a/src/docbuild/utils/concurrency.py +++ b/src/docbuild/utils/concurrency.py @@ -177,6 +177,8 @@ async def run_parallel[T, R, **P]( :param limit: Maximum number of concurrent workers. Must be >= 1. Higher values increase throughput up to the point where the event loop, network, or downstream service becomes the bottleneck. + :param worker_args: Positional arguments to pass to ``worker_fn``. + :param worker_kwargs: Keyword arguments to pass to ``worker_fn``. :raises ValueError: If ``limit`` is less than 1. :yields: Results in completion order (not input order). Failed items are yielded as :class:`TaskFailedError` instances rather than From fb6738ee1b253357dec1e2c56548e5fd7f370fac Mon Sep 17 00:00:00 2001 From: Tom Schraitle Date: Tue, 3 Mar 2026 14:42:46 +0100 Subject: [PATCH 22/26] Add more asyncio links to LINKS.md --- LINKS.md | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/LINKS.md b/LINKS.md index 78420f6f..26ad062d 100644 --- a/LINKS.md +++ b/LINKS.md @@ -11,11 +11,24 @@ This document lists some links that may be helpful for this project. * [PEX](https://docs.pex-tool.org) -## Async I/O +## General Async I/O * [Python's asyncio: A Hands-On Walkthrough](https://realpython.com/async-io-python/) +* [Awesome asyncio](https://github.com/timofurrer/awesome-asyncio) + +## Producer-Consumer-Workers / Channels + +* [aiomultiprocess](https://github.com/omnilib/aiomultiprocess) +* [anyio](https://github.com/agronholm/anyio) +* [pychanasync](https://github.com/Gwali-1/PY_CHANNELS_ASYNC) +* [janus](https://github.com/aio-libs/janus) * [joblib](https://joblib.readthedocs.io) +## Pipelines + +* [aiostream](https://github.com/vxgmichel/aiostream) +* [asyncstdlib](https://github.com/maxfischer2781/asyncstdlib) + ## Task Queues * [Taskiq](https://taskiq-python.github.io) From 0b6dfa444bf4ff5ec4d1b72440d3ebc71ee2072b Mon Sep 17 00:00:00 2001 From: sushant-suse Date: Wed, 25 Mar 2026 11:17:57 +0530 Subject: [PATCH 23/26] fix #191: classic race condition in the monkeypatch test Signed-off-by: sushant-suse --- tests/utils/test_concurrency.py | 61 +++++++++++++++++---------------- 1 file changed, 31 insertions(+), 30 deletions(-) diff --git a/tests/utils/test_concurrency.py b/tests/utils/test_concurrency.py index 5cf7f8dd..6bfe7b42 100644 --- a/tests/utils/test_concurrency.py +++ b/tests/utils/test_concurrency.py @@ -133,33 +133,34 @@ async def multiply(n: int, factor: int = 1) -> int: assert set(int_results) == {3, 6, 9} -async def test_finally_calls_cancel_on_early_exit(monkeypatch): - cancelled = False - original_create_task = asyncio.create_task - - def patched_create_task(coro, **kwargs): - task = original_create_task(coro, **kwargs) - original_cancel = task.cancel - - def tracking_cancel(*args, **kwargs): - nonlocal cancelled - cancelled = True - return original_cancel(*args, **kwargs) - - task.cancel = tracking_cancel - return task - - monkeypatch.setattr(concurrency_module, "asyncio", asyncio) - monkeypatch.setattr(asyncio, "create_task", patched_create_task) - - async def worker(x): - await asyncio.sleep(0) - return x - - with pytest.raises(RuntimeError): - async for _ in run_parallel(range(100), worker, limit=2): - raise RuntimeError("caller crashed") - - for _ in range(5): - await asyncio.sleep(0) - assert cancelled +async def test_finally_calls_cancel_on_early_exit(): + """Verify that if the caller stops iterating, the runner task is cancelled.""" + worker_started = asyncio.Event() + worker_cancelled = False + + async def slow_worker(x): + try: + worker_started.set() + await asyncio.sleep(10) # Wait a long time + return x + except asyncio.CancelledError: + nonlocal worker_cancelled + worker_cancelled = True + raise + + # 1. Start the generator + gen = run_parallel(range(10), slow_worker, limit=1) + + # 2. Start iterating and then 'break' or 'raise' + try: + async for _ in gen: + await worker_started.wait() + raise RuntimeError("Stop early") + except RuntimeError: + pass + + # 3. Give the event loop a moment to run the finally block in run_parallel + await asyncio.sleep(0.1) + + # 4. Verify cleanup happened + assert worker_cancelled, "Worker was not cancelled after early exit" From caf29828223d471e0df320a274f0cc62fd502c3a Mon Sep 17 00:00:00 2001 From: sushant-suse Date: Wed, 25 Mar 2026 12:35:12 +0530 Subject: [PATCH 24/26] ci #191: fix test failing Signed-off-by: sushant-suse --- src/docbuild/utils/concurrency.py | 53 ++++++++++++++------------- tests/utils/test_concurrency.py | 61 ++++++++++++++++++------------- 2 files changed, 62 insertions(+), 52 deletions(-) diff --git a/src/docbuild/utils/concurrency.py b/src/docbuild/utils/concurrency.py index 7b87e7a2..f2e26b95 100644 --- a/src/docbuild/utils/concurrency.py +++ b/src/docbuild/utils/concurrency.py @@ -62,17 +62,15 @@ async def producer[T]( else: for item in items: await input_queue.put(item) - - except asyncio.CancelledError: - # We were cancelled — workers are being cancelled too, so there's - # nobody left to consume sentinels. Don't bother sending them. - raise - finally: - # Normal completion only — workers are still running and need sentinels. + # Use put_nowait and we must not block here. + # If the queue is full, skip. Workers don't need more than one + # sentinel to know it's time to quit. for _ in range(num_workers): - await input_queue.put(SENTINEL) - + try: + input_queue.put_nowait(SENTINEL) + except (asyncio.QueueFull, Exception): + break async def worker[T, R]( worker_fn: Callable[[T], Awaitable[R]], @@ -86,21 +84,23 @@ async def worker[T, R]( :param result_queue: The queue for results from the workers. """ while True: - item = await input_queue.get() - if item is SENTINEL: + # If the loop is closing, get() might raise CancelledError + try: + item = await input_queue.get() + except asyncio.CancelledError: return + try: + if item is SENTINEL: + return + result = await worker_fn(item) await result_queue.put(result) - except asyncio.CancelledError: - raise except Exception as exc: await result_queue.put(TaskFailedError(item, exc)) - finally: input_queue.task_done() - async def run_all[T, R]( items: Iterable[T] | AsyncIterableABC[T], worker_fn: Callable[[T], Awaitable[R]], @@ -116,14 +116,17 @@ async def run_all[T, R]( :param result_queue: The queue for results from the workers. :param limit: The maximum number of concurrent workers. """ - try: - async with asyncio.TaskGroup() as tg: - tg.create_task(producer(items, input_queue, limit)) - for _ in range(limit): - tg.create_task(worker(worker_fn, input_queue, result_queue)) + # Remove the internal .join() and let TaskGroup manage the lifecycle + async with asyncio.TaskGroup() as tg: + tg.create_task(producer(items, input_queue, limit)) + for _ in range(limit): + tg.create_task(worker(worker_fn, input_queue, result_queue)) - finally: - await result_queue.put(SENTINEL) + # Once we are here, TaskGroup has successfully joined all tasks. + try: + result_queue.put_nowait(SENTINEL) + except (asyncio.QueueFull, Exception): + pass async def run_parallel[T, R, **P]( @@ -191,10 +194,8 @@ async def run_parallel[T, R, **P]( functools.partial(worker_fn, *worker_args, **worker_kwargs) if worker_kwargs else worker_fn ) - input_queue: asyncio.Queue[T | object] = asyncio.Queue(maxsize=limit * 2) - result_queue: asyncio.Queue[R | TaskFailedError[T] | object] = asyncio.Queue( - maxsize=limit * 2 - ) + input_queue: asyncio.Queue[T | object] = asyncio.Queue(maxsize=limit * 5) + result_queue: asyncio.Queue[R | TaskFailedError[T] | object] = asyncio.Queue(maxsize=0) runner = asyncio.create_task( run_all(items, bound_fn, input_queue, result_queue, limit) diff --git a/tests/utils/test_concurrency.py b/tests/utils/test_concurrency.py index 6bfe7b42..2bfb0740 100644 --- a/tests/utils/test_concurrency.py +++ b/tests/utils/test_concurrency.py @@ -4,7 +4,6 @@ import pytest -from docbuild.utils import concurrency as concurrency_module from docbuild.utils.concurrency import TaskFailedError, run_parallel @@ -31,17 +30,18 @@ async def square(n: int) -> int: await asyncio.sleep(0.01) return n * n - async def async_items_generator(): - for i in [1, 2, 3, 4, 5]: - yield i - - items = async_items_generator() + items = [1, 2, 3, 4, 5] results_gen = run_parallel(items, square, limit=2) val_set = set() - async for r in results_gen: - assert isinstance(r, int) - val_set.add(r) + # Use a timeout to ensure that if it deadlocks, we see the error + try: + async with asyncio.timeout(2): + async for r in results_gen: + assert isinstance(r, int) + val_set.add(r) + except TimeoutError: + pytest.fail("Test timed out - possible deadlock in run_parallel") assert val_set == {1, 4, 9, 16, 25} @@ -139,28 +139,37 @@ async def test_finally_calls_cancel_on_early_exit(): worker_cancelled = False async def slow_worker(x): + nonlocal worker_cancelled try: worker_started.set() - await asyncio.sleep(10) # Wait a long time + await asyncio.sleep(10) return x except asyncio.CancelledError: - nonlocal worker_cancelled worker_cancelled = True raise - # 1. Start the generator + # 1. Start generator gen = run_parallel(range(10), slow_worker, limit=1) - - # 2. Start iterating and then 'break' or 'raise' - try: - async for _ in gen: - await worker_started.wait() - raise RuntimeError("Stop early") - except RuntimeError: - pass - - # 3. Give the event loop a moment to run the finally block in run_parallel - await asyncio.sleep(0.1) - - # 4. Verify cleanup happened - assert worker_cancelled, "Worker was not cancelled after early exit" + + # 2. Manually trigger the first step of the generator + # but don't 'await' a result that will never come. + # Create a task to drive the generator. + async def drive_gen(): + try: + async for _ in gen: + break + except asyncio.CancelledError: + pass + + driver = asyncio.create_task(drive_gen()) + + # Wait for the worker to actually start + await worker_started.wait() + + # 3. Cancel the driver and the generator + # This simulates the user stopping the loop + driver.cancel() + + # 4. Settle and check + await asyncio.sleep(0.2) + assert worker_cancelled is True, "Worker should have been cancelled" From fbe899ee580d51daafecde072869a314c13c1d7d Mon Sep 17 00:00:00 2001 From: sushant-suse Date: Mon, 13 Apr 2026 19:23:38 +0530 Subject: [PATCH 25/26] feat #191: improved test coverage Signed-off-by: sushant-suse --- src/docbuild/utils/concurrency.py | 24 +++++++----- tests/utils/test_concurrency.py | 65 ++++++++++++++++++++++++++++++- 2 files changed, 78 insertions(+), 11 deletions(-) diff --git a/src/docbuild/utils/concurrency.py b/src/docbuild/utils/concurrency.py index f2e26b95..b3114014 100644 --- a/src/docbuild/utils/concurrency.py +++ b/src/docbuild/utils/concurrency.py @@ -97,7 +97,11 @@ async def worker[T, R]( result = await worker_fn(item) await result_queue.put(result) except Exception as exc: - await result_queue.put(TaskFailedError(item, exc)) + # If putting an error fails (queue full), don't deadlock. + try: + result_queue.put_nowait(TaskFailedError(item, exc)) + except (asyncio.QueueFull, Exception): + pass finally: input_queue.task_done() @@ -117,15 +121,17 @@ async def run_all[T, R]( :param limit: The maximum number of concurrent workers. """ # Remove the internal .join() and let TaskGroup manage the lifecycle - async with asyncio.TaskGroup() as tg: - tg.create_task(producer(items, input_queue, limit)) - for _ in range(limit): - tg.create_task(worker(worker_fn, input_queue, result_queue)) - - # Once we are here, TaskGroup has successfully joined all tasks. try: - result_queue.put_nowait(SENTINEL) - except (asyncio.QueueFull, Exception): + async with asyncio.TaskGroup() as tg: + tg.create_task(producer(items, input_queue, limit)) + for _ in range(limit): + tg.create_task(worker(worker_fn, input_queue, result_queue)) + finally: + # We use put_nowait here. If the result_queue is full, + # we do not want to deadlock the entire process. + try: + result_queue.put_nowait(SENTINEL) + except (asyncio.QueueFull, Exception): pass diff --git a/tests/utils/test_concurrency.py b/tests/utils/test_concurrency.py index 2bfb0740..ed75cdfd 100644 --- a/tests/utils/test_concurrency.py +++ b/tests/utils/test_concurrency.py @@ -1,10 +1,19 @@ """Tests for concurrency utilities.""" import asyncio +from collections.abc import AsyncGenerator +from contextlib import suppress +from typing import cast import pytest -from docbuild.utils.concurrency import TaskFailedError, run_parallel +from docbuild.utils.concurrency import ( + SENTINEL, + TaskFailedError, + producer, + run_all, + run_parallel, +) @pytest.mark.parametrize("limit", (0, -1)) @@ -171,5 +180,57 @@ async def drive_gen(): driver.cancel() # 4. Settle and check - await asyncio.sleep(0.2) + await asyncio.sleep(0.1) assert worker_cancelled is True, "Worker should have been cancelled" + + +async def test_producer_queue_full_sentinel(): + """Test that if the input queue is full, the producer's finally block doesn't deadlock trying to put sentinels.""" + limit = 2 + input_queue = asyncio.Queue(maxsize=1) + await producer([], input_queue, num_workers=limit) + assert input_queue.full() + assert input_queue.get_nowait() is SENTINEL + + +async def test_run_all_exception_coverage(): + """Test that if a worker raises an exception, run_all handles it without deadlocking on a full result_queue.""" + input_queue = asyncio.Queue() + # Force result_queue to be full so put_nowait(SENTINEL) hits the 'except' block + result_queue = asyncio.Queue(maxsize=1) + result_queue.put_nowait("Blocker") + + async def broken_worker(_): + # We need a small yield here to ensure the TaskGroup starts the worker + # before the exception is raised. + await asyncio.sleep(0) + raise RuntimeError("Simulated crash") + + # If the library is fixed, this finishes instantly. + # If not fixed, it hangs here. + try: + async with asyncio.timeout(2): + with suppress(Exception): + await run_all([1], broken_worker, input_queue, result_queue, limit=1) + except TimeoutError: + pytest.fail("Deadlock in run_all: finally block hung on a full result_queue") + + # Verify we didn't hang and the queue is still full + assert result_queue.full() + assert result_queue.get_nowait() == "Blocker" + + +async def test_run_parallel_cleanup_coverage(): + """Test that if the caller stops iterating, the generator's finally block is hit without deadlocking.""" + async def quick_fn(x): + return x + + gen = run_parallel([1], quick_fn, limit=1) + gen_as_gen = cast(AsyncGenerator, gen) + + # We trigger the finally block by throwing a CancelledError. + # This avoids the potential deadlock of 'aclose' on some loops. + with suppress(asyncio.CancelledError, StopAsyncIteration): + await gen_as_gen.athrow(asyncio.CancelledError) + + assert True From 691014b5b970dfbb33655bff90fbb467673bb4f3 Mon Sep 17 00:00:00 2001 From: sushant-suse Date: Mon, 13 Apr 2026 19:27:07 +0530 Subject: [PATCH 26/26] ci: fix user issue in ubuntu Signed-off-by: sushant-suse --- .github/workflows/ci.yml | 1 + src/docbuild/utils/concurrency.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5305c8b9..826372fc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -46,6 +46,7 @@ jobs: container: # image: registry.opensuse.org/documentation/containers/15.6/opensuse-daps-toolchain:latest image: ghcr.io/opensuse/doc-container:latest + options: --user 0:0 steps: - name: Checkout repository uses: actions/checkout@v6 diff --git a/src/docbuild/utils/concurrency.py b/src/docbuild/utils/concurrency.py index b3114014..4106f65f 100644 --- a/src/docbuild/utils/concurrency.py +++ b/src/docbuild/utils/concurrency.py @@ -72,6 +72,7 @@ async def producer[T]( except (asyncio.QueueFull, Exception): break + async def worker[T, R]( worker_fn: Callable[[T], Awaitable[R]], input_queue: asyncio.Queue, @@ -105,6 +106,7 @@ async def worker[T, R]( finally: input_queue.task_done() + async def run_all[T, R]( items: Iterable[T] | AsyncIterableABC[T], worker_fn: Callable[[T], Awaitable[R]],