Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: New type representation with parameters #174

Merged
merged 11 commits into from
Mar 19, 2024
12 changes: 6 additions & 6 deletions guppylang/ast_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, cast

if TYPE_CHECKING:
from guppylang.gtypes import GuppyType
from guppylang.tys.ty import Type

AstNode = (
ast.AST
Expand Down Expand Up @@ -286,24 +286,24 @@ def with_loc(loc: ast.AST, node: A) -> A:
return node


def with_type(ty: "GuppyType", node: A) -> A:
def with_type(ty: "Type", node: A) -> A:
"""Annotates an AST node with a type."""
node.type = ty # type: ignore[attr-defined]
return node


def get_type_opt(node: AstNode) -> Optional["GuppyType"]:
def get_type_opt(node: AstNode) -> Optional["Type"]:
"""Tries to retrieve a type annotation from an AST node."""
from guppylang.gtypes import GuppyType
from guppylang.tys.ty import Type, TypeBase

try:
ty = node.type # type: ignore[union-attr]
return ty if isinstance(ty, GuppyType) else None
return cast(Type, ty) if isinstance(ty, TypeBase) else None
except AttributeError:
return None


def get_type(node: AstNode) -> "GuppyType":
def get_type(node: AstNode) -> "Type":
"""Retrieve a type annotation from an AST node.

Fails if the node is not annotated.
Expand Down
4 changes: 2 additions & 2 deletions guppylang/cfg/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from guppylang.cfg.cfg import CFG
from guppylang.checker.core import Globals
from guppylang.error import GuppyError, InternalGuppyError
from guppylang.gtypes import NoneType
from guppylang.nodes import (
DesugaredGenerator,
DesugaredListComp,
Expand All @@ -26,6 +25,7 @@
NestedFunctionDef,
PyExpr,
)
from guppylang.tys.ty import NoneType

# In order to build expressions, need an endless stream of unique temporary variables
# to store intermediate results
Expand Down Expand Up @@ -213,7 +213,7 @@ def visit_FunctionDef(
from guppylang.checker.func_checker import check_signature

func_ty = check_signature(node, self.globals)
returns_none = isinstance(func_ty.returns, NoneType)
returns_none = isinstance(func_ty.output, NoneType)
cfg = CFGBuilder().build(node.body, returns_none, self.globals)

new_node = NestedFunctionDef(
Expand Down
12 changes: 6 additions & 6 deletions guppylang/checker/cfg_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from guppylang.checker.expr_checker import ExprSynthesizer, to_bool
from guppylang.checker.stmt_checker import StmtChecker
from guppylang.error import GuppyError
from guppylang.gtypes import GuppyType
from guppylang.tys.ty import Type

VarRow = Sequence[Variable]

Expand Down Expand Up @@ -44,17 +44,17 @@ class CheckedBB(BB):


class CheckedCFG(BaseCFG[CheckedBB]):
input_tys: list[GuppyType]
output_ty: GuppyType
input_tys: list[Type]
output_ty: Type

def __init__(self, input_tys: list[GuppyType], output_ty: GuppyType) -> None:
def __init__(self, input_tys: list[Type], output_ty: Type) -> None:
super().__init__([])
self.input_tys = input_tys
self.output_ty = output_ty


def check_cfg(
cfg: CFG, inputs: VarRow, return_ty: GuppyType, globals: Globals
cfg: CFG, inputs: VarRow, return_ty: Type, globals: Globals
) -> CheckedCFG:
"""Type checks a control-flow graph.

Expand Down Expand Up @@ -121,7 +121,7 @@ def check_bb(
bb: BB,
checked_cfg: CheckedCFG,
inputs: VarRow,
return_ty: GuppyType,
return_ty: Type,
globals: Globals,
) -> CheckedBB:
cfg = bb.containing_cfg
Expand Down
80 changes: 53 additions & 27 deletions guppylang/checker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,29 @@
from dataclasses import dataclass
from typing import Any, NamedTuple

from typing_extensions import assert_never

from guppylang.ast_util import AstNode, name_nodes_in_ast
from guppylang.gtypes import (
BoolType,
from guppylang.tys.definition import (
TypeDef,
bool_type_def,
callable_type_def,
linst_type_def,
list_type_def,
none_type_def,
tuple_type_def,
)
from guppylang.tys.param import Parameter
from guppylang.tys.subst import Subst
from guppylang.tys.ty import (
BoundTypeVar,
ExistentialTypeVar,
FunctionType,
GuppyType,
LinstType,
ListType,
NoneType,
Subst,
OpaqueType,
SumType,
TupleType,
Type,
)


Expand All @@ -25,7 +37,7 @@ class Variable:
"""Class holding data associated with a variable."""

name: str
ty: GuppyType
ty: Type
defined_at: AstNode | None
used: AstNode | None

Expand All @@ -38,14 +50,14 @@ class CallableVariable(ABC, Variable):

@abstractmethod
def check_call(
self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: "Context"
self, args: list[ast.expr], ty: Type, node: AstNode, ctx: "Context"
) -> tuple[ast.expr, Subst]:
"""Checks the return type of a function call against a given type."""

@abstractmethod
def synthesize_call(
self, args: list[ast.expr], node: AstNode, ctx: "Context"
) -> tuple[ast.expr, GuppyType]:
) -> tuple[ast.expr, Type]:
"""Synthesizes the return type of a function call."""


Expand All @@ -68,30 +80,44 @@ class Globals(NamedTuple):
"""

values: dict[str, Variable]
types: dict[str, type[GuppyType]]
type_vars: dict[str, TypeVarDecl]
type_defs: dict[str, TypeDef]
param_vars: dict[str, Parameter]
python_scope: PyScope

@staticmethod
def default() -> "Globals":
"""Generates a `Globals` instance that is populated with all core types"""
tys: dict[str, type[GuppyType]] = {
FunctionType.name: FunctionType,
TupleType.name: TupleType,
SumType.name: SumType,
NoneType.name: NoneType,
BoolType.name: BoolType,
ListType.name: ListType,
LinstType.name: LinstType,
type_defs = {
"Callable": callable_type_def,
"tuple": tuple_type_def,
"None": none_type_def,
"bool": bool_type_def,
"list": list_type_def,
"linst": linst_type_def,
}
return Globals({}, tys, {}, {})
return Globals({}, type_defs, {}, {})

def get_instance_func(self, ty: GuppyType, name: str) -> CallableVariable | None:
def get_instance_func(self, ty: Type, name: str) -> CallableVariable | None:
"""Looks up an instance function with a given name for a type.

Returns `None` if the name doesn't exist or isn't a function.
"""
qualname = qualified_name(ty.__class__, name)
defn: TypeDef
match ty:
case BoundTypeVar() | ExistentialTypeVar() | SumType():
return None
case FunctionType():
defn = callable_type_def
case OpaqueType() as ty:
defn = ty.defn
case TupleType():
defn = tuple_type_def
case NoneType():
defn = none_type_def
case _:
assert_never(ty)

qualname = qualified_name(defn.name, name)
if qualname in self.values:
val = self.values[qualname]
if isinstance(val, CallableVariable):
Expand All @@ -101,15 +127,15 @@ def get_instance_func(self, ty: GuppyType, name: str) -> CallableVariable | None
def __or__(self, other: "Globals") -> "Globals":
return Globals(
self.values | other.values,
self.types | other.types,
self.type_vars | other.type_vars,
self.type_defs | other.type_defs,
self.param_vars | other.param_vars,
self.python_scope | other.python_scope,
)

def __ior__(self, other: "Globals") -> "Globals": # noqa: PYI034
self.values.update(other.values)
self.types.update(other.types)
self.type_vars.update(other.type_vars)
self.type_defs.update(other.type_defs)
self.param_vars.update(other.param_vars)
return self


Expand Down Expand Up @@ -203,7 +229,7 @@ def __contains__(self, key: object) -> bool:
return super().__contains__(key)


def qualified_name(ty: type[GuppyType] | str, name: str) -> str:
def qualified_name(ty: TypeDef | str, name: str) -> str:
"""Returns a qualified name for an instance function on a type."""
ty_name = ty if isinstance(ty, str) else ty.name
return f"{ty_name}.{name}"
Loading
Loading