Skip to content

Commit

Permalink
Send context using in venv operator (apache#41039)
Browse files Browse the repository at this point in the history
  • Loading branch information
phi-friday authored Aug 8, 2024
1 parent 08589a7 commit da55393
Show file tree
Hide file tree
Showing 8 changed files with 462 additions and 1 deletion.
6 changes: 6 additions & 0 deletions airflow/decorators/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ class TaskDecoratorCollection:
env_vars: dict[str, str] | None = None,
inherit_env: bool = True,
use_dill: bool = False,
use_airflow_context: bool = False,
**kwargs,
) -> TaskDecorator:
"""Create a decorator to convert the decorated callable to a virtual environment task.
Expand Down Expand Up @@ -176,6 +177,7 @@ class TaskDecoratorCollection:
:param use_dill: Deprecated, use ``serializer`` instead. Whether to use dill to serialize
the args and result (pickle is default). This allows more complex types
but requires you to include dill in your requirements.
:param use_airflow_context: Whether to provide ``get_current_context()`` to the python_callable.
"""
@overload
def virtualenv(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: ...
Expand All @@ -192,6 +194,7 @@ class TaskDecoratorCollection:
env_vars: dict[str, str] | None = None,
inherit_env: bool = True,
use_dill: bool = False,
use_airflow_context: bool = False,
**kwargs,
) -> TaskDecorator:
"""Create a decorator to convert the decorated callable to a virtual environment task.
Expand Down Expand Up @@ -225,6 +228,7 @@ class TaskDecoratorCollection:
:param use_dill: Deprecated, use ``serializer`` instead. Whether to use dill to serialize
the args and result (pickle is default). This allows more complex types
but requires you to include dill in your requirements.
:param use_airflow_context: Whether to provide ``get_current_context()`` to the python_callable.
"""
@overload
def branch( # type: ignore[misc]
Expand Down Expand Up @@ -258,6 +262,7 @@ class TaskDecoratorCollection:
venv_cache_path: None | str = None,
show_return_value_in_logs: bool = True,
use_dill: bool = False,
use_airflow_context: bool = False,
**kwargs,
) -> TaskDecorator:
"""Create a decorator to wrap the decorated callable into a BranchPythonVirtualenvOperator.
Expand Down Expand Up @@ -299,6 +304,7 @@ class TaskDecoratorCollection:
:param use_dill: Deprecated, use ``serializer`` instead. Whether to use dill to serialize
the args and result (pickle is default). This allows more complex types
but requires you to include dill in your requirements.
:param use_airflow_context: Whether to provide ``get_current_context()`` to the python_callable.
"""
@overload
def branch_virtualenv(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: ...
Expand Down
92 changes: 92 additions & 0 deletions airflow/example_dags/example_python_context_decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Example DAG demonstrating the usage of the PythonOperator with `get_current_context()` to get the current context.
Also, demonstrates the usage of the TaskFlow API.
"""

from __future__ import annotations

import sys

import pendulum

from airflow.decorators import dag, task

SOME_EXTERNAL_PYTHON = sys.executable


@dag(
schedule=None,
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
tags=["example"],
)
def example_python_context_decorator():
# [START get_current_context]
@task(task_id="print_the_context")
def print_context() -> str:
"""Print the Airflow context."""
from pprint import pprint

from airflow.operators.python import get_current_context

context = get_current_context()
pprint(context)
return "Whatever you return gets printed in the logs"

print_the_context = print_context()
# [END get_current_context]

# [START get_current_context_venv]
@task.virtualenv(task_id="print_the_context_venv", use_airflow_context=True)
def print_context_venv() -> str:
"""Print the Airflow context in venv."""
from pprint import pprint

from airflow.operators.python import get_current_context

context = get_current_context()
pprint(context)
return "Whatever you return gets printed in the logs"

print_the_context_venv = print_context_venv()
# [END get_current_context_venv]

# [START get_current_context_external]
@task.external_python(
task_id="print_the_context_external", python=SOME_EXTERNAL_PYTHON, use_airflow_context=True
)
def print_context_external() -> str:
"""Print the Airflow context in external python."""
from pprint import pprint

from airflow.operators.python import get_current_context

context = get_current_context()
pprint(context)
return "Whatever you return gets printed in the logs"

print_the_context_external = print_context_external()
# [END get_current_context_external]

_ = print_the_context >> [print_the_context_venv, print_the_context_external]


example_python_context_decorator()
91 changes: 91 additions & 0 deletions airflow/example_dags/example_python_context_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Example DAG demonstrating the usage of the PythonOperator with `get_current_context()` to get the current context.
Also, demonstrates the usage of the classic Python operators.
"""

from __future__ import annotations

import sys

import pendulum

from airflow import DAG
from airflow.operators.python import ExternalPythonOperator, PythonOperator, PythonVirtualenvOperator

SOME_EXTERNAL_PYTHON = sys.executable

with DAG(
dag_id="example_python_context_operator",
schedule=None,
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
tags=["example"],
) as dag:
# [START get_current_context]
def print_context() -> str:
"""Print the Airflow context."""
from pprint import pprint

from airflow.operators.python import get_current_context

context = get_current_context()
pprint(context)
return "Whatever you return gets printed in the logs"

print_the_context = PythonOperator(task_id="print_the_context", python_callable=print_context)
# [END get_current_context]

# [START get_current_context_venv]
def print_context_venv() -> str:
"""Print the Airflow context in venv."""
from pprint import pprint

from airflow.operators.python import get_current_context

context = get_current_context()
pprint(context)
return "Whatever you return gets printed in the logs"

print_the_context_venv = PythonVirtualenvOperator(
task_id="print_the_context_venv", python_callable=print_context_venv, use_airflow_context=True
)
# [END get_current_context_venv]

# [START get_current_context_external]
def print_context_external() -> str:
"""Print the Airflow context in external python."""
from pprint import pprint

from airflow.operators.python import get_current_context

context = get_current_context()
pprint(context)
return "Whatever you return gets printed in the logs"

print_the_context_external = ExternalPythonOperator(
task_id="print_the_context_external",
python_callable=print_context_external,
python=SOME_EXTERNAL_PYTHON,
use_airflow_context=True,
)
# [END get_current_context_external]

_ = print_the_context >> [print_the_context_venv, print_the_context_external]
36 changes: 36 additions & 0 deletions airflow/operators/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,14 @@
from airflow.utils.operator_helpers import ExecutionCallableRunner, KeywordParameters
from airflow.utils.process_utils import execute_in_subprocess
from airflow.utils.python_virtualenv import prepare_virtualenv, write_python_script
from airflow.utils.session import create_session

log = logging.getLogger(__name__)

if TYPE_CHECKING:
from pendulum.datetime import DateTime

from airflow.serialization.enums import Encoding
from airflow.utils.context import Context


Expand Down Expand Up @@ -442,6 +444,7 @@ def __init__(
env_vars: dict[str, str] | None = None,
inherit_env: bool = True,
use_dill: bool = False,
use_airflow_context: bool = False,
**kwargs,
):
if (
Expand Down Expand Up @@ -494,6 +497,7 @@ def __init__(
)
self.env_vars = env_vars
self.inherit_env = inherit_env
self.use_airflow_context = use_airflow_context

@abstractmethod
def _iter_serializable_context_keys(self):
Expand Down Expand Up @@ -540,6 +544,7 @@ def _execute_python_callable_in_subprocess(self, python_path: Path):
string_args_path = tmp_dir / "string_args.txt"
script_path = tmp_dir / "script.py"
termination_log_path = tmp_dir / "termination.log"
airflow_context_path = tmp_dir / "airflow_context.json"

self._write_args(input_path)
self._write_string_args(string_args_path)
Expand All @@ -551,6 +556,7 @@ def _execute_python_callable_in_subprocess(self, python_path: Path):
"pickling_library": self.serializer,
"python_callable": self.python_callable.__name__,
"python_callable_source": self.get_python_source(),
"use_airflow_context": self.use_airflow_context,
}

if inspect.getfile(self.python_callable) == self.dag.fileloc:
Expand All @@ -561,6 +567,23 @@ def _execute_python_callable_in_subprocess(self, python_path: Path):
filename=os.fspath(script_path),
render_template_as_native_obj=self.dag.render_template_as_native_obj,
)
if self.use_airflow_context:
from airflow.serialization.serialized_objects import BaseSerialization

context = get_current_context()
# TODO: `TaskInstance`` will also soon be serialized as expected.
# see more:
# https://github.com/apache/airflow/issues/40974
# https://github.com/apache/airflow/pull/41067
with create_session() as session:
# FIXME: DetachedInstanceError
dag_run, task_instance = context["dag_run"], context["task_instance"]
session.add_all([dag_run, task_instance])
serializable_context: dict[Encoding, Any] = BaseSerialization.serialize(
context, use_pydantic_models=True
)
with airflow_context_path.open("w+") as file:
json.dump(serializable_context, file)

env_vars = dict(os.environ) if self.inherit_env else {}
if self.env_vars:
Expand All @@ -575,6 +598,7 @@ def _execute_python_callable_in_subprocess(self, python_path: Path):
os.fspath(output_path),
os.fspath(string_args_path),
os.fspath(termination_log_path),
os.fspath(airflow_context_path),
],
env=env_vars,
)
Expand Down Expand Up @@ -666,6 +690,7 @@ class PythonVirtualenvOperator(_BasePythonVirtualenvOperator):
:param use_dill: Deprecated, use ``serializer`` instead. Whether to use dill to serialize
the args and result (pickle is default). This allows more complex types
but requires you to include dill in your requirements.
:param use_airflow_context: Whether to provide ``get_current_context()`` to the python_callable.
"""

template_fields: Sequence[str] = tuple(
Expand Down Expand Up @@ -694,6 +719,7 @@ def __init__(
env_vars: dict[str, str] | None = None,
inherit_env: bool = True,
use_dill: bool = False,
use_airflow_context: bool = False,
**kwargs,
):
if (
Expand All @@ -715,6 +741,9 @@ def __init__(
)
if not is_venv_installed():
raise AirflowException("PythonVirtualenvOperator requires virtualenv, please install it.")
if use_airflow_context and (not expect_airflow and not system_site_packages):
error_msg = "use_airflow_context is set to True, but expect_airflow and system_site_packages are set to False."
raise AirflowException(error_msg)
if not requirements:
self.requirements: list[str] = []
elif isinstance(requirements, str):
Expand Down Expand Up @@ -744,6 +773,7 @@ def __init__(
env_vars=env_vars,
inherit_env=inherit_env,
use_dill=use_dill,
use_airflow_context=use_airflow_context,
**kwargs,
)

Expand Down Expand Up @@ -962,6 +992,7 @@ class ExternalPythonOperator(_BasePythonVirtualenvOperator):
:param use_dill: Deprecated, use ``serializer`` instead. Whether to use dill to serialize
the args and result (pickle is default). This allows more complex types
but requires you to include dill in your requirements.
:param use_airflow_context: Whether to provide ``get_current_context()`` to the python_callable.
"""

template_fields: Sequence[str] = tuple({"python"}.union(PythonOperator.template_fields))
Expand All @@ -983,10 +1014,14 @@ def __init__(
env_vars: dict[str, str] | None = None,
inherit_env: bool = True,
use_dill: bool = False,
use_airflow_context: bool = False,
**kwargs,
):
if not python:
raise ValueError("Python Path must be defined in ExternalPythonOperator")
if use_airflow_context and not expect_airflow:
error_msg = "use_airflow_context is set to True, but expect_airflow is set to False."
raise AirflowException(error_msg)
self.python = python
self.expect_pendulum = expect_pendulum
super().__init__(
Expand All @@ -1002,6 +1037,7 @@ def __init__(
env_vars=env_vars,
inherit_env=inherit_env,
use_dill=use_dill,
use_airflow_context=use_airflow_context,
**kwargs,
)

Expand Down
Loading

0 comments on commit da55393

Please sign in to comment.