From e15e1e24b82fb19d388ce27c9a5a372615f19ed2 Mon Sep 17 00:00:00 2001 From: Antoine DECHAUME <> Date: Mon, 5 Dec 2022 10:58:49 +0100 Subject: [PATCH] Add type annotations --- pydantic_numpy/dtype.py | 13 ++++++++++--- pydantic_numpy/ndarray.py | 39 ++++++++++++++++++++++++++++----------- 2 files changed, 38 insertions(+), 14 deletions(-) diff --git a/pydantic_numpy/dtype.py b/pydantic_numpy/dtype.py index 8baafa5..81cd31f 100644 --- a/pydantic_numpy/dtype.py +++ b/pydantic_numpy/dtype.py @@ -1,4 +1,8 @@ +from __future__ import annotations + from typing import Any +from typing import Dict +from typing import TYPE_CHECKING import numpy as np from pydantic import ValidationError @@ -6,18 +10,21 @@ from pydantic_numpy.ndarray import NDArray +if TYPE_CHECKING: + from pydantic.typing import CallableGenerator + class _BaseDType: @classmethod - def __modify_schema__(cls, field_schema): + def __modify_schema__(cls, field_schema: dict[str, Any]) -> None: field_schema.update({"type": cls.__name__}) @classmethod - def __get_validators__(cls): + def __get_validators__(cls) -> CallableGenerator: yield cls.validate @classmethod - def validate(cls, val: Any, field: ModelField): + def validate(cls, val: Any, field: ModelField) -> "_BaseDType": if field.sub_fields: msg = f"{cls.__name__} has no subfields" raise ValidationError(msg) diff --git a/pydantic_numpy/ndarray.py b/pydantic_numpy/ndarray.py index cf0107d..88f92ea 100644 --- a/pydantic_numpy/ndarray.py +++ b/pydantic_numpy/ndarray.py @@ -1,12 +1,26 @@ -from abc import ABC, abstractmethod -from typing import Any, Generic, Mapping, Optional, Type, TypeVar +from __future__ import annotations + import sys +from abc import ABC +from abc import abstractmethod +from pathlib import Path +from typing import Any +from typing import Generic +from typing import Mapping +from typing import Optional +from typing import TypeVar +from typing import TYPE_CHECKING import numpy as np from numpy.lib import NumpyVersion -from pydantic import BaseModel, FilePath, validator +from pydantic import BaseModel +from pydantic import FilePath +from pydantic import validator from pydantic.fields import ModelField +if TYPE_CHECKING: + from pydantic.typing import CallableGenerator + T = TypeVar("T", bound=np.generic) if sys.version_info < (3, 9) or NumpyVersion(np.__version__) < "1.22.0": @@ -20,18 +34,21 @@ class NPFileDesc(BaseModel): key: Optional[str] = None @validator("path") - def absolute(cls, value): + def absolute(cls, value: Path) -> Path: return value.resolve().absolute() class _CommonNDArray(ABC): + @classmethod @abstractmethod - def validate(cls, val: Any, field: ModelField): + def validate(cls, val: Any, field: ModelField) -> nd_array_type: ... @classmethod - def __modify_schema__(cls, field_schema, field: Optional[ModelField]): + def __modify_schema__( + cls, field_schema: dict[str, Any], field: ModelField | None + ) -> None: if field and field.sub_fields: type_with_potential_subtype = f"np.ndarray[{field.sub_fields[0]}]" else: @@ -39,11 +56,11 @@ def __modify_schema__(cls, field_schema, field: Optional[ModelField]): field_schema.update({"type": type_with_potential_subtype}) @classmethod - def __get_validators__(cls): + def __get_validators__(cls) -> CallableGenerator: yield cls.validate - @classmethod - def _validate(cls: Type, val: Any, field: ModelField) -> "NDArray": + @staticmethod + def _validate(val: Any, field: ModelField) -> nd_array_type: if isinstance(val, Mapping): val = NPFileDesc(**val) @@ -81,7 +98,7 @@ def _validate(cls: Type, val: Any, field: ModelField) -> "NDArray": class NDArray(Generic[T], nd_array_type, _CommonNDArray): @classmethod - def validate(cls, val: Any, field: ModelField) -> np.ndarray: + def validate(cls, val: Any, field: ModelField) -> nd_array_type: return cls._validate(val, field) @@ -89,7 +106,7 @@ class PotentialNDArray(Generic[T], nd_array_type, _CommonNDArray): """Like NDArray, but validation errors result in None.""" @classmethod - def validate(cls, val: Any, field: ModelField) -> Optional[np.ndarray]: + def validate(cls, val: Any, field: ModelField) -> Optional[nd_array_type]: try: return cls._validate(val, field) except ValueError: