Skip to content

Commit

Permalink
exp: add exp run --machine flag
Browse files Browse the repository at this point in the history
  • Loading branch information
pmrowla committed Jan 13, 2022
1 parent f4368f3 commit e94900c
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 1 deletion.
10 changes: 10 additions & 0 deletions dvc/command/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def run(self):
checkpoint_resume=self.args.checkpoint_resume,
reset=self.args.reset,
tmp_dir=self.args.tmp_dir,
machine=self.args.machine,
**self._repro_kwargs,
)

Expand Down Expand Up @@ -130,3 +131,12 @@ def _add_run_common(parser):
"your workspace."
),
)
parser.add_argument(
"--machine",
default=None,
help=argparse.SUPPRESS,
# help=(
# "Run this experiment on the specified 'dvc machine' instance."
# )
# metavar="<name>",
)
39 changes: 38 additions & 1 deletion tests/func/experiments/executor/test_ssh.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import posixpath
from contextlib import contextmanager
from functools import partial
from urllib.parse import urlparse

import pytest
from dvc_ssh.tests.cloud import TEST_SSH_KEY_PATH, TEST_SSH_USER

from dvc.fs.ssh import SSHFileSystem
from dvc.repo.experiments.base import EXEC_HEAD, EXEC_MERGE
from dvc.repo.experiments.executor.base import ExecutorInfo
from dvc.repo.experiments.executor.base import ExecutorInfo, ExecutorResult
from dvc.repo.experiments.executor.ssh import SSHExecutor
from tests.func.machine.conftest import * # noqa, pylint: disable=wildcard-import

Expand Down Expand Up @@ -122,3 +123,39 @@ def test_reproduce(tmp_dir, scm, dvc, cloud, exp_stage, mocker):
assert mock_execute.called_once()
_name, args, _kwargs = mock_execute.mock_calls[0]
assert f"dvc exp exec-run --infofile {infofile}" in args[0]


@pytest.mark.needs_internet
@pytest.mark.parametrize("cloud", [pytest.lazy_fixture("git_ssh")])
def test_run_machine(tmp_dir, scm, dvc, cloud, exp_stage, mocker):
baseline = scm.get_rev()
factory = partial(_ssh_factory, cloud)
mocker.patch.object(
dvc.machine,
"get_executor_kwargs",
return_value={
"host": cloud.host,
"port": cloud.port,
"username": TEST_SSH_USER,
"fs_factory": factory,
},
)
mocker.patch.object(dvc.machine, "get_setup_script", return_value=None)
mock_repro = mocker.patch.object(
SSHExecutor,
"reproduce",
return_value=ExecutorResult("abc123", None, False),
)

tmp_dir.gen("params.yaml", "foo: 2")
dvc.experiments.run(exp_stage.addressing, machine="foo")
assert mock_repro.called_once()
_name, _args, kwargs = mock_repro.mock_calls[0]
info = kwargs["info"]
url = urlparse(info.git_url)
assert url.scheme == "ssh"
assert url.hostname == cloud.host
assert url.port == cloud.port
assert info.baseline_rev == baseline
assert kwargs["infofile"] is not None
assert kwargs["fs_factory"] is not None
1 change: 1 addition & 0 deletions tests/unit/command/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def test_experiments_run(dvc, scm, mocker):
"tmp_dir": False,
"checkpoint_resume": None,
"reset": False,
"machine": None,
}
default_arguments.update(repro_arguments)

Expand Down

0 comments on commit e94900c

Please sign in to comment.