diff --git a/assistant/assistant.go b/assistant/assistant.go index 0dd0e976..84639324 100644 --- a/assistant/assistant.go +++ b/assistant/assistant.go @@ -8,10 +8,19 @@ import ( "github.com/henomis/lingoose/types" ) +type Parameters struct { + AssistantName string + AssistantIdentity string + AssistantScope string + CompanyName string + CompanyDescription string +} + type Assistant struct { - llm LLM - rag RAG - thread *thread.Thread + llm LLM + rag RAG + thread *thread.Thread + parameters Parameters } type LLM interface { @@ -26,6 +35,13 @@ func New(llm LLM) *Assistant { assistant := &Assistant{ llm: llm, thread: thread.New(), + parameters: Parameters{ + AssistantName: defaultAssistantName, + AssistantIdentity: defaultAssistantIdentity, + AssistantScope: defaultAssistantScope, + CompanyName: defaultCompanyName, + CompanyDescription: defaultCompanyDescription, + }, } return assistant @@ -41,6 +57,11 @@ func (a *Assistant) WithRAG(rag RAG) *Assistant { return a } +func (a *Assistant) WithParameters(parameters Parameters) *Assistant { + a.parameters = parameters + return a +} + func (a *Assistant) Run(ctx context.Context) error { if a.thread == nil { return nil @@ -71,29 +92,33 @@ func (a *Assistant) generateRAGMessage(ctx context.Context) error { return nil } - query := "" - for _, content := range lastMessage.Contents { - if content.Type == thread.ContentTypeText { - query += content.Data.(string) + "\n" - } else { - continue - } - } + query := strings.Join(a.thread.UserQuery(), "\n") + a.thread.Messages = a.thread.Messages[:len(a.thread.Messages)-1] searchResults, err := a.rag.Retrieve(ctx, query) if err != nil { return err } - context := strings.Join(searchResults, "\n\n") - - a.thread.AddMessage(thread.NewUserMessage().AddContent( + a.thread.AddMessage(thread.NewSystemMessage().AddContent( + thread.NewTextContent( + systemRAGPrompt, + ).Format( + types.M{ + "assistantName": a.parameters.AssistantName, + "assistantIdentity": a.parameters.AssistantIdentity, + "assistantScope": a.parameters.AssistantScope, + "companyName": a.parameters.CompanyName, + "companyDescription": a.parameters.CompanyDescription, + }, + ), + )).AddMessage(thread.NewUserMessage().AddContent( thread.NewTextContent( baseRAGPrompt, ).Format( types.M{ "question": query, - "context": context, + "results": searchResults, }, ), )) diff --git a/assistant/prompt.go b/assistant/prompt.go index 71df12a7..18a4f8fe 100644 --- a/assistant/prompt.go +++ b/assistant/prompt.go @@ -2,7 +2,13 @@ package assistant const ( //nolint:lll - baseRAGPrompt = `You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise. - Question: {{.question}} - Context: {{.context}}` + baseRAGPrompt = "Use the following pieces of retrieved context to answer the question.\n\nQuestion: {{.question}}\nContext:\n{{range .results}}{{.}}\n\n{{end}}" + //nolint:lll + systemRAGPrompt = "You name is {{.assistantName}}, and you are {{.assistantIdentity}} {{if ne .companyName \"\" }}at {{.companyName}}{{end}}{{if ne .companyDescription \"\" }}, {{.companyDescription}}{{end}}. Your task is to assist humans {{.assistantScope}}." + + defaultAssistantName = "AI assistant" + defaultAssistantIdentity = "a helpful and polite assistant" + defaultAssistantScope = "with their questions" + defaultCompanyName = "" + defaultCompanyDescription = "" ) diff --git a/examples/assistant/main.go b/examples/assistant/main.go index 77d6c61e..c36c198e 100644 --- a/examples/assistant/main.go +++ b/examples/assistant/main.go @@ -17,12 +17,11 @@ import ( // download https://raw.githubusercontent.com/hwchase17/chat-your-data/master/state_of_the_union.txt func main() { - r := rag.NewFusion( + r := rag.New( index.New( jsondb.New().WithPersist("db.json"), openaiembedder.New(openaiembedder.AdaEmbeddingV2), ), - openai.New().WithTemperature(0), ).WithTopK(3) _, err := os.Stat("db.json") @@ -35,6 +34,14 @@ func main() { a := assistant.New( openai.New().WithTemperature(0), + ).WithParameters( + assistant.Parameters{ + AssistantName: "AI Pirate Assistant", + AssistantIdentity: "a pirate and helpful assistant", + AssistantScope: "with their questions replying as a pirate", + CompanyName: "Lingoose", + CompanyDescription: "a pirate company that provides AI assistants to help humans with their questions", + }, ).WithRAG(r).WithThread( thread.New().AddMessages( thread.NewUserMessage().AddContent( diff --git a/llm/cohere/formatter.go b/llm/cohere/formatter.go index 67b3360d..f39bd16f 100644 --- a/llm/cohere/formatter.go +++ b/llm/cohere/formatter.go @@ -58,24 +58,3 @@ func threadToChatMessages(t *thread.Thread) (string, []model.ChatMessage) { return message, history } - -// chatMessages := make([]message, len(t.Messages)) -// for i, m := range t.Messages { -// chatMessages[i] = message{ -// Role: threadRoleToOpenAIRole[m.Role], -// } - -// switch m.Role { -// case thread.RoleUser, thread.RoleSystem, thread.RoleAssistant: -// for _, content := range m.Contents { -// if content.Type == thread.ContentTypeText { -// chatMessages[i].Content += content.Data.(string) + "\n" -// } -// } -// case thread.RoleTool: -// continue -// } -// } - -// return chatMessages -// }