Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nextgen Proto Pythonic API: “Add-on” proto for length prefixed serialize/parse #16965

Merged
1 commit merged into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/build_targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,11 @@ def build_targets(name):
srcs = ["google/protobuf/internal/well_known_types_test.py"],
)

internal_py_test(
name = "decoder_test",
srcs = ["google/protobuf/internal/decoder_test.py"],
)

internal_py_test(
name = "wire_format_test",
srcs = ["google/protobuf/internal/wire_format_test.py"],
Expand Down
22 changes: 17 additions & 5 deletions python/google/protobuf/internal/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@
import math
import struct

from google.protobuf import message
from google.protobuf.internal import containers
from google.protobuf.internal import encoder
from google.protobuf.internal import wire_format
from google.protobuf import message


# This is not for optimization, but rather to avoid conflicts with local
Expand All @@ -81,20 +81,32 @@ def _VarintDecoder(mask, result_type):
decoder returns a (value, new_pos) pair.
"""

def DecodeVarint(buffer, pos):
def DecodeVarint(buffer, pos: int=None):
result = 0
shift = 0
while 1:
b = buffer[pos]
if pos is None:
# Read from BytesIO
try:
b = buffer.read(1)[0]
except IndexError as e:
if shift == 0:
# End of BytesIO.
return None
else:
raise ValueError('Fail to read varint %s' % str(e))
else:
b = buffer[pos]
pos += 1
result |= ((b & 0x7f) << shift)
pos += 1
if not (b & 0x80):
result &= mask
result = result_type(result)
return (result, pos)
return result if pos is None else (result, pos)
shift += 7
if shift >= 64:
raise _DecodeError('Too many bytes when decoding varint.')

return DecodeVarint


Expand Down
57 changes: 57 additions & 0 deletions python/google/protobuf/internal/decoder_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# -*- coding: utf-8 -*-
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
#
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd

"""Test decoder."""

import io
import unittest

from google.protobuf.internal import decoder
from google.protobuf.internal import testing_refleaks


_INPUT_BYTES = b'\x84r\x12'
_EXPECTED = (14596, 18)


@testing_refleaks.TestCase
class DecoderTest(unittest.TestCase):

def test_decode_varint_bytes(self):
(size, pos) = decoder._DecodeVarint(_INPUT_BYTES, 0)
self.assertEqual(size, _EXPECTED[0])
self.assertEqual(pos, 2)

(size, pos) = decoder._DecodeVarint(_INPUT_BYTES, 2)
self.assertEqual(size, _EXPECTED[1])
self.assertEqual(pos, 3)

def test_decode_varint_bytes_empty(self):
with self.assertRaises(IndexError) as context:
(size, pos) = decoder._DecodeVarint(b'', 0)
self.assertIn('index out of range', str(context.exception))

def test_decode_varint_bytesio(self):
index = 0
input_io = io.BytesIO(_INPUT_BYTES)
while True:
size = decoder._DecodeVarint(input_io)
if size is None:
break
self.assertEqual(size, _EXPECTED[index])
index += 1
self.assertEqual(index, len(_EXPECTED))

def test_decode_varint_bytesio_empty(self):
input_io = io.BytesIO(b'')
size = decoder._DecodeVarint(input_io)
self.assertEqual(size, None)


if __name__ == '__main__':
unittest.main()
85 changes: 84 additions & 1 deletion python/google/protobuf/internal/proto_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,19 @@

"""Tests Nextgen Pythonic protobuf APIs."""

import io
import unittest

from google.protobuf import proto

from google.protobuf.internal import encoder
from google.protobuf.internal import test_util
from google.protobuf.internal import testing_refleaks

from google.protobuf.internal import _parameterized
from google.protobuf import unittest_pb2
from google.protobuf import unittest_proto3_arena_pb2


@_parameterized.named_parameters(('_proto2', unittest_pb2),
('_proto3', unittest_proto3_arena_pb2))
@testing_refleaks.TestCase
Expand All @@ -30,6 +33,86 @@ def test_simple_serialize_parse(self, message_module):
parsed_msg = proto.parse(message_module.TestAllTypes, serialized_data)
self.assertEqual(msg, parsed_msg)

def test_serialize_parse_length_prefixed_empty(self, message_module):
empty_alltypes = message_module.TestAllTypes()
out = io.BytesIO()
proto.serialize_length_prefixed(empty_alltypes, out)

input_bytes = io.BytesIO(out.getvalue())
msg = proto.parse_length_prefixed(message_module.TestAllTypes, input_bytes)

self.assertEqual(msg, empty_alltypes)

def test_parse_length_prefixed_truncated(self, message_module):
out = io.BytesIO()
encoder._VarintEncoder()(out.write, 9999)
msg = message_module.TestAllTypes(optional_int32=1)
out.write(proto.serialize(msg))

input_bytes = io.BytesIO(out.getvalue())
with self.assertRaises(ValueError) as context:
proto.parse_length_prefixed(message_module.TestAllTypes, input_bytes)
self.assertEqual(
str(context.exception),
'Truncated message or non-buffered input_bytes: '
'Expected 9999 bytes but only 2 bytes parsed for '
'TestAllTypes.',
)

def test_serialize_length_prefixed_fake_io(self, message_module):
class FakeBytesIO(io.BytesIO):

def write(self, b: bytes) -> int:
return 0

msg = message_module.TestAllTypes(optional_int32=123)
out = FakeBytesIO()
with self.assertRaises(TypeError) as context:
proto.serialize_length_prefixed(msg, out)
self.assertIn(
'Failed to write complete message (wrote: 0, expected: 2)',
str(context.exception),
)


_EXPECTED_PROTO3 = b'\x04r\x02hi\x06\x08\x01r\x02hi\x06\x08\x02r\x02hi'
_EXPECTED_PROTO2 = b'\x06\x08\x00r\x02hi\x06\x08\x01r\x02hi\x06\x08\x02r\x02hi'


@_parameterized.named_parameters(
('_proto2', unittest_pb2, _EXPECTED_PROTO2),
('_proto3', unittest_proto3_arena_pb2, _EXPECTED_PROTO3),
)
@testing_refleaks.TestCase
class LengthPrefixedWithGolden(unittest.TestCase):

def test_serialize_length_prefixed(self, message_module, expected):
number_of_messages = 3

out = io.BytesIO()
for index in range(0, number_of_messages):
msg = message_module.TestAllTypes(
optional_int32=index, optional_string='hi'
)
proto.serialize_length_prefixed(msg, out)

self.assertEqual(out.getvalue(), expected)

def test_parse_length_prefixed(self, message_module, input_bytes):
expected_number_of_messages = 3

input_io = io.BytesIO(input_bytes)
index = 0
while True:
msg = proto.parse_length_prefixed(message_module.TestAllTypes, input_io)
if msg is None:
break
self.assertEqual(msg.optional_int32, index)
self.assertEqual(msg.optional_string, 'hi')
index += 1

self.assertEqual(index, expected_number_of_messages)


if __name__ == '__main__':
unittest.main()
83 changes: 80 additions & 3 deletions python/google/protobuf/proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,17 @@

"""Contains the Nextgen Pythonic protobuf APIs."""

import typing
import io
from typing import Type, TypeVar

from google.protobuf.internal import decoder
from google.protobuf.internal import encoder
from google.protobuf.message import Message

def serialize(message: Message, deterministic: bool=None) -> bytes:
_MESSAGE = TypeVar('_MESSAGE', bound='Message')


def serialize(message: _MESSAGE, deterministic: bool = None) -> bytes:
"""Return the serialized proto.

Args:
Expand All @@ -24,7 +30,8 @@ def serialize(message: Message, deterministic: bool=None) -> bytes:
"""
return message.SerializeToString(deterministic=deterministic)

def parse(message_class: typing.Type[Message], payload: bytes) -> Message:

def parse(message_class: Type[_MESSAGE], payload: bytes) -> _MESSAGE:
"""Given a serialized data in binary form, deserialize it into a Message.

Args:
Expand All @@ -37,3 +44,73 @@ def parse(message_class: typing.Type[Message], payload: bytes) -> Message:
new_message = message_class()
new_message.ParseFromString(payload)
return new_message


def serialize_length_prefixed(message: _MESSAGE, output: io.BytesIO) -> None:
"""Writes the size of the message as a varint and the serialized message.

Writes the size of the message as a varint and then the serialized message.
This allows more data to be written to the output after the message. Use
parse_length_prefixed to parse messages written by this method.

The output stream must be buffered, e.g. using
https://docs.python.org/3/library/io.html#buffered-streams.

Example usage:
out = io.BytesIO()
for msg in message_list:
proto.serialize_length_prefixed(msg, out)

Args:
message: The protocol buffer message that should be serialized.
output: BytesIO or custom buffered IO that data should be written to.
"""
size = message.ByteSize()
encoder._VarintEncoder()(output.write, size)
out_size = output.write(serialize(message))

if out_size != size:
raise TypeError(
'Failed to write complete message (wrote: %d, expected: %d)'
'. Ensure output is using buffered IO.' % (out_size, size)
)


def parse_length_prefixed(
message_class: Type[_MESSAGE], input_bytes: io.BytesIO
) -> _MESSAGE:
"""Parse a message from input_bytes.

Args:
message_class: The protocol buffer message class that parser should parse.
input_bytes: A buffered input.

Example usage:
while True:
msg = proto.parse_length_prefixed(message_class, input_bytes)
if msg is None:
break
...

Returns:
A parsed message if successful. None if input_bytes is at EOF.
"""
size = decoder._DecodeVarint(input_bytes)
if size is None:
# It is the end of buffered input. See example usage in the
# API description.
return None

message = message_class()

if size == 0:
return message

parsed_size = message.ParseFromString(input_bytes.read(size))
if parsed_size != size:
raise ValueError(
'Truncated message or non-buffered input_bytes: '
'Expected {0} bytes but only {1} bytes parsed for '
'{2}.'.format(size, parsed_size, message.DESCRIPTOR.name)
)
return message
Loading