Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make fetched entities callable within workflows #867

Merged
merged 12 commits into from
Mar 2, 2022
117 changes: 110 additions & 7 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,11 +548,21 @@ def __rshift__(self, other: Any):
return Output(*promises) # type: ignore


def binding_from_flyte_std(
ctx: _flyte_context.FlyteContext,
var_name: str,
expected_literal_type: _type_models.LiteralType,
t_value: typing.Any,
) -> _literals_models.Binding:
binding_data = binding_data_from_python_std(ctx, expected_literal_type, t_value, t_value_type=None)
return _literals_models.Binding(var=var_name, binding=binding_data)


def binding_data_from_python_std(
ctx: _flyte_context.FlyteContext,
expected_literal_type: _type_models.LiteralType,
t_value: typing.Any,
t_value_type: type,
t_value_type: Optional[type] = None,
) -> _literals_models.BindingData:
# This handles the case where the given value is the output of another task
if isinstance(t_value, Promise):
Expand All @@ -568,7 +578,7 @@ def binding_data_from_python_std(
if expected_literal_type.collection_type is None:
raise AssertionError(f"this should be a list and it is not: {type(t_value)} vs {expected_literal_type}")

sub_type = ListTransformer.get_sub_type(t_value_type)
sub_type = ListTransformer.get_sub_type(t_value_type) if t_value_type else None
collection = _literals_models.BindingDataCollection(
bindings=[
binding_data_from_python_std(ctx, expected_literal_type.collection_type, t, sub_type) for t in t_value
Expand All @@ -585,11 +595,11 @@ def binding_data_from_python_std(
raise AssertionError(
f"this should be a Dictionary type and it is not: {type(t_value)} vs {expected_literal_type}"
)
k_type, v_type = DictTransformer.get_dict_types(t_value_type)
if expected_literal_type.simple == _type_models.SimpleType.STRUCT:
lit = TypeEngine.to_literal(ctx, t_value, type(t_value), expected_literal_type)
return _literals_models.BindingData(scalar=lit.scalar)
else:
_, v_type = DictTransformer.get_dict_types(t_value_type) if t_value_type else None, None
m = _literals_models.BindingDataMap(
bindings={
k: binding_data_from_python_std(ctx, expected_literal_type.map_value_type, v, v_type)
Expand All @@ -607,7 +617,7 @@ def binding_data_from_python_std(
)

# This is the scalar case - e.g. my_task(in1=5)
scalar = TypeEngine.to_literal(ctx, t_value, t_value_type, expected_literal_type).scalar
scalar = TypeEngine.to_literal(ctx, t_value, t_value_type or type(t_value), expected_literal_type).scalar
return _literals_models.BindingData(scalar=scalar)


Expand Down Expand Up @@ -703,7 +713,8 @@ def __init__(self, node: Node, var: str):
@property
def node_id(self):
"""
Override the underlying node_id property to refer to SdkNode.
Override the underlying node_id property to refer to the Node's id. This is to make sure that overriding
node IDs from with_overrides gets serialized correctly.
:rtype: Text
"""
return self.node.id
Expand Down Expand Up @@ -731,6 +742,19 @@ def construct_node_metadata(self) -> _workflow_model.NodeMetadata:
...


class HasFlyteInterface(Protocol):
@property
def name(self) -> str:
...

@property
def interface(self) -> _interface_models.TypedInterface:
...

def construct_node_metadata(self) -> _workflow_model.NodeMetadata:
...


def extract_obj_name(name: str) -> str:
"""
Generates a shortened name, without the module information. Useful for node-names etc. Only extracts the final
Expand All @@ -743,6 +767,87 @@ def extract_obj_name(name: str) -> str:
return name


def create_and_link_node_from_remote(
ctx: FlyteContext,
entity: HasFlyteInterface,
**kwargs,
):
"""
This method is used to generate a node with bindings. This is not used in the execution path.
"""
if ctx.compilation_state is None:
raise _user_exceptions.FlyteAssertion("Cannot create node when not compiling...")

used_inputs = set()
bindings = []

typed_interface = entity.interface

for k in sorted(typed_interface.inputs):
var = typed_interface.inputs[k]
if k not in kwargs:
raise _user_exceptions.FlyteAssertion("Input was not specified for: {} of type {}".format(k, var.type))
v = kwargs[k]
# This check ensures that tuples are not passed into a function, as tuples are not supported by Flyte
# Usually a Tuple will indicate that multiple outputs from a previous task were accidentally passed
# into the function.
if isinstance(v, tuple):
raise AssertionError(
f"Variable({k}) for function({entity.name}) cannot receive a multi-valued tuple {v}."
f" Check if the predecessor function returning more than one value?"
)
try:
bindings.append(
binding_from_flyte_std(
ctx,
var_name=k,
expected_literal_type=var.type,
t_value=v,
)
)
used_inputs.add(k)
except Exception as e:
raise AssertionError(f"Failed to Bind variable {k} for function {entity.name}.") from e

extra_inputs = used_inputs ^ set(kwargs.keys())
if len(extra_inputs) > 0:
raise _user_exceptions.FlyteAssertion(
"Too many inputs were specified for the interface. Extra inputs were: {}".format(extra_inputs)
)

# Detect upstream nodes
# These will be our core Nodes until we can amend the Promise to use NodeOutputs that reference our Nodes
upstream_nodes = list(
set(
[
input_val.ref.node
for input_val in kwargs.values()
if isinstance(input_val, Promise) and input_val.ref.node_id != _common_constants.GLOBAL_INPUT_NODE_ID
]
)
)

flytekit_node = Node(
# TODO: Better naming, probably a derivative of the function name.
id=f"{ctx.compilation_state.prefix}n{len(ctx.compilation_state.nodes)}",
metadata=entity.construct_node_metadata(),
bindings=sorted(bindings, key=lambda b: b.var),
upstream_nodes=upstream_nodes,
flyte_entity=entity,
)
ctx.compilation_state.add_node(flytekit_node)

if len(typed_interface.outputs) == 0:
return VoidPromise(entity.name)

# Create a node output object for each output, they should all point to this node of course.
node_outputs = []
for output_name, output_var_model in typed_interface.outputs.items():
node_outputs.append(Promise(output_name, NodeOutput(node=flytekit_node, var=output_name)))

return create_task_output(node_outputs)


def create_and_link_node(
ctx: FlyteContext,
entity: SupportsNodeCreation,
Expand Down Expand Up @@ -819,8 +924,6 @@ def create_and_link_node(
# Create a node output object for each output, they should all point to this node of course.
node_outputs = []
for output_name, output_var_model in typed_interface.outputs.items():
# TODO: If node id gets updated later, we have to make sure to update the NodeOutput model's ID, which
# is currently just a static str
node_outputs.append(Promise(output_name, NodeOutput(node=flytekit_node, var=output_name)))
# Don't print this, it'll crash cuz sdk_node._upstream_node_ids might be None, but idl code will break

Expand Down
7 changes: 6 additions & 1 deletion flytekit/core/python_function_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,12 @@ def compile_into_workflow(
for entity, model in model_entities.items():
# We only care about gathering tasks here. Launch plans are handled by
# propeller. Subworkflows should already be in the workflow spec.
if not isinstance(entity, Task):
if not isinstance(entity, Task) and not isinstance(entity, task_models.TaskTemplate):
continue

# Handle FlyteTask
if isinstance(entity, task_models.TaskTemplate):
tts.append(entity)
continue

# We are currently not supporting reference tasks since these will
Expand Down
45 changes: 8 additions & 37 deletions flytekit/remote/launch_plan.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,32 @@
from __future__ import annotations

from typing import Optional

from flytekit.core import hash as hash_mixin
from flytekit.core.interface import Interface
from flytekit.core.launch_plan import ReferenceLaunchPlan
from flytekit.core.type_engine import TypeEngine
from flytekit.loggers import remote_logger as logger
from flytekit.models import interface as _interface_models
from flytekit.models import launch_plan as _launch_plan_models
from flytekit.models.core import identifier as id_models
from flytekit.remote import interface as _interface
from flytekit.remote.remote_callable import RemoteEntity


class FlyteLaunchPlan(_launch_plan_models.LaunchPlanSpec):
class FlyteLaunchPlan(hash_mixin.HashOnReferenceMixin, RemoteEntity, _launch_plan_models.LaunchPlanSpec):
"""A class encapsulating a remote Flyte launch plan."""

def __init__(self, id, *args, **kwargs):
super(FlyteLaunchPlan, self).__init__(*args, **kwargs)
# Set all the attributes we expect this class to have
self._id = id
self._name = id.name

# The interface is not set explicitly unless fetched in an engine context
self._interface = None
self._python_interface = None
self._reference_entity = None

def __call__(self, *args, **kwargs):
if self.reference_entity is None:
logger.warning(
f"FlyteLaunchPlan {self} is not callable, most likely because flytekit could not "
f"guess the python interface. The workflow calling this launch plan may not behave correctly."
)
return
return self.reference_entity(*args, **kwargs)

# TODO: Refactor behind mixin
@property
def reference_entity(self) -> Optional[ReferenceLaunchPlan]:
if self._reference_entity is None:
if self.guessed_python_interface is None:
try:
self.guessed_python_interface = Interface(
TypeEngine.guess_python_types(self.interface.inputs),
TypeEngine.guess_python_types(self.interface.outputs),
)
except Exception as e:
logger.warning(f"Error backing out interface {e}, Flyte interface {self.interface}")
return None

self._reference_entity = ReferenceLaunchPlan(
self.id.project,
self.id.domain,
self.id.name,
self.id.version,
inputs=self.guessed_python_interface.inputs,
outputs=self.guessed_python_interface.outputs,
)
return self._reference_entity
def name(self) -> str:
return self._name

@classmethod
def promote_from_model(
Expand All @@ -71,7 +43,6 @@ def promote_from_model(
auth_role=model.auth_role,
raw_output_data_config=model.raw_output_data_config,
)

return lp

@property
Expand Down
61 changes: 61 additions & 0 deletions flytekit/remote/remote_callable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from abc import ABC, abstractmethod
from typing import Any, Optional, Tuple, Union

from flytekit.core.context_manager import BranchEvalMode, ExecutionState, FlyteContext
from flytekit.core.promise import Promise, VoidPromise, create_and_link_node_from_remote, extract_obj_name
from flytekit.exceptions import user as user_exceptions
from flytekit.loggers import remote_logger as logger
from flytekit.models.core.workflow import NodeMetadata


class RemoteEntity(ABC):
@property
@abstractmethod
def name(self) -> str:
...

def construct_node_metadata(self) -> NodeMetadata:
"""
Used when constructing the node that encapsulates this task as part of a broader workflow definition.
"""
return NodeMetadata(
name=extract_obj_name(self.name),
)

def compile(self, ctx: FlyteContext, *args, **kwargs):
return create_and_link_node_from_remote(ctx, entity=self, **kwargs) # noqa

def __call__(self, *args, **kwargs):
# When a Task is () aka __called__, there are three things we may do:
# a. Plain execution Mode - just run the execute function. If not overridden, we should raise an exception
# b. Compilation Mode - this happens when the function is called as part of a workflow (potentially
# dynamic task). Produce promise objects and create a node.
# c. Workflow Execution Mode - when a workflow is being run locally. Even though workflows are functions
# and everything should be able to be passed through naturally, we'll want to wrap output values of the
# function into objects, so that potential .with_cpu or other ancillary functions can be attached to do
# nothing. Subsequent tasks will have to know how to unwrap these. If by chance a non-Flyte task uses a
# task output as an input, things probably will fail pretty obviously.
# Since this is a reference entity, it still needs to be mocked otherwise an exception will be raised.
if len(args) > 0:
raise user_exceptions.FlyteAssertion(
f"Cannot call remotely fetched entity with args - detected {len(args)} positional args {args}"
)

ctx = FlyteContext.current_context()
if ctx.compilation_state is not None and ctx.compilation_state.mode == 1:
return self.compile(ctx, *args, **kwargs)
elif (
ctx.execution_state is not None and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION
):
if ctx.execution_state.branch_eval_mode == BranchEvalMode.BRANCH_SKIPPED:
return
return self.local_execute(ctx, **kwargs)
else:
logger.debug("Fetched entity, running raw execute.")
return self.execute(**kwargs)

def local_execute(self, ctx: FlyteContext, **kwargs) -> Optional[Union[Tuple[Promise], Promise, VoidPromise]]:
raise Exception("Remotely fetched entities cannot be run locally. You have to mock this out.")

def execute(self, **kwargs) -> Any:
raise Exception("Remotely fetched entities cannot be run locally. You have to mock this out.")
Loading