Skip to content

Commit

Permalink
Fetch all add-on instances in one DB request (#5289)
Browse files Browse the repository at this point in the history
  • Loading branch information
dhruvkb authored Dec 20, 2024
1 parent c4bdd35 commit 4a85e14
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 49 deletions.
13 changes: 0 additions & 13 deletions api/api/models/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,19 +224,6 @@ def duration_in_s(self):
def audio_set(self):
return getattr(self, "audioset")

def get_waveform(self) -> list[float]:
"""
Get the waveform if it exists. Return a blank list otherwise.
:return: the waveform, if it exists; empty list otherwise
"""

try:
add_on = AudioAddOn.objects.get(audio_identifier=self.identifier)
return add_on.waveform_peaks or []
except AudioAddOn.DoesNotExist:
return []

def get_or_create_waveform(self):
add_on, _ = AudioAddOn.objects.get_or_create(audio_identifier=self.identifier)

Expand Down
6 changes: 3 additions & 3 deletions api/api/serializers/audio_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,9 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def get_peaks(self, obj) -> list[int]:
if isinstance(obj, Hit):
obj = Audio.objects.get(identifier=obj.identifier)
return obj.get_waveform()
audio_addon = self.context.get("addons", {}).get(obj.identifier)
if audio_addon:
return audio_addon.waveform_peaks

def to_representation(self, instance):
# Get the original representation
Expand Down
5 changes: 5 additions & 0 deletions api/api/views/audio_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from api.docs.audio_docs import thumbnail as thumbnail_docs
from api.models import Audio
from api.models.audio import AudioAddOn
from api.serializers.audio_serializers import (
AudioReportRequestSerializer,
AudioSearchRequestSerializer,
Expand All @@ -38,6 +39,7 @@ class AudioViewSet(MediaViewSet):
"""Viewset for all endpoints pertaining to audio."""

model_class = Audio
addon_model_class = AudioAddOn
media_type = AUDIO_TYPE
query_serializer_class = AudioSearchRequestSerializer
default_index = settings.MEDIA_INDEX_MAPPING[AUDIO_TYPE]
Expand All @@ -47,6 +49,9 @@ class AudioViewSet(MediaViewSet):
def get_queryset(self):
return super().get_queryset().select_related("sensitive_audio", "audioset")

def include_addons(self, serializer):
return serializer.validated_data.get("peaks")

# Extra actions

async def get_image_proxy_media_info(self) -> image_proxy.MediaInfo:
Expand Down
42 changes: 36 additions & 6 deletions api/api/views/media_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from api.controllers import search_controller
from api.controllers.elasticsearch.related import related_media
from api.models import ContentSource
from api.models.base import OpenLedgerModel
from api.models.media import AbstractMedia
from api.serializers import media_serializers
from api.serializers.source_serializers import SourceSerializer
Expand Down Expand Up @@ -51,6 +52,7 @@ class MediaViewSet(AsyncViewSetMixin, AsyncAPIView, ReadOnlyModelViewSet):

# Populate these in the corresponding subclass
model_class: type[AbstractMedia] = None
addon_model_class: type[OpenLedgerModel] = None
media_type: MediaType | None = None
query_serializer_class = None
default_index = None
Expand Down Expand Up @@ -97,7 +99,11 @@ def _get_request_serializer(self, request):
req_serializer.is_valid(raise_exception=True)
return req_serializer

def get_db_results(self, results):
def get_db_results(
self,
results,
include_addons=False,
) -> tuple[list[AbstractMedia], list[OpenLedgerModel]]:
"""
Map ES hits to ORM model instances.
Expand All @@ -107,6 +113,7 @@ def get_db_results(self, results):
which is both unique and indexed, so it's quite performant.
:param results: the list of ES hits
:param include_addons: whether to include add-ons with results
:return: the corresponding list of ORM model instances
"""

Expand All @@ -121,7 +128,12 @@ def get_db_results(self, results):
for result, hit in zip(results, hits):
result.fields_matched = getattr(hit.meta, "highlight", None)

return results
if include_addons and self.addon_model_class:
addons = list(self.addon_model_class.objects.filter(pk__in=identifiers))
else:
addons = []

return (results, addons)

# Standard actions

Expand All @@ -147,6 +159,20 @@ def _validate_source(self, source):
detail=f"Invalid source '{source}'. Valid sources are: {valid_string}.",
)

def include_addons(self, serializer):
"""
Whether to include objects of the addon model when mapping hits to
objects of the media model.
If the media type has an addon model, this method should be overridden
in the subclass to return ``True`` based on serializer input.
:param serializer: the validated serializer instance
:return: whether to include addon model objects
"""

return False

def get_media_results(
self,
request,
Expand Down Expand Up @@ -188,9 +214,13 @@ def get_media_results(
except ValueError as e:
raise APIException(getattr(e, "message", str(e)))

serializer_context = search_context | self.get_serializer_context()

results = self.get_db_results(results)
include_addons = self.include_addons(params)
results, addons = self.get_db_results(results, include_addons)
serializer_context = (
search_context
| self.get_serializer_context()
| {"addons": {addon.audio_identifier: addon for addon in addons}}
)

serializer = self.get_serializer(results, many=True, context=serializer_context)
return self.get_paginated_response(serializer.data)
Expand Down Expand Up @@ -231,7 +261,7 @@ def related(self, request, identifier=None, *_, **__):

serializer_context = self.get_serializer_context()

results = self.get_db_results(results)
results, _ = self.get_db_results(results)

serializer = self.get_serializer(results, many=True, context=serializer_context)
return self.get_paginated_response(serializer.data)
Expand Down
8 changes: 6 additions & 2 deletions api/test/integration/test_dead_link_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ def get_empty_cached_statuses(_, image_urls):
_MAKE_HEAD_REQUESTS_MODULE_PATH = "api.utils.check_dead_links._make_head_requests"


def _mock_get_db_results(results, include_addons=False):
return (results, [])


def _patch_make_head_requests():
def _make_head_requests(urls, *args, **kwargs):
responses = []
Expand Down Expand Up @@ -67,7 +71,7 @@ def test_dead_link_filtering(mocked_map, api_client):
with patch(
"api.views.image_views.ImageViewSet.get_db_results"
) as mock_get_db_result:
mock_get_db_result.side_effect = lambda value: value
mock_get_db_result.side_effect = _mock_get_db_results
res_with_dead_links = api_client.get(
path,
query_params | {"filter_dead": False},
Expand Down Expand Up @@ -121,7 +125,7 @@ def test_dead_link_filtering_all_dead_links(
with patch(
"api.views.image_views.ImageViewSet.get_db_results"
) as mock_get_db_result:
mock_get_db_result.side_effect = lambda value: value
mock_get_db_result.side_effect = _mock_get_db_results
with patch_link_validation_dead_for_count(page_size / DEAD_LINK_RATIO):
response = api_client.get(
path,
Expand Down
25 changes: 0 additions & 25 deletions api/test/unit/models/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,28 +41,3 @@ def test_audio_waveform_caches(generate_peaks_mock, audio_fixture):
audio_fixture.delete()

assert AudioAddOn.objects.count() == 1


@pytest.mark.django_db
@mock.patch("api.models.audio.AudioAddOn.objects.get")
def test_audio_waveform_sent_when_present(get_mock, audio_fixture):
# When ``AudioAddOn.waveform_peaks`` exists, waveform is filled
peaks = [0, 0.25, 0.5, 0.25, 0.1]
get_mock.return_value = mock.Mock(waveform_peaks=peaks)
assert audio_fixture.get_waveform() == peaks


@pytest.mark.django_db
@mock.patch("api.models.audio.AudioAddOn.objects.get")
def test_audio_waveform_blank_when_absent(get_mock, audio_fixture):
# When ``AudioAddOn`` does not exist, waveform is blank
get_mock.side_effect = AudioAddOn.DoesNotExist()
assert audio_fixture.get_waveform() == []


@pytest.mark.django_db
@mock.patch("api.models.audio.AudioAddOn.objects.get")
def test_audio_waveform_blank_when_none(get_mock, audio_fixture):
# When ``AudioAddOn.waveform_peaks`` is None, waveform is blank
get_mock.return_value = mock.Mock(waveform_peaks=None)
assert audio_fixture.get_waveform() == []
39 changes: 39 additions & 0 deletions api/test/unit/views/test_audio_views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from unittest.mock import MagicMock, patch

import pytest
import pytest_django.asserts

from test.factory.models import AudioFactory


@pytest.mark.parametrize("peaks, query_count", [(True, 2), (False, 1)])
@pytest.mark.django_db
def test_peaks_param_determines_addons(api_client, peaks, query_count):
num_results = 20

# Since controller returns a list of ``Hit``s, not model instances, we must
# set the ``meta`` param on each of them to match the shape of ``Hit``.
results = AudioFactory.create_batch(size=num_results)
for result in results:
result.meta = None

controller_ret = (
results,
1, # num_pages
num_results,
{}, # search_context
)
with (
patch(
"api.views.media_views.search_controller",
query_media=MagicMock(return_value=controller_ret),
),
patch(
"api.serializers.media_serializers.search_controller",
get_sources=MagicMock(return_value={}),
),
pytest_django.asserts.assertNumQueries(query_count),
):
res = api_client.get(f"/v1/audio/?peaks={peaks}")

assert res.status_code == 200

0 comments on commit 4a85e14

Please sign in to comment.