diff --git a/docs/api.md b/docs/api.md index 45aedb8..ffcb3d8 100644 --- a/docs/api.md +++ b/docs/api.md @@ -101,16 +101,19 @@ Available options for index `method` are: Where `auto` selects the best available index method, `hnsw` uses the [HNSW](https://github.com/pgvector/pgvector#hnsw) method and `ivfflat` uses [IVFFlat](https://github.com/pgvector/pgvector#ivfflat). +HNSW and IVFFlat indexes both allow for parameterization to control the speed/accuracy tradeoff. vecs provides sane defaults for these parameters. For a greater level of control you can optionally pass an instance of `vecs.IndexArgsIVFFlat` or `vecs.IndexArgsHNSW` to `create_index`'s `index_arguments` argument. Descriptions of the impact for each parameter are available in the [pgvector docs](https://github.com/pgvector/pgvector). + When using IVFFlat indexes, the index must be created __after__ the collection has been populated with records. Building an IVFFlat index on an empty collection will result in significantly reduced recall. You can continue upserting new documents after the index has been created, but should rebuild the index if the size of the collection more than doubles since the last index operation. HNSW indexes can be created immediately after the collection without populating records. -To manually specify `method` and `measure`, add them as arguments to `create_index` for example: +To manually specify `method`, `measure`, and `index_arguments` add them as arguments to `create_index` for example: ```python docs.create_index( method=IndexMethod.hnsw, measure=IndexMeasure.cosine_distance, + measure=IndexArgsHNSW(m=8), ) ``` diff --git a/docs/concepts_indexes.md b/docs/concepts_indexes.md index 3256ddf..97623d8 100644 --- a/docs/concepts_indexes.md +++ b/docs/concepts_indexes.md @@ -40,16 +40,19 @@ Available options for index `method` are: Where `auto` selects the best available index method, `hnsw` uses the [HNSW](https://github.com/pgvector/pgvector#hnsw) method and `ivfflat` uses [IVFFlat](https://github.com/pgvector/pgvector#ivfflat). +HNSW and IVFFlat indexes both allow for parameterization to control the speed/accuracy tradeoff. vecs provides sane defaults for these parameters. For a greater level of control you can optionally pass an instance of `vecs.IndexArgsIVFFlat` or `vecs.IndexArgsHNSW` to `create_index`'s `index_arguments` argument. Descriptions of the impact for each parameter are available in the [pgvector docs](https://github.com/pgvector/pgvector). + When using IVFFlat indexes, the index must be created __after__ the collection has been populated with records. Building an IVFFlat index on an empty collection will result in significantly reduced recall. You can continue upserting new documents after the index has been created, but should rebuild the index if the size of the collection more than doubles since the last index operation. HNSW indexes can be created immediately after the collection without populating records. -To manually specify `method` and `measure`, ass them as arguments to `create_index` for example: +To manually specify `method`, `measure`, and `index_arguments` add them as arguments to `create_index` for example: ```python docs.create_index( method=IndexMethod.hnsw, measure=IndexMeasure.cosine_distance, + measure=IndexArgsHNSW(m=8), ) ``` diff --git a/docs/support_changelog.md b/docs/support_changelog.md index ecc9b6e..25309e1 100644 --- a/docs/support_changelog.md +++ b/docs/support_changelog.md @@ -32,3 +32,5 @@ - Bugfix: removed errant print statement ## master + +- Feature: Parameterized IVFFlat and HNSW indexes diff --git a/src/tests/test_collection.py b/src/tests/test_collection.py index 8112151..ad30685 100644 --- a/src/tests/test_collection.py +++ b/src/tests/test_collection.py @@ -4,7 +4,7 @@ import pytest import vecs -from vecs import IndexMethod +from vecs import IndexArgsHNSW, IndexArgsIVFFlat, IndexMethod from vecs.exc import ArgError @@ -620,6 +620,90 @@ def test_hnsw(client: vecs.Client) -> None: assert len(results) == 1 +def test_index_build_args(client: vecs.Client) -> None: + dim = 4 + bar = client.get_or_create_collection(name="bar", dimension=dim) + bar.upsert([("a", [1, 2, 3, 4], {})]) + + # Test that default value for nlists is used in absence of index build args + bar.create_index(method="ivfflat") + [nlists] = [i for i in bar.index.split("_") if i.startswith("nl")] + assert int(nlists.strip("nl")) == 30 + + # Test nlists is honored when supplied + bar.create_index( + method=IndexMethod.ivfflat, + index_arguments=IndexArgsIVFFlat(n_lists=123), + replace=True, + ) + [nlists] = [i for i in bar.index.split("_") if i.startswith("nl")] + assert int(nlists.strip("nl")) == 123 + + # Test that default values for m and ef_construction are used in absence of + # index build args + bar.create_index(method="hnsw", replace=True) + [m] = [i for i in bar.index.split("_") if i.startswith("m")] + [ef_construction] = [i for i in bar.index.split("_") if i.startswith("efc")] + assert int(m.strip("m")) == 16 + assert int(ef_construction.strip("efc")) == 64 + + # Test m and ef_construction is honored when supplied + bar.create_index( + method="hnsw", + index_arguments=IndexArgsHNSW(m=8, ef_construction=123), + replace=True, + ) + [m] = [i for i in bar.index.split("_") if i.startswith("m")] + [ef_construction] = [i for i in bar.index.split("_") if i.startswith("efc")] + assert int(m.strip("m")) == 8 + assert int(ef_construction.strip("efc")) == 123 + + # Test m is honored and ef_construction is default when _only_ m is supplied + bar.create_index(method="hnsw", index_arguments=IndexArgsHNSW(m=8), replace=True) + [m] = [i for i in bar.index.split("_") if i.startswith("m")] + [ef_construction] = [i for i in bar.index.split("_") if i.startswith("efc")] + assert int(m.strip("m")) == 8 + assert int(ef_construction.strip("efc")) == 64 + + # Test m is default and ef_construction is honoured when _only_ + # ef_construction is supplied + bar.create_index( + method="hnsw", index_arguments=IndexArgsHNSW(ef_construction=123), replace=True + ) + [m] = [i for i in bar.index.split("_") if i.startswith("m")] + [ef_construction] = [i for i in bar.index.split("_") if i.startswith("efc")] + assert int(m.strip("m")) == 16 + assert int(ef_construction.strip("efc")) == 123 + + # Test that exception is raised when index build args don't match + # the requested index type + with pytest.raises(vecs.exc.ArgError): + bar.create_index( + method=IndexMethod.ivfflat, index_arguments=IndexArgsHNSW(), replace=True + ) + with pytest.raises(vecs.exc.ArgError): + bar.create_index( + method=IndexMethod.hnsw, + index_arguments=IndexArgsIVFFlat(n_lists=123), + replace=True, + ) + + # Test that excpetion is raised index build args are supplied by the + # IndexMethod.auto index is specified + with pytest.raises(vecs.exc.ArgError): + bar.create_index( + method=IndexMethod.auto, + index_arguments=IndexArgsIVFFlat(n_lists=123), + replace=True, + ) + with pytest.raises(vecs.exc.ArgError): + bar.create_index( + method=IndexMethod.auto, + index_arguments=IndexArgsHNSW(), + replace=True, + ) + + def test_cosine_index_query(client: vecs.Client) -> None: dim = 4 bar = client.get_or_create_collection(name="bar", dimension=dim) diff --git a/src/vecs/__init__.py b/src/vecs/__init__.py index d05370d..55f80e5 100644 --- a/src/vecs/__init__.py +++ b/src/vecs/__init__.py @@ -1,12 +1,26 @@ from vecs import exc from vecs.client import Client -from vecs.collection import Collection, IndexMeasure, IndexMethod +from vecs.collection import ( + Collection, + IndexArgsHNSW, + IndexArgsIVFFlat, + IndexMeasure, + IndexMethod, +) __project__ = "vecs" __version__ = "0.4.1" -__all__ = ["IndexMethod", "IndexMeasure", "Collection", "Client", "exc"] +__all__ = [ + "IndexArgsIVFFlat", + "IndexArgsHSNW", + "IndexMethod", + "IndexMeasure", + "Collection", + "Client", + "exc", +] def create_client(connection_string: str) -> Client: diff --git a/src/vecs/collection.py b/src/vecs/collection.py index 64c7f36..d236d7a 100644 --- a/src/vecs/collection.py +++ b/src/vecs/collection.py @@ -9,6 +9,7 @@ import math import uuid import warnings +from dataclasses import dataclass from enum import Enum from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union @@ -82,6 +83,40 @@ class IndexMeasure(str, Enum): max_inner_product = "max_inner_product" +@dataclass +class IndexArgsIVFFlat: + """ + A class for arguments that can optionally be supplied to the index creation + method when building an IVFFlat type index. + + Attributes: + nlist (int): The number of IVF centroids that the index should use + """ + + n_lists: int + + +@dataclass +class IndexArgsHNSW: + """ + A class for arguments that can optionally be supplied to the index creation + method when building an HNSW type index. + + Ref: https://github.com/pgvector/pgvector#index-options + + Both attributes are Optional in case the user only wants to specify one and + leave the other as default + + Attributes: + m (int): Maximum number of connections per node per layer (default: 16) + ef_construction (int): Size of the dynamic candidate list for + constructing the graph (default: 64) + """ + + m: Optional[int] = 16 + ef_construction: Optional[int] = 64 + + INDEX_MEASURE_TO_OPS = { # Maps the IndexMeasure enum options to the SQL ops string required by # the pgvector `create index` statement @@ -621,6 +656,7 @@ def create_index( self, measure: IndexMeasure = IndexMeasure.cosine_distance, method: IndexMethod = IndexMethod.auto, + index_arguments: Optional[Union[IndexArgsIVFFlat, IndexArgsHNSW]] = None, replace=True, ) -> None: """ @@ -648,14 +684,37 @@ def create_index( Args: measure (IndexMeasure, optional): The measure to index for. Defaults to 'cosine_distance'. method (IndexMethod, optional): The indexing method to use. Defaults to 'auto'. + index_arguments: (IndexArgsIVFFlat | IndexArgsHNSW, optional): Index type specific arguments replace (bool, optional): Whether to replace the existing index. Defaults to True. Raises: ArgError: If an invalid index method is used, or if *replace* is False and an index already exists. """ - if not method in (IndexMethod.ivfflat, IndexMethod.hnsw, IndexMethod.auto): + + if method not in (IndexMethod.ivfflat, IndexMethod.hnsw, IndexMethod.auto): raise ArgError("invalid index method") + if index_arguments: + # Disallow case where user submits index arguments but uses the + # IndexMethod.auto index (index build arguments should only be + # used with a specific index) + if method == IndexMethod.auto: + raise ArgError( + "Index build parameters are not allowed when using the IndexMethod.auto index." + ) + # Disallow case where user specifies one index type but submits + # index build arguments for the other index type + if ( + isinstance(index_arguments, IndexArgsHNSW) + and method != IndexMethod.hnsw + ) or ( + isinstance(index_arguments, IndexArgsIVFFlat) + and method != IndexMethod.ivfflat + ): + raise ArgError( + f"{index_arguments.__class__.__name__} build parameters were supplied but {method} index was specified." + ) + if method == IndexMethod.auto: if self.client._supports_hnsw(): method = IndexMethod.hnsw @@ -683,18 +742,27 @@ def create_index( raise ArgError("replace is set to False but an index exists") if method == IndexMethod.ivfflat: - n_records: int = sess.execute(func.count(self.table.c.id)).scalar() # type: ignore + if not index_arguments: + n_records: int = sess.execute(func.count(self.table.c.id)).scalar() # type: ignore - n_lists = ( - int(max(n_records / 1000, 30)) - if n_records < 1_000_000 - else int(math.sqrt(n_records)) - ) + n_lists = ( + int(max(n_records / 1000, 30)) + if n_records < 1_000_000 + else int(math.sqrt(n_records)) + ) + else: + # The following mypy error is ignored because mypy + # complains that `index_arguments` is typed as a union + # of IndexArgsIVFFlat and IndexArgsHNSW types, + # which both don't necessarily contain the `n_lists` + # parameter, however we have validated that the + # correct type is being used above. + n_lists = index_arguments.n_lists # type: ignore sess.execute( text( f""" - create index ix_{ops}_ivfflat_{n_lists}_{unique_string} + create index ix_{ops}_ivfflat_nl{n_lists}_{unique_string} on vecs."{self.table.name}" using ivfflat (vec {ops}) with (lists={n_lists}) """ @@ -702,12 +770,20 @@ def create_index( ) if method == IndexMethod.hnsw: + if not index_arguments: + index_arguments = IndexArgsHNSW() + + # See above for explanation of why the following lines + # are ignored + m = index_arguments.m # type: ignore + ef_construction = index_arguments.ef_construction # type: ignore + sess.execute( text( f""" - create index ix_{ops}_hnsw_{unique_string} + create index ix_{ops}_hnsw_m{m}_efc{ef_construction}_{unique_string} on vecs."{self.table.name}" - using hnsw (vec {ops}); + using hnsw (vec {ops}) WITH (m={m}, ef_construction={ef_construction}); """ ) )