Skip to content

Commit

Permalink
Orca: add get remote file to local decorator in orca file utils. (#5508)
Browse files Browse the repository at this point in the history
* feat: add get remote file to local decorator in orca file utils.

* feat: add unit-test for local and s3 FS.

* fix: update func name and comments.

* fix: fix code style.

* fix: fix code style.

* fix: fix typo.

* fix: fix import.
  • Loading branch information
lalalapotter authored Aug 24, 2022
1 parent 2966d49 commit c0c69fe
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 1 deletion.
41 changes: 41 additions & 0 deletions python/orca/src/bigdl/orca/data/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import subprocess
import logging
import shutil
import functools
import glob
from distutils.dir_util import copy_tree
from bigdl.dllib.utils.log4Error import *
Expand Down Expand Up @@ -537,3 +538,43 @@ def get_remote_files_with_prefix_to_local(remote_path_prefix, local_dir):
except Exception as e:
invalidOperationError(False, str(e), cause=e)
return os.path.join(local_dir, prefix)


def multi_fs_load(load_func):
"""
Enable loading file or directory in multiple file systems.
It supports local, hdfs, s3 file systems.
Note: this decorator is different from dllib decorator @enable_multi_fs_load.
This decorator can load on each worker while @enable_multi_fs_load can only load on driver.
:param load_func: load file or directory function
:return: load file or directory function for the specific file system
"""
@functools.wraps(load_func)
def fs_load(path, *args, **kwargs):
from bigdl.dllib.utils.file_utils import is_local_path
if is_local_path(path):
return load_func(path, *args, **kwargs)
else:
import uuid
import tempfile
from bigdl.dllib.utils.file_utils import append_suffix
file_name = str(uuid.uuid1())
file_name = append_suffix(file_name, path.strip("/").split("/")[-1])
temp_path = os.path.join(tempfile.gettempdir(), file_name)
if is_file(path):
get_remote_file_to_local(path, temp_path)
else:
os.mkdir(temp_path)
get_remote_dir_to_local(path, temp_path)
try:
return load_func(temp_path, *args, **kwargs)
finally:
if os.path.isdir(temp_path):
import shutil
shutil.rmtree(temp_path)
else:
os.remove(temp_path)

return fs_load
23 changes: 22 additions & 1 deletion python/orca/test/bigdl/orca/data/test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import shutil
import tempfile

from bigdl.orca.data.file import open_image, open_text, load_numpy, exists, makedirs, write_text
from bigdl.orca.data.file import open_image, open_text, load_numpy, exists, makedirs, write_text, multi_fs_load


class TestFile:
Expand Down Expand Up @@ -160,3 +160,24 @@ def test_write_text_s3(self):
aws_access_key_id=access_key_id,
aws_secret_access_key=secret_access_key).client('s3', verify=False)
s3_client.delete_object(Bucket='analytics-zoo-data', Key='test.txt')

def test_multi_fs_load_local(self):

@multi_fs_load
def mock_func(path):
assert exists(path)

file_path = os.path.join(self.resource_path, "orca/data/random.npy")
mock_func(file_path)

def test_multi_fs_load_s3(self):

@multi_fs_load
def mock_func(path):
assert exists(path)

access_key_id = os.getenv("AWS_ACCESS_KEY_ID")
secret_access_key = os.getenv("AWS_SECRET_ACCESS_KEY")
if access_key_id and secret_access_key:
file_path = "s3://analytics-zoo-data/hyperseg/VGGcompression/core1.npy"
mock_func(file_path)

0 comments on commit c0c69fe

Please sign in to comment.