Skip to content

Commit

Permalink
chore Add ollama vision (#167)
Browse files Browse the repository at this point in the history
  • Loading branch information
henomis authored Feb 15, 2024
1 parent 598dcf0 commit f093366
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 13 deletions.
28 changes: 28 additions & 0 deletions examples/llm/ollama/multimodal/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package main

import (
"context"
"fmt"

"github.com/henomis/lingoose/llm/ollama"
"github.com/henomis/lingoose/thread"
)

func main() {
ollamallm := ollama.New().WithModel("llava")

t := thread.New().AddMessage(
thread.NewUserMessage().AddContent(
thread.NewTextContent("Can you describe the image?"),
).AddContent(
thread.NewImageContentFromURL("https://upload.wikimedia.org/wikipedia/commons/thumb/3/34/Anser_anser_1_%28Piotr_Kuczynski%29.jpg/1280px-Anser_anser_1_%28Piotr_Kuczynski%29.jpg"),
),
)

err := ollamallm.Generate(context.Background(), t)
if err != nil {
panic(err)
}

fmt.Println(t)
}
40 changes: 35 additions & 5 deletions llm/ollama/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@ package ollama

import (
"bytes"
"encoding/base64"
"encoding/json"
"io"
"net/http"
"os"
"strings"

"github.com/henomis/restclientgo"
)
Expand All @@ -29,7 +33,7 @@ func (r *request) Encode() (io.Reader, error) {
}

func (r *request) ContentType() string {
return "application/json"
return jsonContentType
}

type response[T any] struct {
Expand All @@ -40,6 +44,7 @@ type response[T any] struct {
Message T `json:"message"`
Done bool `json:"done"`
streamCallbackFn restclientgo.StreamCallback
RawBody []byte `json:"-"`
}

type assistantMessage struct {
Expand All @@ -55,15 +60,16 @@ func (r *response[T]) Decode(body io.Reader) error {
return json.NewDecoder(body).Decode(r)
}

func (r *response[T]) SetBody(_ io.Reader) error {
func (r *response[T]) SetBody(body io.Reader) error {
r.RawBody, _ = io.ReadAll(body)
return nil
}

func (r *response[T]) AcceptContentType() string {
if r.acceptContentType != "" {
return r.acceptContentType
}
return "application/json"
return jsonContentType
}

func (r *response[T]) SetStatusCode(code int) error {
Expand All @@ -82,10 +88,34 @@ func (r *response[T]) StreamCallback() restclientgo.StreamCallback {
}

type message struct {
Role string `json:"role"`
Content string `json:"content"`
Role string `json:"role"`
Content string `json:"content,omitempty"`
Images []string `json:"images,omitempty"`
}

type options struct {
Temperature float64 `json:"temperature"`
}

func getImageDataAsBase64(imageURL string) (string, error) {
var imageData []byte
var err error

if strings.HasPrefix(imageURL, "http://") || strings.HasPrefix(imageURL, "https://") {
//nolint:gosec
resp, fetchErr := http.Get(imageURL)
if fetchErr != nil {
return "", fetchErr
}
defer resp.Body.Close()

imageData, err = io.ReadAll(resp.Body)
} else {
imageData, err = os.ReadFile(imageURL)
}
if err != nil {
return "", err
}

return base64.StdEncoding.EncodeToString(imageData), nil
}
37 changes: 29 additions & 8 deletions llm/ollama/formatter.go
Original file line number Diff line number Diff line change
@@ -1,27 +1,48 @@
package ollama

import "github.com/henomis/lingoose/thread"
import (
"github.com/henomis/lingoose/thread"
)

func (o *Ollama) buildChatCompletionRequest(t *thread.Thread) *request {
return &request{
Model: o.model,
Messages: threadToChatMessages(t),
Options: options{
Temperature: o.temperature,
},
}
}

//nolint:gocognit
func threadToChatMessages(t *thread.Thread) []message {
chatMessages := make([]message, len(t.Messages))
for i, m := range t.Messages {
chatMessages[i] = message{
Role: threadRoleToOllamaRole[m.Role],
}

var chatMessages []message
for _, m := range t.Messages {
switch m.Role {
case thread.RoleUser, thread.RoleSystem, thread.RoleAssistant:
for _, content := range m.Contents {
chatMessage := message{
Role: threadRoleToOllamaRole[m.Role],
}

contentData, ok := content.Data.(string)
if !ok {
continue
}

if content.Type == thread.ContentTypeText {
chatMessages[i].Content += content.Data.(string) + "\n"
chatMessage.Content = contentData
} else if content.Type == thread.ContentTypeImage {
imageData, err := getImageDataAsBase64(contentData)
if err != nil {
continue
}
chatMessage.Images = []string{imageData}
} else {
continue
}

chatMessages = append(chatMessages, chatMessage)
}
case thread.RoleTool:
continue
Expand Down
12 changes: 12 additions & 0 deletions llm/ollama/ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"

"github.com/henomis/lingoose/llm/cache"
Expand All @@ -15,6 +16,7 @@ import (
const (
defaultModel = "llama2"
ndjsonContentType = "application/x-ndjson"
jsonContentType = "application/json"
defaultEndpoint = "http://localhost:11434/api"
)

Expand All @@ -32,6 +34,7 @@ type StreamCallbackFn func(string)

type Ollama struct {
model string
temperature float64
restClient *restclientgo.RestClient
streamCallbackFn StreamCallbackFn
cache *cache.Cache
Expand Down Expand Up @@ -64,6 +67,11 @@ func (o *Ollama) WithCache(cache *cache.Cache) *Ollama {
return o
}

func (o *Ollama) WithTemperature(temperature float64) *Ollama {
o.temperature = temperature
return o
}

func (o *Ollama) getCache(ctx context.Context, t *thread.Thread) (*cache.Result, error) {
messages := t.UserQuery()
cacheQuery := strings.Join(messages, "\n")
Expand Down Expand Up @@ -193,6 +201,10 @@ func (o *Ollama) stream(ctx context.Context, t *thread.Thread, chatRequest *requ
return fmt.Errorf("%w: %w", ErrOllamaChat, err)
}

if resp.HTTPStatusCode >= http.StatusBadRequest {
return fmt.Errorf("%w: %s", ErrOllamaChat, resp.RawBody)
}

t.AddMessage(thread.NewAssistantMessage().AddContent(
thread.NewTextContent(assistantMessage),
))
Expand Down

0 comments on commit f093366

Please sign in to comment.