Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] dependency: fine grained (user cmd filter) #4363

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions dvc/dependency/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,15 @@
del SCHEMA[BaseOutput.PARAM_METRIC]
SCHEMA.update(RepoDependency.REPO_SCHEMA)
SCHEMA.update(ParamsDependency.PARAM_SCHEMA)
SCHEMA.update({BaseOutput.PARAM_FILTER: str})


def _get(stage, p, info):
if isinstance(p, dict):
p = list(p.items())
assert len(p) == 1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe assert not required (should be handled by schema)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other CLI commands can create/load dependencies, skipping the schema. Good to have an assert.

p, extra_info = p[0] # PARAM_FILTER
info.update(extra_info)
parsed = urlparse(p) if p else None
if parsed and parsed.scheme == "remote":
tree = get_cloud_tree(stage.repo, name=parsed.netloc)
Expand Down
9 changes: 8 additions & 1 deletion dvc/output/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class BaseOutput:

PARAM_PATH = "path"
PARAM_CACHE = "cache"
PARAM_FILTER = "cmd"
PARAM_METRIC = "metric"
PARAM_METRIC_TYPE = "type"
PARAM_METRIC_XPATH = "xpath"
Expand Down Expand Up @@ -176,6 +177,10 @@ def checksum(self):
def checksum(self, checksum):
self.info[self.tree.PARAM_CHECKSUM] = checksum

@property
def filter_cmd(self):
return self.info.get(self.PARAM_FILTER)

def get_checksum(self):
return self.tree.get_hash(self.path_info)

Expand All @@ -188,7 +193,7 @@ def exists(self):
return self.tree.exists(self.path_info)

def save_info(self):
return self.tree.save_info(self.path_info)
return self.tree.save_info(self.path_info, cmd=self.filter_cmd)

def changed_checksum(self):
return self.checksum != self.get_checksum()
Expand Down Expand Up @@ -313,6 +318,8 @@ def dumpd(self):
if self.persist:
ret[self.PARAM_PERSIST] = self.persist

if self.filter_cmd:
ret[self.PARAM_FILTER] = self.filter_cmd
return ret

def verify_metric(self):
Expand Down
7 changes: 4 additions & 3 deletions dvc/repo/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ def isdvc(self, path, **kwargs):
def isexec(self, path): # pylint: disable=unused-argument
return False

def get_file_hash(self, path_info):
def get_file_hash(self, path_info, cmd=None):
assert not cmd, NotImplementedError
outs = self._find_outs(path_info, strict=False)
if len(outs) != 1:
raise OutputNotFoundError
Expand Down Expand Up @@ -404,7 +405,7 @@ def walk_files(self, top, **kwargs): # pylint: disable=arguments-differ
for fname in files:
yield PathInfo(root) / fname

def get_file_hash(self, path_info):
def get_file_hash(self, path_info, cmd=None):
"""Return file checksum for specified path.

If path_info is a DVC out, the pre-computed checksum for the file
Expand All @@ -418,7 +419,7 @@ def get_file_hash(self, path_info):
return self.dvctree.get_file_hash(path_info)
except OutputNotFoundError:
pass
return file_md5(path_info, self)[0]
return file_md5(path_info, self, cmd=cmd)[0]

def copytree(self, top, dest):
top = PathInfo(top)
Expand Down
10 changes: 7 additions & 3 deletions dvc/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
StageParams.PARAM_ALWAYS_CHANGED: bool,
}

DATA_SCHEMA = {**CHECKSUMS_SCHEMA, Required("path"): str}
DATA_SCHEMA = {**CHECKSUMS_SCHEMA, Required(BaseOutput.PARAM_PATH): str}
LOCK_FILE_STAGE_SCHEMA = {
Required(StageParams.PARAM_CMD): str,
StageParams.PARAM_DEPS: [DATA_SCHEMA],
StageParams.PARAM_DEPS: [
{**DATA_SCHEMA, Optional(BaseOutput.PARAM_FILTER): str}
],
StageParams.PARAM_PARAMS: {str: {str: object}},
StageParams.PARAM_OUTS: [DATA_SCHEMA],
}
Expand Down Expand Up @@ -51,7 +53,9 @@
str: {
StageParams.PARAM_CMD: str,
Optional(StageParams.PARAM_WDIR): str,
Optional(StageParams.PARAM_DEPS): [str],
Optional(StageParams.PARAM_DEPS): [
Any(str, {str: {BaseOutput.PARAM_FILTER: str}})
],
Optional(StageParams.PARAM_PARAMS): [
Any(str, PARAM_PSTAGE_NON_DEFAULT_SCHEMA)
],
Expand Down
3 changes: 2 additions & 1 deletion dvc/tree/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ def remove(self, path_info):
logger.debug(f"Removing {path_info}")
self.blob_service.delete_blob(path_info.bucket, path_info.path)

def get_file_hash(self, path_info):
def get_file_hash(self, path_info, cmd=None):
assert not cmd, NotImplementedError
return self.get_etag(path_info)

def _upload(
Expand Down
14 changes: 9 additions & 5 deletions dvc/tree/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class BaseTree:
CACHE_MODE = None
SHARED_MODE_MAP = {None: (None, None), "group": (None, None)}
PARAM_CHECKSUM = None
PARAM_FILTER = None

state = StateNoop()

Expand Down Expand Up @@ -236,7 +237,7 @@ def is_dir_hash(cls, hash_):
return False
return hash_.endswith(cls.CHECKSUM_DIR_SUFFIX)

def get_hash(self, path_info, **kwargs):
def get_hash(self, path_info, cmd=None, **kwargs):
assert path_info and (
isinstance(path_info, str) or path_info.scheme == self.scheme
)
Expand Down Expand Up @@ -265,14 +266,14 @@ def get_hash(self, path_info, **kwargs):
if self.isdir(path_info):
hash_ = self.get_dir_hash(path_info, **kwargs)
else:
hash_ = self.get_file_hash(path_info)
hash_ = self.get_file_hash(path_info, cmd=cmd)

if hash_ and self.exists(path_info):
self.state.save(path_info, hash_)

return hash_

def get_file_hash(self, path_info):
def get_file_hash(self, path_info, cmd=None):
raise NotImplementedError

def get_dir_hash(self, path_info, **kwargs):
Expand All @@ -293,8 +294,11 @@ def path_to_hash(self, path):

return "".join(parts)

def save_info(self, path_info, **kwargs):
return {self.PARAM_CHECKSUM: self.get_hash(path_info, **kwargs)}
def save_info(self, path_info, cmd=None, **kwargs):
ret = {self.PARAM_CHECKSUM: self.get_hash(path_info, **kwargs)}
if cmd:
ret[self.PARAM_FILTER] = cmd
return ret

def _calculate_hashes(self, file_infos):
file_infos = list(file_infos)
Expand Down
2 changes: 1 addition & 1 deletion dvc/tree/gdrive.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ def remove(self, path_info):
item_id = self._get_item_id(path_info)
self.gdrive_delete_file(item_id)

def get_file_hash(self, path_info):
def get_file_hash(self, path_info, cmd=None):
raise NotImplementedError

def _upload(self, from_file, to_info, name=None, no_progress_bar=False):
Expand Down
3 changes: 2 additions & 1 deletion dvc/tree/gs.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ def copy(self, from_info, to_info):
to_bucket = self.gs.bucket(to_info.bucket)
from_bucket.copy_blob(blob, to_bucket, new_name=to_info.path)

def get_file_hash(self, path_info):
def get_file_hash(self, path_info, cmd=None):
assert not cmd, NotImplementedError
import base64
import codecs

Expand Down
3 changes: 2 additions & 1 deletion dvc/tree/hdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ def _group(regex, s, gname):
assert match is not None
return match.group(gname)

def get_file_hash(self, path_info):
def get_file_hash(self, path_info, cmd=None):
assert not cmd, NotImplementedError
# NOTE: pyarrow doesn't support checksum, so we need to use hadoop
regex = r".*\t.*\t(?P<checksum>.*)"
stdout = self.hadoop_fs(
Expand Down
3 changes: 2 additions & 1 deletion dvc/tree/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ def request(self, method, url, **kwargs):
def exists(self, path_info, use_dvcignore=True):
return bool(self.request("HEAD", path_info.url))

def get_file_hash(self, path_info):
def get_file_hash(self, path_info, cmd=None):
assert not cmd, NotImplementedError
url = path_info.url
headers = self.request("HEAD", url).headers
etag = headers.get("ETag") or headers.get("Content-MD5")
Expand Down
5 changes: 3 additions & 2 deletions dvc/tree/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class LocalTree(BaseTree):
scheme = Schemes.LOCAL
PATH_CLS = PathInfo
PARAM_CHECKSUM = "md5"
PARAM_FILTER = "cmd"
PARAM_PATH = "path"
TRAVERSE_PREFIX_LEN = 2
UNPACKED_DIR_SUFFIX = ".unpacked"
Expand Down Expand Up @@ -297,8 +298,8 @@ def is_protected(self, path_info):

return stat.S_IMODE(mode) == self.CACHE_MODE

def get_file_hash(self, path_info):
return file_md5(path_info)[0]
def get_file_hash(self, path_info, cmd=None):
return file_md5(path_info, cmd=cmd)[0]

@staticmethod
def getsize(path_info):
Expand Down
3 changes: 2 additions & 1 deletion dvc/tree/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,8 @@ def _copy(cls, s3, from_info, to_info, extra_args):
if etag != cached_etag:
raise ETagMismatchError(etag, cached_etag)

def get_file_hash(self, path_info):
def get_file_hash(self, path_info, cmd=None):
assert not cmd, NotImplementedError
return self.get_etag(self.s3, path_info.bucket, path_info.path)

def _upload(self, from_file, to_info, name=None, no_progress_bar=False):
Expand Down
3 changes: 2 additions & 1 deletion dvc/tree/ssh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,8 @@ def reflink(self, from_info, to_info):
with self.ssh(from_info) as ssh:
ssh.reflink(from_info.path, to_info.path)

def get_file_hash(self, path_info):
def get_file_hash(self, path_info, cmd=None):
assert not cmd, NotImplementedError
if path_info.scheme != self.scheme:
raise NotImplementedError

Expand Down
3 changes: 2 additions & 1 deletion dvc/tree/webdav.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def exists(self, path_info, use_dvcignore=True):
return self._client.check(path_info.path)

# Gets file hash 'etag'
def get_file_hash(self, path_info):
def get_file_hash(self, path_info, cmd=None):
assert not cmd, NotImplementedError
# Use webdav client info method to get etag
etag = self._client.info(path_info.path)["etag"].strip('"')

Expand Down
27 changes: 25 additions & 2 deletions dvc/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import math
import os
import re
import subprocess
import sys
import tempfile
import time

import colorama
Expand Down Expand Up @@ -43,8 +45,10 @@ def _fobj_md5(fobj, hash_md5, binary, progress_func=None):
progress_func(len(data))


def file_md5(fname, tree=None):
""" get the (md5 hexdigest, md5 digest) of a file """
def file_md5(fname, tree=None, cmd=None):
"""
Returns (md5_hexdigest, md5_digest) of `cmd file` (default: `cmd=cat`)
"""
from dvc.progress import Tqdm
from dvc.istextfile import istextfile

Expand All @@ -58,6 +62,21 @@ def file_md5(fname, tree=None):
open_func = open

if exists_func(fname):
filtered = None
if cmd:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if there are aspects of stage.run.cmd_run which should be used here

p = subprocess.Popen(
cmd.split() + [fname],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
out, err = p.communicate()
if p.returncode != 0:
logger.error("filtering:%s %s", cmd, fname)
raise RuntimeError(err)
with tempfile.NamedTemporaryFile(delete=False) as fobj:
logger.debug("filtering:%s %s > %s", cmd, fname, fobj.name)
fobj.write(out)
fname = filtered = fobj.name
hash_md5 = hashlib.md5()
binary = not istextfile(fname, tree=tree)
size = stat_func(fname).st_size
Expand All @@ -80,6 +99,10 @@ def file_md5(fname, tree=None):
with open_func(fname, "rb") as fobj:
_fobj_md5(fobj, hash_md5, binary, pbar.update)

if filtered is not None:
from dvc.utils.fs import remove

remove(filtered)
Comment on lines +102 to +105
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if this is required - maybe automatically handled elsewhere (i.e. entire tmpdir deleted before exit)

return (hash_md5.hexdigest(), hash_md5.digest())

return (None, None)
Expand Down