From 236230046d40c1d0f7f4b444d6e9a033df700c2d Mon Sep 17 00:00:00 2001 From: Matt Craig Date: Thu, 1 Feb 2024 10:41:43 -0600 Subject: [PATCH 01/15] Initial attempt at pydantic 2 rewrite --- pyproject.toml | 6 +- stellarphot/settings/__init__.py | 1 - stellarphot/settings/astropy_pydantic.py | 252 ++++++++----- stellarphot/settings/autowidgets.py | 17 - stellarphot/settings/models.py | 339 +++++++++--------- .../settings/tests/test_astropy_pydantic.py | 121 +++++++ stellarphot/settings/tests/test_models.py | 99 ++--- 7 files changed, 519 insertions(+), 316 deletions(-) delete mode 100644 stellarphot/settings/autowidgets.py create mode 100644 stellarphot/settings/tests/test_astropy_pydantic.py diff --git a/pyproject.toml b/pyproject.toml index 180786e9..507d2cb7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,13 +19,13 @@ dependencies = [ "bottleneck", "ccdproc", "ginga", - "ipyautoui >=0.5.9", + "ipyautoui >=0.7", "ipyfilechooser", "ipywidgets", "matplotlib", "pandas", "photutils >=1.9", - "pydantic<2", + "pydantic >=2", "pyyaml", ] @@ -155,4 +155,6 @@ filterwarnings = [ 'ignore:Passing unrecognized arguments to super:DeprecationWarning', # pandas will require pyarrow at some point, which is good to know, I guess... 'ignore:[.\n]*Pyarrow will become a required dependency of pandas[.\n]*:DeprecationWarning', + # ipyautoui is generating this on import because they still have some pydantic changes to make + 'ignore:Using extra keyword arguments on `Field` is deprecated:' ] diff --git a/stellarphot/settings/__init__.py b/stellarphot/settings/__init__.py index 56c8955b..378409df 100644 --- a/stellarphot/settings/__init__.py +++ b/stellarphot/settings/__init__.py @@ -1,3 +1,2 @@ -from .autowidgets import * from .models import * from .views import * diff --git a/stellarphot/settings/astropy_pydantic.py b/stellarphot/settings/astropy_pydantic.py index 452a3923..11aa00b1 100644 --- a/stellarphot/settings/astropy_pydantic.py +++ b/stellarphot/settings/astropy_pydantic.py @@ -1,106 +1,184 @@ -from astropy.units import Quantity, Unit - -__all__ = ["UnitType", "QuantityType", "PixelScaleType"] +from dataclasses import dataclass +from typing import Annotated, Any + +from astropy.units import ( + PhysicalType, + Quantity, + Unit, + UnitBase, + UnitConversionError, + get_physical_type, +) +from pydantic import ( + GetCoreSchemaHandler, +) +from pydantic_core import core_schema + +__all__ = ["UnitType", "QuantityType", "EquivalentTo", "WithPhysicalType"] + +_PHYSICAL_TYPES_URL = "https://docs.astropy.org/en/stable/units/ref_api.html#module-astropy.units.physical" # Approach to validation of units was inspired by the GammaPy project # which did it before we did: # https://docs.gammapy.org/dev/_modules/gammapy/analysis/config.html +# Update for pydantic 2.0, based on the pydantic docs: +# https://docs.pydantic.dev/latest/concepts/types/#handling-third-party-types -class UnitType(Unit): - # Validator for Unit type - @classmethod - def __get_validators__(cls): - yield cls.validate - @classmethod - def validate(cls, v): - return Unit(v) +class _UnitQuantTypePydanticAnnotation: + """ + This class is used to annotate fields where validation consists of checking + whether an instance can be created. + + In astropy, this includes `astropy.units.Unit` and `astropy.units.Quantity`. + """ @classmethod - def __modify_schema__(cls, field_schema, field): - # Set default values for the schema in case the field doesn't provide them - name = "Unit" - description = "An astropy unit" - - name = field.name or name - description = field.field_info.description or description - examples = field.field_info.extra.get("examples", []) - - field_schema.update( - { - "title": name, - "description": description, - "examples": examples, - "type": "string", - } + def __get_pydantic_core_schema__( + cls, + source_type: Any, + _handler: GetCoreSchemaHandler, + ) -> core_schema.CoreSchema: + """ + We return a pydantic_core.CoreSchema that behaves in the following ways: + + * A Unit or a Quantity will pass validation and be returned as-is + * A string or a float will be used to create a Unit or a Quantity. + * Nothing else will pass validation + * Serialization will always return just a string + """ + + def validate_by_instantiating(value): + # If the value is valid we will be able to create an instance of the + # source_type from it. For example, if source_type is astropy.units.Unit, + # then we should be able to create a Unit from the value. + try: + result = source_type(value) + except TypeError as err: + raise ValueError(str(err)) from err + return result + + # Both Unit and Qunatity can be created from a string or a float or an + # instance of the same type. So we need to check for all of those cases. + + # core_schema.chain_schema runs the value through each of the schema + # in the list, in order. The output of one schema is the input to the next. + + # When you do `model_json_schema` with a `chain_schema`, then the first entry is + # used if `mode="validation"` and the last is used if `mode="serialization"` + # from the schema used to serialize json. + + # I guess this makes sense, since the first thing in the chain has to handle the + # value coming from json, while the last thing generates the python value for + # the input. + from_str_schema = core_schema.chain_schema( + [ + core_schema.str_schema(), + core_schema.no_info_plain_validator_function(validate_by_instantiating), + ] ) + from_float_schema = core_schema.chain_schema( + [ + core_schema.float_schema(), + core_schema.no_info_plain_validator_function(validate_by_instantiating), + ] + ) + return core_schema.json_or_python_schema( + json_schema=from_str_schema, + # union_schema takes a list of schemas and returns a schema that + # is the "best" match. See the link below for a description of + # what counts as "best": + # https://docs.pydantic.dev/dev/concepts/unions/#smart-mode + # + # In short, schemas are tried from left-to-right, and an exact type match + # wins. + python_schema=core_schema.union_schema( + [ + # Check if it's an instance first before doing any further work. + # Would be nice to provide a list of classes here instead + # of one-by-one. + core_schema.is_instance_schema(UnitBase), + core_schema.is_instance_schema(Quantity), + from_str_schema, + from_float_schema, + ] + ), + # Serialization by converting to a string. + serialization=core_schema.plain_serializer_function_ser_schema( + lambda instance: str(instance) + ), + ) -class QuantityType(Quantity): - # Validator for Quantity type - @classmethod - def __get_validators__(cls): - yield cls.validate - - @classmethod - def validate(cls, v): - try: - v = Quantity(v) - except TypeError as err: - raise ValueError(f"Invalid value for Quantity: {v}") from err - else: - if not v.unit.bases: - raise ValueError("Must provided a unit") - return v - @classmethod - def __modify_schema__(cls, field_schema, field): - # Set default values for the schema in case the field doesn't provide them - name = "Quantity" - description = "An astropy Quantity with units" - - name = field.name or name - description = field.field_info.description or description - examples = field.field_info.extra.get("examples", []) - - field_schema.update( - { - "title": name, - "description": description, - "examples": examples, - "type": "string", - } +@dataclass +class EquivalentTo: + equivalent_unit: Unit + + def __get_pydantic_core_schema__( + self, source_type: Any, handler: GetCoreSchemaHandler + ): + def check_equivalent(value): + if isinstance(value, UnitBase): + value_unit = value + else: + value_unit = value.unit + + try: + value.to(self.equivalent_unit) + except UnitConversionError: + raise ValueError( + f"Unit {value_unit} is not equivalent to {self.equivalent_unit}" + ) from None + return value + + return core_schema.no_info_after_validator_function( + check_equivalent, handler(source_type) ) -class PixelScaleType(Quantity): - # Validator for pixel scale type - @classmethod - def __get_validators__(cls): - yield cls.validate +@dataclass +class WithPhysicalType: + physical_type: str | PhysicalType - @classmethod - def validate(cls, v): + def __post_init__(self): try: - v = Quantity(v) - except TypeError as err: - raise ValueError(f"Invalid value for Quantity: {v}") from err - if ( - len(v.unit.bases) != 2 - or v.unit.bases[0].physical_type != "angle" - or v.unit.bases[1].name != "pix" - ): - raise ValueError(f"Invalid unit for pixel scale: {v.unit!r}") - return v - - @classmethod - def __modify_schema__(cls, field_schema): - field_schema.update( - { - "title": "PixelScale", - "description": "An astropy Quantity with units of angle per pixel", - "examples": ["0.563 arcsec / pix"], - "type": "string", - } + get_physical_type(self.physical_type) + except ValueError as err: + raise ValueError( + str(err) + + f"\nSee {_PHYSICAL_TYPES_URL} for a list of valid physical types." + ) from err + + def __get_pydantic_core_schema__( + self, source_type: Any, handler: GetCoreSchemaHandler + ): + def check_physical_type(value): + is_same = get_physical_type(value) == get_physical_type(self.physical_type) + if is_same: + return value + else: + raise ValueError( + f"Unit of {value} is not equivalent to {self.physical_type}" + ) from None + + return core_schema.no_info_after_validator_function( + check_physical_type, handler(source_type) ) + + +# We have lost default titles and exmples, but that is maybe not so bad + +# This is really nice compared to pydantiv v1... +UnitType = Annotated[Unit, _UnitQuantTypePydanticAnnotation] + +# Quantity type is really clean too +QuantityType = Annotated[Quantity, _UnitQuantTypePydanticAnnotation] + + +# It turns out all (almost all?) astropy types have, buried in them, a representation +# that is a python dictionary info._represent_as_dict and info._construct_from_dict. +# This is what is used to represent astropy +# objects in Tables and FITS files. So we can use this to create a json schema +# for the astropy types. diff --git a/stellarphot/settings/autowidgets.py b/stellarphot/settings/autowidgets.py deleted file mode 100644 index e6dccbeb..00000000 --- a/stellarphot/settings/autowidgets.py +++ /dev/null @@ -1,17 +0,0 @@ -# Some classes for ipyautoui that really belong there, not here - -import ipywidgets as w -from ipyautoui.autowidgets import create_widget_caller - -__all__ = ["CustomBoundedIntTex"] - - -class CustomBoundedIntTex(w.BoundedIntText): - """ - A BoundedIntText widget adapted for use in ipyautoui. - """ - - def __init__(self, schema): - self.schema = schema - self.caller = create_widget_caller(schema) - super().__init__(**self.caller) diff --git a/stellarphot/settings/models.py b/stellarphot/settings/models.py index f808ac10..ea4448af 100644 --- a/stellarphot/settings/models.py +++ b/stellarphot/settings/models.py @@ -1,18 +1,29 @@ # Objects that contains the user settings for the program. from pathlib import Path +from typing import Annotated -import astropy.units as u -from astropy.coordinates import SkyCoord from astropy.io.misc.yaml import AstropyDumper, AstropyLoader -from astropy.time import Time -from astropy.units import IrreducibleUnit, Quantity, Unit -from pydantic import BaseModel, Field, confloat, conint, root_validator, validator +from astropy.units import Quantity, Unit +from pydantic import BaseModel, ConfigDict, Field, confloat, conint, model_validator -from .astropy_pydantic import PixelScaleType, QuantityType, UnitType -from .autowidgets import CustomBoundedIntTex +from .astropy_pydantic import EquivalentTo, QuantityType, UnitType -__all__ = ["Camera", "PhotometryApertures", "PhotometryFileSettings", "Exoplanet"] +__all__ = [ + "Camera", + "PhotometryApertures", + "PhotometryFileSettings", +] # "Exoplanet"] + +# Most models should use the default configuration, but it can be customized if needed. +MODEL_DEFAULT_CONFIGURATION = ConfigDict( + # Make sure default values are valid + validate_default=True, + # Make sure changes to values made after initialization are valid + validate_assignment=True, + # Make sure there are no extra fields + extra="forbid", +) class Camera(BaseModel): @@ -107,6 +118,7 @@ class Camera(BaseModel): """ + model_config = MODEL_DEFAULT_CONFIGURATION data_unit: UnitType = Field( description="units of the data", examples=["adu", "counts", "DN", "electrons"] ) @@ -123,41 +135,34 @@ class Camera(BaseModel): description="unit consistent with read noise, per unit time", examples=["0.01 electron / second"], ) - pixel_scale: PixelScaleType = Field( - description="units of angle per pixel", examples=["0.6 arcsec / pix"] - ) - max_data_value: QuantityType = Field( - description="maximum data value while performing photometry", - examples=["50000 adu"], - ) - - class Config: - validate_all = True - validate_assignment = True - extra = "forbid" - json_encoders = { - Quantity: lambda v: f"{v.value} {v.unit}", - QuantityType: lambda v: f"{v.value} {v.unit}", - Unit: lambda v: f"{v}", - IrreducibleUnit: lambda v: f"{v}", - PixelScaleType: lambda v: f"{v.value} {v.unit}", - } - - # When the switch to pydantic v2 happens, this root_validator will need - # to be replaced by a model_validator decorator. - @root_validator(skip_on_failure=True) - @classmethod - def validate_gain(cls, values): + pixel_scale: Annotated[ + QuantityType, + EquivalentTo(Unit("arcsec / pix")), + Field(description="units of angle per pixel", examples=["0.6 arcsec / pix"]), + ] + max_data_value: Annotated[ + QuantityType, + Field( + description="maximum data value while performing photometry", + examples=["50000 adu"], + gt=0, + ), + ] + + # Run the model validator after the default validator. Unlike in pydantic 1, + # mode="after" passes in an instance as an argument not a value. + @model_validator(mode="after") + def validate_gain(self): # Get read noise units - rn_unit = Quantity(values["read_noise"]).unit + rn_unit = Quantity(self.read_noise).unit # Check that gain and read noise have compatible units, that is that # gain is read noise per data unit. - gain = values["gain"] + gain = self.gain if ( len(gain.unit.bases) != 2 or gain.unit.bases[0] != rn_unit - or gain.unit.bases[1] != values["data_unit"] + or gain.unit.bases[1] != self.data_unit ): raise ValueError( f"Gain units {gain.unit} are not compatible with " @@ -166,11 +171,11 @@ def validate_gain(cls, values): # Check that dark current and read noise have compatible units, that is # that dark current is read noise per second. - dark_current = values["dark_current"] + dark_current = self.dark_current if ( len(dark_current.unit.bases) != 2 or dark_current.unit.bases[0] != rn_unit - or dark_current.unit.bases[1] != u.s + or dark_current.unit.bases[1] != Unit("s") ): raise ValueError( f"Dark current units {dark_current.unit} are not " @@ -178,24 +183,17 @@ def validate_gain(cls, values): ) # Check that maximum data value is consistent with data units - if values["max_data_value"].unit != values["data_unit"]: + if self.max_data_value.unit != self.data_unit: raise ValueError( - f"Maximum data value units {values['max_data_value'].unit} " - f"are not consistent with data units {values['data_unit']}." + f"Maximum data value units {self.max_data_value.unit} " + f"are not consistent with data units {self.data_unit}." ) - return values - - @validator("max_data_value") - @classmethod - def validate_max_data_value(cls, v): - if v.value <= 0: - raise ValueError("max_data_value must be positive") - return v + return self # Add YAML round-tripping for Camera -def _camera_representer(dumper, cam): - return dumper.represent_mapping("!Camera", cam.dict()) +def camera_representer(dumper, cam): + return dumper.represent_mapping("!Camera", cam.model_dump()) def _camera_constructor(loader, node): @@ -248,17 +246,20 @@ class PhotometryApertures(BaseModel): ... ) """ - radius: conint(ge=1) = Field(autoui=CustomBoundedIntTex, default=1) - gap: conint(ge=1) = Field(autoui=CustomBoundedIntTex, default=1) - annulus_width: conint(ge=1) = Field(autoui=CustomBoundedIntTex, default=1) + # model_config = MODEL_DEFAULT_CONFIGURATION + + radius: conint(ge=1) = Field( + default=1, json_schema_extra=dict(autoui="ipywidgets.BoundedIntText") + ) + gap: conint(ge=1) = Field( + default=1, json_schema_extra=dict(autoui="ipywidgets.BoundedIntText") + ) + annulus_width: conint(ge=1) = Field( + default=1, json_schema_extra=dict(autoui="ipywidgets.BoundedIntText") + ) # Disable the UI element by default because it is often calculate from an image fwhm: confloat(gt=0) = Field(disabled=True, default=1.0) - class Config: - validate_assignment = True - validate_all = True - extra = "forbid" - @property def inner_annulus(self): """ @@ -279,6 +280,8 @@ class PhotometryFileSettings(BaseModel): An evolutionary step on the way to having a monolithic set of photometry settings. """ + model_config = MODEL_DEFAULT_CONFIGURATION + image_folder: Path = Field( show_only_dirs=True, default="", @@ -290,111 +293,113 @@ class PhotometryFileSettings(BaseModel): ) -class TimeType(Time): - @classmethod - def __get_validators__(cls): - yield cls.validate - - @classmethod - def validate(cls, v): - return Time(v) - - -class SkyCoordType(SkyCoord): - @classmethod - def __get_validators__(cls): - yield cls.validate - - @classmethod - def validate(cls, v): - return SkyCoord(v) - - -class Exoplanet(BaseModel): - """ - Create an object representing an Exoplanet. - - Parameters - ---------- - - epoch : `astropy.time.Time`, optional - Epoch of the exoplanet. - - period : `astropy.units.Quantity`, optional - Period of the exoplanet. - - Identifier : str - Identifier of the exoplanet. - - coordinate : `astropy.coordinates.SkyCoord` - Coordinates of the exoplanet. - - depth : float - Depth of the exoplanet. - - duration : `astropy.units.Quantity`, optional - Duration of the exoplanet transit. - - Examples - -------- - - To create an `Exoplanet` object, you can pass in the epoch, - period, identifier, coordinate, depth, and duration as keyword arguments: - - >>> from astropy.time import Time - >>> from astropy.coordinates import SkyCoord - >>> from astropy import units as u - >>> planet = Exoplanet(epoch=Time(2455909.29280, format="jd"), - ... period=1.21749 * u.day, - ... identifier="KELT-1b", - ... coordinate=SkyCoord(ra="00:01:26.9169", - ... dec="+39:23:01.7821", - ... frame="icrs", - ... unit=("hour", "degree")), - ... depth=0.006, - ... duration=120 * u.min) - """ - - epoch: TimeType | None = None - period: QuantityType | None = None - identifier: str - coordinate: SkyCoordType - depth: float | None = None - duration: QuantityType | None = None - - class Config: - validate_all = True - validate_assignment = True - extra = "forbid" - json_encoders = { - Quantity: lambda v: f"{v.value} {v.unit}", - QuantityType: lambda v: f"{v.value} {v.unit}", - Time: lambda v: f"{v.value}", - } - - @validator("period") - @classmethod - def validate_period(cls, value): - """ - Checks that the period has physical units of time and raises an error - if that is not true. - """ - if u.get_physical_type(value) != "time": - raise ValueError( - f"Period does not have time units," f"currently has {value.unit} units." - ) - return value - - @validator("duration") - @classmethod - def validate_duration(cls, value): - """ - Checks that the duration has physical units of time and raises an error - if that is not true. - """ - if u.get_physical_type(value) != "time": - raise ValueError( - f"Duration does not have time units," - f"currently has {value.unit} units." - ) - return value +# class TimeType(Time): +# @classmethod +# def __get_validators__(cls): +# yield cls.validate + +# @classmethod +# def validate(cls, v): +# return Time(v) + + +# class SkyCoordType(SkyCoord): +# @classmethod +# def __get_validators__(cls): +# yield cls.validate + +# @classmethod +# def validate(cls, v): +# return SkyCoord(v) + + +# class Exoplanet(BaseModel): +# """ +# Create an object representing an Exoplanet. + +# Parameters +# ---------- + +# epoch : `astropy.time.Time`, optional +# Epoch of the exoplanet. + +# period : `astropy.units.Quantity`, optional +# Period of the exoplanet. + +# Identifier : str +# Identifier of the exoplanet. + +# coordinate : `astropy.coordinates.SkyCoord` +# Coordinates of the exoplanet. + +# depth : float +# Depth of the exoplanet. + +# duration : `astropy.units.Quantity`, optional +# Duration of the exoplanet transit. + +# Examples +# -------- + +# To create an `Exoplanet` object, you can pass in the epoch, +# period, identifier, coordinate, depth, and duration as keyword arguments: + +# >>> from astropy.time import Time +# >>> from astropy.coordinates import SkyCoord +# >>> from astropy import units as u +# >>> planet = Exoplanet(epoch=Time(2455909.29280, format="jd"), +# ... period=1.21749 * u.day, +# ... identifier="KELT-1b", +# ... coordinate=SkyCoord(ra="00:01:26.9169", +# ... dec="+39:23:01.7821", +# ... frame="icrs", +# ... unit=("hour", "degree")), +# ... depth=0.006, +# ... duration=120 * u.min) +# """ +# model_config = MODEL_DEFAULT_CONFIGURATION + +# epoch: TimeType | None = None +# period: QuantityType | None = None +# identifier: str +# coordinate: SkyCoordType +# depth: float | None = None +# duration: QuantityType | None = None + +# # class Config: +# # validate_all = True +# # validate_assignment = True +# # extra = "forbid" +# # json_encoders = { +# # Quantity: lambda v: f"{v.value} {v.unit}", +# # QuantityType: lambda v: f"{v.value} {v.unit}", +# # Time: lambda v: f"{v.value}", +# # } + +# @validator("period") +# @classmethod +# def validate_period(cls, value): +# """ +# Checks that the period has physical units of time and raises an error +# if that is not true. +# """ +# if u.get_physical_type(value) != "time": +# raise ValueError( +# f"Period does not have time units," +# f"currently has {value.unit} units." +# ) +# return value + +# @validator("duration") +# @classmethod +# def validate_duration(cls, value): +# """ +# Checks that the duration has physical units of time and raises an error +# if that is not true. +# """ +# if u.get_physical_type(value) != "time": +# raise ValueError( +# f"Duration does not have time units," +# f"currently has {value.unit} units." +# ) +# return value diff --git a/stellarphot/settings/tests/test_astropy_pydantic.py b/stellarphot/settings/tests/test_astropy_pydantic.py new file mode 100644 index 00000000..e8649502 --- /dev/null +++ b/stellarphot/settings/tests/test_astropy_pydantic.py @@ -0,0 +1,121 @@ +from typing import Annotated + +import pytest +from astropy.units import Quantity, Unit, get_physical_type +from pydantic import BaseModel, ValidationError + +from stellarphot.settings.astropy_pydantic import ( + EquivalentTo, + QuantityType, + UnitType, + WithPhysicalType, +) + + +class _UnitModel(BaseModel): + # Dummy class for testing + unit: UnitType + + +class _QuantityModel(BaseModel): + # Dummy class for testing + quantity: QuantityType + + +@pytest.mark.parametrize( + "init", + [ + Unit("m"), + 1, + "meter", + "parsec / fortnight", + "", + ], +) +def test_unit_initialization(init): + # Make sure we can initialize from each of the ways a Unit can + # be initiailized + expected = Unit(init) + unit = _UnitModel(unit=init) + assert expected == unit.unit + + +@pytest.mark.parametrize( + "init", + [ + -2 * Unit("m"), + 1, + "5 meter", + "13 parsec / fortnight", + "42", + ], +) +def test_quantity_intialization(init): + # Make sure we can initialize from each of the ways a Quantity can + # be initiailized + expected = Quantity(init) + quantity = _QuantityModel(quantity=init) + assert expected == quantity.quantity + + +class _ModelEquivalentTo(BaseModel): + unit_meter: Annotated[UnitType, EquivalentTo(equivalent_unit="m")] + quantity_meter: Annotated[QuantityType, EquivalentTo(equivalent_unit="m")] + + +class _ModelWithPhysicalType(BaseModel): + quant_physical_length: Annotated[QuantityType, WithPhysicalType("length")] + unit_physical_time: Annotated[UnitType, WithPhysicalType("time")] + + +def test_equivalent_to(): + # Make sure we can annotate with an equivalent unit + + # This should succeed + model = _ModelEquivalentTo( + unit_meter="km", + quantity_meter=Quantity(1, "mm"), + ) + assert model.unit_meter == Unit("km") + assert model.quantity_meter.to("m").value == model.quantity_meter.value * 1e-3 + + # Now some failures + + with pytest.raises(ValidationError, match="Unit s is not equivalent to"): + _ModelEquivalentTo(unit_meter="km", quantity_meter=Quantity("1 s")) + + with pytest.raises(ValidationError, match="Unit s is not equivalent to"): + _ModelEquivalentTo(unit_meter="s", quantity_meter=Quantity("1 m")) + + +def test_with_physical_type(): + # Make sure we can annotate with a physical type + model = _ModelWithPhysicalType( + quant_physical_length=Quantity(1, "m"), + unit_physical_time=17 * Unit("second"), + ) + + assert get_physical_type(model.quant_physical_length) == "length" + assert get_physical_type(model.unit_physical_time) == "time" + + # Now some failures + # Pass a time in for physical type of length + with pytest.raises( + ValidationError, match="Unit of 1.0 s is not equivalent to length" + ): + _ModelWithPhysicalType( + quant_physical_length=Quantity(1, "s"), + unit_physical_time=Unit("second"), + ) + + # Pass a length in for physical type of time + with pytest.raises(ValidationError, match="Unit of m is not equivalent to time"): + _ModelWithPhysicalType( + quant_physical_length=Quantity(1, "m"), + unit_physical_time=Unit("meter"), + ) + + +def test_quantity_type_with_invalid_quantity(): + with pytest.raises(ValidationError, match="It does not start with a number"): + _QuantityModel(quantity="meter") diff --git a/stellarphot/settings/tests/test_models.py b/stellarphot/settings/tests/test_models.py index 5984d9e8..d1add1de 100644 --- a/stellarphot/settings/tests/test_models.py +++ b/stellarphot/settings/tests/test_models.py @@ -3,11 +3,12 @@ import astropy.units as u import pytest from astropy.coordinates import SkyCoord +from astropy.table import Table from astropy.time import Time from pydantic import ValidationError from stellarphot.settings import ui_generator -from stellarphot.settings.models import Camera, Exoplanet, PhotometryApertures +from stellarphot.settings.models import Camera, PhotometryApertures DEFAULT_APERTURE_SETTINGS = dict(radius=5, gap=10, annulus_width=15, fwhm=3.2) @@ -39,23 +40,27 @@ def test_camera_attributes(): c = Camera( **TEST_CAMERA_VALUES, ) - assert c.dict() == TEST_CAMERA_VALUES + assert c.model_dump() == TEST_CAMERA_VALUES def test_camera_unitscheck(): # Check that the units are checked properly - # Remove units from all of the Quantity types - camera_dict_no_units = { - k: v.value if hasattr(v, "value") else v for k, v in TEST_CAMERA_VALUES.items() + # Set a clearly incorrect Quantity. Simply removing the units does not lead + # to an invalid Quantity -- it turns out Quantity(5) is valid, with units of + # dimensionless_unscaled. So we need to set the units to something that is + # invalid. + camera_dict_bad_unit = { + k: "5 cows" if hasattr(v, "value") else v for k, v in TEST_CAMERA_VALUES.items() } # All 5 of the attributes after data_unit will be checked for units # and noted in the ValidationError message. Rather than checking # separately for all 5, we just check for the presence of the - # right number of errors - with pytest.raises(ValidationError, match="5 validation errors"): + # right number of errors, which is currently 20 -- 4 for each of the + # 5 attributes, because of the union schema in _UnitTypePydanticAnnotation + with pytest.raises(ValidationError, match="20 validation errors"): Camera( - **camera_dict_no_units, + **camera_dict_bad_unit, ) @@ -65,7 +70,7 @@ def test_camera_negative_max_adu(): camera_for_test["max_data_value"] = -1 * camera_for_test["max_data_value"] # Make sure that a negative max_adu raises an error - with pytest.raises(ValidationError, match="must be positive"): + with pytest.raises(ValidationError, match="Input should be greater than 0"): Camera( **camera_for_test, ) @@ -103,7 +108,7 @@ def test_camera_copy(): c = Camera( **TEST_CAMERA_VALUES, ) - c2 = c.copy() + c2 = c.model_copy() assert c2 == c @@ -122,14 +127,14 @@ def test_camera_altunitscheck(): c = Camera( **camera_for_test, ) - assert c.dict() == camera_for_test + assert c.model_dump() == camera_for_test def test_camera_schema(): # Check that we can generate a schema for a Camera and that it # has the right number of attributes c = Camera(**TEST_CAMERA_VALUES) - schema = c.schema() + schema = c.model_json_schema() assert len(schema["properties"]) == len(TEST_CAMERA_VALUES) @@ -138,10 +143,23 @@ def test_camera_json_round_trip(): c = Camera(**TEST_CAMERA_VALUES) - c2 = Camera.parse_raw(c.json()) + c2 = Camera.model_validate_json(c.model_dump_json()) assert c2 == c +def test_camera_table_round_trip(tmp_path): + # Check that a camera can be stored as part of an astropy.table.Table + # metadata and retrieved + table = Table({"data": [1, 2, 3]}) + c = Camera(**TEST_CAMERA_VALUES) + table.meta["camera"] = c + table_path = tmp_path / "test_table.ecsv" + table.write(table_path) + new_table = Table.read(table_path) + + assert new_table.meta["camera"] == c + + def test_create_aperture_settings_correctly(): ap_set = PhotometryApertures(**DEFAULT_APERTURE_SETTINGS) assert ap_set.radius == DEFAULT_APERTURE_SETTINGS["radius"] @@ -173,21 +191,18 @@ def test_aperture_settings_ui_generation(class_, defaults): # 2) The UI model matches our input # 3) The UI widgets contains the titles we expect # - + instance = class_(**defaults) + instance.model_json_schema() # 1) The UI is generated from the class ui = ui_generator(class_) - print(f"{class_=}") - print(f"{defaults=}") # 2) The UI model matches our input # Set the ui values to the defaults -- the value needs to be whatever would # go into a **widget** though, not a **model**. It is easiest to create # a model and then use its dict() method to get the widget values. - values_dict_as_strings = json.loads(class_(**defaults).json()) - print(f"{values_dict_as_strings=}") + values_dict_as_strings = json.loads(class_(**defaults).model_dump_json()) ui.value = values_dict_as_strings - print(f"{ui.value=}") - assert class_(**ui.value).dict() == defaults + assert class_(**ui.value).model_dump() == defaults # 3) The UI widgets contains the titles generated from pydantic. # Pydantic generically is supposed to generate titles from the field names, @@ -202,11 +217,11 @@ def test_aperture_settings_ui_generation(class_, defaults): pydantic_titles = { f: [f.replace("_", " "), f.replace("_", "")] for f in defaults.keys() } - # pydantic_titles = defaults.keys() title_present = [] - print(f"{ui.di_labels=}") + for title in pydantic_titles.keys(): - for label in ui.di_labels.values(): + for box in ui.di_boxes.values(): + label = box.html_title.value present = ( title.lower() in label.lower() or pydantic_titles[title][0].lower() in label.lower() @@ -230,22 +245,22 @@ def test_create_invalid_values(bad_one): PhotometryApertures(**bad_settings) -def test_create_exoplanet_correctly(): - planet = Exoplanet(**DEFAULT_EXOPLANET_SETTINGS) - print(planet) - assert planet.epoch == DEFAULT_EXOPLANET_SETTINGS["epoch"] - assert u.get_physical_type(planet.period) == "time" - assert planet.identifier == DEFAULT_EXOPLANET_SETTINGS["identifier"] - assert planet.coordinate == DEFAULT_EXOPLANET_SETTINGS["coordinate"] - assert planet.depth == DEFAULT_EXOPLANET_SETTINGS["depth"] - assert u.get_physical_type(planet.duration) == "time" - - -def test_create_invalid_exoplanet(): - values = DEFAULT_EXOPLANET_SETTINGS.copy() - # Make pediod and duration have invalid units for a time - values["period"] = values["period"].value * u.m - values["duration"] = values["duration"].value * u.m - # Check that individual values that are bad raise an error - with pytest.raises(ValidationError, match="2 validation errors"): - Exoplanet(**values) +# def test_create_exoplanet_correctly(): +# planet = Exoplanet(**DEFAULT_EXOPLANET_SETTINGS) +# print(planet) +# assert planet.epoch == DEFAULT_EXOPLANET_SETTINGS["epoch"] +# assert u.get_physical_type(planet.period) == "time" +# assert planet.identifier == DEFAULT_EXOPLANET_SETTINGS["identifier"] +# assert planet.coordinate == DEFAULT_EXOPLANET_SETTINGS["coordinate"] +# assert planet.depth == DEFAULT_EXOPLANET_SETTINGS["depth"] +# assert u.get_physical_type(planet.duration) == "time" + + +# def test_create_invalid_exoplanet(): +# values = DEFAULT_EXOPLANET_SETTINGS.copy() +# # Make pediod and duration have invalid units for a time +# values["period"] = values["period"].value * u.m +# values["duration"] = values["duration"].value * u.m +# # Check that individual values that are bad raise an error +# with pytest.raises(ValidationError, match="2 validation errors"): +# Exoplanet(**values) From a5b73660170b85be46e11b6828a6c70893b71893 Mon Sep 17 00:00:00 2001 From: Matt Craig Date: Mon, 29 Jan 2024 11:06:38 -0600 Subject: [PATCH 02/15] Update photometry and other tests for pydantic --- .../gui_tools/tests/test_seeing_profile.py | 6 +- .../photometry/tests/test_photometry.py | 12 ++-- stellarphot/settings/astropy_pydantic.py | 2 +- .../settings/tests/test_astropy_pydantic.py | 55 +++++++++++++++++++ 4 files changed, 64 insertions(+), 11 deletions(-) diff --git a/stellarphot/gui_tools/tests/test_seeing_profile.py b/stellarphot/gui_tools/tests/test_seeing_profile.py index 5dd375b9..873eab85 100644 --- a/stellarphot/gui_tools/tests/test_seeing_profile.py +++ b/stellarphot/gui_tools/tests/test_seeing_profile.py @@ -76,7 +76,7 @@ def test_seeing_profile_properties(tmp_path): # Check that the photometry apertures have defaulted to the # default values in the model. - assert profile_widget.aperture_settings.value == PhotometryApertures().dict() + assert profile_widget.aperture_settings.value == PhotometryApertures().model_dump() # Get the event handler that updates plots handler = profile_widget._make_show_event() @@ -102,9 +102,7 @@ def test_seeing_profile_properties(tmp_path): new_radius = phot_aps["radius"] - 2 # Change the radius by directly setting the value of the widget that holds # the value. That ends up being nested fairly deeply... - profile_widget.aperture_settings.autowidget.children[0].children[ - 0 - ].value = new_radius + profile_widget.aperture_settings.di_widgets["radius"].value = new_radius # Make sure the settings are updated phot_aps["radius"] = new_radius diff --git a/stellarphot/photometry/tests/test_photometry.py b/stellarphot/photometry/tests/test_photometry.py index 83b0f190..d767f73a 100644 --- a/stellarphot/photometry/tests/test_photometry.py +++ b/stellarphot/photometry/tests/test_photometry.py @@ -87,7 +87,7 @@ def test_calc_noise_source_only(gain, aperture_area): expected = np.sqrt(gain * counts) # Create camera instance - camera = ZERO_CAMERA.copy() + camera = ZERO_CAMERA.model_copy() camera.gain = gain * camera.gain.unit np.testing.assert_allclose( @@ -104,7 +104,7 @@ def test_calc_noise_dark_only(gain, aperture_area): exposure = 20 # Create camera instance - camera = ZERO_CAMERA.copy() + camera = ZERO_CAMERA.model_copy() # Set gain and dark current to values for test camera.dark_current = dark_current * camera.dark_current.unit camera.gain = gain * camera.gain.unit @@ -126,7 +126,7 @@ def test_calc_read_noise_only(gain, aperture_area): expected = np.sqrt(aperture_area * read_noise**2) # Create camera instance - camera = ZERO_CAMERA.copy() + camera = ZERO_CAMERA.model_copy() camera.read_noise = read_noise * camera.read_noise.unit camera.gain = gain * camera.gain.unit @@ -143,7 +143,7 @@ def test_calc_sky_only(gain, aperture_area): expected = np.sqrt(gain * aperture_area * sky) # Create camera instance - camera = ZERO_CAMERA.copy() + camera = ZERO_CAMERA.model_copy() camera.gain = gain * camera.gain.unit np.testing.assert_allclose( @@ -162,7 +162,7 @@ def test_annulus_area_term(): expected = np.sqrt(gain * aperture_area * (1 + aperture_area / annulus_area) * sky) # Create camera instance - camera = ZERO_CAMERA.copy() + camera = ZERO_CAMERA.model_copy() camera.gain = gain * camera.gain.unit np.testing.assert_allclose( @@ -192,7 +192,7 @@ def test_calc_noise_messy_case(digit, expected): read_noise = 12 # Create camera instance - camera = ZERO_CAMERA.copy() + camera = ZERO_CAMERA.model_copy() camera.gain = gain * camera.gain.unit camera.dark_current = dark_current * camera.dark_current.unit camera.read_noise = read_noise * camera.read_noise.unit diff --git a/stellarphot/settings/astropy_pydantic.py b/stellarphot/settings/astropy_pydantic.py index 11aa00b1..ac2c239b 100644 --- a/stellarphot/settings/astropy_pydantic.py +++ b/stellarphot/settings/astropy_pydantic.py @@ -59,7 +59,7 @@ def validate_by_instantiating(value): raise ValueError(str(err)) from err return result - # Both Unit and Qunatity can be created from a string or a float or an + # Both Unit and Quantity can be created from a string or a float or an # instance of the same type. So we need to check for all of those cases. # core_schema.chain_schema runs the value through each of the schema diff --git a/stellarphot/settings/tests/test_astropy_pydantic.py b/stellarphot/settings/tests/test_astropy_pydantic.py index e8649502..d0fd18d8 100644 --- a/stellarphot/settings/tests/test_astropy_pydantic.py +++ b/stellarphot/settings/tests/test_astropy_pydantic.py @@ -1,3 +1,4 @@ +import json from typing import Annotated import pytest @@ -119,3 +120,57 @@ def test_with_physical_type(): def test_quantity_type_with_invalid_quantity(): with pytest.raises(ValidationError, match="It does not start with a number"): _QuantityModel(quantity="meter") + + +@pytest.mark.parametrize( + "input_json_string", + [ + '{"quantity": "1 m"}', + '{"quantity": "1"}', + '{"quantity": "3 second"}', + ], +) +def test_initialize_quantity_with_json(input_json_string): + # Make sure we can initialize a Quantity from a json string + # where the quantity value is stored in the json as a string. + model = _QuantityModel.model_validate_json(input_json_string) + model_json = json.loads(model.model_dump_json()) + input_json = json.loads(input_json_string) + + assert Quantity(model_json["quantity"]) == Quantity(input_json["quantity"]) + + +def test_initialize_quantity_with_json_invalid(): + # Make sure we get an error when the json string is has a value + # that is a float (same fail happens for integer). + # Since our json validation assumes the value is a string, this + # should fail. + with pytest.raises(ValidationError, match="Input should be a valid string"): + _QuantityModel.model_validate_json('{"quantity": 14.0}') + + +@pytest.mark.parametrize( + "input_json_string", + [ + '{"unit": "m"}', + '{"unit": "1"}', + '{"unit": "parsec / fortnight"}', + ], +) +def test_initialize_unit_with_json(input_json_string): + # Make sure we can initialize a Unit from a json string + # where the quantity value is stored in the json as a string. + model = _UnitModel.model_validate_json(input_json_string) + model_json = json.loads(model.model_dump_json()) + input_json = json.loads(input_json_string) + + assert Unit(model_json["unit"]) == Unit(input_json["unit"]) + + +def test_initialize_unit_with_json_invalid(): + # Make sure we get an error when the json string is has a value + # that is a float (same fail happens for integer). + # Since our json validation assumes the value is a string, this + # should fail. + with pytest.raises(ValidationError, match="Input should be a valid string"): + _UnitModel.model_validate_json('{"unit": 14.0}') From ba90eebe18a7e177122d4e74f8a32634e7aad17d Mon Sep 17 00:00:00 2001 From: Matt Craig Date: Mon, 29 Jan 2024 13:08:52 -0600 Subject: [PATCH 03/15] Add more documentation to new classes --- stellarphot/settings/astropy_pydantic.py | 151 +++++++++++++++++++++-- 1 file changed, 144 insertions(+), 7 deletions(-) diff --git a/stellarphot/settings/astropy_pydantic.py b/stellarphot/settings/astropy_pydantic.py index ac2c239b..3168a4e2 100644 --- a/stellarphot/settings/astropy_pydantic.py +++ b/stellarphot/settings/astropy_pydantic.py @@ -32,6 +32,15 @@ class _UnitQuantTypePydanticAnnotation: whether an instance can be created. In astropy, this includes `astropy.units.Unit` and `astropy.units.Quantity`. + + We return a pydantic_core.CoreSchema that behaves in the following ways: + * When initializing from python: + * A Unit or a Quantity will pass validation and be returned as-is + * A string or a float will be used to create a Unit or a Quantity. + * Nothing else will pass validation + * When initializing from json: + * The Unit or Quantity must be represented as a string in the json. + * Serialization will always represent the Unit or Quantity as a string. """ @classmethod @@ -49,6 +58,7 @@ def __get_pydantic_core_schema__( * Serialization will always return just a string """ + # This function will end up creating the instance from the input value. def validate_by_instantiating(value): # If the value is valid we will be able to create an instance of the # source_type from it. For example, if source_type is astropy.units.Unit, @@ -56,6 +66,8 @@ def validate_by_instantiating(value): try: result = source_type(value) except TypeError as err: + # Need to raise a ValueError for pydantic to catch it as a + # validation error. raise ValueError(str(err)) from err return result @@ -65,13 +77,14 @@ def validate_by_instantiating(value): # core_schema.chain_schema runs the value through each of the schema # in the list, in order. The output of one schema is the input to the next. - # When you do `model_json_schema` with a `chain_schema`, then the first entry is - # used if `mode="validation"` and the last is used if `mode="serialization"` - # from the schema used to serialize json. + # The schema below validates a Unit or Quantity from a string. The first link + # in the chain is a schema that validates a string. The second link is a + # schema that takes that string value and creates a Unit or Quantity from it. - # I guess this makes sense, since the first thing in the chain has to handle the - # value coming from json, while the last thing generates the python value for - # the input. + # This schema is used in three places by pydantic: + # 1. When validating a python value + # 2. When validating a json value + # 3. When constructing a JSON schema for the model from_str_schema = core_schema.chain_schema( [ core_schema.str_schema(), @@ -79,13 +92,37 @@ def validate_by_instantiating(value): ] ) + # This schema validates a Unit or Quantity from a float. The first link in the + # chain is a schema that validates the input as a number. The second link is a + # schema that takes that numeric value and creates a Unit or Quantity from it. + + # This schema is used in just one place by pydantic: + # 1. When validating a python value from_float_schema = core_schema.chain_schema( [ core_schema.float_schema(), core_schema.no_info_plain_validator_function(validate_by_instantiating), ] ) + return core_schema.json_or_python_schema( + # The next line specifies two things: + # 1. The schema used to validate value from JSON. Since we are using the + # schema for a string value, the values in the JSON file must be + # strings, even though something like "{quantity: 1}" i.e. initializing + # from a number, would work in python. The reason for this choice is + # that the serialization is to a string, so that is what we expect from + # JSON. If we wanted to allow initialization from a number in JSON, we + # would need to use a union schema that consisted of + # from_str_schema and from_float_schema. + # 2. The schema used to construct a JSON schema for the model. When you do + # `model_json_schema` with a `chain_schema`, then the first entry of + # the chain is used if `mode="validation"` and the last entry of the + # chain is used if `mode="serialization"`. I guess this makes sense, + # since the first thing in the chain has to handle the value coming + # from json,while the last thing generates the python value for the + # input. With the choice below we will *always* want`mode="validation"` + # because pydantic cannot generate a schema fora Unit or Quantity. json_schema=from_str_schema, # union_schema takes a list of schemas and returns a schema that # is the "best" match. See the link below for a description of @@ -94,6 +131,10 @@ def validate_by_instantiating(value): # # In short, schemas are tried from left-to-right, and an exact type match # wins. + # + # The construction below tries to make a value starting from a Unit, + # a Quantity, a string, or a float. The first two are instances, so we + # use is_instance_schema. python_schema=core_schema.union_schema( [ # Check if it's an instance first before doing any further work. @@ -105,7 +146,7 @@ def validate_by_instantiating(value): from_float_schema, ] ), - # Serialization by converting to a string. + # Serialize by converting the Unit or Quantity to a string. serialization=core_schema.plain_serializer_function_ser_schema( lambda instance: str(instance) ), @@ -114,12 +155,58 @@ def validate_by_instantiating(value): @dataclass class EquivalentTo: + """ + This class is a pydantic "marker" (their word for this kind of thing) that + can be used to annotate fields that should be equivalent to a given unit. + + Parameters + ---------- + equivalent_unit : `astropy.units.Unit` + The unit that the annotated field should be equivalent to. + + Examples + -------- + >>> from typing import Annotated + >>> from pydantic import BaseModel, ValidationError + >>> from stellarphot.settings.astropy_pydantic import ( + ... WithPhysicalType, + ... UnitType, + ... QuantityType + ... ) + >>> class UnitModel(BaseModel): + ... length: Annotated[UnitType, EquivalentTo("m")] + >>> UnitModel(length="lightyear") + UnitModel(length=Unit("lyr")) + >>> try: + ... UnitModel(length="second") + ... except ValidationError as err: + ... print(err) + 1 validation error for UnitModel + length + Value error, Unit s is not equivalent to m... + >>> # Next let's do a Quantity + >>> class QuantityModel(BaseModel): + ... velocity: Annotated[QuantityType, EquivalentTo("m / s")] + >>> QuantityModel(velocity="3 lightyear / year") + QuantityModel(velocity=) + >>> try: + ... QuantityModel(velocity="3 parsec / lightyear") + ... except ValidationError as err: + ... print(err) + 1 validation error for QuantityModel + velocity + Value error, Unit pc / lyr is not equivalent to m / s... + """ + equivalent_unit: Unit + """Unit that the annotated field should be equivalent to.""" def __get_pydantic_core_schema__( self, source_type: Any, handler: GetCoreSchemaHandler ): def check_equivalent(value): + # We are getting either a Unit or a Quantity. If it's a Quantity, we + # need to get the unit from it. if isinstance(value, UnitBase): value_unit = value else: @@ -128,11 +215,14 @@ def check_equivalent(value): try: value.to(self.equivalent_unit) except UnitConversionError: + # Raise a ValueError for pydantic to catch it as a validation error. raise ValueError( f"Unit {value_unit} is not equivalent to {self.equivalent_unit}" ) from None return value + # Calling handler(source_type) will pass the result of this annotation + # to the next annotation in the chain. return core_schema.no_info_after_validator_function( check_equivalent, handler(source_type) ) @@ -140,12 +230,57 @@ def check_equivalent(value): @dataclass class WithPhysicalType: + """ + This class is a pydantic "marker" (their word for this kind of thing) that + can be used to annotate fields that should be of a specific physical type. + + Parameters + ---------- + physical_type : str or `astropy.units.physical.PhysicalType` + The physical type of the annotated field. + + Examples + -------- + >>> from typing import Annotated + >>> from pydantic import BaseModel, ValidationError + >>> from stellarphot.settings.astropy_pydantic import ( + ... WithPhysicalType, + ... UnitType, + ... QuantityType + ... ) + >>> class UnitModel(BaseModel): + ... length: Annotated[UnitType, WithPhysicalType("length")] + >>> UnitModel(length="meter") + UnitModel(length=Unit("m")) + >>> try: + ... UnitModel(length="second") + ... except ValidationError as err: + ... print(err) + 1 validation error for UnitModel + length + Value error, Unit of s is not equivalent to length... + >>> # Next let's do a Quantity + >>> class QuantityModel(BaseModel): + ... velocity: Annotated[QuantityType, WithPhysicalType("speed")] + >>> QuantityModel(velocity="3 meter / second") + QuantityModel(velocity=) + >>> try: + ... QuantityModel(velocity="3 meter") + ... except ValidationError as err: + ... print(err) + 1 validation error for QuantityModel + velocity + Value error, Unit of 3.0 m is not equivalent to speed... + """ + physical_type: str | PhysicalType def __post_init__(self): try: get_physical_type(self.physical_type) except ValueError as err: + # Add a link to the astropy documentation for physical types + # to the error message. raise ValueError( str(err) + f"\nSee {_PHYSICAL_TYPES_URL} for a list of valid physical types." @@ -163,6 +298,8 @@ def check_physical_type(value): f"Unit of {value} is not equivalent to {self.physical_type}" ) from None + # As in the EquivalentTo annotation, calling handler(source_type) will pass + # the result of this annotation to the next annotation in the chain. return core_schema.no_info_after_validator_function( check_physical_type, handler(source_type) ) From 176420b2ca102a3e1a6b6111b0719b9f64ec7f09 Mon Sep 17 00:00:00 2001 From: Matt Craig Date: Wed, 31 Jan 2024 09:28:50 -0600 Subject: [PATCH 04/15] Add generic astropy validator This only sort-of works, but can do Time and SkyCoord, which is what we need immediately. --- stellarphot/settings/astropy_pydantic.py | 127 +++++++++++-- stellarphot/settings/models.py | 179 ++++++------------ .../settings/tests/test_astropy_pydantic.py | 44 +++++ 3 files changed, 217 insertions(+), 133 deletions(-) diff --git a/stellarphot/settings/astropy_pydantic.py b/stellarphot/settings/astropy_pydantic.py index 3168a4e2..d02d055d 100644 --- a/stellarphot/settings/astropy_pydantic.py +++ b/stellarphot/settings/astropy_pydantic.py @@ -14,7 +14,13 @@ ) from pydantic_core import core_schema -__all__ = ["UnitType", "QuantityType", "EquivalentTo", "WithPhysicalType"] +__all__ = [ + "UnitType", + "QuantityType", + "EquivalentTo", + "WithPhysicalType", + "AstropyValidator", +] _PHYSICAL_TYPES_URL = "https://docs.astropy.org/en/stable/units/ref_api.html#module-astropy.units.physical" @@ -26,6 +32,39 @@ # https://docs.pydantic.dev/latest/concepts/types/#handling-third-party-types +# This function will end up creating the instance from the input value. +def validate_by_instantiating(source_type): + """ + Return a function that tries to create an instance of source_type from a value. + The intended use of this is as a vallidotr in pydantic. + + Parameters + ---------- + + source_type : Any + The type to create an instance of. + + Returns + ------- + function + A function that tries to create an instance of source_type from a value. + """ + + def _validator(value): + # If the value is valid we will be able to create an instance of the + # source_type from it. For example, if source_type is astropy.units.Unit, + # then we should be able to create a Unit from the value. + try: + result = source_type(value) + except TypeError as err: + # Need to raise a ValueError for pydantic to catch it as a + # validation error. + raise ValueError(str(err)) from err + return result + + return _validator + + class _UnitQuantTypePydanticAnnotation: """ This class is used to annotate fields where validation consists of checking @@ -57,20 +96,6 @@ def __get_pydantic_core_schema__( * Nothing else will pass validation * Serialization will always return just a string """ - - # This function will end up creating the instance from the input value. - def validate_by_instantiating(value): - # If the value is valid we will be able to create an instance of the - # source_type from it. For example, if source_type is astropy.units.Unit, - # then we should be able to create a Unit from the value. - try: - result = source_type(value) - except TypeError as err: - # Need to raise a ValueError for pydantic to catch it as a - # validation error. - raise ValueError(str(err)) from err - return result - # Both Unit and Quantity can be created from a string or a float or an # instance of the same type. So we need to check for all of those cases. @@ -88,7 +113,9 @@ def validate_by_instantiating(value): from_str_schema = core_schema.chain_schema( [ core_schema.str_schema(), - core_schema.no_info_plain_validator_function(validate_by_instantiating), + core_schema.no_info_plain_validator_function( + validate_by_instantiating(source_type) + ), ] ) @@ -101,7 +128,9 @@ def validate_by_instantiating(value): from_float_schema = core_schema.chain_schema( [ core_schema.float_schema(), - core_schema.no_info_plain_validator_function(validate_by_instantiating), + core_schema.no_info_plain_validator_function( + validate_by_instantiating(source_type) + ), ] ) @@ -319,3 +348,67 @@ def check_physical_type(value): # This is what is used to represent astropy # objects in Tables and FITS files. So we can use this to create a json schema # for the astropy types. +def serialize_astropy_type(value): + """ + Two things might happen here: + + 1. value serializes to JSON because each value in the dict reperesentation + is a type JSON knows how to represent, or + 2. value does not serialize because one or more of the values in the dict + representation is itself an astropy class. + """ + + def dict_rep(instance): + return instance.info._represent_as_dict() + + if isinstance(value, UnitBase | Quantity): + return str(value) + try: + rep = dict_rep(value) + except AttributeError: + # Either this is not an astropy thing, in which case just return the + # value, or this is an astropy unit. Happily, we can already serialize + # that. + return value if not hasattr(value, "to_string") else value.to_string() + + result = {} + for k, v in rep.items(): + result[k] = serialize_astropy_type(v) + + return result + + +class AstropyValidator: + @classmethod + def __get_pydantic_core_schema__( + cls, + source_type, + _handler, + ): + def astropy_object_from_dict(value): + """ + This is NOT the right way to be doing this when there are nested + definitions, e.g. in a SkyCoord where the RA and Dec are each + an angle, which is not a native python type. + """ + return source_type.info._construct_from_dict(value) + + from_dict_schema = core_schema.chain_schema( + [ + core_schema.dict_schema(), + core_schema.no_info_plain_validator_function(astropy_object_from_dict), + ] + ) + + return core_schema.json_or_python_schema( + json_schema=from_dict_schema, + python_schema=core_schema.union_schema( + [ + core_schema.is_instance_schema(source_type), + from_dict_schema, + ] + ), + serialization=core_schema.plain_serializer_function_ser_schema( + serialize_astropy_type + ), + ) diff --git a/stellarphot/settings/models.py b/stellarphot/settings/models.py index ea4448af..19653df5 100644 --- a/stellarphot/settings/models.py +++ b/stellarphot/settings/models.py @@ -3,17 +3,21 @@ from pathlib import Path from typing import Annotated +from astropy.coordinates import SkyCoord from astropy.io.misc.yaml import AstropyDumper, AstropyLoader +from astropy.time import Time from astropy.units import Quantity, Unit from pydantic import BaseModel, ConfigDict, Field, confloat, conint, model_validator -from .astropy_pydantic import EquivalentTo, QuantityType, UnitType +from .astropy_pydantic import ( + AstropyValidator, + EquivalentTo, + QuantityType, + UnitType, + WithPhysicalType, +) -__all__ = [ - "Camera", - "PhotometryApertures", - "PhotometryFileSettings", -] # "Exoplanet"] +__all__ = ["Camera", "PhotometryApertures", "PhotometryFileSettings", "Exoplanet"] # Most models should use the default configuration, but it can be customized if needed. MODEL_DEFAULT_CONFIGURATION = ConfigDict( @@ -293,113 +297,56 @@ class PhotometryFileSettings(BaseModel): ) -# class TimeType(Time): -# @classmethod -# def __get_validators__(cls): -# yield cls.validate - -# @classmethod -# def validate(cls, v): -# return Time(v) - - -# class SkyCoordType(SkyCoord): -# @classmethod -# def __get_validators__(cls): -# yield cls.validate - -# @classmethod -# def validate(cls, v): -# return SkyCoord(v) - - -# class Exoplanet(BaseModel): -# """ -# Create an object representing an Exoplanet. - -# Parameters -# ---------- - -# epoch : `astropy.time.Time`, optional -# Epoch of the exoplanet. - -# period : `astropy.units.Quantity`, optional -# Period of the exoplanet. - -# Identifier : str -# Identifier of the exoplanet. - -# coordinate : `astropy.coordinates.SkyCoord` -# Coordinates of the exoplanet. - -# depth : float -# Depth of the exoplanet. - -# duration : `astropy.units.Quantity`, optional -# Duration of the exoplanet transit. - -# Examples -# -------- - -# To create an `Exoplanet` object, you can pass in the epoch, -# period, identifier, coordinate, depth, and duration as keyword arguments: - -# >>> from astropy.time import Time -# >>> from astropy.coordinates import SkyCoord -# >>> from astropy import units as u -# >>> planet = Exoplanet(epoch=Time(2455909.29280, format="jd"), -# ... period=1.21749 * u.day, -# ... identifier="KELT-1b", -# ... coordinate=SkyCoord(ra="00:01:26.9169", -# ... dec="+39:23:01.7821", -# ... frame="icrs", -# ... unit=("hour", "degree")), -# ... depth=0.006, -# ... duration=120 * u.min) -# """ -# model_config = MODEL_DEFAULT_CONFIGURATION - -# epoch: TimeType | None = None -# period: QuantityType | None = None -# identifier: str -# coordinate: SkyCoordType -# depth: float | None = None -# duration: QuantityType | None = None - -# # class Config: -# # validate_all = True -# # validate_assignment = True -# # extra = "forbid" -# # json_encoders = { -# # Quantity: lambda v: f"{v.value} {v.unit}", -# # QuantityType: lambda v: f"{v.value} {v.unit}", -# # Time: lambda v: f"{v.value}", -# # } - -# @validator("period") -# @classmethod -# def validate_period(cls, value): -# """ -# Checks that the period has physical units of time and raises an error -# if that is not true. -# """ -# if u.get_physical_type(value) != "time": -# raise ValueError( -# f"Period does not have time units," -# f"currently has {value.unit} units." -# ) -# return value - -# @validator("duration") -# @classmethod -# def validate_duration(cls, value): -# """ -# Checks that the duration has physical units of time and raises an error -# if that is not true. -# """ -# if u.get_physical_type(value) != "time": -# raise ValueError( -# f"Duration does not have time units," -# f"currently has {value.unit} units." -# ) -# return value +class Exoplanet(BaseModel): + """ + Create an object representing an Exoplanet. + + Parameters + ---------- + + epoch : `astropy.time.Time`, optional + Epoch of the exoplanet. + + period : `astropy.units.Quantity`, optional + Period of the exoplanet. + + Identifier : str + Identifier of the exoplanet. + + coordinate : `astropy.coordinates.SkyCoord` + Coordinates of the exoplanet. + + depth : float + Depth of the exoplanet. + + duration : `astropy.units.Quantity`, optional + Duration of the exoplanet transit. + + Examples + -------- + + To create an `Exoplanet` object, you can pass in the epoch, + period, identifier, coordinate, depth, and duration as keyword arguments: + + >>> from astropy.time import Time + >>> from astropy.coordinates import SkyCoord + >>> from astropy import units as u + >>> planet = Exoplanet(epoch=Time(2455909.29280, format="jd"), + ... period=1.21749 * u.day, + ... identifier="KELT-1b", + ... coordinate=SkyCoord(ra="00:01:26.9169", + ... dec="+39:23:01.7821", + ... frame="icrs", + ... unit=("hour", "degree")), + ... depth=0.006, + ... duration=120 * u.min) + """ + + model_config = MODEL_DEFAULT_CONFIGURATION + + epoch: Annotated[Time, AstropyValidator] | None = None + period: Annotated[QuantityType, WithPhysicalType("time")] | None = None + identifier: str + coordinate: Annotated[SkyCoord, AstropyValidator] + depth: float | None = None + duration: Annotated[QuantityType, WithPhysicalType("time")] | None = None diff --git a/stellarphot/settings/tests/test_astropy_pydantic.py b/stellarphot/settings/tests/test_astropy_pydantic.py index d0fd18d8..18d5b546 100644 --- a/stellarphot/settings/tests/test_astropy_pydantic.py +++ b/stellarphot/settings/tests/test_astropy_pydantic.py @@ -1,15 +1,20 @@ import json from typing import Annotated +import numpy as np import pytest +from astropy.coordinates import SkyCoord +from astropy.time import Time from astropy.units import Quantity, Unit, get_physical_type from pydantic import BaseModel, ValidationError from stellarphot.settings.astropy_pydantic import ( + AstropyValidator, EquivalentTo, QuantityType, UnitType, WithPhysicalType, + serialize_astropy_type, ) @@ -89,6 +94,11 @@ def test_equivalent_to(): _ModelEquivalentTo(unit_meter="s", quantity_meter=Quantity("1 m")) +def test_equiv_physical_type_can_be_used_in_union(): + print("https://github.com/pydantic/pydantic/discussions/6412") + assert 0 + + def test_with_physical_type(): # Make sure we can annotate with a physical type model = _ModelWithPhysicalType( @@ -174,3 +184,37 @@ def test_initialize_unit_with_json_invalid(): # should fail. with pytest.raises(ValidationError, match="Input should be a valid string"): _UnitModel.model_validate_json('{"unit": 14.0}') + + +@pytest.mark.parametrize( + "klass,input", + [ + (Time, "2021-01-01T00:00:00"), + (SkyCoord, "00h42m44.3s +41d16m9s"), + ], +) +def test_time_quant_pydantic(klass, input): + class Model(BaseModel): + value: Annotated[klass, AstropyValidator] + + val = klass(input) + model = Model(value=val) + + # Value should be corret + assert model.value == val + + # model dump should fully serialize to standard python types + assert model.model_dump()["value"] == serialize_astropy_type(val) + + # We should be able to create a new model from the dumped json... + # ...but we can't because we apparently aren't serializing right. + model2 = Model.model_validate_json(model.model_dump_json()) + + if klass is SkyCoord: + np.testing.assert_almost_equal( + model2.value.separation(model.value).arcsec, + 0, + decimal=10, + ) + else: + assert model2.value == model.value From 931b6a3151c6a146b7ef8f91fd6a9b8027cc4a83 Mon Sep 17 00:00:00 2001 From: Matt Craig Date: Thu, 1 Feb 2024 10:38:45 -0600 Subject: [PATCH 05/15] Add hash to validators This is necessary for them to be able to used in type unions like Annotated[QuantityType, EquivalentTo("second")] | None --- stellarphot/settings/astropy_pydantic.py | 6 ++++++ .../settings/tests/test_astropy_pydantic.py | 21 ++++++++++++++++--- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/stellarphot/settings/astropy_pydantic.py b/stellarphot/settings/astropy_pydantic.py index d02d055d..94d91a4b 100644 --- a/stellarphot/settings/astropy_pydantic.py +++ b/stellarphot/settings/astropy_pydantic.py @@ -256,6 +256,9 @@ def check_equivalent(value): check_equivalent, handler(source_type) ) + def __hash__(self): + return hash(self.equivalent_unit) + @dataclass class WithPhysicalType: @@ -333,6 +336,9 @@ def check_physical_type(value): check_physical_type, handler(source_type) ) + def __hash__(self): + return hash(self.physical_type) + # We have lost default titles and exmples, but that is maybe not so bad diff --git a/stellarphot/settings/tests/test_astropy_pydantic.py b/stellarphot/settings/tests/test_astropy_pydantic.py index 18d5b546..2a0be35d 100644 --- a/stellarphot/settings/tests/test_astropy_pydantic.py +++ b/stellarphot/settings/tests/test_astropy_pydantic.py @@ -94,9 +94,24 @@ def test_equivalent_to(): _ModelEquivalentTo(unit_meter="s", quantity_meter=Quantity("1 m")) -def test_equiv_physical_type_can_be_used_in_union(): - print("https://github.com/pydantic/pydantic/discussions/6412") - assert 0 +def test_equivalent_to_can_be_used_in_union(): + class ModelWithUnion(BaseModel): + may_be_none: Annotated[QuantityType, EquivalentTo("second")] | None + + model = ModelWithUnion(may_be_none=None) + assert model.may_be_none is None + model = ModelWithUnion(may_be_none=Quantity(1, "s")) + assert model.may_be_none == Quantity(1, "s") + + +def test_physical_type_can_be_used_in_union(): + class ModelWithUnion(BaseModel): + may_be_none: Annotated[QuantityType, WithPhysicalType("time")] | None + + model = ModelWithUnion(may_be_none=None) + assert model.may_be_none is None + model = ModelWithUnion(may_be_none=Quantity(1, "s")) + assert model.may_be_none == Quantity(1, "s") def test_with_physical_type(): From fd580d8e080ae88235baa952b42c4e718455d31d Mon Sep 17 00:00:00 2001 From: Matt Craig Date: Thu, 1 Feb 2024 10:38:57 -0600 Subject: [PATCH 06/15] Turn model tests back on --- stellarphot/settings/tests/test_models.py | 40 +++++++++++------------ 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/stellarphot/settings/tests/test_models.py b/stellarphot/settings/tests/test_models.py index d1add1de..ec6a137b 100644 --- a/stellarphot/settings/tests/test_models.py +++ b/stellarphot/settings/tests/test_models.py @@ -8,7 +8,7 @@ from pydantic import ValidationError from stellarphot.settings import ui_generator -from stellarphot.settings.models import Camera, PhotometryApertures +from stellarphot.settings.models import Camera, Exoplanet, PhotometryApertures DEFAULT_APERTURE_SETTINGS = dict(radius=5, gap=10, annulus_width=15, fwhm=3.2) @@ -245,22 +245,22 @@ def test_create_invalid_values(bad_one): PhotometryApertures(**bad_settings) -# def test_create_exoplanet_correctly(): -# planet = Exoplanet(**DEFAULT_EXOPLANET_SETTINGS) -# print(planet) -# assert planet.epoch == DEFAULT_EXOPLANET_SETTINGS["epoch"] -# assert u.get_physical_type(planet.period) == "time" -# assert planet.identifier == DEFAULT_EXOPLANET_SETTINGS["identifier"] -# assert planet.coordinate == DEFAULT_EXOPLANET_SETTINGS["coordinate"] -# assert planet.depth == DEFAULT_EXOPLANET_SETTINGS["depth"] -# assert u.get_physical_type(planet.duration) == "time" - - -# def test_create_invalid_exoplanet(): -# values = DEFAULT_EXOPLANET_SETTINGS.copy() -# # Make pediod and duration have invalid units for a time -# values["period"] = values["period"].value * u.m -# values["duration"] = values["duration"].value * u.m -# # Check that individual values that are bad raise an error -# with pytest.raises(ValidationError, match="2 validation errors"): -# Exoplanet(**values) +def test_create_exoplanet_correctly(): + planet = Exoplanet(**DEFAULT_EXOPLANET_SETTINGS) + print(planet) + assert planet.epoch == DEFAULT_EXOPLANET_SETTINGS["epoch"] + assert u.get_physical_type(planet.period) == "time" + assert planet.identifier == DEFAULT_EXOPLANET_SETTINGS["identifier"] + assert planet.coordinate == DEFAULT_EXOPLANET_SETTINGS["coordinate"] + assert planet.depth == DEFAULT_EXOPLANET_SETTINGS["depth"] + assert u.get_physical_type(planet.duration) == "time" + + +def test_create_invalid_exoplanet(): + values = DEFAULT_EXOPLANET_SETTINGS.copy() + # Make pediod and duration have invalid units for a time + values["period"] = values["period"].value * u.m + values["duration"] = values["duration"].value * u.m + # Check that individual values that are bad raise an error + with pytest.raises(ValidationError, match="2 validation errors"): + Exoplanet(**values) From 16e0fc7fc5fd9e393a706ec0c5a31fa72211114c Mon Sep 17 00:00:00 2001 From: Matt Craig Date: Thu, 1 Feb 2024 11:16:47 -0600 Subject: [PATCH 07/15] Improve documentation of some new code --- stellarphot/settings/astropy_pydantic.py | 61 +++++++++++++++++++++--- 1 file changed, 55 insertions(+), 6 deletions(-) diff --git a/stellarphot/settings/astropy_pydantic.py b/stellarphot/settings/astropy_pydantic.py index 94d91a4b..e35165a1 100644 --- a/stellarphot/settings/astropy_pydantic.py +++ b/stellarphot/settings/astropy_pydantic.py @@ -356,19 +356,38 @@ def __hash__(self): # for the astropy types. def serialize_astropy_type(value): """ - Two things might happen here: + Serialize astropy objects like Time and SkyCoord to a dictionary. - 1. value serializes to JSON because each value in the dict reperesentation - is a type JSON knows how to represent, or - 2. value does not serialize because one or more of the values in the dict - representation is itself an astropy class. + In principle, we ought to be able to use the astropy serialization stuff + that is used in writing tables to ecsv here, but that is not quite working + yet. + + Parameters + ---------- + value : Any + The value to serialize. + + Returns + ------- + dict + A dictionary representation of the astropy object. """ def dict_rep(instance): + """ + This is in a function so it only shows up in one place in the event the + private astropy API changes. + """ return instance.info._represent_as_dict() + # The if statement below is a bit of a hack. It's not clear to me how to + # use the _represent_as_dict stuff to serialize something like a SkyCoord that + # has nested astropy objects in it. So for now, we just return the string + # representation of the objects, like Angle (a type of Quantity), that are + # entries in the dict representation of a SKyCoord. if isinstance(value, UnitBase | Quantity): return str(value) + try: rep = dict_rep(value) except AttributeError: @@ -378,6 +397,8 @@ def dict_rep(instance): return value if not hasattr(value, "to_string") else value.to_string() result = {} + + # Recurse to handle nested astropy objects for k, v in rep.items(): result[k] = serialize_astropy_type(v) @@ -385,6 +406,32 @@ def dict_rep(instance): class AstropyValidator: + """ + This class is a pydantic "marker" (their word for this kind of thing) that + can be used to annotate fields that should be of an astropy type that can be + serialized. + + Examples + -------- + >>> from typing import Annotated + >>> from pydantic import BaseModel + >>> from astropy.time import Time + >>> from astropy.coordinates import SkyCoord + >>> from stellarphot.settings.astropy_pydantic import AstropyValidator + >>> # Making a model with a Time field + >>> class TimeModel(BaseModel): + ... time: Annotated[Time, AstropyValidator] + >>> # The time must be either a Time object or a dictionary that can be used to + >>> TimeModel(time=Time("2021-01-01T00:00:00")) + TimeModel(time=