diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3441ff90..5199e803 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -200,13 +200,13 @@ jobs: - name: Build and install run: python -m pip install --verbose -e . - - name: Test without stim + - name: Test without stim or rustworkx run: python -m pytest tests - - name: Add stim - run: python -m pip install stim + - name: Install stim and rustworkx + run: python -m pip install stim rustworkx - - name: Test with stim using coverage + - name: Test with stim and rustworkx using coverage run: python -m pytest tests --cov=./src/pymatching --cov-report term - name: flake8 diff --git a/setup.py b/setup.py index 7154a15f..c0dfcb9e 100644 --- a/setup.py +++ b/setup.py @@ -153,7 +153,7 @@ def build_extension(self, ext): 'console_scripts': ['pymatching=pymatching._cli_argv:cli_argv'], }, python_requires=">=3.7", - install_requires=['scipy', 'numpy', 'networkx', 'rustworkx', 'matplotlib'], + install_requires=['scipy', 'numpy', 'networkx', 'matplotlib'], # Needed on Windows to avoid the default `build` colliding with Bazel's `BUILD`. options={'build': {'build_base': 'python_build_stim'}}, ) diff --git a/src/pymatching/matching.py b/src/pymatching/matching.py index 77a05de1..e51093e1 100644 --- a/src/pymatching/matching.py +++ b/src/pymatching/matching.py @@ -16,14 +16,14 @@ import warnings import numpy as np -import networkx as nx -import rustworkx as rx import pymatching +import networkx as nx from scipy.sparse import csc_matrix, spmatrix import matplotlib.cbook if TYPE_CHECKING: import stim # pragma: no cover + import rustworkx as rx import pymatching._cpp_pymatching as _cpp_pm @@ -38,7 +38,7 @@ class Matching: """ def __init__(self, - graph: Union[csc_matrix, np.ndarray, rx.PyGraph, nx.Graph, List[ + graph: Union[csc_matrix, np.ndarray, "rx.PyGraph", nx.Graph, List[ List[int]], 'stim.DetectorErrorModel', spmatrix] = None, weights: Union[float, np.ndarray, List[float]] = None, error_probabilities: Union[float, np.ndarray, List[float]] = None, @@ -136,22 +136,36 @@ def __init__(self, if graph is None: return del kwargs["H"] + # Networkx graph if isinstance(graph, nx.Graph): self.load_from_networkx(graph) - elif isinstance(graph, rx.PyGraph): - self.load_from_rustworkx(graph) - elif type(graph).__name__ == "DetectorErrorModel": - self._load_from_detector_error_model(graph) - else: - try: - graph = csc_matrix(graph) - except TypeError: - raise TypeError("The type of the input graph is not recognised. `graph` must be " - "a scipy.sparse or numpy matrix, networkx or rustworkx graph, or " - "stim.DetectorErrorModel.") - self.load_from_check_matrix(graph, weights, error_probabilities, - repetitions, timelike_weights, measurement_error_probabilities, - **kwargs) + return + # Rustworkx PyGraph + try: + import rustworkx as rx + if isinstance(graph, rx.PyGraph): + self.load_from_rustworkx(graph) + return + except ImportError: + pass + # stim.DetectorErrorModel + try: + import stim + if isinstance(graph, stim.DetectorErrorModel): + self._load_from_detector_error_model(graph) + return + except ImportError: + pass + # scipy.csc_matrix + try: + graph = csc_matrix(graph) + except TypeError: + raise TypeError("The type of the input graph is not recognised. `graph` must be " + "a scipy.sparse or numpy matrix, networkx or rustworkx graph, or " + "stim.DetectorErrorModel.") + self.load_from_check_matrix(graph, weights, error_probabilities, + repetitions, timelike_weights, measurement_error_probabilities, + **kwargs) def add_noise(self) -> Union[Tuple[np.ndarray, np.ndarray], None]: """Add noise by flipping edges in the matching graph with @@ -210,16 +224,6 @@ def decode(self, (modulo 2) between the (noisy) measurement of stabiliser `i` in time step `j+1` and time step `j` (for the case where the matching graph is constructed from a check matrix with `repetitions>1`). - _legacy_num_neighbours: int - The `num_neighbours` argument available in PyMatching versions 0.x.x is not - available in PyMatching v2.0.0 or later, since it introduced an approximation - that is not relevant or required in the new version 2 implementation. - Providing num_neighbours as this second positional argument will raise an exception in a - future version of PyMatching. - _legacy_return_weight: bool - ``return_weight`` used to be available as this third positional argument, but should now - be set as a keyword argument. In a future version of PyMatching, it will only be possible - to provide `return_weight` as a keyword argument. return_weight : bool, optional If `return_weight==True`, the sum of the weights of the edges in the minimum weight perfect matching is also returned. By default False @@ -314,7 +318,7 @@ def decode(self, if _legacy_return_weight is not None: warnings.warn("The ``return_weights`` argument was provided as a positional argument, but in a future " "version of PyMatching, it will be required to provide ``return_weights`` as a keyword " - "argument.") + "argument.", DeprecationWarning, stacklevel=2) return_weight = _legacy_return_weight detection_events = self._syndrome_array_to_detection_events(z) correction, weight = self._matching_graph.decode(detection_events) @@ -1474,7 +1478,7 @@ def load_from_networkx(self, graph: nx.Graph, *, min_num_fault_ids: int = None) g.add_edge(u, v, fault_ids, weight, e_prob, merge_strategy="smallest-weight") self._matching_graph = g - def load_from_retworkx(self, graph: rx.PyGraph, *, min_num_fault_ids: int = None) -> None: + def load_from_retworkx(self, graph: "rx.PyGraph", *, min_num_fault_ids: int = None) -> None: r""" Load a matching graph from a retworkX graph. This method is deprecated since the retworkx package has been renamed to rustworkx. Please use ``pymatching.Matching.load_from_rustworkx`` instead. @@ -1483,7 +1487,7 @@ def load_from_retworkx(self, graph: rx.PyGraph, *, min_num_fault_ids: int = None "renamed to `rustworkx`. Please use `pymatching.Matching.load_from_rustworkx` instead.", DeprecationWarning, stacklevel=2) self.load_from_rustworkx(graph=graph, min_num_fault_ids=min_num_fault_ids) - def load_from_rustworkx(self, graph: rx.PyGraph, *, min_num_fault_ids: int = None) -> None: + def load_from_rustworkx(self, graph: "rx.PyGraph", *, min_num_fault_ids: int = None) -> None: r""" Load a matching graph from a rustworkX graph @@ -1525,6 +1529,10 @@ def load_from_rustworkx(self, graph: rx.PyGraph, *, min_num_fault_ids: int = Non >>> m """ + try: + import rustworkx as rx + except ImportError: + raise ImportError("rustworkx must be installed to use Matching.load_from_rustworkx") if not isinstance(graph, rx.PyGraph): raise TypeError("G must be a rustworkx graph") boundary = {i for i in graph.node_indices() if graph[i].get("is_boundary", False)} @@ -1587,7 +1595,7 @@ def to_networkx(self) -> nx.Graph: graph.nodes[num_nodes]['is_boundary'] = True return graph - def to_retworkx(self) -> rx.PyGraph: + def to_retworkx(self) -> "rx.PyGraph": """Deprecated, use ``pymatching.Matching.to_rustworkx`` instead (since the `retworkx` package has been renamed to `rustworkx`). This method just calls ``pymatching.Matching.to_rustworkx`` and returns a ``rustworkx.PyGraph``, which is now just the preferred name for ``retworkx.PyGraph``. Note that in the future, only the `rustworkx` package name will be supported, @@ -1597,7 +1605,7 @@ def to_retworkx(self) -> rx.PyGraph: "renamed to `rustworkx`. Please use `pymatching.Matching.to_rustworkx` instead.", DeprecationWarning, stacklevel=2) return self.to_rustworkx() - def to_rustworkx(self) -> rx.PyGraph: + def to_rustworkx(self) -> "rx.PyGraph": """Convert to rustworkx graph Returns a rustworkx graph object corresponding to the matching graph. Each edge payload is a ``dict`` with keys `fault_ids`, `weight` and `error_probability` and @@ -1609,6 +1617,11 @@ def to_rustworkx(self) -> rx.PyGraph: rustworkx.PyGraph rustworkx graph corresponding to the matching graph """ + try: + import rustworkx as rx + except ImportError: + raise ImportError("rustworkx must be installed to use Matching.to_rustworkx.") + graph = rx.PyGraph(multigraph=False) num_nodes = self.num_nodes has_virtual_boundary = False diff --git a/tests/matching/load_from_rustworkx_test.py b/tests/matching/load_from_rustworkx_test.py index 85afa946..3bd486d4 100644 --- a/tests/matching/load_from_rustworkx_test.py +++ b/tests/matching/load_from_rustworkx_test.py @@ -13,7 +13,6 @@ # limitations under the License. import numpy as np -import rustworkx as rx import pytest from pymatching import Matching @@ -21,6 +20,7 @@ def test_boundary_from_rustworkx(): + rx = pytest.importorskip("rustworkx") g = rx.PyGraph() g.add_nodes_from([{} for _ in range(5)]) g.add_edge(4, 0, dict(fault_ids=0)) @@ -38,6 +38,7 @@ def test_boundary_from_rustworkx(): def test_boundaries_from_rustworkx(): + rx = pytest.importorskip("rustworkx") g = rx.PyGraph() g.add_nodes_from([{} for _ in range(6)]) g.add_edge(0, 1, dict(fault_ids=0)) @@ -57,6 +58,7 @@ def test_boundaries_from_rustworkx(): def test_unweighted_stabiliser_graph_from_rustworkx(): + rx = pytest.importorskip("rustworkx") w = rx.PyGraph() w.add_nodes_from([{} for _ in range(6)]) w.add_edge(0, 1, dict(fault_ids=0, weight=7.0)) @@ -90,6 +92,7 @@ def test_unweighted_stabiliser_graph_from_rustworkx(): def test_mwpm_from_rustworkx(): + rx = pytest.importorskip("rustworkx") g = rx.PyGraph() g.add_nodes_from([{} for _ in range(3)]) g.add_edge(0, 1, dict(fault_ids=0)) @@ -122,6 +125,7 @@ def test_mwpm_from_rustworkx(): def test_matching_edges_from_rustworkx(): + rx = pytest.importorskip("rustworkx") g = rx.PyGraph() g.add_nodes_from([{} for _ in range(4)]) g.add_edge(0, 1, dict(fault_ids=0, weight=1.1, error_probability=0.1)) @@ -143,6 +147,7 @@ def test_matching_edges_from_rustworkx(): def test_qubit_id_accepted_via_rustworkx(): + rx = pytest.importorskip("rustworkx") g = rx.PyGraph() g.add_nodes_from([{} for _ in range(4)]) g.add_edge(0, 1, dict(qubit_id=0, weight=1.1, error_probability=0.1)) @@ -163,6 +168,7 @@ def test_qubit_id_accepted_via_rustworkx(): def test_load_from_rustworkx_raises_value_error_if_qubit_id_and_fault_ids_both_supplied(): + rx = pytest.importorskip("rustworkx") with pytest.raises(ValueError): g = rx.PyGraph() g.add_nodes_from([{} for _ in range(3)]) @@ -173,6 +179,7 @@ def test_load_from_rustworkx_raises_value_error_if_qubit_id_and_fault_ids_both_s def test_load_from_rustworkx_type_errors_raised(): + rx = pytest.importorskip("rustworkx") with pytest.raises(TypeError): m = Matching() m.load_from_rustworkx("A") @@ -188,3 +195,11 @@ def test_load_from_rustworkx_type_errors_raised(): g.add_edge(0, 1, dict(fault_ids=[[0], [2]])) m = Matching() m.load_from_rustworkx(g) + + +def test_load_from_rustworkx_without_rustworkx_raises_type_error(): + try: + import rustworkx # noqa + except ImportError: + with pytest.raises(TypeError): + Matching.load_from_rustworkx("test") diff --git a/tests/matching/output_graph_test.py b/tests/matching/output_graph_test.py index 33a30b95..98a60bad 100644 --- a/tests/matching/output_graph_test.py +++ b/tests/matching/output_graph_test.py @@ -13,7 +13,7 @@ # limitations under the License. import networkx as nx -import rustworkx as rx +import pytest from pymatching import Matching @@ -60,6 +60,7 @@ def test_matching_to_networkx(): def test_matching_to_rustworkx(): + rx = pytest.importorskip("rustworkx") g = rx.PyGraph() g.add_nodes_from([{} for _ in range(4)]) g.add_edge(0, 1, dict(fault_ids={0}, weight=1.1, error_probability=0.1)) diff --git a/tests/matching/properties_test.py b/tests/matching/properties_test.py index 715dd24a..32c9f14b 100644 --- a/tests/matching/properties_test.py +++ b/tests/matching/properties_test.py @@ -13,7 +13,7 @@ # limitations under the License. import networkx as nx -import rustworkx as rx +import pytest from pymatching.matching import Matching @@ -48,9 +48,13 @@ def test_set_min_num_fault_ids(): assert m.num_fault_ids == 4 assert m.decode([1, 1]).shape[0] == 4 + +def test_set_min_num_fault_ids_rustworkx(): + rx = pytest.importorskip("rustworkx") g = rx.PyGraph() g.add_nodes_from([{} for _ in range(2)]) g.add_edge(0, 1, dict(fault_ids=3)) + m = Matching() m.load_from_rustworkx(g) assert m.num_fault_ids == 4 assert m.decode([1, 1]).shape[0] == 4 diff --git a/tests/matching/retworkx_test.py b/tests/matching/retworkx_test.py index 06b1cc59..14e6a246 100644 --- a/tests/matching/retworkx_test.py +++ b/tests/matching/retworkx_test.py @@ -1,10 +1,10 @@ import pytest from pymatching import Matching -import rustworkx as rx def test_load_from_retworkx_deprecated(): + rx = pytest.importorskip("rustworkx") with pytest.deprecated_call(): g = rx.PyGraph() g.add_nodes_from([{} for _ in range(3)]) @@ -16,6 +16,7 @@ def test_load_from_retworkx_deprecated(): def test_to_retworkx_deprecated(): + _ = pytest.importorskip("rustworkx") with pytest.deprecated_call(): m = Matching() m.add_edge(0, 1, {0})