Skip to content

Commit

Permalink
implement pkg handler remove test cases (#1644)
Browse files Browse the repository at this point in the history
* add drafts of remove test cases

Signed-off-by: Kosuke Morimoto <[email protected]>

* implement remove handler test cases

Signed-off-by: Kosuke Morimoto <[email protected]>

* add error processing

Signed-off-by: Kosuke Morimoto <[email protected]>

* fix handler's error processing

Signed-off-by: Kosuke Morimoto <[email protected]>

* fix handler's error processing to GetObject

Signed-off-by: Kosuke Morimoto <[email protected]>

* add ErrInvalidUUID test case

Signed-off-by: Kosuke Morimoto <[email protected]>

* Update pkg/agent/core/ngt/handler/grpc/handler.go

Co-authored-by: Yusuke Kato <[email protected]>

Co-authored-by: Yusuke Kato <[email protected]>
  • Loading branch information
kmrmt and kpango authored May 17, 2022
1 parent 6d6a6c4 commit d55e2db
Show file tree
Hide file tree
Showing 4 changed files with 564 additions and 74 deletions.
5 changes: 5 additions & 0 deletions internal/errors/ngt.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ var (
return Errorf("dimension size %d is invalid, the supporting dimension size must be between 2 ~ %d", current, limit)
}

// ErrInvalidUUID represents a function to generate an error that the uuid is invalid.
ErrInvalidUUID = func(uuid string) error {
return Errorf("uuid \"%s\" is invalid", uuid)
}

// ErrDimensionLimitExceed represents a function to generate an error that the supported dimension limit exceeded.
ErrDimensionLimitExceed = func(current, limit int) error {
return Errorf("supported dimension limit exceed:\trequired = %d,\tlimit = %d", current, limit)
Expand Down
63 changes: 63 additions & 0 deletions internal/errors/ngt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,69 @@ func TestErrInvalidDimensionSize(t *testing.T) {
}
}

func TestErrInvalidUUID(t *testing.T) {
type args struct {
uuid string
}
type want struct {
want error
}
type test struct {
name string
args args
want want
checkFunc func(want, error) error
beforeFunc func(args)
afterFunc func(args)
}
defaultCheckFunc := func(w want, got error) error {
if !Is(got, w.want) {
return Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", got, w.want)
}
return nil
}
tests := []test{
{
name: "return an ErrInvalidUUID error when uuid is empty string",
args: args{
uuid: "",
},
want: want{
want: New("uuid \"\" is invalid"),
},
},
{
name: "return an ErrInvalidUUID error when uuid is foo",
args: args{
uuid: "foo",
},
want: want{
want: New("uuid \"foo\" is invalid"),
},
},
}

for _, test := range tests {
t.Run(test.name, func(tt *testing.T) {
if test.beforeFunc != nil {
test.beforeFunc(test.args)
}
if test.afterFunc != nil {
defer test.afterFunc(test.args)
}
checkFunc := test.checkFunc
if test.checkFunc == nil {
checkFunc = defaultCheckFunc
}

got := ErrInvalidUUID(test.args.uuid)
if err := checkFunc(test.want, got); err != nil {
tt.Errorf("error = %v", err)
}
})
}
}

func TestErrDimensionLimitExceed(t *testing.T) {
type args struct {
current int
Expand Down
164 changes: 161 additions & 3 deletions pkg/agent/core/ngt/handler/grpc/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,28 @@ func (s *server) Exists(ctx context.Context, uid *payload.Object_ID) (res *paylo
}
}()
uuid := uid.GetId()
if len(uuid) == 0 {
err = errors.ErrInvalidUUID(uuid)
err = status.WrapWithInvalidArgument(fmt.Sprintf("Exists API invalid argument for uuid \"%s\" detected", uuid), err,
&errdetails.RequestInfo{
RequestId: uuid,
ServingData: errdetails.Serialize(uid),
},
&errdetails.BadRequest{
FieldViolations: []*errdetails.BadRequestFieldViolation{
{
Field: "uuid",
Description: err.Error(),
},
},
},
&errdetails.ResourceInfo{
ResourceType: ngtResourceType + "/ngt.Exists",
ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip),
})
log.Warn(err)
return nil, err
}
oid, ok := s.ngt.Exists(uuid)
if !ok {
err = errors.ErrObjectIDNotFound(uid.GetId())
Expand Down Expand Up @@ -267,8 +289,31 @@ func (s *server) SearchByID(ctx context.Context, req *payload.Search_IDRequest)
span.End()
}
}()
uuid := req.GetId()
if len(uuid) == 0 {
err = errors.ErrInvalidUUID(uuid)
err = status.WrapWithInvalidArgument(fmt.Sprintf("SearchByID API invalid argument for uuid \"%s\" detected", uuid), err,
&errdetails.RequestInfo{
RequestId: uuid,
ServingData: errdetails.Serialize(req),
},
&errdetails.BadRequest{
FieldViolations: []*errdetails.BadRequestFieldViolation{
{
Field: "uuid",
Description: err.Error(),
},
},
},
&errdetails.ResourceInfo{
ResourceType: ngtResourceType + "/ngt.SearchByID",
ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip),
})
log.Warn(err)
return nil, err
}
vec, dst, err := s.ngt.SearchByID(
req.GetId(),
uuid,
req.GetConfig().GetNum(),
req.GetConfig().GetEpsilon(),
req.GetConfig().GetRadius())
Expand Down Expand Up @@ -747,8 +792,31 @@ func (s *server) LinearSearchByID(ctx context.Context, req *payload.Search_IDReq
span.End()
}
}()
uuid := req.GetId()
if len(uuid) == 0 {
err = errors.ErrInvalidUUID(uuid)
err = status.WrapWithInvalidArgument(fmt.Sprintf("LinearSearchByID API invalid argument for uuid \"%s\" detected", uuid), err,
&errdetails.RequestInfo{
RequestId: uuid,
ServingData: errdetails.Serialize(req),
},
&errdetails.BadRequest{
FieldViolations: []*errdetails.BadRequestFieldViolation{
{
Field: "uuid",
Description: err.Error(),
},
},
},
&errdetails.ResourceInfo{
ResourceType: ngtResourceType + "/ngt.LinearSearchByID",
ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip),
})
log.Warn(err)
return nil, err
}
vec, dst, err := s.ngt.LinearSearchByID(
req.GetId(),
uuid,
req.GetConfig().GetNum())
res, err = toSearchResponse(dst, err)
if err != nil || res == nil {
Expand Down Expand Up @@ -1361,7 +1429,30 @@ func (s *server) Update(ctx context.Context, req *payload.Update_Request) (res *
}
return nil, err
}
err = s.ngt.UpdateWithTime(vec.GetId(), vec.GetVector(), req.GetConfig().GetTimestamp())
uuid := vec.GetId()
if len(uuid) == 0 {
err = errors.ErrInvalidUUID(uuid)
err = status.WrapWithInvalidArgument(fmt.Sprintf("Update API invalid argument for uuid \"%s\" detected", uuid), err,
&errdetails.RequestInfo{
RequestId: uuid,
ServingData: errdetails.Serialize(req),
},
&errdetails.BadRequest{
FieldViolations: []*errdetails.BadRequestFieldViolation{
{
Field: "uuid",
Description: err.Error(),
},
},
},
&errdetails.ResourceInfo{
ResourceType: ngtResourceType + "/ngt.Update",
ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip),
})
log.Warn(err)
return nil, err
}
err = s.ngt.UpdateWithTime(uuid, vec.GetVector(), req.GetConfig().GetTimestamp())
if err != nil {
var code trace.Status
if errors.Is(err, errors.ErrObjectIDNotFound(vec.GetId())) {
Expand Down Expand Up @@ -1643,6 +1734,29 @@ func (s *server) Upsert(ctx context.Context, req *payload.Upsert_Request) (loc *
}
return nil, err
}
uuid := vec.GetId()
if len(uuid) == 0 {
err = errors.ErrInvalidUUID(uuid)
err = status.WrapWithInvalidArgument(fmt.Sprintf("Upsert API invalid argument for uuid \"%s\" detected", uuid), err,
&errdetails.RequestInfo{
RequestId: uuid,
ServingData: errdetails.Serialize(req),
},
&errdetails.BadRequest{
FieldViolations: []*errdetails.BadRequestFieldViolation{
{
Field: "uuid",
Description: err.Error(),
},
},
},
&errdetails.ResourceInfo{
ResourceType: ngtResourceType + "/ngt.Upsert",
ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip),
})
log.Warn(err)
return nil, err
}

rtName := "/ngt.Upsert"
_, exists := s.ngt.Exists(req.GetVector().GetId())
Expand Down Expand Up @@ -1888,6 +2002,28 @@ func (s *server) Remove(ctx context.Context, req *payload.Remove_Request) (res *
}()
id := req.GetId()
uuid := id.GetId()
if len(uuid) == 0 {
err = errors.ErrInvalidUUID(uuid)
err = status.WrapWithInvalidArgument(fmt.Sprintf("Remove API invalid argument for uuid \"%s\" detected", uuid), err,
&errdetails.RequestInfo{
RequestId: uuid,
ServingData: errdetails.Serialize(req),
},
&errdetails.BadRequest{
FieldViolations: []*errdetails.BadRequestFieldViolation{
{
Field: "uuid",
Description: err.Error(),
},
},
},
&errdetails.ResourceInfo{
ResourceType: ngtResourceType + "/ngt.Remove",
ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip),
})
log.Warn(err)
return nil, err
}
err = s.ngt.DeleteWithTime(uuid, req.GetConfig().GetTimestamp())
if err != nil {
var code trace.Status
Expand Down Expand Up @@ -2073,6 +2209,28 @@ func (s *server) GetObject(ctx context.Context, id *payload.Object_VectorRequest
}
}()
uuid := id.GetId().GetId()
if len(uuid) == 0 {
err = errors.ErrInvalidUUID(uuid)
err = status.WrapWithInvalidArgument(fmt.Sprintf("GetObject API invalid argument for uuid \"%s\" detected", uuid), err,
&errdetails.RequestInfo{
RequestId: uuid,
ServingData: errdetails.Serialize(id),
},
&errdetails.BadRequest{
FieldViolations: []*errdetails.BadRequestFieldViolation{
{
Field: "uuid",
Description: err.Error(),
},
},
},
&errdetails.ResourceInfo{
ResourceType: ngtResourceType + "/ngt.GetObject",
ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip),
})
log.Warn(err)
return nil, err
}
vec, err := s.ngt.GetObject(uuid)
if err != nil || vec == nil {
err = errors.ErrObjectNotFound(err, uuid)
Expand Down
Loading

0 comments on commit d55e2db

Please sign in to comment.