Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiconnection fixes #52

Merged
merged 6 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,4 @@ dynamic-schema = ["diesel-dynamic-schema"]
gst = []

[patch.crates-io]
diesel = { git = "https://github.com/weiznich/diesel", rev = "e632a7ca4fa12b76d7638392aeaff7522f57adef" }
diesel = { git = "https://github.com/weiznich/diesel", rev = "548e0d73a0f2c207f1abf03f8f6741f5a9b1cb0b" }
10 changes: 10 additions & 0 deletions src/oracle/connection/bind_collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,16 @@ impl<'a> BindCollector<'a, Oracle> for OracleBindCollector<'a> {

Ok(())
}

fn push_null_value(
&mut self,
metadata: <Oracle as diesel::sql_types::TypeMetadata>::TypeMetadata,
) -> diesel::prelude::QueryResult<()> {
let len = self.binds.len();
self.binds
.push((format!("in{}", len), BindValue::NotSet(metadata.tpe)));
Ok(())
}
}

impl<'a, T> From<T> for BindValue<'a>
Expand Down
282 changes: 180 additions & 102 deletions src/oracle/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use super::backend::Oracle;
use super::query_builder::OciQueryBuilder;
use super::OciDataType;
use crate::oracle::connection::stmt_iter::RowIter;
use diesel::connection::Instrumentation;
use diesel::connection::InstrumentationEvent;
use diesel::connection::{Connection, SimpleConnection, TransactionManager};
use diesel::connection::{LoadConnection, MultiConnectionHelper};
use diesel::deserialize::FromSql;
Expand Down Expand Up @@ -146,6 +148,7 @@ mod transaction;
pub struct OciConnection {
raw: oracle::Connection,
transaction_manager: OCITransactionManager,
instrumentation: Option<Box<dyn Instrumentation>>,
}

struct ErrorHelper(oracle::Error);
Expand Down Expand Up @@ -238,8 +241,23 @@ unsafe impl Send for OciConnection {}

impl SimpleConnection for OciConnection {
fn batch_execute(&mut self, query: &str) -> QueryResult<()> {
self.raw.execute(query, &[]).map_err(ErrorHelper::from)?;
Ok(())
self.instrumentation
.on_connection_event(InstrumentationEvent::start_query(
&diesel::connection::StrQueryHelper::new(query),
));
let r = self
.raw
.execute(query, &[])
.map_err(ErrorHelper::from)
.map_err(Into::into)
.map(|_| ());
self.instrumentation
.on_connection_event(InstrumentationEvent::finish_query(
&diesel::connection::StrQueryHelper::new(query),
r.as_ref().err(),
));

r
}
}

Expand All @@ -251,54 +269,20 @@ impl Connection for OciConnection {
/// should be a valid connection string for a given backend. See the
/// documentation for the specific backend for specifics.
fn establish(database_url: &str) -> ConnectionResult<Self> {
let url = url::Url::parse(database_url)
.map_err(|_| ConnectionError::InvalidConnectionUrl("Invalid url".into()))?;
if url.scheme() != "oracle" {
return Err(ConnectionError::InvalidConnectionUrl(format!(
"Got a unsupported url scheme: {}",
url.scheme()
)));
}
let user = url.username();

if user.is_empty() {
return Err(ConnectionError::InvalidConnectionUrl(
"Username not set".into(),
));
}
let user = match percent_encoding::percent_decode_str(url.username()).decode_utf8() {
Ok(username) => username,
Err(_e) => {
return Err(ConnectionError::InvalidConnectionUrl(
"Username could not be percent decoded".into(),
))
}
};
let password = url
.password()
.ok_or_else(|| ConnectionError::InvalidConnectionUrl("Password not set".into()))?;

let host = url
.host_str()
.ok_or_else(|| ConnectionError::InvalidConnectionUrl("Hostname not set".into()))?;
let port = url.port();
let path = url.path();

let mut url = host.to_owned();
if let Some(port) = port {
write!(url, ":{}", port).expect("Write to string does not fail");
}
url += path;

let mut raw = oracle::Connection::connect(user, password, url)
.map_err(ErrorHelper::from)
.map_err(|e| ConnectionError::CouldntSetupConfiguration(e.into()))?;

raw.set_autocommit(true);
let mut instrumentation = diesel::connection::get_default_instrumentation();
instrumentation.on_connection_event(InstrumentationEvent::start_establish_connection(
database_url,
));
let raw = Self::inner_establish(database_url);
instrumentation.on_connection_event(InstrumentationEvent::finish_establish_connection(
database_url,
raw.as_ref().err(),
));

Ok(Self {
raw,
raw: raw?,
transaction_manager: OCITransactionManager::new(),
instrumentation,
})
}

Expand All @@ -307,35 +291,17 @@ impl Connection for OciConnection {
where
T: QueryFragment<Self::Backend> + QueryId,
{
let mut qb = OciQueryBuilder::default();

source.to_sql(&mut qb, &Oracle)?;

let conn = &self.raw;
let sql = qb.finish();
let mut stmt = conn.statement(&sql);
if !source.is_safe_to_cache_prepared(&Oracle)? {
stmt.exclude_from_cache();
}
let mut stmt = stmt.build().map_err(ErrorHelper::from)?;
let mut bind_collector = OracleBindCollector::default();

source.collect_binds(&mut bind_collector, &mut (), &Oracle)?;
let binds = bind_collector
.binds
.iter()
.map(|(n, b)| -> (&str, &dyn oracle::sql_type::ToSql) {
(n as &str, std::ops::Deref::deref(b))
})
.collect::<Vec<_>>();

if stmt.is_query() {
stmt.query_named(&binds).map_err(ErrorHelper::from)?;
} else {
stmt.execute_named(&binds).map_err(ErrorHelper::from)?;
}

Ok(stmt.row_count().map_err(ErrorHelper::from)? as usize)
self.instrumentation
.on_connection_event(InstrumentationEvent::start_query(&diesel::debug_query(
source,
)));
let res = self.inner_executing_returning_count(source);
self.instrumentation
.on_connection_event(InstrumentationEvent::finish_query(
&diesel::debug_query(source),
res.as_ref().err(),
));
res
}

fn transaction_state(
Expand All @@ -357,6 +323,14 @@ impl Connection for OciConnection {
self.transaction_manager.is_test_transaction = true;
Ok(())
}

fn instrumentation(&mut self) -> &mut dyn diesel::connection::Instrumentation {
&mut self.instrumentation
}

fn set_instrumentation(&mut self, instrumentation: impl diesel::connection::Instrumentation) {
self.instrumentation = Some(Box::new(instrumentation));
}
}

impl LoadConnection for OciConnection {
Expand All @@ -370,8 +344,11 @@ impl LoadConnection for OciConnection {
Self::Backend: QueryMetadata<T::SqlType>,
{
let query = source.as_query();

self.with_prepared_statement(query, |mut stmt, bind_collector| {
self.instrumentation
.on_connection_event(InstrumentationEvent::start_query(&diesel::debug_query(
&query,
)));
let res = self.with_prepared_statement(&query, |mut stmt, bind_collector| {
if stmt.is_query() {
let binds = bind_collector
.binds
Expand All @@ -394,7 +371,13 @@ impl LoadConnection for OciConnection {
} else {
unreachable!()
}
})
});
self.instrumentation
.on_connection_event(InstrumentationEvent::finish_query(
&diesel::debug_query(&query),
res.as_ref().err(),
));
res
}
}

Expand All @@ -421,7 +404,7 @@ where
impl OciConnection {
fn with_prepared_statement<'conn, 'query, T, R>(
&'conn mut self,
query: T,
query: &T,
callback: impl FnOnce(oracle::Statement<'conn>, OracleBindCollector) -> QueryResult<R>,
) -> Result<R, Error>
where
Expand Down Expand Up @@ -632,36 +615,131 @@ impl OciConnection {
});

if let Some(first_record) = record_iter.next() {
let mut qb = OciQueryBuilder::default();
first_record.to_sql(&mut qb, &Oracle)?;
let query_string = qb.finish();
let mut batch = self
.raw
.batch(&query_string, record_count)
.build()
.map_err(ErrorHelper::from)?;
self.instrumentation
.on_connection_event(InstrumentationEvent::start_query(&diesel::debug_query(
&first_record,
)));
let res = self.inner_batch_insert(&first_record, record_count, record_iter);
self.instrumentation
.on_connection_event(InstrumentationEvent::finish_query(
&diesel::debug_query(&first_record),
res.as_ref().err(),
));
res
} else {
Ok(0)
}
}

fn inner_batch_insert<Q>(
&mut self,
first_record: &Q,
record_count: usize,
record_iter: impl Iterator<Item = Q>,
) -> Result<usize, Error>
where
Q: QueryFragment<Oracle>,
{
let mut qb = OciQueryBuilder::default();
first_record.to_sql(&mut qb, &Oracle)?;
let query_string = qb.finish();
let mut batch = self
.raw
.batch(&query_string, record_count)
.build()
.map_err(ErrorHelper::from)?;

bind_params_to_batch(first_record, &mut batch)?;
for record in record_iter {
bind_params_to_batch(&record, &mut batch)?;
}
batch.execute().map_err(ErrorHelper::from)?;
Ok(record_count)
}

bind_params_to_batch(first_record, &mut batch)?;
for record in record_iter {
bind_params_to_batch(record, &mut batch)?;
fn inner_establish(database_url: &str) -> Result<oracle::Connection, ConnectionError> {
let url = url::Url::parse(database_url)
.map_err(|_| ConnectionError::InvalidConnectionUrl("Invalid url".into()))?;
if url.scheme() != "oracle" {
return Err(ConnectionError::InvalidConnectionUrl(format!(
"Got a unsupported url scheme: {}",
url.scheme()
)));
}
let user = url.username();
if user.is_empty() {
return Err(ConnectionError::InvalidConnectionUrl(
"Username not set".into(),
));
}
let user = match percent_encoding::percent_decode_str(url.username()).decode_utf8() {
Ok(username) => username,
Err(_e) => {
return Err(ConnectionError::InvalidConnectionUrl(
"Username could not be percent decoded".into(),
))
}
batch.execute().map_err(ErrorHelper::from)?;
Ok(record_count)
};
let password = url
.password()
.ok_or_else(|| ConnectionError::InvalidConnectionUrl("Password not set".into()))?;
let host = url
.host_str()
.ok_or_else(|| ConnectionError::InvalidConnectionUrl("Hostname not set".into()))?;
let port = url.port();
let path = url.path();
let mut url = host.to_owned();
if let Some(port) = port {
write!(url, ":{}", port).expect("Write to string does not fail");
}
url += path;
let mut raw = oracle::Connection::connect(user, password, url)
.map_err(ErrorHelper::from)
.map_err(|e| ConnectionError::CouldntSetupConfiguration(e.into()))?;
raw.set_autocommit(true);
Ok(raw)
}

fn inner_executing_returning_count<T>(&mut self, source: &T) -> Result<usize, Error>
where
T: QueryFragment<Oracle> + QueryId,
{
let mut qb = OciQueryBuilder::default();

source.to_sql(&mut qb, &Oracle)?;

let conn = &self.raw;
let sql = qb.finish();
let mut stmt = conn.statement(&sql);
if !source.is_safe_to_cache_prepared(&Oracle)? {
stmt.exclude_from_cache();
}
let mut stmt = stmt.build().map_err(ErrorHelper::from)?;
let mut bind_collector = OracleBindCollector::default();

source.collect_binds(&mut bind_collector, &mut (), &Oracle)?;
let binds = bind_collector
.binds
.iter()
.map(|(n, b)| -> (&str, &dyn oracle::sql_type::ToSql) {
(n as &str, std::ops::Deref::deref(b))
})
.collect::<Vec<_>>();

if stmt.is_query() {
stmt.query_named(&binds).map_err(ErrorHelper::from)?;
} else {
Ok(0)
stmt.execute_named(&binds).map_err(ErrorHelper::from)?;
}

Ok(stmt.row_count().map_err(ErrorHelper::from)? as usize)
}
}

fn bind_params_to_batch<'a, T, V, Op>(
record: InsertStatement<T, &'a ValuesClause<V, T>, Op>,
fn bind_params_to_batch(
record: &impl QueryFragment<Oracle>,
batch: &mut oracle::Batch,
) -> Result<(), Error>
where
T: Table + 'a,
V: 'a,
InsertStatement<T, &'a ValuesClause<V, T>, Op>: QueryFragment<Oracle>,
{
) -> Result<(), Error> {
let mut bind_collector = OracleBindCollector::default();
record.collect_binds(&mut bind_collector, &mut (), &Oracle)?;
let binds = bind_collector
Expand Down
Loading
Loading