Skip to content
This repository has been archived by the owner on Feb 22, 2023. It is now read-only.

Commit

Permalink
Add option to sort search results by created_on (#916)
Browse files Browse the repository at this point in the history
  • Loading branch information
zackkrida authored Feb 15, 2023
1 parent 9f2b831 commit 2ea8704
Show file tree
Hide file tree
Showing 11 changed files with 10,114 additions and 10,011 deletions.
1 change: 1 addition & 0 deletions api/catalog/api/constants/field_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
json_fields = [
"id",
"title",
"indexed_on",
"foreign_landing_url",
"url",
"creator",
Expand Down
13 changes: 13 additions & 0 deletions api/catalog/api/constants/sorting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
RELEVANCE = "relevance"
INDEXED_ON = "indexed_on"
SORT_FIELDS = [
(RELEVANCE, "Relevance"), # default
(INDEXED_ON, "Indexing date"), # date on which media was indexed into Openverse
]

DESCENDING = "desc"
ASCENDING = "asc"
SORT_DIRECTIONS = [
(DESCENDING, "Descending"), # default
(ASCENDING, "Ascending"),
]
15 changes: 10 additions & 5 deletions api/catalog/api/controllers/search_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pprint
from itertools import accumulate
from math import ceil
from typing import Any, Literal
from typing import Literal

from django.conf import settings
from django.core.cache import cache
Expand All @@ -17,6 +17,8 @@
from elasticsearch_dsl.response import Hit, Response

import catalog.api.models as models
from catalog.api.constants.sorting import INDEXED_ON
from catalog.api.serializers import media_serializers
from catalog.api.utils import tallies
from catalog.api.utils.dead_link_mask import get_query_hash, get_query_mask
from catalog.api.utils.validate_images import validate_images
Expand Down Expand Up @@ -220,8 +222,7 @@ def _post_process_results(

def _apply_filter(
s: Search,
# Any is used here to avoid a circular import
search_params: Any, # MediaSearchRequestSerializer
search_params: media_serializers.MediaSearchRequestSerializer,
serializer_field: str,
es_field: str | None = None,
behaviour: Literal["filter", "exclude"] = "filter",
Expand Down Expand Up @@ -278,8 +279,7 @@ def _exclude_mature_by_param(s: Search, search_params):


def search(
# Any is used here to avoid a circular import
search_params: Any, # MediaSearchRequestSerializer
search_params: media_serializers.MediaSearchRequestSerializer,
index: Literal["image", "audio"],
page_size: int,
ip: int,
Expand Down Expand Up @@ -390,6 +390,11 @@ def search(
# Route users to the same Elasticsearch worker node to reduce
# pagination inconsistencies and increase cache hits.
s = s.params(preference=str(ip), request_timeout=7)

# Sort by new
if search_params.validated_data["sort_by"] == INDEXED_ON:
s = s.sort({"created_on": {"order": search_params.validated_data["sort_dir"]}})

# Paginate
start, end = _get_query_slice(s, page_size, page, filter_dead)
s = s[start:end]
Expand Down
1 change: 1 addition & 0 deletions api/catalog/api/examples/audio_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
base_audio = {
"id": identifier,
"title": "Wish You Were Here",
"indexed_on": "2022-12-06T06:54:25Z",
"foreign_landing_url": "https://www.jamendo.com/track/1214935",
"url": "https://mp3d.jamendo.com/download/track/1214935/mp32",
"creator": "The.madpix.project",
Expand Down
1 change: 1 addition & 0 deletions api/catalog/api/examples/image_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
base_image = {
"id": identifier,
"title": "Tree Bark Photo",
"indexed_on": "2022-08-27T17:39:48Z",
"foreign_landing_url": "https://stocksnap.io/photo/XNVBVXO3B7",
"url": "https://cdn.stocksnap.io/img-thumbs/960w/XNVBVXO3B7.jpg",
"creator": "Tim Sullivan",
Expand Down
46 changes: 46 additions & 0 deletions api/catalog/api/serializers/media_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
from rest_framework.exceptions import NotAuthenticated

from catalog.api.constants.licenses import LICENSE_GROUPS
from catalog.api.constants.sorting import (
DESCENDING,
RELEVANCE,
SORT_DIRECTIONS,
SORT_FIELDS,
)
from catalog.api.controllers import search_controller
from catalog.api.models.media import AbstractMedia
from catalog.api.serializers.base import BaseModelSerializer
Expand Down Expand Up @@ -42,6 +48,8 @@ class MediaSearchRequestSerializer(serializers.Serializer):
"extension",
"mature",
"qa",
# "unstable__sort_by", # excluding unstable fields
# "unstable__sort_dir", # excluding unstable fields
"page_size",
"page",
]
Expand Down Expand Up @@ -109,6 +117,28 @@ class MediaSearchRequestSerializer(serializers.Serializer):
required=False,
default=False,
)

# The ``unstable__`` prefix is used in the query params.
# The validated data does not contain the ``unstable__`` prefix.
# If you rename these fields, update the following references:
# - ``field_names`` in ``MediaSearchRequestSerializer``
# - validators for these fields in ``MediaSearchRequestSerializer``
unstable__sort_by = serializers.ChoiceField(
source="sort_by",
help_text="The field which should be the basis for sorting results.",
choices=SORT_FIELDS,
required=False,
default=RELEVANCE,
)
unstable__sort_dir = serializers.ChoiceField(
source="sort_dir",
help_text="The direction of sorting. Cannot be applied when sorting by "
"`relevance`.",
choices=SORT_DIRECTIONS,
required=False,
default=DESCENDING,
)

page_size = serializers.IntegerField(
label="page_size",
help_text="Number of results to return per page.",
Expand Down Expand Up @@ -170,6 +200,16 @@ def validate_tags(self, value):
def validate_title(self, value):
return self._truncate(value)

def validate_unstable__sort_by(self, value):
request = self.context.get("request")
is_anonymous = bool(request and request.user and request.user.is_anonymous)
return RELEVANCE if is_anonymous else value

def validate_unstable__sort_dir(self, value):
request = self.context.get("request")
is_anonymous = bool(request and request.user and request.user.is_anonymous)
return DESCENDING if is_anonymous else value

def validate_page_size(self, value):
request = self.context.get("request")
is_anonymous = bool(request and request.user and request.user.is_anonymous)
Expand Down Expand Up @@ -314,6 +354,7 @@ class Meta:
model = AbstractMedia
fields = [
"id",
"indexed_on",
"title",
"foreign_landing_url",
"url",
Expand Down Expand Up @@ -345,6 +386,11 @@ class Meta:
source="identifier",
)

indexed_on = serializers.DateTimeField(
source="created_on",
help_text="The timestamp of when the media was indexed by Openverse.",
)

tags = TagSerializer(
allow_null=True, # replaced with ``[]`` in ``to_representation`` below
many=True,
Expand Down
17 changes: 11 additions & 6 deletions api/catalog/api/views/media_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,17 @@ def _get_request_serializer(self, request):
return req_serializer

def get_db_results(self, results):
hit_map = {hit.identifier: hit for hit in results}
results = self.get_queryset().filter(identifier__in=hit_map.keys())
for obj in results:
obj.fields_matched = getattr(
hit_map[str(obj.identifier)], "fields_matched", None
)
identifiers = []
hits = []
for hit in results:
identifiers.append(hit.identifier)
hits.append(hit)

results = list(self.get_queryset().filter(identifier__in=identifiers))
results.sort(key=lambda x: identifiers.index(str(x.identifier)))
for result, hit in zip(results, hits):
result.fields_matched = getattr(hit, "fields_matched", None)

return results

# Standard actions
Expand Down
6 changes: 6 additions & 0 deletions api/catalog/templates/drf-yasg/redoc.html
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,11 @@
img[alt="logo"] {
padding: 20px; /* same as other sidebar items */
}

/* Hide fields that are unstable and likely to change */
td[kind="field"][title^="unstable__"],
td[kind="field"][title^="unstable__"] ~ td {
display: none
}
</style>
{% endblock %}
25 changes: 25 additions & 0 deletions api/test/auth_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,31 @@ def test_auth_rate_limit_reporting(
assert res_data["verified"] is False


@pytest.mark.django_db
@pytest.mark.parametrize(
"sort_dir, exp_indexed_on",
[
("desc", "2022-12-31"),
("asc", "2022-01-01"),
],
)
def test_sorting_authed(
client, monkeypatch, test_auth_token_exchange, sort_dir, exp_indexed_on
):
# Prevent DB lookup for ES results because DB is empty.
monkeypatch.setattr("catalog.api.views.image_views.ImageSerializer.needs_db", False)

time.sleep(1)
token = test_auth_token_exchange["access_token"]
query_params = {"unstable__sort_by": "indexed_on", "unstable__sort_dir": sort_dir}
res = client.get("/v1/images/", query_params, HTTP_AUTHORIZATION=f"Bearer {token}")
assert res.status_code == 200

res_data = res.json()
indexed_on = res_data["results"][0]["indexed_on"][:10] # ``indexed_on`` is ISO.
assert indexed_on == exp_indexed_on


@pytest.mark.django_db
def test_page_size_limit_unauthed(client):
query_params = {"page_size": 20}
Expand Down
Loading

0 comments on commit 2ea8704

Please sign in to comment.