From e47b1ca1af2cc262b94d4acacfcc44cb6f51531e Mon Sep 17 00:00:00 2001 From: Kosuke Morimoto Date: Thu, 7 Sep 2023 22:11:55 +0900 Subject: [PATCH] add core ngt benchmark Signed-off-by: Kosuke Morimoto --- internal/core/algorithm/ngt/ngt_bench_test.go | 148 ++++++++++++++++++ 1 file changed, 148 insertions(+) create mode 100644 internal/core/algorithm/ngt/ngt_bench_test.go diff --git a/internal/core/algorithm/ngt/ngt_bench_test.go b/internal/core/algorithm/ngt/ngt_bench_test.go new file mode 100644 index 00000000000..17617719b59 --- /dev/null +++ b/internal/core/algorithm/ngt/ngt_bench_test.go @@ -0,0 +1,148 @@ +// +// Copyright (C) 2019-2023 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// You may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package ngt provides implementation of Go API for https://github.com/yahoojapan/NGT +package ngt + +import ( + "testing" + + "gonum.org/v1/hdf5" +) + +var ( + vectors [][]float32 + n NGT + ids []uint +) + +func init() { + vectors, _, _ = load("sift-128-euclidean.hdf5") + n, _ = New( + WithDimension(len(vectors[0])), + WithDefaultPoolSize(8), + WithObjectType(Float), + WithDistanceType(L2), + ) +} + +func RunNGT1(b *testing.B) error { + b.Helper() + + ids = make([]uint, len(vectors)) + b.ResetTimer() + for i, vector := range vectors { + id, err := n.Insert(vector) + if err != nil { + return err + } + ids[i] = id + } + + if err := n.CreateIndex(8); err != nil { + return err + } + return nil +} + +func RunNGT2(b *testing.B) error { + b.Helper() + + b.ResetTimer() + for _, id := range ids { + if err := n.Remove(id); err != nil { + return err + } + } + return nil +} + +func BenchmarkNGT(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := RunNGT1(b); err != nil { + b.Fatal(err) + } + + if err := RunNGT2(b); err != nil { + b.Fatal(err) + } + } +} + +// load function loads training and test vector from hdf file. The size of ids is same to the number of training data. +// Each id, which is an element of ids, will be set a random number. +func load(path string) (train, test [][]float32, err error) { + var f *hdf5.File + f, err = hdf5.OpenFile(path, hdf5.F_ACC_RDONLY) + if err != nil { + return nil, nil, err + } + defer f.Close() + + // readFn function reads vectors of the hierarchy with the given the name. + readFn := func(name string) ([][]float32, error) { + // Opens and returns a named Dataset. + // The returned dataset must be closed by the user when it is no longer needed. + d, err := f.OpenDataset(name) + if err != nil { + return nil, err + } + defer d.Close() + + // Space returns an identifier for a copy of the dataspace for a dataset. + sp := d.Space() + defer sp.Close() + + // SimpleExtentDims returns dataspace dimension size and maximum size. + dims, _, _ := sp.SimpleExtentDims() + row, dim := int(dims[0]), int(dims[1]) + + // Gets the stored vector. All are represented as one-dimensional arrays. + // The type of the slice depends on your dataset. + // For fashion-mnist-784-euclidean.hdf5, the datatype is float32. + vec := make([]float32, sp.SimpleExtentNPoints()) + if err := d.Read(&vec); err != nil { + return nil, err + } + + // Converts a one-dimensional array to a two-dimensional array. + // Use the `dim` variable as a separator. + vecs := make([][]float32, row) + for i := 0; i < row; i++ { + vecs[i] = make([]float32, dim) + for j := 0; j < dim; j++ { + vecs[i][j] = float32(vec[i*dim+j]) + } + } + + return vecs, nil + } + + // Gets vector of `train` hierarchy. + train, err = readFn("train") + if err != nil { + return nil, nil, err + } + + // Gets vector of `test` hierarchy. + test, err = readFn("test") + if err != nil { + return nil, nil, err + } + + return +}