From 990b450ea57539271a9fbc7aff7e49aea7f33bd0 Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Thu, 15 Aug 2024 09:28:26 -0400 Subject: [PATCH 1/8] [WIP] - Read offloaded literals Signed-off-by: Eduardo Apolinario --- flytekit/core/type_engine.py | 13 +++++++++++- flytekit/models/literals.py | 22 ++++++++++++++++++++ pyproject.toml | 2 +- tests/flytekit/unit/core/test_type_engine.py | 21 +++++++++++++++++++ 4 files changed, 56 insertions(+), 2 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index d66bc8a956..7fb3929bbb 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -8,6 +8,7 @@ import inspect import json import mimetypes +import os import sys import textwrap import typing @@ -32,7 +33,7 @@ from flytekit.core.context_manager import FlyteContext from flytekit.core.hash import HashMethod from flytekit.core.type_helpers import load_type_from_tag -from flytekit.core.utils import timeit +from flytekit.core.utils import load_proto_from_file, timeit from flytekit.exceptions import user as user_exceptions from flytekit.interaction.string_literals import literal_map_string_repr from flytekit.lazy_import.lazy_module import is_imported @@ -1097,6 +1098,16 @@ def to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: T """ Converts a Literal value with an expected python type into a python value. """ + # Initiate the process of loading the offloaded literal if uri is set + if lv.uri: + # TODO: fail fast if size is larger than X MB + literal_random_path = ctx.file_access.get_random_local_path() + # TODO: Loading a literal from bytes requires writing it to a file + local_literal_file = os.path.join(ctx.execution_state.working_dir, literal_random_path) + ctx.file_access.download(lv.uri, local_literal_file) + input_proto = load_proto_from_file(literals_pb2.Literal, local_literal_file) + lv = Literal.from_flyte_idl(input_proto) + transformer = cls.get_transformer(expected_python_type) return transformer.to_python_value(ctx, lv, expected_python_type) diff --git a/flytekit/models/literals.py b/flytekit/models/literals.py index e08c495b67..cdc855e5a3 100644 --- a/flytekit/models/literals.py +++ b/flytekit/models/literals.py @@ -860,6 +860,8 @@ def __init__( map: Optional[LiteralMap] = None, hash: Optional[str] = None, metadata: Optional[Dict[str, str]] = None, + uri: Optional[str] = None, + size_bytes: Optional[int] = None, ): """ This IDL message represents a literal value in the Flyte ecosystem. @@ -873,6 +875,8 @@ def __init__( self._map = map self._hash = hash self._metadata = metadata + self._uri = uri + self._size_bytes = size_bytes @property def scalar(self): @@ -925,6 +929,20 @@ def metadata(self) -> Optional[Dict[str, str]]: """ return self._metadata + @property + def uri(self) -> Optional[str]: + """ + If set, this value holds the URI of the offloaded literal. + """ + return self._uri + + @property + def size_bytes(self) -> Optional[int]: + """ + If set, this value holds the size in bytes of the offloaded literal proto. + """ + return self._size_bytes + def to_flyte_idl(self): """ :rtype: flyteidl.core.literals_pb2.Literal @@ -935,6 +953,8 @@ def to_flyte_idl(self): map=self.map.to_flyte_idl() if self.map is not None else None, hash=self.hash, metadata=self.metadata, + uri=self.uri, + size_bytes=self.size_bytes, ) @classmethod @@ -953,6 +973,8 @@ def from_flyte_idl(cls, pb2_object): map=LiteralMap.from_flyte_idl(pb2_object.map) if pb2_object.HasField("map") else None, hash=pb2_object.hash if pb2_object.hash else None, metadata={k: v for k, v in pb2_object.metadata.items()} if pb2_object.metadata else None, + uri=pb2_object.uri if pb2_object.uri else None, + size_bytes=pb2_object.size_bytes if pb2_object.size_bytes is not None else None, ) def set_metadata(self, metadata: Dict[str, str]): diff --git a/pyproject.toml b/pyproject.toml index 8e8fcef90f..4a378e9278 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ "diskcache>=5.2.1", "docker>=4.0.0", "docstring-parser>=0.9.0", - "flyteidl>=1.13.1", + "flyteidl>=1.13.2", "fsspec>=2023.3.0", "gcsfs>=2023.3.0", "googleapis-common-protos>=1.57", diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 0cde27c619..034b0628a9 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -3075,3 +3075,24 @@ def test_union_file_directory(): pv = union_trans.to_python_value(ctx, lv, typing.Union[FlyteFile, FlyteDirectory]) assert pv._remote_source == s3_dir + + +def test_offloaded_literal(tmp_path): + ctx = FlyteContext.current_context() + + pt = typing.List[int] + lt = TypeEngine.to_literal_type(pt) + pv = [1, 2, 3] + offloaded_lv = TypeEngine.to_literal(ctx, pv, pt, lt) + + # Write offloaded_lv as bytes to a temp file + with open(f"{tmp_path}/offloaded_proto.pb", "wb") as f: + f.write(offloaded_lv.to_flyte_idl().SerializeToString()) + + literal = Literal( + uri=f"{tmp_path}/offloaded_proto.pb", + size_bytes=100, + ) + + loaded_literal = TypeEngine.to_python_value(ctx, literal, pt) + assert loaded_literal == pv From 5f82ca2266269738d98a6dbe3f7846171aa3357b Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Tue, 17 Sep 2024 17:48:56 -0400 Subject: [PATCH 2/8] Use LiteralOffloadedMetadata field Signed-off-by: Eduardo Apolinario --- flytekit/core/type_engine.py | 7 +- flytekit/models/literals.py | 77 +++++++++++++++----- pyproject.toml | 2 +- tests/flytekit/unit/core/test_type_engine.py | 8 +- 4 files changed, 68 insertions(+), 26 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 7fb3929bbb..c7cae7fe0e 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1099,12 +1099,13 @@ def to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: T Converts a Literal value with an expected python type into a python value. """ # Initiate the process of loading the offloaded literal if uri is set - if lv.uri: - # TODO: fail fast if size is larger than X MB + if lv.offloaded_metadata: literal_random_path = ctx.file_access.get_random_local_path() # TODO: Loading a literal from bytes requires writing it to a file + assert ctx.execution_state local_literal_file = os.path.join(ctx.execution_state.working_dir, literal_random_path) - ctx.file_access.download(lv.uri, local_literal_file) + assert lv.offloaded_metadata.uri + ctx.file_access.download(lv.offloaded_metadata.uri, local_literal_file) input_proto = load_proto_from_file(literals_pb2.Literal, local_literal_file) lv = Literal.from_flyte_idl(input_proto) diff --git a/flytekit/models/literals.py b/flytekit/models/literals.py index cdc855e5a3..4b96f50b82 100644 --- a/flytekit/models/literals.py +++ b/flytekit/models/literals.py @@ -8,7 +8,7 @@ from flytekit.exceptions import user as _user_exceptions from flytekit.models import common as _common from flytekit.models.core import types as _core_types -from flytekit.models.types import Error, StructuredDatasetType +from flytekit.models.types import Error, LiteralType, StructuredDatasetType from flytekit.models.types import LiteralType as _LiteralType from flytekit.models.types import OutputReference as _OutputReference from flytekit.models.types import SchemaType as _SchemaType @@ -852,6 +852,52 @@ def from_flyte_idl(cls, pb2_object): ) +class LiteralOffloadedMetadata(_common.FlyteIdlEntity): + def __init__( + self, + uri: Optional[str], + size_bytes: Optional[int], + inferred_type: Optional[LiteralType], + ): + """ + :param Text uri: URI of the offloaded literal + :param int size_bytes: Size in bytes of the offloaded literal proto + :param LiteralType inferred_type: Inferred type of the offloaded literal + """ + self._uri = uri + self._size_bytes = size_bytes + self._inferred_type = inferred_type + + @property + def uri(self): + return self._uri + + @property + def size_bytes(self): + return self._size_bytes + + @property + def inferred_type(self): + return self._inferred_type + + def to_flyte_idl(self): + return _literals_pb2.LiteralOffloadedMetadata( + uri=self.uri, + size_bytes=self.size_bytes, + inferred_type=self.inferred_type.to_flyte_idl() if self.inferred_type else None, + ) + + @classmethod + def from_flyte_idl(cls, pb2_object): + return cls( + uri=pb2_object.uri, + size_bytes=pb2_object.size_bytes, + inferred_type=_LiteralType.from_flyte_idl(pb2_object.inferred_type) + if pb2_object.HasField("inferred_type") + else None, + ) + + class Literal(_common.FlyteIdlEntity): def __init__( self, @@ -860,8 +906,7 @@ def __init__( map: Optional[LiteralMap] = None, hash: Optional[str] = None, metadata: Optional[Dict[str, str]] = None, - uri: Optional[str] = None, - size_bytes: Optional[int] = None, + offloaded_metadata: Optional[LiteralOffloadedMetadata] = None, ): """ This IDL message represents a literal value in the Flyte ecosystem. @@ -875,8 +920,7 @@ def __init__( self._map = map self._hash = hash self._metadata = metadata - self._uri = uri - self._size_bytes = size_bytes + self._offloaded_metadata = offloaded_metadata @property def scalar(self): @@ -930,18 +974,11 @@ def metadata(self) -> Optional[Dict[str, str]]: return self._metadata @property - def uri(self) -> Optional[str]: + def offloaded_metadata(self) -> Optional[LiteralOffloadedMetadata]: """ - If set, this value holds the URI of the offloaded literal. + This value holds metadata about the offloaded literal. """ - return self._uri - - @property - def size_bytes(self) -> Optional[int]: - """ - If set, this value holds the size in bytes of the offloaded literal proto. - """ - return self._size_bytes + return self._offloaded_metadata def to_flyte_idl(self): """ @@ -953,12 +990,11 @@ def to_flyte_idl(self): map=self.map.to_flyte_idl() if self.map is not None else None, hash=self.hash, metadata=self.metadata, - uri=self.uri, - size_bytes=self.size_bytes, + offloaded_metadata=self.offloaded_metadata.to_flyte_idl() if self.offloaded_metadata else None, ) @classmethod - def from_flyte_idl(cls, pb2_object): + def from_flyte_idl(cls, pb2_object: _literals_pb2.Literal): """ :param flyteidl.core.literals_pb2.Literal pb2_object: :rtype: Literal @@ -973,8 +1009,9 @@ def from_flyte_idl(cls, pb2_object): map=LiteralMap.from_flyte_idl(pb2_object.map) if pb2_object.HasField("map") else None, hash=pb2_object.hash if pb2_object.hash else None, metadata={k: v for k, v in pb2_object.metadata.items()} if pb2_object.metadata else None, - uri=pb2_object.uri if pb2_object.uri else None, - size_bytes=pb2_object.size_bytes if pb2_object.size_bytes is not None else None, + offloaded_metadata=LiteralOffloadedMetadata.from_flyte_idl(pb2_object.offloaded_metadata) + if pb2_object.HasField("offloaded_metadata") + else None, ) def set_metadata(self, metadata: Dict[str, str]): diff --git a/pyproject.toml b/pyproject.toml index 4a378e9278..ba2cc46e83 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ "diskcache>=5.2.1", "docker>=4.0.0", "docstring-parser>=0.9.0", - "flyteidl>=1.13.2", + "flyteidl>=1.13.4", "fsspec>=2023.3.0", "gcsfs>=2023.3.0", "googleapis-common-protos>=1.57", diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 034b0628a9..6391250b90 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -56,6 +56,7 @@ Literal, LiteralCollection, LiteralMap, + LiteralOffloadedMetadata, Primitive, Scalar, Void, @@ -3090,8 +3091,11 @@ def test_offloaded_literal(tmp_path): f.write(offloaded_lv.to_flyte_idl().SerializeToString()) literal = Literal( - uri=f"{tmp_path}/offloaded_proto.pb", - size_bytes=100, + offloaded_metadata=LiteralOffloadedMetadata( + uri=f"{tmp_path}/offloaded_proto.pb", + size_bytes=100, + inferred_type=lt, + ), ) loaded_literal = TypeEngine.to_python_value(ctx, literal, pt) From 60a1abcf7ae4343c02927d748c960f23cf71a4d3 Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Tue, 17 Sep 2024 20:07:27 -0400 Subject: [PATCH 3/8] Assert use of offloaded uri to get around typing constraint Signed-off-by: Eduardo Apolinario --- flytekit/core/type_engine.py | 2 +- flytekit/models/literals.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index c7cae7fe0e..2a23c242dd 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1104,7 +1104,7 @@ def to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: T # TODO: Loading a literal from bytes requires writing it to a file assert ctx.execution_state local_literal_file = os.path.join(ctx.execution_state.working_dir, literal_random_path) - assert lv.offloaded_metadata.uri + assert lv.offloaded_metadata.uri, "missing offloaded uri" ctx.file_access.download(lv.offloaded_metadata.uri, local_literal_file) input_proto = load_proto_from_file(literals_pb2.Literal, local_literal_file) lv = Literal.from_flyte_idl(input_proto) diff --git a/flytekit/models/literals.py b/flytekit/models/literals.py index 4b96f50b82..41c02a4a3c 100644 --- a/flytekit/models/literals.py +++ b/flytekit/models/literals.py @@ -855,9 +855,9 @@ def from_flyte_idl(cls, pb2_object): class LiteralOffloadedMetadata(_common.FlyteIdlEntity): def __init__( self, - uri: Optional[str], - size_bytes: Optional[int], - inferred_type: Optional[LiteralType], + uri: Optional[str] = None, + size_bytes: Optional[int] = None, + inferred_type: Optional[LiteralType] = None, ): """ :param Text uri: URI of the offloaded literal From 03873df27009874ca7592f5810c2c2c2d1448f87 Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Tue, 17 Sep 2024 20:08:08 -0400 Subject: [PATCH 4/8] Add a bunch of unit tests Signed-off-by: Eduardo Apolinario --- .../unit/core/test_offloaded_literals.py | 179 ++++++++++++++++++ tests/flytekit/unit/core/test_type_engine.py | 131 ++++++++++++- 2 files changed, 302 insertions(+), 8 deletions(-) create mode 100644 tests/flytekit/unit/core/test_offloaded_literals.py diff --git a/tests/flytekit/unit/core/test_offloaded_literals.py b/tests/flytekit/unit/core/test_offloaded_literals.py new file mode 100644 index 0000000000..97fd6e97c1 --- /dev/null +++ b/tests/flytekit/unit/core/test_offloaded_literals.py @@ -0,0 +1,179 @@ +from dataclasses import dataclass +import typing + +from mashumaro.mixins.json import DataClassJSONMixin +import pytest +from flytekit import task +from flytekit.configuration import Image, ImageConfig, SerializationSettings +from flytekit.models import literals as literal_models +from flytekit.core import context_manager +from flytekit.models.types import SimpleType +from flytekit.core.type_engine import TypeEngine + +@pytest.fixture +def flyte_ctx(): + return context_manager.FlyteContext.current_context() + + +def test_task_offloaded_literal_single_input(tmp_path): + @task + def t1(a: int) -> str: + return str(a) + + original_input_literal = literal_models.Literal( + scalar=literal_models.Scalar(primitive=literal_models.Primitive(integer=3)) + ) + + # Write offloaded_lv as bytes to a temp file + with open(f"{tmp_path}/offloaded_proto.pb", "wb") as f: + f.write(original_input_literal.to_flyte_idl().SerializeToString()) + + offloaded_input_literal = literal_models.Literal( + offloaded_metadata=literal_models.LiteralOffloadedMetadata( + uri=f"{tmp_path}/offloaded_proto.pb", + inferred_type=literal_models.LiteralType(simple=SimpleType.INTEGER), + ) + ) + + ctx = context_manager.FlyteContextManager.current_context() + output_lm = t1.dispatch_execute( + ctx, + literal_models.LiteralMap( + literals={ + "a": offloaded_input_literal, + } + ), + ) + assert output_lm.literals["o0"].scalar.primitive.string_value == "3" + + +def test_task_offloaded_literal_multiple_input(tmp_path): + @task + def t1(a: int, b: int) -> int: + return a + b + + original_input_literal_a = literal_models.Literal( + scalar=literal_models.Scalar(primitive=literal_models.Primitive(integer=3)) + ) + original_input_literal_b = literal_models.Literal( + scalar=literal_models.Scalar(primitive=literal_models.Primitive(integer=4)) + ) + + # Write offloaded_lv as bytes to a temp file + with open(f"{tmp_path}/offloaded_proto_a.pb", "wb") as f: + f.write(original_input_literal_a.to_flyte_idl().SerializeToString()) + with open(f"{tmp_path}/offloaded_proto_b.pb", "wb") as f: + f.write(original_input_literal_b.to_flyte_idl().SerializeToString()) + + offloaded_input_literal_a = literal_models.Literal( + offloaded_metadata=literal_models.LiteralOffloadedMetadata( + uri=f"{tmp_path}/offloaded_proto_a.pb", + inferred_type=literal_models.LiteralType(simple=SimpleType.INTEGER), + ) + ) + offloaded_input_literal_b = literal_models.Literal( + offloaded_metadata=literal_models.LiteralOffloadedMetadata( + uri=f"{tmp_path}/offloaded_proto_b.pb", + inferred_type=literal_models.LiteralType(simple=SimpleType.INTEGER), + ) + ) + + ctx = context_manager.FlyteContextManager.current_context() + output_lm = t1.dispatch_execute( + ctx, + literal_models.LiteralMap( + literals={ + "a": offloaded_input_literal_a, + "b": offloaded_input_literal_b, + } + ), + ) + assert output_lm.literals["o0"].scalar.primitive.integer == 7 + + +def test_task_offloaded_literal_single_dataclass(tmp_path, flyte_ctx): + @dataclass + class DC(DataClassJSONMixin): + x: int + y: str + z: typing.List[int] + + @task + def t1(dc: DC) -> DC: + return dc + + lt = TypeEngine.to_literal_type(DC) + original_input_literal = TypeEngine.to_literal(flyte_ctx, DC(x=3, y="hello", z=[1, 2, 3]), DC, lt) + + with open(f"{tmp_path}/offloaded_proto.pb", "wb") as f: + f.write(original_input_literal.to_flyte_idl().SerializeToString()) + + offloaded_input_literal = literal_models.Literal( + offloaded_metadata=literal_models.LiteralOffloadedMetadata( + uri=f"{tmp_path}/offloaded_proto.pb", + inferred_type=lt, + ) + ) + + ctx = context_manager.FlyteContextManager.current_context() + output_lm = t1.dispatch_execute( + ctx, + literal_models.LiteralMap( + literals={ + "dc": offloaded_input_literal, + } + ), + ) + assert output_lm.literals["o0"] == original_input_literal + + +def test_task_offloaded_literal_list_int(tmp_path): + @task + def t1(xs: typing.List[int]) -> typing.List[str]: + return [str(a) for a in xs] + + original_input_literal = literal_models.Literal( + collection=literal_models.LiteralCollection( + literals=[ + literal_models.Literal( + scalar=literal_models.Scalar(primitive=literal_models.Primitive(integer=3)) + ), + literal_models.Literal( + scalar=literal_models.Scalar(primitive=literal_models.Primitive(integer=4)) + ), + ] + ) + ) + expected_output_literal = literal_models.Literal( + collection=literal_models.LiteralCollection( + literals=[ + literal_models.Literal( + scalar=literal_models.Scalar(primitive=literal_models.Primitive(string_value="3")) + ), + literal_models.Literal( + scalar=literal_models.Scalar(primitive=literal_models.Primitive(string_value="4")) + ), + ] + ) + ) + + with open(f"{tmp_path}/offloaded_proto.pb", "wb") as f: + f.write(original_input_literal.to_flyte_idl().SerializeToString()) + + offloaded_input_literal = literal_models.Literal( + offloaded_metadata=literal_models.LiteralOffloadedMetadata( + uri=f"{tmp_path}/offloaded_proto.pb", + inferred_type=literal_models.LiteralType(collection_type=SimpleType.INTEGER), + ) + ) + + ctx = context_manager.FlyteContextManager.current_context() + output_lm = t1.dispatch_execute( + ctx, + literal_models.LiteralMap( + literals={ + "xs": offloaded_input_literal, + } + ), + ) + assert output_lm.literals["o0"] == expected_output_literal diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 6391250b90..d07bb1488a 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -3078,25 +3078,140 @@ def test_union_file_directory(): assert pv._remote_source == s3_dir -def test_offloaded_literal(tmp_path): +@pytest.mark.parametrize( + "pt,pv", + [ + (bool, True), + (bool, False), + (int, 42), + (str, "hello"), + (Annotated[int, "tag"], 42), + (typing.List[int], [1, 2, 3]), + (typing.List[str], ["a", "b", "c"]), + (typing.List[Color], [Color.RED, Color.GREEN, Color.BLUE]), + (typing.List[Annotated[int, "tag"]], [1, 2, 3]), + (typing.List[Annotated[str, "tag"]], ["a", "b", "c"]), + (typing.Dict[int, str], {"1": "a", "2": "b", "3": "c"}), + (typing.Dict[str, int], {"a": 1, "b": 2, "c": 3}), + (typing.Dict[str, typing.List[int]], {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}), + (typing.Dict[str, typing.Dict[int, str]], {"a": {"1": "a", "2": "b", "3": "c"}, "b": {"4": "d", "5": "e", "6": "f"}}), + (typing.Union[int, str], 42), + (typing.Union[int, str], "hello"), + (typing.Union[typing.List[int], typing.List[str]], [1, 2, 3]), + (typing.Union[typing.List[int], typing.List[str]], ["a", "b", "c"]), + (typing.Union[typing.List[int], str], [1, 2, 3]), + (typing.Union[typing.List[int], str], "hello"), + ], +) +def test_offloaded_literal(tmp_path, pt, pv): ctx = FlyteContext.current_context() - pt = typing.List[int] lt = TypeEngine.to_literal_type(pt) - pv = [1, 2, 3] - offloaded_lv = TypeEngine.to_literal(ctx, pv, pt, lt) + to_be_offloaded_lv = TypeEngine.to_literal(ctx, pv, pt, lt) + + # Write offloaded_lv as bytes to a temp file + with open(f"{tmp_path}/offloaded_proto.pb", "wb") as f: + f.write(to_be_offloaded_lv.to_flyte_idl().SerializeToString()) + + literal = Literal( + offloaded_metadata=LiteralOffloadedMetadata( + uri=f"{tmp_path}/offloaded_proto.pb", + inferred_type=lt, + ), + ) + + loaded_pv = TypeEngine.to_python_value(ctx, literal, pt) + assert loaded_pv == pv + + +def test_offloaded_literal_with_inferred_type(): + ctx = FlyteContext.current_context() + lt = TypeEngine.to_literal_type(str) + offloaded_literal_missing_uri = Literal( + offloaded_metadata=LiteralOffloadedMetadata( + inferred_type=lt, + ), + ) + with pytest.raises(AssertionError): + TypeEngine.to_python_value(ctx, offloaded_literal_missing_uri, str) + + +def test_offloaded_literal_dataclass(tmp_path): + @dataclass + class InnerDatum(DataClassJsonMixin): + x: int + y: str + + @dataclass + class Datum(DataClassJsonMixin): + inner: InnerDatum + x: int + y: str + z: typing.Dict[int, int] + w: List[int] + + datum = Datum( + inner=InnerDatum(x=1, y="1"), + x=1, + y="1", + z={1: 1}, + w=[1, 1, 1, 1], + ) + + ctx = FlyteContext.current_context() + lt = TypeEngine.to_literal_type(Datum) + to_be_offloaded_lv = TypeEngine.to_literal(ctx, datum, Datum, lt) + + # Write offloaded_lv as bytes to a temp file + with open(f"{tmp_path}/offloaded_proto.pb", "wb") as f: + f.write(to_be_offloaded_lv.to_flyte_idl().SerializeToString()) + + literal = Literal( + offloaded_metadata=LiteralOffloadedMetadata( + uri=f"{tmp_path}/offloaded_proto.pb", + inferred_type=lt, + ), + ) + + loaded_datum = TypeEngine.to_python_value(ctx, literal, Datum) + assert loaded_datum == datum + + +def test_offloaded_literal_flytefile(tmp_path): + ctx = FlyteContext.current_context() + lt = TypeEngine.to_literal_type(FlyteFile) + to_be_offloaded_lv = TypeEngine.to_literal(ctx, "s3://my-file", FlyteFile, lt) + + # Write offloaded_lv as bytes to a temp file + with open(f"{tmp_path}/offloaded_proto.pb", "wb") as f: + f.write(to_be_offloaded_lv.to_flyte_idl().SerializeToString()) + + literal = Literal( + offloaded_metadata=LiteralOffloadedMetadata( + uri=f"{tmp_path}/offloaded_proto.pb", + inferred_type=lt, + ), + ) + + loaded_pv = TypeEngine.to_python_value(ctx, literal, FlyteFile) + assert loaded_pv._remote_source == "s3://my-file" + + +def test_offloaded_literal_flytedirectory(tmp_path): + ctx = FlyteContext.current_context() + lt = TypeEngine.to_literal_type(FlyteDirectory) + to_be_offloaded_lv = TypeEngine.to_literal(ctx, "s3://my-dir", FlyteDirectory, lt) # Write offloaded_lv as bytes to a temp file with open(f"{tmp_path}/offloaded_proto.pb", "wb") as f: - f.write(offloaded_lv.to_flyte_idl().SerializeToString()) + f.write(to_be_offloaded_lv.to_flyte_idl().SerializeToString()) literal = Literal( offloaded_metadata=LiteralOffloadedMetadata( uri=f"{tmp_path}/offloaded_proto.pb", - size_bytes=100, inferred_type=lt, ), ) - loaded_literal = TypeEngine.to_python_value(ctx, literal, pt) - assert loaded_literal == pv + loaded_pv: FlyteDirectory = TypeEngine.to_python_value(ctx, literal, FlyteDirectory) + assert loaded_pv._remote_source == "s3://my-dir" From 639277f553a9050207e7f57a9ec03551ea208e21 Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Tue, 17 Sep 2024 20:12:41 -0400 Subject: [PATCH 5/8] Remove TODO and fix comment Signed-off-by: Eduardo Apolinario --- flytekit/core/type_engine.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index f528462fd4..f1fb51370d 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1156,10 +1156,9 @@ def to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: T """ Converts a Literal value with an expected python type into a python value. """ - # Initiate the process of loading the offloaded literal if uri is set + # Initiate the process of loading the offloaded literal if offloaded_metadata is set if lv.offloaded_metadata: literal_random_path = ctx.file_access.get_random_local_path() - # TODO: Loading a literal from bytes requires writing it to a file assert ctx.execution_state local_literal_file = os.path.join(ctx.execution_state.working_dir, literal_random_path) assert lv.offloaded_metadata.uri, "missing offloaded uri" From cd3a129e721c74edbc6e827502dfeeb5c30dc61e Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Tue, 17 Sep 2024 20:29:33 -0400 Subject: [PATCH 6/8] Simplify generation of local file to store literal Signed-off-by: Eduardo Apolinario --- flytekit/core/type_engine.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index f1fb51370d..5a6a16592a 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1158,9 +1158,7 @@ def to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: T """ # Initiate the process of loading the offloaded literal if offloaded_metadata is set if lv.offloaded_metadata: - literal_random_path = ctx.file_access.get_random_local_path() - assert ctx.execution_state - local_literal_file = os.path.join(ctx.execution_state.working_dir, literal_random_path) + local_literal_file = ctx.file_access.get_random_local_path() assert lv.offloaded_metadata.uri, "missing offloaded uri" ctx.file_access.download(lv.offloaded_metadata.uri, local_literal_file) input_proto = load_proto_from_file(literals_pb2.Literal, local_literal_file) From d1435468ef37023dd0adaa1c93c0c3600ebdf74b Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Tue, 17 Sep 2024 20:30:28 -0400 Subject: [PATCH 7/8] Rename variable: `local_literal_file` to `literal_local_file` Signed-off-by: Eduardo Apolinario --- flytekit/core/type_engine.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 5a6a16592a..20be34e4fe 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1158,10 +1158,10 @@ def to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: T """ # Initiate the process of loading the offloaded literal if offloaded_metadata is set if lv.offloaded_metadata: - local_literal_file = ctx.file_access.get_random_local_path() + literal_local_file = ctx.file_access.get_random_local_path() assert lv.offloaded_metadata.uri, "missing offloaded uri" - ctx.file_access.download(lv.offloaded_metadata.uri, local_literal_file) - input_proto = load_proto_from_file(literals_pb2.Literal, local_literal_file) + ctx.file_access.download(lv.offloaded_metadata.uri, literal_local_file) + input_proto = load_proto_from_file(literals_pb2.Literal, literal_local_file) lv = Literal.from_flyte_idl(input_proto) transformer = cls.get_transformer(expected_python_type) From c400f7293d8a6210cf2c6bdb0af3847df1a3d772 Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Wed, 18 Sep 2024 10:32:33 -0400 Subject: [PATCH 8/8] Fix lint errors Signed-off-by: Eduardo Apolinario --- flytekit/core/type_engine.py | 1 - 1 file changed, 1 deletion(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 20be34e4fe..861909eedd 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -8,7 +8,6 @@ import inspect import json import mimetypes -import os import sys import textwrap import threading