diff --git a/debugger_protocol/arg/__init__.py b/debugger_protocol/arg/__init__.py index f294ec8a2..ec58e37aa 100644 --- a/debugger_protocol/arg/__init__.py +++ b/debugger_protocol/arg/__init__.py @@ -1,6 +1,6 @@ from ._common import NOT_SET, ANY # noqa from ._datatype import FieldsNamespace # noqa -from ._decl import Union, Array, Field # noqa +from ._decl import Enum, Union, Array, Mapping, Field # noqa from ._errors import ( # noqa ArgumentError, ArgMissingError, IncompleteArgError, ArgTypeMismatchError, diff --git a/debugger_protocol/arg/_datatype.py b/debugger_protocol/arg/_datatype.py index 575eb8d96..b4d3000e4 100644 --- a/debugger_protocol/arg/_datatype.py +++ b/debugger_protocol/arg/_datatype.py @@ -22,7 +22,7 @@ def _coerce(datatype, value, call=True): # decl types elif isinstance(datatype, Enum): value = _coerce(datatype.datatype, value, call=False) - if value in datatype.choices: + if value in datatype.choice: return value elif isinstance(datatype, Union): for dt in datatype: @@ -79,49 +79,70 @@ class FieldsNamespace(Readonly, WithRepr): PARAM_TYPE = None PARAM = None + _TRAVERSING = False + @classmethod def traverse(cls, op, **kwargs): """Apply op to each field in cls.FIELDS.""" + if cls._TRAVERSING: # recursion check + return cls + cls._TRAVERSING = True + fields = cls._normalize(cls.FIELDS) - fields = fields.traverse(op) - cls.FIELDS = cls._normalize(fields) + try: + fields_traverse = fields.traverse + except AttributeError: + # must be normalizing right now... + return cls + fields = fields_traverse(op) + cls.FIELDS = cls._normalize(fields, force=True) + + cls._TRAVERSING = False return cls @classmethod def normalize(cls, *transforms): """Normalize FIELDS and apply the given ops.""" fields = cls._normalize(cls.FIELDS) - if not isinstance(fields, Fields): - fields = Fields(*fields) for transform in transforms: fields = _transform_datatype(fields, transform) fields = cls._normalize(fields) cls.FIELDS = fields + return cls @classmethod - def _normalize(cls, fields): + def _normalize(cls, fields, force=False): if fields is None: raise TypeError('missing FIELDS') - if isinstance(fields, Fields): - try: - normalized = cls._normalized - except AttributeError: - normalized = cls._normalized = False - else: - fields = Fields(*fields) - normalized = cls._normalized = False - if not normalized: + + try: + fixref = cls._fixref + except AttributeError: + fixref = cls._fixref = True + if not isinstance(fields, Fields): + fields = Fields(*fields) + if fixref or force: + cls._fixref = False fields = _transform_datatype(fields, lambda dt: _replace_ref(dt, cls)) return fields @classmethod - def bind(cls, ns, **kwargs): + def param(cls): param = cls.PARAM if param is None: if cls.PARAM_TYPE is None: - return cls(**ns) + return None param = cls.PARAM_TYPE(cls.FIELDS, cls) + return param + + @classmethod + def bind(cls, ns, **kwargs): + if isinstance(ns, cls): + return ns + param = cls.param() + if param is None: + return cls(**ns) return param.bind(ns, **kwargs) @classmethod diff --git a/debugger_protocol/arg/_decl.py b/debugger_protocol/arg/_decl.py index 50a1d33f7..1466c6d86 100644 --- a/debugger_protocol/arg/_decl.py +++ b/debugger_protocol/arg/_decl.py @@ -22,7 +22,19 @@ def _is_simple(datatype): def _normalize_datatype(datatype): cls = type(datatype) - if datatype == REF or datatype is TYPE_REFERENCE: + # normalized when instantiated: + if isinstance(datatype, Union): + return datatype + elif isinstance(datatype, Array): + return datatype + elif isinstance(datatype, Array): + return datatype + elif isinstance(datatype, Mapping): + return datatype + elif isinstance(datatype, Fields): + return datatype + # do not need normalization: + elif datatype is TYPE_REFERENCE: return TYPE_REFERENCE elif datatype is ANY: return ANY @@ -30,29 +42,38 @@ def _normalize_datatype(datatype): return datatype elif isinstance(datatype, Enum): return datatype - elif isinstance(datatype, Union): - return datatype - elif isinstance(datatype, Array): - return datatype + # convert to canonical types: + elif type(datatype) == type(REF) and datatype == REF: + return TYPE_REFERENCE elif cls is set or cls is frozenset: - return Union(*datatype) + return Union.unordered(*datatype) elif cls is list or cls is tuple: datatype, = datatype return Array(datatype) elif cls is dict: - raise NotImplementedError + if len(datatype) != 1: + raise NotImplementedError + [keytype, valuetype], = datatype.items() + return Mapping(valuetype, keytype) + # fallback: else: - return datatype + try: + normalize = datatype.normalize + except AttributeError: + return datatype + else: + return normalize() def _transform_datatype(datatype, op): + datatype = op(datatype) try: dt_traverse = datatype.traverse except AttributeError: pass else: datatype = dt_traverse(lambda dt: _transform_datatype(dt, op)) - return op(datatype) + return datatype def _replace_ref(datatype, target): @@ -113,6 +134,13 @@ class Union(tuple): Sets and frozensets are treated Unions in declarations. """ + @classmethod + def unordered(cls, *datatypes, **kwargs): + """Return an unordered union of the given datatypes.""" + self = cls(*datatypes, **kwargs) + self._ordered = False + return self + @classmethod def _traverse(cls, datatypes, op): changed = False @@ -136,6 +164,7 @@ def __new__(cls, *datatypes, **kwargs): ) self = super(Union, cls).__new__(cls, datatypes) self._simple = all(_is_simple(dt) for dt in datatypes) + self._ordered = True return self def __repr__(self): @@ -147,13 +176,16 @@ def __hash__(self): def __eq__(self, other): # honors order if not isinstance(other, Union): return NotImplemented - if super(Union, self).__eq__(other): + elif super(Union, self).__eq__(other): return True - if set(self) != set(other): + elif set(self) != set(other): return False - if self._simple and other._simple: + elif self._simple and other._simple: return True - return NotImplemented + elif not self._ordered and not other._ordered: + return True + else: + return NotImplemented def __ne__(self, other): return not (self == other) @@ -167,7 +199,10 @@ def traverse(self, op, **kwargs): datatypes, changed = self._traverse(self, op) if not changed and not kwargs: return self - return self.__class__(*datatypes, **kwargs) + updated = self.__class__(*datatypes, **kwargs) + if not self._ordered: + updated._ordered = False + return updated class Array(Readonly): @@ -179,13 +214,13 @@ class Array(Readonly): def __init__(self, itemtype, _normalize=True): if _normalize: - itemtype = _normalize_datatype(itemtype) + itemtype = _transform_datatype(itemtype, _normalize_datatype) self._bind_attrs( itemtype=itemtype, ) def __repr__(self): - return '{}(datatype={!r})'.format(type(self).__name__, self.itemtype) + return '{}(itemtype={!r})'.format(type(self).__name__, self.itemtype) def __hash__(self): return hash(self.itemtype) @@ -208,6 +243,56 @@ def traverse(self, op, **kwargs): return self.__class__(datatype, **kwargs) +class Mapping(Readonly): + """Declare a mapping (to a single type).""" + + def __init__(self, valuetype, keytype=str, _normalize=True): + if _normalize: + keytype = _transform_datatype(keytype, _normalize_datatype) + valuetype = _transform_datatype(valuetype, _normalize_datatype) + self._bind_attrs( + keytype=keytype, + valuetype=valuetype, + ) + + def __repr__(self): + if self.keytype is str: + return '{}(valuetype={!r})'.format( + type(self).__name__, self.valuetype) + else: + return '{}(keytype={!r}, valuetype={!r})'.format( + type(self).__name__, self.keytype, self.valuetype) + + def __hash__(self): + return hash((self.keytype, self.valuetype)) + + def __eq__(self, other): + try: + other_keytype = other.keytype + other_valuetype = other.valuetype + except AttributeError: + return False + if self.keytype != other_keytype: + return False + if self.valuetype != other_valuetype: + return False + return True + + def __ne__(self, other): + return not (self == other) + + def traverse(self, op, **kwargs): + """Return a copy with op applied to the item datatype.""" + keytype = op(self.keytype) + valuetype = op(self.valuetype) + if (keytype is self.keytype and + valuetype is self.valuetype and + not kwargs + ): + return self + return self.__class__(valuetype, keytype, **kwargs) + + class Field(namedtuple('Field', 'name datatype default optional')): """Declare a field in a data map param.""" @@ -220,7 +305,7 @@ def __new__(cls, name, datatype=str, enum=None, default=NOT_SET, enum = None if _normalize: - datatype = _normalize_datatype(datatype) + datatype = _transform_datatype(datatype, _normalize_datatype) self = super(Field, cls).__new__( cls, name=str(name) if name else None, @@ -315,4 +400,5 @@ def traverse(self, op, **kwargs): if not changed and not kwargs: return self + kwargs['_normalize'] = False return self.__class__(*updated, **kwargs) diff --git a/debugger_protocol/arg/_params.py b/debugger_protocol/arg/_params.py index f988539a1..5562d6235 100644 --- a/debugger_protocol/arg/_params.py +++ b/debugger_protocol/arg/_params.py @@ -1,6 +1,6 @@ -from ._common import ANY, SIMPLE_TYPES +from ._common import NOT_SET, ANY, SIMPLE_TYPES from ._datatype import FieldsNamespace -from ._decl import Enum, Union, Array, Field, Fields +from ._decl import Enum, Union, Array, Mapping, Field, Fields from ._errors import ArgTypeMismatchError from ._param import Parameter, DatatypeHandler @@ -49,7 +49,8 @@ def param_from_datatype(datatype, **kwargs): elif not isinstance(datatype, type): raise NotImplementedError elif issubclass(datatype, FieldsNamespace): - return ComplexParameter(datatype, **kwargs) + param = datatype.param() + return param or ComplexParameter(datatype, **kwargs) else: raise NotImplementedError @@ -276,6 +277,98 @@ def match_type(self, raw): return self.HANDLER(self.datatype, handlers) +class MappingParameter(Parameter): + """A parameter that is a mapping of some fixed type.""" + + class HANDLER(DatatypeHandler): + + def __init__(self, datatype, handlers=None, + keyparam=None, valueparam=None): + if not isinstance(datatype, Mapping): + raise ValueError( + 'expected an Mapping, got {!r}'.format(datatype)) + super(MappingParameter.HANDLER, self).__init__(datatype) + self.handlers = handlers + self.keyparam = keyparam + self.valueparam = valueparam + + def coerce(self, raw): + if self.handlers is None: + if self.keyparam is None: + keytype = self.datatype.keytype + self.keyparam = param_from_datatype(keytype) + if self.valueparam is None: + valuetype = self.datatype.valuetype + self.valueparam = param_from_datatype(valuetype) + handlers = {} + for key, value in raw.items(): + keyhandler = self.keyparam.match_type(key) + if keyhandler is None: + raise ArgTypeMismatchError(key) + valuehandler = self.valueparam.match_type(value) + if valuehandler is None: + raise ArgTypeMismatchError(value) + handlers[key] = (keyhandler, valuehandler) + self.handlers = handlers + + result = {} + for key, value in raw.items(): + keyhandler, valuehandler = self.handlers[key] + key = keyhandler.coerce(key) + value = valuehandler.coerce(value) + result[key] = value + return result + + def validate(self, coerced): + if self.handlers is None: + raise TypeError('coerce first') + for key, value in coerced.items(): + keyhandler, valuehandler = self.handlers[key] + keyhandler.validate(key) + valuehandler.validate(value) + + def as_data(self, coerced): + if self.handlers is None: + raise TypeError('coerce first') + data = {} + for key, value in coerced.items(): + keyhandler, valuehandler = self.handlers[key] + key = keyhandler.as_data(key) + value = valuehandler.as_data(value) + data[key] = value + return data + + @classmethod + def from_valuetype(cls, valuetype, keytype=str, **kwargs): + datatype = Mapping(valuetype, keytype) + return cls(datatype, **kwargs) + + def __init__(self, datatype): + if not isinstance(datatype, Mapping): + raise ValueError('expected Mapping, got {!r}'.format(datatype)) + keyparam = param_from_datatype(datatype.keytype) + valueparam = param_from_datatype(datatype.valuetype) + handler = self.HANDLER(datatype, None, keyparam, valueparam) + super(MappingParameter, self).__init__(datatype, handler) + + self.keyparam = keyparam + self.valueparam = valueparam + + def match_type(self, raw): + if not isinstance(raw, dict): + return None + handlers = {} + for key, value in raw.items(): + keyhandler = self.keyparam.match_type(key) + if keyhandler is None: + return None + valuehandler = self.valueparam.match_type(value) + if valuehandler is None: + return None + handlers[key] = (keyhandler, valuehandler) + return self.HANDLER(self.datatype, handlers) + + class ComplexParameter(Parameter): class HANDLER(DatatypeHandler): @@ -344,6 +437,7 @@ class ArgNamespace(FieldsNamespace): msg = 'expected Fields or FieldsNamespace, got {!r}' raise ValueError(msg.format(datatype)) datatype.normalize() + datatype.PARAM = self # We set handler later in match_type(). super(ComplexParameter, self).__init__(datatype) @@ -372,6 +466,8 @@ def match_type(self, raw): if not field.optional: return None value = field.default + if value is NOT_SET: + continue param = self.params[field.name] handler = param.match_type(value) if handler is None: diff --git a/tests/debugger_protocol/arg/test__datatype.py b/tests/debugger_protocol/arg/test__datatype.py index 3ddfa6afb..326299d0a 100644 --- a/tests/debugger_protocol/arg/test__datatype.py +++ b/tests/debugger_protocol/arg/test__datatype.py @@ -3,7 +3,8 @@ from debugger_protocol.arg._common import ANY from debugger_protocol.arg._datatype import FieldsNamespace -from debugger_protocol.arg._decl import Array, Field, Fields +from debugger_protocol.arg._decl import Union, Array, Field, Fields +from debugger_protocol.arg._param import Parameter, DatatypeHandler, Arg from ._common import ( BASIC_FULL, BASIC_MIN, Basic, @@ -134,23 +135,25 @@ class Spam(FieldsNamespace): self.assertIs(Spam.FIELDS, fields) for i, field in enumerate(Spam.FIELDS): self.assertIs(field, fieldlist[i]) + self.maxDiff = None self.assertEqual(calls, [ - (op1, str), + (op1, fields), (op1, Field('spam')), - (op1, int), + (op1, str), (op1, Field('ham', int)), - (op1, ANY), - (op1, Array(ANY)), + (op1, int), (op1, Field('eggs', Array(ANY))), - (op1, fields), - (op2, str), + (op1, Array(ANY)), + (op1, ANY), + + (op2, fields), (op2, Field('spam')), - (op2, int), + (op2, str), (op2, Field('ham', int)), - (op2, ANY), - (op2, Array(ANY)), + (op2, int), (op2, Field('eggs', Array(ANY))), - (op2, fields), + (op2, Array(ANY)), + (op2, ANY), ]) def test_normalize_with_op_changed(self): @@ -166,12 +169,119 @@ class Spam(FieldsNamespace): Field('spam', Array(int)), )) + def test_normalize_declarative(self): + class Spam(FieldsNamespace): + FIELDS = [ + Field('a'), + Field('b', bool), + Field.START_OPTIONAL, + Field('c', {int, str}), + Field('d', [int]), + Field('e', ANY), + Field('f', ''), + ] + + class Ham(FieldsNamespace): + FIELDS = [ + Field('w', Spam), + Field('x', int), + Field('y', frozenset({int, str})), + Field('z', (int,)), + ] + + class Eggs(FieldsNamespace): + FIELDS = [ + Field('b', [Ham]), + Field('x', [{str, ('',)}], optional=True), + Field('d', {Spam, ''}, optional=True), + ] + + Eggs.normalize() + + self.assertEqual(Spam.FIELDS, Fields( + Field('a'), + Field('b', bool), + Field('c', Union(int, str), optional=True), + Field('d', Array(int), optional=True), + Field('e', ANY, optional=True), + Field('f', Spam, optional=True), + )) + self.assertEqual(Ham.FIELDS, Fields( + Field('w', Spam), + Field('x', int), + Field('y', Union(int, str)), + Field('z', Array(int)), + )) + self.assertEqual(Eggs.FIELDS, Fields( + Field('b', Array(Ham)), + Field('x', + Array(Union.unordered(str, Array(Eggs))), + optional=True), + Field('d', Union.unordered(Spam, Eggs), optional=True), + )) + def test_normalize_missing(self): with self.assertRaises(TypeError): FieldsNamespace.normalize() ####### + def test_bind_no_param(self): + class Spam(FieldsNamespace): + FIELDS = [ + Field('a'), + ] + + arg = Spam.bind({'a': 'x'}) + + self.assertIsInstance(arg, Spam) + self.assertEqual(arg, Spam(a='x')) + + def test_bind_with_param_obj(self): + class Param(Parameter): + HANDLER = DatatypeHandler(ANY) + match_type = (lambda self, raw: self.HANDLER) + + class Spam(FieldsNamespace): + PARAM = Param(ANY) + FIELDS = [ + Field('a'), + ] + + arg = Spam.bind({'a': 'x'}) + + self.assertIsInstance(arg, Arg) + self.assertEqual(arg, Arg(Param(ANY), {'a': 'x'})) + + def test_bind_with_param_type(self): + class Param(Parameter): + HANDLER = DatatypeHandler(ANY) + match_type = (lambda self, raw: self.HANDLER) + + class Spam(FieldsNamespace): + PARAM_TYPE = Param + FIELDS = [ + Field('a'), + ] + + arg = Spam.bind({'a': 'x'}) + + self.assertIsInstance(arg, Arg) + self.assertEqual(arg, Arg(Param(Spam.FIELDS), {'a': 'x'})) + + def test_bind_already_bound(self): + class Spam(FieldsNamespace): + FIELDS = [ + Field('a'), + ] + + spam = Spam(a='x') + arg = Spam.bind(spam) + + self.assertIs(arg, spam) + + ####### + def test_fields_full(self): class Spam(FieldsNamespace): FIELDS = FIELDS_EXTENDED diff --git a/tests/debugger_protocol/arg/test__decl.py b/tests/debugger_protocol/arg/test__decl.py index 079dc77d4..2aa953ee4 100644 --- a/tests/debugger_protocol/arg/test__decl.py +++ b/tests/debugger_protocol/arg/test__decl.py @@ -4,7 +4,7 @@ from debugger_protocol.arg._datatype import FieldsNamespace from debugger_protocol.arg._decl import ( REF, TYPE_REFERENCE, _normalize_datatype, _transform_datatype, - Enum, Union, Array, Field, Fields) + Enum, Union, Array, Mapping, Field, Fields) from debugger_protocol.arg._param import Parameter, DatatypeHandler, Arg from debugger_protocol.arg._params import ( SimpleParameter, UnionParameter, ArrayParameter, ComplexParameter) @@ -13,6 +13,12 @@ class ModuleTests(unittest.TestCase): def test_normalize_datatype(self): + class Spam: + @classmethod + def normalize(cls): + return OKAY + + OKAY = object() NOOP = object() param = SimpleParameter(str) tests = [ @@ -31,6 +37,8 @@ def test_normalize_datatype(self): (Array(str), NOOP), ([str], Array(str)), ((str,), Array(str)), + (Mapping(str), NOOP), + ({str: str}, Mapping(str)), # others (Field('spam'), NOOP), (Fields(Field('spam')), NOOP), @@ -45,6 +53,7 @@ def test_normalize_datatype(self): (object(), NOOP), (object, NOOP), (type, NOOP), + (Spam, OKAY), ] for datatype, expected in tests: if expected is NOOP: @@ -54,9 +63,6 @@ def test_normalize_datatype(self): self.assertEqual(datatype, expected) - with self.assertRaises(NotImplementedError): - _normalize_datatype({1: 2}) - def test_transform_datatype_simple(self): datatypes = [ REF, @@ -95,10 +101,9 @@ def test_transform_datatype_container(self): class Spam(FieldsNamespace): FIELDS = [ Field('a'), + Field('b', {str: str}) ] - Spam.normalize() - fields = Fields(Field('...')) field_spam = Field('spam', ANY) field_ham = Field('ham', Union( @@ -113,42 +118,46 @@ class Spam(FieldsNamespace): ) tests = { Array(str): [ - str, Array(str), + str, ], Field('...'): [ - str, Field('...'), + str, ], fields: [ - str, - Field('...'), fields, + Field('...'), + str, ], nested: [ - str, - Field('...'), - fields, + nested, + # ... Field('???', fields), + fields, + Field('...'), + str, # ... - ANY, Field('spam', ANY), + ANY, # ... - str, - Field('a'), - Fields(Field('a')), - Spam, - Array(Spam), - Union(Array(Spam)), field_ham, + Union(Array(Spam)), + Array(Spam), + Spam, + Field('a'), + str, + Field('b', Mapping(str)), + Mapping(str), + str, + str, # ... - TYPE_REFERENCE, - Array(TYPE_REFERENCE), field_eggs, - # ... - nested, + Array(TYPE_REFERENCE), + TYPE_REFERENCE, ], } + self.maxDiff = None for datatype, expected in tests.items(): calls = [] op = (lambda dt: calls.append(dt) or dt) @@ -165,10 +174,8 @@ class Spam(FieldsNamespace): transformed = _transform_datatype(datatype, op) self.assertIs(transformed, datatype) - self.assertEqual(set(calls[:2]), {str, int}) - self.assertEqual(calls[2:], [ - Union(str, int), - ]) + self.assertEqual(calls[0], Union(str, int)) + self.assertEqual(set(calls[1:]), {str, int}) class EnumTests(unittest.TestCase): @@ -209,6 +216,7 @@ def test_normalized(self): (frozenset([str, int]), Union(*frozenset([str, int]))), ([str], Array(str)), ((str,), Array(str)), + ({str: str}, Mapping(str)), (None, None), ] for datatype, expected in tests: @@ -217,9 +225,6 @@ def test_normalized(self): self.assertEqual(union, Union(int, expected, str)) - with self.assertRaises(NotImplementedError): - Union({1: 2}) - def test_traverse_noop(self): calls = [] op = (lambda dt: calls.append(dt) or dt) @@ -256,6 +261,7 @@ def test_normalized(self): (frozenset([str, int]), Union(str, int)), ([str], Array(str)), ((str,), Array(str)), + ({str: str}, Mapping(str)), (None, None), ] for datatype, expected in tests: @@ -264,8 +270,20 @@ def test_normalized(self): self.assertEqual(array, Array(expected)) - with self.assertRaises(NotImplementedError): - Array({1: 2}) + def test_normalized_transformed(self): + calls = 0 + + class Spam: + @classmethod + def traverse(cls, op): + nonlocal calls + calls += 1 + return cls + + array = Array(Spam) + + self.assertIs(array.itemtype, Spam) + self.assertEqual(calls, 1) def test_traverse_noop(self): calls = [] @@ -292,6 +310,67 @@ def test_traverse_changed(self): ]) +class MappingTests(unittest.TestCase): + + def test_normalized(self): + tests = [ + (REF, TYPE_REFERENCE), + ({str, int}, Union(str, int)), + (frozenset([str, int]), Union(str, int)), + ([str], Array(str)), + ((str,), Array(str)), + ({str: str}, Mapping(str)), + (None, None), + ] + for datatype, expected in tests: + with self.subTest(datatype): + mapping = Mapping(datatype) + + self.assertEqual(mapping, Mapping(expected)) + + def test_normalized_transformed(self): + calls = 0 + + class Spam: + @classmethod + def traverse(cls, op): + nonlocal calls + calls += 1 + return cls + + mapping = Mapping(Spam) + + self.assertIs(mapping.keytype, str) + self.assertIs(mapping.valuetype, Spam) + self.assertEqual(calls, 1) + + def test_traverse_noop(self): + calls = [] + op = (lambda dt: calls.append(dt) or dt) + mapping = Mapping(Union(str, int)) + transformed = mapping.traverse(op) + + self.assertIs(transformed, mapping) + self.assertCountEqual(calls, [ + str, + # Note that it did not recurse into Union(str, int). + Union(str, int), + ]) + + def test_traverse_changed(self): + calls = [] + op = (lambda dt: calls.append(dt) or str) + mapping = Mapping(ANY) + transformed = mapping.traverse(op) + + self.assertIsNot(transformed, mapping) + self.assertEqual(transformed, Mapping(str)) + self.assertEqual(calls, [ + str, + ANY, + ]) + + class FieldTests(unittest.TestCase): def test_defaults(self): @@ -314,6 +393,7 @@ def test_normalized(self): (frozenset([str, int]), Union(str, int)), ([str], Array(str)), ((str,), Array(str)), + ({str: str}, Mapping(str)), (None, None), ] for datatype, expected in tests: @@ -322,8 +402,20 @@ def test_normalized(self): self.assertEqual(field, Field('spam', expected)) - with self.assertRaises(NotImplementedError): - Field('spam', {1: 2}) + def test_normalized_transformed(self): + calls = 0 + + class Spam: + @classmethod + def traverse(cls, op): + nonlocal calls + calls += 1 + return cls + + field = Field('spam', Spam) + + self.assertIs(field.datatype, Spam) + self.assertEqual(calls, 1) def test_traverse_noop(self): calls = [] @@ -386,6 +478,7 @@ def test_normalized(self): (frozenset([str, int]), Union(str, int)), ([str], Array(str)), ((str,), Array(str)), + ({str: str}, Mapping(str)), (None, None), ] for datatype, expected in tests: @@ -398,11 +491,6 @@ def test_normalized(self): Field('spam', expected), ]) - with self.assertRaises(NotImplementedError): - Fields( - Field('spam', {1: 2}), - ) - def test_with_START_OPTIONAL(self): fields = Fields( Field('spam'), @@ -431,7 +519,7 @@ def test_as_dict(self): Field('ham'), Field('eggs', Array(str)), ) - result = fields.as_dict + result = fields.as_dict() self.assertEqual(result, { 'spam': fields[0], diff --git a/tests/debugger_protocol/arg/test__params.py b/tests/debugger_protocol/arg/test__params.py index 2a05d69f2..64aa59f37 100644 --- a/tests/debugger_protocol/arg/test__params.py +++ b/tests/debugger_protocol/arg/test__params.py @@ -1,13 +1,14 @@ import unittest from debugger_protocol.arg._common import NOT_SET, ANY -from debugger_protocol.arg._decl import Enum, Union, Array, Field, Fields +from debugger_protocol.arg._decl import ( + Enum, Union, Array, Mapping, Field, Fields) from debugger_protocol.arg._param import Parameter, DatatypeHandler from debugger_protocol.arg._params import ( param_from_datatype, NoopParameter, SingletonParameter, SimpleParameter, EnumParameter, - UnionParameter, ArrayParameter, ComplexParameter) + UnionParameter, ArrayParameter, MappingParameter, ComplexParameter) from ._common import FIELDS_BASIC, BASIC_FULL, Basic @@ -667,9 +668,97 @@ def test_as_data_complicated(self): self.assertEqual(data, value) +class MappingParameterTests(unittest.TestCase): + + def test_match_type_match(self): + param = MappingParameter(Mapping(int)) + expected = MappingParameter.HANDLER(Mapping(int)) + values = [ + {'a': 1, 'b': 2, 'c': 3}, + {}, + ] + for value in values: + with self.subTest(value): + handler = param.match_type(value) + + self.assertEqual(handler, expected) + + def test_match_type_no_match(self): + param = MappingParameter(Mapping(int)) + values = [ + {'a': 1, 'b': '2', 'c': 3}, + [('a', 1), ('b', 2), ('c', 3)], + 'spam', + ] + for value in values: + with self.subTest(value): + handler = param.match_type(value) + + self.assertIs(handler, None) + + def test_coerce_simple(self): + param = MappingParameter(Mapping(int)) + values = [ + {'a': 1, 'b': 2, 'c': 3}, + {}, + ] + for value in values: + with self.subTest(value): + handler = param.match_type(value) + coerced = handler.coerce(value) + + self.assertEqual(coerced, value) + + def test_coerce_complicated(self): + param = MappingParameter(Mapping(Union(int, Basic))) + value = { + 'a': 1, + 'b': BASIC_FULL, + 'c': 3, + } + handler = param.match_type(value) + coerced = handler.coerce(value) + + self.assertEqual(coerced, { + 'a': 1, + 'b': Basic(name='spam', value='eggs'), + 'c': 3, + }) + + def test_validate(self): + raw = {'a': 1, 'b': 2, 'c': 3} + param = MappingParameter(Mapping(int)) + handler = param.match_type(raw) + handler.validate(raw) + + def test_as_data_simple(self): + raw = {'a': 1, 'b': 2, 'c': 3} + param = MappingParameter(Mapping(int)) + handler = param.match_type(raw) + data = handler.as_data(raw) + + self.assertEqual(data, raw) + + def test_as_data_complicated(self): + param = MappingParameter(Mapping(Union(int, Basic))) + value = { + 'a': 1, + 'b': BASIC_FULL, + 'c': 3, + } + handler = param.match_type(value) + data = handler.as_data({ + 'a': 1, + 'b': Basic(name='spam', value='eggs'), + 'c': 3, + }) + + self.assertEqual(data, value) + + class ComplexParameterTests(unittest.TestCase): - def test_match_type(self): + def test_match_type_none_missing(self): fields = Fields(*FIELDS_BASIC) param = ComplexParameter(fields) handler = param.match_type(BASIC_FULL) @@ -677,6 +766,19 @@ def test_match_type(self): self.assertIs(type(handler), ComplexParameter.HANDLER) self.assertEqual(handler.datatype.FIELDS, fields) + def test_match_type_missing_optional(self): + fields = Fields( + Field('name'), + Field.START_OPTIONAL, + Field('value'), + ) + param = ComplexParameter(fields) + handler = param.match_type({'name': 'spam'}) + + self.assertIs(type(handler), ComplexParameter.HANDLER) + self.assertEqual(handler.datatype.FIELDS, fields) + self.assertNotIn('value', handler.handlers) + def test_coerce(self): handler = ComplexParameter.HANDLER(Basic) coerced = handler.coerce(BASIC_FULL)