Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stage:add: use remainder as a command script #5350

Merged
merged 3 commits into from
Feb 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 4 additions & 27 deletions dvc/command/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging

from dvc.command.base import CmdBase, append_doc_link
from dvc.command.stage import parse_cmd
from dvc.exceptions import DvcException

logger = logging.getLogger(__name__)
Expand All @@ -22,7 +23,7 @@ def run(self):
self.args.outs_persist_no_cache,
self.args.checkpoints,
self.args.params,
self.args.command,
self.args.cmd,
]
): # pragma: no cover
logger.error(
Expand All @@ -34,7 +35,7 @@ def run(self):

try:
self.repo.run(
cmd=self._parsed_cmd(),
cmd=parse_cmd(self.args.cmd),
outs=self.args.outs,
outs_no_cache=self.args.outs_no_cache,
metrics=self.args.metrics,
Expand Down Expand Up @@ -67,27 +68,6 @@ def run(self):

return 0

def _parsed_cmd(self):
"""
We need to take into account two cases:

- ['python code.py foo bar']: Used mainly with dvc as a library
- ['echo', 'foo bar']: List of arguments received from the CLI

The second case would need quoting, as it was passed through:
dvc run echo "foo bar"
"""
if len(self.args.command) < 2:
return " ".join(self.args.command)

return " ".join(self._quote_argument(arg) for arg in self.args.command)

def _quote_argument(self, argument):
if " " not in argument or '"' in argument:
return argument

return f'"{argument}"'


def add_parser(subparsers, parent_parser):
from dvc.command.stage import _add_common_args
Expand All @@ -103,7 +83,6 @@ def add_parser(subparsers, parent_parser):
run_parser.add_argument(
"-n", "--name", help="Stage name.",
)
_add_common_args(run_parser)
run_parser.add_argument(
"--file", metavar="<filename>", help=argparse.SUPPRESS,
)
Expand Down Expand Up @@ -134,7 +113,5 @@ def add_parser(subparsers, parent_parser):
default=False,
help=argparse.SUPPRESS,
)
run_parser.add_argument(
"command", nargs=argparse.REMAINDER, help="Command to execute."
)
_add_common_args(run_parser)
run_parser.set_defaults(func=CmdRun)
39 changes: 30 additions & 9 deletions dvc/command/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,32 @@ def log_error(relpath: str, exc: Exception):
return 0


def parse_cmd(commands: List[str]) -> str:
"""
We need to take into account two cases:

- ['python code.py foo bar']: Used mainly with dvc as a library
- ['echo', 'foo bar']: List of arguments received from the CLI

The second case would need quoting, as it was passed through:
dvc run echo "foo bar"
"""

def quote_argument(arg: str):
should_quote = " " in arg and '"' not in arg
return f'"{arg}"' if should_quote else arg

if len(commands) < 2:
return " ".join(commands)
return " ".join(map(quote_argument, commands))


class CmdStageAdd(CmdBase):
def run(self):
repo = self.repo
kwargs = vars(self.args)
kwargs["cmd"] = parse_cmd(kwargs.pop("cmd"))

stage = repo.stage.create_from_cli(validate=True, **kwargs)

with repo.scm.track_file_changes(config=repo.config):
Expand Down Expand Up @@ -226,7 +248,7 @@ def _add_common_args(parser):
metavar="<filename>",
)
parser.add_argument(
"-C",
"-c",
"--checkpoints",
action="append",
default=[],
Expand Down Expand Up @@ -254,6 +276,12 @@ def _add_common_args(parser):
"This doesn't affect any DVC operations."
),
)
parser.add_argument(
"cmd",
nargs=argparse.REMAINDER,
skshetry marked this conversation as resolved.
Show resolved Hide resolved
help="Command to execute.",
metavar="command",
)


def add_parser(subparsers, parent_parser):
Expand Down Expand Up @@ -282,15 +310,8 @@ def add_parser(subparsers, parent_parser):
help=STAGE_ADD_HELP,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
stage_add_parser.add_argument("name", help="Name of the stage to add")
stage_add_parser.add_argument(
"-c",
"--command",
action="append",
default=[],
dest="cmd",
help="Command to execute.",
required=True,
"-n", "--name", help="Name of the stage to add", required=True
)
_add_common_args(stage_add_parser)
stage_add_parser.set_defaults(func=CmdStageAdd)
Expand Down
2 changes: 1 addition & 1 deletion tests/func/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test(self):
self.assertEqual(args.outs, [out1, out2])
self.assertEqual(args.outs_no_cache, [out_no_cache1, out_no_cache2])
self.assertEqual(args.file, fname)
self.assertEqual(args.command, [cmd, arg1, arg2])
self.assertEqual(args.cmd, [cmd, arg1, arg2])


class TestPull(TestDvc):
Expand Down
65 changes: 36 additions & 29 deletions tests/unit/command/test_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,19 @@


@pytest.mark.parametrize(
"extra_args, expected_extra",
"command, parsed_command",
[
(["-c", "cmd1", "-c", "cmd2"], {"cmd": ["cmd1", "cmd2"]}),
(["-c", "cmd1"], {"cmd": ["cmd1"]}),
(["echo", "foo", "bar"], "echo foo bar"),
(["echo", '"foo bar"'], 'echo "foo bar"'),
(["echo", "foo bar"], 'echo "foo bar"'),
],
)
def test_stage_add(mocker, dvc, extra_args, expected_extra):
def test_stage_add(mocker, dvc, command, parsed_command):
cli_args = parse_args(
[
"stage",
"add",
"--name",
"name",
"--deps",
"deps",
Expand Down Expand Up @@ -53,7 +55,7 @@ def test_stage_add(mocker, dvc, extra_args, expected_extra):
"--desc",
"description",
"--force",
*extra_args,
*command,
]
)
assert cli_args.func == CmdStageAdd
Expand All @@ -62,27 +64,32 @@ def test_stage_add(mocker, dvc, extra_args, expected_extra):
m = mocker.patch.object(cmd.repo.stage, "create_from_cli")

assert cmd.run() == 0
expected = dict(
name="name",
deps=["deps"],
outs=["outs"],
outs_no_cache=["outs-no-cache"],
params=["file:param1,param2", "param3"],
metrics=["metrics"],
metrics_no_cache=["metrics-no-cache"],
plots=["plots"],
plots_no_cache=["plots-no-cache"],
live="live",
live_no_summary=True,
live_no_report=True,
wdir="wdir",
outs_persist=["outs-persist"],
outs_persist_no_cache=["outs-persist-no-cache"],
checkpoints=["checkpoints"],
always_changed=True,
external=True,
desc="description",
**expected_extra
)
# expected values should be a subset of what's in the call args list
assert expected.items() <= m.call_args[1].items()
expected = {
"name": "name",
"deps": ["deps"],
"outs": ["outs"],
"outs_no_cache": ["outs-no-cache"],
"params": ["file:param1,param2", "param3"],
"metrics": ["metrics"],
"metrics_no_cache": ["metrics-no-cache"],
"plots": ["plots"],
"plots_no_cache": ["plots-no-cache"],
"live": "live",
"live_no_summary": True,
"live_no_report": True,
"wdir": "wdir",
"outs_persist": ["outs-persist"],
"outs_persist_no_cache": ["outs-persist-no-cache"],
"checkpoints": ["checkpoints"],
"always_changed": True,
"external": True,
"desc": "description",
"cmd": parsed_command,
}
args, kwargs = m.call_args
assert not args

for key, val in expected.items():
# expected values should be a subset of what's in the call args list
assert key in kwargs
assert kwargs[key] == val