diff --git a/CveXplore/VERSION b/CveXplore/VERSION index 99a89b94..a487dd1d 100644 --- a/CveXplore/VERSION +++ b/CveXplore/VERSION @@ -1 +1 @@ -0.3.11 \ No newline at end of file +0.3.12 \ No newline at end of file diff --git a/CveXplore/main.py b/CveXplore/main.py index 1b0b529b..c3dd2356 100644 --- a/CveXplore/main.py +++ b/CveXplore/main.py @@ -171,6 +171,16 @@ def get_single_store_entries( ) ) + if entry_type == "cves": + if "id" in dict_filter: + if isinstance(dict_filter["id"], str): + dict_filter["id"] = self._validate_cve_id(dict_filter["id"]) + elif isinstance(dict_filter["id"], dict): + if "$in" in dict_filter["id"]: + dict_filter["id"]["$in"] = [ + self._validate_cve_id(x) for x in dict_filter["id"]["$in"] + ] + results = ( getattr(self.datasource, "store_{}".format(entry_type)) .find(dict_filter) @@ -207,12 +217,12 @@ def get_multi_store_entries( functools.partial(self.get_single_store_entries, limit=limit), *queries ) - joined_list = [] - - for result_list in results: - joined_list += result_list + the_results = [ + result_list for result_list in results if result_list is not None + ] - the_results = list(joined_list) + # flatten results + the_results = [item for row in the_results for item in row] if len(the_results) != 0: return the_results @@ -248,13 +258,7 @@ def cves_for_cpe(self, cpe_string: str) -> List[CveXploreObject] | None: return cves - def cve_by_id(self, cve_id: str) -> CveXploreObject | None: - """ - Method to retrieve a single CVE from the database by its CVE ID number. - The number format should be either CVE-2000-0001, cve-2000-0001 or 2000-0001. - """ - - # first try to match full cve number format + def _validate_cve_id(self, cve_id: str): reg_match = re.compile(r"[cC][vV][eE]-\d{4}-\d{4,10}") if reg_match.match(cve_id) is not None: cve_id = cve_id.upper() @@ -264,10 +268,19 @@ def cve_by_id(self, cve_id: str) -> CveXploreObject | None: cve_id = f"CVE-{cve_id}" else: raise CveNumberValidationError( - "Could not validate the CVE number. The number format should be either " + f"Could not validate the CVE number: {cve_id}. The number format should be either " "CVE-2000-0001, cve-2000-0001 or 2000-0001." ) + return cve_id + + def cve_by_id(self, cve_id: str) -> CveXploreObject | None: + """ + Method to retrieve a single CVE from the database by its CVE ID number. + The number format should be either CVE-2000-0001, cve-2000-0001 or 2000-0001. + """ + cve_id = self._validate_cve_id(cve_id=cve_id) + return self.get_single_store_entry("cves", {"id": cve_id}) def cves_by_id(self, *cve_ids: str) -> Union[Iterable[CveXploreObject], Iterable]: