Skip to content

Commit

Permalink
Add observer to all the llms (#191)
Browse files Browse the repository at this point in the history
  • Loading branch information
henomis authored May 1, 2024
1 parent d08a3ca commit 863ce5c
Show file tree
Hide file tree
Showing 7 changed files with 269 additions and 49 deletions.
60 changes: 58 additions & 2 deletions llm/antropic/antropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,13 @@ import (
"os"
"strings"

"github.com/henomis/restclientgo"

"github.com/henomis/lingoose/llm/cache"
llmobserver "github.com/henomis/lingoose/llm/observer"
"github.com/henomis/lingoose/observer"
"github.com/henomis/lingoose/thread"
"github.com/henomis/restclientgo"
"github.com/henomis/lingoose/types"
)

const (
Expand Down Expand Up @@ -48,6 +52,9 @@ type Antropic struct {
apiVersion string
apiKey string
maxTokens int
name string
observer llmobserver.LLMObserver
observerTraceID string
}

func New() *Antropic {
Expand All @@ -65,6 +72,7 @@ func New() *Antropic {
apiVersion: defaultAPIVersion,
apiKey: apiKey,
maxTokens: defaultMaxTokens,
name: "anthropic",
}
}

Expand Down Expand Up @@ -93,6 +101,12 @@ func (o *Antropic) WithMaxTokens(maxTokens int) *Antropic {
return o
}

func (o *Antropic) WithObserver(observer llmobserver.LLMObserver, traceID string) *Antropic {
o.observer = observer
o.observerTraceID = traceID
return o
}

func (o *Antropic) getCache(ctx context.Context, t *thread.Thread) (*cache.Result, error) {
messages := t.UserQuery()
cacheQuery := strings.Join(messages, "\n")
Expand Down Expand Up @@ -151,16 +165,31 @@ func (o *Antropic) Generate(ctx context.Context, t *thread.Thread) error {

chatRequest := o.buildChatCompletionRequest(t)

var span *observer.Span
var generation *observer.Generation
if o.observer != nil {
span, generation, err = o.startObserveGeneration(t)
if err != nil {
return fmt.Errorf("%w: %w", ErrAnthropicChat, err)
}
}

if o.streamCallbackFn != nil {
err = o.stream(ctx, t, chatRequest)
} else {
err = o.generate(ctx, t, chatRequest)
}

if err != nil {
return err
}

if o.observer != nil {
err = o.stopObserveGeneration(span, generation, t)
if err != nil {
return fmt.Errorf("%w: %w", ErrAnthropicChat, err)
}
}

if o.cache != nil {
err = o.setCache(ctx, t, cacheResult)
if err != nil {
Expand Down Expand Up @@ -249,3 +278,30 @@ func (o *Antropic) stream(ctx context.Context, t *thread.Thread, chatRequest *re

return nil
}

func (o *Antropic) startObserveGeneration(t *thread.Thread) (*observer.Span, *observer.Generation, error) {
return llmobserver.SartObserveGeneration(
o.observer,
o.name,
o.model,
types.M{
"maxTokens": o.maxTokens,
"temperature": o.temperature,
},
o.observerTraceID,
t,
)
}

func (o *Antropic) stopObserveGeneration(
span *observer.Span,
generation *observer.Generation,
t *thread.Thread,
) error {
return llmobserver.StopObserveGeneration(
o.observer,
span,
generation,
t,
)
}
58 changes: 57 additions & 1 deletion llm/cohere/cohere.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,13 @@ import (
"github.com/henomis/cohere-go/model"
"github.com/henomis/cohere-go/request"
"github.com/henomis/cohere-go/response"

"github.com/henomis/lingoose/legacy/chat"
"github.com/henomis/lingoose/llm/cache"
llmobserver "github.com/henomis/lingoose/llm/observer"
"github.com/henomis/lingoose/observer"
"github.com/henomis/lingoose/thread"
"github.com/henomis/lingoose/types"
)

var (
Expand Down Expand Up @@ -46,6 +50,9 @@ type Cohere struct {
stop []string
cache *cache.Cache
streamCallbackFn StreamCallbackFn
name string
observer llmobserver.LLMObserver
observerTraceID string
}

func (c *Cohere) WithCache(cache *cache.Cache) *Cohere {
Expand All @@ -64,6 +71,7 @@ func New() *Cohere {
model: DefaultModel,
temperature: DefaultTemperature,
maxTokens: DefaultMaxTokens,
name: "cohere",
}
}

Expand Down Expand Up @@ -108,6 +116,12 @@ func (c *Cohere) WithStream(callbackFn StreamCallbackFn) *Cohere {
return c
}

func (c *Cohere) WithObserver(observer llmobserver.LLMObserver, traceID string) *Cohere {
c.observer = observer
c.observerTraceID = traceID
return c
}

// Completion returns the completion for the given prompt
func (c *Cohere) Completion(ctx context.Context, prompt string) (string, error) {
resp := &response.Generate{}
Expand Down Expand Up @@ -205,16 +219,31 @@ func (c *Cohere) Generate(ctx context.Context, t *thread.Thread) error {

chatRequest := c.buildChatCompletionRequest(t)

var span *observer.Span
var generation *observer.Generation
if c.observer != nil {
span, generation, err = c.startObserveGeneration(t)
if err != nil {
return fmt.Errorf("%w: %w", ErrCohereChat, err)
}
}

if c.streamCallbackFn != nil {
err = c.stream(ctx, t, chatRequest)
} else {
err = c.generate(ctx, t, chatRequest)
}

if err != nil {
return err
}

if c.observer != nil {
err = c.stopObserveGeneration(span, generation, t)
if err != nil {
return fmt.Errorf("%w: %w", ErrCohereChat, err)
}
}

if c.cache != nil {
err = c.setCache(ctx, t, cacheResult)
if err != nil {
Expand Down Expand Up @@ -269,3 +298,30 @@ func (c *Cohere) stream(ctx context.Context, t *thread.Thread, chatRequest *requ

return nil
}

func (c *Cohere) startObserveGeneration(t *thread.Thread) (*observer.Span, *observer.Generation, error) {
return llmobserver.SartObserveGeneration(
c.observer,
c.name,
string(c.model),
types.M{
"maxTokens": c.maxTokens,
"temperature": c.temperature,
},
c.observerTraceID,
t,
)
}

func (c *Cohere) stopObserveGeneration(
span *observer.Span,
generation *observer.Generation,
t *thread.Thread,
) error {
return llmobserver.StopObserveGeneration(
c.observer,
span,
generation,
t,
)
}
1 change: 1 addition & 0 deletions llm/groq/groq.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ func New() *Groq {
customClient := goopenai.NewClientWithConfig(customConfig)

openaillm := openai.New().WithClient(customClient)
openaillm.Name = "groq"
return &Groq{
OpenAI: openaillm,
}
Expand Down
1 change: 1 addition & 0 deletions llm/localai/localai.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ func New(endpoint string) *LocalAI {
customClient := goopenai.NewClientWithConfig(customConfig)

openaillm := openai.New().WithClient(customClient)
openaillm.Name = "localai"
return &LocalAI{
OpenAI: openaillm,
}
Expand Down
66 changes: 66 additions & 0 deletions llm/observer/observer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package observer

import (
"fmt"

obs "github.com/henomis/lingoose/observer"
"github.com/henomis/lingoose/thread"
"github.com/henomis/lingoose/types"
)

type LLMObserver interface {
Span(*obs.Span) (*obs.Span, error)
SpanEnd(*obs.Span) (*obs.Span, error)
Generation(*obs.Generation) (*obs.Generation, error)
GenerationEnd(*obs.Generation) (*obs.Generation, error)
}

func SartObserveGeneration(
o LLMObserver,
name string,
modelName string,
ModelParameters types.M,
traceID string,
t *thread.Thread,
) (*obs.Span, *obs.Generation, error) {
span, err := o.Span(
&obs.Span{
TraceID: traceID,
Name: name,
},
)
if err != nil {
return nil, nil, err
}

generation, err := o.Generation(
&obs.Generation{
TraceID: traceID,
ParentID: span.ID,
Name: fmt.Sprintf("%s-%s", name, modelName),
Model: modelName,
ModelParameters: ModelParameters,
Input: t.Messages,
},
)
if err != nil {
return nil, nil, err
}
return span, generation, nil
}

func StopObserveGeneration(
o LLMObserver,
span *obs.Span,
generation *obs.Generation,
t *thread.Thread,
) error {
_, err := o.SpanEnd(span)
if err != nil {
return err
}

generation.Output = t.LastMessage()
_, err = o.GenerationEnd(generation)
return err
}
Loading

0 comments on commit 863ce5c

Please sign in to comment.