Skip to content

Commit 27e9d82

Browse files
Retrofit assert type for remaining type-stub checks (#953)
* Retrofit assert type for: - factory.py - list.py - object.py - provider.py - resource.py - singleton.py
1 parent 13847c0 commit 27e9d82

10 files changed

Lines changed: 250 additions & 133 deletions

File tree

tests/typing/configuration.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,9 @@
8181
config5_pydantic.set_pydantic_settings([PydanticSettings()])
8282

8383
# NOTE: Using assignment since PydanticSettings is context-sensitive: conditional on whether pydantic is installed
84-
config5_pydantic_settings: list[PydanticSettings] = (config5_pydantic.get_pydantic_settings())
84+
config5_pydantic_settings: list[PydanticSettings] = (
85+
config5_pydantic.get_pydantic_settings()
86+
)
8587

8688
# Test 6: to check init arguments
8789
config6 = providers.Configuration(

tests/typing/declarative_container.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ class Container5(containers.DeclarativeContainer):
5252
dependencies = Container5.dependencies
5353
assert_type(dependencies, Dict[str, providers.Provider[Any]])
5454

55+
5556
# Test 6: to check base class
5657
class Container6(containers.DeclarativeContainer):
5758
provider = providers.Factory(int)

tests/typing/dict.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
assert_type(var1, Dict[Any, Any])
1313

1414

15-
1615
# Test 2: to check init with non-string keys
1716
provider2 = providers.Dict({object(): providers.Factory(object)})
1817
var2 = provider2()
@@ -42,7 +41,7 @@
4241
a2=providers.Factory(object),
4342
)
4443
provided5 = provider5.provided()
45-
assert_type(provided5, Any)
44+
assert_type(provided5, Any)
4645

4746

4847
# Test 6: to check the return type with await

tests/typing/factory.py

Lines changed: 67 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Any, Callable, Dict, Optional, Tuple, Type
2+
from typing_extensions import assert_type
23

34
from dependency_injector import providers
45

@@ -17,103 +18,126 @@ def create(cls) -> Animal:
1718

1819
# Test 1: to check the return type (class)
1920
provider1 = providers.Factory(Cat)
20-
animal1: Animal = provider1(1, 2, 3, b="1", c=2, e=0.0)
21+
animal1 = provider1(1, 2, 3, b="1", c=2, e=0.0)
22+
assert_type(animal1, Cat)
2123

2224
# Test 2: to check the return type (class factory method)
2325
provider2 = providers.Factory(Cat.create)
24-
animal2: Animal = provider2()
26+
animal2 = provider2()
27+
assert_type(animal2, Animal)
2528

2629
# Test 3: to check the .override() method
2730
provider3 = providers.Factory(Animal)
2831
with provider3.override(providers.Factory(Cat)):
29-
provider3()
32+
animal3 = provider3()
33+
assert_type(animal3, Animal)
3034

3135
# Test 4: to check the .args, .kwargs, .attributes attributes
3236
provider4 = providers.Factory(Animal)
33-
args4: Tuple[Any] = provider4.args
34-
kwargs4: Dict[str, Any] = provider4.kwargs
35-
attributes4: Dict[str, Any] = provider4.attributes
37+
args4 = provider4.args
38+
kwargs4 = provider4.kwargs
39+
attributes4 = provider4.attributes
40+
assert_type(args4, Tuple[Any])
41+
assert_type(kwargs4, Dict[str, Any])
42+
assert_type(attributes4, Dict[str, Any])
3643

3744
# Test 5: to check the provided instance interface
3845
provider5 = providers.Factory(Animal)
39-
provided5: Animal = provider5.provided()
40-
attr_getter5: providers.AttributeGetter = provider5.provided.attr
41-
item_getter5: providers.ItemGetter = provider5.provided["item"]
42-
method_caller5: providers.MethodCaller = provider5.provided.method.call(123, arg=324)
46+
provided5 = provider5.provided
47+
provided_val5 = provided5()
48+
attr_getter5 = provider5.provided.attr
49+
item_getter5 = provider5.provided["item"]
50+
method_caller5 = provider5.provided.method.call(123, arg=324)
51+
assert_type(provided5, providers.ProvidedInstance)
52+
assert_type(provided_val5, Any)
53+
assert_type(attr_getter5, providers.AttributeGetter)
54+
assert_type(item_getter5, providers.ItemGetter)
55+
assert_type(method_caller5, providers.MethodCaller)
4356

4457
# Test 6: to check the DelegatedFactory
4558
provider6 = providers.DelegatedFactory(Cat)
46-
animal6: Animal = provider6(1, 2, 3, b="1", c=2, e=0.0)
59+
animal6 = provider6(1, 2, 3, b="1", c=2, e=0.0)
60+
assert_type(animal6, Cat)
4761

4862
# Test 7: to check the AbstractFactory
4963
provider7 = providers.AbstractFactory(Animal)
5064
provider7.override(providers.Factory(Cat))
51-
animal7: Animal = provider7(1, 2, 3, b="1", c=2, e=0.0)
65+
animal7 = provider7(1, 2, 3, b="1", c=2, e=0.0)
66+
assert_type(animal7, Animal)
5267

5368
# Test 8: to check the FactoryDelegate __init__
5469
provider8 = providers.FactoryDelegate(providers.Factory(object))
5570

5671
# Test 9: to check FactoryAggregate provider
57-
provider9: providers.FactoryAggregate[str] = providers.FactoryAggregate(
72+
provider9 = providers.FactoryAggregate(
5873
a=providers.Factory(str, "str1"),
5974
b=providers.Factory(str, "str2"),
6075
)
61-
factory_a_9: providers.Factory[str] = provider9.a
62-
factory_b_9: providers.Factory[str] = provider9.b
63-
val9: str = provider9("a")
64-
65-
provider9_set_non_string_keys: providers.FactoryAggregate[str] = (
66-
providers.FactoryAggregate()
67-
)
76+
factory_a_9 = provider9.a
77+
factory_b_9 = provider9.b
78+
val9 = provider9("a")
79+
assert_type(provider9, providers.FactoryAggregate[str])
80+
assert_type(factory_a_9, providers.Factory[str])
81+
assert_type(factory_b_9, providers.Factory[str])
82+
assert_type(val9, str)
83+
84+
provider9_set_non_string_keys = providers.FactoryAggregate[str]()
6885
provider9_set_non_string_keys.set_factories({Cat: providers.Factory(str, "str")})
69-
factory_set_non_string_9: providers.Factory[str] = (
70-
provider9_set_non_string_keys.factories[Cat]
71-
)
86+
factory_set_non_string_9 = provider9_set_non_string_keys.factories[Cat]
87+
assert_type(provider9_set_non_string_keys, providers.FactoryAggregate[str])
88+
assert_type(factory_set_non_string_9, providers.Factory[str])
7289

73-
provider9_new_non_string_keys: providers.FactoryAggregate[str] = (
74-
providers.FactoryAggregate(
75-
{Cat: providers.Factory(str, "str")},
76-
)
77-
)
78-
factory_new_non_string_9: providers.Factory[str] = (
79-
provider9_new_non_string_keys.factories[Cat]
90+
provider9_new_non_string_keys = providers.FactoryAggregate(
91+
{Cat: providers.Factory(str, "str")},
8092
)
93+
factory_new_non_string_9 = provider9_new_non_string_keys.factories[Cat]
94+
assert_type(provider9_new_non_string_keys, providers.FactoryAggregate[str])
95+
assert_type(factory_new_non_string_9, providers.Factory[str])
8196

8297
provider9_no_explicit_typing = providers.FactoryAggregate(
8398
a=providers.Factory(str, "str")
8499
)
85-
provider9_no_explicit_typing_factory: providers.Factory[str] = (
86-
provider9_no_explicit_typing.factories["a"]
87-
)
88-
provider9_no_explicit_typing_object: str = provider9_no_explicit_typing("a")
100+
provider9_no_explicit_typing_factory = provider9_no_explicit_typing.factories["a"]
101+
provider9_no_explicit_typing_object = provider9_no_explicit_typing("a")
102+
assert_type(provider9_no_explicit_typing, providers.FactoryAggregate[str])
103+
assert_type(provider9_no_explicit_typing_factory, providers.Factory[str])
104+
assert_type(provider9_no_explicit_typing_object, str)
89105

90106
# Test 10: to check the explicit typing
91-
factory10: providers.Provider[Animal] = providers.Factory(Cat)
92-
animal10: Animal = factory10()
107+
factory10 = providers.Factory[Animal](Cat)
108+
animal10 = factory10()
109+
assert_type(factory10, providers.Factory[Animal])
110+
assert_type(animal10, Animal)
93111

94112
# Test 11: to check the return type with await
95113
provider11 = providers.Factory(Cat)
96114

97115

98116
async def _async11() -> None:
99-
animal1: Animal = await provider11(1, 2, 3, b="1", c=2, e=0.0) # type: ignore
100-
animal2: Animal = await provider11.async_(1, 2, 3, b="1", c=2, e=0.0)
117+
animal1 = await provider11(1, 2, 3, b="1", c=2, e=0.0) # type: ignore
118+
animal2 = await provider11.async_(1, 2, 3, b="1", c=2, e=0.0)
119+
assert_type(animal2, Cat)
101120

102121

103122
# Test 12: to check class type from .provides
104123
provider12 = providers.Factory(Cat)
105-
provided_cls12: Type[Animal] = provider12.cls
124+
provided_cls12 = provider12.cls
106125
assert issubclass(provided_cls12, Animal)
107-
provided_provides12: Optional[Callable[..., Animal]] = provider12.provides
126+
provided_provides12 = provider12.provides
108127
assert provided_provides12 is not None and provided_provides12() == Cat()
128+
assert_type(provided_cls12, Type[Cat])
129+
assert_type(provided_provides12, Callable[..., Cat])
130+
109131

110132
# Test 13: to check class from .provides with explicit typevar
111133
provider13 = providers.Factory[Animal](Cat)
112-
provided_cls13: Type[Animal] = provider13.cls
134+
provided_cls13 = provider13.cls
113135
assert issubclass(provided_cls13, Animal)
114-
provided_provides13: Optional[Callable[..., Animal]] = provider13.provides
136+
provided_provides13 = provider13.provides
115137
assert provided_provides13 is not None and provided_provides13() == Cat()
138+
assert_type(provided_cls13, Type[Animal])
139+
assert_type(provided_provides13, Callable[..., Animal])
116140

117141
# Test 14: to check string imports
118-
provider14: providers.Factory[Dict[Any, Any]] = providers.Factory("builtins.dict")
142+
provider14 = providers.Factory[Any]("builtins.dict")
119143
provider14.set_provides("builtins.dict")

tests/typing/list.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Any, List, Tuple
2+
from typing_extensions import assert_type
23

34
from dependency_injector import providers
45

@@ -7,25 +8,33 @@
78
providers.Factory(object),
89
providers.Factory(object),
910
)
10-
var1: List[Any] = provider1()
11+
var1 = provider1()
12+
assert_type(var1, List[Any])
1113

1214

1315
# Test 2: to check the .args attributes
1416
provider2 = providers.List(
1517
providers.Factory(object),
1618
providers.Factory(object),
1719
)
18-
args2: Tuple[Any] = provider2.args
20+
args2 = provider2.args
21+
assert_type(args2, Tuple[Any])
1922

2023
# Test 3: to check the provided instance interface
2124
provider3 = providers.List(
2225
providers.Factory(object),
2326
providers.Factory(object),
2427
)
25-
provided3: List[Any] = provider3.provided()
26-
attr_getter3: providers.AttributeGetter = provider3.provided.attr
27-
item_getter3: providers.ItemGetter = provider3.provided["item"]
28-
method_caller3: providers.MethodCaller = provider3.provided.method.call(123, arg=324)
28+
provided3 = provider3.provided
29+
provided_val3 = provided3()
30+
attr_getter3 = provider3.provided.attr
31+
item_getter3 = provider3.provided["item"]
32+
method_caller3 = provider3.provided.method.call(123, arg=324)
33+
assert_type(provided3, providers.ProvidedInstance)
34+
assert_type(provided_val3, Any)
35+
assert_type(attr_getter3, providers.AttributeGetter)
36+
assert_type(item_getter3, providers.ItemGetter)
37+
assert_type(method_caller3, providers.MethodCaller)
2938

3039
# Test 4: to check the return type with await
3140
provider4 = providers.List(
@@ -35,5 +44,6 @@
3544

3645

3746
async def _async4() -> None:
38-
var1: List[Any] = await provider4() # type: ignore
39-
var2: List[Any] = await provider4.async_()
47+
var1 = await provider4() # type: ignore
48+
var2 = await provider4.async_()
49+
assert_type(var2, List[Any])

tests/typing/object.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,38 @@
1-
from typing import Optional, Type
1+
from typing import Optional, Any
2+
from typing_extensions import assert_type
23

34
from dependency_injector import providers
45

56
# Test 1: to check the return type
67
provider1 = providers.Object(int(3))
7-
var1: int = provider1()
8+
var1 = provider1()
9+
assert_type(var1, int)
810

911
# Test 2: to check the provided instance interface
1012
provider2 = providers.Object(int)
11-
provided2: Type[int] = provider2.provided()
12-
attr_getter2: providers.AttributeGetter = provider2.provided.attr
13-
item_getter2: providers.ItemGetter = provider2.provided["item"]
14-
method_caller2: providers.MethodCaller = provider2.provided.method.call(123, arg=324)
13+
provided2 = provider2.provided
14+
provided_val2 = provided2()
15+
attr_getter2 = provider2.provided.attr
16+
item_getter2 = provider2.provided["item"]
17+
method_caller2 = provider2.provided.method.call(123, arg=324)
18+
assert_type(provided2, providers.ProvidedInstance)
19+
assert_type(provided_val2, Any)
20+
assert_type(attr_getter2, providers.AttributeGetter)
21+
assert_type(item_getter2, providers.ItemGetter)
22+
assert_type(method_caller2, providers.MethodCaller)
23+
1524

1625
# Test 3: to check the return type with await
1726
provider3 = providers.Object(int(3))
1827

1928

2029
async def _async3() -> None:
21-
var1: int = await provider3() # type: ignore
22-
var2: int = await provider3.async_()
30+
var1 = await provider3() # type: ignore
31+
var2 = await provider3.async_()
32+
assert_type(var2, int)
2333

2434

2535
# Test 4: to check class type from provider
2636
provider4 = providers.Object(int("1"))
27-
provided_provides: Optional[int] = provider4.provides
37+
provided_provides4 = provider4.provides
38+
assert_type(provided_provides4, Optional[int])

tests/typing/provider.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,26 @@
11
from typing import Any
2+
from typing_extensions import assert_type
23

34
from dependency_injector import providers
45

56
# Test 1: to check .provided attribute
67
provider1: providers.Provider[int] = providers.Object(1)
7-
provided: int = provider1.provided()
8-
provider1_delegate: providers.Provider[int] = provider1.provider
8+
provided1 = provider1.provided
9+
provided_val1 = provided1()
10+
provider1_delegate = provider1.provider
11+
assert_type(provider1, providers.Provider[int])
12+
assert_type(provided1, providers.ProvidedInstance)
13+
assert_type(provided_val1, Any)
14+
assert_type(provider1_delegate, providers.Provider[int])
915

1016
# Test 2: to check async mode API
11-
provider2: providers.Provider[Any] = providers.Provider()
17+
provider2 = providers.Provider[Any]()
1218
provider2.enable_async_mode()
1319
provider2.disable_async_mode()
1420
provider2.reset_async_mode()
15-
r1: bool = provider2.is_async_mode_enabled()
16-
r2: bool = provider2.is_async_mode_disabled()
17-
r3: bool = provider2.is_async_mode_undefined()
21+
r1 = provider2.is_async_mode_enabled()
22+
r2 = provider2.is_async_mode_disabled()
23+
r3 = provider2.is_async_mode_undefined()
24+
assert_type(r1, bool)
25+
assert_type(r2, bool)
26+
assert_type(r3, bool)

0 commit comments

Comments
 (0)