diff --git a/.envs/local/django.sample b/.envs/local/django.sample index 556c4670..4677551d 100644 --- a/.envs/local/django.sample +++ b/.envs/local/django.sample @@ -37,3 +37,8 @@ METABASE_SECRET_KEY=000000000000000000000000000000000000000000000000000000000000 # This is a workaround for some celery issues that are likely fixed in future versions. # https://github.com/celery/celery/issues/5761 COLUMNS=80 + +# Analyzer +# ------------------------------------------------------------------------------ +# See ``adserver.analyzer.backends`` for available backends +# ADSERVER_ANALYZER_BACKEND= diff --git a/adserver/analyzer/backends/st.py b/adserver/analyzer/backends/st.py index 33f5843a..f5339214 100644 --- a/adserver/analyzer/backends/st.py +++ b/adserver/analyzer/backends/st.py @@ -1,6 +1,7 @@ import logging import os +import trafilatura from bs4 import BeautifulSoup from sentence_transformers import SentenceTransformer from textacy import preprocessing @@ -16,10 +17,11 @@ class SentenceTransformerAnalyzerBackend(BaseAnalyzerBackend): Quick and dirty analyzer that uses the SentenceTransformer library """ - MODEL_NAME = "multi-qa-MiniLM-L6-cos-v1" + MODEL_NAME = os.getenv("SENTENCE_TRANSFORMERS_MODEL", "multi-qa-MiniLM-L6-cos-v1") MODEL_HOME = os.getenv("SENTENCE_TRANSFORMERS_HOME", "/tmp/sentence_transformers") def preprocess_text(self, text): + log.info("Preprocessing text: %s", text) self.preprocessor = preprocessing.make_pipeline( preprocessing.normalize.unicode, preprocessing.remove.punctuation, @@ -28,14 +30,22 @@ def preprocess_text(self, text): return self.preprocessor(text).lower()[: self.MAX_INPUT_LENGTH] def analyze_response(self, resp): + # Disable the analysis for now return [] + def get_content(self, *args): + downloaded = trafilatura.fetch_url(self.url) + result = trafilatura.extract( + downloaded, include_comments=False, include_tables=False + ) + return self.preprocess_text(result) + def embed_response(self, resp) -> list: """Analyze an HTTP response and return a list of keywords/topics for the URL.""" model = SentenceTransformer(self.MODEL_NAME, cache_folder=self.MODEL_HOME) text = self.get_content(resp) if text: - log.info("Embedding text: %s", text[:100]) + log.info("Postprocessed text: %s", text) embedding = model.encode(text) return embedding.tolist() diff --git a/adserver/analyzer/views.py b/adserver/analyzer/views.py index 3c7b6b75..239c1b0a 100644 --- a/adserver/analyzer/views.py +++ b/adserver/analyzer/views.py @@ -1,6 +1,8 @@ from django.conf import settings from pgvector.django import CosineDistance from rest_framework import status +from rest_framework.permissions import AllowAny +from rest_framework.renderers import StaticHTMLRenderer from rest_framework.response import Response from rest_framework.views import APIView @@ -26,6 +28,9 @@ class EmbeddingViewSet(APIView): :>json array results: An array of similar URLs and scores """ + permission_classes = [AllowAny] + renderer_classes = [StaticHTMLRenderer] + def get(self, request): """Return a list of similar URLs and scores based on querying the AnalyzedURL embedding for an incoming URL.""" url = request.query_params.get("url") @@ -40,7 +45,7 @@ def get(self, request): if not response: return Response( {"error": "Not able to fetch content from URL"}, - status=status.HTTP_404_NOT_FOUND, + status=status.HTTP_400_BAD_REQUEST, ) processed_text = backend_instance.get_content(response) analyzed_embedding = backend_instance.embedding(response) @@ -52,9 +57,19 @@ def get(self, request): ) return Response( - { - "count": len(urls), - "text": processed_text[:500], - "results": [[url.url, url.distance] for url in urls], - } + f""" +