Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add delete_many to support for bulk deletes #305

Merged
merged 13 commits into from
Aug 10, 2022
66 changes: 42 additions & 24 deletions aredis_om/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
no_type_check,
)

from more_itertools import ichunked
from pydantic import BaseModel, validator
from pydantic.fields import FieldInfo as PydanticFieldInfo
from pydantic.fields import ModelField, Undefined, UndefinedType
Expand Down Expand Up @@ -1117,9 +1118,17 @@ def key(self):
return self.make_primary_key(pk)

@classmethod
async def delete(cls, pk: Any) -> int:
async def _delete(cls, db, *pks):
return await db.delete(*pks)

@classmethod
async def delete(
cls, pk: Any, pipeline: Optional[redis.client.Pipeline] = None
) -> int:
"""Delete data at this key."""
return await cls.db().delete(cls.make_primary_key(pk))
db = cls._get_db(pipeline)

return await cls._delete(db, cls.make_primary_key(pk))

@classmethod
async def get(cls, pk: Any) -> "RedisModel":
Expand All @@ -1137,10 +1146,7 @@ async def save(
async def expire(
self, num_seconds: int, pipeline: Optional[redis.client.Pipeline] = None
):
if pipeline is None:
db = self.db()
else:
db = pipeline
db = self._get_db(pipeline)

# TODO: Wrap any Redis response errors in a custom exception?
await db.expire(self.make_primary_key(self.pk), num_seconds)
Expand Down Expand Up @@ -1232,16 +1238,7 @@ async def add(
pipeline: Optional[redis.client.Pipeline] = None,
pipeline_verifier: Callable[..., Any] = verify_pipeline_response,
) -> Sequence["RedisModel"]:
if pipeline is None:
# By default, send commands in a pipeline. Saving each model will
# be atomic, but Redis may process other commands in between
# these saves.
db = cls.db().pipeline(transaction=False)
else:
# If the user gave us a pipeline, add our commands to that. The user
# will be responsible for executing the pipeline after they've accumulated
# the commands they want to send.
db = pipeline
db = cls._get_db(pipeline, bulk=True)

for model in models:
# save() just returns the model, we don't need that here.
Expand All @@ -1255,6 +1252,31 @@ async def add(

return models

@classmethod
def _get_db(
self, pipeline: Optional[redis.client.Pipeline] = None, bulk: bool = False
):
if pipeline is not None:
return pipeline
elif bulk:
return self.db().pipeline(transaction=False)
else:
return self.db()

@classmethod
async def delete_many(
cls,
models: Sequence["RedisModel"],
pipeline: Optional[redis.client.Pipeline] = None,
) -> int:
db = cls._get_db(pipeline)

for chunk in ichunked(models, 100):
pks = [cls.make_primary_key(model.pk) for model in chunk]
await cls._delete(db, *pks)

return len(models)

@classmethod
def redisearch_schema(cls):
raise NotImplementedError
Expand Down Expand Up @@ -1293,10 +1315,8 @@ async def save(
self, pipeline: Optional[redis.client.Pipeline] = None
) -> "HashModel":
self.check()
if pipeline is None:
db = self.db()
else:
db = pipeline
db = self._get_db(pipeline)

document = jsonable_encoder(self.dict())
# TODO: Wrap any Redis response errors in a custom exception?
await db.hset(self.key(), mapping=document)
Expand Down Expand Up @@ -1467,10 +1487,8 @@ async def save(
self, pipeline: Optional[redis.client.Pipeline] = None
) -> "JsonModel":
self.check()
if pipeline is None:
db = self.db()
else:
db = pipeline
db = self._get_db(pipeline)

# TODO: Wrap response errors in a custom exception?
await db.execute_command("JSON.SET", self.key(), ".", self.json())
return self
Expand Down
Loading