From 4dca420bbbd37f1d2ff11e3214ec5c0d7712ce5b Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Sat, 7 Aug 2021 11:39:57 +0800 Subject: [PATCH 01/10] Remove a special queued experiments fix #6157 1. dvc exp remove not accept queued experiments name 2. add some tests for this feature --- dvc/repo/experiments/remove.py | 61 ++++++++++++++-------- tests/func/experiments/test_experiments.py | 19 ++++++- 2 files changed, 57 insertions(+), 23 deletions(-) diff --git a/dvc/repo/experiments/remove.py b/dvc/repo/experiments/remove.py index a919608bb6..307abf715c 100644 --- a/dvc/repo/experiments/remove.py +++ b/dvc/repo/experiments/remove.py @@ -1,4 +1,5 @@ import logging +from typing import Optional from dvc.exceptions import InvalidArgumentError from dvc.repo import locked @@ -21,32 +22,35 @@ def remove(repo, exp_names=None, queue=False, **kwargs): removed += len(repo.experiments.stash) repo.experiments.stash.clear() if exp_names: - ref_infos = list(_get_exp_refs(repo, exp_names)) - remove_exp_refs(repo.scm, ref_infos) - removed += len(ref_infos) + for exp_name in exp_names: + _remove_exp_by_name(repo, exp_name) + removed += 1 return removed -def _get_exp_refs(repo, exp_names): +def _get_exp_stash_index(repo, exp_name: str) -> Optional[int]: + stash_ref_infos = repo.experiments.stash_revs + for _, ref_info in stash_ref_infos.items(): + if ref_info.name == exp_name: + return ref_info.index + return None + + +def _get_exp_ref(repo, exp_name: str) -> Optional[ExpRefInfo]: cur_rev = repo.scm.get_rev() - for name in exp_names: - if name.startswith(EXPS_NAMESPACE): - if not repo.scm.get_ref(name): - raise InvalidArgumentError( - f"'{name}' is not a valid experiment name" - ) - yield ExpRefInfo.from_ref(name) - else: - - exp_refs = list(exp_refs_by_name(repo.scm, name)) - if not exp_refs: - raise InvalidArgumentError( - f"'{name}' is not a valid experiment name" - ) - yield _get_ref(exp_refs, name, cur_rev) - - -def _get_ref(ref_infos, name, cur_rev): + if exp_name.startswith(EXPS_NAMESPACE): + if repo.scm.get_ref(exp_name): + return ExpRefInfo.from_ref(exp_name) + else: + exp_refs = list(exp_refs_by_name(repo.scm, exp_name)) + if exp_refs: + return _get_ref(exp_refs, exp_name, cur_rev) + return None + + +def _get_ref(ref_infos, name, cur_rev) -> Optional[ExpRefInfo]: + if len(ref_infos) == 0: + return None if len(ref_infos) > 1: for info in ref_infos: if info.baseline_sha == cur_rev: @@ -61,3 +65,16 @@ def _get_ref(ref_infos, name, cur_rev): msg.extend([f"\t{info}" for info in ref_infos]) raise InvalidArgumentError("\n".join(msg)) return ref_infos[0] + + +def _remove_exp_by_name(repo, exp_name: str): + ref_info = _get_exp_ref(repo, exp_name) + if ref_info is not None: + remove_exp_refs(repo.scm, [ref_info]) + else: + stash_index = _get_exp_stash_index(repo, exp_name) + if stash_index is None: + raise InvalidArgumentError( + f"'{exp_name}' is not a valid experiment name" + ) + repo.experiments.stash.drop(stash_index) diff --git a/tests/func/experiments/test_experiments.py b/tests/func/experiments/test_experiments.py index 47451f4cb3..7cd63f499f 100644 --- a/tests/func/experiments/test_experiments.py +++ b/tests/func/experiments/test_experiments.py @@ -580,12 +580,29 @@ def test_remove(tmp_dir, scm, dvc, exp_stage): results = dvc.experiments.run(exp_stage.addressing, params=["foo=2"]) exp = first(results) ref_info = first(exp_refs_by_rev(scm, exp)) - dvc.experiments.run(exp_stage.addressing, params=["foo=3"], queue=True) + results = dvc.experiments.run( + exp_stage.addressing, params=["foo=3"], queue=True, name="queue1" + ) + queue1 = first(results) + results = dvc.experiments.run( + exp_stage.addressing, params=["foo=4"], queue=True, name="queue2" + ) + queue2 = first(results) + results = dvc.experiments.run( + exp_stage.addressing, params=["foo=5"], queue=True, name="queue3" + ) + queue3 = first(results) removed = dvc.experiments.remove([str(ref_info)]) assert removed == 1 assert scm.get_ref(str(ref_info)) is None + removed = dvc.experiments.remove(["queue1", "queue2"]) + assert removed == 2 + assert queue1 not in dvc.experiments.stash_revs + assert queue2 not in dvc.experiments.stash_revs + assert queue3 in dvc.experiments.stash_revs + removed = dvc.experiments.remove(queue=True) assert removed == 1 assert len(dvc.experiments.stash) == 0 From 25fe483db55722ed9d72702d6095d7f85ab9353c Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Tue, 10 Aug 2021 15:47:07 +0800 Subject: [PATCH 02/10] Extract tests and add revision support 1. Extract remove experiments to a new file. 2. revision can be used to remove special queued experiment --- dvc/repo/experiments/remove.py | 29 ++++++++------ tests/func/experiments/test_experiments.py | 32 --------------- tests/func/experiments/test_remove.py | 46 ++++++++++++++++++++++ 3 files changed, 63 insertions(+), 44 deletions(-) create mode 100644 tests/func/experiments/test_remove.py diff --git a/dvc/repo/experiments/remove.py b/dvc/repo/experiments/remove.py index 307abf715c..9adb2a6755 100644 --- a/dvc/repo/experiments/remove.py +++ b/dvc/repo/experiments/remove.py @@ -13,25 +13,29 @@ @locked @scm_context -def remove(repo, exp_names=None, queue=False, **kwargs): - if not exp_names and not queue: +def remove(repo, refs_or_revs=None, queue=False, **kwargs): + if not refs_or_revs and not queue: return 0 removed = 0 if queue: removed += len(repo.experiments.stash) repo.experiments.stash.clear() - if exp_names: - for exp_name in exp_names: - _remove_exp_by_name(repo, exp_name) + if refs_or_revs: + for ref_or_rev in refs_or_revs: + _remove_exp_by_ref_or_rev(repo, ref_or_rev) removed += 1 return removed -def _get_exp_stash_index(repo, exp_name: str) -> Optional[int]: +def _get_exp_stash_index(repo, ref_or_rev: str) -> Optional[int]: stash_ref_infos = repo.experiments.stash_revs - for _, ref_info in stash_ref_infos.items(): - if ref_info.name == exp_name: + print("*" * 100) + for rev, ref_info in stash_ref_infos.items(): + print(rev, ref_info) + if ref_info.name == ref_or_rev: + return ref_info.index + if rev == ref_or_rev: return ref_info.index return None @@ -67,14 +71,15 @@ def _get_ref(ref_infos, name, cur_rev) -> Optional[ExpRefInfo]: return ref_infos[0] -def _remove_exp_by_name(repo, exp_name: str): - ref_info = _get_exp_ref(repo, exp_name) +def _remove_exp_by_ref_or_rev(repo, ref_or_rev: str): + ref_info = _get_exp_ref(repo, ref_or_rev) if ref_info is not None: remove_exp_refs(repo.scm, [ref_info]) else: - stash_index = _get_exp_stash_index(repo, exp_name) + stash_index = _get_exp_stash_index(repo, ref_or_rev) if stash_index is None: raise InvalidArgumentError( - f"'{exp_name}' is not a valid experiment name" + f"'{ref_or_rev}' is neither a valid experiment reference" + " nor a queued experiment revision" ) repo.experiments.stash.drop(stash_index) diff --git a/tests/func/experiments/test_experiments.py b/tests/func/experiments/test_experiments.py index 7cd63f499f..5a755f8042 100644 --- a/tests/func/experiments/test_experiments.py +++ b/tests/func/experiments/test_experiments.py @@ -576,38 +576,6 @@ def test_run_metrics(tmp_dir, scm, dvc, exp_stage, mocker): assert show_mock.called_once() -def test_remove(tmp_dir, scm, dvc, exp_stage): - results = dvc.experiments.run(exp_stage.addressing, params=["foo=2"]) - exp = first(results) - ref_info = first(exp_refs_by_rev(scm, exp)) - results = dvc.experiments.run( - exp_stage.addressing, params=["foo=3"], queue=True, name="queue1" - ) - queue1 = first(results) - results = dvc.experiments.run( - exp_stage.addressing, params=["foo=4"], queue=True, name="queue2" - ) - queue2 = first(results) - results = dvc.experiments.run( - exp_stage.addressing, params=["foo=5"], queue=True, name="queue3" - ) - queue3 = first(results) - - removed = dvc.experiments.remove([str(ref_info)]) - assert removed == 1 - assert scm.get_ref(str(ref_info)) is None - - removed = dvc.experiments.remove(["queue1", "queue2"]) - assert removed == 2 - assert queue1 not in dvc.experiments.stash_revs - assert queue2 not in dvc.experiments.stash_revs - assert queue3 in dvc.experiments.stash_revs - - removed = dvc.experiments.remove(queue=True) - assert removed == 1 - assert len(dvc.experiments.stash) == 0 - - def test_checkout_targets_deps(tmp_dir, scm, dvc, exp_stage): from dvc.utils.fs import remove diff --git a/tests/func/experiments/test_remove.py b/tests/func/experiments/test_remove.py new file mode 100644 index 0000000000..484135d15e --- /dev/null +++ b/tests/func/experiments/test_remove.py @@ -0,0 +1,46 @@ +from funcy import first + +from dvc.repo.experiments.utils import exp_refs_by_rev + + +def test_remove_experiments_by_ref(tmp_dir, scm, dvc, exp_stage): + results = dvc.experiments.run(exp_stage.addressing, params=["foo=2"]) + exp = first(results) + ref_info = first(exp_refs_by_rev(scm, exp)) + + removed = dvc.experiments.remove([str(ref_info)]) + assert removed == 1 + assert scm.get_ref(str(ref_info)) is None + + +def test_remove_all_queued_experiments(tmp_dir, scm, dvc, exp_stage): + queue_length = 3 + for i in range(queue_length): + dvc.experiments.run( + exp_stage.addressing, params=[f"foo={i}"], queue=True + ) + + assert len(dvc.experiments.stash) == queue_length + removed = dvc.experiments.remove(queue=True) + assert removed == queue_length + assert len(dvc.experiments.stash) == 0 + + +def test_remove_special_queued_experiments(tmp_dir, scm, dvc, exp_stage): + results = dvc.experiments.run( + exp_stage.addressing, params=["foo=1"], queue=True, name="queue1" + ) + rev1 = first(results) + results = dvc.experiments.run( + exp_stage.addressing, params=["foo=2"], queue=True, name="queue2" + ) + rev2 = first(results) + assert rev1 in dvc.experiments.stash_revs + assert rev2 in dvc.experiments.stash_revs + + assert dvc.experiments.remove(["queue1"]) == 1 + assert rev1 not in dvc.experiments.stash_revs + assert rev2 in dvc.experiments.stash_revs + + assert dvc.experiments.remove([rev2]) == 1 + assert len(dvc.experiments.stash) == 0 From 443e17b572fccda0a6142a7d86eda45faf0653d4 Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Tue, 10 Aug 2021 16:37:41 +0800 Subject: [PATCH 03/10] Accept shortened revisions --- dvc/repo/experiments/remove.py | 12 ++++++------ tests/func/experiments/test_remove.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/dvc/repo/experiments/remove.py b/dvc/repo/experiments/remove.py index 9adb2a6755..edda105a95 100644 --- a/dvc/repo/experiments/remove.py +++ b/dvc/repo/experiments/remove.py @@ -29,14 +29,14 @@ def remove(repo, refs_or_revs=None, queue=False, **kwargs): def _get_exp_stash_index(repo, ref_or_rev: str) -> Optional[int]: - stash_ref_infos = repo.experiments.stash_revs - print("*" * 100) - for rev, ref_info in stash_ref_infos.items(): - print(rev, ref_info) + stash_revs = repo.experiments.stash_revs + for _, ref_info in stash_revs.items(): if ref_info.name == ref_or_rev: return ref_info.index - if rev == ref_or_rev: - return ref_info.index + rev = repo.scm.resolve_rev(ref_or_rev) + if rev in stash_revs: + return stash_revs.get(rev).index + return None diff --git a/tests/func/experiments/test_remove.py b/tests/func/experiments/test_remove.py index 484135d15e..57facd231e 100644 --- a/tests/func/experiments/test_remove.py +++ b/tests/func/experiments/test_remove.py @@ -42,5 +42,5 @@ def test_remove_special_queued_experiments(tmp_dir, scm, dvc, exp_stage): assert rev1 not in dvc.experiments.stash_revs assert rev2 in dvc.experiments.stash_revs - assert dvc.experiments.remove([rev2]) == 1 + assert dvc.experiments.remove([rev2[:5]]) == 1 assert len(dvc.experiments.stash) == 0 From ad59b4bd858751315d350d01c2720b7b40b5e0ff Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Tue, 10 Aug 2021 18:16:36 +0800 Subject: [PATCH 04/10] Split removing committed and queued exp functions --- dvc/repo/experiments/remove.py | 57 ++++++++++++++++++--------- tests/func/experiments/test_remove.py | 8 +++- 2 files changed, 46 insertions(+), 19 deletions(-) diff --git a/dvc/repo/experiments/remove.py b/dvc/repo/experiments/remove.py index edda105a95..bb8b1886e1 100644 --- a/dvc/repo/experiments/remove.py +++ b/dvc/repo/experiments/remove.py @@ -1,9 +1,10 @@ import logging -from typing import Optional +from typing import List, Optional from dvc.exceptions import InvalidArgumentError from dvc.repo import locked from dvc.repo.scm_context import scm_context +from dvc.scm.base import RevError from .base import EXPS_NAMESPACE, ExpRefInfo from .utils import exp_refs_by_name, remove_exp_refs @@ -22,9 +23,16 @@ def remove(repo, refs_or_revs=None, queue=False, **kwargs): removed += len(repo.experiments.stash) repo.experiments.stash.clear() if refs_or_revs: - for ref_or_rev in refs_or_revs: - _remove_exp_by_ref_or_rev(repo, ref_or_rev) - removed += 1 + remained = _remove_commited_experiments(repo, refs_or_revs) + remained = _remove_queued_experiements(repo, remained) + if remained: + logger.warning( + "'{}' is neither a valid experiment reference" + " nor a revision of queued experiment".format( + ";".join(remained) + ) + ) + removed += len(refs_or_revs) - len(remained) return removed @@ -33,10 +41,12 @@ def _get_exp_stash_index(repo, ref_or_rev: str) -> Optional[int]: for _, ref_info in stash_revs.items(): if ref_info.name == ref_or_rev: return ref_info.index - rev = repo.scm.resolve_rev(ref_or_rev) - if rev in stash_revs: - return stash_revs.get(rev).index - + try: + rev = repo.scm.resolve_rev(ref_or_rev) + if rev in stash_revs: + return stash_revs.get(rev).index + except RevError: + pass return None @@ -71,15 +81,26 @@ def _get_ref(ref_infos, name, cur_rev) -> Optional[ExpRefInfo]: return ref_infos[0] -def _remove_exp_by_ref_or_rev(repo, ref_or_rev: str): - ref_info = _get_exp_ref(repo, ref_or_rev) - if ref_info is not None: - remove_exp_refs(repo.scm, [ref_info]) - else: +def _remove_commited_experiments(repo, refs: List[str]) -> List[str]: + remain_list = [] + remove_list = [] + for ref in refs: + ref_info = _get_exp_ref(repo, ref) + if ref_info: + remove_list.append(ref_info) + else: + remain_list.append(ref) + if remove_list: + remove_exp_refs(repo.scm, remove_list) + return remain_list + + +def _remove_queued_experiements(repo, refs_or_revs: List[str]) -> List[str]: + remain_list = [] + for ref_or_rev in refs_or_revs: stash_index = _get_exp_stash_index(repo, ref_or_rev) if stash_index is None: - raise InvalidArgumentError( - f"'{ref_or_rev}' is neither a valid experiment reference" - " nor a queued experiment revision" - ) - repo.experiments.stash.drop(stash_index) + remain_list.append(ref_or_rev) + else: + repo.experiments.stash.drop(stash_index) + return remain_list diff --git a/tests/func/experiments/test_remove.py b/tests/func/experiments/test_remove.py index 57facd231e..c3f6f80245 100644 --- a/tests/func/experiments/test_remove.py +++ b/tests/func/experiments/test_remove.py @@ -3,11 +3,17 @@ from dvc.repo.experiments.utils import exp_refs_by_rev -def test_remove_experiments_by_ref(tmp_dir, scm, dvc, exp_stage): +def test_remove_experiments_by_ref(tmp_dir, scm, dvc, exp_stage, caplog): results = dvc.experiments.run(exp_stage.addressing, params=["foo=2"]) exp = first(results) ref_info = first(exp_refs_by_rev(scm, exp)) + assert dvc.experiments.remove(["non-exist"]) == 0 + assert ( + "'non-exist' is neither a valid experiment " + "reference nor a revision of queued experiment" + ) in caplog.text + removed = dvc.experiments.remove([str(ref_info)]) assert removed == 1 assert scm.get_ref(str(ref_info)) is None From d56aae8b5efacfae8baef7cbf6263b1e9b730211 Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Tue, 10 Aug 2021 20:30:13 +0800 Subject: [PATCH 05/10] Api name change --- dvc/command/experiments.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dvc/command/experiments.py b/dvc/command/experiments.py index fa79072cf5..73910ba959 100644 --- a/dvc/command/experiments.py +++ b/dvc/command/experiments.py @@ -747,7 +747,7 @@ class CmdExperimentsRemove(CmdBase): def run(self): self.repo.experiments.remove( - exp_names=self.args.experiment, queue=self.args.queue + refs_or_revs=self.args.experiment, queue=self.args.queue ) return 0 From e2b3cba282ca82eb428b424222b8a427e0a3f77e Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Tue, 10 Aug 2021 20:31:35 +0800 Subject: [PATCH 06/10] Return to the old API name --- dvc/command/experiments.py | 2 +- dvc/repo/experiments/remove.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/dvc/command/experiments.py b/dvc/command/experiments.py index 73910ba959..fa79072cf5 100644 --- a/dvc/command/experiments.py +++ b/dvc/command/experiments.py @@ -747,7 +747,7 @@ class CmdExperimentsRemove(CmdBase): def run(self): self.repo.experiments.remove( - refs_or_revs=self.args.experiment, queue=self.args.queue + exp_names=self.args.experiment, queue=self.args.queue ) return 0 diff --git a/dvc/repo/experiments/remove.py b/dvc/repo/experiments/remove.py index bb8b1886e1..f83952356a 100644 --- a/dvc/repo/experiments/remove.py +++ b/dvc/repo/experiments/remove.py @@ -14,16 +14,16 @@ @locked @scm_context -def remove(repo, refs_or_revs=None, queue=False, **kwargs): - if not refs_or_revs and not queue: +def remove(repo, exp_names=None, queue=False, **kwargs): + if not exp_names and not queue: return 0 removed = 0 if queue: removed += len(repo.experiments.stash) repo.experiments.stash.clear() - if refs_or_revs: - remained = _remove_commited_experiments(repo, refs_or_revs) + if exp_names: + remained = _remove_commited_experiments(repo, exp_names) remained = _remove_queued_experiements(repo, remained) if remained: logger.warning( @@ -32,7 +32,7 @@ def remove(repo, refs_or_revs=None, queue=False, **kwargs): ";".join(remained) ) ) - removed += len(refs_or_revs) - len(remained) + removed += len(exp_names) - len(remained) return removed From afb596d074b68835db74441b262fa9163df17f19 Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Tue, 10 Aug 2021 21:38:42 +0800 Subject: [PATCH 07/10] shorten some functions --- dvc/repo/experiments/remove.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dvc/repo/experiments/remove.py b/dvc/repo/experiments/remove.py index f83952356a..32af07d7d1 100644 --- a/dvc/repo/experiments/remove.py +++ b/dvc/repo/experiments/remove.py @@ -23,8 +23,8 @@ def remove(repo, exp_names=None, queue=False, **kwargs): removed += len(repo.experiments.stash) repo.experiments.stash.clear() if exp_names: - remained = _remove_commited_experiments(repo, exp_names) - remained = _remove_queued_experiements(repo, remained) + remained = _remove_commited_exps(repo, exp_names) + remained = _remove_queued_exps(repo, remained) if remained: logger.warning( "'{}' is neither a valid experiment reference" @@ -81,7 +81,7 @@ def _get_ref(ref_infos, name, cur_rev) -> Optional[ExpRefInfo]: return ref_infos[0] -def _remove_commited_experiments(repo, refs: List[str]) -> List[str]: +def _remove_commited_exps(repo, refs: List[str]) -> List[str]: remain_list = [] remove_list = [] for ref in refs: @@ -95,7 +95,7 @@ def _remove_commited_experiments(repo, refs: List[str]) -> List[str]: return remain_list -def _remove_queued_experiements(repo, refs_or_revs: List[str]) -> List[str]: +def _remove_queued_exps(repo, refs_or_revs: List[str]) -> List[str]: remain_list = [] for ref_or_rev in refs_or_revs: stash_index = _get_exp_stash_index(repo, ref_or_rev) From 870a04cb050dbf3bf5b2827e7ade48091b1108fa Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Wed, 11 Aug 2021 16:39:55 +0800 Subject: [PATCH 08/10] Error message change --- dvc/repo/experiments/remove.py | 7 +------ tests/func/experiments/test_remove.py | 5 +---- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/dvc/repo/experiments/remove.py b/dvc/repo/experiments/remove.py index 32af07d7d1..e781dab1f7 100644 --- a/dvc/repo/experiments/remove.py +++ b/dvc/repo/experiments/remove.py @@ -27,10 +27,7 @@ def remove(repo, exp_names=None, queue=False, **kwargs): remained = _remove_queued_exps(repo, remained) if remained: logger.warning( - "'{}' is neither a valid experiment reference" - " nor a revision of queued experiment".format( - ";".join(remained) - ) + "'{}' is not a valid experiment".format(";".join(remained)) ) removed += len(exp_names) - len(remained) return removed @@ -63,8 +60,6 @@ def _get_exp_ref(repo, exp_name: str) -> Optional[ExpRefInfo]: def _get_ref(ref_infos, name, cur_rev) -> Optional[ExpRefInfo]: - if len(ref_infos) == 0: - return None if len(ref_infos) > 1: for info in ref_infos: if info.baseline_sha == cur_rev: diff --git a/tests/func/experiments/test_remove.py b/tests/func/experiments/test_remove.py index c3f6f80245..aeabc9f169 100644 --- a/tests/func/experiments/test_remove.py +++ b/tests/func/experiments/test_remove.py @@ -9,10 +9,7 @@ def test_remove_experiments_by_ref(tmp_dir, scm, dvc, exp_stage, caplog): ref_info = first(exp_refs_by_rev(scm, exp)) assert dvc.experiments.remove(["non-exist"]) == 0 - assert ( - "'non-exist' is neither a valid experiment " - "reference nor a revision of queued experiment" - ) in caplog.text + assert ("'non-exist' is not a valid experiment") in caplog.text removed = dvc.experiments.remove([str(ref_info)]) assert removed == 1 From da7849f88d93bc48df99cb831b3cb1a2049b6ff4 Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Thu, 12 Aug 2021 16:08:15 +0800 Subject: [PATCH 09/10] Better test cases, more corner case --- tests/func/experiments/test_remove.py | 53 +++++++++++++++++++-------- 1 file changed, 38 insertions(+), 15 deletions(-) diff --git a/tests/func/experiments/test_remove.py b/tests/func/experiments/test_remove.py index aeabc9f169..1c3755446e 100644 --- a/tests/func/experiments/test_remove.py +++ b/tests/func/experiments/test_remove.py @@ -4,29 +4,40 @@ def test_remove_experiments_by_ref(tmp_dir, scm, dvc, exp_stage, caplog): - results = dvc.experiments.run(exp_stage.addressing, params=["foo=2"]) - exp = first(results) - ref_info = first(exp_refs_by_rev(scm, exp)) + queue_length = 3 + ref_list = [] - assert dvc.experiments.remove(["non-exist"]) == 0 - assert ("'non-exist' is not a valid experiment") in caplog.text + for i in range(queue_length): + results = dvc.experiments.run( + exp_stage.addressing, params=[f"foo={i}"] + ) + ref_info = first(exp_refs_by_rev(scm, first(results))) + ref_list.append(str(ref_info)) - removed = dvc.experiments.remove([str(ref_info)]) - assert removed == 1 - assert scm.get_ref(str(ref_info)) is None + assert dvc.experiments.remove(ref_list[:2] + ["non-exist"]) == 2 + assert ("'non-exist' is not a valid experiment") in caplog.text + assert scm.get_ref(str(ref_list[0])) is None + assert scm.get_ref(str(ref_list[1])) is None + assert scm.get_ref(str(ref_list[2])) is not None def test_remove_all_queued_experiments(tmp_dir, scm, dvc, exp_stage): queue_length = 3 + for i in range(queue_length): dvc.experiments.run( exp_stage.addressing, params=[f"foo={i}"], queue=True ) + results = dvc.experiments.run( + exp_stage.addressing, params=[f"foo={queue_length}"] + ) + ref_info = first(exp_refs_by_rev(scm, first(results))) + assert len(dvc.experiments.stash) == queue_length - removed = dvc.experiments.remove(queue=True) - assert removed == queue_length + assert dvc.experiments.remove(queue=True) == queue_length assert len(dvc.experiments.stash) == 0 + assert scm.get_ref(str(ref_info)) is not None def test_remove_special_queued_experiments(tmp_dir, scm, dvc, exp_stage): @@ -38,12 +49,24 @@ def test_remove_special_queued_experiments(tmp_dir, scm, dvc, exp_stage): exp_stage.addressing, params=["foo=2"], queue=True, name="queue2" ) rev2 = first(results) + results = dvc.experiments.run( + exp_stage.addressing, params=["foo=3"], queue=True, name="queue3" + ) + rev3 = first(results) + results = dvc.experiments.run(exp_stage.addressing, params=["foo=4"]) + ref_info1 = first(exp_refs_by_rev(scm, first(results))) + results = dvc.experiments.run(exp_stage.addressing, params=["foo=5"]) + ref_info2 = first(exp_refs_by_rev(scm, first(results))) + assert rev1 in dvc.experiments.stash_revs assert rev2 in dvc.experiments.stash_revs + assert rev3 in dvc.experiments.stash_revs + assert scm.get_ref(str(ref_info1)) is not None + assert scm.get_ref(str(ref_info2)) is not None - assert dvc.experiments.remove(["queue1"]) == 1 + assert dvc.experiments.remove(["queue1", rev2[:5], str(ref_info1)]) == 3 assert rev1 not in dvc.experiments.stash_revs - assert rev2 in dvc.experiments.stash_revs - - assert dvc.experiments.remove([rev2[:5]]) == 1 - assert len(dvc.experiments.stash) == 0 + assert rev2 not in dvc.experiments.stash_revs + assert rev3 in dvc.experiments.stash_revs + assert scm.get_ref(str(ref_info1)) is None + assert scm.get_ref(str(ref_info2)) is not None From 8af4038e8969925995ed417a0fac73f62c94eb96 Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Thu, 12 Aug 2021 16:17:27 +0800 Subject: [PATCH 10/10] Still raise exception in a mixed case --- dvc/repo/experiments/remove.py | 2 +- tests/func/experiments/test_remove.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/dvc/repo/experiments/remove.py b/dvc/repo/experiments/remove.py index e781dab1f7..8755ac3004 100644 --- a/dvc/repo/experiments/remove.py +++ b/dvc/repo/experiments/remove.py @@ -26,7 +26,7 @@ def remove(repo, exp_names=None, queue=False, **kwargs): remained = _remove_commited_exps(repo, exp_names) remained = _remove_queued_exps(repo, remained) if remained: - logger.warning( + raise InvalidArgumentError( "'{}' is not a valid experiment".format(";".join(remained)) ) removed += len(exp_names) - len(remained) diff --git a/tests/func/experiments/test_remove.py b/tests/func/experiments/test_remove.py index 1c3755446e..f12f142532 100644 --- a/tests/func/experiments/test_remove.py +++ b/tests/func/experiments/test_remove.py @@ -1,5 +1,7 @@ +import pytest from funcy import first +from dvc.exceptions import InvalidArgumentError from dvc.repo.experiments.utils import exp_refs_by_rev @@ -14,8 +16,8 @@ def test_remove_experiments_by_ref(tmp_dir, scm, dvc, exp_stage, caplog): ref_info = first(exp_refs_by_rev(scm, first(results))) ref_list.append(str(ref_info)) - assert dvc.experiments.remove(ref_list[:2] + ["non-exist"]) == 2 - assert ("'non-exist' is not a valid experiment") in caplog.text + with pytest.raises(InvalidArgumentError): + assert dvc.experiments.remove(ref_list[:2] + ["non-exist"]) assert scm.get_ref(str(ref_list[0])) is None assert scm.get_ref(str(ref_list[1])) is None assert scm.get_ref(str(ref_list[2])) is not None