Skip to content

Commit

Permalink
[red-knot] Introduce a new ClassLiteralType struct (astral-sh#14108)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexWaygood authored Nov 5, 2024
1 parent abafeb4 commit eead549
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 143 deletions.
127 changes: 85 additions & 42 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,13 @@ fn symbol<'db>(db: &'db dyn Db, scope: ScopeId<'db>, name: &str) -> Symbol<'db>
/// so the cost of hashing the names is likely to be more expensive than it's worth.
#[salsa::tracked(return_ref)]
fn module_type_symbols<'db>(db: &'db dyn Db) -> smallvec::SmallVec<[ast::name::Name; 8]> {
let Some(module_type) = KnownClass::ModuleType
.to_class(db)
.into_class_literal_type()
else {
let Some(module_type) = KnownClass::ModuleType.to_class(db).into_class_literal() else {
// The most likely way we get here is if a user specified a `--custom-typeshed-dir`
// without a `types.pyi` stub in the `stdlib/` directory
return smallvec::SmallVec::default();
};

let module_type_scope = module_type.body_scope(db);
let module_type_scope = module_type.class.body_scope(db);
let module_type_symbol_table = symbol_table(db, module_type_scope);

// `__dict__` and `__init__` are very special members that can be accessed as attributes
Expand Down Expand Up @@ -303,13 +300,11 @@ fn declarations_ty<'db>(
}
}

/// Unique ID for a type.
/// Representation of a type: a set of possible values at runtime.
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub enum Type<'db> {
/// The dynamic type: a statically-unknown set of values
/// The dynamic type: a statically unknown set of values
Any,
/// The empty set of values
Never,
/// Unknown type (either no annotation, or some kind of type error).
/// Equivalent to Any, or possibly to object in strict mode
Unknown,
Expand All @@ -322,12 +317,14 @@ pub enum Type<'db> {
/// output to be unknown. An output should only be `Todo` if fixing all `Todo` inputs to be not
/// `Todo` would change the output type.
Todo,
/// The empty set of values
Never,
/// A specific function object
FunctionLiteral(FunctionType<'db>),
/// A specific module object
ModuleLiteral(File),
/// A specific class object
ClassLiteral(ClassType<'db>),
ClassLiteral(ClassLiteralType<'db>),
/// The set of Python objects with the given class in their __class__'s method resolution order
Instance(InstanceType<'db>),
/// The set of objects in any of the types in the union
Expand All @@ -338,7 +335,7 @@ pub enum Type<'db> {
IntLiteral(i64),
/// A boolean literal, either `True` or `False`.
BooleanLiteral(bool),
/// A string literal
/// A string literal whose value is known
StringLiteral(StringLiteralType<'db>),
/// A string known to originate only from literal values, but whose value is not known (unlike
/// `StringLiteral` above).
Expand All @@ -362,24 +359,24 @@ impl<'db> Type<'db> {
matches!(self, Type::Todo)
}

pub const fn into_class_literal_type(self) -> Option<ClassType<'db>> {
pub const fn into_class_literal(self) -> Option<ClassLiteralType<'db>> {
match self {
Type::ClassLiteral(class_type) => Some(class_type),
_ => None,
}
}

#[track_caller]
pub fn expect_class_literal(self) -> ClassType<'db> {
self.into_class_literal_type()
pub fn expect_class_literal(self) -> ClassLiteralType<'db> {
self.into_class_literal()
.expect("Expected a Type::ClassLiteral variant")
}

pub const fn is_class_literal(&self) -> bool {
matches!(self, Type::ClassLiteral(..))
}

pub const fn into_module_literal_type(self) -> Option<File> {
pub const fn into_module_literal(self) -> Option<File> {
match self {
Type::ModuleLiteral(file) => Some(file),
_ => None,
Expand All @@ -388,7 +385,7 @@ impl<'db> Type<'db> {

#[track_caller]
pub fn expect_module_literal(self) -> File {
self.into_module_literal_type()
self.into_module_literal()
.expect("Expected a Type::ModuleLiteral variant")
}

Expand All @@ -397,7 +394,7 @@ impl<'db> Type<'db> {
IntersectionBuilder::new(db).add_negative(*self).build()
}

pub const fn into_union_type(self) -> Option<UnionType<'db>> {
pub const fn into_union(self) -> Option<UnionType<'db>> {
match self {
Type::Union(union_type) => Some(union_type),
_ => None,
Expand All @@ -406,15 +403,14 @@ impl<'db> Type<'db> {

#[track_caller]
pub fn expect_union(self) -> UnionType<'db> {
self.into_union_type()
.expect("Expected a Type::Union variant")
self.into_union().expect("Expected a Type::Union variant")
}

pub const fn is_union(&self) -> bool {
matches!(self, Type::Union(..))
}

pub const fn into_intersection_type(self) -> Option<IntersectionType<'db>> {
pub const fn into_intersection(self) -> Option<IntersectionType<'db>> {
match self {
Type::Intersection(intersection_type) => Some(intersection_type),
_ => None,
Expand All @@ -423,11 +419,11 @@ impl<'db> Type<'db> {

#[track_caller]
pub fn expect_intersection(self) -> IntersectionType<'db> {
self.into_intersection_type()
self.into_intersection()
.expect("Expected a Type::Intersection variant")
}

pub const fn into_function_literal_type(self) -> Option<FunctionType<'db>> {
pub const fn into_function_literal(self) -> Option<FunctionType<'db>> {
match self {
Type::FunctionLiteral(function_type) => Some(function_type),
_ => None,
Expand All @@ -436,15 +432,15 @@ impl<'db> Type<'db> {

#[track_caller]
pub fn expect_function_literal(self) -> FunctionType<'db> {
self.into_function_literal_type()
self.into_function_literal()
.expect("Expected a Type::FunctionLiteral variant")
}

pub const fn is_function_literal(&self) -> bool {
matches!(self, Type::FunctionLiteral(..))
}

pub const fn into_int_literal_type(self) -> Option<i64> {
pub const fn into_int_literal(self) -> Option<i64> {
match self {
Type::IntLiteral(value) => Some(value),
_ => None,
Expand All @@ -453,7 +449,7 @@ impl<'db> Type<'db> {

#[track_caller]
pub fn expect_int_literal(self) -> i64 {
self.into_int_literal_type()
self.into_int_literal()
.expect("Expected a Type::IntLiteral variant")
}

Expand Down Expand Up @@ -962,7 +958,7 @@ impl<'db> Type<'db> {
global_lookup
}
}
Type::ClassLiteral(class) => class.class_member(db, name),
Type::ClassLiteral(class_ty) => class_ty.member(db, name),
Type::Instance(_) => {
// TODO MRO? get_own_instance_member, get_instance_member
Type::Todo.into()
Expand Down Expand Up @@ -1109,7 +1105,7 @@ impl<'db> Type<'db> {
}

// TODO annotated return type on `__new__` or metaclass `__call__`
Type::ClassLiteral(class) => {
Type::ClassLiteral(ClassLiteralType { class }) => {
CallOutcome::callable(match class.known(db) {
// If the class is the builtin-bool class (for example `bool(1)`), we try to
// return the specific truthiness value of the input arg, `Literal[True]` for
Expand All @@ -1118,7 +1114,7 @@ impl<'db> Type<'db> {
.first()
.map(|arg| arg.bool(db).into_type(db))
.unwrap_or(Type::BooleanLiteral(false)),
_ => class.to_instance(),
_ => Type::anonymous_instance(class),
})
}

Expand Down Expand Up @@ -1236,7 +1232,7 @@ impl<'db> Type<'db> {
Type::Todo => Type::Todo,
Type::Unknown => Type::Unknown,
Type::Never => Type::Never,
Type::ClassLiteral(class) => Type::Instance(InstanceType::anonymous(*class)),
Type::ClassLiteral(ClassLiteralType { class }) => Type::anonymous_instance(*class),
Type::Union(union) => union.map(db, |element| element.to_instance(db)),
// TODO: we can probably do better here: --Alex
Type::Intersection(_) => Type::Todo,
Expand All @@ -1255,6 +1251,10 @@ impl<'db> Type<'db> {
}
}
pub fn anonymous_instance(class: Class<'db>) -> Self {
Self::Instance(InstanceType::anonymous(class))
}
/// The type `NoneType` / `None`
pub fn none(db: &'db dyn Db) -> Type<'db> {
KnownClass::NoneType.to_instance(db)
Expand All @@ -1266,7 +1266,11 @@ impl<'db> Type<'db> {
pub fn to_meta_type(&self, db: &'db dyn Db) -> Type<'db> {
match self {
Type::Never => Type::Never,
Type::Instance(InstanceType { class, .. }) => Type::ClassLiteral(*class),
// TODO: not really correct -- the meta-type of an `InstanceType { class: T }` should be `type[T]`
// (<https://docs.python.org/3/library/typing.html#the-type-of-class-objects>)
Type::Instance(InstanceType { class, .. }) => {
Type::ClassLiteral(ClassLiteralType { class: *class })
}
Type::Union(union) => union.map(db, |ty| ty.to_meta_type(db)),
Type::BooleanLiteral(_) => KnownClass::Bool.to_class(db),
Type::BytesLiteral(_) => KnownClass::Bytes.to_class(db),
Expand Down Expand Up @@ -1911,8 +1915,12 @@ pub enum KnownFunction {
IsInstance,
}

/// Representation of a runtime class object.
///
/// Does not in itself represent a type,
/// but is used as the inner data for several structs that *do* represent types.
#[salsa::interned]
pub struct ClassType<'db> {
pub struct Class<'db> {
/// Name of the class at definition
#[return_ref]
pub name: ast::name::Name,
Expand All @@ -1923,11 +1931,7 @@ pub struct ClassType<'db> {
}

#[salsa::tracked]
impl<'db> ClassType<'db> {
pub fn to_instance(self) -> Type<'db> {
Type::Instance(InstanceType::anonymous(self))
}

impl<'db> Class<'db> {
/// Return `true` if this class represents `known_class`
pub fn is_known(self, db: &'db dyn Db, known_class: KnownClass) -> bool {
self.known(db) == Some(known_class)
Expand Down Expand Up @@ -2007,7 +2011,7 @@ impl<'db> ClassType<'db> {
///
/// If the MRO could not be accurately resolved, this method falls back to iterating
/// over an MRO that has the class directly inheriting from `Unknown`. Use
/// [`ClassType::try_mro`] if you need to distinguish between the success and failure
/// [`Class::try_mro`] if you need to distinguish between the success and failure
/// cases rather than simply iterating over the inferred resolution order for the class.
///
/// [method resolution order]: https://docs.python.org/3/glossary.html#term-method-resolution-order
Expand All @@ -2016,7 +2020,7 @@ impl<'db> ClassType<'db> {
}

/// Return `true` if `other` is present in this class's MRO.
pub fn is_subclass_of(self, db: &'db dyn Db, other: ClassType) -> bool {
pub fn is_subclass_of(self, db: &'db dyn Db, other: Class) -> bool {
// `is_subclass_of` is checking the subtype relation, in which gradual types do not
// participate, so we should not return `True` if we find `Any/Unknown` in the MRO.
self.iter_mro(db).contains(&ClassBase::Class(other))
Expand Down Expand Up @@ -2056,26 +2060,54 @@ impl<'db> ClassType<'db> {
/// Returns the inferred type of the class member named `name`.
///
/// Returns [`Symbol::Unbound`] if `name` cannot be found in this class's scope
/// directly. Use [`ClassType::class_member`] if you require a method that will
/// directly. Use [`Class::class_member`] if you require a method that will
/// traverse through the MRO until it finds the member.
pub(crate) fn own_class_member(self, db: &'db dyn Db, name: &str) -> Symbol<'db> {
let scope = self.body_scope(db);
symbol(db, scope, name)
}
}

/// A singleton type representing a single class object at runtime.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ClassLiteralType<'db> {
class: Class<'db>,
}

impl<'db> ClassLiteralType<'db> {
fn member(self, db: &'db dyn Db, name: &str) -> Symbol<'db> {
self.class.class_member(db, name)
}
}

impl<'db> From<ClassLiteralType<'db>> for Type<'db> {
fn from(value: ClassLiteralType<'db>) -> Self {
Self::ClassLiteral(value)
}
}

/// A type representing the set of runtime objects which are instances of a certain class.
///
/// Some specific instances of some types need to be treated specially by the type system:
/// for example, various special forms are instances of `typing._SpecialForm`,
/// but need to be handled differently in annotations. These special instances are marked as such
/// using the `known` field on this struct.
///
/// Note that, for example, `InstanceType { class: typing._SpecialForm, known: None }`
/// is a supertype of `InstanceType { class: typing._SpecialForm, known: KnownInstance::Literal }`.
/// The two types are not disjoint.
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub struct InstanceType<'db> {
class: ClassType<'db>,
class: Class<'db>,
known: Option<KnownInstance>,
}

impl<'db> InstanceType<'db> {
pub fn anonymous(class: ClassType<'db>) -> Self {
pub fn anonymous(class: Class<'db>) -> Self {
Self { class, known: None }
}

pub fn known(class: ClassType<'db>, known: KnownInstance) -> Self {
pub fn known(class: Class<'db>, known: KnownInstance) -> Self {
Self {
class,
known: Some(known),
Expand All @@ -2085,6 +2117,17 @@ impl<'db> InstanceType<'db> {
pub fn is_known(&self, known_instance: KnownInstance) -> bool {
self.known == Some(known_instance)
}

/// Return `true` if members of this type are instances of the class `class` at runtime.
pub fn is_instance_of(self, db: &'db dyn Db, class: Class<'db>) -> bool {
self.class.is_subclass_of(db, class)
}
}

impl<'db> From<InstanceType<'db>> for Type<'db> {
fn from(value: InstanceType<'db>) -> Self {
Self::Instance(value)
}
}

#[salsa::interned]
Expand Down
4 changes: 2 additions & 2 deletions crates/red_knot_python_semantic/src/types/diagnostic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::fmt::Formatter;
use std::ops::Deref;
use std::sync::Arc;

use crate::types::Type;
use crate::types::{ClassLiteralType, Type};
use crate::Db;

#[derive(Debug, Eq, PartialEq)]
Expand Down Expand Up @@ -209,7 +209,7 @@ impl<'db> TypeCheckDiagnosticsBuilder<'db> {
assigned_ty: Type<'db>,
) {
match declared_ty {
Type::ClassLiteral(class) => {
Type::ClassLiteral(ClassLiteralType { class }) => {
self.add(node, "invalid-assignment", format_args!(
"Implicit shadowing of class `{}`; annotate to make it explicit if this is intentional",
class.name(self.db)));
Expand Down
4 changes: 2 additions & 2 deletions crates/red_knot_python_semantic/src/types/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use ruff_db::display::FormatterJoinExtension;
use ruff_python_ast::str::Quote;
use ruff_python_literal::escape::AsciiEscape;

use crate::types::{InstanceType, IntersectionType, KnownClass, Type, UnionType};
use crate::types::{ClassLiteralType, InstanceType, IntersectionType, KnownClass, Type, UnionType};
use crate::Db;
use rustc_hash::FxHashMap;

Expand Down Expand Up @@ -76,7 +76,7 @@ impl Display for DisplayRepresentation<'_> {
write!(f, "<module '{:?}'>", file.path(self.db))
}
// TODO functions and classes should display using a fully qualified name
Type::ClassLiteral(class) => f.write_str(class.name(self.db)),
Type::ClassLiteral(ClassLiteralType { class }) => f.write_str(class.name(self.db)),
Type::Instance(InstanceType { class, known }) => f.write_str(match known {
Some(super::KnownInstance::Literal) => "Literal",
_ => class.name(self.db),
Expand Down
Loading

0 comments on commit eead549

Please sign in to comment.