Skip to content

Commit

Permalink
Even more typing in operators (template_fields/ext) (#20608)
Browse files Browse the repository at this point in the history
Part of #19891
There were few more places where I missed adding Sequence
typing - including examples (also converted to tuples) and
also template_ext. Also in a few places iterable was left

GitOrigin-RevId: 83f8e178ba7a3d4ca012c831a5bfc2cade9e812d
  • Loading branch information
potiuk authored and Cloud Composer Team committed Nov 7, 2024
1 parent a2af606 commit b4296c1
Show file tree
Hide file tree
Showing 89 changed files with 150 additions and 146 deletions.
6 changes: 3 additions & 3 deletions airflow/decorators/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.

from typing import Callable, Optional, TypeVar
from typing import Callable, Optional, Sequence, TypeVar

from airflow.decorators.base import DecoratedOperator, task_decorator_factory
from airflow.operators.python import PythonOperator
Expand All @@ -39,12 +39,12 @@ class _PythonDecoratedOperator(DecoratedOperator, PythonOperator):
:type multiple_outputs: bool
"""

template_fields = ('op_args', 'op_kwargs')
template_fields: Sequence[str] = ('op_args', 'op_kwargs')
template_fields_renderers = {"op_args": "py", "op_kwargs": "py"}

# since we won't mutate the arguments, we should just do the shallow copy
# there are some cases we can't deepcopy the objects (e.g protobuf).
shallow_copy_attrs = ('python_callable',)
shallow_copy_attrs: Sequence[str] = ('python_callable',)

def __init__(
self,
Expand Down
6 changes: 3 additions & 3 deletions airflow/decorators/python_virtualenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import inspect
from textwrap import dedent
from typing import Callable, Optional, TypeVar
from typing import Callable, Optional, Sequence, TypeVar

from airflow.decorators.base import DecoratedOperator, task_decorator_factory
from airflow.operators.python import PythonVirtualenvOperator
Expand All @@ -42,12 +42,12 @@ class _PythonVirtualenvDecoratedOperator(DecoratedOperator, PythonVirtualenvOper
:type multiple_outputs: bool
"""

template_fields = ('op_args', 'op_kwargs')
template_fields: Sequence[str] = ('op_args', 'op_kwargs')
template_fields_renderers = {"op_args": "py", "op_kwargs": "py"}

# since we won't mutate the arguments, we should just do the shallow copy
# there are some cases we can't deepcopy the objects (e.g protobuf).
shallow_copy_attrs = ('python_callable',)
shallow_copy_attrs: Sequence[str] = ('python_callable',)

def __init__(
self,
Expand Down
6 changes: 3 additions & 3 deletions airflow/operators/bash.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
import os
from typing import Dict, Optional
from typing import Dict, Optional, Sequence

from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException, AirflowSkipException
Expand Down Expand Up @@ -128,9 +128,9 @@ class BashOperator(BaseOperator):
"""

template_fields = ('bash_command', 'env')
template_fields: Sequence[str] = ('bash_command', 'env')
template_fields_renderers = {'bash_command': 'bash', 'env': 'json'}
template_ext = (
template_ext: Sequence[str] = (
'.sh',
'.bash',
)
Expand Down
6 changes: 3 additions & 3 deletions airflow/operators/email.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Sequence, Union

from airflow.models import BaseOperator
from airflow.utils.context import Context
Expand Down Expand Up @@ -48,9 +48,9 @@ class EmailOperator(BaseOperator):
:type custom_headers: dict
"""

template_fields = ('to', 'subject', 'html_content', 'files')
template_fields: Sequence[str] = ('to', 'subject', 'html_content', 'files')
template_fields_renderers = {"html_content": "html"}
template_ext = ('.html',)
template_ext: Sequence[str] = ('.html',)
ui_color = '#e6faf9'

def __init__(
Expand Down
6 changes: 3 additions & 3 deletions airflow/operators/generic_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import List, Optional, Union
from typing import List, Optional, Sequence, Union

from airflow.hooks.base import BaseHook
from airflow.models import BaseOperator
Expand Down Expand Up @@ -46,8 +46,8 @@ class GenericTransfer(BaseOperator):
:type insert_args: dict
"""

template_fields = ('sql', 'destination_table', 'preoperator')
template_ext = (
template_fields: Sequence[str] = ('sql', 'destination_table', 'preoperator')
template_ext: Sequence[str] = (
'.sql',
'.hql',
)
Expand Down
28 changes: 14 additions & 14 deletions airflow/operators/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, Iterable, List, Mapping, Optional, SupportsAbs, Union
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, SupportsAbs, Union

from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -124,8 +124,8 @@ class SQLCheckOperator(BaseSQLOperator):
:type database: str
"""

template_fields: Iterable[str] = ("sql",)
template_ext: Iterable[str] = (
template_fields: Sequence[str] = ("sql",)
template_ext: Sequence[str] = (
".hql",
".sql",
)
Expand Down Expand Up @@ -178,14 +178,14 @@ class SQLValueCheckOperator(BaseSQLOperator):
"""

__mapper_args__ = {"polymorphic_identity": "SQLValueCheckOperator"}
template_fields = (
template_fields: Sequence[str] = (
"sql",
"pass_value",
) # type: Iterable[str]
template_ext = (
)
template_ext: Sequence[str] = (
".hql",
".sql",
) # type: Iterable[str]
)
ui_color = "#fff7e6"

def __init__(
Expand Down Expand Up @@ -289,8 +289,8 @@ class SQLIntervalCheckOperator(BaseSQLOperator):
"""

__mapper_args__ = {"polymorphic_identity": "SQLIntervalCheckOperator"}
template_fields: Iterable[str] = ("sql1", "sql2")
template_ext: Iterable[str] = (
template_fields: Sequence[str] = ("sql1", "sql2")
template_ext: Sequence[str] = (
".hql",
".sql",
)
Expand Down Expand Up @@ -418,11 +418,11 @@ class SQLThresholdCheckOperator(BaseSQLOperator):
:type max_threshold: numeric or str
"""

template_fields = ("sql", "min_threshold", "max_threshold")
template_ext = (
template_fields: Sequence[str] = ("sql", "min_threshold", "max_threshold")
template_ext: Sequence[str] = (
".hql",
".sql",
) # type: Iterable[str]
)

def __init__(
self,
Expand Down Expand Up @@ -505,8 +505,8 @@ class BranchSQLOperator(BaseSQLOperator, SkipMixin):
:type parameters: mapping or iterable
"""

template_fields = ("sql",)
template_ext = (".sql",)
template_fields: Sequence[str] = ("sql",)
template_ext: Sequence[str] = (".sql",)
ui_color = "#a22034"
ui_fgcolor = "#F7F7F7"

Expand Down
4 changes: 2 additions & 2 deletions airflow/operators/trigger_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import datetime
import json
import time
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Sequence, Union

from airflow.api.common.trigger_dag import trigger_dag
from airflow.exceptions import AirflowException, DagNotFound, DagRunAlreadyExists
Expand Down Expand Up @@ -83,7 +83,7 @@ class TriggerDagRunOperator(BaseOperator):
:type failed_states: list
"""

template_fields = ("trigger_dag_id", "trigger_run_id", "execution_date", "conf")
template_fields: Sequence[str] = ("trigger_dag_id", "trigger_run_id", "execution_date", "conf")
template_fields_renderers = {"conf": "py"}
ui_color = "#ffefeb"

Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/operators/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class AthenaOperator(BaseOperator):

ui_color = '#44b5e2'
template_fields: Sequence[str] = ('query', 'database', 'output_location')
template_ext = ('.sql',)
template_ext: Sequence[str] = ('.sql',)
template_fields_renderers = {"query": "sql"}

def __init__(
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/operators/cloud_formation.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class CloudFormationCreateStackOperator(BaseOperator):
"""

template_fields: Sequence[str] = ('stack_name',)
template_ext = ()
template_ext: Sequence[str] = ()
ui_color = '#6b9659'

def __init__(self, *, stack_name: str, params: dict, aws_conn_id: str = 'aws_default', **kwargs):
Expand Down Expand Up @@ -73,7 +73,7 @@ class CloudFormationDeleteStackOperator(BaseOperator):
"""

template_fields: Sequence[str] = ('stack_name',)
template_ext = ()
template_ext: Sequence[str] = ()
ui_color = '#1d472b'
ui_fgcolor = '#FFF'

Expand Down
10 changes: 5 additions & 5 deletions airflow/providers/amazon/aws/operators/dms.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class DmsCreateTaskOperator(BaseOperator):
'migration_type',
'create_task_kwargs',
)
template_ext = ()
template_ext: Sequence[str] = ()
template_fields_renderers = {
"table_mappings": "json",
"create_task_kwargs": "json",
Expand Down Expand Up @@ -135,7 +135,7 @@ class DmsDeleteTaskOperator(BaseOperator):
"""

template_fields: Sequence[str] = ('replication_task_arn',)
template_ext = ()
template_ext: Sequence[str] = ()
template_fields_renderers: Dict[str, str] = {}

def __init__(
Expand Down Expand Up @@ -175,7 +175,7 @@ class DmsDescribeTasksOperator(BaseOperator):
"""

template_fields: Sequence[str] = ('describe_tasks_kwargs',)
template_ext = ()
template_ext: Sequence[str] = ()
template_fields_renderers: Dict[str, str] = {'describe_tasks_kwargs': 'json'}

def __init__(
Expand Down Expand Up @@ -228,7 +228,7 @@ class DmsStartTaskOperator(BaseOperator):
'start_replication_task_type',
'start_task_kwargs',
)
template_ext = ()
template_ext: Sequence[str] = ()
template_fields_renderers = {'start_task_kwargs': 'json'}

def __init__(
Expand Down Expand Up @@ -277,7 +277,7 @@ class DmsStopTaskOperator(BaseOperator):
"""

template_fields: Sequence[str] = ('replication_task_arn',)
template_ext = ()
template_ext: Sequence[str] = ()
template_fields_renderers: Dict[str, str] = {}

def __init__(
Expand Down
8 changes: 4 additions & 4 deletions airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class EmrAddStepsOperator(BaseOperator):
"""

template_fields: Sequence[str] = ('job_flow_id', 'job_flow_name', 'cluster_states', 'steps')
template_ext = ('.json',)
template_ext: Sequence[str] = ('.json',)
template_fields_renderers = {"steps": "json"}
ui_color = '#f9c915'

Expand Down Expand Up @@ -281,7 +281,7 @@ class EmrCreateJobFlowOperator(BaseOperator):
"""

template_fields: Sequence[str] = ('job_flow_overrides',)
template_ext = ('.json',)
template_ext: Sequence[str] = ('.json',)
template_fields_renderers = {"job_flow_overrides": "json"}
ui_color = '#f9c915'
operator_extra_links = (EmrClusterLink(),)
Expand Down Expand Up @@ -340,7 +340,7 @@ class EmrModifyClusterOperator(BaseOperator):
"""

template_fields: Sequence[str] = ('cluster_id', 'step_concurrency_level')
template_ext = ()
template_ext: Sequence[str] = ()
ui_color = '#f9c915'

def __init__(
Expand Down Expand Up @@ -384,7 +384,7 @@ class EmrTerminateJobFlowOperator(BaseOperator):
"""

template_fields: Sequence[str] = ('job_flow_id',)
template_ext = ()
template_ext: Sequence[str] = ()
ui_color = '#f9c915'

def __init__(self, *, job_flow_id: str, aws_conn_id: str = 'aws_default', **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/operators/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class GlueJobOperator(BaseOperator):
"""

template_fields: Sequence[str] = ('script_args',)
template_ext = ()
template_ext: Sequence[str] = ()
template_fields_renderers = {
"script_args": "json",
"create_job_kwargs": "json",
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/operators/redshift_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class RedshiftSQLOperator(BaseOperator):
"""

template_fields: Sequence[str] = ('sql',)
template_ext = ('.sql',)
template_ext: Sequence[str] = ('.sql',)

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/operators/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ class S3FileTransformOperator(BaseOperator):
"""

template_fields: Sequence[str] = ('source_s3_key', 'dest_s3_key', 'script_args')
template_ext = ()
template_ext: Sequence[str] = ()
ui_color = '#f9c915'

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/operators/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class SageMakerBaseOperator(BaseOperator):
"""

template_fields: Sequence[str] = ('config',)
template_ext = ()
template_ext: Sequence[str] = ()
template_fields_renderers = {'config': 'json'}
ui_color = '#ededed'
integer_fields = []
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/operators/sns.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class SnsPublishOperator(BaseOperator):
"""

template_fields: Sequence[str] = ('message', 'subject', 'message_attributes')
template_ext = ()
template_ext: Sequence[str] = ()
template_fields_renderers = {"message_attributes": "json"}

def __init__(
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/operators/step_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class StepFunctionStartExecutionOperator(BaseOperator):
"""

template_fields: Sequence[str] = ('state_machine_arn', 'name', 'input')
template_ext = ()
template_ext: Sequence[str] = ()
ui_color = '#f9c915'

def __init__(
Expand Down Expand Up @@ -98,7 +98,7 @@ class StepFunctionGetExecutionOutputOperator(BaseOperator):
"""

template_fields: Sequence[str] = ('execution_arn',)
template_ext = ()
template_ext: Sequence[str] = ()
ui_color = '#f9c915'

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/sensors/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class AthenaSensor(BaseSensorOperator):
SUCCESS_STATES = ('SUCCEEDED',)

template_fields: Sequence[str] = ('query_execution_id',)
template_ext = ()
template_ext: Sequence[str] = ()
ui_color = '#66c3ff'

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/sensors/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class BatchSensor(BaseSensorOperator):
"""

template_fields: Sequence[str] = ('job_id',)
template_ext = ()
template_ext: Sequence[str] = ()
ui_color = '#66c3ff'

def __init__(
Expand Down
Loading

0 comments on commit b4296c1

Please sign in to comment.