Skip to content

Commit 8e72f05

Browse files
author
rodrigo.nogueira
committed
feat: Allow alru_cache to automatically clear and rebind to the current event loop on cross-loop access instead of raising a RuntimeError.
1 parent 7ef00b7 commit 8e72f05

2 files changed

Lines changed: 53 additions & 30 deletions

File tree

async_lru/__init__.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,10 @@ def _check_loop(self, loop: asyncio.AbstractEventLoop) -> None:
130130
if self.__first_loop is None:
131131
self.__first_loop = loop
132132
elif self.__first_loop is not loop:
133-
raise RuntimeError(
134-
"alru_cache is not safe to use across event loops: this cache "
135-
"instance was first used with a different event loop. "
136-
"Use separate cache instances per event loop."
137-
)
133+
# Old cache entries hold tasks/handles bound to the previous
134+
# loop and are invalid here. Clear and rebind.
135+
self.cache_clear()
136+
self.__first_loop = loop
138137

139138
def cache_invalidate(self, /, *args: Hashable, **kwargs: Any) -> bool:
140139
key = _make_key(args, kwargs, self.__typed)
@@ -156,8 +155,6 @@ def cache_clear(self) -> None:
156155
self.__cache.clear()
157156

158157
async def cache_close(self, *, wait: bool = False) -> None:
159-
loop = asyncio.get_running_loop()
160-
self._check_loop(loop)
161158
self.__closed = True
162159

163160
tasks = self.__tasks
@@ -236,9 +233,10 @@ async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:
236233
raise RuntimeError(f"alru_cache is closed for {self}")
237234

238235
loop = asyncio.get_running_loop()
236+
self._check_loop(loop)
237+
239238
key = _make_key(fn_args, fn_kwargs, self.__typed)
240239
cache_item = self.__cache.get(key)
241-
self._check_loop(loop)
242240

243241
if cache_item is not None:
244242
self._cache_hit(key)

tests/test_thread_safety.py

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from async_lru import alru_cache
44

55

6-
def test_cross_loop_access_raises_error() -> None:
6+
def test_cross_loop_auto_resets_cache() -> None:
77
@alru_cache(maxsize=100)
88
async def cached_func(key: str) -> str:
99
return f"data_{key}"
@@ -12,19 +12,39 @@ async def cached_func(key: str) -> str:
1212
loop1.run_until_complete(cached_func("test"))
1313
loop1.close()
1414

15+
assert cached_func.cache_info().currsize == 1
16+
1517
loop2 = asyncio.new_event_loop()
16-
error_raised = False
17-
error_message = ""
18-
try:
19-
loop2.run_until_complete(cached_func("test"))
20-
except RuntimeError as e:
21-
error_raised = True
22-
error_message = str(e)
23-
finally:
24-
loop2.close()
18+
result = loop2.run_until_complete(cached_func("test"))
19+
loop2.close()
20+
21+
assert result == "data_test"
22+
# Cache was cleared on loop change, so the old entry is gone.
23+
# The new call re-populated it as a miss.
24+
assert cached_func.cache_info().hits == 0
25+
assert cached_func.cache_info().misses == 1
26+
2527

26-
assert error_raised, "RuntimeError should be raised for cross-loop access"
27-
assert "event loop" in error_message.lower()
28+
def test_cross_loop_preserves_stats_reset() -> None:
29+
@alru_cache(maxsize=100)
30+
async def cached_func(key: str) -> str:
31+
return f"data_{key}"
32+
33+
loop1 = asyncio.new_event_loop()
34+
loop1.run_until_complete(cached_func("a"))
35+
loop1.run_until_complete(cached_func("a"))
36+
loop1.close()
37+
38+
assert cached_func.cache_info().hits == 1
39+
assert cached_func.cache_info().misses == 1
40+
41+
loop2 = asyncio.new_event_loop()
42+
loop2.run_until_complete(cached_func("a"))
43+
loop2.close()
44+
45+
# Stats were reset on loop change (cache_clear resets hits/misses)
46+
assert cached_func.cache_info().hits == 0
47+
assert cached_func.cache_info().misses == 1
2848

2949

3050
def test_invalid_key_does_not_bind_loop() -> None:
@@ -72,7 +92,7 @@ async def run_test() -> list[str]:
7292
assert cached_func.cache_info().hits == 1
7393

7494

75-
def test_cross_loop_cache_close_raises_error() -> None:
95+
def test_cross_loop_cache_close_works() -> None:
7696
@alru_cache(maxsize=100)
7797
async def cached_func(key: str) -> str:
7898
return f"data_{key}"
@@ -82,15 +102,8 @@ async def cached_func(key: str) -> str:
82102
loop1.close()
83103

84104
loop2 = asyncio.new_event_loop()
85-
error_raised = False
86-
try:
87-
loop2.run_until_complete(cached_func.cache_close())
88-
except RuntimeError:
89-
error_raised = True
90-
finally:
91-
loop2.close()
92-
93-
assert error_raised, "RuntimeError should be raised for cross-loop cache_close"
105+
loop2.run_until_complete(cached_func.cache_close())
106+
loop2.close()
94107

95108

96109
def test_sync_methods_work_without_loop_check() -> None:
@@ -125,3 +138,15 @@ async def run_concurrent() -> list[str]:
125138

126139
assert results == ["data_test"] * 3
127140
assert cached_func.cache_info().hits == 2
141+
142+
143+
def test_multiple_loop_transitions() -> None:
144+
@alru_cache(maxsize=100)
145+
async def cached_func(key: str) -> str:
146+
return f"data_{key}"
147+
148+
for i in range(5):
149+
loop = asyncio.new_event_loop()
150+
result = loop.run_until_complete(cached_func("test"))
151+
loop.close()
152+
assert result == "data_test"

0 commit comments

Comments
 (0)