diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 95506ac565..ab93fc3c26 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -100,6 +100,9 @@ jobs: GDRIVE_CREDENTIALS_DATA: ${{ secrets.GDRIVE_CREDENTIALS_DATA }} AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + OSS_ACCESS_KEY_ID: ${{ secrets.OSS_ACCESS_KEY_ID}} + OSS_ACCESS_KEY_SECRET: ${{ secrets.OSS_ACCESS_KEY_SECRET}} + OSS_ENDPOINT: ${{ secrets.OSS_ENDPOINT}} run: >- python -m tests -n=4 --cov-report=xml --cov-report=term diff --git a/dvc/fs/oss.py b/dvc/fs/oss.py index bf97b9b955..d0f52454c5 100644 --- a/dvc/fs/oss.py +++ b/dvc/fs/oss.py @@ -8,124 +8,73 @@ from dvc.progress import Tqdm from dvc.scheme import Schemes -from .base import BaseFileSystem +from .fsspec_wrapper import ObjectFSWrapper logger = logging.getLogger(__name__) -class OSSFileSystem(BaseFileSystem): # pylint:disable=abstract-method - """ - oss2 document: - https://www.alibabacloud.com/help/doc-detail/32026.htm - - - Examples - ---------- - $ dvc remote add myremote oss://my-bucket/path - Set key id, key secret and endpoint using modify command - $ dvc remote modify myremote oss_key_id my-key-id - $ dvc remote modify myremote oss_key_secret my-key-secret - $ dvc remote modify myremote oss_endpoint endpoint - or environment variables - $ export OSS_ACCESS_KEY_ID="my-key-id" - $ export OSS_ACCESS_KEY_SECRET="my-key-secret" - $ export OSS_ENDPOINT="endpoint" - """ - +# pylint:disable=abstract-method +class OSSFileSystem(ObjectFSWrapper): scheme = Schemes.OSS PATH_CLS = CloudURLInfo - REQUIRES = {"oss2": "oss2"} + REQUIRES = {"ossfs": "ossfs"} PARAM_CHECKSUM = "etag" COPY_POLL_SECONDS = 5 LIST_OBJECT_PAGE_SIZE = 100 + DETAIL_FIELDS = frozenset(("etag", "size")) - def __init__(self, **config): - super().__init__(**config) - - self.endpoint = config.get("oss_endpoint") or os.getenv("OSS_ENDPOINT") - - self.key_id = ( - config.get("oss_key_id") - or os.getenv("OSS_ACCESS_KEY_ID") - or "defaultId" + def _prepare_credentials(self, **config): + login_info = {} + login_info["key"] = config.get("oss_key_id") or os.getenv( + "OSS_ACCESS_KEY_ID" ) - - self.key_secret = ( - config.get("oss_key_secret") - or os.getenv("OSS_ACCESS_KEY_SECRET") - or "defaultSecret" + login_info["secret"] = config.get("oss_key_secret") or os.getenv( + "OSS_ACCESS_KEY_SECRET" ) + login_info["endpoint"] = config.get("oss_endpoint") + return login_info @wrap_prop(threading.Lock()) @cached_property - def oss_service(self): - import oss2 - - logger.debug(f"key id: {self.key_id}") - logger.debug(f"key secret: {self.key_secret}") - - return oss2.Auth(self.key_id, self.key_secret) - - def _get_bucket(self, bucket): - import oss2 - - return oss2.Bucket(self.oss_service, self.endpoint, bucket) + def fs(self): + from ossfs import OSSFileSystem as _OSSFileSystem - def _generate_download_url(self, path_info, expires=3600): - return self._get_bucket(path_info.bucket).sign_url( - "GET", path_info.path, expires - ) - - def exists(self, path_info) -> bool: - paths = self._list_paths(path_info) - return any(path_info.path == path for path in paths) - - def _list_paths(self, path_info): - import oss2 - - for blob in oss2.ObjectIterator( - self._get_bucket(path_info.bucket), prefix=path_info.path - ): - yield blob.key - - def walk_files(self, path_info, **kwargs): - if not kwargs.pop("prefix", False): - path_info = path_info / "" - for fname in self._list_paths(path_info): - if fname.endswith("/"): - continue - - yield path_info.replace(path=fname) + return _OSSFileSystem(**self.fs_args) def remove(self, path_info): - if path_info.scheme != self.scheme: - raise NotImplementedError - - logger.debug(f"Removing oss://{path_info}") - self._get_bucket(path_info.bucket).delete_object(path_info.path) - - def _upload_fobj(self, fobj, to_info, **kwargs): - self._get_bucket(to_info.bucket).put_object(to_info.path, fobj) + self.fs.rm_file(self._with_bucket(path_info)) def _upload( - self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs + self, from_file, to_info, name=None, no_progress_bar=False, **kwargs ): - with Tqdm(desc=name, disable=no_progress_bar, bytes=True) as pbar: - bucket = self._get_bucket(to_info.bucket) - bucket.put_object_from_file( - to_info.path, from_file, progress_callback=pbar.update_to + total = os.path.getsize(from_file) + with Tqdm( + disable=no_progress_bar, + total=total, + bytes=True, + desc=name, + **kwargs, + ) as pbar: + self.fs.put_file( + from_file, + self._with_bucket(to_info), + progress_callback=pbar.update_to, ) + self.fs.invalidate_cache(self._with_bucket(to_info.parent)) def _download( - self, from_info, to_file, name=None, no_progress_bar=False, **_kwargs + self, from_info, to_file, name=None, no_progress_bar=False, **pbar_args ): - with Tqdm(desc=name, disable=no_progress_bar, bytes=True) as pbar: - import oss2 - - bucket = self._get_bucket(from_info.bucket) - oss2.resumable_download( - bucket, - from_info.path, + total = self.fs.size(self._with_bucket(from_info)) + with Tqdm( + disable=no_progress_bar, + total=total, + bytes=True, + desc=name, + **pbar_args, + ) as pbar: + self.fs.get_file( + self._with_bucket(from_info), to_file, progress_callback=pbar.update_to, ) diff --git a/dvc/objects/db/__init__.py b/dvc/objects/db/__init__.py index 38fb441bea..1534956143 100644 --- a/dvc/objects/db/__init__.py +++ b/dvc/objects/db/__init__.py @@ -5,6 +5,7 @@ def get_odb(fs, path_info, **config): from .base import ObjectDB from .gdrive import GDriveObjectDB from .local import LocalObjectDB + from .oss import OSSObjectDB from .ssh import SSHObjectDB if fs.scheme == Schemes.LOCAL: @@ -16,6 +17,9 @@ def get_odb(fs, path_info, **config): if fs.scheme == Schemes.GDRIVE: return GDriveObjectDB(fs, path_info, **config) + if fs.scheme == Schemes.OSS: + return OSSObjectDB(fs, path_info, **config) + return ObjectDB(fs, path_info, **config) diff --git a/dvc/objects/db/oss.py b/dvc/objects/db/oss.py new file mode 100644 index 0000000000..226fcc1b4b --- /dev/null +++ b/dvc/objects/db/oss.py @@ -0,0 +1,9 @@ +from .base import ObjectDB + + +class OSSObjectDB(ObjectDB): + """ + Temporary extra verification + """ + + DEFAULT_VERIFY = True diff --git a/setup.py b/setup.py index 60aef8d5a7..c0b1dd05d3 100644 --- a/setup.py +++ b/setup.py @@ -101,7 +101,7 @@ def run(self): s3 = ["s3fs==2021.6.1", "aiobotocore[boto3]==1.3.0"] azure = ["adlfs==2021.7.0", "azure-identity>=1.4.0", "knack"] # https://github.com/Legrandin/pycryptodome/issues/465 -oss = ["oss2==2.6.1", "pycryptodome>=3.10"] +oss = ["ossfs==2021.7.3"] ssh = ["paramiko[invoke]>=2.7.0"] # Remove the env marker if/when pyarrow is available for Python3.9 diff --git a/tests/func/test_data_cloud.py b/tests/func/test_data_cloud.py index c1c97632f6..85cb8617eb 100644 --- a/tests/func/test_data_cloud.py +++ b/tests/func/test_data_cloud.py @@ -30,14 +30,8 @@ "hdfs", "webdav", "webhdfs", + "oss", ] -] + [ - pytest.param( - pytest.lazy_fixture("oss"), - marks=pytest.mark.xfail( - reason="https://github.com/iterative/dvc/issues/4633" - ), - ) ] # Clouds that implement the general methods that can be tested diff --git a/tests/remotes/oss.py b/tests/remotes/oss.py index 42ee82f855..16d929deb9 100644 --- a/tests/remotes/oss.py +++ b/tests/remotes/oss.py @@ -9,7 +9,7 @@ from .base import Base -TEST_OSS_REPO_BUCKET = "dvc-test" +TEST_OSS_REPO_BUCKET = "dvc-test-github" EMULATOR_OSS_ENDPOINT = "127.0.0.1:{port}" EMULATOR_OSS_ACCESS_KEY_ID = "AccessKeyID" EMULATOR_OSS_ACCESS_KEY_SECRET = "AccessKeySecret" @@ -63,22 +63,15 @@ def _check(): @pytest.fixture -def oss(oss_server): +def oss(real_oss): import oss2 - url = OSS.get_url() - ret = OSS(url) - ret.config = { - "url": url, - "oss_key_id": EMULATOR_OSS_ACCESS_KEY_ID, - "oss_key_secret": EMULATOR_OSS_ACCESS_KEY_SECRET, - "oss_endpoint": oss_server, - } + ret = real_oss - auth = oss2.Auth( - EMULATOR_OSS_ACCESS_KEY_ID, EMULATOR_OSS_ACCESS_KEY_SECRET + auth = oss2.Auth(ret.config["oss_key_id"], ret.config["oss_key_secret"]) + bucket = oss2.Bucket( + auth, ret.config["oss_endpoint"], TEST_OSS_REPO_BUCKET ) - bucket = oss2.Bucket(auth, oss_server, TEST_OSS_REPO_BUCKET) try: bucket.get_bucket_info() except oss2.exceptions.NoSuchBucket: diff --git a/tests/unit/remote/test_oss.py b/tests/unit/remote/test_oss.py index da3db3fec6..f208a6d37c 100644 --- a/tests/unit/remote/test_oss.py +++ b/tests/unit/remote/test_oss.py @@ -16,6 +16,6 @@ def test_init(dvc): "oss_endpoint": endpoint, } fs = OSSFileSystem(**config) - assert fs.endpoint == endpoint - assert fs.key_id == key_id - assert fs.key_secret == key_secret + assert fs.fs._endpoint == endpoint + assert fs.fs._auth.id == key_id + assert fs.fs._auth.secret == key_secret