Skip to content

Commit

Permalink
feat: Clone with Embedding Function
Browse files Browse the repository at this point in the history
  • Loading branch information
tazarov committed Mar 10, 2024
1 parent b7c28f4 commit 0f54467
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 11 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
*.test
*.out
go.work
.env
19 changes: 13 additions & 6 deletions cmd/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,6 @@ func cloneCollection(cmd *cobra.Command, args []string) error {
cmd.Printf("invalid meta: %v\n", err)
return err
}
fmt.Printf("metadatasVar: %v\n", *metadatasVar)
var metadatasVal = make(map[string]interface{})
for k, v := range sourceCollection.Metadata {
if k == types.HNSWSpace || k == types.HNSWM || k == types.HNSWConstructionEF || k == types.HNSWSearchEF || k == types.HNSWBatchSize || k == types.HNSWSyncThreshold || k == types.HNSWNumThreads || k == types.HNSWResizeFactor {
Expand All @@ -327,11 +326,8 @@ func cloneCollection(cmd *cobra.Command, args []string) error {
metadatasVal[k] = v
}

fmt.Printf("metadatasVal: %v\n", metadatasVal)

if cmd.Flag("meta").Changed {
for _, meta := range *metadatasVar {
fmt.Printf("meta: %v\n", meta)
kvPair := strings.Split(meta, "=")
if len(kvPair) != 2 {
cmd.Printf("invalid metadata format: %v. should be key=value.", meta)
Expand All @@ -356,6 +352,14 @@ func cloneCollection(cmd *cobra.Command, args []string) error {
} else {
collectionOptions = append(collectionOptions, collection.WithHNSWDistanceFunction(df))
}
var hasEf = false
if efVal, err := embeddingFunctionForString(cmd.Flags().GetString("embedding-function")); err != nil {
cmd.Printf("invalid embedding-function: %v\n", err)
return err
} else if efVal != nil {
hasEf = true
collectionOptions = append(collectionOptions, collection.WithEmbeddingFunction(efVal))
}

if mVal != nil {
collectionOptions = append(collectionOptions, collection.WithHNSWM(int32(*mVal)))
Expand Down Expand Up @@ -407,8 +411,10 @@ func cloneCollection(cmd *cobra.Command, args []string) error {
cmd.Printf("%v\n", err)
return err
}

_embeddings := result.Embeddings
var _embeddings []*types.Embedding
if !hasEf {
_embeddings = result.Embeddings
}
_, err = targetCollection.Add(context.TODO(), _embeddings, result.Metadatas, result.Documents, result.Ids)
if err != nil { // TODO not great to exit on first error but for now that will do. Consider rollback?
cmd.Printf("%v\n", err)
Expand Down Expand Up @@ -560,6 +566,7 @@ func init() {
CloneCollectionCommand.Flags().IntP("sync-threshold", "k", 1000, "hnsw:sync_threshold - The number of elements added to the HNSW index before the index is synced to disk.")
CloneCollectionCommand.Flags().IntP("threads", "n", -1, "hnsw:threads - The number of threads to use during index construction and searches. Defaults to the number of logical cores on the machine.")
CloneCollectionCommand.Flags().Float32P("resize-factor", "r", 1.2, "hnsw:resize_factor - This parameter is used by HNSW's hierarchical layers during insertion..")
CloneCollectionCommand.Flags().StringP("embedding-function", "e", "", "The name of the embedding function to use for the target collection")
CloneCollectionCommand.Flags().StringSliceVarP(&metaSlice, "meta", "a", []string{}, "Defines a single key-value attribute (KVP) to added to collection metadata.")
rootCmd.AddCommand(CloneCollectionCommand)
}
32 changes: 28 additions & 4 deletions cmd/collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"strconv"
"testing"

"github.com/joho/godotenv"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -106,10 +107,6 @@ func addDummyRecordsToCollection(t *testing.T, client *chroma.Client, collection
}
func TestCreateCollectionCommand(t *testing.T) {
command := rootCmd
err := os.Setenv("TEST", "1")
if err != nil {
return
}

t.Run("Create Collection basic", func(t *testing.T) {
client := setup()
Expand Down Expand Up @@ -966,6 +963,33 @@ func TestCloneCollectionCommand(t *testing.T) {
require.Contains(t, output, targetCollectionName)
require.Contains(t, output, "10")
})

t.Run("Clone Collection with openai ef", func(t *testing.T) {
_ = godotenv.Load("../.env")
if os.Getenv("OPENAI_API_KEY") == "" {
t.Skip("Skipping test as OPENAI_API_KEY is not set")
}
resetCloneCommandFlags()
client := setup()
defer tearDown(client)
var sourceCollectionName = "my-new-collection" + strconv.Itoa(rand.Int())
var targetCollectionName = "my-new-collection-copy" + strconv.Itoa(rand.Int())
helperCreateCollection(t, client, sourceCollectionName)
addDummyRecordsToCollection(t, client, sourceCollectionName, 10)
buf := new(bytes.Buffer)
command.SetOut(buf)
command.SetErr(buf)
command.SetArgs([]string{"clone", sourceCollectionName, targetCollectionName, "-e", "openai"})
_, err := command.ExecuteC()
require.NoError(t, err)
output := buf.String()
assertCollectionExists(t, client, sourceCollectionName)
assertCollectionExists(t, client, targetCollectionName)
require.Contains(t, output, "successfully cloned")
require.Contains(t, output, sourceCollectionName)
require.Contains(t, output, targetCollectionName)
require.Contains(t, output, "10")
})
}
func TestCloneCollectionCommandWithSource(t *testing.T) {
command := rootCmd
Expand Down
25 changes: 25 additions & 0 deletions cmd/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,38 @@ package cmd
import (
"context"
"fmt"
"os"

"chroma/utils"
"github.com/spf13/viper"

"github.com/amikos-tech/chroma-go"
"github.com/amikos-tech/chroma-go/cohere"
"github.com/amikos-tech/chroma-go/hf"
"github.com/amikos-tech/chroma-go/openai"
"github.com/amikos-tech/chroma-go/types"
)

func embeddingFunctionForString(embedder string, err error) (types.EmbeddingFunction, error) {
if err != nil {
return nil, err
}
if embedder == "" {
return nil, nil // TODO this is rather a hack
}
switch embedder {
case "openai":
return openai.NewOpenAIEmbeddingFunction(os.Getenv("OPENAI_API_KEY"))
case "cohere":
return cohere.NewCohereEmbeddingFunction(os.Getenv("COHERE_API_KEY")), nil
case "hf":
return hf.NewHuggingFaceEmbeddingFunction(os.Getenv("HF_API_KEY"), os.Getenv("HF_MODEL")), nil
case "hash": // dummy embedding function
return types.NewConsistentHashEmbeddingFunction(), nil
default:
return nil, fmt.Errorf("embedding function not found")
}
}
func getClient(serverAlias string) (*chroma.Client, error) {

Check failure on line 38 in cmd/utils.go

View workflow job for this annotation

GitHub Actions / lint

undefined: chroma (typecheck)
var serverConfig map[string]interface{}
var err error
Expand Down
2 changes: 1 addition & 1 deletion docs/docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ chroma clone <collection-name> <target-collection>\
-k/--sync-threshold <hnsw:sync_threshold> \
-n/--threads <hnsw:threads> \
-r/--resize-factor <hnsw:resize_factor> \
-e/--ensure <create_if_not_exist>
--embedding-function/-e <embedding-function>
```

All flags are optional and applied to the target collection.
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ require (
github.com/amikos-tech/chroma-go v0.0.1
github.com/charmbracelet/huh v0.3.0
github.com/go-playground/validator/v10 v10.19.0
github.com/joho/godotenv v1.5.1
github.com/mitchellh/go-homedir v1.1.0
github.com/spf13/cobra v1.8.0
github.com/spf13/viper v1.18.2
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
Expand Down

0 comments on commit 0f54467

Please sign in to comment.