Skip to content

Commit

Permalink
work on workers
Browse files Browse the repository at this point in the history
  • Loading branch information
brianpetro committed May 20, 2024
1 parent c635a7d commit 9b8110e
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 1 deletion.
1 change: 1 addition & 0 deletions smart-ranker-model/adapters/api.js
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class ApiAdapter extends Adapter {
console.log(error);
return null;
}
async rank(query, documents){ /* OVERRIDE */ }
}
exports.ApiAdapter = ApiAdapter;

46 changes: 46 additions & 0 deletions smart-ranker-model/adapters/worker.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
const { Adapter } = require("./adapter");
// load Worker from worker_threads if environment is node, otherwise load Worker from window
const { Worker } = (typeof window !== 'undefined') ? window : require('worker_threads');

class WorkerAdapter extends Adapter {
// initiates a worker containing SmartRankerModel instance
// rank function posts a message to the worker with the query and documents
// and waits for the worker to return the ranked documents
async init(){
if(!this.worker){
this.worker = new Worker("./worker_model.js");
// send config to worker
this.worker.postMessage({
type: "config",
config: this.main.config.worker_config
});
// wait for worker to be ready
await new Promise((resolve) => {
this.worker.on('message', (data) => {
if(data.type === "ready"){
resolve();
}
});
});
}
}
async rank(query, documents){
// send query and documents to worker
this.worker.postMessage({
type: "rank",
query: query,
documents: documents
});
// wait for worker to return ranked documents
const ranked_documents = await new Promise((resolve) => {
this.worker.on('message', (data) => {
if(data.type === "ranked_documents"){
resolve(data.ranked_documents);
}
});
});
return ranked_documents;
}
}

exports.WorkerAdapter = WorkerAdapter;
70 changes: 70 additions & 0 deletions smart-ranker-model/adapters/worker.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
const test = require('ava');
const { WorkerAdapter } = require('./worker');

test.beforeEach(t => {
t.context = {
main: {
config: {
worker_config: {
adapter: 'Transformers',
model_key: 'Xenova/bge-reranker-base',
quantized: true
}
}
}
};
});

test.serial('WorkerAdapter initializes worker and sends config', async t => {
const workerAdapter = new WorkerAdapter(t.context.main);
await workerAdapter.init();

// Mock the Worker class
const worker = workerAdapter.worker;
worker.postMessage = (message) => {
t.is(message.type, 'config');
t.deepEqual(message.config, t.context.main.config.worker_config);
};


await workerAdapter.rank('query', ['doc1', 'doc2']);
t.pass();
});

test.serial('WorkerAdapter sends rank message and receives ranked documents', async t => {
const workerAdapter = new WorkerAdapter(t.context.main);
await workerAdapter.init();

// Mock the Worker class
const worker = workerAdapter.worker;
worker.postMessage = (message) => {
if (message.type === 'rank') {
t.is(message.query, 'query');
t.deepEqual(message.documents, ['doc1', 'doc2']);
}
};


const ranked_documents = await workerAdapter.rank('query', ['doc1', 'doc2']);
t.deepEqual(ranked_documents, ['doc2', 'doc1']);
});

test.serial('WorkerAdapter handles worker initialization only once', async t => {
const workerAdapter = new WorkerAdapter(t.context.main);
await workerAdapter.init();

// Mock the Worker class
const worker = workerAdapter.worker;
let init_count = 0;
worker.postMessage = (message) => {
if (message.type === 'config') {
init_count++;
}
};


await workerAdapter.rank('query', ['doc1', 'doc2']);
await workerAdapter.rank('query', ['doc1', 'doc2']);
t.is(init_count, 1);
});

7 changes: 6 additions & 1 deletion smart-ranker-model/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"description": "Universal Re-Rankers Interface",
"main": "smart_ranker_model.js",
"scripts": {
"test": "npx ava --verbose test.js"
"test": "npx ava --verbose"
},
"keywords": [
"embeddings",
Expand All @@ -29,5 +29,10 @@
},
"dependencies": {
"@xenova/transformers": "latest"
},
"ava": {
"files": [
"**/*.test.js"
]
}
}
20 changes: 20 additions & 0 deletions smart-ranker-model/worker_model.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
const { SmartRankerModel } = require('./smart_ranker_model');

// get config from message
let model;
onmessage = async (e) => {
if (e.type === "config") {
model = new SmartRankerModel({}, e.config);
await model.init();
postMessage({
type: "ready"
});
}
if (e.type === "rank") {
const ranked_documents = await model.rank(e.query, e.documents);
postMessage({
type: "ranked_documents",
ranked_documents: ranked_documents
});
}
};

0 comments on commit 9b8110e

Please sign in to comment.