Skip to content

Commit

Permalink
minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
P-T-I committed Oct 31, 2023
1 parent 5b93c1f commit b44f9d8
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 14 deletions.
2 changes: 1 addition & 1 deletion CveXplore/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.3.11
0.3.12
39 changes: 26 additions & 13 deletions CveXplore/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,16 @@ def get_single_store_entries(
)
)

if entry_type == "cves":
if "id" in dict_filter:
if isinstance(dict_filter["id"], str):
dict_filter["id"] = self._validate_cve_id(dict_filter["id"])
elif isinstance(dict_filter["id"], dict):
if "$in" in dict_filter["id"]:
dict_filter["id"]["$in"] = [
self._validate_cve_id(x) for x in dict_filter["id"]["$in"]
]

results = (
getattr(self.datasource, "store_{}".format(entry_type))
.find(dict_filter)
Expand Down Expand Up @@ -207,12 +217,12 @@ def get_multi_store_entries(
functools.partial(self.get_single_store_entries, limit=limit), *queries
)

joined_list = []

for result_list in results:
joined_list += result_list
the_results = [
result_list for result_list in results if result_list is not None
]

the_results = list(joined_list)
# flatten results
the_results = [item for row in the_results for item in row]

if len(the_results) != 0:
return the_results
Expand Down Expand Up @@ -248,13 +258,7 @@ def cves_for_cpe(self, cpe_string: str) -> List[CveXploreObject] | None:

return cves

def cve_by_id(self, cve_id: str) -> CveXploreObject | None:
"""
Method to retrieve a single CVE from the database by its CVE ID number.
The number format should be either CVE-2000-0001, cve-2000-0001 or 2000-0001.
"""

# first try to match full cve number format
def _validate_cve_id(self, cve_id: str):
reg_match = re.compile(r"[cC][vV][eE]-\d{4}-\d{4,10}")
if reg_match.match(cve_id) is not None:
cve_id = cve_id.upper()
Expand All @@ -264,10 +268,19 @@ def cve_by_id(self, cve_id: str) -> CveXploreObject | None:
cve_id = f"CVE-{cve_id}"
else:
raise CveNumberValidationError(
"Could not validate the CVE number. The number format should be either "
f"Could not validate the CVE number: {cve_id}. The number format should be either "
"CVE-2000-0001, cve-2000-0001 or 2000-0001."
)

return cve_id

def cve_by_id(self, cve_id: str) -> CveXploreObject | None:
"""
Method to retrieve a single CVE from the database by its CVE ID number.
The number format should be either CVE-2000-0001, cve-2000-0001 or 2000-0001.
"""
cve_id = self._validate_cve_id(cve_id=cve_id)

return self.get_single_store_entry("cves", {"id": cve_id})

def cves_by_id(self, *cve_ids: str) -> Union[Iterable[CveXploreObject], Iterable]:
Expand Down

0 comments on commit b44f9d8

Please sign in to comment.