Skip to content

Commit

Permalink
Use TypeVars for return types of RedisModel and its subtype's methods (
Browse files Browse the repository at this point in the history
…#476)

Co-authored-by: Chayim <[email protected]>
  • Loading branch information
marianhlavac and chayim authored Jul 12, 2023
1 parent 89b6c84 commit c68adac
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions aredis_om/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@

model_registry = {}
_T = TypeVar("_T")
Model = TypeVar("Model", bound="RedisModel")
log = logging.getLogger(__name__)
escaper = TokenEscaper()

Expand Down Expand Up @@ -1310,16 +1311,16 @@ async def delete(
return await cls._delete(db, cls.make_primary_key(pk))

@classmethod
async def get(cls, pk: Any) -> "RedisModel":
async def get(cls: Type["Model"], pk: Any) -> "Model":
raise NotImplementedError

async def update(self, **field_values):
"""Update this model instance with the specified key-value pairs."""
raise NotImplementedError

async def save(
self, pipeline: Optional[redis.client.Pipeline] = None
) -> "RedisModel":
self: "Model", pipeline: Optional[redis.client.Pipeline] = None
) -> "Model":
raise NotImplementedError

async def expire(
Expand Down Expand Up @@ -1423,11 +1424,11 @@ def get_annotations(cls):

@classmethod
async def add(
cls,
models: Sequence["RedisModel"],
cls: Type["Model"],
models: Sequence["Model"],
pipeline: Optional[redis.client.Pipeline] = None,
pipeline_verifier: Callable[..., Any] = verify_pipeline_response,
) -> Sequence["RedisModel"]:
) -> Sequence["Model"]:
db = cls._get_db(pipeline, bulk=True)

for model in models:
Expand Down Expand Up @@ -1502,8 +1503,8 @@ def __init_subclass__(cls, **kwargs):
)

async def save(
self, pipeline: Optional[redis.client.Pipeline] = None
) -> "HashModel":
self: "Model", pipeline: Optional[redis.client.Pipeline] = None
) -> "Model":
self.check()
db = self._get_db(pipeline)

Expand All @@ -1525,7 +1526,7 @@ async def all_pks(cls): # type: ignore
)

@classmethod
async def get(cls, pk: Any) -> "HashModel":
async def get(cls: Type["Model"], pk: Any) -> "Model":
document = await cls.db().hgetall(cls.make_primary_key(pk))
if not document:
raise NotFoundError
Expand Down Expand Up @@ -1676,8 +1677,8 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

async def save(
self, pipeline: Optional[redis.client.Pipeline] = None
) -> "JsonModel":
self: "Model", pipeline: Optional[redis.client.Pipeline] = None
) -> "Model":
self.check()
db = self._get_db(pipeline)

Expand Down Expand Up @@ -1722,7 +1723,7 @@ async def update(self, **field_values):
await self.save()

@classmethod
async def get(cls, pk: Any) -> "JsonModel":
async def get(cls: Type["Model"], pk: Any) -> "Model":
document = json.dumps(await cls.db().json().get(cls.make_key(pk)))
if document == "null":
raise NotFoundError
Expand Down

0 comments on commit c68adac

Please sign in to comment.