From 27554bd347df9f7886c20ec2ad09c2a817e08424 Mon Sep 17 00:00:00 2001 From: Mahesh Ambule <ambulemahesh@gmail.com> Date: Mon, 20 Jun 2022 15:53:43 +0530 Subject: [PATCH 1/9] dynamic types support array, dict, ndarray --- mlem/contrib/lightgbm.py | 8 +- mlem/contrib/numpy.py | 29 +++- mlem/core/data_type.py | 201 +++++++++++++++++++++--- tests/contrib/test_numpy.py | 216 ++++++++++++++++++++------ tests/core/test_data_type.py | 288 ++++++++++++++++++++++++++++++----- 5 files changed, 626 insertions(+), 116 deletions(-) diff --git a/mlem/contrib/lightgbm.py b/mlem/contrib/lightgbm.py index b45fad44..17e9b263 100644 --- a/mlem/contrib/lightgbm.py +++ b/mlem/contrib/lightgbm.py @@ -69,8 +69,12 @@ def get_writer( return LightGBMDataWriter(**kwargs) @classmethod - def process(cls, obj: Any, **kwargs) -> DataType: - return LightGBMDataType(inner=DataAnalyzer.analyze(obj.data)) + def process(cls, obj: Any, is_dynamic: bool = False, **kwargs) -> DataType: + return LightGBMDataType( + inner=DataAnalyzer.analyze( + obj.data, is_dynamic=is_dynamic, **kwargs + ) + ) def get_model(self, prefix: str = "") -> Type[BaseModel]: return self.inner.get_serializer().get_model(prefix) diff --git a/mlem/contrib/numpy.py b/mlem/contrib/numpy.py index e532d34d..232220d8 100644 --- a/mlem/contrib/numpy.py +++ b/mlem/contrib/numpy.py @@ -94,7 +94,7 @@ class NumpyNdarrayType( type: ClassVar[str] = "ndarray" libraries: ClassVar[List[ModuleType]] = [np] - shape: Tuple[Optional[int], ...] + shape: Optional[Tuple[Optional[int], ...]] dtype: str @staticmethod @@ -134,7 +134,16 @@ def _subtype(self, subshape: Tuple[Optional[int], ...]): def get_model(self, prefix: str = "") -> Type[BaseModel]: # TODO: https://github.com/iterative/mlem/issues/33 return create_model( - prefix + "NumpyNdarray", __root__=(List[self._subtype(self.shape[1:])], ...) # type: ignore + prefix + "NumpyNdarray", + __root__=( + self._subtype(self.shape) + if self.shape + else List[ + Union[python_type_from_np_string_repr(self.dtype), List] + ], # type: ignore + ..., + ) + # type: ignore ) def serialize(self, instance: np.ndarray): @@ -148,10 +157,20 @@ def serialize(self, instance: np.ndarray): return instance.tolist() def _check_shape(self, array, exc_type): - if tuple(array.shape)[1:] != self.shape[1:]: - raise exc_type( - f"given array is of shape: {(None,) + tuple(array.shape)[1:]}, expected: {self.shape}" + if self.shape: + if len(array.shape) != len(self.shape): + raise exc_type( + f"given array is of rank: {len(array.shape)}, expected: {len(self.shape)}" + ) + + array_shape = tuple( + None if expected_dim is None else array_dim + for array_dim, expected_dim in zip(array.shape, self.shape) ) + if tuple(array_shape) != self.shape: + raise exc_type( + f"given array is of shape: {array_shape}, expected: {self.shape}" + ) def get_writer(self, project: str = None, filename: str = None, **kwargs): return NumpyArrayWriter() diff --git a/mlem/core/data_type.py b/mlem/core/data_type.py index 4356fbf9..00a5094e 100644 --- a/mlem/core/data_type.py +++ b/mlem/core/data_type.py @@ -2,6 +2,7 @@ Base classes for working with data in MLEM """ import builtins +import json import posixpath from abc import ABC, abstractmethod from typing import ( @@ -70,8 +71,10 @@ def bind(self, data: Any): return self @classmethod - def create(cls, obj: Any, **kwargs): - return DataAnalyzer.analyze(obj, **kwargs).bind(obj) + def create(cls, obj: Any, is_dynamic: bool = False, **kwargs): + return DataAnalyzer.analyze(obj, is_dynamic=is_dynamic, **kwargs).bind( + obj + ) class DataSerializer(ABC): @@ -367,7 +370,7 @@ def get_model(self, prefix: str = "") -> Type[BaseModel]: def _check_type_and_size(obj, dtype, size, exc_type): DataType.check_type(obj, dtype, exc_type) - if size != -1 and len(obj) != size: + if size is not None and len(obj) != size: raise exc_type( f"given {dtype.__name__} has len: {len(obj)}, expected: {size}" ) @@ -449,45 +452,83 @@ def is_object_valid(cls, obj: Any) -> bool: return isinstance(obj, (list, tuple)) @classmethod - def process(cls, obj, **kwargs) -> DataType: + def process(cls, obj, is_dynamic: bool = False, **kwargs) -> DataType: if isinstance(obj, tuple): - return TupleType(items=[DataAnalyzer.analyze(o) for o in obj]) + return TupleType( + items=[ + DataAnalyzer.analyze(o, is_dynamic=is_dynamic, **kwargs) + for o in obj + ] + ) py_types = {type(o) for o in obj} if len(obj) <= 1 or len(py_types) > 1: - return ListType(items=[DataAnalyzer.analyze(o) for o in obj]) + return ListType( + items=[ + DataAnalyzer.analyze(o, is_dynamic=is_dynamic, **kwargs) + for o in obj + ] + ) + + size = None if is_dynamic else len(obj) if not py_types.intersection( PrimitiveType.PRIMITIVES ): # py_types is guaranteed to be singleton set here - items_types = [DataAnalyzer.analyze(o) for o in obj] + items_types = [ + DataAnalyzer.analyze(o, is_dynamic=is_dynamic, **kwargs) + for o in obj + ] first, *others = items_types for other in others: if first != other: return ListType(items=items_types) - return ArrayType(dtype=first, size=len(obj)) + return ArrayType(dtype=first, size=size) # optimization for large lists of same primitive type elements - return ArrayType(dtype=DataAnalyzer.analyze(obj[0]), size=len(obj)) - + return ArrayType( + dtype=DataAnalyzer.analyze( + obj[0], is_dynamic=is_dynamic, **kwargs + ), + size=size, + ) -class DictType(DataType, DataSerializer, DataHook): - """ - DataType for dict - """ - - type: ClassVar[str] = "dict" - item_types: Dict[str, DataType] +class DictTypeHook(DataHook): @classmethod def is_object_valid(cls, obj: Any) -> bool: return isinstance(obj, dict) @classmethod - def process(cls, obj: Any, **kwargs) -> "DictType": - return DictType( - item_types={k: DataAnalyzer.analyze(v) for (k, v) in obj.items()} - ) + def process( + cls, obj: Any, is_dynamic: bool = False, **kwargs + ) -> Union["DictType", "DynamicDictType"]: + + if not is_dynamic: + return DictType( + item_types={ + k: DataAnalyzer.analyze(v, is_dynamic=is_dynamic, **kwargs) + for (k, v) in obj.items() + } + ) + else: + return DynamicDictType( + key_type=DataAnalyzer.analyze( + next(iter(obj.keys())), is_dynamic=is_dynamic, **kwargs + ), + value_type=DataAnalyzer.analyze( + next(iter(obj.values())), is_dynamic=is_dynamic, **kwargs + ), + ) + + +class DictType(DataType, DataSerializer): + """ + DataType for dict with fixed set of keys + """ + + type: ClassVar[str] = "dict" + item_types: Dict[str, DataType] def deserialize(self, obj): self._check_type_and_keys(obj, DeserializationError) @@ -577,6 +618,124 @@ def read_batch( raise NotImplementedError +class DynamicDictType(DataType, DataSerializer): + """ + Dynamic DataType for dict without fixed set of keys + """ + + type: ClassVar[str] = "d_dict" + + key_type: DataType + value_type: DataType + + def deserialize(self, obj): + self._check_type_and_keys(obj, DeserializationError) + return { + self.key_type.get_serializer() + .deserialize( + k, + ): self.value_type.get_serializer() + .deserialize( + v, + ) + for k, v in obj.items() + } + + def serialize(self, instance: dict): + self._check_type_and_keys(instance, SerializationError) + if self.key_type == PrimitiveType and self.value_type == PrimitiveType: + return instance + else: + return { + self.key_type.get_serializer() + .serialize( + k, + ): self.value_type.get_serializer() + .serialize( + v, + ) + for k, v in instance.items() + } + + def _check_type_and_keys(self, obj, exc_type): + self.check_type(obj, dict, exc_type) + obj_type = DictTypeHook.process(obj, is_dynamic=True) + obj_types = (obj_type.key_type, obj_type.value_type) + expected_types = (self.key_type, self.value_type) + if obj_types != expected_types: + raise exc_type( + f"given dict has type: {obj_types}, expected: {expected_types}" + ) + + # TODO - should we check for type of all items of dict? + + def get_requirements(self) -> Requirements: + return sum( + [ + self.key_type.get_requirements(), + self.value_type.get_requirements(), + ], + Requirements.new(), + ) + + def get_writer( + self, project: str = None, filename: str = None, **kwargs + ) -> "DynamicDictWriter": + return DynamicDictWriter(**kwargs) + + def get_model(self, prefix="") -> Type[BaseModel]: + field_type = ( + Dict[ # type: ignore + self.key_type.get_serializer().get_model( + prefix + "_key_" # noqa: F821 + ), + self.value_type.get_serializer().get_model( + prefix + "_val_" # noqa: F821 + ), + ], + ..., + ) + return create_model(prefix + "DynamicDictType", __root__=field_type) # type: ignore + + +class DynamicDictWriter(DataWriter): + type: ClassVar[str] = "d_dict" + + def write( + self, data: DataType, storage: Storage, path: str + ) -> Tuple[DataReader, Artifacts]: + if not isinstance(data, DynamicDictType): + raise ValueError( + f"expected data to be of DynamicDictTypeWriter, got {type(data)} instead" + ) + with storage.open(path) as (f, art): + f.write( + json.dumps(data.get_serializer().serialize(data.data)).encode( + "utf-8" + ) + ) + return DynamicDictReader(data_type=data), {DataWriter.art_name: art} + + +class DynamicDictReader(DataReader): + type: ClassVar[str] = "d_dict" + data_type: DynamicDictType + + def read(self, artifacts: Artifacts) -> DataType: + if DataWriter.art_name not in artifacts: + raise ValueError( + f"Wrong artifacts {artifacts}: should be one {DataWriter.art_name} file" + ) + with artifacts[DataWriter.art_name].open() as f: + data = json.load(f) + return self.data_type.copy().bind(data) + + def read_batch( + self, artifacts: Artifacts, batch_size: int + ) -> Iterator[DataType]: + raise NotImplementedError + + # # # class BytesDataType(DataType): diff --git a/tests/contrib/test_numpy.py b/tests/contrib/test_numpy.py index 5b9252a1..127b5135 100644 --- a/tests/contrib/test_numpy.py +++ b/tests/contrib/test_numpy.py @@ -1,8 +1,9 @@ -from json import loads +import re import numpy as np import pytest from pydantic import parse_obj_as +from pytest_lazyfixture import lazy_fixture from mlem.contrib.numpy import ( NumpyNdarrayType, @@ -36,15 +37,120 @@ def custom_assert(x, y): ) -def test_ndarray_source(): - data = np.array([1, 2, 3]) - data_type = DataType.create(data) - data_write_read_check(data_type, custom_eq=np.array_equal) +@pytest.fixture +def nat(): + data = np.array([[1, 2], [3, 4]]) + dtype = DataType.create(data) + payload = {"shape": (None, 2), "dtype": "int64", "type": "ndarray"} + schema = { + "title": "NumpyNdarray", + "type": "array", + "items": { + "type": "array", + "items": {"type": "integer"}, + "minItems": 2, + "maxItems": 2, + }, + } + test_data1 = data + test_data2 = np.array([[10, 20], [30, 40]]) + test_data3 = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) + return False, dtype, payload, schema, test_data1, test_data2, test_data3 @pytest.fixture -def nat(): - return DataType.create(np.array([[1, 2], [3, 4]])) +def nat_dynamic(): + dtype = NumpyNdarrayType(shape=[2, None, None], dtype="int") + payload = {"dtype": "int", "shape": (2, None, None), "type": "ndarray"} + schema = { + "items": { + "items": {"items": {"type": "integer"}, "type": "array"}, + "type": "array", + }, + "maxItems": 2, + "minItems": 2, + "title": "NumpyNdarray", + "type": "array", + } + test_data1 = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + test_data2 = np.array( + [[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]] + ) + test_data3 = np.array([[[1, 2, 2], [3, 4, 4]], [[5, 6, 6], [7, 8, 8]]]) + return True, dtype, payload, schema, test_data1, test_data2, test_data3 + + +@pytest.fixture +def nat_dynamic_float(): + dtype = NumpyNdarrayType(shape=[2, None, None, 1], dtype="float") + payload = { + "dtype": "float", + "shape": (2, None, None, 1), + "type": "ndarray", + } + schema = { + "items": { + "items": { + "items": { + "items": {"type": "number"}, + "maxItems": 1, + "minItems": 1, + "type": "array", + }, + "type": "array", + }, + "type": "array", + }, + "maxItems": 2, + "minItems": 2, + "title": "NumpyNdarray", + "type": "array", + } + test_data1 = np.array([[[[1.0]], [[3.0]]], [[[5.1]], [[7.1]]]]) + test_data2 = np.array([[[[1.1], [3.0], [5.0]]], [[[7.1], [9.99], [11.2]]]]) + test_data3 = np.array( + [[[[1.1], [3.2]], [[5.33], [7.1]]], [[[1.11], [3.4]], [[5.3], [7.2]]]] + ) + return True, dtype, payload, schema, test_data1, test_data2, test_data3 + + +@pytest.fixture +def nat_dynamic_all_none_dims(): + dtype = NumpyNdarrayType(shape=[None, None, None], dtype="int") + payload = {"dtype": "int", "shape": (None, None, None), "type": "ndarray"} + schema = { + "items": { + "items": {"items": {"type": "integer"}, "type": "array"}, + "type": "array", + }, + "title": "NumpyNdarray", + "type": "array", + } + test_data1 = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + test_data2 = np.array( + [[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]] + ) + test_data3 = np.array([[[1, 2, 2], [3, 4, 4]], [[5, 6, 6], [7, 8, 8]]]) + return True, dtype, payload, schema, test_data1, test_data2, test_data3 + + +@pytest.fixture +def nat_dynamic_shape_none(): + dtype = NumpyNdarrayType(shape=None, dtype="int") + payload = {"dtype": "int", "type": "ndarray"} + schema = { + "items": { + "anyOf": [{"type": "integer"}, {"items": {}, "type": "array"}] + }, + "title": "NumpyNdarray", + "type": "array", + } + test_data1 = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + test_data2 = np.array( + [[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]] + ) + test_data3 = np.array([[[1, 2, 2], [3, 4, 4]], [[5, 6, 6], [7, 8, 8]]]) + return True, dtype, payload, schema, test_data1, test_data2, test_data3 def test_python_type_from_np_string_repr(): @@ -81,41 +187,73 @@ def test_number(): assert ndt.get_serializer().deserialize(n_payload) == value -def test_ndarray(nat): - value = nat.data +@pytest.mark.parametrize("test_data_idx", [4, 5, 6]) +@pytest.mark.parametrize( + "data", + [ + (lazy_fixture("nat")), + (lazy_fixture("nat_dynamic")), + (lazy_fixture("nat_dynamic_all_none_dims")), + (lazy_fixture("nat_dynamic_shape_none")), + (lazy_fixture("nat_dynamic_float")), + ], +) +def test_ndarray(data, test_data_idx): + nat, payload, schema, value = ( + data[1], + data[2], + data[3], + data[test_data_idx], + ) assert isinstance(nat, NumpyNdarrayType) - assert nat.shape == (None, 2) - assert python_type_from_np_string_repr(nat.dtype) == int assert nat.get_requirements().modules == ["numpy"] - payload = nat.json() - nat2 = parse_obj_as(DataType, loads(payload)) + assert nat.dict() == payload + nat2 = parse_obj_as(DataType, payload) assert nat == nat2 assert nat.get_model().__name__ == nat2.get_model().__name__ - assert nat.get_model().schema() == { - "title": "NumpyNdarray", - "type": "array", - "items": { - "type": "array", - "items": {"type": "integer"}, - "minItems": 2, - "maxItems": 2, - }, - } + assert nat.get_model().schema() == schema n_payload = nat.get_serializer().serialize(value) assert (nat.get_serializer().deserialize(n_payload) == value).all() + model = parse_obj_as(nat.get_model(), n_payload) + assert model.__root__ == n_payload + + nat = nat.bind(value) + data_write_read_check(nat, custom_eq=np.array_equal) @pytest.mark.parametrize( - "obj", + "nddtype,obj,err_msg", [ - {}, # wrong type - np.array([[1, 2], [3, 4]], dtype=np.float32), # wrong data type - np.array([1, 2]), # wrong shape + [ + lazy_fixture("nat"), + {}, + "given data is of type: <class 'dict'>, expected: <class 'numpy.ndarray'>", + ], + [ + lazy_fixture("nat"), + np.array([[1, 2], [3, 4]], dtype=np.float32), + "given array is of type: float32, " "expected: int64", + ], + [ + lazy_fixture("nat"), + np.array([1, 2]), + "given array is of rank: 1, expected: 2", + ], + [ + lazy_fixture("nat_dynamic"), + np.array([1, 2]), + "given array is of rank: 1, " "expected: 3", + ], + [ + lazy_fixture("nat_dynamic_float"), + np.array([[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]]), + "given array is of shape: (1, None, None, 2), expected: (2, None, None, 1)", + ], ], ) -def test_ndarray_serialize_failure(nat, obj): - with pytest.raises(SerializationError): - nat.serialize(obj) +def test_ndarray_serialize_failure(nddtype, obj, err_msg): + with pytest.raises(SerializationError, match=re.escape(err_msg)): + nddtype[1].serialize(obj) @pytest.mark.parametrize( @@ -124,26 +262,10 @@ def test_ndarray_serialize_failure(nat, obj): ) def test_ndarray_deserialize_failure(nat, obj): with pytest.raises(DeserializationError): - nat.deserialize(obj) + nat[1].deserialize(obj) def test_requirements(): assert get_object_requirements( NumpyNdarrayType(shape=(0,), dtype="int") ).modules == ["numpy"] - - -# Copyright 2019 Zyfra -# Copyright 2021 Iterative -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/core/test_data_type.py b/tests/core/test_data_type.py index f79ddb84..6e34383b 100644 --- a/tests/core/test_data_type.py +++ b/tests/core/test_data_type.py @@ -1,5 +1,8 @@ +import copy + import pytest from pydantic import parse_obj_as +from pytest_lazyfixture import lazy_fixture from mlem.core.data_type import ( ArrayReader, @@ -9,6 +12,8 @@ DataType, DictReader, DictType, + DynamicDictReader, + DynamicDictType, ListType, PrimitiveReader, PrimitiveType, @@ -27,6 +32,32 @@ def test_primitives_not_ok(): assert not PrimitiveType.is_object_valid(NotPrimitive()) +@pytest.fixture +def array(): + is_dynamic = False + array = [1, 2, 3, 4, 5] + payload = { + "dtype": {"ptype": "int", "type": "primitive"}, + "size": 5, + "type": "array", + } + schema = { + "items": {"type": "integer"}, + "title": "Array", + "type": "array", + } + + return is_dynamic, array, payload, schema + + +@pytest.fixture +def array_dynamic(array): + is_dynamic = True + payload = copy.deepcopy(array[2]) + del payload["size"] + return is_dynamic, array[1], payload, array[3] + + @pytest.mark.parametrize("ptype", PrimitiveType.PRIMITIVES) def test_primitive_source(ptype): if ptype is type(None): # noqa: E721 @@ -63,31 +94,39 @@ def test_primitives(ptype): assert dt.get_model() is ptype -def test_array(): - l_value = [1, 2, 3, 4, 5] - dt = DataAnalyzer.analyze(l_value) +@pytest.mark.parametrize( + "array_data,value", + [ + (lazy_fixture("array"), None), + (lazy_fixture("array_dynamic"), None), + (lazy_fixture("array_dynamic"), [1, 2, 3]), + ], +) +def test_array(array_data, value): + dt = DataAnalyzer.analyze(array_data[1], is_dynamic=array_data[0]) + l_value = array_data[1] if value is None else value assert isinstance(dt, ArrayType) - payload = { - "dtype": {"ptype": "int", "type": "primitive"}, - "size": 5, - "type": "array", - } - assert dt.dict() == payload - dt2 = parse_obj_as(ArrayType, payload) + assert dt.dict() == array_data[2] + dt2 = parse_obj_as(ArrayType, array_data[2]) assert dt2 == dt assert l_value == dt.serialize(l_value) assert l_value == dt.deserialize(l_value) assert dt.get_model().__name__ == "Array" - assert dt.get_model().schema() == { - "items": {"type": "integer"}, - "title": "Array", - "type": "array", - } + assert dt.get_model().schema() == array_data[3] -def test_list_source(): - l_value = [1, 2, 3, 4, 5] - dt = DataType.create(l_value) +@pytest.mark.parametrize( + "is_dynamic,array_data,value", + [ + (False, lazy_fixture("array"), None), + (True, lazy_fixture("array_dynamic"), None), + (True, lazy_fixture("array_dynamic"), [1, 2, 3]), + ], +) +def test_list_source(is_dynamic, array_data, value): + dt = DataType.create(array_data[0]) + l_value = array_data[0] if value is None else value + dt.bind(l_value) artifacts = data_write_read_check( dt, @@ -95,11 +134,8 @@ def test_list_source(): ) assert list(artifacts.keys()) == [f"{x}/data" for x in range(len(l_value))] - assert artifacts["0/data"].uri.endswith("data/0") - assert artifacts["1/data"].uri.endswith("data/1") - assert artifacts["2/data"].uri.endswith("data/2") - assert artifacts["3/data"].uri.endswith("data/3") - assert artifacts["4/data"].uri.endswith("data/4") + for x in range(len(l_value)): + assert artifacts[f"{x}/data"].uri.endswith(f"data/{x}") def test_tuple(): @@ -198,10 +234,10 @@ def test_mixed_list_source(): assert artifacts["5/data"].uri.endswith("data/5") -def test_dict(): +@pytest.fixture +def dict_data(): + is_dynamic = False d = {"1": 1, "2": "a"} - dt = DataAnalyzer.analyze(d) - assert isinstance(dt, DictType) payload = { "item_types": { "1": {"ptype": "int", "type": "primitive"}, @@ -209,13 +245,8 @@ def test_dict(): }, "type": "dict", } - assert dt.dict() == payload - dt2 = parse_obj_as(DictType, payload) - assert dt2 == dt - assert d == dt.serialize(d) - assert d == dt.deserialize(d) - assert dt.get_model().__name__ == "DictType" - assert dt.get_model().schema() == { + + schema = { "title": "DictType", "type": "object", "properties": { @@ -225,10 +256,182 @@ def test_dict(): "required": ["1", "2"], } + test_data1 = {"1": 1, "2": "a"} + test_data2 = {"1": 2, "2": "b"} + test_data3 = {"1": 3, "2": "c"} + + return is_dynamic, d, payload, schema, test_data1, test_data2, test_data3 + + +@pytest.fixture +def dynamic_dict_data(): + is_dynamic = True + d = {"a": 1, "b": 2} + payload = { + "key_type": {"ptype": "str", "type": "primitive"}, + "value_type": {"ptype": "int", "type": "primitive"}, + "type": "d_dict", + } + schema = { + "title": "DynamicDictType", + "type": "object", + "additionalProperties": {"type": "integer"}, + } + + test_data1 = {"a": 1, "b": 2} + test_data2 = {"a": 1} + test_data3 = {"a": 1, "b": 2, "c": 3, "d": 1} + + return is_dynamic, d, payload, schema, test_data1, test_data2, test_data3 + + +@pytest.fixture +def dynamic_dict_str_val_type_data(): + is_dynamic = True + d = {"a": "1", "b": "2"} + payload = { + "key_type": {"ptype": "str", "type": "primitive"}, + "value_type": {"ptype": "str", "type": "primitive"}, + "type": "d_dict", + } + schema = { + "title": "DynamicDictType", + "type": "object", + "additionalProperties": {"type": "string"}, + } + + test_data1 = {"a": "1", "b": "2"} + test_data2 = {"a": "1"} + test_data3 = {"a": "1", "b": "2", "c": "3", "d": "1"} -def test_dict_source(): - d_value = {"1": 1.5, "2": "a", "3": {"1": False}} - data_type = DataType.create(d_value) + return is_dynamic, d, payload, schema, test_data1, test_data2, test_data3 + + +@pytest.fixture +def dynamic_dict_array_type(): + is_dynamic = True + d = {"a": [1, 2, 3], "b": [3, 4, 5]} + payload = { + "key_type": {"ptype": "str", "type": "primitive"}, + "type": "d_dict", + "value_type": { + "dtype": {"ptype": "int", "type": "primitive"}, + "type": "array", + }, + } + schema = { + "additionalProperties": {"$ref": "#/definitions/_val_Array"}, + "definitions": { + "_val_Array": { + "items": {"type": "integer"}, + "title": "_val_Array", + "type": "array", + } + }, + "title": "DynamicDictType", + "type": "object", + } + + test_data1 = {"a": [1, 2, 3], "b": [3, 4, 5]} + test_data2 = {"a": [1, 2, 3]} + test_data3 = {"a": [1, 2, 3], "b": [3, 4, 5], "d": [6, 7, 8]} + return is_dynamic, d, payload, schema, test_data1, test_data2, test_data3 + + +@pytest.fixture +def dynamic_dict_dict_type(): + is_dynamic = True + d = {"a": {"l": [1, 2]}, "b": {"l": [3, 4]}} + payload = { + "key_type": {"ptype": "str", "type": "primitive"}, + "type": "d_dict", + "value_type": { + "key_type": {"ptype": "str", "type": "primitive"}, + "type": "d_dict", + "value_type": { + "dtype": {"ptype": "int", "type": "primitive"}, + "type": "array", + }, + }, + } + schema = { + "additionalProperties": {"$ref": "#/definitions/_val_DynamicDictType"}, + "definitions": { + "_val_DynamicDictType": { + "additionalProperties": { + "$ref": "#/definitions/_val__val_Array" + }, + "title": "_val_DynamicDictType", + "type": "object", + }, + "_val__val_Array": { + "items": {"type": "integer"}, + "title": "_val__val_Array", + "type": "array", + }, + }, + "title": "DynamicDictType", + "type": "object", + } + test_data1 = {"a": {"l": [1, 2]}, "b": {"l": [3, 4]}} + test_data2 = {"a": {"l": [1, 2]}} + test_data3 = {"a": {"l": [1, 2]}, "b": {"l": [3, 4]}, "c": {"k": [3, 4]}} + return is_dynamic, d, payload, schema, test_data1, test_data2, test_data3 + + +@pytest.mark.parametrize("test_data_idx", [4, 5, 6]) +@pytest.mark.parametrize( + "data", + [ + lazy_fixture("dict_data"), + lazy_fixture("dynamic_dict_data"), + lazy_fixture("dynamic_dict_str_val_type_data"), + lazy_fixture("dynamic_dict_array_type"), + lazy_fixture("dynamic_dict_dict_type"), + ], +) +def test_dict(data, test_data_idx): + is_dynamic, d, payload, schema, test_data = ( + data[0], + data[1], + data[2], + data[3], + data[test_data_idx], + ) + dt = DataAnalyzer.analyze(d, is_dynamic=is_dynamic) + dtype = DictType if not is_dynamic else DynamicDictType + assert isinstance(dt, dtype) + + assert dt.dict() == payload + dt2 = parse_obj_as(dtype, payload) + assert dt2 == dt + assert test_data == dt.serialize(test_data) + assert test_data == dt.deserialize(test_data) + assert dt.get_model().__name__ == dtype.__name__ + assert dt.get_model().schema() == schema + assert parse_obj_as(dt.get_model(), test_data) + + +@pytest.mark.parametrize("test_data_idx", [4, 5, 6]) +@pytest.mark.parametrize( + "data", + [ + lazy_fixture("dict_data"), + lazy_fixture("dynamic_dict_data"), + lazy_fixture("dynamic_dict_str_val_type_data"), + lazy_fixture("dynamic_dict_array_type"), + lazy_fixture("dynamic_dict_dict_type"), + ], +) +def test_dict_source(data, test_data_idx): + is_dynamic, d, test_data = ( + data[0], + data[1], + data[test_data_idx], + ) + data_type = DataType.create(d, is_dynamic=is_dynamic) + data_type = data_type.bind(test_data) + dtype_reader = DynamicDictReader if is_dynamic else DictReader def custom_assert(x, y): assert x == y @@ -238,11 +441,14 @@ def custom_assert(x, y): artifacts = data_write_read_check( data_type, - reader_type=DictReader, + reader_type=dtype_reader, custom_assert=custom_assert, ) - assert list(artifacts.keys()) == ["1/data", "2/data", "3/1/data"] - assert artifacts["1/data"].uri.endswith("data/1") - assert artifacts["2/data"].uri.endswith("data/2") - assert artifacts["3/1/data"].uri.endswith("data/3/1") + if not is_dynamic: + assert list(artifacts.keys()) == ["1/data", "2/data"] + assert artifacts["1/data"].uri.endswith("data/1") + assert artifacts["2/data"].uri.endswith("data/2") + else: + assert list(artifacts.keys()) == ["data"] + assert artifacts["data"].uri.endswith("data") From c6e8d72595010de0fe25a24df8b19f58219fe342 Mon Sep 17 00:00:00 2001 From: Mahesh Ambule <ambulemahesh@gmail.com> Date: Mon, 20 Jun 2022 16:23:23 +0530 Subject: [PATCH 2/9] dynamic types support array, dict, ndarray --- mlem/contrib/lightgbm.py | 8 ++------ tests/core/test_data_type.py | 14 +++++++------- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/mlem/contrib/lightgbm.py b/mlem/contrib/lightgbm.py index 17e9b263..b45fad44 100644 --- a/mlem/contrib/lightgbm.py +++ b/mlem/contrib/lightgbm.py @@ -69,12 +69,8 @@ def get_writer( return LightGBMDataWriter(**kwargs) @classmethod - def process(cls, obj: Any, is_dynamic: bool = False, **kwargs) -> DataType: - return LightGBMDataType( - inner=DataAnalyzer.analyze( - obj.data, is_dynamic=is_dynamic, **kwargs - ) - ) + def process(cls, obj: Any, **kwargs) -> DataType: + return LightGBMDataType(inner=DataAnalyzer.analyze(obj.data)) def get_model(self, prefix: str = "") -> Type[BaseModel]: return self.inner.get_serializer().get_model(prefix) diff --git a/tests/core/test_data_type.py b/tests/core/test_data_type.py index 6e34383b..f4d1c06d 100644 --- a/tests/core/test_data_type.py +++ b/tests/core/test_data_type.py @@ -116,16 +116,16 @@ def test_array(array_data, value): @pytest.mark.parametrize( - "is_dynamic,array_data,value", + "array_data,value", [ - (False, lazy_fixture("array"), None), - (True, lazy_fixture("array_dynamic"), None), - (True, lazy_fixture("array_dynamic"), [1, 2, 3]), + (lazy_fixture("array"), None), + (lazy_fixture("array_dynamic"), None), + (lazy_fixture("array_dynamic"), [1, 2, 3]), ], ) -def test_list_source(is_dynamic, array_data, value): - dt = DataType.create(array_data[0]) - l_value = array_data[0] if value is None else value +def test_list_source(array_data, value): + dt = DataType.create(array_data[1]) + l_value = array_data[1] if value is None else value dt.bind(l_value) artifacts = data_write_read_check( From fe53e93c0a18830bf67041d770bf1e3bb9d7dfb2 Mon Sep 17 00:00:00 2001 From: Mahesh Ambule <ambulemahesh@gmail.com> Date: Mon, 20 Jun 2022 21:27:19 +0530 Subject: [PATCH 3/9] dynamic types support array, dict, ndarray --- mlem/core/data_type.py | 62 +++++++++++++++++++++++------------------- 1 file changed, 34 insertions(+), 28 deletions(-) diff --git a/mlem/core/data_type.py b/mlem/core/data_type.py index 00a5094e..1304b55e 100644 --- a/mlem/core/data_type.py +++ b/mlem/core/data_type.py @@ -503,23 +503,9 @@ def is_object_valid(cls, obj: Any) -> bool: def process( cls, obj: Any, is_dynamic: bool = False, **kwargs ) -> Union["DictType", "DynamicDictType"]: - if not is_dynamic: - return DictType( - item_types={ - k: DataAnalyzer.analyze(v, is_dynamic=is_dynamic, **kwargs) - for (k, v) in obj.items() - } - ) - else: - return DynamicDictType( - key_type=DataAnalyzer.analyze( - next(iter(obj.keys())), is_dynamic=is_dynamic, **kwargs - ), - value_type=DataAnalyzer.analyze( - next(iter(obj.values())), is_dynamic=is_dynamic, **kwargs - ), - ) + return DictType.analyze(obj, **kwargs) + return DynamicDictType.analyze(obj, **kwargs) class DictType(DataType, DataSerializer): @@ -530,6 +516,15 @@ class DictType(DataType, DataSerializer): type: ClassVar[str] = "dict" item_types: Dict[str, DataType] + @classmethod + def analyze(cls, obj, **kwargs): + return DictType( + item_types={ + k: DataAnalyzer.analyze(v, is_dynamic=False, **kwargs) + for (k, v) in obj.items() + } + ) + def deserialize(self, obj): self._check_type_and_keys(obj, DeserializationError) return { @@ -645,21 +640,32 @@ def serialize(self, instance: dict): self._check_type_and_keys(instance, SerializationError) if self.key_type == PrimitiveType and self.value_type == PrimitiveType: return instance - else: - return { - self.key_type.get_serializer() - .serialize( - k, - ): self.value_type.get_serializer() - .serialize( - v, - ) - for k, v in instance.items() - } + + return { + self.key_type.get_serializer() + .serialize( + k, + ): self.value_type.get_serializer() + .serialize( + v, + ) + for k, v in instance.items() + } + + @classmethod + def analyze(cls, obj, **kwargs): + return DynamicDictType( + key_type=DataAnalyzer.analyze( + next(iter(obj.keys())), is_dynamic=True, **kwargs + ), + value_type=DataAnalyzer.analyze( + next(iter(obj.values())), is_dynamic=True, **kwargs + ), + ) def _check_type_and_keys(self, obj, exc_type): self.check_type(obj, dict, exc_type) - obj_type = DictTypeHook.process(obj, is_dynamic=True) + obj_type: DynamicDictType = self.analyze(obj) obj_types = (obj_type.key_type, obj_type.value_type) expected_types = (self.key_type, self.value_type) if obj_types != expected_types: From 07fcd8180e6ecab73ad9f3daaecadbe95173939f Mon Sep 17 00:00:00 2001 From: Mahesh Ambule <ambulemahesh@gmail.com> Date: Mon, 20 Jun 2022 22:24:07 +0530 Subject: [PATCH 4/9] fix lint issue --- tests/contrib/test_numpy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/contrib/test_numpy.py b/tests/contrib/test_numpy.py index 127b5135..511a3b3d 100644 --- a/tests/contrib/test_numpy.py +++ b/tests/contrib/test_numpy.py @@ -232,7 +232,7 @@ def test_ndarray(data, test_data_idx): [ lazy_fixture("nat"), np.array([[1, 2], [3, 4]], dtype=np.float32), - "given array is of type: float32, " "expected: int64", + "given array is of type: float32, expected: int64", ], [ lazy_fixture("nat"), @@ -242,7 +242,7 @@ def test_ndarray(data, test_data_idx): [ lazy_fixture("nat_dynamic"), np.array([1, 2]), - "given array is of rank: 1, " "expected: 3", + "given array is of rank: 1, expected: 3", ], [ lazy_fixture("nat_dynamic_float"), From 4095937281e6816de099ee01338c30e112635d18 Mon Sep 17 00:00:00 2001 From: Mahesh Ambule <ambulemahesh@gmail.com> Date: Wed, 22 Jun 2022 02:45:27 +0530 Subject: [PATCH 5/9] incorporated review comments --- mlem/contrib/numpy.py | 4 +- mlem/core/data_type.py | 49 +++++++++++++------ tests/contrib/test_numpy.py | 20 ++++++++ tests/core/test_data_type.py | 95 ++++++++++++++++++++++++++++++++++-- 4 files changed, 146 insertions(+), 22 deletions(-) diff --git a/mlem/contrib/numpy.py b/mlem/contrib/numpy.py index 232220d8..c944c671 100644 --- a/mlem/contrib/numpy.py +++ b/mlem/contrib/numpy.py @@ -137,7 +137,7 @@ def get_model(self, prefix: str = "") -> Type[BaseModel]: prefix + "NumpyNdarray", __root__=( self._subtype(self.shape) - if self.shape + if self.shape is not None else List[ Union[python_type_from_np_string_repr(self.dtype), List] ], # type: ignore @@ -157,7 +157,7 @@ def serialize(self, instance: np.ndarray): return instance.tolist() def _check_shape(self, array, exc_type): - if self.shape: + if self.shape is not None: if len(array.shape) != len(self.shape): raise exc_type( f"given array is of rank: {len(array.shape)}, expected: {len(self.shape)}" diff --git a/mlem/core/data_type.py b/mlem/core/data_type.py index 1304b55e..13d76067 100644 --- a/mlem/core/data_type.py +++ b/mlem/core/data_type.py @@ -19,7 +19,7 @@ ) import flatdict -from pydantic import BaseModel +from pydantic import BaseModel, validator from pydantic.main import create_model from mlem.core.artifacts import Artifacts, Storage @@ -504,8 +504,8 @@ def process( cls, obj: Any, is_dynamic: bool = False, **kwargs ) -> Union["DictType", "DynamicDictType"]: if not is_dynamic: - return DictType.analyze(obj, **kwargs) - return DynamicDictType.analyze(obj, **kwargs) + return DictType.create(obj, **kwargs) + return DynamicDictType.create(obj, **kwargs) class DictType(DataType, DataSerializer): @@ -517,7 +517,7 @@ class DictType(DataType, DataSerializer): item_types: Dict[str, DataType] @classmethod - def analyze(cls, obj, **kwargs): + def create(cls, obj, **kwargs): return DictType( item_types={ k: DataAnalyzer.analyze(v, is_dynamic=False, **kwargs) @@ -620,11 +620,19 @@ class DynamicDictType(DataType, DataSerializer): type: ClassVar[str] = "d_dict" - key_type: DataType + key_type: PrimitiveType value_type: DataType + @validator("key_type") + def is_valid_key_type( # pylint: disable=no-self-argument + cls, key_type # noqa: B902 + ): + if key_type.ptype not in ["str", "int", "float"]: + raise ValueError(f"key_type {key_type.ptype} is not supported") + return key_type + def deserialize(self, obj): - self._check_type_and_keys(obj, DeserializationError) + self.check_type(obj, dict, DeserializationError) return { self.key_type.get_serializer() .deserialize( @@ -637,9 +645,7 @@ def deserialize(self, obj): } def serialize(self, instance: dict): - self._check_type_and_keys(instance, SerializationError) - if self.key_type == PrimitiveType and self.value_type == PrimitiveType: - return instance + self._check_types(instance, SerializationError) return { self.key_type.get_serializer() @@ -653,7 +659,9 @@ def serialize(self, instance: dict): } @classmethod - def analyze(cls, obj, **kwargs): + def create( + cls, obj, is_dynamic: bool = True, **kwargs + ) -> "DynamicDictType": return DynamicDictType( key_type=DataAnalyzer.analyze( next(iter(obj.keys())), is_dynamic=True, **kwargs @@ -663,18 +671,25 @@ def analyze(cls, obj, **kwargs): ), ) - def _check_type_and_keys(self, obj, exc_type): + def _check_types(self, obj, exc_type, ignore_key_type: bool = False): self.check_type(obj, dict, exc_type) - obj_type: DynamicDictType = self.analyze(obj) - obj_types = (obj_type.key_type, obj_type.value_type) - expected_types = (self.key_type, self.value_type) + + obj_type = self.create(obj) + if ignore_key_type: + obj_types: Union[ + Tuple[PrimitiveType, DataType], Tuple[DataType] + ] = (obj_type.value_type,) + expected_types: Union[ + Tuple[PrimitiveType, DataType], Tuple[DataType] + ] = (self.value_type,) + else: + obj_types = (obj_type.key_type, obj_type.value_type) + expected_types = (self.key_type, self.value_type) if obj_types != expected_types: raise exc_type( f"given dict has type: {obj_types}, expected: {expected_types}" ) - # TODO - should we check for type of all items of dict? - def get_requirements(self) -> Requirements: return sum( [ @@ -734,6 +749,8 @@ def read(self, artifacts: Artifacts) -> DataType: ) with artifacts[DataWriter.art_name].open() as f: data = json.load(f) + # json stores keys as strings. Deserialize string keys as well as values. + data = self.data_type.deserialize(data) return self.data_type.copy().bind(data) def read_batch( diff --git a/tests/contrib/test_numpy.py b/tests/contrib/test_numpy.py index 511a3b3d..ce669ea5 100644 --- a/tests/contrib/test_numpy.py +++ b/tests/contrib/test_numpy.py @@ -153,6 +153,20 @@ def nat_dynamic_shape_none(): return True, dtype, payload, schema, test_data1, test_data2, test_data3 +@pytest.fixture +def nat_shape_empty(): + dtype = NumpyNdarrayType(shape=(), dtype="int") + payload = {"dtype": "int", "shape": (), "type": "ndarray"} + schema = { + "title": "NumpyNdarray", + "type": "integer", + } + test_data1 = np.array(1) + test_data2 = np.array(3) + test_data3 = np.array(4) + return True, dtype, payload, schema, test_data1, test_data2, test_data3 + + def test_python_type_from_np_string_repr(): assert python_type_from_np_string_repr("int64") == int @@ -196,6 +210,7 @@ def test_number(): (lazy_fixture("nat_dynamic_all_none_dims")), (lazy_fixture("nat_dynamic_shape_none")), (lazy_fixture("nat_dynamic_float")), + (lazy_fixture("nat_shape_empty")), ], ) def test_ndarray(data, test_data_idx): @@ -244,6 +259,11 @@ def test_ndarray(data, test_data_idx): np.array([1, 2]), "given array is of rank: 1, expected: 3", ], + [ + lazy_fixture("nat_empty_shape"), + np.array([1, 2]), + "given array is of rank: 1, expected: 0", + ], [ lazy_fixture("nat_dynamic_float"), np.array([[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]]), diff --git a/tests/core/test_data_type.py b/tests/core/test_data_type.py index f4d1c06d..a94f6054 100644 --- a/tests/core/test_data_type.py +++ b/tests/core/test_data_type.py @@ -1,5 +1,6 @@ import copy +import numpy as np import pytest from pydantic import parse_obj_as from pytest_lazyfixture import lazy_fixture @@ -307,6 +308,50 @@ def dynamic_dict_str_val_type_data(): return is_dynamic, d, payload, schema, test_data1, test_data2, test_data3 +@pytest.fixture +def dynamic_dict_int_key_type_data(): + is_dynamic = True + d = {1: "1", 2: "2"} + payload = { + "key_type": {"ptype": "int", "type": "primitive"}, + "value_type": {"ptype": "str", "type": "primitive"}, + "type": "d_dict", + } + schema = { + "title": "DynamicDictType", + "type": "object", + "additionalProperties": {"type": "string"}, + } + + test_data1 = {1: "1", 2: "2"} + test_data2 = {2: "1"} + test_data3 = {3: "1", 4: "2", 5: "3", 6: "1"} + + return is_dynamic, d, payload, schema, test_data1, test_data2, test_data3 + + +@pytest.fixture +def dynamic_dict_float_key_type_data(): + is_dynamic = True + d = {1.0: "1", 2.0: "2"} + payload = { + "key_type": {"ptype": "float", "type": "primitive"}, + "value_type": {"ptype": "str", "type": "primitive"}, + "type": "d_dict", + } + schema = { + "title": "DynamicDictType", + "type": "object", + "additionalProperties": {"type": "string"}, + } + + test_data1 = {1.0: "1", 2.0: "2"} + test_data2 = {2.9999999999: "1"} + test_data3 = {3.8998: "1", 4.0001: "2", 5.2: "3", 6.9: "1"} + + return is_dynamic, d, payload, schema, test_data1, test_data2, test_data3 + + @pytest.fixture def dynamic_dict_array_type(): is_dynamic = True @@ -379,6 +424,41 @@ def dynamic_dict_dict_type(): return is_dynamic, d, payload, schema, test_data1, test_data2, test_data3 +@pytest.fixture +def dynamic_dict_ndarray_type(): + is_dynamic = True + d = {"a": np.array([1, 2]), "b": np.array([3, 4])} + payload = { + "key_type": {"ptype": "str", "type": "primitive"}, + "type": "d_dict", + "value_type": {"dtype": "int64", "shape": (None,), "type": "ndarray"}, + } + + schema = { + "additionalProperties": {"$ref": "#/definitions/_val_NumpyNdarray"}, + "definitions": { + "_val_NumpyNdarray": { + "items": {"type": "integer"}, + "title": "_val_NumpyNdarray", + "type": "array", + } + }, + "title": "DynamicDictType", + "type": "object", + } + + test_data1 = {"a": np.array([1, 2]), "b": np.array([3, 4])} + test_data2 = { + "a": np.array([1, 2]), + } + test_data3 = { + "a": np.array([1, 2]), + "b": np.array([3, 4]), + "c": np.array([5, 6]), + } + return is_dynamic, d, payload, schema, test_data1, test_data2, test_data3 + + @pytest.mark.parametrize("test_data_idx", [4, 5, 6]) @pytest.mark.parametrize( "data", @@ -386,8 +466,11 @@ def dynamic_dict_dict_type(): lazy_fixture("dict_data"), lazy_fixture("dynamic_dict_data"), lazy_fixture("dynamic_dict_str_val_type_data"), + lazy_fixture("dynamic_dict_int_key_type_data"), + lazy_fixture("dynamic_dict_float_key_type_data"), lazy_fixture("dynamic_dict_array_type"), lazy_fixture("dynamic_dict_dict_type"), + lazy_fixture("dynamic_dict_ndarray_type"), ], ) def test_dict(data, test_data_idx): @@ -405,11 +488,12 @@ def test_dict(data, test_data_idx): assert dt.dict() == payload dt2 = parse_obj_as(dtype, payload) assert dt2 == dt - assert test_data == dt.serialize(test_data) - assert test_data == dt.deserialize(test_data) + serialised_test_data = dt.serialize(test_data) + deserialised_test_data = dt.deserialize(serialised_test_data) + assert serialised_test_data == dt.serialize(deserialised_test_data) assert dt.get_model().__name__ == dtype.__name__ assert dt.get_model().schema() == schema - assert parse_obj_as(dt.get_model(), test_data) + assert parse_obj_as(dt.get_model(), serialised_test_data) @pytest.mark.parametrize("test_data_idx", [4, 5, 6]) @@ -419,8 +503,11 @@ def test_dict(data, test_data_idx): lazy_fixture("dict_data"), lazy_fixture("dynamic_dict_data"), lazy_fixture("dynamic_dict_str_val_type_data"), + lazy_fixture("dynamic_dict_int_key_type_data"), + lazy_fixture("dynamic_dict_float_key_type_data"), lazy_fixture("dynamic_dict_array_type"), lazy_fixture("dynamic_dict_dict_type"), + lazy_fixture("dynamic_dict_ndarray_type"), ], ) def test_dict_source(data, test_data_idx): @@ -434,7 +521,7 @@ def test_dict_source(data, test_data_idx): dtype_reader = DynamicDictReader if is_dynamic else DictReader def custom_assert(x, y): - assert x == y + np.testing.assert_equal(x, y) assert len(x) == len(y) assert isinstance(x, dict) assert isinstance(y, dict) From 4605d0f1221797f4306dfd865d54023b08a8a642 Mon Sep 17 00:00:00 2001 From: Mahesh Ambule <ambulemahesh@gmail.com> Date: Sat, 25 Jun 2022 01:45:13 +0530 Subject: [PATCH 6/9] incorporated review comments --- mlem/core/data_type.py | 12 +++++------ tests/core/test_data_type.py | 42 +++++++++++++++++++++++------------- 2 files changed, 32 insertions(+), 22 deletions(-) diff --git a/mlem/core/data_type.py b/mlem/core/data_type.py index 13d76067..2b19fac7 100644 --- a/mlem/core/data_type.py +++ b/mlem/core/data_type.py @@ -504,8 +504,8 @@ def process( cls, obj: Any, is_dynamic: bool = False, **kwargs ) -> Union["DictType", "DynamicDictType"]: if not is_dynamic: - return DictType.create(obj, **kwargs) - return DynamicDictType.create(obj, **kwargs) + return DictType.process(obj, **kwargs) + return DynamicDictType.process(obj, **kwargs) class DictType(DataType, DataSerializer): @@ -517,7 +517,7 @@ class DictType(DataType, DataSerializer): item_types: Dict[str, DataType] @classmethod - def create(cls, obj, **kwargs): + def process(cls, obj, **kwargs): return DictType( item_types={ k: DataAnalyzer.analyze(v, is_dynamic=False, **kwargs) @@ -659,9 +659,7 @@ def serialize(self, instance: dict): } @classmethod - def create( - cls, obj, is_dynamic: bool = True, **kwargs - ) -> "DynamicDictType": + def process(cls, obj, **kwargs) -> "DynamicDictType": return DynamicDictType( key_type=DataAnalyzer.analyze( next(iter(obj.keys())), is_dynamic=True, **kwargs @@ -674,7 +672,7 @@ def create( def _check_types(self, obj, exc_type, ignore_key_type: bool = False): self.check_type(obj, dict, exc_type) - obj_type = self.create(obj) + obj_type = self.process(obj) if ignore_key_type: obj_types: Union[ Tuple[PrimitiveType, DataType], Tuple[DataType] diff --git a/tests/core/test_data_type.py b/tests/core/test_data_type.py index a94f6054..9988c7fd 100644 --- a/tests/core/test_data_type.py +++ b/tests/core/test_data_type.py @@ -427,34 +427,46 @@ def dynamic_dict_dict_type(): @pytest.fixture def dynamic_dict_ndarray_type(): is_dynamic = True - d = {"a": np.array([1, 2]), "b": np.array([3, 4])} + d = {11: {1: np.array([1, 2])}, 22: {2: np.array([3, 4])}} payload = { - "key_type": {"ptype": "str", "type": "primitive"}, + "key_type": {"ptype": "int", "type": "primitive"}, "type": "d_dict", - "value_type": {"dtype": "int64", "shape": (None,), "type": "ndarray"}, + "value_type": { + "key_type": {"ptype": "int", "type": "primitive"}, + "type": "d_dict", + "value_type": { + "dtype": "int64", + "shape": (None,), + "type": "ndarray", + }, + }, } - schema = { - "additionalProperties": {"$ref": "#/definitions/_val_NumpyNdarray"}, + "additionalProperties": {"$ref": "#/definitions/_val_DynamicDictType"}, "definitions": { - "_val_NumpyNdarray": { + "_val_DynamicDictType": { + "additionalProperties": { + "$ref": "#/definitions/_val__val_NumpyNdarray" + }, + "title": "_val_DynamicDictType", + "type": "object", + }, + "_val__val_NumpyNdarray": { "items": {"type": "integer"}, - "title": "_val_NumpyNdarray", + "title": "_val__val_NumpyNdarray", "type": "array", - } + }, }, "title": "DynamicDictType", "type": "object", } - test_data1 = {"a": np.array([1, 2]), "b": np.array([3, 4])} - test_data2 = { - "a": np.array([1, 2]), - } + test_data1 = {11: {1: np.array([1, 2])}, 22: {2: np.array([3, 4])}} + test_data2 = {11: {1: np.array([1, 2])}} test_data3 = { - "a": np.array([1, 2]), - "b": np.array([3, 4]), - "c": np.array([5, 6]), + 11: {1: np.array([1, 2])}, + 22: {2: np.array([3, 4])}, + 33: {2: np.array([5, 6])}, } return is_dynamic, d, payload, schema, test_data1, test_data2, test_data3 From 11951777043b7dbc80c720b0bd43f32cf22e1bb5 Mon Sep 17 00:00:00 2001 From: Mahesh Ambule <ambulemahesh@gmail.com> Date: Sat, 25 Jun 2022 01:51:07 +0530 Subject: [PATCH 7/9] reverted licence --- tests/contrib/test_numpy.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/contrib/test_numpy.py b/tests/contrib/test_numpy.py index ce669ea5..d05c2daf 100644 --- a/tests/contrib/test_numpy.py +++ b/tests/contrib/test_numpy.py @@ -289,3 +289,19 @@ def test_requirements(): assert get_object_requirements( NumpyNdarrayType(shape=(0,), dtype="int") ).modules == ["numpy"] + + +# Copyright 2019 Zyfra +# Copyright 2021 Iterative +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. From a71d8fe04d3f137c8d4cb900610d2a87e316c73d Mon Sep 17 00:00:00 2001 From: Mahesh Ambule <ambulemahesh@gmail.com> Date: Tue, 28 Jun 2022 17:56:43 +0530 Subject: [PATCH 8/9] correct fixture name --- tests/contrib/test_numpy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/contrib/test_numpy.py b/tests/contrib/test_numpy.py index d05c2daf..d12b4735 100644 --- a/tests/contrib/test_numpy.py +++ b/tests/contrib/test_numpy.py @@ -260,7 +260,7 @@ def test_ndarray(data, test_data_idx): "given array is of rank: 1, expected: 3", ], [ - lazy_fixture("nat_empty_shape"), + lazy_fixture("nat_shape_empty"), np.array([1, 2]), "given array is of rank: 1, expected: 0", ], From 40677def2381fcc2a9f82153e05168579f39a14a Mon Sep 17 00:00:00 2001 From: Mahesh Ambule <ambulemahesh@gmail.com> Date: Wed, 29 Jun 2022 19:25:10 +0530 Subject: [PATCH 9/9] fix windows numpy int dtype issue --- tests/conftest.py | 8 ++++++++ tests/contrib/test_numpy.py | 10 +++++++--- tests/core/test_data_type.py | 4 ++-- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 014f3d2f..305662e7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,7 @@ from typing import Any, Callable, Type import git +import numpy as np import pandas as pd import pytest from fastapi.testclient import TestClient @@ -426,3 +427,10 @@ def disable_colorama(): import colorama colorama.init = lambda: None + + +@pytest.fixture +def numpy_default_int_dtype(): + # default int type is platform dependent. + # For windows 64 it is int32 and for linux 64 it is int64 + return str(np.array([1]).dtype) diff --git a/tests/contrib/test_numpy.py b/tests/contrib/test_numpy.py index d12b4735..4a66df95 100644 --- a/tests/contrib/test_numpy.py +++ b/tests/contrib/test_numpy.py @@ -38,10 +38,14 @@ def custom_assert(x, y): @pytest.fixture -def nat(): +def nat(numpy_default_int_dtype): data = np.array([[1, 2], [3, 4]]) dtype = DataType.create(data) - payload = {"shape": (None, 2), "dtype": "int64", "type": "ndarray"} + payload = { + "shape": (None, 2), + "dtype": numpy_default_int_dtype, + "type": "ndarray", + } schema = { "title": "NumpyNdarray", "type": "array", @@ -247,7 +251,7 @@ def test_ndarray(data, test_data_idx): [ lazy_fixture("nat"), np.array([[1, 2], [3, 4]], dtype=np.float32), - "given array is of type: float32, expected: int64", + f"given array is of type: float32, expected: {np.array([[1, 2], [3, 4]]).dtype}", ], [ lazy_fixture("nat"), diff --git a/tests/core/test_data_type.py b/tests/core/test_data_type.py index 9988c7fd..1a7e3740 100644 --- a/tests/core/test_data_type.py +++ b/tests/core/test_data_type.py @@ -425,7 +425,7 @@ def dynamic_dict_dict_type(): @pytest.fixture -def dynamic_dict_ndarray_type(): +def dynamic_dict_ndarray_type(numpy_default_int_dtype): is_dynamic = True d = {11: {1: np.array([1, 2])}, 22: {2: np.array([3, 4])}} payload = { @@ -435,7 +435,7 @@ def dynamic_dict_ndarray_type(): "key_type": {"ptype": "int", "type": "primitive"}, "type": "d_dict", "value_type": { - "dtype": "int64", + "dtype": numpy_default_int_dtype, "shape": (None,), "type": "ndarray", },