Skip to content

Commit

Permalink
add Flyte*Execution and control_plane.identifier classes
Browse files Browse the repository at this point in the history
Signed-off-by: cosmicBboy <[email protected]>
  • Loading branch information
cosmicBboy committed Apr 1, 2021
1 parent 5353d98 commit df001ed
Show file tree
Hide file tree
Showing 7 changed files with 398 additions and 7 deletions.
2 changes: 1 addition & 1 deletion flytekit/control_plane/component_nodes.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import logging as _logging
from typing import Dict

from flytekit.common.core import identifier as _identifier
from flytekit.common.exceptions import system as _system_exceptions
from flytekit.control_plane import identifier as _identifier
from flytekit.models import task as _task_model
from flytekit.models.core import workflow as _workflow_model

Expand Down
138 changes: 138 additions & 0 deletions flytekit/control_plane/identifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import six as _six

from flytekit.common.exceptions import user as _user_exceptions
from flytekit.models.core import identifier as _core_identifier


class Identifier(_core_identifier.Identifier):

_STRING_TO_TYPE_MAP = {
"lp": _core_identifier.ResourceType.LAUNCH_PLAN,
"wf": _core_identifier.ResourceType.WORKFLOW,
"tsk": _core_identifier.ResourceType.TASK,
}
_TYPE_TO_STRING_MAP = {v: k for k, v in _six.iteritems(_STRING_TO_TYPE_MAP)}

@classmethod
def promote_from_model(cls, base_model: _core_identifier.Identifier) -> "Identifier":
return cls(
base_model.response_type, base_model.project, base_model.domain, base_model.name, base_model.version
)

@classmethod
def from_python_std(cls, string: str) -> "Identifier":
"""
Parses a string in the correct format into an identifier
"""
segments = string.split(":")
if len(segments) != 5:
raise _user_exceptions.FlyteValueException(
"The provided string was not in a parseable format. The string for an identifier must be in the format"
" entity_type:project:domain:name:version. Received: {}".format(string)
)

resource_type, project, domain, name, version = segments

if resource_type not in cls._STRING_TO_TYPE_MAP:
raise _user_exceptions.FlyteValueException(
"The provided string could not be parsed. The first element of an identifier must be one of: {}. "
"Received: {}".format(list(cls._STRING_TO_TYPE_MAP.keys()), resource_type)
)

return cls(cls._STRING_TO_TYPE_MAP[resource_type], project, domain, name, version)

def __str__(self):
return "{}:{}:{}:{}:{}".format(
type(self)._TYPE_TO_STRING_MAP.get(self.resource_type, "<unknown>"),
self.project,
self.domain,
self.name,
self.version,
)


class WorkflowExecutionIdentifier(_core_identifier.WorkflowExecutionIdentifier):
@classmethod
def promote_from_model(
cls, base_model: _core_identifier.WorkflowExecutionIdentifier
) -> "WorkflowExecutionIdentifier":
return cls(base_model.project, base_model.domain, base_model.name)

@classmethod
def from_python_std(cls, string: str) -> "WorkflowExecutionIdentifier":
"""
Parses a string in the correct format into an identifier
"""
segments = string.split(":")
if len(segments) != 4:
raise _user_exceptions.FlyteValueException(
string,
"The provided string was not in a parseable format. The string for an identifier must be in the format"
" ex:project:domain:name.",
)

resource_type, project, domain, name = segments

if resource_type != "ex":
raise _user_exceptions.FlyteValueException(
resource_type,
"The provided string could not be parsed. The first element of an execution identifier must be 'ex'.",
)

return cls(project, domain, name)

def __str__(self):
return f"ex:{self.project}:{self.domain}:{self.name}"


class TaskExecutionIdentifier(_core_identifier.TaskExecutionIdentifier):
@classmethod
def promote_from_model(cls, base_model: _core_identifier.TaskExecutionIdentifier) -> "TaskExecutionIdentifier":
return cls(
task_id=base_model.task_id,
node_execution_id=base_model.node_execution_id,
retry_attempt=base_model.retry_attempt,
)

@classmethod
def from_python_std(cls, string: str) -> "TaskExecutionIdentifier":
"""
Parses a string in the correct format into an identifier
"""
segments = string.split(":")
if len(segments) != 10:
raise _user_exceptions.FlyteValueException(
string,
"The provided string was not in a parseable format. The string for an identifier must be in the format"
" te:exec_project:exec_domain:exec_name:node_id:task_project:task_domain:task_name:task_version:retry.",
)

resource_type, ep, ed, en, node_id, tp, td, tn, tv, retry = segments

if resource_type != "te":
raise _user_exceptions.FlyteValueException(
resource_type,
"The provided string could not be parsed. The first element of an execution identifier must be 'ex'.",
)

return cls(
task_id=Identifier(_core_identifier.ResourceType.TASK, tp, td, tn, tv),
node_execution_id=_core_identifier.NodeExecutionIdentifier(
node_id=node_id, execution_id=_core_identifier.WorkflowExecutionIdentifier(ep, ed, en),
),
retry_attempt=int(retry),
)

def __str__(self):
return (
"te:"
f"{self.node_execution_id.execution_id.project}:"
f"{self.node_execution_id.execution_id.domain}:"
f"{self.node_execution_id.execution_id.name}:"
f"{self.node_execution_id.node_id}:"
f"{self.task_id.project}:"
f"{self.task_id.domain}:"
f"{self.task_id.name}:"
f"{self.task_id.version}:"
f"{self.retry_attempt}"
)
87 changes: 84 additions & 3 deletions flytekit/control_plane/nodes.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
import logging as _logging
from typing import Dict, List, Optional

from flytekit.clients.helpers import iterate_task_executions as _iterate_task_executions
from flytekit.common import constants as _constants
from flytekit.common import promise as _promise
from flytekit.common.core import identifier as _identifier
from flytekit.common.exceptions import scopes as _exception_scopes
from flytekit.common.exceptions import system as _system_exceptions
from flytekit.common.exceptions import user as _user_exceptions
from flytekit.common.mixins import artifact as _artifact_mixin
from flytekit.common.mixins import hash as _hash_mixin
from flytekit.common.nodes import OutputParameterMapper, ParameterMapper
from flytekit.common.nodes import OutputParameterMapper
from flytekit.common.utils import _dnsify
from flytekit.control_plane import component_nodes as _component_nodes
from flytekit.control_plane import identifier as _identifier
from flytekit.control_plane.tasks import executions as _task_executions
from flytekit.control_plane.tasks.task import FlyteTask
from flytekit.models import common as _common_models
from flytekit.engines.flyte import engine as _flyte_engine
from flytekit.models import node_execution as _node_execution_models
from flytekit.models import task as _task_model
from flytekit.models.core import execution as _execution_models
from flytekit.models.core import workflow as _workflow_model


Expand Down Expand Up @@ -175,3 +180,79 @@ def __rshift__(self, other: "FlyteNode") -> "FlyteNode":

def __repr__(self) -> str:
return f"Node(ID: {self.id} Executable: {self._executable_flyte_object})"


class FlyteNodeExecution(_node_execution_models.NodeExecution, _artifact_mixin.ExecutionArtifact):
def __init__(self, *args, **kwargs):
super(FlyteNodeExecution, self).__init__(*args, **kwargs)
self._task_executions = None
self._workflow_executions = None
self._inputs = None
self._outputs = None

@property
def task_executions(self) -> List["flytekit.control_plane.tasks.executions.FlyteTaskExecution"]:
return self._task_executions or []

@property
def workflow_executions(self) -> List["flytekit.control_plane.workflow_executions.FlyteWorkflowExecution"]:
return self._workflow_executions or []

@property
def executions(self) -> _artifact_mixin.ExecutionArtifact:
return self.task_executions or self.workflow_executions or []

@property
def inputs(self):
# TODO
pass

@property
def outputs(self):
# TODO
pass

@property
def error(self) -> _execution_models.ExecutionError:
"""
If execution is in progress, raise an exception. Otherwise, return None if no error was present upon
reaching completion.
"""
if not self.is_complete:
raise _user_exceptions.FlyteAssertion(
"Please wait until the node execution has completed before requesting error information."
)
return self.closure.error

@property
def is_complete(self) -> bool:
"""Whether or not the execution is complete."""
return self.closure.phase in {
_execution_models.NodeExecutionPhase.ABORTED,
_execution_models.NodeExecutionPhase.FAILED,
_execution_models.NodeExecutionPhase.SKIPPED,
_execution_models.NodeExecutionPhase.SUCCEEDED,
_execution_models.NodeExecutionPhase.TIMED_OUT,
}

@classmethod
def promote_from_model(cls, base_model: _node_execution_models.NodeExecution) -> "FlyteNodeExecution":
return cls(closure=base_model.closure, id=base_model.id, input_uri=base_model.input_uri)

def sync(self):
"""
Syncs the state of the underlying execution artifact with the state observed by the platform.
"""
if not self.is_complete or self.task_executions is not None:
client = _flyte_engine.get_client()
self._closure = client.get_node_execution(self.id).closure
self._task_executions = [
_task_executions.FlyteTaskExecution.promote_from_model(t)
for t in _iterate_task_executions(client, self.id)
]

def _sync_closure(self):
"""
Syncs the closure of the underlying execution artifact with the state observed by the platform.
"""
self._closure = _flyte_engine.get_client().get_node_execution(self.id).closure
79 changes: 79 additions & 0 deletions flytekit/control_plane/tasks/executions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from typing import Optional

from flytekit.common.exceptions import user as _user_exceptions
from flytekit.common.mixins import artifact as _artifact_mixin
from flytekit.engines.flyte import engine as _flyte_engine
from flytekit.models.admin import task_execution as _task_execution_model
from flytekit.models.core import execution as _execution_models


class FlyteTaskExecution(_task_execution_model.TaskExecution, _artifact_mixin.ExecutionArtifact):
def __init__(self, *args, **kwargs):
super(FlyteTaskExecution, self).__init__(*args, **kwargs)
self._inputs = None
self._outputs = None

@property
def is_complete(self) -> bool:
"""Whether or not the execution is complete."""
return self.closure.phase in {
_execution_models.TaskExecutionPhase.ABORTED,
_execution_models.TaskExecutionPhase.FAILED,
_execution_models.TaskExecutionPhase.SUCCEEDED,
}

@property
def inputs(self):
# TODO
pass

@property
def outputs(self):
# TODO
pass

@property
def error(self) -> Optional[_execution_models.ExecutionError]:
"""
If execution is in progress, raise an exception. Otherwise, return None if no error was present upon
reaching completion.
"""
if not self.is_complete:
raise _user_exceptions.FlyteAssertion(
"Please what until the task execution has completed before requesting error information."
)
return self.closure.error

def get_child_executions(self, filters=None):
from flytekit.control_plane import nodes as _nodes

if not self.is_parent:
raise _user_exceptions.FlyteAssertion("Only task executions marked with 'is_parent' have child executions.")
client = _flyte_engine.get_client()
models = {
v.id.node_id: v
for v in _iterate_node_executions(client, task_execution_identifier=self.id, filters=filters)
}

return {k: _nodes.FlyteNodeExecution.promote_from_model(v) for k, v in _six.iteritems(models)}

@classmethod
def promote_from_model(cls, base_model: _task_execution_model.TaskExecution) -> "FlyteTaskExecution":
return cls(
closure=base_model.closure,
id=base_model.id,
input_uri=base_model.input_uri,
is_parent=base_model.is_parent,
)

def sync(self):
"""
Syncs the state of the underlying execution artifact with the state observed by the platform.
"""
self._sync_closure()

def _sync_closure(self):
"""
Syncs the closure of the underlying execution artifact with the state observed by the platform.
"""
self._closure = _flyte_engine.get_client().get_task_execution(self.id).closure
2 changes: 1 addition & 1 deletion flytekit/control_plane/tasks/task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from flytekit.common.core import identifier as _identifier
from flytekit.common.exceptions import scopes as _exception_scopes
from flytekit.common.mixins import hash as _hash_mixin
from flytekit.control_plane import identifier as _identifier
from flytekit.control_plane import interface as _interfaces
from flytekit.engines.flyte import engine as _flyte_engine
from flytekit.models import common as _common_model
Expand Down
4 changes: 2 additions & 2 deletions flytekit/control_plane/workflow.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Dict, List, Optional

from flytekit.common import constants as _constants
from flytekit.common.core import identifier as _identifier
from flytekit.common.exceptions import scopes as _exception_scopes
from flytekit.common.mixins import hash as _hash_mixin
from flytekit.control_plane import identifier as _identifier
from flytekit.control_plane import interface as _interfaces
from flytekit.control_plane import nodes as _nodes
from flytekit.engines.flyte import engine as _flyte_engine
Expand Down Expand Up @@ -87,7 +87,7 @@ def get_sub_workflows(self) -> List["FlyteWorkflow"]:

@classmethod
@_exception_scopes.system_entry_point
def fetch(cls, project, domain, name, version):
def fetch(cls, project: str, domain: str, name: str, version: str):
workflow_id = _identifier.Identifier(_identifier_model.ResourceType.WORKFLOW, project, domain, name, version)
admin_workflow = _flyte_engine.get_client().get_workflow(workflow_id)
cwc = admin_workflow.closure.compiled_workflow
Expand Down
Loading

0 comments on commit df001ed

Please sign in to comment.