From 40e06d2db65cba7911eb3a05bd5dce4417fb3cb5 Mon Sep 17 00:00:00 2001 From: Haouari haitam Kouider <57036855+haouarihk@users.noreply.github.com> Date: Thu, 14 Dec 2023 16:43:38 +0100 Subject: [PATCH 1/5] adding candidateCount + fixing stopSequences not being set in the request --- .../langchain-google-genai/src/chat_models.ts | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/libs/langchain-google-genai/src/chat_models.ts b/libs/langchain-google-genai/src/chat_models.ts index 201fb38dd106..b9bc8c8e8b9f 100644 --- a/libs/langchain-google-genai/src/chat_models.ts +++ b/libs/langchain-google-genai/src/chat_models.ts @@ -86,6 +86,13 @@ export interface GoogleGenerativeAIChatInput extends BaseChatModelParams { */ stopSequences?: string[]; + /** + * Same as variable N but for google. + * + * Note: stopSequences is only supported for Gemini models + */ + candidateCount?: string[]; + /** * A list of unique `SafetySetting` instances for blocking unsafe content. The API will block * any prompts and responses that fail to meet the thresholds set by these settings. If there @@ -157,6 +164,8 @@ export class ChatGoogleGenerativeAI stopSequences: string[] = []; + candidateCount: number = 1; + safetySettings?: SafetySetting[]; apiKey?: string; @@ -198,6 +207,17 @@ export class ChatGoogleGenerativeAI throw new Error("`topK` must be a positive integer"); } + this.candidateCount = fields?.candidateCount ?? this.candidateCount; + if (this.candidateCount && this.candidateCount <= 0) { + throw new Error("`candidateCount` must be above 1."); + } + + this.stopSequences = fields?.stopSequences ?? this.stopSequences; + if(this.stopSequences && typeof this.stopSequences == "string") + this.stopSequences = [this.stopSequences]; + + + this.apiKey = fields?.apiKey ?? getEnvironmentVariable("GOOGLE_API_KEY"); if (!this.apiKey) { throw new Error( @@ -224,7 +244,7 @@ export class ChatGoogleGenerativeAI model: this.modelName, safetySettings: this.safetySettings as SafetySetting[], generationConfig: { - candidateCount: 1, + candidateCount: this.candidateCount, stopSequences: this.stopSequences, maxOutputTokens: this.maxOutputTokens, temperature: this.temperature, From d521bbc24e773859ef4f0a78dd880d89822cab99 Mon Sep 17 00:00:00 2001 From: Jacob Lee Date: Thu, 14 Dec 2023 09:19:21 -0800 Subject: [PATCH 2/5] Update chat_models.ts --- .../langchain-google-genai/src/chat_models.ts | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/libs/langchain-google-genai/src/chat_models.ts b/libs/langchain-google-genai/src/chat_models.ts index b9bc8c8e8b9f..b79b35238c11 100644 --- a/libs/langchain-google-genai/src/chat_models.ts +++ b/libs/langchain-google-genai/src/chat_models.ts @@ -86,13 +86,6 @@ export interface GoogleGenerativeAIChatInput extends BaseChatModelParams { */ stopSequences?: string[]; - /** - * Same as variable N but for google. - * - * Note: stopSequences is only supported for Gemini models - */ - candidateCount?: string[]; - /** * A list of unique `SafetySetting` instances for blocking unsafe content. The API will block * any prompts and responses that fail to meet the thresholds set by these settings. If there @@ -207,16 +200,10 @@ export class ChatGoogleGenerativeAI throw new Error("`topK` must be a positive integer"); } - this.candidateCount = fields?.candidateCount ?? this.candidateCount; - if (this.candidateCount && this.candidateCount <= 0) { - throw new Error("`candidateCount` must be above 1."); - } - this.stopSequences = fields?.stopSequences ?? this.stopSequences; - if(this.stopSequences && typeof this.stopSequences == "string") + if (this.stopSequences && typeof this.stopSequences == "string") { this.stopSequences = [this.stopSequences]; - - + } this.apiKey = fields?.apiKey ?? getEnvironmentVariable("GOOGLE_API_KEY"); if (!this.apiKey) { @@ -244,7 +231,7 @@ export class ChatGoogleGenerativeAI model: this.modelName, safetySettings: this.safetySettings as SafetySetting[], generationConfig: { - candidateCount: this.candidateCount, + candidateCount: 1, stopSequences: this.stopSequences, maxOutputTokens: this.maxOutputTokens, temperature: this.temperature, From ee545d6f32a63677f5fc39be6441f412b7253764 Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Thu, 14 Dec 2023 09:37:47 -0800 Subject: [PATCH 3/5] Small Google fixes --- libs/langchain-google-genai/src/chat_models.ts | 5 ----- .../src/tests/chat_models.int.test.ts | 12 ++++++++++++ libs/langchain-google-genai/src/utils.ts | 4 +++- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/libs/langchain-google-genai/src/chat_models.ts b/libs/langchain-google-genai/src/chat_models.ts index b79b35238c11..0dd0e8570089 100644 --- a/libs/langchain-google-genai/src/chat_models.ts +++ b/libs/langchain-google-genai/src/chat_models.ts @@ -157,8 +157,6 @@ export class ChatGoogleGenerativeAI stopSequences: string[] = []; - candidateCount: number = 1; - safetySettings?: SafetySetting[]; apiKey?: string; @@ -201,9 +199,6 @@ export class ChatGoogleGenerativeAI } this.stopSequences = fields?.stopSequences ?? this.stopSequences; - if (this.stopSequences && typeof this.stopSequences == "string") { - this.stopSequences = [this.stopSequences]; - } this.apiKey = fields?.apiKey ?? getEnvironmentVariable("GOOGLE_API_KEY"); if (!this.apiKey) { diff --git a/libs/langchain-google-genai/src/tests/chat_models.int.test.ts b/libs/langchain-google-genai/src/tests/chat_models.int.test.ts index 656973425a92..dad7a8397218 100644 --- a/libs/langchain-google-genai/src/tests/chat_models.int.test.ts +++ b/libs/langchain-google-genai/src/tests/chat_models.int.test.ts @@ -21,6 +21,18 @@ test("Test Google AI generation", async () => { expect(res).toBeTruthy(); }); +test("Test Google AI generation with a stop sequence", async () => { + const model = new ChatGoogleGenerativeAI({ + stopSequences: ["two", "2"], + }); + const res = await model.invoke([ + ["human", `What are the first three positive whole numbers?`], + ]); + console.log(JSON.stringify(res, null, 2)); + expect(res).toBeTruthy(); + expect(res.additional_kwargs.finishReason).toBe("STOP"); +}); + test("Test Google AI generation with a system message", async () => { const model = new ChatGoogleGenerativeAI({}); const res = await model.generate([ diff --git a/libs/langchain-google-genai/src/utils.ts b/libs/langchain-google-genai/src/utils.ts index c16cd392538e..d8d7c227cbef 100644 --- a/libs/langchain-google-genai/src/utils.ts +++ b/libs/langchain-google-genai/src/utils.ts @@ -177,7 +177,7 @@ export function mapGenerateContentResultToChatResult( message: new AIMessage({ content: text, name: content === null ? undefined : content.role, - additional_kwargs: {}, + additional_kwargs: generationInfo, }), generationInfo, }; @@ -202,6 +202,8 @@ export function convertResponseContentToChatGenerationChunk( message: new AIMessageChunk({ content: text, name: content === null ? undefined : content.role, + // Each chunk can have unique "generationInfo", and merging strategy is unclear, + // so leave blank for now. additional_kwargs: {}, }), generationInfo, From 1046438694686e71ff5af7d6f2915bf0ad8e545d Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Thu, 14 Dec 2023 09:46:22 -0800 Subject: [PATCH 4/5] Fix test --- libs/langchain-google-genai/src/tests/chat_models.int.test.ts | 2 ++ 1 file changed, 2 insertions(+) diff --git a/libs/langchain-google-genai/src/tests/chat_models.int.test.ts b/libs/langchain-google-genai/src/tests/chat_models.int.test.ts index dad7a8397218..112d3a03b270 100644 --- a/libs/langchain-google-genai/src/tests/chat_models.int.test.ts +++ b/libs/langchain-google-genai/src/tests/chat_models.int.test.ts @@ -31,6 +31,8 @@ test("Test Google AI generation with a stop sequence", async () => { console.log(JSON.stringify(res, null, 2)); expect(res).toBeTruthy(); expect(res.additional_kwargs.finishReason).toBe("STOP"); + expect(res.content).not.toContain("2") + expect(res.content).not.toContain("two") }); test("Test Google AI generation with a system message", async () => { From 13c7d8a0c3eb15386f5d814501fd794ef04a431b Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Thu, 14 Dec 2023 09:47:59 -0800 Subject: [PATCH 5/5] Format --- libs/langchain-google-genai/src/tests/chat_models.int.test.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/langchain-google-genai/src/tests/chat_models.int.test.ts b/libs/langchain-google-genai/src/tests/chat_models.int.test.ts index 112d3a03b270..46dffb92e05f 100644 --- a/libs/langchain-google-genai/src/tests/chat_models.int.test.ts +++ b/libs/langchain-google-genai/src/tests/chat_models.int.test.ts @@ -31,8 +31,8 @@ test("Test Google AI generation with a stop sequence", async () => { console.log(JSON.stringify(res, null, 2)); expect(res).toBeTruthy(); expect(res.additional_kwargs.finishReason).toBe("STOP"); - expect(res.content).not.toContain("2") - expect(res.content).not.toContain("two") + expect(res.content).not.toContain("2"); + expect(res.content).not.toContain("two"); }); test("Test Google AI generation with a system message", async () => {