From d6618b7d25ae506681c321a096a8029b0c4e6fbb Mon Sep 17 00:00:00 2001 From: Leo Thomas Date: Mon, 6 Nov 2023 09:48:32 -0500 Subject: [PATCH 1/3] Add index build arguments --- src/tests/test_collection.py | 47 ++++++++++++++++++- src/vecs/__init__.py | 18 +++++++- src/vecs/collection.py | 90 ++++++++++++++++++++++++++++++++---- 3 files changed, 143 insertions(+), 12 deletions(-) diff --git a/src/tests/test_collection.py b/src/tests/test_collection.py index 8112151..8ce2cb0 100644 --- a/src/tests/test_collection.py +++ b/src/tests/test_collection.py @@ -4,7 +4,7 @@ import pytest import vecs -from vecs import IndexMethod +from vecs import IndexArgsHNSW, IndexArgsIVFFlat, IndexMethod from vecs.exc import ArgError @@ -620,6 +620,51 @@ def test_hnsw(client: vecs.Client) -> None: assert len(results) == 1 +def test_index_build_args(client: vecs.Client) -> None: + dim = 4 + bar = client.get_or_create_collection(name="bar", dimension=dim) + bar.upsert([("a", [1, 2, 3, 4], {})]) + + # Test that default value for nlists is used in absence of index build args + bar.create_index(method="ivfflat") # type: ignore + [nlists] = [i for i in bar.index.split("_") if i.startswith("nl")] + assert int(nlists.strip("nl")) == 30 + + # Test that default value for nlists is used when when incorrect index build + # args are supplied + bar.create_index(method=IndexMethod.ivfflat, index_arguments=IndexArgsHNSW(), replace=True) # type: ignore + [nlists] = [i for i in bar.index.split("_") if i.startswith("nl")] + assert int(nlists.strip("nl")) == 30 + + # Test nlists is honored when supplied + bar.create_index(method=IndexMethod.ivfflat, index_arguments=IndexArgsIVFFlat(n_lists=123), replace=True) # type: ignore + [nlists] = [i for i in bar.index.split("_") if i.startswith("nl")] + assert int(nlists.strip("nl")) == 123 + + # Test that default values for m and ef_construction are used in absence of + # index build args + bar.create_index(method="hnsw", replace=True) # type: ignore + [m] = [i for i in bar.index.split("_") if i.startswith("m")] + [ef_construction] = [i for i in bar.index.split("_") if i.startswith("efc")] + assert int(m.strip("m")) == 16 + assert int(ef_construction.strip("efc")) == 64 + + # Test that the default values for m and ef_construction are used when + # incorrect index build args are supplied + bar.create_index(method="hnsw", index_arguments=IndexArgsIVFFlat(n_lists=123), replace=True) # type: ignore + [m] = [i for i in bar.index.split("_") if i.startswith("m")] + [ef_construction] = [i for i in bar.index.split("_") if i.startswith("efc")] + assert int(m.strip("m")) == 16 + assert int(ef_construction.strip("efc")) == 64 + + # Test m and ef_construnction is honored when supplied + bar.create_index(method="hnsw", index_arguments=IndexArgsHNSW(m=12, ef_construction=123), replace=True) # type: ignore + [m] = [i for i in bar.index.split("_") if i.startswith("m")] + [ef_construction] = [i for i in bar.index.split("_") if i.startswith("efc")] + assert int(m.strip("m")) == 12 + assert int(ef_construction.strip("efc")) == 123 + + def test_cosine_index_query(client: vecs.Client) -> None: dim = 4 bar = client.get_or_create_collection(name="bar", dimension=dim) diff --git a/src/vecs/__init__.py b/src/vecs/__init__.py index d05370d..55f80e5 100644 --- a/src/vecs/__init__.py +++ b/src/vecs/__init__.py @@ -1,12 +1,26 @@ from vecs import exc from vecs.client import Client -from vecs.collection import Collection, IndexMeasure, IndexMethod +from vecs.collection import ( + Collection, + IndexArgsHNSW, + IndexArgsIVFFlat, + IndexMeasure, + IndexMethod, +) __project__ = "vecs" __version__ = "0.4.1" -__all__ = ["IndexMethod", "IndexMeasure", "Collection", "Client", "exc"] +__all__ = [ + "IndexArgsIVFFlat", + "IndexArgsHSNW", + "IndexMethod", + "IndexMeasure", + "Collection", + "Client", + "exc", +] def create_client(connection_string: str) -> Client: diff --git a/src/vecs/collection.py b/src/vecs/collection.py index 64c7f36..2b0fe06 100644 --- a/src/vecs/collection.py +++ b/src/vecs/collection.py @@ -9,6 +9,7 @@ import math import uuid import warnings +from dataclasses import dataclass from enum import Enum from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union @@ -82,6 +83,40 @@ class IndexMeasure(str, Enum): max_inner_product = "max_inner_product" +@dataclass +class IndexArgsIVFFlat: + """ + A class for arguments that can optionally be supplied to the index creation + method when building an IVFFlat type index. + + Attributes: + nlist (int): The number of IVF centroids that the index should use + """ + + n_lists: int + + +@dataclass +class IndexArgsHNSW: + """ + A class for arguments that can optionally be supplied to the index creation + method when building an HNSW type index. + + Ref: https://github.com/pgvector/pgvector#index-options + + Both attributes are Optional in case the user only wants to specify one and + leave the other as default + + Attributes: + m (int): Maximum number of connections per node per layer (default: 16) + ef_construction (int): Size of the dynamic candidate list for + constructing the graph (default: 64) + """ + + m: Optional[int] = 16 + ef_construction: Optional[int] = 64 + + INDEX_MEASURE_TO_OPS = { # Maps the IndexMeasure enum options to the SQL ops string required by # the pgvector `create index` statement @@ -621,6 +656,7 @@ def create_index( self, measure: IndexMeasure = IndexMeasure.cosine_distance, method: IndexMethod = IndexMethod.auto, + index_arguments: Optional[Union[IndexArgsIVFFlat, IndexArgsHNSW]] = None, replace=True, ) -> None: """ @@ -653,6 +689,7 @@ def create_index( Raises: ArgError: If an invalid index method is used, or if *replace* is False and an index already exists. """ + if not method in (IndexMethod.ivfflat, IndexMethod.hnsw, IndexMethod.auto): raise ArgError("invalid index method") @@ -667,6 +704,24 @@ def create_index( "HNSW Unavailable. Upgrade your pgvector installation to > 0.5.0 to enable HNSW support" ) + # Catch case where use submits index build args for one index type but + # method defines a different index type. + if index_arguments and ( + (isinstance(index_arguments, IndexArgsHNSW) and method != IndexMethod.hnsw) + or ( + isinstance(index_arguments, IndexArgsIVFFlat) + and method != IndexMethod.ivfflat + ) + ): + warnings.warn( + UserWarning( + f"{index_arguments.__class__.__name__} build parameters were supplied but {method} index was specified. Default parameters for {method} index will be used instead." + ) + ) + # set index_arguments to None in order to instantiate + # with the default values later + index_arguments = None + ops = INDEX_MEASURE_TO_OPS.get(measure) if ops is None: raise ArgError("Unknown index measure") @@ -683,18 +738,27 @@ def create_index( raise ArgError("replace is set to False but an index exists") if method == IndexMethod.ivfflat: - n_records: int = sess.execute(func.count(self.table.c.id)).scalar() # type: ignore + if not index_arguments: + n_records: int = sess.execute(func.count(self.table.c.id)).scalar() # type: ignore - n_lists = ( - int(max(n_records / 1000, 30)) - if n_records < 1_000_000 - else int(math.sqrt(n_records)) - ) + n_lists = ( + int(max(n_records / 1000, 30)) + if n_records < 1_000_000 + else int(math.sqrt(n_records)) + ) + else: + # The following mypy error is ignored because mypy + # complains that `index_arguments` is typed as a union + # of IndexArgsIVFFlat and IndexArgsHNSW types, + # which both don't necessarily contain the `n_lists` + # parameter, however we have validated that the + # correct type is being used above. + n_lists = index_arguments.n_lists # type: ignore sess.execute( text( f""" - create index ix_{ops}_ivfflat_{n_lists}_{unique_string} + create index ix_{ops}_ivfflat_nl{n_lists}_{unique_string} on vecs."{self.table.name}" using ivfflat (vec {ops}) with (lists={n_lists}) """ @@ -702,12 +766,20 @@ def create_index( ) if method == IndexMethod.hnsw: + if not index_arguments: + index_arguments = IndexArgsHNSW() + + # See above for explanation of why the following lines + # are ignored + m = index_arguments.m # type: ignore + ef_construction = index_arguments.ef_construction # type: ignore + sess.execute( text( f""" - create index ix_{ops}_hnsw_{unique_string} + create index ix_{ops}_hnsw_m{m}_efc{ef_construction}_{unique_string} on vecs."{self.table.name}" - using hnsw (vec {ops}); + using hnsw (vec {ops}) WITH (m={m}, ef_construction={ef_construction}); """ ) ) From 2bb7c95a525132568d6ffad36ee3c914db92226d Mon Sep 17 00:00:00 2001 From: Leo Thomas Date: Mon, 6 Nov 2023 12:34:01 -0500 Subject: [PATCH 2/3] Update arg check to raise exception is index build args don't match request index --- src/tests/test_collection.py | 71 ++++++++++++++++++++++++++++-------- src/vecs/collection.py | 41 +++++++++++---------- 2 files changed, 77 insertions(+), 35 deletions(-) diff --git a/src/tests/test_collection.py b/src/tests/test_collection.py index 8ce2cb0..ad30685 100644 --- a/src/tests/test_collection.py +++ b/src/tests/test_collection.py @@ -626,44 +626,83 @@ def test_index_build_args(client: vecs.Client) -> None: bar.upsert([("a", [1, 2, 3, 4], {})]) # Test that default value for nlists is used in absence of index build args - bar.create_index(method="ivfflat") # type: ignore - [nlists] = [i for i in bar.index.split("_") if i.startswith("nl")] - assert int(nlists.strip("nl")) == 30 - - # Test that default value for nlists is used when when incorrect index build - # args are supplied - bar.create_index(method=IndexMethod.ivfflat, index_arguments=IndexArgsHNSW(), replace=True) # type: ignore + bar.create_index(method="ivfflat") [nlists] = [i for i in bar.index.split("_") if i.startswith("nl")] assert int(nlists.strip("nl")) == 30 # Test nlists is honored when supplied - bar.create_index(method=IndexMethod.ivfflat, index_arguments=IndexArgsIVFFlat(n_lists=123), replace=True) # type: ignore + bar.create_index( + method=IndexMethod.ivfflat, + index_arguments=IndexArgsIVFFlat(n_lists=123), + replace=True, + ) [nlists] = [i for i in bar.index.split("_") if i.startswith("nl")] assert int(nlists.strip("nl")) == 123 # Test that default values for m and ef_construction are used in absence of # index build args - bar.create_index(method="hnsw", replace=True) # type: ignore + bar.create_index(method="hnsw", replace=True) [m] = [i for i in bar.index.split("_") if i.startswith("m")] [ef_construction] = [i for i in bar.index.split("_") if i.startswith("efc")] assert int(m.strip("m")) == 16 assert int(ef_construction.strip("efc")) == 64 - # Test that the default values for m and ef_construction are used when - # incorrect index build args are supplied - bar.create_index(method="hnsw", index_arguments=IndexArgsIVFFlat(n_lists=123), replace=True) # type: ignore + # Test m and ef_construction is honored when supplied + bar.create_index( + method="hnsw", + index_arguments=IndexArgsHNSW(m=8, ef_construction=123), + replace=True, + ) + [m] = [i for i in bar.index.split("_") if i.startswith("m")] + [ef_construction] = [i for i in bar.index.split("_") if i.startswith("efc")] + assert int(m.strip("m")) == 8 + assert int(ef_construction.strip("efc")) == 123 + + # Test m is honored and ef_construction is default when _only_ m is supplied + bar.create_index(method="hnsw", index_arguments=IndexArgsHNSW(m=8), replace=True) [m] = [i for i in bar.index.split("_") if i.startswith("m")] [ef_construction] = [i for i in bar.index.split("_") if i.startswith("efc")] - assert int(m.strip("m")) == 16 + assert int(m.strip("m")) == 8 assert int(ef_construction.strip("efc")) == 64 - # Test m and ef_construnction is honored when supplied - bar.create_index(method="hnsw", index_arguments=IndexArgsHNSW(m=12, ef_construction=123), replace=True) # type: ignore + # Test m is default and ef_construction is honoured when _only_ + # ef_construction is supplied + bar.create_index( + method="hnsw", index_arguments=IndexArgsHNSW(ef_construction=123), replace=True + ) [m] = [i for i in bar.index.split("_") if i.startswith("m")] [ef_construction] = [i for i in bar.index.split("_") if i.startswith("efc")] - assert int(m.strip("m")) == 12 + assert int(m.strip("m")) == 16 assert int(ef_construction.strip("efc")) == 123 + # Test that exception is raised when index build args don't match + # the requested index type + with pytest.raises(vecs.exc.ArgError): + bar.create_index( + method=IndexMethod.ivfflat, index_arguments=IndexArgsHNSW(), replace=True + ) + with pytest.raises(vecs.exc.ArgError): + bar.create_index( + method=IndexMethod.hnsw, + index_arguments=IndexArgsIVFFlat(n_lists=123), + replace=True, + ) + + # Test that excpetion is raised index build args are supplied by the + # IndexMethod.auto index is specified + with pytest.raises(vecs.exc.ArgError): + bar.create_index( + method=IndexMethod.auto, + index_arguments=IndexArgsIVFFlat(n_lists=123), + replace=True, + ) + with pytest.raises(vecs.exc.ArgError): + bar.create_index( + method=IndexMethod.auto, + index_arguments=IndexArgsHNSW(), + replace=True, + ) + def test_cosine_index_query(client: vecs.Client) -> None: dim = 4 diff --git a/src/vecs/collection.py b/src/vecs/collection.py index 2b0fe06..4b4d38c 100644 --- a/src/vecs/collection.py +++ b/src/vecs/collection.py @@ -690,9 +690,30 @@ def create_index( ArgError: If an invalid index method is used, or if *replace* is False and an index already exists. """ - if not method in (IndexMethod.ivfflat, IndexMethod.hnsw, IndexMethod.auto): + if method not in (IndexMethod.ivfflat, IndexMethod.hnsw, IndexMethod.auto): raise ArgError("invalid index method") + if index_arguments: + # Disallow case where user submits index arguments but uses the + # IndexMethod.auto index (index build arguments should only be + # used with a specific index) + if method == IndexMethod.auto: + raise ArgError( + "Index build parameters are not allowed when using the IndexMethod.auto index." + ) + # Disallow case where user specifies one index type but submits + # index build arguments for the other index type + if ( + isinstance(index_arguments, IndexArgsHNSW) + and method != IndexMethod.hnsw + ) or ( + isinstance(index_arguments, IndexArgsIVFFlat) + and method != IndexMethod.ivfflat + ): + raise ArgError( + f"{index_arguments.__class__.__name__} build parameters were supplied but {method} index was specified." + ) + if method == IndexMethod.auto: if self.client._supports_hnsw(): method = IndexMethod.hnsw @@ -704,24 +725,6 @@ def create_index( "HNSW Unavailable. Upgrade your pgvector installation to > 0.5.0 to enable HNSW support" ) - # Catch case where use submits index build args for one index type but - # method defines a different index type. - if index_arguments and ( - (isinstance(index_arguments, IndexArgsHNSW) and method != IndexMethod.hnsw) - or ( - isinstance(index_arguments, IndexArgsIVFFlat) - and method != IndexMethod.ivfflat - ) - ): - warnings.warn( - UserWarning( - f"{index_arguments.__class__.__name__} build parameters were supplied but {method} index was specified. Default parameters for {method} index will be used instead." - ) - ) - # set index_arguments to None in order to instantiate - # with the default values later - index_arguments = None - ops = INDEX_MEASURE_TO_OPS.get(measure) if ops is None: raise ArgError("Unknown index measure") From 95c31f37b54077d55e067d70ab0d1ee9bd8bc1ff Mon Sep 17 00:00:00 2001 From: Oliver Rice Date: Tue, 7 Nov 2023 14:10:22 -0600 Subject: [PATCH 3/3] add docs for index_arguments --- docs/api.md | 5 ++++- docs/concepts_indexes.md | 5 ++++- docs/support_changelog.md | 2 ++ src/vecs/collection.py | 1 + 4 files changed, 11 insertions(+), 2 deletions(-) diff --git a/docs/api.md b/docs/api.md index 45aedb8..ffcb3d8 100644 --- a/docs/api.md +++ b/docs/api.md @@ -101,16 +101,19 @@ Available options for index `method` are: Where `auto` selects the best available index method, `hnsw` uses the [HNSW](https://github.com/pgvector/pgvector#hnsw) method and `ivfflat` uses [IVFFlat](https://github.com/pgvector/pgvector#ivfflat). +HNSW and IVFFlat indexes both allow for parameterization to control the speed/accuracy tradeoff. vecs provides sane defaults for these parameters. For a greater level of control you can optionally pass an instance of `vecs.IndexArgsIVFFlat` or `vecs.IndexArgsHNSW` to `create_index`'s `index_arguments` argument. Descriptions of the impact for each parameter are available in the [pgvector docs](https://github.com/pgvector/pgvector). + When using IVFFlat indexes, the index must be created __after__ the collection has been populated with records. Building an IVFFlat index on an empty collection will result in significantly reduced recall. You can continue upserting new documents after the index has been created, but should rebuild the index if the size of the collection more than doubles since the last index operation. HNSW indexes can be created immediately after the collection without populating records. -To manually specify `method` and `measure`, add them as arguments to `create_index` for example: +To manually specify `method`, `measure`, and `index_arguments` add them as arguments to `create_index` for example: ```python docs.create_index( method=IndexMethod.hnsw, measure=IndexMeasure.cosine_distance, + measure=IndexArgsHNSW(m=8), ) ``` diff --git a/docs/concepts_indexes.md b/docs/concepts_indexes.md index 3256ddf..97623d8 100644 --- a/docs/concepts_indexes.md +++ b/docs/concepts_indexes.md @@ -40,16 +40,19 @@ Available options for index `method` are: Where `auto` selects the best available index method, `hnsw` uses the [HNSW](https://github.com/pgvector/pgvector#hnsw) method and `ivfflat` uses [IVFFlat](https://github.com/pgvector/pgvector#ivfflat). +HNSW and IVFFlat indexes both allow for parameterization to control the speed/accuracy tradeoff. vecs provides sane defaults for these parameters. For a greater level of control you can optionally pass an instance of `vecs.IndexArgsIVFFlat` or `vecs.IndexArgsHNSW` to `create_index`'s `index_arguments` argument. Descriptions of the impact for each parameter are available in the [pgvector docs](https://github.com/pgvector/pgvector). + When using IVFFlat indexes, the index must be created __after__ the collection has been populated with records. Building an IVFFlat index on an empty collection will result in significantly reduced recall. You can continue upserting new documents after the index has been created, but should rebuild the index if the size of the collection more than doubles since the last index operation. HNSW indexes can be created immediately after the collection without populating records. -To manually specify `method` and `measure`, ass them as arguments to `create_index` for example: +To manually specify `method`, `measure`, and `index_arguments` add them as arguments to `create_index` for example: ```python docs.create_index( method=IndexMethod.hnsw, measure=IndexMeasure.cosine_distance, + measure=IndexArgsHNSW(m=8), ) ``` diff --git a/docs/support_changelog.md b/docs/support_changelog.md index ecc9b6e..25309e1 100644 --- a/docs/support_changelog.md +++ b/docs/support_changelog.md @@ -32,3 +32,5 @@ - Bugfix: removed errant print statement ## master + +- Feature: Parameterized IVFFlat and HNSW indexes diff --git a/src/vecs/collection.py b/src/vecs/collection.py index 4b4d38c..d236d7a 100644 --- a/src/vecs/collection.py +++ b/src/vecs/collection.py @@ -684,6 +684,7 @@ def create_index( Args: measure (IndexMeasure, optional): The measure to index for. Defaults to 'cosine_distance'. method (IndexMethod, optional): The indexing method to use. Defaults to 'auto'. + index_arguments: (IndexArgsIVFFlat | IndexArgsHNSW, optional): Index type specific arguments replace (bool, optional): Whether to replace the existing index. Defaults to True. Raises: