Skip to content

Commit

Permalink
🐛 fix(ollama): suppport vision for LLaVA models (lobehub#1791)
Browse files Browse the repository at this point in the history
Co-authored-by: elonwu <[email protected]>
  • Loading branch information
wuyuan1992 and elonwu authored Mar 29, 2024
1 parent 3fbafca commit e2d3de6
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 4 deletions.
6 changes: 3 additions & 3 deletions src/config/modelProviders/ollama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -135,23 +135,23 @@ const Ollama: ModelProviderCard = {
hidden: true,
id: 'llava',
tokens: 4000,
vision: false,
vision: true,
},
{
displayName: 'LLaVA 13B',
functionCall: false,
hidden: true,
id: 'llava:13b',
tokens: 4000,
vision: false,
vision: true,
},
{
displayName: 'LLaVA 34B',
functionCall: false,
hidden: true,
id: 'llava:34b',
tokens: 4000,
vision: false,
vision: true,
},
],
id: 'ollama',
Expand Down
47 changes: 46 additions & 1 deletion src/libs/agent-runtime/ollama/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import OpenAI from 'openai';
import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest';

import { ChatStreamCallbacks } from '@/libs/agent-runtime';
import { ChatStreamCallbacks, OpenAIChatMessage } from '@/libs/agent-runtime';

import * as debugStreamModule from '../utils/debugStream';
import { LobeOllamaAI } from './index';
Expand Down Expand Up @@ -317,4 +317,49 @@ describe('LobeOllamaAI', () => {
});
});
});

describe('private method', () => {
describe('convertContentToOllamaMessage', () => {
it('should format message array content of UserMessageContentPart to match ollama api', () => {
const message: OpenAIChatMessage = {
role: 'user',
content: [
{
type: 'text',
text: 'Hello',
},
{
type: 'image_url',
image_url: {
detail: 'auto',
url: '...',
},
},
],
};

const ollamaMessage = instance['convertContentToOllamaMessage'](message);

expect(ollamaMessage).toEqual({
role: 'user',
content: 'Hello',
images: ['iVBO...'],
});
});

it('should not affect string type message content', () => {
const message: OpenAIChatMessage = {
role: 'user',
content: 'Hello',
};

const ollamaMessage = instance['convertContentToOllamaMessage'](message);

expect(ollamaMessage).toEqual({
role: 'user',
content: 'Hello',
});
});
});
});
});
40 changes: 40 additions & 0 deletions src/libs/agent-runtime/ollama/index.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import { OpenAIStream, StreamingTextResponse } from 'ai';
import OpenAI, { ClientOptions } from 'openai';

import { OllamaChatMessage, OpenAIChatMessage } from '@/libs/agent-runtime';

import { LobeRuntimeAI } from '../BaseAI';
import { AgentRuntimeErrorType } from '../error';
import { ChatCompetitionOptions, ChatStreamPayload, ModelProvider } from '../types';
import { AgentRuntimeError } from '../utils/createError';
import { debugStream } from '../utils/debugStream';
import { desensitizeUrl } from '../utils/desensitizeUrl';
import { handleOpenAIError } from '../utils/handleOpenAIError';
import { parseDataUri } from '../utils/uriParser';

const DEFAULT_BASE_URL = 'http://127.0.0.1:11434/v1';

Expand All @@ -25,6 +28,8 @@ export class LobeOllamaAI implements LobeRuntimeAI {

async chat(payload: ChatStreamPayload, options?: ChatCompetitionOptions) {
try {
payload.messages = this.buildOllamaMessages(payload.messages);

const response = await this.client.chat.completions.create(
payload as unknown as OpenAI.ChatCompletionCreateParamsStreaming,
);
Expand Down Expand Up @@ -73,6 +78,41 @@ export class LobeOllamaAI implements LobeRuntimeAI {
});
}
}

private buildOllamaMessages(messages: OpenAIChatMessage[]) {
return messages.map((message) => this.convertContentToOllamaMessage(message));
}

private convertContentToOllamaMessage = (message: OpenAIChatMessage) => {
if (typeof message.content === 'string') {
return message;
}

const ollamaMessage: OllamaChatMessage = {
content: '',
role: message.role,
};

for (const content of message.content) {
switch (content.type) {
case 'text': {
// keep latest text input
ollamaMessage.content = content.text;
break;
}
case 'image_url': {
const { base64 } = parseDataUri(content.image_url.url);
if (base64) {
ollamaMessage.images ??= [];
ollamaMessage.images.push(base64);
}
break;
}
}
}

return ollamaMessage;
};
}

export default LobeOllamaAI;
7 changes: 7 additions & 0 deletions src/libs/agent-runtime/types/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,10 @@ export interface ChatCompletionTool {
}

export type ChatStreamCallbacks = OpenAIStreamCallbacks;

export interface OllamaChatMessage extends OpenAIChatMessage {
/**
* @description images for ollama vision models (https://ollama.com/blog/vision-models)
*/
images?: string[];
}

0 comments on commit e2d3de6

Please sign in to comment.