Skip to content

Commit

Permalink
refactor(sqlite): make background thread responsible for all FFI calls
Browse files Browse the repository at this point in the history
  • Loading branch information
abonander committed Dec 29, 2021
1 parent b3091b0 commit 6943ac0
Show file tree
Hide file tree
Showing 25 changed files with 1,404 additions and 997 deletions.
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,9 @@ paste = "1.0.1"
serde = { version = "1.0.111", features = ["derive"] }
serde_json = "1.0.53"
url = "2.1.1"

rand = "0.8.4"
rand_xoshiro = "0.6.0"
hex = "0.4"
#
# Any
#
Expand Down
5 changes: 4 additions & 1 deletion sqlx-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ mysql = [
"rand",
"rsa",
]
sqlite = ["libsqlite3-sys"]
sqlite = ["libsqlite3-sys", "futures-executor", "flume"]
mssql = ["uuid", "encoding_rs", "regex"]
any = []

Expand Down Expand Up @@ -122,6 +122,9 @@ futures-channel = { version = "0.3.5", default-features = false, features = ["si
futures-core = { version = "0.3.5", default-features = false }
futures-intrusive = "0.4.0"
futures-util = { version = "0.3.5", default-features = false, features = ["alloc", "sink"] }
# used by the SQLite worker thread to block on the async mutex that locks the database handle
futures-executor = { version = "0.3.17", optional = true }
flume = { version = "0.10.9", optional = true, default-features = false, features = ["async"] }
generic-array = { version = "0.14.4", default-features = false, optional = true }
hex = "0.4.2"
hmac = { version = "0.11.0", default-features = false, optional = true }
Expand Down
25 changes: 25 additions & 0 deletions sqlx-core/src/common/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,28 @@
mod statement_cache;

pub(crate) use statement_cache::StatementCache;
use std::fmt::{Debug, Formatter};
use std::ops::{Deref, DerefMut};

/// A wrapper for `Fn`s that provides a debug impl that just says "Function"
pub(crate) struct DebugFn<F: ?Sized>(pub F);

impl<F: ?Sized> Deref for DebugFn<F> {
type Target = F;

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl<F: ?Sized> DerefMut for DebugFn<F> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

impl<F: ?Sized> Debug for DebugFn<F> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("Function").finish()
}
}
2 changes: 1 addition & 1 deletion sqlx-core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pub enum Error {
Database(#[source] Box<dyn DatabaseError>),

/// Error communicating with the database backend.
#[error("error communicating with the server: {0}")]
#[error("error communicating with database: {0}")]
Io(#[from] io::Error),

/// Error occurred while attempting to establish a TLS connection.
Expand Down
27 changes: 25 additions & 2 deletions sqlx-core/src/sqlite/arguments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,16 @@ impl<'q> SqliteArguments<'q> {
self.values.push(SqliteArgumentValue::Null);
}
}

pub(crate) fn into_static(self) -> SqliteArguments<'static> {
SqliteArguments {
values: self
.values
.into_iter()
.map(SqliteArgumentValue::into_static)
.collect(),
}
}
}

impl<'q> Arguments<'q> for SqliteArguments<'q> {
Expand All @@ -49,7 +59,7 @@ impl<'q> Arguments<'q> for SqliteArguments<'q> {
}

impl SqliteArguments<'_> {
pub(super) fn bind(&self, handle: &StatementHandle, offset: usize) -> Result<usize, Error> {
pub(super) fn bind(&self, handle: &mut StatementHandle, offset: usize) -> Result<usize, Error> {
let mut arg_i = offset;
// for handle in &statement.handles {

Expand Down Expand Up @@ -95,7 +105,20 @@ impl SqliteArguments<'_> {
}

impl SqliteArgumentValue<'_> {
fn bind(&self, handle: &StatementHandle, i: usize) -> Result<(), Error> {
fn into_static(self) -> SqliteArgumentValue<'static> {
use SqliteArgumentValue::*;

match self {
Null => Null,
Text(text) => Text(text.into_owned().into()),
Blob(blob) => Blob(blob.into_owned().into()),
Int(v) => Int(v),
Int64(v) => Int64(v),
Double(v) => Double(v),
}
}

fn bind(&self, handle: &mut StatementHandle, i: usize) -> Result<(), Error> {
use SqliteArgumentValue::*;

let status = match self {
Expand Down
127 changes: 98 additions & 29 deletions sqlx-core/src/sqlite/connection/collation.rs
Original file line number Diff line number Diff line change
@@ -1,55 +1,95 @@
use std::cmp::Ordering;
use std::ffi::CString;
use std::fmt::{self, Debug, Formatter};
use std::os::raw::{c_int, c_void};
use std::slice;
use std::str::from_utf8_unchecked;
use std::sync::Arc;

use libsqlite3_sys::{sqlite3_create_collation_v2, SQLITE_OK, SQLITE_UTF8};

use crate::error::Error;
use crate::sqlite::connection::handle::ConnectionHandle;
use crate::sqlite::SqliteError;

unsafe extern "C" fn free_boxed_value<T>(p: *mut c_void) {
drop(Box::from_raw(p as *mut T));
}

pub(crate) fn create_collation<F>(
handle: &ConnectionHandle,
name: &str,
compare: F,
) -> Result<(), Error>
where
F: Fn(&str, &str) -> Ordering + Send + Sync + 'static,
{
unsafe extern "C" fn call_boxed_closure<C>(
#[derive(Clone)]
pub struct Collation {
name: Arc<str>,
collate: Arc<dyn Fn(&str, &str) -> Ordering + Send + Sync + 'static>,
// SAFETY: these must match the concrete type of `collate`
call: unsafe extern "C" fn(
arg1: *mut c_void,
arg2: c_int,
arg3: *const c_void,
arg4: c_int,
arg5: *const c_void,
) -> c_int
) -> c_int,
free: unsafe extern "C" fn(*mut c_void),
}

impl Collation {
pub fn new<N, F>(name: N, collate: F) -> Self
where
C: Fn(&str, &str) -> Ordering,
N: Into<Arc<str>>,
F: Fn(&str, &str) -> Ordering + Send + Sync + 'static,
{
let boxed_f: *mut C = arg1 as *mut C;
debug_assert!(!boxed_f.is_null());
let s1 = {
let c_slice = slice::from_raw_parts(arg3 as *const u8, arg2 as usize);
from_utf8_unchecked(c_slice)
};
let s2 = {
let c_slice = slice::from_raw_parts(arg5 as *const u8, arg4 as usize);
from_utf8_unchecked(c_slice)
unsafe extern "C" fn drop_arc_value<T>(p: *mut c_void) {
drop(Arc::from_raw(p as *mut T));
}

Collation {
name: name.into(),
collate: Arc::new(collate),
call: call_boxed_closure::<F>,
free: drop_arc_value::<F>,
}
}

pub(crate) fn create(&self, handle: &mut ConnectionHandle) -> Result<(), Error> {
let raw_f = Arc::into_raw(Arc::clone(&self.collate));
let c_name = CString::new(&*self.name)
.map_err(|_| err_protocol!("invalid collation name: {:?}", self.name))?;
let flags = SQLITE_UTF8;
let r = unsafe {
sqlite3_create_collation_v2(
handle.as_ptr(),
c_name.as_ptr(),
flags,
raw_f as *mut c_void,
Some(self.call),
Some(self.free),
)
};
let t = (*boxed_f)(s1, s2);

match t {
Ordering::Less => -1,
Ordering::Equal => 0,
Ordering::Greater => 1,
if r == SQLITE_OK {
Ok(())
} else {
// The xDestroy callback is not called if the sqlite3_create_collation_v2() function fails.
drop(unsafe { Arc::from_raw(raw_f) });
Err(Error::Database(Box::new(SqliteError::new(handle.as_ptr()))))
}
}
}

impl Debug for Collation {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("Collation")
.field("name", &self.name)
.finish_non_exhaustive()
}
}

pub(crate) fn create_collation<F>(
handle: &mut ConnectionHandle,
name: &str,
compare: F,
) -> Result<(), Error>
where
F: Fn(&str, &str) -> Ordering + Send + Sync + 'static,
{
unsafe extern "C" fn free_boxed_value<T>(p: *mut c_void) {
drop(Box::from_raw(p as *mut T));
}

let boxed_f: *mut F = Box::into_raw(Box::new(compare));
let c_name =
Expand All @@ -74,3 +114,32 @@ where
Err(Error::Database(Box::new(SqliteError::new(handle.as_ptr()))))
}
}

unsafe extern "C" fn call_boxed_closure<C>(
data: *mut c_void,
left_len: c_int,
left_ptr: *const c_void,
right_len: c_int,
right_ptr: *const c_void,
) -> c_int
where
C: Fn(&str, &str) -> Ordering,
{
let boxed_f: *mut C = data as *mut C;
debug_assert!(!boxed_f.is_null());
let s1 = {
let c_slice = slice::from_raw_parts(left_ptr as *const u8, left_len as usize);
from_utf8_unchecked(c_slice)
};
let s2 = {
let c_slice = slice::from_raw_parts(right_ptr as *const u8, right_len as usize);
from_utf8_unchecked(c_slice)
};
let t = (*boxed_f)(s1, s2);

match t {
Ordering::Less => -1,
Ordering::Equal => 0,
Ordering::Greater => 1,
}
}
Loading

0 comments on commit 6943ac0

Please sign in to comment.