diff --git a/src/lib.rs b/src/lib.rs index 897c7a6..d94b27a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,9 @@ use indexmap::IndexMap; use pep440_rs::{Version, VersionSpecifiers}; use pep508_rs::Requirement; +use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; +use toml::Table; /// The `[build-system]` section of a pyproject.toml as specified in PEP 517 #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] @@ -18,11 +20,13 @@ pub struct BuildSystem { /// A pyproject.toml as specified in PEP 517 #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] #[serde(rename_all = "kebab-case")] -pub struct PyProjectToml { +pub struct PyProjectToml { /// Build-related data pub build_system: Option, /// Project metadata pub project: Option, + /// Tool section + pub tool: Option, } /// PEP 621 project metadata @@ -161,11 +165,32 @@ pub struct Contact { pub email: Option, } -impl PyProjectToml { +impl PyProjectToml { /// Parse `pyproject.toml` content pub fn new(content: &str) -> Result { toml::de::from_str(content) } + + pub fn try_from(value: PyProjectToml) -> Result { + Ok(Self { + build_system: value.build_system, + project: value.project, + tool: value.tool.map(Table::try_into).transpose()?, + }) + } +} + +impl PyProjectToml { + pub fn try_into<'de, T>(self) -> Result, toml::de::Error> + where + T: Deserialize<'de>, + { + Ok(PyProjectToml { + build_system: self.build_system, + project: self.project, + tool: self.tool.map(Table::try_into).transpose()?, + }) + } } #[cfg(test)] @@ -173,6 +198,7 @@ mod tests { use super::{License, LicenseFiles, PyProjectToml, ReadMe}; use pep440_rs::{Version, VersionSpecifiers}; use pep508_rs::Requirement; + use serde::Deserialize; use std::str::FromStr; #[test] @@ -228,7 +254,7 @@ spam-gui = "spam:main_gui" [project.entry-points."spam.magical"] tomatoes = "spam:main_tomatoes""#; - let project_toml = PyProjectToml::new(source).unwrap(); + let project_toml: PyProjectToml = PyProjectToml::new(source).unwrap(); let build_system = &project_toml.build_system.unwrap(); assert_eq!( build_system.requires, @@ -285,7 +311,7 @@ build-backend = "maturin" name = "spam" license = "MIT OR BSD-3-Clause" "#; - let project_toml = PyProjectToml::new(source).unwrap(); + let project_toml: PyProjectToml = PyProjectToml::new(source).unwrap(); let project = project_toml.project.as_ref().unwrap(); assert_eq!( project.license, @@ -310,7 +336,7 @@ license-files.paths = [ "setuptools/_vendor/LICENSE.BSD", ] "#; - let project_toml = PyProjectToml::new(source).unwrap(); + let project_toml: PyProjectToml = PyProjectToml::new(source).unwrap(); let project = project_toml.project.as_ref().unwrap(); assert_eq!( @@ -345,7 +371,7 @@ license-files.globs = [ "setuptools/_vendor/LICENSE*", ] "#; - let project_toml = PyProjectToml::new(source).unwrap(); + let project_toml: PyProjectToml = PyProjectToml::new(source).unwrap(); let project = project_toml.project.as_ref().unwrap(); assert_eq!( @@ -372,7 +398,7 @@ build-backend = "maturin" [project] name = "spam" "#; - let project_toml = PyProjectToml::new(source).unwrap(); + let project_toml: PyProjectToml = PyProjectToml::new(source).unwrap(); let project = project_toml.project.as_ref().unwrap(); assert_eq!( @@ -396,7 +422,7 @@ build-backend = "maturin" name = "spam" readme = {text = "ReadMe!", content-type = "text/plain"} "#; - let project_toml = PyProjectToml::new(source).unwrap(); + let project_toml: PyProjectToml = PyProjectToml::new(source).unwrap(); let project = project_toml.project.as_ref().unwrap(); assert_eq!( @@ -408,4 +434,61 @@ readme = {text = "ReadMe!", content-type = "text/plain"} }) ); } + + #[test] + fn test_parse_pyproject_toml_tool_section_as_table() { + let source = r#"[build-system] +requires = ["maturin"] +build-backend = "maturin" + +[tool.maturin] +bindings = "pyo3" +"#; + let project_toml: PyProjectToml = PyProjectToml::new(source).unwrap(); + let tool = project_toml.tool.as_ref().unwrap(); + assert_eq!( + tool.get("maturin") + .unwrap() + .get("bindings") + .unwrap() + .as_str(), + Some("pyo3") + ); + } + + #[test] + fn test_parse_pyproject_toml_tool_section() { + #[derive(Deserialize)] + struct Maturin { + bindings: Option, + } + + #[derive(Deserialize)] + struct Tools { + maturin: Option, + } + + let source = r#"[build-system] +requires = ["maturin"] +build-backend = "maturin" + +[tool.maturin] +bindings = "pyo3" + +[tool.ruff] +line-length = 120 +"#; + + let project_toml: PyProjectToml = PyProjectToml::new(source).unwrap(); + assert_eq!( + project_toml + .tool + .unwrap() + .maturin + .unwrap() + .bindings + .as_deref(), + Some("pyo3") + ); + } }