diff --git a/python/orca/src/bigdl/orca/data/file.py b/python/orca/src/bigdl/orca/data/file.py index d859a7832b4..c9f13311c82 100644 --- a/python/orca/src/bigdl/orca/data/file.py +++ b/python/orca/src/bigdl/orca/data/file.py @@ -18,6 +18,7 @@ import subprocess import logging import shutil +import functools import glob from distutils.dir_util import copy_tree from bigdl.dllib.utils.log4Error import * @@ -537,3 +538,43 @@ def get_remote_files_with_prefix_to_local(remote_path_prefix, local_dir): except Exception as e: invalidOperationError(False, str(e), cause=e) return os.path.join(local_dir, prefix) + + +def multi_fs_load(load_func): + """ + + Enable loading file or directory in multiple file systems. + It supports local, hdfs, s3 file systems. + Note: this decorator is different from dllib decorator @enable_multi_fs_load. + This decorator can load on each worker while @enable_multi_fs_load can only load on driver. + + :param load_func: load file or directory function + :return: load file or directory function for the specific file system + """ + @functools.wraps(load_func) + def fs_load(path, *args, **kwargs): + from bigdl.dllib.utils.file_utils import is_local_path + if is_local_path(path): + return load_func(path, *args, **kwargs) + else: + import uuid + import tempfile + from bigdl.dllib.utils.file_utils import append_suffix + file_name = str(uuid.uuid1()) + file_name = append_suffix(file_name, path.strip("/").split("/")[-1]) + temp_path = os.path.join(tempfile.gettempdir(), file_name) + if is_file(path): + get_remote_file_to_local(path, temp_path) + else: + os.mkdir(temp_path) + get_remote_dir_to_local(path, temp_path) + try: + return load_func(temp_path, *args, **kwargs) + finally: + if os.path.isdir(temp_path): + import shutil + shutil.rmtree(temp_path) + else: + os.remove(temp_path) + + return fs_load diff --git a/python/orca/test/bigdl/orca/data/test_file.py b/python/orca/test/bigdl/orca/data/test_file.py index 13888b7c8ec..4c1f534c9f3 100644 --- a/python/orca/test/bigdl/orca/data/test_file.py +++ b/python/orca/test/bigdl/orca/data/test_file.py @@ -18,7 +18,7 @@ import shutil import tempfile -from bigdl.orca.data.file import open_image, open_text, load_numpy, exists, makedirs, write_text +from bigdl.orca.data.file import open_image, open_text, load_numpy, exists, makedirs, write_text, multi_fs_load class TestFile: @@ -160,3 +160,24 @@ def test_write_text_s3(self): aws_access_key_id=access_key_id, aws_secret_access_key=secret_access_key).client('s3', verify=False) s3_client.delete_object(Bucket='analytics-zoo-data', Key='test.txt') + + def test_multi_fs_load_local(self): + + @multi_fs_load + def mock_func(path): + assert exists(path) + + file_path = os.path.join(self.resource_path, "orca/data/random.npy") + mock_func(file_path) + + def test_multi_fs_load_s3(self): + + @multi_fs_load + def mock_func(path): + assert exists(path) + + access_key_id = os.getenv("AWS_ACCESS_KEY_ID") + secret_access_key = os.getenv("AWS_SECRET_ACCESS_KEY") + if access_key_id and secret_access_key: + file_path = "s3://analytics-zoo-data/hyperseg/VGGcompression/core1.npy" + mock_func(file_path)