From 8c362a6527d01e5658ae1387b907b27fae32f345 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Fri, 19 May 2023 09:46:35 +0545 Subject: [PATCH] stage add: introduce --run option Part of #5846. --- dvc/commands/stage.py | 16 +++++++++++++++- dvc/repo/stage.py | 19 +++++++++---------- tests/unit/command/test_stage.py | 12 ++++++++++++ 3 files changed, 36 insertions(+), 11 deletions(-) diff --git a/dvc/commands/stage.py b/dvc/commands/stage.py index f3c11aec3c..4f427fb25c 100644 --- a/dvc/commands/stage.py +++ b/dvc/commands/stage.py @@ -121,6 +121,8 @@ def quote_argument(arg: str): class CmdStageAdd(CmdBase): def run(self): + from dvc.repo import lock_repo + kwargs = vars(self.args) kwargs.update( { @@ -128,7 +130,13 @@ def run(self): "params": parse_params(self.args.params), } ) - self.repo.stage.add(**kwargs) + + with self.repo.scm_context, lock_repo(self.repo): + stage = self.repo.stage.add(**kwargs) + if self.args.run: + stage.run() + stage.dump(update_pipeline=False) + return 0 @@ -257,6 +265,12 @@ def _add_common_args(parser): "This doesn't affect any DVC operations." ), ) + parser.add_argument( + "--run", + action="store_true", + default=False, + help="Execute the stage after generating it.", + ) parser.add_argument( "command", nargs=argparse.REMAINDER, diff --git a/dvc/repo/stage.py b/dvc/repo/stage.py index 29765ddf8e..321101934d 100644 --- a/dvc/repo/stage.py +++ b/dvc/repo/stage.py @@ -122,16 +122,15 @@ def add( force=force, **stage_data, ) - with self.repo.scm_context: - stage.dump(update_lock=update_lock) - try: - stage.ignore_outs() - except FileNotFoundError as exc: - ui.warn( - f"Could not create .gitignore entry in {exc.filename}." - " DVC will attempt to create .gitignore entry again when" - " the stage is run." - ) + stage.dump(update_lock=update_lock) + try: + stage.ignore_outs() + except FileNotFoundError as exc: + ui.warn( + f"Could not create .gitignore entry in {exc.filename}." + " DVC will attempt to create .gitignore entry again when" + " the stage is run." + ) return stage diff --git a/tests/unit/command/test_stage.py b/tests/unit/command/test_stage.py index 32eb9aef3d..2b090a31bd 100644 --- a/tests/unit/command/test_stage.py +++ b/tests/unit/command/test_stage.py @@ -86,3 +86,15 @@ def test_stage_add(mocker, dvc, command, parsed_command): cmd=parsed_command, force=True, ) + + +def test_stage_add_and_run(mocker, dvc): + cli_args = parse_args(["stage", "add", "--run", "-n", "foo", "-o", "foo", "cmd"]) + cmd = cli_args.func(cli_args) + add_mock = mocker.patch.object(cmd.repo.stage, "add") + + assert cmd.run() == 0 + + assert called_once_with_subset(add_mock, name="foo", outs=["foo"], cmd="cmd") + add_mock.return_value.run.assert_called_once() + add_mock.return_value.dump.assert_called_once_with(update_pipeline=False)