diff --git a/src/adapters/vertexai/chat.ts b/src/adapters/vertexai/chat.ts index 77bbf5c5..711aed5d 100644 --- a/src/adapters/vertexai/chat.ts +++ b/src/adapters/vertexai/chat.ts @@ -26,7 +26,7 @@ import { import { shallowCopy } from "@/serializer/utils.js"; import type { GetRunContext } from "@/context.js"; import { Emitter } from "@/emitter/emitter.js"; -import { VertexAI } from "@google-cloud/vertexai"; +import { VertexAI, BaseModelParams as Params } from "@google-cloud/vertexai"; import { ChatLLM, ChatLLMOutput } from "@/llms/chat.js"; import { BaseMessage, Role } from "@/llms/primitives/message.js"; import { signalRace } from "@/internals/helpers/promise.js"; @@ -72,7 +72,7 @@ export interface VertexAIChatLLMInput { client?: VertexAI; executionOptions?: ExecutionOptions; cache?: LLMCache; - parameters?: Record; + parameters?: Params; } export class VertexAIChatLLM extends ChatLLM { @@ -82,9 +82,11 @@ export class VertexAIChatLLM extends ChatLLM { }); protected client: VertexAI; + protected parameters?: Params; constructor(protected readonly input: VertexAIChatLLMInput) { super(input.modelId, input.executionOptions, input.cache); + this.parameters = input.parameters; this.client = new VertexAI({ project: input.project, location: input.location }); } @@ -112,7 +114,12 @@ export class VertexAIChatLLM extends ChatLLM { options: GenerateOptions, run: GetRunContext, ): Promise { - const generativeModel = createModel(this.client, this.modelId, options.guided?.json); + const generativeModel = createModel( + this.client, + this.modelId, + options.guided?.json, + this.parameters, + ); const response = await signalRace( () => generativeModel.generateContent({ @@ -132,10 +139,15 @@ export class VertexAIChatLLM extends ChatLLM { options: GenerateOptions | undefined, run: GetRunContext, ): AsyncStream { - const generativeModel = createModel(this.client, this.modelId, options?.guided?.json); + const generativeModel = createModel( + this.client, + this.modelId, + options?.guided?.json, + this.parameters, + ); const chat = generativeModel.startChat(); const response = await chat.sendMessageStream(input.map((msg) => msg.text)); - for await (const chunk of await response.stream) { + for await (const chunk of response.stream) { if (options?.signal?.aborted) { break; } @@ -153,6 +165,7 @@ export class VertexAIChatLLM extends ChatLLM { ...super.createSnapshot(), input: shallowCopy(this.input), client: this.client, + parameters: this.parameters, }; } } diff --git a/src/adapters/vertexai/llm.ts b/src/adapters/vertexai/llm.ts index a2a44ca8..7536fa90 100644 --- a/src/adapters/vertexai/llm.ts +++ b/src/adapters/vertexai/llm.ts @@ -28,7 +28,7 @@ import { import { shallowCopy } from "@/serializer/utils.js"; import type { GetRunContext } from "@/context.js"; import { Emitter } from "@/emitter/emitter.js"; -import { VertexAI } from "@google-cloud/vertexai"; +import { VertexAI, BaseModelParams as Params } from "@google-cloud/vertexai"; import { Role } from "@/llms/primitives/message.js"; import { signalRace } from "@/internals/helpers/promise.js"; import { processContentResponse, getTokenCount, registerVertexAI, createModel } from "./utils.js"; @@ -74,7 +74,7 @@ export interface VertexAILLMInput { client?: VertexAI; executionOptions?: ExecutionOptions; cache?: LLMCache; - parameters?: Record; + parameters?: Params; } export class VertexAILLM extends LLM { @@ -84,9 +84,11 @@ export class VertexAILLM extends LLM { }); protected client: VertexAI; + protected parameters?: Params; constructor(protected readonly input: VertexAILLMInput) { super(input.modelId, input.executionOptions, input.cache); + this.parameters = input.parameters; this.client = input.client ?? new VertexAI({ project: input.project, location: input.location }); } @@ -115,7 +117,12 @@ export class VertexAILLM extends LLM { options: GenerateOptions, run: GetRunContext, ): Promise { - const generativeModel = createModel(this.client, this.modelId, options.guided?.json); + const generativeModel = createModel( + this.client, + this.modelId, + options.guided?.json, + this.parameters, + ); const responses = await signalRace(() => generativeModel.generateContent(input), run.signal); const result: VertexAILLMChunk = { text: processContentResponse(responses.response), @@ -129,9 +136,14 @@ export class VertexAILLM extends LLM { options: GenerateOptions | undefined, run: GetRunContext, ): AsyncStream { - const generativeModel = createModel(this.client, this.modelId, options?.guided?.json); + const generativeModel = createModel( + this.client, + this.modelId, + options?.guided?.json, + this.parameters, + ); const response = await generativeModel.generateContentStream(input); - for await (const chunk of await response.stream) { + for await (const chunk of response.stream) { if (options?.signal?.aborted) { break; } @@ -149,6 +161,7 @@ export class VertexAILLM extends LLM { ...super.createSnapshot(), input: shallowCopy(this.input), client: this.client, + parameters: this.parameters, }; } diff --git a/src/adapters/vertexai/utils.ts b/src/adapters/vertexai/utils.ts index 7f482485..4ebbca5a 100644 --- a/src/adapters/vertexai/utils.ts +++ b/src/adapters/vertexai/utils.ts @@ -16,7 +16,12 @@ import { isString } from "remeda"; import { Serializer } from "@/serializer/serializer.js"; -import { VertexAI, GenerativeModel, ModelParams } from "@google-cloud/vertexai"; +import { + VertexAI, + GenerativeModel, + ModelParams, + BaseModelParams as Params, +} from "@google-cloud/vertexai"; import { getPropStrict } from "@/internals/helpers/object.js"; import { GenerateContentResponse } from "@google-cloud/vertexai"; @@ -50,12 +55,16 @@ export function createModel( client: VertexAI, modelId: string, schema?: string | Record, + params?: Params, ): GenerativeModel { - const modelParams: ModelParams = { model: modelId }; + const modelParams: ModelParams = { model: modelId, ...params }; if (schema) { const schemaJson = isString(schema) ? JSON.parse(schema) : schema; - const generationConfig = { responseSchema: schemaJson, responseMimeType: "application/json" }; - modelParams.generationConfig = generationConfig; + modelParams.generationConfig = { + ...modelParams.generationConfig, + responseSchema: schemaJson, + responseMimeType: "application/json", + }; } return client.getGenerativeModel(modelParams); }