Skip to content

Commit

Permalink
Merge pull request #296 from mamba-org/feat/migrate-to-strict-version
Browse files Browse the repository at this point in the history
New `StrictVersion` type for VersionSpec ranges.
  • Loading branch information
tdejager authored Aug 28, 2023
2 parents 58baa21 + cfda176 commit 743ec3b
Show file tree
Hide file tree
Showing 9 changed files with 382 additions and 159 deletions.
11 changes: 11 additions & 0 deletions crates/rattler_conda_types/src/match_spec/matcher.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use serde::{Serialize, Serializer};
use std::hash::{Hash, Hasher};
use std::{
fmt::{Display, Formatter},
str::FromStr,
Expand All @@ -19,6 +20,16 @@ pub enum StringMatcher {
Regex(regex::Regex),
}

impl Hash for StringMatcher {
fn hash<H: Hasher>(&self, state: &mut H) {
match self {
StringMatcher::Exact(s) => s.hash(state),
StringMatcher::Glob(pattern) => pattern.hash(state),
StringMatcher::Regex(regex) => regex.as_str().hash(state),
}
}
}

impl PartialEq for StringMatcher {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
Expand Down
32 changes: 31 additions & 1 deletion crates/rattler_conda_types/src/match_spec/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use rattler_digest::{serde::SerializableHash, Md5Hash, Sha256Hash};
use serde::Serialize;
use serde_with::{serde_as, skip_serializing_none};
use std::fmt::{Debug, Display, Formatter};
use std::hash::Hash;

pub mod matcher;
pub mod parse;
Expand Down Expand Up @@ -111,7 +112,7 @@ use matcher::StringMatcher;
/// Alternatively, an exact spec is given by `*[sha256=01ba4719c80b6fe911b091a7c05124b64eeece964e09c058ef8f9805daca546b]`.
#[skip_serializing_none]
#[serde_as]
#[derive(Debug, Default, Clone, Serialize, Eq, PartialEq)]
#[derive(Debug, Default, Clone, Serialize, Eq, PartialEq, Hash)]
pub struct MatchSpec {
/// The name of the package
pub name: Option<PackageName>,
Expand Down Expand Up @@ -351,6 +352,7 @@ mod tests {
use rattler_digest::{parse_digest_from_hex, Md5, Sha256};

use crate::{MatchSpec, NamelessMatchSpec, PackageName, PackageRecord, Version};
use std::hash::{Hash, Hasher};

#[test]
fn test_matchspec_format_eq() {
Expand All @@ -370,6 +372,34 @@ mod tests {
assert_eq!(spec, rebuild_spec)
}

#[test]
fn test_hash_match() {
let spec1 = MatchSpec::from_str("tensorflow 2.6.*").unwrap();
let spec2 = MatchSpec::from_str("tensorflow 2.6.*").unwrap();
assert_eq!(spec1, spec2);

let mut hasher = std::collections::hash_map::DefaultHasher::new();
let hash1 = spec1.hash(&mut hasher);
let hash2 = spec2.hash(&mut hasher);

assert_eq!(hash1, hash2);
}

#[test]
fn test_hash_no_match() {
let spec1 = MatchSpec::from_str("tensorflow 2.6.0.*").unwrap();
let spec2 = MatchSpec::from_str("tensorflow 2.6.*").unwrap();
assert_ne!(spec1, spec2);

let mut hasher = std::collections::hash_map::DefaultHasher::new();
spec1.hash(&mut hasher);
let hash1 = hasher.finish();
spec2.hash(&mut hasher);
let hash2 = hasher.finish();

assert_ne!(hash1, hash2);
}

#[test]
fn test_digest_match() {
let record = PackageRecord {
Expand Down
60 changes: 59 additions & 1 deletion crates/rattler_conda_types/src/version/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,46 @@ impl<'v> SegmentIter<'v> {
}
}

/// Version that only has equality when it is exactly the same
/// e.g for [`Version`] 1.0.0 == 1.0 while in [`StrictVersion`]
/// this is not equal. Useful in ranges where we are talking
/// about equality over version ranges instead of specific
/// version instances
#[derive(Clone, PartialOrd, Ord, Eq, Debug)]
pub struct StrictVersion(pub Version);

impl PartialEq for StrictVersion {
fn eq(&self, other: &Self) -> bool {
// StrictVersion is only equal if the number
// of components are the same
// and the components are the same
self.0.components.len() == other.0.components.len() && self.0 == other.0
}
}

impl Display for StrictVersion {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}

impl Hash for StrictVersion {
fn hash<H: Hasher>(&self, state: &mut H) {
fn hash_segments<'i, I: Iterator<Item = SegmentIter<'i>>, H: Hasher>(
state: &mut H,
segments: I,
) {
for segment in segments {
segment.components().rev().for_each(|c| c.hash(state));
}
}

self.0.epoch().hash(state);
hash_segments(state, self.0.segments());
hash_segments(state, self.0.local_segments());
}
}

#[cfg(test)]
mod test {
use std::cmp::Ordering;
Expand All @@ -958,6 +998,8 @@ mod test {

use rand::seq::SliceRandom;

use crate::version::StrictVersion;

use super::Version;

// Tests are inspired by: https://github.com/conda/conda/blob/33a142c16530fcdada6c377486f1c1a385738a96/tests/models/test_version.py
Expand Down Expand Up @@ -1143,6 +1185,22 @@ mod test {
assert_eq!(random_versions, parsed_versions);
}

#[test]
fn strict_version_test() {
let v_1_0 = StrictVersion::from_str("1.0.0").unwrap();
// Should be equal to itself
assert_eq!(v_1_0, v_1_0);
let v_1_0_0 = StrictVersion::from_str("1.0").unwrap();
// Strict version should not discard zero's
assert_ne!(v_1_0, v_1_0_0);
// Ordering should stay the same as version
assert_eq!(v_1_0.cmp(&v_1_0_0), Ordering::Equal);

// Hashing should consider v_1_0 and v_1_0_0 as unequal
assert_eq!(get_hash(&v_1_0), get_hash(&v_1_0));
assert_ne!(get_hash(&v_1_0), get_hash(&v_1_0_0));
}

#[test]
fn bump() {
assert_eq!(
Expand All @@ -1166,7 +1224,7 @@ mod test {
.starts_with(&Version::from_str("1.2").unwrap()));
}

fn get_hash(spec: &Version) -> u64 {
fn get_hash(spec: &impl Hash) -> u64 {
let mut s = DefaultHasher::new();
spec.hash(&mut s);
s.finish()
Expand Down
10 changes: 9 additions & 1 deletion crates/rattler_conda_types/src/version/parse.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{Component, Version};
use super::{Component, StrictVersion, Version};
use crate::version::flags::Flags;
use crate::version::segment::Segment;
use crate::version::{ComponentVec, SegmentVec};
Expand Down Expand Up @@ -437,6 +437,14 @@ impl FromStr for Version {
}
}

impl FromStr for StrictVersion {
type Err = ParseVersionError;

fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(StrictVersion(Version::from_str(s)?))
}
}

#[cfg(test)]
mod test {
use super::Version;
Expand Down
67 changes: 37 additions & 30 deletions crates/rattler_conda_types/src/version_spec/constraint.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::ParseConstraintError;
use super::VersionOperator;
use super::RangeOperator;
use crate::version_spec::parse::constraint_parser;
use crate::version_spec::{EqualityOperator, StrictRangeOperator};
use crate::Version;
use std::str::FromStr;

Expand All @@ -12,7 +13,13 @@ pub(crate) enum Constraint {
Any,

/// Version comparison (e.g `>1.2.3`)
Comparison(VersionOperator, Version),
Comparison(RangeOperator, Version),

/// Strict comparison (e.g `~=1.2.3`)
StrictComparison(StrictRangeOperator, Version),

/// Exact Version
Exact(EqualityOperator, Version),
}

/// Returns true if the specified character is the first character of a version constraint.
Expand All @@ -37,7 +44,7 @@ impl FromStr for Constraint {
mod test {
use super::Constraint;
use crate::version_spec::constraint::ParseConstraintError;
use crate::version_spec::VersionOperator;
use crate::version_spec::{EqualityOperator, RangeOperator, StrictRangeOperator};
use crate::Version;
use std::str::FromStr;

Expand Down Expand Up @@ -91,63 +98,63 @@ mod test {
assert_eq!(
Constraint::from_str(">1.2.3"),
Ok(Constraint::Comparison(
VersionOperator::Greater,
RangeOperator::Greater,
Version::from_str("1.2.3").unwrap()
))
);
assert_eq!(
Constraint::from_str("<1.2.3"),
Ok(Constraint::Comparison(
VersionOperator::Less,
RangeOperator::Less,
Version::from_str("1.2.3").unwrap()
))
);
assert_eq!(
Constraint::from_str("=1.2.3"),
Ok(Constraint::Comparison(
VersionOperator::StartsWith,
Ok(Constraint::StrictComparison(
StrictRangeOperator::StartsWith,
Version::from_str("1.2.3").unwrap()
))
);
assert_eq!(
Constraint::from_str("==1.2.3"),
Ok(Constraint::Comparison(
VersionOperator::Equals,
Ok(Constraint::Exact(
EqualityOperator::Equals,
Version::from_str("1.2.3").unwrap()
))
);
assert_eq!(
Constraint::from_str("!=1.2.3"),
Ok(Constraint::Comparison(
VersionOperator::NotEquals,
Ok(Constraint::Exact(
EqualityOperator::NotEquals,
Version::from_str("1.2.3").unwrap()
))
);
assert_eq!(
Constraint::from_str("~=1.2.3"),
Ok(Constraint::Comparison(
VersionOperator::Compatible,
Ok(Constraint::StrictComparison(
StrictRangeOperator::Compatible,
Version::from_str("1.2.3").unwrap()
))
);
assert_eq!(
Constraint::from_str(">=1.2.3"),
Ok(Constraint::Comparison(
VersionOperator::GreaterEquals,
RangeOperator::GreaterEquals,
Version::from_str("1.2.3").unwrap()
))
);
assert_eq!(
Constraint::from_str("<=1.2.3"),
Ok(Constraint::Comparison(
VersionOperator::LessEquals,
RangeOperator::LessEquals,
Version::from_str("1.2.3").unwrap()
))
);
assert_eq!(
Constraint::from_str(">=1!1.2"),
Ok(Constraint::Comparison(
VersionOperator::GreaterEquals,
RangeOperator::GreaterEquals,
Version::from_str("1!1.2").unwrap()
))
);
Expand All @@ -157,50 +164,50 @@ mod test {
fn test_glob_op() {
assert_eq!(
Constraint::from_str("=1.2.*"),
Ok(Constraint::Comparison(
VersionOperator::StartsWith,
Ok(Constraint::StrictComparison(
StrictRangeOperator::StartsWith,
Version::from_str("1.2").unwrap()
))
);
assert_eq!(
Constraint::from_str("!=1.2.*"),
Ok(Constraint::Comparison(
VersionOperator::NotStartsWith,
Ok(Constraint::StrictComparison(
StrictRangeOperator::NotStartsWith,
Version::from_str("1.2").unwrap()
))
);
assert_eq!(
Constraint::from_str(">=1.2.*"),
Ok(Constraint::Comparison(
VersionOperator::GreaterEquals,
RangeOperator::GreaterEquals,
Version::from_str("1.2").unwrap()
))
);
assert_eq!(
Constraint::from_str("==1.2.*"),
Ok(Constraint::Comparison(
VersionOperator::Equals,
Ok(Constraint::Exact(
EqualityOperator::Equals,
Version::from_str("1.2").unwrap()
))
);
assert_eq!(
Constraint::from_str(">1.2.*"),
Ok(Constraint::Comparison(
VersionOperator::GreaterEquals,
RangeOperator::GreaterEquals,
Version::from_str("1.2").unwrap()
))
);
assert_eq!(
Constraint::from_str("<=1.2.*"),
Ok(Constraint::Comparison(
VersionOperator::LessEquals,
RangeOperator::LessEquals,
Version::from_str("1.2").unwrap()
))
);
assert_eq!(
Constraint::from_str("<1.2.*"),
Ok(Constraint::Comparison(
VersionOperator::Less,
RangeOperator::Less,
Version::from_str("1.2").unwrap()
))
);
Expand All @@ -210,8 +217,8 @@ mod test {
fn test_starts_with() {
assert_eq!(
Constraint::from_str("1.2.*"),
Ok(Constraint::Comparison(
VersionOperator::StartsWith,
Ok(Constraint::StrictComparison(
StrictRangeOperator::StartsWith,
Version::from_str("1.2").unwrap()
))
);
Expand All @@ -225,8 +232,8 @@ mod test {
fn test_exact() {
assert_eq!(
Constraint::from_str("1.2.3"),
Ok(Constraint::Comparison(
VersionOperator::Equals,
Ok(Constraint::Exact(
EqualityOperator::Equals,
Version::from_str("1.2.3").unwrap()
))
);
Expand Down
Loading

0 comments on commit 743ec3b

Please sign in to comment.