From c0cc71ac236256bbe7c834957883301d84c64b32 Mon Sep 17 00:00:00 2001
From: lindsay stevens <lindsay.stevens.au@gmail.com>
Date: Wed, 10 May 2023 20:50:36 +1000
Subject: [PATCH] fix: url-encode parameters URL paths and the fallback_id
 header

- per pyodk/#53, pyodk/#54
---
 pyodk/_endpoints/comments.py               |  8 +++--
 pyodk/_endpoints/form_assignments.py       |  4 +--
 pyodk/_endpoints/form_draft_attachments.py |  4 ++-
 pyodk/_endpoints/form_drafts.py            |  8 +++--
 pyodk/_endpoints/forms.py                  |  4 +--
 pyodk/_endpoints/project_app_users.py      |  4 +--
 pyodk/_endpoints/projects.py               |  2 +-
 pyodk/_endpoints/submissions.py            | 20 +++++++----
 pyodk/_utils/session.py                    | 24 ++++++++++++-
 tests/endpoints/test_forms.py              | 42 +++++++++++++++++++++-
 tests/test_session.py                      | 42 ++++++++++++++++++++++
 11 files changed, 141 insertions(+), 21 deletions(-)

diff --git a/pyodk/_endpoints/comments.py b/pyodk/_endpoints/comments.py
index a5cd02d..1bab2ba 100644
--- a/pyodk/_endpoints/comments.py
+++ b/pyodk/_endpoints/comments.py
@@ -70,7 +70,9 @@ def list(
 
         response = self.session.response_or_error(
             method="GET",
-            url=self.urls.list.format(project_id=pid, form_id=fid, instance_id=iid),
+            url=self.session.urlformat(
+                self.urls.list, project_id=pid, form_id=fid, instance_id=iid
+            ),
             logger=log,
         )
         data = response.json()
@@ -105,7 +107,9 @@ def post(
 
         response = self.session.response_or_error(
             method="POST",
-            url=self.urls.post.format(project_id=pid, form_id=fid, instance_id=iid),
+            url=self.session.urlformat(
+                self.urls.post, project_id=pid, form_id=fid, instance_id=iid
+            ),
             logger=log,
             json=json,
         )
diff --git a/pyodk/_endpoints/form_assignments.py b/pyodk/_endpoints/form_assignments.py
index 8376091..4968b82 100644
--- a/pyodk/_endpoints/form_assignments.py
+++ b/pyodk/_endpoints/form_assignments.py
@@ -58,8 +58,8 @@ def assign(
 
         response = self.session.response_or_error(
             method="POST",
-            url=self.urls.post.format(
-                project_id=pid, form_id=fid, role_id=rid, user_id=uid
+            url=self.session.urlformat(
+                self.urls.post, project_id=pid, form_id=fid, role_id=rid, user_id=uid
             ),
             logger=log,
         )
diff --git a/pyodk/_endpoints/form_draft_attachments.py b/pyodk/_endpoints/form_draft_attachments.py
index b376082..0fc4bb6 100644
--- a/pyodk/_endpoints/form_draft_attachments.py
+++ b/pyodk/_endpoints/form_draft_attachments.py
@@ -61,7 +61,9 @@ def upload(
         with open(file_path, "rb") as fd:
             response = self.session.response_or_error(
                 method="POST",
-                url=self.urls.post.format(project_id=pid, form_id=fid, fname=file_name),
+                url=self.session.urlformat(
+                    self.urls.post, project_id=pid, form_id=fid, fname=file_name
+                ),
                 logger=log,
                 data=fd,
             )
diff --git a/pyodk/_endpoints/form_drafts.py b/pyodk/_endpoints/form_drafts.py
index 0c71da9..ad504a5 100644
--- a/pyodk/_endpoints/form_drafts.py
+++ b/pyodk/_endpoints/form_drafts.py
@@ -76,7 +76,7 @@ def create(
                     )
                 headers = {
                     "Content-Type": content_type,
-                    "X-XlsForm-FormId-Fallback": file_path.stem,
+                    "X-XlsForm-FormId-Fallback": self.session.urlquote(file_path.stem),
                 }
         except PyODKError as err:
             log.error(err, exc_info=True)
@@ -85,7 +85,7 @@ def create(
         with open(file_path, "rb") if file_path is not None else nullcontext() as fd:
             response = self.session.response_or_error(
                 method="POST",
-                url=self.urls.post.format(project_id=pid, form_id=fid),
+                url=self.session.urlformat(self.urls.post, project_id=pid, form_id=fid),
                 logger=log,
                 headers=headers,
                 params=params,
@@ -121,7 +121,9 @@ def publish(
 
         response = self.session.response_or_error(
             method="POST",
-            url=self.urls.post_publish.format(project_id=pid, form_id=fid),
+            url=self.session.urlformat(
+                self.urls.post_publish, project_id=pid, form_id=fid
+            ),
             logger=log,
             params=params,
         )
diff --git a/pyodk/_endpoints/forms.py b/pyodk/_endpoints/forms.py
index bf57f31..87dcb22 100644
--- a/pyodk/_endpoints/forms.py
+++ b/pyodk/_endpoints/forms.py
@@ -85,7 +85,7 @@ def list(self, project_id: Optional[int] = None) -> List[Form]:
         else:
             response = self.session.response_or_error(
                 method="GET",
-                url=self.urls.list.format(project_id=pid),
+                url=self.session.urlformat(self.urls.list, project_id=pid),
                 logger=log,
             )
             data = response.json()
@@ -113,7 +113,7 @@ def get(
         else:
             response = self.session.response_or_error(
                 method="GET",
-                url=self.urls.get.format(project_id=pid, form_id=fid),
+                url=self.session.urlformat(self.urls.get, project_id=pid, form_id=fid),
                 logger=log,
             )
             data = response.json()
diff --git a/pyodk/_endpoints/project_app_users.py b/pyodk/_endpoints/project_app_users.py
index 7ba17ee..5de7d6d 100644
--- a/pyodk/_endpoints/project_app_users.py
+++ b/pyodk/_endpoints/project_app_users.py
@@ -63,7 +63,7 @@ def list(
 
         response = self.session.response_or_error(
             method="GET",
-            url=self.urls.list.format(project_id=pid),
+            url=self.session.urlformat(self.urls.list, project_id=pid),
             logger=log,
         )
         data = response.json()
@@ -92,7 +92,7 @@ def create(
 
         response = self.session.response_or_error(
             method="POST",
-            url=self.urls.post.format(project_id=pid),
+            url=self.session.urlformat(self.urls.post, project_id=pid),
             logger=log,
             json=json,
         )
diff --git a/pyodk/_endpoints/projects.py b/pyodk/_endpoints/projects.py
index 901de2d..159e3fb 100644
--- a/pyodk/_endpoints/projects.py
+++ b/pyodk/_endpoints/projects.py
@@ -95,7 +95,7 @@ def get(self, project_id: Optional[int] = None) -> Project:
         else:
             response = self.session.response_or_error(
                 method="GET",
-                url=self.urls.get.format(project_id=pid),
+                url=self.session.urlformat(self.urls.get, project_id=pid),
                 logger=log,
             )
             data = response.json()
diff --git a/pyodk/_endpoints/submissions.py b/pyodk/_endpoints/submissions.py
index 41ff07e..26f1489 100644
--- a/pyodk/_endpoints/submissions.py
+++ b/pyodk/_endpoints/submissions.py
@@ -88,7 +88,7 @@ def list(
 
         response = self.session.response_or_error(
             method="GET",
-            url=self.urls.list.format(project_id=pid, form_id=fid),
+            url=self.session.urlformat(self.urls.list, project_id=pid, form_id=fid),
             logger=log,
         )
         data = response.json()
@@ -119,7 +119,9 @@ def get(
 
         response = self.session.response_or_error(
             method="GET",
-            url=self.urls.get.format(project_id=pid, form_id=fid, instance_id=iid),
+            url=self.session.urlformat(
+                self.urls.get, project_id=pid, form_id=fid, instance_id=iid
+            ),
             logger=log,
         )
         data = response.json()
@@ -180,7 +182,9 @@ def get_table(
 
         response = self.session.response_or_error(
             method="GET",
-            url=self.urls.get_table.format(project_id=pid, form_id=fid, table_name=table),
+            url=self.session.urlformat(
+                self.urls.get_table, project_id=pid, form_id=fid, table_name=table
+            ),
             logger=log,
             params=params,
         )
@@ -225,7 +229,7 @@ def create(
 
         response = self.session.response_or_error(
             method="POST",
-            url=self.urls.post.format(project_id=pid, form_id=fid),
+            url=self.session.urlformat(self.urls.post, project_id=pid, form_id=fid),
             logger=log,
             headers={"Content-Type": "application/xml"},
             params=params,
@@ -272,7 +276,9 @@ def _put(
 
         response = self.session.response_or_error(
             method="PUT",
-            url=self.urls.put.format(project_id=pid, form_id=fid, instance_id=iid),
+            url=self.session.urlformat(
+                self.urls.put, project_id=pid, form_id=fid, instance_id=iid
+            ),
             logger=log,
             headers={"Content-Type": "application/xml"},
             data=xml,
@@ -308,7 +314,9 @@ def _patch(
 
         response = self.session.response_or_error(
             method="PATCH",
-            url=self.urls.patch.format(project_id=pid, form_id=fid, instance_id=iid),
+            url=self.session.urlformat(
+                self.urls.patch, project_id=pid, form_id=fid, instance_id=iid
+            ),
             logger=log,
             json=json,
         )
diff --git a/pyodk/_utils/session.py b/pyodk/_utils/session.py
index 3b92610..f0d1f90 100644
--- a/pyodk/_utils/session.py
+++ b/pyodk/_utils/session.py
@@ -1,5 +1,7 @@
 from logging import Logger
-from urllib.parse import urljoin
+from string import Formatter
+from typing import Any
+from urllib.parse import quote_plus, urljoin
 
 from requests import PreparedRequest, Response
 from requests import Session as RequestsSession
@@ -12,6 +14,18 @@
 from pyodk.errors import PyODKError
 
 
+class URLFormatter(Formatter):
+    """
+    Makes a valid URL by sending each format input field through urllib.parse.quote_plus.
+    """
+
+    def format_field(self, value: Any, format_spec: str) -> Any:
+        return format(quote_plus(str(value)), format_spec)
+
+
+_URL_FORMATTER = URLFormatter()
+
+
 class Adapter(HTTPAdapter):
     def __init__(self, *args, **kwargs):
         if "timeout" in kwargs:
@@ -100,6 +114,14 @@ def base_url_validate(base_url: str, api_version: str):
     def urljoin(self, url: str) -> str:
         return urljoin(self.base_url, url.lstrip("/"))
 
+    @staticmethod
+    def urlformat(url: str, *args, **kwargs) -> str:
+        return _URL_FORMATTER.format(url, *args, **kwargs)
+
+    @staticmethod
+    def urlquote(url: str) -> str:
+        return _URL_FORMATTER.format_field(url, format_spec="")
+
     def request(self, method, url, *args, **kwargs):
         return super().request(method, self.urljoin(url), *args, **kwargs)
 
diff --git a/tests/endpoints/test_forms.py b/tests/endpoints/test_forms.py
index cf0abaf..e145a87 100644
--- a/tests/endpoints/test_forms.py
+++ b/tests/endpoints/test_forms.py
@@ -3,10 +3,11 @@
 from functools import wraps
 from typing import Callable
 from unittest import TestCase
-from unittest.mock import MagicMock, patch
+from unittest.mock import MagicMock, mock_open, patch
 
 from pyodk._endpoints.form_draft_attachments import FormDraftAttachmentService
 from pyodk._endpoints.form_drafts import FormDraftService
+from pyodk._endpoints.form_drafts import log as form_drafts_log
 from pyodk._endpoints.forms import Form, FormService
 from pyodk._utils.session import Session
 from pyodk.client import Client
@@ -179,6 +180,45 @@ def test_update__def_and_attach__create_upload_publish(self, ctx: MockContext):
             form_id="foo", version=None, project_id=None
         )
 
+    @staticmethod
+    def update__def_encoding_steps(
+        form_id: str, definition: str, expected_url: str, expected_fallback_id: str
+    ):
+        client = Client()
+
+        def mock_wrap_error(**kwargs):
+            return kwargs["value"]
+
+        with patch.object(Session, "response_or_error") as mock_response, patch(
+            "pyodk._utils.validators.wrap_error", mock_wrap_error
+        ), patch("builtins.open", mock_open(), create=True) as mock_open_patch:
+            client.forms.update(form_id, definition=definition)
+        mock_response.assert_any_call(
+            method="POST",
+            url=expected_url,
+            logger=form_drafts_log,
+            headers={
+                "Content-Type": (
+                    "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
+                ),
+                "X-XlsForm-FormId-Fallback": expected_fallback_id,
+            },
+            params={"ignoreWarnings": True},
+            data=mock_open_patch.return_value,
+        )
+
+    def test_update__def_encoding(self):
+        """Should find that the URL and fallback header are url-encoded."""
+        test_cases = (
+            ("foo", "/some/path/foo.xlsx", "projects/1/forms/foo/draft", "foo"),
+            ("foo", "/some/path/✅.xlsx", "projects/1/forms/foo/draft", "%E2%9C%85"),
+            ("✅", "/some/path/✅.xlsx", "projects/1/forms/%E2%9C%85/draft", "%E2%9C%85"),
+            ("✅", "/some/path/foo.xlsx", "projects/1/forms/%E2%9C%85/draft", "foo"),
+        )
+        for case in test_cases:
+            with self.subTest(msg=str(case)):
+                self.update__def_encoding_steps(*case)
+
     def test_update__no_def_no_attach__raises(self):
         """Should raise an error if there is no definition or attachment."""
         client = Client()
diff --git a/tests/test_session.py b/tests/test_session.py
index ff56555..bafe41f 100644
--- a/tests/test_session.py
+++ b/tests/test_session.py
@@ -1,3 +1,4 @@
+from pathlib import Path
 from unittest import TestCase
 
 from pyodk._utils.session import Session
@@ -18,3 +19,44 @@ def test_base_url_validate(self):
             with self.subTest(msg=f"{base_url}"):
                 observed = Session.base_url_validate(base_url=base_url, api_version="v1")
                 self.assertEqual(expected, observed)
+
+    def test_urlformat(self):
+        """Should replace input fields with url-encoded values."""
+        url = "projects/{project_id}/forms/{form_id}"
+        test_cases = (
+            # Basic latin string
+            ({"project_id": 1, "form_id": "a"}, "projects/1/forms/a"),
+            # integer
+            ({"project_id": 1, "form_id": 1}, "projects/1/forms/1"),
+            # latin symbols
+            ({"project_id": 1, "form_id": "+-_*%*"}, "projects/1/forms/%2B-_%2A%25%2A"),
+            # lower case e, with combining acute accent (2 symbols)
+            ({"project_id": 1, "form_id": "tést"}, "projects/1/forms/te%CC%81st"),
+            # lower case e with acute (1 symbol)
+            ({"project_id": 1, "form_id": "tést"}, "projects/1/forms/t%C3%A9st"),
+            # white heavy check mark
+            ({"project_id": 1, "form_id": "✅"}, "projects/1/forms/%E2%9C%85"),
+        )
+        for params, expected in test_cases:
+            with self.subTest(msg=str(params)):
+                self.assertEqual(expected, Session.urlformat(url, **params))
+
+    def test_urlquote(self):
+        """Should url-encode input values."""
+        test_cases = (
+            # Basic latin string
+            ("test.xlsx", "test"),
+            # integer
+            ("1.xls", "1"),
+            # latin symbols
+            ("+-_*%*.xls", "%2B-_%2A%25%2A"),
+            # lower case e, with combining acute accent (2 symbols)
+            ("tést.xlsx", "te%CC%81st"),
+            # lower case e with acute (1 symbol)
+            ("tést", "t%C3%A9st"),
+            # white heavy check mark
+            ("✅.xlsx", "%E2%9C%85"),
+        )
+        for params, expected in test_cases:
+            with self.subTest(msg=str(params)):
+                self.assertEqual(expected, Session.urlquote(Path(params).stem))