Skip to content

Commit

Permalink
refactor: use DeserializeSeed implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
orhun committed Feb 16, 2024
1 parent 9ecda6c commit ea724bf
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 94 deletions.
44 changes: 11 additions & 33 deletions src/project/manifest/feature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::{Activation, PyPiRequirement, SystemRequirements, Target, TargetSelec
use crate::consts;
use crate::project::manifest::channel::{PrioritizedChannel, TomlPrioritizedChannelStrOrMap};
use crate::project::manifest::target::Targets;
use crate::project::manifest::UniquePackageName;
use crate::project::manifest::{deserialize_opt_package_map, deserialize_package_map};
use crate::project::SpecType;
use crate::task::{Task, TaskName};
use crate::utils::spanned::PixiSpanned;
Expand All @@ -11,7 +11,7 @@ use itertools::Either;
use rattler_conda_types::{NamelessMatchSpec, PackageName, Platform};
use serde::de::Error;
use serde::{Deserialize, Deserializer, Serialize};
use serde_with::{serde_as, DisplayFromStr, PickFirst};
use serde_with::serde_as;
use std::borrow::{Borrow, Cow};
use std::collections::HashMap;
use std::fmt;
Expand Down Expand Up @@ -233,17 +233,14 @@ impl<'de> Deserialize<'de> for Feature {
#[serde(default)]
target: IndexMap<PixiSpanned<TargetSelector>, Target>,

#[serde(default)]
#[serde_as(as = "IndexMap<_, PickFirst<(DisplayFromStr, _)>>")]
dependencies: IndexMap<UniquePackageName, NamelessMatchSpec>,
#[serde(default, deserialize_with = "deserialize_package_map")]
dependencies: IndexMap<PackageName, NamelessMatchSpec>,

#[serde(default)]
#[serde_as(as = "Option<IndexMap<_, PickFirst<(DisplayFromStr, _)>>>")]
host_dependencies: Option<IndexMap<UniquePackageName, NamelessMatchSpec>>,
#[serde(default, deserialize_with = "deserialize_opt_package_map")]
host_dependencies: Option<IndexMap<PackageName, NamelessMatchSpec>>,

#[serde(default)]
#[serde_as(as = "Option<IndexMap<_, PickFirst<(DisplayFromStr, _)>>>")]
build_dependencies: Option<IndexMap<UniquePackageName, NamelessMatchSpec>>,
#[serde(default, deserialize_with = "deserialize_opt_package_map")]
build_dependencies: Option<IndexMap<PackageName, NamelessMatchSpec>>,

#[serde(default)]
pypi_dependencies: Option<IndexMap<rip::types::PackageName, PyPiRequirement>>,
Expand All @@ -258,31 +255,12 @@ impl<'de> Deserialize<'de> for Feature {
}

let inner = FeatureInner::deserialize(deserializer)?;
let mut dependencies = HashMap::from_iter([(
SpecType::Run,
inner
.dependencies
.into_iter()
.map(|(p, s)| (p.as_inner(), s))
.collect(),
)]);
let mut dependencies = HashMap::from_iter([(SpecType::Run, inner.dependencies)]);
if let Some(host_deps) = inner.host_dependencies {
dependencies.insert(
SpecType::Host,
host_deps
.into_iter()
.map(|(p, s)| (p.as_inner(), s))
.collect(),
);
dependencies.insert(SpecType::Host, host_deps);
}
if let Some(build_deps) = inner.build_dependencies {
dependencies.insert(
SpecType::Build,
build_deps
.into_iter()
.map(|(p, s)| (p.as_inner(), s))
.collect(),
);
dependencies.insert(SpecType::Build, build_deps);
}

let default_target = Target {
Expand Down
149 changes: 88 additions & 61 deletions src/project/manifest/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@ pub use metadata::ProjectMetadata;
use miette::{miette, Diagnostic, IntoDiagnostic, LabeledSpan, NamedSource};
pub use python::PyPiRequirement;
use rattler_conda_types::{MatchSpec, NamelessMatchSpec, PackageName, Platform, Version};
use serde::de::{DeserializeSeed, MapAccess, Visitor};
use serde::{Deserialize, Deserializer};
use serde_with::{serde_as, DisplayFromStr, PickFirst};
use std::hash::{Hash, Hasher};
use serde_with::serde_as;
use std::fmt;
use std::hash::Hash;
use std::marker::PhantomData;
use std::{
collections::HashMap,
path::{Path, PathBuf},
Expand Down Expand Up @@ -885,45 +888,91 @@ impl ProjectManifest {
}
}

#[derive(Debug, Clone, Eq)]
pub(crate) struct UniquePackageName {
inner: PackageName,
}
struct PackageMap<'a>(&'a IndexMap<PackageName, NamelessMatchSpec>);

impl PartialEq for UniquePackageName {
fn eq(&self, other: &Self) -> bool {
self.as_inner().eq(&other.as_inner())
}
}
impl<'de, 'a> DeserializeSeed<'de> for PackageMap<'a> {
type Value = PackageName;

impl Hash for UniquePackageName {
fn hash<H: Hasher>(&self, state: &mut H) {
self.as_inner().as_normalized().hash(state);
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
let package_name = PackageName::deserialize(deserializer)?;
match self.0.get_key_value(&package_name) {
Some((package_name, _)) => {
Err(serde::de::Error::custom(
format!(
"duplicate dependency: {} (please avoid using capitalized names for the dependencies)", package_name.as_source())
))
}
None => Ok(package_name),
}
}
}

impl UniquePackageName {
pub fn as_inner(&self) -> PackageName {
self.inner.clone()
}
}
struct NamelessMatchSpecWrapper {}

impl<'de> Deserialize<'de> for UniquePackageName {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
impl<'de, 'a> DeserializeSeed<'de> for &'a NamelessMatchSpecWrapper {
type Value = NamelessMatchSpec;

fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: serde::Deserializer<'de>,
D: Deserializer<'de>,
{
let package_name = PackageName::deserialize(deserializer)?;
if package_name.as_source() != package_name.as_normalized() {
Err(serde::de::Error::custom(
"invalid dependency: please avoid using capitalized names for the dependencies",
))
} else {
Ok(Self {
inner: package_name,
serde_untagged::UntaggedEnumVisitor::new()
.string(|str| NamelessMatchSpec::from_str(str).map_err(serde::de::Error::custom))
.map(|map| {
NamelessMatchSpec::deserialize(serde::de::value::MapAccessDeserializer::new(map))
})
.expecting("either a map or a string")
.deserialize(deserializer)
}
}

pub(crate) fn deserialize_package_map<'de, D>(
deserializer: D,
) -> Result<IndexMap<PackageName, NamelessMatchSpec>, D::Error>
where
D: Deserializer<'de>,
{
struct PackageMapVisitor(PhantomData<()>);

impl<'de> Visitor<'de> for PackageMapVisitor {
type Value = IndexMap<PackageName, NamelessMatchSpec>;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(formatter, "a map")
}

fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: MapAccess<'de>,
{
let mut result = IndexMap::new();
let match_spec = NamelessMatchSpecWrapper {};
while let Some((package_name, match_spec)) = map
.next_entry_seed::<PackageMap, &NamelessMatchSpecWrapper>(
PackageMap(&result),
&match_spec,
)?
{
result.insert(package_name, match_spec);
}

Ok(result)
}
}
let visitor = PackageMapVisitor(PhantomData);
deserializer.deserialize_seq(visitor)
}

pub(crate) fn deserialize_opt_package_map<'de, D>(
deserializer: D,
) -> Result<Option<IndexMap<PackageName, NamelessMatchSpec>>, D::Error>
where
D: Deserializer<'de>,
{
Option::deserialize(deserializer)
}

impl<'de> Deserialize<'de> for ProjectManifest {
Expand All @@ -949,17 +998,14 @@ impl<'de> Deserialize<'de> for ProjectManifest {
//
// #[serde(flatten)]
// default_target: Target,
#[serde(default)]
#[serde_as(as = "IndexMap<_, PickFirst<(DisplayFromStr, _)>>")]
dependencies: IndexMap<UniquePackageName, NamelessMatchSpec>,
#[serde(default, deserialize_with = "deserialize_package_map")]
dependencies: IndexMap<PackageName, NamelessMatchSpec>,

#[serde(default)]
#[serde_as(as = "Option<IndexMap<_, PickFirst<(DisplayFromStr, _)>>>")]
host_dependencies: Option<IndexMap<UniquePackageName, NamelessMatchSpec>>,
#[serde(default, deserialize_with = "deserialize_opt_package_map")]
host_dependencies: Option<IndexMap<PackageName, NamelessMatchSpec>>,

#[serde(default)]
#[serde_as(as = "Option<IndexMap<_, PickFirst<(DisplayFromStr, _)>>>")]
build_dependencies: Option<IndexMap<UniquePackageName, NamelessMatchSpec>>,
#[serde(default, deserialize_with = "deserialize_opt_package_map")]
build_dependencies: Option<IndexMap<PackageName, NamelessMatchSpec>>,

#[serde(default)]
pypi_dependencies: Option<IndexMap<rip::types::PackageName, PyPiRequirement>>,
Expand All @@ -982,31 +1028,12 @@ impl<'de> Deserialize<'de> for ProjectManifest {
}

let toml_manifest = TomlProjectManifest::deserialize(deserializer)?;
let mut dependencies = HashMap::from_iter([(
SpecType::Run,
toml_manifest
.dependencies
.into_iter()
.map(|(p, s)| (p.as_inner(), s))
.collect(),
)]);
let mut dependencies = HashMap::from_iter([(SpecType::Run, toml_manifest.dependencies)]);
if let Some(host_deps) = toml_manifest.host_dependencies {
dependencies.insert(
SpecType::Host,
host_deps
.into_iter()
.map(|(p, s)| (p.as_inner(), s))
.collect(),
);
dependencies.insert(SpecType::Host, host_deps);
}
if let Some(build_deps) = toml_manifest.build_dependencies {
dependencies.insert(
SpecType::Build,
build_deps
.into_iter()
.map(|(p, s)| (p.as_inner(), s))
.collect(),
);
dependencies.insert(SpecType::Build, build_deps);
}

let default_target = Target {
Expand Down

0 comments on commit ea724bf

Please sign in to comment.