Skip to content

Commit

Permalink
Update arg check to raise exception is index build args don't match r…
Browse files Browse the repository at this point in the history
…equest index
  • Loading branch information
leothomas committed Nov 6, 2023
1 parent d6618b7 commit 2bb7c95
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 35 deletions.
71 changes: 55 additions & 16 deletions src/tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,44 +626,83 @@ def test_index_build_args(client: vecs.Client) -> None:
bar.upsert([("a", [1, 2, 3, 4], {})])

# Test that default value for nlists is used in absence of index build args
bar.create_index(method="ivfflat") # type: ignore
[nlists] = [i for i in bar.index.split("_") if i.startswith("nl")]
assert int(nlists.strip("nl")) == 30

# Test that default value for nlists is used when when incorrect index build
# args are supplied
bar.create_index(method=IndexMethod.ivfflat, index_arguments=IndexArgsHNSW(), replace=True) # type: ignore
bar.create_index(method="ivfflat")
[nlists] = [i for i in bar.index.split("_") if i.startswith("nl")]
assert int(nlists.strip("nl")) == 30

# Test nlists is honored when supplied
bar.create_index(method=IndexMethod.ivfflat, index_arguments=IndexArgsIVFFlat(n_lists=123), replace=True) # type: ignore
bar.create_index(
method=IndexMethod.ivfflat,
index_arguments=IndexArgsIVFFlat(n_lists=123),
replace=True,
)
[nlists] = [i for i in bar.index.split("_") if i.startswith("nl")]
assert int(nlists.strip("nl")) == 123

# Test that default values for m and ef_construction are used in absence of
# index build args
bar.create_index(method="hnsw", replace=True) # type: ignore
bar.create_index(method="hnsw", replace=True)
[m] = [i for i in bar.index.split("_") if i.startswith("m")]
[ef_construction] = [i for i in bar.index.split("_") if i.startswith("efc")]
assert int(m.strip("m")) == 16
assert int(ef_construction.strip("efc")) == 64

# Test that the default values for m and ef_construction are used when
# incorrect index build args are supplied
bar.create_index(method="hnsw", index_arguments=IndexArgsIVFFlat(n_lists=123), replace=True) # type: ignore
# Test m and ef_construction is honored when supplied
bar.create_index(
method="hnsw",
index_arguments=IndexArgsHNSW(m=8, ef_construction=123),
replace=True,
)
[m] = [i for i in bar.index.split("_") if i.startswith("m")]
[ef_construction] = [i for i in bar.index.split("_") if i.startswith("efc")]
assert int(m.strip("m")) == 8
assert int(ef_construction.strip("efc")) == 123

# Test m is honored and ef_construction is default when _only_ m is supplied
bar.create_index(method="hnsw", index_arguments=IndexArgsHNSW(m=8), replace=True)
[m] = [i for i in bar.index.split("_") if i.startswith("m")]
[ef_construction] = [i for i in bar.index.split("_") if i.startswith("efc")]
assert int(m.strip("m")) == 16
assert int(m.strip("m")) == 8
assert int(ef_construction.strip("efc")) == 64

# Test m and ef_construnction is honored when supplied
bar.create_index(method="hnsw", index_arguments=IndexArgsHNSW(m=12, ef_construction=123), replace=True) # type: ignore
# Test m is default and ef_construction is honoured when _only_
# ef_construction is supplied
bar.create_index(
method="hnsw", index_arguments=IndexArgsHNSW(ef_construction=123), replace=True
)
[m] = [i for i in bar.index.split("_") if i.startswith("m")]
[ef_construction] = [i for i in bar.index.split("_") if i.startswith("efc")]
assert int(m.strip("m")) == 12
assert int(m.strip("m")) == 16
assert int(ef_construction.strip("efc")) == 123

# Test that exception is raised when index build args don't match
# the requested index type
with pytest.raises(vecs.exc.ArgError):
bar.create_index(
method=IndexMethod.ivfflat, index_arguments=IndexArgsHNSW(), replace=True
)
with pytest.raises(vecs.exc.ArgError):
bar.create_index(
method=IndexMethod.hnsw,
index_arguments=IndexArgsIVFFlat(n_lists=123),
replace=True,
)

# Test that excpetion is raised index build args are supplied by the
# IndexMethod.auto index is specified
with pytest.raises(vecs.exc.ArgError):
bar.create_index(
method=IndexMethod.auto,
index_arguments=IndexArgsIVFFlat(n_lists=123),
replace=True,
)
with pytest.raises(vecs.exc.ArgError):
bar.create_index(
method=IndexMethod.auto,
index_arguments=IndexArgsHNSW(),
replace=True,
)


def test_cosine_index_query(client: vecs.Client) -> None:
dim = 4
Expand Down
41 changes: 22 additions & 19 deletions src/vecs/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,9 +690,30 @@ def create_index(
ArgError: If an invalid index method is used, or if *replace* is False and an index already exists.
"""

if not method in (IndexMethod.ivfflat, IndexMethod.hnsw, IndexMethod.auto):
if method not in (IndexMethod.ivfflat, IndexMethod.hnsw, IndexMethod.auto):
raise ArgError("invalid index method")

if index_arguments:
# Disallow case where user submits index arguments but uses the
# IndexMethod.auto index (index build arguments should only be
# used with a specific index)
if method == IndexMethod.auto:
raise ArgError(
"Index build parameters are not allowed when using the IndexMethod.auto index."
)
# Disallow case where user specifies one index type but submits
# index build arguments for the other index type
if (
isinstance(index_arguments, IndexArgsHNSW)
and method != IndexMethod.hnsw
) or (
isinstance(index_arguments, IndexArgsIVFFlat)
and method != IndexMethod.ivfflat
):
raise ArgError(
f"{index_arguments.__class__.__name__} build parameters were supplied but {method} index was specified."
)

if method == IndexMethod.auto:
if self.client._supports_hnsw():
method = IndexMethod.hnsw
Expand All @@ -704,24 +725,6 @@ def create_index(
"HNSW Unavailable. Upgrade your pgvector installation to > 0.5.0 to enable HNSW support"
)

# Catch case where use submits index build args for one index type but
# method defines a different index type.
if index_arguments and (
(isinstance(index_arguments, IndexArgsHNSW) and method != IndexMethod.hnsw)
or (
isinstance(index_arguments, IndexArgsIVFFlat)
and method != IndexMethod.ivfflat
)
):
warnings.warn(
UserWarning(
f"{index_arguments.__class__.__name__} build parameters were supplied but {method} index was specified. Default parameters for {method} index will be used instead."
)
)
# set index_arguments to None in order to instantiate
# with the default values later
index_arguments = None

ops = INDEX_MEASURE_TO_OPS.get(measure)
if ops is None:
raise ArgError("Unknown index measure")
Expand Down

0 comments on commit 2bb7c95

Please sign in to comment.