diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/decorators.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/decorators.py index a07ff9cfd..4ae64258c 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/decorators.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/decorators.py @@ -1,4 +1,5 @@ from typing import Any, Awaitable, Callable, TypeVar, Union +from urllib.parse import unquote from forestadmin.agent_toolkit.resources.collections.base_collection_resource import BaseCollectionResource from forestadmin.agent_toolkit.resources.collections.filter import parse_timezone @@ -42,6 +43,10 @@ async def _authenticate( except JWTError: return Response(status=401) + context_url = None + if "Forest-Context-Url" in request.headers: + context_url = unquote(request.headers["Forest-Context-Url"]) + request.user = User( rendering_id=int(user["rendering_id"]), user_id=int(user["id"]), @@ -51,6 +56,7 @@ async def _authenticate( last_name=user["last_name"], team=user["team"], timezone=parse_timezone(request), + context_url=context_url, ) return await decorated_fn(self, request) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/utils/context.py b/src/agent_toolkit/forestadmin/agent_toolkit/utils/context.py index dab2e96b7..c4f2921bc 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/utils/context.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/utils/context.py @@ -32,6 +32,7 @@ class User: last_name: str team: str timezone: ZoneInfo + context_url: Optional[str] = None # permission_level # role diff --git a/src/agent_toolkit/tests/resources/collections/test_decorators.py b/src/agent_toolkit/tests/resources/collections/test_decorators.py index 4758d3f09..8eb29ea9f 100644 --- a/src/agent_toolkit/tests/resources/collections/test_decorators.py +++ b/src/agent_toolkit/tests/resources/collections/test_decorators.py @@ -183,6 +183,39 @@ async def _decorated_fn(resource, request): self.assertEqual(response, True) + def test_should_parse_forest_context_url_if_present(self): + user = { + "rendering_id": "1", + "id": "1", + "tags": {"test": "tag"}, + "email": "user@company.com", + "first_name": "first_name", + "last_name": "last_name", + "team": "best_team", + } + encoded_user = jwt.encode(user, "auth_secret") + request = RequestCollection( + RequestMethod.GET, + self.book_collection, + body=None, + query={"timezone": "Europe/Paris"}, + headers={ + "Authorization": f"Bearer {encoded_user}", + "Forest-Context-Url": "http://localhost/?param%3D%2Ftest%2F", + }, + ) + + async def _decorated_fn(resource, request): + self.assertEqual(request.user.context_url, "http://localhost/?param=/test/") + + return True + + decorated_fn = AsyncMock(wraps=_decorated_fn) + response = self.loop.run_until_complete(_authenticate(self.collection_resource, request, decorated_fn)) + decorated_fn.assert_awaited_once_with(self.collection_resource, request) + + self.assertEqual(response, True) + class TestAuthorizeDecorators(TestDecorators): @classmethod diff --git a/src/django_agent/forestadmin/django_agent/apps.py b/src/django_agent/forestadmin/django_agent/apps.py index 2e86ad016..32b98e229 100644 --- a/src/django_agent/forestadmin/django_agent/apps.py +++ b/src/django_agent/forestadmin/django_agent/apps.py @@ -4,6 +4,7 @@ import threading from typing import Callable, Optional, Union +from corsheaders import defaults as default_cors_settings from django.apps import AppConfig, apps from django.conf import settings from forestadmin.agent_toolkit.forest_logger import ForestLogger @@ -77,6 +78,7 @@ def get_agent(cls) -> DjangoAgent: def ready(self): # we need to wait for other apps to be ready, for this forest app must be ready # that's why we need another thread waiting for every app to be ready + self.setup_cors_settings() t = threading.Thread(name="forest.wait_and_launch_agent", target=self._wait_for_all_apps_ready_and_launch_agent) t.start() @@ -99,3 +101,13 @@ def _wait_for_all_apps_ready_and_launch_agent(self): ) DjangoAgentApp._DJANGO_AGENT = init_app_agent() + + def setup_cors_settings(self): + # headers + if getattr(settings, "CORS_ALLOW_HEADERS", None): + allowed_headers = settings.CORS_ALLOW_HEADERS + else: + allowed_headers = default_cors_settings.default_headers + + if "Forest-Context-Url" not in allowed_headers: + settings.CORS_ALLOW_HEADERS = (*allowed_headers, "Forest-Context-Url")