diff --git a/proto/fields.py b/proto/fields.py index cc98e8b0..b0d9a742 100644 --- a/proto/fields.py +++ b/proto/fields.py @@ -12,11 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datetime from enum import EnumMeta +from re import L +from proto.datetime_helpers import DatetimeWithNanoseconds +from proto.marshal.rules import wrappers +from proto.marshal.rules.dates import DurationRule +from proto.marshal.collections.maps import MapComposite +from proto.utils import cached_property +from typing import Any, Callable, Optional, Union from google.protobuf import descriptor_pb2 +from google.protobuf import duration_pb2 +from google.protobuf import struct_pb2 +from google.protobuf import timestamp_pb2 +from google.protobuf import wrappers_pb2 from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper +import proto from proto.primitives import ProtoType @@ -140,6 +153,88 @@ def pb_type(self): return self.message.pb() return self.message + @property + def can_represent_natively(self) -> bool: + return not ( + self.proto_type == ProtoType.MESSAGE and + self.message == struct_pb2.Value + ) + + def contribute_to_class(self, cls, name: str): + """Attaches a descriptor to the top-level proto.Message class, so that attribute + reads and writes can be specially handled in `_FieldDescriptor.__get__` and + `FieldDescriptor.__set__`. + + Also contains hooks for write-time type-coersion to translate special cases between + pure Pythonic objects and pb2-compatible structs or values. + """ + set_coercion = None + if self.proto_type == ProtoType.STRING: + # Bytes are accepted for string values, but strings are not accepted for byte values. + # This is an artifact of older Python2 implementations. + set_coercion = self._bytes_to_str + elif self.pb_type == timestamp_pb2.Timestamp: + set_coercion = self._timestamp_to_datetime + elif self.proto_type == ProtoType.MESSAGE and self.message == duration_pb2.Duration: + set_coercion = self._duration_to_timedelta + elif self.proto_type == ProtoType.MESSAGE and self.message == wrappers_pb2.BoolValue: + set_coercion = self._bool_value_to_bool + elif self.enum: + set_coercion = self._literal_to_enum + setattr(cls, name, _FieldDescriptor(name, cls=cls, set_coercion=set_coercion)) + + @cached_property + def reverse_enum_map(self): + """Helper that allows for constant-time lookup on self.enum, used to hydrate + primitives that are supplied but which stand for their official enum types. + + This is used when a developer supplies the literal value for an enum type (often an int). + """ + return {e.value: e for e in self.enum} if self.enum else None + + @cached_property + def reverse_enum_names_map(self): + """Helper that allows for constant-time lookup on self.enum, used to hydrate + primitives that are supplied but which stand for their official enum types. + + This is used when a developer supplies the string value for an enum type's name. + """ + return {e.name: e for e in self.enum} if self.enum else None + + def _literal_to_enum(self, val: Any): + if isinstance(val, self.enum): + return val + return ( + self.reverse_enum_map.get(val, None) or + self.reverse_enum_names_map.get(val, None) + ) + + @staticmethod + def _bytes_to_str(val: Union[bytes, str]) -> str: + if type(val) == bytes: + val = val.decode('utf-8') + return val + + @staticmethod + def _timestamp_to_datetime(val: Union[timestamp_pb2.Timestamp, datetime.datetime]) -> datetime.datetime: + if type(val) == timestamp_pb2.Timestamp: + val = DatetimeWithNanoseconds.from_timestamp_pb(val) + return val + + @staticmethod + def _duration_to_timedelta(val: Union[duration_pb2.Duration, datetime.timedelta]) -> datetime.datetime: + if type(val) == duration_pb2.Duration: + val = DurationRule().to_python(val) + return val + + @staticmethod + def _bool_value_to_bool(val: Union[wrappers_pb2.BoolValue, bool]) -> Optional[bool]: + if val is None: + return None + if type(val) == wrappers_pb2.BoolValue: + val = val.value + return val + class RepeatedField(Field): """A representation of a repeated field in protocol buffers.""" @@ -155,6 +250,192 @@ def __init__(self, key_type, value_type, *, number: int, message=None, enum=None self.map_key_type = key_type +class _FieldDescriptor: + """Handler for proto.Field access on any proto.Message object. + + Wraps each proto.Field instance within a given proto.Message subclass's definition + with getters and setters that allow for caching of values on the proto-plus object, + deferment of syncing to the underlying pb2 object, and tracking of the current state. + + Special treatment is given to MapFields, nested Messages, and certain data types, as + their various implementations within pb2 (which for our purposes is mostly a black box) + sometimes mandate immediate syncing. This is usually because proto-plus objects are not + long-lived, and thus information about which fields are stale would be lost if syncing + was left for serialization time. + """ + + # Namespace for attributes where we will store the Pythonic values of + # various `proto.Field` classes on instantiated `proto.Message` objects. + # For example, in the following scenario, this attribute is involved in + # saving the value "Homer Simpson" to `my_message._cached_fields__name`. + # + # class MyMessage(proto.Message): + # name = proto.Field(proto.STRING, ...) + # + # my_message = MyMessage() + # my_message.name = "Homer Simpson" # saves to `_cached_fields__name` + cached_fields_prefix = '_cached_fields__' + + def __init__(self, name: str, *, cls, set_coercion: Optional[Callable] = None): + # something like "id". required whenever reach back to the pb2 object. + self.original_name = name + # something like "_cached_id" + self.instance_attr_name = f'{self.cached_fields_prefix}{name}' + + # simple types coercion for setting attributes + # (e.g., bytes -> str if our type is string, but we are supplied bytes) + # the signature of `set_coercion` is dependent on the field's data types + # and is always handled by `contribute_to_class` which pairs data types + # to appropriate write-time coercions. + self._set_coercion: Optional[Callable] = set_coercion + self.cls = cls + + @property + def field(self): + return self.cls._meta.fields[self.original_name] + + def _hydrate_dicts(self, value: Any): + """Turns a dictionary assigned to a nested Message into a full instance of + that Message type. + """ + if not isinstance(value, dict): + return value + + if self.field.proto_type == proto.MESSAGE: + _pb = self.field.message._meta.pb(**value) + value = self.field.message.wrap(_pb) + + return value + + def _clear_oneofs(self, instance): + if not self.field.oneof: + return + + for field_name, field in self.cls._meta.fields.items(): + # Don't clear this field + if field_name == self.original_name: + continue + + # Don't clear other fields with different oneof values, or with + # no such values at all + if field.oneof != self.field.oneof: + continue + + delattr(instance, field_name) + + def __set__(self, instance, value): + """Called whenever a value is assigned to a `proto.Field` attribute on an instantiated + `proto.Message` object. + + Usage: + + class MyMessage(proto.Message): + name = proto.Field(proto.STRING, number=1) + + my_message = MyMessage() + my_message.name = "Frodo" + + In the above scenario, `__set__` is called with "Frodo" passed as `value` and + `my_message` passed as `instance`. + """ + value = self._set_coercion(value) if self._set_coercion is not None else value + value = self._hydrate_dicts(value) + + # Warning: `always_commit` is hacky! + # Some contexts, particularly instances created from MapFields, require immediate syncing. + # It is impossible to deduce such a scenario purely from logic available to this function, + # so instead we set a flag on instances when a MapField yields them, and then when those + # instances receive attribute updates, immediately syncing those values to the underlying + # pb2 instance is sufficient. + always_commit: bool = getattr(instance, '_always_commit', False) + if always_commit or not self.field.can_represent_natively: + pb_value = instance._meta.marshal.to_proto(self.field.pb_type, value) + _pb = instance._meta.pb(**{self.original_name: pb_value}) + instance._pb.ClearField(self.original_name) + instance._pb.MergeFrom(_pb) + else: + + if value is None: + self.__delete__(instance) + instance._pb.MergeFrom(instance._meta._pb(**{self.original_name: None})) + return + + instance._meta.marshal.validate_primitives(self.field, value) + instance._mark_pb_stale(self.original_name) + + setattr(instance, self.instance_attr_name, value) + self._clear_oneofs(instance) + + def __get__(self, instance: 'proto.Message', _): # type: ignore + """Called whenever a value is read from a proto.Field attribute on an instantiated + proto.Message object. + + Usage: + + class MyMessage(proto.Message): + name = proto.Field(proto.STRING, number=1) + + my_message = MyMessage(name="Frodo") + print(my_message.name) + + In the above scenario, `__get__` is called with "my_message" passed as + `instance`. + """ + # If `instance` is None, then we are accessing this field directly + # off the class itself instead of off an instance. + if instance is None: + return self.original_name + + value = getattr(instance, self.instance_attr_name, _none) + is_map: bool = isinstance(value, MapComposite) + + # Return any values that do not require immediate rehydration. + # A few notes: + # * primitives are simple, and so can be returned + # * `Messages` are already Pythonic, and so can be returned + # * `Values` are wrappers and so have to be unwrapped + # * The exception to this is MapComposites, which have the same + # types as Values, but which handle their own field caching and + # thus can be returned when pulled off the `instance`. + if value is not _none and (is_map or self.field.can_represent_natively): + return value + + # For the most part, only primitive values can be returned natively, + # meaning this is either a Message itself, in which case, since we're + # dealing with the underlying pb object, we need to sync all deferred + # fields. This is functionally a no-op if no fields have been deferred. + if hasattr(value, '_update_pb'): + value._update_pb() + + pb_value = getattr(instance._pb, self.original_name, None) + value = instance._meta.marshal.to_python( + self.field.pb_type, pb_value, + absent=self.original_name not in instance, + ) + + setattr(instance, self.instance_attr_name, value) + return value + + def __delete__(self, instance): + if hasattr(instance, self.instance_attr_name): + delattr(instance, self.instance_attr_name) + instance._pb.ClearField(self.original_name) + if self.original_name in getattr(instance, '_stale_fields', []): + instance._stale_fields.remove(self.original_name) + + +class _NoneType: + def __bool__(self): + return False + + def __eq__(self, other): + """All _NoneType instances are equal""" + return isinstance(other, _NoneType) + + +_none = _NoneType() + + __all__ = ( "Field", "MapField", diff --git a/proto/marshal/collections/maps.py b/proto/marshal/collections/maps.py index 8ed11349..ecf928bb 100644 --- a/proto/marshal/collections/maps.py +++ b/proto/marshal/collections/maps.py @@ -13,7 +13,9 @@ # limitations under the License. import collections +from typing import Dict, Set +import proto from proto.utils import cached_property @@ -40,6 +42,8 @@ def __init__(self, sequence, *, marshal): """ self._pb = sequence self._marshal = marshal + self._item_cache: Dict = {} + self._stale_keys: Set[str] = set() def __contains__(self, key): # Protocol buffers is so permissive that querying for the existence @@ -47,16 +51,50 @@ def __contains__(self, key): # # By taking a tuple of the keys and querying that, we avoid sending # the lookup to protocol buffers and therefore avoid creating the key. + if key in self._item_cache: + return True return key in tuple(self.keys()) def __getitem__(self, key): # We handle raising KeyError ourselves, because otherwise protocol # buffers will create the key if it does not exist. - if key not in self: - raise KeyError(key) - return self._marshal.to_python(self._pb_type, self.pb[key]) + value = self._item_cache.get(key, _Empty.shared) + + if isinstance(value, _Empty): + if key not in self: + raise KeyError(key) + value = self._marshal.to_python(self._pb_type, self.pb[key]) + + # This is the first domino in a hacky workaround that is completed + # in `fields._FieldDescriptor.__set__`. Because of the by-value nature + # of protobufs (which conflicts with the by-reference nature of Python), + # proto-plus objects that are yielded from MapFields must immediately + # write to their internal pb2 object whenever their fields are updated. + # This is a new requirement as always writing to a proto-plus object's + # inner pb2 protobuf used to be the default, but has been moved to a + # lazy-syncing system for performance reasons. + if isinstance(value, proto.Message): + value._always_commit = True + + self._item_cache[key] = value + + return value def __setitem__(self, key, value): + self._item_cache[key] = value + self._stale_keys.add(key) + # self._sync_key(key, value) + + def _sync_all_keys(self): + for key in self._stale_keys: + self._sync_key(key) + self._stale_keys.clear() + + def _sync_key(self, key, value = None): + value = value or self._item_cache.pop(key, None) + if value is None: + self.pb.pop(key) + return pb_value = self._marshal.to_proto(self._pb_type, value, strict=True) # Directly setting a key is not allowed; however, protocol buffers @@ -69,14 +107,27 @@ def __setitem__(self, key, value): self.pb[key].MergeFrom(pb_value) def __delitem__(self, key): - self.pb.pop(key) + self._item_cache.pop(key, None) + self._stale_keys.add(key) + # self.pb.pop(key) def __len__(self): - return len(self.pb) + _all_keys = set(list(self._item_cache.keys())) + _all_keys = _all_keys.union(list(self.pb.keys())) + return len(_all_keys) + # return len(self.pb) def __iter__(self): + self._sync_all_keys() return iter(self.pb) @property def pb(self): return self._pb + + +class _Empty: + pass + + +_Empty.shared = _Empty() \ No newline at end of file diff --git a/proto/marshal/marshal.py b/proto/marshal/marshal.py index c8224ce2..8af25141 100644 --- a/proto/marshal/marshal.py +++ b/proto/marshal/marshal.py @@ -21,6 +21,7 @@ from google.protobuf import struct_pb2 from google.protobuf import wrappers_pb2 +import proto from proto.marshal import compat from proto.marshal.collections import MapComposite from proto.marshal.collections import Repeated @@ -30,6 +31,10 @@ from proto.marshal.rules import wrappers +max_uint_32 = 1<<32 - 1 +max_int_32 = 1<<31 - 1 +min_int_32 = (max_int_32 * -1) + 1 + class Rule(abc.ABC): """Abstract class definition for marshal rules.""" @@ -122,6 +127,84 @@ def register_rule_class(rule_class: type): return register_rule_class + @staticmethod + def _throw_type_error(primitive): + raise TypeError(f"Unacceptable value of type {type(primitive)}") + + @staticmethod + def _throw_value_error(proto_type, primitive): + raise ValueError(f"Unacceptable value {type(primitive)} for type {proto_type}") + + def validate_primitives(self, field, primitive): + """Replicates validation logic when assigning values to proto.Field attributes + on instantiated proto.Message objects. This is required to recreate the immediacy + of ValueError and TypeError checks that the pb2 layer made, but which are now + deferred because syncing to the pb2 layer itself is often entirely deferred. + """ + proto_type = field.proto_type + if field.repeated: + if primitive and not isinstance(primitive, (list, tuple, Repeated, RepeatedComposite,)): + self._throw_type_error(primitive) + + # Typically, checking a sequence's length after checking its Truthiness is + # unnecessary, but this will make us safe from any unexpected behavior from + # `Repeated` or `RepeatedComposite`. + if primitive and len(primitive) > 0: + primitive = primitive[0] + + _type = type(primitive) + if proto_type == proto.BOOL and _type != bool: + self._throw_type_error(primitive) + elif proto_type == proto.STRING and _type not in (str, bytes,): + self._throw_type_error(primitive) + elif proto_type == proto.DOUBLE and _type != float: + self._throw_type_error(primitive) + elif proto_type == proto.FLOAT and _type != float: + self._throw_type_error(primitive) + elif proto_type == proto.INT64 and _type != int: + self._throw_type_error(primitive) + elif proto_type == proto.UINT64: + if _type != int: + self._throw_type_error(primitive) + if primitive < 0: + self._throw_value_error(proto_type, primitive) + elif proto_type == proto.INT32: + if _type != int: + self._throw_type_error(primitive) + if primitive < min_int_32 or primitive > max_int_32: + self._throw_value_error(proto_type, primitive) + elif proto_type == proto.FIXED64 and _type != int: + self._throw_type_error(primitive) + elif proto_type == proto.FIXED32 and _type != int: + self._throw_type_error(primitive) + elif proto_type == proto.MESSAGE: + # Do nothing - this is not a primitive + pass + elif proto_type == proto.BYTES and _type != bytes: + self._throw_type_error(primitive) + elif proto_type == proto.UINT32: + if _type != int: + self._throw_type_error(primitive) + if primitive < 0 or primitive > max_int_32: + self._throw_value_error(proto_type, primitive) + elif proto_type == proto.ENUM: + # Do nothing - this is not a primitive + pass + elif proto_type == proto.SFIXED32: + if _type != int: + self._throw_type_error(primitive) + if primitive < min_int_32 or primitive > max_int_32: + self._throw_value_error(proto_type, primitive) + elif proto_type == proto.SFIXED64 and _type != int: + self._throw_type_error(primitive) + elif proto_type == proto.SINT32: + if _type != int: + self._throw_type_error(primitive) + if primitive < min_int_32 or primitive > max_int_32: + self._throw_value_error(proto_type, primitive) + elif proto_type == proto.SINT64 and _type != int: + self._throw_type_error(primitive) + def reset(self): """Reset the registry to its initial state.""" self._rules.clear() diff --git a/proto/message.py b/proto/message.py index e046d628..2eb09ada 100644 --- a/proto/message.py +++ b/proto/message.py @@ -15,13 +15,15 @@ import collections import collections.abc import copy +import inspect import re -from typing import List, Type +from typing import List, Set, Type from google.protobuf import descriptor_pb2 from google.protobuf import message from google.protobuf.json_format import MessageToDict, MessageToJson, Parse +import proto from proto import _file_info from proto import _package_info from proto.fields import Field @@ -31,6 +33,11 @@ from proto.primitives import ProtoType +def _has_contribute_to_class(value): + # Only call contribute_to_class() if it's bound. + return not inspect.isclass(value) and hasattr(value, 'contribute_to_class') + + class MessageMeta(type): """A metaclass for building and registering Message subclasses.""" @@ -105,6 +112,7 @@ def __new__(mcls, name, bases, attrs): # Okay, now we deal with all the rest of the fields. # Iterate over all the attributes and separate the fields into # their own sequence. + contributable_attrs = {} fields = [] new_attrs = {} oneofs = collections.OrderedDict() @@ -127,6 +135,9 @@ def __new__(mcls, name, bases, attrs): "package": package, } + if _has_contribute_to_class(field): + contributable_attrs[key] = field + # Add the field to the list of fields. fields.append(field) # If this field is part of a "oneof", ensure the oneof itself @@ -248,6 +259,11 @@ def __new__(mcls, name, bases, attrs): # Run the superclass constructor. cls = super().__new__(mcls, name, bases, new_attrs) + # Now that the class officially exists, but is not yet finalized, allow + # individual attributes to run functions to attach themselves in special ways. + for field_name, field in contributable_attrs.items(): + cls.add_to_class(field_name, field) + # The info class and fields need a reference to the class just created. cls._meta.parent = cls for field in cls._meta.fields.values(): @@ -269,6 +285,14 @@ def __new__(mcls, name, bases, attrs): def __prepare__(mcls, name, bases, **kwargs): return collections.OrderedDict() + def add_to_class(cls, name, value): + """Hook for attributes on the class definition to attach themselves + in special ways.""" + if _has_contribute_to_class(value): + value.contribute_to_class(cls, name) + else: + setattr(cls, name, value) + @property def meta(cls): return cls._meta @@ -289,6 +313,10 @@ def pb(cls, obj=None, *, coerce: bool = False): obj = cls(obj) else: raise TypeError("%r is not an instance of %s" % (obj, cls.__name__,)) + if hasattr(obj, '_update_pb'): + obj._update_pb() + if hasattr(obj, '_update_nested_pb'): + obj._update_nested_pb() return obj._pb def wrap(cls, pb): @@ -313,6 +341,8 @@ def serialize(cls, instance) -> bytes: Returns: bytes: The serialized representation of the protocol buffer. """ + if instance and type(instance) == cls: + instance._update_pb() return cls.pb(instance, coerce=True).SerializeToString() def deserialize(cls, payload: bytes) -> "Message": @@ -446,6 +476,13 @@ class Message(metaclass=MessageMeta): """ def __init__(self, mapping=None, *, ignore_unknown_fields=False, **kwargs): + + self._cached_pb = None + # Tracks any new values have had their serialization deferred, and thus whether we + # are ready to serialize immediately, or need to sync from the instance back to the + # underlying `_pb` + self._stale_fields: Set[str] = set() + # We accept several things for `mapping`: # * An instance of this class. # * An instance of the underlying protobuf descriptor class. @@ -454,7 +491,7 @@ def __init__(self, mapping=None, *, ignore_unknown_fields=False, **kwargs): if mapping is None: if not kwargs: # Special fast path for empty construction. - super().__setattr__("_pb", self._meta.pb()) + # `self._pb` is lazily initialized only when needed return mapping = kwargs @@ -539,6 +576,8 @@ def __contains__(self, key): bool: Whether the field's value corresponds to a non-empty wire serialization. """ + if key in getattr(self, '_stale_fields', set()): + return True pb_value = getattr(self._pb, key) try: # Protocol buffers "HasField" is unfriendly; it only works @@ -552,13 +591,6 @@ def __contains__(self, key): except ValueError: return bool(pb_value) - def __delattr__(self, key): - """Delete the value on the given field. - - This is generally equivalent to setting a falsy value. - """ - self._pb.ClearField(key) - def __eq__(self, other): """Return True if the messages are equal, False otherwise.""" # If these are the same type, use internal protobuf's equality check. @@ -572,38 +604,6 @@ def __eq__(self, other): # Ask the other object. return NotImplemented - def __getattr__(self, key): - """Retrieve the given field's value. - - In protocol buffers, the presence of a field on a message is - sufficient for it to always be "present". - - For primitives, a value of the correct type will always be returned - (the "falsy" values in protocol buffers consistently match those - in Python). For repeated fields, the falsy value is always an empty - sequence. - - For messages, protocol buffers does distinguish between an empty - message and absence, but this distinction is subtle and rarely - relevant. Therefore, this method always returns an empty message - (following the official implementation). To check for message - presence, use ``key in self`` (in other words, ``__contains__``). - - .. note:: - - Some well-known protocol buffer types - (e.g. ``google.protobuf.Timestamp``) will be converted to - their Python equivalents. See the ``marshal`` module for - more details. - """ - try: - pb_type = self._meta.fields[key].pb_type - pb_value = getattr(self._pb, key) - marshal = self._meta.marshal - return marshal.to_python(pb_type, pb_value, absent=key not in self) - except KeyError as ex: - raise AttributeError(str(ex)) - def __ne__(self, other): """Return True if the messages are unequal, False otherwise.""" return not self == other @@ -611,26 +611,61 @@ def __ne__(self, other): def __repr__(self): return repr(self._pb) - def __setattr__(self, key, value): - """Set the value on the given field. - - For well-known protocol buffer types which are marshalled, either - the protocol buffer object or the Python equivalent is accepted. + @property + def _pb(self): + if self._cached_pb is None: + self._cached_pb = self._meta.pb() + return self._cached_pb + + @_pb.setter + def _pb(self, value): + self._cached_pb = value + + def _mark_pb_stale(self, field_name: str): + """We often set fields on the proto-plus object (this), and do not immediately + sync them down to the underlying pb2 object. This mechanism tracks which + fields are stale in this way. """ - if key[0] == "_": - return super().__setattr__(key, value) - marshal = self._meta.marshal - pb_type = self._meta.fields[key].pb_type - pb_value = marshal.to_proto(pb_type, value) + if not hasattr(self, '_stale_fields'): + self._stale_fields = set() + self._stale_fields.add(field_name) + + def _mark_pb_synced(self): + if hasattr(self, '_stale_fields'): + self._stale_fields.clear() + + def _update_nested_pb(self): + """When it is time to serialize a pb2 object, it does not do to sync just ourselves - + we must also recursively search for nested proto-plus objects that may also require + a sync. + """ + for field_name, field in self._meta.fields.items(): + if field.proto_type == proto.MESSAGE: + obj = getattr(self, field_name, None) + # Traversing the tree of objects eventually encounters primitives + # that do not have these functions + if obj and hasattr(obj, '_update_pb'): + obj._update_pb() + # Same here - many values will not have these methods + if obj and hasattr(obj, '_update_nested_pb'): + obj._update_nested_pb() + + def _update_pb(self): + """Loops over any stale fields and syncs them to the underlying pb2 object. + """ + merge_params = {} + for field_name in getattr(self, '_stale_fields', set()): + wrapper_value = getattr(self, field_name, None) - # Clear the existing field. - # This is the only way to successfully write nested falsy values, - # because otherwise MergeFrom will no-op on them. - self._pb.ClearField(key) + field = self._meta.fields[field_name] + pb_value = self._meta.marshal.to_proto(field.pb_type, wrapper_value) + merge_params[field_name] = pb_value - # Merge in the value being set. - if pb_value is not None: - self._pb.MergeFrom(self._meta.pb(**{key: pb_value})) + self._pb.ClearField(field_name) + + if merge_params: + self._pb.MergeFrom(self._meta.pb(**merge_params)) + self._mark_pb_synced() class _MessageInfo: @@ -665,6 +700,13 @@ def __init__( self.marshal = marshal self._pb = None + def __repr__(self) -> str: + return ( + f"_MessageInfo(fields={self.fields}, package={self.package}, full_name={self.full_name}, " + f"options={self.options}, fields_by_number={self.fields_by_number}, marshal={self.marshal}, " + f"_pb={self._pb})" + ) + @property def pb(self) -> Type[message.Message]: """Return the protobuf message type for this descriptor. diff --git a/tests/test_fields_composite.py b/tests/test_fields_composite.py index 60d29e12..3e0bb736 100644 --- a/tests/test_fields_composite.py +++ b/tests/test_fields_composite.py @@ -24,6 +24,7 @@ class Spam(proto.Message): foo = proto.Field(proto.MESSAGE, number=1, message=Foo) eggs = proto.Field(proto.BOOL, number=2) + spam = Spam(foo=Foo(bar="str", baz=42)) assert spam.foo.bar == "str" assert spam.foo.baz == 42 diff --git a/tests/test_fields_map_composite.py b/tests/test_fields_map_composite.py index 13bf0c5f..601cfd64 100644 --- a/tests/test_fields_map_composite.py +++ b/tests/test_fields_map_composite.py @@ -31,6 +31,23 @@ class Baz(proto.Message): assert "k" not in baz.foos +def test_composite_map_with_stale_fields(): + class Foo(proto.Message): + bar = proto.Field(proto.INT32, number=1) + + class Baz(proto.Message): + foos = proto.MapField(proto.STRING, proto.MESSAGE, number=1, message=Foo,) + name = proto.Field(proto.STRING, number=2) + + baz = Baz(foos={"i": Foo(bar=42), "j": Foo(bar=24)}) + baz.name = "New Value" + assert "name" in baz._stale_fields + foo = baz.foos["i"] + foo.bar = 100 + assert Baz.to_dict(baz)["name"] == "New Value" + assert Baz.to_dict(baz)["foos"]["i"]["bar"] == 100 + + def test_composite_map_dict(): class Foo(proto.Message): bar = proto.Field(proto.INT32, number=1) diff --git a/tests/test_fields_optional.py b/tests/test_fields_optional.py index a15904c1..bc4defed 100644 --- a/tests/test_fields_optional.py +++ b/tests/test_fields_optional.py @@ -59,6 +59,7 @@ class Squid(proto.Message): assert s.mass_kg == 20 assert not s.mass_lbs assert Squid.iridiphore_num in s + assert s.iridiphore_num == 600 s = Squid(mass_lbs=40, iridiphore_num=600) assert not s.mass_kg diff --git a/tests/test_marshal_types_wrappers_bool.py b/tests/test_marshal_types_wrappers_bool.py index ffe206a6..9f8d25dd 100644 --- a/tests/test_marshal_types_wrappers_bool.py +++ b/tests/test_marshal_types_wrappers_bool.py @@ -66,8 +66,11 @@ class Foo(proto.Message): bar = proto.Field(proto.MESSAGE, message=wrappers_pb2.BoolValue, number=1,) foo = Foo(bar=True) + assert foo.bar is True foo.bar = wrappers_pb2.BoolValue() assert foo.bar is False + foo.bar = wrappers_pb2.BoolValue(value=True) + assert foo.bar is True def test_bool_value_del():