Skip to content

Commit

Permalink
Allow users to define a new primary key. (#347)
Browse files Browse the repository at this point in the history
* make primary key programmable

Signed-off-by: wiseaidev <[email protected]>

* get primary key field using the `key` method

Signed-off-by: wiseaidev <[email protected]>

* adjust delete_many & expire methods

Signed-off-by: wiseaidev <[email protected]>

* fix query for int primary key

Signed-off-by: wiseaidev <[email protected]>

* fix grammar

Signed-off-by: wiseaidev <[email protected]>

* add unit tests

Signed-off-by: wiseaidev <[email protected]>

Signed-off-by: wiseaidev <[email protected]>
Co-authored-by: Chayim <[email protected]>
Co-authored-by: dvora-h <[email protected]>
  • Loading branch information
3 people authored Sep 8, 2022
1 parent 2e09234 commit 551429c
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 13 deletions.
17 changes: 11 additions & 6 deletions aredis_om/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,10 @@ def resolve_value(
separator_char,
)
return ""
if separator_char in value:
if isinstance(value, int):
# This if will hit only if the field is a primary key of type int
result = f"@{field_name}:[{value} {value}]"
elif separator_char in value:
# The value contains the TAG field separator. We can work
# around this by breaking apart the values and unioning them
# with multiple field:{} queries.
Expand Down Expand Up @@ -1106,12 +1109,12 @@ class Config:
extra = "allow"

def __init__(__pydantic_self__, **data: Any) -> None:
super().__init__(**data)
__pydantic_self__.validate_primary_key()
super().__init__(**data)

def __lt__(self, other):
"""Default sort: compare primary key of models."""
return self.pk < other.pk
return self.key() < other.key()

def key(self):
"""Return the Redis key for this model."""
Expand Down Expand Up @@ -1150,7 +1153,7 @@ async def expire(
db = self._get_db(pipeline)

# TODO: Wrap any Redis response errors in a custom exception?
await db.expire(self.make_primary_key(self.pk), num_seconds)
await db.expire(self.key(), num_seconds)

@validator("pk", always=True, allow_reuse=True)
def validate_pk(cls, v):
Expand All @@ -1167,7 +1170,9 @@ def validate_primary_key(cls):
primary_keys += 1
if primary_keys == 0:
raise RedisModelError("You must define a primary key for the model")
elif primary_keys > 1:
elif primary_keys == 2:
cls.__fields__.pop('pk')
elif primary_keys > 2:
raise RedisModelError("You must define only one primary key for a model")

@classmethod
Expand Down Expand Up @@ -1275,7 +1280,7 @@ async def delete_many(
db = cls._get_db(pipeline)

for chunk in ichunked(models, 100):
pks = [cls.make_primary_key(model.pk) for model in chunk]
pks = [model.key() for model in chunk]
await cls._delete(db, *pks)

return len(models)
Expand Down
70 changes: 63 additions & 7 deletions tests/test_hash_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class Order(BaseHashModel):
created_on: datetime.datetime

class Member(BaseHashModel):
id: int = Field(index=True)
id: int = Field(index=True, primary_key=True)
first_name: str = Field(index=True)
last_name: str = Field(index=True)
email: str = Field(index=True)
Expand Down Expand Up @@ -445,7 +445,7 @@ async def test_saves_model_and_creates_pk(m):
# Save a model instance to Redis
await member.save()

member2 = await m.Member.get(member.pk)
member2 = await m.Member.get(pk=member.id)
assert member2 == member


Expand Down Expand Up @@ -495,7 +495,7 @@ async def test_delete(m):
)

await member.save()
response = await m.Member.delete(member.pk)
response = await m.Member.delete(pk=member.id)
assert response == 1


Expand Down Expand Up @@ -588,8 +588,8 @@ async def test_saves_many(m):
result = await m.Member.add(members)
assert result == [member1, member2]

assert await m.Member.get(pk=member1.pk) == member1
assert await m.Member.get(pk=member2.pk) == member2
assert await m.Member.get(pk=member1.id) == member1
assert await m.Member.get(pk=member2.id) == member2


@py_test_mark_asyncio
Expand Down Expand Up @@ -618,14 +618,14 @@ async def test_delete_many(m):
result = await m.Member.delete_many(members)
assert result == 2
with pytest.raises(NotFoundError):
await m.Member.get(pk=member1.pk)
await m.Member.get(pk=member1.key())


@py_test_mark_asyncio
async def test_updates_a_model(members, m):
member1, member2, member3 = members
await member1.update(last_name="Smith")
member = await m.Member.get(member1.pk)
member = await m.Member.get(member1.id)
assert member.last_name == "Smith"


Expand Down Expand Up @@ -681,3 +681,59 @@ class Address(m.BaseHashModel):
Address.redisearch_schema()
== f"ON HASH PREFIX 1 {key_prefix} SCHEMA pk TAG SEPARATOR | a_string TAG SEPARATOR | a_full_text_string TAG SEPARATOR | a_full_text_string AS a_full_text_string_fts TEXT an_integer NUMERIC SORTABLE a_float NUMERIC"
)


@py_test_mark_asyncio
async def test_primary_key_model_error(m):

class Customer(m.BaseHashModel):
id: int = Field(primary_key=True, index=True)
first_name: str = Field(primary_key=True, index=True)
last_name: str
bio: Optional[str]

await Migrator().run()

with pytest.raises(RedisModelError, match="You must define only one primary key for a model"):
_ = Customer(
id=0,
first_name="Mahmoud",
last_name="Harmouch",
bio="Python developer, wanna work at Redis, Inc."
)


@py_test_mark_asyncio
async def test_primary_pk_exists(m):

class Customer1(m.BaseHashModel):
id: int
first_name: str
last_name: str
bio: Optional[str]

class Customer2(m.BaseHashModel):
id: int = Field(primary_key=True, index=True)
first_name: str
last_name: str
bio: Optional[str]

await Migrator().run()

customer = Customer1(
id=0,
first_name="Mahmoud",
last_name="Harmouch",
bio="Python developer, wanna work at Redis, Inc."
)

assert 'pk' in customer.__fields__

customer = Customer2(
id=1,
first_name="Kim",
last_name="Brookins",
bio="This is member 2 who can be quite anxious until you get to know them.",
)

assert 'pk' not in customer.__fields__

0 comments on commit 551429c

Please sign in to comment.