From 895e306d2cfc5f5fc0c87a3fe6c1e72a56433201 Mon Sep 17 00:00:00 2001 From: gagik Date: Tue, 3 Dec 2024 21:48:02 +0100 Subject: [PATCH] adjust the tests --- src/participant/prompts/promptBase.ts | 8 ++++---- src/test/suite/participant/participant.test.ts | 15 +++++++++++---- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/participant/prompts/promptBase.ts b/src/participant/prompts/promptBase.ts index 440cbd193..58d36c92f 100644 --- a/src/participant/prompts/promptBase.ts +++ b/src/participant/prompts/promptBase.ts @@ -101,9 +101,9 @@ export abstract class PromptBase { return undefined; } - protected getUserPrompt( - request: PromptArgsBase['request'] - ): Promise { + protected getUserPrompt({ + request, + }: PromptArgs): Promise { return Promise.resolve({ prompt: request.prompt, hasSampleDocs: false, @@ -121,7 +121,7 @@ export abstract class PromptBase { this.getAssistantPrompt(args) ); - const { prompt, hasSampleDocs } = await this.getUserPrompt(request); + const { prompt, hasSampleDocs } = await this.getUserPrompt(args); // eslint-disable-next-line new-cap const userPrompt = vscode.LanguageModelChatMessage.User(prompt); diff --git a/src/test/suite/participant/participant.test.ts b/src/test/suite/participant/participant.test.ts index cf98ba10f..f36b9db4c 100644 --- a/src/test/suite/participant/participant.test.ts +++ b/src/test/suite/participant/participant.test.ts @@ -2038,7 +2038,7 @@ Schema: suite('prompt builders', function () { suite('prompt history', function () { test('gets filtered once history goes over maxInputTokens', async function () { - const expectedMaxMessages = 8; + const expectedMaxMessages = 10; const mockedMessages = Array.from( { length: 20 }, @@ -2065,11 +2065,18 @@ Schema: connectionNames: [], }); - // Should include the limit and the initial generic prompt and the newly sent request - expect(messages.length + 2).equals(expectedMaxMessages); + expect(messages.length).equals(expectedMaxMessages); + + // Should consist of the assistant prompt (1 token), 8 history messages (8 tokens), + // and the new request (1 token) expect( messages.slice(1).map((message) => getMessageContent(message)) - ).deep.equal([...mockedMessages, chatRequestMock.prompt]); + ).deep.equal([ + ...mockedMessages.slice( + mockedMessages.length - (expectedMaxMessages - 2) + ), + chatRequestMock.prompt, + ]); }); });