diff --git a/internal/errors/grpc.go b/internal/errors/grpc.go index b14c754fd5..a5675de8d6 100644 --- a/internal/errors/grpc.go +++ b/internal/errors/grpc.go @@ -70,4 +70,14 @@ var ( ErrInvalidProtoMessageType = func(v interface{}) error { return Errorf("failed to marshal/unmarshal proto message, message type is %T (missing vtprotobuf/protobuf helpers)", v) } + + // ErrServerStreamClientRecv represents a function to generate an error that the gRPC client couldn't receive from stream. + ErrServerStreamClientRecv = func(err error) error { + return Wrap(err, "gRPC client failed to receive from stream") + } + + // ErrServerStreamClientSend represents a function to generate an error that the gRPC server couldn't send to stream. + ErrServerStreamServerSend = func(err error) error { + return Wrap(err, "gRPC server failed to send to stream") + } ) diff --git a/pkg/gateway/lb/handler/grpc/handler.go b/pkg/gateway/lb/handler/grpc/handler.go index 27f5bf8c90..0554a9c2cb 100644 --- a/pkg/gateway/lb/handler/grpc/handler.go +++ b/pkg/gateway/lb/handler/grpc/handler.go @@ -20,6 +20,7 @@ package grpc import ( "context" "fmt" + "io" "slices" "strconv" "sync/atomic" @@ -2907,7 +2908,7 @@ func (s *server) getObject(ctx context.Context, uuid string) (vec *payload.Objec ech <- s.gateway.BroadCast(ctx, func(ctx context.Context, target string, vc vald.Client, copts ...grpc.CallOption) error { sctx, sspan := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "BroadCast/"+target), apiName+"/getObject/BroadCast/"+target) defer func() { - if span != nil { + if sspan != nil { sspan.End() } }() @@ -3134,3 +3135,90 @@ func (s *server) StreamGetObject(stream vald.Object_StreamGetObjectServer) (err } return nil } + +func (s *server) StreamListObject(req *payload.Object_List_Request, stream vald.Object_StreamListObjectServer) error { + ctx, span := trace.StartSpan(grpc.WithGRPCMethod(stream.Context(), vald.PackageName+"."+vald.ObjectRPCServiceName+"/"+vald.StreamListObjectRPCName), apiName+"/"+vald.StreamListObjectRPCName) + defer func() { + if span != nil { + span.End() + } + }() + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + var rmu, smu sync.Mutex + err := s.gateway.BroadCast(ctx, func(ctx context.Context, target string, vc vald.Client, copts ...grpc.CallOption) error { + ctx, sspan := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "BroadCast/"+target), apiName+"/"+vald.StreamListObjectRPCName+"/"+target) + defer func() { + if sspan != nil { + sspan.End() + } + }() + + client, err := vc.StreamListObject(ctx, req, copts...) + if err != nil { + log.Errorf("failed to get StreamListObject client for agent(%s): %v", target, err) + return err + } + + eg, ctx := errgroup.WithContext(ctx) + ectx, ecancel := context.WithCancel(ctx) + defer ecancel() + eg.SetLimit(s.streamConcurrency) + + for { + select { + case <-ectx.Done(): + var err error + if !errors.Is(ctx.Err(), context.Canceled) { + err = errors.Join(err, ctx.Err()) + } + if egerr := eg.Wait(); err != nil { + err = errors.Join(err, egerr) + } + return err + default: + eg.Go(safety.RecoverFunc(func() error { + rmu.Lock() + res, err := client.Recv() + rmu.Unlock() + if err != nil { + if errors.Is(err, io.EOF) { + ecancel() + return nil + } + return errors.ErrServerStreamClientRecv(err) + } + + vec := res.GetVector() + if vec == nil { + st := res.GetStatus() + log.Warnf("received empty vector: code %v: details %v: message %v", + st.GetCode(), + st.GetDetails(), + st.GetMessage(), + ) + return nil + } + + smu.Lock() + err = stream.Send(res) + smu.Unlock() + if err != nil { + if sspan != nil { + st, msg, err := status.ParseError(err, codes.Internal, "failed to parse StreamListObject send gRPC error response") + sspan.RecordError(err) + sspan.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + sspan.SetStatus(trace.StatusError, err.Error()) + } + return errors.ErrServerStreamServerSend(err) + } + + return nil + })) + } + } + }) + return err +} diff --git a/pkg/gateway/lb/service/gateway.go b/pkg/gateway/lb/service/gateway.go index a58cd8a5f3..7d32364df0 100644 --- a/pkg/gateway/lb/service/gateway.go +++ b/pkg/gateway/lb/service/gateway.go @@ -36,9 +36,9 @@ type Gateway interface { GetAgentCount(ctx context.Context) int Addrs(ctx context.Context) []string DoMulti(ctx context.Context, num int, - f func(ctx context.Context, tgt string, ac vald.Client, copts ...grpc.CallOption) error) error + f func(ctx context.Context, target string, ac vald.Client, copts ...grpc.CallOption) error) error BroadCast(ctx context.Context, - f func(ctx context.Context, tgt string, ac vald.Client, copts ...grpc.CallOption) error) error + f func(ctx context.Context, target string, ac vald.Client, copts ...grpc.CallOption) error) error } type gateway struct { diff --git a/tests/e2e/crud/crud_test.go b/tests/e2e/crud/crud_test.go index f78a99c6aa..b5a774b88b 100644 --- a/tests/e2e/crud/crud_test.go +++ b/tests/e2e/crud/crud_test.go @@ -375,6 +375,13 @@ func TestE2EStandardCRUD(t *testing.T) { t.Fatalf("an error occurred: %s", err) } + err = op.StreamListObject(t, ctx, operation.Dataset{ + Train: ds.Train[insertFrom : insertFrom+insertNum], + }) + if err != nil { + t.Fatalf("an error occurred: %s", err) + } + err = op.Update(t, ctx, operation.Dataset{ Train: ds.Train[updateFrom : updateFrom+updateNum], }) diff --git a/tests/e2e/operation/operation.go b/tests/e2e/operation/operation.go index 6cfc078ee8..24cc988f7e 100644 --- a/tests/e2e/operation/operation.go +++ b/tests/e2e/operation/operation.go @@ -127,6 +127,7 @@ type Client interface { MultiUpsert(t *testing.T, ctx context.Context, ds Dataset) error MultiRemove(t *testing.T, ctx context.Context, ds Dataset) error GetObject(t *testing.T, ctx context.Context, ds Dataset) error + StreamListObject(t *testing.T, ctx context.Context, ds Dataset) error Exists(t *testing.T, ctx context.Context, id string) error CreateIndex(t *testing.T, ctx context.Context) error SaveIndex(t *testing.T, ctx context.Context) error diff --git a/tests/e2e/operation/stream.go b/tests/e2e/operation/stream.go index 585039dc2b..dd37d6ad67 100644 --- a/tests/e2e/operation/stream.go +++ b/tests/e2e/operation/stream.go @@ -17,6 +17,7 @@ package operation import ( "context" + "fmt" "reflect" "strconv" "testing" @@ -1167,3 +1168,61 @@ func (c *client) GetObject( return rerr } + +func (c *client) StreamListObject( + t *testing.T, + ctx context.Context, + ds Dataset, +) error { + t.Log("StreamListObject operation started") + + client, err := c.getClient(ctx) + if err != nil { + return err + } + + sc, err := client.StreamListObject(ctx, &payload.Object_List_Request{}) + if err != nil { + return err + } + + // kv : [indexId]count + indexCnt := make(map[string]int) +exit_loop: + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + res, err := sc.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + break exit_loop + } + return err + } + vec := res.GetVector() + if vec == nil { + st := res.GetStatus() + return fmt.Errorf("returned vector is empty: code: %v, msg: %v, details: %v", st.GetCode(), st.GetMessage(), st.GetDetails()) + } + indexCnt[vec.GetId()]++ + } + } + + if len(indexCnt) != len(ds.Train) { + return fmt.Errorf("the number of vectors returned is different: got %v, want %v", len(indexCnt), len(ds.Train)) + } + + replica := -1 + for k, v := range indexCnt { + if replica == -1 { + replica = v + continue + } + if v != replica { + return fmt.Errorf("the number of vectors returned is different at index id %v: got %v, want %v", k, v, replica) + } + } + return nil +}