Skip to content

Commit

Permalink
feat(adpters): add ollama structured output and version retrieval (#237)
Browse files Browse the repository at this point in the history
Signed-off-by: Tomas Dvorak <[email protected]>
  • Loading branch information
Tomas2D authored Dec 9, 2024
1 parent 55847e5 commit 821364e
Show file tree
Hide file tree
Showing 7 changed files with 191 additions and 44 deletions.
4 changes: 2 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@
"@zilliz/milvus2-sdk-node": "^2.4.9",
"google-auth-library": "*",
"groq-sdk": "^0.7.0",
"ollama": "^0.5.8",
"ollama": "^0.5.11",
"openai": "^4.67.3",
"openai-chat-tokens": "^0.2.8",
"sequelize": "^6.37.3"
Expand Down Expand Up @@ -305,7 +305,7 @@
"langchain": "~0.3.6",
"linkinator": "^6.1.2",
"lint-staged": "^15.2.10",
"ollama": "^0.5.10",
"ollama": "^0.5.11",
"openai": "^4.76.0",
"openai-chat-tokens": "^0.2.8",
"openapi-fetch": "^0.13.3",
Expand Down
50 changes: 32 additions & 18 deletions src/adapters/ollama/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,18 @@ import { shallowCopy } from "@/serializer/utils.js";
import { ChatLLM, ChatLLMGenerateEvents, ChatLLMOutput } from "@/llms/chat.js";
import { BaseMessage } from "@/llms/primitives/message.js";
import { Emitter } from "@/emitter/emitter.js";
import { ChatResponse, Ollama as Client, Options as Parameters } from "ollama";
import { ChatRequest, ChatResponse, Config, Ollama as Client, Options as Parameters } from "ollama";
import { signalRace } from "@/internals/helpers/promise.js";
import { GetRunContext } from "@/context.js";
import { Cache } from "@/cache/decoratorCache.js";
import { customMerge } from "@/internals/helpers/object.js";
import { customMerge, getPropStrict } from "@/internals/helpers/object.js";
import { safeSum } from "@/internals/helpers/number.js";
import { extractModelMeta, registerClient } from "@/adapters/ollama/shared.js";
import {
extractModelMeta,
registerClient,
retrieveFormat,
retrieveVersion,
} from "@/adapters/ollama/shared.js";
import { getEnv } from "@/internals/env.js";

export class OllamaChatLLMOutput extends ChatLLMOutput {
Expand Down Expand Up @@ -161,22 +166,22 @@ export class OllamaChatLLM extends ChatLLM<OllamaChatLLMOutput> {
};
}

@Cache()
async version() {
const config = getPropStrict(this.client, "config") as Config;
return retrieveVersion(config.host, config.fetch);
}

protected async _generate(
input: BaseMessage[],
options: GenerateOptions,
run: GetRunContext<typeof this>,
): Promise<OllamaChatLLMOutput> {
const response = await signalRace(
() =>
async () =>
this.client.chat({
model: this.modelId,
...(await this.prepareParameters(input, options)),
stream: false,
messages: input.map((msg) => ({
role: msg.role,
content: msg.text,
})),
options: this.parameters,
format: options.guided?.json ? "json" : undefined,
}),
run.signal,
() => this.client.abort(),
Expand All @@ -191,14 +196,8 @@ export class OllamaChatLLM extends ChatLLM<OllamaChatLLMOutput> {
run: GetRunContext<typeof this>,
): AsyncStream<OllamaChatLLMOutput> {
for await (const chunk of await this.client.chat({
model: this.modelId,
...(await this.prepareParameters(input, options)),
stream: true,
messages: input.map((msg) => ({
role: msg.role,
content: msg.text,
})),
options: this.parameters,
format: options.guided?.json ? "json" : undefined,
})) {
if (run.signal.aborted) {
break;
Expand All @@ -208,6 +207,21 @@ export class OllamaChatLLM extends ChatLLM<OllamaChatLLMOutput> {
run.signal.throwIfAborted();
}

protected async prepareParameters(
input: BaseMessage[],
overrides?: GenerateOptions,
): Promise<ChatRequest> {
return {
model: this.modelId,
messages: input.map((msg) => ({
role: msg.role,
content: msg.text,
})),
options: this.parameters,
format: retrieveFormat(await this.version(), overrides?.guided),
};
}

createSnapshot() {
return {
...super.createSnapshot(),
Expand Down
50 changes: 36 additions & 14 deletions src/adapters/ollama/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,25 @@ import {
LLMOutputError,
StreamGenerateOptions,
} from "@/llms/base.js";
import { GenerateResponse, Ollama as Client, Options as Parameters } from "ollama";
import {
Config,
GenerateRequest,
GenerateResponse,
Ollama as Client,
Options as Parameters,
} from "ollama";
import { GetRunContext } from "@/context.js";
import { Cache } from "@/cache/decoratorCache.js";
import { safeSum } from "@/internals/helpers/number.js";
import { shallowCopy } from "@/serializer/utils.js";
import { signalRace } from "@/internals/helpers/promise.js";
import { customMerge } from "@/internals/helpers/object.js";
import { extractModelMeta, registerClient } from "@/adapters/ollama/shared.js";
import { customMerge, getPropStrict } from "@/internals/helpers/object.js";
import {
extractModelMeta,
registerClient,
retrieveFormat,
retrieveVersion,
} from "@/adapters/ollama/shared.js";
import { getEnv } from "@/internals/env.js";

interface Input {
Expand Down Expand Up @@ -131,14 +142,10 @@ export class OllamaLLM extends LLM<OllamaLLMOutput> {
run: GetRunContext<typeof this>,
): Promise<OllamaLLMOutput> {
const response = await signalRace(
() =>
async () =>
this.client.generate({
model: this.modelId,
...(await this.prepareParameters(input, options)),
stream: false,
raw: true,
prompt: input,
options: this.parameters,
format: options.guided?.json ? "json" : undefined,
}),
run.signal,
() => this.client.abort(),
Expand All @@ -153,12 +160,8 @@ export class OllamaLLM extends LLM<OllamaLLMOutput> {
run: GetRunContext<typeof this>,
): AsyncStream<OllamaLLMOutput, void> {
for await (const chunk of await this.client.generate({
model: this.modelId,
...(await this.prepareParameters(input, options)),
stream: true,
raw: true,
prompt: input,
options: this.parameters,
format: options.guided?.json ? "json" : undefined,
})) {
if (run.signal.aborted) {
break;
Expand All @@ -168,6 +171,12 @@ export class OllamaLLM extends LLM<OllamaLLMOutput> {
run.signal.throwIfAborted();
}

@Cache()
async version() {
const config = getPropStrict(this.client, "config") as Config;
return retrieveVersion(config.host, config.fetch);
}

async meta(): Promise<LLMMeta> {
const model = await this.client.show({
model: this.modelId,
Expand All @@ -182,6 +191,19 @@ export class OllamaLLM extends LLM<OllamaLLMOutput> {
};
}

protected async prepareParameters(
input: LLMInput,
overrides?: GenerateOptions,
): Promise<GenerateRequest> {
return {
model: this.modelId,
prompt: input,
raw: true,
options: this.parameters,
format: retrieveFormat(await this.version(), overrides?.guided),
};
}

createSnapshot() {
return {
...super.createSnapshot(),
Expand Down
32 changes: 31 additions & 1 deletion src/adapters/ollama/shared.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
import { Serializer } from "@/serializer/serializer.js";
import { Config, Ollama as Client, ShowResponse } from "ollama";
import { getPropStrict } from "@/internals/helpers/object.js";
import { LLMMeta } from "@/llms/base.js";
import { GuidedOptions, LLMMeta } from "@/llms/base.js";
import { Comparator, compareVersion } from "@/internals/helpers/string.js";
import { isString } from "remeda";

export function registerClient() {
Serializer.register(Client, {
Expand All @@ -34,6 +36,34 @@ export function registerClient() {
});
}

export async function retrieveVersion(
baseUrl: string,
client: typeof fetch = fetch,
): Promise<string> {
const url = new URL("/api/version", baseUrl);
const response = await client(url);
if (!response.ok) {
throw new Error(`Could not retrieve Ollama API version.`);
}
const data = await response.json();
return data.version;
}

export function retrieveFormat(
version: string | number,
guided?: GuidedOptions,
): string | object | undefined {
if (!guided?.json) {
return undefined;
}

if (compareVersion(String(version), Comparator.GTE, "0.5.0")) {
return isString(guided.json) ? JSON.parse(guided.json) : guided.json;
} else {
return "json";
}
}

export function extractModelMeta(response: ShowResponse): LLMMeta {
const tokenLimit = Object.entries(response.model_info)
.find(([k]) => k.includes("context_length") || k.includes("max_sequence_length"))
Expand Down
50 changes: 47 additions & 3 deletions tests/e2e/adapters/ollama/chat.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,19 @@
import { BaseMessage, Role } from "@/llms/primitives/message.js";
import { OllamaChatLLM } from "@/adapters/ollama/chat.js";
import { Ollama } from "ollama";
import { toJsonSchema } from "@/internals/helpers/schema.js";
import { z } from "zod";
import { Comparator, compareVersion } from "@/internals/helpers/string.js";

const host = process.env.OLLAMA_HOST;

describe.runIf(Boolean(host))("Ollama Chat LLM", () => {
const createChatLLM = () => {
const createChatLLM = (maxTokens?: number) => {
return new OllamaChatLLM({
modelId: "llama3.1",
parameters: {
temperature: 0,
num_predict: 5,
num_predict: maxTokens,
},
client: new Ollama({
host,
Expand All @@ -41,7 +44,7 @@ describe.runIf(Boolean(host))("Ollama Chat LLM", () => {
text: `You are a helpful and respectful and honest assistant. Your name is Bee.`,
}),
];
const llm = createChatLLM();
const llm = createChatLLM(5);
const response = await llm.generate([
...conversation,
BaseMessage.of({
Expand All @@ -51,4 +54,45 @@ describe.runIf(Boolean(host))("Ollama Chat LLM", () => {
]);
expect(response.getTextContent()).includes("Bee");
});

it("Leverages structured output", async () => {
const llm = createChatLLM();
const version = await llm.version();

if (compareVersion(version, Comparator.LT, "0.5.0")) {
// eslint-disable-next-line no-console
console.warn(`Structured output is not available in the current version (${version})`);
return;
}

const response = await llm.generate(
[
BaseMessage.of({
role: "user",
text: "Generate a valid JSON object.",
}),
],
{
stream: false,
guided: {
json: toJsonSchema(
z
.object({
a: z.literal("a"),
b: z.literal("b"),
c: z.literal("c"),
})
.strict(),
),
},
},
);
expect(response.getTextContent()).toMatchInlineSnapshot(`"{"a": "a", "b": "b", "c": "c"}"`);
});

it("Retrieves version", async () => {
const llm = createChatLLM();
const version = await llm.version();
expect(version).toBeDefined();
});
});
37 changes: 37 additions & 0 deletions tests/e2e/adapters/ollama/test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/**
* Copyright 2024 IBM Corp.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { OllamaChatLLM } from "@/adapters/ollama/chat.js";
import { Ollama } from "ollama";

const host = process.env.OLLAMA_HOST;

describe.runIf(Boolean(host))("Ollama LLM", () => {
const createLLM = () => {
return new OllamaChatLLM({
modelId: "llama3.1",
client: new Ollama({
host,
}),
});
};

it("Retrieves version", async () => {
const llm = createLLM();
const version = await llm.version();
expect(version).toBeDefined();
});
});
12 changes: 6 additions & 6 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -4845,7 +4845,7 @@ __metadata:
mathjs: "npm:^14.0.0"
mustache: "npm:^4.2.0"
object-hash: "npm:^3.0.0"
ollama: "npm:^0.5.10"
ollama: "npm:^0.5.11"
openai: "npm:^4.76.0"
openai-chat-tokens: "npm:^0.2.8"
openapi-fetch: "npm:^0.13.3"
Expand Down Expand Up @@ -4893,7 +4893,7 @@ __metadata:
"@zilliz/milvus2-sdk-node": ^2.4.9
google-auth-library: "*"
groq-sdk: ^0.7.0
ollama: ^0.5.8
ollama: ^0.5.11
openai: ^4.67.3
openai-chat-tokens: ^0.2.8
sequelize: ^6.37.3
Expand Down Expand Up @@ -10911,12 +10911,12 @@ __metadata:
languageName: node
linkType: hard

"ollama@npm:^0.5.10":
version: 0.5.10
resolution: "ollama@npm:0.5.10"
"ollama@npm:^0.5.11":
version: 0.5.11
resolution: "ollama@npm:0.5.11"
dependencies:
whatwg-fetch: "npm:^3.6.20"
checksum: 10c0/424b92ef640fb562c590bd8bf394c0cdab3eb93a1b70a58bcb98090957375e05130f41d94578124c0fc79b68448db351e3439e4dc3582ccdb7ae13343951c527
checksum: 10c0/9f8bb6715144fac2d423121f29bf7697e3c2132c6696574e2f2de63de8dfa95ac3ed435f3abf35cece6ef07c309065cbf722cead1bee1eda3541b095745f64bf
languageName: node
linkType: hard

Expand Down

0 comments on commit 821364e

Please sign in to comment.