Skip to content

Commit

Permalink
[ENH] Add push based operators, centralized dispatch, hardcode query …
Browse files Browse the repository at this point in the history
…plan as state machine (chroma-core#1888)

## Description of changes
*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
	 - Change on_start to async
	 - Move on_start into the executor.
 - New functionality
	 - This PR adds operators to the query/compaction workers 
	 - Implements a PullLog operator
- Adds a HnswQueryOrchestrator, which can be thought of as a passive
state machine of a hardcoded query plan.
- Adds a dispatcher with worker threads for scheduling tasks. Worker
threads pull for tasks.
	 
Pending work I will address in following prs:

- [ ] Make the callers of dispatch unaware of wrap()
- [ ] PullLogs should poll until completion @Ishiihara is taking this
- [ ] Error handling
- [ ] Wrap the orchestrator
- [ ] Add a server struct and have it create the aforementioned
orchestrator wrapper and then push to it.
- [ ] Impl configurable 
	 
## Test plan
*How are these changes tested?*
- [x] Tests pass locally with `pytest` for python, `yarn test` for js,
`cargo test` for rust

## Documentation Changes
None
  • Loading branch information
HammadB authored Mar 19, 2024
1 parent db4caba commit 93a659d
Show file tree
Hide file tree
Showing 20 changed files with 684 additions and 40 deletions.
5 changes: 3 additions & 2 deletions rust/worker/src/compactor/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,9 @@ impl Scheduler {
}
}

#[async_trait]
impl Component for Scheduler {
fn on_start(&mut self, ctx: &ComponentContext<Self>) {
async fn on_start(&mut self, ctx: &ComponentContext<Self>) {
ctx.scheduler.schedule_interval(
ctx.sender.clone(),
ScheduleMessage {},
Expand Down Expand Up @@ -186,7 +187,7 @@ mod tests {
use std::time::Duration;
use uuid::Uuid;

#[derive(Clone)]
#[derive(Clone, Debug)]
pub(crate) struct TestSysDb {
collections: HashMap<Uuid, Collection>,
}
Expand Down
267 changes: 267 additions & 0 deletions rust/worker/src/execution/dispatcher.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
use super::{operator::TaskMessage, worker_thread::WorkerThread};
use crate::system::{Component, ComponentContext, Handler, Receiver, System};
use async_trait::async_trait;
use std::fmt::Debug;

/// The dispatcher is responsible for distributing tasks to worker threads.
/// It is a component that receives tasks and distributes them to worker threads.
/**
```plaintext
┌─────────────────────────────────────────┐
│ │
│ │
│ │
TaskMessage ───────────►├─────┐ Dispatcher │
│ ▼ │
│ ┌┬───────────────────────────────┐ │
│ │┼┼┼┼┼┼┼┼┼┼┼┼┼┼┼┼┼┼┼┼┼┼┼┼┼┼┼┼┼┼┼┼│ │
└────┴──────────────┴─────────────────┴───┘
│ │
│ │
TaskRequestMessage │ │ TaskMessage
│ │
│ │
┌────────────────┐ ┌────────────────┐ ┌────────────────┐
│ │ │ │ │ │
│ │ │ │ │ │
│ │ │ │ │ │
│ Worker │ │ Worker │ │ Worker │
│ │ │ │ │ │
│ │ │ │ │ │
│ │ │ │ │ │
└────────────────┘ └────────────────┘ └────────────────┘
```
## Implementation notes
- The dispatcher has a queue of tasks that it distributes to worker threads
- A worker thread sends a TaskRequestMessage to the dispatcher when it is ready for a new task
- If no task is available for the worker thread, the dispatcher will place that worker's reciever
in a queue and send a task to the worker when it recieves another one
- The reason to introduce this abstraction is to allow us to control fairness and dynamically adjust
system utilization. It also makes mechanisms like pausing/stopping work easier.
It would have likely been more performant to use the Tokio MT runtime, but we chose to use
this abstraction to grant us flexibility. We can always switch to Tokio MT later if we need to,
or make this dispatcher much more performant through implementing memory-awareness, task-batches,
coarser work-stealing, and other optimizations.
*/
#[derive(Debug)]
struct Dispatcher {
task_queue: Vec<TaskMessage>,
waiters: Vec<TaskRequestMessage>,
n_worker_threads: usize,
}

impl Dispatcher {
/// Create a new dispatcher
/// # Parameters
/// - n_worker_threads: The number of worker threads to use
pub fn new(n_worker_threads: usize) -> Self {
Dispatcher {
task_queue: Vec::new(),
waiters: Vec::new(),
n_worker_threads,
}
}

/// Spawn worker threads
/// # Parameters
/// - system: The system to spawn the worker threads in
/// - self_receiver: The receiver to send tasks to the worker threads, this is a address back to the dispatcher
fn spawn_workers(
&self,
system: &mut System,
self_receiver: Box<dyn Receiver<TaskRequestMessage>>,
) {
for _ in 0..self.n_worker_threads {
let worker = WorkerThread::new(self_receiver.clone());
system.start_component(worker);
}
}

/// Enqueue a task to be processed
/// # Parameters
/// - task: The task to enqueue
async fn enqueue_task(&mut self, task: TaskMessage) {
// If a worker is waiting for a task, send it to the worker in FIFO order
// Otherwise, add it to the task queue
match self.waiters.pop() {
Some(channel) => match channel.reply_to.send(task).await {
Ok(_) => {}
Err(e) => {
println!("Error sending task to worker: {:?}", e);
}
},
None => {
self.task_queue.push(task);
}
}
}

/// Handle a work request from a worker thread
/// # Parameters
/// - worker: The request for work
/// If no work is available, the worker will be placed in a queue and a task will be sent to it
/// when one is available
async fn handle_work_request(&mut self, request: TaskRequestMessage) {
match self.task_queue.pop() {
Some(task) => match request.reply_to.send(task).await {
Ok(_) => {}
Err(e) => {
println!("Error sending task to worker: {:?}", e);
}
},
None => {
self.waiters.push(request);
}
}
}
}

/// A message that a worker thread sends to the dispatcher to request a task
/// # Members
/// - reply_to: The receiver to send the task to, this is the worker thread
#[derive(Debug)]
pub(super) struct TaskRequestMessage {
reply_to: Box<dyn Receiver<TaskMessage>>,
}

impl TaskRequestMessage {
/// Create a new TaskRequestMessage
/// # Parameters
/// - reply_to: The receiver to send the task to, this is the worker thread
/// that is requesting the task
pub(super) fn new(reply_to: Box<dyn Receiver<TaskMessage>>) -> Self {
TaskRequestMessage { reply_to }
}
}

// ============= Component implementation =============

#[async_trait]
impl Component for Dispatcher {
fn queue_size(&self) -> usize {
1000 // TODO: make configurable
}

async fn on_start(&mut self, ctx: &ComponentContext<Self>) {
self.spawn_workers(&mut ctx.system.clone(), ctx.sender.as_receiver());
}
}

#[async_trait]
impl Handler<TaskMessage> for Dispatcher {
async fn handle(&mut self, task: TaskMessage, _ctx: &ComponentContext<Dispatcher>) {
self.enqueue_task(task).await;
}
}

// Worker sends a request for task
#[async_trait]
impl Handler<TaskRequestMessage> for Dispatcher {
async fn handle(&mut self, message: TaskRequestMessage, _ctx: &ComponentContext<Dispatcher>) {
self.handle_work_request(message).await;
}
}

#[cfg(test)]
mod tests {
use std::{
env::current_dir,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
};

use super::*;
use crate::{
execution::operator::{wrap, Operator},
system::System,
};

// Create a component that will schedule DISPATCH_COUNT invocations of the MockOperator
// on an interval of DISPATCH_FREQUENCY_MS.
// Each invocation will sleep for MOCK_OPERATOR_SLEEP_DURATION_MS to simulate work
// Use THREAD_COUNT worker threads
const MOCK_OPERATOR_SLEEP_DURATION_MS: u64 = 100;
const DISPATCH_FREQUENCY_MS: u64 = 5;
const DISPATCH_COUNT: usize = 50;
const THREAD_COUNT: usize = 4;

#[derive(Debug)]
struct MockOperator {}
#[async_trait]
impl Operator<f32, String> for MockOperator {
async fn run(&self, input: &f32) -> String {
// sleep to simulate work
tokio::time::sleep(tokio::time::Duration::from_millis(
MOCK_OPERATOR_SLEEP_DURATION_MS,
))
.await;
input.to_string()
}
}

#[derive(Debug)]
struct MockDispatchUser {
pub dispatcher: Box<dyn Receiver<TaskMessage>>,
counter: Arc<AtomicUsize>, // We expect to recieve DISPATCH_COUNT messages
}
#[async_trait]
impl Component for MockDispatchUser {
fn queue_size(&self) -> usize {
1000
}

async fn on_start(&mut self, ctx: &ComponentContext<Self>) {
// dispatch a new task every DISPATCH_FREQUENCY_MS for DISPATCH_COUNT times
let duration = std::time::Duration::from_millis(DISPATCH_FREQUENCY_MS);
ctx.scheduler.schedule_interval(
ctx.sender.clone(),
(),
duration,
Some(DISPATCH_COUNT),
ctx,
);
}
}
#[async_trait]
impl Handler<String> for MockDispatchUser {
async fn handle(&mut self, message: String, ctx: &ComponentContext<MockDispatchUser>) {
self.counter.fetch_add(1, Ordering::SeqCst);
let curr_count = self.counter.load(Ordering::SeqCst);
// Cancel self
if curr_count == DISPATCH_COUNT {
ctx.cancellation_token.cancel();
}
}
}

#[async_trait]
impl Handler<()> for MockDispatchUser {
async fn handle(&mut self, message: (), ctx: &ComponentContext<MockDispatchUser>) {
let task = wrap(Box::new(MockOperator {}), 42.0, ctx.sender.as_receiver());
let res = self.dispatcher.send(task).await;
}
}

#[tokio::test]
async fn test_dispatcher() {
let mut system = System::new();
let dispatcher = Dispatcher::new(THREAD_COUNT);
let dispatcher_handle = system.start_component(dispatcher);
let counter = Arc::new(AtomicUsize::new(0));
let dispatch_user = MockDispatchUser {
dispatcher: dispatcher_handle.receiver(),
counter: counter.clone(),
};
let mut dispatch_user_handle = system.start_component(dispatch_user);
// yield to allow the component to process the messages
tokio::task::yield_now().await;
// Join on the dispatch user, since it will kill itself after DISPATCH_COUNT messages
dispatch_user_handle.join().await;
// We should have received DISPATCH_COUNT messages
assert_eq!(counter.load(Ordering::SeqCst), DISPATCH_COUNT);
}
}
5 changes: 5 additions & 0 deletions rust/worker/src/execution/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
mod dispatcher;
mod operator;
mod operators;
mod orchestration;
mod worker_thread;
70 changes: 70 additions & 0 deletions rust/worker/src/execution/operator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
use crate::system::Receiver;
use async_trait::async_trait;
use std::fmt::Debug;

/// An operator takes a generic input and returns a generic output.
/// It is a definition of a function.
#[async_trait]
pub(super) trait Operator<I, O>: Send + Sync + Debug
where
I: Send + Sync,
O: Send + Sync,
{
async fn run(&self, input: &I) -> O;
}

/// A task is a wrapper around an operator and its input.
/// It is a description of a function to be run.
#[derive(Debug)]
struct Task<Input, Output>
where
Input: Send + Sync + Debug,
Output: Send + Sync + Debug,
{
operator: Box<dyn Operator<Input, Output>>,
input: Input,
reply_channel: Box<dyn Receiver<Output>>,
}

/// A message type used by the dispatcher to send tasks to worker threads.
pub(super) type TaskMessage = Box<dyn TaskWrapper>;

/// A task wrapper is a trait that can be used to run a task. We use it to
/// erase the I, O types from the Task struct so that tasks.
#[async_trait]
pub(super) trait TaskWrapper: Send + Debug {
async fn run(&self);
}

/// Implement the TaskWrapper trait for every Task. This allows us to
/// erase the I, O types from the Task struct so that tasks can be
/// stored in a homogenous queue regardless of their input and output types.
#[async_trait]
impl<Input, Output> TaskWrapper for Task<Input, Output>
where
Input: Send + Sync + Debug,
Output: Send + Sync + Debug,
{
async fn run(&self) {
let output = self.operator.run(&self.input).await;
let res = self.reply_channel.send(output).await;
// TODO: if this errors, it means the caller was dropped
}
}

/// Wrap an operator and its input into a task message.
pub(super) fn wrap<Input, Output>(
operator: Box<dyn Operator<Input, Output>>,
input: Input,
reply_channel: Box<dyn Receiver<Output>>,
) -> TaskMessage
where
Input: Send + Sync + Debug + 'static,
Output: Send + Sync + Debug + 'static,
{
Box::new(Task {
operator,
input,
reply_channel,
})
}
1 change: 1 addition & 0 deletions rust/worker/src/execution/operators/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub(super) mod pull_log;
Loading

0 comments on commit 93a659d

Please sign in to comment.