diff --git a/embedder/llamacpp/llamacpp.go b/embedder/llamacpp/llamacpp.go index c7d7868a..632a124e 100644 --- a/embedder/llamacpp/llamacpp.go +++ b/embedder/llamacpp/llamacpp.go @@ -2,10 +2,10 @@ package llamacppembedder import ( "context" + "encoding/json" + "errors" "os" "os/exec" - "strconv" - "strings" "github.com/henomis/lingoose/embedder" ) @@ -16,6 +16,16 @@ type LlamaCppEmbedder struct { modelPath string } +type output struct { + Object string `json:"object"` + Data []data `json:"data"` +} +type data struct { + Object string `json:"object"` + Index int `json:"index"` + Embedding []float64 `json:"embedding"` +} + func New() *LlamaCppEmbedder { return &LlamaCppEmbedder{ llamacppPath: "./llama.cpp/embedding", @@ -61,7 +71,7 @@ func (l *LlamaCppEmbedder) embed(ctx context.Context, text string) (embedder.Emb return nil, err } - llamacppArgs := []string{"-m", l.modelPath, "-p", text} + llamacppArgs := []string{"-m", l.modelPath, "--embd-output-format", "json", "-p", text} llamacppArgs = append(llamacppArgs, l.llamacppArgs...) //nolint:gosec @@ -74,14 +84,15 @@ func (l *LlamaCppEmbedder) embed(ctx context.Context, text string) (embedder.Emb } func parseEmbeddings(str string) (embedder.Embedding, error) { - strSlice := strings.Split(strings.TrimSpace(str), " ") - floatSlice := make([]float64, len(strSlice)) - for i, s := range strSlice { - f, err := strconv.ParseFloat(s, 64) - if err != nil { - return nil, err - } - floatSlice[i] = f + var output output + err := json.Unmarshal([]byte(str), &output) + if err != nil { + return nil, err } - return floatSlice, nil + + if len(output.Data) != 1 { + return nil, errors.New("no embeddings found") + } + + return output.Data[0].Embedding, nil }