diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index d11e795d20..939b0531df 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -278,12 +278,26 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]): for f in dataclasses.fields(python_type): v = python_val.__getattribute__(f.name) - if inspect.isclass(f.type) and ( - issubclass(f.type, FlyteSchema) or issubclass(f.type, FlyteFile) or issubclass(f.type, FlyteDirectory) + field_type = f.type + if inspect.isclass(field_type) and ( + issubclass(field_type, FlyteSchema) + or issubclass(field_type, FlyteFile) + or issubclass(field_type, FlyteDirectory) ): - TypeEngine.to_literal(FlyteContext.current_context(), v, f.type, None) - elif dataclasses.is_dataclass(f.type): - self._serialize_flyte_type(v, f.type) + lv = TypeEngine.to_literal(FlyteContext.current_context(), v, field_type, None) + # dataclass_json package will extract the "path" from FlyteFile, FlyteDirectory, and write it to a + # JSON which will be stored in IDL. The path here should always be a remote path, but sometimes the + # path in FlyteFile and FlyteDirectory could be a local path. Therefore, reset the python value here, + # so that dataclass_json can always get a remote path. + # In other words, the file transformer has special code that handles the fact that if remote_source is + # set, then the real uri in the literal should be the remote source, not the path (which may be an + # auto-generated random local path). To be sure we're writing the right path to the json, use the uri + # as determined by the transformer. + if issubclass(field_type, FlyteFile) or issubclass(field_type, FlyteDirectory): + python_val.__setattr__(f.name, field_type(path=lv.scalar.blob.uri)) + + elif dataclasses.is_dataclass(field_type): + self._serialize_flyte_type(v, field_type) def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> T: from flytekit.types.directory.types import FlyteDirectory, FlyteDirToMultipartBlobTransformer diff --git a/tests/flytekit/unit/core/test_flyte_file.py b/tests/flytekit/unit/core/test_flyte_file.py index 3a34445f08..05d8156ece 100644 --- a/tests/flytekit/unit/core/test_flyte_file.py +++ b/tests/flytekit/unit/core/test_flyte_file.py @@ -407,3 +407,28 @@ def test_file_guess(): fft = transformer.guess_python_type(lt) assert issubclass(fft, FlyteFile) assert fft.extension() == "" + + +def test_flyte_file_in_dyn(): + @task + def t1(path: str) -> FlyteFile: + return FlyteFile(path) + + @dynamic + def dyn(fs: FlyteFile): + t2(ff=fs) + + @task + def t2(ff: FlyteFile) -> os.PathLike: + assert ff.remote_source == "s3://somewhere" + assert "/tmp/flyte/" in ff.path + + return ff.path + + @workflow + def wf(path: str) -> os.PathLike: + n1 = t1(path=path) + dyn(fs=n1) + return t2(ff=n1) + + assert "/tmp/flyte/" in wf(path="s3://somewhere").path diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index fcebeb53d6..3f7b4c901c 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -369,16 +369,35 @@ def t1(path: str) -> FileStruct: fs = FileStruct(a=file, b=InnerFileStruct(a=file, b=PNGImageFile(path))) return fs + @dynamic + def dyn(fs: FileStruct): + t2(fs=fs) + t3(fs=fs) + @task def t2(fs: FileStruct) -> os.PathLike: + assert fs.a.remote_source == "s3://somewhere" + assert fs.b.a.remote_source == "s3://somewhere" + assert fs.b.b.remote_source == "s3://somewhere" + assert "/tmp/flyte/" in fs.a.path + assert "/tmp/flyte/" in fs.b.a.path + assert "/tmp/flyte/" in fs.b.b.path + return fs.a.path + @task + def t3(fs: FileStruct) -> FlyteFile: + return fs.a + @workflow - def wf(path: str) -> os.PathLike: + def wf(path: str) -> (os.PathLike, FlyteFile): n1 = t1(path=path) - return t2(fs=n1) + dyn(fs=n1) + return t2(fs=n1), t3(fs=n1) - assert "/tmp/flyte/" in wf(path="s3://somewhere").path + assert "/tmp/flyte/" in wf(path="s3://somewhere")[0].path + assert "/tmp/flyte/" in wf(path="s3://somewhere")[1].path + assert "s3://somewhere" == wf(path="s3://somewhere")[1].remote_source def test_flyte_directory_in_dataclass():