Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement embeddings observer #195

Merged
merged 7 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 47 additions & 2 deletions embedder/cohere/cohere.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ import (
"github.com/henomis/cohere-go/model"
"github.com/henomis/cohere-go/request"
"github.com/henomis/cohere-go/response"

"github.com/henomis/lingoose/embedder"
embobserver "github.com/henomis/lingoose/embedder/observer"
"github.com/henomis/lingoose/observer"
)

type EmbedderModel = model.EmbedModel
Expand Down Expand Up @@ -36,14 +39,18 @@ var EmbedderModelsSize = map[EmbedderModel]int{
}

type Embedder struct {
model EmbedderModel
client *coherego.Client
model EmbedderModel
client *coherego.Client
name string
observer embobserver.EmbeddingObserver
observerTraceID string
}

func New() *Embedder {
return &Embedder{
client: coherego.New(os.Getenv("COHERE_API_KEY")),
model: defaultEmbedderModel,
name: "cohere",
}
}

Expand All @@ -61,6 +68,44 @@ func (e *Embedder) WithModel(model EmbedderModel) *Embedder {

// Embed returns the embeddings for the given texts
func (e *Embedder) Embed(ctx context.Context, texts []string) ([]embedder.Embedding, error) {
var observerEmbedding *observer.Embedding
var err error

if e.observer != nil {
observerEmbedding, err = embobserver.StartObserveEmbedding(
e.observer,
e.name,
string(e.model),
nil,
e.observerTraceID,
observer.ContextValueParentID(ctx),
texts,
)
if err != nil {
return nil, err
}
}

embeddings, err := e.embed(ctx, texts)
if err != nil {
return nil, err
}

if e.observer != nil {
err = embobserver.StopObserveEmbedding(
e.observer,
observerEmbedding,
embeddings,
)
if err != nil {
return nil, err
}
}

return embeddings, nil
}

func (e *Embedder) embed(ctx context.Context, texts []string) ([]embedder.Embedding, error) {
resp := &response.Embed{}
err := e.client.Embed(
ctx,
Expand Down
57 changes: 53 additions & 4 deletions embedder/huggingface/huggingface.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,29 @@ import (
"os"

"github.com/henomis/lingoose/embedder"
embobserver "github.com/henomis/lingoose/embedder/observer"
"github.com/henomis/lingoose/observer"
)

const (
hfDefaultEmbedderModel = "sentence-transformers/all-MiniLM-L6-v2"
)

type HuggingFaceEmbedder struct {
token string
model string
httpClient *http.Client
token string
model string
httpClient *http.Client
name string
observer embobserver.EmbeddingObserver
observerTraceID string
}

func New() *HuggingFaceEmbedder {
return &HuggingFaceEmbedder{
token: os.Getenv("HUGGING_FACE_HUB_TOKEN"),
model: hfDefaultEmbedderModel,
httpClient: http.DefaultClient,
name: "huggingface",
}
}

Expand All @@ -38,6 +44,15 @@ func (h *HuggingFaceEmbedder) WithModel(model string) *HuggingFaceEmbedder {
return h
}

func (h *HuggingFaceEmbedder) WithObserver(
observer embobserver.EmbeddingObserver,
traceID string,
) *HuggingFaceEmbedder {
h.observer = observer
h.observerTraceID = traceID
return h
}

// WithHTTPClient sets the http client to use for the LLM
func (h *HuggingFaceEmbedder) WithHTTPClient(httpClient *http.Client) *HuggingFaceEmbedder {
h.httpClient = httpClient
Expand All @@ -46,5 +61,39 @@ func (h *HuggingFaceEmbedder) WithHTTPClient(httpClient *http.Client) *HuggingFa

// Embed returns the embeddings for the given texts
func (h *HuggingFaceEmbedder) Embed(ctx context.Context, texts []string) ([]embedder.Embedding, error) {
return h.featureExtraction(ctx, texts)
var observerEmbedding *observer.Embedding
var err error

if h.observer != nil {
observerEmbedding, err = embobserver.StartObserveEmbedding(
h.observer,
h.name,
h.model,
nil,
h.observerTraceID,
observer.ContextValueParentID(ctx),
texts,
)
if err != nil {
return nil, err
}
}

embeddings, err := h.featureExtraction(ctx, texts)
if err != nil {
return nil, err
}

if h.observer != nil {
err = embobserver.StopObserveEmbedding(
h.observer,
observerEmbedding,
embeddings,
)
if err != nil {
return nil, err
}
}

return embeddings, nil
}
52 changes: 47 additions & 5 deletions embedder/nomic/nomic.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@ import (
"net/http"
"os"

"github.com/henomis/lingoose/embedder"
"github.com/henomis/restclientgo"

"github.com/henomis/lingoose/embedder"
embobserver "github.com/henomis/lingoose/embedder/observer"
"github.com/henomis/lingoose/observer"
)

const (
Expand All @@ -15,9 +18,12 @@ const (
)

type Embedder struct {
taskType TaskType
model Model
restClient *restclientgo.RestClient
taskType TaskType
model Model
restClient *restclientgo.RestClient
name string
observer embobserver.EmbeddingObserver
observerTraceID string
}

func New() *Embedder {
Expand All @@ -31,6 +37,7 @@ func New() *Embedder {
},
),
model: defaultModel,
name: "nomic",
}
}

Expand All @@ -54,10 +61,34 @@ func (e *Embedder) WithModel(model Model) *Embedder {
return e
}

func (e *Embedder) WithObserver(observer embobserver.EmbeddingObserver, traceID string) *Embedder {
e.observer = observer
e.observerTraceID = traceID
return e
}

// Embed returns the embeddings for the given texts
func (e *Embedder) Embed(ctx context.Context, texts []string) ([]embedder.Embedding, error) {
var observerEmbedding *observer.Embedding
var err error

if e.observer != nil {
observerEmbedding, err = embobserver.StartObserveEmbedding(
e.observer,
e.name,
string(e.model),
nil,
e.observerTraceID,
observer.ContextValueParentID(ctx),
texts,
)
if err != nil {
return nil, err
}
}

var resp response
err := e.restClient.Post(
err = e.restClient.Post(
ctx,
&request{
Texts: texts,
Expand All @@ -70,5 +101,16 @@ func (e *Embedder) Embed(ctx context.Context, texts []string) ([]embedder.Embedd
return nil, err
}

if e.observer != nil {
err = embobserver.StopObserveEmbedding(
e.observer,
observerEmbedding,
resp.Embeddings,
)
if err != nil {
return nil, err
}
}

return resp.Embeddings, nil
}
49 changes: 49 additions & 0 deletions embedder/observer/observer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package observer

import (
"fmt"

"github.com/henomis/lingoose/embedder"
obs "github.com/henomis/lingoose/observer"
"github.com/henomis/lingoose/types"
)

type EmbeddingObserver interface {
Embedding(*obs.Embedding) (*obs.Embedding, error)
EmbeddingEnd(*obs.Embedding) (*obs.Embedding, error)
}

func StartObserveEmbedding(
o EmbeddingObserver,
name string,
modelName string,
ModelParameters types.M,
traceID string,
parentID string,
texts []string,
) (*obs.Embedding, error) {
embedding, err := o.Embedding(
&obs.Embedding{
TraceID: traceID,
ParentID: parentID,
Name: fmt.Sprintf("embedding-%s", name),
Model: modelName,
ModelParameters: ModelParameters,
Input: texts,
},
)
if err != nil {
return nil, err
}
return embedding, nil
}

func StopObserveEmbedding(
o EmbeddingObserver,
embedding *obs.Embedding,
embeddings []embedder.Embedding,
) error {
embedding.Output = embeddings
_, err := o.EmbeddingEnd(embedding)
return err
}
52 changes: 47 additions & 5 deletions embedder/ollama/ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ package ollamaembedder
import (
"context"

"github.com/henomis/lingoose/embedder"
"github.com/henomis/restclientgo"

"github.com/henomis/lingoose/embedder"
embobserver "github.com/henomis/lingoose/embedder/observer"
"github.com/henomis/lingoose/observer"
)

const (
Expand All @@ -13,14 +16,18 @@ const (
)

type Embedder struct {
model string
restClient *restclientgo.RestClient
model string
restClient *restclientgo.RestClient
name string
observer embobserver.EmbeddingObserver
observerTraceID string
}

func New() *Embedder {
return &Embedder{
restClient: restclientgo.New(defaultEndpoint),
model: defaultModel,
name: "ollama",
}
}

Expand All @@ -34,15 +41,50 @@ func (e *Embedder) WithModel(model string) *Embedder {
return e
}

func (e *Embedder) WithObserver(observer embobserver.EmbeddingObserver, traceID string) *Embedder {
e.observer = observer
e.observerTraceID = traceID
return e
}

// Embed returns the embeddings for the given texts
func (e *Embedder) Embed(ctx context.Context, texts []string) ([]embedder.Embedding, error) {
var observerEmbedding *observer.Embedding
var err error

if e.observer != nil {
observerEmbedding, err = embobserver.StartObserveEmbedding(
e.observer,
e.name,
e.model,
nil,
e.observerTraceID,
observer.ContextValueParentID(ctx),
texts,
)
if err != nil {
return nil, err
}
}

embeddings := make([]embedder.Embedding, len(texts))
for i, text := range texts {
embedding, err := e.embed(ctx, text)
embedding, errEmbedd := e.embed(ctx, text)
if errEmbedd != nil {
return nil, errEmbedd
}
embeddings[i] = embedding
}

if e.observer != nil {
err = embobserver.StopObserveEmbedding(
e.observer,
observerEmbedding,
embeddings,
)
if err != nil {
return nil, err
}
embeddings[i] = embedding
}

return embeddings, nil
Expand Down
Loading
Loading