Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates for pydantic v2 #619

Merged
merged 9 commits into from
Jul 7, 2023
4 changes: 2 additions & 2 deletions ci/environment-docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dependencies:
- myst-nb
- netcdf4!=1.6.1
- pip
- pydantic>=1.9
- pydantic>=2.0
- python-graphviz
- python=3.11
- s3fs >=2023.05
Expand All @@ -27,6 +27,6 @@ dependencies:
- zarr>=2.12
- furo>=2022.09.15
- pip:
- git+https://github.com/ncar-xdev/ecgtools
- sphinxext-opengraph
- autodoc_pydantic
- -e ..
2 changes: 1 addition & 1 deletion ci/environment-upstream-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies:
- pooch
- pre-commit
- psutil
- pydantic>=1.9
- pydantic>=2.0
- pydap
- pyproj
- pytest
Expand Down
2 changes: 1 addition & 1 deletion ci/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dependencies:
- pip
- pooch
- pre-commit
- pydantic>=1.9
- pydantic>=2.0
- pydap
- pytest
- pytest-cov
Expand Down
4 changes: 0 additions & 4 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
'myst_nb',
'sphinxext.opengraph',
'sphinx_copybutton',
'sphinxcontrib.autodoc_pydantic',
'sphinx_design',
]

Expand All @@ -29,9 +28,6 @@
copybutton_prompt_text = r'>>> |\.\.\. |\$ |In \[\d*\]: | {2,5}\.\.\.: | {5,8}: '
copybutton_prompt_is_regexp = True

autodoc_pydantic_model_show_json = True
autodoc_pydantic_model_show_config = False

nb_execution_mode = 'cache'
nb_execution_timeout = 600
nb_execution_raise_on_error = True
Expand Down
10 changes: 8 additions & 2 deletions docs/source/reference/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,19 @@ For more details and examples, refer to the relevant chapters in the main part o
## ESM Catalog

```{eval-rst}
.. autopydantic_model:: intake_esm.cat.ESMCatalogModel
.. autoclass:: intake_esm.cat.ESMCatalogModel
:members:
:noindex:
:special-members: __init__
```

## Query Model

```{eval-rst}
.. autopydantic_model:: intake_esm.cat.QueryModel
.. autoclass:: intake_esm.cat.QueryModel
:members:
:noindex:
:special-members: __init__
```

## Derived Variable Registry
Expand Down
95 changes: 41 additions & 54 deletions intake_esm/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pandas as pd
import pydantic
import tlz
from pydantic import ConfigDict

from ._search import search, search_apply_require_all_on

Expand Down Expand Up @@ -40,9 +41,7 @@ class AggregationType(str, enum.Enum):
join_existing = 'join_existing'
union = 'union'

class Config:
validate_all = True
validate_assignment = True
model_config = ConfigDict(validate_default=True, validate_assignment=True)


class DataFormat(str, enum.Enum):
Expand All @@ -51,57 +50,47 @@ class DataFormat(str, enum.Enum):
reference = 'reference'
opendap = 'opendap'

class Config:
validate_all = True
validate_assignment = True
model_config = ConfigDict(validate_default=True, validate_assignment=True)


class Attribute(pydantic.BaseModel):
column_name: pydantic.StrictStr
vocabulary: pydantic.StrictStr = ''

class Config:
validate_all = True
validate_assignment = True
model_config = ConfigDict(validate_default=True, validate_assignment=True)


class Assets(pydantic.BaseModel):
column_name: pydantic.StrictStr
format: typing.Optional[DataFormat]
format_column_name: typing.Optional[pydantic.StrictStr]
format: typing.Optional[DataFormat] = None
format_column_name: typing.Optional[pydantic.StrictStr] = None

class Config:
validate_all = True
validate_assignment = True
model_config = ConfigDict(validate_default=True, validate_assignment=True)

@pydantic.root_validator
def _validate_data_format(cls, values):
data_format, format_column_name = values.get('format'), values.get('format_column_name')
@pydantic.model_validator(mode='after')
def _validate_data_format(cls, model):
data_format, format_column_name = model.format, model.format_column_name
if data_format is not None and format_column_name is not None:
raise ValueError('Cannot set both format and format_column_name')
elif data_format is None and format_column_name is None:
raise ValueError('Must set one of format or format_column_name')
return values
return model


class Aggregation(pydantic.BaseModel):
type: AggregationType
attribute_name: pydantic.StrictStr
options: typing.Optional[dict] = {}
options: dict = {}

class Config:
validate_all = True
validate_assignment = True
model_config = ConfigDict(validate_default=True, validate_assignment=True)


class AggregationControl(pydantic.BaseModel):
variable_column_name: pydantic.StrictStr
groupby_attrs: list[pydantic.StrictStr]
aggregations: list[Aggregation] = []

class Config:
validate_all = True
validate_assignment = True
model_config = ConfigDict(validate_default=True, validate_assignment=True)


class ESMCatalogModel(pydantic.BaseModel):
Expand All @@ -113,35 +102,33 @@ class ESMCatalogModel(pydantic.BaseModel):
attributes: list[Attribute]
assets: Assets
aggregation_control: typing.Optional[AggregationControl] = None
id: typing.Optional[str] = ''
id: str = ''
catalog_dict: typing.Optional[list[dict]] = None
catalog_file: pydantic.StrictStr = None
description: pydantic.StrictStr = None
title: pydantic.StrictStr = None
catalog_file: typing.Optional[pydantic.StrictStr] = None
description: typing.Optional[pydantic.StrictStr] = None
title: typing.Optional[pydantic.StrictStr] = None
last_updated: typing.Optional[typing.Union[datetime.datetime, datetime.date]] = None
_df: typing.Optional[pd.DataFrame] = pydantic.PrivateAttr()
_df: pd.DataFrame = pydantic.PrivateAttr()

class Config:
arbitrary_types_allowed = True
underscore_attrs_are_private = True
validate_all = True
validate_assignment = True
model_config = ConfigDict(
arbitrary_types_allowed=True, validate_default=True, validate_assignment=True
)

@pydantic.root_validator
def validate_catalog(cls, values):
catalog_dict, catalog_file = values.get('catalog_dict'), values.get('catalog_file')
@pydantic.model_validator(mode='after')
def validate_catalog(cls, model):
catalog_dict, catalog_file = model.catalog_dict, model.catalog_file
if catalog_dict is not None and catalog_file is not None:
raise ValueError('catalog_dict and catalog_file cannot be set at the same time')

return values
return model

@classmethod
def from_dict(cls, data: dict) -> 'ESMCatalogModel':
esmcat = data['esmcat']
df = data['df']
if 'last_updated' not in esmcat:
esmcat['last_updated'] = None
cat = cls.parse_obj(esmcat)
cat = cls.model_validate(esmcat)
cat._df = df
return cat

Expand Down Expand Up @@ -254,7 +241,7 @@ def load(
data = json.loads(fobj.read())
if 'last_updated' not in data:
data['last_updated'] = None
cat = cls.parse_obj(data)
cat = cls.model_validate(data)
if cat.catalog_file:
if _mapper.fs.exists(cat.catalog_file):
csv_path = cat.catalog_file
Expand Down Expand Up @@ -417,32 +404,32 @@ class QueryModel(pydantic.BaseModel):

query: dict[pydantic.StrictStr, typing.Union[typing.Any, list[typing.Any]]]
columns: list[str]
require_all_on: typing.Union[str, list[typing.Any]] = None
require_all_on: typing.Optional[typing.Union[str, list[typing.Any]]] = None

class Config:
validate_all = True
validate_assignment = True
# TODO: Seem to be unable to modify fields in model_validator with
# validate_assignment=True since it leads to recursion
model_config = ConfigDict(validate_default=True, validate_assignment=False)

@pydantic.root_validator(pre=False)
def validate_query(cls, values):
query = values.get('query', {})
columns = values.get('columns')
require_all_on = values.get('require_all_on', [])
@pydantic.model_validator(mode='after')
def validate_query(cls, model):
query = model.query
columns = model.columns
require_all_on = model.require_all_on

if query:
for key in query:
if key not in columns:
raise ValueError(f'Column {key} not in columns {columns}')
if isinstance(require_all_on, str):
values['require_all_on'] = [require_all_on]
model.require_all_on = [require_all_on]
if require_all_on is not None:
for key in values['require_all_on']:
for key in model.require_all_on:
if key not in columns:
raise ValueError(f'Column {key} not in columns {columns}')
_query = query.copy()
for key, value in _query.items():
if isinstance(value, (str, int, float, bool)) or value is None or value is pd.NA:
_query[key] = [value]

values['query'] = _query
return values
model.query = _query
return model
38 changes: 21 additions & 17 deletions intake_esm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,11 @@ def _ipython_key_completions_(self):
return self.__dir__()

@pydantic.validate_arguments
def search(self, require_all_on: typing.Union[str, list[str]] = None, **query: typing.Any):
def search(
self,
require_all_on: typing.Optional[typing.Union[str, list[str]]] = None,
**query: typing.Any,
):
"""Search for entries in the catalog.

Parameters
Expand Down Expand Up @@ -443,11 +447,11 @@ def search(self, require_all_on: typing.Union[str, list[str]] = None, **query: t
def serialize(
self,
name: pydantic.StrictStr,
directory: typing.Union[pydantic.DirectoryPath, pydantic.StrictStr] = None,
directory: typing.Optional[typing.Union[pydantic.DirectoryPath, pydantic.StrictStr]] = None,
catalog_type: str = 'dict',
to_csv_kwargs: dict[typing.Any, typing.Any] = None,
json_dump_kwargs: dict[typing.Any, typing.Any] = None,
storage_options: dict[str, typing.Any] = None,
to_csv_kwargs: typing.Optional[dict[typing.Any, typing.Any]] = None,
json_dump_kwargs: typing.Optional[dict[typing.Any, typing.Any]] = None,
storage_options: typing.Optional[dict[str, typing.Any]] = None,
) -> None:
"""Serialize catalog to corresponding json and csv files.

Expand Down Expand Up @@ -536,12 +540,12 @@ def unique(self) -> pd.Series:
@pydantic.validate_arguments
def to_dataset_dict(
self,
xarray_open_kwargs: dict[str, typing.Any] = None,
xarray_combine_by_coords_kwargs: dict[str, typing.Any] = None,
preprocess: typing.Callable = None,
storage_options: dict[pydantic.StrictStr, typing.Any] = None,
progressbar: pydantic.StrictBool = None,
aggregate: pydantic.StrictBool = None,
xarray_open_kwargs: typing.Optional[dict[str, typing.Any]] = None,
xarray_combine_by_coords_kwargs: typing.Optional[dict[str, typing.Any]] = None,
preprocess: typing.Optional[typing.Callable] = None,
storage_options: typing.Optional[dict[pydantic.StrictStr, typing.Any]] = None,
progressbar: typing.Optional[pydantic.StrictBool] = None,
aggregate: typing.Optional[pydantic.StrictBool] = None,
skip_on_error: pydantic.StrictBool = False,
**kwargs,
) -> dict[str, xr.Dataset]:
Expand Down Expand Up @@ -686,12 +690,12 @@ def to_dataset_dict(
@pydantic.validate_arguments
def to_datatree(
self,
xarray_open_kwargs: dict[str, typing.Any] = None,
xarray_combine_by_coords_kwargs: dict[str, typing.Any] = None,
preprocess: typing.Callable = None,
storage_options: dict[pydantic.StrictStr, typing.Any] = None,
progressbar: pydantic.StrictBool = None,
aggregate: pydantic.StrictBool = None,
xarray_open_kwargs: typing.Optional[dict[str, typing.Any]] = None,
xarray_combine_by_coords_kwargs: typing.Optional[dict[str, typing.Any]] = None,
preprocess: typing.Optional[typing.Callable] = None,
storage_options: typing.Optional[dict[pydantic.StrictStr, typing.Any]] = None,
progressbar: typing.Optional[pydantic.StrictBool] = None,
aggregate: typing.Optional[pydantic.StrictBool] = None,
skip_on_error: pydantic.StrictBool = False,
**kwargs,
):
Expand Down
4 changes: 2 additions & 2 deletions intake_esm/derived.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class DerivedVariable(pydantic.BaseModel):
query: dict[pydantic.StrictStr, typing.Union[typing.Any, list[typing.Any]]]
prefer_derived: bool

@pydantic.validator('query')
@pydantic.field_validator('query')
def validate_query(cls, values):
_query = values.copy()
for key, value in _query.items():
Expand Down Expand Up @@ -46,7 +46,7 @@ def __call__(self, *args, variable_key_name: str = None, **kwargs) -> xr.Dataset
class DerivedVariableRegistry:
"""Registry of derived variables"""

def __post_init_post_parse__(self):
def __post_init__(self):
self._registry = {}

@classmethod
Expand Down
12 changes: 6 additions & 6 deletions intake_esm/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,12 @@ def __init__(
*,
variable_column_name: typing.Optional[pydantic.StrictStr] = None,
aggregations: typing.Optional[list[Aggregation]] = None,
requested_variables: list[str] = None,
preprocess: typing.Callable = None,
storage_options: dict[str, typing.Any] = None,
xarray_open_kwargs: dict[str, typing.Any] = None,
xarray_combine_by_coords_kwargs: dict[str, typing.Any] = None,
intake_kwargs: dict[str, typing.Any] = None,
requested_variables: typing.Optional[list[str]] = None,
preprocess: typing.Optional[typing.Callable] = None,
storage_options: typing.Optional[dict[str, typing.Any]] = None,
xarray_open_kwargs: typing.Optional[dict[str, typing.Any]] = None,
xarray_combine_by_coords_kwargs: typing.Optional[dict[str, typing.Any]] = None,
intake_kwargs: typing.Optional[dict[str, typing.Any]] = None,
):
"""An intake compatible Data Source for ESM data.

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ netCDF4>=1.5.5
requests>=2.24.0
xarray>=2022.06
zarr>=2.12
pydantic>=1.9
pydantic>=2.0