Skip to content

Commit

Permalink
feat: add casesensitive support (#608)
Browse files Browse the repository at this point in the history
* feat: add casesensitive support

* snake_case

---------

Co-authored-by: slorello89 <[email protected]>
  • Loading branch information
ninoseki and slorello89 authored May 3, 2024
1 parent f1ed5b2 commit 5ef3d27
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 7 deletions.
16 changes: 16 additions & 0 deletions aredis_om/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,12 +1032,14 @@ class FieldInfo(PydanticFieldInfo):
def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
primary_key = kwargs.pop("primary_key", False)
sortable = kwargs.pop("sortable", Undefined)
case_sensitive = kwargs.pop("case_sensitive", Undefined)
index = kwargs.pop("index", Undefined)
full_text_search = kwargs.pop("full_text_search", Undefined)
vector_options = kwargs.pop("vector_options", None)
super().__init__(default=default, **kwargs)
self.primary_key = primary_key
self.sortable = sortable
self.case_sensitive = case_sensitive
self.index = index
self.full_text_search = full_text_search
self.vector_options = vector_options
Expand Down Expand Up @@ -1169,6 +1171,7 @@ def Field(
regex: Optional[str] = None,
primary_key: bool = False,
sortable: Union[bool, UndefinedType] = Undefined,
case_sensitive: Union[bool, UndefinedType] = Undefined,
index: Union[bool, UndefinedType] = Undefined,
full_text_search: Union[bool, UndefinedType] = Undefined,
vector_options: Optional[VectorFieldOptions] = None,
Expand Down Expand Up @@ -1197,6 +1200,7 @@ def Field(
regex=regex,
primary_key=primary_key,
sortable=sortable,
case_sensitive=case_sensitive,
index=index,
full_text_search=full_text_search,
vector_options=vector_options,
Expand Down Expand Up @@ -1764,6 +1768,7 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo):
# TODO: Abstract string-building logic for each type (TAG, etc.) into
# classes that take a field name.
sortable = getattr(field_info, "sortable", False)
case_sensitive = getattr(field_info, "case_sensitive", False)

if is_supported_container_type(typ):
embedded_cls = get_args(typ)
Expand Down Expand Up @@ -1804,6 +1809,9 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo):
schema = f"{name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}"
if schema and sortable is True:
schema += " SORTABLE"
if schema and case_sensitive is True:
schema += " CASESENSITIVE"

return schema


Expand Down Expand Up @@ -2046,6 +2054,7 @@ def schema_for_type(
else:
path = f"{json_path}.{name}"
sortable = getattr(field_info, "sortable", False)
case_sensitive = getattr(field_info, "case_sensitive", False)
full_text_search = getattr(field_info, "full_text_search", False)
sortable_tag_error = RedisModelError(
"In this Preview release, TAG fields cannot "
Expand Down Expand Up @@ -2076,6 +2085,8 @@ def schema_for_type(
schema = f"{path} AS {index_field_name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}"
if sortable is True:
raise sortable_tag_error
if case_sensitive is True:
schema += " CASESENSITIVE"
elif any(issubclass(typ, t) for t in NUMERIC_TYPES):
schema = f"{path} AS {index_field_name} NUMERIC"
elif issubclass(typ, str):
Expand All @@ -2091,14 +2102,19 @@ def schema_for_type(
# search queries can be sorted, but not exact match
# queries.
schema += " SORTABLE"
if case_sensitive is True:
raise RedisModelError("Text fields cannot be case-sensitive.")
else:
schema = f"{path} AS {index_field_name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}"
if sortable is True:
raise sortable_tag_error
if case_sensitive is True:
schema += " CASESENSITIVE"
else:
schema = f"{path} AS {index_field_name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}"
if sortable is True:
raise sortable_tag_error

return schema
return ""

Expand Down
13 changes: 12 additions & 1 deletion tests/test_hash_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class Order(BaseHashModel):

class Member(BaseHashModel):
id: int = Field(index=True, primary_key=True)
first_name: str = Field(index=True)
first_name: str = Field(index=True, case_sensitive=True)
last_name: str = Field(index=True)
email: str = Field(index=True)
join_date: datetime.date
Expand Down Expand Up @@ -385,6 +385,17 @@ async def test_sorting(members, m):
await m.Member.find().sort_by("join_date").all()


@py_test_mark_asyncio
async def test_case_sensitive(members, m):
member1, member2, member3 = members

actual = await m.Member.find(m.Member.first_name == "Andrew").all()
assert actual == [member1, member3]

actual = await m.Member.find(m.Member.first_name == "andrew").all()
assert actual == []


def test_validates_required_fields(m):
# Raises ValidationError: last_name is required
# TODO: Test the error value
Expand Down
23 changes: 17 additions & 6 deletions tests/test_json_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class Order(EmbeddedJsonModel):
created_on: datetime.datetime

class Member(BaseJsonModel):
first_name: str = Field(index=True)
first_name: str = Field(index=True, case_sensitive=True)
last_name: str = Field(index=True)
email: Optional[EmailStr] = Field(index=True, default=None)
join_date: datetime.date
Expand Down Expand Up @@ -454,15 +454,15 @@ async def test_in_query(members, m):
)
assert actual == [member2, member1, member3]


@py_test_mark_asyncio
async def test_not_in_query(members, m):
member1, member2, member3 = members
actual = await (
m.Member.find(m.Member.pk >> [member2.pk, member3.pk])
.sort_by("age")
.all()
m.Member.find(m.Member.pk >> [member2.pk, member3.pk]).sort_by("age").all()
)
assert actual == [ member1]
assert actual == [member1]


@py_test_mark_asyncio
async def test_update_query(members, m):
Expand Down Expand Up @@ -749,6 +749,17 @@ async def test_sorting(members, m):
await m.Member.find().sort_by("join_date").all()


@py_test_mark_asyncio
async def test_case_sensitive(members, m):
member1, member2, member3 = members

actual = await m.Member.find(m.Member.first_name == "Andrew").all()
assert actual == [member1, member3]

actual = await m.Member.find(m.Member.first_name == "andrew").all()
assert actual == []


@py_test_mark_asyncio
async def test_not_found(m):
with pytest.raises(NotFoundError):
Expand Down Expand Up @@ -873,7 +884,7 @@ async def test_schema(m, key_prefix):
key_prefix = m.Member.make_key(m.Member._meta.primary_key_pattern.format(pk=""))
assert (
m.Member.redisearch_schema()
== f"ON JSON PREFIX 1 {key_prefix} SCHEMA $.pk AS pk TAG SEPARATOR | $.first_name AS first_name TAG SEPARATOR | $.last_name AS last_name TAG SEPARATOR | $.email AS email TAG SEPARATOR | $.age AS age NUMERIC $.bio AS bio TAG SEPARATOR | $.bio AS bio_fts TEXT $.address.pk AS address_pk TAG SEPARATOR | $.address.city AS address_city TAG SEPARATOR | $.address.postal_code AS address_postal_code TAG SEPARATOR | $.address.note.pk AS address_note_pk TAG SEPARATOR | $.address.note.description AS address_note_description TAG SEPARATOR | $.orders[*].pk AS orders_pk TAG SEPARATOR | $.orders[*].items[*].pk AS orders_items_pk TAG SEPARATOR | $.orders[*].items[*].name AS orders_items_name TAG SEPARATOR |"
== f"ON JSON PREFIX 1 {key_prefix} SCHEMA $.pk AS pk TAG SEPARATOR | $.first_name AS first_name TAG SEPARATOR | CASESENSITIVE $.last_name AS last_name TAG SEPARATOR | $.email AS email TAG SEPARATOR | $.age AS age NUMERIC $.bio AS bio TAG SEPARATOR | $.bio AS bio_fts TEXT $.address.pk AS address_pk TAG SEPARATOR | $.address.city AS address_city TAG SEPARATOR | $.address.postal_code AS address_postal_code TAG SEPARATOR | $.address.note.pk AS address_note_pk TAG SEPARATOR | $.address.note.description AS address_note_description TAG SEPARATOR | $.orders[*].pk AS orders_pk TAG SEPARATOR | $.orders[*].items[*].pk AS orders_items_pk TAG SEPARATOR | $.orders[*].items[*].name AS orders_items_name TAG SEPARATOR |"
)


Expand Down

0 comments on commit 5ef3d27

Please sign in to comment.