Skip to content

Commit

Permalink
Soft deprecate python MessageFactory
Browse files Browse the repository at this point in the history
Soft deprecate python MessageFactory. Added new replacement APIs GetMessageClass(descriptor) and GetMessagesFromFiles(files, pool)

PiperOrigin-RevId: 501802633
  • Loading branch information
anandolee authored and copybara-github committed Jan 13, 2023
1 parent f95aafd commit c80e7ef
Show file tree
Hide file tree
Showing 9 changed files with 167 additions and 103 deletions.
3 changes: 1 addition & 2 deletions python/google/protobuf/internal/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,8 +806,7 @@ def DecodeItem(buffer, pos, end, message, field_dict):
if value is None:
message_type = extension.message_type
if not hasattr(message_type, '_concrete_class'):
# pylint: disable=protected-access
message._FACTORY.GetPrototype(message_type)
message_factory.GetMessageClass(message_type)
value = field_dict.setdefault(
extension, message_type._concrete_class())
if value._InternalParse(buffer, message_start,message_end) != message_end:
Expand Down
5 changes: 3 additions & 2 deletions python/google/protobuf/internal/extension_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,9 @@ def __getitem__(self, extension_handle):
elif extension_handle.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE:
message_type = extension_handle.message_type
if not hasattr(message_type, '_concrete_class'):
# pylint: disable=protected-access
self._extended_message._FACTORY.GetPrototype(message_type)
# pylint: disable=g-import-not-at-top
from google.protobuf import message_factory
message_factory.GetMessageClass(message_type)
assert getattr(extension_handle.message_type, '_concrete_class', None), (
'Uninitialized concrete class found for field %r (message type %r)'
% (extension_handle.full_name,
Expand Down
40 changes: 10 additions & 30 deletions python/google/protobuf/internal/message_factory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,36 +92,17 @@ def testGetPrototype(self):
pool = descriptor_pool.DescriptorPool(db)
db.Add(self.factory_test1_fd)
db.Add(self.factory_test2_fd)
factory = message_factory.MessageFactory()
cls = factory.GetPrototype(pool.FindMessageTypeByName(
cls = message_factory.GetMessageClass(pool.FindMessageTypeByName(
'google.protobuf.python.internal.Factory2Message'))
self.assertFalse(cls is factory_test2_pb2.Factory2Message)
self._ExerciseDynamicClass(cls)
cls2 = factory.GetPrototype(pool.FindMessageTypeByName(
cls2 = message_factory.GetMessageClass(pool.FindMessageTypeByName(
'google.protobuf.python.internal.Factory2Message'))
self.assertTrue(cls is cls2)

def testCreatePrototypeOverride(self):
class MyMessageFactory(message_factory.MessageFactory):

def CreatePrototype(self, descriptor):
cls = super(MyMessageFactory, self).CreatePrototype(descriptor)
cls.additional_field = 'Some value'
return cls

db = descriptor_database.DescriptorDatabase()
pool = descriptor_pool.DescriptorPool(db)
db.Add(self.factory_test1_fd)
db.Add(self.factory_test2_fd)
factory = MyMessageFactory()
cls = factory.GetPrototype(pool.FindMessageTypeByName(
'google.protobuf.python.internal.Factory2Message'))
self.assertTrue(hasattr(cls, 'additional_field'))

def testGetExistingPrototype(self):
factory = message_factory.MessageFactory()
# Get Existing Prototype should not create a new class.
cls = factory.GetPrototype(
cls = message_factory.GetMessageClass(
descriptor=factory_test2_pb2.Factory2Message.DESCRIPTOR)
msg = factory_test2_pb2.Factory2Message()
self.assertIsInstance(msg, cls)
Expand Down Expand Up @@ -181,15 +162,14 @@ def testGetMessages(self):

def testDuplicateExtensionNumber(self):
pool = descriptor_pool.DescriptorPool()
factory = message_factory.MessageFactory(pool=pool)

# Add Container message.
f = descriptor_pb2.FileDescriptorProto(
name='google/protobuf/internal/container.proto',
package='google.protobuf.python.internal')
f.message_type.add(name='Container').extension_range.add(start=1, end=10)
pool.Add(f)
msgs = factory.GetMessages([f.name])
msgs = message_factory.GetMessageClassesForFiles([f.name], pool)
self.assertIn('google.protobuf.python.internal.Container', msgs)

# Extend container.
Expand All @@ -205,7 +185,7 @@ def testDuplicateExtensionNumber(self):
type_name='Extension',
extendee='Container')
pool.Add(f)
msgs = factory.GetMessages([f.name])
msgs = message_factory.GetMessageClassesForFiles([f.name], pool)
self.assertIn('google.protobuf.python.internal.Extension', msgs)

# Add Duplicate extending the same field number.
Expand All @@ -223,7 +203,7 @@ def testDuplicateExtensionNumber(self):
pool.Add(f)

with self.assertRaises(Exception) as cm:
factory.GetMessages([f.name])
message_factory.GetMessageClassesForFiles([f.name], pool)

self.assertIn(str(cm.exception),
['Extensions '
Expand Down Expand Up @@ -281,8 +261,8 @@ def FindFileByName(self, name):
db = SimpleDescriptorDB({f1.name: f1, f2.name: f2, f3.name: f3})

pool = descriptor_pool.DescriptorPool(db)
factory = message_factory.MessageFactory(pool=pool)
msgs = factory.GetMessages([f1.name, f3.name]) # Deliberately not f2.
msgs = message_factory.GetMessageClassesForFiles(
[f1.name, f3.name], pool) # Deliberately not f2.
msg = msgs['google.protobuf.python.internal.Container']
desc = msgs['google.protobuf.python.internal.Extension'].DESCRIPTOR
ext1 = desc.file.extensions_by_name['top_level_extension_field']
Expand All @@ -293,8 +273,8 @@ def FindFileByName(self, name):
serialized = m.SerializeToString()

pool = descriptor_pool.DescriptorPool(db)
factory = message_factory.MessageFactory(pool=pool)
msgs = factory.GetMessages([f1.name, f3.name]) # Deliberately not f2.
msgs = message_factory.GetMessageClassesForFiles(
[f1.name, f3.name], pool) # Deliberately not f2.
msg = msgs['google.protobuf.python.internal.Container']
desc = msgs['google.protobuf.python.internal.Extension'].DESCRIPTOR
ext1 = desc.file.extensions_by_name['top_level_extension_field']
Expand Down
3 changes: 2 additions & 1 deletion python/google/protobuf/json_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@

from google.protobuf.internal import type_checkers
from google.protobuf import descriptor
from google.protobuf import message_factory
from google.protobuf import symbol_database


Expand Down Expand Up @@ -409,7 +410,7 @@ def _CreateMessageFromTypeUrl(type_url, descriptor_pool):
raise TypeError(
'Can not find message descriptor by type_url: {0}'.format(type_url)
) from e
message_class = db.GetPrototype(message_descriptor)
message_class = message_factory.GetMessageClass(message_descriptor)
return message_class()


Expand Down
167 changes: 114 additions & 53 deletions python/google/protobuf/message_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@

__author__ = '[email protected] (Matt Toia)'

import warnings

from google.protobuf.internal import api_implementation
from google.protobuf import descriptor_pool
from google.protobuf import message
Expand All @@ -53,6 +55,95 @@
_GENERATED_PROTOCOL_MESSAGE_TYPE = message_impl.GeneratedProtocolMessageType


def GetMessageClass(descriptor):
"""Obtains a proto2 message class based on the passed in descriptor.
Passing a descriptor with a fully qualified name matching a previous
invocation will cause the same class to be returned.
Args:
descriptor: The descriptor to build from.
Returns:
A class describing the passed in descriptor.
"""
concrete_class = getattr(descriptor, '_concrete_class', None)
if concrete_class:
return concrete_class
return _InternalCreateMessageClass(descriptor)


def GetMessageClassesForFiles(files, pool):
"""Gets all the messages from specified files.
This will find and resolve dependencies, failing if the descriptor
pool cannot satisfy them.
Args:
files: The file names to extract messages from.
pool: The descriptor pool to find the files including the dependent
files.
Returns:
A dictionary mapping proto names to the message classes.
"""
result = {}
for file_name in files:
file_desc = pool.FindFileByName(file_name)
for desc in file_desc.message_types_by_name.values():
result[desc.full_name] = GetMessageClass(desc)

# While the extension FieldDescriptors are created by the descriptor pool,
# the python classes created in the factory need them to be registered
# explicitly, which is done below.
#
# The call to RegisterExtension will specifically check if the
# extension was already registered on the object and either
# ignore the registration if the original was the same, or raise
# an error if they were different.

for extension in file_desc.extensions_by_name.values():
extended_class = GetMessageClass(extension.containing_type)
extended_class.RegisterExtension(extension)
# Recursively load protos for extension field, in order to be able to
# fully represent the extension. This matches the behavior for regular
# fields too.
if extension.message_type:
GetMessageClass(extension.message_type)
return result


def _InternalCreateMessageClass(descriptor):
"""Builds a proto2 message class based on the passed in descriptor.
Args:
descriptor: The descriptor to build from.
Returns:
A class describing the passed in descriptor.
"""
descriptor_name = descriptor.name
result_class = _GENERATED_PROTOCOL_MESSAGE_TYPE(
descriptor_name,
(message.Message,),
{
'DESCRIPTOR': descriptor,
# If module not set, it wrongly points to message_factory module.
'__module__': None,
})
for field in descriptor.fields:
if field.message_type:
GetMessageClass(field.message_type)
for extension in result_class.DESCRIPTOR.extensions:
extended_class = GetMessageClass(extension.containing_type)
extended_class.RegisterExtension(extension)
if extension.message_type:
GetMessageClass(extension.message_type)
return result_class


# Deprecated. Please use GetMessageClass() or GetMessageClassesForFiles()
# method above instead.
class MessageFactory(object):
"""Factory for creating Proto2 messages from descriptors in a pool."""

Expand All @@ -72,44 +163,29 @@ def GetPrototype(self, descriptor):
Returns:
A class describing the passed in descriptor.
"""
concrete_class = getattr(descriptor, '_concrete_class', None)
if concrete_class:
return concrete_class
result_class = self.CreatePrototype(descriptor)
return result_class
# TODO(b/258832141): add this warning
# warnings.warn('MessageFactory class is deprecated. Please use '
# 'GetMessageClass() instead of MessageFactory.GetPrototype. '
# 'MessageFactory class will be removed after 2024.')
return GetMessageClass(descriptor)

def CreatePrototype(self, descriptor):
"""Builds a proto2 message class based on the passed in descriptor.
Don't call this function directly, it always creates a new class. Call
GetPrototype() instead. This method is meant to be overridden in subblasses
to perform additional operations on the newly constructed class.
GetMessageClass() instead.
Args:
descriptor: The descriptor to build from.
Returns:
A class describing the passed in descriptor.
"""
descriptor_name = descriptor.name
result_class = _GENERATED_PROTOCOL_MESSAGE_TYPE(
descriptor_name,
(message.Message,),
{
'DESCRIPTOR': descriptor,
# If module not set, it wrongly points to message_factory module.
'__module__': None,
})
result_class._FACTORY = self # pylint: disable=protected-access
for field in descriptor.fields:
if field.message_type:
self.GetPrototype(field.message_type)
for extension in result_class.DESCRIPTOR.extensions:
extended_class = self.GetPrototype(extension.containing_type)
extended_class.RegisterExtension(extension)
if extension.message_type:
self.GetPrototype(extension.message_type)
return result_class
# TODO(b/258832141): add this warning
# warnings.warn('Directly call CreatePrototype is wrong. Please use '
# 'GetMessageClass() method instead. Directly use '
# 'CreatePrototype will raise error after July 2023.')
return _InternalCreateMessageClass(descriptor)

def GetMessages(self, files):
"""Gets all the messages from a specified file.
Expand All @@ -125,37 +201,20 @@ def GetMessages(self, files):
any dependent messages as well as any messages defined in the same file as
a specified message.
"""
result = {}
for file_name in files:
file_desc = self.pool.FindFileByName(file_name)
for desc in file_desc.message_types_by_name.values():
result[desc.full_name] = self.GetPrototype(desc)

# While the extension FieldDescriptors are created by the descriptor pool,
# the python classes created in the factory need them to be registered
# explicitly, which is done below.
#
# The call to RegisterExtension will specifically check if the
# extension was already registered on the object and either
# ignore the registration if the original was the same, or raise
# an error if they were different.

for extension in file_desc.extensions_by_name.values():
extended_class = self.GetPrototype(extension.containing_type)
extended_class.RegisterExtension(extension)
if extension.message_type:
self.GetPrototype(extension.message_type)
return result


_FACTORY = MessageFactory()
# TODO(b/258832141): add this warning
# warnings.warn('MessageFactory class is deprecated. Please use '
# 'GetMessageClassesForFiles() instead of '
# 'MessageFactory.GetMessages(). MessageFactory class '
# 'will be removed after 2024.')
return GetMessageClassesForFiles(files, self.pool)


def GetMessages(file_protos):
def GetMessages(file_protos, pool=None):
"""Builds a dictionary of all the messages available in a set of files.
Args:
file_protos: Iterable of FileDescriptorProto to build messages out of.
pool: The descriptor pool to add the file protos.
Returns:
A dictionary mapping proto names to the message classes. This will include
Expand All @@ -164,13 +223,15 @@ def GetMessages(file_protos):
"""
# The cpp implementation of the protocol buffer library requires to add the
# message in topological order of the dependency graph.
des_pool = pool or descriptor_pool.DescriptorPool()
file_by_name = {file_proto.name: file_proto for file_proto in file_protos}
def _AddFile(file_proto):
for dependency in file_proto.dependency:
if dependency in file_by_name:
# Remove from elements to be visited, in order to cut cycles.
_AddFile(file_by_name.pop(dependency))
_FACTORY.pool.Add(file_proto)
des_pool.Add(file_proto)
while file_by_name:
_AddFile(file_by_name.popitem()[1])
return _FACTORY.GetMessages([file_proto.name for file_proto in file_protos])
return GetMessageClassesForFiles(
[file_proto.name for file_proto in file_protos], des_pool)
Loading

0 comments on commit c80e7ef

Please sign in to comment.