-
Notifications
You must be signed in to change notification settings - Fork 301
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
32 changed files
with
1,371 additions
and
841 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,224 @@ | ||
import datetime | ||
import logging | ||
import os | ||
import re | ||
import subprocess | ||
import typing | ||
from dataclasses import dataclass | ||
|
||
from flytekit.core.context_manager import ExecutionParameters | ||
from flytekit.core.interface import Interface | ||
from flytekit.core.python_function_task import PythonInstanceTask | ||
from flytekit.core.task import TaskPlugins | ||
from flytekit.types.directory import FlyteDirectory | ||
from flytekit.types.file import FlyteFile | ||
|
||
|
||
@dataclass | ||
class OutputLocation: | ||
""" | ||
Args: | ||
var: str The name of the output variable | ||
var_type: typing.Type The type of output variable | ||
location: os.PathLike The location where this output variable will be written to or a regex that accepts input | ||
vars and generates the path. Of the form ``"{{ .inputs.v }}.tmp.md"``. | ||
This example for a given input v, at path `/tmp/abc.csv` will resolve to `/tmp/abc.csv.tmp.md` | ||
""" | ||
|
||
var: str | ||
var_type: typing.Type | ||
location: typing.Union[os.PathLike, str] | ||
|
||
|
||
def _stringify(v: typing.Any) -> str: | ||
""" | ||
Special cased return for the given value. Given the type returns the string version for the type. | ||
Handles FlyteFile and FlyteDirectory specially. Downloads and returns the downloaded filepath | ||
""" | ||
if isinstance(v, FlyteFile): | ||
v.download() | ||
return v.path | ||
if isinstance(v, FlyteDirectory): | ||
v.download() | ||
return v.path | ||
if isinstance(v, datetime.datetime): | ||
return v.isoformat() | ||
return str(v) | ||
|
||
|
||
def _interpolate(tmpl: str, regex: re.Pattern, validate_all_match: bool = True, **kwargs) -> str: | ||
""" | ||
Substitutes all templates that match the supplied regex | ||
with the given inputs and returns the substituted string. The result is non destructive towards the given string. | ||
""" | ||
modified = tmpl | ||
matched = set() | ||
for match in regex.finditer(tmpl): | ||
expr = match.groups()[0] | ||
var = match.groups()[1] | ||
if var not in kwargs: | ||
raise ValueError(f"Variable {var} in Query (part of {expr}) not found in inputs {kwargs.keys()}") | ||
matched.add(var) | ||
val = kwargs[var] | ||
# str conversion should be deliberate, with right conversion for each type | ||
modified = modified.replace(expr, _stringify(val)) | ||
|
||
if validate_all_match: | ||
if len(matched) < len(kwargs.keys()): | ||
diff = set(kwargs.keys()).difference(matched) | ||
raise ValueError(f"Extra Inputs have no matches in script template - missing {diff}") | ||
return modified | ||
|
||
|
||
def _dummy_task_func(): | ||
""" | ||
A Fake function to satisfy the inner PythonTask requirements | ||
""" | ||
return None | ||
|
||
|
||
T = typing.TypeVar("T") | ||
|
||
|
||
class ShellTask(PythonInstanceTask[T]): | ||
""" """ | ||
|
||
_INPUT_REGEX = re.compile(r"({{\s*.inputs.(\w+)\s*}})", re.IGNORECASE) | ||
_OUTPUT_REGEX = re.compile(r"({{\s*.outputs.(\w+)\s*}})", re.IGNORECASE) | ||
|
||
def __init__( | ||
self, | ||
name: str, | ||
debug: bool = False, | ||
script: typing.Optional[str] = None, | ||
script_file: typing.Optional[str] = None, | ||
task_config: T = None, | ||
inputs: typing.Optional[typing.Dict[str, typing.Type]] = None, | ||
output_locs: typing.Optional[typing.List[OutputLocation]] = None, | ||
**kwargs, | ||
): | ||
""" | ||
Args: | ||
name: str Name of the Task. Should be unique in the project | ||
debug: bool Print the generated script and other debugging information | ||
script: The actual script specified as a string | ||
script_file: A path to the file that contains the script (Only script or script_file) can be provided | ||
task_config: T Configuration for the task, can be either a Pod (or coming soon, BatchJob) config | ||
inputs: A Dictionary of input names to types | ||
output_locs: A list of :py:class:`OutputLocations` | ||
**kwargs: Other arguments that can be passed to :ref:class:`PythonInstanceTask` | ||
""" | ||
if script and script_file: | ||
raise ValueError("Only either of script or script_file can be provided") | ||
if not script and not script_file: | ||
raise ValueError("Either a script or script_file is needed") | ||
if script_file: | ||
if not os.path.exists(script_file): | ||
raise ValueError(f"FileNotFound: the specified Script file at path {script_file} cannot be loaded") | ||
script_file = os.path.abspath(script_file) | ||
|
||
if task_config is not None: | ||
if str(type(task_config)) != "flytekitplugins.pod.task.Pod": | ||
raise ValueError("TaskConfig can either be empty - indicating simple container task or a PodConfig.") | ||
|
||
# Each instance of NotebookTask instantiates an underlying task with a dummy function that will only be used | ||
# to run pre- and post- execute functions using the corresponding task plugin. | ||
# We rename the function name here to ensure the generated task has a unique name and avoid duplicate task name | ||
# errors. | ||
# This seem like a hack. We should use a plugin_class that doesn't require a fake-function to make work. | ||
plugin_class = TaskPlugins.find_pythontask_plugin(type(task_config)) | ||
self._config_task_instance = plugin_class(task_config=task_config, task_function=_dummy_task_func) | ||
# Rename the internal task so that there are no conflicts at serialization time. Technically these internal | ||
# tasks should not be serialized at all, but we don't currently have a mechanism for skipping Flyte entities | ||
# at serialization time. | ||
self._config_task_instance._name = f"_bash.{name}" | ||
self._script = script | ||
self._script_file = script_file | ||
self._debug = debug | ||
self._output_locs = output_locs if output_locs else [] | ||
outputs = self._validate_output_locs() | ||
super().__init__( | ||
name, | ||
task_config, | ||
task_type=self._config_task_instance.task_type, | ||
interface=Interface(inputs=inputs, outputs=outputs), | ||
**kwargs, | ||
) | ||
|
||
def _validate_output_locs(self) -> typing.Dict[str, typing.Type]: | ||
outputs = {} | ||
for v in self._output_locs: | ||
if v is None: | ||
raise ValueError("OutputLocation cannot be none") | ||
if not isinstance(v, OutputLocation): | ||
raise ValueError("Every output type should be an output location on the file-system") | ||
if v.location is None: | ||
raise ValueError(f"Output Location not provided for output var {v.var}") | ||
if not issubclass(v.var_type, FlyteFile) and not issubclass(v.var_type, FlyteDirectory): | ||
raise ValueError( | ||
"Currently only outputs of type FlyteFile/FlyteDirectory and their derived types are supported" | ||
) | ||
outputs[v.var] = v.var_type | ||
return outputs | ||
|
||
@property | ||
def script(self) -> typing.Optional[str]: | ||
return self._script | ||
|
||
@property | ||
def script_file(self) -> typing.Optional[os.PathLike]: | ||
return self._script_file | ||
|
||
def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: | ||
return self._config_task_instance.pre_execute(user_params) | ||
|
||
def execute(self, **kwargs) -> typing.Any: | ||
""" | ||
Executes the given script by substituting the inputs and outputs and extracts the outputs from the filesystem | ||
""" | ||
logging.info(f"Running shell script as type {self.task_type}") | ||
if self.script_file: | ||
with open(self.script_file) as f: | ||
self._script = f.read() | ||
|
||
outputs: typing.Dict[str, str] = {} | ||
if self._output_locs: | ||
for v in self._output_locs: | ||
outputs[v.var] = _interpolate(v.location, self._INPUT_REGEX, validate_all_match=False, **kwargs) | ||
|
||
gen_script = _interpolate(self._script, self._INPUT_REGEX, **kwargs) | ||
# For outputs it is not necessary that all outputs are used in the script, some are implicit outputs | ||
# for example gcc main.c will generate a.out automatically | ||
gen_script = _interpolate(gen_script, self._OUTPUT_REGEX, validate_all_match=False, **outputs) | ||
if self._debug: | ||
print("\n==============================================\n") | ||
print(gen_script) | ||
print("\n==============================================\n") | ||
|
||
try: | ||
subprocess.check_call(gen_script, shell=True) | ||
except subprocess.CalledProcessError as e: | ||
files = os.listdir("./") | ||
fstr = "\n-".join(files) | ||
logging.error( | ||
f"Failed to Execute Script, return-code {e.returncode} \n" | ||
f"StdErr: {e.stderr}\n" | ||
f"StdOut: {e.stdout}\n" | ||
f" Current directory contents: .\n-{fstr}" | ||
) | ||
raise | ||
|
||
final_outputs = [] | ||
for v in self._output_locs: | ||
if issubclass(v.var_type, FlyteFile): | ||
final_outputs.append(FlyteFile(outputs[v.var])) | ||
if issubclass(v.var_type, FlyteDirectory): | ||
final_outputs.append(FlyteDirectory(outputs[v.var])) | ||
if len(final_outputs) == 1: | ||
return final_outputs[0] | ||
if len(final_outputs) > 1: | ||
return tuple(final_outputs) | ||
return None | ||
|
||
def post_execute(self, user_params: ExecutionParameters, rval: typing.Any) -> typing.Any: | ||
return self._config_task_instance.post_execute(user_params, rval) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from flyteidl.core import catalog_pb2 | ||
|
||
from flytekit.models import common as _common_models | ||
from flytekit.models.core import identifier as _identifier | ||
|
||
|
||
class CatalogArtifactTag(_common_models.FlyteIdlEntity): | ||
def __init__(self, artifact_id: str, name: str): | ||
self._artifact_id = artifact_id | ||
self._name = name | ||
|
||
@property | ||
def artifact_id(self) -> str: | ||
return self._artifact_id | ||
|
||
@property | ||
def name(self) -> str: | ||
return self._name | ||
|
||
def to_flyte_idl(self) -> catalog_pb2.CatalogArtifactTag: | ||
return catalog_pb2.CatalogArtifactTag(artifact_id=self.artifact_id, name=self.name) | ||
|
||
@classmethod | ||
def from_flyte_idl(cls, p: catalog_pb2.CatalogArtifactTag) -> "CatalogArtifactTag": | ||
return cls( | ||
artifact_id=p.artifact_id, | ||
name=p.name, | ||
) | ||
|
||
|
||
class CatalogMetadata(_common_models.FlyteIdlEntity): | ||
def __init__( | ||
self, | ||
dataset_id: _identifier.Identifier, | ||
artifact_tag: CatalogArtifactTag, | ||
source_task_execution: _identifier.TaskExecutionIdentifier, | ||
): | ||
self._dataset_id = dataset_id | ||
self._artifact_tag = artifact_tag | ||
self._source_task_execution = source_task_execution | ||
|
||
@property | ||
def dataset_id(self) -> _identifier.Identifier: | ||
return self._dataset_id | ||
|
||
@property | ||
def artifact_tag(self) -> CatalogArtifactTag: | ||
return self._artifact_tag | ||
|
||
@property | ||
def source_task_execution(self) -> _identifier.TaskExecutionIdentifier: | ||
return self._source_task_execution | ||
|
||
@property | ||
def source_execution(self) -> _identifier.TaskExecutionIdentifier: | ||
""" | ||
This is a one of but for now there's only one thing in the one of | ||
""" | ||
return self._source_task_execution | ||
|
||
def to_flyte_idl(self) -> catalog_pb2.CatalogMetadata: | ||
return catalog_pb2.CatalogMetadata( | ||
dataset_id=self.dataset_id.to_flyte_idl(), | ||
artifact_tag=self.artifact_tag.to_flyte_idl(), | ||
source_task_execution=self.source_task_execution.to_flyte_idl(), | ||
) | ||
|
||
@classmethod | ||
def from_flyte_idl(cls, pb: catalog_pb2.CatalogMetadata) -> "CatalogMetadata": | ||
return cls( | ||
dataset_id=_identifier.Identifier.from_flyte_idl(pb.dataset_id), | ||
artifact_tag=CatalogArtifactTag.from_flyte_idl(pb.artifact_tag), | ||
# Add HasField check if more things are ever added to the one of | ||
source_task_execution=_identifier.TaskExecutionIdentifier.from_flyte_idl(pb.source_task_execution), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,239 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any, Dict, List, Optional, Union | ||
|
||
from flytekit.common.exceptions import user as _user_exceptions | ||
from flytekit.common.exceptions import user as user_exceptions | ||
from flytekit.models import execution as execution_models | ||
from flytekit.models import node_execution as node_execution_models | ||
from flytekit.models.admin import task_execution as admin_task_execution_models | ||
from flytekit.models.core import execution as core_execution_models | ||
from flytekit.remote.workflow import FlyteWorkflow | ||
|
||
|
||
class FlyteTaskExecution(admin_task_execution_models.TaskExecution): | ||
"""A class encapsulating a task execution being run on a Flyte remote backend.""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
super(FlyteTaskExecution, self).__init__(*args, **kwargs) | ||
self._inputs = None | ||
self._outputs = None | ||
|
||
@property | ||
def is_complete(self) -> bool: | ||
"""Whether or not the execution is complete.""" | ||
return self.closure.phase in { | ||
core_execution_models.TaskExecutionPhase.ABORTED, | ||
core_execution_models.TaskExecutionPhase.FAILED, | ||
core_execution_models.TaskExecutionPhase.SUCCEEDED, | ||
} | ||
|
||
@property | ||
def inputs(self) -> Dict[str, Any]: | ||
""" | ||
Returns the inputs of the task execution in the standard Python format that is produced by | ||
the type engine. | ||
""" | ||
return self._inputs | ||
|
||
@property | ||
def outputs(self) -> Dict[str, Any]: | ||
""" | ||
Returns the outputs of the task execution, if available, in the standard Python format that is produced by | ||
the type engine. | ||
:raises: ``FlyteAssertion`` error if execution is in progress or execution ended in error. | ||
""" | ||
if not self.is_complete: | ||
raise user_exceptions.FlyteAssertion( | ||
"Please wait until the node execution has completed before requesting the outputs." | ||
) | ||
if self.error: | ||
raise user_exceptions.FlyteAssertion("Outputs could not be found because the execution ended in failure.") | ||
return self._outputs | ||
|
||
@property | ||
def error(self) -> Optional[core_execution_models.ExecutionError]: | ||
""" | ||
If execution is in progress, raise an exception. Otherwise, return None if no error was present upon | ||
reaching completion. | ||
""" | ||
if not self.is_complete: | ||
raise user_exceptions.FlyteAssertion( | ||
"Please what until the task execution has completed before requesting error information." | ||
) | ||
return self.closure.error | ||
|
||
@classmethod | ||
def promote_from_model(cls, base_model: admin_task_execution_models.TaskExecution) -> "FlyteTaskExecution": | ||
return cls( | ||
closure=base_model.closure, | ||
id=base_model.id, | ||
input_uri=base_model.input_uri, | ||
is_parent=base_model.is_parent, | ||
) | ||
|
||
|
||
class FlyteWorkflowExecution(execution_models.Execution): | ||
"""A class encapsulating a workflow execution being run on a Flyte remote backend.""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
super(FlyteWorkflowExecution, self).__init__(*args, **kwargs) | ||
self._node_executions = None | ||
self._inputs = None | ||
self._outputs = None | ||
self._flyte_workflow: Optional[FlyteWorkflow] = None | ||
|
||
@property | ||
def node_executions(self) -> Dict[str, "FlyteNodeExecution"]: | ||
"""Get a dictionary of node executions that are a part of this workflow execution.""" | ||
return self._node_executions or {} | ||
|
||
@property | ||
def inputs(self) -> Dict[str, Any]: | ||
""" | ||
Returns the inputs to the execution in the standard python format as dictated by the type engine. | ||
""" | ||
return self._inputs | ||
|
||
@property | ||
def outputs(self) -> Dict[str, Any]: | ||
""" | ||
Returns the outputs to the execution in the standard python format as dictated by the type engine. | ||
:raises: ``FlyteAssertion`` error if execution is in progress or execution ended in error. | ||
""" | ||
if not self.is_complete: | ||
raise _user_exceptions.FlyteAssertion( | ||
"Please wait until the node execution has completed before requesting the outputs." | ||
) | ||
if self.error: | ||
raise _user_exceptions.FlyteAssertion("Outputs could not be found because the execution ended in failure.") | ||
return self._outputs | ||
|
||
@property | ||
def error(self) -> core_execution_models.ExecutionError: | ||
""" | ||
If execution is in progress, raise an exception. Otherwise, return None if no error was present upon | ||
reaching completion. | ||
""" | ||
if not self.is_complete: | ||
raise _user_exceptions.FlyteAssertion( | ||
"Please wait until a workflow has completed before checking for an error." | ||
) | ||
return self.closure.error | ||
|
||
@property | ||
def is_complete(self) -> bool: | ||
""" | ||
Whether or not the execution is complete. | ||
""" | ||
return self.closure.phase in { | ||
core_execution_models.WorkflowExecutionPhase.ABORTED, | ||
core_execution_models.WorkflowExecutionPhase.FAILED, | ||
core_execution_models.WorkflowExecutionPhase.SUCCEEDED, | ||
core_execution_models.WorkflowExecutionPhase.TIMED_OUT, | ||
} | ||
|
||
@classmethod | ||
def promote_from_model(cls, base_model: execution_models.Execution) -> "FlyteWorkflowExecution": | ||
return cls( | ||
closure=base_model.closure, | ||
id=base_model.id, | ||
spec=base_model.spec, | ||
) | ||
|
||
|
||
class FlyteNodeExecution(node_execution_models.NodeExecution): | ||
"""A class encapsulating a node execution being run on a Flyte remote backend.""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
super(FlyteNodeExecution, self).__init__(*args, **kwargs) | ||
self._task_executions = None | ||
self._workflow_executions = [] | ||
self._underlying_node_executions = None | ||
self._inputs = None | ||
self._outputs = None | ||
self._interface = None | ||
|
||
@property | ||
def task_executions(self) -> List[FlyteTaskExecution]: | ||
return self._task_executions or [] | ||
|
||
@property | ||
def workflow_executions(self) -> List[FlyteWorkflowExecution]: | ||
return self._workflow_executions | ||
|
||
@property | ||
def subworkflow_node_executions(self) -> Dict[str, FlyteNodeExecution]: | ||
""" | ||
This returns underlying node executions in instances where the current node execution is | ||
a parent node. This happens when it's either a static or dynamic subworkflow. | ||
""" | ||
return ( | ||
{} | ||
if self._underlying_node_executions is None | ||
else {n.id.node_id: n for n in self._underlying_node_executions} | ||
) | ||
|
||
@property | ||
def executions(self) -> List[Union[FlyteTaskExecution, FlyteWorkflowExecution]]: | ||
return self.task_executions or self._underlying_node_executions or [] | ||
|
||
@property | ||
def inputs(self) -> Dict[str, Any]: | ||
""" | ||
Returns the inputs to the execution in the standard python format as dictated by the type engine. | ||
""" | ||
return self._inputs | ||
|
||
@property | ||
def outputs(self) -> Dict[str, Any]: | ||
""" | ||
Returns the outputs to the execution in the standard python format as dictated by the type engine. | ||
:raises: ``FlyteAssertion`` error if execution is in progress or execution ended in error. | ||
""" | ||
if not self.is_complete: | ||
raise _user_exceptions.FlyteAssertion( | ||
"Please wait until the node execution has completed before requesting the outputs." | ||
) | ||
if self.error: | ||
raise _user_exceptions.FlyteAssertion("Outputs could not be found because the execution ended in failure.") | ||
return self._outputs | ||
|
||
@property | ||
def error(self) -> core_execution_models.ExecutionError: | ||
""" | ||
If execution is in progress, raise an exception. Otherwise, return None if no error was present upon | ||
reaching completion. | ||
""" | ||
if not self.is_complete: | ||
raise _user_exceptions.FlyteAssertion( | ||
"Please wait until the node execution has completed before requesting error information." | ||
) | ||
return self.closure.error | ||
|
||
@property | ||
def is_complete(self) -> bool: | ||
"""Whether or not the execution is complete.""" | ||
return self.closure.phase in { | ||
core_execution_models.NodeExecutionPhase.ABORTED, | ||
core_execution_models.NodeExecutionPhase.FAILED, | ||
core_execution_models.NodeExecutionPhase.SKIPPED, | ||
core_execution_models.NodeExecutionPhase.SUCCEEDED, | ||
core_execution_models.NodeExecutionPhase.TIMED_OUT, | ||
} | ||
|
||
@classmethod | ||
def promote_from_model(cls, base_model: node_execution_models.NodeExecution) -> "FlyteNodeExecution": | ||
return cls( | ||
closure=base_model.closure, id=base_model.id, input_uri=base_model.input_uri, metadata=base_model.metadata | ||
) | ||
|
||
@property | ||
def interface(self) -> "flytekit.remote.interface.TypedInterface": | ||
""" | ||
Return the interface of the task or subworkflow associated with this node execution. | ||
""" | ||
return self._interface |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,76 +0,0 @@ | ||
from typing import Any, Dict | ||
|
||
from flytekit.common.exceptions import user as _user_exceptions | ||
from flytekit.models import execution as _execution_models | ||
from flytekit.models.core import execution as _core_execution_models | ||
from flytekit.remote import identifier as _core_identifier | ||
from flytekit.remote import nodes as _nodes | ||
|
||
|
||
class FlyteWorkflowExecution(_execution_models.Execution): | ||
"""A class encapsulating a workflow execution being run on a Flyte remote backend.""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
super(FlyteWorkflowExecution, self).__init__(*args, **kwargs) | ||
self._node_executions = None | ||
self._inputs = None | ||
self._outputs = None | ||
|
||
@property | ||
def node_executions(self) -> Dict[str, _nodes.FlyteNodeExecution]: | ||
"""Get a dictionary of node executions that are a part of this workflow execution.""" | ||
return self._node_executions or {} | ||
|
||
@property | ||
def inputs(self) -> Dict[str, Any]: | ||
""" | ||
Returns the inputs to the execution in the standard python format as dictated by the type engine. | ||
""" | ||
return self._inputs | ||
|
||
@property | ||
def outputs(self) -> Dict[str, Any]: | ||
""" | ||
Returns the outputs to the execution in the standard python format as dictated by the type engine. | ||
:raises: ``FlyteAssertion`` error if execution is in progress or execution ended in error. | ||
""" | ||
if not self.is_complete: | ||
raise _user_exceptions.FlyteAssertion( | ||
"Please wait until the node execution has completed before requesting the outputs." | ||
) | ||
if self.error: | ||
raise _user_exceptions.FlyteAssertion("Outputs could not be found because the execution ended in failure.") | ||
return self._outputs | ||
|
||
@property | ||
def error(self) -> _core_execution_models.ExecutionError: | ||
""" | ||
If execution is in progress, raise an exception. Otherwise, return None if no error was present upon | ||
reaching completion. | ||
""" | ||
if not self.is_complete: | ||
raise _user_exceptions.FlyteAssertion( | ||
"Please wait until a workflow has completed before checking for an error." | ||
) | ||
return self.closure.error | ||
|
||
@property | ||
def is_complete(self) -> bool: | ||
""" | ||
Whether or not the execution is complete. | ||
""" | ||
return self.closure.phase in { | ||
_core_execution_models.WorkflowExecutionPhase.ABORTED, | ||
_core_execution_models.WorkflowExecutionPhase.FAILED, | ||
_core_execution_models.WorkflowExecutionPhase.SUCCEEDED, | ||
_core_execution_models.WorkflowExecutionPhase.TIMED_OUT, | ||
} | ||
|
||
@classmethod | ||
def promote_from_model(cls, base_model: _execution_models.Execution) -> "FlyteWorkflowExecution": | ||
return cls( | ||
closure=base_model.closure, | ||
id=_core_identifier.WorkflowExecutionIdentifier.promote_from_model(base_model.id), | ||
spec=base_model.spec, | ||
) | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
import datetime | ||
import os | ||
import tempfile | ||
from subprocess import CalledProcessError | ||
|
||
import pytest | ||
|
||
from flytekit import kwtypes | ||
from flytekit.extras.tasks.shell import OutputLocation, ShellTask | ||
from flytekit.types.directory import FlyteDirectory | ||
from flytekit.types.file import CSVFile, FlyteFile | ||
|
||
test_file_path = os.path.dirname(os.path.realpath(__file__)) | ||
testdata = os.path.join(test_file_path, "testdata") | ||
script_sh = os.path.join(testdata, "script.sh") | ||
test_csv = os.path.join(testdata, "test.csv") | ||
|
||
|
||
def test_shell_task_no_io(): | ||
t = ShellTask( | ||
name="test", | ||
script=""" | ||
echo "Hello World!" | ||
""", | ||
) | ||
|
||
t() | ||
|
||
|
||
def test_shell_task_fail(): | ||
t = ShellTask( | ||
name="test", | ||
script=""" | ||
non-existent blah | ||
""", | ||
) | ||
|
||
with pytest.raises(Exception): | ||
t() | ||
|
||
|
||
def test_input_substitution_primitive(): | ||
t = ShellTask( | ||
name="test", | ||
script=""" | ||
set -ex | ||
cat {{ .inputs.f }} | ||
echo "Hello World {{ .inputs.y }} on {{ .inputs.j }}" | ||
""", | ||
inputs=kwtypes(f=str, y=int, j=datetime.datetime), | ||
) | ||
|
||
t(f=os.path.join(test_file_path, "__init__.py"), y=5, j=datetime.datetime(2021, 11, 10, 12, 15, 0)) | ||
t(f=os.path.join(test_file_path, "test_shell.py"), y=5, j=datetime.datetime(2021, 11, 10, 12, 15, 0)) | ||
with pytest.raises(CalledProcessError): | ||
t(f="non_exist.py", y=5, j=datetime.datetime(2021, 11, 10, 12, 15, 0)) | ||
|
||
|
||
def test_input_substitution_files(): | ||
t = ShellTask( | ||
name="test", | ||
script=""" | ||
cat {{ .inputs.f }} | ||
echo "Hello World {{ .inputs.y }} on {{ .inputs.j }}" | ||
""", | ||
inputs=kwtypes(f=CSVFile, y=FlyteDirectory, j=datetime.datetime), | ||
) | ||
|
||
assert t(f=test_csv, y=testdata, j=datetime.datetime(2021, 11, 10, 12, 15, 0)) is None | ||
|
||
|
||
def test_input_output_substitution_files(): | ||
s = """ | ||
cat {{ .inputs.f }} > {{ .outputs.y }} | ||
""" | ||
t = ShellTask( | ||
name="test", | ||
debug=True, | ||
script=s, | ||
inputs=kwtypes(f=CSVFile), | ||
output_locs=[ | ||
OutputLocation(var="y", var_type=FlyteFile, location="{{ .inputs.f }}.mod"), | ||
], | ||
) | ||
|
||
assert t.script == s | ||
|
||
contents = "1,2,3,4\n" | ||
with tempfile.TemporaryDirectory() as tmp: | ||
csv = os.path.join(tmp, "abc.csv") | ||
print(csv) | ||
with open(csv, "w") as f: | ||
f.write(contents) | ||
y = t(f=csv) | ||
assert y.path[-4:] == ".mod" | ||
assert os.path.exists(y.path) | ||
with open(y.path) as f: | ||
s = f.read() | ||
assert s == contents | ||
|
||
|
||
def test_input_single_output_substitution_files(): | ||
s = """ | ||
cat {{ .inputs.f }} >> {{ .outputs.y }} | ||
echo "Hello World {{ .inputs.y }} on {{ .inputs.j }}" | ||
""" | ||
t = ShellTask( | ||
name="test", | ||
debug=True, | ||
script=s, | ||
inputs=kwtypes(f=CSVFile, y=FlyteDirectory, j=datetime.datetime), | ||
output_locs=[OutputLocation(var="y", var_type=FlyteFile, location="{{ .inputs.f }}.pyc")], | ||
) | ||
|
||
assert t.script == s | ||
y = t(f=test_csv, y=testdata, j=datetime.datetime(2021, 11, 10, 12, 15, 0)) | ||
assert y.path[-4:] == ".pyc" | ||
|
||
|
||
def test_input_output_extra_var_in_template(): | ||
t = ShellTask( | ||
name="test", | ||
debug=True, | ||
script=""" | ||
cat {{ .inputs.f }} {{ .inputs.missing }} >> {{ .outputs.y }} | ||
echo "Hello World {{ .inputs.y }} on {{ .inputs.j }} - output {{.outputs.x}}" | ||
""", | ||
inputs=kwtypes(f=CSVFile, y=FlyteDirectory, j=datetime.datetime), | ||
output_locs=[ | ||
OutputLocation(var="x", var_type=FlyteDirectory, location="{{ .inputs.y }}"), | ||
OutputLocation(var="y", var_type=FlyteFile, location="{{ .inputs.f }}.pyc"), | ||
], | ||
) | ||
|
||
with pytest.raises(ValueError): | ||
t(f=test_csv, y=testdata, j=datetime.datetime(2021, 11, 10, 12, 15, 0)) | ||
|
||
|
||
def test_input_output_extra_input(): | ||
t = ShellTask( | ||
name="test", | ||
debug=True, | ||
script=""" | ||
cat {{ .inputs.missing }} >> {{ .outputs.y }} | ||
echo "Hello World {{ .inputs.y }} on {{ .inputs.j }} - output {{.outputs.x}}" | ||
""", | ||
inputs=kwtypes(f=CSVFile, y=FlyteDirectory, j=datetime.datetime), | ||
output_locs=[ | ||
OutputLocation(var="x", var_type=FlyteDirectory, location="{{ .inputs.y }}"), | ||
OutputLocation(var="y", var_type=FlyteFile, location="{{ .inputs.f }}.pyc"), | ||
], | ||
) | ||
|
||
with pytest.raises(ValueError): | ||
t(f=test_csv, y=testdata, j=datetime.datetime(2021, 11, 10, 12, 15, 0)) | ||
|
||
|
||
def test_shell_script(): | ||
t = ShellTask( | ||
name="test2", | ||
debug=True, | ||
script_file=script_sh, | ||
inputs=kwtypes(f=CSVFile, y=FlyteDirectory, j=datetime.datetime), | ||
output_locs=[ | ||
OutputLocation(var="x", var_type=FlyteDirectory, location="{{ .inputs.y }}"), | ||
OutputLocation(var="y", var_type=FlyteFile, location="{{ .inputs.f }}.pyc"), | ||
], | ||
) | ||
|
||
assert t.script_file == script_sh | ||
t(f=test_csv, y=testdata, j=datetime.datetime(2021, 11, 10, 12, 15, 0)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#!/bin/bash | ||
|
||
set -ex | ||
|
||
cat "{{ .inputs.f }}" >> "{{ .outputs.y }}" | ||
echo "Hello World {{ .inputs.y }} on {{ .inputs.j }} - output {{.outputs.x}}" |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,51 @@ | ||
from flytekit.models import node_execution as node_execution_models | ||
from flytekit.models.core import catalog, identifier | ||
from tests.flytekit.unit.common_tests.test_workflow_promote import get_compiled_workflow_closure | ||
|
||
|
||
def test_metadata(): | ||
md = node_execution_models.NodeExecutionMetaData(retry_group="0", is_parent_node=True, spec_node_id="n0") | ||
md2 = node_execution_models.NodeExecutionMetaData.from_flyte_idl(md.to_flyte_idl()) | ||
assert md == md2 | ||
|
||
|
||
def test_workflow_node_metadata(): | ||
wf_exec_id = identifier.WorkflowExecutionIdentifier("project", "domain", "name") | ||
|
||
obj = node_execution_models.WorkflowNodeMetadata(execution_id=wf_exec_id) | ||
assert obj.execution_id is wf_exec_id | ||
|
||
obj2 = node_execution_models.WorkflowNodeMetadata.from_flyte_idl(obj.to_flyte_idl()) | ||
assert obj == obj2 | ||
|
||
|
||
def test_task_node_metadata(): | ||
task_id = identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version") | ||
wf_exec_id = identifier.WorkflowExecutionIdentifier("project", "domain", "name") | ||
node_exec_id = identifier.NodeExecutionIdentifier( | ||
"node_id", | ||
wf_exec_id, | ||
) | ||
te_id = identifier.TaskExecutionIdentifier(task_id, node_exec_id, 3) | ||
ds_id = identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "t1", "abcdef") | ||
tag = catalog.CatalogArtifactTag("my-artifact-id", "some name") | ||
catalog_metadata = catalog.CatalogMetadata(dataset_id=ds_id, artifact_tag=tag, source_task_execution=te_id) | ||
|
||
obj = node_execution_models.TaskNodeMetadata(cache_status=0, catalog_key=catalog_metadata) | ||
assert obj.cache_status == 0 | ||
assert obj.catalog_key == catalog_metadata | ||
|
||
obj2 = node_execution_models.TaskNodeMetadata.from_flyte_idl(obj.to_flyte_idl()) | ||
assert obj2 == obj | ||
|
||
|
||
def test_dynamic_wf_node_metadata(): | ||
wf_id = identifier.Identifier(identifier.ResourceType.WORKFLOW, "project", "domain", "name", "version") | ||
cwc = get_compiled_workflow_closure() | ||
|
||
obj = node_execution_models.DynamicWorkflowNodeMetadata(id=wf_id, compiled_workflow=cwc) | ||
assert obj.id == wf_id | ||
assert obj.compiled_workflow == cwc | ||
|
||
obj2 = node_execution_models.DynamicWorkflowNodeMetadata.from_flyte_idl(obj.to_flyte_idl()) | ||
assert obj2 == obj |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from flytekit.models.core import catalog, identifier | ||
|
||
|
||
def test_catalog_artifact_tag(): | ||
obj = catalog.CatalogArtifactTag("my-artifact-id", "some name") | ||
assert obj.artifact_id == "my-artifact-id" | ||
assert obj.name == "some name" | ||
|
||
obj2 = catalog.CatalogArtifactTag.from_flyte_idl(obj.to_flyte_idl()) | ||
assert obj == obj2 | ||
assert obj2.artifact_id == "my-artifact-id" | ||
assert obj2.name == "some name" | ||
|
||
|
||
def test_catalog_metadata(): | ||
task_id = identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version") | ||
wf_exec_id = identifier.WorkflowExecutionIdentifier("project", "domain", "name") | ||
node_exec_id = identifier.NodeExecutionIdentifier( | ||
"node_id", | ||
wf_exec_id, | ||
) | ||
te_id = identifier.TaskExecutionIdentifier(task_id, node_exec_id, 3) | ||
ds_id = identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "t1", "abcdef") | ||
tag = catalog.CatalogArtifactTag("my-artifact-id", "some name") | ||
obj = catalog.CatalogMetadata(dataset_id=ds_id, artifact_tag=tag, source_task_execution=te_id) | ||
assert obj.dataset_id is ds_id | ||
assert obj.source_execution is te_id | ||
assert obj.source_task_execution is te_id | ||
assert obj.artifact_tag is tag | ||
|
||
obj2 = catalog.CatalogMetadata.from_flyte_idl(obj.to_flyte_idl()) | ||
assert obj == obj2 |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters