Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Skip checking of redefined functions #457

Merged
merged 1 commit into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions guppylang/checker/core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ast
import copy
import itertools
from collections.abc import Iterable, Iterator, Mapping
from collections.abc import Iterable, Iterator
from dataclasses import dataclass, replace
from functools import cached_property
from typing import (
Expand Down Expand Up @@ -204,7 +204,7 @@ class Globals:
user names to definition id and instance implementation id.
"""

defs: Mapping[DefId, Definition]
defs: dict[DefId, Definition]

names: dict[str, DefId]
impls: dict[DefId, dict[str, DefId]]
Expand Down Expand Up @@ -270,14 +270,6 @@ def get_instance_func(self, ty: Type | TypeDef, name: str) -> CallableDef | None
return defn
return None

def update_defs(self, defs: Mapping[DefId, Definition]) -> "Globals":
"""Returns a new `Globals` instance with updated definitions.

This method is needed since in-place definition updates are impossible as the
definition map is immutable.
"""
return Globals({**self.defs, **defs}, self.names, self.impls, self.python_scope)

def __or__(self, other: "Globals") -> "Globals":
impls = {
def_id: self.impls.get(def_id, {}) | other.impls.get(def_id, {})
Expand Down
4 changes: 2 additions & 2 deletions guppylang/definition/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import inspect
import textwrap
from collections.abc import Sequence
from dataclasses import dataclass
from dataclasses import dataclass, replace
from typing import Any

from hugr import Wire, ops
Expand Down Expand Up @@ -314,6 +314,6 @@ def check_instantiate(
**globals.defs,
defn.id: DummyStructDef(defn.id, defn.name, defn.defined_at),
}
dummy_globals = globals.update_defs(dummy_defs)
dummy_globals = replace(globals, defs=globals.defs | dummy_defs)
for field in defn.fields:
type_from_ast(field.type_ast, dummy_globals, {})
31 changes: 26 additions & 5 deletions guppylang/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def load(
def_id: all_checked_defs[def_id]
for def_id in all_globals.impls[def_id].values()
}
self._imported_globals |= Globals(defs, names, impls, {})
self._imported_globals |= Globals(dict(defs), names, impls, {})
self._imported_checked_defs |= defs

# We also need to include transitively imported checked definitions so we can
Expand Down Expand Up @@ -182,6 +182,10 @@ def register_def(self, defn: RawDef, instance: TypeDef | None = None) -> None:
if self._instance_func_buffer is not None and not isinstance(defn, TypeDef):
self._instance_func_buffer[defn.name] = defn
else:
# If this overrides an already defined name, we need to purge the old
# definition to avoid checking it later
if self.contains(defn.name):
self.unregister(self._globals[defn.name])
if isinstance(defn, TypeDef | ParamDef):
self._raw_type_defs[defn.id] = defn
else:
Expand All @@ -191,6 +195,7 @@ def register_def(self, defn: RawDef, instance: TypeDef | None = None) -> None:
self._globals.impls[instance.id][defn.name] = defn.id
else:
self._globals.names[defn.name] = defn.id
self._globals.defs[defn.id] = defn

def register_func_def(
self, f: PyFunc, instance: TypeDef | None = None
Expand All @@ -215,6 +220,22 @@ def _register_buffered_instance_funcs(self, instance: TypeDef) -> None:
for defn in buffer.values():
self.register_def(defn, instance)

def unregister(self, defn: Definition) -> None:
"""Removes a definition from this module.

Also removes all methods when unregistering a type.
"""
self._checked = False
self._compiled = False
self._compiled_hugr = None
self._raw_defs.pop(defn.id, None)
self._raw_type_defs.pop(defn.id, None)
self._globals.defs.pop(defn.id, None)
self._globals.names.pop(defn.name, None)
if impls := self._globals.impls.pop(defn.id, None):
for impl in impls.values():
self.unregister(self._globals[impl])

@property
def checked(self) -> bool:
return self._checked
Expand All @@ -228,12 +249,12 @@ def _check_defs(
raw_defs: Mapping[DefId, RawDef], globals: Globals
) -> dict[DefId, CheckedDef]:
"""Helper method to parse and check raw definitions."""
raw_globals = globals | Globals(raw_defs, {}, {}, {})
raw_globals = globals | Globals(dict(raw_defs), {}, {}, {})
parsed = {
def_id: defn.parse(raw_globals) if isinstance(defn, ParsableDef) else defn
for def_id, defn in raw_defs.items()
}
parsed_globals = globals | Globals(parsed, {}, {}, {})
parsed_globals = globals | Globals(dict(parsed), {}, {}, {})
return {
def_id: (
defn.check(parsed_globals) if isinstance(defn, CheckableDef) else defn
Expand Down Expand Up @@ -263,7 +284,7 @@ def check(self) -> None:
type_defs = self._check_defs(
self._raw_type_defs, self._imported_globals | self._globals
)
self._globals = self._globals.update_defs(type_defs)
self._globals.defs.update(type_defs)

# Collect auto-generated methods
generated: dict[DefId, RawDef] = {}
Expand All @@ -278,7 +299,7 @@ def check(self) -> None:
other_defs = self._check_defs(
self._raw_defs | generated, self._imported_globals | self._globals
)
self._globals = self._globals.update_defs(other_defs)
self._globals.defs.update(other_defs)
self._checked_defs = type_defs | other_defs
self._checked = True

Expand Down
74 changes: 69 additions & 5 deletions tests/integration/test_redefinition.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,83 @@
import pytest

from guppylang.decorator import guppy
from guppylang.module import GuppyModule

import guppylang.prelude.quantum as quantum


def test_redefinition(validate):
def test_func_redefinition(validate):
module = GuppyModule("test")
module.load_all(quantum)

@guppy(module)
def test() -> bool:
return True
return 5 # Type error on purpose

@guppy(module)
def test() -> bool: # noqa: F811
return False

validate(module.compile())


def test_method_redefinition(validate):
module = GuppyModule("test")

@guppy.struct(module)
class Test:
x: int

@guppy(module)
def foo(self: "Test") -> int:
return 1.0 # Type error on purpose

@guppy(module)
def foo(self: "Test") -> int:
return 1 # Type error on purpose

validate(module.compile())


@pytest.mark.skip("See https://github.com/CQCL/guppylang/issues/456")
def test_struct_redefinition(validate):
module = GuppyModule("test")

@guppy.struct(module)
class Test:
x: "blah" # Non-existing type

@guppy.struct(module)
class Test:
y: int

@guppy(module)
def main(x: int) -> Test:
return Test(x)

validate(module.compile())


@pytest.mark.skip("See https://github.com/CQCL/guppylang/issues/456")
def test_struct_method_redefinition(validate):
module = GuppyModule("test")

@guppy.struct(module)
class Test:
x: int

@guppy(module)
def foo(self: "Test") -> int:
return 1.0 # Type error on purpose

@guppy.struct(module)
class Test:
y: int

@guppy(module)
def bar(self: "Test") -> int:
return self.y

@guppy(module)
def main(x: int) -> int:
return Test(x).bar()

validate(module.compile())

Loading