Skip to content

Commit

Permalink
Implement assistant and rag observer span (#196)
Browse files Browse the repository at this point in the history
* chore: add rag and assistant observer implementation

* chore: add example

* chore: implement observer for rag fusion e rag subdocument

* fix
  • Loading branch information
henomis authored May 14, 2024
1 parent fae73dc commit e0a9028
Show file tree
Hide file tree
Showing 7 changed files with 328 additions and 21 deletions.
64 changes: 58 additions & 6 deletions assistant/assistant.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"strings"

obs "github.com/henomis/lingoose/observer"
"github.com/henomis/lingoose/thread"
"github.com/henomis/lingoose/types"
)
Expand All @@ -16,11 +17,18 @@ type Parameters struct {
CompanyDescription string
}

type observer interface {
Span(s *obs.Span) (*obs.Span, error)
SpanEnd(s *obs.Span) (*obs.Span, error)
}

type Assistant struct {
llm LLM
rag RAG
thread *thread.Thread
parameters Parameters
llm LLM
rag RAG
thread *thread.Thread
parameters Parameters
observer observer
observerTraceID string
}

type LLM interface {
Expand Down Expand Up @@ -62,19 +70,47 @@ func (a *Assistant) WithParameters(parameters Parameters) *Assistant {
return a
}

func (a *Assistant) WithObserver(observer observer, traceID string) *Assistant {
a.observer = observer
a.observerTraceID = traceID
return a
}

func (a *Assistant) Run(ctx context.Context) error {
if a.thread == nil {
return nil
}

var err error
var spanAssistant *obs.Span
if a.observer != nil {
spanAssistant, err = a.startObserveSpan(ctx, "assistant")
if err != nil {
return err
}
ctx = obs.ContextWithParentID(ctx, spanAssistant.ID)
}

if a.rag != nil {
err := a.generateRAGMessage(ctx)
errGenerate := a.generateRAGMessage(ctx)
if errGenerate != nil {
return errGenerate
}
}

err = a.llm.Generate(ctx, a.thread)
if err != nil {
return err
}

if a.observer != nil {
err = a.stopObserveSpan(spanAssistant)
if err != nil {
return err
}
}

return a.llm.Generate(ctx, a.thread)
return nil
}

func (a *Assistant) RunWithThread(ctx context.Context, thread *thread.Thread) error {
Expand Down Expand Up @@ -125,3 +161,19 @@ func (a *Assistant) generateRAGMessage(ctx context.Context) error {

return nil
}

func (a *Assistant) startObserveSpan(ctx context.Context, name string) (*obs.Span, error) {
return a.observer.Span(
&obs.Span{
TraceID: a.observerTraceID,
ParentID: obs.ContextValueParentID(ctx),
Name: name,
Input: a.parameters,
},
)
}

func (a *Assistant) stopObserveSpan(span *obs.Span) error {
_, err := a.observer.SpanEnd(span)
return err
}
71 changes: 71 additions & 0 deletions examples/observer/assistant/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package main

import (
"context"
"fmt"
"os"

"github.com/henomis/lingoose/assistant"
openaiembedder "github.com/henomis/lingoose/embedder/openai"
"github.com/henomis/lingoose/index"
"github.com/henomis/lingoose/index/vectordb/jsondb"
"github.com/henomis/lingoose/llm/openai"
"github.com/henomis/lingoose/observer"
"github.com/henomis/lingoose/observer/langfuse"
"github.com/henomis/lingoose/rag"
"github.com/henomis/lingoose/thread"
)

// download https://raw.githubusercontent.com/hwchase17/chat-your-data/master/state_of_the_union.txt

func main() {
ctx := context.Background()

o := langfuse.New(ctx)
trace, err := o.Trace(&observer.Trace{Name: "state of the union"})
if err != nil {
panic(err)
}

r := rag.New(
index.New(
jsondb.New().WithPersist("db.json"),
openaiembedder.New(openaiembedder.AdaEmbeddingV2).WithObserver(o, trace.ID),
),
).WithTopK(3).WithObserver(o, trace.ID)

_, err = os.Stat("db.json")
if os.IsNotExist(err) {
err = r.AddSources(ctx, "state_of_the_union.txt")
if err != nil {
panic(err)
}
}

a := assistant.New(
openai.New().WithTemperature(0).WithObserver(o, trace.ID),
).WithParameters(
assistant.Parameters{
AssistantName: "AI Pirate Assistant",
AssistantIdentity: "a pirate and helpful assistant",
AssistantScope: "with their questions replying as a pirate",
CompanyName: "Lingoose",
CompanyDescription: "a pirate company that provides AI assistants to help humans with their questions",
},
).WithRAG(r).WithThread(
thread.New().AddMessages(
thread.NewUserMessage().AddContent(
thread.NewTextContent("what is the purpose of NATO?"),
),
),
).WithObserver(o, trace.ID)

err = a.Run(ctx)
if err != nil {
panic(err)
}

fmt.Println("----")
fmt.Println(a.Thread())
fmt.Println("----")
}
4 changes: 4 additions & 0 deletions observer/langfuse/formatter.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ func langfuseSpanToObserverSpan(s *model.Span) *observer.Span {
TraceID: s.TraceID,
Name: s.Name,
ParentID: s.ParentObservationID,
Input: s.Input,
Output: s.Output,
}
}

Expand All @@ -35,6 +37,8 @@ func observerSpanToLangfuseSpan(s *observer.Span) *model.Span {
TraceID: s.TraceID,
Name: s.Name,
ParentObservationID: s.ParentID,
Input: s.Input,
Output: s.Output,
}
}

Expand Down
2 changes: 2 additions & 0 deletions observer/observer.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ type Span struct {
ParentID string
TraceID string
Name string
Input any
Output any
}

type Generation struct {
Expand Down
135 changes: 126 additions & 9 deletions rag/rag.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ import (
"github.com/henomis/lingoose/index"
"github.com/henomis/lingoose/index/option"
"github.com/henomis/lingoose/loader"
obs "github.com/henomis/lingoose/observer"
"github.com/henomis/lingoose/textsplitter"
"github.com/henomis/lingoose/thread"
"github.com/henomis/lingoose/types"
)

const (
Expand All @@ -27,12 +29,19 @@ type Loader interface {
LoadFromSource(context.Context, string) ([]document.Document, error)
}

type observer interface {
Span(s *obs.Span) (*obs.Span, error)
SpanEnd(s *obs.Span) (*obs.Span, error)
}

type RAG struct {
index *index.Index
chunkSize uint
chunkOverlap uint
topK uint
loaders map[*regexp.Regexp]Loader // this map a regexp as string to a loader
index *index.Index
chunkSize uint
chunkOverlap uint
topK uint
loaders map[*regexp.Regexp]Loader // this map a regexp as string to a loader
observer observer
observerTraceID string
}

func New(index *index.Index) *RAG {
Expand Down Expand Up @@ -75,14 +84,44 @@ func (r *RAG) WithLoader(sourceRegexp *regexp.Regexp, loader Loader) *RAG {
return r
}

func (r *RAG) WithObserver(observer observer, traceID string) *RAG {
r.observer = observer
r.observerTraceID = traceID
return r
}

func (r *RAG) AddSources(ctx context.Context, sources ...string) error {
for _, source := range sources {
documents, err := r.addSource(ctx, source)
var err error
var span *obs.Span
if r.observer != nil {
span, err = r.startObserveSpan(
ctx,
"RAG AddSources",
types.M{
"chunkSize": r.chunkSize,
"chunkOverlap": r.chunkOverlap,
},
)
if err != nil {
return err
}
ctx = obs.ContextWithParentID(ctx, span.ID)
}

for _, source := range sources {
documents, errAddSource := r.addSource(ctx, source)
if errAddSource != nil {
return errAddSource
}

err = r.index.LoadFromDocuments(ctx, documents)
errAddSource = r.index.LoadFromDocuments(ctx, documents)
if errAddSource != nil {
return errAddSource
}
}

if r.observer != nil {
err = r.stopObserveSpan(span)
if err != nil {
return err
}
Expand All @@ -92,10 +131,72 @@ func (r *RAG) AddSources(ctx context.Context, sources ...string) error {
}

func (r *RAG) AddDocuments(ctx context.Context, documents ...document.Document) error {
return r.index.LoadFromDocuments(ctx, documents)
var err error
var span *obs.Span
if r.observer != nil {
span, err = r.startObserveSpan(
ctx,
"RAG AddDocument",
types.M{
"chunkSize": r.chunkSize,
"chunkOverlap": r.chunkOverlap,
},
)
if err != nil {
return err
}
ctx = obs.ContextWithParentID(ctx, span.ID)
}

err = r.index.LoadFromDocuments(ctx, documents)
if err != nil {
return err
}

if r.observer != nil {
err = r.stopObserveSpan(span)
if err != nil {
return err
}
}

return nil
}

func (r *RAG) Retrieve(ctx context.Context, query string) ([]string, error) {
var err error
var span *obs.Span
if r.observer != nil {
span, err = r.startObserveSpan(
ctx,
"RAG Retrieve",
types.M{
"query": query,
"topK": r.topK,
},
)
if err != nil {
return nil, err
}
ctx = obs.ContextWithParentID(ctx, span.ID)
}

texts, err := r.retrieve(ctx, query)
if err != nil {
return nil, err
}

if r.observer != nil {
err = r.stopObserveSpan(span)
if err != nil {
return nil, err
}
}

return texts, nil
}

func (r *RAG) retrieve(ctx context.Context, query string) ([]string, error) {
results, err := r.index.Query(ctx, query, option.WithTopK(int(r.topK)))
var resultsAsString []string
for _, result := range results {
Expand Down Expand Up @@ -127,3 +228,19 @@ func (r *RAG) addSource(ctx context.Context, source string) ([]document.Document
int(r.chunkOverlap),
).SplitDocuments(documents), nil
}

func (r *RAG) startObserveSpan(ctx context.Context, name string, input any) (*obs.Span, error) {
return r.observer.Span(
&obs.Span{
TraceID: r.observerTraceID,
ParentID: obs.ContextValueParentID(ctx),
Name: name,
Input: input,
},
)
}

func (r *RAG) stopObserveSpan(span *obs.Span) error {
_, err := r.observer.SpanEnd(span)
return err
}
Loading

0 comments on commit e0a9028

Please sign in to comment.