diff --git a/.github/workflows/check-test-release.yml b/.github/workflows/check-test-release.yml index 0d886d04..d09cd2ff 100644 --- a/.github/workflows/check-test-release.yml +++ b/.github/workflows/check-test-release.yml @@ -29,7 +29,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Cancel Previous Runs - uses: styfle/cancel-workflow-action@0.10.0 + uses: styfle/cancel-workflow-action@0.11.0 with: access_token: ${{ github.token }} - uses: actions/checkout@v3 @@ -78,7 +78,7 @@ jobs: echo "::set-output name=dir::$(pip cache dir)" - name: set pip cache id: pip-cache - uses: actions/cache@v3.0.8 + uses: actions/cache@v3.0.11 with: path: ${{ steps.pip-cache-dir.outputs.dir }} key: ${{ runner.os }}-pip-${{ hashFiles('setup.cfg') }} diff --git a/mlem/contrib/lightgbm.py b/mlem/contrib/lightgbm.py index 4914d9db..8e7b117b 100644 --- a/mlem/contrib/lightgbm.py +++ b/mlem/contrib/lightgbm.py @@ -7,8 +7,9 @@ import os import posixpath import tempfile -from typing import Any, ClassVar, Iterator, List, Optional, Tuple, Type +from typing import Any, ClassVar, Iterator, Optional, Tuple, Type +import flatdict import lightgbm as lgb from pydantic import BaseModel @@ -34,6 +35,8 @@ ) LGB_REQUIREMENT = UnixPackageRequirement(package_name="libgomp1") +LIGHTGBM_DATA = "inner" +LIGHTGBM_LABEL = "label" class LightGBMDataType( @@ -43,21 +46,38 @@ class LightGBMDataType( :class:`.DataType` implementation for `lightgbm.Dataset` type :param inner: :class:`.DataType` instance for underlying data + :param labels: :class:`.DataType` instance for underlying labels """ type: ClassVar[str] = "lightgbm" valid_types: ClassVar = (lgb.Dataset,) inner: DataType - """Inner DataType""" + labels: Optional[DataType] def serialize(self, instance: Any) -> dict: self.check_type(instance, lgb.Dataset, SerializationError) + if self.labels is not None: + return { + LIGHTGBM_DATA: self.inner.get_serializer().serialize( + instance.data + ), + LIGHTGBM_LABEL: self.labels.get_serializer().serialize( + instance.label + ), + } return self.inner.get_serializer().serialize(instance.data) def deserialize(self, obj: dict) -> Any: - v = self.inner.get_serializer().deserialize(obj) + if self.labels is not None: + data = self.inner.get_serializer().deserialize(obj[LIGHTGBM_DATA]) + label = self.labels.get_serializer().deserialize( + obj[LIGHTGBM_LABEL] + ) + else: + data = self.inner.get_serializer().deserialize(obj) + label = None try: - return lgb.Dataset(v, free_raw_data=False) + return lgb.Dataset(data, label=label, free_raw_data=False) except ValueError as e: raise DeserializationError( f"object: {obj} could not be converted to lightgbm dataset" @@ -77,7 +97,12 @@ def get_writer( @classmethod def process(cls, obj: Any, **kwargs) -> DataType: - return LightGBMDataType(inner=DataAnalyzer.analyze(obj.data)) + return LightGBMDataType( + inner=DataAnalyzer.analyze(obj.data), + labels=DataAnalyzer.analyze(obj.label) + if obj.label is not None + else None, + ) def get_model(self, prefix: str = "") -> Type[BaseModel]: return self.inner.get_serializer().get_model(prefix) @@ -95,19 +120,42 @@ def write( raise ValueError( f"expected data to be of LightGBMDataType, got {type(data)} instead" ) - lightgbm_construct = data.data.construct() - raw_data = lightgbm_construct.get_data() - underlying_labels = lightgbm_construct.get_label().tolist() - inner_reader, art = data.inner.get_writer().write( - data.inner.copy().bind(raw_data), storage, path - ) + + lightgbm_raw = data.data + + if data.labels is not None: + inner_reader, inner_art = data.inner.get_writer().write( + data.inner.copy().bind(lightgbm_raw.data), + storage, + posixpath.join(path, LIGHTGBM_DATA), + ) + labels_reader, labels_art = data.labels.get_writer().write( + data.labels.copy().bind(lightgbm_raw.label), + storage, + posixpath.join(path, LIGHTGBM_LABEL), + ) + res = dict( + flatdict.FlatterDict( + {LIGHTGBM_DATA: inner_art, LIGHTGBM_LABEL: labels_art}, + delimiter="/", + ) + ) + else: + inner_reader, inner_art = data.inner.get_writer().write( + data.inner.copy().bind(lightgbm_raw.data), + storage, + path, + ) + res = inner_art + labels_reader = None + return ( LightGBMDataReader( data_type=data, inner=inner_reader, - label=underlying_labels, + labels=labels_reader, ), - art, + res, ) @@ -117,15 +165,25 @@ class LightGBMDataReader(DataReader): type: ClassVar[str] = "lightgbm" data_type: LightGBMDataType inner: DataReader - """Inner reader""" - label: List - """List of labels""" + labels: Optional[DataReader] def read(self, artifacts: Artifacts) -> DataType: - inner_data_type = self.inner.read(artifacts) - return LightGBMDataType(inner=inner_data_type).bind( + if self.labels is not None: + artifacts = flatdict.FlatterDict(artifacts, delimiter="/") + inner_data_type = self.inner.read(artifacts[LIGHTGBM_DATA]) # type: ignore[arg-type] + labels_data_type = self.labels.read(artifacts[LIGHTGBM_LABEL]) # type: ignore[arg-type] + else: + inner_data_type = self.inner.read(artifacts) + labels_data_type = None + return LightGBMDataType( + inner=inner_data_type, labels=labels_data_type + ).bind( lgb.Dataset( - inner_data_type.data, label=self.label, free_raw_data=False + inner_data_type.data, + label=labels_data_type.data + if labels_data_type is not None + else None, + free_raw_data=False, ) ) diff --git a/mlem/contrib/torch.py b/mlem/contrib/torch.py index 9f379684..897cbcb3 100644 --- a/mlem/contrib/torch.py +++ b/mlem/contrib/torch.py @@ -7,6 +7,7 @@ """ from typing import Any, ClassVar, Iterator, List, Optional, Tuple +import cloudpickle import torch from pydantic import conlist, create_model @@ -146,9 +147,11 @@ class TorchModelIO(ModelIO): def dump(self, storage: Storage, path, model) -> Artifacts: self.is_jit = isinstance(model, torch.jit.ScriptModule) - save = torch.jit.save if self.is_jit else torch.save with storage.open(path) as (f, art): - save(model, f) + if self.is_jit: + torch.jit.save(model, f) + else: + torch.save(model, f, pickle_module=cloudpickle) return {self.art_name: art} def load(self, artifacts: Artifacts): diff --git a/mlem/core/metadata.py b/mlem/core/metadata.py index 7aebf208..07c35d48 100644 --- a/mlem/core/metadata.py +++ b/mlem/core/metadata.py @@ -3,6 +3,7 @@ searching for MLEM object by given path. """ import logging +import os import posixpath from typing import Any, Dict, Optional, Type, TypeVar, Union, overload @@ -43,7 +44,7 @@ def get_object_metadata( def save( obj: Any, - path: str, + path: Union[str, os.PathLike], project: Optional[str] = None, sample_data=None, fs: Optional[AbstractFileSystem] = None, @@ -70,12 +71,13 @@ def save( sample_data, params=params, ) + path = os.fspath(path) meta.dump(path, fs=fs, project=project) return meta def load( - path: str, + path: Union[str, os.PathLike], project: Optional[str] = None, rev: Optional[str] = None, batch_size: Optional[int] = None, @@ -93,6 +95,7 @@ def load( Returns: Any: Python object saved by MLEM """ + path = os.fspath(path) meta = load_meta( path, project=project, @@ -110,7 +113,7 @@ def load( @overload def load_meta( - path: str, + path: Union[str, os.PathLike], project: Optional[str] = None, rev: Optional[str] = None, follow_links: bool = True, @@ -124,7 +127,7 @@ def load_meta( @overload def load_meta( - path: str, + path: Union[str, os.PathLike], project: Optional[str] = None, rev: Optional[str] = None, follow_links: bool = True, @@ -137,7 +140,7 @@ def load_meta( def load_meta( - path: str, + path: Union[str, os.PathLike], project: Optional[str] = None, rev: Optional[str] = None, follow_links: bool = True, @@ -160,6 +163,7 @@ def load_meta( Returns: MlemObject: Saved MlemObject """ + path = os.fspath(path) location = Location.resolve( path=make_posix(path), project=make_posix(project), diff --git a/setup.py b/setup.py index 2be1ea5a..1966e64d 100644 --- a/setup.py +++ b/setup.py @@ -3,6 +3,7 @@ from setuptools import find_packages, setup install_requires = [ + "cloudpickle", "dill", "requests", "isort>=5.10", diff --git a/tests/contrib/test_lightgbm.py b/tests/contrib/test_lightgbm.py index bd2e193a..6addd714 100644 --- a/tests/contrib/test_lightgbm.py +++ b/tests/contrib/test_lightgbm.py @@ -4,6 +4,8 @@ import pytest from mlem.contrib.lightgbm import ( + LIGHTGBM_DATA, + LIGHTGBM_LABEL, LightGBMDataReader, LightGBMDataType, LightGBMDataWriter, @@ -12,7 +14,12 @@ from mlem.contrib.numpy import NumpyNdarrayType from mlem.contrib.pandas import DataFrameType from mlem.core.artifacts import LOCAL_STORAGE -from mlem.core.data_type import DataAnalyzer, DataType +from mlem.core.data_type import ( + ArrayType, + DataAnalyzer, + DataType, + PrimitiveType, +) from mlem.core.errors import DeserializationError, SerializationError from mlem.core.model import ModelAnalyzer, ModelType from mlem.core.requirements import UnixPackageRequirement @@ -46,7 +53,7 @@ def df_payload(): def data_df(df_payload): return lgb.Dataset( df_payload, - label=np.array([0, 1]).tolist(), + label=np.array([0, 1]), free_raw_data=False, ) @@ -75,6 +82,8 @@ def test_hook_np(dtype_np: DataType): assert set(dtype_np.get_requirements().modules) == {"lightgbm", "numpy"} assert isinstance(dtype_np, LightGBMDataType) assert isinstance(dtype_np.inner, NumpyNdarrayType) + assert isinstance(dtype_np.labels, ArrayType) + assert dtype_np.labels.dtype == PrimitiveType(data=None, ptype="float") assert dtype_np.get_model().__name__ == dtype_np.inner.get_model().__name__ assert dtype_np.get_model().schema() == { "title": "NumpyNdarray", @@ -92,6 +101,7 @@ def test_hook_df(dtype_df: DataType): assert set(dtype_df.get_requirements().modules) == {"lightgbm", "pandas"} assert isinstance(dtype_df, LightGBMDataType) assert isinstance(dtype_df.inner, DataFrameType) + assert isinstance(dtype_df.labels, NumpyNdarrayType) assert dtype_df.get_model().__name__ == dtype_df.inner.get_model().__name__ assert dtype_df.get_model().schema() == { "title": "DataFrame", @@ -116,54 +126,131 @@ def test_hook_df(dtype_df: DataType): @pytest.mark.parametrize( - "lgb_dtype, data_type", - [("dtype_np", NumpyNdarrayType), ("dtype_df", DataFrameType)], + "lgb_dtype, data_type, label_type", + [ + ("dtype_np", NumpyNdarrayType, ArrayType), + ("dtype_df", DataFrameType, NumpyNdarrayType), + ], ) -def test_lightgbm_source(lgb_dtype, data_type, request): +def test_lightgbm_source(lgb_dtype, data_type, label_type, request): lgb_dtype = request.getfixturevalue(lgb_dtype) assert isinstance(lgb_dtype, LightGBMDataType) assert isinstance(lgb_dtype.inner, data_type) + assert isinstance(lgb_dtype.labels, label_type) def custom_assert(x, y): assert hasattr(x, "data") assert hasattr(y, "data") assert all(x.data == y.data) - assert all(x.label == y.label) + label_check = x.label == y.label + if isinstance(label_check, (list, np.ndarray)): + assert all(label_check) + else: + assert label_check - data_write_read_check( + artifacts = data_write_read_check( lgb_dtype, writer=LightGBMDataWriter(), reader_type=LightGBMDataReader, custom_assert=custom_assert, ) + if isinstance(lgb_dtype.inner, NumpyNdarrayType): + assert list(artifacts.keys()) == [ + f"{LIGHTGBM_DATA}/data", + f"{LIGHTGBM_LABEL}/0/data", + f"{LIGHTGBM_LABEL}/1/data", + f"{LIGHTGBM_LABEL}/2/data", + f"{LIGHTGBM_LABEL}/3/data", + f"{LIGHTGBM_LABEL}/4/data", + ] + assert artifacts[f"{LIGHTGBM_DATA}/data"].uri.endswith( + f"data/{LIGHTGBM_DATA}" + ) + assert artifacts[f"{LIGHTGBM_LABEL}/0/data"].uri.endswith( + f"data/{LIGHTGBM_LABEL}/0" + ) + assert artifacts[f"{LIGHTGBM_LABEL}/1/data"].uri.endswith( + f"data/{LIGHTGBM_LABEL}/1" + ) + assert artifacts[f"{LIGHTGBM_LABEL}/2/data"].uri.endswith( + f"data/{LIGHTGBM_LABEL}/2" + ) + assert artifacts[f"{LIGHTGBM_LABEL}/3/data"].uri.endswith( + f"data/{LIGHTGBM_LABEL}/3" + ) + assert artifacts[f"{LIGHTGBM_LABEL}/4/data"].uri.endswith( + f"data/{LIGHTGBM_LABEL}/4" + ) + else: + assert list(artifacts.keys()) == [ + f"{LIGHTGBM_DATA}/data", + f"{LIGHTGBM_LABEL}/data", + ] + assert artifacts[f"{LIGHTGBM_DATA}/data"].uri.endswith( + f"data/{LIGHTGBM_DATA}" + ) + assert artifacts[f"{LIGHTGBM_LABEL}/data"].uri.endswith( + f"data/{LIGHTGBM_LABEL}" + ) + def test_serialize__np(dtype_np, np_payload): - ds = lgb.Dataset(np_payload) + ds = lgb.Dataset(np_payload, label=np_payload.reshape((-1,)).tolist()) payload = dtype_np.serialize(ds) - assert payload == np_payload.tolist() + assert payload[LIGHTGBM_DATA] == np_payload.tolist() + assert payload[LIGHTGBM_LABEL] == np_payload.reshape((-1,)).tolist() with pytest.raises(SerializationError): dtype_np.serialize({"abc": 123}) # wrong type def test_deserialize__np(dtype_np, np_payload): - ds = dtype_np.deserialize(np_payload) + ds = dtype_np.deserialize( + { + LIGHTGBM_DATA: np_payload, + LIGHTGBM_LABEL: np_payload.reshape((-1,)).tolist(), + } + ) assert isinstance(ds, lgb.Dataset) assert np.all(ds.data == np_payload) + assert np.all(ds.label == np_payload.reshape((-1,)).tolist()) with pytest.raises(DeserializationError): - dtype_np.deserialize([[1], ["abc"]]) # illegal matrix + dtype_np.deserialize({LIGHTGBM_DATA: [[1], ["abc"]]}) # illegal matrix -def test_serialize__df(dtype_df, df_payload): - ds = lgb.Dataset(df_payload) - payload = dtype_df.serialize(ds) - assert payload["values"] == df_payload.to_dict("records") +def test_serialize__df(df_payload): + ds = lgb.Dataset(df_payload, label=None, free_raw_data=False) + payload = DataType.create(obj=ds) + assert payload.serialize(ds)["values"] == df_payload.to_dict("records") + assert LIGHTGBM_LABEL not in payload + + def custom_assert(x, y): + assert hasattr(x, "data") + assert hasattr(y, "data") + assert all(x.data == y.data) + assert x.label == y.label + + artifacts = data_write_read_check( + payload, + writer=LightGBMDataWriter(), + reader_type=LightGBMDataReader, + custom_assert=custom_assert, + ) + + assert len(artifacts.keys()) == 1 + assert list(artifacts.keys()) == ["data"] + assert artifacts["data"].uri.endswith("/data") def test_deserialize__df(dtype_df, df_payload): - ds = dtype_df.deserialize({"values": df_payload}) + ds = dtype_df.deserialize( + { + LIGHTGBM_DATA: {"values": df_payload}, + LIGHTGBM_LABEL: np.array([0, 1]).tolist(), + } + ) assert isinstance(ds, lgb.Dataset) assert ds.data.equals(df_payload) diff --git a/tests/contrib/test_torch.py b/tests/contrib/test_torch.py index f8615f3a..7565a2e1 100644 --- a/tests/contrib/test_torch.py +++ b/tests/contrib/test_torch.py @@ -1,4 +1,5 @@ import os +import subprocess import pytest import torch @@ -16,6 +17,7 @@ from mlem.core.errors import DeserializationError, SerializationError from mlem.core.model import ModelAnalyzer from mlem.core.objects import MlemModel +from mlem.utils.path import make_posix from tests.conftest import data_write_read_check @@ -174,6 +176,21 @@ def test_torch_import(tmp_path, net, torchsave): assert isinstance(meta.model_type, TorchModel) +def test_torch_import_in_separate_shell(tmp_path): + path = make_posix(os.path.join(str(tmp_path), "model")) + m = MyNet() + save(m, path) + x = subprocess.run( + [ + "python", + "-c", + f""""from mlem.api import load; loaded = load('{path}')" """, + ], + check=True, + ) + assert x.returncode == 0 + + # Copyright 2019 Zyfra # Copyright 2021 Iterative # diff --git a/tests/core/test_metadata.py b/tests/core/test_metadata.py index 89f4a03c..c8572220 100644 --- a/tests/core/test_metadata.py +++ b/tests/core/test_metadata.py @@ -1,6 +1,7 @@ import os import posixpath import tempfile +from pathlib import Path from urllib.parse import quote_plus import pytest @@ -38,6 +39,20 @@ def test_saving_with_project(model, tmpdir): load_meta(path) +def test_saving_with_pathlib(model, tmpdir): + # by default, tmpdir is of type: `py._path.local.LocalPath`, + # see the test below + path = Path(tmpdir) / "obj" + save(model, path) + load_meta(path) + + +def test_saving_with_localpath(model, tmpdir): + path = tmpdir / "obj" + save(model, path) + load_meta(path) + + def test_model_saving_without_sample_data(model, tmpdir_factory): path = str( tmpdir_factory.mktemp("saving-models-without-sample-data") / "model"