Skip to content

Commit

Permalink
feat(participant): filter message history when it goes over maxInputT…
Browse files Browse the repository at this point in the history
…okens VSCODE-653 (#894)
  • Loading branch information
gagik authored Dec 9, 2024
1 parent 7879cf9 commit 07ebbd0
Show file tree
Hide file tree
Showing 5 changed files with 214 additions and 73 deletions.
2 changes: 1 addition & 1 deletion src/participant/participant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1577,7 +1577,7 @@ export default class ParticipantController {
log.info('Docs chatbot created for chatId', chatId);
}

const history = PromptHistory.getFilteredHistoryForDocs({
const history = await PromptHistory.getFilteredHistoryForDocs({
connectionNames: this._getConnectionNames(),
context: context,
});
Expand Down
86 changes: 62 additions & 24 deletions src/participant/prompts/promptBase.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import type {
ParticipantPromptProperties,
} from '../../telemetry/telemetryService';
import { PromptHistory } from './promptHistory';
import { getCopilotModel } from '../model';
import type { ParticipantCommandType } from '../participantTypes';

export interface PromptArgsBase {
Expand Down Expand Up @@ -94,34 +95,76 @@ export function isContentEmpty(
return true;
}

export abstract class PromptBase<TArgs extends PromptArgsBase> {
protected abstract getAssistantPrompt(args: TArgs): string;
export abstract class PromptBase<PromptArgs extends PromptArgsBase> {
protected abstract getAssistantPrompt(args: PromptArgs): string;

protected get internalPurposeForTelemetry(): InternalPromptPurpose {
return undefined;
}

protected getUserPrompt(args: TArgs): Promise<UserPromptResponse> {
protected getUserPrompt({
request,
}: PromptArgs): Promise<UserPromptResponse> {
return Promise.resolve({
prompt: args.request.prompt,
prompt: request.prompt,
hasSampleDocs: false,
});
}

async buildMessages(args: TArgs): Promise<ModelInput> {
let historyMessages = PromptHistory.getFilteredHistory({
history: args.context?.history,
...args,
private async _countRemainingTokens({
model,
assistantPrompt,
requestPrompt,
}: {
model: vscode.LanguageModelChat | undefined;
assistantPrompt: vscode.LanguageModelChatMessage;
requestPrompt: string;
}): Promise<number | undefined> {
if (model) {
const [assistantPromptTokens, userPromptTokens] = await Promise.all([
model.countTokens(assistantPrompt),
model.countTokens(requestPrompt),
]);
return model.maxInputTokens - (assistantPromptTokens + userPromptTokens);
}
return undefined;
}

async buildMessages(args: PromptArgs): Promise<ModelInput> {
const { context, request, databaseName, collectionName, connectionNames } =
args;

const model = await getCopilotModel();

// eslint-disable-next-line new-cap
const assistantPrompt = vscode.LanguageModelChatMessage.Assistant(
this.getAssistantPrompt(args)
);

const tokenLimit = await this._countRemainingTokens({
model,
assistantPrompt,
requestPrompt: request.prompt,
});

let historyMessages = await PromptHistory.getFilteredHistory({
history: context?.history,
model,
tokenLimit,
namespaceIsKnown:
databaseName !== undefined && collectionName !== undefined,
connectionNames,
});

// If the current user's prompt is a connection name, and the last
// message was to connect. We want to use the last
// message they sent before the connection name as their prompt.
if (args.connectionNames?.includes(args.request.prompt)) {
const history = args.context?.history;
if (connectionNames?.includes(request.prompt)) {
const history = context?.history;
if (!history) {
return {
messages: [],
stats: this.getStats([], args, false),
stats: this.getStats([], { request, context }, false),
};
}
const previousResponse = history[
Expand All @@ -132,13 +175,11 @@ export abstract class PromptBase<TArgs extends PromptArgsBase> {
// Go through the history in reverse order to find the last user message.
for (let i = history.length - 1; i >= 0; i--) {
if (history[i] instanceof vscode.ChatRequestTurn) {
request.prompt = (history[i] as vscode.ChatRequestTurn).prompt;
// Rewrite the arguments so that the prompt is the last user message from history
args = {
...args,
request: {
...args.request,
prompt: (history[i] as vscode.ChatRequestTurn).prompt,
},
request,
};

// Remove the item from the history messages array.
Expand All @@ -150,23 +191,20 @@ export abstract class PromptBase<TArgs extends PromptArgsBase> {
}

const { prompt, hasSampleDocs } = await this.getUserPrompt(args);
const messages = [
// eslint-disable-next-line new-cap
vscode.LanguageModelChatMessage.Assistant(this.getAssistantPrompt(args)),
...historyMessages,
// eslint-disable-next-line new-cap
vscode.LanguageModelChatMessage.User(prompt),
];
// eslint-disable-next-line new-cap
const userPrompt = vscode.LanguageModelChatMessage.User(prompt);

const messages = [assistantPrompt, ...historyMessages, userPrompt];

return {
messages,
stats: this.getStats(messages, args, hasSampleDocs),
stats: this.getStats(messages, { request, context }, hasSampleDocs),
};
}

protected getStats(
messages: vscode.LanguageModelChatMessage[],
{ request, context }: TArgs,
{ request, context }: Pick<PromptArgsBase, 'request' | 'context'>,
hasSampleDocs: boolean
): ParticipantPromptProperties {
return {
Expand Down
37 changes: 23 additions & 14 deletions src/participant/prompts/promptHistory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -106,26 +106,28 @@ export class PromptHistory {
/** When passing the history to the model we only want contextual messages
to be passed. This function parses through the history and returns
the messages that are valuable to keep. */
static getFilteredHistory({
static async getFilteredHistory({
model,
tokenLimit,
connectionNames,
history,
databaseName,
collectionName,
namespaceIsKnown,
}: {
model?: vscode.LanguageModelChat | undefined;
tokenLimit?: number;
connectionNames?: string[]; // Used to scrape the connecting messages from the history.
history?: vscode.ChatContext['history'];
databaseName?: string;
collectionName?: string;
}): vscode.LanguageModelChatMessage[] {
namespaceIsKnown: boolean;
}): Promise<vscode.LanguageModelChatMessage[]> {
const messages: vscode.LanguageModelChatMessage[] = [];

if (!history) {
return [];
}

const namespaceIsKnown =
databaseName !== undefined && collectionName !== undefined;
for (let i = 0; i < history.length; i++) {
let totalUsedTokens = 0;

for (let i = history.length - 1; i >= 0; i--) {
const currentTurn = history[i];

let addedMessage: vscode.LanguageModelChatMessage | undefined;
Expand All @@ -147,16 +149,23 @@ export class PromptHistory {
});
}
if (addedMessage) {
if (tokenLimit) {
totalUsedTokens += (await model?.countTokens(addedMessage)) || 0;
if (totalUsedTokens > tokenLimit) {
break;
}
}

messages.push(addedMessage);
}
}

return messages;
return messages.reverse();
}

/** The docs chatbot keeps its own history so we avoid any
* we need to include history only since last docs message. */
static getFilteredHistoryForDocs({
static async getFilteredHistoryForDocs({
connectionNames,
context,
databaseName,
Expand All @@ -166,7 +175,7 @@ export class PromptHistory {
context?: vscode.ChatContext;
databaseName?: string;
collectionName?: string;
}): vscode.LanguageModelChatMessage[] {
}): Promise<vscode.LanguageModelChatMessage[]> {
if (!context) {
return [];
}
Expand All @@ -192,8 +201,8 @@ export class PromptHistory {
return this.getFilteredHistory({
connectionNames,
history: historySinceLastDocs.reverse(),
databaseName,
collectionName,
namespaceIsKnown:
databaseName !== undefined && collectionName !== undefined,
});
}
}
8 changes: 5 additions & 3 deletions src/participant/sampleDocuments.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,11 @@ export async function getStringifiedSampleDocuments({

const stringifiedDocuments = toJSString(additionToPrompt);

// TODO: model.countTokens will sometimes return undefined - at least in tests. We should investigate why.
promptInputTokens =
(await model.countTokens(prompt + stringifiedDocuments)) || 0;
// Re-evaluate promptInputTokens with less documents if necessary.
if (promptInputTokens > model.maxInputTokens) {
promptInputTokens =
(await model.countTokens(prompt + stringifiedDocuments)) || 0;
}

// Add sample documents to the prompt only when it fits in the context window.
if (promptInputTokens <= model.maxInputTokens) {
Expand Down
Loading

0 comments on commit 07ebbd0

Please sign in to comment.