diff --git a/pyodk/_endpoints/comments.py b/pyodk/_endpoints/comments.py index a5cd02d..1bab2ba 100644 --- a/pyodk/_endpoints/comments.py +++ b/pyodk/_endpoints/comments.py @@ -70,7 +70,9 @@ def list( response = self.session.response_or_error( method="GET", - url=self.urls.list.format(project_id=pid, form_id=fid, instance_id=iid), + url=self.session.urlformat( + self.urls.list, project_id=pid, form_id=fid, instance_id=iid + ), logger=log, ) data = response.json() @@ -105,7 +107,9 @@ def post( response = self.session.response_or_error( method="POST", - url=self.urls.post.format(project_id=pid, form_id=fid, instance_id=iid), + url=self.session.urlformat( + self.urls.post, project_id=pid, form_id=fid, instance_id=iid + ), logger=log, json=json, ) diff --git a/pyodk/_endpoints/form_assignments.py b/pyodk/_endpoints/form_assignments.py index 8376091..4968b82 100644 --- a/pyodk/_endpoints/form_assignments.py +++ b/pyodk/_endpoints/form_assignments.py @@ -58,8 +58,8 @@ def assign( response = self.session.response_or_error( method="POST", - url=self.urls.post.format( - project_id=pid, form_id=fid, role_id=rid, user_id=uid + url=self.session.urlformat( + self.urls.post, project_id=pid, form_id=fid, role_id=rid, user_id=uid ), logger=log, ) diff --git a/pyodk/_endpoints/form_draft_attachments.py b/pyodk/_endpoints/form_draft_attachments.py index b376082..0fc4bb6 100644 --- a/pyodk/_endpoints/form_draft_attachments.py +++ b/pyodk/_endpoints/form_draft_attachments.py @@ -61,7 +61,9 @@ def upload( with open(file_path, "rb") as fd: response = self.session.response_or_error( method="POST", - url=self.urls.post.format(project_id=pid, form_id=fid, fname=file_name), + url=self.session.urlformat( + self.urls.post, project_id=pid, form_id=fid, fname=file_name + ), logger=log, data=fd, ) diff --git a/pyodk/_endpoints/form_drafts.py b/pyodk/_endpoints/form_drafts.py index 0c71da9..ad504a5 100644 --- a/pyodk/_endpoints/form_drafts.py +++ b/pyodk/_endpoints/form_drafts.py @@ -76,7 +76,7 @@ def create( ) headers = { "Content-Type": content_type, - "X-XlsForm-FormId-Fallback": file_path.stem, + "X-XlsForm-FormId-Fallback": self.session.urlquote(file_path.stem), } except PyODKError as err: log.error(err, exc_info=True) @@ -85,7 +85,7 @@ def create( with open(file_path, "rb") if file_path is not None else nullcontext() as fd: response = self.session.response_or_error( method="POST", - url=self.urls.post.format(project_id=pid, form_id=fid), + url=self.session.urlformat(self.urls.post, project_id=pid, form_id=fid), logger=log, headers=headers, params=params, @@ -121,7 +121,9 @@ def publish( response = self.session.response_or_error( method="POST", - url=self.urls.post_publish.format(project_id=pid, form_id=fid), + url=self.session.urlformat( + self.urls.post_publish, project_id=pid, form_id=fid + ), logger=log, params=params, ) diff --git a/pyodk/_endpoints/forms.py b/pyodk/_endpoints/forms.py index bf57f31..87dcb22 100644 --- a/pyodk/_endpoints/forms.py +++ b/pyodk/_endpoints/forms.py @@ -85,7 +85,7 @@ def list(self, project_id: Optional[int] = None) -> List[Form]: else: response = self.session.response_or_error( method="GET", - url=self.urls.list.format(project_id=pid), + url=self.session.urlformat(self.urls.list, project_id=pid), logger=log, ) data = response.json() @@ -113,7 +113,7 @@ def get( else: response = self.session.response_or_error( method="GET", - url=self.urls.get.format(project_id=pid, form_id=fid), + url=self.session.urlformat(self.urls.get, project_id=pid, form_id=fid), logger=log, ) data = response.json() diff --git a/pyodk/_endpoints/project_app_users.py b/pyodk/_endpoints/project_app_users.py index 7ba17ee..5de7d6d 100644 --- a/pyodk/_endpoints/project_app_users.py +++ b/pyodk/_endpoints/project_app_users.py @@ -63,7 +63,7 @@ def list( response = self.session.response_or_error( method="GET", - url=self.urls.list.format(project_id=pid), + url=self.session.urlformat(self.urls.list, project_id=pid), logger=log, ) data = response.json() @@ -92,7 +92,7 @@ def create( response = self.session.response_or_error( method="POST", - url=self.urls.post.format(project_id=pid), + url=self.session.urlformat(self.urls.post, project_id=pid), logger=log, json=json, ) diff --git a/pyodk/_endpoints/projects.py b/pyodk/_endpoints/projects.py index 901de2d..159e3fb 100644 --- a/pyodk/_endpoints/projects.py +++ b/pyodk/_endpoints/projects.py @@ -95,7 +95,7 @@ def get(self, project_id: Optional[int] = None) -> Project: else: response = self.session.response_or_error( method="GET", - url=self.urls.get.format(project_id=pid), + url=self.session.urlformat(self.urls.get, project_id=pid), logger=log, ) data = response.json() diff --git a/pyodk/_endpoints/submissions.py b/pyodk/_endpoints/submissions.py index 41ff07e..26f1489 100644 --- a/pyodk/_endpoints/submissions.py +++ b/pyodk/_endpoints/submissions.py @@ -88,7 +88,7 @@ def list( response = self.session.response_or_error( method="GET", - url=self.urls.list.format(project_id=pid, form_id=fid), + url=self.session.urlformat(self.urls.list, project_id=pid, form_id=fid), logger=log, ) data = response.json() @@ -119,7 +119,9 @@ def get( response = self.session.response_or_error( method="GET", - url=self.urls.get.format(project_id=pid, form_id=fid, instance_id=iid), + url=self.session.urlformat( + self.urls.get, project_id=pid, form_id=fid, instance_id=iid + ), logger=log, ) data = response.json() @@ -180,7 +182,9 @@ def get_table( response = self.session.response_or_error( method="GET", - url=self.urls.get_table.format(project_id=pid, form_id=fid, table_name=table), + url=self.session.urlformat( + self.urls.get_table, project_id=pid, form_id=fid, table_name=table + ), logger=log, params=params, ) @@ -225,7 +229,7 @@ def create( response = self.session.response_or_error( method="POST", - url=self.urls.post.format(project_id=pid, form_id=fid), + url=self.session.urlformat(self.urls.post, project_id=pid, form_id=fid), logger=log, headers={"Content-Type": "application/xml"}, params=params, @@ -272,7 +276,9 @@ def _put( response = self.session.response_or_error( method="PUT", - url=self.urls.put.format(project_id=pid, form_id=fid, instance_id=iid), + url=self.session.urlformat( + self.urls.put, project_id=pid, form_id=fid, instance_id=iid + ), logger=log, headers={"Content-Type": "application/xml"}, data=xml, @@ -308,7 +314,9 @@ def _patch( response = self.session.response_or_error( method="PATCH", - url=self.urls.patch.format(project_id=pid, form_id=fid, instance_id=iid), + url=self.session.urlformat( + self.urls.patch, project_id=pid, form_id=fid, instance_id=iid + ), logger=log, json=json, ) diff --git a/pyodk/_utils/session.py b/pyodk/_utils/session.py index 3b92610..f0d1f90 100644 --- a/pyodk/_utils/session.py +++ b/pyodk/_utils/session.py @@ -1,5 +1,7 @@ from logging import Logger -from urllib.parse import urljoin +from string import Formatter +from typing import Any +from urllib.parse import quote_plus, urljoin from requests import PreparedRequest, Response from requests import Session as RequestsSession @@ -12,6 +14,18 @@ from pyodk.errors import PyODKError +class URLFormatter(Formatter): + """ + Makes a valid URL by sending each format input field through urllib.parse.quote_plus. + """ + + def format_field(self, value: Any, format_spec: str) -> Any: + return format(quote_plus(str(value)), format_spec) + + +_URL_FORMATTER = URLFormatter() + + class Adapter(HTTPAdapter): def __init__(self, *args, **kwargs): if "timeout" in kwargs: @@ -100,6 +114,14 @@ def base_url_validate(base_url: str, api_version: str): def urljoin(self, url: str) -> str: return urljoin(self.base_url, url.lstrip("/")) + @staticmethod + def urlformat(url: str, *args, **kwargs) -> str: + return _URL_FORMATTER.format(url, *args, **kwargs) + + @staticmethod + def urlquote(url: str) -> str: + return _URL_FORMATTER.format_field(url, format_spec="") + def request(self, method, url, *args, **kwargs): return super().request(method, self.urljoin(url), *args, **kwargs) diff --git a/tests/endpoints/test_forms.py b/tests/endpoints/test_forms.py index cf0abaf..e145a87 100644 --- a/tests/endpoints/test_forms.py +++ b/tests/endpoints/test_forms.py @@ -3,10 +3,11 @@ from functools import wraps from typing import Callable from unittest import TestCase -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, mock_open, patch from pyodk._endpoints.form_draft_attachments import FormDraftAttachmentService from pyodk._endpoints.form_drafts import FormDraftService +from pyodk._endpoints.form_drafts import log as form_drafts_log from pyodk._endpoints.forms import Form, FormService from pyodk._utils.session import Session from pyodk.client import Client @@ -179,6 +180,45 @@ def test_update__def_and_attach__create_upload_publish(self, ctx: MockContext): form_id="foo", version=None, project_id=None ) + @staticmethod + def update__def_encoding_steps( + form_id: str, definition: str, expected_url: str, expected_fallback_id: str + ): + client = Client() + + def mock_wrap_error(**kwargs): + return kwargs["value"] + + with patch.object(Session, "response_or_error") as mock_response, patch( + "pyodk._utils.validators.wrap_error", mock_wrap_error + ), patch("builtins.open", mock_open(), create=True) as mock_open_patch: + client.forms.update(form_id, definition=definition) + mock_response.assert_any_call( + method="POST", + url=expected_url, + logger=form_drafts_log, + headers={ + "Content-Type": ( + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" + ), + "X-XlsForm-FormId-Fallback": expected_fallback_id, + }, + params={"ignoreWarnings": True}, + data=mock_open_patch.return_value, + ) + + def test_update__def_encoding(self): + """Should find that the URL and fallback header are url-encoded.""" + test_cases = ( + ("foo", "/some/path/foo.xlsx", "projects/1/forms/foo/draft", "foo"), + ("foo", "/some/path/✅.xlsx", "projects/1/forms/foo/draft", "%E2%9C%85"), + ("✅", "/some/path/✅.xlsx", "projects/1/forms/%E2%9C%85/draft", "%E2%9C%85"), + ("✅", "/some/path/foo.xlsx", "projects/1/forms/%E2%9C%85/draft", "foo"), + ) + for case in test_cases: + with self.subTest(msg=str(case)): + self.update__def_encoding_steps(*case) + def test_update__no_def_no_attach__raises(self): """Should raise an error if there is no definition or attachment.""" client = Client() diff --git a/tests/test_session.py b/tests/test_session.py index ff56555..bafe41f 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -1,3 +1,4 @@ +from pathlib import Path from unittest import TestCase from pyodk._utils.session import Session @@ -18,3 +19,44 @@ def test_base_url_validate(self): with self.subTest(msg=f"{base_url}"): observed = Session.base_url_validate(base_url=base_url, api_version="v1") self.assertEqual(expected, observed) + + def test_urlformat(self): + """Should replace input fields with url-encoded values.""" + url = "projects/{project_id}/forms/{form_id}" + test_cases = ( + # Basic latin string + ({"project_id": 1, "form_id": "a"}, "projects/1/forms/a"), + # integer + ({"project_id": 1, "form_id": 1}, "projects/1/forms/1"), + # latin symbols + ({"project_id": 1, "form_id": "+-_*%*"}, "projects/1/forms/%2B-_%2A%25%2A"), + # lower case e, with combining acute accent (2 symbols) + ({"project_id": 1, "form_id": "tést"}, "projects/1/forms/te%CC%81st"), + # lower case e with acute (1 symbol) + ({"project_id": 1, "form_id": "tést"}, "projects/1/forms/t%C3%A9st"), + # white heavy check mark + ({"project_id": 1, "form_id": "✅"}, "projects/1/forms/%E2%9C%85"), + ) + for params, expected in test_cases: + with self.subTest(msg=str(params)): + self.assertEqual(expected, Session.urlformat(url, **params)) + + def test_urlquote(self): + """Should url-encode input values.""" + test_cases = ( + # Basic latin string + ("test.xlsx", "test"), + # integer + ("1.xls", "1"), + # latin symbols + ("+-_*%*.xls", "%2B-_%2A%25%2A"), + # lower case e, with combining acute accent (2 symbols) + ("tést.xlsx", "te%CC%81st"), + # lower case e with acute (1 symbol) + ("tést", "t%C3%A9st"), + # white heavy check mark + ("✅.xlsx", "%E2%9C%85"), + ) + for params, expected in test_cases: + with self.subTest(msg=str(params)): + self.assertEqual(expected, Session.urlquote(Path(params).stem))