Skip to content
This repository has been archived by the owner on Nov 21, 2024. It is now read-only.

Fix threat #781

Merged
merged 3 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions catalystwan/api/administration.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,8 +391,8 @@ def update(self, payload: ThreatGridApi) -> bool:

def update(self, payload: Union[Organization, Certificate, Password, Vbond, ThreatGridApi]) -> bool:
if isinstance(payload, ThreatGridApi):
dataseq = self.__update_thread_grid_api(payload)
return len(dataseq) == 1
created_threat = self.__update_thread_grid_api(payload)
return created_threat is not None
json_payload = asdict(payload) # type: ignore
if isinstance(payload, Organization):
response = self.__update_organization(json_payload)
Expand Down Expand Up @@ -423,7 +423,7 @@ def __update_vbond(self, payload: dict) -> Response:
return self.session.post(endpoint, json=payload)

def __update_thread_grid_api(self, payload: ThreatGridApi) -> DataSequence[ThreatGridApi]:
return self.session.endpoints.configuration_settings.create_threat_grid_api_key(payload)
return self.session.endpoints.configuration_settings.edit_threat_grid_api_key(payload)

@deprecated(
"Use .endpoints.configuration_settings.edit_organizations() instead", category=CatalystwanDeprecationWarning
Expand Down
8 changes: 6 additions & 2 deletions catalystwan/endpoints/configuration_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,10 @@ def edit_cloud_credentials(self, payload: CloudCredentials) -> DataSequence[Clou
def create_cloud_credentials(self, payload: CloudCredentials) -> DataSequence[CloudCredentials]:
...

@post("/settings/configuration/threatGridApiKey", "data")
def create_threat_grid_api_key(self, payload: ThreatGridApi) -> DataSequence[ThreatGridApi]:
@put("/settings/configuration/threatGridApiKey")
def edit_threat_grid_api_key(self, payload: ThreatGridApi) -> ThreatGridApi:
...

@post("/settings/configuration/threatGridApiKey")
def create_threat_grid_api_key(self, payload: ThreatGridApi) -> ThreatGridApi:
...
15 changes: 8 additions & 7 deletions catalystwan/integration_tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,17 @@ def test_thread_grid_api(self):
thread.set_region_api_key("eur", "1234567890")
thread.set_region_api_key("nam", "0987654321")
# Act
response = self.session.endpoints.configuration_settings.create_threat_grid_api_key(thread)
self.created_thread = response.single_or_default()
created_thread = self.session.endpoints.configuration_settings.edit_threat_grid_api_key(thread)
dump = created_thread.model_dump_json(by_alias=True, exclude_none=True)
self.created_thread = ThreatGridApi.model_validate_json(dump)
# Assert
assert self.created_thread is not None
assert self.created_thread.entries[0].region == "nam"
assert self.created_thread.entries[0].apikey == "0987654321"
assert self.created_thread.entries[1].region == "eur"
assert self.created_thread.entries[1].apikey == "1234567890"
assert self.created_thread.data[0].entries[0].region == "nam"
assert self.created_thread.data[0].entries[0].apikey == "0987654321"
assert self.created_thread.data[0].entries[1].region == "eur"
assert self.created_thread.data[0].entries[1].apikey == "1234567890"

def tearDown(self) -> None:
if self.created_thread:
self.session.endpoints.configuration_settings.create_threat_grid_api_key(ThreatGridApi())
self.session.endpoints.configuration_settings.edit_threat_grid_api_key(ThreatGridApi())
return super().tearDown()
32 changes: 20 additions & 12 deletions catalystwan/models/settings.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,41 @@
# Copyright 2024 Cisco Systems, Inc. and its affiliates
from typing import Any, Dict, List, Literal, Optional
from typing import List, Literal, Optional

from pydantic import BaseModel, ConfigDict, Field, SerializationInfo, SerializerFunctionWrapHandler, model_serializer
from pydantic import BaseModel, ConfigDict, Field

Region = Literal["nam", "eur"]


class ThreadGridApiEntires(BaseModel):
class ThreatGridApiEntires(BaseModel):
model_config = ConfigDict(extra="forbid")
region: Region
apikey: Optional[str] = Field(default="")


class ThreatGridApi(BaseModel):
class ThreatGridApiData(BaseModel):
model_config = ConfigDict(extra="forbid")
entries: List[ThreadGridApiEntires] = Field(
entries: List[ThreatGridApiEntires] = Field(
default_factory=lambda: [
ThreadGridApiEntires(region="nam"),
ThreadGridApiEntires(region="eur"),
]
ThreatGridApiEntires(region="nam"),
ThreatGridApiEntires(region="eur"),
],
)


class ThreatGridApi(BaseModel):
model_config = ConfigDict(extra="forbid")
data: List[ThreatGridApiData] = Field(default_factory=lambda: [ThreatGridApiData()])

def set_region_api_key(self, region: Region, apikey: str) -> None:
for entry in self.entries:
for entry in self.data[0].entries:
if entry.region == region:
entry.apikey = apikey
return
raise ValueError(f"Region {region} not found in ThreatGridApi")

@model_serializer(mode="wrap")
def envelope_data(self, handler: SerializerFunctionWrapHandler, info: SerializationInfo) -> Dict[str, Any]:
return {"data": [handler(self)]}
def get_region_api_key(self, region: Region) -> str:
for entry in self.data[0].entries:
if entry.region == region:
apikey = entry.apikey
return "" if apikey is None else apikey
raise ValueError(f"Region {region} not found in ThreatGridApi")
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,5 @@ def test_threat_grid_api_conversion(self):
convert(policy, context=self.context)
threat = self.context.threat_grid_api
# Assert
assert len(threat.entries) == 2
assert threat.entries[0].region == "nam"
assert threat.entries[0].apikey == "456"
assert threat.entries[1].region == "eur"
assert threat.entries[1].apikey == "123"
assert threat.get_region_api_key("eur") == "123"
assert threat.get_region_api_key("nam") == "456"