diff --git a/.envs/local/django.sample b/.envs/local/django.sample index 556c4670..4677551d 100644 --- a/.envs/local/django.sample +++ b/.envs/local/django.sample @@ -37,3 +37,8 @@ METABASE_SECRET_KEY=000000000000000000000000000000000000000000000000000000000000 # This is a workaround for some celery issues that are likely fixed in future versions. # https://github.com/celery/celery/issues/5761 COLUMNS=80 + +# Analyzer +# ------------------------------------------------------------------------------ +# See ``adserver.analyzer.backends`` for available backends +# ADSERVER_ANALYZER_BACKEND= diff --git a/adserver/analyzer/backends/st.py b/adserver/analyzer/backends/st.py index 33f5843a..f5339214 100644 --- a/adserver/analyzer/backends/st.py +++ b/adserver/analyzer/backends/st.py @@ -1,6 +1,7 @@ import logging import os +import trafilatura from bs4 import BeautifulSoup from sentence_transformers import SentenceTransformer from textacy import preprocessing @@ -16,10 +17,11 @@ class SentenceTransformerAnalyzerBackend(BaseAnalyzerBackend): Quick and dirty analyzer that uses the SentenceTransformer library """ - MODEL_NAME = "multi-qa-MiniLM-L6-cos-v1" + MODEL_NAME = os.getenv("SENTENCE_TRANSFORMERS_MODEL", "multi-qa-MiniLM-L6-cos-v1") MODEL_HOME = os.getenv("SENTENCE_TRANSFORMERS_HOME", "/tmp/sentence_transformers") def preprocess_text(self, text): + log.info("Preprocessing text: %s", text) self.preprocessor = preprocessing.make_pipeline( preprocessing.normalize.unicode, preprocessing.remove.punctuation, @@ -28,14 +30,22 @@ def preprocess_text(self, text): return self.preprocessor(text).lower()[: self.MAX_INPUT_LENGTH] def analyze_response(self, resp): + # Disable the analysis for now return [] + def get_content(self, *args): + downloaded = trafilatura.fetch_url(self.url) + result = trafilatura.extract( + downloaded, include_comments=False, include_tables=False + ) + return self.preprocess_text(result) + def embed_response(self, resp) -> list: """Analyze an HTTP response and return a list of keywords/topics for the URL.""" model = SentenceTransformer(self.MODEL_NAME, cache_folder=self.MODEL_HOME) text = self.get_content(resp) if text: - log.info("Embedding text: %s", text[:100]) + log.info("Postprocessed text: %s", text) embedding = model.encode(text) return embedding.tolist() diff --git a/adserver/analyzer/views.py b/adserver/analyzer/views.py index 3c7b6b75..239c1b0a 100644 --- a/adserver/analyzer/views.py +++ b/adserver/analyzer/views.py @@ -1,6 +1,8 @@ from django.conf import settings from pgvector.django import CosineDistance from rest_framework import status +from rest_framework.permissions import AllowAny +from rest_framework.renderers import StaticHTMLRenderer from rest_framework.response import Response from rest_framework.views import APIView @@ -26,6 +28,9 @@ class EmbeddingViewSet(APIView): :>json array results: An array of similar URLs and scores """ + permission_classes = [AllowAny] + renderer_classes = [StaticHTMLRenderer] + def get(self, request): """Return a list of similar URLs and scores based on querying the AnalyzedURL embedding for an incoming URL.""" url = request.query_params.get("url") @@ -40,7 +45,7 @@ def get(self, request): if not response: return Response( {"error": "Not able to fetch content from URL"}, - status=status.HTTP_404_NOT_FOUND, + status=status.HTTP_400_BAD_REQUEST, ) processed_text = backend_instance.get_content(response) analyzed_embedding = backend_instance.embedding(response) @@ -52,9 +57,19 @@ def get(self, request): ) return Response( - { - "count": len(urls), - "text": processed_text[:500], - "results": [[url.url, url.distance] for url in urls], - } + f""" +

Results:

+ +

+ Text: +

+ + """ ) diff --git a/config/settings/base.py b/config/settings/base.py index 999d8392..ce0de2c4 100644 --- a/config/settings/base.py +++ b/config/settings/base.py @@ -66,6 +66,7 @@ "simple_history", "django_slack", "djstripe", + "corsheaders", ] MIDDLEWARE = [ @@ -73,6 +74,7 @@ "enforce_host.EnforceHostMiddleware", "whitenoise.middleware.WhiteNoiseMiddleware", "django.contrib.sessions.middleware.SessionMiddleware", + "corsheaders.middleware.CorsMiddleware", "django.middleware.common.CommonMiddleware", "django.middleware.csrf.CsrfViewMiddleware", "django.contrib.auth.middleware.AuthenticationMiddleware", @@ -538,3 +540,10 @@ MIDDLEWARE.append(ADSERVER_IPADDRESS_MIDDLEWARE) if ADSERVER_GEOIP_MIDDLEWARE: MIDDLEWARE.append(ADSERVER_GEOIP_MIDDLEWARE) + + +CORS_ALLOWED_ORIGINS = [ + "http://localhost:8000", + "http://127.0.0.1:8000", +] +CORS_ALLOW_HEADERS = ["*"] diff --git a/docker-compose/django/start b/docker-compose/django/start index 41248d85..9c93e0d6 100755 --- a/docker-compose/django/start +++ b/docker-compose/django/start @@ -5,7 +5,7 @@ set -o pipefail set -o nounset # Reinstall dependencies without rebuilding docker image -# pip install -r /app/requirements/production.txt -r /app/requirements/analyzer.txt +pip install -r /app/requirements/development.txt # Don't auto-migrate locally because this can cause weird issues when testing migrations # python manage.py migrate diff --git a/requirements/development.txt b/requirements/development.txt index 8626e531..9ca55392 100644 --- a/requirements/development.txt +++ b/requirements/development.txt @@ -35,3 +35,7 @@ django_coverage_plugin==3.1.0 # Used to build the ML model requests-cache==0.9.5 + +# CORS headers +django-cors-headers==3.8.0 +trafilatura