From 4c88d48830fb5a5532da56d1f5ccf20a4c1bfa74 Mon Sep 17 00:00:00 2001 From: George Panchuk Date: Fri, 7 Jun 2024 12:45:39 +0200 Subject: [PATCH] fix: fix integer index conversion --- qdrant_client/conversions/conversion.py | 24 ++++++++++++++++ tests/conversions/fixtures.py | 37 +++++++++++++++++++++++++ 2 files changed, 61 insertions(+) diff --git a/qdrant_client/conversions/conversion.py b/qdrant_client/conversions/conversion.py index 8530cd2d..fd4dd060 100644 --- a/qdrant_client/conversions/conversion.py +++ b/qdrant_client/conversions/conversion.py @@ -350,6 +350,9 @@ def convert_payload_schema_params( if model.HasField("text_index_params"): text_index_params = model.text_index_params return cls.convert_text_index_params(text_index_params) + if model.HasField("integer_index_params"): + integer_index_params = model.integer_index_params + return cls.convert_integer_index_params(integer_index_params) raise ValueError(f"invalid PayloadIndexParams model: {model}") # pragma: no cover @@ -988,6 +991,16 @@ def convert_text_index_params(cls, model: grpc.TextIndexParams) -> rest.TextInde lowercase=model.lowercase if model.HasField("lowercase") else None, ) + @classmethod + def convert_integer_index_params( + cls, model: grpc.IntegerIndexParams + ) -> rest.IntegerIndexParams: + return rest.IntegerIndexParams( + type=rest.IntegerIndexType.INTEGER, + range=model.range, + lookup=model.lookup, + ) + @classmethod def convert_collection_params_diff( cls, model: grpc.CollectionParamsDiff @@ -1578,6 +1591,11 @@ def convert_payload_schema_params( if isinstance(model, rest.TextIndexParams): return grpc.PayloadIndexParams(text_index_params=cls.convert_text_index_params(model)) + if isinstance(model, rest.IntegerIndexParams): + return grpc.PayloadIndexParams( + integer_index_params=cls.convert_integer_index_params(model) + ) + raise ValueError(f"invalid PayloadSchemaParams model: {model}") # pragma: no cover @classmethod @@ -2402,6 +2420,12 @@ def convert_text_index_params(cls, model: rest.TextIndexParams) -> grpc.TextInde max_token_len=model.max_token_len, ) + @classmethod + def convert_integer_index_params( + cls, model: rest.IntegerIndexParams + ) -> grpc.IntegerIndexParams: + return grpc.IntegerIndexParams(lookup=model.lookup, range=model.range) + @classmethod def convert_collection_params_diff( cls, model: rest.CollectionParamsDiff diff --git a/tests/conversions/fixtures.py b/tests/conversions/fixtures.py index 1771de6e..16896d5b 100644 --- a/tests/conversions/fixtures.py +++ b/tests/conversions/fixtures.py @@ -340,6 +340,11 @@ text_index_params_4 = grpc.TextIndexParams(tokenizer=grpc.TokenizerType.Multilingual) +integer_index_params_0 = grpc.IntegerIndexParams(lookup=True, range=False) +integer_index_params_1 = grpc.IntegerIndexParams(lookup=False, range=True) +integer_index_params_2 = grpc.IntegerIndexParams(lookup=True, range=True) + + payload_schema_text_prefix = grpc.PayloadSchemaInfo( data_type=grpc.PayloadSchemaType.Text, params=grpc.PayloadIndexParams(text_index_params=text_index_params_1), @@ -362,6 +367,24 @@ points=0, ) +payload_schema_integer_lookup = grpc.PayloadSchemaInfo( + data_type=grpc.PayloadSchemaType.Integer, + params=grpc.PayloadIndexParams(integer_index_params=integer_index_params_0), + points=0, +) + +payload_schema_integer_range = grpc.PayloadSchemaInfo( + data_type=grpc.PayloadSchemaType.Integer, + params=grpc.PayloadIndexParams(integer_index_params=integer_index_params_1), + points=0, +) + +payload_schema_integer_lookup_and_range = grpc.PayloadSchemaInfo( + data_type=grpc.PayloadSchemaType.Integer, + params=grpc.PayloadIndexParams(integer_index_params=integer_index_params_2), + points=0, +) + collection_info_grey = grpc.CollectionInfo( status=collection_status_grey, optimizer_status=optimizer_status_error, @@ -391,6 +414,9 @@ "text_field_multilingual": payload_schema_text_multilingual, "bool_field": payload_schema_bool, "datetime_field": payload_schema_datetime, + "integer_lookup": payload_schema_integer_lookup, + "integer_range": payload_schema_integer_range, + "integer_lookup_and_range": payload_schema_integer_lookup_and_range, }, ) @@ -413,6 +439,9 @@ "text_field_multilingual": payload_schema_text_multilingual, "bool_field": payload_schema_bool, "datetime_field": payload_schema_datetime, + "integer_lookup": payload_schema_integer_lookup, + "integer_range": payload_schema_integer_range, + "integer_lookup_and_range": payload_schema_integer_lookup_and_range, }, ) @@ -435,6 +464,9 @@ "text_field_multilingual": payload_schema_text_multilingual, "bool_field": payload_schema_bool, "datetime_field": payload_schema_datetime, + "integer_lookup": payload_schema_integer_lookup, + "integer_range": payload_schema_integer_range, + "integer_lookup_and_range": payload_schema_integer_lookup_and_range, }, ) quantization_config = grpc.QuantizationConfig( @@ -967,6 +999,11 @@ text_index_params_2, text_index_params_3, ], + "IntegerIndexParams": [ + integer_index_params_0, + integer_index_params_1, + integer_index_params_2, + ], "CollectionParamsDiff": [collections_params_diff], "LookupLocation": [lookup_location_1, lookup_location_2], "ReadConsistency": [