Skip to content

Commit

Permalink
Add initial similarity API
Browse files Browse the repository at this point in the history
This ports over the logic from the CLI into a DRF API view

It's a basic start for now,
but gets at least a working example for local dev.
  • Loading branch information
ericholscher committed Feb 19, 2024
1 parent 198834e commit 7c3cd9e
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 5 deletions.
2 changes: 1 addition & 1 deletion adserver/analyzer/management/commands/runmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def handle_url(self, url):
self.stdout.write(
_("Keywords from '%s': %s") % (backend.__name__, analyzed_keywords)
)
analyzed_embedding = backend.embedding(response)
analyzed_embedding = backend_instance.embedding(response)
self.stdout.write(
_("Embeddings from '%s': %s") % (backend.__name__, analyzed_embedding)
)
Expand Down
54 changes: 53 additions & 1 deletion adserver/analyzer/views.py
Original file line number Diff line number Diff line change
@@ -1 +1,53 @@
"""Intentionally blank."""
from django.conf import settings
from pgvector.django import CosineDistance
from rest_framework import status
from rest_framework.response import Response
from rest_framework.views import APIView

from adserver.analyzer.backends.st import SentenceTransformerAnalyzerBackend
from adserver.analyzer.models import AnalyzedUrl


if "adserver.analyzer" in settings.INSTALLED_APPS:

class EmbeddingViewSet(APIView):
"""
Returns a list of similar URLs and scores based on querying the AnalyzedURL embedding for an incoming URL.
.. http:get:: /api/v1/embedding/
Return a list of similar URLs and scores based on querying the AnalyzedURL embedding for an incoming URL
:<json string url: **Required**. The URL to query for similar URLs and scores
:>json int count: The number of similar URLs returned
:>json array results: An array of similar URLs and scores
"""

def get(self, request):
"""Return a list of similar URLs and scores based on querying the AnalyzedURL embedding for an incoming URL."""
url = request.query_params.get("url")

if not url:
return Response(
{"error": "url is required"}, status=status.HTTP_400_BAD_REQUEST
)

backend_instance = SentenceTransformerAnalyzerBackend(url)
response = backend_instance.fetch()
processed_text = backend_instance.get_content(response)
analyzed_embedding = backend_instance.embedding(response)

urls = (
AnalyzedUrl.objects.exclude(embedding=None)
.annotate(distance=CosineDistance("embedding", analyzed_embedding))
.order_by("distance")[:10]
)

return Response(
{
"count": len(urls),
"text": processed_text[:500],
"results": [[url.url, url.distance] for url in urls],
}
)
8 changes: 8 additions & 0 deletions adserver/api/urls.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""API Urls for the ad server."""
from django.conf import settings
from django.urls import path
from rest_framework import routers

Expand All @@ -14,4 +15,11 @@
router = routers.SimpleRouter()
router.register(r"advertisers", AdvertiserViewSet, basename="advertisers")
router.register(r"publishers", PublisherViewSet, basename="publishers")

if "adserver.analyzer" in settings.INSTALLED_APPS:
from adserver.analyzer.views import EmbeddingViewSet

urlpatterns += [path(r"similar/", EmbeddingViewSet.as_view(), name="similar")]


urlpatterns += router.urls
3 changes: 0 additions & 3 deletions adserver/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@


class AdDecisionView(GeoIpMixin, APIView):

"""
Make a decision on an `Advertisement` to show.
Expand Down Expand Up @@ -360,7 +359,6 @@ def decision(self, request, data):


class AdvertiserViewSet(viewsets.ReadOnlyModelViewSet):

"""
Advertiser API calls.
Expand Down Expand Up @@ -470,7 +468,6 @@ def report(self, request, slug=None): # pylint: disable=unused-argument


class PublisherViewSet(viewsets.ReadOnlyModelViewSet):

"""
Publisher API calls.
Expand Down

0 comments on commit 7c3cd9e

Please sign in to comment.