Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improvements of subclasses and jsonization time by add two options in global_config #521

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dataclasses_json/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def __init__(self):
Union[type, Optional[type]],
MarshmallowField
] = {}
self.enable_cache = False
self.include_class_info = False
# self._json_module = json

# TODO: #180
Expand Down
66 changes: 48 additions & 18 deletions dataclasses_json/core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import importlib
import json
import sys
import warnings
Expand All @@ -23,14 +24,18 @@
_is_optional, _isinstance_safe,
_get_type_arg_param,
_get_type_args, _is_counter,
_NO_ARGS,
_NO_ARGS,_cache,
_issubclass_safe, _is_tuple)

Json = Union[dict, list, str, int, float, bool, None]

confs = ['encoder', 'decoder', 'mm_field', 'letter_case', 'exclude']
FieldOverride = namedtuple('FieldOverride', confs) # type: ignore

_fields = fields
@_cache(2**12)
def _cached_fields(cls):
return _fields(cls)

class _ExtendedEncoder(json.JSONEncoder):
def default(self, o) -> Json:
Expand All @@ -52,13 +57,13 @@ def default(self, o) -> Json:
result = json.JSONEncoder.default(self, o)
return result


@_cache(2**12)
def _user_overrides_or_exts(cls):
global_metadata = defaultdict(dict)
encoders = cfg.global_config.encoders
decoders = cfg.global_config.decoders
mm_fields = cfg.global_config.mm_fields
for field in fields(cls):
for field in _cached_fields(cls):
if field.type in encoders:
global_metadata[field.name]['encoder'] = encoders[field.type]
if field.type in decoders:
Expand All @@ -72,7 +77,7 @@ def _user_overrides_or_exts(cls):
cls_config = {}

overrides = {}
for field in fields(cls):
for field in _cached_fields(cls):
field_config = {}
# first apply global overrides or extensions
field_metadata = global_metadata[field.name]
Expand Down Expand Up @@ -140,15 +145,36 @@ def _decode_letter_case_overrides(field_names, overrides):
return names



@_cache(2**14)
def _load_type(package_name, class_name):
try:
module = importlib.import_module(package_name)
type_ = getattr(module, class_name)
return type_
except (ImportError, AttributeError) as e:
raise ImportError(f"Failed to load type {class_name} from package {package_name}: {e}")

@_cache(maxsize=2**12)
def _cached_get_type_hints(cls):
return get_type_hints(cls)

def get_class(base_class,kvs):
if cfg.global_config.include_class_info:
if '__module__' in kvs and "__name__" in kvs:
return _load_type(kvs['__module__'],kvs['__name__'])
return base_class

def _decode_dataclass(cls, kvs, infer_missing):
cls = get_class(cls,kvs)
if _isinstance_safe(kvs, cls):
return kvs
overrides = _user_overrides_or_exts(cls)
kvs = {} if kvs is None and infer_missing else kvs
field_names = [field.name for field in fields(cls)]
field_names = [field.name for field in _cached_fields(cls)]
decode_names = _decode_letter_case_overrides(field_names, overrides)
kvs = {decode_names.get(k, k): v for k, v in kvs.items()}
missing_fields = {field for field in fields(cls) if field.name not in kvs}
missing_fields = {field for field in _cached_fields(cls) if field.name not in kvs}

for field in missing_fields:
if field.default is not MISSING:
Expand All @@ -162,15 +188,16 @@ def _decode_dataclass(cls, kvs, infer_missing):
kvs = _handle_undefined_parameters_safe(cls, kvs, usage="from")

init_kwargs = {}
types = get_type_hints(cls)
for field in fields(cls):
types = _cached_get_type_hints(cls)
for field in _cached_fields(cls):
# The field should be skipped from being added
# to init_kwargs as it's not intended as a constructor argument.
if not field.init:
continue

field_value = kvs[field.name]
field_type = types[field.name]

if field_value is None:
if not _is_optional(field_type):
warning = (
Expand Down Expand Up @@ -253,7 +280,7 @@ def _support_extended_types(field_type, field_value):
res = field_value
return res


@_cache(2**12)
def _is_supported_generic(type_):
if type_ is _NO_ARGS:
return False
Expand Down Expand Up @@ -406,27 +433,30 @@ def _asdict(obj, encode_json=False):
"""
if is_dataclass(obj):
result = []
overrides = _user_overrides_or_exts(obj)
for field in fields(obj):
if cfg.global_config.include_class_info:
result.append(('__module__',obj.__class__.__module__))
result.append(('__name__',obj.__class__.__name__))
overrides = _user_overrides_or_exts(obj.__class__)
for field in _cached_fields(obj.__class__):
if overrides[field.name].encoder:
value = getattr(obj, field.name)
else:
value = _asdict(
getattr(obj, field.name),
encode_json=encode_json
getattr(obj, field.name)
)
result.append((field.name, value))

result = _handle_undefined_parameters_safe(cls=obj, kvs=dict(result),
usage="to")
return _encode_overrides(dict(result), _user_overrides_or_exts(obj),
return _encode_overrides(dict(result), _user_overrides_or_exts(obj.__class__),
encode_json=encode_json)
elif isinstance(obj, Mapping):
return dict((_asdict(k, encode_json=encode_json),
_asdict(v, encode_json=encode_json)) for k, v in
return dict((_asdict(k),
_asdict(v)) for k, v in
obj.items())
# enum.IntFlag and enum.Flag are regarded as collections in Python 3.11, thus a check against Enum is needed
elif isinstance(obj, Collection) and not isinstance(obj, (str, bytes, Enum)):
return list(_asdict(v, encode_json=encode_json) for v in obj)
return list(_asdict(v) for v in obj)
else:
return copy.deepcopy(obj)
# return copy.deepcopy(obj)
return obj
22 changes: 19 additions & 3 deletions dataclasses_json/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,21 @@
from collections import Counter
from typing import (Collection, Mapping, Optional, TypeVar, Any, Type, Tuple,
Union, cast)

from dataclasses import _FIELDS
from dataclasses_json import cfg
import functools

def _cache(maxsize=128):
def decorator(func):
cached_func = functools.lru_cache(maxsize=maxsize)(func)
@functools.wraps(func)
def wrapper(*args, **kwargs):
if cfg.global_config.enable_cache:
return cached_func(*args, **kwargs)
else:
return func(*args, **kwargs)
return wrapper
return decorator

def _get_type_cons(type_):
"""More spaghetti logic for 3.6 vs. 3.7"""
Expand Down Expand Up @@ -66,6 +80,8 @@ def _hasargs(type_, *args):
else:
return res

def _is_dataclass(obj):
return hasattr(obj.__class__,_FIELDS) or hasattr(obj,_FIELDS)

class _NoArgs(object):
def __bool__(self):
Expand Down Expand Up @@ -111,7 +127,7 @@ def _isinstance_safe(o, t):
else:
return result


@_cache(maxsize=2**12)
def _issubclass_safe(cls, classinfo):
try:
return issubclass(cls, classinfo)
Expand All @@ -136,7 +152,7 @@ def _is_new_type_subclass_safe(cls, classinfo):
def _is_new_type(type_):
return inspect.isfunction(type_) and hasattr(type_, "__supertype__")


@_cache(maxsize=2**12)
def _is_optional(type_):
return (_issubclass_safe(type_, Optional) or
_hasargs(type_, type(None)) or
Expand Down
Loading
Loading