diff --git a/examples/redis-mq-example/src/main.rs b/examples/redis-mq-example/src/main.rs index 638cdf5a..ec8b8de8 100644 --- a/examples/redis-mq-example/src/main.rs +++ b/examples/redis-mq-example/src/main.rs @@ -81,10 +81,7 @@ impl MessageQueue for RedisMq { async fn enqueue(&mut self, message: Message) -> Result<(), Self::Error> { let bytes = self .codec - .encode(&RedisJob { - ctx: Default::default(), - job: message, - }) + .encode(&RedisJob::new(message, Default::default())) .unwrap(); self.conn.send_message("email", bytes, None).await?; Ok(()) diff --git a/packages/apalis-core/src/lib.rs b/packages/apalis-core/src/lib.rs index 5511ceef..1ab6244c 100644 --- a/packages/apalis-core/src/lib.rs +++ b/packages/apalis-core/src/lib.rs @@ -22,6 +22,8 @@ #![cfg_attr(docsrs, feature(doc_cfg))] //! # apalis-core //! Utilities for building job and message processing tools. +use std::sync::Arc; + use futures::Stream; use poller::Poller; use worker::WorkerId; @@ -93,6 +95,10 @@ pub trait Codec { fn decode(&self, compact: &Compact) -> Result; } +/// A boxed codec +pub type BoxCodec = + Arc + Sync + Send + 'static>>; + /// Sleep utilities #[cfg(feature = "sleep")] pub async fn sleep(duration: std::time::Duration) { diff --git a/packages/apalis-redis/src/storage.rs b/packages/apalis-redis/src/storage.rs index ea845518..008c7f18 100644 --- a/packages/apalis-redis/src/storage.rs +++ b/packages/apalis-redis/src/storage.rs @@ -95,9 +95,51 @@ struct RedisScript { #[derive(Clone, Debug, Serialize, Deserialize)] pub struct RedisJob { /// The job context - pub ctx: Context, + ctx: Context, /// The inner job - pub job: J, + job: J, +} + +impl RedisJob { + /// Creates a new RedisJob. + pub fn new(job: J, ctx: Context) -> Self { + RedisJob { ctx, job } + } + + /// Gets a reference to the context. + pub fn ctx(&self) -> &Context { + &self.ctx + } + + /// Gets a mutable reference to the context. + pub fn ctx_mut(&mut self) -> &mut Context { + &mut self.ctx + } + + /// Sets the context. + pub fn set_ctx(&mut self, ctx: Context) { + self.ctx = ctx; + } + + /// Gets a reference to the job. + pub fn job(&self) -> &J { + &self.job + } + + /// Gets a mutable reference to the job. + pub fn job_mut(&mut self) -> &mut J { + &mut self.job + } + + /// Sets the job. + pub fn set_job(&mut self, job: J) { + self.job = job; + } + + /// Combines context and job into a tuple. + pub fn into_tuple(self) -> (Context, J) { + (self.ctx, self.job) + } } impl From> for Request { diff --git a/packages/apalis-sql/src/from_row.rs b/packages/apalis-sql/src/from_row.rs index b14c675e..00d748f7 100644 --- a/packages/apalis-sql/src/from_row.rs +++ b/packages/apalis-sql/src/from_row.rs @@ -6,8 +6,50 @@ use crate::context::SqlContext; /// Wrapper for [Request] #[derive(Debug, Clone)] pub struct SqlRequest { - pub(crate) req: T, - pub(crate) context: SqlContext, + req: T, + context: SqlContext, +} + +impl SqlRequest { + /// Creates a new SqlRequest. + pub fn new(req: T, context: SqlContext) -> Self { + SqlRequest { req, context } + } + + /// Gets a reference to the request. + pub fn req(&self) -> &T { + &self.req + } + + /// Gets a mutable reference to the request. + pub fn req_mut(&mut self) -> &mut T { + &mut self.req + } + + /// Sets the request. + pub fn set_req(&mut self, req: T) { + self.req = req; + } + + /// Gets a reference to the context. + pub fn context(&self) -> &SqlContext { + &self.context + } + + /// Gets a mutable reference to the context. + pub fn context_mut(&mut self) -> &mut SqlContext { + &mut self.context + } + + /// Sets the context. + pub fn set_context(&mut self, context: SqlContext) { + self.context = context; + } + + /// Combines request and context into a tuple. + pub fn into_tuple(self) -> (T, SqlContext) { + (self.req, self.context) + } } impl From> for Request { diff --git a/packages/apalis-sql/src/mysql.rs b/packages/apalis-sql/src/mysql.rs index 2abaa17b..7a2be327 100644 --- a/packages/apalis-sql/src/mysql.rs +++ b/packages/apalis-sql/src/mysql.rs @@ -10,7 +10,7 @@ use apalis_core::storage::Storage; use apalis_core::task::namespace::Namespace; use apalis_core::task::task_id::TaskId; use apalis_core::worker::WorkerId; -use apalis_core::{Backend, Codec}; +use apalis_core::{Backend, BoxCodec}; use async_stream::try_stream; use futures::{Stream, StreamExt, TryStreamExt}; use log::error; @@ -38,7 +38,7 @@ pub struct MysqlStorage { job_type: PhantomData, controller: Controller, config: Config, - codec: Arc + Sync + Send + 'static>>, + codec: BoxCodec, ack_notify: Notify<(WorkerId, TaskId)>, } @@ -109,6 +109,11 @@ impl MysqlStorage { pub fn pool(&self) -> &Pool { &self.pool } + + /// Expose the codec + pub fn codec(&self) -> &BoxCodec { + &self.codec + } } impl MysqlStorage { @@ -159,13 +164,18 @@ impl MysqlStorage { let jobs: Vec> = query.fetch_all(&pool).await?; for job in jobs { - yield Some(Into::into(SqlRequest { - context: job.context, - req: self.codec.decode(&job.req).map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))? - })).map(|mut req: Request| { + yield { + let (req, ctx) = job.into_tuple(); + let req = self + .codec + .decode(&req) + .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e))) + .unwrap(); + let req = SqlRequest::new(req, ctx); + let mut req: Request = req.into(); req.insert(Namespace(config.namespace.clone())); - req - }) + Some(req) + } } } } @@ -261,15 +271,17 @@ where .await?; match res { None => Ok(None), - Some(c) => Ok(Some( - SqlRequest { - context: c.context, - req: self.codec.decode(&c.req).map_err(|e| { - sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)) - })?, - } - .into(), - )), + Some(job) => Ok(Some({ + let (req, ctx) = job.into_tuple(); + let req = self + .codec + .decode(&req) + .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?; + let req = SqlRequest::new(req, ctx); + let mut req: Request = req.into(); + req.insert(Namespace(self.config.namespace.clone())); + req + })), } } diff --git a/packages/apalis-sql/src/postgres.rs b/packages/apalis-sql/src/postgres.rs index 4dc73c02..53f3e32c 100644 --- a/packages/apalis-sql/src/postgres.rs +++ b/packages/apalis-sql/src/postgres.rs @@ -52,7 +52,7 @@ use apalis_core::storage::Storage; use apalis_core::task::namespace::Namespace; use apalis_core::task::task_id::TaskId; use apalis_core::worker::WorkerId; -use apalis_core::{Backend, Codec}; +use apalis_core::{Backend, BoxCodec}; use futures::channel::mpsc; use futures::StreamExt; use futures::{select, stream, SinkExt}; @@ -79,14 +79,7 @@ use crate::from_row::SqlRequest; pub struct PostgresStorage { pool: PgPool, job_type: PhantomData, - codec: Arc< - Box< - dyn Codec - + Sync - + Send - + 'static, - >, - >, + codec: BoxCodec, config: Config, controller: Controller, ack_notify: Notify>, @@ -259,6 +252,16 @@ impl PostgresStorage { pub fn pool(&self) -> &Pool { &self.pool } + + /// Expose the config + pub fn config(&self) -> &Config { + &self.config + } + + /// Expose the codec + pub fn codec(&self) -> &BoxCodec { + &self.codec + } } /// A listener that listens to Postgres notifications @@ -323,7 +326,6 @@ impl PgListen { impl PostgresStorage { async fn fetch_next(&mut self, worker_id: &WorkerId) -> Result>, sqlx::Error> { let config = &self.config; - let codec = &self.codec; let job_type = &config.namespace; let fetch_query = "Select * from apalis.get_jobs($1, $2, $3);"; let jobs: Vec> = sqlx::query_as(fetch_query) @@ -339,15 +341,15 @@ impl PostgresStorage { let jobs: Vec<_> = jobs .into_iter() .map(|job| { - let req = SqlRequest { - context: job.context, - req: codec - .decode(&job.req) - .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e))) - .unwrap(), - }; + let (req, ctx) = job.into_tuple(); + let req = self + .codec + .decode(&req) + .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e))) + .unwrap(); + let req = SqlRequest::new(req, ctx); let mut req: Request = req.into(); - req.insert(Namespace(config.namespace.clone())); + req.insert(Namespace(self.config.namespace.clone())); req }) .collect(); @@ -445,17 +447,21 @@ where .bind(job_id.to_string()) .fetch_optional(&self.pool) .await?; + match res { None => Ok(None), - Some(c) => Ok(Some( - SqlRequest { - context: c.context, - req: self.codec.decode(&c.req).map_err(|e| { - sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)) - })?, - } - .into(), - )), + Some(job) => Ok(Some({ + let (req, ctx) = job.into_tuple(); + let req = self + .codec + .decode(&req) + .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e))) + .unwrap(); + let req = SqlRequest::new(req, ctx); + let mut req: Request = req.into(); + req.insert(Namespace(self.config.namespace.clone())); + req + })), } } diff --git a/packages/apalis-sql/src/sqlite.rs b/packages/apalis-sql/src/sqlite.rs index 3e31a64d..6a3e25dd 100644 --- a/packages/apalis-sql/src/sqlite.rs +++ b/packages/apalis-sql/src/sqlite.rs @@ -12,7 +12,7 @@ use apalis_core::storage::Storage; use apalis_core::task::namespace::Namespace; use apalis_core::task::task_id::TaskId; use apalis_core::worker::WorkerId; -use apalis_core::{Backend, Codec}; +use apalis_core::{Backend, BoxCodec}; use async_stream::try_stream; use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt}; use serde::{de::DeserializeOwned, Serialize}; @@ -31,7 +31,7 @@ pub use sqlx::sqlite::SqlitePool; /// The code used to encode Sqlite jobs. /// /// Currently uses JSON -pub type SqliteCodec = Arc + Sync + Send + 'static>>; +pub type SqliteCodec = BoxCodec; /// Represents a [Storage] that persists to Sqlite // #[derive(Debug)] @@ -200,22 +200,19 @@ impl SqliteStorage { let res = fetch_next(&pool, &worker_id, id.0, &config).await?; yield match res { None => None::>, - Some(c) => Some( - SqlRequest { - context: c.context, - req: codec.decode(&c.req).map_err(|e| { - sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)) - })?, - } - .into(), - ).map(|mut req: Request| { + Some(job) => { + let (req, ctx) = job.into_tuple(); + let req = codec + .decode(&req) + .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e))) + .unwrap(); + let req = SqlRequest::new(req, ctx); + let mut req: Request = req.into(); req.insert(Namespace(config.namespace.clone())); - req - }), + Some(req) + } } - - .map(Into::into); - } + }; } } } @@ -280,15 +277,18 @@ where .await?; match res { None => Ok(None), - Some(c) => Ok(Some( - SqlRequest { - context: c.context, - req: self.codec.decode(&c.req).map_err(|e| { - sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)) - })?, - } - .into(), - )), + Some(job) => Ok(Some({ + let (req, ctx) = job.into_tuple(); + let req = self + .codec + .decode(&req) + .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e))) + .unwrap(); + let req = SqlRequest::new(req, ctx); + let mut req: Request = req.into(); + req.insert(Namespace(self.config.namespace.clone())); + req + })), } }