Skip to content

Commit

Permalink
Get raw input/output from remote execution (flyteorg#675)
Browse files Browse the repository at this point in the history
* [wip] for feast demo

Signed-off-by: Ketan Umare <[email protected]>

* clean up a bit

Signed-off-by: Yee Hing Tong <[email protected]>

* add a test and move where constructor is called

Signed-off-by: Yee Hing Tong <[email protected]>

* remove unneeded import

Signed-off-by: Yee Hing Tong <[email protected]>

* add a part of a test

Signed-off-by: Yee Hing Tong <[email protected]>

* Added tests

Signed-off-by: Kevin Su <[email protected]>

* Fixed lint

Signed-off-by: Kevin Su <[email protected]>

* typo

Signed-off-by: Kevin Su <[email protected]>

Co-authored-by: Yee Hing Tong <[email protected]>
Co-authored-by: Kevin Su <[email protected]>
  • Loading branch information
3 people authored Dec 2, 2021
1 parent 75c7b04 commit 742f40e
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 3 deletions.
23 changes: 23 additions & 0 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,4 +965,27 @@ def _register_default_type_transformers():
TypeEngine.register_restricted_type("named tuple", NamedTuple)


class LiteralsResolver(object):
"""
LiteralsResolver is a helper class meant primarily for use with the FlyteRemote experience or any other situation
where you might be working with LiteralMaps. This object allows the caller to specify the Python type that should
correspond to an element of the map.
TODO: Add an optional Flyte idl interface model object to the constructor
"""

def __init__(self, literals: typing.Dict[str, Literal]):
self._literals = literals

@property
def literals(self):
return self._literals

def get(self, attr: str, as_type: Optional[typing.Type] = None):
if attr not in self._literals:
raise AttributeError(f"Attribute {attr} not found")
if as_type is None:
raise ValueError("as_type argument can't be None yet.")
return TypeEngine.to_python_value(FlyteContext.current_context(), self._literals[attr], as_type)


_register_default_type_transformers()
15 changes: 15 additions & 0 deletions flytekit/remote/executions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from flytekit.common.exceptions import user as _user_exceptions
from flytekit.common.exceptions import user as user_exceptions
from flytekit.core.type_engine import LiteralsResolver
from flytekit.models import execution as execution_models
from flytekit.models import node_execution as node_execution_models
from flytekit.models.admin import task_execution as admin_task_execution_models
Expand Down Expand Up @@ -83,6 +84,8 @@ def __init__(self, *args, **kwargs):
self._inputs = None
self._outputs = None
self._flyte_workflow: Optional[FlyteWorkflow] = None
self._raw_inputs: Optional[LiteralsResolver] = None
self._raw_outputs: Optional[LiteralsResolver] = None

@property
def node_executions(self) -> Dict[str, "FlyteNodeExecution"]:
Expand Down Expand Up @@ -111,6 +114,18 @@ def outputs(self) -> Dict[str, Any]:
raise _user_exceptions.FlyteAssertion("Outputs could not be found because the execution ended in failure.")
return self._outputs

@property
def raw_outputs(self) -> LiteralsResolver:
if self._raw_outputs is None:
raise ValueError(f"WF execution: {self} doesn't have raw outputs set")
return self._raw_outputs

@property
def raw_inputs(self) -> LiteralsResolver:
if self._raw_inputs is None:
raise ValueError(f"WF execution: {self} doesn't have raw inputs set")
return self._raw_inputs

@property
def error(self) -> core_execution_models.ExecutionError:
"""
Expand Down
10 changes: 7 additions & 3 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from flytekit.core.context_manager import FlyteContextManager, ImageConfig, SerializationSettings, get_image_config
from flytekit.core.data_persistence import FileAccessProvider
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.type_engine import TypeEngine
from flytekit.core.type_engine import LiteralsResolver, TypeEngine
from flytekit.core.workflow import WorkflowBase
from flytekit.models import common as common_models
from flytekit.models import launch_plan as launch_plan_models
Expand Down Expand Up @@ -1255,15 +1255,19 @@ def _assign_inputs_and_outputs(
):
"""Helper for assigning synced inputs and outputs to an execution object."""
with self.remote_context() as ctx:
input_literal_map = self._get_input_literal_map(execution_data)
execution._raw_inputs = LiteralsResolver(input_literal_map.literals)
execution._inputs = TypeEngine.literal_map_to_kwargs(
ctx=ctx,
lm=self._get_input_literal_map(execution_data),
lm=input_literal_map,
python_types=TypeEngine.guess_python_types(interface.inputs),
)
if execution.is_complete and not execution.error:
output_literal_map = self._get_output_literal_map(execution_data)
execution._raw_outputs = LiteralsResolver(output_literal_map.literals)
execution._outputs = TypeEngine.literal_map_to_kwargs(
ctx=ctx,
lm=self._get_output_literal_map(execution_data),
lm=output_literal_map,
python_types=TypeEngine.guess_python_types(interface.outputs),
)
return execution
Expand Down
2 changes: 2 additions & 0 deletions tests/flytekit/integration/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ def test_fetch_execute_task(flyteclient, flyte_workflows_register):
execution = remote.execute(flyte_task, {"a": 10}, wait=True)
assert execution.outputs["t1_int_output"] == 12
assert execution.outputs["c"] == "world"
assert execution.raw_inputs.get("a", int) == 10
assert execution.raw_outputs.get("c", str) == "world"


def test_execute_python_task(flyteclient, flyte_workflows_register, flyte_remote_env):
Expand Down
59 changes: 59 additions & 0 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
DataclassTransformer,
DictTransformer,
ListTransformer,
LiteralsResolver,
SimpleTransformer,
TypeEngine,
convert_json_schema_to_python_class,
Expand Down Expand Up @@ -691,3 +692,61 @@ def test_schema_in_dataclass():
ot = tf.to_python_value(ctx, lv=lv, expected_python_type=Result)

assert o == ot


@pytest.mark.parametrize(
"literal_value,python_type,expected_python_value",
[
(
Literal(
collection=LiteralCollection(
literals=[
Literal(scalar=Scalar(primitive=Primitive(integer=1))),
Literal(scalar=Scalar(primitive=Primitive(integer=2))),
Literal(scalar=Scalar(primitive=Primitive(integer=3))),
]
)
),
typing.List[int],
[1, 2, 3],
),
(
Literal(
map=LiteralMap(
literals={
"k1": Literal(scalar=Scalar(primitive=Primitive(string_value="v1"))),
"k2": Literal(scalar=Scalar(primitive=Primitive(string_value="2"))),
},
)
),
typing.Dict[str, str],
{"k1": "v1", "k2": "2"},
),
],
)
def test_literals_resolver(literal_value, python_type, expected_python_value):
lit_dict = {"a": literal_value}

lr = LiteralsResolver(lit_dict)
out = lr.get("a", python_type)
assert out == expected_python_value


def test_guess_of_dataclass():
@dataclass_json
@dataclass()
class Foo(object):
x: int
y: str
z: typing.Dict[str, int]

def hello(self):
...

lt = TypeEngine.to_literal_type(Foo)
foo = Foo(1, "hello", {"world": 3})
lv = TypeEngine.to_literal(FlyteContext.current_context(), foo, Foo, lt)
lit_dict = {"a": lv}
lr = LiteralsResolver(lit_dict)
assert lr.get("a", Foo) == foo
assert hasattr(lr.get("a", Foo), "hello") is True

0 comments on commit 742f40e

Please sign in to comment.