Skip to content

Commit

Permalink
fix: Handle forward refs in pydantic.
Browse files Browse the repository at this point in the history
  • Loading branch information
DanCardin committed Apr 30, 2024
1 parent e1f9db4 commit ad5a6c6
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 31 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "dataclass-settings"
version = "0.3.0"
version = "0.3.1"
description = "Declarative dataclass settings."

repository = "https://github.com/dancardin/dataclass-settings"
Expand Down
78 changes: 48 additions & 30 deletions src/dataclass_settings/class_inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import dataclasses
from enum import Enum
from typing import Any, Callable, Type
from typing import Any, Callable, Tuple, Type

import typing_inspect
from typing_extensions import Annotated, Self, get_args, get_origin, get_type_hints
Expand Down Expand Up @@ -71,82 +71,88 @@ class Field:
mapper: Callable[..., Any] | None = None

@classmethod
def from_dataclass(cls, typ: Type) -> list[Self]:
def from_dataclass(cls, typ: Type, type_hints: dict[str, Type]) -> list[Self]:
fields = []
for f in typ.__dataclass_fields__.values():
type_ = get_origin(f.type) or f.type
args = get_args(f.type) or ()
if type_ is Annotated:
type_, *_args = args
args = tuple(_args)
annotation = get_type(type_hints[f.name])

annotation, args = get_annotation_args(annotation)

field = cls(
name=f.name,
type=type_,
type=annotation,
annotations=args,
mapper=type_,
mapper=annotation,
)
fields.append(field)
return fields

@classmethod
def from_pydantic(cls, typ: Type) -> list[Self]:
def from_pydantic(cls, typ: Type, type_hints: dict[str, Type]) -> list[Self]:
fields = []
for name, f in typ.model_fields.items():
annotation_type = get_type(f.annotation)
mapper = annotation_type if detect(annotation_type) else None
annotation = get_type(type_hints[name])
mapper = annotation if detect(annotation) else None

field = cls(
name=name,
type=f.annotation,
type=annotation,
annotations=tuple(f.metadata),
mapper=mapper,
)
fields.append(field)
return fields

@classmethod
def from_pydantic_v1(cls, typ: Type) -> list[Self]:
def from_pydantic_v1(cls, typ: Type, type_hints: dict[str, Type]) -> list[Self]:
fields = []
type_hints = get_type_hints(typ, include_extras=True)
for name, f in typ.__fields__.items():
annotation = get_type(type_hints[name])
annotation, args = get_annotation_args(annotation)

mapper = annotation if detect(annotation) else None

field = cls(
name=name,
type=f.annotation,
annotations=get_args(annotation) or (),
type=annotation,
annotations=args,
mapper=mapper,
)
fields.append(field)
return fields

@classmethod
def from_pydantic_dataclass(cls, typ: Type) -> list[Self]:
def from_pydantic_dataclass(
cls, typ: Type, type_hints: dict[str, Type]
) -> list[Self]:
fields = []

for name, f in typ.__pydantic_fields__.items():
annotation_type = get_type(f.annotation)
mapper = annotation_type if detect(annotation_type) else None
annotation = get_type(type_hints[name])
mapper = annotation if detect(annotation) else None

field = cls(
name=name,
type=f.annotation,
type=annotation,
annotations=tuple(f.metadata),
mapper=mapper,
)
fields.append(field)
return fields

@classmethod
def from_attrs(cls, typ: Type) -> list[Self]:
def from_attrs(cls, typ: Type, type_hints: dict[str, Type]) -> list[Self]:
fields = []

for f in typ.__attrs_attrs__:
annotation = get_type(type_hints[f.name])
annotation, args = get_annotation_args(annotation)

field = cls(
name=f.name,
type=get_origin(f.type) or f.type,
annotations=get_args(f.type) or (),
mapper=f.type,
type=annotation,
annotations=args,
mapper=annotation,
)
fields.append(field)
return fields
Expand Down Expand Up @@ -189,20 +195,22 @@ def map_value(self, value: str | dict[str, Any]):

def fields(cls: type):
class_type = ClassTypes.from_cls(cls)

type_hints = get_type_hints(cls, include_extras=True)
if class_type == ClassTypes.dataclass:
return Field.from_dataclass(cls)
return Field.from_dataclass(cls, type_hints)

if class_type == ClassTypes.pydantic:
return Field.from_pydantic(cls)
return Field.from_pydantic(cls, type_hints)

if class_type == ClassTypes.pydantic_v1:
return Field.from_pydantic_v1(cls)
return Field.from_pydantic_v1(cls, type_hints)

if class_type == ClassTypes.pydantic_dataclass:
return Field.from_pydantic_dataclass(cls)
return Field.from_pydantic_dataclass(cls, type_hints)

if class_type == ClassTypes.attrs:
return Field.from_attrs(cls)
return Field.from_attrs(cls, type_hints)

raise NotImplementedError() # pragma: no cover

Expand All @@ -211,3 +219,13 @@ def get_type(typ):
if typing_inspect.is_optional_type(typ):
return get_args(typ)[0]
return typ


def get_annotation_args(annotation) -> Tuple[Type, Tuple[Any, ...]]:
args: Tuple[Any, ...] = ()
if get_origin(annotation) is Annotated:
args = get_args(annotation)
annotation, *_args = args
args = tuple(_args)

return annotation, args
66 changes: 66 additions & 0 deletions tests/test_forward_ref.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from __future__ import annotations

from dataclasses import dataclass

from attr import dataclass as attr_dataclass
from dataclass_settings import load_settings
from pydantic import BaseModel
from pydantic.dataclasses import dataclass as pydantic_dataclass


class PydanticConfig(BaseModel):
foo: PydanticFoo


class PydanticFoo(BaseModel):
foo: int = 0


def test_pydantic():
config = load_settings(PydanticConfig)
assert config == PydanticConfig(foo=PydanticFoo(foo=0))


@dataclass
class DataclassConfig:
foo: DataclassFoo


@dataclass
class DataclassFoo:
foo: int = 0


def test_dataclass():
config = load_settings(DataclassConfig)
assert config == DataclassConfig(foo=DataclassFoo(foo=0))


@pydantic_dataclass
class PDataclassConfig:
foo: PDataclassFoo


@pydantic_dataclass
class PDataclassFoo:
foo: int = 0


def test_pydantic_dataclass():
config = load_settings(PDataclassConfig)
assert config == PDataclassConfig(foo=PDataclassFoo(foo=0))


@attr_dataclass
class AttrConfig:
foo: AttrFoo


@attr_dataclass
class AttrFoo:
foo: int = 0


def test_attr_dataclass():
config = load_settings(AttrConfig)
assert config == AttrConfig(foo=AttrFoo(foo=0))

0 comments on commit ad5a6c6

Please sign in to comment.