Skip to content

Commit

Permalink
Add type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
Antoine DECHAUME committed Dec 5, 2022
1 parent c820100 commit e15e1e2
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 14 deletions.
13 changes: 10 additions & 3 deletions pydantic_numpy/dtype.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,30 @@
from __future__ import annotations

from typing import Any
from typing import Dict
from typing import TYPE_CHECKING

import numpy as np
from pydantic import ValidationError
from pydantic.fields import ModelField

from pydantic_numpy.ndarray import NDArray

if TYPE_CHECKING:
from pydantic.typing import CallableGenerator


class _BaseDType:
@classmethod
def __modify_schema__(cls, field_schema):
def __modify_schema__(cls, field_schema: dict[str, Any]) -> None:
field_schema.update({"type": cls.__name__})

@classmethod
def __get_validators__(cls):
def __get_validators__(cls) -> CallableGenerator:
yield cls.validate

@classmethod
def validate(cls, val: Any, field: ModelField):
def validate(cls, val: Any, field: ModelField) -> "_BaseDType":
if field.sub_fields:
msg = f"{cls.__name__} has no subfields"
raise ValidationError(msg)
Expand Down
39 changes: 28 additions & 11 deletions pydantic_numpy/ndarray.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,26 @@
from abc import ABC, abstractmethod
from typing import Any, Generic, Mapping, Optional, Type, TypeVar
from __future__ import annotations

import sys
from abc import ABC
from abc import abstractmethod
from pathlib import Path
from typing import Any
from typing import Generic
from typing import Mapping
from typing import Optional
from typing import TypeVar
from typing import TYPE_CHECKING

import numpy as np
from numpy.lib import NumpyVersion
from pydantic import BaseModel, FilePath, validator
from pydantic import BaseModel
from pydantic import FilePath
from pydantic import validator
from pydantic.fields import ModelField

if TYPE_CHECKING:
from pydantic.typing import CallableGenerator

T = TypeVar("T", bound=np.generic)

if sys.version_info < (3, 9) or NumpyVersion(np.__version__) < "1.22.0":
Expand All @@ -20,30 +34,33 @@ class NPFileDesc(BaseModel):
key: Optional[str] = None

@validator("path")
def absolute(cls, value):
def absolute(cls, value: Path) -> Path:
return value.resolve().absolute()


class _CommonNDArray(ABC):

@classmethod
@abstractmethod
def validate(cls, val: Any, field: ModelField):
def validate(cls, val: Any, field: ModelField) -> nd_array_type:
...

@classmethod
def __modify_schema__(cls, field_schema, field: Optional[ModelField]):
def __modify_schema__(
cls, field_schema: dict[str, Any], field: ModelField | None
) -> None:
if field and field.sub_fields:
type_with_potential_subtype = f"np.ndarray[{field.sub_fields[0]}]"
else:
type_with_potential_subtype = "np.ndarray"
field_schema.update({"type": type_with_potential_subtype})

@classmethod
def __get_validators__(cls):
def __get_validators__(cls) -> CallableGenerator:
yield cls.validate

@classmethod
def _validate(cls: Type, val: Any, field: ModelField) -> "NDArray":
@staticmethod
def _validate(val: Any, field: ModelField) -> nd_array_type:
if isinstance(val, Mapping):
val = NPFileDesc(**val)

Expand Down Expand Up @@ -81,15 +98,15 @@ def _validate(cls: Type, val: Any, field: ModelField) -> "NDArray":

class NDArray(Generic[T], nd_array_type, _CommonNDArray):
@classmethod
def validate(cls, val: Any, field: ModelField) -> np.ndarray:
def validate(cls, val: Any, field: ModelField) -> nd_array_type:
return cls._validate(val, field)


class PotentialNDArray(Generic[T], nd_array_type, _CommonNDArray):
"""Like NDArray, but validation errors result in None."""

@classmethod
def validate(cls, val: Any, field: ModelField) -> Optional[np.ndarray]:
def validate(cls, val: Any, field: ModelField) -> Optional[nd_array_type]:
try:
return cls._validate(val, field)
except ValueError:
Expand Down

0 comments on commit e15e1e2

Please sign in to comment.