Skip to content

Commit

Permalink
add cohere chat stream
Browse files Browse the repository at this point in the history
  • Loading branch information
henomis committed Jan 28, 2024
1 parent 3fc9ee5 commit 12caa7f
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 42 deletions.
20 changes: 14 additions & 6 deletions embedder/cohere/cohere.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,23 @@ type EmbedderModel = model.EmbedModel
const (
defaultEmbedderModel EmbedderModel = model.EmbedModelEnglishV20

EmbedderModelEnglishV20 EmbedderModel = model.EmbedModelEnglishV20
EmbedderModelEnglishLightV20 EmbedderModel = model.EmbedModelEnglishLightV20
EmbedderModelMultilingualV20 EmbedderModel = model.EmbedModelMultilingualV20
EmbedderModelEnglishV20 EmbedderModel = model.EmbedModelEnglishV20
EmbedderModelEnglishLightV20 EmbedderModel = model.EmbedModelEnglishLightV20
EmbedderModelMultilingualV20 EmbedderModel = model.EmbedModelMultilingualV20
EmbedderModelEnglishV30 EmbedderModel = model.EmbedModelEnglishV30
EmbedderModelEnglishLightV30 EmbedderModel = model.EmbedModelEnglishLightV30
EmbedderModelMultilingualV30 EmbedderModel = model.EmbedModelMultilingualV30
EmbedderModelMultilingualLightV30 EmbedderModel = model.EmbedModelMultilingualLightV30
)

var EmbedderModelsSize = map[EmbedderModel]int{
EmbedderModelEnglishLightV20: 1024,
EmbedderModelEnglishV20: 4096,
EmbedderModelMultilingualV20: 768,
EmbedderModelEnglishV30: 1024,
EmbedderModelEnglishLightV30: 384,
EmbedderModelEnglishV20: 4096,
EmbedderModelEnglishLightV20: 1024,
EmbedderModelMultilingualV30: 1024,
EmbedderModelMultilingualLightV30: 384,
EmbedderModelMultilingualV20: 768,
}

type Embedder struct {
Expand Down
29 changes: 29 additions & 0 deletions examples/thread/cohere/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package main

import (
"context"
"fmt"

"github.com/henomis/lingoose/llm/cohere"
"github.com/henomis/lingoose/thread"
)

func main() {
t := thread.New()
t.AddMessage(thread.NewUserMessage().AddContent(
thread.NewTextContent("Hello"),
))

err := cohere.New().WithMaxTokens(1000).WithTemperature(0).
WithStream(
func(s string) {
fmt.Print(s)
},
).Generate(context.Background(), t)
if err != nil {
panic(err)
}

fmt.Println(t.String())

}
2 changes: 1 addition & 1 deletion examples/thread/main.go → examples/thread/ollama/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func main() {
))

err := ollama.New().WithEndpoint("http://localhost:11434/api").WithModel("llama2").
WithStream(true, func(s string) {
WithStream(func(s string) {
fmt.Print(s)
}).Generate(context.Background(), t)
if err != nil {
Expand Down
103 changes: 77 additions & 26 deletions llm/cohere/cohere.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@ import (
"github.com/henomis/lingoose/thread"
)

type Model model.Model
var (
ErrCohereChat = fmt.Errorf("cohere chat error")
)

type Model = model.Model

const (
ModelCommand Model = Model(model.ModelCommand)
ModelCommandNightly Model = Model(model.ModelCommandNightly)
ModelCommandLight Model = Model(model.ModelCommandLight)
ModelCommandLightNightly Model = Model(model.ModelCommandLightNightly)
ModelCommand Model = model.ModelCommand
ModelCommandNightly Model = model.ModelCommandNightly
ModelCommandLight Model = model.ModelCommandLight
ModelCommandLightNightly Model = model.ModelCommandLightNightly
)

const (
Expand All @@ -31,14 +35,17 @@ const (
DefaultModel = ModelCommand
)

type StreamCallbackFn func(string)

type Cohere struct {
client *coherego.Client
model Model
temperature float64
maxTokens int
verbose bool
stop []string
cache *cache.Cache
client *coherego.Client
model Model
temperature float64
maxTokens int
verbose bool
stop []string
cache *cache.Cache
streamCallbackFn StreamCallbackFn
}

func (c *Cohere) WithCache(cache *cache.Cache) *Cohere {
Expand Down Expand Up @@ -96,6 +103,11 @@ func (c *Cohere) WithStop(stop []string) *Cohere {
return c
}

func (c *Cohere) WithStream(callbackFn StreamCallbackFn) *Cohere {
c.streamCallbackFn = callbackFn
return c
}

// Completion returns the completion for the given prompt
func (c *Cohere) Completion(ctx context.Context, prompt string) (string, error) {
resp := &response.Generate{}
Expand Down Expand Up @@ -187,34 +199,73 @@ func (c *Cohere) Generate(ctx context.Context, t *thread.Thread) error {
if err == nil {
return nil
} else if !errors.Is(err, cache.ErrCacheMiss) {
return err
return fmt.Errorf("%w: %w", ErrCohereChat, err)
}
}

completionQuery := ""
for _, message := range t.Messages {
for _, content := range message.Contents {
if content.Type == thread.ContentTypeText {
completionQuery += content.Data.(string) + "\n"
}
}
chatRequest := c.buildChatCompletionRequest(t)

if c.streamCallbackFn != nil {
err = c.stream(ctx, t, chatRequest)
} else {
err = c.generate(ctx, t, chatRequest)
}

completionResponse, err := c.Completion(ctx, completionQuery)
if err != nil {
return err
}

t.AddMessage(thread.NewAssistantMessage().AddContent(
thread.NewTextContent(completionResponse),
))

if c.cache != nil {
err = c.setCache(ctx, t, cacheResult)
if err != nil {
return err
return fmt.Errorf("%w: %w", ErrCohereChat, err)
}
}

return nil
}

func (c *Cohere) generate(ctx context.Context, t *thread.Thread, chatRequest *request.Chat) error {
var response response.Chat

err := c.client.Chat(
ctx,
chatRequest,
&response,
)
if err != nil {
return fmt.Errorf("%w: %w", ErrCohereChat, err)
}

t.AddMessage(thread.NewAssistantMessage().AddContent(
thread.NewTextContent(response.Text),
))

return nil
}

func (c *Cohere) stream(ctx context.Context, t *thread.Thread, chatRequest *request.Chat) error {
chatResponse := &response.Chat{}
var assistantMessage string

err := c.client.ChatStream(
ctx,
chatRequest,
chatResponse,
func(r *response.Chat) {
if r.Text != "" {
c.streamCallbackFn(r.Text)
assistantMessage += r.Text
}
},
)
if err != nil {
return fmt.Errorf("%w: %w", ErrCohereChat, err)
}

t.AddMessage(thread.NewAssistantMessage().AddContent(
thread.NewTextContent(assistantMessage),
))

return nil
}
81 changes: 81 additions & 0 deletions llm/cohere/formatter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package cohere

import (
"github.com/henomis/cohere-go/model"
"github.com/henomis/cohere-go/request"
"github.com/henomis/lingoose/thread"
)

var threadRoleToCohereRole = map[thread.Role]model.ChatMessageRole{
thread.RoleSystem: model.ChatMessageRoleChatbot,
thread.RoleUser: model.ChatMessageRoleUser,
thread.RoleAssistant: model.ChatMessageRoleChatbot,
}

func (c *Cohere) buildChatCompletionRequest(t *thread.Thread) *request.Chat {
message, history := threadToChatMessages(t)

return &request.Chat{
Model: c.model,
ChatHistory: history,
Message: message,
}
}

func threadToChatMessages(t *thread.Thread) (string, []model.ChatMessage) {
var history []model.ChatMessage
var message string

for _, m := range t.Messages {
chatMessage := model.ChatMessage{
Role: threadRoleToCohereRole[m.Role],
}

switch m.Role {
case thread.RoleUser, thread.RoleSystem, thread.RoleAssistant:
for _, content := range m.Contents {
if content.Type == thread.ContentTypeText {
chatMessage.Message += content.Data.(string) + "\n"
}
}
case thread.RoleTool:
continue
}

history = append(history, chatMessage)
}

lastMessage := t.LastMessage()
if lastMessage.Role == thread.RoleUser {
for _, content := range lastMessage.Contents {
if content.Type == thread.ContentTypeText {
message += content.Data.(string) + "\n"
}
}

history = history[:len(history)-1]
}

return message, history
}

// chatMessages := make([]message, len(t.Messages))
// for i, m := range t.Messages {
// chatMessages[i] = message{
// Role: threadRoleToOpenAIRole[m.Role],
// }

// switch m.Role {
// case thread.RoleUser, thread.RoleSystem, thread.RoleAssistant:
// for _, content := range m.Contents {
// if content.Type == thread.ContentTypeText {
// chatMessages[i].Content += content.Data.(string) + "\n"
// }
// }
// case thread.RoleTool:
// continue
// }
// }

// return chatMessages
// }
2 changes: 1 addition & 1 deletion llm/ollama/formatter.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func threadToChatMessages(t *thread.Thread) []message {
chatMessages := make([]message, len(t.Messages))
for i, m := range t.Messages {
chatMessages[i] = message{
Role: threadRoleToOpenAIRole[m.Role],
Role: threadRoleToOllamaRole[m.Role],
}

switch m.Role {
Expand Down
11 changes: 3 additions & 8 deletions llm/ollama/ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ var (
ErrOllamaChat = fmt.Errorf("ollama chat error")
)

var threadRoleToOpenAIRole = map[thread.Role]string{
var threadRoleToOllamaRole = map[thread.Role]string{
thread.RoleSystem: "system",
thread.RoleUser: "user",
thread.RoleAssistant: "assistant",
Expand Down Expand Up @@ -54,13 +54,8 @@ func (o *Ollama) WithModel(model string) *Ollama {
return o
}

func (o *Ollama) WithStream(enable bool, callbackFn StreamCallbackFn) *Ollama {
if !enable {
o.streamCallbackFn = nil
} else {
o.streamCallbackFn = callbackFn
}

func (o *Ollama) WithStream(callbackFn StreamCallbackFn) *Ollama {
o.streamCallbackFn = callbackFn
return o
}

Expand Down

0 comments on commit 12caa7f

Please sign in to comment.