Skip to content

Commit

Permalink
expose subworkflow inputs, outputs in node execution (#503)
Browse files Browse the repository at this point in the history
* expose subworkflow inputs, outputs in node execution

Signed-off-by: cosmicBboy <[email protected]>

* fix lint

Signed-off-by: cosmicBboy <[email protected]>

* add docstring

Signed-off-by: cosmicBboy <[email protected]>

* add docstring to list_node_executions, update sync logic

Signed-off-by: cosmicBboy <[email protected]>

* add comment to subworkflow get_interface @wild-endeavor

Signed-off-by: cosmicBboy <[email protected]>

* update executions property

Signed-off-by: cosmicBboy <[email protected]>

* cache interface for node execution, redefine as property

Signed-off-by: cosmicBboy <[email protected]>

* fix lint

Signed-off-by: cosmicBboy <[email protected]>
Signed-off-by: Haytham Abuelfutuh <[email protected]>
  • Loading branch information
cosmicBboy authored and EngHabu committed Jun 25, 2021
1 parent ba1784b commit 7a13759
Show file tree
Hide file tree
Showing 9 changed files with 140 additions and 51 deletions.
26 changes: 16 additions & 10 deletions flytekit/clients/friendly.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import typing

import six as _six
from flyteidl.admin import common_pb2 as _common_pb2
from flyteidl.admin import execution_pb2 as _execution_pb2
Expand All @@ -19,6 +21,7 @@
from flytekit.models import node_execution as _node_execution
from flytekit.models import project as _project
from flytekit.models import task as _task
from flytekit.models.admin import common as _admin_common
from flytekit.models.admin import task_execution as _task_execution
from flytekit.models.admin import workflow as _workflow
from flytekit.models.core import identifier as _identifier
Expand Down Expand Up @@ -666,20 +669,22 @@ def get_node_execution_data(self, node_execution_identifier):
def list_node_executions(
self,
workflow_execution_identifier,
limit=100,
token=None,
filters=None,
sort_by=None,
limit: int = 100,
token: typing.Optional[str] = None,
filters: typing.List[_filters.Filter] = None,
sort_by: _admin_common.Sort = None,
unique_parent_id: str = None,
):
"""
TODO: Comment
"""Get node executions associated with a given workflow execution.
:param flytekit.models.core.identifier.WorkflowExecutionIdentifier workflow_execution_identifier:
:param int limit:
:param Text token: [Optional] If specified, this specifies where in the rows of results to skip before reading.
If you previously retrieved a page response with token="foo" and you want the next page,
specify token="foo".
:param limit: Limit the number of items returned in the response.
:param token: If specified, this specifies where in the rows of results to skip before reading.
If you previously retrieved a page response with token="foo" and you want the next page,
specify ``token="foo"``.
:param list[flytekit.models.filters.Filter] filters:
:param flytekit.models.admin.common.Sort sort_by: [Optional] If provided, the results will be sorted.
:param unique_parent_id: If specified, returns the node executions for the ``unique_parent_id`` node id.
:rtype: list[flytekit.models.node_execution.NodeExecution], Text
"""
exec_list = super(SynchronousFlyteClient, self).list_node_executions_paginated(
Expand All @@ -689,6 +694,7 @@ def list_node_executions(
token=token,
filters=_filters.FilterList(filters or []).to_flyte_idl(),
sort_by=None if sort_by is None else sort_by.to_flyte_idl(),
unique_parent_id=unique_parent_id,
)
)
return (
Expand Down
2 changes: 2 additions & 0 deletions flytekit/clients/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ def iterate_node_executions(
task_execution_identifier=None,
limit=None,
filters=None,
unique_parent_id=None,
):
"""
This returns a generator for node executions.
Expand All @@ -26,6 +27,7 @@ def iterate_node_executions(
limit=num_to_fetch,
token=token,
filters=filters,
unique_parent_id=unique_parent_id,
)
else:
node_execs, next_token = client.list_node_executions_for_task_paginated(
Expand Down
2 changes: 1 addition & 1 deletion flytekit/control_plane/component_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def promote_from_model(
base_model.reference.version,
)

if base_model.launch_plan_ref is not None:
if base_model.launchplan_ref is not None:
return cls(flyte_launch_plan=_launch_plan.FlyteLaunchPlan.fetch(*fetch_args))
elif base_model.sub_workflow_ref is not None:
# the workflow tempaltes for sub-workflows should have been included in the original response
Expand Down
101 changes: 70 additions & 31 deletions flytekit/control_plane/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from flyteidl.core import literals_pb2 as _literals_pb2

import flytekit
from flytekit.clients.helpers import iterate_task_executions as _iterate_task_executions
from flytekit.clients.helpers import iterate_node_executions, iterate_task_executions
from flytekit.common import constants as _constants
from flytekit.common import utils as _common_utils
from flytekit.common.exceptions import system as _system_exceptions
Expand All @@ -15,7 +15,7 @@
from flytekit.common.utils import _dnsify
from flytekit.control_plane import component_nodes as _component_nodes
from flytekit.control_plane import identifier as _identifier
from flytekit.control_plane.tasks import executions as _task_executions
from flytekit.control_plane.tasks.executions import FlyteTaskExecution
from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.promise import NodeOutput
from flytekit.core.type_engine import TypeEngine
Expand Down Expand Up @@ -154,58 +154,59 @@ def with_overrides(self, *args, **kwargs):
raise NotImplementedError("Overrides are not supported in Flyte yet.")

def __repr__(self) -> str:
return f"Node(ID: {self.id} Executable: {self._executable_flyte_object})"
return f"Node(ID: {self.id})"


class FlyteNodeExecution(_node_execution_models.NodeExecution, _artifact_mixin.ExecutionArtifact):
def __init__(self, *args, **kwargs):
super(FlyteNodeExecution, self).__init__(*args, **kwargs)
self._task_executions = None
self._workflow_executions = None
self._subworkflow_node_executions = None
self._inputs = None
self._outputs = None
self._interface = None

@property
def task_executions(self) -> List["flytekit.control_plane.tasks.executions.FlyteTaskExecution"]:
return self._task_executions or []

@property
def workflow_executions(self) -> List["flytekit.control_plane.workflow_executions.FlyteWorkflowExecution"]:
return self._workflow_executions or []
def subworkflow_node_executions(self) -> Dict[str, "flytekit.control_plane.nodes.FlyteNodeExecution"]:
return (
{}
if self._subworkflow_node_executions is None
else {n.id.node_id: n for n in self._subworkflow_node_executions}
)

@property
def executions(self) -> _artifact_mixin.ExecutionArtifact:
return self.task_executions or self.workflow_executions or []
def executions(self) -> List[_artifact_mixin.ExecutionArtifact]:
return self.task_executions or self._subworkflow_node_executions or []

@property
def inputs(self) -> Dict[str, Any]:
"""
Returns the inputs to the execution in the standard python format as dictated by the type engine.
"""
from flytekit.control_plane.tasks.task import FlyteTask

if self._inputs is None:
client = _flyte_engine.get_client()
execution_data = client.get_node_execution_data(self.id)
node_execution_data = client.get_node_execution_data(self.id)

# Inputs are returned inline unless they are too big, in which case a url blob pointing to them is returned.
input_map: _literal_models.LiteralMap = _literal_models.LiteralMap({})
if bool(execution_data.full_inputs.literals):
input_map = execution_data.full_inputs
elif execution_data.inputs.bytes > 0:
if bool(node_execution_data.full_inputs.literals):
input_map = node_execution_data.full_inputs
elif node_execution_data.inputs.bytes > 0:
with _common_utils.AutoDeletingTempDir() as tmp_dir:
tmp_name = _os.path.join(tmp_dir.name, "inputs.pb")
_data_proxy.Data.get_data(execution_data.inputs.url, tmp_name)
_data_proxy.Data.get_data(node_execution_data.inputs.url, tmp_name)
input_map = _literal_models.LiteralMap.from_flyte_idl(
_common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name)
)

task_id = self.task_executions[0].id.task_id
task = FlyteTask.fetch(task_id.project, task_id.domain, task_id.name, task_id.version)
self._inputs = TypeEngine.literal_map_to_kwargs(
ctx=FlyteContextManager.current_context(),
lm=input_map,
python_types=TypeEngine.guess_python_types(task.interface.inputs),
python_types=TypeEngine.guess_python_types(self.interface.inputs),
)
return self._inputs

Expand All @@ -216,8 +217,6 @@ def outputs(self) -> Dict[str, Any]:
:raises: ``FlyteAssertion`` error if execution is in progress or execution ended in error.
"""
from flytekit.control_plane.tasks.task import FlyteTask

if not self.is_complete:
raise _user_exceptions.FlyteAssertion(
"Please wait until the node execution has completed before requesting the outputs."
Expand All @@ -241,12 +240,10 @@ def outputs(self) -> Dict[str, Any]:
_common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name)
)

task_id = self.task_executions[0].id.task_id
task = FlyteTask.fetch(task_id.project, task_id.domain, task_id.name, task_id.version)
self._outputs = TypeEngine.literal_map_to_kwargs(
ctx=FlyteContextManager.current_context(),
lm=output_map,
python_types=TypeEngine.guess_python_types(task.interface.outputs),
python_types=TypeEngine.guess_python_types(self.interface.outputs),
)
return self._outputs

Expand Down Expand Up @@ -279,18 +276,60 @@ def promote_from_model(cls, base_model: _node_execution_models.NodeExecution) ->
closure=base_model.closure, id=base_model.id, input_uri=base_model.input_uri, metadata=base_model.metadata
)

@property
def interface(self) -> "flytekit.control_plane.interface.TypedInterface":
"""
Return the interface of the task or subworkflow associated with this node execution.
"""
if self._interface is None:

from flytekit.control_plane.tasks.task import FlyteTask
from flytekit.control_plane.workflow import FlyteWorkflow

if not self.metadata.is_parent_node:
# if not a parent node, assume a task execution node
task_id = self.task_executions[0].id.task_id
task = FlyteTask.fetch(task_id.project, task_id.domain, task_id.name, task_id.version)
self._interface = task.interface
else:
# otherwise assume the node is associated with a subworkflow
client = _flyte_engine.get_client()

# need to get the FlyteWorkflow associated with this node execution (self), so we need to fetch the
# parent workflow and iterate through the parent's FlyteNodes to get the the FlyteWorkflow object
# representing the subworkflow. This allows us to get the interface for guessing the types of the
# inputs/outputs.
lp_id = client.get_execution(self.id.execution_id).spec.launch_plan
workflow = FlyteWorkflow.fetch(lp_id.project, lp_id.domain, lp_id.name, lp_id.version)
flyte_subworkflow_node: FlyteNode = [n for n in workflow.nodes if n.id == self.id.node_id][0]
self._interface = flyte_subworkflow_node.target.flyte_workflow.interface

return self._interface

def sync(self):
"""
Syncs the state of the underlying execution artifact with the state observed by the platform.
"""
if not self.is_complete or self.task_executions is not None:
client = _flyte_engine.get_client()
self._closure = client.get_node_execution(self.id).closure
self._task_executions = [
_task_executions.FlyteTaskExecution.promote_from_model(t)
for t in _iterate_task_executions(client, self.id)
]
# TODO: sync sub-workflows as well
if self.metadata.is_parent_node:
if not self.is_complete or self._subworkflow_node_executions is None:
self._subworkflow_node_executions = [
FlyteNodeExecution.promote_from_model(n)
for n in iterate_node_executions(
_flyte_engine.get_client(),
workflow_execution_identifier=self.id.execution_id,
unique_parent_id=self.id.node_id,
)
]
else:
if not self.is_complete or self._task_executions is None:
self._task_executions = [
FlyteTaskExecution.promote_from_model(t)
for t in iterate_task_executions(_flyte_engine.get_client(), self.id)
]

self._sync_closure()
for execution in self.executions:
execution.sync()

def _sync_closure(self):
"""
Expand Down
1 change: 1 addition & 0 deletions flytekit/control_plane/tasks/task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from flytekit.common.exceptions import scopes as _exception_scopes
from flytekit.common.exceptions import user as _user_exceptions
from flytekit.common.mixins import hash as _hash_mixin
from flytekit.control_plane import identifier as _identifier
from flytekit.control_plane import interface as _interfaces
Expand Down
1 change: 1 addition & 0 deletions flytekit/control_plane/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from flytekit.common import constants as _constants
from flytekit.common.exceptions import scopes as _exception_scopes
from flytekit.common.exceptions import system as _system_exceptions
from flytekit.common.exceptions import user as _user_exceptions
from flytekit.common.mixins import hash as _hash_mixin
from flytekit.control_plane import identifier as _identifier
from flytekit.control_plane import interface as _interfaces
Expand Down
11 changes: 2 additions & 9 deletions flytekit/models/node_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,7 @@ def closure(self):
def metadata(self) -> NodeExecutionMetaData:
return self._metadata

def to_flyte_idl(self):
"""
:rtype: flyteidl.admin.node_execution_pb2.NodeExecution
"""
def to_flyte_idl(self) -> _node_execution_pb2.NodeExecution:
return _node_execution_pb2.NodeExecution(
id=self.id.to_flyte_idl(),
input_uri=self.input_uri,
Expand All @@ -168,11 +165,7 @@ def to_flyte_idl(self):
)

@classmethod
def from_flyte_idl(cls, p):
"""
:param flyteidl.admin.node_execution_pb2.NodeExecution p:
:rtype: NodeExecution
"""
def from_flyte_idl(cls, p: _node_execution_pb2.NodeExecution) -> "NodeExecution":
return cls(
id=_identifier.NodeExecutionIdentifier.from_flyte_idl(p.id),
input_uri=p.input_uri,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import typing

from flytekit import task, workflow


@task
def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str):
return a + 2, "world"


@workflow
def my_subwf(a: int = 42) -> (str, str):
x, y = t1(a=a)
u, v = t1(a=x)
return y, v


@workflow
def parent_wf(a: int) -> (int, str, str):
x, y = t1(a=a)
u, v = my_subwf(a=x)
return x, u, v


if __name__ == "__main__":
print(f"Running my_wf(a=3) {parent_wf(a=3)}")
21 changes: 21 additions & 0 deletions tests/flytekit/integration/control_plane/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,24 @@ def test_monitor_workflow(flyteclient, flyte_workflows_register):
assert execution.node_executions["n0"].task_executions[0].outputs["o0"] == "hello world"
assert execution.inputs == {}
assert execution.outputs["o0"] == "hello world"


def test_launch_workflow_with_subworkflows(flyteclient, flyte_workflows_register):
execution = launch_plan.FlyteLaunchPlan.fetch(
PROJECT, "development", "workflows.basic.subworkflows.parent_wf", f"v{VERSION}"
).launch_with_literals(
PROJECT,
"development",
literals.LiteralMap({"a": literals.Literal(literals.Scalar(literals.Primitive(integer=101)))}),
)
execution.wait_for_completion()
# check node execution inputs and outputs
assert execution.node_executions["n0"].inputs == {"a": 101}
assert execution.node_executions["n0"].outputs == {"t1_int_output": 103, "c": "world"}
assert execution.node_executions["n1"].inputs == {"a": 103}
assert execution.node_executions["n1"].outputs == {"o0": "world", "o1": "world"}

# check subworkflow task execution inputs and outputs
subworkflow_node_executions = execution.node_executions["n1"].subworkflow_node_executions
subworkflow_node_executions["n1-0-n0"].inputs == {"a": 103}
subworkflow_node_executions["n1-0-n1"].outputs == {"t1_int_output": 107, "c": "world"}

0 comments on commit 7a13759

Please sign in to comment.