Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Evade false positives for inout variable usage #493

Merged
merged 3 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 25 additions & 21 deletions guppylang/checker/linearity_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,17 +101,23 @@ class BBLinearityChecker(ast.NodeVisitor):

scope: Scope
stats: VariableStats[PlaceId]
func_inputs: dict[PlaceId, Variable]
globals: Globals

def check(
self, bb: "CheckedBB[Variable]", is_entry: bool, globals: Globals
self,
bb: "CheckedBB[Variable]",
is_entry: bool,
func_inputs: dict[PlaceId, Variable],
globals: Globals,
) -> Scope:
# Manufacture a scope that holds all places that are live at the start
# of this BB
input_scope = Scope()
for var in bb.sig.input_row:
for place in leaf_places(var):
input_scope.assign(place)
self.func_inputs = func_inputs
self.globals = globals

# Open up a new nested scope to check the BB contents. This way we can track
Expand Down Expand Up @@ -174,6 +180,20 @@ def visit_Assign(self, node: ast.Assign) -> None:
self.visit(node.value)
self._check_assign_targets(node.targets)

# Check that borrowed vars are not being shadowed. This would also be caught by
# the dataflow analysis later, however we can give nicer error messages here.
[target] = node.targets
for tgt in find_nodes(lambda n: isinstance(n, PlaceNode), target):
assert isinstance(tgt, PlaceNode)
if tgt.place.id in self.func_inputs:
entry_place = self.func_inputs[tgt.place.id]
if is_inout_var(entry_place):
raise GuppyError(
f"Assignment shadows borrowed argument `{entry_place}`. "
"Consider assigning to a different name.",
tgt.place.defined_at,
)

def _visit_call_args(self, func_ty: FunctionType, args: list[ast.expr]) -> None:
"""Helper function to check the arguments of a function call.

Expand Down Expand Up @@ -448,30 +468,14 @@ def check_cfg_linearity(
than just variables.
"""
bb_checker = BBLinearityChecker()
func_inputs: dict[PlaceId, Variable] = {v.id: v for v in cfg.entry_bb.sig.input_row}
scopes: dict[BB, Scope] = {
bb: bb_checker.check(bb, is_entry=bb == cfg.entry_bb, globals=globals)
bb: bb_checker.check(
bb, is_entry=bb == cfg.entry_bb, func_inputs=func_inputs, globals=globals
)
for bb in cfg.bbs
}

# Check that borrowed vars are not being shadowed. This would also be caught by
# the dataflow analysis below, however we can give nicer error messages here.
for bb, scope in scopes.items():
if bb == cfg.entry_bb:
# Arguments are assigned in the entry BB, so would yield a false positive
# in the check below. Shadowing in the entry BB will be caught by the check
# in `_check_assign_targets`.
continue
entry_scope = scopes[cfg.entry_bb]
for x, place in scope.vars.items():
if x in entry_scope:
entry_place = entry_scope[x]
if is_inout_var(entry_place):
raise GuppyError(
f"Assignment shadows borrowed argument `{entry_place}`. "
"Consider assigning to a different name.",
place.defined_at,
)

# Mark the borrowed variables as implicitly used in the exit BB
exit_scope = scopes[cfg.exit_bb]
for var in cfg.entry_bb.sig.input_row:
Expand Down
15 changes: 15 additions & 0 deletions tests/integration/test_inout.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,3 +339,18 @@ def main() -> qubit:
return q

validate(module.compile())

def test_shadow_check(validate):
module = GuppyModule("test")

module.load(quantum, qubit)

@guppy.declare(module)
def foo(i: qubit) -> None: ...

@guppy(module)
def main(i: qubit) -> None:
if True:
foo(i)

validate(module.compile())
Loading