Skip to content

Commit

Permalink
Use base aws classes in AWS Glue Crawlers Operators/Sensors/Triggers (a…
Browse files Browse the repository at this point in the history
  • Loading branch information
gopidesupavan authored and romsharon98 committed Jul 26, 2024
1 parent 9c91f97 commit e88264c
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 48 deletions.
41 changes: 22 additions & 19 deletions airflow/providers/amazon/aws/operators/glue_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,22 @@
# under the License.
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
from airflow.providers.amazon.aws.triggers.glue_crawler import GlueCrawlerCompleteTrigger
from airflow.providers.amazon.aws.utils import validate_execute_complete_event
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields

if TYPE_CHECKING:
from airflow.utils.context import Context

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.glue_crawler import GlueCrawlerHook


class GlueCrawlerOperator(BaseOperator):
class GlueCrawlerOperator(AwsBaseOperator[GlueCrawlerHook]):
"""
Creates, updates and triggers an AWS Glue Crawler.
Expand All @@ -45,45 +45,45 @@ class GlueCrawlerOperator(BaseOperator):
:ref:`howto/operator:GlueCrawlerOperator`
:param config: Configurations for the AWS Glue crawler
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is None or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param poll_interval: Time (in seconds) to wait between two consecutive calls to check crawler status
:param wait_for_completion: Whether to wait for crawl execution completion. (default: True)
:param deferrable: If True, the operator will wait asynchronously for the crawl to complete.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False)
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields: Sequence[str] = ("config",)
aws_hook_class = GlueCrawlerHook

template_fields: Sequence[str] = aws_template_fields(
"config",
)
ui_color = "#ededed"

def __init__(
self,
config,
aws_conn_id="aws_default",
region_name: str | None = None,
poll_interval: int = 5,
wait_for_completion: bool = True,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
):
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id
self.poll_interval = poll_interval
self.wait_for_completion = wait_for_completion
self.deferrable = deferrable
self.region_name = region_name
self.config = config

@cached_property
def hook(self) -> GlueCrawlerHook:
"""Create and return a GlueCrawlerHook."""
return GlueCrawlerHook(self.aws_conn_id, region_name=self.region_name)

def execute(self, context: Context):
def execute(self, context: Context) -> str:
"""
Execute AWS Glue Crawler from Airflow.
Expand All @@ -103,6 +103,9 @@ def execute(self, context: Context):
crawler_name=crawler_name,
waiter_delay=self.poll_interval,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
botocore_config=self.botocore_config,
),
method_name="execute_complete",
)
Expand Down
28 changes: 16 additions & 12 deletions airflow/providers/amazon/aws/sensors/glue_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,20 @@
# under the License.
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING, Sequence

from deprecated import deprecated

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException
from airflow.providers.amazon.aws.hooks.glue_crawler import GlueCrawlerHook
from airflow.sensors.base import BaseSensorOperator
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields

if TYPE_CHECKING:
from airflow.utils.context import Context


class GlueCrawlerSensor(BaseSensorOperator):
class GlueCrawlerSensor(AwsBaseSensor[GlueCrawlerHook]):
"""
Waits for an AWS Glue crawler to reach any of the statuses below.
Expand All @@ -41,19 +41,27 @@ class GlueCrawlerSensor(BaseSensorOperator):
:ref:`howto/sensor:GlueCrawlerSensor`
:param crawler_name: The AWS Glue crawler unique name
:param aws_conn_id: aws connection to use, defaults to 'aws_default'
If this is None or empty then the default boto3 behaviour is used. If
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields: Sequence[str] = ("crawler_name",)
aws_hook_class = GlueCrawlerHook

def __init__(self, *, crawler_name: str, aws_conn_id: str | None = "aws_default", **kwargs) -> None:
template_fields: Sequence[str] = aws_template_fields(
"crawler_name",
)

def __init__(self, *, crawler_name: str, **kwargs) -> None:
super().__init__(**kwargs)
self.crawler_name = crawler_name
self.aws_conn_id = aws_conn_id
self.success_statuses = "SUCCEEDED"
self.errored_statuses = ("FAILED", "CANCELLED")

Expand All @@ -79,7 +87,3 @@ def poke(self, context: Context):
def get_hook(self) -> GlueCrawlerHook:
"""Return a new or pre-existing GlueCrawlerHook."""
return self.hook

@cached_property
def hook(self) -> GlueCrawlerHook:
return GlueCrawlerHook(aws_conn_id=self.aws_conn_id)
9 changes: 8 additions & 1 deletion airflow/providers/amazon/aws/triggers/glue_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
aws_conn_id: str | None = "aws_default",
waiter_delay: int = 5,
waiter_max_attempts: int = 1500,
**kwargs,
):
if poll_interval is not None:
warnings.warn(
Expand All @@ -62,7 +63,13 @@ def __init__(
waiter_delay=waiter_delay,
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
**kwargs,
)

def hook(self) -> AwsGenericHook:
return GlueCrawlerHook(aws_conn_id=self.aws_conn_id)
return GlueCrawlerHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
config=self.botocore_config,
)
5 changes: 5 additions & 0 deletions docs/apache-airflow-providers-amazon/operators/glue.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ Prerequisite Tasks

.. include:: ../_partials/prerequisite_tasks.rst

Generic Parameters
------------------

.. include:: ../_partials/generic_parameters.rst

Operators
---------

Expand Down
105 changes: 90 additions & 15 deletions tests/providers/amazon/aws/operators/test_glue_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,19 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Generator
from unittest import mock

import pytest
from moto import mock_aws

from airflow.providers.amazon.aws.hooks.glue_crawler import GlueCrawlerHook
from airflow.providers.amazon.aws.hooks.sts import StsHook
from airflow.providers.amazon.aws.operators.glue_crawler import GlueCrawlerOperator

if TYPE_CHECKING:
from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection

mock_crawler_name = "test-crawler"
mock_role_name = "test-role"
mock_config = {
Expand Down Expand Up @@ -81,20 +90,86 @@


class TestGlueCrawlerOperator:
@pytest.fixture
def mock_conn(self) -> Generator[BaseAwsConnection, None, None]:
with mock.patch.object(GlueCrawlerHook, "get_conn") as _conn:
_conn.create_crawler.return_value = mock_crawler_name
yield _conn

@pytest.fixture
def crawler_hook(self) -> Generator[GlueCrawlerHook, None, None]:
with mock_aws():
hook = GlueCrawlerHook(aws_conn_id="aws_default")
yield hook

def setup_method(self):
self.glue = GlueCrawlerOperator(task_id="test_glue_crawler_operator", config=mock_config)

@mock.patch("airflow.providers.amazon.aws.operators.glue_crawler.GlueCrawlerHook")
def test_execute_without_failure(self, mock_hook):
mock_hook.return_value.has_crawler.return_value = True
self.glue.execute({})

mock_hook.assert_has_calls(
[
mock.call("aws_default", region_name=None),
mock.call().has_crawler("test-crawler"),
mock.call().update_crawler(**mock_config),
mock.call().start_crawler(mock_crawler_name),
mock.call().wait_for_crawler_completion(crawler_name=mock_crawler_name, poll_interval=5),
]
self.op = GlueCrawlerOperator(task_id="test_glue_crawler_operator", config=mock_config)

def test_init(self):
op = GlueCrawlerOperator(
task_id="test_glue_crawler_operator",
aws_conn_id="fake-conn-id",
region_name="eu-west-2",
verify=True,
botocore_config={"read_timeout": 42},
config=mock_config,
)

assert op.hook.client_type == "glue"
assert op.hook.resource_type is None
assert op.hook.aws_conn_id == "fake-conn-id"
assert op.hook._region_name == "eu-west-2"
assert op.hook._verify is True
assert op.hook._config is not None
assert op.hook._config.read_timeout == 42

op = GlueCrawlerOperator(task_id="test_glue_crawler_operator", config=mock_config)

assert op.hook.aws_conn_id == "aws_default"
assert op.hook._region_name is None
assert op.hook._verify is None
assert op.hook._config is None

@mock.patch.object(GlueCrawlerHook, "glue_client")
@mock.patch.object(StsHook, "get_account_number")
def test_execute_update_and_start_crawler(self, sts_mock, mock_glue_client):
sts_mock.get_account_number.return_value = 123456789012
mock_glue_client.get_crawler.return_value = {"Crawler": {}}
self.op.wait_for_completion = False
crawler_name = self.op.execute({})

mock_glue_client.get_crawler.call_count = 2
mock_glue_client.update_crawler.call_count = 1
mock_glue_client.start_crawler.call_count = 1
assert crawler_name == mock_crawler_name

@mock.patch.object(GlueCrawlerHook, "has_crawler")
@mock.patch.object(GlueCrawlerHook, "glue_client")
def test_execute_create_and_start_crawler(self, mock_glue_client, mock_has_crawler):
mock_has_crawler.return_value = False
mock_glue_client.create_crawler.return_value = {}
self.op.wait_for_completion = False
crawler_name = self.op.execute({})

assert crawler_name == mock_crawler_name
mock_glue_client.create_crawler.assert_called_once()

@pytest.mark.parametrize(
"wait_for_completion, deferrable",
[
pytest.param(False, False, id="no_wait"),
pytest.param(True, False, id="wait"),
pytest.param(False, True, id="defer"),
],
)
@mock.patch.object(GlueCrawlerHook, "get_waiter")
def test_crawler_wait_combinations(self, _, wait_for_completion, deferrable, mock_conn, crawler_hook):
self.op.defer = mock.MagicMock()
self.op.wait_for_completion = wait_for_completion
self.op.deferrable = deferrable

response = self.op.execute({})

assert response == mock_crawler_name
assert crawler_hook.get_waiter.call_count == wait_for_completion
assert self.op.defer.call_count == deferrable
25 changes: 25 additions & 0 deletions tests/providers/amazon/aws/sensors/test_glue_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,28 @@ def test_fail_poke(self, get_crawler, soft_fail, expected_exception):
message = f"Status: {crawler_status}"
with pytest.raises(expected_exception, match=message):
self.sensor.poke(context={})

def test_base_aws_op_attributes(self):
op = GlueCrawlerSensor(
task_id="test_glue_crawler_sensor",
crawler_name="aws_test_glue_crawler",
)
assert op.hook.client_type == "glue"
assert op.hook.aws_conn_id == "aws_default"
assert op.hook._region_name is None
assert op.hook._verify is None
assert op.hook._config is None

op = GlueCrawlerSensor(
task_id="test_glue_crawler_sensor",
crawler_name="aws_test_glue_crawler",
aws_conn_id="aws-test-custom-conn",
region_name="eu-west-1",
verify=False,
botocore_config={"read_timeout": 42},
)
assert op.hook.aws_conn_id == "aws-test-custom-conn"
assert op.hook._region_name == "eu-west-1"
assert op.hook._verify is False
assert op.hook._config is not None
assert op.hook._config.read_timeout == 42
24 changes: 23 additions & 1 deletion tests/providers/amazon/aws/triggers/test_glue_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,15 @@
# under the License.
from __future__ import annotations

from unittest.mock import patch
from unittest import mock
from unittest.mock import AsyncMock, patch

import pytest

from airflow.providers.amazon.aws.hooks.glue_crawler import GlueCrawlerHook
from airflow.providers.amazon.aws.triggers.glue_crawler import GlueCrawlerCompleteTrigger
from airflow.triggers.base import TriggerEvent
from tests.providers.amazon.aws.utils.test_waiter import assert_expected_waiter_type


class TestGlueCrawlerCompleteTrigger:
Expand Down Expand Up @@ -47,3 +53,19 @@ def test_serialization(self, mock_warn):
"waiter_max_attempts": 1500,
"aws_conn_id": "aws_default",
}

@pytest.mark.asyncio
@mock.patch.object(GlueCrawlerHook, "get_waiter")
@mock.patch.object(GlueCrawlerHook, "async_conn")
async def test_run_success(self, mock_async_conn, mock_get_waiter):
mock_async_conn.__aenter__.return_value = mock.MagicMock()
mock_get_waiter().wait = AsyncMock()
crawler_name = "test_crawler"
trigger = GlueCrawlerCompleteTrigger(crawler_name=crawler_name)

generator = trigger.run()
response = await generator.asend(None)

assert response == TriggerEvent({"status": "success", "value": None})
assert_expected_waiter_type(mock_get_waiter, "crawler_ready")
mock_get_waiter().wait.assert_called_once()

0 comments on commit e88264c

Please sign in to comment.