Skip to content

Commit

Permalink
Add DAG for filtering archived providers in catalog (#3259)
Browse files Browse the repository at this point in the history
* Add deleted media tables

* Set up new columns

* Add dag for deleting records

* Add tests

* Update dag docs

* Remove unnecessary retries

* Pull RETURN_ROW_COUNT out into utility

* Clean up comments

* Update dag docs

* Simplify table creation
  • Loading branch information
stacimc authored Nov 13, 2023
1 parent fcbc519 commit 6ede4ad
Show file tree
Hide file tree
Showing 18 changed files with 617 additions and 32 deletions.
3 changes: 1 addition & 2 deletions catalog/dags/common/loader/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from common.constants import IMAGE, MediaType, SQLInfo
from common.loader import provider_details as prov
from common.loader.paths import _extract_media_type
from common.sql import PostgresHook
from common.sql import RETURN_ROW_COUNT, PostgresHook
from common.storage import columns as col
from common.storage.columns import NULL, Column, UpsertStrategy
from common.storage.db_columns import setup_db_columns_for_media_type
Expand Down Expand Up @@ -40,7 +40,6 @@
}

CURRENT_TSV_VERSION = "001"
RETURN_ROW_COUNT = lambda c: c.rowcount # noqa: E731


def create_column_definitions(table_columns: list[Column], is_loading=True):
Expand Down
3 changes: 3 additions & 0 deletions catalog/dags/common/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
# https://airflow.apache.org/docs/apache-airflow-providers-postgres/stable/_api/airflow/providers/postgres/hooks/postgres/index.html#airflow.providers.postgres.hooks.postgres.PostgresHook.copy_expert # noqa


RETURN_ROW_COUNT = lambda c: c.rowcount # noqa: E731


def single_value(cursor):
try:
row = cursor.fetchone()
Expand Down
10 changes: 10 additions & 0 deletions catalog/dags/common/storage/columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,3 +767,13 @@ def prepare_string(self, value):
size=1000,
truncate=False,
)

# Columns used by the Deleted Media tables

DELETED_ON = TimestampColumn(
name="deleted_on", required=True, upsert_strategy=UpsertStrategy.no_change
)

DELETED_REASON = StringColumn(
name="deleted_reason", required=True, size=80, truncate=True
)
20 changes: 20 additions & 0 deletions catalog/dags/common/storage/db_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,29 @@
col.AUDIO_SET_FOREIGN_IDENTIFIER,
]


DB_COLUMNS_BY_MEDIA_TYPE = {AUDIO: AUDIO_TABLE_COLUMNS, IMAGE: IMAGE_TABLE_COLUMNS}


def setup_db_columns_for_media_type(func: callable) -> callable:
"""Provide media-type-specific DB columns as a kwarg to the decorated function."""
return setup_kwargs_for_media_type(DB_COLUMNS_BY_MEDIA_TYPE, "db_columns")(func)


DELETED_IMAGE_TABLE_COLUMNS = IMAGE_TABLE_COLUMNS + [col.DELETED_ON, col.DELETED_REASON]
DELETED_AUDIO_TABLE_COLUMNS = AUDIO_TABLE_COLUMNS + [col.DELETED_ON, col.DELETED_REASON]

DELETED_MEDIA_DB_COLUMNS_BY_MEDIA_TYPE = {
AUDIO: DELETED_AUDIO_TABLE_COLUMNS,
IMAGE: DELETED_IMAGE_TABLE_COLUMNS,
}


def setup_deleted_db_columns_for_media_type(func: callable) -> callable:
"""
Provide media-type-specific deleted media DB columns as a kwarg to the decorated
function.
"""
return setup_kwargs_for_media_type(
DELETED_MEDIA_DB_COLUMNS_BY_MEDIA_TYPE, "deleted_db_columns"
)(func)
4 changes: 2 additions & 2 deletions catalog/dags/database/batched_update/batched_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from common import slack
from common.constants import POSTGRES_CONN_ID
from common.sql import PostgresHook, single_value
from common.sql import RETURN_ROW_COUNT, PostgresHook, single_value
from database.batched_update import constants


Expand Down Expand Up @@ -57,7 +57,7 @@ def run_sql(
postgres_conn_id: str = POSTGRES_CONN_ID,
task: AbstractOperator = None,
timeout: timedelta = None,
handler: callable = constants.RETURN_ROW_COUNT,
handler: callable = RETURN_ROW_COUNT,
**kwargs,
):
query = sql_template.format(
Expand Down
1 change: 0 additions & 1 deletion catalog/dags/database/batched_update/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,3 @@
);
"""
DROP_TABLE_QUERY = "DROP TABLE IF EXISTS {temp_table_name} CASCADE;"
RETURN_ROW_COUNT = lambda c: c.rowcount # noqa: E731
21 changes: 21 additions & 0 deletions catalog/dags/database/delete_records/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from datetime import datetime, timedelta


DAG_ID = "delete_records"
SLACK_USERNAME = "Upstream Delete Records"
SLACK_ICON = ":database:"
START_DATE = datetime(2023, 10, 25)
DAGRUN_TIMEOUT = timedelta(days=31 * 3)
CREATE_TIMEOUT = timedelta(hours=6)
DELETE_TIMEOUT = timedelta(hours=1)

CREATE_RECORDS_QUERY = """
INSERT INTO {destination_table} ({destination_cols})
SELECT {source_cols}
FROM {source_table}
{select_query}
"""
DELETE_RECORDS_QUERY = """
DELETE FROM {table}
{select_query}
"""
103 changes: 103 additions & 0 deletions catalog/dags/database/delete_records/delete_records.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import logging
from datetime import timedelta

from airflow.decorators import task
from airflow.models.abstractoperator import AbstractOperator

from common import slack
from common.constants import POSTGRES_CONN_ID
from common.sql import RETURN_ROW_COUNT, PostgresHook
from common.storage.columns import DELETED_ON, Column
from common.storage.db_columns import (
setup_db_columns_for_media_type,
setup_deleted_db_columns_for_media_type,
)
from database.delete_records import constants


logger = logging.getLogger(__name__)


def run_sql(
sql_template: str,
postgres_conn_id: str = POSTGRES_CONN_ID,
task: AbstractOperator = None,
timeout: timedelta = None,
handler: callable = RETURN_ROW_COUNT,
**kwargs,
):
query = sql_template.format(**kwargs)

postgres = PostgresHook(
postgres_conn_id=postgres_conn_id,
default_statement_timeout=(
timeout if timeout else PostgresHook.get_execution_timeout(task)
),
)

return postgres.run(query, handler=handler)


@task
@setup_deleted_db_columns_for_media_type
@setup_db_columns_for_media_type
def create_deleted_records(
*,
select_query: str,
deleted_reason: str,
media_type: str,
db_columns: list[Column] = None,
deleted_db_columns: list[Column] = None,
task: AbstractOperator = None,
postgres_conn_id: str = POSTGRES_CONN_ID,
):
"""
Select records from the given media table using the select query, and then for each
record create a corresponding record in the Deleted Media table.
"""

destination_cols = ", ".join([col.db_name for col in deleted_db_columns])

# To build the source columns, we first list all columns in the main media table
source_cols = ", ".join([col.db_name for col in db_columns])
# Then add the deleted-media specific columns.
# `deleted_on` is set to its insert value to get the current timestamp:
source_cols += f", {DELETED_ON.get_insert_value()}"
# `deleted_reason` is set to the given string
source_cols += f", '{deleted_reason}'"

return run_sql(
sql_template=constants.CREATE_RECORDS_QUERY,
postgres_conn_id=postgres_conn_id,
task=task,
destination_table=f"deleted_{media_type}",
destination_cols=destination_cols,
source_table=media_type,
source_cols=source_cols,
select_query=select_query,
)


@task
def delete_records_from_media_table(
table: str, select_query: str, postgres_conn_id: str = POSTGRES_CONN_ID
):
"""Delete records matching the select_query from the given media table."""
return run_sql(
sql_template=constants.DELETE_RECORDS_QUERY,
table=table,
select_query=select_query,
)


@task
def notify_slack(text: str) -> str:
"""Send a message to Slack."""
slack.send_message(
text,
username=constants.SLACK_USERNAME,
icon_emoji=constants.SLACK_ICON,
dag_id=constants.DAG_ID,
)

return text
114 changes: 114 additions & 0 deletions catalog/dags/database/delete_records/delete_records_dag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""
# Delete Records DAG
This DAG is used to delete records from the Catalog media tables, after creating a
corresponding record in the associated `deleted_<media_type>` table for each record
to be deleted. It is important to note that records deleted by this DAG will still be
available in the API until the next data refresh runs.
Required Dagrun Configuration parameters:
* table_name: the name of the table to delete from. Must be a valid media table
* select_query: a SQL `WHERE` clause used to select the rows that will be deleted
* reason: a string explaining the reason for deleting the records. Ex ('deadlink')
An example dag_run configuration used to delete all records for the "foo" image provider
due to deadlinks would look like this:
```
{
"table_name": "image",
"select_query": "WHERE provider='foo'",
"reason": "deadlink"
}
```
## Warnings
Presently, there is no logic to prevent records that have an entry in a Deleted Media
table from simply being reingested during provider ingestion. Therefore in its current
state, the DAG should _only_ be used to delete records that we can guarantee will not
be reingested (for example, because the provider is archived).
This DAG does not have automated handling for deadlocks, so you must be certain that
records selected for deletion in this DAG are not also being written to by a provider
DAG, for instance. The simplest way to do this is to ensure that any affected provider
DAGs are not currently running.
"""


import logging

from airflow.decorators import dag
from airflow.models.param import Param

from common.constants import AUDIO, DAG_DEFAULT_ARGS, MEDIA_TYPES
from database.delete_records import constants
from database.delete_records.delete_records import (
create_deleted_records,
delete_records_from_media_table,
notify_slack,
)


logger = logging.getLogger(__name__)


@dag(
dag_id=constants.DAG_ID,
schedule=None,
start_date=constants.START_DATE,
tags=["database"],
dagrun_timeout=constants.DAGRUN_TIMEOUT,
doc_md=__doc__,
default_args={**DAG_DEFAULT_ARGS, "retries": 0},
render_template_as_native_obj=True,
params={
"table_name": Param(
default=AUDIO,
enum=MEDIA_TYPES,
description="The name of the media table from which to select records.",
),
"select_query": Param(
default="WHERE...",
type="string",
description=(
"The `WHERE` clause of a query that selects all the rows to"
" be deleted."
),
pattern="^WHERE",
),
"reason": Param(
default="",
type="string",
description="Short descriptor of the reason for deleting the records.",
),
},
)
def delete_records():
# Create the records in the Deleted Media table
insert_into_deleted_media_table = create_deleted_records.override(
task_id="update_deleted_media_table", execution_timeout=constants.CREATE_TIMEOUT
)(
select_query="{{ params.select_query }}",
deleted_reason="{{ params.reason }}",
media_type="{{ params.table_name }}",
)

# If successful, delete the records from the media table
delete_records = delete_records_from_media_table.override(
execution_timeout=constants.DELETE_TIMEOUT
)(table="{{ params.table_name }}", select_query="{{ params.select_query }}")

notify_complete = notify_slack(
text=(
f"Deleted {delete_records} records from the"
" {{ params.table_name }} table matching query: `{{ params.select_query }}`"
),
)

insert_into_deleted_media_table >> delete_records >> notify_complete


delete_records()
3 changes: 1 addition & 2 deletions catalog/dags/maintenance/add_license_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@

from common.constants import DAG_DEFAULT_ARGS, POSTGRES_CONN_ID, XCOM_PULL_TEMPLATE
from common.licenses import get_license_info_from_license_pair
from common.loader.sql import RETURN_ROW_COUNT
from common.slack import send_message
from common.sql import PostgresHook
from common.sql import RETURN_ROW_COUNT, PostgresHook
from providers.provider_dag_factory import AWS_CONN_ID, OPENVERSE_BUCKET


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@
OPENLEDGER_API_CONN_ID,
XCOM_PULL_TEMPLATE,
)
from common.loader.sql import RETURN_ROW_COUNT
from common.slack import send_message
from common.sql import PostgresHook
from common.sql import PostgresHook, RETURN_ROW_COUNT


logger = logging.getLogger(__name__)
Expand Down
Loading

0 comments on commit 6ede4ad

Please sign in to comment.