Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(participant): filter message history when it goes over maxInputTokens VSCODE-653 #894

Merged
merged 9 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/participant/participant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1511,7 +1511,7 @@ export default class ParticipantController {
log.info('Docs chatbot created for chatId', chatId);
}

const history = PromptHistory.getFilteredHistoryForDocs({
const history = await PromptHistory.getFilteredHistoryForDocs({
connectionNames: this._getConnectionNames(),
context: context,
});
Expand Down
82 changes: 61 additions & 21 deletions src/participant/prompts/promptBase.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import type {
ParticipantPromptProperties,
} from '../../telemetry/telemetryService';
import { PromptHistory } from './promptHistory';
import { getCopilotModel } from '../model';

export interface PromptArgsBase {
request: {
Expand Down Expand Up @@ -93,34 +94,76 @@ export function isContentEmpty(
return true;
}

export abstract class PromptBase<TArgs extends PromptArgsBase> {
protected abstract getAssistantPrompt(args: TArgs): string;
export abstract class PromptBase<PromptArgs extends PromptArgsBase> {
protected abstract getAssistantPrompt(args: PromptArgs): string;

protected get internalPurposeForTelemetry(): InternalPromptPurpose {
return undefined;
}

protected getUserPrompt(args: TArgs): Promise<UserPromptResponse> {
protected getUserPrompt({
request,
}: PromptArgs): Promise<UserPromptResponse> {
return Promise.resolve({
prompt: args.request.prompt,
prompt: request.prompt,
hasSampleDocs: false,
});
}

async buildMessages(args: TArgs): Promise<ModelInput> {
let historyMessages = PromptHistory.getFilteredHistory({
history: args.context?.history,
...args,
private async _countRemainingTokens({
model,
assistantPrompt,
requestPrompt,
}: {
model: vscode.LanguageModelChat | undefined;
assistantPrompt: vscode.LanguageModelChatMessage;
requestPrompt: string;
}): Promise<number | undefined> {
if (model) {
const [assistantPromptTokens, userPromptTokens] = await Promise.all([
model.countTokens(assistantPrompt),
model.countTokens(requestPrompt),
]);
return model.maxInputTokens - (assistantPromptTokens + userPromptTokens);
}
return undefined;
}

async buildMessages(args: PromptArgs): Promise<ModelInput> {
const { context, request, databaseName, collectionName, connectionNames } =
args;

const model = await getCopilotModel();

// eslint-disable-next-line new-cap
const assistantPrompt = vscode.LanguageModelChatMessage.Assistant(
this.getAssistantPrompt(args)
);

const tokenLimit = await this._countRemainingTokens({
model,
assistantPrompt,
requestPrompt: request.prompt,
});

let historyMessages = await PromptHistory.getFilteredHistory({
history: context?.history,
model,
tokenLimit,
namespaceIsKnown:
databaseName !== undefined && collectionName !== undefined,
connectionNames,
});

// If the current user's prompt is a connection name, and the last
// message was to connect. We want to use the last
// message they sent before the connection name as their prompt.
if (args.connectionNames?.includes(args.request.prompt)) {
const history = args.context?.history;
if (connectionNames?.includes(request.prompt)) {
gagik marked this conversation as resolved.
Show resolved Hide resolved
const history = context?.history;
if (!history) {
return {
messages: [],
stats: this.getStats([], args, false),
stats: this.getStats([], { request, context }, false),
};
}
const previousResponse = history[
Expand All @@ -135,7 +178,7 @@ export abstract class PromptBase<TArgs extends PromptArgsBase> {
args = {
...args,
request: {
...args.request,
...request,
prompt: (history[i] as vscode.ChatRequestTurn).prompt,
},
};
Expand All @@ -149,23 +192,20 @@ export abstract class PromptBase<TArgs extends PromptArgsBase> {
}

const { prompt, hasSampleDocs } = await this.getUserPrompt(args);
const messages = [
// eslint-disable-next-line new-cap
vscode.LanguageModelChatMessage.Assistant(this.getAssistantPrompt(args)),
...historyMessages,
// eslint-disable-next-line new-cap
vscode.LanguageModelChatMessage.User(prompt),
];
// eslint-disable-next-line new-cap
const userPrompt = vscode.LanguageModelChatMessage.User(prompt);

const messages = [assistantPrompt, ...historyMessages, userPrompt];

return {
messages,
stats: this.getStats(messages, args, hasSampleDocs),
stats: this.getStats(messages, { request, context }, hasSampleDocs),
};
}

protected getStats(
messages: vscode.LanguageModelChatMessage[],
{ request, context }: TArgs,
{ request, context }: Pick<PromptArgsBase, 'request' | 'context'>,
hasSampleDocs: boolean
): ParticipantPromptProperties {
return {
Expand Down
37 changes: 23 additions & 14 deletions src/participant/prompts/promptHistory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -105,26 +105,28 @@ export class PromptHistory {
/** When passing the history to the model we only want contextual messages
to be passed. This function parses through the history and returns
the messages that are valuable to keep. */
static getFilteredHistory({
static async getFilteredHistory({
model,
tokenLimit,
connectionNames,
history,
databaseName,
collectionName,
namespaceIsKnown,
}: {
model?: vscode.LanguageModelChat | undefined;
tokenLimit?: number;
connectionNames?: string[]; // Used to scrape the connecting messages from the history.
history?: vscode.ChatContext['history'];
databaseName?: string;
collectionName?: string;
}): vscode.LanguageModelChatMessage[] {
namespaceIsKnown: boolean;
}): Promise<vscode.LanguageModelChatMessage[]> {
const messages: vscode.LanguageModelChatMessage[] = [];

if (!history) {
return [];
}

const namespaceIsKnown =
databaseName !== undefined && collectionName !== undefined;
for (let i = 0; i < history.length; i++) {
let totalUsedTokens = 0;

for (let i = history.length - 1; i >= 0; i--) {
const currentTurn = history[i];

let addedMessage: vscode.LanguageModelChatMessage | undefined;
Expand All @@ -146,16 +148,23 @@ export class PromptHistory {
});
}
if (addedMessage) {
if (tokenLimit) {
totalUsedTokens += (await model?.countTokens(addedMessage)) || 0;
if (totalUsedTokens > tokenLimit) {
break;
}
}

messages.push(addedMessage);
}
}

return messages;
return messages.reverse();
}

/** The docs chatbot keeps its own history so we avoid any
* we need to include history only since last docs message. */
static getFilteredHistoryForDocs({
static async getFilteredHistoryForDocs({
connectionNames,
context,
databaseName,
Expand All @@ -165,7 +174,7 @@ export class PromptHistory {
context?: vscode.ChatContext;
databaseName?: string;
collectionName?: string;
}): vscode.LanguageModelChatMessage[] {
}): Promise<vscode.LanguageModelChatMessage[]> {
if (!context) {
return [];
}
Expand All @@ -191,8 +200,8 @@ export class PromptHistory {
return this.getFilteredHistory({
connectionNames,
history: historySinceLastDocs.reverse(),
databaseName,
collectionName,
namespaceIsKnown:
databaseName !== undefined && collectionName !== undefined,
});
}
}
8 changes: 5 additions & 3 deletions src/participant/sampleDocuments.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,11 @@ export async function getStringifiedSampleDocuments({

const stringifiedDocuments = toJSString(additionToPrompt);

// TODO: model.countTokens will sometimes return undefined - at least in tests. We should investigate why.
gagik marked this conversation as resolved.
Show resolved Hide resolved
promptInputTokens =
(await model.countTokens(prompt + stringifiedDocuments)) || 0;
// Re-evaluate promptInputTokens with less documents if necessary.
if (promptInputTokens > model.maxInputTokens) {
promptInputTokens =
(await model.countTokens(prompt + stringifiedDocuments)) || 0;
}

// Add sample documents to the prompt only when it fits in the context window.
if (promptInputTokens <= model.maxInputTokens) {
Expand Down
Loading
Loading