From 25246ac86cace73bd167f84f30b8739ec48c3c3b Mon Sep 17 00:00:00 2001 From: Manabu Niseki Date: Fri, 3 May 2024 11:15:09 +0900 Subject: [PATCH 1/2] feat: add casesensitive support --- aredis_om/model/model.py | 16 ++++++++++++++++ tests/test_hash_model.py | 13 ++++++++++++- tests/test_json_model.py | 23 +++++++++++++++++------ 3 files changed, 45 insertions(+), 7 deletions(-) diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 31c42bd..5748a22 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -1032,12 +1032,14 @@ class FieldInfo(PydanticFieldInfo): def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: primary_key = kwargs.pop("primary_key", False) sortable = kwargs.pop("sortable", Undefined) + casesensitive = kwargs.pop("casesensitive", Undefined) index = kwargs.pop("index", Undefined) full_text_search = kwargs.pop("full_text_search", Undefined) vector_options = kwargs.pop("vector_options", None) super().__init__(default=default, **kwargs) self.primary_key = primary_key self.sortable = sortable + self.casesensitive = casesensitive self.index = index self.full_text_search = full_text_search self.vector_options = vector_options @@ -1169,6 +1171,7 @@ def Field( regex: Optional[str] = None, primary_key: bool = False, sortable: Union[bool, UndefinedType] = Undefined, + casesensitive: Union[bool, UndefinedType] = Undefined, index: Union[bool, UndefinedType] = Undefined, full_text_search: Union[bool, UndefinedType] = Undefined, vector_options: Optional[VectorFieldOptions] = None, @@ -1197,6 +1200,7 @@ def Field( regex=regex, primary_key=primary_key, sortable=sortable, + casesensitive=casesensitive, index=index, full_text_search=full_text_search, vector_options=vector_options, @@ -1764,6 +1768,7 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo): # TODO: Abstract string-building logic for each type (TAG, etc.) into # classes that take a field name. sortable = getattr(field_info, "sortable", False) + casesensitive = getattr(field_info, "casesensitive", False) if is_supported_container_type(typ): embedded_cls = get_args(typ) @@ -1804,6 +1809,9 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo): schema = f"{name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}" if schema and sortable is True: schema += " SORTABLE" + if schema and casesensitive is True: + schema += " CASESENSITIVE" + return schema @@ -2046,6 +2054,7 @@ def schema_for_type( else: path = f"{json_path}.{name}" sortable = getattr(field_info, "sortable", False) + casesensitive = getattr(field_info, "casesensitive", False) full_text_search = getattr(field_info, "full_text_search", False) sortable_tag_error = RedisModelError( "In this Preview release, TAG fields cannot " @@ -2076,6 +2085,8 @@ def schema_for_type( schema = f"{path} AS {index_field_name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}" if sortable is True: raise sortable_tag_error + if casesensitive is True: + schema += " CASESENSITIVE" elif any(issubclass(typ, t) for t in NUMERIC_TYPES): schema = f"{path} AS {index_field_name} NUMERIC" elif issubclass(typ, str): @@ -2091,14 +2102,19 @@ def schema_for_type( # search queries can be sorted, but not exact match # queries. schema += " SORTABLE" + if casesensitive is True: + raise RedisModelError("Text fields cannot be case-sensitive.") else: schema = f"{path} AS {index_field_name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}" if sortable is True: raise sortable_tag_error + if casesensitive is True: + schema += " CASESENSITIVE" else: schema = f"{path} AS {index_field_name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}" if sortable is True: raise sortable_tag_error + return schema return "" diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index f7aee62..74d299c 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -48,7 +48,7 @@ class Order(BaseHashModel): class Member(BaseHashModel): id: int = Field(index=True, primary_key=True) - first_name: str = Field(index=True) + first_name: str = Field(index=True, casesensitive=True) last_name: str = Field(index=True) email: str = Field(index=True) join_date: datetime.date @@ -385,6 +385,17 @@ async def test_sorting(members, m): await m.Member.find().sort_by("join_date").all() +@py_test_mark_asyncio +async def test_casesensitive(members, m): + member1, member2, member3 = members + + actual = await m.Member.find(m.Member.first_name == "Andrew").all() + assert actual == [member1, member3] + + actual = await m.Member.find(m.Member.first_name == "andrew").all() + assert actual == [] + + def test_validates_required_fields(m): # Raises ValidationError: last_name is required # TODO: Test the error value diff --git a/tests/test_json_model.py b/tests/test_json_model.py index 55e8b0f..ca87c85 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -67,7 +67,7 @@ class Order(EmbeddedJsonModel): created_on: datetime.datetime class Member(BaseJsonModel): - first_name: str = Field(index=True) + first_name: str = Field(index=True, casesensitive=True) last_name: str = Field(index=True) email: Optional[EmailStr] = Field(index=True, default=None) join_date: datetime.date @@ -454,15 +454,15 @@ async def test_in_query(members, m): ) assert actual == [member2, member1, member3] + @py_test_mark_asyncio async def test_not_in_query(members, m): member1, member2, member3 = members actual = await ( - m.Member.find(m.Member.pk >> [member2.pk, member3.pk]) - .sort_by("age") - .all() + m.Member.find(m.Member.pk >> [member2.pk, member3.pk]).sort_by("age").all() ) - assert actual == [ member1] + assert actual == [member1] + @py_test_mark_asyncio async def test_update_query(members, m): @@ -749,6 +749,17 @@ async def test_sorting(members, m): await m.Member.find().sort_by("join_date").all() +@py_test_mark_asyncio +async def test_casesensitive(members, m): + member1, member2, member3 = members + + actual = await m.Member.find(m.Member.first_name == "Andrew").all() + assert actual == [member1, member3] + + actual = await m.Member.find(m.Member.first_name == "andrew").all() + assert actual == [] + + @py_test_mark_asyncio async def test_not_found(m): with pytest.raises(NotFoundError): @@ -873,7 +884,7 @@ async def test_schema(m, key_prefix): key_prefix = m.Member.make_key(m.Member._meta.primary_key_pattern.format(pk="")) assert ( m.Member.redisearch_schema() - == f"ON JSON PREFIX 1 {key_prefix} SCHEMA $.pk AS pk TAG SEPARATOR | $.first_name AS first_name TAG SEPARATOR | $.last_name AS last_name TAG SEPARATOR | $.email AS email TAG SEPARATOR | $.age AS age NUMERIC $.bio AS bio TAG SEPARATOR | $.bio AS bio_fts TEXT $.address.pk AS address_pk TAG SEPARATOR | $.address.city AS address_city TAG SEPARATOR | $.address.postal_code AS address_postal_code TAG SEPARATOR | $.address.note.pk AS address_note_pk TAG SEPARATOR | $.address.note.description AS address_note_description TAG SEPARATOR | $.orders[*].pk AS orders_pk TAG SEPARATOR | $.orders[*].items[*].pk AS orders_items_pk TAG SEPARATOR | $.orders[*].items[*].name AS orders_items_name TAG SEPARATOR |" + == f"ON JSON PREFIX 1 {key_prefix} SCHEMA $.pk AS pk TAG SEPARATOR | $.first_name AS first_name TAG SEPARATOR | CASESENSITIVE $.last_name AS last_name TAG SEPARATOR | $.email AS email TAG SEPARATOR | $.age AS age NUMERIC $.bio AS bio TAG SEPARATOR | $.bio AS bio_fts TEXT $.address.pk AS address_pk TAG SEPARATOR | $.address.city AS address_city TAG SEPARATOR | $.address.postal_code AS address_postal_code TAG SEPARATOR | $.address.note.pk AS address_note_pk TAG SEPARATOR | $.address.note.description AS address_note_description TAG SEPARATOR | $.orders[*].pk AS orders_pk TAG SEPARATOR | $.orders[*].items[*].pk AS orders_items_pk TAG SEPARATOR | $.orders[*].items[*].name AS orders_items_name TAG SEPARATOR |" ) From 9b7c73abbc2b63d30160e13e14040b9926aec2c3 Mon Sep 17 00:00:00 2001 From: slorello89 Date: Fri, 3 May 2024 10:33:44 -0400 Subject: [PATCH 2/2] snake_case --- aredis_om/model/model.py | 20 ++++++++++---------- tests/test_hash_model.py | 4 ++-- tests/test_json_model.py | 4 ++-- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 5748a22..acfd956 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -1032,14 +1032,14 @@ class FieldInfo(PydanticFieldInfo): def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: primary_key = kwargs.pop("primary_key", False) sortable = kwargs.pop("sortable", Undefined) - casesensitive = kwargs.pop("casesensitive", Undefined) + case_sensitive = kwargs.pop("case_sensitive", Undefined) index = kwargs.pop("index", Undefined) full_text_search = kwargs.pop("full_text_search", Undefined) vector_options = kwargs.pop("vector_options", None) super().__init__(default=default, **kwargs) self.primary_key = primary_key self.sortable = sortable - self.casesensitive = casesensitive + self.case_sensitive = case_sensitive self.index = index self.full_text_search = full_text_search self.vector_options = vector_options @@ -1171,7 +1171,7 @@ def Field( regex: Optional[str] = None, primary_key: bool = False, sortable: Union[bool, UndefinedType] = Undefined, - casesensitive: Union[bool, UndefinedType] = Undefined, + case_sensitive: Union[bool, UndefinedType] = Undefined, index: Union[bool, UndefinedType] = Undefined, full_text_search: Union[bool, UndefinedType] = Undefined, vector_options: Optional[VectorFieldOptions] = None, @@ -1200,7 +1200,7 @@ def Field( regex=regex, primary_key=primary_key, sortable=sortable, - casesensitive=casesensitive, + case_sensitive=case_sensitive, index=index, full_text_search=full_text_search, vector_options=vector_options, @@ -1768,7 +1768,7 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo): # TODO: Abstract string-building logic for each type (TAG, etc.) into # classes that take a field name. sortable = getattr(field_info, "sortable", False) - casesensitive = getattr(field_info, "casesensitive", False) + case_sensitive = getattr(field_info, "case_sensitive", False) if is_supported_container_type(typ): embedded_cls = get_args(typ) @@ -1809,7 +1809,7 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo): schema = f"{name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}" if schema and sortable is True: schema += " SORTABLE" - if schema and casesensitive is True: + if schema and case_sensitive is True: schema += " CASESENSITIVE" return schema @@ -2054,7 +2054,7 @@ def schema_for_type( else: path = f"{json_path}.{name}" sortable = getattr(field_info, "sortable", False) - casesensitive = getattr(field_info, "casesensitive", False) + case_sensitive = getattr(field_info, "case_sensitive", False) full_text_search = getattr(field_info, "full_text_search", False) sortable_tag_error = RedisModelError( "In this Preview release, TAG fields cannot " @@ -2085,7 +2085,7 @@ def schema_for_type( schema = f"{path} AS {index_field_name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}" if sortable is True: raise sortable_tag_error - if casesensitive is True: + if case_sensitive is True: schema += " CASESENSITIVE" elif any(issubclass(typ, t) for t in NUMERIC_TYPES): schema = f"{path} AS {index_field_name} NUMERIC" @@ -2102,13 +2102,13 @@ def schema_for_type( # search queries can be sorted, but not exact match # queries. schema += " SORTABLE" - if casesensitive is True: + if case_sensitive is True: raise RedisModelError("Text fields cannot be case-sensitive.") else: schema = f"{path} AS {index_field_name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}" if sortable is True: raise sortable_tag_error - if casesensitive is True: + if case_sensitive is True: schema += " CASESENSITIVE" else: schema = f"{path} AS {index_field_name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}" diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index 74d299c..05d9872 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -48,7 +48,7 @@ class Order(BaseHashModel): class Member(BaseHashModel): id: int = Field(index=True, primary_key=True) - first_name: str = Field(index=True, casesensitive=True) + first_name: str = Field(index=True, case_sensitive=True) last_name: str = Field(index=True) email: str = Field(index=True) join_date: datetime.date @@ -386,7 +386,7 @@ async def test_sorting(members, m): @py_test_mark_asyncio -async def test_casesensitive(members, m): +async def test_case_sensitive(members, m): member1, member2, member3 = members actual = await m.Member.find(m.Member.first_name == "Andrew").all() diff --git a/tests/test_json_model.py b/tests/test_json_model.py index ca87c85..e7c6ca6 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -67,7 +67,7 @@ class Order(EmbeddedJsonModel): created_on: datetime.datetime class Member(BaseJsonModel): - first_name: str = Field(index=True, casesensitive=True) + first_name: str = Field(index=True, case_sensitive=True) last_name: str = Field(index=True) email: Optional[EmailStr] = Field(index=True, default=None) join_date: datetime.date @@ -750,7 +750,7 @@ async def test_sorting(members, m): @py_test_mark_asyncio -async def test_casesensitive(members, m): +async def test_case_sensitive(members, m): member1, member2, member3 = members actual = await m.Member.find(m.Member.first_name == "Andrew").all()