Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix various issues that existed in the rotate_db_snapshots DAG #2158

Merged
merged 6 commits into from
May 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions catalog/DAGs.md
Original file line number Diff line number Diff line change
Expand Up @@ -671,8 +671,8 @@ than it is configured to retain.

Requires two variables:

`AIRFLOW_RDS_ARN`: The ARN of the RDS DB instance that needs snapshots.
`AIRFLOW_RDS_SNAPSHOTS_TO_RETAIN`: How many historical snapshots to retain.
`CATALOG_RDS_DB_IDENTIFIER`: The "DBIdentifier" of the RDS DB instance.
`CATALOG_RDS_SNAPSHOTS_TO_RETAIN`: How many historical snapshots to retain.

## `science_museum_workflow`

Expand Down
1 change: 1 addition & 0 deletions catalog/dags/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@
POSTGRES_CONN_ID = os.getenv("OPENLEDGER_CONN_ID", "postgres_openledger_testing")
OPENLEDGER_API_CONN_ID = os.getenv("OPENLEDGER_API_CONN_ID", "postgres_openledger_api")
AWS_CONN_ID = os.getenv("AWS_CONN_ID", "aws_conn_id")
AWS_RDS_CONN_ID = os.environ.get("AWS_RDS_CONN_ID", AWS_CONN_ID)
6 changes: 0 additions & 6 deletions catalog/dags/database/staging_database_restore/constants.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
import os

from common.constants import AWS_CONN_ID


_ID_FORMAT = "{}-openverse-db"

DAG_ID = "staging_database_restore"
Expand All @@ -14,6 +9,5 @@
SAFE_TO_MUTATE = {STAGING_IDENTIFIER, TEMP_IDENTIFIER, OLD_IDENTIFIER}

SKIP_VARIABLE = "SKIP_STAGING_DATABASE_RESTORE"
AWS_RDS_CONN_ID = os.environ.get("AWS_RDS_CONN_ID", AWS_CONN_ID)
SLACK_USERNAME = "Staging Database Restore"
SLACK_ICON = ":database-pink:"
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from airflow.utils.trigger_rule import TriggerRule

from common import slack
from common.constants import AWS_RDS_CONN_ID
from database.staging_database_restore import constants
from database.staging_database_restore.utils import (
ensure_mutate_allowed,
Expand Down Expand Up @@ -170,7 +171,7 @@ def make_rds_sensor(task_id: str, db_identifier: str, retries: int = 0) -> RdsDb
task_id=task_id,
db_identifier=db_identifier,
target_statuses=["available"],
aws_conn_id=constants.AWS_RDS_CONN_ID,
aws_conn_id=AWS_RDS_CONN_ID,
mode="reschedule",
timeout=60 * 60, # 1 hour
retries=retries,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from airflow.providers.amazon.aws.sensors.rds import RdsSnapshotExistenceSensor
from airflow.utils.trigger_rule import TriggerRule

from common.constants import DAG_DEFAULT_ARGS
from common.constants import AWS_RDS_CONN_ID, DAG_DEFAULT_ARGS
from database.staging_database_restore import constants
from database.staging_database_restore.staging_database_restore import (
get_latest_prod_snapshot,
Expand Down Expand Up @@ -66,7 +66,7 @@ def restore_staging_database():
task_id="ensure_snapshot_ready",
db_type="instance",
db_snapshot_identifier=latest_snapshot,
aws_conn_id=constants.AWS_RDS_CONN_ID,
aws_conn_id=AWS_RDS_CONN_ID,
mode="reschedule",
timeout=60 * 60 * 4, # 4 hours
)
Expand Down Expand Up @@ -118,7 +118,7 @@ def restore_staging_database():
task_id="delete_old",
db_instance_identifier=constants.OLD_IDENTIFIER,
rds_kwargs={"SkipFinalSnapshot": True, "DeleteAutomatedBackups": False},
aws_conn_id=constants.AWS_RDS_CONN_ID,
aws_conn_id=AWS_RDS_CONN_ID,
wait_for_completion=True,
)

Expand Down
5 changes: 2 additions & 3 deletions catalog/dags/database/staging_database_restore/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from airflow.providers.amazon.aws.hooks.rds import RdsHook

from common.constants import AWS_RDS_CONN_ID
from database.staging_database_restore import constants


Expand All @@ -14,9 +15,7 @@ def setup_rds_hook(func: callable) -> callable:

@functools.wraps(func)
def wrapped(*args, **kwargs):
rds_hook = kwargs.pop("rds_hook", None) or RdsHook(
aws_conn_id=constants.AWS_RDS_CONN_ID
)
rds_hook = kwargs.pop("rds_hook", None) or RdsHook(aws_conn_id=AWS_RDS_CONN_ID)
return func(*args, **kwargs, rds_hook=rds_hook)

return wrapped
Expand Down
56 changes: 38 additions & 18 deletions catalog/dags/maintenance/rotate_db_snapshots.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,36 +11,54 @@

Requires two variables:

`AIRFLOW_RDS_ARN`: The ARN of the RDS DB instance that needs snapshots.
`AIRFLOW_RDS_SNAPSHOTS_TO_RETAIN`: How many historical snapshots to retain.
`CATALOG_RDS_DB_IDENTIFIER`: The "DBIdentifier" of the RDS DB instance.
`CATALOG_RDS_SNAPSHOTS_TO_RETAIN`: How many historical snapshots to retain.
"""

import logging
from datetime import datetime

import boto3
from airflow.decorators import dag, task
from airflow.providers.amazon.aws.hooks.rds import RdsHook
from airflow.providers.amazon.aws.operators.rds import RdsCreateDbSnapshotOperator
from airflow.providers.amazon.aws.sensors.rds import RdsSnapshotExistenceSensor

from common.constants import AWS_RDS_CONN_ID


logger = logging.getLogger(__name__)

DAG_ID = "rotate_db_snapshots"
MAX_ACTIVE = 1


AIRFLOW_MANAGED_SNAPSHOT_ID_PREFIX = "airflow-managed"


@task()
def delete_previous_snapshots(rds_arn: str, snapshots_to_retain: int, rds_region: str):
rds = boto3.client("rds", region_name=rds_region)
def delete_previous_snapshots(db_identifier: str, snapshots_to_retain: int):
hook = RdsHook(aws_conn_id=AWS_RDS_CONN_ID)

# Snapshot object documentation:
# https://docs.aws.amazon.com/AmazonRDS/latest/APIReference/API_DBSnapshot.html
snapshots = rds.describe_db_snapshots(
DBInstanceIdentifier=rds_arn,
snapshots = hook.conn.describe_db_snapshots(
DBInstanceIdentifier=db_identifier,
SnapshotType="manual", # Automated backups cannot be manually managed
)["DBSnapshots"]

# Other manual snapshots may exist; we only want to automatically manage
# ones this DAG created
snapshots = [
snapshot
for snapshot in snapshots
if snapshot["DBSnapshotIdentifier"].startswith(
AIRFLOW_MANAGED_SNAPSHOT_ID_PREFIX
)
]
Comment on lines +49 to +57
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice detail!


snapshots.sort(
key=lambda x: datetime.fromisoformat(x["SnapshotCreateTime"]), reverse=True
key=lambda x: x["SnapshotCreateTime"], # boto3 casts this to datetime
reverse=True,
)

if len(snapshots) <= snapshots_to_retain or not (
Expand All @@ -52,7 +70,9 @@ def delete_previous_snapshots(rds_arn: str, snapshots_to_retain: int, rds_region
logger.info(f"Deleting {len(snapshots_to_delete)} snapshots.")
for snapshot in snapshots_to_delete:
logger.info(f"Deleting {snapshot['DBSnapshotIdentifier']}.")
rds.delete_db_snapshot(DBSnapshotIdentifier=snapshot["DBSnapshotIdentifier"])
hook.conn.delete_db_snapshot(
DBSnapshotIdentifier=snapshot["DBSnapshotIdentifier"]
)


@dag(
Expand All @@ -69,16 +89,16 @@ def delete_previous_snapshots(rds_arn: str, snapshots_to_retain: int, rds_region
render_template_as_native_obj=True,
)
def rotate_db_snapshots():
snapshot_id = "airflow-{{ ds }}"
db_identifier_template = "{{ var.value.AIRFLOW_RDS_ARN }}"
hook_params = {"region_name": "{{ var.value.AIRFLOW_RDS_REGION }}"}
snapshot_id = f"{AIRFLOW_MANAGED_SNAPSHOT_ID_PREFIX}-{{{{ ts_nodash }}}}"
db_identifier = "{{ var.value.CATALOG_RDS_DB_IDENTIFIER }}"

create_db_snapshot = RdsCreateDbSnapshotOperator(
task_id="create_snapshot",
db_type="instance",
db_identifier=db_identifier_template,
db_identifier=db_identifier,
db_snapshot_identifier=snapshot_id,
hook_params=hook_params,
aws_conn_id=AWS_RDS_CONN_ID,
wait_for_completion=False,
)

wait_for_snapshot_availability = RdsSnapshotExistenceSensor(
Expand All @@ -87,16 +107,16 @@ def rotate_db_snapshots():
db_snapshot_identifier=snapshot_id,
# This is the default for ``target_statuses`` but making it explicit is clearer
target_statuses=["available"],
hook_params=hook_params,
aws_conn_id=AWS_RDS_CONN_ID,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for missing this in previous reviews - this should have the mode=reschedule parameter set as well so it doesn't take up a worker slot while waiting for the snapshot. (in general this reduces overhead for Airflow). Unfortunately Airflow doesn't have a way to set this as the default globally.

Copy link
Collaborator Author

@sarayourfriend sarayourfriend May 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really wish there was an easy thing like the eslint "no-restricted-syntax" rule for Python to encode these sorts of things. It's so tedious to try to remember all the dozens of little rules and best practices and for reviewers to have to try to catch them. Not to mention when the reviewer isn't someone who is so familiar with Airflow best practices!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking the same as I was typing this out 😞 Way way easier to remember that way. We might be able to implement custom ruff linting checks perhaps, I'll try to look into it a bit!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apparently not yet: astral-sh/ruff#283

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would allow us to do this: https://github.com/hchasestevens/bellybutton

I remember now that I opened an issue for it: #1317

mode="reschedule",
)

(
create_db_snapshot
>> wait_for_snapshot_availability
>> delete_previous_snapshots(
rds_arn=db_identifier_template,
snapshots_to_retain="{{ var.json.AIRFLOW_RDS_SNAPSHOTS_TO_RETAIN }}",
rds_region=hook_params["region_name"],
db_identifier=db_identifier,
snapshots_to_retain="{{ var.json.CATALOG_RDS_SNAPSHOTS_TO_RETAIN }}",
)
)

Expand Down
78 changes: 47 additions & 31 deletions catalog/tests/dags/maintenance/test_rotate_db_snapshots.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,28 @@
from datetime import datetime, timedelta
from unittest import mock

import boto3
import pytest

from maintenance.rotate_db_snapshots import delete_previous_snapshots
from common.constants import AWS_RDS_CONN_ID
from maintenance.rotate_db_snapshots import (
AIRFLOW_MANAGED_SNAPSHOT_ID_PREFIX,
delete_previous_snapshots,
)


@pytest.fixture
def boto_client(monkeypatch):
get_client = mock.MagicMock()
def rds_hook(monkeypatch):
RdsHook = mock.MagicMock()

monkeypatch.setattr(boto3, "client", get_client)
return get_client
monkeypatch.setattr("maintenance.rotate_db_snapshots.RdsHook", RdsHook)
return RdsHook


@pytest.fixture
def rds_client(boto_client):
rds = mock.MagicMock()

boto_client.return_value = rds
return rds
def hook(rds_hook):
hook_instance = mock.MagicMock()
rds_hook.return_value = hook_instance
return hook_instance


def _make_snapshots(count: int, shuffle=False) -> dict:
Expand All @@ -31,8 +33,8 @@ def _make_snapshots(count: int, shuffle=False) -> dict:
date = date - timedelta(days=1)
snaps.append(
{
"DBSnapshotIdentifier": _id,
"SnapshotCreateTime": date.isoformat(),
"DBSnapshotIdentifier": f"{AIRFLOW_MANAGED_SNAPSHOT_ID_PREFIX}-{_id}",
"SnapshotCreateTime": date, # boto3 returns datetime objects
}
)
return {"DBSnapshots": snaps}
Expand All @@ -41,7 +43,7 @@ def _make_snapshots(count: int, shuffle=False) -> dict:
@pytest.mark.parametrize(
("snapshots", "snapshots_to_retain"),
(
# Less than 7
# Less than the number we want to keep
(_make_snapshots(1), 2),
(_make_snapshots(1), 5),
# Exactly the number we want to keep
Expand All @@ -50,29 +52,44 @@ def _make_snapshots(count: int, shuffle=False) -> dict:
),
)
def test_delete_previous_snapshots_no_snapshots_to_delete(
snapshots, snapshots_to_retain, rds_client
snapshots, snapshots_to_retain, hook
):
rds_client.describe_db_snapshots.return_value = snapshots
delete_previous_snapshots.function("fake_arn", snapshots_to_retain, "fake_region")
rds_client.delete_db_snapshot.assert_not_called()
hook.conn.describe_db_snapshots.return_value = snapshots
delete_previous_snapshots.function("fake_db_identifier", snapshots_to_retain)
hook.conn.delete_db_snapshot.assert_not_called()


def test_delete_previous_snapshots(rds_client):
def test_delete_previous_snapshots(hook):
snapshots_to_retain = 6
snapshots = _make_snapshots(10)
snapshots_to_delete = snapshots["DBSnapshots"][snapshots_to_retain:]
rds_client.describe_db_snapshots.return_value = snapshots
hook.conn.describe_db_snapshots.return_value = snapshots

delete_previous_snapshots.function("fake_arn", snapshots_to_retain, "fake_region")
rds_client.delete_db_snapshot.assert_has_calls(
delete_previous_snapshots.function("fake_db_identifier", snapshots_to_retain)
hook.conn.delete_db_snapshot.assert_has_calls(
[
mock.call(DBSnapshotIdentifier=snapshot["DBSnapshotIdentifier"])
for snapshot in snapshots_to_delete
]
)


def test_sorts_snapshots(rds_client):
def test_delete_previous_snapshots_ignores_non_airflow_managed_ones(hook):
snapshots_to_retain = 2
snapshots = _make_snapshots(4)
# Set the last one to an unmanaged snapshot leaving 1 to delete
snapshots["DBSnapshots"][-1]["DBSnapshotIdentifier"] = "not-managed-by-airflow-123"
snapshot_to_delete = snapshots["DBSnapshots"][-2]

hook.conn.describe_db_snapshots.return_value = snapshots

delete_previous_snapshots.function("fake_db_identifier", snapshots_to_retain)
hook.conn.delete_db_snapshot.assert_has_calls(
[mock.call(DBSnapshotIdentifier=snapshot_to_delete["DBSnapshotIdentifier"])]
)


def test_sorts_snapshots(hook):
"""
As far as we can tell the API does return them pre-sorted but it isn't documented
so just to be sure we'll sort them anyway.
Expand All @@ -84,19 +101,18 @@ def test_sorts_snapshots(rds_client):
# shuffle the snapshots to mimic an unstable API return order
random.shuffle(snapshots["DBSnapshots"])

rds_client.describe_db_snapshots.return_value = snapshots
delete_previous_snapshots.function("fake_arn", snapshots_to_retain, "fake_region")
rds_client.delete_db_snapshot.assert_has_calls(
hook.conn.describe_db_snapshots.return_value = snapshots
delete_previous_snapshots.function("fake_db_identifier", snapshots_to_retain)
hook.conn.delete_db_snapshot.assert_has_calls(
[
mock.call(DBSnapshotIdentifier=snapshot["DBSnapshotIdentifier"])
for snapshot in snapshots_to_delete
]
)


def test_instantiates_rds_client_with_region(boto_client, rds_client):
rds_client.describe_db_snapshots.return_value = _make_snapshots(0)
def test_instantiates_rds_hook_with_rds_connection_id(rds_hook, hook):
hook.conn.describe_db_snapshots.return_value = _make_snapshots(0)

region = "fake_region"
delete_previous_snapshots.function("fake_arn", 0, region)
boto_client.assert_called_once_with("rds", region_name=region)
delete_previous_snapshots.function("fake_db_identifier", 0)
rds_hook.assert_called_once_with(aws_conn_id=AWS_RDS_CONN_ID)