Skip to content

Commit

Permalink
Fix auth reset logic during redirects to different origin when _base_…
Browse files Browse the repository at this point in the history
…url set (#8966)
  • Loading branch information
MaximZemskov authored Sep 1, 2024
1 parent 45d6e4f commit f569894
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGES/8966.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Updated ClientSession's auth logic to include default auth only if the request URL's origin matches _base_url; otherwise, the auth will not be included -- by :user:`MaximZemskov`
5 changes: 4 additions & 1 deletion aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,10 @@ async def _request(

if auth is None:
auth = auth_from_url
if auth is None:

if auth is None and (
not self._base_url or self._base_url.origin() == url.origin()
):
auth = self._default_auth
# It would be confusing if we support explicit
# Authorization header with auth argument
Expand Down
8 changes: 5 additions & 3 deletions docs/client_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,11 @@ The client session supports the context manager protocol for self closing.

:param aiohttp.BasicAuth auth: an object that represents HTTP Basic
Authorization (optional). It will be included
with any request to any origin and will not be
removed, event during redirect to a different
origin.
with any request. However, if the
``_base_url`` parameter is set, the request
URL's origin must match the base URL's origin;
otherwise, the default auth will not be
included.

:param version: supported HTTP version, ``HTTP 1.1`` by default.

Expand Down
132 changes: 132 additions & 0 deletions tests/test_client_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2961,6 +2961,138 @@ async def close(self) -> None:
assert resp.status == 200


async def test_auth_persist_on_redirect_to_other_host_with_global_auth(
create_server_for_url_and_handler: Callable[[URL, Handler], Awaitable[TestServer]],
) -> None:
url_from = URL("http://host1.com/path1")
url_to = URL("http://host2.com/path2")

async def srv_from(request: web.Request) -> NoReturn:
assert request.host == url_from.host
assert request.headers["Authorization"] == "Basic dXNlcjpwYXNz"
raise web.HTTPFound(url_to)

async def srv_to(request: web.Request) -> web.Response:
assert request.host == url_to.host
assert "Authorization" in request.headers, "Header was dropped"
return web.Response()

server_from = await create_server_for_url_and_handler(url_from, srv_from)
server_to = await create_server_for_url_and_handler(url_to, srv_to)

assert (
url_from.host != url_to.host or server_from.scheme != server_to.scheme
), "Invalid test case, host or scheme must differ"

protocol_port_map = {
"http": 80,
"https": 443,
}
etc_hosts = {
(url_from.host, protocol_port_map[server_from.scheme]): server_from,
(url_to.host, protocol_port_map[server_to.scheme]): server_to,
}

class FakeResolver(AbstractResolver):
async def resolve(
self,
host: str,
port: int = 0,
family: socket.AddressFamily = socket.AF_INET,
) -> List[ResolveResult]:
server = etc_hosts[(host, port)]
assert server.port is not None

return [
{
"hostname": host,
"host": server.host,
"port": server.port,
"family": socket.AF_INET,
"proto": 0,
"flags": socket.AI_NUMERICHOST,
}
]

async def close(self) -> None:
"""Dummy"""

connector = aiohttp.TCPConnector(resolver=FakeResolver(), ssl=False)

async with aiohttp.ClientSession(
connector=connector, auth=aiohttp.BasicAuth("user", "pass")
) as client:
resp = await client.get(url_from)
assert resp.status == 200


async def test_drop_auth_on_redirect_to_other_host_with_global_auth_and_base_url(
create_server_for_url_and_handler: Callable[[URL, Handler], Awaitable[TestServer]],
) -> None:
url_from = URL("http://host1.com/path1")
url_to = URL("http://host2.com/path2")

async def srv_from(request: web.Request) -> NoReturn:
assert request.host == url_from.host
assert request.headers["Authorization"] == "Basic dXNlcjpwYXNz"
raise web.HTTPFound(url_to)

async def srv_to(request: web.Request) -> web.Response:
assert request.host == url_to.host
assert "Authorization" not in request.headers, "Header was not dropped"
return web.Response()

server_from = await create_server_for_url_and_handler(url_from, srv_from)
server_to = await create_server_for_url_and_handler(url_to, srv_to)

assert (
url_from.host != url_to.host or server_from.scheme != server_to.scheme
), "Invalid test case, host or scheme must differ"

protocol_port_map = {
"http": 80,
"https": 443,
}
etc_hosts = {
(url_from.host, protocol_port_map[server_from.scheme]): server_from,
(url_to.host, protocol_port_map[server_to.scheme]): server_to,
}

class FakeResolver(AbstractResolver):
async def resolve(
self,
host: str,
port: int = 0,
family: socket.AddressFamily = socket.AF_INET,
) -> List[ResolveResult]:
server = etc_hosts[(host, port)]
assert server.port is not None

return [
{
"hostname": host,
"host": server.host,
"port": server.port,
"family": socket.AF_INET,
"proto": 0,
"flags": socket.AI_NUMERICHOST,
}
]

async def close(self) -> None:
"""Dummy"""

connector = aiohttp.TCPConnector(resolver=FakeResolver(), ssl=False)

async with aiohttp.ClientSession(
connector=connector,
base_url="http://host1.com",
auth=aiohttp.BasicAuth("user", "pass"),
) as client:
resp = await client.get("/path1")
assert resp.status == 200


async def test_async_with_session() -> None:
async with aiohttp.ClientSession() as session:
pass
Expand Down

0 comments on commit f569894

Please sign in to comment.