Skip to content

Commit

Permalink
feat: Array comprehension
Browse files Browse the repository at this point in the history
  • Loading branch information
mark-koch committed Nov 4, 2024
1 parent e6890ee commit 37a7118
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 7 deletions.
24 changes: 21 additions & 3 deletions guppylang/cfg/bb.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@
from typing_extensions import Self

from guppylang.ast_util import AstNode, name_nodes_in_ast
from guppylang.nodes import DesugaredListComp, NestedFunctionDef, PyExpr
from guppylang.nodes import (
DesugaredArrayComp,
DesugaredGenerator,
DesugaredGeneratorExpr,
DesugaredListComp,
NestedFunctionDef,
PyExpr,
)

if TYPE_CHECKING:
from guppylang.cfg.cfg import BaseCFG
Expand Down Expand Up @@ -144,19 +151,30 @@ def _handle_assign_target(self, lhs: ast.expr, node: ast.stmt) -> None:
self.visit(value)

def visit_DesugaredListComp(self, node: DesugaredListComp) -> None:
self._handle_comprehension(node.generators, node.elt)

def visit_DesugaredArrayComp(self, node: DesugaredArrayComp) -> None:
self._handle_comprehension([node.generator], node.elt)

def visit_DesugaredGeneratorExpr(self, node: DesugaredGeneratorExpr) -> None:
self._handle_comprehension(node.generators, node.elt)

def _handle_comprehension(
self, generators: list[DesugaredGenerator], elt: ast.expr
) -> None:
# Names bound in the comprehension are only available inside, so we shouldn't
# update `self.stats` with assignments
inner_visitor = VariableVisitor(self.bb)
inner_stats = inner_visitor.stats

# The generators are evaluated left to right
for gen in node.generators:
for gen in generators:
inner_visitor.visit(gen.iter_assign)
inner_visitor.visit(gen.hasnext_assign)
inner_visitor.visit(gen.next_assign)
for cond in gen.ifs:
inner_visitor.visit(cond)
inner_visitor.visit(node.elt)
inner_visitor.visit(elt)

self.stats.used |= {
x: n for x, n in inner_stats.used.items() if x not in self.stats.assigned
Expand Down
19 changes: 15 additions & 4 deletions guppylang/cfg/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from guppylang.experimental import check_lists_enabled
from guppylang.nodes import (
DesugaredGenerator,
DesugaredGeneratorExpr,
DesugaredListComp,
IterEnd,
IterHasNext,
Expand Down Expand Up @@ -304,8 +305,18 @@ def visit_IfExp(self, node: ast.IfExp) -> ast.Name:
# The final value is stored in the temporary variable
return make_var(tmp, node)

def visit_ListComp(self, node: ast.ListComp) -> ast.AST:
def visit_ListComp(self, node: ast.ListComp) -> DesugaredListComp:
check_lists_enabled(node)
generators, elt = self._build_comprehension(node.generators, node.elt, node)
return with_loc(node, DesugaredListComp(elt=elt, generators=generators))

def visit_GeneratorExp(self, node: ast.GeneratorExp) -> DesugaredGeneratorExpr:
generators, elt = self._build_comprehension(node.generators, node.elt, node)
return with_loc(node, DesugaredGeneratorExpr(elt=elt, generators=generators))

def _build_comprehension(
self, generators: list[ast.comprehension], elt: ast.expr, node: ast.AST
) -> tuple[list[DesugaredGenerator], ast.expr]:
# Check for illegal expressions
illegals = find_nodes(is_illegal_in_list_comp, node)
if illegals:
Expand All @@ -316,7 +327,7 @@ def visit_ListComp(self, node: ast.ListComp) -> ast.AST:
# Desugar into statements that create the iterator, check for a next element,
# get the next element, and finalise the iterator.
gens = []
for g in node.generators:
for g in generators:
if g.is_async:
raise GuppyError("Async generators are not supported", g)
g.iter = self.visit(g.iter)
Expand All @@ -339,8 +350,8 @@ def visit_ListComp(self, node: ast.ListComp) -> ast.AST:
)
gens.append(desugared)

node.elt = self.visit(node.elt)
return with_loc(node, DesugaredListComp(elt=node.elt, generators=gens))
elt = self.visit(elt)
return gens, elt

def visit_Call(self, node: ast.Call) -> ast.AST:
return is_py_expression(node) or self.generic_visit(node)
Expand Down
28 changes: 28 additions & 0 deletions guppylang/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,18 @@ class DesugaredGenerator(ast.expr):
)


class DesugaredGeneratorExpr(ast.expr):
"""A desugared generator expression."""

elt: ast.expr
generators: list[DesugaredGenerator]

_fields = (
"elt",
"generators",
)


class DesugaredListComp(ast.expr):
"""A desugared list comprehension."""

Expand All @@ -211,6 +223,22 @@ class DesugaredListComp(ast.expr):
)


class DesugaredArrayComp(ast.expr):
"""A desugared array comprehension."""

elt: ast.expr
generator: DesugaredGenerator
length: int
elt_ty: Type

_fields = (
"elt",
"generator",
"length",
"elt_ty",
)


class PyExpr(ast.expr):
"""A compile-time evaluated `py(...)` expression."""

Expand Down

0 comments on commit 37a7118

Please sign in to comment.