From 4024583de2a97d6c0aa27ad670c2a1573072c310 Mon Sep 17 00:00:00 2001 From: kevindiu Date: Tue, 12 Sep 2023 01:34:59 +0000 Subject: [PATCH] impl test Signed-off-by: kevindiu --- apis/grpc/v1/payload/payload.pb.go | 1 - .../core/ngt/handler/grpc/handler_test.go | 2 +- .../core/ngt/handler/grpc/update_test.go | 119 +++++++++++++----- 3 files changed, 88 insertions(+), 34 deletions(-) diff --git a/apis/grpc/v1/payload/payload.pb.go b/apis/grpc/v1/payload/payload.pb.go index 40295c75efa..726e4e2d1ef 100644 --- a/apis/grpc/v1/payload/payload.pb.go +++ b/apis/grpc/v1/payload/payload.pb.go @@ -27,7 +27,6 @@ import ( _ "github.com/envoyproxy/protoc-gen-validate/validate" _ "github.com/planetscale/vtprotobuf/vtproto" - sync "github.com/vdaas/vald/internal/sync" status "google.golang.org/genproto/googleapis/rpc/status" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" diff --git a/pkg/agent/core/ngt/handler/grpc/handler_test.go b/pkg/agent/core/ngt/handler/grpc/handler_test.go index 62748b5eacc..71fa62950c3 100644 --- a/pkg/agent/core/ngt/handler/grpc/handler_test.go +++ b/pkg/agent/core/ngt/handler/grpc/handler_test.go @@ -66,7 +66,7 @@ func newIndexedNGTService(ctx context.Context, eg errgroup.Group, t request.Obje // insert and create index for _, req := range reqs.GetRequests() { - err := ngt.Insert(req.GetVector().GetId(), req.GetVector().GetVector()) + err := ngt.InsertWithTime(req.GetVector().GetId(), req.GetVector().GetVector(), req.GetVector().GetTimestamp()) if err != nil { return nil, err } diff --git a/pkg/agent/core/ngt/handler/grpc/update_test.go b/pkg/agent/core/ngt/handler/grpc/update_test.go index ae65db655c4..592e0d45bf9 100644 --- a/pkg/agent/core/ngt/handler/grpc/update_test.go +++ b/pkg/agent/core/ngt/handler/grpc/update_test.go @@ -17,6 +17,7 @@ import ( "context" "math" "testing" + "time" "github.com/vdaas/vald/apis/grpc/v1/payload" "github.com/vdaas/vald/internal/config" @@ -36,6 +37,7 @@ func Test_server_Update(t *testing.T) { type args struct { indexID string indexVector []float32 + indexTS int64 req *payload.Update_Request } type fields struct { @@ -50,11 +52,11 @@ func Test_server_Update(t *testing.T) { args args fields fields want want - checkFunc func(want, *payload.Object_Location, error) error - beforeFunc func(*testing.T, args) (Server, error) + checkFunc func(want, *payload.Object_Location, Server, error) error + beforeFunc func(*testing.T, context.Context, args, string) (Server, error) afterFunc func(args) } - defaultCheckFunc := func(w want, gotRes *payload.Object_Location, err error) error { + defaultCheckFunc := func(w want, gotRes *payload.Object_Location, s Server, err error) error { if err != nil { st, ok := status.FromError(err) if !ok { @@ -89,10 +91,7 @@ func Test_server_Update(t *testing.T) { defaultUpdateConfig := &payload.Update_Config{ SkipStrictExistCheck: true, } - defaultInsertConfig := &payload.Insert_Config{ - SkipStrictExistCheck: true, - } - beforeFunc := func(t *testing.T, ctx context.Context, objectType string) func(*testing.T, args) (Server, error) { + defaultBeforeFunc := func(t *testing.T, ctx context.Context, a args, objectType string) (Server, error) { t.Helper() if objectType == "" { objectType = ngt.Float.String() @@ -113,26 +112,28 @@ func Test_server_Update(t *testing.T) { }, } - return func(t *testing.T, a args) (Server, error) { - t.Helper() - var overwriteVec [][]float32 - if a.indexVector != nil { - overwriteVec = [][]float32{ - a.indexVector, - } - } + insertCfg := &payload.Insert_Config{ + SkipStrictExistCheck: true, + Timestamp: a.indexTS, + } - eg, ctx := errgroup.New(ctx) - ngt, err := newIndexedNGTService(ctx, eg, request.Float, vector.Gaussian, insertNum, defaultInsertConfig, cfg, nil, []string{a.indexID}, overwriteVec) - if err != nil { - return nil, err - } - s, err := New(WithErrGroup(eg), WithNGT(ngt)) - if err != nil { - return nil, err + var overwriteVec [][]float32 + if a.indexVector != nil { + overwriteVec = [][]float32{ + a.indexVector, } - return s, nil } + + eg, ctx := errgroup.New(ctx) + ngt, err := newIndexedNGTService(ctx, eg, request.Float, vector.Gaussian, insertNum, insertCfg, cfg, nil, []string{a.indexID}, overwriteVec) + if err != nil { + return nil, err + } + s, err := New(WithErrGroup(eg), WithNGT(ngt)) + if err != nil { + return nil, err + } + return s, nil } /* @@ -179,6 +180,7 @@ func Test_server_Update(t *testing.T) { - case 2.1: fail update with one duplicated vector, duplicated ID and SkipStrictExistCheck is false - case 2.2: success update with one different vector, duplicated ID and SkipStrictExistsCheck is false - case 2.3: success update with one duplicated vector, different ID and SkipStrictExistCheck is false + - case 3.1: success update timestamp with one same ID and vector, and UpdateTimestampIfExists is true */ tests := []test{ { @@ -832,6 +834,56 @@ func Test_server_Update(t *testing.T) { wantUUID: "uuid-2", }, }, + func() test { + indexID := "test" + indexVector := vector.GaussianDistributedFloat32VectorGenerator(1, dimension)[0] + newTs := time.Now() + ts := newTs.Add(-2 * time.Minute) + + return test{ + name: "Decision Table Testing case 3.1: success update timestamp with one same ID and vector, and UpdateTimestampIfExists is true", + args: args{ + indexID: indexID, + indexVector: indexVector, + indexTS: ts.UnixNano(), + + req: &payload.Update_Request{ + Vector: &payload.Object_Vector{ + Id: indexID, + Vector: indexVector, + Timestamp: newTs.UnixNano(), + }, + Config: &payload.Update_Config{ + SkipStrictExistCheck: false, + UpdateTimestampIfExists: true, + }, + }, + }, + want: want{ + wantUUID: indexID, + }, + checkFunc: func(w want, o *payload.Object_Location, s Server, err error) error { + if err := defaultCheckFunc(w, o, s, err); err != nil { + return err + } + ov, err := s.GetObject(context.Background(), &payload.Object_VectorRequest{ + Id: &payload.Object_ID{ + Id: indexID, + }, + }) + if err != nil { + return err + } + + got := ov.GetTimestamp() + want := newTs.UnixNano() + if got != want { + return errors.Errorf("timestamp is not updated, got: %v, want: %v", got, want) + } + return nil + }, + } + }(), } for _, tc := range tests { @@ -842,23 +894,26 @@ func Test_server_Update(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() + if test.afterFunc != nil { + defer test.afterFunc(test.args) + } + + beforeFunc := test.beforeFunc if test.beforeFunc == nil { - test.beforeFunc = beforeFunc(tt, ctx, tc.fields.objectType) + beforeFunc = defaultBeforeFunc } - s, err := test.beforeFunc(tt, test.args) + s, err := beforeFunc(tt, ctx, test.args, test.fields.objectType) if err != nil { tt.Errorf("error = %v", err) } - if test.afterFunc != nil { - defer test.afterFunc(test.args) - } + + gotRes, err := s.Update(ctx, test.args.req) + checkFunc := test.checkFunc if test.checkFunc == nil { checkFunc = defaultCheckFunc } - - gotRes, err := s.Update(ctx, test.args.req) - if err := checkFunc(test.want, gotRes, err); err != nil { + if err := checkFunc(test.want, gotRes, s, err); err != nil { tt.Errorf("error = %v", err) } })