Skip to content

Commit

Permalink
Remove RegisterExtension() in python generated code
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 530452253
  • Loading branch information
anandolee authored and copybara-github committed May 9, 2023
1 parent 4c79444 commit e5a7a2e
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 45 deletions.
9 changes: 7 additions & 2 deletions python/google/protobuf/descriptor_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,12 +805,17 @@ def _ConvertFileProtoToFileDescriptor(self, file_proto):
self._file_descriptors[file_proto.name] = file_descriptor

# Add extensions to the pool
def AddExtensionForNested(message_type):
for nested in message_type.nested_types:
AddExtensionForNested(nested)
for extension in message_type.extensions:
self._AddExtensionDescriptor(extension)

file_desc = self._file_descriptors[file_proto.name]
for extension in file_desc.extensions_by_name.values():
self._AddExtensionDescriptor(extension)
for message_type in file_desc.message_types_by_name.values():
for extension in message_type.extensions:
self._AddExtensionDescriptor(extension)
AddExtensionForNested(message_type)

return file_desc

Expand Down
23 changes: 23 additions & 0 deletions python/google/protobuf/internal/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
__author__ = '[email protected] (Jie Luo)'

from google.protobuf.internal import enum_type_wrapper
from google.protobuf.internal import python_message
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
Expand Down Expand Up @@ -108,6 +109,28 @@ def BuildMessage(msg_des):
module[name] = BuildMessage(msg_des)


def AddHelpersToExtensions(file_des):
"""Adds field helpers to extensions.
Args:
file_des: FileDescriptor of the .proto file
"""
def AddHelpersToExtension(extension):
python_message._AttachFieldHelpers(
extension.containing_type._concrete_class, extension)

def AddHelpersToNestedExtensions(msg_des):
for nested_type in msg_des.nested_types:
AddHelpersToNestedExtensions(nested_type)
for extension in msg_des.extensions:
AddHelpersToExtension(extension)

for extension in file_des.extensions_by_name.values():
AddHelpersToExtension(extension)
for message_type in file_des.message_types_by_name.values():
AddHelpersToNestedExtensions(message_type)


def BuildServices(file_des, module_name, module):
"""Builds services classes and services stub class.
Expand Down
52 changes: 13 additions & 39 deletions src/google/protobuf/compiler/python/generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -331,10 +331,11 @@ bool Generator::Generate(const FileDescriptor* file,
printer.Print("if _descriptor._USE_C_DESCRIPTORS == False:\n");
printer_->Indent();

// We have to fix up the extensions after the message classes themselves,
// since they need to call static RegisterExtension() methods on these
// classes.
FixForeignFieldsInExtensions();
// We have to fix up the extensions after the message classes themselves
if (HasExtensions()) {
printer.Print("_builder.AddHelpersToExtensions(DESCRIPTOR)\n");
}

// Descriptor options may have custom extensions. These custom options
// can only be successfully parsed after we register corresponding
// extensions. Therefore we parse all options again here to recognize
Expand Down Expand Up @@ -1006,47 +1007,20 @@ void Generator::FixForeignFieldsInDescriptors() const {
printer_->Print("\n");
}

// We need to not only set any necessary message_type fields, but
// also need to call RegisterExtension() on each message we're
// extending.
void Generator::FixForeignFieldsInExtensions() const {
// Top-level extensions.
for (int i = 0; i < file_->extension_count(); ++i) {
FixForeignFieldsInExtension(*file_->extension(i));
}
// Nested extensions.
bool Generator::HasExtensions() const {
if (file_->extension_count() > 0) return true;
for (int i = 0; i < file_->message_type_count(); ++i) {
FixForeignFieldsInNestedExtensions(*file_->message_type(i));
if (HasExtensionsInMessage(*file_->message_type(i))) return true;
}
printer_->Print("\n");
}

void Generator::FixForeignFieldsInExtension(
const FieldDescriptor& extension_field) const {
ABSL_CHECK(extension_field.is_extension());

absl::flat_hash_map<absl::string_view, std::string> m;
// Confusingly, for FieldDescriptors that happen to be extensions,
// containing_type() means "extended type."
// On the other hand, extension_scope() will give us what we normally
// mean by containing_type().
m["extended_message_class"] =
ModuleLevelMessageName(*extension_field.containing_type());
m["field"] = FieldReferencingExpression(
extension_field.extension_scope(), extension_field, "extensions_by_name");
printer_->Print(m, "$extended_message_class$.RegisterExtension($field$)\n");
return false;
}

void Generator::FixForeignFieldsInNestedExtensions(
const Descriptor& descriptor) const {
// Recursively fix up extensions in all nested types.
bool Generator::HasExtensionsInMessage(const Descriptor& descriptor) const {
if (descriptor.extension_count() > 0) return true;
for (int i = 0; i < descriptor.nested_type_count(); ++i) {
FixForeignFieldsInNestedExtensions(*descriptor.nested_type(i));
}
// Fix up extensions directly contained within this type.
for (int i = 0; i < descriptor.extension_count(); ++i) {
FixForeignFieldsInExtension(*descriptor.extension(i));
if (HasExtensionsInMessage(*descriptor.nested_type(i))) return true;
}
return false;
}

// Returns a Python expression that instantiates a Python EnumValueDescriptor
Expand Down
6 changes: 2 additions & 4 deletions src/google/protobuf/compiler/python/generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,8 @@ class PROTOC_EXPORT Generator : public CodeGenerator {
const DescriptorT& descriptor,
const Descriptor* containing_descriptor) const;

void FixForeignFieldsInExtensions() const;
void FixForeignFieldsInExtension(
const FieldDescriptor& extension_field) const;
void FixForeignFieldsInNestedExtensions(const Descriptor& descriptor) const;
bool HasExtensions() const;
bool HasExtensionsInMessage(const Descriptor& descriptor) const;

void PrintTopBoilerplate() const;
void PrintServices() const;
Expand Down

0 comments on commit e5a7a2e

Please sign in to comment.