Skip to content

Commit

Permalink
chore(internal): add support for TypeAliasType (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
stainless-app[bot] committed Dec 13, 2024
1 parent d3a4d21 commit 012f185
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 13 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ authors = [
dependencies = [
"httpx>=0.23.0, <1",
"pydantic>=1.9.0, <3",
"typing-extensions>=4.7, <5",
"typing-extensions>=4.10, <5",
"anyio>=3.5.0, <5",
"distro>=1.7.0, <2",
"sniffio",
Expand Down
3 changes: 3 additions & 0 deletions src/saturn_sdk/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
strip_not_given,
extract_type_arg,
is_annotated_type,
is_type_alias_type,
strip_annotated_type,
)
from ._compat import (
Expand Down Expand Up @@ -428,6 +429,8 @@ def construct_type(*, value: object, type_: object) -> object:
# we allow `object` as the input type because otherwise, passing things like
# `Literal['value']` will be reported as a type error by type checkers
type_ = cast("type[object]", type_)
if is_type_alias_type(type_):
type_ = type_.__value__ # type: ignore[unreachable]

# unwrap `Annotated[T, ...]` -> `T`
if is_annotated_type(type_):
Expand Down
20 changes: 10 additions & 10 deletions src/saturn_sdk/_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import pydantic

from ._types import NoneType
from ._utils import is_given, extract_type_arg, is_annotated_type, extract_type_var_from_base
from ._utils import is_given, extract_type_arg, is_annotated_type, is_type_alias_type, extract_type_var_from_base
from ._models import BaseModel, is_basemodel
from ._constants import RAW_RESPONSE_HEADER, OVERRIDE_CAST_TO_HEADER
from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type
Expand Down Expand Up @@ -126,9 +126,15 @@ def __repr__(self) -> str:
)

def _parse(self, *, to: type[_T] | None = None) -> R | _T:
cast_to = to if to is not None else self._cast_to

# unwrap `TypeAlias('Name', T)` -> `T`
if is_type_alias_type(cast_to):
cast_to = cast_to.__value__ # type: ignore[unreachable]

# unwrap `Annotated[T, ...]` -> `T`
if to and is_annotated_type(to):
to = extract_type_arg(to, 0)
if cast_to and is_annotated_type(cast_to):
cast_to = extract_type_arg(cast_to, 0)

if self._is_sse_stream:
if to:
Expand Down Expand Up @@ -164,18 +170,12 @@ def _parse(self, *, to: type[_T] | None = None) -> R | _T:
return cast(
R,
stream_cls(
cast_to=self._cast_to,
cast_to=cast_to,
response=self.http_response,
client=cast(Any, self._client),
),
)

cast_to = to if to is not None else self._cast_to

# unwrap `Annotated[T, ...]` -> `T`
if is_annotated_type(cast_to):
cast_to = extract_type_arg(cast_to, 0)

if cast_to is NoneType:
return cast(R, None)

Expand Down
1 change: 1 addition & 0 deletions src/saturn_sdk/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
is_iterable_type as is_iterable_type,
is_required_type as is_required_type,
is_annotated_type as is_annotated_type,
is_type_alias_type as is_type_alias_type,
strip_annotated_type as strip_annotated_type,
extract_type_var_from_base as extract_type_var_from_base,
)
Expand Down
31 changes: 30 additions & 1 deletion src/saturn_sdk/_utils/_typing.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
from __future__ import annotations

import sys
import typing
import typing_extensions
from typing import Any, TypeVar, Iterable, cast
from collections import abc as _c_abc
from typing_extensions import Required, Annotated, get_args, get_origin
from typing_extensions import (
TypeIs,
Required,
Annotated,
get_args,
get_origin,
)

from .._types import InheritsGeneric
from .._compat import is_union as _is_union
Expand Down Expand Up @@ -36,6 +45,26 @@ def is_typevar(typ: type) -> bool:
return type(typ) == TypeVar # type: ignore


_TYPE_ALIAS_TYPES: tuple[type[typing_extensions.TypeAliasType], ...] = (typing_extensions.TypeAliasType,)
if sys.version_info >= (3, 12):
_TYPE_ALIAS_TYPES = (*_TYPE_ALIAS_TYPES, typing.TypeAliasType)


def is_type_alias_type(tp: Any, /) -> TypeIs[typing_extensions.TypeAliasType]:
"""Return whether the provided argument is an instance of `TypeAliasType`.
```python
type Int = int
is_type_alias_type(Int)
# > True
Str = TypeAliasType("Str", str)
is_type_alias_type(Str)
# > True
```
"""
return isinstance(tp, _TYPE_ALIAS_TYPES)


# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]
def strip_annotated_type(typ: type) -> type:
if is_required_type(typ) or is_annotated_type(typ):
Expand Down
18 changes: 17 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from typing import Any, Dict, List, Union, Optional, cast
from datetime import datetime, timezone
from typing_extensions import Literal, Annotated
from typing_extensions import Literal, Annotated, TypeAliasType

import pytest
import pydantic
Expand Down Expand Up @@ -828,3 +828,19 @@ class B(BaseModel):
# if the discriminator details object stays the same between invocations then
# we hit the cache
assert UnionType.__discriminator__ is discriminator


@pytest.mark.skipif(not PYDANTIC_V2, reason="TypeAliasType is not supported in Pydantic v1")
def test_type_alias_type() -> None:
Alias = TypeAliasType("Alias", str)

class Model(BaseModel):
alias: Alias
union: Union[int, Alias]

m = construct_type(value={"alias": "foo", "union": "bar"}, type_=Model)
assert isinstance(m, Model)
assert isinstance(m.alias, str)
assert m.alias == "foo"
assert isinstance(m.union, str)
assert m.union == "bar"
4 changes: 4 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
is_union_type,
extract_type_arg,
is_annotated_type,
is_type_alias_type,
)
from saturn_sdk._compat import PYDANTIC_V2, field_outer_type, get_model_fields
from saturn_sdk._models import BaseModel
Expand Down Expand Up @@ -51,6 +52,9 @@ def assert_matches_type(
path: list[str],
allow_none: bool = False,
) -> None:
if is_type_alias_type(type_):
type_ = type_.__value__

# unwrap `Annotated[T, ...]` -> `T`
if is_annotated_type(type_):
type_ = extract_type_arg(type_, 0)
Expand Down

0 comments on commit 012f185

Please sign in to comment.