Skip to content

Commit

Permalink
Use TaskMixin(#10930)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuqian90 committed Sep 18, 2020
1 parent 9157991 commit 4c9464a
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 87 deletions.
28 changes: 11 additions & 17 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,21 +1132,25 @@ def roots(self) -> List["BaseOperator"]:
"""Required by TaskMixin"""
return [self]

@property
def leaves(self) -> List["BaseOperator"]:
"""Required by TaskMixin"""
return [self]

def _set_relatives(
self,
task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]],
upstream: bool = False,
) -> None:
"""Sets relatives for the task or task list."""

if isinstance(task_or_task_list, Sequence):
task_like_object_list = task_or_task_list
else:
task_like_object_list = [task_or_task_list]
if not isinstance(task_or_task_list, Sequence):
task_or_task_list = [task_or_task_list]

task_list: List["BaseOperator"] = []
for task_object in task_like_object_list:
task_list.extend(task_object.roots)
for task_object in task_or_task_list:
task_object.update_relative(self, not upstream)
relatives = task_object.leaves if upstream else task_object.roots
task_list.extend(relatives)

for task in task_list:
if not isinstance(task, BaseOperator):
Expand Down Expand Up @@ -1190,23 +1194,13 @@ def set_downstream(self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]
Set a task or a task list to be directly downstream from the current
task. Required by TaskMixin.
"""
from airflow.utils.task_group import TaskGroup

if isinstance(task_or_task_list, TaskGroup):
task_or_task_list.upstream_task_ids.add(self.task_id)
task_or_task_list = list(task_or_task_list.get_roots())
self._set_relatives(task_or_task_list, upstream=False)

def set_upstream(self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]]) -> None:
"""
Set a task or a task list to be directly upstream from the current
task. Required by TaskMixin.
"""
from airflow.utils.task_group import TaskGroup

if isinstance(task_or_task_list, TaskGroup):
task_or_task_list.downstream_task_ids.add(self.task_id)
task_or_task_list = list(task_or_task_list.get_leaves())
self._set_relatives(task_or_task_list, upstream=True)

@property
Expand Down
11 changes: 11 additions & 0 deletions airflow/models/taskmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ def roots(self):
"""Should return list of root operator List[BaseOperator]"""
raise NotImplementedError()

@property
def leaves(self):
"""Should return list of leaf operator List[BaseOperator]"""
raise NotImplementedError()

@abstractmethod
def set_upstream(self, other: Union["TaskMixin", Sequence["TaskMixin"]]):
"""
Expand All @@ -47,6 +52,12 @@ def set_downstream(self, other: Union["TaskMixin", Sequence["TaskMixin"]]):
"""
raise NotImplementedError()

def update_relative(self, other: "TaskMixin", upstream=True) -> None:
"""
Update relationship information about another TaskMixin. Default is no-op.
Override if necessary.
"""

def __lshift__(self, other: Union["TaskMixin", Sequence["TaskMixin"]]):
"""
Implements Task << Task
Expand Down
10 changes: 6 additions & 4 deletions airflow/models/xcom_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,13 @@
# specific language governing permissions and limitations
# under the License.

from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Union
from typing import Any, Dict, List, Sequence, Union

from airflow.exceptions import AirflowException
from airflow.models.baseoperator import BaseOperator # pylint: disable=R0401
from airflow.models.taskmixin import TaskMixin
from airflow.models.xcom import XCOM_RETURN_KEY

if TYPE_CHECKING:
from airflow.utils.task_group import TaskGroup # pylint: disable=cyclic-import


class XComArg(TaskMixin):
"""
Expand Down Expand Up @@ -105,6 +102,11 @@ def roots(self) -> List[BaseOperator]:
"""Required by TaskMixin"""
return [self._operator]

@property
def leaves(self) -> List[BaseOperator]:
"""Required by TaskMixin"""
return [self._operator]

@property
def key(self) -> str:
"""Returns keys of this XComArg"""
Expand Down
115 changes: 49 additions & 66 deletions airflow/utils/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Set, Union

from airflow.exceptions import AirflowException, DuplicateTaskIdFound
from airflow.models.taskmixin import TaskMixin

if TYPE_CHECKING:
from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import DAG


class TaskGroup:
class TaskGroup(TaskMixin):
"""
A collection of tasks. When set_downstream() or set_upstream() are called on the
TaskGroup, it is applied across all tasks within the group if necessary.
Expand Down Expand Up @@ -168,66 +169,68 @@ def label(self) -> Optional[str]:
"""
return self._group_id

def _set_relative(
self,
task_or_task_list: Union['BaseOperator', Sequence['BaseOperator'], "TaskGroup"],
upstream: bool = False
) -> None:
def update_relative(self, other: "TaskMixin", upstream=True) -> None:
"""
Call set_upstream/set_downstream for all root/leaf tasks within this TaskGroup.
Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids.
Overrides TaskMixin.update_relative.
Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids
accordingly so that we can reduce the number of edges when displaying Graph View.
"""
from airflow.models.baseoperator import BaseOperator
from airflow.models.xcom_arg import XComArg

if upstream:
for task in self.get_roots():
task.set_upstream(task_or_task_list)
else:
for task in self.get_leaves():
task.set_downstream(task_or_task_list)

# Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids
# accordingly so that we can reduce the number of edges when displaying Graph View.
if isinstance(task_or_task_list, TaskGroup):
# Handles TaskGroup and TaskGroup
if isinstance(other, TaskGroup):
# Handles setting relationship between a TaskGroup and another TaskGroup
if upstream:
parent, child = (self, task_or_task_list)
parent, child = (self, other)
else:
parent, child = (task_or_task_list, self)
parent, child = (other, self)

parent.upstream_group_ids.add(child.group_id)
child.downstream_group_ids.add(parent.group_id)
else:
if isinstance(task_or_task_list, XComArg):
task_list = [task_or_task_list.operator]
else:
# Handles TaskGroup and task or list of tasks
try:
task_list = list(task_or_task_list) # type: ignore
except TypeError:
task_list = [task_or_task_list] # type: ignore

for task in task_list:
# Handles setting relationship between a TaskGroup and a task
for task in other.roots:
if not isinstance(task, BaseOperator):
raise AirflowException("Relationships can only be set between TaskGroup or operators; "
f"received {task.__class__.__name__}")
raise AirflowException("Relationships can only be set between TaskGroup "
f"or operators; received {task.__class__.__name__}")

if upstream:
self.upstream_task_ids.add(task.task_id)
else:
self.downstream_task_ids.add(task.task_id)

def _set_relative(
self,
task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]],
upstream: bool = False
) -> None:
"""
Call set_upstream/set_downstream for all root/leaf tasks within this TaskGroup.
Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids.
"""
if upstream:
for task in self.get_roots():
task.set_upstream(task_or_task_list)
else:
for task in self.get_leaves():
task.set_downstream(task_or_task_list)

if not isinstance(task_or_task_list, Sequence):
task_or_task_list = [task_or_task_list]

for task_like in task_or_task_list:
self.update_relative(task_like, upstream)

def set_downstream(
self, task_or_task_list: Union['BaseOperator', Sequence['BaseOperator'], "TaskGroup"]
self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]]
) -> None:
"""
Set a TaskGroup/task/list of task downstream of this TaskGroup.
"""
self._set_relative(task_or_task_list, upstream=False)

def set_upstream(
self, task_or_task_list: Union['BaseOperator', Sequence['BaseOperator'], "TaskGroup"]
self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]]
) -> None:
"""
Set a TaskGroup/task/list of task upstream of this TaskGroup.
Expand All @@ -250,6 +253,16 @@ def has_task(self, task: "BaseOperator") -> bool:

return any(child.has_task(task) for child in self.children.values() if isinstance(child, TaskGroup))

@property
def roots(self) -> List["BaseOperator"]:
"""Required by TaskMixin"""
return list(self.get_roots())

@property
def leaves(self) -> List["BaseOperator"]:
"""Required by TaskMixin"""
return list(self.get_leaves())

def get_roots(self) -> Generator["BaseOperator", None, None]:
"""
Returns a generator of tasks that are root tasks, i.e. those with no upstream
Expand All @@ -268,36 +281,6 @@ def get_leaves(self) -> Generator["BaseOperator", None, None]:
if not any(self.has_task(child) for child in task.get_direct_relatives(upstream=False)):
yield task

def __rshift__(self, other):
"""
Implements Self >> Other == self.set_downstream(other)
"""
self.set_downstream(other)
return other

def __lshift__(self, other):
"""
Implements Self << Other == self.set_upstream(other)
"""
self.set_upstream(other)
return other

def __rrshift__(self, other):
"""
Called for Operator >> [Operator] because list don't have
__rshift__ operators.
"""
self.__lshift__(other)
return self

def __rlshift__(self, other):
"""
Called for Operator << [Operator] because list don't have
__lshift__ operators.
"""
self.__rshift__(other)
return self

def child_id(self, label):
"""
Prefix label with group_id if prefix_group_id is True. Otherwise return the label
Expand Down

0 comments on commit 4c9464a

Please sign in to comment.