Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Read offloaded literals #2685

Merged
merged 10 commits into from
Sep 18, 2024
10 changes: 9 additions & 1 deletion flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from flytekit.core.context_manager import FlyteContext
from flytekit.core.hash import HashMethod
from flytekit.core.type_helpers import load_type_from_tag
from flytekit.core.utils import timeit
from flytekit.core.utils import load_proto_from_file, timeit
from flytekit.exceptions import user as user_exceptions
from flytekit.interaction.string_literals import literal_map_string_repr
from flytekit.lazy_import.lazy_module import is_imported
Expand Down Expand Up @@ -1155,6 +1155,14 @@
"""
Converts a Literal value with an expected python type into a python value.
"""
# Initiate the process of loading the offloaded literal if offloaded_metadata is set
if lv.offloaded_metadata:
literal_local_file = ctx.file_access.get_random_local_path()
assert lv.offloaded_metadata.uri, "missing offloaded uri"
ctx.file_access.download(lv.offloaded_metadata.uri, literal_local_file)
input_proto = load_proto_from_file(literals_pb2.Literal, literal_local_file)
lv = Literal.from_flyte_idl(input_proto)

Check warning on line 1164 in flytekit/core/type_engine.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_engine.py#L1160-L1164

Added lines #L1160 - L1164 were not covered by tests

transformer = cls.get_transformer(expected_python_type)
return transformer.to_python_value(ctx, lv, expected_python_type)

Expand Down
63 changes: 61 additions & 2 deletions flytekit/models/literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from flytekit.exceptions import user as _user_exceptions
from flytekit.models import common as _common
from flytekit.models.core import types as _core_types
from flytekit.models.types import Error, StructuredDatasetType
from flytekit.models.types import Error, LiteralType, StructuredDatasetType
from flytekit.models.types import LiteralType as _LiteralType
from flytekit.models.types import OutputReference as _OutputReference
from flytekit.models.types import SchemaType as _SchemaType
Expand Down Expand Up @@ -852,6 +852,52 @@
)


class LiteralOffloadedMetadata(_common.FlyteIdlEntity):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not your change, but a generator class for this could be useful

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean like a utility function to help in tests?

def __init__(
self,
uri: Optional[str] = None,
size_bytes: Optional[int] = None,
inferred_type: Optional[LiteralType] = None,
):
"""
:param Text uri: URI of the offloaded literal
:param int size_bytes: Size in bytes of the offloaded literal proto
:param LiteralType inferred_type: Inferred type of the offloaded literal
"""
self._uri = uri
self._size_bytes = size_bytes
self._inferred_type = inferred_type

Check warning on line 869 in flytekit/models/literals.py

View check run for this annotation

Codecov / codecov/patch

flytekit/models/literals.py#L867-L869

Added lines #L867 - L869 were not covered by tests

@property
def uri(self):
return self._uri

Check warning on line 873 in flytekit/models/literals.py

View check run for this annotation

Codecov / codecov/patch

flytekit/models/literals.py#L873

Added line #L873 was not covered by tests

@property
def size_bytes(self):
return self._size_bytes

Check warning on line 877 in flytekit/models/literals.py

View check run for this annotation

Codecov / codecov/patch

flytekit/models/literals.py#L877

Added line #L877 was not covered by tests

@property
def inferred_type(self):
return self._inferred_type

Check warning on line 881 in flytekit/models/literals.py

View check run for this annotation

Codecov / codecov/patch

flytekit/models/literals.py#L881

Added line #L881 was not covered by tests

def to_flyte_idl(self):
return _literals_pb2.LiteralOffloadedMetadata(

Check warning on line 884 in flytekit/models/literals.py

View check run for this annotation

Codecov / codecov/patch

flytekit/models/literals.py#L884

Added line #L884 was not covered by tests
uri=self.uri,
size_bytes=self.size_bytes,
inferred_type=self.inferred_type.to_flyte_idl() if self.inferred_type else None,
)

@classmethod
def from_flyte_idl(cls, pb2_object):
return cls(

Check warning on line 892 in flytekit/models/literals.py

View check run for this annotation

Codecov / codecov/patch

flytekit/models/literals.py#L892

Added line #L892 was not covered by tests
uri=pb2_object.uri,
size_bytes=pb2_object.size_bytes,
inferred_type=_LiteralType.from_flyte_idl(pb2_object.inferred_type)
if pb2_object.HasField("inferred_type")
else None,
)


class Literal(_common.FlyteIdlEntity):
def __init__(
self,
Expand All @@ -860,6 +906,7 @@
map: Optional[LiteralMap] = None,
hash: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
offloaded_metadata: Optional[LiteralOffloadedMetadata] = None,
):
"""
This IDL message represents a literal value in the Flyte ecosystem.
Expand All @@ -873,6 +920,7 @@
self._map = map
self._hash = hash
self._metadata = metadata
self._offloaded_metadata = offloaded_metadata

@property
def scalar(self):
Expand Down Expand Up @@ -925,6 +973,13 @@
"""
return self._metadata

@property
def offloaded_metadata(self) -> Optional[LiteralOffloadedMetadata]:
"""
This value holds metadata about the offloaded literal.
"""
return self._offloaded_metadata

def to_flyte_idl(self):
"""
:rtype: flyteidl.core.literals_pb2.Literal
Expand All @@ -935,10 +990,11 @@
map=self.map.to_flyte_idl() if self.map is not None else None,
hash=self.hash,
metadata=self.metadata,
offloaded_metadata=self.offloaded_metadata.to_flyte_idl() if self.offloaded_metadata else None,
)

@classmethod
def from_flyte_idl(cls, pb2_object):
def from_flyte_idl(cls, pb2_object: _literals_pb2.Literal):
"""
:param flyteidl.core.literals_pb2.Literal pb2_object:
:rtype: Literal
Expand All @@ -953,6 +1009,9 @@
map=LiteralMap.from_flyte_idl(pb2_object.map) if pb2_object.HasField("map") else None,
hash=pb2_object.hash if pb2_object.hash else None,
metadata={k: v for k, v in pb2_object.metadata.items()} if pb2_object.metadata else None,
offloaded_metadata=LiteralOffloadedMetadata.from_flyte_idl(pb2_object.offloaded_metadata)
if pb2_object.HasField("offloaded_metadata")
else None,
)

def set_metadata(self, metadata: Dict[str, str]):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dependencies = [
"diskcache>=5.2.1",
"docker>=4.0.0",
"docstring-parser>=0.9.0",
"flyteidl>=1.13.1",
"flyteidl>=1.13.4",
"fsspec>=2023.3.0",
"gcsfs>=2023.3.0",
"googleapis-common-protos>=1.57",
Expand Down
179 changes: 179 additions & 0 deletions tests/flytekit/unit/core/test_offloaded_literals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
from dataclasses import dataclass
import typing

from mashumaro.mixins.json import DataClassJSONMixin
import pytest
from flytekit import task
from flytekit.configuration import Image, ImageConfig, SerializationSettings
from flytekit.models import literals as literal_models
from flytekit.core import context_manager
from flytekit.models.types import SimpleType
from flytekit.core.type_engine import TypeEngine

@pytest.fixture
def flyte_ctx():
return context_manager.FlyteContext.current_context()


def test_task_offloaded_literal_single_input(tmp_path):
@task
def t1(a: int) -> str:
return str(a)

original_input_literal = literal_models.Literal(
scalar=literal_models.Scalar(primitive=literal_models.Primitive(integer=3))
)

# Write offloaded_lv as bytes to a temp file
with open(f"{tmp_path}/offloaded_proto.pb", "wb") as f:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised using / works with windows. I usually go with:

Suggested change
with open(f"{tmp_path}/offloaded_proto.pb", "wb") as f:
with (tmp_path / "offloaded_proto.pb").open("wb") as f:

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a good call. I'm going to fix this in a separate PR.

As for what happens on windows, my guess is that the / becomes part of the file name.

f.write(original_input_literal.to_flyte_idl().SerializeToString())

offloaded_input_literal = literal_models.Literal(
offloaded_metadata=literal_models.LiteralOffloadedMetadata(
uri=f"{tmp_path}/offloaded_proto.pb",
inferred_type=literal_models.LiteralType(simple=SimpleType.INTEGER),
)
)

ctx = context_manager.FlyteContextManager.current_context()
output_lm = t1.dispatch_execute(
ctx,
literal_models.LiteralMap(
literals={
"a": offloaded_input_literal,
}
),
)
assert output_lm.literals["o0"].scalar.primitive.string_value == "3"


def test_task_offloaded_literal_multiple_input(tmp_path):
@task
def t1(a: int, b: int) -> int:
return a + b

original_input_literal_a = literal_models.Literal(
scalar=literal_models.Scalar(primitive=literal_models.Primitive(integer=3))
)
original_input_literal_b = literal_models.Literal(
scalar=literal_models.Scalar(primitive=literal_models.Primitive(integer=4))
)

# Write offloaded_lv as bytes to a temp file
with open(f"{tmp_path}/offloaded_proto_a.pb", "wb") as f:
f.write(original_input_literal_a.to_flyte_idl().SerializeToString())
with open(f"{tmp_path}/offloaded_proto_b.pb", "wb") as f:
f.write(original_input_literal_b.to_flyte_idl().SerializeToString())

offloaded_input_literal_a = literal_models.Literal(
offloaded_metadata=literal_models.LiteralOffloadedMetadata(
uri=f"{tmp_path}/offloaded_proto_a.pb",
inferred_type=literal_models.LiteralType(simple=SimpleType.INTEGER),
)
)
offloaded_input_literal_b = literal_models.Literal(
offloaded_metadata=literal_models.LiteralOffloadedMetadata(
uri=f"{tmp_path}/offloaded_proto_b.pb",
inferred_type=literal_models.LiteralType(simple=SimpleType.INTEGER),
)
)

ctx = context_manager.FlyteContextManager.current_context()
output_lm = t1.dispatch_execute(
ctx,
literal_models.LiteralMap(
literals={
"a": offloaded_input_literal_a,
"b": offloaded_input_literal_b,
}
),
)
assert output_lm.literals["o0"].scalar.primitive.integer == 7


def test_task_offloaded_literal_single_dataclass(tmp_path, flyte_ctx):
@dataclass
class DC(DataClassJSONMixin):
x: int
y: str
z: typing.List[int]

@task
def t1(dc: DC) -> DC:
return dc

lt = TypeEngine.to_literal_type(DC)
original_input_literal = TypeEngine.to_literal(flyte_ctx, DC(x=3, y="hello", z=[1, 2, 3]), DC, lt)

with open(f"{tmp_path}/offloaded_proto.pb", "wb") as f:
f.write(original_input_literal.to_flyte_idl().SerializeToString())

offloaded_input_literal = literal_models.Literal(
offloaded_metadata=literal_models.LiteralOffloadedMetadata(
uri=f"{tmp_path}/offloaded_proto.pb",
inferred_type=lt,
)
)

ctx = context_manager.FlyteContextManager.current_context()
output_lm = t1.dispatch_execute(
ctx,
literal_models.LiteralMap(
literals={
"dc": offloaded_input_literal,
}
),
)
assert output_lm.literals["o0"] == original_input_literal


def test_task_offloaded_literal_list_int(tmp_path):
@task
def t1(xs: typing.List[int]) -> typing.List[str]:
return [str(a) for a in xs]

original_input_literal = literal_models.Literal(
collection=literal_models.LiteralCollection(
literals=[
literal_models.Literal(
scalar=literal_models.Scalar(primitive=literal_models.Primitive(integer=3))
),
literal_models.Literal(
scalar=literal_models.Scalar(primitive=literal_models.Primitive(integer=4))
),
]
)
)
expected_output_literal = literal_models.Literal(
collection=literal_models.LiteralCollection(
literals=[
literal_models.Literal(
scalar=literal_models.Scalar(primitive=literal_models.Primitive(string_value="3"))
),
literal_models.Literal(
scalar=literal_models.Scalar(primitive=literal_models.Primitive(string_value="4"))
),
]
)
)

with open(f"{tmp_path}/offloaded_proto.pb", "wb") as f:
f.write(original_input_literal.to_flyte_idl().SerializeToString())

offloaded_input_literal = literal_models.Literal(
offloaded_metadata=literal_models.LiteralOffloadedMetadata(
uri=f"{tmp_path}/offloaded_proto.pb",
inferred_type=literal_models.LiteralType(collection_type=SimpleType.INTEGER),
)
)

ctx = context_manager.FlyteContextManager.current_context()
output_lm = t1.dispatch_execute(
ctx,
literal_models.LiteralMap(
literals={
"xs": offloaded_input_literal,
}
),
)
assert output_lm.literals["o0"] == expected_output_literal
Loading
Loading