diff --git a/airflow/contrib/sensors/aws_redshift_cluster_sensor.py b/airflow/contrib/sensors/aws_redshift_cluster_sensor.py index c6f3c92ee80ee..dc0da1372b400 100644 --- a/airflow/contrib/sensors/aws_redshift_cluster_sensor.py +++ b/airflow/contrib/sensors/aws_redshift_cluster_sensor.py @@ -20,7 +20,9 @@ import warnings -from airflow.providers.amazon.aws.sensors.redshift_cluster import AwsRedshiftClusterSensor +from airflow.providers.amazon.aws.sensors.redshift_cluster import ( + RedshiftClusterSensor as AwsRedshiftClusterSensor, +) warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.redshift_cluster`.", diff --git a/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_profile.py b/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_profile.py index 8fde39cb7bbd7..e58e2de729b62 100644 --- a/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_profile.py +++ b/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_profile.py @@ -34,7 +34,7 @@ CLUSTER_NAME = 'fargate-demo' FARGATE_PROFILE_NAME = f'{CLUSTER_NAME}-profile' -SELECTORS = environ.get('FARGATE_SELECTORS', [{'namespace': 'default'}]) +SELECTORS = [{'namespace': 'default'}] ROLE_ARN = environ.get('EKS_DEMO_ROLE_ARN', 'arn:aws:iam::123456789012:role/role_name') SUBNETS = environ.get('EKS_DEMO_SUBNETS', 'subnet-12345ab subnet-67890cd').split(' ') diff --git a/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroup_in_one_step.py b/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroup_in_one_step.py index 3aa78754a796a..f19eec622f295 100644 --- a/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroup_in_one_step.py +++ b/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroup_in_one_step.py @@ -94,4 +94,10 @@ target_state=ClusterStates.NONEXISTENT, ) - create_cluster_and_nodegroup >> await_create_nodegroup >> start_pod >> delete_all >> await_delete_cluster + ( + create_cluster_and_nodegroup + >> await_create_nodegroup + >> start_pod + >> delete_all + >> await_delete_cluster + ) diff --git a/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroups.py b/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroups.py index ee4ebc499b5e5..3ec6a3ac459a0 100644 --- a/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroups.py +++ b/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroups.py @@ -55,7 +55,7 @@ ) as dag: # [START howto_operator_eks_create_cluster] - # Create an Amazon EKS Cluster control plane without attaching a compute service. + # Create an Amazon EKS Cluster control plane without attaching compute service. create_cluster = EksCreateClusterOperator( task_id='create_eks_cluster', cluster_role_arn=ROLE_ARN, diff --git a/airflow/providers/amazon/aws/hooks/eks.py b/airflow/providers/amazon/aws/hooks/eks.py index d21d905929d95..1b22bfecec84c 100644 --- a/airflow/providers/amazon/aws/hooks/eks.py +++ b/airflow/providers/amazon/aws/hooks/eks.py @@ -24,7 +24,7 @@ from contextlib import contextmanager from enum import Enum from functools import partial -from typing import Callable, Dict, List, Optional +from typing import Callable, Dict, Generator, List, Optional import yaml from botocore.exceptions import ClientError @@ -124,7 +124,7 @@ def create_nodegroup( clusterName: str, nodegroupName: str, subnets: List[str], - nodeRole: str, + nodeRole: Optional[str], *, tags: Optional[Dict] = None, **kwargs, @@ -170,8 +170,8 @@ def create_nodegroup( def create_fargate_profile( self, clusterName: str, - fargateProfileName: str, - podExecutionRoleArn: str, + fargateProfileName: Optional[str], + podExecutionRoleArn: Optional[str], selectors: List, **kwargs, ) -> Dict: @@ -498,10 +498,10 @@ def _list_all(self, api_call: Callable, response_key: str, verbose: bool) -> Lis def generate_config_file( self, eks_cluster_name: str, - pod_namespace: str, + pod_namespace: Optional[str], pod_username: Optional[str] = None, pod_context: Optional[str] = None, - ) -> str: + ) -> Generator[str, None, None]: """ Writes the kubeconfig file given an EKS Cluster. diff --git a/airflow/providers/amazon/aws/hooks/sagemaker.py b/airflow/providers/amazon/aws/hooks/sagemaker.py index 612842462c2e4..b163fb7919db7 100644 --- a/airflow/providers/amazon/aws/hooks/sagemaker.py +++ b/airflow/providers/amazon/aws/hooks/sagemaker.py @@ -21,8 +21,9 @@ import tempfile import time import warnings +from datetime import datetime from functools import partial -from typing import Any, Callable, Dict, Generator, List, Optional, Set +from typing import Any, Callable, Dict, Generator, List, Optional, Set, cast from botocore.exceptions import ClientError @@ -93,7 +94,7 @@ def secondary_training_status_changed(current_job_description: dict, prev_job_de def secondary_training_status_message( - job_description: Dict[str, List[dict]], prev_description: Optional[dict] + job_description: Dict[str, List[Any]], prev_description: Optional[dict] ) -> str: """ Returns a string contains start time and the secondary training job status message. @@ -121,7 +122,9 @@ def secondary_training_status_message( status_strs = [] for transition in transitions_to_print: message = transition['StatusMessage'] - time_str = timezone.convert_to_utc(job_description['LastModifiedTime']).strftime('%Y-%m-%d %H:%M:%S') + time_str = timezone.convert_to_utc(cast(datetime, job_description['LastModifiedTime'])).strftime( + '%Y-%m-%d %H:%M:%S' + ) status_strs.append(f"{time_str} {transition['Status']} - {message}") return '\n'.join(status_strs) diff --git a/airflow/providers/amazon/aws/operators/eks.py b/airflow/providers/amazon/aws/operators/eks.py index cb17cdfb0ab02..acc323ea176d4 100644 --- a/airflow/providers/amazon/aws/operators/eks.py +++ b/airflow/providers/amazon/aws/operators/eks.py @@ -17,8 +17,9 @@ """This module contains Amazon EKS operators.""" import warnings +from ast import literal_eval from time import sleep -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union, cast from airflow import AirflowException from airflow.models import BaseOperator @@ -127,13 +128,13 @@ def __init__( self, cluster_name: str, cluster_role_arn: str, - resources_vpc_config: Dict, + resources_vpc_config: Dict[str, Any], compute: Optional[str] = DEFAULT_COMPUTE_TYPE, create_cluster_kwargs: Optional[Dict] = None, - nodegroup_name: Optional[str] = DEFAULT_NODEGROUP_NAME, + nodegroup_name: str = DEFAULT_NODEGROUP_NAME, nodegroup_role_arn: Optional[str] = None, create_nodegroup_kwargs: Optional[Dict] = None, - fargate_profile_name: Optional[str] = DEFAULT_FARGATE_PROFILE_NAME, + fargate_profile_name: str = DEFAULT_FARGATE_PROFILE_NAME, fargate_pod_execution_role_arn: Optional[str] = None, fargate_selectors: Optional[List] = None, create_fargate_profile_kwargs: Optional[Dict] = None, @@ -211,7 +212,7 @@ def execute(self, context: 'Context'): eks_hook.create_nodegroup( clusterName=self.cluster_name, nodegroupName=self.nodegroup_name, - subnets=self.resources_vpc_config.get('subnetIds'), + subnets=cast(List[str], self.resources_vpc_config.get('subnetIds')), nodeRole=self.nodegroup_role_arn, **self.create_nodegroup_kwargs, ) @@ -264,21 +265,34 @@ class EksCreateNodegroupOperator(BaseOperator): def __init__( self, cluster_name: str, - nodegroup_subnets: List[str], + nodegroup_subnets: Union[List[str], str], nodegroup_role_arn: str, - nodegroup_name: Optional[str] = DEFAULT_NODEGROUP_NAME, + nodegroup_name: str = DEFAULT_NODEGROUP_NAME, create_nodegroup_kwargs: Optional[Dict] = None, aws_conn_id: str = DEFAULT_CONN_ID, region: Optional[str] = None, **kwargs, ) -> None: self.cluster_name = cluster_name - self.nodegroup_subnets = nodegroup_subnets self.nodegroup_role_arn = nodegroup_role_arn self.nodegroup_name = nodegroup_name self.create_nodegroup_kwargs = create_nodegroup_kwargs or {} self.aws_conn_id = aws_conn_id self.region = region + nodegroup_subnets_list: List[str] = [] + if isinstance(nodegroup_subnets, str): + if nodegroup_subnets != "": + try: + nodegroup_subnets_list = cast(List, literal_eval(nodegroup_subnets)) + except ValueError: + self.log.warning( + "The nodegroup_subnets should be List or string representing " + "Python list and is %s. Defaulting to []", + nodegroup_subnets, + ) + else: + nodegroup_subnets_list = nodegroup_subnets + self.nodegroup_subnets = nodegroup_subnets_list super().__init__(**kwargs) def execute(self, context: 'Context'): @@ -286,7 +300,6 @@ def execute(self, context: 'Context'): aws_conn_id=self.aws_conn_id, region_name=self.region, ) - eks_hook.create_nodegroup( clusterName=self.cluster_name, nodegroupName=self.nodegroup_name, diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index 390259cc07165..29bac5cb65a32 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -168,7 +168,7 @@ def __init__( self.client_request_token = client_request_token or str(uuid4()) self.poll_interval = poll_interval self.max_tries = max_tries - self.job_id = None + self.job_id: Optional[str] = None @cached_property def hook(self) -> EmrContainerHook: diff --git a/airflow/providers/amazon/aws/operators/glue.py b/airflow/providers/amazon/aws/operators/glue.py index 75f989e30194f..fbbc29c62a918 100644 --- a/airflow/providers/amazon/aws/operators/glue.py +++ b/airflow/providers/amazon/aws/operators/glue.py @@ -62,10 +62,10 @@ def __init__( *, job_name: str = 'aws_glue_default_job', job_desc: str = 'AWS Glue Job with Airflow', - script_location: Optional[str] = None, + script_location: str, concurrent_run_limit: Optional[int] = None, script_args: Optional[dict] = None, - retry_limit: Optional[int] = None, + retry_limit: int = 0, num_of_dpus: int = 6, aws_conn_id: str = 'aws_default', region_name: Optional[str] = None, @@ -100,7 +100,7 @@ def execute(self, context: 'Context'): :return: the id of the current glue job. """ - if self.script_location and not self.script_location.startswith(self.s3_protocol): + if not self.script_location.startswith(self.s3_protocol): s3_hook = S3Hook(aws_conn_id=self.aws_conn_id) script_name = os.path.basename(self.script_location) s3_hook.load_file( diff --git a/airflow/providers/amazon/aws/operators/s3.py b/airflow/providers/amazon/aws/operators/s3.py index 4ba03cb419cc1..a05393b43998c 100644 --- a/airflow/providers/amazon/aws/operators/s3.py +++ b/airflow/providers/amazon/aws/operators/s3.py @@ -486,8 +486,9 @@ def execute(self, context: 'Context'): close_fds=True, ) as process: self.log.info("Output:") - for line in iter(process.stdout.readline, b''): - self.log.info(line.decode(self.output_encoding).rstrip()) + if process.stdout is not None: + for line in iter(process.stdout.readline, b''): + self.log.info(line.decode(self.output_encoding).rstrip()) process.wait()