Skip to content

Commit

Permalink
[red-knot] Make the VERSIONS parser use ModuleName as its key type
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexWaygood committed Jun 21, 2024
1 parent d3b495a commit 6e1b980
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 41 deletions.
2 changes: 1 addition & 1 deletion crates/red_knot_module_resolver/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::Db;
/// A module name, e.g. `foo.bar`.
///
/// Always normalized to the absolute form (never a relative module name, i.e., never `.foo`).
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
#[derive(Clone, Debug, Eq, PartialEq, Hash, PartialOrd, Ord)]
pub struct ModuleName(smol_str::SmolStr);

impl ModuleName {
Expand Down
87 changes: 47 additions & 40 deletions crates/red_knot_module_resolver/src/typeshed/versions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@ use std::ops::{RangeFrom, RangeInclusive};
use std::str::FromStr;

use rustc_hash::FxHashMap;
use smol_str::SmolStr;

use ruff_python_stdlib::identifiers::is_identifier;
use crate::module::ModuleName;

#[derive(Debug, PartialEq, Eq)]
pub struct TypeshedVersionsParseError {
Expand Down Expand Up @@ -82,7 +81,7 @@ impl fmt::Display for TypeshedVersionsParseErrorKind {
}

#[derive(Debug, PartialEq, Eq)]
pub struct TypeshedVersions(FxHashMap<SmolStr, PyVersionRange>);
pub struct TypeshedVersions(FxHashMap<ModuleName, PyVersionRange>);

impl TypeshedVersions {
pub fn len(&self) -> usize {
Expand All @@ -93,24 +92,22 @@ impl TypeshedVersions {
self.0.is_empty()
}

pub fn contains_module(&self, module_name: impl Into<SmolStr>) -> bool {
self.0.contains_key(&module_name.into())
pub fn contains_module(&self, module_name: &ModuleName) -> bool {
self.0.contains_key(module_name)
}

pub fn module_exists_on_version(
&self,
module: impl Into<SmolStr>,
module: ModuleName,
version: impl Into<PyVersion>,
) -> bool {
let version = version.into();
let mut module: Option<SmolStr> = Some(module.into());
let mut module: Option<ModuleName> = Some(module);
while let Some(module_to_try) = module {
if let Some(range) = self.0.get(&module_to_try) {
return range.contains(version);
}
module = module_to_try
.rsplit_once('.')
.map(|(parent, _)| SmolStr::new(parent));
module = module_to_try.parent();
}
false
}
Expand Down Expand Up @@ -149,15 +146,14 @@ impl FromStr for TypeshedVersions {
});
};

let module_name = SmolStr::new(module_name);
if !module_name.split('.').all(is_identifier) {
let Some(module_name) = ModuleName::new(module_name) else {
return Err(TypeshedVersionsParseError {
line_number,
reason: TypeshedVersionsParseErrorKind::InvalidModuleName(
module_name.to_string(),
),
});
}
};

match PyVersionRange::from_str(rest) {
Ok(version) => map.insert(module_name, version),
Expand All @@ -176,7 +172,7 @@ impl FromStr for TypeshedVersions {

impl fmt::Display for TypeshedVersions {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let sorted_items: BTreeMap<&SmolStr, &PyVersionRange> = self.0.iter().collect();
let sorted_items: BTreeMap<&ModuleName, &PyVersionRange> = self.0.iter().collect();
for (module_name, range) in sorted_items {
writeln!(f, "{module_name}: {range}")?;
}
Expand Down Expand Up @@ -331,16 +327,22 @@ mod tests {
assert!(versions.len() > 100);
assert!(versions.len() < 1000);

assert!(versions.contains_module("asyncio"));
assert!(versions.module_exists_on_version("asyncio", SupportedPyVersion::Py310));
let asyncio = ModuleName::new_static("asyncio").unwrap();
let asyncio_staggered = ModuleName::new_static("asyncio.staggered").unwrap();
let audioop = ModuleName::new_static("audioop").unwrap();

assert!(versions.contains_module("asyncio.staggered"));
assert!(versions.module_exists_on_version("asyncio.staggered", SupportedPyVersion::Py38));
assert!(!versions.module_exists_on_version("asyncio.staggered", SupportedPyVersion::Py37));
assert!(versions.contains_module(&asyncio));
assert!(versions.module_exists_on_version(asyncio, SupportedPyVersion::Py310));

assert!(versions.contains_module(&asyncio_staggered));
assert!(
versions.module_exists_on_version(asyncio_staggered.clone(), SupportedPyVersion::Py38)
);
assert!(!versions.module_exists_on_version(asyncio_staggered, SupportedPyVersion::Py37));

assert!(versions.contains_module("audioop"));
assert!(versions.module_exists_on_version("audioop", SupportedPyVersion::Py312));
assert!(!versions.module_exists_on_version("audioop", SupportedPyVersion::Py313));
assert!(versions.contains_module(&audioop));
assert!(versions.module_exists_on_version(audioop.clone(), SupportedPyVersion::Py312));
assert!(!versions.module_exists_on_version(audioop, SupportedPyVersion::Py313));
}

#[test]
Expand Down Expand Up @@ -368,24 +370,29 @@ foo: 3.8- # trailing comment
"###
);

assert!(parsed_versions.contains_module("foo"));
assert!(!parsed_versions.module_exists_on_version("foo", SupportedPyVersion::Py37));
assert!(parsed_versions.module_exists_on_version("foo", SupportedPyVersion::Py38));
assert!(parsed_versions.module_exists_on_version("foo", SupportedPyVersion::Py311));

assert!(parsed_versions.contains_module("bar"));
assert!(parsed_versions.module_exists_on_version("bar", SupportedPyVersion::Py37));
assert!(parsed_versions.module_exists_on_version("bar", SupportedPyVersion::Py310));
assert!(!parsed_versions.module_exists_on_version("bar", SupportedPyVersion::Py311));

assert!(parsed_versions.contains_module("bar.baz"));
assert!(parsed_versions.module_exists_on_version("bar.baz", SupportedPyVersion::Py37));
assert!(parsed_versions.module_exists_on_version("bar.baz", SupportedPyVersion::Py39));
assert!(!parsed_versions.module_exists_on_version("bar.baz", SupportedPyVersion::Py310));

assert!(!parsed_versions.contains_module("spam"));
assert!(!parsed_versions.module_exists_on_version("spam", SupportedPyVersion::Py37));
assert!(!parsed_versions.module_exists_on_version("spam", SupportedPyVersion::Py313));
let foo = ModuleName::new_static("foo").unwrap();
let bar = ModuleName::new_static("bar").unwrap();
let bar_baz = ModuleName::new_static("bar.baz").unwrap();
let spam = ModuleName::new_static("spam").unwrap();

assert!(parsed_versions.contains_module(&foo));
assert!(!parsed_versions.module_exists_on_version(foo.clone(), SupportedPyVersion::Py37));
assert!(parsed_versions.module_exists_on_version(foo.clone(), SupportedPyVersion::Py38));
assert!(parsed_versions.module_exists_on_version(foo, SupportedPyVersion::Py311));

assert!(parsed_versions.contains_module(&bar));
assert!(parsed_versions.module_exists_on_version(bar.clone(), SupportedPyVersion::Py37));
assert!(parsed_versions.module_exists_on_version(bar.clone(), SupportedPyVersion::Py310));
assert!(!parsed_versions.module_exists_on_version(bar, SupportedPyVersion::Py311));

assert!(parsed_versions.contains_module(&bar_baz));
assert!(parsed_versions.module_exists_on_version(bar_baz.clone(), SupportedPyVersion::Py37));
assert!(parsed_versions.module_exists_on_version(bar_baz.clone(), SupportedPyVersion::Py39));
assert!(!parsed_versions.module_exists_on_version(bar_baz, SupportedPyVersion::Py310));

assert!(!parsed_versions.contains_module(&spam));
assert!(!parsed_versions.module_exists_on_version(spam.clone(), SupportedPyVersion::Py37));
assert!(!parsed_versions.module_exists_on_version(spam, SupportedPyVersion::Py313));
}

#[test]
Expand Down

0 comments on commit 6e1b980

Please sign in to comment.