From 6aa6844d4a83641b932827299cb22d591137fba3 Mon Sep 17 00:00:00 2001 From: Ketan Umare <16888709+kumare3@users.noreply.github.com> Date: Thu, 2 Dec 2021 10:08:28 -0800 Subject: [PATCH] Get raw input/output from remote execution (#675) * [wip] for feast demo Signed-off-by: Ketan Umare * clean up a bit Signed-off-by: Yee Hing Tong * add a test and move where constructor is called Signed-off-by: Yee Hing Tong * remove unneeded import Signed-off-by: Yee Hing Tong * add a part of a test Signed-off-by: Yee Hing Tong * Added tests Signed-off-by: Kevin Su * Fixed lint Signed-off-by: Kevin Su * typo Signed-off-by: Kevin Su Co-authored-by: Yee Hing Tong Co-authored-by: Kevin Su Signed-off-by: Kevin Su --- flytekit/core/type_engine.py | 23 ++++++++ flytekit/remote/executions.py | 15 +++++ flytekit/remote/remote.py | 10 +++- flytekit/types/file/file.py | 6 +- flytekit/types/pickle/pickle.py | 6 ++ .../integration/remote/test_remote.py | 2 + tests/flytekit/unit/core/test_type_engine.py | 59 +++++++++++++++++++ 7 files changed, 117 insertions(+), 4 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index ff9b257b568..268de66692c 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -965,4 +965,27 @@ def _register_default_type_transformers(): TypeEngine.register_restricted_type("named tuple", NamedTuple) +class LiteralsResolver(object): + """ + LiteralsResolver is a helper class meant primarily for use with the FlyteRemote experience or any other situation + where you might be working with LiteralMaps. This object allows the caller to specify the Python type that should + correspond to an element of the map. + TODO: Add an optional Flyte idl interface model object to the constructor + """ + + def __init__(self, literals: typing.Dict[str, Literal]): + self._literals = literals + + @property + def literals(self): + return self._literals + + def get(self, attr: str, as_type: Optional[typing.Type] = None): + if attr not in self._literals: + raise AttributeError(f"Attribute {attr} not found") + if as_type is None: + raise ValueError("as_type argument can't be None yet.") + return TypeEngine.to_python_value(FlyteContext.current_context(), self._literals[attr], as_type) + + _register_default_type_transformers() diff --git a/flytekit/remote/executions.py b/flytekit/remote/executions.py index 05b94b23021..b581769f52d 100644 --- a/flytekit/remote/executions.py +++ b/flytekit/remote/executions.py @@ -4,6 +4,7 @@ from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.exceptions import user as user_exceptions +from flytekit.core.type_engine import LiteralsResolver from flytekit.models import execution as execution_models from flytekit.models import node_execution as node_execution_models from flytekit.models.admin import task_execution as admin_task_execution_models @@ -83,6 +84,8 @@ def __init__(self, *args, **kwargs): self._inputs = None self._outputs = None self._flyte_workflow: Optional[FlyteWorkflow] = None + self._raw_inputs: Optional[LiteralsResolver] = None + self._raw_outputs: Optional[LiteralsResolver] = None @property def node_executions(self) -> Dict[str, "FlyteNodeExecution"]: @@ -111,6 +114,18 @@ def outputs(self) -> Dict[str, Any]: raise _user_exceptions.FlyteAssertion("Outputs could not be found because the execution ended in failure.") return self._outputs + @property + def raw_outputs(self) -> LiteralsResolver: + if self._raw_outputs is None: + raise ValueError(f"WF execution: {self} doesn't have raw outputs set") + return self._raw_outputs + + @property + def raw_inputs(self) -> LiteralsResolver: + if self._raw_inputs is None: + raise ValueError(f"WF execution: {self} doesn't have raw inputs set") + return self._raw_inputs + @property def error(self) -> core_execution_models.ExecutionError: """ diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 04e18d67940..117abf27e3b 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -48,7 +48,7 @@ from flytekit.core.context_manager import FlyteContextManager, ImageConfig, SerializationSettings, get_image_config from flytekit.core.data_persistence import FileAccessProvider from flytekit.core.launch_plan import LaunchPlan -from flytekit.core.type_engine import TypeEngine +from flytekit.core.type_engine import LiteralsResolver, TypeEngine from flytekit.core.workflow import WorkflowBase from flytekit.models import common as common_models from flytekit.models import launch_plan as launch_plan_models @@ -1255,15 +1255,19 @@ def _assign_inputs_and_outputs( ): """Helper for assigning synced inputs and outputs to an execution object.""" with self.remote_context() as ctx: + input_literal_map = self._get_input_literal_map(execution_data) + execution._raw_inputs = LiteralsResolver(input_literal_map.literals) execution._inputs = TypeEngine.literal_map_to_kwargs( ctx=ctx, - lm=self._get_input_literal_map(execution_data), + lm=input_literal_map, python_types=TypeEngine.guess_python_types(interface.inputs), ) if execution.is_complete and not execution.error: + output_literal_map = self._get_output_literal_map(execution_data) + execution._raw_outputs = LiteralsResolver(output_literal_map.literals) execution._outputs = TypeEngine.literal_map_to_kwargs( ctx=ctx, - lm=self._get_output_literal_map(execution_data), + lm=output_literal_map, python_types=TypeEngine.guess_python_types(interface.outputs), ) return execution diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index 1fbcee049a0..08296785568 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -348,7 +348,11 @@ def _downloader(): return ff def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlyteFile[typing.Any]]: - if literal_type.blob is not None and literal_type.blob.dimensionality == BlobType.BlobDimensionality.SINGLE: + if ( + literal_type.blob is not None + and literal_type.blob.dimensionality == BlobType.BlobDimensionality.SINGLE + and literal_type.blob.format != FlytePickleTransformer.PYTHON_PICKLE_FORMAT + ): return FlyteFile[typing.TypeVar(literal_type.blob.format)] raise ValueError(f"Transformer {self} cannot reverse {literal_type}") diff --git a/flytekit/types/pickle/pickle.py b/flytekit/types/pickle/pickle.py index 8251d111bb2..4a011d2762a 100644 --- a/flytekit/types/pickle/pickle.py +++ b/flytekit/types/pickle/pickle.py @@ -78,6 +78,12 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp ctx.file_access.put_data(uri, remote_path, is_multipart=False) return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) + def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlytePickle[typing.Any]]: + if literal_type.blob is not None and literal_type.blob.format == FlytePickleTransformer.PYTHON_PICKLE_FORMAT: + return FlytePickle + + raise ValueError(f"Transformer {self} cannot reverse {literal_type}") + def get_literal_type(self, t: Type[T]) -> LiteralType: return LiteralType( blob=_core_types.BlobType( diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index 9dc32a2e544..a45c037279e 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -156,6 +156,8 @@ def test_fetch_execute_task(flyteclient, flyte_workflows_register): execution = remote.execute(flyte_task, {"a": 10}, wait=True) assert execution.outputs["t1_int_output"] == 12 assert execution.outputs["c"] == "world" + assert execution.raw_inputs.get("a", int) == 10 + assert execution.raw_outputs.get("c", str) == "world" def test_execute_python_task(flyteclient, flyte_workflows_register, flyte_remote_env): diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 826c8df1fb9..50a77c43d9d 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -21,6 +21,7 @@ DataclassTransformer, DictTransformer, ListTransformer, + LiteralsResolver, SimpleTransformer, TypeEngine, convert_json_schema_to_python_class, @@ -691,3 +692,61 @@ def test_schema_in_dataclass(): ot = tf.to_python_value(ctx, lv=lv, expected_python_type=Result) assert o == ot + + +@pytest.mark.parametrize( + "literal_value,python_type,expected_python_value", + [ + ( + Literal( + collection=LiteralCollection( + literals=[ + Literal(scalar=Scalar(primitive=Primitive(integer=1))), + Literal(scalar=Scalar(primitive=Primitive(integer=2))), + Literal(scalar=Scalar(primitive=Primitive(integer=3))), + ] + ) + ), + typing.List[int], + [1, 2, 3], + ), + ( + Literal( + map=LiteralMap( + literals={ + "k1": Literal(scalar=Scalar(primitive=Primitive(string_value="v1"))), + "k2": Literal(scalar=Scalar(primitive=Primitive(string_value="2"))), + }, + ) + ), + typing.Dict[str, str], + {"k1": "v1", "k2": "2"}, + ), + ], +) +def test_literals_resolver(literal_value, python_type, expected_python_value): + lit_dict = {"a": literal_value} + + lr = LiteralsResolver(lit_dict) + out = lr.get("a", python_type) + assert out == expected_python_value + + +def test_guess_of_dataclass(): + @dataclass_json + @dataclass() + class Foo(object): + x: int + y: str + z: typing.Dict[str, int] + + def hello(self): + ... + + lt = TypeEngine.to_literal_type(Foo) + foo = Foo(1, "hello", {"world": 3}) + lv = TypeEngine.to_literal(FlyteContext.current_context(), foo, Foo, lt) + lit_dict = {"a": lv} + lr = LiteralsResolver(lit_dict) + assert lr.get("a", Foo) == foo + assert hasattr(lr.get("a", Foo), "hello") is True