11import random
2+ from typing import Any , Callable , Coroutine , TypeVar
23
34import pytest
45
56import asyncstdlib as a
67
78from .utility import sync , asyncify , awaitify
89
10+ COR = TypeVar ("COR" , bound = Callable [..., Coroutine [Any , Any , Any ]])
911
10- def hide_coroutine (corofunc ):
11- def wrapper (* args , ** kwargs ):
12+
13+ def hide_coroutine (corofunc : COR ) -> COR :
14+ """Make a coroutine function look like a regular function returning a coroutine"""
15+
16+ def wrapper (* args , ** kwargs ): # type: ignore
1217 return corofunc (* args , ** kwargs )
1318
14- return wrapper
19+ return wrapper # type: ignore
1520
1621
1722@sync
@@ -94,7 +99,7 @@ async def __aiter__(self):
9499
95100@sync
96101async def test_map_as ():
97- async def map_op (value ) :
102+ async def map_op (value : int ) -> int :
98103 return value * 2
99104
100105 assert [value async for value in a .map (map_op , range (5 ))] == list (range (0 , 10 , 2 ))
@@ -105,7 +110,7 @@ async def map_op(value):
105110
106111@sync
107112async def test_map_sa ():
108- def map_op (value ) :
113+ async def map_op (value : int ) -> int :
109114 return value * 2
110115
111116 assert [value async for value in a .map (map_op , asyncify (range (5 )))] == list (
@@ -118,7 +123,7 @@ def map_op(value):
118123
119124@sync
120125async def test_map_aa ():
121- async def map_op (value ) :
126+ async def map_op (value : int ) -> int :
122127 return value * 2
123128
124129 assert [value async for value in a .map (map_op , asyncify (range (5 )))] == list (
@@ -130,6 +135,28 @@ async def map_op(value):
130135 ] == list (range (10 , 20 , 4 ))
131136
132137
138+ @pytest .mark .parametrize (
139+ "itrs" ,
140+ [
141+ (range (4 ), range (5 ), range (5 )),
142+ (range (5 ), range (4 ), range (5 )),
143+ (range (5 ), range (5 ), range (4 )),
144+ ],
145+ )
146+ @sync
147+ async def test_map_strict_unequal (itrs : "tuple[range, ...]" ):
148+ def triple_sum (x : int , y : int , z : int ) -> int :
149+ return x + y + z
150+
151+ # no error without strict
152+ async for _ in a .map (triple_sum , * itrs ):
153+ pass
154+ # error with strict
155+ with pytest .raises (ValueError ):
156+ async for _ in a .map (triple_sum , * itrs , strict = True ):
157+ pass
158+
159+
133160@sync
134161async def test_max_default ():
135162 assert await a .max ((), default = 3 ) == 3
@@ -142,7 +169,7 @@ async def test_max_default():
142169
143170@sync
144171async def test_max_sa ():
145- async def minus (x ) :
172+ async def minus (x : int ) -> int :
146173 return - x
147174
148175 assert await a .max (asyncify ((1 , 2 , 3 , 4 ))) == 4
@@ -167,7 +194,7 @@ async def test_min_default():
167194
168195@sync
169196async def test_min_sa ():
170- async def minus (x ) :
197+ async def minus (x : int ) -> int :
171198 return - x
172199
173200 assert await a .min (asyncify ((1 , 2 , 3 , 4 ))) == 1
@@ -180,7 +207,7 @@ async def minus(x):
180207
181208@sync
182209async def test_filter_as ():
183- async def map_op (value ) :
210+ async def map_op (value : int ) -> bool :
184211 return value % 2 == 0
185212
186213 assert [value async for value in a .filter (map_op , range (5 ))] == list (range (0 , 5 , 2 ))
@@ -194,7 +221,7 @@ async def map_op(value):
194221
195222@sync
196223async def test_filter_sa ():
197- def map_op (value ) :
224+ def map_op (value : int ) -> bool :
198225 return value % 2 == 0
199226
200227 assert [value async for value in a .filter (map_op , asyncify (range (5 )))] == list (
@@ -208,7 +235,7 @@ def map_op(value):
208235
209236@sync
210237async def test_filter_aa ():
211- async def map_op (value ) :
238+ async def map_op (value : int ) -> bool :
212239 return value % 2 == 0
213240
214241 assert [value async for value in a .filter (map_op , asyncify (range (5 )))] == list (
@@ -286,7 +313,7 @@ async def test_types():
286313@pytest .mark .parametrize ("sortable" , sortables )
287314@pytest .mark .parametrize ("reverse" , [True , False ])
288315@sync
289- async def test_sorted_direct (sortable , reverse ):
316+ async def test_sorted_direct (sortable : "list[int] | list[float]" , reverse : bool ):
290317 assert await a .sorted (sortable , reverse = reverse ) == sorted (
291318 sortable , reverse = reverse
292319 )
@@ -305,12 +332,12 @@ async def test_sorted_direct(sortable, reverse):
305332async def test_sorted_stable ():
306333 values = [- i for i in range (20 )]
307334
308- def collision_key (x ) :
335+ def collision_key (x : int ) -> int :
309336 return x // 2
310337
311338 # test the test...
312339 assert sorted (values , key = collision_key ) != [
313- item for key , item in sorted ([(collision_key (i ), i ) for i in values ])
340+ item for _ , item in sorted ([(collision_key (i ), i ) for i in values ])
314341 ]
315342 # test the implementation
316343 assert await a .sorted (values , key = awaitify (collision_key )) == sorted (
0 commit comments