Skip to content

Commit

Permalink
feat(bindings): Task arguments default value binding (#2401)
Browse files Browse the repository at this point in the history
flyteorg/flyte#5321

if the key is not in `kwargs` but in `interface.inputs_with_defaults`, add the value in `interface.inputs_with_defaults` to `kwargs`.

Signed-off-by: Chi-Sheng Liu <[email protected]>
  • Loading branch information
MortalHappiness authored Jun 6, 2024
1 parent c03eaad commit 47f2a29
Show file tree
Hide file tree
Showing 8 changed files with 607 additions and 46 deletions.
56 changes: 29 additions & 27 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import inspect
from copy import deepcopy
from enum import Enum
from typing import Any, Coroutine, Dict, List, Optional, Set, Tuple, Union, cast
from typing import Any, Coroutine, Dict, Hashable, List, Optional, Set, Tuple, Union, cast, get_args

from google.protobuf import struct_pb2 as _struct
from typing_extensions import Protocol, get_args
from typing_extensions import Protocol

from flytekit.core import constants as _common_constants
from flytekit.core import context_manager as _flyte_context
Expand All @@ -23,7 +23,13 @@
)
from flytekit.core.interface import Interface
from flytekit.core.node import Node
from flytekit.core.type_engine import DictTransformer, ListTransformer, TypeEngine, TypeTransformerFailedError
from flytekit.core.type_engine import (
DictTransformer,
ListTransformer,
TypeEngine,
TypeTransformerFailedError,
UnionTransformer,
)
from flytekit.exceptions import user as _user_exceptions
from flytekit.exceptions.user import FlytePromiseAttributeResolveException
from flytekit.loggers import logger
Expand Down Expand Up @@ -774,7 +780,13 @@ def binding_from_python_std(
t_value_type: type,
) -> Tuple[_literals_models.Binding, List[Node]]:
nodes: List[Node] = []
binding_data = binding_data_from_python_std(ctx, expected_literal_type, t_value, t_value_type, nodes)
binding_data = binding_data_from_python_std(
ctx,
expected_literal_type,
t_value,
t_value_type,
nodes,
)
return _literals_models.Binding(var=var_name, binding=binding_data), nodes


Expand Down Expand Up @@ -1060,32 +1072,22 @@ def create_and_link_node(

for k in sorted(interface.inputs):
var = typed_interface.inputs[k]
if var.type.simple == SimpleType.NONE:
raise TypeError("Arguments do not have type annotation")
if k not in kwargs:
is_optional = False
if var.type.union_type:
for variant in var.type.union_type.variants:
if variant.simple == SimpleType.NONE:
val, _default = interface.inputs_with_defaults[k]
if _default is not None:
raise ValueError(
f"The default value for the optional type must be None, but got {_default}"
)
is_optional = True
if not is_optional:
from flytekit.core.base_task import Task

# interface.inputs_with_defaults[k][0] is the type of the default argument
# interface.inputs_with_defaults[k][1] is the value of the default argument
if k in interface.inputs_with_defaults and (
interface.inputs_with_defaults[k][1] is not None
or UnionTransformer.is_optional_type(interface.inputs_with_defaults[k][0])
):
default_val = interface.inputs_with_defaults[k][1]
if not isinstance(default_val, Hashable):
raise _user_exceptions.FlyteAssertion("Cannot use non-hashable object as default argument")
kwargs[k] = default_val
else:
error_msg = f"Input {k} of type {interface.inputs[k]} was not specified for function {entity.name}"

_, _default = interface.inputs_with_defaults[k]
if isinstance(entity, Task) and _default is not None:
error_msg += (
". Flyte workflow syntax is a domain-specific language (DSL) for building execution graphs which "
"supports a subset of Python’s semantics. When calling tasks, all kwargs have to be provided."
)

raise _user_exceptions.FlyteAssertion(error_msg)
else:
continue
v = kwargs[k]
# This check ensures that tuples are not passed into a function, as tuples are not supported by Flyte
# Usually a Tuple will indicate that multiple outputs from a previous task were accidentally passed
Expand Down
2 changes: 1 addition & 1 deletion flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1554,7 +1554,7 @@ def __init__(self):
super().__init__("Typed Union", typing.Union)

@staticmethod
def is_optional_type(t: Type[T]) -> bool:
def is_optional_type(t: Type) -> bool:
"""Return True if `t` is a Union or Optional type."""
return _is_union_type(t) or type(None) in get_args(t)

Expand Down
14 changes: 0 additions & 14 deletions tests/flytekit/unit/core/test_composition.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import Dict, List, NamedTuple, Optional, Union

import pytest

from flytekit.core import launch_plan
from flytekit.core.task import task
from flytekit.core.workflow import workflow
Expand Down Expand Up @@ -186,15 +184,3 @@ def wf(a: Optional[int] = 1) -> Optional[int]:
return t2(a=a)

assert wf() is None

with pytest.raises(ValueError, match="The default value for the optional type must be None, but got 3"):

@task()
def t3(c: Optional[int] = 3) -> Optional[int]:
...

@workflow
def wf():
return t3()

wf()
37 changes: 37 additions & 0 deletions tests/flytekit/unit/core/test_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,43 @@ def ranged_int_to_str(a: int) -> typing.List[str]:
assert res == ["fast-2", "fast-3", "fast-4", "fast-5", "fast-6"]


@pytest.mark.parametrize(
"input_val,output_val",
[
(4, 0),
(5, 5),
],
)
def test_dynamic_local_default_args_task(input_val, output_val):
@task
def t1(a: int = 0) -> int:
return a

@dynamic
def dt(a: int) -> int:
if a % 2 == 0:
return t1()
return t1(a=a)

assert dt(a=input_val) == output_val

with context_manager.FlyteContextManager.with_context(
context_manager.FlyteContextManager.current_context().with_serialization_settings(settings)
) as ctx:
with context_manager.FlyteContextManager.with_context(
ctx.with_execution_state(
ctx.execution_state.with_params(
mode=ExecutionState.Mode.TASK_EXECUTION,
)
)
) as ctx:
input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"a": input_val})
dynamic_job_spec = dt.dispatch_execute(ctx, input_literal_map)
assert len(dynamic_job_spec.nodes) == 1
assert len(dynamic_job_spec.tasks) == 1
assert dynamic_job_spec.nodes[0].inputs[0].binding.scalar.primitive is not None


def test_nested_dynamic_local():
@task
def t1(a: int) -> str:
Expand Down
2 changes: 1 addition & 1 deletion tests/flytekit/unit/core/test_promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def t2(a: typing.Optional[int] = None) -> typing.Optional[int]:

p = create_and_link_node(ctx, t2)
assert p.ref.var == "o0"
assert len(p.ref.node.bindings) == 0
assert len(p.ref.node.bindings) == 1


def test_create_and_link_node_from_remote():
Expand Down
Loading

0 comments on commit 47f2a29

Please sign in to comment.