Skip to content
This repository has been archived by the owner on Aug 2, 2023. It is now read-only.

Fix the declarative code. #43

Merged
merged 10 commits into from
Feb 6, 2018
2 changes: 1 addition & 1 deletion debugger_protocol/arg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ._common import NOT_SET, ANY # noqa
from ._datatype import FieldsNamespace # noqa
from ._decl import Union, Array, Field # noqa
from ._decl import Enum, Union, Array, Mapping, Field # noqa
from ._errors import ( # noqa
ArgumentError,
ArgMissingError, IncompleteArgError, ArgTypeMismatchError,
Expand Down
55 changes: 38 additions & 17 deletions debugger_protocol/arg/_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def _coerce(datatype, value, call=True):
# decl types
elif isinstance(datatype, Enum):
value = _coerce(datatype.datatype, value, call=False)
if value in datatype.choices:
if value in datatype.choice:
return value
elif isinstance(datatype, Union):
for dt in datatype:
Expand Down Expand Up @@ -79,49 +79,70 @@ class FieldsNamespace(Readonly, WithRepr):
PARAM_TYPE = None
PARAM = None

_TRAVERSING = False

@classmethod
def traverse(cls, op, **kwargs):
"""Apply op to each field in cls.FIELDS."""
if cls._TRAVERSING: # recursion check
return cls
cls._TRAVERSING = True

fields = cls._normalize(cls.FIELDS)
fields = fields.traverse(op)
cls.FIELDS = cls._normalize(fields)
try:
fields_traverse = fields.traverse
except AttributeError:
# must be normalizing right now...
return cls
fields = fields_traverse(op)
cls.FIELDS = cls._normalize(fields, force=True)

cls._TRAVERSING = False
return cls

@classmethod
def normalize(cls, *transforms):
"""Normalize FIELDS and apply the given ops."""
fields = cls._normalize(cls.FIELDS)
if not isinstance(fields, Fields):
fields = Fields(*fields)
for transform in transforms:
fields = _transform_datatype(fields, transform)
fields = cls._normalize(fields)
cls.FIELDS = fields
return cls

@classmethod
def _normalize(cls, fields):
def _normalize(cls, fields, force=False):
if fields is None:
raise TypeError('missing FIELDS')
if isinstance(fields, Fields):
try:
normalized = cls._normalized
except AttributeError:
normalized = cls._normalized = False
else:
fields = Fields(*fields)
normalized = cls._normalized = False
if not normalized:

try:
fixref = cls._fixref
except AttributeError:
fixref = cls._fixref = True
if not isinstance(fields, Fields):
fields = Fields(*fields)
if fixref or force:
cls._fixref = False
fields = _transform_datatype(fields,
lambda dt: _replace_ref(dt, cls))
return fields

@classmethod
def bind(cls, ns, **kwargs):
def param(cls):
param = cls.PARAM
if param is None:
if cls.PARAM_TYPE is None:
return cls(**ns)
return None
param = cls.PARAM_TYPE(cls.FIELDS, cls)
return param

@classmethod
def bind(cls, ns, **kwargs):
if isinstance(ns, cls):
return ns
param = cls.param()
if param is None:
return cls(**ns)
return param.bind(ns, **kwargs)

@classmethod
Expand Down
120 changes: 103 additions & 17 deletions debugger_protocol/arg/_decl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,37 +22,58 @@ def _is_simple(datatype):

def _normalize_datatype(datatype):
cls = type(datatype)
if datatype == REF or datatype is TYPE_REFERENCE:
# normalized when instantiated:
if isinstance(datatype, Union):
return datatype
elif isinstance(datatype, Array):
return datatype
elif isinstance(datatype, Array):
return datatype
elif isinstance(datatype, Mapping):
return datatype
elif isinstance(datatype, Fields):
return datatype
# do not need normalization:
elif datatype is TYPE_REFERENCE:
return TYPE_REFERENCE
elif datatype is ANY:
return ANY
elif datatype in list(SIMPLE_TYPES):
return datatype
elif isinstance(datatype, Enum):
return datatype
elif isinstance(datatype, Union):
return datatype
elif isinstance(datatype, Array):
return datatype
# convert to canonical types:
elif type(datatype) == type(REF) and datatype == REF:
return TYPE_REFERENCE
elif cls is set or cls is frozenset:
return Union(*datatype)
return Union.unordered(*datatype)
elif cls is list or cls is tuple:
datatype, = datatype
return Array(datatype)
elif cls is dict:
raise NotImplementedError
if len(datatype) != 1:
raise NotImplementedError
[keytype, valuetype], = datatype.items()
return Mapping(valuetype, keytype)
# fallback:
else:
return datatype
try:
normalize = datatype.normalize
except AttributeError:
return datatype
else:
return normalize()


def _transform_datatype(datatype, op):
datatype = op(datatype)
try:
dt_traverse = datatype.traverse
except AttributeError:
pass
else:
datatype = dt_traverse(lambda dt: _transform_datatype(dt, op))
return op(datatype)
return datatype


def _replace_ref(datatype, target):
Expand Down Expand Up @@ -113,6 +134,13 @@ class Union(tuple):
Sets and frozensets are treated Unions in declarations.
"""

@classmethod
def unordered(cls, *datatypes, **kwargs):
"""Return an unordered union of the given datatypes."""
self = cls(*datatypes, **kwargs)
self._ordered = False
return self

@classmethod
def _traverse(cls, datatypes, op):
changed = False
Expand All @@ -136,6 +164,7 @@ def __new__(cls, *datatypes, **kwargs):
)
self = super(Union, cls).__new__(cls, datatypes)
self._simple = all(_is_simple(dt) for dt in datatypes)
self._ordered = True
return self

def __repr__(self):
Expand All @@ -147,13 +176,16 @@ def __hash__(self):
def __eq__(self, other): # honors order
if not isinstance(other, Union):
return NotImplemented
if super(Union, self).__eq__(other):
elif super(Union, self).__eq__(other):
return True
if set(self) != set(other):
elif set(self) != set(other):
return False
if self._simple and other._simple:
elif self._simple and other._simple:
return True
return NotImplemented
elif not self._ordered and not other._ordered:
return True
else:
return NotImplemented

def __ne__(self, other):
return not (self == other)
Expand All @@ -167,7 +199,10 @@ def traverse(self, op, **kwargs):
datatypes, changed = self._traverse(self, op)
if not changed and not kwargs:
return self
return self.__class__(*datatypes, **kwargs)
updated = self.__class__(*datatypes, **kwargs)
if not self._ordered:
updated._ordered = False
return updated


class Array(Readonly):
Expand All @@ -179,13 +214,13 @@ class Array(Readonly):

def __init__(self, itemtype, _normalize=True):
if _normalize:
itemtype = _normalize_datatype(itemtype)
itemtype = _transform_datatype(itemtype, _normalize_datatype)
self._bind_attrs(
itemtype=itemtype,
)

def __repr__(self):
return '{}(datatype={!r})'.format(type(self).__name__, self.itemtype)
return '{}(itemtype={!r})'.format(type(self).__name__, self.itemtype)

def __hash__(self):
return hash(self.itemtype)
Expand All @@ -208,6 +243,56 @@ def traverse(self, op, **kwargs):
return self.__class__(datatype, **kwargs)


class Mapping(Readonly):
"""Declare a mapping (to a single type)."""

def __init__(self, valuetype, keytype=str, _normalize=True):
if _normalize:
keytype = _transform_datatype(keytype, _normalize_datatype)
valuetype = _transform_datatype(valuetype, _normalize_datatype)
self._bind_attrs(
keytype=keytype,
valuetype=valuetype,
)

def __repr__(self):
if self.keytype is str:
return '{}(valuetype={!r})'.format(
type(self).__name__, self.valuetype)
else:
return '{}(keytype={!r}, valuetype={!r})'.format(
type(self).__name__, self.keytype, self.valuetype)

def __hash__(self):
return hash((self.keytype, self.valuetype))

def __eq__(self, other):
try:
other_keytype = other.keytype
other_valuetype = other.valuetype
except AttributeError:
return False
if self.keytype != other_keytype:
return False
if self.valuetype != other_valuetype:
return False
return True

def __ne__(self, other):
return not (self == other)

def traverse(self, op, **kwargs):
"""Return a copy with op applied to the item datatype."""
keytype = op(self.keytype)
valuetype = op(self.valuetype)
if (keytype is self.keytype and
valuetype is self.valuetype and
not kwargs
):
return self
return self.__class__(valuetype, keytype, **kwargs)


class Field(namedtuple('Field', 'name datatype default optional')):
"""Declare a field in a data map param."""

Expand All @@ -220,7 +305,7 @@ def __new__(cls, name, datatype=str, enum=None, default=NOT_SET,
enum = None

if _normalize:
datatype = _normalize_datatype(datatype)
datatype = _transform_datatype(datatype, _normalize_datatype)
self = super(Field, cls).__new__(
cls,
name=str(name) if name else None,
Expand Down Expand Up @@ -315,4 +400,5 @@ def traverse(self, op, **kwargs):

if not changed and not kwargs:
return self
kwargs['_normalize'] = False
return self.__class__(*updated, **kwargs)
Loading