Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(macos): prevent task object been released while executing async command #1285

Merged
merged 4 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changes/fix-macos-async-command-panic.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"wry": patch
---

On macOS, fix an issue that could cause a panic when running an async command.
2 changes: 2 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,6 @@ pub enum Error {
#[cfg(target_os = "android")]
#[error(transparent)]
CrossBeamRecvError(#[from] crossbeam_channel::RecvError),
#[error("Custom protocol task is invalid.")]
CustomProtocolTaskInvalid,
}
99 changes: 61 additions & 38 deletions src/wkwebview/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ use core_graphics::{
use objc::{
declare::ClassDecl,
runtime::{Class, Object, Sel, BOOL},
Message,
};
use objc_id::Id;

Expand Down Expand Up @@ -193,7 +194,7 @@ impl InnerWebView {
}

// Task handler for custom protocol
extern "C" fn start_task(this: &Object, _: Sel, _webview: id, task: id) {
extern "C" fn start_task(this: &Object, _: Sel, _webview: id, task: *mut Object) {
unsafe {
#[cfg(feature = "tracing")]
let span = tracing::info_span!(parent: None, "wry::custom_protocol::handle", uri = tracing::field::Empty)
Expand Down Expand Up @@ -274,53 +275,75 @@ impl InnerWebView {
// send response
match http_request.body(sent_form_body) {
Ok(final_request) => {
let () = msg_send![task, retain];

let responder: Box<dyn FnOnce(HttpResponse<Cow<'static, [u8]>>)> = Box::new(
move |sent_response| {
let content = sent_response.body();
// default: application/octet-stream, but should be provided by the client
let wanted_mime = sent_response.headers().get(CONTENT_TYPE);
// default to 200
let wanted_status_code = sent_response.status().as_u16() as i32;
// default to HTTP/1.1
let wanted_version = format!("{:#?}", sent_response.version());

let dictionary: id = msg_send![class!(NSMutableDictionary), alloc];
let headers: id = msg_send![dictionary, initWithCapacity:1];
if let Some(mime) = wanted_mime {
let () = msg_send![headers, setObject:NSString::new(mime.to_str().unwrap()) forKey: NSString::new(CONTENT_TYPE.as_str())];
fn check_webview_id_valid(webview_id: u32) -> crate::Result<()> {
match WEBVIEW_IDS.lock().unwrap().contains(&webview_id) {
true => Ok(()),
false => Err(crate::Error::CustomProtocolTaskInvalid),
}
}
let () = msg_send![headers, setObject:NSString::new(&content.len().to_string()) forKey: NSString::new(CONTENT_LENGTH.as_str())];

// add headers
for (name, value) in sent_response.headers().iter() {
let header_key = name.as_str();
if let Ok(value) = value.to_str() {
let () = msg_send![headers, setObject:NSString::new(value) forKey: NSString::new(header_key)];
unsafe fn response(
task: id,
webview_id: u32,
url: id, /* NSURL */
sent_response: HttpResponse<Cow<'_, [u8]>>,
) -> crate::Result<()> {
let content = sent_response.body();
// default: application/octet-stream, but should be provided by the client
let wanted_mime = sent_response.headers().get(CONTENT_TYPE);
// default to 200
let wanted_status_code = sent_response.status().as_u16() as i32;
// default to HTTP/1.1
let wanted_version = format!("{:#?}", sent_response.version());

let dictionary: id = msg_send![class!(NSMutableDictionary), alloc];
let headers: id = msg_send![dictionary, initWithCapacity:1];
if let Some(mime) = wanted_mime {
let () = msg_send![headers, setObject:NSString::new(mime.to_str().unwrap()) forKey: NSString::new(CONTENT_TYPE.as_str())];
}
let () = msg_send![headers, setObject:NSString::new(&content.len().to_string()) forKey: NSString::new(CONTENT_LENGTH.as_str())];

// add headers
for (name, value) in sent_response.headers().iter() {
let header_key = name.as_str();
if let Ok(value) = value.to_str() {
let () = msg_send![headers, setObject:NSString::new(value) forKey: NSString::new(header_key)];
}
}
}

let urlresponse: id = msg_send![class!(NSHTTPURLResponse), alloc];
let response: id = msg_send![urlresponse, initWithURL:url statusCode: wanted_status_code HTTPVersion:NSString::new(&wanted_version) headerFields:headers];
if !WEBVIEW_IDS.lock().unwrap().contains(&webview_id) {
return;
}
let () = msg_send![task, didReceiveResponse: response];
let urlresponse: id = msg_send![class!(NSHTTPURLResponse), alloc];
let response: id = msg_send![urlresponse, initWithURL:url statusCode: wanted_status_code HTTPVersion:NSString::new(&wanted_version) headerFields:headers];

// Send data
let bytes = content.as_ptr() as *mut c_void;
let data: id = msg_send![class!(NSData), alloc];
let data: id = msg_send![data, initWithBytesNoCopy:bytes length:content.len() freeWhenDone: if content.len() == 0 { NO } else { YES }];
check_webview_id_valid(webview_id)?;
(*task)
.send_message::<(id,), ()>(sel!(didReceiveResponse:), (response,))
.map_err(|_| crate::Error::CustomProtocolTaskInvalid)?;

if !WEBVIEW_IDS.lock().unwrap().contains(&webview_id) {
return;
}
let () = msg_send![task, didReceiveData: data];
// Send data
let bytes = content.as_ptr() as *mut c_void;
let data: id = msg_send![class!(NSData), alloc];
let data: id = msg_send![data, initWithBytesNoCopy:bytes length:content.len() freeWhenDone: if content.len() == 0 { NO } else { YES }];

check_webview_id_valid(webview_id)?;
(*task)
.send_message::<(id,), ()>(sel!(didReceiveData:), (data,))
.map_err(|_| crate::Error::CustomProtocolTaskInvalid)?;

// Finish
if !WEBVIEW_IDS.lock().unwrap().contains(&webview_id) {
return;
// Finish
check_webview_id_valid(webview_id)?;
(*task)
.send_message::<(), ()>(sel!(didFinish), ())
.map_err(|_| crate::Error::CustomProtocolTaskInvalid)?;

Ok(())
}
let () = msg_send![task, didFinish];

let _ = response(task, webview_id, url, sent_response);
let () = msg_send![task, release];
},
);

Expand Down
Loading