Skip to content

Commit

Permalink
Add Ollama embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
philippgille committed Mar 2, 2024
1 parent cabf8cc commit 77245c1
Showing 1 changed file with 77 additions and 0 deletions.
77 changes: 77 additions & 0 deletions embed_ollama.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package chromem

import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
)

const baseURLOllama = "http://localhost:11434/api"

type ollamaResponse struct {
Embedding []float32 `json:"embedding"`
}

// NewEmbeddingFuncOllama returns a function that creates embeddings for a document
// using Ollama's embedding API. You can pass any model that Ollama supports and
// that supports embeddings. A good one as of 2024-03-02 is "nomic-embed-text".
// See https://ollama.com/library/nomic-embed-text
func NewEmbeddingFuncOllama(model string) EmbeddingFunc {
// We don't set a default timeout here, although it's usually a good idea.
// In our case though, the library user can set the timeout on the context,
// and it might have to be a long timeout, depending on the document size.
client := &http.Client{}

return func(ctx context.Context, document string) ([]float32, error) {
// Prepare the request body.
reqBody, err := json.Marshal(map[string]string{
"model": model,
"prompt": document,
})
if err != nil {
return nil, fmt.Errorf("couldn't marshal request body: %w", err)
}

// Create the request. Creating it with context is important for a timeout
// to be possible, because the client is configured without a timeout.
req, err := http.NewRequestWithContext(ctx, "POST", baseURLOllama+"/embeddings", bytes.NewBuffer(reqBody))
if err != nil {
return nil, fmt.Errorf("couldn't create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")

// Send the request.
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("couldn't send request: %w", err)
}
defer resp.Body.Close()

// Check the response status.
if resp.StatusCode != http.StatusOK {
return nil, errors.New("error response from the embedding API: " + resp.Status)
}

// Read and decode the response body.
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("couldn't read response body: %w", err)
}
var embeddingResponse ollamaResponse
err = json.Unmarshal(body, &embeddingResponse)
if err != nil {
return nil, fmt.Errorf("couldn't unmarshal response body: %w", err)
}

// Check if the response contains embeddings.
if len(embeddingResponse.Embedding) == 0 {
return nil, errors.New("no embeddings found in the response")
}

return embeddingResponse.Embedding, nil
}
}

0 comments on commit 77245c1

Please sign in to comment.