Skip to content

Commit

Permalink
custom waiters with dynamic values, applied to appflow (#29911)
Browse files Browse the repository at this point in the history
  • Loading branch information
vandonr-amz authored Mar 21, 2023
1 parent 1ab105a commit 05c0841
Show file tree
Hide file tree
Showing 8 changed files with 201 additions and 54 deletions.
28 changes: 11 additions & 17 deletions airflow/providers/amazon/aws/hooks/appflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
# under the License.
from __future__ import annotations

import json
from time import sleep
from typing import TYPE_CHECKING

from airflow.compat.functools import cached_property
Expand Down Expand Up @@ -64,24 +62,20 @@ def run_flow(self, flow_name: str, poll_interval: int = 20, wait_for_completion:
self.log.info("executionId: %s", execution_id)

if wait_for_completion:
last_execs: dict = {}
self.log.info("Waiting for flow run to complete...")
while (
execution_id not in last_execs or last_execs[execution_id]["executionStatus"] == "InProgress"
):
sleep(poll_interval)
# queries the last 20 runs, which should contain ours.
response_desc = self.conn.describe_flow_execution_records(flowName=flow_name)
last_execs = {fe["executionId"]: fe for fe in response_desc["flowExecutions"]}

exec_details = last_execs[execution_id]
self.log.info("Run complete, execution details: %s", exec_details)

if exec_details["executionStatus"] == "Error":
raise Exception(f"Flow error:\n{json.dumps(exec_details, default=str)}")
self.get_waiter("run_complete", {"EXECUTION_ID": execution_id}).wait(
flowName=flow_name,
WaiterConfig={"Delay": poll_interval},
)
self._log_execution_description(flow_name, execution_id)

return execution_id

def _log_execution_description(self, flow_name: str, execution_id: str):
response_desc = self.conn.describe_flow_execution_records(flowName=flow_name)
last_execs = {fe["executionId"]: fe for fe in response_desc["flowExecutions"]}
exec_details = last_execs[execution_id]
self.log.info("Run complete, execution details: %s", exec_details)

def update_flow_filter(
self, flow_name: str, filter_tasks: list[TaskTypeDef], set_trigger_ondemand: bool = False
) -> None:
Expand Down
28 changes: 25 additions & 3 deletions airflow/providers/amazon/aws/hooks/base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import boto3
import botocore
import botocore.session
import jinja2
import requests
import tenacity
from botocore.client import ClientMeta
Expand Down Expand Up @@ -796,25 +797,46 @@ def waiter_path(self) -> PathLike[str] | None:
path = Path(__file__).parents[1].joinpath(f"waiters/{self.client_type}.json").resolve()
return path if path.exists() else None

def get_waiter(self, waiter_name: str) -> Waiter:
def get_waiter(self, waiter_name: str, parameters: dict[str, str] | None = None) -> Waiter:
"""
First checks if there is a custom waiter with the provided waiter_name and
uses that if it exists, otherwise it will check the service client for a
waiter that matches the name and pass that through.
:param waiter_name: The name of the waiter. The name should exactly match the
name of the key in the waiter model file (typically this is CamelCase).
:param parameters: will scan the waiter config for the keys of that dict, and replace them with the
corresponding value. If a custom waiter has such keys to be expanded, they need to be provided
here.
"""
if self.waiter_path and (waiter_name in self._list_custom_waiters()):
# Technically if waiter_name is in custom_waiters then self.waiter_path must
# exist but MyPy doesn't like the fact that self.waiter_path could be None.
with open(self.waiter_path) as config_file:
config = json.load(config_file)
return BaseBotoWaiter(client=self.conn, model_config=config).waiter(waiter_name)
config = json.loads(config_file.read())

config = self._apply_parameters_value(config, waiter_name, parameters)
return BaseBotoWaiter(client=self.conn, model_config=config).waiter(waiter_name)
# If there is no custom waiter found for the provided name,
# then try checking the service's official waiters.
return self.conn.get_waiter(waiter_name)

@staticmethod
def _apply_parameters_value(config: dict, waiter_name: str, parameters: dict[str, str] | None) -> dict:
"""Replaces potential jinja templates in acceptors definition"""
# only process the waiter we're going to use to not raise errors for missing params for other waiters.
acceptors = config["waiters"][waiter_name]["acceptors"]
for a in acceptors:
arg = a["argument"]
template = jinja2.Template(arg, autoescape=False, undefined=jinja2.StrictUndefined)
try:
a["argument"] = template.render(parameters or {})
except jinja2.UndefinedError as e:
raise AirflowException(
f"Parameter was not supplied for templated waiter's acceptor '{arg}'", e
)
return config

def list_waiters(self) -> list[str]:
"""Returns a list containing the names of all waiters for the service, official and custom."""
return [*self._list_official_waiters(), *self._list_custom_waiters()]
Expand Down
4 changes: 3 additions & 1 deletion airflow/providers/amazon/aws/hooks/batch_waiters.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def waiter_model(self) -> botocore.waiter.WaiterModel:
"""
return self._waiter_model

def get_waiter(self, waiter_name: str) -> botocore.waiter.Waiter:
def get_waiter(self, waiter_name: str, _: dict[str, str] | None = None) -> botocore.waiter.Waiter:
"""
Get an AWS Batch service waiter, using the configured ``.waiter_model``.
Expand Down Expand Up @@ -168,6 +168,8 @@ def get_waiter(self, waiter_name: str) -> botocore.waiter.Waiter:
the name (including the casing) of the key name in the waiter
model file (typically this is CamelCasing); see ``.list_waiters``.
:param _: unused, just here to match the method signature in base_aws
:return: a waiter object for the named AWS Batch service
"""
return botocore.waiter.create_waiter_with_client(waiter_name, self.waiter_model, self.client)
Expand Down
30 changes: 30 additions & 0 deletions airflow/providers/amazon/aws/waiters/appflow.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"version": 2,
"waiters": {
"run_complete": {
"operation": "DescribeFlowExecutionRecords",
"delay": 15,
"maxAttempts": 60,
"acceptors": [
{
"expected": "Successful",
"matcher": "path",
"state": "success",
"argument": "flowExecutions[?executionId=='{{EXECUTION_ID}}'].executionStatus"
},
{
"expected": "Error",
"matcher": "path",
"state": "failure",
"argument": "flowExecutions[?executionId=='{{EXECUTION_ID}}'].executionStatus"
},
{
"expected": true,
"matcher": "path",
"state": "failure",
"argument": "length(flowExecutions[?executionId=='{{EXECUTION_ID}}']) > `1`"
}
]
}
}
}
52 changes: 26 additions & 26 deletions tests/providers/amazon/aws/hooks/test_appflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,31 +35,30 @@

@pytest.fixture
def hook():
with mock.patch("airflow.providers.amazon.aws.hooks.appflow.AppflowHook.__init__", return_value=None):
with mock.patch("airflow.providers.amazon.aws.hooks.appflow.AppflowHook.conn") as mock_conn:
mock_conn.describe_flow.return_value = {
"sourceFlowConfig": {"connectorType": CONNECTION_TYPE},
"tasks": [],
"triggerConfig": {"triggerProperties": None},
"flowName": FLOW_NAME,
"destinationFlowConfigList": {},
"lastRunExecutionDetails": {
"mostRecentExecutionStatus": "Successful",
"mostRecentExecutionTime": datetime(3000, 1, 1, tzinfo=timezone.utc),
},
}
mock_conn.update_flow.return_value = {}
mock_conn.start_flow.return_value = {"executionId": EXECUTION_ID}
mock_conn.describe_flow_execution_records.return_value = {
"flowExecutions": [
{
"executionId": EXECUTION_ID,
"executionResult": {"recordsProcessed": 1},
"executionStatus": "Successful",
}
]
}
yield AppflowHook(aws_conn_id=AWS_CONN_ID, region_name=REGION_NAME)
with mock.patch("airflow.providers.amazon.aws.hooks.appflow.AppflowHook.conn") as mock_conn:
mock_conn.describe_flow.return_value = {
"sourceFlowConfig": {"connectorType": CONNECTION_TYPE},
"tasks": [],
"triggerConfig": {"triggerProperties": None},
"flowName": FLOW_NAME,
"destinationFlowConfigList": {},
"lastRunExecutionDetails": {
"mostRecentExecutionStatus": "Successful",
"mostRecentExecutionTime": datetime(3000, 1, 1, tzinfo=timezone.utc),
},
}
mock_conn.update_flow.return_value = {}
mock_conn.start_flow.return_value = {"executionId": EXECUTION_ID}
mock_conn.describe_flow_execution_records.return_value = {
"flowExecutions": [
{
"executionId": EXECUTION_ID,
"executionResult": {"recordsProcessed": 1},
"executionStatus": "Successful",
}
]
}
yield AppflowHook(aws_conn_id=AWS_CONN_ID, region_name=REGION_NAME)


def test_conn_attributes(hook):
Expand All @@ -69,7 +68,8 @@ def test_conn_attributes(hook):


def test_run_flow(hook):
hook.run_flow(flow_name=FLOW_NAME, poll_interval=0)
with mock.patch("airflow.providers.amazon.aws.waiters.base_waiter.BaseBotoWaiter.waiter"):
hook.run_flow(flow_name=FLOW_NAME, poll_interval=0)
hook.conn.describe_flow_execution_records.assert_called_with(flowName=FLOW_NAME)
assert hook.conn.describe_flow_execution_records.call_count == 1
hook.conn.start_flow.assert_called_once_with(flowName=FLOW_NAME)
Expand Down
53 changes: 51 additions & 2 deletions tests/providers/amazon/aws/hooks/test_base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
import os
from base64 import b64encode
from datetime import datetime, timedelta, timezone
from pathlib import Path
from unittest import mock
from unittest.mock import mock_open
from unittest.mock import MagicMock, PropertyMock, mock_open
from uuid import UUID

import boto3
import jinja2
import pytest
from botocore.config import Config
from botocore.credentials import ReadOnlyCredentials
Expand All @@ -34,9 +36,11 @@
from moto import mock_dynamodb, mock_emr, mock_iam, mock_sts
from moto.core import DEFAULT_ACCOUNT_ID

from airflow import AirflowException
from airflow.models.connection import Connection
from airflow.providers.amazon.aws.hooks.base_aws import (
AwsBaseHook,
AwsGenericHook,
BaseSessionFactory,
resolve_session_factory,
)
Expand All @@ -47,7 +51,6 @@
MOCK_CONN_TYPE = "aws"
MOCK_BOTO3_SESSION = mock.MagicMock(return_value="Mock boto3.session.Session")


SAML_ASSERTION = """
<?xml version="1.0"?>
<samlp:Response xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol" ID="_00000000-0000-0000-0000-000000000000" Version="2.0" IssueInstant="2012-01-01T12:00:00.000Z" Destination="https://signin.aws.amazon.com/saml" Consent="urn:oasis:names:tc:SAML:2.0:consent:unspecified">
Expand Down Expand Up @@ -978,3 +981,49 @@ def test_raise_no_creds_default_credentials_strategy(tmp_path_factory, monkeypat
# In normal circumstances lines below should not execute.
# We want to show additional information why this test not passed
assert not result, f"Credentials Method: {hook.get_session().get_credentials().method}"


TEST_WAITER_CONFIG_LOCATION = Path(__file__).parents[1].joinpath("waiters/test.json")


@mock.patch.object(AwsGenericHook, "waiter_path", new_callable=PropertyMock)
def test_waiter_config_params_not_provided(waiter_path_mock: MagicMock, caplog):
waiter_path_mock.return_value = TEST_WAITER_CONFIG_LOCATION
hook = AwsBaseHook(client_type="mwaa") # needs to be a real client type

with pytest.raises(AirflowException) as ae:
hook.get_waiter("wait_for_test")

# should warn about missing param
assert "PARAM_1" in str(ae.value)


@mock.patch.object(AwsGenericHook, "waiter_path", new_callable=PropertyMock)
def test_waiter_config_no_params_needed(waiter_path_mock: MagicMock, caplog):
waiter_path_mock.return_value = TEST_WAITER_CONFIG_LOCATION
hook = AwsBaseHook(client_type="mwaa") # needs to be a real client type

with caplog.at_level("WARN"):
hook.get_waiter("other_wait")

# other waiters in the json need params, but not this one, so we shouldn't warn about it.
assert len(caplog.text) == 0


@mock.patch.object(AwsGenericHook, "waiter_path", new_callable=PropertyMock)
def test_waiter_config_with_parameters_specified(waiter_path_mock: MagicMock):
waiter_path_mock.return_value = TEST_WAITER_CONFIG_LOCATION
hook = AwsBaseHook(client_type="mwaa") # needs to be a real client type

waiter = hook.get_waiter("wait_for_test", {"PARAM_1": "hello", "PARAM_2": "world"})

assert waiter.config.acceptors[0].argument == "'hello' == 'world'"


@mock.patch.object(AwsGenericHook, "waiter_path", new_callable=PropertyMock)
def test_waiter_config_param_wrong_format(waiter_path_mock: MagicMock):
waiter_path_mock.return_value = TEST_WAITER_CONFIG_LOCATION
hook = AwsBaseHook(client_type="mwaa") # needs to be a real client type

with pytest.raises(jinja2.TemplateSyntaxError):
hook.get_waiter("bad_param_wait")
16 changes: 11 additions & 5 deletions tests/providers/amazon/aws/operators/test_appflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ def appflow_conn():
yield mock_conn


@pytest.fixture
def waiter_mock():
with mock.patch("airflow.providers.amazon.aws.waiters.base_waiter.BaseBotoWaiter.waiter") as waiter:
yield waiter


def run_assertions_base(appflow_conn, tasks):
appflow_conn.describe_flow.assert_called_with(flowName=FLOW_NAME)
assert appflow_conn.describe_flow.call_count == 2
Expand All @@ -105,21 +111,21 @@ def run_assertions_base(appflow_conn, tasks):
appflow_conn.start_flow.assert_called_once_with(flowName=FLOW_NAME)


def test_run(appflow_conn, ctx):
def test_run(appflow_conn, ctx, waiter_mock):
operator = AppflowRunOperator(**DUMP_COMMON_ARGS)
operator.execute(ctx) # type: ignore
appflow_conn.describe_flow.assert_called_once_with(flowName=FLOW_NAME)
appflow_conn.describe_flow_execution_records.assert_called_once()
appflow_conn.start_flow.assert_called_once_with(flowName=FLOW_NAME)


def test_run_full(appflow_conn, ctx):
def test_run_full(appflow_conn, ctx, waiter_mock):
operator = AppflowRunFullOperator(**DUMP_COMMON_ARGS)
operator.execute(ctx) # type: ignore
run_assertions_base(appflow_conn, [])


def test_run_after(appflow_conn, ctx):
def test_run_after(appflow_conn, ctx, waiter_mock):
operator = AppflowRunAfterOperator(
source_field="col0", filter_date="2022-05-26T00:00+00:00", **DUMP_COMMON_ARGS
)
Expand All @@ -137,7 +143,7 @@ def test_run_after(appflow_conn, ctx):
)


def test_run_before(appflow_conn, ctx):
def test_run_before(appflow_conn, ctx, waiter_mock):
operator = AppflowRunBeforeOperator(
source_field="col0", filter_date="2022-05-26T00:00+00:00", **DUMP_COMMON_ARGS
)
Expand All @@ -155,7 +161,7 @@ def test_run_before(appflow_conn, ctx):
)


def test_run_daily(appflow_conn, ctx):
def test_run_daily(appflow_conn, ctx, waiter_mock):
operator = AppflowRunDailyOperator(
source_field="col0", filter_date="2022-05-26T00:00+00:00", **DUMP_COMMON_ARGS
)
Expand Down
Loading

0 comments on commit 05c0841

Please sign in to comment.