Skip to content
This repository has been archived by the owner on Nov 21, 2024. It is now read-only.

Commit

Permalink
Fix threat (#781)
Browse files Browse the repository at this point in the history
* Fix threat

* Fix tests

* Fix return type
  • Loading branch information
jpkrajewski authored Jul 25, 2024
1 parent de1a188 commit 80b4c4b
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 29 deletions.
6 changes: 3 additions & 3 deletions catalystwan/api/administration.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,8 +391,8 @@ def update(self, payload: ThreatGridApi) -> bool:

def update(self, payload: Union[Organization, Certificate, Password, Vbond, ThreatGridApi]) -> bool:
if isinstance(payload, ThreatGridApi):
dataseq = self.__update_thread_grid_api(payload)
return len(dataseq) == 1
created_threat = self.__update_thread_grid_api(payload)
return created_threat is not None
json_payload = asdict(payload) # type: ignore
if isinstance(payload, Organization):
response = self.__update_organization(json_payload)
Expand Down Expand Up @@ -423,7 +423,7 @@ def __update_vbond(self, payload: dict) -> Response:
return self.session.post(endpoint, json=payload)

def __update_thread_grid_api(self, payload: ThreatGridApi) -> DataSequence[ThreatGridApi]:
return self.session.endpoints.configuration_settings.create_threat_grid_api_key(payload)
return self.session.endpoints.configuration_settings.edit_threat_grid_api_key(payload)

@deprecated(
"Use .endpoints.configuration_settings.edit_organizations() instead", category=CatalystwanDeprecationWarning
Expand Down
8 changes: 6 additions & 2 deletions catalystwan/endpoints/configuration_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,10 @@ def edit_cloud_credentials(self, payload: CloudCredentials) -> DataSequence[Clou
def create_cloud_credentials(self, payload: CloudCredentials) -> DataSequence[CloudCredentials]:
...

@post("/settings/configuration/threatGridApiKey", "data")
def create_threat_grid_api_key(self, payload: ThreatGridApi) -> DataSequence[ThreatGridApi]:
@put("/settings/configuration/threatGridApiKey")
def edit_threat_grid_api_key(self, payload: ThreatGridApi) -> ThreatGridApi:
...

@post("/settings/configuration/threatGridApiKey")
def create_threat_grid_api_key(self, payload: ThreatGridApi) -> ThreatGridApi:
...
15 changes: 8 additions & 7 deletions catalystwan/integration_tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,17 @@ def test_thread_grid_api(self):
thread.set_region_api_key("eur", "1234567890")
thread.set_region_api_key("nam", "0987654321")
# Act
response = self.session.endpoints.configuration_settings.create_threat_grid_api_key(thread)
self.created_thread = response.single_or_default()
created_thread = self.session.endpoints.configuration_settings.edit_threat_grid_api_key(thread)
dump = created_thread.model_dump_json(by_alias=True, exclude_none=True)
self.created_thread = ThreatGridApi.model_validate_json(dump)
# Assert
assert self.created_thread is not None
assert self.created_thread.entries[0].region == "nam"
assert self.created_thread.entries[0].apikey == "0987654321"
assert self.created_thread.entries[1].region == "eur"
assert self.created_thread.entries[1].apikey == "1234567890"
assert self.created_thread.data[0].entries[0].region == "nam"
assert self.created_thread.data[0].entries[0].apikey == "0987654321"
assert self.created_thread.data[0].entries[1].region == "eur"
assert self.created_thread.data[0].entries[1].apikey == "1234567890"

def tearDown(self) -> None:
if self.created_thread:
self.session.endpoints.configuration_settings.create_threat_grid_api_key(ThreatGridApi())
self.session.endpoints.configuration_settings.edit_threat_grid_api_key(ThreatGridApi())
return super().tearDown()
32 changes: 20 additions & 12 deletions catalystwan/models/settings.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,41 @@
# Copyright 2024 Cisco Systems, Inc. and its affiliates
from typing import Any, Dict, List, Literal, Optional
from typing import List, Literal, Optional

from pydantic import BaseModel, ConfigDict, Field, SerializationInfo, SerializerFunctionWrapHandler, model_serializer
from pydantic import BaseModel, ConfigDict, Field

Region = Literal["nam", "eur"]


class ThreadGridApiEntires(BaseModel):
class ThreatGridApiEntires(BaseModel):
model_config = ConfigDict(extra="forbid")
region: Region
apikey: Optional[str] = Field(default="")


class ThreatGridApi(BaseModel):
class ThreatGridApiData(BaseModel):
model_config = ConfigDict(extra="forbid")
entries: List[ThreadGridApiEntires] = Field(
entries: List[ThreatGridApiEntires] = Field(
default_factory=lambda: [
ThreadGridApiEntires(region="nam"),
ThreadGridApiEntires(region="eur"),
]
ThreatGridApiEntires(region="nam"),
ThreatGridApiEntires(region="eur"),
],
)


class ThreatGridApi(BaseModel):
model_config = ConfigDict(extra="forbid")
data: List[ThreatGridApiData] = Field(default_factory=lambda: [ThreatGridApiData()])

def set_region_api_key(self, region: Region, apikey: str) -> None:
for entry in self.entries:
for entry in self.data[0].entries:
if entry.region == region:
entry.apikey = apikey
return
raise ValueError(f"Region {region} not found in ThreatGridApi")

@model_serializer(mode="wrap")
def envelope_data(self, handler: SerializerFunctionWrapHandler, info: SerializationInfo) -> Dict[str, Any]:
return {"data": [handler(self)]}
def get_region_api_key(self, region: Region) -> str:
for entry in self.data[0].entries:
if entry.region == region:
apikey = entry.apikey
return "" if apikey is None else apikey
raise ValueError(f"Region {region} not found in ThreatGridApi")
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,5 @@ def test_threat_grid_api_conversion(self):
convert(policy, context=self.context)
threat = self.context.threat_grid_api
# Assert
assert len(threat.entries) == 2
assert threat.entries[0].region == "nam"
assert threat.entries[0].apikey == "456"
assert threat.entries[1].region == "eur"
assert threat.entries[1].apikey == "123"
assert threat.get_region_api_key("eur") == "123"
assert threat.get_region_api_key("nam") == "456"

0 comments on commit 80b4c4b

Please sign in to comment.