Skip to content

Commit

Permalink
exp list: show experiment shas (#9501)
Browse files Browse the repository at this point in the history
* exp list: include sha; still needs remote support

* exp list: ignore rev for remotes

* exp list: update output format

* refactor
  • Loading branch information
Dave Berenbaum authored May 25, 2023
1 parent 79c8800 commit 8acd943
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 25 deletions.
34 changes: 26 additions & 8 deletions dvc/commands/experiments/ls.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from dvc.cli.command import CmdBase
from dvc.cli.utils import append_doc_link
from dvc.exceptions import InvalidArgumentError
from dvc.ui import ui

logger = logging.getLogger(__name__)
Expand All @@ -11,19 +12,29 @@
class CmdExperimentsList(CmdBase):
def run(self):
name_only = self.args.name_only
sha_only = self.args.sha_only
git_remote = self.args.git_remote
if sha_only and git_remote:
raise InvalidArgumentError("--sha-only not supported with git_remote.")
exps = self.repo.experiments.ls(
all_commits=self.args.all_commits,
rev=self.args.rev,
num=self.args.num,
git_remote=self.args.git_remote,
git_remote=git_remote,
)

for baseline in exps:
if not name_only:
ui.write(f"{baseline}:")
for exp_name in exps[baseline]:
indent = "" if name_only else "\t"
ui.write(f"{indent}{exp_name}")
if not (name_only or sha_only):
ui.write(f"{baseline[:7]}:")
for exp_name, rev in exps[baseline]:
if name_only:
ui.write(exp_name)
elif sha_only:
ui.write(rev)
elif rev:
ui.write(f"\t{rev[:7]} [{exp_name}]")
else:
ui.write(f"\t{exp_name}")

return 0

Expand All @@ -40,11 +51,18 @@ def add_parser(experiments_subparsers, parent_parser):
formatter_class=argparse.RawDescriptionHelpFormatter,
)
add_rev_selection_flags(experiments_list_parser, "List")
experiments_list_parser.add_argument(
display_group = experiments_list_parser.add_mutually_exclusive_group()
display_group.add_argument(
"--name-only",
"--names-only",
action="store_true",
help="Only output experiment names (without parent commits).",
help="Only output experiment names (without SHAs or parent commits).",
)
display_group.add_argument(
"--sha-only",
"--shas-only",
action="store_true",
help="Only output experiment commit SHAs (without names or parent commits).",
)
experiments_list_parser.add_argument(
"git_remote",
Expand Down
17 changes: 13 additions & 4 deletions dvc/repo/experiments/ls.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from collections import defaultdict
from typing import List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union

from dvc.repo import locked
from dvc.repo.scm_context import scm_context
Expand All @@ -19,7 +19,11 @@ def ls(
all_commits: bool = False,
num: int = 1,
git_remote: Optional[str] = None,
):
) -> Dict[str, List[Tuple[str, Optional[str]]]]:
"""List experiments.
Returns a dict mapping baseline revs to a list of (exp_name, exp_sha) tuples.
"""
rev_set = None
if not all_commits:
rev = rev or "HEAD"
Expand All @@ -36,9 +40,14 @@ def ls(

results = defaultdict(list)
for baseline in ref_info_dict:
name = baseline[:7]
name = baseline
if tags[baseline] or ref_heads[baseline]:
name = tags[baseline] or ref_heads[baseline][len(base) + 1 :]
results[name] = [info.name for info in ref_info_dict[baseline]]
for info in ref_info_dict[baseline]:
if git_remote:
exp_rev = None
else:
exp_rev = repo.scm.get_ref(str(info))
results[name].append((info.name, exp_rev))

return results
16 changes: 8 additions & 8 deletions tests/func/experiments/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,30 +362,30 @@ def test_list(tmp_dir, scm, dvc, exp_stage):
exp_c = first(results)
ref_info_c = first(exp_refs_by_rev(scm, exp_c))

assert dvc.experiments.ls() == {"master": [ref_info_c.name]}
assert dvc.experiments.ls() == {"master": [(ref_info_c.name, exp_c)]}

exp_list = dvc.experiments.ls(rev=ref_info_a.baseline_sha)
assert {key: set(val) for key, val in exp_list.items()} == {
baseline_a[:7]: {ref_info_a.name, ref_info_b.name}
baseline_a: {(ref_info_a.name, exp_a), (ref_info_b.name, exp_b)}
}

exp_list = dvc.experiments.ls(rev=[baseline_a, scm.get_rev()])
assert {key: set(val) for key, val in exp_list.items()} == {
baseline_a[:7]: {ref_info_a.name, ref_info_b.name},
"master": {ref_info_c.name},
baseline_a: {(ref_info_a.name, exp_a), (ref_info_b.name, exp_b)},
"master": {(ref_info_c.name, exp_c)},
}

exp_list = dvc.experiments.ls(all_commits=True)
assert {key: set(val) for key, val in exp_list.items()} == {
baseline_a[:7]: {ref_info_a.name, ref_info_b.name},
"master": {ref_info_c.name},
baseline_a: {(ref_info_a.name, exp_a), (ref_info_b.name, exp_b)},
"master": {(ref_info_c.name, exp_c)},
}

scm.checkout("branch", True)
exp_list = dvc.experiments.ls(all_commits=True)
assert {key: set(val) for key, val in exp_list.items()} == {
baseline_a[:7]: {ref_info_a.name, ref_info_b.name},
"branch": {ref_info_c.name},
baseline_a: {(ref_info_a.name, exp_a), (ref_info_b.name, exp_b)},
"branch": {(ref_info_c.name, exp_c)},
}


Expand Down
6 changes: 3 additions & 3 deletions tests/func/experiments/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,13 @@ def test_list_remote(tmp_dir, scm, dvc, git_downstream, exp_stage, use_url):
git_downstream.tmp_dir.scm.fetch_refspecs(remote, ["master:master"])
exp_list = downstream_exp.ls(rev=baseline_a, git_remote=remote)
assert {key: set(val) for key, val in exp_list.items()} == {
baseline_a[:7]: {ref_info_a.name, ref_info_b.name}
baseline_a: {(ref_info_a.name, None), (ref_info_b.name, None)}
}

exp_list = downstream_exp.ls(all_commits=True, git_remote=remote)
assert {key: set(val) for key, val in exp_list.items()} == {
baseline_a[:7]: {ref_info_a.name, ref_info_b.name},
"master": {ref_info_c.name},
baseline_a: {(ref_info_a.name, None), (ref_info_b.name, None)},
"master": {(ref_info_c.name, None)},
}


Expand Down
4 changes: 2 additions & 2 deletions tests/func/experiments/test_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def test_exp_save_after_commit(tmp_dir, dvc, scm):
dvc.experiments.save(name="exp-2", force=True)

all_exps = dvc.experiments.ls(all_commits=True)
assert all_exps[baseline[:7]] == ["exp-1"]
assert all_exps["master"] == ["exp-2"]
assert all_exps[baseline][0][0] == "exp-1"
assert all_exps["master"][0][0] == "exp-2"


def test_exp_save_with_staged_changes(tmp_dir, dvc, scm):
Expand Down
56 changes: 56 additions & 0 deletions tests/unit/command/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,62 @@ def test_experiments_list(dvc, scm, mocker):
)


@pytest.mark.parametrize(
"args,expected",
[
([], "main:\n\tsha-a [exp-a]\n"),
(["--name-only"], "exp-a\n"),
(["--sha-only"], "sha-a\n"),
],
)
def test_experiments_list_format(mocker, capsys, args, expected):
mocker.patch(
"dvc.repo.experiments.ls.ls",
return_value={
"main": [
("exp-a", "sha-a"),
]
},
)
raw_args = ["experiments", "list", *args]
cli_args = parse_args(raw_args)

cmd = cli_args.func(cli_args)

capsys.readouterr()
assert cmd.run() == 0
cap = capsys.readouterr()
assert cap.out == expected


def test_experiments_list_remote(mocker, capsys):
mocker.patch(
"dvc.repo.experiments.ls.ls",
return_value={
"main": [
("exp-a", None),
]
},
)
cli_args = parse_args(["experiments", "list", "git_remote"])

cmd = cli_args.func(cli_args)

capsys.readouterr()
assert cmd.run() == 0
cap = capsys.readouterr()
assert cap.out == "main:\n\texp-a\n"

cli_args = parse_args(["experiments", "list", "git_remote", "--sha-only"])

cmd = cli_args.func(cli_args)

capsys.readouterr()

with pytest.raises(InvalidArgumentError):
cmd.run()


def test_experiments_push(dvc, scm, mocker):
cli_args = parse_args(
[
Expand Down

0 comments on commit 8acd943

Please sign in to comment.