Skip to content

Commit

Permalink
Refactor ssistant observer (#206)
Browse files Browse the repository at this point in the history
* refactor assistant observer

* fix: linting
  • Loading branch information
henomis authored Jun 2, 2024
1 parent 40f3055 commit a86c0e3
Show file tree
Hide file tree
Showing 11 changed files with 100 additions and 28 deletions.
3 changes: 0 additions & 3 deletions assistant/assistant.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package assistant

import (
"context"
"fmt"
"strings"

obs "github.com/henomis/lingoose/observer"
Expand Down Expand Up @@ -103,8 +102,6 @@ func (a *Assistant) Run(ctx context.Context) error {
if a.thread.LastMessage().Role != thread.RoleTool {
break
}

fmt.Println(a.Thread())
}

err = a.stopObserveSpan(ctx, spanAssistant)
Expand Down
16 changes: 15 additions & 1 deletion examples/assistant/agent/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,25 @@ import (

"github.com/henomis/lingoose/assistant"
"github.com/henomis/lingoose/llm/openai"
"github.com/henomis/lingoose/observer"
"github.com/henomis/lingoose/observer/langfuse"
"github.com/henomis/lingoose/thread"

pythontool "github.com/henomis/lingoose/tool/python"
serpapitool "github.com/henomis/lingoose/tool/serpapi"
)

func main() {
ctx := context.Background()

o := langfuse.New(ctx)
trace, err := o.Trace(&observer.Trace{Name: "state of the union"})
if err != nil {
panic(err)
}

ctx = observer.ContextWithObserverInstance(ctx, o)
ctx = observer.ContextWithTraceID(ctx, trace.ID)

auto := "auto"
a := assistant.New(
Expand All @@ -36,12 +48,14 @@ func main() {
),
).WithMaxIterations(10)

err := a.Run(context.Background())
err = a.Run(ctx)
if err != nil {
panic(err)
}

fmt.Println("----")
fmt.Println(a.Thread())
fmt.Println("----")

o.Flush(ctx)
}
2 changes: 2 additions & 0 deletions examples/observer/assistant/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,6 @@ func main() {
fmt.Println("----")
fmt.Println(a.Thread())
fmt.Println("----")

o.Flush(ctx)
}
14 changes: 8 additions & 6 deletions examples/observer/langfuse/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,14 @@ func main() {
panic(err)
}

generation.Output = &thread.Message{
Role: thread.RoleAssistant,
Contents: []*thread.Content{
{
Type: thread.ContentTypeText,
Data: "The Q3 OKRs contain goals for multiple teams...",
generation.Output = []*thread.Message{
{
Role: thread.RoleAssistant,
Contents: []*thread.Content{
{
Type: thread.ContentTypeText,
Data: "The Q3 OKRs contain goals for multiple teams...",
},
},
},
}
Expand Down
6 changes: 3 additions & 3 deletions llm/antropic/antropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ func (o *Antropic) Generate(ctx context.Context, t *thread.Thread) error {
return err
}

err = o.stopObserveGeneration(ctx, generation, t)
err = o.stopObserveGeneration(ctx, generation, []*thread.Message{t.LastMessage()})
if err != nil {
return fmt.Errorf("%w: %w", ErrAnthropicChat, err)
}
Expand Down Expand Up @@ -281,11 +281,11 @@ func (o *Antropic) startObserveGeneration(ctx context.Context, t *thread.Thread)
func (o *Antropic) stopObserveGeneration(
ctx context.Context,
generation *observer.Generation,
t *thread.Thread,
messagges []*thread.Message,
) error {
return llmobserver.StopObserveGeneration(
ctx,
generation,
t,
messagges,
)
}
6 changes: 3 additions & 3 deletions llm/cohere/cohere.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ func (c *Cohere) Generate(ctx context.Context, t *thread.Thread) error {
return err
}

err = c.stopObserveGeneration(ctx, generation, t)
err = c.stopObserveGeneration(ctx, generation, []*thread.Message{t.LastMessage()})
if err != nil {
return fmt.Errorf("%w: %w", ErrCohereChat, err)
}
Expand Down Expand Up @@ -309,11 +309,11 @@ func (c *Cohere) startObserveGeneration(ctx context.Context, t *thread.Thread) (
func (c *Cohere) stopObserveGeneration(
ctx context.Context,
generation *observer.Generation,
t *thread.Thread,
messages []*thread.Message,
) error {
return llmobserver.StopObserveGeneration(
ctx,
generation,
t,
messages,
)
}
4 changes: 2 additions & 2 deletions llm/observer/observer.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,15 @@ func StartObserveGeneration(
func StopObserveGeneration(
ctx context.Context,
generation *observer.Generation,
t *thread.Thread,
messages []*thread.Message,
) error {
o, ok := observer.ContextValueObserverInstance(ctx).(LLMObserver)
if o == nil || !ok {
// No observer instance in context
return nil
}

generation.Output = t.LastMessage()
generation.Output = messages
_, err := o.GenerationEnd(generation)
return err
}
6 changes: 3 additions & 3 deletions llm/ollama/ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ func (o *Ollama) Generate(ctx context.Context, t *thread.Thread) error {
return err
}

err = o.stopObserveGeneration(ctx, generation, t)
err = o.stopObserveGeneration(ctx, generation, []*thread.Message{t.LastMessage()})
if err != nil {
return fmt.Errorf("%w: %w", ErrOllamaChat, err)
}
Expand Down Expand Up @@ -248,11 +248,11 @@ func (o *Ollama) startObserveGeneration(ctx context.Context, t *thread.Thread) (
func (o *Ollama) stopObserveGeneration(
ctx context.Context,
generation *observer.Generation,
t *thread.Thread,
messages []*thread.Message,
) error {
return llmobserver.StopObserveGeneration(
ctx,
generation,
t,
messages,
)
}
8 changes: 5 additions & 3 deletions llm/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ func (o *OpenAI) Generate(ctx context.Context, t *thread.Thread) error {
return fmt.Errorf("%w: %w", ErrOpenAIChat, err)
}

nMessageBeforeGeneration := len(t.Messages)

if o.streamCallbackFn != nil {
err = o.stream(ctx, t, chatCompletionRequest)
} else {
Expand All @@ -211,7 +213,7 @@ func (o *OpenAI) Generate(ctx context.Context, t *thread.Thread) error {
return err
}

err = o.stopObserveGeneration(ctx, generation, t)
err = o.stopObserveGeneration(ctx, generation, t.Messages[nMessageBeforeGeneration:])
if err != nil {
return fmt.Errorf("%w: %w", ErrOpenAIChat, err)
}
Expand Down Expand Up @@ -454,11 +456,11 @@ func (o *OpenAI) startObserveGeneration(ctx context.Context, t *thread.Thread) (
func (o *OpenAI) stopObserveGeneration(
ctx context.Context,
generation *observer.Generation,
t *thread.Thread,
messages []*thread.Message,
) error {
return llmobserver.StopObserveGeneration(
ctx,
generation,
t,
messages,
)
}
61 changes: 58 additions & 3 deletions observer/langfuse/formatter.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,22 +54,77 @@ func threadMessagesToLangfuseMSlice(messages []*thread.Message) []model.M {
return mSlice
}

func threadOutputMessagesToLangfuseOutput(messages []*thread.Message) any {
if len(messages) == 1 &&
messages[0].Role == thread.RoleAssistant &&
len(messages[0].Contents) == 1 &&
messages[0].Contents[0].Type == thread.ContentTypeText {
return threadMessageToLangfuseM(messages[0])
}

toolCalls := model.M{}
toolMessages := []*thread.Message{}

for _, message := range messages {
if message.Role == thread.RoleAssistant &&
message.Contents[0].Type == thread.ContentTypeToolCall {
toolCalls = threadMessageToLangfuseM(message)
} else if message.Role == thread.RoleTool &&
message.Contents[0].Type == thread.ContentTypeToolResponse {
toolMessages = append(toolMessages, message)
}
}

return append([]model.M{toolCalls}, threadMessagesToLangfuseMSlice(toolMessages)...)
}

func threadMessageToLangfuseM(message *thread.Message) model.M {
if message == nil {
return nil
}

messageContent := ""
role := message.Role
if message.Role == thread.RoleTool {
data := message.Contents[0].AsToolResponseData()
m := model.M{
"type": message.Contents[0].Type,
"id": data.ID,
"name": data.Name,
"results": data.Result,
}

return model.M{
"role": role,
"content": m,
}
}

messageContent := ""
m := make([]model.M, 0)
for _, content := range message.Contents {
if content.Type == thread.ContentTypeText {
messageContent += content.AsString()
} else if content.Type == thread.ContentTypeToolCall {
for _, data := range content.AsToolCallData() {
m = append(m, model.M{
"type": content.Type,
"id": data.ID,
"name": data.Name,
"arguments": data.Arguments,
})
}
}
}
return model.M{
output := model.M{
"role": role,
"content": messageContent,
}

if len(m) > 0 {
output["content"] = m
}

return output
}

func observerGenerationToLangfuseGeneration(g *observer.Generation) *model.Generation {
Expand All @@ -81,7 +136,7 @@ func observerGenerationToLangfuseGeneration(g *observer.Generation) *model.Gener
Model: g.Model,
ModelParameters: g.ModelParameters,
Input: threadMessagesToLangfuseMSlice(g.Input),
Output: threadMessageToLangfuseM(g.Output),
Output: threadOutputMessagesToLangfuseOutput(g.Output),
Metadata: g.Metadata,
}
}
Expand Down
2 changes: 1 addition & 1 deletion observer/observer.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ type Generation struct {
Model string
ModelParameters types.M
Input []*thread.Message
Output *thread.Message
Output []*thread.Message
Metadata types.M
}

Expand Down

0 comments on commit a86c0e3

Please sign in to comment.