diff --git a/openverse_catalog/dags/providers/factory_utils.py b/openverse_catalog/dags/providers/factory_utils.py index 89cbab88d..a2fac4e62 100644 --- a/openverse_catalog/dags/providers/factory_utils.py +++ b/openverse_catalog/dags/providers/factory_utils.py @@ -4,7 +4,7 @@ from types import FunctionType from typing import Callable, Sequence -from airflow.models import TaskInstance +from airflow.models import DagRun, TaskInstance from airflow.utils.dates import cron_presets from common.constants import MediaType from common.storage.media import MediaStore @@ -45,6 +45,7 @@ def generate_tsv_filenames( ingestion_callable: Callable, media_types: list[MediaType], ti: TaskInstance, + dag_run: DagRun, args: Sequence = None, ) -> None: """ @@ -71,7 +72,7 @@ def generate_tsv_filenames( f"Initializing ProviderIngester {ingestion_callable.__name__} in" f"order to generate store filenames." ) - ingester = ingestion_callable(*args) + ingester = ingestion_callable(dag_run.conf, *args) stores = ingester.media_stores # Push the media store output paths to XComs. @@ -87,6 +88,7 @@ def pull_media_wrapper( media_types: list[MediaType], tsv_filenames: list[str], ti: TaskInstance, + dag_run: DagRun, args: Sequence = None, ): """ @@ -116,7 +118,7 @@ def pull_media_wrapper( # A ProviderDataIngester class was passed instead. First we initialize the # class, which will initialize the media stores and DelayedRequester. logger.info(f"Initializing ProviderIngester {ingestion_callable.__name__}") - ingester = ingestion_callable(*args) + ingester = ingestion_callable(dag_run.conf, *args) stores = ingester.media_stores run_func = ingester.ingest_records # args have already been passed into the ingester, we don't need them passed diff --git a/openverse_catalog/dags/providers/provider_api_scripts/provider_data_ingester.py b/openverse_catalog/dags/providers/provider_api_scripts/provider_data_ingester.py index cf1531455..773b38d6d 100644 --- a/openverse_catalog/dags/providers/provider_api_scripts/provider_data_ingester.py +++ b/openverse_catalog/dags/providers/provider_api_scripts/provider_data_ingester.py @@ -1,7 +1,10 @@ +import json import logging +import traceback from abc import ABC, abstractmethod from typing import Dict, List, Optional, Tuple +from airflow.exceptions import AirflowException from airflow.models import Variable from common.requester import DelayedRequester, RetriesExceeded from common.storage.media import MediaStore @@ -12,6 +15,35 @@ logger = logging.getLogger(__name__) +class AggregateIngestionError(Exception): + """ + Custom exception when multiple ingestion errors are skipped and then + raised in aggregate at the end of ingestion. + """ + + pass + + +class IngestionError(Exception): + """ + Custom exception which includes information about the query_params that + were being used when the error was encountered. + """ + + def __init__(self, error, traceback, query_params): + self.error = error + self.traceback = traceback + self.query_params = json.dumps(query_params) + + def __str__(self): + # Append query_param info to error message + return f"{self.error}\nquery_params: {self.query_params}" + + def repr_with_traceback(self): + # Append traceback + return f"{str(self)}\n{self.traceback}" + + class ProviderDataIngester(ABC): """ An abstract base class that initializes media stores and ingests records @@ -51,9 +83,10 @@ def endpoint(self): """ pass - def __init__(self, date: str = None): + def __init__(self, conf: dict = None, date: str = None): """ Optional Arguments: + conf: The configuration dict for the running DagRun date: Date String in the form YYYY-MM-DD. This is the date for which running the script will pull data """ @@ -75,6 +108,24 @@ def __init__(self, date: str = None): self.media_stores = self.init_media_stores() self.date = date + # dag_run configuration options + conf = conf or {} + + # Used to skip over errors and continue ingestion. When enabled, errors + # are not reported until ingestion has completed. + self.skip_ingestion_errors = conf.get("skip_ingestion_errors", False) + self.ingestion_errors: List[IngestionError] = [] # Keep track of errors + + # An optional set of initial query params from which to begin ingestion. + self.initial_query_params = conf.get("initial_query_params") + + # An optional list of `query_params`. When provided, ingestion will be run for + # just these sets of params. + self.override_query_params = None + if query_params_list := conf.get("query_params_list"): + # Create a generator to facilitate fetching the next set of query_params. + self.override_query_params = (qp for qp in query_params_list) + def init_media_stores(self) -> dict[str, MediaStore]: """ Initialize a media store for each media type supported by this @@ -101,10 +152,14 @@ def ingest_records(self, **kwargs) -> None: logger.info(f"Begin ingestion for {self.__class__.__name__}") - try: - while should_continue: - query_params = self.get_next_query_params(query_params, **kwargs) + while should_continue: + query_params = self.get_query_params(query_params, **kwargs) + if query_params is None: + # Break out of ingestion if no query_params are supplied. This can + # happen when the final `override_query_params` is processed. + break + try: batch, should_continue = self.get_batch(query_params) if batch and len(batch) > 0: @@ -114,12 +169,100 @@ def ingest_records(self, **kwargs) -> None: logger.info("Batch complete.") should_continue = False - if self.limit and record_count >= self.limit: - logger.info(f"Ingestion limit of {self.limit} has been reached.") - should_continue = False - finally: - total = self.commit_records() - logger.info(f"Committed {total} records") + except AirflowException as error: + # AirflowExceptions should not be caught, as execution should not + # continue when the task is being stopped by Airflow. + + # If errors have already been caught during processing, raise them + # as well. + if error_summary := self.get_ingestion_errors(): + raise error_summary from error + raise + + except Exception as error: + ingestion_error = IngestionError( + error, traceback.format_exc(), query_params + ) + + if self.skip_ingestion_errors: + # Add this to the errors list but continue processing + self.ingestion_errors.append(ingestion_error) + logger.error(f"Skipping batch due to ingestion error: {error}") + continue + + # Commit whatever records we were able to process, and rethrow the + # exception so the taskrun fails. + self.commit_records() + raise error from ingestion_error + + if self.limit and record_count >= self.limit: + logger.info(f"Ingestion limit of {self.limit} has been reached.") + should_continue = False + + # Commit whatever records we were able to process + self.commit_records() + + # If errors were caught during processing, raise them now + if error_summary := self.get_ingestion_errors(): + raise error_summary + + def get_ingestion_errors(self) -> AggregateIngestionError | None: + """ + If any errors were skipped during ingestion, log them as well as the + associated query parameters. Then return an AggregateIngestionError. + + It there are no errors to report, returns None. + """ + if self.ingestion_errors: + # Log the affected query_params + bad_query_params = ", \n".join( + [f"{e.query_params}" for e in self.ingestion_errors] + ) + logger.info( + "The following query_params resulted in errors: \n" + f"{bad_query_params}" + ) + errors_str = "\n".join( + e.repr_with_traceback() for e in self.ingestion_errors + ) + logger.error( + f"The following errors were encountered during ingestion:\n{errors_str}" + ) + return AggregateIngestionError( + f"{len(self.ingestion_errors)} query batches were skipped due to " + "errors during ingestion using the `skip_ingestion_errors` flag. " + "See the log for more details." + ) + return None + + def get_query_params( + self, prev_query_params: Optional[Dict], **kwargs + ) -> Optional[Dict]: + """ + Returns the next set of query_params for the next request, handling + optional overrides via the dag_run conf. + """ + # If we are getting query_params for the first batch and initial_query_params + # have been set, return them. + if prev_query_params is None and self.initial_query_params: + logger.info( + "Using initial_query_params from dag_run conf:" + f" {json.dumps(self.initial_query_params)}" + ) + return self.initial_query_params + + # If a list of query_params was provided, return the next value. + if self.override_query_params: + next_params = next(self.override_query_params, None) + logger.info( + "Using query params from `query_params_list` set in dag_run conf:" + f" {next_params}" + ) + return next_params + + # Default behavior when no conf options are provided; build the next + # set of query params, given the previous. + return self.get_next_query_params(prev_query_params, **kwargs) @abstractmethod def get_next_query_params( @@ -154,18 +297,9 @@ def get_batch(self, query_params: Dict) -> Tuple[Optional[List], bool]: batch = None should_continue = True + # Get the API response try: - # Get the API response response_json = self.get_response_json(query_params) - - # Build a list of records from the response - batch = self.get_batch_data(response_json) - - # Optionally, apply some logic to the response to determine whether - # ingestion should continue or if should be short-circuited. By default - # this will return True and ingestion continues. - should_continue = self.get_should_continue(response_json) - except ( RequestException, RetriesExceeded, @@ -173,7 +307,16 @@ def get_batch(self, query_params: Dict) -> Tuple[Optional[List], bool]: ValueError, TypeError, ) as e: - logger.error(f"Error getting next query parameters due to {e}") + logger.error(f"Error getting response due to {e}") + response_json = None + + # Build a list of records from the response + batch = self.get_batch_data(response_json) + + # Optionally, apply some logic to the response to determine whether + # ingestion should continue or if should be short-circuited. By default + # this will return True and ingestion continues. + should_continue = self.get_should_continue(response_json) return batch, should_continue @@ -260,4 +403,5 @@ def commit_records(self) -> int: total = 0 for store in self.media_stores.values(): total += store.commit() + logger.info(f"Committed {total} records") return total diff --git a/openverse_catalog/dags/providers/provider_dag_factory.py b/openverse_catalog/dags/providers/provider_dag_factory.py index 3aeb8c59d..28cafc4ef 100644 --- a/openverse_catalog/dags/providers/provider_dag_factory.py +++ b/openverse_catalog/dags/providers/provider_dag_factory.py @@ -39,6 +39,17 @@ previously downloaded data, and update any data that needs updating (eg. popularity metrics). +Provider workflows which extend the ProviderDataIngester class support a few DagRun +configuration variables: + +* `skip_ingestion_errors`: When set to true, errors encountered during ingestion will +be caught to allow ingestion to continue. The `pull_data` task will still fail when +ingestion is complete, and report a summary of all encountered errors. By default +`skip_ingestion_errors` is False. +* `initial_query_params`: An optional dict of query parameters with which to begin +ingestion. This allows a user to manually force ingestion to resume from a particular +batch, for example when retrying after an error. + You can find more background information on the loading process in the following issues and related PRs: diff --git a/tests/dags/providers/provider_api_scripts/test_provider_data_ingester.py b/tests/dags/providers/provider_api_scripts/test_provider_data_ingester.py index a3963c9e9..d59b36078 100644 --- a/tests/dags/providers/provider_api_scripts/test_provider_data_ingester.py +++ b/tests/dags/providers/provider_api_scripts/test_provider_data_ingester.py @@ -1,11 +1,14 @@ import json import os -import unittest from unittest.mock import call, patch import pytest +from airflow.exceptions import AirflowException from common.storage.audio import AudioStore, MockAudioStore from common.storage.image import ImageStore, MockImageStore +from providers.provider_api_scripts.provider_data_ingester import ( + AggregateIngestionError, +) from tests.dags.providers.provider_api_scripts.resources.provider_data_ingester.mock_provider_data_ingester import ( AUDIO_PROVIDER, @@ -20,223 +23,340 @@ os.path.abspath(os.path.dirname(__file__)), "resources/provider_data_ingester" ) +ingester = MockProviderDataIngester() +audio_store = MockAudioStore(AUDIO_PROVIDER) +image_store = MockImageStore(IMAGE_PROVIDER) +ingester.media_stores = {"audio": audio_store, "image": image_store} -class TestProviderDataIngester(unittest.TestCase): - def setUp(self): - self.ingester = MockProviderDataIngester() - # Use mock media stores - self.audio_store = MockAudioStore(AUDIO_PROVIDER) - self.image_store = MockImageStore(IMAGE_PROVIDER) - self.ingester.media_stores = { - "audio": self.audio_store, - "image": self.image_store, - } +def _get_resource_json(json_name): + with open(os.path.join(RESOURCES, json_name)) as f: + resource_json = json.load(f) + return resource_json - def _get_resource_json(self, json_name): - with open(os.path.join(RESOURCES, json_name)) as f: - resource_json = json.load(f) - return resource_json - def test_init_media_stores(self): - ingester = MockProviderDataIngester() +def test_init_media_stores(): + ingester = MockProviderDataIngester() + + # We should have two media stores, with the correct types + assert len(ingester.media_stores) == 2 + assert isinstance(ingester.media_stores["audio"], AudioStore) + assert isinstance(ingester.media_stores["image"], ImageStore) + + +def test_init_with_date(): + ingester = MockProviderDataIngester(date="2020-06-27") + assert ingester.date == "2020-06-27" + - # We should have two media stores, with the correct types - assert len(ingester.media_stores) == 2 - assert isinstance(ingester.media_stores["audio"], AudioStore) - assert isinstance(ingester.media_stores["image"], ImageStore) +def test_init_without_date(): + ingester = MockProviderDataIngester() + assert ingester.date is None - def test_init_with_date(self): - ingester = MockProviderDataIngester(date="2020-06-27") - assert ingester.date == "2020-06-27" - def test_init_without_date(self): +def test_batch_limit_is_capped_to_ingestion_limit(): + with patch( + "providers.provider_api_scripts.provider_data_ingester.Variable" + ) as MockVariable: + MockVariable.get.side_effect = [20] + ingester = MockProviderDataIngester() - assert ingester.date is None + assert ingester.batch_limit == 20 + assert ingester.limit == 20 - def test_batch_limit_is_capped_to_ingestion_limit(self): - with patch( - "providers.provider_api_scripts.provider_data_ingester.Variable" - ) as MockVariable: - MockVariable.get.side_effect = [20] - ingester = MockProviderDataIngester() - assert ingester.batch_limit == 20 - assert ingester.limit == 20 +def test_get_batch_data(): + response_json = _get_resource_json("complete_response.json") + batch = ingester.get_batch_data(response_json) - def test_get_batch_data(self): - response_json = self._get_resource_json("complete_response.json") - batch = self.ingester.get_batch_data(response_json) + assert batch == EXPECTED_BATCH_DATA - assert batch == EXPECTED_BATCH_DATA - def test_process_batch_adds_items_to_correct_media_stores(self): - with ( - patch.object(self.audio_store, "add_item") as audio_store_mock, - patch.object(self.image_store, "add_item") as image_store_mock, - ): - record_count = self.ingester.process_batch(EXPECTED_BATCH_DATA) +def test_process_batch_adds_items_to_correct_media_stores(): + with ( + patch.object(audio_store, "add_item") as audio_store_mock, + patch.object(image_store, "add_item") as image_store_mock, + ): + record_count = ingester.process_batch(EXPECTED_BATCH_DATA) - assert record_count == 3 - assert audio_store_mock.call_count == 1 - assert image_store_mock.call_count == 2 + assert record_count == 3 + assert audio_store_mock.call_count == 1 + assert image_store_mock.call_count == 2 - def test_process_batch_handles_list_of_records(self): - with ( - patch.object(self.audio_store, "add_item") as audio_store_mock, - patch.object(self.image_store, "add_item") as image_store_mock, - patch.object(self.ingester, "get_record_data") as get_record_data_mock, - ): - # Mock `get_record_data` to return a list of records - get_record_data_mock.return_value = MOCK_RECORD_DATA_LIST - record_count = self.ingester.process_batch(EXPECTED_BATCH_DATA[:1]) +def test_process_batch_handles_list_of_records(): + with ( + patch.object(audio_store, "add_item") as audio_store_mock, + patch.object(image_store, "add_item") as image_store_mock, + patch.object(ingester, "get_record_data") as get_record_data_mock, + ): + # Mock `get_record_data` to return a list of records + get_record_data_mock.return_value = MOCK_RECORD_DATA_LIST - # Both records are added, and to the appropriate stores - assert record_count == 2 - assert audio_store_mock.call_count == 1 - assert image_store_mock.call_count == 1 + record_count = ingester.process_batch(EXPECTED_BATCH_DATA[:1]) - def test_ingest_records(self): - with ( - patch.object(self.ingester, "get_batch") as get_batch_mock, - patch.object( - self.ingester, "process_batch", return_value=3 - ) as process_batch_mock, - patch.object(self.ingester, "commit_records") as commit_mock, - ): - get_batch_mock.side_effect = [ - (EXPECTED_BATCH_DATA, True), # First batch - (EXPECTED_BATCH_DATA, True), # Second batch - (None, True), # Final batch - ] + # Both records are added, and to the appropriate stores + assert record_count == 2 + assert audio_store_mock.call_count == 1 + assert image_store_mock.call_count == 1 - self.ingester.ingest_records() - # get_batch is not called again after getting None - assert get_batch_mock.call_count == 3 +def test_ingest_records(): + with ( + patch.object(ingester, "get_batch") as get_batch_mock, + patch.object(ingester, "process_batch", return_value=3) as process_batch_mock, + patch.object(ingester, "commit_records") as commit_mock, + ): + get_batch_mock.side_effect = [ + (EXPECTED_BATCH_DATA, True), # First batch + (EXPECTED_BATCH_DATA, True), # Second batch + (None, True), # Final batch + ] - # process_batch is called for each batch - process_batch_mock.assert_has_calls( - [ - call(EXPECTED_BATCH_DATA), - call(EXPECTED_BATCH_DATA), - ] - ) - # process_batch is not called for a third time with None - assert process_batch_mock.call_count == 2 + ingester.ingest_records() - assert commit_mock.called + # get_batch is not called again after getting None + assert get_batch_mock.call_count == 3 - def test_ingest_records_halts_ingestion_when_should_continue_is_false(self): - with ( - patch.object(self.ingester, "get_batch") as get_batch_mock, - patch.object( - self.ingester, "process_batch", return_value=3 - ) as process_batch_mock, - ): - get_batch_mock.side_effect = [ - (EXPECTED_BATCH_DATA, False), # First batch, should_continue is False + # process_batch is called for each batch + process_batch_mock.assert_has_calls( + [ + call(EXPECTED_BATCH_DATA), + call(EXPECTED_BATCH_DATA), ] + ) + # process_batch is not called for a third time with None + assert process_batch_mock.call_count == 2 - self.ingester.ingest_records() + assert commit_mock.called - # get_batch is not called a second time - assert get_batch_mock.call_count == 1 - assert process_batch_mock.call_count == 1 - process_batch_mock.assert_has_calls( - [ - call(EXPECTED_BATCH_DATA), - ] - ) +def test_ingest_records_halts_ingestion_when_should_continue_is_false(): + with ( + patch.object(ingester, "get_batch") as get_batch_mock, + patch.object(ingester, "process_batch", return_value=3) as process_batch_mock, + ): + get_batch_mock.side_effect = [ + (EXPECTED_BATCH_DATA, False), # First batch, should_continue is False + ] - def test_ingest_records_does_not_process_empty_batch(self): - with ( - patch.object(self.ingester, "get_batch") as get_batch_mock, - patch.object( - self.ingester, "process_batch", return_value=3 - ) as process_batch_mock, - ): - get_batch_mock.side_effect = [ - ([], True), # Empty batch + ingester.ingest_records() + + # get_batch is not called a second time + assert get_batch_mock.call_count == 1 + + assert process_batch_mock.call_count == 1 + process_batch_mock.assert_has_calls( + [ + call(EXPECTED_BATCH_DATA), ] + ) - self.ingester.ingest_records() - # get_batch is not called a second time - assert get_batch_mock.call_count == 1 - # process_batch is not called with an empty batch - assert not process_batch_mock.called - - def test_ingest_records_stops_after_reaching_limit(self): - # Set the ingestion limit for the test to one batch - with patch( - "providers.provider_api_scripts.provider_data_ingester.Variable" - ) as MockVariable: - # Mock the calls to Variable.get, in order - MockVariable.get.side_effect = [3] - - ingester = MockProviderDataIngester() - - with ( - patch.object(ingester, "get_batch") as get_batch_mock, - patch.object( - ingester, "process_batch", return_value=3 - ) as process_batch_mock, - ): - get_batch_mock.side_effect = [ - (EXPECTED_BATCH_DATA, True), # First batch - (EXPECTED_BATCH_DATA, True), # Second batch - (None, True), # Final batch - ] - - ingester.ingest_records() - - # get_batch is not called again after the first batch - assert get_batch_mock.call_count == 1 - assert process_batch_mock.call_count == 1 - - def test_ingest_records_commits_on_exception(self): +def test_ingest_records_does_not_process_empty_batch(): + with ( + patch.object(ingester, "get_batch") as get_batch_mock, + patch.object(ingester, "process_batch", return_value=3) as process_batch_mock, + ): + get_batch_mock.side_effect = [ + ([], True), # Empty batch + ] + + ingester.ingest_records() + + # get_batch is not called a second time + assert get_batch_mock.call_count == 1 + # process_batch is not called with an empty batch + assert not process_batch_mock.called + + +def test_ingest_records_stops_after_reaching_limit(): + # Set the ingestion limit for the test to one batch + with patch( + "providers.provider_api_scripts.provider_data_ingester.Variable" + ) as MockVariable: + # Mock the calls to Variable.get, in order + MockVariable.get.side_effect = [3] + + ingester = MockProviderDataIngester() + with ( - patch.object(self.ingester, "get_batch") as get_batch_mock, + patch.object(ingester, "get_batch") as get_batch_mock, patch.object( - self.ingester, "process_batch", return_value=3 + ingester, "process_batch", return_value=3 ) as process_batch_mock, - patch.object(self.ingester, "commit_records") as commit_mock, ): get_batch_mock.side_effect = [ (EXPECTED_BATCH_DATA, True), # First batch (EXPECTED_BATCH_DATA, True), # Second batch - ValueError("Whoops :C"), # Problem batch - (EXPECTED_BATCH_DATA, True), # Fourth batch, should not be reached + (None, True), # Final batch ] - with pytest.raises(ValueError, match="Whoops :C"): - self.ingester.ingest_records() + ingester.ingest_records() - # Check that get batch was only called thrice - assert get_batch_mock.call_count == 3 - - # process_batch is called for each successful batch - process_batch_mock.assert_has_calls( - [ - call(EXPECTED_BATCH_DATA), - call(EXPECTED_BATCH_DATA), - ] - ) - # process_batch is not called for a third time with exception - assert process_batch_mock.call_count == 2 + # get_batch is not called again after the first batch + assert get_batch_mock.call_count == 1 + assert process_batch_mock.call_count == 1 - # Even with the exception, records were still saved - assert commit_mock.called - def test_commit_commits_all_stores(self): - with ( - patch.object(self.audio_store, "commit") as audio_store_mock, - patch.object(self.image_store, "commit") as image_store_mock, - ): - self.ingester.commit_records() - - assert audio_store_mock.called - assert image_store_mock.called +def test_ingest_records_commits_on_exception(): + with ( + patch.object(ingester, "get_batch") as get_batch_mock, + patch.object(ingester, "process_batch", return_value=3) as process_batch_mock, + patch.object(ingester, "commit_records") as commit_mock, + ): + get_batch_mock.side_effect = [ + (EXPECTED_BATCH_DATA, True), # First batch + (EXPECTED_BATCH_DATA, True), # Second batch + ValueError("Whoops :C"), # Problem batch + (EXPECTED_BATCH_DATA, True), # Fourth batch, should not be reached + ] + + with pytest.raises(ValueError, match="Whoops :C"): + ingester.ingest_records() + + # Check that get batch was only called thrice + assert get_batch_mock.call_count == 3 + + # process_batch is called for each successful batch + process_batch_mock.assert_has_calls( + [ + call(EXPECTED_BATCH_DATA), + call(EXPECTED_BATCH_DATA), + ] + ) + # process_batch is not called for a third time with exception + assert process_batch_mock.call_count == 2 + + # Even with the exception, records were still saved + assert commit_mock.called + + +def test_ingest_records_uses_initial_query_params_from_dagrun_conf(): + # Initialize the ingester with a conf + ingester = MockProviderDataIngester( + {"initial_query_params": {"has_image": 1, "page": 5}} + ) + + # Mock get_batch to halt ingestion after a single batch + with ( + patch.object(ingester, "get_batch", return_value=([], False)) as get_batch_mock, + ): + ingester.ingest_records() + + # get_batch is called with the query_params supplied in the conf + get_batch_mock.assert_called_with({"has_image": 1, "page": 5}) + + +def test_ingest_records_uses_query_params_list_from_dagrun_conf(): + # Initialize the ingester with a conf + ingester = MockProviderDataIngester( + { + "query_params_list": [ + {"has_image": 1, "page": 5}, + {"has_image": 1, "page": 12}, + {"has_image": 1, "page": 142}, + ] + } + ) + + with ( + patch.object( + ingester, "get_batch", return_value=(EXPECTED_BATCH_DATA, True) + ) as get_batch_mock, + patch.object(ingester, "process_batch", return_value=3), + ): + ingester.ingest_records() + + # get_batch is called only three times, for each set of query_params + # in the list, even though `should_continue` is still True + assert get_batch_mock.call_count == 3 + get_batch_mock.assert_has_calls( + [ + call({"has_image": 1, "page": 5}), + call({"has_image": 1, "page": 12}), + call({"has_image": 1, "page": 142}), + ] + ) + + +def test_ingest_records_raises_IngestionError(): + with (patch.object(ingester, "get_batch") as get_batch_mock,): + get_batch_mock.side_effect = [ + Exception("Mock exception message"), + (EXPECTED_BATCH_DATA, True), # Second batch should not be reached + ] + + with pytest.raises(Exception, match="Mock exception message"): + ingester.ingest_records() + + # By default, `skip_ingestion_errors` is False and get_batch_data + # is no longer called after encountering an error + assert get_batch_mock.call_count == 1 + + +@pytest.mark.parametrize( + "batches, expected_call_count, expected_error", + [ + # Multiple errors are skipped + ( + [ + Exception("Mock exception 1"), + (EXPECTED_BATCH_DATA, True), # First error + Exception("Mock exception 2"), + (EXPECTED_BATCH_DATA, False), # Second error, `should_continue` False + ], + 4, # get_batch is called until `should_continue` is False, ignoring errors + AggregateIngestionError, + ), + # An AirflowException should not be skipped + ( + [ + (EXPECTED_BATCH_DATA, True), + AirflowException("An Airflow exception"), # Second batch, should raise + (EXPECTED_BATCH_DATA, True), # This batch should not be reached + ], + 2, # The final batch should not be reached + AirflowException, + ), + # An AirflowException is raised, but there were already other ingestion errors + ( + [ + Exception("Some other exception"), # First batch, should be skipped + AirflowException("An Airflow exception"), # Second batch, should raise + (EXPECTED_BATCH_DATA, True), # This batch should not be reached + ], + 2, # The final batch should not be reached + AggregateIngestionError, # Ingestion errors reported + ), + ], +) +def test_ingest_records_with_skip_ingestion_errors( + batches, expected_call_count, expected_error +): + ingester = MockProviderDataIngester({"skip_ingestion_errors": True}) + + with ( + patch.object(ingester, "get_batch") as get_batch_mock, + patch.object(ingester, "process_batch", return_value=10), + ): + get_batch_mock.side_effect = batches + + # ingest_records ultimately raises an exception + with pytest.raises(expected_error): + ingester.ingest_records() + + # get_batch was called four times before the exception was thrown, + # despite errors being raised + assert get_batch_mock.call_count == expected_call_count + + +def test_commit_commits_all_stores(): + with ( + patch.object(audio_store, "commit") as audio_store_mock, + patch.object(image_store, "commit") as image_store_mock, + ): + ingester.commit_records() + + assert audio_store_mock.called + assert image_store_mock.called diff --git a/tests/dags/providers/test_factory_utils.py b/tests/dags/providers/test_factory_utils.py index ab33c08c4..7a2521fba 100644 --- a/tests/dags/providers/test_factory_utils.py +++ b/tests/dags/providers/test_factory_utils.py @@ -2,7 +2,7 @@ import pytest import requests -from airflow.models import TaskInstance +from airflow.models import DagRun, TaskInstance from pendulum import datetime from providers import factory_utils @@ -17,6 +17,11 @@ def ti_mock() -> TaskInstance: return mock.MagicMock(spec=TaskInstance) +@pytest.fixture +def dagrun_mock() -> DagRun: + return mock.MagicMock(spec=DagRun) + + @pytest.fixture def internal_func_mock(): """ @@ -29,7 +34,7 @@ def internal_func_mock(): fdi = FakeDataIngester() -def _set_up_ingester(mock_func, value): +def _set_up_ingester(mock_conf, mock_func, value): """ Set up ingest records as a proxy for calling the mock function, then return the instance. This is necessary because the args are only handed in during @@ -111,12 +116,15 @@ def test_load_provider_script(func, media_types, stores): (FakeDataIngesterClass, 2, list(fdi.media_stores.values())), ], ) -def test_generate_tsv_filenames(func, media_types, stores, ti_mock, internal_func_mock): +def test_generate_tsv_filenames( + func, media_types, stores, ti_mock, dagrun_mock, internal_func_mock +): value = 42 factory_utils.generate_tsv_filenames( func, media_types, ti_mock, + dagrun_mock, args=[internal_func_mock, value], ) # There should be one call to xcom_push for each provided store @@ -189,7 +197,7 @@ def test_generate_tsv_filenames(func, media_types, stores, ti_mock, internal_fun ], ) def test_pull_media_wrapper( - func, media_types, tsv_filenames, stores, ti_mock, internal_func_mock + func, media_types, tsv_filenames, stores, ti_mock, dagrun_mock, internal_func_mock ): value = 42 factory_utils.pull_media_wrapper( @@ -197,6 +205,7 @@ def test_pull_media_wrapper( media_types, tsv_filenames, ti_mock, + dagrun_mock, args=[internal_func_mock, value], ) # We should have one XCom push for duration