From 0f5446778a539009b97554f1dd4d3e24aa62ed3d Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Sun, 10 Mar 2024 23:59:11 +0200 Subject: [PATCH] feat: Clone with Embedding Function --- .gitignore | 1 + cmd/collection.go | 19 +++++++++++++------ cmd/collection_test.go | 32 ++++++++++++++++++++++++++++---- cmd/utils.go | 25 +++++++++++++++++++++++++ docs/docs/index.md | 2 +- go.mod | 1 + go.sum | 2 ++ 7 files changed, 71 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index bc99e7d..8f0b79b 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ *.test *.out go.work +.env diff --git a/cmd/collection.go b/cmd/collection.go index 2d1e9b4..d1b3123 100644 --- a/cmd/collection.go +++ b/cmd/collection.go @@ -318,7 +318,6 @@ func cloneCollection(cmd *cobra.Command, args []string) error { cmd.Printf("invalid meta: %v\n", err) return err } - fmt.Printf("metadatasVar: %v\n", *metadatasVar) var metadatasVal = make(map[string]interface{}) for k, v := range sourceCollection.Metadata { if k == types.HNSWSpace || k == types.HNSWM || k == types.HNSWConstructionEF || k == types.HNSWSearchEF || k == types.HNSWBatchSize || k == types.HNSWSyncThreshold || k == types.HNSWNumThreads || k == types.HNSWResizeFactor { @@ -327,11 +326,8 @@ func cloneCollection(cmd *cobra.Command, args []string) error { metadatasVal[k] = v } - fmt.Printf("metadatasVal: %v\n", metadatasVal) - if cmd.Flag("meta").Changed { for _, meta := range *metadatasVar { - fmt.Printf("meta: %v\n", meta) kvPair := strings.Split(meta, "=") if len(kvPair) != 2 { cmd.Printf("invalid metadata format: %v. should be key=value.", meta) @@ -356,6 +352,14 @@ func cloneCollection(cmd *cobra.Command, args []string) error { } else { collectionOptions = append(collectionOptions, collection.WithHNSWDistanceFunction(df)) } + var hasEf = false + if efVal, err := embeddingFunctionForString(cmd.Flags().GetString("embedding-function")); err != nil { + cmd.Printf("invalid embedding-function: %v\n", err) + return err + } else if efVal != nil { + hasEf = true + collectionOptions = append(collectionOptions, collection.WithEmbeddingFunction(efVal)) + } if mVal != nil { collectionOptions = append(collectionOptions, collection.WithHNSWM(int32(*mVal))) @@ -407,8 +411,10 @@ func cloneCollection(cmd *cobra.Command, args []string) error { cmd.Printf("%v\n", err) return err } - - _embeddings := result.Embeddings + var _embeddings []*types.Embedding + if !hasEf { + _embeddings = result.Embeddings + } _, err = targetCollection.Add(context.TODO(), _embeddings, result.Metadatas, result.Documents, result.Ids) if err != nil { // TODO not great to exit on first error but for now that will do. Consider rollback? cmd.Printf("%v\n", err) @@ -560,6 +566,7 @@ func init() { CloneCollectionCommand.Flags().IntP("sync-threshold", "k", 1000, "hnsw:sync_threshold - The number of elements added to the HNSW index before the index is synced to disk.") CloneCollectionCommand.Flags().IntP("threads", "n", -1, "hnsw:threads - The number of threads to use during index construction and searches. Defaults to the number of logical cores on the machine.") CloneCollectionCommand.Flags().Float32P("resize-factor", "r", 1.2, "hnsw:resize_factor - This parameter is used by HNSW's hierarchical layers during insertion..") + CloneCollectionCommand.Flags().StringP("embedding-function", "e", "", "The name of the embedding function to use for the target collection") CloneCollectionCommand.Flags().StringSliceVarP(&metaSlice, "meta", "a", []string{}, "Defines a single key-value attribute (KVP) to added to collection metadata.") rootCmd.AddCommand(CloneCollectionCommand) } diff --git a/cmd/collection_test.go b/cmd/collection_test.go index 5e4ba5d..f8bb838 100644 --- a/cmd/collection_test.go +++ b/cmd/collection_test.go @@ -9,6 +9,7 @@ import ( "strconv" "testing" + "github.com/joho/godotenv" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -106,10 +107,6 @@ func addDummyRecordsToCollection(t *testing.T, client *chroma.Client, collection } func TestCreateCollectionCommand(t *testing.T) { command := rootCmd - err := os.Setenv("TEST", "1") - if err != nil { - return - } t.Run("Create Collection basic", func(t *testing.T) { client := setup() @@ -966,6 +963,33 @@ func TestCloneCollectionCommand(t *testing.T) { require.Contains(t, output, targetCollectionName) require.Contains(t, output, "10") }) + + t.Run("Clone Collection with openai ef", func(t *testing.T) { + _ = godotenv.Load("../.env") + if os.Getenv("OPENAI_API_KEY") == "" { + t.Skip("Skipping test as OPENAI_API_KEY is not set") + } + resetCloneCommandFlags() + client := setup() + defer tearDown(client) + var sourceCollectionName = "my-new-collection" + strconv.Itoa(rand.Int()) + var targetCollectionName = "my-new-collection-copy" + strconv.Itoa(rand.Int()) + helperCreateCollection(t, client, sourceCollectionName) + addDummyRecordsToCollection(t, client, sourceCollectionName, 10) + buf := new(bytes.Buffer) + command.SetOut(buf) + command.SetErr(buf) + command.SetArgs([]string{"clone", sourceCollectionName, targetCollectionName, "-e", "openai"}) + _, err := command.ExecuteC() + require.NoError(t, err) + output := buf.String() + assertCollectionExists(t, client, sourceCollectionName) + assertCollectionExists(t, client, targetCollectionName) + require.Contains(t, output, "successfully cloned") + require.Contains(t, output, sourceCollectionName) + require.Contains(t, output, targetCollectionName) + require.Contains(t, output, "10") + }) } func TestCloneCollectionCommandWithSource(t *testing.T) { command := rootCmd diff --git a/cmd/utils.go b/cmd/utils.go index 57cf8d7..6b3c4f0 100644 --- a/cmd/utils.go +++ b/cmd/utils.go @@ -3,13 +3,38 @@ package cmd import ( "context" "fmt" + "os" "chroma/utils" "github.com/spf13/viper" "github.com/amikos-tech/chroma-go" + "github.com/amikos-tech/chroma-go/cohere" + "github.com/amikos-tech/chroma-go/hf" + "github.com/amikos-tech/chroma-go/openai" + "github.com/amikos-tech/chroma-go/types" ) +func embeddingFunctionForString(embedder string, err error) (types.EmbeddingFunction, error) { + if err != nil { + return nil, err + } + if embedder == "" { + return nil, nil // TODO this is rather a hack + } + switch embedder { + case "openai": + return openai.NewOpenAIEmbeddingFunction(os.Getenv("OPENAI_API_KEY")) + case "cohere": + return cohere.NewCohereEmbeddingFunction(os.Getenv("COHERE_API_KEY")), nil + case "hf": + return hf.NewHuggingFaceEmbeddingFunction(os.Getenv("HF_API_KEY"), os.Getenv("HF_MODEL")), nil + case "hash": // dummy embedding function + return types.NewConsistentHashEmbeddingFunction(), nil + default: + return nil, fmt.Errorf("embedding function not found") + } +} func getClient(serverAlias string) (*chroma.Client, error) { var serverConfig map[string]interface{} var err error diff --git a/docs/docs/index.md b/docs/docs/index.md index fbf8ec5..64573ac 100644 --- a/docs/docs/index.md +++ b/docs/docs/index.md @@ -111,7 +111,7 @@ chroma clone \ -k/--sync-threshold \ -n/--threads \ -r/--resize-factor \ - -e/--ensure + --embedding-function/-e ``` All flags are optional and applied to the target collection. diff --git a/go.mod b/go.mod index 9b972e4..db61be1 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/amikos-tech/chroma-go v0.0.1 github.com/charmbracelet/huh v0.3.0 github.com/go-playground/validator/v10 v10.19.0 + github.com/joho/godotenv v1.5.1 github.com/mitchellh/go-homedir v1.1.0 github.com/spf13/cobra v1.8.0 github.com/spf13/viper v1.18.2 diff --git a/go.sum b/go.sum index 43ca65b..a2f794d 100644 --- a/go.sum +++ b/go.sum @@ -40,6 +40,8 @@ github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=