-
Notifications
You must be signed in to change notification settings - Fork 213
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add first pass a db snapshot rotation DAG * Add unit tests * Fix DAG documentation * Add db snapshots DAG to parsing test * Add missing attributes to DAG * Fix DAG_ID Co-authored-by: Madison Swain-Bowden <[email protected]> * Fix template variable * Remove redundant parameter * Update openverse_catalog/dags/maintenance/rotate_db_snapshots.py Co-authored-by: Madison Swain-Bowden <[email protected]> * Use Airflow template strings to get variables Co-authored-by: Madison Swain-Bowden <[email protected]> * Fix dag name * Sort describe snapshots return value (just to make sure) Also fixes the usage of `describe_db_snapshots` to retrieve the actual list of snapshots on the pagination object. * Lint generated DAG file Co-authored-by: Madison Swain-Bowden <[email protected]>
- Loading branch information
1 parent
f654446
commit 3821925
Showing
6 changed files
with
204 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
""" | ||
Manages weekly database snapshots. RDS does not support weekly snapshots | ||
schedules on its own, so we need a DAG to manage this for us. | ||
It runs on Saturdays at 00:00 UTC in order to happen before the data refresh. | ||
The DAG will automatically delete the oldest snapshots when more snaphots | ||
exist than it is configured to retain. | ||
Requires two variables: | ||
`AIRFLOW_RDS_ARN`: The ARN of the RDS DB instance that needs snapshots. | ||
`AIRFLOW_RDS_SNAPSHOTS_TO_RETAIN`: How many historical snapshots to retain. | ||
""" | ||
|
||
import logging | ||
from datetime import datetime | ||
|
||
import boto3 | ||
from airflow.decorators import dag, task | ||
from airflow.providers.amazon.aws.operators.rds import RdsCreateDbSnapshotOperator | ||
from airflow.providers.amazon.aws.sensors.rds import RdsSnapshotExistenceSensor | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
DAG_ID = "rotate_db_snapshots" | ||
MAX_ACTIVE = 1 | ||
|
||
|
||
@task() | ||
def delete_previous_snapshots(rds_arn: str, snapshots_to_retain: int): | ||
rds = boto3.client("rds") | ||
# Snapshot object documentation: | ||
# https://docs.aws.amazon.com/AmazonRDS/latest/APIReference/API_DBSnapshot.html | ||
snapshots = rds.describe_db_snapshots( | ||
DBInstanceIdentifier=rds_arn, | ||
)["DBSnapshots"] | ||
|
||
snapshots.sort( | ||
key=lambda x: datetime.fromisoformat(x["SnapshotCreateTime"]), reverse=True | ||
) | ||
|
||
if len(snapshots) <= snapshots_to_retain or not ( | ||
snapshots_to_delete := snapshots[snapshots_to_retain:] | ||
): | ||
logger.info("No snapshots to delete.") | ||
return | ||
|
||
logger.info(f"Deleting {len(snapshots_to_delete)} snapshots.") | ||
for snapshot in snapshots_to_delete: | ||
logger.info(f"Deleting {snapshot['DBSnapshotIdentifier']}.") | ||
rds.delete_db_snapshot(DBSnapshotIdentifier=snapshot["DBSnapshotIdentifier"]) | ||
|
||
|
||
@dag( | ||
dag_id=DAG_ID, | ||
# At 00:00 on Saturday, this puts it before the data refresh starts | ||
schedule="0 0 * * 6", | ||
start_date=datetime(2022, 12, 2), | ||
tags=["maintenance"], | ||
max_active_tasks=MAX_ACTIVE, | ||
max_active_runs=MAX_ACTIVE, | ||
catchup=False, | ||
# Use the docstring at the top of the file as md docs in the UI | ||
doc_md=__doc__, | ||
render_template_as_native_obj=True, | ||
) | ||
def rotate_db_snapshots(): | ||
snapshot_id = "airflow-{{ ds }}" | ||
db_identifier_template = "{{ var.json.AIRFLOW_RDS_ARN }}" | ||
create_db_snapshot = RdsCreateDbSnapshotOperator( | ||
task_id="create_snapshot", | ||
db_type="instance", | ||
db_identifier=db_identifier_template, | ||
db_snapshot_identifier=snapshot_id, | ||
) | ||
wait_for_snapshot_availability = RdsSnapshotExistenceSensor( | ||
task_id="await_snapshot_availability", | ||
db_type="instance", | ||
db_snapshot_identifier=snapshot_id, | ||
# This is the default for ``target_statuses`` but making it explicit is clearer | ||
target_statuses=["available"], | ||
) | ||
|
||
( | ||
create_db_snapshot | ||
>> wait_for_snapshot_availability | ||
>> delete_previous_snapshots( | ||
rds_arn=db_identifier_template, | ||
snapshots_to_retain="{{ var.json.AIRFLOW_RDS_SNAPSHOTS_TO_RETAIN }}", | ||
) | ||
) | ||
|
||
|
||
rotate_db_snapshots() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import random | ||
from datetime import datetime, timedelta | ||
from unittest import mock | ||
|
||
import boto3 | ||
import pytest | ||
from maintenance.rotate_db_snapshots import delete_previous_snapshots | ||
|
||
|
||
@pytest.fixture | ||
def rds_client(monkeypatch): | ||
rds = mock.MagicMock() | ||
|
||
def get_client(*args, **kwargs): | ||
return rds | ||
|
||
monkeypatch.setattr(boto3, "client", get_client) | ||
return rds | ||
|
||
|
||
def _make_snapshots(count: int, shuffle=False) -> dict: | ||
date = datetime.now() | ||
snaps = [] | ||
for _id in range(count): | ||
date = date - timedelta(days=1) | ||
snaps.append( | ||
{ | ||
"DBSnapshotIdentifier": _id, | ||
"SnapshotCreateTime": date.isoformat(), | ||
} | ||
) | ||
return {"DBSnapshots": snaps} | ||
|
||
|
||
@pytest.mark.parametrize( | ||
("snapshots", "snapshots_to_retain"), | ||
( | ||
# Less than 7 | ||
(_make_snapshots(1), 2), | ||
(_make_snapshots(1), 5), | ||
# Exactly the number we want to keep | ||
(_make_snapshots(7), 7), | ||
(_make_snapshots(2), 2), | ||
), | ||
) | ||
def test_delete_previous_snapshots_no_snapshots_to_delete( | ||
snapshots, snapshots_to_retain, rds_client | ||
): | ||
rds_client.describe_db_snapshots.return_value = snapshots | ||
delete_previous_snapshots.function("fake_arn", snapshots_to_retain) | ||
rds_client.delete_db_snapshot.assert_not_called() | ||
|
||
|
||
def test_delete_previous_snapshots(rds_client): | ||
snapshots_to_retain = 6 | ||
snapshots = _make_snapshots(10) | ||
snapshots_to_delete = snapshots["DBSnapshots"][snapshots_to_retain:] | ||
rds_client.describe_db_snapshots.return_value = snapshots | ||
delete_previous_snapshots.function("fake_arn", snapshots_to_retain) | ||
rds_client.delete_db_snapshot.assert_has_calls( | ||
[ | ||
mock.call(DBSnapshotIdentifier=snapshot["DBSnapshotIdentifier"]) | ||
for snapshot in snapshots_to_delete | ||
] | ||
) | ||
|
||
|
||
def test_sorts_snapshots(rds_client): | ||
""" | ||
As far as we can tell the API does return them pre-sorted but it isn't documented | ||
so just to be sure we'll sort them anyway. | ||
""" | ||
snapshots_to_retain = 6 | ||
# _make_snapshots returns them ordered by date reversed | ||
snapshots = _make_snapshots(10) | ||
snapshots_to_delete = snapshots["DBSnapshots"][snapshots_to_retain:] | ||
# shuffle the snapshots to mimic an unstable API return order | ||
random.shuffle(snapshots["DBSnapshots"]) | ||
|
||
rds_client.describe_db_snapshots.return_value = snapshots | ||
delete_previous_snapshots.function("fake_arn", snapshots_to_retain) | ||
rds_client.delete_db_snapshot.assert_has_calls( | ||
[ | ||
mock.call(DBSnapshotIdentifier=snapshot["DBSnapshotIdentifier"]) | ||
for snapshot in snapshots_to_delete | ||
] | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters