Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expose subworkflow inputs, outputs in node execution #503

Merged
merged 8 commits into from
Jun 4, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions flytekit/clients/friendly.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,16 +670,18 @@ def list_node_executions(
token=None,
filters=None,
sort_by=None,
unique_parent_id: str = None,
):
"""
TODO: Comment
"""Get node executions associated with a given workflow execution.

:param flytekit.models.core.identifier.WorkflowExecutionIdentifier workflow_execution_identifier:
:param int limit:
:param Text token: [Optional] If specified, this specifies where in the rows of results to skip before reading.
If you previously retrieved a page response with token="foo" and you want the next page,
specify token="foo".
If you previously retrieved a page response with token="foo" and you want the next page,
specify token="foo".
:param list[flytekit.models.filters.Filter] filters:
:param flytekit.models.admin.common.Sort sort_by: [Optional] If provided, the results will be sorted.
:param unique_parent_id:
cosmicBboy marked this conversation as resolved.
Show resolved Hide resolved
:rtype: list[flytekit.models.node_execution.NodeExecution], Text
"""
exec_list = super(SynchronousFlyteClient, self).list_node_executions_paginated(
Expand All @@ -689,6 +691,7 @@ def list_node_executions(
token=token,
filters=_filters.FilterList(filters or []).to_flyte_idl(),
sort_by=None if sort_by is None else sort_by.to_flyte_idl(),
unique_parent_id=unique_parent_id,
)
)
return (
Expand Down
2 changes: 2 additions & 0 deletions flytekit/clients/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ def iterate_node_executions(
task_execution_identifier=None,
limit=None,
filters=None,
unique_parent_id=None,
):
"""
This returns a generator for node executions.
Expand All @@ -26,6 +27,7 @@ def iterate_node_executions(
limit=num_to_fetch,
token=token,
filters=filters,
unique_parent_id=unique_parent_id,
)
else:
node_execs, next_token = client.list_node_executions_for_task_paginated(
Expand Down
2 changes: 1 addition & 1 deletion flytekit/control_plane/component_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def promote_from_model(
base_model.reference.version,
)

if base_model.launch_plan_ref is not None:
if base_model.launchplan_ref is not None:
return cls(flyte_launch_plan=_launch_plan.FlyteLaunchPlan.fetch(*fetch_args))
elif base_model.sub_workflow_ref is not None:
# the workflow tempaltes for sub-workflows should have been included in the original response
Expand Down
79 changes: 52 additions & 27 deletions flytekit/control_plane/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from flyteidl.core import literals_pb2 as _literals_pb2

import flytekit
from flytekit.clients.helpers import iterate_task_executions as _iterate_task_executions
from flytekit.clients.helpers import iterate_node_executions, iterate_task_executions
from flytekit.common import constants as _constants
from flytekit.common import utils as _common_utils
from flytekit.common.exceptions import system as _system_exceptions
Expand All @@ -15,7 +15,7 @@
from flytekit.common.utils import _dnsify
from flytekit.control_plane import component_nodes as _component_nodes
from flytekit.control_plane import identifier as _identifier
from flytekit.control_plane.tasks import executions as _task_executions
from flytekit.control_plane.tasks.executions import FlyteTaskExecution
from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.promise import NodeOutput
from flytekit.core.type_engine import TypeEngine
Expand Down Expand Up @@ -154,14 +154,14 @@ def with_overrides(self, *args, **kwargs):
raise NotImplementedError("Overrides are not supported in Flyte yet.")

def __repr__(self) -> str:
return f"Node(ID: {self.id} Executable: {self._executable_flyte_object})"
return f"Node(ID: {self.id})"


class FlyteNodeExecution(_node_execution_models.NodeExecution, _artifact_mixin.ExecutionArtifact):
def __init__(self, *args, **kwargs):
super(FlyteNodeExecution, self).__init__(*args, **kwargs)
self._task_executions = None
self._workflow_executions = None
self._subworkflow_node_executions = None
self._inputs = None
self._outputs = None

Expand All @@ -170,42 +170,42 @@ def task_executions(self) -> List["flytekit.control_plane.tasks.executions.Flyte
return self._task_executions or []

@property
def workflow_executions(self) -> List["flytekit.control_plane.workflow_executions.FlyteWorkflowExecution"]:
return self._workflow_executions or []
def subworkflow_node_executions(self) -> Dict[str, "flytekit.control_plane.nodes.FlyteNodeExecution"]:
return (
{}
if self._subworkflow_node_executions is None
else {n.id.node_id: n for n in self._subworkflow_node_executions}
)

@property
def executions(self) -> _artifact_mixin.ExecutionArtifact:
return self.task_executions or self.workflow_executions or []
def executions(self) -> List[_artifact_mixin.ExecutionArtifact]:
return self.task_executions or list(self.subworkflow_node_executions.values()) or []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the latter half of this just self._subworkflow_node_executions? why the transformations?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_subworkflow_node_executions is a Dict mapping node id to node execution, so it needs to be transformed to output a list of node executions.

We could alternatively have executions output a Union[List, Dict] depending on whether it's task executions of subworkflow node executions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but the property subworkflow_node_executions is a Dict (see above)... we don't have to convert it into a dict, but I found it convenient to be able to access the subworkflow node executions by node_id

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right but we don't necessarily need to access these through the property since it's all in the same class right? sorry this is a minor performance nit and not really worth refactoring :)


@property
def inputs(self) -> Dict[str, Any]:
"""
Returns the inputs to the execution in the standard python format as dictated by the type engine.
"""
from flytekit.control_plane.tasks.task import FlyteTask

if self._inputs is None:
client = _flyte_engine.get_client()
execution_data = client.get_node_execution_data(self.id)
node_execution_data = client.get_node_execution_data(self.id)

# Inputs are returned inline unless they are too big, in which case a url blob pointing to them is returned.
input_map: _literal_models.LiteralMap = _literal_models.LiteralMap({})
if bool(execution_data.full_inputs.literals):
input_map = execution_data.full_inputs
elif execution_data.inputs.bytes > 0:
if bool(node_execution_data.full_inputs.literals):
input_map = node_execution_data.full_inputs
elif node_execution_data.inputs.bytes > 0:
with _common_utils.AutoDeletingTempDir() as tmp_dir:
tmp_name = _os.path.join(tmp_dir.name, "inputs.pb")
_data_proxy.Data.get_data(execution_data.inputs.url, tmp_name)
_data_proxy.Data.get_data(node_execution_data.inputs.url, tmp_name)
input_map = _literal_models.LiteralMap.from_flyte_idl(
_common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name)
)

task_id = self.task_executions[0].id.task_id
task = FlyteTask.fetch(task_id.project, task_id.domain, task_id.name, task_id.version)
self._inputs = TypeEngine.literal_map_to_kwargs(
ctx=FlyteContextManager.current_context(),
lm=input_map,
python_types=TypeEngine.guess_python_types(task.interface.inputs),
python_types=TypeEngine.guess_python_types(self.get_interface().inputs),
)
return self._inputs

Expand All @@ -216,8 +216,6 @@ def outputs(self) -> Dict[str, Any]:

:raises: ``FlyteAssertion`` error if execution is in progress or execution ended in error.
"""
from flytekit.control_plane.tasks.task import FlyteTask

if not self.is_complete:
raise _user_exceptions.FlyteAssertion(
"Please wait until the node execution has completed before requesting the outputs."
Expand All @@ -241,12 +239,10 @@ def outputs(self) -> Dict[str, Any]:
_common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name)
)

task_id = self.task_executions[0].id.task_id
task = FlyteTask.fetch(task_id.project, task_id.domain, task_id.name, task_id.version)
self._outputs = TypeEngine.literal_map_to_kwargs(
ctx=FlyteContextManager.current_context(),
lm=output_map,
python_types=TypeEngine.guess_python_types(task.interface.outputs),
python_types=TypeEngine.guess_python_types(self.get_interface().outputs),
)
return self._outputs

Expand Down Expand Up @@ -279,18 +275,47 @@ def promote_from_model(cls, base_model: _node_execution_models.NodeExecution) ->
closure=base_model.closure, id=base_model.id, input_uri=base_model.input_uri, metadata=base_model.metadata
)

def get_interface(self) -> "flytekit.control_plane.interface.TypedInterface":
"""
Return the interface of the task or subworkflow associated with this node execution.
"""
from flytekit.control_plane.tasks.task import FlyteTask
from flytekit.control_plane.workflow import FlyteWorkflow

if not self.metadata.is_parent_node:
# if not a parent node, assume a task execution node
task_id = self.task_executions[0].id.task_id
task = FlyteTask.fetch(task_id.project, task_id.domain, task_id.name, task_id.version)
return task.interface

# otherwise assume the node is associated with a subworkflow
client = _flyte_engine.get_client()
lp_id = client.get_execution(self.id.execution_id).spec.launch_plan
workflow = FlyteWorkflow.fetch(lp_id.project, lp_id.domain, lp_id.name, lp_id.version)
flyte_subworkflow_node: FlyteNode = [n for n in workflow.nodes if n.id == self.id.node_id][0]
return flyte_subworkflow_node.target.flyte_workflow.interface

def sync(self):
"""
Syncs the state of the underlying execution artifact with the state observed by the platform.
"""
if not self.is_complete or self.task_executions is not None:
if not self.is_complete or self._task_executions is None:
cosmicBboy marked this conversation as resolved.
Show resolved Hide resolved
client = _flyte_engine.get_client()
self._closure = client.get_node_execution(self.id).closure
self._task_executions = [
_task_executions.FlyteTaskExecution.promote_from_model(t)
for t in _iterate_task_executions(client, self.id)
FlyteTaskExecution.promote_from_model(t) for t in iterate_task_executions(client, self.id)
]
# TODO: sync sub-workflows as well

if self.metadata.is_parent_node and (not self.is_complete or self._subworkflow_node_executions is None):
self._subworkflow_node_executions = [
FlyteNodeExecution.promote_from_model(n)
for n in iterate_node_executions(
client, workflow_execution_identifier=self.id.execution_id, unique_parent_id=self.id.node_id
)
]

for execution in self.executions:
execution.sync()

def _sync_closure(self):
"""
Expand Down
1 change: 1 addition & 0 deletions flytekit/control_plane/tasks/task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from flytekit.common.exceptions import scopes as _exception_scopes
from flytekit.common.exceptions import user as _user_exceptions
from flytekit.common.mixins import hash as _hash_mixin
from flytekit.control_plane import identifier as _identifier
from flytekit.control_plane import interface as _interfaces
Expand Down
1 change: 1 addition & 0 deletions flytekit/control_plane/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from flytekit.common import constants as _constants
from flytekit.common.exceptions import scopes as _exception_scopes
from flytekit.common.exceptions import system as _system_exceptions
from flytekit.common.exceptions import user as _user_exceptions
from flytekit.common.mixins import hash as _hash_mixin
from flytekit.control_plane import identifier as _identifier
from flytekit.control_plane import interface as _interfaces
Expand Down
11 changes: 2 additions & 9 deletions flytekit/models/node_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,7 @@ def closure(self):
def metadata(self) -> NodeExecutionMetaData:
return self._metadata

def to_flyte_idl(self):
"""
:rtype: flyteidl.admin.node_execution_pb2.NodeExecution
"""
def to_flyte_idl(self) -> _node_execution_pb2.NodeExecution:
return _node_execution_pb2.NodeExecution(
id=self.id.to_flyte_idl(),
input_uri=self.input_uri,
Expand All @@ -168,11 +165,7 @@ def to_flyte_idl(self):
)

@classmethod
def from_flyte_idl(cls, p):
"""
:param flyteidl.admin.node_execution_pb2.NodeExecution p:
:rtype: NodeExecution
"""
def from_flyte_idl(cls, p: _node_execution_pb2.NodeExecution) -> "NodeExecution":
return cls(
id=_identifier.NodeExecutionIdentifier.from_flyte_idl(p.id),
input_uri=p.input_uri,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import typing

from flytekit import task, workflow


@task
def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str):
return a + 2, "world"


@workflow
def my_subwf(a: int = 42) -> (str, str):
x, y = t1(a=a)
u, v = t1(a=x)
return y, v


@workflow
def parent_wf(a: int) -> (int, str, str):
x, y = t1(a=a)
u, v = my_subwf(a=x)
return x, u, v


if __name__ == "__main__":
print(f"Running my_wf(a=3) {parent_wf(a=3)}")
21 changes: 21 additions & 0 deletions tests/flytekit/integration/control_plane/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,24 @@ def test_monitor_workflow(flyteclient, flyte_workflows_register):
assert execution.node_executions["n0"].task_executions[0].outputs["o0"] == "hello world"
assert execution.inputs == {}
assert execution.outputs["o0"] == "hello world"


def test_launch_workflow_with_subworkflows(flyteclient, flyte_workflows_register):
execution = launch_plan.FlyteLaunchPlan.fetch(
PROJECT, "development", "workflows.basic.subworkflows.parent_wf", f"v{VERSION}"
).launch_with_literals(
PROJECT,
"development",
literals.LiteralMap({"a": literals.Literal(literals.Scalar(literals.Primitive(integer=101)))}),
)
execution.wait_for_completion()
# check node execution inputs and outputs
assert execution.node_executions["n0"].inputs == {"a": 101}
assert execution.node_executions["n0"].outputs == {"t1_int_output": 103, "c": "world"}
assert execution.node_executions["n1"].inputs == {"a": 103}
assert execution.node_executions["n1"].outputs == {"o0": "world", "o1": "world"}

# check subworkflow task execution inputs and outputs
subworkflow_node_executions = execution.node_executions["n1"].subworkflow_node_executions
subworkflow_node_executions["n1-0-n0"].inputs == {"a": 103}
subworkflow_node_executions["n1-0-n1"].outputs == {"t1_int_output": 107, "c": "world"}