Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement assistant and rag observer span #196

Merged
merged 4 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 58 additions & 6 deletions assistant/assistant.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"strings"

obs "github.com/henomis/lingoose/observer"
"github.com/henomis/lingoose/thread"
"github.com/henomis/lingoose/types"
)
Expand All @@ -16,11 +17,18 @@ type Parameters struct {
CompanyDescription string
}

type observer interface {
Span(s *obs.Span) (*obs.Span, error)
SpanEnd(s *obs.Span) (*obs.Span, error)
}

type Assistant struct {
llm LLM
rag RAG
thread *thread.Thread
parameters Parameters
llm LLM
rag RAG
thread *thread.Thread
parameters Parameters
observer observer
observerTraceID string
}

type LLM interface {
Expand Down Expand Up @@ -62,19 +70,47 @@ func (a *Assistant) WithParameters(parameters Parameters) *Assistant {
return a
}

func (a *Assistant) WithObserver(observer observer, traceID string) *Assistant {
a.observer = observer
a.observerTraceID = traceID
return a
}

func (a *Assistant) Run(ctx context.Context) error {
if a.thread == nil {
return nil
}

var err error
var spanAssistant *obs.Span
if a.observer != nil {
spanAssistant, err = a.startObserveSpan(ctx, "assistant")
if err != nil {
return err
}
ctx = obs.ContextWithParentID(ctx, spanAssistant.ID)
}

if a.rag != nil {
err := a.generateRAGMessage(ctx)
errGenerate := a.generateRAGMessage(ctx)
if errGenerate != nil {
return errGenerate
}
}

err = a.llm.Generate(ctx, a.thread)
if err != nil {
return err
}

if a.observer != nil {
err = a.stopObserveSpan(spanAssistant)
if err != nil {
return err
}
}

return a.llm.Generate(ctx, a.thread)
return nil
}

func (a *Assistant) RunWithThread(ctx context.Context, thread *thread.Thread) error {
Expand Down Expand Up @@ -125,3 +161,19 @@ func (a *Assistant) generateRAGMessage(ctx context.Context) error {

return nil
}

func (a *Assistant) startObserveSpan(ctx context.Context, name string) (*obs.Span, error) {
return a.observer.Span(
&obs.Span{
TraceID: a.observerTraceID,
ParentID: obs.ContextValueParentID(ctx),
Name: name,
Input: a.parameters,
},
)
}

func (a *Assistant) stopObserveSpan(span *obs.Span) error {
_, err := a.observer.SpanEnd(span)
return err
}
71 changes: 71 additions & 0 deletions examples/observer/assistant/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package main

import (
"context"
"fmt"
"os"

"github.com/henomis/lingoose/assistant"
openaiembedder "github.com/henomis/lingoose/embedder/openai"
"github.com/henomis/lingoose/index"
"github.com/henomis/lingoose/index/vectordb/jsondb"
"github.com/henomis/lingoose/llm/openai"
"github.com/henomis/lingoose/observer"
"github.com/henomis/lingoose/observer/langfuse"
"github.com/henomis/lingoose/rag"
"github.com/henomis/lingoose/thread"
)

// download https://raw.githubusercontent.com/hwchase17/chat-your-data/master/state_of_the_union.txt

func main() {
ctx := context.Background()

o := langfuse.New(ctx)
trace, err := o.Trace(&observer.Trace{Name: "state of the union"})
if err != nil {
panic(err)
}

r := rag.New(
index.New(
jsondb.New().WithPersist("db.json"),
openaiembedder.New(openaiembedder.AdaEmbeddingV2).WithObserver(o, trace.ID),
),
).WithTopK(3).WithObserver(o, trace.ID)

_, err = os.Stat("db.json")
if os.IsNotExist(err) {
err = r.AddSources(ctx, "state_of_the_union.txt")
if err != nil {
panic(err)
}
}

a := assistant.New(
openai.New().WithTemperature(0).WithObserver(o, trace.ID),
).WithParameters(
assistant.Parameters{
AssistantName: "AI Pirate Assistant",
AssistantIdentity: "a pirate and helpful assistant",
AssistantScope: "with their questions replying as a pirate",
CompanyName: "Lingoose",
CompanyDescription: "a pirate company that provides AI assistants to help humans with their questions",
},
).WithRAG(r).WithThread(
thread.New().AddMessages(
thread.NewUserMessage().AddContent(
thread.NewTextContent("what is the purpose of NATO?"),
),
),
).WithObserver(o, trace.ID)

err = a.Run(ctx)
if err != nil {
panic(err)
}

fmt.Println("----")
fmt.Println(a.Thread())
fmt.Println("----")
}
4 changes: 4 additions & 0 deletions observer/langfuse/formatter.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ func langfuseSpanToObserverSpan(s *model.Span) *observer.Span {
TraceID: s.TraceID,
Name: s.Name,
ParentID: s.ParentObservationID,
Input: s.Input,
Output: s.Output,
}
}

Expand All @@ -35,6 +37,8 @@ func observerSpanToLangfuseSpan(s *observer.Span) *model.Span {
TraceID: s.TraceID,
Name: s.Name,
ParentObservationID: s.ParentID,
Input: s.Input,
Output: s.Output,
}
}

Expand Down
2 changes: 2 additions & 0 deletions observer/observer.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ type Span struct {
ParentID string
TraceID string
Name string
Input any
Output any
}

type Generation struct {
Expand Down
135 changes: 126 additions & 9 deletions rag/rag.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ import (
"github.com/henomis/lingoose/index"
"github.com/henomis/lingoose/index/option"
"github.com/henomis/lingoose/loader"
obs "github.com/henomis/lingoose/observer"
"github.com/henomis/lingoose/textsplitter"
"github.com/henomis/lingoose/thread"
"github.com/henomis/lingoose/types"
)

const (
Expand All @@ -27,12 +29,19 @@ type Loader interface {
LoadFromSource(context.Context, string) ([]document.Document, error)
}

type observer interface {
Span(s *obs.Span) (*obs.Span, error)
SpanEnd(s *obs.Span) (*obs.Span, error)
}

type RAG struct {
index *index.Index
chunkSize uint
chunkOverlap uint
topK uint
loaders map[*regexp.Regexp]Loader // this map a regexp as string to a loader
index *index.Index
chunkSize uint
chunkOverlap uint
topK uint
loaders map[*regexp.Regexp]Loader // this map a regexp as string to a loader
observer observer
observerTraceID string
}

func New(index *index.Index) *RAG {
Expand Down Expand Up @@ -75,14 +84,44 @@ func (r *RAG) WithLoader(sourceRegexp *regexp.Regexp, loader Loader) *RAG {
return r
}

func (r *RAG) WithObserver(observer observer, traceID string) *RAG {
r.observer = observer
r.observerTraceID = traceID
return r
}

func (r *RAG) AddSources(ctx context.Context, sources ...string) error {
for _, source := range sources {
documents, err := r.addSource(ctx, source)
var err error
var span *obs.Span
if r.observer != nil {
span, err = r.startObserveSpan(
ctx,
"RAG AddSources",
types.M{
"chunkSize": r.chunkSize,
"chunkOverlap": r.chunkOverlap,
},
)
if err != nil {
return err
}
ctx = obs.ContextWithParentID(ctx, span.ID)
}

for _, source := range sources {
documents, errAddSource := r.addSource(ctx, source)
if errAddSource != nil {
return errAddSource
}

err = r.index.LoadFromDocuments(ctx, documents)
errAddSource = r.index.LoadFromDocuments(ctx, documents)
if errAddSource != nil {
return errAddSource
}
}

if r.observer != nil {
err = r.stopObserveSpan(span)
if err != nil {
return err
}
Expand All @@ -92,10 +131,72 @@ func (r *RAG) AddSources(ctx context.Context, sources ...string) error {
}

func (r *RAG) AddDocuments(ctx context.Context, documents ...document.Document) error {
return r.index.LoadFromDocuments(ctx, documents)
var err error
var span *obs.Span
if r.observer != nil {
span, err = r.startObserveSpan(
ctx,
"RAG AddDocument",
types.M{
"chunkSize": r.chunkSize,
"chunkOverlap": r.chunkOverlap,
},
)
if err != nil {
return err
}
ctx = obs.ContextWithParentID(ctx, span.ID)
}

err = r.index.LoadFromDocuments(ctx, documents)
if err != nil {
return err
}

if r.observer != nil {
err = r.stopObserveSpan(span)
if err != nil {
return err
}
}

return nil
}

func (r *RAG) Retrieve(ctx context.Context, query string) ([]string, error) {
var err error
var span *obs.Span
if r.observer != nil {
span, err = r.startObserveSpan(
ctx,
"RAG Retrieve",
types.M{
"query": query,
"topK": r.topK,
},
)
if err != nil {
return nil, err
}
ctx = obs.ContextWithParentID(ctx, span.ID)
}

texts, err := r.retrieve(ctx, query)
if err != nil {
return nil, err
}

if r.observer != nil {
err = r.stopObserveSpan(span)
if err != nil {
return nil, err
}
}

return texts, nil
}

func (r *RAG) retrieve(ctx context.Context, query string) ([]string, error) {
results, err := r.index.Query(ctx, query, option.WithTopK(int(r.topK)))
var resultsAsString []string
for _, result := range results {
Expand Down Expand Up @@ -127,3 +228,19 @@ func (r *RAG) addSource(ctx context.Context, source string) ([]document.Document
int(r.chunkOverlap),
).SplitDocuments(documents), nil
}

func (r *RAG) startObserveSpan(ctx context.Context, name string, input any) (*obs.Span, error) {
return r.observer.Span(
&obs.Span{
TraceID: r.observerTraceID,
ParentID: obs.ContextValueParentID(ctx),
Name: name,
Input: input,
},
)
}

func (r *RAG) stopObserveSpan(span *obs.Span) error {
_, err := r.observer.SpanEnd(span)
return err
}
Loading
Loading