diff --git a/vectordb_orm/indexes.py b/vectordb_orm/indexes.py index e02bd84..2bd01c2 100644 --- a/vectordb_orm/indexes.py +++ b/vectordb_orm/indexes.py @@ -49,6 +49,11 @@ def _assert_metric_type(self, metric_type: FloatSimilarityMetric | BinarySimilar if not isinstance (self, tuple(BINARY_INDEXES)): raise ValueError(f"Index type {self} is not supported for metric type {metric_type}") + def _assert_cluster_units_and_inference_comparison(self, cluster_units: int, inference_comparison: int | None) -> tuple[int, int]: + if not (cluster_units >= 1 and cluster_units <= 65536): + raise ValueError("cluster_units must be between 1 and 65536") + if inference_comparison is not None and not (inference_comparison >= 1 and inference_comparison <= cluster_units): + raise ValueError("inference_comparison must be between 1 and cluster_units") class FLAT(IndexBase): """ @@ -84,10 +89,8 @@ def __init__( """ super().__init__(metric_type=metric_type) - if not (cluster_units >= 1 and cluster_units <= 65536): - raise ValueError("cluster_units must be between 1 and 65536") - if inference_comparison is not None and not (inference_comparison >= 1 and inference_comparison <= cluster_units): - raise ValueError("inference_comparison must be between 1 and cluster_units") + self._assert_cluster_units_and_inference_comparison(cluster_units, inference_comparison) + self.nlist = cluster_units self.nprobe = inference_comparison or cluster_units @@ -119,10 +122,8 @@ def __init__( """ super().__init__(metric_type=metric_type) - if not (cluster_units >= 1 and cluster_units <= 65536): - raise ValueError("cluster_units must be between 1 and 65536") - if inference_comparison is not None and not (inference_comparison >= 1 and inference_comparison <= cluster_units): - raise ValueError("inference_comparison must be between 1 and cluster_units") + self._assert_cluster_units_and_inference_comparison(cluster_units, inference_comparison) + self.nlist = cluster_units self.nprobe = inference_comparison or cluster_units @@ -156,12 +157,9 @@ def __init__( """ super().__init__(metric_type=metric_type) - if not (cluster_units >= 1 and cluster_units <= 65536): - raise ValueError("cluster_units must be between 1 and 65536") - if inference_comparison is not None and not (inference_comparison >= 1 and inference_comparison <= cluster_units): - raise ValueError("inference_comparison must be between 1 and cluster_units") - if low_dimension_bits is not None and not (low_dimension_bits >= 1 and low_dimension_bits <= 16): - raise ValueError("low_dimension_bits must be between 1 and 16") + self._assert_cluster_units_and_inference_comparison(cluster_units, inference_comparison) + self._assert_low_dimension_bits(low_dimension_bits) + self.m = product_quantization self.nbits = low_dimension_bits or 8 self.nlist = cluster_units @@ -173,6 +171,9 @@ def get_index_parameters(self): def get_inference_parameters(self): return {"nprobe": self.nprobe} + def _assert_low_dimension_bits(self, low_dimension_bits: int | None): + if low_dimension_bits is not None and not (low_dimension_bits >= 1 and low_dimension_bits <= 16): + raise ValueError("low_dimension_bits must be between 1 and 16") class HNSW(IndexBase): """ @@ -195,14 +196,10 @@ def __init__( """ super().__init__(metric_type=metric_type) - if not (max_degree >= 4 and max_degree <= 64): - raise ValueError("max_degree must be between 4 and 64") - if not (search_scope_index >= 8 and search_scope_index <= 512): - raise ValueError("search_scope must be between 1 and 512") - if not (search_scope_inference >= 1 and search_scope_inference <= 32768): - # NOTE: Technically this needs to be between [top_k, 32768], but we don't know what top_k is - # at index time - raise ValueError("search_scope must be between 1 and 32768") + self._assert_max_degree(max_degree) + self._assert_search_scope_index(search_scope_index) + self._assert_search_scope_inference(search_scope_inference) + self.m = max_degree self.efConstruction = search_scope_index self.ef = search_scope_inference @@ -213,6 +210,20 @@ def get_index_parameters(self): def get_inference_parameters(self): return {"ef": self.ef} + def _assert_max_degree(self, max_degree: int): + if not (max_degree >= 4 and max_degree <= 64): + raise ValueError("max_degree must be between 4 and 64") + + def _assert_search_scope_index(self, search_scope_index: int): + if not (search_scope_index >= 8 and search_scope_index <= 512): + raise ValueError("search_scope must be between 1 and 512") + + def _assert_search_scope_inference(self, search_scope_inference: int): + if not (search_scope_inference >= 1 and search_scope_inference <= 32768): + # NOTE: Technically this needs to be between [top_k, 32768], but we don't know what top_k is + # at index time + raise ValueError("search_scope must be between 1 and 32768") + class BIN_FLAT(IndexBase): """ @@ -248,10 +259,8 @@ def __init__( """ super().__init__(metric_type=metric_type) - if not (cluster_units >= 1 and cluster_units <= 65536): - raise ValueError("cluster_units must be between 1 and 65536") - if inference_comparison is not None and not (inference_comparison >= 1 and inference_comparison <= cluster_units): - raise ValueError("inference_comparison must be between 1 and cluster_units") + self._assert_cluster_units_and_inference_comparison(cluster_units, inference_comparison) + self.nlist = cluster_units self.nprobe = inference_comparison or cluster_units