Skip to content

Commit

Permalink
Add post_to_studio (#297)
Browse files Browse the repository at this point in the history
* Add `post_to_studio`

Co-authored-by: daniele <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: daniele <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Sep 28, 2022
1 parent 74e4209 commit a4ca09e
Show file tree
Hide file tree
Showing 13 changed files with 388 additions and 188 deletions.
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def lint(session: nox.Session) -> None:

args = *(session.posargs or ("--show-diff-on-failure",)), "--all-files"
session.run("pre-commit", "run", *args)
session.run("python", "-m", "mypy")
session.run("python", "-m", "mypy", "--install-types", "--non-interactive")
session.run("python", "-m", "pylint", *locations)


Expand Down
3 changes: 3 additions & 0 deletions src/dvclive/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@
DVCLIVE_HTML = "DVCLIVE_HTML"
DVCLIVE_RESUME = "DVCLIVE_RESUME"
DVC_CHECKPOINT = "DVC_CHECKPOINT"
STUDIO_ENDPOINT = "STUDIO_ENDPOINT"
STUDIO_REPO_URL = "STUDIO_REPO_URL"
STUDIO_TOKEN = "STUDIO_TOKEN"
53 changes: 44 additions & 9 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from .report import make_report
from .serialize import dump_yaml, load_yaml
from .studio import post_to_studio
from .utils import env2bool, nested_update, open_file_in_browser

logging.basicConfig()
Expand All @@ -44,8 +45,14 @@ def __init__(
self._path: Optional[str] = path
self._resume: bool = resume or env2bool(env.DVCLIVE_RESUME)

self.studio_url = os.getenv(env.STUDIO_REPO_URL, None)
self.studio_token = os.getenv(env.STUDIO_TOKEN, None)
self.rev = None

if report == "auto":
if env2bool("CI"):
if self.studio_url and self.studio_token:
report = "studio"
elif env2bool("CI"):
report = "md"
else:
report = "html"
Expand All @@ -63,19 +70,20 @@ def __init__(
if self._path is None:
self._path = self.DEFAULT_DIR

if self._report is not None:
if not self.report_path:
self.report_path = os.path.join(self.dir, f"report.{report}")
out = Path(self.report_path).resolve()
logger.info(f"Report path (if generated): {out}")

self._step: Optional[int] = None
self._scalars: Dict[str, Any] = OrderedDict()
self._images: Dict[str, Any] = OrderedDict()
self._plots: Dict[str, Any] = OrderedDict()
self._params: Dict[str, Any] = OrderedDict()

self._init_paths()

if self._report in ("html", "md"):
if not self.report_path:
self.report_path = os.path.join(self.dir, f"report.{report}")
out = Path(self.report_path).resolve()
logger.info(f"Report path (if generated): {out}")

if self._resume:
self._read_params()
self._step = self.read_step()
Expand All @@ -85,6 +93,20 @@ def __init__(
else:
self._cleanup()

self._latest_studio_step = self.get_step()

if self._report == "studio":
from scmrepo.git import Git

self.rev = Git().get_rev()

if not post_to_studio(self, "start", logger):
logger.warning(
"`post_to_studio` `start` event failed. "
"`studio` report cancelled."
)
self._report = None

def _cleanup(self):
for data_type in DATA_TYPES:
shutil.rmtree(
Expand Down Expand Up @@ -153,7 +175,16 @@ def set_step(self, step: int) -> None:
data.dump(data.val, self._step)
self.make_summary()

self.make_report()
if self._report == "studio":
if not post_to_studio(self, "data", logger):
logger.warning(
"`post_to_studio` `data` event failed."
" Data will be resent on next call."
)
else:
self._latest_studio_step = step
else:
self.make_report()

self.make_checkpoint()

Expand Down Expand Up @@ -258,10 +289,14 @@ def make_report(self):
make_report(
self.dir, self.summary_path, self.report_path, self._report
)

if self._report == "html" and env2bool(env.DVCLIVE_OPEN):
open_file_in_browser(self.report_path)

def end(self):
if self._report == "studio":
if not post_to_studio(self, "done", logger):
logger.warning("`post_to_studio` `done` event failed.")

def make_checkpoint(self):
if env2bool(env.DVC_CHECKPOINT):
make_checkpoint()
Expand Down
82 changes: 82 additions & 0 deletions src/dvclive/studio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from os import getenv

from dvclive.env import STUDIO_ENDPOINT
from dvclive.utils import parse_scalars


def _get_unsent_datapoints(plot, latest_step):
return [x for x in plot if int(x["step"]) >= latest_step]


def _cast_to_numbers(datapoints):
for datapoint in datapoints:
for k, v in datapoint.items():
if k == "step":
datapoint[k] = int(v)
elif k == "timestamp":
continue
else:
datapoint[k] = float(v)
return datapoints


def _to_dvc_format(plots):
formatted = {}
for k, v in plots.items():
formatted[k] = {"data": v}
return formatted


def _get_updates(live):
plots, metrics = parse_scalars(live)
latest_step = live._latest_studio_step # pylint: disable=protected-access

for name, plot in plots.items():
datapoints = _get_unsent_datapoints(plot, latest_step)
plots[name] = _cast_to_numbers(datapoints)

metrics = {live.summary_path: {"data": metrics}}
plots = _to_dvc_format(plots)
return metrics, plots


def post_to_studio(live, event_type, logger) -> bool:
import requests
from requests.exceptions import RequestException

data = {
"type": event_type,
"repo_url": live.studio_url,
"rev": live.rev,
"client": "dvclive",
}

if event_type == "data":
metrics, plots = _get_updates(live)
data["metrics"] = metrics
data["plots"] = plots
data["step"] = live.get_step()

logger.debug(f"post_to_studio `{event_type=}`")

try:
response = requests.post(
getenv(STUDIO_ENDPOINT, "https://studio.iterative.ai/api/live"),
json=data,
headers={
"Content-type": "application/json",
"Authorization": f"token {live.studio_token}",
},
timeout=5,
)
except RequestException:
return False

message = response.content.decode()
logger.debug(
f"post_to_studio: {response.status_code=}" f", {message=}"
if message
else ""
)

return response.status_code == 200
30 changes: 24 additions & 6 deletions src/dvclive/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import csv
import json
import os
import re
import webbrowser
Expand Down Expand Up @@ -34,12 +35,6 @@ def nested_update(d, u):
return d


def parse_tsv(path):
with open(path, "r", encoding="utf-8") as fd:
reader = csv.DictReader(fd, delimiter="\t")
return list(reader)


def run_once(f):
def wrapper(*args, **kwargs):
if not wrapper.has_run:
Expand Down Expand Up @@ -99,3 +94,26 @@ def standardize_metric_name(metric_name: str, framework: str) -> str:
metric_name = f"{split}/{freq}/{'_'.join(rest)}"

return metric_name


def parse_tsv(path):
with open(path, "r", encoding="utf-8") as fd:
reader = csv.DictReader(fd, delimiter="\t")
return list(reader)


def parse_json(path):
with open(path, "r", encoding="utf-8") as fd:
return json.load(fd)


def parse_scalars(live):
from .data import Scalar

live_dir = Path(live.dir)
history = {}
for suffix in Scalar.suffixes:
for scalar_file in live_dir.rglob(f"*{suffix}"):
history[str(scalar_file)] = parse_tsv(scalar_file)
latest = parse_json(live.summary_path)
return history, latest
19 changes: 11 additions & 8 deletions tests/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from dvclive.data.scalar import Scalar
from dvclive.huggingface import DvcLiveCallback
from tests.test_main import read_logs
from dvclive.utils import parse_scalars

# pylint: disable=redefined-outer-name, unused-argument, no-value-for-parameter

Expand Down Expand Up @@ -74,19 +74,22 @@ def test_huggingface_integration(tmp_dir, model, args, data, tokenizer):
tokenizer=tokenizer,
compute_metrics=compute_metrics,
)
trainer.add_callback(DvcLiveCallback())
callback = DvcLiveCallback()
trainer.add_callback(callback)
trainer.train()

assert os.path.exists("dvclive")

logs, _ = read_logs(tmp_dir / "dvclive" / Scalar.subfolder)
logs, _ = parse_scalars(callback.dvclive)

assert len(logs) == 10
assert os.path.join("eval", "matthews_correlation") in logs
assert os.path.join("eval", "loss") in logs
assert os.path.join("train", "loss") in logs
assert len(logs["epoch"]) == 3
assert len(logs[os.path.join("eval", "loss")]) == 2

scalars = os.path.join("dvclive", Scalar.subfolder)
assert os.path.join(scalars, "eval", "matthews_correlation.tsv") in logs
assert os.path.join(scalars, "eval", "loss.tsv") in logs
assert os.path.join(scalars, "train", "loss.tsv") in logs
assert len(logs[os.path.join(scalars, "epoch.tsv")]) == 3
assert len(logs[os.path.join(scalars, "eval", "loss.tsv")]) == 2


def test_huggingface_model_file(tmp_dir, model, args, data, tokenizer, mocker):
Expand Down
21 changes: 12 additions & 9 deletions tests/test_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dvclive import Live
from dvclive.data.scalar import Scalar
from dvclive.keras import DvcLiveCallback
from tests.test_main import read_logs
from dvclive.utils import parse_scalars

# pylint: disable=unused-argument, no-name-in-module, redefined-outer-name

Expand Down Expand Up @@ -38,20 +38,22 @@ def make():
def test_keras_callback(tmp_dir, xor_model, capture_wrap):
model, x, y = xor_model()

callback = DvcLiveCallback()
model.fit(
x,
y,
epochs=1,
batch_size=1,
validation_split=0.2,
callbacks=[DvcLiveCallback()],
callbacks=[callback],
)

assert os.path.exists("dvclive")
logs, _ = read_logs(tmp_dir / "dvclive" / Scalar.subfolder)
logs, _ = parse_scalars(callback.dvclive)

assert os.path.join("train", "accuracy") in logs
assert os.path.join("eval", "accuracy") in logs
scalars = os.path.join("dvclive", Scalar.subfolder)
assert os.path.join(scalars, "train", "accuracy.tsv") in logs
assert os.path.join(scalars, "eval", "accuracy.tsv") in logs


def test_keras_callback_pass_logger(tmp_dir, xor_model, capture_wrap):
Expand All @@ -67,11 +69,12 @@ def test_keras_callback_pass_logger(tmp_dir, xor_model, capture_wrap):
validation_split=0.2,
callbacks=[DvcLiveCallback(dvclive=logger)],
)
assert os.path.exists("train_logs")
logs, _ = read_logs(tmp_dir / "train_logs" / Scalar.subfolder)
assert os.path.exists(logger.dir)
logs, _ = parse_scalars(logger)

assert os.path.join("train", "accuracy") in logs
assert os.path.join("eval", "accuracy") in logs
scalars = os.path.join(logger.dir, Scalar.subfolder)
assert os.path.join(scalars, "train", "accuracy.tsv") in logs
assert os.path.join(scalars, "eval", "accuracy.tsv") in logs


@pytest.mark.parametrize("save_weights_only", (True, False))
Expand Down
8 changes: 4 additions & 4 deletions tests/test_lgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
from sklearn import datasets
from sklearn.model_selection import train_test_split

from dvclive.data.scalar import Scalar
from dvclive.lgbm import DvcLiveCallback
from tests.test_main import read_logs
from dvclive.utils import parse_scalars

# pylint: disable=redefined-outer-name, unused-argument

Expand Down Expand Up @@ -38,17 +37,18 @@ def test_lgbm_integration(tmp_dir, model_params, iris_data):
model = lgbm.LGBMClassifier()
model.set_params(**model_params)

callback = DvcLiveCallback()
model.fit(
iris_data[0][0],
iris_data[0][1],
eval_set=(iris_data[1][0], iris_data[1][1]),
eval_metric=["multi_logloss"],
callbacks=[DvcLiveCallback()],
callbacks=[callback],
)

assert os.path.exists("dvclive")

logs, _ = read_logs(tmp_dir / "dvclive" / Scalar.subfolder)
logs, _ = parse_scalars(callback.dvclive)
assert len(logs) == 1
assert len(list(logs.values())[0]) == 5

Expand Down
11 changes: 6 additions & 5 deletions tests/test_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from dvclive.data.scalar import Scalar
from dvclive.lightning import DvcLiveLogger
from tests.test_main import read_logs
from dvclive.utils import parse_scalars

# pylint: disable=redefined-outer-name, unused-argument

Expand Down Expand Up @@ -93,9 +93,10 @@ def test_lightning_integration(tmp_dir):
assert os.path.exists("logs")
assert not os.path.exists("DvcLiveLogger")

logs, _ = read_logs(tmp_dir / "logs" / Scalar.subfolder)
scalars = os.path.join(dvclive_logger.experiment.dir, Scalar.subfolder)
logs, _ = parse_scalars(dvclive_logger.experiment)

assert len(logs) == 3
assert os.path.join("train", "epoch", "loss") in logs
assert os.path.join("train", "step", "loss") in logs
assert "epoch" in logs
assert os.path.join(scalars, "train", "epoch", "loss.tsv") in logs
assert os.path.join(scalars, "train", "step", "loss.tsv") in logs
assert os.path.join(scalars, "epoch.tsv") in logs
Loading

0 comments on commit a4ca09e

Please sign in to comment.