Skip to content

Commit

Permalink
wip: dynamic fn map
Browse files Browse the repository at this point in the history
  • Loading branch information
leo91000 committed Nov 28, 2022
1 parent db5ec81 commit 8464613
Show file tree
Hide file tree
Showing 9 changed files with 270 additions and 21 deletions.
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
.env
/target
/Cargo.lock
target
Cargo.lock
84 changes: 84 additions & 0 deletions <
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
use std::collections::HashMap;
use std::fmt::Debug;
use std::future::Future;
use std::pin::Pin;

use serde::Deserialize;

pub mod context;
mod db;
pub mod errors;
pub mod migrate;
mod migrations;
mod utils;

#[derive(Clone)]
pub struct WorkerContext {
pool: sqlx::PgPool,
}

type WorkerFn =
Box<dyn Fn(WorkerContext, &str) -> Pin<Box<dyn Future<Output = Result<(), String>> + Send>>>;

pub struct Worker {
concurrency: usize,
poll_interval: u32,
jobs: HashMap<String, WorkerFn>,
}

impl Worker {
pub fn builder() -> WorkerBuilder {
WorkerBuilder {
concurrency: None,
poll_interval: None,
jobs: None,
}
}
}

#[derive(Default)]
pub struct WorkerBuilder {
concurrency: Option<usize>,
poll_interval: Option<u32>,
jobs: Option<HashMap<String, WorkerFn>>,
}

impl WorkerBuilder {
pub fn build(self) -> Worker {
Worker {
concurrency: self.concurrency.unwrap_or_else(num_cpus::get),
poll_interval: self.poll_interval.unwrap_or(1000),
jobs: self.jobs.unwrap_or_else(|| HashMap::new()),
}
}

pub fn concurrency(&mut self, value: usize) -> &mut Self {
self.concurrency = Some(value);
self
}

pub fn poll_interval(&mut self, value: u32) -> &mut Self {
self.poll_interval = Some(value);
self
}

pub fn jobs<T, E, Fut, F>(&mut self, identifier: &str, job: F) -> &mut Self
where
T: for<'de> Deserialize<'de>,
E: Debug,
Fut: Future<Output = Result<(), E>>,
F: Fn(WorkerContext, T) -> Fut,
{
let worker_fn = |ctx, payload| -> Result<(), String> {
Box::pin(async move {
let p = serde_json::from_str(payload).map_err(|e| format!("{:?}", e))?;
job(ctx, p).await.map_err(|e| format!("{:?}", e))?;
Ok(())
})
};

let job_map = self.jobs.unwrap_or_else(|| HashMap::new());
job_map.insert(identifier.to_string(), Box::new(worker_fn));
self
}
}
6 changes: 6 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,15 @@ runtime-async-std-rustls = ["sqlx/runtime-async-std-rustls"]
runtime-async-std-native-tls = ["sqlx/runtime-async-std-native-tls"]

[dependencies]
anyhow = "1.0.66"
async-trait = "0.1.58"
cargo-insta = "1.21.1"
chrono = { version = "0.4.23", features = ["serde"] }
futures = "0.3.25"
getset = "0.1.2"
num_cpus = "1.14.0"
serde = { version = "1.0.147", features = ["derive"] }
serde_json = "1.0.89"
sqlx = { version = "0.6.2", features = ["postgres", "json"] }
thiserror = "1.0.37"
tracing = "0.1.37"
Expand Down
4 changes: 4 additions & 0 deletions rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[toolchain]
profile = "default"
components = ["rustfmt", "clippy"]
channel = "nightly"
24 changes: 24 additions & 0 deletions src/db/job.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
use chrono::prelude::*;
use getset::Getters;
use sqlx::FromRow;

#[derive(FromRow, Getters)]
#[getset(get = "pub")]
pub struct Job {
id: i32,
job_queue_id: i32,
task_id: i32,
payload: Vec<u8>,
priority: i32,
run_at: DateTime<Utc>,
max_attempts: i16,
last_error: String,
created_at: DateTime<Utc>,
updated_at: DateTime<Utc>,
key: String,
locked_at: DateTime<Utc>,
locked_by: String,
revision: i32,
flags: serde_json::Value,
is_available: bool,
}
1 change: 1 addition & 0 deletions src/db/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
mod job;
52 changes: 52 additions & 0 deletions src/job.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use std::{fmt::Debug, future::ready, marker::PhantomData, pin::Pin};

use async_trait::async_trait;
use futures::{future::LocalBoxFuture, Future, FutureExt, TryFutureExt};
use getset::Getters;
use serde::Deserialize;

#[derive(Getters)]
#[getset(get = "pub")]
pub struct JobCtx {
pool: sqlx::PgPool,
}

pub struct Payload<T: for<'de> Deserialize<'de>>(T);

impl<T: for<'de> Deserialize<'de>> Payload<T> {
fn from_str(s: &str) -> serde_json::Result<Self> {
Ok(Self(serde_json::from_str(s)?))
}
}

pub trait JobHandler<O, E: Debug + From<serde_json::Error>, Fut: Future<Output = Result<O, E>>> {
fn handler(&self, ctx: JobCtx, payload: &str) -> Fut;
}

struct JobFn<T, O, E, Fut, F>
where
E: Debug + From<serde_json::Error>,
Fut: Future<Output = Result<O, E>>,
T: for<'de> Deserialize<'de>,
F: Fn(JobCtx, Payload<T>) -> Fut,
{
job_fn: F,
t: PhantomData<T>,
o: PhantomData<O>,
e: PhantomData<E>,
fut: PhantomData<Fut>,
}

impl<'a, T, O, E, Fut2, F> JobHandler<O, E, LocalBoxFuture<'a, Result<O, E>>> for JobFn<T, O, E, Fut2, F>
where
E: Debug + From<serde_json::Error>,
T: for<'de> Deserialize<'de>,
Fut2: Future<Output = Result<O, E>>,
F: Fn(JobCtx, Payload<T>) -> Fut2,
{
fn handler(&self, ctx: JobCtx, payload: &str) -> LocalBoxFuture<'a, Result<O, E>> {
ready(Payload::from_str(payload).map_err(|e| E::from(e)))
.and_then(|de_payload| (self.job_fn)(ctx, de_payload))
.boxed_local()
}
}
94 changes: 85 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,96 @@
use std::collections::HashMap;
use std::fmt::Debug;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;

use futures::FutureExt;
use serde::Deserialize;

pub mod context;
mod db;
pub mod errors;
pub mod migrate;
mod migrations;
mod utils;

pub fn add(left: usize, right: usize) -> usize {
left + right
#[derive(Clone)]
pub struct WorkerContext {
pool: sqlx::PgPool,
}

type WorkerFn =
Box<dyn Fn(WorkerContext, &str) -> Pin<Box<dyn Future<Output = Result<(), String>> + Send>>>;

pub struct Worker {
concurrency: usize,
poll_interval: u32,
jobs: HashMap<String, WorkerFn>,
}

#[cfg(test)]
mod tests {
use super::*;
impl Worker {
pub fn builder() -> WorkerBuilder {
WorkerBuilder {
concurrency: None,
poll_interval: None,
jobs: None,
}
}
}

#[derive(Default)]
pub struct WorkerBuilder {
concurrency: Option<usize>,
poll_interval: Option<u32>,
jobs: Option<HashMap<String, WorkerFn>>,
}

impl WorkerBuilder {
pub fn build(self) -> Worker {
Worker {
concurrency: self.concurrency.unwrap_or_else(num_cpus::get),
poll_interval: self.poll_interval.unwrap_or(1000),
jobs: self.jobs.unwrap_or_else(|| HashMap::new()),
}
}

pub fn concurrency(&mut self, value: usize) -> &mut Self {
self.concurrency = Some(value);
self
}

pub fn poll_interval(&mut self, value: u32) -> &mut Self {
self.poll_interval = Some(value);
self
}

pub fn jobs<T, E, Fut, F>(&mut self, identifier: &str, job_fn: F) -> &mut Self
where
T: for<'de> Deserialize<'de> + Send,
E: Debug,
Fut: Future<Output = Result<(), E>> + Send,
F: Fn(WorkerContext, T) -> Fut + Send + Sync + Clone + 'static,
{
let worker_fn = |ctx, payload| {
async {
let de_payload = serde_json::from_str(payload).cloned();

match de_payload {
Err(e) => Err(format!("{:?}", e)),
Ok(p) => {
let job_result = job_fn.clone()(ctx, p).await;
match job_result {
Err(e) => Err(format!("{:?}", e)),
Ok(v) => Ok(v),
}
}
}
}
.boxed()
};

#[test]
fn it_works() {
let result = add(2, 2);
assert_eq!(result, 4);
let job_map = self.jobs.unwrap_or_else(|| HashMap::new());
job_map.insert(identifier.to_string(), Box::new(worker_fn));
self
}
}
22 changes: 12 additions & 10 deletions src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
use sqlx::{query, Executor, Postgres};
use sqlx::{query_as, Executor, FromRow, Postgres};

use crate::errors::Result;

#[derive(FromRow)]
struct EscapeIdentifierRow {
escaped_identifier: String,
}

pub async fn escape_identifier<'e, E: Executor<'e, Database = Postgres>>(
executor: E,
identifier: &str,
) -> Result<String> {
let escaped_identifier = query!(
"select format('%I', $1::text) as escaped_identifier",
identifier
)
.fetch_one(executor)
.await?
.escaped_identifier
.unwrap();
let result: EscapeIdentifierRow =
query_as("select format('%I', $1::text) as escaped_identifier")
.bind(identifier)
.fetch_one(executor)
.await?;

Ok(escaped_identifier)
Ok(result.escaped_identifier)
}

0 comments on commit 8464613

Please sign in to comment.