From f0960ea6c7447eb4bfc3d9c37f9f796b46e23257 Mon Sep 17 00:00:00 2001 From: phi-friday Date: Thu, 13 Jun 2024 10:36:09 +0900 Subject: [PATCH 01/12] feat: add annotations in template --- airflow/utils/python_virtualenv_script.jinja2 | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/airflow/utils/python_virtualenv_script.jinja2 b/airflow/utils/python_virtualenv_script.jinja2 index 4199a47130fb..2015fb7aa7ed 100644 --- a/airflow/utils/python_virtualenv_script.jinja2 +++ b/airflow/utils/python_virtualenv_script.jinja2 @@ -16,7 +16,9 @@ specific language governing permissions and limitations under the License. -#} - +{% if import_annotations | default(False) %} +from __future__ import annotations +{% endif %} import {{ pickling_library }} import sys From 397234df1fc6140237cd5f8efefa30b5b53983ec Mon Sep 17 00:00:00 2001 From: phi-friday Date: Thu, 13 Jun 2024 10:36:24 +0900 Subject: [PATCH 02/12] feat: apply check_import_future_annotations --- airflow/providers/docker/decorators/docker.py | 38 ++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/airflow/providers/docker/decorators/docker.py b/airflow/providers/docker/decorators/docker.py index 9aafdd1d79bf..03933b5351bd 100644 --- a/airflow/providers/docker/decorators/docker.py +++ b/airflow/providers/docker/decorators/docker.py @@ -16,11 +16,14 @@ # under the License. from __future__ import annotations +import ast import base64 +import inspect import os import pickle +from pathlib import Path from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING, Callable, Sequence +from typing import TYPE_CHECKING, Any, Callable, Sequence import dill @@ -101,6 +104,7 @@ def execute(self, context: Context): if self.op_args or self.op_kwargs: self.pickling_library.dump({"args": self.op_args, "kwargs": self.op_kwargs}, file) py_source = self.get_python_source() + import_annotations = _check_source_import_future_annotations(self.python_callable) write_python_script( jinja_context={ "op_args": self.op_args, @@ -110,6 +114,7 @@ def execute(self, context: Context): "python_callable_source": py_source, "expect_airflow": self.expect_airflow, "string_args_global": False, + "import_annotations": import_annotations, }, filename=script_filename, ) @@ -153,3 +158,34 @@ def docker_task( decorated_operator_class=_DockerDecoratedOperator, **kwargs, ) + + +def _check_import_future_annotations(py_source: str) -> bool: + module = ast.parse(py_source) + for node in module.body: + if not isinstance(node, ast.ImportFrom): + continue + + if node.module != "__future__": + continue + + if any(name.name == "annotations" for name in node.names): + return True + + return False + + +def _check_source_import_future_annotations(obj: Any) -> bool: + file = inspect.getsourcefile(obj) + + if file is None: + return False + + path = Path(file) + if not path.exists(): + return False + + with path.open("r") as f: + py_source = f.read() + + return _check_import_future_annotations(py_source) From 8f7e8755991981cb28baddd724a51e372d1c762a Mon Sep 17 00:00:00 2001 From: phi-friday Date: Thu, 13 Jun 2024 10:37:11 +0900 Subject: [PATCH 03/12] test: add test test_import_annotations --- .../docker/decorators/_with_annotations.py | 20 +++++++++++++++++ .../docker/decorators/test_docker.py | 22 +++++++++++++++++++ 2 files changed, 42 insertions(+) create mode 100644 tests/providers/docker/decorators/_with_annotations.py diff --git a/tests/providers/docker/decorators/_with_annotations.py b/tests/providers/docker/decorators/_with_annotations.py new file mode 100644 index 000000000000..fad5df70bae5 --- /dev/null +++ b/tests/providers/docker/decorators/_with_annotations.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.decorators import task + +if TYPE_CHECKING: + from airflow.decorators.base import Task + +__all__ = [] + + +def create_task_factory(image: str, **kwargs: Any) -> Task[..., Any]: # type: ignore # noqa: F821 + kwargs.setdefault("multiple_outputs", False) + + @task.docker(image=image, auto_remove="force", **kwargs) + def func(value: dict[str, Any]) -> Any: # type: ignore # noqa: F821 + assert isinstance(value, dict) + + return func diff --git a/tests/providers/docker/decorators/test_docker.py b/tests/providers/docker/decorators/test_docker.py index e4fbe15fc324..bfd8b57d1837 100644 --- a/tests/providers/docker/decorators/test_docker.py +++ b/tests/providers/docker/decorators/test_docker.py @@ -284,3 +284,25 @@ def f(): ret = f() assert ret.operator.docker_url == "unix://var/run/docker.sock" + + def test_import_annotations(self, dag_maker): + from typing import Any + + from airflow.models.dagrun import DagRun # noqa: TCH001 + from airflow.utils.state import DagRunState + + from ._with_annotations import create_task_factory # noqa: TID252 + + task_factory = create_task_factory("python:3.9-slim") + + with dag_maker(): + + @task.python(multiple_outputs=False) + def create_dummy_value() -> dict[str, Any]: + return {} + + value = create_dummy_value() + _ = task_factory(value) + + dagrun: DagRun = dag_maker.create_dagrun() + assert DagRunState(dagrun.state) == DagRunState.SUCCESS From 09be7e074ed718bfb803ef82489f37ab481092d4 Mon Sep 17 00:00:00 2001 From: phi-friday Date: Thu, 13 Jun 2024 17:55:17 +0900 Subject: [PATCH 04/12] fix: add annotations as default --- airflow/providers/docker/decorators/docker.py | 38 +------------------ airflow/utils/python_virtualenv_script.jinja2 | 3 +- 2 files changed, 2 insertions(+), 39 deletions(-) diff --git a/airflow/providers/docker/decorators/docker.py b/airflow/providers/docker/decorators/docker.py index 03933b5351bd..9aafdd1d79bf 100644 --- a/airflow/providers/docker/decorators/docker.py +++ b/airflow/providers/docker/decorators/docker.py @@ -16,14 +16,11 @@ # under the License. from __future__ import annotations -import ast import base64 -import inspect import os import pickle -from pathlib import Path from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING, Any, Callable, Sequence +from typing import TYPE_CHECKING, Callable, Sequence import dill @@ -104,7 +101,6 @@ def execute(self, context: Context): if self.op_args or self.op_kwargs: self.pickling_library.dump({"args": self.op_args, "kwargs": self.op_kwargs}, file) py_source = self.get_python_source() - import_annotations = _check_source_import_future_annotations(self.python_callable) write_python_script( jinja_context={ "op_args": self.op_args, @@ -114,7 +110,6 @@ def execute(self, context: Context): "python_callable_source": py_source, "expect_airflow": self.expect_airflow, "string_args_global": False, - "import_annotations": import_annotations, }, filename=script_filename, ) @@ -158,34 +153,3 @@ def docker_task( decorated_operator_class=_DockerDecoratedOperator, **kwargs, ) - - -def _check_import_future_annotations(py_source: str) -> bool: - module = ast.parse(py_source) - for node in module.body: - if not isinstance(node, ast.ImportFrom): - continue - - if node.module != "__future__": - continue - - if any(name.name == "annotations" for name in node.names): - return True - - return False - - -def _check_source_import_future_annotations(obj: Any) -> bool: - file = inspect.getsourcefile(obj) - - if file is None: - return False - - path = Path(file) - if not path.exists(): - return False - - with path.open("r") as f: - py_source = f.read() - - return _check_import_future_annotations(py_source) diff --git a/airflow/utils/python_virtualenv_script.jinja2 b/airflow/utils/python_virtualenv_script.jinja2 index 2015fb7aa7ed..2ff417985e88 100644 --- a/airflow/utils/python_virtualenv_script.jinja2 +++ b/airflow/utils/python_virtualenv_script.jinja2 @@ -16,9 +16,8 @@ specific language governing permissions and limitations under the License. -#} -{% if import_annotations | default(False) %} from __future__ import annotations -{% endif %} + import {{ pickling_library }} import sys From e3ac8a555d7c7a6a6f96fea97b200c7b91a171e5 Mon Sep 17 00:00:00 2001 From: phi-friday Date: Thu, 13 Jun 2024 18:04:49 +0900 Subject: [PATCH 05/12] fix: revert test --- .../docker/decorators/_with_annotations.py | 20 ----------------- .../docker/decorators/test_docker.py | 22 ------------------- 2 files changed, 42 deletions(-) delete mode 100644 tests/providers/docker/decorators/_with_annotations.py diff --git a/tests/providers/docker/decorators/_with_annotations.py b/tests/providers/docker/decorators/_with_annotations.py deleted file mode 100644 index fad5df70bae5..000000000000 --- a/tests/providers/docker/decorators/_with_annotations.py +++ /dev/null @@ -1,20 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from airflow.decorators import task - -if TYPE_CHECKING: - from airflow.decorators.base import Task - -__all__ = [] - - -def create_task_factory(image: str, **kwargs: Any) -> Task[..., Any]: # type: ignore # noqa: F821 - kwargs.setdefault("multiple_outputs", False) - - @task.docker(image=image, auto_remove="force", **kwargs) - def func(value: dict[str, Any]) -> Any: # type: ignore # noqa: F821 - assert isinstance(value, dict) - - return func diff --git a/tests/providers/docker/decorators/test_docker.py b/tests/providers/docker/decorators/test_docker.py index bfd8b57d1837..e4fbe15fc324 100644 --- a/tests/providers/docker/decorators/test_docker.py +++ b/tests/providers/docker/decorators/test_docker.py @@ -284,25 +284,3 @@ def f(): ret = f() assert ret.operator.docker_url == "unix://var/run/docker.sock" - - def test_import_annotations(self, dag_maker): - from typing import Any - - from airflow.models.dagrun import DagRun # noqa: TCH001 - from airflow.utils.state import DagRunState - - from ._with_annotations import create_task_factory # noqa: TID252 - - task_factory = create_task_factory("python:3.9-slim") - - with dag_maker(): - - @task.python(multiple_outputs=False) - def create_dummy_value() -> dict[str, Any]: - return {} - - value = create_dummy_value() - _ = task_factory(value) - - dagrun: DagRun = dag_maker.create_dagrun() - assert DagRunState(dagrun.state) == DagRunState.SUCCESS From 977564457d0fd6ded23f0b73cbc400069e4b9548 Mon Sep 17 00:00:00 2001 From: phi Date: Thu, 13 Jun 2024 21:41:37 +0900 Subject: [PATCH 06/12] test: test_invalid_annotation --- .../docker/decorators/test_docker.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/tests/providers/docker/decorators/test_docker.py b/tests/providers/docker/decorators/test_docker.py index e4fbe15fc324..b3690421113b 100644 --- a/tests/providers/docker/decorators/test_docker.py +++ b/tests/providers/docker/decorators/test_docker.py @@ -16,6 +16,8 @@ # under the License. from __future__ import annotations +from typing import Any + import pytest from airflow.decorators import setup, task, teardown @@ -284,3 +286,40 @@ def f(): ret = f() assert ret.operator.docker_url == "unix://var/run/docker.sock" + + def test_invalid_annotation(self, dag_maker): + @task.python(multiple_outputs=True, do_xcom_push=True) + def create_dummy() -> dict[str, Any]: + import uuid + + return {"unique_id": uuid.uuid4().hex} + + # Functions that throw an error + # if `from __future__ import annotations` is missing + @task.docker(image="python:3.9-slim", auto_remove="force", multiple_outputs=False, do_xcom_push=True) + def in_docker(value: dict[str, Invalid]) -> Invalid: # type: ignore[name-defined] # noqa: F821 + assert isinstance(value, dict) + return value["unique_id"] + + with dag_maker(): + value = create_dummy() + ret = in_docker(value) + + dr = dag_maker.create_dagrun() + value.operator.run(start_date=dr.execution_date, end_date=dr.execution_date) + ret.operator.run(start_date=dr.execution_date, end_date=dr.execution_date) + tis = dr.get_task_instances() + + assert len(tis) == 2 + value_ti = next(x for x in tis if x.task_id == value.operator.task_id) + ret_ti = next(x for x in tis if x.task_id == ret.operator.task_id) + assert value_ti.state == TaskInstanceState.SUCCESS + assert ret_ti.state == TaskInstanceState.SUCCESS + + ti = tis[0] + value_xcom = ti.xcom_pull(task_ids=value_ti.task_id, key="return_value") + ret_xcom = ti.xcom_pull(task_ids=ret_ti.task_id, key="return_value") + assert isinstance(value_xcom, dict) + assert "unique_id" in value_xcom + assert isinstance(ret_xcom, str) + assert value_xcom["unique_id"] == ret_xcom From 84b46c2c50a91aefead1a1ffc4bc31a29dfa3b81 Mon Sep 17 00:00:00 2001 From: phi Date: Fri, 14 Jun 2024 09:01:02 +0900 Subject: [PATCH 07/12] fix: simpler test --- .../docker/decorators/test_docker.py | 33 ++++++------------- 1 file changed, 10 insertions(+), 23 deletions(-) diff --git a/tests/providers/docker/decorators/test_docker.py b/tests/providers/docker/decorators/test_docker.py index b3690421113b..b505abc2b258 100644 --- a/tests/providers/docker/decorators/test_docker.py +++ b/tests/providers/docker/decorators/test_docker.py @@ -16,8 +16,6 @@ # under the License. from __future__ import annotations -from typing import Any - import pytest from airflow.decorators import setup, task, teardown @@ -288,11 +286,10 @@ def f(): assert ret.operator.docker_url == "unix://var/run/docker.sock" def test_invalid_annotation(self, dag_maker): - @task.python(multiple_outputs=True, do_xcom_push=True) - def create_dummy() -> dict[str, Any]: - import uuid + import uuid - return {"unique_id": uuid.uuid4().hex} + unique_id = uuid.uuid4().hex + value = {"unique_id": unique_id} # Functions that throw an error # if `from __future__ import annotations` is missing @@ -302,24 +299,14 @@ def in_docker(value: dict[str, Invalid]) -> Invalid: # type: ignore[name-define return value["unique_id"] with dag_maker(): - value = create_dummy() ret = in_docker(value) dr = dag_maker.create_dagrun() - value.operator.run(start_date=dr.execution_date, end_date=dr.execution_date) ret.operator.run(start_date=dr.execution_date, end_date=dr.execution_date) - tis = dr.get_task_instances() - - assert len(tis) == 2 - value_ti = next(x for x in tis if x.task_id == value.operator.task_id) - ret_ti = next(x for x in tis if x.task_id == ret.operator.task_id) - assert value_ti.state == TaskInstanceState.SUCCESS - assert ret_ti.state == TaskInstanceState.SUCCESS - - ti = tis[0] - value_xcom = ti.xcom_pull(task_ids=value_ti.task_id, key="return_value") - ret_xcom = ti.xcom_pull(task_ids=ret_ti.task_id, key="return_value") - assert isinstance(value_xcom, dict) - assert "unique_id" in value_xcom - assert isinstance(ret_xcom, str) - assert value_xcom["unique_id"] == ret_xcom + ti = dr.get_task_instances()[0] + + assert ti.state == TaskInstanceState.SUCCESS + + xcom = ti.xcom_pull(task_ids=ti.task_id, key="return_value") + assert isinstance(xcom, str) + assert xcom == unique_id From ca3f90baaabb935d89f1e5beced4ffe4166c542f Mon Sep 17 00:00:00 2001 From: phi Date: Fri, 14 Jun 2024 09:20:02 +0900 Subject: [PATCH 08/12] fix: declare _Invalid --- tests/providers/docker/decorators/test_docker.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/providers/docker/decorators/test_docker.py b/tests/providers/docker/decorators/test_docker.py index b505abc2b258..152ee7463fbd 100644 --- a/tests/providers/docker/decorators/test_docker.py +++ b/tests/providers/docker/decorators/test_docker.py @@ -16,6 +16,8 @@ # under the License. from __future__ import annotations +from typing import Any + import pytest from airflow.decorators import setup, task, teardown @@ -29,6 +31,7 @@ DEFAULT_DATE = timezone.datetime(2021, 9, 1) +_Invalid = Any class TestDockerDecorator: @@ -294,7 +297,7 @@ def test_invalid_annotation(self, dag_maker): # Functions that throw an error # if `from __future__ import annotations` is missing @task.docker(image="python:3.9-slim", auto_remove="force", multiple_outputs=False, do_xcom_push=True) - def in_docker(value: dict[str, Invalid]) -> Invalid: # type: ignore[name-defined] # noqa: F821 + def in_docker(value: dict[str, _Invalid]) -> _Invalid: assert isinstance(value, dict) return value["unique_id"] From da9afe2d1c0b1e7a4c084f9e1b1be2b69cd07496 Mon Sep 17 00:00:00 2001 From: phi Date: Fri, 14 Jun 2024 09:20:09 +0900 Subject: [PATCH 09/12] fix: allow test only airflow 2.10+ --- tests/providers/docker/decorators/test_docker.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/providers/docker/decorators/test_docker.py b/tests/providers/docker/decorators/test_docker.py index 152ee7463fbd..a7649ac4a6ab 100644 --- a/tests/providers/docker/decorators/test_docker.py +++ b/tests/providers/docker/decorators/test_docker.py @@ -291,6 +291,11 @@ def f(): def test_invalid_annotation(self, dag_maker): import uuid + from tests.test_utils.compat import AIRFLOW_V_2_10_PLUS + + if not AIRFLOW_V_2_10_PLUS: + pytest.skip("This test is only for Airflow 2.10+") + unique_id = uuid.uuid4().hex value = {"unique_id": unique_id} From 1f3b622a71d577bd0df635d36cd5823b3fb76a20 Mon Sep 17 00:00:00 2001 From: phi Date: Fri, 14 Jun 2024 09:38:55 +0900 Subject: [PATCH 10/12] fix: test_docker -> test_python_virtualenv --- .../docker/decorators/test_docker.py | 34 ------------------- tests/utils/test_python_virtualenv.py | 31 +++++++++++++++++ 2 files changed, 31 insertions(+), 34 deletions(-) diff --git a/tests/providers/docker/decorators/test_docker.py b/tests/providers/docker/decorators/test_docker.py index a7649ac4a6ab..e4fbe15fc324 100644 --- a/tests/providers/docker/decorators/test_docker.py +++ b/tests/providers/docker/decorators/test_docker.py @@ -16,8 +16,6 @@ # under the License. from __future__ import annotations -from typing import Any - import pytest from airflow.decorators import setup, task, teardown @@ -31,7 +29,6 @@ DEFAULT_DATE = timezone.datetime(2021, 9, 1) -_Invalid = Any class TestDockerDecorator: @@ -287,34 +284,3 @@ def f(): ret = f() assert ret.operator.docker_url == "unix://var/run/docker.sock" - - def test_invalid_annotation(self, dag_maker): - import uuid - - from tests.test_utils.compat import AIRFLOW_V_2_10_PLUS - - if not AIRFLOW_V_2_10_PLUS: - pytest.skip("This test is only for Airflow 2.10+") - - unique_id = uuid.uuid4().hex - value = {"unique_id": unique_id} - - # Functions that throw an error - # if `from __future__ import annotations` is missing - @task.docker(image="python:3.9-slim", auto_remove="force", multiple_outputs=False, do_xcom_push=True) - def in_docker(value: dict[str, _Invalid]) -> _Invalid: - assert isinstance(value, dict) - return value["unique_id"] - - with dag_maker(): - ret = in_docker(value) - - dr = dag_maker.create_dagrun() - ret.operator.run(start_date=dr.execution_date, end_date=dr.execution_date) - ti = dr.get_task_instances()[0] - - assert ti.state == TaskInstanceState.SUCCESS - - xcom = ti.xcom_pull(task_ids=ti.task_id, key="return_value") - assert isinstance(xcom, str) - assert xcom == unique_id diff --git a/tests/utils/test_python_virtualenv.py b/tests/utils/test_python_virtualenv.py index 38cda4854baf..7e2a46b6ce2f 100644 --- a/tests/utils/test_python_virtualenv.py +++ b/tests/utils/test_python_virtualenv.py @@ -19,12 +19,17 @@ import sys from pathlib import Path +from typing import Any from unittest import mock import pytest +from airflow.decorators import task from airflow.utils.decorators import remove_task_decorator from airflow.utils.python_virtualenv import _generate_pip_conf, prepare_virtualenv +from airflow.utils.state import TaskInstanceState + +_Invalid = Any class TestPrepareVirtualenv: @@ -138,3 +143,29 @@ def test_remove_decorator_nested(self): py_source = "@foo\n@task.virtualenv()\n@bar\ndef f():\nimport funcsigs" res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv") assert res == "@foo\n@bar\ndef f():\nimport funcsigs" + + def test_invalid_annotation(self, dag_maker): + import uuid + + unique_id = uuid.uuid4().hex + value = {"unique_id": unique_id} + + # Functions that throw an error + # if `from __future__ import annotations` is missing + @task.docker(image="python:3.9-slim", auto_remove="force", multiple_outputs=False, do_xcom_push=True) + def in_docker(value: dict[str, _Invalid]) -> _Invalid: + assert isinstance(value, dict) + return value["unique_id"] + + with dag_maker(): + ret = in_docker(value) + + dr = dag_maker.create_dagrun() + ret.operator.run(start_date=dr.execution_date, end_date=dr.execution_date) + ti = dr.get_task_instances()[0] + + assert ti.state == TaskInstanceState.SUCCESS + + xcom = ti.xcom_pull(task_ids=ti.task_id, key="return_value") + assert isinstance(xcom, str) + assert xcom == unique_id From 32876b998980a6bc41b680538b79d79d47ac0845 Mon Sep 17 00:00:00 2001 From: phi Date: Fri, 14 Jun 2024 09:53:48 +0900 Subject: [PATCH 11/12] fix: task.docker -> task.virtualenv --- tests/utils/test_python_virtualenv.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/utils/test_python_virtualenv.py b/tests/utils/test_python_virtualenv.py index 7e2a46b6ce2f..140decf6235f 100644 --- a/tests/utils/test_python_virtualenv.py +++ b/tests/utils/test_python_virtualenv.py @@ -152,13 +152,13 @@ def test_invalid_annotation(self, dag_maker): # Functions that throw an error # if `from __future__ import annotations` is missing - @task.docker(image="python:3.9-slim", auto_remove="force", multiple_outputs=False, do_xcom_push=True) - def in_docker(value: dict[str, _Invalid]) -> _Invalid: + @task.virtualenv(multiple_outputs=False, do_xcom_push=True) + def in_venv(value: dict[str, _Invalid]) -> _Invalid: assert isinstance(value, dict) return value["unique_id"] with dag_maker(): - ret = in_docker(value) + ret = in_venv(value) dr = dag_maker.create_dagrun() ret.operator.run(start_date=dr.execution_date, end_date=dr.execution_date) From b82e21029868c2ceb60f091f1ddaaf7aae1a8992 Mon Sep 17 00:00:00 2001 From: phi Date: Fri, 14 Jun 2024 10:10:09 +0900 Subject: [PATCH 12/12] fix: utils.test_python_virtualenv -> decorators.test_python_virtualenv --- tests/decorators/test_python_virtualenv.py | 30 +++++++++++++++++++++ tests/utils/test_python_virtualenv.py | 31 ---------------------- 2 files changed, 30 insertions(+), 31 deletions(-) diff --git a/tests/decorators/test_python_virtualenv.py b/tests/decorators/test_python_virtualenv.py index b91bcaae36be..0f7ab6918dd2 100644 --- a/tests/decorators/test_python_virtualenv.py +++ b/tests/decorators/test_python_virtualenv.py @@ -21,12 +21,14 @@ import sys from importlib.util import find_spec from subprocess import CalledProcessError +from typing import Any import pytest from airflow.decorators import setup, task, teardown from airflow.exceptions import RemovedInAirflow3Warning from airflow.utils import timezone +from airflow.utils.state import TaskInstanceState pytestmark = pytest.mark.db_test @@ -37,6 +39,8 @@ CLOUDPICKLE_INSTALLED = find_spec("cloudpickle") is not None CLOUDPICKLE_MARKER = pytest.mark.skipif(not CLOUDPICKLE_INSTALLED, reason="`cloudpickle` is not installed") +_Invalid = Any + class TestPythonVirtualenvDecorator: @CLOUDPICKLE_MARKER @@ -350,3 +354,29 @@ def f(): assert teardown_task.is_teardown assert teardown_task.on_failure_fail_dagrun is on_failure_fail_dagrun ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + def test_invalid_annotation(self, dag_maker): + import uuid + + unique_id = uuid.uuid4().hex + value = {"unique_id": unique_id} + + # Functions that throw an error + # if `from __future__ import annotations` is missing + @task.virtualenv(multiple_outputs=False, do_xcom_push=True) + def in_venv(value: dict[str, _Invalid]) -> _Invalid: + assert isinstance(value, dict) + return value["unique_id"] + + with dag_maker(): + ret = in_venv(value) + + dr = dag_maker.create_dagrun() + ret.operator.run(start_date=dr.execution_date, end_date=dr.execution_date) + ti = dr.get_task_instances()[0] + + assert ti.state == TaskInstanceState.SUCCESS + + xcom = ti.xcom_pull(task_ids=ti.task_id, key="return_value") + assert isinstance(xcom, str) + assert xcom == unique_id diff --git a/tests/utils/test_python_virtualenv.py b/tests/utils/test_python_virtualenv.py index 140decf6235f..38cda4854baf 100644 --- a/tests/utils/test_python_virtualenv.py +++ b/tests/utils/test_python_virtualenv.py @@ -19,17 +19,12 @@ import sys from pathlib import Path -from typing import Any from unittest import mock import pytest -from airflow.decorators import task from airflow.utils.decorators import remove_task_decorator from airflow.utils.python_virtualenv import _generate_pip_conf, prepare_virtualenv -from airflow.utils.state import TaskInstanceState - -_Invalid = Any class TestPrepareVirtualenv: @@ -143,29 +138,3 @@ def test_remove_decorator_nested(self): py_source = "@foo\n@task.virtualenv()\n@bar\ndef f():\nimport funcsigs" res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv") assert res == "@foo\n@bar\ndef f():\nimport funcsigs" - - def test_invalid_annotation(self, dag_maker): - import uuid - - unique_id = uuid.uuid4().hex - value = {"unique_id": unique_id} - - # Functions that throw an error - # if `from __future__ import annotations` is missing - @task.virtualenv(multiple_outputs=False, do_xcom_push=True) - def in_venv(value: dict[str, _Invalid]) -> _Invalid: - assert isinstance(value, dict) - return value["unique_id"] - - with dag_maker(): - ret = in_venv(value) - - dr = dag_maker.create_dagrun() - ret.operator.run(start_date=dr.execution_date, end_date=dr.execution_date) - ti = dr.get_task_instances()[0] - - assert ti.state == TaskInstanceState.SUCCESS - - xcom = ti.xcom_pull(task_ids=ti.task_id, key="return_value") - assert isinstance(xcom, str) - assert xcom == unique_id