Skip to content

Commit 0284e85

Browse files
Tasssadarbdraco
andcommitted
[PR aio-libs#11150/996ad00 backport][3.12] fix: leak of aiodns.DNSResolver when ClientSession is closed (aio-libs#11150)
Co-authored-by: J. Nick Koston <[email protected]> (cherry picked from commit 996ad00)
1 parent c0e04a2 commit 0284e85

File tree

5 files changed

+41
-5
lines changed

5 files changed

+41
-5
lines changed

CHANGES/11150.bugfix.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Fixed leak of ``aiodns.DNSResolver`` when :py:class:`~aiohttp.TCPConnector` is closed and no resolver was passed when creating the connector -- by :user:`Tasssadar`.
2+
3+
This was a regression introduced in version 3.12.0 (:pr:`10897`).

CONTRIBUTORS.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ Vladimir Shulyak
368368
Vladimir Zakharov
369369
Vladyslav Bohaichuk
370370
Vladyslav Bondar
371+
Vojtěch Boček
371372
W. Trevor King
372373
Wei Lin
373374
Weiwei Wang

aiohttp/connector.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -926,9 +926,14 @@ def __init__(
926926
)
927927

928928
self._ssl = _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint)
929+
930+
self._resolver: AbstractResolver
929931
if resolver is None:
930-
resolver = DefaultResolver(loop=self._loop)
931-
self._resolver = resolver
932+
self._resolver = DefaultResolver(loop=self._loop)
933+
self._resolver_owner = True
934+
else:
935+
self._resolver = resolver
936+
self._resolver_owner = False
932937

933938
self._use_dns_cache = use_dns_cache
934939
self._cached_hosts = _DNSCacheTable(ttl=ttl_dns_cache)
@@ -956,6 +961,12 @@ def _close(self) -> List[Awaitable[object]]:
956961

957962
return waiters
958963

964+
async def close(self) -> None:
965+
"""Close all opened transports."""
966+
if self._resolver_owner:
967+
await self._resolver.close()
968+
await super().close()
969+
959970
@property
960971
def family(self) -> int:
961972
"""Socket family like AF_INET."""
@@ -1709,7 +1720,8 @@ def __init__(
17091720
loop=loop,
17101721
)
17111722
if not isinstance(
1712-
self._loop, asyncio.ProactorEventLoop # type: ignore[attr-defined]
1723+
self._loop,
1724+
asyncio.ProactorEventLoop, # type: ignore[attr-defined]
17131725
):
17141726
raise RuntimeError(
17151727
"Named Pipes only available in proactor loop under windows"

aiohttp/resolver.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,9 +258,10 @@ def release_resolver(
258258
loop: The event loop the resolver was using.
259259
"""
260260
# Remove client from its loop's tracking
261-
if loop not in self._loop_data:
261+
current_loop_data = self._loop_data.get(loop)
262+
if current_loop_data is None:
262263
return
263-
resolver, client_set = self._loop_data[loop]
264+
resolver, client_set = current_loop_data
264265
client_set.discard(client)
265266
# If no more clients for this loop, cancel and remove its resolver
266267
if not client_set:

tests/test_connector.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,6 +1270,7 @@ async def test_tcp_connector_dns_cache_not_expired(loop, dns_response) -> None:
12701270
with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver:
12711271
conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=True, ttl_dns_cache=10)
12721272
m_resolver().resolve.return_value = dns_response()
1273+
m_resolver().close = mock.AsyncMock()
12731274
await conn._resolve_host("localhost", 8080)
12741275
await conn._resolve_host("localhost", 8080)
12751276
m_resolver().resolve.assert_called_once_with("localhost", 8080, family=0)
@@ -1281,6 +1282,7 @@ async def test_tcp_connector_dns_cache_forever(loop, dns_response) -> None:
12811282
with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver:
12821283
conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=True, ttl_dns_cache=10)
12831284
m_resolver().resolve.return_value = dns_response()
1285+
m_resolver().close = mock.AsyncMock()
12841286
await conn._resolve_host("localhost", 8080)
12851287
await conn._resolve_host("localhost", 8080)
12861288
m_resolver().resolve.assert_called_once_with("localhost", 8080, family=0)
@@ -1292,6 +1294,7 @@ async def test_tcp_connector_use_dns_cache_disabled(loop, dns_response) -> None:
12921294
with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver:
12931295
conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=False)
12941296
m_resolver().resolve.side_effect = [dns_response(), dns_response()]
1297+
m_resolver().close = mock.AsyncMock()
12951298
await conn._resolve_host("localhost", 8080)
12961299
await conn._resolve_host("localhost", 8080)
12971300
m_resolver().resolve.assert_has_calls(
@@ -1308,6 +1311,7 @@ async def test_tcp_connector_dns_throttle_requests(loop, dns_response) -> None:
13081311
with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver:
13091312
conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=True, ttl_dns_cache=10)
13101313
m_resolver().resolve.return_value = dns_response()
1314+
m_resolver().close = mock.AsyncMock()
13111315
loop.create_task(conn._resolve_host("localhost", 8080))
13121316
loop.create_task(conn._resolve_host("localhost", 8080))
13131317
await asyncio.sleep(0)
@@ -1322,6 +1326,7 @@ async def test_tcp_connector_dns_throttle_requests_exception_spread(loop) -> Non
13221326
conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=True, ttl_dns_cache=10)
13231327
e = Exception()
13241328
m_resolver().resolve.side_effect = e
1329+
m_resolver().close = mock.AsyncMock()
13251330
r1 = loop.create_task(conn._resolve_host("localhost", 8080))
13261331
r2 = loop.create_task(conn._resolve_host("localhost", 8080))
13271332
await asyncio.sleep(0)
@@ -1341,6 +1346,7 @@ async def test_tcp_connector_dns_throttle_requests_cancelled_when_close(
13411346
with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver:
13421347
conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=True, ttl_dns_cache=10)
13431348
m_resolver().resolve.return_value = dns_response()
1349+
m_resolver().close = mock.AsyncMock()
13441350
loop.create_task(conn._resolve_host("localhost", 8080))
13451351
f = loop.create_task(conn._resolve_host("localhost", 8080))
13461352

@@ -1384,6 +1390,7 @@ def exception_handler(loop, context):
13841390
use_dns_cache=False,
13851391
)
13861392
m_resolver().resolve.return_value = dns_response_error()
1393+
m_resolver().close = mock.AsyncMock()
13871394
f = loop.create_task(conn._create_direct_connection(req, [], ClientTimeout(0)))
13881395

13891396
await asyncio.sleep(0)
@@ -1419,6 +1426,7 @@ async def test_tcp_connector_dns_tracing(loop, dns_response) -> None:
14191426
conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=True, ttl_dns_cache=10)
14201427

14211428
m_resolver().resolve.return_value = dns_response()
1429+
m_resolver().close = mock.AsyncMock()
14221430

14231431
await conn._resolve_host("localhost", 8080, traces=traces)
14241432
on_dns_resolvehost_start.assert_called_once_with(
@@ -1460,6 +1468,7 @@ async def test_tcp_connector_dns_tracing_cache_disabled(loop, dns_response) -> N
14601468
conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=False)
14611469

14621470
m_resolver().resolve.side_effect = [dns_response(), dns_response()]
1471+
m_resolver().close = mock.AsyncMock()
14631472

14641473
await conn._resolve_host("localhost", 8080, traces=traces)
14651474

@@ -1514,6 +1523,7 @@ async def test_tcp_connector_dns_tracing_throttle_requests(loop, dns_response) -
15141523
with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver:
15151524
conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=True, ttl_dns_cache=10)
15161525
m_resolver().resolve.return_value = dns_response()
1526+
m_resolver().close = mock.AsyncMock()
15171527
loop.create_task(conn._resolve_host("localhost", 8080, traces=traces))
15181528
loop.create_task(conn._resolve_host("localhost", 8080, traces=traces))
15191529
await asyncio.sleep(0)
@@ -1528,6 +1538,14 @@ async def test_tcp_connector_dns_tracing_throttle_requests(loop, dns_response) -
15281538
await conn.close()
15291539

15301540

1541+
async def test_tcp_connector_close_resolver() -> None:
1542+
m_resolver = mock.AsyncMock()
1543+
with mock.patch("aiohttp.connector.DefaultResolver", return_value=m_resolver):
1544+
conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10)
1545+
await conn.close()
1546+
m_resolver.close.assert_awaited_once()
1547+
1548+
15311549
async def test_dns_error(loop) -> None:
15321550
connector = aiohttp.TCPConnector(loop=loop)
15331551
connector._resolve_host = mock.AsyncMock(
@@ -3691,6 +3709,7 @@ async def resolve_response() -> List[ResolveResult]:
36913709

36923710
with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver:
36933711
m_resolver().resolve.return_value = resolve_response()
3712+
m_resolver().close = mock.AsyncMock()
36943713

36953714
connector = TCPConnector()
36963715
traces = [DummyTracer()]

0 commit comments

Comments
 (0)