diff --git a/flytekit/clients/friendly.py b/flytekit/clients/friendly.py index a5128f7e6d..3dc78db395 100644 --- a/flytekit/clients/friendly.py +++ b/flytekit/clients/friendly.py @@ -1,3 +1,5 @@ +import typing + import six as _six from flyteidl.admin import common_pb2 as _common_pb2 from flyteidl.admin import execution_pb2 as _execution_pb2 @@ -19,6 +21,7 @@ from flytekit.models import node_execution as _node_execution from flytekit.models import project as _project from flytekit.models import task as _task +from flytekit.models.admin import common as _admin_common from flytekit.models.admin import task_execution as _task_execution from flytekit.models.admin import workflow as _workflow from flytekit.models.core import identifier as _identifier @@ -666,20 +669,22 @@ def get_node_execution_data(self, node_execution_identifier): def list_node_executions( self, workflow_execution_identifier, - limit=100, - token=None, - filters=None, - sort_by=None, + limit: int = 100, + token: typing.Optional[str] = None, + filters: typing.List[_filters.Filter] = None, + sort_by: _admin_common.Sort = None, + unique_parent_id: str = None, ): - """ - TODO: Comment + """Get node executions associated with a given workflow execution. + :param flytekit.models.core.identifier.WorkflowExecutionIdentifier workflow_execution_identifier: - :param int limit: - :param Text token: [Optional] If specified, this specifies where in the rows of results to skip before reading. - If you previously retrieved a page response with token="foo" and you want the next page, - specify token="foo". + :param limit: Limit the number of items returned in the response. + :param token: If specified, this specifies where in the rows of results to skip before reading. + If you previously retrieved a page response with token="foo" and you want the next page, + specify ``token="foo"``. :param list[flytekit.models.filters.Filter] filters: :param flytekit.models.admin.common.Sort sort_by: [Optional] If provided, the results will be sorted. + :param unique_parent_id: If specified, returns the node executions for the ``unique_parent_id`` node id. :rtype: list[flytekit.models.node_execution.NodeExecution], Text """ exec_list = super(SynchronousFlyteClient, self).list_node_executions_paginated( @@ -689,6 +694,7 @@ def list_node_executions( token=token, filters=_filters.FilterList(filters or []).to_flyte_idl(), sort_by=None if sort_by is None else sort_by.to_flyte_idl(), + unique_parent_id=unique_parent_id, ) ) return ( diff --git a/flytekit/clients/helpers.py b/flytekit/clients/helpers.py index 4d8a7912a4..4e2dc71e25 100644 --- a/flytekit/clients/helpers.py +++ b/flytekit/clients/helpers.py @@ -4,6 +4,7 @@ def iterate_node_executions( task_execution_identifier=None, limit=None, filters=None, + unique_parent_id=None, ): """ This returns a generator for node executions. @@ -26,6 +27,7 @@ def iterate_node_executions( limit=num_to_fetch, token=token, filters=filters, + unique_parent_id=unique_parent_id, ) else: node_execs, next_token = client.list_node_executions_for_task_paginated( diff --git a/flytekit/control_plane/component_nodes.py b/flytekit/control_plane/component_nodes.py index 10434ab830..a0a28aac3b 100644 --- a/flytekit/control_plane/component_nodes.py +++ b/flytekit/control_plane/component_nodes.py @@ -112,7 +112,7 @@ def promote_from_model( base_model.reference.version, ) - if base_model.launch_plan_ref is not None: + if base_model.launchplan_ref is not None: return cls(flyte_launch_plan=_launch_plan.FlyteLaunchPlan.fetch(*fetch_args)) elif base_model.sub_workflow_ref is not None: # the workflow tempaltes for sub-workflows should have been included in the original response diff --git a/flytekit/control_plane/nodes.py b/flytekit/control_plane/nodes.py index fa53194c28..3dd787019a 100644 --- a/flytekit/control_plane/nodes.py +++ b/flytekit/control_plane/nodes.py @@ -5,7 +5,7 @@ from flyteidl.core import literals_pb2 as _literals_pb2 import flytekit -from flytekit.clients.helpers import iterate_task_executions as _iterate_task_executions +from flytekit.clients.helpers import iterate_node_executions, iterate_task_executions from flytekit.common import constants as _constants from flytekit.common import utils as _common_utils from flytekit.common.exceptions import system as _system_exceptions @@ -15,7 +15,7 @@ from flytekit.common.utils import _dnsify from flytekit.control_plane import component_nodes as _component_nodes from flytekit.control_plane import identifier as _identifier -from flytekit.control_plane.tasks import executions as _task_executions +from flytekit.control_plane.tasks.executions import FlyteTaskExecution from flytekit.core.context_manager import FlyteContextManager from flytekit.core.promise import NodeOutput from flytekit.core.type_engine import TypeEngine @@ -154,58 +154,59 @@ def with_overrides(self, *args, **kwargs): raise NotImplementedError("Overrides are not supported in Flyte yet.") def __repr__(self) -> str: - return f"Node(ID: {self.id} Executable: {self._executable_flyte_object})" + return f"Node(ID: {self.id})" class FlyteNodeExecution(_node_execution_models.NodeExecution, _artifact_mixin.ExecutionArtifact): def __init__(self, *args, **kwargs): super(FlyteNodeExecution, self).__init__(*args, **kwargs) self._task_executions = None - self._workflow_executions = None + self._subworkflow_node_executions = None self._inputs = None self._outputs = None + self._interface = None @property def task_executions(self) -> List["flytekit.control_plane.tasks.executions.FlyteTaskExecution"]: return self._task_executions or [] @property - def workflow_executions(self) -> List["flytekit.control_plane.workflow_executions.FlyteWorkflowExecution"]: - return self._workflow_executions or [] + def subworkflow_node_executions(self) -> Dict[str, "flytekit.control_plane.nodes.FlyteNodeExecution"]: + return ( + {} + if self._subworkflow_node_executions is None + else {n.id.node_id: n for n in self._subworkflow_node_executions} + ) @property - def executions(self) -> _artifact_mixin.ExecutionArtifact: - return self.task_executions or self.workflow_executions or [] + def executions(self) -> List[_artifact_mixin.ExecutionArtifact]: + return self.task_executions or self._subworkflow_node_executions or [] @property def inputs(self) -> Dict[str, Any]: """ Returns the inputs to the execution in the standard python format as dictated by the type engine. """ - from flytekit.control_plane.tasks.task import FlyteTask - if self._inputs is None: client = _flyte_engine.get_client() - execution_data = client.get_node_execution_data(self.id) + node_execution_data = client.get_node_execution_data(self.id) # Inputs are returned inline unless they are too big, in which case a url blob pointing to them is returned. input_map: _literal_models.LiteralMap = _literal_models.LiteralMap({}) - if bool(execution_data.full_inputs.literals): - input_map = execution_data.full_inputs - elif execution_data.inputs.bytes > 0: + if bool(node_execution_data.full_inputs.literals): + input_map = node_execution_data.full_inputs + elif node_execution_data.inputs.bytes > 0: with _common_utils.AutoDeletingTempDir() as tmp_dir: tmp_name = _os.path.join(tmp_dir.name, "inputs.pb") - _data_proxy.Data.get_data(execution_data.inputs.url, tmp_name) + _data_proxy.Data.get_data(node_execution_data.inputs.url, tmp_name) input_map = _literal_models.LiteralMap.from_flyte_idl( _common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name) ) - task_id = self.task_executions[0].id.task_id - task = FlyteTask.fetch(task_id.project, task_id.domain, task_id.name, task_id.version) self._inputs = TypeEngine.literal_map_to_kwargs( ctx=FlyteContextManager.current_context(), lm=input_map, - python_types=TypeEngine.guess_python_types(task.interface.inputs), + python_types=TypeEngine.guess_python_types(self.interface.inputs), ) return self._inputs @@ -216,8 +217,6 @@ def outputs(self) -> Dict[str, Any]: :raises: ``FlyteAssertion`` error if execution is in progress or execution ended in error. """ - from flytekit.control_plane.tasks.task import FlyteTask - if not self.is_complete: raise _user_exceptions.FlyteAssertion( "Please wait until the node execution has completed before requesting the outputs." @@ -241,12 +240,10 @@ def outputs(self) -> Dict[str, Any]: _common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name) ) - task_id = self.task_executions[0].id.task_id - task = FlyteTask.fetch(task_id.project, task_id.domain, task_id.name, task_id.version) self._outputs = TypeEngine.literal_map_to_kwargs( ctx=FlyteContextManager.current_context(), lm=output_map, - python_types=TypeEngine.guess_python_types(task.interface.outputs), + python_types=TypeEngine.guess_python_types(self.interface.outputs), ) return self._outputs @@ -279,18 +276,60 @@ def promote_from_model(cls, base_model: _node_execution_models.NodeExecution) -> closure=base_model.closure, id=base_model.id, input_uri=base_model.input_uri, metadata=base_model.metadata ) + @property + def interface(self) -> "flytekit.control_plane.interface.TypedInterface": + """ + Return the interface of the task or subworkflow associated with this node execution. + """ + if self._interface is None: + + from flytekit.control_plane.tasks.task import FlyteTask + from flytekit.control_plane.workflow import FlyteWorkflow + + if not self.metadata.is_parent_node: + # if not a parent node, assume a task execution node + task_id = self.task_executions[0].id.task_id + task = FlyteTask.fetch(task_id.project, task_id.domain, task_id.name, task_id.version) + self._interface = task.interface + else: + # otherwise assume the node is associated with a subworkflow + client = _flyte_engine.get_client() + + # need to get the FlyteWorkflow associated with this node execution (self), so we need to fetch the + # parent workflow and iterate through the parent's FlyteNodes to get the the FlyteWorkflow object + # representing the subworkflow. This allows us to get the interface for guessing the types of the + # inputs/outputs. + lp_id = client.get_execution(self.id.execution_id).spec.launch_plan + workflow = FlyteWorkflow.fetch(lp_id.project, lp_id.domain, lp_id.name, lp_id.version) + flyte_subworkflow_node: FlyteNode = [n for n in workflow.nodes if n.id == self.id.node_id][0] + self._interface = flyte_subworkflow_node.target.flyte_workflow.interface + + return self._interface + def sync(self): """ Syncs the state of the underlying execution artifact with the state observed by the platform. """ - if not self.is_complete or self.task_executions is not None: - client = _flyte_engine.get_client() - self._closure = client.get_node_execution(self.id).closure - self._task_executions = [ - _task_executions.FlyteTaskExecution.promote_from_model(t) - for t in _iterate_task_executions(client, self.id) - ] - # TODO: sync sub-workflows as well + if self.metadata.is_parent_node: + if not self.is_complete or self._subworkflow_node_executions is None: + self._subworkflow_node_executions = [ + FlyteNodeExecution.promote_from_model(n) + for n in iterate_node_executions( + _flyte_engine.get_client(), + workflow_execution_identifier=self.id.execution_id, + unique_parent_id=self.id.node_id, + ) + ] + else: + if not self.is_complete or self._task_executions is None: + self._task_executions = [ + FlyteTaskExecution.promote_from_model(t) + for t in iterate_task_executions(_flyte_engine.get_client(), self.id) + ] + + self._sync_closure() + for execution in self.executions: + execution.sync() def _sync_closure(self): """ diff --git a/flytekit/control_plane/tasks/task.py b/flytekit/control_plane/tasks/task.py index 71f159e523..9797b94fa1 100644 --- a/flytekit/control_plane/tasks/task.py +++ b/flytekit/control_plane/tasks/task.py @@ -1,4 +1,5 @@ from flytekit.common.exceptions import scopes as _exception_scopes +from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.mixins import hash as _hash_mixin from flytekit.control_plane import identifier as _identifier from flytekit.control_plane import interface as _interfaces diff --git a/flytekit/control_plane/workflow.py b/flytekit/control_plane/workflow.py index 9b9dd26de7..810afeeb38 100644 --- a/flytekit/control_plane/workflow.py +++ b/flytekit/control_plane/workflow.py @@ -3,6 +3,7 @@ from flytekit.common import constants as _constants from flytekit.common.exceptions import scopes as _exception_scopes from flytekit.common.exceptions import system as _system_exceptions +from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.mixins import hash as _hash_mixin from flytekit.control_plane import identifier as _identifier from flytekit.control_plane import interface as _interfaces diff --git a/flytekit/models/node_execution.py b/flytekit/models/node_execution.py index 65a8ef015a..5a0cda6a6e 100644 --- a/flytekit/models/node_execution.py +++ b/flytekit/models/node_execution.py @@ -156,10 +156,7 @@ def closure(self): def metadata(self) -> NodeExecutionMetaData: return self._metadata - def to_flyte_idl(self): - """ - :rtype: flyteidl.admin.node_execution_pb2.NodeExecution - """ + def to_flyte_idl(self) -> _node_execution_pb2.NodeExecution: return _node_execution_pb2.NodeExecution( id=self.id.to_flyte_idl(), input_uri=self.input_uri, @@ -168,11 +165,7 @@ def to_flyte_idl(self): ) @classmethod - def from_flyte_idl(cls, p): - """ - :param flyteidl.admin.node_execution_pb2.NodeExecution p: - :rtype: NodeExecution - """ + def from_flyte_idl(cls, p: _node_execution_pb2.NodeExecution) -> "NodeExecution": return cls( id=_identifier.NodeExecutionIdentifier.from_flyte_idl(p.id), input_uri=p.input_uri, diff --git a/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/basic/subworkflows.py b/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/basic/subworkflows.py new file mode 100644 index 0000000000..ec96bedd42 --- /dev/null +++ b/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/basic/subworkflows.py @@ -0,0 +1,26 @@ +import typing + +from flytekit import task, workflow + + +@task +def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): + return a + 2, "world" + + +@workflow +def my_subwf(a: int = 42) -> (str, str): + x, y = t1(a=a) + u, v = t1(a=x) + return y, v + + +@workflow +def parent_wf(a: int) -> (int, str, str): + x, y = t1(a=a) + u, v = my_subwf(a=x) + return x, u, v + + +if __name__ == "__main__": + print(f"Running my_wf(a=3) {parent_wf(a=3)}") diff --git a/tests/flytekit/integration/control_plane/test_workflow.py b/tests/flytekit/integration/control_plane/test_workflow.py index 90c51be020..bb2281bc02 100644 --- a/tests/flytekit/integration/control_plane/test_workflow.py +++ b/tests/flytekit/integration/control_plane/test_workflow.py @@ -102,3 +102,24 @@ def test_monitor_workflow(flyteclient, flyte_workflows_register): assert execution.node_executions["n0"].task_executions[0].outputs["o0"] == "hello world" assert execution.inputs == {} assert execution.outputs["o0"] == "hello world" + + +def test_launch_workflow_with_subworkflows(flyteclient, flyte_workflows_register): + execution = launch_plan.FlyteLaunchPlan.fetch( + PROJECT, "development", "workflows.basic.subworkflows.parent_wf", f"v{VERSION}" + ).launch_with_literals( + PROJECT, + "development", + literals.LiteralMap({"a": literals.Literal(literals.Scalar(literals.Primitive(integer=101)))}), + ) + execution.wait_for_completion() + # check node execution inputs and outputs + assert execution.node_executions["n0"].inputs == {"a": 101} + assert execution.node_executions["n0"].outputs == {"t1_int_output": 103, "c": "world"} + assert execution.node_executions["n1"].inputs == {"a": 103} + assert execution.node_executions["n1"].outputs == {"o0": "world", "o1": "world"} + + # check subworkflow task execution inputs and outputs + subworkflow_node_executions = execution.node_executions["n1"].subworkflow_node_executions + subworkflow_node_executions["n1-0-n0"].inputs == {"a": 103} + subworkflow_node_executions["n1-0-n1"].outputs == {"t1_int_output": 107, "c": "world"}