Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[red-knot] resolve class members #11256

Merged
merged 4 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 30 additions & 11 deletions crates/red_knot/src/symbols.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ pub(crate) struct Scope {
name: Name,
kind: ScopeKind,
child_scopes: Vec<ScopeId>,
// symbol IDs, hashed by symbol name
/// symbol IDs, hashed by symbol name
symbols_by_name: Map<SymbolId, ()>,
}

Expand Down Expand Up @@ -107,6 +107,7 @@ bitflags! {
pub(crate) struct Symbol {
name: Name,
flags: SymbolFlags,
scope_id: ScopeId,
// kind: Kind,
}

Expand Down Expand Up @@ -141,7 +142,7 @@ pub(crate) enum Definition {
// the small amount of information we need from the AST.
Import(ImportDefinition),
ImportFrom(ImportFromDefinition),
ClassDef(TypedNodeKey<ast::StmtClassDef>),
ClassDef(ClassDefinition),
FunctionDef(TypedNodeKey<ast::StmtFunctionDef>),
Assignment(TypedNodeKey<ast::StmtAssign>),
AnnotatedAssignment(TypedNodeKey<ast::StmtAnnAssign>),
Expand Down Expand Up @@ -174,6 +175,12 @@ impl ImportFromDefinition {
}
}

#[derive(Clone, Debug)]
pub(crate) struct ClassDefinition {
pub(crate) node_key: TypedNodeKey<ast::StmtClassDef>,
pub(crate) scope_id: ScopeId,
}

#[derive(Debug, Clone)]
pub enum Dependency {
Module(ModuleName),
Expand Down Expand Up @@ -332,7 +339,11 @@ impl SymbolTable {
*entry.key()
}
RawEntryMut::Vacant(entry) => {
let id = self.symbols_by_id.push(Symbol { name, flags });
let id = self.symbols_by_id.push(Symbol {
name,
flags,
scope_id,
});
entry.insert_with_hasher(hash, id, (), |_| hash);
id
}
Expand Down Expand Up @@ -482,8 +493,8 @@ impl SymbolTableBuilder {
&mut self,
name: &str,
params: &Option<Box<ast::TypeParams>>,
nested: impl FnOnce(&mut Self),
) {
nested: impl FnOnce(&mut Self) -> ScopeId,
) -> ScopeId {
if let Some(type_params) = params {
self.push_scope(self.cur_scope(), name, ScopeKind::Annotation);
for type_param in &type_params.type_params {
Expand All @@ -495,10 +506,11 @@ impl SymbolTableBuilder {
self.add_or_update_symbol(name, SymbolFlags::IS_DEFINED);
}
}
nested(self);
let scope_id = nested(self);
if params.is_some() {
self.pop_scope();
}
scope_id
}
}

Expand All @@ -525,21 +537,28 @@ impl PreorderVisitor<'_> for SymbolTableBuilder {
// TODO need to capture more definition statements here
match stmt {
ast::Stmt::ClassDef(node) => {
let def = Definition::ClassDef(TypedNodeKey::from_node(node));
self.add_or_update_symbol_with_def(&node.name, def);
self.with_type_params(&node.name, &node.type_params, |builder| {
builder.push_scope(builder.cur_scope(), &node.name, ScopeKind::Class);
let scope_id = self.with_type_params(&node.name, &node.type_params, |builder| {
let scope_id =
builder.push_scope(builder.cur_scope(), &node.name, ScopeKind::Class);
carljm marked this conversation as resolved.
Show resolved Hide resolved
ast::visitor::preorder::walk_stmt(builder, stmt);
builder.pop_scope();
scope_id
});
let def = Definition::ClassDef(ClassDefinition {
node_key: TypedNodeKey::from_node(node),
scope_id,
});
self.add_or_update_symbol_with_def(&node.name, def);
}
ast::Stmt::FunctionDef(node) => {
let def = Definition::FunctionDef(TypedNodeKey::from_node(node));
self.add_or_update_symbol_with_def(&node.name, def);
self.with_type_params(&node.name, &node.type_params, |builder| {
builder.push_scope(builder.cur_scope(), &node.name, ScopeKind::Function);
let scope_id =
builder.push_scope(builder.cur_scope(), &node.name, ScopeKind::Function);
ast::visitor::preorder::walk_stmt(builder, stmt);
builder.pop_scope();
scope_id
});
}
ast::Stmt::Import(ast::StmtImport { names, .. }) => {
Expand Down
36 changes: 26 additions & 10 deletions crates/red_knot/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#![allow(dead_code)]
use crate::ast_ids::NodeKey;
use crate::files::FileId;
use crate::symbols::SymbolId;
use crate::symbols::{ScopeId, SymbolId};
use crate::{FxDashMap, FxIndexSet, Name};
use ruff_index::{newtype_index, IndexVec};
use rustc_hash::FxHashMap;
Expand Down Expand Up @@ -123,8 +123,15 @@ impl TypeStore {
self.add_or_get_module(file_id).add_function(name)
}

fn add_class(&self, file_id: FileId, name: &str, bases: Vec<Type>) -> ClassTypeId {
self.add_or_get_module(file_id).add_class(name, bases)
fn add_class(
&self,
file_id: FileId,
name: &str,
scope_id: ScopeId,
bases: Vec<Type>,
) -> ClassTypeId {
self.add_or_get_module(file_id)
.add_class(name, scope_id, bases)
}

fn add_union(&mut self, file_id: FileId, elems: &[Type]) -> UnionTypeId {
Expand Down Expand Up @@ -316,9 +323,11 @@ impl ModuleTypeStore {
}
}

fn add_class(&mut self, name: &str, bases: Vec<Type>) -> ClassTypeId {
fn add_class(&mut self, name: &str, scope_id: ScopeId, bases: Vec<Type>) -> ClassTypeId {
let class_id = self.classes.push(ClassType {
name: Name::new(name),
file_id: self.file_id,
scope_id,
// TODO: if no bases are given, that should imply [object]
bases,
});
Expand Down Expand Up @@ -403,7 +412,13 @@ impl std::fmt::Display for DisplayType<'_> {

#[derive(Debug)]
pub(crate) struct ClassType {
/// Name of the class at definition
name: Name,
/// FileId in which the class was defined
carljm marked this conversation as resolved.
Show resolved Hide resolved
pub(crate) file_id: FileId,
carljm marked this conversation as resolved.
Show resolved Hide resolved
/// ScopeId of the class body
pub(crate) scope_id: ScopeId,
/// Types of all class bases
bases: Vec<Type>,
}

Expand Down Expand Up @@ -489,6 +504,7 @@ impl IntersectionType {
#[cfg(test)]
mod tests {
use crate::files::Files;
use crate::symbols::SymbolTable;
use crate::types::{Type, TypeStore};
use crate::FxIndexSet;
use std::path::Path;
Expand All @@ -498,7 +514,7 @@ mod tests {
let store = TypeStore::default();
let files = Files::default();
let file_id = files.intern(Path::new("/foo"));
let id = store.add_class(file_id, "C", Vec::new());
let id = store.add_class(file_id, "C", SymbolTable::root_scope_id(), Vec::new());
assert_eq!(store.get_class(id).name(), "C");
let inst = Type::Instance(id);
assert_eq!(format!("{}", inst.display(&store)), "C");
Expand All @@ -520,8 +536,8 @@ mod tests {
let mut store = TypeStore::default();
let files = Files::default();
let file_id = files.intern(Path::new("/foo"));
let c1 = store.add_class(file_id, "C1", Vec::new());
let c2 = store.add_class(file_id, "C2", Vec::new());
let c1 = store.add_class(file_id, "C1", SymbolTable::root_scope_id(), Vec::new());
let c2 = store.add_class(file_id, "C2", SymbolTable::root_scope_id(), Vec::new());
let elems = vec![Type::Instance(c1), Type::Instance(c2)];
let id = store.add_union(file_id, &elems);
assert_eq!(
Expand All @@ -537,9 +553,9 @@ mod tests {
let mut store = TypeStore::default();
let files = Files::default();
let file_id = files.intern(Path::new("/foo"));
let c1 = store.add_class(file_id, "C1", Vec::new());
let c2 = store.add_class(file_id, "C2", Vec::new());
let c3 = store.add_class(file_id, "C3", Vec::new());
let c1 = store.add_class(file_id, "C1", SymbolTable::root_scope_id(), Vec::new());
let c2 = store.add_class(file_id, "C2", SymbolTable::root_scope_id(), Vec::new());
let c3 = store.add_class(file_id, "C3", SymbolTable::root_scope_id(), Vec::new());
let pos = vec![Type::Instance(c1), Type::Instance(c2)];
let neg = vec![Type::Instance(c3)];
let id = store.add_intersection(file_id, &pos, &neg);
Expand Down
65 changes: 60 additions & 5 deletions crates/red_knot/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ use ruff_python_ast::AstNode;

use crate::db::{HasJar, QueryResult, SemanticDb, SemanticJar};
use crate::module::ModuleName;
use crate::symbols::{Definition, ImportFromDefinition, SymbolId};
use crate::types::Type;
use crate::FileId;
use crate::symbols::{ClassDefinition, Definition, ImportFromDefinition, SymbolId};
use crate::types::{ClassType, ClassTypeId, Type};
use crate::{FileId, Name};
use ruff_python_ast as ast;

// FIXME: Figure out proper dead-lock free synchronisation now that this takes `&db` instead of `&mut db`.
Expand Down Expand Up @@ -51,7 +51,7 @@ where
Type::Unknown
}
}
Definition::ClassDef(node_key) => {
Definition::ClassDef(ClassDefinition { node_key, scope_id }) => {
if let Some(ty) = type_store.get_cached_node_type(file_id, node_key.erased()) {
ty
} else {
Expand All @@ -65,7 +65,8 @@ where
bases.push(infer_expr_type(db, file_id, base)?);
}

let ty = Type::Class(type_store.add_class(file_id, &node.name.id, bases));
let ty =
Type::Class(type_store.add_class(file_id, &node.name.id, *scope_id, bases));
type_store.cache_node_type(file_id, *node_key.erased(), ty);
ty
}
Expand Down Expand Up @@ -119,12 +120,30 @@ where
}
}

fn get_class_member<Db>(db: &Db, class_id: ClassTypeId, name: &Name) -> QueryResult<Option<Type>>
carljm marked this conversation as resolved.
Show resolved Hide resolved
carljm marked this conversation as resolved.
Show resolved Hide resolved
where
Db: SemanticDb + HasJar<SemanticJar>,
{
let jar = db.jar()?;
let ClassType {
scope_id, file_id, ..
} = *jar.type_store.get_class(class_id);
let table = db.symbol_table(file_id)?;
if let Some(symbol_id) = table.symbol_id_by_name(scope_id, name) {
Ok(Some(infer_symbol_type(db, file_id, symbol_id)?))
} else {
Ok(None)
}
}

#[cfg(test)]
mod tests {
use super::get_class_member;
use crate::db::tests::TestDb;
use crate::db::{HasJar, SemanticDb, SemanticJar};
use crate::module::{ModuleName, ModuleSearchPath, ModuleSearchPathKind};
use crate::types::Type;
use crate::Name;

// TODO with virtual filesystem we shouldn't have to write files to disk for these
// tests
Expand Down Expand Up @@ -214,4 +233,40 @@ mod tests {

Ok(())
}

#[test]
fn resolve_method() -> anyhow::Result<()> {
let case = create_test()?;
let db = &case.db;

let path = case.src.path().join("mod.py");
std::fs::write(path, "class C:\n def f(self): pass")?;
let file = db
.resolve_module(ModuleName::new("mod"))?
.expect("module should be found")
.path(db)?
.file();
let syms = db.symbol_table(file)?;
let sym = syms
.root_symbol_id_by_name("C")
.expect("C symbol should be found");

let ty = db.infer_symbol_type(file, sym)?;

let Type::Class(class_id) = ty else {
panic!("C is not a Class");
};

let member_ty = get_class_member(db, class_id, &Name::new("f")).expect("C.f to resolve");

let Some(Type::Function(func_id)) = member_ty else {
panic!("C.f is not a Function");
};

let jar = HasJar::<SemanticJar>::jar(db)?;
let function = jar.type_store.get_function(func_id);
assert_eq!(function.name(), "f");

Ok(())
}
}
Loading