Skip to content

Commit

Permalink
dvc: use protected mode by default
Browse files Browse the repository at this point in the history
Fixes #3261
Fixes #2041
  • Loading branch information
efiop committed Mar 17, 2020
1 parent e133444 commit a1a497d
Show file tree
Hide file tree
Showing 20 changed files with 189 additions and 118 deletions.
2 changes: 1 addition & 1 deletion dvc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class RelPath(str):
}
LOCAL_COMMON = {
"type": supported_cache_type,
Optional("protected", default=False): Bool,
Optional("protected", default=False): Bool, # obsoleted
"shared": All(Lower, Choices("group")),
Optional("slow_link_warning", default=True): Bool,
}
Expand Down
41 changes: 25 additions & 16 deletions dvc/remote/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,24 @@ class RemoteBASE(object):
DEFAULT_NO_TRAVERSE = True
DEFAULT_VERIFY = False

CACHE_MODE = None
SHARED_MODE_MAP = {None: (None, None), "group": (None, None)}

state = StateNoop()

def __init__(self, repo, config):
self.repo = repo

self._check_requires(config)

shared = config.get("shared")
self._file_mode, self._dir_mode = self.SHARED_MODE_MAP[shared]

self.checksum_jobs = (
config.get("checksum_jobs")
or (self.repo and self.repo.config["core"].get("checksum_jobs"))
or self.CHECKSUM_JOBS
)
self.protected = False
self.no_traverse = config.get("no_traverse", self.DEFAULT_NO_TRAVERSE)
self.verify = config.get("verify", self.DEFAULT_VERIFY)
self._dir_info = {}
Expand Down Expand Up @@ -221,7 +226,7 @@ def get_dir_checksum(self, path_info):
new_info = self.cache.checksum_to_path_info(checksum)
if self.cache.changed_cache_file(checksum):
self.cache.makedirs(new_info.parent)
self.cache.move(tmp_info, new_info)
self.cache.move(tmp_info, new_info, mode=self.CACHE_MODE)

self.state.save(path_info, checksum)
self.state.save(new_info, checksum)
Expand Down Expand Up @@ -409,30 +414,20 @@ def _do_link(self, from_info, to_info, link_method):

link_method(from_info, to_info)

if self.protected:
self.protect(to_info)

logger.debug(
"Created %s'%s': %s -> %s",
"protected " if self.protected else "",
self.cache_types[0],
from_info,
to_info,
"Created %s'%s': %s -> %s", self.cache_types[0], from_info, to_info,
)

def _save_file(self, path_info, checksum, save_link=True):
assert checksum

cache_info = self.checksum_to_path_info(checksum)
if self.changed_cache(checksum):
self.move(path_info, cache_info)
self.move(path_info, cache_info, mode=self.CACHE_MODE)
self.link(cache_info, path_info)
elif self.iscopy(path_info) and self._cache_is_copy(path_info):
# Default relink procedure involves unneeded copy
if self.protected:
self.protect(path_info)
else:
self.unprotect(path_info)
self.unprotect(path_info)
else:
self.remove(path_info)
self.link(cache_info, path_info)
Expand Down Expand Up @@ -656,7 +651,8 @@ def open(self, path_info, mode="r", encoding=None):
def remove(self, path_info):
raise RemoteActionNotImplemented("remove", self.scheme)

def move(self, from_info, to_info):
def move(self, from_info, to_info, mode=None):
assert mode is None
self.copy(from_info, to_info)
self.remove(from_info)

Expand Down Expand Up @@ -718,6 +714,9 @@ def gc(self, named_cache):
removed = True
return removed

def is_protected(self, path_info):
return False

def changed_cache_file(self, checksum):
"""Compare the given checksum with the (corresponding) actual one.
Expand All @@ -730,7 +729,14 @@ def changed_cache_file(self, checksum):
- Remove the file from cache if it doesn't match the actual checksum
"""

cache_info = self.checksum_to_path_info(checksum)
if self.is_protected(cache_info):
logger.debug(
"Assuming '{}' is unchanged since it is read-only", cache_info
)
return False

actual = self.get_checksum(cache_info)

logger.debug(
Expand All @@ -744,6 +750,9 @@ def changed_cache_file(self, checksum):
return True

if actual.split(".")[0] == checksum.split(".")[0]:
# making cache file read-only so we don't need to check it
# next time
self.protect(cache_info)
return False

if self.exists(cache_info):
Expand Down
54 changes: 29 additions & 25 deletions dvc/remote/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,11 @@ class RemoteLOCAL(RemoteBASE):

DEFAULT_CACHE_TYPES = ["reflink", "copy"]

CACHE_MODE = 0o444
SHARED_MODE_MAP = {None: (0o644, 0o755), "group": (0o664, 0o775)}

def __init__(self, repo, config):
super().__init__(repo, config)
self.protected = config.get("protected", False)

shared = config.get("shared")
self._file_mode, self._dir_mode = self.SHARED_MODE_MAP[shared]

if self.protected:
# cache files are set to be read-only for everyone
self._file_mode = stat.S_IREAD | stat.S_IRGRP | stat.S_IROTH

self.cache_dir = config.get("url")
self._dir_info = {}

Expand Down Expand Up @@ -142,23 +134,25 @@ def remove(self, path_info):
if self.exists(path_info):
remove(path_info.fspath)

def move(self, from_info, to_info):
def move(self, from_info, to_info, mode=None):
if from_info.scheme != "local" or to_info.scheme != "local":
raise NotImplementedError

self.makedirs(to_info.parent)

if self.isfile(from_info):
mode = self._file_mode
else:
mode = self._dir_mode
if mode is None:
if self.isfile(from_info):
mode = self._file_mode
else:
mode = self._dir_mode

move(from_info, to_info, mode=mode)

def copy(self, from_info, to_info):
tmp_info = to_info.parent / tmp_fname(to_info.name)
try:
System.copy(from_info, tmp_info)
os.chmod(fspath_py35(tmp_info), self._file_mode)
os.rename(fspath_py35(tmp_info), fspath_py35(to_info))
except Exception:
self.remove(tmp_info)
Expand Down Expand Up @@ -202,9 +196,13 @@ def hardlink(self, from_info, to_info):
def is_hardlink(path_info):
return System.is_hardlink(path_info)

@staticmethod
def reflink(from_info, to_info):
System.reflink(from_info, to_info)
def reflink(self, from_info, to_info):
tmp_info = to_info.parent / tmp_fname(to_info.name)
System.reflink(from_info, tmp_info)
# NOTE: reflink has its own separate inode, so you can set permissions
# that are different from the source.
os.chmod(fspath_py35(tmp_info), self._file_mode)
os.rename(fspath_py35(tmp_info), fspath_py35(to_info))

def cache_exists(self, checksums, jobs=None, name=None):
return [
Expand Down Expand Up @@ -402,8 +400,7 @@ def _log_missing_caches(checksum_info_dict):
)
logger.warning(msg)

@staticmethod
def _unprotect_file(path):
def _unprotect_file(self, path):
if System.is_symlink(path) or System.is_hardlink(path):
logger.debug("Unprotecting '{}'".format(path))
tmp = os.path.join(os.path.dirname(path), "." + uuid())
Expand All @@ -423,13 +420,13 @@ def _unprotect_file(path):
"a symlink or a hardlink.".format(path)
)

os.chmod(path, os.stat(path).st_mode | stat.S_IWRITE)
os.chmod(path, self._file_mode)

def _unprotect_dir(self, path):
assert is_working_tree(self.repo.tree)

for fname in self.repo.tree.walk_files(path):
RemoteLOCAL._unprotect_file(fname)
self._unprotect_file(fname)

def unprotect(self, path_info):
path = path_info.fspath
Expand All @@ -441,12 +438,11 @@ def unprotect(self, path_info):
if os.path.isdir(path):
self._unprotect_dir(path)
else:
RemoteLOCAL._unprotect_file(path)
self._unprotect_file(path)

@staticmethod
def protect(path_info):
def protect(self, path_info):
path = fspath_py35(path_info)
mode = stat.S_IREAD | stat.S_IRGRP | stat.S_IROTH
mode = self.CACHE_MODE

try:
os.chmod(path, mode)
Expand Down Expand Up @@ -519,3 +515,11 @@ def _get_unpacked_dir_names(self, checksums):
if self.is_dir_checksum(c):
unpacked.add(c + self.UNPACKED_DIR_SUFFIX)
return unpacked

def is_protected(self, path_info):
if not self.exists(path_info):
return False

mode = os.stat(fspath_py35(path_info)).st_mode

return stat.S_IMODE(mode) == self.CACHE_MODE
3 changes: 2 additions & 1 deletion dvc/remote/ssh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,8 @@ def remove(self, path_info):
with self.ssh(path_info) as ssh:
ssh.remove(path_info.path)

def move(self, from_info, to_info):
def move(self, from_info, to_info, mode=None):
assert mode is None
if from_info.scheme != self.scheme or to_info.scheme != self.scheme:
raise NotImplementedError

Expand Down
1 change: 1 addition & 0 deletions tests/func/remote/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def test_dir_cache_changed_on_single_cache_file_modification(tmp_dir, dvc):
assert os.path.exists(unpacked_dir)

cache_file_path = dvc.cache.local.get(file_md5)
os.chmod(cache_file_path, 0o644)
with open(cache_file_path, "a") as fobj:
fobj.write("modification")

Expand Down
19 changes: 13 additions & 6 deletions tests/func/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,12 +489,12 @@ def test(self):
ret = main(["config", "cache.type", "hardlink"])
self.assertEqual(ret, 0)

ret = main(["config", "cache.protected", "true"])
self.assertEqual(ret, 0)

ret = main(["add", self.FOO])
self.assertEqual(ret, 0)

self.assertFalse(os.access(self.FOO, os.W_OK))
self.assertTrue(System.is_hardlink(self.FOO))

ret = main(["unprotect", self.FOO])
self.assertEqual(ret, 0)

Expand Down Expand Up @@ -561,7 +561,6 @@ def test_readding_dir_should_not_unprotect_all(tmp_dir, dvc, mocker):
tmp_dir.gen("dir/data", "data")

dvc.cache.local.cache_types = ["symlink"]
dvc.cache.local.protected = True

dvc.add("dir")
tmp_dir.gen("dir/new_file", "new_file_content")
Expand Down Expand Up @@ -618,15 +617,23 @@ def test_should_relink_on_repeated_add(
@pytest.mark.parametrize("link", ["hardlink", "symlink", "copy"])
def test_should_protect_on_repeated_add(link, tmp_dir, dvc):
dvc.cache.local.cache_types = [link]
dvc.cache.local.protected = True

tmp_dir.dvc_gen({"foo": "foo"})

dvc.unprotect("foo")

dvc.add("foo")

assert not os.access("foo", os.W_OK)
assert not os.access(
os.path.join(".dvc", "cache", "ac", "bd18db4cc2f85cedef654fccc4a4d8"),
os.W_OK,
)

# NOTE: Windows symlink perms don't propagate to the target
if link == "copy" or (link == "symlink" and os.name == "nt"):
assert os.access("foo", os.W_OK)
else:
assert not os.access("foo", os.W_OK)


def test_escape_gitignore_entries(tmp_dir, scm, dvc):
Expand Down
9 changes: 5 additions & 4 deletions tests/func/test_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import shutil

import pytest

Expand All @@ -9,6 +8,8 @@
from dvc.exceptions import FileMissingError
from dvc.main import main
from dvc.path_info import URLInfo
from dvc.utils.fs import remove

from tests.remotes import Azure, GCP, HDFS, Local, OSS, S3, SSH


Expand Down Expand Up @@ -65,7 +66,7 @@ def test_open(remote_url, tmp_dir, dvc):
run_dvc("push")

# Remove cache to force download
shutil.rmtree(dvc.cache.local.cache_dir)
remove(dvc.cache.local.cache_dir)

with api.open("foo") as fd:
assert fd.read() == "foo-text"
Expand All @@ -85,7 +86,7 @@ def test_open_external(remote_url, erepo_dir):
erepo_dir.dvc.push(all_branches=True)

# Remove cache to force download
shutil.rmtree(erepo_dir.dvc.cache.local.cache_dir)
remove(erepo_dir.dvc.cache.local.cache_dir)

# Using file url to force clone to tmp repo
repo_url = "file://{}".format(erepo_dir)
Expand All @@ -101,7 +102,7 @@ def test_missing(remote_url, tmp_dir, dvc):
run_dvc("remote", "add", "-d", "upstream", remote_url)

# Remove cache to make foo missing
shutil.rmtree(dvc.cache.local.cache_dir)
remove(dvc.cache.local.cache_dir)

with pytest.raises(FileMissingError):
api.read("foo")
Expand Down
12 changes: 6 additions & 6 deletions tests/func/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,12 @@ def test_default_cache_type(dvc):

@pytest.mark.skipif(os.name == "nt", reason="Not supported for Windows.")
@pytest.mark.parametrize(
"protected,dir_mode,file_mode",
[(False, 0o775, 0o664), (True, 0o775, 0o444)],
"group, dir_mode", [(False, 0o755), (True, 0o775)],
)
def test_shared_cache(tmp_dir, dvc, protected, dir_mode, file_mode):
with dvc.config.edit() as conf:
conf["cache"].update({"shared": "group", "protected": str(protected)})
def test_shared_cache(tmp_dir, dvc, group, dir_mode):
if group:
with dvc.config.edit() as conf:
conf["cache"].update({"shared": "group"})
dvc.cache = Cache(dvc)

tmp_dir.dvc_gen(
Expand All @@ -203,4 +203,4 @@ def test_shared_cache(tmp_dir, dvc, protected, dir_mode, file_mode):

for fname in fnames:
path = os.path.join(root, fname)
assert stat.S_IMODE(os.stat(path).st_mode) == file_mode
assert stat.S_IMODE(os.stat(path).st_mode) == 0o444
Loading

0 comments on commit a1a497d

Please sign in to comment.