Skip to content

Commit

Permalink
fix(adapters): gcp vertexai llm adapter parameters (#194)
Browse files Browse the repository at this point in the history
Signed-off-by: Akihiko Kuroda <[email protected]>
  • Loading branch information
akihikokuroda authored Nov 26, 2024
1 parent a06ef6f commit 54819bf
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 14 deletions.
23 changes: 18 additions & 5 deletions src/adapters/vertexai/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import {
import { shallowCopy } from "@/serializer/utils.js";
import type { GetRunContext } from "@/context.js";
import { Emitter } from "@/emitter/emitter.js";
import { VertexAI } from "@google-cloud/vertexai";
import { VertexAI, BaseModelParams as Params } from "@google-cloud/vertexai";
import { ChatLLM, ChatLLMOutput } from "@/llms/chat.js";
import { BaseMessage, Role } from "@/llms/primitives/message.js";
import { signalRace } from "@/internals/helpers/promise.js";
Expand Down Expand Up @@ -72,7 +72,7 @@ export interface VertexAIChatLLMInput {
client?: VertexAI;
executionOptions?: ExecutionOptions;
cache?: LLMCache<VertexAIChatLLMOutput>;
parameters?: Record<string, any>;
parameters?: Params;
}

export class VertexAIChatLLM extends ChatLLM<VertexAIChatLLMOutput> {
Expand All @@ -82,9 +82,11 @@ export class VertexAIChatLLM extends ChatLLM<VertexAIChatLLMOutput> {
});

protected client: VertexAI;
protected parameters?: Params;

constructor(protected readonly input: VertexAIChatLLMInput) {
super(input.modelId, input.executionOptions, input.cache);
this.parameters = input.parameters;
this.client = new VertexAI({ project: input.project, location: input.location });
}

Expand Down Expand Up @@ -112,7 +114,12 @@ export class VertexAIChatLLM extends ChatLLM<VertexAIChatLLMOutput> {
options: GenerateOptions,
run: GetRunContext<this>,
): Promise<VertexAIChatLLMOutput> {
const generativeModel = createModel(this.client, this.modelId, options.guided?.json);
const generativeModel = createModel(
this.client,
this.modelId,
options.guided?.json,
this.parameters,
);
const response = await signalRace(
() =>
generativeModel.generateContent({
Expand All @@ -132,10 +139,15 @@ export class VertexAIChatLLM extends ChatLLM<VertexAIChatLLMOutput> {
options: GenerateOptions | undefined,
run: GetRunContext<this>,
): AsyncStream<VertexAIChatLLMOutput, void> {
const generativeModel = createModel(this.client, this.modelId, options?.guided?.json);
const generativeModel = createModel(
this.client,
this.modelId,
options?.guided?.json,
this.parameters,
);
const chat = generativeModel.startChat();
const response = await chat.sendMessageStream(input.map((msg) => msg.text));
for await (const chunk of await response.stream) {
for await (const chunk of response.stream) {
if (options?.signal?.aborted) {
break;
}
Expand All @@ -153,6 +165,7 @@ export class VertexAIChatLLM extends ChatLLM<VertexAIChatLLMOutput> {
...super.createSnapshot(),
input: shallowCopy(this.input),
client: this.client,
parameters: this.parameters,
};
}
}
23 changes: 18 additions & 5 deletions src/adapters/vertexai/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import {
import { shallowCopy } from "@/serializer/utils.js";
import type { GetRunContext } from "@/context.js";
import { Emitter } from "@/emitter/emitter.js";
import { VertexAI } from "@google-cloud/vertexai";
import { VertexAI, BaseModelParams as Params } from "@google-cloud/vertexai";
import { Role } from "@/llms/primitives/message.js";
import { signalRace } from "@/internals/helpers/promise.js";
import { processContentResponse, getTokenCount, registerVertexAI, createModel } from "./utils.js";
Expand Down Expand Up @@ -74,7 +74,7 @@ export interface VertexAILLMInput {
client?: VertexAI;
executionOptions?: ExecutionOptions;
cache?: LLMCache<VertexAILLMOutput>;
parameters?: Record<string, any>;
parameters?: Params;
}

export class VertexAILLM extends LLM<VertexAILLMOutput, GenerateOptions> {
Expand All @@ -84,9 +84,11 @@ export class VertexAILLM extends LLM<VertexAILLMOutput, GenerateOptions> {
});

protected client: VertexAI;
protected parameters?: Params;

constructor(protected readonly input: VertexAILLMInput) {
super(input.modelId, input.executionOptions, input.cache);
this.parameters = input.parameters;
this.client =
input.client ?? new VertexAI({ project: input.project, location: input.location });
}
Expand Down Expand Up @@ -115,7 +117,12 @@ export class VertexAILLM extends LLM<VertexAILLMOutput, GenerateOptions> {
options: GenerateOptions,
run: GetRunContext<this>,
): Promise<VertexAILLMOutput> {
const generativeModel = createModel(this.client, this.modelId, options.guided?.json);
const generativeModel = createModel(
this.client,
this.modelId,
options.guided?.json,
this.parameters,
);
const responses = await signalRace(() => generativeModel.generateContent(input), run.signal);
const result: VertexAILLMChunk = {
text: processContentResponse(responses.response),
Expand All @@ -129,9 +136,14 @@ export class VertexAILLM extends LLM<VertexAILLMOutput, GenerateOptions> {
options: GenerateOptions | undefined,
run: GetRunContext<this>,
): AsyncStream<VertexAILLMOutput, void> {
const generativeModel = createModel(this.client, this.modelId, options?.guided?.json);
const generativeModel = createModel(
this.client,
this.modelId,
options?.guided?.json,
this.parameters,
);
const response = await generativeModel.generateContentStream(input);
for await (const chunk of await response.stream) {
for await (const chunk of response.stream) {
if (options?.signal?.aborted) {
break;
}
Expand All @@ -149,6 +161,7 @@ export class VertexAILLM extends LLM<VertexAILLMOutput, GenerateOptions> {
...super.createSnapshot(),
input: shallowCopy(this.input),
client: this.client,
parameters: this.parameters,
};
}

Expand Down
17 changes: 13 additions & 4 deletions src/adapters/vertexai/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@

import { isString } from "remeda";
import { Serializer } from "@/serializer/serializer.js";
import { VertexAI, GenerativeModel, ModelParams } from "@google-cloud/vertexai";
import {
VertexAI,
GenerativeModel,
ModelParams,
BaseModelParams as Params,
} from "@google-cloud/vertexai";
import { getPropStrict } from "@/internals/helpers/object.js";
import { GenerateContentResponse } from "@google-cloud/vertexai";

Expand Down Expand Up @@ -50,12 +55,16 @@ export function createModel(
client: VertexAI,
modelId: string,
schema?: string | Record<string, any>,
params?: Params,
): GenerativeModel {
const modelParams: ModelParams = { model: modelId };
const modelParams: ModelParams = { model: modelId, ...params };
if (schema) {
const schemaJson = isString(schema) ? JSON.parse(schema) : schema;
const generationConfig = { responseSchema: schemaJson, responseMimeType: "application/json" };
modelParams.generationConfig = generationConfig;
modelParams.generationConfig = {
...modelParams.generationConfig,
responseSchema: schemaJson,
responseMimeType: "application/json",
};
}
return client.getGenerativeModel(modelParams);
}

0 comments on commit 54819bf

Please sign in to comment.