-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c635a7d
commit 9b8110e
Showing
5 changed files
with
143 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
const { Adapter } = require("./adapter"); | ||
// load Worker from worker_threads if environment is node, otherwise load Worker from window | ||
const { Worker } = (typeof window !== 'undefined') ? window : require('worker_threads'); | ||
|
||
class WorkerAdapter extends Adapter { | ||
// initiates a worker containing SmartRankerModel instance | ||
// rank function posts a message to the worker with the query and documents | ||
// and waits for the worker to return the ranked documents | ||
async init(){ | ||
if(!this.worker){ | ||
this.worker = new Worker("./worker_model.js"); | ||
// send config to worker | ||
this.worker.postMessage({ | ||
type: "config", | ||
config: this.main.config.worker_config | ||
}); | ||
// wait for worker to be ready | ||
await new Promise((resolve) => { | ||
this.worker.on('message', (data) => { | ||
if(data.type === "ready"){ | ||
resolve(); | ||
} | ||
}); | ||
}); | ||
} | ||
} | ||
async rank(query, documents){ | ||
// send query and documents to worker | ||
this.worker.postMessage({ | ||
type: "rank", | ||
query: query, | ||
documents: documents | ||
}); | ||
// wait for worker to return ranked documents | ||
const ranked_documents = await new Promise((resolve) => { | ||
this.worker.on('message', (data) => { | ||
if(data.type === "ranked_documents"){ | ||
resolve(data.ranked_documents); | ||
} | ||
}); | ||
}); | ||
return ranked_documents; | ||
} | ||
} | ||
|
||
exports.WorkerAdapter = WorkerAdapter; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
const test = require('ava'); | ||
const { WorkerAdapter } = require('./worker'); | ||
|
||
test.beforeEach(t => { | ||
t.context = { | ||
main: { | ||
config: { | ||
worker_config: { | ||
adapter: 'Transformers', | ||
model_key: 'Xenova/bge-reranker-base', | ||
quantized: true | ||
} | ||
} | ||
} | ||
}; | ||
}); | ||
|
||
test.serial('WorkerAdapter initializes worker and sends config', async t => { | ||
const workerAdapter = new WorkerAdapter(t.context.main); | ||
await workerAdapter.init(); | ||
|
||
// Mock the Worker class | ||
const worker = workerAdapter.worker; | ||
worker.postMessage = (message) => { | ||
t.is(message.type, 'config'); | ||
t.deepEqual(message.config, t.context.main.config.worker_config); | ||
}; | ||
|
||
|
||
await workerAdapter.rank('query', ['doc1', 'doc2']); | ||
t.pass(); | ||
}); | ||
|
||
test.serial('WorkerAdapter sends rank message and receives ranked documents', async t => { | ||
const workerAdapter = new WorkerAdapter(t.context.main); | ||
await workerAdapter.init(); | ||
|
||
// Mock the Worker class | ||
const worker = workerAdapter.worker; | ||
worker.postMessage = (message) => { | ||
if (message.type === 'rank') { | ||
t.is(message.query, 'query'); | ||
t.deepEqual(message.documents, ['doc1', 'doc2']); | ||
} | ||
}; | ||
|
||
|
||
const ranked_documents = await workerAdapter.rank('query', ['doc1', 'doc2']); | ||
t.deepEqual(ranked_documents, ['doc2', 'doc1']); | ||
}); | ||
|
||
test.serial('WorkerAdapter handles worker initialization only once', async t => { | ||
const workerAdapter = new WorkerAdapter(t.context.main); | ||
await workerAdapter.init(); | ||
|
||
// Mock the Worker class | ||
const worker = workerAdapter.worker; | ||
let init_count = 0; | ||
worker.postMessage = (message) => { | ||
if (message.type === 'config') { | ||
init_count++; | ||
} | ||
}; | ||
|
||
|
||
await workerAdapter.rank('query', ['doc1', 'doc2']); | ||
await workerAdapter.rank('query', ['doc1', 'doc2']); | ||
t.is(init_count, 1); | ||
}); | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
const { SmartRankerModel } = require('./smart_ranker_model'); | ||
|
||
// get config from message | ||
let model; | ||
onmessage = async (e) => { | ||
if (e.type === "config") { | ||
model = new SmartRankerModel({}, e.config); | ||
await model.init(); | ||
postMessage({ | ||
type: "ready" | ||
}); | ||
} | ||
if (e.type === "rank") { | ||
const ranked_documents = await model.rank(e.query, e.documents); | ||
postMessage({ | ||
type: "ranked_documents", | ||
ranked_documents: ranked_documents | ||
}); | ||
} | ||
}; |