Skip to content

Commit

Permalink
introduce get_nwd function (#726)
Browse files Browse the repository at this point in the history
* introduce get_nwd function

* fix creating nwd, even if it should not be created
  • Loading branch information
PythonFZ authored Oct 2, 2023
1 parent 33cd32b commit af558f4
Show file tree
Hide file tree
Showing 10 changed files with 78 additions and 40 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "ZnTrack"
version = "0.7.0"
version = "0.7.1"
description = "Create, Run and Benchmark DVC Pipelines in Python"
authors = ["zincwarecode <[email protected]>"]
license = "Apache-2.0"
Expand Down
6 changes: 6 additions & 0 deletions tests/integration/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,12 @@ def test_build_groups(tmp_path_2):
with pytest.raises(ValueError):
project.run(nodes=[42])

# assert that the only directories in "nodes/" are "Group1" and "Group2"
assert set(path.name for path in (tmp_path_2 / "nodes").iterdir()) == {
"Group1",
"Group2",
}


def test_groups_nwd(tmp_path_2):
with zntrack.Project(automatic_node_names=True) as project:
Expand Down
5 changes: 5 additions & 0 deletions tests/integration/test_zntrack_deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ def test_named_parent(proj_path):

project.run()

# assert that the only directory in ./nodes is 'c'
assert set(path.name for path in (proj_path / "nodes").iterdir()) == {
"c",
}

c.load()
assert c.name == "c"
assert c.result == 11
Expand Down
2 changes: 1 addition & 1 deletion tests/test_zntrack.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@

def test_version():
"""Test 'ZnTrack' version."""
assert __version__ == "0.7.0"
assert __version__ == "0.7.1"
38 changes: 13 additions & 25 deletions zntrack/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,14 @@

from zntrack import exceptions
from zntrack.notebooks.jupyter import jupyter_class_to_file
from zntrack.utils import NodeName, NodeStatusResults, config, file_io, module_handler
from zntrack.utils import (
NodeName,
NodeStatusResults,
config,
file_io,
get_nwd,
module_handler,
)

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -197,26 +204,7 @@ def state(self) -> NodeStatus:
@property
def nwd(self) -> pathlib.Path:
"""Get the node working directory."""
try:
nwd = self.__dict__["nwd"]
except KeyError:
if (
self.state.remote is None
and self.state.rev is None
and not self.state.loaded
):
nwd = pathlib.Path("nodes", znflow.get_attribute(self, "name"))
else:
try:
with self.state.fs.open(config.files.zntrack) as f:
zntrack_config = json.load(f)
nwd = zntrack_config[znflow.get_attribute(self, "name")]["nwd"]
nwd = json.loads(json.dumps(nwd), cls=znjson.ZnDecoder)
except (FileNotFoundError, KeyError):
nwd = pathlib.Path("nodes", znflow.get_attribute(self, "name"))
if not nwd.exists():
nwd.mkdir(parents=True)
return nwd
return get_nwd(self, mkdir=True)

def save(
self, parameter: bool = True, results: bool = True, meta_only: bool = False
Expand Down Expand Up @@ -258,7 +246,7 @@ def save(
file=config.files.zntrack,
node_name=self.name,
value_name="nwd",
value=self.nwd,
value=get_nwd(self),
)

def run(self) -> None:
Expand Down Expand Up @@ -292,7 +280,7 @@ def load(self, lazy: bool = None, results: bool = True) -> None:
with contextlib.suppress(FileNotFoundError):
# If the uuid is available, we can assume that all data for
# this Node is available.
with self.state.fs.open(self.nwd / "node-meta.json") as f:
with self.state.fs.open(get_nwd(self) / "node-meta.json") as f:
node_meta = json.load(f)
self._uuid = uuid.UUID(node_meta["uuid"])
self.state.results = NodeStatusResults.AVAILABLE
Expand Down Expand Up @@ -369,9 +357,9 @@ def get_dvc_cmd(
cmd += list(field_cmd)

if git_only_repo:
cmd += ["--metrics-no-cache", f"{(node.nwd /'node-meta.json').as_posix()}"]
cmd += ["--metrics-no-cache", f"{(get_nwd(node) /'node-meta.json').as_posix()}"]
else:
cmd += ["--outs", f"{(node.nwd /'node-meta.json').as_posix()}"]
cmd += ["--outs", f"{(get_nwd(node) /'node-meta.json').as_posix()}"]

module = module_handler(node.__class__)
cmd += [f"zntrack run {module}.{node.__class__.__name__} --name {node.name}"]
Expand Down
6 changes: 3 additions & 3 deletions zntrack/fields/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
_default,
_get_all_connections_and_instances,
)
from zntrack.utils import config, update_key_val
from zntrack.utils import config, get_nwd, update_key_val

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -161,7 +161,7 @@ def get_files(self, instance) -> list:
cmd = [
"import",
node.state.remote if node.state.remote is not None else ".",
(node.nwd / "node-meta.json").as_posix(),
(get_nwd(node) / "node-meta.json").as_posix(),
"-o",
deps_file.as_posix(),
]
Expand All @@ -180,7 +180,7 @@ def get_files(self, instance) -> list:
# # nodes with the same name...)
# # and make the uuid a dependency of the node.
# continue
files.append(node.nwd / "node-meta.json")
files.append(get_nwd(node) / "node-meta.json")
for field in zninit.get_descriptors(Field, self=node):
if field.dvc_option in ["params", "deps"]:
# We do not want to depend on parameter files or
Expand Down
4 changes: 2 additions & 2 deletions zntrack/fields/dvc/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import znjson

from zntrack.fields.field import Field, FieldGroup, PlotsMixin
from zntrack.utils import node_wd
from zntrack.utils import get_nwd, node_wd

if typing.TYPE_CHECKING:
from zntrack import Node
Expand Down Expand Up @@ -134,7 +134,7 @@ def __get__(self, instance: "Node", owner=None):
if instance is None:
return self
value = super().__get__(instance, owner)
return node_wd.ReplaceNWD()(value, nwd=instance.nwd)
return node_wd.ReplaceNWD()(value, nwd=get_nwd(instance))


class PlotsOption(PlotsMixin, DVCOption):
Expand Down
18 changes: 10 additions & 8 deletions zntrack/fields/zn/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
LazyField,
PlotsMixin,
)
from zntrack.utils import config, module_handler, update_key_val
from zntrack.utils import config, get_nwd, module_handler, update_key_val

if typing.TYPE_CHECKING:
from zntrack import Node
Expand Down Expand Up @@ -197,7 +197,7 @@ def get_files(self, instance) -> list:
list
A list containing the path of the file.
"""
return [instance.nwd / f"{self.name}.json"]
return [get_nwd(instance) / f"{self.name}.json"]

def save(self, instance: "Node"):
"""Save the field to disk.
Expand Down Expand Up @@ -249,7 +249,7 @@ class Plots(PlotsMixin, LazyField):

def get_files(self, instance) -> list:
"""Get the path of the file in the node directory."""
return [instance.nwd / f"{self.name}.csv"]
return [get_nwd(instance) / f"{self.name}.csv"]

def save(self, instance: "Node"):
"""Save the field to disk."""
Expand Down Expand Up @@ -364,7 +364,7 @@ def get_files(self, instance) -> list:
cmd = [
"import",
node.state.remote if node.state.remote is not None else ".",
(node.nwd / "node-meta.json").as_posix(),
(get_nwd(node) / "node-meta.json").as_posix(),
"-o",
deps_file.as_posix(),
]
Expand All @@ -383,7 +383,7 @@ def get_files(self, instance) -> list:
# # nodes with the same name...)
# # and make the uuid a dependency of the node.
# continue
files.append(node.nwd / "node-meta.json")
files.append(get_nwd(node) / "node-meta.json")
for field in zninit.get_descriptors(Field, self=node):
if field.dvc_option in ["params", "deps"]:
# We do not want to depend on parameter files or
Expand Down Expand Up @@ -504,10 +504,12 @@ def _get_nwd(self, instance: "Node", name: str) -> pathlib.Path:
# get the name of the parent directory as string
# e.g. we have nodes/AL_0/AL_0_ASEMD_checker_list_0
# but want nodes/AL_0/ASEMD_checker_list_0
if name.startswith(instance.nwd.parent.name):
return instance.nwd.parent / name[len(instance.nwd.parent.name) + 1 :]
if name.startswith(get_nwd(instance).parent.name):
return (
get_nwd(instance).parent / name[len(get_nwd(instance).parent.name) + 1 :]
)
else:
return instance.nwd.parent / name
return get_nwd(instance).parent / name

def get_optional_dvc_cmd(
self, instance: "Node", git_only_repo: bool
Expand Down
1 change: 1 addition & 0 deletions zntrack/project/zntrack_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def run(
node: Node = self.graph.nodes[node_uuid]["value"]
if node_names is not None and node.name not in node_names:
continue
node.nwd # create the node working directory (property-access will create it)
if node._external_:
continue
if eager:
Expand Down
36 changes: 36 additions & 0 deletions zntrack/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Standard python init file for the utils directory."""
import dataclasses
import enum
import json
import logging
import os
import pathlib
Expand All @@ -10,6 +11,8 @@
import typing as t

import dvc.cli
import znflow
import znjson

from zntrack.utils import cli
from zntrack.utils.config import config
Expand Down Expand Up @@ -249,3 +252,36 @@ def update_suffix(self, project: "Project", node: "Node") -> None:
if project.automatic_node_names:
while str(self) in node_names:
self.suffix += 1


def get_nwd(node: "Node", mkdir: bool = False) -> pathlib.Path:
"""Get the node working directory.
This is used instead of `node.nwd` because it allows
for parameters to define if the nwd should be created.
Attributes
----------
node: Node
The node instance for which the nwd should be returned.
mkdir: bool, optional
If True, the nwd is created if it does not exist.
"""
try:
nwd = node.__dict__["nwd"]
except KeyError:
if node.state.remote is None and node.state.rev is None and not node.state.loaded:
nwd = pathlib.Path("nodes", znflow.get_attribute(node, "name"))
else:
try:
with node.state.fs.open(config.files.zntrack) as f:
zntrack_config = json.load(f)
nwd = zntrack_config[znflow.get_attribute(node, "name")]["nwd"]
nwd = json.loads(json.dumps(nwd), cls=znjson.ZnDecoder)
except (FileNotFoundError, KeyError):
nwd = pathlib.Path("nodes", znflow.get_attribute(node, "name"))

if mkdir:
nwd.mkdir(parents=True, exist_ok=True)
return nwd

0 comments on commit af558f4

Please sign in to comment.