Skip to content

Commit

Permalink
Add Python support for retention attribute
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 511914565
  • Loading branch information
acozzette authored and copybara-github committed Feb 24, 2023
1 parent bcb20bb commit 63389c0
Show file tree
Hide file tree
Showing 6 changed files with 251 additions and 74 deletions.
77 changes: 77 additions & 0 deletions python/google/protobuf/internal/generator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from google.protobuf import unittest_mset_pb2
from google.protobuf import unittest_mset_wire_format_pb2
from google.protobuf import unittest_pb2
from google.protobuf import unittest_retention_pb2
from google.protobuf import unittest_custom_options_pb2
from google.protobuf import unittest_no_generic_services_pb2

Expand Down Expand Up @@ -152,6 +153,82 @@ def testMessageWithCustomOptions(self):
# TODO(gps): We really should test for the presence of the enum_opt1
# extension and for its value to be set to -789.

# Options that are explicitly marked RETENTION_SOURCE should not be present
# in the descriptors in the binary.
def testOptionRetention(self):
# Direct options
options = unittest_retention_pb2.DESCRIPTOR.GetOptions()
self.assertTrue(options.HasExtension(unittest_retention_pb2.plain_option))
self.assertTrue(
options.HasExtension(unittest_retention_pb2.runtime_retention_option)
)
self.assertFalse(
options.HasExtension(unittest_retention_pb2.source_retention_option)
)

def check_options_message_is_stripped_correctly(options):
self.assertEqual(options.plain_field, 1)
self.assertEqual(options.runtime_retention_field, 2)
self.assertFalse(options.HasField('source_retention_field'))
self.assertEqual(options.source_retention_field, 0)

# Verify that our test OptionsMessage is stripped correctly on all
# different entity types.
check_options_message_is_stripped_correctly(
options.Extensions[unittest_retention_pb2.file_option]
)
check_options_message_is_stripped_correctly(
unittest_retention_pb2.TopLevelMessage.DESCRIPTOR.GetOptions().Extensions[
unittest_retention_pb2.message_option
]
)
check_options_message_is_stripped_correctly(
unittest_retention_pb2.TopLevelMessage.NestedMessage.DESCRIPTOR.GetOptions().Extensions[
unittest_retention_pb2.message_option
]
)
check_options_message_is_stripped_correctly(
unittest_retention_pb2._TOPLEVELENUM.GetOptions().Extensions[
unittest_retention_pb2.enum_option
]
)
check_options_message_is_stripped_correctly(
unittest_retention_pb2._TOPLEVELMESSAGE_NESTEDENUM.GetOptions().Extensions[
unittest_retention_pb2.enum_option
]
)
check_options_message_is_stripped_correctly(
unittest_retention_pb2._TOPLEVELENUM.values[0]
.GetOptions()
.Extensions[unittest_retention_pb2.enum_entry_option]
)
check_options_message_is_stripped_correctly(
unittest_retention_pb2.DESCRIPTOR.extensions_by_name['i']
.GetOptions()
.Extensions[unittest_retention_pb2.field_option]
)
check_options_message_is_stripped_correctly(
unittest_retention_pb2.TopLevelMessage.DESCRIPTOR.fields[0]
.GetOptions()
.Extensions[unittest_retention_pb2.field_option]
)
check_options_message_is_stripped_correctly(
unittest_retention_pb2.TopLevelMessage.DESCRIPTOR.oneofs[0]
.GetOptions()
.Extensions[unittest_retention_pb2.oneof_option]
)
check_options_message_is_stripped_correctly(
unittest_retention_pb2.DESCRIPTOR.services_by_name['Service']
.GetOptions()
.Extensions[unittest_retention_pb2.service_option]
)
check_options_message_is_stripped_correctly(
unittest_retention_pb2.DESCRIPTOR.services_by_name['Service']
.methods[0]
.GetOptions()
.Extensions[unittest_retention_pb2.method_option]
)

def testNestedTypes(self):
self.assertEqual(
set(unittest_pb2.TestAllTypes.DESCRIPTOR.nested_types),
Expand Down
12 changes: 8 additions & 4 deletions src/google/protobuf/compiler/python/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ cc_library(
deps = [
"//src/google/protobuf:protobuf_nowkt",
"//src/google/protobuf/compiler:code_generator",
"//src/google/protobuf/compiler:retention",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
],
Expand Down Expand Up @@ -61,9 +62,12 @@ pkg_files(

filegroup(
name = "test_srcs",
srcs = glob([
"*_test.cc",
"*unittest.cc",
], allow_empty = True),
srcs = glob(
[
"*_test.cc",
"*unittest.cc",
],
allow_empty = True,
),
visibility = ["//src/google/protobuf/compiler:__pkg__"],
)
100 changes: 54 additions & 46 deletions src/google/protobuf/compiler/python/generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
#include "absl/strings/substitute.h"
#include "google/protobuf/compiler/python/helpers.h"
#include "google/protobuf/compiler/python/pyi_generator.h"
#include "google/protobuf/compiler/retention.h"
#include "google/protobuf/descriptor.h"
#include "google/protobuf/descriptor.pb.h"
#include "google/protobuf/io/printer.h"
Expand Down Expand Up @@ -249,8 +250,7 @@ bool Generator::Generate(const FileDescriptor* file,

std::string filename = GetFileName(file, ".py");

FileDescriptorProto fdp;
file_->CopyTo(&fdp);
FileDescriptorProto fdp = StripSourceRetentionOptions(*file_);
fdp.SerializeToString(&file_descriptor_serialized_);

if (!opensource_runtime_ && GeneratingDescriptorProto()) {
Expand Down Expand Up @@ -342,7 +342,7 @@ bool Generator::Generate(const FileDescriptor* file,
FixAllDescriptorOptions();

// Set serialized_start and serialized_end.
SetSerializedPbInterval();
SetSerializedPbInterval(fdp);

printer_->Outdent();
if (HasGenericServices(file)) {
Expand Down Expand Up @@ -442,7 +442,8 @@ void Generator::PrintFileDescriptor() const {
m["name"] = file_->name();
m["package"] = file_->package();
m["syntax"] = StringifySyntax(file_->syntax());
m["options"] = OptionsValue(file_->options().SerializeAsString());
m["options"] = OptionsValue(
StripLocalSourceRetentionOptions(*file_).SerializeAsString());
m["serialized_descriptor"] = absl::CHexEscape(file_descriptor_serialized_);
if (GeneratingDescriptorProto()) {
printer_->Print("if _descriptor._USE_C_DESCRIPTORS == False:\n");
Expand Down Expand Up @@ -528,7 +529,8 @@ void Generator::PrintEnum(const EnumDescriptor& enum_descriptor) const {
" create_key=_descriptor._internal_create_key,\n"
" values=[\n";
std::string options_string;
enum_descriptor.options().SerializeToString(&options_string);
StripLocalSourceRetentionOptions(enum_descriptor)
.SerializeToString(&options_string);
printer_->Print(m, enum_descriptor_template);
printer_->Indent();
printer_->Indent();
Expand Down Expand Up @@ -681,7 +683,8 @@ void Generator::PrintDescriptor(const Descriptor& message_descriptor) const {
printer_->Outdent();
printer_->Print("],\n");
std::string options_string;
message_descriptor.options().SerializeToString(&options_string);
StripLocalSourceRetentionOptions(message_descriptor)
.SerializeToString(&options_string);
printer_->Print(
"serialized_options=$options_value$,\n"
"is_extendable=$extendable$,\n"
Expand All @@ -708,7 +711,8 @@ void Generator::PrintDescriptor(const Descriptor& message_descriptor) const {
m["name"] = desc->name();
m["full_name"] = desc->full_name();
m["index"] = absl::StrCat(desc->index());
options_string = OptionsValue(desc->options().SerializeAsString());
options_string = OptionsValue(
StripLocalSourceRetentionOptions(*desc).SerializeAsString());
if (options_string == "None") {
m["serialized_options"] = "";
} else {
Expand Down Expand Up @@ -1050,7 +1054,8 @@ void Generator::PrintEnumValueDescriptor(
// TODO(robinson): Fix up EnumValueDescriptor "type" fields.
// More circular references. ::sigh::
std::string options_string;
descriptor.options().SerializeToString(&options_string);
StripLocalSourceRetentionOptions(descriptor)
.SerializeToString(&options_string);
absl::flat_hash_map<absl::string_view, std::string> m;
m["name"] = descriptor.name();
m["index"] = absl::StrCat(descriptor.index());
Expand Down Expand Up @@ -1078,7 +1083,7 @@ std::string Generator::OptionsValue(
void Generator::PrintFieldDescriptor(const FieldDescriptor& field,
bool is_extension) const {
std::string options_string;
field.options().SerializeToString(&options_string);
StripLocalSourceRetentionOptions(field).SerializeToString(&options_string);
absl::flat_hash_map<absl::string_view, std::string> m;
m["name"] = field.name();
m["full_name"] = field.full_name();
Expand Down Expand Up @@ -1216,21 +1221,17 @@ std::string Generator::InternalPackage() const {
: "google3.net.google.protobuf.python.internal";
}

// Prints standard constructor arguments serialized_start and serialized_end.
// Prints descriptor offsets _serialized_start and _serialized_end.
// Args:
// descriptor: The cpp descriptor to have a serialized reference.
// proto: A proto
// descriptor_proto: The descriptor proto to have a serialized reference.
// Example printer output:
// serialized_start=41,
// serialized_end=43,
//
template <typename DescriptorT, typename DescriptorProtoT>
void Generator::PrintSerializedPbInterval(const DescriptorT& descriptor,
DescriptorProtoT& proto,
absl::string_view name) const {
descriptor.CopyTo(&proto);
// _globals['_MYMESSAGE']._serialized_start=47
// _globals['_MYMESSAGE']._serialized_end=76
template <typename DescriptorProtoT>
void Generator::PrintSerializedPbInterval(
const DescriptorProtoT& descriptor_proto, absl::string_view name) const {
std::string sp;
proto.SerializeToString(&sp);
descriptor_proto.SerializeToString(&sp);
int offset = file_descriptor_serialized_.find(sp);
ABSL_CHECK_GE(offset, 0);

Expand All @@ -1254,51 +1255,56 @@ void PrintDescriptorOptionsFixingCode(absl::string_view descriptor,
}
} // namespace

void Generator::SetSerializedPbInterval() const {
// Generates the start and end offsets for each entity in the serialized file
// descriptor. The file argument must exactly match what was serialized into
// file_descriptor_serialized_, and should already have had any
// source-retention options stripped out. This is important because we need an
// exact byte-for-byte match so that we can successfully find the correct
// offsets in the serialized descriptors.
void Generator::SetSerializedPbInterval(const FileDescriptorProto& file) const {
// Top level enums.
for (int i = 0; i < file_->enum_type_count(); ++i) {
EnumDescriptorProto proto;
const EnumDescriptor& descriptor = *file_->enum_type(i);
PrintSerializedPbInterval(descriptor, proto,
PrintSerializedPbInterval(file.enum_type(i),
ModuleLevelDescriptorName(descriptor));
}

// Messages.
for (int i = 0; i < file_->message_type_count(); ++i) {
SetMessagePbInterval(*file_->message_type(i));
SetMessagePbInterval(file.message_type(i), *file_->message_type(i));
}

// Services.
for (int i = 0; i < file_->service_count(); ++i) {
ServiceDescriptorProto proto;
const ServiceDescriptor& service = *file_->service(i);
PrintSerializedPbInterval(service, proto,
PrintSerializedPbInterval(file.service(i),
ModuleLevelServiceDescriptorName(service));
}
}

void Generator::SetMessagePbInterval(const Descriptor& descriptor) const {
DescriptorProto message_proto;
PrintSerializedPbInterval(descriptor, message_proto,
void Generator::SetMessagePbInterval(const DescriptorProto& message_proto,
const Descriptor& descriptor) const {
PrintSerializedPbInterval(message_proto,
ModuleLevelDescriptorName(descriptor));

// Nested messages.
for (int i = 0; i < descriptor.nested_type_count(); ++i) {
SetMessagePbInterval(*descriptor.nested_type(i));
SetMessagePbInterval(message_proto.nested_type(i),
*descriptor.nested_type(i));
}

for (int i = 0; i < descriptor.enum_type_count(); ++i) {
EnumDescriptorProto proto;
const EnumDescriptor& enum_des = *descriptor.enum_type(i);
PrintSerializedPbInterval(enum_des, proto,
PrintSerializedPbInterval(message_proto.enum_type(i),
ModuleLevelDescriptorName(enum_des));
}
}

// Prints expressions that set the options field of all descriptors.
void Generator::FixAllDescriptorOptions() const {
// Prints an expression that sets the file descriptor's options.
std::string file_options = OptionsValue(file_->options().SerializeAsString());
std::string file_options = OptionsValue(
StripLocalSourceRetentionOptions(*file_).SerializeAsString());
if (file_options != "None") {
PrintDescriptorOptionsFixingCode(kDescriptorKey, file_options, printer_);
} else {
Expand Down Expand Up @@ -1326,7 +1332,8 @@ void Generator::FixAllDescriptorOptions() const {
}

void Generator::FixOptionsForOneof(const OneofDescriptor& oneof) const {
std::string oneof_options = OptionsValue(oneof.options().SerializeAsString());
std::string oneof_options =
OptionsValue(StripLocalSourceRetentionOptions(oneof).SerializeAsString());
if (oneof_options != "None") {
std::string oneof_name = absl::Substitute(
"$0.$1['$2']", ModuleLevelDescriptorName(*oneof.containing_type()),
Expand All @@ -1339,15 +1346,15 @@ void Generator::FixOptionsForOneof(const OneofDescriptor& oneof) const {
// value descriptors.
void Generator::FixOptionsForEnum(const EnumDescriptor& enum_descriptor) const {
std::string descriptor_name = ModuleLevelDescriptorName(enum_descriptor);
std::string enum_options =
OptionsValue(enum_descriptor.options().SerializeAsString());
std::string enum_options = OptionsValue(
StripLocalSourceRetentionOptions(enum_descriptor).SerializeAsString());
if (enum_options != "None") {
PrintDescriptorOptionsFixingCode(descriptor_name, enum_options, printer_);
}
for (int i = 0; i < enum_descriptor.value_count(); ++i) {
const EnumValueDescriptor& value_descriptor = *enum_descriptor.value(i);
std::string value_options =
OptionsValue(value_descriptor.options().SerializeAsString());
std::string value_options = OptionsValue(
StripLocalSourceRetentionOptions(value_descriptor).SerializeAsString());
if (value_options != "None") {
PrintDescriptorOptionsFixingCode(
absl::StrFormat("%s.values_by_name[\"%s\"]", descriptor_name.c_str(),
Expand All @@ -1363,17 +1370,17 @@ void Generator::FixOptionsForService(
const ServiceDescriptor& service_descriptor) const {
std::string descriptor_name =
ModuleLevelServiceDescriptorName(service_descriptor);
std::string service_options =
OptionsValue(service_descriptor.options().SerializeAsString());
std::string service_options = OptionsValue(
StripLocalSourceRetentionOptions(service_descriptor).SerializeAsString());
if (service_options != "None") {
PrintDescriptorOptionsFixingCode(descriptor_name, service_options,
printer_);
}

for (int i = 0; i < service_descriptor.method_count(); ++i) {
const MethodDescriptor* method = service_descriptor.method(i);
std::string method_options =
OptionsValue(method->options().SerializeAsString());
std::string method_options = OptionsValue(
StripLocalSourceRetentionOptions(*method).SerializeAsString());
if (method_options != "None") {
std::string method_name = absl::StrCat(
descriptor_name, ".methods_by_name['", method->name(), "']");
Expand All @@ -1385,7 +1392,8 @@ void Generator::FixOptionsForService(
// Prints expressions that set the options for field descriptors (including
// extensions).
void Generator::FixOptionsForField(const FieldDescriptor& field) const {
std::string field_options = OptionsValue(field.options().SerializeAsString());
std::string field_options =
OptionsValue(StripLocalSourceRetentionOptions(field).SerializeAsString());
if (field_options != "None") {
std::string field_name;
if (field.is_extension()) {
Expand Down Expand Up @@ -1430,8 +1438,8 @@ void Generator::FixOptionsForMessage(const Descriptor& descriptor) const {
FixOptionsForField(field);
}
// Message option for this message.
std::string message_options =
OptionsValue(descriptor.options().SerializeAsString());
std::string message_options = OptionsValue(
StripLocalSourceRetentionOptions(descriptor).SerializeAsString());
if (message_options != "None") {
std::string descriptor_name = ModuleLevelDescriptorName(descriptor);
PrintDescriptorOptionsFixingCode(descriptor_name, message_options,
Expand Down
Loading

0 comments on commit 63389c0

Please sign in to comment.