From 54489781392598514d0dacff3a5bb76bdbd7a54a Mon Sep 17 00:00:00 2001 From: henribru <6639509+henribru@users.noreply.github.com> Date: Sat, 11 Nov 2023 15:31:45 +0100 Subject: [PATCH] Improve message base class --- proto-stubs/message.pyi | 48 ++++++++++++++++++++++------------------- 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/proto-stubs/message.pyi b/proto-stubs/message.pyi index 8cbd6f7..823769b 100644 --- a/proto-stubs/message.pyi +++ b/proto-stubs/message.pyi @@ -1,49 +1,53 @@ -from typing import Any, List, Type +from typing import Any, TypeVar, overload +from collections.abc import Mapping from google.protobuf import descriptor_pb2, message from proto.fields import Field from proto.marshal import Marshal +_M = TypeVar("_M") + class MessageMeta(type): def __new__(mcls, name, bases, attrs): ... @classmethod def __prepare__(mcls, name, bases, **kwargs): ... @property def meta(cls): ... - def pb(cls, obj: Any | None = ..., *, coerce: bool = ...): ... - def wrap(cls, pb): ... - def serialize(cls, instance) -> bytes: ... - def deserialize(cls, payload: bytes) -> Message: ... + @overload + def pb(cls: type[_M], obj: None = ..., *, coerce: bool = ...) -> type[message.Message]: ... + @overload + def pb(cls: type[_M], obj: _M, *, coerce: bool = ...) -> message.Message: ... + def wrap(cls: type[_M], pb: message.Message) -> _M: ... + def serialize(cls: type[_M], instance: _M | Mapping | message.Message) -> bytes: ... + def deserialize(cls: type[_M], payload: bytes) -> _M: ... def to_json( - cls, - instance, + cls: type[_M], + instance: _M, *, use_integers_for_enums: bool = ..., including_default_value_fields: bool = ..., preserving_proto_field_name: bool = ... ) -> str: ... - def from_json(cls, payload, *, ignore_unknown_fields: bool = ...) -> Message: ... + def from_json(cls: type[_M], payload: str, *, ignore_unknown_fields: bool = ...) -> _M: ... def to_dict( - cls, - instance, + cls: type[_M], + instance: _M, *, use_integers_for_enums: bool = ..., preserving_proto_field_name: bool = ... - ) -> Message: ... - def copy_from(cls, instance, other) -> None: ... + ) -> dict[str, Any]: ... + def copy_from(cls: type[_M], instance: _M | Mapping | message.Message, other) -> None: ... class Message(metaclass=MessageMeta): def __init__( - self, mapping: Any | None = ..., *, ignore_unknown_fields: bool = ..., **kwargs + self: _M, mapping: _M | Mapping | message.Message | None = ..., *, ignore_unknown_fields: bool = ..., **kwargs ) -> None: ... - def __bool__(self): ... - def __contains__(self, key): ... - def __delattr__(self, key) -> None: ... - def __eq__(self, other): ... - def __getattr__(self, key): ... - def __ne__(self, other): ... - def __setattr__(self, key, value): ... + def __bool__(self) -> bool: ... + def __contains__(self, key: str) -> bool: ... + def __delattr__(self, key: str) -> None: ... + def __eq__(self, other: Any) -> bool: ... + def __ne__(self, other: Any) -> bool: ... class _MessageInfo: package: Any @@ -54,11 +58,11 @@ class _MessageInfo: marshal: Any def __init__( self, - fields: List[Field], + fields: list[Field], package: str, full_name: str, marshal: Marshal, options: descriptor_pb2.MessageOptions, ) -> None: ... @property - def pb(self) -> Type[message.Message]: ... + def pb(self) -> type[message.Message]: ...