Skip to content

Commit

Permalink
remote_source lost on serialization of @dataclass_json with FlyteFile
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw committed Dec 13, 2021
1 parent fed03d3 commit b86ff89
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 5 deletions.
20 changes: 15 additions & 5 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,12 +278,22 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]):

for f in dataclasses.fields(python_type):
v = python_val.__getattribute__(f.name)
if inspect.isclass(f.type) and (
issubclass(f.type, FlyteSchema) or issubclass(f.type, FlyteFile) or issubclass(f.type, FlyteDirectory)
field_type = f.type
if inspect.isclass(field_type) and (
issubclass(field_type, FlyteSchema)
or issubclass(field_type, FlyteFile)
or issubclass(field_type, FlyteDirectory)
):
TypeEngine.to_literal(FlyteContext.current_context(), v, f.type, None)
elif dataclasses.is_dataclass(f.type):
self._serialize_flyte_type(v, f.type)
lv = TypeEngine.to_literal(FlyteContext.current_context(), v, field_type, None)
# dataclass_json package will extract the "path" in FlyteFile, FlyteDirectory, and write it to a
# json file which will be stored in IDL. The path here should always be a remote path, but sometimes the
# path in FlyteFile and FlyteDirectory could be a local path. Therefore, We reset the python value here,
# so that dataclass_json can always get a remote path.
if issubclass(field_type, FlyteFile) or issubclass(field_type, FlyteDirectory):
python_val.__setattr__(f.name, field_type(path=lv.scalar.blob.uri))

elif dataclasses.is_dataclass(field_type):
self._serialize_flyte_type(v, field_type)

def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> T:
from flytekit.types.directory.types import FlyteDirectory, FlyteDirToMultipartBlobTransformer
Expand Down
12 changes: 12 additions & 0 deletions tests/flytekit/unit/core/test_type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,13 +369,25 @@ def t1(path: str) -> FileStruct:
fs = FileStruct(a=file, b=InnerFileStruct(a=file, b=PNGImageFile(path)))
return fs

@dynamic
def dyn(fs: FileStruct):
t2(fs=fs)

@task
def t2(fs: FileStruct) -> os.PathLike:
assert fs.a.remote_source == "s3://somewhere"
assert fs.b.a.remote_source == "s3://somewhere"
assert fs.b.b.remote_source == "s3://somewhere"
assert "/tmp/flyte/" in fs.a.path
assert "/tmp/flyte/" in fs.b.a.path
assert "/tmp/flyte/" in fs.b.b.path

return fs.a.path

@workflow
def wf(path: str) -> os.PathLike:
n1 = t1(path=path)
dyn(fs=n1)
return t2(fs=n1)

assert "/tmp/flyte/" in wf(path="s3://somewhere").path
Expand Down

0 comments on commit b86ff89

Please sign in to comment.