diff --git a/adserver/analyzer/management/commands/runmodel.py b/adserver/analyzer/management/commands/runmodel.py index baa7a387..66354521 100644 --- a/adserver/analyzer/management/commands/runmodel.py +++ b/adserver/analyzer/management/commands/runmodel.py @@ -45,7 +45,7 @@ def handle_url(self, url): self.stdout.write( _("Keywords from '%s': %s") % (backend.__name__, analyzed_keywords) ) - analyzed_embedding = backend.embedding(response) + analyzed_embedding = backend_instance.embedding(response) self.stdout.write( _("Embeddings from '%s': %s") % (backend.__name__, analyzed_embedding) ) diff --git a/adserver/analyzer/views.py b/adserver/analyzer/views.py index 46f5cc57..28a70b3a 100644 --- a/adserver/analyzer/views.py +++ b/adserver/analyzer/views.py @@ -1 +1,53 @@ -"""Intentionally blank.""" +from django.conf import settings +from pgvector.django import CosineDistance +from rest_framework import status +from rest_framework.response import Response +from rest_framework.views import APIView + +from adserver.analyzer.backends.st import SentenceTransformerAnalyzerBackend +from adserver.analyzer.models import AnalyzedUrl + + +if "adserver.analyzer" in settings.INSTALLED_APPS: + + class EmbeddingViewSet(APIView): + """ + Returns a list of similar URLs and scores based on querying the AnalyzedURL embedding for an incoming URL. + + .. http:get:: /api/v1/embedding/ + + Return a list of similar URLs and scores based on querying the AnalyzedURL embedding for an incoming URL + + :json int count: The number of similar URLs returned + :>json array results: An array of similar URLs and scores + """ + + def get(self, request): + """Return a list of similar URLs and scores based on querying the AnalyzedURL embedding for an incoming URL.""" + url = request.query_params.get("url") + + if not url: + return Response( + {"error": "url is required"}, status=status.HTTP_400_BAD_REQUEST + ) + + backend_instance = SentenceTransformerAnalyzerBackend(url) + response = backend_instance.fetch() + processed_text = backend_instance.get_content(response) + analyzed_embedding = backend_instance.embedding(response) + + urls = ( + AnalyzedUrl.objects.exclude(embedding=None) + .annotate(distance=CosineDistance("embedding", analyzed_embedding)) + .order_by("distance")[:10] + ) + + return Response( + { + "count": len(urls), + "text": processed_text[:500], + "results": [[url.url, url.distance] for url in urls], + } + ) diff --git a/adserver/api/urls.py b/adserver/api/urls.py index e04804fc..27a63e30 100644 --- a/adserver/api/urls.py +++ b/adserver/api/urls.py @@ -1,4 +1,5 @@ """API Urls for the ad server.""" +from django.conf import settings from django.urls import path from rest_framework import routers @@ -14,4 +15,11 @@ router = routers.SimpleRouter() router.register(r"advertisers", AdvertiserViewSet, basename="advertisers") router.register(r"publishers", PublisherViewSet, basename="publishers") + +if "adserver.analyzer" in settings.INSTALLED_APPS: + from adserver.analyzer.views import EmbeddingViewSet + + urlpatterns += [path(r"similar/", EmbeddingViewSet.as_view(), name="similar")] + + urlpatterns += router.urls diff --git a/adserver/api/views.py b/adserver/api/views.py index c484d60c..ed6d9b04 100644 --- a/adserver/api/views.py +++ b/adserver/api/views.py @@ -37,7 +37,6 @@ class AdDecisionView(GeoIpMixin, APIView): - """ Make a decision on an `Advertisement` to show. @@ -360,7 +359,6 @@ def decision(self, request, data): class AdvertiserViewSet(viewsets.ReadOnlyModelViewSet): - """ Advertiser API calls. @@ -470,7 +468,6 @@ def report(self, request, slug=None): # pylint: disable=unused-argument class PublisherViewSet(viewsets.ReadOnlyModelViewSet): - """ Publisher API calls.