-
-
Notifications
You must be signed in to change notification settings - Fork 140
/
middleware.py
118 lines (97 loc) · 3.64 KB
/
middleware.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from __future__ import annotations
import asyncio
import json
from typing import Any
from typing import Awaitable
from typing import Callable
from urllib.parse import unquote
from urllib.parse import urlsplit
from urllib.parse import urlunsplit
from django.http import HttpRequest
from django.http.response import HttpResponseBase
from django.utils.functional import cached_property
class HtmxMiddleware:
sync_capable = True
async_capable = True
def __init__(
self,
get_response: (
Callable[[HttpRequest], HttpResponseBase]
| Callable[[HttpRequest], Awaitable[HttpResponseBase]]
),
) -> None:
self.get_response = get_response
if asyncio.iscoroutinefunction(self.get_response):
# Mark the class as async-capable, but do the actual switch
# inside __call__ to avoid swapping out dunder methods
self._is_coroutine = (
asyncio.coroutines._is_coroutine # type: ignore [attr-defined]
)
else:
self._is_coroutine = None
def __call__(
self, request: HttpRequest
) -> HttpResponseBase | Awaitable[HttpResponseBase]:
if self._is_coroutine:
return self.__acall__(request)
request.htmx = HtmxDetails(request) # type: ignore [attr-defined]
return self.get_response(request)
async def __acall__(self, request: HttpRequest) -> HttpResponseBase:
request.htmx = HtmxDetails(request) # type: ignore [attr-defined]
result = self.get_response(request)
assert not isinstance(result, HttpResponseBase) # type narrow
return await result
class HtmxDetails:
def __init__(self, request: HttpRequest) -> None:
self.request = request
def _get_header_value(self, name: str) -> str | None:
value = self.request.headers.get(name) or None
if value:
if self.request.headers.get(f"{name}-URI-AutoEncoded") == "true":
value = unquote(value)
return value
def __bool__(self) -> bool:
return self._get_header_value("HX-Request") == "true"
@cached_property
def boosted(self) -> bool:
return self._get_header_value("HX-Boosted") == "true"
@cached_property
def current_url(self) -> str | None:
return self._get_header_value("HX-Current-URL")
@cached_property
def current_url_abs_path(self) -> str | None:
url = self.current_url
if url is not None:
split = urlsplit(url)
if (
split.scheme == self.request.scheme
and split.netloc == self.request.get_host()
):
url = urlunsplit(split._replace(scheme="", netloc=""))
else:
url = None
return url
@cached_property
def history_restore_request(self) -> bool:
return self._get_header_value("HX-History-Restore-Request") == "true"
@cached_property
def prompt(self) -> str | None:
return self._get_header_value("HX-Prompt")
@cached_property
def target(self) -> str | None:
return self._get_header_value("HX-Target")
@cached_property
def trigger(self) -> str | None:
return self._get_header_value("HX-Trigger")
@cached_property
def trigger_name(self) -> str | None:
return self._get_header_value("HX-Trigger-Name")
@cached_property
def triggering_event(self) -> Any:
value = self._get_header_value("Triggering-Event")
if value is not None:
try:
value = json.loads(value)
except json.JSONDecodeError:
value = None
return value