From a085789b7dfdff34b3d5d74f0c45fcc1b44dc071 Mon Sep 17 00:00:00 2001 From: Craig Labenz Date: Fri, 14 May 2021 09:59:39 -0700 Subject: [PATCH 1/7] performance refactor for handoff to protobuf layer --- proto/fields.py | 203 ++++++++++++++++++++++ proto/marshal/collections/maps.py | 16 +- proto/marshal/marshal.py | 74 ++++++++ proto/message.py | 143 +++++++++------ tests/test_fields_optional.py | 1 + tests/test_marshal_types_wrappers_bool.py | 3 + 6 files changed, 381 insertions(+), 59 deletions(-) diff --git a/proto/fields.py b/proto/fields.py index cc98e8b0..f873462b 100644 --- a/proto/fields.py +++ b/proto/fields.py @@ -12,11 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datetime from enum import EnumMeta +from re import L +from proto.marshal.rules import wrappers +from proto.datetime_helpers import DatetimeWithNanoseconds +from proto.marshal.rules.dates import DurationRule +from typing import Any, Callable, Optional, Union from google.protobuf import descriptor_pb2 +from google.protobuf import duration_pb2 +from google.protobuf import struct_pb2 +from google.protobuf import timestamp_pb2 +from google.protobuf import wrappers_pb2 from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper +import proto from proto.primitives import ProtoType @@ -124,6 +135,10 @@ def package(self) -> str: """Return the package of the field.""" return self.mcls_data["package"] + @property + def inner_pb_type(self): + return None + @property def pb_type(self): """Return the composite type of the field, or None for primitives.""" @@ -140,6 +155,86 @@ def pb_type(self): return self.message.pb() return self.message + @property + def can_get_natively(self) -> bool: + if self.proto_type == ProtoType.MESSAGE and self.message == struct_pb2.Value: + return False + return True + + def can_set_natively(self, val: Any) -> bool: + if self.proto_type == ProtoType.MESSAGE and self.message == struct_pb2.Value: + return False + return True + # return self.pb_type is None and not self.repeated + + def contribute_to_class(self, cls, name: str): + set_coercion = None + if self.proto_type == ProtoType.STRING: + set_coercion = self._bytes_to_str + if self.pb_type == timestamp_pb2.Timestamp: + set_coercion = self._timestamp_to_datetime + if self.proto_type == ProtoType.MESSAGE and self.message == duration_pb2.Duration: + set_coercion = self._duration_to_timedelta + if self.proto_type == ProtoType.MESSAGE and self.message == wrappers_pb2.BoolValue: + set_coercion = self._bool_value_to_bool + if self.enum: + set_coercion = self._literal_to_enum + setattr(cls, name, self._get_field_descriptor_class()(name, cls=cls, set_coercion=set_coercion)) + + @staticmethod + def _get_field_descriptor_class(): + return _FieldDescriptor + + @property + def reverse_enum_map(self): + if not self.enum: + return None + if not getattr(self, '_reverse_enum_map', None): + self._reverse_enum_map = {e.value: e for e in self.enum} + return self._reverse_enum_map + + @property + def reverse_enum_names_map(self): + if not self.enum: + return None + if not getattr(self, '_reverse_enum_names_map', None): + self._reverse_enum_names_map = {e.name: e for e in self.enum} + return self._reverse_enum_names_map + + def _literal_to_enum(self, val: Any): + if isinstance(val, self.enum): + return val + return ( + self.reverse_enum_map.get(val, None) or + self.reverse_enum_names_map.get(val, None) + ) + + @staticmethod + def _bytes_to_str(val: Union[bytes, str]) -> str: + if type(val) == bytes: + val = val.decode('utf-8') + return val + + @staticmethod + def _timestamp_to_datetime(val: Union[timestamp_pb2.Timestamp, datetime.datetime]) -> datetime.datetime: + if type(val) == timestamp_pb2.Timestamp: + val = DatetimeWithNanoseconds.from_timestamp_pb(val) + return val + + @staticmethod + def _duration_to_timedelta(val: Union[duration_pb2.Duration, datetime.timedelta]) -> datetime.datetime: + if type(val) == duration_pb2.Duration: + val = DurationRule().to_python(val) + return val + + @staticmethod + def _bool_value_to_bool(val: Union[wrappers_pb2.BoolValue, bool]) -> Optional[bool]: + if val is None: + return None + if type(val) == wrappers_pb2.BoolValue: + val = val.value + return val + class RepeatedField(Field): """A representation of a repeated field in protocol buffers.""" @@ -154,6 +249,114 @@ def __init__(self, key_type, value_type, *, number: int, message=None, enum=None super().__init__(value_type, number=number, message=message, enum=enum) self.map_key_type = key_type + @property + def inner_pb_type(self): + return + + +class _FieldDescriptor: + def __init__(self, name: str, *, cls, set_coercion: Optional[Callable] = None): + # something like "id". required whenever reach back to the pb2 object. + self.original_name = name + # something like "_cached_id" + self.instance_attr_name = f'_cached_fields__{name}' + + # simple types coercion for setting attributes (for example, bytes -> str) + self._set_coercion = set_coercion or _FieldDescriptor._noop + self.cls = cls + + @property + def field(self): + return self.cls._meta.fields[self.original_name] + + def _hydrate_dicts(self, value: Any, instance): + if not isinstance(value, dict): + return value + + if self.field.proto_type == proto.MESSAGE: + _pb = self.field.message._meta.pb(**value) + value = self.field.message.wrap(_pb) + + return value + + def _clear_oneofs(self, instance): + if not self.field.oneof: + return + + for field_name, field in self.cls._meta.fields.items(): + # Don't clear this field + if field_name == self.original_name: + continue + + # Don't clear other fields with different oneof values, or with + # no such values at all + if field.oneof != self.field.oneof: + continue + + delattr(instance, field_name) + + def __set__(self, instance, value): + value = self._set_coercion(value) + value = self._hydrate_dicts(value, instance) + + always_commit: bool = getattr(instance, '_always_commit', False) + if always_commit or not self.field.can_set_natively(value): # or self.field.pb_type is not None or self.field.repeated: + # MapFields have a unique requirement to always eagerly commit their + # writes, as the MapComposite implementation does not used long-lived + # proto-plus types, and thus any information about dirty fields is + # discarded and leads to irrepairable de-sync between proto-plus and + # protobuf. + pb_value = instance._meta.marshal.to_proto(self.field.pb_type, value) + _pb = instance._meta.pb(**{self.original_name: pb_value}) + instance._pb.ClearField(self.original_name) + instance._pb.MergeFrom(_pb) + else: + + if value is None: + self.__delete__(instance) + instance._pb.MergeFrom(instance._meta._pb(**{self.original_name: None})) + return + + instance._meta.marshal.validate_primitives(self.field, value) + instance._mark_pb_stale(self.original_name) + + setattr(instance, self.instance_attr_name, value) + self._clear_oneofs(instance) + + def __get__(self, instance: 'proto.Message', owner): # type: ignore + # If `instance` is None, then we are accessing this field directly + # off the class itself instead of off an instance. + if instance is None: + return self.original_name + + value = getattr(instance, self.instance_attr_name, None) + if self.field.can_get_natively and value is not None: + return value + else: + # For the most part, only primitive values can be returned natively, meaning + # this is either a Message itself, in which case, since we're dealing with the + # underlying pb object, we need to sync all deferred fields. + # This is functionally a no-op if no fields have been deferred. + if hasattr(value, '_update_pb'): + value._update_pb() + + pb_value = getattr(instance._pb, self.original_name, None) + value = instance._meta.marshal.to_python(self.field.pb_type, pb_value, absent=self.original_name not in instance) + + setattr(instance, self.instance_attr_name, value) + return value + + def __delete__(self, instance): + if hasattr(instance, self.instance_attr_name): + delattr(instance, self.instance_attr_name) + instance._pb.ClearField(self.original_name) + if self.original_name in getattr(instance, '_stale_fields', []): + instance._stale_fields.remove(self.original_name) + + @staticmethod + def _noop(val): + return val + __all__ = ( "Field", diff --git a/proto/marshal/collections/maps.py b/proto/marshal/collections/maps.py index 8ed11349..a521c2e8 100644 --- a/proto/marshal/collections/maps.py +++ b/proto/marshal/collections/maps.py @@ -14,6 +14,7 @@ import collections +import proto from proto.utils import cached_property @@ -24,11 +25,19 @@ class MapComposite(collections.abc.MutableMapping): modify the underlying field container directly. """ + @cached_property + def entry_class(self): + return self.pb.GetEntryClass() + @cached_property def _pb_type(self): """Return the protocol buffer type for this sequence.""" # Huzzah, another hack. Still less bad than RepeatedComposite. - return type(self.pb.GetEntryClass()().value) + return type(self.entry_class().value) + + @cached_property + def _override_pb_type(self): + return getattr(self.entry_class._meta["fields"]["value"], "pb_override", None) def __init__(self, sequence, *, marshal): """Initialize a wrapper around a protobuf map. @@ -54,7 +63,10 @@ def __getitem__(self, key): # buffers will create the key if it does not exist. if key not in self: raise KeyError(key) - return self._marshal.to_python(self._pb_type, self.pb[key]) + obj = self._marshal.to_python(self._pb_type, self.pb[key]) + if isinstance(obj, proto.Message): + obj._always_commit = True + return obj def __setitem__(self, key, value): pb_value = self._marshal.to_proto(self._pb_type, value, strict=True) diff --git a/proto/marshal/marshal.py b/proto/marshal/marshal.py index c8224ce2..c4ce6e78 100644 --- a/proto/marshal/marshal.py +++ b/proto/marshal/marshal.py @@ -21,6 +21,7 @@ from google.protobuf import struct_pb2 from google.protobuf import wrappers_pb2 +import proto from proto.marshal import compat from proto.marshal.collections import MapComposite from proto.marshal.collections import Repeated @@ -30,6 +31,10 @@ from proto.marshal.rules import wrappers +max_uint_32 = 1<<32 - 1 +max_int_32 = 1<<31 - 1 +min_int_32 = max_int_32 * -1 + class Rule(abc.ABC): """Abstract class definition for marshal rules.""" @@ -122,6 +127,75 @@ def register_rule_class(rule_class: type): return register_rule_class + @staticmethod + def _throw_type_error(primitive): + raise TypeError(f"Unacceptable value of type {type(primitive)}") + + @staticmethod + def _throw_value_error(proto_type, primitive): + raise ValueError(f"Unacceptable value {type(primitive)} for type {proto_type}") + + def validate_primitives(self, field, primitive): + proto_type = field.proto_type + if field.repeated: + if primitive and not isinstance(primitive, (list, Repeated, RepeatedComposite,)): + self._throw_type_error(primitive) + if primitive: + primitive = primitive[0] + + _type = type(primitive) + if proto_type == proto.DOUBLE and _type != float: + self._throw_type_error(primitive) + elif proto_type == proto.FLOAT and _type != float: + self._throw_type_error(primitive) + elif proto_type == proto.INT64 and _type != int: + self._throw_type_error(primitive) + elif proto_type == proto.UINT64: + if _type != int: + self._throw_type_error(primitive) + if primitive < 0: + self._throw_value_error(proto_type, primitive) + elif proto_type == proto.INT32: + if _type != int: + self._throw_type_error(primitive) + if primitive < min_int_32 or primitive > max_int_32: + self._throw_value_error(proto_type, primitive) + elif proto_type == proto.FIXED64 and _type != int: + self._throw_type_error(primitive) + elif proto_type == proto.FIXED32 and _type != int: + self._throw_type_error(primitive) + elif proto_type == proto.BOOL and _type != bool: + self._throw_type_error(primitive) + elif proto_type == proto.STRING and _type not in (str, bytes,): + self._throw_type_error(primitive) + # elif proto_type == proto.MESSAGE: + # # This is not a primitive! + # self._throw_type_error(primitive) + elif proto_type == proto.BYTES and _type != bytes: + self._throw_type_error(primitive) + elif proto_type == proto.UINT32: + if _type != int: + self._throw_type_error(primitive) + if primitive < 0 or primitive > max_int_32: + self._throw_value_error(proto_type, primitive) + # elif proto_type == proto.ENUM: + # # This is not a primitive! + # self._throw_type_error(primitive) + elif proto_type == proto.SFIXED32: + if _type != int: + self._throw_type_error(primitive) + if primitive < min_int_32 or primitive > max_int_32: + self._throw_value_error(proto_type, primitive) + elif proto_type == proto.SFIXED64 and _type != int: + self._throw_type_error(primitive) + elif proto_type == proto.SINT32: + if _type != int: + self._throw_type_error(primitive) + if primitive < min_int_32 or primitive > max_int_32: + self._throw_value_error(proto_type, primitive) + elif proto_type == proto.SINT64 and _type != int: + self._throw_type_error(primitive) + def reset(self): """Reset the registry to its initial state.""" self._rules.clear() diff --git a/proto/message.py b/proto/message.py index e046d628..31911d5f 100644 --- a/proto/message.py +++ b/proto/message.py @@ -15,13 +15,15 @@ import collections import collections.abc import copy +import inspect import re -from typing import List, Type +from typing import Callable, List, Optional, Type from google.protobuf import descriptor_pb2 from google.protobuf import message from google.protobuf.json_format import MessageToDict, MessageToJson, Parse +import proto from proto import _file_info from proto import _package_info from proto.fields import Field @@ -31,6 +33,13 @@ from proto.primitives import ProtoType +def _has_contribute_to_class(value): + # Only call contribute_to_class() if it's bound. + return not inspect.isclass(value) and hasattr(value, 'contribute_to_class') + +class MapValueMessage: + isMapValue: bool = True + class MessageMeta(type): """A metaclass for building and registering Message subclasses.""" @@ -105,6 +114,7 @@ def __new__(mcls, name, bases, attrs): # Okay, now we deal with all the rest of the fields. # Iterate over all the attributes and separate the fields into # their own sequence. + contributable_attrs = {} fields = [] new_attrs = {} oneofs = collections.OrderedDict() @@ -127,6 +137,9 @@ def __new__(mcls, name, bases, attrs): "package": package, } + if _has_contribute_to_class(field): + contributable_attrs[key] = field + # Add the field to the list of fields. fields.append(field) # If this field is part of a "oneof", ensure the oneof itself @@ -248,6 +261,9 @@ def __new__(mcls, name, bases, attrs): # Run the superclass constructor. cls = super().__new__(mcls, name, bases, new_attrs) + for field_name, field in contributable_attrs.items(): + cls.add_to_class(field_name, field) + # The info class and fields need a reference to the class just created. cls._meta.parent = cls for field in cls._meta.fields.values(): @@ -269,6 +285,12 @@ def __new__(mcls, name, bases, attrs): def __prepare__(mcls, name, bases, **kwargs): return collections.OrderedDict() + def add_to_class(cls, name, value): + if _has_contribute_to_class(value): + value.contribute_to_class(cls, name) + else: + setattr(cls, name, value) + @property def meta(cls): return cls._meta @@ -289,6 +311,10 @@ def pb(cls, obj=None, *, coerce: bool = False): obj = cls(obj) else: raise TypeError("%r is not an instance of %s" % (obj, cls.__name__,)) + if hasattr(obj, '_update_pb'): + obj._update_pb() + if hasattr(obj, '_update_nested_pb'): + obj._update_nested_pb() return obj._pb def wrap(cls, pb): @@ -313,6 +339,8 @@ def serialize(cls, instance) -> bytes: Returns: bytes: The serialized representation of the protocol buffer. """ + if instance and type(instance) == cls: + instance._update_pb() return cls.pb(instance, coerce=True).SerializeToString() def deserialize(cls, payload: bytes) -> "Message": @@ -446,6 +474,14 @@ class Message(metaclass=MessageMeta): """ def __init__(self, mapping=None, *, ignore_unknown_fields=False, **kwargs): + + self._cached_pb = None + # Tracks any new values have had their serialization deferred, and thus whether we + # are ready to serialize immediately, or need to sync from the instance back to the + # underlying `_pb` + self._stale_fields: List[str] = [] + + # We accept several things for `mapping`: # * An instance of this class. # * An instance of the underlying protobuf descriptor class. @@ -454,7 +490,7 @@ def __init__(self, mapping=None, *, ignore_unknown_fields=False, **kwargs): if mapping is None: if not kwargs: # Special fast path for empty construction. - super().__setattr__("_pb", self._meta.pb()) + # `self._pb` is lazily initialized only when needed return mapping = kwargs @@ -539,6 +575,8 @@ def __contains__(self, key): bool: Whether the field's value corresponds to a non-empty wire serialization. """ + if key in getattr(self, '_stale_fields', []): + return True pb_value = getattr(self._pb, key) try: # Protocol buffers "HasField" is unfriendly; it only works @@ -552,12 +590,6 @@ def __contains__(self, key): except ValueError: return bool(pb_value) - def __delattr__(self, key): - """Delete the value on the given field. - - This is generally equivalent to setting a falsy value. - """ - self._pb.ClearField(key) def __eq__(self, other): """Return True if the messages are equal, False otherwise.""" @@ -572,38 +604,6 @@ def __eq__(self, other): # Ask the other object. return NotImplemented - def __getattr__(self, key): - """Retrieve the given field's value. - - In protocol buffers, the presence of a field on a message is - sufficient for it to always be "present". - - For primitives, a value of the correct type will always be returned - (the "falsy" values in protocol buffers consistently match those - in Python). For repeated fields, the falsy value is always an empty - sequence. - - For messages, protocol buffers does distinguish between an empty - message and absence, but this distinction is subtle and rarely - relevant. Therefore, this method always returns an empty message - (following the official implementation). To check for message - presence, use ``key in self`` (in other words, ``__contains__``). - - .. note:: - - Some well-known protocol buffer types - (e.g. ``google.protobuf.Timestamp``) will be converted to - their Python equivalents. See the ``marshal`` module for - more details. - """ - try: - pb_type = self._meta.fields[key].pb_type - pb_value = getattr(self._pb, key) - marshal = self._meta.marshal - return marshal.to_python(pb_type, pb_value, absent=key not in self) - except KeyError as ex: - raise AttributeError(str(ex)) - def __ne__(self, other): """Return True if the messages are unequal, False otherwise.""" return not self == other @@ -611,26 +611,48 @@ def __ne__(self, other): def __repr__(self): return repr(self._pb) - def __setattr__(self, key, value): - """Set the value on the given field. + @property + def _pb(self): + if self._cached_pb is None: + self._cached_pb = self._meta.pb() + return self._cached_pb - For well-known protocol buffer types which are marshalled, either - the protocol buffer object or the Python equivalent is accepted. - """ - if key[0] == "_": - return super().__setattr__(key, value) - marshal = self._meta.marshal - pb_type = self._meta.fields[key].pb_type - pb_value = marshal.to_proto(pb_type, value) + @_pb.setter + def _pb(self, value): + self._cached_pb = value + # return super().__setattr__('_cached_pb', value) - # Clear the existing field. - # This is the only way to successfully write nested falsy values, - # because otherwise MergeFrom will no-op on them. - self._pb.ClearField(key) + def _mark_pb_stale(self, field_name: str): + if not hasattr(self, '_stale_fields'): + self._stale_fields = [] + self._stale_fields.append(field_name) - # Merge in the value being set. - if pb_value is not None: - self._pb.MergeFrom(self._meta.pb(**{key: pb_value})) + def _mark_pb_synced(self): + self._stale_fields = [] + + def _update_nested_pb(self): + for field_name, field in self._meta.fields.items(): + if field.proto_type == proto.MESSAGE: + obj = getattr(self, field_name, None) + if obj and hasattr(obj, '_update_pb'): + obj._update_pb() + if obj and hasattr(obj, '_update_nested_pb'): + obj._update_nested_pb() + + def _update_pb(self): + merge_params = {} + for field_name in getattr(self, '_stale_fields', []): + wrapper_value = getattr(self, field_name, None) + + field = self._meta.fields[field_name] + pb_value = self._meta.marshal.to_proto(field.pb_type, wrapper_value) + merge_params[field_name] = pb_value + + self._pb.ClearField(field_name) + + if merge_params: + self._pb.MergeFrom(self._meta.pb(**merge_params)) + self._mark_pb_synced() class _MessageInfo: @@ -665,6 +687,13 @@ def __init__( self.marshal = marshal self._pb = None + def __repr__(self) -> str: + return ( + f"_MessageInfo(fields={self.fields}, package={self.package}, full_name={self.full_name}, " + f"options={self.options}, fields_by_number={self.fields_by_number}, marshal={self.marshal}, " + f"_pb={self._pb})" + ) + @property def pb(self) -> Type[message.Message]: """Return the protobuf message type for this descriptor. diff --git a/tests/test_fields_optional.py b/tests/test_fields_optional.py index a15904c1..bc4defed 100644 --- a/tests/test_fields_optional.py +++ b/tests/test_fields_optional.py @@ -59,6 +59,7 @@ class Squid(proto.Message): assert s.mass_kg == 20 assert not s.mass_lbs assert Squid.iridiphore_num in s + assert s.iridiphore_num == 600 s = Squid(mass_lbs=40, iridiphore_num=600) assert not s.mass_kg diff --git a/tests/test_marshal_types_wrappers_bool.py b/tests/test_marshal_types_wrappers_bool.py index ffe206a6..9f8d25dd 100644 --- a/tests/test_marshal_types_wrappers_bool.py +++ b/tests/test_marshal_types_wrappers_bool.py @@ -66,8 +66,11 @@ class Foo(proto.Message): bar = proto.Field(proto.MESSAGE, message=wrappers_pb2.BoolValue, number=1,) foo = Foo(bar=True) + assert foo.bar is True foo.bar = wrappers_pb2.BoolValue() assert foo.bar is False + foo.bar = wrappers_pb2.BoolValue(value=True) + assert foo.bar is True def test_bool_value_del(): From ff3c6f084ba0389d51a12bb2a5f1304a9396c2c1 Mon Sep 17 00:00:00 2001 From: Craig Labenz Date: Fri, 14 May 2021 10:36:35 -0700 Subject: [PATCH 2/7] added comments and docstrings --- proto/fields.py | 88 +++++++++++++++++++++++++++++++++++++++--------- proto/message.py | 15 ++++++++- 2 files changed, 87 insertions(+), 16 deletions(-) diff --git a/proto/fields.py b/proto/fields.py index f873462b..0cb3283b 100644 --- a/proto/fields.py +++ b/proto/fields.py @@ -168,8 +168,17 @@ def can_set_natively(self, val: Any) -> bool: # return self.pb_type is None and not self.repeated def contribute_to_class(self, cls, name: str): + """Attaches a descriptor to the top-level proto.Message class, so that attribute + reads and writes can be specially handled in `_FieldDescriptor.__get__` and + `FieldDescriptor.__set__`. + + Also contains hooks for write-time type-coersion to translate special cases between + pure Pythonic objects and pb2-compatible structs or values. + """ set_coercion = None if self.proto_type == ProtoType.STRING: + # Bytes are accepted for string values, but strings are not accepted for byte values. + # This is an artifact of older Python2 implementations. set_coercion = self._bytes_to_str if self.pb_type == timestamp_pb2.Timestamp: set_coercion = self._timestamp_to_datetime @@ -179,14 +188,15 @@ def contribute_to_class(self, cls, name: str): set_coercion = self._bool_value_to_bool if self.enum: set_coercion = self._literal_to_enum - setattr(cls, name, self._get_field_descriptor_class()(name, cls=cls, set_coercion=set_coercion)) - - @staticmethod - def _get_field_descriptor_class(): - return _FieldDescriptor + setattr(cls, name, _FieldDescriptor(name, cls=cls, set_coercion=set_coercion)) @property def reverse_enum_map(self): + """Helper that allows for constant-time lookup on self.enum, used to hydrate + primitives that are supplied but which stand for their official enum types. + + This is used when a developer supplies the literal value for an enum type (often an int). + """ if not self.enum: return None if not getattr(self, '_reverse_enum_map', None): @@ -195,6 +205,11 @@ def reverse_enum_map(self): @property def reverse_enum_names_map(self): + """Helper that allows for constant-time lookup on self.enum, used to hydrate + primitives that are supplied but which stand for their official enum types. + + This is used when a developer supplies the string value for an enum type's name. + """ if not self.enum: return None if not getattr(self, '_reverse_enum_names_map', None): @@ -255,6 +270,18 @@ def inner_pb_type(self): class _FieldDescriptor: + """Handler for proto.Field access on any proto.Message object. + + Wraps each proto.Field instance within a given proto.Message subclass's definition + with getters and setters that allow for caching of values on the proto-plus object, + deferment of syncing to the underlying pb2 object, and tracking of the current state. + + Special treatment is given to MapFields, nested Messages, and certain data types, as + their various implementations within pb2 (which for our purposes is mostly a black box) + sometimes mandate immediate syncing. This is usually because proto-plus objects are not + long-lived, and thus information about which fields are stale would be lost if syncing + was left for serialization time. + """ def __init__(self, name: str, *, cls, set_coercion: Optional[Callable] = None): # something like "id". required whenever reach back to the pb2 object. self.original_name = name @@ -269,7 +296,10 @@ def __init__(self, name: str, *, cls, set_coercion: Optional[Callable] = None): def field(self): return self.cls._meta.fields[self.original_name] - def _hydrate_dicts(self, value: Any, instance): + def _hydrate_dicts(self, value: Any): + """Turns a dictionary assigned to a nested Message into a full instance of + that Message type. + """ if not isinstance(value, dict): return value @@ -296,16 +326,31 @@ def _clear_oneofs(self, instance): delattr(instance, field_name) def __set__(self, instance, value): - value = self._set_coercion(value) - value = self._hydrate_dicts(value, instance) + """Called whenever a value is assigned to a proto.Field attribute on an instantiated + proto.Message object. + + Usage: + class MyMessage(proto.Message): + name = proto.Field(proto.STRING, number=1) + + my_message = MyMessage() + my_message.name = "Frodo" + + In the above scenario, `__set__` is called with "Frodo" passed as `value` and `my_instance` + passed as `instance`. + """ + value = self._set_coercion(value) + value = self._hydrate_dicts(value) + + # Warning: `always_commit` is hacky! + # Some contexts, particularly instances created from MapFields, require immediate syncing. + # It is impossible to deduce such a scenario purely from logic available to this function, + # so instead we set a flag on instances when a MapField yields them, and then when those + # instances receive attribute updates, immediately syncing those values to the underlying + # pb2 instance is sufficient. always_commit: bool = getattr(instance, '_always_commit', False) - if always_commit or not self.field.can_set_natively(value): # or self.field.pb_type is not None or self.field.repeated: - # MapFields have a unique requirement to always eagerly commit their - # writes, as the MapComposite implementation does not used long-lived - # proto-plus types, and thus any information about dirty fields is - # discarded and leads to irrepairable de-sync between proto-plus and - # protobuf. + if always_commit or not self.field.can_set_natively(value): pb_value = instance._meta.marshal.to_proto(self.field.pb_type, value) _pb = instance._meta.pb(**{self.original_name: pb_value}) instance._pb.ClearField(self.original_name) @@ -323,7 +368,20 @@ def __set__(self, instance, value): setattr(instance, self.instance_attr_name, value) self._clear_oneofs(instance) - def __get__(self, instance: 'proto.Message', owner): # type: ignore + def __get__(self, instance: 'proto.Message', _): # type: ignore + """Called whenever a value is read from a proto.Field attribute on an instantiated + proto.Message object. + + Usage: + + class MyMessage(proto.Message): + name = proto.Field(proto.STRING, number=1) + + my_message = MyMessage(name="Frodo") + print(my_message.name) + + In the above scenario, `__get__` is called with "my_message" passed as `instance`. + """ # If `instance` is None, then we are accessing this field directly # off the class itself instead of off an instance. if instance is None: diff --git a/proto/message.py b/proto/message.py index 31911d5f..1200cb78 100644 --- a/proto/message.py +++ b/proto/message.py @@ -261,6 +261,8 @@ def __new__(mcls, name, bases, attrs): # Run the superclass constructor. cls = super().__new__(mcls, name, bases, new_attrs) + # Now that the class officially exists, but is not yet finalized, allow + # individual attributes to run functions to attach themselves in special ways. for field_name, field in contributable_attrs.items(): cls.add_to_class(field_name, field) @@ -286,6 +288,8 @@ def __prepare__(mcls, name, bases, **kwargs): return collections.OrderedDict() def add_to_class(cls, name, value): + """Hook for attributes on the class definition to attach themselves + in special ways.""" if _has_contribute_to_class(value): value.contribute_to_class(cls, name) else: @@ -620,9 +624,12 @@ def _pb(self): @_pb.setter def _pb(self, value): self._cached_pb = value - # return super().__setattr__('_cached_pb', value) def _mark_pb_stale(self, field_name: str): + """We often set fields on the proto-plus object (this), and do not immediately + sync them down to the underlying pb2 object. This mechanism tracks which + fields are stale in this way. + """ if not hasattr(self, '_stale_fields'): self._stale_fields = [] self._stale_fields.append(field_name) @@ -631,6 +638,10 @@ def _mark_pb_synced(self): self._stale_fields = [] def _update_nested_pb(self): + """When it is time to serialize a pb2 object, it does not do to sync just ourselves - + we must also recursively search for nested proto-plus objects that may also require + a sync. + """ for field_name, field in self._meta.fields.items(): if field.proto_type == proto.MESSAGE: obj = getattr(self, field_name, None) @@ -640,6 +651,8 @@ def _update_nested_pb(self): obj._update_nested_pb() def _update_pb(self): + """Loops over any stale fields and syncs them to the underlying pb2 object. + """ merge_params = {} for field_name in getattr(self, '_stale_fields', []): wrapper_value = getattr(self, field_name, None) From 1109927b0dfb895c68545484a979d668359d4541 Mon Sep 17 00:00:00 2001 From: Craig Labenz Date: Fri, 14 May 2021 10:53:54 -0700 Subject: [PATCH 3/7] removed relic from abandoned experiment --- proto/fields.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/proto/fields.py b/proto/fields.py index 0cb3283b..bcd78d82 100644 --- a/proto/fields.py +++ b/proto/fields.py @@ -135,10 +135,6 @@ def package(self) -> str: """Return the package of the field.""" return self.mcls_data["package"] - @property - def inner_pb_type(self): - return None - @property def pb_type(self): """Return the composite type of the field, or None for primitives.""" @@ -264,10 +260,6 @@ def __init__(self, key_type, value_type, *, number: int, message=None, enum=None super().__init__(value_type, number=number, message=message, enum=enum) self.map_key_type = key_type - @property - def inner_pb_type(self): - return - class _FieldDescriptor: """Handler for proto.Field access on any proto.Message object. From d04f72ea69b460be02e0a879cc4b86b140ff8061 Mon Sep 17 00:00:00 2001 From: Craig Labenz Date: Fri, 14 May 2021 11:00:14 -0700 Subject: [PATCH 4/7] more clean up on abandoned efforts --- proto/marshal/collections/maps.py | 10 +--------- proto/marshal/marshal.py | 17 +++++++++++------ 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/proto/marshal/collections/maps.py b/proto/marshal/collections/maps.py index a521c2e8..c43712c3 100644 --- a/proto/marshal/collections/maps.py +++ b/proto/marshal/collections/maps.py @@ -25,19 +25,11 @@ class MapComposite(collections.abc.MutableMapping): modify the underlying field container directly. """ - @cached_property - def entry_class(self): - return self.pb.GetEntryClass() - @cached_property def _pb_type(self): """Return the protocol buffer type for this sequence.""" # Huzzah, another hack. Still less bad than RepeatedComposite. - return type(self.entry_class().value) - - @cached_property - def _override_pb_type(self): - return getattr(self.entry_class._meta["fields"]["value"], "pb_override", None) + return type(self.self.pb.GetEntryClass()().value) def __init__(self, sequence, *, marshal): """Initialize a wrapper around a protobuf map. diff --git a/proto/marshal/marshal.py b/proto/marshal/marshal.py index c4ce6e78..b0212953 100644 --- a/proto/marshal/marshal.py +++ b/proto/marshal/marshal.py @@ -136,6 +136,11 @@ def _throw_value_error(proto_type, primitive): raise ValueError(f"Unacceptable value {type(primitive)} for type {proto_type}") def validate_primitives(self, field, primitive): + """Replicates validation logic when assigning values to proto.Field attributes + on instantiated proto.Message objects. This is required to recreate the immediacy + of ValueError and TypeError checks that the pb2 layer made, but which are now + deferred because syncing to the pb2 layer itself is often entirely deferred. + """ proto_type = field.proto_type if field.repeated: if primitive and not isinstance(primitive, (list, Repeated, RepeatedComposite,)): @@ -168,9 +173,9 @@ def validate_primitives(self, field, primitive): self._throw_type_error(primitive) elif proto_type == proto.STRING and _type not in (str, bytes,): self._throw_type_error(primitive) - # elif proto_type == proto.MESSAGE: - # # This is not a primitive! - # self._throw_type_error(primitive) + elif proto_type == proto.MESSAGE: + # Do nothing - this is not a primitive + pass elif proto_type == proto.BYTES and _type != bytes: self._throw_type_error(primitive) elif proto_type == proto.UINT32: @@ -178,9 +183,9 @@ def validate_primitives(self, field, primitive): self._throw_type_error(primitive) if primitive < 0 or primitive > max_int_32: self._throw_value_error(proto_type, primitive) - # elif proto_type == proto.ENUM: - # # This is not a primitive! - # self._throw_type_error(primitive) + elif proto_type == proto.ENUM: + # Do nothing - this is not a primitive + pass elif proto_type == proto.SFIXED32: if _type != int: self._throw_type_error(primitive) From f57c932f6cf160f3b757f9bbe55d17a8b3c4c757 Mon Sep 17 00:00:00 2001 From: Craig Labenz Date: Tue, 18 May 2021 13:09:02 -0700 Subject: [PATCH 5/7] responded to code review --- proto/fields.py | 44 ++++++++++++++----------------- proto/marshal/collections/maps.py | 2 +- 2 files changed, 21 insertions(+), 25 deletions(-) diff --git a/proto/fields.py b/proto/fields.py index bcd78d82..40df3681 100644 --- a/proto/fields.py +++ b/proto/fields.py @@ -157,11 +157,11 @@ def can_get_natively(self) -> bool: return False return True - def can_set_natively(self, val: Any) -> bool: + @property + def can_set_natively(self) -> bool: if self.proto_type == ProtoType.MESSAGE and self.message == struct_pb2.Value: return False return True - # return self.pb_type is None and not self.repeated def contribute_to_class(self, cls, name: str): """Attaches a descriptor to the top-level proto.Message class, so that attribute @@ -176,13 +176,13 @@ def contribute_to_class(self, cls, name: str): # Bytes are accepted for string values, but strings are not accepted for byte values. # This is an artifact of older Python2 implementations. set_coercion = self._bytes_to_str - if self.pb_type == timestamp_pb2.Timestamp: + elif self.pb_type == timestamp_pb2.Timestamp: set_coercion = self._timestamp_to_datetime - if self.proto_type == ProtoType.MESSAGE and self.message == duration_pb2.Duration: + elif self.proto_type == ProtoType.MESSAGE and self.message == duration_pb2.Duration: set_coercion = self._duration_to_timedelta - if self.proto_type == ProtoType.MESSAGE and self.message == wrappers_pb2.BoolValue: + elif self.proto_type == ProtoType.MESSAGE and self.message == wrappers_pb2.BoolValue: set_coercion = self._bool_value_to_bool - if self.enum: + elif self.enum: set_coercion = self._literal_to_enum setattr(cls, name, _FieldDescriptor(name, cls=cls, set_coercion=set_coercion)) @@ -281,7 +281,7 @@ def __init__(self, name: str, *, cls, set_coercion: Optional[Callable] = None): self.instance_attr_name = f'_cached_fields__{name}' # simple types coercion for setting attributes (for example, bytes -> str) - self._set_coercion = set_coercion or _FieldDescriptor._noop + self._set_coercion: Optional[Callable] = set_coercion self.cls = cls @property @@ -318,8 +318,8 @@ def _clear_oneofs(self, instance): delattr(instance, field_name) def __set__(self, instance, value): - """Called whenever a value is assigned to a proto.Field attribute on an instantiated - proto.Message object. + """Called whenever a value is assigned to a `proto.Field` attribute on an instantiated + `proto.Message` object. Usage: @@ -329,10 +329,10 @@ class MyMessage(proto.Message): my_message = MyMessage() my_message.name = "Frodo" - In the above scenario, `__set__` is called with "Frodo" passed as `value` and `my_instance` - passed as `instance`. + In the above scenario, `__set__` is called with "Frodo" passed as `value` and + `my_message` passed as `instance`. """ - value = self._set_coercion(value) + value = self._set_coercion(value) if self._set_coercion is not None else value value = self._hydrate_dicts(value) # Warning: `always_commit` is hacky! @@ -342,7 +342,7 @@ class MyMessage(proto.Message): # instances receive attribute updates, immediately syncing those values to the underlying # pb2 instance is sufficient. always_commit: bool = getattr(instance, '_always_commit', False) - if always_commit or not self.field.can_set_natively(value): + if always_commit or not self.field.can_set_natively: pb_value = instance._meta.marshal.to_proto(self.field.pb_type, value) _pb = instance._meta.pb(**{self.original_name: pb_value}) instance._pb.ClearField(self.original_name) @@ -382,13 +382,13 @@ class MyMessage(proto.Message): value = getattr(instance, self.instance_attr_name, None) if self.field.can_get_natively and value is not None: return value - else: - # For the most part, only primitive values can be returned natively, meaning - # this is either a Message itself, in which case, since we're dealing with the - # underlying pb object, we need to sync all deferred fields. - # This is functionally a no-op if no fields have been deferred. - if hasattr(value, '_update_pb'): - value._update_pb() + + # For the most part, only primitive values can be returned natively, meaning + # this is either a Message itself, in which case, since we're dealing with the + # underlying pb object, we need to sync all deferred fields. + # This is functionally a no-op if no fields have been deferred. + if hasattr(value, '_update_pb'): + value._update_pb() pb_value = getattr(instance._pb, self.original_name, None) value = instance._meta.marshal.to_python(self.field.pb_type, pb_value, absent=self.original_name not in instance) @@ -403,10 +403,6 @@ def __delete__(self, instance): if self.original_name in getattr(instance, '_stale_fields', []): instance._stale_fields.remove(self.original_name) - @staticmethod - def _noop(val): - return val - __all__ = ( "Field", diff --git a/proto/marshal/collections/maps.py b/proto/marshal/collections/maps.py index c43712c3..722a34be 100644 --- a/proto/marshal/collections/maps.py +++ b/proto/marshal/collections/maps.py @@ -29,7 +29,7 @@ class MapComposite(collections.abc.MutableMapping): def _pb_type(self): """Return the protocol buffer type for this sequence.""" # Huzzah, another hack. Still less bad than RepeatedComposite. - return type(self.self.pb.GetEntryClass()().value) + return type(self.pb.GetEntryClass()().value) def __init__(self, sequence, *, marshal): """Initialize a wrapper around a protobuf map. From fe17aaea7a574df1b76fe77bab19495fbff47023 Mon Sep 17 00:00:00 2001 From: Craig Labenz Date: Thu, 20 May 2021 11:01:11 -0700 Subject: [PATCH 6/7] responses to code review --- proto/fields.py | 53 ++++++++++++++++++++---------- proto/marshal/marshal.py | 20 ++++++----- proto/message.py | 16 ++++----- tests/test_fields_map_composite.py | 17 ++++++++++ 4 files changed, 72 insertions(+), 34 deletions(-) diff --git a/proto/fields.py b/proto/fields.py index 40df3681..9414b317 100644 --- a/proto/fields.py +++ b/proto/fields.py @@ -15,9 +15,10 @@ import datetime from enum import EnumMeta from re import L -from proto.marshal.rules import wrappers from proto.datetime_helpers import DatetimeWithNanoseconds +from proto.marshal.rules import wrappers from proto.marshal.rules.dates import DurationRule +from proto.utils import cached_property from typing import Any, Callable, Optional, Union from google.protobuf import descriptor_pb2 @@ -186,31 +187,23 @@ def contribute_to_class(self, cls, name: str): set_coercion = self._literal_to_enum setattr(cls, name, _FieldDescriptor(name, cls=cls, set_coercion=set_coercion)) - @property + @cached_property def reverse_enum_map(self): """Helper that allows for constant-time lookup on self.enum, used to hydrate primitives that are supplied but which stand for their official enum types. This is used when a developer supplies the literal value for an enum type (often an int). """ - if not self.enum: - return None - if not getattr(self, '_reverse_enum_map', None): - self._reverse_enum_map = {e.value: e for e in self.enum} - return self._reverse_enum_map + return {e.value: e for e in self.enum} if self.enum else None - @property + @cached_property def reverse_enum_names_map(self): """Helper that allows for constant-time lookup on self.enum, used to hydrate primitives that are supplied but which stand for their official enum types. This is used when a developer supplies the string value for an enum type's name. """ - if not self.enum: - return None - if not getattr(self, '_reverse_enum_names_map', None): - self._reverse_enum_names_map = {e.name: e for e in self.enum} - return self._reverse_enum_names_map + return {e.name: e for e in self.enum} if self.enum else None def _literal_to_enum(self, val: Any): if isinstance(val, self.enum): @@ -274,13 +267,27 @@ class _FieldDescriptor: long-lived, and thus information about which fields are stale would be lost if syncing was left for serialization time. """ + + # Namespace for attributes where we will store the Pythonic values of + # various `proto.Field` classes on instantiated `proto.Message` objects. + # For example, in the following scenario, this attribute is involved in + # saving the value "Homer Simpson" to `my_message._cached_fields__name`. + # + # class MyMessage(proto.Message): + # name = proto.Field(proto.STRING, ...) + # + # my_message = MyMessage() + # my_message.name = "Homer Simpson" # saves to `_cached_fields__name` + cached_fields_prefix = '_cached_fields__' + def __init__(self, name: str, *, cls, set_coercion: Optional[Callable] = None): # something like "id". required whenever reach back to the pb2 object. self.original_name = name # something like "_cached_id" - self.instance_attr_name = f'_cached_fields__{name}' + self.instance_attr_name = f'{self.cached_fields_prefix}{name}' - # simple types coercion for setting attributes (for example, bytes -> str) + # simple types coercion for setting attributes + # (e.g., bytes -> str if our type is string, but we are supplied bytes) self._set_coercion: Optional[Callable] = set_coercion self.cls = cls @@ -379,8 +386,8 @@ class MyMessage(proto.Message): if instance is None: return self.original_name - value = getattr(instance, self.instance_attr_name, None) - if self.field.can_get_natively and value is not None: + value = getattr(instance, self.instance_attr_name, _none) + if self.field.can_get_natively and value is not _none: return value # For the most part, only primitive values can be returned natively, meaning @@ -404,6 +411,18 @@ def __delete__(self, instance): instance._stale_fields.remove(self.original_name) +class _NoneType: + def __bool__(self): + return False + + def __eq__(self, other): + """All _NoneType instances are equal""" + return isinstance(other, _NoneType) + + +_none = _NoneType() + + __all__ = ( "Field", "MapField", diff --git a/proto/marshal/marshal.py b/proto/marshal/marshal.py index b0212953..8af25141 100644 --- a/proto/marshal/marshal.py +++ b/proto/marshal/marshal.py @@ -33,7 +33,7 @@ max_uint_32 = 1<<32 - 1 max_int_32 = 1<<31 - 1 -min_int_32 = max_int_32 * -1 +min_int_32 = (max_int_32 * -1) + 1 class Rule(abc.ABC): """Abstract class definition for marshal rules.""" @@ -143,13 +143,21 @@ def validate_primitives(self, field, primitive): """ proto_type = field.proto_type if field.repeated: - if primitive and not isinstance(primitive, (list, Repeated, RepeatedComposite,)): + if primitive and not isinstance(primitive, (list, tuple, Repeated, RepeatedComposite,)): self._throw_type_error(primitive) - if primitive: + + # Typically, checking a sequence's length after checking its Truthiness is + # unnecessary, but this will make us safe from any unexpected behavior from + # `Repeated` or `RepeatedComposite`. + if primitive and len(primitive) > 0: primitive = primitive[0] _type = type(primitive) - if proto_type == proto.DOUBLE and _type != float: + if proto_type == proto.BOOL and _type != bool: + self._throw_type_error(primitive) + elif proto_type == proto.STRING and _type not in (str, bytes,): + self._throw_type_error(primitive) + elif proto_type == proto.DOUBLE and _type != float: self._throw_type_error(primitive) elif proto_type == proto.FLOAT and _type != float: self._throw_type_error(primitive) @@ -169,10 +177,6 @@ def validate_primitives(self, field, primitive): self._throw_type_error(primitive) elif proto_type == proto.FIXED32 and _type != int: self._throw_type_error(primitive) - elif proto_type == proto.BOOL and _type != bool: - self._throw_type_error(primitive) - elif proto_type == proto.STRING and _type not in (str, bytes,): - self._throw_type_error(primitive) elif proto_type == proto.MESSAGE: # Do nothing - this is not a primitive pass diff --git a/proto/message.py b/proto/message.py index 1200cb78..c179bbb5 100644 --- a/proto/message.py +++ b/proto/message.py @@ -17,7 +17,7 @@ import copy import inspect import re -from typing import Callable, List, Optional, Type +from typing import List, Set, Type from google.protobuf import descriptor_pb2 from google.protobuf import message @@ -37,8 +37,6 @@ def _has_contribute_to_class(value): # Only call contribute_to_class() if it's bound. return not inspect.isclass(value) and hasattr(value, 'contribute_to_class') -class MapValueMessage: - isMapValue: bool = True class MessageMeta(type): """A metaclass for building and registering Message subclasses.""" @@ -483,7 +481,7 @@ def __init__(self, mapping=None, *, ignore_unknown_fields=False, **kwargs): # Tracks any new values have had their serialization deferred, and thus whether we # are ready to serialize immediately, or need to sync from the instance back to the # underlying `_pb` - self._stale_fields: List[str] = [] + self._stale_fields: Set[str] = set() # We accept several things for `mapping`: @@ -579,7 +577,7 @@ def __contains__(self, key): bool: Whether the field's value corresponds to a non-empty wire serialization. """ - if key in getattr(self, '_stale_fields', []): + if key in getattr(self, '_stale_fields', set()): return True pb_value = getattr(self._pb, key) try: @@ -631,11 +629,11 @@ def _mark_pb_stale(self, field_name: str): fields are stale in this way. """ if not hasattr(self, '_stale_fields'): - self._stale_fields = [] - self._stale_fields.append(field_name) + self._stale_fields = set() + self._stale_fields.add(field_name) def _mark_pb_synced(self): - self._stale_fields = [] + self._stale_fields = set() def _update_nested_pb(self): """When it is time to serialize a pb2 object, it does not do to sync just ourselves - @@ -654,7 +652,7 @@ def _update_pb(self): """Loops over any stale fields and syncs them to the underlying pb2 object. """ merge_params = {} - for field_name in getattr(self, '_stale_fields', []): + for field_name in getattr(self, '_stale_fields', set()): wrapper_value = getattr(self, field_name, None) field = self._meta.fields[field_name] diff --git a/tests/test_fields_map_composite.py b/tests/test_fields_map_composite.py index 13bf0c5f..601cfd64 100644 --- a/tests/test_fields_map_composite.py +++ b/tests/test_fields_map_composite.py @@ -31,6 +31,23 @@ class Baz(proto.Message): assert "k" not in baz.foos +def test_composite_map_with_stale_fields(): + class Foo(proto.Message): + bar = proto.Field(proto.INT32, number=1) + + class Baz(proto.Message): + foos = proto.MapField(proto.STRING, proto.MESSAGE, number=1, message=Foo,) + name = proto.Field(proto.STRING, number=2) + + baz = Baz(foos={"i": Foo(bar=42), "j": Foo(bar=24)}) + baz.name = "New Value" + assert "name" in baz._stale_fields + foo = baz.foos["i"] + foo.bar = 100 + assert Baz.to_dict(baz)["name"] == "New Value" + assert Baz.to_dict(baz)["foos"]["i"]["bar"] == 100 + + def test_composite_map_dict(): class Foo(proto.Message): bar = proto.Field(proto.INT32, number=1) From d723f18b92a49244a0c3609506ff77b92a5ec4b6 Mon Sep 17 00:00:00 2001 From: Craig Labenz Date: Mon, 24 May 2021 14:43:25 -0700 Subject: [PATCH 7/7] responded to more code review --- proto/fields.py | 49 +++++++++++++++--------- proto/marshal/collections/maps.py | 63 +++++++++++++++++++++++++++---- proto/message.py | 8 ++-- tests/test_fields_composite.py | 1 + 4 files changed, 92 insertions(+), 29 deletions(-) diff --git a/proto/fields.py b/proto/fields.py index 9414b317..b0d9a742 100644 --- a/proto/fields.py +++ b/proto/fields.py @@ -18,6 +18,7 @@ from proto.datetime_helpers import DatetimeWithNanoseconds from proto.marshal.rules import wrappers from proto.marshal.rules.dates import DurationRule +from proto.marshal.collections.maps import MapComposite from proto.utils import cached_property from typing import Any, Callable, Optional, Union @@ -153,16 +154,11 @@ def pb_type(self): return self.message @property - def can_get_natively(self) -> bool: - if self.proto_type == ProtoType.MESSAGE and self.message == struct_pb2.Value: - return False - return True - - @property - def can_set_natively(self) -> bool: - if self.proto_type == ProtoType.MESSAGE and self.message == struct_pb2.Value: - return False - return True + def can_represent_natively(self) -> bool: + return not ( + self.proto_type == ProtoType.MESSAGE and + self.message == struct_pb2.Value + ) def contribute_to_class(self, cls, name: str): """Attaches a descriptor to the top-level proto.Message class, so that attribute @@ -288,6 +284,9 @@ def __init__(self, name: str, *, cls, set_coercion: Optional[Callable] = None): # simple types coercion for setting attributes # (e.g., bytes -> str if our type is string, but we are supplied bytes) + # the signature of `set_coercion` is dependent on the field's data types + # and is always handled by `contribute_to_class` which pairs data types + # to appropriate write-time coercions. self._set_coercion: Optional[Callable] = set_coercion self.cls = cls @@ -349,7 +348,7 @@ class MyMessage(proto.Message): # instances receive attribute updates, immediately syncing those values to the underlying # pb2 instance is sufficient. always_commit: bool = getattr(instance, '_always_commit', False) - if always_commit or not self.field.can_set_natively: + if always_commit or not self.field.can_represent_natively: pb_value = instance._meta.marshal.to_proto(self.field.pb_type, value) _pb = instance._meta.pb(**{self.original_name: pb_value}) instance._pb.ClearField(self.original_name) @@ -379,7 +378,8 @@ class MyMessage(proto.Message): my_message = MyMessage(name="Frodo") print(my_message.name) - In the above scenario, `__get__` is called with "my_message" passed as `instance`. + In the above scenario, `__get__` is called with "my_message" passed as + `instance`. """ # If `instance` is None, then we are accessing this field directly # off the class itself instead of off an instance. @@ -387,18 +387,31 @@ class MyMessage(proto.Message): return self.original_name value = getattr(instance, self.instance_attr_name, _none) - if self.field.can_get_natively and value is not _none: + is_map: bool = isinstance(value, MapComposite) + + # Return any values that do not require immediate rehydration. + # A few notes: + # * primitives are simple, and so can be returned + # * `Messages` are already Pythonic, and so can be returned + # * `Values` are wrappers and so have to be unwrapped + # * The exception to this is MapComposites, which have the same + # types as Values, but which handle their own field caching and + # thus can be returned when pulled off the `instance`. + if value is not _none and (is_map or self.field.can_represent_natively): return value - # For the most part, only primitive values can be returned natively, meaning - # this is either a Message itself, in which case, since we're dealing with the - # underlying pb object, we need to sync all deferred fields. - # This is functionally a no-op if no fields have been deferred. + # For the most part, only primitive values can be returned natively, + # meaning this is either a Message itself, in which case, since we're + # dealing with the underlying pb object, we need to sync all deferred + # fields. This is functionally a no-op if no fields have been deferred. if hasattr(value, '_update_pb'): value._update_pb() pb_value = getattr(instance._pb, self.original_name, None) - value = instance._meta.marshal.to_python(self.field.pb_type, pb_value, absent=self.original_name not in instance) + value = instance._meta.marshal.to_python( + self.field.pb_type, pb_value, + absent=self.original_name not in instance, + ) setattr(instance, self.instance_attr_name, value) return value diff --git a/proto/marshal/collections/maps.py b/proto/marshal/collections/maps.py index 722a34be..ecf928bb 100644 --- a/proto/marshal/collections/maps.py +++ b/proto/marshal/collections/maps.py @@ -13,6 +13,7 @@ # limitations under the License. import collections +from typing import Dict, Set import proto from proto.utils import cached_property @@ -41,6 +42,8 @@ def __init__(self, sequence, *, marshal): """ self._pb = sequence self._marshal = marshal + self._item_cache: Dict = {} + self._stale_keys: Set[str] = set() def __contains__(self, key): # Protocol buffers is so permissive that querying for the existence @@ -48,19 +51,50 @@ def __contains__(self, key): # # By taking a tuple of the keys and querying that, we avoid sending # the lookup to protocol buffers and therefore avoid creating the key. + if key in self._item_cache: + return True return key in tuple(self.keys()) def __getitem__(self, key): # We handle raising KeyError ourselves, because otherwise protocol # buffers will create the key if it does not exist. - if key not in self: - raise KeyError(key) - obj = self._marshal.to_python(self._pb_type, self.pb[key]) - if isinstance(obj, proto.Message): - obj._always_commit = True - return obj + value = self._item_cache.get(key, _Empty.shared) + + if isinstance(value, _Empty): + if key not in self: + raise KeyError(key) + value = self._marshal.to_python(self._pb_type, self.pb[key]) + + # This is the first domino in a hacky workaround that is completed + # in `fields._FieldDescriptor.__set__`. Because of the by-value nature + # of protobufs (which conflicts with the by-reference nature of Python), + # proto-plus objects that are yielded from MapFields must immediately + # write to their internal pb2 object whenever their fields are updated. + # This is a new requirement as always writing to a proto-plus object's + # inner pb2 protobuf used to be the default, but has been moved to a + # lazy-syncing system for performance reasons. + if isinstance(value, proto.Message): + value._always_commit = True + + self._item_cache[key] = value + + return value def __setitem__(self, key, value): + self._item_cache[key] = value + self._stale_keys.add(key) + # self._sync_key(key, value) + + def _sync_all_keys(self): + for key in self._stale_keys: + self._sync_key(key) + self._stale_keys.clear() + + def _sync_key(self, key, value = None): + value = value or self._item_cache.pop(key, None) + if value is None: + self.pb.pop(key) + return pb_value = self._marshal.to_proto(self._pb_type, value, strict=True) # Directly setting a key is not allowed; however, protocol buffers @@ -73,14 +107,27 @@ def __setitem__(self, key, value): self.pb[key].MergeFrom(pb_value) def __delitem__(self, key): - self.pb.pop(key) + self._item_cache.pop(key, None) + self._stale_keys.add(key) + # self.pb.pop(key) def __len__(self): - return len(self.pb) + _all_keys = set(list(self._item_cache.keys())) + _all_keys = _all_keys.union(list(self.pb.keys())) + return len(_all_keys) + # return len(self.pb) def __iter__(self): + self._sync_all_keys() return iter(self.pb) @property def pb(self): return self._pb + + +class _Empty: + pass + + +_Empty.shared = _Empty() \ No newline at end of file diff --git a/proto/message.py b/proto/message.py index c179bbb5..2eb09ada 100644 --- a/proto/message.py +++ b/proto/message.py @@ -483,7 +483,6 @@ def __init__(self, mapping=None, *, ignore_unknown_fields=False, **kwargs): # underlying `_pb` self._stale_fields: Set[str] = set() - # We accept several things for `mapping`: # * An instance of this class. # * An instance of the underlying protobuf descriptor class. @@ -592,7 +591,6 @@ def __contains__(self, key): except ValueError: return bool(pb_value) - def __eq__(self, other): """Return True if the messages are equal, False otherwise.""" # If these are the same type, use internal protobuf's equality check. @@ -633,7 +631,8 @@ def _mark_pb_stale(self, field_name: str): self._stale_fields.add(field_name) def _mark_pb_synced(self): - self._stale_fields = set() + if hasattr(self, '_stale_fields'): + self._stale_fields.clear() def _update_nested_pb(self): """When it is time to serialize a pb2 object, it does not do to sync just ourselves - @@ -643,8 +642,11 @@ def _update_nested_pb(self): for field_name, field in self._meta.fields.items(): if field.proto_type == proto.MESSAGE: obj = getattr(self, field_name, None) + # Traversing the tree of objects eventually encounters primitives + # that do not have these functions if obj and hasattr(obj, '_update_pb'): obj._update_pb() + # Same here - many values will not have these methods if obj and hasattr(obj, '_update_nested_pb'): obj._update_nested_pb() diff --git a/tests/test_fields_composite.py b/tests/test_fields_composite.py index 60d29e12..3e0bb736 100644 --- a/tests/test_fields_composite.py +++ b/tests/test_fields_composite.py @@ -24,6 +24,7 @@ class Spam(proto.Message): foo = proto.Field(proto.MESSAGE, number=1, message=Foo) eggs = proto.Field(proto.BOOL, number=2) + spam = Spam(foo=Foo(bar="str", baz=42)) assert spam.foo.bar == "str" assert spam.foo.baz == 42