diff --git a/openapi_core/deserializing/styles/__init__.py b/openapi_core/deserializing/styles/__init__.py index b5191289..f9ecef06 100644 --- a/openapi_core/deserializing/styles/__init__.py +++ b/openapi_core/deserializing/styles/__init__.py @@ -1,7 +1,27 @@ +from openapi_core.deserializing.styles.datatypes import StyleDeserializersDict from openapi_core.deserializing.styles.factories import ( StyleDeserializersFactory, ) +from openapi_core.deserializing.styles.util import deep_object_loads +from openapi_core.deserializing.styles.util import form_loads +from openapi_core.deserializing.styles.util import label_loads +from openapi_core.deserializing.styles.util import matrix_loads +from openapi_core.deserializing.styles.util import pipe_delimited_loads +from openapi_core.deserializing.styles.util import simple_loads +from openapi_core.deserializing.styles.util import space_delimited_loads __all__ = ["style_deserializers_factory"] -style_deserializers_factory = StyleDeserializersFactory() +style_deserializers: StyleDeserializersDict = { + "matrix": matrix_loads, + "label": label_loads, + "form": form_loads, + "simple": simple_loads, + "spaceDelimited": space_delimited_loads, + "pipeDelimited": pipe_delimited_loads, + "deepObject": deep_object_loads, +} + +style_deserializers_factory = StyleDeserializersFactory( + style_deserializers=style_deserializers, +) diff --git a/openapi_core/deserializing/styles/datatypes.py b/openapi_core/deserializing/styles/datatypes.py index f2a47c29..6e0b99f7 100644 --- a/openapi_core/deserializing/styles/datatypes.py +++ b/openapi_core/deserializing/styles/datatypes.py @@ -1,4 +1,8 @@ +from typing import Any from typing import Callable +from typing import Dict from typing import List +from typing import Mapping -DeserializerCallable = Callable[[str], List[str]] +DeserializerCallable = Callable[[bool, str, str, Mapping[str, Any]], Any] +StyleDeserializersDict = Dict[str, DeserializerCallable] diff --git a/openapi_core/deserializing/styles/deserializers.py b/openapi_core/deserializing/styles/deserializers.py index b29078a1..b6dbfd93 100644 --- a/openapi_core/deserializing/styles/deserializers.py +++ b/openapi_core/deserializing/styles/deserializers.py @@ -2,6 +2,7 @@ from typing import Any from typing import Callable from typing import List +from typing import Mapping from typing import Optional from jsonschema_path import SchemaPath @@ -11,46 +12,31 @@ from openapi_core.deserializing.styles.exceptions import ( EmptyQueryParameterValue, ) -from openapi_core.schema.parameters import get_aslist -from openapi_core.schema.parameters import get_explode -class CallableStyleDeserializer: +class StyleDeserializer: def __init__( self, - param_or_header: SchemaPath, style: str, + explode: bool, + name: str, + schema_type: str, deserializer_callable: Optional[DeserializerCallable] = None, ): - self.param_or_header = param_or_header self.style = style + self.explode = explode + self.name = name + self.schema_type = schema_type self.deserializer_callable = deserializer_callable - self.aslist = get_aslist(self.param_or_header) - self.explode = get_explode(self.param_or_header) - - def deserialize(self, value: Any) -> Any: + def deserialize(self, location: Mapping[str, Any]) -> Any: if self.deserializer_callable is None: warnings.warn(f"Unsupported {self.style} style") - return value - - # if "in" not defined then it's a Header - if "allowEmptyValue" in self.param_or_header: - warnings.warn( - "Use of allowEmptyValue property is deprecated", - DeprecationWarning, - ) - allow_empty_values = self.param_or_header.getkey( - "allowEmptyValue", False - ) - location_name = self.param_or_header.getkey("in", "header") - if location_name == "query" and value == "" and not allow_empty_values: - name = self.param_or_header["name"] - raise EmptyQueryParameterValue(name) + return location[self.name] - if not self.aslist or self.explode: - return value try: - return self.deserializer_callable(value) + return self.deserializer_callable( + self.explode, self.name, self.schema_type, location + ) except (ValueError, TypeError, AttributeError): - raise DeserializeError(location_name, self.style, value) + raise DeserializeError(self.style, self.name) diff --git a/openapi_core/deserializing/styles/factories.py b/openapi_core/deserializing/styles/factories.py index 578316bf..26a5f61e 100644 --- a/openapi_core/deserializing/styles/factories.py +++ b/openapi_core/deserializing/styles/factories.py @@ -1,30 +1,39 @@ import re from functools import partial +from typing import Any from typing import Dict +from typing import Mapping +from typing import Optional from jsonschema_path import SchemaPath from openapi_core.deserializing.styles.datatypes import DeserializerCallable -from openapi_core.deserializing.styles.deserializers import ( - CallableStyleDeserializer, -) +from openapi_core.deserializing.styles.datatypes import StyleDeserializersDict +from openapi_core.deserializing.styles.deserializers import StyleDeserializer from openapi_core.deserializing.styles.util import split +from openapi_core.schema.parameters import get_explode from openapi_core.schema.parameters import get_style class StyleDeserializersFactory: - STYLE_DESERIALIZERS: Dict[str, DeserializerCallable] = { - "form": partial(split, separator=","), - "simple": partial(split, separator=","), - "spaceDelimited": partial(split, separator=" "), - "pipeDelimited": partial(split, separator="|"), - "deepObject": partial(re.split, pattern=r"\[|\]"), - } + def __init__( + self, + style_deserializers: Optional[StyleDeserializersDict] = None, + ): + if style_deserializers is None: + style_deserializers = {} + self.style_deserializers = style_deserializers - def create(self, param_or_header: SchemaPath) -> CallableStyleDeserializer: + def create( + self, param_or_header: SchemaPath, name: Optional[str] = None + ) -> StyleDeserializer: + name = name or param_or_header["name"] style = get_style(param_or_header) + explode = get_explode(param_or_header) + schema = param_or_header / "schema" + schema_type = schema.getkey("type", "") - deserialize_callable = self.STYLE_DESERIALIZERS.get(style) - return CallableStyleDeserializer( - param_or_header, style, deserialize_callable + deserialize_callable = self.style_deserializers.get(style) + return StyleDeserializer( + style, explode, name, schema_type, deserialize_callable ) diff --git a/openapi_core/deserializing/styles/util.py b/openapi_core/deserializing/styles/util.py index 1f484f21..e04728a9 100644 --- a/openapi_core/deserializing/styles/util.py +++ b/openapi_core/deserializing/styles/util.py @@ -1,5 +1,202 @@ +import re +from functools import partial +from typing import Any from typing import List +from typing import Mapping +from typing import Optional +from openapi_core.schema.protocols import SuportsGetAll +from openapi_core.schema.protocols import SuportsGetList -def split(value: str, separator: str = ",") -> List[str]: - return value.split(separator) + +def split(value: str, separator: str = ",", step: int = 1) -> List[str]: + parts = value.split(separator) + + if step == 1: + return parts + + result = [] + for i in range(len(parts)): + if i % step == 0: + if i + 1 < len(parts): + result.append(parts[i] + separator + parts[i + 1]) + return result + + +def delimited_loads( + explode: bool, + name: str, + schema_type: str, + location: Mapping[str, Any], + delimiter: str, +) -> Any: + value = location[name] + + explode_type = (explode, schema_type) + if explode_type == (False, "array"): + return split(value, separator=delimiter) + if explode_type == (False, "object"): + return dict( + map( + partial(split, separator=delimiter), + split(value, separator=delimiter, step=2), + ) + ) + + raise ValueError("not available") + + +def matrix_loads( + explode: bool, name: str, schema_type: str, location: Mapping[str, Any] +) -> Any: + if explode == False: + m = re.match(rf"^;{name}=(.*)$", location[f";{name}"]) + if m is None: + raise KeyError(name) + value = m.group(1) + # ;color=blue,black,brown + if schema_type == "array": + return split(value) + # ;color=R,100,G,200,B,150 + if schema_type == "object": + return dict(map(split, split(value, step=2))) + # .;color=blue + return value + else: + # ;color=blue;color=black;color=brown + if schema_type == "array": + return re.findall(rf";{name}=([^;]*)", location[f";{name}*"]) + # ;R=100;G=200;B=150 + if schema_type == "object": + value = location[f";{name}*"] + return dict( + map( + partial(split, separator="="), + split(value[1:], separator=";"), + ) + ) + # ;color=blue + m = re.match(rf"^;{name}=(.*)$", location[f";{name}*"]) + if m is None: + raise KeyError(name) + value = m.group(1) + return value + + +def label_loads( + explode: bool, name: str, schema_type: str, location: Mapping[str, Any] +) -> Any: + if explode == False: + value = location[f".{name}"] + # .blue,black,brown + if schema_type == "array": + return split(value[1:]) + # .R,100,G,200,B,150 + if schema_type == "object": + return dict(map(split, split(value[1:], separator=",", step=2))) + # .blue + return value[1:] + else: + value = location[f".{name}*"] + # .blue.black.brown + if schema_type == "array": + return split(value[1:], separator=".") + # .R=100.G=200.B=150 + if schema_type == "object": + return dict( + map( + partial(split, separator="="), + split(value[1:], separator="."), + ) + ) + # .blue + return value[1:] + + +def form_loads( + explode: bool, name: str, schema_type: str, location: Mapping[str, Any] +) -> Any: + explode_type = (explode, schema_type) + # color=blue,black,brown + if explode_type == (False, "array"): + return split(location[name], separator=",") + # color=blue&color=black&color=brown + elif explode_type == (True, "array"): + if name not in location: + raise KeyError(name) + if isinstance(location, SuportsGetAll): + return location.getall(name) + if isinstance(location, SuportsGetList): + return location.getlist(name) + return location[name] + + value = location[name] + # color=R,100,G,200,B,150 + if explode_type == (False, "object"): + return dict(map(split, split(value, separator=",", step=2))) + # R=100&G=200&B=150 + elif explode_type == (True, "object"): + return dict( + map(partial(split, separator="="), split(value, separator="&")) + ) + + # color=blue + return value + + +def simple_loads( + explode: bool, name: str, schema_type: str, location: Mapping[str, Any] +) -> Any: + value = location[name] + + # blue,black,brown + if schema_type == "array": + return split(value, separator=",") + + explode_type = (explode, schema_type) + # R,100,G,200,B,150 + if explode_type == (False, "object"): + return dict(map(split, split(value, separator=",", step=2))) + # R=100,G=200,B=150 + elif explode_type == (True, "object"): + return dict( + map(partial(split, separator="="), split(value, separator=",")) + ) + + # blue + return value + + +def space_delimited_loads( + explode: bool, name: str, schema_type: str, location: Mapping[str, Any] +) -> Any: + return delimited_loads( + explode, name, schema_type, location, delimiter="%20" + ) + + +def pipe_delimited_loads( + explode: bool, name: str, schema_type: str, location: Mapping[str, Any] +) -> Any: + return delimited_loads(explode, name, schema_type, location, delimiter="|") + + +def deep_object_loads( + explode: bool, name: str, schema_type: str, location: Mapping[str, Any] +) -> Any: + explode_type = (explode, schema_type) + + if explode_type != (True, "object"): + raise ValueError("not available") + + keys_str = " ".join(location.keys()) + if not re.search(rf"{name}\[\w+\]", keys_str): + raise KeyError(name) + + values = {} + for key, value in location.items(): + # Split the key from the brackets. + key_split = re.split(pattern=r"\[|\]", string=key) + if key_split[0] == name: + values[key_split[1]] = value + return values diff --git a/openapi_core/schema/parameters.py b/openapi_core/schema/parameters.py index ec69cdf2..da1a5f16 100644 --- a/openapi_core/schema/parameters.py +++ b/openapi_core/schema/parameters.py @@ -1,4 +1,3 @@ -import re from typing import Any from typing import Dict from typing import Mapping @@ -10,18 +9,6 @@ from openapi_core.schema.protocols import SuportsGetList -def get_aslist(param_or_header: SchemaPath) -> bool: - """Checks if parameter/header is described as list for simpler scenarios""" - # if schema is not defined it's a complex scenario - if "schema" not in param_or_header: - return False - - schema = param_or_header / "schema" - schema_type = schema.getkey("type", "any") - # TODO: resolve for 'any' schema type - return schema_type in ["array", "object"] - - def get_style(param_or_header: SchemaPath) -> str: """Checks parameter/header style for simpler scenarios""" if "style" in param_or_header: @@ -44,16 +31,3 @@ def get_explode(param_or_header: SchemaPath) -> bool: # determine default style = get_style(param_or_header) return style == "form" - - -def get_deep_object_value( - location: Mapping[str, Any], - name: Optional[str] = None, -) -> Dict[str, Any]: - values = {} - for key, value in location.items(): - # Split the key from the brackets. - key_split = re.split(pattern=r"\[|\]", string=key) - if key_split[0] == name: - values[key_split[1]] = value - return values diff --git a/openapi_core/unmarshalling/unmarshallers.py b/openapi_core/unmarshalling/unmarshallers.py index 42c5a6b6..858b36a2 100644 --- a/openapi_core/unmarshalling/unmarshallers.py +++ b/openapi_core/unmarshalling/unmarshallers.py @@ -91,24 +91,25 @@ def _unmarshal_schema(self, schema: SchemaPath, value: Any) -> Any: ) return unmarshaller.unmarshal(value) - def _convert_schema_style_value( + def _get_param_or_header_and_schema( self, - raw: Any, param_or_header: SchemaPath, - ) -> Any: - casted, schema = self._convert_schema_style_value_and_schema( - raw, param_or_header + location: Mapping[str, Any], + name: Optional[str] = None, + ) -> Tuple[Any, Optional[SchemaPath]]: + casted, schema = super()._get_param_or_header_and_schema( + param_or_header, location, name=name ) if schema is None: - return casted - return self._unmarshal_schema(schema, casted) + return casted, None + return self._unmarshal_schema(schema, casted), schema - def _convert_content_schema_value( + def _get_content_and_schema( self, raw: Any, content: SchemaPath, mimetype: Optional[str] = None - ) -> Any: - casted, schema = self._convert_content_schema_value_and_schema( + ) -> Tuple[Any, Optional[SchemaPath]]: + casted, schema = super()._get_content_and_schema( raw, content, mimetype ) if schema is None: - return casted - return self._unmarshal_schema(schema, casted) + return casted, None + return self._unmarshal_schema(schema, casted), schema diff --git a/openapi_core/validation/request/validators.py b/openapi_core/validation/request/validators.py index 61a76149..19a59228 100644 --- a/openapi_core/validation/request/validators.py +++ b/openapi_core/validation/request/validators.py @@ -197,12 +197,14 @@ def _get_parameter( location = parameters[param_location] try: - return self._get_param_or_header(param, location, name=name) + value, _ = self._get_param_or_header_and_schema(param, location) except KeyError: required = param.getkey("required", False) if required: raise MissingRequiredParameter(name, param_location) raise MissingParameter(name, param_location) + else: + return value @ValidationErrorWrapper(SecurityValidationError, InvalidSecurity) def _get_security( @@ -255,7 +257,8 @@ def _get_body( content = request_body / "content" raw_body = self._get_body_value(body, request_body) - return self._convert_content_schema_value(raw_body, content, mimetype) + value, _ = self._get_content_and_schema(raw_body, content, mimetype) + return value def _get_body_value( self, body: Optional[str], request_body: SchemaPath diff --git a/openapi_core/validation/response/validators.py b/openapi_core/validation/response/validators.py index b5ff7088..078dd483 100644 --- a/openapi_core/validation/response/validators.py +++ b/openapi_core/validation/response/validators.py @@ -120,7 +120,8 @@ def _get_data( content = operation_response / "content" raw_data = self._get_data_value(data) - return self._convert_content_schema_value(raw_data, content, mimetype) + value, _ = self._get_content_and_schema(raw_data, content, mimetype) + return value def _get_data_value(self, data: str) -> Any: if not data: @@ -169,12 +170,16 @@ def _get_header( ) try: - return self._get_param_or_header(header, headers, name=name) + value, _ = self._get_param_or_header_and_schema( + header, headers, name=name + ) except KeyError: required = header.getkey("required", False) if required: raise MissingRequiredHeader(name) raise MissingHeader(name) + else: + return value class BaseAPICallResponseValidator( diff --git a/openapi_core/validation/validators.py b/openapi_core/validation/validators.py index a6e549cf..f1a34a63 100644 --- a/openapi_core/validation/validators.py +++ b/openapi_core/validation/validators.py @@ -1,5 +1,6 @@ """OpenAPI core validation validators module""" import re +import warnings from functools import cached_property from typing import Any from typing import Mapping @@ -22,17 +23,14 @@ MediaTypeDeserializersFactory, ) from openapi_core.deserializing.styles import style_deserializers_factory +from openapi_core.deserializing.styles.exceptions import ( + EmptyQueryParameterValue, +) from openapi_core.deserializing.styles.factories import ( StyleDeserializersFactory, ) from openapi_core.protocols import Request from openapi_core.protocols import WebhookRequest -from openapi_core.schema.parameters import get_aslist -from openapi_core.schema.parameters import get_deep_object_value -from openapi_core.schema.parameters import get_explode -from openapi_core.schema.parameters import get_style -from openapi_core.schema.protocols import SuportsGetAll -from openapi_core.schema.protocols import SuportsGetList from openapi_core.templating.media_types.datatypes import MediaType from openapi_core.templating.paths.datatypes import PathOperationServer from openapi_core.templating.paths.finders import APICallPathFinder @@ -109,10 +107,15 @@ def _deserialise_media_type( return deserializer.deserialize(value) def _deserialise_style( - self, param_or_header: SchemaPath, value: Any + self, + param_or_header: SchemaPath, + location: Mapping[str, Any], + name: Optional[str] = None, ) -> Any: - deserializer = self.style_deserializers_factory.create(param_or_header) - return deserializer.deserialize(value) + deserializer = self.style_deserializers_factory.create( + param_or_header, name=name + ) + return deserializer.deserialize(location) def _cast(self, schema: SchemaPath, value: Any) -> Any: caster = self.schema_casters_factory.create(schema) @@ -126,86 +129,80 @@ def _validate_schema(self, schema: SchemaPath, value: Any) -> None: ) validator.validate(value) - def _get_param_or_header( + def _get_param_or_header_and_schema( self, param_or_header: SchemaPath, location: Mapping[str, Any], name: Optional[str] = None, - ) -> Any: + ) -> Tuple[Any, Optional[SchemaPath]]: + schema: Optional[SchemaPath] = None # Simple scenario if "content" not in param_or_header: - return self._get_simple_param_or_header( + casted, schema = self._get_simple_param_or_header( param_or_header, location, name=name ) - # Complex scenario - return self._get_complex_param_or_header( - param_or_header, location, name=name - ) + else: + casted, schema = self._get_complex_param_or_header( + param_or_header, location, name=name + ) + + if schema is None: + return casted, None + self._validate_schema(schema, casted) + return casted, schema def _get_simple_param_or_header( self, param_or_header: SchemaPath, location: Mapping[str, Any], name: Optional[str] = None, - ) -> Any: + ) -> Tuple[Any, SchemaPath]: + allow_empty_values = param_or_header.getkey("allowEmptyValue") + if allow_empty_values: + warnings.warn( + "Use of allowEmptyValue property is deprecated", + DeprecationWarning, + ) + # in simple scenrios schema always exist + schema = param_or_header / "schema" try: - raw = self._get_style_value(param_or_header, location, name=name) + deserialised = self._deserialise_style( + param_or_header, location, name=name + ) except KeyError: - # in simple scenrios schema always exist - schema = param_or_header / "schema" if "default" not in schema: raise - raw = schema["default"] - return self._convert_schema_style_value(raw, param_or_header) + return schema["default"], schema + if allow_empty_values is not None: + warnings.warn( + "Use of allowEmptyValue property is deprecated", + DeprecationWarning, + ) + if allow_empty_values is None or not allow_empty_values: + # if "in" not defined then it's a Header + location_name = param_or_header.getkey("in", "header") + if ( + location_name == "query" + and deserialised == "" + and not allow_empty_values + ): + param_or_header_name = param_or_header["name"] + raise EmptyQueryParameterValue(param_or_header_name) + casted = self._cast(schema, deserialised) + return casted, schema def _get_complex_param_or_header( self, param_or_header: SchemaPath, location: Mapping[str, Any], name: Optional[str] = None, - ) -> Any: + ) -> Tuple[Any, Optional[SchemaPath]]: content = param_or_header / "content" - # no point to catch KetError - # in complex scenrios schema doesn't exist raw = self._get_media_type_value(param_or_header, location, name=name) - return self._convert_content_schema_value(raw, content) + return self._get_content_schema_value_and_schema(raw, content) - def _convert_schema_style_value( - self, - raw: Any, - param_or_header: SchemaPath, - ) -> Any: - casted, schema = self._convert_schema_style_value_and_schema( - raw, param_or_header - ) - if schema is None: - return casted - self._validate_schema(schema, casted) - return casted - - def _convert_content_schema_value( - self, raw: Any, content: SchemaPath, mimetype: Optional[str] = None - ) -> Any: - casted, schema = self._convert_content_schema_value_and_schema( - raw, content, mimetype - ) - if schema is None: - return casted - self._validate_schema(schema, casted) - return casted - - def _convert_schema_style_value_and_schema( - self, - raw: Any, - param_or_header: SchemaPath, - ) -> Tuple[Any, SchemaPath]: - deserialised = self._deserialise_style(param_or_header, raw) - schema = param_or_header / "schema" - casted = self._cast(schema, deserialised) - return casted, schema - - def _convert_content_schema_value_and_schema( + def _get_content_schema_value_and_schema( self, raw: Any, content: SchemaPath, @@ -214,6 +211,8 @@ def _convert_content_schema_value_and_schema( mime_type, parameters, media_type = self._find_media_type( content, mimetype ) + # no point to catch KetError + # in complex scenrios schema doesn't exist deserialised = self._deserialise_media_type(mime_type, parameters, raw) casted = self._cast(media_type, deserialised) @@ -223,35 +222,16 @@ def _convert_content_schema_value_and_schema( schema = media_type / "schema" return casted, schema - def _get_style_value( - self, - param_or_header: SchemaPath, - location: Mapping[str, Any], - name: Optional[str] = None, - ) -> Any: - name = name or param_or_header["name"] - style = get_style(param_or_header) - if name not in location: - # Only check if the name is not in the location if the style of - # the param is deepObject,this is because deepObjects will never be found - # as their key also includes the properties of the object already. - if style != "deepObject": - raise KeyError - keys_str = " ".join(location.keys()) - if not re.search(rf"{name}\[\w+\]", keys_str): - raise KeyError - - aslist = get_aslist(param_or_header) - explode = get_explode(param_or_header) - if aslist and explode: - if style == "deepObject": - return get_deep_object_value(location, name) - if isinstance(location, SuportsGetAll): - return location.getall(name) - if isinstance(location, SuportsGetList): - return location.getlist(name) - - return location[name] + def _get_content_and_schema( + self, raw: Any, content: SchemaPath, mimetype: Optional[str] = None + ) -> Tuple[Any, Optional[SchemaPath]]: + casted, schema = self._get_content_schema_value_and_schema( + raw, content, mimetype + ) + if schema is None: + return casted, None + self._validate_schema(schema, casted) + return casted, schema def _get_media_type_value( self, diff --git a/tests/integration/test_petstore.py b/tests/integration/test_petstore.py index 1c28dc36..81a78e68 100644 --- a/tests/integration/test_petstore.py +++ b/tests/integration/test_petstore.py @@ -406,12 +406,12 @@ def test_get_pets_tags_param(self, spec): assert is_dataclass(response_result.data) assert response_result.data.data == [] - def test_get_pets_parameter_deserialization_error(self, spec): + def test_get_pets_parameter_schema_error(self, spec): host_url = "http://petstore.swagger.io/v1" path_pattern = "/v1/pets" query_params = { - "limit": 1, - "tags": 12, + "limit": "1", + "tags": ",,", } request = MockRequest( @@ -428,7 +428,7 @@ def test_get_pets_parameter_deserialization_error(self, spec): spec=spec, cls=V30RequestParametersUnmarshaller, ) - assert type(exc_info.value.__cause__) is DeserializeError + assert type(exc_info.value.__cause__) is InvalidSchemaValue result = unmarshal_request( request, spec=spec, cls=V30RequestBodyUnmarshaller @@ -492,7 +492,8 @@ def test_get_pets_empty_value(self, spec): host_url = "http://petstore.swagger.io/v1" path_pattern = "/v1/pets" query_params = { - "limit": "", + "limit": "1", + "order": "", } request = MockRequest( diff --git a/tests/unit/deserializing/test_styles_deserializers.py b/tests/unit/deserializing/test_styles_deserializers.py index eed4130e..a6895a3a 100644 --- a/tests/unit/deserializing/test_styles_deserializers.py +++ b/tests/unit/deserializing/test_styles_deserializers.py @@ -1,54 +1,456 @@ import pytest from jsonschema_path import SchemaPath +from werkzeug.datastructures import ImmutableMultiDict +from openapi_core.deserializing.exceptions import DeserializeError +from openapi_core.deserializing.styles import style_deserializers_factory from openapi_core.deserializing.styles.exceptions import ( EmptyQueryParameterValue, ) -from openapi_core.deserializing.styles.factories import ( - StyleDeserializersFactory, -) -class TestStyleDeserializer: +class TestParameterStyleDeserializer: @pytest.fixture def deserializer_factory(self): - def create_deserializer(param): - return StyleDeserializersFactory().create(param) + def create_deserializer(param, name=None): + return style_deserializers_factory.create(param, name=name) return create_deserializer - def test_unsupported(self, deserializer_factory): - spec = {"name": "param", "in": "header", "style": "unsupported"} + @pytest.mark.parametrize( + "location_name", ["cookie", "header", "query", "path"] + ) + @pytest.mark.parametrize("value", ["", "test"]) + def test_unsupported(self, deserializer_factory, location_name, value): + name = "param" + schema_type = "string" + spec = { + "name": name, + "in": location_name, + "style": "unsupported", + "schema": { + "type": schema_type, + }, + } param = SchemaPath.from_dict(spec) deserializer = deserializer_factory(param) - value = "" + location = {name: value} with pytest.warns(UserWarning): - result = deserializer.deserialize(value) + result = deserializer.deserialize(location) assert result == value - def test_query_empty(self, deserializer_factory): + @pytest.mark.parametrize( + "location_name,style,explode,schema_type,location", + [ + ("query", "matrix", False, "string", {";param": "invalid"}), + ("query", "matrix", False, "array", {";param": "invalid"}), + ("query", "matrix", False, "object", {";param": "invalid"}), + ("query", "matrix", True, "string", {";param*": "invalid"}), + ("query", "deepObject", True, "object", {"param": "invalid"}), + ("query", "form", True, "array", {}), + ], + ) + def test_name_not_found( + self, + deserializer_factory, + location_name, + style, + explode, + schema_type, + location, + ): + name = "param" + spec = { + "name": name, + "in": location_name, + "style": style, + "explode": explode, + "schema": { + "type": schema_type, + }, + } + param = SchemaPath.from_dict(spec) + deserializer = deserializer_factory(param) + + with pytest.raises(KeyError): + deserializer.deserialize(location) + + @pytest.mark.parametrize( + "location_name,style,explode,schema_type,location", + [ + ("path", "deepObject", False, "string", {"param": "invalid"}), + ("path", "deepObject", False, "array", {"param": "invalid"}), + ("path", "deepObject", False, "object", {"param": "invalid"}), + ("path", "deepObject", True, "string", {"param": "invalid"}), + ("path", "deepObject", True, "array", {"param": "invalid"}), + ("path", "spaceDelimited", False, "string", {"param": "invalid"}), + ("path", "pipeDelimited", False, "string", {"param": "invalid"}), + ], + ) + def test_combination_not_available( + self, + deserializer_factory, + location_name, + style, + explode, + schema_type, + location, + ): + name = "param" + spec = { + "name": name, + "in": location_name, + "style": style, + "explode": explode, + "schema": { + "type": schema_type, + }, + } + param = SchemaPath.from_dict(spec) + deserializer = deserializer_factory(param) + + with pytest.raises(DeserializeError): + deserializer.deserialize(location) + + @pytest.mark.parametrize( + "explode,schema_type,location,expected", + [ + (False, "string", {";param": ";param=blue"}, "blue"), + (True, "string", {";param*": ";param=blue"}, "blue"), + ( + False, + "array", + {";param": ";param=blue,black,brown"}, + ["blue", "black", "brown"], + ), + ( + True, + "array", + {";param*": ";param=blue;param=black;param=brown"}, + ["blue", "black", "brown"], + ), + ( + False, + "object", + {";param": ";param=R,100,G,200,B,150"}, + { + "R": "100", + "G": "200", + "B": "150", + }, + ), + ( + True, + "object", + {";param*": ";R=100;G=200;B=150"}, + { + "R": "100", + "G": "200", + "B": "150", + }, + ), + ], + ) + def test_matrix_valid( + self, deserializer_factory, explode, schema_type, location, expected + ): + name = "param" + spec = { + "name": name, + "in": "path", + "style": "matrix", + "explode": explode, + "schema": { + "type": schema_type, + }, + } + param = SchemaPath.from_dict(spec) + deserializer = deserializer_factory(param) + + result = deserializer.deserialize(location) + + assert result == expected + + @pytest.mark.parametrize( + "explode,schema_type,location,expected", + [ + (False, "string", {".param": ".blue"}, "blue"), + (True, "string", {".param*": ".blue"}, "blue"), + ( + False, + "array", + {".param": ".blue,black,brown"}, + ["blue", "black", "brown"], + ), + ( + True, + "array", + {".param*": ".blue.black.brown"}, + ["blue", "black", "brown"], + ), + ( + False, + "object", + {".param": ".R,100,G,200,B,150"}, + { + "R": "100", + "G": "200", + "B": "150", + }, + ), + ( + True, + "object", + {".param*": ".R=100.G=200.B=150"}, + { + "R": "100", + "G": "200", + "B": "150", + }, + ), + ], + ) + def test_label_valid( + self, deserializer_factory, explode, schema_type, location, expected + ): + name = "param" spec = { - "name": "param", + "name": name, + "in": "path", + "style": "label", + "explode": explode, + "schema": { + "type": schema_type, + }, + } + param = SchemaPath.from_dict(spec) + deserializer = deserializer_factory(param) + + result = deserializer.deserialize(location) + + assert result == expected + + @pytest.mark.parametrize("location_name", ["query", "cookie"]) + @pytest.mark.parametrize( + "explode,schema_type,location,expected", + [ + (False, "string", {"param": "blue"}, "blue"), + (True, "string", {"param": "blue"}, "blue"), + ( + False, + "array", + {"param": "blue,black,brown"}, + ["blue", "black", "brown"], + ), + ( + True, + "array", + ImmutableMultiDict( + [("param", "blue"), ("param", "black"), ("param", "brown")] + ), + ["blue", "black", "brown"], + ), + ( + False, + "object", + {"param": "R,100,G,200,B,150"}, + { + "R": "100", + "G": "200", + "B": "150", + }, + ), + ( + True, + "object", + {"param": "R=100&G=200&B=150"}, + { + "R": "100", + "G": "200", + "B": "150", + }, + ), + ], + ) + def test_form_valid( + self, + deserializer_factory, + location_name, + explode, + schema_type, + location, + expected, + ): + name = "param" + spec = { + "name": name, + "in": location_name, + "explode": explode, + "schema": { + "type": schema_type, + }, + } + param = SchemaPath.from_dict(spec) + deserializer = deserializer_factory(param) + + result = deserializer.deserialize(location) + + assert result == expected + + @pytest.mark.parametrize("location_name", ["path", "header"]) + @pytest.mark.parametrize( + "explode,schema_type,value,expected", + [ + (False, "string", "blue", "blue"), + (True, "string", "blue", "blue"), + (False, "array", "blue,black,brown", ["blue", "black", "brown"]), + (True, "array", "blue,black,brown", ["blue", "black", "brown"]), + ( + False, + "object", + "R,100,G,200,B,150", + { + "R": "100", + "G": "200", + "B": "150", + }, + ), + ( + True, + "object", + "R=100,G=200,B=150", + { + "R": "100", + "G": "200", + "B": "150", + }, + ), + ], + ) + def test_simple_valid( + self, + deserializer_factory, + location_name, + explode, + schema_type, + value, + expected, + ): + name = "param" + spec = { + "name": name, + "in": location_name, + "explode": explode, + "schema": { + "type": schema_type, + }, + } + param = SchemaPath.from_dict(spec) + deserializer = deserializer_factory(param) + location = {name: value} + + result = deserializer.deserialize(location) + + assert result == expected + + @pytest.mark.parametrize( + "schema_type,value,expected", + [ + ("array", "blue%20black%20brown", ["blue", "black", "brown"]), + ( + "object", + "R%20100%20G%20200%20B%20150", + { + "R": "100", + "G": "200", + "B": "150", + }, + ), + ], + ) + def test_space_delimited_valid( + self, deserializer_factory, schema_type, value, expected + ): + name = "param" + spec = { + "name": name, "in": "query", + "style": "spaceDelimited", + "explode": False, + "schema": { + "type": schema_type, + }, } param = SchemaPath.from_dict(spec) deserializer = deserializer_factory(param) - value = "" + location = {name: value} + + result = deserializer.deserialize(location) - with pytest.raises(EmptyQueryParameterValue): - deserializer.deserialize(value) + assert result == expected - def test_query_valid(self, deserializer_factory): + @pytest.mark.parametrize( + "schema_type,value,expected", + [ + ("array", "blue|black|brown", ["blue", "black", "brown"]), + ( + "object", + "R|100|G|200|B|150", + { + "R": "100", + "G": "200", + "B": "150", + }, + ), + ], + ) + def test_pipe_delimited_valid( + self, deserializer_factory, schema_type, value, expected + ): + name = "param" spec = { - "name": "param", + "name": name, "in": "query", + "style": "pipeDelimited", + "explode": False, + "schema": { + "type": schema_type, + }, } param = SchemaPath.from_dict(spec) deserializer = deserializer_factory(param) - value = "test" + location = {name: value} - result = deserializer.deserialize(value) + result = deserializer.deserialize(location) - assert result == value + assert result == expected + + def test_deep_object_valid(self, deserializer_factory): + name = "param" + spec = { + "name": name, + "in": "query", + "style": "deepObject", + "explode": True, + "schema": { + "type": "object", + }, + } + param = SchemaPath.from_dict(spec) + deserializer = deserializer_factory(param) + location = { + "param[R]": "100", + "param[G]": "200", + "param[B]": "150", + "other[0]": "value", + } + + result = deserializer.deserialize(location) + + assert result == { + "R": "100", + "G": "200", + "B": "150", + }