diff --git a/flytekit/core/task.py b/flytekit/core/task.py index 7e420269d3..e02034a32e 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -4,6 +4,11 @@ from functools import update_wrapper from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union, overload +try: + from typing import ParamSpec +except ImportError: + from typing_extensions import ParamSpec # type: ignore + from flytekit.core import launch_plan as _annotated_launchplan from flytekit.core import workflow as _annotated_workflow from flytekit.core.base_task import TaskMetadata, TaskResolverMixin @@ -80,6 +85,7 @@ def find_pythontask_plugin(cls, plugin_config_type: type) -> Type[PythonFunction return PythonFunctionTask +P = ParamSpec("P") T = TypeVar("T") FuncOut = TypeVar("FuncOut") @@ -124,7 +130,7 @@ def task( @overload def task( - _task_function: Callable[..., FuncOut], + _task_function: Callable[P, FuncOut], task_config: Optional[T] = ..., cache: bool = ..., cache_serialize: bool = ..., @@ -157,11 +163,11 @@ def task( pod_template: Optional["PodTemplate"] = ..., pod_template_name: Optional[str] = ..., accelerator: Optional[BaseAccelerator] = ..., -) -> Union[PythonFunctionTask[T], Callable[..., FuncOut]]: ... +) -> Union[Callable[P, FuncOut], PythonFunctionTask[T]]: ... def task( - _task_function: Optional[Callable[..., FuncOut]] = None, + _task_function: Optional[Callable[P, FuncOut]] = None, task_config: Optional[T] = None, cache: bool = False, cache_serialize: bool = False, @@ -201,9 +207,9 @@ def task( pod_template_name: Optional[str] = None, accelerator: Optional[BaseAccelerator] = None, ) -> Union[ - Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]], + Callable[P, FuncOut], + Callable[[Callable[P, FuncOut]], PythonFunctionTask[T]], PythonFunctionTask[T], - Callable[..., FuncOut], ]: """ This is the core decorator to use for any task type in flytekit. @@ -324,7 +330,7 @@ def launch_dynamically(): :param accelerator: The accelerator to use for this task. """ - def wrapper(fn: Callable[..., Any]) -> PythonFunctionTask[T]: + def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]: _metadata = TaskMetadata( cache=cache, cache_serialize=cache_serialize, diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index 58f8157983..b8c0703f04 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -8,6 +8,11 @@ from functools import update_wrapper from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple, Type, Union, cast, overload +try: + from typing import ParamSpec +except ImportError: + from typing_extensions import ParamSpec # type: ignore + from flytekit.core import constants as _common_constants from flytekit.core import launch_plan as _annotated_launch_plan from flytekit.core.base_task import PythonTask, Task @@ -58,6 +63,7 @@ flyte_entity=None, ) +P = ParamSpec("P") T = typing.TypeVar("T") FuncOut = typing.TypeVar("FuncOut") @@ -809,21 +815,21 @@ def workflow( @overload def workflow( - _workflow_function: Callable[..., FuncOut], + _workflow_function: Callable[P, FuncOut], failure_policy: Optional[WorkflowFailurePolicy] = ..., interruptible: bool = ..., on_failure: Optional[Union[WorkflowBase, Task]] = ..., docs: Optional[Documentation] = ..., -) -> Union[PythonFunctionWorkflow, Callable[..., FuncOut]]: ... +) -> Union[Callable[P, FuncOut], PythonFunctionWorkflow]: ... def workflow( - _workflow_function: Optional[Callable[..., Any]] = None, + _workflow_function: Optional[Callable[P, FuncOut]] = None, failure_policy: Optional[WorkflowFailurePolicy] = None, interruptible: bool = False, on_failure: Optional[Union[WorkflowBase, Task]] = None, docs: Optional[Documentation] = None, -) -> Union[Callable[[Callable[..., FuncOut]], PythonFunctionWorkflow], PythonFunctionWorkflow, Callable[..., FuncOut]]: +) -> Union[Callable[P, FuncOut], Callable[[Callable[P, FuncOut]], PythonFunctionWorkflow], PythonFunctionWorkflow]: """ This decorator declares a function to be a Flyte workflow. Workflows are declarative entities that construct a DAG of tasks using the data flow between tasks. @@ -856,7 +862,7 @@ def workflow( :param docs: Description entity for the workflow """ - def wrapper(fn: Callable[..., Any]) -> PythonFunctionWorkflow: + def wrapper(fn: Callable[P, FuncOut]) -> PythonFunctionWorkflow: workflow_metadata = WorkflowMetadata(on_failure=failure_policy or WorkflowFailurePolicy.FAIL_IMMEDIATELY) workflow_metadata_defaults = WorkflowMetadataDefaults(interruptible)