diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index f10b39eb3e..69fd0118d1 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -59,6 +59,16 @@ jobs: # unit test: tokio - run: cargo test --manifest-path sqlx-core/Cargo.toml --no-default-features --features 'chrono uuid postgres mysql tls runtime-tokio' + # integration test: sqlite + async-std + - run: cargo test --no-default-features --features 'runtime-async-std sqlite macros uuid chrono tls' + env: + DATABASE_URL: "sqlite::memory:" + + # integration test: sqlite + tokio + - run: cargo test --no-default-features --features 'runtime-tokio sqlite macros uuid chrono tls' + env: + DATABASE_URL: "sqlite::memory:" + # Rust ------------------------------------------------ - name: Prepare build directory for cache diff --git a/Cargo.lock b/Cargo.lock index e0a60ef17b..b875ff040e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -864,6 +864,17 @@ version = "0.2.67" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eb147597cdf94ed43ab7a9038716637d2d1bf2bc571da995d0028dec06bd3018" +[[package]] +name = "libsqlite3-sys" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "266eb8c361198e8d1f682bc974e5d9e2ae90049fb1943890904d11dad7d4a77d" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + [[package]] name = "lock_api" version = "0.3.3" @@ -1649,6 +1660,7 @@ dependencies = [ "generic-array", "hex", "hmac", + "libsqlite3-sys", "log", "matches", "md-5", @@ -1691,6 +1703,15 @@ dependencies = [ "url", ] +[[package]] +name = "sqlx-example-listen-postgres" +version = "0.1.0" +dependencies = [ + "async-std", + "futures 0.3.4", + "sqlx 0.2.6", +] + [[package]] name = "sqlx-example-realworld-postgres" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 5f14e354af..736f49cd5e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ members = [ "sqlx-core", "sqlx-macros", "sqlx-test", + "examples/listen-postgres", "examples/realworld-postgres", "examples/todos-postgres", ] @@ -42,6 +43,7 @@ runtime-tokio = [ "sqlx-core/runtime-tokio", "sqlx-macros/runtime-tokio" ] # database postgres = [ "sqlx-core/postgres", "sqlx-macros/postgres" ] mysql = [ "sqlx-core/mysql", "sqlx-macros/mysql" ] +sqlite = [ "sqlx-core/sqlite", "sqlx-macros/sqlite" ] # types chrono = [ "sqlx-core/chrono", "sqlx-macros/chrono" ] @@ -70,6 +72,18 @@ required-features = [ "postgres", "macros" ] name = "mysql-macros" required-features = [ "mysql", "macros" ] +[[test]] +name = "sqlite" +required-features = [ "sqlite" ] + +[[test]] +name = "sqlite-raw" +required-features = [ "sqlite" ] + +[[test]] +name = "sqlite-types" +required-features = [ "sqlite" ] + [[test]] name = "mysql" required-features = [ "mysql" ] diff --git a/examples/listen-postgres/Cargo.toml b/examples/listen-postgres/Cargo.toml new file mode 100644 index 0000000000..dd12360f59 --- /dev/null +++ b/examples/listen-postgres/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "sqlx-example-listen-postgres" +version = "0.1.0" +edition = "2018" +workspace = "../.." + +[dependencies] +async-std = { version = "1.4.0", features = [ "attributes", "unstable" ] } +sqlx = { path = "../..", features = [ "postgres", "tls" ] } +futures = "0.3.1" diff --git a/examples/listen-postgres/README.md b/examples/listen-postgres/README.md new file mode 100644 index 0000000000..7a0c39a76b --- /dev/null +++ b/examples/listen-postgres/README.md @@ -0,0 +1,18 @@ +Postgres LISTEN/NOTIFY +====================== + +## Usage + +Declare the database URL. This example does not include any reading or writing of data. + +``` +export DATABASE_URL="postgres://postgres@localhost/postgres" +``` + +Run. + +``` +cargo run +``` + +The example program should connect to the database, and create a LISTEN loop on a predefined set of channels. A NOTIFY task will be spawned which will connect to the same database and will emit notifications on a 5 second interval. diff --git a/examples/listen-postgres/src/main.rs b/examples/listen-postgres/src/main.rs new file mode 100644 index 0000000000..1f902e15d2 --- /dev/null +++ b/examples/listen-postgres/src/main.rs @@ -0,0 +1,68 @@ +use async_std::stream; +use futures::StreamExt; +use futures::TryStreamExt; +use sqlx::postgres::PgListener; +use sqlx::{Executor, PgPool}; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::time::Duration; + +#[async_std::main] +async fn main() -> Result<(), Box> { + println!("Building PG pool."); + let conn_str = + std::env::var("DATABASE_URL").expect("Env var DATABASE_URL is required for this example."); + let pool = sqlx::PgPool::new(&conn_str).await?; + + let mut listener = PgListener::new(&conn_str).await?; + + // let notify_pool = pool.clone(); + let _t = async_std::task::spawn(async move { + stream::interval(Duration::from_secs(2)) + .for_each(|_| notify(&pool)) + .await + }); + + println!("Starting LISTEN loop."); + + listener.listen_all(&["chan0", "chan1", "chan2"]).await?; + + let mut counter = 0usize; + loop { + let notification = listener.recv().await?; + println!("[from recv]: {:?}", notification); + + counter += 1; + if counter >= 3 { + break; + } + } + + // Prove that we are buffering messages by waiting for 6 seconds + listener.execute("SELECT pg_sleep(6)").await?; + + let mut stream = listener.into_stream(); + while let Some(notification) = stream.try_next().await? { + println!("[from stream]: {:?}", notification); + } + + Ok(()) +} + +async fn notify(mut pool: &PgPool) { + static COUNTER: AtomicUsize = AtomicUsize::new(0); + + let res = pool + .execute(&*format!( + r#" +NOTIFY "chan0", '{{"payload": {}}}'; +NOTIFY "chan1", '{{"payload": {}}}'; +NOTIFY "chan2", '{{"payload": {}}}'; + "#, + COUNTER.fetch_add(1, Ordering::SeqCst), + COUNTER.fetch_add(1, Ordering::SeqCst), + COUNTER.fetch_add(1, Ordering::SeqCst) + )) + .await; + + println!("[from notify]: {:?}", res); +} diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index cd36977b0a..083b15100e 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -17,13 +17,14 @@ default = [ "runtime-async-std" ] unstable = [] postgres = [ "md-5", "sha2", "base64", "sha-1", "rand", "hmac" ] mysql = [ "sha-1", "sha2", "generic-array", "num-bigint", "base64", "digest", "rand" ] +sqlite = [ "libsqlite3-sys" ] tls = [ "async-native-tls" ] runtime-async-std = [ "async-native-tls/runtime-async-std", "async-std" ] runtime-tokio = [ "async-native-tls/runtime-tokio", "tokio" ] [dependencies] async-native-tls = { version = "0.3.2", default-features = false, optional = true } -async-std = { version = "1.5.0", optional = true } +async-std = { version = "1.5.0", features = [ "unstable" ], optional = true } async-stream = { version = "0.2.1", default-features = false } base64 = { version = "0.11.0", default-features = false, optional = true, features = [ "std" ] } bitflags = { version = "1.2.1", default-features = false } @@ -50,5 +51,12 @@ tokio = { version = "0.2.13", default-features = false, features = [ "dns", "fs" url = { version = "2.1.1", default-features = false } uuid = { version = "0.8.1", default-features = false, optional = true, features = [ "std" ] } +# +[dependencies.libsqlite3-sys] +version = "0.17.1" +optional = true +default-features = false +features = [ "pkg-config", "vcpkg", "bundled" ] + [dev-dependencies] -matches = "0.1.8" +matches = "0.1.8" \ No newline at end of file diff --git a/sqlx-core/src/connection.rs b/sqlx-core/src/connection.rs index aad5b42eea..33f3b9df90 100644 --- a/sqlx-core/src/connection.rs +++ b/sqlx-core/src/connection.rs @@ -4,6 +4,7 @@ use futures_core::future::BoxFuture; use crate::executor::Executor; use crate::pool::{Pool, PoolConnection}; +use crate::transaction::Transaction; use crate::url::Url; /// Represents a single database connection rather than a pool of database connections. @@ -15,6 +16,16 @@ where Self: Send + 'static, Self: Executor, { + /// Starts a transaction. + /// + /// Returns [`Transaction`](struct.Transaction.html). + fn begin(self) -> BoxFuture<'static, crate::Result>> + where + Self: Sized, + { + Box::pin(Transaction::new(0, self)) + } + /// Close this database connection. fn close(self) -> BoxFuture<'static, crate::Result<()>>; @@ -31,63 +42,61 @@ pub trait Connect: Connection { Self: Sized; } -mod internal { - #[allow(dead_code)] - pub enum MaybeOwnedConnection<'c, C> - where - C: super::Connect, - { - Borrowed(&'c mut C), - Owned(super::PoolConnection), - } - - #[allow(dead_code)] - pub enum ConnectionSource<'c, C> - where - C: super::Connect, - { - Connection(MaybeOwnedConnection<'c, C>), - Pool(super::Pool), - } +#[allow(dead_code)] +pub(crate) enum ConnectionSource<'c, C> +where + C: Connect, +{ + ConnectionRef(&'c mut C), + Connection(C), + PoolConnection(Pool, PoolConnection), + Pool(Pool), } -pub(crate) use self::internal::{ConnectionSource, MaybeOwnedConnection}; - impl<'c, C> ConnectionSource<'c, C> where C: Connect, { #[allow(dead_code)] - pub(crate) async fn resolve_by_ref(&mut self) -> crate::Result<&'_ mut C> { + pub(crate) async fn resolve(&mut self) -> crate::Result<&'_ mut C> { if let ConnectionSource::Pool(pool) = self { - *self = - ConnectionSource::Connection(MaybeOwnedConnection::Owned(pool.acquire().await?)); + let conn = pool.acquire().await?; + + *self = ConnectionSource::PoolConnection(pool.clone(), conn); } Ok(match self { - ConnectionSource::Connection(conn) => match conn { - MaybeOwnedConnection::Borrowed(conn) => &mut *conn, - MaybeOwnedConnection::Owned(ref mut conn) => conn, - }, + ConnectionSource::ConnectionRef(conn) => conn, + ConnectionSource::PoolConnection(_, ref mut conn) => conn, + ConnectionSource::Connection(ref mut conn) => conn, ConnectionSource::Pool(_) => unreachable!(), }) } } -impl<'c, C> From<&'c mut C> for MaybeOwnedConnection<'c, C> +impl<'c, C> From for ConnectionSource<'c, C> +where + C: Connect, +{ + fn from(connection: C) -> Self { + ConnectionSource::Connection(connection) + } +} + +impl<'c, C> From> for ConnectionSource<'c, C> where C: Connect, { - fn from(conn: &'c mut C) -> Self { - MaybeOwnedConnection::Borrowed(conn) + fn from(connection: PoolConnection) -> Self { + ConnectionSource::PoolConnection(Pool(connection.pool.clone()), connection) } } -impl<'c, C> From> for MaybeOwnedConnection<'c, C> +impl<'c, C> From> for ConnectionSource<'c, C> where C: Connect, { - fn from(conn: PoolConnection) -> Self { - MaybeOwnedConnection::Owned(conn) + fn from(pool: Pool) -> Self { + ConnectionSource::Pool(pool) } } diff --git a/sqlx-core/src/cursor.rs b/sqlx-core/src/cursor.rs index caf690fd20..cccd1b2291 100644 --- a/sqlx-core/src/cursor.rs +++ b/sqlx-core/src/cursor.rs @@ -1,6 +1,5 @@ use futures_core::future::BoxFuture; -use crate::connection::{Connect, MaybeOwnedConnection}; use crate::database::{Database, HasRow}; use crate::executor::Execute; use crate::pool::Pool; @@ -19,20 +18,21 @@ where { type Database: Database; - #[doc(hidden)] fn from_pool(pool: &Pool<::Connection>, query: E) -> Self where Self: Sized, E: Execute<'q, Self::Database>; - #[doc(hidden)] - fn from_connection(conn: C, query: E) -> Self + fn from_connection( + connection: &'c mut ::Connection, + query: E, + ) -> Self where Self: Sized, - ::Connection: Connect, - C: Into::Connection>>, E: Execute<'q, Self::Database>; /// Fetch the next row in the result. Returns `None` if there are no more rows. - fn next(&mut self) -> BoxFuture::Row>>>; + fn next<'cur>( + &'cur mut self, + ) -> BoxFuture<'cur, crate::Result>::Row>>>; } diff --git a/sqlx-core/src/database.rs b/sqlx-core/src/database.rs index 79c6744584..70dea80061 100644 --- a/sqlx-core/src/database.rs +++ b/sqlx-core/src/database.rs @@ -28,6 +28,8 @@ where /// The Rust type of table identifiers for this database. type TableId: Display + Clone; + + type RawBuffer; } pub trait HasRawValue<'c> { diff --git a/sqlx-core/src/encode.rs b/sqlx-core/src/encode.rs index 3f856dfda6..1fc2f10469 100644 --- a/sqlx-core/src/encode.rs +++ b/sqlx-core/src/encode.rs @@ -21,9 +21,9 @@ where DB: Database + ?Sized, { /// Writes the value of `self` into `buf` in the expected format for the database. - fn encode(&self, buf: &mut Vec); + fn encode(&self, buf: &mut DB::RawBuffer); - fn encode_nullable(&self, buf: &mut Vec) -> IsNull { + fn encode_nullable(&self, buf: &mut DB::RawBuffer) -> IsNull { self.encode(buf); IsNull::No @@ -40,11 +40,11 @@ where T: Type, T: Encode, { - fn encode(&self, buf: &mut Vec) { + fn encode(&self, buf: &mut DB::RawBuffer) { (*self).encode(buf) } - fn encode_nullable(&self, buf: &mut Vec) -> IsNull { + fn encode_nullable(&self, buf: &mut DB::RawBuffer) -> IsNull { (*self).encode_nullable(buf) } @@ -59,12 +59,12 @@ where T: Type, T: Encode, { - fn encode(&self, buf: &mut Vec) { + fn encode(&self, buf: &mut DB::RawBuffer) { // Forward to [encode_nullable] and ignore the result let _ = self.encode_nullable(buf); } - fn encode_nullable(&self, buf: &mut Vec) -> IsNull { + fn encode_nullable(&self, buf: &mut DB::RawBuffer) -> IsNull { if let Some(self_) = self { self_.encode(buf); diff --git a/sqlx-core/src/error.rs b/sqlx-core/src/error.rs index 62cd437fd1..843edd48c5 100644 --- a/sqlx-core/src/error.rs +++ b/sqlx-core/src/error.rs @@ -177,6 +177,11 @@ pub trait DatabaseError: Display + Debug + Send + Sync { /// The primary, human-readable error message. fn message(&self) -> &str; + /// The (SQLSTATE) code for the error. + fn code(&self) -> Option<&str> { + None + } + fn details(&self) -> Option<&str> { None } @@ -223,19 +228,6 @@ macro_rules! tls_err { #[allow(unused_macros)] macro_rules! impl_fmt_error { ($err:ty) => { - impl std::fmt::Debug for $err { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("DatabaseError") - .field("message", &self.message()) - .field("details", &self.details()) - .field("hint", &self.hint()) - .field("table_name", &self.table_name()) - .field("column_name", &self.column_name()) - .field("constraint_name", &self.constraint_name()) - .finish() - } - } - impl std::fmt::Display for $err { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { f.pad(self.message()) diff --git a/sqlx-core/src/executor.rs b/sqlx-core/src/executor.rs index b487f7d9a4..bdc06db3db 100644 --- a/sqlx-core/src/executor.rs +++ b/sqlx-core/src/executor.rs @@ -24,10 +24,17 @@ where /// discarding any potential result rows. /// /// Returns the number of rows affected, or 0 if not applicable. - fn execute<'e, 'q, E: 'e>(&'e mut self, query: E) -> BoxFuture<'e, crate::Result> + fn execute<'e, 'q: 'e, 'c: 'e, E: 'e>( + &'c mut self, + query: E, + ) -> BoxFuture<'e, crate::Result> where E: Execute<'q, Self::Database>; + /// Executes a query for its result. + /// + /// Returns a [`Cursor`] that can be used to iterate through the [`Row`]s + /// of the result. fn fetch<'e, 'q, E>(&'e mut self, query: E) -> >::Cursor where E: Execute<'q, Self::Database>; @@ -45,14 +52,12 @@ where E: Execute<'q, Self::Database>; } -pub trait RefExecutor<'c> { +// HACK: Generic Associated Types (GATs) will enable us to rework how the Executor bound is done +// in Query to remove the need for this. +pub trait RefExecutor<'e> { type Database: Database; - /// Executes a query for its result. - /// - /// Returns a [`Cursor`] that can be used to iterate through the [`Row`]s - /// of the result. - fn fetch_by_ref<'q, E>(self, query: E) -> >::Cursor + fn fetch_by_ref<'q, E>(self, query: E) -> >::Cursor where E: Execute<'q, Self::Database>; } @@ -87,7 +92,10 @@ where { type Database = T::Database; - fn execute<'e, 'q, E: 'e>(&'e mut self, query: E) -> BoxFuture<'e, crate::Result> + fn execute<'e, 'q: 'e, 'c: 'e, E: 'e>( + &'c mut self, + query: E, + ) -> BoxFuture<'e, crate::Result> where E: Execute<'q, Self::Database>, { diff --git a/sqlx-core/src/io/buf.rs b/sqlx-core/src/io/buf.rs index 96316b6511..2c16e30b3e 100644 --- a/sqlx-core/src/io/buf.rs +++ b/sqlx-core/src/io/buf.rs @@ -2,7 +2,7 @@ use byteorder::ByteOrder; use memchr::memchr; use std::{io, slice, str}; -pub trait Buf { +pub trait Buf<'a> { fn advance(&mut self, cnt: usize); fn get_uint(&mut self, n: usize) -> io::Result; @@ -25,14 +25,14 @@ pub trait Buf { fn get_u64(&mut self) -> io::Result; - fn get_str(&mut self, len: usize) -> io::Result<&str>; + fn get_str(&mut self, len: usize) -> io::Result<&'a str>; - fn get_str_nul(&mut self) -> io::Result<&str>; + fn get_str_nul(&mut self) -> io::Result<&'a str>; - fn get_bytes(&mut self, len: usize) -> io::Result<&[u8]>; + fn get_bytes(&mut self, len: usize) -> io::Result<&'a [u8]>; } -impl<'a> Buf for &'a [u8] { +impl<'a> Buf<'a> for &'a [u8] { fn advance(&mut self, cnt: usize) { *self = &self[cnt..]; } @@ -107,19 +107,19 @@ impl<'a> Buf for &'a [u8] { Ok(val) } - fn get_str(&mut self, len: usize) -> io::Result<&str> { + fn get_str(&mut self, len: usize) -> io::Result<&'a str> { str::from_utf8(self.get_bytes(len)?) .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err)) } - fn get_str_nul(&mut self) -> io::Result<&str> { + fn get_str_nul(&mut self) -> io::Result<&'a str> { let len = memchr(b'\0', &*self).ok_or(io::ErrorKind::InvalidData)?; let s = &self.get_str(len + 1)?[..len]; Ok(s) } - fn get_bytes(&mut self, len: usize) -> io::Result<&[u8]> { + fn get_bytes(&mut self, len: usize) -> io::Result<&'a [u8]> { let buf = &self[..len]; self.advance(len); diff --git a/sqlx-core/src/io/buf_stream.rs b/sqlx-core/src/io/buf_stream.rs index aeae57b6c3..c33d5fabad 100644 --- a/sqlx-core/src/io/buf_stream.rs +++ b/sqlx-core/src/io/buf_stream.rs @@ -45,6 +45,7 @@ where } } + #[cfg(feature = "postgres")] #[inline] pub fn buffer<'c>(&'c self) -> &'c [u8] { &self.rbuf[self.rbuf_rindex..] diff --git a/sqlx-core/src/lib.rs b/sqlx-core/src/lib.rs index 381a0ce787..405242f197 100644 --- a/sqlx-core/src/lib.rs +++ b/sqlx-core/src/lib.rs @@ -1,6 +1,10 @@ //! Core of SQLx, the rust SQL toolkit. Not intended to be used directly. -#![forbid(unsafe_code)] +// When compiling with support for SQLite we must allow some unsafe code in order to +// interface with the inherently unsafe C module. This unsafe code is contained +// to the sqlite module. +#![cfg_attr(feature = "sqlite", deny(unsafe_code))] +#![cfg_attr(not(feature = "sqlite"), forbid(unsafe_code))] #![recursion_limit = "512"] #![cfg_attr(docsrs, feature(doc_cfg))] @@ -11,6 +15,8 @@ pub mod error; #[macro_use] mod io; +mod maybe_owned; + pub mod connection; pub mod cursor; pub mod database; @@ -48,6 +54,10 @@ pub mod mysql; #[cfg_attr(docsrs, doc(cfg(feature = "postgres")))] pub mod postgres; +#[cfg(feature = "sqlite")] +#[cfg_attr(docsrs, doc(cfg(feature = "sqlite")))] +pub mod sqlite; + pub use error::{Error, Result}; // Named Lifetimes: diff --git a/sqlx-core/src/maybe_owned.rs b/sqlx-core/src/maybe_owned.rs new file mode 100644 index 0000000000..a743f951e7 --- /dev/null +++ b/sqlx-core/src/maybe_owned.rs @@ -0,0 +1,52 @@ +use core::borrow::{Borrow, BorrowMut}; +use core::ops::{Deref, DerefMut}; + +pub(crate) enum MaybeOwned { + #[allow(dead_code)] + Borrowed(B), + + #[allow(dead_code)] + Owned(O), +} + +impl MaybeOwned { + #[allow(dead_code)] + pub(crate) fn resolve<'a, 'b: 'a>(&'a mut self, collection: &'b mut Vec) -> &'a mut O { + match self { + MaybeOwned::Owned(ref mut val) => val, + MaybeOwned::Borrowed(index) => &mut collection[*index], + } + } +} + +impl<'a, O, B> From<&'a mut B> for MaybeOwned { + fn from(val: &'a mut B) -> Self { + MaybeOwned::Borrowed(val) + } +} + +impl Deref for MaybeOwned +where + O: Borrow, +{ + type Target = B; + + fn deref(&self) -> &Self::Target { + match self { + MaybeOwned::Borrowed(val) => val, + MaybeOwned::Owned(ref val) => val.borrow(), + } + } +} + +impl DerefMut for MaybeOwned +where + O: BorrowMut, +{ + fn deref_mut(&mut self) -> &mut Self::Target { + match self { + MaybeOwned::Borrowed(val) => val, + MaybeOwned::Owned(ref mut val) => val.borrow_mut(), + } + } +} diff --git a/sqlx-core/src/mysql/connection.rs b/sqlx-core/src/mysql/connection.rs index d9dcb4da1a..24dc3c4f22 100644 --- a/sqlx-core/src/mysql/connection.rs +++ b/sqlx-core/src/mysql/connection.rs @@ -8,7 +8,7 @@ use sha1::Sha1; use crate::connection::{Connect, Connection}; use crate::executor::Executor; use crate::mysql::protocol::{ - AuthPlugin, AuthSwitch, Capabilities, ComPing, Decode, Handshake, HandshakeResponse, + AuthPlugin, AuthSwitch, Capabilities, ComPing, Handshake, HandshakeResponse, }; use crate::mysql::stream::MySqlStream; use crate::mysql::util::xor_eq; @@ -149,7 +149,7 @@ async fn establish(stream: &mut MySqlStream, url: &Url) -> crate::Result<()> { // Read a [Handshake] packet. When connecting to the database server, this is immediately // received from the database server. - let handshake = Handshake::decode(stream.receive().await?)?; + let handshake = Handshake::read(stream.receive().await?)?; let mut auth_plugin = handshake.auth_plugin; let mut auth_plugin_data = handshake.auth_plugin_data; @@ -202,7 +202,7 @@ async fn establish(stream: &mut MySqlStream, url: &Url) -> crate::Result<()> { // AUTH_SWITCH 0xFE => { - let auth = AuthSwitch::decode(packet)?; + let auth = AuthSwitch::read(packet)?; auth_plugin = auth.auth_plugin; auth_plugin_data = auth.auth_plugin_data; diff --git a/sqlx-core/src/mysql/cursor.rs b/sqlx-core/src/mysql/cursor.rs index 936f38d93e..49572931cb 100644 --- a/sqlx-core/src/mysql/cursor.rs +++ b/sqlx-core/src/mysql/cursor.rs @@ -3,10 +3,10 @@ use std::sync::Arc; use futures_core::future::BoxFuture; -use crate::connection::{ConnectionSource, MaybeOwnedConnection}; +use crate::connection::ConnectionSource; use crate::cursor::Cursor; use crate::executor::Execute; -use crate::mysql::protocol::{ColumnCount, ColumnDefinition, Decode, Row, Status, TypeId}; +use crate::mysql::protocol::{ColumnCount, ColumnDefinition, Row, Status, TypeId}; use crate::mysql::{MySql, MySqlArguments, MySqlConnection, MySqlRow}; use crate::pool::Pool; @@ -21,7 +21,6 @@ pub struct MySqlCursor<'c, 'q> { impl<'c, 'q> Cursor<'c, 'q> for MySqlCursor<'c, 'q> { type Database = MySql; - #[doc(hidden)] fn from_pool(pool: &Pool, query: E) -> Self where Self: Sized, @@ -36,15 +35,13 @@ impl<'c, 'q> Cursor<'c, 'q> for MySqlCursor<'c, 'q> { } } - #[doc(hidden)] - fn from_connection(conn: C, query: E) -> Self + fn from_connection(conn: &'c mut MySqlConnection, query: E) -> Self where Self: Sized, - C: Into>, E: Execute<'q, MySql>, { Self { - source: ConnectionSource::Connection(conn.into()), + source: ConnectionSource::ConnectionRef(conn), column_names: Arc::default(), column_types: Vec::new(), binary: true, @@ -60,7 +57,7 @@ impl<'c, 'q> Cursor<'c, 'q> for MySqlCursor<'c, 'q> { async fn next<'a, 'c: 'a, 'q: 'a>( cursor: &'a mut MySqlCursor<'c, 'q>, ) -> crate::Result>> { - let mut conn = cursor.source.resolve_by_ref().await?; + let mut conn = cursor.source.resolve().await?; // The first time [next] is called we need to actually execute our // contained query. We guard against this happening on _all_ next calls @@ -109,7 +106,7 @@ async fn next<'a, 'c: 'a, 'q: 'a>( // At the start of the results we expect to see a // COLUMN_COUNT followed by N COLUMN_DEF - let cc = ColumnCount::decode(conn.stream.packet())?; + let cc = ColumnCount::read(conn.stream.packet())?; // We use these definitions to get the actual column types that is critical // in parsing the rows coming back soon @@ -120,7 +117,7 @@ async fn next<'a, 'c: 'a, 'q: 'a>( let mut column_names = HashMap::with_capacity(cc.columns as usize); for i in 0..cc.columns { - let column = ColumnDefinition::decode(conn.stream.receive().await?)?; + let column = ColumnDefinition::read(conn.stream.receive().await?)?; cursor.column_types.push(column.type_id); @@ -142,15 +139,12 @@ async fn next<'a, 'c: 'a, 'q: 'a>( conn.stream.packet(), &cursor.column_types, &mut conn.current_row_values, - // TODO: Text mode cursor.binary, )?; let row = MySqlRow { row, columns: Arc::clone(&cursor.column_names), - // TODO: Text mode - binary: cursor.binary, }; return Ok(Some(row)); diff --git a/sqlx-core/src/mysql/database.rs b/sqlx-core/src/mysql/database.rs index 62a228985c..e2233bbf24 100644 --- a/sqlx-core/src/mysql/database.rs +++ b/sqlx-core/src/mysql/database.rs @@ -11,6 +11,8 @@ impl Database for MySql { type TypeInfo = super::MySqlTypeInfo; type TableId = Box; + + type RawBuffer = Vec; } impl<'c> HasRow<'c> for MySql { diff --git a/sqlx-core/src/mysql/error.rs b/sqlx-core/src/mysql/error.rs index 8ff1cb5d2b..3d9aa44fa0 100644 --- a/sqlx-core/src/mysql/error.rs +++ b/sqlx-core/src/mysql/error.rs @@ -1,12 +1,23 @@ +use std::fmt::{self, Display}; + use crate::error::DatabaseError; use crate::mysql::protocol::ErrPacket; +#[derive(Debug)] pub struct MySqlError(pub(super) ErrPacket); +impl Display for MySqlError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.pad(self.message()) + } +} + impl DatabaseError for MySqlError { fn message(&self) -> &str { &*self.0.error_message } -} -impl_fmt_error!(MySqlError); + fn code(&self) -> Option<&str> { + self.0.sql_state.as_deref() + } +} diff --git a/sqlx-core/src/mysql/executor.rs b/sqlx-core/src/mysql/executor.rs index 5a9a26f1c5..77fd2ef428 100644 --- a/sqlx-core/src/mysql/executor.rs +++ b/sqlx-core/src/mysql/executor.rs @@ -4,8 +4,8 @@ use crate::cursor::Cursor; use crate::describe::{Column, Describe}; use crate::executor::{Execute, Executor, RefExecutor}; use crate::mysql::protocol::{ - self, ColumnDefinition, ComQuery, ComStmtExecute, ComStmtPrepare, ComStmtPrepareOk, Decode, - FieldFlags, Status, + self, ColumnDefinition, ComQuery, ComStmtExecute, ComStmtPrepare, ComStmtPrepareOk, FieldFlags, + Status, }; use crate::mysql::{MySql, MySqlArguments, MySqlCursor, MySqlTypeInfo}; @@ -49,12 +49,12 @@ impl super::MySqlConnection { return self.stream.handle_err(); } - ComStmtPrepareOk::decode(packet) + ComStmtPrepareOk::read(packet) } async fn drop_column_defs(&mut self, count: usize) -> crate::Result<()> { for _ in 0..count { - let _column = ColumnDefinition::decode(self.stream.receive().await?)?; + let _column = ColumnDefinition::read(self.stream.receive().await?)?; } if count > 0 { @@ -168,7 +168,7 @@ impl super::MySqlConnection { let mut result_columns = Vec::with_capacity(stmt.columns as usize); for _ in 0..stmt.params { - let param = ColumnDefinition::decode(self.stream.receive().await?)?; + let param = ColumnDefinition::read(self.stream.receive().await?)?; param_types.push(MySqlTypeInfo::from_column_def(¶m)); } @@ -177,7 +177,7 @@ impl super::MySqlConnection { } for _ in 0..stmt.columns { - let column = ColumnDefinition::decode(self.stream.receive().await?)?; + let column = ColumnDefinition::read(self.stream.receive().await?)?; result_columns.push(Column:: { type_info: MySqlTypeInfo::from_column_def(&column), @@ -202,7 +202,10 @@ impl super::MySqlConnection { impl Executor for super::MySqlConnection { type Database = MySql; - fn execute<'e, 'q, E: 'e>(&'e mut self, query: E) -> BoxFuture<'e, crate::Result> + fn execute<'e, 'q: 'e, 'c: 'e, E: 'e>( + &'c mut self, + query: E, + ) -> BoxFuture<'e, crate::Result> where E: Execute<'q, Self::Database>, { diff --git a/sqlx-core/src/mysql/protocol/auth_switch.rs b/sqlx-core/src/mysql/protocol/auth_switch.rs index 8c7315beee..23e7dcfa7d 100644 --- a/sqlx-core/src/mysql/protocol/auth_switch.rs +++ b/sqlx-core/src/mysql/protocol/auth_switch.rs @@ -1,17 +1,15 @@ -use byteorder::LittleEndian; - use crate::io::Buf; -use crate::mysql::protocol::{AuthPlugin, Capabilities, Decode, Status}; +use crate::mysql::protocol::AuthPlugin; // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_auth_switch_request.html #[derive(Debug)] -pub struct AuthSwitch { - pub auth_plugin: AuthPlugin, - pub auth_plugin_data: Box<[u8]>, +pub(crate) struct AuthSwitch { + pub(crate) auth_plugin: AuthPlugin, + pub(crate) auth_plugin_data: Box<[u8]>, } -impl Decode for AuthSwitch { - fn decode(mut buf: &[u8]) -> crate::Result +impl AuthSwitch { + pub(crate) fn read(mut buf: &[u8]) -> crate::Result where Self: Sized, { diff --git a/sqlx-core/src/mysql/protocol/column_count.rs b/sqlx-core/src/mysql/protocol/column_count.rs index 85bf08915a..3ed537d497 100644 --- a/sqlx-core/src/mysql/protocol/column_count.rs +++ b/sqlx-core/src/mysql/protocol/column_count.rs @@ -1,16 +1,14 @@ use byteorder::LittleEndian; -use crate::io::Buf; use crate::mysql::io::BufExt; -use crate::mysql::protocol::Decode; #[derive(Debug)] pub struct ColumnCount { pub columns: u64, } -impl Decode for ColumnCount { - fn decode(mut buf: &[u8]) -> crate::Result { +impl ColumnCount { + pub(crate) fn read(mut buf: &[u8]) -> crate::Result { let columns = buf.get_uint_lenenc::()?.unwrap_or(0); Ok(Self { columns }) diff --git a/sqlx-core/src/mysql/protocol/column_def.rs b/sqlx-core/src/mysql/protocol/column_def.rs index b3fb6f7cbd..e20a9e414e 100644 --- a/sqlx-core/src/mysql/protocol/column_def.rs +++ b/sqlx-core/src/mysql/protocol/column_def.rs @@ -2,7 +2,7 @@ use byteorder::LittleEndian; use crate::io::Buf; use crate::mysql::io::BufExt; -use crate::mysql::protocol::{Decode, FieldFlags, TypeId}; +use crate::mysql::protocol::{FieldFlags, TypeId}; // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_com_query_response_text_resultset_column_definition.html // https://mariadb.com/kb/en/resultset/#column-definition-packet @@ -33,8 +33,8 @@ impl ColumnDefinition { } } -impl Decode for ColumnDefinition { - fn decode(mut buf: &[u8]) -> crate::Result { +impl ColumnDefinition { + pub(crate) fn read(mut buf: &[u8]) -> crate::Result { // catalog : string let catalog = buf.get_str_lenenc::()?; diff --git a/sqlx-core/src/mysql/protocol/com_ping.rs b/sqlx-core/src/mysql/protocol/com_ping.rs index 8ebfed8767..a90ce62bad 100644 --- a/sqlx-core/src/mysql/protocol/com_ping.rs +++ b/sqlx-core/src/mysql/protocol/com_ping.rs @@ -1,7 +1,4 @@ -use byteorder::LittleEndian; - use crate::io::BufMut; -use crate::mysql::io::BufMutExt; use crate::mysql::protocol::{Capabilities, Encode}; // https://dev.mysql.com/doc/internals/en/com-ping.html diff --git a/sqlx-core/src/mysql/protocol/com_query.rs b/sqlx-core/src/mysql/protocol/com_query.rs index 512caa7106..0a8a20557c 100644 --- a/sqlx-core/src/mysql/protocol/com_query.rs +++ b/sqlx-core/src/mysql/protocol/com_query.rs @@ -1,7 +1,4 @@ -use byteorder::LittleEndian; - use crate::io::BufMut; -use crate::mysql::io::BufMutExt; use crate::mysql::protocol::{Capabilities, Encode}; // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_com_query.html diff --git a/sqlx-core/src/mysql/protocol/com_set_option.rs b/sqlx-core/src/mysql/protocol/com_set_option.rs deleted file mode 100644 index 6ce7545cbc..0000000000 --- a/sqlx-core/src/mysql/protocol/com_set_option.rs +++ /dev/null @@ -1,29 +0,0 @@ -use byteorder::LittleEndian; - -use crate::io::BufMut; -use crate::mysql::io::BufMutExt; -use crate::mysql::protocol::{Capabilities, Encode}; - -// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/mysql__com_8h.html#a53f60000da139fc7d547db96635a2c02 -#[derive(Debug, Copy, Clone)] -#[repr(u16)] -pub enum SetOption { - MultiStatementsOn = 0x00, - MultiStatementsOff = 0x01, -} - -// https://dev.mysql.com/doc/internals/en/com-set-option.html -#[derive(Debug)] -pub struct ComSetOption { - pub option: SetOption, -} - -impl Encode for ComSetOption { - fn encode(&self, buf: &mut Vec, _: Capabilities) { - // COM_SET_OPTION : int<1> - buf.put_u8(0x1a); - - // option : int<2> - buf.put_u16::(self.option as u16); - } -} diff --git a/sqlx-core/src/mysql/protocol/com_stmt_execute.rs b/sqlx-core/src/mysql/protocol/com_stmt_execute.rs index 591a7e9acd..3c6d36af02 100644 --- a/sqlx-core/src/mysql/protocol/com_stmt_execute.rs +++ b/sqlx-core/src/mysql/protocol/com_stmt_execute.rs @@ -1,7 +1,6 @@ use byteorder::LittleEndian; use crate::io::BufMut; -use crate::mysql::io::BufMutExt; use crate::mysql::protocol::{Capabilities, Encode}; use crate::mysql::types::MySqlTypeInfo; @@ -27,7 +26,7 @@ pub struct ComStmtExecute<'a> { } impl Encode for ComStmtExecute<'_> { - fn encode(&self, buf: &mut Vec, capabilities: Capabilities) { + fn encode(&self, buf: &mut Vec, _: Capabilities) { // COM_STMT_EXECUTE : int<1> buf.put_u8(0x17); diff --git a/sqlx-core/src/mysql/protocol/com_stmt_prepare.rs b/sqlx-core/src/mysql/protocol/com_stmt_prepare.rs index 4bb95e1dcb..7437133700 100644 --- a/sqlx-core/src/mysql/protocol/com_stmt_prepare.rs +++ b/sqlx-core/src/mysql/protocol/com_stmt_prepare.rs @@ -1,7 +1,4 @@ -use byteorder::LittleEndian; - use crate::io::BufMut; -use crate::mysql::io::BufMutExt; use crate::mysql::protocol::{Capabilities, Encode}; // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_com_stmt_prepare.html diff --git a/sqlx-core/src/mysql/protocol/com_stmt_prepare_ok.rs b/sqlx-core/src/mysql/protocol/com_stmt_prepare_ok.rs index d2814582c8..ae34d2b6f0 100644 --- a/sqlx-core/src/mysql/protocol/com_stmt_prepare_ok.rs +++ b/sqlx-core/src/mysql/protocol/com_stmt_prepare_ok.rs @@ -1,27 +1,25 @@ use byteorder::LittleEndian; use crate::io::Buf; -use crate::mysql::io::BufExt; -use crate::mysql::protocol::Decode; // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_com_stmt_prepare.html#sect_protocol_com_stmt_prepare_response_ok #[derive(Debug)] -pub struct ComStmtPrepareOk { - pub statement_id: u32, +pub(crate) struct ComStmtPrepareOk { + pub(crate) statement_id: u32, /// Number of columns in the returned result set (or 0 if statement /// does not return result set). - pub columns: u16, + pub(crate) columns: u16, /// Number of prepared statement parameters ('?' placeholders). - pub params: u16, + pub(crate) params: u16, /// Number of warnings. - pub warnings: u16, + pub(crate) warnings: u16, } -impl Decode for ComStmtPrepareOk { - fn decode(mut buf: &[u8]) -> crate::Result { +impl ComStmtPrepareOk { + pub(crate) fn read(mut buf: &[u8]) -> crate::Result { let header = buf.get_u8()?; if header != 0x00 { diff --git a/sqlx-core/src/mysql/protocol/decode.rs b/sqlx-core/src/mysql/protocol/decode.rs deleted file mode 100644 index 5a8dd601b1..0000000000 --- a/sqlx-core/src/mysql/protocol/decode.rs +++ /dev/null @@ -1,7 +0,0 @@ -use std::io; - -pub trait Decode { - fn decode(buf: &[u8]) -> crate::Result - where - Self: Sized; -} diff --git a/sqlx-core/src/mysql/protocol/encode.rs b/sqlx-core/src/mysql/protocol/encode.rs deleted file mode 100644 index 154a577c91..0000000000 --- a/sqlx-core/src/mysql/protocol/encode.rs +++ /dev/null @@ -1,12 +0,0 @@ -use crate::io::BufMut; -use crate::mysql::protocol::Capabilities; - -pub trait Encode { - fn encode(&self, buf: &mut Vec, capabilities: Capabilities); -} - -impl Encode for &'_ [u8] { - fn encode(&self, buf: &mut Vec, _: Capabilities) { - buf.put_bytes(self); - } -} diff --git a/sqlx-core/src/mysql/protocol/eof.rs b/sqlx-core/src/mysql/protocol/eof.rs index 093b9f6dfc..f8e80031a4 100644 --- a/sqlx-core/src/mysql/protocol/eof.rs +++ b/sqlx-core/src/mysql/protocol/eof.rs @@ -1,8 +1,7 @@ use byteorder::LittleEndian; use crate::io::Buf; -use crate::mysql::io::BufExt; -use crate::mysql::protocol::{Capabilities, Decode, Status}; +use crate::mysql::protocol::Status; // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_eof_packet.html // https://mariadb.com/kb/en/eof_packet/ @@ -12,8 +11,8 @@ pub struct EofPacket { pub status: Status, } -impl Decode for EofPacket { - fn decode(mut buf: &[u8]) -> crate::Result +impl EofPacket { + pub(crate) fn read(mut buf: &[u8]) -> crate::Result where Self: Sized, { diff --git a/sqlx-core/src/mysql/protocol/err.rs b/sqlx-core/src/mysql/protocol/err.rs index 965d54134b..595560e743 100644 --- a/sqlx-core/src/mysql/protocol/err.rs +++ b/sqlx-core/src/mysql/protocol/err.rs @@ -1,8 +1,7 @@ use byteorder::LittleEndian; use crate::io::Buf; -use crate::mysql::io::BufExt; -use crate::mysql::protocol::{Capabilities, Decode, Status}; +use crate::mysql::protocol::Capabilities; // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_err_packet.html // https://mariadb.com/kb/en/err_packet/ @@ -14,7 +13,7 @@ pub struct ErrPacket { } impl ErrPacket { - pub(crate) fn decode(mut buf: &[u8], capabilities: Capabilities) -> crate::Result + pub(crate) fn read(mut buf: &[u8], capabilities: Capabilities) -> crate::Result where Self: Sized, { @@ -50,7 +49,7 @@ impl ErrPacket { #[cfg(test)] mod tests { - use super::{Capabilities, Decode, ErrPacket, Status}; + use super::{Capabilities, ErrPacket}; const ERR_PACKETS_OUT_OF_ORDER: &[u8] = b"\xff\x84\x04Got packets out of order"; @@ -58,7 +57,7 @@ mod tests { #[test] fn it_decodes_packets_out_of_order() { - let mut p = ErrPacket::decode(ERR_PACKETS_OUT_OF_ORDER, Capabilities::PROTOCOL_41).unwrap(); + let p = ErrPacket::read(ERR_PACKETS_OUT_OF_ORDER, Capabilities::PROTOCOL_41).unwrap(); assert_eq!(&*p.error_message, "Got packets out of order"); assert_eq!(p.error_code, 1156); @@ -67,7 +66,7 @@ mod tests { #[test] fn it_decodes_ok_handshake() { - let mut p = ErrPacket::decode(ERR_HANDSHAKE_UNKNOWN_DB, Capabilities::PROTOCOL_41).unwrap(); + let p = ErrPacket::read(ERR_HANDSHAKE_UNKNOWN_DB, Capabilities::PROTOCOL_41).unwrap(); assert_eq!(p.error_code, 1049); assert_eq!(p.sql_state.as_deref(), Some("42000")); diff --git a/sqlx-core/src/mysql/protocol/handshake.rs b/sqlx-core/src/mysql/protocol/handshake.rs index 9bfe3046e9..89364d38f2 100644 --- a/sqlx-core/src/mysql/protocol/handshake.rs +++ b/sqlx-core/src/mysql/protocol/handshake.rs @@ -1,24 +1,24 @@ use byteorder::LittleEndian; use crate::io::Buf; -use crate::mysql::protocol::{AuthPlugin, Capabilities, Decode, Status}; +use crate::mysql::protocol::{AuthPlugin, Capabilities, Status}; // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_handshake_v10.html // https://mariadb.com/kb/en/connection/#initial-handshake-packet #[derive(Debug)] -pub struct Handshake { - pub protocol_version: u8, - pub server_version: Box, - pub connection_id: u32, - pub server_capabilities: Capabilities, - pub server_default_collation: u8, - pub status: Status, - pub auth_plugin: AuthPlugin, - pub auth_plugin_data: Box<[u8]>, +pub(crate) struct Handshake { + pub(crate) protocol_version: u8, + pub(crate) server_version: Box, + pub(crate) connection_id: u32, + pub(crate) server_capabilities: Capabilities, + pub(crate) server_default_collation: u8, + pub(crate) status: Status, + pub(crate) auth_plugin: AuthPlugin, + pub(crate) auth_plugin_data: Box<[u8]>, } -impl Decode for Handshake { - fn decode(mut buf: &[u8]) -> crate::Result +impl Handshake { + pub(crate) fn read(mut buf: &[u8]) -> crate::Result where Self: Sized, { @@ -68,7 +68,7 @@ impl Decode for Handshake { } else { // capability_flags_3 : int<4> let capabilities_3 = buf.get_u32::()?; - capabilities |= Capabilities::from_bits_truncate((capabilities_2 as u64) << 32); + capabilities |= Capabilities::from_bits_truncate((capabilities_3 as u64) << 32); } if capabilities.contains(Capabilities::SECURE_CONNECTION) { @@ -102,15 +102,15 @@ impl Decode for Handshake { #[cfg(test)] mod tests { - use super::{AuthPlugin, Capabilities, Decode, Handshake, Status}; + use super::{AuthPlugin, Capabilities, Handshake, Status}; use matches::assert_matches; const HANDSHAKE_MARIA_DB_10_4_7: &[u8] = b"\n5.5.5-10.4.7-MariaDB-1:10.4.7+maria~bionic\x00\x0b\x00\x00\x00t6L\\j\"dS\x00\xfe\xf7\x08\x02\x00\xff\x81\x15\x00\x00\x00\x00\x00\x00\x07\x00\x00\x00U14Oph9\", capabilities: Capabilities); +} + +impl Encode for &'_ [u8] { + fn encode(&self, buf: &mut Vec, _: Capabilities) { + use crate::io::BufMut; + + buf.put_bytes(self); + } +} diff --git a/sqlx-core/src/mysql/protocol/ok.rs b/sqlx-core/src/mysql/protocol/ok.rs index cd7986aa10..0639d09dd7 100644 --- a/sqlx-core/src/mysql/protocol/ok.rs +++ b/sqlx-core/src/mysql/protocol/ok.rs @@ -2,21 +2,21 @@ use byteorder::LittleEndian; use crate::io::Buf; use crate::mysql::io::BufExt; -use crate::mysql::protocol::{Capabilities, Decode, Status}; +use crate::mysql::protocol::Status; // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_ok_packet.html // https://mariadb.com/kb/en/ok_packet/ #[derive(Debug)] -pub struct OkPacket { - pub affected_rows: u64, - pub last_insert_id: u64, - pub status: Status, - pub warnings: u16, - pub info: Box, +pub(crate) struct OkPacket { + pub(crate) affected_rows: u64, + pub(crate) last_insert_id: u64, + pub(crate) status: Status, + pub(crate) warnings: u16, + pub(crate) info: Box, } -impl Decode for OkPacket { - fn decode(mut buf: &[u8]) -> crate::Result +impl OkPacket { + pub(crate) fn read(mut buf: &[u8]) -> crate::Result where Self: Sized, { @@ -46,13 +46,13 @@ impl Decode for OkPacket { #[cfg(test)] mod tests { - use super::{Capabilities, Decode, OkPacket, Status}; + use super::{OkPacket, Status}; const OK_HANDSHAKE: &[u8] = b"\x00\x00\x00\x02@\x00\x00"; #[test] fn it_decodes_ok_handshake() { - let mut p = OkPacket::decode(OK_HANDSHAKE).unwrap(); + let p = OkPacket::read(OK_HANDSHAKE).unwrap(); assert_eq!(p.affected_rows, 0); assert_eq!(p.last_insert_id, 0); diff --git a/sqlx-core/src/mysql/protocol/row.rs b/sqlx-core/src/mysql/protocol/row.rs index 2c275a2297..384db51c8c 100644 --- a/sqlx-core/src/mysql/protocol/row.rs +++ b/sqlx-core/src/mysql/protocol/row.rs @@ -3,21 +3,20 @@ use std::ops::Range; use byteorder::{ByteOrder, LittleEndian}; use crate::io::Buf; -use crate::mysql::io::BufExt; -use crate::mysql::protocol::{Decode, TypeId}; +use crate::mysql::protocol::TypeId; -pub struct Row<'c> { +pub(crate) struct Row<'c> { buffer: &'c [u8], values: &'c [Option>], - binary: bool, + pub(crate) binary: bool, } impl<'c> Row<'c> { - pub fn len(&self) -> usize { + pub(crate) fn len(&self) -> usize { self.values.len() } - pub fn get(&self, index: usize) -> Option<&'c [u8]> { + pub(crate) fn get(&self, index: usize) -> Option<&'c [u8]> { let range = self.values[index].as_ref()?; Some(&self.buffer[(range.start as usize)..(range.end as usize)]) @@ -54,13 +53,13 @@ fn get_lenenc(buf: &[u8]) -> (usize, Option) { } impl<'c> Row<'c> { - pub fn read( + pub(crate) fn read( mut buf: &'c [u8], columns: &[TypeId], values: &'c mut Vec>>, binary: bool, ) -> crate::Result { - let mut buffer = &*buf; + let buffer = &*buf; values.clear(); values.reserve(columns.len()); @@ -68,7 +67,7 @@ impl<'c> Row<'c> { if !binary { let mut index = 0; - for column_idx in 0..columns.len() { + for _ in 0..columns.len() { let (len_size, size) = get_lenenc(&buf[index..]); if let Some(size) = size { @@ -77,7 +76,7 @@ impl<'c> Row<'c> { values.push(None); } - index += (len_size + size.unwrap_or_default()); + index += len_size + size.unwrap_or_default(); } return Ok(Self { @@ -111,16 +110,16 @@ impl<'c> Row<'c> { if is_null { values.push(None); } else { - let size = match columns[column_idx] { - TypeId::TINY_INT => 1, - TypeId::SMALL_INT => 2, - TypeId::INT | TypeId::FLOAT => 4, - TypeId::BIG_INT | TypeId::DOUBLE => 8, + let (offset, size) = match columns[column_idx] { + TypeId::TINY_INT => (0, 1), + TypeId::SMALL_INT => (0, 2), + TypeId::INT | TypeId::FLOAT => (0, 4), + TypeId::BIG_INT | TypeId::DOUBLE => (0, 8), - TypeId::DATE => 5, - TypeId::TIME => 1 + buffer[index] as usize, + TypeId::DATE => (0, 5), + TypeId::TIME => (0, 1 + buffer[index] as usize), - TypeId::TIMESTAMP | TypeId::DATETIME => 1 + buffer[index] as usize, + TypeId::TIMESTAMP | TypeId::DATETIME => (0, 1 + buffer[index] as usize), TypeId::TINY_BLOB | TypeId::MEDIUM_BLOB @@ -130,7 +129,7 @@ impl<'c> Row<'c> { | TypeId::VAR_CHAR => { let (len_size, len) = get_lenenc(&buffer[index..]); - len_size + len.unwrap_or_default() + (len_size, len.unwrap_or_default()) } id => { @@ -138,8 +137,8 @@ impl<'c> Row<'c> { } }; - values.push(Some(index..(index + size))); - index += size; + values.push(Some((index + offset)..(index + offset + size))); + index += size + offset; } } diff --git a/sqlx-core/src/mysql/protocol/ssl_request.rs b/sqlx-core/src/mysql/protocol/ssl_request.rs index fea5ef8e33..bbf781c657 100644 --- a/sqlx-core/src/mysql/protocol/ssl_request.rs +++ b/sqlx-core/src/mysql/protocol/ssl_request.rs @@ -1,8 +1,7 @@ use byteorder::LittleEndian; use crate::io::BufMut; -use crate::mysql::io::BufMutExt; -use crate::mysql::protocol::{AuthPlugin, Capabilities, Encode}; +use crate::mysql::protocol::{Capabilities, Encode}; // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_handshake_response.html // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest diff --git a/sqlx-core/src/mysql/protocol/type.rs b/sqlx-core/src/mysql/protocol/type.rs index b242ed02b0..38ecb453ed 100644 --- a/sqlx-core/src/mysql/protocol/type.rs +++ b/sqlx-core/src/mysql/protocol/type.rs @@ -49,7 +49,7 @@ type_id_consts! { pub const SMALL_INT: TypeId = TypeId(2); pub const INT: TypeId = TypeId(3); pub const BIG_INT: TypeId = TypeId(8); - pub const MEDIUM_INT: TypeId = TypeId(9); + // pub const MEDIUM_INT: TypeId = TypeId(9); // Numeric: FLOAT, DOUBLE pub const FLOAT: TypeId = TypeId(4); diff --git a/sqlx-core/src/mysql/row.rs b/sqlx-core/src/mysql/row.rs index 72338c60fb..398c44c70e 100644 --- a/sqlx-core/src/mysql/row.rs +++ b/sqlx-core/src/mysql/row.rs @@ -28,7 +28,6 @@ impl<'c> TryFrom>> for MySqlValue<'c> { pub struct MySqlRow<'c> { pub(super) row: protocol::Row<'c>, pub(super) columns: Arc, u16>>, - pub(super) binary: bool, } impl<'c> Row<'c> for MySqlRow<'c> { @@ -38,14 +37,15 @@ impl<'c> Row<'c> for MySqlRow<'c> { self.row.len() } - fn try_get_raw<'r, I>(&'r self, index: I) -> crate::Result>> + fn try_get_raw<'r, I>(&'r self, index: I) -> crate::Result>> where + 'c: 'r, I: ColumnIndex, { let index = index.resolve(self)?; Ok(self.row.get(index).map(|buf| { - if self.binary { + if self.row.binary { MySqlValue::Binary(buf) } else { MySqlValue::Text(buf) diff --git a/sqlx-core/src/mysql/stream.rs b/sqlx-core/src/mysql/stream.rs index d5ad18bcca..4f85d74b4d 100644 --- a/sqlx-core/src/mysql/stream.rs +++ b/sqlx-core/src/mysql/stream.rs @@ -3,7 +3,7 @@ use std::net::Shutdown; use byteorder::{ByteOrder, LittleEndian}; use crate::io::{Buf, BufMut, BufStream, MaybeTlsStream}; -use crate::mysql::protocol::{Capabilities, Decode, Encode, EofPacket, ErrPacket, OkPacket}; +use crate::mysql::protocol::{Capabilities, Encode, EofPacket, ErrPacket, OkPacket}; use crate::mysql::MySqlError; use crate::url::Url; @@ -163,7 +163,7 @@ impl MySqlStream { impl MySqlStream { pub(crate) async fn maybe_receive_eof(&mut self) -> crate::Result<()> { if !self.capabilities.contains(Capabilities::DEPRECATE_EOF) { - let _eof = EofPacket::decode(self.receive().await?)?; + let _eof = EofPacket::read(self.receive().await?)?; } Ok(()) @@ -171,7 +171,7 @@ impl MySqlStream { pub(crate) fn maybe_handle_eof(&mut self) -> crate::Result> { if !self.capabilities.contains(Capabilities::DEPRECATE_EOF) && self.packet()[0] == 0xFE { - Ok(Some(EofPacket::decode(self.packet())?)) + Ok(Some(EofPacket::read(self.packet())?)) } else { Ok(None) } @@ -182,10 +182,10 @@ impl MySqlStream { } pub(crate) fn handle_err(&mut self) -> crate::Result { - Err(MySqlError(ErrPacket::decode(self.packet(), self.capabilities)?).into()) + Err(MySqlError(ErrPacket::read(self.packet(), self.capabilities)?).into()) } pub(crate) fn handle_ok(&mut self) -> crate::Result { - OkPacket::decode(self.packet()) + OkPacket::read(self.packet()) } } diff --git a/sqlx-core/src/mysql/types/bytes.rs b/sqlx-core/src/mysql/types/bytes.rs index 685dd0aa96..2504ec9af9 100644 --- a/sqlx-core/src/mysql/types/bytes.rs +++ b/sqlx-core/src/mysql/types/bytes.rs @@ -2,7 +2,7 @@ use byteorder::LittleEndian; use crate::decode::Decode; use crate::encode::Encode; -use crate::mysql::io::{BufExt, BufMutExt}; +use crate::mysql::io::BufMutExt; use crate::mysql::protocol::TypeId; use crate::mysql::types::MySqlTypeInfo; use crate::mysql::{MySql, MySqlValue}; @@ -41,16 +41,7 @@ impl Encode for Vec { impl<'de> Decode<'de, MySql> for Vec { fn decode(value: Option>) -> crate::Result { match value.try_into()? { - MySqlValue::Binary(mut buf) => { - let len = buf - .get_uint_lenenc::() - .map_err(crate::Error::decode)? - .unwrap_or_default(); - - Ok((&buf[..(len as usize)]).to_vec()) - } - - MySqlValue::Text(s) => Ok(s.to_vec()), + MySqlValue::Binary(buf) | MySqlValue::Text(buf) => Ok(buf.to_vec()), } } } @@ -58,16 +49,7 @@ impl<'de> Decode<'de, MySql> for Vec { impl<'de> Decode<'de, MySql> for &'de [u8] { fn decode(value: Option>) -> crate::Result { match value.try_into()? { - MySqlValue::Binary(mut buf) => { - let len = buf - .get_uint_lenenc::() - .map_err(crate::Error::decode)? - .unwrap_or_default(); - - Ok(&buf[..(len as usize)]) - } - - MySqlValue::Text(s) => Ok(s), + MySqlValue::Binary(buf) | MySqlValue::Text(buf) => Ok(buf), } } } diff --git a/sqlx-core/src/mysql/types/str.rs b/sqlx-core/src/mysql/types/str.rs index e644b67111..cc79379a0a 100644 --- a/sqlx-core/src/mysql/types/str.rs +++ b/sqlx-core/src/mysql/types/str.rs @@ -4,7 +4,7 @@ use byteorder::LittleEndian; use crate::decode::Decode; use crate::encode::Encode; -use crate::mysql::io::{BufExt, BufMutExt}; +use crate::mysql::io::BufMutExt; use crate::mysql::protocol::TypeId; use crate::mysql::types::MySqlTypeInfo; use crate::mysql::{MySql, MySqlValue}; @@ -44,16 +44,9 @@ impl Encode for String { impl<'de> Decode<'de, MySql> for &'de str { fn decode(value: Option>) -> crate::Result { match value.try_into()? { - MySqlValue::Binary(mut buf) => { - let len = buf - .get_uint_lenenc::() - .map_err(crate::Error::decode)? - .unwrap_or_default(); - - from_utf8(&buf[..(len as usize)]).map_err(crate::Error::decode) + MySqlValue::Binary(buf) | MySqlValue::Text(buf) => { + from_utf8(buf).map_err(crate::Error::decode) } - - MySqlValue::Text(s) => from_utf8(s).map_err(crate::Error::decode), } } } diff --git a/sqlx-core/src/pool/connection.rs b/sqlx-core/src/pool/connection.rs index a45a9ca7cc..2d64ec9e31 100644 --- a/sqlx-core/src/pool/connection.rs +++ b/sqlx-core/src/pool/connection.rs @@ -1,4 +1,5 @@ use futures_core::future::BoxFuture; +use std::borrow::{Borrow, BorrowMut}; use std::ops::{Deref, DerefMut}; use std::sync::Arc; use std::time::Instant; @@ -14,7 +15,7 @@ where C: Connect, { live: Option>, - pool: Arc>, + pub(crate) pool: Arc>, } pub(super) struct Live { @@ -35,6 +36,24 @@ pub(super) struct Floating<'p, C> { const DEREF_ERR: &str = "(bug) connection already released to pool"; +impl Borrow for PoolConnection +where + C: Connect, +{ + fn borrow(&self) -> &C { + &*self + } +} + +impl BorrowMut for PoolConnection +where + C: Connect, +{ + fn borrow_mut(&mut self) -> &mut C { + &mut *self + } +} + impl Deref for PoolConnection where C: Connect, diff --git a/sqlx-core/src/pool/executor.rs b/sqlx-core/src/pool/executor.rs index 705953afd3..5904eeb1f8 100644 --- a/sqlx-core/src/pool/executor.rs +++ b/sqlx-core/src/pool/executor.rs @@ -17,7 +17,10 @@ where { type Database = DB; - fn execute<'e, 'q, E: 'e>(&'e mut self, query: E) -> BoxFuture<'e, crate::Result> + fn execute<'e, 'q: 'e, 'c: 'e, E: 'e>( + &'c mut self, + query: E, + ) -> BoxFuture<'e, crate::Result> where E: Execute<'q, Self::Database>, { @@ -65,7 +68,10 @@ where { type Database = C::Database; - fn execute<'e, 'q, E: 'e>(&'e mut self, query: E) -> BoxFuture<'e, crate::Result> + fn execute<'e, 'q: 'e, 'c: 'e, E: 'e>( + &'c mut self, + query: E, + ) -> BoxFuture<'e, crate::Result> where E: Execute<'q, Self::Database>, { diff --git a/sqlx-core/src/pool/inner.rs b/sqlx-core/src/pool/inner.rs index d3e31aee5a..0dc679d542 100644 --- a/sqlx-core/src/pool/inner.rs +++ b/sqlx-core/src/pool/inner.rs @@ -19,7 +19,7 @@ use crate::{ use super::connection::{Floating, Idle, Live}; use super::Options; -pub(super) struct SharedPool { +pub(crate) struct SharedPool { url: String, idle_conns: ArrayQueue>, waiters: SegQueue, @@ -220,13 +220,14 @@ where // successfully established connection Ok(Ok(raw)) => Ok(Some(Floating::new_live(raw, guard))), - // IO error while connecting, this should definitely be logged - // and we should attempt to retry - Ok(Err(crate::Error::Io(e))) => { - log::warn!("error establishing a connection: {}", e); + // an IO error while connecting is assumed to be the system starting up + Ok(Err(crate::Error::Io(_))) => Ok(None), - Ok(None) - } + // TODO: Handle other database "boot period"s + + // [postgres] the database system is starting up + // TODO: Make this check actually check if this is postgres + Ok(Err(crate::Error::Database(error))) if error.code() == Some("57P03") => Ok(None), // Any other error while connection should immediately // terminate and bubble the error up diff --git a/sqlx-core/src/pool/mod.rs b/sqlx-core/src/pool/mod.rs index d8d35f41f5..1761d60e9a 100644 --- a/sqlx-core/src/pool/mod.rs +++ b/sqlx-core/src/pool/mod.rs @@ -21,7 +21,7 @@ pub use self::connection::PoolConnection; pub use self::options::Builder; /// A pool of database connections. -pub struct Pool(Arc>); +pub struct Pool(pub(crate) Arc>); impl Pool where diff --git a/sqlx-core/src/postgres/connection.rs b/sqlx-core/src/postgres/connection.rs index 791937d502..b5c5c03694 100644 --- a/sqlx-core/src/postgres/connection.rs +++ b/sqlx-core/src/postgres/connection.rs @@ -9,8 +9,8 @@ use futures_util::TryFutureExt; use crate::connection::{Connect, Connection}; use crate::executor::Executor; use crate::postgres::protocol::{ - Authentication, AuthenticationMd5, AuthenticationSasl, Message, PasswordMessage, - StartupMessage, StatementId, Terminate, TypeFormat, + Authentication, AuthenticationMd5, AuthenticationSasl, BackendKeyData, Message, + PasswordMessage, StartupMessage, StatementId, Terminate, TypeFormat, }; use crate::postgres::stream::PgStream; use crate::postgres::{sasl, tls}; @@ -88,10 +88,17 @@ pub struct PgConnection { // Work buffer for the value ranges of the current row // This is used as the backing memory for each Row's value indexes pub(super) current_row_values: Vec>>, + + // TODO: Find a use for these values. Perhaps in a debug impl of PgConnection? + #[allow(dead_code)] + process_id: u32, + + #[allow(dead_code)] + secret_key: u32, } // https://www.postgresql.org/docs/12/protocol-flow.html#id-1.10.5.7.3 -async fn startup(stream: &mut PgStream, url: &Url) -> crate::Result<()> { +async fn startup(stream: &mut PgStream, url: &Url) -> crate::Result { // Defaults to postgres@.../postgres let username = url.username().unwrap_or("postgres"); let database = url.database().unwrap_or("postgres"); @@ -118,8 +125,13 @@ async fn startup(stream: &mut PgStream, url: &Url) -> crate::Result<()> { stream.write(StartupMessage { params }); stream.flush().await?; + let mut key_data = BackendKeyData { + process_id: 0, + secret_key: 0, + }; + loop { - match stream.read().await? { + match stream.receive().await? { Message::Authentication => match Authentication::read(stream.buffer())? { Authentication::Ok => { // do nothing. no password is needed to continue. @@ -193,7 +205,7 @@ async fn startup(stream: &mut PgStream, url: &Url) -> crate::Result<()> { Message::BackendKeyData => { // do nothing. we do not care about the server values here. - // todo: we should care and store these on the connection + key_data = BackendKeyData::read(stream.buffer())?; } Message::ParameterStatus => { @@ -212,7 +224,7 @@ async fn startup(stream: &mut PgStream, url: &Url) -> crate::Result<()> { } } - Ok(()) + Ok(key_data) } // https://www.postgresql.org/docs/12/protocol-flow.html#id-1.10.5.7.10 @@ -230,7 +242,7 @@ impl PgConnection { let mut stream = PgStream::new(&url).await?; tls::request_if_needed(&mut stream, &url).await?; - startup(&mut stream, &url).await?; + let key_data = startup(&mut stream, &url).await?; Ok(Self { stream, @@ -240,6 +252,8 @@ impl PgConnection { cache_statement: HashMap::new(), cache_statement_columns: HashMap::new(), cache_statement_formats: HashMap::new(), + process_id: key_data.process_id, + secret_key: key_data.secret_key, }) } } diff --git a/sqlx-core/src/postgres/cursor.rs b/sqlx-core/src/postgres/cursor.rs index 08f9906c97..fcfa444c3a 100644 --- a/sqlx-core/src/postgres/cursor.rs +++ b/sqlx-core/src/postgres/cursor.rs @@ -3,11 +3,13 @@ use std::sync::Arc; use futures_core::future::BoxFuture; -use crate::connection::{ConnectionSource, MaybeOwnedConnection}; +use crate::connection::ConnectionSource; use crate::cursor::Cursor; use crate::executor::Execute; use crate::pool::Pool; -use crate::postgres::protocol::{DataRow, Message, RowDescription, StatementId, TypeFormat}; +use crate::postgres::protocol::{ + DataRow, Message, ReadyForQuery, RowDescription, StatementId, TypeFormat, +}; use crate::postgres::{PgArguments, PgConnection, PgRow, Postgres}; pub struct PgCursor<'c, 'q> { @@ -20,7 +22,6 @@ pub struct PgCursor<'c, 'q> { impl<'c, 'q> Cursor<'c, 'q> for PgCursor<'c, 'q> { type Database = Postgres; - #[doc(hidden)] fn from_pool(pool: &Pool, query: E) -> Self where Self: Sized, @@ -34,15 +35,13 @@ impl<'c, 'q> Cursor<'c, 'q> for PgCursor<'c, 'q> { } } - #[doc(hidden)] - fn from_connection(conn: C, query: E) -> Self + fn from_connection(conn: &'c mut PgConnection, query: E) -> Self where Self: Sized, - C: Into>, E: Execute<'q, Postgres>, { Self { - source: ConnectionSource::Connection(conn.into()), + source: ConnectionSource::ConnectionRef(conn), columns: Arc::default(), formats: Arc::new([] as [TypeFormat; 0]), query: Some(query.into_parts()), @@ -78,7 +77,7 @@ async fn expect_desc( conn: &mut PgConnection, ) -> crate::Result<(HashMap, usize>, Vec)> { let description: Option<_> = loop { - match conn.stream.read().await? { + match conn.stream.receive().await? { Message::ParseComplete | Message::BindComplete => {} Message::RowDescription => { @@ -126,7 +125,7 @@ async fn get_or_describe( async fn next<'a, 'c: 'a, 'q: 'a>( cursor: &'a mut PgCursor<'c, 'q>, ) -> crate::Result>> { - let mut conn = cursor.source.resolve_by_ref().await?; + let mut conn = cursor.source.resolve().await?; // The first time [next] is called we need to actually execute our // contained query. We guard against this happening on _all_ next calls @@ -149,7 +148,7 @@ async fn next<'a, 'c: 'a, 'q: 'a>( } loop { - match conn.stream.read().await? { + match conn.stream.receive().await? { // Indicates that a phase of the extended query flow has completed // We as SQLx don't generally care as long as it is happening Message::ParseComplete | Message::BindComplete => {} @@ -159,6 +158,9 @@ async fn next<'a, 'c: 'a, 'q: 'a>( // Indicates that all queries have finished executing Message::ReadyForQuery => { + // TODO: How should we handle an ERROR status form ReadyForQuery + let _ready = ReadyForQuery::read(conn.stream.buffer())?; + conn.is_ready = true; break; } diff --git a/sqlx-core/src/postgres/database.rs b/sqlx-core/src/postgres/database.rs index 4709b9557e..399b2f6753 100644 --- a/sqlx-core/src/postgres/database.rs +++ b/sqlx-core/src/postgres/database.rs @@ -12,10 +12,11 @@ impl Database for Postgres { type TypeInfo = super::PgTypeInfo; type TableId = u32; + + type RawBuffer = Vec; } impl<'a> HasRow<'a> for Postgres { - // TODO: Can we drop the `type Database = _` type Database = Postgres; type Row = super::PgRow<'a>; diff --git a/sqlx-core/src/postgres/error.rs b/sqlx-core/src/postgres/error.rs index 537d0d2d1c..546ead3ea9 100644 --- a/sqlx-core/src/postgres/error.rs +++ b/sqlx-core/src/postgres/error.rs @@ -1,6 +1,9 @@ +use std::fmt::{self, Display}; + use crate::error::DatabaseError; use crate::postgres::protocol::Response; +#[derive(Debug)] pub struct PgError(pub(super) Response); impl DatabaseError for PgError { @@ -8,6 +11,10 @@ impl DatabaseError for PgError { &self.0.message } + fn code(&self) -> Option<&str> { + Some(&self.0.code) + } + fn details(&self) -> Option<&str> { self.0.detail.as_ref().map(|s| &**s) } @@ -29,4 +36,8 @@ impl DatabaseError for PgError { } } -impl_fmt_error!(PgError); +impl Display for PgError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.pad(self.message()) + } +} diff --git a/sqlx-core/src/postgres/executor.rs b/sqlx-core/src/postgres/executor.rs index c2d3fc608a..f17965ef7c 100644 --- a/sqlx-core/src/postgres/executor.rs +++ b/sqlx-core/src/postgres/executor.rs @@ -9,8 +9,8 @@ use crate::cursor::Cursor; use crate::describe::{Column, Describe}; use crate::executor::{Execute, Executor, RefExecutor}; use crate::postgres::protocol::{ - self, CommandComplete, Field, Message, ParameterDescription, RowDescription, StatementId, - TypeFormat, TypeId, + self, CommandComplete, Field, Message, ParameterDescription, ReadyForQuery, RowDescription, + StatementId, TypeFormat, TypeId, }; use crate::postgres::types::SharedStr; use crate::postgres::{PgArguments, PgConnection, PgCursor, PgRow, PgTypeInfo, Postgres}; @@ -73,7 +73,7 @@ impl PgConnection { if !self.is_ready { loop { - if let Message::ReadyForQuery = self.stream.read().await? { + if let Message::ReadyForQuery = self.stream.receive().await? { // we are now ready to go self.is_ready = true; break; @@ -136,7 +136,7 @@ impl PgConnection { Ok(statement) } - async fn describe<'e, 'q: 'e>( + async fn do_describe<'e, 'q: 'e>( &'e mut self, query: &'q str, ) -> crate::Result> { @@ -150,7 +150,7 @@ impl PgConnection { self.stream.flush().await?; let params = loop { - match self.stream.read().await? { + match self.stream.receive().await? { Message::ParseComplete => {} Message::ParameterDescription => { @@ -167,7 +167,7 @@ impl PgConnection { }; }; - let result = match self.stream.read().await? { + let result = match self.stream.receive().await? { Message::NoData => None, Message::RowDescription => Some(RowDescription::read(self.stream.buffer())?), @@ -329,7 +329,7 @@ impl PgConnection { let mut rows = 0; loop { - match self.stream.read().await? { + match self.stream.receive().await? { Message::ParseComplete | Message::BindComplete | Message::NoData @@ -346,6 +346,9 @@ impl PgConnection { } Message::ReadyForQuery => { + // TODO: How should we handle an ERROR status form ReadyForQuery + let _ready = ReadyForQuery::read(self.stream.buffer())?; + self.is_ready = true; break; } @@ -365,7 +368,10 @@ impl PgConnection { impl Executor for super::PgConnection { type Database = Postgres; - fn execute<'e, 'q, E: 'e>(&'e mut self, query: E) -> BoxFuture<'e, crate::Result> + fn execute<'e, 'q: 'e, 'c: 'e, E: 'e>( + &'c mut self, + query: E, + ) -> BoxFuture<'e, crate::Result> where E: Execute<'q, Self::Database>, { @@ -391,7 +397,7 @@ impl Executor for super::PgConnection { where E: Execute<'q, Self::Database>, { - Box::pin(async move { self.describe(query.into_parts().0).await }) + Box::pin(async move { self.do_describe(query.into_parts().0).await }) } } diff --git a/sqlx-core/src/postgres/listen.rs b/sqlx-core/src/postgres/listen.rs new file mode 100644 index 0000000000..4c163a5fa8 --- /dev/null +++ b/sqlx-core/src/postgres/listen.rs @@ -0,0 +1,305 @@ +use std::collections::HashSet; +use std::fmt::{self, Debug}; +use std::io; + +use async_stream::try_stream; +use futures_channel::mpsc; +use futures_core::future::BoxFuture; +use futures_core::stream::Stream; + +use crate::describe::Describe; +use crate::executor::{Execute, Executor, RefExecutor}; +use crate::pool::{Pool, PoolConnection}; +use crate::postgres::protocol::{Message, NotificationResponse}; +use crate::postgres::{PgConnection, PgCursor, Postgres}; + +/// A stream of asynchronous notifications from Postgres. +/// +/// This listener will auto-reconnect. If the active +/// connection being used ever dies, this listener will detect that event, create a +/// new connection, will re-subscribe to all of the originally specified channels, and will resume +/// operations as normal. +pub struct PgListener { + pool: Pool, + connection: Option>, + buffer_rx: mpsc::UnboundedReceiver>, + buffer_tx: Option>>, + channels: Vec, +} + +/// An asynchronous notification from Postgres. +pub struct PgNotification<'c>(NotificationResponse<'c>); + +impl PgListener { + pub async fn new(url: &str) -> crate::Result { + // Create a pool of 1 without timeouts (as they don't apply here) + // We only use the pool to handle re-connections + let pool = Pool::::builder() + .max_size(1) + .max_lifetime(None) + .idle_timeout(None) + .build(url) + .await?; + + Self::from_pool(&pool).await + } + + pub async fn from_pool(pool: &Pool) -> crate::Result { + // Pull out an initial connection + let mut connection = pool.acquire().await?; + + // Setup a notification buffer + let (sender, receiver) = mpsc::unbounded(); + connection.stream.notifications = Some(sender); + + Ok(Self { + pool: pool.clone(), + connection: Some(connection), + buffer_rx: receiver, + buffer_tx: None, + channels: Vec::new(), + }) + } + + /// Starts listening for notifications on a channel. + pub async fn listen(&mut self, channel: &str) -> crate::Result<()> { + self.connection() + .execute(&*format!("LISTEN {}", ident(channel))) + .await?; + + self.channels.push(channel.to_owned()); + + Ok(()) + } + + /// Starts listening for notifications on all channels. + pub async fn listen_all( + &mut self, + channels: impl IntoIterator, + ) -> crate::Result<()> { + let beg = self.channels.len(); + self.channels.extend(channels.into_iter().map(|s| s.into())); + + self.connection + .as_mut() + .unwrap() + .execute(&*build_listen_all_query(&self.channels[beg..])) + .await?; + + Ok(()) + } + + /// Stops listening for notifications on a channel. + pub async fn unlisten(&mut self, channel: &str) -> crate::Result<()> { + self.connection() + .execute(&*format!("UNLISTEN {}", ident(channel))) + .await?; + + if let Some(pos) = self.channels.iter().position(|s| s == channel) { + self.channels.remove(pos); + } + + Ok(()) + } + + /// Stops listening for notifications on all channels. + pub async fn unlisten_all(&mut self) -> crate::Result<()> { + self.connection().execute("UNLISTEN *").await?; + + self.channels.clear(); + + Ok(()) + } + + #[inline] + async fn connect_if_needed(&mut self) -> crate::Result<()> { + if let None = self.connection { + let mut connection = self.pool.acquire().await?; + connection.stream.notifications = self.buffer_tx.take(); + + connection + .execute(&*build_listen_all_query(&self.channels)) + .await?; + + self.connection = Some(connection); + } + + Ok(()) + } + + #[inline] + fn connection(&mut self) -> &mut PgConnection { + self.connection.as_mut().unwrap() + } + + /// Receives the next notification available from any of the subscribed channels. + pub async fn recv(&mut self) -> crate::Result> { + // Flush the buffer first, if anything + // This would only fill up if this listener is used as a connection + if let Ok(Some(notification)) = self.buffer_rx.try_next() { + return Ok(PgNotification(notification)); + } + + loop { + // Ensure we have an active connection to work with. + self.connect_if_needed().await?; + + match self.connection().stream.read().await { + // We've received an async notification, return it. + Ok(Message::NotificationResponse) => { + let notification = + NotificationResponse::read(self.connection().stream.buffer())?; + + return Ok(PgNotification(notification)); + } + + // Mark the connection as ready for another query + Ok(Message::ReadyForQuery) => { + self.connection().is_ready = true; + } + + // Ignore unexpected messages + Ok(_) => {} + + // The connection is dead, ensure that it is dropped, + // update self state, and loop to try again. + Err(crate::Error::Io(err)) if err.kind() == io::ErrorKind::ConnectionAborted => { + self.buffer_tx = self.connection().stream.notifications.take(); + self.connection = None; + } + + // Forward other errors + Err(error) => { + return Err(error); + } + } + } + } + + /// Consume this listener, returning a `Stream` of notifications. + pub fn into_stream( + mut self, + ) -> impl Stream>> + Unpin { + Box::pin(try_stream! { + loop { + let notification = self.recv().await?; + yield notification.into_owned(); + } + }) + } +} + +impl Executor for PgListener { + type Database = Postgres; + + fn execute<'e, 'q: 'e, 'c: 'e, E: 'e>( + &'c mut self, + query: E, + ) -> BoxFuture<'e, crate::Result> + where + E: Execute<'q, Self::Database>, + { + self.connection().execute(query) + } + + fn fetch<'q, E>(&mut self, query: E) -> PgCursor<'_, 'q> + where + E: Execute<'q, Self::Database>, + { + self.connection().fetch(query) + } + + fn describe<'e, 'q, E: 'e>( + &'e mut self, + query: E, + ) -> BoxFuture<'e, crate::Result>> + where + E: Execute<'q, Self::Database>, + { + self.connection().describe(query) + } +} + +impl<'c> RefExecutor<'c> for &'c mut PgListener { + type Database = Postgres; + + fn fetch_by_ref<'q, E>(self, query: E) -> PgCursor<'c, 'q> + where + E: Execute<'q, Self::Database>, + { + self.connection().fetch_by_ref(query) + } +} + +impl PgNotification<'_> { + /// The process ID of the notifying backend process. + #[inline] + pub fn process_id(&self) -> u32 { + self.0.process_id + } + + /// The channel that the notify has been raised on. This can be thought + /// of as the message topic. + #[inline] + pub fn channel(&self) -> &str { + self.0.channel.as_ref() + } + + /// The payload of the notification. An empty payload is received as an + /// empty string. + #[inline] + pub fn payload(&self) -> &str { + self.0.payload.as_ref() + } + + fn into_owned(self) -> PgNotification<'static> { + PgNotification(self.0.into_owned()) + } +} + +impl Debug for PgNotification<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PgNotification") + .field("process_id", &self.process_id()) + .field("channel", &self.channel()) + .field("payload", &self.payload()) + .finish() + } +} + +fn ident(mut name: &str) -> String { + // If the input string contains a NUL byte, we should truncate the + // identifier. + if let Some(index) = name.find('\0') { + name = &name[..index]; + } + + // Any double quotes must be escaped + name.replace('"', "\"\"") +} + +fn build_listen_all_query(channels: impl IntoIterator>) -> String { + channels.into_iter().fold(String::new(), |mut acc, chan| { + acc.push_str(r#"LISTEN ""#); + acc.push_str(&ident(chan.as_ref())); + acc.push_str(r#"";"#); + acc + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn build_listen_all_query_with_single_channel() { + let output = build_listen_all_query(&["test"]); + assert_eq!(output.as_str(), r#"LISTEN "test";"#); + } + + #[test] + fn build_listen_all_query_with_multiple_channels() { + let output = build_listen_all_query(&["channel.0", "channel.1"]); + assert_eq!(output.as_str(), r#"LISTEN "channel.0";LISTEN "channel.1";"#); + } +} diff --git a/sqlx-core/src/postgres/mod.rs b/sqlx-core/src/postgres/mod.rs index 7ae97c9e20..f1dd59974e 100644 --- a/sqlx-core/src/postgres/mod.rs +++ b/sqlx-core/src/postgres/mod.rs @@ -5,6 +5,7 @@ pub use connection::PgConnection; pub use cursor::PgCursor; pub use database::Postgres; pub use error::PgError; +pub use listen::{PgListener, PgNotification}; pub use row::{PgRow, PgValue}; pub use types::PgTypeInfo; @@ -14,6 +15,7 @@ mod cursor; mod database; mod error; mod executor; +mod listen; mod protocol; mod row; mod sasl; diff --git a/sqlx-core/src/postgres/protocol/authentication.rs b/sqlx-core/src/postgres/protocol/authentication.rs index 9916378d55..0c51489ca1 100644 --- a/sqlx-core/src/postgres/protocol/authentication.rs +++ b/sqlx-core/src/postgres/protocol/authentication.rs @@ -3,7 +3,7 @@ use byteorder::NetworkEndian; use std::str; #[derive(Debug)] -pub enum Authentication { +pub(crate) enum Authentication { /// The authentication exchange is successfully completed. Ok, @@ -54,7 +54,7 @@ pub enum Authentication { } impl Authentication { - pub fn read(mut buf: &[u8]) -> crate::Result { + pub(crate) fn read(mut buf: &[u8]) -> crate::Result { Ok(match buf.get_u32::()? { 0 => Authentication::Ok, 2 => Authentication::KerberosV5, @@ -76,12 +76,12 @@ impl Authentication { } #[derive(Debug)] -pub struct AuthenticationMd5 { - pub salt: [u8; 4], +pub(crate) struct AuthenticationMd5 { + pub(crate) salt: [u8; 4], } impl AuthenticationMd5 { - pub fn read(buf: &[u8]) -> crate::Result { + pub(crate) fn read(buf: &[u8]) -> crate::Result { let mut salt = [0_u8; 4]; salt.copy_from_slice(buf); @@ -90,12 +90,12 @@ impl AuthenticationMd5 { } #[derive(Debug)] -pub struct AuthenticationSasl { - pub mechanisms: Box<[Box]>, +pub(crate) struct AuthenticationSasl { + pub(crate) mechanisms: Box<[Box]>, } impl AuthenticationSasl { - pub fn read(mut buf: &[u8]) -> crate::Result { + pub(crate) fn read(mut buf: &[u8]) -> crate::Result { let mut mechanisms = Vec::new(); while buf[0] != 0 { @@ -109,15 +109,15 @@ impl AuthenticationSasl { } #[derive(Debug)] -pub struct AuthenticationSaslContinue { - pub salt: Vec, - pub iter_count: u32, - pub nonce: Vec, - pub data: String, +pub(crate) struct AuthenticationSaslContinue { + pub(crate) salt: Vec, + pub(crate) iter_count: u32, + pub(crate) nonce: Vec, + pub(crate) data: String, } impl AuthenticationSaslContinue { - pub fn read(buf: &[u8]) -> crate::Result { + pub(crate) fn read(buf: &[u8]) -> crate::Result { let mut salt: Vec = Vec::new(); let mut nonce: Vec = Vec::new(); let mut iter_count: u32 = 0; diff --git a/sqlx-core/src/postgres/protocol/backend_key_data.rs b/sqlx-core/src/postgres/protocol/backend_key_data.rs index 5c086efbce..69394a7743 100644 --- a/sqlx-core/src/postgres/protocol/backend_key_data.rs +++ b/sqlx-core/src/postgres/protocol/backend_key_data.rs @@ -1,4 +1,3 @@ -use super::Decode; use crate::io::Buf; use byteorder::NetworkEndian; @@ -11,8 +10,8 @@ pub struct BackendKeyData { pub secret_key: u32, } -impl Decode for BackendKeyData { - fn decode(mut buf: &[u8]) -> crate::Result { +impl BackendKeyData { + pub(crate) fn read(mut buf: &[u8]) -> crate::Result { let process_id = buf.get_u32::()?; let secret_key = buf.get_u32::()?; @@ -25,13 +24,13 @@ impl Decode for BackendKeyData { #[cfg(test)] mod tests { - use super::{BackendKeyData, Decode}; + use super::BackendKeyData; const BACKEND_KEY_DATA: &[u8] = b"\0\0'\xc6\x89R\xc5+"; #[test] fn it_decodes_backend_key_data() { - let message = BackendKeyData::decode(BACKEND_KEY_DATA).unwrap(); + let message = BackendKeyData::read(BACKEND_KEY_DATA).unwrap(); assert_eq!(message.process_id, 10182); assert_eq!(message.secret_key, 2303903019); diff --git a/sqlx-core/src/postgres/protocol/bind.rs b/sqlx-core/src/postgres/protocol/bind.rs index 260da9f053..6624654b62 100644 --- a/sqlx-core/src/postgres/protocol/bind.rs +++ b/sqlx-core/src/postgres/protocol/bind.rs @@ -1,24 +1,24 @@ -use super::Encode; +use super::Write; use crate::io::BufMut; use crate::postgres::protocol::{StatementId, TypeFormat}; use byteorder::{ByteOrder, NetworkEndian}; -pub struct Bind<'a> { +pub(crate) struct Bind<'a> { /// The name of the destination portal (an empty string selects the unnamed portal). - pub portal: &'a str, + pub(crate) portal: &'a str, /// The id of the source prepared statement (0 selects the unnamed statement). - pub statement: StatementId, + pub(crate) statement: StatementId, /// The parameter format codes. Each must presently be zero (text) or one (binary). /// /// There can be zero to indicate that there are no parameters or that the parameters all use the /// default format (text); or one, in which case the specified format code is applied to all /// parameters; or it can equal the actual number of parameters. - pub formats: &'a [TypeFormat], + pub(crate) formats: &'a [TypeFormat], - pub values_len: i16, - pub values: &'a [u8], + pub(crate) values_len: i16, + pub(crate) values: &'a [u8], /// The result-column format codes. Each must presently be zero (text) or one (binary). /// @@ -26,11 +26,11 @@ pub struct Bind<'a> { /// result columns should all use the default format (text); or one, in which /// case the specified format code is applied to all result columns (if any); /// or it can equal the actual number of result columns of the query. - pub result_formats: &'a [TypeFormat], + pub(crate) result_formats: &'a [TypeFormat], } -impl Encode for Bind<'_> { - fn encode(&self, buf: &mut Vec) { +impl Write for Bind<'_> { + fn write(&self, buf: &mut Vec) { buf.push(b'B'); let pos = buf.len(); @@ -38,7 +38,7 @@ impl Encode for Bind<'_> { buf.put_str_nul(self.portal); - self.statement.encode(buf); + self.statement.write(buf); buf.put_i16::(self.formats.len() as i16); diff --git a/sqlx-core/src/postgres/protocol/cancel_request.rs b/sqlx-core/src/postgres/protocol/cancel_request.rs deleted file mode 100644 index 47d85c8fce..0000000000 --- a/sqlx-core/src/postgres/protocol/cancel_request.rs +++ /dev/null @@ -1,24 +0,0 @@ -use super::Encode; -use crate::io::BufMut; -use byteorder::NetworkEndian; - -/// Sent instead of [`StartupMessage`] with a new connection to cancel a running query on an existing -/// connection. -/// -/// https://www.postgresql.org/docs/devel/protocol-flow.html#id-1.10.5.7.9 -pub struct CancelRequest { - /// The process ID of the target database. - pub process_id: i32, - - /// The secret key for the target database. - pub secret_key: i32, -} - -impl Encode for CancelRequest { - fn encode(&self, buf: &mut Vec) { - buf.put_i32::(16); // message length - buf.put_i32::(8087_7102); // constant for cancel request - buf.put_i32::(self.process_id); - buf.put_i32::(self.secret_key); - } -} diff --git a/sqlx-core/src/postgres/protocol/command_complete.rs b/sqlx-core/src/postgres/protocol/command_complete.rs index 4cc7ab215b..443ec2df8a 100644 --- a/sqlx-core/src/postgres/protocol/command_complete.rs +++ b/sqlx-core/src/postgres/protocol/command_complete.rs @@ -1,8 +1,8 @@ use crate::io::Buf; #[derive(Debug)] -pub struct CommandComplete { - pub affected_rows: u64, +pub(crate) struct CommandComplete { + pub(crate) affected_rows: u64, } impl CommandComplete { diff --git a/sqlx-core/src/postgres/protocol/data_row.rs b/sqlx-core/src/postgres/protocol/data_row.rs index 53f505b5f0..b1dbae66e3 100644 --- a/sqlx-core/src/postgres/protocol/data_row.rs +++ b/sqlx-core/src/postgres/protocol/data_row.rs @@ -2,23 +2,18 @@ use crate::io::Buf; use byteorder::NetworkEndian; use std::ops::Range; -pub struct DataRow<'c> { +pub(crate) struct DataRow<'c> { len: u16, buffer: &'c [u8], values: &'c [Option>], } impl<'c> DataRow<'c> { - pub fn len(&self) -> usize { + pub(crate) fn len(&self) -> usize { self.len as usize } - pub fn get( - &self, - // buffer: &'c [u8], - // values: &[Option>], - index: usize, - ) -> Option<&'c [u8]> { + pub(crate) fn get(&self, index: usize) -> Option<&'c [u8]> { let range = self.values[index].as_ref()?; Some(&self.buffer[(range.start as usize)..(range.end as usize)]) diff --git a/sqlx-core/src/postgres/protocol/decode.rs b/sqlx-core/src/postgres/protocol/decode.rs deleted file mode 100644 index 236a05d2bd..0000000000 --- a/sqlx-core/src/postgres/protocol/decode.rs +++ /dev/null @@ -1,5 +0,0 @@ -pub trait Decode { - fn decode(buf: &[u8]) -> crate::Result - where - Self: Sized; -} diff --git a/sqlx-core/src/postgres/protocol/describe.rs b/sqlx-core/src/postgres/protocol/describe.rs index 18c1a8109e..5c85d68491 100644 --- a/sqlx-core/src/postgres/protocol/describe.rs +++ b/sqlx-core/src/postgres/protocol/describe.rs @@ -1,5 +1,5 @@ use crate::io::BufMut; -use crate::postgres::protocol::{Encode, StatementId}; +use crate::postgres::protocol::{StatementId, Write}; use byteorder::{ByteOrder, NetworkEndian}; pub enum Describe<'a> { @@ -7,8 +7,8 @@ pub enum Describe<'a> { Portal(&'a str), } -impl Encode for Describe<'_> { - fn encode(&self, buf: &mut Vec) { +impl Write for Describe<'_> { + fn write(&self, buf: &mut Vec) { buf.push(b'D'); let pos = buf.len(); @@ -17,7 +17,7 @@ impl Encode for Describe<'_> { match self { Describe::Statement(id) => { buf.push(b'S'); - id.encode(buf); + id.write(buf); } Describe::Portal(name) => { @@ -34,25 +34,25 @@ impl Encode for Describe<'_> { #[cfg(test)] mod test { - use super::{Describe, Encode}; + use super::{Describe, Write}; use crate::postgres::protocol::StatementId; #[test] - fn it_encodes_describe_portal() { + fn it_writes_describe_portal() { let mut buf = Vec::new(); let m = Describe::Portal("__sqlx_p_1"); - m.encode(&mut buf); + m.write(&mut buf); assert_eq!(buf, b"D\0\0\0\x10P__sqlx_p_1\0"); } #[test] - fn it_encodes_describe_statement() { + fn it_writes_describe_statement() { let mut buf = Vec::new(); let m = Describe::Statement(StatementId(1)); - m.encode(&mut buf); + m.write(&mut buf); assert_eq!(buf, b"D\x00\x00\x00\x18S__sqlx_statement_1\x00"); } diff --git a/sqlx-core/src/postgres/protocol/encode.rs b/sqlx-core/src/postgres/protocol/encode.rs deleted file mode 100644 index 05dcdeedb0..0000000000 --- a/sqlx-core/src/postgres/protocol/encode.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub trait Encode { - fn encode(&self, buf: &mut Vec); -} diff --git a/sqlx-core/src/postgres/protocol/execute.rs b/sqlx-core/src/postgres/protocol/execute.rs index 481fdaa0f4..95f1cbf37a 100644 --- a/sqlx-core/src/postgres/protocol/execute.rs +++ b/sqlx-core/src/postgres/protocol/execute.rs @@ -1,5 +1,5 @@ use crate::io::BufMut; -use crate::postgres::protocol::Encode; +use crate::postgres::protocol::Write; use byteorder::NetworkEndian; pub struct Execute<'a> { @@ -11,8 +11,8 @@ pub struct Execute<'a> { pub limit: i32, } -impl Encode for Execute<'_> { - fn encode(&self, buf: &mut Vec) { +impl Write for Execute<'_> { + fn write(&self, buf: &mut Vec) { buf.push(b'E'); // len + nul + len(string) + limit diff --git a/sqlx-core/src/postgres/protocol/flush.rs b/sqlx-core/src/postgres/protocol/flush.rs deleted file mode 100644 index c47d6bd773..0000000000 --- a/sqlx-core/src/postgres/protocol/flush.rs +++ /dev/null @@ -1,12 +0,0 @@ -use crate::io::BufMut; -use crate::postgres::protocol::Encode; -use byteorder::NetworkEndian; - -pub struct Flush; - -impl Encode for Flush { - fn encode(&self, buf: &mut Vec) { - buf.push(b'H'); - buf.put_i32::(4); - } -} diff --git a/sqlx-core/src/postgres/protocol/mod.rs b/sqlx-core/src/postgres/protocol/mod.rs index 332e6de092..4626e9ca96 100644 --- a/sqlx-core/src/postgres/protocol/mod.rs +++ b/sqlx-core/src/postgres/protocol/mod.rs @@ -1,5 +1,6 @@ //! Low level Postgres protocol. Defines the encoding and decoding of the messages communicated //! to and from the database server. +#![allow(unused)] mod type_format; mod type_id; @@ -9,11 +10,8 @@ pub use type_id::TypeId; // REQUESTS mod bind; -mod cancel_request; mod describe; -mod encode; mod execute; -mod flush; mod parse; mod password_message; mod query; @@ -25,48 +23,46 @@ mod statement; mod sync; mod terminate; -pub use bind::Bind; -pub use cancel_request::CancelRequest; -pub use describe::Describe; -pub use encode::Encode; -pub use execute::Execute; -pub use flush::Flush; -pub use parse::Parse; -pub use password_message::PasswordMessage; -pub use query::Query; -pub use sasl::{hi, SaslInitialResponse, SaslResponse}; -pub use ssl_request::SslRequest; -pub use startup_message::StartupMessage; -pub use statement::StatementId; -pub use sync::Sync; -pub use terminate::Terminate; +pub(crate) use bind::Bind; +pub(crate) use describe::Describe; +pub(crate) use execute::Execute; +pub(crate) use parse::Parse; +pub(crate) use password_message::PasswordMessage; +pub(crate) use query::Query; +pub(crate) use sasl::{hi, SaslInitialResponse, SaslResponse}; +#[cfg_attr(not(feature = "tls"), allow(unused_imports, dead_code))] +pub(crate) use ssl_request::SslRequest; +pub(crate) use startup_message::StartupMessage; +pub(crate) use statement::StatementId; +pub(crate) use sync::Sync; +pub(crate) use terminate::Terminate; // RESPONSES mod authentication; mod backend_key_data; mod command_complete; mod data_row; -mod decode; mod notification_response; mod parameter_description; -mod parameter_status; mod ready_for_query; mod response; mod row_description; mod message; -pub use authentication::{ +pub(crate) use authentication::{ Authentication, AuthenticationMd5, AuthenticationSasl, AuthenticationSaslContinue, }; -pub use backend_key_data::BackendKeyData; -pub use command_complete::CommandComplete; -pub use data_row::DataRow; -pub use decode::Decode; -pub use message::Message; -pub use notification_response::NotificationResponse; -pub use parameter_description::ParameterDescription; -pub use parameter_status::ParameterStatus; -pub use ready_for_query::ReadyForQuery; -pub use response::{Response, Severity}; -pub use row_description::{Field, RowDescription}; +pub(crate) use backend_key_data::BackendKeyData; +pub(crate) use command_complete::CommandComplete; +pub(crate) use data_row::DataRow; +pub(crate) use message::Message; +pub(crate) use notification_response::NotificationResponse; +pub(crate) use parameter_description::ParameterDescription; +pub(crate) use ready_for_query::ReadyForQuery; +pub(crate) use response::Response; +pub(crate) use row_description::{Field, RowDescription}; + +pub(crate) trait Write { + fn write(&self, buf: &mut Vec); +} diff --git a/sqlx-core/src/postgres/protocol/notification_response.rs b/sqlx-core/src/postgres/protocol/notification_response.rs index 55a0482b24..f4ba9d83f4 100644 --- a/sqlx-core/src/postgres/protocol/notification_response.rs +++ b/sqlx-core/src/postgres/protocol/notification_response.rs @@ -1,40 +1,48 @@ use crate::io::Buf; -use crate::postgres::protocol::Decode; use byteorder::NetworkEndian; +use std::borrow::Cow; #[derive(Debug)] -pub struct NotificationResponse { - pub pid: u32, - pub channel_name: String, - pub message: String, +pub(crate) struct NotificationResponse<'c> { + pub(crate) process_id: u32, + pub(crate) channel: Cow<'c, str>, + pub(crate) payload: Cow<'c, str>, } -impl Decode for NotificationResponse { - fn decode(mut buf: &[u8]) -> crate::Result { - let pid = buf.get_u32::()?; - let channel_name = buf.get_str_nul()?.to_owned(); - let message = buf.get_str_nul()?.to_owned(); +impl<'c> NotificationResponse<'c> { + pub(crate) fn read(mut buf: &'c [u8]) -> crate::Result { + let process_id = buf.get_u32::()?; + let channel = buf.get_str_nul()?; + let payload = buf.get_str_nul()?; Ok(Self { - pid, - channel_name, - message, + process_id, + channel: Cow::Borrowed(channel), + payload: Cow::Borrowed(payload), }) } + + pub(crate) fn into_owned(self) -> NotificationResponse<'static> { + NotificationResponse { + process_id: self.process_id, + channel: Cow::Owned(self.channel.into_owned()), + payload: Cow::Owned(self.payload.into_owned()), + } + } } #[cfg(test)] mod tests { - use super::{Decode, NotificationResponse}; + use super::NotificationResponse; const NOTIFICATION_RESPONSE: &[u8] = b"\x34\x20\x10\x02TEST-CHANNEL\0THIS IS A TEST\0"; #[test] fn it_decodes_notification_response() { - let message = NotificationResponse::decode(NOTIFICATION_RESPONSE).unwrap(); + let message = NotificationResponse::read(NOTIFICATION_RESPONSE).unwrap(); - assert_eq!(message.pid, 0x34201002); - assert_eq!(message.channel_name, "TEST-CHANNEL"); - assert_eq!(message.message, "THIS IS A TEST"); + assert_eq!(message.process_id, 0x34201002); + assert_eq!(&*message.channel, "TEST-CHANNEL"); + assert_eq!(&*message.payload, "THIS IS A TEST"); } } diff --git a/sqlx-core/src/postgres/protocol/parameter_status.rs b/sqlx-core/src/postgres/protocol/parameter_status.rs deleted file mode 100644 index 23f714b351..0000000000 --- a/sqlx-core/src/postgres/protocol/parameter_status.rs +++ /dev/null @@ -1,32 +0,0 @@ -use crate::io::Buf; -use crate::postgres::protocol::Decode; - -#[derive(Debug)] -pub struct ParameterStatus { - pub name: Box, - pub value: Box, -} - -impl Decode for ParameterStatus { - fn decode(mut buf: &[u8]) -> crate::Result { - let name = buf.get_str_nul()?.into(); - let value = buf.get_str_nul()?.into(); - - Ok(Self { name, value }) - } -} - -#[cfg(test)] -mod tests { - use super::{Decode, ParameterStatus}; - - const PARAM_STATUS: &[u8] = b"session_authorization\0postgres\0"; - - #[test] - fn it_decodes_param_status() { - let message = ParameterStatus::decode(PARAM_STATUS).unwrap(); - - assert_eq!(&*message.name, "session_authorization"); - assert_eq!(&*message.value, "postgres"); - } -} diff --git a/sqlx-core/src/postgres/protocol/parse.rs b/sqlx-core/src/postgres/protocol/parse.rs index e7beb7724f..bc0cbc001d 100644 --- a/sqlx-core/src/postgres/protocol/parse.rs +++ b/sqlx-core/src/postgres/protocol/parse.rs @@ -1,5 +1,5 @@ use crate::io::BufMut; -use crate::postgres::protocol::{Encode, StatementId}; +use crate::postgres::protocol::{StatementId, Write}; use byteorder::{ByteOrder, NetworkEndian}; pub struct Parse<'a> { @@ -8,14 +8,14 @@ pub struct Parse<'a> { pub param_types: &'a [u32], } -impl Encode for Parse<'_> { - fn encode(&self, buf: &mut Vec) { +impl Write for Parse<'_> { + fn write(&self, buf: &mut Vec) { buf.push(b'P'); let pos = buf.len(); buf.put_i32::(0); // skip over len - self.statement.encode(buf); + self.statement.write(buf); buf.put_str_nul(self.query); diff --git a/sqlx-core/src/postgres/protocol/password_message.rs b/sqlx-core/src/postgres/protocol/password_message.rs index 4830fcc08e..84486af016 100644 --- a/sqlx-core/src/postgres/protocol/password_message.rs +++ b/sqlx-core/src/postgres/protocol/password_message.rs @@ -1,9 +1,9 @@ use crate::io::BufMut; -use crate::postgres::protocol::Encode; +use crate::postgres::protocol::Write; use byteorder::NetworkEndian; use md5::{Digest, Md5}; -pub enum PasswordMessage<'a> { +pub(crate) enum PasswordMessage<'a> { ClearText(&'a str), Md5 { @@ -13,8 +13,8 @@ pub enum PasswordMessage<'a> { }, } -impl Encode for PasswordMessage<'_> { - fn encode(&self, buf: &mut Vec) { +impl Write for PasswordMessage<'_> { + fn write(&self, buf: &mut Vec) { buf.push(b'p'); match self { @@ -54,23 +54,23 @@ impl Encode for PasswordMessage<'_> { #[cfg(test)] mod tests { - use super::{Encode, PasswordMessage}; + use super::{PasswordMessage, Write}; const PASSWORD_CLEAR: &[u8] = b"p\0\0\0\rpassword\0"; const PASSWORD_MD5: &[u8] = b"p\0\0\0(md53e2c9d99d49b201ef867a36f3f9ed62c\0"; #[test] - fn it_encodes_password_clear() { + fn it_writes_password_clear() { let mut buf = Vec::new(); let m = PasswordMessage::ClearText("password"); - m.encode(&mut buf); + m.write(&mut buf); assert_eq!(buf, PASSWORD_CLEAR); } #[test] - fn it_encodes_password_md5() { + fn it_writes_password_md5() { let mut buf = Vec::new(); let m = PasswordMessage::Md5 { password: "password", @@ -78,7 +78,7 @@ mod tests { salt: [147, 24, 57, 152], }; - m.encode(&mut buf); + m.write(&mut buf); assert_eq!(buf, PASSWORD_MD5); } diff --git a/sqlx-core/src/postgres/protocol/query.rs b/sqlx-core/src/postgres/protocol/query.rs index 596db5dead..f86a726454 100644 --- a/sqlx-core/src/postgres/protocol/query.rs +++ b/sqlx-core/src/postgres/protocol/query.rs @@ -1,11 +1,11 @@ use crate::io::BufMut; -use crate::postgres::protocol::Encode; +use crate::postgres::protocol::Write; use byteorder::NetworkEndian; pub struct Query<'a>(pub &'a str); -impl Encode for Query<'_> { - fn encode(&self, buf: &mut Vec) { +impl Write for Query<'_> { + fn write(&self, buf: &mut Vec) { buf.push(b'Q'); // len + query + nul @@ -17,16 +17,16 @@ impl Encode for Query<'_> { #[cfg(test)] mod tests { - use super::{Encode, Query}; + use super::{Query, Write}; const QUERY_SELECT_1: &[u8] = b"Q\0\0\0\rSELECT 1\0"; #[test] - fn it_encodes_query() { + fn it_writes_query() { let mut buf = Vec::new(); let m = Query("SELECT 1"); - m.encode(&mut buf); + m.write(&mut buf); assert_eq!(buf, QUERY_SELECT_1); } diff --git a/sqlx-core/src/postgres/protocol/ready_for_query.rs b/sqlx-core/src/postgres/protocol/ready_for_query.rs index 05666895ba..1b526b48ff 100644 --- a/sqlx-core/src/postgres/protocol/ready_for_query.rs +++ b/sqlx-core/src/postgres/protocol/ready_for_query.rs @@ -1,5 +1,3 @@ -use crate::postgres::protocol::Decode; - #[derive(Debug)] #[repr(u8)] pub enum TransactionStatus { @@ -19,8 +17,8 @@ pub struct ReadyForQuery { status: TransactionStatus, } -impl Decode for ReadyForQuery { - fn decode(buf: &[u8]) -> crate::Result { +impl ReadyForQuery { + pub(crate) fn read(buf: &[u8]) -> crate::Result { Ok(Self { status: match buf[0] { b'I' => TransactionStatus::Idle, @@ -41,14 +39,14 @@ impl Decode for ReadyForQuery { #[cfg(test)] mod tests { - use super::{Decode, ReadyForQuery, TransactionStatus}; + use super::{ReadyForQuery, TransactionStatus}; use matches::assert_matches; const READY_FOR_QUERY: &[u8] = b"E"; #[test] fn it_decodes_ready_for_query() { - let message = ReadyForQuery::decode(READY_FOR_QUERY).unwrap(); + let message = ReadyForQuery::read(READY_FOR_QUERY).unwrap(); assert_matches!(message.status, TransactionStatus::Error); } diff --git a/sqlx-core/src/postgres/protocol/response.rs b/sqlx-core/src/postgres/protocol/response.rs index 75559296f7..8e208e5ea3 100644 --- a/sqlx-core/src/postgres/protocol/response.rs +++ b/sqlx-core/src/postgres/protocol/response.rs @@ -2,7 +2,7 @@ use crate::io::Buf; use std::str::{self, FromStr}; #[derive(Debug, Copy, Clone)] -pub enum Severity { +pub(crate) enum Severity { Panic, Fatal, Error, @@ -14,7 +14,7 @@ pub enum Severity { } impl Severity { - pub fn is_error(self) -> bool { + pub(crate) fn is_error(self) -> bool { match self { Severity::Panic | Severity::Fatal | Severity::Error => true, _ => false, @@ -44,28 +44,28 @@ impl FromStr for Severity { } #[derive(Debug)] -pub struct Response { - pub severity: Severity, - pub code: Box, - pub message: Box, - pub detail: Option>, - pub hint: Option>, - pub position: Option, - pub internal_position: Option, - pub internal_query: Option>, - pub where_: Option>, - pub schema: Option>, - pub table: Option>, - pub column: Option>, - pub data_type: Option>, - pub constraint: Option>, - pub file: Option>, - pub line: Option, - pub routine: Option>, +pub(crate) struct Response { + pub(crate) severity: Severity, + pub(crate) code: Box, + pub(crate) message: Box, + pub(crate) detail: Option>, + pub(crate) hint: Option>, + pub(crate) position: Option, + pub(crate) internal_position: Option, + pub(crate) internal_query: Option>, + pub(crate) where_: Option>, + pub(crate) schema: Option>, + pub(crate) table: Option>, + pub(crate) column: Option>, + pub(crate) data_type: Option>, + pub(crate) constraint: Option>, + pub(crate) file: Option>, + pub(crate) line: Option, + pub(crate) routine: Option>, } impl Response { - pub fn read(mut buf: &[u8]) -> crate::Result { + pub(crate) fn read(mut buf: &[u8]) -> crate::Result { let mut code = None::>; let mut message = None::>; let mut severity = None::>; diff --git a/sqlx-core/src/postgres/protocol/row_description.rs b/sqlx-core/src/postgres/protocol/row_description.rs index 2e85782c13..9837c38bce 100644 --- a/sqlx-core/src/postgres/protocol/row_description.rs +++ b/sqlx-core/src/postgres/protocol/row_description.rs @@ -3,19 +3,19 @@ use crate::postgres::protocol::{TypeFormat, TypeId}; use byteorder::NetworkEndian; #[derive(Debug)] -pub struct RowDescription { - pub fields: Box<[Field]>, +pub(crate) struct RowDescription { + pub(crate) fields: Box<[Field]>, } #[derive(Debug)] -pub struct Field { - pub name: Option>, - pub table_id: Option, - pub column_id: i16, - pub type_id: TypeId, - pub type_size: i16, - pub type_mod: i32, - pub type_format: TypeFormat, +pub(crate) struct Field { + pub(crate) name: Option>, + pub(crate) table_id: Option, + pub(crate) column_id: i16, + pub(crate) type_id: TypeId, + pub(crate) type_size: i16, + pub(crate) type_mod: i32, + pub(crate) type_format: TypeFormat, } impl RowDescription { diff --git a/sqlx-core/src/postgres/protocol/sasl.rs b/sqlx-core/src/postgres/protocol/sasl.rs index 80de189901..3079034ae6 100644 --- a/sqlx-core/src/postgres/protocol/sasl.rs +++ b/sqlx-core/src/postgres/protocol/sasl.rs @@ -1,14 +1,14 @@ use crate::io::BufMut; -use crate::postgres::protocol::Encode; +use crate::postgres::protocol::Write; use crate::Result; use byteorder::NetworkEndian; use hmac::{Hmac, Mac}; use sha2::Sha256; -pub struct SaslInitialResponse<'a>(pub &'a str); +pub(crate) struct SaslInitialResponse<'a>(pub(crate) &'a str); -impl<'a> Encode for SaslInitialResponse<'a> { - fn encode(&self, buf: &mut Vec) { +impl<'a> Write for SaslInitialResponse<'a> { + fn write(&self, buf: &mut Vec) { let len = self.0.as_bytes().len() as u32; buf.push(b'p'); buf.put_u32::(4u32 + len + 14u32 + 4u32); @@ -18,10 +18,10 @@ impl<'a> Encode for SaslInitialResponse<'a> { } } -pub struct SaslResponse<'a>(pub &'a str); +pub(crate) struct SaslResponse<'a>(pub(crate) &'a str); -impl<'a> Encode for SaslResponse<'a> { - fn encode(&self, buf: &mut Vec) { +impl<'a> Write for SaslResponse<'a> { + fn write(&self, buf: &mut Vec) { buf.push(b'p'); buf.put_u32::(4u32 + self.0.as_bytes().len() as u32); buf.extend_from_slice(self.0.as_bytes()); @@ -29,7 +29,7 @@ impl<'a> Encode for SaslResponse<'a> { } // Hi(str, salt, i): -pub fn hi<'a>(s: &'a str, salt: &'a [u8], iter_count: u32) -> Result<[u8; 32]> { +pub(crate) fn hi<'a>(s: &'a str, salt: &'a [u8], iter_count: u32) -> Result<[u8; 32]> { let mut mac = Hmac::::new_varkey(s.as_bytes()) .map_err(|_| protocol_err!("HMAC can take key of any size"))?; diff --git a/sqlx-core/src/postgres/protocol/ssl_request.rs b/sqlx-core/src/postgres/protocol/ssl_request.rs index 7c9c74c54f..1bc947f03a 100644 --- a/sqlx-core/src/postgres/protocol/ssl_request.rs +++ b/sqlx-core/src/postgres/protocol/ssl_request.rs @@ -1,13 +1,13 @@ use byteorder::NetworkEndian; use crate::io::BufMut; -use crate::postgres::protocol::Encode; +use crate::postgres::protocol::Write; #[derive(Debug)] pub struct SslRequest; -impl Encode for SslRequest { - fn encode(&self, buf: &mut Vec) { +impl Write for SslRequest { + fn write(&self, buf: &mut Vec) { // packet length: 8 bytes including self buf.put_u32::(8); // 1234 in high 16 bits, 5679 in low 16 @@ -18,7 +18,7 @@ impl Encode for SslRequest { #[test] fn test_ssl_request() { let mut buf = Vec::new(); - SslRequest.encode(&mut buf); + SslRequest.write(&mut buf); assert_eq!(&buf, b"\x00\x00\x00\x08\x04\xd2\x16/"); } diff --git a/sqlx-core/src/postgres/protocol/startup_message.rs b/sqlx-core/src/postgres/protocol/startup_message.rs index 4299efb1fe..e282f6665f 100644 --- a/sqlx-core/src/postgres/protocol/startup_message.rs +++ b/sqlx-core/src/postgres/protocol/startup_message.rs @@ -1,13 +1,13 @@ use crate::io::BufMut; -use crate::postgres::protocol::Encode; +use crate::postgres::protocol::Write; use byteorder::{BigEndian, ByteOrder, NetworkEndian}; pub struct StartupMessage<'a> { pub params: &'a [(&'a str, &'a str)], } -impl Encode for StartupMessage<'_> { - fn encode(&self, buf: &mut Vec) { +impl Write for StartupMessage<'_> { + fn write(&self, buf: &mut Vec) { let pos = buf.len(); buf.put_i32::(0); // skip over len @@ -29,7 +29,7 @@ impl Encode for StartupMessage<'_> { #[cfg(test)] mod tests { - use super::{Encode, StartupMessage}; + use super::{StartupMessage, Write}; const STARTUP_MESSAGE: &[u8] = b"\0\0\0)\0\x03\0\0user\0postgres\0database\0postgres\0\0"; @@ -40,7 +40,7 @@ mod tests { params: &[("user", "postgres"), ("database", "postgres")], }; - m.encode(&mut buf); + m.write(&mut buf); assert_eq!(buf, STARTUP_MESSAGE); } diff --git a/sqlx-core/src/postgres/protocol/statement.rs b/sqlx-core/src/postgres/protocol/statement.rs index be0cb25e90..0f4825a63d 100644 --- a/sqlx-core/src/postgres/protocol/statement.rs +++ b/sqlx-core/src/postgres/protocol/statement.rs @@ -1,13 +1,13 @@ -use std::io::Write; +use std::io::Write as _; use crate::io::BufMut; -use crate::postgres::protocol::Encode; +use crate::postgres::protocol::Write; #[derive(Copy, Clone, PartialOrd, PartialEq, Eq, Hash)] pub struct StatementId(pub u32); -impl Encode for StatementId { - fn encode(&self, buf: &mut Vec) { +impl Write for StatementId { + fn write(&self, buf: &mut Vec) { if self.0 != 0 { let _ = write!(buf, "__sqlx_statement_{}\0", self.0); } else { diff --git a/sqlx-core/src/postgres/protocol/sync.rs b/sqlx-core/src/postgres/protocol/sync.rs index ffb4e04c1e..1cd366ccfc 100644 --- a/sqlx-core/src/postgres/protocol/sync.rs +++ b/sqlx-core/src/postgres/protocol/sync.rs @@ -1,12 +1,12 @@ use crate::io::BufMut; -use crate::postgres::protocol::Encode; +use crate::postgres::protocol::Write; use byteorder::NetworkEndian; pub struct Sync; -impl Encode for Sync { +impl Write for Sync { #[inline] - fn encode(&self, buf: &mut Vec) { + fn write(&self, buf: &mut Vec) { buf.push(b'S'); buf.put_i32::(4); } diff --git a/sqlx-core/src/postgres/protocol/terminate.rs b/sqlx-core/src/postgres/protocol/terminate.rs index 6c5b8e2c46..2422826aea 100644 --- a/sqlx-core/src/postgres/protocol/terminate.rs +++ b/sqlx-core/src/postgres/protocol/terminate.rs @@ -1,12 +1,12 @@ use crate::io::BufMut; -use crate::postgres::protocol::Encode; +use crate::postgres::protocol::Write; use byteorder::NetworkEndian; pub struct Terminate; -impl Encode for Terminate { +impl Write for Terminate { #[inline] - fn encode(&self, buf: &mut Vec) { + fn write(&self, buf: &mut Vec) { buf.push(b'X'); buf.put_i32::(4); } diff --git a/sqlx-core/src/postgres/row.rs b/sqlx-core/src/postgres/row.rs index 37c801caaa..9158b2a7a8 100644 --- a/sqlx-core/src/postgres/row.rs +++ b/sqlx-core/src/postgres/row.rs @@ -41,8 +41,9 @@ impl<'c> Row<'c> for PgRow<'c> { self.data.len() } - fn try_get_raw<'r, I>(&'r self, index: I) -> crate::Result>> + fn try_get_raw<'r, I>(&'r self, index: I) -> crate::Result>> where + 'c: 'r, I: ColumnIndex, { let index = index.resolve(self)?; diff --git a/sqlx-core/src/postgres/sasl.rs b/sqlx-core/src/postgres/sasl.rs index 79881b56b8..553b06c68a 100644 --- a/sqlx-core/src/postgres/sasl.rs +++ b/sqlx-core/src/postgres/sasl.rs @@ -65,7 +65,7 @@ pub(super) async fn authenticate>( stream.write(SaslInitialResponse(&client_first_message)); stream.flush().await?; - let server_first_message = stream.read().await?; + let server_first_message = stream.receive().await?; if let Message::Authentication = server_first_message { let auth = Authentication::read(stream.buffer())?; @@ -140,7 +140,7 @@ pub(super) async fn authenticate>( stream.write(SaslResponse(&client_final_message)); stream.flush().await?; - let _server_final_response = stream.read().await?; + let _server_final_response = stream.receive().await?; // todo: assert that this was SaslFinal? Ok(()) diff --git a/sqlx-core/src/postgres/stream.rs b/sqlx-core/src/postgres/stream.rs index 0b19ad025d..17f29edce1 100644 --- a/sqlx-core/src/postgres/stream.rs +++ b/sqlx-core/src/postgres/stream.rs @@ -2,14 +2,17 @@ use std::convert::TryInto; use std::net::Shutdown; use byteorder::NetworkEndian; +use futures_channel::mpsc::UnboundedSender; use crate::io::{Buf, BufStream, MaybeTlsStream}; -use crate::postgres::protocol::{Encode, Message, Response}; +use crate::postgres::protocol::{Message, NotificationResponse, Response, Write}; use crate::postgres::PgError; use crate::url::Url; +use futures_util::SinkExt; pub struct PgStream { pub(super) stream: BufStream, + pub(super) notifications: Option>>, // Most recently received message // Is referenced by our buffered stream @@ -22,6 +25,7 @@ impl PgStream { let stream = MaybeTlsStream::connect(&url, 5432).await?; Ok(Self { + notifications: None, stream: BufStream::new(stream), message: (Message::ReadyForQuery, 0), }) @@ -34,9 +38,9 @@ impl PgStream { #[inline] pub(super) fn write(&mut self, message: M) where - M: Encode, + M: Write, { - message.encode(self.stream.buffer_mut()); + message.write(self.stream.buffer_mut()); } #[inline] @@ -45,30 +49,36 @@ impl PgStream { } pub(super) async fn read(&mut self) -> crate::Result { - loop { - // https://www.postgresql.org/docs/12/protocol-overview.html#PROTOCOL-MESSAGE-CONCEPTS + // https://www.postgresql.org/docs/12/protocol-overview.html#PROTOCOL-MESSAGE-CONCEPTS - // All communication is through a stream of messages. The first byte of a message - // identifies the message type, and the next four bytes specify the length of the rest of - // the message (this length count includes itself, but not the message-type byte). + // All communication is through a stream of messages. The first byte of a message + // identifies the message type, and the next four bytes specify the length of the rest of + // the message (this length count includes itself, but not the message-type byte). - if self.message.1 > 0 { - // If there is any data in our read buffer we need to make sure we flush that - // so reading will return the *next* message - self.stream.consume(self.message.1 as usize); - } + if self.message.1 > 0 { + // If there is any data in our read buffer we need to make sure we flush that + // so reading will return the *next* message + self.stream.consume(self.message.1 as usize); + } - let mut header = self.stream.peek(4 + 1).await?; + let mut header = self.stream.peek(4 + 1).await?; - let type_ = header.get_u8()?.try_into()?; - let length = header.get_u32::()? - 4; + let type_ = header.get_u8()?.try_into()?; + let length = header.get_u32::()? - 4; - self.message = (type_, length); - self.stream.consume(4 + 1); + self.message = (type_, length); + self.stream.consume(4 + 1); - // Wait until there is enough data in the stream. We then return without actually - // inspecting the data. This is then looked at later through the [buffer] function - let _ = self.stream.peek(length as usize).await?; + // Wait until there is enough data in the stream. We then return without actually + // inspecting the data. This is then looked at later through the [buffer] function + let _ = self.stream.peek(length as usize).await?; + + Ok(type_) + } + + pub(super) async fn receive(&mut self) -> crate::Result { + loop { + let type_ = self.read().await?; match type_ { Message::ErrorResponse | Message::NoticeResponse => { @@ -84,10 +94,19 @@ impl PgStream { continue; } - _ => { - return Ok(type_); + Message::NotificationResponse => { + if let Some(buffer) = &mut self.notifications { + let notification = NotificationResponse::read(self.stream.buffer())?; + + let _ = buffer.send(notification.into_owned()).await; + continue; + } } + + _ => {} } + + return Ok(type_); } } diff --git a/sqlx-core/src/row.rs b/sqlx-core/src/row.rs index 651d27113e..da3707dc41 100644 --- a/sqlx-core/src/row.rs +++ b/sqlx-core/src/row.rs @@ -26,27 +26,30 @@ pub trait Row<'c>: Unpin + Send { fn get<'r, T, I>(&'r self, index: I) -> T where + 'c: 'r, T: Type, I: ColumnIndex, - T: Decode<'c, Self::Database>, + T: Decode<'r, Self::Database>, { self.try_get::(index).unwrap() } fn try_get<'r, T, I>(&'r self, index: I) -> crate::Result where + 'c: 'r, T: Type, I: ColumnIndex, - T: Decode<'c, Self::Database>, + T: Decode<'r, Self::Database>, { Ok(Decode::decode(self.try_get_raw(index)?)?) } fn try_get_raw<'r, I>( - &self, + &'r self, index: I, - ) -> crate::Result<>::RawValue> + ) -> crate::Result<>::RawValue> where + 'c: 'r, I: ColumnIndex; } @@ -68,7 +71,7 @@ macro_rules! impl_from_row_for_tuple { impl<'c, $($T,)+> crate::row::FromRow<'c, $r<'c>> for ($($T,)+) where $($T: crate::types::Type<$db>,)+ - $($T: crate::decode::Decode<'c, $db>,)+ + $($T: for<'r> crate::decode::Decode<'r, $db>,)+ { #[inline] fn from_row(row: $r<'c>) -> crate::Result { diff --git a/sqlx-core/src/sqlite/arguments.rs b/sqlx-core/src/sqlite/arguments.rs new file mode 100644 index 0000000000..28099fefe5 --- /dev/null +++ b/sqlx-core/src/sqlite/arguments.rs @@ -0,0 +1,133 @@ +use core::ffi::c_void; +use core::mem; + +use std::os::raw::c_int; + +use libsqlite3_sys::{ + sqlite3_bind_blob, sqlite3_bind_double, sqlite3_bind_int, sqlite3_bind_int64, + sqlite3_bind_null, sqlite3_bind_text, SQLITE_OK, SQLITE_TRANSIENT, +}; + +use crate::arguments::Arguments; +use crate::encode::Encode; +use crate::sqlite::statement::Statement; +use crate::sqlite::Sqlite; +use crate::sqlite::SqliteError; +use crate::types::Type; + +#[derive(Debug, Clone)] +pub enum SqliteArgumentValue { + Null, + + // TODO: Take by reference to remove the allocation + Text(String), + + // TODO: Take by reference to remove the allocation + Blob(Vec), + + Double(f64), + + Int(i32), + + Int64(i64), +} + +#[derive(Default)] +pub struct SqliteArguments { + index: usize, + values: Vec, +} + +impl SqliteArguments { + pub(crate) fn next(&mut self) -> Option { + if self.index >= self.values.len() { + return None; + } + + let mut value = SqliteArgumentValue::Null; + mem::swap(&mut value, &mut self.values[self.index]); + + self.index += 1; + Some(value) + } +} + +impl Arguments for SqliteArguments { + type Database = Sqlite; + + fn reserve(&mut self, len: usize, _size_hint: usize) { + self.values.reserve(len); + } + + fn add(&mut self, value: T) + where + T: Encode + Type, + { + value.encode(&mut self.values); + } +} + +impl SqliteArgumentValue { + pub(super) fn bind(&self, statement: &mut Statement, index: usize) -> crate::Result<()> { + // TODO: Handle error of trying to bind too many parameters here + let index = index as c_int; + + // https://sqlite.org/c3ref/bind_blob.html + #[allow(unsafe_code)] + let status: c_int = match self { + SqliteArgumentValue::Blob(value) => { + // TODO: Handle bytes that are too large + let bytes = value.as_slice(); + let bytes_ptr = bytes.as_ptr() as *const c_void; + let bytes_len = bytes.len() as i32; + + unsafe { + sqlite3_bind_blob( + statement.handle(), + index, + bytes_ptr, + bytes_len, + SQLITE_TRANSIENT(), + ) + } + } + + SqliteArgumentValue::Text(value) => { + // TODO: Handle text that is too large + let bytes = value.as_bytes(); + let bytes_ptr = bytes.as_ptr() as *const i8; + let bytes_len = bytes.len() as i32; + + unsafe { + sqlite3_bind_text( + statement.handle(), + index, + bytes_ptr, + bytes_len, + SQLITE_TRANSIENT(), + ) + } + } + + SqliteArgumentValue::Double(value) => unsafe { + sqlite3_bind_double(statement.handle(), index, *value) + }, + + SqliteArgumentValue::Int(value) => unsafe { + sqlite3_bind_int(statement.handle(), index, *value) + }, + + SqliteArgumentValue::Int64(value) => unsafe { + sqlite3_bind_int64(statement.handle(), index, *value) + }, + + SqliteArgumentValue::Null => unsafe { sqlite3_bind_null(statement.handle(), index) }, + }; + + if status != SQLITE_OK { + return Err(SqliteError::from_connection(statement.connection.0.as_ptr()).into()); + } + + Ok(()) + } +} diff --git a/sqlx-core/src/sqlite/connection.rs b/sqlx-core/src/sqlite/connection.rs new file mode 100644 index 0000000000..f6812be3a4 --- /dev/null +++ b/sqlx-core/src/sqlite/connection.rs @@ -0,0 +1,158 @@ +use core::ptr::{null, null_mut, NonNull}; + +use std::collections::HashMap; +use std::convert::TryInto; +use std::ffi::CString; + +use futures_core::future::BoxFuture; +use futures_util::future; +use libsqlite3_sys::{ + sqlite3, sqlite3_close, sqlite3_extended_result_codes, sqlite3_open_v2, SQLITE_OK, + SQLITE_OPEN_CREATE, SQLITE_OPEN_NOMUTEX, SQLITE_OPEN_READWRITE, SQLITE_OPEN_SHAREDCACHE, +}; + +use crate::connection::{Connect, Connection}; +use crate::executor::Executor; +use crate::sqlite::statement::Statement; +use crate::sqlite::worker::Worker; +use crate::sqlite::SqliteError; +use crate::url::Url; + +/// Thin wrapper around [sqlite3] to impl `Send`. +#[derive(Clone, Copy)] +pub(super) struct SqliteConnectionHandle(pub(super) NonNull); + +/// A connection to a [SQLite][super::Sqlite] database. +pub struct SqliteConnection { + pub(super) handle: SqliteConnectionHandle, + pub(super) worker: Worker, + // Storage of the most recently prepared, non-persistent statement + pub(super) statement: Option, + // Storage of persistent statements + pub(super) statements: Vec, + pub(super) statement_by_query: HashMap, +} + +// A SQLite3 handle is safe to send between threads, provided not more than +// one is accessing it at the same time. This is upheld as long as [SQLITE_CONFIG_MULTITHREAD] is +// enabled and [SQLITE_THREADSAFE] was enabled when sqlite was compiled. We refuse to work +// if these conditions are not upheld. + +// + +// + +#[allow(unsafe_code)] +unsafe impl Send for SqliteConnectionHandle {} + +async fn establish(url: crate::Result) -> crate::Result { + let mut worker = Worker::new(); + + let url = url?; + let url = url + .as_str() + .trim_start_matches("sqlite:") + .trim_start_matches("//"); + + // By default, we connect to an in-memory database. + // TODO: Handle the error when there are internal NULs in the database URL + let filename = CString::new(url).unwrap(); + + let handle = worker + .run(move || -> crate::Result { + let mut handle = null_mut(); + + // [SQLITE_OPEN_NOMUTEX] will instruct [sqlite3_open_v2] to return an error if it + // cannot satisfy our wish for a thread-safe, lock-free connection object + let flags = SQLITE_OPEN_READWRITE + | SQLITE_OPEN_CREATE + | SQLITE_OPEN_NOMUTEX + | SQLITE_OPEN_SHAREDCACHE; + + // + #[allow(unsafe_code)] + let status = unsafe { sqlite3_open_v2(filename.as_ptr(), &mut handle, flags, null()) }; + + if status != SQLITE_OK { + return Err(SqliteError::from_connection(handle).into()); + } + + // Enable extended result codes + // https://www.sqlite.org/c3ref/extended_result_codes.html + #[allow(unsafe_code)] + unsafe { + sqlite3_extended_result_codes(handle, 1); + } + + Ok(SqliteConnectionHandle(NonNull::new(handle).unwrap())) + }) + .await?; + + Ok(SqliteConnection { + worker, + handle, + statement: None, + statements: Vec::with_capacity(10), + statement_by_query: HashMap::with_capacity(10), + }) +} + +impl SqliteConnection { + #[inline] + pub(super) fn handle(&mut self) -> *mut sqlite3 { + self.handle.0.as_ptr() + } +} + +impl Connect for SqliteConnection { + fn connect(url: T) -> BoxFuture<'static, crate::Result> + where + T: TryInto, + Self: Sized, + { + let url = url.try_into(); + + Box::pin(async move { + let mut conn = establish(url).await?; + + // https://www.sqlite.org/wal.html + + // language=SQLite + conn.execute( + r#" +PRAGMA journal_mode = WAL; +PRAGMA synchronous = NORMAL; + "#, + ) + .await?; + + Ok(conn) + }) + } +} + +impl Connection for SqliteConnection { + fn close(self) -> BoxFuture<'static, crate::Result<()>> { + // All necessary behavior is handled on drop + Box::pin(future::ok(())) + } + + fn ping(&mut self) -> BoxFuture> { + // For SQLite connections, PING does effectively nothing + Box::pin(future::ok(())) + } +} + +impl Drop for SqliteConnection { + fn drop(&mut self) { + // Drop all statements first + self.statements.clear(); + + // Next close the statement + // https://sqlite.org/c3ref/close.html + #[allow(unsafe_code)] + unsafe { + let _ = sqlite3_close(self.handle()); + } + } +} diff --git a/sqlx-core/src/sqlite/cursor.rs b/sqlx-core/src/sqlite/cursor.rs new file mode 100644 index 0000000000..2fed34004b --- /dev/null +++ b/sqlx-core/src/sqlite/cursor.rs @@ -0,0 +1,95 @@ +use futures_core::future::BoxFuture; + +use crate::connection::ConnectionSource; +use crate::cursor::Cursor; +use crate::executor::Execute; +use crate::pool::Pool; +use crate::sqlite::statement::Step; +use crate::sqlite::{Sqlite, SqliteArguments, SqliteConnection, SqliteRow}; + +pub struct SqliteCursor<'c, 'q> { + pub(super) source: ConnectionSource<'c, SqliteConnection>, + query: &'q str, + arguments: Option, + pub(super) statement: Option>, +} + +impl<'c, 'q> Cursor<'c, 'q> for SqliteCursor<'c, 'q> { + type Database = Sqlite; + + fn from_pool(pool: &Pool, query: E) -> Self + where + Self: Sized, + E: Execute<'q, Sqlite>, + { + let (query, arguments) = query.into_parts(); + + Self { + source: ConnectionSource::Pool(pool.clone()), + statement: None, + query, + arguments, + } + } + + fn from_connection(conn: &'c mut SqliteConnection, query: E) -> Self + where + Self: Sized, + E: Execute<'q, Sqlite>, + { + let (query, arguments) = query.into_parts(); + + Self { + source: ConnectionSource::ConnectionRef(conn), + statement: None, + query, + arguments, + } + } + + fn next(&mut self) -> BoxFuture>>> { + Box::pin(next(self)) + } +} + +async fn next<'a, 'c: 'a, 'q: 'a>( + cursor: &'a mut SqliteCursor<'c, 'q>, +) -> crate::Result>> { + let conn = cursor.source.resolve().await?; + + loop { + if cursor.statement.is_none() { + let key = conn.prepare(&mut cursor.query, cursor.arguments.is_some())?; + + if let Some(arguments) = &mut cursor.arguments { + conn.statement_mut(key).bind(arguments)?; + } + + cursor.statement = Some(key); + } + + let key = cursor.statement.unwrap(); + let statement = conn.statement_mut(key); + + let step = statement.step().await?; + + match step { + Step::Row => { + return Ok(Some(SqliteRow { + values: statement.data_count(), + statement: key, + connection: conn, + })); + } + + Step::Done if cursor.query.is_empty() => { + return Ok(None); + } + + Step::Done => { + cursor.statement = None; + // continue + } + } + } +} diff --git a/sqlx-core/src/sqlite/database.rs b/sqlx-core/src/sqlite/database.rs new file mode 100644 index 0000000000..efd40ca1dd --- /dev/null +++ b/sqlx-core/src/sqlite/database.rs @@ -0,0 +1,32 @@ +use crate::database::{Database, HasCursor, HasRawValue, HasRow}; + +/// **Sqlite** database driver. +pub struct Sqlite; + +impl Database for Sqlite { + type Connection = super::SqliteConnection; + + type Arguments = super::SqliteArguments; + + type TypeInfo = super::SqliteTypeInfo; + + type TableId = String; + + type RawBuffer = Vec; +} + +impl<'c> HasRow<'c> for Sqlite { + type Database = Sqlite; + + type Row = super::SqliteRow<'c>; +} + +impl<'c, 'q> HasCursor<'c, 'q> for Sqlite { + type Database = Sqlite; + + type Cursor = super::SqliteCursor<'c, 'q>; +} + +impl<'c> HasRawValue<'c> for Sqlite { + type RawValue = super::SqliteValue<'c>; +} diff --git a/sqlx-core/src/sqlite/error.rs b/sqlx-core/src/sqlite/error.rs new file mode 100644 index 0000000000..1b52dfff27 --- /dev/null +++ b/sqlx-core/src/sqlite/error.rs @@ -0,0 +1,51 @@ +use crate::error::DatabaseError; +use bitflags::_core::str::from_utf8_unchecked; +use libsqlite3_sys::{sqlite3, sqlite3_errmsg, sqlite3_extended_errcode}; +use std::ffi::CStr; +use std::fmt::{self, Display}; +use std::os::raw::c_int; + +#[derive(Debug)] +pub struct SqliteError { + code: String, + message: String, +} + +// Error Codes And Messages +// https://www.sqlite.org/c3ref/errcode.html + +impl SqliteError { + pub(super) fn from_connection(conn: *mut sqlite3) -> Self { + #[allow(unsafe_code)] + let code: c_int = unsafe { sqlite3_extended_errcode(conn) }; + + #[allow(unsafe_code)] + let message = unsafe { + let err = sqlite3_errmsg(conn); + debug_assert!(!err.is_null()); + + from_utf8_unchecked(CStr::from_ptr(err).to_bytes()) + }; + + Self { + code: code.to_string(), + message: message.to_owned(), + } + } +} + +impl Display for SqliteError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.pad(self.message()) + } +} + +impl DatabaseError for SqliteError { + fn message(&self) -> &str { + &self.message + } + + fn code(&self) -> Option<&str> { + Some(&self.code) + } +} diff --git a/sqlx-core/src/sqlite/executor.rs b/sqlx-core/src/sqlite/executor.rs new file mode 100644 index 0000000000..ad69e20cf9 --- /dev/null +++ b/sqlx-core/src/sqlite/executor.rs @@ -0,0 +1,198 @@ +use futures_core::future::BoxFuture; + +use libsqlite3_sys::sqlite3_changes; + +use crate::cursor::Cursor; +use crate::describe::{Column, Describe}; +use crate::executor::{Execute, Executor, RefExecutor}; +use crate::sqlite::cursor::SqliteCursor; +use crate::sqlite::statement::{Statement, Step}; +use crate::sqlite::types::SqliteType; +use crate::sqlite::{Sqlite, SqliteConnection, SqliteTypeInfo}; + +impl SqliteConnection { + pub(super) fn prepare( + &mut self, + query: &mut &str, + persistent: bool, + ) -> crate::Result> { + // TODO: Revisit statement caching and allow cache expiration by using a + // generational index + + if !persistent { + // A non-persistent query will be immediately prepared and returned, + // regardless of the current state of the cache + self.statement = Some(Statement::new(self, query, false)?); + return Ok(None); + } + + if let Some(key) = self.statement_by_query.get(&**query) { + let statement = &mut self.statements[*key]; + + // Adjust the passed in query string as if [string3_prepare] + // did the tail parsing + *query = &query[statement.tail..]; + + // As this statement has very likely been used before, we reset + // it to clear the bindings and its program state + statement.reset(); + + return Ok(Some(*key)); + } + + // Prepare a new statement object; ensuring to tell SQLite that this will be stored + // for a "long" time and re-used multiple times + + let query_key = query.to_owned(); + let statement = Statement::new(self, query, true)?; + + let key = self.statements.len(); + + self.statement_by_query.insert(query_key, key); + self.statements.push(statement); + + Ok(Some(key)) + } + + // This is used for [affected_rows] in the public API. + fn changes(&mut self) -> u64 { + // Returns the number of rows modified, inserted or deleted by the most recently + // completed INSERT, UPDATE or DELETE statement. + + // https://www.sqlite.org/c3ref/changes.html + #[allow(unsafe_code)] + let changes = unsafe { sqlite3_changes(self.handle()) }; + changes as u64 + } + + #[inline] + pub(super) fn statement(&self, key: Option) -> &Statement { + match key { + Some(key) => &self.statements[key], + None => self.statement.as_ref().unwrap(), + } + } + + #[inline] + pub(super) fn statement_mut(&mut self, key: Option) -> &mut Statement { + match key { + Some(key) => &mut self.statements[key], + None => self.statement.as_mut().unwrap(), + } + } +} + +impl Executor for SqliteConnection { + type Database = Sqlite; + + fn execute<'e, 'q: 'e, 'c: 'e, E: 'e>( + &'c mut self, + query: E, + ) -> BoxFuture<'e, crate::Result> + where + E: Execute<'q, Self::Database>, + { + let (mut query, mut arguments) = query.into_parts(); + + Box::pin(async move { + loop { + let key = self.prepare(&mut query, arguments.is_some())?; + let statement = self.statement_mut(key); + + if let Some(arguments) = &mut arguments { + statement.bind(arguments)?; + } + + while let Step::Row = statement.step().await? { + // We only care about the rows modified; ignore + } + + if query.is_empty() { + break; + } + } + + Ok(self.changes()) + }) + } + + fn fetch<'q, E>(&mut self, query: E) -> SqliteCursor<'_, 'q> + where + E: Execute<'q, Self::Database>, + { + SqliteCursor::from_connection(self, query) + } + + fn describe<'e, 'q, E: 'e>( + &'e mut self, + query: E, + ) -> BoxFuture<'e, crate::Result>> + where + E: Execute<'q, Self::Database>, + { + Box::pin(async move { + let (mut query, _) = query.into_parts(); + let key = self.prepare(&mut query, false)?; + let statement = self.statement_mut(key); + + // First let's attempt to describe what we can about parameter types + // Which happens to just be the count, heh + let num_params = statement.params(); + let params = vec![ + SqliteTypeInfo { + r#type: SqliteType::Null, + affinity: None, + }; + num_params + ] + .into_boxed_slice(); + + // Next, collect (return) column types and names + let num_columns = statement.column_count(); + let mut columns = Vec::with_capacity(num_columns); + for i in 0..num_columns { + let name = statement.column_name(i).to_owned(); + let decl = statement.column_decltype(i); + + let r#type = match decl { + None => SqliteType::Null, + Some(decl) => match &*decl.to_ascii_lowercase() { + "bool" | "boolean" => SqliteType::Boolean, + "clob" | "text" => SqliteType::Text, + "blob" => SqliteType::Blob, + "real" | "double" | "double precision" | "float" => SqliteType::Float, + _ if decl.contains("int") => SqliteType::Integer, + _ if decl.contains("char") => SqliteType::Text, + _ => SqliteType::Null, + }, + }; + + columns.push(Column { + name: Some(name.into()), + non_null: None, + table_id: None, + type_info: SqliteTypeInfo { + r#type, + affinity: None, + }, + }) + } + + Ok(Describe { + param_types: params, + result_columns: columns.into_boxed_slice(), + }) + }) + } +} + +impl<'e> RefExecutor<'e> for &'e mut SqliteConnection { + type Database = Sqlite; + + fn fetch_by_ref<'q, E>(self, query: E) -> SqliteCursor<'e, 'q> + where + E: Execute<'q, Self::Database>, + { + SqliteCursor::from_connection(self, query) + } +} diff --git a/sqlx-core/src/sqlite/mod.rs b/sqlx-core/src/sqlite/mod.rs new file mode 100644 index 0000000000..21221c4f2a --- /dev/null +++ b/sqlx-core/src/sqlite/mod.rs @@ -0,0 +1,27 @@ +mod arguments; +mod connection; +mod cursor; +mod database; +mod error; +mod executor; +mod row; +mod statement; +mod types; +mod value; +mod worker; + +pub use arguments::{SqliteArgumentValue, SqliteArguments}; +pub use connection::SqliteConnection; +pub use cursor::SqliteCursor; +pub use database::Sqlite; +pub use error::SqliteError; +pub use row::SqliteRow; +pub use types::SqliteTypeInfo; +pub use value::SqliteValue; + +/// An alias for [`Pool`][crate::Pool], specialized for **Sqlite**. +pub type SqlitePool = crate::pool::Pool; + +make_query_as!(SqliteQueryAs, Sqlite, SqliteRow); +impl_map_row_for_row!(Sqlite, SqliteRow); +impl_from_row_for_tuples!(Sqlite, SqliteRow); diff --git a/sqlx-core/src/sqlite/row.rs b/sqlx-core/src/sqlite/row.rs new file mode 100644 index 0000000000..b5ef73f4b5 --- /dev/null +++ b/sqlx-core/src/sqlite/row.rs @@ -0,0 +1,59 @@ +use crate::database::HasRow; +use crate::row::{ColumnIndex, Row}; +use crate::sqlite::statement::Statement; +use crate::sqlite::value::SqliteValue; +use crate::sqlite::{Sqlite, SqliteConnection}; + +pub struct SqliteRow<'c> { + pub(super) values: usize, + pub(super) statement: Option, + pub(super) connection: &'c mut SqliteConnection, +} + +impl<'c> SqliteRow<'c> { + fn statement(&'c self) -> &'c Statement { + self.connection.statement(self.statement) + } +} + +impl<'c> Row<'c> for SqliteRow<'c> { + type Database = Sqlite; + + #[inline] + fn len(&self) -> usize { + self.values + } + + fn try_get_raw<'r, I>(&'r self, index: I) -> crate::Result> + where + 'c: 'r, + I: ColumnIndex, + { + let index = index.resolve(self)?; + let value = SqliteValue::new(self.statement(), index); + + Ok(value) + } +} + +impl ColumnIndex for usize { + fn resolve(self, row: &::Row) -> crate::Result { + let len = Row::len(row); + + if self >= len { + return Err(crate::Error::ColumnIndexOutOfBounds { len, index: self }); + } + + Ok(self) + } +} + +impl ColumnIndex for &'_ str { + fn resolve(self, row: &::Row) -> crate::Result { + row.statement() + .columns + .get(self) + .ok_or_else(|| crate::Error::ColumnNotFound((*self).into())) + .map(|&index| index as usize) + } +} diff --git a/sqlx-core/src/sqlite/statement.rs b/sqlx-core/src/sqlite/statement.rs new file mode 100644 index 0000000000..8dfc144ee3 --- /dev/null +++ b/sqlx-core/src/sqlite/statement.rs @@ -0,0 +1,236 @@ +use core::ptr::{null, null_mut, NonNull}; + +use std::collections::HashMap; +use std::ffi::CStr; +use std::os::raw::c_int; + +use libsqlite3_sys::{ + sqlite3_bind_parameter_count, sqlite3_clear_bindings, sqlite3_column_count, + sqlite3_column_decltype, sqlite3_column_name, sqlite3_data_count, sqlite3_finalize, + sqlite3_prepare_v3, sqlite3_reset, sqlite3_step, sqlite3_stmt, SQLITE_DONE, SQLITE_OK, + SQLITE_PREPARE_NO_VTAB, SQLITE_PREPARE_PERSISTENT, SQLITE_ROW, +}; + +use crate::sqlite::connection::SqliteConnectionHandle; +use crate::sqlite::worker::Worker; +use crate::sqlite::SqliteError; +use crate::sqlite::{SqliteArguments, SqliteConnection}; + +/// Return values from [SqliteStatement::step]. +pub(super) enum Step { + /// The statement has finished executing successfully. + Done, + + /// Another row of output is available. + Row, +} + +/// Thin wrapper around [sqlite3_stmt] to impl `Send`. +#[derive(Clone, Copy)] +pub(super) struct SqliteStatementHandle(NonNull); + +/// Represents a _single_ SQL statement that has been compiled into binary +/// form and is ready to be evaluated. +/// +/// The statement is finalized ( `sqlite3_finalize` ) on drop. +pub(super) struct Statement { + handle: SqliteStatementHandle, + pub(super) connection: SqliteConnectionHandle, + pub(super) worker: Worker, + pub(super) tail: usize, + pub(super) columns: HashMap, +} + +// SQLite3 statement objects are safe to send between threads, but *not* safe +// for general-purpose concurrent access between threads. See more notes +// on [SqliteConnectionHandle]. + +#[allow(unsafe_code)] +unsafe impl Send for SqliteStatementHandle {} + +impl Statement { + pub(super) fn new( + conn: &mut SqliteConnection, + query: &mut &str, + persistent: bool, + ) -> crate::Result { + // TODO: Error on queries that are too large + let query_ptr = query.as_bytes().as_ptr() as *const i8; + let query_len = query.len() as i32; + let mut statement_handle: *mut sqlite3_stmt = null_mut(); + let mut flags = SQLITE_PREPARE_NO_VTAB; + let mut tail: *const i8 = null(); + + if persistent { + // SQLITE_PREPARE_PERSISTENT + // The SQLITE_PREPARE_PERSISTENT flag is a hint to the query + // planner that the prepared statement will be retained for a long time + // and probably reused many times. + flags |= SQLITE_PREPARE_PERSISTENT; + } + + // + #[allow(unsafe_code)] + let status = unsafe { + sqlite3_prepare_v3( + conn.handle(), + query_ptr, + query_len, + flags as u32, + &mut statement_handle, + &mut tail, + ) + }; + + if status != SQLITE_OK { + return Err(SqliteError::from_connection(conn.handle()).into()); + } + + // If pzTail is not NULL then *pzTail is made to point to the first byte + // past the end of the first SQL statement in zSql. + let tail = (tail as usize) - (query_ptr as usize); + *query = &query[tail..].trim(); + + let mut self_ = Self { + worker: conn.worker.clone(), + connection: conn.handle, + handle: SqliteStatementHandle(NonNull::new(statement_handle).unwrap()), + columns: HashMap::new(), + tail, + }; + + // Prepare a column hash map for use in pulling values from a column by name + let count = self_.column_count(); + self_.columns.reserve(count); + + for i in 0..count { + let name = self_.column_name(i).to_owned(); + self_.columns.insert(name, i); + } + + Ok(self_) + } + + /// Returns a pointer to the raw C pointer backing this statement. + #[inline] + #[allow(unsafe_code)] + pub(super) unsafe fn handle(&self) -> *mut sqlite3_stmt { + self.handle.0.as_ptr() + } + + pub(super) fn data_count(&mut self) -> usize { + // https://sqlite.org/c3ref/data_count.html + + // The sqlite3_data_count(P) interface returns the number of columns + // in the current row of the result set. + + // The value is correct only if there was a recent call to + // sqlite3_step that returned SQLITE_ROW. + + #[allow(unsafe_code)] + let count: c_int = unsafe { sqlite3_data_count(self.handle()) }; + count as usize + } + + pub(super) fn column_count(&mut self) -> usize { + // https://sqlite.org/c3ref/column_count.html + #[allow(unsafe_code)] + let count = unsafe { sqlite3_column_count(self.handle()) }; + count as usize + } + + pub(super) fn column_name(&mut self, index: usize) -> &str { + // https://sqlite.org/c3ref/column_name.html + #[allow(unsafe_code)] + let name = unsafe { + let ptr = sqlite3_column_name(self.handle(), index as c_int); + debug_assert!(!ptr.is_null()); + + CStr::from_ptr(ptr) + }; + + name.to_str().unwrap() + } + + pub(super) fn column_decltype(&mut self, index: usize) -> Option<&str> { + // https://sqlite.org/c3ref/column_name.html + #[allow(unsafe_code)] + let name = unsafe { + let ptr = sqlite3_column_decltype(self.handle(), index as c_int); + + if ptr.is_null() { + None + } else { + Some(CStr::from_ptr(ptr)) + } + }; + + name.map(|s| s.to_str().unwrap()) + } + + pub(super) fn params(&mut self) -> usize { + // https://www.hwaci.com/sw/sqlite/c3ref/bind_parameter_count.html + #[allow(unsafe_code)] + let num = unsafe { sqlite3_bind_parameter_count(self.handle()) }; + num as usize + } + + pub(super) fn bind(&mut self, arguments: &mut SqliteArguments) -> crate::Result<()> { + for index in 0..self.params() { + if let Some(value) = arguments.next() { + value.bind(self, index + 1)?; + } else { + break; + } + } + + Ok(()) + } + + pub(super) fn reset(&mut self) { + // https://sqlite.org/c3ref/reset.html + // https://sqlite.org/c3ref/clear_bindings.html + + // the status value of reset is ignored because it merely propagates + // the status of the most recently invoked step function + + #[allow(unsafe_code)] + let _ = unsafe { sqlite3_reset(self.handle()) }; + + #[allow(unsafe_code)] + let _ = unsafe { sqlite3_clear_bindings(self.handle()) }; + } + + pub(super) async fn step(&mut self) -> crate::Result { + // https://sqlite.org/c3ref/step.html + + let handle = self.handle; + + #[allow(unsafe_code)] + let status = unsafe { + self.worker + .run(move || sqlite3_step(handle.0.as_ptr())) + .await + }; + + match status { + SQLITE_DONE => Ok(Step::Done), + + SQLITE_ROW => Ok(Step::Row), + + _ => { + return Err(SqliteError::from_connection(self.connection.0.as_ptr()).into()); + } + } + } +} + +impl Drop for Statement { + fn drop(&mut self) { + // https://sqlite.org/c3ref/finalize.html + #[allow(unsafe_code)] + unsafe { + let _ = sqlite3_finalize(self.handle()); + } + } +} diff --git a/sqlx-core/src/sqlite/types/bool.rs b/sqlx-core/src/sqlite/types/bool.rs new file mode 100644 index 0000000000..7b4e4c2f53 --- /dev/null +++ b/sqlx-core/src/sqlite/types/bool.rs @@ -0,0 +1,23 @@ +use crate::decode::Decode; +use crate::encode::Encode; +use crate::sqlite::types::{SqliteType, SqliteTypeAffinity}; +use crate::sqlite::{Sqlite, SqliteArgumentValue, SqliteTypeInfo, SqliteValue}; +use crate::types::Type; + +impl Type for bool { + fn type_info() -> SqliteTypeInfo { + SqliteTypeInfo::new(SqliteType::Boolean, SqliteTypeAffinity::Numeric) + } +} + +impl Encode for bool { + fn encode(&self, values: &mut Vec) { + values.push(SqliteArgumentValue::Int((*self).into())); + } +} + +impl<'a> Decode<'a, Sqlite> for bool { + fn decode(value: SqliteValue<'a>) -> crate::Result { + Ok(value.int() != 0) + } +} diff --git a/sqlx-core/src/sqlite/types/bytes.rs b/sqlx-core/src/sqlite/types/bytes.rs new file mode 100644 index 0000000000..9766349eb5 --- /dev/null +++ b/sqlx-core/src/sqlite/types/bytes.rs @@ -0,0 +1,42 @@ +use crate::decode::Decode; +use crate::encode::Encode; +use crate::sqlite::types::{SqliteType, SqliteTypeAffinity}; +use crate::sqlite::{Sqlite, SqliteArgumentValue, SqliteTypeInfo, SqliteValue}; +use crate::types::Type; + +impl Type for [u8] { + fn type_info() -> SqliteTypeInfo { + SqliteTypeInfo::new(SqliteType::Blob, SqliteTypeAffinity::Blob) + } +} + +impl Type for Vec { + fn type_info() -> SqliteTypeInfo { + <[u8] as Type>::type_info() + } +} + +impl Encode for [u8] { + fn encode(&self, values: &mut Vec) { + // TODO: look into a way to remove this allocation + values.push(SqliteArgumentValue::Blob(self.to_owned())); + } +} + +impl Encode for Vec { + fn encode(&self, values: &mut Vec) { + <[u8] as Encode>::encode(self, values) + } +} + +impl<'de> Decode<'de, Sqlite> for &'de [u8] { + fn decode(value: SqliteValue<'de>) -> crate::Result<&'de [u8]> { + Ok(value.blob()) + } +} + +impl<'de> Decode<'de, Sqlite> for Vec { + fn decode(value: SqliteValue<'de>) -> crate::Result> { + <&[u8] as Decode>::decode(value).map(ToOwned::to_owned) + } +} diff --git a/sqlx-core/src/sqlite/types/float.rs b/sqlx-core/src/sqlite/types/float.rs new file mode 100644 index 0000000000..1e37ece232 --- /dev/null +++ b/sqlx-core/src/sqlite/types/float.rs @@ -0,0 +1,41 @@ +use crate::decode::Decode; +use crate::encode::Encode; +use crate::sqlite::types::{SqliteType, SqliteTypeAffinity}; +use crate::sqlite::{Sqlite, SqliteArgumentValue, SqliteTypeInfo, SqliteValue}; +use crate::types::Type; + +impl Type for f32 { + fn type_info() -> SqliteTypeInfo { + SqliteTypeInfo::new(SqliteType::Float, SqliteTypeAffinity::Real) + } +} + +impl Encode for f32 { + fn encode(&self, values: &mut Vec) { + values.push(SqliteArgumentValue::Double((*self).into())); + } +} + +impl<'a> Decode<'a, Sqlite> for f32 { + fn decode(value: SqliteValue<'a>) -> crate::Result { + Ok(value.double() as f32) + } +} + +impl Type for f64 { + fn type_info() -> SqliteTypeInfo { + SqliteTypeInfo::new(SqliteType::Float, SqliteTypeAffinity::Real) + } +} + +impl Encode for f64 { + fn encode(&self, values: &mut Vec) { + values.push(SqliteArgumentValue::Double((*self).into())); + } +} + +impl<'a> Decode<'a, Sqlite> for f64 { + fn decode(value: SqliteValue<'a>) -> crate::Result { + Ok(value.double()) + } +} diff --git a/sqlx-core/src/sqlite/types/int.rs b/sqlx-core/src/sqlite/types/int.rs new file mode 100644 index 0000000000..8f8d6be9f2 --- /dev/null +++ b/sqlx-core/src/sqlite/types/int.rs @@ -0,0 +1,41 @@ +use crate::decode::Decode; +use crate::encode::Encode; +use crate::sqlite::types::{SqliteType, SqliteTypeAffinity}; +use crate::sqlite::{Sqlite, SqliteArgumentValue, SqliteTypeInfo, SqliteValue}; +use crate::types::Type; + +impl Type for i32 { + fn type_info() -> SqliteTypeInfo { + SqliteTypeInfo::new(SqliteType::Integer, SqliteTypeAffinity::Integer) + } +} + +impl Encode for i32 { + fn encode(&self, values: &mut Vec) { + values.push(SqliteArgumentValue::Int((*self).into())); + } +} + +impl<'a> Decode<'a, Sqlite> for i32 { + fn decode(value: SqliteValue<'a>) -> crate::Result { + Ok(value.int()) + } +} + +impl Type for i64 { + fn type_info() -> SqliteTypeInfo { + SqliteTypeInfo::new(SqliteType::Integer, SqliteTypeAffinity::Integer) + } +} + +impl Encode for i64 { + fn encode(&self, values: &mut Vec) { + values.push(SqliteArgumentValue::Int64((*self).into())); + } +} + +impl<'a> Decode<'a, Sqlite> for i64 { + fn decode(value: SqliteValue<'a>) -> crate::Result { + Ok(value.int64()) + } +} diff --git a/sqlx-core/src/sqlite/types/mod.rs b/sqlx-core/src/sqlite/types/mod.rs new file mode 100644 index 0000000000..dc1eb3d84e --- /dev/null +++ b/sqlx-core/src/sqlite/types/mod.rs @@ -0,0 +1,81 @@ +use std::fmt::{self, Display}; + +use crate::decode::Decode; +use crate::sqlite::value::SqliteValue; +use crate::sqlite::Sqlite; +use crate::types::TypeInfo; + +mod bool; +mod bytes; +mod float; +mod int; +mod str; + +// https://www.sqlite.org/c3ref/c_blob.html +#[derive(Debug, PartialEq, Clone, Copy)] +pub(crate) enum SqliteType { + Integer = 1, + Float = 2, + Text = 3, + Blob = 4, + Null = 5, + + // Non-standard extensions + Boolean, +} + +// https://www.sqlite.org/datatype3.html#type_affinity +#[derive(Debug, PartialEq, Clone, Copy)] +pub(crate) enum SqliteTypeAffinity { + Text, + Numeric, + Integer, + Real, + Blob, +} + +#[derive(Debug, Clone)] +pub struct SqliteTypeInfo { + pub(crate) r#type: SqliteType, + pub(crate) affinity: Option, +} + +impl SqliteTypeInfo { + fn new(r#type: SqliteType, affinity: SqliteTypeAffinity) -> Self { + Self { + r#type, + affinity: Some(affinity), + } + } +} + +impl Display for SqliteTypeInfo { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self.r#type { + SqliteType::Null => "NULL", + SqliteType::Text => "TEXT", + SqliteType::Boolean => "BOOLEAN", + SqliteType::Integer => "INTEGER", + SqliteType::Float => "DOUBLE", + SqliteType::Blob => "BLOB", + }) + } +} + +impl TypeInfo for SqliteTypeInfo { + fn compatible(&self, other: &Self) -> bool { + self.affinity == other.affinity + } +} + +impl<'de, T> Decode<'de, Sqlite> for Option +where + T: Decode<'de, Sqlite>, +{ + fn decode(value: SqliteValue<'de>) -> crate::Result { + match value.r#type() { + SqliteType::Null => Ok(None), + _ => >::decode(value).map(Some), + } + } +} diff --git a/sqlx-core/src/sqlite/types/str.rs b/sqlx-core/src/sqlite/types/str.rs new file mode 100644 index 0000000000..e22efd1755 --- /dev/null +++ b/sqlx-core/src/sqlite/types/str.rs @@ -0,0 +1,42 @@ +use crate::decode::Decode; +use crate::encode::Encode; +use crate::sqlite::types::{SqliteType, SqliteTypeAffinity}; +use crate::sqlite::{Sqlite, SqliteArgumentValue, SqliteTypeInfo, SqliteValue}; +use crate::types::Type; + +impl Type for str { + fn type_info() -> SqliteTypeInfo { + SqliteTypeInfo::new(SqliteType::Text, SqliteTypeAffinity::Text) + } +} + +impl Type for String { + fn type_info() -> SqliteTypeInfo { + >::type_info() + } +} + +impl Encode for str { + fn encode(&self, values: &mut Vec) { + // TODO: look into a way to remove this allocation + values.push(SqliteArgumentValue::Text(self.to_owned())); + } +} + +impl Encode for String { + fn encode(&self, values: &mut Vec) { + >::encode(self, values) + } +} + +impl<'de> Decode<'de, Sqlite> for &'de str { + fn decode(value: SqliteValue<'de>) -> crate::Result<&'de str> { + Ok(value.text()) + } +} + +impl<'de> Decode<'de, Sqlite> for String { + fn decode(value: SqliteValue<'de>) -> crate::Result { + <&str as Decode>::decode(value).map(ToOwned::to_owned) + } +} diff --git a/sqlx-core/src/sqlite/value.rs b/sqlx-core/src/sqlite/value.rs new file mode 100644 index 0000000000..227332c3ba --- /dev/null +++ b/sqlx-core/src/sqlite/value.rs @@ -0,0 +1,101 @@ +use core::slice; + +use std::ffi::CStr; +use std::str::from_utf8_unchecked; + +use libsqlite3_sys::{ + sqlite3_column_blob, sqlite3_column_bytes, sqlite3_column_double, sqlite3_column_int, + sqlite3_column_int64, sqlite3_column_text, sqlite3_column_type, SQLITE_BLOB, SQLITE_FLOAT, + SQLITE_INTEGER, SQLITE_NULL, SQLITE_TEXT, +}; + +use crate::sqlite::statement::Statement; +use crate::sqlite::types::SqliteType; + +pub struct SqliteValue<'c> { + index: usize, + statement: &'c Statement, +} + +impl<'c> SqliteValue<'c> { + #[inline] + pub(super) fn new(statement: &'c Statement, index: usize) -> Self { + Self { statement, index } + } +} + +// https://www.sqlite.org/c3ref/column_blob.html +// https://www.sqlite.org/capi3ref.html#sqlite3_column_blob + +// These routines return information about a single column of the current result row of a query. + +impl<'c> SqliteValue<'c> { + /// Returns the initial data type of the result column. + pub(super) fn r#type(&self) -> SqliteType { + #[allow(unsafe_code)] + let type_code = unsafe { sqlite3_column_type(self.statement.handle(), self.index as i32) }; + + match type_code { + SQLITE_INTEGER => SqliteType::Integer, + SQLITE_FLOAT => SqliteType::Float, + SQLITE_BLOB => SqliteType::Blob, + SQLITE_NULL => SqliteType::Null, + SQLITE_TEXT => SqliteType::Text, + + _ => unreachable!(), + } + } + + /// Returns the 32-bit INTEGER result. + pub(super) fn int(&self) -> i32 { + #[allow(unsafe_code)] + unsafe { + sqlite3_column_int(self.statement.handle(), self.index as i32) + } + } + + /// Returns the 64-bit INTEGER result. + pub(super) fn int64(&self) -> i64 { + #[allow(unsafe_code)] + unsafe { + sqlite3_column_int64(self.statement.handle(), self.index as i32) + } + } + + /// Returns the 64-bit, REAL result. + pub(super) fn double(&self) -> f64 { + #[allow(unsafe_code)] + unsafe { + sqlite3_column_double(self.statement.handle(), self.index as i32) + } + } + + /// Returns the UTF-8 TEXT result. + pub(super) fn text(&self) -> &'c str { + #[allow(unsafe_code)] + unsafe { + let ptr = sqlite3_column_text(self.statement.handle(), self.index as i32) as *const i8; + + debug_assert!(!ptr.is_null()); + + from_utf8_unchecked(CStr::from_ptr(ptr).to_bytes()) + } + } + + /// Returns the BLOB result. + pub(super) fn blob(&self) -> &'c [u8] { + let index = self.index as i32; + + #[allow(unsafe_code)] + let ptr = unsafe { sqlite3_column_blob(self.statement.handle(), index) }; + + // Returns the size of the BLOB result in bytes. + #[allow(unsafe_code)] + let len = unsafe { sqlite3_column_bytes(self.statement.handle(), index) }; + + #[allow(unsafe_code)] + unsafe { + slice::from_raw_parts(ptr as *const u8, len as usize) + } + } +} diff --git a/sqlx-core/src/sqlite/worker.rs b/sqlx-core/src/sqlite/worker.rs new file mode 100644 index 0000000000..20a12a48b2 --- /dev/null +++ b/sqlx-core/src/sqlite/worker.rs @@ -0,0 +1,67 @@ +use crossbeam_queue::ArrayQueue; +use futures_channel::oneshot; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::thread::{park, spawn, JoinHandle}; + +// After tinkering with this, I believe the safest solution is to spin up a discrete thread per +// SQLite connection and perform all I/O operations for SQLite on _that_ thread. To this effect +// we have a worker struct that is a thin message passing API to run messages on the worker thread. + +#[derive(Clone)] +pub(crate) struct Worker { + running: Arc, + queue: Arc>>, + handle: Arc>, +} + +impl Worker { + pub(crate) fn new() -> Self { + let queue: Arc>> = Arc::new(ArrayQueue::new(1)); + let running = Arc::new(AtomicBool::new(true)); + + Self { + handle: Arc::new(spawn({ + let queue = queue.clone(); + let running = running.clone(); + + move || { + while running.load(Ordering::SeqCst) { + if let Ok(message) = queue.pop() { + (message)(); + } + + park(); + } + } + })), + queue, + running, + } + } + + pub(crate) async fn run(&mut self, f: F) -> R + where + F: Send + 'static, + R: Send + 'static, + F: FnOnce() -> R, + { + let (sender, receiver) = oneshot::channel::(); + + let _ = self.queue.push(Box::new(move || { + let _ = sender.send(f()); + })); + + self.handle.thread().unpark(); + + receiver.await.unwrap() + } +} + +impl Drop for Worker { + fn drop(&mut self) { + if Arc::strong_count(&self.handle) == 1 { + self.running.store(false, Ordering::SeqCst); + } + } +} diff --git a/sqlx-core/src/transaction.rs b/sqlx-core/src/transaction.rs index 85b00efd0f..d98dc40ecb 100644 --- a/sqlx-core/src/transaction.rs +++ b/sqlx-core/src/transaction.rs @@ -102,7 +102,10 @@ where { type Database = T::Database; - fn execute<'e, 'q, E: 'e>(&'e mut self, query: E) -> BoxFuture<'e, crate::Result> + fn execute<'e, 'q: 'e, 't: 'e, E: 'e>( + &'t mut self, + query: E, + ) -> BoxFuture<'e, crate::Result> where E: Execute<'q, Self::Database>, { @@ -127,14 +130,14 @@ where } } -impl<'c, DB, T> RefExecutor<'c> for &'c mut Transaction +impl<'e, DB, T> RefExecutor<'e> for &'e mut Transaction where DB: Database, T: Connection, { type Database = DB; - fn fetch_by_ref<'q, E>(self, query: E) -> >::Cursor + fn fetch_by_ref<'q, E>(self, query: E) -> >::Cursor where E: Execute<'q, Self::Database>, { diff --git a/sqlx-core/src/url.rs b/sqlx-core/src/url.rs index 0a85d8fb0c..cc4a2d6daa 100644 --- a/sqlx-core/src/url.rs +++ b/sqlx-core/src/url.rs @@ -1,6 +1,7 @@ use std::borrow::Cow; use std::convert::{TryFrom, TryInto}; +#[derive(Debug)] pub struct Url(url::Url); impl TryFrom for Url { @@ -28,6 +29,11 @@ impl<'s> TryFrom<&'s String> for Url { } impl Url { + #[allow(dead_code)] + pub(crate) fn as_str(&self) -> &str { + self.0.as_str() + } + pub fn host(&self) -> &str { let host = self.0.host_str(); diff --git a/sqlx-macros/Cargo.toml b/sqlx-macros/Cargo.toml index 01a9bda6cb..231e4bed8f 100644 --- a/sqlx-macros/Cargo.toml +++ b/sqlx-macros/Cargo.toml @@ -24,6 +24,7 @@ runtime-tokio = [ "sqlx/runtime-tokio", "tokio", "lazy_static" ] # database mysql = [ "sqlx/mysql" ] postgres = [ "sqlx/postgres" ] +sqlite = [ "sqlx/sqlite" ] # type chrono = [ "sqlx/chrono" ] diff --git a/sqlx-macros/src/database/mod.rs b/sqlx-macros/src/database/mod.rs index 4a2eb77e60..87f3f030af 100644 --- a/sqlx-macros/src/database/mod.rs +++ b/sqlx-macros/src/database/mod.rs @@ -84,3 +84,6 @@ mod postgres; #[cfg(feature = "mysql")] mod mysql; + +#[cfg(feature = "sqlite")] +mod sqlite; diff --git a/sqlx-macros/src/database/sqlite.rs b/sqlx-macros/src/database/sqlite.rs new file mode 100644 index 0000000000..d4278e7bd8 --- /dev/null +++ b/sqlx-macros/src/database/sqlite.rs @@ -0,0 +1,13 @@ +impl_database_ext! { + sqlx::sqlite::Sqlite { + i32, + i64, + f32, + f64, + String, + Vec, + }, + ParamChecking::Weak, + feature-types: _info => None, + row = sqlx::sqlite::SqliteRow +} diff --git a/sqlx-macros/src/derives.rs b/sqlx-macros/src/derives.rs index 1d4a7172f3..b4e73a0732 100644 --- a/sqlx-macros/src/derives.rs +++ b/sqlx-macros/src/derives.rs @@ -25,10 +25,10 @@ pub(crate) fn expand_derive_encode(input: DeriveInput) -> syn::Result for #ident #ty_generics #where_clause { - fn encode(&self, buf: &mut std::vec::Vec) { + fn encode(&self, buf: &mut ::RawBuffer) { sqlx::encode::Encode::encode(&self.0, buf) } - fn encode_nullable(&self, buf: &mut std::vec::Vec) -> sqlx::encode::IsNull { + fn encode_nullable(&self, buf: &mut ::RawBuffer) -> sqlx::encode::IsNull { sqlx::encode::Encode::encode_nullable(&self.0, buf) } fn size_hint(&self) -> usize { diff --git a/sqlx-macros/src/lib.rs b/sqlx-macros/src/lib.rs index c0272d9452..0bed307f31 100644 --- a/sqlx-macros/src/lib.rs +++ b/sqlx-macros/src/lib.rs @@ -64,6 +64,20 @@ macro_rules! async_macro ( let db_url = Url::parse(&dotenv::var("DATABASE_URL").map_err(|_| "DATABASE_URL not set")?)?; match db_url.scheme() { + #[cfg(feature = "sqlite")] + "sqlite" => { + let $db = sqlx::sqlite::SqliteConnection::connect(db_url.as_str()) + .await + .map_err(|e| format!("failed to connect to database: {}", e))?; + + $expr.await + } + #[cfg(not(feature = "sqlite"))] + "sqlite" => Err(format!( + "DATABASE_URL {} has the scheme of a SQLite database but the `sqlite` \ + feature of sqlx was not enabled", + db_url + ).into()), #[cfg(feature = "postgres")] "postgresql" | "postgres" => { let $db = sqlx::postgres::PgConnection::connect(db_url.as_str()) diff --git a/sqlx-test/src/lib.rs b/sqlx-test/src/lib.rs index d3bbad6998..a13aa6d213 100644 --- a/sqlx-test/src/lib.rs +++ b/sqlx-test/src/lib.rs @@ -90,6 +90,13 @@ macro_rules! MySql_query_for_test_prepared_type { }; } +#[macro_export] +macro_rules! Sqlite_query_for_test_prepared_type { + () => { + "SELECT {} is ?, ? as _1" + }; +} + #[macro_export] macro_rules! Postgres_query_for_test_prepared_type { () => { diff --git a/src/lib.rs b/src/lib.rs index 1a31ba6ee3..d1dc6aaad5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,13 +9,13 @@ compile_error!("only one of 'runtime-async-std' or 'runtime-tokio' features must pub use sqlx_core::arguments; pub use sqlx_core::connection::{Connect, Connection}; pub use sqlx_core::cursor::Cursor; -pub use sqlx_core::database::{Database, HasCursor, HasRawValue, HasRow}; +pub use sqlx_core::database::{self, Database}; pub use sqlx_core::describe; -pub use sqlx_core::executor::{Execute, Executor}; +pub use sqlx_core::executor::{Execute, Executor, RefExecutor}; pub use sqlx_core::pool::{self, Pool}; pub use sqlx_core::query::{self, query, Query}; pub use sqlx_core::query_as::{query_as, QueryAs}; -pub use sqlx_core::row::{FromRow, Row}; +pub use sqlx_core::row::{self, FromRow, Row}; pub use sqlx_core::transaction::Transaction; #[doc(inline)] @@ -32,6 +32,10 @@ pub use sqlx_core::mysql::{self, MySql, MySqlConnection, MySqlPool}; #[cfg_attr(docsrs, doc(cfg(feature = "postgres")))] pub use sqlx_core::postgres::{self, PgConnection, PgPool, Postgres}; +#[cfg(feature = "sqlite")] +#[cfg_attr(docsrs, doc(cfg(feature = "sqlite")))] +pub use sqlx_core::sqlite::{self, Sqlite, SqliteConnection, SqlitePool}; + #[cfg(feature = "macros")] #[doc(hidden)] pub extern crate sqlx_macros; @@ -75,4 +79,7 @@ pub mod prelude { #[cfg(feature = "mysql")] pub use super::mysql::MySqlQueryAs; + + #[cfg(feature = "sqlite")] + pub use super::sqlite::SqliteQueryAs; } diff --git a/tests/derives.rs b/tests/derives.rs index facdbea90a..d5120031c3 100644 --- a/tests/derives.rs +++ b/tests/derives.rs @@ -5,30 +5,33 @@ use sqlx::encode::Encode; struct Foo(i32); #[test] -#[cfg(feature = "mysql")] -fn encode_mysql() { - encode_with_db::(); +#[cfg(feature = "postgres")] +fn encode_with_postgres() { + use sqlx_core::postgres::Postgres; + + let example = Foo(0x1122_3344); + + let mut encoded = Vec::new(); + let mut encoded_orig = Vec::new(); + + Encode::::encode(&example, &mut encoded); + Encode::::encode(&example.0, &mut encoded_orig); + + assert_eq!(encoded, encoded_orig); } #[test] -#[cfg(feature = "postgres")] -fn encode_postgres() { - encode_with_db::(); -} +#[cfg(feature = "mysql")] +fn encode_with_mysql() { + use sqlx_core::mysql::MySql; -#[allow(dead_code)] -fn encode_with_db() -where - Foo: Encode, - i32: Encode, -{ let example = Foo(0x1122_3344); let mut encoded = Vec::new(); let mut encoded_orig = Vec::new(); - Encode::::encode(&example, &mut encoded); - Encode::::encode(&example.0, &mut encoded_orig); + Encode::::encode(&example, &mut encoded); + Encode::::encode(&example.0, &mut encoded_orig); assert_eq!(encoded, encoded_orig); } diff --git a/tests/postgres.rs b/tests/postgres.rs index f7847aa96a..2eb1c92278 100644 --- a/tests/postgres.rs +++ b/tests/postgres.rs @@ -1,6 +1,6 @@ use futures::TryStreamExt; -use sqlx::postgres::{PgPool, PgRow}; -use sqlx::{postgres::PgConnection, Connect, Executor, Row}; +use sqlx::postgres::{PgPool, PgQueryAs, PgRow}; +use sqlx::{postgres::PgConnection, Connect, Connection, Executor, Row}; use std::time::Duration; #[cfg_attr(feature = "runtime-async-std", async_std::test)] @@ -86,6 +86,72 @@ async fn it_can_return_interleaved_nulls_issue_104() -> anyhow::Result<()> { Ok(()) } +#[cfg_attr(feature = "runtime-async-std", async_std::test)] +#[cfg_attr(feature = "runtime-tokio", tokio::test)] +async fn it_can_work_with_transactions() -> anyhow::Result<()> { + let mut conn = connect().await?; + + conn.execute("CREATE TABLE IF NOT EXISTS _sqlx_users_1922 (id INTEGER PRIMARY KEY)") + .await?; + + conn.execute("TRUNCATE _sqlx_users_1922").await?; + + // begin .. rollback + + let mut tx = conn.begin().await?; + + sqlx::query("INSERT INTO _sqlx_users_1922 (id) VALUES ($1)") + .bind(10_i32) + .execute(&mut tx) + .await?; + + conn = tx.rollback().await?; + + let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_1922") + .fetch_one(&mut conn) + .await?; + + assert_eq!(count, 0); + + // begin .. commit + + let mut tx = conn.begin().await?; + + sqlx::query("INSERT INTO _sqlx_users_1922 (id) VALUES ($1)") + .bind(10_i32) + .execute(&mut tx) + .await?; + + conn = tx.commit().await?; + + let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_1922") + .fetch_one(&mut conn) + .await?; + + assert_eq!(count, 1); + + // begin .. (drop) + + { + let mut tx = conn.begin().await?; + + sqlx::query("INSERT INTO _sqlx_users_1922 (id) VALUES ($1)") + .bind(20_i32) + .execute(&mut tx) + .await?; + } + + conn = connect().await?; + + let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_1922") + .fetch_one(&mut conn) + .await?; + + assert_eq!(count, 1); + + Ok(()) +} + // run with `cargo test --features postgres -- --ignored --nocapture pool_smoke_test` #[ignore] #[cfg_attr(feature = "runtime-async-std", async_std::test)] diff --git a/tests/sqlite-raw.rs b/tests/sqlite-raw.rs new file mode 100644 index 0000000000..04c3e10725 --- /dev/null +++ b/tests/sqlite-raw.rs @@ -0,0 +1,52 @@ +//! Tests for the raw (unprepared) query API for Sqlite. + +use sqlx::{Cursor, Executor, Row, Sqlite}; +use sqlx_test::new; + +#[cfg_attr(feature = "runtime-async-std", async_std::test)] +#[cfg_attr(feature = "runtime-tokio", tokio::test)] +async fn test_select_expression() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let mut cursor = conn.fetch("SELECT 5"); + let row = cursor.next().await?.unwrap(); + + assert!(5i32 == row.try_get::(0)?); + + Ok(()) +} + +#[cfg_attr(feature = "runtime-async-std", async_std::test)] +#[cfg_attr(feature = "runtime-tokio", tokio::test)] +async fn test_multi_read_write() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let mut cursor = conn.fetch( + " +CREATE TABLE IF NOT EXISTS _sqlx_test ( + id INT PRIMARY KEY, + text TEXT NOT NULL +); + +SELECT 'Hello World' as _1; + +INSERT INTO _sqlx_test (text) VALUES ('this is a test'); + +SELECT id, text FROM _sqlx_test; + ", + ); + + let row = cursor.next().await?.unwrap(); + + assert!("Hello World" == row.try_get::<&str, _>("_1")?); + + let row = cursor.next().await?.unwrap(); + + let id: i64 = row.try_get("id")?; + let text: &str = row.try_get("text")?; + + assert_eq!(0, id); + assert_eq!("this is a test", text); + + Ok(()) +} diff --git a/tests/sqlite-types.rs b/tests/sqlite-types.rs new file mode 100644 index 0000000000..5f04d68e92 --- /dev/null +++ b/tests/sqlite-types.rs @@ -0,0 +1,46 @@ +use sqlx::Sqlite; +use sqlx_test::test_type; + +test_type!(null( + Sqlite, + Option, + "NULL" == None:: +)); + +test_type!(bool(Sqlite, bool, "FALSE" == false, "TRUE" == true)); + +test_type!(i32(Sqlite, i32, "94101" == 94101_i32)); + +test_type!(i64(Sqlite, i64, "9358295312" == 9358295312_i64)); + +// NOTE: This behavior can be surprising. Floating-point parameters are widening to double which can +// result in strange rounding. +test_type!(f32( + Sqlite, + f32, + "3.1410000324249268" == 3.141f32 as f64 as f32 +)); + +test_type!(f64( + Sqlite, + f64, + "939399419.1225182" == 939399419.1225182_f64 +)); + +test_type!(string( + Sqlite, + String, + "'this is foo'" == "this is foo", + "''" == "" +)); + +test_type!(bytes( + Sqlite, + Vec, + "X'DEADBEEF'" + == vec![0xDE_u8, 0xAD, 0xBE, 0xEF], + "X''" + == Vec::::new(), + "X'0000000052'" + == vec![0_u8, 0, 0, 0, 0x52] +)); diff --git a/tests/sqlite.rs b/tests/sqlite.rs new file mode 100644 index 0000000000..48fd696e62 --- /dev/null +++ b/tests/sqlite.rs @@ -0,0 +1,145 @@ +use futures::TryStreamExt; +use sqlx::{sqlite::SqliteQueryAs, Connect, Connection, Executor, Sqlite, SqliteConnection}; +use sqlx_test::new; + +#[cfg_attr(feature = "runtime-async-std", async_std::test)] +#[cfg_attr(feature = "runtime-tokio", tokio::test)] +async fn it_connects() -> anyhow::Result<()> { + Ok(new::().await?.ping().await?) +} + +#[cfg_attr(feature = "runtime-async-std", async_std::test)] +#[cfg_attr(feature = "runtime-tokio", tokio::test)] +async fn it_fails_to_connect() -> anyhow::Result<()> { + // empty connection string + assert!(SqliteConnection::connect("").await.is_err()); + assert!( + SqliteConnection::connect("sqlite:///please_do_not_run_sqlx_tests_as_root") + .await + .is_err() + ); + + Ok(()) +} + +#[cfg_attr(feature = "runtime-async-std", async_std::test)] +#[cfg_attr(feature = "runtime-tokio", tokio::test)] +async fn it_fails_to_parse() -> anyhow::Result<()> { + let mut conn = new::().await?; + let res = conn.execute("SEELCT 1").await; + + assert!(res.is_err()); + + let err = res.unwrap_err().to_string(); + + assert_eq!("near \"SEELCT\": syntax error", err); + + Ok(()) +} + +#[cfg_attr(feature = "runtime-async-std", async_std::test)] +#[cfg_attr(feature = "runtime-tokio", tokio::test)] +async fn it_executes() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let _ = conn + .execute( + r#" +CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY) + "#, + ) + .await?; + + for index in 1..=10_i32 { + let cnt = sqlx::query("INSERT INTO users (id) VALUES (?)") + .bind(index) + .execute(&mut conn) + .await?; + + assert_eq!(cnt, 1); + } + + let sum: i32 = sqlx::query_as("SELECT id FROM users") + .fetch(&mut conn) + .try_fold(0_i32, |acc, (x,): (i32,)| async move { Ok(acc + x) }) + .await?; + + assert_eq!(sum, 55); + + Ok(()) +} + +#[cfg_attr(feature = "runtime-async-std", async_std::test)] +#[cfg_attr(feature = "runtime-tokio", tokio::test)] +async fn it_can_execute_multiple_statements() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let affected = conn + .execute( + r#" +CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY, other INTEGER); +INSERT INTO users DEFAULT VALUES; + "#, + ) + .await?; + + assert_eq!(affected, 1); + + for index in 2..5_i32 { + let (id, other): (i32, i32) = sqlx::query_as( + r#" +INSERT INTO users (other) VALUES (?); +SELECT id, other FROM users WHERE id = last_insert_rowid(); + "#, + ) + .bind(index) + .fetch_one(&mut conn) + .await?; + + assert_eq!(id, index); + assert_eq!(other, index); + } + + Ok(()) +} + +#[cfg_attr(feature = "runtime-async-std", async_std::test)] +#[cfg_attr(feature = "runtime-tokio", tokio::test)] +async fn it_describes() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let _ = conn + .execute( + r#" +CREATE TEMPORARY TABLE describe_test ( + _1 int primary key, + _2 text not null, + _3 blob, + _4 boolean, + _5 float, + _6 varchar(255), + _7 double, + _8 bigint +) + "#, + ) + .await?; + + let describe = conn + .describe("select nt.*, false from describe_test nt") + .await?; + + assert_eq!(describe.result_columns[0].type_info.to_string(), "INTEGER"); + assert_eq!(describe.result_columns[1].type_info.to_string(), "TEXT"); + assert_eq!(describe.result_columns[2].type_info.to_string(), "BLOB"); + assert_eq!(describe.result_columns[3].type_info.to_string(), "BOOLEAN"); + assert_eq!(describe.result_columns[4].type_info.to_string(), "DOUBLE"); + assert_eq!(describe.result_columns[5].type_info.to_string(), "TEXT"); + assert_eq!(describe.result_columns[6].type_info.to_string(), "DOUBLE"); + assert_eq!(describe.result_columns[7].type_info.to_string(), "INTEGER"); + + // Expressions can not be described + assert_eq!(describe.result_columns[8].type_info.to_string(), "NULL"); + + Ok(()) +}