From 5227d6e6863c61d95a8861d7c4392d6ed9ee5120 Mon Sep 17 00:00:00 2001 From: David de la Iglesia Castro Date: Mon, 8 Nov 2021 16:44:21 +0100 Subject: [PATCH] Image support (#166) * Added image_pil and image_numpy * Use DATA_TYPES list in metrics * Use subdir structure * Use data subdirs in init_path * Fix test_logging * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Added test_image * pre-commit * Fix catalyst and fastai * Make pillow optional dep * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Renamed scalar -> scalars * Raise exception * Fix pylint * Old summary * Removed subdirs * Add image summary * Fix test subdirs * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Include step in image summary * lint * Raise Error on lazy PIL import * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix setup * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed merge * Fixed tests * Fixed step formatting Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- dvclive/data/__init__.py | 6 +++- dvclive/data/image_numpy.py | 24 ++++++++++++++++ dvclive/data/image_pil.py | 33 ++++++++++++++++++++++ dvclive/data/scalar.py | 2 ++ dvclive/live.py | 28 +++++++++++------- setup.py | 7 +++-- tests/test_data/test_image.py | 53 +++++++++++++++++++++++++++++++++++ tests/test_main.py | 19 ++++++------- 8 files changed, 148 insertions(+), 24 deletions(-) create mode 100644 dvclive/data/image_numpy.py create mode 100644 dvclive/data/image_pil.py create mode 100644 tests/test_data/test_image.py diff --git a/dvclive/data/__init__.py b/dvclive/data/__init__.py index 8dc7d5f7..2be9d386 100644 --- a/dvclive/data/__init__.py +++ b/dvclive/data/__init__.py @@ -1 +1,5 @@ -from .scalar import Scalar # noqa: F401 +from .image_numpy import ImageNumpy +from .image_pil import ImagePIL +from .scalar import Scalar + +DATA_TYPES = [ImageNumpy, ImagePIL, Scalar] diff --git a/dvclive/data/image_numpy.py b/dvclive/data/image_numpy.py new file mode 100644 index 00000000..7b5b9630 --- /dev/null +++ b/dvclive/data/image_numpy.py @@ -0,0 +1,24 @@ +from dvclive.error import DvcLiveError + +from .image_pil import ImagePIL + + +class ImageNumpy(ImagePIL): + @staticmethod + def could_log(val: object) -> bool: + if val.__class__.__module__ == "numpy": + return True + return False + + def dump(self, val, step) -> None: + try: + from PIL import Image + except ImportError as e: + raise DvcLiveError( + "'pillow' is required for logging images." + " You can install it by running" + " 'pip install pillow'" + ) from e + + val = Image.fromarray(val) + super().dump(val, step) diff --git a/dvclive/data/image_pil.py b/dvclive/data/image_pil.py new file mode 100644 index 00000000..3f6d4d93 --- /dev/null +++ b/dvclive/data/image_pil.py @@ -0,0 +1,33 @@ +from pathlib import Path + +from .base import Data + + +class ImagePIL(Data): + suffixes = [".jpg", ".jpeg", ".gif", ".png"] + + @staticmethod + def could_log(val: object) -> bool: + if val.__class__.__module__ == "PIL.Image": + return True + return False + + @property + def output_path(self) -> Path: + if Path(self.name).suffix not in self.suffixes: + raise ValueError( + f"Invalid image suffix '{Path(self.name).suffix}'" + f" Must be one of {self.suffixes}" + ) + return self.output_folder / "{step}" / self.name + + def dump(self, val, step) -> None: + super().dump(val, step) + output_path = Path(str(self.output_path).format(step=step)) + output_path.parent.mkdir(exist_ok=True, parents=True) + + val.save(output_path) + + @property + def summary(self): + return {self.name: str(self.output_path).format(step=self.step)} diff --git a/dvclive/data/scalar.py b/dvclive/data/scalar.py index c70d44a0..d85f6631 100644 --- a/dvclive/data/scalar.py +++ b/dvclive/data/scalar.py @@ -10,6 +10,8 @@ class Scalar(Data): + suffixes = [".csv", ".tsv"] + @staticmethod def could_log(val: object) -> bool: if isinstance(val, (int, float)): diff --git a/dvclive/live.py b/dvclive/live.py index 96c56c60..64b83c55 100644 --- a/dvclive/live.py +++ b/dvclive/live.py @@ -1,11 +1,12 @@ import json import logging import os +import shutil from collections import OrderedDict from pathlib import Path from typing import Any, Dict, Optional, Union -from .data import Scalar +from .data import DATA_TYPES from .dvc import make_checkpoint, make_html from .error import ConfigMismatchError, InvalidDataTypeError @@ -46,19 +47,23 @@ def __init__( def _cleanup(self): - for dvclive_file in Path(self.dir).rglob("*.tsv"): - dvclive_file.unlink() + for data_type in DATA_TYPES: + for suffix in data_type.suffixes: + for data_file in Path(self.dir).rglob(f"*{suffix}"): + data_file.unlink() if os.path.exists(self.summary_path): os.remove(self.summary_path) if os.path.exists(self.html_path): - os.remove(self.html_path) + shutil.rmtree(Path(self.html_path).parent, ignore_errors=True) def _init_paths(self): os.makedirs(self.dir, exist_ok=True) if self._summary: self.make_summary() + if self._html: + os.makedirs(Path(self.html_path).parent, exist_ok=True) def init_from_env(self) -> None: from . import env @@ -96,11 +101,11 @@ def exists(self): @property def summary_path(self): - return self.dir + ".json" + return str(self.dir) + ".json" @property def html_path(self): - return self.dir + "_dvc_plots/index.html" + return str(self.dir) + "_dvc_plots/index.html" def get_step(self) -> int: return self._step @@ -119,15 +124,18 @@ def next_step(self): def log(self, name: str, val: Union[int, float]): + data = None if name in self._data: data = self._data[name] - elif Scalar.could_log(val): - data = Scalar(name, self.dir) - self._data[name] = data else: + for data_type in DATA_TYPES: + if data_type.could_log(val): + data = data_type(name, self.dir) + self._data[name] = data + if data is None: raise InvalidDataTypeError(name, type(val)) - data.dump(val, self._step) + if self._summary: self.make_summary() diff --git a/setup.py b/setup.py index 2cebf563..d84ef395 100644 --- a/setup.py +++ b/setup.py @@ -36,7 +36,7 @@ def run(self): _build_py.run(self) -mmcv = ["mmcv", "torch", "torchvision"] +mmcv = ["mmcv"] tf = ["tensorflow"] xgb = ["xgboost"] lgbm = ["lightgbm"] @@ -44,8 +44,9 @@ def run(self): catalyst = ["catalyst"] fastai = ["fastai"] pl = ["pytorch_lightning"] +image = ["pillow"] -all_libs = mmcv + tf + xgb + lgbm + hugginface + catalyst + fastai + pl +all_libs = mmcv + tf + xgb + lgbm + hugginface + catalyst + fastai + pl + image tests_requires = [ "pylint==2.5.3", @@ -74,9 +75,11 @@ def run(self): "tf": tf, "xgb": xgb, "lgbm": lgbm, + "mmcv": mmcv, "huggingface": hugginface, "catalyst": catalyst, "fastai": fastai, + "image": image, "pytorch_lightning": pl, }, keywords="data-science metrics machine-learning developer-tools ai", diff --git a/tests/test_data/test_image.py b/tests/test_data/test_image.py new file mode 100644 index 00000000..5e7a8792 --- /dev/null +++ b/tests/test_data/test_image.py @@ -0,0 +1,53 @@ +import os + +import numpy as np +import pytest +from PIL import Image + +# pylint: disable=unused-argument +from dvclive import Live +from tests.test_main import _parse_json + + +def test_PIL(tmp_dir): + dvclive = Live() + img = Image.new("RGB", (500, 500), (250, 250, 250)) + dvclive.log("image.png", img) + + assert (tmp_dir / dvclive.dir / "0" / "image.png").exists() + summary = _parse_json("dvclive.json") + + assert summary["image.png"] == os.path.join(dvclive.dir, "0", "image.png") + + +def test_invalid_extension(tmp_dir): + dvclive = Live() + img = Image.new("RGB", (500, 500), (250, 250, 250)) + with pytest.raises(ValueError): + dvclive.log("image.foo", img) + + +@pytest.mark.parametrize("shape", [(500, 500), (500, 500, 3), (500, 500, 4)]) +def test_numpy(tmp_dir, shape): + dvclive = Live() + img = np.ones(shape, np.uint8) * 255 + dvclive.log("image.png", img) + + assert (tmp_dir / dvclive.dir / "0" / "image.png").exists() + + +def test_step_formatting(tmp_dir): + dvclive = Live() + img = np.ones((500, 500, 3), np.uint8) + for _ in range(3): + dvclive.log("image.png", img) + dvclive.next_step() + + for step in range(3): + assert (tmp_dir / dvclive.dir / str(step) / "image.png").exists() + + summary = _parse_json("dvclive.json") + + assert summary["image.png"] == os.path.join( + dvclive.dir, str(step), "image.png" + ) diff --git a/tests/test_main.py b/tests/test_main.py index a756076f..0403fe92 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -18,13 +18,14 @@ def read_logs(path: str): - assert os.path.isdir(path) + path = Path(path) + assert path.is_dir() history = {} - for metric_file in Path(path).rglob("*.tsv"): - metric_name = str(metric_file).replace(path + os.path.sep, "") + for metric_file in path.rglob("*.tsv"): + metric_name = str(metric_file).replace(str(path) + os.path.sep, "") metric_name = metric_name.replace(".tsv", "") history[metric_name] = _parse_tsv(metric_file) - latest = _parse_json(path + ".json") + latest = _parse_json(str(path) + ".json") return history, latest @@ -67,9 +68,8 @@ def test_logging(tmp_dir, summary): dvclive.log("m1", 1) - assert (tmp_dir / "logs").is_dir() assert (tmp_dir / "logs" / "m1.tsv").is_file() - assert (tmp_dir / "logs.json").is_file() == summary + assert (tmp_dir / dvclive.summary_path).is_file() == summary if summary: _, s = read_logs("logs") @@ -82,8 +82,6 @@ def test_nested_logging(tmp_dir): dvclive.log("train/m1", 1) dvclive.log("val/val_1/m1", 1) - assert (tmp_dir / "logs").is_dir() - assert (tmp_dir / "logs" / "train").is_dir() assert (tmp_dir / "logs" / "val" / "val_1").is_dir() assert (tmp_dir / "logs" / "train" / "m1.tsv").is_file() assert (tmp_dir / "logs" / "val" / "val_1" / "m1.tsv").is_file() @@ -129,20 +127,19 @@ def test_cleanup(tmp_dir, summary, html): html_path = tmp_dir / dvclive.html_path if html: - html_path.parent.mkdir() html_path.touch() (tmp_dir / "logs" / "some_user_file.txt").touch() assert (tmp_dir / "logs" / "m1.tsv").is_file() - assert (tmp_dir / "logs.json").is_file() == summary + assert (tmp_dir / dvclive.summary_path).is_file() == summary assert html_path.is_file() == html dvclive = Live("logs", summary=summary) assert (tmp_dir / "logs" / "some_user_file.txt").is_file() assert not (tmp_dir / "logs" / "m1.tsv").is_file() - assert (tmp_dir / "logs.json").is_file() == summary + assert (tmp_dir / dvclive.summary_path).is_file() == summary assert not (html_path).is_file()