Skip to content

Commit

Permalink
feat: Skip checking of redefined functions (#457)
Browse files Browse the repository at this point in the history
Closes #431.

* Make `Globals.defs` mutable
* Add `GuppyModule.unregister` function to remove definitions from a
module
* Redefining of classes doesn't work because of #456
  • Loading branch information
mark-koch authored Sep 9, 2024
1 parent 9d35a78 commit 7f9ad32
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 22 deletions.
12 changes: 2 additions & 10 deletions guppylang/checker/core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ast
import copy
import itertools
from collections.abc import Iterable, Iterator, Mapping
from collections.abc import Iterable, Iterator
from dataclasses import dataclass, replace
from functools import cached_property
from typing import (
Expand Down Expand Up @@ -204,7 +204,7 @@ class Globals:
user names to definition id and instance implementation id.
"""

defs: Mapping[DefId, Definition]
defs: dict[DefId, Definition]

names: dict[str, DefId]
impls: dict[DefId, dict[str, DefId]]
Expand Down Expand Up @@ -270,14 +270,6 @@ def get_instance_func(self, ty: Type | TypeDef, name: str) -> CallableDef | None
return defn
return None

def update_defs(self, defs: Mapping[DefId, Definition]) -> "Globals":
"""Returns a new `Globals` instance with updated definitions.
This method is needed since in-place definition updates are impossible as the
definition map is immutable.
"""
return Globals({**self.defs, **defs}, self.names, self.impls, self.python_scope)

def __or__(self, other: "Globals") -> "Globals":
impls = {
def_id: self.impls.get(def_id, {}) | other.impls.get(def_id, {})
Expand Down
4 changes: 2 additions & 2 deletions guppylang/definition/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import inspect
import textwrap
from collections.abc import Sequence
from dataclasses import dataclass
from dataclasses import dataclass, replace
from typing import Any

from hugr import Wire, ops
Expand Down Expand Up @@ -314,6 +314,6 @@ def check_instantiate(
**globals.defs,
defn.id: DummyStructDef(defn.id, defn.name, defn.defined_at),
}
dummy_globals = globals.update_defs(dummy_defs)
dummy_globals = replace(globals, defs=globals.defs | dummy_defs)
for field in defn.fields:
type_from_ast(field.type_ast, dummy_globals, {})
31 changes: 26 additions & 5 deletions guppylang/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def load(
def_id: all_checked_defs[def_id]
for def_id in all_globals.impls[def_id].values()
}
self._imported_globals |= Globals(defs, names, impls, {})
self._imported_globals |= Globals(dict(defs), names, impls, {})
self._imported_checked_defs |= defs

# We also need to include transitively imported checked definitions so we can
Expand Down Expand Up @@ -184,6 +184,10 @@ def register_def(self, defn: RawDef, instance: TypeDef | None = None) -> None:
if self._instance_func_buffer is not None and not isinstance(defn, TypeDef):
self._instance_func_buffer[defn.name] = defn
else:
# If this overrides an already defined name, we need to purge the old
# definition to avoid checking it later
if self.contains(defn.name):
self.unregister(self._globals[defn.name])
if isinstance(defn, TypeDef | ParamDef):
self._raw_type_defs[defn.id] = defn
else:
Expand All @@ -193,6 +197,7 @@ def register_def(self, defn: RawDef, instance: TypeDef | None = None) -> None:
self._globals.impls[instance.id][defn.name] = defn.id
else:
self._globals.names[defn.name] = defn.id
self._globals.defs[defn.id] = defn

def register_func_def(
self, f: PyFunc, instance: TypeDef | None = None
Expand All @@ -217,6 +222,22 @@ def _register_buffered_instance_funcs(self, instance: TypeDef) -> None:
for defn in buffer.values():
self.register_def(defn, instance)

def unregister(self, defn: Definition) -> None:
"""Removes a definition from this module.
Also removes all methods when unregistering a type.
"""
self._checked = False
self._compiled = False
self._compiled_hugr = None
self._raw_defs.pop(defn.id, None)
self._raw_type_defs.pop(defn.id, None)
self._globals.defs.pop(defn.id, None)
self._globals.names.pop(defn.name, None)
if impls := self._globals.impls.pop(defn.id, None):
for impl in impls.values():
self.unregister(self._globals[impl])

@property
def checked(self) -> bool:
return self._checked
Expand All @@ -230,12 +251,12 @@ def _check_defs(
raw_defs: Mapping[DefId, RawDef], globals: Globals
) -> dict[DefId, CheckedDef]:
"""Helper method to parse and check raw definitions."""
raw_globals = globals | Globals(raw_defs, {}, {}, {})
raw_globals = globals | Globals(dict(raw_defs), {}, {}, {})
parsed = {
def_id: defn.parse(raw_globals) if isinstance(defn, ParsableDef) else defn
for def_id, defn in raw_defs.items()
}
parsed_globals = globals | Globals(parsed, {}, {}, {})
parsed_globals = globals | Globals(dict(parsed), {}, {}, {})
return {
def_id: (
defn.check(parsed_globals) if isinstance(defn, CheckableDef) else defn
Expand Down Expand Up @@ -265,7 +286,7 @@ def check(self) -> None:
type_defs = self._check_defs(
self._raw_type_defs, self._imported_globals | self._globals
)
self._globals = self._globals.update_defs(type_defs)
self._globals.defs.update(type_defs)

# Collect auto-generated methods
generated: dict[DefId, RawDef] = {}
Expand All @@ -280,7 +301,7 @@ def check(self) -> None:
other_defs = self._check_defs(
self._raw_defs | generated, self._imported_globals | self._globals
)
self._globals = self._globals.update_defs(other_defs)
self._globals.defs.update(other_defs)
self._checked_defs = type_defs | other_defs
self._checked = True

Expand Down
74 changes: 69 additions & 5 deletions tests/integration/test_redefinition.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,83 @@
import pytest

from guppylang.decorator import guppy
from guppylang.module import GuppyModule

import guppylang.prelude.quantum as quantum


def test_redefinition(validate):
def test_func_redefinition(validate):
module = GuppyModule("test")
module.load_all(quantum)

@guppy(module)
def test() -> bool:
return True
return 5 # Type error on purpose

@guppy(module)
def test() -> bool: # noqa: F811
return False

validate(module.compile())


def test_method_redefinition(validate):
module = GuppyModule("test")

@guppy.struct(module)
class Test:
x: int

@guppy(module)
def foo(self: "Test") -> int:
return 1.0 # Type error on purpose

@guppy(module)
def foo(self: "Test") -> int:
return 1 # Type error on purpose

validate(module.compile())


@pytest.mark.skip("See https://github.com/CQCL/guppylang/issues/456")
def test_struct_redefinition(validate):
module = GuppyModule("test")

@guppy.struct(module)
class Test:
x: "blah" # Non-existing type

@guppy.struct(module)
class Test:
y: int

@guppy(module)
def main(x: int) -> Test:
return Test(x)

validate(module.compile())


@pytest.mark.skip("See https://github.com/CQCL/guppylang/issues/456")
def test_struct_method_redefinition(validate):
module = GuppyModule("test")

@guppy.struct(module)
class Test:
x: int

@guppy(module)
def foo(self: "Test") -> int:
return 1.0 # Type error on purpose

@guppy.struct(module)
class Test:
y: int

@guppy(module)
def bar(self: "Test") -> int:
return self.y

@guppy(module)
def main(x: int) -> int:
return Test(x).bar()

validate(module.compile())

0 comments on commit 7f9ad32

Please sign in to comment.