Skip to content

Commit 59f52ca

Browse files
committed
fix(connector): propagate proxy headers on connection reuse
1 parent cfdafac commit 59f52ca

4 files changed

Lines changed: 121 additions & 20 deletions

File tree

CHANGES/2596.bugfix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Pass proxy authorization headers when reusing a connection. Prevents 407 (Proxy authentication required) when reusing a connection.

CONTRIBUTORS.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ Georges Dubus
151151
Greg Holt
152152
Gregory Haynes
153153
Grigoriy Soldatov
154+
Guillaume Leurquin
154155
Gus Goulart
155156
Gustavo Carneiro
156157
Günther Jena

aiohttp/connector.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,30 @@ def _available_connections(self, key: "ConnectionKey") -> int:
550550

551551
return total_remain
552552

553+
def _update_proxy_auth_header_and_build_proxy_req(
554+
self, req: ClientRequest
555+
) -> ClientRequestBase:
556+
"""Set Proxy-Authorization header for non-SSL proxy requests and builds the proxy request for SSL proxy requests."""
557+
url = req.proxy
558+
assert url is not None
559+
headers = req.proxy_headers or CIMultiDict[str]()
560+
headers[hdrs.HOST] = req.headers[hdrs.HOST]
561+
proxy_req = ClientRequestBase(
562+
hdrs.METH_GET,
563+
url,
564+
headers=headers,
565+
auth=req.proxy_auth,
566+
loop=self._loop,
567+
ssl=req.ssl,
568+
)
569+
auth = proxy_req.headers.pop(hdrs.AUTHORIZATION, None)
570+
if auth is not None:
571+
if not req.is_ssl():
572+
req.headers[hdrs.PROXY_AUTHORIZATION] = auth
573+
else:
574+
proxy_req.headers[hdrs.PROXY_AUTHORIZATION] = auth
575+
return proxy_req
576+
553577
async def connect(
554578
self, req: ClientRequest, traces: list["Trace"], timeout: "ClientTimeout"
555579
) -> Connection:
@@ -558,12 +582,16 @@ async def connect(
558582
if (conn := await self._get(key, traces)) is not None:
559583
# If we do not have to wait and we can get a connection from the pool
560584
# we can avoid the timeout ceil logic and directly return the connection
585+
if req.proxy:
586+
self._update_proxy_auth_header_and_build_proxy_req(req)
561587
return conn
562588

563589
async with ceil_timeout(timeout.connect, timeout.ceil_threshold):
564590
if self._available_connections(key) <= 0:
565591
await self._wait_for_available_connection(key, traces)
566592
if (conn := await self._get(key, traces)) is not None:
593+
if req.proxy:
594+
self._update_proxy_auth_header_and_build_proxy_req(req)
567595
return conn
568596

569597
placeholder = cast(
@@ -1453,32 +1481,13 @@ async def _create_direct_connection(
14531481
async def _create_proxy_connection(
14541482
self, req: ClientRequest, traces: list["Trace"], timeout: "ClientTimeout"
14551483
) -> tuple[asyncio.BaseTransport, ResponseHandler]:
1456-
headers = CIMultiDict[str]() if req.proxy_headers is None else req.proxy_headers
1457-
headers[hdrs.HOST] = req.headers[hdrs.HOST]
1458-
1459-
url = req.proxy
1460-
assert url is not None
1461-
proxy_req = ClientRequestBase(
1462-
hdrs.METH_GET,
1463-
url,
1464-
headers=headers,
1465-
auth=req.proxy_auth,
1466-
loop=self._loop,
1467-
ssl=req.ssl,
1468-
)
1484+
proxy_req = self._update_proxy_auth_header_and_build_proxy_req(req)
14691485

14701486
# create connection to proxy server
14711487
transport, proto = await self._create_direct_connection(
14721488
proxy_req, [], timeout, client_error=ClientProxyConnectionError
14731489
)
14741490

1475-
auth = proxy_req.headers.pop(hdrs.AUTHORIZATION, None)
1476-
if auth is not None:
1477-
if not req.is_ssl():
1478-
req.headers[hdrs.PROXY_AUTHORIZATION] = auth
1479-
else:
1480-
proxy_req.headers[hdrs.PROXY_AUTHORIZATION] = auth
1481-
14821491
if req.is_ssl():
14831492
self._warn_about_tls_in_tls(transport, req)
14841493

tests/test_connector.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Tests of http client with custom Connector
22
import asyncio
3+
import contextlib
34
import gc
45
import hashlib
56
import platform
@@ -16,6 +17,7 @@
1617
from unittest import mock
1718

1819
import pytest
20+
from multidict import CIMultiDict
1921
from pytest_mock import MockerFixture
2022
from yarl import URL
2123

@@ -25,6 +27,7 @@
2527
ClientSession,
2628
ClientTimeout,
2729
connector as connector_module,
30+
hdrs,
2831
web,
2932
)
3033
from aiohttp.abc import ResolveResult
@@ -3299,6 +3302,93 @@ async def test_connect_reuseconn_tracing(
32993302
await conn.close()
33003303

33013304

3305+
@pytest.mark.parametrize(
3306+
"test_case,wait_for_con,expect_proxy_auth_header",
3307+
[
3308+
("use_proxy_with_embedded_auth", False, True),
3309+
("use_proxy_with_auth_headers", True, True),
3310+
("use_proxy_no_auth", False, False),
3311+
("dont_use_proxy", False, False),
3312+
],
3313+
)
3314+
async def test_connect_reuse_proxy_headers( # type: ignore[misc]
3315+
loop: asyncio.AbstractEventLoop,
3316+
make_client_request: _RequestMaker,
3317+
test_case: str,
3318+
wait_for_con: bool,
3319+
expect_proxy_auth_header: bool,
3320+
) -> None:
3321+
proto = create_mocked_conn(loop)
3322+
proto.is_connected.return_value = True
3323+
3324+
if test_case != "dont_use_proxy":
3325+
proxy = (
3326+
URL("http://user:[email protected]")
3327+
if test_case == "use_proxy_with_embedded_auth"
3328+
else URL("http://example.com")
3329+
)
3330+
proxy_headers = (
3331+
CIMultiDict({hdrs.AUTHORIZATION: "Basic dXNlcjpwYXNzd29yZA=="})
3332+
if test_case == "use_proxy_with_auth_headers"
3333+
else None
3334+
)
3335+
else:
3336+
proxy = None
3337+
proxy_headers = None
3338+
key = ConnectionKey(
3339+
"localhost",
3340+
80,
3341+
False,
3342+
True,
3343+
proxy,
3344+
None,
3345+
hash(tuple(proxy_headers.items())) if proxy_headers else None,
3346+
)
3347+
req = make_client_request(
3348+
"GET",
3349+
URL("http://localhost:80"),
3350+
loop=loop,
3351+
response_class=mock.Mock(),
3352+
proxy=proxy,
3353+
proxy_headers=proxy_headers,
3354+
)
3355+
3356+
conn = aiohttp.BaseConnector(limit=1)
3357+
3358+
async def _create_con(*args: Any, **kwargs: Any) -> None:
3359+
conn._conns[key] = deque([(proto, loop.time())])
3360+
3361+
with contextlib.ExitStack() as stack:
3362+
if wait_for_con:
3363+
# Simulate no available connections
3364+
stack.enter_context(
3365+
mock.patch.object(
3366+
conn, "_available_connections", autospec=True, return_value=0
3367+
)
3368+
)
3369+
# Upon waiting for a connection, populate _conns with our proto,
3370+
# mocking a connection becoming immediately available
3371+
stack.enter_context(
3372+
mock.patch.object(
3373+
conn,
3374+
"_wait_for_available_connection",
3375+
autospec=True,
3376+
side_effect=_create_con,
3377+
)
3378+
)
3379+
else:
3380+
await _create_con()
3381+
# Call function to test
3382+
conn2 = await conn.connect(req, [], ClientTimeout())
3383+
conn2.release()
3384+
await conn.close()
3385+
3386+
if expect_proxy_auth_header:
3387+
assert req.headers[hdrs.PROXY_AUTHORIZATION] == "Basic dXNlcjpwYXNzd29yZA=="
3388+
else:
3389+
assert hdrs.PROXY_AUTHORIZATION not in req.headers
3390+
3391+
33023392
async def test_connect_with_limit_and_limit_per_host(
33033393
loop: asyncio.AbstractEventLoop,
33043394
key: ConnectionKey,

0 commit comments

Comments
 (0)