diff --git a/src/bin/main.rs b/src/bin/main.rs index 649865824..aa97e6512 100644 --- a/src/bin/main.rs +++ b/src/bin/main.rs @@ -2,7 +2,7 @@ use actix_web::dev::Server; use clap::Parser; use log::{error, info, warn}; use martin::config::{read_config, Config, ConfigBuilder}; -use martin::pg::config::{PgArgs, PgConfigBuilder}; +use martin::pg::config::{PgArgs, PgConfig}; use martin::pg::configurator::resolve_pg_data; use martin::source::IdResolver; use martin::srv::config::{SrvArgs, SrvConfigBuilder}; @@ -49,7 +49,7 @@ impl From for ConfigBuilder { ConfigBuilder { srv: SrvConfigBuilder::from(args.srv), - pg: PgConfigBuilder::from((args.pg, args.connection)), + pg: PgConfig::from((args.pg, args.connection)), unrecognized: HashMap::new(), } } @@ -89,7 +89,7 @@ async fn start(args: Args) -> io::Result { info!("Saving config to {file_name}, use --config to load it"); File::create(file_name)?.write_all(yaml.as_bytes())?; } - } else if config.pg.discover_functions || config.pg.discover_tables { + } else if config.pg.run_autodiscovery { info!("Martin has been configured with automatic settings."); info!("Use --save-config to save or print Martin configuration."); } diff --git a/src/config.rs b/src/config.rs index 8b88a0981..db4bfb114 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,5 +1,5 @@ use crate::io_error; -use crate::pg::config::{PgConfig, PgConfigBuilder}; +use crate::pg::config::PgConfig; use crate::srv::config::{SrvConfig, SrvConfigBuilder}; use log::warn; use serde::{Deserialize, Serialize}; @@ -22,7 +22,7 @@ pub struct ConfigBuilder { #[serde(flatten)] pub srv: SrvConfigBuilder, #[serde(flatten)] - pub pg: PgConfigBuilder, + pub pg: PgConfig, #[serde(flatten)] pub unrecognized: HashMap, } @@ -124,16 +124,10 @@ mod tests { worker_processes: 8, }, pg: PgConfig { - connection_string: "postgres://postgres@localhost:5432/db".to_string(), - #[cfg(feature = "ssl")] - ca_root_file: None, - #[cfg(feature = "ssl")] - danger_accept_invalid_certs: false, + connection_string: Some("postgres://postgres@localhost:5432/db".to_string()), default_srid: Some(4326), - pool_size: 20, - discover_functions: false, - discover_tables: false, - tables: HashMap::from([( + pool_size: Some(20), + tables: Some(HashMap::from([( "table_source".to_string(), TableInfo { schema: "public".to_string(), @@ -150,8 +144,8 @@ mod tests { properties: HashMap::from([("gid".to_string(), "int4".to_string())]), ..Default::default() }, - )]), - functions: HashMap::from([( + )])), + functions: Some(HashMap::from([( "function_zxy_query".to_string(), FunctionInfo::new_extended( "public".to_string(), @@ -160,7 +154,8 @@ mod tests { 30, Bounds::MAX, ), - )]), + )])), + ..Default::default() }, }; assert_eq!(config, expected); diff --git a/src/pg/config.rs b/src/pg/config.rs index 1995d32a6..aa23b32d9 100644 --- a/src/pg/config.rs +++ b/src/pg/config.rs @@ -1,6 +1,6 @@ use crate::config::{report_unrecognized_config, set_option}; use crate::pg::utils::create_tilejson; -use crate::utils::InfoMap; +use crate::utils::{InfoMap, Schemas}; use serde::{Deserialize, Serialize}; use serde_yaml::Value; use std::collections::HashMap; @@ -181,50 +181,43 @@ impl PgInfo for FunctionInfo { pub type TableInfoSources = InfoMap; pub type FuncInfoSources = InfoMap; -#[derive(Clone, Debug, Serialize, PartialEq)] +#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)] pub struct PgConfig { - pub connection_string: String, + pub connection_string: Option, #[cfg(feature = "ssl")] #[serde(skip_serializing_if = "Option::is_none")] pub ca_root_file: Option, #[cfg(feature = "ssl")] + #[serde(skip_serializing_if = "Clone::clone")] pub danger_accept_invalid_certs: bool, #[serde(skip_serializing_if = "Option::is_none")] pub default_srid: Option, - pub pool_size: u32, - #[serde(skip_serializing)] - pub discover_functions: bool, - #[serde(skip_serializing)] - pub discover_tables: bool, - pub tables: TableInfoSources, - pub functions: FuncInfoSources, -} - -#[derive(Debug, Default, PartialEq, Deserialize)] -pub struct PgConfigBuilder { - pub connection_string: Option, - #[cfg(feature = "ssl")] - pub ca_root_file: Option, - #[cfg(feature = "ssl")] - pub danger_accept_invalid_certs: Option, - pub default_srid: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub pool_size: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub auto_tables: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub auto_functions: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub tables: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub functions: Option, + #[serde(skip)] + pub run_autodiscovery: bool, } -impl PgConfigBuilder { +impl PgConfig { pub fn merge(&mut self, other: Self) -> &mut Self { set_option(&mut self.connection_string, other.connection_string); #[cfg(feature = "ssl")] - set_option(&mut self.ca_root_file, other.ca_root_file); - #[cfg(feature = "ssl")] - set_option( - &mut self.danger_accept_invalid_certs, - other.danger_accept_invalid_certs, - ); + { + set_option(&mut self.ca_root_file, other.ca_root_file); + self.danger_accept_invalid_certs |= other.danger_accept_invalid_certs; + } set_option(&mut self.default_srid, other.default_srid); set_option(&mut self.pool_size, other.pool_size); + set_option(&mut self.auto_tables, other.auto_tables); + set_option(&mut self.auto_functions, other.auto_functions); set_option(&mut self.tables, other.tables); set_option(&mut self.functions, other.functions); self @@ -249,24 +242,16 @@ impl PgConfigBuilder { ) })?; Ok(PgConfig { - connection_string, - #[cfg(feature = "ssl")] - ca_root_file: self.ca_root_file, - #[cfg(feature = "ssl")] - danger_accept_invalid_certs: self.danger_accept_invalid_certs.unwrap_or_default(), - default_srid: self.default_srid, - pool_size: self.pool_size.unwrap_or(POOL_SIZE_DEFAULT), - discover_functions: self.tables.is_none() && self.functions.is_none(), - discover_tables: self.tables.is_none() && self.functions.is_none(), - tables: self.tables.unwrap_or_default(), - functions: self.functions.unwrap_or_default(), + connection_string: Some(connection_string), + run_autodiscovery: self.tables.is_none() && self.functions.is_none(), + ..self }) } } -impl From<(PgArgs, Option)> for PgConfigBuilder { +impl From<(PgArgs, Option)> for PgConfig { fn from((args, connection): (PgArgs, Option)) -> Self { - PgConfigBuilder { + PgConfig { connection_string: connection.or_else(|| { env::var_os("DATABASE_URL").and_then(|connection| connection.into_string().ok()) }), @@ -275,13 +260,8 @@ impl From<(PgArgs, Option)> for PgConfigBuilder { env::var_os("CA_ROOT_FILE").and_then(|connection| connection.into_string().ok()) }), #[cfg(feature = "ssl")] - danger_accept_invalid_certs: if args.danger_accept_invalid_certs - || env::var_os("DANGER_ACCEPT_INVALID_CERTS").is_some() - { - Some(true) - } else { - None - }, + danger_accept_invalid_certs: args.danger_accept_invalid_certs + || env::var_os("DANGER_ACCEPT_INVALID_CERTS").is_some(), default_srid: args.default_srid.or_else(|| { env::var_os("DEFAULT_SRID").and_then(|srid| { srid.into_string() @@ -290,8 +270,7 @@ impl From<(PgArgs, Option)> for PgConfigBuilder { }) }), pool_size: args.pool_size, - tables: None, - functions: None, + ..Default::default() } } } diff --git a/src/pg/configurator.rs b/src/pg/configurator.rs index e731b5be4..ba732eadd 100755 --- a/src/pg/configurator.rs +++ b/src/pg/configurator.rs @@ -7,7 +7,7 @@ use crate::pg::pool::Pool; use crate::pg::table_source::{calc_srid, get_table_sources, merge_table_info, table_to_query}; use crate::source::IdResolver; use crate::srv::server::Sources; -use crate::utils::{find_info, InfoMap}; +use crate::utils::{find_info, normalize_key, InfoMap, Schemas}; use futures::future::{join_all, try_join}; use itertools::Itertools; use log::{debug, error, info, warn}; @@ -27,8 +27,8 @@ pub async fn resolve_pg_data( Ok(( tables, PgConfig { - tables: tbl_info, - functions: func_info, + tables: Some(tbl_info), + functions: Some(func_info), ..config }, pg.pool, @@ -38,8 +38,8 @@ pub async fn resolve_pg_data( struct PgBuilder { pool: Pool, default_srid: Option, - discover_functions: bool, - discover_tables: bool, + auto_functions: Schemas, + auto_tables: Schemas, id_resolver: IdResolver, tables: TableInfoSources, functions: FuncInfoSources, @@ -48,19 +48,20 @@ struct PgBuilder { impl PgBuilder { async fn new(config: &PgConfig, id_resolver: IdResolver) -> io::Result { let pool = Pool::new(config).await?; + let auto = config.run_autodiscovery; Ok(Self { pool, default_srid: config.default_srid, - discover_functions: config.discover_functions, - discover_tables: config.discover_tables, + auto_functions: config.auto_functions.clone().unwrap_or(Schemas::Bool(auto)), + auto_tables: config.auto_tables.clone().unwrap_or(Schemas::Bool(auto)), id_resolver, - tables: config.tables.clone(), - functions: config.functions.clone(), + tables: config.tables.clone().unwrap_or_default(), + functions: config.functions.clone().unwrap_or_default(), }) } pub async fn instantiate_tables(&self) -> Result<(Sources, TableInfoSources), io::Error> { - let all_tables = get_table_sources(&self.pool).await?; + let mut all_tables = get_table_sources(&self.pool).await?; // Match configured sources with the discovered ones and add them to the pending list. let mut used = HashSet::<(&str, &str, &str)>::new(); @@ -90,20 +91,20 @@ impl PgBuilder { pending.push(table_to_query(id2, cfg_inf, self.pool.clone())); } - if self.discover_tables { - // Sort the discovered sources by schema, table and geometry column to ensure a consistent behavior - for (schema, tables) in all_tables.into_iter().sorted_by(by_key) { - for (table, geoms) in tables.into_iter().sorted_by(by_key) { - for (geom, mut src_inf) in geoms.into_iter().sorted_by(by_key) { - if used.contains(&(schema.as_str(), table.as_str(), geom.as_str())) { - continue; - } - let id2 = self.resolve_id(table.clone(), &src_inf); - let Some(srid) = calc_srid(&src_inf.format_id(), &id2, src_inf.srid,0, self.default_srid) else {continue}; - src_inf.srid = srid; - info!("Discovered source {id2} from {}", summary(&src_inf)); - pending.push(table_to_query(id2, src_inf, self.pool.clone())); + // Sort the discovered sources by schema, table and geometry column to ensure a consistent behavior + for schema in self.auto_tables.get(|| all_tables.keys()) { + let Some(schema2) = normalize_key(&all_tables, &schema, "schema", "") else { continue }; + let tables = all_tables.remove(&schema2).unwrap(); + for (table, geoms) in tables.into_iter().sorted_by(by_key) { + for (geom, mut src_inf) in geoms.into_iter().sorted_by(by_key) { + if used.contains(&(schema.as_str(), table.as_str(), geom.as_str())) { + continue; } + let id2 = self.resolve_id(table.clone(), &src_inf); + let Some(srid) = calc_srid(&src_inf.format_id(), &id2, src_inf.srid,0, self.default_srid) else {continue}; + src_inf.srid = srid; + info!("Discovered source {id2} from {}", summary(&src_inf)); + pending.push(table_to_query(id2, src_inf, self.pool.clone())); } } } @@ -129,14 +130,13 @@ impl PgBuilder { } pub async fn instantiate_functions(&self) -> Result<(Sources, FuncInfoSources), io::Error> { - let all_functions = get_function_sources(&self.pool).await?; - + let mut all_funcs = get_function_sources(&self.pool).await?; let mut res: Sources = HashMap::new(); let mut info_map = FuncInfoSources::new(); let mut used = HashSet::<(&str, &str)>::new(); for (id, cfg_inf) in &self.functions { - let Some(schemas) = find_info(&all_functions, &cfg_inf.schema, "schema", id) else { continue }; + let Some(schemas) = find_info(&all_funcs, &cfg_inf.schema, "schema", id) else { continue }; if schemas.is_empty() { warn!("No functions found in schema {}. Only functions like (z,x,y) -> bytea and similar are considered. See README.md", cfg_inf.schema); continue; @@ -155,19 +155,19 @@ impl PgBuilder { info_map.insert(id2, cfg_inf.clone()); } - if self.discover_functions { - // Sort the discovered sources by schema and function name to ensure a consistent behavior - for (schema, funcs) in all_functions.into_iter().sorted_by(by_key) { - for (name, (pg_sql, src_inf)) in funcs.into_iter().sorted_by(by_key) { - if used.contains(&(schema.as_str(), name.as_str())) { - continue; - } - let id2 = self.resolve_id(name.clone(), &src_inf); - self.add_func_src(&mut res, id2.clone(), &src_inf, pg_sql.clone()); - info!("Discovered source {id2} from function {}", pg_sql.signature); - debug!("{}", pg_sql.query); - info_map.insert(id2, src_inf); + // Sort the discovered sources by schema and function name to ensure a consistent behavior + for schema in self.auto_functions.get(|| all_funcs.keys()) { + let Some(schema2) = normalize_key(&all_funcs, &schema, "schema", "") else { continue }; + let funcs = all_funcs.remove(&schema2).unwrap(); + for (name, (pg_sql, src_inf)) in funcs.into_iter().sorted_by(by_key) { + if used.contains(&(schema.as_str(), name.as_str())) { + continue; } + let id2 = self.resolve_id(name.clone(), &src_inf); + self.add_func_src(&mut res, id2.clone(), &src_inf, pg_sql.clone()); + info!("Discovered source {id2} from function {}", pg_sql.signature); + debug!("{}", pg_sql.query); + info_map.insert(id2, src_inf); } } diff --git a/src/pg/pool.rs b/src/pg/pool.rs index 7b12087a4..68e1e3812 100755 --- a/src/pg/pool.rs +++ b/src/pg/pool.rs @@ -1,4 +1,4 @@ -use crate::pg::config::PgConfig; +use crate::pg::config::{PgConfig, POOL_SIZE_DEFAULT}; use crate::pg::utils::io_error; use bb8::PooledConnection; use bb8_postgres::{tokio_postgres as pg, PostgresConnectionManager}; @@ -35,7 +35,7 @@ pub struct Pool { impl Pool { pub async fn new(config: &PgConfig) -> io::Result { - let conn_str = config.connection_string.as_str(); + let conn_str = config.connection_string.as_ref().unwrap().as_str(); info!("Connecting to {conn_str}"); let pg_cfg = pg::config::Config::from_str(conn_str) .map_err(|e| io_error!(e, "Can't parse connection string {conn_str}"))?; @@ -66,7 +66,7 @@ impl Pool { }; let pool = InternalPool::builder() - .max_size(config.pool_size) + .max_size(config.pool_size.unwrap_or(POOL_SIZE_DEFAULT)) .build(manager) .await .map_err(|e| io_error!(e, "Can't build connection pool"))?; diff --git a/src/utils.rs b/src/utils.rs index 750b8f4e7..fc6c5b676 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,4 +1,6 @@ +use itertools::Itertools; use log::{error, info, warn}; +use serde::{Deserialize, Serialize}; use std::collections::HashMap; pub type InfoMap = HashMap; @@ -56,3 +58,33 @@ pub fn find_info_kv<'a, T>( None } } + +/// A list of schemas to include in the discovery process, or a boolean to +/// indicate whether to run discovery at all. +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[serde(untagged)] +pub enum Schemas { + Bool(bool), + List(Vec), +} + +impl Schemas { + /// Returns a list of schemas to include in the discovery process. + /// If self is a true, returns a list of all schemas produced by the callback. + pub fn get<'a, I, F>(&self, keys: F) -> Vec + where + I: Iterator, + F: FnOnce() -> I, + { + match self { + Schemas::List(lst) => lst.clone(), + Schemas::Bool(all) => { + if *all { + keys().sorted().map(String::to_string).collect() + } else { + Vec::new() + } + } + } + } +} diff --git a/tests/function_source_test.rs b/tests/function_source_test.rs index dbd2798e9..ff8a2396d 100644 --- a/tests/function_source_test.rs +++ b/tests/function_source_test.rs @@ -1,7 +1,9 @@ use ctor::ctor; +use itertools::Itertools; use log::info; use martin::pg::function_source::get_function_sources; use martin::source::Xyz; +use martin::utils::Schemas; #[path = "utils.rs"] mod utils; @@ -58,3 +60,15 @@ async fn function_source_tile() { assert!(!tile.is_empty()); } + +#[actix_rt::test] +async fn function_source_schemas() { + let mut cfg = mock_empty_config().await; + cfg.auto_functions = Some(Schemas::List(vec!["MixedCase".to_owned()])); + cfg.auto_tables = Some(Schemas::Bool(false)); + let sources = mock_sources(cfg).await.0; + assert_eq!( + sources.keys().sorted().collect::>(), + vec!["function_Mixed_Name"], + ); +} diff --git a/tests/server_test.rs b/tests/server_test.rs index f7a59ff3e..77764332b 100644 --- a/tests/server_test.rs +++ b/tests/server_test.rs @@ -17,7 +17,7 @@ fn init() { macro_rules! create_app { ($sources:expr) => {{ - let sources = $sources.await.0; + let sources = mock_sources($sources.await).await.0; let state = crate::utils::mock_app_data(sources).await; ::actix_web::test::init_service( ::actix_web::App::new() @@ -34,7 +34,7 @@ fn test_get(path: &str) -> Request { #[actix_rt::test] async fn get_catalog_ok() { - let app = create_app!(mock_unconfigured()); + let app = create_app!(mock_empty_config()); let req = test_get("/catalog"); let response = call_service(&app, req).await; @@ -62,7 +62,7 @@ async fn get_table_source_ok() { srid: 3857, ..table }; - let app = create_app!(mock_sources( + let app = create_app!(mock_config( None, Some(vec![("table_source", table_source), ("bad_srid", bad_srid)]), None @@ -119,7 +119,7 @@ async fn get_table_source_multiple_geom_tile_ok() { async fn get_table_source_tile_minmax_zoom_ok() { let mut tables = mock_table_config_map(); - let app = create_app!(mock_sources( + let cfg = mock_config( None, Some(vec![ ( @@ -146,8 +146,9 @@ async fn get_table_source_tile_minmax_zoom_ok() { }, ), ]), - None - )); + None, + ); + let app = create_app!(cfg); // zoom = 0 (nothing) let req = test_get("/points1/0/0/0"); @@ -212,7 +213,7 @@ async fn get_table_source_tile_minmax_zoom_ok() { #[actix_rt::test] async fn get_function_tiles() { - let app = create_app!(mock_unconfigured()); + let app = create_app!(mock_empty_config()); let req = test_get("/function_zoom_xy/6/38/20"); assert!(call_service(&app, req).await.status().is_success()); @@ -277,7 +278,7 @@ async fn get_composite_source_tile_minmax_zoom_ok() { ..tables.remove("points2").unwrap() }; let tables = vec![("points1", points1), ("points2", points2)]; - let app = create_app!(mock_sources(None, Some(tables), None)); + let app = create_app!(mock_config(None, Some(tables), None)); // zoom = 0 (nothing) let req = test_get("/points1,points2/0/0/0"); @@ -317,7 +318,7 @@ async fn get_composite_source_tile_minmax_zoom_ok() { #[actix_rt::test] async fn get_function_source_ok() { - let app = create_app!(mock_unconfigured()); + let app = create_app!(mock_empty_config()); let req = test_get("/non_existent"); let response = call_service(&app, req).await; @@ -364,7 +365,7 @@ async fn get_function_source_ok() { #[actix_rt::test] async fn get_function_source_tile_ok() { - let app = create_app!(mock_unconfigured()); + let app = create_app!(mock_empty_config()); let req = test_get("/function_zxy_query/0/0/0"); let response = call_service(&app, req).await; @@ -386,7 +387,7 @@ async fn get_function_source_tile_minmax_zoom_ok() { ("function_source1", function_source1), ("function_source2", function_source2), ]; - let app = create_app!(mock_sources(Some(funcs), None, None)); + let app = create_app!(mock_config(Some(funcs), None, None)); // zoom = 0 (function_source1) let req = test_get("/function_source1/0/0/0"); @@ -431,7 +432,7 @@ async fn get_function_source_tile_minmax_zoom_ok() { #[actix_rt::test] async fn get_function_source_query_params_ok() { - let app = create_app!(mock_unconfigured()); + let app = create_app!(mock_empty_config()); let req = test_get("/function_zxy_query_test/0/0/0"); let response = call_service(&app, req).await; @@ -445,7 +446,7 @@ async fn get_function_source_query_params_ok() { #[actix_rt::test] async fn get_health_returns_ok() { - let app = create_app!(mock_unconfigured()); + let app = create_app!(mock_empty_config()); let req = test_get("/health"); let response = call_service(&app, req).await; @@ -485,7 +486,7 @@ async fn tables_feature_id() { ("id_and_prop", id_and_prop), ("prop_only", prop_only), ]; - let mock = mock_sources(None, Some(tables.clone()), None).await; + let mock = mock_sources(mock_config(None, Some(tables.clone()), None).await).await; let src = table(&mock, "no_id"); assert_eq!(src.id_column, None); @@ -510,7 +511,7 @@ async fn tables_feature_id() { // -------------------------------------------- - let app = create_app!(mock_sources(None, Some(tables.clone()), None)); + let app = create_app!(mock_config(None, Some(tables.clone()), None)); for (name, _) in tables.iter() { let req = test_get(format!("/{name}/0/0/0").as_str()); let response = call_service(&app, req).await; diff --git a/tests/table_source_test.rs b/tests/table_source_test.rs index d596eef23..ff3a80da3 100644 --- a/tests/table_source_test.rs +++ b/tests/table_source_test.rs @@ -1,6 +1,8 @@ use ctor::ctor; +use itertools::Itertools; use log::info; use martin::source::Xyz; +use martin::utils::Schemas; use std::collections::HashMap; #[path = "utils.rs"] @@ -92,3 +94,15 @@ async fn tables_multiple_geom_ok() { let source = table(&mock, "table_source_multiple_geom.1"); assert_eq!(source.geometry_column, "geom2"); } + +#[actix_rt::test] +async fn table_source_schemas() { + let mut cfg = mock_empty_config().await; + cfg.auto_functions = Some(Schemas::Bool(false)); + cfg.auto_tables = Some(Schemas::List(vec!["MixedCase".to_owned()])); + let sources = mock_sources(cfg).await.0; + assert_eq!( + sources.keys().sorted().collect::>(), + vec!["MixPoints"], + ); +} diff --git a/tests/utils.rs b/tests/utils.rs index 73e46ce21..674e7bec4 100644 --- a/tests/utils.rs +++ b/tests/utils.rs @@ -2,8 +2,7 @@ use actix_web::web::Data; use log::info; -use martin::pg::config::{FunctionInfo, PgConfigBuilder}; -use martin::pg::config::{PgConfig, TableInfo}; +use martin::pg::config::{FunctionInfo, PgConfig, TableInfo}; use martin::pg::configurator::resolve_pg_data; use martin::pg::pool::Pool; use martin::source::{IdResolver, Source}; @@ -27,14 +26,9 @@ pub async fn mock_config( ) -> PgConfig { let connection_string: String = env::var("DATABASE_URL").unwrap(); info!("Connecting to {connection_string}"); - let config = PgConfigBuilder { + let config = PgConfig { connection_string: Some(connection_string), - #[cfg(feature = "ssl")] - ca_root_file: None, - #[cfg(feature = "ssl")] - danger_accept_invalid_certs: None, default_srid, - pool_size: None, tables: tables.map(|s| { s.iter() .map(|v| (v.0.to_string(), v.1.clone())) @@ -45,24 +39,25 @@ pub async fn mock_config( .map(|v| (v.0.to_string(), v.1.clone())) .collect::>() }), + ..Default::default() }; config.finalize().expect("Unable to finalize config") } +#[allow(dead_code)] +pub async fn mock_empty_config() -> PgConfig { + mock_config(None, None, None).await +} + #[allow(dead_code)] pub async fn mock_pool() -> Pool { - let res = Pool::new(&mock_config(None, None, None).await).await; + let res = Pool::new(&mock_empty_config().await).await; res.expect("Failed to create pool") } #[allow(dead_code)] -pub async fn mock_sources( - functions: Option>, - tables: Option>, - default_srid: Option, -) -> MockSource { - let cfg = mock_config(functions, tables, default_srid).await; - let res = resolve_pg_data(cfg, IdResolver::default()).await; +pub async fn mock_sources(config: PgConfig) -> MockSource { + let res = resolve_pg_data(config, IdResolver::default()).await; let res = res.expect("Failed to resolve pg data"); (res.0, res.1) } @@ -74,27 +69,22 @@ pub async fn mock_app_data(sources: Sources) -> Data { #[allow(dead_code)] pub async fn mock_unconfigured() -> MockSource { - mock_sources(None, None, None).await + mock_sources(mock_empty_config().await).await } #[allow(dead_code)] pub async fn mock_unconfigured_srid(default_srid: Option) -> MockSource { - mock_sources(None, None, default_srid).await -} - -#[allow(dead_code)] -pub async fn mock_configured() -> MockSource { - mock_sources(mock_func_config(), mock_table_config(), None).await + mock_sources(mock_config(None, None, default_srid).await).await } #[allow(dead_code)] pub async fn mock_configured_funcs() -> MockSource { - mock_sources(mock_func_config(), None, None).await + mock_sources(mock_config(mock_func_config(), None, None).await).await } #[allow(dead_code)] -pub async fn mock_configured_tables(default_srid: Option) -> MockSource { - mock_sources(None, mock_table_config(), default_srid).await +pub async fn mock_configured_tables(default_srid: Option) -> PgConfig { + mock_config(None, mock_table_config(), default_srid).await } pub fn mock_func_config() -> Option> { @@ -298,7 +288,7 @@ pub fn props(props: &[(&'static str, &'static str)]) -> HashMap #[allow(dead_code)] pub fn table<'a>(mock: &'a MockSource, name: &str) -> &'a TableInfo { let (_, PgConfig { tables, .. }) = mock; - tables.get(name).unwrap() + tables.as_ref().map(|v| v.get(name).unwrap()).unwrap() } #[allow(dead_code)]