Skip to content

Commit

Permalink
Extend hooks arguments into AwsBaseWaiterTrigger
Browse files Browse the repository at this point in the history
  • Loading branch information
Taragolis committed Oct 12, 2023
1 parent 946b539 commit 997495d
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 3 deletions.
15 changes: 14 additions & 1 deletion airflow/providers/amazon/aws/triggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ class AwsBaseWaiterTrigger(BaseTrigger):
:param waiter_max_attempts: The maximum number of attempts to be made.
:param aws_conn_id: The Airflow connection used for AWS credentials. To be used to build the hook.
:param region_name: The AWS region where the resources to watch are. To be used to build the hook.
:param verify: Whether or not to verify SSL certificates. To be used to build the hook.
:param botocore_config: Configuration dictionary (key-values) for botocore client.
To be used to build the hook.
"""

def __init__(
Expand All @@ -72,6 +75,8 @@ def __init__(
waiter_max_attempts: int,
aws_conn_id: str | None,
region_name: str | None = None,
verify: bool | str | None = None,
botocore_config: dict | None = None,
):
# parameters that should be hardcoded in the child's implem
self.serialized_fields = serialized_fields
Expand All @@ -90,6 +95,8 @@ def __init__(
self.attempts = waiter_max_attempts
self.aws_conn_id = aws_conn_id
self.region_name = region_name
self.verify = verify
self.botocore_config = botocore_config

def serialize(self) -> tuple[str, dict[str, Any]]:
# here we put together the "common" params,
Expand All @@ -102,9 +109,15 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
},
**self.serialized_fields,
)

# if we serialize the None value from this, it breaks subclasses that don't have it in their ctor.
if self.region_name:
# if we serialize the None value from this, it breaks subclasses that don't have it in their ctor.
params["region_name"] = self.region_name
if self.verify is not None:
params["verify"] = self.verify
if self.botocore_config is not None:
params["botocore_config"] = self.botocore_config

return (
# remember that self is an instance of the subclass here, not of this class.
self.__class__.__module__ + "." + self.__class__.__qualname__,
Expand Down
27 changes: 25 additions & 2 deletions tests/providers/amazon/aws/triggers/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,33 @@ def test_region_serialized(self):
assert "region_name" in args
assert args["region_name"] == "my_region"

def test_region_not_serialized_if_omitted(self):
@pytest.mark.parametrize("verify", [True, False, pytest.param("/foo/bar.pem", id="path")])
def test_verify_serialized(self, verify):
self.trigger.verify = verify
_, args = self.trigger.serialize()

assert "region_name" not in args
assert "verify" in args
assert args["verify"] == verify

@pytest.mark.parametrize(
"botocore_config",
[
pytest.param({"read_timeout": 10, "connect_timeout": 42, "keepalive": True}, id="non-empty-dict"),
pytest.param({}, id="empty-dict"),
],
)
def test_botocore_config_serialized(self, botocore_config):
self.trigger.botocore_config = botocore_config
_, args = self.trigger.serialize()

assert "botocore_config" in args
assert args["botocore_config"] == botocore_config

@pytest.mark.parametrize("param_name", ["region_name", "verify", "botocore_config"])
def test_hooks_args_not_serialized_if_omitted(self, param_name):
_, args = self.trigger.serialize()

assert param_name not in args

def test_serialize_extra_fields(self):
self.trigger.serialized_fields = {"foo": "bar", "foz": "baz"}
Expand Down

0 comments on commit 997495d

Please sign in to comment.