Skip to content

Commit

Permalink
add core ngt benchmark
Browse files Browse the repository at this point in the history
Signed-off-by: Kosuke Morimoto <[email protected]>
  • Loading branch information
kmrmt committed Sep 13, 2023
1 parent 631e1ff commit e47b1ca
Showing 1 changed file with 148 additions and 0 deletions.
148 changes: 148 additions & 0 deletions internal/core/algorithm/ngt/ngt_bench_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
//
// Copyright (C) 2019-2023 vdaas.org vald team <[email protected]>
//
// Licensed under the Apache License, Version 2.0 (the "License");
// You may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//

// Package ngt provides implementation of Go API for https://github.com/yahoojapan/NGT
package ngt

import (
"testing"

"gonum.org/v1/hdf5"
)

var (
vectors [][]float32
n NGT
ids []uint
)

func init() {
vectors, _, _ = load("sift-128-euclidean.hdf5")
n, _ = New(
WithDimension(len(vectors[0])),
WithDefaultPoolSize(8),
WithObjectType(Float),
WithDistanceType(L2),
)
}

func RunNGT1(b *testing.B) error {
b.Helper()

ids = make([]uint, len(vectors))
b.ResetTimer()
for i, vector := range vectors {
id, err := n.Insert(vector)
if err != nil {
return err
}
ids[i] = id
}

if err := n.CreateIndex(8); err != nil {
return err
}
return nil
}

func RunNGT2(b *testing.B) error {
b.Helper()

b.ResetTimer()
for _, id := range ids {
if err := n.Remove(id); err != nil {
return err
}
}
return nil
}

func BenchmarkNGT(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := RunNGT1(b); err != nil {
b.Fatal(err)
}

if err := RunNGT2(b); err != nil {
b.Fatal(err)
}
}
}

// load function loads training and test vector from hdf file. The size of ids is same to the number of training data.
// Each id, which is an element of ids, will be set a random number.
func load(path string) (train, test [][]float32, err error) {
var f *hdf5.File
f, err = hdf5.OpenFile(path, hdf5.F_ACC_RDONLY)
if err != nil {
return nil, nil, err
}
defer f.Close()

// readFn function reads vectors of the hierarchy with the given the name.
readFn := func(name string) ([][]float32, error) {
// Opens and returns a named Dataset.
// The returned dataset must be closed by the user when it is no longer needed.
d, err := f.OpenDataset(name)
if err != nil {
return nil, err
}
defer d.Close()

// Space returns an identifier for a copy of the dataspace for a dataset.
sp := d.Space()
defer sp.Close()

// SimpleExtentDims returns dataspace dimension size and maximum size.
dims, _, _ := sp.SimpleExtentDims()
row, dim := int(dims[0]), int(dims[1])

// Gets the stored vector. All are represented as one-dimensional arrays.
// The type of the slice depends on your dataset.
// For fashion-mnist-784-euclidean.hdf5, the datatype is float32.
vec := make([]float32, sp.SimpleExtentNPoints())
if err := d.Read(&vec); err != nil {
return nil, err
}

// Converts a one-dimensional array to a two-dimensional array.
// Use the `dim` variable as a separator.
vecs := make([][]float32, row)
for i := 0; i < row; i++ {
vecs[i] = make([]float32, dim)
for j := 0; j < dim; j++ {
vecs[i][j] = float32(vec[i*dim+j])
}
}

return vecs, nil
}

// Gets vector of `train` hierarchy.
train, err = readFn("train")
if err != nil {
return nil, nil, err
}

// Gets vector of `test` hierarchy.
test, err = readFn("test")
if err != nil {
return nil, nil, err
}

return
}

0 comments on commit e47b1ca

Please sign in to comment.