Skip to content

Commit

Permalink
Refactor index init validation
Browse files Browse the repository at this point in the history
  • Loading branch information
piercefreeman committed Apr 19, 2023
1 parent 7048eb5 commit 05ae1a3
Showing 1 changed file with 35 additions and 26 deletions.
61 changes: 35 additions & 26 deletions vectordb_orm/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ def _assert_metric_type(self, metric_type: FloatSimilarityMetric | BinarySimilar
if not isinstance (self, tuple(BINARY_INDEXES)):
raise ValueError(f"Index type {self} is not supported for metric type {metric_type}")

def _assert_cluster_units_and_inference_comparison(self, cluster_units: int, inference_comparison: int | None) -> tuple[int, int]:
if not (cluster_units >= 1 and cluster_units <= 65536):
raise ValueError("cluster_units must be between 1 and 65536")
if inference_comparison is not None and not (inference_comparison >= 1 and inference_comparison <= cluster_units):
raise ValueError("inference_comparison must be between 1 and cluster_units")

class FLAT(IndexBase):
"""
Expand Down Expand Up @@ -84,10 +89,8 @@ def __init__(
"""
super().__init__(metric_type=metric_type)

if not (cluster_units >= 1 and cluster_units <= 65536):
raise ValueError("cluster_units must be between 1 and 65536")
if inference_comparison is not None and not (inference_comparison >= 1 and inference_comparison <= cluster_units):
raise ValueError("inference_comparison must be between 1 and cluster_units")
self._assert_cluster_units_and_inference_comparison(cluster_units, inference_comparison)

self.nlist = cluster_units
self.nprobe = inference_comparison or cluster_units

Expand Down Expand Up @@ -119,10 +122,8 @@ def __init__(
"""
super().__init__(metric_type=metric_type)

if not (cluster_units >= 1 and cluster_units <= 65536):
raise ValueError("cluster_units must be between 1 and 65536")
if inference_comparison is not None and not (inference_comparison >= 1 and inference_comparison <= cluster_units):
raise ValueError("inference_comparison must be between 1 and cluster_units")
self._assert_cluster_units_and_inference_comparison(cluster_units, inference_comparison)

self.nlist = cluster_units
self.nprobe = inference_comparison or cluster_units

Expand Down Expand Up @@ -156,12 +157,9 @@ def __init__(
"""
super().__init__(metric_type=metric_type)

if not (cluster_units >= 1 and cluster_units <= 65536):
raise ValueError("cluster_units must be between 1 and 65536")
if inference_comparison is not None and not (inference_comparison >= 1 and inference_comparison <= cluster_units):
raise ValueError("inference_comparison must be between 1 and cluster_units")
if low_dimension_bits is not None and not (low_dimension_bits >= 1 and low_dimension_bits <= 16):
raise ValueError("low_dimension_bits must be between 1 and 16")
self._assert_cluster_units_and_inference_comparison(cluster_units, inference_comparison)
self._assert_low_dimension_bits(low_dimension_bits)

self.m = product_quantization
self.nbits = low_dimension_bits or 8
self.nlist = cluster_units
Expand All @@ -173,6 +171,9 @@ def get_index_parameters(self):
def get_inference_parameters(self):
return {"nprobe": self.nprobe}

def _assert_low_dimension_bits(self, low_dimension_bits: int | None):
if low_dimension_bits is not None and not (low_dimension_bits >= 1 and low_dimension_bits <= 16):
raise ValueError("low_dimension_bits must be between 1 and 16")

class HNSW(IndexBase):
"""
Expand All @@ -195,14 +196,10 @@ def __init__(
"""
super().__init__(metric_type=metric_type)

if not (max_degree >= 4 and max_degree <= 64):
raise ValueError("max_degree must be between 4 and 64")
if not (search_scope_index >= 8 and search_scope_index <= 512):
raise ValueError("search_scope must be between 1 and 512")
if not (search_scope_inference >= 1 and search_scope_inference <= 32768):
# NOTE: Technically this needs to be between [top_k, 32768], but we don't know what top_k is
# at index time
raise ValueError("search_scope must be between 1 and 32768")
self._assert_max_degree(max_degree)
self._assert_search_scope_index(search_scope_index)
self._assert_search_scope_inference(search_scope_inference)

self.m = max_degree
self.efConstruction = search_scope_index
self.ef = search_scope_inference
Expand All @@ -213,6 +210,20 @@ def get_index_parameters(self):
def get_inference_parameters(self):
return {"ef": self.ef}

def _assert_max_degree(self, max_degree: int):
if not (max_degree >= 4 and max_degree <= 64):
raise ValueError("max_degree must be between 4 and 64")

def _assert_search_scope_index(self, search_scope_index: int):
if not (search_scope_index >= 8 and search_scope_index <= 512):
raise ValueError("search_scope must be between 1 and 512")

def _assert_search_scope_inference(self, search_scope_inference: int):
if not (search_scope_inference >= 1 and search_scope_inference <= 32768):
# NOTE: Technically this needs to be between [top_k, 32768], but we don't know what top_k is
# at index time
raise ValueError("search_scope must be between 1 and 32768")


class BIN_FLAT(IndexBase):
"""
Expand Down Expand Up @@ -248,10 +259,8 @@ def __init__(
"""
super().__init__(metric_type=metric_type)

if not (cluster_units >= 1 and cluster_units <= 65536):
raise ValueError("cluster_units must be between 1 and 65536")
if inference_comparison is not None and not (inference_comparison >= 1 and inference_comparison <= cluster_units):
raise ValueError("inference_comparison must be between 1 and cluster_units")
self._assert_cluster_units_and_inference_comparison(cluster_units, inference_comparison)

self.nlist = cluster_units
self.nprobe = inference_comparison or cluster_units

Expand Down

0 comments on commit 05ae1a3

Please sign in to comment.