diff --git a/catalystwan/endpoints/__init__.py b/catalystwan/endpoints/__init__.py index e8cc249c..0ecc89c2 100644 --- a/catalystwan/endpoints/__init__.py +++ b/catalystwan/endpoints/__init__.py @@ -131,9 +131,7 @@ def resolve_nested_base_model_unions( if annotated_origin := get_args(annotation): if (len(annotated_origin) >= 1) and get_origin(annotated_origin[0]) == Union: type_args = get_args(annotated_origin[0]) - if all(isclass(t) for t in type_args) and all( - issubclass(t, BaseModel) for t in type_args - ): + if all(isclass(t) for t in type_args) and all(issubclass(t, BaseModel) for t in type_args): models_types.extend(list(type_args)) return models_types else: @@ -203,16 +201,18 @@ class APIEndpoints: """ @classmethod - def _prepare_payload(cls, payload: PayloadType, force_json: bool = False) -> PreparedPayload: + def _prepare_payload( + cls, payload: PayloadType, force_json: bool = False, context: Dict[str, Any] = {} + ) -> PreparedPayload: """Helper method to prepare data for sending based on type""" if force_json or isinstance(payload, dict): return PreparedPayload(data=json.dumps(payload), headers={"content-type": "application/json"}) if isinstance(payload, (str, bytes)): return PreparedPayload(data=payload) elif isinstance(payload, (BaseModel)): - return cls._prepare_basemodel_payload(payload) + return cls._prepare_basemodel_payload(payload, context) elif isinstance(payload, Sequence) and not isinstance(payload, (str, bytes)): - return cls._prepare_sequence_payload(payload) # type: ignore[arg-type] + return cls._prepare_sequence_payload(payload, context) # type: ignore[arg-type] # offender is List[JSON] which is also a Sequence can be ignored as long as force_json is passed correctly elif isinstance(payload, CustomPayloadType): return payload.prepared() @@ -220,26 +220,27 @@ def _prepare_payload(cls, payload: PayloadType, force_json: bool = False) -> Pre raise APIRequestPayloadTypeError(payload) @classmethod - def _prepare_basemodel_payload(cls, payload: BaseModel) -> PreparedPayload: + def _prepare_basemodel_payload(cls, payload: BaseModel, context: Dict[str, Any] = {}) -> PreparedPayload: """Helper method to prepare BaseModel instance for sending""" return PreparedPayload( - data=payload.model_dump_json(exclude_none=True, by_alias=True), headers={"content-type": "application/json"} + data=payload.model_dump_json(exclude_none=True, by_alias=True, context=context), + headers={"content-type": "application/json"}, ) @classmethod - def _prepare_sequence_payload(cls, payload: Iterable[BaseModel]) -> PreparedPayload: + def _prepare_sequence_payload(cls, payload: Iterable[BaseModel], context: Dict[str, Any] = {}) -> PreparedPayload: """Helper method to prepare sequences for sending""" items = [] for item in payload: - items.append(item.model_dump(exclude_none=True, by_alias=True)) + items.append(item.model_dump(exclude_none=True, by_alias=True, context=context)) data = json.dumps(items) return PreparedPayload(data=data, headers={"content-type": "application/json"}) @classmethod - def _prepare_params(cls, params: RequestParamsType) -> Dict[str, Any]: + def _prepare_params(cls, params: RequestParamsType, context: Dict[str, Any] = {}) -> Dict[str, Any]: """Helper method to prepare params for sending""" if isinstance(params, BaseModel): - return params.model_dump(exclude_none=True, by_alias=True) + return params.model_dump(exclude_none=True, by_alias=True, context=context) return params def __init__(self, client: APIEndpointClient): @@ -257,10 +258,11 @@ def _request( ) -> APIEndpointClientResponse: """Prepares and sends request using client protocol""" _kwargs = dict(kwargs) + context = dict(api_version=self._api_version) if payload is not None: - _kwargs.update(self._prepare_payload(payload, force_json_payload).asdict()) + _kwargs.update(self._prepare_payload(payload, force_json_payload, context).asdict()) if params is not None: - _kwargs.update({"params": self._prepare_params(params)}) + _kwargs.update({"params": self._prepare_params(params, context)}) return self._client.request(method, self._basepath + url, **_kwargs) @property diff --git a/catalystwan/models/common.py b/catalystwan/models/common.py index b8f7cc29..6a1565ff 100644 --- a/catalystwan/models/common.py +++ b/catalystwan/models/common.py @@ -1,13 +1,72 @@ # Copyright 2023 Cisco Systems, Inc. and its affiliates -from typing import Dict, List, Literal, Optional, Sequence, Set, Tuple, Union +from dataclasses import InitVar, dataclass, field +from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple, Union from uuid import UUID -from pydantic import PlainSerializer +from packaging.specifiers import SpecifierSet # type: ignore +from packaging.version import Version # type: ignore +from pydantic import PlainSerializer, SerializationInfo +from pydantic.fields import FieldInfo from pydantic.functional_validators import BeforeValidator from typing_extensions import Annotated +@dataclass() +class VersionedField: + """ + This class could be used as field type annotation for pydantic.BaseModel fields. + Together with dedicated @model_serializer it allows pick different serialization alias. + When version provided as specifier set eg. ">=20.13" matches Manager API version detected at runtime + original serialization_alias will be overriden. + + Example: + >>> from catalystwan.models.common import VersionedField + >>> from pydantic import BaseModel, SerializationInfo, SerializerFunctionWrapHandler, model_serializer + >>> from typing_extensions import Annotated + >>> + >>> class Payload(BaseModel): + >>> snake_case: Annotated[int, VersionedField(versions="<=20.12", serialization_alias="kebab-case")] + >>> + >>> @model_serializer(mode="wrap", when_used="json") + >>> def serialize(self, handler: SerializerFunctionWrapHandler, info: SerializationInfo) -> Dict[str, Any]: + >>> return VersionedField.update_model_fields(self.model_fields, handler(self), info) + """ + + versions: InitVar[str] + versions_set: SpecifierSet = field(init=False) + serialization_alias: str + + def __post_init__(self, versions): + self.versions_set = SpecifierSet(versions) + + @staticmethod + def update_model_fields( + model_fields: Dict[str, FieldInfo], model_dict: Dict[str, Any], serialization_info: SerializationInfo + ) -> Dict[str, Any]: + """To be reused in methods decorated with pydantic.model_serializer + Args: + model_fields (Dict[str, FieldInfo]): obtained from BaseModel class + model_dict (Dict[str, Any]): obtained from serialized BaseModel instance + serialization_info (SerializationInfo): passed from serializer + + Returns: + Dict[str, Any]: model_dict with updated field names according to matching runtime version + """ + if serialization_info.context is not None: + api_version: Optional[Version] = serialization_info.context.get("api_version") + if api_version is not None: + for field_name, field_info in model_fields.items(): + versioned_fields = [meta for meta in field_info.metadata if isinstance(meta, VersionedField)] + for versioned_field in versioned_fields: + if api_version in versioned_field.versions_set: + current_field_name = field_info.serialization_alias or field_info.alias or field_name + if model_dict.get(current_field_name) is not None: + model_dict[versioned_field.serialization_alias] = model_dict[current_field_name] + del model_dict[current_field_name] + return model_dict + + def check_fields_exclusive(values: Dict, field_names: Set[str], at_least_one: bool = False) -> bool: """Helper method to check fields are mutually exclusive diff --git a/catalystwan/session.py b/catalystwan/session.py index 0f95edcc..592eb41e 100644 --- a/catalystwan/session.py +++ b/catalystwan/session.py @@ -173,7 +173,7 @@ def __init__( self.api = APIContainer(self) self.endpoints = APIEndpointContainter(self) self._platform_version: str = "" - self._api_version: Version + self._api_version: Version = NullVersion self._state: ManagerSessionState = ManagerSessionState.OPERATIVE self.restart_timeout: int = 1200 self.polling_requests_timeout: int = 10 diff --git a/catalystwan/tests/test_endpoints.py b/catalystwan/tests/test_endpoints.py index 0216bbb6..b40511f8 100644 --- a/catalystwan/tests/test_endpoints.py +++ b/catalystwan/tests/test_endpoints.py @@ -13,7 +13,7 @@ import pytest # type: ignore from packaging.version import Version # type: ignore from parameterized import parameterized # type: ignore -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field, model_serializer from typing_extensions import Annotated from catalystwan.endpoints import ( @@ -29,6 +29,7 @@ from catalystwan.endpoints import logger as endpoints_logger from catalystwan.endpoints import post, put, request, versions, view from catalystwan.exceptions import APIEndpointError, APIRequestPayloadTypeError, APIVersionError, APIViewError +from catalystwan.models.common import VersionedField from catalystwan.typed_list import DataSequence from catalystwan.utils.session_type import ProviderAsTenantView, ProviderView, TenantView @@ -874,3 +875,36 @@ class TestAPI(APIEndpoints): @request("POST", "/v1/data") def create(self, payload: AnyBaseModel) -> None: # type: ignore [empty-body] ... + + @parameterized.expand( + [ + ("1.3", '{"name":"John"}'), + ("1.9", '{"newName":"John"}'), + ] + ) + def test_api_version_passed_in_dump_context(self, version, expected_payload_json): + # Arrange + class Payload(BaseModel): + model_config = ConfigDict(populate_by_name=True) + name: Annotated[str, VersionedField(versions=">1.6", serialization_alias="newName")] + + @model_serializer(mode="wrap") + def serialize(self, handler, info): + return VersionedField.update_model_fields(self.model_fields, handler(self), info) + + class ExampleAPI(APIEndpoints): + @request("POST", "/v1/data") + def create(self, payload: Payload) -> None: # type: ignore [empty-body] + ... + + self.session_mock.api_version = Version(version) + api = ExampleAPI(self.session_mock) + # Act + api.create(Payload(name="John")) + # Assert + self.session_mock.request.assert_called_once_with( + "POST", + self.base_path + "/v1/data", + data=expected_payload_json, + headers={"content-type": "application/json"}, + ) diff --git a/catalystwan/tests/test_models_common.py b/catalystwan/tests/test_models_common.py new file mode 100644 index 00000000..8139e302 --- /dev/null +++ b/catalystwan/tests/test_models_common.py @@ -0,0 +1,53 @@ +# Copyright 2024 Cisco Systems, Inc. and its affiliates + +import unittest +from typing import Any, Dict, Set + +from packaging.version import Version # type: ignore +from parameterized import parameterized # type: ignore +from pydantic import BaseModel, ConfigDict, Field, SerializationInfo, SerializerFunctionWrapHandler, model_serializer +from typing_extensions import Annotated + +from catalystwan.models.common import VersionedField + +A = Annotated[ + int, VersionedField(versions=">=1", serialization_alias="a-kebab"), Field(default=0, serialization_alias="aCamel") +] + +B = Annotated[ + float, + VersionedField(versions=">=2", serialization_alias="b-kebab"), +] + + +class VersionedFieldsModel(BaseModel): + model_config = ConfigDict(populate_by_name=True) + a: A + b: B = Field(default=0.0, serialization_alias="bCamel") + c: Annotated[bool, VersionedField(versions=">=3", serialization_alias="c-kebab")] = False + + @model_serializer(mode="wrap") + def dump(self, handler: SerializerFunctionWrapHandler, info: SerializationInfo) -> Dict[str, Any]: + return VersionedField.update_model_fields(self.model_fields, handler(self), info) + + +class Payload(BaseModel): + model_config = ConfigDict(populate_by_name=True) + data: VersionedFieldsModel + + +class TestModelsCommonVersionedField(unittest.TestCase): + def setUp(self): + self.model = Payload(data=VersionedFieldsModel()) + + @parameterized.expand( + [ + ("0.9", {"aCamel", "bCamel", "c"}), + ("1.0", {"a-kebab", "bCamel", "c"}), + ("2.1", {"a-kebab", "b-kebab", "c"}), + ("3.0.1", {"a-kebab", "b-kebab", "c-kebab"}), + ] + ) + def test_versioned_field_model_serialize(self, version: str, expected_fields: Set[str]): + data_dict = self.model.model_dump(by_alias=True, context={"api_version": Version(version)}).get("data") + assert expected_fields == data_dict.keys() diff --git a/pyproject.toml b/pyproject.toml index 1d92303c..72a7c245 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ flake8-quotes = "^3.3.1" clint = "^0.5.1" requests-toolbelt = "^1.0.0" packaging = "^23.0" -pydantic = "^2.5" +pydantic = "^2.7" typing-extensions = "^4.6.1" [tool.poetry.dev-dependencies]