Skip to content

Commit

Permalink
Implement embeddings observer (#195)
Browse files Browse the repository at this point in the history
  • Loading branch information
henomis authored May 14, 2024
1 parent 5727aae commit fae73dc
Show file tree
Hide file tree
Showing 10 changed files with 389 additions and 23 deletions.
49 changes: 47 additions & 2 deletions embedder/cohere/cohere.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ import (
"github.com/henomis/cohere-go/model"
"github.com/henomis/cohere-go/request"
"github.com/henomis/cohere-go/response"

"github.com/henomis/lingoose/embedder"
embobserver "github.com/henomis/lingoose/embedder/observer"
"github.com/henomis/lingoose/observer"
)

type EmbedderModel = model.EmbedModel
Expand Down Expand Up @@ -36,14 +39,18 @@ var EmbedderModelsSize = map[EmbedderModel]int{
}

type Embedder struct {
model EmbedderModel
client *coherego.Client
model EmbedderModel
client *coherego.Client
name string
observer embobserver.EmbeddingObserver
observerTraceID string
}

func New() *Embedder {
return &Embedder{
client: coherego.New(os.Getenv("COHERE_API_KEY")),
model: defaultEmbedderModel,
name: "cohere",
}
}

Expand All @@ -61,6 +68,44 @@ func (e *Embedder) WithModel(model EmbedderModel) *Embedder {

// Embed returns the embeddings for the given texts
func (e *Embedder) Embed(ctx context.Context, texts []string) ([]embedder.Embedding, error) {
var observerEmbedding *observer.Embedding
var err error

if e.observer != nil {
observerEmbedding, err = embobserver.StartObserveEmbedding(
e.observer,
e.name,
string(e.model),
nil,
e.observerTraceID,
observer.ContextValueParentID(ctx),
texts,
)
if err != nil {
return nil, err
}
}

embeddings, err := e.embed(ctx, texts)
if err != nil {
return nil, err
}

if e.observer != nil {
err = embobserver.StopObserveEmbedding(
e.observer,
observerEmbedding,
embeddings,
)
if err != nil {
return nil, err
}
}

return embeddings, nil
}

func (e *Embedder) embed(ctx context.Context, texts []string) ([]embedder.Embedding, error) {
resp := &response.Embed{}
err := e.client.Embed(
ctx,
Expand Down
57 changes: 53 additions & 4 deletions embedder/huggingface/huggingface.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,29 @@ import (
"os"

"github.com/henomis/lingoose/embedder"
embobserver "github.com/henomis/lingoose/embedder/observer"
"github.com/henomis/lingoose/observer"
)

const (
hfDefaultEmbedderModel = "sentence-transformers/all-MiniLM-L6-v2"
)

type HuggingFaceEmbedder struct {
token string
model string
httpClient *http.Client
token string
model string
httpClient *http.Client
name string
observer embobserver.EmbeddingObserver
observerTraceID string
}

func New() *HuggingFaceEmbedder {
return &HuggingFaceEmbedder{
token: os.Getenv("HUGGING_FACE_HUB_TOKEN"),
model: hfDefaultEmbedderModel,
httpClient: http.DefaultClient,
name: "huggingface",
}
}

Expand All @@ -38,6 +44,15 @@ func (h *HuggingFaceEmbedder) WithModel(model string) *HuggingFaceEmbedder {
return h
}

func (h *HuggingFaceEmbedder) WithObserver(
observer embobserver.EmbeddingObserver,
traceID string,
) *HuggingFaceEmbedder {
h.observer = observer
h.observerTraceID = traceID
return h
}

// WithHTTPClient sets the http client to use for the LLM
func (h *HuggingFaceEmbedder) WithHTTPClient(httpClient *http.Client) *HuggingFaceEmbedder {
h.httpClient = httpClient
Expand All @@ -46,5 +61,39 @@ func (h *HuggingFaceEmbedder) WithHTTPClient(httpClient *http.Client) *HuggingFa

// Embed returns the embeddings for the given texts
func (h *HuggingFaceEmbedder) Embed(ctx context.Context, texts []string) ([]embedder.Embedding, error) {
return h.featureExtraction(ctx, texts)
var observerEmbedding *observer.Embedding
var err error

if h.observer != nil {
observerEmbedding, err = embobserver.StartObserveEmbedding(
h.observer,
h.name,
h.model,
nil,
h.observerTraceID,
observer.ContextValueParentID(ctx),
texts,
)
if err != nil {
return nil, err
}
}

embeddings, err := h.featureExtraction(ctx, texts)
if err != nil {
return nil, err
}

if h.observer != nil {
err = embobserver.StopObserveEmbedding(
h.observer,
observerEmbedding,
embeddings,
)
if err != nil {
return nil, err
}
}

return embeddings, nil
}
52 changes: 47 additions & 5 deletions embedder/nomic/nomic.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@ import (
"net/http"
"os"

"github.com/henomis/lingoose/embedder"
"github.com/henomis/restclientgo"

"github.com/henomis/lingoose/embedder"
embobserver "github.com/henomis/lingoose/embedder/observer"
"github.com/henomis/lingoose/observer"
)

const (
Expand All @@ -15,9 +18,12 @@ const (
)

type Embedder struct {
taskType TaskType
model Model
restClient *restclientgo.RestClient
taskType TaskType
model Model
restClient *restclientgo.RestClient
name string
observer embobserver.EmbeddingObserver
observerTraceID string
}

func New() *Embedder {
Expand All @@ -31,6 +37,7 @@ func New() *Embedder {
},
),
model: defaultModel,
name: "nomic",
}
}

Expand All @@ -54,10 +61,34 @@ func (e *Embedder) WithModel(model Model) *Embedder {
return e
}

func (e *Embedder) WithObserver(observer embobserver.EmbeddingObserver, traceID string) *Embedder {
e.observer = observer
e.observerTraceID = traceID
return e
}

// Embed returns the embeddings for the given texts
func (e *Embedder) Embed(ctx context.Context, texts []string) ([]embedder.Embedding, error) {
var observerEmbedding *observer.Embedding
var err error

if e.observer != nil {
observerEmbedding, err = embobserver.StartObserveEmbedding(
e.observer,
e.name,
string(e.model),
nil,
e.observerTraceID,
observer.ContextValueParentID(ctx),
texts,
)
if err != nil {
return nil, err
}
}

var resp response
err := e.restClient.Post(
err = e.restClient.Post(
ctx,
&request{
Texts: texts,
Expand All @@ -70,5 +101,16 @@ func (e *Embedder) Embed(ctx context.Context, texts []string) ([]embedder.Embedd
return nil, err
}

if e.observer != nil {
err = embobserver.StopObserveEmbedding(
e.observer,
observerEmbedding,
resp.Embeddings,
)
if err != nil {
return nil, err
}
}

return resp.Embeddings, nil
}
49 changes: 49 additions & 0 deletions embedder/observer/observer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package observer

import (
"fmt"

"github.com/henomis/lingoose/embedder"
obs "github.com/henomis/lingoose/observer"
"github.com/henomis/lingoose/types"
)

type EmbeddingObserver interface {
Embedding(*obs.Embedding) (*obs.Embedding, error)
EmbeddingEnd(*obs.Embedding) (*obs.Embedding, error)
}

func StartObserveEmbedding(
o EmbeddingObserver,
name string,
modelName string,
ModelParameters types.M,
traceID string,
parentID string,
texts []string,
) (*obs.Embedding, error) {
embedding, err := o.Embedding(
&obs.Embedding{
TraceID: traceID,
ParentID: parentID,
Name: fmt.Sprintf("embedding-%s", name),
Model: modelName,
ModelParameters: ModelParameters,
Input: texts,
},
)
if err != nil {
return nil, err
}
return embedding, nil
}

func StopObserveEmbedding(
o EmbeddingObserver,
embedding *obs.Embedding,
embeddings []embedder.Embedding,
) error {
embedding.Output = embeddings
_, err := o.EmbeddingEnd(embedding)
return err
}
52 changes: 47 additions & 5 deletions embedder/ollama/ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ package ollamaembedder
import (
"context"

"github.com/henomis/lingoose/embedder"
"github.com/henomis/restclientgo"

"github.com/henomis/lingoose/embedder"
embobserver "github.com/henomis/lingoose/embedder/observer"
"github.com/henomis/lingoose/observer"
)

const (
Expand All @@ -13,14 +16,18 @@ const (
)

type Embedder struct {
model string
restClient *restclientgo.RestClient
model string
restClient *restclientgo.RestClient
name string
observer embobserver.EmbeddingObserver
observerTraceID string
}

func New() *Embedder {
return &Embedder{
restClient: restclientgo.New(defaultEndpoint),
model: defaultModel,
name: "ollama",
}
}

Expand All @@ -34,15 +41,50 @@ func (e *Embedder) WithModel(model string) *Embedder {
return e
}

func (e *Embedder) WithObserver(observer embobserver.EmbeddingObserver, traceID string) *Embedder {
e.observer = observer
e.observerTraceID = traceID
return e
}

// Embed returns the embeddings for the given texts
func (e *Embedder) Embed(ctx context.Context, texts []string) ([]embedder.Embedding, error) {
var observerEmbedding *observer.Embedding
var err error

if e.observer != nil {
observerEmbedding, err = embobserver.StartObserveEmbedding(
e.observer,
e.name,
e.model,
nil,
e.observerTraceID,
observer.ContextValueParentID(ctx),
texts,
)
if err != nil {
return nil, err
}
}

embeddings := make([]embedder.Embedding, len(texts))
for i, text := range texts {
embedding, err := e.embed(ctx, text)
embedding, errEmbedd := e.embed(ctx, text)
if errEmbedd != nil {
return nil, errEmbedd
}
embeddings[i] = embedding
}

if e.observer != nil {
err = embobserver.StopObserveEmbedding(
e.observer,
observerEmbedding,
embeddings,
)
if err != nil {
return nil, err
}
embeddings[i] = embedding
}

return embeddings, nil
Expand Down
Loading

0 comments on commit fae73dc

Please sign in to comment.