From 9c4c65073721457bd7e0443c5f3044f38646b840 Mon Sep 17 00:00:00 2001 From: Mahesh Ambule Date: Wed, 29 Jun 2022 20:33:10 +0530 Subject: [PATCH] dynamic types support array, dict, ndarray (#304) * dynamic types support array, dict, ndarray * dynamic types support array, dict, ndarray * dynamic types support array, dict, ndarray * fix lint issue * incorporated review comments * incorporated review comments * reverted licence * correct fixture name * fix windows numpy int dtype issue Co-authored-by: Alexander Guschin <1aguschin@gmail.com> --- mlem/contrib/numpy.py | 29 ++- mlem/core/data_type.py | 216 +++++++++++++++++-- tests/conftest.py | 8 + tests/contrib/test_numpy.py | 224 +++++++++++++++++--- tests/core/test_data_type.py | 389 +++++++++++++++++++++++++++++++---- 5 files changed, 770 insertions(+), 96 deletions(-) diff --git a/mlem/contrib/numpy.py b/mlem/contrib/numpy.py index 63169805..6dbdd697 100644 --- a/mlem/contrib/numpy.py +++ b/mlem/contrib/numpy.py @@ -94,7 +94,7 @@ class NumpyNdarrayType( type: ClassVar[str] = "ndarray" libraries: ClassVar[List[ModuleType]] = [np] - shape: Tuple[Optional[int], ...] + shape: Optional[Tuple[Optional[int], ...]] dtype: str @staticmethod @@ -133,7 +133,16 @@ def _subtype(self, subshape: Tuple[Optional[int], ...]): def get_model(self, prefix: str = "") -> Type[BaseModel]: return create_model( - prefix + "NumpyNdarray", __root__=(List[self._subtype(self.shape[1:])], ...) # type: ignore + prefix + "NumpyNdarray", + __root__=( + self._subtype(self.shape) + if self.shape is not None + else List[ + Union[python_type_from_np_string_repr(self.dtype), List] + ], # type: ignore + ..., + ) + # type: ignore ) def serialize(self, instance: np.ndarray): @@ -147,10 +156,20 @@ def serialize(self, instance: np.ndarray): return instance.tolist() def _check_shape(self, array, exc_type): - if tuple(array.shape)[1:] != self.shape[1:]: - raise exc_type( - f"given array is of shape: {(None,) + tuple(array.shape)[1:]}, expected: {self.shape}" + if self.shape is not None: + if len(array.shape) != len(self.shape): + raise exc_type( + f"given array is of rank: {len(array.shape)}, expected: {len(self.shape)}" + ) + + array_shape = tuple( + None if expected_dim is None else array_dim + for array_dim, expected_dim in zip(array.shape, self.shape) ) + if tuple(array_shape) != self.shape: + raise exc_type( + f"given array is of shape: {array_shape}, expected: {self.shape}" + ) def get_writer(self, project: str = None, filename: str = None, **kwargs): return NumpyArrayWriter() diff --git a/mlem/core/data_type.py b/mlem/core/data_type.py index 4356fbf9..2b19fac7 100644 --- a/mlem/core/data_type.py +++ b/mlem/core/data_type.py @@ -2,6 +2,7 @@ Base classes for working with data in MLEM """ import builtins +import json import posixpath from abc import ABC, abstractmethod from typing import ( @@ -18,7 +19,7 @@ ) import flatdict -from pydantic import BaseModel +from pydantic import BaseModel, validator from pydantic.main import create_model from mlem.core.artifacts import Artifacts, Storage @@ -70,8 +71,10 @@ def bind(self, data: Any): return self @classmethod - def create(cls, obj: Any, **kwargs): - return DataAnalyzer.analyze(obj, **kwargs).bind(obj) + def create(cls, obj: Any, is_dynamic: bool = False, **kwargs): + return DataAnalyzer.analyze(obj, is_dynamic=is_dynamic, **kwargs).bind( + obj + ) class DataSerializer(ABC): @@ -367,7 +370,7 @@ def get_model(self, prefix: str = "") -> Type[BaseModel]: def _check_type_and_size(obj, dtype, size, exc_type): DataType.check_type(obj, dtype, exc_type) - if size != -1 and len(obj) != size: + if size is not None and len(obj) != size: raise exc_type( f"given {dtype.__name__} has len: {len(obj)}, expected: {size}" ) @@ -449,44 +452,77 @@ def is_object_valid(cls, obj: Any) -> bool: return isinstance(obj, (list, tuple)) @classmethod - def process(cls, obj, **kwargs) -> DataType: + def process(cls, obj, is_dynamic: bool = False, **kwargs) -> DataType: if isinstance(obj, tuple): - return TupleType(items=[DataAnalyzer.analyze(o) for o in obj]) + return TupleType( + items=[ + DataAnalyzer.analyze(o, is_dynamic=is_dynamic, **kwargs) + for o in obj + ] + ) py_types = {type(o) for o in obj} if len(obj) <= 1 or len(py_types) > 1: - return ListType(items=[DataAnalyzer.analyze(o) for o in obj]) + return ListType( + items=[ + DataAnalyzer.analyze(o, is_dynamic=is_dynamic, **kwargs) + for o in obj + ] + ) + + size = None if is_dynamic else len(obj) if not py_types.intersection( PrimitiveType.PRIMITIVES ): # py_types is guaranteed to be singleton set here - items_types = [DataAnalyzer.analyze(o) for o in obj] + items_types = [ + DataAnalyzer.analyze(o, is_dynamic=is_dynamic, **kwargs) + for o in obj + ] first, *others = items_types for other in others: if first != other: return ListType(items=items_types) - return ArrayType(dtype=first, size=len(obj)) + return ArrayType(dtype=first, size=size) # optimization for large lists of same primitive type elements - return ArrayType(dtype=DataAnalyzer.analyze(obj[0]), size=len(obj)) + return ArrayType( + dtype=DataAnalyzer.analyze( + obj[0], is_dynamic=is_dynamic, **kwargs + ), + size=size, + ) + + +class DictTypeHook(DataHook): + @classmethod + def is_object_valid(cls, obj: Any) -> bool: + return isinstance(obj, dict) + + @classmethod + def process( + cls, obj: Any, is_dynamic: bool = False, **kwargs + ) -> Union["DictType", "DynamicDictType"]: + if not is_dynamic: + return DictType.process(obj, **kwargs) + return DynamicDictType.process(obj, **kwargs) -class DictType(DataType, DataSerializer, DataHook): +class DictType(DataType, DataSerializer): """ - DataType for dict + DataType for dict with fixed set of keys """ type: ClassVar[str] = "dict" item_types: Dict[str, DataType] @classmethod - def is_object_valid(cls, obj: Any) -> bool: - return isinstance(obj, dict) - - @classmethod - def process(cls, obj: Any, **kwargs) -> "DictType": + def process(cls, obj, **kwargs): return DictType( - item_types={k: DataAnalyzer.analyze(v) for (k, v) in obj.items()} + item_types={ + k: DataAnalyzer.analyze(v, is_dynamic=False, **kwargs) + for (k, v) in obj.items() + } ) def deserialize(self, obj): @@ -577,6 +613,150 @@ def read_batch( raise NotImplementedError +class DynamicDictType(DataType, DataSerializer): + """ + Dynamic DataType for dict without fixed set of keys + """ + + type: ClassVar[str] = "d_dict" + + key_type: PrimitiveType + value_type: DataType + + @validator("key_type") + def is_valid_key_type( # pylint: disable=no-self-argument + cls, key_type # noqa: B902 + ): + if key_type.ptype not in ["str", "int", "float"]: + raise ValueError(f"key_type {key_type.ptype} is not supported") + return key_type + + def deserialize(self, obj): + self.check_type(obj, dict, DeserializationError) + return { + self.key_type.get_serializer() + .deserialize( + k, + ): self.value_type.get_serializer() + .deserialize( + v, + ) + for k, v in obj.items() + } + + def serialize(self, instance: dict): + self._check_types(instance, SerializationError) + + return { + self.key_type.get_serializer() + .serialize( + k, + ): self.value_type.get_serializer() + .serialize( + v, + ) + for k, v in instance.items() + } + + @classmethod + def process(cls, obj, **kwargs) -> "DynamicDictType": + return DynamicDictType( + key_type=DataAnalyzer.analyze( + next(iter(obj.keys())), is_dynamic=True, **kwargs + ), + value_type=DataAnalyzer.analyze( + next(iter(obj.values())), is_dynamic=True, **kwargs + ), + ) + + def _check_types(self, obj, exc_type, ignore_key_type: bool = False): + self.check_type(obj, dict, exc_type) + + obj_type = self.process(obj) + if ignore_key_type: + obj_types: Union[ + Tuple[PrimitiveType, DataType], Tuple[DataType] + ] = (obj_type.value_type,) + expected_types: Union[ + Tuple[PrimitiveType, DataType], Tuple[DataType] + ] = (self.value_type,) + else: + obj_types = (obj_type.key_type, obj_type.value_type) + expected_types = (self.key_type, self.value_type) + if obj_types != expected_types: + raise exc_type( + f"given dict has type: {obj_types}, expected: {expected_types}" + ) + + def get_requirements(self) -> Requirements: + return sum( + [ + self.key_type.get_requirements(), + self.value_type.get_requirements(), + ], + Requirements.new(), + ) + + def get_writer( + self, project: str = None, filename: str = None, **kwargs + ) -> "DynamicDictWriter": + return DynamicDictWriter(**kwargs) + + def get_model(self, prefix="") -> Type[BaseModel]: + field_type = ( + Dict[ # type: ignore + self.key_type.get_serializer().get_model( + prefix + "_key_" # noqa: F821 + ), + self.value_type.get_serializer().get_model( + prefix + "_val_" # noqa: F821 + ), + ], + ..., + ) + return create_model(prefix + "DynamicDictType", __root__=field_type) # type: ignore + + +class DynamicDictWriter(DataWriter): + type: ClassVar[str] = "d_dict" + + def write( + self, data: DataType, storage: Storage, path: str + ) -> Tuple[DataReader, Artifacts]: + if not isinstance(data, DynamicDictType): + raise ValueError( + f"expected data to be of DynamicDictTypeWriter, got {type(data)} instead" + ) + with storage.open(path) as (f, art): + f.write( + json.dumps(data.get_serializer().serialize(data.data)).encode( + "utf-8" + ) + ) + return DynamicDictReader(data_type=data), {DataWriter.art_name: art} + + +class DynamicDictReader(DataReader): + type: ClassVar[str] = "d_dict" + data_type: DynamicDictType + + def read(self, artifacts: Artifacts) -> DataType: + if DataWriter.art_name not in artifacts: + raise ValueError( + f"Wrong artifacts {artifacts}: should be one {DataWriter.art_name} file" + ) + with artifacts[DataWriter.art_name].open() as f: + data = json.load(f) + # json stores keys as strings. Deserialize string keys as well as values. + data = self.data_type.deserialize(data) + return self.data_type.copy().bind(data) + + def read_batch( + self, artifacts: Artifacts, batch_size: int + ) -> Iterator[DataType]: + raise NotImplementedError + + # # # class BytesDataType(DataType): diff --git a/tests/conftest.py b/tests/conftest.py index 014f3d2f..305662e7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,7 @@ from typing import Any, Callable, Type import git +import numpy as np import pandas as pd import pytest from fastapi.testclient import TestClient @@ -426,3 +427,10 @@ def disable_colorama(): import colorama colorama.init = lambda: None + + +@pytest.fixture +def numpy_default_int_dtype(): + # default int type is platform dependent. + # For windows 64 it is int32 and for linux 64 it is int64 + return str(np.array([1]).dtype) diff --git a/tests/contrib/test_numpy.py b/tests/contrib/test_numpy.py index 5b9252a1..4a66df95 100644 --- a/tests/contrib/test_numpy.py +++ b/tests/contrib/test_numpy.py @@ -1,8 +1,9 @@ -from json import loads +import re import numpy as np import pytest from pydantic import parse_obj_as +from pytest_lazyfixture import lazy_fixture from mlem.contrib.numpy import ( NumpyNdarrayType, @@ -36,15 +37,138 @@ def custom_assert(x, y): ) -def test_ndarray_source(): - data = np.array([1, 2, 3]) - data_type = DataType.create(data) - data_write_read_check(data_type, custom_eq=np.array_equal) +@pytest.fixture +def nat(numpy_default_int_dtype): + data = np.array([[1, 2], [3, 4]]) + dtype = DataType.create(data) + payload = { + "shape": (None, 2), + "dtype": numpy_default_int_dtype, + "type": "ndarray", + } + schema = { + "title": "NumpyNdarray", + "type": "array", + "items": { + "type": "array", + "items": {"type": "integer"}, + "minItems": 2, + "maxItems": 2, + }, + } + test_data1 = data + test_data2 = np.array([[10, 20], [30, 40]]) + test_data3 = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) + return False, dtype, payload, schema, test_data1, test_data2, test_data3 + + +@pytest.fixture +def nat_dynamic(): + dtype = NumpyNdarrayType(shape=[2, None, None], dtype="int") + payload = {"dtype": "int", "shape": (2, None, None), "type": "ndarray"} + schema = { + "items": { + "items": {"items": {"type": "integer"}, "type": "array"}, + "type": "array", + }, + "maxItems": 2, + "minItems": 2, + "title": "NumpyNdarray", + "type": "array", + } + test_data1 = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + test_data2 = np.array( + [[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]] + ) + test_data3 = np.array([[[1, 2, 2], [3, 4, 4]], [[5, 6, 6], [7, 8, 8]]]) + return True, dtype, payload, schema, test_data1, test_data2, test_data3 + + +@pytest.fixture +def nat_dynamic_float(): + dtype = NumpyNdarrayType(shape=[2, None, None, 1], dtype="float") + payload = { + "dtype": "float", + "shape": (2, None, None, 1), + "type": "ndarray", + } + schema = { + "items": { + "items": { + "items": { + "items": {"type": "number"}, + "maxItems": 1, + "minItems": 1, + "type": "array", + }, + "type": "array", + }, + "type": "array", + }, + "maxItems": 2, + "minItems": 2, + "title": "NumpyNdarray", + "type": "array", + } + test_data1 = np.array([[[[1.0]], [[3.0]]], [[[5.1]], [[7.1]]]]) + test_data2 = np.array([[[[1.1], [3.0], [5.0]]], [[[7.1], [9.99], [11.2]]]]) + test_data3 = np.array( + [[[[1.1], [3.2]], [[5.33], [7.1]]], [[[1.11], [3.4]], [[5.3], [7.2]]]] + ) + return True, dtype, payload, schema, test_data1, test_data2, test_data3 + + +@pytest.fixture +def nat_dynamic_all_none_dims(): + dtype = NumpyNdarrayType(shape=[None, None, None], dtype="int") + payload = {"dtype": "int", "shape": (None, None, None), "type": "ndarray"} + schema = { + "items": { + "items": {"items": {"type": "integer"}, "type": "array"}, + "type": "array", + }, + "title": "NumpyNdarray", + "type": "array", + } + test_data1 = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + test_data2 = np.array( + [[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]] + ) + test_data3 = np.array([[[1, 2, 2], [3, 4, 4]], [[5, 6, 6], [7, 8, 8]]]) + return True, dtype, payload, schema, test_data1, test_data2, test_data3 @pytest.fixture -def nat(): - return DataType.create(np.array([[1, 2], [3, 4]])) +def nat_dynamic_shape_none(): + dtype = NumpyNdarrayType(shape=None, dtype="int") + payload = {"dtype": "int", "type": "ndarray"} + schema = { + "items": { + "anyOf": [{"type": "integer"}, {"items": {}, "type": "array"}] + }, + "title": "NumpyNdarray", + "type": "array", + } + test_data1 = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + test_data2 = np.array( + [[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]] + ) + test_data3 = np.array([[[1, 2, 2], [3, 4, 4]], [[5, 6, 6], [7, 8, 8]]]) + return True, dtype, payload, schema, test_data1, test_data2, test_data3 + + +@pytest.fixture +def nat_shape_empty(): + dtype = NumpyNdarrayType(shape=(), dtype="int") + payload = {"dtype": "int", "shape": (), "type": "ndarray"} + schema = { + "title": "NumpyNdarray", + "type": "integer", + } + test_data1 = np.array(1) + test_data2 = np.array(3) + test_data3 = np.array(4) + return True, dtype, payload, schema, test_data1, test_data2, test_data3 def test_python_type_from_np_string_repr(): @@ -81,41 +205,79 @@ def test_number(): assert ndt.get_serializer().deserialize(n_payload) == value -def test_ndarray(nat): - value = nat.data +@pytest.mark.parametrize("test_data_idx", [4, 5, 6]) +@pytest.mark.parametrize( + "data", + [ + (lazy_fixture("nat")), + (lazy_fixture("nat_dynamic")), + (lazy_fixture("nat_dynamic_all_none_dims")), + (lazy_fixture("nat_dynamic_shape_none")), + (lazy_fixture("nat_dynamic_float")), + (lazy_fixture("nat_shape_empty")), + ], +) +def test_ndarray(data, test_data_idx): + nat, payload, schema, value = ( + data[1], + data[2], + data[3], + data[test_data_idx], + ) assert isinstance(nat, NumpyNdarrayType) - assert nat.shape == (None, 2) - assert python_type_from_np_string_repr(nat.dtype) == int assert nat.get_requirements().modules == ["numpy"] - payload = nat.json() - nat2 = parse_obj_as(DataType, loads(payload)) + assert nat.dict() == payload + nat2 = parse_obj_as(DataType, payload) assert nat == nat2 assert nat.get_model().__name__ == nat2.get_model().__name__ - assert nat.get_model().schema() == { - "title": "NumpyNdarray", - "type": "array", - "items": { - "type": "array", - "items": {"type": "integer"}, - "minItems": 2, - "maxItems": 2, - }, - } + assert nat.get_model().schema() == schema n_payload = nat.get_serializer().serialize(value) assert (nat.get_serializer().deserialize(n_payload) == value).all() + model = parse_obj_as(nat.get_model(), n_payload) + assert model.__root__ == n_payload + + nat = nat.bind(value) + data_write_read_check(nat, custom_eq=np.array_equal) @pytest.mark.parametrize( - "obj", + "nddtype,obj,err_msg", [ - {}, # wrong type - np.array([[1, 2], [3, 4]], dtype=np.float32), # wrong data type - np.array([1, 2]), # wrong shape + [ + lazy_fixture("nat"), + {}, + "given data is of type: , expected: ", + ], + [ + lazy_fixture("nat"), + np.array([[1, 2], [3, 4]], dtype=np.float32), + f"given array is of type: float32, expected: {np.array([[1, 2], [3, 4]]).dtype}", + ], + [ + lazy_fixture("nat"), + np.array([1, 2]), + "given array is of rank: 1, expected: 2", + ], + [ + lazy_fixture("nat_dynamic"), + np.array([1, 2]), + "given array is of rank: 1, expected: 3", + ], + [ + lazy_fixture("nat_shape_empty"), + np.array([1, 2]), + "given array is of rank: 1, expected: 0", + ], + [ + lazy_fixture("nat_dynamic_float"), + np.array([[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]]), + "given array is of shape: (1, None, None, 2), expected: (2, None, None, 1)", + ], ], ) -def test_ndarray_serialize_failure(nat, obj): - with pytest.raises(SerializationError): - nat.serialize(obj) +def test_ndarray_serialize_failure(nddtype, obj, err_msg): + with pytest.raises(SerializationError, match=re.escape(err_msg)): + nddtype[1].serialize(obj) @pytest.mark.parametrize( @@ -124,7 +286,7 @@ def test_ndarray_serialize_failure(nat, obj): ) def test_ndarray_deserialize_failure(nat, obj): with pytest.raises(DeserializationError): - nat.deserialize(obj) + nat[1].deserialize(obj) def test_requirements(): diff --git a/tests/core/test_data_type.py b/tests/core/test_data_type.py index f79ddb84..1a7e3740 100644 --- a/tests/core/test_data_type.py +++ b/tests/core/test_data_type.py @@ -1,5 +1,9 @@ +import copy + +import numpy as np import pytest from pydantic import parse_obj_as +from pytest_lazyfixture import lazy_fixture from mlem.core.data_type import ( ArrayReader, @@ -9,6 +13,8 @@ DataType, DictReader, DictType, + DynamicDictReader, + DynamicDictType, ListType, PrimitiveReader, PrimitiveType, @@ -27,6 +33,32 @@ def test_primitives_not_ok(): assert not PrimitiveType.is_object_valid(NotPrimitive()) +@pytest.fixture +def array(): + is_dynamic = False + array = [1, 2, 3, 4, 5] + payload = { + "dtype": {"ptype": "int", "type": "primitive"}, + "size": 5, + "type": "array", + } + schema = { + "items": {"type": "integer"}, + "title": "Array", + "type": "array", + } + + return is_dynamic, array, payload, schema + + +@pytest.fixture +def array_dynamic(array): + is_dynamic = True + payload = copy.deepcopy(array[2]) + del payload["size"] + return is_dynamic, array[1], payload, array[3] + + @pytest.mark.parametrize("ptype", PrimitiveType.PRIMITIVES) def test_primitive_source(ptype): if ptype is type(None): # noqa: E721 @@ -63,31 +95,39 @@ def test_primitives(ptype): assert dt.get_model() is ptype -def test_array(): - l_value = [1, 2, 3, 4, 5] - dt = DataAnalyzer.analyze(l_value) +@pytest.mark.parametrize( + "array_data,value", + [ + (lazy_fixture("array"), None), + (lazy_fixture("array_dynamic"), None), + (lazy_fixture("array_dynamic"), [1, 2, 3]), + ], +) +def test_array(array_data, value): + dt = DataAnalyzer.analyze(array_data[1], is_dynamic=array_data[0]) + l_value = array_data[1] if value is None else value assert isinstance(dt, ArrayType) - payload = { - "dtype": {"ptype": "int", "type": "primitive"}, - "size": 5, - "type": "array", - } - assert dt.dict() == payload - dt2 = parse_obj_as(ArrayType, payload) + assert dt.dict() == array_data[2] + dt2 = parse_obj_as(ArrayType, array_data[2]) assert dt2 == dt assert l_value == dt.serialize(l_value) assert l_value == dt.deserialize(l_value) assert dt.get_model().__name__ == "Array" - assert dt.get_model().schema() == { - "items": {"type": "integer"}, - "title": "Array", - "type": "array", - } + assert dt.get_model().schema() == array_data[3] -def test_list_source(): - l_value = [1, 2, 3, 4, 5] - dt = DataType.create(l_value) +@pytest.mark.parametrize( + "array_data,value", + [ + (lazy_fixture("array"), None), + (lazy_fixture("array_dynamic"), None), + (lazy_fixture("array_dynamic"), [1, 2, 3]), + ], +) +def test_list_source(array_data, value): + dt = DataType.create(array_data[1]) + l_value = array_data[1] if value is None else value + dt.bind(l_value) artifacts = data_write_read_check( dt, @@ -95,11 +135,8 @@ def test_list_source(): ) assert list(artifacts.keys()) == [f"{x}/data" for x in range(len(l_value))] - assert artifacts["0/data"].uri.endswith("data/0") - assert artifacts["1/data"].uri.endswith("data/1") - assert artifacts["2/data"].uri.endswith("data/2") - assert artifacts["3/data"].uri.endswith("data/3") - assert artifacts["4/data"].uri.endswith("data/4") + for x in range(len(l_value)): + assert artifacts[f"{x}/data"].uri.endswith(f"data/{x}") def test_tuple(): @@ -198,10 +235,10 @@ def test_mixed_list_source(): assert artifacts["5/data"].uri.endswith("data/5") -def test_dict(): +@pytest.fixture +def dict_data(): + is_dynamic = False d = {"1": 1, "2": "a"} - dt = DataAnalyzer.analyze(d) - assert isinstance(dt, DictType) payload = { "item_types": { "1": {"ptype": "int", "type": "primitive"}, @@ -209,13 +246,8 @@ def test_dict(): }, "type": "dict", } - assert dt.dict() == payload - dt2 = parse_obj_as(DictType, payload) - assert dt2 == dt - assert d == dt.serialize(d) - assert d == dt.deserialize(d) - assert dt.get_model().__name__ == "DictType" - assert dt.get_model().schema() == { + + schema = { "title": "DictType", "type": "object", "properties": { @@ -225,24 +257,297 @@ def test_dict(): "required": ["1", "2"], } + test_data1 = {"1": 1, "2": "a"} + test_data2 = {"1": 2, "2": "b"} + test_data3 = {"1": 3, "2": "c"} + + return is_dynamic, d, payload, schema, test_data1, test_data2, test_data3 + + +@pytest.fixture +def dynamic_dict_data(): + is_dynamic = True + d = {"a": 1, "b": 2} + payload = { + "key_type": {"ptype": "str", "type": "primitive"}, + "value_type": {"ptype": "int", "type": "primitive"}, + "type": "d_dict", + } + schema = { + "title": "DynamicDictType", + "type": "object", + "additionalProperties": {"type": "integer"}, + } + + test_data1 = {"a": 1, "b": 2} + test_data2 = {"a": 1} + test_data3 = {"a": 1, "b": 2, "c": 3, "d": 1} + + return is_dynamic, d, payload, schema, test_data1, test_data2, test_data3 + + +@pytest.fixture +def dynamic_dict_str_val_type_data(): + is_dynamic = True + d = {"a": "1", "b": "2"} + payload = { + "key_type": {"ptype": "str", "type": "primitive"}, + "value_type": {"ptype": "str", "type": "primitive"}, + "type": "d_dict", + } + schema = { + "title": "DynamicDictType", + "type": "object", + "additionalProperties": {"type": "string"}, + } + + test_data1 = {"a": "1", "b": "2"} + test_data2 = {"a": "1"} + test_data3 = {"a": "1", "b": "2", "c": "3", "d": "1"} + + return is_dynamic, d, payload, schema, test_data1, test_data2, test_data3 -def test_dict_source(): - d_value = {"1": 1.5, "2": "a", "3": {"1": False}} - data_type = DataType.create(d_value) + +@pytest.fixture +def dynamic_dict_int_key_type_data(): + is_dynamic = True + d = {1: "1", 2: "2"} + payload = { + "key_type": {"ptype": "int", "type": "primitive"}, + "value_type": {"ptype": "str", "type": "primitive"}, + "type": "d_dict", + } + schema = { + "title": "DynamicDictType", + "type": "object", + "additionalProperties": {"type": "string"}, + } + + test_data1 = {1: "1", 2: "2"} + test_data2 = {2: "1"} + test_data3 = {3: "1", 4: "2", 5: "3", 6: "1"} + + return is_dynamic, d, payload, schema, test_data1, test_data2, test_data3 + + +@pytest.fixture +def dynamic_dict_float_key_type_data(): + is_dynamic = True + d = {1.0: "1", 2.0: "2"} + payload = { + "key_type": {"ptype": "float", "type": "primitive"}, + "value_type": {"ptype": "str", "type": "primitive"}, + "type": "d_dict", + } + schema = { + "title": "DynamicDictType", + "type": "object", + "additionalProperties": {"type": "string"}, + } + + test_data1 = {1.0: "1", 2.0: "2"} + test_data2 = {2.9999999999: "1"} + test_data3 = {3.8998: "1", 4.0001: "2", 5.2: "3", 6.9: "1"} + + return is_dynamic, d, payload, schema, test_data1, test_data2, test_data3 + + +@pytest.fixture +def dynamic_dict_array_type(): + is_dynamic = True + d = {"a": [1, 2, 3], "b": [3, 4, 5]} + payload = { + "key_type": {"ptype": "str", "type": "primitive"}, + "type": "d_dict", + "value_type": { + "dtype": {"ptype": "int", "type": "primitive"}, + "type": "array", + }, + } + schema = { + "additionalProperties": {"$ref": "#/definitions/_val_Array"}, + "definitions": { + "_val_Array": { + "items": {"type": "integer"}, + "title": "_val_Array", + "type": "array", + } + }, + "title": "DynamicDictType", + "type": "object", + } + + test_data1 = {"a": [1, 2, 3], "b": [3, 4, 5]} + test_data2 = {"a": [1, 2, 3]} + test_data3 = {"a": [1, 2, 3], "b": [3, 4, 5], "d": [6, 7, 8]} + return is_dynamic, d, payload, schema, test_data1, test_data2, test_data3 + + +@pytest.fixture +def dynamic_dict_dict_type(): + is_dynamic = True + d = {"a": {"l": [1, 2]}, "b": {"l": [3, 4]}} + payload = { + "key_type": {"ptype": "str", "type": "primitive"}, + "type": "d_dict", + "value_type": { + "key_type": {"ptype": "str", "type": "primitive"}, + "type": "d_dict", + "value_type": { + "dtype": {"ptype": "int", "type": "primitive"}, + "type": "array", + }, + }, + } + schema = { + "additionalProperties": {"$ref": "#/definitions/_val_DynamicDictType"}, + "definitions": { + "_val_DynamicDictType": { + "additionalProperties": { + "$ref": "#/definitions/_val__val_Array" + }, + "title": "_val_DynamicDictType", + "type": "object", + }, + "_val__val_Array": { + "items": {"type": "integer"}, + "title": "_val__val_Array", + "type": "array", + }, + }, + "title": "DynamicDictType", + "type": "object", + } + test_data1 = {"a": {"l": [1, 2]}, "b": {"l": [3, 4]}} + test_data2 = {"a": {"l": [1, 2]}} + test_data3 = {"a": {"l": [1, 2]}, "b": {"l": [3, 4]}, "c": {"k": [3, 4]}} + return is_dynamic, d, payload, schema, test_data1, test_data2, test_data3 + + +@pytest.fixture +def dynamic_dict_ndarray_type(numpy_default_int_dtype): + is_dynamic = True + d = {11: {1: np.array([1, 2])}, 22: {2: np.array([3, 4])}} + payload = { + "key_type": {"ptype": "int", "type": "primitive"}, + "type": "d_dict", + "value_type": { + "key_type": {"ptype": "int", "type": "primitive"}, + "type": "d_dict", + "value_type": { + "dtype": numpy_default_int_dtype, + "shape": (None,), + "type": "ndarray", + }, + }, + } + schema = { + "additionalProperties": {"$ref": "#/definitions/_val_DynamicDictType"}, + "definitions": { + "_val_DynamicDictType": { + "additionalProperties": { + "$ref": "#/definitions/_val__val_NumpyNdarray" + }, + "title": "_val_DynamicDictType", + "type": "object", + }, + "_val__val_NumpyNdarray": { + "items": {"type": "integer"}, + "title": "_val__val_NumpyNdarray", + "type": "array", + }, + }, + "title": "DynamicDictType", + "type": "object", + } + + test_data1 = {11: {1: np.array([1, 2])}, 22: {2: np.array([3, 4])}} + test_data2 = {11: {1: np.array([1, 2])}} + test_data3 = { + 11: {1: np.array([1, 2])}, + 22: {2: np.array([3, 4])}, + 33: {2: np.array([5, 6])}, + } + return is_dynamic, d, payload, schema, test_data1, test_data2, test_data3 + + +@pytest.mark.parametrize("test_data_idx", [4, 5, 6]) +@pytest.mark.parametrize( + "data", + [ + lazy_fixture("dict_data"), + lazy_fixture("dynamic_dict_data"), + lazy_fixture("dynamic_dict_str_val_type_data"), + lazy_fixture("dynamic_dict_int_key_type_data"), + lazy_fixture("dynamic_dict_float_key_type_data"), + lazy_fixture("dynamic_dict_array_type"), + lazy_fixture("dynamic_dict_dict_type"), + lazy_fixture("dynamic_dict_ndarray_type"), + ], +) +def test_dict(data, test_data_idx): + is_dynamic, d, payload, schema, test_data = ( + data[0], + data[1], + data[2], + data[3], + data[test_data_idx], + ) + dt = DataAnalyzer.analyze(d, is_dynamic=is_dynamic) + dtype = DictType if not is_dynamic else DynamicDictType + assert isinstance(dt, dtype) + + assert dt.dict() == payload + dt2 = parse_obj_as(dtype, payload) + assert dt2 == dt + serialised_test_data = dt.serialize(test_data) + deserialised_test_data = dt.deserialize(serialised_test_data) + assert serialised_test_data == dt.serialize(deserialised_test_data) + assert dt.get_model().__name__ == dtype.__name__ + assert dt.get_model().schema() == schema + assert parse_obj_as(dt.get_model(), serialised_test_data) + + +@pytest.mark.parametrize("test_data_idx", [4, 5, 6]) +@pytest.mark.parametrize( + "data", + [ + lazy_fixture("dict_data"), + lazy_fixture("dynamic_dict_data"), + lazy_fixture("dynamic_dict_str_val_type_data"), + lazy_fixture("dynamic_dict_int_key_type_data"), + lazy_fixture("dynamic_dict_float_key_type_data"), + lazy_fixture("dynamic_dict_array_type"), + lazy_fixture("dynamic_dict_dict_type"), + lazy_fixture("dynamic_dict_ndarray_type"), + ], +) +def test_dict_source(data, test_data_idx): + is_dynamic, d, test_data = ( + data[0], + data[1], + data[test_data_idx], + ) + data_type = DataType.create(d, is_dynamic=is_dynamic) + data_type = data_type.bind(test_data) + dtype_reader = DynamicDictReader if is_dynamic else DictReader def custom_assert(x, y): - assert x == y + np.testing.assert_equal(x, y) assert len(x) == len(y) assert isinstance(x, dict) assert isinstance(y, dict) artifacts = data_write_read_check( data_type, - reader_type=DictReader, + reader_type=dtype_reader, custom_assert=custom_assert, ) - assert list(artifacts.keys()) == ["1/data", "2/data", "3/1/data"] - assert artifacts["1/data"].uri.endswith("data/1") - assert artifacts["2/data"].uri.endswith("data/2") - assert artifacts["3/1/data"].uri.endswith("data/3/1") + if not is_dynamic: + assert list(artifacts.keys()) == ["1/data", "2/data"] + assert artifacts["1/data"].uri.endswith("data/1") + assert artifacts["2/data"].uri.endswith("data/2") + else: + assert list(artifacts.keys()) == ["data"] + assert artifacts["data"].uri.endswith("data")