From 543661f3f1419b4a5cb83b063aa024cd07eeac28 Mon Sep 17 00:00:00 2001 From: Brian Joseph Petro Date: Sun, 15 Dec 2024 10:31:20 -0500 Subject: [PATCH] Refactor SmartDirectory, SmartEntities, and SmartEntity classes to use async for embedding calculations (enables archtiecture to support external data storage via adapters) - Updated methods in SmartDirectory to be asynchronous, enhancing the handling of nearest and furthest source results. - Refactored SmartEntities and SmartEntity to ensure all nearest and connection-related methods are async, improving consistency across the codebase. - Enhanced documentation for async methods, providing clearer guidance on their usage and expected behavior. - Adjusted tests to accommodate the new async method signatures, ensuring proper functionality and reliability. --- smart-directories/components/directory.js | 4 +- smart-directories/smart_directory.js | 8 +-- smart-entities/README.md | 2 +- smart-entities/adapters/default.js | 39 ++++++++++----- smart-entities/smart_entities.js | 57 ++++++++++++---------- smart-entities/smart_entity.js | 11 +++-- smart-entities/test/env.test.js | 2 +- smart-entities/test/smart_entities.test.js | 4 +- smart-sources/smart_source.js | 7 +-- 9 files changed, 78 insertions(+), 56 deletions(-) diff --git a/smart-directories/components/directory.js b/smart-directories/components/directory.js index c12aed18..818f055c 100644 --- a/smart-directories/components/directory.js +++ b/smart-directories/components/directory.js @@ -68,8 +68,8 @@ async function render_content(directory, sources_container, subdirs_container, o subdirs_container.innerHTML = ''; const results = directory.settings.sort_nearest - ? directory.nearest_sources_results - : directory.furthest_sources_results; + ? await directory.get_nearest_sources_results() + : await directory.get_furthest_sources_results(); const result_frags = await render_results.call(this, results, opts); sources_container.appendChild(result_frags); } diff --git a/smart-directories/smart_directory.js b/smart-directories/smart_directory.js index 94aecacc..5e2357f4 100644 --- a/smart-directories/smart_directory.js +++ b/smart-directories/smart_directory.js @@ -95,7 +95,7 @@ export class SmartDirectory extends SmartEntity { ); } - get nearest_sources_results() { + async get_nearest_sources_results() { if(!this.median_vec) { console.log(`no median vec for directory: ${this.data.path}`); return []; @@ -103,10 +103,10 @@ export class SmartDirectory extends SmartEntity { const filter = { key_starts_with: this.data.path } - const results = this.env.smart_sources.nearest(this.median_vec, filter); + const results = await this.env.smart_sources.nearest(this.median_vec, filter); return results.sort(sort_by_score_descending); } - get furthest_sources_results() { + async get_furthest_sources_results() { if(!this.median_vec) { console.log(`no median vec for directory: ${this.data.path}`); return []; @@ -114,7 +114,7 @@ export class SmartDirectory extends SmartEntity { const filter = { key_starts_with: this.data.path } - const results = this.env.smart_sources.furthest(this.median_vec, filter); + const results = await this.env.smart_sources.furthest(this.median_vec, filter); return results.sort(sort_by_score_ascending); } diff --git a/smart-entities/README.md b/smart-entities/README.md index 545b6614..a0971b1d 100644 --- a/smart-entities/README.md +++ b/smart-entities/README.md @@ -58,7 +58,7 @@ const entity = new SmartEntity(environment, { await smartEntities.create_or_update(entity); // Find nearest neighbors -const nearestNeighbors = smartEntities.nearest(entity.vec); +const nearestNeighbors = await smartEntities.nearest(entity.vec); // Perform a lookup based on hypotheticals const results = await smartEntities.lookup({ diff --git a/smart-entities/adapters/default.js b/smart-entities/adapters/default.js index b129e6d7..6c43b144 100644 --- a/smart-entities/adapters/default.js +++ b/smart-entities/adapters/default.js @@ -99,6 +99,7 @@ export class DefaultEntitiesVectorAdapter extends EntitiesVectorAdapter { * @returns {Promise} */ async process_embed_queue() { + const embed_queue = this.collection.embed_queue; // Reset stats as in SmartEntities this._reset_embed_queue_stats(); @@ -113,7 +114,7 @@ export class DefaultEntitiesVectorAdapter extends EntitiesVectorAdapter { } const datetime_start = new Date(); - if (!this.collection.embed_queue.length) { + if (!embed_queue.length) { return console.log(`Smart Connections: No items in ${this.collection.collection_key} embed queue`); } @@ -121,7 +122,7 @@ export class DefaultEntitiesVectorAdapter extends EntitiesVectorAdapter { console.log(`Processing ${this.collection.collection_key} embed queue: ${embed_queue.length} items`); // Process in batches according to embed_model.batch_size - for (let i = 0; i < this.collection.embed_queue.length; i += this.collection.embed_model.batch_size) { + for (let i = 0; i < embed_queue.length; i += this.collection.embed_model.batch_size) { if (this.collection.is_queue_halted) { this.collection.is_queue_halted = false; // reset halt after break break; @@ -254,6 +255,10 @@ export class DefaultEntitiesVectorAdapter extends EntitiesVectorAdapter { this.total_tokens = 0; this.total_time = 0; } + + get notices() { + return this.collection.notices; + } } @@ -264,13 +269,16 @@ export class DefaultEntitiesVectorAdapter extends EntitiesVectorAdapter { * In-memory adapter for a single entity. Stores and retrieves vectors from item.data. */ export class DefaultEntityVectorAdapter extends EntityVectorAdapter { + get data() { + return this.item.data; + } /** * Retrieve the current vector embedding for this entity. * @async * @returns {Promise} The entity's vector or undefined if not set. */ async get_vec() { - return this.item.data?.embeddings?.[this.item.embed_model_key]?.vec; + return this.vec; } /** @@ -280,13 +288,7 @@ export class DefaultEntityVectorAdapter extends EntityVectorAdapter { * @returns {Promise} */ async set_vec(vec) { - if (!this.item.data.embeddings) { - this.item.data.embeddings = {}; - } - if (!this.item.data.embeddings[this.item.embed_model_key]) { - this.item.data.embeddings[this.item.embed_model_key] = {}; - } - this.item.data.embeddings[this.item.embed_model_key].vec = vec; + this.vec = vec; } /** @@ -299,4 +301,19 @@ export class DefaultEntityVectorAdapter extends EntityVectorAdapter { delete this.item.data.embeddings[this.item.embed_model_key].vec; } } -} + + // adds synchronous get/set for vec + get vec() { + return this.item.data?.embeddings?.[this.item.embed_model_key]?.vec; + } + set vec(vec){ + if (!this.item.data.embeddings) { + this.item.data.embeddings = {}; + } + if (!this.item.data.embeddings[this.item.embed_model_key]) { + this.item.data.embeddings[this.item.embed_model_key] = {}; + } + this.item.data.embeddings[this.item.embed_model_key].vec = vec; + } + +} \ No newline at end of file diff --git a/smart-entities/smart_entities.js b/smart-entities/smart_entities.js index c7d7db62..5a9c6e42 100644 --- a/smart-entities/smart_entities.js +++ b/smart-entities/smart_entities.js @@ -30,6 +30,8 @@ export class SmartEntities extends Collection { /** @type {string|null} */ this.model_instance_id = null; + /** @type {Array} */ + this._embed_queue = []; } /** @@ -137,11 +139,12 @@ export class SmartEntities extends Collection { /** * Finds the nearest entities to a given entity. + * @async * @param {Object} entity - The reference entity. * @param {Object} [filter={}] - Optional filters to apply. * @returns {Promise>} An array of result objects with score and item. */ - nearest_to(entity, filter = {}) { return this.nearest(entity.vec, filter); } + async nearest_to(entity, filter = {}) { return await this.nearest(entity.vec, filter); } /** * Finds the nearest entities to a vector using the default adapter. @@ -150,9 +153,9 @@ export class SmartEntities extends Collection { * @param {Object} [filter={}] - Optional filters to apply. * @returns {Promise>} An array of result objects with score and item. */ - nearest(vec, filter = {}) { - if (!vec) return console.log("no vec"); - return this.entities_vector_adapter.nearest(vec, filter); + async nearest(vec, filter = {}) { + if (!vec) return console.warn("nearest: no vec"); + return await this.entities_vector_adapter.nearest(vec, filter); } /** @@ -162,9 +165,9 @@ export class SmartEntities extends Collection { * @param {Object} [filter={}] - Optional filters to apply. * @returns {Promise>} An array of result objects with score and item. */ - furthest(vec, filter = {}) { - if (!vec) return console.log("no vec"); - return this.entities_vector_adapter.furthest(vec, filter); + async furthest(vec, filter = {}) { + if (!vec) return console.warn("furthest: no vec"); + return await this.entities_vector_adapter.furthest(vec, filter); } /** @@ -264,26 +267,26 @@ export class SmartEntities extends Collection { ...(this.env.chats?.current?.scope || {}), ...(params.filter || {}), }; - const results = hyp_vecs - .reduce((acc, embedding, i) => { - const results = this.nearest(embedding.vec, filter); - results.forEach(result => { - if (!acc[result.item.path] || result.score > acc[result.item.path].score) { - acc[result.item.path] = { - key: result.item.key, - score: result.score, - item: result.item, - entity: result.item, // DEPRECATED: use item instead - hypothetical_i: i, - }; - } else { - // DEPRECATED: Handling when last score added to entity is not top score - result.score = acc[result.item.path].score; - } - }); - return acc; - }, {}) - ; + const results = await hyp_vecs.reduce(async (acc_promise, embedding, i) => { + const acc = await acc_promise; + const results = await this.nearest(embedding.vec, filter); + results.forEach(result => { + if (!acc[result.item.path] || result.score > acc[result.item.path].score) { + acc[result.item.path] = { + key: result.item.key, + score: result.score, + item: result.item, + entity: result.item, // DEPRECATED: use item instead + hypothetical_i: i, + }; + } else { + // DEPRECATED: Handling when last score added to entity is not top score + result.score = acc[result.item.path].score; + } + }); + return acc; + }, Promise.resolve({})); + const top_k = Object.values(results) .sort(sort_by_score) .slice(0, limit) diff --git a/smart-entities/smart_entity.js b/smart-entities/smart_entity.js index e55ed058..f4febd26 100644 --- a/smart-entities/smart_entity.js +++ b/smart-entities/smart_entity.js @@ -84,7 +84,7 @@ export class SmartEntity extends CollectionItem { * @param {Object} [filter={}] - Optional filters to apply. * @returns {Array<{item:Object, score:number}>} An array of result objects with score and item. */ - nearest(filter = {}) { return this.collection.nearest_to(this, filter); } + async nearest(filter = {}) { return await this.collection.nearest_to(this, filter); } /** * Prepares the input for embedding. @@ -112,10 +112,11 @@ export class SmartEntity extends CollectionItem { /** * Finds connections relevant to this entity based on provided parameters. + * @async * @param {Object} [params={}] - Parameters for finding connections. * @returns {Array<{item:Object, score:number}>} An array of result objects with score and item. */ - find_connections(params = {}) { + async find_connections(params = {}) { const filter_opts = this.prepare_find_connections_filter_opts(params); const limit = params.filter?.limit || params.limit // DEPRECATED: for backwards compatibility @@ -124,7 +125,7 @@ export class SmartEntity extends CollectionItem { const cache_key = this.key + JSON.stringify(params); // no objects/instances in cache key if (!this.env.connections_cache) this.env.connections_cache = {}; if (!this.env.connections_cache[cache_key]) { - const connections = this.nearest(filter_opts) + const connections = (await this.nearest(filter_opts)) .sort(sort_by_score) .slice(0, limit); this.connections_to_cache(cache_key, connections); @@ -239,14 +240,14 @@ export class SmartEntity extends CollectionItem { * @readonly * @returns {Array|undefined} The vector or undefined if not set. */ - get vec() { return this.entity_adapter.get_vec(); } + get vec() { return this.entity_adapter.vec; } /** * Sets the vector representation in the entity adapter. * @param {Array} vec - The vector to set. */ set vec(vec) { - this.entity_adapter.set_vec(vec); + this.entity_adapter.vec = vec; this._queue_embed = false; this._embed_input = null; this.queue_save(); diff --git a/smart-entities/test/env.test.js b/smart-entities/test/env.test.js index 836d039b..5dd026da 100644 --- a/smart-entities/test/env.test.js +++ b/smart-entities/test/env.test.js @@ -80,7 +80,7 @@ test('SmartEntity methods work correctly', async t => { t.truthy(entity.nearest); t.truthy(entity.find_connections); - const connections = entity.find_connections(); + const connections = await entity.find_connections(); t.true(Array.isArray(connections)); }); diff --git a/smart-entities/test/smart_entities.test.js b/smart-entities/test/smart_entities.test.js index 0f3d76e8..49037427 100644 --- a/smart-entities/test/smart_entities.test.js +++ b/smart-entities/test/smart_entities.test.js @@ -23,7 +23,7 @@ test.serial('SmartEntities nearest', async t => { const { env } = t.context; await Promise.all(Object.values(test_data).map(entity_data => env.smart_entities.create_or_update(entity_data))); - const nearest = env.smart_entities.nearest([0.1, 0.2, 0.3]); + const nearest = await env.smart_entities.nearest([0.1, 0.2, 0.3]); t.is(nearest.length, 3, 'Should return all entities'); t.is(nearest[0].path, 'test1', 'Nearest entity should be entity1'); }); @@ -41,7 +41,7 @@ test.serial('SmartEntity find_connections', async t => { await Promise.all(Object.values(test_data).map(entity_data => env.smart_entities.create_or_update(entity_data))); const entity = env.smart_entities.get('test1'); - const connections = entity.find_connections(); + const connections = await entity.find_connections(); t.is(connections.length, 2, 'Should return 2 connections'); t.is(connections[0].path, 'test2', 'First connection should be entity2'); }); diff --git a/smart-sources/smart_source.js b/smart-sources/smart_source.js index 598ea657..5ecdaf9c 100644 --- a/smart-sources/smart_source.js +++ b/smart-sources/smart_source.js @@ -78,15 +78,16 @@ export class SmartSource extends SmartEntity { /** * Finds connections relevant to this SmartSource based on provided parameters. + * @async * @param {Object} [params={}] - Parameters for finding connections. * @param {boolean} [params.exclude_source_connections=false] - Whether to exclude source connections. * @param {boolean} [params.exclude_blocks_from_source_connections=false] - Whether to exclude block connections from source connections. * @returns {Array} An array of relevant SmartSource entities. */ - find_connections(params={}) { + async find_connections(params={}) { let connections; if(this.block_collection.settings.embed_blocks && params.exclude_source_connections) connections = []; - else connections = super.find_connections(params); + else connections = await super.find_connections(params); const filter_opts = this.prepare_find_connections_filter_opts(params); const limit = params.filter?.limit || params.limit // DEPRECATED: for backwards compatibility @@ -99,7 +100,7 @@ export class SmartSource extends SmartEntity { const cache_key = this.key + JSON.stringify(params) + "_blocks"; if(!this.env.connections_cache) this.env.connections_cache = {}; if(!this.env.connections_cache[cache_key]){ - const nearest = this.env.smart_blocks.nearest(this.vec, filter_opts) + const nearest = (await this.env.smart_blocks.nearest(this.vec, filter_opts)) .sort(sort_by_score) .slice(0, limit) ;