From 7f9ad32906e909c552025063c062d8b79d43325a Mon Sep 17 00:00:00 2001 From: Mark Koch <48097969+mark-koch@users.noreply.github.com> Date: Mon, 9 Sep 2024 16:29:16 +0200 Subject: [PATCH] feat: Skip checking of redefined functions (#457) Closes #431. * Make `Globals.defs` mutable * Add `GuppyModule.unregister` function to remove definitions from a module * Redefining of classes doesn't work because of #456 --- guppylang/checker/core.py | 12 +---- guppylang/definition/struct.py | 4 +- guppylang/module.py | 31 +++++++++-- tests/integration/test_redefinition.py | 74 ++++++++++++++++++++++++-- 4 files changed, 99 insertions(+), 22 deletions(-) diff --git a/guppylang/checker/core.py b/guppylang/checker/core.py index 60ec1c4d..142d6873 100644 --- a/guppylang/checker/core.py +++ b/guppylang/checker/core.py @@ -1,7 +1,7 @@ import ast import copy import itertools -from collections.abc import Iterable, Iterator, Mapping +from collections.abc import Iterable, Iterator from dataclasses import dataclass, replace from functools import cached_property from typing import ( @@ -204,7 +204,7 @@ class Globals: user names to definition id and instance implementation id. """ - defs: Mapping[DefId, Definition] + defs: dict[DefId, Definition] names: dict[str, DefId] impls: dict[DefId, dict[str, DefId]] @@ -270,14 +270,6 @@ def get_instance_func(self, ty: Type | TypeDef, name: str) -> CallableDef | None return defn return None - def update_defs(self, defs: Mapping[DefId, Definition]) -> "Globals": - """Returns a new `Globals` instance with updated definitions. - - This method is needed since in-place definition updates are impossible as the - definition map is immutable. - """ - return Globals({**self.defs, **defs}, self.names, self.impls, self.python_scope) - def __or__(self, other: "Globals") -> "Globals": impls = { def_id: self.impls.get(def_id, {}) | other.impls.get(def_id, {}) diff --git a/guppylang/definition/struct.py b/guppylang/definition/struct.py index a1fb1a97..370a5c3b 100644 --- a/guppylang/definition/struct.py +++ b/guppylang/definition/struct.py @@ -2,7 +2,7 @@ import inspect import textwrap from collections.abc import Sequence -from dataclasses import dataclass +from dataclasses import dataclass, replace from typing import Any from hugr import Wire, ops @@ -314,6 +314,6 @@ def check_instantiate( **globals.defs, defn.id: DummyStructDef(defn.id, defn.name, defn.defined_at), } - dummy_globals = globals.update_defs(dummy_defs) + dummy_globals = replace(globals, defs=globals.defs | dummy_defs) for field in defn.fields: type_from_ast(field.type_ast, dummy_globals, {}) diff --git a/guppylang/module.py b/guppylang/module.py index fb55db62..028a688b 100644 --- a/guppylang/module.py +++ b/guppylang/module.py @@ -144,7 +144,7 @@ def load( def_id: all_checked_defs[def_id] for def_id in all_globals.impls[def_id].values() } - self._imported_globals |= Globals(defs, names, impls, {}) + self._imported_globals |= Globals(dict(defs), names, impls, {}) self._imported_checked_defs |= defs # We also need to include transitively imported checked definitions so we can @@ -184,6 +184,10 @@ def register_def(self, defn: RawDef, instance: TypeDef | None = None) -> None: if self._instance_func_buffer is not None and not isinstance(defn, TypeDef): self._instance_func_buffer[defn.name] = defn else: + # If this overrides an already defined name, we need to purge the old + # definition to avoid checking it later + if self.contains(defn.name): + self.unregister(self._globals[defn.name]) if isinstance(defn, TypeDef | ParamDef): self._raw_type_defs[defn.id] = defn else: @@ -193,6 +197,7 @@ def register_def(self, defn: RawDef, instance: TypeDef | None = None) -> None: self._globals.impls[instance.id][defn.name] = defn.id else: self._globals.names[defn.name] = defn.id + self._globals.defs[defn.id] = defn def register_func_def( self, f: PyFunc, instance: TypeDef | None = None @@ -217,6 +222,22 @@ def _register_buffered_instance_funcs(self, instance: TypeDef) -> None: for defn in buffer.values(): self.register_def(defn, instance) + def unregister(self, defn: Definition) -> None: + """Removes a definition from this module. + + Also removes all methods when unregistering a type. + """ + self._checked = False + self._compiled = False + self._compiled_hugr = None + self._raw_defs.pop(defn.id, None) + self._raw_type_defs.pop(defn.id, None) + self._globals.defs.pop(defn.id, None) + self._globals.names.pop(defn.name, None) + if impls := self._globals.impls.pop(defn.id, None): + for impl in impls.values(): + self.unregister(self._globals[impl]) + @property def checked(self) -> bool: return self._checked @@ -230,12 +251,12 @@ def _check_defs( raw_defs: Mapping[DefId, RawDef], globals: Globals ) -> dict[DefId, CheckedDef]: """Helper method to parse and check raw definitions.""" - raw_globals = globals | Globals(raw_defs, {}, {}, {}) + raw_globals = globals | Globals(dict(raw_defs), {}, {}, {}) parsed = { def_id: defn.parse(raw_globals) if isinstance(defn, ParsableDef) else defn for def_id, defn in raw_defs.items() } - parsed_globals = globals | Globals(parsed, {}, {}, {}) + parsed_globals = globals | Globals(dict(parsed), {}, {}, {}) return { def_id: ( defn.check(parsed_globals) if isinstance(defn, CheckableDef) else defn @@ -265,7 +286,7 @@ def check(self) -> None: type_defs = self._check_defs( self._raw_type_defs, self._imported_globals | self._globals ) - self._globals = self._globals.update_defs(type_defs) + self._globals.defs.update(type_defs) # Collect auto-generated methods generated: dict[DefId, RawDef] = {} @@ -280,7 +301,7 @@ def check(self) -> None: other_defs = self._check_defs( self._raw_defs | generated, self._imported_globals | self._globals ) - self._globals = self._globals.update_defs(other_defs) + self._globals.defs.update(other_defs) self._checked_defs = type_defs | other_defs self._checked = True diff --git a/tests/integration/test_redefinition.py b/tests/integration/test_redefinition.py index 7defc40b..b1545e8a 100644 --- a/tests/integration/test_redefinition.py +++ b/tests/integration/test_redefinition.py @@ -1,19 +1,83 @@ +import pytest + from guppylang.decorator import guppy from guppylang.module import GuppyModule -import guppylang.prelude.quantum as quantum - -def test_redefinition(validate): +def test_func_redefinition(validate): module = GuppyModule("test") - module.load_all(quantum) @guppy(module) def test() -> bool: - return True + return 5 # Type error on purpose @guppy(module) def test() -> bool: # noqa: F811 return False validate(module.compile()) + + +def test_method_redefinition(validate): + module = GuppyModule("test") + + @guppy.struct(module) + class Test: + x: int + + @guppy(module) + def foo(self: "Test") -> int: + return 1.0 # Type error on purpose + + @guppy(module) + def foo(self: "Test") -> int: + return 1 # Type error on purpose + + validate(module.compile()) + + +@pytest.mark.skip("See https://github.com/CQCL/guppylang/issues/456") +def test_struct_redefinition(validate): + module = GuppyModule("test") + + @guppy.struct(module) + class Test: + x: "blah" # Non-existing type + + @guppy.struct(module) + class Test: + y: int + + @guppy(module) + def main(x: int) -> Test: + return Test(x) + + validate(module.compile()) + + +@pytest.mark.skip("See https://github.com/CQCL/guppylang/issues/456") +def test_struct_method_redefinition(validate): + module = GuppyModule("test") + + @guppy.struct(module) + class Test: + x: int + + @guppy(module) + def foo(self: "Test") -> int: + return 1.0 # Type error on purpose + + @guppy.struct(module) + class Test: + y: int + + @guppy(module) + def bar(self: "Test") -> int: + return self.y + + @guppy(module) + def main(x: int) -> int: + return Test(x).bar() + + validate(module.compile()) +