Skip to content

Commit

Permalink
exp run: add experiment name check. (#6848)
Browse files Browse the repository at this point in the history
* Add experiment name check.

1. Add experiment name check (https://git-scm.com/docs/git-check-ref-format)
2. Add duplicate exp name check.
3. Add some unit test for it.

* Ban slash / in dvc exp names

* Use dulwich backend for ref name checking

* Some bug fix

* Update dvc/repo/experiments/__init__.py

Co-authored-by: Peter Rowlands (변기호) <[email protected]>

* Update dvc/repo/experiments/__init__.py

Co-authored-by: Peter Rowlands (변기호) <[email protected]>

* Some review changes

* Make some funtion more reusable.

Co-authored-by: Peter Rowlands (변기호) <[email protected]>
  • Loading branch information
karajan1001 and pmrowla authored Oct 25, 2021
1 parent 52a9cdd commit c075921
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 1 deletion.
21 changes: 21 additions & 0 deletions dvc/repo/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,17 @@ def _log_reproduced(self, revs: Iterable[str], tmp_dir: bool = False):
"\tdvc exp branch <exp> <branch>\n"
)

def _validate_new_ref(self, exp_ref: ExpRefInfo):
from .utils import check_ref_format

if not exp_ref.name:
return

check_ref_format(self.scm, exp_ref)

if self.scm.get_ref(str(exp_ref)):
raise ExperimentExistsError(exp_ref.name)

@scm_locked
def new(self, *args, checkpoint_resume: Optional[str] = None, **kwargs):
"""Create a new experiment.
Expand All @@ -485,6 +496,16 @@ def new(self, *args, checkpoint_resume: Optional[str] = None, **kwargs):
*args, resume_rev=checkpoint_resume, **kwargs
)

name = kwargs.get("name", None)
baseline_sha = kwargs.get("baseline_rev") or self.repo.scm.get_rev()
exp_ref = ExpRefInfo(baseline_sha=baseline_sha, name=name)

try:
self._validate_new_ref(exp_ref)
except ExperimentExistsError as err:
if not (kwargs.get("force", False) or kwargs.get("reset", False)):
raise err

return self._stash_exp(*args, **kwargs)

def _resume_checkpoint(
Expand Down
9 changes: 9 additions & 0 deletions dvc/repo/experiments/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,12 @@ def resolve_exp_ref(
msg.extend([f"\t{info}" for info in exp_ref_list])
raise InvalidArgumentError("\n".join(msg))
return exp_ref_list[0]


def check_ref_format(scm: "Git", ref: ExpRefInfo):
# "/" forbidden, only in dvc exp as we didn't support it for now.
if not scm.check_ref_format(str(ref)) or "/" in ref.name:
raise InvalidArgumentError(
f"Invalid exp name {ref.name}, the exp name must follow rules in "
"https://git-scm.com/docs/git-check-ref-format"
)
1 change: 1 addition & 0 deletions dvc/scm/git/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ def get_fs(self, rev: str):
status = partialmethod(_backend_func, "status")
merge = partialmethod(_backend_func, "merge")
validate_git_remote = partialmethod(_backend_func, "validate_git_remote")
check_ref_format = partialmethod(_backend_func, "check_ref_format")

def resolve_rev(self, rev: str) -> str:
from dvc.repo.experiments.utils import exp_refs_by_name
Expand Down
5 changes: 5 additions & 0 deletions dvc/scm/git/backend/dulwich/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,3 +681,8 @@ def validate_git_remote(self, url: str, **kwargs):
os.path.join("", path)
):
raise InvalidRemoteSCMRepo(url)

def check_ref_format(self, refname: str):
from dulwich.refs import check_ref_format

return check_ref_format(refname.encode())
15 changes: 15 additions & 0 deletions tests/func/experiments/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,15 @@ def test_experiment_exists(tmp_dir, scm, dvc, exp_stage, mocker, workspace):
tmp_dir=not workspace,
)

new_mock = mocker.spy(dvc.experiments, "_stash_exp")
with pytest.raises(ExperimentExistsError):
dvc.experiments.run(
exp_stage.addressing,
name="foo",
params=["foo=3"],
tmp_dir=not workspace,
)
new_mock.assert_not_called()

results = dvc.experiments.run(
exp_stage.addressing,
Expand Down Expand Up @@ -685,3 +687,16 @@ def test_exp_run_recursive(tmp_dir, scm, dvc, run_copy_metrics):
)
assert dvc.experiments.run(".", recursive=True)
assert (tmp_dir / "metric.json").parse() == {"foo": 1}


def test_experiment_name_invalid(tmp_dir, scm, dvc, exp_stage, mocker):
from dvc.exceptions import InvalidArgumentError

new_mock = mocker.spy(dvc.experiments, "_stash_exp")
with pytest.raises(InvalidArgumentError):
dvc.experiments.run(
exp_stage.addressing,
name="fo^o",
params=["foo=3"],
)
new_mock.assert_not_called()
16 changes: 15 additions & 1 deletion tests/unit/repo/experiments/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pytest

from dvc.exceptions import InvalidArgumentError
from dvc.repo.experiments.base import EXPS_NAMESPACE, ExpRefInfo
from dvc.repo.experiments.utils import resolve_exp_ref
from dvc.repo.experiments.utils import check_ref_format, resolve_exp_ref


def commit_exp_ref(tmp_dir, scm, file="foo", contents="foo", name="foo"):
Expand All @@ -25,3 +26,16 @@ def test_resolve_exp_ref(tmp_dir, scm, git_upstream, name_only, use_url):
remote_ref_info = resolve_exp_ref(scm, "foo" if name_only else ref, remote)
assert isinstance(remote_ref_info, ExpRefInfo)
assert str(remote_ref_info) == ref


@pytest.mark.parametrize(
"name,result", [("name", True), ("group/name", False), ("na me", False)]
)
def test_run_check_ref_format(scm, name, result):

ref = ExpRefInfo("abc123", name)
if result:
check_ref_format(scm, ref)
else:
with pytest.raises(InvalidArgumentError):
check_ref_format(scm, ref)

0 comments on commit c075921

Please sign in to comment.