From 985bb06ba57ab67bb218ca9ca7549a81bea88f87 Mon Sep 17 00:00:00 2001 From: eladkal <45845474+eladkal@users.noreply.github.com> Date: Thu, 9 Dec 2021 19:00:29 +0200 Subject: [PATCH] Organize EC2 classes in Amazon provider (#20157) * Organize EC2 classes in Amazon provider --- airflow/providers/amazon/aws/operators/ec2.py | 115 ++++++++++++++++++ .../aws/operators/ec2_start_instance.py | 56 ++------- .../amazon/aws/operators/ec2_stop_instance.py | 56 ++------- airflow/providers/amazon/aws/sensors/ec2.py | 65 ++++++++++ .../amazon/aws/sensors/ec2_instance_state.py | 53 ++------ airflow/providers/amazon/provider.yaml | 2 + .../prepare_provider_packages.py | 2 + tests/deprecated_classes.py | 12 ++ ...test_ec2_start_instance.py => test_ec2.py} | 39 +++++- .../aws/operators/test_ec2_stop_instance.py | 60 --------- ...test_ec2_instance_state.py => test_ec2.py} | 2 +- 11 files changed, 258 insertions(+), 204 deletions(-) create mode 100644 airflow/providers/amazon/aws/operators/ec2.py create mode 100644 airflow/providers/amazon/aws/sensors/ec2.py rename tests/providers/amazon/aws/operators/{test_ec2_start_instance.py => test_ec2.py} (60%) delete mode 100644 tests/providers/amazon/aws/operators/test_ec2_stop_instance.py rename tests/providers/amazon/aws/sensors/{test_ec2_instance_state.py => test_ec2.py} (97%) diff --git a/airflow/providers/amazon/aws/operators/ec2.py b/airflow/providers/amazon/aws/operators/ec2.py new file mode 100644 index 0000000000000..1e58390c3ae6a --- /dev/null +++ b/airflow/providers/amazon/aws/operators/ec2.py @@ -0,0 +1,115 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from typing import Optional + +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook + + +class EC2StartInstanceOperator(BaseOperator): + """ + Start AWS EC2 instance using boto3. + + :param instance_id: id of the AWS EC2 instance + :type instance_id: str + :param aws_conn_id: aws connection to use + :type aws_conn_id: str + :param region_name: (optional) aws region name associated with the client + :type region_name: Optional[str] + :param check_interval: time in seconds that the job should wait in + between each instance state checks until operation is completed + :type check_interval: float + """ + + template_fields = ("instance_id", "region_name") + ui_color = "#eeaa11" + ui_fgcolor = "#ffffff" + + def __init__( + self, + *, + instance_id: str, + aws_conn_id: str = "aws_default", + region_name: Optional[str] = None, + check_interval: float = 15, + **kwargs, + ): + super().__init__(**kwargs) + self.instance_id = instance_id + self.aws_conn_id = aws_conn_id + self.region_name = region_name + self.check_interval = check_interval + + def execute(self, context): + ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) + self.log.info("Starting EC2 instance %s", self.instance_id) + instance = ec2_hook.get_instance(instance_id=self.instance_id) + instance.start() + ec2_hook.wait_for_state( + instance_id=self.instance_id, + target_state="running", + check_interval=self.check_interval, + ) + + +class EC2StopInstanceOperator(BaseOperator): + """ + Stop AWS EC2 instance using boto3. + + :param instance_id: id of the AWS EC2 instance + :type instance_id: str + :param aws_conn_id: aws connection to use + :type aws_conn_id: str + :param region_name: (optional) aws region name associated with the client + :type region_name: Optional[str] + :param check_interval: time in seconds that the job should wait in + between each instance state checks until operation is completed + :type check_interval: float + """ + + template_fields = ("instance_id", "region_name") + ui_color = "#eeaa11" + ui_fgcolor = "#ffffff" + + def __init__( + self, + *, + instance_id: str, + aws_conn_id: str = "aws_default", + region_name: Optional[str] = None, + check_interval: float = 15, + **kwargs, + ): + super().__init__(**kwargs) + self.instance_id = instance_id + self.aws_conn_id = aws_conn_id + self.region_name = region_name + self.check_interval = check_interval + + def execute(self, context): + ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) + self.log.info("Stopping EC2 instance %s", self.instance_id) + instance = ec2_hook.get_instance(instance_id=self.instance_id) + instance.stop() + ec2_hook.wait_for_state( + instance_id=self.instance_id, + target_state="stopped", + check_interval=self.check_interval, + ) diff --git a/airflow/providers/amazon/aws/operators/ec2_start_instance.py b/airflow/providers/amazon/aws/operators/ec2_start_instance.py index f6e324146d66a..c2c25e5708b0a 100644 --- a/airflow/providers/amazon/aws/operators/ec2_start_instance.py +++ b/airflow/providers/amazon/aws/operators/ec2_start_instance.py @@ -16,54 +16,14 @@ # specific language governing permissions and limitations # under the License. # +"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.ec2`.""" -from typing import Optional +import warnings -from airflow.models import BaseOperator -from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook +from airflow.providers.amazon.aws.operators.ec2 import EC2StartInstanceOperator # noqa - -class EC2StartInstanceOperator(BaseOperator): - """ - Start AWS EC2 instance using boto3. - - :param instance_id: id of the AWS EC2 instance - :type instance_id: str - :param aws_conn_id: aws connection to use - :type aws_conn_id: str - :param region_name: (optional) aws region name associated with the client - :type region_name: Optional[str] - :param check_interval: time in seconds that the job should wait in - between each instance state checks until operation is completed - :type check_interval: float - """ - - template_fields = ("instance_id", "region_name") - ui_color = "#eeaa11" - ui_fgcolor = "#ffffff" - - def __init__( - self, - *, - instance_id: str, - aws_conn_id: str = "aws_default", - region_name: Optional[str] = None, - check_interval: float = 15, - **kwargs, - ): - super().__init__(**kwargs) - self.instance_id = instance_id - self.aws_conn_id = aws_conn_id - self.region_name = region_name - self.check_interval = check_interval - - def execute(self, context): - ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) - self.log.info("Starting EC2 instance %s", self.instance_id) - instance = ec2_hook.get_instance(instance_id=self.instance_id) - instance.start() - ec2_hook.wait_for_state( - instance_id=self.instance_id, - target_state="running", - check_interval=self.check_interval, - ) +warnings.warn( + "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.ec2`.", + DeprecationWarning, + stacklevel=2, +) diff --git a/airflow/providers/amazon/aws/operators/ec2_stop_instance.py b/airflow/providers/amazon/aws/operators/ec2_stop_instance.py index e32090739a466..ddafa21c5bd5e 100644 --- a/airflow/providers/amazon/aws/operators/ec2_stop_instance.py +++ b/airflow/providers/amazon/aws/operators/ec2_stop_instance.py @@ -16,54 +16,14 @@ # specific language governing permissions and limitations # under the License. # +"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.ec2`.""" -from typing import Optional +import warnings -from airflow.models import BaseOperator -from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook +from airflow.providers.amazon.aws.operators.ec2 import EC2StopInstanceOperator # noqa - -class EC2StopInstanceOperator(BaseOperator): - """ - Stop AWS EC2 instance using boto3. - - :param instance_id: id of the AWS EC2 instance - :type instance_id: str - :param aws_conn_id: aws connection to use - :type aws_conn_id: str - :param region_name: (optional) aws region name associated with the client - :type region_name: Optional[str] - :param check_interval: time in seconds that the job should wait in - between each instance state checks until operation is completed - :type check_interval: float - """ - - template_fields = ("instance_id", "region_name") - ui_color = "#eeaa11" - ui_fgcolor = "#ffffff" - - def __init__( - self, - *, - instance_id: str, - aws_conn_id: str = "aws_default", - region_name: Optional[str] = None, - check_interval: float = 15, - **kwargs, - ): - super().__init__(**kwargs) - self.instance_id = instance_id - self.aws_conn_id = aws_conn_id - self.region_name = region_name - self.check_interval = check_interval - - def execute(self, context): - ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) - self.log.info("Stopping EC2 instance %s", self.instance_id) - instance = ec2_hook.get_instance(instance_id=self.instance_id) - instance.stop() - ec2_hook.wait_for_state( - instance_id=self.instance_id, - target_state="stopped", - check_interval=self.check_interval, - ) +warnings.warn( + "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.ec2`.", + DeprecationWarning, + stacklevel=2, +) diff --git a/airflow/providers/amazon/aws/sensors/ec2.py b/airflow/providers/amazon/aws/sensors/ec2.py new file mode 100644 index 0000000000000..83c12d413876e --- /dev/null +++ b/airflow/providers/amazon/aws/sensors/ec2.py @@ -0,0 +1,65 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from typing import Optional + +from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook +from airflow.sensors.base import BaseSensorOperator + + +class EC2InstanceStateSensor(BaseSensorOperator): + """ + Check the state of the AWS EC2 instance until + state of the instance become equal to the target state. + + :param target_state: target state of instance + :type target_state: str + :param instance_id: id of the AWS EC2 instance + :type instance_id: str + :param region_name: (optional) aws region name associated with the client + :type region_name: Optional[str] + """ + + template_fields = ("target_state", "instance_id", "region_name") + ui_color = "#cc8811" + ui_fgcolor = "#ffffff" + valid_states = ["running", "stopped", "terminated"] + + def __init__( + self, + *, + target_state: str, + instance_id: str, + aws_conn_id: str = "aws_default", + region_name: Optional[str] = None, + **kwargs, + ): + if target_state not in self.valid_states: + raise ValueError(f"Invalid target_state: {target_state}") + super().__init__(**kwargs) + self.target_state = target_state + self.instance_id = instance_id + self.aws_conn_id = aws_conn_id + self.region_name = region_name + + def poke(self, context): + ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) + instance_state = ec2_hook.get_instance_state(instance_id=self.instance_id) + self.log.info("instance state: %s", instance_state) + return instance_state == self.target_state diff --git a/airflow/providers/amazon/aws/sensors/ec2_instance_state.py b/airflow/providers/amazon/aws/sensors/ec2_instance_state.py index 83c12d413876e..d166b69d20ffb 100644 --- a/airflow/providers/amazon/aws/sensors/ec2_instance_state.py +++ b/airflow/providers/amazon/aws/sensors/ec2_instance_state.py @@ -15,51 +15,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# - -from typing import Optional - -from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook -from airflow.sensors.base import BaseSensorOperator - - -class EC2InstanceStateSensor(BaseSensorOperator): - """ - Check the state of the AWS EC2 instance until - state of the instance become equal to the target state. - - :param target_state: target state of instance - :type target_state: str - :param instance_id: id of the AWS EC2 instance - :type instance_id: str - :param region_name: (optional) aws region name associated with the client - :type region_name: Optional[str] - """ +"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.ec2`.""" - template_fields = ("target_state", "instance_id", "region_name") - ui_color = "#cc8811" - ui_fgcolor = "#ffffff" - valid_states = ["running", "stopped", "terminated"] +import warnings - def __init__( - self, - *, - target_state: str, - instance_id: str, - aws_conn_id: str = "aws_default", - region_name: Optional[str] = None, - **kwargs, - ): - if target_state not in self.valid_states: - raise ValueError(f"Invalid target_state: {target_state}") - super().__init__(**kwargs) - self.target_state = target_state - self.instance_id = instance_id - self.aws_conn_id = aws_conn_id - self.region_name = region_name +from airflow.providers.amazon.aws.sensors.ec2 import EC2InstanceStateSensor # noqa - def poke(self, context): - ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) - instance_state = ec2_hook.get_instance_state(instance_id=self.instance_id) - self.log.info("instance state: %s", instance_state) - return instance_state == self.target_state +warnings.warn( + "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.ec2`.", + DeprecationWarning, + stacklevel=2, +) diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index f7c392f5e228d..31f05a26563a3 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -192,6 +192,7 @@ operators: python-modules: - airflow.providers.amazon.aws.operators.ec2_start_instance - airflow.providers.amazon.aws.operators.ec2_stop_instance + - airflow.providers.amazon.aws.operators.ec2 - integration-name: Amazon ECS python-modules: - airflow.providers.amazon.aws.operators.ecs @@ -263,6 +264,7 @@ sensors: - integration-name: Amazon EC2 python-modules: - airflow.providers.amazon.aws.sensors.ec2_instance_state + - airflow.providers.amazon.aws.sensors.ec2 - integration-name: Amazon Elastic Kubernetes Service (EKS) python-modules: - airflow.providers.amazon.aws.sensors.eks diff --git a/dev/provider_packages/prepare_provider_packages.py b/dev/provider_packages/prepare_provider_packages.py index 007f677bedb22..6565410d7bb3e 100755 --- a/dev/provider_packages/prepare_provider_packages.py +++ b/dev/provider_packages/prepare_provider_packages.py @@ -2137,6 +2137,8 @@ def summarise_total_vs_bad_and_warnings(total: int, bad: int, warns: List[warnin "This module is deprecated. Please use `kubernetes.client.models.V1VolumeMount`.", 'numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header,' ' got 216 from PyObject', + 'This module is deprecated. Please use `airflow.providers.amazon.aws.operators.ec2`.', + 'This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.ec2`.', } diff --git a/tests/deprecated_classes.py b/tests/deprecated_classes.py index 2cb927947f716..5fcd2f4c47ec5 100644 --- a/tests/deprecated_classes.py +++ b/tests/deprecated_classes.py @@ -1343,6 +1343,14 @@ "airflow.operators.dummy.DummyOperator", "airflow.operators.dummy_operator.DummyOperator", ), + ( + "airflow.providers.amazon.aws.operators.ec2.EC2StartInstanceOperator", + "airflow.providers.amazon.aws.operators.ec2_start_instance.EC2StartInstanceOperator", + ), + ( + "airflow.providers.amazon.aws.operators.ec2.EC2StopInstanceOperator", + "airflow.providers.amazon.aws.operators.ec2_stop_instance.EC2StopInstanceOperator", + ), ] SECRETS = [ @@ -1586,6 +1594,10 @@ 'airflow.providers.sftp.sensors.sftp.SFTPSensor', 'airflow.contrib.sensors.sftp_sensor.SFTPSensor', ), + ( + 'airflow.providers.amazon.aws.sensors.ec2.EC2InstanceStateSensor', + 'airflow.providers.amazon.aws.sensors.ec2_instance_state.EC2InstanceStateSensor', + ), ] TRANSFERS = [ diff --git a/tests/providers/amazon/aws/operators/test_ec2_start_instance.py b/tests/providers/amazon/aws/operators/test_ec2.py similarity index 60% rename from tests/providers/amazon/aws/operators/test_ec2_start_instance.py rename to tests/providers/amazon/aws/operators/test_ec2.py index 994cf90769e46..fc21a27c9e542 100644 --- a/tests/providers/amazon/aws/operators/test_ec2_start_instance.py +++ b/tests/providers/amazon/aws/operators/test_ec2.py @@ -22,10 +22,10 @@ from moto import mock_ec2 from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook -from airflow.providers.amazon.aws.operators.ec2_start_instance import EC2StartInstanceOperator +from airflow.providers.amazon.aws.operators.ec2 import EC2StartInstanceOperator, EC2StopInstanceOperator -class TestEC2Operator(unittest.TestCase): +class TestEC2StartInstanceOperator(unittest.TestCase): def test_init(self): ec2_operator = EC2StartInstanceOperator( task_id="task_test", @@ -58,3 +58,38 @@ def test_start_instance(self): start_test.execute(None) # assert instance state is running assert ec2_hook.get_instance_state(instance_id=instance_id) == "running" + + +class TestEC2StopInstanceOperator(unittest.TestCase): + def test_init(self): + ec2_operator = EC2StopInstanceOperator( + task_id="task_test", + instance_id="i-123abc", + aws_conn_id="aws_conn_test", + region_name="region-test", + check_interval=3, + ) + assert ec2_operator.task_id == "task_test" + assert ec2_operator.instance_id == "i-123abc" + assert ec2_operator.aws_conn_id == "aws_conn_test" + assert ec2_operator.region_name == "region-test" + assert ec2_operator.check_interval == 3 + + @mock_ec2 + def test_stop_instance(self): + # create instance + ec2_hook = EC2Hook() + instances = ec2_hook.conn.create_instances( + MaxCount=1, + MinCount=1, + ) + instance_id = instances[0].instance_id + + # stop instance + stop_test = EC2StopInstanceOperator( + task_id="stop_test", + instance_id=instance_id, + ) + stop_test.execute(None) + # assert instance state is running + assert ec2_hook.get_instance_state(instance_id=instance_id) == "stopped" diff --git a/tests/providers/amazon/aws/operators/test_ec2_stop_instance.py b/tests/providers/amazon/aws/operators/test_ec2_stop_instance.py deleted file mode 100644 index 6bc591b1eaab6..0000000000000 --- a/tests/providers/amazon/aws/operators/test_ec2_stop_instance.py +++ /dev/null @@ -1,60 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -import unittest - -from moto import mock_ec2 - -from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook -from airflow.providers.amazon.aws.operators.ec2_stop_instance import EC2StopInstanceOperator - - -class TestEC2Operator(unittest.TestCase): - def test_init(self): - ec2_operator = EC2StopInstanceOperator( - task_id="task_test", - instance_id="i-123abc", - aws_conn_id="aws_conn_test", - region_name="region-test", - check_interval=3, - ) - assert ec2_operator.task_id == "task_test" - assert ec2_operator.instance_id == "i-123abc" - assert ec2_operator.aws_conn_id == "aws_conn_test" - assert ec2_operator.region_name == "region-test" - assert ec2_operator.check_interval == 3 - - @mock_ec2 - def test_stop_instance(self): - # create instance - ec2_hook = EC2Hook() - instances = ec2_hook.conn.create_instances( - MaxCount=1, - MinCount=1, - ) - instance_id = instances[0].instance_id - - # stop instance - stop_test = EC2StopInstanceOperator( - task_id="stop_test", - instance_id=instance_id, - ) - stop_test.execute(None) - # assert instance state is running - assert ec2_hook.get_instance_state(instance_id=instance_id) == "stopped" diff --git a/tests/providers/amazon/aws/sensors/test_ec2_instance_state.py b/tests/providers/amazon/aws/sensors/test_ec2.py similarity index 97% rename from tests/providers/amazon/aws/sensors/test_ec2_instance_state.py rename to tests/providers/amazon/aws/sensors/test_ec2.py index e715da291f30d..f7e3fdba026d0 100644 --- a/tests/providers/amazon/aws/sensors/test_ec2_instance_state.py +++ b/tests/providers/amazon/aws/sensors/test_ec2.py @@ -23,7 +23,7 @@ from moto import mock_ec2 from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook -from airflow.providers.amazon.aws.sensors.ec2_instance_state import EC2InstanceStateSensor +from airflow.providers.amazon.aws.sensors.ec2 import EC2InstanceStateSensor class TestEC2InstanceStateSensor(unittest.TestCase):