-
Notifications
You must be signed in to change notification settings - Fork 300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Read offloaded literals #2685
Read offloaded literals #2685
Changes from all commits
990b450
5f82ca2
60a1abc
03873df
8f2299c
639277f
cd3a129
d143546
c400f72
cf51973
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,179 @@ | ||||||
from dataclasses import dataclass | ||||||
import typing | ||||||
|
||||||
from mashumaro.mixins.json import DataClassJSONMixin | ||||||
import pytest | ||||||
from flytekit import task | ||||||
from flytekit.configuration import Image, ImageConfig, SerializationSettings | ||||||
from flytekit.models import literals as literal_models | ||||||
from flytekit.core import context_manager | ||||||
from flytekit.models.types import SimpleType | ||||||
from flytekit.core.type_engine import TypeEngine | ||||||
|
||||||
@pytest.fixture | ||||||
def flyte_ctx(): | ||||||
return context_manager.FlyteContext.current_context() | ||||||
|
||||||
|
||||||
def test_task_offloaded_literal_single_input(tmp_path): | ||||||
@task | ||||||
def t1(a: int) -> str: | ||||||
return str(a) | ||||||
|
||||||
original_input_literal = literal_models.Literal( | ||||||
scalar=literal_models.Scalar(primitive=literal_models.Primitive(integer=3)) | ||||||
) | ||||||
|
||||||
# Write offloaded_lv as bytes to a temp file | ||||||
with open(f"{tmp_path}/offloaded_proto.pb", "wb") as f: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm surprised using
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that's a good call. I'm going to fix this in a separate PR. As for what happens on windows, my guess is that the |
||||||
f.write(original_input_literal.to_flyte_idl().SerializeToString()) | ||||||
|
||||||
offloaded_input_literal = literal_models.Literal( | ||||||
offloaded_metadata=literal_models.LiteralOffloadedMetadata( | ||||||
uri=f"{tmp_path}/offloaded_proto.pb", | ||||||
inferred_type=literal_models.LiteralType(simple=SimpleType.INTEGER), | ||||||
) | ||||||
) | ||||||
|
||||||
ctx = context_manager.FlyteContextManager.current_context() | ||||||
output_lm = t1.dispatch_execute( | ||||||
ctx, | ||||||
literal_models.LiteralMap( | ||||||
literals={ | ||||||
"a": offloaded_input_literal, | ||||||
} | ||||||
), | ||||||
) | ||||||
assert output_lm.literals["o0"].scalar.primitive.string_value == "3" | ||||||
|
||||||
|
||||||
def test_task_offloaded_literal_multiple_input(tmp_path): | ||||||
@task | ||||||
def t1(a: int, b: int) -> int: | ||||||
return a + b | ||||||
|
||||||
original_input_literal_a = literal_models.Literal( | ||||||
scalar=literal_models.Scalar(primitive=literal_models.Primitive(integer=3)) | ||||||
) | ||||||
original_input_literal_b = literal_models.Literal( | ||||||
scalar=literal_models.Scalar(primitive=literal_models.Primitive(integer=4)) | ||||||
) | ||||||
|
||||||
# Write offloaded_lv as bytes to a temp file | ||||||
with open(f"{tmp_path}/offloaded_proto_a.pb", "wb") as f: | ||||||
f.write(original_input_literal_a.to_flyte_idl().SerializeToString()) | ||||||
with open(f"{tmp_path}/offloaded_proto_b.pb", "wb") as f: | ||||||
f.write(original_input_literal_b.to_flyte_idl().SerializeToString()) | ||||||
|
||||||
offloaded_input_literal_a = literal_models.Literal( | ||||||
offloaded_metadata=literal_models.LiteralOffloadedMetadata( | ||||||
uri=f"{tmp_path}/offloaded_proto_a.pb", | ||||||
inferred_type=literal_models.LiteralType(simple=SimpleType.INTEGER), | ||||||
) | ||||||
) | ||||||
offloaded_input_literal_b = literal_models.Literal( | ||||||
offloaded_metadata=literal_models.LiteralOffloadedMetadata( | ||||||
uri=f"{tmp_path}/offloaded_proto_b.pb", | ||||||
inferred_type=literal_models.LiteralType(simple=SimpleType.INTEGER), | ||||||
) | ||||||
) | ||||||
|
||||||
ctx = context_manager.FlyteContextManager.current_context() | ||||||
output_lm = t1.dispatch_execute( | ||||||
ctx, | ||||||
literal_models.LiteralMap( | ||||||
literals={ | ||||||
"a": offloaded_input_literal_a, | ||||||
"b": offloaded_input_literal_b, | ||||||
} | ||||||
), | ||||||
) | ||||||
assert output_lm.literals["o0"].scalar.primitive.integer == 7 | ||||||
|
||||||
|
||||||
def test_task_offloaded_literal_single_dataclass(tmp_path, flyte_ctx): | ||||||
@dataclass | ||||||
class DC(DataClassJSONMixin): | ||||||
x: int | ||||||
y: str | ||||||
z: typing.List[int] | ||||||
|
||||||
@task | ||||||
def t1(dc: DC) -> DC: | ||||||
return dc | ||||||
|
||||||
lt = TypeEngine.to_literal_type(DC) | ||||||
original_input_literal = TypeEngine.to_literal(flyte_ctx, DC(x=3, y="hello", z=[1, 2, 3]), DC, lt) | ||||||
|
||||||
with open(f"{tmp_path}/offloaded_proto.pb", "wb") as f: | ||||||
f.write(original_input_literal.to_flyte_idl().SerializeToString()) | ||||||
|
||||||
offloaded_input_literal = literal_models.Literal( | ||||||
offloaded_metadata=literal_models.LiteralOffloadedMetadata( | ||||||
uri=f"{tmp_path}/offloaded_proto.pb", | ||||||
inferred_type=lt, | ||||||
) | ||||||
) | ||||||
|
||||||
ctx = context_manager.FlyteContextManager.current_context() | ||||||
output_lm = t1.dispatch_execute( | ||||||
ctx, | ||||||
literal_models.LiteralMap( | ||||||
literals={ | ||||||
"dc": offloaded_input_literal, | ||||||
} | ||||||
), | ||||||
) | ||||||
assert output_lm.literals["o0"] == original_input_literal | ||||||
|
||||||
|
||||||
def test_task_offloaded_literal_list_int(tmp_path): | ||||||
@task | ||||||
def t1(xs: typing.List[int]) -> typing.List[str]: | ||||||
return [str(a) for a in xs] | ||||||
|
||||||
original_input_literal = literal_models.Literal( | ||||||
collection=literal_models.LiteralCollection( | ||||||
literals=[ | ||||||
literal_models.Literal( | ||||||
scalar=literal_models.Scalar(primitive=literal_models.Primitive(integer=3)) | ||||||
), | ||||||
literal_models.Literal( | ||||||
scalar=literal_models.Scalar(primitive=literal_models.Primitive(integer=4)) | ||||||
), | ||||||
] | ||||||
) | ||||||
) | ||||||
expected_output_literal = literal_models.Literal( | ||||||
collection=literal_models.LiteralCollection( | ||||||
literals=[ | ||||||
literal_models.Literal( | ||||||
scalar=literal_models.Scalar(primitive=literal_models.Primitive(string_value="3")) | ||||||
), | ||||||
literal_models.Literal( | ||||||
scalar=literal_models.Scalar(primitive=literal_models.Primitive(string_value="4")) | ||||||
), | ||||||
] | ||||||
) | ||||||
) | ||||||
|
||||||
with open(f"{tmp_path}/offloaded_proto.pb", "wb") as f: | ||||||
f.write(original_input_literal.to_flyte_idl().SerializeToString()) | ||||||
|
||||||
offloaded_input_literal = literal_models.Literal( | ||||||
offloaded_metadata=literal_models.LiteralOffloadedMetadata( | ||||||
uri=f"{tmp_path}/offloaded_proto.pb", | ||||||
inferred_type=literal_models.LiteralType(collection_type=SimpleType.INTEGER), | ||||||
) | ||||||
) | ||||||
|
||||||
ctx = context_manager.FlyteContextManager.current_context() | ||||||
output_lm = t1.dispatch_execute( | ||||||
ctx, | ||||||
literal_models.LiteralMap( | ||||||
literals={ | ||||||
"xs": offloaded_input_literal, | ||||||
} | ||||||
), | ||||||
) | ||||||
assert output_lm.literals["o0"] == expected_output_literal |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not your change, but a generator class for this could be useful
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you mean like a utility function to help in tests?