Skip to content

Commit

Permalink
Refactor observer (#194)
Browse files Browse the repository at this point in the history
  • Loading branch information
henomis authored May 14, 2024
1 parent bd7180f commit 5727aae
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 57 deletions.
20 changes: 17 additions & 3 deletions examples/observer/langfuse/openai/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,26 @@ import (
)

func main() {
o := langfuse.New(context.Background())
ctx := context.Background()

o := langfuse.New(ctx)
trace, err := o.Trace(&observer.Trace{Name: "Who are you"})
if err != nil {
panic(err)
}

span, err := o.Span(
&observer.Span{
TraceID: trace.ID,
Name: "SPAN",
},
)
if err != nil {
panic(err)
}

ctx = observer.ContextWithParentID(ctx, span.ID)

openaillm := openai.New().WithObserver(o, trace.ID)

t := thread.New().AddMessage(
Expand All @@ -24,10 +38,10 @@ func main() {
),
)

err = openaillm.Generate(context.Background(), t)
err = openaillm.Generate(ctx, t)
if err != nil {
panic(err)
}

o.Flush(context.Background())
o.Flush(ctx)
}
12 changes: 5 additions & 7 deletions llm/antropic/antropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,9 @@ func (o *Antropic) Generate(ctx context.Context, t *thread.Thread) error {

chatRequest := o.buildChatCompletionRequest(t)

var span *observer.Span
var generation *observer.Generation
if o.observer != nil {
span, generation, err = o.startObserveGeneration(t)
generation, err = o.startObserveGeneration(ctx, t)
if err != nil {
return fmt.Errorf("%w: %w", ErrAnthropicChat, err)
}
Expand All @@ -184,7 +183,7 @@ func (o *Antropic) Generate(ctx context.Context, t *thread.Thread) error {
}

if o.observer != nil {
err = o.stopObserveGeneration(span, generation, t)
err = o.stopObserveGeneration(generation, t)
if err != nil {
return fmt.Errorf("%w: %w", ErrAnthropicChat, err)
}
Expand Down Expand Up @@ -279,8 +278,8 @@ func (o *Antropic) stream(ctx context.Context, t *thread.Thread, chatRequest *re
return nil
}

func (o *Antropic) startObserveGeneration(t *thread.Thread) (*observer.Span, *observer.Generation, error) {
return llmobserver.SartObserveGeneration(
func (o *Antropic) startObserveGeneration(ctx context.Context, t *thread.Thread) (*observer.Generation, error) {
return llmobserver.StartObserveGeneration(
o.observer,
o.name,
o.model,
Expand All @@ -289,18 +288,17 @@ func (o *Antropic) startObserveGeneration(t *thread.Thread) (*observer.Span, *ob
"temperature": o.temperature,
},
o.observerTraceID,
observer.ContextValueParentID(ctx),
t,
)
}

func (o *Antropic) stopObserveGeneration(
span *observer.Span,
generation *observer.Generation,
t *thread.Thread,
) error {
return llmobserver.StopObserveGeneration(
o.observer,
span,
generation,
t,
)
Expand Down
12 changes: 5 additions & 7 deletions llm/cohere/cohere.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,9 @@ func (c *Cohere) Generate(ctx context.Context, t *thread.Thread) error {

chatRequest := c.buildChatCompletionRequest(t)

var span *observer.Span
var generation *observer.Generation
if c.observer != nil {
span, generation, err = c.startObserveGeneration(t)
generation, err = c.startObserveGeneration(ctx, t)
if err != nil {
return fmt.Errorf("%w: %w", ErrCohereChat, err)
}
Expand All @@ -238,7 +237,7 @@ func (c *Cohere) Generate(ctx context.Context, t *thread.Thread) error {
}

if c.observer != nil {
err = c.stopObserveGeneration(span, generation, t)
err = c.stopObserveGeneration(generation, t)
if err != nil {
return fmt.Errorf("%w: %w", ErrCohereChat, err)
}
Expand Down Expand Up @@ -299,8 +298,8 @@ func (c *Cohere) stream(ctx context.Context, t *thread.Thread, chatRequest *requ
return nil
}

func (c *Cohere) startObserveGeneration(t *thread.Thread) (*observer.Span, *observer.Generation, error) {
return llmobserver.SartObserveGeneration(
func (c *Cohere) startObserveGeneration(ctx context.Context, t *thread.Thread) (*observer.Generation, error) {
return llmobserver.StartObserveGeneration(
c.observer,
c.name,
string(c.model),
Expand All @@ -309,18 +308,17 @@ func (c *Cohere) startObserveGeneration(t *thread.Thread) (*observer.Span, *obse
"temperature": c.temperature,
},
c.observerTraceID,
observer.ContextValueParentID(ctx),
t,
)
}

func (c *Cohere) stopObserveGeneration(
span *observer.Span,
generation *observer.Generation,
t *thread.Thread,
) error {
return llmobserver.StopObserveGeneration(
c.observer,
span,
generation,
t,
)
Expand Down
33 changes: 8 additions & 25 deletions llm/observer/observer.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,58 +9,41 @@ import (
)

type LLMObserver interface {
Span(*obs.Span) (*obs.Span, error)
SpanEnd(*obs.Span) (*obs.Span, error)
Generation(*obs.Generation) (*obs.Generation, error)
GenerationEnd(*obs.Generation) (*obs.Generation, error)
}

func SartObserveGeneration(
func StartObserveGeneration(
o LLMObserver,
name string,
modelName string,
ModelParameters types.M,
traceID string,
parentID string,
t *thread.Thread,
) (*obs.Span, *obs.Generation, error) {
span, err := o.Span(
&obs.Span{
TraceID: traceID,
Name: name,
},
)
if err != nil {
return nil, nil, err
}

) (*obs.Generation, error) {
generation, err := o.Generation(
&obs.Generation{
TraceID: traceID,
ParentID: span.ID,
Name: fmt.Sprintf("%s-%s", name, modelName),
ParentID: parentID,
Name: fmt.Sprintf("llm-%s", name),
Model: modelName,
ModelParameters: ModelParameters,
Input: t.Messages,
},
)
if err != nil {
return nil, nil, err
return nil, err
}
return span, generation, nil
return generation, nil
}

func StopObserveGeneration(
o LLMObserver,
span *obs.Span,
generation *obs.Generation,
t *thread.Thread,
) error {
_, err := o.SpanEnd(span)
if err != nil {
return err
}

generation.Output = t.LastMessage()
_, err = o.GenerationEnd(generation)
_, err := o.GenerationEnd(generation)
return err
}
12 changes: 5 additions & 7 deletions llm/ollama/ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,9 @@ func (o *Ollama) Generate(ctx context.Context, t *thread.Thread) error {

chatRequest := o.buildChatCompletionRequest(t)

var span *observer.Span
var generation *observer.Generation
if o.observer != nil {
span, generation, err = o.startObserveGeneration(t)
generation, err = o.startObserveGeneration(ctx, t)
if err != nil {
return fmt.Errorf("%w: %w", ErrOllamaChat, err)
}
Expand All @@ -163,7 +162,7 @@ func (o *Ollama) Generate(ctx context.Context, t *thread.Thread) error {
}

if o.observer != nil {
err = o.stopObserveGeneration(span, generation, t)
err = o.stopObserveGeneration(generation, t)
if err != nil {
return fmt.Errorf("%w: %w", ErrOllamaChat, err)
}
Expand Down Expand Up @@ -245,8 +244,8 @@ func (o *Ollama) stream(ctx context.Context, t *thread.Thread, chatRequest *requ
return nil
}

func (o *Ollama) startObserveGeneration(t *thread.Thread) (*observer.Span, *observer.Generation, error) {
return llmobserver.SartObserveGeneration(
func (o *Ollama) startObserveGeneration(ctx context.Context, t *thread.Thread) (*observer.Generation, error) {
return llmobserver.StartObserveGeneration(
o.observer,
o.name,
o.model,
Expand All @@ -256,18 +255,17 @@ func (o *Ollama) startObserveGeneration(t *thread.Thread) (*observer.Span, *obse
"temperature": o.temperature,
},
o.observerTraceID,
observer.ContextValueParentID(ctx),
t,
)
}

func (o *Ollama) stopObserveGeneration(
span *observer.Span,
generation *observer.Generation,
t *thread.Thread,
) error {
return llmobserver.StopObserveGeneration(
o.observer,
span,
generation,
t,
)
Expand Down
13 changes: 5 additions & 8 deletions llm/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,9 @@ func (o *OpenAI) Generate(ctx context.Context, t *thread.Thread) error {
chatCompletionRequest.ToolChoice = o.getChatCompletionRequestToolChoice()
}

var span *observer.Span
var generation *observer.Generation

if o.observer != nil {
span, generation, err = o.startObserveGeneration(t)
generation, err = o.startObserveGeneration(ctx, t)
if err != nil {
return fmt.Errorf("%w: %w", ErrOpenAIChat, err)
}
Expand All @@ -219,7 +217,7 @@ func (o *OpenAI) Generate(ctx context.Context, t *thread.Thread) error {
}

if o.observer != nil {
err = o.stopObserveGeneration(span, generation, t)
err = o.stopObserveGeneration(generation, t)
if err != nil {
return fmt.Errorf("%w: %w", ErrOpenAIChat, err)
}
Expand Down Expand Up @@ -439,8 +437,8 @@ func (o *OpenAI) callTools(toolCalls []openai.ToolCall) []*thread.Message {
return messages
}

func (o *OpenAI) startObserveGeneration(t *thread.Thread) (*observer.Span, *observer.Generation, error) {
return llmobserver.SartObserveGeneration(
func (o *OpenAI) startObserveGeneration(ctx context.Context, t *thread.Thread) (*observer.Generation, error) {
return llmobserver.StartObserveGeneration(
o.observer,
o.Name,
string(o.model),
Expand All @@ -449,18 +447,17 @@ func (o *OpenAI) startObserveGeneration(t *thread.Thread) (*observer.Span, *obse
"temperature": o.temperature,
},
o.observerTraceID,
observer.ContextValueParentID(ctx),
t,
)
}

func (o *OpenAI) stopObserveGeneration(
span *observer.Span,
generation *observer.Generation,
t *thread.Thread,
) error {
return llmobserver.StopObserveGeneration(
o.observer,
span,
generation,
t,
)
Expand Down
25 changes: 25 additions & 0 deletions observer/observer.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
package observer

import (
"context"

"github.com/henomis/lingoose/thread"
"github.com/henomis/lingoose/types"
)

type ContextKey string

const (
ContextKeyParentID ContextKey = "observerParentID"
ContextKeyGeneration ContextKey = "observerGeneration"
)

type Trace struct {
ID string
Name string
Expand Down Expand Up @@ -43,3 +52,19 @@ type Score struct {
Name string
Value float64
}

func ContextValueParentID(ctx context.Context) string {
parentID, ok := ctx.Value(ContextKeyParentID).(string)
if !ok {
return ""
}
return parentID
}

func ContextWithParentID(ctx context.Context, parentID string) context.Context {
return context.WithValue(ctx, ContextKeyParentID, parentID)
}

func ContextWithouthParentID(ctx context.Context) context.Context {
return context.WithValue(ctx, ContextKeyParentID, "")
}

0 comments on commit 5727aae

Please sign in to comment.