diff --git a/internal/net/grpc/stream.go b/internal/net/grpc/stream.go index eac59c976c..631d139576 100644 --- a/internal/net/grpc/stream.go +++ b/internal/net/grpc/stream.go @@ -22,6 +22,7 @@ import ( "context" "fmt" "runtime" + "slices" "sync/atomic" "github.com/vdaas/vald/internal/errors" @@ -32,7 +33,6 @@ import ( "github.com/vdaas/vald/internal/net/grpc/status" "github.com/vdaas/vald/internal/observability/trace" "github.com/vdaas/vald/internal/safety" - "github.com/vdaas/vald/internal/slices" "github.com/vdaas/vald/internal/sync" "github.com/vdaas/vald/internal/sync/errgroup" "google.golang.org/grpc" @@ -75,7 +75,7 @@ func BidirectionalStream[Q any, R any](ctx context.Context, stream ServerStream, errs = append(errs, err) emu.Unlock() } - slices.RemoveDuplicates(errs, func(left, right error) int { + removeDuplicates(errs, func(left, right error) int { return cmp.Compare(left.Error(), right.Error()) }) emu.Lock() @@ -229,3 +229,11 @@ func BidirectionalStreamClient(stream ClientStream, } }() } + +func removeDuplicates[E comparable](x []E, less func(left, right E) int) []E { + if len(x) < 2 { + return x + } + slices.SortStableFunc(x, less) + return slices.Compact(x) +} diff --git a/pkg/gateway/lb/handler/grpc/aggregation.go b/pkg/gateway/lb/handler/grpc/aggregation.go index e29a4a46bc..8a700c8edd 100644 --- a/pkg/gateway/lb/handler/grpc/aggregation.go +++ b/pkg/gateway/lb/handler/grpc/aggregation.go @@ -18,6 +18,7 @@ import ( "fmt" "math" "math/big" + "slices" "sync/atomic" "time" @@ -31,7 +32,6 @@ import ( "github.com/vdaas/vald/internal/net/grpc/errdetails" "github.com/vdaas/vald/internal/net/grpc/status" "github.com/vdaas/vald/internal/observability/trace" - "github.com/vdaas/vald/internal/slices" "github.com/vdaas/vald/internal/sync" ) @@ -557,7 +557,7 @@ func newSlice(num, replica int) Aggregator { } } -func (_ *valdSliceAggr) Start(_ context.Context) {} +func (*valdSliceAggr) Start(_ context.Context) {} func (v *valdSliceAggr) Send(ctx context.Context, data *payload.Search_Response) { result := data.GetResults() @@ -580,7 +580,7 @@ func (v *valdSliceAggr) Send(ctx context.Context, data *payload.Search_Response) } func (v *valdSliceAggr) Result() (res *payload.Search_Response) { - slices.RemoveDuplicates(v.result, func(l, r *DistPayload) int { + removeDuplicates(v.result, func(l, r *DistPayload) int { return l.distance.Cmp(r.distance) }) @@ -646,7 +646,7 @@ func (v *valdPoolSliceAggr) Send(ctx context.Context, data *payload.Search_Respo } func (v *valdPoolSliceAggr) Result() (res *payload.Search_Response) { - slices.RemoveDuplicates(v.result, func(l, r *DistPayload) int { + removeDuplicates(v.result, func(l, r *DistPayload) int { return l.distance.Cmp(r.distance) }) @@ -662,3 +662,11 @@ func (v *valdPoolSliceAggr) Result() (res *payload.Search_Response) { poolDist.Put(v.result[:0]) return res } + +func removeDuplicates[E comparable](x []E, less func(left, right E) int) []E { + if len(x) < 2 { + return x + } + slices.SortStableFunc(x, less) + return slices.Compact(x) +}