diff --git a/pydantic_numpy/__init__.py b/pydantic_numpy/__init__.py index 43e0a93..1a6cc26 100644 --- a/pydantic_numpy/__init__.py +++ b/pydantic_numpy/__init__.py @@ -9,6 +9,16 @@ T = TypeVar("T", bound=np.generic) +class BaseNDArray: + @classmethod + def __modify_schema__(cls, field_schema): + # __modify_schema__ should mutate the dict it receives in place, + # the returned value will be ignored + field_schema.update({ + "type": "np.ndarray" + }) + + class NPFileDesc(BaseModel): path: Path key: Optional[str] @@ -16,7 +26,7 @@ class NPFileDesc(BaseModel): if NumpyVersion(np.__version__) < "1.22.0": - class NDArray(Generic[T], np.ndarray): + class NDArray(Generic[T], np.ndarray, BaseNDArray): @classmethod def __get_validators__(cls): yield cls.validate @@ -25,7 +35,7 @@ def __get_validators__(cls): def validate(cls, val: Any, field: ModelField) -> np.ndarray: return _validate(cls, val, field) - class PotentialNDArray(Generic[T], np.ndarray): + class PotentialNDArray(Generic[T], np.ndarray, BaseNDArray): """Like NDArray, but validation errors result in None.""" @classmethod @@ -42,7 +52,7 @@ def validate(cls, val: Any, field: ModelField) -> Optional[np.ndarray]: else: - class NDArray(Generic[T], np.ndarray[Any, T]): + class NDArray(Generic[T], np.ndarray[Any, T], BaseNDArray): @classmethod def __get_validators__(cls): yield cls.validate @@ -51,7 +61,7 @@ def __get_validators__(cls): def validate(cls, val: Any, field: ModelField) -> Optional[np.ndarray]: return _validate(cls, val, field) - class PotentialNDArray(Generic[T], np.ndarray[Any, T]): + class PotentialNDArray(Generic[T], np.ndarray[Any, T], BaseNDArray): @classmethod def __get_validators__(cls): yield cls.validate