Skip to content
This repository has been archived by the owner on Sep 13, 2023. It is now read-only.

Commit

Permalink
dynamic types support array, dict, ndarray (#304)
Browse files Browse the repository at this point in the history
* dynamic types support array, dict, ndarray

* dynamic types support array, dict, ndarray

* dynamic types support array, dict, ndarray

* fix lint issue

* incorporated review comments

* incorporated review comments

* reverted licence

* correct fixture name

* fix windows numpy int dtype issue

Co-authored-by: Alexander Guschin <[email protected]>
  • Loading branch information
maheshambule and aguschin authored Jun 29, 2022
1 parent 3df18ab commit 9c4c650
Show file tree
Hide file tree
Showing 5 changed files with 770 additions and 96 deletions.
29 changes: 24 additions & 5 deletions mlem/contrib/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class NumpyNdarrayType(
type: ClassVar[str] = "ndarray"
libraries: ClassVar[List[ModuleType]] = [np]

shape: Tuple[Optional[int], ...]
shape: Optional[Tuple[Optional[int], ...]]
dtype: str

@staticmethod
Expand Down Expand Up @@ -133,7 +133,16 @@ def _subtype(self, subshape: Tuple[Optional[int], ...]):

def get_model(self, prefix: str = "") -> Type[BaseModel]:
return create_model(
prefix + "NumpyNdarray", __root__=(List[self._subtype(self.shape[1:])], ...) # type: ignore
prefix + "NumpyNdarray",
__root__=(
self._subtype(self.shape)
if self.shape is not None
else List[
Union[python_type_from_np_string_repr(self.dtype), List]
], # type: ignore
...,
)
# type: ignore
)

def serialize(self, instance: np.ndarray):
Expand All @@ -147,10 +156,20 @@ def serialize(self, instance: np.ndarray):
return instance.tolist()

def _check_shape(self, array, exc_type):
if tuple(array.shape)[1:] != self.shape[1:]:
raise exc_type(
f"given array is of shape: {(None,) + tuple(array.shape)[1:]}, expected: {self.shape}"
if self.shape is not None:
if len(array.shape) != len(self.shape):
raise exc_type(
f"given array is of rank: {len(array.shape)}, expected: {len(self.shape)}"
)

array_shape = tuple(
None if expected_dim is None else array_dim
for array_dim, expected_dim in zip(array.shape, self.shape)
)
if tuple(array_shape) != self.shape:
raise exc_type(
f"given array is of shape: {array_shape}, expected: {self.shape}"
)

def get_writer(self, project: str = None, filename: str = None, **kwargs):
return NumpyArrayWriter()
Expand Down
216 changes: 198 additions & 18 deletions mlem/core/data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Base classes for working with data in MLEM
"""
import builtins
import json
import posixpath
from abc import ABC, abstractmethod
from typing import (
Expand All @@ -18,7 +19,7 @@
)

import flatdict
from pydantic import BaseModel
from pydantic import BaseModel, validator
from pydantic.main import create_model

from mlem.core.artifacts import Artifacts, Storage
Expand Down Expand Up @@ -70,8 +71,10 @@ def bind(self, data: Any):
return self

@classmethod
def create(cls, obj: Any, **kwargs):
return DataAnalyzer.analyze(obj, **kwargs).bind(obj)
def create(cls, obj: Any, is_dynamic: bool = False, **kwargs):
return DataAnalyzer.analyze(obj, is_dynamic=is_dynamic, **kwargs).bind(
obj
)


class DataSerializer(ABC):
Expand Down Expand Up @@ -367,7 +370,7 @@ def get_model(self, prefix: str = "") -> Type[BaseModel]:

def _check_type_and_size(obj, dtype, size, exc_type):
DataType.check_type(obj, dtype, exc_type)
if size != -1 and len(obj) != size:
if size is not None and len(obj) != size:
raise exc_type(
f"given {dtype.__name__} has len: {len(obj)}, expected: {size}"
)
Expand Down Expand Up @@ -449,44 +452,77 @@ def is_object_valid(cls, obj: Any) -> bool:
return isinstance(obj, (list, tuple))

@classmethod
def process(cls, obj, **kwargs) -> DataType:
def process(cls, obj, is_dynamic: bool = False, **kwargs) -> DataType:
if isinstance(obj, tuple):
return TupleType(items=[DataAnalyzer.analyze(o) for o in obj])
return TupleType(
items=[
DataAnalyzer.analyze(o, is_dynamic=is_dynamic, **kwargs)
for o in obj
]
)

py_types = {type(o) for o in obj}
if len(obj) <= 1 or len(py_types) > 1:
return ListType(items=[DataAnalyzer.analyze(o) for o in obj])
return ListType(
items=[
DataAnalyzer.analyze(o, is_dynamic=is_dynamic, **kwargs)
for o in obj
]
)

size = None if is_dynamic else len(obj)

if not py_types.intersection(
PrimitiveType.PRIMITIVES
): # py_types is guaranteed to be singleton set here
items_types = [DataAnalyzer.analyze(o) for o in obj]
items_types = [
DataAnalyzer.analyze(o, is_dynamic=is_dynamic, **kwargs)
for o in obj
]
first, *others = items_types
for other in others:
if first != other:
return ListType(items=items_types)
return ArrayType(dtype=first, size=len(obj))
return ArrayType(dtype=first, size=size)

# optimization for large lists of same primitive type elements
return ArrayType(dtype=DataAnalyzer.analyze(obj[0]), size=len(obj))
return ArrayType(
dtype=DataAnalyzer.analyze(
obj[0], is_dynamic=is_dynamic, **kwargs
),
size=size,
)


class DictTypeHook(DataHook):
@classmethod
def is_object_valid(cls, obj: Any) -> bool:
return isinstance(obj, dict)

@classmethod
def process(
cls, obj: Any, is_dynamic: bool = False, **kwargs
) -> Union["DictType", "DynamicDictType"]:
if not is_dynamic:
return DictType.process(obj, **kwargs)
return DynamicDictType.process(obj, **kwargs)


class DictType(DataType, DataSerializer, DataHook):
class DictType(DataType, DataSerializer):
"""
DataType for dict
DataType for dict with fixed set of keys
"""

type: ClassVar[str] = "dict"
item_types: Dict[str, DataType]

@classmethod
def is_object_valid(cls, obj: Any) -> bool:
return isinstance(obj, dict)

@classmethod
def process(cls, obj: Any, **kwargs) -> "DictType":
def process(cls, obj, **kwargs):
return DictType(
item_types={k: DataAnalyzer.analyze(v) for (k, v) in obj.items()}
item_types={
k: DataAnalyzer.analyze(v, is_dynamic=False, **kwargs)
for (k, v) in obj.items()
}
)

def deserialize(self, obj):
Expand Down Expand Up @@ -577,6 +613,150 @@ def read_batch(
raise NotImplementedError


class DynamicDictType(DataType, DataSerializer):
"""
Dynamic DataType for dict without fixed set of keys
"""

type: ClassVar[str] = "d_dict"

key_type: PrimitiveType
value_type: DataType

@validator("key_type")
def is_valid_key_type( # pylint: disable=no-self-argument
cls, key_type # noqa: B902
):
if key_type.ptype not in ["str", "int", "float"]:
raise ValueError(f"key_type {key_type.ptype} is not supported")
return key_type

def deserialize(self, obj):
self.check_type(obj, dict, DeserializationError)
return {
self.key_type.get_serializer()
.deserialize(
k,
): self.value_type.get_serializer()
.deserialize(
v,
)
for k, v in obj.items()
}

def serialize(self, instance: dict):
self._check_types(instance, SerializationError)

return {
self.key_type.get_serializer()
.serialize(
k,
): self.value_type.get_serializer()
.serialize(
v,
)
for k, v in instance.items()
}

@classmethod
def process(cls, obj, **kwargs) -> "DynamicDictType":
return DynamicDictType(
key_type=DataAnalyzer.analyze(
next(iter(obj.keys())), is_dynamic=True, **kwargs
),
value_type=DataAnalyzer.analyze(
next(iter(obj.values())), is_dynamic=True, **kwargs
),
)

def _check_types(self, obj, exc_type, ignore_key_type: bool = False):
self.check_type(obj, dict, exc_type)

obj_type = self.process(obj)
if ignore_key_type:
obj_types: Union[
Tuple[PrimitiveType, DataType], Tuple[DataType]
] = (obj_type.value_type,)
expected_types: Union[
Tuple[PrimitiveType, DataType], Tuple[DataType]
] = (self.value_type,)
else:
obj_types = (obj_type.key_type, obj_type.value_type)
expected_types = (self.key_type, self.value_type)
if obj_types != expected_types:
raise exc_type(
f"given dict has type: {obj_types}, expected: {expected_types}"
)

def get_requirements(self) -> Requirements:
return sum(
[
self.key_type.get_requirements(),
self.value_type.get_requirements(),
],
Requirements.new(),
)

def get_writer(
self, project: str = None, filename: str = None, **kwargs
) -> "DynamicDictWriter":
return DynamicDictWriter(**kwargs)

def get_model(self, prefix="") -> Type[BaseModel]:
field_type = (
Dict[ # type: ignore
self.key_type.get_serializer().get_model(
prefix + "_key_" # noqa: F821
),
self.value_type.get_serializer().get_model(
prefix + "_val_" # noqa: F821
),
],
...,
)
return create_model(prefix + "DynamicDictType", __root__=field_type) # type: ignore


class DynamicDictWriter(DataWriter):
type: ClassVar[str] = "d_dict"

def write(
self, data: DataType, storage: Storage, path: str
) -> Tuple[DataReader, Artifacts]:
if not isinstance(data, DynamicDictType):
raise ValueError(
f"expected data to be of DynamicDictTypeWriter, got {type(data)} instead"
)
with storage.open(path) as (f, art):
f.write(
json.dumps(data.get_serializer().serialize(data.data)).encode(
"utf-8"
)
)
return DynamicDictReader(data_type=data), {DataWriter.art_name: art}


class DynamicDictReader(DataReader):
type: ClassVar[str] = "d_dict"
data_type: DynamicDictType

def read(self, artifacts: Artifacts) -> DataType:
if DataWriter.art_name not in artifacts:
raise ValueError(
f"Wrong artifacts {artifacts}: should be one {DataWriter.art_name} file"
)
with artifacts[DataWriter.art_name].open() as f:
data = json.load(f)
# json stores keys as strings. Deserialize string keys as well as values.
data = self.data_type.deserialize(data)
return self.data_type.copy().bind(data)

def read_batch(
self, artifacts: Artifacts, batch_size: int
) -> Iterator[DataType]:
raise NotImplementedError


#
#
# class BytesDataType(DataType):
Expand Down
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Callable, Type

import git
import numpy as np
import pandas as pd
import pytest
from fastapi.testclient import TestClient
Expand Down Expand Up @@ -426,3 +427,10 @@ def disable_colorama():
import colorama

colorama.init = lambda: None


@pytest.fixture
def numpy_default_int_dtype():
# default int type is platform dependent.
# For windows 64 it is int32 and for linux 64 it is int64
return str(np.array([1]).dtype)
Loading

0 comments on commit 9c4c650

Please sign in to comment.