From 2a10bbecaf8955c7bf1956086aef42630112788b Mon Sep 17 00:00:00 2001 From: Dov Shlachter Date: Wed, 3 Mar 2021 17:00:12 -0800 Subject: [PATCH] fix: adding enums to a repeated field does not raise a TypeError (#202) Fixes issue #201, where enums added to a repeated field triggered a TypeError because they were coverted to integers during marshaling. --- proto/marshal/collections/repeated.py | 2 +- tests/test_marshal_strict.py | 24 ++++++++++++++++++++ tests/test_marshal_types_enum.py | 32 +++++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 1 deletion(-) create mode 100644 tests/test_marshal_strict.py diff --git a/proto/marshal/collections/repeated.py b/proto/marshal/collections/repeated.py index 30fa68d0..01b5d2fd 100644 --- a/proto/marshal/collections/repeated.py +++ b/proto/marshal/collections/repeated.py @@ -174,5 +174,5 @@ def __setitem__(self, key, value): def insert(self, index: int, value): """Insert ``value`` in the sequence before ``index``.""" - pb_value = self._marshal.to_proto(self._pb_type, value, strict=True) + pb_value = self._marshal.to_proto(self._pb_type, value) self.pb.insert(index, pb_value) diff --git a/tests/test_marshal_strict.py b/tests/test_marshal_strict.py new file mode 100644 index 00000000..75e2f05d --- /dev/null +++ b/tests/test_marshal_strict.py @@ -0,0 +1,24 @@ +# Copyright (C) 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import proto +from proto.marshal.marshal import BaseMarshal +import pytest + + +def test_strict_to_proto(): + m = BaseMarshal() + + with pytest.raises(TypeError): + m.to_proto(dict, None, strict=True) diff --git a/tests/test_marshal_types_enum.py b/tests/test_marshal_types_enum.py index 1d302053..6cd348c3 100644 --- a/tests/test_marshal_types_enum.py +++ b/tests/test_marshal_types_enum.py @@ -58,3 +58,35 @@ class Foo(proto.Enum): with mock.patch.object(warnings, "warn") as warn: assert enum_rule.to_python(4) == 4 warn.assert_called_once_with("Unrecognized Foo enum value: 4") + + +def test_enum_append(): + class Bivalve(proto.Enum): + CLAM = 0 + OYSTER = 1 + + class MolluscContainer(proto.Message): + bivalves = proto.RepeatedField(proto.ENUM, number=1, enum=Bivalve,) + + mc = MolluscContainer() + clam = Bivalve.CLAM + mc.bivalves.append(clam) + mc.bivalves.append(1) + + assert mc.bivalves == [clam, Bivalve.OYSTER] + + +def test_enum_map_insert(): + class Bivalve(proto.Enum): + CLAM = 0 + OYSTER = 1 + + class MolluscContainer(proto.Message): + bivalves = proto.MapField(proto.STRING, proto.ENUM, number=1, enum=Bivalve,) + + mc = MolluscContainer() + clam = Bivalve.CLAM + mc.bivalves["clam"] = clam + mc.bivalves["oyster"] = 1 + + assert mc.bivalves == {"clam": clam, "oyster": Bivalve.OYSTER}