Skip to content

Commit

Permalink
exp push: Handle rev arg as list.
Browse files Browse the repository at this point in the history
  • Loading branch information
daavoo committed May 18, 2023
1 parent e6e3912 commit ecb71c8
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
8 changes: 5 additions & 3 deletions dvc/repo/experiments/push.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ def exp_refs_from_names(scm: "Git", exp_names: List[str]) -> Set["ExpRefInfo"]:
return exp_ref_set


def exp_refs_from_rev(scm: "Git", rev: str, num: int = 1) -> Set["ExpRefInfo"]:
def exp_refs_from_rev(scm: "Git", rev: List[str], num: int = 1) -> Set["ExpRefInfo"]:
exp_ref_set = set()
rev_dict = iter_revs(scm, [rev], num)
rev_dict = iter_revs(scm, rev, num)
rev_set = set(rev_dict.keys())
ref_info_dict = exp_refs_by_baseline(scm, rev_set)
for _, ref_info_list in ref_info_dict.items():
Expand All @@ -99,7 +99,7 @@ def push(
git_remote: str,
exp_names: Union[List[str], str],
all_commits: bool = False,
rev: Optional[str] = None,
rev: Optional[Union[List[str], str]] = None,
num: int = 1,
force: bool = False,
push_cache: bool = False,
Expand All @@ -112,6 +112,8 @@ def push(
if exp_names:
exp_ref_set.update(exp_refs_from_names(repo.scm, ensure_list(exp_names)))
if rev:
if isinstance(rev, str):
rev = [rev]
exp_ref_set.update(exp_refs_from_rev(repo.scm, rev, num=num))

push_result = _push(repo, git_remote, exp_ref_set, force)
Expand Down
23 changes: 23 additions & 0 deletions tests/func/experiments/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,29 @@ def test_push_args(tmp_dir, scm, dvc, git_upstream, exp_stage, all_, rev, result
assert git_upstream.tmp_dir.scm.get_ref(str(ref_info3)) == result3


def test_push_multi_rev(tmp_dir, scm, dvc, git_upstream, exp_stage):
remote = git_upstream.url
baseline = scm.get_rev()

results = dvc.experiments.run(exp_stage.addressing, params=["foo=1"])
exp1 = first(results)
ref_info1 = first(exp_refs_by_rev(scm, exp1))
results = dvc.experiments.run(exp_stage.addressing, params=["foo=2"])
exp2 = first(results)
ref_info2 = first(exp_refs_by_rev(scm, exp2))

scm.commit("new_baseline")

results = dvc.experiments.run(exp_stage.addressing, params=["foo=3"])
exp3 = first(results)
ref_info3 = first(exp_refs_by_rev(scm, exp3))

dvc.experiments.push(remote, [], rev=[baseline, scm.get_rev()])
assert git_upstream.tmp_dir.scm.get_ref(str(ref_info1)) == exp1
assert git_upstream.tmp_dir.scm.get_ref(str(ref_info2)) == exp2
assert git_upstream.tmp_dir.scm.get_ref(str(ref_info3)) == exp3


def test_push_diverged(tmp_dir, scm, dvc, git_upstream, exp_stage):
git_upstream.tmp_dir.scm_gen("foo", "foo", commit="init")
remote_rev = git_upstream.tmp_dir.scm.get_rev()
Expand Down

0 comments on commit ecb71c8

Please sign in to comment.