From f3a3f60b7a4b181272472b188e67d3b5f4ad3306 Mon Sep 17 00:00:00 2001 From: Cycloctane Date: Tue, 29 Oct 2024 04:21:55 +0800 Subject: [PATCH] Allow URLs with paths that end with `/` as base_url in ClientSession (#9530) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Sam Bull Co-authored-by: J. Nick Koston --- CHANGES/9530.feature.rst | 2 ++ aiohttp/client.py | 8 ++++---- docs/client_reference.rst | 17 ++++++++++++++--- tests/test_client_session.py | 23 +++++++++++++++++++++++ 4 files changed, 43 insertions(+), 7 deletions(-) create mode 100644 CHANGES/9530.feature.rst diff --git a/CHANGES/9530.feature.rst b/CHANGES/9530.feature.rst new file mode 100644 index 00000000000..cc4e75a13ca --- /dev/null +++ b/CHANGES/9530.feature.rst @@ -0,0 +1,2 @@ +Updated :py:class:`~aiohttp.ClientSession` to support paths in ``base_url`` parameter. +``base_url`` paths must end with a ``/`` -- by :user:`Cycloctane`. diff --git a/aiohttp/client.py b/aiohttp/client.py index 15c89018ffb..7e466876f1d 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -300,9 +300,9 @@ def __init__( else: self._base_url = URL(base_url) self._base_url_origin = self._base_url.origin() - assert ( - self._base_url_origin == self._base_url - ), "Only absolute URLs without path part are supported" + assert self._base_url.absolute, "Only absolute URLs are supported" + if self._base_url is not None and not self._base_url.path.endswith("/"): + raise ValueError("base_url must have a trailing '/'") loop = asyncio.get_running_loop() @@ -415,7 +415,7 @@ def _build_url(self, str_or_url: StrOrURL) -> URL: if self._base_url is None: return url else: - assert not url.absolute and url.path.startswith("/") + assert not url.absolute return self._base_url.join(url) async def _request( diff --git a/docs/client_reference.rst b/docs/client_reference.rst index 8cb6c91c2bc..ed290277b50 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -62,9 +62,20 @@ The client session supports the context manager protocol for self closing. :param base_url: Base part of the URL (optional) - If set, it allows to skip the base part (https://docs.aiohttp.org) in - request calls. It must not include a path (as in - https://docs.aiohttp.org/en/stable). + If set, allows to join a base part to relative URLs in request calls. + If the URL has a path it must have a trailing ``/`` (as in + https://docs.aiohttp.org/en/stable/). + + Note that URL joining follows :rfc:`3986`. This means, in the most + common case the request URLs should have no leading slash, e.g.:: + + session = ClientSession(base_url="http://example.com/foo/") + + await session.request("GET", "bar") + # request for http://example.com/foo/bar + + await session.request("GET", "/bar") + # request for http://example.com/bar .. versionadded:: 3.8 diff --git a/tests/test_client_session.py b/tests/test_client_session.py index c7d907b0010..54e8ff8c658 100644 --- a/tests/test_client_session.py +++ b/tests/test_client_session.py @@ -1122,6 +1122,24 @@ async def test_requote_redirect_url_default_disable() -> None: URL("http://example.com/test"), id="base_url=URL('http://example.com') url='/test'", ), + pytest.param( + URL("http://example.com/test1/"), + "test2", + URL("http://example.com/test1/test2"), + id="base_url=URL('http://example.com/test1/') url='test2'", + ), + pytest.param( + URL("http://example.com/test1/"), + "/test2", + URL("http://example.com/test2"), + id="base_url=URL('http://example.com/test1/') url='/test2'", + ), + pytest.param( + URL("http://example.com/test1/"), + "test2?q=foo#bar", + URL("http://example.com/test1/test2?q=foo#bar"), + id="base_url=URL('http://example.com/test1/') url='test2?q=foo#bar'", + ), ], ) async def test_build_url_returns_expected_url( @@ -1134,6 +1152,11 @@ async def test_build_url_returns_expected_url( assert session._build_url(url) == expected_url +async def test_base_url_without_trailing_slash() -> None: + with pytest.raises(ValueError, match="base_url must have a trailing '/'"): + ClientSession(base_url="http://example.com/test") + + async def test_instantiation_with_invalid_timeout_value( loop: asyncio.AbstractEventLoop, ) -> None: