Skip to content

Commit

Permalink
post_live_metrics: Add _post_in_chunks.
Browse files Browse the repository at this point in the history
  • Loading branch information
daavoo committed Aug 21, 2023
1 parent 4fd7f1a commit 848ad45
Showing 1 changed file with 72 additions and 24 deletions.
96 changes: 72 additions & 24 deletions src/dvc_studio_client/post_live_metrics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json
from collections import defaultdict
from os import getenv
from typing import Any, Dict, Literal, Optional
from urllib.parse import urljoin
Expand All @@ -12,6 +14,8 @@
from .env import DVC_STUDIO_TOKEN, STUDIO_ENDPOINT, STUDIO_TOKEN
from .schema import SCHEMAS_BY_TYPE

MAX_CHUNK_SIZE = 29000000


def get_studio_token_and_repo_url(studio_token=None, studio_repo_url=None):
studio_token = studio_token or getenv(DVC_STUDIO_TOKEN) or getenv(STUDIO_TOKEN)
Expand All @@ -22,6 +26,67 @@ def get_studio_token_and_repo_url(studio_token=None, studio_repo_url=None):
return config.get("token"), config.get("repo_url")


def _single_post(url, body, token):
try:
response = requests.post(
url,
json=body,
headers={
"Content-type": "application/json",
"Authorization": f"token {token}",
},
timeout=(30, 5),
)
except RequestException as e:
logger.warning(f"Failed to post to Studio: {e}")
return False

message = response.content.decode()
logger.debug(
f"post_to_studio: {response.status_code=}" f", {message=}" if message else ""
)

if response.status_code != 200:
logger.warning(f"Failed to post to Studio: {message}")
return False

return True


def _post_in_chunks(url, body, token):
plots = body.pop("plots")

# First post only metrics and params
_single_post(url, body, token)
body.pop("metrics", None)
body.pop("params", None)

# Studio backend has a limitation on the size of the request body.
# So we split the plots into chunks and post them separately.
chunks = defaultdict(dict)
total_size = 0
for plot_name, plot_data in plots.items():
if "data" in plot_data:
size = len(json.dumps(plot_data["data"]).encode("utf-8"))
elif "image" in plot_data:
size = len(plot_data["image"])

if size > MAX_CHUNK_SIZE:
logger.warning(f"Plot {plot_name} is too large to be sent to Studio.")
continue

total_size += size
chunks[int(total_size / MAX_CHUNK_SIZE)][plot_name] = plot_data

for n, chunk in chunks.items():
logger.debug("Posting chunk: " + str(n))
body["plots"] = chunk
if not _single_post(url, body, token):
return False

return True


def post_live_metrics( # noqa: C901
event_type: Literal["start", "data", "done"],
baseline_sha: str,
Expand Down Expand Up @@ -98,6 +163,9 @@ def post_live_metrics( # noqa: C901
plots={
"dvclive/plots/metrics/foo.tsv": {
"data": [{"step": 0, "foo": 1.0}]
},
"dvclive/plots/images/bar.png": {
"image": "base64-string"
}
}
```
Expand Down Expand Up @@ -156,7 +224,6 @@ def post_live_metrics( # noqa: C901
body["step"] = step
if plots:
body["plots"] = plots

elif event_type == "done":
if experiment_rev:
body["experiment_rev"] = experiment_rev
Expand All @@ -172,31 +239,12 @@ def post_live_metrics( # noqa: C901
return None

logger.debug(f"post_studio_live_metrics `{event_type=}`")
logger.debug(f"JSON body `{body=}`")

path = getenv(STUDIO_ENDPOINT) or "api/live"
url = urljoin(config["url"], path)
try:
response = requests.post(
url,
json=body,
headers={
"Content-type": "application/json",
"Authorization": f"token {config['token']}",
},
timeout=(30, 5),
)
except RequestException as e:
logger.warning(f"Failed to post to Studio: {e}")
return False

message = response.content.decode()
logger.debug(
f"post_to_studio: {response.status_code=}" f", {message=}" if message else ""
)
token = config["token"]

if response.status_code != 200:
logger.warning(f"Failed to post to Studio: {message}")
return False
if body["type"] != "data" or "plots" not in body:
return _single_post(url, body, token)

return True
return _post_in_chunks(url, body, token)

0 comments on commit 848ad45

Please sign in to comment.