Skip to content

Commit

Permalink
[EG] Regenerate beta (#35014)
Browse files Browse the repository at this point in the history
* update generation

* version

* skip
  • Loading branch information
l0lawrence committed Apr 22, 2024
1 parent 7c8ce19 commit 1060de0
Show file tree
Hide file tree
Showing 9 changed files with 1,635 additions and 243 deletions.
129 changes: 85 additions & 44 deletions sdk/eventgrid/azure-eventgrid/azure/eventgrid/_model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,21 @@
# license information.
# --------------------------------------------------------------------------
# pylint: disable=protected-access, arguments-differ, signature-differs, broad-except
# pyright: reportGeneralTypeIssues=false

import calendar
import decimal
import functools
import sys
import logging
import base64
import re
import copy
import typing
import email
import enum
import email.utils
from datetime import datetime, date, time, timedelta, timezone
from json import JSONEncoder
from typing_extensions import Self
import isodate
from azure.core.exceptions import DeserializationError
from azure.core import CaseInsensitiveEnumMeta
Expand All @@ -34,6 +36,7 @@
__all__ = ["SdkJSONEncoder", "Model", "rest_field", "rest_discriminator"]

TZ_UTC = timezone.utc
_T = typing.TypeVar("_T")


def _timedelta_as_isostr(td: timedelta) -> str:
Expand Down Expand Up @@ -144,6 +147,8 @@ def default(self, o): # pylint: disable=too-many-return-statements
except TypeError:
if isinstance(o, _Null):
return None
if isinstance(o, decimal.Decimal):
return float(o)
if isinstance(o, (bytes, bytearray)):
return _serialize_bytes(o, self.format)
try:
Expand Down Expand Up @@ -239,7 +244,7 @@ def _deserialize_date(attr: typing.Union[str, date]) -> date:
# This must NOT use defaultmonth/defaultday. Using None ensure this raises an exception.
if isinstance(attr, date):
return attr
return isodate.parse_date(attr, defaultmonth=None, defaultday=None)
return isodate.parse_date(attr, defaultmonth=None, defaultday=None) # type: ignore


def _deserialize_time(attr: typing.Union[str, time]) -> time:
Expand Down Expand Up @@ -275,6 +280,12 @@ def _deserialize_duration(attr):
return isodate.parse_duration(attr)


def _deserialize_decimal(attr):
if isinstance(attr, decimal.Decimal):
return attr
return decimal.Decimal(str(attr))


_DESERIALIZE_MAPPING = {
datetime: _deserialize_datetime,
date: _deserialize_date,
Expand All @@ -283,6 +294,7 @@ def _deserialize_duration(attr):
bytearray: _deserialize_bytes,
timedelta: _deserialize_duration,
typing.Any: lambda x: x,
decimal.Decimal: _deserialize_decimal,
}

_DESERIALIZE_MAPPING_WITHFORMAT = {
Expand Down Expand Up @@ -373,8 +385,12 @@ def get(self, key: str, default: typing.Any = None) -> typing.Any:
except KeyError:
return default

@typing.overload # type: ignore
def pop(self, key: str) -> typing.Any: # pylint: disable=no-member
@typing.overload
def pop(self, key: str) -> typing.Any:
...

@typing.overload
def pop(self, key: str, default: _T) -> _T:
...

@typing.overload
Expand All @@ -395,8 +411,8 @@ def clear(self) -> None:
def update(self, *args: typing.Any, **kwargs: typing.Any) -> None:
self._data.update(*args, **kwargs)

@typing.overload # type: ignore
def setdefault(self, key: str) -> typing.Any:
@typing.overload
def setdefault(self, key: str, default: None = None) -> None:
...

@typing.overload
Expand Down Expand Up @@ -434,6 +450,10 @@ def _serialize(o, format: typing.Optional[str] = None): # pylint: disable=too-m
return tuple(_serialize(x, format) for x in o)
if isinstance(o, (bytes, bytearray)):
return _serialize_bytes(o, format)
if isinstance(o, decimal.Decimal):
return float(o)
if isinstance(o, enum.Enum):
return o.value
try:
# First try datetime.datetime
return _serialize_datetime(o, format)
Expand All @@ -458,7 +478,13 @@ def _get_rest_field(


def _create_value(rf: typing.Optional["_RestField"], value: typing.Any) -> typing.Any:
return _deserialize(rf._type, value) if (rf and rf._is_model) else _serialize(value, rf._format if rf else None)
if not rf:
return _serialize(value, None)
if rf._is_multipart_file_input:
return value
if rf._is_model:
return _deserialize(rf._type, value)
return _serialize(value, rf._format)


class Model(_MyMutableMapping):
Expand Down Expand Up @@ -494,7 +520,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None:
def copy(self) -> "Model":
return Model(self.__dict__)

def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> "Model": # pylint: disable=unused-argument
def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> Self: # pylint: disable=unused-argument
# we know the last three classes in mro are going to be 'Model', 'dict', and 'object'
mros = cls.__mro__[:-3][::-1] # ignore model, dict, and object parents, and reverse the mro order
attr_to_rest_field: typing.Dict[str, _RestField] = { # map attribute name to rest_field property
Expand Down Expand Up @@ -536,7 +562,7 @@ def _deserialize(cls, data, exist_discriminators):
exist_discriminators.append(discriminator)
mapped_cls = cls.__mapping__.get(
data.get(discriminator), cls
) # pylint: disable=no-member
) # pyright: ignore # pylint: disable=no-member
if mapped_cls == cls:
return cls(data)
return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access
Expand All @@ -553,20 +579,25 @@ def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.
if exclude_readonly:
readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)]
for k, v in self.items():
if exclude_readonly and k in readonly_props: # pyright: ignore[reportUnboundVariable]
if exclude_readonly and k in readonly_props: # pyright: ignore
continue
result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly)
is_multipart_file_input = False
try:
is_multipart_file_input = next(rf for rf in self._attr_to_rest_field.values() if rf._rest_name == k)._is_multipart_file_input
except StopIteration:
pass
result[k] = v if is_multipart_file_input else Model._as_dict_value(v, exclude_readonly=exclude_readonly)
return result

@staticmethod
def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any:
if v is None or isinstance(v, _Null):
return None
if isinstance(v, (list, tuple, set)):
return [
return type(v)(
Model._as_dict_value(x, exclude_readonly=exclude_readonly)
for x in v
]
)
if isinstance(v, dict):
return {
dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly)
Expand Down Expand Up @@ -607,29 +638,22 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj
return obj
return _deserialize(model_deserializer, obj)

return functools.partial(_deserialize_model, annotation)
return functools.partial(_deserialize_model, annotation) # pyright: ignore
except Exception:
pass

# is it a literal?
try:
if sys.version_info >= (3, 8):
from typing import (
Literal,
) # pylint: disable=no-name-in-module, ungrouped-imports
else:
from typing_extensions import Literal # type: ignore # pylint: disable=ungrouped-imports

if annotation.__origin__ == Literal:
if annotation.__origin__ is typing.Literal: # pyright: ignore
return None
except AttributeError:
pass

# is it optional?
try:
if any(a for a in annotation.__args__ if a == type(None)):
if any(a for a in annotation.__args__ if a == type(None)): # pyright: ignore
if_obj_deserializer = _get_deserialize_callable_from_annotation(
next(a for a in annotation.__args__ if a != type(None)), module, rf
next(a for a in annotation.__args__ if a != type(None)), module, rf # pyright: ignore
)

def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Callable], obj):
Expand All @@ -642,7 +666,13 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla
pass

if getattr(annotation, "__origin__", None) is typing.Union:
deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__]
# initial ordering is we make `string` the last deserialization option, because it is often them most generic
deserializers = [
_get_deserialize_callable_from_annotation(arg, module, rf)
for arg in sorted(
annotation.__args__, key=lambda x: hasattr(x, "__name__") and x.__name__ == "str" # pyright: ignore
)
]

def _deserialize_with_union(deserializers, obj):
for deserializer in deserializers:
Expand All @@ -655,32 +685,31 @@ def _deserialize_with_union(deserializers, obj):
return functools.partial(_deserialize_with_union, deserializers)

try:
if annotation._name == "Dict":
key_deserializer = _get_deserialize_callable_from_annotation(annotation.__args__[0], module, rf)
value_deserializer = _get_deserialize_callable_from_annotation(annotation.__args__[1], module, rf)
if annotation._name == "Dict": # pyright: ignore
value_deserializer = _get_deserialize_callable_from_annotation(
annotation.__args__[1], module, rf # pyright: ignore
)

def _deserialize_dict(
key_deserializer: typing.Optional[typing.Callable],
value_deserializer: typing.Optional[typing.Callable],
obj: typing.Dict[typing.Any, typing.Any],
):
if obj is None:
return obj
return {
_deserialize(key_deserializer, k, module): _deserialize(value_deserializer, v, module)
k: _deserialize(value_deserializer, v, module)
for k, v in obj.items()
}

return functools.partial(
_deserialize_dict,
key_deserializer,
value_deserializer,
)
except (AttributeError, IndexError):
pass
try:
if annotation._name in ["List", "Set", "Tuple", "Sequence"]:
if len(annotation.__args__) > 1:
if annotation._name in ["List", "Set", "Tuple", "Sequence"]: # pyright: ignore
if len(annotation.__args__) > 1: # pyright: ignore

def _deserialize_multiple_sequence(
entry_deserializers: typing.List[typing.Optional[typing.Callable]],
Expand All @@ -694,10 +723,12 @@ def _deserialize_multiple_sequence(
)

entry_deserializers = [
_get_deserialize_callable_from_annotation(dt, module, rf) for dt in annotation.__args__
_get_deserialize_callable_from_annotation(dt, module, rf) for dt in annotation.__args__ # pyright: ignore
]
return functools.partial(_deserialize_multiple_sequence, entry_deserializers)
deserializer = _get_deserialize_callable_from_annotation(annotation.__args__[0], module, rf)
deserializer = _get_deserialize_callable_from_annotation(
annotation.__args__[0], module, rf # pyright: ignore
)

def _deserialize_sequence(
deserializer: typing.Optional[typing.Callable],
Expand All @@ -712,27 +743,29 @@ def _deserialize_sequence(
pass

def _deserialize_default(
annotation,
deserializer_from_mapping,
deserializer,
obj,
):
if obj is None:
return obj
try:
return _deserialize_with_callable(annotation, obj)
return _deserialize_with_callable(deserializer, obj)
except Exception:
pass
return _deserialize_with_callable(deserializer_from_mapping, obj)
return obj

return functools.partial(_deserialize_default, annotation, get_deserializer(annotation, rf))
if get_deserializer(annotation, rf):
return functools.partial(_deserialize_default, get_deserializer(annotation, rf))

return functools.partial(_deserialize_default, annotation)


def _deserialize_with_callable(
deserializer: typing.Optional[typing.Callable[[typing.Any], typing.Any]],
value: typing.Any,
):
try:
if value is None:
if value is None or isinstance(value, _Null):
return None
if deserializer is None:
return value
Expand Down Expand Up @@ -760,7 +793,8 @@ def _deserialize(
value = value.http_response.json()
if rf is None and format:
rf = _RestField(format=format)
deserializer = _get_deserialize_callable_from_annotation(deserializer, module, rf)
if not isinstance(deserializer, functools.partial):
deserializer = _get_deserialize_callable_from_annotation(deserializer, module, rf)
return _deserialize_with_callable(deserializer, value)


Expand All @@ -774,6 +808,7 @@ def __init__(
visibility: typing.Optional[typing.List[str]] = None,
default: typing.Any = _UNSET,
format: typing.Optional[str] = None,
is_multipart_file_input: bool = False,
):
self._type = type
self._rest_name_input = name
Expand All @@ -783,6 +818,11 @@ def __init__(
self._is_model = False
self._default = default
self._format = format
self._is_multipart_file_input = is_multipart_file_input

@property
def _class_type(self) -> typing.Any:
return getattr(self._type, "args", [None])[0]

@property
def _rest_name(self) -> str:
Expand Down Expand Up @@ -828,8 +868,9 @@ def rest_field(
visibility: typing.Optional[typing.List[str]] = None,
default: typing.Any = _UNSET,
format: typing.Optional[str] = None,
is_multipart_file_input: bool = False,
) -> typing.Any:
return _RestField(name=name, type=type, visibility=visibility, default=default, format=format)
return _RestField(name=name, type=type, visibility=visibility, default=default, format=format, is_multipart_file_input=is_multipart_file_input)


def rest_discriminator(
Expand Down
Loading

0 comments on commit 1060de0

Please sign in to comment.