Skip to content

Commit

Permalink
[red-knot] Preliminary support for type aliases
Browse files Browse the repository at this point in the history
  • Loading branch information
sharkdp committed Nov 15, 2024
1 parent 9f3235a commit f56e867
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 13 deletions.
21 changes: 21 additions & 0 deletions crates/red_knot_python_semantic/src/semantic_index/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,27 @@ where
},
);
}
ast::Stmt::TypeAlias(type_alias) => {
let symbol = self.add_symbol(
type_alias
.name
.as_name_expr()
.expect("type alias name is a name expr")
.id
.clone(),
);
self.add_definition(symbol, type_alias);
self.visit_expr(&type_alias.name);

self.with_type_params(
NodeWithScopeRef::TypeAliasTypeParameters(type_alias),
type_alias.type_params.as_ref(),
|builder| {
builder.visit_expr(&type_alias.value);
builder.current_scope()
},
);
}
ast::Stmt::Import(node) => {
for alias in &node.names {
let symbol_name = if let Some(asname) = &alias.asname {
Expand Down
19 changes: 19 additions & 0 deletions crates/red_knot_python_semantic/src/semantic_index/definition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ pub(crate) enum DefinitionNodeRef<'a> {
For(ForStmtDefinitionNodeRef<'a>),
Function(&'a ast::StmtFunctionDef),
Class(&'a ast::StmtClassDef),
TypeAlias(&'a ast::StmtTypeAlias),
NamedExpression(&'a ast::ExprNamed),
Assignment(AssignmentDefinitionNodeRef<'a>),
AnnotatedAssignment(&'a ast::StmtAnnAssign),
Expand All @@ -109,6 +110,12 @@ impl<'a> From<&'a ast::StmtClassDef> for DefinitionNodeRef<'a> {
}
}

impl<'a> From<&'a ast::StmtTypeAlias> for DefinitionNodeRef<'a> {
fn from(node: &'a ast::StmtTypeAlias) -> Self {
Self::TypeAlias(node)
}
}

impl<'a> From<&'a ast::ExprNamed> for DefinitionNodeRef<'a> {
fn from(node: &'a ast::ExprNamed) -> Self {
Self::NamedExpression(node)
Expand Down Expand Up @@ -265,6 +272,9 @@ impl<'db> DefinitionNodeRef<'db> {
DefinitionNodeRef::Class(class) => {
DefinitionKind::Class(AstNodeRef::new(parsed, class))
}
DefinitionNodeRef::TypeAlias(type_alias) => {
DefinitionKind::TypeAlias(AstNodeRef::new(parsed, type_alias))
}
DefinitionNodeRef::NamedExpression(named) => {
DefinitionKind::NamedExpression(AstNodeRef::new(parsed, named))
}
Expand Down Expand Up @@ -358,6 +368,7 @@ impl<'db> DefinitionNodeRef<'db> {
}
Self::Function(node) => node.into(),
Self::Class(node) => node.into(),
Self::TypeAlias(node) => node.into(),
Self::NamedExpression(node) => node.into(),
Self::Assignment(AssignmentDefinitionNodeRef {
value: _,
Expand Down Expand Up @@ -434,6 +445,7 @@ pub enum DefinitionKind<'db> {
ImportFrom(ImportFromDefinitionKind),
Function(AstNodeRef<ast::StmtFunctionDef>),
Class(AstNodeRef<ast::StmtClassDef>),
TypeAlias(AstNodeRef<ast::StmtTypeAlias>),
NamedExpression(AstNodeRef<ast::ExprNamed>),
Assignment(AssignmentDefinitionKind<'db>),
AnnotatedAssignment(AstNodeRef<ast::StmtAnnAssign>),
Expand All @@ -456,6 +468,7 @@ impl DefinitionKind<'_> {
// functions, classes, and imports always bind, and we consider them declarations
DefinitionKind::Function(_)
| DefinitionKind::Class(_)
| DefinitionKind::TypeAlias(_)
| DefinitionKind::Import(_)
| DefinitionKind::ImportFrom(_)
| DefinitionKind::TypeVar(_)
Expand Down Expand Up @@ -682,6 +695,12 @@ impl From<&ast::StmtClassDef> for DefinitionNodeKey {
}
}

impl From<&ast::StmtTypeAlias> for DefinitionNodeKey {
fn from(node: &ast::StmtTypeAlias) -> Self {
Self(NodeKey::from_node(node))
}
}

impl From<&ast::ExprName> for DefinitionNodeKey {
fn from(node: &ast::ExprName) -> Self {
Self(NodeKey::from_node(node))
Expand Down
19 changes: 18 additions & 1 deletion crates/red_knot_python_semantic/src/semantic_index/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ impl<'db> ScopeId<'db> {
NodeWithScopeKind::Class(class) | NodeWithScopeKind::ClassTypeParameters(class) => {
class.name.as_str()
}
NodeWithScopeKind::TypeAliasTypeParameters(type_alias) => type_alias
.name
.as_name_expr()
.map(|name| name.id.as_str())
.unwrap_or("<type alias>"),
NodeWithScopeKind::Function(function)
| NodeWithScopeKind::FunctionTypeParameters(function) => function.name.as_str(),
NodeWithScopeKind::Lambda(_) => "<lambda>",
Expand Down Expand Up @@ -201,6 +206,7 @@ pub enum ScopeKind {
Class,
Function,
Comprehension,
TypeAlias,
}

impl ScopeKind {
Expand Down Expand Up @@ -326,6 +332,7 @@ pub(crate) enum NodeWithScopeRef<'a> {
Lambda(&'a ast::ExprLambda),
FunctionTypeParameters(&'a ast::StmtFunctionDef),
ClassTypeParameters(&'a ast::StmtClassDef),
TypeAliasTypeParameters(&'a ast::StmtTypeAlias),
ListComprehension(&'a ast::ExprListComp),
SetComprehension(&'a ast::ExprSetComp),
DictComprehension(&'a ast::ExprDictComp),
Expand All @@ -347,6 +354,9 @@ impl NodeWithScopeRef<'_> {
NodeWithScopeRef::Function(function) => {
NodeWithScopeKind::Function(AstNodeRef::new(module, function))
}
NodeWithScopeRef::TypeAliasTypeParameters(type_alias) => {
NodeWithScopeKind::TypeAliasTypeParameters(AstNodeRef::new(module, type_alias))
}
NodeWithScopeRef::Lambda(lambda) => {
NodeWithScopeKind::Lambda(AstNodeRef::new(module, lambda))
}
Expand Down Expand Up @@ -387,6 +397,9 @@ impl NodeWithScopeRef<'_> {
NodeWithScopeRef::ClassTypeParameters(class) => {
NodeWithScopeKey::ClassTypeParameters(NodeKey::from_node(class))
}
NodeWithScopeRef::TypeAliasTypeParameters(type_alias) => {
NodeWithScopeKey::TypeAliasTypeParameters(NodeKey::from_node(type_alias))
}
NodeWithScopeRef::ListComprehension(comprehension) => {
NodeWithScopeKey::ListComprehension(NodeKey::from_node(comprehension))
}
Expand All @@ -411,6 +424,7 @@ pub enum NodeWithScopeKind {
ClassTypeParameters(AstNodeRef<ast::StmtClassDef>),
Function(AstNodeRef<ast::StmtFunctionDef>),
FunctionTypeParameters(AstNodeRef<ast::StmtFunctionDef>),
TypeAliasTypeParameters(AstNodeRef<ast::StmtTypeAlias>),
Lambda(AstNodeRef<ast::ExprLambda>),
ListComprehension(AstNodeRef<ast::ExprListComp>),
SetComprehension(AstNodeRef<ast::ExprSetComp>),
Expand All @@ -425,7 +439,9 @@ impl NodeWithScopeKind {
Self::Class(_) => ScopeKind::Class,
Self::Function(_) => ScopeKind::Function,
Self::Lambda(_) => ScopeKind::Function,
Self::FunctionTypeParameters(_) | Self::ClassTypeParameters(_) => ScopeKind::Annotation,
Self::FunctionTypeParameters(_)
| Self::ClassTypeParameters(_)
| Self::TypeAliasTypeParameters(_) => ScopeKind::Annotation,
Self::ListComprehension(_)
| Self::SetComprehension(_)
| Self::DictComprehension(_)
Expand Down Expand Up @@ -455,6 +471,7 @@ pub(crate) enum NodeWithScopeKey {
ClassTypeParameters(NodeKey),
Function(NodeKey),
FunctionTypeParameters(NodeKey),
TypeAliasTypeParameters(NodeKey),
Lambda(NodeKey),
ListComprehension(NodeKey),
SetComprehension(NodeKey),
Expand Down
65 changes: 53 additions & 12 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,9 @@ impl<'db> TypeInferenceBuilder<'db> {
NodeWithScopeKind::FunctionTypeParameters(function) => {
self.infer_function_type_params(function.node());
}
NodeWithScopeKind::TypeAliasTypeParameters(type_alias) => {
self.infer_type_alias_type_params(type_alias.node());
}
NodeWithScopeKind::ListComprehension(comprehension) => {
self.infer_list_comprehension_expression_scope(comprehension.node());
}
Expand Down Expand Up @@ -605,6 +608,9 @@ impl<'db> TypeInferenceBuilder<'db> {
self.infer_function_definition(function.node(), definition);
}
DefinitionKind::Class(class) => self.infer_class_definition(class.node(), definition),
DefinitionKind::TypeAlias(type_alias) => {
self.infer_type_alias_definition(type_alias.node(), definition);
}
DefinitionKind::Import(import) => {
self.infer_import_definition(import.node(), definition);
}
Expand Down Expand Up @@ -847,6 +853,17 @@ impl<'db> TypeInferenceBuilder<'db> {
self.infer_parameters(&function.parameters);
}

fn infer_type_alias_type_params(&mut self, type_alias: &ast::StmtTypeAlias) {
let _span = tracing::trace_span!("infer_type_alias_type_params").entered();
let type_params = type_alias
.type_params
.as_ref()
.expect("type alias type params scope without type params");

self.infer_type_parameters(type_params);
self.infer_annotation_expression(&type_alias.value, DeferredExpressionState::None);
}

fn infer_function_body(&mut self, function: &ast::StmtFunctionDef) {
self.infer_body(&function.body);
}
Expand Down Expand Up @@ -893,8 +910,10 @@ impl<'db> TypeInferenceBuilder<'db> {
}
}

fn infer_definition(&mut self, node: impl Into<DefinitionNodeKey>) {
fn infer_definition(&mut self, node: impl Into<DefinitionNodeKey> + std::fmt::Debug) {
let _span = tracing::trace_span!("infer_definition", node=?node).entered();
let definition = self.index.definition(node);
tracing::trace!(definition=?definition);
let result = infer_definition_types(self.db, definition);
self.extend(result);
}
Expand Down Expand Up @@ -1107,6 +1126,31 @@ impl<'db> TypeInferenceBuilder<'db> {
}
}

fn infer_type_alias_definition(
&mut self,
type_alias: &ast::StmtTypeAlias,
definition: Definition<'db>,
) {
let type_alias_ty = Type::Todo;
let _span =
tracing::trace_span!("infer_type_alias_definition", type_alias=?type_alias).entered();

self.infer_expression(&type_alias.name);
if let Some(ref type_params) = type_alias.type_params {
self.infer_type_parameters(type_params);
}

self.infer_annotation_expression(&type_alias.value, DeferredExpressionState::Deferred);

// TODO: _with_binding?
self.add_declaration_with_binding(
type_alias.into(),
definition,
type_alias_ty,
type_alias_ty,
);
}

fn infer_if_statement(&mut self, if_statement: &ast::StmtIf) {
let ast::StmtIf {
range: _,
Expand Down Expand Up @@ -1406,6 +1450,7 @@ impl<'db> TypeInferenceBuilder<'db> {
node: &ast::TypeParamTypeVar,
definition: Definition<'db>,
) {
let _span = tracing::trace_span!("infer_typevar_definition", node=?node).entered();
let ast::TypeParamTypeVar {
range: _,
name,
Expand Down Expand Up @@ -1439,13 +1484,15 @@ impl<'db> TypeInferenceBuilder<'db> {
)),
None => None,
};
tracing::trace!(bound_or_constraint=?bound_or_constraint);
let default_ty = self.infer_optional_type_expression(default.as_deref());
let ty = Type::KnownInstance(KnownInstanceType::TypeVar(TypeVarInstance::new(
self.db,
name.id.clone(),
bound_or_constraint,
default_ty,
)));
tracing::trace!(ty=?ty);
self.add_declaration_with_binding(node.into(), definition, ty, ty);
}

Expand Down Expand Up @@ -1828,17 +1875,9 @@ impl<'db> TypeInferenceBuilder<'db> {
self.infer_augmented_op(assignment, target_type, value_type)
}

fn infer_type_alias_statement(&mut self, type_alias_statement: &ast::StmtTypeAlias) {
let ast::StmtTypeAlias {
range: _,
name,
type_params: _,
value,
} = type_alias_statement;
self.infer_expression(value);
self.infer_expression(name);

// TODO: properly handle generic type aliases, which need their own annotation scope
fn infer_type_alias_statement(&mut self, node: &ast::StmtTypeAlias) {
let _span = tracing::trace_span!("infer_type_alias_statement", node=?node).entered();
self.infer_definition(node);
}

fn infer_for_statement(&mut self, for_statement: &ast::StmtFor) {
Expand Down Expand Up @@ -2779,6 +2818,7 @@ impl<'db> TypeInferenceBuilder<'db> {

/// Infer the type of a [`ast::ExprName`] expression, assuming a load context.
fn infer_name_load(&mut self, name: &ast::ExprName) -> Type<'db> {
let _span = tracing::trace_span!("infer_name_load", name = ?name).entered();
let ast::ExprName {
range: _,
id,
Expand Down Expand Up @@ -4114,6 +4154,7 @@ impl<'db> TypeInferenceBuilder<'db> {
annotation: &ast::Expr,
deferred_state: DeferredExpressionState,
) -> Type<'db> {
let _span = tracing::trace_span!("infer_annotation_expression", ?annotation).entered();
let previous_deferred_state = std::mem::replace(&mut self.deferred_state, deferred_state);
let annotation_ty = self.infer_annotation_expression_impl(annotation);
self.deferred_state = previous_deferred_state;
Expand Down
2 changes: 2 additions & 0 deletions crates/red_knot_workspace/tests/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ fn get_workspace_root() -> anyhow::Result<SystemPathBuf> {
#[test]
#[allow(clippy::print_stdout)]
fn corpus_no_panic() -> anyhow::Result<()> {
let _logging = ruff_db::testing::setup_logging();

let root = SystemPathBuf::from("/src");

let system = TestSystem::default();
Expand Down

0 comments on commit f56e867

Please sign in to comment.