Skip to content
This repository has been archived by the owner on Sep 13, 2023. It is now read-only.

Commit

Permalink
fix conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
madhur-tandon committed Oct 27, 2022
2 parents 2846cc6 + 86105ee commit 9234258
Show file tree
Hide file tree
Showing 8 changed files with 229 additions and 44 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/check-test-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Cancel Previous Runs
uses: styfle/cancel-workflow-action@0.10.0
uses: styfle/cancel-workflow-action@0.11.0
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v3
Expand Down Expand Up @@ -78,7 +78,7 @@ jobs:
echo "::set-output name=dir::$(pip cache dir)"
- name: set pip cache
id: pip-cache
uses: actions/[email protected].8
uses: actions/[email protected].11
with:
path: ${{ steps.pip-cache-dir.outputs.dir }}
key: ${{ runner.os }}-pip-${{ hashFiles('setup.cfg') }}
Expand Down
96 changes: 77 additions & 19 deletions mlem/contrib/lightgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import os
import posixpath
import tempfile
from typing import Any, ClassVar, Iterator, List, Optional, Tuple, Type
from typing import Any, ClassVar, Iterator, Optional, Tuple, Type

import flatdict
import lightgbm as lgb
from pydantic import BaseModel

Expand All @@ -34,6 +35,8 @@
)

LGB_REQUIREMENT = UnixPackageRequirement(package_name="libgomp1")
LIGHTGBM_DATA = "inner"
LIGHTGBM_LABEL = "label"


class LightGBMDataType(
Expand All @@ -43,21 +46,38 @@ class LightGBMDataType(
:class:`.DataType` implementation for `lightgbm.Dataset` type
:param inner: :class:`.DataType` instance for underlying data
:param labels: :class:`.DataType` instance for underlying labels
"""

type: ClassVar[str] = "lightgbm"
valid_types: ClassVar = (lgb.Dataset,)
inner: DataType
"""Inner DataType"""
labels: Optional[DataType]

def serialize(self, instance: Any) -> dict:
self.check_type(instance, lgb.Dataset, SerializationError)
if self.labels is not None:
return {
LIGHTGBM_DATA: self.inner.get_serializer().serialize(
instance.data
),
LIGHTGBM_LABEL: self.labels.get_serializer().serialize(
instance.label
),
}
return self.inner.get_serializer().serialize(instance.data)

def deserialize(self, obj: dict) -> Any:
v = self.inner.get_serializer().deserialize(obj)
if self.labels is not None:
data = self.inner.get_serializer().deserialize(obj[LIGHTGBM_DATA])
label = self.labels.get_serializer().deserialize(
obj[LIGHTGBM_LABEL]
)
else:
data = self.inner.get_serializer().deserialize(obj)
label = None
try:
return lgb.Dataset(v, free_raw_data=False)
return lgb.Dataset(data, label=label, free_raw_data=False)
except ValueError as e:
raise DeserializationError(
f"object: {obj} could not be converted to lightgbm dataset"
Expand All @@ -77,7 +97,12 @@ def get_writer(

@classmethod
def process(cls, obj: Any, **kwargs) -> DataType:
return LightGBMDataType(inner=DataAnalyzer.analyze(obj.data))
return LightGBMDataType(
inner=DataAnalyzer.analyze(obj.data),
labels=DataAnalyzer.analyze(obj.label)
if obj.label is not None
else None,
)

def get_model(self, prefix: str = "") -> Type[BaseModel]:
return self.inner.get_serializer().get_model(prefix)
Expand All @@ -95,19 +120,42 @@ def write(
raise ValueError(
f"expected data to be of LightGBMDataType, got {type(data)} instead"
)
lightgbm_construct = data.data.construct()
raw_data = lightgbm_construct.get_data()
underlying_labels = lightgbm_construct.get_label().tolist()
inner_reader, art = data.inner.get_writer().write(
data.inner.copy().bind(raw_data), storage, path
)

lightgbm_raw = data.data

if data.labels is not None:
inner_reader, inner_art = data.inner.get_writer().write(
data.inner.copy().bind(lightgbm_raw.data),
storage,
posixpath.join(path, LIGHTGBM_DATA),
)
labels_reader, labels_art = data.labels.get_writer().write(
data.labels.copy().bind(lightgbm_raw.label),
storage,
posixpath.join(path, LIGHTGBM_LABEL),
)
res = dict(
flatdict.FlatterDict(
{LIGHTGBM_DATA: inner_art, LIGHTGBM_LABEL: labels_art},
delimiter="/",
)
)
else:
inner_reader, inner_art = data.inner.get_writer().write(
data.inner.copy().bind(lightgbm_raw.data),
storage,
path,
)
res = inner_art
labels_reader = None

return (
LightGBMDataReader(
data_type=data,
inner=inner_reader,
label=underlying_labels,
labels=labels_reader,
),
art,
res,
)


Expand All @@ -117,15 +165,25 @@ class LightGBMDataReader(DataReader):
type: ClassVar[str] = "lightgbm"
data_type: LightGBMDataType
inner: DataReader
"""Inner reader"""
label: List
"""List of labels"""
labels: Optional[DataReader]

def read(self, artifacts: Artifacts) -> DataType:
inner_data_type = self.inner.read(artifacts)
return LightGBMDataType(inner=inner_data_type).bind(
if self.labels is not None:
artifacts = flatdict.FlatterDict(artifacts, delimiter="/")
inner_data_type = self.inner.read(artifacts[LIGHTGBM_DATA]) # type: ignore[arg-type]
labels_data_type = self.labels.read(artifacts[LIGHTGBM_LABEL]) # type: ignore[arg-type]
else:
inner_data_type = self.inner.read(artifacts)
labels_data_type = None
return LightGBMDataType(
inner=inner_data_type, labels=labels_data_type
).bind(
lgb.Dataset(
inner_data_type.data, label=self.label, free_raw_data=False
inner_data_type.data,
label=labels_data_type.data
if labels_data_type is not None
else None,
free_raw_data=False,
)
)

Expand Down
7 changes: 5 additions & 2 deletions mlem/contrib/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""
from typing import Any, ClassVar, Iterator, List, Optional, Tuple

import cloudpickle
import torch
from pydantic import conlist, create_model

Expand Down Expand Up @@ -146,9 +147,11 @@ class TorchModelIO(ModelIO):

def dump(self, storage: Storage, path, model) -> Artifacts:
self.is_jit = isinstance(model, torch.jit.ScriptModule)
save = torch.jit.save if self.is_jit else torch.save
with storage.open(path) as (f, art):
save(model, f)
if self.is_jit:
torch.jit.save(model, f)
else:
torch.save(model, f, pickle_module=cloudpickle)
return {self.art_name: art}

def load(self, artifacts: Artifacts):
Expand Down
14 changes: 9 additions & 5 deletions mlem/core/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
searching for MLEM object by given path.
"""
import logging
import os
import posixpath
from typing import Any, Dict, Optional, Type, TypeVar, Union, overload

Expand Down Expand Up @@ -43,7 +44,7 @@ def get_object_metadata(

def save(
obj: Any,
path: str,
path: Union[str, os.PathLike],
project: Optional[str] = None,
sample_data=None,
fs: Optional[AbstractFileSystem] = None,
Expand All @@ -70,12 +71,13 @@ def save(
sample_data,
params=params,
)
path = os.fspath(path)
meta.dump(path, fs=fs, project=project)
return meta


def load(
path: str,
path: Union[str, os.PathLike],
project: Optional[str] = None,
rev: Optional[str] = None,
batch_size: Optional[int] = None,
Expand All @@ -93,6 +95,7 @@ def load(
Returns:
Any: Python object saved by MLEM
"""
path = os.fspath(path)
meta = load_meta(
path,
project=project,
Expand All @@ -110,7 +113,7 @@ def load(

@overload
def load_meta(
path: str,
path: Union[str, os.PathLike],
project: Optional[str] = None,
rev: Optional[str] = None,
follow_links: bool = True,
Expand All @@ -124,7 +127,7 @@ def load_meta(

@overload
def load_meta(
path: str,
path: Union[str, os.PathLike],
project: Optional[str] = None,
rev: Optional[str] = None,
follow_links: bool = True,
Expand All @@ -137,7 +140,7 @@ def load_meta(


def load_meta(
path: str,
path: Union[str, os.PathLike],
project: Optional[str] = None,
rev: Optional[str] = None,
follow_links: bool = True,
Expand All @@ -160,6 +163,7 @@ def load_meta(
Returns:
MlemObject: Saved MlemObject
"""
path = os.fspath(path)
location = Location.resolve(
path=make_posix(path),
project=make_posix(project),
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from setuptools import find_packages, setup

install_requires = [
"cloudpickle",
"dill",
"requests",
"isort>=5.10",
Expand Down
Loading

0 comments on commit 9234258

Please sign in to comment.