Skip to content

Commit

Permalink
fix: expose the missing apis (#361)
Browse files Browse the repository at this point in the history
  • Loading branch information
geofmureithi authored Jul 10, 2024
1 parent 8ae48dc commit f1daab3
Show file tree
Hide file tree
Showing 7 changed files with 182 additions and 77 deletions.
5 changes: 1 addition & 4 deletions examples/redis-mq-example/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,7 @@ impl<Message: Send + 'static> MessageQueue<Message> for RedisMq<Message> {
async fn enqueue(&mut self, message: Message) -> Result<(), Self::Error> {
let bytes = self
.codec
.encode(&RedisJob {
ctx: Default::default(),
job: message,
})
.encode(&RedisJob::new(message, Default::default()))
.unwrap();
self.conn.send_message("email", bytes, None).await?;
Ok(())
Expand Down
6 changes: 6 additions & 0 deletions packages/apalis-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#![cfg_attr(docsrs, feature(doc_cfg))]
//! # apalis-core
//! Utilities for building job and message processing tools.
use std::sync::Arc;

use futures::Stream;
use poller::Poller;
use worker::WorkerId;
Expand Down Expand Up @@ -93,6 +95,10 @@ pub trait Codec<T, Compact> {
fn decode(&self, compact: &Compact) -> Result<T, Self::Error>;
}

/// A boxed codec
pub type BoxCodec<T, Compact, Error = error::Error> =
Arc<Box<dyn Codec<T, Compact, Error = Error> + Sync + Send + 'static>>;

/// Sleep utilities
#[cfg(feature = "sleep")]
pub async fn sleep(duration: std::time::Duration) {
Expand Down
46 changes: 44 additions & 2 deletions packages/apalis-redis/src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,51 @@ struct RedisScript {
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct RedisJob<J> {
/// The job context
pub ctx: Context,
ctx: Context,
/// The inner job
pub job: J,
job: J,
}

impl<J> RedisJob<J> {
/// Creates a new RedisJob.
pub fn new(job: J, ctx: Context) -> Self {
RedisJob { ctx, job }
}

/// Gets a reference to the context.
pub fn ctx(&self) -> &Context {
&self.ctx
}

/// Gets a mutable reference to the context.
pub fn ctx_mut(&mut self) -> &mut Context {
&mut self.ctx
}

/// Sets the context.
pub fn set_ctx(&mut self, ctx: Context) {
self.ctx = ctx;
}

/// Gets a reference to the job.
pub fn job(&self) -> &J {
&self.job
}

/// Gets a mutable reference to the job.
pub fn job_mut(&mut self) -> &mut J {
&mut self.job
}

/// Sets the job.
pub fn set_job(&mut self, job: J) {
self.job = job;
}

/// Combines context and job into a tuple.
pub fn into_tuple(self) -> (Context, J) {
(self.ctx, self.job)
}
}

impl<T> From<RedisJob<T>> for Request<T> {
Expand Down
46 changes: 44 additions & 2 deletions packages/apalis-sql/src/from_row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,50 @@ use crate::context::SqlContext;
/// Wrapper for [Request]
#[derive(Debug, Clone)]
pub struct SqlRequest<T> {
pub(crate) req: T,
pub(crate) context: SqlContext,
req: T,
context: SqlContext,
}

impl<T> SqlRequest<T> {
/// Creates a new SqlRequest.
pub fn new(req: T, context: SqlContext) -> Self {
SqlRequest { req, context }
}

/// Gets a reference to the request.
pub fn req(&self) -> &T {
&self.req
}

/// Gets a mutable reference to the request.
pub fn req_mut(&mut self) -> &mut T {
&mut self.req
}

/// Sets the request.
pub fn set_req(&mut self, req: T) {
self.req = req;
}

/// Gets a reference to the context.
pub fn context(&self) -> &SqlContext {
&self.context
}

/// Gets a mutable reference to the context.
pub fn context_mut(&mut self) -> &mut SqlContext {
&mut self.context
}

/// Sets the context.
pub fn set_context(&mut self, context: SqlContext) {
self.context = context;
}

/// Combines request and context into a tuple.
pub fn into_tuple(self) -> (T, SqlContext) {
(self.req, self.context)
}
}

impl<T> From<SqlRequest<T>> for Request<T> {
Expand Down
46 changes: 29 additions & 17 deletions packages/apalis-sql/src/mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use apalis_core::storage::Storage;
use apalis_core::task::namespace::Namespace;
use apalis_core::task::task_id::TaskId;
use apalis_core::worker::WorkerId;
use apalis_core::{Backend, Codec};
use apalis_core::{Backend, BoxCodec};
use async_stream::try_stream;
use futures::{Stream, StreamExt, TryStreamExt};
use log::error;
Expand Down Expand Up @@ -38,7 +38,7 @@ pub struct MysqlStorage<T> {
job_type: PhantomData<T>,
controller: Controller,
config: Config,
codec: Arc<Box<dyn Codec<T, serde_json::Value, Error = Error> + Sync + Send + 'static>>,
codec: BoxCodec<T, serde_json::Value>,
ack_notify: Notify<(WorkerId, TaskId)>,
}

Expand Down Expand Up @@ -109,6 +109,11 @@ impl<T: Serialize + DeserializeOwned> MysqlStorage<T> {
pub fn pool(&self) -> &Pool<MySql> {
&self.pool
}

/// Expose the codec
pub fn codec(&self) -> &BoxCodec<T, serde_json::Value> {
&self.codec
}
}

impl<T: DeserializeOwned + Send + Unpin + Sync + 'static> MysqlStorage<T> {
Expand Down Expand Up @@ -159,13 +164,18 @@ impl<T: DeserializeOwned + Send + Unpin + Sync + 'static> MysqlStorage<T> {
let jobs: Vec<SqlRequest<Value>> = query.fetch_all(&pool).await?;

for job in jobs {
yield Some(Into::into(SqlRequest {
context: job.context,
req: self.codec.decode(&job.req).map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?
})).map(|mut req: Request<T>| {
yield {
let (req, ctx) = job.into_tuple();
let req = self
.codec
.decode(&req)
.map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))
.unwrap();
let req = SqlRequest::new(req, ctx);
let mut req: Request<T> = req.into();
req.insert(Namespace(config.namespace.clone()));
req
})
Some(req)
}
}
}
}
Expand Down Expand Up @@ -261,15 +271,17 @@ where
.await?;
match res {
None => Ok(None),
Some(c) => Ok(Some(
SqlRequest {
context: c.context,
req: self.codec.decode(&c.req).map_err(|e| {
sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e))
})?,
}
.into(),
)),
Some(job) => Ok(Some({
let (req, ctx) = job.into_tuple();
let req = self
.codec
.decode(&req)
.map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
let req = SqlRequest::new(req, ctx);
let mut req: Request<T> = req.into();
req.insert(Namespace(self.config.namespace.clone()));
req
})),
}
}

Expand Down
60 changes: 33 additions & 27 deletions packages/apalis-sql/src/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ use apalis_core::storage::Storage;
use apalis_core::task::namespace::Namespace;
use apalis_core::task::task_id::TaskId;
use apalis_core::worker::WorkerId;
use apalis_core::{Backend, Codec};
use apalis_core::{Backend, BoxCodec};
use futures::channel::mpsc;
use futures::StreamExt;
use futures::{select, stream, SinkExt};
Expand All @@ -79,14 +79,7 @@ use crate::from_row::SqlRequest;
pub struct PostgresStorage<T> {
pool: PgPool,
job_type: PhantomData<T>,
codec: Arc<
Box<
dyn Codec<T, serde_json::Value, Error = apalis_core::error::Error>
+ Sync
+ Send
+ 'static,
>,
>,
codec: BoxCodec<T, serde_json::Value>,
config: Config,
controller: Controller,
ack_notify: Notify<AckResponse<TaskId>>,
Expand Down Expand Up @@ -259,6 +252,16 @@ impl<T: Serialize + DeserializeOwned> PostgresStorage<T> {
pub fn pool(&self) -> &Pool<Postgres> {
&self.pool
}

/// Expose the config
pub fn config(&self) -> &Config {
&self.config
}

/// Expose the codec
pub fn codec(&self) -> &BoxCodec<T, serde_json::Value> {
&self.codec
}
}

/// A listener that listens to Postgres notifications
Expand Down Expand Up @@ -323,7 +326,6 @@ impl PgListen {
impl<T: DeserializeOwned + Send + Unpin + 'static> PostgresStorage<T> {
async fn fetch_next(&mut self, worker_id: &WorkerId) -> Result<Vec<Request<T>>, sqlx::Error> {
let config = &self.config;
let codec = &self.codec;
let job_type = &config.namespace;
let fetch_query = "Select * from apalis.get_jobs($1, $2, $3);";
let jobs: Vec<SqlRequest<serde_json::Value>> = sqlx::query_as(fetch_query)
Expand All @@ -339,15 +341,15 @@ impl<T: DeserializeOwned + Send + Unpin + 'static> PostgresStorage<T> {
let jobs: Vec<_> = jobs
.into_iter()
.map(|job| {
let req = SqlRequest {
context: job.context,
req: codec
.decode(&job.req)
.map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))
.unwrap(),
};
let (req, ctx) = job.into_tuple();
let req = self
.codec
.decode(&req)
.map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))
.unwrap();
let req = SqlRequest::new(req, ctx);
let mut req: Request<T> = req.into();
req.insert(Namespace(config.namespace.clone()));
req.insert(Namespace(self.config.namespace.clone()));
req
})
.collect();
Expand Down Expand Up @@ -445,17 +447,21 @@ where
.bind(job_id.to_string())
.fetch_optional(&self.pool)
.await?;

match res {
None => Ok(None),
Some(c) => Ok(Some(
SqlRequest {
context: c.context,
req: self.codec.decode(&c.req).map_err(|e| {
sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e))
})?,
}
.into(),
)),
Some(job) => Ok(Some({
let (req, ctx) = job.into_tuple();
let req = self
.codec
.decode(&req)
.map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))
.unwrap();
let req = SqlRequest::new(req, ctx);
let mut req: Request<T> = req.into();
req.insert(Namespace(self.config.namespace.clone()));
req
})),
}
}

Expand Down
Loading

0 comments on commit f1daab3

Please sign in to comment.