Skip to content

Commit

Permalink
Fix set not being auto-tracked
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry committed Oct 20, 2020
1 parent 856864d commit add12f9
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 32 deletions.
7 changes: 3 additions & 4 deletions dvc/parsing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,7 @@ def _resolve_entry(self, name: str, definition):
def resolve(self):
stages = self.data.get(STAGES_KWD, {})
data = join(starmap(self._resolve_entry, stages.items()))
logger.trace( # pytype: disable=attribute-error
"Resolved dvc.yaml:\n%s", dumps_yaml(data)
)
logger.trace("Resolved dvc.yaml:\n%s", dumps_yaml(data))
return {STAGES_KWD: data}

def _resolve_stage(self, context: Context, name: str, definition) -> dict:
Expand Down Expand Up @@ -108,8 +106,9 @@ def _resolve_stage(self, context: Context, name: str, definition) -> dict:
logger.trace( # pytype: disable=attribute-error
"Context during resolution of stage %s:\n%s", name, context
)

with context.track():
stage_d = resolve(definition, context)
stage_d = resolve(definition, context, unwrap=True)

params = stage_d.get(PARAMS_KWD, []) + context.tracked

Expand Down
62 changes: 56 additions & 6 deletions dvc/parsing/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,53 @@ def path(self):
@dataclass
class Value:
value: Any
meta: Optional[Meta] = field(compare=False, default=None, repr=False)
meta: Meta = field(compare=False, repr=False)

def __repr__(self):
return f"'{self}'"

def __str__(self) -> str:
return str(self.value)

def get_sources(self):
return {self.meta.source: self.meta.path()}


class String:
"""
Wrapper around string, that can interpolate, and keep the
original source of those interpolations.
"""

def __init__(self, template, matches, context):

from .interpolate import _resolve_value

index, buf = 0, ""
self.meta = defaultdict(set)
for match in matches:
start, end = match.span(0)
val = _resolve_value(match, context)
self._add_source(val)
buf += template[index:start] + str(val)
index = end
value = buf + template[index:]
self.value = value.replace(r"\${", "${")

def __repr__(self) -> str:
return str(self.value)

def _add_source(self, val: Union[Value, "String"]):
# string might have been built from multiple sources
if isinstance(val, Value) and val.meta and val.meta.source:
self.meta[val.meta.source].add(val.meta.path())
if isinstance(val, String) and val.meta:
for source, keys in self.meta.items():
self.meta[source].update(keys)

def get_sources(self):
return self.meta


class Container:
meta: Meta
Expand All @@ -66,7 +105,7 @@ def _convert(self, key, value):
meta = Meta.update_path(self.meta, key)
if value is None or isinstance(value, (int, float, str, bytes, bool)):
return Value(value, meta=meta)
elif isinstance(value, (CtxList, CtxDict, Value)):
elif isinstance(value, (CtxList, CtxDict, Value, String)):
return value
elif isinstance(value, (list, dict)):
container = CtxDict if isinstance(value, dict) else CtxList
Expand Down Expand Up @@ -108,6 +147,9 @@ def select(self, key: str):
) from exc
return d.select(rems[0]) if rems else d

def get_sources(self):
return {}


class CtxList(Container, MutableSequence):
_key_transform = staticmethod(int)
Expand All @@ -120,6 +162,9 @@ def __init__(self, values: Sequence, meta: Meta = None):
def insert(self, index: int, value):
self.data.insert(index, self._convert(index, value))

def get_sources(self):
return {self.meta.source: self.meta.path()}


class CtxDict(Container, MutableMapping):
def __init__(self, mapping: Mapping = None, meta: Meta = None, **kwargs):
Expand Down Expand Up @@ -158,10 +203,15 @@ def track(self):
self._track = False

def _track_data(self, node):
if isinstance(node, (Value, CtxList)):
meta = node.meta
if meta and meta.source and self._track:
self._tracked_data[meta.source].add(meta.path())
if not self._track:
return

for source, keys in node.get_sources().items():
if not source:
continue
params_file = self._tracked_data[source]
keys = [keys] if isinstance(keys, str) else keys
params_file.update(keys)

@property
def tracked(self):
Expand Down
37 changes: 15 additions & 22 deletions dvc/parsing/interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from funcy import rpartial

from dvc.parsing.context import Context, Value
from dvc.parsing.context import Context, String, Value

KEYCRE = re.compile(
r"""
Expand All @@ -16,6 +16,8 @@
re.VERBOSE,
)

UNWRAP_DEFAULT = False


def _get_matches(template):
return list(KEYCRE.finditer(template))
Expand All @@ -24,43 +26,34 @@ def _get_matches(template):
def _resolve_value(match, context: Context):
_, _, inner = match.groups()
value = context.select(inner)
if isinstance(value, Value):
return value.value
return value


def _str_interpolate(template, matches, context):
index, buf = 0, ""
for match in matches:
start, end = match.span(0)
buf += template[index:start] + str(_resolve_value(match, context))
index = end
return buf + template[index:]
def _unwrap(value):
if isinstance(value, (Value, String)):
return value.value
return value


def _resolve_str(src: str, context):
def _resolve_str(src: str, context, unwrap=UNWRAP_DEFAULT):
matches = _get_matches(src)
if len(matches) == 1 and src == matches[0].group(0):
# replace "${enabled}", if `enabled` is a boolean, with it's actual
# value rather than it's string counterparts.
return _resolve_value(matches[0], context)
elif matches:
# but not "${num} days"
src = _str_interpolate(src, matches, context)

# regex already backtracks and avoids any `${` starting with
# backslashes(`\`). We just need to replace those by `${`.
return src.replace(r"\${", "${")
value = _resolve_value(matches[0], context)
else:
value = String(src, matches, context)
return _unwrap(value) if unwrap else value


def resolve(src, context):
def resolve(src, context, unwrap=UNWRAP_DEFAULT):
Seq = (list, tuple, set)

apply_value = rpartial(resolve, context)
apply_value = rpartial(resolve, context, unwrap=unwrap)
if isinstance(src, Mapping):
return {key: apply_value(value) for key, value in src.items()}
elif isinstance(src, Seq):
return type(src)(map(apply_value, src))
elif isinstance(src, str):
return _resolve_str(src, context)
return _resolve_str(src, context, unwrap=unwrap)
return src

0 comments on commit add12f9

Please sign in to comment.