Skip to content

Commit

Permalink
Support prompt variants
Browse files Browse the repository at this point in the history
fixed #14485

Signed-off-by: Jonas Helming <[email protected]>
  • Loading branch information
JonasHelming committed Nov 20, 2024
1 parent 8f293ed commit e770ba6
Show file tree
Hide file tree
Showing 8 changed files with 262 additions and 28 deletions.
8 changes: 7 additions & 1 deletion packages/ai-chat/src/common/universal-chat-agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ simple solutions.
`
};

export const universalTemplateVariant: PromptTemplate = {
id: 'universal-system-empty',
template: '',
variantOf: universalTemplate.id,
};

@injectable()
export class UniversalChatAgent extends AbstractStreamParsingChatAgent implements ChatAgent {
name: string;
Expand All @@ -96,7 +102,7 @@ export class UniversalChatAgent extends AbstractStreamParsingChatAgent implement
+ 'questions the user might ask. The universal agent currently does not have any context by default, i.e. it cannot '
+ 'access the current user context or the workspace.';
this.variables = [];
this.promptTemplates = [universalTemplate];
this.promptTemplates = [universalTemplate, universalTemplateVariant];
this.functions = [];
this.agentSpecificVariables = [];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,26 @@ export class AIAgentConfigurationWidget extends ReactWidget {
Enable Agent
</label>
</div>
<div className="settings-section-subcategory-title ai-settings-section-subcategory-title">
Prompt Templates
</div>
<div className='ai-templates'>
{agent.promptTemplates?.map(template =>
<TemplateRenderer
key={agent?.id + '.' + template.id}
agentId={agent.id}
template={template}
promptCustomizationService={this.promptCustomizationService} />)}
{agent.promptTemplates
?.filter(template => !template.variantOf)
.map(template => (
<div key={agent.id + '.' + template.id}>
<TemplateRenderer
key={agent?.id + '.' + template.id}
agentId={agent.id}
template={template}
promptService={this.promptService}
aiSettingsService={this.aiSettingsService}
promptCustomizationService={this.promptCustomizationService}
/>
</div>
))}
</div>

<div className='ai-lm-requirements'>
<LanguageModelRenderer
agent={agent}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,103 @@
// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0
// *****************************************************************************
import * as React from '@theia/core/shared/react';
import { PromptCustomizationService } from '../../common/prompt-service';
import { PromptTemplate } from '../../common';
import { PromptCustomizationService, PromptService } from '../../common/prompt-service';
import { AISettingsService, PromptTemplate } from '../../common';

export interface TemplateSettingProps {
const DEFAULT_VARIANT = 'default';

export interface TemplateRendererProps {
agentId: string;
template: PromptTemplate;
promptCustomizationService: PromptCustomizationService;
promptService: PromptService;
aiSettingsService: AISettingsService;
}

export const TemplateRenderer: React.FC<TemplateSettingProps> = ({ agentId, template, promptCustomizationService }) => {
export const TemplateRenderer: React.FC<TemplateRendererProps> = ({
agentId,
template,
promptCustomizationService,
promptService,
aiSettingsService,
}) => {
const [variantIds, setVariantIds] = React.useState<string[]>([]);
const [selectedVariant, setSelectedVariant] = React.useState<string>(DEFAULT_VARIANT);

React.useEffect(() => {
(async () => {
const variants = promptService.getVariantIds(template.id);
setVariantIds([DEFAULT_VARIANT, ...variants]);

const agentSettings = await aiSettingsService.getAgentSettings(agentId);
const currentVariant =
agentSettings?.selectedVariants?.[template.id] || DEFAULT_VARIANT;
setSelectedVariant(currentVariant);
})();
}, [template.id, promptService, aiSettingsService, agentId]);

const handleVariantChange = async (event: React.ChangeEvent<HTMLSelectElement>) => {
const newVariant = event.target.value;
setSelectedVariant(newVariant);

const agentSettings = await aiSettingsService.getAgentSettings(agentId);
const selectedVariants = agentSettings?.selectedVariants || {};

const updatedVariants = { ...selectedVariants };
if (newVariant === DEFAULT_VARIANT) {
delete updatedVariants[template.id];
} else {
updatedVariants[template.id] = newVariant;
}

await aiSettingsService.updateAgentSettings(agentId, {
selectedVariants: updatedVariants,
});
};

const openTemplate = React.useCallback(async () => {
promptCustomizationService.editTemplate(template.id, template.template);
}, [template, promptCustomizationService]);
const templateId = selectedVariant === DEFAULT_VARIANT ? template.id : selectedVariant;
const selectedTemplate = promptService.getRawPrompt(templateId);
promptCustomizationService.editTemplate(templateId, selectedTemplate?.template || '');
}, [selectedVariant, template.id, promptService, promptCustomizationService]);

const resetTemplate = React.useCallback(async () => {
promptCustomizationService.resetTemplate(template.id);
}, [promptCustomizationService, template]);

return <>
{template.id}
<button className='theia-button main' onClick={openTemplate}>Edit</button>
<button className='theia-button secondary' onClick={resetTemplate}>Reset</button>
</>;
const templateId = selectedVariant === DEFAULT_VARIANT ? template.id : selectedVariant;
promptCustomizationService.resetTemplate(templateId);
}, [selectedVariant, template.id, promptCustomizationService]);

return (
<div className="template-renderer">
<div className="settings-section-title template-header">
<strong>{template.id}</strong>
</div>
<div className="template-controls">
{variantIds.length > 1 && (
<>
<label htmlFor={`variant-selector-${template.id}`} className="template-select-label">
Variant:
</label>
<select
id={`variant-selector-${template.id}`}
className="theia-select template-variant-selector"
value={selectedVariant}
onChange={handleVariantChange}
>
{variantIds.map(variantId => (
<option key={variantId} value={variantId}>
{variantId}
</option>
))}
</select>
</>
)}
<button className="theia-button main" onClick={openTemplate}>
Edit
</button>
<button className="theia-button secondary" onClick={resetTemplate}>
Reset
</button>
</div>
</div>
);
};
38 changes: 33 additions & 5 deletions packages/ai-core/src/browser/style/index.css
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,42 @@
margin-left: var(--theia-ui-padding);
}

.theia-settings-container .settings-section-subcategory-title.ai-settings-section-subcategory-title {
padding-left: 0;
}

.ai-templates {
display: grid;
/** Display content in 3 columns */
grid-template-columns: 1fr auto auto;
/** add a 3px gap between rows */
row-gap: 3px;
display: flex;
flex-direction: column;
gap: 5px;
}

.template-renderer {
display: flex;
flex-direction: column;
padding: 10px;
}

.template-header {
margin-bottom: 8px;
}

.template-controls {
display: flex;
align-items: center;
gap: 10px;
}

.template-select-label {
margin-right: 5px;
}

.template-variant-selector {
min-width: 120px;
}



#ai-variable-configuration-container-widget,
#ai-agent-configuration-container-widget {
margin-top: 5px;
Expand Down
2 changes: 1 addition & 1 deletion packages/ai-core/src/common/agent-service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ export class AgentServiceImpl implements AgentService {
registerAgent(agent: Agent): void {
this._agents.push(agent);
agent.promptTemplates.forEach(
template => this.promptService.storePrompt(template.id, template.template)
template => this.promptService.storePromptTemplate(template)
);
this.onDidChangeAgentsEmitter.fire();
}
Expand Down
56 changes: 56 additions & 0 deletions packages/ai-core/src/common/prompt-service.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -244,4 +244,60 @@ describe('PromptService', () => {
expect(prompt?.text).to.equal('Hi, John! {{!-- Another comment --}}');
});

it('should return all variant IDs of a given prompt', () => {
promptService.storePromptTemplate({ id: 'main', template: 'Main template' });

promptService.storePromptTemplate({
id: 'variant1',
template: 'Variant 1',
variantOf: 'main'
});
promptService.storePromptTemplate({
id: 'variant2',
template: 'Variant 2',
variantOf: 'main'
});
promptService.storePromptTemplate({
id: 'variant3',
template: 'Variant 3',
variantOf: 'main'
});

const variantIds = promptService.getVariantIds('main');
expect(variantIds).to.deep.equal(['variant1', 'variant2', 'variant3']);
});

it('should return an empty array if no variants exist for a given prompt', () => {
promptService.storePromptTemplate({ id: 'main', template: 'Main template' });

const variantIds = promptService.getVariantIds('main');
expect(variantIds).to.deep.equal([]);
});

it('should return an empty array if the main prompt ID does not exist', () => {
const variantIds = promptService.getVariantIds('nonExistent');
expect(variantIds).to.deep.equal([]);
});

it('should not influence prompts without variants when other prompts have variants', () => {
promptService.storePromptTemplate({ id: 'mainWithVariants', template: 'Main template with variants' });
promptService.storePromptTemplate({ id: 'mainWithoutVariants', template: 'Main template without variants' });

promptService.storePromptTemplate({
id: 'variant1',
template: 'Variant 1',
variantOf: 'mainWithVariants'
});
promptService.storePromptTemplate({
id: 'variant2',
template: 'Variant 2',
variantOf: 'mainWithVariants'
});

const variantsForMainWithVariants = promptService.getVariantIds('mainWithVariants');
const variantsForMainWithoutVariants = promptService.getVariantIds('mainWithoutVariants');

expect(variantsForMainWithVariants).to.deep.equal(['variant1', 'variant2']);
expect(variantsForMainWithoutVariants).to.deep.equal([]);
});
});
52 changes: 51 additions & 1 deletion packages/ai-core/src/common/prompt-service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,16 @@ import { ToolInvocationRegistry } from './tool-invocation-registry';
import { toolRequestToPromptText } from './language-model-util';
import { ToolRequest } from './language-model';
import { matchFunctionsRegEx, matchVariablesRegEx } from './prompt-service-util';
import { AISettingsService } from './settings-service';

export interface PromptTemplate {
id: string;
template: string;
/**
* (Optional) The ID of the main template for which this template is a variant.
* If present, this indicates that the current template represents an alternative version of the specified main template.
*/
variantOf?: string;
}

export interface PromptMap { [id: string]: PromptTemplate }
Expand Down Expand Up @@ -68,6 +74,11 @@ export interface PromptService {
* @param prompt the prompt template to store
*/
storePrompt(id: string, prompt: string): void;
/**
* Adds a {@link PromptTemplate} to the list of prompts.
* @param promptTemplate the prompt template to store
*/
storePromptTemplate(promptTemplate: PromptTemplate): void;
/**
* Removes a prompt from the list of prompts.
* @param id the id of the prompt
Expand All @@ -77,6 +88,20 @@ export interface PromptService {
* Return all known prompts as a {@link PromptMap map}.
*/
getAllPrompts(): PromptMap;
/**
* Retrieve all variant IDs of a given {@link PromptTemplate}.
* @param id the id of the main {@link PromptTemplate}
* @returns an array of string IDs representing the variants of the given template
*/
getVariantIds(id: string): string[];
/**
* Retrieve the currently selected variant ID for a given main prompt ID.
* If a variant is selected for the main prompt, it will be returned.
* Otherwise, the main prompt ID will be returned.
* @param id the id of the main prompt
* @returns the variant ID if one is selected, or the main prompt ID otherwise
*/
getVariantId(id: string): Promise<string>;
}

export interface CustomAgentDescription {
Expand Down Expand Up @@ -163,6 +188,9 @@ export interface PromptCustomizationService {

@injectable()
export class PromptServiceImpl implements PromptService {
@inject(AISettingsService) @optional()
protected readonly settingsService: AISettingsService | undefined;

@inject(PromptCustomizationService) @optional()
protected readonly customizationService: PromptCustomizationService | undefined;

Expand Down Expand Up @@ -203,8 +231,22 @@ export class PromptServiceImpl implements PromptService {
return commentRegex.test(template) ? template.replace(commentRegex, '').trimStart() : template;
}

async getVariantId(id: string): Promise<string> {
if (this.settingsService !== undefined) {
const agentSettingsMap = await this.settingsService.getSettings();

for (const agentSettings of Object.values(agentSettingsMap)) {
if (agentSettings.selectedVariants && agentSettings.selectedVariants[id]) {
return agentSettings.selectedVariants[id];
}
}
}
return id;
}

async getPrompt(id: string, args?: { [key: string]: unknown }): Promise<ResolvedPromptTemplate | undefined> {
const prompt = this.getUnresolvedPrompt(id);
const variantId = await this.getVariantId(id);
const prompt = this.getUnresolvedPrompt(variantId);
if (prompt === undefined) {
return undefined;
}
Expand Down Expand Up @@ -280,4 +322,12 @@ export class PromptServiceImpl implements PromptService {
removePrompt(id: string): void {
delete this._prompts[id];
}
getVariantIds(id: string): string[] {
return Object.values(this._prompts)
.filter(prompt => prompt.variantOf === id)
.map(variant => variant.id);
}
storePromptTemplate(promptTemplate: PromptTemplate): void {
this._prompts[promptTemplate.id] = promptTemplate;
}
}
5 changes: 5 additions & 0 deletions packages/ai-core/src/common/settings-service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,9 @@ export type AISettings = Record<string, AgentSettings>;
export interface AgentSettings {
languageModelRequirements?: LanguageModelRequirement[];
enable?: boolean;
/**
* A mapping of main template IDs to their selected variant IDs.
* If a main template is not present in this mapping, it means the main template is used.
*/
selectedVariants?: Record<string, string>;
}

0 comments on commit e770ba6

Please sign in to comment.