Skip to content

Commit

Permalink
Merge pull request #1649 from aboutcode-org/add-bulk-search-v2
Browse files Browse the repository at this point in the history
Add bulk search in v2
  • Loading branch information
TG1999 authored Nov 14, 2024
2 parents 8bca5cc + b071a51 commit ae605f4
Show file tree
Hide file tree
Showing 3 changed files with 559 additions and 8 deletions.
300 changes: 297 additions & 3 deletions vulnerabilities/api_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,20 @@
#


from django_filters import rest_framework as filters
from drf_spectacular.utils import OpenApiParameter
from drf_spectacular.utils import extend_schema
from drf_spectacular.utils import extend_schema_view
from packageurl import PackageURL
from rest_framework import serializers
from rest_framework import status
from rest_framework import viewsets
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.reverse import reverse

from vulnerabilities.api import PackageFilterSet
from vulnerabilities.api import VulnerabilitySeveritySerializer
from vulnerabilities.models import Package
from vulnerabilities.models import Vulnerability
from vulnerabilities.models import VulnerabilityReference
Expand Down Expand Up @@ -90,6 +99,26 @@ def get_url(self, obj):
)


@extend_schema_view(
list=extend_schema(
parameters=[
OpenApiParameter(
name="vulnerability_id",
description="Filter by one or more vulnerability IDs",
required=False,
type={"type": "array", "items": {"type": "string"}},
location=OpenApiParameter.QUERY,
),
OpenApiParameter(
name="alias",
description="Filter by alias (CVE or other unique identifier)",
required=False,
type={"type": "array", "items": {"type": "string"}},
location=OpenApiParameter.QUERY,
),
]
)
)
class VulnerabilityV2ViewSet(viewsets.ReadOnlyModelViewSet):
queryset = Vulnerability.objects.all()
serializer_class = VulnerabilityV2Serializer
Expand Down Expand Up @@ -142,6 +171,7 @@ def list(self, request, *args, **kwargs):

class PackageV2Serializer(serializers.ModelSerializer):
purl = serializers.CharField(source="package_url")
risk_score = serializers.FloatField(read_only=True)
affected_by_vulnerabilities = serializers.SerializerMethodField()
fixing_vulnerabilities = serializers.SerializerMethodField()
next_non_vulnerable_version = serializers.CharField(read_only=True)
Expand All @@ -155,6 +185,7 @@ class Meta:
"fixing_vulnerabilities",
"next_non_vulnerable_version",
"latest_non_vulnerable_version",
"risk_score",
]

def get_affected_by_vulnerabilities(self, obj):
Expand All @@ -164,9 +195,39 @@ def get_fixing_vulnerabilities(self, obj):
return [vuln.vulnerability_id for vuln in obj.fixing_vulnerabilities.all()]


class PackageurlListSerializer(serializers.Serializer):
purls = serializers.ListField(
child=serializers.CharField(),
allow_empty=False,
help_text="List of PackageURL strings in canonical form.",
)


class PackageBulkSearchRequestSerializer(PackageurlListSerializer):
purl_only = serializers.BooleanField(required=False, default=False)
plain_purl = serializers.BooleanField(required=False, default=False)


class LookupRequestSerializer(serializers.Serializer):
purl = serializers.CharField(
required=True,
help_text="PackageURL strings in canonical form.",
)


class PackageV2FilterSet(filters.FilterSet):
affected_by_vulnerability = filters.CharFilter(
field_name="affected_by_vulnerabilities__vulnerability_id"
)
fixing_vulnerability = filters.CharFilter(field_name="fixing_vulnerabilities__vulnerability_id")
purl = filters.CharFilter(field_name="package_url")


class PackageV2ViewSet(viewsets.ReadOnlyModelViewSet):
queryset = Package.objects.all()
serializer_class = PackageV2Serializer
filter_backends = (filters.DjangoFilterBackend,)
filterset_class = PackageV2FilterSet

def get_queryset(self):
queryset = super().get_queryset()
Expand All @@ -188,15 +249,248 @@ def get_queryset(self):

def list(self, request, *args, **kwargs):
queryset = self.get_queryset()

# Apply pagination
page = self.paginate_queryset(queryset)
if page is not None:
# Collect only vulnerabilities for packages in the current page
vulnerabilities = set()
for package in page:
vulnerabilities.update(package.affected_by_vulnerabilities.all())
vulnerabilities.update(package.fixing_vulnerabilities.all())

# Serialize the vulnerabilities with vulnerability_id as keys
vulnerability_data = {
vuln.vulnerability_id: VulnerabilityV2Serializer(vuln).data
for vuln in vulnerabilities
}

# Serialize the current page of packages
serializer = self.get_serializer(page, many=True)
data = serializer.data

# Use 'self.get_paginated_response' to include pagination data
return self.get_paginated_response({"packages": data})
return self.get_paginated_response(
{"vulnerabilities": vulnerability_data, "packages": data}
)

# If pagination is not applied, collect vulnerabilities for all packages
vulnerabilities = set()
for package in queryset:
vulnerabilities.update(package.affected_by_vulnerabilities.all())
vulnerabilities.update(package.fixing_vulnerabilities.all())

# If pagination is not applied
vulnerability_data = {
vuln.vulnerability_id: VulnerabilityV2Serializer(vuln).data for vuln in vulnerabilities
}

# Serialize all packages when pagination is not applied
serializer = self.get_serializer(queryset, many=True)
data = serializer.data
return Response({"packages": data})
return Response({"vulnerabilities": vulnerability_data, "packages": data})

@extend_schema(
request=PackageurlListSerializer,
responses={200: PackageV2Serializer(many=True)},
)
@action(
detail=False,
methods=["post"],
serializer_class=PackageurlListSerializer,
filter_backends=[],
pagination_class=None,
)
def bulk_lookup(self, request):
"""
Return the response for exact PackageURLs requested for.
"""
serializer = self.serializer_class(data=request.data)
if not serializer.is_valid():
return Response(
status=status.HTTP_400_BAD_REQUEST,
data={
"error": serializer.errors,
"message": "A non-empty 'purls' list of PURLs is required.",
},
)
validated_data = serializer.validated_data
purls = validated_data.get("purls")

# Fetch packages matching the provided purls
packages = Package.objects.for_purls(purls).with_is_vulnerable()

# Collect vulnerabilities associated with these packages
vulnerabilities = set()
for package in packages:
vulnerabilities.update(package.affected_by_vulnerabilities.all())
vulnerabilities.update(package.fixing_vulnerabilities.all())

# Serialize vulnerabilities with vulnerability_id as keys
vulnerability_data = {
vuln.vulnerability_id: VulnerabilityV2Serializer(vuln).data for vuln in vulnerabilities
}

# Serialize packages
package_data = PackageV2Serializer(
packages,
many=True,
context={"request": request},
).data

return Response(
{
"vulnerabilities": vulnerability_data,
"packages": package_data,
}
)

@extend_schema(
request=PackageBulkSearchRequestSerializer,
responses={200: PackageV2Serializer(many=True)},
)
@action(
detail=False,
methods=["post"],
serializer_class=PackageBulkSearchRequestSerializer,
filter_backends=[],
pagination_class=None,
)
def bulk_search(self, request):
"""
Lookup for vulnerable packages using many Package URLs at once.
"""
serializer = self.serializer_class(data=request.data)
if not serializer.is_valid():
return Response(
status=status.HTTP_400_BAD_REQUEST,
data={
"error": serializer.errors,
"message": "A non-empty 'purls' list of PURLs is required.",
},
)
validated_data = serializer.validated_data
purls = validated_data.get("purls")
purl_only = validated_data.get("purl_only", False)
plain_purl = validated_data.get("plain_purl", False)

if plain_purl:
purl_objects = [PackageURL.from_string(purl) for purl in purls]
plain_purl_objects = [
PackageURL(
type=purl.type,
namespace=purl.namespace,
name=purl.name,
version=purl.version,
)
for purl in purl_objects
]
plain_purls = [str(purl) for purl in plain_purl_objects]

query = (
Package.objects.filter(plain_package_url__in=plain_purls)
.order_by("plain_package_url")
.distinct("plain_package_url")
.with_is_vulnerable()
)

packages = query

# Collect vulnerabilities associated with these packages
vulnerabilities = set()
for package in packages:
vulnerabilities.update(package.affected_by_vulnerabilities.all())
vulnerabilities.update(package.fixing_vulnerabilities.all())

vulnerability_data = {
vuln.vulnerability_id: VulnerabilityV2Serializer(vuln).data
for vuln in vulnerabilities
}

if not purl_only:
package_data = PackageV2Serializer(
packages, many=True, context={"request": request}
).data
return Response(
{
"vulnerabilities": vulnerability_data,
"packages": package_data,
}
)

# Using order by and distinct because there will be
# many fully qualified purl for a single plain purl
vulnerable_purls = query.vulnerable().only("plain_package_url")
vulnerable_purls = [str(package.plain_package_url) for package in vulnerable_purls]
return Response(data=vulnerable_purls)

query = Package.objects.filter(package_url__in=purls).distinct().with_is_vulnerable()
packages = query

# Collect vulnerabilities associated with these packages
vulnerabilities = set()
for package in packages:
vulnerabilities.update(package.affected_by_vulnerabilities.all())
vulnerabilities.update(package.fixing_vulnerabilities.all())

vulnerability_data = {
vuln.vulnerability_id: VulnerabilityV2Serializer(vuln).data for vuln in vulnerabilities
}

if not purl_only:
package_data = PackageV2Serializer(
packages, many=True, context={"request": request}
).data
return Response(
{
"vulnerabilities": vulnerability_data,
"packages": package_data,
}
)

vulnerable_purls = query.vulnerable().only("package_url")
vulnerable_purls = [str(package.package_url) for package in vulnerable_purls]
return Response(data=vulnerable_purls)

@action(detail=False, methods=["get"])
def all(self, request):
"""
Return a list of Package URLs of vulnerable packages.
"""
vulnerable_purls = (
Package.objects.vulnerable()
.only("package_url")
.order_by("package_url")
.distinct()
.values_list("package_url", flat=True)
)
return Response(vulnerable_purls)

@extend_schema(
request=LookupRequestSerializer,
responses={200: PackageV2Serializer(many=True)},
)
@action(
detail=False,
methods=["post"],
serializer_class=LookupRequestSerializer,
filter_backends=[],
pagination_class=None,
)
def lookup(self, request):
"""
Return the response for exact PackageURL requested for.
"""
serializer = self.serializer_class(data=request.data)
if not serializer.is_valid():
return Response(
status=status.HTTP_400_BAD_REQUEST,
data={
"error": serializer.errors,
"message": "A 'purl' is required.",
},
)
validated_data = serializer.validated_data
purl = validated_data.get("purl")

qs = self.get_queryset().for_purls([purl]).with_is_vulnerable()
return Response(PackageV2Serializer(qs, many=True, context={"request": request}).data)
Loading

0 comments on commit ae605f4

Please sign in to comment.