diff --git a/catalog/DAGs.md b/catalog/DAGs.md index 2405915d043..f26d87939b2 100644 --- a/catalog/DAGs.md +++ b/catalog/DAGs.md @@ -743,10 +743,13 @@ https://www.rawpixel.com/api/v1/search?tags=$publicdomain&page=1&pagesize=100 ## `recreate_audio_popularity_calculation` This file generates Apache Airflow DAGs that, for the given media type, -completely wipe out the PostgreSQL relations and functions involved in -calculating our standardized popularity metric. It then recreates relations and -functions to make the calculation, and performs an initial calculation. The -results are available in the materialized view for that media type. +completely wipes out and recreates the PostgreSQL functions involved in +calculating our standardized popularity metric. + +Note that they do not drop any tables or views related to popularity, and they +do not perform any popularity calculations. Once this DAG has been run, the +associated popularity refresh DAG must be run in order to actually recalculate +popularity constants and standardized popularity scores using the new functions. These DAGs are not on a schedule, and should only be run manually when new SQL code is deployed for the calculation. @@ -754,10 +757,13 @@ code is deployed for the calculation. ## `recreate_image_popularity_calculation` This file generates Apache Airflow DAGs that, for the given media type, -completely wipe out the PostgreSQL relations and functions involved in -calculating our standardized popularity metric. It then recreates relations and -functions to make the calculation, and performs an initial calculation. The -results are available in the materialized view for that media type. +completely wipes out and recreates the PostgreSQL functions involved in +calculating our standardized popularity metric. + +Note that they do not drop any tables or views related to popularity, and they +do not perform any popularity calculations. Once this DAG has been run, the +associated popularity refresh DAG must be run in order to actually recalculate +popularity constants and standardized popularity scores using the new functions. These DAGs are not on a schedule, and should only be run manually when new SQL code is deployed for the calculation. diff --git a/catalog/dags/common/constants.py b/catalog/dags/common/constants.py index 5ec93f6e9e6..f660d49abd4 100644 --- a/catalog/dags/common/constants.py +++ b/catalog/dags/common/constants.py @@ -1,4 +1,5 @@ import os +from dataclasses import dataclass from datetime import datetime, timedelta from typing import Literal @@ -36,3 +37,41 @@ AWS_RDS_CONN_ID = os.environ.get("AWS_RDS_CONN_ID", AWS_CONN_ID) ES_PROD_HTTP_CONN_ID = "elasticsearch_http_production" REFRESH_POKE_INTERVAL = int(os.getenv("DATA_REFRESH_POKE_INTERVAL", 60 * 30)) + + +@dataclass +class SQLInfo: + """ + Configuration object for a media type's popularity SQL info. + + Required Constructor Arguments: + + media_table: name of the main media table + metrics_table: name of the popularity metrics table + standardized_popularity_fn: name of the standardized_popularity sql + function + popularity_percentile_fn: name of the popularity percentile sql + function + + """ + + media_table: str + metrics_table: str + standardized_popularity_fn: str + popularity_percentile_fn: str + + +SQL_INFO_BY_MEDIA_TYPE = { + AUDIO: SQLInfo( + media_table=AUDIO, + metrics_table="audio_popularity_metrics", + standardized_popularity_fn="standardized_audio_popularity", + popularity_percentile_fn="audio_popularity_percentile", + ), + IMAGE: SQLInfo( + media_table=IMAGE, + metrics_table="image_popularity_metrics", + standardized_popularity_fn="standardized_image_popularity", + popularity_percentile_fn="image_popularity_percentile", + ), +} diff --git a/catalog/dags/common/loader/sql.py b/catalog/dags/common/loader/sql.py index 1e6089fbdb0..9a5fd33dbf9 100644 --- a/catalog/dags/common/loader/sql.py +++ b/catalog/dags/common/loader/sql.py @@ -4,29 +4,24 @@ from airflow.models.abstractoperator import AbstractOperator from psycopg2.errors import InvalidTextRepresentation -from common.constants import AUDIO, IMAGE, MediaType +from common.constants import IMAGE, MediaType, SQLInfo from common.loader import provider_details as prov from common.loader.paths import _extract_media_type -from common.popularity.constants import ( - STANDARDIZED_AUDIO_POPULARITY_FUNCTION, - STANDARDIZED_IMAGE_POPULARITY_FUNCTION, -) from common.sql import PostgresHook from common.storage import columns as col from common.storage.columns import NULL, Column, UpsertStrategy -from common.storage.db_columns import AUDIO_TABLE_COLUMNS, IMAGE_TABLE_COLUMNS +from common.storage.db_columns import setup_db_columns_for_media_type from common.storage.tsv_columns import ( COLUMNS, - CURRENT_AUDIO_TSV_COLUMNS, - CURRENT_IMAGE_TSV_COLUMNS, - required_columns, + REQUIRED_COLUMNS, + setup_tsv_columns_for_media_type, ) +from common.utils import setup_sql_info_for_media_type logger = logging.getLogger(__name__) LOAD_TABLE_NAME_STUB = "load_" -TABLE_NAMES = {AUDIO: AUDIO, IMAGE: IMAGE} DB_USER_NAME = "deploy" NOW = "NOW()" FALSE = "'f'" @@ -44,14 +39,6 @@ prov.SMK_DEFAULT_PROVIDER: "1 month 3 days", } -DB_COLUMNS = { - IMAGE: IMAGE_TABLE_COLUMNS, - AUDIO: AUDIO_TABLE_COLUMNS, -} -TSV_COLUMNS = { - AUDIO: CURRENT_AUDIO_TSV_COLUMNS, - IMAGE: CURRENT_IMAGE_TSV_COLUMNS, -} CURRENT_TSV_VERSION = "001" RETURN_ROW_COUNT = lambda c: c.rowcount # noqa: E731 @@ -67,10 +54,13 @@ def create_column_definitions(table_columns: list[Column], is_loading=True): return ",\n ".join(definitions) +@setup_tsv_columns_for_media_type def create_loading_table( postgres_conn_id: str, identifier: str, - media_type: str = IMAGE, + *, + media_type: str, + tsv_columns: list[Column], ): """Create intermediary table and indices if they do not exist.""" load_table = _get_load_table_name(identifier, media_type=media_type) @@ -78,8 +68,7 @@ def create_loading_table( postgres_conn_id=postgres_conn_id, default_statement_timeout=10.0, ) - loading_table_columns = TSV_COLUMNS[media_type] - columns_definition = f"{create_column_definitions(loading_table_columns)}" + columns_definition = f"{create_column_definitions(tsv_columns)}" table_creation_query = dedent( f""" CREATE UNLOGGED TABLE public.{load_table}( @@ -216,7 +205,7 @@ def clean_intermediate_table_data( ) missing_columns = 0 - for column in required_columns: + for column in REQUIRED_COLUMNS: missing_columns += postgres.run( f"DELETE FROM {load_table} WHERE {column.db_name} IS NULL;", handler=RETURN_ROW_COUNT, @@ -268,13 +257,16 @@ def _is_tsv_column_from_different_version( ) +@setup_sql_info_for_media_type +@setup_db_columns_for_media_type def upsert_records_to_db_table( postgres_conn_id: str, identifier: str, - db_table: str = None, - media_type: str = IMAGE, + *, + media_type: str, + db_columns: list[Column], + sql_info: SQLInfo, tsv_version: str = CURRENT_TSV_VERSION, - popularity_function: str = STANDARDIZED_IMAGE_POPULARITY_FUNCTION, task: AbstractOperator = None, ): """ @@ -285,35 +277,28 @@ def upsert_records_to_db_table( :param postgres_conn_id :param identifier - :param db_table :param media_type :param tsv_version: The version of TSV being processed. This determines which columns are used in the upsert query. :param task To be automagically passed by airflow. :return: """ - if db_table is None: - db_table = TABLE_NAMES.get(media_type, TABLE_NAMES[IMAGE]) - - if media_type is AUDIO: - popularity_function = STANDARDIZED_AUDIO_POPULARITY_FUNCTION - load_table = _get_load_table_name(identifier, media_type=media_type) - logger.info(f"Upserting new records into {db_table}.") + logger.info(f"Upserting new records into {sql_info.media_table}.") postgres = PostgresHook( postgres_conn_id=postgres_conn_id, default_statement_timeout=PostgresHook.get_execution_timeout(task), ) # Remove identifier column - db_columns: list[Column] = DB_COLUMNS[media_type][1:] + db_columns = db_columns[1:] column_inserts = {} column_conflict_values = {} for column in db_columns: args = [] if column.db_name == col.STANDARDIZED_POPULARITY.db_name: args = [ - popularity_function, + sql_info.standardized_popularity_fn, ] if column.upsert_strategy == UpsertStrategy.no_change: @@ -331,13 +316,13 @@ def upsert_records_to_db_table( upsert_conflict_string = ",\n ".join(column_conflict_values.values()) upsert_query = dedent( f""" - INSERT INTO {db_table} AS old + INSERT INTO {sql_info.media_table} AS old ({col.DIRECT_URL.name}, {', '.join(column_inserts.keys())}) SELECT DISTINCT ON ({col.DIRECT_URL.name}) {col.DIRECT_URL.name}, {', '.join(column_inserts.values())} FROM {load_table} as new WHERE NOT EXISTS ( - SELECT {col.DIRECT_URL.name} from {db_table} + SELECT {col.DIRECT_URL.name} from {sql_info.media_table} WHERE {col.DIRECT_URL.name} = new.{col.DIRECT_URL.name} AND MD5({col.FOREIGN_ID.name}) <> MD5(new.{col.FOREIGN_ID.name}) ) diff --git a/catalog/dags/common/popularity/README.md b/catalog/dags/common/popularity/README.md deleted file mode 100644 index 2209225cac0..00000000000 --- a/catalog/dags/common/popularity/README.md +++ /dev/null @@ -1,29 +0,0 @@ -# Popularity - -This code allows for the calculation of image popularity within a provider. For -example, this allows us to boost Flickr results that have more views than -others. - -## What this code does - -1. Dump the popularity metrics for each row into a TSV. -2. Compute the 85th percentile for each metric, which is required for the - popularity calculation. This is a heavyweight database calculation, so we - cache it for a really long time. -3. Iterate through the TSV calculating the popularity for each row. -4. UPDATE all rows, setting the `normalized_popularity` key in the `meta_data` - column. - -## To start calculating popularity data for a provider - -1. In your provider script, store the popularity metric you'd like to track in - the `meta_data` column. See - [stocksnap](https://github.com/WordPress/openverse-catalog/blob/6c172033e42a91bcd8f9bf78fd6b933a70bd88bf/openverse_catalog/dags/provider_api_scripts/stocksnap.py#L175-L185) - as an example. -2. Add the provider name and metric to the `IMAGE_POPULARITY_METRICS` constant - in the [popularity/sql.py](sql.py) file. -3. Add the new provider and metric to the - `INSERT INTO public.image_popularity_metrics` statement in - [0004_openledger_image_view.sql](../../../../docker/upstream_db/0004_openledger_image_view.sql#L45). - For now all percentiles should be set to `.85`, this may be adjusted in the - future. diff --git a/catalog/dags/common/popularity/constants.py b/catalog/dags/common/popularity/constants.py deleted file mode 100644 index c2588840e44..00000000000 --- a/catalog/dags/common/popularity/constants.py +++ /dev/null @@ -1,7 +0,0 @@ -IMAGE_VIEW_NAME = "image_view" -AUDIO_VIEW_NAME = "audio_view" -AUDIOSET_VIEW_NAME = "audioset_view" -IMAGE_POPULARITY_PERCENTILE_FUNCTION = "image_popularity_percentile" -AUDIO_POPULARITY_PERCENTILE_FUNCTION = "audio_popularity_percentile" -STANDARDIZED_IMAGE_POPULARITY_FUNCTION = "standardized_image_popularity" -STANDARDIZED_AUDIO_POPULARITY_FUNCTION = "standardized_audio_popularity" diff --git a/catalog/dags/common/popularity/sql.py b/catalog/dags/common/popularity/sql.py deleted file mode 100644 index 939a9f65f95..00000000000 --- a/catalog/dags/common/popularity/sql.py +++ /dev/null @@ -1,507 +0,0 @@ -from collections import namedtuple -from datetime import timedelta -from textwrap import dedent - -from airflow.decorators import task, task_group -from airflow.models.abstractoperator import AbstractOperator - -from common.constants import AUDIO, DAG_DEFAULT_ARGS, IMAGE -from common.loader.sql import TABLE_NAMES -from common.popularity.constants import ( - AUDIO_POPULARITY_PERCENTILE_FUNCTION, - AUDIO_VIEW_NAME, - IMAGE_POPULARITY_PERCENTILE_FUNCTION, - IMAGE_VIEW_NAME, - STANDARDIZED_AUDIO_POPULARITY_FUNCTION, - STANDARDIZED_IMAGE_POPULARITY_FUNCTION, -) -from common.sql import PostgresHook, _single_value -from common.storage import columns as col -from common.storage.db_columns import AUDIO_TABLE_COLUMNS, IMAGE_TABLE_COLUMNS - - -DEFAULT_PERCENTILE = 0.85 - -IMAGE_VIEW_ID_IDX = "image_view_identifier_idx" -AUDIO_VIEW_ID_IDX = "audio_view_identifier_idx" -IMAGE_VIEW_PROVIDER_FID_IDX = "image_view_provider_fid_idx" -AUDIO_VIEW_PROVIDER_FID_IDX = "audio_view_provider_fid_idx" - -# Column name constants -VALUE = "val" -CONSTANT = "constant" -FID = col.FOREIGN_ID.db_name -IDENTIFIER = col.IDENTIFIER.db_name -METADATA_COLUMN = col.META_DATA.db_name -METRIC = "metric" -PARTITION = col.PROVIDER.db_name -PERCENTILE = "percentile" -PROVIDER = col.PROVIDER.db_name - -Column = namedtuple("Column", ["name", "definition"]) - -IMAGE_POPULARITY_METRICS_TABLE_NAME = "image_popularity_metrics" -AUDIO_POPULARITY_METRICS_TABLE_NAME = "audio_popularity_metrics" - -IMAGE_POPULARITY_METRICS = { - "flickr": {"metric": "views"}, - "nappy": {"metric": "downloads"}, - "rawpixel": {"metric": "download_count"}, - "stocksnap": {"metric": "downloads_raw"}, - "wikimedia": {"metric": "global_usage_count"}, -} - -AUDIO_POPULARITY_METRICS = { - "jamendo": {"metric": "listens"}, - "wikimedia_audio": {"metric": "global_usage_count"}, - "freesound": {"metric": "num_downloads"}, -} - -POPULARITY_METRICS_TABLE_COLUMNS = [ - Column(name=PARTITION, definition="character varying(80) PRIMARY KEY"), - Column(name=METRIC, definition="character varying(80)"), - Column(name=PERCENTILE, definition="float"), - Column(name=VALUE, definition="float"), - Column(name=CONSTANT, definition="float"), -] - -# Further refactoring of this nature will be done in -# https://github.com/WordPress/openverse/issues/2678. -POPULARITY_METRICS_BY_MEDIA_TYPE = { - AUDIO: AUDIO_POPULARITY_METRICS, - IMAGE: IMAGE_POPULARITY_METRICS, -} - - -def drop_media_matview( - postgres_conn_id: str, - media_type: str = IMAGE, - db_view: str = IMAGE_VIEW_NAME, - pg_timeout: float = timedelta(minutes=10).total_seconds(), -): - if media_type == AUDIO: - db_view = AUDIO_VIEW_NAME - - postgres = PostgresHook( - postgres_conn_id=postgres_conn_id, default_statement_timeout=pg_timeout - ) - postgres.run(f"DROP MATERIALIZED VIEW IF EXISTS public.{db_view} CASCADE;") - - -def drop_media_popularity_relations( - postgres_conn_id, - media_type=IMAGE, - db_view=IMAGE_VIEW_NAME, - metrics=IMAGE_POPULARITY_METRICS_TABLE_NAME, - pg_timeout: float = timedelta(minutes=10).total_seconds(), -): - if media_type == AUDIO: - db_view = AUDIO_VIEW_NAME - metrics = AUDIO_POPULARITY_METRICS_TABLE_NAME - - postgres = PostgresHook( - postgres_conn_id=postgres_conn_id, default_statement_timeout=pg_timeout - ) - drop_media_view = f"DROP MATERIALIZED VIEW IF EXISTS public.{db_view} CASCADE;" - drop_popularity_metrics = f"DROP TABLE IF EXISTS public.{metrics} CASCADE;" - postgres.run(drop_media_view) - postgres.run(drop_popularity_metrics) - - -def drop_media_popularity_functions( - postgres_conn_id, - media_type=IMAGE, - standardized_popularity=STANDARDIZED_IMAGE_POPULARITY_FUNCTION, - popularity_percentile=IMAGE_POPULARITY_PERCENTILE_FUNCTION, -): - if media_type == AUDIO: - popularity_percentile = AUDIO_POPULARITY_PERCENTILE_FUNCTION - standardized_popularity = STANDARDIZED_AUDIO_POPULARITY_FUNCTION - postgres = PostgresHook( - postgres_conn_id=postgres_conn_id, default_statement_timeout=10.0 - ) - drop_standardized_popularity = ( - f"DROP FUNCTION IF EXISTS public.{standardized_popularity} CASCADE;" - ) - drop_popularity_percentile = ( - f"DROP FUNCTION IF EXISTS public.{popularity_percentile} CASCADE;" - ) - postgres.run(drop_standardized_popularity) - postgres.run(drop_popularity_percentile) - - -def create_media_popularity_metrics( - postgres_conn_id, - media_type=IMAGE, - popularity_metrics_table=IMAGE_POPULARITY_METRICS_TABLE_NAME, -): - if media_type == AUDIO: - popularity_metrics_table = AUDIO_POPULARITY_METRICS_TABLE_NAME - postgres = PostgresHook( - postgres_conn_id=postgres_conn_id, default_statement_timeout=10.0 - ) - popularity_metrics_columns_string = ",\n ".join( - f"{c.name} {c.definition}" for c in POPULARITY_METRICS_TABLE_COLUMNS - ) - query = dedent( - f""" - CREATE TABLE public.{popularity_metrics_table} ( - {popularity_metrics_columns_string} - ); - """ - ) - postgres.run(query) - - -@task -def update_media_popularity_metrics( - postgres_conn_id, - media_type=IMAGE, - popularity_metrics=None, - popularity_metrics_table=IMAGE_POPULARITY_METRICS_TABLE_NAME, - popularity_percentile=IMAGE_POPULARITY_PERCENTILE_FUNCTION, - task: AbstractOperator = None, -): - if popularity_metrics is None: - popularity_metrics = POPULARITY_METRICS_BY_MEDIA_TYPE[media_type] - if media_type == AUDIO: - popularity_metrics_table = AUDIO_POPULARITY_METRICS_TABLE_NAME - postgres = PostgresHook( - postgres_conn_id=postgres_conn_id, - default_statement_timeout=PostgresHook.get_execution_timeout(task), - ) - - column_names = [c.name for c in POPULARITY_METRICS_TABLE_COLUMNS] - - # Note that we do not update the val and constant. That is only done during the - # calculation tasks. In other words, we never want to clear out the current value of - # the popularity constant unless we're already done calculating the new one, since - # that can be a time consuming process. - updates_string = ",\n ".join( - f"{c}=EXCLUDED.{c}" - for c in column_names - if c not in [PARTITION, CONSTANT, VALUE] - ) - popularity_metric_inserts = _get_popularity_metric_insert_values_string( - popularity_metrics - ) - - query = dedent( - f""" - INSERT INTO public.{popularity_metrics_table} ( - {', '.join(column_names)} - ) VALUES - {popularity_metric_inserts} - ON CONFLICT ({PARTITION}) - DO UPDATE SET - {updates_string} - ; - """ - ) - return postgres.run(query) - - -@task -def calculate_media_popularity_percentile_value( - postgres_conn_id, - provider, - media_type=IMAGE, - popularity_metrics_table=IMAGE_POPULARITY_METRICS_TABLE_NAME, - popularity_percentile=IMAGE_POPULARITY_PERCENTILE_FUNCTION, - task: AbstractOperator = None, -): - if media_type == AUDIO: - popularity_metrics_table = AUDIO_POPULARITY_METRICS_TABLE_NAME - popularity_percentile = AUDIO_POPULARITY_PERCENTILE_FUNCTION - - postgres = PostgresHook( - postgres_conn_id=postgres_conn_id, - default_statement_timeout=PostgresHook.get_execution_timeout(task), - ) - - # Calculate the percentile value. E.g. if `percentile` = 0.80, then we'll - # calculate the _value_ of the 80th percentile for this provider's - # popularity metric. - calculate_new_percentile_value_query = dedent( - f""" - SELECT {popularity_percentile}({PARTITION}, {METRIC}, {PERCENTILE}) - FROM {popularity_metrics_table} - WHERE {col.PROVIDER.db_name}='{provider}'; - """ - ) - - return postgres.run(calculate_new_percentile_value_query, handler=_single_value) - - -@task -def update_percentile_and_constants_values_for_provider( - postgres_conn_id, - provider, - raw_percentile_value, - media_type=IMAGE, - popularity_metrics=None, - popularity_metrics_table=IMAGE_POPULARITY_METRICS_TABLE_NAME, - task: AbstractOperator = None, -): - if popularity_metrics is None: - popularity_metrics = POPULARITY_METRICS_BY_MEDIA_TYPE.get(media_type, {}) - if media_type == AUDIO: - popularity_metrics_table = AUDIO_POPULARITY_METRICS_TABLE_NAME - - if raw_percentile_value is None: - # Occurs when a provider has a metric configured, but there are no records - # with any data for that metric. - return - - postgres = PostgresHook( - postgres_conn_id=postgres_conn_id, - default_statement_timeout=PostgresHook.get_execution_timeout(task), - ) - - provider_info = popularity_metrics.get(provider) - percentile = provider_info.get("percentile", DEFAULT_PERCENTILE) - - # Calculate the popularity constant using the percentile value - percentile_value = raw_percentile_value or 1 - new_constant = ((1 - percentile) / (percentile)) * percentile_value - - # Update the percentile value and constant in the metrics table - update_constant_query = dedent( - f""" - UPDATE public.{popularity_metrics_table} - SET {VALUE} = {percentile_value}, {CONSTANT} = {new_constant} - WHERE {col.PROVIDER.db_name} = '{provider}'; - """ - ) - return postgres.run(update_constant_query) - - -@task_group -def update_percentile_and_constants_for_provider( - postgres_conn_id, provider, media_type=IMAGE, execution_timeout=None -): - calculate_percentile_val = calculate_media_popularity_percentile_value.override( - task_id="calculate_percentile_value", - execution_timeout=execution_timeout - or DAG_DEFAULT_ARGS.get("execution_timeout"), - )( - postgres_conn_id=postgres_conn_id, - provider=provider, - media_type=media_type, - ) - calculate_percentile_val.doc = ( - "Calculate the percentile popularity value for this provider. For" - " example, if this provider has `percentile`=0.80 and `metric`='views'," - " calculate the 80th percentile value of views for all records for this" - " provider." - ) - - update_metrics_table = update_percentile_and_constants_values_for_provider.override( - task_id="update_percentile_values_and_constant", - )( - postgres_conn_id=postgres_conn_id, - provider=provider, - raw_percentile_value=calculate_percentile_val, - media_type=media_type, - ) - update_metrics_table.doc = ( - "Given the newly calculated percentile value, calculate the" - " popularity constant and update the metrics table with the newly" - " calculated values." - ) - - -def _get_popularity_metric_insert_values_string( - popularity_metrics, - default_percentile=DEFAULT_PERCENTILE, -): - return ",\n ".join( - _format_popularity_metric_insert_tuple_string( - provider, - provider_info["metric"], - provider_info.get("percentile", default_percentile), - ) - for provider, provider_info in popularity_metrics.items() - ) - - -def _format_popularity_metric_insert_tuple_string( - provider, - metric, - percentile, -): - # Default null val and constant - return f"('{provider}', '{metric}', {percentile}, null, null)" - - -def create_media_popularity_percentile_function( - postgres_conn_id, - media_type=IMAGE, - popularity_percentile=IMAGE_POPULARITY_PERCENTILE_FUNCTION, - media_table=TABLE_NAMES[IMAGE], -): - postgres = PostgresHook( - postgres_conn_id=postgres_conn_id, default_statement_timeout=10.0 - ) - if media_type == AUDIO: - popularity_percentile = AUDIO_POPULARITY_PERCENTILE_FUNCTION - media_table = TABLE_NAMES[AUDIO] - query = dedent( - f""" - CREATE OR REPLACE FUNCTION public.{popularity_percentile}( - provider text, pop_field text, percentile float - ) RETURNS FLOAT AS $$ - SELECT percentile_disc($3) WITHIN GROUP ( - ORDER BY ({METADATA_COLUMN}->>$2)::float - ) - FROM {media_table} WHERE {PARTITION}=$1; - $$ - LANGUAGE SQL - STABLE - RETURNS NULL ON NULL INPUT; - """ - ) - postgres.run(query) - - -def create_standardized_media_popularity_function( - postgres_conn_id, - media_type=IMAGE, - function_name=STANDARDIZED_IMAGE_POPULARITY_FUNCTION, - popularity_metrics=IMAGE_POPULARITY_METRICS_TABLE_NAME, -): - if media_type == AUDIO: - popularity_metrics = AUDIO_POPULARITY_METRICS_TABLE_NAME - function_name = STANDARDIZED_AUDIO_POPULARITY_FUNCTION - postgres = PostgresHook( - postgres_conn_id=postgres_conn_id, default_statement_timeout=10.0 - ) - query = dedent( - f""" - CREATE OR REPLACE FUNCTION public.{function_name}( - provider text, meta_data jsonb - ) RETURNS FLOAT AS $$ - SELECT ($2->>{METRIC})::float / (($2->>{METRIC})::float + {CONSTANT}) - FROM {popularity_metrics} WHERE provider=$1; - $$ - LANGUAGE SQL - STABLE - RETURNS NULL ON NULL INPUT; - """ - ) - postgres.run(query) - - -def create_media_view( - postgres_conn_id, - media_type=IMAGE, - standardized_popularity_func=STANDARDIZED_IMAGE_POPULARITY_FUNCTION, - table_name=TABLE_NAMES[IMAGE], - db_columns=IMAGE_TABLE_COLUMNS, - db_view_name=IMAGE_VIEW_NAME, - db_view_id_idx=IMAGE_VIEW_ID_IDX, - db_view_provider_fid_idx=IMAGE_VIEW_PROVIDER_FID_IDX, - task: AbstractOperator = None, -): - if media_type == AUDIO: - table_name = TABLE_NAMES[AUDIO] - db_columns = AUDIO_TABLE_COLUMNS - db_view_name = AUDIO_VIEW_NAME - db_view_id_idx = AUDIO_VIEW_ID_IDX - db_view_provider_fid_idx = AUDIO_VIEW_PROVIDER_FID_IDX - standardized_popularity_func = STANDARDIZED_AUDIO_POPULARITY_FUNCTION - postgres = PostgresHook( - postgres_conn_id=postgres_conn_id, - default_statement_timeout=PostgresHook.get_execution_timeout(task), - ) - # We want to copy all columns except standardized popularity, which is calculated - columns_to_select = (", ").join( - [ - column.db_name - for column in db_columns - if column.db_name != col.STANDARDIZED_POPULARITY.db_name - ] - ) - create_view_query = dedent( - f""" - CREATE MATERIALIZED VIEW public.{db_view_name} AS - SELECT - {columns_to_select}, - {standardized_popularity_func}( - {table_name}.{PARTITION}, - {table_name}.{METADATA_COLUMN} - ) AS standardized_popularity - FROM {table_name}; - """ - ) - add_idx_query = dedent( - f""" - CREATE UNIQUE INDEX {db_view_id_idx} - ON public.{db_view_name} ({IDENTIFIER}); - CREATE UNIQUE INDEX {db_view_provider_fid_idx} - ON public.{db_view_name} - USING btree({PROVIDER}, md5({FID})); - """ - ) - postgres.run(create_view_query) - postgres.run(add_idx_query) - - -def get_providers_with_popularity_data_for_media_type( - postgres_conn_id: str, - media_type: str = IMAGE, - popularity_metrics: str = IMAGE_POPULARITY_METRICS_TABLE_NAME, - pg_timeout: float = timedelta(minutes=10).total_seconds(), -): - """ - Return a list of distinct `provider`s that support popularity data, - for the given media type. - """ - if media_type == AUDIO: - popularity_metrics = AUDIO_POPULARITY_METRICS_TABLE_NAME - - postgres = PostgresHook( - postgres_conn_id=postgres_conn_id, default_statement_timeout=pg_timeout - ) - providers = postgres.get_records( - f"SELECT DISTINCT provider FROM public.{popularity_metrics};" - ) - - return [x[0] for x in providers] - - -def format_update_standardized_popularity_query( - media_type=IMAGE, - standardized_popularity_func=STANDARDIZED_IMAGE_POPULARITY_FUNCTION, - table_name=TABLE_NAMES[IMAGE], - db_columns=IMAGE_TABLE_COLUMNS, - db_view_name=IMAGE_VIEW_NAME, - db_view_id_idx=IMAGE_VIEW_ID_IDX, - db_view_provider_fid_idx=IMAGE_VIEW_PROVIDER_FID_IDX, - task: AbstractOperator = None, -): - """ - Create a SQL query for updating the standardized popularity for the given - media type. Only the `SET ...` portion of the query is returned, to be used - by a `batched_update` DagRun. - """ - if media_type == AUDIO: - table_name = TABLE_NAMES[AUDIO] - standardized_popularity_func = STANDARDIZED_AUDIO_POPULARITY_FUNCTION - - return ( - f"SET {col.STANDARDIZED_POPULARITY.db_name} = {standardized_popularity_func}" - f"({table_name}.{PARTITION}, {table_name}.{METADATA_COLUMN})" - ) - - -def update_db_view( - postgres_conn_id, media_type=IMAGE, db_view_name=IMAGE_VIEW_NAME, task=None -): - if media_type == AUDIO: - db_view_name = AUDIO_VIEW_NAME - postgres = PostgresHook( - postgres_conn_id=postgres_conn_id, - default_statement_timeout=PostgresHook.get_execution_timeout(task), - ) - postgres.run(f"REFRESH MATERIALIZED VIEW CONCURRENTLY {db_view_name};") diff --git a/catalog/dags/common/sql.py b/catalog/dags/common/sql.py index 15e7940099a..f9f2f308bcf 100644 --- a/catalog/dags/common/sql.py +++ b/catalog/dags/common/sql.py @@ -27,7 +27,7 @@ # https://airflow.apache.org/docs/apache-airflow-providers-postgres/stable/_api/airflow/providers/postgres/hooks/postgres/index.html#airflow.providers.postgres.hooks.postgres.PostgresHook.copy_expert # noqa -def _single_value(cursor): +def single_value(cursor): try: row = cursor.fetchone() return row[0] diff --git a/catalog/dags/common/storage/db_columns.py b/catalog/dags/common/storage/db_columns.py index ad3c1bef602..f96c26915d2 100644 --- a/catalog/dags/common/storage/db_columns.py +++ b/catalog/dags/common/storage/db_columns.py @@ -2,7 +2,9 @@ This module contains the lists of database columns in the same order as in the main media tables within the database. """ +from common.constants import AUDIO, IMAGE from common.storage import columns as col +from common.utils import setup_kwargs_for_media_type # Columns that are only in the main table; @@ -79,3 +81,10 @@ col.STANDARDIZED_POPULARITY, col.AUDIO_SET_FOREIGN_IDENTIFIER, ] + +DB_COLUMNS_BY_MEDIA_TYPE = {AUDIO: AUDIO_TABLE_COLUMNS, IMAGE: IMAGE_TABLE_COLUMNS} + + +def setup_db_columns_for_media_type(func: callable) -> callable: + """Provide media-type-specific DB columns as a kwarg to the decorated function.""" + return setup_kwargs_for_media_type(DB_COLUMNS_BY_MEDIA_TYPE, "db_columns")(func) diff --git a/catalog/dags/common/storage/tsv_columns.py b/catalog/dags/common/storage/tsv_columns.py index 4d956822864..397d262e843 100644 --- a/catalog/dags/common/storage/tsv_columns.py +++ b/catalog/dags/common/storage/tsv_columns.py @@ -1,6 +1,7 @@ from common.constants import AUDIO, IMAGE from common.storage import columns as col from common.storage.columns import Column +from common.utils import setup_kwargs_for_media_type # Image has 'legacy' 000 version @@ -120,6 +121,13 @@ CURRENT_AUDIO_TSV_COLUMNS: list[Column] = COLUMNS[AUDIO][CURRENT_VERSION[AUDIO]] CURRENT_IMAGE_TSV_COLUMNS: list[Column] = COLUMNS[IMAGE][CURRENT_VERSION[IMAGE]] - # This list is the same for all media types -required_columns = [col for col in CURRENT_IMAGE_TSV_COLUMNS if col.required] +REQUIRED_COLUMNS = [col for col in CURRENT_IMAGE_TSV_COLUMNS if col.required] + + +def setup_tsv_columns_for_media_type(func: callable) -> callable: + """Provide media-type-specific TSV columns as a kwarg to the decorated function.""" + return setup_kwargs_for_media_type( + {AUDIO: CURRENT_AUDIO_TSV_COLUMNS, IMAGE: CURRENT_IMAGE_TSV_COLUMNS}, + "tsv_columns", + )(func) diff --git a/catalog/dags/common/utils.py b/catalog/dags/common/utils.py new file mode 100644 index 00000000000..aa6f238b1b8 --- /dev/null +++ b/catalog/dags/common/utils.py @@ -0,0 +1,80 @@ +import functools +from inspect import _ParameterKind, signature +from typing import Any + +from common.constants import SQL_INFO_BY_MEDIA_TYPE + + +def setup_kwargs_for_media_type( + values_by_media_type: dict[str, Any], kwarg_name: str +) -> callable: + """ + Create a decorator which provides media_type-specific information as parameters + for the called function. The called function must itself have a media_type kwarg, + which is used to select values. + + Required arguments: + + values_by_media_type: A dict mapping media types to arbitrary values, which may + themselves be of any type + kwarg_name: The name of the kwarg that will be passed to the called + function + + Usage example: + + @setup_kwargs_for_media_type(MY_VALS_BY_MEDIA_TYPE, 'foo') + def my_fun(media_type, foo = None): + ... + + When `my_fun` is called, if the `foo` kwarg is not passed explicitly, it will be set + to the value of MY_VALS_BY_MEDIA_TYPE[media_type]. An error is raised for an invalid + media type. + """ + + def wrap(func: callable) -> callable: + """ + Provide the appropriate value for the media_type passed in the called function. + If the called function is already explicitly passed a value for `kwarg_name`, + use that value instead. + """ + + # The called function must be supplied a `media_type` keyword-only argument. It + # cannot allow the value to be supplied as a positional argument. + if ( + media_type := signature(func).parameters.get("media_type") + ) is None or media_type.kind != _ParameterKind.KEYWORD_ONLY: + raise Exception( + f"Improperly configured function `{func.__qualname__}`:" + " `media_type` must be a keyword-only argument." + ) + + @functools.wraps(func) + def wrapped(*args, **kwargs): + # First check to see if the called function was already passed a value + # for the given kwarg name. If so, simply use this. + if (media_info := kwargs.pop(kwarg_name, None)) is None: + # The called function should be passed a `media_type`, whose value + # is a key in the values dict + media_type = kwargs.get("media_type", None) + + if media_type not in values_by_media_type.keys(): + raise ValueError( + f"{func.__qualname__}: No values matching media type" + f" `{media_type}`" + ) + + # Get the value corresponding to the media type + media_info = values_by_media_type.get(media_type) + + # Add the media-type-specific info to kwargs, using the passed kwarg name + kwargs[kwarg_name] = media_info + return func(*args, **kwargs) + + return wrapped + + return wrap + + +def setup_sql_info_for_media_type(func: callable) -> callable: + """Provide media-type-specific SQLInfo as a kwarg to the decorated function.""" + return setup_kwargs_for_media_type(SQL_INFO_BY_MEDIA_TYPE, "sql_info")(func) diff --git a/catalog/dags/data_refresh/dag_factory.py b/catalog/dags/data_refresh/dag_factory.py index de8291d6aad..f2fd382574e 100644 --- a/catalog/dags/data_refresh/dag_factory.py +++ b/catalog/dags/data_refresh/dag_factory.py @@ -34,7 +34,7 @@ OPENLEDGER_API_CONN_ID, XCOM_PULL_TEMPLATE, ) -from common.sql import PGExecuteQueryOperator, _single_value +from common.sql import PGExecuteQueryOperator, single_value from data_refresh.data_refresh_task_factory import create_data_refresh_task_group from data_refresh.data_refresh_types import DATA_REFRESH_CONFIGS, DataRefresh from data_refresh.reporting import report_record_difference @@ -93,7 +93,7 @@ def create_data_refresh_dag(data_refresh: DataRefresh, external_dag_ids: Sequenc task_id="get_before_record_count", conn_id=OPENLEDGER_API_CONN_ID, sql=count_sql, - handler=_single_value, + handler=single_value, return_last=True, ) @@ -108,7 +108,7 @@ def create_data_refresh_dag(data_refresh: DataRefresh, external_dag_ids: Sequenc task_id="get_after_record_count", conn_id=OPENLEDGER_API_CONN_ID, sql=count_sql, - handler=_single_value, + handler=single_value, return_last=True, ) diff --git a/catalog/dags/data_refresh/recreate_view_data_task_factory.py b/catalog/dags/data_refresh/recreate_view_data_task_factory.py deleted file mode 100644 index 21834ed4f1b..00000000000 --- a/catalog/dags/data_refresh/recreate_view_data_task_factory.py +++ /dev/null @@ -1,81 +0,0 @@ -""" -# Recreate Materialized View Task Factory -This file generates a TaskGroup that recreates the materialized view for a -given media type, using a factory function. - -The task drops and recreates the materialized view, but not the underlying tables. This -means that the only effect is to add or update data (including popularity data) -for records which have been ingested since the last time the view was -refreshed. - -This should be run every time before a data refresh is triggered. -""" -from airflow.operators.python import PythonOperator -from airflow.utils.task_group import TaskGroup -from airflow.utils.trigger_rule import TriggerRule - -from common.constants import POSTGRES_CONN_ID -from common.popularity import sql -from data_refresh import reporting -from data_refresh.data_refresh_types import DataRefresh - - -GROUP_ID = "recreate_matview" -DROP_DB_VIEW_TASK_ID = "drop_materialized_popularity_view" -CREATE_DB_VIEW_TASK_ID = "create_materialized_popularity_view" - - -def create_recreate_view_data_task(data_refresh: DataRefresh): - """ - Create the recreate related tasks. - - The task drops and recreates the materialized view for the given media type. The - view collates popularity data for each record. Recreating has the effect of adding - popularity data for records that were ingested since the last time the view was - created or refreshed, and updating popularity data for existing records. It also - creates a reporting task which will report the status of the matview refresh once - it is complete. - - The view is dropped and recreated rather than refreshed, because refreshing the view - takes much longer in production and times out. - - Required Arguments: - - data_refresh: configuration information for the data refresh - """ - with TaskGroup(group_id=GROUP_ID) as recreate_matview: - drop_matview = PythonOperator( - task_id=DROP_DB_VIEW_TASK_ID, - python_callable=sql.drop_media_matview, - op_kwargs={ - "postgres_conn_id": POSTGRES_CONN_ID, - "media_type": data_refresh.media_type, - }, - trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS, - retries=0, - ) - create_matview = PythonOperator( - task_id=CREATE_DB_VIEW_TASK_ID, - python_callable=sql.create_media_view, - op_kwargs={ - "postgres_conn_id": POSTGRES_CONN_ID, - "media_type": data_refresh.media_type, - }, - execution_timeout=data_refresh.create_materialized_view_timeout, - retries=0, - doc_md=create_recreate_view_data_task.__doc__, - ) - recreate_status = PythonOperator( - task_id=f"report_{GROUP_ID}_status", - python_callable=reporting.report_status, - op_kwargs={ - "media_type": data_refresh.media_type, - "dag_id": data_refresh.dag_id, - "message": "Matview refresh complete | " - "_Next: ingestion server data refresh_", - }, - ) - - drop_matview >> create_matview >> recreate_status - - return recreate_matview diff --git a/catalog/dags/database/batched_update/batched_update.py b/catalog/dags/database/batched_update/batched_update.py index 74c2f789396..d09c2df1d14 100644 --- a/catalog/dags/database/batched_update/batched_update.py +++ b/catalog/dags/database/batched_update/batched_update.py @@ -7,7 +7,7 @@ from common import slack from common.constants import POSTGRES_CONN_ID -from common.sql import PostgresHook, _single_value +from common.sql import PostgresHook, single_value from database.batched_update import constants @@ -44,7 +44,7 @@ def get_expected_update_count( dry_run=dry_run, sql_template=constants.SELECT_TEMP_TABLE_COUNT_QUERY, query_id=query_id, - handler=_single_value, + handler=single_value, ) diff --git a/catalog/dags/database/recreate_popularity_calculation_dag_factory.py b/catalog/dags/database/recreate_popularity_calculation_dag_factory.py deleted file mode 100644 index 7381b1d28aa..00000000000 --- a/catalog/dags/database/recreate_popularity_calculation_dag_factory.py +++ /dev/null @@ -1,138 +0,0 @@ -""" -This file generates Apache Airflow DAGs that, for the given media type, -completely wipe out the PostgreSQL relations and functions involved in -calculating our standardized popularity metric. It then recreates relations -and functions to make the calculation, and performs an initial calculation. -The results are available in the materialized view for that media type. - -These DAGs are not on a schedule, and should only be run manually when new -SQL code is deployed for the calculation. -""" -from airflow import DAG -from airflow.operators.python import PythonOperator - -from common.constants import DAG_DEFAULT_ARGS, POSTGRES_CONN_ID -from common.popularity import sql -from data_refresh.data_refresh_types import DATA_REFRESH_CONFIGS, DataRefresh - - -def create_recreate_popularity_calculation_dag(data_refresh: DataRefresh): - media_type = data_refresh.media_type - default_args = { - **DAG_DEFAULT_ARGS, - **data_refresh.default_args, - } - - dag = DAG( - dag_id=f"recreate_{media_type}_popularity_calculation", - default_args=default_args, - max_active_runs=1, - schedule=None, - catchup=False, - doc_md=__doc__, - tags=["database", "data_refresh"], - ) - with dag: - drop_relations = PythonOperator( - task_id="drop_popularity_relations", - python_callable=sql.drop_media_popularity_relations, - op_kwargs={ - "postgres_conn_id": POSTGRES_CONN_ID, - "media_type": media_type, - }, - doc="Drop the existing popularity views and tables.", - ) - - drop_functions = PythonOperator( - task_id="drop_popularity_functions", - python_callable=sql.drop_media_popularity_functions, - op_kwargs={ - "postgres_conn_id": POSTGRES_CONN_ID, - "media_type": media_type, - }, - doc="Drop the existing popularity functions.", - ) - - create_metrics_table = PythonOperator( - task_id="create_popularity_metrics_table", - python_callable=sql.create_media_popularity_metrics, - op_kwargs={ - "postgres_conn_id": POSTGRES_CONN_ID, - "media_type": media_type, - }, - doc=( - "Create the popularity metrics table, which stores popularity " - "metrics and target percentiles per provider." - ), - ) - - update_metrics_table = PythonOperator( - task_id="update_popularity_metrics_table", - python_callable=sql.update_media_popularity_metrics, - op_kwargs={ - "postgres_conn_id": POSTGRES_CONN_ID, - "media_type": media_type, - }, - doc="Update the popularity metrics table with values for each provider.", - ) - - create_percentile_function = PythonOperator( - task_id="create_popularity_percentile_function", - python_callable=sql.create_media_popularity_percentile_function, - op_kwargs={ - "postgres_conn_id": POSTGRES_CONN_ID, - "media_type": media_type, - }, - doc=( - "Create the function for calculating popularity percentile values, " - "used for calculating the popularity constants for each provider." - ), - ) - - create_popularity_function = PythonOperator( - task_id="create_standardized_popularity_function", - python_callable=sql.create_standardized_media_popularity_function, - op_kwargs={ - "postgres_conn_id": POSTGRES_CONN_ID, - "media_type": media_type, - }, - doc=( - "Create the function that calculates popularity data for a given " - "record, standardizing across providers with the generated popularity " - "constants." - ), - ) - - create_matview = PythonOperator( - task_id="create_materialized_popularity_view", - python_callable=sql.create_media_view, - op_kwargs={ - "postgres_conn_id": POSTGRES_CONN_ID, - "media_type": media_type, - }, - execution_timeout=data_refresh.create_materialized_view_timeout, - doc=( - "Create the materialized view containing standardized popularity data " - "for all records." - ), - ) - - ( - [drop_relations, drop_functions] - >> create_metrics_table - >> [update_metrics_table, create_percentile_function] - >> create_popularity_function - >> create_matview - ) - - return dag - - -# Generate a recreate_popularity_calculation DAG for each DATA_REFRESH_CONFIG. -for data_refresh in DATA_REFRESH_CONFIGS: - recreate_popularity_calculation_dag = create_recreate_popularity_calculation_dag( - data_refresh - ) - globals()[ - recreate_popularity_calculation_dag.dag_id - ] = recreate_popularity_calculation_dag diff --git a/catalog/dags/popularity/dag_factory.py b/catalog/dags/popularity/popularity_refresh_dag_factory.py similarity index 66% rename from catalog/dags/popularity/dag_factory.py rename to catalog/dags/popularity/popularity_refresh_dag_factory.py index dfb55645e04..567e907d2d1 100644 --- a/catalog/dags/popularity/dag_factory.py +++ b/catalog/dags/popularity/popularity_refresh_dag_factory.py @@ -28,17 +28,14 @@ from airflow import DAG from airflow.decorators import task from airflow.operators.trigger_dagrun import TriggerDagRunOperator +from popularity import sql from popularity.popularity_refresh_types import ( POPULARITY_REFRESH_CONFIGS, PopularityRefresh, ) -from popularity.refresh_popularity_metrics_task_factory import ( - create_refresh_popularity_metrics_task_group, -) from common import slack from common.constants import DAG_DEFAULT_ARGS, POSTGRES_CONN_ID -from common.popularity import sql from database.batched_update.constants import DAG_ID as BATCHED_UPDATE_DAG_ID @@ -70,52 +67,6 @@ def get_last_updated_time(): return datetime.utcnow() -@task -def get_providers_update_confs( - postgres_conn_id: str, - popularity_refresh: PopularityRefresh, - last_updated_time: datetime, -): - """ - Build a list of DagRun confs for each provider of this media type. The confs will - be used by the `batched_update` DAG to perform a batched update of all existing - records, to recalculate their standardized_popularity with the new popularity - constant. Providers that do not support popularity data are omitted. - """ - # For the media type, get a list of the providers who support popularity data - providers = sql.get_providers_with_popularity_data_for_media_type( - postgres_conn_id, popularity_refresh.media_type - ) - - # For each provider, create a conf that will be used by the batched_update to - # refresh standardized popularity scores. - return [ - { - # Uniquely identify the query - "query_id": ( - f"{provider}_popularity_refresh_{last_updated_time.strftime('%Y%m%d')}" - ), - "table_name": popularity_refresh.media_type, - # Query used to select records that should be refreshed - "select_query": ( - f"WHERE provider='{provider}' AND updated_on <" - f" '{last_updated_time.strftime('%Y-%m-%d %H:%M:%S')}'" - ), - # Query used to update the standardized_popularity - "update_query": sql.format_update_standardized_popularity_query( - popularity_refresh.media_type - ), - "batch_size": 10_000, - "update_timeout": ( - popularity_refresh.refresh_popularity_batch_timeout.total_seconds() - ), - "dry_run": False, - "resume_update": False, - } - for provider in providers - ] - - def create_popularity_refresh_dag(popularity_refresh: PopularityRefresh): """ Instantiate a DAG for a popularity refresh. @@ -144,11 +95,54 @@ def create_popularity_refresh_dag(popularity_refresh: PopularityRefresh): ) with dag: - # Refresh the underlying popularity tables. This step recalculates the - # popularity constants, which will later be used to calculate updated - # standardized popularity scores. - refresh_popularity_metrics = create_refresh_popularity_metrics_task_group( - popularity_refresh + update_metrics = sql.update_media_popularity_metrics.override( + task_id="update_popularity_metrics", + )( + postgres_conn_id=POSTGRES_CONN_ID, + media_type=popularity_refresh.media_type, + popularity_metrics=popularity_refresh.popularity_metrics, + ) + update_metrics.doc = ( + "Updates the metrics and target percentiles. If a popularity" + " metric is configured for a new provider, this step will add it" + " to the metrics table." + ) + + update_metrics_status = notify_slack.override( + task_id="report_update_popularity_metrics_status" + )( + text="Popularity metrics update complete | _Next: popularity" + " constants update_", + media_type=popularity_refresh.media_type, + dag_id=popularity_refresh.dag_id, + ) + + update_constants = ( + sql.update_percentile_and_constants_for_provider.override( + group_id="refresh_popularity_metrics_and_constants", + ) + .partial( + postgres_conn_id=POSTGRES_CONN_ID, + media_type=popularity_refresh.media_type, + execution_timeout=popularity_refresh.refresh_metrics_timeout, + popularity_metrics=popularity_refresh.popularity_metrics, + ) + .expand(provider=list(popularity_refresh.popularity_metrics.keys())) + ) + update_constants.doc = ( + "Recalculate the percentile values and popularity constants" + " for each provider, and update them in the metrics table. The" + " popularity constants will be used to calculate standardized" + " popularity scores." + ) + + update_constants_status = notify_slack.override( + task_id="report_update_popularity_metrics_status" + )( + text="Popularity constants update complete | _Next: refresh" + " popularity scores_", + media_type=popularity_refresh.media_type, + dag_id=popularity_refresh.dag_id, ) # Once popularity constants have been calculated, establish the cutoff time @@ -170,7 +164,7 @@ def create_popularity_refresh_dag(popularity_refresh: PopularityRefresh): retries=0, ).expand( # Build the conf for each provider - conf=get_providers_update_confs( + conf=sql.get_providers_update_confs( POSTGRES_CONN_ID, popularity_refresh, get_cutoff_time ) ) @@ -185,12 +179,9 @@ def create_popularity_refresh_dag(popularity_refresh: PopularityRefresh): ) # Set up task dependencies - ( - refresh_popularity_metrics - >> get_cutoff_time - >> refresh_popularity_scores - >> notify_complete - ) + update_metrics >> [update_metrics_status, update_constants] + update_constants >> [update_constants_status, get_cutoff_time] + get_cutoff_time >> refresh_popularity_scores >> notify_complete return dag diff --git a/catalog/dags/popularity/popularity_refresh_types.py b/catalog/dags/popularity/popularity_refresh_types.py index 76d3b4631b2..b25867e9e27 100644 --- a/catalog/dags/popularity/popularity_refresh_types.py +++ b/catalog/dags/popularity/popularity_refresh_types.py @@ -19,7 +19,12 @@ class PopularityRefresh: Required Constructor Arguments: - media_type: str describing the media type to be refreshed. + media_type: str describing the media type to be refreshed. + popularity_metrics: dictionary mapping providers of this media type + to their popularity metrics and, optionally, percentile. If + the percentile key is not included, the default value will + be used. + Ex: {"my_provider": {"metric": "views", "percentile": 0.5}} Optional Constructor Arguments: @@ -45,6 +50,7 @@ class PopularityRefresh: dag_id: str = field(init=False) media_type: str + popularity_metrics: dict default_args: dict | None = field(default_factory=dict) start_date: datetime = datetime(2023, 1, 1) schedule: str | None = "@monthly" @@ -61,11 +67,23 @@ def __post_init__(self): PopularityRefresh( media_type="image", refresh_metrics_timeout=timedelta(hours=24), + popularity_metrics={ + "flickr": {"metric": "views"}, + "nappy": {"metric": "downloads"}, + "rawpixel": {"metric": "download_count"}, + "stocksnap": {"metric": "downloads_raw"}, + "wikimedia": {"metric": "global_usage_count"}, + }, ), PopularityRefresh( media_type="audio", # Poke every minute, instead of every thirty minutes poke_interval=int(os.getenv("DATA_REFRESH_POKE_INTERVAL", 60)), refresh_popularity_timeout=timedelta(days=1), + popularity_metrics={ + "jamendo": {"metric": "listens"}, + "wikimedia_audio": {"metric": "global_usage_count"}, + "freesound": {"metric": "num_downloads"}, + }, ), ] diff --git a/catalog/dags/popularity/recreate_popularity_calculation_dag_factory.py b/catalog/dags/popularity/recreate_popularity_calculation_dag_factory.py new file mode 100644 index 00000000000..31261ec38fc --- /dev/null +++ b/catalog/dags/popularity/recreate_popularity_calculation_dag_factory.py @@ -0,0 +1,79 @@ +""" +This file generates Apache Airflow DAGs that, for the given media type, +completely wipes out and recreates the PostgreSQL functions involved in +calculating our standardized popularity metric. + +Note that they do not drop any tables or views related to popularity, and +they do not perform any popularity calculations. Once this DAG has been run, +the associated popularity refresh DAG must be run in order to actually +recalculate popularity constants and standardized popularity scores using +the new functions. + +These DAGs are not on a schedule, and should only be run manually when new +SQL code is deployed for the calculation. +""" +from airflow import DAG +from popularity import sql +from popularity.popularity_refresh_types import ( + POPULARITY_REFRESH_CONFIGS, + PopularityRefresh, +) + +from common.constants import DAG_DEFAULT_ARGS, POSTGRES_CONN_ID + + +def create_recreate_popularity_calculation_dag(popularity_refresh: PopularityRefresh): + media_type = popularity_refresh.media_type + default_args = { + **DAG_DEFAULT_ARGS, + **popularity_refresh.default_args, + } + + dag = DAG( + dag_id=f"recreate_{media_type}_popularity_calculation", + default_args=default_args, + max_active_runs=1, + schedule=None, + catchup=False, + doc_md=__doc__, + tags=["database", "data_refresh"], + ) + with dag: + drop_functions = sql.drop_media_popularity_functions( + postgres_conn_id=POSTGRES_CONN_ID, + media_type=media_type, + ) + drop_functions.doc = "Drop the existing popularity functions." + + create_percentile_function = sql.create_media_popularity_percentile_function( + postgres_conn_id=POSTGRES_CONN_ID, + media_type=media_type, + ) + create_percentile_function.doc = ( + "Create the function for calculating popularity percentile values, " + "used for calculating the popularity constants for each provider." + ) + + create_popularity_function = sql.create_standardized_media_popularity_function( + postgres_conn_id=POSTGRES_CONN_ID, + media_type=media_type, + ) + create_popularity_function.doc = ( + "Create the function that calculates popularity data for a given " + "record, standardizing across providers with the generated popularity " + "constants." + ) + + (drop_functions >> create_percentile_function >> create_popularity_function) + + return dag + + +# Generate a recreate_popularity_calculation DAG for each POPULARITY_REFRESH_CONFIGS. +for popularity_refresh in POPULARITY_REFRESH_CONFIGS: + recreate_popularity_calculation_dag = create_recreate_popularity_calculation_dag( + popularity_refresh + ) + globals()[ + recreate_popularity_calculation_dag.dag_id + ] = recreate_popularity_calculation_dag diff --git a/catalog/dags/popularity/refresh_popularity_metrics_task_factory.py b/catalog/dags/popularity/refresh_popularity_metrics_task_factory.py deleted file mode 100644 index 428e23f7456..00000000000 --- a/catalog/dags/popularity/refresh_popularity_metrics_task_factory.py +++ /dev/null @@ -1,108 +0,0 @@ -""" -# Refresh Popularity Metrics TaskGroup Factory -This file generates a TaskGroup that refreshes the underlying popularity DB -tables, using a factory function. - -This step updates any changes to popularity metrics, and recalculates the -popularity constants. It should be run at least once every month, or whenever -a new popularity metric is added. Scheduling is handled in the parent data -refresh DAG. -""" -from airflow.operators.python import PythonOperator -from airflow.utils.task_group import TaskGroup -from popularity.popularity_refresh_types import PopularityRefresh - -from common.constants import POSTGRES_CONN_ID -from common.popularity import sql -from data_refresh import reporting -from data_refresh.data_refresh_types import DataRefresh - - -GROUP_ID = "refresh_popularity_metrics_and_constants" -UPDATE_MEDIA_POPULARITY_METRICS_TASK_ID = "update_media_popularity_metrics" -UPDATE_MEDIA_POPULARITY_CONSTANTS_TASK_ID = "update_media_popularity_constants" - - -def create_refresh_popularity_metrics_task_group( - refresh_config: DataRefresh | PopularityRefresh, -): - """ - Create tasks related to refreshing popularity statistics. - - This factory method instantiates a TaskGroup that will update the popularity - DB tables for the given media type, including percentiles and popularity - metrics. It also creates a reporting tasks which will report the status of the - various steps once they complete. - - Required Arguments: - - refresh_config: configuration data for the refresh - """ - media_type = refresh_config.media_type - execution_timeout = refresh_config.refresh_metrics_timeout - - with TaskGroup(group_id=GROUP_ID) as refresh_all_popularity_data: - update_metrics = sql.update_media_popularity_metrics.override( - task_id=UPDATE_MEDIA_POPULARITY_METRICS_TASK_ID, - execution_timeout=execution_timeout, - )( - postgres_conn_id=POSTGRES_CONN_ID, - media_type=media_type, - ) - update_metrics.doc = ( - "Updates the metrics and target percentiles. If a popularity" - " metric is configured for a new provider, this step will add it" - " to the metrics table." - ) - - update_metrics_status = PythonOperator( - task_id=f"report_{UPDATE_MEDIA_POPULARITY_METRICS_TASK_ID}_status", - python_callable=reporting.report_status, - op_kwargs={ - "media_type": media_type, - "dag_id": refresh_config.dag_id, - "message": "Popularity metrics update complete | " - "_Next: popularity constants update_", - }, - ) - - update_constants = ( - sql.update_percentile_and_constants_for_provider.override( - group_id=UPDATE_MEDIA_POPULARITY_CONSTANTS_TASK_ID, - ) - .partial( - postgres_conn_id=POSTGRES_CONN_ID, - media_type=media_type, - execution_timeout=execution_timeout, - ) - .expand( - provider=[ - provider - for provider in sql.POPULARITY_METRICS_BY_MEDIA_TYPE[ - media_type - ].keys() - ] - ) - ) - update_constants.doc = ( - "Recalculate the percentile values and popularity constants" - " for each provider, and update them in the metrics table. The" - " popularity constants will be used to calculate standardized" - " popularity scores." - ) - - update_constants_status = PythonOperator( - task_id=f"report_{UPDATE_MEDIA_POPULARITY_CONSTANTS_TASK_ID}_status", - python_callable=reporting.report_status, - op_kwargs={ - "media_type": media_type, - "dag_id": refresh_config.dag_id, - "message": "Popularity constants update complete | " - "_Next: refresh matview_", - }, - ) - - update_metrics >> [update_metrics_status, update_constants] - update_constants >> update_constants_status - - return refresh_all_popularity_data diff --git a/catalog/dags/popularity/sql.py b/catalog/dags/popularity/sql.py new file mode 100644 index 00000000000..bacf3bf6291 --- /dev/null +++ b/catalog/dags/popularity/sql.py @@ -0,0 +1,366 @@ +from collections import namedtuple +from datetime import datetime, timedelta +from textwrap import dedent + +from airflow.decorators import task, task_group +from airflow.models.abstractoperator import AbstractOperator +from popularity.popularity_refresh_types import PopularityRefresh + +from common.constants import DAG_DEFAULT_ARGS, SQLInfo +from common.sql import PostgresHook, single_value +from common.storage import columns as col +from common.utils import setup_sql_info_for_media_type + + +DEFAULT_PERCENTILE = 0.85 + + +# Column name constants +VALUE = "val" +CONSTANT = "constant" +FID = col.FOREIGN_ID.db_name +IDENTIFIER = col.IDENTIFIER.db_name +METADATA_COLUMN = col.META_DATA.db_name +METRIC = "metric" +PARTITION = col.PROVIDER.db_name +PERCENTILE = "percentile" +PROVIDER = col.PROVIDER.db_name + +Column = namedtuple("Column", ["name", "definition"]) + + +POPULARITY_METRICS_TABLE_COLUMNS = [ + Column(name=PARTITION, definition="character varying(80) PRIMARY KEY"), + Column(name=METRIC, definition="character varying(80)"), + Column(name=PERCENTILE, definition="float"), + Column(name=VALUE, definition="float"), + Column(name=CONSTANT, definition="float"), +] + + +@task +@setup_sql_info_for_media_type +def drop_media_popularity_functions( + postgres_conn_id: str, *, media_type: str, sql_info: SQLInfo = None +): + postgres = PostgresHook( + postgres_conn_id=postgres_conn_id, default_statement_timeout=10.0 + ) + postgres.run( + f"DROP FUNCTION IF EXISTS public.{sql_info.standardized_popularity_fn} CASCADE;" + ) + postgres.run( + f"DROP FUNCTION IF EXISTS public.{sql_info.popularity_percentile_fn} CASCADE;" + ) + + +@task +@setup_sql_info_for_media_type +def create_media_popularity_metrics( + postgres_conn_id: str, *, media_type: str, sql_info: SQLInfo = None +): + postgres = PostgresHook( + postgres_conn_id=postgres_conn_id, default_statement_timeout=10.0 + ) + popularity_metrics_columns_string = ",\n ".join( + f"{c.name} {c.definition}" for c in POPULARITY_METRICS_TABLE_COLUMNS + ) + query = dedent( + f""" + CREATE TABLE public.{sql_info.metrics_table} ( + {popularity_metrics_columns_string} + ); + """ + ) + postgres.run(query) + + +@task +@setup_sql_info_for_media_type +def update_media_popularity_metrics( + postgres_conn_id: str, + popularity_metrics: dict, + *, + media_type: str, + sql_info: SQLInfo = None, + task: AbstractOperator = None, +): + postgres = PostgresHook( + postgres_conn_id=postgres_conn_id, + default_statement_timeout=PostgresHook.get_execution_timeout(task), + ) + + column_names = [c.name for c in POPULARITY_METRICS_TABLE_COLUMNS] + + # Note that we do not update the val and constant. That is only done during the + # calculation tasks. In other words, we never want to clear out the current value of + # the popularity constant unless we're already done calculating the new one, since + # that can be a time consuming process. + updates_string = ",\n ".join( + f"{c}=EXCLUDED.{c}" + for c in column_names + if c not in [PARTITION, CONSTANT, VALUE] + ) + popularity_metric_inserts = _get_popularity_metric_insert_values_string( + popularity_metrics + ) + + query = dedent( + f""" + INSERT INTO public.{sql_info.metrics_table} ( + {', '.join(column_names)} + ) VALUES + {popularity_metric_inserts} + ON CONFLICT ({PARTITION}) + DO UPDATE SET + {updates_string} + ; + """ + ) + return postgres.run(query) + + +@task +@setup_sql_info_for_media_type +def calculate_media_popularity_percentile_value( + postgres_conn_id: str, + provider: str, + *, + media_type: str, + sql_info: SQLInfo = None, + task: AbstractOperator = None, +): + postgres = PostgresHook( + postgres_conn_id=postgres_conn_id, + default_statement_timeout=PostgresHook.get_execution_timeout(task), + ) + + # Calculate the percentile value. E.g. if `percentile` = 0.80, then we'll + # calculate the _value_ of the 80th percentile for this provider's + # popularity metric. + calculate_new_percentile_value_query = dedent( + f""" + SELECT {sql_info.popularity_percentile_fn}({PARTITION}, {METRIC}, {PERCENTILE}) + FROM {sql_info.metrics_table} + WHERE {col.PROVIDER.db_name}='{provider}'; + """ + ) + + return postgres.run(calculate_new_percentile_value_query, handler=single_value) + + +@task +@setup_sql_info_for_media_type +def update_percentile_and_constants_values_for_provider( + postgres_conn_id: str, + provider: str, + raw_percentile_value: float, + popularity_metrics: dict, + *, + media_type: str, + sql_info: SQLInfo = None, + task: AbstractOperator = None, +): + if raw_percentile_value is None: + # Occurs when a provider has a metric configured, but there are no records + # with any data for that metric. + return + + postgres = PostgresHook( + postgres_conn_id=postgres_conn_id, + default_statement_timeout=PostgresHook.get_execution_timeout(task), + ) + + provider_info = popularity_metrics.get(provider) + percentile = provider_info.get("percentile", DEFAULT_PERCENTILE) + + # Calculate the popularity constant using the percentile value + percentile_value = raw_percentile_value or 1 + new_constant = ((1 - percentile) / (percentile)) * percentile_value + + # Update the percentile value and constant in the metrics table + update_constant_query = dedent( + f""" + UPDATE public.{sql_info.metrics_table} + SET {VALUE} = {percentile_value}, {CONSTANT} = {new_constant} + WHERE {col.PROVIDER.db_name} = '{provider}'; + """ + ) + return postgres.run(update_constant_query) + + +@task_group +def update_percentile_and_constants_for_provider( + postgres_conn_id: str, + provider: str, + media_type: str, + popularity_metrics: dict, + execution_timeout: timedelta = None, +): + calculate_percentile_val = calculate_media_popularity_percentile_value.override( + task_id="calculate_percentile_value", + execution_timeout=execution_timeout + or DAG_DEFAULT_ARGS.get("execution_timeout"), + )( + postgres_conn_id=postgres_conn_id, + provider=provider, + media_type=media_type, + ) + calculate_percentile_val.doc = ( + "Calculate the percentile popularity value for this provider. For" + " example, if this provider has `percentile`=0.80 and `metric`='views'," + " calculate the 80th percentile value of views for all records for this" + " provider." + ) + + update_metrics_table = update_percentile_and_constants_values_for_provider.override( + task_id="update_percentile_values_and_constant", + )( + postgres_conn_id=postgres_conn_id, + provider=provider, + raw_percentile_value=calculate_percentile_val, + media_type=media_type, + popularity_metrics=popularity_metrics, + ) + update_metrics_table.doc = ( + "Given the newly calculated percentile value, calculate the" + " popularity constant and update the metrics table with the newly" + " calculated values." + ) + + +def _get_popularity_metric_insert_values_string( + popularity_metrics: dict, + default_percentile: float = DEFAULT_PERCENTILE, +) -> str: + return ",\n ".join( + _format_popularity_metric_insert_tuple_string( + provider, + provider_info["metric"], + provider_info.get("percentile", default_percentile), + ) + for provider, provider_info in popularity_metrics.items() + ) + + +def _format_popularity_metric_insert_tuple_string( + provider: str, + metric: str, + percentile: float, +): + # Default null val and constant + return f"('{provider}', '{metric}', {percentile}, null, null)" + + +@task +@setup_sql_info_for_media_type +def create_media_popularity_percentile_function( + postgres_conn_id: str, + *, + media_type: str, + sql_info: SQLInfo = None, +): + postgres = PostgresHook( + postgres_conn_id=postgres_conn_id, default_statement_timeout=10.0 + ) + + query = dedent( + f""" + CREATE OR REPLACE FUNCTION public.{sql_info.popularity_percentile_fn}( + provider text, pop_field text, percentile float + ) RETURNS FLOAT AS $$ + SELECT percentile_disc($3) WITHIN GROUP ( + ORDER BY ({METADATA_COLUMN}->>$2)::float + ) + FROM {sql_info.media_table} WHERE {PARTITION}=$1; + $$ + LANGUAGE SQL + STABLE + RETURNS NULL ON NULL INPUT; + """ + ) + postgres.run(query) + + +@task +@setup_sql_info_for_media_type +def create_standardized_media_popularity_function( + postgres_conn_id: str, *, media_type: str, sql_info: SQLInfo = None +): + postgres = PostgresHook( + postgres_conn_id=postgres_conn_id, default_statement_timeout=10.0 + ) + query = dedent( + f""" + CREATE OR REPLACE FUNCTION public.{sql_info.standardized_popularity_fn}( + provider text, meta_data jsonb + ) RETURNS FLOAT AS $$ + SELECT ($2->>{METRIC})::float / (($2->>{METRIC})::float + {CONSTANT}) + FROM {sql_info.metrics_table} WHERE provider=$1; + $$ + LANGUAGE SQL + STABLE + RETURNS NULL ON NULL INPUT; + """ + ) + postgres.run(query) + + +@setup_sql_info_for_media_type +def format_update_standardized_popularity_query( + *, + media_type: str, + sql_info: SQLInfo = None, +) -> str: + """ + Create a SQL query for updating the standardized popularity for the given + media type. Only the `SET ...` portion of the query is returned, to be used + by a `batched_update` DagRun. + """ + return ( + f"SET {col.STANDARDIZED_POPULARITY.db_name} =" + f" {sql_info.standardized_popularity_fn}({sql_info.media_table}.{PARTITION}," + f" {sql_info.media_table}.{METADATA_COLUMN})" + ) + + +@task +def get_providers_update_confs( + postgres_conn_id: str, + popularity_refresh: PopularityRefresh, + last_updated_time: datetime, +) -> list[dict]: + """ + Build a list of DagRun confs for each provider of this media type. The confs will + be used by the `batched_update` DAG to perform a batched update of all existing + records, to recalculate their standardized_popularity with the new popularity + constant. Providers that do not support popularity data are omitted. + """ + + # For each provider, create a conf that will be used by the batched_update to + # refresh standardized popularity scores. + return [ + { + # Uniquely identify the query + "query_id": ( + f"{provider}_popularity_refresh_{last_updated_time.strftime('%Y%m%d')}" + ), + "table_name": popularity_refresh.media_type, + # Query used to select records that should be refreshed + "select_query": ( + f"WHERE provider='{provider}' AND updated_on <" + f" '{last_updated_time.strftime('%Y-%m-%d %H:%M:%S')}'" + ), + # Query used to update the standardized_popularity + "update_query": format_update_standardized_popularity_query( + media_type=popularity_refresh.media_type + ), + "batch_size": 10_000, + "update_timeout": ( + popularity_refresh.refresh_popularity_batch_timeout.total_seconds() + ), + "dry_run": False, + "resume_update": False, + } + for provider in popularity_refresh.popularity_metrics.keys() + ] diff --git a/catalog/tests/dags/common/conftest.py b/catalog/tests/dags/common/conftest.py index 1a97d5b3a77..c36f42a7209 100644 --- a/catalog/tests/dags/common/conftest.py +++ b/catalog/tests/dags/common/conftest.py @@ -1,20 +1,14 @@ import socket -from datetime import datetime, timedelta from urllib.parse import urlparse import boto3 import pytest -from airflow import DAG -from airflow.models.abstractoperator import AbstractOperator -from airflow.operators.python import PythonOperator from catalog.tests.dags.common.loader.test_s3 import ( ACCESS_KEY, S3_LOCAL_ENDPOINT, SECRET_KEY, ) -from common.constants import POSTGRES_CONN_ID -from common.sql import PGExecuteQueryOperator, PostgresHook POSTGRES_TEST_CONN_ID = "postgres_openledger_testing" @@ -56,88 +50,3 @@ def empty_s3_bucket(request): bucket.create() yield bucket _delete_bucket(bucket) - - -@pytest.fixture -def identifier(request): - return f"{hash(request.node.name)}".replace("-", "_") - - -@pytest.fixture -def image_table(identifier): - # Parallelized tests need to use distinct database tables - return f"image_{identifier}" - - -TEST_SQL = "SELECT PG_SLEEP(1);" - - -def timed_pg_hook_sleeper( - task, - statement_timeout: float = None, -): - pg = PostgresHook( - default_statement_timeout=PostgresHook.get_execution_timeout(task), - conn_id=POSTGRES_CONN_ID, - ) - pg.run(sql=TEST_SQL, statement_timeout=statement_timeout) - - -def mapped_select_pg_hook( - select_val: int, - task: AbstractOperator, -): - pg = PostgresHook( - default_statement_timeout=PostgresHook.get_execution_timeout(task), - conn_id=POSTGRES_CONN_ID, - ) - return pg.run(f"select {select_val};") - - -def create_pg_timeout_tester_dag(): - with DAG( - dag_id="a_pg_timeout_tester", - schedule=None, - doc_md="DAG to test query timeouts in postgres", - start_date=datetime(2023, 1, 1), - ) as dag: - pg_operator_happy = PGExecuteQueryOperator( - task_id="pg_operator_happy", - retries=0, - conn_id=POSTGRES_CONN_ID, - sql=TEST_SQL, - execution_timeout=timedelta(seconds=2), - doc_md="Custom PG operator, with query finished before execution timeout", - ) - pg_hook_happy = PythonOperator( - task_id="pg_hook_happy", - retries=0, - python_callable=timed_pg_hook_sleeper, - execution_timeout=timedelta(hours=2), - doc_md="Custom PG hook, with query finished before execution timeout", - ) - pg_hook_no_timeout = PythonOperator( - task_id="pg_hook_no_timeout", - retries=0, - python_callable=timed_pg_hook_sleeper, - doc_md="Custom PG hook, with no execution timeout", - ) - pg_operator_mapped = PythonOperator.partial( - task_id="pg_operator_mapped", - retries=0, - execution_timeout=timedelta(minutes=1), - doc_md="Custom PG operator, mapped to list", - python_callable=mapped_select_pg_hook, - ).expand(op_args=[(1,), (2,)]) - [pg_operator_happy, pg_hook_happy, pg_hook_no_timeout, pg_operator_mapped] - return dag - - -@pytest.fixture(scope="session") -def mock_timeout_dag(): - return create_pg_timeout_tester_dag() - - -@pytest.fixture(scope="session") -def mock_pg_hook_task(mock_timeout_dag) -> PythonOperator: - return mock_timeout_dag.get_task("pg_hook_happy") diff --git a/catalog/tests/dags/common/loader/test_sql.py b/catalog/tests/dags/common/loader/test_sql.py index 4569c77e46a..ede644e31d3 100644 --- a/catalog/tests/dags/common/loader/test_sql.py +++ b/catalog/tests/dags/common/loader/test_sql.py @@ -11,10 +11,7 @@ from psycopg2.errors import InvalidTextRepresentation from catalog.tests.dags.common.conftest import POSTGRES_TEST_CONN_ID as POSTGRES_CONN_ID -from catalog.tests.dags.common.popularity.test_sql import ( - TableInfo, - _set_up_std_popularity_func, -) +from catalog.tests.dags.popularity.test_sql import _set_up_std_popularity_func from catalog.tests.test_utils import sql as utils from common.loader import sql from common.storage import columns as col @@ -29,22 +26,6 @@ def load_table(identifier): return f"load_image_{identifier}" -@pytest.fixture -def table_info( - image_table, - identifier, -) -> TableInfo: - return TableInfo( - image=image_table, - image_view=f"image_view_{identifier}", - metrics=f"image_popularity_metrics_{identifier}", - standardized_popularity=f"standardized_popularity_{identifier}", - popularity_percentile=f"popularity_percentile_{identifier}", - image_view_idx=f"test_view_id_{identifier}_idx", - provider_fid_idx=f"test_view_provider_fid_{identifier}_idx", - ) - - @pytest.fixture def postgres(load_table) -> utils.PostgresRef: conn = psycopg2.connect(utils.POSTGRES_TEST_URI) @@ -73,26 +54,24 @@ def postgres_with_load_table( @pytest.fixture -def postgres_with_load_and_image_table( - load_table, image_table, table_info, mock_pg_hook_task -): +def postgres_with_load_and_image_table(load_table, sql_info, mock_pg_hook_task): conn = psycopg2.connect(utils.POSTGRES_TEST_URI) cur = conn.cursor() drop_test_relations_query = f""" DROP TABLE IF EXISTS {load_table} CASCADE; - DROP TABLE IF EXISTS {image_table} CASCADE; - DROP INDEX IF EXISTS {image_table}_provider_fid_idx; - DROP TABLE IF EXISTS {table_info.metrics} CASCADE; - DROP FUNCTION IF EXISTS {table_info.standardized_popularity} CASCADE; - DROP FUNCTION IF EXISTS {table_info.popularity_percentile} CASCADE; + DROP TABLE IF EXISTS {sql_info.media_table} CASCADE; + DROP INDEX IF EXISTS {sql_info.media_table}_provider_fid_idx; + DROP TABLE IF EXISTS {sql_info.metrics_table} CASCADE; + DROP FUNCTION IF EXISTS {sql_info.standardized_popularity_fn} CASCADE; + DROP FUNCTION IF EXISTS {sql_info.popularity_percentile_fn} CASCADE; """ cur.execute(drop_test_relations_query) cur.execute(utils.CREATE_LOAD_TABLE_QUERY.format(load_table)) cur.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp" WITH SCHEMA public;') - cur.execute(utils.CREATE_IMAGE_TABLE_QUERY.format(image_table)) - cur.execute(utils.UNIQUE_CONDITION_QUERY.format(table=image_table)) + cur.execute(utils.CREATE_IMAGE_TABLE_QUERY.format(sql_info.media_table)) + cur.execute(utils.UNIQUE_CONDITION_QUERY.format(table=sql_info.media_table)) conn.commit() @@ -142,7 +121,7 @@ def test_create_loading_table_creates_table( postgres, load_table, identifier, mock_pg_hook_task ): postgres_conn_id = POSTGRES_CONN_ID - sql.create_loading_table(postgres_conn_id, identifier) + sql.create_loading_table(postgres_conn_id, identifier, media_type="image") check_query = ( f"SELECT EXISTS (SELECT FROM pg_tables WHERE tablename='{load_table}');" @@ -154,9 +133,9 @@ def test_create_loading_table_creates_table( def test_create_loading_table_errors_if_run_twice_with_same_id(postgres, identifier): postgres_conn_id = POSTGRES_CONN_ID - sql.create_loading_table(postgres_conn_id, identifier) + sql.create_loading_table(postgres_conn_id, identifier, media_type="image") with pytest.raises(Exception): - sql.create_loading_table(postgres_conn_id, identifier) + sql.create_loading_table(postgres_conn_id, identifier, media_type="image") @flaky @@ -421,6 +400,7 @@ def test_upsert_records_inserts_one_record_to_empty_image_table( tmpdir, load_table, image_table, + sql_info, identifier, mock_pg_hook_task, ): @@ -470,10 +450,21 @@ def test_upsert_records_inserts_one_record_to_empty_image_table( load_data_query = f"""INSERT INTO {load_table} VALUES( {query_values} );""" - postgres_with_load_and_image_table.cursor.execute(load_data_query) - postgres_with_load_and_image_table.connection.commit() + + _set_up_std_popularity_func( + postgres_with_load_and_image_table, + load_data_query, + {}, + sql_info, + mock_pg_hook_task, + ) + sql.upsert_records_to_db_table( - postgres_conn_id, identifier, db_table=image_table, task=mock_pg_hook_task + postgres_conn_id, + identifier, + media_type="image", + sql_info=sql_info, + task=mock_pg_hook_task, ) postgres_with_load_and_image_table.cursor.execute(f"SELECT * FROM {image_table};") actual_rows = postgres_with_load_and_image_table.cursor.fetchall() @@ -504,6 +495,7 @@ def test_upsert_records_inserts_two_records_to_image_table( tmpdir, load_table, image_table, + sql_info, identifier, mock_pg_hook_task, ): @@ -534,8 +526,21 @@ def test_upsert_records_inserts_two_records_to_image_table( );""" postgres_with_load_and_image_table.cursor.execute(load_data_query) postgres_with_load_and_image_table.connection.commit() + + _set_up_std_popularity_func( + postgres_with_load_and_image_table, + None, + {}, + sql_info, + mock_pg_hook_task, + ) + sql.upsert_records_to_db_table( - postgres_conn_id, identifier, db_table=image_table, task=mock_pg_hook_task + postgres_conn_id, + identifier, + media_type="image", + sql_info=sql_info, + task=mock_pg_hook_task, ) postgres_with_load_and_image_table.cursor.execute(f"SELECT * FROM {image_table};") actual_rows = postgres_with_load_and_image_table.cursor.fetchall() @@ -548,6 +553,7 @@ def test_upsert_records_replaces_updated_on_and_last_synced_with_source( tmpdir, load_table, image_table, + sql_info, identifier, mock_pg_hook_task, ): @@ -567,11 +573,21 @@ def test_upsert_records_replaces_updated_on_and_last_synced_with_source( '{FID}','{LAND_URL}','{IMG_URL}','{LICENSE}','{VERSION}', '{PROVIDER}','{PROVIDER}' );""" - postgres_with_load_and_image_table.cursor.execute(load_data_query) - postgres_with_load_and_image_table.connection.commit() + + _set_up_std_popularity_func( + postgres_with_load_and_image_table, + load_data_query, + {}, + sql_info, + mock_pg_hook_task, + ) sql.upsert_records_to_db_table( - postgres_conn_id, identifier, db_table=image_table, task=mock_pg_hook_task + postgres_conn_id, + identifier, + media_type="image", + sql_info=sql_info, + task=mock_pg_hook_task, ) postgres_with_load_and_image_table.cursor.execute(f"SELECT * FROM {image_table};") original_row = postgres_with_load_and_image_table.cursor.fetchall() @@ -585,7 +601,11 @@ def test_upsert_records_replaces_updated_on_and_last_synced_with_source( time.sleep(0.5) sql.upsert_records_to_db_table( - postgres_conn_id, identifier, db_table=image_table, task=mock_pg_hook_task + postgres_conn_id, + identifier, + media_type="image", + sql_info=sql_info, + task=mock_pg_hook_task, ) postgres_with_load_and_image_table.cursor.execute(f"SELECT * FROM {image_table};") updated_result = postgres_with_load_and_image_table.cursor.fetchall() @@ -606,6 +626,7 @@ def test_upsert_records_replaces_data( tmpdir, load_table, image_table, + sql_info, identifier, mock_pg_hook_task, ): @@ -668,10 +689,20 @@ def test_upsert_records_replaces_data( load_data_query_a = f"""INSERT INTO {load_table} VALUES( {query_values} );""" - postgres_with_load_and_image_table.cursor.execute(load_data_query_a) - postgres_with_load_and_image_table.connection.commit() + _set_up_std_popularity_func( + postgres_with_load_and_image_table, + load_data_query_a, + {}, + sql_info, + mock_pg_hook_task, + ) + sql.upsert_records_to_db_table( - postgres_conn_id, identifier, db_table=image_table, task=mock_pg_hook_task + postgres_conn_id, + identifier, + media_type="image", + sql_info=sql_info, + task=mock_pg_hook_task, ) postgres_with_load_and_image_table.connection.commit() @@ -705,7 +736,11 @@ def test_upsert_records_replaces_data( postgres_with_load_and_image_table.cursor.execute(load_data_query_b) postgres_with_load_and_image_table.connection.commit() sql.upsert_records_to_db_table( - postgres_conn_id, identifier, db_table=image_table, task=mock_pg_hook_task + postgres_conn_id, + identifier, + media_type="image", + sql_info=sql_info, + task=mock_pg_hook_task, ) postgres_with_load_and_image_table.connection.commit() postgres_with_load_and_image_table.cursor.execute(f"SELECT * FROM {image_table};") @@ -730,6 +765,7 @@ def test_upsert_records_does_not_replace_with_nulls( tmpdir, load_table, image_table, + sql_info, identifier, mock_pg_hook_task, ): @@ -784,10 +820,20 @@ def test_upsert_records_does_not_replace_with_nulls( load_data_query_a = f"""INSERT INTO {load_table} VALUES( {query_values_a} );""" - postgres_with_load_and_image_table.cursor.execute(load_data_query_a) - postgres_with_load_and_image_table.connection.commit() + _set_up_std_popularity_func( + postgres_with_load_and_image_table, + load_data_query_a, + {}, + sql_info, + mock_pg_hook_task, + ) + sql.upsert_records_to_db_table( - postgres_conn_id, identifier, db_table=image_table, task=mock_pg_hook_task + postgres_conn_id, + identifier, + media_type="image", + sql_info=sql_info, + task=mock_pg_hook_task, ) postgres_with_load_and_image_table.connection.commit() @@ -811,7 +857,11 @@ def test_upsert_records_does_not_replace_with_nulls( postgres_with_load_and_image_table.cursor.execute(load_data_query_b) postgres_with_load_and_image_table.connection.commit() sql.upsert_records_to_db_table( - postgres_conn_id, identifier, db_table=image_table, task=mock_pg_hook_task + postgres_conn_id, + identifier, + media_type="image", + sql_info=sql_info, + task=mock_pg_hook_task, ) postgres_with_load_and_image_table.connection.commit() postgres_with_load_and_image_table.cursor.execute(f"SELECT * FROM {image_table};") @@ -837,6 +887,7 @@ def test_upsert_records_merges_meta_data( tmpdir, load_table, image_table, + sql_info, identifier, mock_pg_hook_task, ): @@ -875,10 +926,20 @@ def test_upsert_records_merges_meta_data( load_data_query_b = f"""INSERT INTO {load_table} VALUES( {query_values_b} );""" - postgres_with_load_and_image_table.cursor.execute(load_data_query_a) - postgres_with_load_and_image_table.connection.commit() + _set_up_std_popularity_func( + postgres_with_load_and_image_table, + load_data_query_a, + {}, + sql_info, + mock_pg_hook_task, + ) + sql.upsert_records_to_db_table( - postgres_conn_id, identifier, db_table=image_table, task=mock_pg_hook_task + postgres_conn_id, + identifier, + media_type="image", + sql_info=sql_info, + task=mock_pg_hook_task, ) postgres_with_load_and_image_table.connection.commit() postgres_with_load_and_image_table.cursor.execute(f"DELETE FROM {load_table};") @@ -886,7 +947,11 @@ def test_upsert_records_merges_meta_data( postgres_with_load_and_image_table.cursor.execute(load_data_query_b) postgres_with_load_and_image_table.connection.commit() sql.upsert_records_to_db_table( - postgres_conn_id, identifier, db_table=image_table, task=mock_pg_hook_task + postgres_conn_id, + identifier, + media_type="image", + sql_info=sql_info, + task=mock_pg_hook_task, ) postgres_with_load_and_image_table.connection.commit() postgres_with_load_and_image_table.cursor.execute(f"SELECT * FROM {image_table};") @@ -903,6 +968,7 @@ def test_upsert_records_does_not_replace_with_null_values_in_meta_data( tmpdir, load_table, image_table, + sql_info, identifier, mock_pg_hook_task, ): @@ -941,10 +1007,20 @@ def test_upsert_records_does_not_replace_with_null_values_in_meta_data( load_data_query_b = f"""INSERT INTO {load_table} VALUES( {query_values_b} );""" - postgres_with_load_and_image_table.cursor.execute(load_data_query_a) - postgres_with_load_and_image_table.connection.commit() + _set_up_std_popularity_func( + postgres_with_load_and_image_table, + load_data_query_a, + {}, + sql_info, + mock_pg_hook_task, + ) + sql.upsert_records_to_db_table( - postgres_conn_id, identifier, db_table=image_table, task=mock_pg_hook_task + postgres_conn_id, + identifier, + media_type="image", + sql_info=sql_info, + task=mock_pg_hook_task, ) postgres_with_load_and_image_table.connection.commit() postgres_with_load_and_image_table.cursor.execute(f"DELETE FROM {load_table};") @@ -952,7 +1028,11 @@ def test_upsert_records_does_not_replace_with_null_values_in_meta_data( postgres_with_load_and_image_table.cursor.execute(load_data_query_b) postgres_with_load_and_image_table.connection.commit() sql.upsert_records_to_db_table( - postgres_conn_id, identifier, db_table=image_table, task=mock_pg_hook_task + postgres_conn_id, + identifier, + media_type="image", + sql_info=sql_info, + task=mock_pg_hook_task, ) postgres_with_load_and_image_table.connection.commit() postgres_with_load_and_image_table.cursor.execute(f"SELECT * FROM {image_table};") @@ -971,6 +1051,7 @@ def test_upsert_records_merges_tags( tmpdir, load_table, image_table, + sql_info, identifier, mock_pg_hook_task, ): @@ -1016,10 +1097,19 @@ def test_upsert_records_merges_tags( load_data_query_b = f"""INSERT INTO {load_table} VALUES( {query_values_b} );""" - postgres_with_load_and_image_table.cursor.execute(load_data_query_a) - postgres_with_load_and_image_table.connection.commit() + _set_up_std_popularity_func( + postgres_with_load_and_image_table, + load_data_query_a, + {}, + sql_info, + mock_pg_hook_task, + ) sql.upsert_records_to_db_table( - postgres_conn_id, identifier, db_table=image_table, task=mock_pg_hook_task + postgres_conn_id, + identifier, + media_type="image", + sql_info=sql_info, + task=mock_pg_hook_task, ) postgres_with_load_and_image_table.connection.commit() postgres_with_load_and_image_table.cursor.execute(f"DELETE FROM {load_table};") @@ -1027,7 +1117,11 @@ def test_upsert_records_merges_tags( postgres_with_load_and_image_table.cursor.execute(load_data_query_b) postgres_with_load_and_image_table.connection.commit() sql.upsert_records_to_db_table( - postgres_conn_id, identifier, db_table=image_table, task=mock_pg_hook_task + postgres_conn_id, + identifier, + media_type="image", + sql_info=sql_info, + task=mock_pg_hook_task, ) postgres_with_load_and_image_table.connection.commit() postgres_with_load_and_image_table.cursor.execute(f"SELECT * FROM {image_table};") @@ -1050,6 +1144,7 @@ def test_upsert_records_does_not_replace_tags_with_null( tmpdir, load_table, image_table, + sql_info, identifier, mock_pg_hook_task, ): @@ -1089,10 +1184,19 @@ def test_upsert_records_does_not_replace_tags_with_null( load_data_query_b = f"""INSERT INTO {load_table} VALUES( {query_values_b} );""" - postgres_with_load_and_image_table.cursor.execute(load_data_query_a) - postgres_with_load_and_image_table.connection.commit() + _set_up_std_popularity_func( + postgres_with_load_and_image_table, + load_data_query_a, + {}, + sql_info, + mock_pg_hook_task, + ) sql.upsert_records_to_db_table( - postgres_conn_id, identifier, db_table=image_table, task=mock_pg_hook_task + postgres_conn_id, + identifier, + media_type="image", + sql_info=sql_info, + task=mock_pg_hook_task, ) postgres_with_load_and_image_table.connection.commit() postgres_with_load_and_image_table.cursor.execute(f"DELETE FROM {load_table};") @@ -1100,7 +1204,11 @@ def test_upsert_records_does_not_replace_tags_with_null( postgres_with_load_and_image_table.cursor.execute(load_data_query_b) postgres_with_load_and_image_table.connection.commit() sql.upsert_records_to_db_table( - postgres_conn_id, identifier, db_table=image_table, task=mock_pg_hook_task + postgres_conn_id, + identifier, + media_type="image", + sql_info=sql_info, + task=mock_pg_hook_task, ) postgres_with_load_and_image_table.connection.commit() postgres_with_load_and_image_table.cursor.execute(f"SELECT * FROM {image_table};") @@ -1122,6 +1230,7 @@ def test_upsert_records_replaces_null_tags( tmpdir, load_table, image_table, + sql_info, identifier, mock_pg_hook_task, ): @@ -1160,10 +1269,19 @@ def test_upsert_records_replaces_null_tags( {query_values_b} );""" - postgres_with_load_and_image_table.cursor.execute(load_data_query_a) - postgres_with_load_and_image_table.connection.commit() + _set_up_std_popularity_func( + postgres_with_load_and_image_table, + load_data_query_a, + {}, + sql_info, + mock_pg_hook_task, + ) sql.upsert_records_to_db_table( - postgres_conn_id, identifier, db_table=image_table, task=mock_pg_hook_task + postgres_conn_id, + identifier, + media_type="image", + sql_info=sql_info, + task=mock_pg_hook_task, ) postgres_with_load_and_image_table.connection.commit() postgres_with_load_and_image_table.cursor.execute(f"DELETE FROM {load_table};") @@ -1171,7 +1289,11 @@ def test_upsert_records_replaces_null_tags( postgres_with_load_and_image_table.cursor.execute(load_data_query_b) postgres_with_load_and_image_table.connection.commit() sql.upsert_records_to_db_table( - postgres_conn_id, identifier, db_table=image_table, task=mock_pg_hook_task + postgres_conn_id, + identifier, + media_type="image", + sql_info=sql_info, + task=mock_pg_hook_task, ) postgres_with_load_and_image_table.connection.commit() postgres_with_load_and_image_table.cursor.execute(f"SELECT * FROM {image_table};") @@ -1194,6 +1316,7 @@ def test_upsert_records_handles_duplicate_url_and_does_not_merge( tmpdir, load_table, image_table, + sql_info, identifier, mock_pg_hook_task, ): @@ -1238,10 +1361,19 @@ def test_upsert_records_handles_duplicate_url_and_does_not_merge( # Simulate a DAG run where A is ingested into the loading table, upserted into # the image table, and finally the loading table is cleared for the next DAG run. - postgres_with_load_and_image_table.cursor.execute(load_data_query_a) - postgres_with_load_and_image_table.connection.commit() + _set_up_std_popularity_func( + postgres_with_load_and_image_table, + load_data_query_a, + {}, + sql_info, + mock_pg_hook_task, + ) sql.upsert_records_to_db_table( - postgres_conn_id, identifier, db_table=image_table, task=mock_pg_hook_task + postgres_conn_id, + identifier, + media_type="image", + sql_info=sql_info, + task=mock_pg_hook_task, ) postgres_with_load_and_image_table.connection.commit() postgres_with_load_and_image_table.cursor.execute(f"DELETE FROM {load_table};") @@ -1252,7 +1384,11 @@ def test_upsert_records_handles_duplicate_url_and_does_not_merge( postgres_with_load_and_image_table.cursor.execute(load_data_query_b) postgres_with_load_and_image_table.connection.commit() sql.upsert_records_to_db_table( - postgres_conn_id, identifier, db_table=image_table, task=mock_pg_hook_task + postgres_conn_id, + identifier, + media_type="image", + sql_info=sql_info, + task=mock_pg_hook_task, ) postgres_with_load_and_image_table.connection.commit() @@ -1274,6 +1410,7 @@ def test_upsert_records_handles_duplicate_urls_in_a_single_batch_and_does_not_me tmpdir, load_table, image_table, + sql_info, identifier, mock_pg_hook_task, ): @@ -1345,9 +1482,20 @@ def test_upsert_records_handles_duplicate_urls_in_a_single_batch_and_does_not_me rows = postgres_with_load_and_image_table.cursor.fetchall() assert len(rows) == 3 + _set_up_std_popularity_func( + postgres_with_load_and_image_table, + None, + {}, + sql_info, + mock_pg_hook_task, + ) # Now try upserting the records from the loading table to the final image table. sql.upsert_records_to_db_table( - postgres_conn_id, identifier, db_table=image_table, task=mock_pg_hook_task + postgres_conn_id, + identifier, + media_type="image", + sql_info=sql_info, + task=mock_pg_hook_task, ) postgres_with_load_and_image_table.connection.commit() @@ -1370,7 +1518,7 @@ def test_upsert_records_calculates_standardized_popularity( load_table, image_table, identifier, - table_info, + sql_info, mock_pg_hook_task, ): postgres_conn_id = POSTGRES_CONN_ID @@ -1407,17 +1555,17 @@ def test_upsert_records_calculates_standardized_popularity( PROVIDER: {"metric": "views", "percentile": 0.8}, } - # Now we set up the popularity constants tables, views, and functions. This method will - # run the `data_query` to insert our test rows, which will initially have `null` standardized + # Now we re-set up the popularity constants tables, views, and functions after running + # the `data_query` to insert our test rows, which will initially have `null` standardized # popularity (because no popularity constants exist). Then it will insert `metrics` into - # the `image_popularity_metrics` table, and create the `image_popularity_constants` view, - # calculating a value for the popularity constant for PROVIDER using those initial records. + # the `image_popularity_metrics` table, and calculate a value for the popularity constant + # for PROVIDER using those initial records. # Then it sets up the standardized popularity function itself. _set_up_std_popularity_func( postgres_with_load_and_image_table, data_query, metrics, - table_info, + sql_info, mock_pg_hook_task, ) @@ -1471,8 +1619,8 @@ def test_upsert_records_calculates_standardized_popularity( sql.upsert_records_to_db_table( postgres_conn_id, identifier, - db_table=image_table, - popularity_function=table_info.standardized_popularity, + media_type="image", + sql_info=sql_info, task=mock_pg_hook_task, ) postgres_with_load_and_image_table.connection.commit() diff --git a/catalog/tests/dags/common/popularity/__init__.py b/catalog/tests/dags/common/popularity/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/catalog/tests/dags/common/popularity/test_dag_factory.py b/catalog/tests/dags/common/popularity/test_dag_factory.py deleted file mode 100644 index ba173a24dfe..00000000000 --- a/catalog/tests/dags/common/popularity/test_dag_factory.py +++ /dev/null @@ -1,101 +0,0 @@ -from datetime import timedelta -from unittest import mock - -import pytest -from airflow.models import DagRun -from airflow.models.dag import DAG -from airflow.utils.session import create_session -from airflow.utils.timezone import datetime -from airflow.utils.types import DagRunType -from popularity.dag_factory import get_providers_update_confs -from popularity.popularity_refresh_types import PopularityRefresh - -from catalog.tests.test_utils.sql import POSTGRES_CONN_ID - - -TEST_DAG_ID = "popularity_refresh_dag_factory_test_dag" -TEST_DAG = DAG(TEST_DAG_ID, default_args={"owner": "airflow"}) -TEST_DAY = datetime(2023, 1, 1) - - -@pytest.fixture(autouse=True) -def clean_db(): - with create_session() as session: - session.query(DagRun).filter(DagRun.dag_id == TEST_DAG_ID).delete() - - -def _create_dagrun(start_date, dag_state, conf={}): - return TEST_DAG.create_dagrun( - start_date=start_date, - execution_date=start_date, - data_interval=(start_date, start_date), - state=dag_state, - run_type=DagRunType.MANUAL, - conf=conf, - ) - - -@pytest.mark.parametrize( - "providers, media_type, expected_confs", - [ - # No providers for this media type - ([], "image", []), - ( - ["foo_provider"], - "image", - [ - { - "query_id": "foo_provider_popularity_refresh_20230101", - "table_name": "image", - "select_query": "WHERE provider='foo_provider' AND updated_on < '2023-01-01 00:00:00'", - "update_query": "SET standardized_popularity = standardized_image_popularity(image.provider, image.meta_data)", - "batch_size": 10000, - "update_timeout": 3600.0, - "dry_run": False, - "resume_update": False, - }, - ], - ), - ( - ["my_provider", "your_provider"], - "audio", - [ - { - "query_id": "my_provider_popularity_refresh_20230101", - "table_name": "audio", - "select_query": "WHERE provider='my_provider' AND updated_on < '2023-01-01 00:00:00'", - "update_query": "SET standardized_popularity = standardized_audio_popularity(audio.provider, audio.meta_data)", - "batch_size": 10000, - "update_timeout": 3600.0, - "dry_run": False, - "resume_update": False, - }, - { - "query_id": "your_provider_popularity_refresh_20230101", - "table_name": "audio", - "select_query": "WHERE provider='your_provider' AND updated_on < '2023-01-01 00:00:00'", - "update_query": "SET standardized_popularity = standardized_audio_popularity(audio.provider, audio.meta_data)", - "batch_size": 10000, - "update_timeout": 3600.0, - "dry_run": False, - "resume_update": False, - }, - ], - ), - ], -) -def test_get_providers_update_confs(providers, media_type, expected_confs): - with mock.patch( - "common.popularity.sql.get_providers_with_popularity_data_for_media_type", - return_value=providers, - ): - actual_confs = get_providers_update_confs.function( - POSTGRES_CONN_ID, - PopularityRefresh( - media_type=media_type, - refresh_popularity_batch_timeout=timedelta(hours=1), - ), - TEST_DAY, - ) - - assert actual_confs == expected_confs diff --git a/catalog/tests/dags/common/popularity/test_resources/mock_popularity_dump.tsv b/catalog/tests/dags/common/popularity/test_resources/mock_popularity_dump.tsv deleted file mode 100644 index b2e3a7cf65c..00000000000 --- a/catalog/tests/dags/common/popularity/test_resources/mock_popularity_dump.tsv +++ /dev/null @@ -1,4 +0,0 @@ -identifier provider global_usage_count views -00000000-0000-0000-0000-000000000001 foo 500 -00000000-0000-0000-0000-000000000002 foo 10 -00000000-0000-0000-0000-000000000003 foo 10 500 diff --git a/catalog/tests/dags/common/test_utils.py b/catalog/tests/dags/common/test_utils.py new file mode 100644 index 00000000000..aec2401e1fc --- /dev/null +++ b/catalog/tests/dags/common/test_utils.py @@ -0,0 +1,101 @@ +import pytest + +from common.utils import setup_kwargs_for_media_type + + +TEST_VALS_BY_MEDIA_TYPE = {"audio": "foo", "image": "bar"} +p = pytest.param + + +@pytest.mark.parametrize( + "media_type, my_param, expected_param", + ( + ("audio", None, "foo"), + ("image", None, "bar"), + # Pass in an explicit value for my_param; this should be returned + p( + "audio", + "hello world", + "hello world", + id="explicitly passed value should be returned", + ), + p( + "foo", + "hello world", + "hello world", + id="explicitly passed value is returned, even if the values dict does not have a key for the media type", + ), + # No media type + p( + None, + None, + None, + marks=pytest.mark.raises(exception=ValueError), + id="raises error when no media type passed", + ), + p( + "foo", + None, + None, + marks=pytest.mark.raises(exception=ValueError), + id="raises error when no matching key in values dict", + ), + ), +) +def test_setup_kwargs_for_media_type(media_type, my_param, expected_param): + @setup_kwargs_for_media_type(TEST_VALS_BY_MEDIA_TYPE, "my_param") + def test_fn(*, media_type: str, my_param: str = None): + assert my_param == expected_param + + test_fn(media_type=media_type, my_param=my_param) + + +def test_setup_kwargs_for_media_type_creates_new_decorator(): + # Create a new decorator using the factory + new_decorator = setup_kwargs_for_media_type(TEST_VALS_BY_MEDIA_TYPE, "new_param") + + # New function decorated with this decorator + @new_decorator + def test_fn(*, media_type: str, new_param: str = None): + return new_param + + assert test_fn(media_type="audio") == "foo" + + +def test_setup_kwargs_for_media_type_fails_without_media_type_kwarg(): + with pytest.raises(Exception, match="Improperly configured"): + # Decorated function does not have a media_type kwarg + @setup_kwargs_for_media_type(TEST_VALS_BY_MEDIA_TYPE, "my_param") + def test_fn(*, my_param: str = None): + pass + + +def test_setup_kwargs_for_media_type_fails_with_media_type_arg(): + with pytest.raises(Exception, match="Improperly configured"): + # Decorate a function that allows media_type to be passed as a keyword + # or as a positional argument + @setup_kwargs_for_media_type(TEST_VALS_BY_MEDIA_TYPE, "my_param") + def test_fn(media_type, my_param: str = None): + pass + + +def test_setup_kwargs_for_media_type_fails_with_var_kwargs(): + with pytest.raises(Exception, match="Improperly configured"): + # Decorate a function that has var kwargs but does not explicitly + # require a keyword-only `media_type` arg + @setup_kwargs_for_media_type(TEST_VALS_BY_MEDIA_TYPE, "my_param") + def test_fn(**kwargs): + pass + + +def test_setup_kwargs_for_media_type_fails_without_kwarg(): + # Decorated function does not have the kwarg we want populated + @setup_kwargs_for_media_type(TEST_VALS_BY_MEDIA_TYPE, "my_param") + def test_fn(*, media_type: str): + pass + + with pytest.raises( + TypeError, + match="got an unexpected keyword argument 'my_param'", + ): + test_fn(media_type="audio") diff --git a/catalog/tests/dags/conftest.py b/catalog/tests/dags/conftest.py index d994ae6729b..a9e53f3b60b 100644 --- a/catalog/tests/dags/conftest.py +++ b/catalog/tests/dags/conftest.py @@ -1,8 +1,14 @@ +from datetime import datetime, timedelta from unittest import mock import pytest +from airflow import DAG +from airflow.models.abstractoperator import AbstractOperator +from airflow.operators.python import PythonOperator from requests import Response +from common.constants import POSTGRES_CONN_ID, SQLInfo +from common.sql import PGExecuteQueryOperator, PostgresHook from oauth2 import oauth2 @@ -28,6 +34,104 @@ def _var_get_replacement(*args, **kwargs): return values[args[0]] +@pytest.fixture +def identifier(request): + return f"{hash(request.node.name)}".replace("-", "_") + + +@pytest.fixture +def image_table(identifier): + # Parallelized tests need to use distinct database tables + return f"image_{identifier}" + + +@pytest.fixture +def sql_info( + image_table, + identifier, +) -> SQLInfo: + return SQLInfo( + media_table=image_table, + metrics_table=f"image_popularity_metrics_{identifier}", + standardized_popularity_fn=f"standardized_image_popularity_{identifier}", + popularity_percentile_fn=f"image_popularity_percentile_{identifier}", + ) + + +TEST_SQL = "SELECT PG_SLEEP(1);" + + +def timed_pg_hook_sleeper( + task, + statement_timeout: float = None, +): + pg = PostgresHook( + default_statement_timeout=PostgresHook.get_execution_timeout(task), + conn_id=POSTGRES_CONN_ID, + ) + pg.run(sql=TEST_SQL, statement_timeout=statement_timeout) + + +def mapped_select_pg_hook( + select_val: int, + task: AbstractOperator, +): + pg = PostgresHook( + default_statement_timeout=PostgresHook.get_execution_timeout(task), + conn_id=POSTGRES_CONN_ID, + ) + return pg.run(f"select {select_val};") + + +def create_pg_timeout_tester_dag(): + with DAG( + dag_id="a_pg_timeout_tester", + schedule=None, + doc_md="DAG to test query timeouts in postgres", + start_date=datetime(2023, 1, 1), + ) as dag: + pg_operator_happy = PGExecuteQueryOperator( + task_id="pg_operator_happy", + retries=0, + conn_id=POSTGRES_CONN_ID, + sql=TEST_SQL, + execution_timeout=timedelta(seconds=2), + doc_md="Custom PG operator, with query finished before execution timeout", + ) + pg_hook_happy = PythonOperator( + task_id="pg_hook_happy", + retries=0, + python_callable=timed_pg_hook_sleeper, + execution_timeout=timedelta(hours=2), + doc_md="Custom PG hook, with query finished before execution timeout", + ) + pg_hook_no_timeout = PythonOperator( + task_id="pg_hook_no_timeout", + retries=0, + python_callable=timed_pg_hook_sleeper, + doc_md="Custom PG hook, with no execution timeout", + ) + pg_operator_mapped = PythonOperator.partial( + task_id="pg_operator_mapped", + retries=0, + execution_timeout=timedelta(minutes=1), + doc_md="Custom PG operator, mapped to list", + python_callable=mapped_select_pg_hook, + ).expand(op_args=[(1,), (2,)]) + [pg_operator_happy, pg_hook_happy, pg_hook_no_timeout, pg_operator_mapped] + return dag + + +@pytest.fixture(scope="session") +def mock_timeout_dag(): + return create_pg_timeout_tester_dag() + + +@pytest.fixture(scope="session") +def mock_pg_hook_task(mock_timeout_dag) -> PythonOperator: + return mock_timeout_dag.get_task("pg_hook_happy") + + @pytest.fixture def oauth_provider_var_mock(): with mock.patch("oauth2.oauth2.Variable") as MockVariable: diff --git a/catalog/dags/common/popularity/__init__.py b/catalog/tests/dags/popularity/__init__.py similarity index 100% rename from catalog/dags/common/popularity/__init__.py rename to catalog/tests/dags/popularity/__init__.py diff --git a/catalog/tests/dags/popularity/test_popularity_refresh_types.py b/catalog/tests/dags/popularity/test_popularity_refresh_types.py new file mode 100644 index 00000000000..6cece979c9f --- /dev/null +++ b/catalog/tests/dags/popularity/test_popularity_refresh_types.py @@ -0,0 +1,36 @@ +import re +from pathlib import Path + +import pytest +from popularity.popularity_refresh_types import POPULARITY_REFRESH_CONFIGS + + +DDL_DEFINITIONS_PATH = Path(__file__).parents[4] / "docker" / "upstream_db" + + +@pytest.mark.parametrize( + "ddl_filename, metrics", + [ + ( + "0004_openledger_image_view.sql", + POPULARITY_REFRESH_CONFIGS[0].popularity_metrics, + ), + ( + "0007_openledger_audio_view.sql", + POPULARITY_REFRESH_CONFIGS[1].popularity_metrics, + ), + ], +) +def test_ddl_matches_definitions(ddl_filename, metrics): + ddl = (DDL_DEFINITIONS_PATH / ddl_filename).read_text() + if not ( + match := re.search( + r"INSERT INTO public.\w+_popularity_metrics.*?;", + ddl, + re.MULTILINE | re.DOTALL, + ) + ): + raise ValueError(f"Could not find insert statement in ddl file {ddl_filename}") + + for provider in metrics: + assert provider in match.group(0) diff --git a/catalog/tests/dags/common/popularity/test_sql.py b/catalog/tests/dags/popularity/test_sql.py similarity index 57% rename from catalog/tests/dags/common/popularity/test_sql.py rename to catalog/tests/dags/popularity/test_sql.py index 12b9285e3e3..3792b22061f 100644 --- a/catalog/tests/dags/common/popularity/test_sql.py +++ b/catalog/tests/dags/popularity/test_sql.py @@ -1,73 +1,58 @@ import os -import re from collections import namedtuple -from pathlib import Path +from datetime import datetime, timedelta from textwrap import dedent -from typing import NamedTuple import psycopg2 import pytest +from popularity import sql +from popularity.popularity_refresh_types import PopularityRefresh from catalog.tests.dags.common.conftest import POSTGRES_TEST_CONN_ID as POSTGRES_CONN_ID +from common.constants import SQLInfo from common.loader.sql import create_column_definitions -from common.popularity import sql from common.storage.db_columns import IMAGE_TABLE_COLUMNS -DDL_DEFINITIONS_PATH = Path(__file__).parents[5] / "docker" / "upstream_db" POSTGRES_TEST_URI = os.getenv("AIRFLOW_CONN_POSTGRES_OPENLEDGER_TESTING") -class TableInfo(NamedTuple): - image: str - image_view: str - metrics: str - standardized_popularity: str - popularity_percentile: str - image_view_idx: str - provider_fid_idx: str - - @pytest.fixture -def table_info( +def sql_info( image_table, identifier, -) -> TableInfo: - return TableInfo( - image=image_table, - image_view=f"image_view_{identifier}", - metrics=f"image_popularity_metrics_{identifier}", - standardized_popularity=f"standardized_popularity_{identifier}", - popularity_percentile=f"popularity_percentile_{identifier}", - image_view_idx=f"test_view_id_{identifier}_idx", - provider_fid_idx=f"test_view_provider_fid_{identifier}_idx", +) -> SQLInfo: + return SQLInfo( + media_table=image_table, + metrics_table=f"image_popularity_metrics_{identifier}", + standardized_popularity_fn=f"standardized_image_popularity_{identifier}", + popularity_percentile_fn=f"image_popularity_percentile_{identifier}", ) @pytest.fixture -def postgres_with_image_table(table_info): +def postgres_with_image_table(sql_info): Postgres = namedtuple("Postgres", ["cursor", "connection"]) conn = psycopg2.connect(POSTGRES_TEST_URI) cur = conn.cursor() drop_test_relations_query = f""" - DROP MATERIALIZED VIEW IF EXISTS {table_info.image_view} CASCADE; - DROP TABLE IF EXISTS {table_info.metrics} CASCADE; - DROP TABLE IF EXISTS {table_info.image} CASCADE; - DROP FUNCTION IF EXISTS {table_info.standardized_popularity} CASCADE; - DROP FUNCTION IF EXISTS {table_info.popularity_percentile} CASCADE; + DROP TABLE IF EXISTS {sql_info.metrics_table} CASCADE; + DROP TABLE IF EXISTS {sql_info.media_table} CASCADE; + DROP FUNCTION IF EXISTS {sql_info.standardized_popularity_fn} CASCADE; + DROP FUNCTION IF EXISTS {sql_info.popularity_percentile_fn} CASCADE; """ cur.execute(drop_test_relations_query) cur.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp" WITH SCHEMA public;') image_columns = create_column_definitions(IMAGE_TABLE_COLUMNS) - cur.execute(f"CREATE TABLE public.{table_info.image} ({image_columns});") + cur.execute(f"CREATE TABLE public.{sql_info.media_table} ({image_columns});") cur.execute( f""" -CREATE UNIQUE INDEX {table_info.image}_provider_fid_idx -ON public.{table_info.image} +CREATE UNIQUE INDEX {sql_info.media_table}_provider_fid_idx +ON public.{sql_info.media_table} USING btree (provider, md5(foreign_identifier)); """ ) @@ -82,20 +67,21 @@ def postgres_with_image_table(table_info): conn.close() -def _set_up_popularity_metrics(metrics_dict, table_info, mock_pg_hook_task): +def _set_up_popularity_metrics(metrics_dict, sql_info, mock_pg_hook_task): conn_id = POSTGRES_CONN_ID # Create metrics table - sql.create_media_popularity_metrics( - postgres_conn_id=conn_id, - popularity_metrics_table=table_info.metrics, + sql.create_media_popularity_metrics.function( + postgres_conn_id=conn_id, media_type="image", sql_info=sql_info ) # Insert values from metrics_dict into metrics table - sql.update_media_popularity_metrics.function( - postgres_conn_id=conn_id, - popularity_metrics=metrics_dict, - popularity_metrics_table=table_info.metrics, - task=mock_pg_hook_task, - ) + if metrics_dict: + sql.update_media_popularity_metrics.function( + postgres_conn_id=conn_id, + media_type="image", + popularity_metrics=metrics_dict, + sql_info=sql_info, + task=mock_pg_hook_task, + ) # For each provider in metrics_dict, calculate the percentile and then # update the percentile and popularity constant @@ -103,95 +89,68 @@ def _set_up_popularity_metrics(metrics_dict, table_info, mock_pg_hook_task): percentile_val = sql.calculate_media_popularity_percentile_value.function( postgres_conn_id=conn_id, provider=provider, + media_type="image", task=mock_pg_hook_task, - popularity_metrics_table=table_info.metrics, - popularity_percentile=table_info.popularity_percentile, + sql_info=sql_info, ) sql.update_percentile_and_constants_values_for_provider.function( postgres_conn_id=conn_id, provider=provider, raw_percentile_value=percentile_val, - popularity_metrics_table=table_info.metrics, + media_type="image", popularity_metrics=metrics_dict, + sql_info=sql_info, ) -def _set_up_popularity_percentile_function(table_info): +def _set_up_popularity_percentile_function(sql_info): conn_id = POSTGRES_CONN_ID - sql.create_media_popularity_percentile_function( - conn_id, - popularity_percentile=table_info.popularity_percentile, - media_table=table_info.image, + sql.create_media_popularity_percentile_function.function( + conn_id, media_type="image", sql_info=sql_info ) -def _set_up_popularity_constants( +def _set_up_popularity_metrics_and_constants( pg, data_query, metrics_dict, - table_info, + sql_info, mock_pg_hook_task, ): # Execute the data query first (typically, loads sample data into the media table) - pg.cursor.execute(data_query) - pg.connection.commit() + if data_query: + pg.cursor.execute(data_query) + pg.connection.commit() # Then set up functions, metrics, and constants - _set_up_popularity_percentile_function(table_info) - _set_up_popularity_metrics(metrics_dict, table_info, mock_pg_hook_task) + _set_up_popularity_percentile_function(sql_info) + _set_up_popularity_metrics(metrics_dict, sql_info, mock_pg_hook_task) def _set_up_std_popularity_func( pg, data_query, metrics_dict, - table_info, + sql_info, mock_pg_hook_task, ): conn_id = POSTGRES_CONN_ID - _set_up_popularity_constants( + _set_up_popularity_metrics_and_constants( pg, data_query, metrics_dict, - table_info, - mock_pg_hook_task, - ) - sql.create_standardized_media_popularity_function( - conn_id, + sql_info, mock_pg_hook_task, - function_name=table_info.standardized_popularity, - popularity_metrics=table_info.metrics, - ) - - -def _set_up_image_view( - pg, - data_query, - metrics_dict, - table_info, - mock_pg_hook_task, -): - conn_id = POSTGRES_CONN_ID - _set_up_std_popularity_func( - pg, data_query, metrics_dict, table_info, mock_pg_hook_task ) - sql.create_media_view( - conn_id, - standardized_popularity_func=table_info.standardized_popularity, - table_name=table_info.image, - db_view_name=table_info.image_view, - db_view_id_idx=table_info.image_view_idx, - db_view_provider_fid_idx=table_info.provider_fid_idx, - task=mock_pg_hook_task, + sql.create_standardized_media_popularity_function.function( + conn_id, media_type="image", sql_info=sql_info ) -def test_popularity_percentile_function_calculates( - postgres_with_image_table, table_info -): +def test_popularity_percentile_function_calculates(postgres_with_image_table, sql_info): data_query = dedent( f""" - INSERT INTO {table_info.image} ( + INSERT INTO {sql_info.media_table} ( created_on, updated_on, provider, foreign_identifier, url, meta_data, license, removed_from_source ) @@ -225,10 +184,10 @@ def test_popularity_percentile_function_calculates( ) postgres_with_image_table.cursor.execute(data_query) postgres_with_image_table.connection.commit() - _set_up_popularity_percentile_function(table_info) + _set_up_popularity_percentile_function(sql_info) mp_perc_1 = dedent( f""" - SELECT {table_info.popularity_percentile}('my_provider', 'views', 0.5); + SELECT {sql_info.popularity_percentile_fn}('my_provider', 'views', 0.5); """ ) postgres_with_image_table.cursor.execute(mp_perc_1) @@ -237,7 +196,7 @@ def test_popularity_percentile_function_calculates( assert actual_percentile_val == expect_percentile_val mp_perc_2 = dedent( f""" - SELECT {table_info.popularity_percentile}('diff_provider', 'comments', 0.3); + SELECT {sql_info.popularity_percentile_fn}('diff_provider', 'comments', 0.3); """ ) postgres_with_image_table.cursor.execute(mp_perc_2) @@ -247,11 +206,11 @@ def test_popularity_percentile_function_calculates( def test_popularity_percentile_function_nones_when_missing_type( - postgres_with_image_table, table_info + postgres_with_image_table, sql_info ): data_query = dedent( f""" - INSERT INTO {table_info.image} ( + INSERT INTO {sql_info.media_table} ( created_on, updated_on, provider, foreign_identifier, url, meta_data, license, removed_from_source ) @@ -269,10 +228,10 @@ def test_popularity_percentile_function_nones_when_missing_type( ) postgres_with_image_table.cursor.execute(data_query) postgres_with_image_table.connection.commit() - _set_up_popularity_percentile_function(table_info) + _set_up_popularity_percentile_function(sql_info) mp_perc_3 = dedent( f""" - SELECT {table_info.popularity_percentile}('diff_provider', 'views', 0.3); + SELECT {sql_info.popularity_percentile_fn}('diff_provider', 'views', 0.3); """ ) postgres_with_image_table.cursor.execute(mp_perc_3) @@ -281,11 +240,11 @@ def test_popularity_percentile_function_nones_when_missing_type( def test_metrics_table_adds_values_and_constants( - postgres_with_image_table, table_info, mock_pg_hook_task + postgres_with_image_table, sql_info, mock_pg_hook_task ): data_query = dedent( f""" - INSERT INTO {table_info.image} ( + INSERT INTO {sql_info.media_table} ( created_on, updated_on, provider, foreign_identifier, url, meta_data, license, removed_from_source ) @@ -321,11 +280,11 @@ def test_metrics_table_adds_values_and_constants( "my_provider": {"metric": "views", "percentile": 0.5}, "diff_provider": {"metric": "comments", "percentile": 0.8}, } - _set_up_popularity_constants( - postgres_with_image_table, data_query, metrics, table_info, mock_pg_hook_task + _set_up_popularity_metrics_and_constants( + postgres_with_image_table, data_query, metrics, sql_info, mock_pg_hook_task ) - check_query = f"SELECT * FROM {table_info.metrics};" + check_query = f"SELECT * FROM {sql_info.metrics_table};" postgres_with_image_table.cursor.execute(check_query) expect_rows = [ ("diff_provider", "comments", 0.8, 50.0, 12.5), @@ -337,11 +296,11 @@ def test_metrics_table_adds_values_and_constants( def test_metrics_table_handles_zeros_and_missing_in_constants( - postgres_with_image_table, table_info, mock_pg_hook_task + postgres_with_image_table, sql_info, mock_pg_hook_task ): data_query = dedent( f""" - INSERT INTO {table_info.image} ( + INSERT INTO {sql_info.media_table} ( created_on, updated_on, provider, foreign_identifier, url, meta_data, license, removed_from_source ) @@ -379,11 +338,11 @@ def test_metrics_table_handles_zeros_and_missing_in_constants( # Provider that has a metric configured, but no records with data for that metric "diff_provider": {"metric": "comments", "percentile": 0.8}, } - _set_up_popularity_constants( - postgres_with_image_table, data_query, metrics, table_info, mock_pg_hook_task + _set_up_popularity_metrics_and_constants( + postgres_with_image_table, data_query, metrics, sql_info, mock_pg_hook_task ) - check_query = f"SELECT * FROM {table_info.metrics};" + check_query = f"SELECT * FROM {sql_info.metrics_table};" postgres_with_image_table.cursor.execute(check_query) expect_rows = [ ("diff_provider", "comments", 0.8, None, None), @@ -394,53 +353,12 @@ def test_metrics_table_handles_zeros_and_missing_in_constants( assert expect_row == pytest.approx(sorted_row) -def test_get_providers_with_popularity_data_for_media_type( - postgres_with_image_table, table_info, mock_pg_hook_task -): - data_query = dedent( - f""" - INSERT INTO {table_info.image} ( - created_on, updated_on, provider, foreign_identifier, url, - meta_data, license, removed_from_source - ) - VALUES - ( - NOW(), NOW(), 'my_provider', 'fid_a', 'https://test.com/a.jpg', - '{{"views": 0, "description": "cats"}}', 'cc0', false - ), - ( - NOW(), NOW(), 'diff_provider', 'fid_b', 'https://test.com/b.jpg', - '{{"views": 50, "description": "cats"}}', 'cc0', false - ), - ( - NOW(), NOW(), 'provider_without_popularity', 'fid_b', 'https://test.com/b.jpg', - '{{"views": 50, "description": "cats"}}', 'cc0', false - ) - ; - """ - ) - metrics = { - "my_provider": {"metric": "views", "percentile": 0.8}, - "diff_provider": {"metric": "comments", "percentile": 0.8}, - } - _set_up_popularity_constants( - postgres_with_image_table, data_query, metrics, table_info, mock_pg_hook_task - ) - - expected_providers = ["diff_provider", "my_provider"] - actual_providers = sql.get_providers_with_popularity_data_for_media_type( - POSTGRES_CONN_ID, media_type="image", popularity_metrics=table_info.metrics - ) - - assert actual_providers == expected_providers - - def test_standardized_popularity_function_calculates( - postgres_with_image_table, table_info, mock_pg_hook_task + postgres_with_image_table, sql_info, mock_pg_hook_task ): data_query = dedent( f""" - INSERT INTO {table_info.image} ( + INSERT INTO {sql_info.media_table} ( created_on, updated_on, provider, foreign_identifier, url, meta_data, license, removed_from_source ) @@ -466,9 +384,9 @@ def test_standardized_popularity_function_calculates( "other_provider": {"metric": "likes", "percentile": 0.5}, } _set_up_std_popularity_func( - postgres_with_image_table, data_query, metrics, table_info, mock_pg_hook_task + postgres_with_image_table, data_query, metrics, sql_info, mock_pg_hook_task ) - check_query = f"SELECT * FROM {table_info.metrics};" + check_query = f"SELECT * FROM {sql_info.metrics_table};" postgres_with_image_table.cursor.execute(check_query) print(list(postgres_with_image_table.cursor)) arg_list = [ @@ -487,7 +405,7 @@ def test_standardized_popularity_function_calculates( print(arg_list[i]) std_pop_query = dedent( f""" - SELECT {table_info.standardized_popularity}( + SELECT {sql_info.standardized_popularity_fn}( '{arg_list[i][0]}', '{arg_list[i][1]}'::jsonb ); @@ -499,73 +417,67 @@ def test_standardized_popularity_function_calculates( assert actual_std_pop_val == expect_std_pop_val -def test_image_view_calculates_std_pop( - postgres_with_image_table, table_info, mock_pg_hook_task -): - data_query = dedent( - f""" - INSERT INTO {table_info.image} ( - created_on, updated_on, provider, foreign_identifier, url, - meta_data, license, removed_from_source - ) - VALUES - ( - NOW(), NOW(), 'my_provider', 'fid_a', 'https://test.com/a.jpg', - '{{"views": 0, "description": "cats"}}', 'cc0', false - ), - ( - NOW(), NOW(), 'my_provider', 'fid_b', 'https://test.com/b.jpg', - '{{"views": 50, "description": "cats"}}', 'cc0', false - ), - ( - NOW(), NOW(), 'my_provider', 'fid_c', 'https://test.com/c.jpg', - '{{"views": 75, "description": "cats"}}', 'cc0', false - ), - ( - NOW(), NOW(), 'my_provider', 'fid_d', 'https://test.com/d.jpg', - '{{"views": 150, "description": "cats"}}', 'cc0', false - ) - """ - ) - metrics = {"my_provider": {"metric": "views", "percentile": 0.5}} - _set_up_image_view( - postgres_with_image_table, data_query, metrics, table_info, mock_pg_hook_task - ) - check_query = dedent( - f""" - SELECT foreign_identifier, standardized_popularity - FROM {table_info.image_view}; - """ - ) - postgres_with_image_table.cursor.execute(check_query) - rd = dict(postgres_with_image_table.cursor) - assert all( - [ - rd["fid_a"] == 0.0, - rd["fid_b"] == 0.5, - rd["fid_c"] == 0.6, - rd["fid_d"] == 0.75, - ] - ) - - @pytest.mark.parametrize( - "ddl_filename, metrics", + "providers, media_type, expected_confs", [ - ("0004_openledger_image_view.sql", sql.IMAGE_POPULARITY_METRICS), - ("0007_openledger_audio_view.sql", sql.AUDIO_POPULARITY_METRICS), + # No providers for this media type + ([], "image", []), + ( + ["foo_provider"], + "image", + [ + { + "query_id": "foo_provider_popularity_refresh_20230101", + "table_name": "image", + "select_query": "WHERE provider='foo_provider' AND updated_on < '2023-01-01 00:00:00'", + "update_query": "SET standardized_popularity = standardized_image_popularity(image.provider, image.meta_data)", + "batch_size": 10000, + "update_timeout": 3600.0, + "dry_run": False, + "resume_update": False, + }, + ], + ), + ( + ["my_provider", "your_provider"], + "audio", + [ + { + "query_id": "my_provider_popularity_refresh_20230101", + "table_name": "audio", + "select_query": "WHERE provider='my_provider' AND updated_on < '2023-01-01 00:00:00'", + "update_query": "SET standardized_popularity = standardized_audio_popularity(audio.provider, audio.meta_data)", + "batch_size": 10000, + "update_timeout": 3600.0, + "dry_run": False, + "resume_update": False, + }, + { + "query_id": "your_provider_popularity_refresh_20230101", + "table_name": "audio", + "select_query": "WHERE provider='your_provider' AND updated_on < '2023-01-01 00:00:00'", + "update_query": "SET standardized_popularity = standardized_audio_popularity(audio.provider, audio.meta_data)", + "batch_size": 10000, + "update_timeout": 3600.0, + "dry_run": False, + "resume_update": False, + }, + ], + ), ], ) -def test_ddl_matches_definitions(ddl_filename, metrics): - ddl = (DDL_DEFINITIONS_PATH / ddl_filename).read_text() - if not ( - match := re.search( - r"INSERT INTO public.\w+_popularity_metrics.*?;", - ddl, - re.MULTILINE | re.DOTALL, - ) - ): - raise ValueError(f"Could not find insert statement in ddl file {ddl_filename}") +def test_get_providers_update_confs(providers, media_type, expected_confs): + TEST_DAY = datetime(2023, 1, 1) + config = PopularityRefresh( + media_type=media_type, + refresh_popularity_batch_timeout=timedelta(hours=1), + popularity_metrics={provider: {"metric": "views"} for provider in providers}, + ) + + actual_confs = sql.get_providers_update_confs.function( + POSTGRES_CONN_ID, + config, + TEST_DAY, + ) - for provider in metrics: - assert provider in match.group(0) + assert actual_confs == expected_confs diff --git a/catalog/tests/dags/test_dag_parsing.py b/catalog/tests/dags/test_dag_parsing.py index e778f4a1f4e..eb90f9b8774 100644 --- a/catalog/tests/dags/test_dag_parsing.py +++ b/catalog/tests/dags/test_dag_parsing.py @@ -20,7 +20,8 @@ "maintenance/airflow_log_cleanup_workflow.py", "maintenance/pr_review_reminders/pr_review_reminders_dag.py", "maintenance/rotate_db_snapshots.py", - "database/recreate_popularity_calculation_dag_factory.py", + "popularity/recreate_popularity_calculation_dag_factory.py", + "popularity/popularity_refresh_dag_factory.py", "data_refresh/dag_factory.py", "data_refresh/create_filtered_index_dag.py", "oauth2/authorize_dag.py", @@ -34,7 +35,8 @@ "providers/provider_ingestion_workflow_dag_factory.py": len( REINGESTION_WORKFLOW_CONFIGS ), - "database/recreate_popularity_calculation_dag_factory.py": len(MEDIA_TYPES), + "popularity/recreate_popularity_calculation_dag_factory.py": len(MEDIA_TYPES), + "popularity/popularity_refresh_dag_factory.py": len(MEDIA_TYPES), "data_refresh/dag_factory.py": len(MEDIA_TYPES), "data_refresh/create_filtered_index_dag.py": len(MEDIA_TYPES), } diff --git a/catalog/tests/test_utils/sql.py b/catalog/tests/test_utils/sql.py index b58c11a4efb..2072c33f510 100644 --- a/catalog/tests/test_utils/sql.py +++ b/catalog/tests/test_utils/sql.py @@ -4,10 +4,10 @@ from airflow.models import TaskInstance -from common.constants import IMAGE -from common.loader.sql import TSV_COLUMNS, create_column_definitions +from common.loader.sql import create_column_definitions from common.storage import columns as col from common.storage.db_columns import IMAGE_TABLE_COLUMNS +from common.storage.tsv_columns import CURRENT_IMAGE_TSV_COLUMNS POSTGRES_CONN_ID = os.getenv("TEST_CONN_ID") @@ -18,7 +18,7 @@ LOADING_TABLE_COLUMN_DEFINITIONS = create_column_definitions( - TSV_COLUMNS[IMAGE], is_loading=True + CURRENT_IMAGE_TSV_COLUMNS, is_loading=True ) CREATE_LOAD_TABLE_QUERY = f"""CREATE TABLE public.{{}} ( @@ -80,7 +80,7 @@ def create_query_values( columns=None, ): if columns is None: - columns = TSV_COLUMNS[IMAGE] + columns = CURRENT_IMAGE_TSV_COLUMNS result = [] for column in columns: val = column_values.get(column.db_name) diff --git a/docker/upstream_db/0004_openledger_image_view.sql b/docker/upstream_db/0004_openledger_image_view.sql index efdf2060749..e3802f01874 100644 --- a/docker/upstream_db/0004_openledger_image_view.sql +++ b/docker/upstream_db/0004_openledger_image_view.sql @@ -42,38 +42,3 @@ $$ LANGUAGE SQL STABLE RETURNS NULL ON NULL INPUT; - - -CREATE MATERIALIZED VIEW image_view AS - SELECT - identifier, - created_on, - updated_on, - ingestion_type, - provider, - source, - foreign_identifier, - foreign_landing_url, - url, - thumbnail, - width, - height, - filesize, - license, - license_version, - creator, - creator_url, - title, - meta_data, - tags, - watermarked, - last_synced_with_source, - removed_from_source, - filetype, - category, - standardized_image_popularity( - image.provider, image.meta_data - ) AS standardized_popularity - FROM image; - -CREATE UNIQUE INDEX ON image_view (identifier); diff --git a/docker/upstream_db/0007_openledger_audio_view.sql b/docker/upstream_db/0007_openledger_audio_view.sql index f73748e7f89..af625f4c261 100644 --- a/docker/upstream_db/0007_openledger_audio_view.sql +++ b/docker/upstream_db/0007_openledger_audio_view.sql @@ -38,46 +38,6 @@ STABLE RETURNS NULL ON NULL INPUT; -CREATE MATERIALIZED VIEW audio_view AS - SELECT - identifier, - created_on, - updated_on, - ingestion_type, - provider, - source, - foreign_identifier, - foreign_landing_url, - url, - thumbnail, - filetype, - duration, - bit_rate, - sample_rate, - category, - genres, - audio_set, - alt_files, - filesize, - license, - license_version, - creator, - creator_url, - title, - meta_data, - tags, - watermarked, - last_synced_with_source, - removed_from_source, - audio_set ->> 'foreign_identifier' AS audio_set_foreign_identifier, - standardized_audio_popularity( - audio.provider, audio.meta_data - ) AS standardized_popularity - FROM audio; - -CREATE UNIQUE INDEX ON audio_view (identifier); - - CREATE VIEW audioset_view AS -- DISTINCT clause exists to ensure that only one record is present for a given -- foreign identifier/provider pair. This exists as a hard constraint in the API table diff --git a/ingestion_server/test/integration_test.py b/ingestion_server/test/integration_test.py index 3c9678143dd..cf43d539b67 100644 --- a/ingestion_server/test/integration_test.py +++ b/ingestion_server/test/integration_test.py @@ -167,7 +167,6 @@ def _load_data(cls, conn, table_names): f"COPY {table_name} FROM STDIN WITH (FORMAT csv, HEADER true)", data, ) - cur.execute(f"REFRESH MATERIALIZED VIEW {table_name}_view") conn.commit() cur.close() diff --git a/ingestion_server/test/mock_schemas/audio_view.sql b/ingestion_server/test/mock_schemas/audio_view.sql deleted file mode 100644 index fa5e991a903..00000000000 --- a/ingestion_server/test/mock_schemas/audio_view.sql +++ /dev/null @@ -1,192 +0,0 @@ --- --- PostgreSQL database dump --- - --- Dumped from database version 13.2 --- Dumped by pg_dump version 13.3 (Debian 13.3-1.pgdg100+1) - -SET statement_timeout = 0; -SET lock_timeout = 0; -SET idle_in_transaction_session_timeout = 0; -SET client_encoding = 'UTF8'; -SET standard_conforming_strings = on; -SET check_function_bodies = false; -SET xmloption = content; -SET client_min_messages = warning; -SET row_security = off; - -SET default_tablespace = ''; - -SET default_table_access_method = heap; - --- --- Name: audio_view; Type: TABLE; Schema: public; Owner: deploy --- - -CREATE TABLE public.audio_view ( - id integer NOT NULL, - created_on timestamp with time zone NOT NULL, - updated_on timestamp with time zone NOT NULL, - identifier uuid NOT NULL, - foreign_identifier character varying(1000), - title character varying(2000), - foreign_landing_url character varying(1000), - creator character varying(2000), - creator_url character varying(2000), - url character varying(1000), - filesize integer, - watermarked boolean, - license character varying(50) NOT NULL, - license_version character varying(25), - provider character varying(80), - source character varying(80), - last_synced_with_source timestamp with time zone, - removed_from_source boolean NOT NULL, - view_count integer, - tags jsonb, - meta_data jsonb, - audio_set_position integer, - genres character varying(80)[], - category character varying(80), - duration integer, - bit_rate integer, - sample_rate integer, - alt_files jsonb, - thumbnail character varying(1000), - filetype character varying(80), - audio_set_foreign_identifier character varying(1000), - standardized_popularity double precision, - ingestion_type character varying(1000), - audio_set jsonb -); - - -ALTER TABLE public.audio_view OWNER TO deploy; - --- --- Name: audio_id_seq; Type: SEQUENCE; Schema: public; Owner: deploy --- - -CREATE SEQUENCE public.audio_id_seq - AS integer - START WITH 1 - INCREMENT BY 1 - NO MINVALUE - NO MAXVALUE - CACHE 1; - - -ALTER TABLE public.audio_id_seq OWNER TO deploy; - --- --- Name: audio_id_seq; Type: SEQUENCE OWNED BY; Schema: public; Owner: deploy --- - -ALTER SEQUENCE public.audio_id_seq OWNED BY public.audio_view.id; - - --- --- Name: audio_view id; Type: DEFAULT; Schema: public; Owner: deploy --- - -ALTER TABLE ONLY public.audio_view ALTER COLUMN id SET DEFAULT nextval('public.audio_id_seq'::regclass); - - --- --- Name: audio_view audio_identifier_key; Type: CONSTRAINT; Schema: public; Owner: deploy --- - -ALTER TABLE ONLY public.audio_view - ADD CONSTRAINT audio_identifier_key UNIQUE (identifier); - - --- --- Name: audio_view audio_pkey; Type: CONSTRAINT; Schema: public; Owner: deploy --- - -ALTER TABLE ONLY public.audio_view - ADD CONSTRAINT audio_pkey PRIMARY KEY (id); - - --- --- Name: audio_view audio_url_key; Type: CONSTRAINT; Schema: public; Owner: deploy --- - -ALTER TABLE ONLY public.audio_view - ADD CONSTRAINT audio_url_key UNIQUE (url); - - --- --- Name: audio_view unique_provider_audio; Type: CONSTRAINT; Schema: public; Owner: deploy --- - -ALTER TABLE ONLY public.audio_view - ADD CONSTRAINT unique_provider_audio UNIQUE (foreign_identifier, provider); - - --- --- Name: audio_category_ceb7d386; Type: INDEX; Schema: public; Owner: deploy --- - -CREATE INDEX audio_category_ceb7d386 ON public.audio_view USING btree (category); - - --- --- Name: audio_category_ceb7d386_like; Type: INDEX; Schema: public; Owner: deploy --- - -CREATE INDEX audio_category_ceb7d386_like ON public.audio_view USING btree (category varchar_pattern_ops); - - --- --- Name: audio_foreign_identifier_617f66ad; Type: INDEX; Schema: public; Owner: deploy --- - -CREATE INDEX audio_foreign_identifier_617f66ad ON public.audio_view USING btree (foreign_identifier); - - --- --- Name: audio_foreign_identifier_617f66ad_like; Type: INDEX; Schema: public; Owner: deploy --- - -CREATE INDEX audio_foreign_identifier_617f66ad_like ON public.audio_view USING btree (foreign_identifier varchar_pattern_ops); - - --- --- Name: audio_genres_e34cc474; Type: INDEX; Schema: public; Owner: deploy --- - -CREATE INDEX audio_genres_e34cc474 ON public.audio_view USING btree (genres); - - --- --- Name: audio_last_synced_with_source_94c4a383; Type: INDEX; Schema: public; Owner: deploy --- - -CREATE INDEX audio_last_synced_with_source_94c4a383 ON public.audio_view USING btree (last_synced_with_source); - - --- --- Name: audio_provider_8fe1eb54; Type: INDEX; Schema: public; Owner: deploy --- - -CREATE INDEX audio_provider_8fe1eb54 ON public.audio_view USING btree (provider); - - --- --- Name: audio_provider_8fe1eb54_like; Type: INDEX; Schema: public; Owner: deploy --- - -CREATE INDEX audio_provider_8fe1eb54_like ON public.audio_view USING btree (provider varchar_pattern_ops); - - --- --- Name: audio_source_e9ccc813; Type: INDEX; Schema: public; Owner: deploy --- - -CREATE INDEX audio_source_e9ccc813 ON public.audio_view USING btree (source); - - --- --- PostgreSQL database dump complete --- diff --git a/ingestion_server/test/mock_schemas/image_view.sql b/ingestion_server/test/mock_schemas/image_view.sql deleted file mode 100644 index 6d94493b81a..00000000000 --- a/ingestion_server/test/mock_schemas/image_view.sql +++ /dev/null @@ -1,178 +0,0 @@ --- --- PostgreSQL database dump --- - --- Dumped from database version 13.2 --- Dumped by pg_dump version 13.3 (Debian 13.3-1.pgdg100+1) - -SET statement_timeout = 0; -SET lock_timeout = 0; -SET idle_in_transaction_session_timeout = 0; -SET client_encoding = 'UTF8'; -SET standard_conforming_strings = on; -SET check_function_bodies = false; -SET xmloption = content; -SET client_min_messages = warning; -SET row_security = off; - -SET default_tablespace = ''; - -SET default_table_access_method = heap; - --- --- Name: image_view; Type: TABLE; Schema: public; Owner: deploy --- - -CREATE TABLE public.image_view ( - id integer NOT NULL, - created_on timestamp with time zone NOT NULL, - updated_on timestamp with time zone NOT NULL, - identifier uuid NOT NULL, - provider character varying(80), - source character varying(80), - foreign_identifier character varying(1000), - foreign_landing_url character varying(1000), - url character varying(1000), - thumbnail character varying(1000), - width integer, - height integer, - filesize integer, - license character varying(50) NOT NULL, - license_version character varying(25), - creator character varying(2000), - creator_url character varying(2000), - title character varying(2000), - last_synced_with_source timestamp with time zone, - removed_from_source boolean NOT NULL, - meta_data jsonb, - view_count integer DEFAULT 0, - tags jsonb, - watermarked boolean, - filetype character varying(80), - standardized_popularity double precision, - ingestion_type character varying(1000) -); - - -ALTER TABLE public.image_view OWNER TO deploy; - --- --- Name: image_id_seq; Type: SEQUENCE; Schema: public; Owner: deploy --- - -CREATE SEQUENCE public.image_id_seq - AS integer - START WITH 1 - INCREMENT BY 1 - NO MINVALUE - NO MAXVALUE - CACHE 1; - - -ALTER TABLE public.image_id_seq OWNER TO deploy; - --- --- Name: image_id_seq; Type: SEQUENCE OWNED BY; Schema: public; Owner: deploy --- - -ALTER SEQUENCE public.image_id_seq OWNED BY public.image_view.id; - - --- --- Name: image_view id; Type: DEFAULT; Schema: public; Owner: deploy --- - -ALTER TABLE ONLY public.image_view ALTER COLUMN id SET DEFAULT nextval('public.image_id_seq'::regclass); - - --- --- Name: image_view image_identifier_key; Type: CONSTRAINT; Schema: public; Owner: deploy --- - -ALTER TABLE ONLY public.image_view - ADD CONSTRAINT image_identifier_key UNIQUE (identifier); - - --- --- Name: image_view image_pkey; Type: CONSTRAINT; Schema: public; Owner: deploy --- - -ALTER TABLE ONLY public.image_view - ADD CONSTRAINT image_pkey PRIMARY KEY (id); - - --- --- Name: image_view image_url_key; Type: CONSTRAINT; Schema: public; Owner: deploy --- - -ALTER TABLE ONLY public.image_view - ADD CONSTRAINT image_url_key UNIQUE (url); - - --- --- Name: image_view unique_provider_image; Type: CONSTRAINT; Schema: public; Owner: deploy --- - -ALTER TABLE ONLY public.image_view - ADD CONSTRAINT unique_provider_image UNIQUE (foreign_identifier, provider); - - --- --- Name: image_foreign_identifier_4c72d3ee; Type: INDEX; Schema: public; Owner: deploy --- - -CREATE INDEX image_foreign_identifier_4c72d3ee ON public.image_view USING btree (foreign_identifier); - - --- --- Name: image_foreign_identifier_4c72d3ee_like; Type: INDEX; Schema: public; Owner: deploy --- - -CREATE INDEX image_foreign_identifier_4c72d3ee_like ON public.image_view USING btree (foreign_identifier varchar_pattern_ops); - - --- --- Name: image_last_synced_with_source_187adf09; Type: INDEX; Schema: public; Owner: deploy --- - -CREATE INDEX image_last_synced_with_source_187adf09 ON public.image_view USING btree (last_synced_with_source); - - --- --- Name: image_provider_7d11f847; Type: INDEX; Schema: public; Owner: deploy --- - -CREATE INDEX image_provider_7d11f847 ON public.image_view USING btree (provider); - - --- --- Name: image_provider_7d11f847_like; Type: INDEX; Schema: public; Owner: deploy --- - -CREATE INDEX image_provider_7d11f847_like ON public.image_view USING btree (provider varchar_pattern_ops); - - --- --- Name: image_source_d5a89e97; Type: INDEX; Schema: public; Owner: deploy --- - -CREATE INDEX image_source_d5a89e97 ON public.image_view USING btree (source); - - --- --- Name: image_source_d5a89e97_like; Type: INDEX; Schema: public; Owner: deploy --- - -CREATE INDEX image_source_d5a89e97_like ON public.image_view USING btree (source varchar_pattern_ops); - - --- --- Name: image_url_c6aabda2_like; Type: INDEX; Schema: public; Owner: deploy --- - -CREATE INDEX image_url_c6aabda2_like ON public.image_view USING btree (url varchar_pattern_ops); - - --- --- PostgreSQL database dump complete --- diff --git a/load_sample_data.sh b/load_sample_data.sh index bc55042aedc..d0576429cfe 100755 --- a/load_sample_data.sh +++ b/load_sample_data.sh @@ -16,7 +16,6 @@ function load_sample_data { \copy $1 \ from './sample_data/sample_$1.csv' \ with (FORMAT csv, HEADER true); - REFRESH MATERIALIZED VIEW $1_view; EOF" }