From f1dd4184e53a8932267de3eb95559b905938e9d9 Mon Sep 17 00:00:00 2001 From: Kevin Kragenbrink Date: Tue, 12 Nov 2024 22:19:55 -0800 Subject: [PATCH] feat(settings): add temp, topp, topk, etc. values --- CHANGELOG.md | 1 + src/ai/provider/deepinfra.js | 20 +++++++--- src/ai/provider/openai.js | 16 +++++--- src/conversation/store.js | 8 ++-- src/event/emitter.js | 2 +- src/settings/settings.js | 6 +++ src/settings/settings.json | 59 +++++++++++++++++++++++++++++ src/ui/settings.js | 18 +++++++++ templates/applications/Settings.hbs | 25 ++++++++++++ 9 files changed, 137 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a7866a3..7514e67 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased](https://github.com/nivthefox/foundryvtt-aide) ### Added - Document embeddings will be updated when another user updates a document. +- New settings are now available to control the LLM's creativity and consistency. ### Fixed - Resolved an issue where models could be set prematurely. (#1) diff --git a/src/ai/provider/deepinfra.js b/src/ai/provider/deepinfra.js index 2893497..e409171 100644 --- a/src/ai/provider/deepinfra.js +++ b/src/ai/provider/deepinfra.js @@ -12,7 +12,19 @@ export class DeepInfra { * @param {AIProviderSettings} config */ constructor(config) { - this.#apiKey = config.apiKey; + delete config.provider; + + if (config.apiKey) { + this.#apiKey = config.apiKey; + delete config.apiKey; + } + + if (config.baseURL) { + this.#baseUrl = config.baseURL; + delete config.baseURL; + } + + this.config = config; } /** @@ -108,11 +120,7 @@ export class DeepInfra { body: JSON.stringify({ model: model, messages: query, - temperature: 0.7, - top_p: 0.9, - top_k: 0, - presence_penalty: 0.0, - frequency_penalty: 0.0, + ...this.config, stream }) }); diff --git a/src/ai/provider/openai.js b/src/ai/provider/openai.js index 299e2fa..a3e0341 100644 --- a/src/ai/provider/openai.js +++ b/src/ai/provider/openai.js @@ -12,11 +12,19 @@ export class OpenAI { * @param {AIProviderSettings} config */ constructor(config) { - this.#apiKey = config.apiKey; + delete config.provider; + + if (config.apiKey) { + this.#apiKey = config.apiKey; + delete config.apiKey; + } if (config.baseURL) { this.#baseUrl = config.baseURL; + delete config.baseURL; } + + this.config = config; } /** @@ -88,11 +96,7 @@ export class OpenAI { body: JSON.stringify({ model, messages: query, - temperature: 0.7, - top_p: 0.9, - top_k: 0, - presence_penalty: 0.0, - frequency_penalty: 0.0, + ...this.config, stream }) }); diff --git a/src/conversation/store.js b/src/conversation/store.js index 939efab..a2a0b64 100644 --- a/src/conversation/store.js +++ b/src/conversation/store.js @@ -61,11 +61,9 @@ export class Store { await this.#createDirectoryIfMissing(); await this.#fetchConversations(); - this.#emitter.on('conversation.create', (id) => - this.#conversations.set(id, {id, loaded: false})); - this.#emitter.on('conversation.delete', (id) => - this.#conversations.delete(id)); - this.#emitter.on('conversation.update', (id) => { + this.#emitter.on('conversation.create', id => this.#conversations.set(id, {id, loaded: false})); + this.#emitter.on('conversation.delete', id => this.#conversations.delete(id)); + this.#emitter.on('conversation.update', id => { const {userId} = this.#conversations.get(id); this.#conversations.set(id, {id, loaded: false}); this.get(userId, id); diff --git a/src/event/emitter.js b/src/event/emitter.js index 3c77e59..113f875 100644 --- a/src/event/emitter.js +++ b/src/event/emitter.js @@ -21,6 +21,6 @@ export class Emitter { this.#logger.debug('Received event %s', event.name); listener(...event.args); } - }) + }); } } diff --git a/src/settings/settings.js b/src/settings/settings.js index 2990639..2d43787 100644 --- a/src/settings/settings.js +++ b/src/settings/settings.js @@ -37,6 +37,12 @@ export class Settings { provider: this.#context.game.settings.get(this.#module, 'ChatProvider'), apiKey: this.#context.game.settings.get(this.#module, 'ChatAPIKey'), baseURL: this.#context.game.settings.get(this.#module, 'ChatBaseURL'), + temperature: this.#context.game.settings.get(this.#module, 'ChatTemperature'), + maxTokens: this.#context.game.settings.get(this.#module, 'ChatMaxTokens'), + topP: this.#context.game.settings.get(this.#module, 'ChatTopP'), + topK: this.#context.game.settings.get(this.#module, 'ChatTopK'), + frequencyPenalty: this.#context.game.settings.get(this.#module, 'ChatFrequencyPenalty'), + presencePenalty: this.#context.game.settings.get(this.#module, 'ChatPresencePenalty'), }, embedding: { provider: this.#context.game.settings.get(this.#module, 'EmbeddingProvider'), diff --git a/src/settings/settings.json b/src/settings/settings.json index 3944c90..66d126a 100644 --- a/src/settings/settings.json +++ b/src/settings/settings.json @@ -24,6 +24,65 @@ "default": "", "scope": "client" }, + "ChatTemperature": { + "type": "Number", + "default": 0.7, + "scope": "client", + "range": { + "min": 0, + "max": 2, + "step": 0.05 + } + }, + "ChatTopP": { + "type": "Number", + "default": 0.9, + "scope": "client", + "range": { + "min": 0, + "max": 1, + "step": 0.05 + } + }, + "ChatTopK": { + "type": "Number", + "default": 0, + "scope": "client", + "range": { + "min": 0, + "max": 32, + "step": 1 + } + }, + "ChatPresencePenalty": { + "type": "Number", + "default": 0, + "scope": "client", + "range": { + "min": 0, + "max": 2, + "step": 0.05 + } + }, + "ChatFrequencyPenalty": { + "type": "Number", + "default": 0, + "scope": "client", + "range": { + "min": 0, + "max": 2, + "step": 0.05 + } + }, + "ChatMaxTokens": { + "type": "Number", + "default": 150, + "scope": "client", + "range": { + "min": 1, + "step": 1 + } + }, "EmbeddingProvider": { "type": "String", "default": "", diff --git a/src/ui/settings.js b/src/ui/settings.js index e6fbec9..7c1bf8b 100644 --- a/src/ui/settings.js +++ b/src/ui/settings.js @@ -27,9 +27,21 @@ export class ChatSettings extends HandlebarsApplicationMixin(ApplicationV2) { const chatModel = html.querySelector('.chat-model').value; const embeddingModel = html.querySelector('.embedding-model').value; + const temperature = html.querySelector('.temperature').value; + const maxTokens = html.querySelector('.max-tokens').value; + const topP = html.querySelector('.top-p').value; + const topK = html.querySelector('.top-k').value; + const frequencyPenalty = html.querySelector('.frequency-penalty').value; + const presencePenalty = html.querySelector('.presence-penalty').value; game.settings.set('aide', 'ChatModel', chatModel); game.settings.set('aide', 'EmbeddingModel', embeddingModel); + game.settings.set('aide', 'ChatTemperature', temperature); + game.settings.set('aide', 'ChatMaxTokens', maxTokens); + game.settings.set('aide', 'ChatTopP', topP); + game.settings.set('aide', 'ChatTopK', topK); + game.settings.set('aide', 'ChatFrequencyPenalty', frequencyPenalty); + game.settings.set('aide', 'ChatPresencePenalty', presencePenalty); this.close(); } @@ -46,6 +58,12 @@ export class ChatSettings extends HandlebarsApplicationMixin(ApplicationV2) { acc[model] = model; return acc; }, {}), + temperature: game.settings.get('aide', 'ChatTemperature'), + maxTokens: game.settings.get('aide', 'ChatMaxTokens'), + topP: game.settings.get('aide', 'ChatTopP'), + topK: game.settings.get('aide', 'ChatTopK'), + frequencyPenalty: game.settings.get('aide', 'ChatFrequencyPenalty'), + presencePenalty: game.settings.get('aide', 'ChatPresencePenalty'), embeddingModel: game.settings.get('aide', 'EmbeddingModel'), embeddingModels: (await aide.embeddingClient.getEmbeddingModels()) .reduce((acc, model) => { diff --git a/templates/applications/Settings.hbs b/templates/applications/Settings.hbs index 77ba684..3c4a46e 100644 --- a/templates/applications/Settings.hbs +++ b/templates/applications/Settings.hbs @@ -1,3 +1,4 @@ +
+ + + + + +