Skip to content

Commit

Permalink
impl test
Browse files Browse the repository at this point in the history
Signed-off-by: kevindiu <[email protected]>
  • Loading branch information
kevindiu committed Sep 19, 2023
1 parent 4186158 commit 4024583
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 34 deletions.
1 change: 0 additions & 1 deletion apis/grpc/v1/payload/payload.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pkg/agent/core/ngt/handler/grpc/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func newIndexedNGTService(ctx context.Context, eg errgroup.Group, t request.Obje

// insert and create index
for _, req := range reqs.GetRequests() {
err := ngt.Insert(req.GetVector().GetId(), req.GetVector().GetVector())
err := ngt.InsertWithTime(req.GetVector().GetId(), req.GetVector().GetVector(), req.GetVector().GetTimestamp())
if err != nil {
return nil, err
}
Expand Down
119 changes: 87 additions & 32 deletions pkg/agent/core/ngt/handler/grpc/update_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"context"
"math"
"testing"
"time"

"github.com/vdaas/vald/apis/grpc/v1/payload"
"github.com/vdaas/vald/internal/config"
Expand All @@ -36,6 +37,7 @@ func Test_server_Update(t *testing.T) {
type args struct {
indexID string
indexVector []float32
indexTS int64
req *payload.Update_Request
}
type fields struct {
Expand All @@ -50,11 +52,11 @@ func Test_server_Update(t *testing.T) {
args args
fields fields
want want
checkFunc func(want, *payload.Object_Location, error) error
beforeFunc func(*testing.T, args) (Server, error)
checkFunc func(want, *payload.Object_Location, Server, error) error
beforeFunc func(*testing.T, context.Context, args, string) (Server, error)
afterFunc func(args)
}
defaultCheckFunc := func(w want, gotRes *payload.Object_Location, err error) error {
defaultCheckFunc := func(w want, gotRes *payload.Object_Location, s Server, err error) error {
if err != nil {
st, ok := status.FromError(err)
if !ok {
Expand Down Expand Up @@ -89,10 +91,7 @@ func Test_server_Update(t *testing.T) {
defaultUpdateConfig := &payload.Update_Config{
SkipStrictExistCheck: true,
}
defaultInsertConfig := &payload.Insert_Config{
SkipStrictExistCheck: true,
}
beforeFunc := func(t *testing.T, ctx context.Context, objectType string) func(*testing.T, args) (Server, error) {
defaultBeforeFunc := func(t *testing.T, ctx context.Context, a args, objectType string) (Server, error) {
t.Helper()
if objectType == "" {
objectType = ngt.Float.String()
Expand All @@ -113,26 +112,28 @@ func Test_server_Update(t *testing.T) {
},
}

return func(t *testing.T, a args) (Server, error) {
t.Helper()
var overwriteVec [][]float32
if a.indexVector != nil {
overwriteVec = [][]float32{
a.indexVector,
}
}
insertCfg := &payload.Insert_Config{
SkipStrictExistCheck: true,
Timestamp: a.indexTS,
}

eg, ctx := errgroup.New(ctx)
ngt, err := newIndexedNGTService(ctx, eg, request.Float, vector.Gaussian, insertNum, defaultInsertConfig, cfg, nil, []string{a.indexID}, overwriteVec)
if err != nil {
return nil, err
}
s, err := New(WithErrGroup(eg), WithNGT(ngt))
if err != nil {
return nil, err
var overwriteVec [][]float32
if a.indexVector != nil {
overwriteVec = [][]float32{
a.indexVector,
}
return s, nil
}

eg, ctx := errgroup.New(ctx)
ngt, err := newIndexedNGTService(ctx, eg, request.Float, vector.Gaussian, insertNum, insertCfg, cfg, nil, []string{a.indexID}, overwriteVec)
if err != nil {
return nil, err
}
s, err := New(WithErrGroup(eg), WithNGT(ngt))
if err != nil {
return nil, err
}
return s, nil
}

/*
Expand Down Expand Up @@ -179,6 +180,7 @@ func Test_server_Update(t *testing.T) {
- case 2.1: fail update with one duplicated vector, duplicated ID and SkipStrictExistCheck is false
- case 2.2: success update with one different vector, duplicated ID and SkipStrictExistsCheck is false
- case 2.3: success update with one duplicated vector, different ID and SkipStrictExistCheck is false
- case 3.1: success update timestamp with one same ID and vector, and UpdateTimestampIfExists is true
*/
tests := []test{
{
Expand Down Expand Up @@ -832,6 +834,56 @@ func Test_server_Update(t *testing.T) {
wantUUID: "uuid-2",
},
},
func() test {
indexID := "test"
indexVector := vector.GaussianDistributedFloat32VectorGenerator(1, dimension)[0]
newTs := time.Now()
ts := newTs.Add(-2 * time.Minute)

return test{
name: "Decision Table Testing case 3.1: success update timestamp with one same ID and vector, and UpdateTimestampIfExists is true",
args: args{
indexID: indexID,
indexVector: indexVector,
indexTS: ts.UnixNano(),

req: &payload.Update_Request{
Vector: &payload.Object_Vector{
Id: indexID,
Vector: indexVector,
Timestamp: newTs.UnixNano(),
},
Config: &payload.Update_Config{
SkipStrictExistCheck: false,
UpdateTimestampIfExists: true,
},
},
},
want: want{
wantUUID: indexID,
},
checkFunc: func(w want, o *payload.Object_Location, s Server, err error) error {
if err := defaultCheckFunc(w, o, s, err); err != nil {
return err
}
ov, err := s.GetObject(context.Background(), &payload.Object_VectorRequest{
Id: &payload.Object_ID{
Id: indexID,
},
})
if err != nil {
return err
}

got := ov.GetTimestamp()
want := newTs.UnixNano()
if got != want {
return errors.Errorf("timestamp is not updated, got: %v, want: %v", got, want)
}
return nil
},
}
}(),
}

for _, tc := range tests {
Expand All @@ -842,23 +894,26 @@ func Test_server_Update(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

if test.afterFunc != nil {
defer test.afterFunc(test.args)
}

beforeFunc := test.beforeFunc
if test.beforeFunc == nil {
test.beforeFunc = beforeFunc(tt, ctx, tc.fields.objectType)
beforeFunc = defaultBeforeFunc
}
s, err := test.beforeFunc(tt, test.args)
s, err := beforeFunc(tt, ctx, test.args, test.fields.objectType)
if err != nil {
tt.Errorf("error = %v", err)
}
if test.afterFunc != nil {
defer test.afterFunc(test.args)
}

gotRes, err := s.Update(ctx, test.args.req)

checkFunc := test.checkFunc
if test.checkFunc == nil {
checkFunc = defaultCheckFunc
}

gotRes, err := s.Update(ctx, test.args.req)
if err := checkFunc(test.want, gotRes, err); err != nil {
if err := checkFunc(test.want, gotRes, s, err); err != nil {
tt.Errorf("error = %v", err)
}
})
Expand Down

0 comments on commit 4024583

Please sign in to comment.