Skip to content

Commit

Permalink
Normalize vectors on embedding creation instead of querying
Browse files Browse the repository at this point in the history
- Normalizes only once instead of each time
- Embedding creation takes time anyway, while query should be as fast as possible
  • Loading branch information
philippgille committed Mar 16, 2024
1 parent 05c4f76 commit 579fd46
Show file tree
Hide file tree
Showing 13 changed files with 105 additions and 27 deletions.
6 changes: 5 additions & 1 deletion collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,17 @@ func (c *Collection) AddDocument(ctx context.Context, doc Document) error {
m[k] = v
}

// Create embedding if they don't exist
// Create embedding if they don't exist, otherwise normalize if necessary
if len(doc.Embedding) == 0 {
embedding, err := c.embed(ctx, doc.Content)
if err != nil {
return fmt.Errorf("couldn't create embedding of document: %w", err)
}
doc.Embedding = embedding
} else {
if !isNormalized(doc.Embedding) {
doc.Embedding = normalizeVector(doc.Embedding)
}
}

c.documentsLock.Lock()
Expand Down
11 changes: 6 additions & 5 deletions collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func TestCollection_Add(t *testing.T) {
ctx := context.Background()
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.1, 0.1, 0.2}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}
Expand Down Expand Up @@ -109,7 +109,7 @@ func TestCollection_Add_Error(t *testing.T) {
ctx := context.Background()
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.1, 0.1, 0.2}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}
Expand Down Expand Up @@ -160,7 +160,7 @@ func TestCollection_AddConcurrently(t *testing.T) {
ctx := context.Background()
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.1, 0.1, 0.2}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}
Expand Down Expand Up @@ -256,7 +256,7 @@ func TestCollection_AddConcurrently_Error(t *testing.T) {
ctx := context.Background()
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.1, 0.1, 0.2}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}
Expand Down Expand Up @@ -313,8 +313,9 @@ func TestCollection_Count(t *testing.T) {
db := NewDB()
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return []float32{-0.1, 0.1, 0.2}, nil
return vectors, nil
}
c, err := db.CreateCollection(name, metadata, embeddingFunc)
if err != nil {
Expand Down
3 changes: 3 additions & 0 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ import (
// EmbeddingFunc is a function that creates embeddings for a given text.
// chromem-go will use OpenAI`s "text-embedding-3-small" model by default,
// but you can provide your own function, using any model you like.
// The function must return a *normalized* vector, i.e. the length of the vector
// must be 1. OpenAI's and Mistral's embedding models do this by default. Some
// others like Nomic's "nomic-embed-text-v1.5" don't.
type EmbeddingFunc func(ctx context.Context, text string) ([]float32, error)

// DB is the chromem-go database. It holds collections, which hold documents.
Expand Down
12 changes: 6 additions & 6 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ func TestDB_CreateCollection(t *testing.T) {
// Values in the collection
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.1, 0.1, 0.2}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}
Expand Down Expand Up @@ -81,7 +81,7 @@ func TestDB_ListCollections(t *testing.T) {
// Values in the collection
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.1, 0.1, 0.2}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}
Expand Down Expand Up @@ -147,7 +147,7 @@ func TestDB_GetCollection(t *testing.T) {
// Values in the collection
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.1, 0.1, 0.2}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}
Expand Down Expand Up @@ -196,7 +196,7 @@ func TestDB_GetOrCreateCollection(t *testing.T) {
// Values in the collection
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.1, 0.1, 0.2}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}
Expand Down Expand Up @@ -299,7 +299,7 @@ func TestDB_DeleteCollection(t *testing.T) {
// Values in the collection
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.1, 0.1, 0.2}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}
Expand Down Expand Up @@ -331,7 +331,7 @@ func TestDB_Reset(t *testing.T) {
// Values in the collection
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.1, 0.1, 0.2}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}
Expand Down
2 changes: 1 addition & 1 deletion document_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ func TestDocument_New(t *testing.T) {
ctx := context.Background()
id := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.1, 0.1, 0.2}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
content := "hello world"
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
Expand Down
12 changes: 8 additions & 4 deletions embed_compat.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,13 @@ const (
// NewEmbeddingFuncMistral returns a function that creates embeddings for a text
// using the Mistral API.
func NewEmbeddingFuncMistral(apiKey string) EmbeddingFunc {
// Mistral embeddings are normalized, see section "Distance Measures" on
// https://docs.mistral.ai/guides/embeddings/.
normalized := true

// The Mistral API docs don't mention the `encoding_format` as optional,
// but it seems to be, just like OpenAI. So we reuse the OpenAI function.
return NewEmbeddingFuncOpenAICompat(baseURLMistral, apiKey, embeddingModelMistral)
return NewEmbeddingFuncOpenAICompat(baseURLMistral, apiKey, embeddingModelMistral, &normalized)
}

const baseURLJina = "https://api.jina.ai/v1"
Expand All @@ -28,7 +32,7 @@ const (
// NewEmbeddingFuncJina returns a function that creates embeddings for a text
// using the Jina API.
func NewEmbeddingFuncJina(apiKey string, model EmbeddingModelJina) EmbeddingFunc {
return NewEmbeddingFuncOpenAICompat(baseURLJina, apiKey, string(model))
return NewEmbeddingFuncOpenAICompat(baseURLJina, apiKey, string(model), nil)
}

const baseURLMixedbread = "https://api.mixedbread.ai"
Expand All @@ -49,7 +53,7 @@ const (
// NewEmbeddingFuncMixedbread returns a function that creates embeddings for a text
// using the mixedbread.ai API.
func NewEmbeddingFuncMixedbread(apiKey string, model EmbeddingModelMixedbread) EmbeddingFunc {
return NewEmbeddingFuncOpenAICompat(baseURLMixedbread, apiKey, string(model))
return NewEmbeddingFuncOpenAICompat(baseURLMixedbread, apiKey, string(model), nil)
}

const baseURLLocalAI = "http://localhost:8080/v1"
Expand All @@ -64,5 +68,5 @@ const baseURLLocalAI = "http://localhost:8080/v1"
// But other embedding models are supported as well. See the LocalAI documentation
// for details.
func NewEmbeddingFuncLocalAI(model string) EmbeddingFunc {
return NewEmbeddingFuncOpenAICompat(baseURLLocalAI, "", model)
return NewEmbeddingFuncOpenAICompat(baseURLLocalAI, "", model, nil)
}
18 changes: 17 additions & 1 deletion embed_ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"io"
"net/http"
"sync"
)

// TODO: Turn into const and use as default, but allow user to pass custom URL
Expand All @@ -28,6 +29,9 @@ func NewEmbeddingFuncOllama(model string) EmbeddingFunc {
// and it might have to be a long timeout, depending on the text length.
client := &http.Client{}

var checkedNormalized bool
checkNormalized := sync.Once{}

return func(ctx context.Context, text string) ([]float32, error) {
// Prepare the request body.
reqBody, err := json.Marshal(map[string]string{
Expand Down Expand Up @@ -74,6 +78,18 @@ func NewEmbeddingFuncOllama(model string) EmbeddingFunc {
return nil, errors.New("no embeddings found in the response")
}

return embeddingResponse.Embedding, nil
v := embeddingResponse.Embedding
checkNormalized.Do(func() {
if isNormalized(v) {
checkedNormalized = true
} else {
checkedNormalized = false
}
})
if !checkedNormalized {
v = normalizeVector(v)
}

return v, nil
}
}
2 changes: 1 addition & 1 deletion embed_ollama_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func TestNewEmbeddingFuncOllama(t *testing.T) {
if err != nil {
t.Fatal("unexpected error:", err)
}
wantRes := []float32{-0.1, 0.1, 0.2}
wantRes := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`

// Mock server
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down
35 changes: 32 additions & 3 deletions embed_openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"io"
"net/http"
"os"
"sync"
)

const BaseURLOpenAI = "https://api.openai.com/v1"
Expand Down Expand Up @@ -39,7 +40,9 @@ func NewEmbeddingFuncDefault() EmbeddingFunc {
// NewEmbeddingFuncOpenAI returns a function that creates embeddings for a text
// using the OpenAI API.
func NewEmbeddingFuncOpenAI(apiKey string, model EmbeddingModelOpenAI) EmbeddingFunc {
return NewEmbeddingFuncOpenAICompat(BaseURLOpenAI, apiKey, string(model))
// OpenAI embeddings are normalized
normalized := true
return NewEmbeddingFuncOpenAICompat(BaseURLOpenAI, apiKey, string(model), &normalized)
}

// NewEmbeddingFuncOpenAICompat returns a function that creates embeddings for a text
Expand All @@ -48,12 +51,20 @@ func NewEmbeddingFuncOpenAI(apiKey string, model EmbeddingModelOpenAI) Embedding
// - LitLLM: https://github.com/BerriAI/litellm
// - Ollama: https://github.com/ollama/ollama/blob/main/docs/openai.md
// - etc.
func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string) EmbeddingFunc {
//
// The `normalized` parameter indicates whether the vectors returned by the embedding
// model are already normalized, as is the case for OpenAI's and Mistral's models.
// The flag is optional. If it's nil, it will be autodetected on the first request
// (which bears a small risk that the vector just happens to have a length of 1).
func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *bool) EmbeddingFunc {
// We don't set a default timeout here, although it's usually a good idea.
// In our case though, the library user can set the timeout on the context,
// and it might have to be a long timeout, depending on the text length.
client := &http.Client{}

var checkedNormalized bool
checkNormalized := sync.Once{}

return func(ctx context.Context, text string) ([]float32, error) {
// Prepare the request body.
reqBody, err := json.Marshal(map[string]string{
Expand Down Expand Up @@ -101,6 +112,24 @@ func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string) EmbeddingFunc {
return nil, errors.New("no embeddings found in the response")
}

return embeddingResponse.Data[0].Embedding, nil
v := embeddingResponse.Data[0].Embedding
if normalized != nil {
if *normalized {
return v, nil
}
return normalizeVector(v), nil
}
checkNormalized.Do(func() {
if isNormalized(v) {
checkedNormalized = true
} else {
checkedNormalized = false
}
})
if !checkedNormalized {
v = normalizeVector(v)
}

return v, nil
}
}
4 changes: 2 additions & 2 deletions embed_openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func TestNewEmbeddingFuncOpenAICompat(t *testing.T) {
if err != nil {
t.Fatal("unexpected error:", err)
}
wantRes := []float32{-0.1, 0.1, 0.2}
wantRes := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`

// Mock server
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -75,7 +75,7 @@ func TestNewEmbeddingFuncOpenAICompat(t *testing.T) {
defer ts.Close()
baseURL := ts.URL + baseURLSuffix

f := chromem.NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model)
f := chromem.NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model, nil)
res, err := f(context.Background(), input)
if err != nil {
t.Fatal("expected nil, got", err)
Expand Down
2 changes: 1 addition & 1 deletion persistence_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func TestPersistence(t *testing.T) {
}
obj := s{
Foo: "test",
Bar: []float32{-0.1, 0.1, 0.2},
Bar: []float32{-0.40824828, 0.40824828, 0.81649655}, // normalized version of `{-0.1, 0.1, 0.2}`
}

persist(tempDir, obj)
Expand Down
3 changes: 2 additions & 1 deletion query.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Docu
return
}

sim, err := cosineSimilarity(queryVectors, doc.Embedding)
// As the vectors are normalized, the dot product is the cosine similarity.
sim, err := dotProduct(queryVectors, doc.Embedding)
if err != nil {
setSharedErr(fmt.Errorf("couldn't calculate similarity for document '%s': %w", doc.ID, err))
return
Expand Down
22 changes: 21 additions & 1 deletion vector.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package chromem

import (
"errors"
"fmt"
"math"
)

Expand All @@ -20,11 +21,30 @@ func cosineSimilarity(a, b []float32) (float32, error) {
if !isNormalized(a) || !isNormalized(b) {
a, b = normalizeVector(a), normalizeVector(b)
}
dotProduct, err := dotProduct(a, b)
if err != nil {
return 0, fmt.Errorf("couldn't calculate dot product: %w", err)
}

// Vectors are already normalized, so no need to divide by magnitudes

return dotProduct, nil
}

// dotProduct calculates the dot product between two vectors.
// It's the same as cosine similarity for normalized vectors.
// The resulting value represents the similarity, so a higher value means the
// vectors are more similar.
func dotProduct(a, b []float32) (float32, error) {
// The vectors must have the same length
if len(a) != len(b) {
return 0, errors.New("vectors must have the same length")
}

var dotProduct float32
for i := range a {
dotProduct += a[i] * b[i]
}
// Vectors are already normalized, so no need to divide by magnitudes

return dotProduct, nil
}
Expand Down

0 comments on commit 579fd46

Please sign in to comment.