diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index 8561ec0157..15113f8f01 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -153,7 +153,7 @@ def new_remote_file(cls, name: typing.Optional[str] = None) -> FlyteFile: Create a new FlyteFile object with a remote path. """ ctx = FlyteContextManager.current_context() - r = ctx.file_access.get_random_string() + r = name or ctx.file_access.get_random_string() remote_path = ctx.file_access.join(ctx.file_access.raw_output_prefix, r) return cls(path=remote_path) diff --git a/tests/flytekit/unit/core/test_flyte_file.py b/tests/flytekit/unit/core/test_flyte_file.py index 33a796b875..ff038c164b 100644 --- a/tests/flytekit/unit/core/test_flyte_file.py +++ b/tests/flytekit/unit/core/test_flyte_file.py @@ -678,3 +678,9 @@ def test_join(): def test_headers(): assert FlyteFilePathTransformer.get_additional_headers("xyz") == {} assert len(FlyteFilePathTransformer.get_additional_headers(".gz")) == 1 + + +def test_new_remote_file(): + nf = FlyteFile.new_remote_file(name="foo.txt") + assert isinstance(nf, FlyteFile) + assert nf.path.endswith('foo.txt')