Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow auto discovery per schema #510

Merged
merged 17 commits into from
Dec 10, 2022
6 changes: 3 additions & 3 deletions src/bin/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use actix_web::dev::Server;
use clap::Parser;
use log::{error, info, warn};
use martin::config::{read_config, Config, ConfigBuilder};
use martin::pg::config::{PgArgs, PgConfigBuilder};
use martin::pg::config::{PgArgs, PgConfig};
use martin::pg::configurator::resolve_pg_data;
use martin::source::IdResolver;
use martin::srv::config::{SrvArgs, SrvConfigBuilder};
Expand Down Expand Up @@ -49,7 +49,7 @@ impl From<Args> for ConfigBuilder {

ConfigBuilder {
srv: SrvConfigBuilder::from(args.srv),
pg: PgConfigBuilder::from((args.pg, args.connection)),
pg: PgConfig::from((args.pg, args.connection)),
unrecognized: HashMap::new(),
}
}
Expand Down Expand Up @@ -89,7 +89,7 @@ async fn start(args: Args) -> io::Result<Server> {
info!("Saving config to {file_name}, use --config to load it");
File::create(file_name)?.write_all(yaml.as_bytes())?;
}
} else if config.pg.discover_functions || config.pg.discover_tables {
} else if config.pg.run_autodiscovery {
info!("Martin has been configured with automatic settings.");
info!("Use --save-config to save or print Martin configuration.");
}
Expand Down
23 changes: 9 additions & 14 deletions src/config.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::io_error;
use crate::pg::config::{PgConfig, PgConfigBuilder};
use crate::pg::config::PgConfig;
use crate::srv::config::{SrvConfig, SrvConfigBuilder};
use log::warn;
use serde::{Deserialize, Serialize};
Expand All @@ -22,7 +22,7 @@ pub struct ConfigBuilder {
#[serde(flatten)]
pub srv: SrvConfigBuilder,
#[serde(flatten)]
pub pg: PgConfigBuilder,
pub pg: PgConfig,
#[serde(flatten)]
pub unrecognized: HashMap<String, Value>,
}
Expand Down Expand Up @@ -124,16 +124,10 @@ mod tests {
worker_processes: 8,
},
pg: PgConfig {
connection_string: "postgres://postgres@localhost:5432/db".to_string(),
#[cfg(feature = "ssl")]
ca_root_file: None,
#[cfg(feature = "ssl")]
danger_accept_invalid_certs: false,
connection_string: Some("postgres://postgres@localhost:5432/db".to_string()),
default_srid: Some(4326),
pool_size: 20,
discover_functions: false,
discover_tables: false,
tables: HashMap::from([(
pool_size: Some(20),
tables: Some(HashMap::from([(
"table_source".to_string(),
TableInfo {
schema: "public".to_string(),
Expand All @@ -150,8 +144,8 @@ mod tests {
properties: HashMap::from([("gid".to_string(), "int4".to_string())]),
..Default::default()
},
)]),
functions: HashMap::from([(
)])),
functions: Some(HashMap::from([(
"function_zxy_query".to_string(),
FunctionInfo::new_extended(
"public".to_string(),
Expand All @@ -160,7 +154,8 @@ mod tests {
30,
Bounds::MAX,
),
)]),
)])),
..Default::default()
},
};
assert_eq!(config, expected);
Expand Down
77 changes: 28 additions & 49 deletions src/pg/config.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::config::{report_unrecognized_config, set_option};
use crate::pg::utils::create_tilejson;
use crate::utils::InfoMap;
use crate::utils::{InfoMap, Schemas};
use serde::{Deserialize, Serialize};
use serde_yaml::Value;
use std::collections::HashMap;
Expand Down Expand Up @@ -181,50 +181,43 @@ impl PgInfo for FunctionInfo {
pub type TableInfoSources = InfoMap<TableInfo>;
pub type FuncInfoSources = InfoMap<FunctionInfo>;

#[derive(Clone, Debug, Serialize, PartialEq)]
#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
pub struct PgConfig {
pub connection_string: String,
pub connection_string: Option<String>,
#[cfg(feature = "ssl")]
#[serde(skip_serializing_if = "Option::is_none")]
pub ca_root_file: Option<String>,
#[cfg(feature = "ssl")]
#[serde(skip_serializing_if = "Clone::clone")]
pub danger_accept_invalid_certs: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub default_srid: Option<i32>,
pub pool_size: u32,
#[serde(skip_serializing)]
pub discover_functions: bool,
#[serde(skip_serializing)]
pub discover_tables: bool,
pub tables: TableInfoSources,
pub functions: FuncInfoSources,
}

#[derive(Debug, Default, PartialEq, Deserialize)]
pub struct PgConfigBuilder {
pub connection_string: Option<String>,
#[cfg(feature = "ssl")]
pub ca_root_file: Option<String>,
#[cfg(feature = "ssl")]
pub danger_accept_invalid_certs: Option<bool>,
pub default_srid: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub pool_size: Option<u32>,
#[serde(skip)]
pub auto_tables: Option<Schemas>,
#[serde(skip)]
pub auto_functions: Option<Schemas>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tables: Option<TableInfoSources>,
#[serde(skip_serializing_if = "Option::is_none")]
pub functions: Option<FuncInfoSources>,
#[serde(skip)]
pub run_autodiscovery: bool,
}

impl PgConfigBuilder {
impl PgConfig {
pub fn merge(&mut self, other: Self) -> &mut Self {
set_option(&mut self.connection_string, other.connection_string);
#[cfg(feature = "ssl")]
set_option(&mut self.ca_root_file, other.ca_root_file);
#[cfg(feature = "ssl")]
set_option(
&mut self.danger_accept_invalid_certs,
other.danger_accept_invalid_certs,
);
{
set_option(&mut self.ca_root_file, other.ca_root_file);
self.danger_accept_invalid_certs |= other.danger_accept_invalid_certs;
}
set_option(&mut self.default_srid, other.default_srid);
set_option(&mut self.pool_size, other.pool_size);
set_option(&mut self.auto_tables, other.auto_tables);
set_option(&mut self.auto_functions, other.auto_functions);
set_option(&mut self.tables, other.tables);
set_option(&mut self.functions, other.functions);
self
Expand All @@ -249,24 +242,16 @@ impl PgConfigBuilder {
)
})?;
Ok(PgConfig {
connection_string,
#[cfg(feature = "ssl")]
ca_root_file: self.ca_root_file,
#[cfg(feature = "ssl")]
danger_accept_invalid_certs: self.danger_accept_invalid_certs.unwrap_or_default(),
default_srid: self.default_srid,
pool_size: self.pool_size.unwrap_or(POOL_SIZE_DEFAULT),
discover_functions: self.tables.is_none() && self.functions.is_none(),
discover_tables: self.tables.is_none() && self.functions.is_none(),
tables: self.tables.unwrap_or_default(),
functions: self.functions.unwrap_or_default(),
connection_string: Some(connection_string),
run_autodiscovery: self.tables.is_none() && self.functions.is_none(),
..self
})
}
}

impl From<(PgArgs, Option<String>)> for PgConfigBuilder {
impl From<(PgArgs, Option<String>)> for PgConfig {
fn from((args, connection): (PgArgs, Option<String>)) -> Self {
PgConfigBuilder {
PgConfig {
connection_string: connection.or_else(|| {
env::var_os("DATABASE_URL").and_then(|connection| connection.into_string().ok())
}),
Expand All @@ -275,13 +260,8 @@ impl From<(PgArgs, Option<String>)> for PgConfigBuilder {
env::var_os("CA_ROOT_FILE").and_then(|connection| connection.into_string().ok())
}),
#[cfg(feature = "ssl")]
danger_accept_invalid_certs: if args.danger_accept_invalid_certs
|| env::var_os("DANGER_ACCEPT_INVALID_CERTS").is_some()
{
Some(true)
} else {
None
},
danger_accept_invalid_certs: args.danger_accept_invalid_certs
|| env::var_os("DANGER_ACCEPT_INVALID_CERTS").is_some(),
default_srid: args.default_srid.or_else(|| {
env::var_os("DEFAULT_SRID").and_then(|srid| {
srid.into_string()
Expand All @@ -290,8 +270,7 @@ impl From<(PgArgs, Option<String>)> for PgConfigBuilder {
})
}),
pool_size: args.pool_size,
tables: None,
functions: None,
..Default::default()
}
}
}
76 changes: 38 additions & 38 deletions src/pg/configurator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::pg::pool::Pool;
use crate::pg::table_source::{calc_srid, get_table_sources, merge_table_info, table_to_query};
use crate::source::IdResolver;
use crate::srv::server::Sources;
use crate::utils::{find_info, InfoMap};
use crate::utils::{find_info, normalize_key, InfoMap, Schemas};
use futures::future::{join_all, try_join};
use itertools::Itertools;
use log::{debug, error, info, warn};
Expand All @@ -27,8 +27,8 @@ pub async fn resolve_pg_data(
Ok((
tables,
PgConfig {
tables: tbl_info,
functions: func_info,
tables: Some(tbl_info),
functions: Some(func_info),
..config
},
pg.pool,
Expand All @@ -38,8 +38,8 @@ pub async fn resolve_pg_data(
struct PgBuilder {
pool: Pool,
default_srid: Option<i32>,
discover_functions: bool,
discover_tables: bool,
auto_functions: Schemas,
auto_tables: Schemas,
id_resolver: IdResolver,
tables: TableInfoSources,
functions: FuncInfoSources,
Expand All @@ -48,19 +48,20 @@ struct PgBuilder {
impl PgBuilder {
async fn new(config: &PgConfig, id_resolver: IdResolver) -> io::Result<Self> {
let pool = Pool::new(config).await?;
let auto = config.run_autodiscovery;
Ok(Self {
pool,
default_srid: config.default_srid,
discover_functions: config.discover_functions,
discover_tables: config.discover_tables,
auto_functions: config.auto_functions.clone().unwrap_or(Schemas::Bool(auto)),
auto_tables: config.auto_tables.clone().unwrap_or(Schemas::Bool(auto)),
id_resolver,
tables: config.tables.clone(),
functions: config.functions.clone(),
tables: config.tables.clone().unwrap_or_default(),
functions: config.functions.clone().unwrap_or_default(),
})
}

pub async fn instantiate_tables(&self) -> Result<(Sources, TableInfoSources), io::Error> {
let all_tables = get_table_sources(&self.pool).await?;
let mut all_tables = get_table_sources(&self.pool).await?;

// Match configured sources with the discovered ones and add them to the pending list.
let mut used = HashSet::<(&str, &str, &str)>::new();
Expand Down Expand Up @@ -90,20 +91,20 @@ impl PgBuilder {
pending.push(table_to_query(id2, cfg_inf, self.pool.clone()));
}

if self.discover_tables {
// Sort the discovered sources by schema, table and geometry column to ensure a consistent behavior
for (schema, tables) in all_tables.into_iter().sorted_by(by_key) {
for (table, geoms) in tables.into_iter().sorted_by(by_key) {
for (geom, mut src_inf) in geoms.into_iter().sorted_by(by_key) {
if used.contains(&(schema.as_str(), table.as_str(), geom.as_str())) {
continue;
}
let id2 = self.resolve_id(table.clone(), &src_inf);
let Some(srid) = calc_srid(&src_inf.format_id(), &id2, src_inf.srid,0, self.default_srid) else {continue};
src_inf.srid = srid;
info!("Discovered source {id2} from {}", summary(&src_inf));
pending.push(table_to_query(id2, src_inf, self.pool.clone()));
// Sort the discovered sources by schema, table and geometry column to ensure a consistent behavior
for schema in self.auto_tables.get(|| all_tables.keys()) {
let Some(schema2) = normalize_key(&all_tables, &schema, "schema", "") else { continue };
let tables = all_tables.remove(&schema2).unwrap();
for (table, geoms) in tables.into_iter().sorted_by(by_key) {
for (geom, mut src_inf) in geoms.into_iter().sorted_by(by_key) {
if used.contains(&(schema.as_str(), table.as_str(), geom.as_str())) {
continue;
}
let id2 = self.resolve_id(table.clone(), &src_inf);
let Some(srid) = calc_srid(&src_inf.format_id(), &id2, src_inf.srid,0, self.default_srid) else {continue};
src_inf.srid = srid;
info!("Discovered source {id2} from {}", summary(&src_inf));
pending.push(table_to_query(id2, src_inf, self.pool.clone()));
}
}
}
Expand All @@ -129,14 +130,13 @@ impl PgBuilder {
}

pub async fn instantiate_functions(&self) -> Result<(Sources, FuncInfoSources), io::Error> {
let all_functions = get_function_sources(&self.pool).await?;

let mut all_funcs = get_function_sources(&self.pool).await?;
let mut res: Sources = HashMap::new();
let mut info_map = FuncInfoSources::new();
let mut used = HashSet::<(&str, &str)>::new();

for (id, cfg_inf) in &self.functions {
let Some(schemas) = find_info(&all_functions, &cfg_inf.schema, "schema", id) else { continue };
let Some(schemas) = find_info(&all_funcs, &cfg_inf.schema, "schema", id) else { continue };
if schemas.is_empty() {
warn!("No functions found in schema {}. Only functions like (z,x,y) -> bytea and similar are considered. See README.md", cfg_inf.schema);
continue;
Expand All @@ -155,19 +155,19 @@ impl PgBuilder {
info_map.insert(id2, cfg_inf.clone());
}

if self.discover_functions {
// Sort the discovered sources by schema and function name to ensure a consistent behavior
for (schema, funcs) in all_functions.into_iter().sorted_by(by_key) {
for (name, (pg_sql, src_inf)) in funcs.into_iter().sorted_by(by_key) {
if used.contains(&(schema.as_str(), name.as_str())) {
continue;
}
let id2 = self.resolve_id(name.clone(), &src_inf);
self.add_func_src(&mut res, id2.clone(), &src_inf, pg_sql.clone());
info!("Discovered source {id2} from function {}", pg_sql.signature);
debug!("{}", pg_sql.query);
info_map.insert(id2, src_inf);
// Sort the discovered sources by schema and function name to ensure a consistent behavior
for schema in self.auto_functions.get(|| all_funcs.keys()) {
let Some(schema2) = normalize_key(&all_funcs, &schema, "schema", "") else { continue };
let funcs = all_funcs.remove(&schema2).unwrap();
for (name, (pg_sql, src_inf)) in funcs.into_iter().sorted_by(by_key) {
if used.contains(&(schema.as_str(), name.as_str())) {
continue;
}
let id2 = self.resolve_id(name.clone(), &src_inf);
self.add_func_src(&mut res, id2.clone(), &src_inf, pg_sql.clone());
info!("Discovered source {id2} from function {}", pg_sql.signature);
debug!("{}", pg_sql.query);
info_map.insert(id2, src_inf);
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/pg/function_source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ pub async fn get_function_sources(pool: &Pool) -> Result<SqlFuncInfoMapMap, io::
assert_eq!(t, &["bytea", "text"]);
}
(None, None) => {}
_ => panic!("Invalid output record names or types"),
_ => panic!("Invalid output record names or types: {output_record_names:?} {output_record_types:?}"),
}
assert!(output_type == "bytea" || output_type == "record");

Expand Down
6 changes: 3 additions & 3 deletions src/pg/pool.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::pg::config::PgConfig;
use crate::pg::config::{PgConfig, POOL_SIZE_DEFAULT};
use crate::pg::utils::io_error;
use bb8::PooledConnection;
use bb8_postgres::{tokio_postgres as pg, PostgresConnectionManager};
Expand Down Expand Up @@ -35,7 +35,7 @@ pub struct Pool {

impl Pool {
pub async fn new(config: &PgConfig) -> io::Result<Self> {
let conn_str = config.connection_string.as_str();
let conn_str = config.connection_string.as_ref().unwrap().as_str();
info!("Connecting to {conn_str}");
let pg_cfg = pg::config::Config::from_str(conn_str)
.map_err(|e| io_error!(e, "Can't parse connection string {conn_str}"))?;
Expand Down Expand Up @@ -66,7 +66,7 @@ impl Pool {
};

let pool = InternalPool::builder()
.max_size(config.pool_size)
.max_size(config.pool_size.unwrap_or(POOL_SIZE_DEFAULT))
.build(manager)
.await
.map_err(|e| io_error!(e, "Can't build connection pool"))?;
Expand Down
Loading