Skip to content

Commit

Permalink
RPC cancellation support (#174)
Browse files Browse the repository at this point in the history
Fixes #168
  • Loading branch information
cretz authored Oct 21, 2024
1 parent cde31d9 commit 67741fc
Show file tree
Hide file tree
Showing 26 changed files with 1,109 additions and 1,276 deletions.
24 changes: 16 additions & 8 deletions temporalio/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 7 additions & 8 deletions temporalio/Rakefile
Original file line number Diff line number Diff line change
Expand Up @@ -193,19 +193,15 @@ namespace :proto do
# Calls #{class_name}.#{method.name} API call.
#
# @param request [#{method.input_type.msgclass}] API request.
# @param rpc_retry [Boolean] Whether to implicitly retry known retryable errors.
# @param rpc_metadata [Hash<String, String>, nil] Headers to include on the RPC call.
# @param rpc_timeout [Float, nil] Number of seconds before timeout.
# @param rpc_options [RPCOptions, nil] Advanced RPC options.
# @return [#{method.output_type.msgclass}] API response.
def #{rpc}(request, rpc_retry: false, rpc_metadata: nil, rpc_timeout: nil)
def #{rpc}(request, rpc_options: nil)
invoke_rpc(
rpc: '#{rpc}',
request_class: #{method.input_type.msgclass},
response_class: #{method.output_type.msgclass},
request:,
rpc_retry:,
rpc_metadata:,
rpc_timeout:
rpc_options:
)
end
TEXT
Expand Down Expand Up @@ -236,7 +232,10 @@ namespace :proto do
# Camel case to snake case
rpc = method.name.gsub(/([A-Z])/, '_\1').downcase.delete_prefix('_')
file.puts <<-TEXT
def #{rpc}: (untyped request, ?rpc_retry: bool, ?rpc_metadata: Hash[String, String]?, ?rpc_timeout: Float?) -> untyped
def #{rpc}: (
untyped request,
?rpc_options: RPCOptions?
) -> untyped
TEXT
end

Expand Down
1 change: 1 addition & 0 deletions temporalio/ext/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ temporal-sdk-core = { version = "0.1.0", path = "./sdk-core/core", features = ["
temporal-sdk-core-api = { version = "0.1.0", path = "./sdk-core/core-api" }
temporal-sdk-core-protos = { version = "0.1.0", path = "./sdk-core/sdk-core-protos" }
tokio = "1.26"
tokio-util = "0.7"
tonic = "0.12"
tracing = "0.1"
url = "2.2"
103 changes: 73 additions & 30 deletions temporalio/ext/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ pub fn init(ruby: &Ruby) -> Result<(), Error> {
inner_class.define_method("code", method!(RpcFailure::code, 0))?;
inner_class.define_method("message", method!(RpcFailure::message, 0))?;
inner_class.define_method("details", method!(RpcFailure::details, 0))?;

let inner_class = class.define_class("CancellationToken", class::object())?;
inner_class.define_singleton_method("new", function!(CancellationToken::new, 0))?;
inner_class.define_method("cancel", method!(CancellationToken::cancel, 0))?;
Ok(())
}

Expand All @@ -58,16 +62,17 @@ pub struct Client {
#[macro_export]
macro_rules! rpc_call {
($client:ident, $callback:ident, $call:ident, $trait:tt, $call_name:ident) => {{
let cancel_token = $call.cancel_token.clone();
if $call.retry {
let mut core_client = $client.core.clone();
let req = $call.into_request()?;
$crate::client::rpc_resp($client, $callback, async move {
$crate::client::rpc_resp($client, $callback, cancel_token, async move {
$trait::$call_name(&mut core_client, req).await
})
} else {
let mut core_client = $client.core.clone().into_inner();
let req = $call.into_request()?;
$crate::client::rpc_resp($client, $callback, async move {
$crate::client::rpc_resp($client, $callback, cancel_token, async move {
$trait::$call_name(&mut core_client, req).await
})
}
Expand Down Expand Up @@ -176,39 +181,43 @@ impl Client {

pub fn async_invoke_rpc(&self, args: &[Value]) -> Result<(), Error> {
let args = scan_args::scan_args::<(), (), (), (), _, ()>(args)?;
let (service, rpc, request, retry, metadata, timeout, queue) = scan_args::get_kwargs::<
_,
(
u8,
String,
RString,
bool,
Option<HashMap<String, String>>,
Option<f64>,
Value,
),
(),
(),
>(
args.keywords,
&[
id!("service"),
id!("rpc"),
id!("request"),
id!("rpc_retry"),
id!("rpc_metadata"),
id!("rpc_timeout"),
id!("queue"),
],
&[],
)?
.required;
let (service, rpc, request, retry, metadata, timeout, cancel_token, queue) =
scan_args::get_kwargs::<
_,
(
u8,
String,
RString,
bool,
Option<HashMap<String, String>>,
Option<f64>,
Option<&CancellationToken>,
Value,
),
(),
(),
>(
args.keywords,
&[
id!("service"),
id!("rpc"),
id!("request"),
id!("rpc_retry"),
id!("rpc_metadata"),
id!("rpc_timeout"),
id!("rpc_cancellation_token"),
id!("queue"),
],
&[],
)?
.required;
let call = RpcCall {
rpc,
request: unsafe { request.as_slice() },
retry,
metadata,
timeout,
cancel_token: cancel_token.map(|c| c.token.clone()),
_not_send_sync: PhantomData,
};
let callback = AsyncCallback::from_queue(queue);
Expand Down Expand Up @@ -249,6 +258,7 @@ pub(crate) struct RpcCall<'a> {
pub retry: bool,
pub metadata: Option<HashMap<String, String>>,
pub timeout: Option<f64>,
pub cancel_token: Option<tokio_util::sync::CancellationToken>,

// This RPC call contains an unsafe reference to Ruby bytes that does not
// outlive the call, so we prevent it from being sent to another thread.
Expand Down Expand Up @@ -280,14 +290,25 @@ impl RpcCall<'_> {
pub(crate) fn rpc_resp<P>(
client: &Client,
callback: AsyncCallback,
cancel_token: Option<tokio_util::sync::CancellationToken>,
fut: impl Future<Output = Result<tonic::Response<P>, tonic::Status>> + Send + 'static,
) -> Result<(), Error>
where
P: prost::Message,
P: Default,
{
client.runtime_handle.spawn(
async move { fut.await.map(|msg| msg.get_ref().encode_to_vec()) },
async move {
let res = if let Some(cancel_token) = cancel_token {
tokio::select! {
_ = cancel_token.cancelled() => Err(tonic::Status::new(tonic::Code::Cancelled, "<__user_canceled__>")),
v = fut => v,
}
} else {
fut.await
};
res.map(|msg| msg.get_ref().encode_to_vec())
},
move |_, result| {
match result {
// TODO(cretz): Any reasonable way to prevent byte copy that is just going to get decoded into proto
Expand All @@ -299,3 +320,25 @@ where
);
Ok(())
}

#[derive(DataTypeFunctions, TypedData)]
#[magnus(
class = "Temporalio::Internal::Bridge::Client::CancellationToken",
free_immediately
)]
pub struct CancellationToken {
pub(crate) token: tokio_util::sync::CancellationToken,
}

impl CancellationToken {
pub fn new() -> Result<Self, Error> {
Ok(Self {
token: tokio_util::sync::CancellationToken::new(),
})
}

pub fn cancel(&self) -> Result<(), Error> {
self.token.cancel();
Ok(())
}
}
Loading

0 comments on commit 67741fc

Please sign in to comment.