diff --git a/crates/red_knot_python_semantic/resources/mdtest/binary/instances.md b/crates/red_knot_python_semantic/resources/mdtest/binary/instances.md index c2f8f2f1a844ca..efc619132cbba5 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/binary/instances.md +++ b/crates/red_knot_python_semantic/resources/mdtest/binary/instances.md @@ -202,11 +202,7 @@ reveal_type(A() + B()) # revealed: MyString # N.B. Still a subtype of `A`, even though `A` does not appear directly in the class's `__bases__` class C(B): ... -# TODO: we currently only understand direct subclasses as subtypes of the superclass. -# We need to iterate through the full MRO rather than just the class's bases; -# if we do, we'll understand `C` as a subtype of `A`, and correctly understand this as being -# `MyString` rather than `str` -reveal_type(A() + C()) # revealed: str +reveal_type(A() + C()) # revealed: MyString ``` ## Reflected precedence 2 diff --git a/crates/red_knot_python_semantic/resources/mdtest/mro.md b/crates/red_knot_python_semantic/resources/mdtest/mro.md new file mode 100644 index 00000000000000..0bfd69be1ffe8f --- /dev/null +++ b/crates/red_knot_python_semantic/resources/mdtest/mro.md @@ -0,0 +1,397 @@ +# Method Resolution Order tests + +Tests that assert that we can infer the correct type for a class's `__mro__` attribute. + +This attribute is rarely accessed directly at runtime. However, it's extremely important for *us* to +know the precise possible values of a class's Method Resolution Order, or we won't be able to infer +the correct type of attributes accessed from instances. + +For documentation on method resolution orders, see: + +- +- + +## No bases + +```py +class C: + pass + +reveal_type(C.__mro__) # revealed: tuple[Literal[C], Literal[object]] +``` + +## The special case: `object` itself + +```py +reveal_type(object.__mro__) # revealed: tuple[Literal[object]] +``` + +## Explicit inheritance from `object` + +```py +class C(object): + pass + +reveal_type(C.__mro__) # revealed: tuple[Literal[C], Literal[object]] +``` + +## Explicit inheritance from non-`object` single base + +```py +class A: + pass + +class B(A): + pass + +reveal_type(B.__mro__) # revealed: tuple[Literal[B], Literal[A], Literal[object]] +``` + +## Linearization of multiple bases + +```py +class A: + pass + +class B: + pass + +class C(A, B): + pass + +reveal_type(C.__mro__) # revealed: tuple[Literal[C], Literal[A], Literal[B], Literal[object]] +``` + +## Complex diamond inheritance (1) + +This is "ex_2" from + +```py +class O: + pass + +class X(O): + pass + +class Y(O): + pass + +class A(X, Y): + pass + +class B(Y, X): + pass + +reveal_type(A.__mro__) # revealed: tuple[Literal[A], Literal[X], Literal[Y], Literal[O], Literal[object]] +reveal_type(B.__mro__) # revealed: tuple[Literal[B], Literal[Y], Literal[X], Literal[O], Literal[object]] +``` + +## Complex diamond inheritance (2) + +This is "ex_5" from + +```py +class O: + pass + +class F(O): + pass + +class E(O): + pass + +class D(O): + pass + +class C(D, F): + pass + +class B(D, E): + pass + +class A(B, C): + pass + +# revealed: tuple[Literal[C], Literal[D], Literal[F], Literal[O], Literal[object]] +reveal_type(C.__mro__) +# revealed: tuple[Literal[B], Literal[D], Literal[E], Literal[O], Literal[object]] +reveal_type(B.__mro__) +# revealed: tuple[Literal[A], Literal[B], Literal[C], Literal[D], Literal[E], Literal[F], Literal[O], Literal[object]] +reveal_type(A.__mro__) +``` + +## Complex diamond inheritance (3) + +This is "ex_6" from + +```py +class O: + pass + +class F(O): + pass + +class E(O): + pass + +class D(O): + pass + +class C(D, F): + pass + +class B(E, D): + pass + +class A(B, C): + pass + +# revealed: tuple[Literal[C], Literal[D], Literal[F], Literal[O], Literal[object]] +reveal_type(C.__mro__) +# revealed: tuple[Literal[B], Literal[E], Literal[D], Literal[O], Literal[object]] +reveal_type(B.__mro__) +# revealed: tuple[Literal[A], Literal[B], Literal[E], Literal[C], Literal[D], Literal[F], Literal[O], Literal[object]] +reveal_type(A.__mro__) +``` + +## Complex diamond inheritance (4) + +This is "ex_9" from + +```py +class O: + pass + +class A(O): + pass + +class B(O): + pass + +class C(O): + pass + +class D(O): + pass + +class E(O): + pass + +class K1(A, B, C): + pass + +class K2(D, B, E): + pass + +class K3(D, A): + pass + +class Z(K1, K2, K3): + pass + +# revealed: tuple[Literal[K1], Literal[A], Literal[B], Literal[C], Literal[O], Literal[object]] +reveal_type(K1.__mro__) +# revealed: tuple[Literal[K2], Literal[D], Literal[B], Literal[E], Literal[O], Literal[object]] +reveal_type(K2.__mro__) +# revealed: tuple[Literal[K3], Literal[D], Literal[A], Literal[O], Literal[object]] +reveal_type(K3.__mro__) +# revealed: tuple[Literal[Z], Literal[K1], Literal[K2], Literal[K3], Literal[D], Literal[A], Literal[B], Literal[C], Literal[E], Literal[O], Literal[object]] +reveal_type(Z.__mro__) +``` + +## Inheritance from `Unknown` + +```py +from does_not_exist import DoesNotExist # error: [unresolved-import] + +class A(DoesNotExist): + pass + +class B: + pass + +class C: + pass + +class D(A, B, C): + pass + +class E(B, C): + pass + +class F(E, A): + pass + +reveal_type(A.__mro__) # revealed: tuple[Literal[A], Unknown, Literal[object]] +reveal_type(D.__mro__) # revealed: tuple[Literal[D], Literal[A], Unknown, Literal[B], Literal[C], Literal[object]] +reveal_type(E.__mro__) # revealed: tuple[Literal[E], Literal[B], Literal[C], Literal[object]] +reveal_type(F.__mro__) # revealed: tuple[Literal[F], Literal[E], Literal[B], Literal[C], Literal[A], Unknown, Literal[object]] +``` + +## `__bases__` lists that cause errors at runtime + +If the class's `__bases__` cause an exception to be raised at runtime and therefore the class +creation to fail, we infer the class's `__mro__` as being `[, Unknown, object]`: + +```py +# error: [inconsistent-mro] "Cannot create a consistent method resolution order (MRO) for class `Foo` with bases list `[, ]`" +class Foo(object, int): + pass + +reveal_type(Foo.__mro__) # revealed: tuple[Literal[Foo], Unknown, Literal[object]] + +class Bar(Foo): + pass + +reveal_type(Bar.__mro__) # revealed: tuple[Literal[Bar], Literal[Foo], Unknown, Literal[object]] + +# This is the `TypeError` at the bottom of "ex_2" +# in the examples at + +class O: + pass + +class X(O): + pass + +class Y(O): + pass + +class A(X, Y): + pass + +class B(Y, X): + pass + +reveal_type(A.__mro__) # revealed: tuple[Literal[A], Literal[X], Literal[Y], Literal[O], Literal[object]] +reveal_type(B.__mro__) # revealed: tuple[Literal[B], Literal[Y], Literal[X], Literal[O], Literal[object]] + +# error: [inconsistent-mro] "Cannot create a consistent method resolution order (MRO) for class `Z` with bases list `[, ]`" +class Z(A, B): + pass + +reveal_type(Z.__mro__) # revealed: tuple[Literal[Z], Unknown, Literal[object]] + +class AA(Z): + pass + +reveal_type(AA.__mro__) # revealed: tuple[Literal[AA], Literal[Z], Unknown, Literal[object]] +``` + +## `__bases__` includes a `Union` + +We don't support union types in a class's bases; a base must resolve to a single `ClassLiteralType`. +If we find a union type in a class's bases, we infer the class's `__mro__` as being +`[, Unknown, object]`, the same as for MROs that cause errors at runtime. + +```py +def returns_bool() -> bool: + return True + +class A: + pass + +class B: + pass + +if returns_bool(): + x = A +else: + x = B + +reveal_type(x) # revealed: Literal[A, B] + +# error: [invalid-base] "Invalid class base with type `Literal[A, B]` (all bases must be a class, `Any`, `Unknown` or `Todo`)" +class Foo(x): + pass + +reveal_type(Foo.__mro__) # revealed: tuple[Literal[Foo], Unknown, Literal[object]] +``` + +## `__bases__` includes multiple `Union`s + +```py +def returns_bool() -> bool: + return True + +class A: + pass + +class B: + pass + +class C: + pass + +class D: + pass + +if returns_bool(): + x = A +else: + x = B + +if returns_bool(): + y = C +else: + y = D + +reveal_type(x) # revealed: Literal[A, B] +reveal_type(y) # revealed: Literal[C, D] + +# error: [invalid-base] "Invalid class base with type `Literal[A, B]` (all bases must be a class, `Any`, `Unknown` or `Todo`)" +# error: [invalid-base] "Invalid class base with type `Literal[C, D]` (all bases must be a class, `Any`, `Unknown` or `Todo`)" +class Foo(x, y): + pass + +reveal_type(Foo.__mro__) # revealed: tuple[Literal[Foo], Unknown, Literal[object]] +``` + +## `__bases__` lists that cause errors... now with `Union`s + +```py +def returns_bool() -> bool: + return True + +class O: + pass + +class X(O): + pass + +class Y(O): + pass + +if bool(): + foo = Y +else: + foo = object + +# error: [invalid-base] "Invalid class base with type `Literal[Y, object]` (all bases must be a class, `Any`, `Unknown` or `Todo`)" +class PossibleError(foo, X): + pass + +reveal_type(PossibleError.__mro__) # revealed: tuple[Literal[PossibleError], Unknown, Literal[object]] + +class A(X, Y): + pass + +reveal_type(A.__mro__) # revealed: tuple[Literal[A], Literal[X], Literal[Y], Literal[O], Literal[object]] + +if returns_bool(): + class B(X, Y): + pass + +else: + class B(Y, X): + pass + +# revealed: tuple[Literal[B], Literal[X], Literal[Y], Literal[O], Literal[object]] | tuple[Literal[B], Literal[Y], Literal[X], Literal[O], Literal[object]] +reveal_type(B.__mro__) + +# error: [invalid-base] "Invalid class base with type `Literal[B, B]` (all bases must be a class, `Any`, `Unknown` or `Todo`)" +class Z(A, B): + pass + +reveal_type(Z.__mro__) # revealed: tuple[Literal[Z], Unknown, Literal[object]] +``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/scopes/moduletype_attrs.md b/crates/red_knot_python_semantic/resources/mdtest/scopes/moduletype_attrs.md index b468ede95481a1..a69aa018b34d5c 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/scopes/moduletype_attrs.md +++ b/crates/red_knot_python_semantic/resources/mdtest/scopes/moduletype_attrs.md @@ -59,11 +59,10 @@ reveal_type(typing.__name__) # revealed: str reveal_type(typing.__init__) # revealed: Literal[__init__] # These come from `builtins.object`, not `types.ModuleType`: -# TODO: we don't currently understand `types.ModuleType` as inheriting from `object`; -# these should not reveal `Unknown`: -reveal_type(typing.__eq__) # revealed: Unknown -reveal_type(typing.__class__) # revealed: Unknown -reveal_type(typing.__module__) # revealed: Unknown +reveal_type(typing.__eq__) # revealed: Literal[__eq__] + +# TODO: understand properties +reveal_type(typing.__class__) # revealed: Literal[__class__] # TODO: needs support for attribute access on instances, properties and generics; # should be `dict[str, Any]` diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 481531e460539e..fb0fa666c09485 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -1,5 +1,7 @@ +use mro::{ClassBase, Mro, MroError}; use ruff_db::files::File; use ruff_python_ast as ast; +use std::borrow::Cow; use crate::module_resolver::file_to_module; use crate::semantic_index::ast_ids::HasScopedAstId; @@ -13,7 +15,7 @@ use crate::stdlib::{builtins_symbol, types_symbol, typeshed_symbol, typing_exten use crate::symbol::{Boundness, Symbol}; use crate::types::diagnostic::TypeCheckDiagnosticsBuilder; use crate::types::narrow::narrowing_constraint; -use crate::{Db, FxOrderSet, HasTy, Module, SemanticModel}; +use crate::{Db, FxOrderSet, Module}; pub(crate) use self::builder::{IntersectionBuilder, UnionBuilder}; pub use self::diagnostic::{TypeCheckDiagnostic, TypeCheckDiagnostics}; @@ -26,6 +28,7 @@ mod builder; mod diagnostic; mod display; mod infer; +mod mro; mod narrow; pub fn check_types(db: &dyn Db, file: File) -> TypeCheckDiagnostics { @@ -827,6 +830,14 @@ impl<'db> Type<'db> { /// as accessed from instances of the `Bar` class. #[must_use] pub(crate) fn member(&self, db: &'db dyn Db, name: &str) -> Symbol<'db> { + if name == "__mro__" { + if let Some(mro) = Mro::of_ty(db, *self) { + let mro_element_tys: Box<_> = mro.iter().copied().map(Type::from).collect(); + let mro_ty = Type::Tuple(TupleType::new(db, mro_element_tys)); + return Symbol::Type(mro_ty, Boundness::Bound); + } + } + match self { Type::Any => Type::Any.into(), Type::Never => { @@ -1788,6 +1799,7 @@ pub struct ClassType<'db> { known: Option, } +#[salsa::tracked] impl<'db> ClassType<'db> { pub fn is_known(self, db: &'db dyn Db, known_class: KnownClass) -> bool { match self.known(db) { @@ -1804,36 +1816,30 @@ impl<'db> ClassType<'db> { }) } - /// Return an iterator over the types of this class's bases. - /// - /// # Panics: - /// If `definition` is not a `DefinitionKind::Class`. - pub fn bases(&self, db: &'db dyn Db) -> impl Iterator> { - let definition = self.definition(db); - let DefinitionKind::Class(class_stmt_node) = definition.kind(db) else { - panic!("Class type definition must have DefinitionKind::Class"); - }; - class_stmt_node - .bases() - .iter() - .map(move |base_expr: &ast::Expr| { - if class_stmt_node.type_params.is_some() { - // when we have a specialized scope, we'll look up the inference - // within that scope - let model: SemanticModel<'db> = SemanticModel::new(db, definition.file(db)); - base_expr.ty(&model) - } else { - // Otherwise, we can do the lookup based on the definition scope - definition_expression_ty(db, definition, base_expr) - } - }) + /// Return the original [`ast::StmtClassDef`] node associated with this class + fn node(self, db: &'db dyn Db) -> &'db ast::StmtClassDef { + match self.definition(db).kind(db) { + DefinitionKind::Class(class_stmt_node) => class_stmt_node, + _ => panic!("Class type definition must have DefinitionKind::Class"), + } + } + + #[salsa::tracked(return_ref)] + fn try_mro(self, db: &'db dyn Db) -> Result, MroError<'db>> { + Mro::of_class(db, self) + } + + fn mro(self, db: &'db dyn Db) -> Cow<'db, Mro<'db>> { + self.try_mro(db) + .as_ref() + .map_or_else(|_| Cow::Owned(Mro::from_error(db, self)), Cow::Borrowed) } pub fn is_subclass_of(self, db: &'db dyn Db, other: ClassType) -> bool { // TODO: we need to iterate over the *MRO* here, not the bases (other == self) - || self.bases(db).any(|base| match base { - Type::ClassLiteral(base_class) => base_class == other, + || self.mro(db).iter().any(|base| match base { + ClassBase::Class(base_class) => *base_class == other, // `is_subclass_of` is checking the subtype relation, in which gradual types do not // participate, so we should not return `True` if we find `Any/Unknown` in the // bases. @@ -1860,10 +1866,17 @@ impl<'db> ClassType<'db> { } pub(crate) fn inherited_class_member(self, db: &'db dyn Db, name: &str) -> Symbol<'db> { - for base in self.bases(db) { - let member = base.member(db, name); - if !member.is_unbound() { - return member; + for superclass in self.mro(db).iter().skip(1) { + match superclass { + ClassBase::Any | ClassBase::Unknown | ClassBase::Todo => { + return Type::from(*superclass).member(db, name) + } + ClassBase::Class(class) => { + let member = class.own_class_member(db, name); + if !member.is_unbound() { + return member; + } + } } } diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 567e615a0b0092..5c16375b1cfe7f 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -61,6 +61,8 @@ use crate::types::{ use crate::util::subscript::{PyIndex, PySlice}; use crate::Db; +use super::mro::MroError; + /// Infer all types for a [`ScopeId`], including all definitions and expressions in that scope. /// Use when checking a scope, or needing to provide a type for an arbitrary expression in the /// scope. @@ -413,13 +415,55 @@ impl<'db> TypeInferenceBuilder<'db> { for definition in self.types.declarations.keys() { if infer_definition_types(self.db, *definition).has_deferred { let deferred = infer_deferred_types(self.db, *definition); - deferred_expression_types.extend(deferred.expressions.iter()); + deferred_expression_types.extend(&deferred.expressions); } } - self.types - .expressions - .extend(deferred_expression_types.iter()); + self.types.expressions.extend(deferred_expression_types); + } + + self.check_class_mros(); + } + + /// Iterate over all class definitions to check that Python will be able to create + /// a consistent "[method resolution order]" for each class at runtime. If not, + /// issue a diagnostic. + /// + /// [method resolution order]: https://docs.python.org/3/glossary.html#term-method-resolution-order + fn check_class_mros(&mut self) { + let declarations = std::mem::take(&mut self.types.declarations); + let class_definitions = declarations + .values() + .filter_map(|ty| ty.into_class_literal_type()); + + for class in class_definitions { + match class.try_mro(self.db) { + Ok(_) => continue, + Err(MroError::InvalidBases(bases)) => { + for (index, base_ty) in bases { + let base_node = &class.node(self.db).bases()[*index]; + self.diagnostics.add( + base_node.into(), + "invalid-base", + format_args!( + "Invalid class base with type `{}` (all bases must be a class, `Any`, `Unknown` or `Todo`)", + base_ty.display(self.db) + ) + ); + } + }, + Err(MroError::UnresolvableMro(bases)) => self.diagnostics.add( + class.node(self.db).into(), + "inconsistent-mro", + format_args!( + "Cannot create a consistent method resolution order (MRO) for class `{}` with bases list `[{}]`", + class.name(self.db), + bases.iter().map(|base|base.display(self.db)).join(", ") + ) + ) + } } + + self.types.declarations = declarations; } fn infer_region_definition(&mut self, definition: Definition<'db>) { @@ -4196,7 +4240,8 @@ mod tests { use crate::semantic_index::symbol::FileScopeId; use crate::semantic_index::{global_scope, semantic_index, symbol_table, use_def_map}; use crate::types::{ - check_types, global_symbol, infer_definition_types, symbol, TypeCheckDiagnostics, + check_types, definition_expression_ty, global_symbol, infer_definition_types, symbol, + TypeCheckDiagnostics, }; use crate::{HasTy, ProgramSettings, SemanticModel}; use ruff_db::files::{system_path_to_file, File}; @@ -4205,7 +4250,7 @@ mod tests { use ruff_db::testing::assert_function_query_was_not_run; use ruff_python_ast::name::Name; - use super::TypeInferenceBuilder; + use super::*; fn setup_db() -> TestDb { let db = TestDb::new(); @@ -4338,8 +4383,15 @@ mod tests { let class = ty.expect_class_literal(); let base_names: Vec<_> = class - .bases(&db) - .map(|base_ty| format!("{}", base_ty.display(&db))) + .node(&db) + .bases() + .iter() + .map(|base| { + format!( + "{}", + definition_expression_ty(&db, class.definition(&db), base).display(&db) + ) + }) .collect(); assert_eq!(base_names, vec!["Literal[Base]"]); @@ -4575,13 +4627,13 @@ mod tests { let a = system_path_to_file(&db, "src/a.py").expect("file to exist"); let c_ty = global_symbol(&db, a, "C").expect_type(); let c_class = c_ty.expect_class_literal(); - let mut c_bases = c_class.bases(&db); - let b_ty = c_bases.next().unwrap(); - let b_class = b_ty.expect_class_literal(); + let c_mro = c_class.mro(&db); + let b_ty = c_mro[1]; + let b_class = b_ty.expect_class(); assert_eq!(b_class.name(&db), "B"); - let mut b_bases = b_class.bases(&db); - let a_ty = b_bases.next().unwrap(); - let a_class = a_ty.expect_class_literal(); + let b_mro = b_class.mro(&db); + let a_ty = b_mro[1]; + let a_class = a_ty.expect_class(); assert_eq!(a_class.name(&db), "A"); Ok(()) @@ -4730,15 +4782,8 @@ mod tests { db.write_file("/src/a.pyi", "class C(object): pass")?; let file = system_path_to_file(&db, "/src/a.pyi").unwrap(); let ty = global_symbol(&db, file, "C").expect_type(); - - let base = ty - .expect_class_literal() - .bases(&db) - .next() - .expect("there should be at least one base"); - - assert_eq!(base.display(&db).to_string(), "Literal[object]"); - + let base = ty.expect_class_literal().mro(&db)[1]; + assert_eq!(base.display(&db).to_string(), ""); Ok(()) } diff --git a/crates/red_knot_python_semantic/src/types/mro.rs b/crates/red_knot_python_semantic/src/types/mro.rs new file mode 100644 index 00000000000000..a0373fb10a764a --- /dev/null +++ b/crates/red_knot_python_semantic/src/types/mro.rs @@ -0,0 +1,275 @@ +use std::borrow::Cow; +use std::collections::VecDeque; +use std::ops::Deref; + +use ruff_python_ast as ast; + +use super::{definition_expression_ty, ClassType, KnownClass, Type}; +use crate::semantic_index::definition::Definition; +use crate::{Db, HasTy, SemanticModel}; + +/// A single possible method resolution order of a given class. +/// +/// See [`ClassType::mro_possibilities`] for more details. +#[derive(PartialEq, Eq, Default, Hash, Clone, Debug)] +pub(super) struct Mro<'db>(Box<[ClassBase<'db>]>); + +impl<'db> Mro<'db> { + /// In the event that a possible list of bases would (or could) lead to a + /// `TypeError` being raised at runtime due to an unresolvable MRO, we + /// infer the class as being `[, Unknown, object]`. + /// This seems most likely to reduce the possibility of cascading errors + /// elsewhere. + /// + /// (We emit a diagnostic warning about the runtime `TypeError` in + /// [`super::infer::TypeInferenceBuilder::infer_region_scope`].) + pub(super) fn from_error(db: &'db dyn Db, class: ClassType<'db>) -> Self { + Self::from([ + ClassBase::Class(class), + ClassBase::Unknown, + ClassBase::object(db), + ]) + } + + pub(super) fn of_class(db: &'db dyn Db, class: ClassType<'db>) -> Result> { + let class_stmt_node = class.node(db); + + match class_stmt_node.bases() { + [] if class.is_known(db, KnownClass::Object) => { + Ok(Self::from([ClassBase::Class(class)])) + } + [] => Ok(Self::from([ClassBase::Class(class), ClassBase::object(db)])), + [single_base] => { + ClassBase::try_from_node(db, single_base, class_stmt_node, class.definition(db)) + .map(|base| { + std::iter::once(ClassBase::Class(class)) + .chain(Mro::of_base(db, base).iter().copied()) + .collect() + }) + .map_err(|base_ty| MroError::InvalidBases(Box::from([(0, base_ty)]))) + } + multiple_bases => { + let definition = class.definition(db); + let mut valid_bases = vec![]; + let mut invalid_bases = vec![]; + + for (i, base_node) in multiple_bases.iter().enumerate() { + match ClassBase::try_from_node(db, base_node, class_stmt_node, definition) { + Ok(valid_base) => valid_bases.push(valid_base), + Err(invalid_base) => invalid_bases.push((i, invalid_base)), + } + } + + if !invalid_bases.is_empty() { + return Err(MroError::InvalidBases(invalid_bases.into_boxed_slice())); + } + + let mut seqs = vec![VecDeque::from([ClassBase::Class(class)])]; + for base in &valid_bases { + seqs.push(Mro::of_base(db, *base).iter().copied().collect()); + } + seqs.push(valid_bases.iter().copied().collect()); + + c3_merge(seqs) + .ok_or_else(|| MroError::UnresolvableMro(valid_bases.into_boxed_slice())) + } + } + } + + pub(super) fn of_ty(db: &'db dyn Db, ty: Type<'db>) -> Option> { + ClassBase::try_from_ty(ty).map(|as_base| Self::of_base(db, as_base)) + } + + fn of_base(db: &'db dyn Db, base: ClassBase<'db>) -> Cow<'db, Self> { + match base { + ClassBase::Any => Cow::Owned(Mro::from([ClassBase::Any, ClassBase::object(db)])), + ClassBase::Unknown => { + Cow::Owned(Mro::from([ClassBase::Unknown, ClassBase::object(db)])) + } + ClassBase::Todo => Cow::Owned(Mro::from([ClassBase::Todo, ClassBase::object(db)])), + ClassBase::Class(class) => class.mro(db), + } + } +} + +impl<'db, const N: usize> From<[ClassBase<'db>; N]> for Mro<'db> { + fn from(value: [ClassBase<'db>; N]) -> Self { + Self(Box::from(value)) + } +} + +impl<'db> From>> for Mro<'db> { + fn from(value: Vec>) -> Self { + Self(value.into_boxed_slice()) + } +} + +impl<'db> Deref for Mro<'db> { + type Target = [ClassBase<'db>]; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl<'db> FromIterator> for Mro<'db> { + fn from_iter>>(iter: T) -> Self { + Self(iter.into_iter().collect()) + } +} + +impl<'a, 'db> IntoIterator for &'a Mro<'db> { + type IntoIter = std::slice::Iter<'a, ClassBase<'db>>; + type Item = &'a ClassBase<'db>; + + fn into_iter(self) -> Self::IntoIter { + self.0.iter() + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub(super) enum MroError<'db> { + InvalidBases(Box<[(usize, Type<'db>)]>), + UnresolvableMro(Box<[ClassBase<'db>]>), +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub(super) enum ClassBase<'db> { + Any, + Unknown, + Todo, + Class(ClassType<'db>), +} + +impl<'db> ClassBase<'db> { + fn object(db: &'db dyn Db) -> Self { + KnownClass::Object + .to_class(db) + .into_class_literal_type() + .map_or(Self::Unknown, Self::Class) + } + + fn try_from_node( + db: &'db dyn Db, + base_node: &'db ast::Expr, + class_stmt_node: &'db ast::StmtClassDef, + definition: Definition<'db>, + ) -> Result> { + let base_ty = if class_stmt_node.type_params.is_some() { + // when we have a specialized scope, we'll look up the inference + // within that scope + let model = SemanticModel::new(db, definition.file(db)); + base_node.ty(&model) + } else { + // Otherwise, we can do the lookup based on the definition scope + definition_expression_ty(db, definition, base_node) + }; + + Self::try_from_ty(base_ty).ok_or(base_ty) + } + + fn try_from_ty(ty: Type<'db>) -> Option { + match ty { + Type::Any => Some(Self::Any), + Type::Unknown => Some(Self::Unknown), + Type::Todo => Some(Self::Todo), + Type::ClassLiteral(class) => Some(Self::Class(class)), + Type::Union(_) => None, // TODO -- forces consideration of multiple possible MROs? + Type::Intersection(_) => None, // TODO -- probably incorrect? + Type::Instance(_) => None, // TODO -- handle `__mro_entries__`? + Type::Never + | Type::None + | Type::BooleanLiteral(_) + | Type::FunctionLiteral(_) + | Type::BytesLiteral(_) + | Type::IntLiteral(_) + | Type::StringLiteral(_) + | Type::LiteralString + | Type::Tuple(_) + | Type::SliceLiteral(_) + | Type::ModuleLiteral(_) => None, + } + } + + pub(super) fn display(self, db: &'db dyn Db) -> impl std::fmt::Display + 'db { + struct Display<'db> { + base: ClassBase<'db>, + db: &'db dyn Db, + } + + impl std::fmt::Display for Display<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.base { + ClassBase::Any => f.write_str("Any"), + ClassBase::Todo => f.write_str("Todo"), + ClassBase::Unknown => f.write_str("Unknown"), + ClassBase::Class(class) => write!(f, "", class.name(self.db)), + } + } + } + + Display { base: self, db } + } + + #[cfg(test)] + #[track_caller] + pub(super) fn expect_class(self) -> ClassType<'db> { + match self { + ClassBase::Class(class) => class, + _ => panic!("Expected a `ClassBase::Class()` variant"), + } + } +} + +impl<'db> From> for Type<'db> { + fn from(value: ClassBase<'db>) -> Self { + match value { + ClassBase::Any => Type::Any, + ClassBase::Todo => Type::Todo, + ClassBase::Unknown => Type::Unknown, + ClassBase::Class(class) => Type::ClassLiteral(class), + } + } +} + +/// Implementation of the [C3-merge algorithm] for calculating a Python class's +/// [method resolution order]. +/// +/// [C3-merge algorithm]: https://docs.python.org/3/howto/mro.html#python-2-3-mro +/// [method resolution order]: https://docs.python.org/3/glossary.html#term-method-resolution-order +fn c3_merge(mut sequences: Vec>) -> Option { + // Most MROs aren't that long... + let mut mro = Vec::with_capacity(8); + + loop { + sequences.retain(|sequence| !sequence.is_empty()); + + if sequences.is_empty() { + return Some(Mro::from(mro)); + } + + // If the candidate exists "deeper down" in the inheritance hierarchy, + // we should refrain from adding it to the MRO for now. Add the first candidate + // for which this does not hold true. If this holds true for all candidates, + // return `None`; it will be impossible to find a consistent MRO for the class + // with the given bases. + let mro_entry = sequences.iter().find_map(|outer_sequence| { + let candidate = outer_sequence[0]; + + let not_head = sequences + .iter() + .all(|sequence| sequence.iter().skip(1).all(|base| base != &candidate)); + + not_head.then_some(candidate) + })?; + + mro.push(mro_entry); + + // Make sure we don't try to add the candidate to the MRO twice: + for sequence in &mut sequences { + if sequence[0] == mro_entry { + sequence.pop_front(); + } + } + } +}