diff --git a/src/config.rs b/src/config.rs index 215fd7e..bbd3039 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,4 +1,7 @@ -use std::{collections::HashSet, path::PathBuf}; +use std::{ + collections::HashSet, + path::{Path, PathBuf}, +}; use anyhow::{Error, Result}; use serde::Deserialize; @@ -90,6 +93,7 @@ impl TryFrom for Config { bc.auth.insert(s); } Element::Fork => bc.fork = true, + Element::Include(_) => { /* no-op */ } Element::KeepUmask => bc.keep_umask = true, Element::Limit => { warn!("warning: busd does not implement ``"); @@ -142,6 +146,11 @@ impl Config { let doc: Document = quick_xml::de::from_str(s)?; Self::try_from(doc) } + + pub fn read_file(file_path: impl AsRef) -> Result { + let doc = Document::read_file(file_path)?; + Self::try_from(doc) + } } #[derive(Clone, Debug, Deserialize, PartialEq)] diff --git a/src/config/xml.rs b/src/config/xml.rs index 83f62e1..ddd25cb 100644 --- a/src/config/xml.rs +++ b/src/config/xml.rs @@ -1,6 +1,13 @@ -use std::path::PathBuf; +use std::{ + env::current_dir, + fs::read_to_string, + path::{Path, PathBuf}, + str::FromStr, +}; +use anyhow::{Error, Result}; use serde::Deserialize; +use tracing::{error, warn}; use super::{BusType, MessageType}; @@ -14,6 +21,92 @@ use super::{BusType, MessageType}; pub struct Document { #[serde(rename = "$value", default)] pub busconfig: Vec, + file_path: Option, +} +impl FromStr for Document { + type Err = Error; + + fn from_str(s: &str) -> Result { + quick_xml::de::from_str(s).map_err(Error::msg) + } +} +impl Document { + pub fn read_file(file_path: impl AsRef) -> Result { + let text = read_to_string(file_path.as_ref())?; + let mut doc = Document::from_str(&text)?; + doc.file_path = Some(file_path.as_ref().to_path_buf()); + doc.resolve_includes() + } + + fn resolve_includes(self) -> Result { + // TODO: implement protection against circular `` references + let base_path = self.base_path()?; + let Document { + busconfig, + file_path, + } = self; + + let mut doc = Document { + busconfig: vec![], + file_path: None, + }; + + for el in busconfig { + match el { + Element::Include(include) => { + let ignore_missing = include.ignore_missing == IncludeOption::Yes; + let file_path = match resolve_include_path(&base_path, &include.file_path) { + Ok(ok) => ok, + Err(err) => { + let msg = format!( + "'{}' should be a valid file path", + include.file_path.display() + ); + if ignore_missing { + warn!(msg); + continue; + } + error!(msg); + return Err(err); + } + }; + let mut included = match Document::read_file(&file_path) { + Ok(ok) => ok, + Err(err) => { + let msg = format!( + "'{}' should contain valid XML", + include.file_path.display() + ); + if ignore_missing { + warn!(msg); + continue; + } + error!(msg); + return Err(err); + } + }; + doc.busconfig.append(&mut included.busconfig); + } + _ => doc.busconfig.push(el), + } + } + + doc.file_path = file_path; + Ok(doc) + } + + fn base_path(&self) -> Result { + match &self.file_path { + Some(some) => Ok(some + .parent() + .ok_or_else(|| Error::msg("`` path should contain a file name"))? + .to_path_buf()), + None => { + warn!("ad-hoc document with unknown file path, using current working directory"); + current_dir().map_err(Error::msg) + } + } + } } #[derive(Clone, Debug, Deserialize, PartialEq)] @@ -22,9 +115,9 @@ pub enum Element { AllowAnonymous, Auth(String), Fork, + Include(IncludeElement), + // TODO: support `` KeepUmask, - // TODO: support `` TODO: support `` Listen(String), Limit, Pidfile(PathBuf), @@ -43,6 +136,29 @@ pub enum Element { User(String), } +#[derive(Clone, Debug, Default, Deserialize, PartialEq)] +pub struct IncludeElement { + #[serde(default, rename = "@ignore_missing")] + ignore_missing: IncludeOption, + + // TODO: implement SELinux + #[serde(default, rename = "@if_selinux_enabled")] + if_selinux_enable: IncludeOption, + #[serde(default, rename = "@selinux_root_relative")] + selinux_root_relative: IncludeOption, + + #[serde(rename = "$value")] + file_path: PathBuf, +} + +#[derive(Clone, Debug, Default, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum IncludeOption { + #[default] + No, + Yes, +} + #[derive(Clone, Debug, Deserialize, PartialEq)] #[serde(rename_all = "snake_case")] pub enum PolicyContext { @@ -137,3 +253,19 @@ pub struct TypeElement { #[serde(rename = "$text")] pub r#type: BusType, } + +fn resolve_include_path( + base_path: impl AsRef, + include_path: impl AsRef, +) -> Result { + let p = include_path.as_ref(); + if p.is_absolute() { + return p.canonicalize().map_err(Error::msg); + } + + base_path + .as_ref() + .join(p) + .canonicalize() + .map_err(Error::msg) +} diff --git a/tests/config.rs b/tests/config.rs new file mode 100644 index 0000000..4fccf41 --- /dev/null +++ b/tests/config.rs @@ -0,0 +1,51 @@ +use std::collections::HashSet; + +use busd::config::{Access, Config, Name, Operation, OwnOperation, Policy}; + +#[test] +fn config_read_file_with_includes_ok() { + let got = Config::read_file("./tests/fixture.conf") + .expect("should read and parse ./tests/fixture.conf"); + + assert_eq!( + got, + Config { + auth: HashSet::from_iter(vec![String::from("ANONYMOUS"), String::from("EXTERNAL"),]), + listen: HashSet::from_iter(vec![ + String::from("unix:path=/tmp/foo"), + String::from("tcp:host=localhost,port=1234"), + ]), + policies: vec![ + Policy::DefaultContext(vec![ + ( + Access::Allow, + Operation::Own(OwnOperation { + own: Some(Name::Any) + }) + ), + ( + Access::Deny, + Operation::Own(OwnOperation { + own: Some(Name::Any) + }) + ), + ]), + Policy::MandatoryContext(vec![ + ( + Access::Deny, + Operation::Own(OwnOperation { + own: Some(Name::Any) + }) + ), + ( + Access::Allow, + Operation::Own(OwnOperation { + own: Some(Name::Any) + }) + ), + ],), + ], + ..Default::default() + } + ); +} diff --git a/tests/fixture.conf b/tests/fixture.conf new file mode 100644 index 0000000..494b4d0 --- /dev/null +++ b/tests/fixture.conf @@ -0,0 +1,12 @@ + + + ANONYMOUS + unix:path=/tmp/foo + + + + + ./fixture_included.conf + ./fixture_missing.conf + diff --git a/tests/fixture_included.conf b/tests/fixture_included.conf new file mode 100644 index 0000000..89ed48b --- /dev/null +++ b/tests/fixture_included.conf @@ -0,0 +1,10 @@ + + + EXTERNAL + tcp:host=localhost,port=1234 + + + + +