From 80b4c4b97642d954461b0640bc5708226e3d3615 Mon Sep 17 00:00:00 2001 From: Jakub Krajewski <95274389+jpkrajewski@users.noreply.github.com> Date: Thu, 25 Jul 2024 12:22:41 +0200 Subject: [PATCH] Fix threat (#781) * Fix threat * Fix tests * Fix return type --- catalystwan/api/administration.py | 6 ++-- .../endpoints/configuration_settings.py | 8 +++-- .../integration_tests/test_settings.py | 15 +++++---- catalystwan/models/settings.py | 32 ++++++++++++------- .../policy_converters/test_threat_grid_api.py | 7 ++-- 5 files changed, 39 insertions(+), 29 deletions(-) diff --git a/catalystwan/api/administration.py b/catalystwan/api/administration.py index cd9e6135..89c9956f 100644 --- a/catalystwan/api/administration.py +++ b/catalystwan/api/administration.py @@ -391,8 +391,8 @@ def update(self, payload: ThreatGridApi) -> bool: def update(self, payload: Union[Organization, Certificate, Password, Vbond, ThreatGridApi]) -> bool: if isinstance(payload, ThreatGridApi): - dataseq = self.__update_thread_grid_api(payload) - return len(dataseq) == 1 + created_threat = self.__update_thread_grid_api(payload) + return created_threat is not None json_payload = asdict(payload) # type: ignore if isinstance(payload, Organization): response = self.__update_organization(json_payload) @@ -423,7 +423,7 @@ def __update_vbond(self, payload: dict) -> Response: return self.session.post(endpoint, json=payload) def __update_thread_grid_api(self, payload: ThreatGridApi) -> DataSequence[ThreatGridApi]: - return self.session.endpoints.configuration_settings.create_threat_grid_api_key(payload) + return self.session.endpoints.configuration_settings.edit_threat_grid_api_key(payload) @deprecated( "Use .endpoints.configuration_settings.edit_organizations() instead", category=CatalystwanDeprecationWarning diff --git a/catalystwan/endpoints/configuration_settings.py b/catalystwan/endpoints/configuration_settings.py index a3274b7c..abed1be7 100644 --- a/catalystwan/endpoints/configuration_settings.py +++ b/catalystwan/endpoints/configuration_settings.py @@ -722,6 +722,10 @@ def edit_cloud_credentials(self, payload: CloudCredentials) -> DataSequence[Clou def create_cloud_credentials(self, payload: CloudCredentials) -> DataSequence[CloudCredentials]: ... - @post("/settings/configuration/threatGridApiKey", "data") - def create_threat_grid_api_key(self, payload: ThreatGridApi) -> DataSequence[ThreatGridApi]: + @put("/settings/configuration/threatGridApiKey") + def edit_threat_grid_api_key(self, payload: ThreatGridApi) -> ThreatGridApi: + ... + + @post("/settings/configuration/threatGridApiKey") + def create_threat_grid_api_key(self, payload: ThreatGridApi) -> ThreatGridApi: ... diff --git a/catalystwan/integration_tests/test_settings.py b/catalystwan/integration_tests/test_settings.py index c04c4008..f7e503c9 100644 --- a/catalystwan/integration_tests/test_settings.py +++ b/catalystwan/integration_tests/test_settings.py @@ -13,16 +13,17 @@ def test_thread_grid_api(self): thread.set_region_api_key("eur", "1234567890") thread.set_region_api_key("nam", "0987654321") # Act - response = self.session.endpoints.configuration_settings.create_threat_grid_api_key(thread) - self.created_thread = response.single_or_default() + created_thread = self.session.endpoints.configuration_settings.edit_threat_grid_api_key(thread) + dump = created_thread.model_dump_json(by_alias=True, exclude_none=True) + self.created_thread = ThreatGridApi.model_validate_json(dump) # Assert assert self.created_thread is not None - assert self.created_thread.entries[0].region == "nam" - assert self.created_thread.entries[0].apikey == "0987654321" - assert self.created_thread.entries[1].region == "eur" - assert self.created_thread.entries[1].apikey == "1234567890" + assert self.created_thread.data[0].entries[0].region == "nam" + assert self.created_thread.data[0].entries[0].apikey == "0987654321" + assert self.created_thread.data[0].entries[1].region == "eur" + assert self.created_thread.data[0].entries[1].apikey == "1234567890" def tearDown(self) -> None: if self.created_thread: - self.session.endpoints.configuration_settings.create_threat_grid_api_key(ThreatGridApi()) + self.session.endpoints.configuration_settings.edit_threat_grid_api_key(ThreatGridApi()) return super().tearDown() diff --git a/catalystwan/models/settings.py b/catalystwan/models/settings.py index 903d8c63..a06740fc 100644 --- a/catalystwan/models/settings.py +++ b/catalystwan/models/settings.py @@ -1,33 +1,41 @@ # Copyright 2024 Cisco Systems, Inc. and its affiliates -from typing import Any, Dict, List, Literal, Optional +from typing import List, Literal, Optional -from pydantic import BaseModel, ConfigDict, Field, SerializationInfo, SerializerFunctionWrapHandler, model_serializer +from pydantic import BaseModel, ConfigDict, Field Region = Literal["nam", "eur"] -class ThreadGridApiEntires(BaseModel): +class ThreatGridApiEntires(BaseModel): model_config = ConfigDict(extra="forbid") region: Region apikey: Optional[str] = Field(default="") -class ThreatGridApi(BaseModel): +class ThreatGridApiData(BaseModel): model_config = ConfigDict(extra="forbid") - entries: List[ThreadGridApiEntires] = Field( + entries: List[ThreatGridApiEntires] = Field( default_factory=lambda: [ - ThreadGridApiEntires(region="nam"), - ThreadGridApiEntires(region="eur"), - ] + ThreatGridApiEntires(region="nam"), + ThreatGridApiEntires(region="eur"), + ], ) + +class ThreatGridApi(BaseModel): + model_config = ConfigDict(extra="forbid") + data: List[ThreatGridApiData] = Field(default_factory=lambda: [ThreatGridApiData()]) + def set_region_api_key(self, region: Region, apikey: str) -> None: - for entry in self.entries: + for entry in self.data[0].entries: if entry.region == region: entry.apikey = apikey return raise ValueError(f"Region {region} not found in ThreatGridApi") - @model_serializer(mode="wrap") - def envelope_data(self, handler: SerializerFunctionWrapHandler, info: SerializationInfo) -> Dict[str, Any]: - return {"data": [handler(self)]} + def get_region_api_key(self, region: Region) -> str: + for entry in self.data[0].entries: + if entry.region == region: + apikey = entry.apikey + return "" if apikey is None else apikey + raise ValueError(f"Region {region} not found in ThreatGridApi") diff --git a/catalystwan/tests/config_migration/policy_converters/test_threat_grid_api.py b/catalystwan/tests/config_migration/policy_converters/test_threat_grid_api.py index 998359c4..f7ebee2d 100644 --- a/catalystwan/tests/config_migration/policy_converters/test_threat_grid_api.py +++ b/catalystwan/tests/config_migration/policy_converters/test_threat_grid_api.py @@ -24,8 +24,5 @@ def test_threat_grid_api_conversion(self): convert(policy, context=self.context) threat = self.context.threat_grid_api # Assert - assert len(threat.entries) == 2 - assert threat.entries[0].region == "nam" - assert threat.entries[0].apikey == "456" - assert threat.entries[1].region == "eur" - assert threat.entries[1].apikey == "123" + assert threat.get_region_api_key("eur") == "123" + assert threat.get_region_api_key("nam") == "456"