diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py index 733c9a6ac8dd..aa90cce0d12d 100644 --- a/python/google/protobuf/descriptor_pool.py +++ b/python/google/protobuf/descriptor_pool.py @@ -805,12 +805,17 @@ def _ConvertFileProtoToFileDescriptor(self, file_proto): self._file_descriptors[file_proto.name] = file_descriptor # Add extensions to the pool + def AddExtensionForNested(message_type): + for nested in message_type.nested_types: + AddExtensionForNested(nested) + for extension in message_type.extensions: + self._AddExtensionDescriptor(extension) + file_desc = self._file_descriptors[file_proto.name] for extension in file_desc.extensions_by_name.values(): self._AddExtensionDescriptor(extension) for message_type in file_desc.message_types_by_name.values(): - for extension in message_type.extensions: - self._AddExtensionDescriptor(extension) + AddExtensionForNested(message_type) return file_desc diff --git a/python/google/protobuf/internal/builder.py b/python/google/protobuf/internal/builder.py index 64353ee4af60..5beb21678363 100644 --- a/python/google/protobuf/internal/builder.py +++ b/python/google/protobuf/internal/builder.py @@ -38,6 +38,7 @@ __author__ = 'jieluo@google.com (Jie Luo)' from google.protobuf.internal import enum_type_wrapper +from google.protobuf.internal import python_message from google.protobuf import message as _message from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database @@ -108,6 +109,28 @@ def BuildMessage(msg_des): module[name] = BuildMessage(msg_des) +def AddHelpersToExtensions(file_des): + """Adds field helpers to extensions. + + Args: + file_des: FileDescriptor of the .proto file + """ + def AddHelpersToExtension(extension): + python_message._AttachFieldHelpers( + extension.containing_type._concrete_class, extension) + + def AddHelpersToNestedExtensions(msg_des): + for nested_type in msg_des.nested_types: + AddHelpersToNestedExtensions(nested_type) + for extension in msg_des.extensions: + AddHelpersToExtension(extension) + + for extension in file_des.extensions_by_name.values(): + AddHelpersToExtension(extension) + for message_type in file_des.message_types_by_name.values(): + AddHelpersToNestedExtensions(message_type) + + def BuildServices(file_des, module_name, module): """Builds services classes and services stub class. diff --git a/src/google/protobuf/compiler/python/generator.cc b/src/google/protobuf/compiler/python/generator.cc index 636808a3ae20..fc63c34d6b5b 100644 --- a/src/google/protobuf/compiler/python/generator.cc +++ b/src/google/protobuf/compiler/python/generator.cc @@ -331,10 +331,11 @@ bool Generator::Generate(const FileDescriptor* file, printer.Print("if _descriptor._USE_C_DESCRIPTORS == False:\n"); printer_->Indent(); - // We have to fix up the extensions after the message classes themselves, - // since they need to call static RegisterExtension() methods on these - // classes. - FixForeignFieldsInExtensions(); + // We have to fix up the extensions after the message classes themselves + if (HasExtensions()) { + printer.Print("_builder.AddHelpersToExtensions(DESCRIPTOR)\n"); + } + // Descriptor options may have custom extensions. These custom options // can only be successfully parsed after we register corresponding // extensions. Therefore we parse all options again here to recognize @@ -1006,47 +1007,20 @@ void Generator::FixForeignFieldsInDescriptors() const { printer_->Print("\n"); } -// We need to not only set any necessary message_type fields, but -// also need to call RegisterExtension() on each message we're -// extending. -void Generator::FixForeignFieldsInExtensions() const { - // Top-level extensions. - for (int i = 0; i < file_->extension_count(); ++i) { - FixForeignFieldsInExtension(*file_->extension(i)); - } - // Nested extensions. +bool Generator::HasExtensions() const { + if (file_->extension_count() > 0) return true; for (int i = 0; i < file_->message_type_count(); ++i) { - FixForeignFieldsInNestedExtensions(*file_->message_type(i)); + if (HasExtensionsInMessage(*file_->message_type(i))) return true; } - printer_->Print("\n"); -} - -void Generator::FixForeignFieldsInExtension( - const FieldDescriptor& extension_field) const { - ABSL_CHECK(extension_field.is_extension()); - - absl::flat_hash_map m; - // Confusingly, for FieldDescriptors that happen to be extensions, - // containing_type() means "extended type." - // On the other hand, extension_scope() will give us what we normally - // mean by containing_type(). - m["extended_message_class"] = - ModuleLevelMessageName(*extension_field.containing_type()); - m["field"] = FieldReferencingExpression( - extension_field.extension_scope(), extension_field, "extensions_by_name"); - printer_->Print(m, "$extended_message_class$.RegisterExtension($field$)\n"); + return false; } -void Generator::FixForeignFieldsInNestedExtensions( - const Descriptor& descriptor) const { - // Recursively fix up extensions in all nested types. +bool Generator::HasExtensionsInMessage(const Descriptor& descriptor) const { + if (descriptor.extension_count() > 0) return true; for (int i = 0; i < descriptor.nested_type_count(); ++i) { - FixForeignFieldsInNestedExtensions(*descriptor.nested_type(i)); - } - // Fix up extensions directly contained within this type. - for (int i = 0; i < descriptor.extension_count(); ++i) { - FixForeignFieldsInExtension(*descriptor.extension(i)); + if (HasExtensionsInMessage(*descriptor.nested_type(i))) return true; } + return false; } // Returns a Python expression that instantiates a Python EnumValueDescriptor diff --git a/src/google/protobuf/compiler/python/generator.h b/src/google/protobuf/compiler/python/generator.h index 65c16f591eb0..99537f1e3be6 100644 --- a/src/google/protobuf/compiler/python/generator.h +++ b/src/google/protobuf/compiler/python/generator.h @@ -141,10 +141,8 @@ class PROTOC_EXPORT Generator : public CodeGenerator { const DescriptorT& descriptor, const Descriptor* containing_descriptor) const; - void FixForeignFieldsInExtensions() const; - void FixForeignFieldsInExtension( - const FieldDescriptor& extension_field) const; - void FixForeignFieldsInNestedExtensions(const Descriptor& descriptor) const; + bool HasExtensions() const; + bool HasExtensionsInMessage(const Descriptor& descriptor) const; void PrintTopBoilerplate() const; void PrintServices() const;