Skip to content

Commit

Permalink
Allow parsing the tool table
Browse files Browse the repository at this point in the history
  • Loading branch information
bschoenmaeckers committed Mar 13, 2024
1 parent ff19f44 commit 7d46cb4
Showing 1 changed file with 60 additions and 8 deletions.
68 changes: 60 additions & 8 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use indexmap::IndexMap;
use pep440_rs::{Version, VersionSpecifiers};
use pep508_rs::Requirement;
use serde::{Deserialize, Serialize};
use serde::de::DeserializeOwned;
use toml::Table;

/// The `[build-system]` section of a pyproject.toml as specified in PEP 517
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
Expand All @@ -18,11 +20,13 @@ pub struct BuildSystem {
/// A pyproject.toml as specified in PEP 517
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "kebab-case")]
pub struct PyProjectToml {
pub struct PyProjectToml<T = Table> {
/// Build-related data
pub build_system: Option<BuildSystem>,
/// Project metadata
pub project: Option<Project>,
/// Tool section
pub tool: Option<T>,
}

/// PEP 621 project metadata
Expand Down Expand Up @@ -161,7 +165,7 @@ pub struct Contact {
pub email: Option<String>,
}

impl PyProjectToml {
impl<T: DeserializeOwned> PyProjectToml<T> {
/// Parse `pyproject.toml` content
pub fn new(content: &str) -> Result<Self, toml::de::Error> {
toml::de::from_str(content)
Expand All @@ -174,6 +178,7 @@ mod tests {
use pep440_rs::{Version, VersionSpecifiers};
use pep508_rs::Requirement;
use std::str::FromStr;
use serde::{Deserialize};

#[test]
fn test_parse_pyproject_toml() {
Expand Down Expand Up @@ -228,7 +233,7 @@ spam-gui = "spam:main_gui"
[project.entry-points."spam.magical"]
tomatoes = "spam:main_tomatoes""#;
let project_toml = PyProjectToml::new(source).unwrap();
let project_toml: PyProjectToml = PyProjectToml::new(source).unwrap();
let build_system = &project_toml.build_system.unwrap();
assert_eq!(
build_system.requires,
Expand Down Expand Up @@ -285,7 +290,7 @@ build-backend = "maturin"
name = "spam"
license = "MIT OR BSD-3-Clause"
"#;
let project_toml = PyProjectToml::new(source).unwrap();
let project_toml: PyProjectToml = PyProjectToml::new(source).unwrap();
let project = project_toml.project.as_ref().unwrap();
assert_eq!(
project.license,
Expand All @@ -310,7 +315,7 @@ license-files.paths = [
"setuptools/_vendor/LICENSE.BSD",
]
"#;
let project_toml = PyProjectToml::new(source).unwrap();
let project_toml: PyProjectToml = PyProjectToml::new(source).unwrap();
let project = project_toml.project.as_ref().unwrap();

assert_eq!(
Expand Down Expand Up @@ -345,7 +350,7 @@ license-files.globs = [
"setuptools/_vendor/LICENSE*",
]
"#;
let project_toml = PyProjectToml::new(source).unwrap();
let project_toml: PyProjectToml = PyProjectToml::new(source).unwrap();
let project = project_toml.project.as_ref().unwrap();

assert_eq!(
Expand All @@ -372,7 +377,7 @@ build-backend = "maturin"
[project]
name = "spam"
"#;
let project_toml = PyProjectToml::new(source).unwrap();
let project_toml: PyProjectToml = PyProjectToml::new(source).unwrap();
let project = project_toml.project.as_ref().unwrap();

assert_eq!(
Expand All @@ -396,7 +401,7 @@ build-backend = "maturin"
name = "spam"
readme = {text = "ReadMe!", content-type = "text/plain"}
"#;
let project_toml = PyProjectToml::new(source).unwrap();
let project_toml: PyProjectToml = PyProjectToml::new(source).unwrap();
let project = project_toml.project.as_ref().unwrap();

assert_eq!(
Expand All @@ -408,4 +413,51 @@ readme = {text = "ReadMe!", content-type = "text/plain"}
})
);
}

#[test]
fn test_parse_pyproject_toml_tool_section_as_table() {
let source = r#"[build-system]
requires = ["maturin"]
build-backend = "maturin"
[tool.maturin]
bindings = "pyo3"
"#;
let project_toml: PyProjectToml = PyProjectToml::new(source).unwrap();
let tool = project_toml.tool.as_ref().unwrap();
assert_eq!(
tool.get("maturin").unwrap().get("bindings").unwrap().as_str(),
Some("pyo3")
);
}

#[test]
fn test_parse_pyproject_toml_tool_section() {
#[derive(Deserialize)]
struct Maturin {
bindings: Option<String>,
}

#[derive(Deserialize)]
struct Tools {
maturin: Option<Maturin>,
}

let source = r#"[build-system]
requires = ["maturin"]
build-backend = "maturin"
[tool.maturin]
bindings = "pyo3"
[tool.ruff]
line-length = 120
"#;

let project_toml: PyProjectToml<Tools> = PyProjectToml::new(source).unwrap();
assert_eq!(
project_toml.tool.unwrap().maturin.unwrap().bindings.as_deref(),
Some("pyo3")
);
}
}

0 comments on commit 7d46cb4

Please sign in to comment.