From 97ab89f94c4d7c2e392384ca65bdaf58edd4de0b Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Mon, 13 Dec 2021 18:23:26 +0800 Subject: [PATCH 1/3] remote_source lost on serialization of @dataclass_json with FlyteFile Signed-off-by: Kevin Su --- flytekit/core/type_engine.py | 20 +++++++++++++++----- tests/flytekit/unit/core/test_type_hints.py | 12 ++++++++++++ 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index d11e795d20..06ec8b1499 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -278,12 +278,22 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]): for f in dataclasses.fields(python_type): v = python_val.__getattribute__(f.name) - if inspect.isclass(f.type) and ( - issubclass(f.type, FlyteSchema) or issubclass(f.type, FlyteFile) or issubclass(f.type, FlyteDirectory) + field_type = f.type + if inspect.isclass(field_type) and ( + issubclass(field_type, FlyteSchema) + or issubclass(field_type, FlyteFile) + or issubclass(field_type, FlyteDirectory) ): - TypeEngine.to_literal(FlyteContext.current_context(), v, f.type, None) - elif dataclasses.is_dataclass(f.type): - self._serialize_flyte_type(v, f.type) + lv = TypeEngine.to_literal(FlyteContext.current_context(), v, field_type, None) + # dataclass_json package will extract the "path" from FlyteFile, FlyteDirectory, and write it to a + # JSON which will be stored in IDL. The path here should always be a remote path, but sometimes the + # path in FlyteFile and FlyteDirectory could be a local path. Therefore, We reset the python value here, + # so that dataclass_json can always get a remote path. + if issubclass(field_type, FlyteFile) or issubclass(field_type, FlyteDirectory): + python_val.__setattr__(f.name, field_type(path=lv.scalar.blob.uri)) + + elif dataclasses.is_dataclass(field_type): + self._serialize_flyte_type(v, field_type) def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> T: from flytekit.types.directory.types import FlyteDirectory, FlyteDirToMultipartBlobTransformer diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index fcebeb53d6..81e2f4954a 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -369,13 +369,25 @@ def t1(path: str) -> FileStruct: fs = FileStruct(a=file, b=InnerFileStruct(a=file, b=PNGImageFile(path))) return fs + @dynamic + def dyn(fs: FileStruct): + t2(fs=fs) + @task def t2(fs: FileStruct) -> os.PathLike: + assert fs.a.remote_source == "s3://somewhere" + assert fs.b.a.remote_source == "s3://somewhere" + assert fs.b.b.remote_source == "s3://somewhere" + assert "/tmp/flyte/" in fs.a.path + assert "/tmp/flyte/" in fs.b.a.path + assert "/tmp/flyte/" in fs.b.b.path + return fs.a.path @workflow def wf(path: str) -> os.PathLike: n1 = t1(path=path) + dyn(fs=n1) return t2(fs=n1) assert "/tmp/flyte/" in wf(path="s3://somewhere").path From cd98150ddf7908dcc931561b7a276e68cada900b Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 15 Dec 2021 00:43:42 +0800 Subject: [PATCH 2/3] updated tests Signed-off-by: Kevin Su --- flytekit/core/type_engine.py | 6 +++++- tests/flytekit/unit/core/test_type_hints.py | 13 ++++++++++--- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 06ec8b1499..939b0531df 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -287,8 +287,12 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]): lv = TypeEngine.to_literal(FlyteContext.current_context(), v, field_type, None) # dataclass_json package will extract the "path" from FlyteFile, FlyteDirectory, and write it to a # JSON which will be stored in IDL. The path here should always be a remote path, but sometimes the - # path in FlyteFile and FlyteDirectory could be a local path. Therefore, We reset the python value here, + # path in FlyteFile and FlyteDirectory could be a local path. Therefore, reset the python value here, # so that dataclass_json can always get a remote path. + # In other words, the file transformer has special code that handles the fact that if remote_source is + # set, then the real uri in the literal should be the remote source, not the path (which may be an + # auto-generated random local path). To be sure we're writing the right path to the json, use the uri + # as determined by the transformer. if issubclass(field_type, FlyteFile) or issubclass(field_type, FlyteDirectory): python_val.__setattr__(f.name, field_type(path=lv.scalar.blob.uri)) diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index 81e2f4954a..3f7b4c901c 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -372,6 +372,7 @@ def t1(path: str) -> FileStruct: @dynamic def dyn(fs: FileStruct): t2(fs=fs) + t3(fs=fs) @task def t2(fs: FileStruct) -> os.PathLike: @@ -384,13 +385,19 @@ def t2(fs: FileStruct) -> os.PathLike: return fs.a.path + @task + def t3(fs: FileStruct) -> FlyteFile: + return fs.a + @workflow - def wf(path: str) -> os.PathLike: + def wf(path: str) -> (os.PathLike, FlyteFile): n1 = t1(path=path) dyn(fs=n1) - return t2(fs=n1) + return t2(fs=n1), t3(fs=n1) - assert "/tmp/flyte/" in wf(path="s3://somewhere").path + assert "/tmp/flyte/" in wf(path="s3://somewhere")[0].path + assert "/tmp/flyte/" in wf(path="s3://somewhere")[1].path + assert "s3://somewhere" == wf(path="s3://somewhere")[1].remote_source def test_flyte_directory_in_dataclass(): From 0c0b6500da5f36c96f65f9b05b9b7a2a150f484c Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 15 Dec 2021 00:51:01 +0800 Subject: [PATCH 3/3] updated tests Signed-off-by: Kevin Su --- tests/flytekit/unit/core/test_flyte_file.py | 25 +++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/flytekit/unit/core/test_flyte_file.py b/tests/flytekit/unit/core/test_flyte_file.py index 3a34445f08..05d8156ece 100644 --- a/tests/flytekit/unit/core/test_flyte_file.py +++ b/tests/flytekit/unit/core/test_flyte_file.py @@ -407,3 +407,28 @@ def test_file_guess(): fft = transformer.guess_python_type(lt) assert issubclass(fft, FlyteFile) assert fft.extension() == "" + + +def test_flyte_file_in_dyn(): + @task + def t1(path: str) -> FlyteFile: + return FlyteFile(path) + + @dynamic + def dyn(fs: FlyteFile): + t2(ff=fs) + + @task + def t2(ff: FlyteFile) -> os.PathLike: + assert ff.remote_source == "s3://somewhere" + assert "/tmp/flyte/" in ff.path + + return ff.path + + @workflow + def wf(path: str) -> os.PathLike: + n1 = t1(path=path) + dyn(fs=n1) + return t2(ff=n1) + + assert "/tmp/flyte/" in wf(path="s3://somewhere").path