Skip to content

Commit

Permalink
Image support (#166)
Browse files Browse the repository at this point in the history
* Added image_pil and image_numpy

* Use DATA_TYPES list in metrics

* Use subdir structure

* Use data subdirs in init_path

* Fix test_logging

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Added test_image

* pre-commit

* Fix catalyst and fastai

* Make pillow optional dep

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Renamed scalar -> scalars

* Raise exception

* Fix pylint

* Old summary

* Removed subdirs

* Add image summary

* Fix test subdirs

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Include step in image summary

* lint

* Raise Error on lazy PIL import

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix setup

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixed merge

* Fixed tests

* Fixed step formatting

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
daavoo and pre-commit-ci[bot] authored Nov 8, 2021
1 parent f6739c5 commit 5227d6e
Show file tree
Hide file tree
Showing 8 changed files with 148 additions and 24 deletions.
6 changes: 5 additions & 1 deletion dvclive/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
from .scalar import Scalar # noqa: F401
from .image_numpy import ImageNumpy
from .image_pil import ImagePIL
from .scalar import Scalar

DATA_TYPES = [ImageNumpy, ImagePIL, Scalar]
24 changes: 24 additions & 0 deletions dvclive/data/image_numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from dvclive.error import DvcLiveError

from .image_pil import ImagePIL


class ImageNumpy(ImagePIL):
@staticmethod
def could_log(val: object) -> bool:
if val.__class__.__module__ == "numpy":
return True
return False

def dump(self, val, step) -> None:
try:
from PIL import Image
except ImportError as e:
raise DvcLiveError(
"'pillow' is required for logging images."
" You can install it by running"
" 'pip install pillow'"
) from e

val = Image.fromarray(val)
super().dump(val, step)
33 changes: 33 additions & 0 deletions dvclive/data/image_pil.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from pathlib import Path

from .base import Data


class ImagePIL(Data):
suffixes = [".jpg", ".jpeg", ".gif", ".png"]

@staticmethod
def could_log(val: object) -> bool:
if val.__class__.__module__ == "PIL.Image":
return True
return False

@property
def output_path(self) -> Path:
if Path(self.name).suffix not in self.suffixes:
raise ValueError(
f"Invalid image suffix '{Path(self.name).suffix}'"
f" Must be one of {self.suffixes}"
)
return self.output_folder / "{step}" / self.name

def dump(self, val, step) -> None:
super().dump(val, step)
output_path = Path(str(self.output_path).format(step=step))
output_path.parent.mkdir(exist_ok=True, parents=True)

val.save(output_path)

@property
def summary(self):
return {self.name: str(self.output_path).format(step=self.step)}
2 changes: 2 additions & 0 deletions dvclive/data/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@


class Scalar(Data):
suffixes = [".csv", ".tsv"]

@staticmethod
def could_log(val: object) -> bool:
if isinstance(val, (int, float)):
Expand Down
28 changes: 18 additions & 10 deletions dvclive/live.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import json
import logging
import os
import shutil
from collections import OrderedDict
from pathlib import Path
from typing import Any, Dict, Optional, Union

from .data import Scalar
from .data import DATA_TYPES
from .dvc import make_checkpoint, make_html
from .error import ConfigMismatchError, InvalidDataTypeError

Expand Down Expand Up @@ -46,19 +47,23 @@ def __init__(

def _cleanup(self):

for dvclive_file in Path(self.dir).rglob("*.tsv"):
dvclive_file.unlink()
for data_type in DATA_TYPES:
for suffix in data_type.suffixes:
for data_file in Path(self.dir).rglob(f"*{suffix}"):
data_file.unlink()

if os.path.exists(self.summary_path):
os.remove(self.summary_path)

if os.path.exists(self.html_path):
os.remove(self.html_path)
shutil.rmtree(Path(self.html_path).parent, ignore_errors=True)

def _init_paths(self):
os.makedirs(self.dir, exist_ok=True)
if self._summary:
self.make_summary()
if self._html:
os.makedirs(Path(self.html_path).parent, exist_ok=True)

def init_from_env(self) -> None:
from . import env
Expand Down Expand Up @@ -96,11 +101,11 @@ def exists(self):

@property
def summary_path(self):
return self.dir + ".json"
return str(self.dir) + ".json"

@property
def html_path(self):
return self.dir + "_dvc_plots/index.html"
return str(self.dir) + "_dvc_plots/index.html"

def get_step(self) -> int:
return self._step
Expand All @@ -119,15 +124,18 @@ def next_step(self):

def log(self, name: str, val: Union[int, float]):

data = None
if name in self._data:
data = self._data[name]
elif Scalar.could_log(val):
data = Scalar(name, self.dir)
self._data[name] = data
else:
for data_type in DATA_TYPES:
if data_type.could_log(val):
data = data_type(name, self.dir)
self._data[name] = data
if data is None:
raise InvalidDataTypeError(name, type(val))

data.dump(val, self._step)

if self._summary:
self.make_summary()

Expand Down
7 changes: 5 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,17 @@ def run(self):
_build_py.run(self)


mmcv = ["mmcv", "torch", "torchvision"]
mmcv = ["mmcv"]
tf = ["tensorflow"]
xgb = ["xgboost"]
lgbm = ["lightgbm"]
hugginface = ["transformers", "datasets"]
catalyst = ["catalyst"]
fastai = ["fastai"]
pl = ["pytorch_lightning"]
image = ["pillow"]

all_libs = mmcv + tf + xgb + lgbm + hugginface + catalyst + fastai + pl
all_libs = mmcv + tf + xgb + lgbm + hugginface + catalyst + fastai + pl + image

tests_requires = [
"pylint==2.5.3",
Expand Down Expand Up @@ -74,9 +75,11 @@ def run(self):
"tf": tf,
"xgb": xgb,
"lgbm": lgbm,
"mmcv": mmcv,
"huggingface": hugginface,
"catalyst": catalyst,
"fastai": fastai,
"image": image,
"pytorch_lightning": pl,
},
keywords="data-science metrics machine-learning developer-tools ai",
Expand Down
53 changes: 53 additions & 0 deletions tests/test_data/test_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import os

import numpy as np
import pytest
from PIL import Image

# pylint: disable=unused-argument
from dvclive import Live
from tests.test_main import _parse_json


def test_PIL(tmp_dir):
dvclive = Live()
img = Image.new("RGB", (500, 500), (250, 250, 250))
dvclive.log("image.png", img)

assert (tmp_dir / dvclive.dir / "0" / "image.png").exists()
summary = _parse_json("dvclive.json")

assert summary["image.png"] == os.path.join(dvclive.dir, "0", "image.png")


def test_invalid_extension(tmp_dir):
dvclive = Live()
img = Image.new("RGB", (500, 500), (250, 250, 250))
with pytest.raises(ValueError):
dvclive.log("image.foo", img)


@pytest.mark.parametrize("shape", [(500, 500), (500, 500, 3), (500, 500, 4)])
def test_numpy(tmp_dir, shape):
dvclive = Live()
img = np.ones(shape, np.uint8) * 255
dvclive.log("image.png", img)

assert (tmp_dir / dvclive.dir / "0" / "image.png").exists()


def test_step_formatting(tmp_dir):
dvclive = Live()
img = np.ones((500, 500, 3), np.uint8)
for _ in range(3):
dvclive.log("image.png", img)
dvclive.next_step()

for step in range(3):
assert (tmp_dir / dvclive.dir / str(step) / "image.png").exists()

summary = _parse_json("dvclive.json")

assert summary["image.png"] == os.path.join(
dvclive.dir, str(step), "image.png"
)
19 changes: 8 additions & 11 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@


def read_logs(path: str):
assert os.path.isdir(path)
path = Path(path)
assert path.is_dir()
history = {}
for metric_file in Path(path).rglob("*.tsv"):
metric_name = str(metric_file).replace(path + os.path.sep, "")
for metric_file in path.rglob("*.tsv"):
metric_name = str(metric_file).replace(str(path) + os.path.sep, "")
metric_name = metric_name.replace(".tsv", "")
history[metric_name] = _parse_tsv(metric_file)
latest = _parse_json(path + ".json")
latest = _parse_json(str(path) + ".json")
return history, latest


Expand Down Expand Up @@ -67,9 +68,8 @@ def test_logging(tmp_dir, summary):

dvclive.log("m1", 1)

assert (tmp_dir / "logs").is_dir()
assert (tmp_dir / "logs" / "m1.tsv").is_file()
assert (tmp_dir / "logs.json").is_file() == summary
assert (tmp_dir / dvclive.summary_path).is_file() == summary

if summary:
_, s = read_logs("logs")
Expand All @@ -82,8 +82,6 @@ def test_nested_logging(tmp_dir):
dvclive.log("train/m1", 1)
dvclive.log("val/val_1/m1", 1)

assert (tmp_dir / "logs").is_dir()
assert (tmp_dir / "logs" / "train").is_dir()
assert (tmp_dir / "logs" / "val" / "val_1").is_dir()
assert (tmp_dir / "logs" / "train" / "m1.tsv").is_file()
assert (tmp_dir / "logs" / "val" / "val_1" / "m1.tsv").is_file()
Expand Down Expand Up @@ -129,20 +127,19 @@ def test_cleanup(tmp_dir, summary, html):

html_path = tmp_dir / dvclive.html_path
if html:
html_path.parent.mkdir()
html_path.touch()

(tmp_dir / "logs" / "some_user_file.txt").touch()

assert (tmp_dir / "logs" / "m1.tsv").is_file()
assert (tmp_dir / "logs.json").is_file() == summary
assert (tmp_dir / dvclive.summary_path).is_file() == summary
assert html_path.is_file() == html

dvclive = Live("logs", summary=summary)

assert (tmp_dir / "logs" / "some_user_file.txt").is_file()
assert not (tmp_dir / "logs" / "m1.tsv").is_file()
assert (tmp_dir / "logs.json").is_file() == summary
assert (tmp_dir / dvclive.summary_path).is_file() == summary
assert not (html_path).is_file()


Expand Down

0 comments on commit 5227d6e

Please sign in to comment.