From b2856837dfdca3eb2de50565bf445420e2b55388 Mon Sep 17 00:00:00 2001 From: Don Setiawan Date: Wed, 16 Aug 2023 10:19:44 -0700 Subject: [PATCH] refactor: update codebase to use Pydantic v2 (#133) * chore(deps): bump pydantic deps to v2 * refactor: update pydantic models to use v2 and clean up * docs: add more docstring for class --- ref: Issue #85 --- pyproject.toml | 5 +- src/gnatss/configs/io.py | 14 ++- src/gnatss/configs/main.py | 163 +++++++++++++++++------------------ src/gnatss/configs/solver.py | 52 +++++------ src/gnatss/loaders.py | 2 +- 5 files changed, 112 insertions(+), 124 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 12de31f..973c9b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,8 @@ dependencies = [ "nptyping>=2.5.0,<3", "cftime>=1.6.2,<2", "pandas>=2.0.3,<3", - "pydantic>=1.10.6,<2", + "pydantic>=2.1.1,<3", + "pydantic-settings>=2.0.3,<3", "pyyaml>=6.0.1,<7", "pymap3d>=3.0.1,<4", "pluggy>=1.2.0,<2", @@ -46,7 +47,7 @@ docs = [ "sphinx-panels", "sphinx_rtd_theme", "sphinxcontrib-mermaid", - "autodoc_pydantic" + "autodoc_pydantic>=2.0.1,<3.0" ] lint = [ "black", diff --git a/src/gnatss/configs/io.py b/src/gnatss/configs/io.py index 2e0fa62..20de577 100644 --- a/src/gnatss/configs/io.py +++ b/src/gnatss/configs/io.py @@ -21,13 +21,11 @@ class InputData(BaseModel): ), ) - def __init__(__pydantic_self__, **data: Any) -> None: + def __init__(self, **data: Any) -> None: super().__init__(**data) # Checks the file - if not check_file_exists( - __pydantic_self__.path, __pydantic_self__.storage_options - ): + if not check_file_exists(self.path, self.storage_options): raise FileNotFoundError("The specified file doesn't exist!") @@ -48,11 +46,9 @@ class OutputPath(BaseModel): _fsmap: str = PrivateAttr() - def __init__(__pydantic_self__, **data: Any) -> None: + def __init__(self, **data: Any) -> None: super().__init__(**data) - __pydantic_self__._fsmap = fsspec.get_mapper( - __pydantic_self__.path, **__pydantic_self__.storage_options - ) + self._fsmap = fsspec.get_mapper(self.path, **self.storage_options) # Checks the file permission as the object is being created - check_permission(__pydantic_self__._fsmap) + check_permission(self._fsmap) diff --git a/src/gnatss/configs/main.py b/src/gnatss/configs/main.py index b8874f8..76e148a 100644 --- a/src/gnatss/configs/main.py +++ b/src/gnatss/configs/main.py @@ -5,11 +5,16 @@ """ import warnings from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Type import yaml -from pydantic import BaseSettings, Field -from pydantic.fields import ModelField +from pydantic import Field +from pydantic.fields import FieldInfo +from pydantic_settings import ( + BaseSettings, + PydanticBaseSettingsSource, + SettingsConfigDict, +) from .io import OutputPath from .solver import Solver @@ -17,98 +22,94 @@ CONFIG_FILE = "config.yaml" -def yaml_config_settings_source(settings: BaseSettings) -> Dict[str, Any]: +class YamlConfigSettingsSource(PydanticBaseSettingsSource): """ + A simple settings source that reads from a yaml file. + Read config settings form a local yaml file where the software runs + """ - Parameters - ---------- - settings : pydantic.BaseSettings - The base settings class + def get_field_value( + self, field: FieldInfo, field_name: str + ) -> tuple[Any, str, bool]: + """ + Gets the value, + the key for model creation, + and a flag to determine whether value is complex. - Returns - ------- - dict - The configuration dictionary based on inputs from the yaml - file - """ - encoding = settings.__config__.env_file_encoding - config_path = Path(CONFIG_FILE) - if config_path.exists(): - # Only load config.yaml when it exists - return yaml.safe_load(config_path.read_text(encoding)) - else: - warnings.warn( - ( - f"Configuration file `{CONFIG_FILE}` not found. " - "Will attempt to retrieve configuration from environment variables." + *This is an override for the pydantic abstract method.* + """ + encoding = self.config.get("env_file_encoding") + config_path = Path(CONFIG_FILE) + file_content_yaml = {} + if config_path.exists(): + # Only load config.yaml when it exists + file_content_yaml = yaml.safe_load(config_path.read_text(encoding)) + else: + # Warn user when config.yaml is not found + warnings.warn( + ( + f"Configuration file `{CONFIG_FILE}` not found. " + "Will attempt to retrieve configuration from environment variables." + ) ) - ) - return {} + field_value = file_content_yaml.get(field_name) + return field_value, field_name, False -class BaseConfiguration(BaseSettings): - """Base configuration class""" + def prepare_field_value( + self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool + ) -> Any: + """ + Prepares the value of a field. - @classmethod - def add_fields(cls, **field_definitions: Any) -> None: + *This is an override for the pydantic abstract method.* """ - Adds additional configuration field on the fly (inplace) + return value - Parameters - ---------- - **field_definitions - Keyword arguments of the new field to be added + def __call__(self) -> Dict[str, Any]: + """ + Allows the class to be called as a function. """ - new_fields: Dict[str, ModelField] = {} - new_annotations: Dict[str, Optional[type]] = {} - - for f_name, f_def in field_definitions.items(): - if isinstance(f_def, tuple): - try: - f_annotation, f_value = f_def - except ValueError as e: - raise Exception( - "field definitions should either be a tuple of" - " (, ) or just a " - "default value, unfortunately this means tuples as " - "default values are not allowed" - ) from e - else: - f_annotation, f_value = None, f_def - - if f_annotation: - new_annotations[f_name] = f_annotation - - new_fields[f_name] = ModelField.infer( - name=f_name, - value=f_value, - annotation=f_annotation, - class_validators=None, - config=cls.__config__, + d: Dict[str, Any] = {} + + for field_name, field in self.settings_cls.model_fields.items(): + field_value, field_key, value_is_complex = self.get_field_value( + field, field_name + ) + field_value = self.prepare_field_value( + field_name, field, field_value, value_is_complex ) + if field_value is not None: + d[field_key] = field_value + + return d - cls.__fields__.update(new_fields) - cls.__annotations__.update(new_annotations) - class Config: - env_file_encoding = "utf-8" - env_nested_delimiter = "__" - env_prefix = "gnatss_" +class BaseConfiguration(BaseSettings): + """Base configuration class""" + + model_config = SettingsConfigDict( + env_file_encoding="utf-8", + env_nested_delimiter="__", + env_prefix="gnatss_", + ) - @classmethod - def customise_sources( - cls, + @classmethod + def settings_customise_sources( + cls, + settings_cls: Type[BaseSettings], + init_settings: PydanticBaseSettingsSource, + env_settings: PydanticBaseSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, # noqa + file_secret_settings: PydanticBaseSettingsSource, + ): + return ( init_settings, + YamlConfigSettingsSource(settings_cls), env_settings, file_secret_settings, - ): - return ( - init_settings, - yaml_config_settings_source, - env_settings, - file_secret_settings, - ) + ) class Configuration(BaseConfiguration): @@ -119,12 +120,10 @@ class Configuration(BaseConfiguration): solver: Optional[Solver] = Field(None, description="Solver configurations") output: OutputPath - def __init__(__pydantic_self__, **data): + def __init__(self, **data): super().__init__(**data) # Set the transponders pxp id based on the site id - transponders = __pydantic_self__.solver.transponders + transponders = self.solver.transponders for idx in range(len(transponders)): - transponders[idx].pxp_id = "-".join( - [__pydantic_self__.site_id, str(idx + 1)] - ) + transponders[idx].pxp_id = "-".join([self.site_id, str(idx + 1)]) diff --git a/src/gnatss/configs/solver.py b/src/gnatss/configs/solver.py index a52d996..04ea327 100644 --- a/src/gnatss/configs/solver.py +++ b/src/gnatss/configs/solver.py @@ -3,10 +3,11 @@ The solver module containing base models for solver configuration """ -from typing import Any, List, Literal, Optional +from functools import cached_property +from typing import List, Literal, Optional from uuid import uuid4 -from pydantic import BaseModel, Field, PrivateAttr +from pydantic import BaseModel, Field, computed_field from .io import InputData @@ -16,20 +17,13 @@ class ReferenceEllipsoid(BaseModel): semi_major_axis: float = Field(..., description="Semi-major axis (m)") reverse_flattening: float = Field(..., description="Reverse flattening") - eccentricity: Optional[float] = Field( - None, - description="Eccentricity. **This field will be computed during object creation**", - ) - - def __init__(__pydantic_self__, **data: Any) -> None: - super().__init__(**data) - # Note: Potential improvement with computed value - # https://github.com/pydantic/pydantic/pull/2625 - __pydantic_self__.eccentricity = ( - 2.0 / __pydantic_self__.reverse_flattening - - (1.0 / __pydantic_self__.reverse_flattening) ** 2.0 - ) + @computed_field( + description="Eccentricity. **This field will be computed during object creation**" # noqa + ) + @cached_property + def eccentricity(self) -> Optional[float]: + return 2.0 / self.reverse_flattening - (1.0 / self.reverse_flattening) ** 2.0 class ArrayCenter(BaseModel): @@ -59,14 +53,14 @@ class SolverInputs(BaseModel): class SolverGlobal(BaseModel): """Solver global base model for inversion process.""" - max_dat = 45000 - max_gps = 423000 - max_del = 15000 - max_brk = 20 - max_surv = 10 - max_sdt_obs = 2000 - max_obm = 472 - max_unmm = 9 + max_dat: int = 45000 + max_gps: int = 423000 + max_del: int = 15000 + max_brk: int = 20 + max_surv: int = 10 + max_sdt_obs: int = 2000 + max_obm: int = 472 + max_unmm: int = 9 class SolverTransponder(BaseModel): @@ -94,14 +88,12 @@ class SolverTransponder(BaseModel): "**This field will be computed during object creation**" ), ) - # Auto generated uuid per transponder for unique identifier - _uuid: str = PrivateAttr() - - def __init__(__pydantic_self__, **data: Any) -> None: - super().__init__(**data) - # A solver transponder unique identifier - __pydantic_self__._uuid = uuid4().hex + @computed_field(repr=False, description="Transponder unique identifier") + @cached_property + def _uuid(self) -> str: + """Auto generated uuid per transponder for unique identifier""" + return uuid4().hex class Solver(BaseModel): diff --git a/src/gnatss/loaders.py b/src/gnatss/loaders.py index eb19173..7f77209 100644 --- a/src/gnatss/loaders.py +++ b/src/gnatss/loaders.py @@ -4,7 +4,7 @@ import pandas as pd import yaml -from pydantic.error_wrappers import ValidationError +from pydantic import ValidationError from . import constants from .configs.main import Configuration