From f158384c88787d1e436b634657dcfc12d531d68e Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 1 Jul 2024 11:23:50 +0100 Subject: [PATCH] fix(hugr-py): more ruff lints + fix some typos (#1246) import ruff.toml from guppylang - turn on more lints - mostly auto fixes, mostly import sorting add a config file for `typos` tool https://github.com/crate-ci/typos?tab=readme-ov-file - we could consider adding this to CI/pre-commit but don't want to do a big repo-wide change right now - I used cli typos tool with `cargo install typos-cli` coverage is not happy because most of the fixes are in TYPE_CHECKING import blocks or unraised exceptions --- _typos.toml | 4 + hugr-py/src/hugr/cfg.py | 15 ++-- hugr-py/src/hugr/cond_loop.py | 15 ++-- hugr-py/src/hugr/dfg.py | 22 +++--- hugr-py/src/hugr/exceptions.py | 10 ++- hugr-py/src/hugr/function.py | 8 +- hugr-py/src/hugr/hugr.py | 24 +++--- hugr-py/src/hugr/node_port.py | 23 +++--- hugr-py/src/hugr/ops.py | 19 +++-- hugr-py/src/hugr/serialization/__init__.py | 3 - hugr-py/src/hugr/serialization/ops.py | 46 ++++++----- hugr-py/src/hugr/serialization/serial_hugr.py | 11 ++- .../src/hugr/serialization/testing_hugr.py | 12 ++- hugr-py/src/hugr/serialization/tys.py | 52 +++++++------ hugr-py/src/hugr/tys.py | 4 +- hugr-py/src/hugr/utils.py | 5 +- hugr-py/src/hugr/val.py | 4 +- hugr-py/tests/conftest.py | 33 +++++--- hugr-py/tests/serialization/test_basic.py | 2 +- hugr-py/tests/test_cfg.py | 7 +- hugr-py/tests/test_cond_loop.py | 12 +-- hugr-py/tests/test_hugr_build.py | 13 ++-- hugr-py/tests/test_version.py | 7 +- ruff.toml | 78 +++++++++++++++++++ scripts/generate_schema.py | 13 ++-- 25 files changed, 297 insertions(+), 145 deletions(-) create mode 100644 _typos.toml create mode 100644 ruff.toml diff --git a/_typos.toml b/_typos.toml new file mode 100644 index 000000000..d5d8f06c2 --- /dev/null +++ b/_typos.toml @@ -0,0 +1,4 @@ +[default.extend-identifiers] +bck = "bck" # BiMap uses abbreviation +ser_it = "ser_it" +SerCollection = "SerCollection" diff --git a/hugr-py/src/hugr/cfg.py b/hugr-py/src/hugr/cfg.py index 0cec794b1..f10256acd 100644 --- a/hugr-py/src/hugr/cfg.py +++ b/hugr-py/src/hugr/cfg.py @@ -1,15 +1,18 @@ from __future__ import annotations from dataclasses import dataclass +from typing import TYPE_CHECKING import hugr.ops as ops +import hugr.val as val from .dfg import _DfBase -from .exceptions import NoSiblingAncestor, NotInSameCfg, MismatchedExit +from .exceptions import MismatchedExit, NoSiblingAncestor, NotInSameCfg from .hugr import Hugr, ParentBuilder -from .node_port import Node, Wire, ToNode -from .tys import TypeRow, Type -import hugr.val as val + +if TYPE_CHECKING: + from .node_port import Node, ToNode, Wire + from .tys import Type, TypeRow class Block(_DfBase[ops.DataflowBlock]): @@ -27,13 +30,13 @@ def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type: src_parent = self.hugr[src.node].parent try: super()._wire_up_port(node, offset, p) - except NoSiblingAncestor: + except NoSiblingAncestor as e: # note this just checks if there is a common CFG ancestor # it does not check for valid dominance between basic blocks # that is deferred to full HUGR validation. while cfg_node != src_parent: if src_parent is None or src_parent == self.hugr.root: - raise NotInSameCfg(src.node.idx, node.idx) + raise NotInSameCfg(src.node.idx, node.idx) from e src_parent = self.hugr[src_parent].parent self.hugr.add_link(src, node.inp(offset)) diff --git a/hugr-py/src/hugr/cond_loop.py b/hugr-py/src/hugr/cond_loop.py index a1ac42830..6cc1e5758 100644 --- a/hugr-py/src/hugr/cond_loop.py +++ b/hugr-py/src/hugr/cond_loop.py @@ -1,14 +1,16 @@ from __future__ import annotations from dataclasses import dataclass +from typing import TYPE_CHECKING import hugr.ops as ops from .dfg import _DfBase from .hugr import Hugr, ParentBuilder -from .node_port import Node, Wire, ToNode -from .tys import Sum, TypeRow +if TYPE_CHECKING: + from .node_port import Node, ToNode, Wire + from .tys import Sum, TypeRow class Case(_DfBase[ops.Case]): @@ -35,7 +37,8 @@ def __init__(self, case: Case) -> None: def _parent_conditional(self) -> Conditional: if self._parent_cond is None: - raise ConditionalError("If must have a parent conditional.") + msg = "If must have a parent conditional." + raise ConditionalError(msg) return self._parent_cond @@ -84,11 +87,13 @@ def _update_outputs(self, outputs: TypeRow) -> None: self.parent_op._outputs = outputs else: if outputs != self.parent_op._outputs: - raise ConditionalError("Mismatched case outputs.") + msg = "Mismatched case outputs." + raise ConditionalError(msg) def add_case(self, case_id: int) -> Case: if case_id not in self.cases: - raise ConditionalError(f"Case {case_id} out of possible range.") + msg = f"Case {case_id} out of possible range." + raise ConditionalError(msg) input_types = self.parent_op.nth_inputs(case_id) new_case = Case.new_nested( ops.Case(input_types), diff --git a/hugr-py/src/hugr/dfg.py b/hugr-py/src/hugr/dfg.py index 24e2b0647..7c8b3a41c 100644 --- a/hugr-py/src/hugr/dfg.py +++ b/hugr-py/src/hugr/dfg.py @@ -3,8 +3,6 @@ from dataclasses import dataclass, replace from typing import ( TYPE_CHECKING, - Iterable, - Sequence, TypeVar, ) @@ -13,23 +11,25 @@ import hugr.ops as ops import hugr.val as val from hugr.tys import ( + ExtensionSet, + FunctionKind, + FunctionType, + PolyFuncType, Type, + TypeArg, TypeRow, get_first_sum, - FunctionType, - TypeArg, - FunctionKind, - PolyFuncType, - ExtensionSet, ) from .exceptions import NoSiblingAncestor from .hugr import Hugr, ParentBuilder -from .node_port import Node, OutPort, Wire, ToNode if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + from .cfg import Cfg from .cond_loop import Conditional, If, TailLoop + from .node_port import Node, OutPort, ToNode, Wire DP = TypeVar("DP", bound=ops.DfParentOp) @@ -210,7 +210,8 @@ def _fn_sig(self, func: ToNode) -> PolyFuncType: case FunctionKind(sig): signature = sig case _: - raise ValueError("Expected 'func' to be a function") + msg = "Expected 'func' to be a function" + raise ValueError(msg) return signature def _wire_up(self, node: Node, ports: Iterable[Wire]) -> TypeRow: @@ -223,7 +224,8 @@ def _get_dataflow_type(self, wire: Wire) -> Type: port = wire.out_port() ty = self.hugr.port_type(port) if ty is None: - raise ValueError(f"Port {port} is not a dataflow port.") + msg = f"Port {port} is not a dataflow port." + raise ValueError(msg) return ty def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type: diff --git a/hugr-py/src/hugr/exceptions.py b/hugr-py/src/hugr/exceptions.py index e5c211d44..bc9245029 100644 --- a/hugr-py/src/hugr/exceptions.py +++ b/hugr-py/src/hugr/exceptions.py @@ -8,7 +8,10 @@ class NoSiblingAncestor(Exception): @property def msg(self): - return f"Source {self.src} has no sibling ancestor of target {self.tgt}, so cannot wire up." + return ( + f"Source {self.src} has no sibling ancestor of target {self.tgt}," + " so cannot wire up." + ) @dataclass @@ -18,7 +21,10 @@ class NotInSameCfg(Exception): @property def msg(self): - return f"Source {self.src} is not in the same CFG as target {self.tgt}, so cannot wire up." + return ( + f"Source {self.src} is not in the same CFG as target {self.tgt}," + " so cannot wire up." + ) @dataclass diff --git a/hugr-py/src/hugr/function.py b/hugr-py/src/hugr/function.py index 2e698c5d1..ecd461bb7 100644 --- a/hugr-py/src/hugr/function.py +++ b/hugr-py/src/hugr/function.py @@ -1,14 +1,18 @@ from __future__ import annotations from dataclasses import dataclass +from typing import TYPE_CHECKING import hugr.ops as ops import hugr.val as val from .dfg import _DfBase -from hugr.node_port import Node from .hugr import Hugr -from .tys import TypeRow, TypeParam, PolyFuncType, Type, TypeBound + +if TYPE_CHECKING: + from hugr.node_port import Node + + from .tys import PolyFuncType, Type, TypeBound, TypeParam, TypeRow @dataclass diff --git a/hugr-py/src/hugr/hugr.py b/hugr-py/src/hugr/hugr.py index 68bc8c91b..628ee2b43 100644 --- a/hugr-py/src/hugr/hugr.py +++ b/hugr-py/src/hugr/hugr.py @@ -1,28 +1,28 @@ from __future__ import annotations -from collections.abc import Mapping +from collections.abc import Iterable, Mapping from dataclasses import dataclass, field, replace from typing import ( + TYPE_CHECKING, Generic, - Iterable, Protocol, TypeVar, cast, overload, - Type as PyType, ) - -from hugr.ops import Op, DataflowOp, Const, Call -from hugr.tys import Type, Kind, ValueKind -from hugr.val import Value -from hugr.node_port import Direction, InPort, OutPort, ToNode, Node, _SubPort +from hugr.node_port import Direction, InPort, Node, OutPort, ToNode, _SubPort +from hugr.ops import Call, Const, DataflowOp, Op from hugr.serialization.ops import OpType as SerialOp from hugr.serialization.serial_hugr import SerialHugr +from hugr.tys import Kind, Type, ValueKind from hugr.utils import BiMap from .exceptions import ParentBeforeChild +if TYPE_CHECKING: + from hugr.val import Value + @dataclass() class NodeData: @@ -88,7 +88,7 @@ def __iter__(self): def __len__(self) -> int: return self.num_nodes() - def _get_typed_op(self, node: ToNode, cl: PyType[OpVar2]) -> OpVar2: + def _get_typed_op(self, node: ToNode, cl: type[OpVar2]) -> OpVar2: op = self[node].op assert isinstance(op, cl) return op @@ -241,11 +241,11 @@ def incoming_links(self, node: ToNode) -> Iterable[tuple[InPort, list[OutPort]]] return self._node_links(node, self._links.bck) def num_incoming(self, node: Node) -> int: - # connecetd links + # connected links return sum(1 for _ in self.incoming_links(node)) def num_outgoing(self, node: ToNode) -> int: - # connecetd links + # connected links return sum(1 for _ in self.outgoing_links(node)) # TODO: num_links and _linked_ports @@ -274,7 +274,7 @@ def insert_hugr(self, hugr: Hugr, parent: ToNode | None = None) -> dict[Node, No mapping[node_data.parent] if node_data.parent else parent ) except KeyError as e: - raise ParentBeforeChild() from e + raise ParentBeforeChild from e mapping[Node(idx)] = self.add_node(node_data.op, node_parent) for src, dst in hugr._links.items(): diff --git a/hugr-py/src/hugr/node_port.py b/hugr-py/src/hugr/node_port.py index 23e10291a..5ec3c716b 100644 --- a/hugr-py/src/hugr/node_port.py +++ b/hugr-py/src/hugr/node_port.py @@ -3,15 +3,19 @@ from dataclasses import dataclass, field, replace from enum import Enum from typing import ( + TYPE_CHECKING, ClassVar, - Iterator, + Generic, Protocol, - overload, TypeVar, - Generic, + overload, ) + from typing_extensions import Self +if TYPE_CHECKING: + from collections.abc import Iterator + class Direction(Enum): INCOMING = 0 @@ -57,7 +61,7 @@ def __getitem__( ) -> OutPort | Iterator[OutPort]: return self.to_node()._index(index) - def out_port(self) -> "OutPort": + def out_port(self) -> OutPort: return OutPort(self.to_node(), 0) def inp(self, offset: int) -> InPort: @@ -83,17 +87,16 @@ def _index( ) -> OutPort | Iterator[OutPort]: match index: case int(index): - if self._num_out_ports is not None: - if index >= self._num_out_ports: - raise IndexError("Index out of range") + if self._num_out_ports is not None and index >= self._num_out_ports: + msg = "Index out of range" + raise IndexError(msg) return self.out(index) case slice(): start = index.start or 0 stop = index.stop or self._num_out_ports if stop is None: - raise ValueError( - "Stop must be specified when number of outputs unknown" - ) + msg = "Stop must be specified when number of outputs unknown" + raise ValueError(msg) step = index.step or 1 return (self[i] for i in range(start, stop, step)) case tuple(xs): diff --git a/hugr-py/src/hugr/ops.py b/hugr-py/src/hugr/ops.py index f5f9a50b5..cdd316965 100644 --- a/hugr-py/src/hugr/ops.py +++ b/hugr-py/src/hugr/ops.py @@ -1,13 +1,18 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Protocol, Sequence, runtime_checkable, TypeVar -from hugr.serialization.ops import BaseOp +from typing import TYPE_CHECKING, Protocol, TypeVar, runtime_checkable + import hugr.serialization.ops as sops -from hugr.utils import ser_it import hugr.tys as tys -from hugr.node_port import Node, InPort, OutPort, Wire import hugr.val as val +from hugr.node_port import InPort, Node, OutPort, Wire +from hugr.utils import ser_it + +if TYPE_CHECKING: + from collections.abc import Sequence + + from hugr.serialization.ops import BaseOp @dataclass @@ -616,11 +621,13 @@ def _fn_instantiation( else: # TODO substitute type args into signature to get instantiation if instantiation is None: - raise NoConcreteFunc("Missing instantiation for polymorphic function.") + msg = "Missing instantiation for polymorphic function." + raise NoConcreteFunc(msg) type_args = type_args or [] if len(signature.params) != len(type_args): - raise NoConcreteFunc("Mismatched number of type arguments.") + msg = "Mismatched number of type arguments." + raise NoConcreteFunc(msg) return instantiation, list(type_args) diff --git a/hugr-py/src/hugr/serialization/__init__.py b/hugr-py/src/hugr/serialization/__init__.py index ad1db81e8..e69de29bb 100644 --- a/hugr-py/src/hugr/serialization/__init__.py +++ b/hugr-py/src/hugr/serialization/__init__.py @@ -1,3 +0,0 @@ -from .serial_hugr import SerialHugr - -__all__ = ["SerialHugr"] diff --git a/hugr-py/src/hugr/serialization/ops.py b/hugr-py/src/hugr/serialization/ops.py index 380c0aa7d..15c409700 100644 --- a/hugr-py/src/hugr/serialization/ops.py +++ b/hugr-py/src/hugr/serialization/ops.py @@ -1,27 +1,32 @@ from __future__ import annotations + import inspect import sys from abc import ABC, abstractmethod from typing import Any, Literal -from pydantic import Field, RootModel, ConfigDict +from pydantic import ConfigDict, Field, RootModel + +from hugr.utils import deser_it from . import tys as stys from .tys import ( + ConfiguredBaseModel, ExtensionId, ExtensionSet, FunctionType, PolyFuncType, - Type, - TypeRow, SumType, + Type, TypeBound, - ConfiguredBaseModel, + TypeRow, +) +from .tys import ( classes as tys_classes, +) +from .tys import ( model_rebuild as tys_model_rebuild, ) -from hugr.utils import deser_it - NodeID = int @@ -128,10 +133,10 @@ class TupleValue(BaseValue): """A constant tuple value.""" v: Literal["Tuple"] = Field(default="Tuple", title="ValueTag") - vs: list["Value"] + vs: list[Value] def deserialize(self) -> val.Value: - return val.Tuple(*deser_it((v.root for v in self.vs))) + return val.Tuple(*deser_it(v.root for v in self.vs)) class SumValue(BaseValue): @@ -143,7 +148,7 @@ class SumValue(BaseValue): v: Literal["Sum"] = Field(default="Sum", title="ValueTag") tag: int typ: SumType - vs: list["Value"] + vs: list[Value] model_config = ConfigDict( json_schema_extra={ "description": ( @@ -155,7 +160,7 @@ class SumValue(BaseValue): def deserialize(self) -> val.Value: return val.Sum( - self.tag, self.typ.deserialize(), deser_it((v.root for v in self.vs)) + self.tag, self.typ.deserialize(), deser_it(v.root for v in self.vs) ) @@ -224,7 +229,8 @@ def deserialize(self) -> ops.DataflowBlock: model_config = ConfigDict( json_schema_extra={ - "description": "A CFG basic block node. The signature is that of the internal Dataflow graph.", + "description": "A CFG basic block node." + " The signature is that of the internal Dataflow graph.", } ) @@ -239,7 +245,8 @@ class ExitBlock(BaseOp): model_config = ConfigDict( json_schema_extra={ # Needed to avoid random '\n's in the pydantic description - "description": "The single exit node of the CFG, has no children, stores the types of the CFG node output.", + "description": "The single exit node of the CFG, has no children," + " stores the types of the CFG node output.", } ) @@ -362,7 +369,7 @@ def deserialize(self) -> ops.LoadFunc: (f_ty,) = signature.output assert isinstance( f_ty, tys.FunctionType - ), "Expected single funciton type output" + ), "Expected single function type output" return ops.LoadFunc( self.func_sig.deserialize(), f_ty, @@ -517,8 +524,8 @@ def deserialize(self) -> ops.Custom: # Needed to avoid random '\n's in the pydantic description json_schema_extra={ "description": ( - "A user-defined operation that can be downcasted by the extensions that " - "define it." + "A user-defined operation that can be downcasted by the extensions that" + " define it." ) } ) @@ -682,7 +689,8 @@ class OpDef(ConfiguredBaseModel, populate_by_name=True): tys_model_rebuild(dict(classes)) -# needed to avoid circular imports -from hugr import ops # noqa: E402 -from hugr import val # noqa: E402 -from hugr import tys # noqa: E402 +from hugr import ( # noqa: E402 # needed to avoid circular imports + ops, + tys, + val, +) diff --git a/hugr-py/src/hugr/serialization/serial_hugr.py b/hugr-py/src/hugr/serialization/serial_hugr.py index 49bfbd2f7..619eaca90 100644 --- a/hugr-py/src/hugr/serialization/serial_hugr.py +++ b/hugr-py/src/hugr/serialization/serial_hugr.py @@ -1,11 +1,13 @@ from typing import Any, Literal -from pydantic import Field, ConfigDict +from pydantic import ConfigDict, Field -from .ops import NodeID, OpType, classes as ops_classes -from .tys import model_rebuild, ConfiguredBaseModel import hugr +from .ops import NodeID, OpType +from .ops import classes as ops_classes +from .tys import ConfiguredBaseModel, model_rebuild + Port = tuple[NodeID, int | None] # (node, offset) Edge = tuple[Port, Port] @@ -37,7 +39,8 @@ def get_version(cls) -> str: return cls(nodes=[], edges=[]).version @classmethod - def _pydantic_rebuild(cls, config: ConfigDict = ConfigDict(), **kwargs): + def _pydantic_rebuild(cls, config: ConfigDict | None = None, **kwargs): + config = config or ConfigDict() my_classes = dict(ops_classes) my_classes[cls.__name__] = cls model_rebuild(my_classes, config=config, **kwargs) diff --git a/hugr-py/src/hugr/serialization/testing_hugr.py b/hugr-py/src/hugr/serialization/testing_hugr.py index 32bf2b95f..c36bc4e59 100644 --- a/hugr-py/src/hugr/serialization/testing_hugr.py +++ b/hugr-py/src/hugr/serialization/testing_hugr.py @@ -1,7 +1,10 @@ -from pydantic import ConfigDict from typing import Literal -from .tys import Type, SumType, PolyFuncType, ConfiguredBaseModel, model_rebuild -from .ops import Value, OpType, OpDef, classes as ops_classes + +from pydantic import ConfigDict + +from .ops import OpDef, OpType, Value +from .ops import classes as ops_classes +from .tys import ConfiguredBaseModel, PolyFuncType, SumType, Type, model_rebuild class TestingHugr(ConfiguredBaseModel): @@ -22,7 +25,8 @@ def get_version(cls) -> str: return cls().version @classmethod - def _pydantic_rebuild(cls, config: ConfigDict = ConfigDict(), **kwargs): + def _pydantic_rebuild(cls, config: ConfigDict | None = None, **kwargs): + config = config or ConfigDict() my_classes = dict(ops_classes) my_classes[cls.__name__] = cls model_rebuild(my_classes, config=config, **kwargs) diff --git a/hugr-py/src/hugr/serialization/tys.py b/hugr-py/src/hugr/serialization/tys.py index 053deefcd..d245c9380 100644 --- a/hugr-py/src/hugr/serialization/tys.py +++ b/hugr-py/src/hugr/serialization/tys.py @@ -1,24 +1,28 @@ from __future__ import annotations -from abc import ABC, abstractmethod import inspect import sys +from abc import ABC, abstractmethod from enum import Enum -from typing import Annotated, Any, Literal, Union, Mapping +from typing import TYPE_CHECKING, Annotated, Any, Literal -from hugr.utils import deser_it from pydantic import ( BaseModel, + ConfigDict, Field, RootModel, ValidationError, ValidationInfo, ValidatorFunctionWrapHandler, WrapValidator, - ConfigDict, ) from pydantic_core import PydanticCustomError +from hugr.utils import deser_it + +if TYPE_CHECKING: + from collections.abc import Mapping + def _json_custom_error_validator( value: Any, handler: ValidatorFunctionWrapHandler, _info: ValidationInfo @@ -36,8 +40,9 @@ def _json_custom_error_validator( try: return handler(value) except ValidationError as err: + msg = "invalid_json" raise PydanticCustomError( - "invalid_json", + msg, "Input is not valid json", ) from err @@ -68,7 +73,7 @@ def deserialize(self) -> tys.TypeParam: ... class TypeTypeParam(BaseTypeParam): tp: Literal["Type"] = "Type" - b: "TypeBound" + b: TypeBound def deserialize(self) -> tys.TypeTypeParam: return tys.TypeTypeParam(bound=self.b) @@ -84,7 +89,7 @@ def deserialize(self) -> tys.BoundedNatParam: class OpaqueParam(BaseTypeParam): tp: Literal["Opaque"] = "Opaque" - ty: "Opaque" + ty: Opaque def deserialize(self) -> tys.OpaqueParam: return tys.OpaqueParam(ty=self.ty.deserialize()) @@ -92,7 +97,7 @@ def deserialize(self) -> tys.OpaqueParam: class ListParam(BaseTypeParam): tp: Literal["List"] = "List" - param: "TypeParam" + param: TypeParam def deserialize(self) -> tys.ListParam: return tys.ListParam(param=self.param.deserialize()) @@ -100,7 +105,7 @@ def deserialize(self) -> tys.ListParam: class TupleParam(BaseTypeParam): tp: Literal["Tuple"] = "Tuple" - params: list["TypeParam"] + params: list[TypeParam] def deserialize(self) -> tys.TupleParam: return tys.TupleParam(params=deser_it(self.params)) @@ -144,7 +149,7 @@ def deserialize(self) -> tys.TypeArg: ... class TypeTypeArg(BaseTypeArg): tya: Literal["Type"] = "Type" - ty: "Type" + ty: Type def deserialize(self) -> tys.TypeTypeArg: return tys.TypeTypeArg(ty=self.ty.deserialize()) @@ -160,7 +165,7 @@ def deserialize(self) -> tys.BoundedNatArg: class OpaqueArg(BaseTypeArg): tya: Literal["Opaque"] = "Opaque" - typ: "Opaque" + typ: Opaque value: Any def deserialize(self) -> tys.OpaqueArg: @@ -169,7 +174,7 @@ def deserialize(self) -> tys.OpaqueArg: class SequenceArg(BaseTypeArg): tya: Literal["Sequence"] = "Sequence" - elems: list["TypeArg"] + elems: list[TypeArg] def deserialize(self) -> tys.SequenceArg: return tys.SequenceArg(elems=deser_it(self.elems)) @@ -222,7 +227,7 @@ def deserialize(self) -> tys.Type: ... class MultiContainer(BaseType): - ty: "Type" + ty: Type class Array(MultiContainer): @@ -251,14 +256,14 @@ class GeneralSum(BaseType): t: Literal["Sum"] = "Sum" s: Literal["General"] = "General" - rows: list["TypeRow"] + rows: list[TypeRow] def deserialize(self) -> tys.Sum: return tys.Sum(variant_rows=[[t.deserialize() for t in r] for r in self.rows]) class SumType(RootModel): - root: Annotated[Union[UnitSum, GeneralSum], Field(discriminator="s")] + root: Annotated[UnitSum | GeneralSum, Field(discriminator="s")] # This seems to be required for nested discriminated unions to work @property @@ -281,7 +286,7 @@ class Variable(BaseType): t: Literal["V"] = "V" i: int - b: "TypeBound" + b: TypeBound def deserialize(self) -> tys.Variable: return tys.Variable(idx=self.i, bound=self.b) @@ -293,7 +298,7 @@ class RowVar(BaseType): t: Literal["R"] = "R" i: int - b: "TypeBound" + b: TypeBound def deserialize(self) -> tys.RowVariable: return tys.RowVariable(idx=self.i, bound=self.b) @@ -314,13 +319,13 @@ class FunctionType(BaseType): t: Literal["G"] = "G" - input: "TypeRow" # Value inputs of the function. - output: "TypeRow" # Value outputs of the function. + input: TypeRow # Value inputs of the function. + output: TypeRow # Value outputs of the function. # The extension requirements which are added by the operation extension_reqs: ExtensionSet = Field(default_factory=ExtensionSet) @classmethod - def empty(cls) -> "FunctionType": + def empty(cls) -> FunctionType: return FunctionType(input=[], output=[], extension_reqs=[]) def deserialize(self) -> tys.FunctionType: @@ -354,7 +359,7 @@ class PolyFuncType(BaseType): body: FunctionType @classmethod - def empty(cls) -> "PolyFuncType": + def empty(cls) -> PolyFuncType: return PolyFuncType(params=[], body=FunctionType.empty()) def deserialize(self) -> tys.PolyFuncType: @@ -380,7 +385,7 @@ class TypeBound(Enum): Any = "A" @staticmethod - def join(*bs: "TypeBound") -> "TypeBound": + def join(*bs: TypeBound) -> TypeBound: """Computes the least upper bound for a sequence of bounds.""" res = TypeBound.Eq for b in bs: @@ -475,9 +480,10 @@ def deserialize(self) -> tys.Type: def model_rebuild( classes: Mapping[str, type], - config: ConfigDict = ConfigDict(), + config: ConfigDict | None = None, **kwargs, ): + config = config or ConfigDict() for c in classes.values(): if issubclass(c, ConfiguredBaseModel): c.update_model_config(config) diff --git a/hugr-py/src/hugr/tys.py b/hugr-py/src/hugr/tys.py index cfbb7c294..4a83d422a 100644 --- a/hugr-py/src/hugr/tys.py +++ b/hugr-py/src/hugr/tys.py @@ -1,8 +1,10 @@ from __future__ import annotations + from dataclasses import dataclass, field +from typing import Any, Protocol, runtime_checkable + import hugr.serialization.tys as stys from hugr.utils import ser_it -from typing import Any, Protocol, runtime_checkable ExtensionId = stys.ExtensionId ExtensionSet = stys.ExtensionSet diff --git a/hugr-py/src/hugr/utils.py b/hugr-py/src/hugr/utils.py index dd4bdae1a..0bf383df8 100644 --- a/hugr-py/src/hugr/utils.py +++ b/hugr-py/src/hugr/utils.py @@ -1,7 +1,6 @@ -from collections.abc import Hashable, ItemsView, MutableMapping +from collections.abc import Hashable, ItemsView, Iterable, MutableMapping from dataclasses import dataclass, field -from typing import Generic, Iterable, Protocol, TypeVar - +from typing import Generic, Protocol, TypeVar L = TypeVar("L", bound=Hashable) R = TypeVar("R", bound=Hashable) diff --git a/hugr-py/src/hugr/val.py b/hugr-py/src/hugr/val.py index d69056068..b08b7c23e 100644 --- a/hugr-py/src/hugr/val.py +++ b/hugr-py/src/hugr/val.py @@ -1,6 +1,8 @@ from __future__ import annotations + from dataclasses import dataclass, field -from typing import Any, Protocol, runtime_checkable, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable + import hugr.serialization.ops as sops import hugr.serialization.tys as stys import hugr.tys as tys diff --git a/hugr-py/tests/conftest.py b/hugr-py/tests/conftest.py index 220aa5037..2edae34bf 100644 --- a/hugr-py/tests/conftest.py +++ b/hugr-py/tests/conftest.py @@ -1,17 +1,20 @@ from __future__ import annotations -from dataclasses import dataclass, field -import subprocess +import json import os import pathlib -from hugr.node_port import Wire +import subprocess +from dataclasses import dataclass, field +from typing import TYPE_CHECKING -from hugr.hugr import Hugr -from hugr.ops import Custom, Command -from hugr.serialization import SerialHugr import hugr.tys as tys import hugr.val as val -import json +from hugr.hugr import Hugr +from hugr.ops import Command, Custom +from hugr.serialization.serial_hugr import SerialHugr + +if TYPE_CHECKING: + from hugr.node_port import Wire def int_t(width: int) -> tys.Opaque: @@ -39,12 +42,15 @@ class LogicOps(Custom): extension: tys.ExtensionId = "logic" +_NotSig = tys.FunctionType.endo([tys.Bool]) + + # TODO get from YAML @dataclass class NotDef(LogicOps): num_out: int = 1 op_name: str = "Not" - signature: tys.FunctionType = tys.FunctionType.endo([tys.Bool]) + signature: tys.FunctionType = _NotSig def __call__(self, a: Wire) -> Command: return super().__call__(a) @@ -58,11 +64,14 @@ class QuantumOps(Custom): extension: tys.ExtensionId = "tket2.quantum" +_OneQbSig = tys.FunctionType.endo([tys.Qubit]) + + @dataclass class OneQbGate(QuantumOps): op_name: str num_out: int = 1 - signature: tys.FunctionType = tys.FunctionType.endo([tys.Qubit]) + signature: tys.FunctionType = _OneQbSig def __call__(self, q: Wire) -> Command: return super().__call__(q) @@ -70,12 +79,14 @@ def __call__(self, q: Wire) -> Command: H = OneQbGate("H") +_MeasSig = tys.FunctionType([tys.Qubit], [tys.Qubit, tys.Bool]) + @dataclass class MeasureDef(QuantumOps): op_name: str = "Measure" num_out: int = 2 - signature: tys.FunctionType = tys.FunctionType([tys.Qubit], [tys.Qubit, tys.Bool]) + signature: tys.FunctionType = _MeasSig def __call__(self, q: Wire) -> Command: return super().__call__(q) @@ -115,7 +126,7 @@ def validate(h: Hugr, mermaid: bool = False, roundtrip: bool = True): if mermaid: cmd.append("--mermaid") serial = h.to_serial().to_json() - subprocess.run(cmd, check=True, input=serial.encode()) + subprocess.run(cmd, check=True, input=serial.encode()) # noqa: S603 if roundtrip: h2 = Hugr.from_serial(SerialHugr.load_json(json.loads(serial))) diff --git a/hugr-py/tests/serialization/test_basic.py b/hugr-py/tests/serialization/test_basic.py index 5c3b41ace..1479888eb 100644 --- a/hugr-py/tests/serialization/test_basic.py +++ b/hugr-py/tests/serialization/test_basic.py @@ -1,4 +1,4 @@ -from hugr.serialization import SerialHugr +from hugr.serialization.serial_hugr import SerialHugr def test_empty(): diff --git a/hugr-py/tests/test_cfg.py b/hugr-py/tests/test_cfg.py index ddd3b9c32..e2c02554c 100644 --- a/hugr-py/tests/test_cfg.py +++ b/hugr-py/tests/test_cfg.py @@ -1,9 +1,10 @@ -from hugr.cfg import Cfg +import hugr.ops as ops import hugr.tys as tys import hugr.val as val +from hugr.cfg import Cfg from hugr.dfg import Dfg -import hugr.ops as ops -from .conftest import validate, INT_T, DivMod, IntVal + +from .conftest import INT_T, DivMod, IntVal, validate def build_basic_cfg(cfg: Cfg) -> None: diff --git a/hugr-py/tests/test_cond_loop.py b/hugr-py/tests/test_cond_loop.py index 75791317b..919480160 100644 --- a/hugr-py/tests/test_cond_loop.py +++ b/hugr-py/tests/test_cond_loop.py @@ -1,10 +1,12 @@ -from hugr.cond_loop import Conditional, ConditionalError, TailLoop -from hugr.dfg import Dfg -import hugr.tys as tys +import pytest + import hugr.ops as ops +import hugr.tys as tys import hugr.val as val -import pytest -from .conftest import INT_T, validate, IntVal, H, Measure +from hugr.cond_loop import Conditional, ConditionalError, TailLoop +from hugr.dfg import Dfg + +from .conftest import INT_T, H, IntVal, Measure, validate SUM_T = tys.Sum([[tys.Qubit], [tys.Qubit, INT_T]]) diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index b6058a293..9a1075276 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -1,16 +1,17 @@ from __future__ import annotations -from hugr.node_port import Node, _SubPort -from hugr.hugr import Hugr -from hugr.dfg import Dfg, _ancestral_sibling -from hugr.ops import NoConcreteFunc +import pytest + import hugr.ops as ops import hugr.tys as tys import hugr.val as val +from hugr.dfg import Dfg, _ancestral_sibling from hugr.function import Module -import pytest +from hugr.hugr import Hugr +from hugr.node_port import Node, _SubPort +from hugr.ops import NoConcreteFunc -from .conftest import Not, INT_T, IntVal, validate, DivMod +from .conftest import INT_T, DivMod, IntVal, Not, validate def test_stable_indices(): diff --git a/hugr-py/tests/test_version.py b/hugr-py/tests/test_version.py index ac9e154d6..5ed309b90 100644 --- a/hugr-py/tests/test_version.py +++ b/hugr-py/tests/test_version.py @@ -1,6 +1,8 @@ # from https://github.com/python-poetry/poetry/issues/144#issuecomment-877835259 -import toml # type: ignore[import-untyped] from pathlib import Path + +import toml # type: ignore[import-untyped] + import hugr @@ -8,7 +10,8 @@ def test_versions_are_in_sync(): """Checks if the pyproject.toml and package.__init__.py __version__ are in sync.""" path = Path(__file__).resolve().parents[1] / "pyproject.toml" - pyproject = toml.loads(open(str(path)).read()) + with Path.open(path, "r") as f: + pyproject = toml.loads(f.read()) pyproject_version = pyproject["tool"]["poetry"]["version"] package_init_version = hugr.__version__ diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 000000000..47501f2c4 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,78 @@ +# See https://docs.astral.sh/ruff/rules/ +target-version = "py310" + +line-length = 88 + +exclude = ["tests/error"] + +[lint] + +select = [ + "F", # pyflakes + "E", # pycodestyle Errors + "W", # pycodestyle Warnings + + # "A", # flake8-builtins + # "ANN", # flake8-annotations + # "ARG", # flake8-unused-arguments + "B", # flake8-Bugbear + "BLE", # flake8-blind-except + "C4", # flake8-comprehensions + # "C90", # mccabe + # "COM", # flake8-commas + # "CPY", # flake8-copyright + # "D", # pydocstyle + "EM", # flake8-errmsg + # "ERA", # eradicate + "EXE", # flake8-executable + "FA", # flake8-future-annotations + # "FBT", # flake8-boolean-trap + # "FIX", # flake8-fixme + "FLY", # flynt + # "FURB", # refurb + "G", # flake8-logging-format + "I", # isort + "ICN", # flake8-import-conventions + "INP", # flake8-no-pep420 + "INT", # flake8-gettext + # "ISC", # flake8-implicit-str-concat + # "LOG", # flake8-logging + # "N", # pep8-Naming + "NPY", # NumPy-specific + "PERF", # Perflint + "PGH", # pygrep-hooks + "PIE", # flake8-pie + # "PL", # pylint + "PT", # flake8-pytest-style + "PTH", # flake8-use-pathlib + "PYI", # flake8-pyi + "Q", # flake8-quotes + # "RET", # flake8-return + "RSE", # flake8-raise + "RUF", # Ruff-specific + "S", # flake8-bandit (Security) + "SIM", # flake8-simplify + # "SLF", # flake8-self + "SLOT", # flake8-slots + "T10", # flake8-debugger + "T20", # flake8-print + "TCH", # flake8-type-checking + # "TD", # flake8-todos + "TID", # flake8-tidy-imports + "TRY", # tryceratops + "UP", # pyupgrade + "YTT", # flake8-2020 +] + + +ignore = [ + "S101", # Use of `assert` detected + "TRY003", # Avoid specifying long messages outside the exception class +] + +[lint.per-file-ignores] +"hugr-pr/tests/**" = ["D"] +"scripts/generate_schema.py"= ["T201", "EXE001"] + +[lint.pydocstyle] +convention = "google" diff --git a/scripts/generate_schema.py b/scripts/generate_schema.py index 882f0d7e6..78d66fb19 100644 --- a/scripts/generate_schema.py +++ b/scripts/generate_schema.py @@ -1,28 +1,29 @@ #!/usr/bin/env python """Dumps the json schema for `hugr.serialization.SerialHugr` to a file. -The schema is written to a file named `hugr_schema_v#.json` in the specified output directory. -If no output directory is specified, the schema is written to the current working directory. +The schema is written to a file named `hugr_schema_v#.json` +in the specified output directory. +If no output directory is specified, +the schema is written to the current working directory. usage: python generate_schema.py [] """ import json import sys -from typing import Type, Optional from pathlib import Path from pydantic import ConfigDict -from hugr.serialization import SerialHugr +from hugr.serialization.serial_hugr import SerialHugr from hugr.serialization.testing_hugr import TestingHugr def write_schema( out_dir: Path, name_prefix: str, - schema: Type[SerialHugr] | Type[TestingHugr], - config: Optional[ConfigDict] = None, + schema: type[SerialHugr] | type[TestingHugr], + config: ConfigDict | None = None, **kwargs, ): version = schema.get_version()