Skip to content

Commit

Permalink
fix(providers): get models working properly
Browse files Browse the repository at this point in the history
  • Loading branch information
nivthefox committed Nov 6, 2024
1 parent ca907a3 commit e66e462
Show file tree
Hide file tree
Showing 11 changed files with 161 additions and 87 deletions.
20 changes: 9 additions & 11 deletions src/ai/client.js
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,11 @@ export class Client {
/**
* Create a new AI client
*
* @param {string} provider
* @param {AIProviderConfig} configuration
* @param {AIProviderSettings} settings
* @returns {Client}
*/
static create(provider, configuration) {
const implementation = Client.#createImplementation(provider, configuration);
static create(settings) {
const implementation = Client.#createImplementation(settings);
return new Client(implementation);
}

Expand Down Expand Up @@ -112,12 +111,11 @@ export class Client {

/**
* @private
* @param {string} provider
* @param {AIProviderConfig} configuration
* @param {AIProviderSettings} settings
* @returns {AIProvider}
*/
static #createImplementation(provider, configuration) {
switch (provider.toLowerCase()) {
static #createImplementation(settings) {
switch (settings.provider.toLowerCase()) {
case 'anthropic':
throw new Error('Unsupported provider: Anthropic');
// disabled temporarily as Anthropic does not support embeddings
Expand All @@ -126,11 +124,11 @@ export class Client {
// providers
// return new Anthropic(configuration);
case 'deepinfra':
return new DeepInfra(configuration);
return new DeepInfra(settings);
case 'openai':
return new OpenAI(configuration);
return new OpenAI(settings);
default:
throw new Error(`Unsupported provider: ${provider}`);
throw new Error(`Unsupported provider: ${settings.provider}`);
}
}
}
2 changes: 1 addition & 1 deletion src/ai/client.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ export default function ClientTest(quench) {

describe('factory creation', () => {
it('creates client with provider configuration', () => {
assert.throws(() => Client.create('invalid', {}), /Unsupported provider/);
assert.throws(() => Client.create({provider: 'invalid'}), /Unsupported provider/);
// Additional provider creation tests would go here once we have real providers
});
});
Expand Down
2 changes: 1 addition & 1 deletion src/ai/provider/anthropic.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ export class Anthropic {
#embeddingModels = null;

/**
* @param {AIProviderConfig} config
* @param {AIProviderSettings} config
*/
constructor(config) {
if (!config.apiKey) {
Expand Down
76 changes: 47 additions & 29 deletions src/ai/provider/deepinfra.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,14 @@
*/
export class DeepInfra {
#apiKey = null;
#baseUrl = 'https://api.deepinfra.com/v1/inference';
#baseUrl = 'https://api.deepinfra.com/v1/openai';
#chatModels = null;
#embeddingModels = null;

/**
* @param {AIProviderConfig} config
* @param {AIProviderSettings} config
*/
constructor(config) {
if (!config.apiKey) {
throw new Error('DeepInfra API key is required');
}
this.#apiKey = config.apiKey;
}

Expand All @@ -37,37 +34,58 @@ export class DeepInfra {
}

const data = await response.json();
this.#chatModels = data.models
.filter(model => model.type === 'text-generation')
.map(model => model.id);

this.#chatModels = data.data
.filter(model => !model.id.toLowerCase().includes('embed'))
.map(model => model.id)
.sort((a, b) => a.localeCompare(b));
return this.#chatModels;
}

/**
* @returns {Promise<string[]>}
*/
async getEmbeddingModels() {
if (this.#embeddingModels !== null) {
return this.#embeddingModels;
}

const response = await fetch(`${this.#baseUrl}/models`, {
headers: {
'Authorization': `Bearer ${this.#apiKey}`
}
});

if (!response.ok) {
throw new Error(`DeepInfra API error: ${response.status}`);
}

const data = await response.json();
this.#embeddingModels = data.models
.filter(model => model.type === 'embedding')
.map(model => model.id);

return this.#embeddingModels;
// todo: temporarily hardcoded since there's no way to get the embedding
// models from the API
return [
'BAAI/bge-base-en-v1.5',
'BAAI/bge-large-en-v1.5',
'BAAI/bge-m3',
'intfloat/e5-base-v2',
'intfloat/e5-large-v2',
'intfloat/multilingual-e5-large',
'sentence-transformers/all-MiniLM-L12-v2',
'sentence-transformers/all-MiniLM-L6-v2',
'sentence-transformers/all-mpnet-base-v2',
'sentence-transformers/clip-ViT-B-32',
'sentence-transformers/clip-ViT-B-32-multilingual-v1',
'sentence-transformers/multi-qa-mpnet-base-dot-v1',
'sentence-transformers/paraphrase-MiniLM-L6-v2',
'shibing624/text2vec-base-chinese',
'thenlper/gte-base',
'thenlper/gte-large',
].sort((a, b) => a.localeCompare(b));

// if (this.#embeddingModels !== null) {
// return this.#embeddingModels;
// }
//
// const response = await fetch(`${this.#baseUrl}/models`, {
// headers: {
// 'Authorization': `Bearer ${this.#apiKey}`
// }
// });
//
// if (!response.ok) {
// throw new Error(`DeepInfra API error: ${response.status}`);
// }
//
// const data = await response.json();
// this.#embeddingModels = data.data
// .filter(model => model.id.toLowerCase().includes('embed'))
// .map(model => model.id);
//
// return this.#embeddingModels;
}

/**
Expand Down
34 changes: 15 additions & 19 deletions src/ai/provider/deepinfra.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@ export default function DeepInfraProviderTest({describe, it, assert, beforeEach,
globalThis.fetch = originalFetch;
});

describe('initialization', () => {
it('requires API key', () => {
assert.throws(() => new DeepInfra({}), /API key is required/);
});
});

describe('model listing', () => {
let fetchCount = 0;

Expand All @@ -31,10 +25,10 @@ export default function DeepInfraProviderTest({describe, it, assert, beforeEach,
return {
ok: true,
json: async () => ({
models: [
{ id: 'model1', type: 'text-generation' },
{ id: 'model2', type: 'embedding' },
{ id: 'model3', type: 'other' }
data: [
{id: 'text-chat-model1'},
{id: 'text-chat-model2'},
{id: 'text-embed-model3'},
]
})
};
Expand All @@ -44,18 +38,20 @@ export default function DeepInfraProviderTest({describe, it, assert, beforeEach,
it('caches chat models', async () => {
const models1 = await provider.getChatModels();
const models2 = await provider.getChatModels();
assert.deepEqual(models1, ['model1']);
assert.deepEqual(models2, ['model1']);
assert.deepEqual(models1, ['text-chat-model1', 'text-chat-model2']);
assert.deepEqual(models2, ['text-chat-model1', 'text-chat-model2']);
assert.equal(fetchCount, 1, 'Should only fetch once');
});

it('caches embedding models', async () => {
const models1 = await provider.getEmbeddingModels();
const models2 = await provider.getEmbeddingModels();
assert.deepEqual(models1, ['model2']);
assert.deepEqual(models2, ['model2']);
assert.equal(fetchCount, 1, 'Should only fetch once');
});
// todo: disabled since there's no way to get the embedding models from the API
// they are temporarily hardcoded in the provider
// it('caches embedding models', async () => {
// const models1 = await provider.getEmbeddingModels();
// const models2 = await provider.getEmbeddingModels();
// assert.deepEqual(models1, ['text-embed-model3']);
// assert.deepEqual(models2, ['text-embed-model3']);
// assert.equal(fetchCount, 1, 'Should only fetch once');
// });
});

describe('generation', () => {
Expand Down
9 changes: 3 additions & 6 deletions src/ai/provider/openai.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,9 @@ export class OpenAI {
#embeddingModels = null;

/**
* @param {AIProviderConfig} config
* @param {AIProviderSettings} config
*/
constructor(config) {
if (!config.apiKey) {
throw new Error('OpenAI API key is required');
}
this.#apiKey = config.apiKey;

if (config.baseURL) {
Expand Down Expand Up @@ -42,7 +39,7 @@ export class OpenAI {

const data = await response.json();
this.#chatModels = data.data
.filter(model => model.id.startsWith('gpt-'))
.filter(model => !model.id.toLowerCase().includes('embed'))
.map(model => model.id);

return this.#chatModels;
Expand All @@ -68,7 +65,7 @@ export class OpenAI {

const data = await response.json();
this.#embeddingModels = data.data
.filter(model => model.id.startsWith('text-embedding-'))
.filter(model => model.id.toLowerCase().includes('embed'))
.map(model => model.id);

return this.#embeddingModels;
Expand Down
5 changes: 0 additions & 5 deletions src/ai/provider/openai.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@ export default function OpenAIProviderTest({describe, it, assert, beforeEach, af
});

describe('initialization', () => {
it('requires API key', () => {
assert.throws(() => new OpenAI({}), /OpenAI API key is required/);
});

it('allows custom base URL', async () => {
let capturedUrl;
globalThis.fetch = async (url, options) => {
Expand Down Expand Up @@ -57,7 +53,6 @@ export default function OpenAIProviderTest({describe, it, assert, beforeEach, af
{ id: 'gpt-3.5-turbo' },
{ id: 'text-embedding-3-small' },
{ id: 'text-embedding-3-large' },
{ id: 'dall-e-3' }
]
})
};
Expand Down
34 changes: 28 additions & 6 deletions src/app/app.js
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import {Logger, LogLevels} from './logger';
import {Store as ConversationStore} from '../conversation/store';
import {DocumentManager} from '../document/manager';
import {Settings} from '../settings/settings';
import {Store as ConversationStore} from '../conversation/store';
import {VectorStore} from '../document/vector_store';

import {renderChatWithAIButton} from '../ui/sidebar';
import {Client} from '../ai/client';

export class App {
id;
Expand All @@ -17,20 +21,38 @@ export class App {
this.version = version;

this.logger = Logger.getLogger(this.name, LogLevels.Debug);

this.settings = new Settings(ctx, this.id);
this.conversationStore = new ConversationStore(ctx);

ctx.Hooks.once('setup', () => this.setup());
ctx.Hooks.once('ready', () => this.ready());
ctx.Hooks.once('setup', () => this.setup(ctx, id));
ctx.Hooks.once('ready', () => this.ready(ctx, id));
ctx.Hooks.on('renderSidebarTab', (app, html) => renderChatWithAIButton(app, html));
}

async ready() {
this.logger.debug('Version %s Ready', this.version);
}

async setup() {
async setup(ctx, id) {
this.settings.registerSettings();

// Initialize Clients and Stores
const providerSettings = this.settings.getProviderSettings();
this.chatClient = Client.create(providerSettings.chat);
this.embeddingClient = Client.create(providerSettings.embedding);
this.conversationStore = new ConversationStore(ctx);
this.vectorStore = new VectorStore(ctx);

// Initialize Document Manager
const managerSettings = this.settings.getDocumentManagerSettings();
this.documentManager = new DocumentManager(ctx, managerSettings, this.embeddingClient, this.vectorStore);
// todo: only rebuild if necessary
await this.documentManager.rebuildVectorStore();

// Register model choices
const chatModels = await this.chatClient.getChatModels();
this.settings.setChoices('ChatModel', chatModels);

const embeddingModels = await this.embeddingClient.getEmbeddingModels();
this.settings.setChoices('EmbeddingModel', embeddingModels);
}
}
Loading

0 comments on commit e66e462

Please sign in to comment.