Skip to content

Commit

Permalink
fix: merge conflict
Browse files Browse the repository at this point in the history
kennyworkman committed Dec 9, 2021
2 parents ebdadf3 + f28bb74 commit fe378f5
Showing 32 changed files with 1,371 additions and 841 deletions.
3 changes: 1 addition & 2 deletions flytekit/clients/friendly.py
Original file line number Diff line number Diff line change
@@ -676,12 +676,11 @@ def get_node_execution(self, node_execution_identifier):
)
)

def get_node_execution_data(self, node_execution_identifier):
def get_node_execution_data(self, node_execution_identifier) -> _execution.NodeExecutionGetDataResponse:
"""
Returns signed URLs to LiteralMap blobs for a node execution's inputs and outputs (when available).
:param flytekit.models.core.identifier.NodeExecutionIdentifier node_execution_identifier:
:rtype: flytekit.models.execution.NodeExecutionGetDataResponse
"""
return _execution.NodeExecutionGetDataResponse.from_flyte_idl(
super(SynchronousFlyteClient, self).get_node_execution_data(
42 changes: 40 additions & 2 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@
from google.protobuf.json_format import MessageToDict as _MessageToDict
from google.protobuf.json_format import ParseDict as _ParseDict
from google.protobuf.struct_pb2 import Struct
from marshmallow_enum import EnumField, LoadDumpOptions
from marshmallow_jsonschema import JSONSchema
from typing_extensions import get_origin

@@ -30,7 +31,7 @@
from flytekit.models import types as _type_models
from flytekit.models.annotation import TypeAnnotation as _annotation_model
from flytekit.models.core import types as _core_types
from flytekit.models.literals import Literal, LiteralCollection, LiteralMap, Primitive, Scalar
from flytekit.models.literals import Literal, LiteralCollection, LiteralMap, Primitive, Scalar, Schema
from flytekit.models.types import LiteralType, SimpleType

T = typing.TypeVar("T")
@@ -229,7 +230,13 @@ def get_literal_type(self, t: Type[T]) -> LiteralType:
)
schema = None
try:
schema = JSONSchema().dump(cast(DataClassJsonMixin, t).schema())
s = cast(DataClassJsonMixin, t).schema()
for _, v in s.fields.items():
# marshmallow-jsonschema only supports enums loaded by name.
# https://github.com/fuhrysteve/marshmallow-jsonschema/blob/81eada1a0c42ff67de216923968af0a6b54e5dcb/marshmallow_jsonschema/base.py#L228
if isinstance(v, EnumField):
v.load_by = LoadDumpOptions.name
schema = JSONSchema().dump(s)
except Exception as e:
logger.warn("failed to extract schema for object %s, (will run schemaless) error: %s", str(t), e)

@@ -245,10 +252,39 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
raise AssertionError(
f"Dataclass {python_type} should be decorated with @dataclass_json to be " f"serialized correctly"
)
self._serialize_flyte_type(python_val, python_type)
return Literal(
scalar=Scalar(generic=_json_format.Parse(cast(DataClassJsonMixin, python_val).to_json(), _struct.Struct()))
)

def _serialize_flyte_type(self, python_val: T, python_type: Type[T]):
"""
If any field inside the dataclass is flyte type, we should use flyte type transformer for that field.
"""
from flytekit.types.schema.types import FlyteSchema, FlyteSchemaTransformer

for f in dataclasses.fields(python_type):
v = python_val.__getattribute__(f.name)
if inspect.isclass(f.type) and issubclass(f.type, FlyteSchema):
FlyteSchemaTransformer().to_literal(FlyteContext.current_context(), v, f.type, None)
elif dataclasses.is_dataclass(f.type):
self._serialize_flyte_type(v, f.type)

def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type["FlyteSchema"]):
from flytekit.types.schema.types import FlyteSchema, FlyteSchemaTransformer

for f in dataclasses.fields(expected_python_type):
v = python_val.__getattribute__(f.name)
if inspect.isclass(f.type) and issubclass(f.type, FlyteSchema):
t = FlyteSchemaTransformer()
t.to_python_value(
FlyteContext.current_context(),
Literal(scalar=Scalar(schema=Schema(v.remote_path, t._get_schema_type(f.type)))),
f.type,
)
elif dataclasses.is_dataclass(f.type):
self._deserialize_flyte_type(v, f.type)

def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any:
if t == int:
return int(val)
@@ -291,7 +327,9 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
f"Dataclass {expected_python_type} should be decorated with @dataclass_json to be "
f"serialized correctly"
)

dc = cast(DataClassJsonMixin, expected_python_type).from_json(_json_format.MessageToJson(lv.scalar.generic))
self._deserialize_flyte_type(dc, expected_python_type)
return self._fix_dataclass_int(expected_python_type, dc)

def guess_python_type(self, literal_type: LiteralType) -> Type[T]:
File renamed without changes.
224 changes: 224 additions & 0 deletions flytekit/extras/tasks/shell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
import datetime
import logging
import os
import re
import subprocess
import typing
from dataclasses import dataclass

from flytekit.core.context_manager import ExecutionParameters
from flytekit.core.interface import Interface
from flytekit.core.python_function_task import PythonInstanceTask
from flytekit.core.task import TaskPlugins
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile


@dataclass
class OutputLocation:
"""
Args:
var: str The name of the output variable
var_type: typing.Type The type of output variable
location: os.PathLike The location where this output variable will be written to or a regex that accepts input
vars and generates the path. Of the form ``"{{ .inputs.v }}.tmp.md"``.
This example for a given input v, at path `/tmp/abc.csv` will resolve to `/tmp/abc.csv.tmp.md`
"""

var: str
var_type: typing.Type
location: typing.Union[os.PathLike, str]


def _stringify(v: typing.Any) -> str:
"""
Special cased return for the given value. Given the type returns the string version for the type.
Handles FlyteFile and FlyteDirectory specially. Downloads and returns the downloaded filepath
"""
if isinstance(v, FlyteFile):
v.download()
return v.path
if isinstance(v, FlyteDirectory):
v.download()
return v.path
if isinstance(v, datetime.datetime):
return v.isoformat()
return str(v)


def _interpolate(tmpl: str, regex: re.Pattern, validate_all_match: bool = True, **kwargs) -> str:
"""
Substitutes all templates that match the supplied regex
with the given inputs and returns the substituted string. The result is non destructive towards the given string.
"""
modified = tmpl
matched = set()
for match in regex.finditer(tmpl):
expr = match.groups()[0]
var = match.groups()[1]
if var not in kwargs:
raise ValueError(f"Variable {var} in Query (part of {expr}) not found in inputs {kwargs.keys()}")
matched.add(var)
val = kwargs[var]
# str conversion should be deliberate, with right conversion for each type
modified = modified.replace(expr, _stringify(val))

if validate_all_match:
if len(matched) < len(kwargs.keys()):
diff = set(kwargs.keys()).difference(matched)
raise ValueError(f"Extra Inputs have no matches in script template - missing {diff}")
return modified


def _dummy_task_func():
"""
A Fake function to satisfy the inner PythonTask requirements
"""
return None


T = typing.TypeVar("T")


class ShellTask(PythonInstanceTask[T]):
""" """

_INPUT_REGEX = re.compile(r"({{\s*.inputs.(\w+)\s*}})", re.IGNORECASE)
_OUTPUT_REGEX = re.compile(r"({{\s*.outputs.(\w+)\s*}})", re.IGNORECASE)

def __init__(
self,
name: str,
debug: bool = False,
script: typing.Optional[str] = None,
script_file: typing.Optional[str] = None,
task_config: T = None,
inputs: typing.Optional[typing.Dict[str, typing.Type]] = None,
output_locs: typing.Optional[typing.List[OutputLocation]] = None,
**kwargs,
):
"""
Args:
name: str Name of the Task. Should be unique in the project
debug: bool Print the generated script and other debugging information
script: The actual script specified as a string
script_file: A path to the file that contains the script (Only script or script_file) can be provided
task_config: T Configuration for the task, can be either a Pod (or coming soon, BatchJob) config
inputs: A Dictionary of input names to types
output_locs: A list of :py:class:`OutputLocations`
**kwargs: Other arguments that can be passed to :ref:class:`PythonInstanceTask`
"""
if script and script_file:
raise ValueError("Only either of script or script_file can be provided")
if not script and not script_file:
raise ValueError("Either a script or script_file is needed")
if script_file:
if not os.path.exists(script_file):
raise ValueError(f"FileNotFound: the specified Script file at path {script_file} cannot be loaded")
script_file = os.path.abspath(script_file)

if task_config is not None:
if str(type(task_config)) != "flytekitplugins.pod.task.Pod":
raise ValueError("TaskConfig can either be empty - indicating simple container task or a PodConfig.")

# Each instance of NotebookTask instantiates an underlying task with a dummy function that will only be used
# to run pre- and post- execute functions using the corresponding task plugin.
# We rename the function name here to ensure the generated task has a unique name and avoid duplicate task name
# errors.
# This seem like a hack. We should use a plugin_class that doesn't require a fake-function to make work.
plugin_class = TaskPlugins.find_pythontask_plugin(type(task_config))
self._config_task_instance = plugin_class(task_config=task_config, task_function=_dummy_task_func)
# Rename the internal task so that there are no conflicts at serialization time. Technically these internal
# tasks should not be serialized at all, but we don't currently have a mechanism for skipping Flyte entities
# at serialization time.
self._config_task_instance._name = f"_bash.{name}"
self._script = script
self._script_file = script_file
self._debug = debug
self._output_locs = output_locs if output_locs else []
outputs = self._validate_output_locs()
super().__init__(
name,
task_config,
task_type=self._config_task_instance.task_type,
interface=Interface(inputs=inputs, outputs=outputs),
**kwargs,
)

def _validate_output_locs(self) -> typing.Dict[str, typing.Type]:
outputs = {}
for v in self._output_locs:
if v is None:
raise ValueError("OutputLocation cannot be none")
if not isinstance(v, OutputLocation):
raise ValueError("Every output type should be an output location on the file-system")
if v.location is None:
raise ValueError(f"Output Location not provided for output var {v.var}")
if not issubclass(v.var_type, FlyteFile) and not issubclass(v.var_type, FlyteDirectory):
raise ValueError(
"Currently only outputs of type FlyteFile/FlyteDirectory and their derived types are supported"
)
outputs[v.var] = v.var_type
return outputs

@property
def script(self) -> typing.Optional[str]:
return self._script

@property
def script_file(self) -> typing.Optional[os.PathLike]:
return self._script_file

def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
return self._config_task_instance.pre_execute(user_params)

def execute(self, **kwargs) -> typing.Any:
"""
Executes the given script by substituting the inputs and outputs and extracts the outputs from the filesystem
"""
logging.info(f"Running shell script as type {self.task_type}")
if self.script_file:
with open(self.script_file) as f:
self._script = f.read()

outputs: typing.Dict[str, str] = {}
if self._output_locs:
for v in self._output_locs:
outputs[v.var] = _interpolate(v.location, self._INPUT_REGEX, validate_all_match=False, **kwargs)

gen_script = _interpolate(self._script, self._INPUT_REGEX, **kwargs)
# For outputs it is not necessary that all outputs are used in the script, some are implicit outputs
# for example gcc main.c will generate a.out automatically
gen_script = _interpolate(gen_script, self._OUTPUT_REGEX, validate_all_match=False, **outputs)
if self._debug:
print("\n==============================================\n")
print(gen_script)
print("\n==============================================\n")

try:
subprocess.check_call(gen_script, shell=True)
except subprocess.CalledProcessError as e:
files = os.listdir("./")
fstr = "\n-".join(files)
logging.error(
f"Failed to Execute Script, return-code {e.returncode} \n"
f"StdErr: {e.stderr}\n"
f"StdOut: {e.stdout}\n"
f" Current directory contents: .\n-{fstr}"
)
raise

final_outputs = []
for v in self._output_locs:
if issubclass(v.var_type, FlyteFile):
final_outputs.append(FlyteFile(outputs[v.var]))
if issubclass(v.var_type, FlyteDirectory):
final_outputs.append(FlyteDirectory(outputs[v.var]))
if len(final_outputs) == 1:
return final_outputs[0]
if len(final_outputs) > 1:
return tuple(final_outputs)
return None

def post_execute(self, user_params: ExecutionParameters, rval: typing.Any) -> typing.Any:
return self._config_task_instance.post_execute(user_params, rval)
75 changes: 75 additions & 0 deletions flytekit/models/core/catalog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from flyteidl.core import catalog_pb2

from flytekit.models import common as _common_models
from flytekit.models.core import identifier as _identifier


class CatalogArtifactTag(_common_models.FlyteIdlEntity):
def __init__(self, artifact_id: str, name: str):
self._artifact_id = artifact_id
self._name = name

@property
def artifact_id(self) -> str:
return self._artifact_id

@property
def name(self) -> str:
return self._name

def to_flyte_idl(self) -> catalog_pb2.CatalogArtifactTag:
return catalog_pb2.CatalogArtifactTag(artifact_id=self.artifact_id, name=self.name)

@classmethod
def from_flyte_idl(cls, p: catalog_pb2.CatalogArtifactTag) -> "CatalogArtifactTag":
return cls(
artifact_id=p.artifact_id,
name=p.name,
)


class CatalogMetadata(_common_models.FlyteIdlEntity):
def __init__(
self,
dataset_id: _identifier.Identifier,
artifact_tag: CatalogArtifactTag,
source_task_execution: _identifier.TaskExecutionIdentifier,
):
self._dataset_id = dataset_id
self._artifact_tag = artifact_tag
self._source_task_execution = source_task_execution

@property
def dataset_id(self) -> _identifier.Identifier:
return self._dataset_id

@property
def artifact_tag(self) -> CatalogArtifactTag:
return self._artifact_tag

@property
def source_task_execution(self) -> _identifier.TaskExecutionIdentifier:
return self._source_task_execution

@property
def source_execution(self) -> _identifier.TaskExecutionIdentifier:
"""
This is a one of but for now there's only one thing in the one of
"""
return self._source_task_execution

def to_flyte_idl(self) -> catalog_pb2.CatalogMetadata:
return catalog_pb2.CatalogMetadata(
dataset_id=self.dataset_id.to_flyte_idl(),
artifact_tag=self.artifact_tag.to_flyte_idl(),
source_task_execution=self.source_task_execution.to_flyte_idl(),
)

@classmethod
def from_flyte_idl(cls, pb: catalog_pb2.CatalogMetadata) -> "CatalogMetadata":
return cls(
dataset_id=_identifier.Identifier.from_flyte_idl(pb.dataset_id),
artifact_tag=CatalogArtifactTag.from_flyte_idl(pb.artifact_tag),
# Add HasField check if more things are ever added to the one of
source_task_execution=_identifier.TaskExecutionIdentifier.from_flyte_idl(pb.source_task_execution),
)
24 changes: 19 additions & 5 deletions flytekit/models/execution.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import typing

import flyteidl.admin.execution_pb2 as _execution_pb2
import flyteidl.admin.node_execution_pb2 as _node_execution_pb2
import flyteidl.admin.task_execution_pb2 as _task_execution_pb2
@@ -7,6 +9,7 @@
from flytekit.models import literals as _literals_models
from flytekit.models.core import execution as _core_execution
from flytekit.models.core import identifier as _identifier
from flytekit.models.node_execution import DynamicWorkflowNodeMetadata


class ExecutionMetadata(_common_models.FlyteIdlEntity):
@@ -238,7 +241,6 @@ class Execution(_common_models.FlyteIdlEntity):
def __init__(self, id, spec, closure):
"""
:param flytekit.models.core.identifier.WorkflowExecutionIdentifier id:
:param Text id:
:param ExecutionSpec spec:
:param ExecutionClosure closure:
"""
@@ -403,8 +405,8 @@ def __init__(self, inputs, outputs, full_inputs, full_outputs):
"""
:param _common_models.UrlBlob inputs:
:param _common_models.UrlBlob outputs:
:param _literals_pb2.LiteralMap full_inputs:
:param _literals_pb2.LiteralMap full_outputs:
:param _literals_models.LiteralMap full_inputs:
:param _literals_models.LiteralMap full_outputs:
"""
self._inputs = inputs
self._outputs = outputs
@@ -428,14 +430,14 @@ def outputs(self):
@property
def full_inputs(self):
"""
:rtype: _literals_pb2.LiteralMap
:rtype: _literals_models.LiteralMap
"""
return self._full_inputs

@property
def full_outputs(self):
"""
:rtype: _literals_pb2.LiteralMap
:rtype: _literals_models.LiteralMap
"""
return self._full_outputs

@@ -493,6 +495,14 @@ def to_flyte_idl(self):


class NodeExecutionGetDataResponse(_CommonDataResponse):
def __init__(self, *args, dynamic_workflow: typing.Optional[DynamicWorkflowNodeMetadata] = None, **kwargs):
super().__init__(*args, **kwargs)
self._dynamic_workflow = dynamic_workflow

@property
def dynamic_workflow(self) -> typing.Optional[DynamicWorkflowNodeMetadata]:
return self._dynamic_workflow

@classmethod
def from_flyte_idl(cls, pb2_object):
"""
@@ -504,6 +514,9 @@ def from_flyte_idl(cls, pb2_object):
outputs=_common_models.UrlBlob.from_flyte_idl(pb2_object.outputs),
full_inputs=_literals_models.LiteralMap.from_flyte_idl(pb2_object.full_inputs),
full_outputs=_literals_models.LiteralMap.from_flyte_idl(pb2_object.full_outputs),
dynamic_workflow=DynamicWorkflowNodeMetadata.from_flyte_idl(pb2_object.dynamic_workflow)
if pb2_object.HasField("dynamic_workflow")
else None,
)

def to_flyte_idl(self):
@@ -515,4 +528,5 @@ def to_flyte_idl(self):
outputs=self.outputs.to_flyte_idl(),
full_inputs=self.full_inputs.to_flyte_idl(),
full_outputs=self.full_outputs.to_flyte_idl(),
dynamic_workflow=self.dynamic_workflow.to_flyte_idl() if self.dynamic_workflow else None,
)
115 changes: 114 additions & 1 deletion flytekit/models/node_execution.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,101 @@
import typing

import flyteidl.admin.node_execution_pb2 as _node_execution_pb2
import pytz as _pytz

from flytekit.models import common as _common_models
from flytekit.models.core import catalog as catalog_models
from flytekit.models.core import compiler as core_compiler_models
from flytekit.models.core import execution as _core_execution
from flytekit.models.core import identifier as _identifier


class WorkflowNodeMetadata(_common_models.FlyteIdlEntity):
def __init__(self, execution_id: _identifier.WorkflowExecutionIdentifier):
self._execution_id = execution_id

@property
def execution_id(self) -> _identifier.WorkflowExecutionIdentifier:
return self._execution_id

def to_flyte_idl(self) -> _node_execution_pb2.WorkflowNodeMetadata:
return _node_execution_pb2.WorkflowNodeMetadata(
executionId=self.execution_id.to_flyte_idl(),
)

@classmethod
def from_flyte_idl(cls, p: _node_execution_pb2.WorkflowNodeMetadata) -> "WorkflowNodeMetadata":
return cls(
execution_id=_identifier.WorkflowExecutionIdentifier.from_flyte_idl(p.executionId),
)


class DynamicWorkflowNodeMetadata(_common_models.FlyteIdlEntity):
def __init__(self, id: _identifier.Identifier, compiled_workflow: core_compiler_models.CompiledWorkflowClosure):
self._id = id
self._compiled_workflow = compiled_workflow

@property
def id(self) -> _identifier.Identifier:
return self._id

@property
def compiled_workflow(self) -> core_compiler_models.CompiledWorkflowClosure:
return self._compiled_workflow

def to_flyte_idl(self) -> _node_execution_pb2.DynamicWorkflowNodeMetadata:
return _node_execution_pb2.DynamicWorkflowNodeMetadata(
id=self.id.to_flyte_idl(),
compiled_workflow=self.compiled_workflow.to_flyte_idl(),
)

@classmethod
def from_flyte_idl(cls, p: _node_execution_pb2.DynamicWorkflowNodeMetadata) -> "DynamicWorkflowNodeMetadata":
yy = cls(
id=_identifier.Identifier.from_flyte_idl(p.id),
compiled_workflow=core_compiler_models.CompiledWorkflowClosure.from_flyte_idl(p.compiled_workflow),
)
return yy


class TaskNodeMetadata(_common_models.FlyteIdlEntity):
def __init__(self, cache_status: int, catalog_key: catalog_models.CatalogMetadata):
self._cache_status = cache_status
self._catalog_key = catalog_key

@property
def cache_status(self) -> int:
return self._cache_status

@property
def catalog_key(self) -> catalog_models.CatalogMetadata:
return self._catalog_key

def to_flyte_idl(self) -> _node_execution_pb2.TaskNodeMetadata:
return _node_execution_pb2.TaskNodeMetadata(
cache_status=self.cache_status,
catalog_key=self.catalog_key.to_flyte_idl(),
)

@classmethod
def from_flyte_idl(cls, p: _node_execution_pb2.TaskNodeMetadata) -> "TaskNodeMetadata":
return cls(
cache_status=p.cache_status,
catalog_key=catalog_models.CatalogMetadata.from_flyte_idl(p.catalog_key),
)


class NodeExecutionClosure(_common_models.FlyteIdlEntity):
def __init__(self, phase, started_at, duration, output_uri=None, error=None):
def __init__(
self,
phase,
started_at,
duration,
output_uri=None,
error=None,
workflow_node_metadata: typing.Optional[WorkflowNodeMetadata] = None,
task_node_metadata: typing.Optional[TaskNodeMetadata] = None,
):
"""
:param int phase:
:param datetime.datetime started_at:
@@ -20,6 +108,9 @@ def __init__(self, phase, started_at, duration, output_uri=None, error=None):
self._duration = duration
self._output_uri = output_uri
self._error = error
self._workflow_node_metadata = workflow_node_metadata
self._task_node_metadata = task_node_metadata
# TODO: Add output_data field as well.

@property
def phase(self):
@@ -56,6 +147,18 @@ def error(self):
"""
return self._error

@property
def workflow_node_metadata(self) -> typing.Optional[WorkflowNodeMetadata]:
return self._workflow_node_metadata

@property
def task_node_metadata(self) -> typing.Optional[TaskNodeMetadata]:
return self._task_node_metadata

@property
def target_metadata(self) -> typing.Union[WorkflowNodeMetadata, TaskNodeMetadata]:
return self.workflow_node_metadata or self.task_node_metadata

def to_flyte_idl(self):
"""
:rtype: flyteidl.admin.node_execution_pb2.NodeExecutionClosure
@@ -64,6 +167,10 @@ def to_flyte_idl(self):
phase=self.phase,
output_uri=self.output_uri,
error=self.error.to_flyte_idl() if self.error is not None else None,
workflow_node_metadata=self.workflow_node_metadata.to_flyte_idl()
if self.workflow_node_metadata is not None
else None,
task_node_metadata=self.task_node_metadata.to_flyte_idl() if self.task_node_metadata is not None else None,
)
obj.started_at.FromDatetime(self.started_at.astimezone(_pytz.UTC).replace(tzinfo=None))
obj.duration.FromTimedelta(self.duration)
@@ -81,6 +188,12 @@ def from_flyte_idl(cls, p):
error=_core_execution.ExecutionError.from_flyte_idl(p.error) if p.HasField("error") else None,
started_at=p.started_at.ToDatetime().replace(tzinfo=_pytz.UTC),
duration=p.duration.ToTimedelta(),
workflow_node_metadata=WorkflowNodeMetadata.from_flyte_idl(p.workflow_node_metadata)
if p.HasField("workflow_node_metadata")
else None,
task_node_metadata=TaskNodeMetadata.from_flyte_idl(p.task_node_metadata)
if p.HasField("task_node_metadata")
else None,
)


7 changes: 3 additions & 4 deletions flytekit/remote/__init__.py
Original file line number Diff line number Diff line change
@@ -79,10 +79,9 @@
"""

from flytekit.remote.component_nodes import FlyteTaskNode, FlyteWorkflowNode
from flytekit.remote.executions import FlyteNodeExecution, FlyteTaskExecution, FlyteWorkflowExecution
from flytekit.remote.launch_plan import FlyteLaunchPlan
from flytekit.remote.nodes import FlyteNode, FlyteNodeExecution
from flytekit.remote.nodes import FlyteNode
from flytekit.remote.remote import FlyteRemote
from flytekit.remote.tasks.executions import FlyteTaskExecution
from flytekit.remote.tasks.task import FlyteTask
from flytekit.remote.task import FlyteTask
from flytekit.remote.workflow import FlyteWorkflow
from flytekit.remote.workflow_execution import FlyteWorkflowExecution
21 changes: 10 additions & 11 deletions flytekit/remote/component_nodes.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
import logging as _logging
from typing import Dict

import flytekit
from flytekit.common.exceptions import system as _system_exceptions
from flytekit.models import launch_plan as _launch_plan_model
from flytekit.models import task as _task_model
from flytekit.models.core import identifier as id_models
from flytekit.models.core import workflow as _workflow_model
from flytekit.remote import identifier as _identifier


class FlyteTaskNode(_workflow_model.TaskNode):
"""A class encapsulating a task that a Flyte node needs to execute."""

def __init__(self, flyte_task: "flytekit.remote.tasks.task.FlyteTask"):
def __init__(self, flyte_task: "flytekit.remote.task.FlyteTask"):
self._flyte_task = flyte_task
super(FlyteTaskNode, self).__init__(None)

@property
def reference_id(self) -> _identifier.Identifier:
def reference_id(self) -> id_models.Identifier:
"""A globally unique identifier for the task."""
return self._flyte_task.id

@@ -29,7 +28,7 @@ def flyte_task(self) -> "flytekit.remote.tasks.task.FlyteTask":
def promote_from_model(
cls,
base_model: _workflow_model.TaskNode,
tasks: Dict[_identifier.Identifier, _task_model.TaskTemplate],
tasks: Dict[id_models.Identifier, _task_model.TaskTemplate],
) -> "FlyteTaskNode":
"""
Takes the idl wrapper for a TaskNode and returns the hydrated Flytekit object for it by fetching it with the
@@ -38,12 +37,12 @@ def promote_from_model(
:param base_model:
:param tasks:
"""
from flytekit.remote.tasks import task as _task
from flytekit.remote.task import FlyteTask

if base_model.reference_id in tasks:
task = tasks[base_model.reference_id]
_logging.info(f"Found existing task template for {task.id}, will not retrieve from Admin")
flyte_task = _task.FlyteTask.promote_from_model(task)
flyte_task = FlyteTask.promote_from_model(task)
return cls(flyte_task)

raise _system_exceptions.FlyteSystemException(f"Task template {base_model.reference_id} not found.")
@@ -76,7 +75,7 @@ def __repr__(self) -> str:
return f"FlyteWorkflowNode with launch plan: {self.flyte_launch_plan}"

@property
def launchplan_ref(self) -> _identifier.Identifier:
def launchplan_ref(self) -> id_models.Identifier:
"""A globally unique identifier for the launch plan, which should map to Admin."""
return self._flyte_launch_plan.id if self._flyte_launch_plan else None

@@ -96,9 +95,9 @@ def flyte_workflow(self) -> "flytekit.remote.workflow.FlyteWorkflow":
def promote_from_model(
cls,
base_model: _workflow_model.WorkflowNode,
sub_workflows: Dict[_identifier.Identifier, _workflow_model.WorkflowTemplate],
node_launch_plans: Dict[_identifier.Identifier, _launch_plan_model.LaunchPlanSpec],
tasks: Dict[_identifier.Identifier, _task_model.TaskTemplate],
sub_workflows: Dict[id_models.Identifier, _workflow_model.WorkflowTemplate],
node_launch_plans: Dict[id_models.Identifier, _launch_plan_model.LaunchPlanSpec],
tasks: Dict[id_models.Identifier, _task_model.TaskTemplate],
) -> "FlyteWorkflowNode":
from flytekit.remote import launch_plan as _launch_plan
from flytekit.remote import workflow as _workflow
239 changes: 239 additions & 0 deletions flytekit/remote/executions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
from __future__ import annotations

from typing import Any, Dict, List, Optional, Union

from flytekit.common.exceptions import user as _user_exceptions
from flytekit.common.exceptions import user as user_exceptions
from flytekit.models import execution as execution_models
from flytekit.models import node_execution as node_execution_models
from flytekit.models.admin import task_execution as admin_task_execution_models
from flytekit.models.core import execution as core_execution_models
from flytekit.remote.workflow import FlyteWorkflow


class FlyteTaskExecution(admin_task_execution_models.TaskExecution):
"""A class encapsulating a task execution being run on a Flyte remote backend."""

def __init__(self, *args, **kwargs):
super(FlyteTaskExecution, self).__init__(*args, **kwargs)
self._inputs = None
self._outputs = None

@property
def is_complete(self) -> bool:
"""Whether or not the execution is complete."""
return self.closure.phase in {
core_execution_models.TaskExecutionPhase.ABORTED,
core_execution_models.TaskExecutionPhase.FAILED,
core_execution_models.TaskExecutionPhase.SUCCEEDED,
}

@property
def inputs(self) -> Dict[str, Any]:
"""
Returns the inputs of the task execution in the standard Python format that is produced by
the type engine.
"""
return self._inputs

@property
def outputs(self) -> Dict[str, Any]:
"""
Returns the outputs of the task execution, if available, in the standard Python format that is produced by
the type engine.
:raises: ``FlyteAssertion`` error if execution is in progress or execution ended in error.
"""
if not self.is_complete:
raise user_exceptions.FlyteAssertion(
"Please wait until the node execution has completed before requesting the outputs."
)
if self.error:
raise user_exceptions.FlyteAssertion("Outputs could not be found because the execution ended in failure.")
return self._outputs

@property
def error(self) -> Optional[core_execution_models.ExecutionError]:
"""
If execution is in progress, raise an exception. Otherwise, return None if no error was present upon
reaching completion.
"""
if not self.is_complete:
raise user_exceptions.FlyteAssertion(
"Please what until the task execution has completed before requesting error information."
)
return self.closure.error

@classmethod
def promote_from_model(cls, base_model: admin_task_execution_models.TaskExecution) -> "FlyteTaskExecution":
return cls(
closure=base_model.closure,
id=base_model.id,
input_uri=base_model.input_uri,
is_parent=base_model.is_parent,
)


class FlyteWorkflowExecution(execution_models.Execution):
"""A class encapsulating a workflow execution being run on a Flyte remote backend."""

def __init__(self, *args, **kwargs):
super(FlyteWorkflowExecution, self).__init__(*args, **kwargs)
self._node_executions = None
self._inputs = None
self._outputs = None
self._flyte_workflow: Optional[FlyteWorkflow] = None

@property
def node_executions(self) -> Dict[str, "FlyteNodeExecution"]:
"""Get a dictionary of node executions that are a part of this workflow execution."""
return self._node_executions or {}

@property
def inputs(self) -> Dict[str, Any]:
"""
Returns the inputs to the execution in the standard python format as dictated by the type engine.
"""
return self._inputs

@property
def outputs(self) -> Dict[str, Any]:
"""
Returns the outputs to the execution in the standard python format as dictated by the type engine.
:raises: ``FlyteAssertion`` error if execution is in progress or execution ended in error.
"""
if not self.is_complete:
raise _user_exceptions.FlyteAssertion(
"Please wait until the node execution has completed before requesting the outputs."
)
if self.error:
raise _user_exceptions.FlyteAssertion("Outputs could not be found because the execution ended in failure.")
return self._outputs

@property
def error(self) -> core_execution_models.ExecutionError:
"""
If execution is in progress, raise an exception. Otherwise, return None if no error was present upon
reaching completion.
"""
if not self.is_complete:
raise _user_exceptions.FlyteAssertion(
"Please wait until a workflow has completed before checking for an error."
)
return self.closure.error

@property
def is_complete(self) -> bool:
"""
Whether or not the execution is complete.
"""
return self.closure.phase in {
core_execution_models.WorkflowExecutionPhase.ABORTED,
core_execution_models.WorkflowExecutionPhase.FAILED,
core_execution_models.WorkflowExecutionPhase.SUCCEEDED,
core_execution_models.WorkflowExecutionPhase.TIMED_OUT,
}

@classmethod
def promote_from_model(cls, base_model: execution_models.Execution) -> "FlyteWorkflowExecution":
return cls(
closure=base_model.closure,
id=base_model.id,
spec=base_model.spec,
)


class FlyteNodeExecution(node_execution_models.NodeExecution):
"""A class encapsulating a node execution being run on a Flyte remote backend."""

def __init__(self, *args, **kwargs):
super(FlyteNodeExecution, self).__init__(*args, **kwargs)
self._task_executions = None
self._workflow_executions = []
self._underlying_node_executions = None
self._inputs = None
self._outputs = None
self._interface = None

@property
def task_executions(self) -> List[FlyteTaskExecution]:
return self._task_executions or []

@property
def workflow_executions(self) -> List[FlyteWorkflowExecution]:
return self._workflow_executions

@property
def subworkflow_node_executions(self) -> Dict[str, FlyteNodeExecution]:
"""
This returns underlying node executions in instances where the current node execution is
a parent node. This happens when it's either a static or dynamic subworkflow.
"""
return (
{}
if self._underlying_node_executions is None
else {n.id.node_id: n for n in self._underlying_node_executions}
)

@property
def executions(self) -> List[Union[FlyteTaskExecution, FlyteWorkflowExecution]]:
return self.task_executions or self._underlying_node_executions or []

@property
def inputs(self) -> Dict[str, Any]:
"""
Returns the inputs to the execution in the standard python format as dictated by the type engine.
"""
return self._inputs

@property
def outputs(self) -> Dict[str, Any]:
"""
Returns the outputs to the execution in the standard python format as dictated by the type engine.
:raises: ``FlyteAssertion`` error if execution is in progress or execution ended in error.
"""
if not self.is_complete:
raise _user_exceptions.FlyteAssertion(
"Please wait until the node execution has completed before requesting the outputs."
)
if self.error:
raise _user_exceptions.FlyteAssertion("Outputs could not be found because the execution ended in failure.")
return self._outputs

@property
def error(self) -> core_execution_models.ExecutionError:
"""
If execution is in progress, raise an exception. Otherwise, return None if no error was present upon
reaching completion.
"""
if not self.is_complete:
raise _user_exceptions.FlyteAssertion(
"Please wait until the node execution has completed before requesting error information."
)
return self.closure.error

@property
def is_complete(self) -> bool:
"""Whether or not the execution is complete."""
return self.closure.phase in {
core_execution_models.NodeExecutionPhase.ABORTED,
core_execution_models.NodeExecutionPhase.FAILED,
core_execution_models.NodeExecutionPhase.SKIPPED,
core_execution_models.NodeExecutionPhase.SUCCEEDED,
core_execution_models.NodeExecutionPhase.TIMED_OUT,
}

@classmethod
def promote_from_model(cls, base_model: node_execution_models.NodeExecution) -> "FlyteNodeExecution":
return cls(
closure=base_model.closure, id=base_model.id, input_uri=base_model.input_uri, metadata=base_model.metadata
)

@property
def interface(self) -> "flytekit.remote.interface.TypedInterface":
"""
Return the interface of the task or subworkflow associated with this node execution.
"""
return self._interface
137 changes: 0 additions & 137 deletions flytekit/remote/identifier.py

This file was deleted.

13 changes: 0 additions & 13 deletions flytekit/remote/interface.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
from typing import Any, Dict, List, Tuple

from flytekit.models import interface as _interface_models
from flytekit.models import literals as _literal_models
from flytekit.remote import nodes as _nodes


class TypedInterface(_interface_models.TypedInterface):
@@ -13,12 +9,3 @@ def promote_from_model(cls, model):
:rtype: TypedInterface
"""
return cls(model.inputs, model.outputs)

def create_bindings_for_inputs(
self, map_of_bindings: Dict[str, Any]
) -> Tuple[List[_literal_models.Binding], List[_nodes.FlyteNode]]:
"""
:param: map_of_bindings: this can be scalar primitives, it can be node output references, lists, etc.
:raises: flytekit.common.exceptions.user.FlyteAssertion
"""
return [], []
15 changes: 7 additions & 8 deletions flytekit/remote/launch_plan.py
Original file line number Diff line number Diff line change
@@ -7,8 +7,7 @@
from flytekit.engines.flyte import engine as _flyte_engine
from flytekit.models import interface as _interface_models
from flytekit.models import launch_plan as _launch_plan_models
from flytekit.models.core import identifier as _identifier_model
from flytekit.remote import identifier as _identifier
from flytekit.models.core import identifier as id_models
from flytekit.remote import interface as _interface


@@ -27,11 +26,11 @@ def __init__(self, id, *args, **kwargs):

@classmethod
def promote_from_model(
cls, id: _identifier.Identifier, model: _launch_plan_models.LaunchPlanSpec
cls, id: id_models.Identifier, model: _launch_plan_models.LaunchPlanSpec
) -> "FlyteLaunchPlan":
lp = cls(
id=id,
workflow_id=_identifier.Identifier.promote_from_model(model.workflow_id),
workflow_id=model.workflow_id,
default_inputs=_interface_models.ParameterMap(model.default_inputs.parameters),
fixed_inputs=model.fixed_inputs,
entity_metadata=model.entity_metadata,
@@ -50,7 +49,7 @@ def promote_from_model(
return lp

@property
def id(self) -> _identifier.Identifier:
def id(self) -> id_models.Identifier:
return self._id

@property
@@ -65,7 +64,7 @@ def is_scheduled(self) -> bool:
return False

@property
def workflow_id(self) -> _identifier.Identifier:
def workflow_id(self) -> id_models.Identifier:
return self._workflow_id

@property
@@ -78,8 +77,8 @@ def interface(self) -> _interface.TypedInterface:
return self._interface

@property
def resource_type(self) -> _identifier_model.ResourceType:
return _identifier_model.ResourceType.LAUNCH_PLAN
def resource_type(self) -> id_models.ResourceType:
return id_models.ResourceType.LAUNCH_PLAN

@property
def entity_type_text(self) -> str:
177 changes: 23 additions & 154 deletions flytekit/remote/nodes.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,18 @@
from __future__ import annotations

import logging as _logging
from typing import Any, Dict, List, Optional, Union
from typing import Dict, List, Optional, Union

import flytekit
from flytekit.clients.helpers import iterate_node_executions, iterate_task_executions
from flytekit.common import constants as _constants
from flytekit.common.exceptions import system as _system_exceptions
from flytekit.common.exceptions import user as _user_exceptions
from flytekit.common.mixins import artifact as _artifact_mixin
from flytekit.common.mixins import hash as _hash_mixin
from flytekit.common.utils import _dnsify
from flytekit.core.promise import NodeOutput
from flytekit.engines.flyte import engine as _flyte_engine
from flytekit.models import launch_plan as _launch_plan_model
from flytekit.models import node_execution as _node_execution_models
from flytekit.models import task as _task_model
from flytekit.models.core import execution as _execution_models
from flytekit.models.core import identifier as id_models
from flytekit.models.core import workflow as _workflow_model
from flytekit.remote import component_nodes as _component_nodes
from flytekit.remote import identifier as _identifier
from flytekit.remote.tasks.executions import FlyteTaskExecution


class FlyteNode(_hash_mixin.HashOnReferenceMixin, _workflow_model.Node):
@@ -34,7 +28,6 @@ def __init__(
flyte_workflow: Optional["FlyteWorkflow"] = None,
flyte_launch_plan: Optional["FlyteLaunchPlan"] = None,
flyte_branch=None,
parameter_mapping=True,
):
non_none_entities = list(filter(None, [flyte_task, flyte_workflow, flyte_launch_plan, flyte_branch]))
if len(non_none_entities) != 1:
@@ -50,15 +43,20 @@ def __init__(
elif flyte_launch_plan is not None:
workflow_node = _component_nodes.FlyteWorkflowNode(flyte_launch_plan=flyte_launch_plan)

task_node = None
if flyte_task:
task_node = _component_nodes.FlyteTaskNode(flyte_task)
branch_node = None

super(FlyteNode, self).__init__(
id=_dnsify(id) if id else None,
id=id,
metadata=metadata,
inputs=bindings,
upstream_node_ids=[n.id for n in upstream_nodes],
output_aliases=[],
task_node=_component_nodes.FlyteTaskNode(flyte_task) if flyte_task else None,
task_node=task_node,
workflow_node=workflow_node,
branch_node=flyte_branch,
branch_node=branch_node,
)
self._upstream = upstream_nodes

@@ -70,11 +68,12 @@ def flyte_entity(self) -> Union["FlyteTask", "FlyteWorkflow", "FlyteLaunchPlan"]
def promote_from_model(
cls,
model: _workflow_model.Node,
sub_workflows: Optional[Dict[_identifier.Identifier, _workflow_model.WorkflowTemplate]],
node_launch_plans: Optional[Dict[_identifier.Identifier, _launch_plan_model.LaunchPlanSpec]],
tasks: Optional[Dict[_identifier.Identifier, _task_model.TaskTemplate]],
) -> "FlyteNode":
id = model.id
sub_workflows: Optional[Dict[id_models.Identifier, _workflow_model.WorkflowTemplate]],
node_launch_plans: Optional[Dict[id_models.Identifier, _launch_plan_model.LaunchPlanSpec]],
tasks: Optional[Dict[id_models.Identifier, _task_model.TaskTemplate]],
) -> FlyteNode:
node_model_id = model.id
# TODO: Consider removing
if id in {_constants.START_NODE_ID, _constants.END_NODE_ID}:
_logging.warning(f"Should not call promote from model on a start node or end node {model}")
return None
@@ -97,6 +96,7 @@ def promote_from_model(

# When WorkflowTemplate models (containing node models) are returned by Admin, they've been compiled with a
# start node. In order to make the promoted FlyteWorkflow look the same, we strip the start-node text back out.
# TODO: Consider removing
for model_input in model.inputs:
if (
model_input.binding.promise is not None
@@ -106,7 +106,7 @@ def promote_from_model(

if flyte_task_node is not None:
return cls(
id=id,
id=node_model_id,
upstream_nodes=[], # set downstream, model doesn't contain this information
bindings=model.inputs,
metadata=model.metadata,
@@ -115,15 +115,15 @@ def promote_from_model(
elif flyte_workflow_node is not None:
if flyte_workflow_node.flyte_workflow is not None:
return cls(
id=id,
id=node_model_id,
upstream_nodes=[], # set downstream, model doesn't contain this information
bindings=model.inputs,
metadata=model.metadata,
flyte_workflow=flyte_workflow_node.flyte_workflow,
)
elif flyte_workflow_node.flyte_launch_plan is not None:
return cls(
id=id,
id=node_model_id,
upstream_nodes=[], # set downstream, model doesn't contain this information
bindings=model.inputs,
metadata=model.metadata,
@@ -135,7 +135,7 @@ def promote_from_model(
raise _system_exceptions.FlyteSystemException("Bad FlyteNode model, both task and workflow nodes are empty")

@property
def upstream_nodes(self) -> List["FlyteNode"]:
def upstream_nodes(self) -> List[FlyteNode]:
return self._upstream

@property
@@ -146,136 +146,5 @@ def upstream_node_ids(self) -> List[str]:
def outputs(self) -> Dict[str, NodeOutput]:
return self._outputs

def assign_id_and_return(self, id: str):
if self.id:
raise _user_exceptions.FlyteAssertion(
f"Error assigning ID: {id} because {self} is already assigned. Has this node been ssigned to another "
"workflow already?"
)
self._id = _dnsify(id) if id else None
self._metadata.name = id
return self

def with_overrides(self, *args, **kwargs):
# TODO: Implement overrides
raise NotImplementedError("Overrides are not supported in Flyte yet.")

def __repr__(self) -> str:
return f"Node(ID: {self.id})"


class FlyteNodeExecution(_node_execution_models.NodeExecution, _artifact_mixin.ExecutionArtifact):
"""A class encapsulating a node execution being run on a Flyte remote backend."""

def __init__(self, *args, **kwargs):
super(FlyteNodeExecution, self).__init__(*args, **kwargs)
self._task_executions = None
self._subworkflow_node_executions = None
self._inputs = None
self._outputs = None
self._interface = None

@property
def task_executions(self) -> List["flytekit.remote.tasks.executions.FlyteTaskExecution"]:
return self._task_executions or []

@property
def subworkflow_node_executions(self) -> Dict[str, "flytekit.remote.nodes.FlyteNodeExecution"]:
return (
{}
if self._subworkflow_node_executions is None
else {n.id.node_id: n for n in self._subworkflow_node_executions}
)

@property
def executions(self) -> List[_artifact_mixin.ExecutionArtifact]:
return self.task_executions or self._subworkflow_node_executions or []

@property
def inputs(self) -> Dict[str, Any]:
"""
Returns the inputs to the execution in the standard python format as dictated by the type engine.
"""
return self._inputs

@property
def outputs(self) -> Dict[str, Any]:
"""
Returns the outputs to the execution in the standard python format as dictated by the type engine.
:raises: ``FlyteAssertion`` error if execution is in progress or execution ended in error.
"""
if not self.is_complete:
raise _user_exceptions.FlyteAssertion(
"Please wait until the node execution has completed before requesting the outputs."
)
if self.error:
raise _user_exceptions.FlyteAssertion("Outputs could not be found because the execution ended in failure.")
return self._outputs

@property
def error(self) -> _execution_models.ExecutionError:
"""
If execution is in progress, raise an exception. Otherwise, return None if no error was present upon
reaching completion.
"""
if not self.is_complete:
raise _user_exceptions.FlyteAssertion(
"Please wait until the node execution has completed before requesting error information."
)
return self.closure.error

@property
def is_complete(self) -> bool:
"""Whether or not the execution is complete."""
return self.closure.phase in {
_execution_models.NodeExecutionPhase.ABORTED,
_execution_models.NodeExecutionPhase.FAILED,
_execution_models.NodeExecutionPhase.SKIPPED,
_execution_models.NodeExecutionPhase.SUCCEEDED,
_execution_models.NodeExecutionPhase.TIMED_OUT,
}

@classmethod
def promote_from_model(cls, base_model: _node_execution_models.NodeExecution) -> "FlyteNodeExecution":
return cls(
closure=base_model.closure, id=base_model.id, input_uri=base_model.input_uri, metadata=base_model.metadata
)

@property
def interface(self) -> "flytekit.remote.interface.TypedInterface":
"""
Return the interface of the task or subworkflow associated with this node execution.
"""
return self._interface

def sync(self):
"""
Syncs the state of the underlying execution artifact with the state observed by the platform.
"""
if self.metadata.is_parent_node:
if not self.is_complete or self._subworkflow_node_executions is None:
self._subworkflow_node_executions = [
FlyteNodeExecution.promote_from_model(n)
for n in iterate_node_executions(
_flyte_engine.get_client(),
workflow_execution_identifier=self.id.execution_id,
unique_parent_id=self.id.node_id,
)
]
else:
if not self.is_complete or self._task_executions is None:
self._task_executions = [
FlyteTaskExecution.promote_from_model(t)
for t in iterate_task_executions(_flyte_engine.get_client(), self.id)
]

self._sync_closure()
for execution in self.executions:
execution.sync()

def _sync_closure(self):
"""
Syncs the closure of the underlying execution artifact with the state observed by the platform.
"""
self._closure = _flyte_engine.get_client().get_node_execution(self.id).closure
283 changes: 183 additions & 100 deletions flytekit/remote/remote.py

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions flytekit/remote/tasks/task.py → flytekit/remote/task.py
Original file line number Diff line number Diff line change
@@ -6,7 +6,6 @@
from flytekit.loggers import logger
from flytekit.models import task as _task_model
from flytekit.models.core import identifier as _identifier_model
from flytekit.remote import identifier as _identifier
from flytekit.remote import interface as _interfaces


@@ -61,7 +60,7 @@ def promote_from_model(cls, base_model: _task_model.TaskTemplate) -> "FlyteTask"
)
# Override the newly generated name if one exists in the base model
if not base_model.id.is_empty:
t._id = _identifier.Identifier.promote_from_model(base_model.id)
t._id = base_model.id

if t.interface is not None:
try:
96 changes: 0 additions & 96 deletions flytekit/remote/tasks/executions.py

This file was deleted.

99 changes: 54 additions & 45 deletions flytekit/remote/workflow.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from __future__ import annotations

from typing import Dict, List, Optional

from flytekit.common import constants as _constants
from flytekit.common.exceptions import system as _system_exceptions
from flytekit.common.exceptions import user as _user_exceptions
from flytekit.common.mixins import hash as _hash_mixin
from flytekit.core.interface import Interface
from flytekit.core.type_engine import TypeEngine
from flytekit.models import launch_plan as _launch_plan_models
from flytekit.models import launch_plan as launch_plan_models
from flytekit.models import task as _task_models
from flytekit.models.core import identifier as _identifier_model
from flytekit.models.core import compiler as compiler_models
from flytekit.models.core import identifier as id_models
from flytekit.models.core import workflow as _workflow_models
from flytekit.remote import identifier as _identifier
from flytekit.remote import interface as _interfaces
from flytekit.remote import nodes as _nodes

@@ -23,10 +24,15 @@ def __init__(
nodes: List[_nodes.FlyteNode],
interface,
output_bindings,
id,
id: id_models.Identifier,
metadata,
metadata_defaults,
subworkflows: Optional[Dict[id_models.Identifier, _workflow_models.WorkflowTemplate]] = None,
tasks: Optional[Dict[id_models.Identifier, _task_models.TaskSpec]] = None,
launch_plans: Optional[Dict[id_models.Identifier, launch_plan_models.LaunchPlanSpec]] = None,
compiled_closure: Optional[compiler_models.CompiledWorkflowClosure] = None,
):
# TODO: Remove check
for node in nodes:
for upstream in node.upstream_nodes:
if upstream.id is None:
@@ -46,9 +52,13 @@ def __init__(
self._flyte_nodes = nodes
self._python_interface = None

@property
def upstream_entities(self):
return set(n.executable_flyte_object for n in self._flyte_nodes)
# Optional things that we save for ease of access when promoting from a model or CompiledWorkflowClosure
self._subworkflows = subworkflows
self._tasks = tasks
self._launch_plans = launch_plans
self._compiled_closure = compiled_closure

self._node_map = None

@property
def interface(self) -> _interfaces.TypedInterface:
@@ -60,7 +70,7 @@ def entity_type_text(self) -> str:

@property
def resource_type(self):
return _identifier_model.ResourceType.WORKFLOW
return id_models.ResourceType.WORKFLOW

@property
def flyte_nodes(self) -> List[_nodes.FlyteNode]:
@@ -76,37 +86,6 @@ def guessed_python_interface(self, value):
return
self._python_interface = value

def get_sub_workflows(self) -> List["FlyteWorkflow"]:
result = []
for node in self.flyte_nodes:
if node.workflow_node is not None and node.workflow_node.sub_workflow_ref is not None:
if node.flyte_entity is not None and node.flyte_entity.entity_type_text == "Workflow":
result.append(node.flyte_entity)
result.extend(node.flyte_entity.get_sub_workflows())
else:
raise _system_exceptions.FlyteSystemException(
"workflow node with subworkflow found but bad executable " "object {}".format(node.flyte_entity)
)

# get subworkflows in conditional branches
if node.branch_node is not None:
if_else: _workflow_models.IfElseBlock = node.branch_node.if_else
leaf_nodes: List[_nodes.FlyteNode] = filter(
None,
[
if_else.case.then_node,
*([] if if_else.other is None else [x.then_node for x in if_else.other]),
if_else.else_node,
],
)
for leaf_node in leaf_nodes:
exec_flyte_obj = leaf_node.flyte_entity
if exec_flyte_obj is not None and exec_flyte_obj.entity_type_text == "Workflow":
result.append(exec_flyte_obj)
result.extend(exec_flyte_obj.get_sub_workflows())

return result

@classmethod
def get_non_system_nodes(cls, nodes: List[_workflow_models.Node]) -> List[_workflow_models.Node]:
return [n for n in nodes if n.id not in {_constants.START_NODE_ID, _constants.END_NODE_ID}]
@@ -115,10 +94,10 @@ def get_non_system_nodes(cls, nodes: List[_workflow_models.Node]) -> List[_workf
def promote_from_model(
cls,
base_model: _workflow_models.WorkflowTemplate,
sub_workflows: Optional[Dict[_identifier.Identifier, _workflow_models.WorkflowTemplate]] = None,
node_launch_plans: Optional[Dict[_identifier.Identifier, _launch_plan_models.LaunchPlanSpec]] = None,
tasks: Optional[Dict[_identifier.Identifier, _task_models.TaskTemplate]] = None,
) -> "FlyteWorkflow":
sub_workflows: Optional[Dict[id_models, _workflow_models.WorkflowTemplate]] = None,
node_launch_plans: Optional[Dict[id_models, launch_plan_models.LaunchPlanSpec]] = None,
tasks: Optional[Dict[id_models, _task_models.TaskTemplate]] = None,
) -> FlyteWorkflow:
base_model_non_system_nodes = cls.get_non_system_nodes(base_model.nodes)
sub_workflows = sub_workflows or {}
tasks = tasks or {}
@@ -137,20 +116,50 @@ def promote_from_model(
# No inputs/outputs specified, see the constructor for more information on the overrides.
wf = cls(
nodes=list(node_map.values()),
id=_identifier.Identifier.promote_from_model(base_model.id),
id=base_model.id,
metadata=base_model.metadata,
metadata_defaults=base_model.metadata_defaults,
interface=_interfaces.TypedInterface.promote_from_model(base_model.interface),
output_bindings=base_model.outputs,
subworkflows=sub_workflows,
tasks=tasks,
launch_plans=node_launch_plans,
)

if wf.interface is not None:
wf.guessed_python_interface = Interface(
inputs=TypeEngine.guess_python_types(wf.interface.inputs),
outputs=TypeEngine.guess_python_types(wf.interface.outputs),
)
wf._node_map = node_map

return wf

@classmethod
def promote_from_closure(
cls,
closure: compiler_models.CompiledWorkflowClosure,
node_launch_plans: Optional[Dict[id_models, launch_plan_models.LaunchPlanSpec]] = None,
):
"""
Extracts out the relevant portions of a FlyteWorkflow from a closure from the control plane.
:param closure: This is the closure returned by Admin
:param node_launch_plans: The reason this exists is because the compiled closure doesn't have launch plans.
It only has subworkflows and tasks. Why this is is unclear. If supplied, this map of launch plans will be
:return:
"""
sub_workflows = {sw.template.id: sw.template for sw in closure.sub_workflows}
tasks = {t.template.id: t.template for t in closure.tasks}

flyte_wf = FlyteWorkflow.promote_from_model(
base_model=closure.primary.template,
sub_workflows=sub_workflows,
node_launch_plans=node_launch_plans,
tasks=tasks,
)
flyte_wf._compiled_closure = closure
return flyte_wf

def __call__(self, *args, **input_map):
raise NotImplementedError
76 changes: 0 additions & 76 deletions flytekit/remote/workflow_execution.py
Original file line number Diff line number Diff line change
@@ -1,76 +0,0 @@
from typing import Any, Dict

from flytekit.common.exceptions import user as _user_exceptions
from flytekit.models import execution as _execution_models
from flytekit.models.core import execution as _core_execution_models
from flytekit.remote import identifier as _core_identifier
from flytekit.remote import nodes as _nodes


class FlyteWorkflowExecution(_execution_models.Execution):
"""A class encapsulating a workflow execution being run on a Flyte remote backend."""

def __init__(self, *args, **kwargs):
super(FlyteWorkflowExecution, self).__init__(*args, **kwargs)
self._node_executions = None
self._inputs = None
self._outputs = None

@property
def node_executions(self) -> Dict[str, _nodes.FlyteNodeExecution]:
"""Get a dictionary of node executions that are a part of this workflow execution."""
return self._node_executions or {}

@property
def inputs(self) -> Dict[str, Any]:
"""
Returns the inputs to the execution in the standard python format as dictated by the type engine.
"""
return self._inputs

@property
def outputs(self) -> Dict[str, Any]:
"""
Returns the outputs to the execution in the standard python format as dictated by the type engine.
:raises: ``FlyteAssertion`` error if execution is in progress or execution ended in error.
"""
if not self.is_complete:
raise _user_exceptions.FlyteAssertion(
"Please wait until the node execution has completed before requesting the outputs."
)
if self.error:
raise _user_exceptions.FlyteAssertion("Outputs could not be found because the execution ended in failure.")
return self._outputs

@property
def error(self) -> _core_execution_models.ExecutionError:
"""
If execution is in progress, raise an exception. Otherwise, return None if no error was present upon
reaching completion.
"""
if not self.is_complete:
raise _user_exceptions.FlyteAssertion(
"Please wait until a workflow has completed before checking for an error."
)
return self.closure.error

@property
def is_complete(self) -> bool:
"""
Whether or not the execution is complete.
"""
return self.closure.phase in {
_core_execution_models.WorkflowExecutionPhase.ABORTED,
_core_execution_models.WorkflowExecutionPhase.FAILED,
_core_execution_models.WorkflowExecutionPhase.SUCCEEDED,
_core_execution_models.WorkflowExecutionPhase.TIMED_OUT,
}

@classmethod
def promote_from_model(cls, base_model: _execution_models.Execution) -> "FlyteWorkflowExecution":
return cls(
closure=base_model.closure,
id=_core_identifier.WorkflowExecutionIdentifier.promote_from_model(base_model.id),
spec=base_model.spec,
)
20 changes: 11 additions & 9 deletions flytekit/types/schema/types.py
Original file line number Diff line number Diff line change
@@ -4,11 +4,13 @@
import os
import typing
from abc import abstractmethod
from dataclasses import dataclass
from dataclasses import dataclass, field
from enum import Enum
from typing import Type

import numpy as _np
from dataclasses_json import config, dataclass_json
from marshmallow import fields

from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.type_engine import T, TypeEngine, TypeTransformer
@@ -167,7 +169,10 @@ def get_handler(cls, t: Type) -> SchemaHandler:
return cls._SCHEMA_HANDLERS[t]


@dataclass_json
@dataclass
class FlyteSchema(object):
remote_path: typing.Optional[os.PathLike] = field(default=None, metadata=config(mm_field=fields.String()))
"""
This is the main schema class that users should use.
"""
@@ -220,7 +225,7 @@ def format(cls) -> SchemaFormat:
def __init__(
self,
local_path: os.PathLike = None,
remote_path: str = None,
remote_path: os.PathLike = None,
supported_mode: SchemaOpenMode = SchemaOpenMode.WRITE,
downloader: typing.Callable[[str, os.PathLike], None] = None,
):
@@ -234,10 +239,11 @@ def __init__(
):
raise ValueError("To create a FlyteSchema in write mode, local_path is required")

if local_path is None:
local_path = FlyteContextManager.current_context().file_access.get_random_local_directory()
local_path = local_path or FlyteContextManager.current_context().file_access.get_random_local_directory()
self._local_path = local_path
self._remote_path = remote_path
# Make this field public, so that the dataclass transformer can set a value for it
# https://github.com/flyteorg/flytekit/blob/bcc8541bd6227b532f8462563fe8aac902242b21/flytekit/core/type_engine.py#L298
self.remote_path = remote_path or FlyteContextManager.current_context().file_access.get_random_remote_path()
self._supported_mode = supported_mode
# This is a special attribute that indicates if the data was either downloaded or uploaded
self._downloaded = False
@@ -247,10 +253,6 @@ def __init__(
def local_path(self) -> os.PathLike:
return self._local_path

@property
def remote_path(self) -> str:
return typing.cast(str, self._remote_path)

@property
def supported_mode(self) -> SchemaOpenMode:
return self._supported_mode
4 changes: 2 additions & 2 deletions tests/flytekit/integration/remote/test_remote.py
Original file line number Diff line number Diff line change
@@ -82,7 +82,7 @@ def test_monitor_workflow_execution(flyteclient, flyte_workflows_register, flyte
poll_interval = datetime.timedelta(seconds=1)
time_to_give_up = datetime.datetime.utcnow() + datetime.timedelta(seconds=60)

execution = remote.sync_workflow_execution(execution)
execution = remote.sync_workflow_execution(execution, sync_nodes=True)
while datetime.datetime.utcnow() < time_to_give_up:

if execution.is_complete:
@@ -94,7 +94,7 @@ def test_monitor_workflow_execution(flyteclient, flyte_workflows_register, flyte
execution.outputs

time.sleep(poll_interval.total_seconds())
execution = remote.sync_workflow_execution(execution)
execution = remote.sync_workflow_execution(execution, sync_nodes=True)

if execution.node_executions:
assert execution.node_executions["start-node"].closure.phase == 3 # SUCCEEEDED
55 changes: 55 additions & 0 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
@@ -5,13 +5,16 @@
from datetime import timedelta
from enum import Enum

import pandas as pd
import pytest
from dataclasses_json import DataClassJsonMixin, dataclass_json
from flyteidl.core import errors_pb2
from google.protobuf import json_format as _json_format
from google.protobuf import struct_pb2 as _struct
from marshmallow_enum import LoadDumpOptions
from marshmallow_jsonschema import JSONSchema

from flytekit import kwtypes
from flytekit.common.exceptions import user as user_exceptions
from flytekit.core.annotation import FlyteAnnotation
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
@@ -34,6 +37,7 @@
from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer
from flytekit.types.pickle import FlytePickle
from flytekit.types.pickle.pickle import FlytePickleTransformer
from flytekit.types.schema import FlyteSchema


def test_type_engine():
@@ -549,6 +553,28 @@ def test_enum_type():
TypeEngine.to_literal_type(UnsupportedEnumValues)


def test_enum_in_dataclass():
@dataclass_json
@dataclass
class Datum(object):
x: int
y: Color

lt = TypeEngine.to_literal_type(Datum)
schema = Datum.schema()
schema.fields["y"].load_by = LoadDumpOptions.name
assert lt.metadata == JSONSchema().dump(schema)

transformer = DataclassTransformer()
ctx = FlyteContext.current_context()
datum = Datum(5, Color.RED)
lv = transformer.to_literal(ctx, datum, Datum, lt)
gt = transformer.guess_python_type(lt)
pv = transformer.to_python_value(ctx, lv, expected_python_type=gt)
assert datum.x == pv.x
assert datum.y.value == pv.y


@pytest.mark.parametrize(
"python_value,python_types,expected_literal_map",
[
@@ -663,3 +689,32 @@ def test_multiple_annotations():
t = typing.Annotated[int, FlyteAnnotation({"foo": "bar"}), FlyteAnnotation({"anotha": "one"})]
with pytest.raises(Exception):
TypeEngine.to_literal_type(t)
TestSchema = FlyteSchema[kwtypes(some_str=str)]


@dataclass_json
@dataclass
class InnerResult:
number: int
schema: TestSchema


@dataclass_json
@dataclass
class Result:
result: InnerResult
schema: TestSchema


def test_schema_in_dataclass():
schema = TestSchema()
df = pd.DataFrame(data={"some_str": ["a", "b", "c"]})
schema.open().write(df)
o = Result(result=InnerResult(number=1, schema=schema), schema=schema)
ctx = FlyteContext.current_context()
tf = DataclassTransformer()
lt = tf.get_literal_type(Result)
lv = tf.to_literal(ctx, o, Result, lt)
ot = tf.to_python_value(ctx, lv=lv, expected_python_type=Result)

assert o == ot
53 changes: 53 additions & 0 deletions tests/flytekit/unit/core/test_type_hints.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
import typing
from collections import OrderedDict
from dataclasses import dataclass
from enum import Enum

import pandas
import pytest
@@ -1063,6 +1064,58 @@ def wf(x: int, y: int) -> Datum:
wf(x=10, y=20)


def test_enum_in_dataclass():
class Color(Enum):
RED = "red"
GREEN = "green"
BLUE = "blue"

@dataclass_json
@dataclass
class Datum(object):
x: int
y: Color

@task
def t1(x: int) -> Datum:
return Datum(x=x, y=Color.RED)

@workflow
def wf(x: int) -> Datum:
return t1(x=x)

assert wf(x=10) == Datum(10, Color.RED)


def test_flyte_schema_dataclass():
TestSchema = FlyteSchema[kwtypes(some_str=str)]

@dataclass_json
@dataclass
class InnerResult:
number: int
schema: TestSchema

@dataclass_json
@dataclass
class Result:
result: InnerResult
schema: TestSchema

schema = TestSchema()

@task
def t1(x: int) -> Result:

return Result(result=InnerResult(number=x, schema=schema), schema=schema)

@workflow
def wf(x: int) -> Result:
return t1(x=x)

assert wf(x=10) == Result(result=InnerResult(number=10, schema=schema), schema=schema)


def test_environment():
@task(environment={"FOO": "foofoo", "BAZ": "baz"})
def t1(a: int) -> str:
Empty file.
171 changes: 171 additions & 0 deletions tests/flytekit/unit/extras/tasks/test_shell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import datetime
import os
import tempfile
from subprocess import CalledProcessError

import pytest

from flytekit import kwtypes
from flytekit.extras.tasks.shell import OutputLocation, ShellTask
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import CSVFile, FlyteFile

test_file_path = os.path.dirname(os.path.realpath(__file__))
testdata = os.path.join(test_file_path, "testdata")
script_sh = os.path.join(testdata, "script.sh")
test_csv = os.path.join(testdata, "test.csv")


def test_shell_task_no_io():
t = ShellTask(
name="test",
script="""
echo "Hello World!"
""",
)

t()


def test_shell_task_fail():
t = ShellTask(
name="test",
script="""
non-existent blah
""",
)

with pytest.raises(Exception):
t()


def test_input_substitution_primitive():
t = ShellTask(
name="test",
script="""
set -ex
cat {{ .inputs.f }}
echo "Hello World {{ .inputs.y }} on {{ .inputs.j }}"
""",
inputs=kwtypes(f=str, y=int, j=datetime.datetime),
)

t(f=os.path.join(test_file_path, "__init__.py"), y=5, j=datetime.datetime(2021, 11, 10, 12, 15, 0))
t(f=os.path.join(test_file_path, "test_shell.py"), y=5, j=datetime.datetime(2021, 11, 10, 12, 15, 0))
with pytest.raises(CalledProcessError):
t(f="non_exist.py", y=5, j=datetime.datetime(2021, 11, 10, 12, 15, 0))


def test_input_substitution_files():
t = ShellTask(
name="test",
script="""
cat {{ .inputs.f }}
echo "Hello World {{ .inputs.y }} on {{ .inputs.j }}"
""",
inputs=kwtypes(f=CSVFile, y=FlyteDirectory, j=datetime.datetime),
)

assert t(f=test_csv, y=testdata, j=datetime.datetime(2021, 11, 10, 12, 15, 0)) is None


def test_input_output_substitution_files():
s = """
cat {{ .inputs.f }} > {{ .outputs.y }}
"""
t = ShellTask(
name="test",
debug=True,
script=s,
inputs=kwtypes(f=CSVFile),
output_locs=[
OutputLocation(var="y", var_type=FlyteFile, location="{{ .inputs.f }}.mod"),
],
)

assert t.script == s

contents = "1,2,3,4\n"
with tempfile.TemporaryDirectory() as tmp:
csv = os.path.join(tmp, "abc.csv")
print(csv)
with open(csv, "w") as f:
f.write(contents)
y = t(f=csv)
assert y.path[-4:] == ".mod"
assert os.path.exists(y.path)
with open(y.path) as f:
s = f.read()
assert s == contents


def test_input_single_output_substitution_files():
s = """
cat {{ .inputs.f }} >> {{ .outputs.y }}
echo "Hello World {{ .inputs.y }} on {{ .inputs.j }}"
"""
t = ShellTask(
name="test",
debug=True,
script=s,
inputs=kwtypes(f=CSVFile, y=FlyteDirectory, j=datetime.datetime),
output_locs=[OutputLocation(var="y", var_type=FlyteFile, location="{{ .inputs.f }}.pyc")],
)

assert t.script == s
y = t(f=test_csv, y=testdata, j=datetime.datetime(2021, 11, 10, 12, 15, 0))
assert y.path[-4:] == ".pyc"


def test_input_output_extra_var_in_template():
t = ShellTask(
name="test",
debug=True,
script="""
cat {{ .inputs.f }} {{ .inputs.missing }} >> {{ .outputs.y }}
echo "Hello World {{ .inputs.y }} on {{ .inputs.j }} - output {{.outputs.x}}"
""",
inputs=kwtypes(f=CSVFile, y=FlyteDirectory, j=datetime.datetime),
output_locs=[
OutputLocation(var="x", var_type=FlyteDirectory, location="{{ .inputs.y }}"),
OutputLocation(var="y", var_type=FlyteFile, location="{{ .inputs.f }}.pyc"),
],
)

with pytest.raises(ValueError):
t(f=test_csv, y=testdata, j=datetime.datetime(2021, 11, 10, 12, 15, 0))


def test_input_output_extra_input():
t = ShellTask(
name="test",
debug=True,
script="""
cat {{ .inputs.missing }} >> {{ .outputs.y }}
echo "Hello World {{ .inputs.y }} on {{ .inputs.j }} - output {{.outputs.x}}"
""",
inputs=kwtypes(f=CSVFile, y=FlyteDirectory, j=datetime.datetime),
output_locs=[
OutputLocation(var="x", var_type=FlyteDirectory, location="{{ .inputs.y }}"),
OutputLocation(var="y", var_type=FlyteFile, location="{{ .inputs.f }}.pyc"),
],
)

with pytest.raises(ValueError):
t(f=test_csv, y=testdata, j=datetime.datetime(2021, 11, 10, 12, 15, 0))


def test_shell_script():
t = ShellTask(
name="test2",
debug=True,
script_file=script_sh,
inputs=kwtypes(f=CSVFile, y=FlyteDirectory, j=datetime.datetime),
output_locs=[
OutputLocation(var="x", var_type=FlyteDirectory, location="{{ .inputs.y }}"),
OutputLocation(var="y", var_type=FlyteFile, location="{{ .inputs.f }}.pyc"),
],
)

assert t.script_file == script_sh
t(f=test_csv, y=testdata, j=datetime.datetime(2021, 11, 10, 12, 15, 0))
6 changes: 6 additions & 0 deletions tests/flytekit/unit/extras/tasks/testdata/script.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/bin/bash

set -ex

cat "{{ .inputs.f }}" >> "{{ .outputs.y }}"
echo "Hello World {{ .inputs.y }} on {{ .inputs.j }} - output {{.outputs.x}}"
Empty file.
44 changes: 44 additions & 0 deletions tests/flytekit/unit/models/admin/test_node_executions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,51 @@
from flytekit.models import node_execution as node_execution_models
from flytekit.models.core import catalog, identifier
from tests.flytekit.unit.common_tests.test_workflow_promote import get_compiled_workflow_closure


def test_metadata():
md = node_execution_models.NodeExecutionMetaData(retry_group="0", is_parent_node=True, spec_node_id="n0")
md2 = node_execution_models.NodeExecutionMetaData.from_flyte_idl(md.to_flyte_idl())
assert md == md2


def test_workflow_node_metadata():
wf_exec_id = identifier.WorkflowExecutionIdentifier("project", "domain", "name")

obj = node_execution_models.WorkflowNodeMetadata(execution_id=wf_exec_id)
assert obj.execution_id is wf_exec_id

obj2 = node_execution_models.WorkflowNodeMetadata.from_flyte_idl(obj.to_flyte_idl())
assert obj == obj2


def test_task_node_metadata():
task_id = identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version")
wf_exec_id = identifier.WorkflowExecutionIdentifier("project", "domain", "name")
node_exec_id = identifier.NodeExecutionIdentifier(
"node_id",
wf_exec_id,
)
te_id = identifier.TaskExecutionIdentifier(task_id, node_exec_id, 3)
ds_id = identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "t1", "abcdef")
tag = catalog.CatalogArtifactTag("my-artifact-id", "some name")
catalog_metadata = catalog.CatalogMetadata(dataset_id=ds_id, artifact_tag=tag, source_task_execution=te_id)

obj = node_execution_models.TaskNodeMetadata(cache_status=0, catalog_key=catalog_metadata)
assert obj.cache_status == 0
assert obj.catalog_key == catalog_metadata

obj2 = node_execution_models.TaskNodeMetadata.from_flyte_idl(obj.to_flyte_idl())
assert obj2 == obj


def test_dynamic_wf_node_metadata():
wf_id = identifier.Identifier(identifier.ResourceType.WORKFLOW, "project", "domain", "name", "version")
cwc = get_compiled_workflow_closure()

obj = node_execution_models.DynamicWorkflowNodeMetadata(id=wf_id, compiled_workflow=cwc)
assert obj.id == wf_id
assert obj.compiled_workflow == cwc

obj2 = node_execution_models.DynamicWorkflowNodeMetadata.from_flyte_idl(obj.to_flyte_idl())
assert obj2 == obj
32 changes: 32 additions & 0 deletions tests/flytekit/unit/models/core/test_catalog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from flytekit.models.core import catalog, identifier


def test_catalog_artifact_tag():
obj = catalog.CatalogArtifactTag("my-artifact-id", "some name")
assert obj.artifact_id == "my-artifact-id"
assert obj.name == "some name"

obj2 = catalog.CatalogArtifactTag.from_flyte_idl(obj.to_flyte_idl())
assert obj == obj2
assert obj2.artifact_id == "my-artifact-id"
assert obj2.name == "some name"


def test_catalog_metadata():
task_id = identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version")
wf_exec_id = identifier.WorkflowExecutionIdentifier("project", "domain", "name")
node_exec_id = identifier.NodeExecutionIdentifier(
"node_id",
wf_exec_id,
)
te_id = identifier.TaskExecutionIdentifier(task_id, node_exec_id, 3)
ds_id = identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "t1", "abcdef")
tag = catalog.CatalogArtifactTag("my-artifact-id", "some name")
obj = catalog.CatalogMetadata(dataset_id=ds_id, artifact_tag=tag, source_task_execution=te_id)
assert obj.dataset_id is ds_id
assert obj.source_execution is te_id
assert obj.source_task_execution is te_id
assert obj.artifact_tag is tag

obj2 = catalog.CatalogMetadata.from_flyte_idl(obj.to_flyte_idl())
assert obj == obj2
77 changes: 0 additions & 77 deletions tests/flytekit/unit/remote/test_identifier.py

This file was deleted.

96 changes: 4 additions & 92 deletions tests/flytekit/unit/remote/test_remote.py
Original file line number Diff line number Diff line change
@@ -6,20 +6,8 @@
from flytekit.common.exceptions import user as user_exceptions
from flytekit.configuration import internal
from flytekit.models import common as common_models
from flytekit.models.admin.workflow import Workflow
from flytekit.models.core.identifier import (
Identifier,
NodeExecutionIdentifier,
ResourceType,
WorkflowExecutionIdentifier,
)
from flytekit.models.core.identifier import ResourceType, WorkflowExecutionIdentifier
from flytekit.models.execution import Execution
from flytekit.models.interface import TypedInterface, Variable
from flytekit.models.launch_plan import LaunchPlan
from flytekit.models.node_execution import NodeExecution, NodeExecutionMetaData
from flytekit.models.task import Task
from flytekit.models.types import LiteralType, SimpleType
from flytekit.remote import FlyteWorkflow
from flytekit.remote.remote import FlyteRemote

CLIENT_METHODS = {
@@ -41,50 +29,6 @@
}


@patch("flytekit.clients.friendly.SynchronousFlyteClient")
@patch("flytekit.configuration.platform.URL")
@patch("flytekit.configuration.platform.INSECURE")
@pytest.mark.parametrize(
"entity_cls,resource_type",
[
[Workflow, ResourceType.WORKFLOW],
[Task, ResourceType.TASK],
[LaunchPlan, ResourceType.LAUNCH_PLAN],
],
)
def test_remote_fetch_execute_entities_task_workflow_launchplan(
mock_insecure,
mock_url,
mock_client,
entity_cls,
resource_type,
):
admin_entities = [
entity_cls(
Identifier(resource_type, "p1", "d1", "n1", version),
*([MagicMock()] if resource_type != ResourceType.LAUNCH_PLAN else [MagicMock(), MagicMock()]),
)
for version in ["latest", "old"]
]

mock_url.get.return_value = "localhost"
mock_insecure.get.return_value = True
mock_client = MagicMock()
getattr(mock_client, CLIENT_METHODS[resource_type]).return_value = admin_entities, ""

remote = FlyteRemote.from_config("p1", "d1")
remote._client = mock_client
fetch_method = getattr(remote, REMOTE_METHODS[resource_type])
flyte_entity_latest = fetch_method(name="n1", version="latest")
flyte_entity_latest_implicit = fetch_method(name="n1")
flyte_entity_old = fetch_method(name="n1", version="old")

assert flyte_entity_latest.entity_type_text == ENTITY_TYPE_TEXT[resource_type]
assert flyte_entity_latest.id == admin_entities[0].id
assert flyte_entity_latest.id == flyte_entity_latest_implicit.id
assert flyte_entity_latest.id != flyte_entity_old.id


@patch("flytekit.clients.friendly.SynchronousFlyteClient")
@patch("flytekit.configuration.platform.URL")
@patch("flytekit.configuration.platform.INSECURE")
@@ -106,39 +50,7 @@ def test_remote_fetch_workflow_execution(mock_insecure, mock_url, mock_client_ma
assert flyte_workflow_execution.id == admin_workflow_execution.id


@patch("flytekit.configuration.platform.URL")
@patch("flytekit.configuration.platform.INSECURE")
def test_get_node_execution_interface(mock_insecure, mock_url):
expected_interface = TypedInterface(
{"in1": Variable(LiteralType(simple=SimpleType.STRING), "in1 description")},
{"out1": Variable(LiteralType(simple=SimpleType.INTEGER), "out1 description")},
)

node_exec_id = NodeExecutionIdentifier("node_id", WorkflowExecutionIdentifier("p1", "d1", "exec_name"))

mock_node = MagicMock()
mock_node.id = node_exec_id.node_id
task_node = MagicMock()
flyte_task = MagicMock()
flyte_task.interface = expected_interface
task_node.flyte_task = flyte_task
mock_node.task_node = task_node

flyte_workflow = FlyteWorkflow([mock_node], None, None, None, None, None)

mock_url.get.return_value = "localhost"
mock_insecure.get.return_value = True
mock_client = MagicMock()

remote = FlyteRemote.from_config("p1", "d1")
remote._client = mock_client
actual_interface = remote._get_node_execution_interface(
NodeExecution(node_exec_id, None, None, NodeExecutionMetaData(None, True, None)), flyte_workflow
)
assert actual_interface == expected_interface


@patch("flytekit.remote.workflow_execution.FlyteWorkflowExecution.promote_from_model")
@patch("flytekit.remote.executions.FlyteWorkflowExecution.promote_from_model")
@patch("flytekit.configuration.platform.URL")
@patch("flytekit.configuration.platform.INSECURE")
def test_underscore_execute_uses_launch_plan_attributes(mock_insecure, mock_url, mock_wf_exec):
@@ -171,7 +83,7 @@ def local_assertions(*args, **kwargs):
)


@patch("flytekit.remote.workflow_execution.FlyteWorkflowExecution.promote_from_model")
@patch("flytekit.remote.executions.FlyteWorkflowExecution.promote_from_model")
@patch("flytekit.configuration.auth.ASSUMABLE_IAM_ROLE")
@patch("flytekit.configuration.platform.URL")
@patch("flytekit.configuration.platform.INSECURE")
@@ -201,7 +113,7 @@ def local_assertions(*args, **kwargs):
)


@patch("flytekit.remote.workflow_execution.FlyteWorkflowExecution.promote_from_model")
@patch("flytekit.remote.executions.FlyteWorkflowExecution.promote_from_model")
@patch("flytekit.configuration.platform.URL")
@patch("flytekit.configuration.platform.INSECURE")
def test_execute_with_wrong_input_key(mock_insecure, mock_url, mock_wf_exec):
5 changes: 0 additions & 5 deletions tests/flytekit/unit/remote/test_wrapper_classes.py
Original file line number Diff line number Diff line change
@@ -68,9 +68,6 @@ def wf(b: int) -> int:
assert list(fwf.interface.inputs.keys()) == ["b"]
assert len(fwf.nodes) == 1
assert len(fwf.flyte_nodes) == 1
flyte_subwfs = fwf.get_sub_workflows()
assert len(flyte_subwfs) == 1
assert fwf.nodes[0].workflow_node.sub_workflow_ref == flyte_subwfs[0].id

# Test another subwf that calls a launch plan instead of the sub_wf directly
@workflow
@@ -88,8 +85,6 @@ def wf2(b: int) -> int:
assert list(fwf.interface.inputs.keys()) == ["b"]
assert len(fwf.nodes) == 1
assert len(fwf.flyte_nodes) == 1
flyte_subwfs = fwf.get_sub_workflows()
assert len(flyte_subwfs) == 0
# The resource type will be different, so just check the name
assert fwf.nodes[0].workflow_node.launchplan_ref.name == list(lp_specs.values())[0].workflow_id.name

0 comments on commit fe378f5

Please sign in to comment.