Skip to content

Commit

Permalink
fix(typing): limit supported dtypes to subclasses of np.generic
Browse files Browse the repository at this point in the history
test: remove few "pyright: ignore" instructions where we previously
failed for common Python types
chore: bump PATCH
  • Loading branch information
futurwasfree authored and Can H. Tartanoglu committed Jan 3, 2024
1 parent d5f9e02 commit dfc63d9
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 55 deletions.
25 changes: 14 additions & 11 deletions pydantic_numpy/helper/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
from typing import Any, Callable, ClassVar, Optional, Union

import numpy as np
from numpy.typing import DTypeLike
from pydantic import FilePath, GetJsonSchemaHandler, PositiveInt, validate_call
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import core_schema
from typing_extensions import Annotated

from pydantic_numpy.helper.typing import NumpyDataDict
from pydantic_numpy.helper.typing import NumpyDataDict, SupportedDTypes
from pydantic_numpy.helper.validation import (
create_array_validator,
validate_multi_array_numpy_file,
Expand All @@ -21,12 +20,16 @@
class NpArrayPydanticAnnotation:
dimensions: ClassVar[Optional[PositiveInt]]

data_type: ClassVar[DTypeLike]
data_type: ClassVar[SupportedDTypes]
strict_data_typing: ClassVar[bool]

@classmethod
def factory(
cls, *, data_type: DTypeLike, dimensions: Optional[int] = None, strict_data_typing: bool = False
cls,
*,
data_type: Optional[SupportedDTypes] = None,
dimensions: Optional[int] = None,
strict_data_typing: bool = False,
) -> type:
"""
Create an instance NpArrayPydanticAnnotation that is configured for a specific dimension and dtype.
Expand All @@ -36,7 +39,7 @@ def factory(
Parameters
----------
data_type: DTypeLike
data_type: SupportedDTypes
dimensions: Optional[int]
Number of dimensions determine the depth of the numpy array.
strict_data_typing: bool
Expand All @@ -47,7 +50,7 @@ def factory(
NpArrayPydanticAnnotation
"""
if strict_data_typing and not data_type:
msg = "Strict data typing requires data_type (DTypeLike) definition"
msg = "Strict data typing requires data_type (SupportedDTypes) definition"
raise ValueError(msg)

return type(
Expand Down Expand Up @@ -90,14 +93,14 @@ def __get_pydantic_json_schema__(


def np_array_pydantic_annotated_typing(
data_type: DTypeLike = None, dimensions: Optional[int] = None, strict_data_typing: bool = False
data_type: Optional[SupportedDTypes] = None, dimensions: Optional[int] = None, strict_data_typing: bool = False
):
"""
Generates typing and pydantic annotation of a np.ndarray parametrized with given constraints
Parameters
----------
data_type: DTypeLike
data_type: SupportedDTypes
dimensions: Optional[int]
Number of dimensions determine the depth of the numpy array.
strict_data_typing: bool
Expand All @@ -113,7 +116,7 @@ def np_array_pydantic_annotated_typing(
MultiArrayNumpyFile,
np.ndarray[ # type: ignore[misc]
_int_to_dim_type[dimensions] if dimensions else Any, # pyright: ignore
np.dtype[data_type] if _data_type_resolver(data_type) else data_type, # type: ignore[misc]
np.dtype[data_type] if _data_type_resolver(data_type) else data_type, # type: ignore[valid-type]
],
],
NpArrayPydanticAnnotation.factory(
Expand All @@ -122,8 +125,8 @@ def np_array_pydantic_annotated_typing(
]


def _data_type_resolver(data_type: DTypeLike) -> bool:
return data_type is not None and issubclass(data_type, np.generic) # pyright: ignore
def _data_type_resolver(data_type: Optional[SupportedDTypes]) -> bool:
return data_type is not None and issubclass(data_type, np.generic)


def _serialize_numpy_array_to_data_dict(array: np.ndarray) -> NumpyDataDict:
Expand Down
3 changes: 3 additions & 0 deletions pydantic_numpy/helper/typing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import numpy as np
from typing_extensions import TypedDict

SupportedDTypes = type[np.generic]


class NumpyDataDict(TypedDict):
data_type: str
Expand Down
30 changes: 4 additions & 26 deletions pydantic_numpy/helper/validation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Callable, Optional, Union, cast
from typing import Callable, Optional, Union

import numpy as np
import numpy.typing as npt
from numpy import floating, integer
from numpy.lib.npyio import NpzFile
from pydantic import FilePath

from pydantic_numpy.helper.typing import NumpyDataDict
from pydantic_numpy.helper.typing import NumpyDataDict, SupportedDTypes
from pydantic_numpy.model import MultiArrayNumpyFile


Expand All @@ -15,7 +15,7 @@ class PydanticNumpyMultiArrayNumpyFileOnFilePath(Exception):


def create_array_validator(
dimensions: Optional[int], target_data_type: npt.DTypeLike, strict_data_typing: bool
dimensions: Optional[int], target_data_type: SupportedDTypes, strict_data_typing: bool
) -> Callable[[npt.NDArray], npt.NDArray]:
"""
Creates a validator that ensures the numpy array has the defined dimensions and dtype (data_type).
Expand Down Expand Up @@ -52,9 +52,7 @@ def array_validator(array_data: Union[npt.NDArray, NumpyDataDict]) -> npt.NDArra
msg = f"The data_type {array.dtype.type} does not coincide with type hint; {target_data_type}"
raise ValueError(msg)

if issubclass(_resolve_type_of_array_dtype(target_data_type), integer) and issubclass(
_resolve_type_of_array_dtype(array.dtype), floating
):
if issubclass(target_data_type, integer) and issubclass(array.dtype.type, floating):
array = np.round(array).astype(target_data_type, copy=False)
else:
array = array.astype(target_data_type, copy=True)
Expand Down Expand Up @@ -109,23 +107,3 @@ def validate_multi_array_numpy_file(v: MultiArrayNumpyFile) -> npt.NDArray:
NDArray from MultiArrayNumpyFile
"""
return v.load()


def _resolve_type_of_array_dtype(array_dtype: npt.DTypeLike) -> type:
"""
np.dtype have the type stored in the type attribute, function to extract that type.
If the DTypelike isn't np.dtype we just return what is already a type.
Parameters
----------
array_dtype: DTypeLike
Returns
-------
type
"""
if hasattr(array_dtype, "type"):
assert array_dtype is not None
return array_dtype.type # pyright: ignore
else:
return cast(type, array_dtype)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pydantic_numpy"
version = "4.1.2"
version = "4.1.3"
description = "Pydantic Model integration of the NumPy array"
authors = ["Can H. Tartanoglu", "Christoph Heindl"]
maintainers = ["Can H. Tartanoglu <[email protected]>"]
Expand Down
5 changes: 3 additions & 2 deletions tests/helper/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from typing import Optional

import numpy as np
import numpy.typing as npt
from hypothesis.extra.numpy import arrays
from hypothesis.strategies import floats
from pydantic import BaseModel

from pydantic_numpy.helper.typing import SupportedDTypes


@cache
def cached_calculation(array_type_hint) -> type[BaseModel]:
Expand All @@ -17,7 +18,7 @@ class ModelForTesting(BaseModel):


@cache
def cached_hyp_array(numpy_dtype: npt.DTypeLike, dimensions: Optional[int] = None, *, _axis_length: int = 1):
def cached_hyp_array(numpy_dtype: SupportedDTypes, dimensions: Optional[int] = None, *, _axis_length: int = 1):
if np.issubdtype(numpy_dtype, np.floating):
if numpy_dtype == np.float16:
width = 16
Expand Down
28 changes: 13 additions & 15 deletions tests/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from typing import Hashable, Optional, cast

import numpy as np
import numpy.typing as npt
import orjson
import pytest
from numpy.testing import assert_almost_equal
from pydantic import ValidationError

from pydantic_numpy.helper.typing import SupportedDTypes
from pydantic_numpy.helper.validation import PydanticNumpyMultiArrayNumpyFileOnFilePath
from pydantic_numpy.model import MultiArrayNumpyFile
from pydantic_numpy.util import np_general_all_close
Expand All @@ -24,15 +24,15 @@


@pytest.mark.parametrize("numpy_dtype,pydantic_typing,dimensions", data_type_array_typing_dimensions)
def test_correct_type(numpy_dtype: npt.DTypeLike, pydantic_typing, dimensions: Optional[int]):
def test_correct_type(numpy_dtype: SupportedDTypes, pydantic_typing, dimensions: Optional[int]):
assert cached_calculation(pydantic_typing)(
array_field=cached_hyp_array(cast(Hashable, numpy_dtype), dimensions).example()
)


@pytest.mark.parametrize("numpy_dtype,pydantic_typing,dimensions", strict_data_type_nd_array_typing_dimensions)
@pytest.mark.parametrize("wrong_numpy_type", supported_data_types)
def test_wrong_dtype_type(numpy_dtype: npt.DTypeLike, pydantic_typing, dimensions: Optional[int], wrong_numpy_type):
def test_wrong_dtype_type(numpy_dtype: SupportedDTypes, pydantic_typing, dimensions: Optional[int], wrong_numpy_type):
if wrong_numpy_type == numpy_dtype:
return

Expand All @@ -42,7 +42,7 @@ def test_wrong_dtype_type(numpy_dtype: npt.DTypeLike, pydantic_typing, dimension


@pytest.mark.parametrize("numpy_dtype,pydantic_typing,dimensions", dimension_testing_group)
def test_wrong_dimension(numpy_dtype: npt.DTypeLike, pydantic_typing, dimensions: Optional[int]):
def test_wrong_dimension(numpy_dtype: SupportedDTypes, pydantic_typing, dimensions: Optional[int]):
assert dimensions is not None
wrong_dimension = dimensions + 1

Expand All @@ -56,34 +56,32 @@ def test_wrong_dimension(numpy_dtype: npt.DTypeLike, pydantic_typing, dimensions
@pytest.mark.parametrize(
"numpy_dtype,pydantic_typing,dimensions", data_type_nd_array_typing_dimensions_without_complex
)
def test_json_serialize_deserialize(numpy_dtype: npt.DTypeLike, pydantic_typing, dimensions: Optional[int]):
def test_json_serialize_deserialize(numpy_dtype: SupportedDTypes, pydantic_typing, dimensions: Optional[int]):
hyp_array = cached_hyp_array(cast(Hashable, numpy_dtype), dimensions).example()

numpy_model = cached_calculation(pydantic_typing)
dumped_model_json_loaded = orjson.loads(numpy_model(array_field=hyp_array).model_dump_json())

round_trip_result = numpy_model(
array_field=dumped_model_json_loaded["array_field"]
).array_field # pyright: ignore
round_trip_result = numpy_model(array_field=dumped_model_json_loaded["array_field"]).array_field

if issubclass(numpy_dtype, np.timedelta64) or issubclass(numpy_dtype, np.datetime64): # pyright: ignore
if issubclass(numpy_dtype, np.timedelta64) or issubclass(numpy_dtype, np.datetime64):
assert hyp_array == round_trip_result
else:
assert_almost_equal(hyp_array, round_trip_result)

@pytest.mark.parametrize("numpy_dtype,pydantic_typing,dimensions", data_type_array_typing_dimensions)
def test_file_path_passing_validation(numpy_dtype: npt.DTypeLike, pydantic_typing, dimensions: Optional[int]):
def test_file_path_passing_validation(numpy_dtype: SupportedDTypes, pydantic_typing, dimensions: Optional[int]):
hyp_array = cached_hyp_array(cast(Hashable, numpy_dtype), dimensions).example()

with tempfile.NamedTemporaryFile(mode="w+", delete=True, suffix=".npz") as tf:
np.savez_compressed(tf.name, my_array=hyp_array)
numpy_model = cached_calculation(pydantic_typing)(array_field=Path(tf.name))

assert np_general_all_close(numpy_model.array_field, hyp_array) # pyright: ignore
assert np_general_all_close(numpy_model.array_field, hyp_array)

@pytest.mark.parametrize("numpy_dtype,pydantic_typing,dimensions", data_type_array_typing_dimensions)
def test_file_path_error_on_reading_single_array_file(
numpy_dtype: npt.DTypeLike, pydantic_typing, dimensions: Optional[int]
numpy_dtype: SupportedDTypes, pydantic_typing, dimensions: Optional[int]
):
hyp_array = cached_hyp_array(cast(Hashable, numpy_dtype), dimensions).example()

Expand All @@ -96,7 +94,7 @@ def test_file_path_error_on_reading_single_array_file(

@pytest.mark.parametrize("numpy_dtype,pydantic_typing,dimensions", data_type_array_typing_dimensions)
def test_multi_array_numpy_passing_validation(
numpy_dtype: npt.DTypeLike, pydantic_typing, dimensions: Optional[int]
numpy_dtype: SupportedDTypes, pydantic_typing, dimensions: Optional[int]
):
hyp_array = cached_hyp_array(cast(Hashable, numpy_dtype), dimensions).example()

Expand All @@ -105,11 +103,11 @@ def test_multi_array_numpy_passing_validation(
numpy_model = cached_calculation(pydantic_typing)(
array_field=MultiArrayNumpyFile(path=Path(tf.name), key="my_array")
)
assert np_general_all_close(numpy_model.array_field, hyp_array) # pyright: ignore
assert np_general_all_close(numpy_model.array_field, hyp_array)

@pytest.mark.parametrize("numpy_dtype,pydantic_typing,dimensions", data_type_array_typing_dimensions)
def test_multi_array_numpy_error_on_reading_single_array_file(
numpy_dtype: npt.DTypeLike, pydantic_typing, dimensions: Optional[int]
numpy_dtype: SupportedDTypes, pydantic_typing, dimensions: Optional[int]
):
hyp_array = cached_hyp_array(cast(Hashable, numpy_dtype), dimensions).example()

Expand Down

0 comments on commit dfc63d9

Please sign in to comment.