Skip to content

Commit

Permalink
[red-knot] Preliminary support for type aliases
Browse files Browse the repository at this point in the history
  • Loading branch information
sharkdp committed Nov 20, 2024
1 parent 942d6ee commit b9d931c
Show file tree
Hide file tree
Showing 10 changed files with 198 additions and 30 deletions.
22 changes: 22 additions & 0 deletions crates/red_knot_python_semantic/src/semantic_index/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,28 @@ where
},
);
}
ast::Stmt::TypeAlias(type_alias) => {
let symbol = self.add_symbol(
type_alias
.name
.as_name_expr()
.expect("type alias name is a name expr") // TODO: does the parser guarantee this?
.id
.clone(),
);
self.add_definition(symbol, type_alias);
self.visit_expr(&type_alias.name);

self.with_type_params(
NodeWithScopeRef::TypeAliasTypeParameters(type_alias),
type_alias.type_params.as_ref(),
|builder| {
builder.push_scope(NodeWithScopeRef::TypeAlias(type_alias));
builder.visit_expr(&type_alias.value);
builder.pop_scope()
},
);
}
ast::Stmt::Import(node) => {
for alias in &node.names {
let symbol_name = if let Some(asname) = &alias.asname {
Expand Down
19 changes: 19 additions & 0 deletions crates/red_knot_python_semantic/src/semantic_index/definition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ pub(crate) enum DefinitionNodeRef<'a> {
For(ForStmtDefinitionNodeRef<'a>),
Function(&'a ast::StmtFunctionDef),
Class(&'a ast::StmtClassDef),
TypeAlias(&'a ast::StmtTypeAlias),
NamedExpression(&'a ast::ExprNamed),
Assignment(AssignmentDefinitionNodeRef<'a>),
AnnotatedAssignment(&'a ast::StmtAnnAssign),
Expand All @@ -109,6 +110,12 @@ impl<'a> From<&'a ast::StmtClassDef> for DefinitionNodeRef<'a> {
}
}

impl<'a> From<&'a ast::StmtTypeAlias> for DefinitionNodeRef<'a> {
fn from(node: &'a ast::StmtTypeAlias) -> Self {
Self::TypeAlias(node)
}
}

impl<'a> From<&'a ast::ExprNamed> for DefinitionNodeRef<'a> {
fn from(node: &'a ast::ExprNamed) -> Self {
Self::NamedExpression(node)
Expand Down Expand Up @@ -265,6 +272,9 @@ impl<'db> DefinitionNodeRef<'db> {
DefinitionNodeRef::Class(class) => {
DefinitionKind::Class(AstNodeRef::new(parsed, class))
}
DefinitionNodeRef::TypeAlias(type_alias) => {
DefinitionKind::TypeAlias(AstNodeRef::new(parsed, type_alias))
}
DefinitionNodeRef::NamedExpression(named) => {
DefinitionKind::NamedExpression(AstNodeRef::new(parsed, named))
}
Expand Down Expand Up @@ -358,6 +368,7 @@ impl<'db> DefinitionNodeRef<'db> {
}
Self::Function(node) => node.into(),
Self::Class(node) => node.into(),
Self::TypeAlias(node) => node.into(),
Self::NamedExpression(node) => node.into(),
Self::Assignment(AssignmentDefinitionNodeRef {
value: _,
Expand Down Expand Up @@ -434,6 +445,7 @@ pub enum DefinitionKind<'db> {
ImportFrom(ImportFromDefinitionKind),
Function(AstNodeRef<ast::StmtFunctionDef>),
Class(AstNodeRef<ast::StmtClassDef>),
TypeAlias(AstNodeRef<ast::StmtTypeAlias>),
NamedExpression(AstNodeRef<ast::ExprNamed>),
Assignment(AssignmentDefinitionKind<'db>),
AnnotatedAssignment(AstNodeRef<ast::StmtAnnAssign>),
Expand All @@ -456,6 +468,7 @@ impl DefinitionKind<'_> {
// functions, classes, and imports always bind, and we consider them declarations
DefinitionKind::Function(_)
| DefinitionKind::Class(_)
| DefinitionKind::TypeAlias(_)
| DefinitionKind::Import(_)
| DefinitionKind::ImportFrom(_)
| DefinitionKind::TypeVar(_)
Expand Down Expand Up @@ -682,6 +695,12 @@ impl From<&ast::StmtClassDef> for DefinitionNodeKey {
}
}

impl From<&ast::StmtTypeAlias> for DefinitionNodeKey {
fn from(node: &ast::StmtTypeAlias) -> Self {
Self(NodeKey::from_node(node))
}
}

impl From<&ast::ExprName> for DefinitionNodeKey {
fn from(node: &ast::ExprName) -> Self {
Self(NodeKey::from_node(node))
Expand Down
37 changes: 36 additions & 1 deletion crates/red_knot_python_semantic/src/semantic_index/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ impl<'db> ScopeId<'db> {
NodeWithScopeKind::ClassTypeParameters(_)
| NodeWithScopeKind::FunctionTypeParameters(_)
| NodeWithScopeKind::Function(_)
| NodeWithScopeKind::TypeAlias(_)
| NodeWithScopeKind::ListComprehension(_)
| NodeWithScopeKind::SetComprehension(_)
| NodeWithScopeKind::DictComprehension(_)
Expand All @@ -144,6 +145,12 @@ impl<'db> ScopeId<'db> {
}
NodeWithScopeKind::Function(function)
| NodeWithScopeKind::FunctionTypeParameters(function) => function.name.as_str(),
NodeWithScopeKind::TypeAlias(type_alias)
| NodeWithScopeKind::TypeAliasTypeParameters(type_alias) => type_alias
.name
.as_name_expr()
.map(|name| name.id.as_str())
.unwrap_or("<type alias>"),
NodeWithScopeKind::Lambda(_) => "<lambda>",
NodeWithScopeKind::ListComprehension(_) => "<listcomp>",
NodeWithScopeKind::SetComprehension(_) => "<setcomp>",
Expand Down Expand Up @@ -201,6 +208,7 @@ pub enum ScopeKind {
Class,
Function,
Comprehension,
TypeAlias,
}

impl ScopeKind {
Expand Down Expand Up @@ -326,6 +334,8 @@ pub(crate) enum NodeWithScopeRef<'a> {
Lambda(&'a ast::ExprLambda),
FunctionTypeParameters(&'a ast::StmtFunctionDef),
ClassTypeParameters(&'a ast::StmtClassDef),
TypeAlias(&'a ast::StmtTypeAlias),
TypeAliasTypeParameters(&'a ast::StmtTypeAlias),
ListComprehension(&'a ast::ExprListComp),
SetComprehension(&'a ast::ExprSetComp),
DictComprehension(&'a ast::ExprDictComp),
Expand All @@ -347,6 +357,12 @@ impl NodeWithScopeRef<'_> {
NodeWithScopeRef::Function(function) => {
NodeWithScopeKind::Function(AstNodeRef::new(module, function))
}
NodeWithScopeRef::TypeAlias(type_alias) => {
NodeWithScopeKind::TypeAlias(AstNodeRef::new(module, type_alias))
}
NodeWithScopeRef::TypeAliasTypeParameters(type_alias) => {
NodeWithScopeKind::TypeAliasTypeParameters(AstNodeRef::new(module, type_alias))
}
NodeWithScopeRef::Lambda(lambda) => {
NodeWithScopeKind::Lambda(AstNodeRef::new(module, lambda))
}
Expand Down Expand Up @@ -387,6 +403,12 @@ impl NodeWithScopeRef<'_> {
NodeWithScopeRef::ClassTypeParameters(class) => {
NodeWithScopeKey::ClassTypeParameters(NodeKey::from_node(class))
}
NodeWithScopeRef::TypeAlias(type_alias) => {
NodeWithScopeKey::TypeAliasTypeParameters(NodeKey::from_node(type_alias))
}
NodeWithScopeRef::TypeAliasTypeParameters(type_alias) => {
NodeWithScopeKey::TypeAliasTypeParameters(NodeKey::from_node(type_alias))
}
NodeWithScopeRef::ListComprehension(comprehension) => {
NodeWithScopeKey::ListComprehension(NodeKey::from_node(comprehension))
}
Expand All @@ -411,6 +433,8 @@ pub enum NodeWithScopeKind {
ClassTypeParameters(AstNodeRef<ast::StmtClassDef>),
Function(AstNodeRef<ast::StmtFunctionDef>),
FunctionTypeParameters(AstNodeRef<ast::StmtFunctionDef>),
TypeAliasTypeParameters(AstNodeRef<ast::StmtTypeAlias>),
TypeAlias(AstNodeRef<ast::StmtTypeAlias>),
Lambda(AstNodeRef<ast::ExprLambda>),
ListComprehension(AstNodeRef<ast::ExprListComp>),
SetComprehension(AstNodeRef<ast::ExprSetComp>),
Expand All @@ -424,8 +448,11 @@ impl NodeWithScopeKind {
Self::Module => ScopeKind::Module,
Self::Class(_) => ScopeKind::Class,
Self::Function(_) => ScopeKind::Function,
Self::TypeAlias(_) => ScopeKind::TypeAlias,
Self::Lambda(_) => ScopeKind::Function,
Self::FunctionTypeParameters(_) | Self::ClassTypeParameters(_) => ScopeKind::Annotation,
Self::FunctionTypeParameters(_)
| Self::ClassTypeParameters(_)
| Self::TypeAliasTypeParameters(_) => ScopeKind::Annotation,
Self::ListComprehension(_)
| Self::SetComprehension(_)
| Self::DictComprehension(_)
Expand All @@ -446,6 +473,13 @@ impl NodeWithScopeKind {
_ => panic!("expected function"),
}
}

pub fn expect_type_alias(&self) -> &ast::StmtTypeAlias {
match self {
Self::TypeAlias(type_alias) => type_alias.node(),
_ => panic!("expected type alias"),
}
}
}

#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
Expand All @@ -455,6 +489,7 @@ pub(crate) enum NodeWithScopeKey {
ClassTypeParameters(NodeKey),
Function(NodeKey),
FunctionTypeParameters(NodeKey),
TypeAliasTypeParameters(NodeKey),
Lambda(NodeKey),
ListComprehension(NodeKey),
SetComprehension(NodeKey),
Expand Down
31 changes: 31 additions & 0 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,8 @@ pub enum Type<'db> {
Instance(InstanceType<'db>),
/// A single Python object that requires special treatment in the type system
KnownInstance(KnownInstanceType<'db>),
/// A type alias, with a name and corresponding type
TypeAlias(TypeAliasType<'db>),
/// The set of objects in any of the types in the union
Union(UnionType<'db>),
/// The set of objects in all of the types in the intersection
Expand Down Expand Up @@ -722,6 +724,9 @@ impl<'db> Type<'db> {
/// wrong `false` answers in some cases.
pub(crate) fn is_disjoint_from(self, db: &'db dyn Db, other: Type<'db>) -> bool {
match (self, other) {
(Type::TypeAlias(alias), other) | (other, Type::TypeAlias(alias)) => {
alias.value_type(db).is_disjoint_from(db, other)
}
(Type::Never, _) | (_, Type::Never) => true,

(Type::Any, _) | (_, Type::Any) => false,
Expand Down Expand Up @@ -928,6 +933,7 @@ impl<'db> Type<'db> {
/// for more complicated types that are actually singletons.
pub(crate) fn is_singleton(self, db: &'db dyn Db) -> bool {
match self {
Type::TypeAlias(alias) => alias.value_type(db).is_singleton(db),
Type::Any
| Type::Never
| Type::Unknown
Expand Down Expand Up @@ -985,6 +991,7 @@ impl<'db> Type<'db> {
/// Return true if this type is non-empty and all inhabitants of this type compare equal.
pub(crate) fn is_single_valued(self, db: &'db dyn Db) -> bool {
match self {
Type::TypeAlias(alias) => alias.value_type(db).is_single_valued(db),
Type::FunctionLiteral(..)
| Type::ModuleLiteral(..)
| Type::ClassLiteral(..)
Expand Down Expand Up @@ -1049,6 +1056,7 @@ impl<'db> Type<'db> {
#[must_use]
pub(crate) fn member(&self, db: &'db dyn Db, name: &str) -> Symbol<'db> {
match self {
Type::TypeAlias(alias) => alias.value_type(db).member(db, name),
Type::Any => Type::Any.into(),
Type::Never => {
// TODO: attribute lookup on Never type
Expand Down Expand Up @@ -1188,6 +1196,7 @@ impl<'db> Type<'db> {
/// when `bool(x)` is called on an object `x`.
fn bool(&self, db: &'db dyn Db) -> Truthiness {
match self {
Type::TypeAlias(alias) => alias.value_type(db).bool(db),
Type::Any | Type::Todo | Type::Never | Type::Unknown => Truthiness::Ambiguous,
Type::FunctionLiteral(_) => Truthiness::AlwaysTrue,
Type::ModuleLiteral(_) => Truthiness::AlwaysTrue,
Expand Down Expand Up @@ -1439,6 +1448,7 @@ impl<'db> Type<'db> {
#[must_use]
pub fn to_instance(&self, db: &'db dyn Db) -> Type<'db> {
match self {
Type::TypeAlias(alias) => alias.value_type(db).to_instance(db),
Type::Any => Type::Any,
Type::Todo => Type::Todo,
Type::Unknown => Type::Unknown,
Expand Down Expand Up @@ -1478,6 +1488,7 @@ impl<'db> Type<'db> {
Type::Unknown => Type::Unknown,
// TODO map this to a new `Type::TypeVar` variant
Type::KnownInstance(KnownInstanceType::TypeVar(_)) => *self,
Type::TypeAlias(alias) => alias.value_type(db).in_type_expression(db),
_ => Type::Todo,
}
}
Expand Down Expand Up @@ -1525,6 +1536,7 @@ impl<'db> Type<'db> {
#[must_use]
pub fn to_meta_type(&self, db: &'db dyn Db) -> Type<'db> {
match self {
Type::TypeAlias(alias) => alias.value_type(db).to_meta_type(db),
Type::Never => Type::Never,
Type::Instance(InstanceType { class }) => {
Type::SubclassOf(SubclassOfType { class: *class })
Expand Down Expand Up @@ -2713,6 +2725,25 @@ impl<'db> Class<'db> {
}
}

#[salsa::interned]
pub struct TypeAliasType<'db> {
#[return_ref]
pub name: ast::name::Name,

rhs_scope: ScopeId<'db>,
}

impl TypeAliasType<'_> {
pub fn value_type(self, db: &dyn Db) -> Type {
let scope = self.rhs_scope(db);

let type_alias_stmt_node = scope.node(db).expect_type_alias();
let definition = semantic_index(db, scope.file(db)).definition(type_alias_stmt_node);

definition_expression_ty(db, definition, &type_alias_stmt_node.value)
}
}

/// Either the explicit `metaclass=` keyword of the class, or the inferred metaclass of one of its base classes.
#[derive(Debug, Clone, PartialEq, Eq)]
pub(super) struct MetaclassCandidate<'db> {
Expand Down
1 change: 1 addition & 0 deletions crates/red_knot_python_semantic/src/types/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ impl Display for DisplayRepresentation<'_> {
}
Type::KnownInstance(known_instance) => f.write_str(known_instance.repr(self.db)),
Type::FunctionLiteral(function) => f.write_str(function.name(self.db)),
Type::TypeAlias(alias) => f.write_str(alias.name(self.db)),
Type::Union(union) => union.display(self.db).fmt(f),
Type::Intersection(intersection) => intersection.display(self.db).fmt(f),
Type::IntLiteral(n) => n.fmt(f),
Expand Down
Loading

0 comments on commit b9d931c

Please sign in to comment.