Skip to content

Commit

Permalink
Fix all Amazon Provider MyPy errors
Browse files Browse the repository at this point in the history
Part of #19891
  • Loading branch information
potiuk committed Jan 20, 2022
1 parent 8ff5fe2 commit 157b1f5
Show file tree
Hide file tree
Showing 10 changed files with 53 additions and 28 deletions.
4 changes: 3 additions & 1 deletion airflow/contrib/sensors/aws_redshift_cluster_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@

import warnings

from airflow.providers.amazon.aws.sensors.redshift_cluster import AwsRedshiftClusterSensor
from airflow.providers.amazon.aws.sensors.redshift_cluster import (
RedshiftClusterSensor as AwsRedshiftClusterSensor,
)

warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.redshift_cluster`.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

CLUSTER_NAME = 'fargate-demo'
FARGATE_PROFILE_NAME = f'{CLUSTER_NAME}-profile'
SELECTORS = environ.get('FARGATE_SELECTORS', [{'namespace': 'default'}])
SELECTORS = [{'namespace': 'default'}]

ROLE_ARN = environ.get('EKS_DEMO_ROLE_ARN', 'arn:aws:iam::123456789012:role/role_name')
SUBNETS = environ.get('EKS_DEMO_SUBNETS', 'subnet-12345ab subnet-67890cd').split(' ')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,10 @@
target_state=ClusterStates.NONEXISTENT,
)

create_cluster_and_nodegroup >> await_create_nodegroup >> start_pod >> delete_all >> await_delete_cluster
(
create_cluster_and_nodegroup
>> await_create_nodegroup
>> start_pod
>> delete_all
>> await_delete_cluster
)
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
) as dag:

# [START howto_operator_eks_create_cluster]
# Create an Amazon EKS Cluster control plane without attaching a compute service.
# Create an Amazon EKS Cluster control plane without attaching compute service.
create_cluster = EksCreateClusterOperator(
task_id='create_eks_cluster',
cluster_role_arn=ROLE_ARN,
Expand Down
12 changes: 6 additions & 6 deletions airflow/providers/amazon/aws/hooks/eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from contextlib import contextmanager
from enum import Enum
from functools import partial
from typing import Callable, Dict, List, Optional
from typing import Callable, Dict, Generator, List, Optional

import yaml
from botocore.exceptions import ClientError
Expand Down Expand Up @@ -127,7 +127,7 @@ def create_nodegroup(
clusterName: str,
nodegroupName: str,
subnets: List[str],
nodeRole: str,
nodeRole: Optional[str],
*,
tags: Optional[Dict] = None,
**kwargs,
Expand Down Expand Up @@ -178,8 +178,8 @@ def create_nodegroup(
def create_fargate_profile(
self,
clusterName: str,
fargateProfileName: str,
podExecutionRoleArn: str,
fargateProfileName: Optional[str],
podExecutionRoleArn: Optional[str],
selectors: List,
**kwargs,
) -> Dict:
Expand Down Expand Up @@ -536,10 +536,10 @@ def _list_all(self, api_call: Callable, response_key: str, verbose: bool) -> Lis
def generate_config_file(
self,
eks_cluster_name: str,
pod_namespace: str,
pod_namespace: Optional[str],
pod_username: Optional[str] = None,
pod_context: Optional[str] = None,
) -> str:
) -> Generator[str, None, None]:
"""
Writes the kubeconfig file given an EKS Cluster.
Expand Down
9 changes: 6 additions & 3 deletions airflow/providers/amazon/aws/hooks/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
import tempfile
import time
import warnings
from datetime import datetime
from functools import partial
from typing import Any, Callable, Dict, Generator, List, Optional, Set
from typing import Any, Callable, Dict, Generator, List, Optional, Set, cast

from botocore.exceptions import ClientError

Expand Down Expand Up @@ -95,7 +96,7 @@ def secondary_training_status_changed(current_job_description: dict, prev_job_de


def secondary_training_status_message(
job_description: Dict[str, List[dict]], prev_description: Optional[dict]
job_description: Dict[str, List[Any]], prev_description: Optional[dict]
) -> str:
"""
Returns a string contains start time and the secondary training job status message.
Expand Down Expand Up @@ -125,7 +126,9 @@ def secondary_training_status_message(
status_strs = []
for transition in transitions_to_print:
message = transition['StatusMessage']
time_str = timezone.convert_to_utc(job_description['LastModifiedTime']).strftime('%Y-%m-%d %H:%M:%S')
time_str = timezone.convert_to_utc(cast(datetime, job_description['LastModifiedTime'])).strftime(
'%Y-%m-%d %H:%M:%S'
)
status_strs.append(f"{time_str} {transition['Status']} - {message}")

return '\n'.join(status_strs)
Expand Down
31 changes: 22 additions & 9 deletions airflow/providers/amazon/aws/operators/eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

"""This module contains Amazon EKS operators."""
import warnings
from ast import literal_eval
from time import sleep
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union, cast

from airflow import AirflowException
from airflow.models import BaseOperator
Expand Down Expand Up @@ -138,13 +139,13 @@ def __init__(
self,
cluster_name: str,
cluster_role_arn: str,
resources_vpc_config: Dict,
resources_vpc_config: Dict[str, Any],
compute: Optional[str] = DEFAULT_COMPUTE_TYPE,
create_cluster_kwargs: Optional[Dict] = None,
nodegroup_name: Optional[str] = DEFAULT_NODEGROUP_NAME,
nodegroup_name: str = DEFAULT_NODEGROUP_NAME,
nodegroup_role_arn: Optional[str] = None,
create_nodegroup_kwargs: Optional[Dict] = None,
fargate_profile_name: Optional[str] = DEFAULT_FARGATE_PROFILE_NAME,
fargate_profile_name: str = DEFAULT_FARGATE_PROFILE_NAME,
fargate_pod_execution_role_arn: Optional[str] = None,
fargate_selectors: Optional[List] = None,
create_fargate_profile_kwargs: Optional[Dict] = None,
Expand Down Expand Up @@ -222,7 +223,7 @@ def execute(self, context: 'Context'):
eks_hook.create_nodegroup(
clusterName=self.cluster_name,
nodegroupName=self.nodegroup_name,
subnets=self.resources_vpc_config.get('subnetIds'),
subnets=cast(List[str], self.resources_vpc_config.get('subnetIds')),
nodeRole=self.nodegroup_role_arn,
**self.create_nodegroup_kwargs,
)
Expand Down Expand Up @@ -281,29 +282,41 @@ class EksCreateNodegroupOperator(BaseOperator):
def __init__(
self,
cluster_name: str,
nodegroup_subnets: List[str],
nodegroup_subnets: Union[List[str], str],
nodegroup_role_arn: str,
nodegroup_name: Optional[str] = DEFAULT_NODEGROUP_NAME,
nodegroup_name: str = DEFAULT_NODEGROUP_NAME,
create_nodegroup_kwargs: Optional[Dict] = None,
aws_conn_id: str = DEFAULT_CONN_ID,
region: Optional[str] = None,
**kwargs,
) -> None:
self.cluster_name = cluster_name
self.nodegroup_subnets = nodegroup_subnets
self.nodegroup_role_arn = nodegroup_role_arn
self.nodegroup_name = nodegroup_name
self.create_nodegroup_kwargs = create_nodegroup_kwargs or {}
self.aws_conn_id = aws_conn_id
self.region = region
nodegroup_subnets_list: List[str] = []
if isinstance(nodegroup_subnets, str):
if nodegroup_subnets != "":
try:
nodegroup_subnets_list = cast(List, literal_eval(nodegroup_subnets))
except ValueError:
self.log.warning(
"The nodegroup_subnets should be List or string representing "
"Python list and is %s. Defaulting to []",
nodegroup_subnets,
)
else:
nodegroup_subnets_list = nodegroup_subnets
self.nodegroup_subnets = nodegroup_subnets_list
super().__init__(**kwargs)

def execute(self, context: 'Context'):
eks_hook = EksHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region,
)

eks_hook.create_nodegroup(
clusterName=self.cluster_name,
nodegroupName=self.nodegroup_name,
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def __init__(
self.client_request_token = client_request_token or str(uuid4())
self.poll_interval = poll_interval
self.max_tries = max_tries
self.job_id = None
self.job_id: Optional[str] = None

@cached_property
def hook(self) -> EmrContainerHook:
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/amazon/aws/operators/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ def __init__(
*,
job_name: str = 'aws_glue_default_job',
job_desc: str = 'AWS Glue Job with Airflow',
script_location: Optional[str] = None,
script_location: str,
concurrent_run_limit: Optional[int] = None,
script_args: Optional[dict] = None,
retry_limit: Optional[int] = None,
retry_limit: int = 0,
num_of_dpus: int = 6,
aws_conn_id: str = 'aws_default',
region_name: Optional[str] = None,
Expand Down Expand Up @@ -113,7 +113,7 @@ def execute(self, context: 'Context'):
:return: the id of the current glue job.
"""
if self.script_location and not self.script_location.startswith(self.s3_protocol):
if not self.script_location.startswith(self.s3_protocol):
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id)
script_name = os.path.basename(self.script_location)
s3_hook.load_file(
Expand Down
5 changes: 3 additions & 2 deletions airflow/providers/amazon/aws/operators/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,8 +523,9 @@ def execute(self, context: 'Context'):
close_fds=True,
) as process:
self.log.info("Output:")
for line in iter(process.stdout.readline, b''):
self.log.info(line.decode(self.output_encoding).rstrip())
if process.stdout is not None:
for line in iter(process.stdout.readline, b''):
self.log.info(line.decode(self.output_encoding).rstrip())

process.wait()

Expand Down

0 comments on commit 157b1f5

Please sign in to comment.