Skip to content

Commit 58e7f8f

Browse files
patchback[bot]Dreamsorcerergonas
authored
[PR #12106/8ab84c52 backport][3.14] Bound DNS cache (#12116)
**This is a backport of PR #12106 as merged into master (8ab84c5).** --------- Co-authored-by: Sam Bull <[email protected]> Co-authored-by: gonas <[email protected]>
1 parent f82bc8a commit 58e7f8f

3 files changed

Lines changed: 95 additions & 6 deletions

File tree

CHANGES/12106.feature.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Added a ``dns_cache_max_size`` parameter to ``TCPConnector`` to limit the size of the cache -- by :user:`Dreamsorcerer`.

aiohttp/connector.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -839,25 +839,33 @@ async def _create_connection(
839839

840840

841841
class _DNSCacheTable:
842-
def __init__(self, ttl: float | None = None) -> None:
843-
self._addrs_rr: dict[tuple[str, int], tuple[Iterator[ResolveResult], int]] = {}
842+
def __init__(self, ttl: float | None = None, max_size: int = 1000) -> None:
843+
self._addrs_rr: OrderedDict[
844+
tuple[str, int], tuple[Iterator[ResolveResult], int]
845+
] = OrderedDict()
844846
self._timestamps: dict[tuple[str, int], float] = {}
845847
self._ttl = ttl
848+
self._max_size = max_size
846849

847850
def __contains__(self, host: object) -> bool:
848851
return host in self._addrs_rr
849852

850853
def add(self, key: tuple[str, int], addrs: list[ResolveResult]) -> None:
854+
if key in self._addrs_rr:
855+
self._addrs_rr.move_to_end(key)
856+
851857
self._addrs_rr[key] = (cycle(addrs), len(addrs))
852858

853859
if self._ttl is not None:
854860
self._timestamps[key] = monotonic()
855861

862+
if len(self._addrs_rr) > self._max_size:
863+
oldest_key, _ = self._addrs_rr.popitem(last=False)
864+
self._timestamps.pop(oldest_key, None)
865+
856866
def remove(self, key: tuple[str, int]) -> None:
857867
self._addrs_rr.pop(key, None)
858-
859-
if self._ttl is not None:
860-
self._timestamps.pop(key, None)
868+
self._timestamps.pop(key, None)
861869

862870
def clear(self) -> None:
863871
self._addrs_rr.clear()
@@ -868,6 +876,7 @@ def next_addrs(self, key: tuple[str, int]) -> list[ResolveResult]:
868876
addrs = list(islice(loop, length))
869877
# Consume one more element to shift internal state of `cycle`
870878
next(loop)
879+
self._addrs_rr.move_to_end(key)
871880
return addrs
872881

873882
def expired(self, key: tuple[str, int]) -> bool:
@@ -956,6 +965,7 @@ def __init__(
956965
fingerprint: bytes | None = None,
957966
use_dns_cache: bool = True,
958967
ttl_dns_cache: int | None = 10,
968+
dns_cache_max_size: int = 1000,
959969
family: socket.AddressFamily = socket.AddressFamily.AF_UNSPEC,
960970
ssl_context: SSLContext | None = None,
961971
ssl: bool | Fingerprint | SSLContext = True,
@@ -994,7 +1004,9 @@ def __init__(
9941004
self._resolver_owner = False
9951005

9961006
self._use_dns_cache = use_dns_cache
997-
self._cached_hosts = _DNSCacheTable(ttl=ttl_dns_cache)
1007+
self._cached_hosts = _DNSCacheTable(
1008+
ttl=ttl_dns_cache, max_size=dns_cache_max_size
1009+
)
9981010
self._throttle_dns_futures: dict[tuple[str, int], set[asyncio.Future[None]]] = (
9991011
{}
10001012
)

tests/test_connector.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4027,6 +4027,25 @@ async def handler(request):
40274027

40284028

40294029
class TestDNSCacheTable:
4030+
host1 = ("localhost", 80)
4031+
host2 = ("foo", 80)
4032+
result1: ResolveResult = {
4033+
"hostname": "localhost",
4034+
"host": "127.0.0.1",
4035+
"port": 80,
4036+
"family": socket.AF_INET,
4037+
"proto": 0,
4038+
"flags": socket.AI_NUMERICHOST,
4039+
}
4040+
result2: ResolveResult = {
4041+
"hostname": "foo",
4042+
"host": "127.0.0.2",
4043+
"port": 80,
4044+
"family": socket.AF_INET,
4045+
"proto": 0,
4046+
"flags": socket.AI_NUMERICHOST,
4047+
}
4048+
40304049
@pytest.fixture
40314050
def dns_cache_table(self):
40324051
return _DNSCacheTable()
@@ -4112,6 +4131,63 @@ def test_next_addrs_single(self, dns_cache_table) -> None:
41124131
addrs = dns_cache_table.next_addrs("foo")
41134132
assert addrs == ["127.0.0.1"]
41144133

4134+
def test_max_size_eviction(self) -> None:
4135+
table = _DNSCacheTable(max_size=2)
4136+
4137+
table.add(self.host1, [self.result1])
4138+
table.add(self.host2, [self.result2])
4139+
4140+
host3 = ("example.com", 80)
4141+
result3: ResolveResult = {
4142+
**self.result1,
4143+
"hostname": "example.com",
4144+
"host": "1.2.3.4",
4145+
}
4146+
table.add(host3, [result3])
4147+
4148+
assert len(table._addrs_rr) == 2
4149+
assert self.host1 not in table._addrs_rr
4150+
assert host3 in table._addrs_rr
4151+
4152+
def test_lru_eviction(self) -> None:
4153+
table = _DNSCacheTable(max_size=2)
4154+
4155+
table.add(self.host1, [self.result1])
4156+
table.add(self.host2, [self.result2])
4157+
4158+
table.next_addrs(self.host1)
4159+
4160+
host3 = ("example.com", 80)
4161+
result3: ResolveResult = {
4162+
**self.result1,
4163+
"hostname": "example.com",
4164+
"host": "1.2.3.4",
4165+
}
4166+
table.add(host3, [result3])
4167+
4168+
assert self.host1 in table._addrs_rr
4169+
assert self.host2 not in table._addrs_rr
4170+
4171+
def test_lru_eviction_add(self) -> None:
4172+
table = _DNSCacheTable(max_size=2)
4173+
4174+
table.add(self.host1, [self.result1])
4175+
table.add(self.host2, [self.result2])
4176+
4177+
# Re-add, thus making host1 the most recently used.
4178+
table.add(self.host1, [self.result1])
4179+
4180+
host3 = ("example.com", 80)
4181+
result3: ResolveResult = {
4182+
**self.result1,
4183+
"hostname": "example.com",
4184+
"host": "1.2.3.4",
4185+
}
4186+
table.add(host3, [result3])
4187+
4188+
assert self.host1 in table._addrs_rr
4189+
assert self.host2 not in table._addrs_rr
4190+
41154191

41164192
async def test_connector_cache_trace_race():
41174193
class DummyTracer:

0 commit comments

Comments
 (0)