Skip to content

Commit

Permalink
fix: ensure PG connection is established before using it
Browse files Browse the repository at this point in the history
Fixes #1940.
  • Loading branch information
crepererum committed Jul 27, 2022
1 parent 78a0a59 commit 528ba22
Showing 1 changed file with 40 additions and 21 deletions.
61 changes: 40 additions & 21 deletions sqlx-core/src/postgres/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use either::Either;
use futures_channel::mpsc;
use futures_core::future::BoxFuture;
use futures_core::stream::{BoxStream, Stream};
use futures_util::{FutureExt, StreamExt, TryStreamExt};

use crate::describe::Describe;
use crate::error::Error;
Expand Down Expand Up @@ -96,6 +97,7 @@ impl PgListener {
/// The channel name is quoted here to ensure case sensitivity.
pub async fn listen(&mut self, channel: &str) -> Result<(), Error> {
self.connection()
.await?
.execute(&*format!(r#"LISTEN "{}""#, ident(channel)))
.await?;

Expand All @@ -112,21 +114,22 @@ impl PgListener {
let beg = self.channels.len();
self.channels.extend(channels.into_iter().map(|s| s.into()));

self.connection
.as_mut()
.unwrap()
.execute(&*build_listen_all_query(&self.channels[beg..]))
.await?;
let query = build_listen_all_query(&self.channels[beg..]);
self.connection().await?.execute(&*query).await?;

Ok(())
}

/// Stops listening for notifications on a channel.
/// The channel name is quoted here to ensure case sensitivity.
pub async fn unlisten(&mut self, channel: &str) -> Result<(), Error> {
self.connection()
.execute(&*format!(r#"UNLISTEN "{}""#, ident(channel)))
.await?;
// use RAW connection and do NOT re-connect automatically, since this is not required for
// UNLISTEN (we've disconnected anyways)
if let Some(connection) = self.connection.as_mut() {
connection
.execute(&*format!(r#"UNLISTEN "{}""#, ident(channel)))
.await?;
}

if let Some(pos) = self.channels.iter().position(|s| s == channel) {
self.channels.remove(pos);
Expand All @@ -137,7 +140,11 @@ impl PgListener {

/// Stops listening for notifications on all channels.
pub async fn unlisten_all(&mut self) -> Result<(), Error> {
self.connection().execute("UNLISTEN *").await?;
// use RAW connection and do NOT re-connect automatically, since this is not required for
// UNLISTEN (we've disconnected anyways)
if let Some(connection) = self.connection.as_mut() {
connection.execute("UNLISTEN *").await?;
}

self.channels.clear();

Expand All @@ -161,8 +168,11 @@ impl PgListener {
}

#[inline]
fn connection(&mut self) -> &mut PgConnection {
self.connection.as_mut().unwrap()
async fn connection(&mut self) -> Result<&mut PgConnection, Error> {
// Ensure we have an active connection to work with.
self.connect_if_needed().await?;

Ok(self.connection.as_mut().unwrap())
}

/// Receives the next notification available from any of the subscribed channels.
Expand Down Expand Up @@ -237,10 +247,7 @@ impl PgListener {
let mut close_event = (!self.ignore_close_event).then(|| self.pool.close_event());

loop {
// Ensure we have an active connection to work with.
self.connect_if_needed().await?;

let next_message = self.connection().stream.recv_unchecked();
let next_message = self.connection().await?.stream.recv_unchecked();

let res = if let Some(ref mut close_event) = close_event {
// cancels the wait and returns `Err(PoolClosed)` if the pool is closed
Expand All @@ -256,7 +263,7 @@ impl PgListener {
// The connection is dead, ensure that it is dropped,
// update self state, and loop to try again.
Err(Error::Io(err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
self.buffer_tx = self.connection().stream.notifications.take();
self.buffer_tx = self.connection().await?.stream.notifications.take();
self.connection = None;

// lost connection
Expand All @@ -277,7 +284,7 @@ impl PgListener {

// Mark the connection as ready for another query
MessageFormat::ReadyForQuery => {
self.connection().pending_ready_for_query_count -= 1;
self.connection().await?.pending_ready_for_query_count -= 1;
}

// Ignore unexpected messages
Expand Down Expand Up @@ -336,7 +343,13 @@ impl<'c> Executor<'c> for &'c mut PgListener {
'c: 'e,
E: Execute<'q, Self::Database>,
{
self.connection().fetch_many(query)
futures_util::stream::once(async move {
// need some basic type annotation to help the compiler a bit
let res: Result<_, Error> = Ok(self.connection().await?.fetch_many(query));
res
})
.try_flatten()
.boxed()
}

fn fetch_optional<'e, 'q: 'e, E: 'q>(
Expand All @@ -347,7 +360,7 @@ impl<'c> Executor<'c> for &'c mut PgListener {
'c: 'e,
E: Execute<'q, Self::Database>,
{
self.connection().fetch_optional(query)
async move { self.connection().await?.fetch_optional(query).await }.boxed()
}

fn prepare_with<'e, 'q: 'e>(
Expand All @@ -358,7 +371,13 @@ impl<'c> Executor<'c> for &'c mut PgListener {
where
'c: 'e,
{
self.connection().prepare_with(query, parameters)
async move {
self.connection()
.await?
.prepare_with(query, parameters)
.await
}
.boxed()
}

#[doc(hidden)]
Expand All @@ -369,7 +388,7 @@ impl<'c> Executor<'c> for &'c mut PgListener {
where
'c: 'e,
{
self.connection().describe(query)
async move { self.connection().await?.describe(query).await }.boxed()
}
}

Expand Down

0 comments on commit 528ba22

Please sign in to comment.