Skip to content

Commit

Permalink
Workaround lack of hook_params templating
Browse files Browse the repository at this point in the history
  • Loading branch information
sarayourfriend committed May 22, 2023
1 parent 8bd7c45 commit c07749c
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions catalog/dags/maintenance/rotate_db_snapshots.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,19 @@
The DAG will automatically delete the oldest snapshots when more snaphots
exist than it is configured to retain.
Requires two variables:
Requires three variables:
`AIRFLOW_RDS_ARN`: The ARN of the RDS DB instance that needs snapshots.
`AIRFLOW_RDS_SNAPSHOTS_TO_RETAIN`: How many historical snapshots to retain.
`AIRFLOW_RDS_REGION`: The region of the RDS DB instance.
"""

import logging
from datetime import datetime

import boto3
from airflow.decorators import dag, task
from airflow.models import Variable
from airflow.providers.amazon.aws.operators.rds import RdsCreateDbSnapshotOperator
from airflow.providers.amazon.aws.sensors.rds import RdsSnapshotExistenceSensor

Expand All @@ -28,6 +30,9 @@

DAG_ID = "rotate_db_snapshots"
MAX_ACTIVE = 1
# This cannot be pulled in the DAG itself because ``hook_params``
# on the operators provided by the amazon provider is not templated
RDS_REGION = Variable.get("AIRFLOW_RDS_REGION")


@task()
Expand Down Expand Up @@ -70,13 +75,13 @@ def delete_previous_snapshots(rds_arn: str, snapshots_to_retain: int, rds_region
)
def rotate_db_snapshots():
snapshot_id = "airflow-{{ ds }}"
db_identifier_template = "{{ var.value.AIRFLOW_RDS_ARN }}"
hook_params = {"region_name": "{{ var.value.AIRFLOW_RDS_REGION }}"}
db_identifier = "{{ var.value.AIRFLOW_RDS_ARN }}"
hook_params = {"region_name": RDS_REGION}

create_db_snapshot = RdsCreateDbSnapshotOperator(
task_id="create_snapshot",
db_type="instance",
db_identifier=db_identifier_template,
db_identifier=db_identifier,
db_snapshot_identifier=snapshot_id,
hook_params=hook_params,
)
Expand All @@ -94,9 +99,9 @@ def rotate_db_snapshots():
create_db_snapshot
>> wait_for_snapshot_availability
>> delete_previous_snapshots(
rds_arn=db_identifier_template,
rds_arn=db_identifier,
snapshots_to_retain="{{ var.json.AIRFLOW_RDS_SNAPSHOTS_TO_RETAIN }}",
rds_region=hook_params["region_name"],
rds_region=RDS_REGION,
)
)

Expand Down

0 comments on commit c07749c

Please sign in to comment.