diff --git a/vulnerabilities/data_inference.py b/vulnerabilities/data_inference.py index 6e54fcdd5..bace8dba2 100644 --- a/vulnerabilities/data_inference.py +++ b/vulnerabilities/data_inference.py @@ -30,9 +30,9 @@ class Inference: relationship is to be inserted into the database """ + vulnerability_id: str confidence: int summary: Optional[str] = None - vulnerability_id: Optional[str] = None affected_packages: List[PackageURL] = dataclasses.field(default_factory=list) fixed_packages: List[PackageURL] = dataclasses.field(default_factory=list) references: List[Reference] = dataclasses.field(default_factory=list) @@ -47,9 +47,12 @@ def __post_init__(self): class Improver: """ - All improvers should inherit this class and implement inferences method to return - new inferences for packages or vulnerabilities + All improvers must inherit this class and implement the infer method to + return new inferences for packages or vulnerabilities """ - def inferences(self) -> List[Inference]: + def infer(self) -> List[Inference]: + """ + Implement this method to generate and return Inferences + """ raise NotImplementedError diff --git a/vulnerabilities/data_source.py b/vulnerabilities/data_source.py index 52f59ef02..385870b30 100644 --- a/vulnerabilities/data_source.py +++ b/vulnerabilities/data_source.py @@ -81,7 +81,7 @@ class AffectedPackage: # the version specifier contains the version scheme as is: semver:>=1,3,4 version_specifier: VersionSpecifier - def json(self): + def to_dict(self): # TODO: VersionSpecifier.__str__ is not working # https://github.com/nexB/univers/issues/7 # Adjust following code when it is fixed @@ -89,12 +89,11 @@ def json(self): ranges = ",".join( [f"{rng.operator}{rng.version.value}" for rng in self.version_specifier.ranges] ) - return json.dumps({"package": self.package, "version_specifier": f"{scheme}:{ranges}"}) + return {"package": self.package, "version_specifier": f"{scheme}:{ranges}"} @staticmethod - def from_json(affected_package_json): - obj = json.loads(affected_package_json) - affected_package = AffectedPackage(**obj) + def from_dict(affected_package_dict): + affected_package = AffectedPackage(**affected_package_dict) package = PackageURL(*affected_package.package) version_specifier = VersionSpecifier.from_version_spec_string( affected_package.version_specifier @@ -117,33 +116,29 @@ class AdvisoryData: affected_packages: List[AffectedPackage] = dataclasses.field(default_factory=list) fixed_packages: List[AffectedPackage] = dataclasses.field(default_factory=list) references: List[Reference] = dataclasses.field(default_factory=list) - date_published: Optional[str] = None + date_published: Optional[datetime.date] = None def normalized(self): ... - def serializable(self, o): - if isinstance(o, AffectedPackage): - return o.json() - if isinstance(o, Reference): - return vars(o) - if isinstance(o, datetime): - return o.isoformat() - - return json.JSONEncoder.default(self, o) - - def json(self): - return json.dumps(vars(self), default=self.serializable) + def to_dict(self): + return { + "summary": self.summary, + "vulnerability_id": self.vulnerability_id, + "affected_packages": [pkg.to_dict() for pkg in self.affected_packages], + "fixed_packages": [pkg.to_dict() for pkg in self.fixed_packages], + "references": [vars(ref) for ref in self.references], + "date_published": self.date_published.isoformat(), + } @staticmethod - def from_json(advisory_data_json: str): - obj = json.loads(advisory_data_json) - advisory_data = AdvisoryData(**obj) + def from_dict(advisory_data_dict: str): + advisory_data = AdvisoryData(**advisory_data_dict) advisory_data.affected_packages = [ - AffectedPackage.from_json(p) for p in advisory_data.affected_packages + AffectedPackage.from_dict(p) for p in advisory_data.affected_packages ] advisory_data.fixed_packages = [ - AffectedPackage.from_json(p) for p in advisory_data.fixed_packages + AffectedPackage.from_dict(p) for p in advisory_data.fixed_packages ] advisory_data.references = [Reference(**ref) for ref in advisory_data.references] return advisory_data diff --git a/vulnerabilities/import_runner.py b/vulnerabilities/import_runner.py index 50e93a55b..d8a9cfd80 100644 --- a/vulnerabilities/import_runner.py +++ b/vulnerabilities/import_runner.py @@ -124,7 +124,7 @@ def process_advisories(source: str, advisory_data: Set[AdvisoryData]) -> None: date_published=data.date_published, date_collected=datetime.datetime.now(tz=datetime.timezone.utc), source=source, - data=data.json(), + data=json.dumps(data.to_dict()), ) ) diff --git a/vulnerabilities/improve_runner.py b/vulnerabilities/improve_runner.py index 6f76d25f9..717faa18b 100644 --- a/vulnerabilities/improve_runner.py +++ b/vulnerabilities/improve_runner.py @@ -38,7 +38,9 @@ def run(self) -> None: def process_inferences(source: str, inferences: Set[Inference]): bulk_create_vuln_pkg_refs = set() for inference in inferences: - vuln, vuln_created = _get_or_create_vulnerability(inference.vulnerability_id, inference.summary) + vuln, vuln_created = _get_or_create_vulnerability( + inference.vulnerability_id, inference.summary + ) for vuln_ref in inference.references: ref, _ = models.VulnerabilityReference.objects.get_or_create( vulnerability=vuln, reference_id=vuln_ref.reference_id, url=vuln_ref.url @@ -77,9 +79,7 @@ def process_inferences(source: str, inferences: Set[Inference]): ) -def _get_or_create_vulnerability( - vulnerability_id, summary -) -> Tuple[models.Vulnerability, bool]: +def _get_or_create_vulnerability(vulnerability_id, summary) -> Tuple[models.Vulnerability, bool]: vuln, created = models.Vulnerability.objects.get_or_create( vulnerability_id=vulnerability_id diff --git a/vulnerabilities/improvers/default.py b/vulnerabilities/improvers/default.py index 52d6deaa6..4eb2c41e0 100644 --- a/vulnerabilities/improvers/default.py +++ b/vulnerabilities/improvers/default.py @@ -1,3 +1,4 @@ +import json from typing import List from itertools import chain @@ -11,6 +12,7 @@ from vulnerabilities.data_inference import MAX_CONFIDENCE from vulnerabilities.models import Advisory + class DefaultImprover(Improver): def inferences(self) -> List[Inference]: advisories = Advisory.objects.all() @@ -18,7 +20,7 @@ def inferences(self) -> List[Inference]: inferences = [] for advisory in advisories: - advisory_data = AdvisoryData.from_json(advisory.data) + advisory_data = AdvisoryData.from_dict(json.loads(advisory.data)) affected_packages = chain.from_iterable( [exact_purls(pkg) for pkg in advisory_data.affected_packages] @@ -55,7 +57,7 @@ def exact_purls(pkg: AffectedPackage) -> List[PackageURL]: purls = [] for rng in vs.ranges: if "=" in rng.operator and not "!" in rng.operator: - purl = pkg.package._replace(version = rng.version.value) + purl = pkg.package._replace(version=rng.version.value) purls.append(purl) return purls diff --git a/vulnerabilities/improvers/nginx.py b/vulnerabilities/improvers/nginx.py deleted file mode 100644 index e97e30953..000000000 --- a/vulnerabilities/improvers/nginx.py +++ /dev/null @@ -1,44 +0,0 @@ -from packageurl import PackageURL - -from vulnerabilities.data_inference import Improver -from vulnerabilities.data_inference import Advisory -from vulnerabilities.data_inference import Inference -from vulnerabilities.helpers import nearest_patched_package -from vulnerabilities.models import Vulnerability -from vulnerabilities.models import Package - -class NginxTimeTravel(Improver): - def updated_inferences(self): - inferences = [] - - vulnerabilities = set(Vulnerability.objects.filter(vulnerable_packages__name="nginx")) - vulnerabilities.union(Vulnerability.objects.filter(patched_packages__name="nginx")) - - for vulnerability in vulnerabilities: - affected_packages = map(package_url, Package.objects.filter(vulnerable_package__package__name="nginx", vulnerabilities = vulnerability)) - fixed_packages = map(package_url, Package.objects.filter(patched_package__package__name="nginx", vulnerabilities = vulnerability)) - - time_traveller = nearest_patched_package(affected_packages, fixed_packages) - affected_packages = [ affected_package.vulnerable_package for affected_package in time_traveller] - fixed_packages = [ affected_package.patched_package for affected_package in time_traveller if affected_package.patched_package is not None] - - inference = Inference(advisory = Advisory( - vulnerability_id=vulnerability.vulnerability_id, - summary=vulnerability.summary, - affected_package_urls=fixed_packages, - ), source="time travel", confidence=30) - inferences.append(inference) - - return inferences - - -def package_url(package): - return PackageURL( - type=package.type, - namespace=package.namespace, - name=package.name, - version=package.version, - subpath=package.subpath, - qualifiers=package.qualifiers - ) - diff --git a/vulnerabilities/models.py b/vulnerabilities/models.py index e7bd4ac4f..8a6c4f06c 100644 --- a/vulnerabilities/models.py +++ b/vulnerabilities/models.py @@ -320,10 +320,19 @@ class Meta: class Advisory(models.Model): - date_published = models.DateField() - date_collected = models.DateField() - source = models.CharField(max_length=100) - improved_on = models.DateTimeField(null=True) - improved_times = models.IntegerField(default=0) - # data would contain a data_source.AdvisoryData - data = models.JSONField() + """ + An advisory directly obtained from upstream without any modifications. + """ + + date_published = models.DateField(help_text="Date of publication of the advisory") + date_collected = models.DateField(help_text="Date on which the advisory was collected") + source = models.CharField( + max_length=100, + help_text="Fully qualified name of the importer prefixed with the module name importing the advisory. Eg: vulnerabilities.importers.nginx.NginxDataSource", + ) + date_improved = models.DateTimeField( + null=True, help_text="Latest date on which the advisory was improved by an improver" + ) + data = models.JSONField( + help_text="Contents of data_source.AdvisoryData serialized as a JSON object" + ) diff --git a/vulnerabilities/views.py b/vulnerabilities/views.py index 6bdc6c285..d018ef944 100644 --- a/vulnerabilities/views.py +++ b/vulnerabilities/views.py @@ -74,8 +74,14 @@ def request_to_queryset(request): models.Package.objects.all() .filter(name__icontains=package_name, type__icontains=package_type) .annotate( - vulnerability_count=Count("vulnerabilities", filter=Q(vulnerabilities__packagerelatedvulnerability__fix=False)), - patched_vulnerability_count=Count("vulnerabilities",filter=Q(vulnerabilities__packagerelatedvulnerability__fix=True)), + vulnerability_count=Count( + "vulnerabilities", + filter=Q(vulnerabilities__packagerelatedvulnerability__fix=False), + ), + patched_vulnerability_count=Count( + "vulnerabilities", + filter=Q(vulnerabilities__packagerelatedvulnerability__fix=True), + ), ) .prefetch_related() ) @@ -102,8 +108,12 @@ def request_to_vulnerabilities(request): vuln_id = request.GET["vuln_id"] return list( models.Vulnerability.objects.filter(vulnerability_id__icontains=vuln_id).annotate( - vulnerable_package_count=Count("packages", filter=Q(packagerelatedvulnerability__fix=False)), - patched_package_count=Count("packages", filter=Q(packagerelatedvulnerability__fix=True)), + vulnerable_package_count=Count( + "packages", filter=Q(packagerelatedvulnerability__fix=False) + ), + patched_package_count=Count( + "packages", filter=Q(packagerelatedvulnerability__fix=True) + ), ) )