Skip to content

Commit

Permalink
Determine RootModel complexity from root type (#344)
Browse files Browse the repository at this point in the history
Co-authored-by: Hasan Ramezani <[email protected]>
  • Loading branch information
user1584 and hramezani authored Aug 13, 2024
1 parent 5c3a817 commit 8fb9abb
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 1 deletion.
12 changes: 11 additions & 1 deletion pydantic_settings/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

import typing_extensions
from dotenv import dotenv_values
from pydantic import AliasChoices, AliasPath, BaseModel, Json
from pydantic import AliasChoices, AliasPath, BaseModel, Json, RootModel
from pydantic._internal._repr import Representation
from pydantic._internal._typing_extra import WithArgsTypes, origin_is_union, typing_base
from pydantic._internal._utils import deep_update, is_model_class, lenient_issubclass
Expand Down Expand Up @@ -1904,6 +1904,16 @@ def read_env_file(


def _annotation_is_complex(annotation: type[Any] | None, metadata: list[Any]) -> bool:
# If the model is a root model, the root annotation should be used to
# evaluate the complexity.
if isinstance(annotation, type) and issubclass(annotation, RootModel):
# In some rare cases (see test_root_model_as_field),
# the root attribute is not available. For these cases, python 3.8 and 3.9
# return 'RootModelRootType'.
root_annotation = annotation.__annotations__.get('root', None)
if root_annotation is not None and root_annotation != 'RootModelRootType':
annotation = root_annotation

if any(isinstance(md, Json) for md in metadata): # type: ignore[misc]
return False
# Check if annotation is of the form Annotated[type, metadata].
Expand Down
38 changes: 38 additions & 0 deletions tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import dataclasses
import json
import os
import pathlib
import re
import sys
import typing
Expand Down Expand Up @@ -2185,6 +2186,43 @@ class Settings(BaseSettings):
assert s.model_dump() == {'z': [{'x': 1, 'y': {'foo': 1}}, {'x': 2, 'y': {'foo': 2}}]}


def test_str_based_root_model(env):
"""Testing to pass string directly to root model."""

class Foo(RootModel[str]):
root: str

class Settings(BaseSettings):
foo: Foo
plain: str

TEST_STR = 'hello world'
env.set('foo', TEST_STR)
env.set('plain', TEST_STR)
s = Settings()
assert s.model_dump() == {'foo': TEST_STR, 'plain': TEST_STR}


def test_path_based_root_model(env):
"""Testing to pass path directly to root model."""

class Foo(RootModel[pathlib.PurePosixPath]):
root: pathlib.PurePosixPath

class Settings(BaseSettings):
foo: Foo
plain: pathlib.PurePosixPath

TEST_PATH: str = '/hello/world'
env.set('foo', TEST_PATH)
env.set('plain', TEST_PATH)
s = Settings()
assert s.model_dump() == {
'foo': pathlib.PurePosixPath(TEST_PATH),
'plain': pathlib.PurePosixPath(TEST_PATH),
}


def test_optional_field_from_env(env):
class Settings(BaseSettings):
x: Optional[str] = None
Expand Down

0 comments on commit 8fb9abb

Please sign in to comment.