Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Protobuf Performance Refactor #230

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 253 additions & 0 deletions proto/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
from enum import EnumMeta
from re import L
from proto.marshal.rules import wrappers
from proto.datetime_helpers import DatetimeWithNanoseconds
from proto.marshal.rules.dates import DurationRule
from typing import Any, Callable, Optional, Union

from google.protobuf import descriptor_pb2
from google.protobuf import duration_pb2
from google.protobuf import struct_pb2
from google.protobuf import timestamp_pb2
from google.protobuf import wrappers_pb2
from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper

import proto
from proto.primitives import ProtoType


Expand Down Expand Up @@ -140,6 +151,101 @@ def pb_type(self):
return self.message.pb()
return self.message

@property
def can_get_natively(self) -> bool:
if self.proto_type == ProtoType.MESSAGE and self.message == struct_pb2.Value:
craiglabenz marked this conversation as resolved.
Show resolved Hide resolved
return False
return True
craiglabenz marked this conversation as resolved.
Show resolved Hide resolved

def can_set_natively(self, val: Any) -> bool:
craiglabenz marked this conversation as resolved.
Show resolved Hide resolved
if self.proto_type == ProtoType.MESSAGE and self.message == struct_pb2.Value:
return False
craiglabenz marked this conversation as resolved.
Show resolved Hide resolved
return True
# return self.pb_type is None and not self.repeated
craiglabenz marked this conversation as resolved.
Show resolved Hide resolved

def contribute_to_class(self, cls, name: str):
"""Attaches a descriptor to the top-level proto.Message class, so that attribute
reads and writes can be specially handled in `_FieldDescriptor.__get__` and
`FieldDescriptor.__set__`.

Also contains hooks for write-time type-coersion to translate special cases between
pure Pythonic objects and pb2-compatible structs or values.
"""
set_coercion = None
if self.proto_type == ProtoType.STRING:
# Bytes are accepted for string values, but strings are not accepted for byte values.
# This is an artifact of older Python2 implementations.
craiglabenz marked this conversation as resolved.
Show resolved Hide resolved
set_coercion = self._bytes_to_str
if self.pb_type == timestamp_pb2.Timestamp:
craiglabenz marked this conversation as resolved.
Show resolved Hide resolved
set_coercion = self._timestamp_to_datetime
if self.proto_type == ProtoType.MESSAGE and self.message == duration_pb2.Duration:
set_coercion = self._duration_to_timedelta
if self.proto_type == ProtoType.MESSAGE and self.message == wrappers_pb2.BoolValue:
set_coercion = self._bool_value_to_bool
if self.enum:
set_coercion = self._literal_to_enum
setattr(cls, name, _FieldDescriptor(name, cls=cls, set_coercion=set_coercion))

@property
def reverse_enum_map(self):
"""Helper that allows for constant-time lookup on self.enum, used to hydrate
primitives that are supplied but which stand for their official enum types.

This is used when a developer supplies the literal value for an enum type (often an int).
"""
if not self.enum:
return None
if not getattr(self, '_reverse_enum_map', None):
self._reverse_enum_map = {e.value: e for e in self.enum}
return self._reverse_enum_map
craiglabenz marked this conversation as resolved.
Show resolved Hide resolved

@property
def reverse_enum_names_map(self):
"""Helper that allows for constant-time lookup on self.enum, used to hydrate
primitives that are supplied but which stand for their official enum types.

This is used when a developer supplies the string value for an enum type's name.
"""
if not self.enum:
return None
if not getattr(self, '_reverse_enum_names_map', None):
self._reverse_enum_names_map = {e.name: e for e in self.enum}
return self._reverse_enum_names_map

def _literal_to_enum(self, val: Any):
if isinstance(val, self.enum):
return val
return (
self.reverse_enum_map.get(val, None) or
self.reverse_enum_names_map.get(val, None)
)

@staticmethod
def _bytes_to_str(val: Union[bytes, str]) -> str:
if type(val) == bytes:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably worth profiling here vs. isinstance(val, bytes), and maybe type(val) is bytes. E.g.:

$ python3.9 -m timeit "isinstance('', bytes)"
2000000 loops, best of 5: 127 nsec per loop

$ python3.9 -m timeit "type('') == bytes"
2000000 loops, best of 5: 129 nsec per loop

$ python3.9 -m timeit "type('') is bytes"
2000000 loops, best of 5: 123 nsec per loop

and similarly for the other conversions below.

val = val.decode('utf-8')
return val

@staticmethod
def _timestamp_to_datetime(val: Union[timestamp_pb2.Timestamp, datetime.datetime]) -> datetime.datetime:
if type(val) == timestamp_pb2.Timestamp:
val = DatetimeWithNanoseconds.from_timestamp_pb(val)
return val

@staticmethod
def _duration_to_timedelta(val: Union[duration_pb2.Duration, datetime.timedelta]) -> datetime.datetime:
if type(val) == duration_pb2.Duration:
val = DurationRule().to_python(val)
return val

@staticmethod
def _bool_value_to_bool(val: Union[wrappers_pb2.BoolValue, bool]) -> Optional[bool]:
if val is None:
return None
if type(val) == wrappers_pb2.BoolValue:
val = val.value
return val


class RepeatedField(Field):
"""A representation of a repeated field in protocol buffers."""
Expand All @@ -155,6 +261,153 @@ def __init__(self, key_type, value_type, *, number: int, message=None, enum=None
self.map_key_type = key_type


class _FieldDescriptor:
"""Handler for proto.Field access on any proto.Message object.

Wraps each proto.Field instance within a given proto.Message subclass's definition
with getters and setters that allow for caching of values on the proto-plus object,
deferment of syncing to the underlying pb2 object, and tracking of the current state.

Special treatment is given to MapFields, nested Messages, and certain data types, as
their various implementations within pb2 (which for our purposes is mostly a black box)
sometimes mandate immediate syncing. This is usually because proto-plus objects are not
long-lived, and thus information about which fields are stale would be lost if syncing
was left for serialization time.
"""
def __init__(self, name: str, *, cls, set_coercion: Optional[Callable] = None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make it more clear what signature set_coercion is expected to have? What what types does it take and return?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, from usage it looks like it it takes one param and returns one param, with the types dependent on the field type. Can you clarify that in a comment?
Also, what do you think about providing an empty default instead of checking set_coercion == None?

, set_coercion: Callable = lambda v: v)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There used to be a no-op default, but @tseaver opined that this pattern might be cleaner. Either way, expanding on the comments.

# something like "id". required whenever reach back to the pb2 object.
self.original_name = name
# something like "_cached_id"
self.instance_attr_name = f'_cached_fields__{name}'
craiglabenz marked this conversation as resolved.
Show resolved Hide resolved

# simple types coercion for setting attributes (for example, bytes -> str)
self._set_coercion = set_coercion or _FieldDescriptor._noop
self.cls = cls

@property
def field(self):
return self.cls._meta.fields[self.original_name]

def _hydrate_dicts(self, value: Any):
"""Turns a dictionary assigned to a nested Message into a full instance of
that Message type.
"""
if not isinstance(value, dict):
return value

if self.field.proto_type == proto.MESSAGE:
_pb = self.field.message._meta.pb(**value)
value = self.field.message.wrap(_pb)

return value

def _clear_oneofs(self, instance):
if not self.field.oneof:
return

for field_name, field in self.cls._meta.fields.items():
# Don't clear this field
if field_name == self.original_name:
continue

# Don't clear other fields with different oneof values, or with
# no such values at all
if field.oneof != self.field.oneof:
continue

delattr(instance, field_name)

def __set__(self, instance, value):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Call instance message instead? You seemed to prefer that below.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An interesting idea. It seems to be convention when using attribute descriptors like this to name the variable instance, so I think I prefer that here, but I could be convinced otherwise.

"""Called whenever a value is assigned to a proto.Field attribute on an instantiated
proto.Message object.

Usage:

class MyMessage(proto.Message):
name = proto.Field(proto.STRING, number=1)

my_message = MyMessage()
my_message.name = "Frodo"

In the above scenario, `__set__` is called with "Frodo" passed as `value` and `my_instance`
passed as `instance`.
craiglabenz marked this conversation as resolved.
Show resolved Hide resolved
"""
value = self._set_coercion(value)
craiglabenz marked this conversation as resolved.
Show resolved Hide resolved
value = self._hydrate_dicts(value)

# Warning: `always_commit` is hacky!
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As the comment says, this is hacky. Sadly, it is a lynch-pin for the whole refactor.

# Some contexts, particularly instances created from MapFields, require immediate syncing.
# It is impossible to deduce such a scenario purely from logic available to this function,
# so instead we set a flag on instances when a MapField yields them, and then when those
# instances receive attribute updates, immediately syncing those values to the underlying
# pb2 instance is sufficient.
always_commit: bool = getattr(instance, '_always_commit', False)
if always_commit or not self.field.can_set_natively(value):
pb_value = instance._meta.marshal.to_proto(self.field.pb_type, value)
_pb = instance._meta.pb(**{self.original_name: pb_value})
instance._pb.ClearField(self.original_name)
instance._pb.MergeFrom(_pb)
else:

if value is None:
self.__delete__(instance)
instance._pb.MergeFrom(instance._meta._pb(**{self.original_name: None}))
return

instance._meta.marshal.validate_primitives(self.field, value)
instance._mark_pb_stale(self.original_name)

setattr(instance, self.instance_attr_name, value)
self._clear_oneofs(instance)

def __get__(self, instance: 'proto.Message', _): # type: ignore
"""Called whenever a value is read from a proto.Field attribute on an instantiated
proto.Message object.

Usage:

class MyMessage(proto.Message):
name = proto.Field(proto.STRING, number=1)

my_message = MyMessage(name="Frodo")
print(my_message.name)

In the above scenario, `__get__` is called with "my_message" passed as `instance`.
"""
# If `instance` is None, then we are accessing this field directly
# off the class itself instead of off an instance.
if instance is None:
return self.original_name

value = getattr(instance, self.instance_attr_name, None)
if self.field.can_get_natively and value is not None:
return value
craiglabenz marked this conversation as resolved.
Show resolved Hide resolved
else:
craiglabenz marked this conversation as resolved.
Show resolved Hide resolved
# For the most part, only primitive values can be returned natively, meaning
# this is either a Message itself, in which case, since we're dealing with the
# underlying pb object, we need to sync all deferred fields.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is a little confusing. The 'either' has no second branch.

# This is functionally a no-op if no fields have been deferred.
if hasattr(value, '_update_pb'):
value._update_pb()

pb_value = getattr(instance._pb, self.original_name, None)
value = instance._meta.marshal.to_python(self.field.pb_type, pb_value, absent=self.original_name not in instance)

setattr(instance, self.instance_attr_name, value)
return value

def __delete__(self, instance):
if hasattr(instance, self.instance_attr_name):
delattr(instance, self.instance_attr_name)
instance._pb.ClearField(self.original_name)
if self.original_name in getattr(instance, '_stale_fields', []):
instance._stale_fields.remove(self.original_name)

@staticmethod
def _noop(val):
return val


__all__ = (
"Field",
"MapField",
Expand Down
8 changes: 6 additions & 2 deletions proto/marshal/collections/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import collections

import proto
from proto.utils import cached_property


Expand All @@ -28,7 +29,7 @@ class MapComposite(collections.abc.MutableMapping):
def _pb_type(self):
"""Return the protocol buffer type for this sequence."""
# Huzzah, another hack. Still less bad than RepeatedComposite.
return type(self.pb.GetEntryClass()().value)
return type(self.self.pb.GetEntryClass()().value)
tseaver marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, sequence, *, marshal):
"""Initialize a wrapper around a protobuf map.
Expand All @@ -54,7 +55,10 @@ def __getitem__(self, key):
# buffers will create the key if it does not exist.
if key not in self:
raise KeyError(key)
return self._marshal.to_python(self._pb_type, self.pb[key])
obj = self._marshal.to_python(self._pb_type, self.pb[key])
if isinstance(obj, proto.Message):
obj._always_commit = True
craiglabenz marked this conversation as resolved.
Show resolved Hide resolved
return obj

def __setitem__(self, key, value):
pb_value = self._marshal.to_proto(self._pb_type, value, strict=True)
Expand Down
Loading