Skip to content

Commit

Permalink
feature: callable for template_fields (#37028)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: raphaelauv <[email protected]>
Co-authored-by: Tzu-ping Chung <[email protected]>
  • Loading branch information
3 people authored Jun 11, 2024
1 parent 52f858c commit 6c7aa4b
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 46 deletions.
15 changes: 9 additions & 6 deletions airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,12 +731,15 @@ def _do_render_template_fields(
pass

try:
rendered_content = self.render_template(
value,
context,
jinja_env,
seen_oids,
)
if callable(value):
rendered_content = value(context=context, jinja_env=jinja_env)
else:
rendered_content = self.render_template(
value,
context,
jinja_env,
seen_oids,
)
except Exception:
value_masked = redact(name=attr_name, value=value)
self.log.exception(
Expand Down
105 changes: 65 additions & 40 deletions docs/apache-airflow/core-concepts/operators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,33 @@ For example, say you want to pass the start of the data interval as an environme
Here, ``{{ ds }}`` is a templated variable, and because the ``env`` parameter of the ``BashOperator`` is templated with Jinja, the data interval's start date will be available as an environment variable named ``DATA_INTERVAL_START`` in your Bash script.

You can use Jinja templating with every parameter that is marked as "templated" in the documentation. Template substitution occurs just before the ``pre_execute`` function of your operator is called.
You can also pass in a callable instead when Python is more readable than a Jinja template. The callable must accept two named arguments ``context`` and ``jinja_env``:

You can also use Jinja templating with nested fields, as long as these nested fields are marked as templated in the structure they belong to: fields registered in ``template_fields`` property will be submitted to template substitution, like the ``path`` field in the example below:
.. code-block:: python
def build_complex_command(context, jinja_env):
with open("file.csv") as f:
return do_complex_things(f)
t = BashOperator(
task_id="complex_templated_echo",
bash_command=build_complex_command,
dag=dag,
)
Since each template field is only rendered once, the callable's return value will not go through rendering again. Therefore, the callable must manually render any templates. This can be done by calling ``render_template()`` on the current task like this:

.. code-block:: python
def build_complex_command(context, jinja_env):
with open("file.csv") as f:
data = do_complex_things(f)
return context["task"].render_template(data, context, jinja_env)
You can use templating with every parameter that is marked as "templated" in the documentation. Template substitution occurs just before the ``pre_execute`` function of your operator is called.

You can also use templating with nested fields, as long as these nested fields are marked as templated in the structure they belong to: fields registered in ``template_fields`` property will be submitted to template substitution, like the ``path`` field in the example below:

.. code-block:: python
Expand Down Expand Up @@ -211,64 +235,65 @@ Alternatively, if you want to prevent Airflow from treating a value as a referen
Rendering Fields as Native Python Objects
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

By default, all the ``template_fields`` are rendered as strings.

Example, let's say ``extract`` task pushes a dictionary
(Example: ``{"1001": 301.27, "1002": 433.21, "1003": 502.22}``) to :ref:`XCom <concepts:xcom>` table.
Now, when the following task is run, ``order_data`` argument is passed a string, example:
``'{"1001": 301.27, "1002": 433.21, "1003": 502.22}'``.
By default, all Jinja templates in ``template_fields`` are rendered as strings. This however is not always desired. For example, let's say an ``extract`` task pushes a dictionary ``{"1001": 301.27, "1002": 433.21, "1003": 502.22}`` to :ref:`XCom <concepts:xcom>`:

.. code-block:: python
transform = PythonOperator(
task_id="transform",
op_kwargs={"order_data": "{{ti.xcom_pull('extract')}}"},
python_callable=transform,
)
If you instead want the rendered template field to return a Native Python object (``dict`` in our example),
you can pass ``render_template_as_native_obj=True`` to the DAG as follows:

.. code-block:: python
dag = DAG(
dag_id="example_template_as_python_object",
schedule=None,
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
render_template_as_native_obj=True,
)
@task(task_id="extract")
def extract():
data_string = '{"1001": 301.27, "1002": 433.21, "1003": 502.22}'
return json.loads(data_string)
If a task depends on ``extract``, ``order_data`` argument is passed a string ``"{'1001': 301.27, '1002': 433.21, '1003': 502.22}"``:

.. code-block:: python
def transform(order_data):
print(type(order_data))
total_order_value = 0
for value in order_data.values():
total_order_value += value
total_order_value = sum(order_data.values()) # Fails because order_data is a str :(
return {"total_order_value": total_order_value}
extract_task = extract()
transform = PythonOperator(
task_id="transform",
op_kwargs={"order_data": "{{ ti.xcom_pull('extract') }}"},
python_callable=transform,
)
extract() >> transform
There are two solutions if we want to get the actual dict instead. The first is to use a callable:

transform_task = PythonOperator(
.. code-block:: python
def render_transform_op_kwargs(context, jinja_env):
order_data = context["ti"].xcom_pull("extract")
return {"order_data": order_data}
transform = PythonOperator(
task_id="transform",
op_kwargs={"order_data": "{{ti.xcom_pull('extract')}}"},
op_kwargs=render_transform_op_kwargs,
python_callable=transform,
)
extract_task >> transform_task
Alternatively, Jinja can also be instructed to render a native Python object. This is done by passing ``render_template_as_native_obj=True`` to the DAG. This makes Airflow use `NativeEnvironment <https://jinja.palletsprojects.com/en/2.11.x/nativetypes/>`_ instead of the default ``SandboxedEnvironment``:

.. code-block:: python
In this case, ``order_data`` argument is passed: ``{"1001": 301.27, "1002": 433.21, "1003": 502.22}``.
with DAG(
dag_id="example_template_as_python_object",
schedule=None,
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
render_template_as_native_obj=True,
):
transform = PythonOperator(
task_id="transform",
op_kwargs={"order_data": "{{ ti.xcom_pull('extract') }}"},
python_callable=transform,
)
Airflow uses Jinja's `NativeEnvironment <https://jinja.palletsprojects.com/en/2.11.x/nativetypes/>`_
when ``render_template_as_native_obj`` is set to ``True``.
With ``NativeEnvironment``, rendering a template produces a native Python type.
.. _concepts:reserved-keywords:

Expand Down
26 changes: 26 additions & 0 deletions tests/models/test_baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,32 @@ def test_render_template_fields(self):
assert task.arg1 == "footemplated"
assert task.arg2 == "bartemplated"

@pytest.mark.db_test
def test_render_template_fields_func_using_context(self):
"""Verify if operator attributes are correctly templated."""

def fn_to_template(context, jinja_env):
tmp = context["task"].render_template("{{ bar }}", context, jinja_env)
return "foo_" + tmp

task = MockOperator(task_id="op1", arg2=fn_to_template)

# Trigger templating and verify if attributes are templated correctly
task.render_template_fields(context={"bar": "bartemplated", "task": task})
assert task.arg2 == "foo_bartemplated"

@pytest.mark.db_test
def test_render_template_fields_simple_func(self):
"""Verify if operator attributes are correctly templated."""

def fn_to_template(**kwargs):
a = "foo_" + ("bar" * 3)
return a

task = MockOperator(task_id="op1", arg2=fn_to_template)
task.render_template_fields({})
assert task.arg2 == "foo_barbarbar"

@pytest.mark.parametrize(("content",), [(object(),), (uuid.uuid4(),)])
def test_render_template_fields_no_change(self, content):
"""Tests if non-templatable types remain unchanged."""
Expand Down
15 changes: 15 additions & 0 deletions tests/operators/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,21 @@ def test_python_callable_keyword_arguments_are_templatized(self):
assert rendered_op_kwargs["a_date"] == date(2019, 1, 1)
assert rendered_op_kwargs["a_templated_string"] == f"dag {self.dag_id} ran on {self.ds_templated}."

def test_python_callable_keyword_arguments_callable_not_templatized(self):
"""Test PythonOperator op_kwargs are not templatized if it's a callable"""

def a_fn():
return 4

task = self.render_templates(
lambda: 0,
op_kwargs={
"a_callable": a_fn,
},
)
rendered_op_kwargs = task.op_kwargs
assert rendered_op_kwargs["a_callable"] == a_fn

def test_python_operator_shallow_copy_attr(self):
def not_callable(x):
raise RuntimeError("Should not be triggered")
Expand Down

0 comments on commit 6c7aa4b

Please sign in to comment.