|
| 1 | +"""Tests for concurrency utilities.""" |
| 2 | + |
| 3 | +import asyncio |
| 4 | +import logging |
| 5 | + |
| 6 | +import pytest |
| 7 | + |
| 8 | +from docbuild.utils.concurrency import TaskFailedError, process_unordered |
| 9 | + |
| 10 | + |
| 11 | +async def test_process_unordered_basic(): |
| 12 | + """Test basic parallel processing of a list of numbers.""" |
| 13 | + async def square(n: int) -> int: |
| 14 | + await asyncio.sleep(0.01) |
| 15 | + return n * n |
| 16 | + |
| 17 | + items = [1, 2, 3, 4, 5] |
| 18 | + result = await process_unordered(items, square, limit=2) |
| 19 | + |
| 20 | + val_set = set() |
| 21 | + for r in result: |
| 22 | + assert isinstance(r, int) |
| 23 | + val_set.add(r) |
| 24 | + |
| 25 | + assert val_set == {1, 4, 9, 16, 25} |
| 26 | + |
| 27 | + |
| 28 | +async def test_process_unordered_concurrency_limit(): |
| 29 | + """Verify that concurrency limit is respected.""" |
| 30 | + active_workers = 0 |
| 31 | + max_active = 0 |
| 32 | + lock = asyncio.Lock() |
| 33 | + |
| 34 | + async def track_concurrency(n: int) -> int: |
| 35 | + nonlocal active_workers, max_active |
| 36 | + async with lock: |
| 37 | + active_workers += 1 |
| 38 | + max_active = max(max_active, active_workers) |
| 39 | + |
| 40 | + await asyncio.sleep(0.05) |
| 41 | + |
| 42 | + async with lock: |
| 43 | + active_workers -= 1 |
| 44 | + return n |
| 45 | + |
| 46 | + items = range(10) |
| 47 | + limit = 3 |
| 48 | + # Use higher limit in worker fn but restrict at call site |
| 49 | + await process_unordered(items, track_concurrency, limit=limit) |
| 50 | + |
| 51 | + assert max_active <= limit |
| 52 | + |
| 53 | + |
| 54 | +async def test_process_unordered_exceptions(): |
| 55 | + """Test exception handling returning TaskFailedError.""" |
| 56 | + async def fail_on_even(n: int) -> int: |
| 57 | + if n % 2 == 0: |
| 58 | + raise ValueError(f"Even number: n={n}") |
| 59 | + return n |
| 60 | + |
| 61 | + items = [1, 2, 3] |
| 62 | + results = await process_unordered(items, fail_on_even, limit=2) |
| 63 | + |
| 64 | + assert len(results) == 3 |
| 65 | + |
| 66 | + success_vals = [] |
| 67 | + failed_items = [] |
| 68 | + |
| 69 | + for r in results: |
| 70 | + match r: |
| 71 | + case TaskFailedError(item=item, original_exception=exc): |
| 72 | + failed_items.append(item) |
| 73 | + assert isinstance(exc, ValueError) |
| 74 | + case _: |
| 75 | + success_vals.append(r) |
| 76 | + |
| 77 | + assert set(success_vals) == {1, 3} |
| 78 | + assert failed_items == [2] |
| 79 | + |
| 80 | + |
| 81 | +async def test_process_unordered_empty(): |
| 82 | + """Test processing an empty list.""" |
| 83 | + async def identity(n): return n |
| 84 | + results = await process_unordered([], identity, limit=5) |
| 85 | + assert results == [] |
| 86 | + |
| 87 | + |
| 88 | +async def test_process_unordered_kwargs(): |
| 89 | + """Test passing kwargs to worker function.""" |
| 90 | + async def multiply(n: int, factor: int = 1) -> int: |
| 91 | + return n * factor |
| 92 | + |
| 93 | + items = [1, 2, 3] |
| 94 | + results = await process_unordered(items, multiply, limit=2, factor=3) |
| 95 | + |
| 96 | + # We might get exceptions if anything failed, but expecting ints |
| 97 | + int_results = [r for r in results if isinstance(r, int)] |
| 98 | + assert set(int_results) == {3, 6, 9} |
0 commit comments