Skip to content
This repository has been archived by the owner on Nov 21, 2024. It is now read-only.

Commit

Permalink
enable context based serialization alias
Browse files Browse the repository at this point in the history
  • Loading branch information
sbasan committed Apr 18, 2024
1 parent cafc368 commit 3dfd23b
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 19 deletions.
30 changes: 16 additions & 14 deletions catalystwan/endpoints/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,7 @@ def resolve_nested_base_model_unions(
if annotated_origin := get_args(annotation):
if (len(annotated_origin) >= 1) and get_origin(annotated_origin[0]) == Union:
type_args = get_args(annotated_origin[0])
if all(isclass(t) for t in type_args) and all(
issubclass(t, BaseModel) for t in type_args
):
if all(isclass(t) for t in type_args) and all(issubclass(t, BaseModel) for t in type_args):
models_types.extend(list(type_args))
return models_types
else:
Expand Down Expand Up @@ -203,43 +201,46 @@ class APIEndpoints:
"""

@classmethod
def _prepare_payload(cls, payload: PayloadType, force_json: bool = False) -> PreparedPayload:
def _prepare_payload(
cls, payload: PayloadType, force_json: bool = False, context: Dict[str, Any] = {}
) -> PreparedPayload:
"""Helper method to prepare data for sending based on type"""
if force_json or isinstance(payload, dict):
return PreparedPayload(data=json.dumps(payload), headers={"content-type": "application/json"})
if isinstance(payload, (str, bytes)):
return PreparedPayload(data=payload)
elif isinstance(payload, (BaseModel)):
return cls._prepare_basemodel_payload(payload)
return cls._prepare_basemodel_payload(payload, context)
elif isinstance(payload, Sequence) and not isinstance(payload, (str, bytes)):
return cls._prepare_sequence_payload(payload) # type: ignore[arg-type]
return cls._prepare_sequence_payload(payload, context) # type: ignore[arg-type]
# offender is List[JSON] which is also a Sequence can be ignored as long as force_json is passed correctly
elif isinstance(payload, CustomPayloadType):
return payload.prepared()
else:
raise APIRequestPayloadTypeError(payload)

@classmethod
def _prepare_basemodel_payload(cls, payload: BaseModel) -> PreparedPayload:
def _prepare_basemodel_payload(cls, payload: BaseModel, context: Dict[str, Any] = {}) -> PreparedPayload:
"""Helper method to prepare BaseModel instance for sending"""
return PreparedPayload(
data=payload.model_dump_json(exclude_none=True, by_alias=True), headers={"content-type": "application/json"}
data=payload.model_dump_json(exclude_none=True, by_alias=True, context=context),
headers={"content-type": "application/json"},
)

@classmethod
def _prepare_sequence_payload(cls, payload: Iterable[BaseModel]) -> PreparedPayload:
def _prepare_sequence_payload(cls, payload: Iterable[BaseModel], context: Dict[str, Any] = {}) -> PreparedPayload:
"""Helper method to prepare sequences for sending"""
items = []
for item in payload:
items.append(item.model_dump(exclude_none=True, by_alias=True))
items.append(item.model_dump(exclude_none=True, by_alias=True, context=context))
data = json.dumps(items)
return PreparedPayload(data=data, headers={"content-type": "application/json"})

@classmethod
def _prepare_params(cls, params: RequestParamsType) -> Dict[str, Any]:
def _prepare_params(cls, params: RequestParamsType, context: Dict[str, Any] = {}) -> Dict[str, Any]:
"""Helper method to prepare params for sending"""
if isinstance(params, BaseModel):
return params.model_dump(exclude_none=True, by_alias=True)
return params.model_dump(exclude_none=True, by_alias=True, context=context)
return params

def __init__(self, client: APIEndpointClient):
Expand All @@ -257,10 +258,11 @@ def _request(
) -> APIEndpointClientResponse:
"""Prepares and sends request using client protocol"""
_kwargs = dict(kwargs)
context = dict(api_version=self._api_version)
if payload is not None:
_kwargs.update(self._prepare_payload(payload, force_json_payload).asdict())
_kwargs.update(self._prepare_payload(payload, force_json_payload, context).asdict())
if params is not None:
_kwargs.update({"params": self._prepare_params(params)})
_kwargs.update({"params": self._prepare_params(params, context)})
return self._client.request(method, self._basepath + url, **_kwargs)

@property
Expand Down
63 changes: 61 additions & 2 deletions catalystwan/models/common.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,72 @@
# Copyright 2023 Cisco Systems, Inc. and its affiliates

from typing import Dict, List, Literal, Optional, Sequence, Set, Tuple, Union
from dataclasses import InitVar, dataclass, field
from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple, Union
from uuid import UUID

from pydantic import PlainSerializer
from packaging.specifiers import SpecifierSet # type: ignore
from packaging.version import Version # type: ignore
from pydantic import PlainSerializer, SerializationInfo
from pydantic.fields import FieldInfo
from pydantic.functional_validators import BeforeValidator
from typing_extensions import Annotated


@dataclass()
class VersionedField:
"""
This class could be used as field type annotation for pydantic.BaseModel fields.
Together with dedicated @model_serializer it allows pick different serialization alias.
When version provided as specifier set eg. ">=20.13" matches Manager API version detected at runtime
original serialization_alias will be overriden.
Example:
>>> from catalystwan.models.common import VersionedField
>>> from pydantic import BaseModel, SerializationInfo, SerializerFunctionWrapHandler, model_serializer
>>> from typing_extensions import Annotated
>>>
>>> class Payload(BaseModel):
>>> snake_case: Annotated[int, VersionedField(versions="<=20.12", serialization_alias="kebab-case")]
>>>
>>> @model_serializer(mode="wrap", when_used="json")
>>> def serialize(self, handler: SerializerFunctionWrapHandler, info: SerializationInfo) -> Dict[str, Any]:
>>> return VersionedField.update_model_fields(self.model_fields, handler(self), info)
"""

versions: InitVar[str]
versions_set: SpecifierSet = field(init=False)
serialization_alias: str

def __post_init__(self, versions):
self.versions_set = SpecifierSet(versions)

@staticmethod
def update_model_fields(
model_fields: Dict[str, FieldInfo], model_dict: Dict[str, Any], serialization_info: SerializationInfo
) -> Dict[str, Any]:
"""To be reused in methods decorated with pydantic.model_serializer
Args:
model_fields (Dict[str, FieldInfo]): obtained from BaseModel class
model_dict (Dict[str, Any]): obtained from serialized BaseModel instance
serialization_info (SerializationInfo): passed from serializer
Returns:
Dict[str, Any]: model_dict with updated field names according to matching runtime version
"""
if serialization_info.context is not None:
api_version: Optional[Version] = serialization_info.context.get("api_version")
if api_version is not None:
for field_name, field_info in model_fields.items():
versioned_fields = [meta for meta in field_info.metadata if isinstance(meta, VersionedField)]
for versioned_field in versioned_fields:
if api_version in versioned_field.versions_set:
current_field_name = field_info.serialization_alias or field_info.alias or field_name
if model_dict.get(current_field_name) is not None:
model_dict[versioned_field.serialization_alias] = model_dict[current_field_name]
del model_dict[current_field_name]
return model_dict


def check_fields_exclusive(values: Dict, field_names: Set[str], at_least_one: bool = False) -> bool:
"""Helper method to check fields are mutually exclusive
Expand Down
2 changes: 1 addition & 1 deletion catalystwan/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def __init__(
self.api = APIContainer(self)
self.endpoints = APIEndpointContainter(self)
self._platform_version: str = ""
self._api_version: Version
self._api_version: Version = NullVersion # type: ignore
self._state: ManagerSessionState = ManagerSessionState.OPERATIVE
self.restart_timeout: int = 1200
self.polling_requests_timeout: int = 10
Expand Down
36 changes: 35 additions & 1 deletion catalystwan/tests/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import pytest # type: ignore
from packaging.version import Version # type: ignore
from parameterized import parameterized # type: ignore
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field, model_serializer
from typing_extensions import Annotated

from catalystwan.endpoints import (
Expand All @@ -29,6 +29,7 @@
from catalystwan.endpoints import logger as endpoints_logger
from catalystwan.endpoints import post, put, request, versions, view
from catalystwan.exceptions import APIEndpointError, APIRequestPayloadTypeError, APIVersionError, APIViewError
from catalystwan.models.common import VersionedField
from catalystwan.typed_list import DataSequence
from catalystwan.utils.session_type import ProviderAsTenantView, ProviderView, TenantView

Expand Down Expand Up @@ -874,3 +875,36 @@ class TestAPI(APIEndpoints):
@request("POST", "/v1/data")
def create(self, payload: AnyBaseModel) -> None: # type: ignore [empty-body]
...

@parameterized.expand(
[
("1.3", '{"name":"John"}'),
("1.9", '{"newName":"John"}'),
]
)
def test_api_version_passed_in_dump_context(self, version, expected_payload_json):
# Arrange
class Payload(BaseModel):
model_config = ConfigDict(populate_by_name=True)
name: Annotated[str, VersionedField(versions=">1.6", serialization_alias="newName")]

@model_serializer(mode="wrap")
def serialize(self, handler, info):
return VersionedField.update_model_fields(self.model_fields, handler(self), info)

class ExampleAPI(APIEndpoints):
@request("POST", "/v1/data")
def create(self, payload: Payload) -> None: # type: ignore [empty-body]
...

self.session_mock.api_version = Version(version)
api = ExampleAPI(self.session_mock)
# Act
api.create(Payload(name="John"))
# Assert
self.session_mock.request.assert_called_once_with(
"POST",
self.base_path + "/v1/data",
data=expected_payload_json,
headers={"content-type": "application/json"},
)
53 changes: 53 additions & 0 deletions catalystwan/tests/test_models_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2024 Cisco Systems, Inc. and its affiliates

import unittest
from typing import Any, Dict, Set

from packaging.version import Version # type: ignore
from parameterized import parameterized # type: ignore
from pydantic import BaseModel, ConfigDict, Field, SerializationInfo, SerializerFunctionWrapHandler, model_serializer
from typing_extensions import Annotated

from catalystwan.models.common import VersionedField

A = Annotated[
int, VersionedField(versions=">=1", serialization_alias="a-kebab"), Field(default=0, serialization_alias="aCamel")
]

B = Annotated[
float,
VersionedField(versions=">=2", serialization_alias="b-kebab"),
]


class VersionedFieldsModel(BaseModel):
model_config = ConfigDict(populate_by_name=True)
a: A
b: B = Field(default=0.0, serialization_alias="bCamel")
c: Annotated[bool, VersionedField(versions=">=3", serialization_alias="c-kebab")] = False

@model_serializer(mode="wrap")
def dump(self, handler: SerializerFunctionWrapHandler, info: SerializationInfo) -> Dict[str, Any]:
return VersionedField.update_model_fields(self.model_fields, handler(self), info)


class Payload(BaseModel):
model_config = ConfigDict(populate_by_name=True)
data: VersionedFieldsModel


class TestModelsCommonVersionedField(unittest.TestCase):
def setUp(self):
self.model = Payload(data=VersionedFieldsModel())

@parameterized.expand(
[
("0.9", {"aCamel", "bCamel", "c"}),
("1.0", {"a-kebab", "bCamel", "c"}),
("2.1", {"a-kebab", "b-kebab", "c"}),
("3.0.1", {"a-kebab", "b-kebab", "c-kebab"}),
]
)
def test_versioned_field_model_serialize(self, version: str, expected_fields: Set[str]):
data_dict = self.model.model_dump(by_alias=True, context={"api_version": Version(version)}).get("data")
assert expected_fields == data_dict.keys()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ flake8-quotes = "^3.3.1"
clint = "^0.5.1"
requests-toolbelt = "^1.0.0"
packaging = "^23.0"
pydantic = "^2.5"
pydantic = "^2.7"
typing-extensions = "^4.6.1"

[tool.poetry.dev-dependencies]
Expand Down

0 comments on commit 3dfd23b

Please sign in to comment.