Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

live: Moved all make_{x} to next_step(). #353

Merged
merged 1 commit into from
Nov 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/dvclive/catalyst.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,4 @@ def on_epoch_end(self, runner) -> None:
scheduler=runner.scheduler,
)
utils.save_checkpoint(checkpoint, self.model_file)
self.live.make_report()
self.live.next_step()
1 change: 0 additions & 1 deletion src/dvclive/fastai.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,4 @@ def after_epoch(self):

if self.model_file:
self.learn.save(self.model_file)
self.live.make_report()
self.live.next_step()
1 change: 0 additions & 1 deletion src/dvclive/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def on_log(
logs = kwargs["logs"]
for key, value in logs.items():
self.live.log_metric(standardize_metric_name(key, __name__), value)
self.live.make_report()
self.live.next_step()

def on_epoch_end(
Expand Down
1 change: 0 additions & 1 deletion src/dvclive/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,4 @@ def on_epoch_end(
self.model.save_weights(self.model_file)
else:
self.model.save(self.model_file)
self.live.make_report()
self.live.next_step()
1 change: 0 additions & 1 deletion src/dvclive/lgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,4 @@ def __call__(self, env):

if self.model_file:
env.model.save_model(self.model_file)
self.live.make_report()
self.live.next_step()
1 change: 0 additions & 1 deletion src/dvclive/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,4 @@ def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None):
metric_val = metric_val.cpu().detach().item()
metric_name = standardize_metric_name(metric_name, __name__)
self.experiment.log_metric(name=metric_name, val=metric_val)
self.experiment.make_report()
self.experiment.next_step()
33 changes: 17 additions & 16 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(
else:
self._cleanup()

self._latest_studio_step = self.get_step()
self._latest_studio_step = self.get_step() if resume else -1
if self.report_mode == "studio":
from scmrepo.git import Git

Expand Down Expand Up @@ -134,23 +134,16 @@ def get_step(self) -> int:
return self._step or 0

def set_step(self, step: int) -> None:
if self._step is None:
self._step = 0
self.make_summary()

if self.report_mode == "studio":
if not post_to_studio(self, "data", logger):
logger.warning(
"`post_to_studio` `data` event failed."
" Data will be resent on next call."
)
else:
self._latest_studio_step = step

self._step = step
logger.debug(f"Step: {self._step}")

def next_step(self):
if self._step is None:
self._step = 0

self.make_summary()
self.make_report()
self.make_checkpoint()
self.set_step(self.get_step() + 1)

def log_metric(
Expand All @@ -169,7 +162,6 @@ def log_metric(
data.dump(val, timestamp=timestamp)

self.summary = nested_update(self.summary, data.to_summary(val))
self.make_summary()
logger.debug(f"Logged {name}: {val}")

def log_image(self, name: str, val):
Expand Down Expand Up @@ -233,12 +225,21 @@ def make_summary(self):
dump_json(self.summary, self.metrics_file, cls=NumpyEncoder)

def make_report(self):
if self.report_mode is not None:
if self.report_mode == "studio":
if not post_to_studio(self, "data", logger):
logger.warning(
"`post_to_studio` `data` event failed."
" Data will be resent on next call."
)
else:
self._latest_studio_step = self.get_step()
elif self.report_mode is not None:
make_report(self)
if self.report_mode == "html" and env2bool(env.DVCLIVE_OPEN):
open_file_in_browser(self.report_file)

def end(self):
self.make_summary()
if self.report_mode == "studio":
if not post_to_studio(self, "done", logger):
logger.warning("`post_to_studio` `done` event failed.")
Expand Down
2 changes: 1 addition & 1 deletion src/dvclive/studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


def _get_unsent_datapoints(plot, latest_step):
return [x for x in plot if int(x["step"]) >= latest_step]
return [x for x in plot if int(x["step"]) > latest_step]


def _cast_to_numbers(datapoints):
Expand Down
1 change: 0 additions & 1 deletion src/dvclive/xgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,4 @@ def after_iteration(self, model, epoch, evals_log):
self.live.log_metric(key, latest_metric)
if self.model_file:
model.save_model(self.model_file)
self.live.make_report()
self.live.next_step()
18 changes: 18 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def test_logging_no_step(tmp_dir):
dvclive = Live("logs")

dvclive.log_metric("m1", 1)
dvclive.make_summary()

assert not (tmp_dir / "logs" / "m1.tsv").is_file()
assert (tmp_dir / dvclive.metrics_file).is_file()
Expand Down Expand Up @@ -254,6 +255,7 @@ def test_custom_steps(tmp_dir):
for step, metric in zip(steps, metrics):
dvclive.set_step(step)
dvclive.log_metric("m", metric)
dvclive.make_summary()

assert read_history(dvclive, "m") == (steps, metrics)
assert read_latest(dvclive, "m") == (last(steps), last(metrics))
Expand All @@ -265,10 +267,12 @@ def test_log_reset_with_set_step(tmp_dir):
for i in range(3):
dvclive.set_step(i)
dvclive.log_metric("train_m", 1)
dvclive.make_summary()

for i in range(3):
dvclive.set_step(i)
dvclive.log_metric("val_m", 1)
dvclive.make_summary()

assert read_history(dvclive, "train_m") == ([0, 1, 2], [1, 1, 1])
assert read_history(dvclive, "val_m") == ([0, 1, 2], [1, 1, 1])
Expand Down Expand Up @@ -366,3 +370,17 @@ def test_log_metric_timestamp(timestamp):
history, _ = parse_metrics(live)
logged = next(iter(history.values()))
assert ("timestamp" in logged[0]) == timestamp


def test_make_summary_is_called_on_end(tmp_dir):
live = Live()

live.summary["foo"] = 1.0
live.end()

assert json.loads((tmp_dir / live.metrics_file).read_text()) == {
# no `step`
"foo": 1.0
}
log_file = tmp_dir / live.plots_dir / Metric.subfolder / "foo.tsv"
assert not log_file.exists()