Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support adding directories in google cloud storage remote #2853

Merged
merged 19 commits into from
Dec 1, 2019
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dvc/remote/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def get_checksum(self, path_info):
return checksum

if self.isdir(path_info):
checksum = self.get_dir_checksum(path_info)
checksum = self.get_dir_checksum(path_info / "")
skshetry marked this conversation as resolved.
Show resolved Hide resolved
else:
checksum = self.get_file_checksum(path_info)

Expand Down
28 changes: 23 additions & 5 deletions dvc/remote/gs.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,16 +136,34 @@ def remove(self, path_info):

blob.delete()

def _list_paths(self, bucket, prefix):
for blob in self.gs.bucket(bucket).list_blobs(prefix=prefix):
def _list_paths(self, bucket, prefix, max_items=None):
skshetry marked this conversation as resolved.
Show resolved Hide resolved
for blob in self.gs.bucket(bucket).list_blobs(
prefix=prefix, max_results=max_items
):
yield blob.name

def list_cache_paths(self):
return self._list_paths(self.path_info.bucket, self.path_info.path)
for cache in self.walk_files(self.path_info):
yield cache.path
skshetry marked this conversation as resolved.
Show resolved Hide resolved

def walk_files(self, path_info):
for fname in self._list_paths(path_info.bucket, path_info.path):
yield path_info.replace(fname)

def isdir(self, path_info):
dir_path = path_info / ""
return bool(
list(
self._list_paths(path_info.bucket, dir_path.path, max_items=1)
)
)

def exists(self, path_info):
paths = set(self._list_paths(path_info.bucket, path_info.path))
return any(path_info.path == path for path in paths)
dir_path = path_info / ""
fname = next(
self._list_paths(path_info.bucket, path_info.path, max_items=1), ""
)
return path_info.path == fname or fname.startswith(dir_path.path)

def _upload(self, from_file, to_info, name=None, no_progress_bar=True):
bucket = self.gs.bucket(to_info.bucket)
Expand Down
82 changes: 74 additions & 8 deletions tests/unit/remote/test_s3.py β†’ tests/func/test_remote_dir.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
# -*- coding: utf-8 -*-
skshetry marked this conversation as resolved.
Show resolved Hide resolved
import pytest
import uuid

from moto import mock_s3

from dvc.remote.gs import RemoteGS
from dvc.remote.s3 import RemoteS3

from tests.func.test_data_cloud import _should_test_gcp

test_gs = pytest.mark.skipif(not _should_test_gcp(), reason="Skipping on gs.")


def create_object_gs(client, bucket, key, body):
bucket = client.get_bucket(bucket)
blob = bucket.blob(key)
blob.upload_from_string(body)


@pytest.fixture
def remote():
def remote_s3():
"""Returns a RemoteS3 connected to a bucket with the following structure:

bucket
Expand Down Expand Up @@ -38,7 +51,54 @@ def remote():
yield remote


def test_isdir(remote):
@pytest.fixture
def remote_gs():
"""Returns a RemoteGS connected to a bucket with the following structure:
bucket
β”œβ”€β”€ data
β”‚ β”œβ”€β”€ alice
β”‚ β”œβ”€β”€ alpha
β”‚ └── subdir
β”‚ β”œβ”€β”€ 1
β”‚ β”œβ”€β”€ 2
β”‚ └── 3
β”œβ”€β”€ empty_dir
β”œβ”€β”€ empty_file
└── foo
"""
prefix = str(uuid.uuid4())
REMOTE_URL = "gs://dvc-test/" + prefix
remote = RemoteGS(None, {"url": REMOTE_URL})
teardowns = []

def put_object(file, content):
skshetry marked this conversation as resolved.
Show resolved Hide resolved
create_object_gs(remote.gs, "dvc-test", prefix + "/" + file, content)
teardowns.append(lambda: remote.remove(remote.path_info / file))

put_object("empty_dir/", "")
put_object("empty_file", "")
put_object("foo", "foo")
put_object("data/alice", "alice")
put_object("data/alpha", "alpha")
put_object("data/subdir/1", "1")
put_object("data/subdir/2", "2")
put_object("data/subdir/3", "3")
skshetry marked this conversation as resolved.
Show resolved Hide resolved

yield remote

for teardown in teardowns:
teardown()


remote_parameterized = pytest.mark.parametrize(
"remote_name", [pytest.param("remote_gs", marks=test_gs), "remote_s3"]
)
skshetry marked this conversation as resolved.
Show resolved Hide resolved


@remote_parameterized
def test_isdir(request, remote_name):
remote = request.getfixturevalue(remote_name)

test_cases = [
(True, "data"),
(True, "data/"),
Expand All @@ -54,7 +114,10 @@ def test_isdir(remote):
assert remote.isdir(remote.path_info / path) == expected


def test_exists(remote):
@remote_parameterized
def test_exists(request, remote_name):
remote = request.getfixturevalue(remote_name)

test_cases = [
(True, "data"),
(True, "data/"),
Expand All @@ -72,7 +135,10 @@ def test_exists(remote):
assert remote.exists(remote.path_info / path) == expected


def test_walk_files(remote):
@remote_parameterized
def test_walk_files(request, remote_name):
remote = request.getfixturevalue(remote_name)

files = [
remote.path_info / "data/alice",
remote.path_info / "data/alpha",
Expand All @@ -84,16 +150,16 @@ def test_walk_files(remote):
assert list(remote.walk_files(remote.path_info / "data")) == files


def test_copy_preserve_etag_across_buckets(remote):
s3 = remote.s3
def test_copy_preserve_etag_across_buckets(remote_s3):
s3 = remote_s3.s3
s3.create_bucket(Bucket="another")

another = RemoteS3(None, {"url": "s3://another", "region": "us-east-1"})

from_info = remote.path_info / "foo"
from_info = remote_s3.path_info / "foo"
to_info = another.path_info / "foo"

remote.copy(from_info, to_info)
remote_s3.copy(from_info, to_info)

from_etag = RemoteS3.get_etag(s3, "bucket", "foo")
to_etag = RemoteS3.get_etag(s3, "another", "foo")
Expand Down