diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 17a4837432..a182ead2cc 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -548,11 +548,21 @@ def __rshift__(self, other: Any): return Output(*promises) # type: ignore +def binding_from_flyte_std( + ctx: _flyte_context.FlyteContext, + var_name: str, + expected_literal_type: _type_models.LiteralType, + t_value: typing.Any, +) -> _literals_models.Binding: + binding_data = binding_data_from_python_std(ctx, expected_literal_type, t_value, t_value_type=None) + return _literals_models.Binding(var=var_name, binding=binding_data) + + def binding_data_from_python_std( ctx: _flyte_context.FlyteContext, expected_literal_type: _type_models.LiteralType, t_value: typing.Any, - t_value_type: type, + t_value_type: Optional[type] = None, ) -> _literals_models.BindingData: # This handles the case where the given value is the output of another task if isinstance(t_value, Promise): @@ -568,7 +578,7 @@ def binding_data_from_python_std( if expected_literal_type.collection_type is None: raise AssertionError(f"this should be a list and it is not: {type(t_value)} vs {expected_literal_type}") - sub_type = ListTransformer.get_sub_type(t_value_type) + sub_type = ListTransformer.get_sub_type(t_value_type) if t_value_type else None collection = _literals_models.BindingDataCollection( bindings=[ binding_data_from_python_std(ctx, expected_literal_type.collection_type, t, sub_type) for t in t_value @@ -585,11 +595,11 @@ def binding_data_from_python_std( raise AssertionError( f"this should be a Dictionary type and it is not: {type(t_value)} vs {expected_literal_type}" ) - k_type, v_type = DictTransformer.get_dict_types(t_value_type) if expected_literal_type.simple == _type_models.SimpleType.STRUCT: lit = TypeEngine.to_literal(ctx, t_value, type(t_value), expected_literal_type) return _literals_models.BindingData(scalar=lit.scalar) else: + _, v_type = DictTransformer.get_dict_types(t_value_type) if t_value_type else None, None m = _literals_models.BindingDataMap( bindings={ k: binding_data_from_python_std(ctx, expected_literal_type.map_value_type, v, v_type) @@ -607,7 +617,7 @@ def binding_data_from_python_std( ) # This is the scalar case - e.g. my_task(in1=5) - scalar = TypeEngine.to_literal(ctx, t_value, t_value_type, expected_literal_type).scalar + scalar = TypeEngine.to_literal(ctx, t_value, t_value_type or type(t_value), expected_literal_type).scalar return _literals_models.BindingData(scalar=scalar) @@ -703,7 +713,8 @@ def __init__(self, node: Node, var: str): @property def node_id(self): """ - Override the underlying node_id property to refer to SdkNode. + Override the underlying node_id property to refer to the Node's id. This is to make sure that overriding + node IDs from with_overrides gets serialized correctly. :rtype: Text """ return self.node.id @@ -731,6 +742,19 @@ def construct_node_metadata(self) -> _workflow_model.NodeMetadata: ... +class HasFlyteInterface(Protocol): + @property + def name(self) -> str: + ... + + @property + def interface(self) -> _interface_models.TypedInterface: + ... + + def construct_node_metadata(self) -> _workflow_model.NodeMetadata: + ... + + def extract_obj_name(name: str) -> str: """ Generates a shortened name, without the module information. Useful for node-names etc. Only extracts the final @@ -743,6 +767,87 @@ def extract_obj_name(name: str) -> str: return name +def create_and_link_node_from_remote( + ctx: FlyteContext, + entity: HasFlyteInterface, + **kwargs, +): + """ + This method is used to generate a node with bindings. This is not used in the execution path. + """ + if ctx.compilation_state is None: + raise _user_exceptions.FlyteAssertion("Cannot create node when not compiling...") + + used_inputs = set() + bindings = [] + + typed_interface = entity.interface + + for k in sorted(typed_interface.inputs): + var = typed_interface.inputs[k] + if k not in kwargs: + raise _user_exceptions.FlyteAssertion("Input was not specified for: {} of type {}".format(k, var.type)) + v = kwargs[k] + # This check ensures that tuples are not passed into a function, as tuples are not supported by Flyte + # Usually a Tuple will indicate that multiple outputs from a previous task were accidentally passed + # into the function. + if isinstance(v, tuple): + raise AssertionError( + f"Variable({k}) for function({entity.name}) cannot receive a multi-valued tuple {v}." + f" Check if the predecessor function returning more than one value?" + ) + try: + bindings.append( + binding_from_flyte_std( + ctx, + var_name=k, + expected_literal_type=var.type, + t_value=v, + ) + ) + used_inputs.add(k) + except Exception as e: + raise AssertionError(f"Failed to Bind variable {k} for function {entity.name}.") from e + + extra_inputs = used_inputs ^ set(kwargs.keys()) + if len(extra_inputs) > 0: + raise _user_exceptions.FlyteAssertion( + "Too many inputs were specified for the interface. Extra inputs were: {}".format(extra_inputs) + ) + + # Detect upstream nodes + # These will be our core Nodes until we can amend the Promise to use NodeOutputs that reference our Nodes + upstream_nodes = list( + set( + [ + input_val.ref.node + for input_val in kwargs.values() + if isinstance(input_val, Promise) and input_val.ref.node_id != _common_constants.GLOBAL_INPUT_NODE_ID + ] + ) + ) + + flytekit_node = Node( + # TODO: Better naming, probably a derivative of the function name. + id=f"{ctx.compilation_state.prefix}n{len(ctx.compilation_state.nodes)}", + metadata=entity.construct_node_metadata(), + bindings=sorted(bindings, key=lambda b: b.var), + upstream_nodes=upstream_nodes, + flyte_entity=entity, + ) + ctx.compilation_state.add_node(flytekit_node) + + if len(typed_interface.outputs) == 0: + return VoidPromise(entity.name) + + # Create a node output object for each output, they should all point to this node of course. + node_outputs = [] + for output_name, output_var_model in typed_interface.outputs.items(): + node_outputs.append(Promise(output_name, NodeOutput(node=flytekit_node, var=output_name))) + + return create_task_output(node_outputs) + + def create_and_link_node( ctx: FlyteContext, entity: SupportsNodeCreation, @@ -819,8 +924,6 @@ def create_and_link_node( # Create a node output object for each output, they should all point to this node of course. node_outputs = [] for output_name, output_var_model in typed_interface.outputs.items(): - # TODO: If node id gets updated later, we have to make sure to update the NodeOutput model's ID, which - # is currently just a static str node_outputs.append(Promise(output_name, NodeOutput(node=flytekit_node, var=output_name))) # Don't print this, it'll crash cuz sdk_node._upstream_node_ids might be None, but idl code will break diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index 35016f9571..2e102c8b62 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -210,7 +210,12 @@ def compile_into_workflow( for entity, model in model_entities.items(): # We only care about gathering tasks here. Launch plans are handled by # propeller. Subworkflows should already be in the workflow spec. - if not isinstance(entity, Task): + if not isinstance(entity, Task) and not isinstance(entity, task_models.TaskTemplate): + continue + + # Handle FlyteTask + if isinstance(entity, task_models.TaskTemplate): + tts.append(entity) continue # We are currently not supporting reference tasks since these will diff --git a/flytekit/remote/launch_plan.py b/flytekit/remote/launch_plan.py index 3bf845fcab..32b3a93aec 100644 --- a/flytekit/remote/launch_plan.py +++ b/flytekit/remote/launch_plan.py @@ -1,60 +1,32 @@ +from __future__ import annotations + from typing import Optional +from flytekit.core import hash as hash_mixin from flytekit.core.interface import Interface -from flytekit.core.launch_plan import ReferenceLaunchPlan -from flytekit.core.type_engine import TypeEngine -from flytekit.loggers import remote_logger as logger from flytekit.models import interface as _interface_models from flytekit.models import launch_plan as _launch_plan_models from flytekit.models.core import identifier as id_models from flytekit.remote import interface as _interface +from flytekit.remote.remote_callable import RemoteEntity -class FlyteLaunchPlan(_launch_plan_models.LaunchPlanSpec): +class FlyteLaunchPlan(hash_mixin.HashOnReferenceMixin, RemoteEntity, _launch_plan_models.LaunchPlanSpec): """A class encapsulating a remote Flyte launch plan.""" def __init__(self, id, *args, **kwargs): super(FlyteLaunchPlan, self).__init__(*args, **kwargs) # Set all the attributes we expect this class to have self._id = id + self._name = id.name # The interface is not set explicitly unless fetched in an engine context self._interface = None self._python_interface = None - self._reference_entity = None - - def __call__(self, *args, **kwargs): - if self.reference_entity is None: - logger.warning( - f"FlyteLaunchPlan {self} is not callable, most likely because flytekit could not " - f"guess the python interface. The workflow calling this launch plan may not behave correctly." - ) - return - return self.reference_entity(*args, **kwargs) - # TODO: Refactor behind mixin @property - def reference_entity(self) -> Optional[ReferenceLaunchPlan]: - if self._reference_entity is None: - if self.guessed_python_interface is None: - try: - self.guessed_python_interface = Interface( - TypeEngine.guess_python_types(self.interface.inputs), - TypeEngine.guess_python_types(self.interface.outputs), - ) - except Exception as e: - logger.warning(f"Error backing out interface {e}, Flyte interface {self.interface}") - return None - - self._reference_entity = ReferenceLaunchPlan( - self.id.project, - self.id.domain, - self.id.name, - self.id.version, - inputs=self.guessed_python_interface.inputs, - outputs=self.guessed_python_interface.outputs, - ) - return self._reference_entity + def name(self) -> str: + return self._name @classmethod def promote_from_model( @@ -71,7 +43,6 @@ def promote_from_model( auth_role=model.auth_role, raw_output_data_config=model.raw_output_data_config, ) - return lp @property diff --git a/flytekit/remote/remote_callable.py b/flytekit/remote/remote_callable.py new file mode 100644 index 0000000000..699d4d8140 --- /dev/null +++ b/flytekit/remote/remote_callable.py @@ -0,0 +1,61 @@ +from abc import ABC, abstractmethod +from typing import Any, Optional, Tuple, Union + +from flytekit.core.context_manager import BranchEvalMode, ExecutionState, FlyteContext +from flytekit.core.promise import Promise, VoidPromise, create_and_link_node_from_remote, extract_obj_name +from flytekit.exceptions import user as user_exceptions +from flytekit.loggers import remote_logger as logger +from flytekit.models.core.workflow import NodeMetadata + + +class RemoteEntity(ABC): + @property + @abstractmethod + def name(self) -> str: + ... + + def construct_node_metadata(self) -> NodeMetadata: + """ + Used when constructing the node that encapsulates this task as part of a broader workflow definition. + """ + return NodeMetadata( + name=extract_obj_name(self.name), + ) + + def compile(self, ctx: FlyteContext, *args, **kwargs): + return create_and_link_node_from_remote(ctx, entity=self, **kwargs) # noqa + + def __call__(self, *args, **kwargs): + # When a Task is () aka __called__, there are three things we may do: + # a. Plain execution Mode - just run the execute function. If not overridden, we should raise an exception + # b. Compilation Mode - this happens when the function is called as part of a workflow (potentially + # dynamic task). Produce promise objects and create a node. + # c. Workflow Execution Mode - when a workflow is being run locally. Even though workflows are functions + # and everything should be able to be passed through naturally, we'll want to wrap output values of the + # function into objects, so that potential .with_cpu or other ancillary functions can be attached to do + # nothing. Subsequent tasks will have to know how to unwrap these. If by chance a non-Flyte task uses a + # task output as an input, things probably will fail pretty obviously. + # Since this is a reference entity, it still needs to be mocked otherwise an exception will be raised. + if len(args) > 0: + raise user_exceptions.FlyteAssertion( + f"Cannot call remotely fetched entity with args - detected {len(args)} positional args {args}" + ) + + ctx = FlyteContext.current_context() + if ctx.compilation_state is not None and ctx.compilation_state.mode == 1: + return self.compile(ctx, *args, **kwargs) + elif ( + ctx.execution_state is not None and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION + ): + if ctx.execution_state.branch_eval_mode == BranchEvalMode.BRANCH_SKIPPED: + return + return self.local_execute(ctx, **kwargs) + else: + logger.debug("Fetched entity, running raw execute.") + return self.execute(**kwargs) + + def local_execute(self, ctx: FlyteContext, **kwargs) -> Optional[Union[Tuple[Promise], Promise, VoidPromise]]: + raise Exception("Remotely fetched entities cannot be run locally. You have to mock this out.") + + def execute(self, **kwargs) -> Any: + raise Exception("Remotely fetched entities cannot be run locally. You have to mock this out.") diff --git a/flytekit/remote/task.py b/flytekit/remote/task.py index 34b4f1d9c2..028719ec7f 100644 --- a/flytekit/remote/task.py +++ b/flytekit/remote/task.py @@ -1,16 +1,16 @@ from typing import Optional -from flytekit.core import hash as _hash_mixin +from flytekit.core import hash as hash_mixin from flytekit.core.interface import Interface -from flytekit.core.task import ReferenceTask from flytekit.core.type_engine import TypeEngine from flytekit.loggers import remote_logger as logger from flytekit.models import task as _task_model from flytekit.models.core import identifier as _identifier_model from flytekit.remote import interface as _interfaces +from flytekit.remote.remote_callable import RemoteEntity -class FlyteTask(_hash_mixin.HashOnReferenceMixin, _task_model.TaskTemplate): +class FlyteTask(hash_mixin.HashOnReferenceMixin, RemoteEntity, _task_model.TaskTemplate): """A class encapsulating a remote Flyte task.""" def __init__(self, id, type, metadata, interface, custom, container=None, task_type_version=0, config=None): @@ -25,44 +25,11 @@ def __init__(self, id, type, metadata, interface, custom, container=None, task_t config=config, ) self._python_interface = None - self._reference_entity = None - - def __call__(self, *args, **kwargs): - if self.reference_entity is None: - logger.warning( - f"FlyteTask {self} is not callable, most likely because flytekit could not " - f"guess the python interface. The workflow calling this task may not behave correctly" - ) - return - return self.reference_entity(*args, **kwargs) - - # TODO: Refactor behind mixin - @property - def reference_entity(self) -> Optional[ReferenceTask]: - if self._reference_entity is None: - if self.guessed_python_interface is None: - try: - self.guessed_python_interface = Interface( - TypeEngine.guess_python_types(self.interface.inputs), - TypeEngine.guess_python_types(self.interface.outputs), - ) - except Exception as e: - logger.warning(f"Error backing out interface {e}, Flyte interface {self.interface}") - return None - - self._reference_entity = ReferenceTask( - self.id.project, - self.id.domain, - self.id.name, - self.id.version, - inputs=self.guessed_python_interface.inputs, - outputs=self.guessed_python_interface.outputs, - ) - return self._reference_entity + self._name = id.name @property - def interface(self) -> _interfaces.TypedInterface: - return super(FlyteTask, self).interface + def name(self) -> str: + return self._name @property def resource_type(self) -> _identifier_model.ResourceType: diff --git a/flytekit/remote/workflow.py b/flytekit/remote/workflow.py index 14ef6c91bf..2ff74376c8 100644 --- a/flytekit/remote/workflow.py +++ b/flytekit/remote/workflow.py @@ -14,21 +14,22 @@ from flytekit.models.core import workflow as _workflow_models from flytekit.remote import interface as _interfaces from flytekit.remote import nodes as _nodes +from flytekit.remote.remote_callable import RemoteEntity -class FlyteWorkflow(_hash_mixin.HashOnReferenceMixin, _workflow_models.WorkflowTemplate): +class FlyteWorkflow(_hash_mixin.HashOnReferenceMixin, RemoteEntity, _workflow_models.WorkflowTemplate): """A class encapsulating a remote Flyte workflow.""" def __init__( self, + id: id_models.Identifier, nodes: List[_nodes.FlyteNode], interface, output_bindings, - id: id_models.Identifier, metadata, metadata_defaults, subworkflows: Optional[Dict[id_models.Identifier, _workflow_models.WorkflowTemplate]] = None, - tasks: Optional[Dict[id_models.Identifier, _task_models.TaskSpec]] = None, + tasks: Optional[Dict[id_models.Identifier, _task_models.TaskTemplate]] = None, launch_plans: Optional[Dict[id_models.Identifier, launch_plan_models.LaunchPlanSpec]] = None, compiled_closure: Optional[compiler_models.CompiledWorkflowClosure] = None, ): @@ -58,10 +59,15 @@ def __init__( self._launch_plans = launch_plans self._compiled_closure = compiled_closure self._node_map = None + self._name = id.name + + @property + def name(self) -> str: + return self._name @property - def interface(self) -> _interfaces.TypedInterface: - return super(FlyteWorkflow, self).interface + def sub_workflows(self) -> Optional[Dict[id_models.Identifier, _workflow_models.WorkflowTemplate]]: + return self._subworkflows @property def entity_type_text(self) -> str: @@ -114,8 +120,8 @@ def promote_from_model( # No inputs/outputs specified, see the constructor for more information on the overrides. wf = cls( - nodes=list(node_map.values()), id=base_model.id, + nodes=list(node_map.values()), metadata=base_model.metadata, metadata_defaults=base_model.metadata_defaults, interface=_interfaces.TypedInterface.promote_from_model(base_model.interface), @@ -159,6 +165,3 @@ def promote_from_closure( ) flyte_wf._compiled_closure = closure return flyte_wf - - def __call__(self, *args, **input_map): - raise NotImplementedError diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index 8c9750e8cd..05b4224cc9 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -126,6 +126,9 @@ def get_serializable_workflow( settings: SerializationSettings, entity: WorkflowBase, ) -> admin_workflow_models.WorkflowSpec: + # TODO: Try to move up following config refactor - https://github.com/flyteorg/flyte/issues/2214 + from flytekit.remote.workflow import FlyteWorkflow + # Get node models upstream_node_models = [ get_serializable(entity_mapping, settings, n) @@ -151,6 +154,11 @@ def get_serializable_workflow( sub_wfs.append(sub_wf_spec.template) sub_wfs.extend(sub_wf_spec.sub_workflows) + if isinstance(n.flyte_entity, FlyteWorkflow): + get_serializable(entity_mapping, settings, n.flyte_entity) + sub_wfs.append(n.flyte_entity) + sub_wfs.extend([s for s in n.flyte_entity.sub_workflows.values()]) + if isinstance(n.flyte_entity, BranchNode): if_else: workflow_model.IfElseBlock = n.flyte_entity._ifelse_block # See comment in get_serializable_branch_node also. Again this is a List[Node] even though it's supposed @@ -168,6 +176,10 @@ def get_serializable_workflow( sub_wf_spec = get_serializable(entity_mapping, settings, leaf_node.flyte_entity) sub_wfs.append(sub_wf_spec.template) sub_wfs.extend(sub_wf_spec.sub_workflows) + elif isinstance(leaf_node.flyte_entity, FlyteWorkflow): + get_serializable(entity_mapping, settings, leaf_node.flyte_entity) + sub_wfs.append(leaf_node.flyte_entity) + sub_wfs.extend([s for s in leaf_node.flyte_entity.sub_workflows.values()]) wf_id = _identifier_model.Identifier( resource_type=_identifier_model.ResourceType.WORKFLOW, @@ -237,6 +249,11 @@ def get_serializable_node( if entity.flyte_entity is None: raise Exception(f"Node {entity.id} has no flyte entity") + # TODO: Try to move back up following config refactor - https://github.com/flyteorg/flyte/issues/2214 + from flytekit.remote.launch_plan import FlyteLaunchPlan + from flytekit.remote.task import FlyteTask + from flytekit.remote.workflow import FlyteWorkflow + upstream_sdk_nodes = [ get_serializable(entity_mapping, settings, n) for n in entity.upstream_nodes @@ -319,6 +336,49 @@ def get_serializable_node( output_aliases=[], workflow_node=workflow_model.WorkflowNode(launchplan_ref=lp_spec.id), ) + + elif isinstance(entity.flyte_entity, FlyteTask): + # Recursive call doesn't do anything except put the entity on the map. + get_serializable(entity_mapping, settings, entity.flyte_entity) + node_model = workflow_model.Node( + id=_dnsify(entity.id), + metadata=entity.metadata, + inputs=entity.bindings, + upstream_node_ids=[n.id for n in upstream_sdk_nodes], + output_aliases=[], + task_node=workflow_model.TaskNode( + reference_id=entity.flyte_entity.id, overrides=TaskNodeOverrides(resources=entity._resources) + ), + ) + elif isinstance(entity.flyte_entity, FlyteWorkflow): + wf_template = get_serializable(entity_mapping, settings, entity.flyte_entity) + for _, sub_wf in entity.flyte_entity.sub_workflows.items(): + get_serializable(entity_mapping, settings, sub_wf) + node_model = workflow_model.Node( + id=_dnsify(entity.id), + metadata=entity.metadata, + inputs=entity.bindings, + upstream_node_ids=[n.id for n in upstream_sdk_nodes], + output_aliases=[], + workflow_node=workflow_model.WorkflowNode(sub_workflow_ref=wf_template.id), + ) + elif isinstance(entity.flyte_entity, FlyteLaunchPlan): + # Recursive call doesn't do anything except put the entity on the map. + get_serializable(entity_mapping, settings, entity.flyte_entity) + # Node's inputs should not contain the data which is fixed input + node_input = [] + for b in entity.bindings: + if b.var not in entity.flyte_entity.fixed_inputs.literals: + node_input.append(b) + + node_model = workflow_model.Node( + id=_dnsify(entity.id), + metadata=entity.metadata, + inputs=node_input, + upstream_node_ids=[n.id for n in upstream_sdk_nodes], + output_aliases=[], + workflow_node=workflow_model.WorkflowNode(launchplan_ref=entity.flyte_entity.id), + ) else: raise Exception(f"Node contained non-serializable entity {entity._flyte_entity}") @@ -375,6 +435,11 @@ def get_serializable( :return: The resulting control plane entity, in addition to being added to the mutable entity_mapping parameter is also returned. """ + # TODO: Try to replace following config refactor - https://github.com/flyteorg/flyte/issues/2214 + from flytekit.remote.launch_plan import FlyteLaunchPlan + from flytekit.remote.task import FlyteTask + from flytekit.remote.workflow import FlyteWorkflow + if entity in entity_mapping: return entity_mapping[entity] @@ -395,6 +460,10 @@ def get_serializable( elif isinstance(entity, BranchNode): cp_entity = get_serializable_branch_node(entity_mapping, settings, entity) + + elif isinstance(entity, FlyteTask) or isinstance(entity, FlyteWorkflow) or isinstance(entity, FlyteLaunchPlan): + cp_entity = entity + else: raise Exception(f"Non serializable type found {type(entity)} Entity {entity}") diff --git a/tests/flytekit/unit/core/test_serialization.py b/tests/flytekit/unit/core/test_serialization.py index d3395e9fd5..11616203cd 100644 --- a/tests/flytekit/unit/core/test_serialization.py +++ b/tests/flytekit/unit/core/test_serialization.py @@ -354,7 +354,7 @@ def middle_subwf() -> typing.Tuple[int, int]: @workflow def parent_wf() -> typing.Tuple[int, int, int, int]: m1, m2 = middle_subwf() - l1, l2 = leaf_subwf() + l1, l2 = leaf_subwf().with_overrides(node_name="foo-node") return m1, m2, l1, l2 wf_spec = get_serializable(OrderedDict(), serialization_settings, parent_wf) @@ -366,6 +366,8 @@ def parent_wf() -> typing.Tuple[int, int, int, int]: assert len(midwf.nodes) == 1 assert midwf.nodes[0].workflow_node is not None assert midwf.nodes[0].workflow_node.sub_workflow_ref.name == "test_serialization.leaf_subwf" + assert wf_spec.template.nodes[1].id == "foo-node" + assert wf_spec.template.outputs[2].binding.promise.node_id == "foo-node" def test_serialization_named_outputs_single(): diff --git a/tests/flytekit/unit/remote/test_calling.py b/tests/flytekit/unit/remote/test_calling.py index 97a1a001dc..07483487ab 100644 --- a/tests/flytekit/unit/remote/test_calling.py +++ b/tests/flytekit/unit/remote/test_calling.py @@ -3,14 +3,19 @@ import pytest +from flytekit import dynamic from flytekit.core import context_manager -from flytekit.core.context_manager import Image, ImageConfig +from flytekit.core.context_manager import ExecutionState, FastSerializationSettings, Image, ImageConfig from flytekit.core.launch_plan import LaunchPlan -from flytekit.core.reference_entity import ReferenceSpec from flytekit.core.task import task +from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import workflow +from flytekit.exceptions.user import FlyteAssertion +from flytekit.models.core.workflow import WorkflowTemplate +from flytekit.models.task import TaskTemplate from flytekit.remote import FlyteLaunchPlan, FlyteTask from flytekit.remote.interface import TypedInterface +from flytekit.remote.workflow import FlyteWorkflow from flytekit.tools.translator import gather_dependent_entities, get_serializable default_img = Image(name="default", fqn="test", tag="tag") @@ -48,18 +53,28 @@ def sub_wf(a: int, b: str) -> (int, str): def test_fetched_task(): @workflow def wf(a: int) -> int: - return ft(a=a) + return ft(a=a).with_overrides(node_name="foobar") # Should not work unless mocked out. with pytest.raises(Exception, match="cannot be run locally"): wf(a=3) - # Should have one reference entity + # Should have one task template serialized = OrderedDict() - get_serializable(serialized, serialization_settings, wf) + wf_spec = get_serializable(serialized, serialization_settings, wf) vals = [v for v in serialized.values()] - refs = [f for f in filter(lambda x: isinstance(x, ReferenceSpec), vals)] - assert len(refs) == 1 + tts = [f for f in filter(lambda x: isinstance(x, TaskTemplate), vals)] + assert len(tts) == 1 + assert wf_spec.template.nodes[0].id == "foobar" + assert wf_spec.template.outputs[0].binding.promise.node_id == "foobar" + + +def test_misnamed(): + with pytest.raises(FlyteAssertion): + + @workflow + def wf(a: int) -> int: + return ft(b=a) def test_calling_lp(): @@ -83,3 +98,84 @@ def wf2(a: int) -> typing.Tuple[int, str]: wf_spec = get_serializable(serialized, serialization_settings, wf2) print(wf_spec.template.nodes[0].workflow_node.launchplan_ref) assert wf_spec.template.nodes[0].workflow_node.launchplan_ref == lp_model.id + + +def test_dynamic(): + @dynamic + def my_subwf(a: int) -> typing.List[int]: + s = [] + for i in range(a): + s.append(ft(a=i)) + return s + + with context_manager.FlyteContextManager.with_context( + context_manager.FlyteContextManager.current_context().with_serialization_settings( + context_manager.SerializationSettings( + project="test_proj", + domain="test_domain", + version="abc", + image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), + env={}, + fast_serialization_settings=FastSerializationSettings(enabled=True), + ) + ) + ) as ctx: + with context_manager.FlyteContextManager.with_context( + ctx.with_execution_state( + ctx.execution_state.with_params( + mode=ExecutionState.Mode.TASK_EXECUTION, + additional_context={ + "dynamic_addl_distro": "s3://my-s3-bucket/fast/123", + "dynamic_dest_dir": "/User/flyte/workflows", + }, + ) + ) + ) as ctx: + input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"a": 2}) + # Test that it works + dynamic_job_spec = my_subwf.dispatch_execute(ctx, input_literal_map) + assert len(dynamic_job_spec._nodes) == 2 + assert len(dynamic_job_spec.tasks) == 1 + assert dynamic_job_spec.tasks[0].id == ft.id + + # Test that the fast execute stuff does not get applied because the commands of tasks fetched from + # Admin should never change. + args = " ".join(dynamic_job_spec.tasks[0].container.args) + assert not args.startswith("pyflyte-fast-execute") + + +def test_calling_wf(): + # No way to fetch from Admin in unit tests so we serialize and then promote back + serialized = OrderedDict() + wf_spec = get_serializable(serialized, serialization_settings, sub_wf) + task_templates, wf_specs, lp_specs = gather_dependent_entities(serialized) + fwf = FlyteWorkflow.promote_from_model(wf_spec.template, tasks=task_templates) + + @workflow + def parent_1(a: int, b: str) -> typing.Tuple[int, str]: + y = t1(a=a) + return fwf(a=y, b=b) + + # No way to fetch from Admin in unit tests so we serialize and then promote back + serialized = OrderedDict() + wf_spec = get_serializable(serialized, serialization_settings, parent_1) + # Get task_specs from the second one, merge with the first one. Admin normally would be the one to do this. + task_templates_p1, wf_specs, lp_specs = gather_dependent_entities(serialized) + for k, v in task_templates.items(): + task_templates_p1[k] = v + + # Pick out the subworkflow templates from the ordereddict. We can't use the output of the gather_dependent_entities + # function because that only looks for WorkflowSpecs + subwf_templates = {x.id: x for x in list(filter(lambda x: isinstance(x, WorkflowTemplate), serialized.values()))} + fwf_p1 = FlyteWorkflow.promote_from_model(wf_spec.template, sub_workflows=subwf_templates, tasks=task_templates_p1) + + @workflow + def parent_2(a: int, b: str) -> typing.Tuple[int, str]: + x, y = fwf_p1(a=a, b=b) + z = t1(a=x) + return z, y + + serialized = OrderedDict() + wf_spec = get_serializable(serialized, serialization_settings, parent_2) + # Make sure both were picked up. + assert len(wf_spec.sub_workflows) == 2