From 84646136cbad0613999ec8f6caf885e55638880d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9o=20Coletta?= Date: Mon, 28 Nov 2022 18:52:10 +0100 Subject: [PATCH] wip: dynamic fn map --- .gitignore | 4 +- < | 84 ++++++++++++++++++++++++++++++++++++++++ Cargo.toml | 6 +++ rust-toolchain.toml | 4 ++ src/db/job.rs | 24 ++++++++++++ src/db/mod.rs | 1 + src/job.rs | 52 +++++++++++++++++++++++++ src/lib.rs | 94 ++++++++++++++++++++++++++++++++++++++++----- src/utils.rs | 22 ++++++----- 9 files changed, 270 insertions(+), 21 deletions(-) create mode 100644 < create mode 100644 rust-toolchain.toml create mode 100644 src/db/job.rs create mode 100644 src/db/mod.rs create mode 100644 src/job.rs diff --git a/.gitignore b/.gitignore index 48d5723..c5def5a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,3 @@ .env -/target -/Cargo.lock +target +Cargo.lock diff --git a/< b/< new file mode 100644 index 0000000..808dc7f --- /dev/null +++ b/< @@ -0,0 +1,84 @@ +use std::collections::HashMap; +use std::fmt::Debug; +use std::future::Future; +use std::pin::Pin; + +use serde::Deserialize; + +pub mod context; +mod db; +pub mod errors; +pub mod migrate; +mod migrations; +mod utils; + +#[derive(Clone)] +pub struct WorkerContext { + pool: sqlx::PgPool, +} + +type WorkerFn = + Box Pin> + Send>>>; + +pub struct Worker { + concurrency: usize, + poll_interval: u32, + jobs: HashMap, +} + +impl Worker { + pub fn builder() -> WorkerBuilder { + WorkerBuilder { + concurrency: None, + poll_interval: None, + jobs: None, + } + } +} + +#[derive(Default)] +pub struct WorkerBuilder { + concurrency: Option, + poll_interval: Option, + jobs: Option>, +} + +impl WorkerBuilder { + pub fn build(self) -> Worker { + Worker { + concurrency: self.concurrency.unwrap_or_else(num_cpus::get), + poll_interval: self.poll_interval.unwrap_or(1000), + jobs: self.jobs.unwrap_or_else(|| HashMap::new()), + } + } + + pub fn concurrency(&mut self, value: usize) -> &mut Self { + self.concurrency = Some(value); + self + } + + pub fn poll_interval(&mut self, value: u32) -> &mut Self { + self.poll_interval = Some(value); + self + } + + pub fn jobs(&mut self, identifier: &str, job: F) -> &mut Self + where + T: for<'de> Deserialize<'de>, + E: Debug, + Fut: Future>, + F: Fn(WorkerContext, T) -> Fut, + { + let worker_fn = |ctx, payload| -> Result<(), String> { + Box::pin(async move { + let p = serde_json::from_str(payload).map_err(|e| format!("{:?}", e))?; + job(ctx, p).await.map_err(|e| format!("{:?}", e))?; + Ok(()) + }) + }; + + let job_map = self.jobs.unwrap_or_else(|| HashMap::new()); + job_map.insert(identifier.to_string(), Box::new(worker_fn)); + self + } +} diff --git a/Cargo.toml b/Cargo.toml index 46d6319..de09b22 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,9 +11,15 @@ runtime-async-std-rustls = ["sqlx/runtime-async-std-rustls"] runtime-async-std-native-tls = ["sqlx/runtime-async-std-native-tls"] [dependencies] +anyhow = "1.0.66" async-trait = "0.1.58" +cargo-insta = "1.21.1" +chrono = { version = "0.4.23", features = ["serde"] } futures = "0.3.25" getset = "0.1.2" +num_cpus = "1.14.0" +serde = { version = "1.0.147", features = ["derive"] } +serde_json = "1.0.89" sqlx = { version = "0.6.2", features = ["postgres", "json"] } thiserror = "1.0.37" tracing = "0.1.37" diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 0000000..f7f18e6 --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,4 @@ +[toolchain] +profile = "default" +components = ["rustfmt", "clippy"] +channel = "nightly" diff --git a/src/db/job.rs b/src/db/job.rs new file mode 100644 index 0000000..63d6c74 --- /dev/null +++ b/src/db/job.rs @@ -0,0 +1,24 @@ +use chrono::prelude::*; +use getset::Getters; +use sqlx::FromRow; + +#[derive(FromRow, Getters)] +#[getset(get = "pub")] +pub struct Job { + id: i32, + job_queue_id: i32, + task_id: i32, + payload: Vec, + priority: i32, + run_at: DateTime, + max_attempts: i16, + last_error: String, + created_at: DateTime, + updated_at: DateTime, + key: String, + locked_at: DateTime, + locked_by: String, + revision: i32, + flags: serde_json::Value, + is_available: bool, +} diff --git a/src/db/mod.rs b/src/db/mod.rs new file mode 100644 index 0000000..5eb5348 --- /dev/null +++ b/src/db/mod.rs @@ -0,0 +1 @@ +mod job; diff --git a/src/job.rs b/src/job.rs new file mode 100644 index 0000000..df0d374 --- /dev/null +++ b/src/job.rs @@ -0,0 +1,52 @@ +use std::{fmt::Debug, future::ready, marker::PhantomData, pin::Pin}; + +use async_trait::async_trait; +use futures::{future::LocalBoxFuture, Future, FutureExt, TryFutureExt}; +use getset::Getters; +use serde::Deserialize; + +#[derive(Getters)] +#[getset(get = "pub")] +pub struct JobCtx { + pool: sqlx::PgPool, +} + +pub struct Payload Deserialize<'de>>(T); + +impl Deserialize<'de>> Payload { + fn from_str(s: &str) -> serde_json::Result { + Ok(Self(serde_json::from_str(s)?)) + } +} + +pub trait JobHandler, Fut: Future>> { + fn handler(&self, ctx: JobCtx, payload: &str) -> Fut; +} + +struct JobFn +where + E: Debug + From, + Fut: Future>, + T: for<'de> Deserialize<'de>, + F: Fn(JobCtx, Payload) -> Fut, +{ + job_fn: F, + t: PhantomData, + o: PhantomData, + e: PhantomData, + fut: PhantomData, +} + +impl<'a, T, O, E, Fut2, F> JobHandler>> for JobFn +where + E: Debug + From, + T: for<'de> Deserialize<'de>, + Fut2: Future>, + F: Fn(JobCtx, Payload) -> Fut2, +{ + fn handler(&self, ctx: JobCtx, payload: &str) -> LocalBoxFuture<'a, Result> { + ready(Payload::from_str(payload).map_err(|e| E::from(e))) + .and_then(|de_payload| (self.job_fn)(ctx, de_payload)) + .boxed_local() + } +} diff --git a/src/lib.rs b/src/lib.rs index 7c5899f..c7ecad1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,20 +1,96 @@ +use std::collections::HashMap; +use std::fmt::Debug; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; + +use futures::FutureExt; +use serde::Deserialize; + pub mod context; +mod db; pub mod errors; pub mod migrate; mod migrations; mod utils; -pub fn add(left: usize, right: usize) -> usize { - left + right +#[derive(Clone)] +pub struct WorkerContext { + pool: sqlx::PgPool, +} + +type WorkerFn = + Box Pin> + Send>>>; + +pub struct Worker { + concurrency: usize, + poll_interval: u32, + jobs: HashMap, } -#[cfg(test)] -mod tests { - use super::*; +impl Worker { + pub fn builder() -> WorkerBuilder { + WorkerBuilder { + concurrency: None, + poll_interval: None, + jobs: None, + } + } +} + +#[derive(Default)] +pub struct WorkerBuilder { + concurrency: Option, + poll_interval: Option, + jobs: Option>, +} + +impl WorkerBuilder { + pub fn build(self) -> Worker { + Worker { + concurrency: self.concurrency.unwrap_or_else(num_cpus::get), + poll_interval: self.poll_interval.unwrap_or(1000), + jobs: self.jobs.unwrap_or_else(|| HashMap::new()), + } + } + + pub fn concurrency(&mut self, value: usize) -> &mut Self { + self.concurrency = Some(value); + self + } + + pub fn poll_interval(&mut self, value: u32) -> &mut Self { + self.poll_interval = Some(value); + self + } + + pub fn jobs(&mut self, identifier: &str, job_fn: F) -> &mut Self + where + T: for<'de> Deserialize<'de> + Send, + E: Debug, + Fut: Future> + Send, + F: Fn(WorkerContext, T) -> Fut + Send + Sync + Clone + 'static, + { + let worker_fn = |ctx, payload| { + async { + let de_payload = serde_json::from_str(payload).cloned(); + + match de_payload { + Err(e) => Err(format!("{:?}", e)), + Ok(p) => { + let job_result = job_fn.clone()(ctx, p).await; + match job_result { + Err(e) => Err(format!("{:?}", e)), + Ok(v) => Ok(v), + } + } + } + } + .boxed() + }; - #[test] - fn it_works() { - let result = add(2, 2); - assert_eq!(result, 4); + let job_map = self.jobs.unwrap_or_else(|| HashMap::new()); + job_map.insert(identifier.to_string(), Box::new(worker_fn)); + self } } diff --git a/src/utils.rs b/src/utils.rs index 934be98..266eea8 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,19 +1,21 @@ -use sqlx::{query, Executor, Postgres}; +use sqlx::{query_as, Executor, FromRow, Postgres}; use crate::errors::Result; +#[derive(FromRow)] +struct EscapeIdentifierRow { + escaped_identifier: String, +} + pub async fn escape_identifier<'e, E: Executor<'e, Database = Postgres>>( executor: E, identifier: &str, ) -> Result { - let escaped_identifier = query!( - "select format('%I', $1::text) as escaped_identifier", - identifier - ) - .fetch_one(executor) - .await? - .escaped_identifier - .unwrap(); + let result: EscapeIdentifierRow = + query_as("select format('%I', $1::text) as escaped_identifier") + .bind(identifier) + .fetch_one(executor) + .await?; - Ok(escaped_identifier) + Ok(result.escaped_identifier) }