diff --git a/proto/message.py b/proto/message.py index 97ee3814..5135a542 100644 --- a/proto/message.py +++ b/proto/message.py @@ -273,6 +273,27 @@ def __prepare__(mcls, name, bases, **kwargs): def meta(cls): return cls._meta + def __dir__(self): + names = set(dir(type)) + names.update( + ( + "meta", + "pb", + "wrap", + "serialize", + "deserialize", + "to_json", + "from_json", + "to_dict", + "copy_from", + ) + ) + desc = self.pb().DESCRIPTOR + names.update(t.name for t in desc.nested_types) + names.update(e.name for e in desc.enum_types) + + return names + def pb(cls, obj=None, *, coerce: bool = False): """Return the underlying protobuf Message class or instance. @@ -520,6 +541,29 @@ def __init__( # Create the internal protocol buffer. super().__setattr__("_pb", self._meta.pb(**params)) + def __dir__(self): + desc = type(self).pb().DESCRIPTOR + names = {f_name for f_name in self._meta.fields.keys()} + names.update(m.name for m in desc.nested_types) + names.update(e.name for e in desc.enum_types) + names.update(dir(object())) + # Can't think of a better way of determining + # the special methods than manually listing them. + names.update( + ( + "__bool__", + "__contains__", + "__dict__", + "__getattr__", + "__getstate__", + "__module__", + "__setstate__", + "__weakref__", + ) + ) + + return names + def __bool__(self): """Return True if any field is truthy, False otherwise.""" return any(k in self and getattr(self, k) for k in self._meta.fields.keys()) diff --git a/tests/test_message.py b/tests/test_message.py index 5351dbd7..843fad22 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -346,3 +346,72 @@ class Squid(proto.Message): with pytest.raises(TypeError): Mollusc.Squid.copy_from(m.squid, (("mass_kg", 20))) + + +def test_dir(): + class Mollusc(proto.Message): + class Class(proto.Enum): + UNKNOWN = 0 + GASTROPOD = 1 + BIVALVE = 2 + CEPHALOPOD = 3 + + class Arm(proto.Message): + length_cm = proto.Field(proto.INT32, number=1) + + mass_kg = proto.Field(proto.INT32, number=1) + class_ = proto.Field(Class, number=2) + arms = proto.RepeatedField(Arm, number=3) + + expected = ( + { + # Fields and nested message and enum types + "arms", + "class_", + "mass_kg", + "Arm", + "Class", + } + | { + # Other methods and attributes + "__bool__", + "__contains__", + "__dict__", + "__getattr__", + "__getstate__", + "__module__", + "__setstate__", + "__weakref__", + } + | set(dir(object)) + ) # Gets the long tail of dunder methods and attributes. + + actual = set(dir(Mollusc())) + + # Check instance names + assert actual == expected + + # Check type names + expected = ( + set(dir(type)) + | { + # Class methods from the MessageMeta metaclass + "copy_from", + "deserialize", + "from_json", + "meta", + "pb", + "serialize", + "to_dict", + "to_json", + "wrap", + } + | { + # Nested message and enum types + "Arm", + "Class", + } + ) + + actual = set(dir(Mollusc)) + assert actual == expected