From f0933660767b6aa52e6266d32a1b070a41f30093 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Thu, 15 Feb 2024 08:56:51 +0100 Subject: [PATCH] chore Add ollama vision (#167) --- examples/llm/ollama/multimodal/main.go | 28 ++++++++++++++++++ llm/ollama/api.go | 40 ++++++++++++++++++++++---- llm/ollama/formatter.go | 37 ++++++++++++++++++------ llm/ollama/ollama.go | 12 ++++++++ 4 files changed, 104 insertions(+), 13 deletions(-) create mode 100644 examples/llm/ollama/multimodal/main.go diff --git a/examples/llm/ollama/multimodal/main.go b/examples/llm/ollama/multimodal/main.go new file mode 100644 index 00000000..556bdf5a --- /dev/null +++ b/examples/llm/ollama/multimodal/main.go @@ -0,0 +1,28 @@ +package main + +import ( + "context" + "fmt" + + "github.com/henomis/lingoose/llm/ollama" + "github.com/henomis/lingoose/thread" +) + +func main() { + ollamallm := ollama.New().WithModel("llava") + + t := thread.New().AddMessage( + thread.NewUserMessage().AddContent( + thread.NewTextContent("Can you describe the image?"), + ).AddContent( + thread.NewImageContentFromURL("https://upload.wikimedia.org/wikipedia/commons/thumb/3/34/Anser_anser_1_%28Piotr_Kuczynski%29.jpg/1280px-Anser_anser_1_%28Piotr_Kuczynski%29.jpg"), + ), + ) + + err := ollamallm.Generate(context.Background(), t) + if err != nil { + panic(err) + } + + fmt.Println(t) +} diff --git a/llm/ollama/api.go b/llm/ollama/api.go index 53dfde39..1ba0855e 100644 --- a/llm/ollama/api.go +++ b/llm/ollama/api.go @@ -2,8 +2,12 @@ package ollama import ( "bytes" + "encoding/base64" "encoding/json" "io" + "net/http" + "os" + "strings" "github.com/henomis/restclientgo" ) @@ -29,7 +33,7 @@ func (r *request) Encode() (io.Reader, error) { } func (r *request) ContentType() string { - return "application/json" + return jsonContentType } type response[T any] struct { @@ -40,6 +44,7 @@ type response[T any] struct { Message T `json:"message"` Done bool `json:"done"` streamCallbackFn restclientgo.StreamCallback + RawBody []byte `json:"-"` } type assistantMessage struct { @@ -55,7 +60,8 @@ func (r *response[T]) Decode(body io.Reader) error { return json.NewDecoder(body).Decode(r) } -func (r *response[T]) SetBody(_ io.Reader) error { +func (r *response[T]) SetBody(body io.Reader) error { + r.RawBody, _ = io.ReadAll(body) return nil } @@ -63,7 +69,7 @@ func (r *response[T]) AcceptContentType() string { if r.acceptContentType != "" { return r.acceptContentType } - return "application/json" + return jsonContentType } func (r *response[T]) SetStatusCode(code int) error { @@ -82,10 +88,34 @@ func (r *response[T]) StreamCallback() restclientgo.StreamCallback { } type message struct { - Role string `json:"role"` - Content string `json:"content"` + Role string `json:"role"` + Content string `json:"content,omitempty"` + Images []string `json:"images,omitempty"` } type options struct { Temperature float64 `json:"temperature"` } + +func getImageDataAsBase64(imageURL string) (string, error) { + var imageData []byte + var err error + + if strings.HasPrefix(imageURL, "http://") || strings.HasPrefix(imageURL, "https://") { + //nolint:gosec + resp, fetchErr := http.Get(imageURL) + if fetchErr != nil { + return "", fetchErr + } + defer resp.Body.Close() + + imageData, err = io.ReadAll(resp.Body) + } else { + imageData, err = os.ReadFile(imageURL) + } + if err != nil { + return "", err + } + + return base64.StdEncoding.EncodeToString(imageData), nil +} diff --git a/llm/ollama/formatter.go b/llm/ollama/formatter.go index 5f4c01c5..86f4e91b 100644 --- a/llm/ollama/formatter.go +++ b/llm/ollama/formatter.go @@ -1,27 +1,48 @@ package ollama -import "github.com/henomis/lingoose/thread" +import ( + "github.com/henomis/lingoose/thread" +) func (o *Ollama) buildChatCompletionRequest(t *thread.Thread) *request { return &request{ Model: o.model, Messages: threadToChatMessages(t), + Options: options{ + Temperature: o.temperature, + }, } } +//nolint:gocognit func threadToChatMessages(t *thread.Thread) []message { - chatMessages := make([]message, len(t.Messages)) - for i, m := range t.Messages { - chatMessages[i] = message{ - Role: threadRoleToOllamaRole[m.Role], - } - + var chatMessages []message + for _, m := range t.Messages { switch m.Role { case thread.RoleUser, thread.RoleSystem, thread.RoleAssistant: for _, content := range m.Contents { + chatMessage := message{ + Role: threadRoleToOllamaRole[m.Role], + } + + contentData, ok := content.Data.(string) + if !ok { + continue + } + if content.Type == thread.ContentTypeText { - chatMessages[i].Content += content.Data.(string) + "\n" + chatMessage.Content = contentData + } else if content.Type == thread.ContentTypeImage { + imageData, err := getImageDataAsBase64(contentData) + if err != nil { + continue + } + chatMessage.Images = []string{imageData} + } else { + continue } + + chatMessages = append(chatMessages, chatMessage) } case thread.RoleTool: continue diff --git a/llm/ollama/ollama.go b/llm/ollama/ollama.go index 3fa0831e..933f1bd3 100644 --- a/llm/ollama/ollama.go +++ b/llm/ollama/ollama.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "net/http" "strings" "github.com/henomis/lingoose/llm/cache" @@ -15,6 +16,7 @@ import ( const ( defaultModel = "llama2" ndjsonContentType = "application/x-ndjson" + jsonContentType = "application/json" defaultEndpoint = "http://localhost:11434/api" ) @@ -32,6 +34,7 @@ type StreamCallbackFn func(string) type Ollama struct { model string + temperature float64 restClient *restclientgo.RestClient streamCallbackFn StreamCallbackFn cache *cache.Cache @@ -64,6 +67,11 @@ func (o *Ollama) WithCache(cache *cache.Cache) *Ollama { return o } +func (o *Ollama) WithTemperature(temperature float64) *Ollama { + o.temperature = temperature + return o +} + func (o *Ollama) getCache(ctx context.Context, t *thread.Thread) (*cache.Result, error) { messages := t.UserQuery() cacheQuery := strings.Join(messages, "\n") @@ -193,6 +201,10 @@ func (o *Ollama) stream(ctx context.Context, t *thread.Thread, chatRequest *requ return fmt.Errorf("%w: %w", ErrOllamaChat, err) } + if resp.HTTPStatusCode >= http.StatusBadRequest { + return fmt.Errorf("%w: %s", ErrOllamaChat, resp.RawBody) + } + t.AddMessage(thread.NewAssistantMessage().AddContent( thread.NewTextContent(assistantMessage), ))