forked from chroma-core/chroma
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ENH] Add push based operators, centralized dispatch, hardcode query …
…plan as state machine (chroma-core#1888) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Change on_start to async - Move on_start into the executor. - New functionality - This PR adds operators to the query/compaction workers - Implements a PullLog operator - Adds a HnswQueryOrchestrator, which can be thought of as a passive state machine of a hardcoded query plan. - Adds a dispatcher with worker threads for scheduling tasks. Worker threads pull for tasks. Pending work I will address in following prs: - [ ] Make the callers of dispatch unaware of wrap() - [ ] PullLogs should poll until completion @Ishiihara is taking this - [ ] Error handling - [ ] Wrap the orchestrator - [ ] Add a server struct and have it create the aforementioned orchestrator wrapper and then push to it. - [ ] Impl configurable ## Test plan *How are these changes tested?* - [x] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Documentation Changes None
- Loading branch information
Showing
20 changed files
with
684 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,267 @@ | ||
use super::{operator::TaskMessage, worker_thread::WorkerThread}; | ||
use crate::system::{Component, ComponentContext, Handler, Receiver, System}; | ||
use async_trait::async_trait; | ||
use std::fmt::Debug; | ||
|
||
/// The dispatcher is responsible for distributing tasks to worker threads. | ||
/// It is a component that receives tasks and distributes them to worker threads. | ||
/** | ||
```plaintext | ||
┌─────────────────────────────────────────┐ | ||
│ │ | ||
│ │ | ||
│ │ | ||
TaskMessage ───────────►├─────┐ Dispatcher │ | ||
│ ▼ │ | ||
│ ┌┬───────────────────────────────┐ │ | ||
│ │┼┼┼┼┼┼┼┼┼┼┼┼┼┼┼┼┼┼┼┼┼┼┼┼┼┼┼┼┼┼┼┼│ │ | ||
└────┴──────────────┴─────────────────┴───┘ | ||
▲ | ||
│ │ | ||
│ │ | ||
TaskRequestMessage │ │ TaskMessage | ||
│ │ | ||
│ │ | ||
▼ | ||
┌────────────────┐ ┌────────────────┐ ┌────────────────┐ | ||
│ │ │ │ │ │ | ||
│ │ │ │ │ │ | ||
│ │ │ │ │ │ | ||
│ Worker │ │ Worker │ │ Worker │ | ||
│ │ │ │ │ │ | ||
│ │ │ │ │ │ | ||
│ │ │ │ │ │ | ||
└────────────────┘ └────────────────┘ └────────────────┘ | ||
``` | ||
## Implementation notes | ||
- The dispatcher has a queue of tasks that it distributes to worker threads | ||
- A worker thread sends a TaskRequestMessage to the dispatcher when it is ready for a new task | ||
- If no task is available for the worker thread, the dispatcher will place that worker's reciever | ||
in a queue and send a task to the worker when it recieves another one | ||
- The reason to introduce this abstraction is to allow us to control fairness and dynamically adjust | ||
system utilization. It also makes mechanisms like pausing/stopping work easier. | ||
It would have likely been more performant to use the Tokio MT runtime, but we chose to use | ||
this abstraction to grant us flexibility. We can always switch to Tokio MT later if we need to, | ||
or make this dispatcher much more performant through implementing memory-awareness, task-batches, | ||
coarser work-stealing, and other optimizations. | ||
*/ | ||
#[derive(Debug)] | ||
struct Dispatcher { | ||
task_queue: Vec<TaskMessage>, | ||
waiters: Vec<TaskRequestMessage>, | ||
n_worker_threads: usize, | ||
} | ||
|
||
impl Dispatcher { | ||
/// Create a new dispatcher | ||
/// # Parameters | ||
/// - n_worker_threads: The number of worker threads to use | ||
pub fn new(n_worker_threads: usize) -> Self { | ||
Dispatcher { | ||
task_queue: Vec::new(), | ||
waiters: Vec::new(), | ||
n_worker_threads, | ||
} | ||
} | ||
|
||
/// Spawn worker threads | ||
/// # Parameters | ||
/// - system: The system to spawn the worker threads in | ||
/// - self_receiver: The receiver to send tasks to the worker threads, this is a address back to the dispatcher | ||
fn spawn_workers( | ||
&self, | ||
system: &mut System, | ||
self_receiver: Box<dyn Receiver<TaskRequestMessage>>, | ||
) { | ||
for _ in 0..self.n_worker_threads { | ||
let worker = WorkerThread::new(self_receiver.clone()); | ||
system.start_component(worker); | ||
} | ||
} | ||
|
||
/// Enqueue a task to be processed | ||
/// # Parameters | ||
/// - task: The task to enqueue | ||
async fn enqueue_task(&mut self, task: TaskMessage) { | ||
// If a worker is waiting for a task, send it to the worker in FIFO order | ||
// Otherwise, add it to the task queue | ||
match self.waiters.pop() { | ||
Some(channel) => match channel.reply_to.send(task).await { | ||
Ok(_) => {} | ||
Err(e) => { | ||
println!("Error sending task to worker: {:?}", e); | ||
} | ||
}, | ||
None => { | ||
self.task_queue.push(task); | ||
} | ||
} | ||
} | ||
|
||
/// Handle a work request from a worker thread | ||
/// # Parameters | ||
/// - worker: The request for work | ||
/// If no work is available, the worker will be placed in a queue and a task will be sent to it | ||
/// when one is available | ||
async fn handle_work_request(&mut self, request: TaskRequestMessage) { | ||
match self.task_queue.pop() { | ||
Some(task) => match request.reply_to.send(task).await { | ||
Ok(_) => {} | ||
Err(e) => { | ||
println!("Error sending task to worker: {:?}", e); | ||
} | ||
}, | ||
None => { | ||
self.waiters.push(request); | ||
} | ||
} | ||
} | ||
} | ||
|
||
/// A message that a worker thread sends to the dispatcher to request a task | ||
/// # Members | ||
/// - reply_to: The receiver to send the task to, this is the worker thread | ||
#[derive(Debug)] | ||
pub(super) struct TaskRequestMessage { | ||
reply_to: Box<dyn Receiver<TaskMessage>>, | ||
} | ||
|
||
impl TaskRequestMessage { | ||
/// Create a new TaskRequestMessage | ||
/// # Parameters | ||
/// - reply_to: The receiver to send the task to, this is the worker thread | ||
/// that is requesting the task | ||
pub(super) fn new(reply_to: Box<dyn Receiver<TaskMessage>>) -> Self { | ||
TaskRequestMessage { reply_to } | ||
} | ||
} | ||
|
||
// ============= Component implementation ============= | ||
|
||
#[async_trait] | ||
impl Component for Dispatcher { | ||
fn queue_size(&self) -> usize { | ||
1000 // TODO: make configurable | ||
} | ||
|
||
async fn on_start(&mut self, ctx: &ComponentContext<Self>) { | ||
self.spawn_workers(&mut ctx.system.clone(), ctx.sender.as_receiver()); | ||
} | ||
} | ||
|
||
#[async_trait] | ||
impl Handler<TaskMessage> for Dispatcher { | ||
async fn handle(&mut self, task: TaskMessage, _ctx: &ComponentContext<Dispatcher>) { | ||
self.enqueue_task(task).await; | ||
} | ||
} | ||
|
||
// Worker sends a request for task | ||
#[async_trait] | ||
impl Handler<TaskRequestMessage> for Dispatcher { | ||
async fn handle(&mut self, message: TaskRequestMessage, _ctx: &ComponentContext<Dispatcher>) { | ||
self.handle_work_request(message).await; | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use std::{ | ||
env::current_dir, | ||
sync::{ | ||
atomic::{AtomicUsize, Ordering}, | ||
Arc, | ||
}, | ||
}; | ||
|
||
use super::*; | ||
use crate::{ | ||
execution::operator::{wrap, Operator}, | ||
system::System, | ||
}; | ||
|
||
// Create a component that will schedule DISPATCH_COUNT invocations of the MockOperator | ||
// on an interval of DISPATCH_FREQUENCY_MS. | ||
// Each invocation will sleep for MOCK_OPERATOR_SLEEP_DURATION_MS to simulate work | ||
// Use THREAD_COUNT worker threads | ||
const MOCK_OPERATOR_SLEEP_DURATION_MS: u64 = 100; | ||
const DISPATCH_FREQUENCY_MS: u64 = 5; | ||
const DISPATCH_COUNT: usize = 50; | ||
const THREAD_COUNT: usize = 4; | ||
|
||
#[derive(Debug)] | ||
struct MockOperator {} | ||
#[async_trait] | ||
impl Operator<f32, String> for MockOperator { | ||
async fn run(&self, input: &f32) -> String { | ||
// sleep to simulate work | ||
tokio::time::sleep(tokio::time::Duration::from_millis( | ||
MOCK_OPERATOR_SLEEP_DURATION_MS, | ||
)) | ||
.await; | ||
input.to_string() | ||
} | ||
} | ||
|
||
#[derive(Debug)] | ||
struct MockDispatchUser { | ||
pub dispatcher: Box<dyn Receiver<TaskMessage>>, | ||
counter: Arc<AtomicUsize>, // We expect to recieve DISPATCH_COUNT messages | ||
} | ||
#[async_trait] | ||
impl Component for MockDispatchUser { | ||
fn queue_size(&self) -> usize { | ||
1000 | ||
} | ||
|
||
async fn on_start(&mut self, ctx: &ComponentContext<Self>) { | ||
// dispatch a new task every DISPATCH_FREQUENCY_MS for DISPATCH_COUNT times | ||
let duration = std::time::Duration::from_millis(DISPATCH_FREQUENCY_MS); | ||
ctx.scheduler.schedule_interval( | ||
ctx.sender.clone(), | ||
(), | ||
duration, | ||
Some(DISPATCH_COUNT), | ||
ctx, | ||
); | ||
} | ||
} | ||
#[async_trait] | ||
impl Handler<String> for MockDispatchUser { | ||
async fn handle(&mut self, message: String, ctx: &ComponentContext<MockDispatchUser>) { | ||
self.counter.fetch_add(1, Ordering::SeqCst); | ||
let curr_count = self.counter.load(Ordering::SeqCst); | ||
// Cancel self | ||
if curr_count == DISPATCH_COUNT { | ||
ctx.cancellation_token.cancel(); | ||
} | ||
} | ||
} | ||
|
||
#[async_trait] | ||
impl Handler<()> for MockDispatchUser { | ||
async fn handle(&mut self, message: (), ctx: &ComponentContext<MockDispatchUser>) { | ||
let task = wrap(Box::new(MockOperator {}), 42.0, ctx.sender.as_receiver()); | ||
let res = self.dispatcher.send(task).await; | ||
} | ||
} | ||
|
||
#[tokio::test] | ||
async fn test_dispatcher() { | ||
let mut system = System::new(); | ||
let dispatcher = Dispatcher::new(THREAD_COUNT); | ||
let dispatcher_handle = system.start_component(dispatcher); | ||
let counter = Arc::new(AtomicUsize::new(0)); | ||
let dispatch_user = MockDispatchUser { | ||
dispatcher: dispatcher_handle.receiver(), | ||
counter: counter.clone(), | ||
}; | ||
let mut dispatch_user_handle = system.start_component(dispatch_user); | ||
// yield to allow the component to process the messages | ||
tokio::task::yield_now().await; | ||
// Join on the dispatch user, since it will kill itself after DISPATCH_COUNT messages | ||
dispatch_user_handle.join().await; | ||
// We should have received DISPATCH_COUNT messages | ||
assert_eq!(counter.load(Ordering::SeqCst), DISPATCH_COUNT); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
mod dispatcher; | ||
mod operator; | ||
mod operators; | ||
mod orchestration; | ||
mod worker_thread; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
use crate::system::Receiver; | ||
use async_trait::async_trait; | ||
use std::fmt::Debug; | ||
|
||
/// An operator takes a generic input and returns a generic output. | ||
/// It is a definition of a function. | ||
#[async_trait] | ||
pub(super) trait Operator<I, O>: Send + Sync + Debug | ||
where | ||
I: Send + Sync, | ||
O: Send + Sync, | ||
{ | ||
async fn run(&self, input: &I) -> O; | ||
} | ||
|
||
/// A task is a wrapper around an operator and its input. | ||
/// It is a description of a function to be run. | ||
#[derive(Debug)] | ||
struct Task<Input, Output> | ||
where | ||
Input: Send + Sync + Debug, | ||
Output: Send + Sync + Debug, | ||
{ | ||
operator: Box<dyn Operator<Input, Output>>, | ||
input: Input, | ||
reply_channel: Box<dyn Receiver<Output>>, | ||
} | ||
|
||
/// A message type used by the dispatcher to send tasks to worker threads. | ||
pub(super) type TaskMessage = Box<dyn TaskWrapper>; | ||
|
||
/// A task wrapper is a trait that can be used to run a task. We use it to | ||
/// erase the I, O types from the Task struct so that tasks. | ||
#[async_trait] | ||
pub(super) trait TaskWrapper: Send + Debug { | ||
async fn run(&self); | ||
} | ||
|
||
/// Implement the TaskWrapper trait for every Task. This allows us to | ||
/// erase the I, O types from the Task struct so that tasks can be | ||
/// stored in a homogenous queue regardless of their input and output types. | ||
#[async_trait] | ||
impl<Input, Output> TaskWrapper for Task<Input, Output> | ||
where | ||
Input: Send + Sync + Debug, | ||
Output: Send + Sync + Debug, | ||
{ | ||
async fn run(&self) { | ||
let output = self.operator.run(&self.input).await; | ||
let res = self.reply_channel.send(output).await; | ||
// TODO: if this errors, it means the caller was dropped | ||
} | ||
} | ||
|
||
/// Wrap an operator and its input into a task message. | ||
pub(super) fn wrap<Input, Output>( | ||
operator: Box<dyn Operator<Input, Output>>, | ||
input: Input, | ||
reply_channel: Box<dyn Receiver<Output>>, | ||
) -> TaskMessage | ||
where | ||
Input: Send + Sync + Debug + 'static, | ||
Output: Send + Sync + Debug + 'static, | ||
{ | ||
Box::new(Task { | ||
operator, | ||
input, | ||
reply_channel, | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
pub(super) mod pull_log; |
Oops, something went wrong.