diff --git a/proto/marshal/marshal.py b/proto/marshal/marshal.py index 0a12752e..bd21cbc1 100644 --- a/proto/marshal/marshal.py +++ b/proto/marshal/marshal.py @@ -18,6 +18,7 @@ from google.protobuf import message from google.protobuf import duration_pb2 from google.protobuf import timestamp_pb2 +from google.protobuf import field_mask_pb2 from google.protobuf import struct_pb2 from google.protobuf import wrappers_pb2 @@ -31,6 +32,7 @@ from proto.marshal.rules import dates from proto.marshal.rules import struct from proto.marshal.rules import wrappers +from proto.marshal.rules import field_mask from proto.primitives import ProtoType @@ -126,6 +128,9 @@ def reset(self): self.register(timestamp_pb2.Timestamp, dates.TimestampRule()) self.register(duration_pb2.Duration, dates.DurationRule()) + # Register FieldMask wrappers. + self.register(field_mask_pb2.FieldMask, field_mask.FieldMaskRule()) + # Register nullable primitive wrappers. self.register(wrappers_pb2.BoolValue, wrappers.BoolValueRule()) self.register(wrappers_pb2.BytesValue, wrappers.BytesValueRule()) diff --git a/proto/marshal/rules/dates.py b/proto/marshal/rules/dates.py index 5145bcf8..33d12829 100644 --- a/proto/marshal/rules/dates.py +++ b/proto/marshal/rules/dates.py @@ -47,6 +47,10 @@ def to_proto(self, value) -> timestamp_pb2.Timestamp: seconds=int(value.timestamp()), nanos=value.microsecond * 1000, ) + if isinstance(value, str): + timestamp_value = timestamp_pb2.Timestamp() + timestamp_value.FromJsonString(value=value) + return timestamp_value return value @@ -74,4 +78,8 @@ def to_proto(self, value) -> duration_pb2.Duration: seconds=value.days * 86400 + value.seconds, nanos=value.microseconds * 1000, ) + if isinstance(value, str): + duration_value = duration_pb2.Duration() + duration_value.FromJsonString(value=value) + return duration_value return value diff --git a/proto/marshal/rules/field_mask.py b/proto/marshal/rules/field_mask.py new file mode 100644 index 00000000..348e7e39 --- /dev/null +++ b/proto/marshal/rules/field_mask.py @@ -0,0 +1,36 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.protobuf import field_mask_pb2 + + +class FieldMaskRule: + """A marshal between FieldMask and strings. + + See https://github.com/googleapis/proto-plus-python/issues/333 + and + https://developers.google.com/protocol-buffers/docs/proto3#json + for more details. + """ + + def to_python(self, value, *, absent: bool = None): + return value + + def to_proto(self, value): + if isinstance(value, str): + field_mask_value = field_mask_pb2.FieldMask() + field_mask_value.FromJsonString(value=value) + return field_mask_value + + return value diff --git a/tests/test_marshal_field_mask.py b/tests/test_marshal_field_mask.py new file mode 100644 index 00000000..ffb36a2f --- /dev/null +++ b/tests/test_marshal_field_mask.py @@ -0,0 +1,100 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.protobuf import field_mask_pb2 + +import proto +from proto.marshal.marshal import BaseMarshal + + +def test_field_mask_read(): + class Foo(proto.Message): + mask = proto.Field( + proto.MESSAGE, + number=1, + message=field_mask_pb2.FieldMask, + ) + + foo = Foo(mask=field_mask_pb2.FieldMask(paths=["f.b.d", "f.c"])) + + assert isinstance(foo.mask, field_mask_pb2.FieldMask) + assert foo.mask.paths == ["f.b.d", "f.c"] + + +def test_field_mask_write_string(): + class Foo(proto.Message): + mask = proto.Field( + proto.MESSAGE, + number=1, + message=field_mask_pb2.FieldMask, + ) + + foo = Foo() + foo.mask = "f.b.d,f.c" + + assert isinstance(foo.mask, field_mask_pb2.FieldMask) + assert foo.mask.paths == ["f.b.d", "f.c"] + + +def test_field_mask_write_pb2(): + class Foo(proto.Message): + mask = proto.Field( + proto.MESSAGE, + number=1, + message=field_mask_pb2.FieldMask, + ) + + foo = Foo() + foo.mask = field_mask_pb2.FieldMask(paths=["f.b.d", "f.c"]) + + assert isinstance(foo.mask, field_mask_pb2.FieldMask) + assert foo.mask.paths == ["f.b.d", "f.c"] + + +def test_field_mask_absence(): + class Foo(proto.Message): + mask = proto.Field( + proto.MESSAGE, + number=1, + message=field_mask_pb2.FieldMask, + ) + + foo = Foo() + assert not foo.mask.paths + + +def test_timestamp_del(): + class Foo(proto.Message): + mask = proto.Field( + proto.MESSAGE, + number=1, + message=field_mask_pb2.FieldMask, + ) + + foo = Foo() + foo.mask = field_mask_pb2.FieldMask(paths=["f.b.d", "f.c"]) + + del foo.mask + assert not foo.mask.paths + + +def test_timestamp_to_python_idempotent(): + # This path can never run in the current configuration because proto + # values are the only thing ever saved, and `to_python` is a read method. + # + # However, we test idempotency for consistency with `to_proto` and + # general resiliency. + marshal = BaseMarshal() + py_value = "f.b.d,f.c" + assert marshal.to_python(field_mask_pb2.FieldMask, py_value) is py_value diff --git a/tests/test_marshal_types_dates.py b/tests/test_marshal_types_dates.py index 5fe09ab5..21841b99 100644 --- a/tests/test_marshal_types_dates.py +++ b/tests/test_marshal_types_dates.py @@ -98,6 +98,24 @@ class Foo(proto.Message): assert Foo.pb(foo).event_time.seconds == 1335020400 +def test_timestamp_write_string(): + class Foo(proto.Message): + event_time = proto.Field( + proto.MESSAGE, + number=1, + message=timestamp_pb2.Timestamp, + ) + + foo = Foo() + foo.event_time = "2012-04-21T15:00:00Z" + assert isinstance(foo.event_time, DatetimeWithNanoseconds) + assert isinstance(Foo.pb(foo).event_time, timestamp_pb2.Timestamp) + assert foo.event_time.year == 2012 + assert foo.event_time.month == 4 + assert foo.event_time.hour == 15 + assert Foo.pb(foo).event_time.seconds == 1335020400 + + def test_timestamp_rmw_nanos(): class Foo(proto.Message): event_time = proto.Field( @@ -207,6 +225,22 @@ class Foo(proto.Message): assert Foo.pb(foo).ttl.seconds == 120 +def test_duration_write_string(): + class Foo(proto.Message): + ttl = proto.Field( + proto.MESSAGE, + number=1, + message=duration_pb2.Duration, + ) + + foo = Foo() + foo.ttl = "120s" + assert isinstance(foo.ttl, timedelta) + assert isinstance(Foo.pb(foo).ttl, duration_pb2.Duration) + assert foo.ttl.seconds == 120 + assert Foo.pb(foo).ttl.seconds == 120 + + def test_duration_del(): class Foo(proto.Message): ttl = proto.Field(