From b6de417c9481ca859aa5b06258955d08de007b5b Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Mon, 30 Sep 2024 19:40:21 -0400 Subject: [PATCH] Use `serde-untagged` to improve some untagged enum error messages (#7822) ## Summary This is related to https://github.com/astral-sh/uv/issues/7817, but doesn't close it. --- Cargo.lock | 34 ++++++++++++++- Cargo.toml | 1 + crates/pypi-types/Cargo.toml | 1 + crates/pypi-types/src/simple_json.rs | 43 ++++++++++++------- crates/uv-configuration/Cargo.toml | 1 + crates/uv-configuration/src/trusted_host.rs | 40 ++++++++--------- .../src/metadata/requires_dist.rs | 29 +++++++++++-- crates/uv-workspace/Cargo.toml | 1 + crates/uv-workspace/src/pyproject.rs | 17 ++++++-- 9 files changed, 123 insertions(+), 44 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cc78efaecc45..2a1e4f329d56 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1077,6 +1077,16 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" +[[package]] +name = "erased-serde" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24e2389d65ab4fab27dc2a5de7b191e1f6617d1f1c8855c0dc569c94a4cbb18d" +dependencies = [ + "serde", + "typeid", +] + [[package]] name = "errno" version = "0.3.9" @@ -2717,7 +2727,7 @@ dependencies = [ "indoc", "libc", "memoffset 0.9.1", - "parking_lot 0.11.2", + "parking_lot 0.12.3", "portable-atomic", "pyo3-build-config", "pyo3-ffi", @@ -2796,6 +2806,7 @@ dependencies = [ "regex", "rkyv", "serde", + "serde-untagged", "thiserror", "toml", "toml_edit", @@ -3518,6 +3529,17 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "serde-untagged" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2676ba99bd82f75cae5cbd2c8eda6fa0b8760f18978ea840e980dd5567b5c5b6" +dependencies = [ + "erased-serde", + "serde", + "typeid", +] + [[package]] name = "serde_derive" version = "1.0.210" @@ -4256,6 +4278,12 @@ version = "0.18.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0609f771ad9c6155384897e1df4d948e692667cc0588548b68eb44d052b27633" +[[package]] +name = "typeid" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e13db2e0ccd5e14a544e8a246ba2312cd25223f616442d7f2cb0e3db614236e" + [[package]] name = "typenum" version = "1.17.0" @@ -4698,6 +4726,7 @@ dependencies = [ "rustc-hash", "schemars", "serde", + "serde-untagged", "serde_json", "thiserror", "tracing", @@ -5329,6 +5358,7 @@ dependencies = [ "same-file", "schemars", "serde", + "serde-untagged", "tempfile", "thiserror", "tokio", @@ -5551,7 +5581,7 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.59.0", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index bc942220fe5f..d2a1e54b054b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -141,6 +141,7 @@ same-file = { version = "1.0.6" } schemars = { version = "0.8.21", features = ["url"] } seahash = { version = "4.1.0" } serde = { version = "1.0.210", features = ["derive"] } +serde-untagged = { version = "0.1.6" } serde_json = { version = "1.0.128" } sha2 = { version = "0.10.8" } smallvec = { version = "1.13.2" } diff --git a/crates/pypi-types/Cargo.toml b/crates/pypi-types/Cargo.toml index 3f27358d2320..f23b50bfc053 100644 --- a/crates/pypi-types/Cargo.toml +++ b/crates/pypi-types/Cargo.toml @@ -27,6 +27,7 @@ mailparse = { workspace = true } regex = { workspace = true } rkyv = { workspace = true } serde = { workspace = true } +serde-untagged = { workspace = true } thiserror = { workspace = true } toml = { workspace = true } toml_edit = { workspace = true } diff --git a/crates/pypi-types/src/simple_json.rs b/crates/pypi-types/src/simple_json.rs index d32f1ac64ad4..134736c438e8 100644 --- a/crates/pypi-types/src/simple_json.rs +++ b/crates/pypi-types/src/simple_json.rs @@ -1,9 +1,8 @@ use std::str::FromStr; use jiff::Timestamp; -use serde::{Deserialize, Deserializer, Serialize}; - use pep440_rs::{VersionSpecifiers, VersionSpecifiersParseError}; +use serde::{Deserialize, Deserializer, Serialize}; use crate::lenient_requirement::LenientVersionSpecifiers; @@ -71,13 +70,24 @@ where )) } -#[derive(Debug, Clone, Deserialize)] -#[serde(untagged)] +#[derive(Debug, Clone)] pub enum CoreMetadata { Bool(bool), Hashes(Hashes), } +impl<'de> Deserialize<'de> for CoreMetadata { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + serde_untagged::UntaggedEnumVisitor::new() + .bool(|bool| Ok(CoreMetadata::Bool(bool))) + .map(|map| map.deserialize().map(CoreMetadata::Hashes)) + .deserialize(deserializer) + } +} + impl CoreMetadata { pub fn is_available(&self) -> bool { match self { @@ -87,24 +97,25 @@ impl CoreMetadata { } } -#[derive( - Debug, - Clone, - PartialEq, - Eq, - Hash, - Deserialize, - rkyv::Archive, - rkyv::Deserialize, - rkyv::Serialize, -)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, rkyv::Archive, rkyv::Deserialize, rkyv::Serialize)] #[rkyv(derive(Debug))] -#[serde(untagged)] pub enum Yanked { Bool(bool), Reason(String), } +impl<'de> Deserialize<'de> for Yanked { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + serde_untagged::UntaggedEnumVisitor::new() + .bool(|bool| Ok(Yanked::Bool(bool))) + .string(|string| Ok(Yanked::Reason(string.to_owned()))) + .deserialize(deserializer) + } +} + impl Yanked { pub fn is_yanked(&self) -> bool { match self { diff --git a/crates/uv-configuration/Cargo.toml b/crates/uv-configuration/Cargo.toml index 2d91036620c1..2f01a615da25 100644 --- a/crates/uv-configuration/Cargo.toml +++ b/crates/uv-configuration/Cargo.toml @@ -28,6 +28,7 @@ fs-err = { workspace = true } rustc-hash = { workspace = true } schemars = { workspace = true, optional = true } serde = { workspace = true } +serde-untagged = { workspace = true } serde_json = { workspace = true } thiserror = { workspace = true } tracing = { workspace = true } diff --git a/crates/uv-configuration/src/trusted_host.rs b/crates/uv-configuration/src/trusted_host.rs index ea4f529caba4..4fa1bce6788a 100644 --- a/crates/uv-configuration/src/trusted_host.rs +++ b/crates/uv-configuration/src/trusted_host.rs @@ -1,5 +1,5 @@ +use serde::{Deserialize, Deserializer}; use std::str::FromStr; - use url::Url; /// A trusted host, which could be a host or a host-port pair. @@ -33,28 +33,28 @@ impl TrustedHost { } } -#[derive(serde::Deserialize)] -#[serde(untagged)] -enum TrustHostWire { - String(String), - Struct { - scheme: Option, - host: String, - port: Option, - }, -} - -impl<'de> serde::de::Deserialize<'de> for TrustedHost { - fn deserialize(deserializer: D) -> Result +impl<'de> Deserialize<'de> for TrustedHost { + fn deserialize(deserializer: D) -> Result where - D: serde::de::Deserializer<'de>, + D: Deserializer<'de>, { - let helper = TrustHostWire::deserialize(deserializer)?; - - match helper { - TrustHostWire::String(s) => TrustedHost::from_str(&s).map_err(serde::de::Error::custom), - TrustHostWire::Struct { scheme, host, port } => Ok(TrustedHost { scheme, host, port }), + #[derive(Deserialize)] + struct Inner { + scheme: Option, + host: String, + port: Option, } + + serde_untagged::UntaggedEnumVisitor::new() + .string(|string| TrustedHost::from_str(string).map_err(serde::de::Error::custom)) + .map(|map| { + map.deserialize::().map(|inner| TrustedHost { + scheme: inner.scheme, + host: inner.host, + port: inner.port, + }) + }) + .deserialize(deserializer) } } diff --git a/crates/uv-distribution/src/metadata/requires_dist.rs b/crates/uv-distribution/src/metadata/requires_dist.rs index 928e99ab8e8e..f47c058eb3b1 100644 --- a/crates/uv-distribution/src/metadata/requires_dist.rs +++ b/crates/uv-distribution/src/metadata/requires_dist.rs @@ -226,6 +226,29 @@ mod test { "###); } + #[tokio::test] + async fn wrong_type() { + let input = indoc! {r#" + [project] + name = "foo" + version = "0.0.0" + dependencies = [ + "tqdm", + ] + [tool.uv.sources] + tqdm = true + "#}; + + assert_snapshot!(format_err(input).await, @r###" + error: TOML parse error at line 8, column 8 + | + 8 | tqdm = true + | ^^^^ + invalid type: boolean `true`, expected an array or map + + "###); + } + #[tokio::test] async fn too_many_git_specs() { let input = indoc! {r#" @@ -264,7 +287,7 @@ mod test { | 8 | tqdm = { git = "https://github.com/tqdm/tqdm", ref = "baaaaaab" } | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - data did not match any variant of untagged enum SourcesWire + data did not match any variant of untagged enum Source "###); } @@ -288,7 +311,7 @@ mod test { | 8 | tqdm = { path = "tqdm", index = "torch" } | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - data did not match any variant of untagged enum SourcesWire + data did not match any variant of untagged enum Source "###); } @@ -348,7 +371,7 @@ mod test { | 8 | tqdm = { url = "§invalid#+#*Ä" } | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ - data did not match any variant of untagged enum SourcesWire + data did not match any variant of untagged enum Source "###); } diff --git a/crates/uv-workspace/Cargo.toml b/crates/uv-workspace/Cargo.toml index b04eb185cd9a..ba6c1d0a47e2 100644 --- a/crates/uv-workspace/Cargo.toml +++ b/crates/uv-workspace/Cargo.toml @@ -32,6 +32,7 @@ rustc-hash = { workspace = true } same-file = { workspace = true } schemars = { workspace = true, optional = true } serde = { workspace = true, features = ["derive"] } +serde-untagged = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true } toml = { workspace = true } diff --git a/crates/uv-workspace/src/pyproject.rs b/crates/uv-workspace/src/pyproject.rs index f873b6451345..57582d93f43d 100644 --- a/crates/uv-workspace/src/pyproject.rs +++ b/crates/uv-workspace/src/pyproject.rs @@ -444,15 +444,26 @@ impl IntoIterator for Sources { } } -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] -#[serde(rename_all = "kebab-case", untagged)] -#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema), schemars(untagged))] #[allow(clippy::large_enum_variant)] enum SourcesWire { One(Source), Many(Vec), } +impl<'de> serde::de::Deserialize<'de> for SourcesWire { + fn deserialize(deserializer: D) -> Result + where + D: serde::de::Deserializer<'de>, + { + serde_untagged::UntaggedEnumVisitor::new() + .map(|map| map.deserialize().map(SourcesWire::One)) + .seq(|seq| seq.deserialize().map(SourcesWire::Many)) + .deserialize(deserializer) + } +} + impl TryFrom for Sources { type Error = SourceError;