diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5305c8b9..826372fc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -46,6 +46,7 @@ jobs: container: # image: registry.opensuse.org/documentation/containers/15.6/opensuse-daps-toolchain:latest image: ghcr.io/opensuse/doc-container:latest + options: --user 0:0 steps: - name: Checkout repository uses: actions/checkout@v6 diff --git a/LINKS.md b/LINKS.md index 78420f6f..26ad062d 100644 --- a/LINKS.md +++ b/LINKS.md @@ -11,11 +11,24 @@ This document lists some links that may be helpful for this project. * [PEX](https://docs.pex-tool.org) -## Async I/O +## General Async I/O * [Python's asyncio: A Hands-On Walkthrough](https://realpython.com/async-io-python/) +* [Awesome asyncio](https://github.com/timofurrer/awesome-asyncio) + +## Producer-Consumer-Workers / Channels + +* [aiomultiprocess](https://github.com/omnilib/aiomultiprocess) +* [anyio](https://github.com/agronholm/anyio) +* [pychanasync](https://github.com/Gwali-1/PY_CHANNELS_ASYNC) +* [janus](https://github.com/aio-libs/janus) * [joblib](https://joblib.readthedocs.io) +## Pipelines + +* [aiostream](https://github.com/vxgmichel/aiostream) +* [asyncstdlib](https://github.com/maxfischer2781/asyncstdlib) + ## Task Queues * [Taskiq](https://taskiq-python.github.io) diff --git a/changelog.d/191.feature.rst b/changelog.d/191.feature.rst new file mode 100644 index 00000000..fa198f44 --- /dev/null +++ b/changelog.d/191.feature.rst @@ -0,0 +1 @@ +Add a producer/consumer model implementation. diff --git a/src/docbuild/utils/concurrency.py b/src/docbuild/utils/concurrency.py new file mode 100644 index 00000000..4106f65f --- /dev/null +++ b/src/docbuild/utils/concurrency.py @@ -0,0 +1,314 @@ +"""Concurrency utilities using producer-consumer patterns. + +This module provides helpers for managing concurrent asyncio tasks with +strict concurrency limits, backpressure handling, and robust exception tracking. + +It is designed to handle both I/O-bound tasks (via native asyncio coroutines) and +CPU-bound tasks (via `loop.run_in_executor`) while keeping resource usage deterministic. +""" + +import asyncio +from collections.abc import ( + AsyncIterable as AsyncIterableABC, + AsyncIterator, + Awaitable, + Callable, + Iterable, +) +from contextlib import suppress +import functools +import logging +from typing import Concatenate + +log = logging.getLogger(__name__) + +#: Sentinel value for internal use when needed (e.g., to signal completion). +SENTINEL = object() + + +class TaskFailedError[T](Exception): + """Exception raised when a task fails during processing. + + This wrapper preserves the context of a failure in concurrent processing pipelines. + Since results may be returned out of order or aggregated later, wrapping the + exception allows the caller to link a failure back to the specific input item + that caused it. + + :param item: The item that was being processed. + :param original_exception: The exception that caused the failure. + """ + + def __init__(self, item: T, original_exception: Exception) -> None: + super().__init__(f"Task failed for item {item}: {original_exception}") + self.item = item + self.original_exception = original_exception + + +async def producer[T]( + items: Iterable[T] | AsyncIterableABC[T], + input_queue: asyncio.Queue, + num_workers: int, +) -> None: + """Feed items into the input queue, then send one sentinel per worker. + + :param items: An iterable or async iterable of items to be processed. + :param input_queue: The queue for items to be processed by workers. + :param num_workers: The number of workers, used to send the correct number of sentinels. + """ + try: + if isinstance(items, AsyncIterableABC): + async for item in items: + await input_queue.put(item) + else: + for item in items: + await input_queue.put(item) + finally: + # Use put_nowait and we must not block here. + # If the queue is full, skip. Workers don't need more than one + # sentinel to know it's time to quit. + for _ in range(num_workers): + try: + input_queue.put_nowait(SENTINEL) + except (asyncio.QueueFull, Exception): + break + + +async def worker[T, R]( + worker_fn: Callable[[T], Awaitable[R]], + input_queue: asyncio.Queue, + result_queue: asyncio.Queue, +) -> None: + """Pull items from the input queue, process them, push results out. + + :param worker_fn: The asynchronous function that processes a single item. + :param input_queue: The queue for items to be processed by workers. + :param result_queue: The queue for results from the workers. + """ + while True: + # If the loop is closing, get() might raise CancelledError + try: + item = await input_queue.get() + except asyncio.CancelledError: + return + + try: + if item is SENTINEL: + return + + result = await worker_fn(item) + await result_queue.put(result) + except Exception as exc: + # If putting an error fails (queue full), don't deadlock. + try: + result_queue.put_nowait(TaskFailedError(item, exc)) + except (asyncio.QueueFull, Exception): + pass + finally: + input_queue.task_done() + + +async def run_all[T, R]( + items: Iterable[T] | AsyncIterableABC[T], + worker_fn: Callable[[T], Awaitable[R]], + input_queue: asyncio.Queue, + result_queue: asyncio.Queue, + limit: int, +) -> None: + """Orchestrate producer + workers, then signal the consumer when done. + + :param items: An iterable or async iterable of items to be processed. + :param worker_fn: The asynchronous function that processes a single item. + :param input_queue: The queue for items to be processed by workers. + :param result_queue: The queue for results from the workers. + :param limit: The maximum number of concurrent workers. + """ + # Remove the internal .join() and let TaskGroup manage the lifecycle + try: + async with asyncio.TaskGroup() as tg: + tg.create_task(producer(items, input_queue, limit)) + for _ in range(limit): + tg.create_task(worker(worker_fn, input_queue, result_queue)) + finally: + # We use put_nowait here. If the result_queue is full, + # we do not want to deadlock the entire process. + try: + result_queue.put_nowait(SENTINEL) + except (asyncio.QueueFull, Exception): + pass + + +async def run_parallel[T, R, **P]( + items: Iterable[T] | AsyncIterableABC[T], + worker_fn: Callable[Concatenate[T, P], Awaitable[R]], + limit: int, + *worker_args: P.args, + **worker_kwargs: P.kwargs, +) -> AsyncIterator[R | TaskFailedError[T]]: + """Process items concurrently with bounded parallelism. + + Uses a producer/worker/consumer pipeline: + + - A single **producer** task feeds items into a bounded input queue. + - ``limit`` **worker** tasks pull from the input queue, call ``worker_fn``, + and push results into a bounded result queue. + - The **caller** consumes results by iterating this async generator. + + All three stages run concurrently. Backpressure propagates naturally: + a slow consumer stalls workers; stalled workers stall the producer. + Order of results is NOT guaranteed. + + If ``worker_fn`` raises, the exception is wrapped in + :class:`TaskFailedError` and yielded rather than re-raised, so one + failing item does not abort the pipeline. + + Performance characteristics + --------------------------- + - **Throughput:** approaches ``limit * per-worker-throughput`` for + I/O-bound workloads where workers spend most time awaiting external + resources. CPU-bound work gains little due to the GIL; use + ``ProcessPoolExecutor`` wrapped in ``asyncio.run_in_executor`` instead. + - **Startup cost:** O(limit) — one asyncio task per worker, each cheap + to create (~microseconds). + - **Memory:** O(limit). Both the input queue (``maxsize=limit * 2``) + and the result queue (``maxsize=limit * 2``) are bounded. At most + ``limit`` items are in-flight inside workers at any time, giving a + total live-item count of roughly ``5 * limit``. + Note: each item itself may be arbitrarily large; the O(limit) bound + refers to the *number* of items held in memory, not their byte size. + - **Latency:** time-to-first-result equals one worker's latency. + Remaining results stream out as workers complete, with no polling + delay (sentinel-based signalling, zero busy-wait). + - **Cancellation:** if the caller abandons the generator (e.g. ``break`` + in an ``async for`` loop), the internal runner task is cancelled and + all worker tasks are cleaned up promptly via ``TaskGroup``. + + :param items: Iterable or async iterable of items to process. + :param worker_fn: Async callable invoked as ``worker_fn(item)`` for + each item. Must be safe to call concurrently from ``limit`` tasks. + :param limit: Maximum number of concurrent workers. Must be >= 1. + Higher values increase throughput up to the point where the event + loop, network, or downstream service becomes the bottleneck. + :param worker_args: Positional arguments to pass to ``worker_fn``. + :param worker_kwargs: Keyword arguments to pass to ``worker_fn``. + :raises ValueError: If ``limit`` is less than 1. + :yields: Results in completion order (not input order). Failed items + are yielded as :class:`TaskFailedError` instances rather than + raising, so the caller can handle partial failures inline. + """ + if limit <= 0: + raise ValueError("limit must be >= 1") + + bound_fn = ( + functools.partial(worker_fn, *worker_args, **worker_kwargs) if worker_kwargs else worker_fn + ) + + input_queue: asyncio.Queue[T | object] = asyncio.Queue(maxsize=limit * 5) + result_queue: asyncio.Queue[R | TaskFailedError[T] | object] = asyncio.Queue(maxsize=0) + + runner = asyncio.create_task( + run_all(items, bound_fn, input_queue, result_queue, limit) + ) + + try: + while True: + result = await result_queue.get() + if result is SENTINEL: + break + yield result # type: ignore[misc] + + finally: + if not runner.done(): + runner.cancel() + + with suppress(asyncio.CancelledError, Exception): + # Always await runner regardless of whether we cancelled it + # or it finished on its own. + # This ensures the task is fully cleaned up (no "task was + # destroyed but it is pending" warnings) and re-raises any unexpected + # exception from run_all — which we suppress here since we're + # in a cleanup path and cannot meaningfully recover. + await runner + + +if __name__ == "__main__": + import time + + logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + + async def sample_worker(num: int) -> int: + """Create a simple worker that simulates some I/O-bound work.""" + if num in (5, 8): + log.warning("Simulating failure for item %d", num) + # HINT: This is the "alternative" implementation. + # Instead of having a Failure class, we just raise the exception + # and add the item into the exception as an additional metadata + raise ValueError("Item 5 is not allowed!", num) + # Alternative: + # raise ValueError("Item 5 is not allowed!", {"item": num}) + + log.info("Processing item %d", num) + await asyncio.sleep(0.1) # Simulate I/O delay + return num * 2 + + # Make process intensive tasks in a executor + # 1. Define the heavy lifting function (must be at module level for pickle) + def heavy_cpu_math(item: int) -> int: + """Simulate a CPU-bound task.""" + return item * item + + async def main() -> None: + """Run the example.""" + async def generate_items() -> AsyncIterableABC[int]: + for i in range(10): + yield i + # yield from range(10) + + log.info("--- Running process_unordered ---") + start_time = time.monotonic() + task_results = ( + res async for res in run_parallel(generate_items(), sample_worker, limit=3) + ) + end_time = time.monotonic() + log.info("Finished in %.2f seconds\n", end_time - start_time) + + successful_results = [] + failed_tasks = [] + async for res in task_results: + if isinstance(res, TaskFailedError): + failed_tasks.append((res.item, res.original_exception)) + else: + successful_results.append(res) + + log.info("Successful results (unordered): %s", (successful_results)) + log.info("Caught exceptions: %s", failed_tasks) + + ## ------------------- + log.info("--- Running process executor ---") + from concurrent.futures import Executor, ProcessPoolExecutor + + # 2. Create the wrapper + async def cpu_worker_wrapper( + item: int, executor: Executor | None = None + ) -> int: + loop = asyncio.get_running_loop() + # Use the passed executor + return await loop.run_in_executor(executor, heavy_cpu_math, item) + + # 3. Use your existing utility with the executor passed as a kwarg + items = range(10) + with ProcessPoolExecutor() as process_pool: + + successful_results = [] + failed_tasks = [] + async for res in run_parallel( + items, cpu_worker_wrapper, limit=4, executor=process_pool + ): + if isinstance(res, TaskFailedError): + failed_tasks.append((res.item, res.original_exception)) + else: + successful_results.append(res) + + log.info("Successful results (unordered): %s", (successful_results)) + log.info("Caught exceptions: %s", failed_tasks) + + asyncio.run(main()) diff --git a/tests/utils/test_concurrency.py b/tests/utils/test_concurrency.py new file mode 100644 index 00000000..ed75cdfd --- /dev/null +++ b/tests/utils/test_concurrency.py @@ -0,0 +1,236 @@ +"""Tests for concurrency utilities.""" + +import asyncio +from collections.abc import AsyncGenerator +from contextlib import suppress +from typing import cast + +import pytest + +from docbuild.utils.concurrency import ( + SENTINEL, + TaskFailedError, + producer, + run_all, + run_parallel, +) + + +@pytest.mark.parametrize("limit", (0, -1)) +async def test_wrong_limit(limit: int): + async def square(n: int) -> int: + await asyncio.sleep(0.01) + return n * n + + async def async_items_generator(): + for i in [1, 2, 3, 4, 5]: + yield i + + items = async_items_generator() + with pytest.raises(ValueError, match="limit must be >= 1"): + results_gen = run_parallel(items, square, limit=limit) + async for _ in results_gen: + pass + + +async def test_process_unordered_basic(): + """Test basic parallel processing of a list of numbers.""" + async def square(n: int) -> int: + await asyncio.sleep(0.01) + return n * n + + items = [1, 2, 3, 4, 5] + results_gen = run_parallel(items, square, limit=2) + + val_set = set() + # Use a timeout to ensure that if it deadlocks, we see the error + try: + async with asyncio.timeout(2): + async for r in results_gen: + assert isinstance(r, int) + val_set.add(r) + except TimeoutError: + pytest.fail("Test timed out - possible deadlock in run_parallel") + + assert val_set == {1, 4, 9, 16, 25} + + +async def test_process_unordered_concurrency_limit(): + """Verify that concurrency limit is respected.""" + active_workers = 0 + max_active = 0 + lock = asyncio.Lock() + + async def track_concurrency(n: int) -> int: + nonlocal active_workers, max_active + async with lock: + active_workers += 1 + max_active = max(max_active, active_workers) + + await asyncio.sleep(0.05) + + async with lock: + active_workers -= 1 + return n + + items = range(10) + limit = 3 + # Consume the async generator to ensure all workers run and concurrency is tracked. + _ = [r async for r in run_parallel(items, track_concurrency, limit=limit)] + + assert max_active <= limit + + +async def test_process_unordered_exceptions(): + """Test exception handling returning TaskFailedError.""" + async def fail_on_even(n: int) -> int: + if n % 2 == 0: + raise ValueError(f"Even number: n={n}") + return n + + items = [1, 2, 3] + results_gen = run_parallel(items, fail_on_even, limit=2) + results = [r async for r in results_gen] + assert len(results) == 3 + + success_vals = [] + failed_items = [] + + for r in results: + match r: + case TaskFailedError(item=item, original_exception=exc): + failed_items.append(item) + assert isinstance(exc, ValueError) + case _: + success_vals.append(r) + + assert set(success_vals) == {1, 3} + assert failed_items == [2] + + +async def test_process_unordered_empty(): + """Test processing an empty list.""" + async def identity(n): return n + results_gen = run_parallel([], identity, limit=5) + collected_results = [r async for r in results_gen] + assert collected_results == [] + + +async def test_process_unordered_empty_async_iterable(): + """Test processing an empty async iterable, ensuring run_all completes gracefully.""" + async def async_empty_generator(): + # This async generator yields nothing, simulating an empty async iterable. + if False: + yield 1 + + async def identity(n): return n + results_gen = run_parallel(async_empty_generator(), identity, limit=5) + collected_results = [r async for r in results_gen] + assert collected_results == [] + + +async def test_process_unordered_kwargs(): + """Test passing kwargs to worker function.""" + async def multiply(n: int, factor: int = 1) -> int: + return n * factor + + items = [1, 2, 3] + results_gen = run_parallel(items, multiply, limit=2, factor=3) + collected_results = [r async for r in results_gen] + # We might get exceptions if anything failed, but expecting ints + int_results = [r for r in collected_results if isinstance(r, int)] + assert set(int_results) == {3, 6, 9} + + +async def test_finally_calls_cancel_on_early_exit(): + """Verify that if the caller stops iterating, the runner task is cancelled.""" + worker_started = asyncio.Event() + worker_cancelled = False + + async def slow_worker(x): + nonlocal worker_cancelled + try: + worker_started.set() + await asyncio.sleep(10) + return x + except asyncio.CancelledError: + worker_cancelled = True + raise + + # 1. Start generator + gen = run_parallel(range(10), slow_worker, limit=1) + + # 2. Manually trigger the first step of the generator + # but don't 'await' a result that will never come. + # Create a task to drive the generator. + async def drive_gen(): + try: + async for _ in gen: + break + except asyncio.CancelledError: + pass + + driver = asyncio.create_task(drive_gen()) + + # Wait for the worker to actually start + await worker_started.wait() + + # 3. Cancel the driver and the generator + # This simulates the user stopping the loop + driver.cancel() + + # 4. Settle and check + await asyncio.sleep(0.1) + assert worker_cancelled is True, "Worker should have been cancelled" + + +async def test_producer_queue_full_sentinel(): + """Test that if the input queue is full, the producer's finally block doesn't deadlock trying to put sentinels.""" + limit = 2 + input_queue = asyncio.Queue(maxsize=1) + await producer([], input_queue, num_workers=limit) + assert input_queue.full() + assert input_queue.get_nowait() is SENTINEL + + +async def test_run_all_exception_coverage(): + """Test that if a worker raises an exception, run_all handles it without deadlocking on a full result_queue.""" + input_queue = asyncio.Queue() + # Force result_queue to be full so put_nowait(SENTINEL) hits the 'except' block + result_queue = asyncio.Queue(maxsize=1) + result_queue.put_nowait("Blocker") + + async def broken_worker(_): + # We need a small yield here to ensure the TaskGroup starts the worker + # before the exception is raised. + await asyncio.sleep(0) + raise RuntimeError("Simulated crash") + + # If the library is fixed, this finishes instantly. + # If not fixed, it hangs here. + try: + async with asyncio.timeout(2): + with suppress(Exception): + await run_all([1], broken_worker, input_queue, result_queue, limit=1) + except TimeoutError: + pytest.fail("Deadlock in run_all: finally block hung on a full result_queue") + + # Verify we didn't hang and the queue is still full + assert result_queue.full() + assert result_queue.get_nowait() == "Blocker" + + +async def test_run_parallel_cleanup_coverage(): + """Test that if the caller stops iterating, the generator's finally block is hit without deadlocking.""" + async def quick_fn(x): + return x + + gen = run_parallel([1], quick_fn, limit=1) + gen_as_gen = cast(AsyncGenerator, gen) + + # We trigger the finally block by throwing a CancelledError. + # This avoids the potential deadlock of 'aclose' on some loops. + with suppress(asyncio.CancelledError, StopAsyncIteration): + await gen_as_gen.athrow(asyncio.CancelledError) + + assert True