Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix auth reset logic during redirects to different origin when _base_url set #8966

Merged
merged 4 commits into from
Sep 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/8966.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Updated ClientSession's auth logic to include default auth only if the request URL's origin matches _base_url; otherwise, the auth will not be included -- by :user:`MaximZemskov`
5 changes: 4 additions & 1 deletion aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,10 @@ async def _request(

if auth is None:
auth = auth_from_url
if auth is None:

if auth is None and (
not self._base_url or self._base_url.origin() == url.origin()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling .origin() is a bit expensive #7583 (comment)

We should probably cache self._base_url.origin() as self._base_url_origin` so we don't have to build it every time

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably could guard this with self._default_auth being set as well

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

):
auth = self._default_auth
# It would be confusing if we support explicit
# Authorization header with auth argument
Expand Down
8 changes: 5 additions & 3 deletions docs/client_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,11 @@ The client session supports the context manager protocol for self closing.

:param aiohttp.BasicAuth auth: an object that represents HTTP Basic
Authorization (optional). It will be included
with any request to any origin and will not be
removed, event during redirect to a different
origin.
with any request. However, if the
``_base_url`` parameter is set, the request
URL's origin must match the base URL's origin;
otherwise, the default auth will not be
included.

:param version: supported HTTP version, ``HTTP 1.1`` by default.

Expand Down
132 changes: 132 additions & 0 deletions tests/test_client_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2961,6 +2961,138 @@ async def close(self) -> None:
assert resp.status == 200


async def test_auth_persist_on_redirect_to_other_host_with_global_auth(
create_server_for_url_and_handler: Callable[[URL, Handler], Awaitable[TestServer]],
) -> None:
url_from = URL("http://host1.com/path1")
url_to = URL("http://host2.com/path2")

async def srv_from(request: web.Request) -> NoReturn:
assert request.host == url_from.host
assert request.headers["Authorization"] == "Basic dXNlcjpwYXNz"
raise web.HTTPFound(url_to)

async def srv_to(request: web.Request) -> web.Response:
assert request.host == url_to.host
assert "Authorization" in request.headers, "Header was dropped"
return web.Response()

server_from = await create_server_for_url_and_handler(url_from, srv_from)
server_to = await create_server_for_url_and_handler(url_to, srv_to)

assert (
url_from.host != url_to.host or server_from.scheme != server_to.scheme
), "Invalid test case, host or scheme must differ"

protocol_port_map = {
"http": 80,
"https": 443,
}
etc_hosts = {
(url_from.host, protocol_port_map[server_from.scheme]): server_from,
(url_to.host, protocol_port_map[server_to.scheme]): server_to,
}

class FakeResolver(AbstractResolver):
async def resolve(
self,
host: str,
port: int = 0,
family: socket.AddressFamily = socket.AF_INET,
) -> List[ResolveResult]:
server = etc_hosts[(host, port)]
assert server.port is not None

return [
{
"hostname": host,
"host": server.host,
"port": server.port,
"family": socket.AF_INET,
"proto": 0,
"flags": socket.AI_NUMERICHOST,
}
]

async def close(self) -> None:
"""Dummy"""

connector = aiohttp.TCPConnector(resolver=FakeResolver(), ssl=False)

async with aiohttp.ClientSession(
connector=connector, auth=aiohttp.BasicAuth("user", "pass")
) as client:
resp = await client.get(url_from)
assert resp.status == 200


async def test_drop_auth_on_redirect_to_other_host_with_global_auth_and_base_url(
create_server_for_url_and_handler: Callable[[URL, Handler], Awaitable[TestServer]],
) -> None:
url_from = URL("http://host1.com/path1")
url_to = URL("http://host2.com/path2")

async def srv_from(request: web.Request) -> NoReturn:
assert request.host == url_from.host
assert request.headers["Authorization"] == "Basic dXNlcjpwYXNz"
raise web.HTTPFound(url_to)

async def srv_to(request: web.Request) -> web.Response:
assert request.host == url_to.host
assert "Authorization" not in request.headers, "Header was not dropped"
return web.Response()

server_from = await create_server_for_url_and_handler(url_from, srv_from)
server_to = await create_server_for_url_and_handler(url_to, srv_to)

assert (
url_from.host != url_to.host or server_from.scheme != server_to.scheme
), "Invalid test case, host or scheme must differ"

protocol_port_map = {
"http": 80,
"https": 443,
}
etc_hosts = {
(url_from.host, protocol_port_map[server_from.scheme]): server_from,
(url_to.host, protocol_port_map[server_to.scheme]): server_to,
}

class FakeResolver(AbstractResolver):
async def resolve(
self,
host: str,
port: int = 0,
family: socket.AddressFamily = socket.AF_INET,
) -> List[ResolveResult]:
server = etc_hosts[(host, port)]
assert server.port is not None

return [
{
"hostname": host,
"host": server.host,
"port": server.port,
"family": socket.AF_INET,
"proto": 0,
"flags": socket.AI_NUMERICHOST,
}
]

async def close(self) -> None:
"""Dummy"""

connector = aiohttp.TCPConnector(resolver=FakeResolver(), ssl=False)

async with aiohttp.ClientSession(
connector=connector,
base_url="http://host1.com",
auth=aiohttp.BasicAuth("user", "pass"),
) as client:
resp = await client.get("/path1")
assert resp.status == 200


async def test_async_with_session() -> None:
async with aiohttp.ClientSession() as session:
pass
Expand Down
Loading