diff --git a/pkg/gateway/mirror/handler/grpc/handler.go b/pkg/gateway/mirror/handler/grpc/handler.go index 014b0fa551..cd3a43cf7c 100644 --- a/pkg/gateway/mirror/handler/grpc/handler.go +++ b/pkg/gateway/mirror/handler/grpc/handler.go @@ -943,59 +943,31 @@ func (s *server) Insert(ctx context.Context, req *payload.Insert_Request) (loc * reqSrcPodName := s.gateway.FromForwardedContext(ctx) - // When this condition is matched, the request is proxied to another Mirror gateway. + // If this condition is matched, it means that the request was proxied from another Mirror Gateway. // So this component sends requests only to the Vald gateway (LB gateway) of its own cluster. if len(reqSrcPodName) != 0 { - _, err = s.gateway.Do(ctx, s.vAddr, func(ctx context.Context, vc vald.ClientWithMirror, copts ...grpc.CallOption) (interface{}, error) { - loc, err = vc.Insert(ctx, req, copts...) - return loc, err + loc, err = s.doInsert(ctx, req, func(ctx context.Context) (*payload.Object_Location, error) { + _, derr := s.gateway.Do(ctx, s.vAddr, func(ctx context.Context, vc vald.ClientWithMirror, copts ...grpc.CallOption) (interface{}, error) { + loc, err = vc.Insert(ctx, req, copts...) + return loc, err + }) + return loc, errors.Join(derr, err) }) if err != nil { reqInfo := &errdetails.RequestInfo{ - RequestId: req.GetVector().GetId(), - ServingData: errdetails.Serialize(req), + RequestId: req.GetVector().GetId(), } resInfo := &errdetails.ResourceInfo{ ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.InsertRPCName, ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, s.vAddr), } - var attrs trace.Attributes - - switch { - case errors.Is(err, context.Canceled): - err = status.WrapWithCanceled( - vald.InsertRPCName+" API canceld", err, reqInfo, resInfo, - ) - attrs = trace.StatusCodeCancelled(err.Error()) - case errors.Is(err, context.DeadlineExceeded): - err = status.WrapWithDeadlineExceeded( - vald.InsertRPCName+" API deadline exceeded", err, reqInfo, resInfo, - ) - attrs = trace.StatusCodeDeadlineExceeded(err.Error()) - case errors.Is(err, errors.ErrTargetNotFound): - err = status.WrapWithInternal( - vald.InsertRPCName+" API target not found", err, reqInfo, resInfo, - ) - attrs = trace.StatusCodeInternal(err.Error()) - case errors.Is(err, errors.ErrGRPCClientConnNotFound("*")): - err = status.WrapWithInternal( - vald.InsertRPCName+" API connection not found", err, reqInfo, resInfo, - ) - attrs = trace.StatusCodeInternal(err.Error()) - default: - var ( - st *status.Status - msg string - ) - st, msg, err = status.ParseError(err, codes.Internal, - "failed to parse "+vald.InsertRPCName+" gRPC error response", reqInfo, resInfo, - ) - attrs = trace.FromGRPCStatus(st.Code(), msg) - } + st, msg, err := status.ParseError(err, codes.Internal, + "failed to parse "+vald.InsertRPCName+" gRPC error response", reqInfo, resInfo, + ) log.Warn(err) if span != nil { span.RecordError(err) - span.SetAttributes(attrs...) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) span.SetStatus(trace.StatusError, err.Error()) } return nil, err @@ -1004,8 +976,11 @@ func (s *server) Insert(ctx context.Context, req *payload.Insert_Request) (loc * return loc, nil } + // If this condition is matched, it means the request from user. + // So this component sends requests to other Mirror gateways and the Vald gateway (LB gateway) of its own cluster. + var mu sync.Mutex - var result sync.Map[string, error] + var result sync.Map[string, *errorState] // map[target host: error state] loc = &payload.Object_Location{ Uuid: req.GetVector().GetId(), Ips: make([]string, 0), @@ -1018,23 +993,32 @@ func (s *server) Insert(ctx context.Context, req *payload.Insert_Request) (loc * } }() - ce, err := s.insert(ctx, vc, req, copts...) + code := codes.OK + ce, err := s.doInsert(ctx, req, func(ctx context.Context) (*payload.Object_Location, error) { + return vc.Insert(ctx, req, copts...) + }) if err != nil { - st, _, _ := status.ParseError(err, codes.Internal, + var ( + st *status.Status + msg string + ) + st, msg, err = status.ParseError(err, codes.Internal, "failed to parse "+vald.InsertRPCName+" gRPC error response", &errdetails.RequestInfo{ - RequestId: req.GetVector().GetId(), - ServingData: errdetails.Serialize(req), + RequestId: req.GetVector().GetId(), }, &errdetails.ResourceInfo{ ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.InsertRPCName + ".BroadCast/" + target, ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, target), }, ) - if st.Code() == codes.AlreadyExists { - // NOTE: If it is strictly necessary to check, fix this logic. - return nil + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) } + code = st.Code() } if err == nil && ce != nil { mu.Lock() @@ -1042,13 +1026,12 @@ func (s *server) Insert(ctx context.Context, req *payload.Insert_Request) (loc * loc.Ips = append(loc.Ips, ce.GetIps()...) mu.Unlock() } - result.Store(target, err) + result.Store(target, &errorState{err, code}) return err }) if err != nil { reqInfo := &errdetails.RequestInfo{ - RequestId: req.GetVector().GetId(), - ServingData: errdetails.Serialize(req), + RequestId: req.GetVector().GetId(), } resInfo := &errdetails.ResourceInfo{ ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.InsertRPCName + ".BroadCast", @@ -1080,33 +1063,47 @@ func (s *server) Insert(ctx context.Context, req *payload.Insert_Request) (loc * return nil, err } - var errs error - targets := make([]string, 0, 10) - result.Range(func(target string, err error) bool { - if err == nil { - targets = append(targets, target) - } else { - errs = errors.Join(errs, err) + alreadyExistsTgts := make([]string, 0, result.Len()/2) + successTgts := make([]string, 0, result.Len()/2) + result.Range(func(target string, es *errorState) bool { + switch { + case es.err == nil: + successTgts = append(successTgts, target) + case es.code == codes.AlreadyExists: + alreadyExistsTgts = append(alreadyExistsTgts, target) + err = errors.Join(err, es.err) + default: + err = errors.Join(es.err, err) } return true }) - switch { - case errs == nil: - log.Debugf("Insert API mirror request succeeded to %#v", loc) + if err == nil { + log.Debugf(vald.InsertRPCName+" API request succeeded to %#v", loc) return loc, nil - case len(targets) == 0 && errs != nil: - log.Error("failed to Insert API mirror request: %v and can not rollback because success target length is 0", errs) - st, msg, err := status.ParseError(errs, codes.Internal, - "failed to parse "+vald.InsertRPCName+" gRPC error response", - &errdetails.RequestInfo{ - RequestId: req.GetVector().GetId(), - ServingData: errdetails.Serialize(req), - }, - &errdetails.ResourceInfo{ - ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.InsertRPCName, - ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), - }, + } + + reqInfo := &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + } + resInfo := &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.InsertRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + } + + switch { + case result.Len() == len(alreadyExistsTgts): + err = status.WrapWithAlreadyExists(vald.InsertRPCName+" API target same vector already exists", err, reqInfo, resInfo) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.StatusCodeAlreadyExists(err.Error())...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + case result.Len() > len(successTgts)+len(alreadyExistsTgts): // Contains errors except for ALREADY_EXIST. + st, msg, err := status.ParseError(err, codes.Internal, + "failed to parse "+vald.InsertRPCName+" gRPC error response", reqInfo, resInfo, ) + log.Warn(err) if span != nil { span.RecordError(err) span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) @@ -1114,60 +1111,72 @@ func (s *server) Insert(ctx context.Context, req *payload.Insert_Request) (loc * } return nil, err } - log.Error("failed to Insert API mirror request: %v, so starts the rollback request", errs) - var emu sync.Mutex - var rerrs error - rmReq := &payload.Remove_Request{ - Id: &payload.Object_ID{ - Id: req.GetVector().GetId(), + // In this case, the status code in the result object contains only OK or ALREADY_EXIST. + // And send Update API requst to ALREADY_EXIST cluster using the query requested by the user. + log.Warnf("failed to "+vald.InsertRPCName+" API: %#v", err) + + updateReq := &payload.Update_Request{ + Vector: req.GetVector(), + Config: &payload.Update_Config{ + Timestamp: req.GetConfig().GetTimestamp(), }, } - err = s.gateway.DoMulti(ctx, targets, - func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error { - ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "rollback/BroadCast/"+target), apiName+"/"+vald.InsertRPCName+"/rollback/"+target) - defer func() { - if span != nil { - span.End() - } - }() + err = s.gateway.DoMulti(ctx, alreadyExistsTgts, func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error { + ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "DoMulti/"+target), apiName+"/"+vald.UpdateRPCName+"/"+target) + defer func() { + if span != nil { + span.End() + } + }() - _, err := s.remove(ctx, vc, rmReq, copts...) - if err != nil { - st, _, err := status.ParseError(err, codes.Internal, - "failed to parse "+vald.RemoveRPCName+" for "+vald.InsertRPCName+" error response for "+target, - &errdetails.RequestInfo{ - RequestId: rmReq.GetId().GetId(), - ServingData: errdetails.Serialize(rmReq), - }, - &errdetails.ResourceInfo{ - ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.InsertRPCName + "." + vald.RemoveRPCName + ".BroadCast/" + target, - ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, target), - }, - ) - if st.Code() == codes.NotFound { - return nil - } - emu.Lock() - rerrs = errors.Join(rerrs, err) - emu.Unlock() - return err + code := codes.OK + ce, err := s.doUpdate(ctx, updateReq, func(ctx context.Context) (*payload.Object_Location, error) { + return vc.Update(ctx, updateReq, copts...) + }) + if err != nil { + var ( + st *status.Status + msg string + ) + st, msg, err = status.ParseError(err, codes.Internal, + "failed to parse "+vald.UpdateRPCName+" gRPC error response", + &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + }, + &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpdateRPCName + ".DoMulti/" + target, + ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, target), + }, + ) + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) } - return nil - }, - ) + code = st.Code() + } + if err == nil && ce != nil { + mu.Lock() + loc.Name = ce.GetName() + loc.Ips = append(loc.Ips, ce.GetIps()...) + mu.Unlock() + } + result.Store(target, &errorState{err, code}) + return err + }) if err != nil { reqInfo := &errdetails.RequestInfo{ - RequestId: rmReq.GetId().GetId(), - ServingData: errdetails.Serialize(rmReq), + RequestId: updateReq.GetVector().GetId(), } resInfo := &errdetails.ResourceInfo{ - ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.InsertRPCName + "." + vald.RemoveRPCName + ".BroadCast", + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpdateRPCName + "." + vald.InsertRPCName + ".DoMulti", ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), } if errors.Is(err, errors.ErrGRPCClientConnNotFound("*")) { err = status.WrapWithInternal( - vald.RemoveRPCName+" for "+vald.InsertRPCName+" API connection not found", err, reqInfo, resInfo, + vald.UpdateRPCName+" for "+vald.InsertRPCName+" API connection not found", err, reqInfo, resInfo, ) log.Warn(err) if span != nil { @@ -1180,7 +1189,7 @@ func (s *server) Insert(ctx context.Context, req *payload.Insert_Request) (loc * // There is no possibility to reach this part, but we add error handling just in case. st, msg, err := status.ParseError(err, codes.Internal, - "failed to parse "+vald.RemoveRPCName+" for "+vald.InsertRPCName+" gRPC error response", reqInfo, resInfo, + "failed to parse "+vald.UpdateRPCName+" for "+vald.InsertRPCName+" gRPC error response", reqInfo, resInfo, ) log.Warn(err) if span != nil { @@ -1190,19 +1199,49 @@ func (s *server) Insert(ctx context.Context, req *payload.Insert_Request) (loc * } return nil, err } - if rerrs == nil { - log.Debugf("rollback for Insert API mirror request succeeded to %v", targets) - st, msg, err := status.ParseError(errs, codes.Internal, - "failed to parse "+vald.InsertRPCName+" gRPC error response", - &errdetails.RequestInfo{ - RequestId: req.GetVector().GetId(), - ServingData: errdetails.Serialize(req), - }, - &errdetails.ResourceInfo{ - ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.InsertRPCName, - ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), - }, + + alreadyExistsTgts = alreadyExistsTgts[0:0] + successTgts = successTgts[0:0] + result.Range(func(target string, es *errorState) bool { + switch { + case es.err == nil: + successTgts = append(successTgts, target) + case es.code == codes.AlreadyExists: + alreadyExistsTgts = append(alreadyExistsTgts, target) + err = errors.Join(err, es.err) + default: + err = errors.Join(es.err, err) + } + return true + }) + if err == nil || (len(successTgts) > 0 && result.Len() == len(successTgts)+len(alreadyExistsTgts)) { + log.Debugf(vald.UpdateRPCName+"for "+vald.InsertRPCName+" API request succeeded to %#v", loc) + return loc, nil + } + + reqInfo = &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + } + resInfo = &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.InsertRPCName + "." + vald.UpdateRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + } + + switch { + case result.Len() == len(alreadyExistsTgts): + err = status.WrapWithAlreadyExists(vald.UpdateRPCName+" for "+vald.InsertRPCName+" API target same vector already exists", err, reqInfo, resInfo) + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.StatusCodeAlreadyExists(err.Error())...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + case result.Len() > len(successTgts)+len(alreadyExistsTgts): // Contains errors except for ALREADY_EXIST. + st, msg, err := status.ParseError(err, codes.Internal, + "failed to parse "+vald.UpdateRPCName+" for "+vald.InsertRPCName+" gRPC error response", reqInfo, resInfo, ) + log.Warn(err) if span != nil { span.RecordError(err) span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) @@ -1210,35 +1249,19 @@ func (s *server) Insert(ctx context.Context, req *payload.Insert_Request) (loc * } return nil, err } - log.Debugf("failed to rollback for Insert API mirror request succeeded to %v: %v", targets, rerrs) - st, msg, err := status.ParseError(rerrs, codes.Internal, - "failed to parse "+vald.RemoveRPCName+" for "+vald.InsertRPCName+" gRPC error response", - &errdetails.RequestInfo{ - RequestId: rmReq.GetId().GetId(), - ServingData: errdetails.Serialize(rmReq), - }, - &errdetails.ResourceInfo{ - ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.InsertRPCName + "." + vald.RemoveRPCName, - ResourceName: fmt.Sprintf("%s: %s(%s) %v", apiName, s.name, s.ip, targets), - }, - ) - if span != nil { - span.RecordError(err) - span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) - span.SetStatus(trace.StatusError, err.Error()) - } - return nil, err + log.Debugf(vald.UpdateRPCName+"for "+vald.InsertRPCName+" API request succeeded to %#v, err: %v", loc, err) + return loc, nil } -func (s *server) insert(ctx context.Context, client vald.InsertClient, req *payload.Insert_Request, opts ...grpc.CallOption) (loc *payload.Object_Location, err error) { - ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "insert"), apiName+"/insert") +func (s *server) doInsert(ctx context.Context, req *payload.Insert_Request, f func(ctx context.Context) (*payload.Object_Location, error)) (loc *payload.Object_Location, err error) { + ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "doInsert"), apiName+"/doInsert") defer func() { if span != nil { span.End() } }() - loc, err = client.Insert(ctx, req, opts...) + loc, err = f(ctx) if err != nil { reqInfo := &errdetails.RequestInfo{ RequestId: req.GetVector().GetId(), @@ -1261,6 +1284,11 @@ func (s *server) insert(ctx context.Context, client vald.InsertClient, req *payl vald.InsertRPCName+" API deadline exceeded", err, reqInfo, resInfo, ) attrs = trace.StatusCodeDeadlineExceeded(err.Error()) + case errors.Is(err, errors.ErrTargetNotFound): + err = status.WrapWithInternal( + vald.InsertRPCName+" API target not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) case errors.Is(err, errors.ErrGRPCClientConnNotFound("*")): err = status.WrapWithInternal( vald.InsertRPCName+" API connection not found", err, reqInfo, resInfo, @@ -1413,62 +1441,31 @@ func (s *server) Update(ctx context.Context, req *payload.Update_Request) (loc * reqSrcPodName := s.gateway.FromForwardedContext(ctx) - // When this condition is matched, the request is proxied to another Mirror gateway. + // If this condition is matched, it means that the request was proxied from another Mirror Gateway. // So this component sends requests only to the Vald gateway (LB gateway) of its own cluster. if len(reqSrcPodName) != 0 { - _, err = s.gateway.Do(ctx, s.vAddr, func(ctx context.Context, vc vald.ClientWithMirror, copts ...grpc.CallOption) (interface{}, error) { - loc, err = vc.Update(ctx, req, copts...) - if err != nil { - return nil, err - } - return loc, nil + loc, err = s.doUpdate(ctx, req, func(ctx context.Context) (*payload.Object_Location, error) { + _, derr := s.gateway.Do(ctx, s.vAddr, func(ctx context.Context, vc vald.ClientWithMirror, copts ...grpc.CallOption) (interface{}, error) { + loc, err = vc.Update(ctx, req, copts...) + return loc, err + }) + return loc, errors.Join(derr, err) }) if err != nil { reqInfo := &errdetails.RequestInfo{ - RequestId: req.GetVector().GetId(), - ServingData: errdetails.Serialize(req), + RequestId: req.GetVector().GetId(), } resInfo := &errdetails.ResourceInfo{ ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpdateRPCName, ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, s.vAddr), } - var attrs trace.Attributes - - switch { - case errors.Is(err, context.Canceled): - err = status.WrapWithCanceled( - vald.UpdateRPCName+" API canceld", err, reqInfo, resInfo, - ) - attrs = trace.StatusCodeCancelled(err.Error()) - case errors.Is(err, context.DeadlineExceeded): - err = status.WrapWithDeadlineExceeded( - vald.UpdateRPCName+" API deadline exceeded", err, reqInfo, resInfo, - ) - attrs = trace.StatusCodeDeadlineExceeded(err.Error()) - case errors.Is(err, errors.ErrTargetNotFound): - err = status.WrapWithInternal( - vald.UpdateRPCName+" API target not found", err, reqInfo, resInfo, - ) - attrs = trace.StatusCodeInternal(err.Error()) - case errors.Is(err, errors.ErrGRPCClientConnNotFound("*")): - err = status.WrapWithInternal( - vald.UpdateRPCName+" API connection not found", err, reqInfo, resInfo, - ) - attrs = trace.StatusCodeInternal(err.Error()) - default: - var ( - st *status.Status - msg string - ) - st, msg, err = status.ParseError(err, codes.Internal, - "failed to parse "+vald.UpdateRPCName+" gRPC error response", reqInfo, resInfo, - ) - attrs = trace.FromGRPCStatus(st.Code(), msg) - } + st, msg, err := status.ParseError(err, codes.Internal, + "failed to parse "+vald.UpdateRPCName+" gRPC error response", reqInfo, resInfo, + ) log.Warn(err) if span != nil { span.RecordError(err) - span.SetAttributes(attrs...) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) span.SetStatus(trace.StatusError, err.Error()) } return nil, err @@ -1477,60 +1474,60 @@ func (s *server) Update(ctx context.Context, req *payload.Update_Request) (loc * return loc, nil } - objReq := &payload.Object_VectorRequest{ - Id: &payload.Object_ID{ - Id: req.GetVector().GetId(), - }, - } - oldVecs, err := s.getObjects(ctx, objReq) - if err != nil { - return nil, err - } + // If this condition is matched, it means the request from user. + // So this component sends requests to other Mirror gateways and the Vald gateway (LB gateway) of its own cluster. var mu sync.Mutex - var result sync.Map[string, error] + var result sync.Map[string, *errorState] // map[target host: error state] loc = &payload.Object_Location{ Uuid: req.GetVector().GetId(), Ips: make([]string, 0), } - - err = s.gateway.BroadCast(ctx, - func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error { - ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "BroadCast/"+target), apiName+"/"+vald.UpdateRPCName+"/"+target) - defer func() { - if span != nil { - span.End() - } - }() - - ce, err := s.update(ctx, vc, req, copts...) - if err != nil { - st, _, _ := status.ParseError(err, codes.Internal, - "failed to parse "+vald.UpdateRPCName+" API error response for "+target, - &errdetails.RequestInfo{ - RequestId: req.GetVector().GetId(), - ServingData: errdetails.Serialize(req), - }, - &errdetails.ResourceInfo{ - ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpdateRPCName + ".BroadCast/" + target, - ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, target), - }, - ) - if st.Code() == codes.AlreadyExists { - // NOTE: If it is strictly necessary to check, fix this logic. - return nil - } + err = s.gateway.BroadCast(ctx, func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error { + ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "BroadCast/"+target), apiName+"/"+vald.UpdateRPCName+"/"+target) + defer func() { + if span != nil { + span.End() } - if err == nil && ce != nil { - mu.Lock() - loc.Name = ce.GetName() - loc.Ips = append(loc.Ips, ce.GetIps()...) - mu.Unlock() + }() + + code := codes.OK + ce, err := s.doUpdate(ctx, req, func(ctx context.Context) (*payload.Object_Location, error) { + return vc.Update(ctx, req, copts...) + }) + if err != nil { + var ( + st *status.Status + msg string + ) + st, msg, err = status.ParseError(err, codes.Internal, + "failed to parse "+vald.UpdateRPCName+" gRPC error response", + &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + ServingData: errdetails.Serialize(req), + }, + &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpdateRPCName + ".BroadCast/" + target, + ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, target), + }, + ) + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) } - result.Store(target, err) - return err - }, - ) + code = st.Code() + } + if err == nil && ce != nil { + mu.Lock() + loc.Name = ce.GetName() + loc.Ips = append(loc.Ips, ce.GetIps()...) + mu.Unlock() + } + result.Store(target, &errorState{err, code}) + return err + }) if err != nil { reqInfo := &errdetails.RequestInfo{ RequestId: req.GetVector().GetId(), @@ -1566,33 +1563,60 @@ func (s *server) Update(ctx context.Context, req *payload.Update_Request) (loc * return nil, err } - var errs error - targets := make([]string, 0, 10) - result.Range(func(target string, err error) bool { - if err == nil { - targets = append(targets, target) - } else { - errs = errors.Join(errs, err) + var alreadyExistsCnt int + notFoundTgts := make([]string, 0, result.Len()/2) + successTgts := make([]string, 0, result.Len()/2) + result.Range(func(target string, es *errorState) bool { + switch { + case es.err == nil: + successTgts = append(successTgts, target) + case es.code == codes.AlreadyExists: + alreadyExistsCnt++ + err = errors.Join(err, es.err) + case es.code == codes.NotFound: + notFoundTgts = append(notFoundTgts, target) + err = errors.Join(err, es.err) + default: + err = errors.Join(es.err, err) } return true }) - switch { - case errs == nil: - log.Debugf("Update API mirror request succeeded to %#v", loc) + if err == nil || (len(successTgts) > 0 && result.Len() == len(successTgts)+alreadyExistsCnt) { + log.Debugf(vald.UpdateRPCName+" API request succeeded to %#v", loc) return loc, nil - case len(targets) == 0 && errs != nil: - log.Error("failed to Update API mirror request: %v and can not rollback because success target length is 0", errs) - st, msg, err := status.ParseError(errs, codes.Internal, - "failed to parse "+vald.UpdateRPCName+" gRPC error response", - &errdetails.RequestInfo{ - RequestId: req.GetVector().GetId(), - ServingData: errdetails.Serialize(req), - }, - &errdetails.ResourceInfo{ - ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpdateRPCName, - ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), - }, - ) + } + + reqInfo := &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + } + resInfo := &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpdateRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + } + + switch { + case result.Len() == len(notFoundTgts): + err = status.WrapWithNotFound(vald.UpdateRPCName+" API target not found", err, reqInfo, resInfo) + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.StatusCodeAlreadyExists(err.Error())...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + case result.Len() == alreadyExistsCnt: + err = status.WrapWithAlreadyExists(vald.UpdateRPCName+" API target same vector already exists", err, reqInfo, resInfo) + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.StatusCodeNotFound(err.Error())...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + case result.Len() > len(successTgts)+len(notFoundTgts)+alreadyExistsCnt: // Contains errors except for NOT_FOUND and ALREADY_EXIST. + st, msg, err := status.ParseError(err, codes.Internal, + "failed to parse "+vald.UpdateRPCName+" gRPC error response", reqInfo, resInfo) + log.Warn(err) if span != nil { span.RecordError(err) span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) @@ -1600,92 +1624,74 @@ func (s *server) Update(ctx context.Context, req *payload.Update_Request) (loc * } return nil, err } - log.Error("failed to Update API mirror request: %v, so starts the rollback request", errs) - var emu sync.Mutex - var rerrs error - rmReq := &payload.Remove_Request{ - Id: &payload.Object_ID{ - Id: req.GetVector().GetId(), + // In this case, the status code in the result object contains only OK or ALREADY_EXIST or NOT_FOUND. + // And send Insert API requst to NOT_FOUND cluster using query requested by the user. + log.Warnf("failed to "+vald.UpdateRPCName+" API: %#v", err) + + insReq := &payload.Insert_Request{ + Vector: req.GetVector(), + Config: &payload.Insert_Config{ + Timestamp: req.GetConfig().GetTimestamp(), }, } - - err = s.gateway.DoMulti(ctx, targets, - func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error { - ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "rollback/BroadCast/"+target), apiName+"/"+vald.RemoveRPCName+"/rollback/"+target) - defer func() { - if span != nil { - span.End() - } - }() - - oldVec, ok := oldVecs.Load(target) - if !ok || oldVec == nil { - _, err := s.remove(ctx, vc, rmReq, copts...) - if err != nil { - st, _, _ := status.ParseError(err, codes.Internal, - "failed to parse "+vald.RemoveRPCName+" for "+vald.UpdateRPCName+" gRPC error response", - &errdetails.RequestInfo{ - RequestId: rmReq.GetId().GetId(), - ServingData: errdetails.Serialize(rmReq), - }, - &errdetails.ResourceInfo{ - ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpdateRPCName + "." + vald.RemoveRPCName + ".BroadCast/" + target, - ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, target), - }, - ) - if st.Code() == codes.NotFound { - return nil - } - emu.Lock() - rerrs = errors.Join(rerrs, err) - emu.Unlock() - return err - } - return nil + err = s.gateway.DoMulti(ctx, notFoundTgts, func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error { + ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "BroadCast/"+target), apiName+"/"+vald.InsertRPCName+"/"+target) + defer func() { + if span != nil { + span.End() } + }() - req := &payload.Update_Request{ - Vector: oldVec, - Config: &payload.Update_Config{ - SkipStrictExistCheck: true, + code := codes.OK + ce, err := s.doInsert(ctx, insReq, func(ctx context.Context) (*payload.Object_Location, error) { + return vc.Insert(ctx, insReq, copts...) + }) + if err != nil { + var ( + st *status.Status + msg string + ) + st, msg, err = status.ParseError(err, codes.Internal, + "failed to parse "+vald.InsertRPCName+" gRPC error response", + &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + ServingData: errdetails.Serialize(req), }, + &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.InsertRPCName + ".BroadCast/" + target, + ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, target), + }, + ) + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) } - _, err := s.update(ctx, vc, req, copts...) - if err != nil { - st, _, _ := status.ParseError(err, codes.Internal, - "failed to parse "+vald.UpdateRPCName+" for "+vald.UpdateRPCName+" gRPC error response", - &errdetails.RequestInfo{ - RequestId: req.GetVector().GetId(), - ServingData: errdetails.Serialize(req), - }, - &errdetails.ResourceInfo{ - ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpdateRPCName + "." + vald.UpdateRPCName, - ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, target), - }, - ) - if st.Code() == codes.AlreadyExists { - return nil - } - emu.Lock() - rerrs = errors.Join(rerrs, err) - emu.Unlock() - return err - } - return nil - }, - ) + code = st.Code() + } + if err == nil && ce != nil { + mu.Lock() + loc.Name = ce.GetName() + loc.Ips = append(loc.Ips, ce.GetIps()...) + mu.Unlock() + } + result.Store(target, &errorState{err, code}) + return err + }) if err != nil { reqInfo := &errdetails.RequestInfo{ - RequestId: req.GetVector().GetId(), + RequestId: req.GetVector().GetId(), + ServingData: errdetails.Serialize(req), } resInfo := &errdetails.ResourceInfo{ - ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpdateRPCName + ".Rollback.BroadCast", + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpdateRPCName + "." + vald.InsertRPCName + ".BroadCast", ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), } if errors.Is(err, errors.ErrGRPCClientConnNotFound("*")) { err = status.WrapWithInternal( - vald.UpdateRPCName+" for Rollback connection not found", err, reqInfo, resInfo, + vald.InsertRPCName+" for "+vald.UpdateRPCName+" API connection not found", err, reqInfo, resInfo, ) log.Warn(err) if span != nil { @@ -1696,31 +1702,74 @@ func (s *server) Update(ctx context.Context, req *payload.Update_Request) (loc * return nil, err } - // There is no possibility to reach this part, but we add error handling just in case. - st, msg, err := status.ParseError(err, codes.Internal, - "failed to parse "+vald.UpdateRPCName+" for Rollback gRPC error response", reqInfo, resInfo, - ) + // There is no possibility to reach this part, but we add error handling just in case. + st, msg, err := status.ParseError(err, codes.Internal, + "failed to parse "+vald.InsertRPCName+" for "+vald.UpdateRPCName+" gRPC error response", reqInfo, resInfo, + ) + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + + alreadyExistsCnt = 0 + notFoundTgts = notFoundTgts[0:0] + successTgts = successTgts[0:0] + result.Range(func(target string, em *errorState) bool { + switch { + case em.err == nil: + successTgts = append(successTgts, target) + case em.code == codes.AlreadyExists: + alreadyExistsCnt++ + err = errors.Join(err, em.err) + case em.code == codes.NotFound: + notFoundTgts = append(notFoundTgts, target) + err = errors.Join(err, em.err) + default: + err = errors.Join(em.err, err) + } + return true + }) + if err == nil || (len(successTgts) > 0 && result.Len() == len(successTgts)+alreadyExistsCnt) { + log.Debugf(vald.InsertRPCName+" for "+vald.UpdateRPCName+" API request succeeded to %#v", loc) + return loc, nil + } + + reqInfo = &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + ServingData: errdetails.Serialize(req), + } + resInfo = &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpdateRPCName + "." + vald.InsertRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + } + + switch { + case result.Len() == len(notFoundTgts): + err = status.WrapWithNotFound(vald.InsertRPCName+" for "+vald.UpdateRPCName+" API target not found", err, reqInfo, resInfo) log.Warn(err) if span != nil { span.RecordError(err) - span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetAttributes(trace.StatusCodeAlreadyExists(err.Error())...) span.SetStatus(trace.StatusError, err.Error()) } return nil, err - } - if rerrs == nil { - log.Debugf("rollback for Update API mirror request succeeded to %v", targets) - st, msg, err := status.ParseError(errs, codes.Internal, - "failed to parse "+vald.UpdateRPCName+" gRPC error response", - &errdetails.RequestInfo{ - RequestId: req.GetVector().GetId(), - ServingData: errdetails.Serialize(req), - }, - &errdetails.ResourceInfo{ - ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpdateRPCName, - ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), - }, - ) + case result.Len() == alreadyExistsCnt: + err = status.WrapWithAlreadyExists(vald.InsertRPCName+" for "+vald.UpdateRPCName+" API target same vector already exists", err, reqInfo, resInfo) + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.StatusCodeNotFound(err.Error())...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + case result.Len() > len(successTgts)+len(notFoundTgts)+alreadyExistsCnt: // Contains errors except for NOT_FOUND and ALREADY_EXIST. + st, msg, err := status.ParseError(err, codes.Internal, + "failed to parse "+vald.InsertRPCName+" for "+vald.UpdateRPCName+" gRPC error response", reqInfo, resInfo) + log.Warn(err) if span != nil { span.RecordError(err) span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) @@ -1728,34 +1777,19 @@ func (s *server) Update(ctx context.Context, req *payload.Update_Request) (loc * } return nil, err } - log.Debugf("failed to rollback for Update API mirror request succeeded to %v: %v", targets, rerrs) - st, msg, err := status.ParseError(rerrs, codes.Internal, - "failed to parse "+vald.UpdateRPCName+" for Rollback gRPC error response", - &errdetails.RequestInfo{ - RequestId: req.GetVector().GetId(), - }, - &errdetails.ResourceInfo{ - ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpdateRPCName + ".Rollback", - ResourceName: fmt.Sprintf("%s: %s(%s) %v", apiName, s.name, s.ip, targets), - }, - ) - if span != nil { - span.RecordError(err) - span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) - span.SetStatus(trace.StatusError, err.Error()) - } - return nil, err + log.Debugf(vald.InsertRPCName+" for "+vald.UpdateRPCName+" API request succeeded to %#v, err: %v", loc, err) + return loc, nil } -func (s *server) update(ctx context.Context, client vald.UpdateClient, req *payload.Update_Request, opts ...grpc.CallOption) (loc *payload.Object_Location, err error) { - ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "update"), apiName+"/update") +func (s *server) doUpdate(ctx context.Context, req *payload.Update_Request, f func(ctx context.Context) (*payload.Object_Location, error)) (loc *payload.Object_Location, err error) { + ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "doUpdate"), apiName+"/doUpdate") defer func() { if span != nil { span.End() } }() - loc, err = client.Update(ctx, req, opts...) + loc, err = f(ctx) if err != nil { reqInfo := &errdetails.RequestInfo{ RequestId: req.GetVector().GetId(), @@ -1778,6 +1812,11 @@ func (s *server) update(ctx context.Context, client vald.UpdateClient, req *payl vald.UpdateRPCName+" API deadline exceeded", err, reqInfo, resInfo, ) attrs = trace.StatusCodeDeadlineExceeded(err.Error()) + case errors.Is(err, errors.ErrTargetNotFound): + err = status.WrapWithInternal( + vald.UpdateRPCName+" API target not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) case errors.Is(err, errors.ErrGRPCClientConnNotFound("*")): err = status.WrapWithInternal( vald.UpdateRPCName+" API connection not found", err, reqInfo, resInfo, @@ -1930,15 +1969,15 @@ func (s *server) Upsert(ctx context.Context, req *payload.Upsert_Request) (loc * reqSrcPodName := s.gateway.FromForwardedContext(ctx) - // When this condition is matched, the request is proxied to another Mirror gateway. + // If this condition is matched, it means that the request was proxied from another Mirror Gateway. // So this component sends requests only to the Vald gateway (LB gateway) of its own cluster. if len(reqSrcPodName) != 0 { - _, err = s.gateway.Do(ctx, s.vAddr, func(ctx context.Context, vc vald.ClientWithMirror, copts ...grpc.CallOption) (interface{}, error) { - loc, err = vc.Upsert(ctx, req, copts...) - if err != nil { - return nil, err - } - return loc, nil + loc, err = s.doUpsert(ctx, req, func(ctx context.Context) (*payload.Object_Location, error) { + s.gateway.Do(ctx, s.vAddr, func(ctx context.Context, vc vald.ClientWithMirror, copts ...grpc.CallOption) (interface{}, error) { + loc, err = vc.Upsert(ctx, req, copts...) + return loc, err + }) + return loc, err }) if err != nil { reqInfo := &errdetails.RequestInfo{ @@ -1949,43 +1988,13 @@ func (s *server) Upsert(ctx context.Context, req *payload.Upsert_Request) (loc * ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpsertRPCName, ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, s.vAddr), } - var attrs trace.Attributes - - switch { - case errors.Is(err, context.Canceled): - err = status.WrapWithCanceled( - vald.UpsertRPCName+" API canceld", err, reqInfo, resInfo, - ) - attrs = trace.StatusCodeCancelled(err.Error()) - case errors.Is(err, context.DeadlineExceeded): - err = status.WrapWithDeadlineExceeded( - vald.UpsertRPCName+" API deadline exceeded", err, reqInfo, resInfo, - ) - attrs = trace.StatusCodeDeadlineExceeded(err.Error()) - case errors.Is(err, errors.ErrTargetNotFound): - err = status.WrapWithInternal( - vald.UpsertRPCName+" API target not found", err, reqInfo, resInfo, - ) - attrs = trace.StatusCodeInternal(err.Error()) - case errors.Is(err, errors.ErrGRPCClientConnNotFound("*")): - err = status.WrapWithInternal( - vald.UpsertRPCName+" API connection not found", err, reqInfo, resInfo, - ) - attrs = trace.StatusCodeInternal(err.Error()) - default: - var ( - st *status.Status - msg string - ) - st, msg, err = status.ParseError(err, codes.Internal, - "failed to parse "+vald.UpsertRPCName+" gRPC error response", reqInfo, resInfo, - ) - attrs = trace.FromGRPCStatus(st.Code(), msg) - } + st, msg, err := status.ParseError(err, codes.Internal, + "failed to parse "+vald.UpsertRPCName+" gRPC error response", reqInfo, resInfo, + ) log.Warn(err) if span != nil { span.RecordError(err) - span.SetAttributes(attrs...) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) span.SetStatus(trace.StatusError, err.Error()) } return nil, err @@ -1994,18 +2003,11 @@ func (s *server) Upsert(ctx context.Context, req *payload.Upsert_Request) (loc * return loc, nil } - objReq := &payload.Object_VectorRequest{ - Id: &payload.Object_ID{ - Id: req.GetVector().GetId(), - }, - } - oldVecs, err := s.getObjects(ctx, objReq) - if err != nil { - return nil, err - } + // If this condition is matched, it means the request from user. + // So this component sends requests to other Mirror gateways and the Vald gateway (LB gateway) of its own cluster. var mu sync.Mutex - var result sync.Map[string, error] + var result sync.Map[string, *errorState] // map[target host: error state] loc = &payload.Object_Location{ Uuid: req.GetVector().GetId(), Ips: make([]string, 0), @@ -2018,9 +2020,16 @@ func (s *server) Upsert(ctx context.Context, req *payload.Upsert_Request) (loc * } }() - ce, err := s.upsert(ctx, vc, req, copts...) + code := codes.OK + ce, err := s.doUpsert(ctx, req, func(ctx context.Context) (*payload.Object_Location, error) { + return vc.Upsert(ctx, req, copts...) + }) if err != nil { - st, _, _ := status.ParseError(err, codes.Internal, + var ( + st *status.Status + msg string + ) + st, msg, err = status.ParseError(err, codes.Internal, "failed to parse "+vald.UpsertRPCName+" gRPC error response", &errdetails.RequestInfo{ RequestId: req.GetVector().GetId(), @@ -2031,10 +2040,13 @@ func (s *server) Upsert(ctx context.Context, req *payload.Upsert_Request) (loc * ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, target), }, ) - if st.Code() == codes.AlreadyExists { - // NOTE: If it is strictly necessary to check, fix this logic. - return nil + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) } + code = st.Code() } if err == nil && ce != nil { mu.Lock() @@ -2042,7 +2054,7 @@ func (s *server) Upsert(ctx context.Context, req *payload.Upsert_Request) (loc * loc.Ips = append(loc.Ips, ce.GetIps()...) mu.Unlock() } - result.Store(target, err) + result.Store(target, &errorState{err, code}) return err }) if err != nil { @@ -2080,160 +2092,47 @@ func (s *server) Upsert(ctx context.Context, req *payload.Upsert_Request) (loc * return nil, err } - var errs error - targets := make([]string, 0, 10) - result.Range(func(target string, err error) bool { - if err == nil { - targets = append(targets, target) - } else { - errs = errors.Join(errs, err) + var alreadyExistsCnt int + successTgts := make([]string, 0, result.Len()/2) + result.Range(func(target string, es *errorState) bool { + switch { + case es.err == nil: + successTgts = append(successTgts, target) + case es.code == codes.AlreadyExists: + alreadyExistsCnt++ + err = errors.Join(err, es.err) + default: + err = errors.Join(es.err, err) } return true }) - switch { - case errs == nil: - log.Debugf("Upsert API mirror request succeeded to %#v", loc) + if err == nil || (len(successTgts) > 0 && result.Len() == len(successTgts)+alreadyExistsCnt) { + log.Debugf(vald.UpsertRPCName+" API request succeeded to %#v", loc) return loc, nil - case len(targets) == 0 && errs != nil: - log.Error("failed to Upsert API mirror request: %v and can not rollback because success target length is 0", errs) - st, msg, err := status.ParseError(errs, codes.Internal, - "failed to parse "+vald.UpsertRPCName+" gRPC error response", - &errdetails.RequestInfo{ - RequestId: req.GetVector().GetId(), - ServingData: errdetails.Serialize(req), - }, - &errdetails.ResourceInfo{ - ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpsertRPCName, - ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), - }, - ) - if span != nil { - span.RecordError(err) - span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) - span.SetStatus(trace.StatusError, err.Error()) - } - return nil, err } - log.Error("failed to Upsert API mirror request: %v, so starts the rollback request", errs) - - var emu sync.Mutex - var rerrs error - rmReq := &payload.Remove_Request{ - Id: &payload.Object_ID{ - Id: req.GetVector().GetId(), - }, + reqInfo := &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + ServingData: errdetails.Serialize(req), + } + resInfo := &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpsertRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), } - err = s.gateway.DoMulti(ctx, targets, - func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error { - ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "rollback/BroadCast/"+target), apiName+"/"+vald.UpsertRPCName+"/rollback/"+target) - defer func() { - if span != nil { - span.End() - } - }() - - oldVec, ok := oldVecs.Load(target) - if !ok || oldVec == nil { - _, err := s.remove(ctx, vc, rmReq, copts...) - if err != nil { - st, _, _ := status.ParseError(err, codes.Internal, - "failed to parse "+vald.RemoveRPCName+" for "+vald.UpsertRPCName+" gRPC error response", - &errdetails.RequestInfo{ - RequestId: rmReq.GetId().GetId(), - ServingData: errdetails.Serialize(rmReq), - }, - &errdetails.ResourceInfo{ - ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpsertRPCName + "." + vald.RemoveRPCName + ".BroadCast/" + target, - ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, target), - }, - ) - if st.Code() == codes.NotFound { - return nil - } - emu.Lock() - rerrs = errors.Join(rerrs, err) - emu.Unlock() - return err - } - return nil - } - - req := &payload.Update_Request{ - Vector: oldVec, - Config: &payload.Update_Config{ - SkipStrictExistCheck: true, - }, - } - _, err := s.update(ctx, vc, req, copts...) - if err != nil { - st, _, _ := status.ParseError(err, codes.Internal, - "failed to parse "+vald.UpdateRPCName+" for "+vald.UpsertRPCName+" gRPC error response", - &errdetails.RequestInfo{ - RequestId: req.GetVector().GetId(), - ServingData: errdetails.Serialize(req), - }, - &errdetails.ResourceInfo{ - ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpsertRPCName + "." + vald.UpdateRPCName, - ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, target), - }, - ) - if st.Code() == codes.AlreadyExists { - return nil - } - emu.Lock() - rerrs = errors.Join(rerrs, err) - emu.Unlock() - return err - } - return nil - }, - ) - if err != nil { - reqInfo := &errdetails.RequestInfo{ - RequestId: req.GetVector().GetId(), - } - resInfo := &errdetails.ResourceInfo{ - ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpsertRPCName + ".Rollback.BroadCast", - ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), - } - if errors.Is(err, errors.ErrGRPCClientConnNotFound("*")) { - err = status.WrapWithInternal( - vald.UpsertRPCName+" for Rollback connection not found", err, reqInfo, resInfo, - ) - log.Warn(err) - if span != nil { - span.RecordError(err) - span.SetAttributes(trace.StatusCodeInternal(err.Error())...) - span.SetStatus(trace.StatusError, err.Error()) - } - return nil, err - } - // There is no possibility to reach this part, but we add error handling just in case. - st, msg, err := status.ParseError(err, codes.Internal, - "failed to parse "+vald.UpsertRPCName+" for Rollback gRPC error response", reqInfo, resInfo, - ) + switch { + case result.Len() == alreadyExistsCnt: + err = status.WrapWithAlreadyExists(vald.UpsertRPCName+" API target same vector already exists", err, reqInfo, resInfo) log.Warn(err) if span != nil { span.RecordError(err) - span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetAttributes(trace.StatusCodeAlreadyExists(err.Error())...) span.SetStatus(trace.StatusError, err.Error()) } return nil, err - } - if rerrs == nil { - log.Debugf("rollback for Upsert API mirror request succeeded to %v", targets) - st, msg, err := status.ParseError(errs, codes.Internal, - "failed to parse "+vald.UpsertRPCName+" gRPC error response", - &errdetails.RequestInfo{ - RequestId: req.GetVector().GetId(), - ServingData: errdetails.Serialize(req), - }, - &errdetails.ResourceInfo{ - ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpsertRPCName, - ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), - }, - ) + default: + st, msg, err := status.ParseError(err, codes.Internal, + "failed to parse "+vald.UpsertRPCName+" gRPC error response", reqInfo, resInfo) + log.Warn(err) if span != nil { span.RecordError(err) span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) @@ -2241,34 +2140,17 @@ func (s *server) Upsert(ctx context.Context, req *payload.Upsert_Request) (loc * } return nil, err } - log.Debugf("failed to rollback for Upsert API mirror request succeeded to %v: %v", targets, rerrs) - st, msg, err := status.ParseError(rerrs, codes.Internal, - "failed to parse "+vald.UpsertRPCName+" for Rollback gRPC error response", - &errdetails.RequestInfo{ - RequestId: req.GetVector().GetId(), - }, - &errdetails.ResourceInfo{ - ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpsertRPCName + ".Rollback", - ResourceName: fmt.Sprintf("%s: %s(%s) %v", apiName, s.name, s.ip, targets), - }, - ) - if span != nil { - span.RecordError(err) - span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) - span.SetStatus(trace.StatusError, err.Error()) - } - return nil, err } -func (s *server) upsert(ctx context.Context, client vald.UpsertClient, req *payload.Upsert_Request, opts ...grpc.CallOption) (loc *payload.Object_Location, err error) { - ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "upsert"), apiName+"/upsert") +func (s *server) doUpsert(ctx context.Context, req *payload.Upsert_Request, f func(ctx context.Context) (*payload.Object_Location, error)) (loc *payload.Object_Location, err error) { + ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "doUpsert"), apiName+"/doUpsert") defer func() { if span != nil { span.End() } }() - loc, err = client.Upsert(ctx, req, opts...) + loc, err = f(ctx) if err != nil { reqInfo := &errdetails.RequestInfo{ RequestId: req.GetVector().GetId(), @@ -2291,6 +2173,11 @@ func (s *server) upsert(ctx context.Context, client vald.UpsertClient, req *payl vald.UpsertRPCName+" API deadline exceeded", err, reqInfo, resInfo, ) attrs = trace.StatusCodeDeadlineExceeded(err.Error()) + case errors.Is(err, errors.ErrTargetNotFound): + err = status.WrapWithInternal( + vald.UpsertRPCName+" API target not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) case errors.Is(err, errors.ErrGRPCClientConnNotFound("*")): err = status.WrapWithInternal( vald.UpsertRPCName+" API connection not found", err, reqInfo, resInfo, @@ -2306,7 +2193,7 @@ func (s *server) upsert(ctx context.Context, client vald.UpsertClient, req *payl ) attrs = trace.FromGRPCStatus(st.Code(), msg) } - log.Warn("failed to process Upsert request\terror: %s", err.Error()) + log.Warn(err) if span != nil { span.RecordError(err) span.SetAttributes(attrs...) @@ -2443,62 +2330,31 @@ func (s *server) Remove(ctx context.Context, req *payload.Remove_Request) (loc * reqSrcPodName := s.gateway.FromForwardedContext(ctx) - // When this condition is matched, the request is proxied to another Mirror gateway. - // So this component sends the request only to the Vald gateway (LB gateway) of own cluster. + // If this condition is matched, it means that the request was proxied from another Mirror Gateway. + // So this component sends requests only to the Vald gateway (LB gateway) of its own cluster. if len(reqSrcPodName) != 0 { - _, err = s.gateway.Do(ctx, s.vAddr, func(ctx context.Context, vc vald.ClientWithMirror, copts ...grpc.CallOption) (interface{}, error) { - loc, err = vc.Remove(ctx, req, copts...) - if err != nil { - return nil, err - } - return loc, nil + loc, err = s.doRemove(ctx, req, func(ctx context.Context) (*payload.Object_Location, error) { + s.gateway.Do(ctx, s.vAddr, func(ctx context.Context, vc vald.ClientWithMirror, copts ...grpc.CallOption) (interface{}, error) { + loc, err = vc.Remove(ctx, req, copts...) + return loc, err + }) + return loc, err }) if err != nil { reqInfo := &errdetails.RequestInfo{ - RequestId: req.GetId().GetId(), - ServingData: errdetails.Serialize(req), + RequestId: req.GetId().GetId(), } resInfo := &errdetails.ResourceInfo{ ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.RemoveRPCName, ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, s.vAddr), } - var attrs trace.Attributes - - switch { - case errors.Is(err, context.Canceled): - err = status.WrapWithCanceled( - vald.RemoveRPCName+" API canceld", err, reqInfo, resInfo, - ) - attrs = trace.StatusCodeCancelled(err.Error()) - case errors.Is(err, context.DeadlineExceeded): - err = status.WrapWithDeadlineExceeded( - vald.RemoveRPCName+" API deadline exceeded", err, reqInfo, resInfo, - ) - attrs = trace.StatusCodeDeadlineExceeded(err.Error()) - case errors.Is(err, errors.ErrTargetNotFound): - err = status.WrapWithInternal( - vald.RemoveRPCName+" API target not found", err, reqInfo, resInfo, - ) - attrs = trace.StatusCodeInternal(err.Error()) - case errors.Is(err, errors.ErrGRPCClientConnNotFound("*")): - err = status.WrapWithInternal( - vald.RemoveRPCName+" API connection not found", err, reqInfo, resInfo, - ) - attrs = trace.StatusCodeInternal(err.Error()) - default: - var ( - st *status.Status - msg string - ) - st, msg, err = status.ParseError(err, codes.Internal, - "failed to parse "+vald.RemoveRPCName+" gRPC error response", reqInfo, resInfo, - ) - attrs = trace.FromGRPCStatus(st.Code(), msg) - } + st, msg, err := status.ParseError(err, codes.Internal, + "failed to parse "+vald.RemoveRPCName+" gRPC error response", reqInfo, resInfo, + ) log.Warn(err) if span != nil { span.RecordError(err) - span.SetAttributes(attrs...) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) span.SetStatus(trace.StatusError, err.Error()) } return nil, err @@ -2507,186 +2363,70 @@ func (s *server) Remove(ctx context.Context, req *payload.Remove_Request) (loc * return loc, nil } - objReq := &payload.Object_VectorRequest{ - Id: &payload.Object_ID{ - Id: req.GetId().GetId(), - }, - } - oldVecs, err := s.getObjects(ctx, objReq) - if err != nil { - return nil, err - } + // If this condition is matched, it means the request from user. + // So this component sends requests to other Mirror gateways and the Vald gateway (LB gateway) of its own cluster. var mu sync.Mutex - var result sync.Map[string, error] + var result sync.Map[string, *errorState] // map[target host: error state] loc = &payload.Object_Location{ Uuid: req.GetId().GetId(), Ips: make([]string, 0), } - err = s.gateway.BroadCast(ctx, func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error { ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "BroadCast/"+target), apiName+"/"+vald.RemoveRPCName+"/"+target) defer func() { if span != nil { span.End() } - }() - - ce, err := s.remove(ctx, vc, req, copts...) - if err != nil { - st, _, _ := status.ParseError(err, codes.Internal, - "failed to parse "+vald.RemoveRPCName+" gRPC error response for "+target, - &errdetails.RequestInfo{ - RequestId: req.GetId().GetId(), - ServingData: errdetails.Serialize(req), - }, - &errdetails.ResourceInfo{ - ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.RemoveRPCName + ".BroadCast/" + target, - ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, target), - }, - ) - if st.Code() == codes.NotFound { - // NOTE: If it is strictly necessary to check, fix this logic. - return nil - } - } - if err == nil && ce != nil { - mu.Lock() - loc.Name = ce.GetName() - loc.Ips = append(loc.Ips, ce.GetIps()...) - mu.Unlock() - } - result.Store(target, err) - return err - }) - if err != nil { - reqInfo := &errdetails.RequestInfo{ - RequestId: req.GetId().GetId(), - ServingData: errdetails.Serialize(req), - } - resInfo := &errdetails.ResourceInfo{ - ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.RemoveRPCName + ".BroadCast", - ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), - } - if errors.Is(err, errors.ErrGRPCClientConnNotFound("*")) { - err = status.WrapWithInternal( - vald.RemoveRPCName+" API connection not found", err, reqInfo, resInfo, - ) - log.Warn(err) - if span != nil { - span.RecordError(err) - span.SetAttributes(trace.StatusCodeInternal(err.Error())...) - span.SetStatus(trace.StatusError, err.Error()) - } - return nil, err - } - - // There is no possibility to reach this part, but we add error handling just in case. - st, msg, err := status.ParseError(err, codes.Internal, - "failed to parse "+vald.RemoveRPCName+" gRPC error response", reqInfo, resInfo, - ) - log.Warn(err) - if span != nil { - span.RecordError(err) - span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) - span.SetStatus(trace.StatusError, err.Error()) - } - return nil, err - } - - var errs error - targets := make([]string, 0, 10) - result.Range(func(target string, err error) bool { - if err == nil { - targets = append(targets, target) - } else { - errs = errors.Join(errs, err) - } - return true - }) - switch { - case errs == nil: - log.Debugf("Remove API mirror request succeeded to %#v", loc) - return loc, nil - case len(targets) == 0 && errs != nil: - log.Error("failed to Remove API mirror request: %v and can not rollback because success target length is 0", errs) - st, msg, err := status.ParseError(errs, codes.Internal, - "failed to parse "+vald.RemoveRPCName+" gRPC error response", - &errdetails.RequestInfo{ - RequestId: req.GetId().GetId(), - ServingData: errdetails.Serialize(req), - }, - &errdetails.ResourceInfo{ - ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.RemoveRPCName, - ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), - }, - ) - if span != nil { - span.RecordError(err) - span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) - span.SetStatus(trace.StatusError, err.Error()) - } - return nil, err - } - log.Error("failed to Remove API mirror request: %v, so starts the rollback request", errs) - - var emu sync.Mutex - var rerrs error - err = s.gateway.DoMulti(ctx, targets, - func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error { - ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "rollback/BroadCast/"+target), apiName+"/"+vald.RemoveRPCName+"/rollback/"+target) - defer func() { - if span != nil { - span.End() - } - }() - - objv, ok := oldVecs.Load(target) - if !ok || objv == nil { - log.Debug("failed to load old vector from %s", target) - return nil - } - req := &payload.Upsert_Request{ - Vector: objv, - Config: &payload.Upsert_Config{ - SkipStrictExistCheck: true, + }() + + code := codes.OK + ce, err := s.doRemove(ctx, req, func(ctx context.Context) (*payload.Object_Location, error) { + return vc.Remove(ctx, req, copts...) + }) + if err != nil { + var ( + st *status.Status + msg string + ) + st, msg, err = status.ParseError(err, codes.Internal, + "failed to parse "+vald.RemoveRPCName+" gRPC error response", + &errdetails.RequestInfo{ + RequestId: req.GetId().GetId(), }, + &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.RemoveRPCName + ".BroadCast/" + target, + ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, target), + }, + ) + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) } - _, err := s.upsert(ctx, vc, req, copts...) - if err != nil { - st, _, _ := status.ParseError(err, codes.Internal, - "failed to parse "+vald.UpsertRPCName+" for "+vald.RemoveRPCName+" gRPC error response", - &errdetails.RequestInfo{ - RequestId: req.GetVector().GetId(), - ServingData: errdetails.Serialize(req), - }, - &errdetails.ResourceInfo{ - ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.RemoveRPCName + "." + vald.UpsertRPCName + ".BroadCast/" + target, - ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, target), - }, - ) - if st.Code() == codes.AlreadyExists { - return nil - } - emu.Lock() - rerrs = errors.Join(rerrs, err) - emu.Unlock() - return err - } - return nil - }, - ) + code = st.Code() + } + if err == nil && ce != nil { + mu.Lock() + loc.Name = ce.GetName() + loc.Ips = append(loc.Ips, ce.GetIps()...) + mu.Unlock() + } + result.Store(target, &errorState{err, code}) + return err + }) if err != nil { reqInfo := &errdetails.RequestInfo{ RequestId: req.GetId().GetId(), } resInfo := &errdetails.ResourceInfo{ - ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.RemoveRPCName + "." + vald.UpsertRPCName + ".BroadCast", + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.RemoveRPCName + ".BroadCast", ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), } if errors.Is(err, errors.ErrGRPCClientConnNotFound("*")) { err = status.WrapWithInternal( - vald.UpsertRPCName+" for "+vald.RemoveRPCName+" API connection not found", err, reqInfo, resInfo, + vald.RemoveRPCName+" API connection not found", err, reqInfo, resInfo, ) log.Warn(err) if span != nil { @@ -2699,7 +2439,7 @@ func (s *server) Remove(ctx context.Context, req *payload.Remove_Request) (loc * // There is no possibility to reach this part, but we add error handling just in case. st, msg, err := status.ParseError(err, codes.Internal, - "failed to parse "+vald.UpsertRPCName+" for "+vald.RemoveRPCName+" gRPC error response", reqInfo, resInfo, + "failed to parse "+vald.RemoveRPCName+" gRPC error response", reqInfo, resInfo, ) log.Warn(err) if span != nil { @@ -2709,19 +2449,47 @@ func (s *server) Remove(ctx context.Context, req *payload.Remove_Request) (loc * } return nil, err } - if rerrs == nil { - log.Debugf("rollback for Remove API mirror request succeeded to %v", targets) - st, msg, err := status.ParseError(errs, codes.Internal, - "failed to parse "+vald.RemoveRPCName+" gRPC error response", - &errdetails.RequestInfo{ - RequestId: req.GetId().GetId(), - ServingData: errdetails.Serialize(req), - }, - &errdetails.ResourceInfo{ - ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.RemoveRPCName, - ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), - }, - ) + + var notFoundCnt int + successTgts := make([]string, 0, result.Len()/2) + result.Range(func(target string, es *errorState) bool { + switch { + case es.err == nil: + successTgts = append(successTgts, target) + case es.code == codes.NotFound: + notFoundCnt++ + err = errors.Join(err, es.err) + default: + err = errors.Join(es.err, err) + } + return true + }) + if err == nil || (len(successTgts) > 0 && result.Len() == len(successTgts)+notFoundCnt) { + log.Debugf(vald.RemoveRPCName+" API request succeeded to %#v", loc) + return loc, nil + } + + reqInfo := &errdetails.RequestInfo{ + RequestId: req.GetId().GetId(), + } + resInfo := &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.RemoveRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + } + + switch { + case result.Len() == notFoundCnt: + err = status.WrapWithNotFound(vald.RemoveRPCName+" API id "+req.GetId().GetId()+" not found", err, reqInfo, resInfo) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.StatusCodeAlreadyExists(err.Error())...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + default: + st, msg, err := status.ParseError(err, codes.Internal, + "failed to parse "+vald.RemoveRPCName+" gRPC error response", reqInfo, resInfo) + log.Warn(err) if span != nil { span.RecordError(err) span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) @@ -2729,38 +2497,20 @@ func (s *server) Remove(ctx context.Context, req *payload.Remove_Request) (loc * } return nil, err } - log.Debugf("failed to rollback for Remove API mirror request succeeded to %v: %v", targets, rerrs) - st, msg, err := status.ParseError(rerrs, codes.Internal, - "failed to parse "+vald.UpsertRPCName+" for "+vald.RemoveRPCName+" gRPC error response", - &errdetails.RequestInfo{ - RequestId: req.GetId().GetId(), - }, - &errdetails.ResourceInfo{ - ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.RemoveRPCName + "." + vald.UpsertRPCName, - ResourceName: fmt.Sprintf("%s: %s(%s) %v", apiName, s.name, s.ip, targets), - }, - ) - if span != nil { - span.RecordError(err) - span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) - span.SetStatus(trace.StatusError, err.Error()) - } - return nil, err } -func (s *server) remove(ctx context.Context, client vald.RemoveClient, req *payload.Remove_Request, opts ...grpc.CallOption) (*payload.Object_Location, error) { - ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "remove"), apiName+"/remove") +func (s *server) doRemove(ctx context.Context, req *payload.Remove_Request, f func(ctx context.Context) (*payload.Object_Location, error)) (loc *payload.Object_Location, err error) { + ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "doRemove"), apiName+"/doRemove") defer func() { if span != nil { span.End() } }() - loc, err := client.Remove(ctx, req, opts...) + loc, err = f(ctx) if err != nil { reqInfo := &errdetails.RequestInfo{ - RequestId: req.GetId().GetId(), - ServingData: errdetails.Serialize(req), + RequestId: req.GetId().GetId(), } resInfo := &errdetails.ResourceInfo{ ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.RemoveRPCName, @@ -2779,6 +2529,11 @@ func (s *server) remove(ctx context.Context, client vald.RemoveClient, req *payl vald.RemoveRPCName+" API deadline exceeded", err, reqInfo, resInfo, ) attrs = trace.StatusCodeDeadlineExceeded(err.Error()) + case errors.Is(err, errors.ErrTargetNotFound): + err = status.WrapWithInternal( + vald.RemoveRPCName+" API target not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) case errors.Is(err, errors.ErrGRPCClientConnNotFound("*")): err = status.WrapWithInternal( vald.RemoveRPCName+" API connection not found", err, reqInfo, resInfo, @@ -2931,12 +2686,15 @@ func (s *server) RemoveByTimestamp(ctx context.Context, req *payload.Remove_Time reqSrcPodName := s.gateway.FromForwardedContext(ctx) - // When this condition is matched, the request is proxied to another Mirror gateway. + // If this condition is matched, it means that the request was proxied from another Mirror Gateway. // So this component sends requests only to the Vald gateway (LB gateway) of its own cluster. if len(reqSrcPodName) != 0 { - _, err = s.gateway.Do(ctx, s.vAddr, func(ctx context.Context, vc vald.ClientWithMirror, opts ...grpc.CallOption) (interface{}, error) { - locs, err = vc.RemoveByTimestamp(ctx, req, opts...) - return locs, err + locs, err = s.doRemoveByTimestamp(ctx, req, func(ctx context.Context) (*payload.Object_Locations, error) { + _, derr := s.gateway.Do(ctx, s.vAddr, func(ctx context.Context, vc vald.ClientWithMirror, copts ...grpc.CallOption) (interface{}, error) { + locs, err = vc.RemoveByTimestamp(ctx, req, copts...) + return locs, err + }) + return locs, errors.Join(derr, err) }) if err != nil { reqInfo := &errdetails.RequestInfo{ @@ -2946,55 +2704,29 @@ func (s *server) RemoveByTimestamp(ctx context.Context, req *payload.Remove_Time ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.RemoveByTimestampRPCName, ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, s.vAddr), } - var attrs trace.Attributes - - switch { - case errors.Is(err, context.Canceled): - err = status.WrapWithCanceled( - vald.RemoveByTimestampRPCName+" API canceld", err, reqInfo, resInfo, - ) - attrs = trace.StatusCodeCancelled(err.Error()) - case errors.Is(err, context.DeadlineExceeded): - err = status.WrapWithDeadlineExceeded( - vald.RemoveByTimestampRPCName+" API deadline exceeded", err, reqInfo, resInfo, - ) - attrs = trace.StatusCodeDeadlineExceeded(err.Error()) - case errors.Is(err, errors.ErrTargetNotFound): - err = status.WrapWithInternal( - vald.RemoveByTimestampRPCName+" API target not found", err, reqInfo, resInfo, - ) - attrs = trace.StatusCodeInternal(err.Error()) - case errors.Is(err, errors.ErrGRPCClientConnNotFound("*")): - err = status.WrapWithInternal( - vald.RemoveByTimestampRPCName+" API connection not found", err, reqInfo, resInfo, - ) - attrs = trace.StatusCodeInternal(err.Error()) - default: - var ( - st *status.Status - msg string - ) - st, msg, err = status.ParseError(err, codes.Internal, - "failed to parse "+vald.RemoveByTimestampRPCName+" gRPC error response", reqInfo, resInfo, - ) - attrs = trace.FromGRPCStatus(st.Code(), msg) - } + st, msg, err := status.ParseError(err, codes.Internal, + "failed to parse "+vald.RemoveRPCName+" gRPC error response", reqInfo, resInfo, + ) log.Warn(err) if span != nil { span.RecordError(err) - span.SetAttributes(attrs...) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) span.SetStatus(trace.StatusError, err.Error()) } return nil, err } + log.Debugf("RemoveByTimestamp API remove succeeded to %#v", locs) return locs, nil } + // If this condition is matched, it means the request from user. + // So this component sends requests to other Mirror gateways and the Vald gateway (LB gateway) of its own cluster. + var mu sync.Mutex - var result sync.Map[string, error] + var result sync.Map[string, *errorState] // map[target host: error state] locs = new(payload.Object_Locations) - err = s.gateway.BroadCast(ctx, func(ctx context.Context, target string, vc vald.ClientWithMirror, opts ...grpc.CallOption) error { + err = s.gateway.BroadCast(ctx, func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error { ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "BroadCast/"+target), apiName+"/"+vald.RemoveByTimestampRPCName+"/"+target) defer func() { if span != nil { @@ -3002,61 +2734,37 @@ func (s *server) RemoveByTimestamp(ctx context.Context, req *payload.Remove_Time } }() - res, err := vc.RemoveByTimestamp(ctx, req, opts...) + code := codes.OK + res, err := s.doRemoveByTimestamp(ctx, req, func(ctx context.Context) (*payload.Object_Locations, error) { + return vc.RemoveByTimestamp(ctx, req, copts...) + }) if err != nil { - reqInfo := &errdetails.RequestInfo{ - ServingData: errdetails.Serialize(req), - } - resInfo := &errdetails.ResourceInfo{ - ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.RemoveByTimestampRPCName, - ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), - } - var attrs trace.Attributes - var code codes.Code - - switch { - case errors.Is(err, context.Canceled): - err = status.WrapWithCanceled( - vald.RemoveByTimestampRPCName+" API canceld", err, reqInfo, resInfo, - ) - attrs = trace.StatusCodeCancelled(err.Error()) - case errors.Is(err, context.DeadlineExceeded): - err = status.WrapWithDeadlineExceeded( - vald.RemoveByTimestampRPCName+" API deadline exceeded", err, reqInfo, resInfo, - ) - attrs = trace.StatusCodeDeadlineExceeded(err.Error()) - case errors.Is(err, errors.ErrGRPCClientConnNotFound("*")): - err = status.WrapWithInternal( - vald.RemoveByTimestampRPCName+" API connection not found", err, reqInfo, resInfo, - ) - attrs = trace.StatusCodeInternal(err.Error()) - default: - var ( - st *status.Status - msg string - ) - st, msg, err = status.ParseError(err, codes.Internal, - "failed to parse "+vald.RemoveByTimestampRPCName+" gRPC error response", reqInfo, resInfo, - ) - attrs = trace.FromGRPCStatus(st.Code(), msg) - code = st.Code() - } + var ( + st *status.Status + msg string + ) + st, msg, err = status.ParseError(err, codes.Internal, + "failed to parse "+vald.RemoveByTimestampRPCName+" gRPC error response", + &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.RemoveByTimestampRPCName + ".BroadCast/" + target, + ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, target), + }, + ) log.Warn(err) if span != nil { span.RecordError(err) - span.SetAttributes(attrs...) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) span.SetStatus(trace.StatusError, err.Error()) } - if code == codes.NotFound { - return nil - } - result.Store(target, err) - return err + code = st.Code() } - mu.Lock() - locs.Locations = append(locs.Locations, res.GetLocations()...) - mu.Unlock() - return nil + if err == nil && res != nil { + mu.Lock() + locs.Locations = append(locs.Locations, res.GetLocations()...) + mu.Unlock() + } + result.Store(target, &errorState{err, code}) + return err }) if err != nil { reqInfo := &errdetails.RequestInfo{ @@ -3084,80 +2792,106 @@ func (s *server) RemoveByTimestamp(ctx context.Context, req *payload.Remove_Time "failed to parse "+vald.RemoveByTimestampRPCName+" gRPC error response", reqInfo, resInfo, ) log.Warn(err) - if err != nil { - if span != nil { - span.RecordError(err) - span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) - span.SetStatus(trace.StatusError, err.Error()) - } + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) } return nil, err } - result.Range(func(_ string, rerr error) bool { - if rerr != nil { - err = errors.Join(err, rerr) + var notFoundCnt int + successTgts := make([]string, 0, result.Len()/2) + result.Range(func(target string, es *errorState) bool { + switch { + case es.err == nil: + successTgts = append(successTgts, target) + case es.code == codes.NotFound: + notFoundCnt++ + err = errors.Join(err, es.err) + default: + err = errors.Join(es.err, err) } return true }) - if err != nil { + if err == nil || (len(successTgts) > 0 && result.Len() == len(successTgts)+notFoundCnt) { + log.Debugf(vald.RemoveByTimestampRPCName+" API request succeeded to %#v", locs) + return locs, nil + } + + reqInfo := &errdetails.RequestInfo{ + ServingData: errdetails.Serialize(req), + } + resInfo := &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.RemoveByTimestampRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + } + + switch { + case result.Len() == notFoundCnt: + err = status.WrapWithNotFound(vald.RemoveByTimestampRPCName+" API target not found", err, reqInfo, resInfo) + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.StatusCodeAlreadyExists(err.Error())...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + default: st, msg, err := status.ParseError(err, codes.Internal, - "failed to parse "+vald.RemoveByTimestampRPCName+" gRPC error response") - if err != nil { + "failed to parse "+vald.RemoveByTimestampRPCName+" gRPC error response", reqInfo, resInfo) + log.Warn(err) + if span != nil { span.RecordError(err) span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) span.SetStatus(trace.StatusError, err.Error()) } return nil, err } - return locs, nil } -func (s *server) GetObject(ctx context.Context, req *payload.Object_VectorRequest) (vec *payload.Object_Vector, err error) { - ctx, span := trace.StartSpan(grpc.WithGRPCMethod(ctx, vald.PackageName+"."+vald.ObjectRPCServiceName+"/"+vald.GetObjectRPCName), apiName+"/"+vald.GetObjectRPCName) +func (s *server) doRemoveByTimestamp( + ctx context.Context, + req *payload.Remove_TimestampRequest, + f func(ctx context.Context) (*payload.Object_Locations, error), +) (locs *payload.Object_Locations, err error) { + ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "doRemoveByTimestamp"), apiName+"/doRemoveByTimestamp") defer func() { if span != nil { span.End() } }() - _, err = s.gateway.Do(ctx, s.vAddr, func(ctx context.Context, vc vald.ClientWithMirror, copts ...grpc.CallOption) (interface{}, error) { - vec, err = vc.GetObject(ctx, req, copts...) - if err != nil { - return nil, err - } - return vec, nil - }) + locs, err = f(ctx) if err != nil { reqInfo := &errdetails.RequestInfo{ - RequestId: req.GetId().GetId(), ServingData: errdetails.Serialize(req), } resInfo := &errdetails.ResourceInfo{ - ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.GetObjectRPCName, - ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, s.vAddr), + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.RemoveByTimestampRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), } var attrs trace.Attributes switch { case errors.Is(err, context.Canceled): err = status.WrapWithCanceled( - vald.GetObjectRPCName+" API canceld", err, reqInfo, resInfo, + vald.RemoveByTimestampRPCName+" API canceld", err, reqInfo, resInfo, ) attrs = trace.StatusCodeCancelled(err.Error()) case errors.Is(err, context.DeadlineExceeded): err = status.WrapWithDeadlineExceeded( - vald.GetObjectRPCName+" API deadline exceeded", err, reqInfo, resInfo, + vald.RemoveByTimestampRPCName+" API deadline exceeded", err, reqInfo, resInfo, ) attrs = trace.StatusCodeDeadlineExceeded(err.Error()) case errors.Is(err, errors.ErrTargetNotFound): err = status.WrapWithInternal( - vald.GetObjectRPCName+" API target not found", err, reqInfo, resInfo, + vald.RemoveByTimestampRPCName+" API target not found", err, reqInfo, resInfo, ) attrs = trace.StatusCodeInternal(err.Error()) case errors.Is(err, errors.ErrGRPCClientConnNotFound("*")): err = status.WrapWithInternal( - vald.GetObjectRPCName+" API connection not found", err, reqInfo, resInfo, + vald.RemoveByTimestampRPCName+" API connection not found", err, reqInfo, resInfo, ) attrs = trace.StatusCodeInternal(err.Error()) default: @@ -3166,7 +2900,7 @@ func (s *server) GetObject(ctx context.Context, req *payload.Object_VectorReques msg string ) st, msg, err = status.ParseError(err, codes.Internal, - "failed to parse "+vald.GetObjectRPCName+" gRPC error response", reqInfo, resInfo, + "failed to parse "+vald.RemoveByTimestampRPCName+" gRPC error response", reqInfo, resInfo, ) attrs = trace.FromGRPCStatus(st.Code(), msg) } @@ -3178,138 +2912,72 @@ func (s *server) GetObject(ctx context.Context, req *payload.Object_VectorReques } return nil, err } - return vec, nil + return locs, nil } -func (s *server) getObjects(ctx context.Context, req *payload.Object_VectorRequest) (vecs *sync.Map[string, *payload.Object_Vector], err error) { - ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "getObjects"), apiName+"/"+vald.GetObjectRPCName+"/getObjects") +func (s *server) GetObject(ctx context.Context, req *payload.Object_VectorRequest) (vec *payload.Object_Vector, err error) { + ctx, span := trace.StartSpan(grpc.WithGRPCMethod(ctx, vald.PackageName+"."+vald.ObjectRPCServiceName+"/"+vald.GetObjectRPCName), apiName+"/"+vald.GetObjectRPCName) defer func() { if span != nil { span.End() } }() - var errs error - var emu sync.Mutex - vecs = new(sync.Map[string, *payload.Object_Vector]) - err = s.gateway.BroadCast(ctx, func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error { - ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "BroadCast/"+target), apiName+"/"+vald.GetObjectRPCName+"/getObjects/"+target) - defer func() { - if span != nil { - span.End() - } - }() - - vec, err := vc.GetObject(ctx, req, copts...) - if err != nil { - reqInfo := &errdetails.RequestInfo{ - RequestId: req.GetId().GetId(), - ServingData: errdetails.Serialize(req), - } - resInfo := &errdetails.ResourceInfo{ - ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.GetObjectRPCName, - ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, target), - } - var attrs trace.Attributes - var code codes.Code - - switch { - case errors.Is(err, context.Canceled): - err = status.WrapWithCanceled( - vald.GetObjectRPCName+" API canceld", err, reqInfo, resInfo, - ) - attrs = trace.StatusCodeCancelled(err.Error()) - code = codes.Canceled - case errors.Is(err, context.DeadlineExceeded): - err = status.WrapWithDeadlineExceeded( - vald.GetObjectRPCName+" API deadline exceeded", err, reqInfo, resInfo, - ) - attrs = trace.StatusCodeDeadlineExceeded(err.Error()) - code = codes.DeadlineExceeded - case errors.Is(err, errors.ErrTargetNotFound): - err = status.WrapWithInternal( - vald.GetObjectRPCName+" API target not found", err, reqInfo, resInfo, - ) - attrs = trace.StatusCodeInternal(err.Error()) - code = codes.Internal - case errors.Is(err, errors.ErrGRPCClientConnNotFound("*")): - err = status.WrapWithInternal( - vald.GetObjectRPCName+" API connection not found", err, reqInfo, resInfo, - ) - attrs = trace.StatusCodeInternal(err.Error()) - code = codes.Internal - default: - var ( - st *status.Status - msg string - ) - st, msg, err = status.ParseError(err, codes.Internal, - "failed to parse "+vald.GetObjectRPCName+" gRPC error response", reqInfo, resInfo, - ) - attrs = trace.FromGRPCStatus(st.Code(), msg) - code = st.Code() - } - log.Warn(err) - if span != nil { - span.RecordError(err) - span.SetAttributes(attrs...) - span.SetStatus(trace.StatusError, err.Error()) - } - if code == codes.NotFound { - return nil - } - emu.Lock() - errs = errors.Join(errs, err) - emu.Unlock() - return err - } - vecs.Store(target, vec) - return nil + _, err = s.gateway.Do(ctx, s.vAddr, func(ctx context.Context, vc vald.ClientWithMirror, copts ...grpc.CallOption) (interface{}, error) { + vec, err = vc.GetObject(ctx, req, copts...) + return vec, err }) if err != nil { - if errors.Is(err, errors.ErrGRPCClientConnNotFound("*")) { + reqInfo := &errdetails.RequestInfo{ + RequestId: req.GetId().GetId(), + ServingData: errdetails.Serialize(req), + } + resInfo := &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.GetObjectRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, s.vAddr), + } + var attrs trace.Attributes + + switch { + case errors.Is(err, context.Canceled): + err = status.WrapWithCanceled( + vald.GetObjectRPCName+" API canceld", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeCancelled(err.Error()) + case errors.Is(err, context.DeadlineExceeded): + err = status.WrapWithDeadlineExceeded( + vald.GetObjectRPCName+" API deadline exceeded", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeDeadlineExceeded(err.Error()) + case errors.Is(err, errors.ErrTargetNotFound): err = status.WrapWithInternal( - vald.GetObjectRPCName+" API connection not found", err, - &errdetails.RequestInfo{ - RequestId: req.GetId().GetId(), - ServingData: errdetails.Serialize(req), - }, - &errdetails.ResourceInfo{ - ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.GetObjectRPCName + ".BroadCast", - ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), - }, + vald.GetObjectRPCName+" API target not found", err, reqInfo, resInfo, ) - log.Warn(err) - if span != nil { - span.RecordError(err) - span.SetAttributes(trace.StatusCodeInternal(err.Error())...) - span.SetStatus(trace.StatusError, err.Error()) - } - return nil, err + attrs = trace.StatusCodeInternal(err.Error()) + case errors.Is(err, errors.ErrGRPCClientConnNotFound("*")): + err = status.WrapWithInternal( + vald.GetObjectRPCName+" API connection not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + default: + var ( + st *status.Status + msg string + ) + st, msg, err = status.ParseError(err, codes.Internal, + "failed to parse "+vald.GetObjectRPCName+" gRPC error response", reqInfo, resInfo, + ) + attrs = trace.FromGRPCStatus(st.Code(), msg) } - errs = errors.Join(errs, err) - } - if errs != nil { - st, msg, err := status.ParseError(errs, codes.Internal, - "failed to parse "+vald.GetObjectRPCName+" gRPC error response", - &errdetails.RequestInfo{ - RequestId: req.GetId().GetId(), - ServingData: errdetails.Serialize(req), - }, - &errdetails.ResourceInfo{ - ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.GetObjectRPCName + "." + "BroadCast", - ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), - }, - ) log.Warn(err) if span != nil { span.RecordError(err) - span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetAttributes(attrs...) span.SetStatus(trace.StatusError, err.Error()) } return nil, err } - return vecs, nil + return vec, nil } func (s *server) StreamGetObject(stream vald.Object_StreamGetObjectServer) (err error) { @@ -3359,3 +3027,8 @@ func (s *server) StreamGetObject(stream vald.Object_StreamGetObjectServer) (err } return nil } + +type errorState struct { + err error + code codes.Code +} diff --git a/pkg/gateway/mirror/handler/grpc/handler_test.go b/pkg/gateway/mirror/handler/grpc/handler_test.go index 84968025c4..29d97f9664 100644 --- a/pkg/gateway/mirror/handler/grpc/handler_test.go +++ b/pkg/gateway/mirror/handler/grpc/handler_test.go @@ -16,7 +16,6 @@ package grpc import ( "context" "reflect" - "sync/atomic" "testing" "github.com/vdaas/vald/apis/grpc/v1/payload" @@ -25,7 +24,6 @@ import ( "github.com/vdaas/vald/internal/net/grpc" "github.com/vdaas/vald/internal/net/grpc/codes" "github.com/vdaas/vald/internal/net/grpc/status" - "github.com/vdaas/vald/internal/sync" "github.com/vdaas/vald/internal/sync/errgroup" "github.com/vdaas/vald/internal/test/data/vector" "github.com/vdaas/vald/internal/test/goleak" @@ -66,7 +64,11 @@ func Test_server_Insert(t *testing.T) { } defaultCheckFunc := func(w want, gotCe *payload.Object_Location, err error) error { if !errors.Is(err, w.err) { - return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + gotSt, gotOk := status.FromError(err) + wantSt, wantOk := status.FromError(w.err) + if gotOk != wantOk || gotSt.Code() != wantSt.Code() { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } } if !reflect.DeepEqual(gotCe, w.wantCe) { return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotCe, w.wantCe) @@ -83,24 +85,23 @@ func Test_server_Insert(t *testing.T) { Uuid: uuid, Ips: []string{"127.0.0.1"}, } + targets := []string{ + "vald-01", "vald-02", + } cmap := map[string]vald.ClientWithMirror{ - "vald-mirror-01": &mockClient{ + targets[0]: &mockClient{ InsertFunc: func(_ context.Context, _ *payload.Insert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { return loc, nil }, }, - "vald-lb-gateway-01": &mockClient{ + targets[1]: &mockClient{ InsertFunc: func(_ context.Context, _ *payload.Insert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { return loc, nil }, }, } - wantLoc := &payload.Object_Location{ - Uuid: uuid, - Ips: []string{"127.0.0.1", "127.0.0.1"}, - } return test{ - name: "success insert with new ID", + name: "Success: insert with new ID", args: args{ ctx: egctx, req: &payload.Insert_Request{ @@ -118,15 +119,18 @@ func Test_server_Insert(t *testing.T) { return "" }, BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { - for tgt, c := range cmap { - f(ctx, tgt, c) + for _, tgt := range targets { + f(ctx, tgt, cmap[tgt]) } return nil }, }, }, want: want{ - wantCe: wantLoc, + wantCe: &payload.Object_Location{ + Uuid: uuid, + Ips: []string{"127.0.0.1", "127.0.0.1"}, + }, }, afterFunc: func(t *testing.T, args args) { t.Helper() @@ -143,23 +147,103 @@ func Test_server_Insert(t *testing.T) { Uuid: uuid, Ips: []string{"127.0.0.1"}, } + targets := []string{ + "vald-01", "vald-02", + } cmap := map[string]vald.ClientWithMirror{ - "vald-mirror-01": &mockClient{ + targets[0]: &mockClient{ InsertFunc: func(_ context.Context, _ *payload.Insert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { - return loc, nil + return &payload.Object_Location{ + Uuid: uuid, + Ips: []string{"127.0.0.1"}, + }, nil }, - RemoveFunc: func(_ context.Context, _ *payload.Remove_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + }, + targets[1]: &mockClient{ + InsertFunc: func(_ context.Context, _ *payload.Insert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return nil, status.Error(codes.AlreadyExists, errors.ErrMetaDataAlreadyExists(uuid).Error()) + }, + UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { return loc, nil }, }, - "vald-lb-gateway-01": &mockClient{ + } + return test{ + name: "Success: when the last status codes are (OK, OK) after updating the target that returned AlreadyExists", + args: args{ + ctx: egctx, + req: &payload.Insert_Request{ + Vector: &payload.Object_Vector{ + Id: uuid, + Vector: vector.GaussianDistributedFloat32VectorGenerator(1, dimension)[0], + }, + Config: defaultInsertConfig, + }, + }, + fields: fields{ + eg: eg, + gateway: &mockGateway{ + FromForwardedContextFunc: func(_ context.Context) string { + return "" + }, + BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + for _, tgt := range targets { + f(ctx, tgt, cmap[tgt]) + } + return nil + }, + DoMultiFunc: func(ctx context.Context, targets []string, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + for _, target := range targets { + if c, ok := cmap[target]; !ok { + return errors.ErrTargetNotFound + } else { + f(ctx, target, c) + } + } + return nil + }, + }, + }, + want: want{ + wantCe: &payload.Object_Location{ + Uuid: uuid, + Ips: []string{"127.0.0.1", "127.0.0.1"}, + }, + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + cancel() + }, + } + }(), + func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, egctx := errgroup.New(ctx) + + uuid := "test" + targets := []string{ + "vald-01", "vald-02", + } + cmap := map[string]vald.ClientWithMirror{ + targets[0]: &mockClient{ + InsertFunc: func(_ context.Context, _ *payload.Insert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return &payload.Object_Location{ + Uuid: uuid, + Ips: []string{"127.0.0.1"}, + }, nil + }, + }, + targets[1]: &mockClient{ InsertFunc: func(_ context.Context, _ *payload.Insert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { - return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()) + return nil, status.Error(codes.AlreadyExists, errors.ErrMetaDataAlreadyExists(uuid).Error()) + }, + UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return nil, status.Error(codes.AlreadyExists, errors.ErrMetaDataAlreadyExists(uuid).Error()) }, }, } return test{ - name: "fail insert with new ID but remove rollback success", + name: "Success: when the last status codes are (OK, AlreadyExists) after updating the target that returned AlreadyExists", args: args{ ctx: egctx, req: &payload.Insert_Request{ @@ -176,25 +260,84 @@ func Test_server_Insert(t *testing.T) { FromForwardedContextFunc: func(_ context.Context) string { return "" }, - BroadCastFunc: func(ctx context.Context, f func(_ context.Context, _ string, _ vald.ClientWithMirror, _ ...grpc.CallOption) error) error { - for tgt, c := range cmap { - f(ctx, tgt, c) + BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + for _, tgt := range targets { + f(ctx, tgt, cmap[tgt]) } return nil }, DoMultiFunc: func(ctx context.Context, targets []string, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { - if len(targets) != 1 { - return errors.New("invalid target") + for _, target := range targets { + if c, ok := cmap[target]; !ok { + return errors.New("target not found") + } else { + f(ctx, target, c) + } } - if c, ok := cmap[targets[0]]; ok { - f(ctx, targets[0], c) + return nil + }, + }, + }, + want: want{ + wantCe: &payload.Object_Location{ + Uuid: uuid, + Ips: []string{"127.0.0.1"}, + }, + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + cancel() + }, + } + }(), + func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, egctx := errgroup.New(ctx) + + uuid := "test" + targets := []string{ + "vald-01", "vald-02", + } + cmap := map[string]vald.ClientWithMirror{ + targets[0]: &mockClient{ + InsertFunc: func(_ context.Context, _ *payload.Insert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return nil, status.Error(codes.AlreadyExists, errors.ErrMetaDataAlreadyExists(uuid).Error()) + }, + }, + targets[1]: &mockClient{ + InsertFunc: func(_ context.Context, _ *payload.Insert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return nil, status.Error(codes.AlreadyExists, errors.ErrMetaDataAlreadyExists(uuid).Error()) + }, + }, + } + return test{ + name: "Fail: when the status codes are (AlreadyExists, AlreadyExists)", + args: args{ + ctx: egctx, + req: &payload.Insert_Request{ + Vector: &payload.Object_Vector{ + Id: uuid, + Vector: vector.GaussianDistributedFloat32VectorGenerator(1, dimension)[0], + }, + Config: defaultInsertConfig, + }, + }, + fields: fields{ + eg: eg, + gateway: &mockGateway{ + FromForwardedContextFunc: func(_ context.Context) string { + return "" + }, + BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + for _, tgt := range targets { + f(ctx, tgt, cmap[tgt]) } return nil }, }, }, want: want{ - err: status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()), + err: status.Error(codes.AlreadyExists, vald.InsertRPCName+" API target same vector already exists"), }, afterFunc: func(t *testing.T, args args) { t.Helper() @@ -211,23 +354,78 @@ func Test_server_Insert(t *testing.T) { Uuid: uuid, Ips: []string{"127.0.0.1"}, } + targets := []string{ + "vald-01", "vald-02", + } cmap := map[string]vald.ClientWithMirror{ - "vald-mirror-01": &mockClient{ + targets[0]: &mockClient{ InsertFunc: func(_ context.Context, _ *payload.Insert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { return loc, nil }, - RemoveFunc: func(_ context.Context, _ *payload.Remove_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + }, + targets[1]: &mockClient{ + InsertFunc: func(_ context.Context, _ *payload.Insert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()) + }, + }, + } + return test{ + name: "Fail: when the status codes are (OK, Internal)", + args: args{ + ctx: egctx, + req: &payload.Insert_Request{ + Vector: &payload.Object_Vector{ + Id: uuid, + Vector: vector.GaussianDistributedFloat32VectorGenerator(1, dimension)[0], + }, + Config: defaultInsertConfig, + }, + }, + fields: fields{ + eg: eg, + gateway: &mockGateway{ + FromForwardedContextFunc: func(_ context.Context) string { + return "" + }, + BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + for _, tgt := range targets { + f(ctx, tgt, cmap[tgt]) + } + return nil + }, + }, + }, + want: want{ + err: status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()), + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + cancel() + }, + } + }(), + func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, egctx := errgroup.New(ctx) + + uuid := "test" + targets := []string{ + "vald-01", "vald-02", + } + cmap := map[string]vald.ClientWithMirror{ + targets[0]: &mockClient{ + InsertFunc: func(_ context.Context, _ *payload.Insert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()) }, }, - "vald-lb-gateway-01": &mockClient{ + targets[1]: &mockClient{ InsertFunc: func(_ context.Context, _ *payload.Insert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerOpenState.Error()) }, }, } return test{ - name: "fail insert with new ID and fail remove rollback", + name: "Fail: when the status codes are (Internal, Internal)", args: args{ ctx: egctx, req: &payload.Insert_Request{ @@ -244,18 +442,83 @@ func Test_server_Insert(t *testing.T) { FromForwardedContextFunc: func(_ context.Context) string { return "" }, - BroadCastFunc: func(ctx context.Context, f func(_ context.Context, _ string, _ vald.ClientWithMirror, _ ...grpc.CallOption) error) error { - for tgt, c := range cmap { - f(ctx, tgt, c) + BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + for _, tgt := range targets { + f(ctx, tgt, cmap[tgt]) } return nil }, - DoMultiFunc: func(ctx context.Context, targets []string, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { - if len(targets) != 1 { - return errors.New("invalid target") + }, + }, + want: want{ + err: status.Error(codes.Internal, errors.Join( + status.Error(codes.Internal, errors.ErrCircuitBreakerOpenState.Error()), + status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()), + ).Error()), + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + cancel() + }, + } + }(), + func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, egctx := errgroup.New(ctx) + + uuid := "test" + targets := []string{ + "vald-01", "vald-02", + } + cmap := map[string]vald.ClientWithMirror{ + targets[0]: &mockClient{ + InsertFunc: func(_ context.Context, _ *payload.Insert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return &payload.Object_Location{ + Uuid: uuid, + Ips: []string{"127.0.0.1"}, + }, nil + }, + }, + targets[1]: &mockClient{ + InsertFunc: func(_ context.Context, _ *payload.Insert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return nil, status.Error(codes.AlreadyExists, errors.ErrMetaDataAlreadyExists(uuid).Error()) + }, + UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()) + }, + }, + } + return test{ + name: "Fail: when the last status codes are (OK, Internal) after updating the target that returned AlreadyExists", + args: args{ + ctx: egctx, + req: &payload.Insert_Request{ + Vector: &payload.Object_Vector{ + Id: uuid, + Vector: vector.GaussianDistributedFloat32VectorGenerator(1, dimension)[0], + }, + Config: defaultInsertConfig, + }, + }, + fields: fields{ + eg: eg, + gateway: &mockGateway{ + FromForwardedContextFunc: func(_ context.Context) string { + return "" + }, + BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + for _, tgt := range targets { + f(ctx, tgt, cmap[tgt]) } - if c, ok := cmap[targets[0]]; ok { - f(ctx, targets[0], c) + return nil + }, + DoMultiFunc: func(ctx context.Context, targets []string, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + for _, target := range targets { + if c, ok := cmap[target]; !ok { + return errors.New("target not found") + } else { + f(ctx, target, c) + } } return nil }, @@ -340,7 +603,11 @@ func Test_server_Update(t *testing.T) { } defaultCheckFunc := func(w want, gotLoc *payload.Object_Location, err error) error { if !errors.Is(err, w.err) { - return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + gotSt, gotOk := status.FromError(err) + wantSt, wantOk := status.FromError(w.err) + if gotOk != wantOk || gotSt.Code() != wantSt.Code() { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } } if !reflect.DeepEqual(gotLoc, w.wantLoc) { return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotLoc, w.wantLoc) @@ -357,30 +624,23 @@ func Test_server_Update(t *testing.T) { Uuid: uuid, Ips: []string{"127.0.0.1"}, } + targets := []string{ + "vald-01", "vald-02", + } cmap := map[string]vald.ClientWithMirror{ - "vald-mirror-01": &mockClient{ - GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { - return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) - }, + targets[0]: &mockClient{ UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { return loc, nil }, }, - "vald-lb-gateway-01": &mockClient{ - GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { - return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) - }, + targets[1]: &mockClient{ UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { return loc, nil }, }, } - wantLoc := &payload.Object_Location{ - Uuid: uuid, - Ips: []string{"127.0.0.1", "127.0.0.1"}, - } return test{ - name: "success update with new ID", + name: "Success: update with new ID", args: args{ ctx: egctx, req: &payload.Update_Request{ @@ -398,15 +658,18 @@ func Test_server_Update(t *testing.T) { return "" }, BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { - for tgt, c := range cmap { - f(ctx, tgt, c) + for _, tgt := range targets { + f(ctx, tgt, cmap[tgt]) } return nil }, }, }, want: want{ - wantLoc: wantLoc, + wantLoc: &payload.Object_Location{ + Uuid: uuid, + Ips: []string{"127.0.0.1", "127.0.0.1"}, + }, }, afterFunc: func(t *testing.T, args args) { t.Helper() @@ -423,29 +686,23 @@ func Test_server_Update(t *testing.T) { Uuid: uuid, Ips: []string{"127.0.0.1"}, } + targets := []string{ + "vald-01", "vald-02", + } cmap := map[string]vald.ClientWithMirror{ - "vald-mirror-01": &mockClient{ - GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { - return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) - }, + targets[0]: &mockClient{ UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { return loc, nil }, - RemoveFunc: func(_ context.Context, _ *payload.Remove_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { - return loc, nil - }, }, - "vald-lb-gateway-01": &mockClient{ - GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { - return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) - }, + targets[1]: &mockClient{ UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { - return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()) + return nil, status.Error(codes.AlreadyExists, errors.ErrMetaDataAlreadyExists(uuid).Error()) }, }, } return test{ - name: "fail update with new ID but remove rollback success", + name: "Success: when the status codes are (AlreadyExists, OK)", args: args{ ctx: egctx, req: &payload.Update_Request{ @@ -462,25 +719,19 @@ func Test_server_Update(t *testing.T) { FromForwardedContextFunc: func(_ context.Context) string { return "" }, - BroadCastFunc: func(ctx context.Context, f func(_ context.Context, _ string, _ vald.ClientWithMirror, _ ...grpc.CallOption) error) error { - for tgt, c := range cmap { - f(ctx, tgt, c) - } - return nil - }, - DoMultiFunc: func(ctx context.Context, targets []string, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { - if len(targets) != 1 { - return errors.New("invalid target") - } - if c, ok := cmap[targets[0]]; ok { - f(ctx, targets[0], c) + BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + for _, tgt := range targets { + f(ctx, tgt, cmap[tgt]) } return nil }, }, }, want: want{ - err: status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()), + wantLoc: &payload.Object_Location{ + Uuid: uuid, + Ips: []string{"127.0.0.1"}, + }, }, afterFunc: func(t *testing.T, args args) { t.Helper() @@ -497,34 +748,38 @@ func Test_server_Update(t *testing.T) { Uuid: uuid, Ips: []string{"127.0.0.1"}, } - ovec := &payload.Object_Vector{ - Id: uuid, - Vector: vector.GaussianDistributedFloat32VectorGenerator(1, dimension)[0], + targets := []string{ + "vald-01", "vald-02", "vald-03", } cmap := map[string]vald.ClientWithMirror{ - "vald-mirror-01": &mockClient{ - GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { - return ovec, nil - }, + targets[0]: &mockClient{ UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { return loc, nil }, }, - "vald-lb-gateway-01": &mockClient{ - GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { + targets[1]: &mockClient{ + UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) }, + InsertFunc: func(_ context.Context, _ *payload.Insert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return loc, nil + }, + }, + targets[2]: &mockClient{ UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { - return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()) + return loc, nil }, }, } return test{ - name: "fail update with new ID but update rollback success", + name: "Success: when the last status codes are (OK, OK, OK) after inserting the target that returned NotFound", args: args{ ctx: egctx, req: &payload.Update_Request{ - Vector: ovec, + Vector: &payload.Object_Vector{ + Id: uuid, + Vector: vector.GaussianDistributedFloat32VectorGenerator(1, dimension)[0], + }, Config: defaultUpdateConfig, }, }, @@ -535,24 +790,30 @@ func Test_server_Update(t *testing.T) { return "" }, BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { - for tgt, c := range cmap { - f(ctx, tgt, c) + for _, tgt := range targets { + f(ctx, tgt, cmap[tgt]) } return nil }, DoMultiFunc: func(ctx context.Context, targets []string, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { - if len(targets) != 1 { - return errors.New("invalid target") - } - if c, ok := cmap[targets[0]]; ok { - f(ctx, targets[0], c) + for _, target := range targets { + if c, ok := cmap[target]; !ok { + return errors.ErrTargetNotFound + } else { + f(ctx, target, c) + } } return nil }, }, }, want: want{ - err: status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()), + wantLoc: &payload.Object_Location{ + Uuid: uuid, + Ips: []string{ + "127.0.0.1", "127.0.0.1", "127.0.0.1", + }, + }, }, afterFunc: func(t *testing.T, args args) { t.Helper() @@ -569,29 +830,31 @@ func Test_server_Update(t *testing.T) { Uuid: uuid, Ips: []string{"127.0.0.1"}, } + targets := []string{ + "vald-01", "vald-02", "vald-03", + } cmap := map[string]vald.ClientWithMirror{ - "vald-mirror-01": &mockClient{ - GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { - return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) - }, + targets[0]: &mockClient{ UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { return loc, nil }, - RemoveFunc: func(_ context.Context, _ *payload.Remove_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { - return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()) - }, }, - "vald-lb-gateway-01": &mockClient{ - GetObjectFunc: func(ctx context.Context, in *payload.Object_VectorRequest, opts ...grpc.CallOption) (*payload.Object_Vector, error) { + targets[1]: &mockClient{ + UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) }, + InsertFunc: func(_ context.Context, _ *payload.Insert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return loc, nil + }, + }, + targets[2]: &mockClient{ UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { - return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerOpenState.Error()) + return nil, status.Error(codes.AlreadyExists, errors.ErrMetaDataAlreadyExists(uuid).Error()) }, }, } return test{ - name: "fail update with new ID and fail remove rollback", + name: "Success: when the last status codes are (OK, OK, AlreadyExists) after inserting the target that returned NotFound", args: args{ ctx: egctx, req: &payload.Update_Request{ @@ -608,25 +871,86 @@ func Test_server_Update(t *testing.T) { FromForwardedContextFunc: func(_ context.Context) string { return "" }, - BroadCastFunc: func(ctx context.Context, f func(_ context.Context, _ string, _ vald.ClientWithMirror, _ ...grpc.CallOption) error) error { - for tgt, c := range cmap { - f(ctx, tgt, c) + BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + for _, tgt := range targets { + f(ctx, tgt, cmap[tgt]) } return nil }, DoMultiFunc: func(ctx context.Context, targets []string, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { - if len(targets) != 1 { - return errors.New("invalid target") + for _, target := range targets { + if c, ok := cmap[target]; !ok { + return errors.ErrTargetNotFound + } else { + f(ctx, target, c) + } } - if c, ok := cmap[targets[0]]; ok { - f(ctx, targets[0], c) + return nil + }, + }, + }, + want: want{ + wantLoc: &payload.Object_Location{ + Uuid: uuid, + Ips: []string{ + "127.0.0.1", "127.0.0.1", + }, + }, + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + cancel() + }, + } + }(), + func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, egctx := errgroup.New(ctx) + + uuid := "test" + targets := []string{ + "vald-01", "vald-02", + } + cmap := map[string]vald.ClientWithMirror{ + targets[0]: &mockClient{ + UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) + }, + }, + targets[1]: &mockClient{ + UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) + }, + }, + } + return test{ + name: "Fail: when the status codes are (NotFound, NotFound)", + args: args{ + ctx: egctx, + req: &payload.Update_Request{ + Vector: &payload.Object_Vector{ + Id: uuid, + Vector: vector.GaussianDistributedFloat32VectorGenerator(1, dimension)[0], + }, + Config: defaultUpdateConfig, + }, + }, + fields: fields{ + eg: eg, + gateway: &mockGateway{ + FromForwardedContextFunc: func(_ context.Context) string { + return "" + }, + BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + for _, tgt := range targets { + f(ctx, tgt, cmap[tgt]) } return nil }, }, }, want: want{ - err: status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()), + err: status.Error(codes.NotFound, vald.UpdateRPCName+" API id "+uuid+" not found"), }, afterFunc: func(t *testing.T, args args) { t.Helper() @@ -643,38 +967,151 @@ func Test_server_Update(t *testing.T) { Uuid: uuid, Ips: []string{"127.0.0.1"}, } - ovec := &payload.Object_Vector{ - Id: uuid, - Vector: vector.GaussianDistributedFloat32VectorGenerator(1, dimension)[0], + targets := []string{ + "vald-01", "vald-02", } - var cnt uint32 cmap := map[string]vald.ClientWithMirror{ - "vald-mirror-01": &mockClient{ - GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { - return ovec, nil + targets[0]: &mockClient{ + UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return loc, nil + }, + }, + targets[1]: &mockClient{ + UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()) + }, + }, + } + return test{ + name: "Fail: when the status codes are (Internal, OK)", + args: args{ + ctx: egctx, + req: &payload.Update_Request{ + Vector: &payload.Object_Vector{ + Id: uuid, + Vector: vector.GaussianDistributedFloat32VectorGenerator(1, dimension)[0], + }, + Config: defaultUpdateConfig, + }, + }, + fields: fields{ + eg: eg, + gateway: &mockGateway{ + FromForwardedContextFunc: func(_ context.Context) string { + return "" + }, + BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + for _, tgt := range targets { + f(ctx, tgt, cmap[tgt]) + } + return nil + }, }, + }, + want: want{ + err: status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()), + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + cancel() + }, + } + }(), + func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, egctx := errgroup.New(ctx) + + uuid := "test" + targets := []string{ + "vald-01", "vald-02", + } + cmap := map[string]vald.ClientWithMirror{ + targets[0]: &mockClient{ UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { - if atomic.AddUint32(&cnt, 1) == 1 { - return loc, nil - } return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()) }, }, - "vald-lb-gateway-01": &mockClient{ - GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { + targets[1]: &mockClient{ + UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return nil, status.Error(codes.AlreadyExists, errors.ErrMetaDataAlreadyExists(uuid).Error()) + }, + }, + } + return test{ + name: "Fail: when the status codes are (Internal, AlreadyExists)", + args: args{ + ctx: egctx, + req: &payload.Update_Request{ + Vector: &payload.Object_Vector{ + Id: uuid, + Vector: vector.GaussianDistributedFloat32VectorGenerator(1, dimension)[0], + }, + Config: defaultUpdateConfig, + }, + }, + fields: fields{ + eg: eg, + gateway: &mockGateway{ + FromForwardedContextFunc: func(_ context.Context) string { + return "" + }, + BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + for _, tgt := range targets { + f(ctx, tgt, cmap[tgt]) + } + return nil + }, + }, + }, + want: want{ + err: status.Error(codes.Internal, errors.Join( + status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()), + status.Error(codes.AlreadyExists, errors.ErrMetaDataAlreadyExists(uuid).Error()), + ).Error()), + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + cancel() + }, + } + }(), + func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, egctx := errgroup.New(ctx) + + uuid := "test" + targets := []string{ + "vald-01", "vald-02", "vald-03", + } + cmap := map[string]vald.ClientWithMirror{ + targets[0]: &mockClient{ + UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return nil, status.Error(codes.AlreadyExists, errors.ErrMetaDataAlreadyExists(uuid).Error()) + }, + }, + targets[1]: &mockClient{ + UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) }, + InsertFunc: func(_ context.Context, _ *payload.Insert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return nil, status.Error(codes.AlreadyExists, errors.ErrMetaDataAlreadyExists(uuid).Error()) + }, + }, + targets[2]: &mockClient{ UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { - return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerOpenState.Error()) + return nil, status.Error(codes.AlreadyExists, errors.ErrMetaDataAlreadyExists(uuid).Error()) }, }, } return test{ - name: "fail update with new ID and fail update rollback", + name: "Fail: when the last status codes are (AlreadyExists, AlreadyExists, AlreadyExists) after inserting the target that returned NotFound", args: args{ ctx: egctx, req: &payload.Update_Request{ - Vector: ovec, + Vector: &payload.Object_Vector{ + Id: uuid, + Vector: vector.GaussianDistributedFloat32VectorGenerator(1, dimension)[0], + }, Config: defaultUpdateConfig, }, }, @@ -685,17 +1122,95 @@ func Test_server_Update(t *testing.T) { return "" }, BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { - for tgt, c := range cmap { - f(ctx, tgt, c) + for _, tgt := range targets { + f(ctx, tgt, cmap[tgt]) + } + return nil + }, + DoMultiFunc: func(ctx context.Context, targets []string, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + for _, target := range targets { + if c, ok := cmap[target]; !ok { + return errors.ErrTargetNotFound + } else { + f(ctx, target, c) + } } return nil }, - DoMultiFunc: func(ctx context.Context, targets []string, f func(ctx context.Context, target string, vc vald.ClientWithMirror, _ ...grpc.CallOption) error) error { - if len(targets) != 1 { - return errors.New("invalid target") + }, + }, + want: want{ + err: status.Error(codes.AlreadyExists, vald.InsertRPCName+" for "+vald.UpdateRPCName+" API target same vector already exists"), + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + cancel() + }, + } + }(), + func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, egctx := errgroup.New(ctx) + + uuid := "test" + loc := &payload.Object_Location{ + Uuid: uuid, + Ips: []string{"127.0.0.1"}, + } + targets := []string{ + "vald-01", "vald-02", "vald-03", + } + cmap := map[string]vald.ClientWithMirror{ + targets[0]: &mockClient{ + UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return loc, nil + }, + }, + targets[1]: &mockClient{ + UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) + }, + InsertFunc: func(_ context.Context, _ *payload.Insert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()) + }, + }, + targets[2]: &mockClient{ + UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return loc, nil + }, + }, + } + return test{ + name: "Fail: when the last status codes are (OK, OK, Internal) after inserting the target that returned NotFound", + args: args{ + ctx: egctx, + req: &payload.Update_Request{ + Vector: &payload.Object_Vector{ + Id: uuid, + Vector: vector.GaussianDistributedFloat32VectorGenerator(1, dimension)[0], + }, + Config: defaultUpdateConfig, + }, + }, + fields: fields{ + eg: eg, + gateway: &mockGateway{ + FromForwardedContextFunc: func(_ context.Context) string { + return "" + }, + BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + for _, tgt := range targets { + f(ctx, tgt, cmap[tgt]) } - if c, ok := cmap[targets[0]]; ok { - f(ctx, targets[0], c) + return nil + }, + DoMultiFunc: func(ctx context.Context, targets []string, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + for _, target := range targets { + if c, ok := cmap[target]; !ok { + return errors.New("target not found") + } else { + f(ctx, target, c) + } } return nil }, @@ -780,7 +1295,11 @@ func Test_server_Upsert(t *testing.T) { } defaultCheckFunc := func(w want, gotLoc *payload.Object_Location, err error) error { if !errors.Is(err, w.err) { - return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + gotSt, gotOk := status.FromError(err) + wantSt, wantOk := status.FromError(w.err) + if gotOk != wantOk || gotSt.Code() != wantSt.Code() { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } } if !reflect.DeepEqual(gotLoc, w.wantLoc) { return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotLoc, w.wantLoc) @@ -797,30 +1316,23 @@ func Test_server_Upsert(t *testing.T) { Uuid: uuid, Ips: []string{"127.0.0.1"}, } + targets := []string{ + "vald-01", "vald-02", + } cmap := map[string]vald.ClientWithMirror{ - "vald-mirror-01": &mockClient{ - GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { - return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) - }, + targets[0]: &mockClient{ UpsertFunc: func(_ context.Context, _ *payload.Upsert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { return loc, nil }, }, - "vald-lb-gateway-01": &mockClient{ - GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { - return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) - }, + targets[1]: &mockClient{ UpsertFunc: func(_ context.Context, _ *payload.Upsert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { return loc, nil }, }, } - wantLoc := &payload.Object_Location{ - Uuid: uuid, - Ips: []string{"127.0.0.1", "127.0.0.1"}, - } return test{ - name: "success upsert with new ID", + name: "Success: upsert with new ID", args: args{ ctx: egctx, req: &payload.Upsert_Request{ @@ -838,15 +1350,18 @@ func Test_server_Upsert(t *testing.T) { return "" }, BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { - for tgt, c := range cmap { - f(ctx, tgt, c) + for _, tgt := range targets { + f(ctx, tgt, cmap[tgt]) } return nil }, }, }, want: want{ - wantLoc: wantLoc, + wantLoc: &payload.Object_Location{ + Uuid: uuid, + Ips: []string{"127.0.0.1", "127.0.0.1"}, + }, }, afterFunc: func(t *testing.T, args args) { t.Helper() @@ -863,29 +1378,23 @@ func Test_server_Upsert(t *testing.T) { Uuid: uuid, Ips: []string{"127.0.0.1"}, } + targets := []string{ + "vald-01", "vald-02", + } cmap := map[string]vald.ClientWithMirror{ - "vald-mirror-01": &mockClient{ - GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { - return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) - }, + targets[0]: &mockClient{ UpsertFunc: func(_ context.Context, _ *payload.Upsert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { return loc, nil }, - RemoveFunc: func(_ context.Context, _ *payload.Remove_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { - return loc, nil - }, }, - "vald-lb-gateway-01": &mockClient{ - GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { - return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) - }, + targets[1]: &mockClient{ UpsertFunc: func(ctx context.Context, in *payload.Upsert_Request, opts ...grpc.CallOption) (*payload.Object_Location, error) { - return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()) + return nil, status.Error(codes.AlreadyExists, errors.ErrMetaDataAlreadyExists(uuid).Error()) }, }, } return test{ - name: "fail upsert with new ID but remove rollback success", + name: "Success: when the status codes are (AlreadyExists, OK)", args: args{ ctx: egctx, req: &payload.Upsert_Request{ @@ -903,24 +1412,15 @@ func Test_server_Upsert(t *testing.T) { return "" }, BroadCastFunc: func(ctx context.Context, f func(_ context.Context, _ string, _ vald.ClientWithMirror, _ ...grpc.CallOption) error) error { - for tgt, c := range cmap { - f(ctx, tgt, c) - } - return nil - }, - DoMultiFunc: func(ctx context.Context, targets []string, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { - if len(targets) != 1 { - return errors.New("invalid target") - } - if c, ok := cmap[targets[0]]; ok { - f(ctx, targets[0], c) + for _, tgt := range targets { + f(ctx, tgt, cmap[tgt]) } return nil }, }, }, want: want{ - err: status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()), + wantLoc: loc, }, afterFunc: func(t *testing.T, args args) { t.Helper() @@ -933,41 +1433,30 @@ func Test_server_Upsert(t *testing.T) { eg, egctx := errgroup.New(ctx) uuid := "test" - loc := &payload.Object_Location{ - Uuid: uuid, - Ips: []string{"127.0.0.1"}, - } - ovec := &payload.Object_Vector{ - Id: uuid, - Vector: vector.GaussianDistributedFloat32VectorGenerator(1, dimension)[0], + targets := []string{ + "vald-01", "vald-02", } cmap := map[string]vald.ClientWithMirror{ - "vald-mirror-01": &mockClient{ - GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { - return ovec, nil - }, + targets[0]: &mockClient{ UpsertFunc: func(_ context.Context, _ *payload.Upsert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { - return loc, nil - }, - UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { - return loc, nil + return nil, status.Error(codes.AlreadyExists, errors.ErrMetaDataAlreadyExists(uuid).Error()) }, }, - "vald-lb-gateway-01": &mockClient{ - GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { - return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) - }, - UpsertFunc: func(_ context.Context, _ *payload.Upsert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { - return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()) + targets[1]: &mockClient{ + UpsertFunc: func(ctx context.Context, in *payload.Upsert_Request, opts ...grpc.CallOption) (*payload.Object_Location, error) { + return nil, status.Error(codes.AlreadyExists, errors.ErrMetaDataAlreadyExists(uuid).Error()) }, }, } return test{ - name: "fail upsert with new ID but update rollback success", + name: "Fail: when the status codes are (AlreadyExists, AlreadyExists)", args: args{ ctx: egctx, req: &payload.Upsert_Request{ - Vector: ovec, + Vector: &payload.Object_Vector{ + Id: uuid, + Vector: vector.GaussianDistributedFloat32VectorGenerator(1, dimension)[0], + }, Config: defaultUpsertConfig, }, }, @@ -977,25 +1466,16 @@ func Test_server_Upsert(t *testing.T) { FromForwardedContextFunc: func(_ context.Context) string { return "" }, - BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { - for tgt, c := range cmap { - f(ctx, tgt, c) - } - return nil - }, - DoMultiFunc: func(ctx context.Context, targets []string, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { - if len(targets) != 1 { - return errors.New("invalid target") - } - if c, ok := cmap[targets[0]]; ok { - f(ctx, targets[0], c) + BroadCastFunc: func(ctx context.Context, f func(_ context.Context, _ string, _ vald.ClientWithMirror, _ ...grpc.CallOption) error) error { + for _, tgt := range targets { + f(ctx, tgt, cmap[tgt]) } return nil }, }, }, want: want{ - err: status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()), + err: status.Error(codes.AlreadyExists, vald.UpsertRPCName+" API target same vector already exists"), }, afterFunc: func(t *testing.T, args args) { t.Helper() @@ -1012,29 +1492,23 @@ func Test_server_Upsert(t *testing.T) { Uuid: uuid, Ips: []string{"127.0.0.1"}, } + targets := []string{ + "vald-01", "vald-02", + } cmap := map[string]vald.ClientWithMirror{ - "vald-mirror-01": &mockClient{ - GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { - return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) - }, + targets[0]: &mockClient{ UpsertFunc: func(_ context.Context, _ *payload.Upsert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { return loc, nil }, - RemoveFunc: func(_ context.Context, _ *payload.Remove_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { - return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()) - }, }, - "vald-lb-gateway-01": &mockClient{ - GetObjectFunc: func(ctx context.Context, in *payload.Object_VectorRequest, opts ...grpc.CallOption) (*payload.Object_Vector, error) { - return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) - }, - UpsertFunc: func(ctx context.Context, in *payload.Upsert_Request, opts ...grpc.CallOption) (*payload.Object_Location, error) { - return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerOpenState.Error()) + targets[1]: &mockClient{ + UpsertFunc: func(_ context.Context, _ *payload.Upsert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()) }, }, } return test{ - name: "fail upsert with new ID and fail remove rollback", + name: "Fail: when the status codes are (Internal, OK)", args: args{ ctx: egctx, req: &payload.Upsert_Request{ @@ -1051,18 +1525,9 @@ func Test_server_Upsert(t *testing.T) { FromForwardedContextFunc: func(_ context.Context) string { return "" }, - BroadCastFunc: func(ctx context.Context, f func(_ context.Context, _ string, _ vald.ClientWithMirror, _ ...grpc.CallOption) error) error { - for tgt, c := range cmap { - f(ctx, tgt, c) - } - return nil - }, - DoMultiFunc: func(ctx context.Context, targets []string, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { - if len(targets) != 1 { - return errors.New("invalid target") - } - if c, ok := cmap[targets[0]]; ok { - f(ctx, targets[0], c) + BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + for _, tgt := range targets { + f(ctx, tgt, cmap[tgt]) } return nil }, @@ -1082,41 +1547,30 @@ func Test_server_Upsert(t *testing.T) { eg, egctx := errgroup.New(ctx) uuid := "test" - loc := &payload.Object_Location{ - Uuid: uuid, - Ips: []string{"127.0.0.1"}, - } - ovec := &payload.Object_Vector{ - Id: uuid, - Vector: vector.GaussianDistributedFloat32VectorGenerator(1, dimension)[0], + targets := []string{ + "vald-01", "vald-02", } cmap := map[string]vald.ClientWithMirror{ - "vald-mirror-01": &mockClient{ - GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { - return ovec, nil - }, + targets[0]: &mockClient{ UpsertFunc: func(_ context.Context, _ *payload.Upsert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { - return loc, nil - }, - UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()) }, }, - "vald-lb-gateway-01": &mockClient{ - GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { - return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) - }, + targets[1]: &mockClient{ UpsertFunc: func(_ context.Context, _ *payload.Upsert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerOpenState.Error()) }, }, } return test{ - name: "fail upsert with new ID and fail update rollback", + name: "Fail: upsert when the status codes are (Internal, Internal)", args: args{ ctx: egctx, req: &payload.Upsert_Request{ - Vector: ovec, + Vector: &payload.Object_Vector{ + Id: uuid, + Vector: vector.GaussianDistributedFloat32VectorGenerator(1, dimension)[0], + }, Config: defaultUpsertConfig, }, }, @@ -1127,24 +1581,18 @@ func Test_server_Upsert(t *testing.T) { return "" }, BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { - for tgt, c := range cmap { - f(ctx, tgt, c) - } - return nil - }, - DoMultiFunc: func(ctx context.Context, targets []string, f func(ctx context.Context, target string, vc vald.ClientWithMirror, _ ...grpc.CallOption) error) error { - if len(targets) != 1 { - return errors.New("invalid target") - } - if c, ok := cmap[targets[0]]; ok { - f(ctx, targets[0], c) + for _, tgt := range targets { + f(ctx, tgt, cmap[tgt]) } return nil }, }, }, want: want{ - err: status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()), + err: status.Error(codes.Internal, errors.Join( + status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()), + status.Error(codes.Internal, errors.ErrCircuitBreakerOpenState.Error()), + ).Error()), }, afterFunc: func(t *testing.T, args args) { t.Helper() @@ -1222,7 +1670,11 @@ func Test_server_Remove(t *testing.T) { } defaultCheckFunc := func(w want, gotLoc *payload.Object_Location, err error) error { if !errors.Is(err, w.err) { - return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + gotSt, gotOk := status.FromError(err) + wantSt, wantOk := status.FromError(w.err) + if gotOk != wantOk || gotSt.Code() != wantSt.Code() { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } } if !reflect.DeepEqual(gotLoc, w.wantLoc) { return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotLoc, w.wantLoc) @@ -1239,34 +1691,23 @@ func Test_server_Remove(t *testing.T) { Uuid: uuid, Ips: []string{"127.0.0.1"}, } - ovec := &payload.Object_Vector{ - Id: uuid, - Vector: vector.GaussianDistributedFloat32VectorGenerator(1, dimension)[0], + targets := []string{ + "vald-01", "vald-02", } cmap := map[string]vald.ClientWithMirror{ - "vald-mirror-01": &mockClient{ - GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { - return ovec, nil - }, + targets[0]: &mockClient{ RemoveFunc: func(_ context.Context, _ *payload.Remove_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { return loc, nil }, }, - "vald-lb-gateway-01": &mockClient{ - GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { - return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) - }, + targets[1]: &mockClient{ RemoveFunc: func(_ context.Context, _ *payload.Remove_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { - return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) + return loc, nil }, }, } - wantLoc := &payload.Object_Location{ - Uuid: uuid, - Ips: []string{"127.0.0.1"}, - } return test{ - name: "success remove with existing ID", + name: "Success: remove with existing ID", args: args{ ctx: egctx, req: &payload.Remove_Request{ @@ -1283,15 +1724,18 @@ func Test_server_Remove(t *testing.T) { return "" }, BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { - for tgt, c := range cmap { - f(ctx, tgt, c) + for _, tgt := range targets { + f(ctx, tgt, cmap[tgt]) } return nil }, }, }, want: want{ - wantLoc: wantLoc, + wantLoc: &payload.Object_Location{ + Uuid: uuid, + Ips: []string{"127.0.0.1", "127.0.0.1"}, + }, }, afterFunc: func(t *testing.T, args args) { t.Helper() @@ -1308,33 +1752,23 @@ func Test_server_Remove(t *testing.T) { Uuid: uuid, Ips: []string{"127.0.0.1"}, } - ovec := &payload.Object_Vector{ - Id: uuid, - Vector: vector.GaussianDistributedFloat32VectorGenerator(1, dimension)[0], + targets := []string{ + "vald-01", "vald-02", } cmap := map[string]vald.ClientWithMirror{ - "vald-mirror-01": &mockClient{ - GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { - return ovec, nil - }, + targets[0]: &mockClient{ RemoveFunc: func(_ context.Context, _ *payload.Remove_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { return loc, nil }, - UpsertFunc: func(_ context.Context, _ *payload.Upsert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { - return loc, nil - }, }, - "vald-lb-gateway-01": &mockClient{ - GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { - return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) - }, + targets[1]: &mockClient{ RemoveFunc: func(_ context.Context, _ *payload.Remove_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { - return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()) + return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) }, }, } return test{ - name: "fail remove with existing ID but upsert rollback success", + name: "Success: when the status codes are (NotFound, OK)", args: args{ ctx: egctx, req: &payload.Remove_Request{ @@ -1350,25 +1784,16 @@ func Test_server_Remove(t *testing.T) { FromForwardedContextFunc: func(_ context.Context) string { return "" }, - BroadCastFunc: func(ctx context.Context, f func(_ context.Context, _ string, _ vald.ClientWithMirror, _ ...grpc.CallOption) error) error { - for tgt, c := range cmap { - f(ctx, tgt, c) - } - return nil - }, - DoMultiFunc: func(ctx context.Context, targets []string, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { - if len(targets) != 1 { - return errors.New("invalid target") - } - if c, ok := cmap[targets[0]]; ok { - f(ctx, targets[0], c) + BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + for _, tgt := range targets { + f(ctx, tgt, cmap[tgt]) } return nil }, }, }, want: want{ - err: status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()), + wantLoc: loc, }, afterFunc: func(t *testing.T, args args) { t.Helper() @@ -1385,33 +1810,77 @@ func Test_server_Remove(t *testing.T) { Uuid: uuid, Ips: []string{"127.0.0.1"}, } - ovec := &payload.Object_Vector{ - Id: uuid, - Vector: vector.GaussianDistributedFloat32VectorGenerator(1, dimension)[0], + targets := []string{ + "vald-01", "vald-02", } cmap := map[string]vald.ClientWithMirror{ - "vald-mirror-01": &mockClient{ - GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { - return ovec, nil - }, + targets[0]: &mockClient{ RemoveFunc: func(_ context.Context, _ *payload.Remove_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { return loc, nil }, - UpsertFunc: func(_ context.Context, _ *payload.Upsert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + }, + targets[1]: &mockClient{ + RemoveFunc: func(_ context.Context, _ *payload.Remove_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()) }, }, - "vald-lb-gateway-01": &mockClient{ - GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { - return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) + } + return test{ + name: "Fail: when the status codes are (Internal, OK)", + args: args{ + ctx: egctx, + req: &payload.Remove_Request{ + Id: &payload.Object_ID{ + Id: uuid, + }, + Config: defaultRemoveConfig, + }, + }, + fields: fields{ + eg: eg, + gateway: &mockGateway{ + FromForwardedContextFunc: func(_ context.Context) string { + return "" + }, + BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + for _, tgt := range targets { + f(ctx, tgt, cmap[tgt]) + } + return nil + }, + }, + }, + want: want{ + err: status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()), + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + cancel() + }, + } + }(), + func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, egctx := errgroup.New(ctx) + + uuid := "test" + targets := []string{ + "vald-01", "vald-02", + } + cmap := map[string]vald.ClientWithMirror{ + targets[0]: &mockClient{ + RemoveFunc: func(_ context.Context, _ *payload.Remove_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()) }, + }, + targets[1]: &mockClient{ RemoveFunc: func(_ context.Context, _ *payload.Remove_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { - return loc, status.Error(codes.Internal, errors.ErrCircuitBreakerOpenState.Error()) + return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerOpenState.Error()) }, }, } return test{ - name: "fail remove with existing ID and fail upsert rollback", + name: "Fail: when the status codes are (Internal, Internal)", args: args{ ctx: egctx, req: &payload.Remove_Request{ @@ -1428,24 +1897,72 @@ func Test_server_Remove(t *testing.T) { return "" }, BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { - for tgt, c := range cmap { - f(ctx, tgt, c) + for _, tgt := range targets { + f(ctx, tgt, cmap[tgt]) } return nil }, - DoMultiFunc: func(ctx context.Context, targets []string, f func(ctx context.Context, target string, vc vald.ClientWithMirror, _ ...grpc.CallOption) error) error { - if len(targets) != 1 { - return errors.New("invalid target") - } - if c, ok := cmap[targets[0]]; ok { - f(ctx, targets[0], c) + }, + }, + want: want{ + err: status.Error(codes.Internal, errors.Join( + status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()), + status.Error(codes.Internal, errors.ErrCircuitBreakerOpenState.Error()), + ).Error()), + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + cancel() + }, + } + }(), + func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, egctx := errgroup.New(ctx) + + uuid := "test" + targets := []string{ + "vald-01", "vald-02", + } + cmap := map[string]vald.ClientWithMirror{ + targets[0]: &mockClient{ + RemoveFunc: func(_ context.Context, _ *payload.Remove_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return nil, status.Error(codes.NotFound, errors.ErrIndexNotFound.Error()) + }, + }, + targets[1]: &mockClient{ + RemoveFunc: func(_ context.Context, _ *payload.Remove_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return nil, status.Error(codes.NotFound, errors.ErrIndexNotFound.Error()) + }, + }, + } + return test{ + name: "Fail: when the status codes are (NotFound, NotFound)", + args: args{ + ctx: egctx, + req: &payload.Remove_Request{ + Id: &payload.Object_ID{ + Id: uuid, + }, + Config: defaultRemoveConfig, + }, + }, + fields: fields{ + eg: eg, + gateway: &mockGateway{ + FromForwardedContextFunc: func(_ context.Context) string { + return "" + }, + BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + for _, tgt := range targets { + f(ctx, tgt, cmap[tgt]) } return nil }, }, }, want: want{ - err: status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()), + err: status.Error(codes.NotFound, vald.RemoveRPCName+" API id "+uuid+" not found"), }, afterFunc: func(t *testing.T, args args) { t.Helper() @@ -1489,593 +2006,286 @@ func Test_server_Remove(t *testing.T) { } } -// NOT IMPLEMENTED BELOW - -func TestNew(t *testing.T) { +func Test_server_RemoveByTimestamp(t *testing.T) { + defaultRemoveByTimestampReq := &payload.Remove_TimestampRequest{ + Timestamps: []*payload.Remove_Timestamp{}, + } type args struct { - opts []Option + ctx context.Context + req *payload.Remove_TimestampRequest + } + type fields struct { + eg errgroup.Group + gateway service.Gateway + mirror service.Mirror + vAddr string + streamConcurrency int + name string + ip string + UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror } type want struct { - want vald.ServerWithMirror - err error + wantLocs *payload.Object_Locations + err error } type test struct { name string args args + fields fields want want - checkFunc func(want, vald.Server, error) error + checkFunc func(want, *payload.Object_Locations, error) error beforeFunc func(*testing.T, args) afterFunc func(*testing.T, args) } - defaultCheckFunc := func(w want, got vald.Server, err error) error { + defaultCheckFunc := func(w want, gotLocs *payload.Object_Locations, err error) error { if !errors.Is(err, w.err) { - return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + gotSt, gotOk := status.FromError(err) + wantSt, wantOk := status.FromError(w.err) + if gotOk != wantOk || gotSt.Code() != wantSt.Code() { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } } - if !reflect.DeepEqual(got, w.want) { - return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", got, w.want) + if !reflect.DeepEqual(gotLocs, w.wantLocs) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotLocs, w.wantLocs) } return nil } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - opts:nil, - }, - want: want{}, - checkFunc: defaultCheckFunc, - beforeFunc: func(t *testing.T, args args) { - t.Helper() - }, - afterFunc: func(t *testing.T, args args) { - t.Helper() - }, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - opts:nil, - }, - want: want{}, - checkFunc: defaultCheckFunc, - beforeFunc: func(t *testing.T, args args) { - t.Helper() - }, - afterFunc: func(t *testing.T, args args) { - t.Helper() - }, - } - }(), - */ - } + func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, egctx := errgroup.New(ctx) - for _, tc := range tests { - test := tc - t.Run(test.name, func(tt *testing.T) { - tt.Parallel() - defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) - if test.beforeFunc != nil { - test.beforeFunc(tt, test.args) + loc := &payload.Object_Location{ + Uuid: "test", + Ips: []string{ + "127.0.0.1", + }, } - if test.afterFunc != nil { - defer test.afterFunc(tt, test.args) + loc2 := &payload.Object_Location{ + Uuid: "test02", + Ips: []string{ + "127.0.0.1", + }, } - checkFunc := test.checkFunc - if test.checkFunc == nil { - checkFunc = defaultCheckFunc + targets := []string{ + "vald-01", "vald-02", } - - got, err := New(test.args.opts...) - if err := checkFunc(test.want, got, err); err != nil { - tt.Errorf("error = %v", err) + cmap := map[string]vald.ClientWithMirror{ + targets[0]: &mockClient{ + RemoveByTimestampFunc: func(_ context.Context, _ *payload.Remove_TimestampRequest, _ ...grpc.CallOption) (*payload.Object_Locations, error) { + return &payload.Object_Locations{ + Locations: []*payload.Object_Location{ + loc, + }, + }, nil + }, + }, + targets[1]: &mockClient{ + RemoveByTimestampFunc: func(_ context.Context, _ *payload.Remove_TimestampRequest, _ ...grpc.CallOption) (*payload.Object_Locations, error) { + return &payload.Object_Locations{ + Locations: []*payload.Object_Location{ + loc2, + }, + }, nil + }, + }, } - }) - } -} - -func Test_server_Register(t *testing.T) { - type args struct { - ctx context.Context - req *payload.Mirror_Targets - } - type fields struct { - eg errgroup.Group - gateway service.Gateway - mirror service.Mirror - vAddr string - streamConcurrency int - name string - ip string - UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror - } - type want struct { - want *payload.Mirror_Targets - err error - } - type test struct { - name string - args args - fields fields - want want - checkFunc func(want, *payload.Mirror_Targets, error) error - beforeFunc func(*testing.T, args) - afterFunc func(*testing.T, args) - } - defaultCheckFunc := func(w want, got *payload.Mirror_Targets, err error) error { - if !errors.Is(err, w.err) { - return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) - } - if !reflect.DeepEqual(got, w.want) { - return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", got, w.want) - } - return nil - } - tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - ctx:nil, - req:nil, - }, - fields: fields { - eg:nil, - gateway:nil, - mirror:nil, - vAddr:"", - streamConcurrency:0, - name:"", - ip:"", - UnimplementedValdServerWithMirror:nil, - }, - want: want{}, - checkFunc: defaultCheckFunc, - beforeFunc: func(t *testing.T, args args) { - t.Helper() - }, - afterFunc: func(t *testing.T, args args) { - t.Helper() - }, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - ctx:nil, - req:nil, - }, - fields: fields { - eg:nil, - gateway:nil, - mirror:nil, - vAddr:"", - streamConcurrency:0, - name:"", - ip:"", - UnimplementedValdServerWithMirror:nil, - }, - want: want{}, - checkFunc: defaultCheckFunc, - beforeFunc: func(t *testing.T, args args) { - t.Helper() - }, - afterFunc: func(t *testing.T, args args) { - t.Helper() - }, - } - }(), - */ - } + return test{ + name: "Success: removeByTimestamp", + args: args{ + ctx: egctx, + req: defaultRemoveByTimestampReq, + }, + fields: fields{ + eg: eg, + gateway: &mockGateway{ + FromForwardedContextFunc: func(_ context.Context) string { + return "" + }, + BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + for _, tgt := range targets { + f(ctx, tgt, cmap[tgt]) + } + return nil + }, + }, + }, + want: want{ + wantLocs: &payload.Object_Locations{ + Locations: []*payload.Object_Location{ + loc, loc2, + }, + }, + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + cancel() + }, + } + }(), + func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, egctx := errgroup.New(ctx) - for _, tc := range tests { - test := tc - t.Run(test.name, func(tt *testing.T) { - tt.Parallel() - defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) - if test.beforeFunc != nil { - test.beforeFunc(tt, test.args) + loc := &payload.Object_Location{ + Uuid: "test", + Ips: []string{ + "127.0.0.1", + }, } - if test.afterFunc != nil { - defer test.afterFunc(tt, test.args) + targets := []string{ + "vald-01", "vald-02", } - checkFunc := test.checkFunc - if test.checkFunc == nil { - checkFunc = defaultCheckFunc + cmap := map[string]vald.ClientWithMirror{ + targets[0]: &mockClient{ + RemoveByTimestampFunc: func(_ context.Context, _ *payload.Remove_TimestampRequest, _ ...grpc.CallOption) (*payload.Object_Locations, error) { + return &payload.Object_Locations{ + Locations: []*payload.Object_Location{ + loc, + }, + }, nil + }, + }, + targets[1]: &mockClient{ + RemoveByTimestampFunc: func(_ context.Context, _ *payload.Remove_TimestampRequest, _ ...grpc.CallOption) (*payload.Object_Locations, error) { + return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound("test02").Error()) + }, + }, } - s := &server{ - eg: test.fields.eg, - gateway: test.fields.gateway, - mirror: test.fields.mirror, - vAddr: test.fields.vAddr, - streamConcurrency: test.fields.streamConcurrency, - name: test.fields.name, - ip: test.fields.ip, - UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, + return test{ + name: "Success: when the status codes are (NotFound, OK)", + args: args{ + ctx: egctx, + req: defaultRemoveByTimestampReq, + }, + fields: fields{ + eg: eg, + gateway: &mockGateway{ + FromForwardedContextFunc: func(_ context.Context) string { + return "" + }, + BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + for _, tgt := range targets { + f(ctx, tgt, cmap[tgt]) + } + return nil + }, + }, + }, + want: want{ + wantLocs: &payload.Object_Locations{ + Locations: []*payload.Object_Location{ + loc, + }, + }, + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + cancel() + }, } + }(), + func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, egctx := errgroup.New(ctx) - got, err := s.Register(test.args.ctx, test.args.req) - if err := checkFunc(test.want, got, err); err != nil { - tt.Errorf("error = %v", err) + targets := []string{ + "vald-01", "vald-02", } - }) - } -} + cmap := map[string]vald.ClientWithMirror{ + targets[0]: &mockClient{ + RemoveByTimestampFunc: func(_ context.Context, _ *payload.Remove_TimestampRequest, _ ...grpc.CallOption) (*payload.Object_Locations, error) { + return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()) + }, + }, + targets[1]: &mockClient{ + RemoveByTimestampFunc: func(_ context.Context, _ *payload.Remove_TimestampRequest, _ ...grpc.CallOption) (*payload.Object_Locations, error) { + return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerOpenState.Error()) + }, + }, + } + return test{ + name: "Fail: when the status codes are (Internal, Internal)", + args: args{ + ctx: egctx, + req: defaultRemoveByTimestampReq, + }, + fields: fields{ + eg: eg, + gateway: &mockGateway{ + FromForwardedContextFunc: func(_ context.Context) string { + return "" + }, + BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + for _, tgt := range targets { + f(ctx, tgt, cmap[tgt]) + } + return nil + }, + }, + }, + want: want{ + err: status.Error(codes.Internal, errors.Join( + status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()), + status.Error(codes.Internal, errors.ErrCircuitBreakerOpenState.Error()), + ).Error()), + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + cancel() + }, + } + }(), + func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, egctx := errgroup.New(ctx) -func Test_server_Advertise(t *testing.T) { - type args struct { - ctx context.Context - req *payload.Mirror_Targets - } - type fields struct { - eg errgroup.Group - gateway service.Gateway - mirror service.Mirror - vAddr string - streamConcurrency int - name string - ip string - UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror - } - type want struct { - wantRes *payload.Mirror_Targets - err error - } - type test struct { - name string - args args - fields fields - want want - checkFunc func(want, *payload.Mirror_Targets, error) error - beforeFunc func(*testing.T, args) - afterFunc func(*testing.T, args) - } - defaultCheckFunc := func(w want, gotRes *payload.Mirror_Targets, err error) error { - if !errors.Is(err, w.err) { - return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) - } - if !reflect.DeepEqual(gotRes, w.wantRes) { - return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotRes, w.wantRes) - } - return nil - } - tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - ctx:nil, - req:nil, - }, - fields: fields { - eg:nil, - gateway:nil, - mirror:nil, - vAddr:"", - streamConcurrency:0, - name:"", - ip:"", - UnimplementedValdServerWithMirror:nil, - }, - want: want{}, - checkFunc: defaultCheckFunc, - beforeFunc: func(t *testing.T, args args) { - t.Helper() - }, - afterFunc: func(t *testing.T, args args) { - t.Helper() - }, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - ctx:nil, - req:nil, - }, - fields: fields { - eg:nil, - gateway:nil, - mirror:nil, - vAddr:"", - streamConcurrency:0, - name:"", - ip:"", - UnimplementedValdServerWithMirror:nil, - }, - want: want{}, - checkFunc: defaultCheckFunc, - beforeFunc: func(t *testing.T, args args) { - t.Helper() - }, - afterFunc: func(t *testing.T, args args) { - t.Helper() - }, - } - }(), - */ - } - - for _, tc := range tests { - test := tc - t.Run(test.name, func(tt *testing.T) { - tt.Parallel() - defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) - if test.beforeFunc != nil { - test.beforeFunc(tt, test.args) + uuid1 := "test01" + uuid2 := "test02" + targets := []string{ + "vald-01", "vald-02", } - if test.afterFunc != nil { - defer test.afterFunc(tt, test.args) - } - checkFunc := test.checkFunc - if test.checkFunc == nil { - checkFunc = defaultCheckFunc - } - s := &server{ - eg: test.fields.eg, - gateway: test.fields.gateway, - mirror: test.fields.mirror, - vAddr: test.fields.vAddr, - streamConcurrency: test.fields.streamConcurrency, - name: test.fields.name, - ip: test.fields.ip, - UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, - } - - gotRes, err := s.Advertise(test.args.ctx, test.args.req) - if err := checkFunc(test.want, gotRes, err); err != nil { - tt.Errorf("error = %v", err) - } - }) - } -} - -func Test_server_Exists(t *testing.T) { - type args struct { - ctx context.Context - meta *payload.Object_ID - } - type fields struct { - eg errgroup.Group - gateway service.Gateway - mirror service.Mirror - vAddr string - streamConcurrency int - name string - ip string - UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror - } - type want struct { - wantId *payload.Object_ID - err error - } - type test struct { - name string - args args - fields fields - want want - checkFunc func(want, *payload.Object_ID, error) error - beforeFunc func(*testing.T, args) - afterFunc func(*testing.T, args) - } - defaultCheckFunc := func(w want, gotId *payload.Object_ID, err error) error { - if !errors.Is(err, w.err) { - return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) - } - if !reflect.DeepEqual(gotId, w.wantId) { - return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotId, w.wantId) - } - return nil - } - tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - ctx:nil, - meta:nil, - }, - fields: fields { - eg:nil, - gateway:nil, - mirror:nil, - vAddr:"", - streamConcurrency:0, - name:"", - ip:"", - UnimplementedValdServerWithMirror:nil, - }, - want: want{}, - checkFunc: defaultCheckFunc, - beforeFunc: func(t *testing.T, args args) { - t.Helper() - }, - afterFunc: func(t *testing.T, args args) { - t.Helper() - }, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - ctx:nil, - meta:nil, - }, - fields: fields { - eg:nil, - gateway:nil, - mirror:nil, - vAddr:"", - streamConcurrency:0, - name:"", - ip:"", - UnimplementedValdServerWithMirror:nil, - }, - want: want{}, - checkFunc: defaultCheckFunc, - beforeFunc: func(t *testing.T, args args) { - t.Helper() - }, - afterFunc: func(t *testing.T, args args) { - t.Helper() - }, - } - }(), - */ - } - - for _, tc := range tests { - test := tc - t.Run(test.name, func(tt *testing.T) { - tt.Parallel() - defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) - if test.beforeFunc != nil { - test.beforeFunc(tt, test.args) - } - if test.afterFunc != nil { - defer test.afterFunc(tt, test.args) - } - checkFunc := test.checkFunc - if test.checkFunc == nil { - checkFunc = defaultCheckFunc - } - s := &server{ - eg: test.fields.eg, - gateway: test.fields.gateway, - mirror: test.fields.mirror, - vAddr: test.fields.vAddr, - streamConcurrency: test.fields.streamConcurrency, - name: test.fields.name, - ip: test.fields.ip, - UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, - } - - gotId, err := s.Exists(test.args.ctx, test.args.meta) - if err := checkFunc(test.want, gotId, err); err != nil { - tt.Errorf("error = %v", err) + cmap := map[string]vald.ClientWithMirror{ + targets[0]: &mockClient{ + RemoveByTimestampFunc: func(_ context.Context, _ *payload.Remove_TimestampRequest, _ ...grpc.CallOption) (*payload.Object_Locations, error) { + return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid1).Error()) + }, + }, + targets[1]: &mockClient{ + RemoveByTimestampFunc: func(_ context.Context, _ *payload.Remove_TimestampRequest, _ ...grpc.CallOption) (*payload.Object_Locations, error) { + return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid2).Error()) + }, + }, } - }) - } -} - -func Test_server_Search(t *testing.T) { - type args struct { - ctx context.Context - req *payload.Search_Request - } - type fields struct { - eg errgroup.Group - gateway service.Gateway - mirror service.Mirror - vAddr string - streamConcurrency int - name string - ip string - UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror - } - type want struct { - wantRes *payload.Search_Response - err error - } - type test struct { - name string - args args - fields fields - want want - checkFunc func(want, *payload.Search_Response, error) error - beforeFunc func(*testing.T, args) - afterFunc func(*testing.T, args) - } - defaultCheckFunc := func(w want, gotRes *payload.Search_Response, err error) error { - if !errors.Is(err, w.err) { - return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) - } - if !reflect.DeepEqual(gotRes, w.wantRes) { - return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotRes, w.wantRes) - } - return nil - } - tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - ctx:nil, - req:nil, - }, - fields: fields { - eg:nil, - gateway:nil, - mirror:nil, - vAddr:"", - streamConcurrency:0, - name:"", - ip:"", - UnimplementedValdServerWithMirror:nil, - }, - want: want{}, - checkFunc: defaultCheckFunc, - beforeFunc: func(t *testing.T, args args) { - t.Helper() - }, - afterFunc: func(t *testing.T, args args) { - t.Helper() - }, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - ctx:nil, - req:nil, - }, - fields: fields { - eg:nil, - gateway:nil, - mirror:nil, - vAddr:"", - streamConcurrency:0, - name:"", - ip:"", - UnimplementedValdServerWithMirror:nil, - }, - want: want{}, - checkFunc: defaultCheckFunc, - beforeFunc: func(t *testing.T, args args) { - t.Helper() - }, - afterFunc: func(t *testing.T, args args) { - t.Helper() - }, - } - }(), - */ + return test{ + name: "Fail: when the status codes are (NotFound, NotFound)", + args: args{ + ctx: egctx, + req: defaultRemoveByTimestampReq, + }, + fields: fields{ + eg: eg, + gateway: &mockGateway{ + FromForwardedContextFunc: func(_ context.Context) string { + return "" + }, + BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + for _, tgt := range targets { + f(ctx, tgt, cmap[tgt]) + } + return nil + }, + }, + }, + want: want{ + err: status.Error(codes.NotFound, vald.RemoveByTimestampRPCName+" API target not found"), + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + cancel() + }, + } + }(), } for _, tc := range tests { @@ -2104,48 +2314,38 @@ func Test_server_Search(t *testing.T) { UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, } - gotRes, err := s.Search(test.args.ctx, test.args.req) - if err := checkFunc(test.want, gotRes, err); err != nil { + gotLocs, err := s.RemoveByTimestamp(test.args.ctx, test.args.req) + if err := checkFunc(test.want, gotLocs, err); err != nil { tt.Errorf("error = %v", err) } }) } } -func Test_server_SearchByID(t *testing.T) { +// NOT IMPLEMENTED BELOW + +func TestNew(t *testing.T) { type args struct { - ctx context.Context - req *payload.Search_IDRequest - } - type fields struct { - eg errgroup.Group - gateway service.Gateway - mirror service.Mirror - vAddr string - streamConcurrency int - name string - ip string - UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror + opts []Option } type want struct { - wantRes *payload.Search_Response - err error + want vald.ServerWithMirror + err error } type test struct { name string args args - fields fields want want - checkFunc func(want, *payload.Search_Response, error) error + checkFunc func(want, vald.Server, error) error beforeFunc func(*testing.T, args) afterFunc func(*testing.T, args) } - defaultCheckFunc := func(w want, gotRes *payload.Search_Response, err error) error { + defaultCheckFunc := func(w want, got vald.Server, err error) error { if !errors.Is(err, w.err) { return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) } - if !reflect.DeepEqual(gotRes, w.wantRes) { - return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotRes, w.wantRes) + if !reflect.DeepEqual(got, w.want) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", got, w.want) } return nil } @@ -2155,18 +2355,7 @@ func Test_server_SearchByID(t *testing.T) { { name: "test_case_1", args: args { - ctx:nil, - req:nil, - }, - fields: fields { - eg:nil, - gateway:nil, - mirror:nil, - vAddr:"", - streamConcurrency:0, - name:"", - ip:"", - UnimplementedValdServerWithMirror:nil, + opts:nil, }, want: want{}, checkFunc: defaultCheckFunc, @@ -2185,18 +2374,7 @@ func Test_server_SearchByID(t *testing.T) { return test { name: "test_case_2", args: args { - ctx:nil, - req:nil, - }, - fields: fields { - eg:nil, - gateway:nil, - mirror:nil, - vAddr:"", - streamConcurrency:0, - name:"", - ip:"", - UnimplementedValdServerWithMirror:nil, + opts:nil, }, want: want{}, checkFunc: defaultCheckFunc, @@ -2226,28 +2404,19 @@ func Test_server_SearchByID(t *testing.T) { if test.checkFunc == nil { checkFunc = defaultCheckFunc } - s := &server{ - eg: test.fields.eg, - gateway: test.fields.gateway, - mirror: test.fields.mirror, - vAddr: test.fields.vAddr, - streamConcurrency: test.fields.streamConcurrency, - name: test.fields.name, - ip: test.fields.ip, - UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, - } - gotRes, err := s.SearchByID(test.args.ctx, test.args.req) - if err := checkFunc(test.want, gotRes, err); err != nil { + got, err := New(test.args.opts...) + if err := checkFunc(test.want, got, err); err != nil { tt.Errorf("error = %v", err) } }) } } -func Test_server_StreamSearch(t *testing.T) { +func Test_server_Register(t *testing.T) { type args struct { - stream vald.Search_StreamSearchServer + ctx context.Context + req *payload.Mirror_Targets } type fields struct { eg errgroup.Group @@ -2260,21 +2429,25 @@ func Test_server_StreamSearch(t *testing.T) { UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror } type want struct { - err error + want *payload.Mirror_Targets + err error } type test struct { name string args args fields fields want want - checkFunc func(want, error) error + checkFunc func(want, *payload.Mirror_Targets, error) error beforeFunc func(*testing.T, args) afterFunc func(*testing.T, args) } - defaultCheckFunc := func(w want, err error) error { + defaultCheckFunc := func(w want, got *payload.Mirror_Targets, err error) error { if !errors.Is(err, w.err) { return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) } + if !reflect.DeepEqual(got, w.want) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", got, w.want) + } return nil } tests := []test{ @@ -2283,7 +2456,8 @@ func Test_server_StreamSearch(t *testing.T) { { name: "test_case_1", args: args { - stream:nil, + ctx:nil, + req:nil, }, fields: fields { eg:nil, @@ -2312,7 +2486,8 @@ func Test_server_StreamSearch(t *testing.T) { return test { name: "test_case_2", args: args { - stream:nil, + ctx:nil, + req:nil, }, fields: fields { eg:nil, @@ -2363,17 +2538,18 @@ func Test_server_StreamSearch(t *testing.T) { UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, } - err := s.StreamSearch(test.args.stream) - if err := checkFunc(test.want, err); err != nil { + got, err := s.Register(test.args.ctx, test.args.req) + if err := checkFunc(test.want, got, err); err != nil { tt.Errorf("error = %v", err) } }) } } -func Test_server_StreamSearchByID(t *testing.T) { +func Test_server_Advertise(t *testing.T) { type args struct { - stream vald.Search_StreamSearchByIDServer + ctx context.Context + req *payload.Mirror_Targets } type fields struct { eg errgroup.Group @@ -2386,21 +2562,25 @@ func Test_server_StreamSearchByID(t *testing.T) { UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror } type want struct { - err error + wantRes *payload.Mirror_Targets + err error } type test struct { name string args args fields fields want want - checkFunc func(want, error) error + checkFunc func(want, *payload.Mirror_Targets, error) error beforeFunc func(*testing.T, args) afterFunc func(*testing.T, args) } - defaultCheckFunc := func(w want, err error) error { + defaultCheckFunc := func(w want, gotRes *payload.Mirror_Targets, err error) error { if !errors.Is(err, w.err) { return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) } + if !reflect.DeepEqual(gotRes, w.wantRes) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotRes, w.wantRes) + } return nil } tests := []test{ @@ -2409,7 +2589,8 @@ func Test_server_StreamSearchByID(t *testing.T) { { name: "test_case_1", args: args { - stream:nil, + ctx:nil, + req:nil, }, fields: fields { eg:nil, @@ -2438,7 +2619,8 @@ func Test_server_StreamSearchByID(t *testing.T) { return test { name: "test_case_2", args: args { - stream:nil, + ctx:nil, + req:nil, }, fields: fields { eg:nil, @@ -2489,18 +2671,18 @@ func Test_server_StreamSearchByID(t *testing.T) { UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, } - err := s.StreamSearchByID(test.args.stream) - if err := checkFunc(test.want, err); err != nil { + gotRes, err := s.Advertise(test.args.ctx, test.args.req) + if err := checkFunc(test.want, gotRes, err); err != nil { tt.Errorf("error = %v", err) } }) } } -func Test_server_MultiSearch(t *testing.T) { +func Test_server_Exists(t *testing.T) { type args struct { - ctx context.Context - req *payload.Search_MultiRequest + ctx context.Context + meta *payload.Object_ID } type fields struct { eg errgroup.Group @@ -2513,24 +2695,24 @@ func Test_server_MultiSearch(t *testing.T) { UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror } type want struct { - wantRes *payload.Search_Responses - err error + wantId *payload.Object_ID + err error } type test struct { name string args args fields fields want want - checkFunc func(want, *payload.Search_Responses, error) error + checkFunc func(want, *payload.Object_ID, error) error beforeFunc func(*testing.T, args) afterFunc func(*testing.T, args) } - defaultCheckFunc := func(w want, gotRes *payload.Search_Responses, err error) error { + defaultCheckFunc := func(w want, gotId *payload.Object_ID, err error) error { if !errors.Is(err, w.err) { return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) } - if !reflect.DeepEqual(gotRes, w.wantRes) { - return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotRes, w.wantRes) + if !reflect.DeepEqual(gotId, w.wantId) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotId, w.wantId) } return nil } @@ -2541,7 +2723,7 @@ func Test_server_MultiSearch(t *testing.T) { name: "test_case_1", args: args { ctx:nil, - req:nil, + meta:nil, }, fields: fields { eg:nil, @@ -2571,7 +2753,7 @@ func Test_server_MultiSearch(t *testing.T) { name: "test_case_2", args: args { ctx:nil, - req:nil, + meta:nil, }, fields: fields { eg:nil, @@ -2622,18 +2804,18 @@ func Test_server_MultiSearch(t *testing.T) { UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, } - gotRes, err := s.MultiSearch(test.args.ctx, test.args.req) - if err := checkFunc(test.want, gotRes, err); err != nil { + gotId, err := s.Exists(test.args.ctx, test.args.meta) + if err := checkFunc(test.want, gotId, err); err != nil { tt.Errorf("error = %v", err) } }) } } -func Test_server_MultiSearchByID(t *testing.T) { +func Test_server_Search(t *testing.T) { type args struct { ctx context.Context - req *payload.Search_MultiIDRequest + req *payload.Search_Request } type fields struct { eg errgroup.Group @@ -2646,7 +2828,7 @@ func Test_server_MultiSearchByID(t *testing.T) { UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror } type want struct { - wantRes *payload.Search_Responses + wantRes *payload.Search_Response err error } type test struct { @@ -2654,11 +2836,11 @@ func Test_server_MultiSearchByID(t *testing.T) { args args fields fields want want - checkFunc func(want, *payload.Search_Responses, error) error + checkFunc func(want, *payload.Search_Response, error) error beforeFunc func(*testing.T, args) afterFunc func(*testing.T, args) } - defaultCheckFunc := func(w want, gotRes *payload.Search_Responses, err error) error { + defaultCheckFunc := func(w want, gotRes *payload.Search_Response, err error) error { if !errors.Is(err, w.err) { return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) } @@ -2755,7 +2937,7 @@ func Test_server_MultiSearchByID(t *testing.T) { UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, } - gotRes, err := s.MultiSearchByID(test.args.ctx, test.args.req) + gotRes, err := s.Search(test.args.ctx, test.args.req) if err := checkFunc(test.want, gotRes, err); err != nil { tt.Errorf("error = %v", err) } @@ -2763,10 +2945,10 @@ func Test_server_MultiSearchByID(t *testing.T) { } } -func Test_server_LinearSearch(t *testing.T) { +func Test_server_SearchByID(t *testing.T) { type args struct { ctx context.Context - req *payload.Search_Request + req *payload.Search_IDRequest } type fields struct { eg errgroup.Group @@ -2888,7 +3070,7 @@ func Test_server_LinearSearch(t *testing.T) { UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, } - gotRes, err := s.LinearSearch(test.args.ctx, test.args.req) + gotRes, err := s.SearchByID(test.args.ctx, test.args.req) if err := checkFunc(test.want, gotRes, err); err != nil { tt.Errorf("error = %v", err) } @@ -2896,10 +3078,9 @@ func Test_server_LinearSearch(t *testing.T) { } } -func Test_server_LinearSearchByID(t *testing.T) { +func Test_server_StreamSearch(t *testing.T) { type args struct { - ctx context.Context - req *payload.Search_IDRequest + stream vald.Search_StreamSearchServer } type fields struct { eg errgroup.Group @@ -2912,25 +3093,21 @@ func Test_server_LinearSearchByID(t *testing.T) { UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror } type want struct { - wantRes *payload.Search_Response - err error + err error } type test struct { name string args args fields fields want want - checkFunc func(want, *payload.Search_Response, error) error + checkFunc func(want, error) error beforeFunc func(*testing.T, args) afterFunc func(*testing.T, args) } - defaultCheckFunc := func(w want, gotRes *payload.Search_Response, err error) error { + defaultCheckFunc := func(w want, err error) error { if !errors.Is(err, w.err) { return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) } - if !reflect.DeepEqual(gotRes, w.wantRes) { - return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotRes, w.wantRes) - } return nil } tests := []test{ @@ -2939,8 +3116,7 @@ func Test_server_LinearSearchByID(t *testing.T) { { name: "test_case_1", args: args { - ctx:nil, - req:nil, + stream:nil, }, fields: fields { eg:nil, @@ -2969,8 +3145,7 @@ func Test_server_LinearSearchByID(t *testing.T) { return test { name: "test_case_2", args: args { - ctx:nil, - req:nil, + stream:nil, }, fields: fields { eg:nil, @@ -3021,17 +3196,17 @@ func Test_server_LinearSearchByID(t *testing.T) { UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, } - gotRes, err := s.LinearSearchByID(test.args.ctx, test.args.req) - if err := checkFunc(test.want, gotRes, err); err != nil { + err := s.StreamSearch(test.args.stream) + if err := checkFunc(test.want, err); err != nil { tt.Errorf("error = %v", err) } }) } } -func Test_server_StreamLinearSearch(t *testing.T) { +func Test_server_StreamSearchByID(t *testing.T) { type args struct { - stream vald.Search_StreamLinearSearchServer + stream vald.Search_StreamSearchByIDServer } type fields struct { eg errgroup.Group @@ -3147,7 +3322,7 @@ func Test_server_StreamLinearSearch(t *testing.T) { UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, } - err := s.StreamLinearSearch(test.args.stream) + err := s.StreamSearchByID(test.args.stream) if err := checkFunc(test.want, err); err != nil { tt.Errorf("error = %v", err) } @@ -3155,9 +3330,10 @@ func Test_server_StreamLinearSearch(t *testing.T) { } } -func Test_server_StreamLinearSearchByID(t *testing.T) { +func Test_server_MultiSearch(t *testing.T) { type args struct { - stream vald.Search_StreamLinearSearchByIDServer + ctx context.Context + req *payload.Search_MultiRequest } type fields struct { eg errgroup.Group @@ -3170,21 +3346,25 @@ func Test_server_StreamLinearSearchByID(t *testing.T) { UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror } type want struct { - err error + wantRes *payload.Search_Responses + err error } type test struct { name string args args fields fields want want - checkFunc func(want, error) error + checkFunc func(want, *payload.Search_Responses, error) error beforeFunc func(*testing.T, args) afterFunc func(*testing.T, args) } - defaultCheckFunc := func(w want, err error) error { + defaultCheckFunc := func(w want, gotRes *payload.Search_Responses, err error) error { if !errors.Is(err, w.err) { return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) } + if !reflect.DeepEqual(gotRes, w.wantRes) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotRes, w.wantRes) + } return nil } tests := []test{ @@ -3193,7 +3373,8 @@ func Test_server_StreamLinearSearchByID(t *testing.T) { { name: "test_case_1", args: args { - stream:nil, + ctx:nil, + req:nil, }, fields: fields { eg:nil, @@ -3222,7 +3403,8 @@ func Test_server_StreamLinearSearchByID(t *testing.T) { return test { name: "test_case_2", args: args { - stream:nil, + ctx:nil, + req:nil, }, fields: fields { eg:nil, @@ -3273,18 +3455,18 @@ func Test_server_StreamLinearSearchByID(t *testing.T) { UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, } - err := s.StreamLinearSearchByID(test.args.stream) - if err := checkFunc(test.want, err); err != nil { + gotRes, err := s.MultiSearch(test.args.ctx, test.args.req) + if err := checkFunc(test.want, gotRes, err); err != nil { tt.Errorf("error = %v", err) } }) } } -func Test_server_MultiLinearSearch(t *testing.T) { +func Test_server_MultiSearchByID(t *testing.T) { type args struct { ctx context.Context - req *payload.Search_MultiRequest + req *payload.Search_MultiIDRequest } type fields struct { eg errgroup.Group @@ -3406,7 +3588,7 @@ func Test_server_MultiLinearSearch(t *testing.T) { UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, } - gotRes, err := s.MultiLinearSearch(test.args.ctx, test.args.req) + gotRes, err := s.MultiSearchByID(test.args.ctx, test.args.req) if err := checkFunc(test.want, gotRes, err); err != nil { tt.Errorf("error = %v", err) } @@ -3414,10 +3596,10 @@ func Test_server_MultiLinearSearch(t *testing.T) { } } -func Test_server_MultiLinearSearchByID(t *testing.T) { +func Test_server_LinearSearch(t *testing.T) { type args struct { ctx context.Context - req *payload.Search_MultiIDRequest + req *payload.Search_Request } type fields struct { eg errgroup.Group @@ -3430,7 +3612,7 @@ func Test_server_MultiLinearSearchByID(t *testing.T) { UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror } type want struct { - wantRes *payload.Search_Responses + wantRes *payload.Search_Response err error } type test struct { @@ -3438,11 +3620,11 @@ func Test_server_MultiLinearSearchByID(t *testing.T) { args args fields fields want want - checkFunc func(want, *payload.Search_Responses, error) error + checkFunc func(want, *payload.Search_Response, error) error beforeFunc func(*testing.T, args) afterFunc func(*testing.T, args) } - defaultCheckFunc := func(w want, gotRes *payload.Search_Responses, err error) error { + defaultCheckFunc := func(w want, gotRes *payload.Search_Response, err error) error { if !errors.Is(err, w.err) { return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) } @@ -3539,7 +3721,7 @@ func Test_server_MultiLinearSearchByID(t *testing.T) { UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, } - gotRes, err := s.MultiLinearSearchByID(test.args.ctx, test.args.req) + gotRes, err := s.LinearSearch(test.args.ctx, test.args.req) if err := checkFunc(test.want, gotRes, err); err != nil { tt.Errorf("error = %v", err) } @@ -3547,12 +3729,10 @@ func Test_server_MultiLinearSearchByID(t *testing.T) { } } -func Test_server_insert(t *testing.T) { +func Test_server_LinearSearchByID(t *testing.T) { type args struct { - ctx context.Context - client vald.InsertClient - req *payload.Insert_Request - opts []grpc.CallOption + ctx context.Context + req *payload.Search_IDRequest } type fields struct { eg errgroup.Group @@ -3565,7 +3745,7 @@ func Test_server_insert(t *testing.T) { UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror } type want struct { - wantLoc *payload.Object_Location + wantRes *payload.Search_Response err error } type test struct { @@ -3573,16 +3753,16 @@ func Test_server_insert(t *testing.T) { args args fields fields want want - checkFunc func(want, *payload.Object_Location, error) error + checkFunc func(want, *payload.Search_Response, error) error beforeFunc func(*testing.T, args) afterFunc func(*testing.T, args) } - defaultCheckFunc := func(w want, gotLoc *payload.Object_Location, err error) error { + defaultCheckFunc := func(w want, gotRes *payload.Search_Response, err error) error { if !errors.Is(err, w.err) { return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) } - if !reflect.DeepEqual(gotLoc, w.wantLoc) { - return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotLoc, w.wantLoc) + if !reflect.DeepEqual(gotRes, w.wantRes) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotRes, w.wantRes) } return nil } @@ -3593,9 +3773,7 @@ func Test_server_insert(t *testing.T) { name: "test_case_1", args: args { ctx:nil, - client:nil, req:nil, - opts:nil, }, fields: fields { eg:nil, @@ -3625,9 +3803,7 @@ func Test_server_insert(t *testing.T) { name: "test_case_2", args: args { ctx:nil, - client:nil, req:nil, - opts:nil, }, fields: fields { eg:nil, @@ -3678,17 +3854,17 @@ func Test_server_insert(t *testing.T) { UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, } - gotLoc, err := s.insert(test.args.ctx, test.args.client, test.args.req, test.args.opts...) - if err := checkFunc(test.want, gotLoc, err); err != nil { + gotRes, err := s.LinearSearchByID(test.args.ctx, test.args.req) + if err := checkFunc(test.want, gotRes, err); err != nil { tt.Errorf("error = %v", err) } }) } } -func Test_server_StreamInsert(t *testing.T) { +func Test_server_StreamLinearSearch(t *testing.T) { type args struct { - stream vald.Insert_StreamInsertServer + stream vald.Search_StreamLinearSearchServer } type fields struct { eg errgroup.Group @@ -3804,7 +3980,7 @@ func Test_server_StreamInsert(t *testing.T) { UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, } - err := s.StreamInsert(test.args.stream) + err := s.StreamLinearSearch(test.args.stream) if err := checkFunc(test.want, err); err != nil { tt.Errorf("error = %v", err) } @@ -3812,10 +3988,9 @@ func Test_server_StreamInsert(t *testing.T) { } } -func Test_server_MultiInsert(t *testing.T) { +func Test_server_StreamLinearSearchByID(t *testing.T) { type args struct { - ctx context.Context - reqs *payload.Insert_MultiRequest + stream vald.Search_StreamLinearSearchByIDServer } type fields struct { eg errgroup.Group @@ -3828,25 +4003,21 @@ func Test_server_MultiInsert(t *testing.T) { UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror } type want struct { - wantRes *payload.Object_Locations - err error + err error } type test struct { name string args args fields fields want want - checkFunc func(want, *payload.Object_Locations, error) error + checkFunc func(want, error) error beforeFunc func(*testing.T, args) afterFunc func(*testing.T, args) } - defaultCheckFunc := func(w want, gotRes *payload.Object_Locations, err error) error { + defaultCheckFunc := func(w want, err error) error { if !errors.Is(err, w.err) { return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) } - if !reflect.DeepEqual(gotRes, w.wantRes) { - return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotRes, w.wantRes) - } return nil } tests := []test{ @@ -3855,8 +4026,7 @@ func Test_server_MultiInsert(t *testing.T) { { name: "test_case_1", args: args { - ctx:nil, - reqs:nil, + stream:nil, }, fields: fields { eg:nil, @@ -3885,8 +4055,7 @@ func Test_server_MultiInsert(t *testing.T) { return test { name: "test_case_2", args: args { - ctx:nil, - reqs:nil, + stream:nil, }, fields: fields { eg:nil, @@ -3937,20 +4106,18 @@ func Test_server_MultiInsert(t *testing.T) { UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, } - gotRes, err := s.MultiInsert(test.args.ctx, test.args.reqs) - if err := checkFunc(test.want, gotRes, err); err != nil { + err := s.StreamLinearSearchByID(test.args.stream) + if err := checkFunc(test.want, err); err != nil { tt.Errorf("error = %v", err) } }) } } -func Test_server_update(t *testing.T) { +func Test_server_MultiLinearSearch(t *testing.T) { type args struct { - ctx context.Context - client vald.UpdateClient - req *payload.Update_Request - opts []grpc.CallOption + ctx context.Context + req *payload.Search_MultiRequest } type fields struct { eg errgroup.Group @@ -3963,7 +4130,7 @@ func Test_server_update(t *testing.T) { UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror } type want struct { - wantLoc *payload.Object_Location + wantRes *payload.Search_Responses err error } type test struct { @@ -3971,16 +4138,16 @@ func Test_server_update(t *testing.T) { args args fields fields want want - checkFunc func(want, *payload.Object_Location, error) error + checkFunc func(want, *payload.Search_Responses, error) error beforeFunc func(*testing.T, args) afterFunc func(*testing.T, args) } - defaultCheckFunc := func(w want, gotLoc *payload.Object_Location, err error) error { + defaultCheckFunc := func(w want, gotRes *payload.Search_Responses, err error) error { if !errors.Is(err, w.err) { return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) } - if !reflect.DeepEqual(gotLoc, w.wantLoc) { - return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotLoc, w.wantLoc) + if !reflect.DeepEqual(gotRes, w.wantRes) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotRes, w.wantRes) } return nil } @@ -3991,9 +4158,7 @@ func Test_server_update(t *testing.T) { name: "test_case_1", args: args { ctx:nil, - client:nil, req:nil, - opts:nil, }, fields: fields { eg:nil, @@ -4023,9 +4188,7 @@ func Test_server_update(t *testing.T) { name: "test_case_2", args: args { ctx:nil, - client:nil, req:nil, - opts:nil, }, fields: fields { eg:nil, @@ -4076,17 +4239,18 @@ func Test_server_update(t *testing.T) { UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, } - gotLoc, err := s.update(test.args.ctx, test.args.client, test.args.req, test.args.opts...) - if err := checkFunc(test.want, gotLoc, err); err != nil { + gotRes, err := s.MultiLinearSearch(test.args.ctx, test.args.req) + if err := checkFunc(test.want, gotRes, err); err != nil { tt.Errorf("error = %v", err) } }) } } -func Test_server_StreamUpdate(t *testing.T) { +func Test_server_MultiLinearSearchByID(t *testing.T) { type args struct { - stream vald.Update_StreamUpdateServer + ctx context.Context + req *payload.Search_MultiIDRequest } type fields struct { eg errgroup.Group @@ -4099,21 +4263,25 @@ func Test_server_StreamUpdate(t *testing.T) { UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror } type want struct { - err error + wantRes *payload.Search_Responses + err error } type test struct { name string args args fields fields want want - checkFunc func(want, error) error + checkFunc func(want, *payload.Search_Responses, error) error beforeFunc func(*testing.T, args) afterFunc func(*testing.T, args) } - defaultCheckFunc := func(w want, err error) error { + defaultCheckFunc := func(w want, gotRes *payload.Search_Responses, err error) error { if !errors.Is(err, w.err) { return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) } + if !reflect.DeepEqual(gotRes, w.wantRes) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotRes, w.wantRes) + } return nil } tests := []test{ @@ -4122,7 +4290,8 @@ func Test_server_StreamUpdate(t *testing.T) { { name: "test_case_1", args: args { - stream:nil, + ctx:nil, + req:nil, }, fields: fields { eg:nil, @@ -4151,7 +4320,8 @@ func Test_server_StreamUpdate(t *testing.T) { return test { name: "test_case_2", args: args { - stream:nil, + ctx:nil, + req:nil, }, fields: fields { eg:nil, @@ -4202,18 +4372,17 @@ func Test_server_StreamUpdate(t *testing.T) { UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, } - err := s.StreamUpdate(test.args.stream) - if err := checkFunc(test.want, err); err != nil { + gotRes, err := s.MultiLinearSearchByID(test.args.ctx, test.args.req) + if err := checkFunc(test.want, gotRes, err); err != nil { tt.Errorf("error = %v", err) } }) } } -func Test_server_MultiUpdate(t *testing.T) { +func Test_server_StreamInsert(t *testing.T) { type args struct { - ctx context.Context - reqs *payload.Update_MultiRequest + stream vald.Insert_StreamInsertServer } type fields struct { eg errgroup.Group @@ -4226,25 +4395,21 @@ func Test_server_MultiUpdate(t *testing.T) { UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror } type want struct { - wantRes *payload.Object_Locations - err error + err error } type test struct { name string args args fields fields want want - checkFunc func(want, *payload.Object_Locations, error) error + checkFunc func(want, error) error beforeFunc func(*testing.T, args) afterFunc func(*testing.T, args) } - defaultCheckFunc := func(w want, gotRes *payload.Object_Locations, err error) error { + defaultCheckFunc := func(w want, err error) error { if !errors.Is(err, w.err) { return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) } - if !reflect.DeepEqual(gotRes, w.wantRes) { - return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotRes, w.wantRes) - } return nil } tests := []test{ @@ -4253,8 +4418,7 @@ func Test_server_MultiUpdate(t *testing.T) { { name: "test_case_1", args: args { - ctx:nil, - reqs:nil, + stream:nil, }, fields: fields { eg:nil, @@ -4283,8 +4447,7 @@ func Test_server_MultiUpdate(t *testing.T) { return test { name: "test_case_2", args: args { - ctx:nil, - reqs:nil, + stream:nil, }, fields: fields { eg:nil, @@ -4335,20 +4498,18 @@ func Test_server_MultiUpdate(t *testing.T) { UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, } - gotRes, err := s.MultiUpdate(test.args.ctx, test.args.reqs) - if err := checkFunc(test.want, gotRes, err); err != nil { + err := s.StreamInsert(test.args.stream) + if err := checkFunc(test.want, err); err != nil { tt.Errorf("error = %v", err) } }) } } -func Test_server_upsert(t *testing.T) { +func Test_server_MultiInsert(t *testing.T) { type args struct { - ctx context.Context - client vald.UpsertClient - req *payload.Upsert_Request - opts []grpc.CallOption + ctx context.Context + reqs *payload.Insert_MultiRequest } type fields struct { eg errgroup.Group @@ -4361,7 +4522,7 @@ func Test_server_upsert(t *testing.T) { UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror } type want struct { - wantLoc *payload.Object_Location + wantRes *payload.Object_Locations err error } type test struct { @@ -4369,16 +4530,16 @@ func Test_server_upsert(t *testing.T) { args args fields fields want want - checkFunc func(want, *payload.Object_Location, error) error + checkFunc func(want, *payload.Object_Locations, error) error beforeFunc func(*testing.T, args) afterFunc func(*testing.T, args) } - defaultCheckFunc := func(w want, gotLoc *payload.Object_Location, err error) error { + defaultCheckFunc := func(w want, gotRes *payload.Object_Locations, err error) error { if !errors.Is(err, w.err) { return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) } - if !reflect.DeepEqual(gotLoc, w.wantLoc) { - return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotLoc, w.wantLoc) + if !reflect.DeepEqual(gotRes, w.wantRes) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotRes, w.wantRes) } return nil } @@ -4389,9 +4550,7 @@ func Test_server_upsert(t *testing.T) { name: "test_case_1", args: args { ctx:nil, - client:nil, - req:nil, - opts:nil, + reqs:nil, }, fields: fields { eg:nil, @@ -4421,9 +4580,7 @@ func Test_server_upsert(t *testing.T) { name: "test_case_2", args: args { ctx:nil, - client:nil, - req:nil, - opts:nil, + reqs:nil, }, fields: fields { eg:nil, @@ -4474,17 +4631,17 @@ func Test_server_upsert(t *testing.T) { UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, } - gotLoc, err := s.upsert(test.args.ctx, test.args.client, test.args.req, test.args.opts...) - if err := checkFunc(test.want, gotLoc, err); err != nil { + gotRes, err := s.MultiInsert(test.args.ctx, test.args.reqs) + if err := checkFunc(test.want, gotRes, err); err != nil { tt.Errorf("error = %v", err) } }) } } -func Test_server_StreamUpsert(t *testing.T) { +func Test_server_StreamUpdate(t *testing.T) { type args struct { - stream vald.Upsert_StreamUpsertServer + stream vald.Update_StreamUpdateServer } type fields struct { eg errgroup.Group @@ -4600,7 +4757,7 @@ func Test_server_StreamUpsert(t *testing.T) { UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, } - err := s.StreamUpsert(test.args.stream) + err := s.StreamUpdate(test.args.stream) if err := checkFunc(test.want, err); err != nil { tt.Errorf("error = %v", err) } @@ -4608,10 +4765,10 @@ func Test_server_StreamUpsert(t *testing.T) { } } -func Test_server_MultiUpsert(t *testing.T) { +func Test_server_MultiUpdate(t *testing.T) { type args struct { ctx context.Context - reqs *payload.Upsert_MultiRequest + reqs *payload.Update_MultiRequest } type fields struct { eg errgroup.Group @@ -4733,7 +4890,7 @@ func Test_server_MultiUpsert(t *testing.T) { UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, } - gotRes, err := s.MultiUpsert(test.args.ctx, test.args.reqs) + gotRes, err := s.MultiUpdate(test.args.ctx, test.args.reqs) if err := checkFunc(test.want, gotRes, err); err != nil { tt.Errorf("error = %v", err) } @@ -4741,12 +4898,9 @@ func Test_server_MultiUpsert(t *testing.T) { } } -func Test_server_remove(t *testing.T) { +func Test_server_StreamUpsert(t *testing.T) { type args struct { - ctx context.Context - client vald.RemoveClient - req *payload.Remove_Request - opts []grpc.CallOption + stream vald.Upsert_StreamUpsertServer } type fields struct { eg errgroup.Group @@ -4759,25 +4913,21 @@ func Test_server_remove(t *testing.T) { UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror } type want struct { - want *payload.Object_Location - err error + err error } type test struct { name string args args fields fields want want - checkFunc func(want, *payload.Object_Location, error) error + checkFunc func(want, error) error beforeFunc func(*testing.T, args) afterFunc func(*testing.T, args) } - defaultCheckFunc := func(w want, got *payload.Object_Location, err error) error { + defaultCheckFunc := func(w want, err error) error { if !errors.Is(err, w.err) { return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) } - if !reflect.DeepEqual(got, w.want) { - return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", got, w.want) - } return nil } tests := []test{ @@ -4786,10 +4936,7 @@ func Test_server_remove(t *testing.T) { { name: "test_case_1", args: args { - ctx:nil, - client:nil, - req:nil, - opts:nil, + stream:nil, }, fields: fields { eg:nil, @@ -4818,10 +4965,7 @@ func Test_server_remove(t *testing.T) { return test { name: "test_case_2", args: args { - ctx:nil, - client:nil, - req:nil, - opts:nil, + stream:nil, }, fields: fields { eg:nil, @@ -4872,17 +5016,18 @@ func Test_server_remove(t *testing.T) { UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, } - got, err := s.remove(test.args.ctx, test.args.client, test.args.req, test.args.opts...) - if err := checkFunc(test.want, got, err); err != nil { + err := s.StreamUpsert(test.args.stream) + if err := checkFunc(test.want, err); err != nil { tt.Errorf("error = %v", err) } }) } } -func Test_server_StreamRemove(t *testing.T) { +func Test_server_MultiUpsert(t *testing.T) { type args struct { - stream vald.Remove_StreamRemoveServer + ctx context.Context + reqs *payload.Upsert_MultiRequest } type fields struct { eg errgroup.Group @@ -4895,21 +5040,25 @@ func Test_server_StreamRemove(t *testing.T) { UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror } type want struct { - err error + wantRes *payload.Object_Locations + err error } type test struct { name string args args fields fields want want - checkFunc func(want, error) error + checkFunc func(want, *payload.Object_Locations, error) error beforeFunc func(*testing.T, args) afterFunc func(*testing.T, args) } - defaultCheckFunc := func(w want, err error) error { + defaultCheckFunc := func(w want, gotRes *payload.Object_Locations, err error) error { if !errors.Is(err, w.err) { return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) } + if !reflect.DeepEqual(gotRes, w.wantRes) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotRes, w.wantRes) + } return nil } tests := []test{ @@ -4918,7 +5067,8 @@ func Test_server_StreamRemove(t *testing.T) { { name: "test_case_1", args: args { - stream:nil, + ctx:nil, + reqs:nil, }, fields: fields { eg:nil, @@ -4947,7 +5097,8 @@ func Test_server_StreamRemove(t *testing.T) { return test { name: "test_case_2", args: args { - stream:nil, + ctx:nil, + reqs:nil, }, fields: fields { eg:nil, @@ -4998,18 +5149,17 @@ func Test_server_StreamRemove(t *testing.T) { UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, } - err := s.StreamRemove(test.args.stream) - if err := checkFunc(test.want, err); err != nil { + gotRes, err := s.MultiUpsert(test.args.ctx, test.args.reqs) + if err := checkFunc(test.want, gotRes, err); err != nil { tt.Errorf("error = %v", err) } }) } } -func Test_server_MultiRemove(t *testing.T) { +func Test_server_StreamRemove(t *testing.T) { type args struct { - ctx context.Context - reqs *payload.Remove_MultiRequest + stream vald.Remove_StreamRemoveServer } type fields struct { eg errgroup.Group @@ -5022,25 +5172,21 @@ func Test_server_MultiRemove(t *testing.T) { UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror } type want struct { - wantRes *payload.Object_Locations - err error + err error } type test struct { name string args args fields fields want want - checkFunc func(want, *payload.Object_Locations, error) error + checkFunc func(want, error) error beforeFunc func(*testing.T, args) afterFunc func(*testing.T, args) } - defaultCheckFunc := func(w want, gotRes *payload.Object_Locations, err error) error { + defaultCheckFunc := func(w want, err error) error { if !errors.Is(err, w.err) { return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) } - if !reflect.DeepEqual(gotRes, w.wantRes) { - return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotRes, w.wantRes) - } return nil } tests := []test{ @@ -5049,8 +5195,7 @@ func Test_server_MultiRemove(t *testing.T) { { name: "test_case_1", args: args { - ctx:nil, - reqs:nil, + stream:nil, }, fields: fields { eg:nil, @@ -5079,8 +5224,7 @@ func Test_server_MultiRemove(t *testing.T) { return test { name: "test_case_2", args: args { - ctx:nil, - reqs:nil, + stream:nil, }, fields: fields { eg:nil, @@ -5131,18 +5275,18 @@ func Test_server_MultiRemove(t *testing.T) { UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, } - gotRes, err := s.MultiRemove(test.args.ctx, test.args.reqs) - if err := checkFunc(test.want, gotRes, err); err != nil { + err := s.StreamRemove(test.args.stream) + if err := checkFunc(test.want, err); err != nil { tt.Errorf("error = %v", err) } }) } } -func Test_server_GetObject(t *testing.T) { +func Test_server_MultiRemove(t *testing.T) { type args struct { - ctx context.Context - req *payload.Object_VectorRequest + ctx context.Context + reqs *payload.Remove_MultiRequest } type fields struct { eg errgroup.Group @@ -5155,7 +5299,7 @@ func Test_server_GetObject(t *testing.T) { UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror } type want struct { - wantVec *payload.Object_Vector + wantRes *payload.Object_Locations err error } type test struct { @@ -5163,16 +5307,16 @@ func Test_server_GetObject(t *testing.T) { args args fields fields want want - checkFunc func(want, *payload.Object_Vector, error) error + checkFunc func(want, *payload.Object_Locations, error) error beforeFunc func(*testing.T, args) afterFunc func(*testing.T, args) } - defaultCheckFunc := func(w want, gotVec *payload.Object_Vector, err error) error { + defaultCheckFunc := func(w want, gotRes *payload.Object_Locations, err error) error { if !errors.Is(err, w.err) { return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) } - if !reflect.DeepEqual(gotVec, w.wantVec) { - return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotVec, w.wantVec) + if !reflect.DeepEqual(gotRes, w.wantRes) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotRes, w.wantRes) } return nil } @@ -5183,7 +5327,7 @@ func Test_server_GetObject(t *testing.T) { name: "test_case_1", args: args { ctx:nil, - req:nil, + reqs:nil, }, fields: fields { eg:nil, @@ -5213,7 +5357,7 @@ func Test_server_GetObject(t *testing.T) { name: "test_case_2", args: args { ctx:nil, - req:nil, + reqs:nil, }, fields: fields { eg:nil, @@ -5264,15 +5408,15 @@ func Test_server_GetObject(t *testing.T) { UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, } - gotVec, err := s.GetObject(test.args.ctx, test.args.req) - if err := checkFunc(test.want, gotVec, err); err != nil { + gotRes, err := s.MultiRemove(test.args.ctx, test.args.reqs) + if err := checkFunc(test.want, gotRes, err); err != nil { tt.Errorf("error = %v", err) } }) } } -func Test_server_getObjects(t *testing.T) { +func Test_server_GetObject(t *testing.T) { type args struct { ctx context.Context req *payload.Object_VectorRequest @@ -5288,24 +5432,24 @@ func Test_server_getObjects(t *testing.T) { UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror } type want struct { - wantVecs *sync.Map[string, *payload.Object_Vector] - err error + wantVec *payload.Object_Vector + err error } type test struct { name string args args fields fields want want - checkFunc func(want, *sync.Map[string, *payload.Object_Vector], error) error + checkFunc func(want, *payload.Object_Vector, error) error beforeFunc func(*testing.T, args) afterFunc func(*testing.T, args) } - defaultCheckFunc := func(w want, gotVecs *sync.Map[string, *payload.Object_Vector], err error) error { + defaultCheckFunc := func(w want, gotVec *payload.Object_Vector, err error) error { if !errors.Is(err, w.err) { return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) } - if !reflect.DeepEqual(gotVecs, w.wantVecs) { - return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotVecs, w.wantVecs) + if !reflect.DeepEqual(gotVec, w.wantVec) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotVec, w.wantVec) } return nil } @@ -5397,8 +5541,8 @@ func Test_server_getObjects(t *testing.T) { UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, } - gotVecs, err := s.getObjects(test.args.ctx, test.args.req) - if err := checkFunc(test.want, gotVecs, err); err != nil { + gotVec, err := s.GetObject(test.args.ctx, test.args.req) + if err := checkFunc(test.want, gotVec, err); err != nil { tt.Errorf("error = %v", err) } }) diff --git a/pkg/gateway/mirror/handler/grpc/mock_test.go b/pkg/gateway/mirror/handler/grpc/mock_test.go index 4a5d93b336..0bfb3548c1 100644 --- a/pkg/gateway/mirror/handler/grpc/mock_test.go +++ b/pkg/gateway/mirror/handler/grpc/mock_test.go @@ -151,7 +151,7 @@ func (m *mockClient) Remove(ctx context.Context, in *payload.Remove_Request, opt } func (m *mockClient) RemoveByTimestamp(ctx context.Context, in *payload.Remove_TimestampRequest, opts ...grpc.CallOption) (*payload.Object_Locations, error) { - return m.RemoveByTimestamp(ctx, in, opts...) + return m.RemoveByTimestampFunc(ctx, in, opts...) } func (m *mockClient) StreamRemove(ctx context.Context, opts ...grpc.CallOption) (vald.Remove_StreamRemoveClient, error) {