Skip to content

Commit

Permalink
[AIRFLOW-103] Allow jinja templates to be used in task params
Browse files Browse the repository at this point in the history
  • Loading branch information
withnale committed Jun 3, 2016
1 parent 89edb6f commit d1387b3
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 9 deletions.
28 changes: 19 additions & 9 deletions airflow/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1417,10 +1417,7 @@ def get_template_context(self, session=None):
session.expunge_all()
session.commit()

if task.params:
params.update(task.params)

return {
context = {
'dag': task.dag,
'ds': ds,
'ds_nodash': ds_nodash,
Expand All @@ -1447,6 +1444,18 @@ def get_template_context(self, session=None):
'test_mode': self.test_mode,
}

# Allow task level param definitions to be rendered via jinja
# using the context available up until this point
if task.params:
if task.render_params:
rt = self.task.render_template # shortcut to method
rendered_content = rt('params', task.params, context)
params.update(rendered_content)
else:
params.update(task.params)

return context

def render_templates(self):
task = self.task
jinja_context = self.get_template_context()
Expand Down Expand Up @@ -1712,6 +1721,9 @@ class derived from this one results in the creation of a task object,
:param on_success_callback: much like the ``on_failure_callback`` excepts
that it is executed when the task succeeds.
:type on_success_callback: callable
:param render_params: set this to true to allow params to be rendered
and available to other templates
:type render_params: bool
:param trigger_rule: defines the rule by which dependencies are applied
for the task to get triggered. Options are:
``{ all_success | all_failed | all_done | one_success |
Expand Down Expand Up @@ -1757,6 +1769,7 @@ def __init__(
on_failure_callback=None,
on_success_callback=None,
on_retry_callback=None,
render_params=False,
trigger_rule=TriggerRule.ALL_SUCCESS,
*args,
**kwargs):
Expand Down Expand Up @@ -1790,6 +1803,7 @@ def __init__(
.format(all_triggers=TriggerRule.all_triggers,
d=dag.dag_id, t=task_id, tr = trigger_rule))

self.render_params = render_params
self.trigger_rule = trigger_rule
self.depends_on_past = depends_on_past
self.wait_for_downstream = wait_for_downstream
Expand Down Expand Up @@ -2048,11 +2062,7 @@ def render_template_from_field(self, attr, content, context, jinja_env):
k: rt("{}[{}]".format(attr, k), v, context)
for k, v in list(content.items())}
else:
param_type = type(content)
msg = (
"Type '{param_type}' used for parameter '{attr}' is "
"not supported for templating").format(**locals())
raise AirflowException(msg)
result = content
return result

def render_template(self, attr, content, context):
Expand Down
68 changes: 68 additions & 0 deletions tests/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,74 @@ def test_py_op(templates_dict, ds, **kwargs):
dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, force=True)


def test_template_context_simple(self):
TI = models.TaskInstance
dag = models.DAG('test-dag', start_date=DEFAULT_DATE)
task = operators.DummyOperator(task_id='task', owner='unittest', dag=dag)
ti = TI(task, execution_date=DEFAULT_DATE)
context = ti.get_template_context()

expected = {
'tomorrow_ds': '2015-01-02',
'task_instance_key_str': 'test-dag__task__20150101',
'test_mode': False,
'params': {}
}
self.assertDictContainsSubset(expected, context)

def test_template_context_with_static_params(self):
TI = models.TaskInstance
dag = models.DAG('test-dag', start_date=DEFAULT_DATE,
params={'foo': 'dag', 'bar': 'dag'})
task = operators.DummyOperator(task_id='task', owner='unittest', dag=dag,
params={'foo': 'task', 'boolean': True})
ti = TI(task, execution_date=DEFAULT_DATE)
context = ti.get_template_context()

expected = {
'tomorrow_ds': '2015-01-02',
'task_instance_key_str': 'test-dag__task__20150101',
'test_mode': False,
'params': {'foo': 'task', 'bar': 'dag', 'boolean': True}
}
self.assertDictContainsSubset(expected, context)

def test_template_context_with_dynamic_params(self):
TI = models.TaskInstance
dag = models.DAG('test-dag', start_date=DEFAULT_DATE,
params={'foo': 'dag', 'bar': 'dag'})
task = operators.DummyOperator(task_id='task', owner='unittest', dag=dag,
params={'foo': '{{ ds }}'})
ti = TI(task, execution_date=DEFAULT_DATE)
context = ti.get_template_context()

expected = {
'tomorrow_ds': '2015-01-02',
'task_instance_key_str': 'test-dag__task__20150101',
'test_mode': False,
'params': {'foo': '{{ ds }}', 'bar': 'dag'}
}
self.assertDictContainsSubset(expected, context)

def test_template_context_with_dynamic_params_and_render_params(self):
TI = models.TaskInstance
dag = models.DAG('test-dag', start_date=DEFAULT_DATE,
params={'foo': 'dag', 'bar': 'dag'})
task = operators.DummyOperator(task_id='task', owner='unittest', dag=dag,
render_params=True,
params={'foo': '{{ ds }}'})
ti = TI(task, execution_date=DEFAULT_DATE)
context = ti.get_template_context()

expected = {
'tomorrow_ds': '2015-01-02',
'task_instance_key_str': 'test-dag__task__20150101',
'test_mode': False,
'params': {'foo': '2015-01-01', 'bar': 'dag'}
}
self.assertDictContainsSubset(expected, context)

def test_complex_template(self):
class OperatorSubclass(operators.BaseOperator):
template_fields = ['some_templated_field']
Expand Down

0 comments on commit d1387b3

Please sign in to comment.