Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "absl::StatusOr<PythonConstMessagePointer> GetConstMessagePointer(PyObject* msg)" in proto_api which works with cpp extension, upb and pure python. #19302

Merged
merged 1 commit into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/build_targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,7 @@ def build_targets(name):
visibility = ["//visibility:public"],
deps = [
"//src/google/protobuf",
"//src/google/protobuf/io",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/status",
"@system_python//:python_headers",
Expand Down
85 changes: 85 additions & 0 deletions python/google/protobuf/proto_api.cc
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
#include "google/protobuf/proto_api.h"

#include <Python.h>

#include <memory>
#include <string>

#include "absl/log/absl_check.h"
#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
#include "google/protobuf/message.h"
namespace google {
namespace protobuf {
Expand Down Expand Up @@ -52,6 +56,87 @@ PythonMessageMutator PyProto_API::CreatePythonMessageMutator(
return PythonMessageMutator(owned_msg, msg, py_msg);
}

PythonConstMessagePointer::PythonConstMessagePointer(Message* owned_msg,
const Message* message,
PyObject* py_msg)
: owned_msg_(owned_msg), message_(message), py_msg_(py_msg) {
ABSL_DCHECK(py_msg != nullptr);
ABSL_DCHECK(message != nullptr);
Py_INCREF(py_msg_);
}

PythonConstMessagePointer::PythonConstMessagePointer(
PythonConstMessagePointer&& other)
: owned_msg_(other.owned_msg_ == nullptr ? nullptr
: other.owned_msg_.release()),
message_(other.message_),
py_msg_(other.py_msg_) {
other.message_ = nullptr;
other.py_msg_ = nullptr;
}

bool PythonConstMessagePointer::NotChanged() {
ABSL_DCHECK(!PyErr_Occurred());
if (owned_msg_ == nullptr) {
return false;
}

PyObject* py_serialized_pb(
PyObject_CallMethod(py_msg_, "SerializeToString", nullptr));
if (py_serialized_pb == nullptr) {
PyErr_Format(PyExc_ValueError, "Fail to serialize py_msg");
return false;
}
char* data;
Py_ssize_t len;
if (PyBytes_AsStringAndSize(py_serialized_pb, &data, &len) < 0) {
Py_DECREF(py_serialized_pb);
PyErr_Format(PyExc_ValueError, "Fail to get bytes from serialized data");
return false;
}

// Even if serialize python message deterministic above, the
// serialize result may still diff between languages. So parse to
// another c++ message for compare.
std::unique_ptr<google::protobuf::Message> parsed_msg(owned_msg_->New());
parsed_msg->ParseFromArray(data, static_cast<int>(len));
std::string wire_other;
google::protobuf::io::StringOutputStream stream_other(&wire_other);
google::protobuf::io::CodedOutputStream output_other(&stream_other);
output_other.SetSerializationDeterministic(true);
parsed_msg->SerializeToCodedStream(&output_other);

std::string wire;
google::protobuf::io::StringOutputStream stream(&wire);
google::protobuf::io::CodedOutputStream output(&stream);
output.SetSerializationDeterministic(true);
owned_msg_->SerializeToCodedStream(&output);

if (wire == wire_other) {
Py_DECREF(py_serialized_pb);
return true;
}
PyErr_Format(PyExc_ValueError, "pymessage has been changed");
Py_DECREF(py_serialized_pb);
return false;
}

PythonConstMessagePointer::~PythonConstMessagePointer() {
if (py_msg_ == nullptr) {
ABSL_DCHECK(message_ == nullptr);
ABSL_DCHECK(owned_msg_ == nullptr);
return;
}
ABSL_DCHECK(owned_msg_ != nullptr);
ABSL_DCHECK(NotChanged());
Py_DECREF(py_msg_);
}

PythonConstMessagePointer PyProto_API::CreatePythonConstMessagePointer(
Message* owned_msg, const Message* msg, PyObject* py_msg) const {
return PythonConstMessagePointer(owned_msg, msg, py_msg);
}

} // namespace python
} // namespace protobuf
} // namespace google
44 changes: 41 additions & 3 deletions python/google/protobuf/proto_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
// PyProtoAPICapsuleName(), 0));
// if (!py_proto_api) { ...handle ImportError... }
// Then use the methods of the returned class:
// py_proto_api->GetMessagePointer(...);
// py_proto_api->GetConstMessagePointer(...);

#ifndef GOOGLE_PROTOBUF_PYTHON_PROTO_API_H__
#define GOOGLE_PROTOBUF_PYTHON_PROTO_API_H__
Expand All @@ -31,11 +31,14 @@
#include "google/protobuf/descriptor_database.h"
#include "google/protobuf/message.h"

PyObject* pymessage_mutate_const(PyObject* self, PyObject* args);

namespace google {
namespace protobuf {
namespace python {

class PythonMessageMutator;
class PythonConstMessagePointer;

// Note on the implementation:
// This API is designed after
Expand All @@ -55,23 +58,36 @@ struct PyProto_API {
// Side-effect: The message will definitely be cleared. *When* the message
// gets cleared is undefined (C++ will clear it up-front, python/upb will
// clear it on destruction). Nothing should rely on the python message
// during the lifetime of this object
// during the lifetime of this object.
// User should not hold onto the returned PythonMessageMutator while
// calling back into Python
// calling back into Python.
// Warning: there is a risk of deadlock with Python/C++ if users use the
// returned message->GetDescriptor()->file->pool()
virtual absl::StatusOr<PythonMessageMutator> GetClearedMessageMutator(
PyObject* msg) const = 0;

// Returns a PythonConstMessagePointer. For UPB and Pure Python, it points
// to a new c++ message copied from python message. For cpp extension, it
// points the internal c++ message.
// User should not hold onto the returned PythonConstMessagePointer
// while calling back into Python.
virtual absl::StatusOr<PythonConstMessagePointer> GetConstMessagePointer(
PyObject* msg) const = 0;

// If the passed object is a Python Message, returns its internal pointer.
// Otherwise, returns NULL with an exception set.
// TODO: Remove deprecated GetMessagePointer().
[[deprecated(
"GetMessagePointer() only work with Cpp Extension, "
"please migrate to GetConstMessagePointer().")]]
virtual const Message* GetMessagePointer(PyObject* msg) const = 0;

// If the passed object is a Python Message, returns a mutable pointer.
// Otherwise, returns NULL with an exception set.
// This function will succeed only if there are no other Python objects
// pointing to the message, like submessages or repeated containers.
// With the current implementation, only empty messages are in this case.
// TODO: Remove deprecated GetMutableMessagePointer().
[[deprecated(
"GetMutableMessagePointer() only work with Cpp Extension, "
"please migrate to GetClearedMessageMutator().")]]
Expand Down Expand Up @@ -133,6 +149,8 @@ struct PyProto_API {
PythonMessageMutator CreatePythonMessageMutator(Message* owned_msg,
Message* msg,
PyObject* py_msg) const;
PythonConstMessagePointer CreatePythonConstMessagePointer(
Message* owned_msg, const Message* msg, PyObject* py_msg) const;
};

// User should not hold onto this object while calling back into Python
Expand Down Expand Up @@ -161,6 +179,26 @@ class PythonMessageMutator {
PyObject* py_msg_;
};

class PythonConstMessagePointer {
public:
PythonConstMessagePointer(PythonConstMessagePointer&& other);
~PythonConstMessagePointer();

const Message& get() { return *message_; }

private:
friend struct google::protobuf::python::PyProto_API;
PythonConstMessagePointer(Message* owned_msg, const Message* message,
PyObject* py_msg);

friend PyObject* ::pymessage_mutate_const(PyObject* self, PyObject* args);
// Check if the const message has been changed.
bool NotChanged();
std::unique_ptr<Message> owned_msg_;
const Message* message_;
PyObject* py_msg_;
};

inline const char* PyProtoAPICapsuleName() {
static const char kCapsuleName[] = "google.protobuf.pyext._message.proto_API";
return kCapsuleName;
Expand Down
115 changes: 82 additions & 33 deletions python/google/protobuf/pyext/message_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,54 +158,103 @@ google::protobuf::DynamicMessageFactory* GetFactory() {
return factory;
}

absl::StatusOr<google::protobuf::Message*> CreateNewMessage(PyObject* py_msg) {
PyObject* pyd = PyObject_GetAttrString(py_msg, "DESCRIPTOR");
if (pyd == nullptr) {
return absl::InvalidArgumentError("py_msg has no attribute 'DESCRIPTOR'");
}

PyObject* fn = PyObject_GetAttrString(pyd, "full_name");
if (fn == nullptr) {
return absl::InvalidArgumentError(
"DESCRIPTOR has no attribute 'full_name'");
}

const char* descriptor_full_name = PyUnicode_AsUTF8(fn);
if (descriptor_full_name == nullptr) {
return absl::InternalError("Fail to convert descriptor full name");
}

PyObject* pyfile = PyObject_GetAttrString(pyd, "file");
Py_DECREF(pyd);
if (pyfile == nullptr) {
return absl::InvalidArgumentError("DESCRIPTOR has no attribute 'file'");
}
auto gen_d = google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName(
descriptor_full_name);
if (gen_d) {
Py_DECREF(pyfile);
Py_DECREF(fn);
return google::protobuf::MessageFactory::generated_factory()
->GetPrototype(gen_d)
->New();
}
auto d = FindMessageDescriptor(pyfile, descriptor_full_name);
Py_DECREF(pyfile);
RETURN_IF_ERROR(d.status());
Py_DECREF(fn);
return GetFactory()->GetPrototype(*d)->New();
}

bool CopyToOwnedMsg(google::protobuf::Message** copy, const google::protobuf::Message& message) {
*copy = message.New();
std::string wire;
message.SerializeToString(&wire);
(*copy)->ParseFromArray(wire.data(), wire.size());
return true;
}

// C++ API. Clients get at this via proto_api.h
struct ApiImplementation : google::protobuf::python::PyProto_API {
absl::StatusOr<google::protobuf::python::PythonMessageMutator> GetClearedMessageMutator(
PyObject* py_msg) const override {
if (PyObject_TypeCheck(py_msg, google::protobuf::python::CMessage_Type)) {
google::protobuf::Message* message =
google::protobuf::python::PyMessage_GetMutableMessagePointer(py_msg);
if (message == nullptr) {
return absl::InternalError(
"Fail to get message pointer. The message "
"may already had a reference.");
}
message->Clear();
return CreatePythonMessageMutator(nullptr, message, py_msg);
}
PyObject* pyd = PyObject_GetAttrString(py_msg, "DESCRIPTOR");
if (pyd == nullptr) {
return absl::InvalidArgumentError("py_msg has no attribute 'DESCRIPTOR'");
}

PyObject* fn = PyObject_GetAttrString(pyd, "full_name");
if (fn == nullptr) {
return absl::InvalidArgumentError(
"DESCRIPTOR has no attribute 'full_name'");
}
auto msg = CreateNewMessage(py_msg);
RETURN_IF_ERROR(msg.status());
return CreatePythonMessageMutator(*msg, *msg, py_msg);
}

const char* descriptor_full_name = PyUnicode_AsUTF8(fn);
if (descriptor_full_name == nullptr) {
return absl::InternalError("Fail to convert descriptor full name");
absl::StatusOr<google::protobuf::python::PythonConstMessagePointer>
GetConstMessagePointer(PyObject* py_msg) const override {
if (PyObject_TypeCheck(py_msg, google::protobuf::python::CMessage_Type)) {
const google::protobuf::Message* message =
google::protobuf::python::PyMessage_GetMessagePointer(py_msg);
google::protobuf::Message* owned_msg = nullptr;
ABSL_DCHECK(CopyToOwnedMsg(&owned_msg, *message));
return CreatePythonConstMessagePointer(owned_msg, message, py_msg);
}

PyObject* pyfile = PyObject_GetAttrString(pyd, "file");
Py_DECREF(pyd);
if (pyfile == nullptr) {
return absl::InvalidArgumentError("DESCRIPTOR has no attribute 'file'");
auto msg = CreateNewMessage(py_msg);
RETURN_IF_ERROR(msg.status());
PyObject* serialized_pb(
PyObject_CallMethod(py_msg, "SerializeToString", nullptr));
if (serialized_pb == nullptr) {
return absl::InternalError("Fail to serialize py_msg");
}
auto gen_d =
google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName(
descriptor_full_name);
if (gen_d) {
Py_DECREF(pyfile);
Py_DECREF(fn);
google::protobuf::Message* msg = google::protobuf::MessageFactory::generated_factory()
->GetPrototype(gen_d)
->New();
return CreatePythonMessageMutator(msg, msg, py_msg);
char* data;
Py_ssize_t len;
if (PyBytes_AsStringAndSize(serialized_pb, &data, &len) < 0) {
Py_DECREF(serialized_pb);
return absl::InternalError(
"Fail to get bytes from py_msg serialized data");
}
auto d = FindMessageDescriptor(pyfile, descriptor_full_name);
Py_DECREF(pyfile);
RETURN_IF_ERROR(d.status());
Py_DECREF(fn);
google::protobuf::Message* msg = GetFactory()->GetPrototype(*d)->New();
return CreatePythonMessageMutator(msg, msg, py_msg);
if (!(*msg)->ParseFromArray(data, len)) {
Py_DECREF(serialized_pb);
return absl::InternalError(
"Couldn't parse py_message to google::protobuf::Message*!");
}
Py_DECREF(serialized_pb);
return CreatePythonConstMessagePointer(*msg, *msg, py_msg);
}

const google::protobuf::Message* GetMessagePointer(PyObject* msg) const override {
Expand Down
Loading