diff --git a/examples/rag/main.go b/examples/rag/simple/main.go similarity index 100% rename from examples/rag/main.go rename to examples/rag/simple/main.go diff --git a/examples/rag/subdocument/main.go b/examples/rag/subdocument/main.go new file mode 100644 index 00000000..ad41628f --- /dev/null +++ b/examples/rag/subdocument/main.go @@ -0,0 +1,54 @@ +package main + +import ( + "context" + "fmt" + "os" + + "github.com/henomis/lingoose/assistant" + openaiembedder "github.com/henomis/lingoose/embedder/openai" + "github.com/henomis/lingoose/index" + "github.com/henomis/lingoose/index/vectordb/jsondb" + "github.com/henomis/lingoose/llm/openai" + "github.com/henomis/lingoose/rag" + "github.com/henomis/lingoose/thread" +) + +// download https://raw.githubusercontent.com/hwchase17/chat-your-data/master/state_of_the_union.txt + +func main() { + r := rag.NewSubDocumentRAG( + index.New( + jsondb.New().WithPersist("db.json"), + openaiembedder.New(openaiembedder.AdaEmbeddingV2), + ), + openai.New(), + ).WithTopK(3) + + _, err := os.Stat("db.json") + if os.IsNotExist(err) { + err = r.AddSources(context.Background(), "state_of_the_union.txt") + if err != nil { + panic(err) + } + } + + a := assistant.New( + openai.New().WithTemperature(0), + ).WithRAG(r).WithThread( + thread.New().AddMessages( + thread.NewUserMessage().AddContent( + thread.NewTextContent("what is the purpose of NATO?"), + ), + ), + ) + + err = a.Run(context.Background()) + if err != nil { + panic(err) + } + + fmt.Println("----") + fmt.Println(a.Thread()) + fmt.Println("----") +} diff --git a/rag/rag.go b/rag/rag.go index c8f3d529..a4dd8eda 100644 --- a/rag/rag.go +++ b/rag/rag.go @@ -35,11 +35,6 @@ type RAG struct { loaders map[*regexp.Regexp]Loader // this map a regexp as string to a loader } -type Fusion struct { - RAG - llm LLM -} - func New(index *index.Index) *RAG { rag := &RAG{ index: index, diff --git a/rag/rag_fusion.go b/rag/rag_fusion.go index e8363c82..2c55eace 100644 --- a/rag/rag_fusion.go +++ b/rag/rag_fusion.go @@ -17,6 +17,11 @@ var ragFusionPrompts = []string{ "OUTPUT (4 queries):", } +type Fusion struct { + RAG + llm LLM +} + func NewFusion(index *index.Index, llm LLM) *Fusion { return &Fusion{ RAG: *New(index), diff --git a/rag/sub_document.go b/rag/sub_document.go new file mode 100644 index 00000000..b4e0d0f4 --- /dev/null +++ b/rag/sub_document.go @@ -0,0 +1,121 @@ +package rag + +import ( + "context" + "regexp" + + "github.com/henomis/lingoose/document" + "github.com/henomis/lingoose/index" + "github.com/henomis/lingoose/textsplitter" + "github.com/henomis/lingoose/thread" + "github.com/henomis/lingoose/types" +) + +const ( + defaultSubDocumentRAGChunkSize = 8192 + defaultSubDocumentRAGChunkOverlap = 0 + defaultSubDocumentRAGChildChunkSize = 512 +) + +type SubDocumentRAG struct { + RAG + childChunkSize uint + llm LLM +} + +//nolint:lll +var SubDocumentRAGSummarizePrompt = "Please give a concise summary of the context in 1-2 sentences.\n\nContext: {{.context}}" + +func NewSubDocumentRAG(index *index.Index, llm LLM) *SubDocumentRAG { + return &SubDocumentRAG{ + RAG: *New(index). + WithChunkSize(defaultSubDocumentRAGChunkSize). + WithChunkOverlap(defaultSubDocumentRAGChunkOverlap), + childChunkSize: defaultSubDocumentRAGChildChunkSize, + llm: llm, + } +} + +func (r *SubDocumentRAG) WithChunkSize(chunkSize uint) *SubDocumentRAG { + r.chunkSize = chunkSize + return r +} + +func (r *SubDocumentRAG) WithChildChunkSize(childChunkSize uint) *SubDocumentRAG { + r.childChunkSize = childChunkSize + return r +} + +func (r *SubDocumentRAG) WithChunkOverlap(chunkOverlap uint) *SubDocumentRAG { + r.chunkOverlap = chunkOverlap + return r +} + +func (r *SubDocumentRAG) WithTopK(topK uint) *SubDocumentRAG { + r.topK = topK + return r +} + +func (r *SubDocumentRAG) WithLoader(sourceRegexp *regexp.Regexp, loader Loader) *SubDocumentRAG { + r.loaders[sourceRegexp] = loader + return r +} + +func (r *SubDocumentRAG) AddSources(ctx context.Context, sources ...string) error { + for _, source := range sources { + documents, err := r.addSource(ctx, source) + if err != nil { + return err + } + + subDocuments, err := r.generateSubDocuments(ctx, documents) + if err != nil { + return err + } + + err = r.index.LoadFromDocuments(ctx, subDocuments) + if err != nil { + return err + } + } + + return nil +} + +func (r *SubDocumentRAG) generateSubDocuments( + ctx context.Context, + documents []document.Document, +) ([]document.Document, error) { + var subDocuments []document.Document + + for _, doc := range documents { + t := thread.New().AddMessages( + thread.NewUserMessage().AddContent( + thread.NewTextContent(SubDocumentRAGSummarizePrompt).Format( + types.M{ + "context": doc.Content, + }, + ), + ), + ) + + err := r.llm.Generate(ctx, t) + if err != nil { + return nil, err + } + summary := t.LastMessage().Contents[0].AsString() + + subChunks := textsplitter.NewRecursiveCharacterTextSplitter( + int(r.childChunkSize), + 0, + ).SplitDocuments([]document.Document{doc}) + + for i := range subChunks { + subChunks[i].Content = summary + "\n" + subChunks[i].Content + } + + subDocuments = append(subDocuments, subChunks...) + } + + return subDocuments, nil +}