From facf2a7b7253554af36715ae32016f596cb940b7 Mon Sep 17 00:00:00 2001 From: Hritik Vijay Date: Tue, 17 Jan 2023 21:29:04 +0530 Subject: [PATCH] Add tests for improve_runner Signed-off-by: Hritik Vijay --- vulnerabilities/improve_runner.py | 22 ++- vulnerabilities/tests/test_improve_runner.py | 167 +++++++++++++++++++ vulnerabilities/tests/test_improver.py | 2 + 3 files changed, 183 insertions(+), 8 deletions(-) diff --git a/vulnerabilities/improve_runner.py b/vulnerabilities/improve_runner.py index e678d6562..a2325bab6 100644 --- a/vulnerabilities/improve_runner.py +++ b/vulnerabilities/improve_runner.py @@ -57,6 +57,7 @@ def run(self) -> None: @transaction.atomic def process_inferences(inferences: List[Inference], advisory: Advisory, improver_name: str): """ + Return number of inferences processed. An atomic transaction that updates both the Advisory (e.g. date_improved) and processes the given inferences to create or update corresponding database fields. @@ -65,10 +66,11 @@ def process_inferences(inferences: List[Inference], advisory: Advisory, improver erroneous. Also, the atomic transaction for every advisory and its inferences makes sure that date_improved of advisory is consistent. """ + inferences_processed_count = 0 if not inferences: - logger.warn(f"Nothing to improve. Source: {improver_name} Advisory id: {advisory.id}") - return + logger.warning(f"Nothing to improve. Source: {improver_name} Advisory id: {advisory.id}") + return inferences_processed_count logger.info(f"Improving advisory id: {advisory.id}") @@ -80,7 +82,7 @@ def process_inferences(inferences: List[Inference], advisory: Advisory, improver ) if not vulnerability: - logger.warn(f"Unable to get vulnerability for inference: {inference!r}") + logger.warning(f"Unable to get vulnerability for inference: {inference!r}") continue for ref in inference.references: @@ -143,8 +145,12 @@ def process_inferences(inferences: List[Inference], advisory: Advisory, improver cwe_obj, created = Weakness.objects.get_or_create(cwe_id=cwe_id) cwe_obj.vulnerabilities.add(vulnerability) cwe_obj.save() + + inferences_processed_count += 1 + advisory.date_improved = datetime.now(timezone.utc) advisory.save() + return inferences_processed_count def create_valid_vulnerability_reference(url, reference_id=None): @@ -168,7 +174,7 @@ def create_valid_vulnerability_reference(url, reference_id=None): return reference -def get_or_create_vulnerability_and_aliases(vulnerability_id, alias_names, summary): +def get_or_create_vulnerability_and_aliases(alias_names, vulnerability_id=None, summary=None): """ Get or create vulnerabilitiy and aliases such that all existing and new aliases point to the same vulnerability @@ -188,7 +194,7 @@ def get_or_create_vulnerability_and_aliases(vulnerability_id, alias_names, summa # TODO: It is possible that all those vulnerabilities are actually # the same at data level, figure out a way to merge them if len(existing_vulns) > 1: - logger.warn( + logger.warning( f"Given aliases {alias_names} already exist and do not point " f"to a single vulnerability. Cannot improve. Skipped." ) @@ -201,7 +207,7 @@ def get_or_create_vulnerability_and_aliases(vulnerability_id, alias_names, summa and vulnerability_id and existing_alias_vuln.vulnerability_id != vulnerability_id ): - logger.warn( + logger.warning( f"Given aliases {alias_names!r} already exist and point to existing" f"vulnerability {existing_alias_vuln}. Unable to create Vulnerability " f"with vulnerability_id {vulnerability_id}. Skipped" @@ -214,7 +220,7 @@ def get_or_create_vulnerability_and_aliases(vulnerability_id, alias_names, summa try: vulnerability = Vulnerability.objects.get(vulnerability_id=vulnerability_id) except Vulnerability.DoesNotExist: - logger.warn( + logger.warning( f"Given vulnerability_id: {vulnerability_id} does not exist in the database" ) return @@ -223,7 +229,7 @@ def get_or_create_vulnerability_and_aliases(vulnerability_id, alias_names, summa vulnerability.save() if summary and summary != vulnerability.summary: - logger.warn( + logger.warning( f"Inconsistent summary for {vulnerability!r}. " f"Existing: {vulnerability.summary}, provided: {summary}" ) diff --git a/vulnerabilities/tests/test_improve_runner.py b/vulnerabilities/tests/test_improve_runner.py index 1d425b211..9354b4688 100644 --- a/vulnerabilities/tests/test_improve_runner.py +++ b/vulnerabilities/tests/test_improve_runner.py @@ -7,9 +7,27 @@ # See https://aboutcode.org for more information about nexB OSS projects. # +from collections import Counter + import pytest +from django.utils import timezone +from packageurl import PackageURL +from pytest_django.asserts import assertQuerysetEqual +from vulnerabilities.importer import Reference from vulnerabilities.improve_runner import create_valid_vulnerability_reference +from vulnerabilities.improve_runner import get_or_create_vulnerability_and_aliases +from vulnerabilities.improve_runner import process_inferences +from vulnerabilities.improver import Improver +from vulnerabilities.improver import Inference +from vulnerabilities.models import Advisory +from vulnerabilities.models import Alias +from vulnerabilities.models import Package +from vulnerabilities.models import PackageRelatedVulnerability +from vulnerabilities.models import Vulnerability +from vulnerabilities.models import VulnerabilityReference +from vulnerabilities.models import VulnerabilityRelatedReference +from vulnerabilities.models import VulnerabilitySeverity @pytest.mark.django_db @@ -37,3 +55,152 @@ def test_create_valid_vulnerability_reference_accepts_long_references(): url="https://foo.bar", ) assert result + + +@pytest.mark.django_db +def test_get_or_create_vulnerability_and_aliases_with_new_vulnerability_and_new_aliases(): + alias_names = ["TAYLOR-1337", "SWIFT-1337"] + summary = "Melodious vulnerability" + vulnerability = get_or_create_vulnerability_and_aliases( + alias_names=alias_names, summary=summary + ) + assert vulnerability + alias_names_in_db = vulnerability.get_aliases.values_list("alias", flat=True) + assert Counter(alias_names_in_db) == Counter(alias_names) + + +@pytest.mark.django_db +def test_get_or_create_vulnerability_and_aliases_with_different_vulnerability_and_existing_aliases(): + existing_vulnerability = Vulnerability(vulnerability_id="VCID-Existing") + existing_vulnerability.save() + existing_aliases = [] + existing_alias_names = ["ALIAS-1", "ALIAS-2"] + for alias in existing_alias_names: + existing_aliases.append(Alias(alias=alias, vulnerability=existing_vulnerability)) + Alias.objects.bulk_create(existing_aliases) + + different_vulnerability = Vulnerability(vulnerability_id="VCID-New") + different_vulnerability.save() + assert not get_or_create_vulnerability_and_aliases( + alias_names=existing_alias_names, vulnerability_id=different_vulnerability.vulnerability_id + ) + + +@pytest.mark.django_db +def test_get_or_create_vulnerability_and_aliases_with_existing_vulnerability_and_new_aliases(): + existing_vulnerability = Vulnerability(vulnerability_id="VCID-Existing") + existing_vulnerability.save() + + existing_alias_names = ["ALIAS-1", "ALIAS-2"] + vulnerability = get_or_create_vulnerability_and_aliases( + vulnerability_id="VCID-Existing", alias_names=existing_alias_names + ) + assert existing_vulnerability == vulnerability + + alias_names_in_db = vulnerability.get_aliases.values_list("alias", flat=True) + assert Counter(alias_names_in_db) == Counter(existing_alias_names) + + +@pytest.mark.django_db +def test_get_or_create_vulnerability_and_aliases_with_existing_vulnerability_and_existing_aliases(): + existing_vulnerability = Vulnerability(vulnerability_id="VCID-Existing") + existing_vulnerability.save() + + existing_aliases = [] + existing_alias_names = ["ALIAS-1", "ALIAS-2"] + for alias in existing_alias_names: + existing_aliases.append(Alias(alias=alias, vulnerability=existing_vulnerability)) + Alias.objects.bulk_create(existing_aliases) + + vulnerability = get_or_create_vulnerability_and_aliases( + vulnerability_id="VCID-Existing", alias_names=existing_alias_names + ) + assert existing_vulnerability == vulnerability + + alias_names_in_db = vulnerability.get_aliases.values_list("alias", flat=True) + assert Counter(alias_names_in_db) == Counter(existing_alias_names) + + +@pytest.mark.django_db +def test_get_or_create_vulnerability_and_aliases_with_existing_vulnerability_and_existing_and_new_aliases(): + existing_vulnerability = Vulnerability(vulnerability_id="VCID-Existing") + existing_vulnerability.save() + + existing_aliases = [] + existing_alias_names = ["ALIAS-1", "ALIAS-2"] + for alias in existing_alias_names: + existing_aliases.append(Alias(alias=alias, vulnerability=existing_vulnerability)) + Alias.objects.bulk_create(existing_aliases) + + new_alias_names = ["ALIAS-3", "ALIAS-4"] + alias_names = existing_alias_names + new_alias_names + vulnerability = get_or_create_vulnerability_and_aliases( + vulnerability_id="VCID-Existing", alias_names=alias_names + ) + assert existing_vulnerability == vulnerability + + alias_names_in_db = vulnerability.get_aliases.values_list("alias", flat=True) + assert Counter(alias_names_in_db) == Counter(alias_names) + + +DUMMY_ADVISORY = Advisory(summary="dummy", created_by="tests", date_collected=timezone.now()) + + +@pytest.mark.django_db +def test_process_inferences_with_no_inference(): + assert not process_inferences( + inferences=[], advisory=DUMMY_ADVISORY, improver_name="test_improver" + ) + + +@pytest.mark.django_db +def test_process_inferences_with_unknown_but_specified_vulnerability(): + inference = Inference(vulnerability_id="VCID-Does-Not-Exist-In-DB", aliases=["MATRIX-Neo"]) + assert not process_inferences( + inferences=[inference], advisory=DUMMY_ADVISORY, improver_name="test_improver" + ) + + +INFERENCES = [ + Inference( + aliases=["CVE-1", "CVE-2"], + summary="One upon a time, in a package far far away", + affected_purls=[ + PackageURL(type="character", namespace="star-wars", name="anakin", version="1") + ], + fixed_purl=PackageURL( + type="character", namespace="star-wars", name="darth-vader", version="1" + ), + references=[Reference(reference_id="imperial-vessel-1", url="https://m47r1x.github.io")], + ) +] + + +def get_objects_in_all_tables_used_by_process_inferences(): + return { + "vulnerabilities": list(Vulnerability.objects.all()), + "aliases": list(Alias.objects.all()), + "references": list(VulnerabilityReference.objects.all()), + "advisories": list(Advisory.objects.all()), + "packages": list(Package.objects.all()), + "references": list(VulnerabilityReference.objects.all()), + "severity": list(VulnerabilitySeverity.objects.all()), + } + + +@pytest.mark.django_db +def test_process_inferences_idempotency(): + process_inferences(INFERENCES, DUMMY_ADVISORY, improver_name="test_improver") + all_objects = get_objects_in_all_tables_used_by_process_inferences() + process_inferences(INFERENCES, DUMMY_ADVISORY, improver_name="test_improver") + process_inferences(INFERENCES, DUMMY_ADVISORY, improver_name="test_improver") + assert all_objects == get_objects_in_all_tables_used_by_process_inferences() + + +@pytest.mark.django_db +def test_process_inference_idempotency_with_different_improver_names(): + process_inferences(INFERENCES, DUMMY_ADVISORY, improver_name="test_improver_one") + all_objects = get_objects_in_all_tables_used_by_process_inferences() + process_inferences(INFERENCES, DUMMY_ADVISORY, improver_name="test_improver_two") + process_inferences(INFERENCES, DUMMY_ADVISORY, improver_name="test_improver_three") + assert all_objects == get_objects_in_all_tables_used_by_process_inferences() diff --git a/vulnerabilities/tests/test_improver.py b/vulnerabilities/tests/test_improver.py index ed524a991..792fefc63 100644 --- a/vulnerabilities/tests/test_improver.py +++ b/vulnerabilities/tests/test_improver.py @@ -31,6 +31,7 @@ def test_inference_to_dict_method_with_vulnerability_id(): "affected_purls": [], "fixed_purl": None, "references": [], + "weaknesses": [], } assert expected == inference.to_dict() @@ -46,6 +47,7 @@ def test_inference_to_dict_method_with_purls(): "affected_purls": [purl.to_dict()], "fixed_purl": purl.to_dict(), "references": [], + "weaknesses": [], } assert expected == inference.to_dict()