Skip to content

Commit

Permalink
live: Moved all make_{x} to next_step(). (#353)
Browse files Browse the repository at this point in the history
`log_metric` now stores logged value in `live.summary` but doesn't call `make_summary` as before.

Call `make_summary` inside `live.end`.

Closes #238
Closes #232
  • Loading branch information
daavoo authored Nov 4, 2022
1 parent e0fbcc1 commit f985113
Show file tree
Hide file tree
Showing 10 changed files with 36 additions and 24 deletions.
1 change: 0 additions & 1 deletion src/dvclive/catalyst.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,4 @@ def on_epoch_end(self, runner) -> None:
scheduler=runner.scheduler,
)
utils.save_checkpoint(checkpoint, self.model_file)
self.live.make_report()
self.live.next_step()
1 change: 0 additions & 1 deletion src/dvclive/fastai.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,4 @@ def after_epoch(self):

if self.model_file:
self.learn.save(self.model_file)
self.live.make_report()
self.live.next_step()
1 change: 0 additions & 1 deletion src/dvclive/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def on_log(
logs = kwargs["logs"]
for key, value in logs.items():
self.live.log_metric(standardize_metric_name(key, __name__), value)
self.live.make_report()
self.live.next_step()

def on_epoch_end(
Expand Down
1 change: 0 additions & 1 deletion src/dvclive/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,4 @@ def on_epoch_end(
self.model.save_weights(self.model_file)
else:
self.model.save(self.model_file)
self.live.make_report()
self.live.next_step()
1 change: 0 additions & 1 deletion src/dvclive/lgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,4 @@ def __call__(self, env):

if self.model_file:
env.model.save_model(self.model_file)
self.live.make_report()
self.live.next_step()
1 change: 0 additions & 1 deletion src/dvclive/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,4 @@ def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None):
metric_val = metric_val.cpu().detach().item()
metric_name = standardize_metric_name(metric_name, __name__)
self.experiment.log_metric(name=metric_name, val=metric_val)
self.experiment.make_report()
self.experiment.next_step()
33 changes: 17 additions & 16 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(
else:
self._cleanup()

self._latest_studio_step = self.get_step()
self._latest_studio_step = self.get_step() if resume else -1
if self.report_mode == "studio":
from scmrepo.git import Git

Expand Down Expand Up @@ -134,23 +134,16 @@ def get_step(self) -> int:
return self._step or 0

def set_step(self, step: int) -> None:
if self._step is None:
self._step = 0
self.make_summary()

if self.report_mode == "studio":
if not post_to_studio(self, "data", logger):
logger.warning(
"`post_to_studio` `data` event failed."
" Data will be resent on next call."
)
else:
self._latest_studio_step = step

self._step = step
logger.debug(f"Step: {self._step}")

def next_step(self):
if self._step is None:
self._step = 0

self.make_summary()
self.make_report()
self.make_checkpoint()
self.set_step(self.get_step() + 1)

def log_metric(
Expand All @@ -169,7 +162,6 @@ def log_metric(
data.dump(val, timestamp=timestamp)

self.summary = nested_update(self.summary, data.to_summary(val))
self.make_summary()
logger.debug(f"Logged {name}: {val}")

def log_image(self, name: str, val):
Expand Down Expand Up @@ -233,12 +225,21 @@ def make_summary(self):
dump_json(self.summary, self.metrics_file, cls=NumpyEncoder)

def make_report(self):
if self.report_mode is not None:
if self.report_mode == "studio":
if not post_to_studio(self, "data", logger):
logger.warning(
"`post_to_studio` `data` event failed."
" Data will be resent on next call."
)
else:
self._latest_studio_step = self.get_step()
elif self.report_mode is not None:
make_report(self)
if self.report_mode == "html" and env2bool(env.DVCLIVE_OPEN):
open_file_in_browser(self.report_file)

def end(self):
self.make_summary()
if self.report_mode == "studio":
if not post_to_studio(self, "done", logger):
logger.warning("`post_to_studio` `done` event failed.")
Expand Down
2 changes: 1 addition & 1 deletion src/dvclive/studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


def _get_unsent_datapoints(plot, latest_step):
return [x for x in plot if int(x["step"]) >= latest_step]
return [x for x in plot if int(x["step"]) > latest_step]


def _cast_to_numbers(datapoints):
Expand Down
1 change: 0 additions & 1 deletion src/dvclive/xgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,4 @@ def after_iteration(self, model, epoch, evals_log):
self.live.log_metric(key, latest_metric)
if self.model_file:
model.save_model(self.model_file)
self.live.make_report()
self.live.next_step()
18 changes: 18 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def test_logging_no_step(tmp_dir):
dvclive = Live("logs")

dvclive.log_metric("m1", 1)
dvclive.make_summary()

assert not (tmp_dir / "logs" / "m1.tsv").is_file()
assert (tmp_dir / dvclive.metrics_file).is_file()
Expand Down Expand Up @@ -254,6 +255,7 @@ def test_custom_steps(tmp_dir):
for step, metric in zip(steps, metrics):
dvclive.set_step(step)
dvclive.log_metric("m", metric)
dvclive.make_summary()

assert read_history(dvclive, "m") == (steps, metrics)
assert read_latest(dvclive, "m") == (last(steps), last(metrics))
Expand All @@ -265,10 +267,12 @@ def test_log_reset_with_set_step(tmp_dir):
for i in range(3):
dvclive.set_step(i)
dvclive.log_metric("train_m", 1)
dvclive.make_summary()

for i in range(3):
dvclive.set_step(i)
dvclive.log_metric("val_m", 1)
dvclive.make_summary()

assert read_history(dvclive, "train_m") == ([0, 1, 2], [1, 1, 1])
assert read_history(dvclive, "val_m") == ([0, 1, 2], [1, 1, 1])
Expand Down Expand Up @@ -366,3 +370,17 @@ def test_log_metric_timestamp(timestamp):
history, _ = parse_metrics(live)
logged = next(iter(history.values()))
assert ("timestamp" in logged[0]) == timestamp


def test_make_summary_is_called_on_end(tmp_dir):
live = Live()

live.summary["foo"] = 1.0
live.end()

assert json.loads((tmp_dir / live.metrics_file).read_text()) == {
# no `step`
"foo": 1.0
}
log_file = tmp_dir / live.plots_dir / Metric.subfolder / "foo.tsv"
assert not log_file.exists()

0 comments on commit f985113

Please sign in to comment.