Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix import future annotations in venv jinja template #40208

Merged
merged 12 commits into from
Jun 14, 2024
1 change: 1 addition & 0 deletions airflow/utils/python_virtualenv_script.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
specific language governing permissions and limitations
under the License.
-#}
from __future__ import annotations
phi-friday marked this conversation as resolved.
Show resolved Hide resolved

import {{ pickling_library }}
import sys
Expand Down
30 changes: 30 additions & 0 deletions tests/decorators/test_python_virtualenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@
import sys
from importlib.util import find_spec
from subprocess import CalledProcessError
from typing import Any

import pytest

from airflow.decorators import setup, task, teardown
from airflow.exceptions import RemovedInAirflow3Warning
from airflow.utils import timezone
from airflow.utils.state import TaskInstanceState

pytestmark = pytest.mark.db_test

Expand All @@ -37,6 +39,8 @@
CLOUDPICKLE_INSTALLED = find_spec("cloudpickle") is not None
CLOUDPICKLE_MARKER = pytest.mark.skipif(not CLOUDPICKLE_INSTALLED, reason="`cloudpickle` is not installed")

_Invalid = Any


class TestPythonVirtualenvDecorator:
@CLOUDPICKLE_MARKER
Expand Down Expand Up @@ -350,3 +354,29 @@ def f():
assert teardown_task.is_teardown
assert teardown_task.on_failure_fail_dagrun is on_failure_fail_dagrun
ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

def test_invalid_annotation(self, dag_maker):
import uuid

unique_id = uuid.uuid4().hex
value = {"unique_id": unique_id}

# Functions that throw an error
# if `from __future__ import annotations` is missing
@task.virtualenv(multiple_outputs=False, do_xcom_push=True)
def in_venv(value: dict[str, _Invalid]) -> _Invalid:
assert isinstance(value, dict)
return value["unique_id"]

with dag_maker():
ret = in_venv(value)

dr = dag_maker.create_dagrun()
ret.operator.run(start_date=dr.execution_date, end_date=dr.execution_date)
ti = dr.get_task_instances()[0]

assert ti.state == TaskInstanceState.SUCCESS

xcom = ti.xcom_pull(task_ids=ti.task_id, key="return_value")
assert isinstance(xcom, str)
assert xcom == unique_id