Skip to content

Commit

Permalink
fix: refactor stream list object rpc
Browse files Browse the repository at this point in the history
Signed-off-by: hlts2 <[email protected]>
  • Loading branch information
hlts2 committed Nov 21, 2023
1 parent 374dc57 commit 9e87530
Showing 1 changed file with 112 additions and 47 deletions.
159 changes: 112 additions & 47 deletions pkg/gateway/mirror/handler/grpc/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"fmt"
"io"
"reflect"
"sync/atomic"

"github.com/vdaas/vald/apis/grpc/v1/payload"
"github.com/vdaas/vald/apis/grpc/v1/vald"
Expand Down Expand Up @@ -3037,7 +3038,7 @@ func (s *server) StreamListObject(req *payload.Object_List_Request, stream vald.
}
}()

_, err := s.gateway.Do(ctx, s.vAddr, func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) (interface{}, error) {
_, err := s.gateway.Do(ctx, s.vAddr, func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) (obj interface{}, err error) {
ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "Do/"+target), apiName+"/"+vald.StreamListObjectRPCName+"/"+target)
defer func() {
if span != nil {
Expand All @@ -3048,52 +3049,7 @@ func (s *server) StreamListObject(req *payload.Object_List_Request, stream vald.
if err != nil {
return nil, err
}

ctx, cancel := context.WithCancel(ctx)
defer cancel()
eg, egctx := errgroup.WithContext(ctx)
eg.SetLimit(s.streamConcurrency)

var (
mu sync.Mutex
emu sync.Mutex
errs = make([]error, 0, s.streamConcurrency)
)
finalize := func() error {
err := eg.Wait()
if err != nil {

}
return nil
}
for {
select {
case <-egctx.Done():
return nil, finalize()
default:
res, err := client.Recv()
if err != nil {
if err != io.EOF && errors.Is(err, io.EOF) {
return nil, finalize()
}
return nil, errors.Join(err, finalize())
}
if res != nil {
eg.Go(safety.RecoverFunc(func() (err error) {
// TODO: add trace
mu.Lock()
err = stream.Send(res)
mu.Unlock()
if err != nil {
emu.Lock()
errs = append(errs, err)
emu.Unlock()
}
return nil
}))
}
}
}
return obj, s.streamListObject(ctx, client, stream)
})
if err != nil {
if span != nil {
Expand All @@ -3107,6 +3063,115 @@ func (s *server) StreamListObject(req *payload.Object_List_Request, stream vald.
return nil
}

func (s *server) streamListObject(ctx context.Context, clientS vald.Object_StreamListObjectClient, serverS vald.Object_StreamListObjectServer) (err error) {
cctx, cancel := context.WithCancel(ctx)
defer cancel()
eg, egctx := errgroup.WithContext(cctx)
eg.SetLimit(s.streamConcurrency)

var mu, rmu sync.Mutex
var egCnt int64
for {
select {
case <-egctx.Done():
// If the root context is not canceld error, it is treated as an error.
if ctx.Err() != nil && !errors.Is(ctx.Err(), context.Canceled) {
err = errors.Join(ctx.Err(), err)
}
if egerr := eg.Wait(); egerr != nil {
err = errors.Join(err, egerr)
}
return err
default:
eg.Go(safety.RecoverFunc(func() (err error) {
id := fmt.Sprintf("stream-%020d", atomic.AddInt64(&egCnt, 1))
_, span := trace.StartSpan(egctx, apiName+"/streamListObject/"+id)
defer func() {
if span != nil {
span.End()
}
}()

rmu.Lock()
res, err := clientS.Recv()
rmu.Unlock()
if err != nil {
if errors.Is(err, io.EOF) {
cancel()
return nil
}
err = errors.ErrServerStreamClientRecv(err)
var attr trace.Attributes
switch {
case errors.Is(err, context.Canceled):
err = status.WrapWithCanceled("Stream Recv returned canceld error at "+id, err)
attr = trace.StatusCodeCancelled(err.Error())
case errors.Is(err, context.DeadlineExceeded):
err = status.WrapWithDeadlineExceeded("Stream Recv returned deadlin exceeded error at "+id, err)
attr = trace.StatusCodeDeadlineExceeded(err.Error())
default:
var (
st *status.Status
msg string
)
st, msg, err = status.ParseError(err, codes.Internal, "Stream Recv returned an error at "+id)
if st != nil {
attr = trace.FromGRPCStatus(st.Code(), msg)
}
}
log.Warn(err)
if span != nil {
span.RecordError(err)
span.SetAttributes(attr...)
span.SetStatus(trace.StatusError, err.Error())
}
return err
}
if res.GetVector() == nil {
return nil
}

mu.Lock()
err = serverS.Send(res)
mu.Unlock()
if err != nil {
if errors.Is(err, io.EOF) {
cancel()
return nil
}
err = errors.ErrServerStreamServerSend(err)
var attr trace.Attributes
switch {
case errors.Is(err, context.Canceled):
err = status.WrapWithCanceled("Stream Send returned canceld error at "+id, err)
attr = trace.StatusCodeCancelled(err.Error())
case errors.Is(err, context.DeadlineExceeded):
err = status.WrapWithDeadlineExceeded("Stream Send returned deadlin exceeded error at "+id, err)
attr = trace.StatusCodeDeadlineExceeded(err.Error())
default:
var (
st *status.Status
msg string
)
st, msg, err = status.ParseError(err, codes.Internal, "Stream Send returned an error at "+id)
if st != nil {
attr = trace.FromGRPCStatus(st.Code(), msg)
}
}
log.Warn(err)
if span != nil {
span.RecordError(err)
span.SetAttributes(attr...)
span.SetStatus(trace.StatusError, err.Error())
}
return err
}
return nil
}))
}
}
}

type errorState struct {
err error
code codes.Code
Expand Down

0 comments on commit 9e87530

Please sign in to comment.