Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remote_source lost on serialization of @dataclass_json with FlyteFile #774

Merged
merged 3 commits into from
Dec 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,12 +278,26 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]):

for f in dataclasses.fields(python_type):
v = python_val.__getattribute__(f.name)
if inspect.isclass(f.type) and (
issubclass(f.type, FlyteSchema) or issubclass(f.type, FlyteFile) or issubclass(f.type, FlyteDirectory)
field_type = f.type
if inspect.isclass(field_type) and (
issubclass(field_type, FlyteSchema)
or issubclass(field_type, FlyteFile)
or issubclass(field_type, FlyteDirectory)
):
TypeEngine.to_literal(FlyteContext.current_context(), v, f.type, None)
elif dataclasses.is_dataclass(f.type):
self._serialize_flyte_type(v, f.type)
lv = TypeEngine.to_literal(FlyteContext.current_context(), v, field_type, None)
# dataclass_json package will extract the "path" from FlyteFile, FlyteDirectory, and write it to a
# JSON which will be stored in IDL. The path here should always be a remote path, but sometimes the
# path in FlyteFile and FlyteDirectory could be a local path. Therefore, reset the python value here,
# so that dataclass_json can always get a remote path.
# In other words, the file transformer has special code that handles the fact that if remote_source is
# set, then the real uri in the literal should be the remote source, not the path (which may be an
# auto-generated random local path). To be sure we're writing the right path to the json, use the uri
# as determined by the transformer.
if issubclass(field_type, FlyteFile) or issubclass(field_type, FlyteDirectory):
python_val.__setattr__(f.name, field_type(path=lv.scalar.blob.uri))

elif dataclasses.is_dataclass(field_type):
self._serialize_flyte_type(v, field_type)

def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> T:
from flytekit.types.directory.types import FlyteDirectory, FlyteDirToMultipartBlobTransformer
Expand Down
25 changes: 25 additions & 0 deletions tests/flytekit/unit/core/test_flyte_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,3 +407,28 @@ def test_file_guess():
fft = transformer.guess_python_type(lt)
assert issubclass(fft, FlyteFile)
assert fft.extension() == ""


def test_flyte_file_in_dyn():
@task
def t1(path: str) -> FlyteFile:
return FlyteFile(path)

@dynamic
def dyn(fs: FlyteFile):
t2(ff=fs)

@task
def t2(ff: FlyteFile) -> os.PathLike:
assert ff.remote_source == "s3://somewhere"
assert "/tmp/flyte/" in ff.path

return ff.path

@workflow
def wf(path: str) -> os.PathLike:
n1 = t1(path=path)
dyn(fs=n1)
return t2(ff=n1)

assert "/tmp/flyte/" in wf(path="s3://somewhere").path
25 changes: 22 additions & 3 deletions tests/flytekit/unit/core/test_type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,16 +369,35 @@ def t1(path: str) -> FileStruct:
fs = FileStruct(a=file, b=InnerFileStruct(a=file, b=PNGImageFile(path)))
return fs

@dynamic
def dyn(fs: FileStruct):
t2(fs=fs)
t3(fs=fs)

@task
def t2(fs: FileStruct) -> os.PathLike:
assert fs.a.remote_source == "s3://somewhere"
assert fs.b.a.remote_source == "s3://somewhere"
assert fs.b.b.remote_source == "s3://somewhere"
assert "/tmp/flyte/" in fs.a.path
assert "/tmp/flyte/" in fs.b.a.path
assert "/tmp/flyte/" in fs.b.b.path

return fs.a.path

@task
def t3(fs: FileStruct) -> FlyteFile:
return fs.a

@workflow
def wf(path: str) -> os.PathLike:
def wf(path: str) -> (os.PathLike, FlyteFile):
n1 = t1(path=path)
return t2(fs=n1)
dyn(fs=n1)
return t2(fs=n1), t3(fs=n1)

assert "/tmp/flyte/" in wf(path="s3://somewhere").path
assert "/tmp/flyte/" in wf(path="s3://somewhere")[0].path
assert "/tmp/flyte/" in wf(path="s3://somewhere")[1].path
assert "s3://somewhere" == wf(path="s3://somewhere")[1].remote_source


def test_flyte_directory_in_dataclass():
Expand Down