Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expose subworkflow inputs, outputs in node execution #503

Merged
merged 8 commits into from
Jun 4, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions flytekit/clients/friendly.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import typing

import six as _six
from flyteidl.admin import common_pb2 as _common_pb2
from flyteidl.admin import execution_pb2 as _execution_pb2
Expand All @@ -19,6 +21,7 @@
from flytekit.models import node_execution as _node_execution
from flytekit.models import project as _project
from flytekit.models import task as _task
from flytekit.models.admin import common as _admin_common
from flytekit.models.admin import task_execution as _task_execution
from flytekit.models.admin import workflow as _workflow
from flytekit.models.core import identifier as _identifier
Expand Down Expand Up @@ -666,20 +669,22 @@ def get_node_execution_data(self, node_execution_identifier):
def list_node_executions(
self,
workflow_execution_identifier,
limit=100,
token=None,
filters=None,
sort_by=None,
limit: int = 100,
token: typing.Optional[str] = None,
filters: typing.List[_filters.Filter] = None,
sort_by: _admin_common.Sort = None,
unique_parent_id: str = None,
):
"""
TODO: Comment
"""Get node executions associated with a given workflow execution.

:param flytekit.models.core.identifier.WorkflowExecutionIdentifier workflow_execution_identifier:
:param int limit:
:param Text token: [Optional] If specified, this specifies where in the rows of results to skip before reading.
If you previously retrieved a page response with token="foo" and you want the next page,
specify token="foo".
:param limit: Limit the number of items returned in the response.
:param token: If specified, this specifies where in the rows of results to skip before reading.
If you previously retrieved a page response with token="foo" and you want the next page,
specify ``token="foo"``.
:param list[flytekit.models.filters.Filter] filters:
:param flytekit.models.admin.common.Sort sort_by: [Optional] If provided, the results will be sorted.
:param unique_parent_id: If specified, returns the node executions for the ``unique_parent_id`` node id.
:rtype: list[flytekit.models.node_execution.NodeExecution], Text
"""
exec_list = super(SynchronousFlyteClient, self).list_node_executions_paginated(
Expand All @@ -689,6 +694,7 @@ def list_node_executions(
token=token,
filters=_filters.FilterList(filters or []).to_flyte_idl(),
sort_by=None if sort_by is None else sort_by.to_flyte_idl(),
unique_parent_id=unique_parent_id,
)
)
return (
Expand Down
2 changes: 2 additions & 0 deletions flytekit/clients/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ def iterate_node_executions(
task_execution_identifier=None,
limit=None,
filters=None,
unique_parent_id=None,
):
"""
This returns a generator for node executions.
Expand All @@ -26,6 +27,7 @@ def iterate_node_executions(
limit=num_to_fetch,
token=token,
filters=filters,
unique_parent_id=unique_parent_id,
)
else:
node_execs, next_token = client.list_node_executions_for_task_paginated(
Expand Down
2 changes: 1 addition & 1 deletion flytekit/control_plane/component_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def promote_from_model(
base_model.reference.version,
)

if base_model.launch_plan_ref is not None:
if base_model.launchplan_ref is not None:
return cls(flyte_launch_plan=_launch_plan.FlyteLaunchPlan.fetch(*fetch_args))
elif base_model.sub_workflow_ref is not None:
# the workflow tempaltes for sub-workflows should have been included in the original response
Expand Down
101 changes: 70 additions & 31 deletions flytekit/control_plane/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from flyteidl.core import literals_pb2 as _literals_pb2

import flytekit
from flytekit.clients.helpers import iterate_task_executions as _iterate_task_executions
from flytekit.clients.helpers import iterate_node_executions, iterate_task_executions
from flytekit.common import constants as _constants
from flytekit.common import utils as _common_utils
from flytekit.common.exceptions import system as _system_exceptions
Expand All @@ -15,7 +15,7 @@
from flytekit.common.utils import _dnsify
from flytekit.control_plane import component_nodes as _component_nodes
from flytekit.control_plane import identifier as _identifier
from flytekit.control_plane.tasks import executions as _task_executions
from flytekit.control_plane.tasks.executions import FlyteTaskExecution
from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.promise import NodeOutput
from flytekit.core.type_engine import TypeEngine
Expand Down Expand Up @@ -154,58 +154,59 @@ def with_overrides(self, *args, **kwargs):
raise NotImplementedError("Overrides are not supported in Flyte yet.")

def __repr__(self) -> str:
return f"Node(ID: {self.id} Executable: {self._executable_flyte_object})"
return f"Node(ID: {self.id})"


class FlyteNodeExecution(_node_execution_models.NodeExecution, _artifact_mixin.ExecutionArtifact):
def __init__(self, *args, **kwargs):
super(FlyteNodeExecution, self).__init__(*args, **kwargs)
self._task_executions = None
self._workflow_executions = None
self._subworkflow_node_executions = None
self._inputs = None
self._outputs = None
self._interface = None

@property
def task_executions(self) -> List["flytekit.control_plane.tasks.executions.FlyteTaskExecution"]:
return self._task_executions or []

@property
def workflow_executions(self) -> List["flytekit.control_plane.workflow_executions.FlyteWorkflowExecution"]:
return self._workflow_executions or []
def subworkflow_node_executions(self) -> Dict[str, "flytekit.control_plane.nodes.FlyteNodeExecution"]:
return (
{}
if self._subworkflow_node_executions is None
else {n.id.node_id: n for n in self._subworkflow_node_executions}
)

@property
def executions(self) -> _artifact_mixin.ExecutionArtifact:
return self.task_executions or self.workflow_executions or []
def executions(self) -> List[_artifact_mixin.ExecutionArtifact]:
return self.task_executions or self._subworkflow_node_executions or []

@property
def inputs(self) -> Dict[str, Any]:
"""
Returns the inputs to the execution in the standard python format as dictated by the type engine.
"""
from flytekit.control_plane.tasks.task import FlyteTask

if self._inputs is None:
client = _flyte_engine.get_client()
execution_data = client.get_node_execution_data(self.id)
node_execution_data = client.get_node_execution_data(self.id)

# Inputs are returned inline unless they are too big, in which case a url blob pointing to them is returned.
input_map: _literal_models.LiteralMap = _literal_models.LiteralMap({})
if bool(execution_data.full_inputs.literals):
input_map = execution_data.full_inputs
elif execution_data.inputs.bytes > 0:
if bool(node_execution_data.full_inputs.literals):
input_map = node_execution_data.full_inputs
elif node_execution_data.inputs.bytes > 0:
with _common_utils.AutoDeletingTempDir() as tmp_dir:
tmp_name = _os.path.join(tmp_dir.name, "inputs.pb")
_data_proxy.Data.get_data(execution_data.inputs.url, tmp_name)
_data_proxy.Data.get_data(node_execution_data.inputs.url, tmp_name)
input_map = _literal_models.LiteralMap.from_flyte_idl(
_common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name)
)

task_id = self.task_executions[0].id.task_id
task = FlyteTask.fetch(task_id.project, task_id.domain, task_id.name, task_id.version)
self._inputs = TypeEngine.literal_map_to_kwargs(
ctx=FlyteContextManager.current_context(),
lm=input_map,
python_types=TypeEngine.guess_python_types(task.interface.inputs),
python_types=TypeEngine.guess_python_types(self.interface.inputs),
)
return self._inputs

Expand All @@ -216,8 +217,6 @@ def outputs(self) -> Dict[str, Any]:

:raises: ``FlyteAssertion`` error if execution is in progress or execution ended in error.
"""
from flytekit.control_plane.tasks.task import FlyteTask

if not self.is_complete:
raise _user_exceptions.FlyteAssertion(
"Please wait until the node execution has completed before requesting the outputs."
Expand All @@ -241,12 +240,10 @@ def outputs(self) -> Dict[str, Any]:
_common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name)
)

task_id = self.task_executions[0].id.task_id
task = FlyteTask.fetch(task_id.project, task_id.domain, task_id.name, task_id.version)
self._outputs = TypeEngine.literal_map_to_kwargs(
ctx=FlyteContextManager.current_context(),
lm=output_map,
python_types=TypeEngine.guess_python_types(task.interface.outputs),
python_types=TypeEngine.guess_python_types(self.interface.outputs),
)
return self._outputs

Expand Down Expand Up @@ -279,18 +276,60 @@ def promote_from_model(cls, base_model: _node_execution_models.NodeExecution) ->
closure=base_model.closure, id=base_model.id, input_uri=base_model.input_uri, metadata=base_model.metadata
)

@property
def interface(self) -> "flytekit.control_plane.interface.TypedInterface":
"""
Return the interface of the task or subworkflow associated with this node execution.
"""
if self._interface is None:

from flytekit.control_plane.tasks.task import FlyteTask
from flytekit.control_plane.workflow import FlyteWorkflow

if not self.metadata.is_parent_node:
# if not a parent node, assume a task execution node
task_id = self.task_executions[0].id.task_id
task = FlyteTask.fetch(task_id.project, task_id.domain, task_id.name, task_id.version)
self._interface = task.interface
else:
# otherwise assume the node is associated with a subworkflow
client = _flyte_engine.get_client()

# need to get the FlyteWorkflow associated with this node execution (self), so we need to fetch the
# parent workflow and iterate through the parent's FlyteNodes to get the the FlyteWorkflow object
# representing the subworkflow. This allows us to get the interface for guessing the types of the
# inputs/outputs.
lp_id = client.get_execution(self.id.execution_id).spec.launch_plan
workflow = FlyteWorkflow.fetch(lp_id.project, lp_id.domain, lp_id.name, lp_id.version)
flyte_subworkflow_node: FlyteNode = [n for n in workflow.nodes if n.id == self.id.node_id][0]
self._interface = flyte_subworkflow_node.target.flyte_workflow.interface

return self._interface

def sync(self):
"""
Syncs the state of the underlying execution artifact with the state observed by the platform.
"""
if not self.is_complete or self.task_executions is not None:
client = _flyte_engine.get_client()
self._closure = client.get_node_execution(self.id).closure
self._task_executions = [
_task_executions.FlyteTaskExecution.promote_from_model(t)
for t in _iterate_task_executions(client, self.id)
]
# TODO: sync sub-workflows as well
if self.metadata.is_parent_node:
Copy link
Contributor

@wild-endeavor wild-endeavor Jun 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this field true for the parent node of a dynamic workflow?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still need to test that... in another PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be false for dynamic workflows

if not self.is_complete or self._subworkflow_node_executions is None:
self._subworkflow_node_executions = [
FlyteNodeExecution.promote_from_model(n)
for n in iterate_node_executions(
_flyte_engine.get_client(),
workflow_execution_identifier=self.id.execution_id,
unique_parent_id=self.id.node_id,
)
]
else:
if not self.is_complete or self._task_executions is None:
self._task_executions = [
FlyteTaskExecution.promote_from_model(t)
for t in iterate_task_executions(_flyte_engine.get_client(), self.id)
]

self._sync_closure()
for execution in self.executions:
execution.sync()

def _sync_closure(self):
"""
Expand Down
1 change: 1 addition & 0 deletions flytekit/control_plane/tasks/task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from flytekit.common.exceptions import scopes as _exception_scopes
from flytekit.common.exceptions import user as _user_exceptions
from flytekit.common.mixins import hash as _hash_mixin
from flytekit.control_plane import identifier as _identifier
from flytekit.control_plane import interface as _interfaces
Expand Down
1 change: 1 addition & 0 deletions flytekit/control_plane/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from flytekit.common import constants as _constants
from flytekit.common.exceptions import scopes as _exception_scopes
from flytekit.common.exceptions import system as _system_exceptions
from flytekit.common.exceptions import user as _user_exceptions
from flytekit.common.mixins import hash as _hash_mixin
from flytekit.control_plane import identifier as _identifier
from flytekit.control_plane import interface as _interfaces
Expand Down
11 changes: 2 additions & 9 deletions flytekit/models/node_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,7 @@ def closure(self):
def metadata(self) -> NodeExecutionMetaData:
return self._metadata

def to_flyte_idl(self):
"""
:rtype: flyteidl.admin.node_execution_pb2.NodeExecution
"""
def to_flyte_idl(self) -> _node_execution_pb2.NodeExecution:
return _node_execution_pb2.NodeExecution(
id=self.id.to_flyte_idl(),
input_uri=self.input_uri,
Expand All @@ -168,11 +165,7 @@ def to_flyte_idl(self):
)

@classmethod
def from_flyte_idl(cls, p):
"""
:param flyteidl.admin.node_execution_pb2.NodeExecution p:
:rtype: NodeExecution
"""
def from_flyte_idl(cls, p: _node_execution_pb2.NodeExecution) -> "NodeExecution":
return cls(
id=_identifier.NodeExecutionIdentifier.from_flyte_idl(p.id),
input_uri=p.input_uri,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import typing

from flytekit import task, workflow


@task
def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str):
return a + 2, "world"


@workflow
def my_subwf(a: int = 42) -> (str, str):
x, y = t1(a=a)
u, v = t1(a=x)
return y, v


@workflow
def parent_wf(a: int) -> (int, str, str):
x, y = t1(a=a)
u, v = my_subwf(a=x)
return x, u, v


if __name__ == "__main__":
print(f"Running my_wf(a=3) {parent_wf(a=3)}")
21 changes: 21 additions & 0 deletions tests/flytekit/integration/control_plane/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,24 @@ def test_monitor_workflow(flyteclient, flyte_workflows_register):
assert execution.node_executions["n0"].task_executions[0].outputs["o0"] == "hello world"
assert execution.inputs == {}
assert execution.outputs["o0"] == "hello world"


def test_launch_workflow_with_subworkflows(flyteclient, flyte_workflows_register):
execution = launch_plan.FlyteLaunchPlan.fetch(
PROJECT, "development", "workflows.basic.subworkflows.parent_wf", f"v{VERSION}"
).launch_with_literals(
PROJECT,
"development",
literals.LiteralMap({"a": literals.Literal(literals.Scalar(literals.Primitive(integer=101)))}),
)
execution.wait_for_completion()
# check node execution inputs and outputs
assert execution.node_executions["n0"].inputs == {"a": 101}
assert execution.node_executions["n0"].outputs == {"t1_int_output": 103, "c": "world"}
assert execution.node_executions["n1"].inputs == {"a": 103}
assert execution.node_executions["n1"].outputs == {"o0": "world", "o1": "world"}

# check subworkflow task execution inputs and outputs
subworkflow_node_executions = execution.node_executions["n1"].subworkflow_node_executions
subworkflow_node_executions["n1-0-n0"].inputs == {"a": 103}
subworkflow_node_executions["n1-0-n1"].outputs == {"t1_int_output": 107, "c": "world"}