Skip to content

Commit

Permalink
refactor: use a custom deserialized type for checking duplicates
Browse files Browse the repository at this point in the history
  • Loading branch information
orhun committed Jan 31, 2024
1 parent 870c990 commit c70586c
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 75 deletions.
41 changes: 30 additions & 11 deletions src/project/manifest/feature.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use super::{Activation, PyPiRequirement, SystemRequirements, Target, TargetSelector};
use crate::consts;
use crate::project::manifest::channel::{PrioritizedChannel, TomlPrioritizedChannelStrOrMap};
use crate::project::manifest::deserialize_dependencies;
use crate::project::manifest::target::Targets;
use crate::project::manifest::UniquePackageName;
use crate::project::SpecType;
use crate::task::Task;
use crate::utils::spanned::PixiSpanned;
Expand Down Expand Up @@ -230,15 +230,15 @@ impl<'de> Deserialize<'de> for Feature {

#[serde(default)]
#[serde_as(as = "IndexMap<_, PickFirst<(DisplayFromStr, _)>>")]
dependencies: IndexMap<String, NamelessMatchSpec>,
dependencies: IndexMap<UniquePackageName, NamelessMatchSpec>,

#[serde(default)]
#[serde_as(as = "Option<IndexMap<_, PickFirst<(DisplayFromStr, _)>>>")]
host_dependencies: Option<IndexMap<String, NamelessMatchSpec>>,
host_dependencies: Option<IndexMap<UniquePackageName, NamelessMatchSpec>>,

#[serde(default)]
#[serde_as(as = "Option<IndexMap<_, PickFirst<(DisplayFromStr, _)>>>")]
build_dependencies: Option<IndexMap<String, NamelessMatchSpec>>,
build_dependencies: Option<IndexMap<UniquePackageName, NamelessMatchSpec>>,

#[serde(default)]
pypi_dependencies: Option<IndexMap<rip::types::PackageName, PyPiRequirement>>,
Expand All @@ -253,13 +253,32 @@ impl<'de> Deserialize<'de> for Feature {
}

let inner = FeatureInner::deserialize(deserializer)?;

let dependencies = deserialize_dependencies(
inner.dependencies,
inner.host_dependencies,
inner.build_dependencies,
)
.map_err(serde::de::Error::custom)?;
let mut dependencies = HashMap::from_iter([(
SpecType::Run,
inner
.dependencies
.into_iter()
.map(|(p, s)| (p.as_inner(), s))
.collect(),
)]);
if let Some(host_deps) = inner.host_dependencies {
dependencies.insert(
SpecType::Host,
host_deps
.into_iter()
.map(|(p, s)| (p.as_inner(), s))
.collect(),
);
}
if let Some(build_deps) = inner.build_dependencies {
dependencies.insert(
SpecType::Build,
build_deps
.into_iter()
.map(|(p, s)| (p.as_inner(), s))
.collect(),
);
}

let default_target = Target {
dependencies,
Expand Down
133 changes: 69 additions & 64 deletions src/project/manifest/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,9 @@ use itertools::Itertools;
pub use metadata::ProjectMetadata;
use miette::{miette, Diagnostic, IntoDiagnostic, LabeledSpan, NamedSource};
pub use python::PyPiRequirement;
use rattler_conda_types::{
InvalidPackageNameError, MatchSpec, NamelessMatchSpec, PackageName, Platform, Version,
};
use rattler_conda_types::{MatchSpec, NamelessMatchSpec, PackageName, Platform, Version};
use serde_with::{serde_as, DisplayFromStr, Map, PickFirst};
use std::collections::HashSet;
use std::hash::Hash;
use std::hash::{Hash, Hasher};
use std::{
collections::HashMap,
path::{Path, PathBuf},
Expand Down Expand Up @@ -745,13 +742,45 @@ impl ProjectManifest {
}
}

#[derive(Debug, Error, Diagnostic)]
pub(crate) enum ManifestDeserializeError {
#[error("duplicate dependency: {0}, please avoid using capitalized names for the dependencies as they are read as lowercase as well.")]
DuplicateDependency(String),
#[derive(Debug, Clone, Eq)]
pub(crate) struct UniquePackageName {
inner: PackageName,
}

impl PartialEq for UniquePackageName {
fn eq(&self, other: &Self) -> bool {
self.as_inner().eq(&other.as_inner())
}
}

#[error("invalid package name: `{0}`")]
InvalidPackageName(#[from] InvalidPackageNameError),
impl Hash for UniquePackageName {
fn hash<H: Hasher>(&self, state: &mut H) {
self.as_inner().as_normalized().hash(state);
}
}

impl UniquePackageName {
pub fn as_inner(&self) -> PackageName {
self.inner.clone()
}
}

impl<'de> Deserialize<'de> for UniquePackageName {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let package_name = PackageName::deserialize(deserializer)?;
if package_name.as_source() != package_name.as_normalized() {
Err(serde::de::Error::custom(
"invalid dependency: please avoid using capitalized names for the dependencies",
))
} else {
Ok(Self {
inner: package_name,
})
}
}
}

impl<'de> Deserialize<'de> for ProjectManifest {
Expand Down Expand Up @@ -779,15 +808,15 @@ impl<'de> Deserialize<'de> for ProjectManifest {
// default_target: Target,
#[serde(default)]
#[serde_as(as = "IndexMap<_, PickFirst<(DisplayFromStr, _)>>")]
dependencies: IndexMap<String, NamelessMatchSpec>,
dependencies: IndexMap<UniquePackageName, NamelessMatchSpec>,

#[serde(default)]
#[serde_as(as = "Option<IndexMap<_, PickFirst<(DisplayFromStr, _)>>>")]
host_dependencies: Option<IndexMap<String, NamelessMatchSpec>>,
host_dependencies: Option<IndexMap<UniquePackageName, NamelessMatchSpec>>,

#[serde(default)]
#[serde_as(as = "Option<IndexMap<_, PickFirst<(DisplayFromStr, _)>>>")]
build_dependencies: Option<IndexMap<String, NamelessMatchSpec>>,
build_dependencies: Option<IndexMap<UniquePackageName, NamelessMatchSpec>>,

#[serde(default)]
pypi_dependencies: Option<IndexMap<rip::types::PackageName, PyPiRequirement>>,
Expand All @@ -811,13 +840,32 @@ impl<'de> Deserialize<'de> for ProjectManifest {
}

let toml_manifest = TomlProjectManifest::deserialize(deserializer)?;

let dependencies = deserialize_dependencies(
toml_manifest.dependencies,
toml_manifest.host_dependencies,
toml_manifest.build_dependencies,
)
.map_err(serde::de::Error::custom)?;
let mut dependencies = HashMap::from_iter([(
SpecType::Run,
toml_manifest
.dependencies
.into_iter()
.map(|(p, s)| (p.as_inner(), s))
.collect(),
)]);
if let Some(host_deps) = toml_manifest.host_dependencies {
dependencies.insert(
SpecType::Host,
host_deps
.into_iter()
.map(|(p, s)| (p.as_inner(), s))
.collect(),
);
}
if let Some(build_deps) = toml_manifest.build_dependencies {
dependencies.insert(
SpecType::Build,
build_deps
.into_iter()
.map(|(p, s)| (p.as_inner(), s))
.collect(),
);
}

let default_target = Target {
dependencies,
Expand Down Expand Up @@ -883,49 +931,6 @@ impl<'de> Deserialize<'de> for ProjectManifest {
}
}

/// Deserializes dependencies into a structured HashMap based on SpecType.
///
/// This function also checks for duplicate package names.
pub(crate) fn deserialize_dependencies(
run_dependencies: IndexMap<String, NamelessMatchSpec>,
host_dependencies: Option<IndexMap<String, NamelessMatchSpec>>,
build_dependencies: Option<IndexMap<String, NamelessMatchSpec>>,
) -> Result<HashMap<SpecType, IndexMap<PackageName, NamelessMatchSpec>>, ManifestDeserializeError> {
// map dependencies to a tuple type and handle package name errors
let map_dependencies =
|dependencies: IndexMap<String, NamelessMatchSpec>| -> Result<Vec<(_, _)>, ManifestDeserializeError> {
let mut result_vec = Vec::new();
for (p, s) in dependencies {
let package_name_result =
PackageName::from_str(&p).map_err(ManifestDeserializeError::InvalidPackageName)?;
result_vec.push((package_name_result, s));
}
Ok(result_vec)
};

// check for duplicates
let run_dependencies = map_dependencies(run_dependencies.clone())?;
let host_dependencies = map_dependencies(host_dependencies.clone().unwrap_or_default())?;
let build_dependencies = map_dependencies(build_dependencies.clone().unwrap_or_default())?;
let mut all_dependencies = run_dependencies.clone();
all_dependencies.extend(host_dependencies.clone());
all_dependencies.extend(build_dependencies.clone());
let mut dependency_map = HashSet::new();
for (package_name, _) in all_dependencies.iter() {
if !dependency_map.insert(package_name.as_normalized()) {
return Err(ManifestDeserializeError::DuplicateDependency(
package_name.as_normalized().to_string(),
));
}
}

Ok(HashMap::from_iter([
(SpecType::Run, run_dependencies.into_iter().collect()),
(SpecType::Host, host_dependencies.into_iter().collect()),
(SpecType::Build, build_dependencies.into_iter().collect()),
]))
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down

0 comments on commit c70586c

Please sign in to comment.