Skip to content

Commit

Permalink
add nprobe option
Browse files Browse the repository at this point in the history
  • Loading branch information
datelier committed Aug 15, 2024
1 parent 3e32660 commit ccfc94b
Show file tree
Hide file tree
Showing 12 changed files with 591 additions and 537 deletions.
1 change: 1 addition & 0 deletions apis/docs/v1/docs.md
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,7 @@ Represent search configuration.
| min_num | [uint32](#uint32) | | Minimum number of result to be returned. |
| aggregation_algorithm | [Search.AggregationAlgorithm](#payload-v1-Search-AggregationAlgorithm) | | Aggregation Algorithm |
| ratio | [google.protobuf.FloatValue](#google-protobuf-FloatValue) | | Search ratio for agent return result number. |
| nprobe | [uint32](#uint32) | | Search nprobe. |

<a name="payload-v1-Search-IDRequest"></a>

Expand Down
1,035 changes: 522 additions & 513 deletions apis/grpc/v1/payload/payload.pb.go

Large diffs are not rendered by default.

31 changes: 31 additions & 0 deletions apis/grpc/v1/payload/payload_vtproto.pb.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ func (m *Search_Config) CloneVT() *Search_Config {
r.MinNum = m.MinNum
r.AggregationAlgorithm = m.AggregationAlgorithm
r.Ratio = (*wrapperspb.FloatValue)((*wrapperspb1.FloatValue)(m.Ratio).CloneVT())
r.Nprobe = m.Nprobe

Check warning on line 187 in apis/grpc/v1/payload/payload_vtproto.pb.go

View check run for this annotation

Codecov / codecov/patch

apis/grpc/v1/payload/payload_vtproto.pb.go#L187

Added line #L187 was not covered by tests
if len(m.unknownFields) > 0 {
r.unknownFields = make([]byte, len(m.unknownFields))
copy(r.unknownFields, m.unknownFields)
Expand Down Expand Up @@ -2223,6 +2224,9 @@ func (this *Search_Config) EqualVT(that *Search_Config) bool {
if !(*wrapperspb1.FloatValue)(this.Ratio).EqualVT((*wrapperspb1.FloatValue)(that.Ratio)) {
return false
}
if this.Nprobe != that.Nprobe {
return false

Check warning on line 2228 in apis/grpc/v1/payload/payload_vtproto.pb.go

View check run for this annotation

Codecov / codecov/patch

apis/grpc/v1/payload/payload_vtproto.pb.go#L2227-L2228

Added lines #L2227 - L2228 were not covered by tests
}
return string(this.unknownFields) == string(that.unknownFields)
}

Expand Down Expand Up @@ -5166,6 +5170,11 @@ func (m *Search_Config) MarshalToSizedBufferVT(dAtA []byte) (int, error) {
i -= len(m.unknownFields)
copy(dAtA[i:], m.unknownFields)
}
if m.Nprobe != 0 {
i = protohelpers.EncodeVarint(dAtA, i, uint64(m.Nprobe))
i--
dAtA[i] = 0x58

Check warning on line 5176 in apis/grpc/v1/payload/payload_vtproto.pb.go

View check run for this annotation

Codecov / codecov/patch

apis/grpc/v1/payload/payload_vtproto.pb.go#L5173-L5176

Added lines #L5173 - L5176 were not covered by tests
}
if m.Ratio != nil {
size, err := (*wrapperspb1.FloatValue)(m.Ratio).MarshalToSizedBufferVT(dAtA[:i])
if err != nil {
Expand Down Expand Up @@ -9783,6 +9792,9 @@ func (m *Search_Config) SizeVT() (n int) {
l = (*wrapperspb1.FloatValue)(m.Ratio).SizeVT()
n += 1 + l + protohelpers.SizeOfVarint(uint64(l))
}
if m.Nprobe != 0 {
n += 1 + protohelpers.SizeOfVarint(uint64(m.Nprobe))

Check warning on line 9796 in apis/grpc/v1/payload/payload_vtproto.pb.go

View check run for this annotation

Codecov / codecov/patch

apis/grpc/v1/payload/payload_vtproto.pb.go#L9795-L9796

Added lines #L9795 - L9796 were not covered by tests
}
n += len(m.unknownFields)
return n
}
Expand Down Expand Up @@ -12382,6 +12394,25 @@ func (m *Search_Config) UnmarshalVT(dAtA []byte) error {
return err
}
iNdEx = postIndex
case 11:
if wireType != 0 {
return fmt.Errorf("proto: wrong wireType = %d for field Nprobe", wireType)

Check warning on line 12399 in apis/grpc/v1/payload/payload_vtproto.pb.go

View check run for this annotation

Codecov / codecov/patch

apis/grpc/v1/payload/payload_vtproto.pb.go#L12397-L12399

Added lines #L12397 - L12399 were not covered by tests
}
m.Nprobe = 0
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return protohelpers.ErrIntOverflow

Check warning on line 12404 in apis/grpc/v1/payload/payload_vtproto.pb.go

View check run for this annotation

Codecov / codecov/patch

apis/grpc/v1/payload/payload_vtproto.pb.go#L12401-L12404

Added lines #L12401 - L12404 were not covered by tests
}
if iNdEx >= l {
return io.ErrUnexpectedEOF

Check warning on line 12407 in apis/grpc/v1/payload/payload_vtproto.pb.go

View check run for this annotation

Codecov / codecov/patch

apis/grpc/v1/payload/payload_vtproto.pb.go#L12406-L12407

Added lines #L12406 - L12407 were not covered by tests
}
b := dAtA[iNdEx]
iNdEx++
m.Nprobe |= uint32(b&0x7F) << shift
if b < 0x80 {
break

Check warning on line 12413 in apis/grpc/v1/payload/payload_vtproto.pb.go

View check run for this annotation

Codecov / codecov/patch

apis/grpc/v1/payload/payload_vtproto.pb.go#L12409-L12413

Added lines #L12409 - L12413 were not covered by tests
}
}
default:
iNdEx = preIndex
skippy, err := protohelpers.Skip(dAtA[iNdEx:])
Expand Down
2 changes: 2 additions & 0 deletions apis/proto/v1/payload/payload.proto
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ message Search {
AggregationAlgorithm aggregation_algorithm = 9;
// Search ratio for agent return result number.
google.protobuf.FloatValue ratio = 10;
// Search nprobe.
uint32 nprobe = 11;
}

// AggregationAlgorithm is enum of each aggregation algorithms
Expand Down
5 changes: 5 additions & 0 deletions apis/swagger/v1/vald/filter.swagger.json
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,11 @@
"type": "number",
"format": "float",
"description": "Search ratio for agent return result number."
},
"nprobe": {
"type": "integer",
"format": "int64",
"description": "Search nprobe."
}
},
"description": "Represent search configuration."
Expand Down
5 changes: 5 additions & 0 deletions apis/swagger/v1/vald/search.swagger.json
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,11 @@
"type": "number",
"format": "float",
"description": "Search ratio for agent return result number."
},
"nprobe": {
"type": "integer",
"format": "int64",
"description": "Search nprobe."
}
},
"description": "Represent search configuration."
Expand Down
25 changes: 7 additions & 18 deletions internal/core/algorithm/faiss/Capi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,8 @@ FaissStruct* faiss_create_index(
switch (method_type) {
case IVFPQ:
return faiss_create_index_ivfpq(d, nlist, m, nbits_per_idx, metric_type);
break;
case BINARYIVF:
return faiss_create_index_binaryivf(d*8, nlist);
break;
default:
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: no method type.";
Expand All @@ -75,10 +73,8 @@ FaissStruct* faiss_create_index_ivfpq(
switch (metric_type) {
case faiss::METRIC_INNER_PRODUCT:
quantizer = new faiss::IndexFlat(d, faiss::METRIC_INNER_PRODUCT);
break;
case faiss::METRIC_L2:
quantizer = new faiss::IndexFlat(d, faiss::METRIC_L2);
break;
default:
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: no metric type.";
Expand Down Expand Up @@ -134,10 +130,8 @@ FaissStruct* faiss_read_index(const char* fname, const int method_type) {
switch (method_type) {
case IVFPQ:
return faiss_read_index_ivfpq(fname);
break;
case BINARYIVF:
return faiss_read_index_binaryindex(fname);
break;
default:
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: no method type.";
Expand Down Expand Up @@ -197,10 +191,8 @@ bool faiss_write_index(
switch (method_type) {
case IVFPQ:
return faiss_write_index_ivfpq(st, fname);
break;
case BINARYIVF:
return faiss_write_index_binaryivf(st, fname);
break;
default:
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: no method type.";
Expand Down Expand Up @@ -261,10 +253,8 @@ bool faiss_train(
switch (method_type) {
case IVFPQ:
return faiss_train_ivfpq(st, nb, xb);
break;
case BINARYIVF:
return faiss_train_binaryivf(st, nb, reinterpret_cast<const uint8_t*>(xb));
break;
default:
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: no method type.";
Expand Down Expand Up @@ -332,10 +322,8 @@ int faiss_add(
switch (method_type) {
case IVFPQ:
return faiss_add_ivfpq(st, nb, xb, xids);
break;
case BINARYIVF:
return faiss_add_binaryivf(st, nb, reinterpret_cast<const uint8_t*>(xb), xids);
break;
default:
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: no method type.";
Expand Down Expand Up @@ -395,6 +383,7 @@ int faiss_add_binaryivf(
bool faiss_search(
const FaissStruct* st,
const int k,
const int nprobe,
const int nq,
const float* xq,
long* I,
Expand All @@ -406,11 +395,9 @@ bool faiss_search(

switch (method_type) {
case IVFPQ:
return faiss_search_ivfpq(st, k, nq, xq, I, D);
break;
return faiss_search_ivfpq(st, k, nprobe, nq, xq, I, D);
case BINARYIVF:
return faiss_search_binaryivf(st, k, nq, reinterpret_cast<const uint8_t*>(xq), I, D);
break;
return faiss_search_binaryivf(st, k, nprobe, nq, reinterpret_cast<const uint8_t*>(xq), I, D);
default:
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: no method type.";
Expand All @@ -422,6 +409,7 @@ bool faiss_search(
bool faiss_search_ivfpq(
const FaissStruct* st,
const int k,
const int nprobe,
const int nq,
const float* xq,
long* I,
Expand All @@ -433,6 +421,7 @@ bool faiss_search_ivfpq(
try {
//printf("is_trained: %d\n", (static_cast<faiss::IndexIVFPQ*>(st->faiss_index))->is_trained);
//printf("ntotal: %ld\n", (static_cast<faiss::IndexIVFPQ*>(st->faiss_index))->ntotal);
(static_cast<faiss::IndexIVFPQ*>(st->faiss_index))->nprobe = nprobe;
(static_cast<faiss::IndexIVFPQ*>(st->faiss_index))->search(nq, xq, k, D, I);
//printf("I=\n");
//for(int i = 0; i < nq; i++) {
Expand Down Expand Up @@ -461,6 +450,7 @@ bool faiss_search_ivfpq(
bool faiss_search_binaryivf(
const FaissStruct* st,
const int k,
const int nprobe,
const int nq,
const uint8_t* xq,
long* I,
Expand All @@ -471,6 +461,7 @@ bool faiss_search_binaryivf(

int32_t* tmpD = new int32_t[nq*k];
try {
(static_cast<faiss::IndexBinaryIVF*>(st->faiss_index))->nprobe = nprobe;
(static_cast<faiss::IndexBinaryIVF*>(st->faiss_index))->search(nq, xq, k, tmpD, I);
} catch(std::exception &err) {
std::stringstream ss;
Expand All @@ -497,10 +488,8 @@ int faiss_remove(
switch (method_type) {
case IVFPQ:
return faiss_remove_ivfpq(st, size, ids);
break;
case BINARYIVF:
return faiss_remove_binaryivf(st, size, ids);
break;
default:
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: no method type.";
Expand Down
3 changes: 3 additions & 0 deletions internal/core/algorithm/faiss/Capi.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ extern "C" {
bool faiss_search(
const FaissStruct* st,
const int k,
const int nprobe,
const int nq,
const float* xq,
long* I,
Expand All @@ -98,13 +99,15 @@ extern "C" {
bool faiss_search_ivfpq(
const FaissStruct* st,
const int k,
const int nprobe,
const int nq,
const float* xq,
long* I,
float* D);
bool faiss_search_binaryivf(
const FaissStruct* st,
const int k,
const int nprobe,
const int nq,
const uint8_t* xq,
long* I,
Expand Down
6 changes: 3 additions & 3 deletions internal/core/algorithm/faiss/faiss.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ type (
Add(nb int, xb []float32, xids []int64) (int, error)

// Search returns search result as []algorithm.SearchResult.
Search(k, nq int, xq []float32) ([]algorithm.SearchResult, error)
Search(k, nprobe, nq int, xq []float32) ([]algorithm.SearchResult, error)

// Remove removes from faiss index.
Remove(size int, ids []int64) (int, error)
Expand Down Expand Up @@ -215,15 +215,15 @@ func (f *faiss) Add(nb int, xb []float32, xids []int64) (int, error) {
}

// Search returns search result as []algorithm.SearchResult.
func (f *faiss) Search(k, nq int, xq []float32) ([]algorithm.SearchResult, error) {
func (f *faiss) Search(k, nprobe, nq int, xq []float32) ([]algorithm.SearchResult, error) {

Check warning on line 218 in internal/core/algorithm/faiss/faiss.go

View check run for this annotation

Codecov / codecov/patch

internal/core/algorithm/faiss/faiss.go#L218

Added line #L218 was not covered by tests
if len(xq) != nq*int(f.dimension) {
return nil, errors.ErrIncompatibleDimensionSize(len(xq), int(f.dimension))
}

I := make([]int64, k*nq)
D := make([]float32, k*nq)
f.mu.RLock()
ret := C.faiss_search(f.st, (C.int)(k), (C.int)(nq), (*C.float)(&xq[0]), (*C.long)(&I[0]), (*C.float)(&D[0]), C.int(f.methodType))
ret := C.faiss_search(f.st, (C.int)(k), (C.int)(nprobe), (C.int)(nq), (*C.float)(&xq[0]), (*C.long)(&I[0]), (*C.float)(&D[0]), C.int(f.methodType))

Check warning on line 226 in internal/core/algorithm/faiss/faiss.go

View check run for this annotation

Codecov / codecov/patch

internal/core/algorithm/faiss/faiss.go#L226

Added line #L226 was not covered by tests
f.mu.RUnlock()
if ret == ErrorCode {
return nil, errors.NewFaissError("failed to faiss_search")
Expand Down
1 change: 1 addition & 0 deletions pkg/agent/core/faiss/handler/grpc/search.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ func (s *server) Search(
}
res, err = s.faiss.Search(
req.GetConfig().GetNum(),
req.GetConfig().GetNprobe(),

Check warning on line 72 in pkg/agent/core/faiss/handler/grpc/search.go

View check run for this annotation

Codecov / codecov/patch

pkg/agent/core/faiss/handler/grpc/search.go#L72

Added line #L72 was not covered by tests
1,
req.GetVector())
if err == nil && res == nil {
Expand Down
11 changes: 8 additions & 3 deletions pkg/agent/core/faiss/service/faiss.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ type (
CreateIndex(ctx context.Context) error
SaveIndex(ctx context.Context) error
CreateAndSaveIndex(ctx context.Context) error
Search(k, nq uint32, xq []float32) (*payload.Search_Response, error)
Search(k, nprobe, nq uint32, xq []float32) (*payload.Search_Response, error)
Delete(uuid string) error
DeleteWithTime(uuid string, t int64) error
Exists(uuid string) (uint32, bool)
Expand Down Expand Up @@ -1123,12 +1123,17 @@ func (f *faiss) CreateAndSaveIndex(ctx context.Context) error {
return f.SaveIndex(ctx)
}

func (f *faiss) Search(k, nq uint32, xq []float32) (res *payload.Search_Response, err error) {
func (f *faiss) Search(
k, nprobe, nq uint32, xq []float32,
) (res *payload.Search_Response, err error) {

Check warning on line 1128 in pkg/agent/core/faiss/service/faiss.go

View check run for this annotation

Codecov / codecov/patch

pkg/agent/core/faiss/service/faiss.go#L1128

Added line #L1128 was not covered by tests
if f.IsIndexing() {
return nil, errors.ErrCreateIndexingIsInProgress
}
if nprobe == 0 {
nprobe = 1

Check warning on line 1133 in pkg/agent/core/faiss/service/faiss.go

View check run for this annotation

Codecov / codecov/patch

pkg/agent/core/faiss/service/faiss.go#L1132-L1133

Added lines #L1132 - L1133 were not covered by tests
}

sr, err := f.core.Search(int(k), int(nq), xq)
sr, err := f.core.Search(int(k), int(nprobe), int(nq), xq)

Check warning on line 1136 in pkg/agent/core/faiss/service/faiss.go

View check run for this annotation

Codecov / codecov/patch

pkg/agent/core/faiss/service/faiss.go#L1136

Added line #L1136 was not covered by tests
if err != nil {
if f.IsIndexing() {
return nil, errors.ErrCreateIndexingIsInProgress
Expand Down
3 changes: 3 additions & 0 deletions rust/libs/proto/src/payload.v1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ pub mod search {
/// Search ratio for agent return result number.
#[prost(message, optional, tag="10")]
pub ratio: ::core::option::Option<f32>,
/// Search nprobe.
#[prost(uint32, tag="11")]
pub nprobe: u32,
}
/// Represent a search response.
#[allow(clippy::derive_partial_eq_without_eq)]
Expand Down

0 comments on commit ccfc94b

Please sign in to comment.