diff --git a/Makefile b/Makefile index 7a07e11a2c..380a67c0c5 100644 --- a/Makefile +++ b/Makefile @@ -38,6 +38,7 @@ VERSION ?= $(eval VERSION := $(shell cat versions/VALD_VERSION))$(VERSION) NGT_VERSION := $(eval NGT_VERSION := $(shell cat versions/NGT_VERSION))$(NGT_VERSION) NGT_REPO = github.com/yahoojapan/NGT +FAISS_VERSION := $(eval FAISS_VERSION := $(shell cat versions/FAISS_VERSION))$(FAISS_VERSION) GOPROXY=direct GO_VERSION := $(eval GO_VERSION := $(shell cat versions/GO_VERSION))$(GO_VERSION) @@ -508,6 +509,18 @@ ngt/install: /usr/local/include/NGT/Capi.h rm -rf $(TEMP_DIR)/NGT-$(NGT_VERSION) ldconfig +.PHONY: faiss/install +## install Faiss +faiss/install: /usr/local/lib/libfaiss.so +/usr/local/lib/libfaiss.so: + curl -LO https://github.com/facebookresearch/faiss/archive/v$(FAISS_VERSION).tar.gz + tar zxf v$(FAISS_VERSION).tar.gz -C $(TEMP_DIR)/ + cd $(TEMP_DIR)/faiss-$(FAISS_VERSION) && \ + cmake -DFAISS_ENABLE_GPU=OFF -DFAISS_ENABLE_PYTHON=OFF -DBUILD_TESTING=OFF -DBUILD_SHARED_LIBS=ON -B build . && \ + make -C build -j faiss && \ + make -C build install + ldconfig + .PHONY: lint ## run lints lint: vet diff --git a/Makefile.d/build.mk b/Makefile.d/build.mk index 81f4bac8d9..d34f91f89b 100644 --- a/Makefile.d/build.mk +++ b/Makefile.d/build.mk @@ -61,6 +61,42 @@ cmd/agent/core/ngt/ngt: \ $(dir $@)main.go $@ -version +cmd/agent/core/faiss/faiss: \ + faiss/install \ + $(GO_SOURCES_INTERNAL) \ + $(PBGOS) \ + $(shell find ./cmd/agent/core/faiss -type f -name '*.go' -not -name '*_test.go' -not -name 'doc.go') \ + $(shell find ./pkg/agent/core/faiss ./pkg/agent/core/ngt/service/kvs ./pkg/agent/core/ngt/service/vqueue ./pkg/agent/internal -type f -name '*.go' -not -name '*_test.go' -not -name 'doc.go') + CFLAGS="$(CFLAGS)" \ + CXXFLAGS="$(CXXFLAGS)" \ + CGO_ENABLED=1 \ + CGO_CXXFLAGS="-g -Ofast -march=native" \ + CGO_FFLAGS="-g -Ofast -march=native" \ + CGO_LDFLAGS="-g -Ofast -march=native" \ + GO111MODULE=on \ + GOPRIVATE=$(GOPRIVATE) \ + go build \ + --ldflags "-w -linkmode 'external' \ + -extldflags '-fPIC -pthread -fopenmp -std=gnu++20 -lstdc++ -lm -z relro -z now $(EXTLDFLAGS)' \ + -X '$(GOPKG)/internal/info.Version=$(VERSION)' \ + -X '$(GOPKG)/internal/info.GitCommit=$(GIT_COMMIT)' \ + -X '$(GOPKG)/internal/info.BuildTime=$(DATETIME)' \ + -X '$(GOPKG)/internal/info.GoVersion=$(GO_VERSION)' \ + -X '$(GOPKG)/internal/info.GoOS=$(GOOS)' \ + -X '$(GOPKG)/internal/info.GoArch=$(GOARCH)' \ + -X '$(GOPKG)/internal/info.CGOEnabled=$${CGO_ENABLED}' \ + -X '$(GOPKG)/internal/info.FaissVersion=$(FAISS_VERSION)' \ + -X '$(GOPKG)/internal/info.BuildCPUInfoFlags=$(CPU_INFO_FLAGS)' \ + -buildid=" \ + -mod=readonly \ + -modcacherw \ + -a \ + -tags "cgo osusergo netgo static_build" \ + -trimpath \ + -o $@ \ + $(dir $@)main.go + $@ -version + cmd/agent/sidecar/sidecar: \ $(GO_SOURCES_INTERNAL) \ $(PBGOS) \ diff --git a/Makefile.d/docker.mk b/Makefile.d/docker.mk index 0f4b14b786..0c1c10d261 100644 --- a/Makefile.d/docker.mk +++ b/Makefile.d/docker.mk @@ -52,6 +52,18 @@ docker/build/agent-ngt: --build-arg DISTROLESS_IMAGE_TAG=$(DISTROLESS_IMAGE_TAG) \ --build-arg MAINTAINER=$(MAINTAINER) +.PHONY: docker/build/agent-faiss +## build agent-faiss image +docker/build/agent-faiss: + $(DOCKER) build \ + $(DOCKER_OPTS) \ + -f dockers/agent/core/faiss/Dockerfile \ + -t $(ORG)/vald-agent-faiss:$(TAG) . \ + --build-arg GO_VERSION=$(GO_VERSION) \ + --build-arg DISTROLESS_IMAGE=$(DISTROLESS_IMAGE) \ + --build-arg DISTROLESS_IMAGE_TAG=$(DISTROLESS_IMAGE_TAG) \ + --build-arg MAINTAINER=$(MAINTAINER) + .PHONY: docker/name/agent-sidecar docker/name/agent-sidecar: @echo "$(ORG)/$(AGENT_SIDECAR_IMAGE)" diff --git a/Makefile.d/e2e.mk b/Makefile.d/e2e.mk index 669adb79b4..5c76000459 100644 --- a/Makefile.d/e2e.mk +++ b/Makefile.d/e2e.mk @@ -19,6 +19,16 @@ e2e: $(call run-e2e-crud-test,-run TestE2EStandardCRUD) +.PHONY: e2e/faiss +## run e2e/faiss +e2e/faiss: + #$(call run-e2e-crud-faiss-test,-run TestE2EInsertOnly) + #$(call run-e2e-crud-faiss-test,-run TestE2ESearchOnly) + #$(call run-e2e-crud-faiss-test,-run TestE2EUpdateOnly) + #$(call run-e2e-crud-faiss-test,-run TestE2ERemoveOnly) + #$(call run-e2e-crud-faiss-test,-run TestE2EInsertAndSearch) + $(call run-e2e-crud-faiss-test,-run TestE2EStandardCRUD) + .PHONY: e2e/multi ## run e2e multiple apis e2e/multi: diff --git a/Makefile.d/functions.mk b/Makefile.d/functions.mk index 0310625be1..6d8a83d4bd 100644 --- a/Makefile.d/functions.mk +++ b/Makefile.d/functions.mk @@ -113,6 +113,28 @@ define run-e2e-crud-test -namespace=$(E2E_TARGET_NAMESPACE) endef +define run-e2e-crud-faiss-test + go test \ + -race \ + -mod=readonly \ + $1 \ + -v $(ROOTDIR)/tests/e2e/crud/crud_faiss_test.go \ + -tags "e2e" \ + -timeout $(E2E_TIMEOUT) \ + -host=$(E2E_BIND_HOST) \ + -port=$(E2E_BIND_PORT) \ + -dataset=$(ROOTDIR)/hack/benchmark/assets/dataset/$(E2E_DATASET_NAME).hdf5 \ + -insert-num=$(E2E_INSERT_COUNT) \ + -search-num=$(E2E_SEARCH_COUNT) \ + -update-num=$(E2E_UPDATE_COUNT) \ + -remove-num=$(E2E_REMOVE_COUNT) \ + -wait-after-insert=$(E2E_WAIT_FOR_CREATE_INDEX_DURATION) \ + -portforward=$(E2E_PORTFORWARD_ENABLED) \ + -portforward-pod-name=$(E2E_TARGET_POD_NAME) \ + -portforward-pod-port=$(E2E_TARGET_PORT) \ + -namespace=$(E2E_TARGET_NAMESPACE) +endef + define run-e2e-multi-crud-test go test \ -race \ diff --git a/cmd/agent/core/faiss/main.go b/cmd/agent/core/faiss/main.go new file mode 100644 index 0000000000..d2aba32de1 --- /dev/null +++ b/cmd/agent/core/faiss/main.go @@ -0,0 +1,59 @@ +// +// Copyright (C) 2019-2023 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package main provides program main +package main + +import ( + "context" + + "github.com/vdaas/vald/internal/errors" + "github.com/vdaas/vald/internal/info" + "github.com/vdaas/vald/internal/log" + "github.com/vdaas/vald/internal/runner" + "github.com/vdaas/vald/internal/safety" + "github.com/vdaas/vald/pkg/agent/core/faiss/config" + "github.com/vdaas/vald/pkg/agent/core/faiss/usecase" +) + +const ( + maxVersion = "v0.0.10" + minVersion = "v0.0.0" + name = "agent faiss" +) + +func main() { + if err := safety.RecoverFunc(func() error { + return runner.Do( + context.Background(), + runner.WithName(name), + runner.WithVersion(info.Version, maxVersion, minVersion), + runner.WithConfigLoader(func(path string) (interface{}, *config.GlobalConfig, error) { + cfg, err := config.NewConfig(path) + if err != nil { + return nil, nil, errors.Wrap(err, "failed to load "+name+"'s configuration") + } + return cfg, &cfg.GlobalConfig, nil + }), + runner.WithDaemonInitializer(func(cfg interface{}) (runner.Runner, error) { + return usecase.New(cfg.(*config.Data)) + }), + ) + })(); err != nil { + log.Fatal(err, info.Get()) + return + } +} diff --git a/cmd/agent/core/faiss/sample.yaml b/cmd/agent/core/faiss/sample.yaml new file mode 100644 index 0000000000..b146729db7 --- /dev/null +++ b/cmd/agent/core/faiss/sample.yaml @@ -0,0 +1,123 @@ +# +# Copyright (C) 2019-2023 vdaas.org vald team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +--- +version: v0.0.0 +time_zone: JST +logging: + format: raw + level: debug + logger: glg +server_config: + servers: + - name: grpc + host: 0.0.0.0 + port: 8081 + grpc: + bidirectional_stream_concurrency: 20 + connection_timeout: "" + header_table_size: 0 + initial_conn_window_size: 0 + initial_window_size: 0 + interceptors: [] + keepalive: + max_conn_age: "" + max_conn_age_grace: "" + max_conn_idle: "" + time: "" + timeout: "" + max_header_list_size: 0 + max_receive_message_size: 0 + max_send_message_size: 0 + read_buffer_size: 0 + write_buffer_size: 0 + mode: GRPC + probe_wait_time: 3s + restart: true + health_check_servers: + - name: readiness + host: 0.0.0.0 + port: 3001 + http: + handler_timeout: "" + idle_timeout: "" + read_header_timeout: "" + read_timeout: "" + shutdown_duration: 0s + write_timeout: "" + mode: "" + probe_wait_time: 3s + metrics_servers: + startup_strategy: + - grpc + - readiness + full_shutdown_duration: 600s + tls: + ca: /path/to/ca + cert: /path/to/cert + enabled: false + key: /path/to/key +observability: + enabled: false + collector: + duration: 5s + metrics: + enable_cgo: true + enable_goroutine: true + enable_memory: true + enable_version_info: true + version_info_labels: + - vald_version + - server_name + - git_commit + - build_time + - go_version + - go_os + - go_arch + - faiss_version + trace: + enabled: false + sampling_rate: 1 + prometheus: + enabled: false + endpoint: /metrics + namespace: vald + jaeger: + enabled: false + collector_endpoint: "" + agent_endpoint: "jaeger-agent.default.svc.cluster.local:6831" + username: "" + password: "" + service_name: "vald-agent-faiss" + buffer_max_count: 10 +faiss: + auto_index_check_duration: 30m + auto_index_duration_limit: 24h + auto_index_length: 100 + auto_save_index_duration: 35m + dimension: 64 + enable_copy_on_write: false + enable_in_memory_mode: true + enable_proactive_gc: true + index_path: "" + initial_delay_max_duration: 3m + load_index_timeout_factor: 1ms + m: 8 # dimension % m == 0, train size >= 2^m(or nlist) * minPointsPerCentroid + max_load_index_timeout: 10m + metric_type: "inner_product" + min_load_index_timeout: 3m + nbits_per_idx: 8 + nlist: 100 diff --git a/dockers/agent/core/faiss/Dockerfile b/dockers/agent/core/faiss/Dockerfile new file mode 100644 index 0000000000..dcf915b824 --- /dev/null +++ b/dockers/agent/core/faiss/Dockerfile @@ -0,0 +1,119 @@ +# +# Copyright (C) 2019-2023 vdaas.org vald team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +ARG GO_VERSION=latest +ARG MAINTAINER="vdaas.org vald team " + +FROM golang:${GO_VERSION} AS golang + +FROM ubuntu:devel AS builder + +ENV GO111MODULE on +ENV DEBIAN_FRONTEND noninteractive +ENV INITRD No +ENV LANG en_US.UTF-8 +ENV GOROOT /opt/go +ENV GOPATH /go +ENV PATH ${PATH}:${GOROOT}/bin:${GOPATH}/bin +ENV ORG vdaas +ENV REPO vald +ENV PKG agent/core/faiss +ENV PKG_INTERNAL agent/internal +ENV APP_NAME faiss + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + ca-certificates \ + cmake \ + curl \ + gcc \ + git \ + g++ \ + intel-mkl \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +COPY --from=golang /usr/local/go $GOROOT +RUN mkdir $GOPATH + +WORKDIR ${GOPATH}/src/github.com/${ORG}/${REPO} + +COPY go.mod . +COPY go.sum . + +RUN go mod download + +WORKDIR ${GOPATH}/src/github.com/${ORG}/${REPO}/internal +COPY internal . + +WORKDIR ${GOPATH}/src/github.com/${ORG}/${REPO}/apis/grpc +COPY apis/grpc . + +WORKDIR ${GOPATH}/src/github.com/${ORG}/${REPO}/pkg/${PKG} +COPY pkg/${PKG} . +# copy ngt/service/kvs and ngt/service/vqueue +WORKDIR ${GOPATH}/src/github.com/${ORG}/${REPO}/pkg/agent/core/ngt/service +COPY pkg/agent/core/ngt/service/kvs ./kvs +COPY pkg/agent/core/ngt/service/vqueue ./vqueue + +WORKDIR ${GOPATH}/src/github.com/${ORG}/${REPO}/pkg/${PKG_INTERNAL} +COPY pkg/${PKG_INTERNAL} . + +WORKDIR ${GOPATH}/src/github.com/${ORG}/${REPO}/cmd/${PKG} +COPY cmd/${PKG} . + +WORKDIR ${GOPATH}/src/github.com/${ORG}/${REPO}/versions +COPY versions . + +WORKDIR ${GOPATH}/src/github.com/${ORG}/${REPO}/Makefile.d +COPY Makefile.d . + +WORKDIR ${GOPATH}/src/github.com/${ORG}/${REPO} +COPY Makefile . +RUN update-alternatives --set libblas.so-x86_64-linux-gnu /usr/lib/x86_64-linux-gnu/libmkl_rt.so \ + && make faiss/install + +COPY .git . + +RUN make REPO=${ORG} NAME=${REPO} cmd/${PKG}/${APP_NAME} \ + && mv "cmd/${PKG}/${APP_NAME}" "/usr/bin/${APP_NAME}" + +WORKDIR ${GOPATH}/src/github.com/${ORG}/${REPO}/cmd/${PKG} +RUN cp sample.yaml /tmp/config.yaml + +FROM ubuntu:devel +LABEL maintainer "${MAINTAINER}" + +ENV APP_NAME faiss + +COPY --from=builder /usr/bin/${APP_NAME} /go/bin/${APP_NAME} +COPY --from=builder /tmp/config.yaml /etc/server/config.yaml + +COPY --from=builder /usr/local/lib/libfaiss.so /usr/local/lib/libfaiss.so +COPY --from=builder /lib/x86_64-linux-gnu/libstdc++.so.6 /lib/x86_64-linux-gnu/libstdc++.so.6 +COPY --from=builder /lib/x86_64-linux-gnu/libgcc_s.so.1 /lib/x86_64-linux-gnu/libgcc_s.so.1 +COPY --from=builder /lib/x86_64-linux-gnu/libc.so.6 /lib/x86_64-linux-gnu/libc.so.6 +COPY --from=builder /lib/x86_64-linux-gnu/libmkl_intel_lp64.so /lib/x86_64-linux-gnu/libmkl_intel_lp64.so +COPY --from=builder /lib/x86_64-linux-gnu/libmkl_sequential.so /lib/x86_64-linux-gnu/libmkl_sequential.so +COPY --from=builder /lib/x86_64-linux-gnu/libmkl_core.so /lib/x86_64-linux-gnu/libmkl_core.so +COPY --from=builder /lib/x86_64-linux-gnu/libgomp.so.1 /lib/x86_64-linux-gnu/libgomp.so.1 +COPY --from=builder /lib/x86_64-linux-gnu/libm.so.6 /lib/x86_64-linux-gnu/libm.so.6 +COPY --from=builder /lib/x86_64-linux-gnu/libdl.so.2 /lib/x86_64-linux-gnu/libdl.so.2 +COPY --from=builder /lib/x86_64-linux-gnu/libpthread.so.0 /lib/x86_64-linux-gnu/libpthread.so.0 +COPY --from=builder /lib/x86_64-linux-gnu/libmkl_avx2.so /lib/x86_64-linux-gnu/libmkl_avx2.so +RUN ldconfig -v + +ENTRYPOINT ["/go/bin/faiss"] diff --git a/docs/overview/component/README.md b/docs/overview/component/README.md index 49461c6b9b..58e03e547e 100644 --- a/docs/overview/component/README.md +++ b/docs/overview/component/README.md @@ -67,6 +67,7 @@ In this section, we will describe what is Vald Agent and the corresponding compo Vald Agent provides functionalities to perform approximate nearest neighbor search. Agent-NGT uses [yahoojapan/NGT](https://github.com/yahoojapan/NGT) as a core library. +And Agent-Faiss uses [facebookresearch/faiss](https://github.com/facebookresearch/faiss) as a core library. Each Vald Agent pod has its own vector data space because only several Vald Agents are selected to be inserted/updated in a single insert/update request. diff --git a/docs/overview/component/agent.md b/docs/overview/component/agent.md index be0326bb7f..1cf33ad534 100644 --- a/docs/overview/component/agent.md +++ b/docs/overview/component/agent.md @@ -78,6 +78,56 @@ This image shows the mechanism to create NGT index. Please refer to [Go Doc](https://pkg.go.dev/github.com/vdaas/vald@v1.3.1/pkg/agent/core/ngt/service) for other functions. +#### Vald Agent Faiss + +Vald Agent Faiss uses [Faiss](https://github.com/facebookresearch/faiss) as an algorithm. + +The main functions are the followings: + +- Insert + - Request to insert new vectors into the Faiss. + - Requested vectors are stored in the `vqueue`. + - Cache a fixed number of verctors for Faiss training. + - Once Faiss trained in CreateIndex, the vector is never cached for Faiss training. +- Search + - Get the nearest neighbor vectors of the request vector from Faiss indexes. + - radius/epsilon is search config for NGT and has no meaning in Faiss. +- Update + - Create a request to update the specific vectors to the new vectors. + - Requested vectors are stored in the `vqueue`. +- Remove + - Create a request to remove the specific vectors from Faiss indexes. + - Requested vectors are stored in the `vqueue`. +- Exist + - Check whether the specific vectors are already inserted or not. +- CreateIndex + - Create a new Faiss index structure in memory using vectors stored in the `vqueue` and the existing Faiss index structure if it exists. + - If a certain number of vectors required for Faiss training are not cached, they will not be trained. + - If Faiss is not trained, no index is generated. +- SaveIndex + - Save metadata about Faiss index information to the internal storage. + +Unimplemented functions are the followings: + +- GetObject +- SearchByID +- StreamXXX +- MultiXXX + +
+Same as Agent NGT, You have to control the duration of CreateIndex and SaveIndex by configuration. + +These methods don’t always run when getting the request. + +
+ +
+As you see, Vald Agent Faiss can only search the nearest neighbors from the Faiss index. + +You have to wait to complete the CreateIndex and SaveIndex functions before searching. + +
+ ### Sidecar `Sidecar` saves the index metadata file to external storage like Amazon S3 or Google Cloud Storage. diff --git a/docs/tutorial/get-started-with-faiss-agent.md b/docs/tutorial/get-started-with-faiss-agent.md new file mode 100644 index 0000000000..87d7562160 --- /dev/null +++ b/docs/tutorial/get-started-with-faiss-agent.md @@ -0,0 +1,393 @@ +# Get Started + +This tutorial is for those who have already completed [Get Started](https://github.com/vdaas/vald/blob/main/docs/tutorial/get-started.md). +Please refer to Prepare the Kubernetes Cluster and others there. + +## Deploy Vald on Kubernetes Cluster + +This chapter shows how to deploy Vald using Helm and run it on your Kubernetes cluster.
+In this tutorial, you will deploy the basic configuration of Vald that is consisted of vald-agent-faiss, vald-lb-gateway, vald-discoverer, and vald-manager-index.
+ +1. Clone the repository + + ```bash + git clone https://github.com/vdaas/vald.git && \ + cd vald + ``` + +1. Confirm which cluster to deploy + + ```bash + kubectl cluster-info + ``` + +1. Edit Configurations + + Set the parameters for connecting to the vald-lb-gateway through Kubernetes ingress from the external network. + Please set these parameters. + + ```bash + vim example/helm/values.yaml + === + ## vald-lb-gateway settings + gateway: + lb: + ... + ingress: + enabled: true + # TODO: Set your ingress host. + host: localhost + # TODO: Set annotations which you have to set for your k8s cluster. + annotations: + ... + ## vald-agent-faiss settings + agent: + algorithm: faiss + image: + repository: vdaas/vald-agent-faiss + tag: latest + faiss: + auto_index_check_duration: 1m + auto_index_duration_limit: 24h + auto_index_length: 10 + auto_save_index_duration: 35m + dimension: 784 + enable_copy_on_write: false + enable_in_memory_mode: true + enable_proactive_gc: true + index_path: "" + initial_delay_max_duration: 3m + load_index_timeout_factor: 1ms + m: 8 # dimension % m == 0, train size >= 2^m(or nlist) * minPointsPerCentroid + max_load_index_timeout: 10m + metric_type: "inner_product" + min_load_index_timeout: 3m + nbits_per_idx: 8 + nlist: 100 + ... + ``` + + Note:
+ If you decided to use port-forward instead of ingress, please set `gateway.lb.ingress.enabled` to `false`. + +1. Deploy Vald using Helm + + Add vald repo into the helm repo. + + ```bash + helm repo add vald https://vald.vdaas.org/charts + ``` + + Deploy vald on your Kubernetes cluster. + + ```bash + helm install vald vald/vald --values example/helm/values.yaml + ``` + +1. Verify + + When finish deploying Vald, you can check the Vald's pods status following command. + + ```bash + kubectl get pods + ``` + +
Example output
+ If the deployment is successful, all Vald components should be running. + + ```bash + NAME READY STATUS RESTARTS AGE + vald-agent-faiss-0 1/1 Running 0 7m12s + vald-agent-faiss-1 1/1 Running 0 7m12s + vald-agent-faiss-2 1/1 Running 0 7m12s + vald-agent-faiss-3 1/1 Running 0 7m12s + vald-agent-faiss-4 1/1 Running 0 7m12s + vald-discoverer-7f9f697dbb-q44qh 1/1 Running 0 7m11s + vald-lb-gateway-6b7b9f6948-4z5md 1/1 Running 0 7m12s + vald-lb-gateway-6b7b9f6948-68g94 1/1 Running 0 6m56s + vald-lb-gateway-6b7b9f6948-cvspq 1/1 Running 0 6m56s + vald-manager-index-74c7b5ddd6-jrnlw 1/1 Running 0 7m12s + ``` + +
+ + ```bash + kubectl get ingress + ``` + +
Example output
+ + ```bash + NAME CLASS HOSTS ADDRESS PORTS AGE + vald-lb-gateway-ingress localhost 192.168.16.2 80 7m43s + ``` + +
+ + ```bash + kubectl get svc + ``` + +
Example output
+ + ```bash + NAME TYPE CLUSTER-IP EXTERNAL-IP PORT(S) AGE + kubernetes ClusterIP 10.43.0.1 443/TCP 9m29s + vald-agent-faiss ClusterIP None 8081/TCP,3001/TCP 8m48s + vald-discoverer ClusterIP None 8081/TCP,3001/TCP 8m48s + vald-manager-index ClusterIP None 8081/TCP,3001/TCP 8m48s + vald-lb-gateway ClusterIP None 8081/TCP,3001/TCP 8m48s + ``` + +
+ +## Run Example Code + +In this chapter, you will execute insert, search, and delete vectors to your Vald cluster using the example code.
+The [Fashion-MNIST](https://github.com/zalandoresearch/fashion-mnist) is used as a dataset for indexing and search query. + +The example code is implemented in Go and using [vald-client-go](https://github.com/vdaas/vald-client-go), one of the official Vald client libraries, for requesting to Vald cluster. +Vald provides multiple language client libraries such as Go, Java, Node.js, Python, etc. +If you are interested, please refer to [SDKs](../user-guides/sdks.md).
+ +1. Port Forward(option) + + If you do not use Kubernetes Ingress, port-forward is required to make requests from your local environment. + + ```bash + kubectl port-forward deployment/vald-lb-gateway 8081:8081 + ``` + +1. Download dataset + + Download [Fashion-MNIST](https://github.com/zalandoresearch/fashion-mnist) that is used as a dataset for indexing and search query. + + Move to the working directory + + ```bash + cd example/client + ``` + + Download Fashion-MNIST testing dataset + + ```bash + wget http://ann-benchmarks.com/fashion-mnist-784-euclidean.hdf5 + ``` + +1. Run Example + + We use [`example/client/main.go`](https://github.com/vdaas/vald/blob/main/example/client/main.go) to run the example.
+ This example will insert and index 400 vectors into the Vald from the Fashion-MNIST dataset via [gRPC](https://grpc.io/). + And then after waiting for indexing, it will request for searching the nearest vector 10 times. + You will get the 10 nearest neighbor vectors for each search query.
+ Run example codes by executing the below command. + + ```bash + go run main.go + ``` + +
The detailed explanation of example code is here
+ This will execute 6 steps. + + 1. init + + - Import packages +
example code
+ + ```go + package main + + import ( + "context" + "encoding/json" + "flag" + "time" + + "github.com/kpango/fuid" + "github.com/kpango/glg" + "github.com/vdaas/vald-client-go/v1/payload" + "github.com/vdaas/vald-client-go/v1/vald" + + "gonum.org/v1/hdf5" + "google.golang.org/grpc" + ) + ``` + +
+ + - Set variables + + - The constant number of training datasets and test datasets. +
example code
+ + ```go + const ( + insertCount = 400 + testCount = 20 + ) + ``` + +
+ + - The variables for configuration. +
example code
+ + ```go + const ( + datasetPath string + grpcServerAddr string + indexingWaitSeconds uint + ) + ``` + +
+ + - Recognition parameters. +
example code
+ + ```go + func init() { + flag.StringVar(&datasetPath, "path", "fashion-mnist-784-euclidean.hdf5", "set dataset path") + flag.StringVar(&grpcServerAddr, "addr", "127.0.0.1:8081", "set gRPC server address") + flag.UintVar(&indexingWaitSeconds, "wait", 60, "set indexing wait seconds") + flag.Parse() + } + ``` + +
+ + 1. load + + - Loading from Fashion-MNIST dataset and set id for each vector that is loaded. This step will return the training dataset, test dataset, and ids list of ids when loading is completed with success. +
example code
+ + ```go + ids, train, test, err := load(datasetPath) + if err != nil { + glg.Fatal(err) + } + ``` + +
+ + 1. Create the gRPC connection and Vald client with gRPC connection. + +
example code
+ + ```go + ctx := context.Background() + + conn, err := grpc.DialContext(ctx, grpcServerAddr, grpc.WithInsecure()) + if err != nil { + glg.Fatal(err) + } + + client := vald.NewValdClient(conn) + ``` + +
+ + 1. Insert and Index + + - Insert and Indexing 400 training datasets to the Vald agent. +
example code
+ + ```go + for i := range ids [:insertCount] { + _, err := client.Insert(ctx, &payload.Insert_Request{ + Vector: &payload.Object_Vector{ + Id: ids[i], + Vector: train[i], + }, + Config: &payload.Insert_Config{ + SkipStrictExistCheck: true, + }, + }) + if err != nil { + glg.Fatal(err) + } + if i%10 == 0 { + glg.Infof("Inserted %d", i) + } + } + ``` + +
+ + - Wait until indexing finish. +
example code
+ + ```go + wt := time.Duration(indexingWaitSeconds) * time.Second + glg.Infof("Wait %s for indexing to finish", wt) + time.Sleep(wt) + ``` + +
+ + 1. Search + + - Search 10 neighbor vectors for each 20 test datasets and return a list of the neighbor vectors. + + - When getting approximate vectors, the Vald client sends search config and vector to the server via gRPC. +
example code
+ + ```go + glg.Infof("Start search %d times", testCount) + for i, vec := range test[:testCount] { + res, err := client.Search(ctx, &payload.Search_Request){ + Vector: vec, + Config: &payload.Search_Config{ + Num: 10, + Radius: -1, + Epsilon: 0.1, + Timeout: 100000000, + } + } + if err != nil { + glg.Fatal(err) + } + + b, _ := json.MarshalIndent(res.GetResults(), "", " ") + glg.Infof("%d - Results : %s\n\n", i+1, string(b)) + time.Sleep(1 * time.Second) + } + ``` + +
+ + 1. Remove + + - Remove 400 indexed training datasets from the Vald agent. +
example code
+ + ```go + for i := range ids [:insertCount] { + _, err := client.Remove(ctx, &payload.Remove_Request{ + Id: &payload.Object_ID{ + Id: ids[i], + }, + }) + if err != nil { + glg.Fatal(err) + } + if i%10 == 0 { + glg.Infof("Removed %d", i) + } + } + ``` + +
+ +## Cleanup + +In the last, you can remove the deployed Vald Cluster by executing the below command. + +```bash +helm uninstall vald +``` + +## References + +- [Get Started with NGT agent by default](https://github.com/vdaas/vald/blob/main/docs/tutorial/get-started.md) +- [Faiss](https://github.com/facebookresearch/faiss) diff --git a/docs/user-guides/configuration.md b/docs/user-guides/configuration.md index b3b9e06d88..f54def2a17 100644 --- a/docs/user-guides/configuration.md +++ b/docs/user-guides/configuration.md @@ -175,6 +175,42 @@ When the setting parameter of Vald Agent NGT is shorter than the setting value o If this happens, the Index Manager may not function properly. +#### Faiss + +Vald Agent Faiss uses [facebookresearch/faiss][faiss] as a core library for searching vectors. +The behaviors of Faiss can be configured by setting `agent.faiss` field object. + +The important parameters are the followings: + +- `agent.faiss.dimension` +- `agent.faiss.distance_type` +- `agent.faiss.m` +- `agent.faiss.metric_type` +- `agent.faiss.nbits_per_idx` +- `agent.faiss.nlist` + +Users should configure these parameters first to fit their use case. +For further details, please read [the Fiass wiki][faiss-wiki]. + +Vald Agent Faiss has a feature to start indexing automatically. +The behavior of this feature can be configured with these parameters: + +- `agent.faiss.auto_index_duration_limit` +- `agent.faiss.auto_index_check_duration` +- `agent.faiss.auto_index_length` + +
+While the Vald Agent Faiss is in the process of creating indexes, it will ignore all search requests to the target pods. +
+ +
+When deploying Vald Index Manager, the above parameters should be set much longer than the Vald Index Manager settings (Please refer to the Vald Index Manager section) or minus value.
+E.g., set agent.faiss.auto_index_duration_limit to "720h" or "-1h" and agent.faiss.auto_index_check_duration to "24h" or "-1h".
+This is because the Vald Index Manager accurately grasps the index information of each Vald Agent Faiss and controls the execution timing of indexing.

+When the setting parameter of Vald Agent Faiss is shorter than the setting value of Vald Index Manager, Vald Agent Faiss may start indexing by itself without the execution command from Vald Index Manager. +If this happens, the Index Manager may not function properly. +
+ #### Resource requests and limits, Pod priorities Because the Vald Agent pod places indexes on memory, termination of agent pods causes loss of indexes. @@ -371,3 +407,5 @@ For further details, there are references to the Helm values in the Vald GitHub [kubernetes-topology-spread-constraints]: https://kubernetes.io/docs/concepts/workloads/pods/pod-topology-spread-constraints/ [yj-ngt]: https://github.com/yahoojapan/NGT [yj-ngt-wiki]: https://github.com/yahoojapan/NGT/wiki +[faiss]: https://github.com/facebookresearch/faiss +[faiss-wiki]: https://github.com/facebookresearch/faiss/wiki diff --git a/example/client/main.go b/example/client/main.go index 3a608b41d5..c5c8a3f7e2 100644 --- a/example/client/main.go +++ b/example/client/main.go @@ -28,24 +28,26 @@ import ( ) const ( - insertCount = 400 testCount = 20 ) var ( datasetPath string grpcServerAddr string + insertCount uint indexingWaitSeconds uint ) func init() { /** Path option specifies hdf file by path. Default value is `fashion-mnist-784-euclidean.hdf5`. - Addr option specifies grpc server address. Default value is `127.0.0.1:8080`. - Wait option specifies indexing wait time (in seconds). Default value is `60`. + Addr option specifies grpc server address. Default value is `127.0.0.1:8081`. + Insert option specifies insert count. Default value is `400`. + Wait option specifies indexing wait time (in seconds). Default value is `60`. **/ flag.StringVar(&datasetPath, "path", "fashion-mnist-784-euclidean.hdf5", "dataset path") - flag.StringVar(&grpcServerAddr, "addr", "localhost:8080", "gRPC server address") + flag.StringVar(&grpcServerAddr, "addr", "localhost:8081", "gRPC server address") + flag.UintVar(&insertCount, "insert", 400, "insert count") flag.UintVar(&indexingWaitSeconds, "wait", 60, "indexing wait seconds") flag.Parse() } diff --git a/internal/config/faiss.go b/internal/config/faiss.go new file mode 100644 index 0000000000..732d5a26be --- /dev/null +++ b/internal/config/faiss.go @@ -0,0 +1,118 @@ +// +// Copyright (C) 2019-2023 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package config providers configuration type and load configuration logic +package config + +// Faiss represent the faiss core configuration for server. +type Faiss struct { + // IndexPath represents the faiss index file path + IndexPath string `yaml:"index_path" json:"index_path,omitempty"` + + // Dimension represents the faiss index dimension + Dimension int `yaml:"dimension" json:"dimension,omitempty" info:"dimension"` + + // Nlist represents the number of Voronoi cells + // ref: https://github.com/facebookresearch/faiss/wiki/Faster-search + Nlist int `yaml:"nlist" json:"nlist,omitempty" info:"nlist"` + + // M represents the number of subquantizers + // ref: https://github.com/facebookresearch/faiss/wiki/Faiss-indexes-(composite)#cell-probe-method-with-a-pq-index-as-coarse-quantizer + M int `yaml:"m" json:"m,omitempty" info:"m"` + + // NbitsPerIdx represents the number of bit per subvector index + // ref: https://github.com/facebookresearch/faiss/wiki/FAQ#can-i-ignore-warning-clustering-xxx-points-to-yyy-centroids + NbitsPerIdx int `yaml:"nbits_per_idx" json:"nbits_per_idx,omitempty" info:"nbits_per_idx"` + + // MetricType represents the metric type + MetricType string `yaml:"metric_type" json:"metric_type,omitempty" info:"metric_type"` + + // EnableInMemoryMode enables on memory faiss indexing mode + EnableInMemoryMode bool `yaml:"enable_in_memory_mode" json:"enable_in_memory_mode,omitempty"` + + // AutoIndexCheckDuration represents checking loop duration about auto indexing execution + AutoIndexCheckDuration string `yaml:"auto_index_check_duration" json:"auto_index_check_duration,omitempty"` + + // AutoSaveIndexDuration represents checking loop duration about auto save index execution + AutoSaveIndexDuration string `yaml:"auto_save_index_duration" json:"auto_save_index_duration,omitempty"` + + // AutoIndexDurationLimit represents auto indexing duration limit + AutoIndexDurationLimit string `yaml:"auto_index_duration_limit" json:"auto_index_duration_limit,omitempty"` + + // AutoIndexLength represents auto index length limit + AutoIndexLength int `yaml:"auto_index_length" json:"auto_index_length,omitempty"` + + // InitialDelayMaxDuration represents maximum duration for initial delay + InitialDelayMaxDuration string `yaml:"initial_delay_max_duration" json:"initial_delay_max_duration,omitempty"` + + // MinLoadIndexTimeout represents minimum duration of load index timeout + MinLoadIndexTimeout string `yaml:"min_load_index_timeout" json:"min_load_index_timeout,omitempty"` + + // MaxLoadIndexTimeout represents maximum duration of load index timeout + MaxLoadIndexTimeout string `yaml:"max_load_index_timeout" json:"max_load_index_timeout,omitempty"` + + // LoadIndexTimeoutFactor represents a factor of load index timeout + LoadIndexTimeoutFactor string `yaml:"load_index_timeout_factor" json:"load_index_timeout_factor,omitempty"` + + // EnableProactiveGC enables more proactive GC call for reducing heap memory allocation + EnableProactiveGC bool `yaml:"enable_proactive_gc" json:"enable_proactive_gc,omitempty"` + + // EnableCopyOnWrite enables copy on write saving + EnableCopyOnWrite bool `yaml:"enable_copy_on_write" json:"enable_copy_on_write,omitempty"` + + // VQueue represents the faiss vector queue buffer size + VQueue *VQueue `json:"vqueue,omitempty" yaml:"vqueue"` + + // KVSDB represents the faiss bidirectional kv store configuration + KVSDB *KVSDB `json:"kvsdb,omitempty" yaml:"kvsdb"` +} + +//// KVSDB represent the faiss vector bidirectional kv store configuration +//type KVSDB struct { +// // Concurrency represents kvsdb range loop processing concurrency +// Concurrency int `json:"concurrency,omitempty" yaml:"concurrency,omitempty"` +//} +// +//// VQueue represent the faiss vector queue buffer size +//type VQueue struct { +// // InsertBufferPoolSize represents insert time ordered slice buffer size +// InsertBufferPoolSize int `json:"insert_buffer_pool_size,omitempty" yaml:"insert_buffer_pool_size"` +// +// // DeleteBufferPoolSize represents delete time ordered slice buffer size +// DeleteBufferPoolSize int `json:"delete_buffer_pool_size,omitempty" yaml:"delete_buffer_pool_size"` +//} + +// Bind returns Faiss object whose some string value is filed value or environment value. +func (f *Faiss) Bind() *Faiss { + f.IndexPath = GetActualValue(f.IndexPath) + f.MetricType = GetActualValue(f.MetricType) + f.AutoIndexCheckDuration = GetActualValue(f.AutoIndexCheckDuration) + f.AutoIndexDurationLimit = GetActualValue(f.AutoIndexDurationLimit) + f.AutoSaveIndexDuration = GetActualValue(f.AutoSaveIndexDuration) + f.InitialDelayMaxDuration = GetActualValue(f.InitialDelayMaxDuration) + f.MinLoadIndexTimeout = GetActualValue(f.MinLoadIndexTimeout) + f.MaxLoadIndexTimeout = GetActualValue(f.MaxLoadIndexTimeout) + f.LoadIndexTimeoutFactor = GetActualValue(f.LoadIndexTimeoutFactor) + + if f.VQueue == nil { + f.VQueue = new(VQueue) + } + if f.KVSDB == nil { + f.KVSDB = new(KVSDB) + } + + return f +} diff --git a/internal/core/algorithm/faiss/Capi.cpp b/internal/core/algorithm/faiss/Capi.cpp new file mode 100644 index 0000000000..85cdbc7ad3 --- /dev/null +++ b/internal/core/algorithm/faiss/Capi.cpp @@ -0,0 +1,231 @@ +// +// Copyright (C) 2019-2023 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "Capi.h" + +FaissStruct* faiss_create_index( + const int d, + const int nlist, + const int m, + const int nbits_per_idx, + const int metric_type) { + //printf(__FUNCTION__); + //printf("\n"); + //fflush(stdout); + + FaissStruct *st = NULL; + try { + faiss::IndexFlat *quantizer; + switch (metric_type) { + case faiss::METRIC_INNER_PRODUCT: + quantizer = new faiss::IndexFlat(d, faiss::METRIC_INNER_PRODUCT); + break; + case faiss::METRIC_L2: + quantizer = new faiss::IndexFlat(d, faiss::METRIC_L2); + break; + default: + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: no metric type."; + std::cerr << ss.str() << std::endl; + return NULL; + } + faiss::IndexIVFPQ *index = new faiss::IndexIVFPQ(quantizer, d, nlist, m, nbits_per_idx); + //index->verbose = true; + st = new FaissStruct{ + static_cast(quantizer), + static_cast(index) + }; + } catch(std::exception &err) { + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + std::cerr << ss.str() << std::endl; + } + + return st; +} + +FaissStruct* faiss_read_index(const char* fname) { + //printf(__FUNCTION__); + //printf("\n"); + //fflush(stdout); + + FaissStruct *st = NULL; + try { + st = new FaissStruct{ + static_cast(NULL), + static_cast(faiss::read_index(fname)) + }; + } catch(std::exception &err) { + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + std::cerr << ss.str() << std::endl; + } + + return st; +} + +bool faiss_write_index( + const FaissStruct* st, + const char* fname) { + //printf(__FUNCTION__); + //printf("\n"); + //fflush(stdout); + + try { + faiss::write_index(static_cast(st->faiss_index), fname); + } catch(std::exception &err) { + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + std::cerr << ss.str() << std::endl; + return false; + } + + fflush(stdout); + return true; +} + +bool faiss_train( + const FaissStruct* st, + const int nb, + const float* xb) { + //printf(__FUNCTION__); + //printf("\n"); + //fflush(stdout); + + try { + //printf("is_trained: %d\n", (static_cast(st->faiss_index))->is_trained); + //printf("ntotal: %ld\n", (static_cast(st->faiss_index))->ntotal); + (static_cast(st->faiss_index))->train(nb, xb); + //printf("is_trained: %d\n", (static_cast(st->faiss_index))->is_trained); + //printf("ntotal: %ld\n", (static_cast(st->faiss_index))->ntotal); + } catch(std::exception &err) { + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + std::cerr << ss.str() << std::endl; + return false; + } + + fflush(stdout); + return true; +} + +int faiss_add( + const FaissStruct* st, + const int nb, + const float* xb, + const long int* xids ) { + //printf(__FUNCTION__); + //printf("\n"); + //fflush(stdout); + + try { + //printf("is_trained: %d\n", (static_cast(st->faiss_index))->is_trained); + //printf("ntotal: %ld\n", (static_cast(st->faiss_index))->ntotal); + (static_cast(st->faiss_index))->add_with_ids(nb, xb, xids); + //printf("is_trained: %d\n", (static_cast(st->faiss_index))->is_trained); + //printf("ntotal: %ld\n", (static_cast(st->faiss_index))->ntotal); + } catch(std::exception &err) { + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + std::cerr << ss.str() << std::endl; + return -1; + } + + fflush(stdout); + return (static_cast(st->faiss_index))->ntotal; +} + +bool faiss_search( + const FaissStruct* st, + const int k, + const int nq, + const float* xq, + long* I, + float* D) { + //printf(__FUNCTION__); + //printf("\n"); + //fflush(stdout); + + try { + //printf("is_trained: %d\n", (static_cast(st->faiss_index))->is_trained); + //printf("ntotal: %ld\n", (static_cast(st->faiss_index))->ntotal); + (static_cast(st->faiss_index))->search(nq, xq, k, D, I); + //printf("I=\n"); + //for(int i = 0; i < nq; i++) { + // for(int j = 0; j < k; j++) { + // printf("%5ld ", I[i * k + j]); + // } + // printf("\n"); + //} + //printf("D=\n"); + //for(int i = 0; i < nq; i++) { + // for(int j = 0; j < k; j++) { + // printf("%7g ", D[i * k + j]); + // } + // printf("\n"); + //} + } catch(std::exception &err) { + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + std::cerr << ss.str() << std::endl; + return false; + } + + return true; +} + +int faiss_remove( + const FaissStruct* st, + const int size, + const long int* ids) { + //printf(__FUNCTION__); + //printf("\n"); + //fflush(stdout); + + try { + //printf("is_trained: %d\n", (static_cast(st->faiss_index))->is_trained); + //printf("ntotal: %ld\n", (static_cast(st->faiss_index))->ntotal); + faiss::IDSelectorArray sel(size, ids); + (static_cast(st->faiss_index))->remove_ids(sel); + //printf("is_trained: %d\n", (static_cast(st->faiss_index))->is_trained); + //printf("ntotal: %ld\n", (static_cast(st->faiss_index))->ntotal); + } catch(std::exception &err) { + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + std::cerr << ss.str() << std::endl; + return -1; + } + + return (static_cast(st->faiss_index))->ntotal; +} + +void faiss_free(FaissStruct* st) { + //printf(__FUNCTION__); + //printf("\n"); + //fflush(stdout); + + free(st); + return; +} diff --git a/internal/core/algorithm/faiss/Capi.h b/internal/core/algorithm/faiss/Capi.h new file mode 100644 index 0000000000..c1df817d35 --- /dev/null +++ b/internal/core/algorithm/faiss/Capi.h @@ -0,0 +1,64 @@ +// +// Copyright (C) 2019-2023 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#ifdef __cplusplus +extern "C" { +#endif + #include + #include + #include + + typedef void* FaissQuantizer; + typedef void* FaissIndex; + typedef struct { + FaissQuantizer faiss_quantizer; + FaissIndex faiss_index; + } FaissStruct; + + FaissStruct* faiss_create_index( + const int d, + const int nlist, + const int m, + const int nbits_per_idx, + const int metric_type); + FaissStruct* faiss_read_index(const char* fname); + bool faiss_write_index( + const FaissStruct* st, + const char* fname); + bool faiss_train( + const FaissStruct* st, + const int nb, + const float* xb); + int faiss_add( + const FaissStruct* st, + const int nb, + const float* xb, + const long int* xids); + bool faiss_search( + const FaissStruct* st, + const int k, + const int nq, + const float* xq, + long* I, + float* D); + int faiss_remove( + const FaissStruct* st, + const int size, + const long int* ids); + void faiss_free(FaissStruct* st); +#ifdef __cplusplus +} +#endif diff --git a/internal/core/algorithm/faiss/faiss.go b/internal/core/algorithm/faiss/faiss.go new file mode 100644 index 0000000000..b7a0e60518 --- /dev/null +++ b/internal/core/algorithm/faiss/faiss.go @@ -0,0 +1,255 @@ +// +// Copyright (C) 2019-2023 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package faiss provides implementation of Go API for https://github.com/facebookresearch/faiss +package faiss + +/* +#cgo LDFLAGS: -lfaiss +#include +*/ +import "C" + +import ( + "sync" + "unsafe" + + "github.com/vdaas/vald/internal/errors" +) + +type ( + // Faiss is core interface. + Faiss interface { + // SaveIndex stores faiss index to strage. + SaveIndex() error + + // SaveIndexWithPath stores faiss index to specified storage. + SaveIndexWithPath(idxPath string) error + + // Train trains faiss index. + Train(nb int, xb []float32) error + + // Add returns faiss ntotal. + Add(nb int, xb []float32, xids []int64) (int, error) + + // Search returns search result as []SearchResult. + Search(k, nq int, xq []float32) ([]SearchResult, error) + + // Remove removes from faiss index. + Remove(size int, ids []int64) (int, error) + + // Close faiss index. + Close() + } + + faiss struct { + st *C.FaissStruct + dimension C.int + nlist C.int + m C.int + nbitsPerIdx C.int + metricType metricType + idxPath string + mu *sync.RWMutex + } + + SearchResult struct { + ID uint32 + Distance float32 + Error error + } +) + +// metricType is alias of metric type in Faiss. +type metricType int + +const ( + // ------------------------------------------------------------- + // Metric Type Definition + // (https://github.com/facebookresearch/faiss/wiki/MetricType-and-distances) + // ------------------------------------------------------------- + // DistanceNone is unknown distance type. + DistanceNone metricType = iota - 1 + // InnerProduct is inner product. + InnerProduct + // L2 is l2 norm. + L2 + // -------------------------------------------------------------. + + // ------------------------------------------------------------- + // ErrorCode is false + // -------------------------------------------------------------. + ErrorCode = C._Bool(false) + // -------------------------------------------------------------. +) + +// New returns Faiss instance with recreating empty index file. +func New(opts ...Option) (Faiss, error) { + return gen(false, opts...) +} + +func Load(opts ...Option) (Faiss, error) { + return gen(true, opts...) +} + +func gen(isLoad bool, opts ...Option) (Faiss, error) { + var ( + f = new(faiss) + err error + ) + f.mu = new(sync.RWMutex) + + defer func() { + if err != nil { + f.Close() + } + }() + + for _, opt := range append(defaultOptions, opts...) { + if err = opt(f); err != nil { + return nil, errors.NewFaissError("faiss option error") + } + } + + if isLoad { + path := C.CString(f.idxPath) + defer C.free(unsafe.Pointer(path)) + f.st = C.faiss_read_index(path) + if f.st == nil { + return nil, errors.NewFaissError("faiss load index error") + } + } else { + switch f.metricType { + case InnerProduct: + f.st = C.faiss_create_index(f.dimension, f.nlist, f.m, f.nbitsPerIdx, C.int(InnerProduct)) + case L2: + f.st = C.faiss_create_index(f.dimension, f.nlist, f.m, f.nbitsPerIdx, C.int(L2)) + default: + return nil, errors.NewFaissError("faiss create index error: no metric type") + } + if f.st == nil { + return nil, errors.NewFaissError("faiss create index error: nil pointer") + } + } + + return f, nil +} + +// SaveIndex stores faiss index to storage. +func (f *faiss) SaveIndex() error { + path := C.CString(f.idxPath) + defer C.free(unsafe.Pointer(path)) + + f.mu.Lock() + ret := C.faiss_write_index(f.st, path) + f.mu.Unlock() + if ret == ErrorCode { + return errors.NewFaissError("failed to faiss_write_index") + } + + return nil +} + +// SaveIndexWithPath stores faiss index to specified storage. +func (f *faiss) SaveIndexWithPath(idxPath string) error { + path := C.CString(idxPath) + defer C.free(unsafe.Pointer(path)) + + f.mu.Lock() + ret := C.faiss_write_index(f.st, path) + f.mu.Unlock() + if ret == ErrorCode { + return errors.NewFaissError("failed to faiss_write_index") + } + + return nil +} + +// Train trains faiss index. +func (f *faiss) Train(nb int, xb []float32) error { + f.mu.Lock() + ret := C.faiss_train(f.st, (C.int)(nb), (*C.float)(&xb[0])) + f.mu.Unlock() + if ret == ErrorCode { + return errors.NewFaissError("failed to faiss_train") + } + + return nil +} + +// Add returns faiss ntotal. +func (f *faiss) Add(nb int, xb []float32, xids []int64) (int, error) { + dim := int(f.dimension) + if len(xb) != dim*nb || len(xb) != dim*len(xids) { + return -1, errors.ErrIncompatibleDimensionSize(len(xb)/nb, dim) + } + + f.mu.Lock() + ntotal := int(C.faiss_add(f.st, (C.int)(nb), (*C.float)(&xb[0]), (*C.long)(&xids[0]))) + f.mu.Unlock() + if ntotal < 0 { + return ntotal, errors.NewFaissError("failed to faiss_add") + } + + return ntotal, nil +} + +// Search returns search result as []SearchResult. +func (f *faiss) Search(k, nq int, xq []float32) ([]SearchResult, error) { + if len(xq) != nq*int(f.dimension) { + return nil, errors.ErrIncompatibleDimensionSize(len(xq), int(f.dimension)) + } + + I := make([]int64, k*nq) + D := make([]float32, k*nq) + f.mu.RLock() + ret := C.faiss_search(f.st, (C.int)(k), (C.int)(nq), (*C.float)(&xq[0]), (*C.long)(&I[0]), (*C.float)(&D[0])) + f.mu.RUnlock() + if ret == ErrorCode { + return nil, errors.NewFaissError("failed to faiss_search") + } + + if len(I) == 0 || len(D) == 0 { + return nil, errors.ErrEmptySearchResult + } + + result := make([]SearchResult, k) + for i := range result { + result[i] = SearchResult{uint32(I[i]), D[i], nil} + } + + return result, nil +} + +// Remove removes from faiss index. +func (f *faiss) Remove(size int, ids []int64) (int, error) { + f.mu.Lock() + ntotal := int(C.faiss_remove(f.st, (C.int)(size), (*C.long)(&ids[0]))) + f.mu.Unlock() + if ntotal < 0 { + return ntotal, errors.NewFaissError("failed to faiss_remove") + } + + return ntotal, nil +} + +// Close faiss index. +func (f *faiss) Close() { + if f.st != nil { + C.faiss_free(f.st) + f.st = nil + } +} diff --git a/internal/core/algorithm/faiss/option.go b/internal/core/algorithm/faiss/option.go new file mode 100644 index 0000000000..40c9a3ad5a --- /dev/null +++ b/internal/core/algorithm/faiss/option.go @@ -0,0 +1,120 @@ +// +// Copyright (C) 2019-2023 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package faiss provides implementation of Go API for https://github.com/facebookresearch/faiss +package faiss + +import "C" + +import ( + "strings" + + "github.com/vdaas/vald/internal/core/algorithm" + "github.com/vdaas/vald/internal/errors" +) + +// Option represents the functional option for faiss. +type Option func(*faiss) error + +var defaultOptions = []Option{ + WithDimension(64), + WithNlist(100), + WithM(8), + WithNbitsPerIdx(8), + WithMetricType("l2"), +} + +// WithDimension represents the option to set the dimension for faiss. +func WithDimension(dim int) Option { + return func(f *faiss) error { + if dim > algorithm.MaximumVectorDimensionSize || dim < algorithm.MinimumVectorDimensionSize { + err := errors.ErrInvalidDimensionSize(dim, algorithm.MaximumVectorDimensionSize) + return errors.NewErrCriticalOption("dimension", dim, err) + } + + f.dimension = (C.int)(dim) + return nil + } +} + +// WithNlist represents the option to set the nlist for faiss. +func WithNlist(nlist int) Option { + return func(f *faiss) error { + if nlist <= 0 { + return errors.NewErrInvalidOption("nlist", nlist) + } + + f.nlist = (C.int)(nlist) + return nil + } +} + +// WithM represents the option to set the m for faiss. +func WithM(m int) Option { + return func(f *faiss) error { + if m <= 0 || int(f.dimension)%m != 0 { + return errors.NewErrInvalidOption("m", m) + } + + f.m = (C.int)(m) + return nil + } +} + +// WithNbitsPerIdx represents the option to set the n bits per index for faiss. +func WithNbitsPerIdx(nbitsPerIdx int) Option { + return func(f *faiss) error { + if nbitsPerIdx <= 0 { + return errors.NewErrInvalidOption("nbitsPerIdx", nbitsPerIdx) + } + + f.nbitsPerIdx = (C.int)(nbitsPerIdx) + return nil + } +} + +// WithMetricType represents the option to set the metric type for faiss. +func WithMetricType(metricType string) Option { + return func(f *faiss) error { + if len(metricType) == 0 { + return errors.NewErrIgnoredOption("metricType") + } + + switch strings.NewReplacer("-", "", "_", "", " ", "").Replace(strings.ToLower(metricType)) { + case "innerproduct": + f.metricType = InnerProduct + case "l2": + f.metricType = L2 + default: + err := errors.ErrUnsupportedDistanceType + return errors.NewErrCriticalOption("metricType", metricType, err) + } + + return nil + } +} + +// WithIndexPath represents the option to set the index path for faiss. +func WithIndexPath(idxPath string) Option { + return func(f *faiss) error { + if len(idxPath) == 0 { + return errors.NewErrIgnoredOption("indexPath") + } + + f.idxPath = idxPath + return nil + } +} diff --git a/internal/errors/faiss.go b/internal/errors/faiss.go new file mode 100644 index 0000000000..735942079c --- /dev/null +++ b/internal/errors/faiss.go @@ -0,0 +1,32 @@ +// +// Copyright (C) 2019-2023 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package errors provides error types and function +package errors + +type FaissError struct { + Msg string +} + +func NewFaissError(msg string) error { + return FaissError{ + Msg: msg, + } +} + +func (f FaissError) Error() string { + return f.Msg +} diff --git a/internal/observability/metrics/agent/core/faiss/faiss.go b/internal/observability/metrics/agent/core/faiss/faiss.go new file mode 100644 index 0000000000..fb0a68c34c --- /dev/null +++ b/internal/observability/metrics/agent/core/faiss/faiss.go @@ -0,0 +1,267 @@ +// Copyright (C) 2019-2023 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package faiss + +import ( + "context" + + "github.com/vdaas/vald/internal/observability/metrics" + "github.com/vdaas/vald/pkg/agent/core/faiss/service" + api "go.opentelemetry.io/otel/metric" + view "go.opentelemetry.io/otel/sdk/metric" + "go.opentelemetry.io/otel/sdk/metric/aggregation" +) + +const ( + indexCountMetricsName = "agent_core_faiss_index_count" + indexCountMetricsDescription = "Agent Faiss index count" + + uncommittedIndexCountMetricsName = "agent_core_faiss_uncommitted_index_count" + uncommittedIndexCountMetricsDescription = "Agent Faiss index count" + + insertVQueueCountMetricsName = "agent_core_faiss_insert_vqueue_count" + insertVQueueCountMetricsDescription = "Agent Faiss insert vqueue count" + + deleteVQueueCountMetricsName = "agent_core_faiss_delete_vqueue_count" + deleteVQueueCountMetricsDescription = "Agent Faiss delete vqueue count" + + completedCreateIndexTotalMetricsName = "agent_core_faiss_completed_create_index_total" + completedCreateIndexTotalMetricsDescription = "The cumulative count of completed create index execution" + + executedProactiveGCTotalMetricsName = "agent_core_faiss_executed_proactive_gc_total" + executedProactiveGCTotalMetricsDescription = "The cumulative count of proactive GC execution" + + isIndexingMetricsName = "agent_core_faiss_is_indexing" + isIndexingMetricsDescription = "Currently indexing or no" + + isSavingMetricsName = "agent_core_faiss_is_saving" + isSavingMetricsDescription = "Currently saving or not" + + trainCountMetricsName = "agent_core_faiss_train_count" + trainCountMetricsDescription = "Agent Faiss train count" +) + +type faissMetrics struct { + faiss service.Faiss +} + +func New(f service.Faiss) metrics.Metric { + return &faissMetrics{ + faiss: f, + } +} + +func (f *faissMetrics) View() ([]metrics.View, error) { + return []metrics.View{ + view.NewView( + view.Instrument{ + Name: indexCountMetricsName, + Description: indexCountMetricsDescription, + }, + view.Stream{ + Aggregation: aggregation.LastValue{}, + }, + ), + view.NewView( + view.Instrument{ + Name: uncommittedIndexCountMetricsName, + Description: uncommittedIndexCountMetricsDescription, + }, + view.Stream{ + Aggregation: aggregation.LastValue{}, + }, + ), + view.NewView( + view.Instrument{ + Name: insertVQueueCountMetricsName, + Description: insertVQueueCountMetricsDescription, + }, + view.Stream{ + Aggregation: aggregation.LastValue{}, + }, + ), + view.NewView( + view.Instrument{ + Name: deleteVQueueCountMetricsName, + Description: deleteVQueueCountMetricsDescription, + }, + view.Stream{ + Aggregation: aggregation.LastValue{}, + }, + ), + view.NewView( + view.Instrument{ + Name: completedCreateIndexTotalMetricsName, + Description: completedCreateIndexTotalMetricsDescription, + }, + view.Stream{ + Aggregation: aggregation.LastValue{}, + }, + ), + view.NewView( + view.Instrument{ + Name: executedProactiveGCTotalMetricsName, + Description: executedProactiveGCTotalMetricsDescription, + }, + view.Stream{ + Aggregation: aggregation.LastValue{}, + }, + ), + view.NewView( + view.Instrument{ + Name: isIndexingMetricsName, + Description: isIndexingMetricsDescription, + }, + view.Stream{ + Aggregation: aggregation.LastValue{}, + }, + ), + view.NewView( + view.Instrument{ + Name: isSavingMetricsName, + Description: isSavingMetricsDescription, + }, + view.Stream{ + Aggregation: aggregation.LastValue{}, + }, + ), + view.NewView( + view.Instrument{ + Name: trainCountMetricsName, + Description: trainCountMetricsDescription, + }, + view.Stream{ + Aggregation: aggregation.LastValue{}, + }, + ), + }, nil +} + +func (f *faissMetrics) Register(m metrics.Meter) error { + indexCount, err := m.Int64ObservableGauge( + indexCountMetricsName, + metrics.WithDescription(indexCountMetricsDescription), + metrics.WithUnit(metrics.Dimensionless), + ) + if err != nil { + return err + } + + uncommittedIndexCount, err := m.Int64ObservableGauge( + uncommittedIndexCountMetricsName, + metrics.WithDescription(uncommittedIndexCountMetricsDescription), + metrics.WithUnit(metrics.Dimensionless), + ) + if err != nil { + return err + } + + insertVQueueCount, err := m.Int64ObservableGauge( + insertVQueueCountMetricsName, + metrics.WithDescription(insertVQueueCountMetricsDescription), + metrics.WithUnit(metrics.Dimensionless), + ) + if err != nil { + return err + } + + deleteVQueueCount, err := m.Int64ObservableGauge( + deleteVQueueCountMetricsName, + metrics.WithDescription(deleteVQueueCountMetricsDescription), + metrics.WithUnit(metrics.Dimensionless), + ) + if err != nil { + return err + } + + completedCreateIndexTotal, err := m.Int64ObservableGauge( + completedCreateIndexTotalMetricsName, + metrics.WithDescription(completedCreateIndexTotalMetricsDescription), + metrics.WithUnit(metrics.Dimensionless), + ) + if err != nil { + return err + } + + executedProactiveGCTotal, err := m.Int64ObservableGauge( + executedProactiveGCTotalMetricsName, + metrics.WithDescription(executedProactiveGCTotalMetricsDescription), + metrics.WithUnit(metrics.Dimensionless), + ) + if err != nil { + return err + } + + isIndexing, err := m.Int64ObservableGauge( + isIndexingMetricsName, + metrics.WithDescription(isIndexingMetricsDescription), + metrics.WithUnit(metrics.Dimensionless), + ) + if err != nil { + return err + } + + isSaving, err := m.Int64ObservableGauge( + isSavingMetricsName, + metrics.WithDescription(isSavingMetricsDescription), + metrics.WithUnit(metrics.Dimensionless), + ) + if err != nil { + return err + } + + trainCount, err := m.Int64ObservableGauge( + trainCountMetricsName, + metrics.WithDescription(trainCountMetricsDescription), + metrics.WithUnit(metrics.Dimensionless), + ) + if err != nil { + return err + } + + _, err = m.RegisterCallback( + func(_ context.Context, o api.Observer) error { + var indexing int64 + if f.faiss.IsIndexing() { + indexing = 1 + } + var saving int64 + if f.faiss.IsSaving() { + saving = 1 + } + + o.ObserveInt64(indexCount, int64(f.faiss.Len())) + o.ObserveInt64(uncommittedIndexCount, int64(f.faiss.InsertVQueueBufferLen()+f.faiss.DeleteVQueueBufferLen())) + o.ObserveInt64(insertVQueueCount, int64(f.faiss.InsertVQueueBufferLen())) + o.ObserveInt64(deleteVQueueCount, int64(int64(f.faiss.DeleteVQueueBufferLen()))) + o.ObserveInt64(completedCreateIndexTotal, int64(f.faiss.NumberOfCreateIndexExecution())) + o.ObserveInt64(executedProactiveGCTotal, int64(f.faiss.NumberOfProactiveGCExecution())) + o.ObserveInt64(isIndexing, int64(indexing)) + o.ObserveInt64(isSaving, int64(saving)) + o.ObserveInt64(trainCount, int64(f.faiss.GetTrainSize())) + + return nil + }, + indexCount, + uncommittedIndexCount, + insertVQueueCount, + deleteVQueueCount, + completedCreateIndexTotal, + executedProactiveGCTotal, + isIndexing, + isSaving, + trainCount, + ) + return err +} diff --git a/pkg/agent/core/faiss/config/config.go b/pkg/agent/core/faiss/config/config.go new file mode 100644 index 0000000000..8952abd715 --- /dev/null +++ b/pkg/agent/core/faiss/config/config.go @@ -0,0 +1,77 @@ +// +// Copyright (C) 2019-2023 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package setting stores all server application settings +package config + +import ( + "github.com/vdaas/vald/internal/config" + "github.com/vdaas/vald/internal/errors" +) + +// GlobalConfig is type alias for config.GlobalConfig. +type GlobalConfig = config.GlobalConfig + +// Data represent a application setting data content (config.yaml). +// In K8s environment, this configuration is stored in K8s ConfigMap. +type Data struct { + GlobalConfig `json:",inline" yaml:",inline"` + + // Server represent all server configurations + Server *config.Servers `json:"server_config" yaml:"server_config"` + + // Observability represent observability configurations + Observability *config.Observability `json:"observability" yaml:"observability"` + + // Faiss represent faiss core configuration + Faiss *config.Faiss `json:"faiss" yaml:"faiss"` +} + +// NewConfig returns the Data struct or error from the given file path. +func NewConfig(path string) (cfg *Data, err error) { + cfg = new(Data) + + err = config.Read(path, &cfg) + if err != nil { + return nil, err + } + + if cfg != nil { + cfg.Bind() + } else { + return nil, errors.ErrInvalidConfig + } + + if cfg.Server != nil { + cfg.Server = cfg.Server.Bind() + } else { + return nil, errors.ErrInvalidConfig + } + + if cfg.Observability != nil { + cfg.Observability = cfg.Observability.Bind() + } else { + cfg.Observability = new(config.Observability).Bind() + } + + if cfg.Faiss != nil { + cfg.Faiss = cfg.Faiss.Bind() + } else { + return nil, errors.ErrInvalidConfig + } + + return cfg, nil +} diff --git a/pkg/agent/core/faiss/handler/doc.go b/pkg/agent/core/faiss/handler/doc.go new file mode 100644 index 0000000000..825bcc7f61 --- /dev/null +++ b/pkg/agent/core/faiss/handler/doc.go @@ -0,0 +1,17 @@ +// +// Copyright (C) 2019-2023 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler diff --git a/pkg/agent/core/faiss/handler/grpc/handler.go b/pkg/agent/core/faiss/handler/grpc/handler.go new file mode 100644 index 0000000000..89debe0b09 --- /dev/null +++ b/pkg/agent/core/faiss/handler/grpc/handler.go @@ -0,0 +1,95 @@ +// +// Copyright (C) 2019-2023 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package grpc provides grpc server logic +package grpc + +import ( + "reflect" + + agent "github.com/vdaas/vald/apis/grpc/v1/agent/core" + "github.com/vdaas/vald/apis/grpc/v1/payload" + "github.com/vdaas/vald/apis/grpc/v1/vald" + "github.com/vdaas/vald/internal/errgroup" + "github.com/vdaas/vald/internal/errors" + "github.com/vdaas/vald/internal/log" + "github.com/vdaas/vald/pkg/agent/core/faiss/service" +) + +type Server interface { + agent.AgentServer + vald.Server +} + +type server struct { + name string + ip string + faiss service.Faiss + eg errgroup.Group + streamConcurrency int + agent.UnimplementedAgentServer + vald.UnimplementedValdServer +} + +const ( + apiName = "vald/agent/core/faiss" + faissResourceType = "vald/internal/core/algorithm" +) + +var errFaiss = new(errors.FaissError) + +func New(opts ...Option) (Server, error) { + s := new(server) + + for _, opt := range append(defaultOptions, opts...) { + if err := opt(s); err != nil { + werr := errors.ErrOptionFailed(err, reflect.ValueOf(opt)) + + e := new(errors.ErrCriticalOption) + if errors.As(err, &e) { + log.Error(werr) + return nil, werr + } + log.Warn(werr) + } + } + return s, nil +} + +func (s *server) newLocations(uuids ...string) (locs *payload.Object_Locations) { + if len(uuids) == 0 { + return nil + } + locs = &payload.Object_Locations{ + Locations: make([]*payload.Object_Location, 0, len(uuids)), + } + for _, uuid := range uuids { + locs.Locations = append(locs.GetLocations(), &payload.Object_Location{ + Name: s.name, + Uuid: uuid, + Ips: []string{s.ip}, + }) + } + return locs +} + +func (s *server) newLocation(uuid string) *payload.Object_Location { + locs := s.newLocations(uuid) + if locs != nil && locs.GetLocations() != nil && len(locs.GetLocations()) > 0 { + return locs.Locations[0] + } + return nil +} diff --git a/pkg/agent/core/faiss/handler/grpc/index.go b/pkg/agent/core/faiss/handler/grpc/index.go new file mode 100644 index 0000000000..4e95aab65f --- /dev/null +++ b/pkg/agent/core/faiss/handler/grpc/index.go @@ -0,0 +1,180 @@ +// +// Copyright (C) 2019-2023 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package grpc + +import ( + "context" + "fmt" + + "github.com/vdaas/vald/apis/grpc/v1/payload" + "github.com/vdaas/vald/internal/errors" + "github.com/vdaas/vald/internal/info" + "github.com/vdaas/vald/internal/log" + "github.com/vdaas/vald/internal/net/grpc/errdetails" + "github.com/vdaas/vald/internal/net/grpc/status" + "github.com/vdaas/vald/internal/observability/trace" +) + +func (s *server) CreateIndex(ctx context.Context, c *payload.Control_CreateIndexRequest) (res *payload.Empty, err error) { + ctx, span := trace.StartSpan(ctx, apiName+".CreateIndex") + defer func() { + if span != nil { + span.End() + } + }() + res = new(payload.Empty) + err = s.faiss.CreateIndex(ctx) + if err != nil { + if errors.Is(err, errors.ErrUncommittedIndexNotFound) { + err = status.WrapWithFailedPrecondition(fmt.Sprintf("CreateIndex API failed"), err, + &errdetails.RequestInfo{ + ServingData: errdetails.Serialize(c), + }, + &errdetails.ResourceInfo{ + ResourceType: faissResourceType + "/faiss.CreateIndex", + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }, + &errdetails.PreconditionFailure{ + Violations: []*errdetails.PreconditionFailureViolation{ + { + Type: "uncommitted index is empty", + Subject: "failed to CreateIndex operation caused by empty uncommitted indices", + }, + }, + }, info.Get()) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.StatusCodeFailedPrecondition(err.Error())...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + log.Error(err) + err = status.WrapWithInternal(fmt.Sprintf("CreateIndex API failed"), err, + &errdetails.RequestInfo{ + ServingData: errdetails.Serialize(c), + }, + &errdetails.ResourceInfo{ + ResourceType: faissResourceType + "/faiss.CreateIndex", + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }, info.Get()) + log.Error(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.StatusCodeInternal(err.Error())...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + return res, nil +} + +func (s *server) SaveIndex(ctx context.Context, _ *payload.Empty) (res *payload.Empty, err error) { + ctx, span := trace.StartSpan(ctx, apiName+".SaveIndex") + defer func() { + if span != nil { + span.End() + } + }() + res = new(payload.Empty) + err = s.faiss.SaveIndex(ctx) + if err != nil { + log.Error(err) + err = status.WrapWithInternal("SaveIndex API failed to save indices", err, + &errdetails.ResourceInfo{ + ResourceType: faissResourceType + "/faiss.SaveIndex", + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }, info.Get()) + log.Error(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.StatusCodeInternal(err.Error())...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + return res, nil +} + +func (s *server) CreateAndSaveIndex(ctx context.Context, c *payload.Control_CreateIndexRequest) (res *payload.Empty, err error) { + ctx, span := trace.StartSpan(ctx, apiName+".CreateAndSaveIndex") + defer func() { + if span != nil { + span.End() + } + }() + res = new(payload.Empty) + err = s.faiss.CreateAndSaveIndex(ctx) + if err != nil { + if errors.Is(err, errors.ErrUncommittedIndexNotFound) { + err = status.WrapWithFailedPrecondition(fmt.Sprintf("CreateAndSaveIndex API failed to create indexes pool_size = %d", c.GetPoolSize()), err, + &errdetails.RequestInfo{ + ServingData: errdetails.Serialize(c), + }, + &errdetails.ResourceInfo{ + ResourceType: faissResourceType + "/faiss.CreateAndSaveIndex", + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }, + &errdetails.PreconditionFailure{ + Violations: []*errdetails.PreconditionFailureViolation{ + { + Type: "uncommitted index is empty", + Subject: "failed to CreateAndSaveIndex operation caused by empty uncommitted indices", + }, + }, + }, info.Get()) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.StatusCodeFailedPrecondition(err.Error())...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + err = status.WrapWithInternal(fmt.Sprintf("CreateAndSaveIndex API failed to create indexes pool_size = %d", c.GetPoolSize()), err, + &errdetails.RequestInfo{ + ServingData: errdetails.Serialize(c), + }, + &errdetails.ResourceInfo{ + ResourceType: faissResourceType + "/faiss.CreateAndSaveIndex", + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }, info.Get()) + log.Error(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.StatusCodeInternal(err.Error())...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + return res, nil +} + +func (s *server) IndexInfo(ctx context.Context, c *payload.Empty) (res *payload.Info_Index_Count, err error) { + _, span := trace.StartSpan(ctx, apiName+".IndexInfo") + defer func() { + if span != nil { + span.End() + } + }() + + return &payload.Info_Index_Count{ + Stored: uint32(s.faiss.Len()), + Uncommitted: uint32(s.faiss.InsertVQueueBufferLen() + s.faiss.DeleteVQueueBufferLen()), + Indexing: s.faiss.IsIndexing(), + Saving: s.faiss.IsSaving(), + }, nil +} diff --git a/pkg/agent/core/faiss/handler/grpc/insert.go b/pkg/agent/core/faiss/handler/grpc/insert.go new file mode 100644 index 0000000000..434b1c15d7 --- /dev/null +++ b/pkg/agent/core/faiss/handler/grpc/insert.go @@ -0,0 +1,141 @@ +// +// Copyright (C) 2019-2023 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package grpc + +import ( + "context" + "fmt" + + "github.com/vdaas/vald/apis/grpc/v1/payload" + "github.com/vdaas/vald/apis/grpc/v1/vald" + "github.com/vdaas/vald/internal/errors" + "github.com/vdaas/vald/internal/info" + "github.com/vdaas/vald/internal/log" + "github.com/vdaas/vald/internal/net/grpc/codes" + "github.com/vdaas/vald/internal/net/grpc/errdetails" + "github.com/vdaas/vald/internal/net/grpc/status" + "github.com/vdaas/vald/internal/observability/trace" + "go.opentelemetry.io/otel/attribute" +) + +func (s *server) Insert(ctx context.Context, req *payload.Insert_Request) (res *payload.Object_Location, err error) { + _, span := trace.StartSpan(ctx, apiName+"/"+vald.InsertRPCName) + defer func() { + if span != nil { + span.End() + } + }() + vec := req.GetVector() + if len(vec.GetVector()) != s.faiss.GetDimensionSize() { + err = errors.ErrIncompatibleDimensionSize(len(vec.GetVector()), int(s.faiss.GetDimensionSize())) + err = status.WrapWithInvalidArgument("Insert API Incompatible Dimension Size detected", + err, + &errdetails.RequestInfo{ + RequestId: vec.GetId(), + ServingData: errdetails.Serialize(req), + }, + &errdetails.BadRequest{ + FieldViolations: []*errdetails.BadRequestFieldViolation{ + { + Field: "vector dimension size", + Description: err.Error(), + }, + }, + }, + &errdetails.ResourceInfo{ + ResourceType: faissResourceType + "/faiss.Insert", + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }) + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.StatusCodeInvalidArgument(err.Error())...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + + err = s.faiss.InsertWithTime(vec.GetId(), vec.GetVector(), req.GetConfig().GetTimestamp()) + if err != nil { + var attrs []attribute.KeyValue + + if errors.Is(err, errors.ErrUUIDAlreadyExists(vec.GetId())) { + err = status.WrapWithAlreadyExists(fmt.Sprintf("Insert API uuid %s already exists", vec.GetId()), err, + &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + ServingData: errdetails.Serialize(req), + }, + &errdetails.ResourceInfo{ + ResourceType: faissResourceType + "/faiss.Insert", + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }) + log.Warn(err) + attrs = trace.StatusCodeAlreadyExists(err.Error()) + } else if errors.Is(err, errors.ErrUUIDNotFound(0)) { + err = status.WrapWithInvalidArgument(fmt.Sprintf("Insert API empty uuid \"%s\" was given", vec.GetId()), err, + &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + ServingData: errdetails.Serialize(req), + }, + &errdetails.BadRequest{ + FieldViolations: []*errdetails.BadRequestFieldViolation{ + { + Field: "uuid", + Description: err.Error(), + }, + }, + }, + &errdetails.ResourceInfo{ + ResourceType: faissResourceType + "/faiss.Insert", + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }) + log.Warn(err) + attrs = trace.StatusCodeInvalidArgument(err.Error()) + } else { + var ( + st *status.Status + msg string + ) + st, msg, err = status.ParseError(err, codes.Internal, + "failed to parse Insert gRPC error response", + &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + ServingData: errdetails.Serialize(req), + }, + &errdetails.ResourceInfo{ + ResourceType: faissResourceType + "/faiss.Insert", + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }, info.Get()) + attrs = trace.FromGRPCStatus(st.Code(), msg) + } + if span != nil { + span.RecordError(err) + span.SetAttributes(attrs...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + return s.newLocation(vec.GetId()), nil +} + +func (s *server) StreamInsert(stream vald.Insert_StreamInsertServer) (err error) { + return s.UnimplementedValdServer.UnimplementedInsertServer.StreamInsert(stream) +} + +func (s *server) MultiInsert(ctx context.Context, reqs *payload.Insert_MultiRequest) (res *payload.Object_Locations, err error) { + return s.UnimplementedValdServer.UnimplementedInsertServer.MultiInsert(ctx, reqs) +} diff --git a/pkg/agent/core/faiss/handler/grpc/linear_search.go b/pkg/agent/core/faiss/handler/grpc/linear_search.go new file mode 100644 index 0000000000..3bdf2515d8 --- /dev/null +++ b/pkg/agent/core/faiss/handler/grpc/linear_search.go @@ -0,0 +1,48 @@ +// +// Copyright (C) 2019-2023 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package grpc + +import ( + "context" + + "github.com/vdaas/vald/apis/grpc/v1/payload" + "github.com/vdaas/vald/apis/grpc/v1/vald" +) + +func (s *server) LinearSearch(ctx context.Context, req *payload.Search_Request) (res *payload.Search_Response, err error) { + return s.UnimplementedValdServer.UnimplementedSearchServer.LinearSearch(ctx, req) +} + +func (s *server) LinearSearchByID(ctx context.Context, req *payload.Search_IDRequest) (res *payload.Search_Response, err error) { + return s.UnimplementedValdServer.UnimplementedSearchServer.LinearSearchByID(ctx, req) +} + +func (s *server) StreamLinearSearch(stream vald.Search_StreamLinearSearchServer) (err error) { + return s.UnimplementedValdServer.UnimplementedSearchServer.StreamLinearSearch(stream) +} + +func (s *server) StreamLinearSearchByID(stream vald.Search_StreamLinearSearchByIDServer) (err error) { + return s.UnimplementedValdServer.UnimplementedSearchServer.StreamLinearSearchByID(stream) +} + +func (s *server) MultiLinearSearch(ctx context.Context, reqs *payload.Search_MultiRequest) (res *payload.Search_Responses, errs error) { + return s.UnimplementedValdServer.UnimplementedSearchServer.MultiLinearSearch(ctx, reqs) +} + +func (s *server) MultiLinearSearchByID(ctx context.Context, reqs *payload.Search_MultiIDRequest) (res *payload.Search_Responses, errs error) { + return s.UnimplementedValdServer.UnimplementedSearchServer.MultiLinearSearchByID(ctx, reqs) +} diff --git a/pkg/agent/core/faiss/handler/grpc/object.go b/pkg/agent/core/faiss/handler/grpc/object.go new file mode 100644 index 0000000000..214016d9eb --- /dev/null +++ b/pkg/agent/core/faiss/handler/grpc/object.go @@ -0,0 +1,96 @@ +// +// Copyright (C) 2019-2023 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package grpc + +import ( + "context" + "fmt" + + "github.com/vdaas/vald/apis/grpc/v1/payload" + "github.com/vdaas/vald/apis/grpc/v1/vald" + "github.com/vdaas/vald/internal/errors" + "github.com/vdaas/vald/internal/log" + "github.com/vdaas/vald/internal/net/grpc/errdetails" + "github.com/vdaas/vald/internal/net/grpc/status" + "github.com/vdaas/vald/internal/observability/trace" +) + +func (s *server) Exists(ctx context.Context, uid *payload.Object_ID) (res *payload.Object_ID, err error) { + _, span := trace.StartSpan(ctx, apiName+"/"+vald.ExistsRPCName) + defer func() { + if span != nil { + span.End() + } + }() + uuid := uid.GetId() + if len(uuid) == 0 { + err = errors.ErrInvalidUUID(uuid) + err = status.WrapWithInvalidArgument(fmt.Sprintf("Exists API invalid argument for uuid \"%s\" detected", uuid), err, + &errdetails.RequestInfo{ + RequestId: uuid, + ServingData: errdetails.Serialize(uid), + }, + &errdetails.BadRequest{ + FieldViolations: []*errdetails.BadRequestFieldViolation{ + { + Field: "uuid", + Description: err.Error(), + }, + }, + }, + &errdetails.ResourceInfo{ + ResourceType: faissResourceType + "/faiss.Exists", + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }) + + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.StatusCodeInvalidArgument(err.Error())...) + span.SetStatus(trace.StatusError, err.Error()) + } + log.Warn(err) + return nil, err + } + if _, ok := s.faiss.Exists(uuid); !ok { + err = errors.ErrObjectIDNotFound(uid.GetId()) + err = status.WrapWithNotFound(fmt.Sprintf("Exists API meta %s's uuid not found", uid.GetId()), err, + &errdetails.RequestInfo{ + RequestId: uid.GetId(), + ServingData: errdetails.Serialize(uid), + }, + &errdetails.ResourceInfo{ + ResourceType: faissResourceType + "/faiss.Exists", + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }, + uid.GetId()) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.StatusCodeNotFound(err.Error())...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + return uid, nil +} + +func (s *server) GetObject(ctx context.Context, id *payload.Object_VectorRequest) (res *payload.Object_Vector, err error) { + return s.UnimplementedValdServer.UnimplementedObjectServer.GetObject(ctx, id) +} + +func (s *server) StreamGetObject(stream vald.Object_StreamGetObjectServer) (err error) { + return s.UnimplementedValdServer.UnimplementedObjectServer.StreamGetObject(stream) +} diff --git a/pkg/agent/core/faiss/handler/grpc/option.go b/pkg/agent/core/faiss/handler/grpc/option.go new file mode 100644 index 0000000000..2b443b1cfe --- /dev/null +++ b/pkg/agent/core/faiss/handler/grpc/option.go @@ -0,0 +1,100 @@ +// +// Copyright (C) 2019-2023 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package grpc provides grpc server logic +package grpc + +import ( + "os" + "runtime" + + "github.com/vdaas/vald/internal/errgroup" + "github.com/vdaas/vald/internal/errors" + "github.com/vdaas/vald/internal/log" + "github.com/vdaas/vald/internal/net" + "github.com/vdaas/vald/pkg/agent/core/faiss/service" +) + +// Option represents the functional option for server. +type Option func(*server) error + +var defaultOptions = []Option{ + WithName(func() string { + name, err := os.Hostname() + if err != nil { + log.Warn(err) + } + return name + }()), + WithIP(net.LoadLocalIP()), + WithStreamConcurrency(runtime.GOMAXPROCS(-1) * 10), + WithErrGroup(errgroup.Get()), +} + +// WithIP returns the option to set the IP for server. +func WithIP(ip string) Option { + return func(s *server) error { + if len(ip) == 0 { + return errors.NewErrInvalidOption("ip", ip) + } + s.ip = ip + return nil + } +} + +// WithName returns the option to set the name for server. +func WithName(name string) Option { + return func(s *server) error { + if len(name) == 0 { + return errors.NewErrInvalidOption("name", name) + } + s.name = name + return nil + } +} + +// WithFaiss returns the option to set the Faiss service for server. +func WithFaiss(f service.Faiss) Option { + return func(s *server) error { + if f == nil { + return errors.NewErrInvalidOption("faiss", f) + } + s.faiss = f + return nil + } +} + +// WithStreamConcurrency returns the option to set the stream concurrency for server. +func WithStreamConcurrency(c int) Option { + return func(s *server) error { + if c <= 0 { + return errors.NewErrInvalidOption("streamConcurrency", c) + } + s.streamConcurrency = c + return nil + } +} + +// WithErrGroup returns the option to set the error group for server. +func WithErrGroup(eg errgroup.Group) Option { + return func(s *server) error { + if eg == nil { + return errors.NewErrInvalidOption("errGroup", eg) + } + s.eg = eg + return nil + } +} diff --git a/pkg/agent/core/faiss/handler/grpc/remove.go b/pkg/agent/core/faiss/handler/grpc/remove.go new file mode 100644 index 0000000000..1d7342c488 --- /dev/null +++ b/pkg/agent/core/faiss/handler/grpc/remove.go @@ -0,0 +1,134 @@ +// +// Copyright (C) 2019-2023 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package grpc + +import ( + "context" + "fmt" + + "github.com/vdaas/vald/apis/grpc/v1/payload" + "github.com/vdaas/vald/apis/grpc/v1/vald" + "github.com/vdaas/vald/internal/errors" + "github.com/vdaas/vald/internal/info" + "github.com/vdaas/vald/internal/log" + "github.com/vdaas/vald/internal/net/grpc/errdetails" + "github.com/vdaas/vald/internal/net/grpc/status" + "github.com/vdaas/vald/internal/observability/trace" + "go.opentelemetry.io/otel/attribute" +) + +func (s *server) Remove(ctx context.Context, req *payload.Remove_Request) (res *payload.Object_Location, err error) { + _, span := trace.StartSpan(ctx, apiName+"/"+vald.RemoveRPCName) + defer func() { + if span != nil { + span.End() + } + }() + id := req.GetId() + uuid := id.GetId() + if len(uuid) == 0 { + err = errors.ErrInvalidUUID(uuid) + err = status.WrapWithInvalidArgument(fmt.Sprintf("Remove API invalid argument for uuid \"%s\" detected", uuid), err, + &errdetails.RequestInfo{ + RequestId: uuid, + ServingData: errdetails.Serialize(req), + }, + &errdetails.BadRequest{ + FieldViolations: []*errdetails.BadRequestFieldViolation{ + { + Field: "uuid", + Description: err.Error(), + }, + }, + }, + &errdetails.ResourceInfo{ + ResourceType: faissResourceType + "/faiss.Remove", + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.StatusCodeInvalidArgument(err.Error())...) + span.SetStatus(trace.StatusError, err.Error()) + } + log.Warn(err) + return nil, err + } + err = s.faiss.DeleteWithTime(uuid, req.GetConfig().GetTimestamp()) + if err != nil { + var attrs []attribute.KeyValue + if errors.Is(err, errors.ErrObjectIDNotFound(uuid)) { + err = status.WrapWithNotFound(fmt.Sprintf("Remove API uuid %s not found", uuid), err, + &errdetails.RequestInfo{ + RequestId: uuid, + ServingData: errdetails.Serialize(req), + }, + &errdetails.ResourceInfo{ + ResourceType: faissResourceType + "/faiss.Remove", + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }) + log.Warn(err) + attrs = trace.StatusCodeNotFound(err.Error()) + } else if errors.Is(err, errors.ErrUUIDNotFound(0)) { + err = status.WrapWithInvalidArgument(fmt.Sprintf("Remove API invalid argument for uuid \"%s\" detected", uuid), err, + &errdetails.RequestInfo{ + RequestId: uuid, + ServingData: errdetails.Serialize(req), + }, + &errdetails.BadRequest{ + FieldViolations: []*errdetails.BadRequestFieldViolation{ + { + Field: "uuid", + Description: err.Error(), + }, + }, + }, + &errdetails.ResourceInfo{ + ResourceType: faissResourceType + "/faiss.Remove", + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }) + log.Warn(err) + attrs = trace.StatusCodeInvalidArgument(err.Error()) + } else { + err = status.WrapWithInternal("Remove API failed", err, + &errdetails.RequestInfo{ + RequestId: uuid, + ServingData: errdetails.Serialize(req), + }, + &errdetails.ResourceInfo{ + ResourceType: faissResourceType + "/faiss.Remove", + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }, info.Get()) + log.Error(err) + attrs = trace.StatusCodeInternal(err.Error()) + } + if span != nil { + span.RecordError(err) + span.SetAttributes(attrs...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + return s.newLocation(uuid), nil +} + +func (s *server) StreamRemove(stream vald.Remove_StreamRemoveServer) (err error) { + return s.UnimplementedValdServer.UnimplementedRemoveServer.StreamRemove(stream) +} + +func (s *server) MultiRemove(ctx context.Context, reqs *payload.Remove_MultiRequest) (res *payload.Object_Locations, err error) { + return s.UnimplementedValdServer.UnimplementedRemoveServer.MultiRemove(ctx, reqs) +} diff --git a/pkg/agent/core/faiss/handler/grpc/search.go b/pkg/agent/core/faiss/handler/grpc/search.go new file mode 100644 index 0000000000..46aac1e55f --- /dev/null +++ b/pkg/agent/core/faiss/handler/grpc/search.go @@ -0,0 +1,196 @@ +// +// Copyright (C) 2019-2023 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package grpc + +import ( + "context" + "fmt" + + "github.com/vdaas/vald/apis/grpc/v1/payload" + "github.com/vdaas/vald/apis/grpc/v1/vald" + "github.com/vdaas/vald/internal/errors" + "github.com/vdaas/vald/internal/info" + "github.com/vdaas/vald/internal/log" + "github.com/vdaas/vald/internal/net/grpc/errdetails" + "github.com/vdaas/vald/internal/net/grpc/status" + "github.com/vdaas/vald/internal/observability/trace" + "github.com/vdaas/vald/pkg/agent/core/faiss/model" + "go.opentelemetry.io/otel/attribute" +) + +func (s *server) Search(ctx context.Context, req *payload.Search_Request) (res *payload.Search_Response, err error) { + _, span := trace.StartSpan(ctx, apiName+"/"+vald.SearchRPCName) + defer func() { + if span != nil { + span.End() + } + }() + if len(req.GetVector()) != s.faiss.GetDimensionSize() { + err = errors.ErrIncompatibleDimensionSize(len(req.GetVector()), int(s.faiss.GetDimensionSize())) + err = status.WrapWithInvalidArgument("Search API Incompatible Dimension Size detected", + err, + &errdetails.RequestInfo{ + RequestId: req.GetConfig().GetRequestId(), + ServingData: errdetails.Serialize(req), + }, + &errdetails.BadRequest{ + FieldViolations: []*errdetails.BadRequestFieldViolation{ + { + Field: "vector dimension size", + Description: err.Error(), + }, + }, + }, + &errdetails.ResourceInfo{ + ResourceType: faissResourceType + "/faiss.Search", + }) + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.StatusCodeInvalidArgument(err.Error())...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + res, err = toSearchResponse( + s.faiss.Search( + req.GetConfig().GetNum(), + 1, + req.GetVector())) + if err != nil || res == nil { + var attrs []attribute.KeyValue + switch { + case errors.Is(err, errors.ErrCreateIndexingIsInProgress): + err = status.WrapWithAborted("Search API aborted to process search request due to createing indices is in progress", err, + &errdetails.RequestInfo{ + RequestId: req.GetConfig().GetRequestId(), + ServingData: errdetails.Serialize(req), + }, + &errdetails.ResourceInfo{ + ResourceType: faissResourceType + "/faiss.Search", + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }) + log.Debug(err) + attrs = trace.StatusCodeAborted(err.Error()) + case errors.Is(err, errors.ErrEmptySearchResult), + err == nil && res == nil, + 0 < req.GetConfig().GetMinNum() && len(res.GetResults()) < int(req.GetConfig().GetMinNum()): + err = status.WrapWithNotFound(fmt.Sprintf("Search API requestID %s's search result not found", req.GetConfig().GetRequestId()), err, + &errdetails.RequestInfo{ + RequestId: req.GetConfig().GetRequestId(), + ServingData: errdetails.Serialize(req), + }, + &errdetails.ResourceInfo{ + ResourceType: faissResourceType + "/faiss.Search", + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }) + log.Debug(err) + attrs = trace.StatusCodeNotFound(err.Error()) + case errors.As(err, &errFaiss): + log.Errorf("faiss core process returned error: %v", err) + err = status.WrapWithInternal("Search API failed to process search request due to faiss core process returned error", err, + &errdetails.RequestInfo{ + RequestId: req.GetConfig().GetRequestId(), + ServingData: errdetails.Serialize(req), + }, + &errdetails.ResourceInfo{ + ResourceType: faissResourceType + "/faiss.Search/core.faiss", + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }, info.Get()) + log.Error(err) + attrs = trace.StatusCodeInternal(err.Error()) + case errors.Is(err, errors.ErrIncompatibleDimensionSize(len(req.GetVector()), int(s.faiss.GetDimensionSize()))): + err = status.WrapWithInvalidArgument("Search API Incompatible Dimension Size detected", + err, + &errdetails.RequestInfo{ + RequestId: req.GetConfig().GetRequestId(), + ServingData: errdetails.Serialize(req), + }, + &errdetails.BadRequest{ + FieldViolations: []*errdetails.BadRequestFieldViolation{ + { + Field: "vector dimension size", + Description: err.Error(), + }, + }, + }, + &errdetails.ResourceInfo{ + ResourceType: faissResourceType + "/faiss.Search", + }) + log.Warn(err) + attrs = trace.StatusCodeInvalidArgument(err.Error()) + default: + err = status.WrapWithInternal("Search API failed to process search request", err, + &errdetails.RequestInfo{ + RequestId: req.GetConfig().GetRequestId(), + ServingData: errdetails.Serialize(req), + }, + &errdetails.ResourceInfo{ + ResourceType: faissResourceType + "/faiss.Search", + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }, info.Get()) + log.Error(err) + attrs = trace.StatusCodeInternal(err.Error()) + } + if span != nil { + span.RecordError(err) + span.SetAttributes(attrs...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + res.RequestId = req.GetConfig().GetRequestId() + return res, nil +} + +func (s *server) SearchByID(ctx context.Context, req *payload.Search_IDRequest) (res *payload.Search_Response, err error) { + return s.UnimplementedValdServer.UnimplementedSearchServer.SearchByID(ctx, req) +} + +func (s *server) StreamSearch(stream vald.Search_StreamSearchServer) (err error) { + return s.UnimplementedValdServer.UnimplementedSearchServer.StreamSearch(stream) +} + +func (s *server) StreamSearchByID(stream vald.Search_StreamSearchByIDServer) (err error) { + return s.UnimplementedValdServer.UnimplementedSearchServer.StreamSearchByID(stream) +} + +func (s *server) MultiSearch(ctx context.Context, reqs *payload.Search_MultiRequest) (res *payload.Search_Responses, errs error) { + return s.UnimplementedValdServer.UnimplementedSearchServer.MultiSearch(ctx, reqs) +} + +func (s *server) MultiSearchByID(ctx context.Context, reqs *payload.Search_MultiIDRequest) (res *payload.Search_Responses, errs error) { + return s.UnimplementedValdServer.UnimplementedSearchServer.MultiSearchByID(ctx, reqs) +} + +func toSearchResponse(dists []model.Distance, err error) (res *payload.Search_Response, rerr error) { + if err != nil { + return nil, err + } + if len(dists) == 0 { + return nil, errors.ErrEmptySearchResult + } + res = new(payload.Search_Response) + res.Results = make([]*payload.Object_Distance, 0, len(dists)) + for _, dist := range dists { + res.Results = append(res.GetResults(), &payload.Object_Distance{ + Id: dist.ID, + Distance: dist.Distance, + }) + } + return res, nil +} diff --git a/pkg/agent/core/faiss/handler/grpc/update.go b/pkg/agent/core/faiss/handler/grpc/update.go new file mode 100644 index 0000000000..ccf88a6d27 --- /dev/null +++ b/pkg/agent/core/faiss/handler/grpc/update.go @@ -0,0 +1,173 @@ +// +// Copyright (C) 2019-2023 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package grpc + +import ( + "context" + "fmt" + + "github.com/vdaas/vald/apis/grpc/v1/payload" + "github.com/vdaas/vald/apis/grpc/v1/vald" + "github.com/vdaas/vald/internal/errors" + "github.com/vdaas/vald/internal/info" + "github.com/vdaas/vald/internal/log" + "github.com/vdaas/vald/internal/net/grpc/errdetails" + "github.com/vdaas/vald/internal/net/grpc/status" + "github.com/vdaas/vald/internal/observability/trace" + "go.opentelemetry.io/otel/attribute" +) + +func (s *server) Update(ctx context.Context, req *payload.Update_Request) (res *payload.Object_Location, err error) { + _, span := trace.StartSpan(ctx, apiName+"/"+vald.UpdateRPCName) + defer func() { + if span != nil { + span.End() + } + }() + + vec := req.GetVector() + if len(vec.GetVector()) != s.faiss.GetDimensionSize() { + err = errors.ErrIncompatibleDimensionSize(len(vec.GetVector()), int(s.faiss.GetDimensionSize())) + err = status.WrapWithInvalidArgument("Update API Incompatible Dimension Size detected", + err, + &errdetails.RequestInfo{ + RequestId: vec.GetId(), + ServingData: errdetails.Serialize(req), + }, + &errdetails.BadRequest{ + FieldViolations: []*errdetails.BadRequestFieldViolation{ + { + Field: "vector dimension size", + Description: err.Error(), + }, + }, + }, + &errdetails.ResourceInfo{ + ResourceType: faissResourceType + "/faiss.Update", + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }) + log.Warn(err) + if span != nil { + span.RecordError(err) + span.SetAttributes(trace.StatusCodeInvalidArgument(err.Error())...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + + uuid := vec.GetId() + if len(uuid) == 0 { + err = errors.ErrInvalidUUID(uuid) + err = status.WrapWithInvalidArgument(fmt.Sprintf("Update API invalid argument for uuid \"%s\" detected", uuid), err, + &errdetails.RequestInfo{ + RequestId: uuid, + ServingData: errdetails.Serialize(req), + }, + &errdetails.BadRequest{ + FieldViolations: []*errdetails.BadRequestFieldViolation{ + { + Field: "uuid", + Description: err.Error(), + }, + }, + }, + &errdetails.ResourceInfo{ + ResourceType: faissResourceType + "/faiss.Update", + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }) + log.Warn(err) + return nil, err + } + + err = s.faiss.UpdateWithTime(uuid, vec.GetVector(), req.GetConfig().GetTimestamp()) + if err != nil { + var attrs []attribute.KeyValue + if errors.Is(err, errors.ErrObjectIDNotFound(vec.GetId())) { + err = status.WrapWithNotFound(fmt.Sprintf("Update API uuid %s not found", vec.GetId()), err, + &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + ServingData: errdetails.Serialize(req), + }, + &errdetails.ResourceInfo{ + ResourceType: faissResourceType + "/faiss.Update", + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }) + log.Warn(err) + attrs = trace.StatusCodeNotFound(err.Error()) + } else if errors.Is(err, errors.ErrUUIDNotFound(0)) || errors.Is(err, errors.ErrInvalidDimensionSize(len(vec.GetVector()), s.faiss.GetDimensionSize())) { + err = status.WrapWithInvalidArgument(fmt.Sprintf("Update API invalid argument for uuid \"%s\" vec \"%v\" detected", vec.GetId(), vec.GetVector()), err, + &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + ServingData: errdetails.Serialize(req), + }, + &errdetails.BadRequest{ + FieldViolations: []*errdetails.BadRequestFieldViolation{ + { + Field: "uuid or vector", + Description: err.Error(), + }, + }, + }, + &errdetails.ResourceInfo{ + ResourceType: faissResourceType + "/faiss.Update", + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }) + log.Warn(err) + attrs = trace.StatusCodeInvalidArgument(err.Error()) + } else if errors.Is(err, errors.ErrUUIDAlreadyExists(vec.GetId())) { + err = status.WrapWithAlreadyExists(fmt.Sprintf("Update API uuid %s's same data already exists", vec.GetId()), err, + &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + ServingData: errdetails.Serialize(req), + }, + &errdetails.ResourceInfo{ + ResourceType: faissResourceType + "/faiss.Update", + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }) + log.Warn(err) + attrs = trace.StatusCodeAlreadyExists(err.Error()) + } else { + err = status.WrapWithInternal("Update API failed", err, + &errdetails.RequestInfo{ + RequestId: req.GetVector().GetId(), + ServingData: errdetails.Serialize(req), + }, + &errdetails.ResourceInfo{ + ResourceType: faissResourceType + "/faiss.Update", + ResourceName: fmt.Sprintf("%s: %s(%s)", apiName, s.name, s.ip), + }, info.Get()) + log.Error(err) + attrs = trace.StatusCodeInternal(err.Error()) + } + if span != nil { + span.RecordError(err) + span.SetAttributes(attrs...) + span.SetStatus(trace.StatusError, err.Error()) + } + return nil, err + } + + return s.newLocation(vec.GetId()), nil +} + +func (s *server) StreamUpdate(stream vald.Update_StreamUpdateServer) (err error) { + return s.UnimplementedValdServer.UnimplementedUpdateServer.StreamUpdate(stream) +} + +func (s *server) MultiUpdate(ctx context.Context, reqs *payload.Update_MultiRequest) (res *payload.Object_Locations, err error) { + return s.UnimplementedValdServer.UnimplementedUpdateServer.MultiUpdate(ctx, reqs) +} diff --git a/pkg/agent/core/faiss/handler/grpc/upsert.go b/pkg/agent/core/faiss/handler/grpc/upsert.go new file mode 100644 index 0000000000..a6ad7960ed --- /dev/null +++ b/pkg/agent/core/faiss/handler/grpc/upsert.go @@ -0,0 +1,36 @@ +// +// Copyright (C) 2019-2023 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package grpc + +import ( + "context" + + "github.com/vdaas/vald/apis/grpc/v1/payload" + "github.com/vdaas/vald/apis/grpc/v1/vald" +) + +func (s *server) Upsert(ctx context.Context, req *payload.Upsert_Request) (loc *payload.Object_Location, err error) { + return s.UnimplementedValdServer.UnimplementedUpsertServer.Upsert(ctx, req) +} + +func (s *server) StreamUpsert(stream vald.Upsert_StreamUpsertServer) (err error) { + return s.UnimplementedValdServer.UnimplementedUpsertServer.StreamUpsert(stream) +} + +func (s *server) MultiUpsert(ctx context.Context, reqs *payload.Upsert_MultiRequest) (res *payload.Object_Locations, err error) { + return s.UnimplementedValdServer.UnimplementedUpsertServer.MultiUpsert(ctx, reqs) +} diff --git a/pkg/agent/core/faiss/handler/rest/handler.go b/pkg/agent/core/faiss/handler/rest/handler.go new file mode 100644 index 0000000000..89e4c20d56 --- /dev/null +++ b/pkg/agent/core/faiss/handler/rest/handler.go @@ -0,0 +1,175 @@ +// +// Copyright (C) 2019-2023 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package rest provides rest api logic +package rest + +import ( + "net/http" + + "github.com/vdaas/vald/apis/grpc/v1/payload" + "github.com/vdaas/vald/internal/net/http/dump" + "github.com/vdaas/vald/internal/net/http/json" + "github.com/vdaas/vald/pkg/agent/core/faiss/handler/grpc" +) + +type Handler interface { + Index(w http.ResponseWriter, r *http.Request) (int, error) + Exists(w http.ResponseWriter, r *http.Request) (int, error) + Search(w http.ResponseWriter, r *http.Request) (int, error) + SearchByID(w http.ResponseWriter, r *http.Request) (int, error) + LinearSearch(w http.ResponseWriter, r *http.Request) (int, error) + LinearSearchByID(w http.ResponseWriter, r *http.Request) (int, error) + Insert(w http.ResponseWriter, r *http.Request) (int, error) + MultiInsert(w http.ResponseWriter, r *http.Request) (int, error) + Update(w http.ResponseWriter, r *http.Request) (int, error) + MultiUpdate(w http.ResponseWriter, r *http.Request) (int, error) + Remove(w http.ResponseWriter, r *http.Request) (int, error) + MultiRemove(w http.ResponseWriter, r *http.Request) (int, error) + CreateIndex(w http.ResponseWriter, r *http.Request) (int, error) + SaveIndex(w http.ResponseWriter, r *http.Request) (int, error) + CreateAndSaveIndex(w http.ResponseWriter, r *http.Request) (int, error) + GetObject(w http.ResponseWriter, r *http.Request) (int, error) +} + +type handler struct { + agent grpc.Server +} + +func New(opts ...Option) Handler { + h := new(handler) + + for _, opt := range append(defaultOptions, opts...) { + opt(h) + } + return h +} + +func (h *handler) Index(w http.ResponseWriter, r *http.Request) (int, error) { + data := make(map[string]interface{}) + return json.Handler(w, r, &data, func() (interface{}, error) { + return dump.Request(nil, data, r) + }) +} + +func (h *handler) Search(w http.ResponseWriter, r *http.Request) (code int, err error) { + var req *payload.Search_Request + return json.Handler(w, r, &req, func() (interface{}, error) { + return h.agent.Search(r.Context(), req) + }) +} + +func (h *handler) SearchByID(w http.ResponseWriter, r *http.Request) (code int, err error) { + var req *payload.Search_IDRequest + return json.Handler(w, r, &req, func() (interface{}, error) { + return h.agent.SearchByID(r.Context(), req) + }) +} + +func (h *handler) LinearSearch(w http.ResponseWriter, r *http.Request) (code int, err error) { + var req *payload.Search_Request + return json.Handler(w, r, &req, func() (interface{}, error) { + return h.agent.LinearSearch(r.Context(), req) + }) +} + +func (h *handler) LinearSearchByID(w http.ResponseWriter, r *http.Request) (code int, err error) { + var req *payload.Search_IDRequest + return json.Handler(w, r, &req, func() (interface{}, error) { + return h.agent.LinearSearchByID(r.Context(), req) + }) +} + +func (h *handler) Insert(w http.ResponseWriter, r *http.Request) (code int, err error) { + var req *payload.Insert_Request + return json.Handler(w, r, &req, func() (interface{}, error) { + return h.agent.Insert(r.Context(), req) + }) +} + +func (h *handler) MultiInsert(w http.ResponseWriter, r *http.Request) (code int, err error) { + var req *payload.Insert_MultiRequest + return json.Handler(w, r, &req, func() (interface{}, error) { + return h.agent.MultiInsert(r.Context(), req) + }) +} + +func (h *handler) Update(w http.ResponseWriter, r *http.Request) (code int, err error) { + var req *payload.Update_Request + return json.Handler(w, r, &req, func() (interface{}, error) { + return h.agent.Update(r.Context(), req) + }) +} + +func (h *handler) MultiUpdate(w http.ResponseWriter, r *http.Request) (code int, err error) { + var req *payload.Update_MultiRequest + return json.Handler(w, r, &req, func() (interface{}, error) { + return h.agent.MultiUpdate(r.Context(), req) + }) +} + +func (h *handler) Remove(w http.ResponseWriter, r *http.Request) (code int, err error) { + var req *payload.Remove_Request + return json.Handler(w, r, &req, func() (interface{}, error) { + return h.agent.Remove(r.Context(), req) + }) +} + +func (h *handler) MultiRemove(w http.ResponseWriter, r *http.Request) (code int, err error) { + var req *payload.Remove_MultiRequest + return json.Handler(w, r, &req, func() (interface{}, error) { + return h.agent.MultiRemove(r.Context(), req) + }) +} + +func (h *handler) CreateIndex(w http.ResponseWriter, r *http.Request) (code int, err error) { + var req *payload.Control_CreateIndexRequest + return json.Handler(w, r, &req, func() (interface{}, error) { + return h.agent.CreateIndex(r.Context(), req) + }) +} + +func (h *handler) SaveIndex(w http.ResponseWriter, r *http.Request) (code int, err error) { + var req *payload.Empty + return json.Handler(w, r, &req, func() (interface{}, error) { + return h.agent.SaveIndex(r.Context(), req) + }) +} + +func (h *handler) CreateAndSaveIndex(w http.ResponseWriter, r *http.Request) (code int, err error) { + var req *payload.Control_CreateIndexRequest + return json.Handler(w, r, &req, func() (interface{}, error) { + _, err = h.agent.CreateIndex(r.Context(), req) + if err != nil { + return nil, err + } + return h.agent.SaveIndex(r.Context(), nil) + }) +} + +func (h *handler) GetObject(w http.ResponseWriter, r *http.Request) (code int, err error) { + var req *payload.Object_VectorRequest + return json.Handler(w, r, &req, func() (interface{}, error) { + return h.agent.GetObject(r.Context(), req) + }) +} + +func (h *handler) Exists(w http.ResponseWriter, r *http.Request) (code int, err error) { + var req *payload.Object_ID + return json.Handler(w, r, &req, func() (interface{}, error) { + return h.agent.Exists(r.Context(), req) + }) +} diff --git a/pkg/agent/core/faiss/handler/rest/option.go b/pkg/agent/core/faiss/handler/rest/option.go new file mode 100644 index 0000000000..9b3b41e7d6 --- /dev/null +++ b/pkg/agent/core/faiss/handler/rest/option.go @@ -0,0 +1,30 @@ +// +// Copyright (C) 2019-2023 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package rest provides rest api logic +package rest + +import "github.com/vdaas/vald/pkg/agent/core/faiss/handler/grpc" + +type Option func(*handler) + +var defaultOptions = []Option{} + +func WithAgent(a grpc.Server) Option { + return func(h *handler) { + h.agent = a + } +} diff --git a/pkg/agent/core/faiss/model/faiss.go b/pkg/agent/core/faiss/model/faiss.go new file mode 100644 index 0000000000..5d751a83e3 --- /dev/null +++ b/pkg/agent/core/faiss/model/faiss.go @@ -0,0 +1,23 @@ +// +// Copyright (C) 2019-2023 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package model defines object structure +package model + +type Distance struct { + ID string + Distance float32 +} diff --git a/pkg/agent/core/faiss/router/option.go b/pkg/agent/core/faiss/router/option.go new file mode 100644 index 0000000000..4db2794819 --- /dev/null +++ b/pkg/agent/core/faiss/router/option.go @@ -0,0 +1,51 @@ +// +// Copyright (C) 2019-2023 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package router provides implementation of Go API for routing http Handler wrapped by rest.Func +package router + +import ( + "github.com/vdaas/vald/internal/errgroup" + "github.com/vdaas/vald/pkg/agent/core/faiss/handler/rest" +) + +// Option represents the functional option for router. +type Option func(*router) + +var defaultOptions = []Option{ + WithTimeout("3s"), +} + +// WithHandler returns the option to set the handler for the router. +func WithHandler(h rest.Handler) Option { + return func(r *router) { + r.handler = h + } +} + +// WithTimeout returns the option to set the timeout for the router. +func WithTimeout(timeout string) Option { + return func(r *router) { + r.timeout = timeout + } +} + +// WithErrGroup returns the option to set the error group for the router. +func WithErrGroup(eg errgroup.Group) Option { + return func(r *router) { + r.eg = eg + } +} diff --git a/pkg/agent/core/faiss/router/router.go b/pkg/agent/core/faiss/router/router.go new file mode 100644 index 0000000000..4b3cb2cc83 --- /dev/null +++ b/pkg/agent/core/faiss/router/router.go @@ -0,0 +1,170 @@ +// +// Copyright (C) 2019-2023 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package router provides implementation of Go API for routing http Handler wrapped by rest.Func +package router + +import ( + "net/http" + + "github.com/vdaas/vald/internal/errgroup" + "github.com/vdaas/vald/internal/net/http/middleware" + "github.com/vdaas/vald/internal/net/http/routing" + "github.com/vdaas/vald/pkg/agent/core/faiss/handler/rest" +) + +type router struct { + handler rest.Handler + eg errgroup.Group + timeout string +} + +// New returns REST route&method information from handler interface. +func New(opts ...Option) http.Handler { + r := new(router) + + for _, opt := range append(defaultOptions, opts...) { + opt(r) + } + + h := r.handler + + return routing.New( + routing.WithMiddleware( + middleware.NewTimeout( + middleware.WithTimeout(r.timeout), + middleware.WithErrorGroup(r.eg), + )), + routing.WithRoutes([]routing.Route{ + { + "Index", + []string{ + http.MethodGet, + }, + "/", + h.Index, + }, + { + "Search", + []string{ + http.MethodPost, + }, + "/search", + h.Search, + }, + { + "Search By ID", + []string{ + http.MethodPost, + }, + "/id/search", + h.SearchByID, + }, + { + "LinearSearch", + []string{ + http.MethodPost, + }, + "/linearsearch", + h.LinearSearch, + }, + { + "LinearSearch By ID", + []string{ + http.MethodPost, + }, + "/id/linearsearch", + h.LinearSearchByID, + }, + { + "Insert", + []string{ + http.MethodPost, + }, + "/insert", + h.Insert, + }, + { + "Multiple Insert", + []string{ + http.MethodPost, + }, + "/insert/multi", + h.MultiInsert, + }, + { + "Update", + []string{ + http.MethodPost, + http.MethodPatch, + http.MethodPut, + }, + "/update", + h.Update, + }, + { + "Multiple Update", + []string{ + http.MethodPost, + http.MethodPatch, + http.MethodPut, + }, + "/update/multi", + h.MultiUpdate, + }, + { + "Remove", + []string{ + http.MethodDelete, + }, + "/delete", + h.Remove, + }, + { + "Multiple Remove", + []string{ + http.MethodDelete, + http.MethodPost, + }, + "/delete/multi", + h.MultiRemove, + }, + { + "Create Index", + []string{ + http.MethodPost, + }, + "/index/create", + h.CreateIndex, + }, + { + "Save Index", + []string{ + http.MethodGet, + }, + "/index/save", + h.SaveIndex, + }, + { + "GetObject", + []string{ + http.MethodGet, + }, + "/object/{id}", + h.GetObject, + }, + }...)) +} diff --git a/pkg/agent/core/faiss/service/faiss.go b/pkg/agent/core/faiss/service/faiss.go new file mode 100644 index 0000000000..d915240762 --- /dev/null +++ b/pkg/agent/core/faiss/service/faiss.go @@ -0,0 +1,1284 @@ +// +// Copyright (C) 2019-2023 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package service manages the main logic of server. +package service + +import ( + "context" + "encoding/gob" + "io/fs" + "math" + "os" + "path/filepath" + "reflect" + "runtime" + "sync" + "sync/atomic" + "time" + + "github.com/vdaas/vald/internal/config" + core "github.com/vdaas/vald/internal/core/algorithm/faiss" + "github.com/vdaas/vald/internal/errgroup" + "github.com/vdaas/vald/internal/errors" + "github.com/vdaas/vald/internal/file" + "github.com/vdaas/vald/internal/log" + "github.com/vdaas/vald/internal/observability/trace" + "github.com/vdaas/vald/internal/safety" + "github.com/vdaas/vald/internal/strings" + "github.com/vdaas/vald/pkg/agent/core/faiss/model" + "github.com/vdaas/vald/pkg/agent/core/ngt/service/kvs" + "github.com/vdaas/vald/pkg/agent/core/ngt/service/vqueue" + "github.com/vdaas/vald/pkg/agent/internal/metadata" +) + +type ( + Faiss interface { + Start(ctx context.Context) <-chan error + Train(nb int, xb []float32) error + Insert(uuid string, xb []float32) error + InsertWithTime(uuid string, vec []float32, t int64) error + Update(uuid string, vec []float32) error + UpdateWithTime(uuid string, vec []float32, t int64) error + CreateIndex(ctx context.Context) error + SaveIndex(ctx context.Context) error + CreateAndSaveIndex(ctx context.Context) error + Search(k, nq uint32, xq []float32) ([]model.Distance, error) + Delete(uuid string) error + DeleteWithTime(uuid string, t int64) error + Exists(uuid string) (uint32, bool) + IsIndexing() bool + IsSaving() bool + NumberOfCreateIndexExecution() uint64 + NumberOfProactiveGCExecution() uint64 + Len() uint64 + InsertVQueueBufferLen() uint64 + DeleteVQueueBufferLen() uint64 + GetDimensionSize() int + GetTrainSize() int + Close(ctx context.Context) error + } + + faiss struct { + core core.Faiss + eg errgroup.Group + kvs kvs.BidiMap + fmu sync.Mutex + fmap map[string]int64 // failure map for index + vq vqueue.Queue + addVecs []float32 + addIds []int64 + isTrained bool + trainSize int + icnt uint64 + + // statuses + indexing atomic.Value + saving atomic.Value + cimu sync.Mutex // create index mutex + lastNocie uint64 // last number of create index execution this value prevent unnecessary saveindex + + // counters + nocie uint64 // number of create index execution + nogce uint64 // number of proactive GC execution + wfci uint64 // wait for create indexing + + // configurations + inMem bool // in-memory mode + dim int // dimension size + nlist int // the number of Voronoi cells + m int // number of subquantizers + alen int // auto indexing length + dur time.Duration // auto indexing check duration + sdur time.Duration // auto save index check duration + lim time.Duration // auto indexing time limit + minLit time.Duration // minimum load index timeout + maxLit time.Duration // maximum load index timeout + litFactor time.Duration // load index timeout factor + enableProactiveGC bool // if this value is true, agent component will purge GC memory more proactive + enableCopyOnWrite bool // if this value is true, agent component will write backup file using Copy on Write and saves old files to the old directory + path string // index path + smu sync.Mutex // save index lock + tmpPath atomic.Value // temporary index path for Copy on Write + oldPath string // old volume path + basePath string // index base directory for CoW + cowmu sync.Mutex // copy on write move lock + dcd bool // disable commit daemon + idelay time.Duration // initial delay duration + kvsdbConcurrency int // kvsdb concurrency + } +) + +const ( + kvsFileName = "faiss-meta.kvsdb" + kvsTimestampFileName = "faiss-timestamp.kvsdb" + noTimeStampFile = -1 + + oldIndexDirName = "backup" + originIndexDirName = "origin" + + // ref: https://github.com/facebookresearch/faiss/wiki/FAQ#can-i-ignore-warning-clustering-xxx-points-to-yyy-centroids + // ref: https://github.com/facebookresearch/faiss/blob/main/faiss/Clustering.cpp#L38 + minPointsPerCentroid int = 39 +) + +func New(cfg *config.Faiss, opts ...Option) (Faiss, error) { + var ( + f = &faiss{ + fmap: make(map[string]int64), + dim: cfg.Dimension, + nlist: cfg.Nlist, + m: cfg.M, + enableProactiveGC: cfg.EnableProactiveGC, + enableCopyOnWrite: cfg.EnableCopyOnWrite, + kvsdbConcurrency: cfg.KVSDB.Concurrency, + } + err error + ) + + for _, opt := range append(defaultOptions, opts...) { + if err := opt(f); err != nil { + return nil, errors.ErrOptionFailed(err, reflect.ValueOf(opt)) + } + } + + if len(f.path) == 0 { + f.inMem = true + } + + if f.enableCopyOnWrite && !f.inMem && len(f.path) != 0 { + sep := string(os.PathSeparator) + f.path, err = filepath.Abs(strings.ReplaceAll(f.path, sep+sep, sep)) + if err != nil { + log.Warn(err) + } + + f.basePath = f.path + f.oldPath = file.Join(f.basePath, oldIndexDirName) + f.path = file.Join(f.basePath, originIndexDirName) + err = file.MkdirAll(f.oldPath, fs.ModePerm) + if err != nil { + log.Warn(err) + } + err = file.MkdirAll(f.path, fs.ModePerm) + if err != nil { + log.Warn(err) + } + err = f.mktmp() + if err != nil { + return nil, err + } + } + + err = f.initFaiss( + core.WithDimension(cfg.Dimension), + core.WithNlist(cfg.Nlist), + core.WithM(cfg.M), + core.WithNbitsPerIdx(cfg.NbitsPerIdx), + core.WithMetricType(cfg.MetricType), + ) + if err != nil { + return nil, err + } + + if f.dur == 0 || f.alen == 0 { + f.dcd = true + } + + if f.vq == nil { + f.vq, err = vqueue.New() + if err != nil { + return nil, err + } + } + + f.indexing.Store(false) + f.saving.Store(false) + + return f, nil +} + +func (f *faiss) initFaiss(opts ...core.Option) error { + var err error + + if f.kvs == nil { + f.kvs = kvs.New(kvs.WithConcurrency(f.kvsdbConcurrency)) + } + + if f.inMem { + log.Debug("vald agent starts with in-memory mode") + f.core, err = core.New(opts...) + return err + } + + ctx := context.Background() + err = f.load(ctx, f.path, opts...) + var current uint64 + if err != nil { + if !f.enableCopyOnWrite { + log.Debug("failed to load vald index from %s\t error: %v", f.path, err) + if f.kvs == nil { + f.kvs = kvs.New(kvs.WithConcurrency(f.kvsdbConcurrency)) + } else if f.kvs.Len() > 0 { + f.kvs.Close() + f.kvs = kvs.New(kvs.WithConcurrency(f.kvsdbConcurrency)) + } + + if f.core != nil { + f.core.Close() + f.core = nil + } + f.core, err = core.New(append(opts, core.WithIndexPath(f.path))...) + return err + } + + if errors.Is(err, errors.ErrIndicesAreTooFewComparedToMetadata) && f.kvs != nil { + current = f.kvs.Len() + log.Warnf( + "load vald primary index success from %s\t error: %v\tbut index data are too few %d compared to metadata count now trying to load from old copied index data from %s and compare them", + f.path, + err, + current, + f.oldPath, + ) + } else { + log.Warnf("failed to load vald primary index from %s\t error: %v\ttrying to load from old copied index data from %s", f.path, err, f.oldPath) + } + } else { + return nil + } + + err = f.load(ctx, f.oldPath, opts...) + if err == nil { + if current != 0 && f.kvs.Len() < current { + log.Warnf( + "load vald secondary index success from %s\t error: %v\tbut index data are too few %d compared to primary data now trying to load from primary index data again from %s and start up with them", + f.oldPath, + err, + f.kvs.Len(), + f.oldPath, + ) + + err = f.load(ctx, f.path, opts...) + if err == nil { + return nil + } + } else { + return nil + } + } + + log.Warnf("failed to load vald secondary index from %s and %s\t error: %v\ttrying to load from non-CoW index data from %s for backwards compatibility", f.path, f.oldPath, err, f.basePath) + err = f.load(ctx, f.basePath, opts...) + if err == nil { + file.CopyDir(ctx, f.basePath, f.path) + return nil + } + + tpath := f.tmpPath.Load().(string) + log.Warnf( + "failed to load vald backwards index from %s and %s and %s\t error: %v\tvald agent couldn't find any index from agent volume in %s trying to start as new index from %s", + f.path, + f.oldPath, + f.basePath, + err, + f.basePath, + tpath, + ) + + if f.core != nil { + f.core.Close() + f.core = nil + } + f.core, err = core.New(append(opts, core.WithIndexPath(tpath))...) + if err != nil { + return err + } + + if f.kvs == nil { + f.kvs = kvs.New(kvs.WithConcurrency(f.kvsdbConcurrency)) + } else if f.kvs.Len() > 0 { + f.kvs.Close() + f.kvs = kvs.New(kvs.WithConcurrency(f.kvsdbConcurrency)) + } + + return nil +} + +func (f *faiss) load(ctx context.Context, path string, opts ...core.Option) error { + exist, fi, err := file.ExistsWithDetail(path) + switch { + case !exist, fi == nil, fi != nil && fi.Size() == 0, err != nil && errors.Is(err, fs.ErrNotExist): + err = errors.Wrapf(errors.ErrIndexFileNotFound, "index file does not exists,\tpath: %s,\terr: %v", path, err) + return err + case err != nil && errors.Is(err, fs.ErrPermission): + if fi != nil { + err = errors.Wrap(errors.ErrFailedToOpenFile(err, path, 0, fi.Mode()), "invalid permission for loading index path") + } + return err + case exist && fi != nil && fi.IsDir(): + if fi.Mode().IsDir() && !strings.HasSuffix(path, string(os.PathSeparator)) { + path += string(os.PathSeparator) + } + files, err := filepath.Glob(file.Join(filepath.Dir(path), "*")) + if err != nil || len(files) == 0 { + err = errors.Wrapf(errors.ErrIndexFileNotFound, "index path exists but no file does not exists in the directory,\tpath: %s,\tfiles: %v\terr: %v", path, files, err) + return err + } + if strings.HasSuffix(path, string(os.PathSeparator)) { + path = strings.TrimSuffix(path, string(os.PathSeparator)) + } + } + + metadataPath := file.Join(path, metadata.AgentMetadataFileName) + log.Debugf("index path: %s exists, now starting to check metadata from %s", path, metadataPath) + exist, fi, err = file.ExistsWithDetail(metadataPath) + switch { + case !exist, fi == nil, fi != nil && fi.Size() == 0, err != nil && errors.Is(err, fs.ErrNotExist): + err = errors.Wrapf(errors.ErrIndexFileNotFound, "metadata file does not exists,\tpath: %s,\terr: %v", metadataPath, err) + return err + case err != nil && errors.Is(err, fs.ErrPermission): + if fi != nil { + err = errors.Wrap(errors.ErrFailedToOpenFile(err, metadataPath, 0, fi.Mode()), "invalid permission for loading metadata") + } + return err + } + + log.Debugf("index path: %s and metadata: %s exists, now starting to load metadata", path, metadataPath) + agentMetadata, err := metadata.Load(metadataPath) + if err != nil && errors.Is(err, fs.ErrNotExist) || agentMetadata == nil || agentMetadata.Faiss == nil || agentMetadata.Faiss.IndexCount == 0 { + err = errors.Wrapf(err, "cannot read metadata from path: %s\tmetadata: %s", path, agentMetadata) + return err + } + + kvsFilePath := file.Join(path, kvsFileName) + log.Debugf("index path: %s and metadata: %s exists and successfully load metadata, now starting to load kvs data from %s", path, metadataPath, kvsFilePath) + exist, fi, err = file.ExistsWithDetail(kvsFilePath) + switch { + case !exist, fi == nil, fi != nil && fi.Size() == 0, err != nil && errors.Is(err, fs.ErrNotExist): + err = errors.Wrapf(errors.ErrIndexFileNotFound, "kvsdb file does not exists,\tpath: %s,\terr: %v", kvsFilePath, err) + return err + case err != nil && errors.Is(err, fs.ErrPermission): + if fi != nil { + err = errors.ErrFailedToOpenFile(err, kvsFilePath, 0, fi.Mode()) + } + err = errors.Wrapf(err, "invalid permission for loading kvsdb file from %s", kvsFilePath) + return err + } + + kvsTimestampFilePath := file.Join(path, kvsTimestampFileName) + log.Debugf("now starting to load kvs timestamp data from %s", kvsTimestampFilePath) + exist, fi, err = file.ExistsWithDetail(kvsTimestampFilePath) + switch { + case !exist, fi == nil, fi != nil && fi.Size() == 0, err != nil && errors.Is(err, fs.ErrNotExist): + log.Warnf("timestamp kvsdb file does not exists,\tpath: %s,\terr: %v", kvsTimestampFilePath, err) + case err != nil && errors.Is(err, fs.ErrPermission): + if fi != nil { + err = errors.ErrFailedToOpenFile(err, kvsTimestampFilePath, 0, fi.Mode()) + } + log.Warnf("invalid permission for loading timestamp kvsdb file from %s", kvsTimestampFilePath) + } + + var timeout time.Duration + if agentMetadata != nil && agentMetadata.Faiss != nil { + log.Debugf("the backup index size is %d. starting to load...", agentMetadata.Faiss.IndexCount) + timeout = time.Duration( + math.Min( + math.Max( + float64(agentMetadata.Faiss.IndexCount)*float64(f.litFactor), + float64(f.minLit), + ), + float64(f.maxLit), + ), + ) + } else { + log.Debugf("cannot inspect the backup index size. starting to load default value.") + timeout = time.Duration(math.Min(float64(f.minLit), float64(f.maxLit))) + } + + log.Debugf("index path: %s and metadata: %s and kvsdb file: %s and timestamp kvsdb file: %s exists and successfully load metadata, now starting to load full index and kvs data in concurrent", path, metadataPath, kvsFilePath, kvsTimestampFilePath) + + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + eg, _ := errgroup.New(ctx) + eg.Go(safety.RecoverFunc(func() (err error) { + if f.core != nil { + f.core.Close() + f.core = nil + } + f.core, err = core.Load(append(opts, core.WithIndexPath(path))...) + if err != nil { + err = errors.Wrapf(err, "failed to load faiss index from path: %s", path) + return err + } + return nil + })) + + eg.Go(safety.RecoverFunc(func() (err error) { + err = f.loadKVS(ctx, path, timeout) + if err != nil { + err = errors.Wrapf(err, "failed to load kvsdb data from path: %s, %s", kvsFilePath, kvsTimestampFilePath) + return err + } + if f.kvs != nil && float64(agentMetadata.Faiss.IndexCount/2) > float64(f.kvs.Len()) { + return errors.ErrIndicesAreTooFewComparedToMetadata + } + return nil + })) + + ech := make(chan error, 1) + // NOTE: when it exceeds the timeout while loading, + // it should exit this function and leave this goroutine running. + f.eg.Go(safety.RecoverFunc(func() error { + defer close(ech) + ech <- safety.RecoverFunc(func() (err error) { + err = eg.Wait() + if err != nil { + log.Error(err) + return err + } + cancel() + return nil + })() + return nil + })) + + select { + case err := <-ech: + return err + case <-ctx.Done(): + if errors.Is(ctx.Err(), context.DeadlineExceeded) { + log.Errorf("cannot load index backup data from %s within the timeout %s. the process is going to be killed.", path, timeout) + err := metadata.Store(metadataPath, + &metadata.Metadata{ + IsInvalid: true, + Faiss: &metadata.Faiss{ + IndexCount: 0, + }, + }, + ) + if err != nil { + return err + } + return errors.ErrIndexLoadTimeout + } + } + + return nil +} + +func (f *faiss) loadKVS(ctx context.Context, path string, timeout time.Duration) (err error) { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + eg, _ := errgroup.New(ctx) + + m := make(map[string]uint32) + mt := make(map[string]int64) + + eg.Go(safety.RecoverFunc(func() (err error) { + gob.Register(map[string]uint32{}) + var fi *os.File + fi, err = file.Open( + file.Join(path, kvsFileName), + os.O_RDONLY|os.O_SYNC, + fs.ModePerm, + ) + if err != nil { + return err + } + defer func() { + if fi != nil { + derr := fi.Close() + if derr != nil { + err = errors.Wrap(err, derr.Error()) + } + } + }() + err = gob.NewDecoder(fi).Decode(&m) + if err != nil { + log.Errorf("error decoding kvsdb file,\terr: %v", err) + return err + } + return nil + })) + + eg.Go(safety.RecoverFunc(func() (err error) { + gob.Register(map[string]int64{}) + var ft *os.File + ft, err = file.Open( + file.Join(path, kvsTimestampFileName), + os.O_RDONLY|os.O_SYNC, + fs.ModePerm, + ) + if err != nil { + log.Warnf("error opening timestamp kvsdb file,\terr: %v", err) + } + defer func() { + if ft != nil { + derr := ft.Close() + if derr != nil { + err = errors.Wrap(err, derr.Error()) + } + } + }() + err = gob.NewDecoder(ft).Decode(&mt) + if err != nil { + log.Warnf("error decoding timestamp kvsdb file,\terr: %v", err) + } + return nil + })) + + err = eg.Wait() + if err != nil { + return err + } + + if f.kvs == nil { + f.kvs = kvs.New(kvs.WithConcurrency(f.kvsdbConcurrency)) + } else if f.kvs.Len() > 0 { + f.kvs.Close() + f.kvs = kvs.New(kvs.WithConcurrency(f.kvsdbConcurrency)) + } + for k, id := range m { + if ts, ok := mt[k]; ok { + f.kvs.Set(k, id, ts) + } else { + // NOTE: SaveIndex do not write ngt-timestamp.kvsdb with timestamp 0. + f.kvs.Set(k, id, 0) + f.fmap[k] = int64(id) + } + } + for k := range mt { + if _, ok := m[k]; !ok { + f.fmap[k] = noTimeStampFile + } + } + + return nil +} + +func (f *faiss) mktmp() error { + if !f.enableCopyOnWrite { + return nil + } + + path, err := file.MkdirTemp(file.Join(os.TempDir(), "vald")) + if err != nil { + log.Warnf("failed to create temporary index file path directory %s:\terr: %v", path, err) + return err + } + + f.tmpPath.Store(path) + + return nil +} + +func (f *faiss) Start(ctx context.Context) <-chan error { + if f.dcd { + return nil + } + + ech := make(chan error, 2) + f.eg.Go(safety.RecoverFunc(func() (err error) { + defer close(ech) + if f.dur <= 0 { + f.dur = math.MaxInt64 + } + if f.sdur <= 0 { + f.sdur = math.MaxInt64 + } + if f.lim <= 0 { + f.lim = math.MaxInt64 + } + + if f.idelay > 0 { + timer := time.NewTimer(f.idelay) + select { + case <-ctx.Done(): + timer.Stop() + return ctx.Err() + case <-timer.C: + } + timer.Stop() + } + + tick := time.NewTicker(f.dur) + sTick := time.NewTicker(f.sdur) + limit := time.NewTicker(f.lim) + defer tick.Stop() + defer sTick.Stop() + defer limit.Stop() + for { + err = nil + select { + case <-ctx.Done(): + err = f.CreateIndex(ctx) + if err != nil && !errors.Is(err, errors.ErrUncommittedIndexNotFound) { + ech <- err + return errors.Wrap(ctx.Err(), err.Error()) + } + return ctx.Err() + case <-tick.C: + if f.vq.IVQLen() >= f.alen { + err = f.CreateIndex(ctx) + } + case <-limit.C: + err = f.CreateAndSaveIndex(ctx) + case <-sTick.C: + err = f.SaveIndex(ctx) + } + if err != nil && err != errors.ErrUncommittedIndexNotFound { + ech <- err + runtime.Gosched() + err = nil + } + } + })) + + return ech +} + +func (f *faiss) Train(nb int, xb []float32) error { + err := f.core.Train(nb, xb) + if err != nil { + log.Errorf("failed to faiss train", err) + return err + } + + return nil +} + +func (f *faiss) Insert(uuid string, vec []float32) error { + return f.insert(uuid, vec, time.Now().UnixNano(), true) +} + +func (f *faiss) InsertWithTime(uuid string, vec []float32, t int64) error { + if t <= 0 { + t = time.Now().UnixNano() + } + + return f.insert(uuid, vec, t, true) +} + +func (f *faiss) insert(uuid string, xb []float32, t int64, validation bool) error { + if len(uuid) == 0 { + err := errors.ErrUUIDNotFound(0) + return err + } + + if validation { + _, ok := f.Exists(uuid) + if ok { + return errors.ErrUUIDAlreadyExists(uuid) + } + } + + return f.vq.PushInsert(uuid, xb, t) +} + +func (f *faiss) Update(uuid string, vec []float32) error { + return f.update(uuid, vec, time.Now().UnixNano()) +} + +func (f *faiss) UpdateWithTime(uuid string, vec []float32, t int64) error { + if t <= 0 { + t = time.Now().UnixNano() + } + return f.update(uuid, vec, t) +} + +func (f *faiss) update(uuid string, vec []float32, t int64) (err error) { + if err = f.readyForUpdate(uuid, vec); err != nil { + return err + } + + err = f.delete(uuid, t, true) // `true` is to return NotFound error with non-existent ID + if err != nil { + return err + } + + t++ + return f.insert(uuid, vec, t, false) +} + +func (f *faiss) readyForUpdate(uuid string, vec []float32) (err error) { + if len(uuid) == 0 { + return errors.ErrUUIDNotFound(0) + } + + if len(vec) != f.GetDimensionSize() { + return errors.ErrInvalidDimensionSize(len(vec), f.GetDimensionSize()) + } + + // not impl GetObject() + + return nil +} + +func (f *faiss) CreateIndex(ctx context.Context) error { + ctx, span := trace.StartSpan(ctx, "vald/agent-faiss/service/Faiss.CreateIndex") + defer func() { + if span != nil { + span.End() + } + }() + + ic := f.vq.IVQLen() + f.vq.DVQLen() + (len(f.addVecs) / f.dim) + if ic == 0 { + return errors.ErrUncommittedIndexNotFound + } + + wf := atomic.AddUint64(&f.wfci, 1) + if wf > 1 { + atomic.AddUint64(&f.wfci, ^uint64(0)) + log.Debugf("concurrent create index waiting detected this request will be ignored, concurrent: %d", wf) + return nil + } + + err := func() error { + ticker := time.NewTicker(time.Millisecond * 100) + defer ticker.Stop() + // wait for not indexing & not saving + for f.IsIndexing() || f.IsSaving() { + runtime.Gosched() + select { + case <-ctx.Done(): + atomic.AddUint64(&f.wfci, ^uint64(0)) + return ctx.Err() + case <-ticker.C: + } + } + atomic.AddUint64(&f.wfci, ^uint64(0)) + return nil + }() + if err != nil { + return err + } + + f.cimu.Lock() + defer f.cimu.Unlock() + f.indexing.Store(true) + defer f.indexing.Store(false) + defer f.gc() + now := time.Now().UnixNano() + ic = f.vq.IVQLen() + f.vq.DVQLen() + (len(f.addVecs) / f.dim) + if ic == 0 { + return errors.ErrUncommittedIndexNotFound + } + + log.Infof("create index operation started, uncommitted indexes = %d", ic) + log.Debug("create index delete phase started") + f.vq.RangePopDelete(ctx, now, func(uuid string) bool { + log.Debugf("start delete operation for kvsdb id: %s", uuid) + oid, ok := f.kvs.Delete(uuid) + if !ok { + log.Warn(errors.ErrObjectIDNotFound(uuid)) + return true + } + log.Debugf("start remove operation for faiss index id: %s, oid: %d", uuid, oid) + ntotal, err := f.core.Remove(1, []int64{int64(oid)}) + if err != nil { + log.Errorf("failed to remove oid: %d from faiss index. error: %v", oid, err) + f.fmu.Lock() + f.fmap[uuid] = int64(oid) + f.fmu.Unlock() + } + log.Debugf("removed from faiss index and kvsdb id: %s, oid: %d, index size: %d", uuid, oid, ntotal) + return true + }) + log.Debug("create index delete phase finished") + + f.gc() + + log.Debug("create index insert phase started") + f.vq.RangePopInsert(ctx, now, func(uuid string, vector []float32, timestamp int64) bool { + log.Debugf("start stack operation for faiss index id: %s, icnt: %d", uuid, uint32(f.icnt)) + f.addVecs = append(f.addVecs, vector...) + f.addIds = append(f.addIds, int64(f.icnt)) + + log.Debugf("start insert operation for kvsdb id: %s, icnt: %d", uuid, uint32(f.icnt)) + f.kvs.Set(uuid, uint32(f.icnt), timestamp) + atomic.AddUint64(&f.icnt, 1) + + f.fmu.Lock() + _, ok := f.fmap[uuid] + if ok { + delete(f.fmap, uuid) + } + f.fmu.Unlock() + log.Debugf("finished to insert index and kvsdb id: %s, icnt: %d", uuid, uint32(f.icnt)) + return true + }) + + var max int + if f.nlist > int(math.Pow(2, float64(f.m))) { + max = f.nlist + } else { + max = int(math.Pow(2, float64(f.m))) + } + if !f.isTrained && len(f.addVecs)/f.dim >= max*minPointsPerCentroid { + log.Debug("faiss train phase started") + log.Debugf("max * minPointsPerCentroid: %d", max*minPointsPerCentroid) + err := f.core.Train(len(f.addVecs)/f.dim, f.addVecs) + if err != nil { + log.Errorf("failed to faiss train", err) + return err + } + f.isTrained = true + f.trainSize = len(f.addVecs) / f.dim + log.Debug("faiss train phase finished") + } + if f.isTrained && len(f.addVecs) > 0 { + log.Debug("faiss add phase started") + ntotal, err := f.core.Add(len(f.addVecs)/f.dim, f.addVecs, f.addIds) + if err != nil { + log.Errorf("failed to faiss add", err) + return err + } + f.addVecs = nil + f.addIds = nil + log.Debugf("is trained: %v, index size: %d", f.isTrained, ntotal) + log.Debug("faiss add phase finished") + } + log.Debug("create index insert phase finished") + + atomic.AddUint64(&f.nocie, 1) + log.Info("create index operation finished") + + return nil +} + +func (f *faiss) SaveIndex(ctx context.Context) error { + ctx, span := trace.StartSpan(ctx, "vald/agent-faiss/service/Faiss.SaveIndex") + defer func() { + if span != nil { + span.End() + } + }() + + if !f.inMem { + return f.saveIndex(ctx) + } + + return nil +} + +func (f *faiss) saveIndex(ctx context.Context) error { + nocie := atomic.LoadUint64(&f.nocie) + if atomic.LoadUint64(&f.lastNocie) == nocie || !f.isTrained { + return nil + } + atomic.SwapUint64(&f.lastNocie, nocie) + + err := func() error { + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + // wait for not indexing & not saving + for f.IsIndexing() || f.IsSaving() { + runtime.Gosched() + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + } + } + return nil + }() + if err != nil { + return err + } + + f.saving.Store(true) + defer f.gc() + defer f.saving.Store(false) + + // no cleanup invalid index + + eg, ectx := errgroup.New(ctx) + // we want to ensure the acutal kvs size between kvsdb and metadata, + // so we create this counter to count the actual kvs size instead of using kvs.Len() + var ( + kvsLen uint64 + path string + ) + + if f.enableCopyOnWrite { + path = f.tmpPath.Load().(string) + } else { + path = f.path + } + + f.smu.Lock() + defer f.smu.Unlock() + + eg.Go(safety.RecoverFunc(func() (err error) { + if f.kvs.Len() > 0 && path != "" { + m := make(map[string]uint32, f.Len()) + mt := make(map[string]int64, f.Len()) + var mu sync.Mutex + + f.kvs.Range(ectx, func(key string, id uint32, ts int64) bool { + mu.Lock() + m[key] = id + mt[key] = ts + mu.Unlock() + atomic.AddUint64(&kvsLen, 1) + return true + }) + + var fi *os.File + fi, err = file.Open( + file.Join(path, kvsFileName), + os.O_WRONLY|os.O_CREATE|os.O_TRUNC, + fs.ModePerm, + ) + if err != nil { + return err + } + defer func() { + if fi != nil { + derr := fi.Close() + if derr != nil { + err = errors.Wrap(err, derr.Error()) + } + } + }() + + gob.Register(map[string]uint32{}) + err = gob.NewEncoder(fi).Encode(&m) + if err != nil { + return err + } + + err = fi.Sync() + if err != nil { + return err + } + + m = make(map[string]uint32) + + var ft *os.File + ft, err = file.Open( + file.Join(path, kvsTimestampFileName), + os.O_WRONLY|os.O_CREATE|os.O_TRUNC, + fs.ModePerm, + ) + if err != nil { + return err + } + defer func() { + if ft != nil { + derr := ft.Close() + if derr != nil { + err = errors.Wrap(err, derr.Error()) + } + } + }() + + gob.Register(map[string]int64{}) + err = gob.NewEncoder(ft).Encode(&mt) + if err != nil { + return err + } + + err = ft.Sync() + if err != nil { + return err + } + + mt = make(map[string]int64) + } + + return nil + })) + + eg.Go(safety.RecoverFunc(func() (err error) { + f.fmu.Lock() + fl := len(f.fmap) + f.fmu.Unlock() + + if fl > 0 && path != "" { + var fi *os.File + fi, err = file.Open( + file.Join(path, "invalid-"+kvsFileName), + os.O_WRONLY|os.O_CREATE|os.O_TRUNC, + fs.ModePerm, + ) + if err != nil { + return err + } + defer func() { + if fi != nil { + derr := fi.Close() + if derr != nil { + err = errors.Wrap(err, derr.Error()) + } + } + }() + + gob.Register(map[string]int64{}) + f.fmu.Lock() + err = gob.NewEncoder(fi).Encode(&f.fmap) + f.fmu.Unlock() + if err != nil { + return err + } + + err = fi.Sync() + if err != nil { + return err + } + } + + return nil + })) + + eg.Go(safety.RecoverFunc(func() error { + return f.core.SaveIndexWithPath(path) + })) + + err = eg.Wait() + if err != nil { + return err + } + + err = metadata.Store( + file.Join(path, metadata.AgentMetadataFileName), + &metadata.Metadata{ + IsInvalid: false, + Faiss: &metadata.Faiss{ + IndexCount: kvsLen, + }, + }, + ) + if err != nil { + return err + } + + return f.moveAndSwitchSavedData(ctx) +} + +func (f *faiss) moveAndSwitchSavedData(ctx context.Context) error { + if !f.enableCopyOnWrite { + return nil + } + + var err error + f.cowmu.Lock() + defer f.cowmu.Unlock() + + err = file.MoveDir(ctx, f.path, f.oldPath) + if err != nil { + log.Warnf("failed to backup backup data from %s to %s error: %v", f.path, f.oldPath, err) + } + + path := f.tmpPath.Load().(string) + err = file.MoveDir(ctx, path, f.path) + if err != nil { + log.Warnf("failed to move temporary index data from %s to %s error: %v, trying to rollback secondary backup data from %s to %s", path, f.path, f.oldPath, f.path, err) + return file.MoveDir(ctx, f.oldPath, f.path) + } + defer log.Warnf("finished to copy index from %s => %s => %s", path, f.path, f.oldPath) + + return f.mktmp() +} + +func (f *faiss) CreateAndSaveIndex(ctx context.Context) error { + ctx, span := trace.StartSpan(ctx, "vald/agent-faiss/service/Faiss.CreateAndSaveIndex") + defer func() { + if span != nil { + span.End() + } + }() + + err := f.CreateIndex(ctx) + if err != nil && + !errors.Is(err, errors.ErrUncommittedIndexNotFound) && + !errors.Is(err, context.Canceled) && + !errors.Is(err, context.DeadlineExceeded) { + return err + } + + return f.SaveIndex(ctx) +} + +func (f *faiss) Search(k, nq uint32, xq []float32) ([]model.Distance, error) { + if f.IsIndexing() { + return nil, errors.ErrCreateIndexingIsInProgress + } + + sr, err := f.core.Search(int(k), int(nq), xq) + if err != nil { + if f.IsIndexing() { + return nil, errors.ErrCreateIndexingIsInProgress + } + + log.Errorf("cgo error detected: faiss api returned error %v", err) + return nil, err + } + + if len(sr) == 0 { + return nil, errors.ErrEmptySearchResult + } + + ds := make([]model.Distance, 0, len(sr)) + for _, d := range sr { + if err = d.Error; d.ID == 0 && err != nil { + log.Warnf("an error occurred while searching: %s", err) + continue + } + + key, _, ok := f.kvs.GetInverse(d.ID) + if ok { + ds = append(ds, model.Distance{ + ID: key, + Distance: d.Distance, + }) + } else { + log.Warn("not found", d.ID, d.Distance) + } + } + + return ds, nil +} + +func (f *faiss) Delete(uuid string) (err error) { + return f.delete(uuid, time.Now().UnixNano(), true) +} + +func (f *faiss) DeleteWithTime(uuid string, t int64) (err error) { + if t <= 0 { + t = time.Now().UnixNano() + } + + return f.delete(uuid, t, true) +} + +func (f *faiss) delete(uuid string, t int64, validation bool) error { + if len(uuid) == 0 { + return errors.ErrUUIDNotFound(0) + } + + if validation { + _, _, ok := f.kvs.Get(uuid) + if !ok && !f.vq.IVExists(uuid) { + return errors.ErrObjectIDNotFound(uuid) + } + } + + return f.vq.PushDelete(uuid, t) +} + +func (f *faiss) Exists(uuid string) (uint32, bool) { + var ( + oid uint32 + ok bool + ) + + ok = f.vq.IVExists(uuid) + if !ok { + oid, _, ok = f.kvs.Get(uuid) + if !ok { + log.Debugf("Exists\tuuid: %s's data not found in kvsdb and insert vqueue\terror: %v", uuid, errors.ErrObjectIDNotFound(uuid)) + return 0, false + } + if f.vq.DVExists(uuid) { + log.Debugf("Exists\tuuid: %s's data found in kvsdb and not found in insert vqueue, but delete vqueue data exists. the object will be delete soon\terror: %v", + uuid, errors.ErrObjectIDNotFound(uuid)) + return 0, false + } + } + + return oid, ok +} + +func (f *faiss) IsIndexing() bool { + i, ok := f.indexing.Load().(bool) + return i && ok +} + +func (f *faiss) IsSaving() bool { + s, ok := f.saving.Load().(bool) + return s && ok +} + +func (f *faiss) NumberOfCreateIndexExecution() uint64 { + return atomic.LoadUint64(&f.nocie) +} + +func (f *faiss) NumberOfProactiveGCExecution() uint64 { + return atomic.LoadUint64(&f.nogce) +} + +func (f *faiss) gc() { + if f.enableProactiveGC { + runtime.GC() + atomic.AddUint64(&f.nogce, 1) + } +} + +func (f *faiss) Len() uint64 { + return f.kvs.Len() +} + +func (f *faiss) InsertVQueueBufferLen() uint64 { + return uint64(f.vq.IVQLen()) +} + +func (f *faiss) DeleteVQueueBufferLen() uint64 { + return uint64(f.vq.DVQLen()) +} + +func (f *faiss) GetDimensionSize() int { + return f.dim +} + +func (f *faiss) GetTrainSize() int { + return f.trainSize +} + +func (f *faiss) Close(ctx context.Context) error { + err := f.kvs.Close() + if len(f.path) != 0 { + cerr := f.CreateIndex(ctx) + if cerr != nil && + !errors.Is(err, errors.ErrUncommittedIndexNotFound) && + !errors.Is(err, context.Canceled) && + !errors.Is(err, context.DeadlineExceeded) { + if err != nil { + err = errors.Wrap(cerr, err.Error()) + } else { + err = cerr + } + } + + serr := f.SaveIndex(ctx) + if serr != nil && + !errors.Is(err, errors.ErrUncommittedIndexNotFound) && + !errors.Is(err, context.Canceled) && + !errors.Is(err, context.DeadlineExceeded) { + if err != nil { + err = errors.Wrap(serr, err.Error()) + } else { + err = serr + } + } + } + + f.core.Close() + + return nil +} diff --git a/pkg/agent/core/faiss/service/option.go b/pkg/agent/core/faiss/service/option.go new file mode 100644 index 0000000000..1d271dffdc --- /dev/null +++ b/pkg/agent/core/faiss/service/option.go @@ -0,0 +1,271 @@ +// +// Copyright (C) 2019-2023 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package service + +import ( + "math" + "math/big" + "os" + "time" + + "github.com/vdaas/vald/internal/errgroup" + "github.com/vdaas/vald/internal/file" + "github.com/vdaas/vald/internal/rand" + "github.com/vdaas/vald/internal/strings" + "github.com/vdaas/vald/internal/timeutil" +) + +// Option represent the functional option for faiss +type Option func(f *faiss) error + +var defaultOptions = []Option{ + WithErrGroup(errgroup.Get()), + WithAutoIndexCheckDuration("30m"), + WithAutoSaveIndexDuration("35m"), + WithAutoIndexDurationLimit("24h"), + WithAutoIndexLength(100), + WithInitialDelayMaxDuration("3m"), + WithMinLoadIndexTimeout("3m"), + WithMaxLoadIndexTimeout("10m"), + WithLoadIndexTimeoutFactor("1ms"), + WithProactiveGC(true), +} + +// WithErrGroup returns the functional option to set the error group. +func WithErrGroup(eg errgroup.Group) Option { + return func(f *faiss) error { + if eg != nil { + f.eg = eg + } + + return nil + } +} + +// WithEnableInMemoryMode returns the functional option to set the in memory mode flag. +func WithEnableInMemoryMode(enabled bool) Option { + return func(f *faiss) error { + f.inMem = enabled + + return nil + } +} + +// WithIndexPath returns the functional option to set the index path of the Faiss. +func WithIndexPath(path string) Option { + return func(f *faiss) error { + if path == "" { + return nil + } + f.path = file.Join(strings.TrimSuffix(path, string(os.PathSeparator))) + return nil + } +} + +// WithAutoIndexCheckDuration returns the functional option to set the index check duration. +func WithAutoIndexCheckDuration(dur string) Option { + return func(f *faiss) error { + if dur == "" { + return nil + } + + d, err := timeutil.Parse(dur) + if err != nil { + return err + } + + f.dur = d + + return nil + } +} + +// WithAutoSaveIndexDuration returns the functional option to set the auto save index duration. +func WithAutoSaveIndexDuration(dur string) Option { + return func(f *faiss) error { + if dur == "" { + return nil + } + + d, err := timeutil.Parse(dur) + if err != nil { + return err + } + + f.sdur = d + + return nil + } +} + +// WithAutoIndexDurationLimit returns the functional option to set the auto index duration limit. +func WithAutoIndexDurationLimit(dur string) Option { + return func(f *faiss) error { + if dur == "" { + return nil + } + + d, err := timeutil.Parse(dur) + if err != nil { + return err + } + + f.lim = d + + return nil + } +} + +// WithAutoIndexLength returns the functional option to set the auto index length. +func WithAutoIndexLength(l int) Option { + return func(f *faiss) error { + f.alen = l + + return nil + } +} + +const ( + defaultDurationLimit float64 = 1.1 + defaultRandDuration int64 = 1 +) + +var ( + bigMaxFloat64 = big.NewFloat(math.MaxFloat64) + bigMinFloat64 = big.NewFloat(math.SmallestNonzeroFloat64) + bigMaxInt64 = big.NewInt(math.MaxInt64) + bigMinInt64 = big.NewInt(math.MinInt64) +) + +// WithInitialDelayMaxDuration returns the functional option to set the initial delay duration. +func WithInitialDelayMaxDuration(dur string) Option { + return func(f *faiss) error { + if dur == "" { + return nil + } + + d, err := timeutil.Parse(dur) + if err != nil { + return err + } + + var dt time.Duration + switch { + case d <= time.Nanosecond: + return nil + case d <= time.Microsecond: + dt = time.Nanosecond + case d <= time.Millisecond: + dt = time.Microsecond + case d <= time.Second: + dt = time.Millisecond + default: + dt = time.Second + } + + dbs := math.Round(float64(d) / float64(dt)) + bdbs := big.NewFloat(dbs) + if dbs <= 0 || bigMaxFloat64.Cmp(bdbs) <= 0 || bigMinFloat64.Cmp(bdbs) >= 0 { + dbs = defaultDurationLimit + } + + rnd := int64(rand.LimitedUint32(uint64(dbs))) + brnd := big.NewInt(rnd) + if rnd <= 0 || bigMaxInt64.Cmp(brnd) <= 0 || bigMinInt64.Cmp(brnd) >= 0 { + rnd = defaultRandDuration + } + + delay := time.Duration(rnd) * dt + if delay <= 0 || delay >= math.MaxInt64 || delay <= math.MinInt64 { + return WithInitialDelayMaxDuration(dur)(f) + } + + f.idelay = delay + + return nil + } +} + +// WithMinLoadIndexTimeout returns the functional option to set the minimal load index timeout. +func WithMinLoadIndexTimeout(dur string) Option { + return func(f *faiss) error { + if dur == "" { + return nil + } + + d, err := timeutil.Parse(dur) + if err != nil { + return err + } + + f.minLit = d + + return nil + } +} + +// WithMaxLoadIndexTimeout returns the functional option to set the maximum load index timeout. +func WithMaxLoadIndexTimeout(dur string) Option { + return func(f *faiss) error { + if dur == "" { + return nil + } + + d, err := timeutil.Parse(dur) + if err != nil { + return err + } + + f.maxLit = d + + return nil + } +} + +// WithLoadIndexTimeoutFactor returns the functional option to set the factor of load index timeout. +func WithLoadIndexTimeoutFactor(dur string) Option { + return func(f *faiss) error { + if dur == "" { + return nil + } + + d, err := timeutil.Parse(dur) + if err != nil { + return err + } + + f.litFactor = d + + return nil + } +} + +// WithProactiveGC returns the functional option to set the proactive GC enable flag. +func WithProactiveGC(enabled bool) Option { + return func(f *faiss) error { + f.enableProactiveGC = enabled + return nil + } +} + +// WithCopyOnWrite returns the functional option to set the CoW enable flag. +func WithCopyOnWrite(enabled bool) Option { + return func(f *faiss) error { + f.enableCopyOnWrite = enabled + return nil + } +} diff --git a/pkg/agent/core/faiss/usecase/agentd.go b/pkg/agent/core/faiss/usecase/agentd.go new file mode 100644 index 0000000000..12dcc430ad --- /dev/null +++ b/pkg/agent/core/faiss/usecase/agentd.go @@ -0,0 +1,193 @@ +// +// Copyright (C) 2019-2023 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package usecase + +import ( + "context" + + agent "github.com/vdaas/vald/apis/grpc/v1/agent/core" + vald "github.com/vdaas/vald/apis/grpc/v1/vald" + iconf "github.com/vdaas/vald/internal/config" + "github.com/vdaas/vald/internal/errgroup" + "github.com/vdaas/vald/internal/net/grpc" + "github.com/vdaas/vald/internal/observability" + faissmetrics "github.com/vdaas/vald/internal/observability/metrics/agent/core/faiss" + infometrics "github.com/vdaas/vald/internal/observability/metrics/info" + "github.com/vdaas/vald/internal/runner" + "github.com/vdaas/vald/internal/safety" + "github.com/vdaas/vald/internal/servers/server" + "github.com/vdaas/vald/internal/servers/starter" + "github.com/vdaas/vald/pkg/agent/core/faiss/config" + handler "github.com/vdaas/vald/pkg/agent/core/faiss/handler/grpc" + "github.com/vdaas/vald/pkg/agent/core/faiss/handler/rest" + "github.com/vdaas/vald/pkg/agent/core/faiss/router" + "github.com/vdaas/vald/pkg/agent/core/faiss/service" +) + +type run struct { + eg errgroup.Group + cfg *config.Data + faiss service.Faiss + server starter.Server + observability observability.Observability +} + +func New(cfg *config.Data) (r runner.Runner, err error) { + faiss, err := service.New( + cfg.Faiss, + service.WithErrGroup(errgroup.Get()), + service.WithEnableInMemoryMode(cfg.Faiss.EnableInMemoryMode), + service.WithIndexPath(cfg.Faiss.IndexPath), + service.WithAutoIndexCheckDuration(cfg.Faiss.AutoIndexCheckDuration), + service.WithAutoSaveIndexDuration(cfg.Faiss.AutoSaveIndexDuration), + service.WithAutoIndexDurationLimit(cfg.Faiss.AutoIndexDurationLimit), + service.WithAutoIndexLength(cfg.Faiss.AutoIndexLength), + service.WithInitialDelayMaxDuration(cfg.Faiss.InitialDelayMaxDuration), + service.WithMinLoadIndexTimeout(cfg.Faiss.MinLoadIndexTimeout), + service.WithMaxLoadIndexTimeout(cfg.Faiss.MaxLoadIndexTimeout), + service.WithLoadIndexTimeoutFactor(cfg.Faiss.LoadIndexTimeoutFactor), + service.WithProactiveGC(cfg.Faiss.EnableProactiveGC), + service.WithCopyOnWrite(cfg.Faiss.EnableCopyOnWrite), + ) + if err != nil { + return nil, err + } + + g, err := handler.New( + handler.WithFaiss(faiss), + handler.WithStreamConcurrency(cfg.Server.GetGRPCStreamConcurrency()), + ) + if err != nil { + return nil, err + } + + eg := errgroup.Get() + + grpcServerOptions := []server.Option{ + server.WithGRPCRegistFunc(func(srv *grpc.Server) { + agent.RegisterAgentServer(srv, g) + vald.RegisterValdServer(srv, g) + }), + server.WithPreStartFunc(func() error { + return nil + }), + server.WithPreStopFunction(func() error { + return nil + }), + } + + var obs observability.Observability + if cfg.Observability != nil && cfg.Observability.Enabled { + obs, err = observability.NewWithConfig( + cfg.Observability, + faissmetrics.New(faiss), + infometrics.New("agent_core_faiss_info", "Agent Faiss info", *cfg.Faiss), + ) + if err != nil { + return nil, err + } + } + + srv, err := starter.New( + starter.WithConfig(cfg.Server), + starter.WithREST(func(sc *iconf.Server) []server.Option { + return []server.Option{ + server.WithHTTPHandler( + router.New( + router.WithTimeout(sc.HTTP.HandlerTimeout), + router.WithErrGroup(eg), + router.WithHandler( + rest.New( + rest.WithAgent(g), + ), + ), + ), + ), + } + }), + starter.WithGRPC(func(sc *iconf.Server) []server.Option { + return grpcServerOptions + }), + ) + if err != nil { + return nil, err + } + + return &run{ + eg: eg, + faiss: faiss, + cfg: cfg, + server: srv, + observability: obs, + }, nil +} + +func (r *run) PreStart(ctx context.Context) error { + if r.observability != nil { + return r.observability.PreStart(ctx) + } + + return nil +} + +func (r *run) Start(ctx context.Context) (<-chan error, error) { + ech := make(chan error, 3) + var oech, nech, sech <-chan error + r.eg.Go(safety.RecoverFunc(func() (err error) { + defer close(ech) + if r.observability != nil { + oech = r.observability.Start(ctx) + } + nech = r.faiss.Start(ctx) + sech = r.server.ListenAndServe(ctx) + for { + select { + case <-ctx.Done(): + return ctx.Err() + case err = <-oech: + case err = <-nech: + case err = <-sech: + } + if err != nil { + select { + case <-ctx.Done(): + return ctx.Err() + case ech <- err: + } + } + } + })) + + return ech, nil +} + +func (r *run) PreStop(ctx context.Context) error { + return nil +} + +func (r *run) Stop(ctx context.Context) error { + if r.observability != nil { + r.observability.Stop(ctx) + } + + return r.server.Shutdown(ctx) +} + +func (r *run) PostStop(ctx context.Context) error { + r.faiss.Close(ctx) + return nil +} diff --git a/pkg/agent/internal/metadata/metadata.go b/pkg/agent/internal/metadata/metadata.go index 4c0225f98d..a40e1d8926 100644 --- a/pkg/agent/internal/metadata/metadata.go +++ b/pkg/agent/internal/metadata/metadata.go @@ -32,14 +32,19 @@ const ( ) type Metadata struct { - IsInvalid bool `json:"is_invalid" yaml:"is_invalid"` - NGT *NGT `json:"ngt,omitempty" yaml:"ngt"` + IsInvalid bool `json:"is_invalid" yaml:"is_invalid"` + NGT *NGT `json:"ngt,omitempty" yaml:"ngt"` + Faiss *Faiss `json:"faiss,omitempty" yaml:"faiss"` } type NGT struct { IndexCount uint64 `json:"index_count" yaml:"index_count"` } +type Faiss struct { + IndexCount uint64 `json:"index_count" yaml:"index_count"` +} + func Load(path string) (meta *Metadata, err error) { var fi os.FileInfo exists, fi, err := file.ExistsWithDetail(path) diff --git a/tests/e2e/crud/crud_faiss_test.go b/tests/e2e/crud/crud_faiss_test.go new file mode 100644 index 0000000000..548738a61a --- /dev/null +++ b/tests/e2e/crud/crud_faiss_test.go @@ -0,0 +1,331 @@ +//go:build e2e + +// +// Copyright (C) 2019-2023 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// package crud provides e2e tests using ann-benchmarks datasets +package crud + +import ( + "context" + "flag" + "fmt" + "os" + "testing" + "time" + + "github.com/vdaas/vald/internal/errors" + "github.com/vdaas/vald/internal/file" + "github.com/vdaas/vald/internal/net/grpc/codes" + "github.com/vdaas/vald/internal/net/grpc/status" + "github.com/vdaas/vald/tests/e2e/hdf5" + "github.com/vdaas/vald/tests/e2e/kubernetes/client" + "github.com/vdaas/vald/tests/e2e/kubernetes/portforward" + "github.com/vdaas/vald/tests/e2e/operation" +) + +var ( + host string + port int + ds *hdf5.Dataset + + insertNum int + searchNum int + searchByIDNum int + getObjectNum int + updateNum int + upsertNum int + removeNum int + + insertFrom int + searchFrom int + searchByIDFrom int + getObjectFrom int + updateFrom int + upsertFrom int + removeFrom int + + waitAfterInsertDuration time.Duration + + kubeClient client.Client + namespace string + + forwarder *portforward.Portforward +) + +func init() { + testing.Init() + + flag.StringVar(&host, "host", "localhost", "hostname") + flag.IntVar(&port, "port", 8081, "gRPC port") + + flag.IntVar(&insertNum, "insert-num", 10000, "number of id-vector pairs used for insert") + flag.IntVar(&searchNum, "search-num", 10000, "number of id-vector pairs used for search") + flag.IntVar(&updateNum, "update-num", 10000, "number of id-vector pairs used for update") + flag.IntVar(&removeNum, "remove-num", 10000, "number of id-vector pairs used for remove") + + flag.IntVar(&insertFrom, "insert-from", 0, "first index of id-vector pairs used for insert") + flag.IntVar(&searchFrom, "search-from", 0, "first index of id-vector pairs used for search") + flag.IntVar(&updateFrom, "update-from", 0, "first index of id-vector pairs used for update") + flag.IntVar(&removeFrom, "remove-from", 0, "first index of id-vector pairs used for remove") + + datasetName := flag.String("dataset", "fashion-mnist-784-euclidean.hdf5", "dataset") + waitAfterInsert := flag.String("wait-after-insert", "3m", "wait duration after inserting vectors") + + pf := flag.Bool("portforward", false, "enable port forwarding") + pfPodName := flag.String("portforward-pod-name", "vald-gateway-0", "pod name (only for port forward)") + pfPodPort := flag.Int("portforward-pod-port", port, "pod gRPC port (only for port forward)") + + kubeConfig := flag.String("kubeconfig", file.Join(os.Getenv("HOME"), ".kube", "config"), "kubeconfig path") + flag.StringVar(&namespace, "namespace", "default", "namespace") + + flag.Parse() + + var err error + if *pf { + kubeClient, err = client.New(*kubeConfig) + if err != nil { + panic(err) + } + + forwarder = kubeClient.Portforward(namespace, *pfPodName, port, *pfPodPort) + + err = forwarder.Start() + if err != nil { + panic(err) + } + } + + fmt.Printf("loading dataset: %s ", *datasetName) + ds, err = hdf5.HDF5ToDataset(*datasetName) + if err != nil { + panic(err) + } + fmt.Println("loading finished") + + waitAfterInsertDuration, err = time.ParseDuration(*waitAfterInsert) + if err != nil { + panic(err) + } +} + +func teardown() { + if forwarder != nil { + forwarder.Close() + } +} + +func sleep(t *testing.T, dur time.Duration) { + t.Logf("%v sleep for %s.", time.Now(), dur) + time.Sleep(dur) + t.Logf("%v sleep finished.", time.Now()) +} + +func TestE2EInsertOnly(t *testing.T) { + t.Cleanup(teardown) + ctx := context.Background() + + op, err := operation.New(host, port) + if err != nil { + t.Fatalf("an error occurred: %s", err) + } + + err = op.Insert(t, ctx, operation.Dataset{ + Train: ds.Train[insertFrom : insertFrom+insertNum], + }) + if err != nil { + t.Fatalf("an error occurred: %s", err) + } +} + +func TestE2ESearchOnly(t *testing.T) { + t.Cleanup(teardown) + ctx := context.Background() + + op, err := operation.New(host, port) + if err != nil { + t.Fatalf("an error occurred: %s", err) + } + + err = op.Search(t, ctx, operation.Dataset{ + Test: ds.Test[searchFrom : searchFrom+searchNum], + Neighbors: ds.Neighbors[searchFrom : searchFrom+searchNum], + }) + if err != nil { + t.Fatalf("an error occurred: %s", err) + } +} + +func TestE2EUpdateOnly(t *testing.T) { + t.Cleanup(teardown) + ctx := context.Background() + + op, err := operation.New(host, port) + if err != nil { + t.Fatalf("an error occurred: %s", err) + } + + err = op.UpdateWithParameters( + t, + ctx, + operation.Dataset{ + Train: ds.Train[updateFrom : updateFrom+updateNum], + }, + true, + 1, + func(t *testing.T, status int32, msg string) error { + t.Helper() + + if status != int32(codes.NotFound) { + return errors.Errorf("the returned status is not NotFound on Update #1: %s", err) + } + + t.Logf("received a NotFound error on #1: %s", msg) + + return nil + }, + func(t *testing.T, err error) error { + t.Helper() + + st, _, _ := status.ParseError(err, codes.Unknown, "") + if st.Code() != codes.NotFound { + return err + } + + return nil + }, + ) + if err != nil { + t.Fatalf("an error occurred: %s", err) + } +} + +func TestE2ERemoveOnly(t *testing.T) { + t.Cleanup(teardown) + ctx := context.Background() + + op, err := operation.New(host, port) + if err != nil { + t.Fatalf("an error occurred: %s", err) + } + + err = op.Remove(t, ctx, operation.Dataset{ + Train: ds.Train[removeFrom : removeFrom+removeNum], + }) + if err != nil { + t.Fatalf("an error occurred: %s", err) + } +} + +func TestE2EInsertAndSearch(t *testing.T) { + t.Cleanup(teardown) + ctx := context.Background() + + op, err := operation.New(host, port) + if err != nil { + t.Fatalf("an error occurred: %s", err) + } + + err = op.Insert(t, ctx, operation.Dataset{ + Train: ds.Train[insertFrom : insertFrom+insertNum], + }) + if err != nil { + t.Fatalf("an error occurred: %s", err) + } + + sleep(t, waitAfterInsertDuration) + + err = op.Search(t, ctx, operation.Dataset{ + Test: ds.Test[searchFrom : searchFrom+searchNum], + Neighbors: ds.Neighbors[searchFrom : searchFrom+searchNum], + }) + if err != nil { + t.Fatalf("an error occurred: %s", err) + } +} + +func TestE2EStandardCRUD(t *testing.T) { + t.Cleanup(teardown) + ctx := context.Background() + + op, err := operation.New(host, port) + if err != nil { + t.Fatalf("an error occurred: %s", err) + } + + err = op.Insert(t, ctx, operation.Dataset{ + Train: ds.Train[insertFrom : insertFrom+insertNum], + }) + if err != nil { + t.Fatalf("an error occurred: %s", err) + } + + sleep(t, waitAfterInsertDuration) + + err = op.Search(t, ctx, operation.Dataset{ + Test: ds.Test[searchFrom : searchFrom+searchNum], + Neighbors: ds.Neighbors[searchFrom : searchFrom+searchNum], + }) + if err != nil { + t.Fatalf("an error occurred: %s", err) + } + + err = op.Exists(t, ctx, "0") + if err != nil { + t.Fatalf("an error occurred: %s", err) + } + + err = op.UpdateWithParameters( + t, + ctx, + operation.Dataset{ + Train: ds.Train[updateFrom : updateFrom+updateNum], + }, + true, + 1, + func(t *testing.T, status int32, msg string) error { + t.Helper() + + if status != int32(codes.NotFound) { + return errors.Errorf("the returned status is not NotFound on Update #1: %s", err) + } + + t.Logf("received a NotFound error on #1: %s", msg) + + return nil + }, + func(t *testing.T, err error) error { + t.Helper() + + st, _, _ := status.ParseError(err, codes.Unknown, "") + if st.Code() != codes.NotFound { + return err + } + + return nil + }, + ) + if err != nil { + t.Fatalf("an error occurred: %s", err) + } + + err = op.Remove(t, ctx, operation.Dataset{ + Train: ds.Train[removeFrom : removeFrom+removeNum], + }) + if err != nil { + t.Fatalf("an error occurred: %s", err) + } +} diff --git a/versions/FAISS_VERSION b/versions/FAISS_VERSION new file mode 100644 index 0000000000..661e7aeadf --- /dev/null +++ b/versions/FAISS_VERSION @@ -0,0 +1 @@ +1.7.3