diff --git a/src/db/mod.rs b/src/db/mod.rs index e3679598b..59071e0da 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -7,6 +7,9 @@ pub use self::file::add_path_into_database; pub use self::migrate::migrate; pub use self::pool::{Pool, PoolError}; +#[cfg(test)] +pub(crate) use self::pool::PoolConnection; + mod add_package; pub mod blacklist; mod delete_crate; diff --git a/src/db/pool.rs b/src/db/pool.rs index 79710ae95..92eb88011 100644 --- a/src/db/pool.rs +++ b/src/db/pool.rs @@ -1,22 +1,30 @@ use crate::Config; use postgres::Connection; -use std::marker::PhantomData; +use r2d2_postgres::PostgresConnectionManager; -#[cfg(test)] -use std::sync::{Arc, Mutex, MutexGuard}; +pub(crate) type PoolConnection = r2d2::PooledConnection; + +const DEFAULT_SCHEMA: &str = "public"; #[derive(Debug, Clone)] -pub enum Pool { - R2D2(r2d2::Pool), - #[cfg(test)] - Simple(Arc>), +pub struct Pool { + pool: r2d2::Pool, } impl Pool { pub fn new(config: &Config) -> Result { + Self::new_inner(config, DEFAULT_SCHEMA) + } + + #[cfg(test)] + pub(crate) fn new_with_schema(config: &Config, schema: &str) -> Result { + Self::new_inner(config, schema) + } + + fn new_inner(config: &Config, schema: &str) -> Result { crate::web::metrics::MAX_DB_CONNECTIONS.set(config.max_pool_size as i64); - let manager = r2d2_postgres::PostgresConnectionManager::new( + let manager = PostgresConnectionManager::new( config.database_url.as_str(), r2d2_postgres::TlsMode::None, ) @@ -25,73 +33,54 @@ impl Pool { let pool = r2d2::Pool::builder() .max_size(config.max_pool_size) .min_idle(Some(config.min_pool_idle)) + .connection_customizer(Box::new(SetSchema::new(schema))) .build(manager) .map_err(PoolError::PoolCreationFailed)?; - Ok(Pool::R2D2(pool)) - } - - #[cfg(test)] - pub(crate) fn new_simple(conn: Arc>) -> Self { - Pool::Simple(conn) + Ok(Pool { pool }) } - pub fn get(&self) -> Result, PoolError> { - match self { - Self::R2D2(r2d2) => match r2d2.get() { - Ok(conn) => Ok(DerefConnection::Connection(conn, PhantomData)), - Err(err) => { - crate::web::metrics::FAILED_DB_CONNECTIONS.inc(); - Err(PoolError::ConnectionError(err)) - } - }, - - #[cfg(test)] - Self::Simple(mutex) => Ok(DerefConnection::Guard( - mutex.lock().expect("failed to lock the connection"), - )), + pub fn get(&self) -> Result { + match self.pool.get() { + Ok(conn) => Ok(conn), + Err(err) => { + crate::web::metrics::FAILED_DB_CONNECTIONS.inc(); + Err(PoolError::ConnectionError(err)) + } } } pub(crate) fn used_connections(&self) -> u32 { - match self { - Self::R2D2(conn) => conn.state().connections - conn.state().idle_connections, - - #[cfg(test)] - Self::Simple(..) => 0, - } + self.pool.state().connections - self.pool.state().idle_connections } pub(crate) fn idle_connections(&self) -> u32 { - match self { - Self::R2D2(conn) => conn.state().idle_connections, - - #[cfg(test)] - Self::Simple(..) => 0, - } + self.pool.state().idle_connections } } -pub enum DerefConnection<'a> { - Connection( - r2d2::PooledConnection, - PhantomData<&'a ()>, - ), - - #[cfg(test)] - Guard(MutexGuard<'a, Connection>), +#[derive(Debug)] +struct SetSchema { + schema: String, } -impl<'a> std::ops::Deref for DerefConnection<'a> { - type Target = Connection; - - fn deref(&self) -> &Connection { - match self { - Self::Connection(conn, ..) => conn, +impl SetSchema { + fn new(schema: &str) -> Self { + Self { + schema: schema.into(), + } + } +} - #[cfg(test)] - Self::Guard(guard) => &guard, +impl r2d2::CustomizeConnection for SetSchema { + fn on_acquire(&self, conn: &mut Connection) -> Result<(), postgres::Error> { + if self.schema != DEFAULT_SCHEMA { + conn.execute( + &format!("SET search_path TO {}, {};", self.schema, DEFAULT_SCHEMA), + &[], + )?; } + Ok(()) } } diff --git a/src/test/mod.rs b/src/test/mod.rs index 21e9152aa..f7173b8bf 100644 --- a/src/test/mod.rs +++ b/src/test/mod.rs @@ -1,5 +1,6 @@ mod fakes; +use crate::db::{Pool, PoolConnection}; use crate::storage::s3::TestS3; use crate::web::Server; use crate::Config; @@ -11,10 +12,7 @@ use reqwest::{ blocking::{Client, RequestBuilder}, Method, }; -use std::{ - panic, - sync::{Arc, Mutex, MutexGuard}, -}; +use std::{panic, sync::Arc}; pub(crate) fn wrapper(f: impl FnOnce(&TestEnvironment) -> Result<(), Error>) { let _ = dotenv::dotenv(); @@ -123,7 +121,13 @@ impl TestEnvironment { } fn base_config(&self) -> Config { - Config::from_env().expect("failed to get base config") + let mut config = Config::from_env().expect("failed to get base config"); + + // Use less connections for each test compared to production. + config.max_pool_size = 2; + config.min_pool_idle = 0; + + config } pub(crate) fn override_config(&self, f: impl FnOnce(&mut Config)) { @@ -157,7 +161,7 @@ impl TestEnvironment { } pub(crate) struct TestDatabase { - conn: Arc>, + pool: Pool, schema: String, } @@ -178,13 +182,15 @@ impl TestDatabase { crate::db::migrate(None, &conn)?; Ok(TestDatabase { - conn: Arc::new(Mutex::new(conn)), + pool: Pool::new_with_schema(config, &schema)?, schema, }) } - pub(crate) fn conn(&self) -> MutexGuard { - self.conn.lock().expect("failed to lock the connection") + pub(crate) fn conn(&self) -> PoolConnection { + self.pool + .get() + .expect("failed to get a connection out of the pool") } pub(crate) fn fake_release(&self) -> fakes::FakeRelease { @@ -212,8 +218,8 @@ pub(crate) struct TestFrontend { impl TestFrontend { fn new(db: &TestDatabase, config: Arc) -> Self { Self { - server: Server::start_test(db.conn.clone(), config) - .expect("failed to start the server"), + server: Server::start(Some("127.0.0.1:0"), false, db.pool.clone(), config) + .expect("failed to start the web server"), client: Client::new(), } } diff --git a/src/web/mod.rs b/src/web/mod.rs index 56c515fec..2c3921cb5 100644 --- a/src/web/mod.rs +++ b/src/web/mod.rs @@ -79,9 +79,6 @@ use std::net::SocketAddr; use std::sync::Arc; use std::{env, fmt, path::PathBuf, time::Duration}; -#[cfg(test)] -use std::sync::Mutex; - /// Duration of static files for staticfile and DatabaseFileHandler (in seconds) const STATIC_FILE_CACHE_DURATION: u64 = 60 * 60 * 24 * 30 * 12; // 12 months const STYLE_CSS: &str = include_str!(concat!(env!("OUT_DIR"), "/style.css")); @@ -397,21 +394,6 @@ impl Server { Ok(server) } - #[cfg(test)] - pub(crate) fn start_test( - conn: Arc>, - config: Arc, - ) -> Result { - let templates = TemplateData::new(&conn.lock().unwrap())?; - - Ok(Self::start_inner( - "127.0.0.1:0", - Pool::new_simple(conn.clone()), - config, - Arc::new(templates), - )) - } - fn start_inner( addr: &str, pool: Pool,