Skip to content

Commit

Permalink
Add StreamListObject to LB (#2203)
Browse files Browse the repository at this point in the history
* Add StreamListObject to LB

* Add E2E for StreamListObject

* Update error handling

* Fix StreamListObject e2e verification

* Update internal/errors/grpc.go

Co-authored-by: Kiichiro YUKAWA <[email protected]>

---------

Co-authored-by: Kiichiro YUKAWA <[email protected]>
  • Loading branch information
ykadowak and vankichi authored Oct 11, 2023
1 parent 366cd02 commit 9f27ff2
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 3 deletions.
10 changes: 10 additions & 0 deletions internal/errors/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,14 @@ var (
ErrInvalidProtoMessageType = func(v interface{}) error {
return Errorf("failed to marshal/unmarshal proto message, message type is %T (missing vtprotobuf/protobuf helpers)", v)
}

// ErrServerStreamClientRecv represents a function to generate an error that the gRPC client couldn't receive from stream.
ErrServerStreamClientRecv = func(err error) error {
return Wrap(err, "gRPC client failed to receive from stream")
}

// ErrServerStreamClientSend represents a function to generate an error that the gRPC server couldn't send to stream.
ErrServerStreamServerSend = func(err error) error {
return Wrap(err, "gRPC server failed to send to stream")
}
)
90 changes: 89 additions & 1 deletion pkg/gateway/lb/handler/grpc/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package grpc
import (
"context"
"fmt"
"io"
"slices"
"strconv"
"sync/atomic"
Expand Down Expand Up @@ -2907,7 +2908,7 @@ func (s *server) getObject(ctx context.Context, uuid string) (vec *payload.Objec
ech <- s.gateway.BroadCast(ctx, func(ctx context.Context, target string, vc vald.Client, copts ...grpc.CallOption) error {
sctx, sspan := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "BroadCast/"+target), apiName+"/getObject/BroadCast/"+target)
defer func() {
if span != nil {
if sspan != nil {
sspan.End()
}
}()
Expand Down Expand Up @@ -3134,3 +3135,90 @@ func (s *server) StreamGetObject(stream vald.Object_StreamGetObjectServer) (err
}
return nil
}

func (s *server) StreamListObject(req *payload.Object_List_Request, stream vald.Object_StreamListObjectServer) error {
ctx, span := trace.StartSpan(grpc.WithGRPCMethod(stream.Context(), vald.PackageName+"."+vald.ObjectRPCServiceName+"/"+vald.StreamListObjectRPCName), apiName+"/"+vald.StreamListObjectRPCName)
defer func() {
if span != nil {
span.End()
}
}()

ctx, cancel := context.WithCancel(ctx)
defer cancel()

var rmu, smu sync.Mutex
err := s.gateway.BroadCast(ctx, func(ctx context.Context, target string, vc vald.Client, copts ...grpc.CallOption) error {
ctx, sspan := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "BroadCast/"+target), apiName+"/"+vald.StreamListObjectRPCName+"/"+target)
defer func() {
if sspan != nil {
sspan.End()
}
}()

client, err := vc.StreamListObject(ctx, req, copts...)
if err != nil {
log.Errorf("failed to get StreamListObject client for agent(%s): %v", target, err)
return err
}

eg, ctx := errgroup.WithContext(ctx)
ectx, ecancel := context.WithCancel(ctx)
defer ecancel()
eg.SetLimit(s.streamConcurrency)

for {
select {
case <-ectx.Done():
var err error
if !errors.Is(ctx.Err(), context.Canceled) {
err = errors.Join(err, ctx.Err())
}
if egerr := eg.Wait(); err != nil {
err = errors.Join(err, egerr)
}
return err
default:
eg.Go(safety.RecoverFunc(func() error {
rmu.Lock()
res, err := client.Recv()
rmu.Unlock()
if err != nil {
if errors.Is(err, io.EOF) {
ecancel()
return nil
}
return errors.ErrServerStreamClientRecv(err)
}

vec := res.GetVector()
if vec == nil {
st := res.GetStatus()
log.Warnf("received empty vector: code %v: details %v: message %v",
st.GetCode(),
st.GetDetails(),
st.GetMessage(),
)
return nil
}

smu.Lock()
err = stream.Send(res)
smu.Unlock()
if err != nil {
if sspan != nil {
st, msg, err := status.ParseError(err, codes.Internal, "failed to parse StreamListObject send gRPC error response")
sspan.RecordError(err)
sspan.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...)
sspan.SetStatus(trace.StatusError, err.Error())
}
return errors.ErrServerStreamServerSend(err)
}

return nil
}))
}
}
})
return err
}
4 changes: 2 additions & 2 deletions pkg/gateway/lb/service/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ type Gateway interface {
GetAgentCount(ctx context.Context) int
Addrs(ctx context.Context) []string
DoMulti(ctx context.Context, num int,
f func(ctx context.Context, tgt string, ac vald.Client, copts ...grpc.CallOption) error) error
f func(ctx context.Context, target string, ac vald.Client, copts ...grpc.CallOption) error) error
BroadCast(ctx context.Context,
f func(ctx context.Context, tgt string, ac vald.Client, copts ...grpc.CallOption) error) error
f func(ctx context.Context, target string, ac vald.Client, copts ...grpc.CallOption) error) error
}

type gateway struct {
Expand Down
7 changes: 7 additions & 0 deletions tests/e2e/crud/crud_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,13 @@ func TestE2EStandardCRUD(t *testing.T) {
t.Fatalf("an error occurred: %s", err)
}

err = op.StreamListObject(t, ctx, operation.Dataset{
Train: ds.Train[insertFrom : insertFrom+insertNum],
})
if err != nil {
t.Fatalf("an error occurred: %s", err)
}

err = op.Update(t, ctx, operation.Dataset{
Train: ds.Train[updateFrom : updateFrom+updateNum],
})
Expand Down
1 change: 1 addition & 0 deletions tests/e2e/operation/operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ type Client interface {
MultiUpsert(t *testing.T, ctx context.Context, ds Dataset) error
MultiRemove(t *testing.T, ctx context.Context, ds Dataset) error
GetObject(t *testing.T, ctx context.Context, ds Dataset) error
StreamListObject(t *testing.T, ctx context.Context, ds Dataset) error
Exists(t *testing.T, ctx context.Context, id string) error
CreateIndex(t *testing.T, ctx context.Context) error
SaveIndex(t *testing.T, ctx context.Context) error
Expand Down
59 changes: 59 additions & 0 deletions tests/e2e/operation/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package operation

import (
"context"
"fmt"
"reflect"
"strconv"
"testing"
Expand Down Expand Up @@ -1167,3 +1168,61 @@ func (c *client) GetObject(

return rerr
}

func (c *client) StreamListObject(
t *testing.T,
ctx context.Context,
ds Dataset,
) error {
t.Log("StreamListObject operation started")

client, err := c.getClient(ctx)
if err != nil {
return err
}

sc, err := client.StreamListObject(ctx, &payload.Object_List_Request{})
if err != nil {
return err
}

// kv : [indexId]count
indexCnt := make(map[string]int)
exit_loop:
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
res, err := sc.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
break exit_loop
}
return err
}
vec := res.GetVector()
if vec == nil {
st := res.GetStatus()
return fmt.Errorf("returned vector is empty: code: %v, msg: %v, details: %v", st.GetCode(), st.GetMessage(), st.GetDetails())
}
indexCnt[vec.GetId()]++
}
}

if len(indexCnt) != len(ds.Train) {
return fmt.Errorf("the number of vectors returned is different: got %v, want %v", len(indexCnt), len(ds.Train))
}

replica := -1
for k, v := range indexCnt {
if replica == -1 {
replica = v
continue
}
if v != replica {
return fmt.Errorf("the number of vectors returned is different at index id %v: got %v, want %v", k, v, replica)
}
}
return nil
}

0 comments on commit 9f27ff2

Please sign in to comment.