diff --git a/dgraph/cmd/zero/oracle.go b/dgraph/cmd/zero/oracle.go index 01c82d3ec4f..97b4dd07e72 100644 --- a/dgraph/cmd/zero/oracle.go +++ b/dgraph/cmd/zero/oracle.go @@ -30,6 +30,7 @@ import ( "github.com/dgraph-io/badger/v4/y" "github.com/dgraph-io/dgo/v230/protos/api" "github.com/dgraph-io/dgraph/protos/pb" + "github.com/dgraph-io/dgraph/tok/hnsw" "github.com/dgraph-io/dgraph/x" "github.com/dgraph-io/ristretto/z" ) @@ -376,6 +377,9 @@ func (s *Server) commit(ctx context.Context, src *api.TxnContext) error { return errors.Wrapf(err, "unable to parse group id from %s", pkey) } pred := splits[1] + if strings.Contains(pred, hnsw.VecKeyword) { + pred = pred[0:strings.Index(pred, hnsw.VecKeyword)] + } tablet := s.ServingTablet(pred) if tablet == nil { return errors.Errorf("Tablet for %s is nil", pred) diff --git a/dgraphtest/paths.go b/dgraphtest/paths.go index d3b5db032fb..c6486a4f2de 100644 --- a/dgraphtest/paths.go +++ b/dgraphtest/paths.go @@ -51,9 +51,9 @@ func init() { if err != nil { panic(err) } - if err := ensureDgraphClone(); err != nil { - panic(err) - } + //if err := ensureDgraphClone(); err != nil { + // panic(err) + //} log.Printf("[INFO] baseRepoDir: %v", baseRepoDir) log.Printf("[INFO] repoDir: %v", repoDir) diff --git a/dql/parser.go b/dql/parser.go index 3f1fa1acb55..08d51556c3f 100644 --- a/dql/parser.go +++ b/dql/parser.go @@ -33,12 +33,13 @@ import ( ) const ( - uidFunc = "uid" - valueFunc = "val" - typFunc = "type" - lenFunc = "len" - countFunc = "count" - uidInFunc = "uid_in" + uidFunc = "uid" + valueFunc = "val" + typFunc = "type" + lenFunc = "len" + countFunc = "count" + uidInFunc = "uid_in" + similarToFn = "similar_to" ) var ( @@ -356,7 +357,7 @@ func parseValue(v varInfo) (types.Val, error) { }, nil } } - case "vfloat": + case "vector32float": { if i, err := types.ParseVFloat(v.Value); err != nil { return types.Val{}, errors.Wrapf(err, "Expected a vfloat but got %v", v.Value) @@ -1711,7 +1712,7 @@ func validFuncName(name string) bool { switch name { case "regexp", "anyofterms", "allofterms", "alloftext", "anyoftext", - "has", "uid", "uid_in", "anyof", "allof", "type", "match": + "has", "uid", "uid_in", "anyof", "allof", "type", "match", "similar_to": return true } return false @@ -1884,7 +1885,7 @@ L: case IsInequalityFn(function.Name): err = parseFuncArgs(it, function) - case function.Name == "uid_in": + case function.Name == "uid_in" || function.Name == "similar_to": err = parseFuncArgs(it, function) default: diff --git a/go.mod b/go.mod index 2789c74f869..3da8a99fca6 100644 --- a/go.mod +++ b/go.mod @@ -59,6 +59,7 @@ require ( go.opencensus.io v0.24.0 go.uber.org/zap v1.16.0 golang.org/x/crypto v0.21.0 + golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 golang.org/x/net v0.22.0 golang.org/x/sync v0.6.0 golang.org/x/sys v0.18.0 @@ -141,7 +142,6 @@ require ( github.com/xdg/stringprep v1.0.3 // indirect go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.10.0 // indirect - golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 // indirect golang.org/x/mod v0.16.0 // indirect golang.org/x/time v0.3.0 // indirect google.golang.org/api v0.122.0 // indirect diff --git a/posting/index.go b/posting/index.go index e21d4b6c975..c4a7820014f 100644 --- a/posting/index.go +++ b/posting/index.go @@ -21,6 +21,7 @@ import ( "context" "encoding/binary" "encoding/hex" + "encoding/json" "fmt" "math" "os" @@ -38,6 +39,7 @@ import ( "github.com/dgraph-io/dgraph/protos/pb" "github.com/dgraph-io/dgraph/schema" "github.com/dgraph-io/dgraph/tok" + "github.com/dgraph-io/dgraph/tok/hnsw" "github.com/dgraph-io/dgraph/types" "github.com/dgraph-io/dgraph/x" "github.com/dgraph-io/ristretto/z" @@ -46,10 +48,11 @@ import ( var emptyCountParams countParams type indexMutationInfo struct { - tokenizers []tok.Tokenizer - edge *pb.DirectedEdge // Represents the original uid -> value edge. - val types.Val - op pb.DirectedEdge_Op + tokenizers []tok.Tokenizer + factorySpecs []*tok.FactoryCreateSpec + edge *pb.DirectedEdge // Represents the original uid -> value edge. + val types.Val + op pb.DirectedEdge_Op } // indexTokens return tokens, without the predicate prefix and @@ -85,20 +88,104 @@ func indexTokens(ctx context.Context, info *indexMutationInfo) ([]string, error) // addIndexMutations adds mutation(s) for a single term, to maintain the index, // but only for the given tokenizers. // TODO - See if we need to pass op as argument as t should already have Op. -func (txn *Txn) addIndexMutations(ctx context.Context, info *indexMutationInfo) error { + +func (txn *Txn) addIndexMutations(ctx context.Context, info *indexMutationInfo) ([]*pb.DirectedEdge, error) { if info.tokenizers == nil { info.tokenizers = schema.State().Tokenizer(ctx, info.edge.Attr) } + if info.factorySpecs == nil { + specs, err := schema.State().FactoryCreateSpec(ctx, info.edge.Attr) + if err != nil { + return nil, err + } + info.factorySpecs = specs + } + attr := info.edge.Attr uid := info.edge.Entity if uid == 0 { - return errors.New("invalid UID with value 0") + return []*pb.DirectedEdge{}, errors.New("invalid UID with value 0") + } + + inKey := x.DataKey(info.edge.Attr, uid) + pl, err := txn.Get(inKey) + if err != nil { + return []*pb.DirectedEdge{}, err + } + data, err := pl.AllValues(txn.StartTs) + if err != nil { + return []*pb.DirectedEdge{}, err + } + + if info.op == pb.DirectedEdge_DEL && + len(data) > 0 && data[0].Tid == types.VFloatID { + // TODO look into better alternatives + // The issue here is that we will create dead nodes in the Vector Index + // assuming an HNSW index type. What we should do instead is invoke + // index.Remove(). However, we currently do + // not support this in VectorIndex code!! + // if a delete & dealing with vfloats, add this to dead node in persistent store. + // What we should do instead is invoke the factory.Remove(key) operation. + deadAttr := hnsw.ConcatStrings(info.edge.Attr, hnsw.VecDead) + deadKey := x.DataKey(deadAttr, 1) + pl, err := txn.Get(deadKey) + if err != nil { + return []*pb.DirectedEdge{}, err + } + var deadNodes []uint64 + deadData, _ := pl.Value(txn.StartTs) + if deadData.Value == nil { + deadNodes = append(deadNodes, uid) + } else { + deadNodes, err = hnsw.ParseEdges(string(deadData.Value.([]byte))) + if err != nil { + return []*pb.DirectedEdge{}, err + } + deadNodes = append(deadNodes, uid) + } + deadNodesBytes, marshalErr := json.Marshal(deadNodes) + if marshalErr != nil { + return []*pb.DirectedEdge{}, marshalErr + } + edge := &pb.DirectedEdge{ + Entity: 1, + Attr: deadAttr, + Value: deadNodesBytes, + ValueType: pb.Posting_ValType(0), + } + pl.addMutation(ctx, txn, edge) + } + + // TODO: As stated earlier, we need to validate that it is okay to assume + // that we care about just data[0]. + // Similarly, the current assumption is that we have at most one + // Vector Index, but this assumption may break later. + if info.op == pb.DirectedEdge_SET && + len(data) > 0 && data[0].Tid == types.VFloatID && + len(info.factorySpecs) > 0 { + // retrieve vector from inUuid save as inVec + inVec := types.BytesAsFloatArray(data[0].Value.([]byte)) + tc := hnsw.NewTxnCache(NewViTxn(txn), txn.StartTs) + indexer, err := info.factorySpecs[0].CreateIndex(attr) + if err != nil { + return []*pb.DirectedEdge{}, err + } + edges, err := indexer.Insert(ctx, tc, uid, inVec) + if err != nil { + return []*pb.DirectedEdge{}, err + } + pbEdges := []*pb.DirectedEdge{} + for _, e := range edges { + pbe := indexEdgeToPbEdge(e) + pbEdges = append(pbEdges, pbe) + } + return pbEdges, nil } tokens, err := indexTokens(ctx, info) if err != nil { // This data is not indexable - return err + return []*pb.DirectedEdge{}, err } // Create a value token -> uid edge. @@ -110,10 +197,10 @@ func (txn *Txn) addIndexMutations(ctx context.Context, info *indexMutationInfo) for _, token := range tokens { if err := txn.addIndexMutation(ctx, edge, token); err != nil { - return err + return []*pb.DirectedEdge{}, err } } - return nil + return []*pb.DirectedEdge{}, nil } func (txn *Txn) addIndexMutation(ctx context.Context, edge *pb.DirectedEdge, token string) error { @@ -291,12 +378,18 @@ func (l *List) handleDeleteAll(ctx context.Context, edge *pb.DirectedEdge, txn * Tid: types.TypeID(p.ValType), Value: p.Value, } - return txn.addIndexMutations(ctx, &indexMutationInfo{ - tokenizers: schema.State().Tokenizer(ctx, edge.Attr), - edge: edge, - val: val, - op: pb.DirectedEdge_DEL, + factorySpecs, err := schema.State().FactoryCreateSpec(ctx, edge.Attr) + if err != nil { + return err + } + _, err = txn.addIndexMutations(ctx, &indexMutationInfo{ + tokenizers: schema.State().Tokenizer(ctx, edge.Attr), + factorySpecs: factorySpecs, + edge: edge, + val: val, + op: pb.DirectedEdge_DEL, }) + return err default: return nil } @@ -490,7 +583,7 @@ func (l *List) AddMutationWithIndex(ctx context.Context, edge *pb.DirectedEdge, if doUpdateIndex { // Exact matches. if found && val.Value != nil { - if err := txn.addIndexMutations(ctx, &indexMutationInfo{ + if _, err := txn.addIndexMutations(ctx, &indexMutationInfo{ tokenizers: schema.State().Tokenizer(ctx, edge.Attr), edge: edge, val: val, @@ -504,7 +597,7 @@ func (l *List) AddMutationWithIndex(ctx context.Context, edge *pb.DirectedEdge, Tid: types.TypeID(edge.ValueType), Value: edge.Value, } - if err := txn.addIndexMutations(ctx, &indexMutationInfo{ + if _, err := txn.addIndexMutations(ctx, &indexMutationInfo{ tokenizers: schema.State().Tokenizer(ctx, edge.Attr), edge: edge, val: val, @@ -551,7 +644,7 @@ type rebuilder struct { // The posting list passed here is the on disk version. It is not coming // from the LRU cache. - fn func(uid uint64, pl *List, txn *Txn) error + fn func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) } func (r *rebuilder) Run(ctx context.Context) error { @@ -560,6 +653,11 @@ func (r *rebuilder) Run(ctx context.Context) error { return nil } + pred, ok := schema.State().Get(ctx, r.attr) + if !ok { + return errors.Errorf("Rebuilder.run: Unable to find schema for %s", r.attr) + } + // We write the index in a temporary badger first and then, // merge entries before writing them to p directory. tmpIndexDir, err := os.MkdirTemp(x.WorkerConfig.TmpDir, "dgraph_index_") @@ -597,10 +695,16 @@ func (r *rebuilder) Run(ctx context.Context) error { // We set it to 1 in case there are no keys found and NewStreamAt is called with ts=0. var counter uint64 = 1 + var txn *Txn + tmpWriter := tmpDB.NewManagedWriteBatch() stream := pstore.NewStreamAt(r.startTs) stream.LogPrefix = fmt.Sprintf("Rebuilding index for predicate %s (1/2):", r.attr) stream.Prefix = r.prefix + //TODO We need to create a single transaction irrespective of the type of the predicate + if pred.ValueType == pb.Posting_VFLOAT { + txn = NewTxn(r.startTs) + } stream.KeyToList = func(key []byte, itr *badger.Iterator) (*bpb.KVList, error) { // We should return quickly if the context is no longer valid. select { @@ -622,17 +726,42 @@ func (r *rebuilder) Run(ctx context.Context) error { // We are using different transactions in each call to KeyToList function. This could // be a problem for computing reverse count indexes if deltas for same key are added // in different transactions. Such a case doesn't occur for now. - txn := NewTxn(r.startTs) - if err := r.fn(pk.Uid, l, txn); err != nil { + // TODO: Maybe we can always use txn initialized in rebuilder.Run(). + streamTxn := txn + if streamTxn == nil { + streamTxn = NewTxn(r.startTs) + } + edges, err := r.fn(pk.Uid, l, streamTxn) + if err != nil { return nil, err } - // Convert data into deltas. - txn.Update() + if txn != nil { + kvs := make([]*bpb.KV, 0, len(edges)) + for _, edge := range edges { + version := atomic.AddUint64(&counter, 1) + key := x.DataKey(edge.Attr, edge.Entity) + pl, err := txn.GetFromDelta(key) + if err != nil { + return &bpb.KVList{}, nil + } + data := pl.getMutation(r.startTs) + kv := bpb.KV{ + Key: x.DataKey(edge.Attr, edge.Entity), + Value: data, + UserMeta: []byte{BitDeltaPosting}, + Version: version, + } + kvs = append(kvs, &kv) + } + return &bpb.KVList{Kv: kvs}, nil + } + // Convert data into deltas. + streamTxn.Update() // txn.cache.Lock() is not required because we are the only one making changes to txn. - kvs := make([]*bpb.KV, 0, len(txn.cache.deltas)) - for key, data := range txn.cache.deltas { + kvs := make([]*bpb.KV, 0, len(streamTxn.cache.deltas)) + for key, data := range streamTxn.cache.deltas { version := atomic.AddUint64(&counter, 1) kv := bpb.KV{ Key: []byte(key), @@ -771,14 +900,26 @@ func (rb *IndexRebuild) GetQuerySchema() *pb.SchemaUpdate { // DropIndexes drops the indexes that need to be rebuilt. func (rb *IndexRebuild) DropIndexes(ctx context.Context) error { - prefixes, err := prefixesForTokIndexes(ctx, rb) + rebuildInfo := rb.needsTokIndexRebuild() + prefixes, err := rebuildInfo.prefixesForTokIndexes() if err != nil { return err } + vectorIndexPrefixes, err := rebuildInfo.prefixesForVectorIndexes() + if err != nil { + return nil + } + prefixes = append(prefixes, vectorIndexPrefixes...) prefixes = append(prefixes, prefixesToDropReverseEdges(ctx, rb)...) prefixes = append(prefixes, prefixesToDropCountIndex(ctx, rb)...) - glog.Infof("Deleting indexes for %s", rb.Attr) - return pstore.DropPrefix(prefixes...) + prefixes = append(prefixes, prefixesToDropVectorIndexEdges(ctx, rb)...) + if len(prefixes) > 0 { + // This trace message now gets logged only if there are any prefixes to + // to be deleted + glog.Infof("Deleting indexes for %s", rb.Attr) + return pstore.DropPrefix(prefixes...) + } + return nil } // BuildData updates data. @@ -806,12 +947,15 @@ func (rb *IndexRebuild) BuildIndexes(ctx context.Context) error { } type indexRebuildInfo struct { - op indexOp - tokenizersToDelete []string - tokenizersToRebuild []string + op indexOp + attr string + tokenizersToDelete []string + tokenizersToRebuild []string + vectorIndexesToDelete []*pb.VectorIndexSpec + vectorIndexesToRebuild []*pb.VectorIndexSpec } -func (rb *IndexRebuild) needsTokIndexRebuild() indexRebuildInfo { +func (rb *IndexRebuild) needsTokIndexRebuild() *indexRebuildInfo { x.AssertTruef(rb.CurrentSchema != nil, "Current schema cannot be nil.") // If the old schema is nil, we can treat it as an empty schema. Copy it @@ -827,8 +971,9 @@ func (rb *IndexRebuild) needsTokIndexRebuild() indexRebuildInfo { // Index does not need to be rebuilt or deleted if the scheme directive // did not require an index before and now. if !currIndex && !prevIndex { - return indexRebuildInfo{ - op: indexNoop, + return &indexRebuildInfo{ + op: indexNoop, + attr: rb.Attr, } } @@ -837,19 +982,24 @@ func (rb *IndexRebuild) needsTokIndexRebuild() indexRebuildInfo { // prevIndex since the previous if statement guarantees both values are // different. if !currIndex { - return indexRebuildInfo{ - op: indexDelete, - tokenizersToDelete: old.Tokenizer, + return &indexRebuildInfo{ + op: indexDelete, + attr: rb.Attr, + tokenizersToDelete: old.Tokenizer, + vectorIndexesToDelete: old.IndexSpecs, } } // All tokenizers in the index need to be deleted and rebuilt if the value // types have changed. if currIndex && rb.CurrentSchema.ValueType != old.ValueType { - return indexRebuildInfo{ - op: indexRebuild, - tokenizersToDelete: old.Tokenizer, - tokenizersToRebuild: rb.CurrentSchema.Tokenizer, + return &indexRebuildInfo{ + op: indexRebuild, + attr: rb.Attr, + tokenizersToDelete: old.Tokenizer, + tokenizersToRebuild: rb.CurrentSchema.Tokenizer, + vectorIndexesToDelete: old.IndexSpecs, + vectorIndexesToRebuild: rb.CurrentSchema.IndexSpecs, } } @@ -865,63 +1015,129 @@ func (rb *IndexRebuild) needsTokIndexRebuild() indexRebuildInfo { newTokenizers, deletedTokenizers := x.Diff(currTokens, prevTokens) - // If the tokenizers are the same, nothing needs to be done. - if len(newTokenizers) == 0 && len(deletedTokenizers) == 0 { - return indexRebuildInfo{ - op: indexNoop, + prevFactoryNames := make(map[string]struct{}) + prevFactories := make(map[string]*pb.VectorIndexSpec) + for _, t := range old.IndexSpecs { + prevFactoryNames[t.Name] = struct{}{} + prevFactories[t.Name] = t + } + currFactoryNames := make(map[string]struct{}) + currFactories := make(map[string]*pb.VectorIndexSpec) + for _, t := range rb.CurrentSchema.IndexSpecs { + currFactoryNames[t.Name] = struct{}{} + currFactories[t.Name] = t + } + + newFactoryNames, deletedFactoryNames := x.Diff(currFactoryNames, prevFactoryNames) + + // If the tokenizers and factories are the same, nothing needs to be done. + if len(newTokenizers) == 0 && len(deletedTokenizers) == 0 && + len(newFactoryNames) == 0 && len(deletedFactoryNames) == 0 { + return &indexRebuildInfo{ + op: indexNoop, + attr: rb.Attr, } } + newFactories := []*pb.VectorIndexSpec{} + for _, name := range newFactoryNames { + newFactories = append(newFactories, currFactories[name]) + } + deletedFactories := []*pb.VectorIndexSpec{} + for _, name := range deletedFactoryNames { + deletedFactories = append(deletedFactories, prevFactories[name]) + } - return indexRebuildInfo{ - op: indexRebuild, - tokenizersToDelete: deletedTokenizers, - tokenizersToRebuild: newTokenizers, + return &indexRebuildInfo{ + op: indexRebuild, + attr: rb.Attr, + tokenizersToDelete: deletedTokenizers, + tokenizersToRebuild: newTokenizers, + vectorIndexesToDelete: deletedFactories, + vectorIndexesToRebuild: newFactories, } } -func prefixesForTokIndexes(ctx context.Context, rb *IndexRebuild) ([][]byte, error) { - rebuildInfo := rb.needsTokIndexRebuild() - prefixes := [][]byte{} +func (rb *indexRebuildInfo) appendTokenizerPrefixesToDelete( + tokenizer string, + priorPrefixes [][]byte) ([][]byte, error) { + retVal := priorPrefixes + prefixesNonLang, err := prefixesToDeleteTokensFor(rb.attr, tokenizer, false) + if err != nil { + return nil, err + } + retVal = append(retVal, prefixesNonLang...) + if tokenizer != "exact" { + return retVal, nil + } + prefixesWithLang, err := prefixesToDeleteTokensFor(rb.attr, tokenizer, true) + if err != nil { + return nil, err + } + return append(retVal, prefixesWithLang...), nil +} - if rebuildInfo.op == indexNoop { +// TODO: Kill this function. Rather than calculating prefixes -- like we do +// +// for tokenizers -- we should instead invoke the Remove(indexName) +// operation of the VectorIndexFactory, and have it do all the deletion. +// At the moment however, the Remove operation does not interact with +// Dgraph transactions, so this is not yet possible. +func (rb *indexRebuildInfo) prefixesForVectorIndexes() ([][]byte, error) { + prefixes := [][]byte{} + var err error + if rb.op == indexNoop { return prefixes, nil } - glog.Infof("Computing prefix index for attr %s and tokenizers %s", rb.Attr, - rebuildInfo.tokenizersToDelete) - for _, tokenizer := range rebuildInfo.tokenizersToDelete { - prefixesNonLang, err := prefixesToDeleteTokensFor(rb.Attr, tokenizer, false) + for _, vectorSpec := range rb.vectorIndexesToDelete { + glog.Infof("Computing prefix index for attr %s and index factory %s", + rb.attr, vectorSpec.Name) + // The mechanism currently is the same for tokenizers and + // vector factories. + prefixes, err = rb.appendTokenizerPrefixesToDelete(vectorSpec.Name, prefixes) if err != nil { return nil, err } - prefixes = append(prefixes, prefixesNonLang...) - if tokenizer != "exact" { - continue - } - prefixesWithLang, err := prefixesToDeleteTokensFor(rb.Attr, tokenizer, true) + } + + for _, vectorSpec := range rb.vectorIndexesToRebuild { + glog.Infof("Computing prefix index for attr %s and index factory %s", + rb.attr, vectorSpec.Name) + // The mechanism currently is the same for tokenizers and + // vector factories. + prefixes, err = rb.appendTokenizerPrefixesToDelete(vectorSpec.Name, prefixes) if err != nil { return nil, err } - prefixes = append(prefixes, prefixesWithLang...) } - glog.Infof("Deleting index for attr %s and tokenizers %s", rb.Attr, - rebuildInfo.tokenizersToRebuild) - // Before rebuilding, the existing index needs to be deleted. - for _, tokenizer := range rebuildInfo.tokenizersToRebuild { - prefixesNonLang, err := prefixesToDeleteTokensFor(rb.Attr, tokenizer, false) + return prefixes, nil +} + +func (rb *indexRebuildInfo) prefixesForTokIndexes() ([][]byte, error) { + prefixes := [][]byte{} + var err error + if rb.op == indexNoop { + return prefixes, nil + } + + glog.Infof("Computing prefix index for attr %s and tokenizers %s", rb.attr, + rb.tokenizersToDelete) + for _, tokenizer := range rb.tokenizersToDelete { + prefixes, err = rb.appendTokenizerPrefixesToDelete(tokenizer, prefixes) if err != nil { return nil, err } - prefixes = append(prefixes, prefixesNonLang...) - if tokenizer != "exact" { - continue - } - prefixesWithLang, err := prefixesToDeleteTokensFor(rb.Attr, tokenizer, true) + } + + glog.Infof("Deleting index for attr %s and tokenizers %s", rb.attr, + rb.tokenizersToRebuild) + // Before rebuilding, the existing index needs to be deleted. + for _, tokenizer := range rb.tokenizersToRebuild { + prefixes, err = rb.appendTokenizerPrefixesToDelete(tokenizer, prefixes) if err != nil { return nil, err } - prefixes = append(prefixes, prefixesWithLang...) } return prefixes, nil @@ -936,7 +1152,7 @@ func rebuildTokIndex(ctx context.Context, rb *IndexRebuild) error { } // Exit early if there are no tokenizers to rebuild. - if len(rebuildInfo.tokenizersToRebuild) == 0 { + if len(rebuildInfo.tokenizersToRebuild) == 0 && len(rebuildInfo.vectorIndexesToRebuild) == 0 { return nil } @@ -947,11 +1163,22 @@ func rebuildTokIndex(ctx context.Context, rb *IndexRebuild) error { return err } + var factorySpecs []*tok.FactoryCreateSpec + if len(rebuildInfo.vectorIndexesToRebuild) > 0 { + factorySpec, err := tok.GetFactoryCreateSpecFromSpec( + rebuildInfo.vectorIndexesToRebuild[0]) + if err != nil { + return err + } + factorySpecs = []*tok.FactoryCreateSpec{factorySpec} + } + pk := x.ParsedKey{Attr: rb.Attr} builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs} - builder.fn = func(uid uint64, pl *List, txn *Txn) error { + builder.fn = func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) { edge := pb.DirectedEdge{Attr: rb.Attr, Entity: uid} - return pl.Iterate(txn.StartTs, 0, func(p *pb.Posting) error { + edges := []*pb.DirectedEdge{} + err := pl.Iterate(txn.StartTs, 0, func(p *pb.Posting) error { // Add index entries based on p. val := types.Val{ Value: p.Value, @@ -960,20 +1187,26 @@ func rebuildTokIndex(ctx context.Context, rb *IndexRebuild) error { edge.Lang = string(p.LangTag) for { - err := txn.addIndexMutations(ctx, &indexMutationInfo{ - tokenizers: tokenizers, - edge: &edge, - val: val, - op: pb.DirectedEdge_SET, + newEdges, err := txn.addIndexMutations(ctx, &indexMutationInfo{ + tokenizers: tokenizers, + factorySpecs: factorySpecs, + edge: &edge, + val: val, + op: pb.DirectedEdge_SET, }) switch err { case ErrRetry: time.Sleep(10 * time.Millisecond) default: + edges = append(edges, newEdges...) return err } } }) + if err != nil { + return []*pb.DirectedEdge{}, err + } + return edges, err } return builder.Run(ctx) } @@ -1038,7 +1271,7 @@ func rebuildCountIndex(ctx context.Context, rb *IndexRebuild) error { glog.Infof("Rebuilding count index for %s", rb.Attr) var reverse bool - fn := func(uid uint64, pl *List, txn *Txn) error { + fn := func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) { t := &pb.DirectedEdge{ ValueId: uid, Attr: rb.Attr, @@ -1046,7 +1279,7 @@ func rebuildCountIndex(ctx context.Context, rb *IndexRebuild) error { } sz := pl.Length(rb.StartTs, 0) if sz == -1 { - return nil + return []*pb.DirectedEdge{}, nil } for { err := txn.addCountMutation(ctx, t, uint32(sz), reverse) @@ -1054,7 +1287,7 @@ func rebuildCountIndex(ctx context.Context, rb *IndexRebuild) error { case ErrRetry: time.Sleep(10 * time.Millisecond) default: - return err + return []*pb.DirectedEdge{}, err } } } @@ -1077,6 +1310,53 @@ func rebuildCountIndex(ctx context.Context, rb *IndexRebuild) error { return builder.Run(ctx) } +func (rb *IndexRebuild) needsVectorIndexEdgesRebuild() indexOp { + x.AssertTruef(rb.CurrentSchema != nil, "Current schema cannot be nil.") + + // If old schema is nil, treat it as an empty schema. Copy it to avoid + // overwriting it in rb. + old := rb.OldSchema + if old == nil { + old = &pb.SchemaUpdate{} + } + + currIndex := rb.CurrentSchema.Directive == pb.SchemaUpdate_INDEX && + rb.CurrentSchema.ValueType == pb.Posting_VFLOAT + prevIndex := old.Directive == pb.SchemaUpdate_INDEX && + old.ValueType == pb.Posting_VFLOAT + + // If the schema directive did not change, return indexNoop. + if currIndex == prevIndex { + return indexNoop + } + + // If the current schema requires an index, index should be rebuilt. + if currIndex { + return indexRebuild + } + // Otherwise, index should only be deleted. + return indexDelete +} + +// This needs to be moved to the implementation of vector-indexer API +func prefixesToDropVectorIndexEdges(ctx context.Context, rb *IndexRebuild) [][]byte { + // Exit early if indices do not need to be rebuilt. + op := rb.needsVectorIndexEdgesRebuild() + if op == indexNoop { + return nil + } + + prefixes := append([][]byte{}, x.PredicatePrefix(hnsw.ConcatStrings(rb.Attr, hnsw.VecEntry))) + prefixes = append(prefixes, x.PredicatePrefix(hnsw.ConcatStrings(rb.Attr, hnsw.VecDead))) + prefixes = append(prefixes, x.PredicatePrefix(hnsw.ConcatStrings(rb.Attr, hnsw.VecKeyword))) + + for i := 0; i < hnsw.VectorIndexMaxLevels; i++ { + prefixes = append(prefixes, x.PredicatePrefix(hnsw.ConcatStrings(rb.Attr, hnsw.VecKeyword, fmt.Sprint(i)))) + } + + return prefixes +} + func (rb *IndexRebuild) needsReverseEdgesRebuild() indexOp { x.AssertTruef(rb.CurrentSchema != nil, "Current schema cannot be nil.") @@ -1132,9 +1412,9 @@ func rebuildReverseEdges(ctx context.Context, rb *IndexRebuild) error { glog.Infof("Rebuilding reverse index for %s", rb.Attr) pk := x.ParsedKey{Attr: rb.Attr} builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs} - builder.fn = func(uid uint64, pl *List, txn *Txn) error { + builder.fn = func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) { edge := pb.DirectedEdge{Attr: rb.Attr, Entity: uid} - return pl.Iterate(txn.StartTs, 0, func(pp *pb.Posting) error { + return []*pb.DirectedEdge{}, pl.Iterate(txn.StartTs, 0, func(pp *pb.Posting) error { puid := pp.Uid // Add reverse entries based on p. edge.ValueId = puid @@ -1185,7 +1465,7 @@ func rebuildListType(ctx context.Context, rb *IndexRebuild) error { pk := x.ParsedKey{Attr: rb.Attr} builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs} - builder.fn = func(uid uint64, pl *List, txn *Txn) error { + builder.fn = func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) { var mpost *pb.Posting err := pl.Iterate(txn.StartTs, 0, func(p *pb.Posting) error { // We only want to modify the untagged value. There could be other values with a @@ -1196,10 +1476,10 @@ func rebuildListType(ctx context.Context, rb *IndexRebuild) error { return nil }) if err != nil { - return err + return []*pb.DirectedEdge{}, err } if mpost == nil { - return nil + return []*pb.DirectedEdge{}, nil } // Delete the old edge corresponding to ValueId math.MaxUint64 t := &pb.DirectedEdge{ @@ -1212,7 +1492,7 @@ func rebuildListType(ctx context.Context, rb *IndexRebuild) error { // get updated. pl = txn.cache.SetIfAbsent(string(pl.key), pl) if err := pl.addMutation(ctx, txn, t); err != nil { - return err + return []*pb.DirectedEdge{}, err } // Add the new edge with the fingerprinted value id. newEdge := &pb.DirectedEdge{ @@ -1222,7 +1502,7 @@ func rebuildListType(ctx context.Context, rb *IndexRebuild) error { Op: pb.DirectedEdge_SET, Facets: mpost.Facets, } - return pl.addMutation(ctx, txn, newEdge) + return []*pb.DirectedEdge{}, pl.addMutation(ctx, txn, newEdge) } return builder.Run(ctx) } @@ -1243,12 +1523,18 @@ func DeleteData(ns uint64) error { // DeletePredicate deletes all entries and indices for a given predicate. func DeletePredicate(ctx context.Context, attr string, ts uint64) error { glog.Infof("Dropping predicate: [%s]", attr) - prefix := x.PredicatePrefix(attr) - if err := pstore.DropPrefix(prefix); err != nil { - return err + preds := schema.State().PredicatesToDelete(attr) + for _, pred := range preds { + prefix := x.PredicatePrefix(pred) + if err := schema.State().Delete(pred, ts); err != nil { + return err + } + if err := pstore.DropPrefix(prefix); err != nil { + return err + } } - return schema.State().Delete(attr, ts) + return nil } // DeleteNamespace bans the namespace and deletes its predicates/types from the schema. diff --git a/posting/index_test.go b/posting/index_test.go index b8098e644a7..03f2f74cc0e 100644 --- a/posting/index_test.go +++ b/posting/index_test.go @@ -267,7 +267,8 @@ func TestRebuildTokIndex(t *testing.T) { OldSchema: nil, CurrentSchema: ¤tSchema, } - prefixes, err := prefixesForTokIndexes(context.Background(), &rb) + rebuildInfo := rb.needsTokIndexRebuild() + prefixes, err := rebuildInfo.prefixesForTokIndexes() require.NoError(t, err) require.NoError(t, pstore.DropPrefix(prefixes...)) require.NoError(t, rebuildTokIndex(context.Background(), &rb)) @@ -320,7 +321,8 @@ func TestRebuildTokIndexWithDeletion(t *testing.T) { OldSchema: nil, CurrentSchema: ¤tSchema, } - prefixes, err := prefixesForTokIndexes(context.Background(), &rb) + rebuildInfo := rb.needsTokIndexRebuild() + prefixes, err := rebuildInfo.prefixesForTokIndexes() require.NoError(t, err) require.NoError(t, pstore.DropPrefix(prefixes...)) require.NoError(t, rebuildTokIndex(context.Background(), &rb)) @@ -334,7 +336,8 @@ func TestRebuildTokIndexWithDeletion(t *testing.T) { OldSchema: ¤tSchema, CurrentSchema: &newSchema, } - prefixes, err = prefixesForTokIndexes(context.Background(), &rb) + rebuildInfo = rb.needsTokIndexRebuild() + prefixes, err = rebuildInfo.prefixesForTokIndexes() require.NoError(t, err) require.NoError(t, pstore.DropPrefix(prefixes...)) require.NoError(t, rebuildTokIndex(context.Background(), &rb)) diff --git a/posting/list.go b/posting/list.go index cba042d7126..bc4ca8e61bf 100644 --- a/posting/list.go +++ b/posting/list.go @@ -38,6 +38,7 @@ import ( "github.com/dgraph-io/dgraph/types/facets" "github.com/dgraph-io/dgraph/x" "github.com/dgraph-io/ristretto/z" + "github.com/dgraph-io/dgraph/tok/index" ) var ( @@ -78,6 +79,16 @@ type List struct { maxTs uint64 // max commit timestamp seen for this list. } +func indexEdgeToPbEdge(t *index.KeyValue) *pb.DirectedEdge { + return &pb.DirectedEdge{ + Entity: t.Entity, + Attr: t.Attr, + Value: t.Value, + ValueType: pb.Posting_ValType(0), + Op: pb.DirectedEdge_SET, + } +} + // NewList returns a new list with an immutable layer set to plist and the // timestamp of the immutable layer set to minTs. func NewList(key []byte, plist *pb.PostingList, minTs uint64) *List { @@ -1298,9 +1309,14 @@ func (l *List) GetLangTags(readTs uint64) ([]string, error) { // Value returns the default value from the posting list. The default value is // defined as the value without a language tag. +// Value cannot be used to read from cache func (l *List) Value(readTs uint64) (rval types.Val, rerr error) { l.RLock() defer l.RUnlock() + return l.ValueWithLockHeld(readTs) +} + +func (l *List) ValueWithLockHeld(readTs uint64) (rval types.Val, rerr error) { val, found, err := l.findValue(readTs, math.MaxUint64) if err != nil { return val, errors.Wrapf(err, diff --git a/posting/lists.go b/posting/lists.go index 5c70fc5914b..8c73d45ed56 100644 --- a/posting/lists.go +++ b/posting/lists.go @@ -17,6 +17,7 @@ package posting import ( + "bytes" "context" "fmt" "sync" @@ -30,6 +31,7 @@ import ( "github.com/dgraph-io/dgraph/x" "github.com/dgraph-io/ristretto" "github.com/dgraph-io/ristretto/z" + "github.com/dgraph-io/dgraph/tok/index" ) const ( @@ -113,6 +115,44 @@ type LocalCache struct { plists map[string]*List } +// struct to implement LocalCache interface from vector-indexer +// acts as wrapper for dgraph *LocalCache +type viLocalCache struct { + delegate *LocalCache +} + +func (vc *viLocalCache) Find(prefix []byte, filter func([]byte) bool) (uint64, error) { + return vc.delegate.Find(prefix, filter) +} + +func (vc *viLocalCache) Get(key []byte) (rval index.Value, rerr error) { + pl, err := vc.delegate.Get(key) + if err != nil { + return nil, err + } + pl.Lock() + defer pl.Unlock() + return vc.GetValueFromPostingList(pl) +} + +func (vc *viLocalCache) GetWithLockHeld(key []byte) (rval index.Value, rerr error) { + pl, err := vc.delegate.Get(key) + if err != nil { + return nil, err + } + return vc.GetValueFromPostingList(pl) +} + +func (vc *viLocalCache) GetValueFromPostingList(pl *List) (rval index.Value, rerr error) { + val, err := pl.ValueWithLockHeld(vc.delegate.startTs) + rval = val.Value + return rval, err +} + +func NewViLocalCache(delegate *LocalCache) *viLocalCache { + return &viLocalCache{delegate: delegate} +} + // NewLocalCache returns a new LocalCache instance. func NewLocalCache(startTs uint64) *LocalCache { return &LocalCache{ @@ -129,6 +169,89 @@ func NoCache(startTs uint64) *LocalCache { return &LocalCache{startTs: startTs} } +func (lc *LocalCache) Find(pred []byte, filter func([]byte) bool) (uint64, error) { + txn := pstore.NewTransactionAt(lc.startTs, false) + defer txn.Discard() + + attr := string(pred) + + initKey := x.ParsedKey{ + Attr: attr, + } + startKey := x.DataKey(attr, 0) + prefix := initKey.DataPrefix() + + result := &pb.List{} + var prevKey []byte + itOpt := badger.DefaultIteratorOptions + itOpt.PrefetchValues = false + itOpt.AllVersions = true + itOpt.Prefix = prefix + it := txn.NewIterator(itOpt) + defer it.Close() + + for it.Seek(startKey); it.Valid(); { + item := it.Item() + if bytes.Equal(item.Key(), prevKey) { + it.Next() + continue + } + prevKey = append(prevKey[:0], item.Key()...) + + // Parse the key upfront, otherwise ReadPostingList would advance the + // iterator. + pk, err := x.Parse(item.Key()) + if err != nil { + return 0, err + } + + // If we have moved to the next attribute, break + if pk.Attr != attr { + break + } + + if pk.HasStartUid { + // The keys holding parts of a split key should not be accessed here because + // they have a different prefix. However, the check is being added to guard + // against future bugs. + continue + } + + switch { + case item.UserMeta()&BitEmptyPosting > 0: + // This is an empty posting list. So, it should not be included. + continue + default: + // This bit would only be set if there are valid uids in UidPack. + key := x.DataKey(attr, pk.Uid) + pl, err := lc.Get(key) + if err != nil { + return 0, err + } + vals, err := pl.Value(lc.startTs) + switch { + case err == ErrNoValue: + continue + case err != nil: + return 0, err + } + + if filter(vals.Value.([]byte)) { + result.Uids = append(result.Uids, pk.Uid) + break + } + + continue + } + } + + if len(result.Uids) > 0 { + return result.Uids[0], nil + } + + return 0, badger.ErrKeyNotFound +} + func (lc *LocalCache) getNoStore(key string) *List { lc.RLock() defer lc.RUnlock() diff --git a/posting/oracle.go b/posting/oracle.go index 435a7853602..e5a384715b6 100644 --- a/posting/oracle.go +++ b/posting/oracle.go @@ -29,6 +29,7 @@ import ( "github.com/dgraph-io/dgraph/protos/pb" "github.com/dgraph-io/dgraph/x" + "github.com/dgraph-io/dgraph/tok/index" ) var o *oracle @@ -65,6 +66,74 @@ type Txn struct { cache *LocalCache // This pointer does not get modified. } +// struct to implement Txn interface from vector-indexer +// acts as wrapper for dgraph *Txn +type viTxn struct { + delegate *Txn +} + +func NewViTxn(delegate *Txn) *viTxn { + return &viTxn{delegate: delegate} +} + +func (vt *viTxn) Find(prefix []byte, filter func([]byte) bool) (uint64, error) { + return vt.delegate.cache.Find(prefix, filter) +} + +func (vt *viTxn) StartTs() uint64 { + return vt.delegate.StartTs +} + +func (vt *viTxn) Get(key []byte) (rval index.Value, rerr error) { + pl, err := vt.delegate.cache.Get(key) + if err != nil { + return nil, err + } + pl.Lock() + defer pl.Unlock() + return vt.GetValueFromPostingList(pl) +} + +func (vt *viTxn) GetWithLockHeld(key []byte) (rval index.Value, rerr error) { + pl, err := vt.delegate.cache.Get(key) + if err != nil { + return nil, err + } + return vt.GetValueFromPostingList(pl) +} + +func (vt *viTxn) GetValueFromPostingList(pl *List) (rval index.Value, rerr error) { + val, err := pl.ValueWithLockHeld(vt.delegate.StartTs) + rval = val.Value + return rval, err +} + +func (vt *viTxn) AddMutation(ctx context.Context, key []byte, t *index.KeyValue) error { + pl, err := vt.delegate.cache.Get(key) + if err != nil { + return err + } + return pl.addMutation(ctx, vt.delegate, indexEdgeToPbEdge(t)) +} + +func (vt *viTxn) AddMutationWithLockHeld(ctx context.Context, key []byte, t *index.KeyValue) error { + pl, err := vt.delegate.cache.Get(key) + if err != nil { + return err + } + return pl.addMutationInternal(ctx, vt.delegate, indexEdgeToPbEdge(t)) +} + +func (vt *viTxn) LockKey(key []byte) { + pl, _ := vt.delegate.cache.Get(key) + pl.Lock() +} + +func (vt *viTxn) UnlockKey(key []byte) { + pl, _ := vt.delegate.cache.Get(key) + pl.Unlock() +} + // NewTxn returns a new Txn instance. func NewTxn(startTs uint64) *Txn { return &Txn{ diff --git a/query/cloud_test.go b/query/cloud_test.go index 9aea8b04e3e..fae5ff42e17 100644 --- a/query/cloud_test.go +++ b/query/cloud_test.go @@ -41,6 +41,6 @@ func TestMain(m *testing.M) { dc = c client = dg.Dgraph - populateCluster() + populateCluster(dc) m.Run() } diff --git a/query/common_test.go b/query/common_test.go index a20ea9bcf8f..a11063be65f 100644 --- a/query/common_test.go +++ b/query/common_test.go @@ -29,6 +29,7 @@ import ( "github.com/stretchr/testify/require" "github.com/dgraph-io/dgo/v230/protos/api" + "github.com/dgraph-io/dgraph/dgraphtest" "github.com/dgraph-io/dgraph/x" ) @@ -253,14 +254,6 @@ type SchoolInfo { county } -type User { - name - password - gender - friend - alive -} - type Node { node name @@ -344,10 +337,9 @@ tweet-c : string @index(fulltext) . tweet-d : string @index(trigram) . name2 : string @index(term) . age2 : int @index(int) . -vectorNonIndex : float32vector . ` -func populateCluster() { +func populateCluster(dc dgraphtest.Cluster) { x.Panic(client.Alter(context.Background(), &api.Operation{DropAll: true})) // In the query package, we test using hard coded UIDs so that we know what results @@ -355,10 +347,30 @@ func populateCluster() { // all the UIDs we are using during the tests. x.Panic(dc.AssignUids(client.Dgraph, 65536)) - setSchema(testSchema) - err := addTriplesToCluster(` - <1> "[1.0, 1.0, 2.0, 2.0]" . - <2> "[2.0, 1.0, 2.0, 2.0]" . + higher, err := dgraphtest.IsHigherVersion(dc.GetVersion(), "160a0faa5fc6233fdc5a4caa4a7a3d1591f460d0") + x.Panic(err) + var ts string + if higher { + ts = testSchema + `type User { + name + password + gender + friend + alive + user_profile + } + user_profile : float32vector @index(hnsw(metric:"euclidian")) .` + } else { + ts = testSchema + `type User { + name + password + gender + friend + alive + }` + } + setSchema(ts) + err = addTriplesToCluster(` <1> "Michonne" . <2> "King Lear" . <3> "Margaret" . diff --git a/query/integration_test.go b/query/integration_test.go index 901a1d11442..b05b2ae8bed 100644 --- a/query/integration_test.go +++ b/query/integration_test.go @@ -37,6 +37,6 @@ func TestMain(m *testing.M) { x.Panic(client.LoginIntoNamespace(context.Background(), dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) - populateCluster() + populateCluster(dc) m.Run() } diff --git a/query/query.go b/query/query.go index b0efa7f0eed..d11f1cb6900 100644 --- a/query/query.go +++ b/query/query.go @@ -298,6 +298,8 @@ type SubGraph struct { List bool // whether predicate is of list type pathMeta *pathMetadata + + vectorMetrics map[string]uint64 } func (sg *SubGraph) recurse(set func(sg *SubGraph)) { @@ -1177,6 +1179,12 @@ func (sg *SubGraph) transformVars(doneVars map[string]varValue, path []*SubGraph mt.Const = val continue } + // TODO: Need to understand why certain aggregations map to uid = 0 + // while others map to uid = MaxUint64 + if val, ok := newMap[0]; ok && len(newMap) == 1 { + mt.Const = val + continue + } mt.Val = newMap } @@ -1259,8 +1267,10 @@ func (sg *SubGraph) valueVarAggregation(doneVars map[string]varValue, path []*Su } if rangeOver == nil { it := doneVars[sg.Params.Var] + mp[0] = sg.MathExp.Const it.Vals = mp doneVars[sg.Params.Var] = it + sg.Params.UidToVal = mp return nil } for _, uid := range rangeOver.Uids { @@ -2170,6 +2180,7 @@ func ProcessGraph(ctx context.Context, sg, parent *SubGraph, rch chan error) { sg.counts = result.Counts sg.LangTags = result.LangMatrix sg.List = result.List + sg.vectorMetrics = result.VectorMetrics if sg.Params.DoCount { if len(sg.Filters) == 0 { @@ -2648,7 +2659,7 @@ func isValidArg(a string) bool { func isValidFuncName(f string) bool { switch f { case "anyofterms", "allofterms", "val", "regexp", "anyoftext", "alloftext", - "has", "uid", "uid_in", "anyof", "allof", "type", "match": + "has", "uid", "uid_in", "anyof", "allof", "type", "match", "similar_to": return true } return isInequalityFn(f) || types.IsGeoFunc(f) @@ -2846,6 +2857,14 @@ func (req *Request) ProcessQuery(ctx context.Context) (err error) { continue } + // Just as above, no need to execute "similar_to" query if the + // vector parameter was a Var and evaluated as empty + if sg.SrcFunc != nil && sg.SrcFunc.Name == "similar_to" && + len(sg.SrcFunc.Args) == 1 && len(sg.Params.NeedsVar) > 0 { + errChan <- nil + continue + } + switch { case sg.Params.Alias == "shortest": // We allow only one shortest path block per query. @@ -3030,4 +3049,9 @@ func calculateMetrics(sg *SubGraph, metrics map[string]uint64) { for _, child := range sg.Children { calculateMetrics(child, metrics) } + if sg.vectorMetrics != nil { + for key, value := range sg.vectorMetrics { + metrics[key] += value + } + } } diff --git a/query/query1_test.go b/query/query1_test.go index a6c54cbdff5..67ddf3195fe 100644 --- a/query/query1_test.go +++ b/query/query1_test.go @@ -1649,6 +1649,23 @@ func TestAggregateEmpty3(t *testing.T) { require.JSONEq(t, `{"data": {"all":[]}}`, js) } +func TestAggregateEmpty4(t *testing.T) { + query := ` + { + var(func: type(User)) + { + up as user_profile + } + similar(func: similar_to(user_profile, 4, val(up))) + { + uid + } + } + ` + js := processQueryNoErr(t, query) + require.JSONEq(t, `{"data": {"similar":[]}}`, js) +} + func TestFilterLang(t *testing.T) { // This tests the fix for #1334. While getting uids for filter, we fetch data keys when number // of uids is less than number of tokens. Lang tag was not passed correctly while fetching these diff --git a/query/shortest.go b/query/shortest.go index be099e26e5a..262332e5676 100644 --- a/query/shortest.go +++ b/query/shortest.go @@ -242,7 +242,7 @@ func (sg *SubGraph) expandOut(ctx context.Context, if numEdges > x.Config.LimitQueryEdge { // If we've seen too many edges, stop the query. rch <- errors.Errorf("Exceeded query edge limit = %v. Found %v edges.", - x.Config.LimitQueryEdge, numEdges) + x.Config.LimitMutationsNquad, numEdges) return } diff --git a/query/upgrade_test.go b/query/upgrade_test.go index 51c3e7253f5..69dd62e4c44 100644 --- a/query/upgrade_test.go +++ b/query/upgrade_test.go @@ -39,7 +39,7 @@ func TestMain(m *testing.M) { client = dg dc = c - populateCluster() + populateCluster(dc) } query := func(c dgraphtest.Cluster) int { diff --git a/query/vector_test.go b/query/vector_test.go new file mode 100644 index 00000000000..8ce61fdc40d --- /dev/null +++ b/query/vector_test.go @@ -0,0 +1,493 @@ +//go:build integration + +/* + * Copyright 2023 Dgraph Labs, Inc. and Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package query + +import ( + "context" + "encoding/json" + "fmt" + "math/rand" + "strings" + "testing" + + "github.com/dgraph-io/dgo/v230/protos/api" + "github.com/stretchr/testify/require" +) + +var ( + vectorSchemaWithIndex = `%v: float32vector @index(hnsw(exponent: "%v", metric: "%v")) .` +) + +const ( + vectorSchemaWithoutIndex = `%v: float32vector .` +) + +func updateVector(t *testing.T, triple string, pred string) []float32 { + uid := strings.Split(triple, " ")[0] + randomVec := generateRandomVector(10) + updatedTriple := fmt.Sprintf("%s <%s> \"%v\" .", uid, pred, randomVec) + require.NoError(t, addTriplesToCluster(updatedTriple)) + + updatedVec, err := queryVectorUsingUid(t, uid, pred) + require.NoError(t, err) + require.Equal(t, randomVec, updatedVec) + return updatedVec +} + +func queryVectorUsingUid(t *testing.T, uid, pred string) ([]float32, error) { + vectorQuery := fmt.Sprintf(` + { + vector(func: uid(%v)) { + %v + } + }`, uid, pred) + + resp, err := client.Query(vectorQuery) + require.NoError(t, err) + + type VectorData struct { + VTest []float32 `json:"vtest"` + } + + type Data struct { + Vector []VectorData `json:"vector"` + } + + var data Data + + err = json.Unmarshal([]byte(resp.Json), &data) + if err != nil { + return []float32{}, err + } + + return data.Vector[0].VTest, nil + +} + +func queryMultipleVectorsUsingSimilarTo(t *testing.T, vector []float32, pred string, topK int) ([][]float32, error) { + vectorQuery := fmt.Sprintf(` + { + vector(func: similar_to(%v, %v, "%v")) { + uid + %v + } + }`, pred, topK, vector, pred) + + resp, err := client.Query(vectorQuery) + require.NoError(t, err) + + type VectorData struct { + UID string `json:"uid"` + VTest []float32 `json:"vtest"` + } + + type Data struct { + Vector []VectorData `json:"vector"` + } + + var data Data + + err = json.Unmarshal([]byte(resp.Json), &data) + if err != nil { + return [][]float32{}, err + } + + var vectors [][]float32 + for _, vector := range data.Vector { + vectors = append(vectors, vector.VTest) + } + return vectors, nil +} + +func querySingleVectorError(t *testing.T, vector, pred string, validateError bool) ([]float32, error) { + + vectorQuery := fmt.Sprintf(` + { + vector(func: similar_to(%v, 1, "%v")) { + uid + %v + } + }`, pred, vector, pred) + + resp, err := client.Query(vectorQuery) + if validateError { + require.NoError(t, err) + } else if err != nil { + return []float32{}, err + } + + type VectorData struct { + UID string `json:"uid"` + VTest []float32 `json:"vtest"` + } + + type Data struct { + Vector []VectorData `json:"vector"` + } + + var data Data + + err = json.Unmarshal([]byte(resp.Json), &data) + if err != nil { + return []float32{}, err + } + + return data.Vector[0].VTest, nil +} + +func querySingleVector(t *testing.T, vector, pred string) ([]float32, error) { + return querySingleVectorError(t, vector, pred, true) +} + +func queryAllVectorsPred(t *testing.T, pred string) ([][]float32, error) { + vectorQuery := fmt.Sprintf(` + { + vector(func: has(%v)) { + uid + %v + } + }`, pred, pred) + + resp, err := client.Query(vectorQuery) + require.NoError(t, err) + + type VectorData struct { + UID string `json:"uid"` + VTest []float32 `json:"vtest"` + } + + type Data struct { + Vector []VectorData `json:"vector"` + } + + var data Data + + err = json.Unmarshal([]byte(resp.Json), &data) + if err != nil { + return [][]float32{}, err + } + + var vectors [][]float32 + for _, vector := range data.Vector { + vectors = append(vectors, vector.VTest) + } + return vectors, nil +} + +func generateRandomVector(size int) []float32 { + vector := make([]float32, size) + for i := 0; i < size; i++ { + vector[i] = rand.Float32() * 10 + } + return vector +} + +func formatVector(label string, vector []float32, index int) string { + vectorString := fmt.Sprintf(`"[%s]"`, strings.Trim(strings.Join(strings.Fields(fmt.Sprint(vector)), ", "), "[]")) + return fmt.Sprintf("<0x%x> <%s> %s . \n", index+10, label, vectorString) +} + +func generateRandomVectors(numVectors, vectorSize int, label string) (string, [][]float32) { + var builder strings.Builder + var vectors [][]float32 + // builder.WriteString("`") + for i := 0; i < numVectors; i++ { + randomVector := generateRandomVector(vectorSize) + vectors = append(vectors, randomVector) + formattedVector := formatVector(label, randomVector, i) + builder.WriteString(formattedVector) + } + + return builder.String(), vectors +} + +func testVectorMutationSameLength(t *testing.T) { + rdf, vectors := generateRandomVectors(10, 5, "vtest") + require.NoError(t, addTriplesToCluster(rdf)) + + allVectors, err := queryAllVectorsPred(t, "vtest") + require.NoError(t, err) + + require.Equal(t, vectors, allVectors) + + triple := strings.Split(rdf, "\n")[1] + vector, err := querySingleVector(t, strings.Split(triple, `"`)[1], "vtest") + require.NoError(t, err) + require.Contains(t, allVectors, vector) + + triple = strings.Split(rdf, "\n")[3] + vector, err = querySingleVector(t, strings.Split(triple, `"`)[1], "vtest") + require.NoError(t, err) + require.Contains(t, allVectors, vector) + + triple = strings.Split(rdf, "\n")[5] + vector, err = querySingleVector(t, strings.Split(triple, `"`)[1], "vtest") + require.NoError(t, err) + require.Contains(t, allVectors, vector) + + triple = strings.Split(rdf, "\n")[7] + vector, err = querySingleVector(t, strings.Split(triple, `"`)[1], "vtest") + require.NoError(t, err) + require.Contains(t, allVectors, vector) + + triple = strings.Split(rdf, "\n")[9] + vector, err = querySingleVector(t, strings.Split(triple, `"`)[1], "vtest") + require.NoError(t, err) + require.Contains(t, allVectors, vector) +} + +func testVectorMutationDiffrentLength(t *testing.T, err string) { + rdf := `<0x1> "[1.5]" . + <0x2> "[1.5, 2.0]" . + <0x3> "[1.5, 2.0, 3.0]" . + <0x4> "[1.5, 2.0, 3.0, 4.5]" . + <0x5> "[1.5, 2.0, 3.0, 4.5, 5.0]" . + <0x6> "[1.5, 2.0, 3.0, 4.5, 5.0, 6.5]" . + <0x7> "[1.5, 2.0, 3.0, 4.5, 5.0, 6.5, 7.0]" . + <0x8> "[1.5, 2.0, 3.0, 4.5, 5.0, 6.5, 7.0, 8.5]" . + <0x9> "[1.5, 2.0, 3.0, 4.5, 5.0, 6.5, 7.0, 8.5, 9.0]" . + <0xA> "[1.5, 2.0, 3.0, 4.5, 5.0, 6.5, 7.0, 8.5, 9.0, 10.5]" .` + + require.ErrorContains(t, addTriplesToCluster(rdf), err) +} + +func TestVectorsMutateFixedLengthWithDiffrentIndexes(t *testing.T) { + dropPredicate("vtest") + + setSchema(fmt.Sprintf(vectorSchemaWithIndex, "vtest", "4", "euclidian")) + testVectorMutationSameLength(t) + dropPredicate("vtest") + + setSchema(fmt.Sprintf(vectorSchemaWithIndex, "vtest", "4", "cosine")) + testVectorMutationSameLength(t) + dropPredicate("vtest") + + setSchema(fmt.Sprintf(vectorSchemaWithIndex, "vtest", "4", "dot_product")) + testVectorMutationSameLength(t) + dropPredicate("vtest") +} + +func TestVectorMutateDiffrentLengthWithDiffrentIndexes(t *testing.T) { + dropPredicate("vtest") + + setSchema(fmt.Sprintf(vectorSchemaWithIndex, "vtest", "4", "euclidian")) + testVectorMutationDiffrentLength(t, "can not subtract vectors of different lengths") + dropPredicate("vtest") + + setSchema(fmt.Sprintf(vectorSchemaWithIndex, "vtest", "4", "cosine")) + testVectorMutationDiffrentLength(t, "can not compute dot product on vectors of different lengths") + dropPredicate("vtest") + + setSchema(fmt.Sprintf(vectorSchemaWithIndex, "vtest", "4", "dot_product")) + testVectorMutationDiffrentLength(t, "can not subtract vectors of different lengths") + dropPredicate("vtest") +} + +func TestVectorReindex(t *testing.T) { + dropPredicate("vtest") + + pred := "vtest" + + setSchema(fmt.Sprintf(vectorSchemaWithIndex, pred, "4", "euclidian")) + + numVectors := 100 + vectorSize := 4 + + randomVectors, allVectors := generateRandomVectors(numVectors, vectorSize, pred) + require.NoError(t, addTriplesToCluster(randomVectors)) + + setSchema(fmt.Sprintf(vectorSchemaWithoutIndex, pred)) + + query := `{ + vector(func: has(vtest)) { + count(uid) + } + }` + + result := processQueryNoErr(t, query) + require.JSONEq(t, `{"data": {"vector":[{"count":100}]}}`, result) + + triple := strings.Split(randomVectors, "\n")[0] + _, err := querySingleVectorError(t, strings.Split(triple, `"`)[1], "vtest", false) + require.NotNil(t, err) + + setSchema(fmt.Sprintf(vectorSchemaWithIndex, pred, "4", "euclidian")) + vector, err := querySingleVector(t, strings.Split(triple, `"`)[1], "vtest") + require.NoError(t, err) + require.Contains(t, allVectors, vector) +} + +func TestVectorMutationWithoutIndex(t *testing.T) { + dropPredicate("vtest") + + pred := "vtest" + setSchema(fmt.Sprintf(vectorSchemaWithoutIndex, pred)) + + numVectors := 1000 + vectorSize := 4 + + randomVectors, _ := generateRandomVectors(numVectors, vectorSize, pred) + require.NoError(t, addTriplesToCluster(randomVectors)) + + query := `{ + vector(func: has(vtest)) { + count(uid) + } + }` + + result := processQueryNoErr(t, query) + require.JSONEq(t, fmt.Sprintf(`{"data": {"vector":[{"count":%d}]}}`, numVectors), result) + + dropPredicate("vtest") + + pred = "vtest2" + setSchema(fmt.Sprintf(vectorSchemaWithoutIndex, pred)) + + randomVectors, _ = generateRandomVectors(numVectors, vectorSize, pred) + require.NoError(t, addTriplesToCluster(randomVectors)) + + query = `{ + vector(func: has(vtest2)) { + count(uid) + } + }` + + result = processQueryNoErr(t, query) + require.JSONEq(t, fmt.Sprintf(`{"data": {"vector":[{"count":%d}]}}`, numVectors), result) + dropPredicate("vtest2") +} + +func TestVectorDelete(t *testing.T) { + pred := "vtest" + dropPredicate(pred) + + setSchema(fmt.Sprintf(vectorSchemaWithIndex, pred, "4", "euclidian")) + + numVectors := 1000 + rdf, vectors := generateRandomVectors(numVectors, 10, "vtest") + require.NoError(t, addTriplesToCluster(rdf)) + + query := `{ + vector(func: has(vtest)) { + count(uid) + } + }` + + result := processQueryNoErr(t, query) + require.JSONEq(t, fmt.Sprintf(`{"data": {"vector":[{"count":%d}]}}`, numVectors), result) + + allVectors, err := queryAllVectorsPred(t, "vtest") + require.NoError(t, err) + + require.Equal(t, vectors, allVectors) + + triples := strings.Split(rdf, "\n") + + deleteTriple := func(idx int) string { + triple := triples[idx] + + deleteTriplesInCluster(triple) + uid := strings.Split(triple, " ")[0] + query = fmt.Sprintf(`{ + vector(func: uid(%s)) { + vtest + } + }`, uid[1:len(uid)-1]) + + result = processQueryNoErr(t, query) + require.JSONEq(t, `{"data": {"vector":[]}}`, result) + return triple + + } + + for i := 0; i < len(triples)-2; i++ { + triple := deleteTriple(i) + vector, err := querySingleVector(t, strings.Split(triple, `"`)[1], "vtest") + require.NoError(t, err) + require.Contains(t, allVectors, vector) + } + + triple := deleteTriple(len(triples) - 2) + _, err = querySingleVectorError(t, strings.Split(triple, `"`)[1], "vtest", false) + require.NotNil(t, err) +} + +func TestVectorUpdate(t *testing.T) { + pred := "vtest" + dropPredicate(pred) + + setSchema(fmt.Sprintf(vectorSchemaWithIndex, pred, "4", "euclidian")) + + numVectors := 1000 + rdf, vectors := generateRandomVectors(1000, 10, "vtest") + require.NoError(t, addTriplesToCluster(rdf)) + + allVectors, err := queryAllVectorsPred(t, "vtest") + require.NoError(t, err) + + require.Equal(t, vectors, allVectors) + + updateVectorQuery := func(idx int) { + triple := strings.Split(rdf, "\n")[idx] + updatedVec := updateVector(t, triple, "vtest") + allVectors[idx] = updatedVec + + updatedVectors, err := queryMultipleVectorsUsingSimilarTo(t, allVectors[0], "vtest", 100) + require.NoError(t, err) + + for _, i := range updatedVectors { + require.Contains(t, allVectors, i) + } + } + + for i := 0; i < 1000; i++ { + idx := rand.Intn(numVectors) + updateVectorQuery(idx) + } +} + +func TestVectorTwoTxnWithoutCommit(t *testing.T) { + pred := "vtest" + dropPredicate(pred) + + setSchema(fmt.Sprintf(vectorSchemaWithIndex, pred, "4", "euclidian")) + + rdf, vectors := generateRandomVectors(5, 5, "vtest") + txn1 := client.NewTxn() + _, err := txn1.Mutate(context.Background(), &api.Mutation{ + SetNquads: []byte(rdf), + }) + require.NoError(t, err) + + rdf, _ = generateRandomVectors(5, 5, "vtest") + txn2 := client.NewTxn() + _, err = txn2.Mutate(context.Background(), &api.Mutation{ + SetNquads: []byte(rdf), + }) + require.NoError(t, err) + + require.NoError(t, txn1.Commit(context.Background())) + require.Error(t, txn2.Commit(context.Background())) + resp, err := queryMultipleVectorsUsingSimilarTo(t, vectors[0], "vtest", 5) + require.NoError(t, err) + + for i := 0; i < len(vectors); i++ { + require.Contains(t, resp, vectors[i]) + } +} diff --git a/tok/hnsw/helper.go b/tok/hnsw/helper.go index 52eac7e0321..03be45906e3 100644 --- a/tok/hnsw/helper.go +++ b/tok/hnsw/helper.go @@ -1,13 +1,20 @@ package hnsw import ( + "context" + "encoding/binary" + "encoding/json" + "log" "math" + "math/rand" "sort" "strconv" "strings" "github.com/chewxy/math32" c "github.com/dgraph-io/dgraph/tok/constraints" + "github.com/dgraph-io/dgraph/tok/index" + "github.com/getsentry/sentry-go" "github.com/pkg/errors" ) @@ -261,3 +268,417 @@ func GetSimType[T c.Float](indexType string, floatBits int) SimilarityType[T] { insortHeap: insortPersistentHeapAscending[T], isBetterScore: isBetterScoreForDistance[T]} } } + +// implements CacheType interface +type TxnCache struct { + txn index.Txn + startTs uint64 +} + +func (tc *TxnCache) Get(key []byte) (rval index.Value, rerr error) { + return tc.txn.Get(key) +} + +func (tc *TxnCache) Ts() uint64 { + return tc.startTs +} + +func (tc *TxnCache) Find(prefix []byte, filter func([]byte) bool) (uint64, error) { + return tc.txn.Find(prefix, filter) +} + +func NewTxnCache(txn index.Txn, startTs uint64) *TxnCache { + return &TxnCache{ + txn: txn, + startTs: startTs, + } +} + +// implements index.CacheType interface +type QueryCache struct { + cache index.LocalCache + readTs uint64 +} + +func (qc *QueryCache) Find(prefix []byte, filter func([]byte) bool) (uint64, error) { + return qc.cache.Find(prefix, filter) +} + +func (qc *QueryCache) Get(key []byte) (rval index.Value, rerr error) { + return qc.cache.Get(key) +} + +func (qc *QueryCache) Ts() uint64 { + return qc.readTs +} + +func NewQueryCache(cache index.LocalCache, readTs uint64) *QueryCache { + return &QueryCache{ + cache: cache, + readTs: readTs, + } +} + +// getDataFromKeyWithCacheType(keyString, uid, c) looks up data in c +// associated with keyString and uid. +func getDataFromKeyWithCacheType(keyString string, uid uint64, c index.CacheType) (index.Value, error) { + key := DataKey(keyString, uid) + data, err := c.Get(key) + if err != nil { + return nil, errors.New(err.Error() + plError + keyString + " with uid" + strconv.FormatUint(uid, 10)) + } + return data, nil +} + +// populateEdgeDataFromStore(keyString, uid, c, edgeData) +// will fill edgeData with the contents of the neighboring edges for +// a given DataKey by looking into the given cache (which may result +// in a call to the underlying persistent storage). +// If data is found for the key, this returns true, otherwise, it +// returns false. If the data was found (and there were no errors), +// it populates edgeData with the found contents. +func populateEdgeDataFromKeyWithCacheType( + keyString string, + uid uint64, + c index.CacheType, + edgeData *[][]uint64) (bool, error) { + data, err := getDataFromKeyWithCacheType(keyString, uid, c) + // Note that "dataError" errors are treated as just not having + // found the data -- no harm, no foul, as it is probably a + // dead reference that we can ignore. + if err != nil && !strings.Contains(err.Error(), dataError) { + return false, err + } + if data == nil { + return false, nil + } + err = json.Unmarshal(data.([]byte), &edgeData) + return true, err +} + +// entryUuidInsert adds the entry uuid to the given key +func entryUuidInsert( + ctx context.Context, + key []byte, + txn index.Txn, + predEntryKey string, + entryUuid []byte) (*index.KeyValue, error) { + edge := &index.KeyValue{ + Entity: 1, + Attr: predEntryKey, + Value: entryUuid, + } + err := txn.AddMutationWithLockHeld(ctx, key, edge) + return edge, err +} + +func ConcatStrings(strs ...string) string { + total := "" + for _, s := range strs { + total += s + } + return total +} + +func getInsertLayer(maxLevels int) int { + // multFactor is a multiplicative factor used to normalize the distribution + var level int + randFloat := rand.Float64() + for i := 0; i < maxLevels; i++ { + // calculate level based on section 3.1 here + if randFloat < math.Pow(1.0/float64(5), float64(maxLevels-1-i)) { + level = i + break + } + } + return level +} + +// adds the data corresponding to a uid to the given vec variable in the form of []T +// this does not allocate memory for vec, so it must be allocated before calling this function +func (ph *persistentHNSW[T]) getVecFromUid(uid uint64, c index.CacheType, vec *[]T) error { + data, err := getDataFromKeyWithCacheType(ph.pred, uid, c) + if err != nil { + if strings.Contains(err.Error(), plError) { + // no vector. Return empty array of floats + index.BytesAsFloatArray([]byte{}, vec, ph.floatBits) + return errors.New("Nil vector returned") + } + return err + } + if data != nil { + index.BytesAsFloatArray(data.([]byte), vec, ph.floatBits) + return nil + + } else { + index.BytesAsFloatArray([]byte{}, vec, ph.floatBits) + return errors.New("Nil vector returned") + } +} + +// chooses whether to create the entry and start nodes based on if it already +// exists, and if it hasnt been created yet, it adds the startNode to all +// levels. +func (ph *persistentHNSW[T]) createEntryAndStartNodes( + ctx context.Context, + c *TxnCache, + inUuid uint64, + vec *[]T) (uint64, []*index.KeyValue, error) { + txn := c.txn + edges := []*index.KeyValue{} + entryKey := DataKey(ph.vecEntryKey, 1) // 0-profile_vector_entry + txn.LockKey(entryKey) + defer txn.UnlockKey(entryKey) + data, _ := txn.GetWithLockHeld(entryKey) + + create_edges := func(inUuid uint64) (uint64, []*index.KeyValue, error) { + startEdges, err := ph.addStartNodeToAllLevels(ctx, entryKey, txn, inUuid) + if err != nil { + return 0, []*index.KeyValue{}, err + } + // return entry node at all levels + edges = append(edges, startEdges...) + return 0, edges, nil + } + + if data == nil { + // no entries in vector index yet b/c no entry exists, so put in all levels + return create_edges(inUuid) + } + + entry := BytesToUint64(data.([]byte)) // convert entry Uuid returned from Get to uint64 + err := ph.getVecFromUid(entry, c, vec) + if err != nil || len(*vec) == 0 { + // The entry vector has been deleted. We have to create a new entry vector. + entry, err := ph.PickStartNode(ctx, c, vec) + if err != nil { + return 0, []*index.KeyValue{}, err + } + return create_edges(entry) + } + + return entry, edges, nil +} + +// adds empty layers to all levels +func (ph *persistentHNSW[T]) addStartNodeToAllLevels( + ctx context.Context, + entryKey []byte, + txn index.Txn, + inUuid uint64) ([]*index.KeyValue, error) { + edges := []*index.KeyValue{} + key := DataKey(ph.vecKey, inUuid) + emptyEdges := make([][]uint64, ph.maxLevels) + emptyEdgesBytes, err := json.Marshal(emptyEdges) + if err != nil { + return []*index.KeyValue{}, err + } + // creates empty at all levels only for entry node + edge, err := ph.newPersistentEdgeKeyValueEntry(ctx, key, txn, inUuid, emptyEdgesBytes) + if err != nil { + return []*index.KeyValue{}, err + } + edges = append(edges, edge) + inUuidByte := Uint64ToBytes(inUuid) + // add inUuid as entry for this structure from now on + edge, err = entryUuidInsert(ctx, entryKey, txn, ph.vecEntryKey, inUuidByte) + if err != nil { + return []*index.KeyValue{}, err + } + edges = append(edges, edge) + return edges, nil +} + +// creates a new edge with the given uuid and edges. Lock must be held before calling this function +func (ph *persistentHNSW[T]) newPersistentEdgeKeyValueEntry(ctx context.Context, key []byte, + txn index.Txn, uuid uint64, edges []byte) (*index.KeyValue, error) { + txn.LockKey(key) + defer txn.UnlockKey(key) + edge := &index.KeyValue{ + Entity: uuid, + Attr: ph.vecKey, + Value: edges, + } + if err := txn.AddMutationWithLockHeld(ctx, key, edge); err != nil { + return nil, err + } + return edge, nil +} + +// addNeighbors adds the neighbors of the given uuid to the given level. +// It returns the edge created and the error if any. +func (ph *persistentHNSW[T]) addNeighbors(ctx context.Context, tc *TxnCache, + uuid uint64, allLayerNeighbors [][]uint64) (*index.KeyValue, error) { + + txn := tc.txn + keyPred := ph.vecKey + key := DataKey(keyPred, uuid) + txn.LockKey(key) + defer txn.UnlockKey(key) + var nnEdgesErr error + var allLayerEdges [][]uint64 + var ok bool + allLayerEdges, ok = ph.nodeAllEdges[uuid] + if !ok { + data, _ := txn.GetWithLockHeld(key) + if data == nil { + allLayerEdges = allLayerNeighbors + } else { + // all edges of nearest neighbor + err := json.Unmarshal(data.([]byte), &allLayerEdges) + if err != nil { + return nil, err + } + } + } + for level := 0; level < ph.maxLevels; level++ { + allLayerEdges[level], nnEdgesErr = ph.removeDeadNodes(allLayerEdges[level], tc) + if nnEdgesErr != nil { + return nil, nnEdgesErr + } + // This adds at most efConstruction number of edges for each layer for this node + allLayerEdges[level] = append(allLayerEdges[level], allLayerNeighbors[level]...) + } + + // on every modification of the layer edges, add it to in mem map so you dont have to always be reading + // from persistent storage + ph.nodeAllEdges[uuid] = allLayerEdges + inboundEdgesBytes, marshalErr := json.Marshal(allLayerEdges) + if marshalErr != nil { + return nil, marshalErr + } + + edge := &index.KeyValue{ + Entity: uuid, + Attr: ph.vecKey, + Value: inboundEdgesBytes, + } + if err := txn.AddMutationWithLockHeld(ctx, key, edge); err != nil { + return nil, err + } + return edge, nil +} + +// removeDeadNodes(nnEdges, tc) removes dead nodes from nnEdges and returns the new nnEdges +func (ph *persistentHNSW[T]) removeDeadNodes(nnEdges []uint64, tc *TxnCache) ([]uint64, error) { + data, err := getDataFromKeyWithCacheType(ph.vecDead, 1, tc) + if err != nil && err.Error() == plError { + return []uint64{}, err + } + var deadNodes []uint64 + if data != nil { // if dead nodes exist, convert to []uint64 + deadNodes, err = ParseEdges(string(data.([]byte))) + if err != nil { + return []uint64{}, err + } + nnEdges = diff(nnEdges, deadNodes) // set nnEdges to be all elements not contained in deadNodes + } + return nnEdges, nil +} + +func Uint64ToBytes(key uint64) []byte { + b := make([]byte, 8) + binary.BigEndian.PutUint64(b, key) + return b +} + +func BytesToUint64(bytes []byte) uint64 { + return binary.BigEndian.Uint64(bytes) +} + +func isEqual[T c.Float](a []T, b []T) bool { + if len(a) != len(b) { + return false + } + for i, val := range a { + if val != b[i] { + return false + } + } + return true +} + +// DataKey generates a data key with the given attribute and UID. +// The structure of a data key is as follows: +// +// byte 0: key type prefix (set to DefaultPrefix or ByteSplit if part of a multi-part list) +// byte 1-2: length of attr +// next len(attr) bytes: value of attr +// next byte: data type prefix (set to ByteData) +// next eight bytes: value of uid +// next eight bytes (optional): if the key corresponds to a split list, the startUid of +// the split stored in this key and the first byte will be sets to ByteSplit. +func DataKey(attr string, uid uint64) []byte { + extra := 1 + 8 // ByteData + UID + buf, prefixLen := generateKey(DefaultPrefix, attr, extra) + + rest := buf[prefixLen:] + rest[0] = ByteData + + rest = rest[1:] + binary.BigEndian.PutUint64(rest, uid) + return buf +} + +// genKey creates the key and writes the initial bytes (type byte, length of attribute, +// and the attribute itself). It leaves the rest of the key empty for further processing +// if necessary. It also returns next index from where further processing should be done. +func generateKey(typeByte byte, attr string, extra int) ([]byte, int) { + // Separate namespace and attribute from attr and write namespace in the first 8 bytes of key. + namespace, attr := ParseNamespaceBytes(attr) + prefixLen := 1 + 8 + 2 + len(attr) // byteType + ns + len(pred) + pred + buf := make([]byte, prefixLen+extra) + buf[0] = typeByte + AssertTrue(copy(buf[1:], namespace) == 8) + rest := buf[9:] + + writeAttr(rest, attr) + return buf, prefixLen +} + +func ParseNamespaceBytes(attr string) ([]byte, string) { + splits := strings.SplitN(attr, NsSeparator, 2) + ns := make([]byte, 8) + binary.BigEndian.PutUint64(ns, strToUint(splits[0])) + return ns, splits[1] +} + +// AssertTrue asserts that b is true. Otherwise, it would log fatal. +func AssertTrue(b bool) { + if !b { + log.Fatalf("%+v", errors.Errorf("Assert failed")) + } +} + +func writeAttr(buf []byte, attr string) []byte { + AssertTrue(len(attr) < math.MaxUint16) + binary.BigEndian.PutUint16(buf[:2], uint16(len(attr))) + + rest := buf[2:] + AssertTrue(len(attr) == copy(rest, attr)) + + return rest[len(attr):] +} + +// For consistency, use base16 to encode/decode the namespace. +func strToUint(s string) uint64 { + ns, err := strconv.ParseUint(s, 16, 64) + Check(err) + return ns +} + +// Check logs fatal if err != nil. +func Check(err error) { + if err != nil { + err = errors.Wrap(err, "") + CaptureSentryException(err) + log.Fatalf("%+v", err) + } +} + +// CaptureSentryException sends the error report to Sentry. +func CaptureSentryException(err error) { + if err != nil { + sentry.CaptureException(err) + } +} diff --git a/tok/hnsw/persistent_factory.go b/tok/hnsw/persistent_factory.go index b1ccb724e06..8de9467f8f4 100644 --- a/tok/hnsw/persistent_factory.go +++ b/tok/hnsw/persistent_factory.go @@ -93,8 +93,20 @@ func (hf *persistentIndexFactory[T]) createWithLock( err := errors.New("index with name " + name + " already exists") return nil, err } - // Not implemented yet - return nil, nil + retVal := &persistentHNSW[T]{ + pred: name, + vecEntryKey: ConcatStrings(name, VecEntry), + vecKey: ConcatStrings(name, VecKeyword), + vecDead: ConcatStrings(name, VecDead), + floatBits: floatBits, + nodeAllEdges: map[uint64][][]uint64{}, + } + err := retVal.applyOptions(o) + if err != nil { + return nil, err + } + hf.indexMap[name] = retVal + return retVal, nil } // Find is an implementation of the IndexFactory interface function, invoked by an persistentIndexFactory diff --git a/tok/hnsw/persistent_hnsw.go b/tok/hnsw/persistent_hnsw.go index f931ae638b2..2a3fec0f09e 100644 --- a/tok/hnsw/persistent_hnsw.go +++ b/tok/hnsw/persistent_hnsw.go @@ -1,10 +1,15 @@ package hnsw import ( + "context" "fmt" + "strings" + "time" c "github.com/dgraph-io/dgraph/tok/constraints" + "github.com/dgraph-io/dgraph/tok/index" opt "github.com/dgraph-io/dgraph/tok/options" + "github.com/pkg/errors" ) type persistentHNSW[T c.Float] struct { @@ -17,6 +22,10 @@ type persistentHNSW[T c.Float] struct { vecDead string simType SimilarityType[T] floatBits int + // nodeAllEdges[65443][1][3] indicates the 3rd neighbor in the first + // layer for uuid 65443. The result will be a neighboring uuid. + nodeAllEdges map[uint64][][]uint64 + visitedUids []uint64 } func (ph *persistentHNSW[T]) applyOptions(o opt.Options) error { @@ -63,3 +72,377 @@ func (ph *persistentHNSW[T]) applyOptions(o opt.Options) error { } return nil } + +func (ph *persistentHNSW[T]) emptyFinalResultWithError(e error) ( + *index.SearchPathResult, error) { + return index.NewSearchPathResult(), e +} + +func (ph *persistentHNSW[T]) emptySearchResultWithError(e error) (*searchLayerResult[T], error) { + return newLayerResult[T](0), e +} + +// fillNeighborEdges(uuid, c, edges) will "fill" edges with the neighbors for +// all levels associated with given uuid and CacheType. +// It returns true when we were able to find the node (either in cache or +// in persistent store) and false otherwise. +// (Of course, it may also return an error if a problem was encountered). +func (ph *persistentHNSW[T]) fillNeighborEdges(uuid uint64, c index.CacheType, edges *[][]uint64) (bool, error) { + var ok bool + *edges, ok = ph.nodeAllEdges[uuid] + if ok { + return true, nil + } + + ok, err := populateEdgeDataFromKeyWithCacheType(ph.vecKey, uuid, c, edges) + if err != nil { + return false, err + } + if !ok { + return false, nil + } + + // add this to in mem storage of uid -> edges + ph.nodeAllEdges[uuid] = *edges + return true, nil +} + +// searchPersistentLayer searches a layer of the hnsw graph for the nearest +// neighbors of the query vector and returns the traversal path and the nearest +// neighbors +func (ph *persistentHNSW[T]) searchPersistentLayer( + c index.CacheType, + level int, + entry uint64, + startVec, query []T, + entryIsFilteredOut bool, + expectedNeighbors int, + filter index.SearchFilter[T]) (*searchLayerResult[T], error) { + r := newLayerResult[T](level) + + bestDist, err := ph.simType.distanceScore(startVec, query, ph.floatBits) + r.markFirstDistanceComputation() + if err != nil { + return ph.emptySearchResultWithError(err) + } + best := minPersistentHeapElement[T]{ + value: bestDist, + index: entry, + filteredOut: entryIsFilteredOut, + } + r.setFirstPathNode(best) + //create set using map to append to on future visited nodes + ph.visitedUids = append(ph.visitedUids, best.index) + candidateHeap := *buildPersistentHeapByInit([]minPersistentHeapElement[T]{best}) + for candidateHeap.Len() != 0 { + currCandidate := candidateHeap.Pop().(minPersistentHeapElement[T]) + if r.numNeighbors() < expectedNeighbors && + ph.simType.isBetterScore(r.lastNeighborScore(), currCandidate.value) { + // If the "worst score" in our neighbors list is deemed to have + // a better score than the current candidate -- and if we have at + // least our expected number of nearest results -- we discontinue + // the search. + // Note that while this is faithful to the published + // HNSW algorithms insofar as we stop when we reach a local + // minimum, it leaves something to be desired in terms of + // guarantees of getting best results. + break + } + var allLayerEdges [][]uint64 + + found, err := ph.fillNeighborEdges(currCandidate.index, c, &allLayerEdges) + if err != nil { + return ph.emptySearchResultWithError(err) + } + if !found { + continue + } + currLayerEdges := allLayerEdges[level] + currLayerEdges = diff(currLayerEdges, ph.visitedUids) + var eVec []T + for i := range currLayerEdges { + // iterate over candidate's neighbors distances to get + // best ones + _ = ph.getVecFromUid(currLayerEdges[i], c, &eVec) + // intentionally ignoring error -- we catch it + // indirectly via eVec == nil check. + if len(eVec) == 0 { + continue + } + currDist, err := ph.simType.distanceScore(eVec, query, ph.floatBits) + ph.visitedUids = append(ph.visitedUids, currLayerEdges[i]) + r.incrementDistanceComputations() + if err != nil { + return ph.emptySearchResultWithError(err) + } + filteredOut := !filter(query, eVec, currLayerEdges[i]) + currElement := initPersistentHeapElement( + currDist, currLayerEdges[i], filteredOut) + nodeVisited := r.nodeVisited(*currElement) + if !nodeVisited { + r.addToVisited(*currElement) + + // If we have not yet found k candidates, we can consider + // any candidate. Otherwise, only consider those that + // are better than our current k nearest neighbors. + // Note that the "numNeighbors" function is a bit tricky: + // If we previously added to the heap M elements that should + // be filtered out, we ignore M elements in the numNeighbors + // check! In this way, we can make sure to allow in up to + // expectedNeighbors "unfiltered" elements. + if ph.simType.isBetterScore(currDist, r.lastNeighborScore()) || + r.numNeighbors() < expectedNeighbors { + candidateHeap.Push(*currElement) + r.addPathNode(*currElement, ph.simType, expectedNeighbors) + } + } + } + } + return r, nil +} + +// Search searches the hnsw graph for the nearest neighbors of the query vector +// and returns the traversal path and the nearest neighbors +func (ph *persistentHNSW[T]) Search(ctx context.Context, c index.CacheType, query []T, + maxResults int, filter index.SearchFilter[T]) (nnUids []uint64, err error) { + r, err := ph.SearchWithPath(ctx, c, query, maxResults, filter) + return r.Neighbors, err +} + +// Search searches the hnsw graph for the nearest neighbors of the query uid +// and returns the traversal path and the nearest neighbors +func (ph *persistentHNSW[T]) SearchWithUid(ctx context.Context, c index.CacheType, queryUid uint64, + maxResults int, filter index.SearchFilter[T]) (nnUids []uint64, err error) { + var queryVec []T + err = ph.getVecFromUid(queryUid, c, &queryVec) + if err != nil { + if strings.Contains(err.Error(), plError) { + // No vector. return empty result + return []uint64{}, nil + } + return []uint64{}, err + } + + if len(queryVec) == 0 { + // No vector. return empty result + return []uint64{}, nil + } + + shouldFilterOutQueryVec := !filter(queryVec, queryVec, queryUid) + + // how normal search works is by cotinuously searching higher layers + // for the best entry node to the last layer since we already know the + // best entry node (since it already exists in the lowest level), we + // can just search the last layer and return the results. + r, err := ph.searchPersistentLayer( + c, ph.maxLevels-1, queryUid, queryVec, queryVec, + shouldFilterOutQueryVec, maxResults, filter) + for _, n := range r.neighbors { + nnUids = append(nnUids, n.index) + } + return nnUids, err +} + +// There will be times when the entry node has been deleted. In that case, we want to make a new node +// the first vector. +func (ph *persistentHNSW[T]) calculateNewEntryVec( + ctx context.Context, + c index.CacheType, + startVec *[]T) (uint64, error) { + + itr, err := c.Find([]byte(ph.pred), func(value []byte) bool { + index.BytesAsFloatArray(value, startVec, ph.floatBits) + return len(*startVec) != 0 + }) + + if err != nil { + return 0, errors.Wrapf(err, "HNSW tree has no elements") + } + if itr == 0 { + return itr, errors.New("HNSW tree has no elements") + } + + return itr, nil +} + +func (ph *persistentHNSW[T]) PickStartNode( + ctx context.Context, + c index.CacheType, + startVec *[]T) (uint64, error) { + + data, err := getDataFromKeyWithCacheType(ph.vecEntryKey, 1, c) + if err != nil { + if strings.Contains(err.Error(), plError) { + // The index might be empty + return ph.calculateNewEntryVec(ctx, c, startVec) + } + return 0, err + } + + entry := BytesToUint64(data.([]byte)) + err = ph.getVecFromUid(entry, c, startVec) + if err != nil { + fmt.Println(err) + } + + if len(*startVec) == 0 { + return ph.calculateNewEntryVec(ctx, c, startVec) + } + return entry, err +} + +// SearchWithPath allows persistentHNSW to implement index.OptionalIndexSupport. +// See index.OptionalIndexSupport.SearchWithPath for more info. +func (ph *persistentHNSW[T]) SearchWithPath( + ctx context.Context, + c index.CacheType, + query []T, + maxResults int, + filter index.SearchFilter[T]) (r *index.SearchPathResult, err error) { + start := time.Now().UnixMilli() + r = index.NewSearchPathResult() + + // 0-profile_vector_entry + var startVec []T + entry, err := ph.PickStartNode(ctx, c, &startVec) + if err != nil { + return ph.emptyFinalResultWithError(err) + } + + // Calculates best entry for last level (maxLevels-1) by searching each + // layer and using new best entry. + for level := 0; level < ph.maxLevels-1; level++ { + if isEqual(startVec, query) { + break + } + filterOut := !filter(query, startVec, entry) + layerResult, err := ph.searchPersistentLayer( + c, level, entry, startVec, query, filterOut, ph.efSearch, filter) + if err != nil { + return ph.emptyFinalResultWithError(err) + } + layerResult.updateFinalMetrics(r) + entry = layerResult.bestNeighbor().index + layerResult.updateFinalPath(r) + err = ph.getVecFromUid(entry, c, &startVec) + if err != nil { + return ph.emptyFinalResultWithError(err) + } + } + filterOut := !filter(query, startVec, entry) + layerResult, err := ph.searchPersistentLayer( + c, ph.maxLevels-1, entry, startVec, query, filterOut, maxResults, filter) + if err != nil { + return ph.emptyFinalResultWithError(err) + } + layerResult.updateFinalMetrics(r) + layerResult.updateFinalPath(r) + layerResult.addFinalNeighbors(r) + t := time.Now().UnixMilli() + elapsed := t - start + r.Metrics[searchTime] = uint64(elapsed) + return r, nil +} + +// InsertToPersistentStorage inserts a node into the hnsw graph and returns the +// traversal path and the edges created +func (ph *persistentHNSW[T]) Insert(ctx context.Context, c index.CacheType, + inUuid uint64, inVec []T) ([]*index.KeyValue, error) { + tc, ok := c.(*TxnCache) + if !ok { + return []*index.KeyValue{}, nil + } + _, edges, err := ph.insertHelper(ctx, tc, inUuid, inVec) + return edges, err +} + +// InsertToPersistentStorage inserts a node into the hnsw graph and returns the +// traversal path and the edges created +func (ph *persistentHNSW[T]) insertHelper(ctx context.Context, tc *TxnCache, + inUuid uint64, inVec []T) ([]minPersistentHeapElement[T], []*index.KeyValue, error) { + + // return all the new edges created at all HNSW levels + var startVec []T + entry, edges, err := ph.createEntryAndStartNodes(ctx, tc, inUuid, &startVec) + if err != nil || len(edges) > 0 { + return []minPersistentHeapElement[T]{}, edges, err + } + + if entry == inUuid { + // something interesting is you physically cannot add duplicate nodes, + // it'll just overwrite w the same info + // only situation where you can add duplicate nodes is if your + // mutation adds the same node as entry + return []minPersistentHeapElement[T]{}, []*index.KeyValue{}, nil + } + + // startVecs: vectors used to calc where to start up until inLevel, + // nns: nearest neighbors to return, + // visited: all visited nodes + // var nns []minPersistentHeapElement[T] + visited := []minPersistentHeapElement[T]{} + inLevel := getInsertLayer(ph.maxLevels) // calculate layer to insert node at (randomized every time) + var layerErr error + + for level := 0; level < inLevel; level++ { + // perform insertion for layers [level, max_level) only, when level < inLevel just find better start + err := ph.getVecFromUid(entry, tc, &startVec) + if err != nil { + return []minPersistentHeapElement[T]{}, []*index.KeyValue{}, err + } + layerResult, err := ph.searchPersistentLayer(tc, level, entry, startVec, + inVec, false, 1, index.AcceptAll[T]) + if err != nil { + return []minPersistentHeapElement[T]{}, []*index.KeyValue{}, err + } + entry = layerResult.bestNeighbor().index + } + + emptyEdges := make([][]uint64, ph.maxLevels) + _, err = ph.addNeighbors(ctx, tc, inUuid, emptyEdges) + if err != nil { + return []minPersistentHeapElement[T]{}, []*index.KeyValue{}, err + } + + var outboundEdgesAllLayers = make([][]uint64, ph.maxLevels) + var inboundEdgesAllLayersMap = make(map[uint64][][]uint64) + nnUidArray := []uint64{} + for level := inLevel; level < ph.maxLevels; level++ { + err := ph.getVecFromUid(entry, tc, &startVec) + if err != nil { + return []minPersistentHeapElement[T]{}, []*index.KeyValue{}, err + } + layerResult, err := ph.searchPersistentLayer(tc, level, entry, startVec, + inVec, false, ph.efConstruction, index.AcceptAll[T]) + if err != nil { + return []minPersistentHeapElement[T]{}, []*index.KeyValue{}, layerErr + } + + nns := layerResult.neighbors + for i := 0; i < len(nns); i++ { + nnUidArray = append(nnUidArray, nns[i].index) + inboundEdgesAllLayersMap[nns[i].index] = make([][]uint64, ph.maxLevels) + inboundEdgesAllLayersMap[nns[i].index][level] = + append(inboundEdgesAllLayersMap[nns[i].index][level], inUuid) + // add nn to outboundEdges. + // These should already be correctly ordered. + outboundEdgesAllLayers[level] = + append(outboundEdgesAllLayers[level], nns[i].index) + } + } + edge, err := ph.addNeighbors(ctx, tc, inUuid, outboundEdgesAllLayers) + for i := 0; i < len(nnUidArray); i++ { + edge, err := ph.addNeighbors( + ctx, tc, nnUidArray[i], inboundEdgesAllLayersMap[nnUidArray[i]]) + if err != nil { + return []minPersistentHeapElement[T]{}, []*index.KeyValue{}, err + } + edges = append(edges, edge) + } + if err != nil { + return []minPersistentHeapElement[T]{}, []*index.KeyValue{}, err + } + edges = append(edges, edge) + + return visited, edges, nil +} diff --git a/tok/hnsw/persistent_hnsw_test.go b/tok/hnsw/persistent_hnsw_test.go new file mode 100644 index 00000000000..a4c813e7a5f --- /dev/null +++ b/tok/hnsw/persistent_hnsw_test.go @@ -0,0 +1,614 @@ +package hnsw + +import ( + "context" + "fmt" + "sync" + "testing" + + c "github.com/dgraph-io/dgraph/tok/constraints" + "github.com/dgraph-io/dgraph/tok/index" + opt "github.com/dgraph-io/dgraph/tok/options" + "golang.org/x/exp/slices" +) + +type createpersistentHNSWTest[T c.Float] struct { + maxLevels int + efSearch int + efConstruction int + pred string + indexType string + expectedIndexType string + floatBits int +} + +var createpersistentHNSWTests = []createpersistentHNSWTest[float64]{ + { + maxLevels: 1, + efSearch: 1, + efConstruction: 1, + pred: "a", + indexType: "b", + expectedIndexType: Euclidian, + floatBits: 64, + }, + { + maxLevels: 1, + efSearch: 1, + efConstruction: 1, + pred: "a", + indexType: Euclidian, + expectedIndexType: Euclidian, + floatBits: 64, + }, + { + maxLevels: 1, + efSearch: 1, + efConstruction: 1, + pred: "a", + indexType: Cosine, + expectedIndexType: Cosine, + floatBits: 64, + }, + { + maxLevels: 1, + efSearch: 1, + efConstruction: 1, + pred: "a", + indexType: DotProd, + expectedIndexType: DotProd, + floatBits: 64, + }, +} + +func optionsFromCreateTestCase[T c.Float](tc createpersistentHNSWTest[T]) opt.Options { + retVal := opt.NewOptions() + retVal.SetOpt(MaxLevelsOpt, tc.maxLevels) + retVal.SetOpt(EfSearchOpt, tc.efSearch) + retVal.SetOpt(EfConstructionOpt, tc.efConstruction) + retVal.SetOpt(MetricOpt, GetSimType[T](tc.indexType, tc.floatBits)) + return retVal +} + +func TestRaceCreateOrReplace(t *testing.T) { + f := CreateFactory[float64](64) + test := createpersistentHNSWTests[0] + opts := optionsFromCreateTestCase(test) + + var wg sync.WaitGroup + run := func() { + for i := 0; i < 10; i++ { + vIndex, err := f.CreateOrReplace(test.pred, opts, nil, 64) + if err != nil { + t.Errorf("Error creating index: %s for test case %d (%+v)", + err, i, test) + } + if vIndex == nil { + t.Errorf("TestCreatepersistentHNSW test case %d (%+v) generated nil index", + i, test) + } + } + wg.Done() + } + + for i := 0; i < 5; i++ { + wg.Add(1) + go run() + } + + wg.Wait() +} + +func TestCreatepersistentHNSW(t *testing.T) { + f := CreateFactory[float64](64) + for i, test := range createpersistentHNSWTests { + opts := optionsFromCreateTestCase(test) + vIndex, err := f.CreateOrReplace(test.pred, opts, nil, 64) + if err != nil { + t.Errorf("Error creating index: %s for test case %d (%+v)", + err, i, test) + return + } + if vIndex == nil { + t.Errorf("TestCreatepersistentHNSW test case %d (%+v) generated nil index", + i, test) + return + } + flatPh := vIndex.(*persistentHNSW[float64]) + if flatPh.simType.indexType != test.expectedIndexType { + t.Errorf("output %q not equal to expected %q", flatPh.simType.indexType, test.expectedIndexType) + return + } + } +} + +type flatInMemListAddMutationTest struct { + key string + startTs uint64 + finishTs uint64 + t *index.KeyValue + expectedErr error +} + +var flatInMemListAddMutationTests = []flatInMemListAddMutationTest{ + {key: "a", startTs: 0, finishTs: 5, t: &index.KeyValue{Value: []byte("abc")}, expectedErr: nil}, + {key: "b", startTs: 1, finishTs: 2, t: &index.KeyValue{Value: []byte("123")}, expectedErr: nil}, + {key: "c", startTs: 0, finishTs: 99, t: &index.KeyValue{Value: []byte("xyz")}, expectedErr: nil}, +} + +// TODO: It seriously seems wrong to have a transactional concept so tightly coupled with +// +// Dgraph product here! We should expect that we are using this module for completely +// independent use, possibly having nothing to do with Dgraph. +func flatInMemListWriteMutation(test flatInMemListAddMutationTest, t *testing.T) { + l := newInMemList(test.key, test.startTs, test.finishTs) + err := l.AddMutation(context.TODO(), nil, test.t) + if err != nil { + if err.Error() != test.expectedErr.Error() { + t.Errorf("Output %q not equal to expected %q", err.Error(), test.expectedErr.Error()) + } + } else { + if err != test.expectedErr { + t.Errorf("Output %q not equal to expected %q", err, test.expectedErr) + } + } + // should not modify db [test.startTs, test.finishTs) + if tsDbs[test.finishTs-1].inMemTestDb[test.key] != tsDbs[test.startTs].inMemTestDb[test.key] { + t.Errorf( + "Database at time %q not equal to expected database at time %q. Expected: %q, Got: %q", + test.finishTs-1, test.startTs, + tsDbs[test.startTs].inMemTestDb[test.key], + tsDbs[test.finishTs-1].inMemTestDb[test.key]) + } + if string(tsDbs[test.finishTs].inMemTestDb[test.key].([]byte)[:]) != string(test.t.Value[:]) { + t.Errorf("The database at time %q for key %q gave value of %q instead of %q", test.finishTs, + test.key, string(tsDbs[test.finishTs].inMemTestDb[test.key].([]byte)[:]), string(test.t.Value[:])) + } + if string(tsDbs[test.finishTs].inMemTestDb[test.key].([]byte)[:]) != + string(tsDbs[99].inMemTestDb[test.key].([]byte)[:]) { + t.Errorf("The database at time %q for key %q gave value of %q instead of %q", test.finishTs, + test.key, string(tsDbs[99].inMemTestDb[test.key].([]byte)[:]), + string(tsDbs[test.finishTs].inMemTestDb[test.key].([]byte)[:])) + } +} + +func TestFlatInMemListAddMutation(t *testing.T) { + emptyTsDbs() + for _, test := range flatInMemListAddMutationTests { + flatInMemListWriteMutation(test, t) + } +} + +var flatInMemListAddMutationOverwriteTests = []flatInMemListAddMutationTest{ + {key: "a", startTs: 0, finishTs: 5, t: &index.KeyValue{Value: []byte("abc")}, expectedErr: nil}, + {key: "a", startTs: 0, finishTs: 5, t: &index.KeyValue{Value: []byte("123")}, expectedErr: nil}, + {key: "a", startTs: 0, finishTs: 5, t: &index.KeyValue{Value: []byte("xyz")}, expectedErr: nil}, +} + +func TestFlatInMemListAddOverwriteMutation(t *testing.T) { + emptyTsDbs() + for _, test := range flatInMemListAddMutationOverwriteTests { + flatInMemListWriteMutation(test, t) + } +} + +type flatInMemListAddMutationTestBranchDependent struct { + key string + startTs uint64 + finishTs uint64 + t *index.KeyValue + expectedErr error + currIteration int +} + +var flatInMemListAddMultipleWritesMutationTests = []flatInMemListAddMutationTestBranchDependent{ + {key: "a", startTs: 0, finishTs: 2, t: &index.KeyValue{Value: []byte("abc")}, expectedErr: nil, currIteration: 0}, + {key: "a", startTs: 1, finishTs: 3, t: &index.KeyValue{Value: []byte("123")}, expectedErr: nil, currIteration: 1}, + {key: "a", startTs: 2, finishTs: 4, t: &index.KeyValue{Value: []byte("xyz")}, expectedErr: nil, currIteration: 2}, +} + +func TestFlatInMemListAddMultipleWritesMutation(t *testing.T) { + emptyTsDbs() + for _, test := range flatInMemListAddMultipleWritesMutationTests { + l := newInMemList(test.key, test.startTs, test.finishTs) + err := l.AddMutation(context.TODO(), nil, test.t) + if err != nil { + if err.Error() != test.expectedErr.Error() { + t.Errorf("Output %q not equal to expected %q", err.Error(), test.expectedErr.Error()) + } + } else { + if err != test.expectedErr { + t.Errorf("Output %q not equal to expected %q", err, test.expectedErr) + } + } + if test.currIteration == 0 { + conv := flatInMemListAddMutationTest{test.key, test.startTs, test.finishTs, test.t, test.expectedErr} + flatInMemListWriteMutation(conv, t) + } else { + if string(tsDbs[test.finishTs-1].inMemTestDb[test.key].([]byte)[:]) != + string(flatInMemListAddMultipleWritesMutationTests[test.currIteration-1].t.Value[:]) { + t.Errorf("The database at time %q for key %q gave value of %q instead of %q", test.finishTs, + test.key, string(tsDbs[test.finishTs].inMemTestDb[test.key].([]byte)[:]), string(test.t.Value[:])) + } + if string(tsDbs[test.finishTs].inMemTestDb[test.key].([]byte)[:]) != string(test.t.Value[:]) { + t.Errorf("The database at time %q for key %q gave value of %q instead of %q", test.finishTs, + test.key, string(tsDbs[test.finishTs].inMemTestDb[test.key].([]byte)[:]), string(test.t.Value[:])) + } + } + } +} + +type insertToPersistentFlatStorageTest struct { + tc *TxnCache + inUuid uint64 + inVec []float64 + expectedErr error + expectedEdgesList []string + minExpectedEdge string +} + +var flatPhs = []*persistentHNSW[float64]{ + { + maxLevels: 5, + efConstruction: 16, + efSearch: 12, + pred: "0-a", + vecEntryKey: ConcatStrings("0-a", VecEntry), + vecKey: ConcatStrings("0-a", VecKeyword), + vecDead: ConcatStrings("0-a", VecDead), + floatBits: 64, + simType: GetSimType[float64](Euclidian, 64), + nodeAllEdges: make(map[uint64][][]uint64), + }, + { + maxLevels: 5, + efConstruction: 16, + efSearch: 12, + pred: "0-a", + vecEntryKey: ConcatStrings("0-a", VecEntry), + vecKey: ConcatStrings("0-a", VecKeyword), + vecDead: ConcatStrings("0-a", VecDead), + floatBits: 64, + simType: GetSimType[float64](Cosine, 64), + nodeAllEdges: make(map[uint64][][]uint64), + }, + { + maxLevels: 5, + efConstruction: 16, + efSearch: 12, + pred: "0-a", + vecEntryKey: ConcatStrings("0-a", VecEntry), + vecKey: ConcatStrings("0-a", VecKeyword), + vecDead: ConcatStrings("0-a", VecDead), + floatBits: 64, + simType: GetSimType[float64](DotProd, 64), + nodeAllEdges: make(map[uint64][][]uint64), + }, +} + +var flatPh = &persistentHNSW[float64]{ + maxLevels: 5, + efConstruction: 16, + efSearch: 12, + pred: "0-a", + vecEntryKey: ConcatStrings("0-a", VecEntry), + vecKey: ConcatStrings("0-a", VecKeyword), + vecDead: ConcatStrings("0-a", VecDead), + floatBits: 64, + simType: GetSimType[float64](Euclidian, 64), + nodeAllEdges: make(map[uint64][][]uint64), +} + +var flatEntryInsertToPersistentFlatStorageTests = []insertToPersistentFlatStorageTest{ + { + tc: NewTxnCache(&inMemTxn{startTs: 12, commitTs: 40}, 12), + inUuid: uint64(123), + inVec: []float64{0.824, 0.319, 0.111}, + expectedErr: nil, + expectedEdgesList: []string{"0-a__vector__123", "0-a__vector_entry_1"}, + minExpectedEdge: "", + }, + { + tc: NewTxnCache(&inMemTxn{startTs: 11, commitTs: 37}, 11), + inUuid: uint64(1), + inVec: []float64{0.3, 0.5, 0.7}, + expectedErr: nil, + expectedEdgesList: []string{"0-a__vector__1", "0-a__vector_entry_1"}, + minExpectedEdge: "", + }, + { + tc: NewTxnCache(&inMemTxn{startTs: 0, commitTs: 1}, 0), + inUuid: uint64(5), + inVec: []float64{0.1, 0.1, 0.1}, + expectedErr: nil, + expectedEdgesList: []string{"0-a__vector__5", "0-a__vector_entry_1"}, + minExpectedEdge: "", + }, +} + +func TestFlatEntryInsertToPersistentFlatStorage(t *testing.T) { + emptyTsDbs() + for _, test := range flatEntryInsertToPersistentFlatStorageTests { + emptyTsDbs() + key := DataKey(flatPh.pred, test.inUuid) + for i := range tsDbs { + tsDbs[i].inMemTestDb[string(key[:])] = floatArrayAsBytes(test.inVec) + } + edges, err := flatPh.Insert(context.TODO(), test.tc, test.inUuid, test.inVec) + if err != nil { + if err.Error() != test.expectedErr.Error() { + t.Errorf("Output %q not equal to expected %q", err.Error(), test.expectedErr.Error()) + } + } else { + if err != test.expectedErr { + t.Errorf("Output %q not equal to expected %q", err, test.expectedErr) + } + } + var float1, float2 = []float64{}, []float64{} + index.BytesAsFloatArray(tsDbs[0].inMemTestDb[string(key[:])].([]byte), &float1, 64) + index.BytesAsFloatArray(tsDbs[99].inMemTestDb[string(key[:])].([]byte), &float2, 64) + if !equalFloat64Slice(float1, float2) { + t.Errorf("Vector value for predicate %q at beginning and end of database were "+ + "not equivalent. Start Value: %v, End Value: %v", flatPh.pred, tsDbs[0].inMemTestDb[flatPh.pred].([]float64), + tsDbs[99].inMemTestDb[flatPh.pred].([]float64)) + } + edgesNameList := []string{} + for _, edge := range edges { + edgeName := edge.Attr + "_" + fmt.Sprint(edge.Entity) + edgesNameList = append(edgesNameList, edgeName) + } + if !equalStringSlice(edgesNameList, test.expectedEdgesList) { + t.Errorf("Edges created during insert is incorrect. Expected: %v, Got: %v", test.expectedEdgesList, edgesNameList) + } + entryKey := DataKey(ConcatStrings(flatPh.pred, VecEntry), 1) + entryVal := BytesToUint64(tsDbs[99].inMemTestDb[string(entryKey[:])].([]byte)) + if entryVal != test.inUuid { + t.Errorf("entry value stored is incorrect. Expected: %q, Got: %q", test.inUuid, entryVal) + } + } +} + +var flatEntryInsert = insertToPersistentFlatStorageTest{ + tc: NewTxnCache(&inMemTxn{startTs: 0, commitTs: 1}, 0), + inUuid: uint64(5), + inVec: []float64{0.1, 0.1, 0.1}, + expectedErr: nil, + expectedEdgesList: []string{ + "0-a__vector__5", + "0-a__vector__5", + "0-a__vector__5", + "0-a__vector__5", + "0-a__vector__5", + "0-a__vector_entry_1", + }, + minExpectedEdge: "", +} + +var nonflatEntryInsertToPersistentFlatStorageTests = []insertToPersistentFlatStorageTest{ + { + tc: NewTxnCache(&inMemTxn{startTs: 12, commitTs: 40}, 12), + inUuid: uint64(123), + inVec: []float64{0.824, 0.319, 0.111}, + expectedErr: nil, + expectedEdgesList: []string{}, + minExpectedEdge: "0-a__vector__123", + }, + { + tc: NewTxnCache(&inMemTxn{startTs: 11, commitTs: 37}, 11), + inUuid: uint64(1), + inVec: []float64{0.3, 0.5, 0.7}, + expectedErr: nil, + expectedEdgesList: []string{}, + minExpectedEdge: "0-a__vector__1", + }, +} + +func TestNonflatEntryInsertToPersistentFlatStorage(t *testing.T) { + emptyTsDbs() + key := DataKey(flatPh.pred, flatEntryInsert.inUuid) + for i := range tsDbs { + tsDbs[i].inMemTestDb[string(key[:])] = floatArrayAsBytes(flatEntryInsert.inVec) + } + _, err := flatPh.Insert(context.TODO(), + flatEntryInsert.tc, + flatEntryInsert.inUuid, + flatEntryInsert.inVec) + if err != nil { + t.Errorf("Encountered error on initial insert: %s", err) + return + } + // testKey := DataKey(BuildDataKeyPred(flatPh.pred, VecKeyword, fmt.Sprint(0)), flatEntryInsert.inUuid) + // fmt.Print(tsDbs[1].inMemTestDb[string(testKey[:])]) + for _, test := range nonflatEntryInsertToPersistentFlatStorageTests { + entryKey := DataKey(ConcatStrings(flatPh.pred, VecEntry), 1) + entryVal := BytesToUint64(tsDbs[99].inMemTestDb[string(entryKey[:])].([]byte)) + if entryVal != 5 { + t.Errorf("entry value stored is incorrect. Expected: %q, Got: %q", 5, entryVal) + } + for i := range tsDbs { + key := DataKey(flatPh.pred, test.inUuid) + tsDbs[i].inMemTestDb[string(key[:])] = floatArrayAsBytes(test.inVec) + } + edges, err := flatPh.Insert(context.TODO(), test.tc, test.inUuid, test.inVec) + if err != nil && test.expectedErr != nil { + if err.Error() != test.expectedErr.Error() { + t.Errorf("Output %q not equal to expected %q", err.Error(), test.expectedErr.Error()) + } + } else { + if err != test.expectedErr { + t.Errorf("Output %q not equal to expected %q", err, test.expectedErr) + } + } + var float1, float2 = []float64{}, []float64{} + index.BytesAsFloatArray(tsDbs[0].inMemTestDb[string(key[:])].([]byte), &float1, 64) + index.BytesAsFloatArray(tsDbs[99].inMemTestDb[string(key[:])].([]byte), &float2, 64) + if !equalFloat64Slice(float1, float2) { + t.Errorf("Vector value for predicate %q at beginning and end of database were "+ + "not equivalent. Start Value: %v, End Value: %v", flatPh.pred, tsDbs[0].inMemTestDb[flatPh.pred].([]float64), + tsDbs[99].inMemTestDb[flatPh.pred].([]float64)) + } + edgesNameList := []string{} + for _, edge := range edges { + edgeName := edge.Attr + "_" + fmt.Sprint(edge.Entity) + edgesNameList = append(edgesNameList, edgeName) + } + if !slices.Contains(edgesNameList, test.minExpectedEdge) { + t.Errorf("Expected at least %q in list of edges %v", test.minExpectedEdge, edgesNameList) + } + } +} + +type searchPersistentFlatStorageTest struct { + qc *QueryCache + query []float64 + maxResults int + expectedErr error + expectedNns []uint64 +} + +var searchPersistentFlatStorageTests = []searchPersistentFlatStorageTest{ + { + qc: NewQueryCache(&inMemLocalCache{readTs: 45}, 45), + query: []float64{0.3, 0.5, 0.7}, + maxResults: 1, + expectedErr: nil, + expectedNns: []uint64{1}, + }, + { + qc: NewQueryCache(&inMemLocalCache{readTs: 93}, 93), + query: []float64{0.824, 0.319, 0.111}, + maxResults: 1, + expectedErr: nil, + expectedNns: []uint64{5}, + }, +} + +var flatPopulateBasicInsertsForSearch = []insertToPersistentFlatStorageTest{ + { + tc: NewTxnCache(&inMemTxn{startTs: 0, commitTs: 1}, 0), + inUuid: uint64(5), + inVec: []float64{0.1, 0.1, 0.1}, + expectedErr: nil, + expectedEdgesList: nil, + minExpectedEdge: "", + }, + { + tc: NewTxnCache(&inMemTxn{startTs: 11, commitTs: 15}, 11), + inUuid: uint64(123), + inVec: []float64{0.824, 0.319, 0.111}, + expectedErr: nil, + expectedEdgesList: nil, + minExpectedEdge: "", + }, + { + tc: NewTxnCache(&inMemTxn{startTs: 20, commitTs: 37}, 20), + inUuid: uint64(1), + inVec: []float64{0.3, 0.5, 0.7}, + expectedErr: nil, + expectedEdgesList: nil, + minExpectedEdge: "", + }, +} + +func flatPopulateInserts(insertArr []insertToPersistentFlatStorageTest) error { + emptyTsDbs() + for _, in := range insertArr { + for i := range tsDbs { + key := DataKey(flatPh.pred, in.inUuid) + tsDbs[i].inMemTestDb[string(key[:])] = floatArrayAsBytes(in.inVec) + } + _, err := flatPh.Insert(context.TODO(), in.tc, in.inUuid, in.inVec) + if err != nil { + return err + } + } + return nil +} + +func RunFlatSearchTests(t *testing.T, test searchPersistentFlatStorageTest, flatPh *persistentHNSW[float64]) { + nns, err := flatPh.Search(context.TODO(), test.qc, test.query, test.maxResults, index.AcceptAll[float64]) + if err != nil && test.expectedErr != nil { + if err.Error() != test.expectedErr.Error() { + t.Errorf("Output %q not equal to expected %q", err.Error(), test.expectedErr.Error()) + } + } else { + if err != test.expectedErr { + t.Errorf("Output %q not equal to expected %q", err, test.expectedErr) + } + } + if !equalUint64Slice(nns, test.expectedNns) { + t.Errorf("Nearest neighbors expected value: %v, Got: %v", test.expectedNns, nns) + } +} + +func TestBasicSearchPersistentFlatStorage(t *testing.T) { + for _, flatPh := range flatPhs { + emptyTsDbs() + err := flatPopulateInserts(flatPopulateBasicInsertsForSearch) + if err != nil { + t.Errorf("Error populating inserts: %s", err) + return + } + for _, test := range searchPersistentFlatStorageTests { + RunFlatSearchTests(t, test, flatPh) + } + } +} + +var flatPopulateOverlappingInserts = []insertToPersistentFlatStorageTest{ + { + tc: NewTxnCache(&inMemTxn{startTs: 0, commitTs: 5}, 0), + inUuid: uint64(5), + inVec: []float64{0.1, 0.1, 0.1}, + expectedErr: nil, + expectedEdgesList: nil, + minExpectedEdge: "", + }, + { + tc: NewTxnCache(&inMemTxn{startTs: 3, commitTs: 9}, 3), + inUuid: uint64(123), + inVec: []float64{0.824, 0.319, 0.111}, + expectedErr: nil, + expectedEdgesList: nil, + minExpectedEdge: "", + }, + { + tc: NewTxnCache(&inMemTxn{startTs: 8, commitTs: 37}, 8), + inUuid: uint64(1), + inVec: []float64{0.3, 0.5, 0.7}, + expectedErr: nil, + expectedEdgesList: nil, + minExpectedEdge: "", + }, +} + +var overlappingSearchPersistentFlatStorageTests = []searchPersistentFlatStorageTest{ + { + qc: NewQueryCache(&inMemLocalCache{readTs: 45}, 45), + query: []float64{0.3, 0.5, 0.7}, + maxResults: 1, + expectedErr: nil, + expectedNns: []uint64{123}, + }, + { + qc: NewQueryCache(&inMemLocalCache{readTs: 93}, 93), + query: []float64{0.824, 0.319, 0.111}, + maxResults: 1, + expectedErr: nil, + expectedNns: []uint64{123}, + }, +} + +func TestOverlappingInsertsAndSearchPersistentFlatStorage(t *testing.T) { + for _, flatPh := range flatPhs { + emptyTsDbs() + err := flatPopulateInserts(flatPopulateOverlappingInserts) + if err != nil { + t.Errorf("Error from flatPopulateInserts: %s", err) + return + } + for _, test := range overlappingSearchPersistentFlatStorageTests { + RunFlatSearchTests(t, test, flatPh) + } + } +} diff --git a/tok/hnsw/search_layer.go b/tok/hnsw/search_layer.go new file mode 100644 index 00000000000..49f129648bb --- /dev/null +++ b/tok/hnsw/search_layer.go @@ -0,0 +1,120 @@ +package hnsw + +import ( + c "github.com/dgraph-io/dgraph/tok/constraints" + "github.com/dgraph-io/dgraph/tok/index" + + "fmt" +) + +type searchLayerResult[T c.Float] struct { + // neighbors represents the candidates with the best scores so far. + neighbors []minPersistentHeapElement[T] + // visited represents elements seen (so we don't try to re-visit). + visited []minPersistentHeapElement[T] + path []uint64 + metrics map[string]uint64 + level int + // filtered represents num elements of meighbors that don't + // belong in final return set since they should be filtered out. + // When we encounter a node that we consider a "best" node, but where + // it should be filtered out, we allow it to enter the "neighbors" + // attribute as an element. However, we then allow neighbors to + // grow by this extra "filtered" amount. Theoretically, it could be + // pushed out, but that will be okay! At the end, we grab all + // non-filtered elements up to the limit of what is expected. + filtered int +} + +func newLayerResult[T c.Float](level int) *searchLayerResult[T] { + return &searchLayerResult[T]{ + neighbors: []minPersistentHeapElement[T]{}, + visited: []minPersistentHeapElement[T]{}, + path: []uint64{}, + metrics: make(map[string]uint64), + level: level, + } +} + +func (slr *searchLayerResult[T]) setFirstPathNode(n minPersistentHeapElement[T]) { + slr.neighbors = []minPersistentHeapElement[T]{n} + slr.visited = []minPersistentHeapElement[T]{n} + slr.path = []uint64{n.index} +} + +func (slr *searchLayerResult[T]) addPathNode( + n minPersistentHeapElement[T], + simType SimilarityType[T], + maxResults int) { + slr.neighbors = simType.insortHeap(slr.neighbors, n) + if n.filteredOut { + slr.filtered++ + } + effectiveMaxLen := maxResults + slr.filtered + if len(slr.neighbors) > effectiveMaxLen { + slr.neighbors = slr.neighbors[:effectiveMaxLen] + } + + if slr.neighbors[0].index == n.index { + slr.path = append(slr.path, slr.neighbors[0].index) + } +} + +func (slr *searchLayerResult[T]) numNeighbors() int { + return len(slr.neighbors) - slr.filtered +} + +func (slr *searchLayerResult[T]) markFirstDistanceComputation() { + slr.metrics[distanceComputations] = 1 +} + +func (slr *searchLayerResult[T]) incrementDistanceComputations() { + slr.metrics[distanceComputations]++ +} + +// slr.lastNeighborScore() returns the "score" (based on similarity type) +// of the last neighbor being tracked. The score is reflected as a value +// of the minPersistentHeapElement. +// If slr is empty, this will panic. +func (slr *searchLayerResult[T]) lastNeighborScore() T { + return slr.neighbors[len(slr.neighbors)-1].value +} + +// slr.bestNeighbor() returns the heap element with the "best" score. +// panics if there is no such element. +func (slr *searchLayerResult[T]) bestNeighbor() minPersistentHeapElement[T] { + return slr.neighbors[0] +} + +func (slr *searchLayerResult[T]) nodeVisited(n minPersistentHeapElement[T]) bool { + for _, visitedNode := range slr.visited { + if visitedNode.index == n.index { + return true + } + } + return false +} + +func (slr *searchLayerResult[T]) addToVisited(n minPersistentHeapElement[T]) { + slr.visited = append(slr.visited, n) +} + +func (slr *searchLayerResult[T]) updateFinalMetrics(r *index.SearchPathResult) { + visitName := ConcatStrings(visitedVectorsLevel, fmt.Sprint(slr.level)) + r.Metrics[visitName] += uint64(len(slr.visited)) + for k, v := range slr.metrics { + r.Metrics[k] += v + } +} + +func (slr *searchLayerResult[T]) updateFinalPath(r *index.SearchPathResult) { + r.Path = append(r.Path, slr.path...) +} + +func (slr *searchLayerResult[T]) addFinalNeighbors(r *index.SearchPathResult) { + for _, n := range slr.neighbors { + if !n.filteredOut { + r.Neighbors = append(r.Neighbors, n.index) + } + } +} diff --git a/tok/hnsw/test_helper.go b/tok/hnsw/test_helper.go new file mode 100644 index 00000000000..036ac3f03e4 --- /dev/null +++ b/tok/hnsw/test_helper.go @@ -0,0 +1,281 @@ +package hnsw + +import ( + "context" + "encoding/binary" + "math" + "strings" + "sync" + + "github.com/dgraph-io/dgraph/tok/index" + "github.com/pkg/errors" +) + +// holds an map in memory that is a string (which will be []bytes as string) +// as the key, with an index.Val as the value +type indexStorage struct { + inMemTestDb map[string]index.Value + + //Two locks allow for lock promotion when writing, so we promote a read lock + //between the start and finish times to a full lock on the finish time + + // readMu acquires read locks when accessing values + readMu sync.RWMutex + // writeMu acquires write locks on mutations + writeMu sync.Mutex +} + +// datastructure visualization of persistent db over 100 units of time +// within this, we will conduct all testing, i.e. reads at 1 Ts = tsDbs[1], +// writes at 4 Ts = tsDbs[4] +var tsDbs [100]indexStorage + +func emptyTsDbs() { + for i := range tsDbs { + tsDbs[i] = indexStorage{inMemTestDb: make(map[string]index.Value)} + } +} + +type inMemList struct { + key string + startTs uint64 + finishTs uint64 +} + +// creates a new inMem list with the list's corresponding key, +// when it's action was started and when it will conclude. +// for mutations startTs will be txn.StartTs and finishTs will be txn.commitTs +// for reads, they both start and finish at c.ReadTs +// finishTs is unknown in real scenarios, this is for testing purposes +func newInMemList(key string, startTs, finishTs uint64) *inMemList { + return &inMemList{ + key: key, + startTs: startTs, + finishTs: finishTs, + } +} + +// locks the posting list & invokes ValueWithLockHeld +func (l *inMemList) Value(readTs uint64) (rval index.Value, rerr error) { + // reading should only lock the db at current instance in time + tsDbs[readTs].readMu.RLock() + defer tsDbs[readTs].readMu.RUnlock() + return l.ValueWithLockHeld(readTs) +} + +// reads value from the database at readTs corresponding to List's key +func (l *inMemList) ValueWithLockHeld(readTs uint64) (rval index.Value, rerr error) { + val, ok := tsDbs[readTs].inMemTestDb[l.key] + if !ok { + return nil, errors.New("Could not find data with key " + l.key) + } + return val, nil +} + +// locks the posting list and invokes AddMutationWithLockHeld +func (l *inMemList) AddMutation(ctx context.Context, txn index.Txn, t *index.KeyValue) error { + // locks from the txn.StartTs up to txn.CommitTs + l.Lock() + defer l.Unlock() + return l.AddMutationWithLockHeld(ctx, txn, t) +} + +// adds mutation to the database at the txn's commitTs +func (l *inMemList) AddMutationWithLockHeld(ctx context.Context, txn index.Txn, t *index.KeyValue) error { + // creates key from directedEdge + //builds value + val := t.Value + // a mutation persists from the moment the txn gets committed until the "rest of time" + for i := l.finishTs; i < uint64(len(tsDbs)); i++ { + tsDbs[i].inMemTestDb[l.key] = val + } + return nil +} + +// if youre locking at a certain point in time, the lock should be held for this moment +// and all future moments until your commitTs +func (l *inMemList) Lock() { + if !strings.Contains(l.key, "entry") { + for i := l.startTs; i <= l.finishTs; i++ { + tsDbs[i].readMu.RLock() + } + for i := l.finishTs; i < uint64(len(tsDbs)); i++ { + tsDbs[i].writeMu.Lock() + } + } +} + +// undoes lock +func (l *inMemList) Unlock() { + if !strings.Contains(l.key, "entry") { + for i := l.startTs; i <= l.finishTs; i++ { + tsDbs[i].readMu.RUnlock() + } + for i := l.finishTs; i < uint64(len(tsDbs)); i++ { + tsDbs[i].writeMu.Unlock() + } + } +} + +// a txn has a startTs (when the txn started) and commitTs (when the txn changes were committed) +type inMemTxn struct { + startTs uint64 + commitTs uint64 +} + +func (t *inMemTxn) Find(prefix []byte, filter func([]byte) bool) (uint64, error) { + tsDb := tsDbs[t.startTs] + tsDb.readMu.RLock() + defer tsDb.readMu.RUnlock() + for _, b := range tsDb.inMemTestDb { + if filter(b.([]byte)) { + return 1, nil + } + } + return 0, nil +} + +func (t *inMemTxn) StartTs() uint64 { + return t.startTs +} + +// locks the txn and invokes GetWithLockHeld +func (t *inMemTxn) Get(key []byte) (rval index.Value, rerr error) { + tsDbs[t.startTs].readMu.RLock() + defer tsDbs[t.startTs].readMu.RUnlock() + return t.GetWithLockHeld(key) +} + +// reads value from the database at txn's startTs +func (t *inMemTxn) GetWithLockHeld(key []byte) (rval index.Value, rerr error) { + val, ok := tsDbs[t.startTs].inMemTestDb[string(key[:])] + if !ok { + return nil, errors.New("Could not find data with key " + string(key[:])) + } + return val, nil +} + +// locks the txn and invokes AddMutationWithLockHeld +func (t *inMemTxn) AddMutation(ctx context.Context, key []byte, t1 *index.KeyValue) error { + tsDbs[t.startTs].writeMu.Lock() + defer tsDbs[t.startTs].writeMu.Unlock() + return t.AddMutationWithLockHeld(ctx, key, t1) +} + +// adds mutation to the database at the txn's commitTs +func (t *inMemTxn) AddMutationWithLockHeld(ctx context.Context, key []byte, t1 *index.KeyValue) error { + val := t1.Value + for i := t.commitTs; i < uint64(len(tsDbs)); i++ { + tsDbs[i].inMemTestDb[string(key[:])] = val + } + return nil +} + +// locks the txn +func (t *inMemTxn) LockKey(key []byte) { + if !strings.Contains(string(key[:]), "entry") { + // locks from the txn.StartTs up to txn.CommitTs + for i := t.startTs; i <= t.commitTs; i++ { + tsDbs[i].readMu.RLock() + } + for i := t.commitTs; i < uint64(len(tsDbs)); i++ { + tsDbs[i].writeMu.Lock() + } + } +} + +// undoes lock +func (t *inMemTxn) UnlockKey(key []byte) { + if !strings.Contains(string(key[:]), "entry") { + // locks from the txn.StartTs up to txn.CommitTs + for i := t.startTs; i <= t.commitTs; i++ { + tsDbs[i].readMu.RUnlock() + } + for i := t.commitTs; i < uint64(len(tsDbs)); i++ { + tsDbs[i].writeMu.Unlock() + } + } +} + +type inMemLocalCache struct { + readTs uint64 +} + +// locks the local cache and invokes GetWithLockHeld +func (c *inMemLocalCache) Get(key []byte) (rval index.Value, rerr error) { + tsDbs[c.readTs].readMu.RLock() + defer tsDbs[c.readTs].readMu.RUnlock() + return c.GetWithLockHeld(key) +} + +func (c *inMemLocalCache) Find(prefix []byte, filter func([]byte) bool) (uint64, error) { + tsDb := tsDbs[c.readTs] + tsDb.readMu.RLock() + defer tsDb.readMu.RUnlock() + for _, b := range tsDb.inMemTestDb { + if filter(b.([]byte)) { + return 1, nil + } + } + return 0, nil +} + +// reads value from the database at c's readTs +func (c *inMemLocalCache) GetWithLockHeld(key []byte) (rval index.Value, rerr error) { + val, ok := tsDbs[c.readTs].inMemTestDb[string(key[:])] + if !ok { + return nil, errors.New("Could not find data with key " + string(key[:])) + } + return val, nil +} + +func equalFloat64Slice(a, b []float64) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +func equalStringSlice(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +func equalUint64Slice(a, b []uint64) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +// floatArrayAsBytes(v) will create a byte array encoding +// v using LittleEndian format. This is sort of the inverse +// of BytesAsFloatArray, but note that we can always be successful +// converting to bytes, but the inverse is not feasible. +func floatArrayAsBytes(v []float64) []byte { + retVal := make([]byte, 8*len(v)) + offset := retVal + for i := 0; i < len(v); i++ { + bits := math.Float64bits(v[i]) + binary.LittleEndian.PutUint64(offset, bits) + offset = offset[8:] + } + return retVal +} diff --git a/tok/index/helper.go b/tok/index/helper.go new file mode 100644 index 00000000000..200645d9fdc --- /dev/null +++ b/tok/index/helper.go @@ -0,0 +1,79 @@ +/* + * Copyright 2023 Hypermode, Inc. and Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package index + +import ( + "encoding/binary" + "math" + + c "github.com/dgraph-io/dgraph/tok/constraints" +) + +// BytesAsFloatArray[T c.Float](encoded) converts encoded into a []T, +// where T is either float32 or float64, depending on the value of floatBits. +// Let floatBytes = floatBits/8. If len(encoded) % floatBytes is +// not 0, it will ignore any trailing bytes, and simply convert floatBytes +// bytes at a time to generate the entries. +// The result is appended to the given retVal slice. If retVal is nil +// then a new slice is created and appended to. +func BytesAsFloatArray[T c.Float](encoded []byte, retVal *[]T, floatBits int) { + // Unfortunately, this is not as simple as casting the result, + // and it is also not possible to directly use the + // golang "unsafe" library to directly do the conversion. + // The machine where this operation gets run might prefer + // BigEndian/LittleEndian, but the machine that sent it may have + // preferred the other, and there is no way to tell! + // + // The solution below, unfortunately, requires another memory + // allocation. + // TODO Potential optimization: If we detect that current machine is + // using LittleEndian format, there might be a way of making this + // work with the golang "unsafe" library. + floatBytes := floatBits / 8 + + *retVal = (*retVal)[:0] + resultLen := len(encoded) / floatBytes + if resultLen == 0 { + return + } + for i := 0; i < resultLen; i++ { + // Assume LittleEndian for encoding since this is + // the assumption elsewhere when reading from client. + // See dgraph-io/dgo/protos/api.pb.go + // See also dgraph-io/dgraph/types/conversion.go + // This also seems to be the preference from many examples + // I have found via Google search. It's unclear why this + // should be a preference. + if retVal == nil { + retVal = &[]T{} + } + *retVal = append(*retVal, BytesToFloat[T](encoded, floatBits)) + + encoded = encoded[(floatBytes):] + } +} + +func BytesToFloat[T c.Float](encoded []byte, floatBits int) T { + if floatBits == 32 { + bits := binary.LittleEndian.Uint32(encoded) + return T(math.Float32frombits(bits)) + } else if floatBits == 64 { + bits := binary.LittleEndian.Uint64(encoded) + return T(math.Float64frombits(bits)) + } + panic("Invalid floatBits") +} diff --git a/tok/index/index.go b/tok/index/index.go index e87faba9e91..65caf65a8b5 100644 --- a/tok/index/index.go +++ b/tok/index/index.go @@ -79,15 +79,35 @@ func AcceptAll[T c.Float](_, _ []T, _ uint64) bool { return true } // AcceptNone implements SearchFilter by way of rejecting all results. func AcceptNone[T c.Float](_, _ []T, _ uint64) bool { return false } +// OptionalIndexSupport defines abilities that might not be universally +// supported by all VectorIndex types. A VectorIndex will technically +// define the functions required by OptionalIndexSupport, but may do so +// by way of simply returning an errors.ErrUnsupported result. +type OptionalIndexSupport[T c.Float] interface { + // SearchWithPath(ctx, c, query, maxResults, filter) is similar to + // Search(ctx, c, query, maxResults, filter), but returns an extended + // set of content in the search results. + // The full contents returned are indicated by the SearchPathResult. + // See the description there for more info. + SearchWithPath( + ctx context.Context, + c CacheType, + query []T, + maxResults int, + filter SearchFilter[T]) (*SearchPathResult, error) +} + // A VectorIndex can be used to Search for vectors and add vectors to an index. type VectorIndex[T c.Float] interface { + OptionalIndexSupport[T] + // Search will find the uids for a given set of vectors based on the // input query, limiting to the specified maximum number of results. // The filter parameter indicates that we might discard certain parameters // based on some input criteria. The maxResults count is counted *after* // being filtered. In other words, we only count those results that had not // been filtered out. - Search(ctx context.Context, query []T, + Search(ctx context.Context, c CacheType, query []T, maxResults int, filter SearchFilter[T]) ([]uint64, error) @@ -97,11 +117,51 @@ type VectorIndex[T c.Float] interface { // based on some input criteria. The maxResults count is counted *after* // being filtered. In other words, we only count those results that had not // been filtered out. - SearchWithUid(ctx context.Context, queryUid uint64, + SearchWithUid(ctx context.Context, c CacheType, queryUid uint64, maxResults int, filter SearchFilter[T]) ([]uint64, error) // Insert will add a vector and uuid into the existing VectorIndex. If // uuid already exists, it should throw an error to not insert duplicate uuids - Insert(ctx context.Context, uuid uint64, vec []T) ([]*KeyValue, error) + Insert(ctx context.Context, c CacheType, uuid uint64, vec []T) ([]*KeyValue, error) +} + +// A Txn is an interface representation of a persistent storage transaction, +// where multiple operations are performed on a database +type Txn interface { + // StartTs gets the exact time that the transaction started, returned in uint64 format + StartTs() uint64 + // Get uses a []byte key to return the Value corresponding to the key + Get(key []byte) (rval Value, rerr error) + // GetWithLockHeld uses a []byte key to return the Value corresponding to the key with a mutex lock held + GetWithLockHeld(key []byte) (rval Value, rerr error) + Find(prefix []byte, filter func(val []byte) bool) (uint64, error) + // Adds a mutation operation on a index.Txn interface, where the mutation + // is represented in the form of an index.DirectedEdge + AddMutation(ctx context.Context, key []byte, t *KeyValue) error + // Same as AddMutation but with a mutex lock held + AddMutationWithLockHeld(ctx context.Context, key []byte, t *KeyValue) error + // mutex lock + LockKey(key []byte) + // mutex unlock + UnlockKey(key []byte) +} + +// Local cache is an interface representation of the local cache of a persistent storage system +type LocalCache interface { + // Get uses a []byte key to return the Value corresponding to the key + Get(key []byte) (rval Value, rerr error) + // GetWithLockHeld uses a []byte key to return the Value corresponding to the key with a mutex lock held + GetWithLockHeld(key []byte) (rval Value, rerr error) + Find(prefix []byte, filter func(val []byte) bool) (uint64, error) +} + +// Value is an interface representation of the value of a persistent storage system +type Value interface{} + +// CacheType is an interface representation of the cache of a persistent storage system +type CacheType interface { + Get(key []byte) (rval Value, rerr error) + Ts() uint64 + Find(prefix []byte, filter func(val []byte) bool) (uint64, error) } diff --git a/tok/index/search_path.go b/tok/index/search_path.go new file mode 100644 index 00000000000..ccd41f732fb --- /dev/null +++ b/tok/index/search_path.go @@ -0,0 +1,41 @@ +/* + * Copyright 2023 Hypermode, Inc. and Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package index + +// SearchPathResult is the return-type for the optional +// SearchWithPath function for a VectorIndex +// (by way of extending OptionalIndexSupport). +type SearchPathResult struct { + // The collection of nearest-neighbors in sorted order after filtlering + // out neighbors that fail any Filter criteria. + Neighbors []uint64 + // The path from the start of search to the closest neighbor vector. + Path []uint64 + // A collection of captured named counters that occurred for the + // particular search. + Metrics map[string]uint64 +} + +// NewSearchPathResult() provides an initialized (empty) *SearchPathResult. +// The attributes will be non-nil, but empty. +func NewSearchPathResult() *SearchPathResult { + return &SearchPathResult{ + Neighbors: []uint64{}, + Path: []uint64{}, + Metrics: make(map[string]uint64), + } +} diff --git a/worker/mutation.go b/worker/mutation.go index 652bd0a1954..2c3d80792a4 100644 --- a/worker/mutation.go +++ b/worker/mutation.go @@ -403,7 +403,7 @@ func checkSchema(s *pb.SchemaUpdate) error { return errors.Errorf("Tokenizer must be specified while indexing a predicate: %+v", s) } - if len(s.Tokenizer) > 0 && s.Directive != pb.SchemaUpdate_INDEX { + if schema.HasTokenizerOrVectorIndexSpec(s) && s.Directive != pb.SchemaUpdate_INDEX { return errors.Errorf("Directive must be SchemaUpdate_INDEX when a tokenizer is specified") } @@ -514,6 +514,7 @@ func validateSchemaForUnique(prevSchema pb.SchemaUpdate, currentSchema *pb.Schem // ValidateAndConvert checks compatibility or converts to the schema type if the storage type is // specified. If no storage type is specified then it converts to the schema type. func ValidateAndConvert(edge *pb.DirectedEdge, su *pb.SchemaUpdate) error { + if isDeletePredicateEdge(edge) { return nil } @@ -541,7 +542,7 @@ func ValidateAndConvert(edge *pb.DirectedEdge, su *pb.SchemaUpdate) error { return errors.Errorf("Input for predicate %q of type scalar is uid. Edge: %v", x.ParseAttr(edge.Attr), edge) - // The suggested storage type matches the schema, OK! + // The suggested storage type matches the schema, OK! (Nothing to do ...) case storageType == schemaType && schemaType != types.DefaultID: return nil @@ -557,6 +558,7 @@ func ValidateAndConvert(edge *pb.DirectedEdge, su *pb.SchemaUpdate) error { src := types.Val{Tid: types.TypeID(edge.ValueType), Value: edge.Value} // check compatibility of schema type and storage type + // The goal is to convert value on edge to value type defined by schema. if dst, err = types.Convert(src, schemaType); err != nil { return err } @@ -578,8 +580,16 @@ func ValidateAndConvert(edge *pb.DirectedEdge, su *pb.SchemaUpdate) error { } } + // TODO: Figure out why this is Enum. It really seems like an odd choice -- rather than + // specifying it as the same type as presented in su. edge.ValueType = schemaType.Enum() - edge.Value = b.Value.([]byte) + var ok bool + edge.Value, ok = b.Value.([]byte) + if !ok { + return errors.Errorf("failure to convert edge type: '%+v' to schema type: '%+v'", + storageType, schemaType) + } + return nil } diff --git a/worker/predicate_move.go b/worker/predicate_move.go index f5ad429009b..1d21abb4fe2 100644 --- a/worker/predicate_move.go +++ b/worker/predicate_move.go @@ -22,6 +22,7 @@ import ( "io" "math" "strconv" + "strings" "github.com/dustin/go-humanize" "github.com/golang/glog" @@ -36,6 +37,7 @@ import ( "github.com/dgraph-io/dgraph/schema" "github.com/dgraph-io/dgraph/x" "github.com/dgraph-io/ristretto/z" + "github.com/dgraph-io/dgraph/tok/hnsw" ) var ( @@ -216,7 +218,8 @@ func (w *grpcWorker) MovePredicate(ctx context.Context, return &emptyPayload, errEmptyPredicate } - if in.DestGid == 0 { + //TODO: need to find possibly a better way to not move __vector_ predicates + if in.DestGid == 0 && !strings.Contains(in.Predicate, hnsw.VecKeyword) { glog.Infof("Was instructed to delete tablet: %v", in.Predicate) // Expected Checksum ensures that all the members of this group would block until they get // the latest membership status where this predicate now belongs to another group. So they @@ -230,6 +233,11 @@ func (w *grpcWorker) MovePredicate(ctx context.Context, } return &emptyPayload, groups().Node.proposeAndWait(ctx, p) } + + if strings.Contains(in.Predicate, hnsw.VecKeyword) { + return &emptyPayload, nil + } + if err := posting.Oracle().WaitForTs(ctx, in.TxnTs); err != nil { return &emptyPayload, errors.Errorf("While waiting for txn ts: %d. Error: %v", in.TxnTs, err) } diff --git a/worker/task.go b/worker/task.go index b34e089213f..7761033e2e0 100644 --- a/worker/task.go +++ b/worker/task.go @@ -19,6 +19,7 @@ package worker import ( "bytes" "context" + "fmt" "sort" "strconv" "strings" @@ -41,6 +42,8 @@ import ( "github.com/dgraph-io/dgraph/schema" ctask "github.com/dgraph-io/dgraph/task" "github.com/dgraph-io/dgraph/tok" + "github.com/dgraph-io/dgraph/tok/hnsw" + "github.com/dgraph-io/dgraph/tok/index" "github.com/dgraph-io/dgraph/types" "github.com/dgraph-io/dgraph/types/facets" "github.com/dgraph-io/dgraph/x" @@ -220,6 +223,7 @@ const ( uidInFn customIndexFn matchFn + similarToFn standardFn = 100 ) @@ -258,6 +262,8 @@ func parseFuncTypeHelper(name string) (FuncType, string) { return hasFn, f case "uid_in": return uidInFn, f + case "similar_to": + return similarToFn, f case "anyof", "allof": return customIndexFn, f case "match": @@ -282,6 +288,8 @@ func needsIndex(fnType FuncType, uidList *pb.List) bool { return true case geoFn, fullTextSearchFn, standardFn, matchFn: return true + case similarToFn: + return true } return false } @@ -317,7 +325,7 @@ func (srcFn *functionContext) needsValuePostings(typ types.TypeID) (bool, error) case uidInFn, compareScalarFn: // Operate on uid postings return false, nil - case notAFunction: + case notAFunction, similarToFn: return typ.IsScalar(), nil } return false, errors.Errorf("Unhandled case in fetchValuePostings for fn: %s", srcFn.fname) @@ -341,11 +349,46 @@ func (qs *queryState) handleValuePostings(ctx context.Context, args funcArgs) er } switch srcFn.fnType { - case notAFunction, aggregatorFn, passwordFn, compareAttrFn: + case notAFunction, aggregatorFn, passwordFn, compareAttrFn, similarToFn: default: return errors.Errorf("Unhandled function in handleValuePostings: %s", srcFn.fname) } + if srcFn.fnType == similarToFn { + numNeighbors, err := strconv.ParseInt(q.SrcFunc.Args[0], 10, 32) + if err != nil { + return fmt.Errorf("invalid value for number of neighbors: %s", q.SrcFunc.Args[0]) + } + cspec, err := pickFactoryCreateSpec(ctx, args.q.Attr) + if err != nil { + return err + } + //TODO: generate maxLevels from schema, filter, etc. + qc := hnsw.NewQueryCache( + posting.NewViLocalCache(qs.cache), + args.q.ReadTs, + ) + indexer, err := cspec.CreateIndex(args.q.Attr) + if err != nil { + return err + } + var nnUids []uint64 + if srcFn.vectorInfo != nil { + nnUids, err = indexer.Search(ctx, qc, srcFn.vectorInfo, + int(numNeighbors), index.AcceptAll[float32]) + } else { + nnUids, err = indexer.SearchWithUid(ctx, qc, srcFn.vectorUid, + int(numNeighbors), index.AcceptAll[float32]) + } + + if err != nil { + return err + } + sort.Slice(nnUids, func(i, j int) bool { return nnUids[i] < nnUids[j] }) + args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{Uids: nnUids}) + return nil + } + if srcFn.atype == types.PasswordID && srcFn.fnType != passwordFn { // Silently skip if the user is trying to fetch an attribute of type password. return nil @@ -1026,7 +1069,7 @@ func (qs *queryState) helpProcessTask(ctx context.Context, q *pb.Query, gid uint } if needsValPostings { span.Annotate(nil, "handleValuePostings") - if err = qs.handleValuePostings(ctx, args); err != nil { + if err := qs.handleValuePostings(ctx, args); err != nil { return nil, err } } else { @@ -1684,6 +1727,8 @@ type functionContext struct { isFuncAtRoot bool isStringFn bool atype types.TypeID + vectorInfo []float32 + vectorUid uint64 } const ( @@ -1941,6 +1986,14 @@ func parseSrcFn(ctx context.Context, q *pb.Query) (*functionContext, error) { return nil, err } checkRoot(q, fc) + case similarToFn: + if err = ensureArgsCount(q.SrcFunc, 2); err != nil { + return nil, err + } + fc.vectorInfo, fc.vectorUid, err = interpretVFloatOrUid(q.SrcFunc.Args[1]) + if err != nil { + return nil, err + } case uidInFn: for _, arg := range q.SrcFunc.Args { uidParsed, err := strconv.ParseUint(arg, 0, 64) @@ -1966,6 +2019,18 @@ func parseSrcFn(ctx context.Context, q *pb.Query) (*functionContext, error) { return fc, nil } +func interpretVFloatOrUid(val string) ([]float32, uint64, error) { + vf, err := types.ParseVFloat(val) + if err == nil { + return vf, 0, nil + } + uid, err := strconv.ParseUint(val, 0, 64) + if err == nil { + return nil, uid, nil + } + return nil, uid, errors.Errorf("Value %q is not a uid or vector", val) +} + // ServeTask is used to respond to a query. func (w *grpcWorker) ServeTask(ctx context.Context, q *pb.Query) (*pb.Result, error) { ctx, span := otrace.StartSpan(ctx, "worker.ServeTask") diff --git a/worker/tokens.go b/worker/tokens.go index b930974d018..13f7f2cf703 100644 --- a/worker/tokens.go +++ b/worker/tokens.go @@ -122,6 +122,30 @@ func pickTokenizer(ctx context.Context, attr string, f string) (tok.Tokenizer, e return tokenizers[0], nil } +// pickFactoryCreateSpec(ctx, attr) will find the FactoryCreateSpec (i.e., +// index name + options) for the given attribute "attr". +// Note that unlike pickTokenizer(ctx, attr, f), we do not include the +// parameter "f" (for function name), as we do not take action with it. +// This is otherwise similar to pickTokenizer. +func pickFactoryCreateSpec(ctx context.Context, attr string) (*tok.FactoryCreateSpec, error) { + // Get the tokenizers and choose the corresponding one. + if !schema.State().IsIndexed(ctx, attr) { + return nil, errors.Errorf("Attribute %s is not indexed.", attr) + } + + cspecs, err := schema.State().FactoryCreateSpec(ctx, attr) + if err != nil { + return nil, err + } + if len(cspecs) == 0 { + return nil, errors.Errorf("Schema state not found for %s.", attr) + } + + // At the moment, it would only be relevant to consider the first one. + // This is similar to pickTokenizer in behavior. + return cspecs[0], nil +} + // getInequalityTokens gets tokens ge/le/between compared to given tokens using the first sortable // index that is found for the predicate. // In case of ge/gt/le/lt/eq len(ineqValues) should be 1, else(between) len(ineqValues) should be 2.