-
Notifications
You must be signed in to change notification settings - Fork 124
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
kira
committed
Jan 24, 2024
1 parent
e1bbf5c
commit fb97191
Showing
4 changed files
with
189 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
FROM neurips23 | ||
|
||
RUN apt update | ||
RUN apt install -y software-properties-common | ||
RUN add-apt-repository -y ppa:git-core/ppa | ||
RUN apt update | ||
RUN DEBIAN_FRONTEND=noninteractive apt install -y git make cmake g++ libaio-dev libgoogle-perftools-dev libunwind-dev clang-format libboost-dev libboost-program-options-dev libmkl-full-dev libcpprest-dev python3.10 | ||
|
||
RUN git clone https://github.com/hhy3/zilliz-bigann.git --branch streaming | ||
RUN pip install ./zilliz-bigann/*.whl | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
random-xs-clustered: | ||
zilliz: | ||
docker-tag: neurips23-streaming-zilliz | ||
module: neurips23.streaming.zilliz.zilliz | ||
constructor: Zilliz | ||
base-args: ["@metric"] | ||
run-groups: | ||
base: | ||
args: | | ||
[{"R":32, "L":100, "insert_threads":16, "consolidate_threads":16}] | ||
query-args: | | ||
[{"Ls":200, "T":8}] | ||
random-xs: | ||
zilliz: | ||
docker-tag: neurips23-streaming-zilliz | ||
module: neurips23.streaming.zilliz.zilliz | ||
constructor: Zilliz | ||
base-args: ["@metric"] | ||
run-groups: | ||
base: | ||
args: | | ||
[{"R":32, "L":50, "insert_threads":16, "consolidate_threads":16}] | ||
query-args: | | ||
[{"Ls":50, "T":8}] | ||
msturing-10M-clustered: | ||
zilliz: | ||
docker-tag: neurips23-streaming-zilliz | ||
module: neurips23.streaming.zilliz.zilliz | ||
constructor: Zilliz | ||
base-args: ["@metric"] | ||
run-groups: | ||
base: | ||
args: | | ||
[{"R":16, "L":10, "insert_threads":8, "consolidate_threads":8}] | ||
query-args: | | ||
[ | ||
{"Ls":100, "T":8} | ||
] | ||
msturing-30M-clustered: | ||
zilliz: | ||
docker-tag: neurips23-streaming-zilliz | ||
module: neurips23.streaming.zilliz.zilliz | ||
constructor: Zilliz | ||
base-args: ["@metric"] | ||
run-groups: | ||
base: | ||
args: | | ||
[ | ||
{"R":32, "L":110, "insert_threads":8, "consolidate_threads":8} | ||
] | ||
query-args: | | ||
[ | ||
{"Ls":400, "T":8}, | ||
{"Ls":450, "T":8}, | ||
{"Ls":500, "T":8}, | ||
{"Ls":550, "T":8} | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
from __future__ import absolute_import | ||
import psutil | ||
import os | ||
import time | ||
import numpy as np | ||
|
||
import diskannpy | ||
import fast_refine | ||
|
||
from neurips23.streaming.base import BaseStreamingANN | ||
|
||
class Zilliz(BaseStreamingANN): | ||
def __init__(self, metric, index_params): | ||
self.name = "pyanns" | ||
if (index_params.get("R")==None): | ||
print("Error: missing parameter R") | ||
return | ||
if (index_params.get("L")==None): | ||
print("Error: missing parameter L") | ||
return | ||
self._index_params = index_params | ||
self._metric = metric | ||
|
||
self.R = index_params.get("R") | ||
self.L = index_params.get("L") | ||
self.insert_threads = index_params.get("insert_threads") | ||
self.consolidate_threads = index_params.get("consolidate_threads") | ||
self.mx = None | ||
self.mi = None | ||
|
||
def index_name(self): | ||
return f"R{self.R}_L{self.L}" | ||
|
||
def create_index_dir(self, dataset): | ||
index_dir = os.path.join(os.getcwd(), "data", "indices", "streaming") | ||
os.makedirs(index_dir, mode=0o777, exist_ok=True) | ||
index_dir = os.path.join(index_dir, 'pyanns') | ||
os.makedirs(index_dir, mode=0o777, exist_ok=True) | ||
index_dir = os.path.join(index_dir, dataset.short_name()) | ||
os.makedirs(index_dir, mode=0o777, exist_ok=True) | ||
index_dir = os.path.join(index_dir, self.index_name()) | ||
os.makedirs(index_dir, mode=0o777, exist_ok=True) | ||
return index_dir | ||
|
||
def translate_dist_fn(self, metric): | ||
if metric == 'euclidean': | ||
return 'l2' | ||
elif metric == 'ip': | ||
return 'mips' | ||
else: | ||
raise Exception('Invalid metric') | ||
|
||
def translate_dtype(self, dtype:str): | ||
return np.uint8 | ||
|
||
def setup(self, dtype, max_pts, ndim): | ||
self.index = diskannpy.DynamicMemoryIndex( | ||
distance_metric = self.translate_dist_fn(self._metric), | ||
vector_dtype = self.translate_dtype(dtype), | ||
max_vectors = max_pts, | ||
dimensions = ndim, | ||
graph_degree = self.R, | ||
complexity=self.L, | ||
num_threads = self.insert_threads, #to allocate scratch space for up to 64 search threads | ||
initial_search_complexity = 100 | ||
) | ||
self.refiner = fast_refine.Refiner(ndim, max_pts) | ||
self.max_pts = max_pts | ||
print('Index class constructed and ready for update/search') | ||
self.active_indices = set() | ||
self.num_unprocessed_deletes = 0 | ||
|
||
def quant(self, X, mi, mx): | ||
return np.round(np.clip((X - mi) / (mx - mi) * 127.0, 0.0, 127.0)).astype('uint8') | ||
|
||
def insert(self, X, ids): | ||
if self.mi is None: | ||
self.mi = X.min() | ||
self.mx = X.max() | ||
|
||
self.refiner.batch_insert(X, ids) | ||
X = self.quant(X, self.mi, self.mx) | ||
self.active_indices.update(ids+1) | ||
print('#active pts', len(self.active_indices), '#unprocessed deletes', self.num_unprocessed_deletes) | ||
if len(self.active_indices) + self.num_unprocessed_deletes > self.max_pts: | ||
self.index.consolidate_delete() | ||
self.num_unprocessed_deletes = 0 | ||
|
||
self.index.batch_insert(X, ids+1) | ||
|
||
def delete(self, ids): | ||
self.refiner.batch_delete(ids) | ||
for id in ids: | ||
self.index.mark_deleted(id+1) | ||
self.active_indices.difference_update(ids+1) | ||
self.num_unprocessed_deletes += len(ids) | ||
|
||
def query(self, X, k): | ||
"""Carry out a batch query for k-NN of query set X.""" | ||
nq, d = X.shape | ||
Xq = self.quant(X, self.mi, self.mx) | ||
k_mul = 5 | ||
k_reorder = k * k_mul | ||
I, _ = self.index.batch_search( | ||
Xq, k_reorder, self.Ls, self.search_threads) | ||
I = I - 1 | ||
self.res = self.refiner.batch_refine(X, I, k).reshape(nq, k) | ||
|
||
def set_query_arguments(self, query_args): | ||
self._query_args = query_args | ||
self.Ls = 0 if query_args.get("Ls") == None else query_args.get("Ls") | ||
self.search_threads = self._query_args.get("T") | ||
|
||
def __str__(self): | ||
return f'zilliz({self.index_name(), self._query_args})' |