From d7b6df86a7cf8c9aba9696293e109e4bb91a1716 Mon Sep 17 00:00:00 2001 From: Eric Snow Date: Mon, 5 Feb 2018 23:05:29 +0000 Subject: [PATCH 01/10] Fix typos. --- debugger_protocol/arg/_datatype.py | 2 +- tests/debugger_protocol/arg/test__decl.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/debugger_protocol/arg/_datatype.py b/debugger_protocol/arg/_datatype.py index 575eb8d96..908ce947c 100644 --- a/debugger_protocol/arg/_datatype.py +++ b/debugger_protocol/arg/_datatype.py @@ -22,7 +22,7 @@ def _coerce(datatype, value, call=True): # decl types elif isinstance(datatype, Enum): value = _coerce(datatype.datatype, value, call=False) - if value in datatype.choices: + if value in datatype.choice: return value elif isinstance(datatype, Union): for dt in datatype: diff --git a/tests/debugger_protocol/arg/test__decl.py b/tests/debugger_protocol/arg/test__decl.py index 079dc77d4..b5dd8ed8a 100644 --- a/tests/debugger_protocol/arg/test__decl.py +++ b/tests/debugger_protocol/arg/test__decl.py @@ -136,7 +136,7 @@ class Spam(FieldsNamespace): # ... str, Field('a'), - Fields(Field('a')), + #Fields(Field('a')), Spam, Array(Spam), Union(Array(Spam)), @@ -431,7 +431,7 @@ def test_as_dict(self): Field('ham'), Field('eggs', Array(str)), ) - result = fields.as_dict + result = fields.as_dict() self.assertEqual(result, { 'spam': fields[0], From b5f7e17e987cec3d7cced3bd8edfada0d18f2096 Mon Sep 17 00:00:00 2001 From: Eric Snow Date: Mon, 5 Feb 2018 23:06:34 +0000 Subject: [PATCH 02/10] Do not try to bind when already bound. --- debugger_protocol/arg/_datatype.py | 2 + tests/debugger_protocol/arg/test__datatype.py | 57 +++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/debugger_protocol/arg/_datatype.py b/debugger_protocol/arg/_datatype.py index 908ce947c..3e8d5205e 100644 --- a/debugger_protocol/arg/_datatype.py +++ b/debugger_protocol/arg/_datatype.py @@ -117,6 +117,8 @@ def _normalize(cls, fields): @classmethod def bind(cls, ns, **kwargs): + if isinstance(ns, cls): + return ns param = cls.PARAM if param is None: if cls.PARAM_TYPE is None: diff --git a/tests/debugger_protocol/arg/test__datatype.py b/tests/debugger_protocol/arg/test__datatype.py index 3ddfa6afb..756b1a46f 100644 --- a/tests/debugger_protocol/arg/test__datatype.py +++ b/tests/debugger_protocol/arg/test__datatype.py @@ -4,6 +4,7 @@ from debugger_protocol.arg._common import ANY from debugger_protocol.arg._datatype import FieldsNamespace from debugger_protocol.arg._decl import Array, Field, Fields +from debugger_protocol.arg._param import Parameter, DatatypeHandler, Arg from ._common import ( BASIC_FULL, BASIC_MIN, Basic, @@ -172,6 +173,62 @@ def test_normalize_missing(self): ####### + def test_bind_no_param(self): + class Spam(FieldsNamespace): + FIELDS = [ + Field('a'), + ] + + arg = Spam.bind({'a': 'x'}) + + self.assertIsInstance(arg, Spam) + self.assertEqual(arg, Spam(a='x')) + + def test_bind_with_param_obj(self): + class Param(Parameter): + HANDLER = DatatypeHandler(ANY) + match_type = (lambda self, raw: self.HANDLER) + + class Spam(FieldsNamespace): + PARAM = Param(ANY) + FIELDS = [ + Field('a'), + ] + + arg = Spam.bind({'a': 'x'}) + + self.assertIsInstance(arg, Arg) + self.assertEqual(arg, Arg(Param(ANY), {'a': 'x'})) + + def test_bind_with_param_type(self): + class Param(Parameter): + HANDLER = DatatypeHandler(ANY) + match_type = (lambda self, raw: self.HANDLER) + + class Spam(FieldsNamespace): + PARAM_TYPE = Param + FIELDS = [ + Field('a'), + ] + + arg = Spam.bind({'a': 'x'}) + + self.assertIsInstance(arg, Arg) + self.assertEqual(arg, Arg(Param(Spam.FIELDS), {'a': 'x'})) + + def test_bind_already_bound(self): + class Spam(FieldsNamespace): + FIELDS = [ + Field('a'), + ] + + spam = Spam(a='x') + arg = Spam.bind(spam) + + self.assertIs(arg, spam) + + ####### + def test_fields_full(self): class Spam(FieldsNamespace): FIELDS = FIELDS_EXTENDED From f3e9b59f46174f9b3c3425ccecef712d506cc510 Mon Sep 17 00:00:00 2001 From: Eric Snow Date: Mon, 5 Feb 2018 23:24:30 +0000 Subject: [PATCH 03/10] Ignore missing optional args. --- debugger_protocol/arg/_params.py | 4 +++- tests/debugger_protocol/arg/test__params.py | 15 ++++++++++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/debugger_protocol/arg/_params.py b/debugger_protocol/arg/_params.py index f988539a1..1499b43fc 100644 --- a/debugger_protocol/arg/_params.py +++ b/debugger_protocol/arg/_params.py @@ -1,4 +1,4 @@ -from ._common import ANY, SIMPLE_TYPES +from ._common import NOT_SET, ANY, SIMPLE_TYPES from ._datatype import FieldsNamespace from ._decl import Enum, Union, Array, Field, Fields from ._errors import ArgTypeMismatchError @@ -372,6 +372,8 @@ def match_type(self, raw): if not field.optional: return None value = field.default + if value is NOT_SET: + continue param = self.params[field.name] handler = param.match_type(value) if handler is None: diff --git a/tests/debugger_protocol/arg/test__params.py b/tests/debugger_protocol/arg/test__params.py index 2a05d69f2..d0c31aec9 100644 --- a/tests/debugger_protocol/arg/test__params.py +++ b/tests/debugger_protocol/arg/test__params.py @@ -669,7 +669,7 @@ def test_as_data_complicated(self): class ComplexParameterTests(unittest.TestCase): - def test_match_type(self): + def test_match_type_none_missing(self): fields = Fields(*FIELDS_BASIC) param = ComplexParameter(fields) handler = param.match_type(BASIC_FULL) @@ -677,6 +677,19 @@ def test_match_type(self): self.assertIs(type(handler), ComplexParameter.HANDLER) self.assertEqual(handler.datatype.FIELDS, fields) + def test_match_type_missing_optional(self): + fields = Fields( + Field('name'), + Field.START_OPTIONAL, + Field('value'), + ) + param = ComplexParameter(fields) + handler = param.match_type({'name': 'spam'}) + + self.assertIs(type(handler), ComplexParameter.HANDLER) + self.assertEqual(handler.datatype.FIELDS, fields) + self.assertNotIn('value', handler.handlers) + def test_coerce(self): handler = ComplexParameter.HANDLER(Basic) coerced = handler.coerce(BASIC_FULL) From 8704c47dd249f404c6d60bce033e35262ad68520 Mon Sep 17 00:00:00 2001 From: Eric Snow Date: Mon, 5 Feb 2018 23:25:36 +0000 Subject: [PATCH 04/10] Respect datatype.normalize(). --- debugger_protocol/arg/_datatype.py | 1 + debugger_protocol/arg/_decl.py | 16 +++++++++++++++- tests/debugger_protocol/arg/test__decl.py | 7 +++++++ 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/debugger_protocol/arg/_datatype.py b/debugger_protocol/arg/_datatype.py index 3e8d5205e..05c7f2645 100644 --- a/debugger_protocol/arg/_datatype.py +++ b/debugger_protocol/arg/_datatype.py @@ -97,6 +97,7 @@ def normalize(cls, *transforms): fields = _transform_datatype(fields, transform) fields = cls._normalize(fields) cls.FIELDS = fields + return cls @classmethod def _normalize(cls, fields): diff --git a/debugger_protocol/arg/_decl.py b/debugger_protocol/arg/_decl.py index 50a1d33f7..e1770bd83 100644 --- a/debugger_protocol/arg/_decl.py +++ b/debugger_protocol/arg/_decl.py @@ -22,18 +22,26 @@ def _is_simple(datatype): def _normalize_datatype(datatype): cls = type(datatype) + # convert to canonical types (part 1): if datatype == REF or datatype is TYPE_REFERENCE: return TYPE_REFERENCE + # do not need normalization: elif datatype is ANY: return ANY elif datatype in list(SIMPLE_TYPES): return datatype elif isinstance(datatype, Enum): return datatype + # normalized when instantiated: elif isinstance(datatype, Union): return datatype elif isinstance(datatype, Array): return datatype + elif isinstance(datatype, Field): + return datatype + elif isinstance(datatype, Fields): + return datatype + # convert to canonical types (part 2): elif cls is set or cls is frozenset: return Union(*datatype) elif cls is list or cls is tuple: @@ -41,8 +49,14 @@ def _normalize_datatype(datatype): return Array(datatype) elif cls is dict: raise NotImplementedError + # fallback: else: - return datatype + try: + normalize = datatype.normalize + except AttributeError: + return datatype + else: + return normalize() def _transform_datatype(datatype, op): diff --git a/tests/debugger_protocol/arg/test__decl.py b/tests/debugger_protocol/arg/test__decl.py index b5dd8ed8a..40a94b077 100644 --- a/tests/debugger_protocol/arg/test__decl.py +++ b/tests/debugger_protocol/arg/test__decl.py @@ -13,6 +13,12 @@ class ModuleTests(unittest.TestCase): def test_normalize_datatype(self): + class Spam: + @classmethod + def normalize(cls): + return OKAY + + OKAY = object() NOOP = object() param = SimpleParameter(str) tests = [ @@ -45,6 +51,7 @@ def test_normalize_datatype(self): (object(), NOOP), (object, NOOP), (type, NOOP), + (Spam, OKAY), ] for datatype, expected in tests: if expected is NOOP: From 604859c01c2abebe5943a1868179541214f71010 Mon Sep 17 00:00:00 2001 From: Eric Snow Date: Tue, 6 Feb 2018 00:55:01 +0000 Subject: [PATCH 05/10] Recursively normalize in Array() and Field(). --- debugger_protocol/arg/_decl.py | 4 +-- tests/debugger_protocol/arg/test__decl.py | 30 +++++++++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/debugger_protocol/arg/_decl.py b/debugger_protocol/arg/_decl.py index e1770bd83..5169341ff 100644 --- a/debugger_protocol/arg/_decl.py +++ b/debugger_protocol/arg/_decl.py @@ -193,7 +193,7 @@ class Array(Readonly): def __init__(self, itemtype, _normalize=True): if _normalize: - itemtype = _normalize_datatype(itemtype) + itemtype = _transform_datatype(itemtype, _normalize_datatype) self._bind_attrs( itemtype=itemtype, ) @@ -234,7 +234,7 @@ def __new__(cls, name, datatype=str, enum=None, default=NOT_SET, enum = None if _normalize: - datatype = _normalize_datatype(datatype) + datatype = _transform_datatype(datatype, _normalize_datatype) self = super(Field, cls).__new__( cls, name=str(name) if name else None, diff --git a/tests/debugger_protocol/arg/test__decl.py b/tests/debugger_protocol/arg/test__decl.py index 40a94b077..a3ced351e 100644 --- a/tests/debugger_protocol/arg/test__decl.py +++ b/tests/debugger_protocol/arg/test__decl.py @@ -274,6 +274,21 @@ def test_normalized(self): with self.assertRaises(NotImplementedError): Array({1: 2}) + def test_normalized_transformed(self): + calls = 0 + + class Spam: + @classmethod + def traverse(cls, op): + nonlocal calls + calls += 1 + return cls + + array = Array(Spam) + + self.assertIs(array.itemtype, Spam) + self.assertEqual(calls, 1) + def test_traverse_noop(self): calls = [] op = (lambda dt: calls.append(dt) or dt) @@ -332,6 +347,21 @@ def test_normalized(self): with self.assertRaises(NotImplementedError): Field('spam', {1: 2}) + def test_normalized_transformed(self): + calls = 0 + + class Spam: + @classmethod + def traverse(cls, op): + nonlocal calls + calls += 1 + return cls + + field = Field('spam', Spam) + + self.assertIs(field.datatype, Spam) + self.assertEqual(calls, 1) + def test_traverse_noop(self): calls = [] op = (lambda dt: calls.append(dt) or dt) From baa3ecc0849cf67bfb91333023513b6fb9d24391 Mon Sep 17 00:00:00 2001 From: Eric Snow Date: Tue, 6 Feb 2018 03:51:29 +0000 Subject: [PATCH 06/10] Apply the op to the datatype before traversing. --- debugger_protocol/arg/_datatype.py | 39 ++++++---- debugger_protocol/arg/_decl.py | 55 +++++++++----- tests/debugger_protocol/arg/test__datatype.py | 75 ++++++++++++++++--- tests/debugger_protocol/arg/test__decl.py | 45 ++++++----- 4 files changed, 146 insertions(+), 68 deletions(-) diff --git a/debugger_protocol/arg/_datatype.py b/debugger_protocol/arg/_datatype.py index 05c7f2645..80cd99f3a 100644 --- a/debugger_protocol/arg/_datatype.py +++ b/debugger_protocol/arg/_datatype.py @@ -79,20 +79,31 @@ class FieldsNamespace(Readonly, WithRepr): PARAM_TYPE = None PARAM = None + _TRAVERSING = False + @classmethod def traverse(cls, op, **kwargs): """Apply op to each field in cls.FIELDS.""" + if cls._TRAVERSING: # recursion check + return cls + cls._TRAVERSING = True + fields = cls._normalize(cls.FIELDS) - fields = fields.traverse(op) - cls.FIELDS = cls._normalize(fields) + try: + fields_traverse = fields.traverse + except AttributeError: + # must be normalizing right now... + return cls + fields = fields_traverse(op) + cls.FIELDS = cls._normalize(fields, force=True) + + cls._TRAVERSING = False return cls @classmethod def normalize(cls, *transforms): """Normalize FIELDS and apply the given ops.""" fields = cls._normalize(cls.FIELDS) - if not isinstance(fields, Fields): - fields = Fields(*fields) for transform in transforms: fields = _transform_datatype(fields, transform) fields = cls._normalize(fields) @@ -100,18 +111,18 @@ def normalize(cls, *transforms): return cls @classmethod - def _normalize(cls, fields): + def _normalize(cls, fields, force=False): if fields is None: raise TypeError('missing FIELDS') - if isinstance(fields, Fields): - try: - normalized = cls._normalized - except AttributeError: - normalized = cls._normalized = False - else: - fields = Fields(*fields) - normalized = cls._normalized = False - if not normalized: + + try: + fixref = cls._fixref + except AttributeError: + fixref = cls._fixref = True + if not isinstance(fields, Fields): + fields = Fields(*fields) + if fixref or force: + cls._fixref = False fields = _transform_datatype(fields, lambda dt: _replace_ref(dt, cls)) return fields diff --git a/debugger_protocol/arg/_decl.py b/debugger_protocol/arg/_decl.py index 5169341ff..2fd275d2a 100644 --- a/debugger_protocol/arg/_decl.py +++ b/debugger_protocol/arg/_decl.py @@ -22,18 +22,8 @@ def _is_simple(datatype): def _normalize_datatype(datatype): cls = type(datatype) - # convert to canonical types (part 1): - if datatype == REF or datatype is TYPE_REFERENCE: - return TYPE_REFERENCE - # do not need normalization: - elif datatype is ANY: - return ANY - elif datatype in list(SIMPLE_TYPES): - return datatype - elif isinstance(datatype, Enum): - return datatype # normalized when instantiated: - elif isinstance(datatype, Union): + if isinstance(datatype, Union): return datatype elif isinstance(datatype, Array): return datatype @@ -41,9 +31,20 @@ def _normalize_datatype(datatype): return datatype elif isinstance(datatype, Fields): return datatype - # convert to canonical types (part 2): + # do not need normalization: + elif datatype is TYPE_REFERENCE: + return TYPE_REFERENCE + elif datatype is ANY: + return ANY + elif datatype in list(SIMPLE_TYPES): + return datatype + elif isinstance(datatype, Enum): + return datatype + # convert to canonical types: + elif type(datatype) == type(REF) and datatype == REF: + return TYPE_REFERENCE elif cls is set or cls is frozenset: - return Union(*datatype) + return Union.unordered(*datatype) elif cls is list or cls is tuple: datatype, = datatype return Array(datatype) @@ -60,13 +61,14 @@ def _normalize_datatype(datatype): def _transform_datatype(datatype, op): + datatype = op(datatype) try: dt_traverse = datatype.traverse except AttributeError: pass else: datatype = dt_traverse(lambda dt: _transform_datatype(dt, op)) - return op(datatype) + return datatype def _replace_ref(datatype, target): @@ -127,6 +129,13 @@ class Union(tuple): Sets and frozensets are treated Unions in declarations. """ + @classmethod + def unordered(cls, *datatypes, **kwargs): + """Return an unordered union of the given datatypes.""" + self = cls(*datatypes, **kwargs) + self._ordered = False + return self + @classmethod def _traverse(cls, datatypes, op): changed = False @@ -150,6 +159,7 @@ def __new__(cls, *datatypes, **kwargs): ) self = super(Union, cls).__new__(cls, datatypes) self._simple = all(_is_simple(dt) for dt in datatypes) + self._ordered = True return self def __repr__(self): @@ -161,13 +171,16 @@ def __hash__(self): def __eq__(self, other): # honors order if not isinstance(other, Union): return NotImplemented - if super(Union, self).__eq__(other): + elif super(Union, self).__eq__(other): return True - if set(self) != set(other): + elif set(self) != set(other): return False - if self._simple and other._simple: + elif self._simple and other._simple: + return True + elif not self._ordered and not other._ordered: return True - return NotImplemented + else: + return NotImplemented def __ne__(self, other): return not (self == other) @@ -181,7 +194,10 @@ def traverse(self, op, **kwargs): datatypes, changed = self._traverse(self, op) if not changed and not kwargs: return self - return self.__class__(*datatypes, **kwargs) + updated = self.__class__(*datatypes, **kwargs) + if not self._ordered: + updated._ordered = False + return updated class Array(Readonly): @@ -329,4 +345,5 @@ def traverse(self, op, **kwargs): if not changed and not kwargs: return self + kwargs['_normalize'] = False return self.__class__(*updated, **kwargs) diff --git a/tests/debugger_protocol/arg/test__datatype.py b/tests/debugger_protocol/arg/test__datatype.py index 756b1a46f..326299d0a 100644 --- a/tests/debugger_protocol/arg/test__datatype.py +++ b/tests/debugger_protocol/arg/test__datatype.py @@ -3,7 +3,7 @@ from debugger_protocol.arg._common import ANY from debugger_protocol.arg._datatype import FieldsNamespace -from debugger_protocol.arg._decl import Array, Field, Fields +from debugger_protocol.arg._decl import Union, Array, Field, Fields from debugger_protocol.arg._param import Parameter, DatatypeHandler, Arg from ._common import ( @@ -135,23 +135,25 @@ class Spam(FieldsNamespace): self.assertIs(Spam.FIELDS, fields) for i, field in enumerate(Spam.FIELDS): self.assertIs(field, fieldlist[i]) + self.maxDiff = None self.assertEqual(calls, [ - (op1, str), + (op1, fields), (op1, Field('spam')), - (op1, int), + (op1, str), (op1, Field('ham', int)), - (op1, ANY), - (op1, Array(ANY)), + (op1, int), (op1, Field('eggs', Array(ANY))), - (op1, fields), - (op2, str), + (op1, Array(ANY)), + (op1, ANY), + + (op2, fields), (op2, Field('spam')), - (op2, int), + (op2, str), (op2, Field('ham', int)), - (op2, ANY), - (op2, Array(ANY)), + (op2, int), (op2, Field('eggs', Array(ANY))), - (op2, fields), + (op2, Array(ANY)), + (op2, ANY), ]) def test_normalize_with_op_changed(self): @@ -167,6 +169,57 @@ class Spam(FieldsNamespace): Field('spam', Array(int)), )) + def test_normalize_declarative(self): + class Spam(FieldsNamespace): + FIELDS = [ + Field('a'), + Field('b', bool), + Field.START_OPTIONAL, + Field('c', {int, str}), + Field('d', [int]), + Field('e', ANY), + Field('f', ''), + ] + + class Ham(FieldsNamespace): + FIELDS = [ + Field('w', Spam), + Field('x', int), + Field('y', frozenset({int, str})), + Field('z', (int,)), + ] + + class Eggs(FieldsNamespace): + FIELDS = [ + Field('b', [Ham]), + Field('x', [{str, ('',)}], optional=True), + Field('d', {Spam, ''}, optional=True), + ] + + Eggs.normalize() + + self.assertEqual(Spam.FIELDS, Fields( + Field('a'), + Field('b', bool), + Field('c', Union(int, str), optional=True), + Field('d', Array(int), optional=True), + Field('e', ANY, optional=True), + Field('f', Spam, optional=True), + )) + self.assertEqual(Ham.FIELDS, Fields( + Field('w', Spam), + Field('x', int), + Field('y', Union(int, str)), + Field('z', Array(int)), + )) + self.assertEqual(Eggs.FIELDS, Fields( + Field('b', Array(Ham)), + Field('x', + Array(Union.unordered(str, Array(Eggs))), + optional=True), + Field('d', Union.unordered(Spam, Eggs), optional=True), + )) + def test_normalize_missing(self): with self.assertRaises(TypeError): FieldsNamespace.normalize() diff --git a/tests/debugger_protocol/arg/test__decl.py b/tests/debugger_protocol/arg/test__decl.py index a3ced351e..8164c09fa 100644 --- a/tests/debugger_protocol/arg/test__decl.py +++ b/tests/debugger_protocol/arg/test__decl.py @@ -104,8 +104,6 @@ class Spam(FieldsNamespace): Field('a'), ] - Spam.normalize() - fields = Fields(Field('...')) field_spam = Field('spam', ANY) field_ham = Field('ham', Union( @@ -120,42 +118,43 @@ class Spam(FieldsNamespace): ) tests = { Array(str): [ - str, Array(str), + str, ], Field('...'): [ - str, Field('...'), + str, ], fields: [ - str, - Field('...'), fields, + Field('...'), + str, ], nested: [ - str, - Field('...'), - fields, + nested, + # ... Field('???', fields), + fields, + Field('...'), + str, # ... - ANY, Field('spam', ANY), + ANY, # ... - str, - Field('a'), - #Fields(Field('a')), - Spam, - Array(Spam), - Union(Array(Spam)), field_ham, + Union(Array(Spam)), + Array(Spam), + Spam, + #Fields(Field('a')), + Field('a'), + str, # ... - TYPE_REFERENCE, - Array(TYPE_REFERENCE), field_eggs, - # ... - nested, + Array(TYPE_REFERENCE), + TYPE_REFERENCE, ], } + self.maxDiff = None for datatype, expected in tests.items(): calls = [] op = (lambda dt: calls.append(dt) or dt) @@ -172,10 +171,8 @@ class Spam(FieldsNamespace): transformed = _transform_datatype(datatype, op) self.assertIs(transformed, datatype) - self.assertEqual(set(calls[:2]), {str, int}) - self.assertEqual(calls[2:], [ - Union(str, int), - ]) + self.assertEqual(calls[0], Union(str, int)) + self.assertEqual(set(calls[1:]), {str, int}) class EnumTests(unittest.TestCase): From 215f91d0441c1b60216473d4cb41d975911b3277 Mon Sep 17 00:00:00 2001 From: Eric Snow Date: Tue, 6 Feb 2018 04:40:17 +0000 Subject: [PATCH 07/10] Avoid infinite recursion in param_from_datatype(). --- debugger_protocol/arg/_datatype.py | 15 +++++++++++---- debugger_protocol/arg/_params.py | 4 +++- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/debugger_protocol/arg/_datatype.py b/debugger_protocol/arg/_datatype.py index 80cd99f3a..b4d3000e4 100644 --- a/debugger_protocol/arg/_datatype.py +++ b/debugger_protocol/arg/_datatype.py @@ -128,14 +128,21 @@ def _normalize(cls, fields, force=False): return fields @classmethod - def bind(cls, ns, **kwargs): - if isinstance(ns, cls): - return ns + def param(cls): param = cls.PARAM if param is None: if cls.PARAM_TYPE is None: - return cls(**ns) + return None param = cls.PARAM_TYPE(cls.FIELDS, cls) + return param + + @classmethod + def bind(cls, ns, **kwargs): + if isinstance(ns, cls): + return ns + param = cls.param() + if param is None: + return cls(**ns) return param.bind(ns, **kwargs) @classmethod diff --git a/debugger_protocol/arg/_params.py b/debugger_protocol/arg/_params.py index 1499b43fc..eb0087c3d 100644 --- a/debugger_protocol/arg/_params.py +++ b/debugger_protocol/arg/_params.py @@ -49,7 +49,8 @@ def param_from_datatype(datatype, **kwargs): elif not isinstance(datatype, type): raise NotImplementedError elif issubclass(datatype, FieldsNamespace): - return ComplexParameter(datatype, **kwargs) + param = datatype.param() + return param or ComplexParameter(datatype, **kwargs) else: raise NotImplementedError @@ -344,6 +345,7 @@ class ArgNamespace(FieldsNamespace): msg = 'expected Fields or FieldsNamespace, got {!r}' raise ValueError(msg.format(datatype)) datatype.normalize() + datatype.PARAM = self # We set handler later in match_type(). super(ComplexParameter, self).__init__(datatype) From 800e05cd7acd0056a8d180e740e0083f5bda2bb5 Mon Sep 17 00:00:00 2001 From: Eric Snow Date: Tue, 6 Feb 2018 08:37:33 +0000 Subject: [PATCH 08/10] Add a Mapping declaration type. --- debugger_protocol/arg/__init__.py | 2 +- debugger_protocol/arg/_decl.py | 60 ++++++++++++++- tests/debugger_protocol/arg/test__decl.py | 92 ++++++++++++++++++----- 3 files changed, 131 insertions(+), 23 deletions(-) diff --git a/debugger_protocol/arg/__init__.py b/debugger_protocol/arg/__init__.py index f294ec8a2..ec58e37aa 100644 --- a/debugger_protocol/arg/__init__.py +++ b/debugger_protocol/arg/__init__.py @@ -1,6 +1,6 @@ from ._common import NOT_SET, ANY # noqa from ._datatype import FieldsNamespace # noqa -from ._decl import Union, Array, Field # noqa +from ._decl import Enum, Union, Array, Mapping, Field # noqa from ._errors import ( # noqa ArgumentError, ArgMissingError, IncompleteArgError, ArgTypeMismatchError, diff --git a/debugger_protocol/arg/_decl.py b/debugger_protocol/arg/_decl.py index 2fd275d2a..049b2b77e 100644 --- a/debugger_protocol/arg/_decl.py +++ b/debugger_protocol/arg/_decl.py @@ -27,7 +27,9 @@ def _normalize_datatype(datatype): return datatype elif isinstance(datatype, Array): return datatype - elif isinstance(datatype, Field): + elif isinstance(datatype, Array): + return datatype + elif isinstance(datatype, Mapping): return datatype elif isinstance(datatype, Fields): return datatype @@ -49,7 +51,10 @@ def _normalize_datatype(datatype): datatype, = datatype return Array(datatype) elif cls is dict: - raise NotImplementedError + if len(datatype) != 1: + raise NotImplementedError + [keytype, valuetype], = datatype.items() + return Mapping(valuetype, keytype) # fallback: else: try: @@ -215,7 +220,7 @@ def __init__(self, itemtype, _normalize=True): ) def __repr__(self): - return '{}(datatype={!r})'.format(type(self).__name__, self.itemtype) + return '{}(itemtype={!r})'.format(type(self).__name__, self.itemtype) def __hash__(self): return hash(self.itemtype) @@ -238,6 +243,55 @@ def traverse(self, op, **kwargs): return self.__class__(datatype, **kwargs) +class Mapping(Readonly): + """Declare a mapping (to a single type).""" + + def __init__(self, valuetype, keytype=str, _normalize=True): + if _normalize: + keytype = _transform_datatype(keytype, _normalize_datatype) + valuetype = _transform_datatype(valuetype, _normalize_datatype) + self._bind_attrs( + keytype=keytype, + valuetype=valuetype, + ) + + def __repr__(self): + if self.keytype is str: + return '{}(valuetype={!r})'.format(type(self).__name__, self.valuetype) + else: + return '{}(keytype={!r}, valuetype={!r})'.format( + type(self).__name__, self.keytype, self.valuetype) + + def __hash__(self): + return hash((self.keytype, self.valuetype)) + + def __eq__(self, other): + try: + other_keytype = other.keytype + other_valuetype = other.valuetype + except AttributeError: + return False + if self.keytype != other_keytype: + return False + if self.valuetype != other_valuetype: + return False + return True + + def __ne__(self, other): + return not (self == other) + + def traverse(self, op, **kwargs): + """Return a copy with op applied to the item datatype.""" + keytype = op(self.keytype) + valuetype = op(self.valuetype) + if (keytype is self.keytype and + valuetype is self.valuetype and + not kwargs + ): + return self + return self.__class__(valuetype, keytype, **kwargs) + + class Field(namedtuple('Field', 'name datatype default optional')): """Declare a field in a data map param.""" diff --git a/tests/debugger_protocol/arg/test__decl.py b/tests/debugger_protocol/arg/test__decl.py index 8164c09fa..2aa953ee4 100644 --- a/tests/debugger_protocol/arg/test__decl.py +++ b/tests/debugger_protocol/arg/test__decl.py @@ -4,7 +4,7 @@ from debugger_protocol.arg._datatype import FieldsNamespace from debugger_protocol.arg._decl import ( REF, TYPE_REFERENCE, _normalize_datatype, _transform_datatype, - Enum, Union, Array, Field, Fields) + Enum, Union, Array, Mapping, Field, Fields) from debugger_protocol.arg._param import Parameter, DatatypeHandler, Arg from debugger_protocol.arg._params import ( SimpleParameter, UnionParameter, ArrayParameter, ComplexParameter) @@ -37,6 +37,8 @@ def normalize(cls): (Array(str), NOOP), ([str], Array(str)), ((str,), Array(str)), + (Mapping(str), NOOP), + ({str: str}, Mapping(str)), # others (Field('spam'), NOOP), (Fields(Field('spam')), NOOP), @@ -61,9 +63,6 @@ def normalize(cls): self.assertEqual(datatype, expected) - with self.assertRaises(NotImplementedError): - _normalize_datatype({1: 2}) - def test_transform_datatype_simple(self): datatypes = [ REF, @@ -102,6 +101,7 @@ def test_transform_datatype_container(self): class Spam(FieldsNamespace): FIELDS = [ Field('a'), + Field('b', {str: str}) ] fields = Fields(Field('...')) @@ -145,9 +145,12 @@ class Spam(FieldsNamespace): Union(Array(Spam)), Array(Spam), Spam, - #Fields(Field('a')), Field('a'), str, + Field('b', Mapping(str)), + Mapping(str), + str, + str, # ... field_eggs, Array(TYPE_REFERENCE), @@ -213,6 +216,7 @@ def test_normalized(self): (frozenset([str, int]), Union(*frozenset([str, int]))), ([str], Array(str)), ((str,), Array(str)), + ({str: str}, Mapping(str)), (None, None), ] for datatype, expected in tests: @@ -221,9 +225,6 @@ def test_normalized(self): self.assertEqual(union, Union(int, expected, str)) - with self.assertRaises(NotImplementedError): - Union({1: 2}) - def test_traverse_noop(self): calls = [] op = (lambda dt: calls.append(dt) or dt) @@ -260,6 +261,7 @@ def test_normalized(self): (frozenset([str, int]), Union(str, int)), ([str], Array(str)), ((str,), Array(str)), + ({str: str}, Mapping(str)), (None, None), ] for datatype, expected in tests: @@ -268,9 +270,6 @@ def test_normalized(self): self.assertEqual(array, Array(expected)) - with self.assertRaises(NotImplementedError): - Array({1: 2}) - def test_normalized_transformed(self): calls = 0 @@ -311,6 +310,67 @@ def test_traverse_changed(self): ]) +class MappingTests(unittest.TestCase): + + def test_normalized(self): + tests = [ + (REF, TYPE_REFERENCE), + ({str, int}, Union(str, int)), + (frozenset([str, int]), Union(str, int)), + ([str], Array(str)), + ((str,), Array(str)), + ({str: str}, Mapping(str)), + (None, None), + ] + for datatype, expected in tests: + with self.subTest(datatype): + mapping = Mapping(datatype) + + self.assertEqual(mapping, Mapping(expected)) + + def test_normalized_transformed(self): + calls = 0 + + class Spam: + @classmethod + def traverse(cls, op): + nonlocal calls + calls += 1 + return cls + + mapping = Mapping(Spam) + + self.assertIs(mapping.keytype, str) + self.assertIs(mapping.valuetype, Spam) + self.assertEqual(calls, 1) + + def test_traverse_noop(self): + calls = [] + op = (lambda dt: calls.append(dt) or dt) + mapping = Mapping(Union(str, int)) + transformed = mapping.traverse(op) + + self.assertIs(transformed, mapping) + self.assertCountEqual(calls, [ + str, + # Note that it did not recurse into Union(str, int). + Union(str, int), + ]) + + def test_traverse_changed(self): + calls = [] + op = (lambda dt: calls.append(dt) or str) + mapping = Mapping(ANY) + transformed = mapping.traverse(op) + + self.assertIsNot(transformed, mapping) + self.assertEqual(transformed, Mapping(str)) + self.assertEqual(calls, [ + str, + ANY, + ]) + + class FieldTests(unittest.TestCase): def test_defaults(self): @@ -333,6 +393,7 @@ def test_normalized(self): (frozenset([str, int]), Union(str, int)), ([str], Array(str)), ((str,), Array(str)), + ({str: str}, Mapping(str)), (None, None), ] for datatype, expected in tests: @@ -341,9 +402,6 @@ def test_normalized(self): self.assertEqual(field, Field('spam', expected)) - with self.assertRaises(NotImplementedError): - Field('spam', {1: 2}) - def test_normalized_transformed(self): calls = 0 @@ -420,6 +478,7 @@ def test_normalized(self): (frozenset([str, int]), Union(str, int)), ([str], Array(str)), ((str,), Array(str)), + ({str: str}, Mapping(str)), (None, None), ] for datatype, expected in tests: @@ -432,11 +491,6 @@ def test_normalized(self): Field('spam', expected), ]) - with self.assertRaises(NotImplementedError): - Fields( - Field('spam', {1: 2}), - ) - def test_with_START_OPTIONAL(self): fields = Fields( Field('spam'), From b9274f074d4d27fdf55a43f389235aa2177ee63f Mon Sep 17 00:00:00 2001 From: Eric Snow Date: Tue, 6 Feb 2018 09:14:08 +0000 Subject: [PATCH 09/10] Add MappingParameter. --- debugger_protocol/arg/_params.py | 94 ++++++++++++++++++++- tests/debugger_protocol/arg/test__params.py | 93 +++++++++++++++++++- 2 files changed, 184 insertions(+), 3 deletions(-) diff --git a/debugger_protocol/arg/_params.py b/debugger_protocol/arg/_params.py index eb0087c3d..5562d6235 100644 --- a/debugger_protocol/arg/_params.py +++ b/debugger_protocol/arg/_params.py @@ -1,6 +1,6 @@ from ._common import NOT_SET, ANY, SIMPLE_TYPES from ._datatype import FieldsNamespace -from ._decl import Enum, Union, Array, Field, Fields +from ._decl import Enum, Union, Array, Mapping, Field, Fields from ._errors import ArgTypeMismatchError from ._param import Parameter, DatatypeHandler @@ -277,6 +277,98 @@ def match_type(self, raw): return self.HANDLER(self.datatype, handlers) +class MappingParameter(Parameter): + """A parameter that is a mapping of some fixed type.""" + + class HANDLER(DatatypeHandler): + + def __init__(self, datatype, handlers=None, + keyparam=None, valueparam=None): + if not isinstance(datatype, Mapping): + raise ValueError( + 'expected an Mapping, got {!r}'.format(datatype)) + super(MappingParameter.HANDLER, self).__init__(datatype) + self.handlers = handlers + self.keyparam = keyparam + self.valueparam = valueparam + + def coerce(self, raw): + if self.handlers is None: + if self.keyparam is None: + keytype = self.datatype.keytype + self.keyparam = param_from_datatype(keytype) + if self.valueparam is None: + valuetype = self.datatype.valuetype + self.valueparam = param_from_datatype(valuetype) + handlers = {} + for key, value in raw.items(): + keyhandler = self.keyparam.match_type(key) + if keyhandler is None: + raise ArgTypeMismatchError(key) + valuehandler = self.valueparam.match_type(value) + if valuehandler is None: + raise ArgTypeMismatchError(value) + handlers[key] = (keyhandler, valuehandler) + self.handlers = handlers + + result = {} + for key, value in raw.items(): + keyhandler, valuehandler = self.handlers[key] + key = keyhandler.coerce(key) + value = valuehandler.coerce(value) + result[key] = value + return result + + def validate(self, coerced): + if self.handlers is None: + raise TypeError('coerce first') + for key, value in coerced.items(): + keyhandler, valuehandler = self.handlers[key] + keyhandler.validate(key) + valuehandler.validate(value) + + def as_data(self, coerced): + if self.handlers is None: + raise TypeError('coerce first') + data = {} + for key, value in coerced.items(): + keyhandler, valuehandler = self.handlers[key] + key = keyhandler.as_data(key) + value = valuehandler.as_data(value) + data[key] = value + return data + + @classmethod + def from_valuetype(cls, valuetype, keytype=str, **kwargs): + datatype = Mapping(valuetype, keytype) + return cls(datatype, **kwargs) + + def __init__(self, datatype): + if not isinstance(datatype, Mapping): + raise ValueError('expected Mapping, got {!r}'.format(datatype)) + keyparam = param_from_datatype(datatype.keytype) + valueparam = param_from_datatype(datatype.valuetype) + handler = self.HANDLER(datatype, None, keyparam, valueparam) + super(MappingParameter, self).__init__(datatype, handler) + + self.keyparam = keyparam + self.valueparam = valueparam + + def match_type(self, raw): + if not isinstance(raw, dict): + return None + handlers = {} + for key, value in raw.items(): + keyhandler = self.keyparam.match_type(key) + if keyhandler is None: + return None + valuehandler = self.valueparam.match_type(value) + if valuehandler is None: + return None + handlers[key] = (keyhandler, valuehandler) + return self.HANDLER(self.datatype, handlers) + + class ComplexParameter(Parameter): class HANDLER(DatatypeHandler): diff --git a/tests/debugger_protocol/arg/test__params.py b/tests/debugger_protocol/arg/test__params.py index d0c31aec9..64aa59f37 100644 --- a/tests/debugger_protocol/arg/test__params.py +++ b/tests/debugger_protocol/arg/test__params.py @@ -1,13 +1,14 @@ import unittest from debugger_protocol.arg._common import NOT_SET, ANY -from debugger_protocol.arg._decl import Enum, Union, Array, Field, Fields +from debugger_protocol.arg._decl import ( + Enum, Union, Array, Mapping, Field, Fields) from debugger_protocol.arg._param import Parameter, DatatypeHandler from debugger_protocol.arg._params import ( param_from_datatype, NoopParameter, SingletonParameter, SimpleParameter, EnumParameter, - UnionParameter, ArrayParameter, ComplexParameter) + UnionParameter, ArrayParameter, MappingParameter, ComplexParameter) from ._common import FIELDS_BASIC, BASIC_FULL, Basic @@ -667,6 +668,94 @@ def test_as_data_complicated(self): self.assertEqual(data, value) +class MappingParameterTests(unittest.TestCase): + + def test_match_type_match(self): + param = MappingParameter(Mapping(int)) + expected = MappingParameter.HANDLER(Mapping(int)) + values = [ + {'a': 1, 'b': 2, 'c': 3}, + {}, + ] + for value in values: + with self.subTest(value): + handler = param.match_type(value) + + self.assertEqual(handler, expected) + + def test_match_type_no_match(self): + param = MappingParameter(Mapping(int)) + values = [ + {'a': 1, 'b': '2', 'c': 3}, + [('a', 1), ('b', 2), ('c', 3)], + 'spam', + ] + for value in values: + with self.subTest(value): + handler = param.match_type(value) + + self.assertIs(handler, None) + + def test_coerce_simple(self): + param = MappingParameter(Mapping(int)) + values = [ + {'a': 1, 'b': 2, 'c': 3}, + {}, + ] + for value in values: + with self.subTest(value): + handler = param.match_type(value) + coerced = handler.coerce(value) + + self.assertEqual(coerced, value) + + def test_coerce_complicated(self): + param = MappingParameter(Mapping(Union(int, Basic))) + value = { + 'a': 1, + 'b': BASIC_FULL, + 'c': 3, + } + handler = param.match_type(value) + coerced = handler.coerce(value) + + self.assertEqual(coerced, { + 'a': 1, + 'b': Basic(name='spam', value='eggs'), + 'c': 3, + }) + + def test_validate(self): + raw = {'a': 1, 'b': 2, 'c': 3} + param = MappingParameter(Mapping(int)) + handler = param.match_type(raw) + handler.validate(raw) + + def test_as_data_simple(self): + raw = {'a': 1, 'b': 2, 'c': 3} + param = MappingParameter(Mapping(int)) + handler = param.match_type(raw) + data = handler.as_data(raw) + + self.assertEqual(data, raw) + + def test_as_data_complicated(self): + param = MappingParameter(Mapping(Union(int, Basic))) + value = { + 'a': 1, + 'b': BASIC_FULL, + 'c': 3, + } + handler = param.match_type(value) + data = handler.as_data({ + 'a': 1, + 'b': Basic(name='spam', value='eggs'), + 'c': 3, + }) + + self.assertEqual(data, value) + + class ComplexParameterTests(unittest.TestCase): def test_match_type_none_missing(self): From 1f732ea686a78bb61d31d58dbdd8cd2b3dd6ccbd Mon Sep 17 00:00:00 2001 From: Eric Snow Date: Tue, 6 Feb 2018 09:19:45 +0000 Subject: [PATCH 10/10] Fix lint. --- debugger_protocol/arg/_decl.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/debugger_protocol/arg/_decl.py b/debugger_protocol/arg/_decl.py index 049b2b77e..1466c6d86 100644 --- a/debugger_protocol/arg/_decl.py +++ b/debugger_protocol/arg/_decl.py @@ -257,10 +257,11 @@ def __init__(self, valuetype, keytype=str, _normalize=True): def __repr__(self): if self.keytype is str: - return '{}(valuetype={!r})'.format(type(self).__name__, self.valuetype) + return '{}(valuetype={!r})'.format( + type(self).__name__, self.valuetype) else: return '{}(keytype={!r}, valuetype={!r})'.format( - type(self).__name__, self.keytype, self.valuetype) + type(self).__name__, self.keytype, self.valuetype) def __hash__(self): return hash((self.keytype, self.valuetype))