Skip to content

Commit

Permalink
Merge pull request #1318 from jlowin/infer_dag
Browse files Browse the repository at this point in the history
Syntactic Sugar! Dag inference, operator composition, and a big docs update
  • Loading branch information
jlowin committed Apr 12, 2016
2 parents f8d19b4 + 9b6c84d commit 9689159
Show file tree
Hide file tree
Showing 7 changed files with 561 additions and 142 deletions.
13 changes: 13 additions & 0 deletions airflow/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def run_command(command):
'dagbag_import_timeout': 30,
'non_pooled_task_slot_count': 128,
},
'operators': {
'default_owner': 'airflow'
},
'webserver': {
'base_url': 'http://localhost:8080',
'web_server_host': '0.0.0.0',
Expand Down Expand Up @@ -217,6 +220,13 @@ def run_command(command):
# How long before timing out a python file import while filling the DagBag
dagbag_import_timeout = 30
[operators]
# The default owner assigned to each new operator, unless
# provided explicitly or passed via `default_args`
default_owner = Airflow
[webserver]
# The base url of your website as airflow cannot guess what domain or
# cname you are using. This is used in automated emails that
Expand Down Expand Up @@ -373,6 +383,9 @@ def run_command(command):
fernet_key = {FERNET_KEY}
non_pooled_task_slot_count = 128
[operators]
default_owner = airflow
[webserver]
base_url = http://localhost:8080
web_server_host = 0.0.0.0
Expand Down
161 changes: 147 additions & 14 deletions airflow/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@
from airflow.utils.timeout import timeout
from airflow.utils.trigger_rule import TriggerRule


Base = declarative_base()
ID_LEN = 250
SQL_ALCHEMY_CONN = configuration.get('core', 'SQL_ALCHEMY_CONN')
Expand All @@ -85,6 +84,9 @@
else:
LongText = Text

# used by DAG context_managers
_CONTEXT_MANAGER_DAG = None


def clear_task_instances(tis, session, activate_dag_runs=True):
'''
Expand Down Expand Up @@ -1604,7 +1606,7 @@ class derived from this one results in the creation of a task object,
def __init__(
self,
task_id,
owner,
owner=configuration.get('operators', 'DEFAULT_OWNER'),
email=None,
email_on_retry=True,
email_on_failure=True,
Expand Down Expand Up @@ -1643,7 +1645,6 @@ def __init__(
)

validate_key(task_id)
self.dag_id = dag.dag_id if dag else 'adhoc_' + owner
self.task_id = task_id
self.owner = owner
self.email = email
Expand Down Expand Up @@ -1689,14 +1690,16 @@ def __init__(
self.params = params or {} # Available in templates!
self.adhoc = adhoc
self.priority_weight = priority_weight
if dag:
dag.add_task(self)
self.dag = dag

# Private attributes
self._upstream_task_ids = []
self._downstream_task_ids = []

if not dag and _CONTEXT_MANAGER_DAG:
dag = _CONTEXT_MANAGER_DAG
if dag:
self.dag = dag

self._comps = {
'task_id',
'dag_id',
Expand Down Expand Up @@ -1740,14 +1743,104 @@ def __hash__(self):
hash_components.append(repr(val))
return hash(tuple(hash_components))

# Composing Operators -----------------------------------------------

def __rshift__(self, other):
"""
Implements Self >> Other == self.set_downstream(other)
If "Other" is a DAG, the DAG is assigned to the Operator.
"""
if isinstance(other, DAG):
# if this dag is already assigned, do nothing
# otherwise, do normal dag assignment
if not (self.has_dag() and self.dag is other):
self.dag = other
else:
self.set_downstream(other)
return other

def __lshift__(self, other):
"""
Implements Self << Other == self.set_upstream(other)
If "Other" is a DAG, the DAG is assigned to the Operator.
"""
if isinstance(other, DAG):
# if this dag is already assigned, do nothing
# otherwise, do normal dag assignment
if not (self.has_dag() and self.dag is other):
self.dag = other
else:
self.set_upstream(other)
return other

def __rrshift__(self, other):
"""
Called for [DAG] >> [Operator] because DAGs don't have
__rshift__ operators.
"""
self.__lshift__(other)
return self

def __rlshift__(self, other):
"""
Called for [DAG] << [Operator] because DAGs don't have
__lshift__ operators.
"""
self.__rshift__(other)
return self

# /Composing Operators ---------------------------------------------

@property
def dag(self):
"""
Returns the Operator's DAG if set, otherwise raises an error
"""
if self.has_dag():
return self._dag
else:
raise AirflowException(
'Operator {} has not been assigned to a DAG yet'.format(self))

@dag.setter
def dag(self, dag):
"""
Operators can be assigned to one DAG, one time. Repeat assignments to
that same DAG are ok.
"""
if not isinstance(dag, DAG):
raise TypeError(
'Expected DAG; received {}'.format(dag.__class__.__name__))
elif self.has_dag() and self.dag is not dag:
raise AirflowException(
"The DAG assigned to {} can not be changed.".format(self))
elif self.task_id not in [t.task_id for t in dag.tasks]:
dag.add_task(self)
self._dag = dag

def has_dag(self):
"""
Returns True if the Operator has been assigned to a DAG.
"""
return getattr(self, '_dag', None) is not None

@property
def dag_id(self):
if self.has_dag():
return self.dag.dag_id
else:
return 'adhoc_' + self.owner

@property
def schedule_interval(self):
"""
The schedule interval of the DAG always wins over individual tasks so
that tasks within a DAG always line up. The task still needs a
schedule_interval as it may not be attached to a DAG.
"""
if hasattr(self, 'dag') and self.dag:
if self.has_dag():
return self.dag._schedule_interval
else:
return self._schedule_interval
Expand Down Expand Up @@ -1998,7 +2091,6 @@ def dry_run(self):
logging.info('Rendering template for {0}'.format(attr))
logging.info(content)


def get_direct_relatives(self, upstream=False):
"""
Get the direct relatives to the current task, upstream or
Expand All @@ -2010,7 +2102,8 @@ def get_direct_relatives(self, upstream=False):
return self.downstream_list

def __repr__(self):
return "<Task({self.__class__.__name__}): {self.task_id}>".format(self=self)
return "<Task({self.__class__.__name__}): {self.task_id}>".format(
self=self)

@property
def task_type(self):
Expand All @@ -2029,9 +2122,35 @@ def _set_relatives(self, task_or_task_list, upstream=False):
task_list = list(task_or_task_list)
except TypeError:
task_list = [task_or_task_list]

for t in task_list:
if not isinstance(t, BaseOperator):
raise AirflowException(
"Relationships can only be set between "
"Operators; received {}".format(t.__class__.__name__))

# relationships can only be set if the tasks share a single DAG. Tasks
# without a DAG are assigned to that DAG.
dags = set(t.dag for t in [self] + task_list if t.has_dag())

if len(dags) > 1:
raise AirflowException(
'Tried to set relationships between tasks in '
'more than one DAG: {}'.format(dags))
elif len(dags) == 1:
dag = list(dags)[0]
else:
raise AirflowException(
"Tried to create relationships between tasks that don't have "
"DAGs yet. Set the DAG for at least one "
"task and try again: {}".format([self] + task_list))

if dag and not self.has_dag():
self.dag = dag

for task in task_list:
if not isinstance(task, BaseOperator):
raise AirflowException('Expecting a task')
if dag and not task.has_dag():
task.dag = dag
if upstream:
task.append_only_new(task._downstream_task_ids, self.task_id)
self.append_only_new(self._upstream_task_ids, task.task_id)
Expand Down Expand Up @@ -2279,6 +2398,20 @@ def __hash__(self):
hash_components.append(repr(val))
return hash(tuple(hash_components))

# Context Manager -----------------------------------------------

def __enter__(self):
global _CONTEXT_MANAGER_DAG
self._old_context_manager_dag = _CONTEXT_MANAGER_DAG
_CONTEXT_MANAGER_DAG = self
return self

def __exit__(self, _type, _value, _tb):
global _CONTEXT_MANAGER_DAG
_CONTEXT_MANAGER_DAG = self._old_context_manager_dag

# /Context Manager ----------------------------------------------

def date_range(self, start_date, num=None, end_date=datetime.now()):
if num:
end_date = None
Expand Down Expand Up @@ -2739,8 +2872,8 @@ def add_task(self, task):
"to the DAG ".format(task.task_id))
else:
self.tasks.append(task)
task.dag_id = self.dag_id
task.dag = self

self.task_count = len(self.tasks)

def add_tasks(self, tasks):
Expand Down Expand Up @@ -3057,8 +3190,8 @@ def delete(cls, xcoms, session=None):
xcoms = [xcoms]
for xcom in xcoms:
if not isinstance(xcom, XCom):
raise TypeError(
'Expected XCom; received {}'.format(type(xcom)))
raise TypeError('Expected XCom; received {}'.format(
xcom.__class__.__name__))
session.delete(xcom)
session.commit()

Expand Down
5 changes: 3 additions & 2 deletions airflow/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ def wrapper(*args, **kwargs):
"Use keyword arguments when initializing operators")
dag_args = {}
dag_params = {}
if 'dag' in kwargs and kwargs['dag']:
dag = kwargs['dag']
import airflow.models
if kwargs.get('dag', None) or airflow.models._CONTEXT_MANAGER_DAG:
dag = kwargs.get('dag', None) or airflow.models._CONTEXT_MANAGER_DAG
dag_args = copy(dag.default_args) or {}
dag_params = copy(dag.params) or {}

Expand Down
Loading

0 comments on commit 9689159

Please sign in to comment.