Skip to content

Commit

Permalink
Allow auto discovery per schema
Browse files Browse the repository at this point in the history
  • Loading branch information
nyurik committed Dec 9, 2022
1 parent ac02930 commit 45e61d1
Show file tree
Hide file tree
Showing 10 changed files with 174 additions and 149 deletions.
6 changes: 3 additions & 3 deletions src/bin/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use actix_web::dev::Server;
use clap::Parser;
use log::{error, info, warn};
use martin::config::{read_config, Config, ConfigBuilder};
use martin::pg::config::{PgArgs, PgConfigBuilder};
use martin::pg::config::{PgArgs, PgConfig};
use martin::pg::configurator::resolve_pg_data;
use martin::source::IdResolver;
use martin::srv::config::{SrvArgs, SrvConfigBuilder};
Expand Down Expand Up @@ -49,7 +49,7 @@ impl From<Args> for ConfigBuilder {

ConfigBuilder {
srv: SrvConfigBuilder::from(args.srv),
pg: PgConfigBuilder::from((args.pg, args.connection)),
pg: PgConfig::from((args.pg, args.connection)),
unrecognized: HashMap::new(),
}
}
Expand Down Expand Up @@ -89,7 +89,7 @@ async fn start(args: Args) -> io::Result<Server> {
info!("Saving config to {file_name}, use --config to load it");
File::create(file_name)?.write_all(yaml.as_bytes())?;
}
} else if config.pg.discover_functions || config.pg.discover_tables {
} else if config.pg.run_autodiscovery {
info!("Martin has been configured with automatic settings.");
info!("Use --save-config to save or print Martin configuration.");
}
Expand Down
23 changes: 9 additions & 14 deletions src/config.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::io_error;
use crate::pg::config::{PgConfig, PgConfigBuilder};
use crate::pg::config::PgConfig;
use crate::srv::config::{SrvConfig, SrvConfigBuilder};
use log::warn;
use serde::{Deserialize, Serialize};
Expand All @@ -22,7 +22,7 @@ pub struct ConfigBuilder {
#[serde(flatten)]
pub srv: SrvConfigBuilder,
#[serde(flatten)]
pub pg: PgConfigBuilder,
pub pg: PgConfig,
#[serde(flatten)]
pub unrecognized: HashMap<String, Value>,
}
Expand Down Expand Up @@ -124,16 +124,10 @@ mod tests {
worker_processes: 8,
},
pg: PgConfig {
connection_string: "postgres://postgres@localhost:5432/db".to_string(),
#[cfg(feature = "ssl")]
ca_root_file: None,
#[cfg(feature = "ssl")]
danger_accept_invalid_certs: false,
connection_string: Some("postgres://postgres@localhost:5432/db".to_string()),
default_srid: Some(4326),
pool_size: 20,
discover_functions: false,
discover_tables: false,
tables: HashMap::from([(
pool_size: Some(20),
tables: Some(HashMap::from([(
"table_source".to_string(),
TableInfo {
schema: "public".to_string(),
Expand All @@ -150,8 +144,8 @@ mod tests {
properties: HashMap::from([("gid".to_string(), "int4".to_string())]),
..Default::default()
},
)]),
functions: HashMap::from([(
)])),
functions: Some(HashMap::from([(
"function_zxy_query".to_string(),
FunctionInfo::new_extended(
"public".to_string(),
Expand All @@ -160,7 +154,8 @@ mod tests {
30,
Bounds::MAX,
),
)]),
)])),
..Default::default()
},
};
assert_eq!(config, expected);
Expand Down
77 changes: 28 additions & 49 deletions src/pg/config.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::config::{report_unrecognized_config, set_option};
use crate::pg::utils::create_tilejson;
use crate::utils::InfoMap;
use crate::utils::{InfoMap, Schemas};
use serde::{Deserialize, Serialize};
use serde_yaml::Value;
use std::collections::HashMap;
Expand Down Expand Up @@ -181,50 +181,43 @@ impl PgInfo for FunctionInfo {
pub type TableInfoSources = InfoMap<TableInfo>;
pub type FuncInfoSources = InfoMap<FunctionInfo>;

#[derive(Clone, Debug, Serialize, PartialEq)]
#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
pub struct PgConfig {
pub connection_string: String,
pub connection_string: Option<String>,
#[cfg(feature = "ssl")]
#[serde(skip_serializing_if = "Option::is_none")]
pub ca_root_file: Option<String>,
#[cfg(feature = "ssl")]
#[serde(skip_serializing_if = "Clone::clone")]
pub danger_accept_invalid_certs: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub default_srid: Option<i32>,
pub pool_size: u32,
#[serde(skip_serializing)]
pub discover_functions: bool,
#[serde(skip_serializing)]
pub discover_tables: bool,
pub tables: TableInfoSources,
pub functions: FuncInfoSources,
}

#[derive(Debug, Default, PartialEq, Deserialize)]
pub struct PgConfigBuilder {
pub connection_string: Option<String>,
#[cfg(feature = "ssl")]
pub ca_root_file: Option<String>,
#[cfg(feature = "ssl")]
pub danger_accept_invalid_certs: Option<bool>,
pub default_srid: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub pool_size: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub auto_tables: Option<Schemas>,
#[serde(skip_serializing_if = "Option::is_none")]
pub auto_functions: Option<Schemas>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tables: Option<TableInfoSources>,
#[serde(skip_serializing_if = "Option::is_none")]
pub functions: Option<FuncInfoSources>,
#[serde(skip)]
pub run_autodiscovery: bool,
}

impl PgConfigBuilder {
impl PgConfig {
pub fn merge(&mut self, other: Self) -> &mut Self {
set_option(&mut self.connection_string, other.connection_string);
#[cfg(feature = "ssl")]
set_option(&mut self.ca_root_file, other.ca_root_file);
#[cfg(feature = "ssl")]
set_option(
&mut self.danger_accept_invalid_certs,
other.danger_accept_invalid_certs,
);
{
set_option(&mut self.ca_root_file, other.ca_root_file);
self.danger_accept_invalid_certs |= other.danger_accept_invalid_certs;
}
set_option(&mut self.default_srid, other.default_srid);
set_option(&mut self.pool_size, other.pool_size);
set_option(&mut self.auto_tables, other.auto_tables);
set_option(&mut self.auto_functions, other.auto_functions);
set_option(&mut self.tables, other.tables);
set_option(&mut self.functions, other.functions);
self
Expand All @@ -249,24 +242,16 @@ impl PgConfigBuilder {
)
})?;
Ok(PgConfig {
connection_string,
#[cfg(feature = "ssl")]
ca_root_file: self.ca_root_file,
#[cfg(feature = "ssl")]
danger_accept_invalid_certs: self.danger_accept_invalid_certs.unwrap_or_default(),
default_srid: self.default_srid,
pool_size: self.pool_size.unwrap_or(POOL_SIZE_DEFAULT),
discover_functions: self.tables.is_none() && self.functions.is_none(),
discover_tables: self.tables.is_none() && self.functions.is_none(),
tables: self.tables.unwrap_or_default(),
functions: self.functions.unwrap_or_default(),
connection_string: Some(connection_string),
run_autodiscovery: self.tables.is_none() && self.functions.is_none(),
..self
})
}
}

impl From<(PgArgs, Option<String>)> for PgConfigBuilder {
impl From<(PgArgs, Option<String>)> for PgConfig {
fn from((args, connection): (PgArgs, Option<String>)) -> Self {
PgConfigBuilder {
PgConfig {
connection_string: connection.or_else(|| {
env::var_os("DATABASE_URL").and_then(|connection| connection.into_string().ok())
}),
Expand All @@ -275,13 +260,8 @@ impl From<(PgArgs, Option<String>)> for PgConfigBuilder {
env::var_os("CA_ROOT_FILE").and_then(|connection| connection.into_string().ok())
}),
#[cfg(feature = "ssl")]
danger_accept_invalid_certs: if args.danger_accept_invalid_certs
|| env::var_os("DANGER_ACCEPT_INVALID_CERTS").is_some()
{
Some(true)
} else {
None
},
danger_accept_invalid_certs: args.danger_accept_invalid_certs
|| env::var_os("DANGER_ACCEPT_INVALID_CERTS").is_some(),
default_srid: args.default_srid.or_else(|| {
env::var_os("DEFAULT_SRID").and_then(|srid| {
srid.into_string()
Expand All @@ -290,8 +270,7 @@ impl From<(PgArgs, Option<String>)> for PgConfigBuilder {
})
}),
pool_size: args.pool_size,
tables: None,
functions: None,
..Default::default()
}
}
}
76 changes: 38 additions & 38 deletions src/pg/configurator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::pg::pool::Pool;
use crate::pg::table_source::{calc_srid, get_table_sources, merge_table_info, table_to_query};
use crate::source::IdResolver;
use crate::srv::server::Sources;
use crate::utils::{find_info, InfoMap};
use crate::utils::{find_info, normalize_key, InfoMap, Schemas};
use futures::future::{join_all, try_join};
use itertools::Itertools;
use log::{debug, error, info, warn};
Expand All @@ -27,8 +27,8 @@ pub async fn resolve_pg_data(
Ok((
tables,
PgConfig {
tables: tbl_info,
functions: func_info,
tables: Some(tbl_info),
functions: Some(func_info),
..config
},
pg.pool,
Expand All @@ -38,8 +38,8 @@ pub async fn resolve_pg_data(
struct PgBuilder {
pool: Pool,
default_srid: Option<i32>,
discover_functions: bool,
discover_tables: bool,
auto_functions: Schemas,
auto_tables: Schemas,
id_resolver: IdResolver,
tables: TableInfoSources,
functions: FuncInfoSources,
Expand All @@ -48,19 +48,20 @@ struct PgBuilder {
impl PgBuilder {
async fn new(config: &PgConfig, id_resolver: IdResolver) -> io::Result<Self> {
let pool = Pool::new(config).await?;
let auto = config.run_autodiscovery;
Ok(Self {
pool,
default_srid: config.default_srid,
discover_functions: config.discover_functions,
discover_tables: config.discover_tables,
auto_functions: config.auto_functions.clone().unwrap_or(Schemas::Bool(auto)),
auto_tables: config.auto_tables.clone().unwrap_or(Schemas::Bool(auto)),
id_resolver,
tables: config.tables.clone(),
functions: config.functions.clone(),
tables: config.tables.clone().unwrap_or_default(),
functions: config.functions.clone().unwrap_or_default(),
})
}

pub async fn instantiate_tables(&self) -> Result<(Sources, TableInfoSources), io::Error> {
let all_tables = get_table_sources(&self.pool).await?;
let mut all_tables = get_table_sources(&self.pool).await?;

// Match configured sources with the discovered ones and add them to the pending list.
let mut used = HashSet::<(&str, &str, &str)>::new();
Expand Down Expand Up @@ -90,20 +91,20 @@ impl PgBuilder {
pending.push(table_to_query(id2, cfg_inf, self.pool.clone()));
}

if self.discover_tables {
// Sort the discovered sources by schema, table and geometry column to ensure a consistent behavior
for (schema, tables) in all_tables.into_iter().sorted_by(by_key) {
for (table, geoms) in tables.into_iter().sorted_by(by_key) {
for (geom, mut src_inf) in geoms.into_iter().sorted_by(by_key) {
if used.contains(&(schema.as_str(), table.as_str(), geom.as_str())) {
continue;
}
let id2 = self.resolve_id(table.clone(), &src_inf);
let Some(srid) = calc_srid(&src_inf.format_id(), &id2, src_inf.srid,0, self.default_srid) else {continue};
src_inf.srid = srid;
info!("Discovered source {id2} from {}", summary(&src_inf));
pending.push(table_to_query(id2, src_inf, self.pool.clone()));
// Sort the discovered sources by schema, table and geometry column to ensure a consistent behavior
for schema in self.auto_tables.get(|| all_tables.keys()) {
let Some(schema2) = normalize_key(&all_tables, &schema, "schema", "") else { continue };
let tables = all_tables.remove(&schema2).unwrap();
for (table, geoms) in tables.into_iter().sorted_by(by_key) {
for (geom, mut src_inf) in geoms.into_iter().sorted_by(by_key) {
if used.contains(&(schema.as_str(), table.as_str(), geom.as_str())) {
continue;
}
let id2 = self.resolve_id(table.clone(), &src_inf);
let Some(srid) = calc_srid(&src_inf.format_id(), &id2, src_inf.srid,0, self.default_srid) else {continue};
src_inf.srid = srid;
info!("Discovered source {id2} from {}", summary(&src_inf));
pending.push(table_to_query(id2, src_inf, self.pool.clone()));
}
}
}
Expand All @@ -129,14 +130,13 @@ impl PgBuilder {
}

pub async fn instantiate_functions(&self) -> Result<(Sources, FuncInfoSources), io::Error> {
let all_functions = get_function_sources(&self.pool).await?;

let mut all_funcs = get_function_sources(&self.pool).await?;
let mut res: Sources = HashMap::new();
let mut info_map = FuncInfoSources::new();
let mut used = HashSet::<(&str, &str)>::new();

for (id, cfg_inf) in &self.functions {
let Some(schemas) = find_info(&all_functions, &cfg_inf.schema, "schema", id) else { continue };
let Some(schemas) = find_info(&all_funcs, &cfg_inf.schema, "schema", id) else { continue };
if schemas.is_empty() {
warn!("No functions found in schema {}. Only functions like (z,x,y) -> bytea and similar are considered. See README.md", cfg_inf.schema);
continue;
Expand All @@ -155,19 +155,19 @@ impl PgBuilder {
info_map.insert(id2, cfg_inf.clone());
}

if self.discover_functions {
// Sort the discovered sources by schema and function name to ensure a consistent behavior
for (schema, funcs) in all_functions.into_iter().sorted_by(by_key) {
for (name, (pg_sql, src_inf)) in funcs.into_iter().sorted_by(by_key) {
if used.contains(&(schema.as_str(), name.as_str())) {
continue;
}
let id2 = self.resolve_id(name.clone(), &src_inf);
self.add_func_src(&mut res, id2.clone(), &src_inf, pg_sql.clone());
info!("Discovered source {id2} from function {}", pg_sql.signature);
debug!("{}", pg_sql.query);
info_map.insert(id2, src_inf);
// Sort the discovered sources by schema and function name to ensure a consistent behavior
for schema in self.auto_functions.get(|| all_funcs.keys()) {
let Some(schema2) = normalize_key(&all_funcs, &schema, "schema", "") else { continue };
let funcs = all_funcs.remove(&schema2).unwrap();
for (name, (pg_sql, src_inf)) in funcs.into_iter().sorted_by(by_key) {
if used.contains(&(schema.as_str(), name.as_str())) {
continue;
}
let id2 = self.resolve_id(name.clone(), &src_inf);
self.add_func_src(&mut res, id2.clone(), &src_inf, pg_sql.clone());
info!("Discovered source {id2} from function {}", pg_sql.signature);
debug!("{}", pg_sql.query);
info_map.insert(id2, src_inf);
}
}

Expand Down
6 changes: 3 additions & 3 deletions src/pg/pool.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::pg::config::PgConfig;
use crate::pg::config::{PgConfig, POOL_SIZE_DEFAULT};
use crate::pg::utils::io_error;
use bb8::PooledConnection;
use bb8_postgres::{tokio_postgres as pg, PostgresConnectionManager};
Expand Down Expand Up @@ -35,7 +35,7 @@ pub struct Pool {

impl Pool {
pub async fn new(config: &PgConfig) -> io::Result<Self> {
let conn_str = config.connection_string.as_str();
let conn_str = config.connection_string.as_ref().unwrap().as_str();
info!("Connecting to {conn_str}");
let pg_cfg = pg::config::Config::from_str(conn_str)
.map_err(|e| io_error!(e, "Can't parse connection string {conn_str}"))?;
Expand Down Expand Up @@ -66,7 +66,7 @@ impl Pool {
};

let pool = InternalPool::builder()
.max_size(config.pool_size)
.max_size(config.pool_size.unwrap_or(POOL_SIZE_DEFAULT))
.build(manager)
.await
.map_err(|e| io_error!(e, "Can't build connection pool"))?;
Expand Down
Loading

0 comments on commit 45e61d1

Please sign in to comment.