Skip to content

Commit

Permalink
live: Explicitly require exp for sending Studio updates.
Browse files Browse the repository at this point in the history
Requires using either `dvc exp run` or `save_dvc_exp=True`.

Closes #474
  • Loading branch information
daavoo committed Mar 27, 2023
1 parent 32d5778 commit 62880aa
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 17 deletions.
13 changes: 11 additions & 2 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
self._init_cleanup()

self._baseline_rev: Optional[str] = None
self._exp_name: str = "dvclive-exp"
self._exp_name: Optional[str] = None
self._experiment_rev: Optional[str] = None
self._inside_dvc_exp: bool = False
self._dvc_repo = None
Expand Down Expand Up @@ -146,12 +146,21 @@ def _init_studio(self):
self._studio_events_to_skip.add("done")
elif self._dvc_repo is None:
logger.warning(
"Can't send updates to Studio without a DVC Repo."
"Can't connect to Studio without a DVC Repo."
"\nYou can create a DVC Repo by calling `dvc init`."
)
self._studio_events_to_skip.add("start")
self._studio_events_to_skip.add("data")
self._studio_events_to_skip.add("done")
elif not self._save_dvc_exp:
logger.warning(
"Can't connect to Studio without creating a DVC experiment."
"\nIf you have a DVC Pipeline, run it with `dvc exp run`."
"\nIf you are using DVCLive alone, use `save_dvc_exp=True`."
)
self._studio_events_to_skip.add("start")
self._studio_events_to_skip.add("data")
self._studio_events_to_skip.add("done")
else:
response = post_live_metrics(
"start", self._baseline_rev, self._exp_name, "dvclive"
Expand Down
6 changes: 3 additions & 3 deletions tests/test_dvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,15 @@ def test_exp_save_on_end(tmp_dir, save, mocked_dvc_repo):
live.end()
if save:
assert live._baseline_rev is not None
assert live._exp_name != "dvclive-exp"
assert live._exp_name is not None
mocked_dvc_repo.experiments.save.assert_called_with(
name=live._exp_name,
include_untracked=[live.dir],
force=True,
)
else:
assert live._baseline_rev is not None
assert live._exp_name == "dvclive-exp"
assert live._exp_name is None
mocked_dvc_repo.experiments.save.assert_not_called()


Expand Down Expand Up @@ -131,7 +131,7 @@ def test_exp_save_run_on_dvc_repro(tmp_dir, mocker):
live = Live(save_dvc_exp=True)
assert live._save_dvc_exp
assert live._baseline_rev is not None
assert live._exp_name != "dvclive-exp"
assert live._exp_name is not None
live.end()

dvc_repo.experiments.save.assert_called_with(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_frameworks/test_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def test_lightning_val_udpates_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio
mocked_post, _ = mocked_studio_post

model = ValLitXOR()
dvclive_logger = DVCLiveLogger()
dvclive_logger = DVCLiveLogger(save_dvc_exp=True)
trainer = Trainer(
logger=dvclive_logger,
max_steps=4,
Expand Down
27 changes: 16 additions & 11 deletions tests/test_studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def test_post_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post):
live = Live()
live = Live(save_dvc_exp=True)
live.log_param("fooparam", 1)

dvc_path = Path(live.dvc_file).as_posix()
Expand All @@ -25,7 +25,7 @@ def test_post_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post):
"type": "start",
"repo_url": "STUDIO_REPO_URL",
"baseline_sha": "f" * 40,
"name": "dvclive-exp",
"name": live._exp_name,
"client": "dvclive",
},
headers={
Expand All @@ -44,7 +44,7 @@ def test_post_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post):
"type": "data",
"repo_url": "STUDIO_REPO_URL",
"baseline_sha": "f" * 40,
"name": "dvclive-exp",
"name": live._exp_name,
"step": 0,
"metrics": {metrics_path: {"data": {"step": 0, "foo": 1}}},
"params": {params_path: {"fooparam": 1}},
Expand All @@ -67,7 +67,7 @@ def test_post_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post):
"type": "data",
"repo_url": "STUDIO_REPO_URL",
"baseline_sha": "f" * 40,
"name": "dvclive-exp",
"name": live._exp_name,
"step": 1,
"metrics": {metrics_path: {"data": {"step": 1, "foo": 2}}},
"params": {params_path: {"fooparam": 1}},
Expand All @@ -88,7 +88,7 @@ def test_post_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post):
"type": "done",
"repo_url": "STUDIO_REPO_URL",
"baseline_sha": "f" * 40,
"name": "dvclive-exp",
"name": live._exp_name,
"client": "dvclive",
},
headers={
Expand All @@ -104,7 +104,7 @@ def test_post_to_studio_failed_data_request(
):
mocked_post, valid_response = mocked_studio_post

live = Live()
live = Live(save_dvc_exp=True)

dvc_path = Path(live.dvc_file).as_posix()
metrics_path = Path(live.metrics_file).as_posix()
Expand All @@ -125,7 +125,7 @@ def test_post_to_studio_failed_data_request(
"type": "data",
"repo_url": "STUDIO_REPO_URL",
"baseline_sha": "f" * 40,
"name": "dvclive-exp",
"name": live._exp_name,
"step": 1,
"metrics": {metrics_path: {"data": {"step": 1, "foo": 2}}},
"plots": {
Expand Down Expand Up @@ -153,7 +153,7 @@ def test_post_to_studio_failed_start_request(
mocked_response.status_code = 400
mocked_post = mocker.patch("requests.post", return_value=mocked_response)

live = Live()
live = Live(save_dvc_exp=True)

live.log_metric("foo", 1)
live.next_step()
Expand All @@ -166,7 +166,7 @@ def test_post_to_studio_failed_start_request(

def test_post_to_studio_end_only_once(tmp_dir, mocked_dvc_repo, mocked_studio_post):
mocked_post, _ = mocked_studio_post
with Live() as live:
with Live(save_dvc_exp=True) as live:
live.log_metric("foo", 1)
live.next_step()

Expand Down Expand Up @@ -245,7 +245,7 @@ def test_post_to_studio_include_prefix_if_needed(
def test_post_to_studio_shorten_names(tmp_dir, mocked_dvc_repo, mocked_studio_post):
mocked_post, _ = mocked_studio_post

live = Live()
live = Live(save_dvc_exp=True)
live.log_metric("eval/loss", 1)
live.next_step()

Expand Down Expand Up @@ -299,7 +299,7 @@ def test_post_to_studio_inside_subdir(
subdir.mkdir()
monkeypatch.chdir(subdir)

live = Live()
live = Live(save_dvc_exp=True)
live.log_metric("foo", 1)
live.next_step()

Expand Down Expand Up @@ -373,3 +373,8 @@ def test_post_to_studio_inside_subdir_dvc_exp(
},
timeout=5,
)


def test_post_to_studio_requires_exp(tmp_dir, mocked_dvc_repo, mocked_studio_post):
assert Live()._studio_events_to_skip == {"start", "data", "done"}
assert not Live(save_dvc_exp=True)._studio_events_to_skip

0 comments on commit 62880aa

Please sign in to comment.