diff --git a/.github/workflows/dockers-gateway-mirror-image.yaml b/.github/workflows/dockers-gateway-mirror-image.yaml new file mode 100644 index 0000000000..5c3fe5592e --- /dev/null +++ b/.github/workflows/dockers-gateway-mirror-image.yaml @@ -0,0 +1,156 @@ +# +# Copyright (C) 2019-2022 vdaas.org vald team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +name: "Build docker image: gateway-mirror" +on: + push: + branches: + - main + tags: + - "*.*.*" + - "v*.*.*" + - "*.*.*-*" + - "v*.*.*-*" + paths: + - ".github/actions/docker-build/actions.yaml" + - ".github/workflows/dockers-gateway-mirror-image.yml" + - "go.mod" + - "go.sum" + - "internal/**" + - "!internal/**/*_test.go" + - "!internal/**/*_mock.go" + - "!internal/db/**" + - "!internal/k8s/**" + - "apis/grpc/**" + - "pkg/gateway/mirror/**" + - "cmd/gateway/mirror/**" + - "pkg/gateway/internal/**" + - "dockers/gateway/mirror/Dockerfile" + - "versions/GO_VERSION" + pull_request: + paths: + - ".github/actions/docker-build/actions.yaml" + - ".github/workflows/dockers-gateway-mirror-image.yml" + - "go.mod" + - "go.sum" + - "internal/**" + - "!internal/**/*_test.go" + - "!internal/**/*_mock.go" + - "!internal/db/**" + - "!internal/k8s/**" + - "apis/grpc/**" + - "pkg/gateway/mirror/**" + - "cmd/gateway/mirror/**" + - "pkg/gateway/internal/**" + - "dockers/gateway/mirror/Dockerfile" + - "versions/GO_VERSION" + pull_request_target: + paths: + - ".github/actions/docker-build/actions.yaml" + - ".github/workflows/dockers-gateway-mirror-image.yml" + - "go.mod" + - "go.sum" + - "internal/**" + - "!internal/**/*_test.go" + - "!internal/**/*_mock.go" + - "!internal/db/**" + - "!internal/k8s/**" + - "apis/grpc/**" + - "pkg/gateway/mirror/**" + - "cmd/gateway/nirror/**" + - "pkg/gateway/internal/**" + - "dockers/gateway/mirror/Dockerfile" + - "versions/GO_VERSION" + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-${{ github.event_name }} + cancel-in-progress: true + +jobs: + build: + strategy: + max-parallel: 4 + runs-on: ubuntu-latest + if: ${{ (github.event_name == 'pull_request' && github.event.pull_request.head.repo.fork == false) || (github.event.pull_request.head.repo.fork == true && github.event_name == 'pull_request_target' && contains(github.event.pull_request.labels.*.name, 'ci/approved')) }} + steps: + - uses: actions/checkout@v3 + - name: set git config + run: | + git config --global --add safe.directory ${GITHUB_WORKSPACE} + - name: Setup QEMU + uses: docker/setup-qemu-action@v2 + with: + platforms: all + - name: Setup Docker Buildx + id: buildx + uses: docker/setup-buildx-action@v2 + with: + buildkitd-flags: "--debug" + - name: Login to DockerHub + uses: docker/login-action@v2 + with: + username: ${{ secrets.DOCKERHUB_USER }} + password: ${{ secrets.DOCKERHUB_PASS }} + - name: Login to GitHub Container Registry + uses: docker/login-action@v2 + with: + registry: ghcr.io + username: ${{ secrets.PACKAGE_USER }} + password: ${{ secrets.PACKAGE_TOKEN }} + - name: Build and Publish + id: build_and_publish + uses: ./.github/actions/docker-build + with: + target: gateway-mirror + builder: ${{ steps.buildx.outputs.name }} + - name: Initialize CodeQL + if: startsWith( github.ref, 'refs/tags/') + uses: github/codeql-action/init@v2 + - name: Run vulnerability scanner (table) + if: startsWith( github.ref, 'refs/tags/') + uses: aquasecurity/trivy-action@master + with: + image-ref: "${{ steps.build_and_publish.outputs.IMAGE_NAME }}:${{ steps.build_and_publish.outputs.PRIMARY_TAG }}" + format: "table" + - name: Run vulnerability scanner (sarif) + if: startsWith( github.ref, 'refs/tags/') + uses: aquasecurity/trivy-action@master + with: + image-ref: "${{ steps.build_and_publish.outputs.IMAGE_NAME }}:${{ steps.build_and_publish.outputs.PRIMARY_TAG }}" + format: "template" + template: "@/contrib/sarif.tpl" + output: "trivy-results.sarif" + - name: Upload Trivy scan results to Security tab + if: startsWith( github.ref, 'refs/tags/') + uses: github/codeql-action/upload-sarif@v2 + with: + sarif_file: "trivy-results.sarif" + slack: + name: Slack notification + needs: build + runs-on: ubuntu-latest + if: github.ref == 'refs/heads/main' || startsWith( github.ref, 'refs/tags/') + steps: + - uses: technote-space/workflow-conclusion-action@v2 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - uses: 8398a7/action-slack@v3 + with: + author_name: vald-mirror-gateway image build + status: ${{ env.WORKFLOW_CONCLUSION }} + only_mention_fail: channel + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + SLACK_WEBHOOK_URL: ${{ secrets.SLACK_NOTIFY_WEBHOOK_URL }} diff --git a/Makefile b/Makefile index 9d326895fa..cf6fbc956d 100644 --- a/Makefile +++ b/Makefile @@ -28,6 +28,7 @@ CI_CONTAINER_IMAGE = $(NAME)-ci-container DEV_CONTAINER_IMAGE = $(NAME)-dev-container DISCOVERER_IMAGE = $(NAME)-discoverer-k8s FILTER_GATEWAY_IMAGE = $(NAME)-filter-gateway +MIRROR_GATEWAY_IMAGE = $(NAME)-mirror-gateway HELM_OPERATOR_IMAGE = $(NAME)-helm-operator LB_GATEWAY_IMAGE = $(NAME)-lb-gateway LOADTEST_IMAGE = $(NAME)-loadtest diff --git a/Makefile.d/build.mk b/Makefile.d/build.mk index 7c46a488a7..2f6ffe7424 100644 --- a/Makefile.d/build.mk +++ b/Makefile.d/build.mk @@ -177,6 +177,34 @@ cmd/gateway/filter/filter: \ $(dir $@)main.go $@ -version +cmd/gateway/mirror/mirror: \ + $(GO_SOURCES_INTERNAL) \ + $(PBGOS) \ + $(shell find ./cmd/gateway/mirror -type f -name '*.go' -not -name '*_test.go' -not -name 'doc.go') \ + $(shell find ./pkg/gateway/mirror -type f -name '*.go' -not -name '*_test.go' -not -name 'doc.go') + CGO_ENABLED=0 \ + GO111MODULE=on \ + GOPRIVATE=$(GOPRIVATE) \ + go build \ + --ldflags "-w -extldflags=-static \ + -X '$(GOPKG)/internal/info.Version=$(VERSION)' \ + -X '$(GOPKG)/internal/info.GitCommit=$(GIT_COMMIT)' \ + -X '$(GOPKG)/internal/info.BuildTime=$(DATETIME)' \ + -X '$(GOPKG)/internal/info.GoVersion=$(GO_VERSION)' \ + -X '$(GOPKG)/internal/info.GoOS=$(GOOS)' \ + -X '$(GOPKG)/internal/info.GoArch=$(GOARCH)' \ + -X '$(GOPKG)/internal/info.CGOEnabled=$${CGO_ENABLED}' \ + -X '$(GOPKG)/internal/info.BuildCPUInfoFlags=$(CPU_INFO_FLAGS)' \ + -buildid=" \ + -mod=readonly \ + -modcacherw \ + -a \ + -tags "osusergo netgo static_build" \ + -trimpath \ + -o $@ \ + $(dir $@)main.go + $@ -version + cmd/manager/index/index: \ $(GO_SOURCES_INTERNAL) \ $(PBGOS) \ diff --git a/Makefile.d/docker.mk b/Makefile.d/docker.mk index a8858a27a6..763ede9be1 100644 --- a/Makefile.d/docker.mk +++ b/Makefile.d/docker.mk @@ -114,6 +114,21 @@ docker/build/gateway-filter: --build-arg DISTROLESS_IMAGE=$(DISTROLESS_IMAGE) \ --build-arg DISTROLESS_IMAGE_TAG=$(DISTROLESS_IMAGE_TAG) +.PHONY: docker/name/gateway-mirror +docker/name/gateway-mirror: + @echo "$(ORG)/$(MIRROR_GATEWAY_IMAGE)" + +.PHONY: docker/build/gateway-mirror +## build gateway-mirror image +docker/build/gateway-mirror: + $(DOCKER) build \ + $(DOCKER_OPTS) \ + -f dockers/gateway/mirror/Dockerfile \ + -t $(ORG)/$(MIRROR_GATEWAY_IMAGE):$(TAG) . \ + --build-arg GO_VERSION=$(GO_VERSION) \ + --build-arg DISTROLESS_IMAGE=$(DISTROLESS_IMAGE) \ + --build-arg DISTROLESS_IMAGE_TAG=$(DISTROLESS_IMAGE_TAG) + .PHONY: docker/name/manager-index docker/name/manager-index: @echo "$(ORG)/$(MANAGER_INDEX_IMAGE)" diff --git a/Makefile.d/k8s.mk b/Makefile.d/k8s.mk index 3f5f85ed58..4f0c9edd6b 100644 --- a/Makefile.d/k8s.mk +++ b/Makefile.d/k8s.mk @@ -13,6 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # + +MIRROR01_NAMESPACE = vald-01 +MIRROR02_NAMESPACE = vald-02 +MIRROR03_NAMESPACE = vald-03 + .PHONY: k8s/manifest/clean ## clean k8s manifests k8s/manifest/clean: @@ -100,6 +105,33 @@ k8s/vald/delete: kubectl delete -f $(TEMP_DIR)/vald/templates/agent rm -rf $(TEMP_DIR) +.PHONY: k8s/multi/vald/deploy +## deploy multiple vald sample clusters to k8s +k8s/multi/vald/deploy: + -@kubectl create ns $(MIRROR01_NAMESPACE) + -@kubectl create ns $(MIRROR02_NAMESPACE) + -@kubectl create ns $(MIRROR03_NAMESPACE) + helm install vald-cluster-01 charts/vald \ + -f ./charts/vald/values/multi-vald/dev-vald-with-mirror.yaml \ + -f ./charts/vald/values/multi-vald/dev-vald-01.yaml \ + -n $(MIRROR01_NAMESPACE) + helm install vald-cluster-02 charts/vald \ + -f ./charts/vald/values/multi-vald/dev-vald-with-mirror.yaml \ + -f ./charts/vald/values/multi-vald/dev-vald-02.yaml \ + -n $(MIRROR02_NAMESPACE) + helm install vald-cluster-03 charts/vald \ + -f ./charts/vald/values/multi-vald/dev-vald-with-mirror.yaml \ + -f ./charts/vald/values/multi-vald/dev-vald-03.yaml \ + -n $(MIRROR03_NAMESPACE) + +.PHONY: k8s/multi/vald/delete +## delete multiple vald sample clusters to k8s +k8s/multi/vald/delete: + helm uninstall vald-cluster-01 -n vald-01 + helm uninstall vald-cluster-02 -n vald-02 + helm uninstall vald-cluster-03 -n vald-03 + -@kubectl delete ns vald-01 vald-02 vald-03 + .PHONY: k8s/vald-helm-operator/deploy ## deploy vald-helm-operator to k8s k8s/vald-helm-operator/deploy: diff --git a/apis/grpc/v1/vald/vald.go b/apis/grpc/v1/vald/vald.go index 01c12f8666..5819594535 100644 --- a/apis/grpc/v1/vald/vald.go +++ b/apis/grpc/v1/vald/vald.go @@ -18,6 +18,7 @@ package vald import ( + "github.com/vdaas/vald/apis/grpc/v1/mirror" grpc "google.golang.org/grpc" ) @@ -35,6 +36,11 @@ type ServerWithFilter interface { FilterServer } +type ServerWithMirror interface { + Server + mirror.MirrorServer +} + type UnimplementedValdServer struct { UnimplementedInsertServer UnimplementedUpdateServer @@ -49,6 +55,11 @@ type UnimplementedValdServerWithFilter struct { UnimplementedFilterServer } +type UnimplementedValdServerWithMirror struct { + UnimplementedValdServer + mirror.UnimplementedMirrorServer +} + type Client interface { InsertClient UpdateClient @@ -63,6 +74,11 @@ type ClientWithFilter interface { FilterClient } +type ClientWithMirror interface { + Client + mirror.MirrorClient +} + const PackageName = "vald.v1" const ( @@ -73,6 +89,7 @@ const ( RemoveRPCServiceName = "Remove" ObjectRPCServiceName = "Object" FilterRPCServiceName = "Filter" + MirrorRPCServiceName = "Mirror" ) const ( @@ -123,6 +140,9 @@ const ( ExistsRPCName = "Exists" GetObjectRPCName = "GetObject" StreamGetObjectRPCName = "StreamGetObject" + + RegisterRPCName = "Register" + AdvertiseRPCName = "Advertise" ) type client struct { @@ -134,6 +154,11 @@ type client struct { ObjectClient } +type clientWithMirror struct { + Client + mirror.MirrorClient +} + func RegisterValdServer(s *grpc.Server, srv Server) { RegisterInsertServer(s, srv) RegisterUpdateServer(s, srv) @@ -148,6 +173,11 @@ func RegisterValdServerWithFilter(s *grpc.Server, srv ServerWithFilter) { RegisterFilterServer(s, srv) } +func RegisterValdServerWithMirror(s *grpc.Server, srv ServerWithMirror) { + RegisterValdServer(s, srv) + mirror.RegisterMirrorServer(s, srv) +} + func NewValdClient(conn *grpc.ClientConn) Client { return &client{ NewInsertClient(conn), @@ -158,3 +188,10 @@ func NewValdClient(conn *grpc.ClientConn) Client { NewObjectClient(conn), } } + +func NewValdClientWithMirror(conn *grpc.ClientConn) ClientWithMirror { + return &clientWithMirror{ + Client: NewValdClient(conn), + MirrorClient: mirror.NewMirrorClient(conn), + } +} diff --git a/charts/vald/templates/gateway/mirror/configmap.yaml b/charts/vald/templates/gateway/mirror/configmap.yaml new file mode 100644 index 0000000000..24e8c87117 --- /dev/null +++ b/charts/vald/templates/gateway/mirror/configmap.yaml @@ -0,0 +1,80 @@ +# +# Copyright (C) 2019-2023 vdaas.org vald team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +{{- $gateway := .Values.gateway.mirror -}} +{{- $lb := .Values.gateway.lb -}} +{{- if $gateway.enabled }} +apiVersion: v1 +kind: ConfigMap +metadata: + name: {{ $gateway.name }}-config + labels: + app.kubernetes.io/name: {{ include "vald.name" . }} + helm.sh/chart: {{ include "vald.chart" . }} + app.kubernetes.io/managed-by: {{ .Release.Service }} + app.kubernetes.io/instance: {{ .Release.Name }} + app.kubernetes.io/version: {{ .Chart.Version }} + app.kubernetes.io/component: gateway-mirror +data: + config.yaml: | + --- + version: {{ $gateway.version }} + time_zone: {{ default .Values.defaults.time_zone $gateway.time_zone }} + logging: + {{- $logging := dict "Values" $gateway.logging "default" .Values.defaults.logging }} + {{- include "vald.logging" $logging | nindent 6 }} + server_config: + {{- $servers := dict "Values" $gateway.server_config "default" .Values.defaults.server_config }} + {{- include "vald.servers" $servers | nindent 6 }} + observability: + {{- $observability := dict "Values" $gateway.observability "default" .Values.defaults.observability }} + {{- include "vald.observability" $observability | nindent 6 }} + gateway: + pod_name: {{ $gateway.gateway_config.pod_name }} + advertise_interval: {{ $gateway.gateway_config.advertise_interval }} + client: + {{- $client := $gateway.gateway_config.client }} + {{- $addrs := default list $client.addrs }} + {{- if $lb.enabled -}} + {{- $defaultHost := printf "%s.%s.svc.cluster.local" $lb.name .Release.Namespace }} + {{- $defaultPort := default .Values.defaults.server_config.servers.grpc.port $lb.server_config.servers.grpc.port }} + {{- $defaultAddr := (list (printf "%s:%d" $defaultHost (int64 $defaultPort))) }} + {{- $addrs = (concat $addrs $defaultAddr) }} + {{- end -}} + {{- if $addrs }} + addrs: + {{- toYaml $addrs | nindent 10 }} + {{- else }} + addrs: [] + {{- end -}} + {{- $GRPCClient := dict "Values" $client "default" .Values.defaults.grpc.client }} + {{- include "vald.grpc.client" $GRPCClient | nindent 8 }} + self_mirror_addr: + {{- if $gateway.ingress.enabled -}} + {{- $defaultHost := $gateway.ingress.host }} + {{- $defaultPort := default .Values.defaults.server_config.servers.grpc.port $gateway.server_config.servers.grpc.port }} + {{- printf "%s:%d" $defaultHost (int64 $defaultPort) | indent 1 }} + {{- else -}} + {{- $defaultHost := printf "%s.%s.svc.cluster.local" $gateway.name .Release.Namespace }} + {{- $defaultPort := default .Values.defaults.server_config.servers.grpc.port $gateway.server_config.servers.grpc.port }} + {{- printf "%s:%d" $defaultHost (int64 $defaultPort) | indent 1 }} + {{- end }} + gateway_addr: + {{- if $lb.enabled -}} + {{- $defaultHost := printf "%s.%s.svc.cluster.local" $lb.name .Release.Namespace }} + {{- $defaultPort := default .Values.defaults.server_config.servers.grpc.port $lb.server_config.servers.grpc.port }} + {{- printf "%s:%d" $defaultHost (int64 $defaultPort) | indent 1 }} + {{- end }} +{{- end }} diff --git a/charts/vald/templates/gateway/mirror/daemonset.yaml b/charts/vald/templates/gateway/mirror/daemonset.yaml new file mode 100644 index 0000000000..bd0e539455 --- /dev/null +++ b/charts/vald/templates/gateway/mirror/daemonset.yaml @@ -0,0 +1,134 @@ +# +# Copyright (C) 2019-2023 vdaas.org vald team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +{{- $gateway := .Values.gateway.mirror -}} +{{- if and $gateway.enabled (eq $gateway.kind "DaemonSet") }} +apiVersion: apps/v1 +kind: DaemonSet +metadata: + name: {{ $gateway.name }} + labels: + app: {{ $gateway.name }} + app.kubernetes.io/name: {{ include "vald.name" . }} + helm.sh/chart: {{ include "vald.chart" . }} + app.kubernetes.io/managed-by: {{ .Release.Service }} + app.kubernetes.io/instance: {{ .Release.Name }} + app.kubernetes.io/version: {{ .Chart.Version }} + app.kubernetes.io/component: gateway-mirror + {{- if $gateway.annotations }} + annotations: + {{- toYaml $gateway.annotations | nindent 4 }} + {{- end }} +spec: + revisionHistoryLimit: {{ $gateway.revisionHistoryLimit }} + selector: + matchLabels: + app: {{ $gateway.name }} + updateStrategy: + type: RollingUpdate + rollingUpdate: + maxSurge: {{ $gateway.rollingUpdate.maxSurge }} + maxUnavailable: {{ $gateway.rollingUpdate.maxUnavailable }} + template: + metadata: + creationTimestamp: null + labels: + app: {{ $gateway.name }} + app.kubernetes.io/name: {{ include "vald.name" . }} + app.kubernetes.io/instance: {{ .Release.Name }} + app.kubernetes.io/component: gateway-mirror + annotations: + checksum/configmap: {{ include (print $.Template.BasePath "/gateway/mirror/configmap.yaml") . | sha256sum }} + {{- if $gateway.podAnnotations }} + {{- toYaml $gateway.podAnnotations | nindent 8 }} + {{- end }} + {{- $pprof := default .Values.defaults.server_config.metrics.pprof $gateway.server_config.metrics.pprof }} + {{- if $pprof.enabled }} + pyroscope.io/scrape: "true" + pyroscope.io/application-name: {{ $gateway.name }} + pyroscope.io/profile-cpu-enabled: "true" + pyroscope.io/profile-mem-enabled: "true" + pyroscope.io/port: "{{ $pprof.port }}" + {{- end }} + spec: + {{- if $gateway.initContainers }} + initContainers: + {{- $initContainers := dict "initContainers" $gateway.initContainers "Values" .Values "namespace" .Release.Namespace -}} + {{- include "vald.initContainers" $initContainers | trim | nindent 8 }} + {{- end }} + affinity: + {{- include "vald.affinity" $gateway.affinity | nindent 8 }} + {{- if $gateway.topologySpreadConstraints }} + topologySpreadConstraints: + {{- toYaml $gateway.topologySpreadConstraints | nindent 8 }} + {{- end }} + containers: + - name: {{ $gateway.name }} + image: "{{ $gateway.image.repository }}:{{ default .Values.defaults.image.tag $gateway.image.tag }}" + imagePullPolicy: {{ $gateway.image.pullPolicy }} + {{- $servers := dict "Values" $gateway.server_config "default" .Values.defaults.server_config -}} + {{- include "vald.containerPorts" $servers | trim | nindent 10 }} + resources: + {{- toYaml $gateway.resources | nindent 12 }} + terminationMessagePath: /dev/termination-log + terminationMessagePolicy: File + {{- if $gateway.securityContext }} + securityContext: + {{- toYaml $gateway.securityContext | nindent 12 }} + {{- end }} + {{- if $gateway.env }} + env: + {{- toYaml $gateway.env | nindent 12 }} + {{- end }} + volumeMounts: + - name: {{ $gateway.name }}-config + mountPath: /etc/server/ + {{- if $gateway.volumeMounts }} + {{- toYaml $gateway.volumeMounts | nindent 12 }} + {{- end }} + dnsPolicy: ClusterFirst + restartPolicy: Always + schedulerName: default-scheduler + {{- if $gateway.podSecurityContext }} + securityContext: + {{- toYaml $gateway.podSecurityContext | nindent 8 }} + {{- end }} + terminationGracePeriodSeconds: {{ $gateway.terminationGracePeriodSeconds }} + volumes: + - name: {{ $gateway.name }}-config + configMap: + defaultMode: 420 + name: {{ $gateway.name }}-config + {{- if $gateway.volumes }} + {{- toYaml $gateway.volumes | nindent 8 }} + {{- end }} + {{- if $gateway.nodeName }} + nodeName: {{ $gateway.nodeName }} + {{- end }} + {{- if $gateway.nodeSelector }} + nodeSelector: + {{- toYaml $gateway.nodeSelector | nindent 8 }} + {{- end }} + {{- if $gateway.tolerations }} + tolerations: + {{- toYaml $gateway.tolerations | nindent 8 }} + {{- end }} + {{- if $gateway.podPriority }} + {{- if $gateway.podPriority.enabled }} + priorityClassName: {{ .Release.Namespace }}-{{ $gateway.name }}-priority + {{- end }} + {{- end }} +status: +{{- end }} diff --git a/charts/vald/templates/gateway/mirror/deployment.yaml b/charts/vald/templates/gateway/mirror/deployment.yaml new file mode 100644 index 0000000000..e11df4d3f0 --- /dev/null +++ b/charts/vald/templates/gateway/mirror/deployment.yaml @@ -0,0 +1,138 @@ +# +# Copyright (C) 2019-2023 vdaas.org vald team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +{{- $gateway := .Values.gateway.mirror -}} +{{- if and $gateway.enabled (eq $gateway.kind "Deployment") }} +apiVersion: apps/v1 +kind: Deployment +metadata: + name: {{ $gateway.name }} + labels: + app: {{ $gateway.name }} + app.kubernetes.io/name: {{ include "vald.name" . }} + helm.sh/chart: {{ include "vald.chart" . }} + app.kubernetes.io/managed-by: {{ .Release.Service }} + app.kubernetes.io/instance: {{ .Release.Name }} + app.kubernetes.io/version: {{ .Chart.Version }} + app.kubernetes.io/component: gateway-mirror + {{- if $gateway.annotations }} + annotations: + {{- toYaml $gateway.annotations | nindent 4 }} + {{- end }} +spec: + progressDeadlineSeconds: {{ $gateway.progressDeadlineSeconds }} + {{- if not $gateway.hpa.enabled }} + replicas: {{ $gateway.minReplicas }} + {{- end }} + revisionHistoryLimit: {{ $gateway.revisionHistoryLimit }} + selector: + matchLabels: + app: {{ $gateway.name }} + strategy: + rollingUpdate: + maxSurge: {{ $gateway.rollingUpdate.maxSurge }} + maxUnavailable: {{ $gateway.rollingUpdate.maxUnavailable }} + type: RollingUpdate + template: + metadata: + creationTimestamp: null + labels: + app: {{ $gateway.name }} + app.kubernetes.io/name: {{ include "vald.name" . }} + app.kubernetes.io/instance: {{ .Release.Name }} + app.kubernetes.io/component: gateway-mirror + annotations: + checksum/configmap: {{ include (print $.Template.BasePath "/gateway/mirror/configmap.yaml") . | sha256sum }} + {{- if $gateway.podAnnotations }} + {{- toYaml $gateway.podAnnotations | nindent 8 }} + {{- end }} + {{- $pprof := default .Values.defaults.server_config.metrics.pprof $gateway.server_config.metrics.pprof }} + {{- if $pprof.enabled }} + pyroscope.io/scrape: "true" + pyroscope.io/application-name: {{ $gateway.name }} + pyroscope.io/profile-cpu-enabled: "true" + pyroscope.io/profile-mem-enabled: "true" + pyroscope.io/port: "{{ $pprof.port }}" + {{- end }} + spec: + {{- if $gateway.initContainers }} + initContainers: + {{- $initContainers := dict "initContainers" $gateway.initContainers "Values" .Values "namespace" .Release.Namespace -}} + {{- include "vald.initContainers" $initContainers | trim | nindent 8 }} + {{- end }} + affinity: + {{- include "vald.affinity" $gateway.affinity | nindent 8 }} + {{- if $gateway.topologySpreadConstraints }} + topologySpreadConstraints: + {{- toYaml $gateway.topologySpreadConstraints | nindent 8 }} + {{- end }} + containers: + - name: {{ $gateway.name }} + image: "{{ $gateway.image.repository }}:{{ default .Values.defaults.image.tag $gateway.image.tag }}" + imagePullPolicy: {{ $gateway.image.pullPolicy }} + {{- $servers := dict "Values" $gateway.server_config "default" .Values.defaults.server_config -}} + {{- include "vald.containerPorts" $servers | trim | nindent 10 }} + resources: + {{- toYaml $gateway.resources | nindent 12 }} + terminationMessagePath: /dev/termination-log + terminationMessagePolicy: File + {{- if $gateway.securityContext }} + securityContext: + {{- toYaml $gateway.securityContext | nindent 12 }} + {{- end }} + {{- if $gateway.env }} + env: + {{- toYaml $gateway.env | nindent 12 }} + {{- end }} + volumeMounts: + - name: {{ $gateway.name }}-config + mountPath: /etc/server/ + {{- if $gateway.volumeMounts }} + {{- toYaml $gateway.volumeMounts | nindent 12 }} + {{- end }} + dnsPolicy: ClusterFirst + restartPolicy: Always + schedulerName: default-scheduler + {{- if $gateway.podSecurityContext }} + securityContext: + {{- toYaml $gateway.podSecurityContext | nindent 8 }} + {{- end }} + terminationGracePeriodSeconds: {{ $gateway.terminationGracePeriodSeconds }} + volumes: + - name: {{ $gateway.name }}-config + configMap: + defaultMode: 420 + name: {{ $gateway.name }}-config + {{- if $gateway.volumes }} + {{- toYaml $gateway.volumes | nindent 8 }} + {{- end }} + {{- if $gateway.nodeName }} + nodeName: {{ $gateway.nodeName }} + {{- end }} + {{- if $gateway.nodeSelector }} + nodeSelector: + {{- toYaml $gateway.nodeSelector | nindent 8 }} + {{- end }} + {{- if $gateway.tolerations }} + tolerations: + {{- toYaml $gateway.tolerations | nindent 8 }} + {{- end }} + {{- if $gateway.podPriority }} + {{- if $gateway.podPriority.enabled }} + priorityClassName: {{ .Release.Namespace }}-{{ $gateway.name }}-priority + {{- end }} + {{- end }} +status: +{{- end }} diff --git a/charts/vald/templates/gateway/mirror/hpa.yaml b/charts/vald/templates/gateway/mirror/hpa.yaml new file mode 100644 index 0000000000..c4227170fa --- /dev/null +++ b/charts/vald/templates/gateway/mirror/hpa.yaml @@ -0,0 +1,38 @@ +# +# Copyright (C) 2019-2023 vdaas.org vald team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +{{- $gateway := .Values.gateway.mirror -}} +{{- if and $gateway.enabled $gateway.hpa.enabled }} +apiVersion: autoscaling/v1 +kind: HorizontalPodAutoscaler +metadata: + name: {{ $gateway.name }} + labels: + app.kubernetes.io/name: {{ include "vald.name" . }} + helm.sh/chart: {{ include "vald.chart" . }} + app.kubernetes.io/managed-by: {{ .Release.Service }} + app.kubernetes.io/instance: {{ .Release.Name }} + app.kubernetes.io/version: {{ .Chart.Version }} + app.kubernetes.io/component: gateway-mirror +spec: + maxReplicas: {{ $gateway.maxReplicas }} + minReplicas: {{ $gateway.minReplicas }} + scaleTargetRef: + apiVersion: apps/v1 + kind: {{ $gateway.kind }} + name: {{ $gateway.name }} + targetCPUUtilizationPercentage: {{ $gateway.hpa.targetCPUUtilizationPercentage }} +status: +{{- end }} diff --git a/charts/vald/templates/gateway/mirror/ing.yaml b/charts/vald/templates/gateway/mirror/ing.yaml new file mode 100644 index 0000000000..05022fd515 --- /dev/null +++ b/charts/vald/templates/gateway/mirror/ing.yaml @@ -0,0 +1,47 @@ +# +# Copyright (C) 2019-2023 vdaas.org vald team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +{{- $gateway := .Values.gateway.mirror -}} +{{- if and $gateway.enabled $gateway.ingress.enabled }} +apiVersion: networking.k8s.io/v1 +kind: Ingress +metadata: + annotations: + {{- toYaml $gateway.ingress.annotations | nindent 4 }} + labels: + name: {{ $gateway.name }}-ingress + app: {{ $gateway.name }}-ingress + app.kubernetes.io/name: {{ include "vald.name" . }} + helm.sh/chart: {{ include "vald.chart" . }} + app.kubernetes.io/managed-by: {{ .Release.Service }} + app.kubernetes.io/instance: {{ .Release.Name }} + app.kubernetes.io/version: {{ .Chart.Version }} + app.kubernetes.io/component: gateway-mirror + name: {{ $gateway.name }}-ingress +spec: + defaultBackend: + service: + name: {{ $gateway.name }} + {{- include "vald.ingressPort" (dict "Values" $gateway.ingress) | nindent 6 }} + rules: + - host: {{ $gateway.ingress.host }} + http: + paths: + - backend: + service: + name: {{ $gateway.name }} + {{- include "vald.ingressPort" (dict "Values" $gateway.ingress) | nindent 12 }} + pathType: {{ $gateway.ingress.pathType }} +{{- end }} diff --git a/charts/vald/templates/gateway/mirror/pdb.yaml b/charts/vald/templates/gateway/mirror/pdb.yaml new file mode 100644 index 0000000000..d516368ae5 --- /dev/null +++ b/charts/vald/templates/gateway/mirror/pdb.yaml @@ -0,0 +1,34 @@ +# +# Copyright (C) 2019-2023 vdaas.org vald team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +{{- $gateway := .Values.gateway.mirror -}} +{{- if $gateway.enabled }} +apiVersion: policy/v1 +kind: PodDisruptionBudget +metadata: + name: {{ $gateway.name }} + labels: + app.kubernetes.io/name: {{ include "vald.name" . }} + helm.sh/chart: {{ include "vald.chart" . }} + app.kubernetes.io/managed-by: {{ .Release.Service }} + app.kubernetes.io/instance: {{ .Release.Name }} + app.kubernetes.io/version: {{ .Chart.Version }} + app.kubernetes.io/component: gateway-mirror +spec: + maxUnavailable: {{ $gateway.maxUnavailable }} + selector: + matchLabels: + app: {{ $gateway.name }} +{{- end }} diff --git a/charts/vald/templates/gateway/mirror/priorityclass.yaml b/charts/vald/templates/gateway/mirror/priorityclass.yaml new file mode 100644 index 0000000000..cb9b3bfa7e --- /dev/null +++ b/charts/vald/templates/gateway/mirror/priorityclass.yaml @@ -0,0 +1,32 @@ +# +# Copyright (C) 2019-2023 vdaas.org vald team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +{{- $gateway := .Values.gateway.mirror -}} +{{- if and $gateway.enabled $gateway.podPriority.enabled }} +apiVersion: scheduling.k8s.io/v1 +kind: PriorityClass +metadata: + name: {{ .Release.Namespace }}-{{ $gateway.name }}-priority + labels: + app.kubernetes.io/name: {{ include "vald.name" . }} + helm.sh/chart: {{ include "vald.chart" . }} + app.kubernetes.io/managed-by: {{ .Release.Service }} + app.kubernetes.io/instance: {{ .Release.Name }} + app.kubernetes.io/version: {{ .Chart.Version }} + app.kubernetes.io/component: gateway-mirror +value: {{ $gateway.podPriority.value }} +globalDefault: false +description: "A priority class for Vald mirror gateway." +{{- end }} diff --git a/charts/vald/templates/gateway/mirror/svc.yaml b/charts/vald/templates/gateway/mirror/svc.yaml new file mode 100644 index 0000000000..68191b2c62 --- /dev/null +++ b/charts/vald/templates/gateway/mirror/svc.yaml @@ -0,0 +1,52 @@ +# +# Copyright (C) 2019-2023 vdaas.org vald team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +{{- $gateway := .Values.gateway.mirror -}} +{{- if $gateway.enabled }} +apiVersion: v1 +kind: Service +metadata: + name: {{ $gateway.name }} + {{- if $gateway.service.annotations }} + annotations: + {{- toYaml $gateway.service.annotations | nindent 4 }} + {{- end }} + labels: + app.kubernetes.io/name: {{ include "vald.name" . }} + helm.sh/chart: {{ include "vald.chart" . }} + app.kubernetes.io/managed-by: {{ .Release.Service }} + app.kubernetes.io/instance: {{ .Release.Name }} + app.kubernetes.io/version: {{ .Chart.Version }} + app.kubernetes.io/component: gateway-mirror + {{- if $gateway.service.labels }} + {{- toYaml $gateway.service.labels | nindent 4 }} + {{- end }} +spec: + {{- $servers := dict "Values" $gateway.server_config "default" .Values.defaults.server_config }} + {{- include "vald.servicePorts" $servers | nindent 2 }} + selector: + app.kubernetes.io/name: {{ include "vald.name" . }} + app.kubernetes.io/component: gateway-mirror + {{- if eq $gateway.serviceType "ClusterIP" }} + clusterIP: None + {{- end }} + type: {{ $gateway.serviceType }} + {{- if $gateway.externalTrafficPolicy }} + externalTrafficPolicy: {{ $gateway.externalTrafficPolicy }} + {{- end }} + {{- if $gateway.internalTrafficPolicy }} + internalTrafficPolicy: {{ $gateway.internalTrafficPolicy }} + {{- end }} +{{- end }} diff --git a/charts/vald/values.yaml b/charts/vald/values.yaml index 50e9c03cce..5eda3e9227 100644 --- a/charts/vald/values.yaml +++ b/charts/vald/values.yaml @@ -1413,7 +1413,266 @@ gateway: # @schema {"name": "gateway.filter.gateway_config.egress_filter.distance_filters", "type": "array", "items": {"type": "string"}} # gateway.filter.gateway_config.egress_filter.distance_filters -- distance egress vector filter targets distance_filters: [] - + # @schema {"name": "gateway.mirror", "type": "object"} + mirror: + # @schema {"name": "gateway.mirror.enabled", "type": "boolean"} + # gateway.mirror.enabled -- gateway enabled + enabled: false + # @schema {"name": "gateway.mirror.version", "type": "string", "pattern": "^v[0-9]+\\.[0-9]+\\.[0-9]$", "anchor": "version"} + # gateway.mirror.version -- version of gateway config + version: v0.0.0 + # @schema {"name": "gateway.mirror.time_zone", "type": "string"} + # gateway.mirror.time_zone -- Time zone + time_zone: "" + # @schema {"name": "gateway.mirror.logging", "alias": "logging"} + # gateway.mirror.logging -- logging config (overrides defaults.logging) + logging: {} + # @schema {"name": "gateway.mirror.name", "type": "string"} + # gateway.mirror.name -- name of gateway deployment + name: vald-mirror-gateway + # @schema {"name": "gateway.mirror.kind", "type": "string", "enum": ["Deployment", "DaemonSet"]} + # gateway.mirror.kind -- deployment kind: Deployment or DaemonSet + kind: Deployment + # @schema {"name": "gateway.mirror.serviceType", "type": "string", "enum": ["ClusterIP", "LoadBalancer", "NodePort"]} + # gateway.mirror.serviceType -- service type: ClusterIP, LoadBalancer or NodePort + serviceType: ClusterIP + # @schema {"name": "gateway.mirror.externalTrafficPolicy", "type": "string"} + # gateway.mirror.externalTrafficPolicy -- external traffic policy (can be specified when service type is LoadBalancer or NodePort) : Cluster or Local + externalTrafficPolicy: "" + # @schema {"name": "gateway.mirror.internalTrafficPolicy", "type": "string"} + # gateway.mirror.internalTrafficPolicy -- internal traffic policy (can be specified when service type is LoadBalancer or NodePort) : Cluster or Local + internalTrafficPolicy: "" + # @schema {"name": "gateway.mirror.progressDeadlineSeconds", "type": "integer"} + # gateway.mirror.progressDeadlineSeconds -- progress deadline seconds + progressDeadlineSeconds: 600 + # @schema {"name": "gateway.mirror.minReplicas", "type": "integer", "minimum": 0} + # gateway.mirror.minReplicas -- minimum number of replicas. + # if HPA is disabled, the replicas will be set to this value + minReplicas: 3 + # @schema {"name": "gateway.mirror.maxReplicas", "type": "integer", "minimum": 0} + # gateway.mirror.maxReplicas -- maximum number of replicas. + # if HPA is disabled, this value will be ignored. + maxReplicas: 9 + # @schema {"name": "gateway.mirror.maxUnavailable", "type": "string"} + # gateway.mirror.maxUnavailable -- maximum number of unavailable replicas + maxUnavailable: 50% + # @schema {"name": "gateway.mirror.revisionHistoryLimit", "type": "integer", "minimum": 0} + # gateway.mirror.revisionHistoryLimit -- number of old history to retain to allow rollback + revisionHistoryLimit: 2 + # @schema {"name": "gateway.mirror.terminationGracePeriodSeconds", "type": "integer", "minimum": 0} + # gateway.mirror.terminationGracePeriodSeconds -- duration in seconds pod needs to terminate gracefully + terminationGracePeriodSeconds: 30 + # @schema {"name": "gateway.mirror.podSecurityContext", "type": "object"} + # gateway.mirror.podSecurityContext -- security context for pod + podSecurityContext: + runAsUser: 65532 + runAsNonRoot: true + runAsGroup: 65532 + fsGroup: 65532 + fsGroupChangePolicy: "OnRootMismatch" + # @schema {"name": "gateway.mirror.securityContext", "type": "object"} + # gateway.mirror.securityContext -- security context for container + securityContext: + runAsUser: 65532 + runAsNonRoot: true + runAsGroup: 65532 + privileged: false + allowPrivilegeEscalation: false + readOnlyRootFilesystem: true + capabilities: + drop: + - ALL + # @schema {"name": "gateway.mirror.podPriority", "type": "object", "anchor": "podPriority"} + podPriority: + # @schema {"name": "gateway.mirror.podPriority.enabled", "type": "boolean"} + # gateway.mirror.podPriority.enabled -- gateway pod PriorityClass enabled + enabled: true + # @schema {"name": "gateway.mirror.podPriority.value", "type": "integer"} + # gateway.mirror.podPriority.value -- gateway pod PriorityClass value + value: 1000000 + # @schema {"name": "gateway.mirror.annotations", "type": "object"} + # gateway.mirror.annotations -- deployment annotations + annotations: {} + # @schema {"name": "gateway.mirror.podAnnotations", "type": "object"} + # gateway.mirror.podAnnotations -- pod annotations + podAnnotations: {} + # @schema {"name": "gateway.mirror.service", "type": "object", "anchor": "service"} + service: + # @schema {"name": "gateway.mirror.service.annotations", "type": "object"} + # gateway.mirror.service.annotations -- service annotations + annotations: {} + # @schema {"name": "gateway.mirror.service.labels", "type": "object"} + # gateway.mirror.service.labels -- service labels + labels: {} + # @schema {"name": "gateway.mirror.hpa", "type": "object", "anchor": "hpa"} + hpa: + # @schema {"name": "gateway.mirror.hpa.enabled", "type": "boolean"} + # gateway.mirror.hpa.enabled -- HPA enabled + enabled: true + # @schema {"name": "gateway.mirror.hpa.targetCPUUtilizationPercentage", "type": "integer"} + # gateway.mirror.hpa.targetCPUUtilizationPercentage -- HPA CPU utilization percentage + targetCPUUtilizationPercentage: 80 + # @schema {"name": "gateway.mirror.image", "type": "object", "anchor": "image"} + image: + # @schema {"name": "gateway.mirror.image.repository", "type": "string"} + # gateway.mirror.image.repository -- image repository + repository: vdaas/vald-mirror-gateway + # @schema {"name": "gateway.mirror.image.tag", "type": "string"} + # gateway.mirror.image.tag -- image tag (overrides defaults.image.tag) + tag: "" + # @schema {"name": "gateway.mirror.image.pullPolicy", "type": "string", "enum": ["Always", "Never", "IfNotPresent"]} + # gateway.mirror.image.pullPolicy -- image pull policy + pullPolicy: Always + # @schema {"name": "gateway.mirror.rollingUpdate", "type": "object", "anchor": "rollingUpdate"} + rollingUpdate: + # @schema {"name": "gateway.mirror.rollingUpdate.maxSurge", "type": "string"} + # gateway.mirror.rollingUpdate.maxSurge -- max surge of rolling update + maxSurge: 25% + # @schema {"name": "gateway.mirror.rollingUpdate.maxUnavailable", "type": "string"} + # gateway.mirror.rollingUpdate.maxUnavailable -- max unavailable of rolling update + maxUnavailable: 25% + # @schema {"name": "gateway.mirror.initContainers", "type": "array", "items": {"type": "object"}, "anchor": "initContainers"} + # gateway.mirror.initContainers -- init containers + initContainers: + - type: wait-for + name: wait-for-gateway-lb + target: gateway-lb + image: busybox:stable + sleepDuration: 2 + # @schema {"name": "gateway.mirror.env", "type": "array", "items": {"type": "object"}, "anchor": "env"} + # gateway.mirror.env -- environment variables + env: + - name: MY_NODE_NAME + valueFrom: + fieldRef: + fieldPath: spec.nodeName + - name: MY_POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: MY_POD_NAMESPACE + valueFrom: + fieldRef: + fieldPath: metadata.namespace + # @schema {"name": "gateway.mirror.volumeMounts", "type": "array", "items": {"type": "object"}, "anchor": "volumeMounts"} + # gateway.mirror.volumeMounts -- volume mounts + volumeMounts: [] + # @schema {"name": "gateway.mirror.volumes", "type": "array", "items": {"type": "object"}, "anchor": "volumes"} + # gateway.mirror.volumes -- volumes + volumes: [] + # @schema {"name": "gateway.mirror.nodeName", "type": "string"} + # gateway.mirror.nodeName -- node name + nodeName: "" + # @schema {"name": "gateway.mirror.nodeSelector", "type": "object", "anchor": "nodeSelector"} + # gateway.mirror.nodeSelector -- node selector + nodeSelector: {} + # @schema {"name": "gateway.mirror.tolerations", "type": "array", "items": {"type": "object"}, "anchor": "tolerations"} + # gateway.mirror.tolerations -- tolerations + tolerations: [] + # @schema {"name": "gateway.mirror.affinity", "type": "object", "anchor": "affinity"} + affinity: + # @schema {"name": "gateway.mirror.affinity.nodeAffinity", "type": "object"} + nodeAffinity: + # @schema {"name": "gateway.mirror.affinity.nodeAffinity.preferredDuringSchedulingIgnoredDuringExecution", "type": "array", "items": {"type": "object"}} + # gateway.mirror.affinity.nodeAffinity.preferredDuringSchedulingIgnoredDuringExecution -- node affinity preferred scheduling terms + preferredDuringSchedulingIgnoredDuringExecution: [] + # @schema {"name": "gateway.mirror.affinity.nodeAffinity.requiredDuringSchedulingIgnoredDuringExecution", "type": "object"} + requiredDuringSchedulingIgnoredDuringExecution: + # @schema {"name": "gateway.mirror.affinity.nodeAffinity.requiredDuringSchedulingIgnoredDuringExecution.nodeSelectorTerms", "type": "array", "items": {"type": "object"}} + # gateway.mirror.affinity.nodeAffinity.requiredDuringSchedulingIgnoredDuringExecution.nodeSelectorTerms -- node affinity required node selectors + nodeSelectorTerms: [] + # @schema {"name": "gateway.mirror.affinity.podAffinity", "type": "object"} + podAffinity: + # @schema {"name": "gateway.mirror.affinity.podAffinity.preferredDuringSchedulingIgnoredDuringExecution", "type": "array", "items": {"type": "object"}} + # gateway.mirror.affinity.podAffinity.preferredDuringSchedulingIgnoredDuringExecution -- pod affinity preferred scheduling terms + preferredDuringSchedulingIgnoredDuringExecution: [] + # @schema {"name": "gateway.mirror.affinity.podAffinity.requiredDuringSchedulingIgnoredDuringExecution", "type": "array", "items": {"type": "object"}} + # gateway.mirror.affinity.podAffinity.requiredDuringSchedulingIgnoredDuringExecution -- pod affinity required scheduling terms + requiredDuringSchedulingIgnoredDuringExecution: [] + # @schema {"name": "gateway.mirror.affinity.podAntiAffinity", "type": "object"} + podAntiAffinity: + # @schema {"name": "gateway.mirror.affinity.podAntiAffinity.preferredDuringSchedulingIgnoredDuringExecution", "type": "array", "items": {"type": "object"}} + # gateway.mirror.affinity.podAntiAffinity.preferredDuringSchedulingIgnoredDuringExecution -- pod anti-affinity preferred scheduling terms + preferredDuringSchedulingIgnoredDuringExecution: + - weight: 100 + podAffinityTerm: + topologyKey: kubernetes.io/hostname + labelSelector: + matchExpressions: + - key: app + operator: In + values: + - vald-mirror-gateway + # @schema {"name": "gateway.mirror.affinity.podAntiAffinity.requiredDuringSchedulingIgnoredDuringExecution", "type": "array", "items": {"type": "object"}} + # gateway.mirror.affinity.podAntiAffinity.requiredDuringSchedulingIgnoredDuringExecution -- pod anti-affinity required scheduling terms + requiredDuringSchedulingIgnoredDuringExecution: [] + # @schema {"name": "gateway.mirror.topologySpreadConstraints", "type": "array", "items": {"type": "object"}, "anchor": "topologySpreadConstraints"} + # gateway.mirror.topologySpreadConstraints -- topology spread constraints of gateway pods + topologySpreadConstraints: [] + # @schema {"name": "gateway.mirror.server_config", "alias": "server_config"} + # gateway.mirror.server_config -- server config (overrides defaults.server_config) + server_config: + servers: + rest: {} + grpc: {} + healths: + liveness: {} + readiness: {} + startup: {} + metrics: + pprof: {} + # @schema {"name": "gateway.mirror.observability", "alias": "observability"} + # gateway.mirror.observability -- observability config (overrides defaults.observability) + observability: + otlp: + attribute: + service_name: vald-mirror-gateway + # @schema {"name": "gateway.mirror.ingress", "type": "object"} + ingress: + # @schema {"name": "gateway.mirror.ingress.pathType", "type": "string"} + # gateway.mirror.ingress.pathType -- gateway ingress pathType + pathType: ImplementationSpecific + # @schema {"name": "gateway.mirror.ingress.enabled", "type": "boolean"} + # gateway.mirror.ingress.enabled -- gateway ingress enabled + enabled: false + # @schema {"name": "gateway.mirror.ingress.annotations", "type": "object"} + # gateway.mirror.ingress.annotations -- annotations for ingress + annotations: + nginx.ingress.kubernetes.io/grpc-backend: "true" + # @schema {"name": "gateway.mirror.ingress.host", "type": "string"} + # gateway.mirror.ingress.host -- ingress hostname + host: mirror.gateway.vald.vdaas.org + # @schema {"name": "gateway.mirror.ingress.servicePort", "type": "string"} + # gateway.mirror.ingress.servicePort -- service port to be exposed by ingress + servicePort: grpc + # @schema {"name": "gateway.mirror.resources", "type": "object", "anchor": "resources"} + # gateway.mirror.resources -- compute resources + resources: + # @schema {"name": "gateway.mirror.resources.requests", "type": "object"} + requests: + cpu: 200m + memory: 150Mi + # @schema {"name": "gateway.mirror.resources.limits", "type": "object"} + limits: + cpu: 2000m + memory: 700Mi + # @schema {"name": "gateway.mirror.gateway_config", "type": "object"} + gateway_config: + # @schema {"name": "gateway.mirror.gateway_config.client", "alias": "grpc.client"} + # gateway.mirror.gateway_config.client -- gRPC client (overrides defaults.grpc.client) + client: {} + # @schema {"name": "gateway.mirror.gateway_config.self_mirror_addr", "type": "string"} + # gateway.mirror.gateway_config.self_mirror_addr -- address for self mirror-gateway + self_mirror_addr: "" + # @schema {"name": "gateway.mirror.gateway_config.gateway_addr", "type": "string"} + # gateway.mirror.gateway_config.gateway_addr -- address for lb-gateway + gateway_addr: "" + # @schema {"name": "gateway.mirror.gateway_config.pod_name", "type": "string"} + # gateway.mirror.gateway_config.pod_name -- self mirror gateway pod name + pod_name: _MY_POD_NAME_ + # @schema {"name": "gateway.mirror.gateway_config.advertise_interval", "type": "string"} + # gateway.mirror.gateway_config.advertise_interval -- interval to advertise mirror-gateway information to other mirror-gateway. + advertise_interval: "1s" # @schema {"name": "agent", "type": "object"} agent: # @schema {"name": "agent.enabled", "type": "boolean"} diff --git a/charts/vald/values/multi-vald/dev-vald-01.yaml b/charts/vald/values/multi-vald/dev-vald-01.yaml new file mode 100644 index 0000000000..0849bb2433 --- /dev/null +++ b/charts/vald/values/multi-vald/dev-vald-01.yaml @@ -0,0 +1,5 @@ +discoverer: + clusterRoleBinding: + name: vald-01 + serviceAccount: + name: vald-01 diff --git a/charts/vald/values/multi-vald/dev-vald-02.yaml b/charts/vald/values/multi-vald/dev-vald-02.yaml new file mode 100644 index 0000000000..7e60a18b34 --- /dev/null +++ b/charts/vald/values/multi-vald/dev-vald-02.yaml @@ -0,0 +1,7 @@ +discoverer: + clusterRole: + enabled: false + clusterRoleBinding: + name: vald-02 + serviceAccount: + name: vald-02 diff --git a/charts/vald/values/multi-vald/dev-vald-03.yaml b/charts/vald/values/multi-vald/dev-vald-03.yaml new file mode 100644 index 0000000000..27a37502bb --- /dev/null +++ b/charts/vald/values/multi-vald/dev-vald-03.yaml @@ -0,0 +1,15 @@ +gateway: + mirror: + gateway_config: + client: + addrs: + - vald-mirror-gateway.vald-01.svc.cluster.local:8081 + - vald-mirror-gateway.vald-02.svc.cluster.local:8081 + +discoverer: + clusterRole: + enabled: false + clusterRoleBinding: + name: vald-03 + serviceAccount: + name: vald-03 diff --git a/charts/vald/values/multi-vald/dev-vald-with-mirror.yaml b/charts/vald/values/multi-vald/dev-vald-with-mirror.yaml new file mode 100644 index 0000000000..7c243cffde --- /dev/null +++ b/charts/vald/values/multi-vald/dev-vald-with-mirror.yaml @@ -0,0 +1,88 @@ +defaults: + image: + tag: pr-1949 # TODO: Delete it later. + server_config: + metrics: + pprof: + enabled: true + servers: + grpc: + server: + grpc: + interceptors: + - RecoverInterceptor + - TraceInterceptor + - AccessLogInterceptor + - MetricInterceptor + grpc: + client: + dial_option: + interceptors: + - TraceInterceptor + observability: + enabled: true + otlp: + collector_endpoint: "opentelemetry-collector-collector.default.svc.cluster.local:4317" + trace: + enabled: true + +gateway: + lb: + podAnnotations: + profefe.com/enable: "true" + profefe.com/port: "6060" + profefe.com/service: "vald-lb-gateway" + resources: + requests: + cpu: 100m + memory: 50Mi + gateway_config: + index_replica: 2 + + # NOTE: Will work with multiple replicas in the future. + mirror: + enabled: true + minReplicas: 1 + maxReplicas: 1 + ingress: + enabled: false + gateway_config: + lb_client: {} + mirror_client: {} + self_mirror_client: {} + +agent: + podAnnotations: + profefe.com/enable: "true" + profefe.com/port: "6060" + profefe.com/service: "vald-agent-ngt" + minReplicas: 5 + maxReplicas: 10 + podManagementPolicy: Parallel + resources: + requests: + cpu: 100m + memory: 50Mi + ngt: + dimension: 784 + +discoverer: + podAnnotations: + profefe.com/enable: "true" + profefe.com/port: "6060" + profefe.com/service: "vald-discoverer" + resources: + requests: + cpu: 100m + memory: 50Mi + +manager: + index: + podAnnotations: + profefe.com/enable: "true" + profefe.com/port: "6060" + profefe.com/service: "vald-manager-index" + resources: + requests: + cpu: 100m + memory: 30Mi diff --git a/cmd/gateway/mirror/main.go b/cmd/gateway/mirror/main.go new file mode 100644 index 0000000000..9f595fdc81 --- /dev/null +++ b/cmd/gateway/mirror/main.go @@ -0,0 +1,58 @@ +// +// Copyright (C) 2019-2022 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package main provides program main +package main + +import ( + "context" + + "github.com/vdaas/vald/internal/errors" + "github.com/vdaas/vald/internal/info" + "github.com/vdaas/vald/internal/log" + "github.com/vdaas/vald/internal/runner" + "github.com/vdaas/vald/internal/safety" + "github.com/vdaas/vald/pkg/gateway/mirror/config" + "github.com/vdaas/vald/pkg/gateway/mirror/usecase" +) + +const ( + maxVersion = "v0.0.10" + minVersion = "v0.0.0" + name = "gateway mirror" +) + +func main() { + if err := safety.RecoverFunc(func() error { + return runner.Do( + context.Background(), + runner.WithName(name), + runner.WithVersion(info.Version, maxVersion, minVersion), + runner.WithConfigLoader(func(path string) (interface{}, *config.GlobalConfig, error) { + cfg, err := config.NewConfig(path) + if err != nil { + return nil, nil, errors.Wrap(err, "failed to load "+name+"'s configuration") + } + return cfg, &cfg.GlobalConfig, nil + }), + runner.WithDaemonInitializer(func(cfg interface{}) (runner.Runner, error) { + return usecase.New(cfg.(*config.Data)) + }), + ) + })(); err != nil { + log.Fatal(err, info.Get()) + } +} diff --git a/cmd/gateway/mirror/sample.yaml b/cmd/gateway/mirror/sample.yaml new file mode 100644 index 0000000000..ffcd6c861d --- /dev/null +++ b/cmd/gateway/mirror/sample.yaml @@ -0,0 +1,148 @@ +# +# Copyright (C) 2019-2022 vdaas.org vald team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +version: v0.0.0 +time_zone: JST +logging: + format: raw + level: debug + logger: glg +server_config: + servers: + - name: grpc + host: 0.0.0.0 + port: 8081 + grpc: + bidirectional_stream_concurrency: 20 + connection_timeout: "" + header_table_size: 0 + initial_conn_window_size: 0 + initial_window_size: 0 + interceptors: [] + keepalive: + max_conn_age: "" + max_conn_age_grace: "" + max_conn_idle: "" + time: "" + timeout: "" + max_header_list_size: 0 + max_receive_message_size: 0 + max_send_message_size: 0 + read_buffer_size: 0 + write_buffer_size: 0 + mode: GRPC + probe_wait_time: 3s + restart: true + health_check_servers: + - name: liveness + host: 0.0.0.0 + port: 3000 + http: + handler_timeout: "" + idle_timeout: "" + read_header_timeout: "" + read_timeout: "" + shutdown_duration: 5s + write_timeout: "" + mode: "" + probe_wait_time: 3s + - name: readiness + host: 0.0.0.0 + port: 3001 + http: + handler_timeout: "" + idle_timeout: "" + read_header_timeout: "" + read_timeout: "" + shutdown_duration: 0s + write_timeout: "" + mode: "" + probe_wait_time: 3s + metrics_servers: + startup_strategy: + - liveness + - grpc + - readiness + full_shutdown_duration: 600s + tls: + ca: /path/to/ca + cert: /path/to/cert + enabled: false + key: /path/to/key +observability: + enabled: false + trace: + enabled: false +gateway: + client: + addrs: + - vald-discoverer.default.svc.cluster.local:8081 + health_check_duration: "1s" + connection_pool: + enable_dns_resolver: true + enable_rebalance: true + old_conn_close_duration: 3s + rebalance_duration: 30m + size: 3 + backoff: + backoff_factor: 1.1 + backoff_time_limit: 5s + enable_error_log: true + initial_duration: 5ms + jitter_limit: 100ms + maximum_duration: 5s + retry_count: 100 + call_option: + max_recv_msg_size: 0 + max_retry_rpc_buffer_size: 0 + max_send_msg_size: 0 + wait_for_ready: true + dial_option: + backoff_base_delay: 1s + backoff_jitter: 0.2 + backoff_max_delay: 120s + backoff_multiplier: 1.6 + enable_backoff: false + initial_connection_window_size: 0 + initial_window_size: 0 + insecure: true + keepalive: + permit_without_stream: false + time: "" + timeout: "" + max_msg_size: 0 + min_connection_timeout: 20s + read_buffer_size: 0 + tcp: + dialer: + dual_stack_enabled: true + keepalive: "" + timeout: "" + dns: + cache_enabled: true + cache_expiration: 1h + refresh_duration: 30m + tls: + enabled: false + ca: /path/to/ca + cert: /path/to/cert + key: /path/to/key + timeout: "" + write_buffer_size: 0 + tls: + enabled: false + ca: /path/to/ca + cert: /path/to/cert + key: /path/to/key diff --git a/dockers/gateway/mirror/Dockerfile b/dockers/gateway/mirror/Dockerfile new file mode 100644 index 0000000000..e21ac01dbf --- /dev/null +++ b/dockers/gateway/mirror/Dockerfile @@ -0,0 +1,81 @@ +# +# Copyright (C) 2019-2022 vdaas.org vald team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +ARG GO_VERSION=latest +ARG DISTROLESS_IMAGE=gcr.io/distroless/static +ARG DISTROLESS_IMAGE_TAG=nonroot +ARG MAINTAINER="vdaas.org vald team " + +FROM golang:${GO_VERSION} AS builder + +ENV GO111MODULE on +ENV LANG en_US.UTF-8 +ENV ORG vdaas +ENV REPO vald +ENV PKG gateway/mirror +ENV APP_NAME mirror + +# skipcq: DOK-DL3008 +RUN apt-get update && apt-get install -y --no-install-recommends \ + upx \ + git \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +RUN mkdir -p "$GOPATH/src" + +WORKDIR ${GOPATH}/src/github.com/${ORG}/${REPO} + +COPY go.mod . +COPY go.sum . + +RUN go mod download + +WORKDIR ${GOPATH}/src/github.com/${ORG}/${REPO}/internal +COPY internal . + +WORKDIR ${GOPATH}/src/github.com/${ORG}/${REPO}/apis/grpc +COPY apis/grpc . + +WORKDIR ${GOPATH}/src/github.com/${ORG}/${REPO}/pkg/${PKG} +COPY pkg/${PKG} . + +WORKDIR ${GOPATH}/src/github.com/${ORG}/${REPO}/cmd/${PKG} +COPY cmd/${PKG} . + +WORKDIR ${GOPATH}/src/github.com/${ORG}/${REPO}/versions +COPY versions . + +WORKDIR ${GOPATH}/src/github.com/${ORG}/${REPO}/Makefile.d +COPY Makefile.d . + +WORKDIR ${GOPATH}/src/github.com/${ORG}/${REPO} +COPY Makefile . +COPY .git . + +RUN make REPO=${ORG} NAME=${REPO} cmd/${PKG}/${APP_NAME} \ + && mv "cmd/${PKG}/${APP_NAME}" "/usr/bin/${APP_NAME}" + +FROM ${DISTROLESS_IMAGE}:${DISTROLESS_IMAGE_TAG} +LABEL maintainer "${MAINTAINER}" + +ENV APP_NAME mirror + +COPY --from=builder /usr/bin/${APP_NAME} /go/bin/${APP_NAME} + +USER nonroot:nonroot + +ENTRYPOINT ["/go/bin/mirror"] diff --git a/internal/circuitbreaker/manager.go b/internal/circuitbreaker/manager.go index a7846b9aaa..7c0b5a2282 100644 --- a/internal/circuitbreaker/manager.go +++ b/internal/circuitbreaker/manager.go @@ -94,7 +94,8 @@ func (bm *breakerManager) Do(ctx context.Context, key string, fn func(ctx contex if err != nil { switch st { case StateClosed: - err = errors.Wrapf(err, "circuitbreaker state is %s, this error is not caused by circuitbreaker", st.String()) + cerr := errors.Wrapf(err, "circuitbreaker state is %s, this error is not caused by circuitbreaker", st.String()) + log.Debug(cerr) case StateOpen: if !errors.Is(err, errors.ErrCircuitBreakerOpenState) { err = errors.Join(err, errors.ErrCircuitBreakerOpenState) diff --git a/internal/client/v1/client/mirror/mirror.go b/internal/client/v1/client/mirror/mirror.go new file mode 100644 index 0000000000..30c6b25425 --- /dev/null +++ b/internal/client/v1/client/mirror/mirror.go @@ -0,0 +1,98 @@ +package mirror + +import ( + "context" + + "github.com/vdaas/vald/apis/grpc/v1/mirror" + "github.com/vdaas/vald/apis/grpc/v1/payload" + "github.com/vdaas/vald/apis/grpc/v1/vald" + "github.com/vdaas/vald/internal/errors" + "github.com/vdaas/vald/internal/net/grpc" + "github.com/vdaas/vald/internal/observability/trace" +) + +const ( + apiName = "vald/internal/client/v1/client/mirror" +) + +type Client interface { + mirror.MirrorClient + GRPCClient() grpc.Client + Start(context.Context) (<-chan error, error) + Stop(context.Context) error +} + +type client struct { + addrs []string + c grpc.Client +} + +func New(opts ...Option) (Client, error) { + c := new(client) + for _, opt := range append(defaultOpts, opts...) { + if err := opt(c); err != nil { + return nil, err + } + } + if c.c == nil { + if len(c.addrs) == 0 { + return nil, errors.ErrGRPCTargetAddrNotFound + } + c.c = grpc.New(grpc.WithAddrs(c.addrs...)) + } + return c, nil +} + +func (c *client) Start(ctx context.Context) (<-chan error, error) { + return c.c.StartConnectionMonitor(ctx) +} + +func (c *client) Stop(ctx context.Context) error { + return c.Stop(ctx) +} + +func (c *client) GRPCClient() grpc.Client { + return c.c +} + +func (c *client) Register(ctx context.Context, in *payload.Mirror_Targets, opts ...grpc.CallOption) (res *payload.Mirror_Targets, err error) { + ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "internal/client/"+vald.RegisterRPCName), apiName+"/"+vald.RegisterRPCName) + defer func() { + if span != nil { + span.End() + } + }() + + _, err = c.c.RoundRobin(ctx, func(ctx context.Context, conn *grpc.ClientConn, copts ...grpc.CallOption) (interface{}, error) { + res, err = mirror.NewMirrorClient(conn).Register(ctx, in, append(copts, opts...)...) + if err != nil { + return nil, err + } + return res, nil + }) + if err != nil { + return nil, err + } + return res, nil +} + +func (c *client) Advertise(ctx context.Context, in *payload.Mirror_Targets, opts ...grpc.CallOption) (res *payload.Mirror_Targets, err error) { + ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "internal/client/"+vald.AdvertiseRPCName), apiName+"/"+vald.AdvertiseRPCName) + defer func() { + if span != nil { + span.End() + } + }() + + _, err = c.c.RoundRobin(ctx, func(ctx context.Context, conn *grpc.ClientConn, copts ...grpc.CallOption) (interface{}, error) { + res, err = mirror.NewMirrorClient(conn).Advertise(ctx, in, append(copts, opts...)...) + if err != nil { + return nil, err + } + return res, nil + }) + if err != nil { + return nil, err + } + return res, nil +} diff --git a/internal/client/v1/client/mirror/option.go b/internal/client/v1/client/mirror/option.go new file mode 100644 index 0000000000..894e76216b --- /dev/null +++ b/internal/client/v1/client/mirror/option.go @@ -0,0 +1,32 @@ +package mirror + +import ( + "github.com/vdaas/vald/internal/net/grpc" +) + +type Option func(c *client) error + +var defaultOpts = []Option{} + +func WithAddrs(addrs ...string) Option { + return func(c *client) error { + if addrs == nil { + return nil + } + if c.addrs != nil { + c.addrs = append(c.addrs, addrs...) + } else { + c.addrs = addrs + } + return nil + } +} + +func WithClient(gc grpc.Client) Option { + return func(c *client) error { + if gc != nil { + c.c = gc + } + return nil + } +} diff --git a/internal/config/mirror.go b/internal/config/mirror.go new file mode 100644 index 0000000000..43db1ab4ab --- /dev/null +++ b/internal/config/mirror.go @@ -0,0 +1,46 @@ +// +// Copyright (C) 2019-2022 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package config providers configuration type and load configuration logic +package config + +// Mirror represents the Mirror Gateway configuration. +type Mirror struct { + // Client represents the gRPC client configuration for connecting the LB Gateway. + Client *GRPCClient `json:"client" yaml:"client"` + // SelfMirrorAddr represents the address for the self Mirror Gateway. + SelfMirrorAddr string `json:"self_mirror_addr" yaml:"self_mirror_addr"` + // GatewayAddr represents the address for the Vald Gateway (e.g lb-gateway). + GatewayAddr string `json:"gateway_addr" yaml:"gateway_addr"` + // PodName represents self Mirror Gateway Pod name. + PodName string `json:"pod_name" yaml:"pod_name"` + // AdvertiseInterval represents interval to advertise Mirror Gateway information to other mirror gateway. + AdvertiseInterval string `json:"advertise_interval" yaml:"advertise_interval"` +} + +// Bind binds the actual data from the Mirror receiver fields. +func (m *Mirror) Bind() *Mirror { + m.SelfMirrorAddr = GetActualValue(m.SelfMirrorAddr) + m.GatewayAddr = GetActualValue(m.GatewayAddr) + m.PodName = GetActualValue(m.PodName) + m.AdvertiseInterval = GetActualValue(m.AdvertiseInterval) + if m.Client != nil { + m.Client = m.Client.Bind() + } else { + m.Client = new(GRPCClient).Bind() + } + return m +} diff --git a/internal/net/grpc/interceptor/server/metric/metric.go b/internal/net/grpc/interceptor/server/metric/metric.go index f7a2e7cb81..7eb5477ace 100644 --- a/internal/net/grpc/interceptor/server/metric/metric.go +++ b/internal/net/grpc/interceptor/server/metric/metric.go @@ -65,7 +65,7 @@ func MetricInterceptors() (grpc.UnaryServerInterceptor, grpc.StreamServerInterce elapsedTime := time.Since(now) record(ctx, info.FullMethod, err, float64(elapsedTime)/float64(time.Millisecond)) return resp, err - }, func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + }, func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) { now := time.Now() err = handler(srv, ss) elapsedTime := time.Since(now) diff --git a/internal/net/grpc/metadata.go b/internal/net/grpc/metadata.go new file mode 100644 index 0000000000..afc51d6e96 --- /dev/null +++ b/internal/net/grpc/metadata.go @@ -0,0 +1,34 @@ +// +// Copyright (C) 2019-2023 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package grpc provides generic functionality for grpc +package grpc + +import ( + "context" + + "google.golang.org/grpc/metadata" +) + +type MD = metadata.MD + +func NewOutgoingContext(ctx context.Context, md MD) context.Context { + return metadata.NewOutgoingContext(ctx, md) +} + +func FromIncomingContext(ctx context.Context) (metadata.MD, bool) { + return metadata.FromIncomingContext(ctx) +} diff --git a/pkg/gateway/mirror/README.md b/pkg/gateway/mirror/README.md new file mode 100755 index 0000000000..668dce401c --- /dev/null +++ b/pkg/gateway/mirror/README.md @@ -0,0 +1 @@ +# vald Mirror gateway diff --git a/pkg/gateway/mirror/config/config.go b/pkg/gateway/mirror/config/config.go new file mode 100644 index 0000000000..fec0bf3618 --- /dev/null +++ b/pkg/gateway/mirror/config/config.go @@ -0,0 +1,156 @@ +// +// Copyright (C) 2019-2022 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package setting stores all server application settings +package config + +import ( + "github.com/vdaas/vald/internal/config" + "github.com/vdaas/vald/internal/errors" +) + +type ( + GlobalConfig = config.GlobalConfig + Server = config.Server +) + +// Config represent a application setting data content (config.yaml). +// In K8s environment, this configuration is stored in K8s ConfigMap. +type Data struct { + config.GlobalConfig `json:",inline" yaml:",inline"` + + // Server represent all server configurations + Server *config.Servers `json:"server_config" yaml:"server_config"` + + // Observability represent observability configurations + Observability *config.Observability `json:"observability" yaml:"observability"` + + // Mirror represent mirror gateway service configuration + Mirror *config.Mirror `json:"gateway" yaml:"gateway"` +} + +func NewConfig(path string) (cfg *Data, err error) { + cfg = new(Data) + + if err = config.Read(path, &cfg); err != nil { + return nil, err + } + + if cfg != nil { + cfg.Bind() + } else { + return nil, errors.ErrInvalidConfig + } + + if cfg.Server != nil { + cfg.Server = cfg.Server.Bind() + } else { + return nil, errors.ErrInvalidConfig + } + + if cfg.Observability != nil { + cfg.Observability = cfg.Observability.Bind() + } else { + cfg.Observability = new(config.Observability).Bind() + } + + if cfg.Mirror != nil { + cfg.Mirror = cfg.Mirror.Bind() + } else { + return nil, errors.ErrInvalidConfig + } + return cfg, nil +} + +// func FakeData() { +// d := Data{ +// Version: "v0.0.1", +// Server: &config.Servers{ +// Servers: []*config.Server{ +// { +// Name: "agent-rest", +// Host: "127.0.0.1", +// Port: 8080, +// Mode: "REST", +// ProbeWaitTime: "3s", +// ShutdownDuration: "5s", +// HandlerTimeout: "5s", +// IdleTimeout: "2s", +// ReadHeaderTimeout: "1s", +// ReadTimeout: "1s", +// WriteTimeout: "1s", +// }, +// { +// Name: "agent-grpc", +// Host: "127.0.0.1", +// Port: 8082, +// Mode: "GRPC", +// }, +// }, +// MetricsServers: []*config.Server{ +// { +// Name: "pprof", +// Host: "127.0.0.1", +// Port: 6060, +// Mode: "REST", +// ProbeWaitTime: "3s", +// ShutdownDuration: "5s", +// HandlerTimeout: "5s", +// IdleTimeout: "2s", +// ReadHeaderTimeout: "1s", +// ReadTimeout: "1s", +// WriteTimeout: "1s", +// }, +// }, +// HealthCheckServers: []*config.Server{ +// { +// Name: "livenesss", +// Host: "127.0.0.1", +// Port: 3000, +// }, +// { +// Name: "readiness", +// Host: "127.0.0.1", +// Port: 3001, +// }, +// }, +// StartUpStrategy: []string{ +// "livenesss", +// "pprof", +// "agent-grpc", +// "agent-rest", +// "readiness", +// }, +// ShutdownStrategy: []string{ +// "readiness", +// "agent-rest", +// "agent-grpc", +// "pprof", +// "livenesss", +// }, +// FullShutdownDuration: "30s", +// TLS: &config.TLS{ +// Enabled: false, +// Cert: "/path/to/cert", +// Key: "/path/to/key", +// CA: "/path/to/ca", +// }, +// }, +// Mirror: &config.Mirror{ +// }, +// } +// fmt.Println(config.ToRawYaml(d)) +// } diff --git a/pkg/gateway/mirror/handler/doc.go b/pkg/gateway/mirror/handler/doc.go new file mode 100644 index 0000000000..f1014141ea --- /dev/null +++ b/pkg/gateway/mirror/handler/doc.go @@ -0,0 +1,17 @@ +// +// Copyright (C) 2019-2022 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler diff --git a/pkg/gateway/mirror/handler/grpc/handler.go b/pkg/gateway/mirror/handler/grpc/handler.go new file mode 100644 index 0000000000..7b65ce351c --- /dev/null +++ b/pkg/gateway/mirror/handler/grpc/handler.go @@ -0,0 +1,3283 @@ +// +// Copyright (C) 2019-2022 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package grpc provides grpc server logic +package grpc + +import ( + "context" + "fmt" + "reflect" + "sync" + + "github.com/vdaas/vald/apis/grpc/v1/payload" + "github.com/vdaas/vald/apis/grpc/v1/vald" + "github.com/vdaas/vald/internal/errgroup" + "github.com/vdaas/vald/internal/errors" + "github.com/vdaas/vald/internal/log" + "github.com/vdaas/vald/internal/net/grpc" + "github.com/vdaas/vald/internal/net/grpc/codes" + "github.com/vdaas/vald/internal/net/grpc/errdetails" + "github.com/vdaas/vald/internal/net/grpc/status" + "github.com/vdaas/vald/internal/observability/trace" + "github.com/vdaas/vald/internal/safety" + "github.com/vdaas/vald/pkg/gateway/mirror/service" +) + +type server struct { + eg errgroup.Group + gateway service.Gateway // Mirror Gateway client service. + mirror service.Mirror + vAddr string // Vald Gateway address (lb-gateway). + streamConcurrency int + name string + ip string + vald.UnimplementedValdServerWithMirror +} + +const apiName = "vald/gateway/mirror" + +func New(opts ...Option) (vald.ServerWithMirror, error) { + s := new(server) + for _, opt := range append(defaultOptions, opts...) { + if err := opt(s); err != nil { + oerr := errors.ErrOptionFailed(err, reflect.ValueOf(opt)) + e := &errors.ErrCriticalOption{} + if errors.As(err, &e) { + log.Error(oerr) + return nil, oerr + } + log.Warn(oerr) + } + } + return s, nil +} + +func (s *server) Register(ctx context.Context, req *payload.Mirror_Targets) (*payload.Mirror_Targets, error) { + ctx, span := trace.StartSpan(grpc.WithGRPCMethod(ctx, vald.PackageName+"."+vald.MirrorRPCServiceName+"/"+vald.RegisterRPCName), apiName+"/"+vald.RegisterRPCName) + defer func() { + if span != nil { + span.End() + } + }() + err := s.mirror.Connect(ctx, req.GetTargets()...) + if err != nil { + reqInfo := &errdetails.RequestInfo{ + ServingData: errdetails.Serialize(req), + } + resInfo := &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.RegisterRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + } + var attrs trace.Attributes + + switch { + case errors.Is(err, context.Canceled): + err = status.WrapWithCanceled( + vald.RegisterRPCName+" API canceld", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeCancelled(err.Error()) + case errors.Is(err, context.DeadlineExceeded): + err = status.WrapWithCanceled( + vald.RegisterRPCName+" API deadline exceeded", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeDeadlineExceeded(err.Error()) + case errors.Is(err, errors.ErrGRPCClientConnNotFound("*")): + err = status.WrapWithInternal( + vald.RegisterRPCName+" API connection not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + case errors.Is(err, errors.ErrTargetNotFound): + err = status.WrapWithInvalidArgument( + vald.RegisterRPCName+" API target not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInvalidArgument(err.Error()) + default: + var ( + st *status.Status + msg string + ) + st, msg, err = status.ParseError(err, codes.Internal, + "failed to parse "+vald.RegisterRPCName+" gRPC error response", reqInfo, resInfo, + ) + attrs = trace.FromGRPCStatus(st.Code(), msg) + } + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(attrs...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + return req, nil +} + +func (s *server) Advertise(ctx context.Context, req *payload.Mirror_Targets) (res *payload.Mirror_Targets, err error) { + ctx, span := trace.StartSpan(grpc.WithGRPCMethod(ctx, vald.PackageName+"."+vald.MirrorRPCServiceName+"/"+vald.AdvertiseRPCName), apiName+"/"+vald.AdvertiseRPCName) + defer func() { + if span != nil { + span.End() + } + }() + _, err = s.Register(ctx, req) + if err != nil { + return nil, err + } + tgts, err := s.mirror.MirrorTargets() + if err != nil { + err = status.WrapWithInternal(vald.AdvertiseRPCName+" API failed to get connected vald gateway targets", err, + &errdetails.RequestInfo{ + ServingData: errdetails.Serialize(req), + }, + &errdetails.BadRequest{ + FieldViolations: []*errdetails.BadRequestFieldViolation{ + { + Field: "mirror gateway targets", + Description: err.Error(), + }, + }, + }, + &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.AdvertiseRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }, + ) + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.StatusCodeInternal(err.Error())...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + return &payload.Mirror_Targets{ + Targets: tgts, + }, nil +} + +func (s *server) Exists(ctx context.Context, meta *payload.Object_ID) (id *payload.Object_ID, err error) { + ctx, span := trace.StartSpan(grpc.WithGRPCMethod(ctx, vald.PackageName+"."+vald.ObjectRPCServiceName+"/"+vald.ExistsRPCName), apiName+"/"+vald.ExistsRPCName) + defer func() { + if span != nil { + span.End() + } + }() + + _, err = s.gateway.Do(ctx, s.vAddr, func(ctx context.Context, vc vald.ClientWithMirror, copts ...grpc.CallOption) (interface{}, error) { + id, err = vc.Exists(ctx, meta, copts...) + if err != nil { + return nil, err + } + return id, nil + }) + if err != nil { + reqInfo := &errdetails.RequestInfo{ + ServingData: errdetails.Serialize(meta), + } + resInfo := &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.ExistsRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, s.vAddr), + } + var attrs trace.Attributes + + switch { + case errors.Is(err, context.Canceled): + err = status.WrapWithCanceled( + vald.ExistsRPCName+" API canceld", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeCancelled(err.Error()) + case errors.Is(err, context.DeadlineExceeded): + err = status.WrapWithDeadlineExceeded( + vald.ExistsRPCName+" API deadline exceeded", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeDeadlineExceeded(err.Error()) + case errors.Is(err, errors.ErrTargetNotFound): + err = status.WrapWithInternal( + vald.ExistsRPCName+" API target not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + case errors.Is(err, errors.ErrGRPCClientConnNotFound("*")): + err = status.WrapWithInternal( + vald.ExistsRPCName+" API connection not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + default: + var ( + st *status.Status + msg string + ) + st, msg, err = status.ParseError(err, codes.Internal, + "failed to parse "+vald.ExistsRPCName+" gRPC error response", reqInfo, resInfo, + ) + attrs = trace.FromGRPCStatus(st.Code(), msg) + } + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(attrs...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + return id, nil +} + +func (s *server) Search(ctx context.Context, req *payload.Search_Request) (res *payload.Search_Response, err error) { + ctx, span := trace.StartSpan(grpc.WithGRPCMethod(ctx, vald.PackageName+"."+vald.SearchRPCServiceName+"/"+vald.SearchRPCName), apiName+"/"+vald.SearchRPCName) + defer func() { + if span != nil { + span.End() + } + }() + + _, err = s.gateway.Do(ctx, s.vAddr, func(ctx context.Context, vc vald.ClientWithMirror, copts ...grpc.CallOption) (interface{}, error) { + res, err = vc.Search(ctx, req, copts...) + if err != nil { + return nil, err + } + return res, nil + }) + if err != nil { + reqInfo := &errdetails.RequestInfo{ + RequestId: req.GetConfig().GetRequestId(), + ServingData: errdetails.Serialize(req), + } + resInfo := &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.SearchRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, s.vAddr), + } + var attrs trace.Attributes + + switch { + case errors.Is(err, context.Canceled): + err = status.WrapWithCanceled( + vald.SearchRPCName+" API canceld", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeCancelled(err.Error()) + case errors.Is(err, context.DeadlineExceeded): + err = status.WrapWithDeadlineExceeded( + vald.SearchRPCName+" API deadline exceeded", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeDeadlineExceeded(err.Error()) + case errors.Is(err, errors.ErrTargetNotFound): + err = status.WrapWithInternal( + vald.SearchRPCName+" API target not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + case errors.Is(err, errors.ErrGRPCClientConnNotFound("*")): + err = status.WrapWithInternal( + vald.SearchRPCName+" API connection not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + default: + var ( + st *status.Status + msg string + ) + st, msg, err = status.ParseError(err, codes.Internal, + "failed to parse "+vald.SearchRPCName+" gRPC error response", reqInfo, resInfo, + ) + attrs = trace.FromGRPCStatus(st.Code(), msg) + } + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(attrs...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + return res, nil +} + +func (s *server) SearchByID(ctx context.Context, req *payload.Search_IDRequest) ( + res *payload.Search_Response, err error, +) { + ctx, span := trace.StartSpan(grpc.WithGRPCMethod(ctx, vald.PackageName+"."+vald.SearchRPCServiceName+"/"+vald.SearchByIDRPCName), apiName+"/"+vald.SearchByIDRPCName) + defer func() { + if span != nil { + span.End() + } + }() + + _, err = s.gateway.Do(ctx, s.vAddr, func(ctx context.Context, vc vald.ClientWithMirror, copts ...grpc.CallOption) (interface{}, error) { + res, err = vc.SearchByID(ctx, req, copts...) + if err != nil { + return nil, err + } + return res, nil + }) + if err != nil { + reqInfo := &errdetails.RequestInfo{ + RequestId: req.GetConfig().GetRequestId(), + ServingData: errdetails.Serialize(req), + } + resInfo := &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.SearchByIDRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, s.vAddr), + } + var attrs trace.Attributes + + switch { + case errors.Is(err, context.Canceled): + err = status.WrapWithCanceled( + vald.SearchByIDRPCName+" API canceld", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeCancelled(err.Error()) + case errors.Is(err, context.DeadlineExceeded): + err = status.WrapWithDeadlineExceeded( + vald.SearchByIDRPCName+" API deadline exceeded", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeDeadlineExceeded(err.Error()) + case errors.Is(err, errors.ErrTargetNotFound): + err = status.WrapWithInternal( + vald.SearchByIDRPCName+" API target not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + case errors.Is(err, errors.ErrGRPCClientConnNotFound("*")): + err = status.WrapWithInternal( + vald.SearchByIDRPCName+" API connection not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + default: + var ( + st *status.Status + msg string + ) + st, msg, err = status.ParseError(err, codes.Internal, + "failed to parse "+vald.SearchByIDRPCName+" gRPC error response", reqInfo, resInfo, + ) + attrs = trace.FromGRPCStatus(st.Code(), msg) + } + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(attrs...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + return res, nil +} + +func (s *server) StreamSearch(stream vald.Search_StreamSearchServer) (err error) { + ctx, span := trace.StartSpan(grpc.WithGRPCMethod(stream.Context(), vald.PackageName+"."+vald.SearchRPCServiceName+"/"+vald.StreamSearchRPCName), apiName+"/"+vald.StreamSearchRPCName) + defer func() { + if span != nil { + span.End() + } + }() + err = grpc.BidirectionalStream(ctx, stream, s.streamConcurrency, + func(ctx context.Context, req *payload.Search_Request) (*payload.Search_StreamResponse, error) { + ctx, sspan := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "BidirectionalStream"), apiName+"/"+vald.StreamSearchRPCName+"/requestID-"+req.GetConfig().GetRequestId()) + defer func() { + if sspan != nil { + sspan.End() + } + }() + res, err := s.Search(ctx, req) + if err != nil { + st, msg, err := status.ParseError(err, codes.Internal, "failed to parse "+vald.SearchRPCName+" gRPC error response") + if sspan != nil { + sspan.RecordError(err) + sspan.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + sspan.SetStatus(trace.StatusError, err.Error()) + } + return &payload.Search_StreamResponse{ + Payload: &payload.Search_StreamResponse_Status{ + Status: st.Proto(), + }, + }, err + } + return &payload.Search_StreamResponse{ + Payload: &payload.Search_StreamResponse_Response{ + Response: res, + }, + }, nil + }, + ) + if err != nil { + st, msg, err := status.ParseError(err, codes.Internal, + "failed to parse "+vald.StreamSearchRPCName+" gRPC error response") + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + return err + } + return nil +} + +func (s *server) StreamSearchByID(stream vald.Search_StreamSearchByIDServer) (err error) { + ctx, span := trace.StartSpan(grpc.WithGRPCMethod(stream.Context(), vald.PackageName+"."+vald.SearchRPCServiceName+"/"+vald.StreamSearchByIDRPCName), apiName+"/"+vald.StreamSearchByIDRPCName) + defer func() { + if span != nil { + span.End() + } + }() + err = grpc.BidirectionalStream(ctx, stream, s.streamConcurrency, + func(ctx context.Context, req *payload.Search_IDRequest) (*payload.Search_StreamResponse, error) { + ctx, sspan := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "BidirectionalStream"), apiName+"."+vald.StreamSearchByIDRPCName+"/id-"+req.GetId()) + defer func() { + if sspan != nil { + sspan.End() + } + }() + res, err := s.SearchByID(ctx, req) + if err != nil { + st, msg, err := status.ParseError(err, codes.Internal, "failed to parse "+vald.SearchByIDRPCName+" gRPC error response") + if sspan != nil { + sspan.RecordError(err) + sspan.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + sspan.SetStatus(trace.StatusError, err.Error()) + } + return &payload.Search_StreamResponse{ + Payload: &payload.Search_StreamResponse_Status{ + Status: st.Proto(), + }, + }, err + } + return &payload.Search_StreamResponse{ + Payload: &payload.Search_StreamResponse_Response{ + Response: res, + }, + }, nil + }, + ) + if err != nil { + st, msg, err := status.ParseError(err, codes.Internal, + "failed to parse "+vald.StreamSearchByIDRPCName+" gRPC error response") + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + return err + } + return nil +} + +func (s *server) MultiSearch(ctx context.Context, req *payload.Search_MultiRequest) (res *payload.Search_Responses, err error) { + ctx, span := trace.StartSpan(grpc.WithGRPCMethod(ctx, vald.PackageName+"."+vald.SearchRPCServiceName+"/"+vald.MultiSearchRPCName), apiName+"/"+vald.MultiSearchRPCName) + defer func() { + if span != nil { + span.End() + } + }() + + _, err = s.gateway.Do(ctx, s.vAddr, func(ctx context.Context, vc vald.ClientWithMirror, copts ...grpc.CallOption) (interface{}, error) { + res, err = vc.MultiSearch(ctx, req, copts...) + if err != nil { + return nil, err + } + return res, nil + }) + if err != nil { + reqInfo := &errdetails.RequestInfo{ + ServingData: errdetails.Serialize(req), + } + resInfo := &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.MultiSearchRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, s.vAddr), + } + var attrs trace.Attributes + + switch { + case errors.Is(err, context.Canceled): + err = status.WrapWithCanceled( + vald.MultiSearchRPCName+" API canceld", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeCancelled(err.Error()) + case errors.Is(err, context.DeadlineExceeded): + err = status.WrapWithDeadlineExceeded( + vald.MultiSearchRPCName+" API deadline exceeded", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeDeadlineExceeded(err.Error()) + case errors.Is(err, errors.ErrTargetNotFound): + err = status.WrapWithInternal( + vald.MultiSearchRPCName+" API target not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + case errors.Is(err, errors.ErrGRPCClientConnNotFound("*")): + err = status.WrapWithInternal( + vald.MultiSearchRPCName+" API connection not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + default: + var ( + st *status.Status + msg string + ) + st, msg, err = status.ParseError(err, codes.Internal, + "failed to parse "+vald.MultiSearchRPCName+" gRPC error response", reqInfo, resInfo, + ) + attrs = trace.FromGRPCStatus(st.Code(), msg) + } + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(attrs...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + return res, nil +} + +func (s *server) MultiSearchByID(ctx context.Context, req *payload.Search_MultiIDRequest) (res *payload.Search_Responses, err error) { + ctx, span := trace.StartSpan(grpc.WithGRPCMethod(ctx, vald.PackageName+"."+vald.SearchRPCServiceName+"/"+vald.MultiSearchByIDRPCName), apiName+"/"+vald.MultiSearchByIDRPCName) + defer func() { + if span != nil { + span.End() + } + }() + + _, err = s.gateway.Do(ctx, s.vAddr, func(ctx context.Context, vc vald.ClientWithMirror, copts ...grpc.CallOption) (interface{}, error) { + res, err = vc.MultiSearchByID(ctx, req, copts...) + if err != nil { + return nil, err + } + return res, nil + }) + if err != nil { + reqInfo := &errdetails.RequestInfo{ + ServingData: errdetails.Serialize(req), + } + resInfo := &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.MultiSearchByIDRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, s.vAddr), + } + var attrs trace.Attributes + + switch { + case errors.Is(err, context.Canceled): + err = status.WrapWithCanceled( + vald.MultiSearchByIDRPCName+" API canceld", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeCancelled(err.Error()) + case errors.Is(err, context.DeadlineExceeded): + err = status.WrapWithDeadlineExceeded( + vald.MultiSearchByIDRPCName+" API deadline exceeded", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeDeadlineExceeded(err.Error()) + case errors.Is(err, errors.ErrTargetNotFound): + err = status.WrapWithInternal( + vald.MultiSearchByIDRPCName+" API target not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + case errors.Is(err, errors.ErrGRPCClientConnNotFound("*")): + err = status.WrapWithInternal( + vald.MultiSearchByIDRPCName+" API connection not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + default: + var ( + st *status.Status + msg string + ) + st, msg, err = status.ParseError(err, codes.Internal, + "failed to parse "+vald.MultiSearchByIDRPCName+" gRPC error response", reqInfo, resInfo, + ) + attrs = trace.FromGRPCStatus(st.Code(), msg) + } + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(attrs...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + return res, nil +} + +func (s *server) LinearSearch(ctx context.Context, req *payload.Search_Request) (res *payload.Search_Response, err error) { + ctx, span := trace.StartSpan(grpc.WithGRPCMethod(ctx, vald.PackageName+"."+vald.SearchRPCServiceName+"/"+vald.LinearSearchRPCName), apiName+"/"+vald.LinearSearchRPCName) + defer func() { + if span != nil { + span.End() + } + }() + + _, err = s.gateway.Do(ctx, s.vAddr, func(ctx context.Context, vc vald.ClientWithMirror, copts ...grpc.CallOption) (interface{}, error) { + res, err = vc.LinearSearch(ctx, req, copts...) + if err != nil { + return nil, err + } + return res, nil + }) + if err != nil { + reqInfo := &errdetails.RequestInfo{ + RequestId: req.GetConfig().GetRequestId(), + ServingData: errdetails.Serialize(req), + } + resInfo := &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.LinearSearchRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, s.vAddr), + } + var attrs trace.Attributes + + switch { + case errors.Is(err, context.Canceled): + err = status.WrapWithCanceled( + vald.LinearSearchRPCName+" API canceld", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeCancelled(err.Error()) + case errors.Is(err, context.DeadlineExceeded): + err = status.WrapWithDeadlineExceeded( + vald.LinearSearchRPCName+" API deadline exceeded", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeDeadlineExceeded(err.Error()) + case errors.Is(err, errors.ErrTargetNotFound): + err = status.WrapWithInternal( + vald.LinearSearchRPCName+" API target not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + case errors.Is(err, errors.ErrGRPCClientConnNotFound("*")): + err = status.WrapWithInternal( + vald.LinearSearchRPCName+" API connection not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + default: + var ( + st *status.Status + msg string + ) + st, msg, err = status.ParseError(err, codes.Internal, + "failed to parse "+vald.LinearSearchRPCName+" gRPC error response", reqInfo, resInfo, + ) + attrs = trace.FromGRPCStatus(st.Code(), msg) + } + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(attrs...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + return res, nil +} + +func (s *server) LinearSearchByID(ctx context.Context, req *payload.Search_IDRequest) ( + res *payload.Search_Response, err error, +) { + ctx, span := trace.StartSpan(grpc.WithGRPCMethod(ctx, vald.PackageName+"."+vald.SearchRPCServiceName+"/"+vald.LinearSearchByIDRPCName), apiName+"/"+vald.LinearSearchByIDRPCName) + defer func() { + if span != nil { + span.End() + } + }() + + _, err = s.gateway.Do(ctx, s.vAddr, func(ctx context.Context, vc vald.ClientWithMirror, copts ...grpc.CallOption) (interface{}, error) { + res, err = vc.LinearSearchByID(ctx, req, copts...) + if err != nil { + return nil, err + } + return res, nil + }) + if err != nil { + reqInfo := &errdetails.RequestInfo{ + RequestId: req.GetConfig().GetRequestId(), + ServingData: errdetails.Serialize(req), + } + resInfo := &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.LinearSearchByIDRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, s.vAddr), + } + var attrs trace.Attributes + + switch { + case errors.Is(err, context.Canceled): + err = status.WrapWithCanceled( + vald.LinearSearchByIDRPCName+" API canceld", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeCancelled(err.Error()) + case errors.Is(err, context.DeadlineExceeded): + err = status.WrapWithDeadlineExceeded( + vald.LinearSearchByIDRPCName+" API deadline exceeded", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeDeadlineExceeded(err.Error()) + case errors.Is(err, errors.ErrTargetNotFound): + err = status.WrapWithInternal( + vald.LinearSearchByIDRPCName+" API target not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + case errors.Is(err, errors.ErrGRPCClientConnNotFound("*")): + err = status.WrapWithInternal( + vald.LinearSearchByIDRPCName+" API connection not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + default: + var ( + st *status.Status + msg string + ) + st, msg, err = status.ParseError(err, codes.Internal, + "failed to parse "+vald.LinearSearchByIDRPCName+" gRPC error response", reqInfo, resInfo, + ) + attrs = trace.FromGRPCStatus(st.Code(), msg) + } + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(attrs...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + return res, nil +} + +func (s *server) StreamLinearSearch(stream vald.Search_StreamLinearSearchServer) (err error) { + ctx, span := trace.StartSpan(grpc.WithGRPCMethod(stream.Context(), vald.PackageName+"."+vald.SearchRPCServiceName+"/"+vald.StreamLinearSearchRPCName), apiName+"/"+vald.StreamLinearSearchRPCName) + defer func() { + if span != nil { + span.End() + } + }() + err = grpc.BidirectionalStream(ctx, stream, s.streamConcurrency, + func(ctx context.Context, req *payload.Search_Request) (*payload.Search_StreamResponse, error) { + ctx, sspan := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "BidirectionalStream"), apiName+"/"+vald.StreamLinearSearchRPCName+"/requestID-"+req.GetConfig().GetRequestId()) + defer func() { + if sspan != nil { + sspan.End() + } + }() + res, err := s.LinearSearch(ctx, req) + if err != nil { + st, msg, err := status.ParseError(err, codes.Internal, "failed to parse "+vald.LinearSearchRPCName+" gRPC error response") + if sspan != nil { + sspan.RecordError(err) + sspan.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + sspan.SetStatus(trace.StatusError, err.Error()) + } + return &payload.Search_StreamResponse{ + Payload: &payload.Search_StreamResponse_Status{ + Status: st.Proto(), + }, + }, err + } + return &payload.Search_StreamResponse{ + Payload: &payload.Search_StreamResponse_Response{ + Response: res, + }, + }, nil + }, + ) + if err != nil { + st, msg, err := status.ParseError(err, codes.Internal, + "failed to parse "+vald.StreamLinearSearchRPCName+" gRPC error response") + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + return err + } + return nil +} + +func (s *server) StreamLinearSearchByID(stream vald.Search_StreamLinearSearchByIDServer) (err error) { + ctx, span := trace.StartSpan(grpc.WithGRPCMethod(stream.Context(), vald.PackageName+"."+vald.SearchRPCServiceName+"/"+vald.StreamLinearSearchByIDRPCName), apiName+"/"+vald.StreamLinearSearchByIDRPCName) + defer func() { + if span != nil { + span.End() + } + }() + err = grpc.BidirectionalStream(ctx, stream, s.streamConcurrency, + func(ctx context.Context, req *payload.Search_IDRequest) (*payload.Search_StreamResponse, error) { + ctx, sspan := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "BidirectionalStream"), apiName+"."+vald.StreamLinearSearchByIDRPCName+"/id-"+req.GetId()) + defer func() { + if sspan != nil { + sspan.End() + } + }() + res, err := s.LinearSearchByID(ctx, req) + if err != nil { + st, msg, err := status.ParseError(err, codes.Internal, "failed to parse "+vald.LinearSearchByIDRPCName+" gRPC error response") + if sspan != nil { + sspan.RecordError(err) + sspan.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + sspan.SetStatus(trace.StatusError, err.Error()) + } + return &payload.Search_StreamResponse{ + Payload: &payload.Search_StreamResponse_Status{ + Status: st.Proto(), + }, + }, err + } + return &payload.Search_StreamResponse{ + Payload: &payload.Search_StreamResponse_Response{ + Response: res, + }, + }, nil + }, + ) + if err != nil { + st, msg, err := status.ParseError(err, codes.Internal, + "failed to parse "+vald.StreamLinearSearchByIDRPCName+" gRPC error response") + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + return err + } + return nil +} + +func (s *server) MultiLinearSearch(ctx context.Context, req *payload.Search_MultiRequest) (res *payload.Search_Responses, err error) { + ctx, span := trace.StartSpan(grpc.WithGRPCMethod(ctx, vald.PackageName+"."+vald.SearchRPCServiceName+"/"+vald.MultiLinearSearchRPCName), apiName+"/"+vald.MultiLinearSearchRPCName) + defer func() { + if span != nil { + span.End() + } + }() + + _, err = s.gateway.Do(ctx, s.vAddr, func(ctx context.Context, vc vald.ClientWithMirror, copts ...grpc.CallOption) (interface{}, error) { + res, err = vc.MultiLinearSearch(ctx, req, copts...) + if err != nil { + return nil, err + } + return res, nil + }) + if err != nil { + reqInfo := &errdetails.RequestInfo{ + ServingData: errdetails.Serialize(req), + } + resInfo := &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.MultiLinearSearchRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, s.vAddr), + } + var attrs trace.Attributes + + switch { + case errors.Is(err, context.Canceled): + err = status.WrapWithCanceled( + vald.MultiLinearSearchRPCName+" API canceld", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeCancelled(err.Error()) + case errors.Is(err, context.DeadlineExceeded): + err = status.WrapWithDeadlineExceeded( + vald.MultiLinearSearchRPCName+" API deadline exceeded", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeDeadlineExceeded(err.Error()) + case errors.Is(err, errors.ErrTargetNotFound): + err = status.WrapWithInternal( + vald.MultiLinearSearchRPCName+" API target not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + case errors.Is(err, errors.ErrGRPCClientConnNotFound("*")): + err = status.WrapWithInternal( + vald.MultiLinearSearchRPCName+" API connection not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + default: + var ( + st *status.Status + msg string + ) + st, msg, err = status.ParseError(err, codes.Internal, + "failed to parse "+vald.MultiLinearSearchRPCName+" gRPC error response", reqInfo, resInfo, + ) + attrs = trace.FromGRPCStatus(st.Code(), msg) + } + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(attrs...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + return res, nil +} + +func (s *server) MultiLinearSearchByID(ctx context.Context, req *payload.Search_MultiIDRequest) (res *payload.Search_Responses, err error) { + ctx, span := trace.StartSpan(grpc.WithGRPCMethod(ctx, vald.PackageName+"."+vald.SearchRPCServiceName+"/"+vald.MultiLinearSearchByIDRPCName), apiName+"/"+vald.MultiLinearSearchByIDRPCName) + defer func() { + if span != nil { + span.End() + } + }() + + _, err = s.gateway.Do(ctx, s.vAddr, func(ctx context.Context, vc vald.ClientWithMirror, copts ...grpc.CallOption) (interface{}, error) { + res, err = vc.MultiLinearSearchByID(ctx, req, copts...) + if err != nil { + return nil, err + } + return res, nil + }) + if err != nil { + reqInfo := &errdetails.RequestInfo{ + ServingData: errdetails.Serialize(req), + } + resInfo := &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.MultiLinearSearchByIDRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, s.vAddr), + } + var attrs trace.Attributes + + switch { + case errors.Is(err, context.Canceled): + err = status.WrapWithCanceled( + vald.MultiLinearSearchByIDRPCName+" API canceld", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeCancelled(err.Error()) + case errors.Is(err, context.DeadlineExceeded): + err = status.WrapWithDeadlineExceeded( + vald.MultiLinearSearchByIDRPCName+" API deadline exceeded", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeDeadlineExceeded(err.Error()) + case errors.Is(err, errors.ErrTargetNotFound): + err = status.WrapWithInternal( + vald.MultiLinearSearchByIDRPCName+" API target not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + case errors.Is(err, errors.ErrGRPCClientConnNotFound("*")): + err = status.WrapWithInternal( + vald.MultiLinearSearchByIDRPCName+" API connection not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + default: + var ( + st *status.Status + msg string + ) + st, msg, err = status.ParseError(err, codes.Internal, + "failed to parse "+vald.MultiLinearSearchByIDRPCName+" gRPC error response", reqInfo, resInfo, + ) + attrs = trace.FromGRPCStatus(st.Code(), msg) + } + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(attrs...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + return res, nil +} + +func (s *server) Insert(ctx context.Context, req *payload.Insert_Request) (loc *payload.Object_Location, err error) { + ctx, span := trace.StartSpan(grpc.WithGRPCMethod(ctx, vald.PackageName+"."+vald.InsertRPCServiceName+"/"+vald.InsertRPCName), apiName+"/"+vald.InsertRPCName) + defer func() { + if span != nil { + span.End() + } + }() + + reqSrcPodName := s.gateway.FromForwardedContext(ctx) + + // When this condition is matched, the request is proxied to another Mirror gateway. + // So this component sends requests only to the Vald gateway (LB gateway) of its own cluster. + if len(reqSrcPodName) != 0 { + _, err = s.gateway.Do(ctx, s.vAddr, func(ctx context.Context, vc vald.ClientWithMirror, copts ...grpc.CallOption) (interface{}, error) { + loc, err = vc.Insert(ctx, req, copts...) + if err != nil { + return nil, err + } + return loc, nil + }) + if err != nil { + reqInfo := &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + ServingData: errdetails.Serialize(req), + } + resInfo := &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.InsertRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, s.vAddr), + } + var attrs trace.Attributes + + switch { + case errors.Is(err, context.Canceled): + err = status.WrapWithCanceled( + vald.InsertRPCName+" API canceld", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeCancelled(err.Error()) + case errors.Is(err, context.DeadlineExceeded): + err = status.WrapWithDeadlineExceeded( + vald.InsertRPCName+" API deadline exceeded", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeDeadlineExceeded(err.Error()) + case errors.Is(err, errors.ErrTargetNotFound): + err = status.WrapWithInternal( + vald.InsertRPCName+" API target not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + case errors.Is(err, errors.ErrGRPCClientConnNotFound("*")): + err = status.WrapWithInternal( + vald.InsertRPCName+" API connection not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + default: + var ( + st *status.Status + msg string + ) + st, msg, err = status.ParseError(err, codes.Internal, + "failed to parse "+vald.InsertRPCName+" gRPC error response", reqInfo, resInfo, + ) + attrs = trace.FromGRPCStatus(st.Code(), msg) + } + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(attrs...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + log.Debugf("Insert API succeeded to %#v", loc) + return loc, nil + } + + var mu sync.Mutex + var result sync.Map + loc = &payload.Object_Location{ + Uuid: req.GetVector().GetId(), + Ips: make([]string, 0), + } + err = s.gateway.BroadCast(ctx, func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error { + ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "BroadCast/"+target), apiName+"/"+vald.InsertRPCName+"/"+target) + defer func() { + if span != nil { + span.End() + } + }() + + ce, err := s.insert(ctx, vc, req, copts...) + if err != nil { + st, _, _ := status.ParseError(err, codes.Internal, + "failed to parse "+vald.InsertRPCName+" gRPC error response", + &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + ServingData: errdetails.Serialize(req), + }, + &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.InsertRPCName + ".BroadCast/" + target, + ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, target), + }, + ) + if st.Code() == codes.AlreadyExists { + // NOTE: If it is strictly necessary to check, fix this logic. + return nil + } + } + if err == nil && ce != nil { + mu.Lock() + loc.Name = ce.GetName() + loc.Ips = append(loc.Ips, ce.GetIps()...) + mu.Unlock() + } + result.Store(target, err) + return err + }) + if err != nil { + reqInfo := &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + ServingData: errdetails.Serialize(req), + } + resInfo := &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.InsertRPCName + ".BroadCast", + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + } + if errors.Is(err, errors.ErrGRPCClientConnNotFound("*")) { + err = status.WrapWithInternal( + vald.InsertRPCName+" API connection not found", err, reqInfo, resInfo, + ) + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.StatusCodeInternal(err.Error())...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + + // There is no possibility to reach this part, but we add error handling just in case. + st, msg, err := status.ParseError(err, codes.Internal, + "failed to parse "+vald.InsertRPCName+" gRPC error response", reqInfo, resInfo, + ) + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + + var errs error + targets := make([]string, 0, 10) + result.Range(func(target, err any) bool { + if err == nil { + targets = append(targets, target.(string)) + } else { + if err, ok := err.(error); ok && err != nil { + if errs != nil { + errs = errors.Join(errs, err) + } else { + errs = err + } + } + } + return true + }) + switch { + case errs == nil: + log.Debugf("Insert API mirror request succeeded to %#v", loc) + return loc, nil + case len(targets) == 0 && errs != nil: + log.Error("failed to Insert API mirror request: %v and can not rollback because success target length is 0", errs) + st, msg, err := status.ParseError(errs, codes.Internal, + "failed to parse "+vald.InsertRPCName+" gRPC error response", + &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + ServingData: errdetails.Serialize(req), + }, + &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.InsertRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }, + ) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + log.Error("failed to Insert API mirror request: %v, so starts the rollback request", errs) + + var emu sync.Mutex + var rerrs error + rmReq := &payload.Remove_Request{ + Id: &payload.Object_ID{ + Id: req.GetVector().GetId(), + }, + } + err = s.gateway.DoMulti(ctx, targets, + func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error { + ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "rollback/BroadCast/"+target), apiName+"/"+vald.InsertRPCName+"/rollback/"+target) + defer func() { + if span != nil { + span.End() + } + }() + + _, err := s.remove(ctx, vc, rmReq, copts...) + if err != nil { + st, _, err := status.ParseError(err, codes.Internal, + "failed to parse "+vald.RemoveRPCName+" for "+vald.InsertRPCName+" error response for "+target, + &errdetails.RequestInfo{ + RequestId: rmReq.GetId().GetId(), + ServingData: errdetails.Serialize(rmReq), + }, + &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.InsertRPCName + "." + vald.RemoveRPCName + ".BroadCast/" + target, + ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, target), + }, + ) + if st.Code() == codes.NotFound { + return nil + } + emu.Lock() + if rerrs != nil { + rerrs = errors.Join(rerrs, err) + } else { + rerrs = err + } + emu.Unlock() + return err + } + return nil + }, + ) + if err != nil { + reqInfo := &errdetails.RequestInfo{ + RequestId: rmReq.GetId().GetId(), + ServingData: errdetails.Serialize(rmReq), + } + resInfo := &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.InsertRPCName + "." + vald.RemoveRPCName + ".BroadCast", + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + } + if errors.Is(err, errors.ErrGRPCClientConnNotFound("*")) { + err = status.WrapWithInternal( + vald.RemoveRPCName+" for "+vald.InsertRPCName+" API connection not found", err, reqInfo, resInfo, + ) + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.StatusCodeInternal(err.Error())...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + + // There is no possibility to reach this part, but we add error handling just in case. + st, msg, err := status.ParseError(err, codes.Internal, + "failed to parse "+vald.RemoveRPCName+" for "+vald.InsertRPCName+" gRPC error response", reqInfo, resInfo, + ) + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + if rerrs == nil { + log.Debugf("rollback for Insert API mirror request succeeded to %v", targets) + st, msg, err := status.ParseError(errs, codes.Internal, + "failed to parse "+vald.InsertRPCName+" gRPC error response", + &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + ServingData: errdetails.Serialize(req), + }, + &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.InsertRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }, + ) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + log.Debugf("failed to rollback for Insert API mirror request succeeded to %v: %v", targets, rerrs) + st, msg, err := status.ParseError(rerrs, codes.Internal, + "failed to parse "+vald.RemoveRPCName+" for "+vald.InsertRPCName+" gRPC error response", + &errdetails.RequestInfo{ + RequestId: rmReq.GetId().GetId(), + ServingData: errdetails.Serialize(rmReq), + }, + &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.InsertRPCName + "." + vald.RemoveRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s) %v", apiName, s.name, s.ip, targets), + }, + ) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err +} + +func (s *server) insert(ctx context.Context, client vald.InsertClient, req *payload.Insert_Request, opts ...grpc.CallOption) (loc *payload.Object_Location, err error) { + ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "insert"), apiName+"/insert") + defer func() { + if span != nil { + span.End() + } + }() + + loc, err = client.Insert(ctx, req, opts...) + if err != nil { + reqInfo := &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + ServingData: errdetails.Serialize(req), + } + resInfo := &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.InsertRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + } + var attrs trace.Attributes + + switch { + case errors.Is(err, context.Canceled): + err = status.WrapWithCanceled( + vald.InsertRPCName+" API canceld", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeCancelled(err.Error()) + case errors.Is(err, context.DeadlineExceeded): + err = status.WrapWithDeadlineExceeded( + vald.InsertRPCName+" API deadline exceeded", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeDeadlineExceeded(err.Error()) + case errors.Is(err, errors.ErrGRPCClientConnNotFound("*")): + err = status.WrapWithInternal( + vald.InsertRPCName+" API connection not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + default: + var ( + st *status.Status + msg string + ) + st, msg, err = status.ParseError(err, codes.Internal, + "failed to parse "+vald.InsertRPCName+" gRPC error response", reqInfo, resInfo, + ) + attrs = trace.FromGRPCStatus(st.Code(), msg) + } + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(attrs...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + return loc, nil +} + +func (s *server) StreamInsert(stream vald.Insert_StreamInsertServer) (err error) { + ctx, span := trace.StartSpan(grpc.WithGRPCMethod(stream.Context(), vald.PackageName+"."+vald.InsertRPCServiceName+"/"+vald.StreamInsertRPCName), apiName+"/"+vald.StreamInsertRPCName) + defer func() { + if span != nil { + span.End() + } + }() + err = grpc.BidirectionalStream(ctx, stream, s.streamConcurrency, + func(ctx context.Context, req *payload.Insert_Request) (*payload.Object_StreamLocation, error) { + ctx, sspan := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "BidirectionalStream"), apiName+"/"+vald.StreamInsertRPCName+"/id-"+req.GetVector().GetId()) + defer func() { + if sspan != nil { + sspan.End() + } + }() + res, err := s.Insert(ctx, req) + if err != nil { + st, msg, err := status.ParseError(err, codes.Internal, "failed to parse "+vald.InsertRPCName+" gRPC error response") + if sspan != nil { + sspan.RecordError(err) + sspan.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + sspan.SetStatus(trace.StatusError, err.Error()) + } + return &payload.Object_StreamLocation{ + Payload: &payload.Object_StreamLocation_Status{ + Status: st.Proto(), + }, + }, err + } + return &payload.Object_StreamLocation{ + Payload: &payload.Object_StreamLocation_Location{ + Location: res, + }, + }, nil + }, + ) + if err != nil { + st, msg, err := status.ParseError(err, codes.Internal, "failed to parse "+vald.StreamInsertRPCName+" gRPC error response") + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + return err + } + return nil +} + +func (s *server) MultiInsert(ctx context.Context, reqs *payload.Insert_MultiRequest) (res *payload.Object_Locations, errs error) { + ctx, span := trace.StartSpan(grpc.WithGRPCMethod(ctx, vald.PackageName+"."+vald.InsertRPCServiceName+"/"+vald.MultiInsertRPCName), apiName+"/"+vald.MultiInsertRPCName) + defer func() { + if span != nil { + span.End() + } + }() + + res = &payload.Object_Locations{ + Locations: make([]*payload.Object_Location, len(reqs.GetRequests())), + } + + var mu, emu sync.Mutex + var wg sync.WaitGroup + + for i, r := range reqs.GetRequests() { + idx, req := i, r + wg.Add(1) + s.eg.Go(safety.RecoverFunc(func() error { + defer wg.Done() + ti := "errgroup.Go/id-" + req.GetVector().GetId() + ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, ti), apiName+"/"+vald.MultiInsertRPCName+"/"+ti) + defer func() { + if span != nil { + span.End() + } + }() + + loc, err := s.Insert(ctx, req) + if err != nil { + st, msg, err := status.ParseError(err, codes.Internal, "failed to parse "+vald.InsertRPCName+" gRPC error response", + &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + ServingData: errdetails.Serialize(req), + }) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + emu.Lock() + if errs != nil { + errs = errors.Join(errs, err) + } else { + errs = err + } + emu.Unlock() + return nil + } + mu.Lock() + res.Locations[idx] = loc + mu.Unlock() + return nil + })) + } + wg.Wait() + if errs != nil { + st, msg, err := status.ParseError(errs, codes.Internal, "failed to parse "+vald.MultiInsertRPCName+" gRPC error response", + &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.MultiInsertRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + return res, err + } + return res, nil +} + +func (s *server) Update(ctx context.Context, req *payload.Update_Request) (loc *payload.Object_Location, err error) { + ctx, span := trace.StartSpan(grpc.WithGRPCMethod(ctx, vald.PackageName+"."+vald.UpdateRPCServiceName+"/"+vald.UpdateRPCName), apiName+"/"+vald.UpdateRPCName) + defer func() { + if span != nil { + span.End() + } + }() + + reqSrcPodName := s.gateway.FromForwardedContext(ctx) + + // When this condition is matched, the request is proxied to another Mirror gateway. + // So this component sends requests only to the Vald gateway (LB gateway) of its own cluster. + if len(reqSrcPodName) != 0 { + _, err = s.gateway.Do(ctx, s.vAddr, func(ctx context.Context, vc vald.ClientWithMirror, copts ...grpc.CallOption) (interface{}, error) { + loc, err = vc.Update(ctx, req, copts...) + if err != nil { + return nil, err + } + return loc, nil + }) + if err != nil { + reqInfo := &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + ServingData: errdetails.Serialize(req), + } + resInfo := &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpdateRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, s.vAddr), + } + var attrs trace.Attributes + + switch { + case errors.Is(err, context.Canceled): + err = status.WrapWithCanceled( + vald.UpdateRPCName+" API canceld", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeCancelled(err.Error()) + case errors.Is(err, context.DeadlineExceeded): + err = status.WrapWithDeadlineExceeded( + vald.UpdateRPCName+" API deadline exceeded", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeDeadlineExceeded(err.Error()) + case errors.Is(err, errors.ErrTargetNotFound): + err = status.WrapWithInternal( + vald.UpdateRPCName+" API target not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + case errors.Is(err, errors.ErrGRPCClientConnNotFound("*")): + err = status.WrapWithInternal( + vald.UpdateRPCName+" API connection not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + default: + var ( + st *status.Status + msg string + ) + st, msg, err = status.ParseError(err, codes.Internal, + "failed to parse "+vald.UpdateRPCName+" gRPC error response", reqInfo, resInfo, + ) + attrs = trace.FromGRPCStatus(st.Code(), msg) + } + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(attrs...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + log.Debugf("Update API succeeded to %#v", loc) + return loc, nil + } + + objReq := &payload.Object_VectorRequest{ + Id: &payload.Object_ID{ + Id: req.GetVector().GetId(), + }, + } + oldVecs, err := s.getObjects(ctx, objReq) + if err != nil { + return nil, err + } + + var mu sync.Mutex + var result sync.Map + loc = &payload.Object_Location{ + Uuid: req.GetVector().GetId(), + Ips: make([]string, 0), + } + + err = s.gateway.BroadCast(ctx, + func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error { + ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "BroadCast/"+target), apiName+"/"+vald.UpdateRPCName+"/"+target) + defer func() { + if span != nil { + span.End() + } + }() + + ce, err := s.update(ctx, vc, req, copts...) + if err != nil { + st, _, _ := status.ParseError(err, codes.Internal, + "failed to parse "+vald.UpdateRPCName+" API error response for "+target, + &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + ServingData: errdetails.Serialize(req), + }, + &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpdateRPCName + ".BroadCast/" + target, + ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, target), + }, + ) + if st.Code() == codes.AlreadyExists { + // NOTE: If it is strictly necessary to check, fix this logic. + return nil + } + } + if err == nil && ce != nil { + mu.Lock() + loc.Name = ce.GetName() + loc.Ips = append(loc.Ips, ce.GetIps()...) + mu.Unlock() + } + result.Store(target, err) + return err + }, + ) + if err != nil { + reqInfo := &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + ServingData: errdetails.Serialize(req), + } + resInfo := &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpdateRPCName + ".BroadCast", + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + } + if errors.Is(err, errors.ErrGRPCClientConnNotFound("*")) { + err = status.WrapWithInternal( + vald.UpdateRPCName+" API connection not found", err, reqInfo, resInfo, + ) + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.StatusCodeInternal(err.Error())...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + + // There is no possibility to reach this part, but we add error handling just in case. + st, msg, err := status.ParseError(err, codes.Internal, + "failed to parse "+vald.UpdateRPCName+" gRPC error response", reqInfo, resInfo, + ) + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + + var errs error + targets := make([]string, 0, 10) + result.Range(func(target, err any) bool { + if err == nil { + targets = append(targets, target.(string)) + } else { + if err, ok := err.(error); ok && err != nil { + if errs != nil { + errs = errors.Join(errs, err) + } else { + errs = err + } + } + } + return true + }) + switch { + case errs == nil: + log.Debugf("Update API mirror request succeeded to %#v", loc) + return loc, nil + case len(targets) == 0 && errs != nil: + log.Error("failed to Update API mirror request: %v and can not rollback because success target length is 0", errs) + st, msg, err := status.ParseError(errs, codes.Internal, + "failed to parse "+vald.UpdateRPCName+" gRPC error response", + &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + ServingData: errdetails.Serialize(req), + }, + &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpdateRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }, + ) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + log.Error("failed to Update API mirror request: %v, so starts the rollback request", errs) + + var emu sync.Mutex + var rerrs error + rmReq := &payload.Remove_Request{ + Id: &payload.Object_ID{ + Id: req.GetVector().GetId(), + }, + } + + err = s.gateway.DoMulti(ctx, targets, + func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error { + ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "rollback/BroadCast/"+target), apiName+"/"+vald.RemoveRPCName+"/rollback/"+target) + defer func() { + if span != nil { + span.End() + } + }() + + oldVec, ok := oldVecs.Load(target) + if !ok || oldVec == nil { + _, err := s.remove(ctx, vc, rmReq, copts...) + if err != nil { + st, _, _ := status.ParseError(err, codes.Internal, + "failed to parse "+vald.RemoveRPCName+" for "+vald.UpdateRPCName+" gRPC error response", + &errdetails.RequestInfo{ + RequestId: rmReq.GetId().GetId(), + ServingData: errdetails.Serialize(rmReq), + }, + &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpdateRPCName + "." + vald.RemoveRPCName + ".BroadCast/" + target, + ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, target), + }, + ) + if st.Code() == codes.NotFound { + return nil + } + emu.Lock() + if rerrs != nil { + rerrs = errors.Join(rerrs, err) + } else { + rerrs = err + } + emu.Unlock() + return err + } + return nil + } + + req := &payload.Update_Request{ + Vector: oldVec.(*payload.Object_Vector), + Config: &payload.Update_Config{ + SkipStrictExistCheck: true, + }, + } + _, err := s.update(ctx, vc, req, copts...) + if err != nil { + st, _, _ := status.ParseError(err, codes.Internal, + "failed to parse "+vald.UpdateRPCName+" for "+vald.UpdateRPCName+" gRPC error response", + &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + ServingData: errdetails.Serialize(req), + }, + &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpdateRPCName + "." + vald.UpdateRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, target), + }, + ) + if st.Code() == codes.AlreadyExists { + return nil + } + emu.Lock() + if rerrs != nil { + rerrs = errors.Join(rerrs, err) + } else { + rerrs = err + } + emu.Unlock() + return err + } + return nil + }, + ) + if err != nil { + reqInfo := &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + } + resInfo := &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpdateRPCName + ".Rollback.BroadCast", + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + } + if errors.Is(err, errors.ErrGRPCClientConnNotFound("*")) { + err = status.WrapWithInternal( + vald.UpdateRPCName+" for Rollback connection not found", err, reqInfo, resInfo, + ) + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.StatusCodeInternal(err.Error())...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + + // There is no possibility to reach this part, but we add error handling just in case. + st, msg, err := status.ParseError(err, codes.Internal, + "failed to parse "+vald.UpdateRPCName+" for Rollback gRPC error response", reqInfo, resInfo, + ) + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + if rerrs == nil { + log.Debugf("rollback for Update API mirror request succeeded to %v", targets) + st, msg, err := status.ParseError(errs, codes.Internal, + "failed to parse "+vald.UpdateRPCName+" gRPC error response", + &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + ServingData: errdetails.Serialize(req), + }, + &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpdateRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }, + ) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + log.Debugf("failed to rollback for Update API mirror request succeeded to %v: %v", targets, rerrs) + st, msg, err := status.ParseError(rerrs, codes.Internal, + "failed to parse "+vald.UpdateRPCName+" for Rollback gRPC error response", + &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + }, + &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpdateRPCName + ".Rollback", + ResourceName: fmt.Sprintf("%s: %s(%s) %v", apiName, s.name, s.ip, targets), + }, + ) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err +} + +func (s *server) update(ctx context.Context, client vald.UpdateClient, req *payload.Update_Request, opts ...grpc.CallOption) (loc *payload.Object_Location, err error) { + ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "update"), apiName+"/update") + defer func() { + if span != nil { + span.End() + } + }() + + loc, err = client.Update(ctx, req, opts...) + if err != nil { + reqInfo := &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + ServingData: errdetails.Serialize(req), + } + resInfo := &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpdateRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + } + var attrs trace.Attributes + + switch { + case errors.Is(err, context.Canceled): + err = status.WrapWithCanceled( + vald.UpdateRPCName+" API canceld", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeCancelled(err.Error()) + case errors.Is(err, context.DeadlineExceeded): + err = status.WrapWithDeadlineExceeded( + vald.UpdateRPCName+" API deadline exceeded", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeDeadlineExceeded(err.Error()) + case errors.Is(err, errors.ErrGRPCClientConnNotFound("*")): + err = status.WrapWithInternal( + vald.UpdateRPCName+" API connection not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + default: + var ( + st *status.Status + msg string + ) + st, msg, err = status.ParseError(err, codes.Internal, + "failed to parse "+vald.UpdateRPCName+" gRPC error response", reqInfo, resInfo, + ) + attrs = trace.FromGRPCStatus(st.Code(), msg) + } + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(attrs...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + return loc, nil +} + +func (s *server) StreamUpdate(stream vald.Update_StreamUpdateServer) (err error) { + ctx, span := trace.StartSpan(grpc.WithGRPCMethod(stream.Context(), vald.PackageName+"."+vald.UpdateRPCServiceName+"/"+vald.StreamUpdateRPCName), apiName+"/"+vald.StreamUpdateRPCName) + defer func() { + if span != nil { + span.End() + } + }() + err = grpc.BidirectionalStream(ctx, stream, s.streamConcurrency, + func(ctx context.Context, req *payload.Update_Request) (*payload.Object_StreamLocation, error) { + ctx, sspan := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "BidirectionalStream"), apiName+"/"+vald.StreamUpdateRPCName+"/id-"+req.GetVector().GetId()) + defer func() { + if sspan != nil { + sspan.End() + } + }() + res, err := s.Update(ctx, req) + if err != nil { + st, msg, err := status.ParseError(err, codes.Internal, "failed to parse "+vald.UpdateRPCName+" gRPC error response") + if sspan != nil { + sspan.RecordError(err) + sspan.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + sspan.SetStatus(trace.StatusError, err.Error()) + } + return &payload.Object_StreamLocation{ + Payload: &payload.Object_StreamLocation_Status{ + Status: st.Proto(), + }, + }, err + } + return &payload.Object_StreamLocation{ + Payload: &payload.Object_StreamLocation_Location{ + Location: res, + }, + }, nil + }, + ) + if err != nil { + st, msg, err := status.ParseError(err, codes.Internal, "failed to parse "+vald.StreamUpdateRPCName+" gRPC error response") + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + return err + } + return nil +} + +func (s *server) MultiUpdate(ctx context.Context, reqs *payload.Update_MultiRequest) (res *payload.Object_Locations, errs error) { + ctx, span := trace.StartSpan(grpc.WithGRPCMethod(ctx, vald.PackageName+"."+vald.UpdateRPCServiceName+"/"+vald.MultiUpdateRPCName), apiName+"/"+vald.MultiUpdateRPCName) + defer func() { + if span != nil { + span.End() + } + }() + + res = &payload.Object_Locations{ + Locations: make([]*payload.Object_Location, len(reqs.GetRequests())), + } + + var mu, emu sync.Mutex + var wg sync.WaitGroup + + for i, r := range reqs.GetRequests() { + idx, req := i, r + wg.Add(1) + s.eg.Go(safety.RecoverFunc(func() error { + defer wg.Done() + ti := "errgroup.Go/id-" + req.GetVector().GetId() + ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, ti), apiName+"/"+vald.MultiUpdateRPCName+"/"+ti) + defer func() { + if span != nil { + span.End() + } + }() + + loc, err := s.Update(ctx, req) + if err != nil { + st, msg, err := status.ParseError(err, codes.Internal, "failed to parse "+vald.UpdateRPCName+" gRPC error response", + &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + ServingData: errdetails.Serialize(req), + }) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + emu.Lock() + if errs != nil { + errs = errors.Join(errs, err) + } else { + errs = err + } + emu.Unlock() + return nil + } + mu.Lock() + res.Locations[idx] = loc + mu.Unlock() + return nil + })) + } + wg.Wait() + if errs != nil { + st, msg, err := status.ParseError(errs, codes.Internal, "failed to parse "+vald.MultiUpdateRPCName+" gRPC error response", + &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.MultiUpdateRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + return res, err + } + return res, nil +} + +func (s *server) Upsert(ctx context.Context, req *payload.Upsert_Request) (loc *payload.Object_Location, err error) { + ctx, span := trace.StartSpan(grpc.WithGRPCMethod(ctx, vald.PackageName+"."+vald.UpsertRPCServiceName+"/"+vald.UpsertRPCName), apiName+"/"+vald.UpsertRPCName) + defer func() { + if span != nil { + span.End() + } + }() + + reqSrcPodName := s.gateway.FromForwardedContext(ctx) + + // When this condition is matched, the request is proxied to another Mirror gateway. + // So this component sends requests only to the Vald gateway (LB gateway) of its own cluster. + if len(reqSrcPodName) != 0 { + _, err = s.gateway.Do(ctx, s.vAddr, func(ctx context.Context, vc vald.ClientWithMirror, copts ...grpc.CallOption) (interface{}, error) { + loc, err = vc.Upsert(ctx, req, copts...) + if err != nil { + return nil, err + } + return loc, nil + }) + if err != nil { + reqInfo := &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + ServingData: errdetails.Serialize(req), + } + resInfo := &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpsertRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, s.vAddr), + } + var attrs trace.Attributes + + switch { + case errors.Is(err, context.Canceled): + err = status.WrapWithCanceled( + vald.UpsertRPCName+" API canceld", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeCancelled(err.Error()) + case errors.Is(err, context.DeadlineExceeded): + err = status.WrapWithDeadlineExceeded( + vald.UpsertRPCName+" API deadline exceeded", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeDeadlineExceeded(err.Error()) + case errors.Is(err, errors.ErrTargetNotFound): + err = status.WrapWithInternal( + vald.UpsertRPCName+" API target not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + case errors.Is(err, errors.ErrGRPCClientConnNotFound("*")): + err = status.WrapWithInternal( + vald.UpsertRPCName+" API connection not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + default: + var ( + st *status.Status + msg string + ) + st, msg, err = status.ParseError(err, codes.Internal, + "failed to parse "+vald.UpsertRPCName+" gRPC error response", reqInfo, resInfo, + ) + attrs = trace.FromGRPCStatus(st.Code(), msg) + } + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(attrs...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + log.Debugf("Upsert API succeeded to %#v", loc) + return loc, nil + } + + objReq := &payload.Object_VectorRequest{ + Id: &payload.Object_ID{ + Id: req.GetVector().GetId(), + }, + } + oldVecs, err := s.getObjects(ctx, objReq) + if err != nil { + return nil, err + } + + var mu sync.Mutex + var result sync.Map + loc = &payload.Object_Location{ + Uuid: req.GetVector().GetId(), + Ips: make([]string, 0), + } + err = s.gateway.BroadCast(ctx, func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error { + ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "BroadCast/"+target), apiName+"/"+vald.UpsertRPCName+"/"+target) + defer func() { + if span != nil { + span.End() + } + }() + + ce, err := s.upsert(ctx, vc, req, copts...) + if err != nil { + st, _, _ := status.ParseError(err, codes.Internal, + "failed to parse "+vald.UpsertRPCName+" gRPC error response", + &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + ServingData: errdetails.Serialize(req), + }, + &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpsertRPCName + ".BroadCast/" + target, + ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, target), + }, + ) + if st.Code() == codes.AlreadyExists { + // NOTE: If it is strictly necessary to check, fix this logic. + return nil + } + } + if err == nil && ce != nil { + mu.Lock() + loc.Name = ce.GetName() + loc.Ips = append(loc.Ips, ce.GetIps()...) + mu.Unlock() + } + result.Store(target, err) + return err + }) + if err != nil { + reqInfo := &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + ServingData: errdetails.Serialize(req), + } + resInfo := &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpsertRPCName + ".BroadCast", + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + } + if errors.Is(err, errors.ErrGRPCClientConnNotFound("*")) { + err = status.WrapWithInternal( + vald.UpsertRPCName+" API connection not found", err, reqInfo, resInfo, + ) + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.StatusCodeInternal(err.Error())...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + + // There is no possibility to reach this part, but we add error handling just in case. + st, msg, err := status.ParseError(err, codes.Internal, + "failed to parse "+vald.UpsertRPCName+" gRPC error response", reqInfo, resInfo, + ) + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + + var errs error + targets := make([]string, 0, 10) + result.Range(func(target, err any) bool { + if err == nil { + targets = append(targets, target.(string)) + } else { + if err, ok := err.(error); ok && err != nil { + if errs != nil { + errs = errors.Join(errs, err) + } else { + errs = err + } + } + } + return true + }) + switch { + case errs == nil: + log.Debugf("Upsert API mirror request succeeded to %#v", loc) + return loc, nil + case len(targets) == 0 && errs != nil: + log.Error("failed to Upsert API mirror request: %v and can not rollback because success target length is 0", errs) + st, msg, err := status.ParseError(errs, codes.Internal, + "failed to parse "+vald.UpsertRPCName+" gRPC error response", + &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + ServingData: errdetails.Serialize(req), + }, + &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpsertRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }, + ) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + log.Error("failed to Upsert API mirror request: %v, so starts the rollback request", errs) + + var emu sync.Mutex + var rerrs error + rmReq := &payload.Remove_Request{ + Id: &payload.Object_ID{ + Id: req.GetVector().GetId(), + }, + } + err = s.gateway.DoMulti(ctx, targets, + func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error { + ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "rollback/BroadCast/"+target), apiName+"/"+vald.UpsertRPCName+"/rollback/"+target) + defer func() { + if span != nil { + span.End() + } + }() + + oldVec, ok := oldVecs.Load(target) + if !ok || oldVec == nil { + _, err := s.remove(ctx, vc, rmReq, copts...) + if err != nil { + st, _, _ := status.ParseError(err, codes.Internal, + "failed to parse "+vald.RemoveRPCName+" for "+vald.UpsertRPCName+" gRPC error response", + &errdetails.RequestInfo{ + RequestId: rmReq.GetId().GetId(), + ServingData: errdetails.Serialize(rmReq), + }, + &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpsertRPCName + "." + vald.RemoveRPCName + ".BroadCast/" + target, + ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, target), + }, + ) + if st.Code() == codes.NotFound { + return nil + } + emu.Lock() + if rerrs != nil { + rerrs = errors.Join(rerrs, err) + } else { + rerrs = err + } + emu.Unlock() + return err + } + return nil + } + + req := &payload.Update_Request{ + Vector: oldVec.(*payload.Object_Vector), + Config: &payload.Update_Config{ + SkipStrictExistCheck: true, + }, + } + _, err := s.update(ctx, vc, req, copts...) + if err != nil { + st, _, _ := status.ParseError(err, codes.Internal, + "failed to parse "+vald.UpdateRPCName+" for "+vald.UpsertRPCName+" gRPC error response", + &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + ServingData: errdetails.Serialize(req), + }, + &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpsertRPCName + "." + vald.UpdateRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, target), + }, + ) + if st.Code() == codes.AlreadyExists { + return nil + } + emu.Lock() + if rerrs != nil { + rerrs = errors.Join(rerrs, err) + } else { + rerrs = err + } + emu.Unlock() + return err + } + return nil + }, + ) + if err != nil { + reqInfo := &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + } + resInfo := &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpsertRPCName + ".Rollback.BroadCast", + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + } + if errors.Is(err, errors.ErrGRPCClientConnNotFound("*")) { + err = status.WrapWithInternal( + vald.UpsertRPCName+" for Rollback connection not found", err, reqInfo, resInfo, + ) + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.StatusCodeInternal(err.Error())...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + + // There is no possibility to reach this part, but we add error handling just in case. + st, msg, err := status.ParseError(err, codes.Internal, + "failed to parse "+vald.UpsertRPCName+" for Rollback gRPC error response", reqInfo, resInfo, + ) + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + if rerrs == nil { + log.Debugf("rollback for Upsert API mirror request succeeded to %v", targets) + st, msg, err := status.ParseError(errs, codes.Internal, + "failed to parse "+vald.UpsertRPCName+" gRPC error response", + &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + ServingData: errdetails.Serialize(req), + }, + &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpsertRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }, + ) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + log.Debugf("failed to rollback for Upsert API mirror request succeeded to %v: %v", targets, rerrs) + st, msg, err := status.ParseError(rerrs, codes.Internal, + "failed to parse "+vald.UpsertRPCName+" for Rollback gRPC error response", + &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + }, + &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpsertRPCName + ".Rollback", + ResourceName: fmt.Sprintf("%s: %s(%s) %v", apiName, s.name, s.ip, targets), + }, + ) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err +} + +func (s *server) upsert(ctx context.Context, client vald.UpsertClient, req *payload.Upsert_Request, opts ...grpc.CallOption) (loc *payload.Object_Location, err error) { + ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "upsert"), apiName+"/upsert") + defer func() { + if span != nil { + span.End() + } + }() + + loc, err = client.Upsert(ctx, req, opts...) + if err != nil { + reqInfo := &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + ServingData: errdetails.Serialize(req), + } + resInfo := &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.UpsertRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + } + var attrs trace.Attributes + + switch { + case errors.Is(err, context.Canceled): + err = status.WrapWithCanceled( + vald.UpsertRPCName+" API canceld", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeCancelled(err.Error()) + case errors.Is(err, context.DeadlineExceeded): + err = status.WrapWithDeadlineExceeded( + vald.UpsertRPCName+" API deadline exceeded", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeDeadlineExceeded(err.Error()) + case errors.Is(err, errors.ErrGRPCClientConnNotFound("*")): + err = status.WrapWithInternal( + vald.UpsertRPCName+" API connection not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + default: + var ( + st *status.Status + msg string + ) + st, msg, err = status.ParseError(err, codes.Internal, + "failed to parse "+vald.UpsertRPCName+" gRPC error response", reqInfo, resInfo, + ) + attrs = trace.FromGRPCStatus(st.Code(), msg) + } + log.Warn("failed to process Upsert request\terror: %s", err.Error()) + if span != nil { + span.RecordError(err) + span.SetAttributes(attrs...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + return loc, nil +} + +func (s *server) StreamUpsert(stream vald.Upsert_StreamUpsertServer) (err error) { + ctx, span := trace.StartSpan(grpc.WithGRPCMethod(stream.Context(), vald.PackageName+"."+vald.UpsertRPCServiceName+"/"+vald.StreamUpsertRPCName), apiName+"/"+vald.StreamUpsertRPCName) + defer func() { + if span != nil { + span.End() + } + }() + err = grpc.BidirectionalStream(ctx, stream, s.streamConcurrency, + func(ctx context.Context, req *payload.Upsert_Request) (*payload.Object_StreamLocation, error) { + ctx, sspan := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "BidirectionalStream"), apiName+"/"+vald.StreamUpsertRPCName+"/id-"+req.GetVector().GetId()) + defer func() { + if sspan != nil { + sspan.End() + } + }() + res, err := s.Upsert(ctx, req) + if err != nil { + st, msg, err := status.ParseError(err, codes.Internal, "failed to parse "+vald.UpsertRPCName+" gRPC error response") + if sspan != nil { + sspan.RecordError(err) + sspan.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + sspan.SetStatus(trace.StatusError, err.Error()) + } + return &payload.Object_StreamLocation{ + Payload: &payload.Object_StreamLocation_Status{ + Status: st.Proto(), + }, + }, err + } + return &payload.Object_StreamLocation{ + Payload: &payload.Object_StreamLocation_Location{ + Location: res, + }, + }, nil + }, + ) + if err != nil { + st, msg, err := status.ParseError(err, codes.Internal, "failed to parse "+vald.StreamUpsertRPCName+" gRPC error response") + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + return err + } + return nil +} + +func (s *server) MultiUpsert(ctx context.Context, reqs *payload.Upsert_MultiRequest) (res *payload.Object_Locations, errs error) { + ctx, span := trace.StartSpan(grpc.WithGRPCMethod(ctx, vald.PackageName+"."+vald.UpsertRPCServiceName+"/"+vald.MultiUpsertRPCName), apiName+"/"+vald.MultiUpsertRPCName) + defer func() { + if span != nil { + span.End() + } + }() + + res = &payload.Object_Locations{ + Locations: make([]*payload.Object_Location, len(reqs.GetRequests())), + } + + var mu, emu sync.Mutex + var wg sync.WaitGroup + + for i, r := range reqs.GetRequests() { + idx, req := i, r + wg.Add(1) + s.eg.Go(safety.RecoverFunc(func() error { + defer wg.Done() + ti := "errgroup.Go/id-" + req.GetVector().GetId() + ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, ti), apiName+"/"+vald.MultiUpsertRPCName+"/"+ti) + defer func() { + if span != nil { + span.End() + } + }() + + loc, err := s.Upsert(ctx, req) + if err != nil { + st, msg, err := status.ParseError(err, codes.Internal, "failed to parse "+vald.UpsertRPCName+" gRPC error response", + &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + ServingData: errdetails.Serialize(req), + }) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + emu.Lock() + if errs != nil { + errs = errors.Join(errs, err) + } else { + errs = err + } + emu.Unlock() + return nil + } + mu.Lock() + res.Locations[idx] = loc + mu.Unlock() + return nil + })) + } + wg.Wait() + if errs != nil { + st, msg, err := status.ParseError(errs, codes.Internal, "failed to parse "+vald.MultiUpsertRPCName+" gRPC error response", + &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.MultiUpsertRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + return res, err + } + return res, nil +} + +func (s *server) Remove(ctx context.Context, req *payload.Remove_Request) (loc *payload.Object_Location, err error) { + ctx, span := trace.StartSpan(grpc.WithGRPCMethod(ctx, vald.PackageName+"."+vald.RemoveRPCServiceName+"/"+vald.RemoveRPCName), apiName+"/"+vald.RemoveRPCName) + defer func() { + if span != nil { + span.End() + } + }() + + reqSrcPodName := s.gateway.FromForwardedContext(ctx) + + // When this condition is matched, the request is proxied to another Mirror gateway. + // So this component sends the request only to the Vald gateway (LB gateway) of own cluster. + if len(reqSrcPodName) != 0 { + _, err = s.gateway.Do(ctx, s.vAddr, func(ctx context.Context, vc vald.ClientWithMirror, copts ...grpc.CallOption) (interface{}, error) { + loc, err = vc.Remove(ctx, req, copts...) + if err != nil { + return nil, err + } + return loc, nil + }) + if err != nil { + reqInfo := &errdetails.RequestInfo{ + RequestId: req.GetId().GetId(), + ServingData: errdetails.Serialize(req), + } + resInfo := &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.RemoveRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, s.vAddr), + } + var attrs trace.Attributes + + switch { + case errors.Is(err, context.Canceled): + err = status.WrapWithCanceled( + vald.RemoveRPCName+" API canceld", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeCancelled(err.Error()) + case errors.Is(err, context.DeadlineExceeded): + err = status.WrapWithDeadlineExceeded( + vald.RemoveRPCName+" API deadline exceeded", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeDeadlineExceeded(err.Error()) + case errors.Is(err, errors.ErrTargetNotFound): + err = status.WrapWithInternal( + vald.RemoveRPCName+" API target not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + case errors.Is(err, errors.ErrGRPCClientConnNotFound("*")): + err = status.WrapWithInternal( + vald.RemoveRPCName+" API connection not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + default: + var ( + st *status.Status + msg string + ) + st, msg, err = status.ParseError(err, codes.Internal, + "failed to parse "+vald.RemoveRPCName+" gRPC error response", reqInfo, resInfo, + ) + attrs = trace.FromGRPCStatus(st.Code(), msg) + } + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(attrs...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + log.Debugf("Remove API remove succeeded to %#v", loc) + return loc, nil + } + + objReq := &payload.Object_VectorRequest{ + Id: &payload.Object_ID{ + Id: req.GetId().GetId(), + }, + } + oldVecs, err := s.getObjects(ctx, objReq) + if err != nil { + return nil, err + } + + var mu sync.Mutex + var result sync.Map + loc = &payload.Object_Location{ + Uuid: req.GetId().GetId(), + Ips: make([]string, 0), + } + + err = s.gateway.BroadCast(ctx, func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error { + ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "BroadCast/"+target), apiName+"/"+vald.RemoveRPCName+"/"+target) + defer func() { + if span != nil { + span.End() + } + }() + + ce, err := s.remove(ctx, vc, req, copts...) + if err != nil { + st, _, _ := status.ParseError(err, codes.Internal, + "failed to parse "+vald.RemoveRPCName+" gRPC error response for "+target, + &errdetails.RequestInfo{ + RequestId: req.GetId().GetId(), + ServingData: errdetails.Serialize(req), + }, + &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.RemoveRPCName + ".BroadCast/" + target, + ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, target), + }, + ) + if st.Code() == codes.NotFound { + // NOTE: If it is strictly necessary to check, fix this logic. + return nil + } + } + if err == nil && ce != nil { + mu.Lock() + loc.Name = ce.GetName() + loc.Ips = append(loc.Ips, ce.GetIps()...) + mu.Unlock() + } + result.Store(target, err) + return err + }) + if err != nil { + reqInfo := &errdetails.RequestInfo{ + RequestId: req.GetId().GetId(), + ServingData: errdetails.Serialize(req), + } + resInfo := &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.RemoveRPCName + ".BroadCast", + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + } + if errors.Is(err, errors.ErrGRPCClientConnNotFound("*")) { + err = status.WrapWithInternal( + vald.RemoveRPCName+" API connection not found", err, reqInfo, resInfo, + ) + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.StatusCodeInternal(err.Error())...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + + // There is no possibility to reach this part, but we add error handling just in case. + st, msg, err := status.ParseError(err, codes.Internal, + "failed to parse "+vald.RemoveRPCName+" gRPC error response", reqInfo, resInfo, + ) + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + + var errs error + targets := make([]string, 0, 10) + result.Range(func(target, err any) bool { + if err == nil { + targets = append(targets, target.(string)) + } else { + if err, ok := err.(error); ok && err != nil { + if errs != nil { + errs = errors.Join(errs, err) + } else { + errs = err + } + } + } + return true + }) + switch { + case errs == nil: + log.Debugf("Remove API mirror request succeeded to %#v", loc) + return loc, nil + case len(targets) == 0 && errs != nil: + log.Error("failed to Remove API mirror request: %v and can not rollback because success target length is 0", errs) + st, msg, err := status.ParseError(errs, codes.Internal, + "failed to parse "+vald.RemoveRPCName+" gRPC error response", + &errdetails.RequestInfo{ + RequestId: req.GetId().GetId(), + ServingData: errdetails.Serialize(req), + }, + &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.RemoveRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }, + ) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + log.Error("failed to Remove API mirror request: %v, so starts the rollback request", errs) + + var emu sync.Mutex + var rerrs error + err = s.gateway.DoMulti(ctx, targets, + func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error { + ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "rollback/BroadCast/"+target), apiName+"/"+vald.RemoveRPCName+"/rollback/"+target) + defer func() { + if span != nil { + span.End() + } + }() + + objv, ok := oldVecs.Load(target) + if !ok || objv == nil { + log.Debug("failed to load old vector from %s", target) + return nil + } + req := &payload.Upsert_Request{ + Vector: objv.(*payload.Object_Vector), + Config: &payload.Upsert_Config{ + SkipStrictExistCheck: true, + }, + } + _, err := s.upsert(ctx, vc, req, copts...) + if err != nil { + st, _, _ := status.ParseError(err, codes.Internal, + "failed to parse "+vald.UpsertRPCName+" for "+vald.RemoveRPCName+" gRPC error response", + &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + ServingData: errdetails.Serialize(req), + }, + &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.RemoveRPCName + "." + vald.UpsertRPCName + ".BroadCast/" + target, + ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, target), + }, + ) + if st.Code() == codes.AlreadyExists { + return nil + } + emu.Lock() + if rerrs != nil { + rerrs = errors.Join(rerrs, err) + } else { + rerrs = err + } + emu.Unlock() + return err + } + return nil + }, + ) + if err != nil { + reqInfo := &errdetails.RequestInfo{ + RequestId: req.GetId().GetId(), + } + resInfo := &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.RemoveRPCName + "." + vald.UpsertRPCName + ".BroadCast", + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + } + if errors.Is(err, errors.ErrGRPCClientConnNotFound("*")) { + err = status.WrapWithInternal( + vald.UpsertRPCName+" for "+vald.RemoveRPCName+" API connection not found", err, reqInfo, resInfo, + ) + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.StatusCodeInternal(err.Error())...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + + // There is no possibility to reach this part, but we add error handling just in case. + st, msg, err := status.ParseError(err, codes.Internal, + "failed to parse "+vald.UpsertRPCName+" for "+vald.RemoveRPCName+" gRPC error response", reqInfo, resInfo, + ) + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + if rerrs == nil { + log.Debugf("rollback for Remove API mirror request succeeded to %v", targets) + st, msg, err := status.ParseError(errs, codes.Internal, + "failed to parse "+vald.RemoveRPCName+" gRPC error response", + &errdetails.RequestInfo{ + RequestId: req.GetId().GetId(), + ServingData: errdetails.Serialize(req), + }, + &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.RemoveRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }, + ) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + log.Debugf("failed to rollback for Remove API mirror request succeeded to %v: %v", targets, rerrs) + st, msg, err := status.ParseError(rerrs, codes.Internal, + "failed to parse "+vald.UpsertRPCName+" for "+vald.RemoveRPCName+" gRPC error response", + &errdetails.RequestInfo{ + RequestId: req.GetId().GetId(), + }, + &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.RemoveRPCName + "." + vald.UpsertRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s) %v", apiName, s.name, s.ip, targets), + }, + ) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err +} + +func (s *server) remove(ctx context.Context, client vald.RemoveClient, req *payload.Remove_Request, opts ...grpc.CallOption) (*payload.Object_Location, error) { + ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "remove"), apiName+"/remove") + defer func() { + if span != nil { + span.End() + } + }() + + loc, err := client.Remove(ctx, req, opts...) + if err != nil { + reqInfo := &errdetails.RequestInfo{ + RequestId: req.GetId().GetId(), + ServingData: errdetails.Serialize(req), + } + resInfo := &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.RemoveRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + } + var attrs trace.Attributes + + switch { + case errors.Is(err, context.Canceled): + err = status.WrapWithCanceled( + vald.RemoveRPCName+" API canceld", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeCancelled(err.Error()) + case errors.Is(err, context.DeadlineExceeded): + err = status.WrapWithDeadlineExceeded( + vald.RemoveRPCName+" API deadline exceeded", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeDeadlineExceeded(err.Error()) + case errors.Is(err, errors.ErrGRPCClientConnNotFound("*")): + err = status.WrapWithInternal( + vald.RemoveRPCName+" API connection not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + default: + var ( + st *status.Status + msg string + ) + st, msg, err = status.ParseError(err, codes.Internal, + "failed to parse "+vald.RemoveRPCName+" gRPC error response", reqInfo, resInfo, + ) + attrs = trace.FromGRPCStatus(st.Code(), msg) + } + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(attrs...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + return loc, nil +} + +func (s *server) StreamRemove(stream vald.Remove_StreamRemoveServer) (err error) { + ctx, span := trace.StartSpan(grpc.WithGRPCMethod(stream.Context(), vald.PackageName+"."+vald.RemoveRPCServiceName+"/"+vald.StreamRemoveRPCName), apiName+"/"+vald.StreamRemoveRPCName) + defer func() { + if span != nil { + span.End() + } + }() + err = grpc.BidirectionalStream(ctx, stream, s.streamConcurrency, + func(ctx context.Context, req *payload.Remove_Request) (*payload.Object_StreamLocation, error) { + ctx, sspan := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "BidirectionalStream"), apiName+"/"+vald.StreamRemoveRPCName+"/id-"+req.GetId().GetId()) + defer func() { + if sspan != nil { + sspan.End() + } + }() + res, err := s.Remove(ctx, req) + if err != nil { + st, msg, err := status.ParseError(err, codes.Internal, "failed to parse "+vald.RemoveRPCName+" gRPC error response") + if sspan != nil { + sspan.RecordError(err) + sspan.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + sspan.SetStatus(trace.StatusError, err.Error()) + } + return &payload.Object_StreamLocation{ + Payload: &payload.Object_StreamLocation_Status{ + Status: st.Proto(), + }, + }, err + } + return &payload.Object_StreamLocation{ + Payload: &payload.Object_StreamLocation_Location{ + Location: res, + }, + }, nil + }, + ) + if err != nil { + st, msg, err := status.ParseError(err, codes.Internal, "failed to parse "+vald.StreamRemoveRPCName+" gRPC error response") + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + return err + } + return nil +} + +func (s *server) MultiRemove(ctx context.Context, reqs *payload.Remove_MultiRequest) (res *payload.Object_Locations, errs error) { + ctx, span := trace.StartSpan(grpc.WithGRPCMethod(ctx, vald.PackageName+"."+vald.RemoveRPCServiceName+"/"+vald.MultiRemoveRPCName), apiName+"/"+vald.MultiRemoveRPCName) + defer func() { + if span != nil { + span.End() + } + }() + + res = &payload.Object_Locations{ + Locations: make([]*payload.Object_Location, len(reqs.GetRequests())), + } + + var mu, emu sync.Mutex + var wg sync.WaitGroup + + for i, r := range reqs.GetRequests() { + idx, req := i, r + wg.Add(1) + s.eg.Go(safety.RecoverFunc(func() error { + defer wg.Done() + ti := "errgroup.Go/id-" + req.GetId().GetId() + ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, ti), apiName+"/"+vald.MultiRemoveRPCName+"/"+ti) + defer func() { + if span != nil { + span.End() + } + }() + + loc, err := s.Remove(ctx, req) + if err != nil { + st, msg, err := status.ParseError(err, codes.Internal, "failed to parse "+vald.RemoveRPCName+" gRPC error response", + &errdetails.RequestInfo{ + RequestId: req.GetId().GetId(), + ServingData: errdetails.Serialize(req), + }) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + emu.Lock() + if errs != nil { + errs = errors.Join(errs, err) + } else { + errs = err + } + emu.Unlock() + return nil + } + mu.Lock() + res.Locations[idx] = loc + mu.Unlock() + return nil + })) + } + wg.Wait() + if errs != nil { + st, msg, err := status.ParseError(errs, codes.Internal, "failed to parse "+vald.MultiRemoveRPCName+" gRPC error response", + &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.MultiRemoveRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + return res, err + } + return res, nil +} + +func (s *server) GetObject(ctx context.Context, req *payload.Object_VectorRequest) (vec *payload.Object_Vector, err error) { + ctx, span := trace.StartSpan(grpc.WithGRPCMethod(ctx, vald.PackageName+"."+vald.ObjectRPCServiceName+"/"+vald.GetObjectRPCName), apiName+"/"+vald.GetObjectRPCName) + defer func() { + if span != nil { + span.End() + } + }() + + _, err = s.gateway.Do(ctx, s.vAddr, func(ctx context.Context, vc vald.ClientWithMirror, copts ...grpc.CallOption) (interface{}, error) { + vec, err = vc.GetObject(ctx, req, copts...) + if err != nil { + return nil, err + } + return vec, nil + }) + if err != nil { + reqInfo := &errdetails.RequestInfo{ + RequestId: req.GetId().GetId(), + ServingData: errdetails.Serialize(req), + } + resInfo := &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.GetObjectRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, s.vAddr), + } + var attrs trace.Attributes + + switch { + case errors.Is(err, context.Canceled): + err = status.WrapWithCanceled( + vald.GetObjectRPCName+" API canceld", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeCancelled(err.Error()) + case errors.Is(err, context.DeadlineExceeded): + err = status.WrapWithDeadlineExceeded( + vald.GetObjectRPCName+" API deadline exceeded", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeDeadlineExceeded(err.Error()) + case errors.Is(err, errors.ErrTargetNotFound): + err = status.WrapWithInternal( + vald.GetObjectRPCName+" API target not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + case errors.Is(err, errors.ErrGRPCClientConnNotFound("*")): + err = status.WrapWithInternal( + vald.GetObjectRPCName+" API connection not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + default: + var ( + st *status.Status + msg string + ) + st, msg, err = status.ParseError(err, codes.Internal, + "failed to parse "+vald.GetObjectRPCName+" gRPC error response", reqInfo, resInfo, + ) + attrs = trace.FromGRPCStatus(st.Code(), msg) + } + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(attrs...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + return vec, nil +} + +func (s *server) getObjects(ctx context.Context, req *payload.Object_VectorRequest) (vecs *sync.Map, err error) { + ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "getObjects"), apiName+"/"+vald.GetObjectRPCName+"/getObjects") + defer func() { + if span != nil { + span.End() + } + }() + + var errs error + var emu sync.Mutex + vecs = new(sync.Map) + err = s.gateway.BroadCast(ctx, func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error { + ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "BroadCast/"+target), apiName+"/"+vald.GetObjectRPCName+"/getObjects/"+target) + defer func() { + if span != nil { + span.End() + } + }() + + vec, err := vc.GetObject(ctx, req, copts...) + if err != nil { + reqInfo := &errdetails.RequestInfo{ + RequestId: req.GetId().GetId(), + ServingData: errdetails.Serialize(req), + } + resInfo := &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.GetObjectRPCName, + ResourceName: fmt.Sprintf("%s: %s(%s) to %s", apiName, s.name, s.ip, target), + } + var attrs trace.Attributes + var code codes.Code + + switch { + case errors.Is(err, context.Canceled): + err = status.WrapWithCanceled( + vald.GetObjectRPCName+" API canceld", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeCancelled(err.Error()) + code = codes.Canceled + case errors.Is(err, context.DeadlineExceeded): + err = status.WrapWithDeadlineExceeded( + vald.GetObjectRPCName+" API deadline exceeded", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeDeadlineExceeded(err.Error()) + code = codes.DeadlineExceeded + case errors.Is(err, errors.ErrTargetNotFound): + err = status.WrapWithInternal( + vald.GetObjectRPCName+" API target not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + code = codes.Internal + case errors.Is(err, errors.ErrGRPCClientConnNotFound("*")): + err = status.WrapWithInternal( + vald.GetObjectRPCName+" API connection not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + code = codes.Internal + default: + var ( + st *status.Status + msg string + ) + st, msg, err = status.ParseError(err, codes.Internal, + "failed to parse "+vald.GetObjectRPCName+" gRPC error response", reqInfo, resInfo, + ) + attrs = trace.FromGRPCStatus(st.Code(), msg) + code = st.Code() + } + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(attrs...) + span.SetStatus(trace.StatusError, err.Error()) + } + if code == codes.NotFound { + return nil + } + emu.Lock() + if errs == nil { + errs = err + } else { + errs = errors.Join(errs, err) + } + emu.Unlock() + return err + } + vecs.Store(target, vec) + return nil + }) + if err != nil { + if errors.Is(err, errors.ErrGRPCClientConnNotFound("*")) { + err = status.WrapWithInternal( + vald.GetObjectRPCName+" API connection not found", err, + &errdetails.RequestInfo{ + RequestId: req.GetId().GetId(), + ServingData: errdetails.Serialize(req), + }, + &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.GetObjectRPCName + ".BroadCast", + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }, + ) + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.StatusCodeInternal(err.Error())...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + errs = errors.Join(errs, err) + } + if errs != nil { + st, msg, err := status.ParseError(errs, codes.Internal, + "failed to parse "+vald.GetObjectRPCName+" gRPC error response", + &errdetails.RequestInfo{ + RequestId: req.GetId().GetId(), + ServingData: errdetails.Serialize(req), + }, + &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.GetObjectRPCName + "." + "BroadCast", + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }, + ) + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + return vecs, nil +} + +func (s *server) StreamGetObject(stream vald.Object_StreamGetObjectServer) (err error) { + ctx, span := trace.StartSpan(grpc.WithGRPCMethod(stream.Context(), vald.PackageName+"."+vald.ObjectRPCServiceName+"/"+vald.StreamGetObjectRPCName), apiName+"/"+vald.StreamGetObjectRPCName) + defer func() { + if span != nil { + span.End() + } + }() + err = grpc.BidirectionalStream(ctx, stream, s.streamConcurrency, + func(ctx context.Context, req *payload.Object_VectorRequest) (*payload.Object_StreamVector, error) { + ctx, sspan := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "BidirectionalStream"), apiName+"/"+vald.StreamInsertRPCName+"/id-"+req.GetId().GetId()) + defer func() { + if sspan != nil { + sspan.End() + } + }() + res, err := s.GetObject(ctx, req) + if err != nil { + st, msg, err := status.ParseError(err, codes.Internal, "failed to parse "+vald.GetObjectRPCName+" gRPC error response") + if sspan != nil { + sspan.RecordError(err) + sspan.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + sspan.SetStatus(trace.StatusError, err.Error()) + } + return &payload.Object_StreamVector{ + Payload: &payload.Object_StreamVector_Status{ + Status: st.Proto(), + }, + }, err + } + return &payload.Object_StreamVector{ + Payload: &payload.Object_StreamVector_Vector{ + Vector: res, + }, + }, nil + }, + ) + if err != nil { + st, msg, err := status.ParseError(err, codes.Internal, "failed to parse "+vald.StreamGetObjectRPCName+" gRPC error response") + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + return err + } + return nil +} diff --git a/pkg/gateway/mirror/handler/grpc/handler_test.go b/pkg/gateway/mirror/handler/grpc/handler_test.go new file mode 100644 index 0000000000..6127d02463 --- /dev/null +++ b/pkg/gateway/mirror/handler/grpc/handler_test.go @@ -0,0 +1,5519 @@ +package grpc + +import ( + "context" + "reflect" + "sync" + "sync/atomic" + "testing" + + "github.com/vdaas/vald/apis/grpc/v1/payload" + "github.com/vdaas/vald/apis/grpc/v1/vald" + "github.com/vdaas/vald/internal/errgroup" + "github.com/vdaas/vald/internal/errors" + "github.com/vdaas/vald/internal/net/grpc" + "github.com/vdaas/vald/internal/net/grpc/codes" + "github.com/vdaas/vald/internal/net/grpc/status" + "github.com/vdaas/vald/internal/test/data/vector" + "github.com/vdaas/vald/internal/test/goleak" + "github.com/vdaas/vald/pkg/gateway/mirror/service" +) + +func Test_server_Insert(t *testing.T) { + const dimension = 128 + defaultInsertConfig := &payload.Insert_Config{ + SkipStrictExistCheck: true, + } + type args struct { + ctx context.Context + req *payload.Insert_Request + } + type fields struct { + eg errgroup.Group + gateway service.Gateway + mirror service.Mirror + vAddr string + streamConcurrency int + name string + ip string + UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror + } + type want struct { + wantCe *payload.Object_Location + err error + } + type test struct { + name string + args args + fields fields + want want + checkFunc func(want, *payload.Object_Location, error) error + beforeFunc func(*testing.T, args) + afterFunc func(*testing.T, args) + } + defaultCheckFunc := func(w want, gotCe *payload.Object_Location, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } + if !reflect.DeepEqual(gotCe, w.wantCe) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotCe, w.wantCe) + } + return nil + } + tests := []test{ + func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, egctx := errgroup.New(ctx) + + uuid := "test" + loc := &payload.Object_Location{ + Uuid: uuid, + Ips: []string{"127.0.0.1"}, + } + cmap := map[string]vald.ClientWithMirror{ + "vald-mirror-01": &mockClient{ + InsertFunc: func(_ context.Context, _ *payload.Insert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return loc, nil + }, + }, + "vald-lb-gateway-01": &mockClient{ + InsertFunc: func(_ context.Context, _ *payload.Insert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return loc, nil + }, + }, + } + wantLoc := &payload.Object_Location{ + Uuid: uuid, + Ips: []string{"127.0.0.1", "127.0.0.1"}, + } + return test{ + name: "success insert with new ID", + args: args{ + ctx: egctx, + req: &payload.Insert_Request{ + Vector: &payload.Object_Vector{ + Id: uuid, + Vector: vector.GaussianDistributedFloat32VectorGenerator(1, dimension)[0], + }, + Config: defaultInsertConfig, + }, + }, + fields: fields{ + eg: eg, + gateway: &mockGateway{ + FromForwardedContextFunc: func(_ context.Context) string { + return "" + }, + BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + for tgt, c := range cmap { + f(ctx, tgt, c) + } + return nil + }, + }, + }, + want: want{ + wantCe: wantLoc, + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + cancel() + }, + } + }(), + func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, egctx := errgroup.New(ctx) + + uuid := "test" + loc := &payload.Object_Location{ + Uuid: uuid, + Ips: []string{"127.0.0.1"}, + } + cmap := map[string]vald.ClientWithMirror{ + "vald-mirror-01": &mockClient{ + InsertFunc: func(_ context.Context, _ *payload.Insert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return loc, nil + }, + RemoveFunc: func(_ context.Context, _ *payload.Remove_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return loc, nil + }, + }, + "vald-lb-gateway-01": &mockClient{ + InsertFunc: func(_ context.Context, _ *payload.Insert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()) + }, + }, + } + return test{ + name: "fail insert with new ID but remove rollback success", + args: args{ + ctx: egctx, + req: &payload.Insert_Request{ + Vector: &payload.Object_Vector{ + Id: uuid, + Vector: vector.GaussianDistributedFloat32VectorGenerator(1, dimension)[0], + }, + Config: defaultInsertConfig, + }, + }, + fields: fields{ + eg: eg, + gateway: &mockGateway{ + FromForwardedContextFunc: func(_ context.Context) string { + return "" + }, + BroadCastFunc: func(ctx context.Context, f func(_ context.Context, _ string, _ vald.ClientWithMirror, _ ...grpc.CallOption) error) error { + for tgt, c := range cmap { + f(ctx, tgt, c) + } + return nil + }, + DoMultiFunc: func(ctx context.Context, targets []string, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + if len(targets) != 1 { + return errors.New("invalid target") + } + if c, ok := cmap[targets[0]]; ok { + f(ctx, targets[0], c) + } + return nil + }, + }, + }, + want: want{ + err: status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()), + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + cancel() + }, + } + }(), + func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, egctx := errgroup.New(ctx) + + uuid := "test" + loc := &payload.Object_Location{ + Uuid: uuid, + Ips: []string{"127.0.0.1"}, + } + cmap := map[string]vald.ClientWithMirror{ + "vald-mirror-01": &mockClient{ + InsertFunc: func(_ context.Context, _ *payload.Insert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return loc, nil + }, + RemoveFunc: func(_ context.Context, _ *payload.Remove_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()) + }, + }, + "vald-lb-gateway-01": &mockClient{ + InsertFunc: func(_ context.Context, _ *payload.Insert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerOpenState.Error()) + }, + }, + } + return test{ + name: "fail insert with new ID and fail remove rollback", + args: args{ + ctx: egctx, + req: &payload.Insert_Request{ + Vector: &payload.Object_Vector{ + Id: uuid, + Vector: vector.GaussianDistributedFloat32VectorGenerator(1, dimension)[0], + }, + Config: defaultInsertConfig, + }, + }, + fields: fields{ + eg: eg, + gateway: &mockGateway{ + FromForwardedContextFunc: func(_ context.Context) string { + return "" + }, + BroadCastFunc: func(ctx context.Context, f func(_ context.Context, _ string, _ vald.ClientWithMirror, _ ...grpc.CallOption) error) error { + for tgt, c := range cmap { + f(ctx, tgt, c) + } + return nil + }, + DoMultiFunc: func(ctx context.Context, targets []string, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + if len(targets) != 1 { + return errors.New("invalid target") + } + if c, ok := cmap[targets[0]]; ok { + f(ctx, targets[0], c) + } + return nil + }, + }, + }, + want: want{ + err: status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()), + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + cancel() + }, + } + }(), + } + + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) + if test.beforeFunc != nil { + test.beforeFunc(tt, test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(tt, test.args) + } + checkFunc := test.checkFunc + if test.checkFunc == nil { + checkFunc = defaultCheckFunc + } + s := &server{ + eg: test.fields.eg, + gateway: test.fields.gateway, + mirror: test.fields.mirror, + vAddr: test.fields.vAddr, + streamConcurrency: test.fields.streamConcurrency, + name: test.fields.name, + ip: test.fields.ip, + UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, + } + + gotCe, err := s.Insert(test.args.ctx, test.args.req) + if err := checkFunc(test.want, gotCe, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + +func Test_server_Update(t *testing.T) { + const dimension = 128 + defaultUpdateConfig := &payload.Update_Config{ + SkipStrictExistCheck: true, + } + type args struct { + ctx context.Context + req *payload.Update_Request + } + type fields struct { + eg errgroup.Group + gateway service.Gateway + mirror service.Mirror + vAddr string + streamConcurrency int + name string + ip string + UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror + } + type want struct { + wantLoc *payload.Object_Location + err error + } + type test struct { + name string + args args + fields fields + want want + checkFunc func(want, *payload.Object_Location, error) error + beforeFunc func(*testing.T, args) + afterFunc func(*testing.T, args) + } + defaultCheckFunc := func(w want, gotLoc *payload.Object_Location, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } + if !reflect.DeepEqual(gotLoc, w.wantLoc) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotLoc, w.wantLoc) + } + return nil + } + tests := []test{ + func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, egctx := errgroup.New(ctx) + + uuid := "test" + loc := &payload.Object_Location{ + Uuid: uuid, + Ips: []string{"127.0.0.1"}, + } + cmap := map[string]vald.ClientWithMirror{ + "vald-mirror-01": &mockClient{ + GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { + return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) + }, + UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return loc, nil + }, + }, + "vald-lb-gateway-01": &mockClient{ + GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { + return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) + }, + UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return loc, nil + }, + }, + } + wantLoc := &payload.Object_Location{ + Uuid: uuid, + Ips: []string{"127.0.0.1", "127.0.0.1"}, + } + return test{ + name: "success update with new ID", + args: args{ + ctx: egctx, + req: &payload.Update_Request{ + Vector: &payload.Object_Vector{ + Id: uuid, + Vector: vector.GaussianDistributedFloat32VectorGenerator(1, dimension)[0], + }, + Config: defaultUpdateConfig, + }, + }, + fields: fields{ + eg: eg, + gateway: &mockGateway{ + FromForwardedContextFunc: func(_ context.Context) string { + return "" + }, + BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + for tgt, c := range cmap { + f(ctx, tgt, c) + } + return nil + }, + }, + }, + want: want{ + wantLoc: wantLoc, + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + cancel() + }, + } + }(), + func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, egctx := errgroup.New(ctx) + + uuid := "test" + loc := &payload.Object_Location{ + Uuid: uuid, + Ips: []string{"127.0.0.1"}, + } + cmap := map[string]vald.ClientWithMirror{ + "vald-mirror-01": &mockClient{ + GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { + return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) + }, + UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return loc, nil + }, + RemoveFunc: func(_ context.Context, _ *payload.Remove_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return loc, nil + }, + }, + "vald-lb-gateway-01": &mockClient{ + GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { + return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) + }, + UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()) + }, + }, + } + return test{ + name: "fail update with new ID but remove rollback success", + args: args{ + ctx: egctx, + req: &payload.Update_Request{ + Vector: &payload.Object_Vector{ + Id: uuid, + Vector: vector.GaussianDistributedFloat32VectorGenerator(1, dimension)[0], + }, + Config: defaultUpdateConfig, + }, + }, + fields: fields{ + eg: eg, + gateway: &mockGateway{ + FromForwardedContextFunc: func(_ context.Context) string { + return "" + }, + BroadCastFunc: func(ctx context.Context, f func(_ context.Context, _ string, _ vald.ClientWithMirror, _ ...grpc.CallOption) error) error { + for tgt, c := range cmap { + f(ctx, tgt, c) + } + return nil + }, + DoMultiFunc: func(ctx context.Context, targets []string, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + if len(targets) != 1 { + return errors.New("invalid target") + } + if c, ok := cmap[targets[0]]; ok { + f(ctx, targets[0], c) + } + return nil + }, + }, + }, + want: want{ + err: status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()), + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + cancel() + }, + } + }(), + func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, egctx := errgroup.New(ctx) + + uuid := "test" + loc := &payload.Object_Location{ + Uuid: uuid, + Ips: []string{"127.0.0.1"}, + } + ovec := &payload.Object_Vector{ + Id: uuid, + Vector: vector.GaussianDistributedFloat32VectorGenerator(1, dimension)[0], + } + cmap := map[string]vald.ClientWithMirror{ + "vald-mirror-01": &mockClient{ + GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { + return ovec, nil + }, + UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return loc, nil + }, + }, + "vald-lb-gateway-01": &mockClient{ + GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { + return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) + }, + UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()) + }, + }, + } + return test{ + name: "fail update with new ID but update rollback success", + args: args{ + ctx: egctx, + req: &payload.Update_Request{ + Vector: ovec, + Config: defaultUpdateConfig, + }, + }, + fields: fields{ + eg: eg, + gateway: &mockGateway{ + FromForwardedContextFunc: func(_ context.Context) string { + return "" + }, + BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + for tgt, c := range cmap { + f(ctx, tgt, c) + } + return nil + }, + DoMultiFunc: func(ctx context.Context, targets []string, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + if len(targets) != 1 { + return errors.New("invalid target") + } + if c, ok := cmap[targets[0]]; ok { + f(ctx, targets[0], c) + } + return nil + }, + }, + }, + want: want{ + err: status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()), + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + cancel() + }, + } + }(), + func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, egctx := errgroup.New(ctx) + + uuid := "test" + loc := &payload.Object_Location{ + Uuid: uuid, + Ips: []string{"127.0.0.1"}, + } + cmap := map[string]vald.ClientWithMirror{ + "vald-mirror-01": &mockClient{ + GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { + return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) + }, + UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return loc, nil + }, + RemoveFunc: func(_ context.Context, _ *payload.Remove_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()) + }, + }, + "vald-lb-gateway-01": &mockClient{ + GetObjectFunc: func(ctx context.Context, in *payload.Object_VectorRequest, opts ...grpc.CallOption) (*payload.Object_Vector, error) { + return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) + }, + UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerOpenState.Error()) + }, + }, + } + return test{ + name: "fail update with new ID and fail remove rollback", + args: args{ + ctx: egctx, + req: &payload.Update_Request{ + Vector: &payload.Object_Vector{ + Id: uuid, + Vector: vector.GaussianDistributedFloat32VectorGenerator(1, dimension)[0], + }, + Config: defaultUpdateConfig, + }, + }, + fields: fields{ + eg: eg, + gateway: &mockGateway{ + FromForwardedContextFunc: func(_ context.Context) string { + return "" + }, + BroadCastFunc: func(ctx context.Context, f func(_ context.Context, _ string, _ vald.ClientWithMirror, _ ...grpc.CallOption) error) error { + for tgt, c := range cmap { + f(ctx, tgt, c) + } + return nil + }, + DoMultiFunc: func(ctx context.Context, targets []string, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + if len(targets) != 1 { + return errors.New("invalid target") + } + if c, ok := cmap[targets[0]]; ok { + f(ctx, targets[0], c) + } + return nil + }, + }, + }, + want: want{ + err: status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()), + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + cancel() + }, + } + }(), + func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, egctx := errgroup.New(ctx) + + uuid := "test" + loc := &payload.Object_Location{ + Uuid: uuid, + Ips: []string{"127.0.0.1"}, + } + ovec := &payload.Object_Vector{ + Id: uuid, + Vector: vector.GaussianDistributedFloat32VectorGenerator(1, dimension)[0], + } + var cnt uint32 + cmap := map[string]vald.ClientWithMirror{ + "vald-mirror-01": &mockClient{ + GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { + return ovec, nil + }, + UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + if atomic.AddUint32(&cnt, 1) == 1 { + return loc, nil + } + return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()) + }, + }, + "vald-lb-gateway-01": &mockClient{ + GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { + return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) + }, + UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerOpenState.Error()) + }, + }, + } + return test{ + name: "fail update with new ID and fail update rollback", + args: args{ + ctx: egctx, + req: &payload.Update_Request{ + Vector: ovec, + Config: defaultUpdateConfig, + }, + }, + fields: fields{ + eg: eg, + gateway: &mockGateway{ + FromForwardedContextFunc: func(_ context.Context) string { + return "" + }, + BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + for tgt, c := range cmap { + f(ctx, tgt, c) + } + return nil + }, + DoMultiFunc: func(ctx context.Context, targets []string, f func(ctx context.Context, target string, vc vald.ClientWithMirror, _ ...grpc.CallOption) error) error { + if len(targets) != 1 { + return errors.New("invalid target") + } + if c, ok := cmap[targets[0]]; ok { + f(ctx, targets[0], c) + } + return nil + }, + }, + }, + want: want{ + err: status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()), + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + cancel() + }, + } + }(), + } + + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) + if test.beforeFunc != nil { + test.beforeFunc(tt, test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(tt, test.args) + } + checkFunc := test.checkFunc + if test.checkFunc == nil { + checkFunc = defaultCheckFunc + } + s := &server{ + eg: test.fields.eg, + gateway: test.fields.gateway, + mirror: test.fields.mirror, + vAddr: test.fields.vAddr, + streamConcurrency: test.fields.streamConcurrency, + name: test.fields.name, + ip: test.fields.ip, + UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, + } + + gotLoc, err := s.Update(test.args.ctx, test.args.req) + if err := checkFunc(test.want, gotLoc, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + +func Test_server_Upsert(t *testing.T) { + const dimension = 128 + defaultUpsertConfig := &payload.Upsert_Config{ + SkipStrictExistCheck: true, + } + type args struct { + ctx context.Context + req *payload.Upsert_Request + } + type fields struct { + eg errgroup.Group + gateway service.Gateway + mirror service.Mirror + vAddr string + streamConcurrency int + name string + ip string + UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror + } + type want struct { + wantLoc *payload.Object_Location + err error + } + type test struct { + name string + args args + fields fields + want want + checkFunc func(want, *payload.Object_Location, error) error + beforeFunc func(*testing.T, args) + afterFunc func(*testing.T, args) + } + defaultCheckFunc := func(w want, gotLoc *payload.Object_Location, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } + if !reflect.DeepEqual(gotLoc, w.wantLoc) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotLoc, w.wantLoc) + } + return nil + } + tests := []test{ + func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, egctx := errgroup.New(ctx) + + uuid := "test" + loc := &payload.Object_Location{ + Uuid: uuid, + Ips: []string{"127.0.0.1"}, + } + cmap := map[string]vald.ClientWithMirror{ + "vald-mirror-01": &mockClient{ + GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { + return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) + }, + UpsertFunc: func(_ context.Context, _ *payload.Upsert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return loc, nil + }, + }, + "vald-lb-gateway-01": &mockClient{ + GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { + return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) + }, + UpsertFunc: func(_ context.Context, _ *payload.Upsert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return loc, nil + }, + }, + } + wantLoc := &payload.Object_Location{ + Uuid: uuid, + Ips: []string{"127.0.0.1", "127.0.0.1"}, + } + return test{ + name: "success upsert with new ID", + args: args{ + ctx: egctx, + req: &payload.Upsert_Request{ + Vector: &payload.Object_Vector{ + Id: uuid, + Vector: vector.GaussianDistributedFloat32VectorGenerator(1, dimension)[0], + }, + Config: defaultUpsertConfig, + }, + }, + fields: fields{ + eg: eg, + gateway: &mockGateway{ + FromForwardedContextFunc: func(_ context.Context) string { + return "" + }, + BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + for tgt, c := range cmap { + f(ctx, tgt, c) + } + return nil + }, + }, + }, + want: want{ + wantLoc: wantLoc, + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + cancel() + }, + } + }(), + func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, egctx := errgroup.New(ctx) + + uuid := "test" + loc := &payload.Object_Location{ + Uuid: uuid, + Ips: []string{"127.0.0.1"}, + } + cmap := map[string]vald.ClientWithMirror{ + "vald-mirror-01": &mockClient{ + GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { + return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) + }, + UpsertFunc: func(_ context.Context, _ *payload.Upsert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return loc, nil + }, + RemoveFunc: func(_ context.Context, _ *payload.Remove_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return loc, nil + }, + }, + "vald-lb-gateway-01": &mockClient{ + GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { + return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) + }, + UpsertFunc: func(ctx context.Context, in *payload.Upsert_Request, opts ...grpc.CallOption) (*payload.Object_Location, error) { + return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()) + }, + }, + } + return test{ + name: "fail upsert with new ID but remove rollback success", + args: args{ + ctx: egctx, + req: &payload.Upsert_Request{ + Vector: &payload.Object_Vector{ + Id: uuid, + Vector: vector.GaussianDistributedFloat32VectorGenerator(1, dimension)[0], + }, + Config: defaultUpsertConfig, + }, + }, + fields: fields{ + eg: eg, + gateway: &mockGateway{ + FromForwardedContextFunc: func(_ context.Context) string { + return "" + }, + BroadCastFunc: func(ctx context.Context, f func(_ context.Context, _ string, _ vald.ClientWithMirror, _ ...grpc.CallOption) error) error { + for tgt, c := range cmap { + f(ctx, tgt, c) + } + return nil + }, + DoMultiFunc: func(ctx context.Context, targets []string, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + if len(targets) != 1 { + return errors.New("invalid target") + } + if c, ok := cmap[targets[0]]; ok { + f(ctx, targets[0], c) + } + return nil + }, + }, + }, + want: want{ + err: status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()), + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + cancel() + }, + } + }(), + func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, egctx := errgroup.New(ctx) + + uuid := "test" + loc := &payload.Object_Location{ + Uuid: uuid, + Ips: []string{"127.0.0.1"}, + } + ovec := &payload.Object_Vector{ + Id: uuid, + Vector: vector.GaussianDistributedFloat32VectorGenerator(1, dimension)[0], + } + cmap := map[string]vald.ClientWithMirror{ + "vald-mirror-01": &mockClient{ + GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { + return ovec, nil + }, + UpsertFunc: func(_ context.Context, _ *payload.Upsert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return loc, nil + }, + UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return loc, nil + }, + }, + "vald-lb-gateway-01": &mockClient{ + GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { + return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) + }, + UpsertFunc: func(_ context.Context, _ *payload.Upsert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()) + }, + }, + } + return test{ + name: "fail upsert with new ID but update rollback success", + args: args{ + ctx: egctx, + req: &payload.Upsert_Request{ + Vector: ovec, + Config: defaultUpsertConfig, + }, + }, + fields: fields{ + eg: eg, + gateway: &mockGateway{ + FromForwardedContextFunc: func(_ context.Context) string { + return "" + }, + BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + for tgt, c := range cmap { + f(ctx, tgt, c) + } + return nil + }, + DoMultiFunc: func(ctx context.Context, targets []string, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + if len(targets) != 1 { + return errors.New("invalid target") + } + if c, ok := cmap[targets[0]]; ok { + f(ctx, targets[0], c) + } + return nil + }, + }, + }, + want: want{ + err: status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()), + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + cancel() + }, + } + }(), + func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, egctx := errgroup.New(ctx) + + uuid := "test" + loc := &payload.Object_Location{ + Uuid: uuid, + Ips: []string{"127.0.0.1"}, + } + cmap := map[string]vald.ClientWithMirror{ + "vald-mirror-01": &mockClient{ + GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { + return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) + }, + UpsertFunc: func(_ context.Context, _ *payload.Upsert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return loc, nil + }, + RemoveFunc: func(_ context.Context, _ *payload.Remove_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()) + }, + }, + "vald-lb-gateway-01": &mockClient{ + GetObjectFunc: func(ctx context.Context, in *payload.Object_VectorRequest, opts ...grpc.CallOption) (*payload.Object_Vector, error) { + return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) + }, + UpsertFunc: func(ctx context.Context, in *payload.Upsert_Request, opts ...grpc.CallOption) (*payload.Object_Location, error) { + return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerOpenState.Error()) + }, + }, + } + return test{ + name: "fail upsert with new ID and fail remove rollback", + args: args{ + ctx: egctx, + req: &payload.Upsert_Request{ + Vector: &payload.Object_Vector{ + Id: uuid, + Vector: vector.GaussianDistributedFloat32VectorGenerator(1, dimension)[0], + }, + Config: defaultUpsertConfig, + }, + }, + fields: fields{ + eg: eg, + gateway: &mockGateway{ + FromForwardedContextFunc: func(_ context.Context) string { + return "" + }, + BroadCastFunc: func(ctx context.Context, f func(_ context.Context, _ string, _ vald.ClientWithMirror, _ ...grpc.CallOption) error) error { + for tgt, c := range cmap { + f(ctx, tgt, c) + } + return nil + }, + DoMultiFunc: func(ctx context.Context, targets []string, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + if len(targets) != 1 { + return errors.New("invalid target") + } + if c, ok := cmap[targets[0]]; ok { + f(ctx, targets[0], c) + } + return nil + }, + }, + }, + want: want{ + err: status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()), + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + cancel() + }, + } + }(), + func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, egctx := errgroup.New(ctx) + + uuid := "test" + loc := &payload.Object_Location{ + Uuid: uuid, + Ips: []string{"127.0.0.1"}, + } + ovec := &payload.Object_Vector{ + Id: uuid, + Vector: vector.GaussianDistributedFloat32VectorGenerator(1, dimension)[0], + } + cmap := map[string]vald.ClientWithMirror{ + "vald-mirror-01": &mockClient{ + GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { + return ovec, nil + }, + UpsertFunc: func(_ context.Context, _ *payload.Upsert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return loc, nil + }, + UpdateFunc: func(_ context.Context, _ *payload.Update_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()) + }, + }, + "vald-lb-gateway-01": &mockClient{ + GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { + return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) + }, + UpsertFunc: func(_ context.Context, _ *payload.Upsert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerOpenState.Error()) + }, + }, + } + return test{ + name: "fail upsert with new ID and fail update rollback", + args: args{ + ctx: egctx, + req: &payload.Upsert_Request{ + Vector: ovec, + Config: defaultUpsertConfig, + }, + }, + fields: fields{ + eg: eg, + gateway: &mockGateway{ + FromForwardedContextFunc: func(_ context.Context) string { + return "" + }, + BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + for tgt, c := range cmap { + f(ctx, tgt, c) + } + return nil + }, + DoMultiFunc: func(ctx context.Context, targets []string, f func(ctx context.Context, target string, vc vald.ClientWithMirror, _ ...grpc.CallOption) error) error { + if len(targets) != 1 { + return errors.New("invalid target") + } + if c, ok := cmap[targets[0]]; ok { + f(ctx, targets[0], c) + } + return nil + }, + }, + }, + want: want{ + err: status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()), + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + cancel() + }, + } + }(), + } + + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) + if test.beforeFunc != nil { + test.beforeFunc(tt, test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(tt, test.args) + } + checkFunc := test.checkFunc + if test.checkFunc == nil { + checkFunc = defaultCheckFunc + } + s := &server{ + eg: test.fields.eg, + gateway: test.fields.gateway, + mirror: test.fields.mirror, + vAddr: test.fields.vAddr, + streamConcurrency: test.fields.streamConcurrency, + name: test.fields.name, + ip: test.fields.ip, + UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, + } + + gotLoc, err := s.Upsert(test.args.ctx, test.args.req) + if err := checkFunc(test.want, gotLoc, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + +func Test_server_Remove(t *testing.T) { + const dimension = 128 + defaultRemoveConfig := &payload.Remove_Config{ + SkipStrictExistCheck: true, + } + type args struct { + ctx context.Context + req *payload.Remove_Request + } + type fields struct { + eg errgroup.Group + gateway service.Gateway + mirror service.Mirror + vAddr string + streamConcurrency int + name string + ip string + UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror + } + type want struct { + wantLoc *payload.Object_Location + err error + } + type test struct { + name string + args args + fields fields + want want + checkFunc func(want, *payload.Object_Location, error) error + beforeFunc func(*testing.T, args) + afterFunc func(*testing.T, args) + } + defaultCheckFunc := func(w want, gotLoc *payload.Object_Location, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } + if !reflect.DeepEqual(gotLoc, w.wantLoc) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotLoc, w.wantLoc) + } + return nil + } + tests := []test{ + func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, egctx := errgroup.New(ctx) + + uuid := "test" + loc := &payload.Object_Location{ + Uuid: uuid, + Ips: []string{"127.0.0.1"}, + } + ovec := &payload.Object_Vector{ + Id: uuid, + Vector: vector.GaussianDistributedFloat32VectorGenerator(1, dimension)[0], + } + cmap := map[string]vald.ClientWithMirror{ + "vald-mirror-01": &mockClient{ + GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { + return ovec, nil + }, + RemoveFunc: func(_ context.Context, _ *payload.Remove_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return loc, nil + }, + }, + "vald-lb-gateway-01": &mockClient{ + GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { + return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) + }, + RemoveFunc: func(_ context.Context, _ *payload.Remove_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) + }, + }, + } + wantLoc := &payload.Object_Location{ + Uuid: uuid, + Ips: []string{"127.0.0.1"}, + } + return test{ + name: "success remove with existing ID", + args: args{ + ctx: egctx, + req: &payload.Remove_Request{ + Id: &payload.Object_ID{ + Id: uuid, + }, + Config: defaultRemoveConfig, + }, + }, + fields: fields{ + eg: eg, + gateway: &mockGateway{ + FromForwardedContextFunc: func(_ context.Context) string { + return "" + }, + BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + for tgt, c := range cmap { + f(ctx, tgt, c) + } + return nil + }, + }, + }, + want: want{ + wantLoc: wantLoc, + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + cancel() + }, + } + }(), + func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, egctx := errgroup.New(ctx) + + uuid := "test" + loc := &payload.Object_Location{ + Uuid: uuid, + Ips: []string{"127.0.0.1"}, + } + ovec := &payload.Object_Vector{ + Id: uuid, + Vector: vector.GaussianDistributedFloat32VectorGenerator(1, dimension)[0], + } + cmap := map[string]vald.ClientWithMirror{ + "vald-mirror-01": &mockClient{ + GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { + return ovec, nil + }, + RemoveFunc: func(_ context.Context, _ *payload.Remove_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return loc, nil + }, + UpsertFunc: func(_ context.Context, _ *payload.Upsert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return loc, nil + }, + }, + "vald-lb-gateway-01": &mockClient{ + GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { + return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) + }, + RemoveFunc: func(_ context.Context, _ *payload.Remove_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()) + }, + }, + } + return test{ + name: "fail remove with existing ID but upsert rollback success", + args: args{ + ctx: egctx, + req: &payload.Remove_Request{ + Id: &payload.Object_ID{ + Id: uuid, + }, + Config: defaultRemoveConfig, + }, + }, + fields: fields{ + eg: eg, + gateway: &mockGateway{ + FromForwardedContextFunc: func(_ context.Context) string { + return "" + }, + BroadCastFunc: func(ctx context.Context, f func(_ context.Context, _ string, _ vald.ClientWithMirror, _ ...grpc.CallOption) error) error { + for tgt, c := range cmap { + f(ctx, tgt, c) + } + return nil + }, + DoMultiFunc: func(ctx context.Context, targets []string, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + if len(targets) != 1 { + return errors.New("invalid target") + } + if c, ok := cmap[targets[0]]; ok { + f(ctx, targets[0], c) + } + return nil + }, + }, + }, + want: want{ + err: status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()), + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + cancel() + }, + } + }(), + func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, egctx := errgroup.New(ctx) + + uuid := "test" + loc := &payload.Object_Location{ + Uuid: uuid, + Ips: []string{"127.0.0.1"}, + } + ovec := &payload.Object_Vector{ + Id: uuid, + Vector: vector.GaussianDistributedFloat32VectorGenerator(1, dimension)[0], + } + cmap := map[string]vald.ClientWithMirror{ + "vald-mirror-01": &mockClient{ + GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { + return ovec, nil + }, + RemoveFunc: func(_ context.Context, _ *payload.Remove_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return loc, nil + }, + UpsertFunc: func(_ context.Context, _ *payload.Upsert_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return nil, status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()) + }, + }, + "vald-lb-gateway-01": &mockClient{ + GetObjectFunc: func(_ context.Context, _ *payload.Object_VectorRequest, _ ...grpc.CallOption) (*payload.Object_Vector, error) { + return nil, status.Error(codes.NotFound, errors.ErrObjectIDNotFound(uuid).Error()) + }, + RemoveFunc: func(_ context.Context, _ *payload.Remove_Request, _ ...grpc.CallOption) (*payload.Object_Location, error) { + return loc, status.Error(codes.Internal, errors.ErrCircuitBreakerOpenState.Error()) + }, + }, + } + return test{ + name: "fail remove with existing ID and fail upsert rollback", + args: args{ + ctx: egctx, + req: &payload.Remove_Request{ + Id: &payload.Object_ID{ + Id: uuid, + }, + Config: defaultRemoveConfig, + }, + }, + fields: fields{ + eg: eg, + gateway: &mockGateway{ + FromForwardedContextFunc: func(_ context.Context) string { + return "" + }, + BroadCastFunc: func(ctx context.Context, f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error { + for tgt, c := range cmap { + f(ctx, tgt, c) + } + return nil + }, + DoMultiFunc: func(ctx context.Context, targets []string, f func(ctx context.Context, target string, vc vald.ClientWithMirror, _ ...grpc.CallOption) error) error { + if len(targets) != 1 { + return errors.New("invalid target") + } + if c, ok := cmap[targets[0]]; ok { + f(ctx, targets[0], c) + } + return nil + }, + }, + }, + want: want{ + err: status.Error(codes.Internal, errors.ErrCircuitBreakerHalfOpenFlowLimitation.Error()), + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + cancel() + }, + } + }(), + } + + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) + if test.beforeFunc != nil { + test.beforeFunc(tt, test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(tt, test.args) + } + checkFunc := test.checkFunc + if test.checkFunc == nil { + checkFunc = defaultCheckFunc + } + s := &server{ + eg: test.fields.eg, + gateway: test.fields.gateway, + mirror: test.fields.mirror, + vAddr: test.fields.vAddr, + streamConcurrency: test.fields.streamConcurrency, + name: test.fields.name, + ip: test.fields.ip, + UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, + } + + gotLoc, err := s.Remove(test.args.ctx, test.args.req) + if err := checkFunc(test.want, gotLoc, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + +// NOT IMPLEMENTED BELOW + +func TestNew(t *testing.T) { + type args struct { + opts []Option + } + type want struct { + want vald.ServerWithMirror + err error + } + type test struct { + name string + args args + want want + checkFunc func(want, vald.Server, error) error + beforeFunc func(*testing.T, args) + afterFunc func(*testing.T, args) + } + defaultCheckFunc := func(w want, got vald.Server, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } + if !reflect.DeepEqual(got, w.want) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", got, w.want) + } + return nil + } + tests := []test{ + // TODO test cases + /* + { + name: "test_case_1", + args: args { + opts:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + }, + */ + + // TODO test cases + /* + func() test { + return test { + name: "test_case_2", + args: args { + opts:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + } + }(), + */ + } + + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) + if test.beforeFunc != nil { + test.beforeFunc(tt, test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(tt, test.args) + } + checkFunc := test.checkFunc + if test.checkFunc == nil { + checkFunc = defaultCheckFunc + } + + got, err := New(test.args.opts...) + if err := checkFunc(test.want, got, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + +func Test_server_Register(t *testing.T) { + type args struct { + ctx context.Context + req *payload.Mirror_Targets + } + type fields struct { + eg errgroup.Group + gateway service.Gateway + mirror service.Mirror + vAddr string + streamConcurrency int + name string + ip string + UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror + } + type want struct { + want *payload.Mirror_Targets + err error + } + type test struct { + name string + args args + fields fields + want want + checkFunc func(want, *payload.Mirror_Targets, error) error + beforeFunc func(*testing.T, args) + afterFunc func(*testing.T, args) + } + defaultCheckFunc := func(w want, got *payload.Mirror_Targets, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } + if !reflect.DeepEqual(got, w.want) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", got, w.want) + } + return nil + } + tests := []test{ + // TODO test cases + /* + { + name: "test_case_1", + args: args { + ctx:nil, + req:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + }, + */ + + // TODO test cases + /* + func() test { + return test { + name: "test_case_2", + args: args { + ctx:nil, + req:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + } + }(), + */ + } + + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) + if test.beforeFunc != nil { + test.beforeFunc(tt, test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(tt, test.args) + } + checkFunc := test.checkFunc + if test.checkFunc == nil { + checkFunc = defaultCheckFunc + } + s := &server{ + eg: test.fields.eg, + gateway: test.fields.gateway, + mirror: test.fields.mirror, + vAddr: test.fields.vAddr, + streamConcurrency: test.fields.streamConcurrency, + name: test.fields.name, + ip: test.fields.ip, + UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, + } + + got, err := s.Register(test.args.ctx, test.args.req) + if err := checkFunc(test.want, got, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + +func Test_server_Advertise(t *testing.T) { + type args struct { + ctx context.Context + req *payload.Mirror_Targets + } + type fields struct { + eg errgroup.Group + gateway service.Gateway + mirror service.Mirror + vAddr string + streamConcurrency int + name string + ip string + UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror + } + type want struct { + wantRes *payload.Mirror_Targets + err error + } + type test struct { + name string + args args + fields fields + want want + checkFunc func(want, *payload.Mirror_Targets, error) error + beforeFunc func(*testing.T, args) + afterFunc func(*testing.T, args) + } + defaultCheckFunc := func(w want, gotRes *payload.Mirror_Targets, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } + if !reflect.DeepEqual(gotRes, w.wantRes) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotRes, w.wantRes) + } + return nil + } + tests := []test{ + // TODO test cases + /* + { + name: "test_case_1", + args: args { + ctx:nil, + req:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + }, + */ + + // TODO test cases + /* + func() test { + return test { + name: "test_case_2", + args: args { + ctx:nil, + req:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + } + }(), + */ + } + + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) + if test.beforeFunc != nil { + test.beforeFunc(tt, test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(tt, test.args) + } + checkFunc := test.checkFunc + if test.checkFunc == nil { + checkFunc = defaultCheckFunc + } + s := &server{ + eg: test.fields.eg, + gateway: test.fields.gateway, + mirror: test.fields.mirror, + vAddr: test.fields.vAddr, + streamConcurrency: test.fields.streamConcurrency, + name: test.fields.name, + ip: test.fields.ip, + UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, + } + + gotRes, err := s.Advertise(test.args.ctx, test.args.req) + if err := checkFunc(test.want, gotRes, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + +func Test_server_Exists(t *testing.T) { + type args struct { + ctx context.Context + meta *payload.Object_ID + } + type fields struct { + eg errgroup.Group + gateway service.Gateway + mirror service.Mirror + vAddr string + streamConcurrency int + name string + ip string + UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror + } + type want struct { + wantId *payload.Object_ID + err error + } + type test struct { + name string + args args + fields fields + want want + checkFunc func(want, *payload.Object_ID, error) error + beforeFunc func(*testing.T, args) + afterFunc func(*testing.T, args) + } + defaultCheckFunc := func(w want, gotId *payload.Object_ID, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } + if !reflect.DeepEqual(gotId, w.wantId) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotId, w.wantId) + } + return nil + } + tests := []test{ + // TODO test cases + /* + { + name: "test_case_1", + args: args { + ctx:nil, + meta:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + }, + */ + + // TODO test cases + /* + func() test { + return test { + name: "test_case_2", + args: args { + ctx:nil, + meta:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + } + }(), + */ + } + + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) + if test.beforeFunc != nil { + test.beforeFunc(tt, test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(tt, test.args) + } + checkFunc := test.checkFunc + if test.checkFunc == nil { + checkFunc = defaultCheckFunc + } + s := &server{ + eg: test.fields.eg, + gateway: test.fields.gateway, + mirror: test.fields.mirror, + vAddr: test.fields.vAddr, + streamConcurrency: test.fields.streamConcurrency, + name: test.fields.name, + ip: test.fields.ip, + UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, + } + + gotId, err := s.Exists(test.args.ctx, test.args.meta) + if err := checkFunc(test.want, gotId, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + +func Test_server_Search(t *testing.T) { + type args struct { + ctx context.Context + req *payload.Search_Request + } + type fields struct { + eg errgroup.Group + gateway service.Gateway + mirror service.Mirror + vAddr string + streamConcurrency int + name string + ip string + UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror + } + type want struct { + wantRes *payload.Search_Response + err error + } + type test struct { + name string + args args + fields fields + want want + checkFunc func(want, *payload.Search_Response, error) error + beforeFunc func(*testing.T, args) + afterFunc func(*testing.T, args) + } + defaultCheckFunc := func(w want, gotRes *payload.Search_Response, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } + if !reflect.DeepEqual(gotRes, w.wantRes) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotRes, w.wantRes) + } + return nil + } + tests := []test{ + // TODO test cases + /* + { + name: "test_case_1", + args: args { + ctx:nil, + req:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + }, + */ + + // TODO test cases + /* + func() test { + return test { + name: "test_case_2", + args: args { + ctx:nil, + req:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + } + }(), + */ + } + + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) + if test.beforeFunc != nil { + test.beforeFunc(tt, test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(tt, test.args) + } + checkFunc := test.checkFunc + if test.checkFunc == nil { + checkFunc = defaultCheckFunc + } + s := &server{ + eg: test.fields.eg, + gateway: test.fields.gateway, + mirror: test.fields.mirror, + vAddr: test.fields.vAddr, + streamConcurrency: test.fields.streamConcurrency, + name: test.fields.name, + ip: test.fields.ip, + UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, + } + + gotRes, err := s.Search(test.args.ctx, test.args.req) + if err := checkFunc(test.want, gotRes, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + +func Test_server_SearchByID(t *testing.T) { + type args struct { + ctx context.Context + req *payload.Search_IDRequest + } + type fields struct { + eg errgroup.Group + gateway service.Gateway + mirror service.Mirror + vAddr string + streamConcurrency int + name string + ip string + UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror + } + type want struct { + wantRes *payload.Search_Response + err error + } + type test struct { + name string + args args + fields fields + want want + checkFunc func(want, *payload.Search_Response, error) error + beforeFunc func(*testing.T, args) + afterFunc func(*testing.T, args) + } + defaultCheckFunc := func(w want, gotRes *payload.Search_Response, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } + if !reflect.DeepEqual(gotRes, w.wantRes) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotRes, w.wantRes) + } + return nil + } + tests := []test{ + // TODO test cases + /* + { + name: "test_case_1", + args: args { + ctx:nil, + req:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + }, + */ + + // TODO test cases + /* + func() test { + return test { + name: "test_case_2", + args: args { + ctx:nil, + req:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + } + }(), + */ + } + + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) + if test.beforeFunc != nil { + test.beforeFunc(tt, test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(tt, test.args) + } + checkFunc := test.checkFunc + if test.checkFunc == nil { + checkFunc = defaultCheckFunc + } + s := &server{ + eg: test.fields.eg, + gateway: test.fields.gateway, + mirror: test.fields.mirror, + vAddr: test.fields.vAddr, + streamConcurrency: test.fields.streamConcurrency, + name: test.fields.name, + ip: test.fields.ip, + UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, + } + + gotRes, err := s.SearchByID(test.args.ctx, test.args.req) + if err := checkFunc(test.want, gotRes, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + +func Test_server_StreamSearch(t *testing.T) { + type args struct { + stream vald.Search_StreamSearchServer + } + type fields struct { + eg errgroup.Group + gateway service.Gateway + mirror service.Mirror + vAddr string + streamConcurrency int + name string + ip string + UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror + } + type want struct { + err error + } + type test struct { + name string + args args + fields fields + want want + checkFunc func(want, error) error + beforeFunc func(*testing.T, args) + afterFunc func(*testing.T, args) + } + defaultCheckFunc := func(w want, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } + return nil + } + tests := []test{ + // TODO test cases + /* + { + name: "test_case_1", + args: args { + stream:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + }, + */ + + // TODO test cases + /* + func() test { + return test { + name: "test_case_2", + args: args { + stream:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + } + }(), + */ + } + + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) + if test.beforeFunc != nil { + test.beforeFunc(tt, test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(tt, test.args) + } + checkFunc := test.checkFunc + if test.checkFunc == nil { + checkFunc = defaultCheckFunc + } + s := &server{ + eg: test.fields.eg, + gateway: test.fields.gateway, + mirror: test.fields.mirror, + vAddr: test.fields.vAddr, + streamConcurrency: test.fields.streamConcurrency, + name: test.fields.name, + ip: test.fields.ip, + UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, + } + + err := s.StreamSearch(test.args.stream) + if err := checkFunc(test.want, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + +func Test_server_StreamSearchByID(t *testing.T) { + type args struct { + stream vald.Search_StreamSearchByIDServer + } + type fields struct { + eg errgroup.Group + gateway service.Gateway + mirror service.Mirror + vAddr string + streamConcurrency int + name string + ip string + UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror + } + type want struct { + err error + } + type test struct { + name string + args args + fields fields + want want + checkFunc func(want, error) error + beforeFunc func(*testing.T, args) + afterFunc func(*testing.T, args) + } + defaultCheckFunc := func(w want, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } + return nil + } + tests := []test{ + // TODO test cases + /* + { + name: "test_case_1", + args: args { + stream:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + }, + */ + + // TODO test cases + /* + func() test { + return test { + name: "test_case_2", + args: args { + stream:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + } + }(), + */ + } + + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) + if test.beforeFunc != nil { + test.beforeFunc(tt, test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(tt, test.args) + } + checkFunc := test.checkFunc + if test.checkFunc == nil { + checkFunc = defaultCheckFunc + } + s := &server{ + eg: test.fields.eg, + gateway: test.fields.gateway, + mirror: test.fields.mirror, + vAddr: test.fields.vAddr, + streamConcurrency: test.fields.streamConcurrency, + name: test.fields.name, + ip: test.fields.ip, + UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, + } + + err := s.StreamSearchByID(test.args.stream) + if err := checkFunc(test.want, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + +func Test_server_MultiSearch(t *testing.T) { + type args struct { + ctx context.Context + req *payload.Search_MultiRequest + } + type fields struct { + eg errgroup.Group + gateway service.Gateway + mirror service.Mirror + vAddr string + streamConcurrency int + name string + ip string + UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror + } + type want struct { + wantRes *payload.Search_Responses + err error + } + type test struct { + name string + args args + fields fields + want want + checkFunc func(want, *payload.Search_Responses, error) error + beforeFunc func(*testing.T, args) + afterFunc func(*testing.T, args) + } + defaultCheckFunc := func(w want, gotRes *payload.Search_Responses, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } + if !reflect.DeepEqual(gotRes, w.wantRes) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotRes, w.wantRes) + } + return nil + } + tests := []test{ + // TODO test cases + /* + { + name: "test_case_1", + args: args { + ctx:nil, + req:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + }, + */ + + // TODO test cases + /* + func() test { + return test { + name: "test_case_2", + args: args { + ctx:nil, + req:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + } + }(), + */ + } + + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) + if test.beforeFunc != nil { + test.beforeFunc(tt, test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(tt, test.args) + } + checkFunc := test.checkFunc + if test.checkFunc == nil { + checkFunc = defaultCheckFunc + } + s := &server{ + eg: test.fields.eg, + gateway: test.fields.gateway, + mirror: test.fields.mirror, + vAddr: test.fields.vAddr, + streamConcurrency: test.fields.streamConcurrency, + name: test.fields.name, + ip: test.fields.ip, + UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, + } + + gotRes, err := s.MultiSearch(test.args.ctx, test.args.req) + if err := checkFunc(test.want, gotRes, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + +func Test_server_MultiSearchByID(t *testing.T) { + type args struct { + ctx context.Context + req *payload.Search_MultiIDRequest + } + type fields struct { + eg errgroup.Group + gateway service.Gateway + mirror service.Mirror + vAddr string + streamConcurrency int + name string + ip string + UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror + } + type want struct { + wantRes *payload.Search_Responses + err error + } + type test struct { + name string + args args + fields fields + want want + checkFunc func(want, *payload.Search_Responses, error) error + beforeFunc func(*testing.T, args) + afterFunc func(*testing.T, args) + } + defaultCheckFunc := func(w want, gotRes *payload.Search_Responses, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } + if !reflect.DeepEqual(gotRes, w.wantRes) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotRes, w.wantRes) + } + return nil + } + tests := []test{ + // TODO test cases + /* + { + name: "test_case_1", + args: args { + ctx:nil, + req:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + }, + */ + + // TODO test cases + /* + func() test { + return test { + name: "test_case_2", + args: args { + ctx:nil, + req:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + } + }(), + */ + } + + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) + if test.beforeFunc != nil { + test.beforeFunc(tt, test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(tt, test.args) + } + checkFunc := test.checkFunc + if test.checkFunc == nil { + checkFunc = defaultCheckFunc + } + s := &server{ + eg: test.fields.eg, + gateway: test.fields.gateway, + mirror: test.fields.mirror, + vAddr: test.fields.vAddr, + streamConcurrency: test.fields.streamConcurrency, + name: test.fields.name, + ip: test.fields.ip, + UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, + } + + gotRes, err := s.MultiSearchByID(test.args.ctx, test.args.req) + if err := checkFunc(test.want, gotRes, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + +func Test_server_LinearSearch(t *testing.T) { + type args struct { + ctx context.Context + req *payload.Search_Request + } + type fields struct { + eg errgroup.Group + gateway service.Gateway + mirror service.Mirror + vAddr string + streamConcurrency int + name string + ip string + UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror + } + type want struct { + wantRes *payload.Search_Response + err error + } + type test struct { + name string + args args + fields fields + want want + checkFunc func(want, *payload.Search_Response, error) error + beforeFunc func(*testing.T, args) + afterFunc func(*testing.T, args) + } + defaultCheckFunc := func(w want, gotRes *payload.Search_Response, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } + if !reflect.DeepEqual(gotRes, w.wantRes) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotRes, w.wantRes) + } + return nil + } + tests := []test{ + // TODO test cases + /* + { + name: "test_case_1", + args: args { + ctx:nil, + req:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + }, + */ + + // TODO test cases + /* + func() test { + return test { + name: "test_case_2", + args: args { + ctx:nil, + req:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + } + }(), + */ + } + + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) + if test.beforeFunc != nil { + test.beforeFunc(tt, test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(tt, test.args) + } + checkFunc := test.checkFunc + if test.checkFunc == nil { + checkFunc = defaultCheckFunc + } + s := &server{ + eg: test.fields.eg, + gateway: test.fields.gateway, + mirror: test.fields.mirror, + vAddr: test.fields.vAddr, + streamConcurrency: test.fields.streamConcurrency, + name: test.fields.name, + ip: test.fields.ip, + UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, + } + + gotRes, err := s.LinearSearch(test.args.ctx, test.args.req) + if err := checkFunc(test.want, gotRes, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + +func Test_server_LinearSearchByID(t *testing.T) { + type args struct { + ctx context.Context + req *payload.Search_IDRequest + } + type fields struct { + eg errgroup.Group + gateway service.Gateway + mirror service.Mirror + vAddr string + streamConcurrency int + name string + ip string + UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror + } + type want struct { + wantRes *payload.Search_Response + err error + } + type test struct { + name string + args args + fields fields + want want + checkFunc func(want, *payload.Search_Response, error) error + beforeFunc func(*testing.T, args) + afterFunc func(*testing.T, args) + } + defaultCheckFunc := func(w want, gotRes *payload.Search_Response, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } + if !reflect.DeepEqual(gotRes, w.wantRes) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotRes, w.wantRes) + } + return nil + } + tests := []test{ + // TODO test cases + /* + { + name: "test_case_1", + args: args { + ctx:nil, + req:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + }, + */ + + // TODO test cases + /* + func() test { + return test { + name: "test_case_2", + args: args { + ctx:nil, + req:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + } + }(), + */ + } + + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) + if test.beforeFunc != nil { + test.beforeFunc(tt, test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(tt, test.args) + } + checkFunc := test.checkFunc + if test.checkFunc == nil { + checkFunc = defaultCheckFunc + } + s := &server{ + eg: test.fields.eg, + gateway: test.fields.gateway, + mirror: test.fields.mirror, + vAddr: test.fields.vAddr, + streamConcurrency: test.fields.streamConcurrency, + name: test.fields.name, + ip: test.fields.ip, + UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, + } + + gotRes, err := s.LinearSearchByID(test.args.ctx, test.args.req) + if err := checkFunc(test.want, gotRes, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + +func Test_server_StreamLinearSearch(t *testing.T) { + type args struct { + stream vald.Search_StreamLinearSearchServer + } + type fields struct { + eg errgroup.Group + gateway service.Gateway + mirror service.Mirror + vAddr string + streamConcurrency int + name string + ip string + UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror + } + type want struct { + err error + } + type test struct { + name string + args args + fields fields + want want + checkFunc func(want, error) error + beforeFunc func(*testing.T, args) + afterFunc func(*testing.T, args) + } + defaultCheckFunc := func(w want, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } + return nil + } + tests := []test{ + // TODO test cases + /* + { + name: "test_case_1", + args: args { + stream:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + }, + */ + + // TODO test cases + /* + func() test { + return test { + name: "test_case_2", + args: args { + stream:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + } + }(), + */ + } + + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) + if test.beforeFunc != nil { + test.beforeFunc(tt, test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(tt, test.args) + } + checkFunc := test.checkFunc + if test.checkFunc == nil { + checkFunc = defaultCheckFunc + } + s := &server{ + eg: test.fields.eg, + gateway: test.fields.gateway, + mirror: test.fields.mirror, + vAddr: test.fields.vAddr, + streamConcurrency: test.fields.streamConcurrency, + name: test.fields.name, + ip: test.fields.ip, + UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, + } + + err := s.StreamLinearSearch(test.args.stream) + if err := checkFunc(test.want, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + +func Test_server_StreamLinearSearchByID(t *testing.T) { + type args struct { + stream vald.Search_StreamLinearSearchByIDServer + } + type fields struct { + eg errgroup.Group + gateway service.Gateway + mirror service.Mirror + vAddr string + streamConcurrency int + name string + ip string + UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror + } + type want struct { + err error + } + type test struct { + name string + args args + fields fields + want want + checkFunc func(want, error) error + beforeFunc func(*testing.T, args) + afterFunc func(*testing.T, args) + } + defaultCheckFunc := func(w want, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } + return nil + } + tests := []test{ + // TODO test cases + /* + { + name: "test_case_1", + args: args { + stream:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + }, + */ + + // TODO test cases + /* + func() test { + return test { + name: "test_case_2", + args: args { + stream:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + } + }(), + */ + } + + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) + if test.beforeFunc != nil { + test.beforeFunc(tt, test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(tt, test.args) + } + checkFunc := test.checkFunc + if test.checkFunc == nil { + checkFunc = defaultCheckFunc + } + s := &server{ + eg: test.fields.eg, + gateway: test.fields.gateway, + mirror: test.fields.mirror, + vAddr: test.fields.vAddr, + streamConcurrency: test.fields.streamConcurrency, + name: test.fields.name, + ip: test.fields.ip, + UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, + } + + err := s.StreamLinearSearchByID(test.args.stream) + if err := checkFunc(test.want, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + +func Test_server_MultiLinearSearch(t *testing.T) { + type args struct { + ctx context.Context + req *payload.Search_MultiRequest + } + type fields struct { + eg errgroup.Group + gateway service.Gateway + mirror service.Mirror + vAddr string + streamConcurrency int + name string + ip string + UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror + } + type want struct { + wantRes *payload.Search_Responses + err error + } + type test struct { + name string + args args + fields fields + want want + checkFunc func(want, *payload.Search_Responses, error) error + beforeFunc func(*testing.T, args) + afterFunc func(*testing.T, args) + } + defaultCheckFunc := func(w want, gotRes *payload.Search_Responses, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } + if !reflect.DeepEqual(gotRes, w.wantRes) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotRes, w.wantRes) + } + return nil + } + tests := []test{ + // TODO test cases + /* + { + name: "test_case_1", + args: args { + ctx:nil, + req:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + }, + */ + + // TODO test cases + /* + func() test { + return test { + name: "test_case_2", + args: args { + ctx:nil, + req:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + } + }(), + */ + } + + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) + if test.beforeFunc != nil { + test.beforeFunc(tt, test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(tt, test.args) + } + checkFunc := test.checkFunc + if test.checkFunc == nil { + checkFunc = defaultCheckFunc + } + s := &server{ + eg: test.fields.eg, + gateway: test.fields.gateway, + mirror: test.fields.mirror, + vAddr: test.fields.vAddr, + streamConcurrency: test.fields.streamConcurrency, + name: test.fields.name, + ip: test.fields.ip, + UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, + } + + gotRes, err := s.MultiLinearSearch(test.args.ctx, test.args.req) + if err := checkFunc(test.want, gotRes, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + +func Test_server_MultiLinearSearchByID(t *testing.T) { + type args struct { + ctx context.Context + req *payload.Search_MultiIDRequest + } + type fields struct { + eg errgroup.Group + gateway service.Gateway + mirror service.Mirror + vAddr string + streamConcurrency int + name string + ip string + UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror + } + type want struct { + wantRes *payload.Search_Responses + err error + } + type test struct { + name string + args args + fields fields + want want + checkFunc func(want, *payload.Search_Responses, error) error + beforeFunc func(*testing.T, args) + afterFunc func(*testing.T, args) + } + defaultCheckFunc := func(w want, gotRes *payload.Search_Responses, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } + if !reflect.DeepEqual(gotRes, w.wantRes) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotRes, w.wantRes) + } + return nil + } + tests := []test{ + // TODO test cases + /* + { + name: "test_case_1", + args: args { + ctx:nil, + req:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + }, + */ + + // TODO test cases + /* + func() test { + return test { + name: "test_case_2", + args: args { + ctx:nil, + req:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + } + }(), + */ + } + + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) + if test.beforeFunc != nil { + test.beforeFunc(tt, test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(tt, test.args) + } + checkFunc := test.checkFunc + if test.checkFunc == nil { + checkFunc = defaultCheckFunc + } + s := &server{ + eg: test.fields.eg, + gateway: test.fields.gateway, + mirror: test.fields.mirror, + vAddr: test.fields.vAddr, + streamConcurrency: test.fields.streamConcurrency, + name: test.fields.name, + ip: test.fields.ip, + UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, + } + + gotRes, err := s.MultiLinearSearchByID(test.args.ctx, test.args.req) + if err := checkFunc(test.want, gotRes, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + +func Test_server_insert(t *testing.T) { + type args struct { + ctx context.Context + client vald.InsertClient + req *payload.Insert_Request + opts []grpc.CallOption + } + type fields struct { + eg errgroup.Group + gateway service.Gateway + mirror service.Mirror + vAddr string + streamConcurrency int + name string + ip string + UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror + } + type want struct { + wantLoc *payload.Object_Location + err error + } + type test struct { + name string + args args + fields fields + want want + checkFunc func(want, *payload.Object_Location, error) error + beforeFunc func(*testing.T, args) + afterFunc func(*testing.T, args) + } + defaultCheckFunc := func(w want, gotLoc *payload.Object_Location, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } + if !reflect.DeepEqual(gotLoc, w.wantLoc) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotLoc, w.wantLoc) + } + return nil + } + tests := []test{ + // TODO test cases + /* + { + name: "test_case_1", + args: args { + ctx:nil, + client:nil, + req:nil, + opts:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + }, + */ + + // TODO test cases + /* + func() test { + return test { + name: "test_case_2", + args: args { + ctx:nil, + client:nil, + req:nil, + opts:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + } + }(), + */ + } + + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) + if test.beforeFunc != nil { + test.beforeFunc(tt, test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(tt, test.args) + } + checkFunc := test.checkFunc + if test.checkFunc == nil { + checkFunc = defaultCheckFunc + } + s := &server{ + eg: test.fields.eg, + gateway: test.fields.gateway, + mirror: test.fields.mirror, + vAddr: test.fields.vAddr, + streamConcurrency: test.fields.streamConcurrency, + name: test.fields.name, + ip: test.fields.ip, + UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, + } + + gotLoc, err := s.insert(test.args.ctx, test.args.client, test.args.req, test.args.opts...) + if err := checkFunc(test.want, gotLoc, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + +func Test_server_StreamInsert(t *testing.T) { + type args struct { + stream vald.Insert_StreamInsertServer + } + type fields struct { + eg errgroup.Group + gateway service.Gateway + mirror service.Mirror + vAddr string + streamConcurrency int + name string + ip string + UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror + } + type want struct { + err error + } + type test struct { + name string + args args + fields fields + want want + checkFunc func(want, error) error + beforeFunc func(*testing.T, args) + afterFunc func(*testing.T, args) + } + defaultCheckFunc := func(w want, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } + return nil + } + tests := []test{ + // TODO test cases + /* + { + name: "test_case_1", + args: args { + stream:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + }, + */ + + // TODO test cases + /* + func() test { + return test { + name: "test_case_2", + args: args { + stream:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + } + }(), + */ + } + + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) + if test.beforeFunc != nil { + test.beforeFunc(tt, test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(tt, test.args) + } + checkFunc := test.checkFunc + if test.checkFunc == nil { + checkFunc = defaultCheckFunc + } + s := &server{ + eg: test.fields.eg, + gateway: test.fields.gateway, + mirror: test.fields.mirror, + vAddr: test.fields.vAddr, + streamConcurrency: test.fields.streamConcurrency, + name: test.fields.name, + ip: test.fields.ip, + UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, + } + + err := s.StreamInsert(test.args.stream) + if err := checkFunc(test.want, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + +func Test_server_MultiInsert(t *testing.T) { + type args struct { + ctx context.Context + reqs *payload.Insert_MultiRequest + } + type fields struct { + eg errgroup.Group + gateway service.Gateway + mirror service.Mirror + vAddr string + streamConcurrency int + name string + ip string + UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror + } + type want struct { + wantRes *payload.Object_Locations + err error + } + type test struct { + name string + args args + fields fields + want want + checkFunc func(want, *payload.Object_Locations, error) error + beforeFunc func(*testing.T, args) + afterFunc func(*testing.T, args) + } + defaultCheckFunc := func(w want, gotRes *payload.Object_Locations, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } + if !reflect.DeepEqual(gotRes, w.wantRes) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotRes, w.wantRes) + } + return nil + } + tests := []test{ + // TODO test cases + /* + { + name: "test_case_1", + args: args { + ctx:nil, + reqs:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + }, + */ + + // TODO test cases + /* + func() test { + return test { + name: "test_case_2", + args: args { + ctx:nil, + reqs:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + } + }(), + */ + } + + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) + if test.beforeFunc != nil { + test.beforeFunc(tt, test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(tt, test.args) + } + checkFunc := test.checkFunc + if test.checkFunc == nil { + checkFunc = defaultCheckFunc + } + s := &server{ + eg: test.fields.eg, + gateway: test.fields.gateway, + mirror: test.fields.mirror, + vAddr: test.fields.vAddr, + streamConcurrency: test.fields.streamConcurrency, + name: test.fields.name, + ip: test.fields.ip, + UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, + } + + gotRes, err := s.MultiInsert(test.args.ctx, test.args.reqs) + if err := checkFunc(test.want, gotRes, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + +func Test_server_update(t *testing.T) { + type args struct { + ctx context.Context + client vald.UpdateClient + req *payload.Update_Request + opts []grpc.CallOption + } + type fields struct { + eg errgroup.Group + gateway service.Gateway + mirror service.Mirror + vAddr string + streamConcurrency int + name string + ip string + UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror + } + type want struct { + wantLoc *payload.Object_Location + err error + } + type test struct { + name string + args args + fields fields + want want + checkFunc func(want, *payload.Object_Location, error) error + beforeFunc func(*testing.T, args) + afterFunc func(*testing.T, args) + } + defaultCheckFunc := func(w want, gotLoc *payload.Object_Location, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } + if !reflect.DeepEqual(gotLoc, w.wantLoc) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotLoc, w.wantLoc) + } + return nil + } + tests := []test{ + // TODO test cases + /* + { + name: "test_case_1", + args: args { + ctx:nil, + client:nil, + req:nil, + opts:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + }, + */ + + // TODO test cases + /* + func() test { + return test { + name: "test_case_2", + args: args { + ctx:nil, + client:nil, + req:nil, + opts:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + } + }(), + */ + } + + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) + if test.beforeFunc != nil { + test.beforeFunc(tt, test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(tt, test.args) + } + checkFunc := test.checkFunc + if test.checkFunc == nil { + checkFunc = defaultCheckFunc + } + s := &server{ + eg: test.fields.eg, + gateway: test.fields.gateway, + mirror: test.fields.mirror, + vAddr: test.fields.vAddr, + streamConcurrency: test.fields.streamConcurrency, + name: test.fields.name, + ip: test.fields.ip, + UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, + } + + gotLoc, err := s.update(test.args.ctx, test.args.client, test.args.req, test.args.opts...) + if err := checkFunc(test.want, gotLoc, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + +func Test_server_StreamUpdate(t *testing.T) { + type args struct { + stream vald.Update_StreamUpdateServer + } + type fields struct { + eg errgroup.Group + gateway service.Gateway + mirror service.Mirror + vAddr string + streamConcurrency int + name string + ip string + UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror + } + type want struct { + err error + } + type test struct { + name string + args args + fields fields + want want + checkFunc func(want, error) error + beforeFunc func(*testing.T, args) + afterFunc func(*testing.T, args) + } + defaultCheckFunc := func(w want, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } + return nil + } + tests := []test{ + // TODO test cases + /* + { + name: "test_case_1", + args: args { + stream:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + }, + */ + + // TODO test cases + /* + func() test { + return test { + name: "test_case_2", + args: args { + stream:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + } + }(), + */ + } + + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) + if test.beforeFunc != nil { + test.beforeFunc(tt, test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(tt, test.args) + } + checkFunc := test.checkFunc + if test.checkFunc == nil { + checkFunc = defaultCheckFunc + } + s := &server{ + eg: test.fields.eg, + gateway: test.fields.gateway, + mirror: test.fields.mirror, + vAddr: test.fields.vAddr, + streamConcurrency: test.fields.streamConcurrency, + name: test.fields.name, + ip: test.fields.ip, + UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, + } + + err := s.StreamUpdate(test.args.stream) + if err := checkFunc(test.want, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + +func Test_server_MultiUpdate(t *testing.T) { + type args struct { + ctx context.Context + reqs *payload.Update_MultiRequest + } + type fields struct { + eg errgroup.Group + gateway service.Gateway + mirror service.Mirror + vAddr string + streamConcurrency int + name string + ip string + UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror + } + type want struct { + wantRes *payload.Object_Locations + err error + } + type test struct { + name string + args args + fields fields + want want + checkFunc func(want, *payload.Object_Locations, error) error + beforeFunc func(*testing.T, args) + afterFunc func(*testing.T, args) + } + defaultCheckFunc := func(w want, gotRes *payload.Object_Locations, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } + if !reflect.DeepEqual(gotRes, w.wantRes) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotRes, w.wantRes) + } + return nil + } + tests := []test{ + // TODO test cases + /* + { + name: "test_case_1", + args: args { + ctx:nil, + reqs:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + }, + */ + + // TODO test cases + /* + func() test { + return test { + name: "test_case_2", + args: args { + ctx:nil, + reqs:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + } + }(), + */ + } + + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) + if test.beforeFunc != nil { + test.beforeFunc(tt, test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(tt, test.args) + } + checkFunc := test.checkFunc + if test.checkFunc == nil { + checkFunc = defaultCheckFunc + } + s := &server{ + eg: test.fields.eg, + gateway: test.fields.gateway, + mirror: test.fields.mirror, + vAddr: test.fields.vAddr, + streamConcurrency: test.fields.streamConcurrency, + name: test.fields.name, + ip: test.fields.ip, + UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, + } + + gotRes, err := s.MultiUpdate(test.args.ctx, test.args.reqs) + if err := checkFunc(test.want, gotRes, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + +func Test_server_upsert(t *testing.T) { + type args struct { + ctx context.Context + client vald.UpsertClient + req *payload.Upsert_Request + opts []grpc.CallOption + } + type fields struct { + eg errgroup.Group + gateway service.Gateway + mirror service.Mirror + vAddr string + streamConcurrency int + name string + ip string + UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror + } + type want struct { + wantLoc *payload.Object_Location + err error + } + type test struct { + name string + args args + fields fields + want want + checkFunc func(want, *payload.Object_Location, error) error + beforeFunc func(*testing.T, args) + afterFunc func(*testing.T, args) + } + defaultCheckFunc := func(w want, gotLoc *payload.Object_Location, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } + if !reflect.DeepEqual(gotLoc, w.wantLoc) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotLoc, w.wantLoc) + } + return nil + } + tests := []test{ + // TODO test cases + /* + { + name: "test_case_1", + args: args { + ctx:nil, + client:nil, + req:nil, + opts:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + }, + */ + + // TODO test cases + /* + func() test { + return test { + name: "test_case_2", + args: args { + ctx:nil, + client:nil, + req:nil, + opts:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + } + }(), + */ + } + + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) + if test.beforeFunc != nil { + test.beforeFunc(tt, test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(tt, test.args) + } + checkFunc := test.checkFunc + if test.checkFunc == nil { + checkFunc = defaultCheckFunc + } + s := &server{ + eg: test.fields.eg, + gateway: test.fields.gateway, + mirror: test.fields.mirror, + vAddr: test.fields.vAddr, + streamConcurrency: test.fields.streamConcurrency, + name: test.fields.name, + ip: test.fields.ip, + UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, + } + + gotLoc, err := s.upsert(test.args.ctx, test.args.client, test.args.req, test.args.opts...) + if err := checkFunc(test.want, gotLoc, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + +func Test_server_StreamUpsert(t *testing.T) { + type args struct { + stream vald.Upsert_StreamUpsertServer + } + type fields struct { + eg errgroup.Group + gateway service.Gateway + mirror service.Mirror + vAddr string + streamConcurrency int + name string + ip string + UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror + } + type want struct { + err error + } + type test struct { + name string + args args + fields fields + want want + checkFunc func(want, error) error + beforeFunc func(*testing.T, args) + afterFunc func(*testing.T, args) + } + defaultCheckFunc := func(w want, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } + return nil + } + tests := []test{ + // TODO test cases + /* + { + name: "test_case_1", + args: args { + stream:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + }, + */ + + // TODO test cases + /* + func() test { + return test { + name: "test_case_2", + args: args { + stream:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + } + }(), + */ + } + + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) + if test.beforeFunc != nil { + test.beforeFunc(tt, test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(tt, test.args) + } + checkFunc := test.checkFunc + if test.checkFunc == nil { + checkFunc = defaultCheckFunc + } + s := &server{ + eg: test.fields.eg, + gateway: test.fields.gateway, + mirror: test.fields.mirror, + vAddr: test.fields.vAddr, + streamConcurrency: test.fields.streamConcurrency, + name: test.fields.name, + ip: test.fields.ip, + UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, + } + + err := s.StreamUpsert(test.args.stream) + if err := checkFunc(test.want, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + +func Test_server_MultiUpsert(t *testing.T) { + type args struct { + ctx context.Context + reqs *payload.Upsert_MultiRequest + } + type fields struct { + eg errgroup.Group + gateway service.Gateway + mirror service.Mirror + vAddr string + streamConcurrency int + name string + ip string + UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror + } + type want struct { + wantRes *payload.Object_Locations + err error + } + type test struct { + name string + args args + fields fields + want want + checkFunc func(want, *payload.Object_Locations, error) error + beforeFunc func(*testing.T, args) + afterFunc func(*testing.T, args) + } + defaultCheckFunc := func(w want, gotRes *payload.Object_Locations, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } + if !reflect.DeepEqual(gotRes, w.wantRes) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotRes, w.wantRes) + } + return nil + } + tests := []test{ + // TODO test cases + /* + { + name: "test_case_1", + args: args { + ctx:nil, + reqs:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + }, + */ + + // TODO test cases + /* + func() test { + return test { + name: "test_case_2", + args: args { + ctx:nil, + reqs:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + } + }(), + */ + } + + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) + if test.beforeFunc != nil { + test.beforeFunc(tt, test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(tt, test.args) + } + checkFunc := test.checkFunc + if test.checkFunc == nil { + checkFunc = defaultCheckFunc + } + s := &server{ + eg: test.fields.eg, + gateway: test.fields.gateway, + mirror: test.fields.mirror, + vAddr: test.fields.vAddr, + streamConcurrency: test.fields.streamConcurrency, + name: test.fields.name, + ip: test.fields.ip, + UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, + } + + gotRes, err := s.MultiUpsert(test.args.ctx, test.args.reqs) + if err := checkFunc(test.want, gotRes, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + +func Test_server_remove(t *testing.T) { + type args struct { + ctx context.Context + client vald.RemoveClient + req *payload.Remove_Request + opts []grpc.CallOption + } + type fields struct { + eg errgroup.Group + gateway service.Gateway + mirror service.Mirror + vAddr string + streamConcurrency int + name string + ip string + UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror + } + type want struct { + want *payload.Object_Location + err error + } + type test struct { + name string + args args + fields fields + want want + checkFunc func(want, *payload.Object_Location, error) error + beforeFunc func(*testing.T, args) + afterFunc func(*testing.T, args) + } + defaultCheckFunc := func(w want, got *payload.Object_Location, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } + if !reflect.DeepEqual(got, w.want) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", got, w.want) + } + return nil + } + tests := []test{ + // TODO test cases + /* + { + name: "test_case_1", + args: args { + ctx:nil, + client:nil, + req:nil, + opts:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + }, + */ + + // TODO test cases + /* + func() test { + return test { + name: "test_case_2", + args: args { + ctx:nil, + client:nil, + req:nil, + opts:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + } + }(), + */ + } + + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) + if test.beforeFunc != nil { + test.beforeFunc(tt, test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(tt, test.args) + } + checkFunc := test.checkFunc + if test.checkFunc == nil { + checkFunc = defaultCheckFunc + } + s := &server{ + eg: test.fields.eg, + gateway: test.fields.gateway, + mirror: test.fields.mirror, + vAddr: test.fields.vAddr, + streamConcurrency: test.fields.streamConcurrency, + name: test.fields.name, + ip: test.fields.ip, + UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, + } + + got, err := s.remove(test.args.ctx, test.args.client, test.args.req, test.args.opts...) + if err := checkFunc(test.want, got, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + +func Test_server_StreamRemove(t *testing.T) { + type args struct { + stream vald.Remove_StreamRemoveServer + } + type fields struct { + eg errgroup.Group + gateway service.Gateway + mirror service.Mirror + vAddr string + streamConcurrency int + name string + ip string + UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror + } + type want struct { + err error + } + type test struct { + name string + args args + fields fields + want want + checkFunc func(want, error) error + beforeFunc func(*testing.T, args) + afterFunc func(*testing.T, args) + } + defaultCheckFunc := func(w want, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } + return nil + } + tests := []test{ + // TODO test cases + /* + { + name: "test_case_1", + args: args { + stream:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + }, + */ + + // TODO test cases + /* + func() test { + return test { + name: "test_case_2", + args: args { + stream:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + } + }(), + */ + } + + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) + if test.beforeFunc != nil { + test.beforeFunc(tt, test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(tt, test.args) + } + checkFunc := test.checkFunc + if test.checkFunc == nil { + checkFunc = defaultCheckFunc + } + s := &server{ + eg: test.fields.eg, + gateway: test.fields.gateway, + mirror: test.fields.mirror, + vAddr: test.fields.vAddr, + streamConcurrency: test.fields.streamConcurrency, + name: test.fields.name, + ip: test.fields.ip, + UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, + } + + err := s.StreamRemove(test.args.stream) + if err := checkFunc(test.want, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + +func Test_server_MultiRemove(t *testing.T) { + type args struct { + ctx context.Context + reqs *payload.Remove_MultiRequest + } + type fields struct { + eg errgroup.Group + gateway service.Gateway + mirror service.Mirror + vAddr string + streamConcurrency int + name string + ip string + UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror + } + type want struct { + wantRes *payload.Object_Locations + err error + } + type test struct { + name string + args args + fields fields + want want + checkFunc func(want, *payload.Object_Locations, error) error + beforeFunc func(*testing.T, args) + afterFunc func(*testing.T, args) + } + defaultCheckFunc := func(w want, gotRes *payload.Object_Locations, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } + if !reflect.DeepEqual(gotRes, w.wantRes) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotRes, w.wantRes) + } + return nil + } + tests := []test{ + // TODO test cases + /* + { + name: "test_case_1", + args: args { + ctx:nil, + reqs:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + }, + */ + + // TODO test cases + /* + func() test { + return test { + name: "test_case_2", + args: args { + ctx:nil, + reqs:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + } + }(), + */ + } + + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) + if test.beforeFunc != nil { + test.beforeFunc(tt, test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(tt, test.args) + } + checkFunc := test.checkFunc + if test.checkFunc == nil { + checkFunc = defaultCheckFunc + } + s := &server{ + eg: test.fields.eg, + gateway: test.fields.gateway, + mirror: test.fields.mirror, + vAddr: test.fields.vAddr, + streamConcurrency: test.fields.streamConcurrency, + name: test.fields.name, + ip: test.fields.ip, + UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, + } + + gotRes, err := s.MultiRemove(test.args.ctx, test.args.reqs) + if err := checkFunc(test.want, gotRes, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + +func Test_server_GetObject(t *testing.T) { + type args struct { + ctx context.Context + req *payload.Object_VectorRequest + } + type fields struct { + eg errgroup.Group + gateway service.Gateway + mirror service.Mirror + vAddr string + streamConcurrency int + name string + ip string + UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror + } + type want struct { + wantVec *payload.Object_Vector + err error + } + type test struct { + name string + args args + fields fields + want want + checkFunc func(want, *payload.Object_Vector, error) error + beforeFunc func(*testing.T, args) + afterFunc func(*testing.T, args) + } + defaultCheckFunc := func(w want, gotVec *payload.Object_Vector, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } + if !reflect.DeepEqual(gotVec, w.wantVec) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotVec, w.wantVec) + } + return nil + } + tests := []test{ + // TODO test cases + /* + { + name: "test_case_1", + args: args { + ctx:nil, + req:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + }, + */ + + // TODO test cases + /* + func() test { + return test { + name: "test_case_2", + args: args { + ctx:nil, + req:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + } + }(), + */ + } + + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) + if test.beforeFunc != nil { + test.beforeFunc(tt, test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(tt, test.args) + } + checkFunc := test.checkFunc + if test.checkFunc == nil { + checkFunc = defaultCheckFunc + } + s := &server{ + eg: test.fields.eg, + gateway: test.fields.gateway, + mirror: test.fields.mirror, + vAddr: test.fields.vAddr, + streamConcurrency: test.fields.streamConcurrency, + name: test.fields.name, + ip: test.fields.ip, + UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, + } + + gotVec, err := s.GetObject(test.args.ctx, test.args.req) + if err := checkFunc(test.want, gotVec, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + +func Test_server_getObjects(t *testing.T) { + type args struct { + ctx context.Context + req *payload.Object_VectorRequest + } + type fields struct { + eg errgroup.Group + gateway service.Gateway + mirror service.Mirror + vAddr string + streamConcurrency int + name string + ip string + UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror + } + type want struct { + wantVecs *sync.Map + err error + } + type test struct { + name string + args args + fields fields + want want + checkFunc func(want, *sync.Map, error) error + beforeFunc func(*testing.T, args) + afterFunc func(*testing.T, args) + } + defaultCheckFunc := func(w want, gotVecs *sync.Map, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } + if !reflect.DeepEqual(gotVecs, w.wantVecs) { + return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", gotVecs, w.wantVecs) + } + return nil + } + tests := []test{ + // TODO test cases + /* + { + name: "test_case_1", + args: args { + ctx:nil, + req:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + }, + */ + + // TODO test cases + /* + func() test { + return test { + name: "test_case_2", + args: args { + ctx:nil, + req:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + } + }(), + */ + } + + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) + if test.beforeFunc != nil { + test.beforeFunc(tt, test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(tt, test.args) + } + checkFunc := test.checkFunc + if test.checkFunc == nil { + checkFunc = defaultCheckFunc + } + s := &server{ + eg: test.fields.eg, + gateway: test.fields.gateway, + mirror: test.fields.mirror, + vAddr: test.fields.vAddr, + streamConcurrency: test.fields.streamConcurrency, + name: test.fields.name, + ip: test.fields.ip, + UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, + } + + gotVecs, err := s.getObjects(test.args.ctx, test.args.req) + if err := checkFunc(test.want, gotVecs, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + +func Test_server_StreamGetObject(t *testing.T) { + type args struct { + stream vald.Object_StreamGetObjectServer + } + type fields struct { + eg errgroup.Group + gateway service.Gateway + mirror service.Mirror + vAddr string + streamConcurrency int + name string + ip string + UnimplementedValdServerWithMirror vald.UnimplementedValdServerWithMirror + } + type want struct { + err error + } + type test struct { + name string + args args + fields fields + want want + checkFunc func(want, error) error + beforeFunc func(*testing.T, args) + afterFunc func(*testing.T, args) + } + defaultCheckFunc := func(w want, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } + return nil + } + tests := []test{ + // TODO test cases + /* + { + name: "test_case_1", + args: args { + stream:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + }, + */ + + // TODO test cases + /* + func() test { + return test { + name: "test_case_2", + args: args { + stream:nil, + }, + fields: fields { + eg:nil, + gateway:nil, + mirror:nil, + vAddr:"", + streamConcurrency:0, + name:"", + ip:"", + UnimplementedValdServerWithMirror:nil, + }, + want: want{}, + checkFunc: defaultCheckFunc, + beforeFunc: func(t *testing.T, args args) { + t.Helper() + }, + afterFunc: func(t *testing.T, args args) { + t.Helper() + }, + } + }(), + */ + } + + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) + if test.beforeFunc != nil { + test.beforeFunc(tt, test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(tt, test.args) + } + checkFunc := test.checkFunc + if test.checkFunc == nil { + checkFunc = defaultCheckFunc + } + s := &server{ + eg: test.fields.eg, + gateway: test.fields.gateway, + mirror: test.fields.mirror, + vAddr: test.fields.vAddr, + streamConcurrency: test.fields.streamConcurrency, + name: test.fields.name, + ip: test.fields.ip, + UnimplementedValdServerWithMirror: test.fields.UnimplementedValdServerWithMirror, + } + + err := s.StreamGetObject(test.args.stream) + if err := checkFunc(test.want, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} diff --git a/pkg/gateway/mirror/handler/grpc/mock_test.go b/pkg/gateway/mirror/handler/grpc/mock_test.go new file mode 100644 index 0000000000..f62078105a --- /dev/null +++ b/pkg/gateway/mirror/handler/grpc/mock_test.go @@ -0,0 +1,215 @@ +package grpc + +import ( + "context" + + "github.com/vdaas/vald/apis/grpc/v1/payload" + "github.com/vdaas/vald/apis/grpc/v1/vald" + "github.com/vdaas/vald/internal/net/grpc" + "github.com/vdaas/vald/pkg/gateway/mirror/service" +) + +type mockClient struct { + InsertFunc func(ctx context.Context, in *payload.Insert_Request, opts ...grpc.CallOption) (*payload.Object_Location, error) + StreamInsertFunc func(ctx context.Context, opts ...grpc.CallOption) (vald.Insert_StreamInsertClient, error) + MultiInsertFunc func(ctx context.Context, in *payload.Insert_MultiRequest, opts ...grpc.CallOption) (*payload.Object_Locations, error) + + UpdateFunc func(ctx context.Context, in *payload.Update_Request, opts ...grpc.CallOption) (*payload.Object_Location, error) + StreamUpdateFunc func(ctx context.Context, opts ...grpc.CallOption) (vald.Update_StreamUpdateClient, error) + MultiUpdateFunc func(ctx context.Context, in *payload.Update_MultiRequest, opts ...grpc.CallOption) (*payload.Object_Locations, error) + + UpsertFunc func(ctx context.Context, in *payload.Upsert_Request, opts ...grpc.CallOption) (*payload.Object_Location, error) + StreamUpsertFunc func(ctx context.Context, opts ...grpc.CallOption) (vald.Upsert_StreamUpsertClient, error) + MultiUpsertFunc func(ctx context.Context, in *payload.Upsert_MultiRequest, opts ...grpc.CallOption) (*payload.Object_Locations, error) + + SearchFunc func(ctx context.Context, in *payload.Search_Request, opts ...grpc.CallOption) (*payload.Search_Response, error) + SearchByIDFunc func(ctx context.Context, in *payload.Search_IDRequest, opts ...grpc.CallOption) (*payload.Search_Response, error) + StreamSearchFunc func(ctx context.Context, opts ...grpc.CallOption) (vald.Search_StreamSearchClient, error) + StreamSearchByIDFunc func(ctx context.Context, opts ...grpc.CallOption) (vald.Search_StreamSearchByIDClient, error) + MultiSearchFunc func(ctx context.Context, in *payload.Search_MultiRequest, opts ...grpc.CallOption) (*payload.Search_Responses, error) + MultiSearchByIDFunc func(ctx context.Context, in *payload.Search_MultiIDRequest, opts ...grpc.CallOption) (*payload.Search_Responses, error) + LinearSearchFunc func(ctx context.Context, in *payload.Search_Request, opts ...grpc.CallOption) (*payload.Search_Response, error) + LinearSearchByIDFunc func(ctx context.Context, in *payload.Search_IDRequest, opts ...grpc.CallOption) (*payload.Search_Response, error) + StreamLinearSearchFunc func(ctx context.Context, opts ...grpc.CallOption) (vald.Search_StreamLinearSearchClient, error) + StreamLinearSearchByIDFunc func(ctx context.Context, opts ...grpc.CallOption) (vald.Search_StreamLinearSearchByIDClient, error) + MultiLinearSearchFunc func(ctx context.Context, in *payload.Search_MultiRequest, opts ...grpc.CallOption) (*payload.Search_Responses, error) + MultiLinearSearchByIDFunc func(ctx context.Context, in *payload.Search_MultiIDRequest, opts ...grpc.CallOption) (*payload.Search_Responses, error) + + RemoveFunc func(ctx context.Context, in *payload.Remove_Request, opts ...grpc.CallOption) (*payload.Object_Location, error) + StreamRemoveFunc func(ctx context.Context, opts ...grpc.CallOption) (vald.Remove_StreamRemoveClient, error) + MultiRemoveFunc func(ctx context.Context, in *payload.Remove_MultiRequest, opts ...grpc.CallOption) (*payload.Object_Locations, error) + + ExistsFunc func(ctx context.Context, in *payload.Object_ID, opts ...grpc.CallOption) (*payload.Object_ID, error) + GetObjectFunc func(ctx context.Context, in *payload.Object_VectorRequest, opts ...grpc.CallOption) (*payload.Object_Vector, error) + StreamGetObjectFunc func(ctx context.Context, opts ...grpc.CallOption) (vald.Object_StreamGetObjectClient, error) + + RegisterFunc func(ctx context.Context, in *payload.Mirror_Targets, opts ...grpc.CallOption) (*payload.Mirror_Targets, error) + AdvertiseFunc func(ctx context.Context, in *payload.Mirror_Targets, opts ...grpc.CallOption) (*payload.Mirror_Targets, error) +} + +func (m *mockClient) Insert(ctx context.Context, in *payload.Insert_Request, opts ...grpc.CallOption) (*payload.Object_Location, error) { + return m.InsertFunc(ctx, in, opts...) +} + +func (m *mockClient) StreamInsert(ctx context.Context, opts ...grpc.CallOption) (vald.Insert_StreamInsertClient, error) { + return m.StreamInsertFunc(ctx, opts...) +} + +func (m *mockClient) MultiInsert(ctx context.Context, in *payload.Insert_MultiRequest, opts ...grpc.CallOption) (*payload.Object_Locations, error) { + return m.MultiInsertFunc(ctx, in, opts...) +} + +func (m *mockClient) Update(ctx context.Context, in *payload.Update_Request, opts ...grpc.CallOption) (*payload.Object_Location, error) { + return m.UpdateFunc(ctx, in, opts...) +} + +func (m *mockClient) StreamUpdate(ctx context.Context, opts ...grpc.CallOption) (vald.Update_StreamUpdateClient, error) { + return m.StreamUpdateFunc(ctx, opts...) +} + +func (m *mockClient) MultiUpdate(ctx context.Context, in *payload.Update_MultiRequest, opts ...grpc.CallOption) (*payload.Object_Locations, error) { + return m.MultiUpdateFunc(ctx, in, opts...) +} + +func (m *mockClient) Upsert(ctx context.Context, in *payload.Upsert_Request, opts ...grpc.CallOption) (*payload.Object_Location, error) { + return m.UpsertFunc(ctx, in, opts...) +} + +func (m *mockClient) StreamUpsert(ctx context.Context, opts ...grpc.CallOption) (vald.Upsert_StreamUpsertClient, error) { + return m.StreamUpsertFunc(ctx, opts...) +} + +func (m *mockClient) MultiUpsert(ctx context.Context, in *payload.Upsert_MultiRequest, opts ...grpc.CallOption) (*payload.Object_Locations, error) { + return m.MultiUpsertFunc(ctx, in, opts...) +} + +func (m *mockClient) Search(ctx context.Context, in *payload.Search_Request, opts ...grpc.CallOption) (*payload.Search_Response, error) { + return m.SearchFunc(ctx, in, opts...) +} + +func (m *mockClient) SearchByID(ctx context.Context, in *payload.Search_IDRequest, opts ...grpc.CallOption) (*payload.Search_Response, error) { + return m.SearchByIDFunc(ctx, in, opts...) +} + +func (m *mockClient) StreamSearch(ctx context.Context, opts ...grpc.CallOption) (vald.Search_StreamSearchClient, error) { + return m.StreamSearchFunc(ctx, opts...) +} + +func (m *mockClient) StreamSearchByID(ctx context.Context, opts ...grpc.CallOption) (vald.Search_StreamSearchByIDClient, error) { + return m.StreamSearchByIDFunc(ctx, opts...) +} + +func (m *mockClient) MultiSearch(ctx context.Context, in *payload.Search_MultiRequest, opts ...grpc.CallOption) (*payload.Search_Responses, error) { + return m.MultiSearchFunc(ctx, in, opts...) +} + +func (m *mockClient) MultiSearchByID(ctx context.Context, in *payload.Search_MultiIDRequest, opts ...grpc.CallOption) (*payload.Search_Responses, error) { + return m.MultiSearchByIDFunc(ctx, in, opts...) +} + +func (m *mockClient) LinearSearch(ctx context.Context, in *payload.Search_Request, opts ...grpc.CallOption) (*payload.Search_Response, error) { + return m.LinearSearchFunc(ctx, in, opts...) +} + +func (m *mockClient) LinearSearchByID(ctx context.Context, in *payload.Search_IDRequest, opts ...grpc.CallOption) (*payload.Search_Response, error) { + return m.LinearSearchByIDFunc(ctx, in, opts...) +} + +func (m *mockClient) StreamLinearSearch(ctx context.Context, opts ...grpc.CallOption) (vald.Search_StreamLinearSearchClient, error) { + return m.StreamLinearSearchFunc(ctx, opts...) +} + +func (m *mockClient) StreamLinearSearchByID(ctx context.Context, opts ...grpc.CallOption) (vald.Search_StreamLinearSearchByIDClient, error) { + return m.StreamLinearSearchByIDFunc(ctx, opts...) +} + +func (m *mockClient) MultiLinearSearch(ctx context.Context, in *payload.Search_MultiRequest, opts ...grpc.CallOption) (*payload.Search_Responses, error) { + return m.MultiLinearSearchFunc(ctx, in, opts...) +} + +func (m *mockClient) MultiLinearSearchByID(ctx context.Context, in *payload.Search_MultiIDRequest, opts ...grpc.CallOption) (*payload.Search_Responses, error) { + return m.MultiLinearSearchByIDFunc(ctx, in, opts...) +} + +func (m *mockClient) Remove(ctx context.Context, in *payload.Remove_Request, opts ...grpc.CallOption) (*payload.Object_Location, error) { + return m.RemoveFunc(ctx, in, opts...) +} + +func (m *mockClient) StreamRemove(ctx context.Context, opts ...grpc.CallOption) (vald.Remove_StreamRemoveClient, error) { + return m.StreamRemoveFunc(ctx, opts...) +} + +func (m *mockClient) MultiRemove(ctx context.Context, in *payload.Remove_MultiRequest, opts ...grpc.CallOption) (*payload.Object_Locations, error) { + return m.MultiRemoveFunc(ctx, in, opts...) +} + +func (m *mockClient) Exists(ctx context.Context, in *payload.Object_ID, opts ...grpc.CallOption) (*payload.Object_ID, error) { + return m.ExistsFunc(ctx, in, opts...) +} + +func (m *mockClient) GetObject(ctx context.Context, in *payload.Object_VectorRequest, opts ...grpc.CallOption) (*payload.Object_Vector, error) { + return m.GetObjectFunc(ctx, in, opts...) +} + +func (m *mockClient) StreamGetObject(ctx context.Context, opts ...grpc.CallOption) (vald.Object_StreamGetObjectClient, error) { + return m.StreamGetObjectFunc(ctx, opts...) +} + +func (m *mockClient) Register(ctx context.Context, in *payload.Mirror_Targets, opts ...grpc.CallOption) (*payload.Mirror_Targets, error) { + return m.RegisterFunc(ctx, in) +} + +func (m *mockClient) Advertise(ctx context.Context, in *payload.Mirror_Targets, opts ...grpc.CallOption) (*payload.Mirror_Targets, error) { + return m.AdvertiseFunc(ctx, in) +} + +var _ vald.ClientWithMirror = (*mockClient)(nil) + +type mockGateway struct { + StartFunc func(ctx context.Context) (<-chan error, error) + ForwardedContextFunc func(ctx context.Context, podName string) context.Context + FromForwardedContextFunc func(ctx context.Context) string + BroadCastFunc func(ctx context.Context, + f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error + DoFunc func(ctx context.Context, target string, + f func(ctx context.Context, vc vald.ClientWithMirror, copts ...grpc.CallOption) (interface{}, error)) (interface{}, error) + DoMultiFunc func(ctx context.Context, targets []string, + f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error + GRPCClientFunc func() grpc.Client +} + +func (m *mockGateway) Start(ctx context.Context) (<-chan error, error) { + return m.StartFunc(ctx) +} + +func (m *mockGateway) ForwardedContext(ctx context.Context, podName string) context.Context { + return m.ForwardedContextFunc(ctx, podName) +} + +func (m *mockGateway) FromForwardedContext(ctx context.Context) string { + return m.FromForwardedContextFunc(ctx) +} + +func (m *mockGateway) BroadCast(ctx context.Context, + f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error, +) error { + return m.BroadCastFunc(ctx, f) +} + +func (m *mockGateway) Do(ctx context.Context, target string, + f func(ctx context.Context, vc vald.ClientWithMirror, copts ...grpc.CallOption) (interface{}, error), +) (interface{}, error) { + return m.DoFunc(ctx, target, f) +} + +func (m *mockGateway) DoMulti(ctx context.Context, targets []string, + f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error, +) error { + return m.DoMultiFunc(ctx, targets, f) +} + +func (m *mockGateway) GRPCClient() grpc.Client { + return m.GRPCClientFunc() +} + +var _ service.Gateway = (*mockGateway)(nil) diff --git a/pkg/gateway/mirror/handler/grpc/option.go b/pkg/gateway/mirror/handler/grpc/option.go new file mode 100644 index 0000000000..09f0996e17 --- /dev/null +++ b/pkg/gateway/mirror/handler/grpc/option.go @@ -0,0 +1,110 @@ +// +// Copyright (C) 2019-2022 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package grpc provides grpc server logic +package grpc + +import ( + "os" + "runtime" + + "github.com/vdaas/vald/internal/errgroup" + "github.com/vdaas/vald/internal/errors" + "github.com/vdaas/vald/internal/log" + "github.com/vdaas/vald/internal/net" + "github.com/vdaas/vald/pkg/gateway/mirror/service" +) + +type Option func(*server) error + +var defaultOptions = []Option{ + WithErrGroup(errgroup.Get()), + WithStreamConcurrency(runtime.GOMAXPROCS(-1) * 10), + WithName(func() string { + name, err := os.Hostname() + if err != nil { + log.Warn(err) + } + return name + }()), + WithIP(net.LoadLocalIP()), +} + +// WithIP returns the option to set the IP for server. +func WithIP(ip string) Option { + return func(s *server) error { + if len(ip) != 0 { + s.ip = ip + } + return nil + } +} + +// WithName returns the option to set the name for server. +func WithName(name string) Option { + return func(s *server) error { + if len(name) != 0 { + s.name = name + } + return nil + } +} + +func WithGateway(g service.Gateway) Option { + return func(s *server) error { + if g != nil { + s.gateway = g + } + return nil + } +} + +func WithMirror(m service.Mirror) Option { + return func(s *server) error { + if m != nil { + s.mirror = m + } + return nil + } +} + +func WithErrGroup(eg errgroup.Group) Option { + return func(s *server) error { + if eg != nil { + s.eg = eg + } + return nil + } +} + +func WithStreamConcurrency(c int) Option { + return func(s *server) error { + if c > 0 { + s.streamConcurrency = c + } + return nil + } +} + +func WithValdAddr(addr string) Option { + return func(s *server) error { + if len(addr) == 0 { + return errors.NewErrCriticalOption("valdAddr", addr) + } + s.vAddr = addr + return nil + } +} diff --git a/pkg/gateway/mirror/handler/rest/handler.go b/pkg/gateway/mirror/handler/rest/handler.go new file mode 100644 index 0000000000..55a856e928 --- /dev/null +++ b/pkg/gateway/mirror/handler/rest/handler.go @@ -0,0 +1,211 @@ +// +// Copyright (C) 2019-2022 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package rest provides rest api logic +package rest + +import ( + "net/http" + + "github.com/vdaas/vald/apis/grpc/v1/payload" + "github.com/vdaas/vald/apis/grpc/v1/vald" + "github.com/vdaas/vald/internal/net/http/dump" + "github.com/vdaas/vald/internal/net/http/json" +) + +type Handler interface { + Register(w http.ResponseWriter, r *http.Request) (int, error) + Advertise(w http.ResponseWriter, r *http.Request) (int, error) + Index(w http.ResponseWriter, r *http.Request) (int, error) + Exists(w http.ResponseWriter, r *http.Request) (int, error) + Search(w http.ResponseWriter, r *http.Request) (int, error) + SearchByID(w http.ResponseWriter, r *http.Request) (int, error) + MultiSearch(w http.ResponseWriter, r *http.Request) (int, error) + MultiSearchByID(w http.ResponseWriter, r *http.Request) (int, error) + LinearSearch(w http.ResponseWriter, r *http.Request) (int, error) + LinearSearchByID(w http.ResponseWriter, r *http.Request) (int, error) + MultiLinearSearch(w http.ResponseWriter, r *http.Request) (int, error) + MultiLinearSearchByID(w http.ResponseWriter, r *http.Request) (int, error) + Insert(w http.ResponseWriter, r *http.Request) (int, error) + MultiInsert(w http.ResponseWriter, r *http.Request) (int, error) + Update(w http.ResponseWriter, r *http.Request) (int, error) + MultiUpdate(w http.ResponseWriter, r *http.Request) (int, error) + Upsert(w http.ResponseWriter, r *http.Request) (int, error) + MultiUpsert(w http.ResponseWriter, r *http.Request) (int, error) + Remove(w http.ResponseWriter, r *http.Request) (int, error) + MultiRemove(w http.ResponseWriter, r *http.Request) (int, error) + GetObject(w http.ResponseWriter, r *http.Request) (int, error) +} + +type handler struct { + vald vald.ServerWithMirror +} + +func New(opts ...Option) Handler { + h := new(handler) + + for _, opt := range append(defaultOptions, opts...) { + opt(h) + } + return h +} + +func (h *handler) Register(w http.ResponseWriter, r *http.Request) (code int, err error) { + var req *payload.Mirror_Targets + return json.Handler(w, r, &req, func() (interface{}, error) { + return h.vald.Register(r.Context(), req) + }) +} + +func (h *handler) Advertise(w http.ResponseWriter, r *http.Request) (code int, err error) { + var req *payload.Mirror_Targets + return json.Handler(w, r, &req, func() (interface{}, error) { + return h.vald.Advertise(r.Context(), req) + }) +} + +func (h *handler) Index(w http.ResponseWriter, r *http.Request) (int, error) { + data := make(map[string]interface{}) + return json.Handler(w, r, &data, func() (interface{}, error) { + return dump.Request(nil, data, r) + }) +} + +func (h *handler) Search(w http.ResponseWriter, r *http.Request) (code int, err error) { + var req *payload.Search_Request + return json.Handler(w, r, &req, func() (interface{}, error) { + return h.vald.Search(r.Context(), req) + }) +} + +func (h *handler) SearchByID(w http.ResponseWriter, r *http.Request) (code int, err error) { + var req *payload.Search_IDRequest + return json.Handler(w, r, &req, func() (interface{}, error) { + return h.vald.SearchByID(r.Context(), req) + }) +} + +func (h *handler) MultiSearch(w http.ResponseWriter, r *http.Request) (code int, err error) { + var req *payload.Search_MultiRequest + return json.Handler(w, r, &req, func() (interface{}, error) { + return h.vald.MultiSearch(r.Context(), req) + }) +} + +func (h *handler) MultiSearchByID(w http.ResponseWriter, r *http.Request) (code int, err error) { + var req *payload.Search_MultiIDRequest + return json.Handler(w, r, &req, func() (interface{}, error) { + return h.vald.MultiSearchByID(r.Context(), req) + }) +} + +func (h *handler) LinearSearch(w http.ResponseWriter, r *http.Request) (code int, err error) { + var req *payload.Search_Request + return json.Handler(w, r, &req, func() (interface{}, error) { + return h.vald.LinearSearch(r.Context(), req) + }) +} + +func (h *handler) LinearSearchByID(w http.ResponseWriter, r *http.Request) (code int, err error) { + var req *payload.Search_IDRequest + return json.Handler(w, r, &req, func() (interface{}, error) { + return h.vald.LinearSearchByID(r.Context(), req) + }) +} + +func (h *handler) MultiLinearSearch(w http.ResponseWriter, r *http.Request) (code int, err error) { + var req *payload.Search_MultiRequest + return json.Handler(w, r, &req, func() (interface{}, error) { + return h.vald.MultiLinearSearch(r.Context(), req) + }) +} + +func (h *handler) MultiLinearSearchByID(w http.ResponseWriter, r *http.Request) (code int, err error) { + var req *payload.Search_MultiIDRequest + return json.Handler(w, r, &req, func() (interface{}, error) { + return h.vald.MultiLinearSearchByID(r.Context(), req) + }) +} + +func (h *handler) Insert(w http.ResponseWriter, r *http.Request) (code int, err error) { + var req *payload.Insert_Request + return json.Handler(w, r, &req, func() (interface{}, error) { + return h.vald.Insert(r.Context(), req) + }) +} + +func (h *handler) MultiInsert(w http.ResponseWriter, r *http.Request) (code int, err error) { + var req *payload.Insert_MultiRequest + return json.Handler(w, r, &req, func() (interface{}, error) { + return h.vald.MultiInsert(r.Context(), req) + }) +} + +func (h *handler) Update(w http.ResponseWriter, r *http.Request) (code int, err error) { + var req *payload.Update_Request + return json.Handler(w, r, &req, func() (interface{}, error) { + return h.vald.Update(r.Context(), req) + }) +} + +func (h *handler) MultiUpdate(w http.ResponseWriter, r *http.Request) (code int, err error) { + var req *payload.Update_MultiRequest + return json.Handler(w, r, &req, func() (interface{}, error) { + return h.vald.MultiUpdate(r.Context(), req) + }) +} + +func (h *handler) Upsert(w http.ResponseWriter, r *http.Request) (code int, err error) { + var req *payload.Upsert_Request + return json.Handler(w, r, &req, func() (interface{}, error) { + return h.vald.Upsert(r.Context(), req) + }) +} + +func (h *handler) MultiUpsert(w http.ResponseWriter, r *http.Request) (code int, err error) { + var req *payload.Upsert_MultiRequest + return json.Handler(w, r, &req, func() (interface{}, error) { + return h.vald.MultiUpsert(r.Context(), req) + }) +} + +func (h *handler) Remove(w http.ResponseWriter, r *http.Request) (code int, err error) { + var req *payload.Remove_Request + return json.Handler(w, r, &req, func() (interface{}, error) { + return h.vald.Remove(r.Context(), req) + }) +} + +func (h *handler) MultiRemove(w http.ResponseWriter, r *http.Request) (code int, err error) { + var req *payload.Remove_MultiRequest + return json.Handler(w, r, &req, func() (interface{}, error) { + return h.vald.MultiRemove(r.Context(), req) + }) +} + +func (h *handler) GetObject(w http.ResponseWriter, r *http.Request) (code int, err error) { + var req *payload.Object_VectorRequest + return json.Handler(w, r, &req, func() (interface{}, error) { + return h.vald.GetObject(r.Context(), req) + }) +} + +func (h *handler) Exists(w http.ResponseWriter, r *http.Request) (code int, err error) { + var req *payload.Object_ID + return json.Handler(w, r, &req, func() (interface{}, error) { + return h.vald.Exists(r.Context(), req) + }) +} diff --git a/pkg/gateway/mirror/handler/rest/option.go b/pkg/gateway/mirror/handler/rest/option.go new file mode 100644 index 0000000000..109a418fe3 --- /dev/null +++ b/pkg/gateway/mirror/handler/rest/option.go @@ -0,0 +1,32 @@ +// +// Copyright (C) 2019-2022 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package rest provides rest api logic +package rest + +import ( + "github.com/vdaas/vald/apis/grpc/v1/vald" +) + +type Option func(*handler) + +var defaultOptions = []Option{} + +func WithVald(v vald.ServerWithMirror) Option { + return func(h *handler) { + h.vald = v + } +} diff --git a/pkg/gateway/mirror/router/option.go b/pkg/gateway/mirror/router/option.go new file mode 100644 index 0000000000..d27a01fa5a --- /dev/null +++ b/pkg/gateway/mirror/router/option.go @@ -0,0 +1,40 @@ +// +// Copyright (C) 2019-2022 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package router provides implementation of Go API for routing http Handler wrapped by rest.Func +package router + +import ( + "github.com/vdaas/vald/pkg/gateway/mirror/handler/rest" +) + +type Option func(*router) + +var defaultOptions = []Option{ + WithTimeout("3s"), +} + +func WithHandler(h rest.Handler) Option { + return func(r *router) { + r.handler = h + } +} + +func WithTimeout(timeout string) Option { + return func(r *router) { + r.timeout = timeout + } +} diff --git a/pkg/gateway/mirror/router/router.go b/pkg/gateway/mirror/router/router.go new file mode 100644 index 0000000000..b2b0d06a24 --- /dev/null +++ b/pkg/gateway/mirror/router/router.go @@ -0,0 +1,183 @@ +// +// Copyright (C) 2019-2022 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package router provides implementation of Go API for routing http Handler wrapped by rest.Func +package router + +import ( + "net/http" + + "github.com/vdaas/vald/internal/net/http/routing" + "github.com/vdaas/vald/pkg/gateway/mirror/handler/rest" +) + +type router struct { + handler rest.Handler + timeout string +} + +// New returns REST route&method information from handler interface. +func New(opts ...Option) http.Handler { + r := new(router) + + for _, opt := range append(defaultOptions, opts...) { + opt(r) + } + + h := r.handler + + return routing.New( + routing.WithRoutes([]routing.Route{ + { + "Index", + []string{ + http.MethodGet, + }, + "/", + h.Index, + }, + { + "Register", + []string{ + http.MethodPost, + }, + "/register", + h.Register, + }, + { + "Advertise", + []string{ + http.MethodPost, + }, + "/register", + h.Advertise, + }, + { + "Search", + []string{ + http.MethodPost, + }, + "/search", + h.Search, + }, + { + "Search By ID", + []string{ + http.MethodGet, + }, + "/search/{id}", + h.SearchByID, + }, + + { + "Multi Search", + []string{ + http.MethodPost, + }, + "/search/multi", + h.MultiSearch, + }, + { + "Multi Search By ID", + []string{ + http.MethodGet, + }, + "/search/multi/{id}", + h.MultiSearchByID, + }, + { + "Insert", + []string{ + http.MethodPost, + }, + "/insert", + h.Insert, + }, + { + "Multiple Insert", + []string{ + http.MethodPost, + }, + "/insert/multi", + h.MultiInsert, + }, + { + "Update", + []string{ + http.MethodPost, + http.MethodPatch, + http.MethodPut, + }, + "/update", + h.Update, + }, + { + "Multiple Update", + []string{ + http.MethodPost, + http.MethodPatch, + http.MethodPut, + }, + "/update/multi", + h.MultiUpdate, + }, + { + "Upsert", + []string{ + http.MethodPost, + http.MethodPatch, + http.MethodPut, + }, + "/upsert", + h.Upsert, + }, + { + "Multiple Upsert", + []string{ + http.MethodPost, + http.MethodPatch, + http.MethodPut, + }, + "/upsert/multi", + h.MultiUpsert, + }, + { + "Remove", + []string{ + http.MethodDelete, + }, + "/delete/{id}", + h.Remove, + }, + { + "Multiple Remove", + []string{ + http.MethodDelete, + http.MethodPost, + }, + "/delete/multi", + h.MultiRemove, + }, + { + "GetObject", + []string{ + http.MethodGet, + }, + "/object/{id}", + h.GetObject, + }, + }...)) +} diff --git a/pkg/gateway/mirror/service/doc.go b/pkg/gateway/mirror/service/doc.go new file mode 100644 index 0000000000..266389f73d --- /dev/null +++ b/pkg/gateway/mirror/service/doc.go @@ -0,0 +1,18 @@ +// +// Copyright (C) 2019-2022 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package service manages the main logic of server. +package service diff --git a/pkg/gateway/mirror/service/gateway.go b/pkg/gateway/mirror/service/gateway.go new file mode 100644 index 0000000000..40690d4dc0 --- /dev/null +++ b/pkg/gateway/mirror/service/gateway.go @@ -0,0 +1,172 @@ +// +// Copyright (C) 2019-2022 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package service +package service + +import ( + "context" + "reflect" + + "github.com/vdaas/vald/apis/grpc/v1/vald" + client "github.com/vdaas/vald/internal/client/v1/client/mirror" + "github.com/vdaas/vald/internal/errgroup" + "github.com/vdaas/vald/internal/errors" + "github.com/vdaas/vald/internal/log" + "github.com/vdaas/vald/internal/net/grpc" + "github.com/vdaas/vald/internal/observability/trace" +) + +const ( + forwardedContextKey = "forwarded-for" + forwardedContextValue = "gateway mirror" +) + +type Gateway interface { + ForwardedContext(ctx context.Context, podName string) context.Context + FromForwardedContext(ctx context.Context) string + BroadCast(ctx context.Context, + f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error + Do(ctx context.Context, target string, + f func(ctx context.Context, vc vald.ClientWithMirror, copts ...grpc.CallOption) (interface{}, error)) (interface{}, error) + DoMulti(ctx context.Context, targets []string, + f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error) error + GRPCClient() grpc.Client +} + +type gateway struct { + client client.Client // Mirror Gateway client for other clusters and to the Vald gateway (LB gateway) client for own cluster. + eg errgroup.Group + podName string +} + +func NewGateway(opts ...Option) (Gateway, error) { + g := new(gateway) + for _, opt := range append(defaultGWOpts, opts...) { + if err := opt(g); err != nil { + oerr := errors.ErrOptionFailed(err, reflect.ValueOf(opt)) + e := &errors.ErrCriticalOption{} + if errors.As(err, &e) { + log.Error(oerr) + return nil, oerr + } + log.Warn(oerr) + return nil, oerr + } + } + return g, nil +} + +func (g *gateway) GRPCClient() grpc.Client { + return g.client.GRPCClient() +} + +func (g *gateway) ForwardedContext(ctx context.Context, podName string) context.Context { + return grpc.NewOutgoingContext(ctx, grpc.MD{ + forwardedContextKey: []string{ + podName, + }, + }) +} + +func (g *gateway) FromForwardedContext(ctx context.Context) string { + md, ok := grpc.FromIncomingContext(ctx) + if !ok { + return "" + } + vals, ok := md[forwardedContextKey] + if !ok { + return "" + } + if len(vals) > 0 { + return vals[0] + } + return "" +} + +func (g *gateway) BroadCast(ctx context.Context, + f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error, +) (err error) { + ctx, span := trace.StartSpan(ctx, "vald/gateway/mirror/service/Gateway.BroadCast") + defer func() { + if span != nil { + span.End() + } + }() + return g.client.GRPCClient().RangeConcurrent(g.ForwardedContext(ctx, g.podName), -1, func(ictx context.Context, + addr string, conn *grpc.ClientConn, copts ...grpc.CallOption, + ) (err error) { + select { + case <-ictx.Done(): + return nil + default: + err = f(ictx, addr, vald.NewValdClientWithMirror(conn), copts...) + if err != nil { + return err + } + } + return nil + }) +} + +func (g *gateway) Do(ctx context.Context, target string, + f func(ctx context.Context, vc vald.ClientWithMirror, copts ...grpc.CallOption) (interface{}, error), +) (res interface{}, err error) { + ctx, span := trace.StartSpan(ctx, "vald/gateway/mirror/service/Gateway.Do") + defer func() { + if span != nil { + span.End() + } + }() + + if len(target) == 0 { + return nil, errors.ErrTargetNotFound + } + return g.client.GRPCClient().Do(g.ForwardedContext(ctx, g.podName), target, + func(ictx context.Context, conn *grpc.ClientConn, copts ...grpc.CallOption) (interface{}, error) { + return f(ictx, vald.NewValdClientWithMirror(conn), copts...) + }, + ) +} + +func (g *gateway) DoMulti(ctx context.Context, targets []string, + f func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error, +) error { + ctx, span := trace.StartSpan(ctx, "vald/gateway/mirror/service/Gateway.DoMulti") + defer func() { + if span != nil { + span.End() + } + }() + + if len(targets) == 0 { + return errors.ErrTargetNotFound + } + return g.client.GRPCClient().OrderedRangeConcurrent(g.ForwardedContext(ctx, g.podName), targets, -1, + func(ictx context.Context, addr string, conn *grpc.ClientConn, copts ...grpc.CallOption) (err error) { + select { + case <-ictx.Done(): + return nil + default: + err = f(ictx, addr, vald.NewValdClientWithMirror(conn), copts...) + if err != nil { + return err + } + } + return nil + }, + ) +} diff --git a/pkg/gateway/mirror/service/mirror.go b/pkg/gateway/mirror/service/mirror.go new file mode 100644 index 0000000000..31d68340e8 --- /dev/null +++ b/pkg/gateway/mirror/service/mirror.go @@ -0,0 +1,432 @@ +// +// Copyright (C) 2019-2022 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package service +package service + +import ( + "context" + "reflect" + "sync" + "time" + + "github.com/vdaas/vald/apis/grpc/v1/payload" + "github.com/vdaas/vald/apis/grpc/v1/vald" + "github.com/vdaas/vald/internal/errgroup" + "github.com/vdaas/vald/internal/errors" + "github.com/vdaas/vald/internal/log" + "github.com/vdaas/vald/internal/net" + "github.com/vdaas/vald/internal/net/grpc" + "github.com/vdaas/vald/internal/net/grpc/codes" + "github.com/vdaas/vald/internal/net/grpc/errdetails" + "github.com/vdaas/vald/internal/net/grpc/status" + "github.com/vdaas/vald/internal/observability/trace" +) + +// Mirror manages other mirror gateway connection. +// If there is a new Mirror Gateway components, registers new connection. +type Mirror interface { + Start(ctx context.Context) (<-chan error, error) + Connect(ctx context.Context, targets ...*payload.Mirror_Target) error + MirrorTargets() ([]*payload.Mirror_Target, error) +} + +type mirr struct { + addrl sync.Map // List of all connected addresses + selfMirrAddrs []string // Address of self mirror gateway + selfMirrAddrl sync.Map // List of self Mirror gateway addresses + gwAddrs []string // Address of Vald Gateway (LB gateway) + gwAddrl sync.Map // List of Vald Gateway addresses + gateway Gateway + eg errgroup.Group + advertiseDur time.Duration +} + +func NewMirror(opts ...MirrorOption) (Mirror, error) { + m := new(mirr) + for _, opt := range append(defaultMirrOpts, opts...) { + if err := opt(m); err != nil { + oerr := errors.ErrOptionFailed(err, reflect.ValueOf(opt)) + e := &errors.ErrCriticalOption{} + if errors.As(err, &e) { + log.Error(oerr) + return nil, oerr + } + log.Warn(oerr) + } + } + for _, addr := range m.selfMirrAddrs { + m.selfMirrAddrl.Store(addr, struct{}{}) + } + for _, addr := range m.gwAddrs { + m.gwAddrl.Store(addr, struct{}{}) + } + return m, nil +} + +func (m *mirr) Start(ctx context.Context) (<-chan error, error) { + ech := make(chan error, 100) + + aech, err := m.startAdvertise(ctx) + if err != nil { + close(ech) + return nil, err + } + + m.eg.Go(func() (err error) { + defer close(ech) + for { + select { + case <-ctx.Done(): + return ctx.Err() + case err = <-aech: + } + if err != nil { + select { + case <-ctx.Done(): + case ech <- err: + } + err = nil + } + } + }) + return ech, nil +} + +func (m *mirr) startAdvertise(ctx context.Context) (<-chan error, error) { + ctx, span := trace.StartSpan(ctx, "vald/gateway/mirror/service/Mirror.startAdvertise") + defer func() { + if span != nil { + span.End() + } + }() + ech := make(chan error, 100) + + tgts, err := m.toMirrorTargets(m.selfMirrAddrs...) + if err != nil { + close(ech) + return nil, err + } + err = m.registers(ctx, tgts) + if err != nil && + !errors.Is(err, errors.ErrTargetNotFound) && + !errors.Is(err, errors.ErrGRPCClientConnNotFound("*")) { + var attrs trace.Attributes + + switch { + case errors.Is(err, context.Canceled): + err = status.WrapWithCanceled( + vald.InsertRPCName+" API canceld", err, + ) + attrs = trace.StatusCodeCancelled(err.Error()) + case errors.Is(err, context.DeadlineExceeded): + err = status.WrapWithDeadlineExceeded( + vald.InsertRPCName+" API deadline exceeded", err, + ) + attrs = trace.StatusCodeDeadlineExceeded(err.Error()) + default: + var ( + st *status.Status + msg string + ) + st, msg, err = status.ParseError(err, codes.Internal, "failed to parse "+vald.RegisterRPCName+" gRPC error response") + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...) + span.SetStatus(trace.StatusError, err.Error()) + } + } + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(attrs...) + span.SetStatus(trace.StatusError, err.Error()) + } + close(ech) + return nil, err + } + + m.eg.Go(func() (err error) { + tic := time.NewTicker(m.advertiseDur) + defer close(ech) + defer tic.Stop() + + for { + select { + case <-ctx.Done(): + return err + case <-tic.C: + tgts, err := m.toMirrorTargets(append(m.selfMirrAddrs, m.gateway.GRPCClient().ConnectedAddrs()...)...) + if err != nil || len(tgts.GetTargets()) == 0 { + if err == nil { + err = errors.ErrTargetNotFound + } + select { + case <-ctx.Done(): + return ctx.Err() + case ech <- err: + } + continue + } + resTgts, err := m.advertises(ctx, tgts) + if err != nil || len(resTgts) == 0 { + if err == nil { + err = errors.ErrTargetNotFound + } + select { + case <-ctx.Done(): + return ctx.Err() + case ech <- err: + } + continue + } + if err = m.Connect(ctx, resTgts...); err != nil { + select { + case <-ctx.Done(): + return ctx.Err() + case ech <- err: + } + } + log.Debugf("[mirror]: connected mirror gateway targets: %v", m.gateway.GRPCClient().ConnectedAddrs()) + } + } + }) + return ech, nil +} + +func (m *mirr) registers(ctx context.Context, tgts *payload.Mirror_Targets) error { + ctx, span := trace.StartSpan(grpc.WithGRPCMethod(ctx, vald.PackageName+"."+vald.MirrorRPCServiceName+"/"+vald.RegisterRPCName), "vald/gateway/mirror/service/Mirror.registers") + defer func() { + if span != nil { + span.End() + } + }() + reqInfo := &errdetails.RequestInfo{ + ServingData: errdetails.Serialize(tgts), + } + resInfo := &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.RegisterRPCName, + } + + return m.gateway.DoMulti(ctx, m.connectedMirrorAddrs(), func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error { + ctx, span := trace.StartSpan(ctx, "vald/gateway/mirror/service/Mirror.registers/"+target) + defer func() { + if span != nil { + span.End() + } + }() + + _, err := vc.Register(ctx, tgts, copts...) + if err != nil { + var attrs trace.Attributes + switch { + case errors.Is(err, context.Canceled): + err = status.WrapWithCanceled( + vald.RegisterRPCName+" API canceld", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeCancelled(err.Error()) + case errors.Is(err, context.DeadlineExceeded): + err = status.WrapWithCanceled( + vald.RegisterRPCName+" API deadline exceeded", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeDeadlineExceeded(err.Error()) + case errors.Is(err, errors.ErrGRPCClientConnNotFound("*")): + err = status.WrapWithInternal( + vald.RegisterRPCName+" API connection not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + default: + var ( + st *status.Status + msg string + ) + st, msg, err = status.ParseError(err, codes.Internal, + "failed to parse "+vald.RegisterRPCName+" gRPC error response", reqInfo, resInfo, + ) + attrs = trace.FromGRPCStatus(st.Code(), msg) + } + log.Error("failed to send Register API to %s\t: %v", target, err) + if span != nil { + span.RecordError(err) + span.SetAttributes(attrs...) + span.SetStatus(trace.StatusError, err.Error()) + } + return err + } + + return nil + }) +} + +func (m *mirr) advertises(ctx context.Context, tgts *payload.Mirror_Targets) ([]*payload.Mirror_Target, error) { + ctx, span := trace.StartSpan(grpc.WithGRPCMethod(ctx, vald.PackageName+"."+vald.MirrorRPCServiceName+"/"+vald.AdvertiseRPCName), "vald/gateway/vald/service/Mirror.advertises") + defer func() { + if span != nil { + span.End() + } + }() + reqInfo := &errdetails.RequestInfo{ + ServingData: errdetails.Serialize(tgts), + } + resInfo := &errdetails.ResourceInfo{ + ResourceType: errdetails.ValdGRPCResourceTypePrefix + "/vald.v1." + vald.AdvertiseRPCName, + } + resTgts := make([]*payload.Mirror_Target, 0, len(tgts.GetTargets())) + var mu sync.Mutex + + err := m.gateway.DoMulti(ctx, m.connectedMirrorAddrs(), func(ctx context.Context, target string, vc vald.ClientWithMirror, copts ...grpc.CallOption) error { + ctx, span := trace.StartSpan(ctx, "vald/gateway/mirror/service/Mirror.advertises/"+target) + defer func() { + if span != nil { + span.End() + } + }() + res, err := vc.Advertise(ctx, tgts) + if err != nil { + var attrs trace.Attributes + switch { + case errors.Is(err, context.Canceled): + err = status.WrapWithCanceled( + vald.AdvertiseRPCName+" API canceld", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeCancelled(err.Error()) + case errors.Is(err, context.DeadlineExceeded): + err = status.WrapWithCanceled( + vald.AdvertiseRPCName+" API deadline exceeded", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeDeadlineExceeded(err.Error()) + case errors.Is(err, errors.ErrGRPCClientConnNotFound("*")): + err = status.WrapWithInternal( + vald.AdvertiseRPCName+" API connection not found", err, reqInfo, resInfo, + ) + attrs = trace.StatusCodeInternal(err.Error()) + default: + var ( + st *status.Status + msg string + ) + st, msg, err = status.ParseError(err, codes.Internal, + "failed to parse "+vald.AdvertiseRPCName+" gRPC error response", reqInfo, resInfo, + ) + attrs = trace.FromGRPCStatus(st.Code(), msg) + } + log.Errorf("failed to process advertise requst to %s\terror: %s", target, err.Error()) + if span != nil { + span.RecordError(err) + span.SetAttributes(attrs...) + span.SetStatus(trace.StatusError, err.Error()) + } + return err + } + if res != nil && len(res.GetTargets()) > 0 { + mu.Lock() + resTgts = append(resTgts, res.GetTargets()...) + mu.Unlock() + } + return nil + }) + return resTgts, err +} + +func (m *mirr) Connect(ctx context.Context, targets ...*payload.Mirror_Target) error { + ctx, span := trace.StartSpan(ctx, "vald/gateway/mirror/service/Mirror.Connect") + defer func() { + if span != nil { + span.End() + } + }() + if len(targets) == 0 { + return errors.ErrTargetNotFound + } + for _, target := range targets { + addr := net.JoinHostPort(target.GetIp(), uint16(target.GetPort())) // addr: host:port + if !m.isSelfMirrorAddr(addr) && !m.isGatewayAddr(addr) { + _, ok := m.addrl.Load(addr) + if !ok || !m.gateway.GRPCClient().IsConnected(ctx, addr) { + _, err := m.gateway.GRPCClient().Connect(ctx, addr) + if err != nil { + m.addrl.Delete(addr) + return err + } + } + m.addrl.Store(addr, struct{}{}) + } + } + return nil +} + +func (m *mirr) MirrorTargets() ([]*payload.Mirror_Target, error) { + addrs := append(m.selfMirrAddrs, m.gateway.GRPCClient().ConnectedAddrs()...) + tgts := make([]*payload.Mirror_Target, 0, len(addrs)) + for _, addr := range addrs { + if !m.isGatewayAddr(addr) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + tgts = append(tgts, &payload.Mirror_Target{ + Ip: host, + Port: uint32(port), + }) + } + } + return tgts, nil +} + +func (m *mirr) isSelfMirrorAddr(addr string) bool { + if _, ok := m.selfMirrAddrl.Load(addr); ok { + return true + } + return false +} + +func (m *mirr) isGatewayAddr(addr string) bool { + if _, ok := m.gwAddrl.Load(addr); ok { + return true + } + return false +} + +func (m *mirr) connectedMirrorAddrs() []string { + connectedAddrs := m.gateway.GRPCClient().ConnectedAddrs() + addrs := make([]string, 0, len(connectedAddrs)) + for _, addr := range connectedAddrs { + if !m.isSelfMirrorAddr(addr) && + !m.isGatewayAddr(addr) { + addrs = append(addrs, addr) + } + } + return addrs +} + +func (m *mirr) toMirrorTargets(addrs ...string) (*payload.Mirror_Targets, error) { + tgts := make([]*payload.Mirror_Target, 0, len(addrs)) + for _, addr := range addrs { + if ok := m.isGatewayAddr(addr); !ok { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + tgts = append(tgts, &payload.Mirror_Target{ + Ip: host, + Port: uint32(port), + }) + } + } + return &payload.Mirror_Targets{ + Targets: tgts, + }, nil +} diff --git a/pkg/gateway/mirror/service/mirror_option.go b/pkg/gateway/mirror/service/mirror_option.go new file mode 100644 index 0000000000..e942322c7f --- /dev/null +++ b/pkg/gateway/mirror/service/mirror_option.go @@ -0,0 +1,72 @@ +package service + +import ( + "time" + + "github.com/vdaas/vald/internal/errgroup" + "github.com/vdaas/vald/internal/errors" +) + +type MirrorOption func(m *mirr) error + +var defaultMirrOpts = []MirrorOption{ + WithAdvertiseInterval("1s"), +} + +func WithErrorGroup(eg errgroup.Group) MirrorOption { + return func(m *mirr) error { + if eg != nil { + m.eg = eg + } + return nil + } +} + +func WithValdAddrs(addrs ...string) MirrorOption { + return func(m *mirr) error { + if len(addrs) == 0 { + return errors.NewErrCriticalOption("lbAddrs", addrs) + } + if m.gwAddrs == nil { + m.gwAddrs = make([]string, 0, len(addrs)) + } + m.gwAddrs = append(m.gwAddrs, addrs...) + return nil + } +} + +func WithSelfMirrorAddrs(addrs ...string) MirrorOption { + return func(m *mirr) error { + if len(addrs) == 0 { + return errors.NewErrCriticalOption("selfMirrorAddrs", addrs) + } + if m.selfMirrAddrs == nil { + m.selfMirrAddrs = make([]string, 0, len(addrs)) + } + m.selfMirrAddrs = append(m.selfMirrAddrs, addrs...) + return nil + } +} + +func WithGateway(g Gateway) MirrorOption { + return func(m *mirr) error { + if g != nil { + m.gateway = g + } + return nil + } +} + +func WithAdvertiseInterval(s string) MirrorOption { + return func(m *mirr) error { + if len(s) == 0 { + return errors.NewErrInvalidOption("advertiseInterval", s) + } + dur, err := time.ParseDuration(s) + if err != nil { + return errors.NewErrInvalidOption("advertiseInterval", s, err) + } + m.advertiseDur = dur + return nil + } +} diff --git a/pkg/gateway/mirror/service/option.go b/pkg/gateway/mirror/service/option.go new file mode 100644 index 0000000000..6ee360463d --- /dev/null +++ b/pkg/gateway/mirror/service/option.go @@ -0,0 +1,58 @@ +// +// Copyright (C) 2019-2022 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package service represents gateway's service logic +package service + +import ( + "github.com/vdaas/vald/internal/client/v1/client/mirror" + "github.com/vdaas/vald/internal/errgroup" + "github.com/vdaas/vald/internal/errors" +) + +type Option func(g *gateway) error + +var defaultGWOpts = []Option{ + WithErrGroup(errgroup.Get()), +} + +func WithMirrorClient(c mirror.Client) Option { + return func(g *gateway) error { + if c != nil { + g.client = c + } + return nil + } +} + +func WithErrGroup(eg errgroup.Group) Option { + return func(g *gateway) error { + if eg != nil { + g.eg = eg + } + return nil + } +} + +func WithPodName(s string) Option { + return func(g *gateway) error { + if len(s) == 0 { + return errors.NewErrCriticalOption("podName", s) + } + g.podName = s + return nil + } +} diff --git a/pkg/gateway/mirror/usecase/vald.go b/pkg/gateway/mirror/usecase/vald.go new file mode 100644 index 0000000000..bde3b59b3c --- /dev/null +++ b/pkg/gateway/mirror/usecase/vald.go @@ -0,0 +1,220 @@ +// +// Copyright (C) 2019-2022 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package usecase represents gateways usecase layer +package usecase + +import ( + "context" + + "github.com/vdaas/vald/apis/grpc/v1/vald" + mclient "github.com/vdaas/vald/internal/client/v1/client/mirror" + "github.com/vdaas/vald/internal/errgroup" + "github.com/vdaas/vald/internal/net/grpc" + "github.com/vdaas/vald/internal/observability" + backoffmetrics "github.com/vdaas/vald/internal/observability/metrics/backoff" + cbmetrics "github.com/vdaas/vald/internal/observability/metrics/circuitbreaker" + "github.com/vdaas/vald/internal/runner" + "github.com/vdaas/vald/internal/safety" + "github.com/vdaas/vald/internal/servers/server" + "github.com/vdaas/vald/internal/servers/starter" + "github.com/vdaas/vald/pkg/gateway/mirror/config" + handler "github.com/vdaas/vald/pkg/gateway/mirror/handler/grpc" + "github.com/vdaas/vald/pkg/gateway/mirror/handler/rest" + "github.com/vdaas/vald/pkg/gateway/mirror/router" + "github.com/vdaas/vald/pkg/gateway/mirror/service" +) + +type run struct { + eg errgroup.Group + cfg *config.Data + server starter.Server + c mclient.Client + gw service.Gateway + mgw service.Mirror + observability observability.Observability +} + +func New(cfg *config.Data) (r runner.Runner, err error) { + eg := errgroup.Get() + + cOpts, err := cfg.Mirror.Client.Opts() + if err != nil { + return nil, err + } + cOpts = append(cOpts, grpc.WithErrGroup(eg)) + + c, err := mclient.New( + mclient.WithAddrs(cfg.Mirror.Client.Addrs...), + mclient.WithClient(grpc.New(cOpts...)), + ) + if err != nil { + return nil, err + } + + gw, err := service.NewGateway( + service.WithErrGroup(eg), + service.WithMirrorClient(c), + service.WithPodName(cfg.Mirror.PodName), + ) + if err != nil { + return nil, err + } + mgw, err := service.NewMirror( + service.WithErrorGroup(eg), + service.WithAdvertiseInterval(cfg.Mirror.AdvertiseInterval), + service.WithValdAddrs(cfg.Mirror.GatewayAddr), + service.WithSelfMirrorAddrs(cfg.Mirror.SelfMirrorAddr), + service.WithGateway(gw), + ) + if err != nil { + return nil, err + } + + v, err := handler.New( + handler.WithValdAddr(cfg.Mirror.GatewayAddr), + handler.WithErrGroup(eg), + handler.WithGateway(gw), + handler.WithMirror(mgw), + handler.WithStreamConcurrency(cfg.Server.GetGRPCStreamConcurrency()), + ) + if err != nil { + return nil, err + } + + grpcServerOptions := []server.Option{ + server.WithGRPCRegistFunc(func(srv *grpc.Server) { + vald.RegisterValdServerWithMirror(srv, v) + }), + server.WithPreStopFunction(func() error { + return nil + }), + } + + var obs observability.Observability + if cfg.Observability.Enabled { + obs, err = observability.NewWithConfig( + cfg.Observability, + backoffmetrics.New(), + cbmetrics.New(), + ) + if err != nil { + return nil, err + } + } + + srv, err := starter.New( + starter.WithConfig(cfg.Server), + starter.WithREST(func(sc *config.Server) []server.Option { + return []server.Option{ + server.WithHTTPHandler( + router.New( + router.WithHandler( + rest.New( + rest.WithVald(v), + ), + ), + ), + ), + } + }), + starter.WithGRPC(func(sc *config.Server) []server.Option { + return grpcServerOptions + }), + ) + if err != nil { + return nil, err + } + + return &run{ + eg: eg, + cfg: cfg, + server: srv, + c: c, + gw: gw, + mgw: mgw, + observability: obs, + }, nil +} + +func (r *run) PreStart(ctx context.Context) error { + if r.observability != nil { + return r.observability.PreStart(ctx) + } + return nil +} + +func (r *run) Start(ctx context.Context) (<-chan error, error) { + ech := make(chan error, 6) + var mech, cech, sech, oech <-chan error + var err error + + sech = r.server.ListenAndServe(ctx) + if r.c != nil { + cech, err = r.c.Start(ctx) + if err != nil { + close(ech) + return nil, err + } + } + if r.mgw != nil { + mech, err = r.mgw.Start(ctx) + if err != nil { + close(ech) + return nil, err + } + } + if r.observability != nil { + oech = r.observability.Start(ctx) + } + + r.eg.Go(safety.RecoverFunc(func() (err error) { + defer close(ech) + for { + select { + case <-ctx.Done(): + return ctx.Err() + case err = <-mech: + case err = <-cech: + case err = <-sech: + case err = <-oech: + } + if err != nil { + select { + case <-ctx.Done(): + return ctx.Err() + case ech <- err: + } + } + } + })) + return ech, nil +} + +func (r *run) PreStop(ctx context.Context) error { + return nil +} + +func (r *run) Stop(ctx context.Context) error { + if r.observability != nil { + r.observability.Stop(ctx) + } + return r.server.Shutdown(ctx) +} + +func (*run) PostStop(_ context.Context) error { + return nil +}