@@ -62,17 +62,15 @@ async def producer[T](
6262 else :
6363 for item in items :
6464 await input_queue .put (item )
65-
66- except asyncio .CancelledError :
67- # We were cancelled — workers are being cancelled too, so there's
68- # nobody left to consume sentinels. Don't bother sending them.
69- raise
70-
7165 finally :
72- # Normal completion only — workers are still running and need sentinels.
66+ # Use put_nowait and we must not block here.
67+ # If the queue is full, skip. Workers don't need more than one
68+ # sentinel to know it's time to quit.
7369 for _ in range (num_workers ):
74- await input_queue .put (SENTINEL )
75-
70+ try :
71+ input_queue .put_nowait (SENTINEL )
72+ except (asyncio .QueueFull , Exception ):
73+ break
7674
7775async def worker [T , R ](
7876 worker_fn : Callable [[T ], Awaitable [R ]],
@@ -86,21 +84,23 @@ async def worker[T, R](
8684 :param result_queue: The queue for results from the workers.
8785 """
8886 while True :
89- item = await input_queue .get ()
90- if item is SENTINEL :
87+ # If the loop is closing, get() might raise CancelledError
88+ try :
89+ item = await input_queue .get ()
90+ except asyncio .CancelledError :
9191 return
92+
9293 try :
94+ if item is SENTINEL :
95+ return
96+
9397 result = await worker_fn (item )
9498 await result_queue .put (result )
95- except asyncio .CancelledError :
96- raise
9799 except Exception as exc :
98100 await result_queue .put (TaskFailedError (item , exc ))
99-
100101 finally :
101102 input_queue .task_done ()
102103
103-
104104async def run_all [T , R ](
105105 items : Iterable [T ] | AsyncIterableABC [T ],
106106 worker_fn : Callable [[T ], Awaitable [R ]],
@@ -116,14 +116,17 @@ async def run_all[T, R](
116116 :param result_queue: The queue for results from the workers.
117117 :param limit: The maximum number of concurrent workers.
118118 """
119- try :
120- async with asyncio .TaskGroup () as tg :
121- tg .create_task (producer (items , input_queue , limit ))
122- for _ in range (limit ):
123- tg .create_task (worker (worker_fn , input_queue , result_queue ))
119+ # Remove the internal .join() and let TaskGroup manage the lifecycle
120+ async with asyncio .TaskGroup () as tg :
121+ tg .create_task (producer (items , input_queue , limit ))
122+ for _ in range (limit ):
123+ tg .create_task (worker (worker_fn , input_queue , result_queue ))
124124
125- finally :
126- await result_queue .put (SENTINEL )
125+ # Once we are here, TaskGroup has successfully joined all tasks.
126+ try :
127+ result_queue .put_nowait (SENTINEL )
128+ except (asyncio .QueueFull , Exception ):
129+ pass
127130
128131
129132async def run_parallel [T , R , ** P ](
@@ -191,10 +194,8 @@ async def run_parallel[T, R, **P](
191194 functools .partial (worker_fn , * worker_args , ** worker_kwargs ) if worker_kwargs else worker_fn
192195 )
193196
194- input_queue : asyncio .Queue [T | object ] = asyncio .Queue (maxsize = limit * 2 )
195- result_queue : asyncio .Queue [R | TaskFailedError [T ] | object ] = asyncio .Queue (
196- maxsize = limit * 2
197- )
197+ input_queue : asyncio .Queue [T | object ] = asyncio .Queue (maxsize = limit * 5 )
198+ result_queue : asyncio .Queue [R | TaskFailedError [T ] | object ] = asyncio .Queue (maxsize = 0 )
198199
199200 runner = asyncio .create_task (
200201 run_all (items , bound_fn , input_queue , result_queue , limit )
0 commit comments