Skip to content

Commit

Permalink
Style deserializing reimplementation
Browse files Browse the repository at this point in the history
  • Loading branch information
p1c2u committed Oct 17, 2023
1 parent df1f1e1 commit 45cd3f9
Show file tree
Hide file tree
Showing 12 changed files with 721 additions and 204 deletions.
22 changes: 21 additions & 1 deletion openapi_core/deserializing/styles/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,27 @@
from openapi_core.deserializing.styles.datatypes import StyleDeserializersDict
from openapi_core.deserializing.styles.factories import (
StyleDeserializersFactory,
)
from openapi_core.deserializing.styles.util import deep_object_loads
from openapi_core.deserializing.styles.util import form_loads
from openapi_core.deserializing.styles.util import label_loads
from openapi_core.deserializing.styles.util import matrix_loads
from openapi_core.deserializing.styles.util import pipe_delimited_loads
from openapi_core.deserializing.styles.util import simple_loads
from openapi_core.deserializing.styles.util import space_delimited_loads

__all__ = ["style_deserializers_factory"]

style_deserializers_factory = StyleDeserializersFactory()
style_deserializers: StyleDeserializersDict = {
"matrix": matrix_loads,
"label": label_loads,
"form": form_loads,
"simple": simple_loads,
"spaceDelimited": space_delimited_loads,
"pipeDelimited": pipe_delimited_loads,
"deepObject": deep_object_loads,
}

style_deserializers_factory = StyleDeserializersFactory(
style_deserializers=style_deserializers,
)
6 changes: 5 additions & 1 deletion openapi_core/deserializing/styles/datatypes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Mapping

DeserializerCallable = Callable[[str], List[str]]
DeserializerCallable = Callable[[bool, str, str, Mapping[str, Any]], Any]
StyleDeserializersDict = Dict[str, DeserializerCallable]
45 changes: 16 additions & 29 deletions openapi_core/deserializing/styles/deserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any
from typing import Callable
from typing import List
from typing import Mapping
from typing import Optional

from jsonschema_path import SchemaPath
Expand All @@ -11,46 +12,32 @@
from openapi_core.deserializing.styles.exceptions import (
EmptyQueryParameterValue,
)
from openapi_core.schema.parameters import get_aslist
from openapi_core.schema.parameters import get_explode


class CallableStyleDeserializer:
class StyleDeserializer:
def __init__(
self,
param_or_header: SchemaPath,
style: str,
explode: bool,
name: str,
schema_type: str,
deserializer_callable: Optional[DeserializerCallable] = None,
):
self.param_or_header = param_or_header
self.style = style
self.explode = explode
self.name = name
self.schema_type = schema_type
self.deserializer_callable = deserializer_callable

self.aslist = get_aslist(self.param_or_header)
self.explode = get_explode(self.param_or_header)

def deserialize(self, value: Any) -> Any:
def deserialize(self, location: Mapping[str, Any]) -> Any:
if self.deserializer_callable is None:
warnings.warn(f"Unsupported {self.style} style")
return value
return location[self.name]

# if "in" not defined then it's a Header
if "allowEmptyValue" in self.param_or_header:
warnings.warn(
"Use of allowEmptyValue property is deprecated",
DeprecationWarning,
# try:
if True:
return self.deserializer_callable(
self.explode, self.name, self.schema_type, location
)
allow_empty_values = self.param_or_header.getkey(
"allowEmptyValue", False
)
location_name = self.param_or_header.getkey("in", "header")
if location_name == "query" and value == "" and not allow_empty_values:
name = self.param_or_header["name"]
raise EmptyQueryParameterValue(name)

if not self.aslist or self.explode:
return value
try:
return self.deserializer_callable(value)
except (ValueError, TypeError, AttributeError):
raise DeserializeError(location_name, self.style, value)
# except (ValueError, TypeError, AttributeError):
# raise DeserializeError(self.style, self.name)
37 changes: 23 additions & 14 deletions openapi_core/deserializing/styles/factories.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,39 @@
import re
from functools import partial
from typing import Any
from typing import Dict
from typing import Mapping
from typing import Optional

from jsonschema_path import SchemaPath

from openapi_core.deserializing.styles.datatypes import DeserializerCallable
from openapi_core.deserializing.styles.deserializers import (
CallableStyleDeserializer,
)
from openapi_core.deserializing.styles.datatypes import StyleDeserializersDict
from openapi_core.deserializing.styles.deserializers import StyleDeserializer
from openapi_core.deserializing.styles.util import split
from openapi_core.schema.parameters import get_explode
from openapi_core.schema.parameters import get_style


class StyleDeserializersFactory:
STYLE_DESERIALIZERS: Dict[str, DeserializerCallable] = {
"form": partial(split, separator=","),
"simple": partial(split, separator=","),
"spaceDelimited": partial(split, separator=" "),
"pipeDelimited": partial(split, separator="|"),
"deepObject": partial(re.split, pattern=r"\[|\]"),
}
def __init__(
self,
style_deserializers: Optional[StyleDeserializersDict] = None,
):
if style_deserializers is None:
style_deserializers = {}
self.style_deserializers = style_deserializers

def create(self, param_or_header: SchemaPath) -> CallableStyleDeserializer:
def create(
self, param_or_header: SchemaPath, name: Optional[str] = None
) -> StyleDeserializer:
name = name or param_or_header["name"]
style = get_style(param_or_header)
explode = get_explode(param_or_header)
schema = param_or_header / "schema"
schema_type = schema.getkey("type", "")

deserialize_callable = self.STYLE_DESERIALIZERS.get(style)
return CallableStyleDeserializer(
param_or_header, style, deserialize_callable
deserialize_callable = self.style_deserializers.get(style)
return StyleDeserializer(
style, explode, name, schema_type, deserialize_callable
)
210 changes: 208 additions & 2 deletions openapi_core/deserializing/styles/util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,211 @@
import re
from functools import partial
from typing import Any
from typing import List
from typing import Mapping
from typing import Optional

from openapi_core.schema.protocols import SuportsGetAll
from openapi_core.schema.protocols import SuportsGetList

def split(value: str, separator: str = ",") -> List[str]:
return value.split(separator)

def split(value: str, separator: str = ",", step: int = 1) -> List[str]:
parts = value.split(separator)

if step == 1:
return parts

result = []
for i in range(len(parts)):
if i % step == 0:
if i + 1 < len(parts):
result.append(parts[i] + separator + parts[i + 1])
return result


def delimited_loads(
explode: bool,
name: str,
schema_type: str,
location: Mapping[str, Any],
delimiter: str,
) -> Any:
value = location[name]

explode_type = (explode, schema_type)
if explode_type == (False, "array"):
return split(value, separator=delimiter)
if explode_type == (False, "object"):
return dict(
map(
partial(split, separator=delimiter),
split(value, separator=delimiter, step=2),
)
)

return value


def matrix_loads(
explode: bool, name: str, schema_type: str, location: Mapping[str, Any]
) -> Any:
if explode == False:
m = re.match(rf"^;{name}=(.*)$", location[f";{name}"])
if m is None:
raise KeyError(name)
value = m.group(1)
# .;color=blue
if schema_type == "string":
return value
# ;color=blue,black,brown
if schema_type == "array":
return split(value)
# ;color=R,100,G,200,B,150
if schema_type == "object":
return dict(map(split, split(value, step=2)))
else:
# ;color=blue
if schema_type == "string":
m = re.match(rf"^;{name}=(.*)$", location[f";{name}*"])
if m is None:
raise KeyError(name)
value = m.group(1)
return value
# ;color=blue;color=black;color=brown
if schema_type == "array":
return re.findall(rf";{name}=([^;]*)", location[f";{name}*"])
# ;R=100;G=200;B=150
if schema_type == "object":
value = location[f";{name}*"]
return dict(
map(
partial(split, separator="="),
split(value[1:], separator=";"),
)
)
return location[name]


def label_loads(
explode: bool, name: str, schema_type: str, location: Mapping[str, Any]
) -> Any:
if explode == False:
value = location[f".{name}"]
# .blue
if schema_type == "string":
return value[1:]
# .blue,black,brown
if schema_type == "array":
return split(value[1:])
# .R,100,G,200,B,150
if schema_type == "object":
return dict(map(split, split(value[1:], separator=",", step=2)))
else:
value = location[f".{name}*"]
# .blue
if schema_type == "string":
return value[1:]
# .blue.black.brown
if schema_type == "array":
return split(value[1:], separator=".")
# .R=100.G=200.B=150
if schema_type == "object":
return dict(
map(
partial(split, separator="="),
split(value[1:], separator="."),
)
)
return location[name]


def form_loads(
explode: bool, name: str, schema_type: str, location: Mapping[str, Any]
) -> Any:
if schema_type == "string":
return location[name]

explode_type = (explode, schema_type)
# color=blue,black,brown
if explode_type == (False, "array"):
return split(location[name], separator=",")
# color=blue&color=black&color=brown
elif explode_type == (True, "array"):
if name not in location:
raise KeyError(name)
if isinstance(location, SuportsGetAll):
return location.getall(name)
if isinstance(location, SuportsGetList):
return location.getlist(name)
return location[name]

value = location[name]
# color=R,100,G,200,B,150
if explode_type == (False, "object"):
return dict(map(split, split(value, separator=",", step=2)))
# R=100&G=200&B=150
elif explode_type == (True, "object"):
return dict(
map(partial(split, separator="="), split(value, separator="&"))
)

return value


def simple_loads(
explode: bool, name: str, schema_type: str, location: Mapping[str, Any]
) -> Any:
value = location[name]
if schema_type == "string":
return value

# blue,black,brown
if schema_type == "array":
return split(value, separator=",")

explode_type = (explode, schema_type)
# R,100,G,200,B,150
if explode_type == (False, "object"):
return dict(map(split, split(value, separator=",", step=2)))
# R=100,G=200,B=150
elif explode_type == (True, "object"):
return dict(
map(partial(split, separator="="), split(value, separator=","))
)

return value


def space_delimited_loads(
explode: bool, name: str, schema_type: str, location: Mapping[str, Any]
) -> Any:
return delimited_loads(
explode, name, schema_type, location, delimiter="%20"
)


def pipe_delimited_loads(
explode: bool, name: str, schema_type: str, location: Mapping[str, Any]
) -> Any:
return delimited_loads(explode, name, schema_type, location, delimiter="|")


def deep_object_loads(
explode: bool, name: str, schema_type: str, location: Mapping[str, Any]
) -> Any:
explode_type = (explode, schema_type)

if explode_type != (True, "object"):
return location[name]

keys_str = " ".join(location.keys())
if not re.search(rf"{name}\[\w+\]", keys_str):
raise KeyError(name)

values = {}
for key, value in location.items():
# Split the key from the brackets.
key_split = re.split(pattern=r"\[|\]", string=key)
if key_split[0] == name:
values[key_split[1]] = value
return values
Loading

0 comments on commit 45cd3f9

Please sign in to comment.