Skip to content

Commit

Permalink
Remove rustworkx as a dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
oscarhiggott committed Feb 18, 2024
1 parent b802cc5 commit 5369b80
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 41 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -200,13 +200,13 @@ jobs:
- name: Build and install
run: python -m pip install --verbose -e .

- name: Test without stim
- name: Test without stim or rustworkx
run: python -m pytest tests

- name: Add stim
run: python -m pip install stim
- name: Install stim and rustworkx
run: python -m pip install stim rustworkx

- name: Test with stim using coverage
- name: Test with stim and rustworkx using coverage
run: python -m pytest tests --cov=./src/pymatching --cov-report term

- name: flake8
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def build_extension(self, ext):
'console_scripts': ['pymatching=pymatching._cli_argv:cli_argv'],
},
python_requires=">=3.7",
install_requires=['scipy', 'numpy', 'networkx', 'rustworkx', 'matplotlib'],
install_requires=['scipy', 'numpy', 'networkx', 'matplotlib'],
# Needed on Windows to avoid the default `build` colliding with Bazel's `BUILD`.
options={'build': {'build_base': 'python_build_stim'}},
)
77 changes: 45 additions & 32 deletions src/pymatching/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
import warnings

import numpy as np
import networkx as nx
import rustworkx as rx
import pymatching
import networkx as nx
from scipy.sparse import csc_matrix, spmatrix
import matplotlib.cbook

if TYPE_CHECKING:
import stim # pragma: no cover
import rustworkx as rx

import pymatching._cpp_pymatching as _cpp_pm

Expand All @@ -38,7 +38,7 @@ class Matching:
"""

def __init__(self,
graph: Union[csc_matrix, np.ndarray, rx.PyGraph, nx.Graph, List[
graph: Union[csc_matrix, np.ndarray, "rx.PyGraph", nx.Graph, List[
List[int]], 'stim.DetectorErrorModel', spmatrix] = None,
weights: Union[float, np.ndarray, List[float]] = None,
error_probabilities: Union[float, np.ndarray, List[float]] = None,
Expand Down Expand Up @@ -136,22 +136,36 @@ def __init__(self,
if graph is None:
return
del kwargs["H"]
# Networkx graph
if isinstance(graph, nx.Graph):
self.load_from_networkx(graph)
elif isinstance(graph, rx.PyGraph):
self.load_from_rustworkx(graph)
elif type(graph).__name__ == "DetectorErrorModel":
self._load_from_detector_error_model(graph)
else:
try:
graph = csc_matrix(graph)
except TypeError:
raise TypeError("The type of the input graph is not recognised. `graph` must be "
"a scipy.sparse or numpy matrix, networkx or rustworkx graph, or "
"stim.DetectorErrorModel.")
self.load_from_check_matrix(graph, weights, error_probabilities,
repetitions, timelike_weights, measurement_error_probabilities,
**kwargs)
return
# Rustworkx PyGraph
try:
import rustworkx as rx
if isinstance(graph, rx.PyGraph):
self.load_from_rustworkx(graph)
return
except ImportError:
pass
# stim.DetectorErrorModel
try:
import stim
if isinstance(graph, stim.DetectorErrorModel):
self._load_from_detector_error_model(graph)
return
except ImportError:
pass
# scipy.csc_matrix
try:
graph = csc_matrix(graph)
except TypeError:
raise TypeError("The type of the input graph is not recognised. `graph` must be "
"a scipy.sparse or numpy matrix, networkx or rustworkx graph, or "
"stim.DetectorErrorModel.")
self.load_from_check_matrix(graph, weights, error_probabilities,
repetitions, timelike_weights, measurement_error_probabilities,
**kwargs)

def add_noise(self) -> Union[Tuple[np.ndarray, np.ndarray], None]:
"""Add noise by flipping edges in the matching graph with
Expand Down Expand Up @@ -210,16 +224,6 @@ def decode(self,
(modulo 2) between the (noisy) measurement of stabiliser `i` in time
step `j+1` and time step `j` (for the case where the matching graph is
constructed from a check matrix with `repetitions>1`).
_legacy_num_neighbours: int
The `num_neighbours` argument available in PyMatching versions 0.x.x is not
available in PyMatching v2.0.0 or later, since it introduced an approximation
that is not relevant or required in the new version 2 implementation.
Providing num_neighbours as this second positional argument will raise an exception in a
future version of PyMatching.
_legacy_return_weight: bool
``return_weight`` used to be available as this third positional argument, but should now
be set as a keyword argument. In a future version of PyMatching, it will only be possible
to provide `return_weight` as a keyword argument.
return_weight : bool, optional
If `return_weight==True`, the sum of the weights of the edges in the
minimum weight perfect matching is also returned. By default False
Expand Down Expand Up @@ -314,7 +318,7 @@ def decode(self,
if _legacy_return_weight is not None:
warnings.warn("The ``return_weights`` argument was provided as a positional argument, but in a future "
"version of PyMatching, it will be required to provide ``return_weights`` as a keyword "
"argument.")
"argument.", DeprecationWarning, stacklevel=2)
return_weight = _legacy_return_weight
detection_events = self._syndrome_array_to_detection_events(z)
correction, weight = self._matching_graph.decode(detection_events)
Expand Down Expand Up @@ -1474,7 +1478,7 @@ def load_from_networkx(self, graph: nx.Graph, *, min_num_fault_ids: int = None)
g.add_edge(u, v, fault_ids, weight, e_prob, merge_strategy="smallest-weight")
self._matching_graph = g

def load_from_retworkx(self, graph: rx.PyGraph, *, min_num_fault_ids: int = None) -> None:
def load_from_retworkx(self, graph: "rx.PyGraph", *, min_num_fault_ids: int = None) -> None:
r"""
Load a matching graph from a retworkX graph. This method is deprecated since the retworkx package has been
renamed to rustworkx. Please use ``pymatching.Matching.load_from_rustworkx`` instead.
Expand All @@ -1483,7 +1487,7 @@ def load_from_retworkx(self, graph: rx.PyGraph, *, min_num_fault_ids: int = None
"renamed to `rustworkx`. Please use `pymatching.Matching.load_from_rustworkx` instead.", DeprecationWarning, stacklevel=2)
self.load_from_rustworkx(graph=graph, min_num_fault_ids=min_num_fault_ids)

def load_from_rustworkx(self, graph: rx.PyGraph, *, min_num_fault_ids: int = None) -> None:
def load_from_rustworkx(self, graph: "rx.PyGraph", *, min_num_fault_ids: int = None) -> None:
r"""
Load a matching graph from a rustworkX graph
Expand Down Expand Up @@ -1525,6 +1529,10 @@ def load_from_rustworkx(self, graph: rx.PyGraph, *, min_num_fault_ids: int = Non
>>> m
<pymatching.Matching object with 1 detector, 2 boundary nodes, and 2 edges>
"""
try:
import rustworkx as rx
except ImportError:
raise ImportError("rustworkx must be installed to use Matching.load_from_rustworkx")
if not isinstance(graph, rx.PyGraph):
raise TypeError("G must be a rustworkx graph")
boundary = {i for i in graph.node_indices() if graph[i].get("is_boundary", False)}
Expand Down Expand Up @@ -1587,7 +1595,7 @@ def to_networkx(self) -> nx.Graph:
graph.nodes[num_nodes]['is_boundary'] = True
return graph

def to_retworkx(self) -> rx.PyGraph:
def to_retworkx(self) -> "rx.PyGraph":
"""Deprecated, use ``pymatching.Matching.to_rustworkx`` instead (since the `retworkx` package has been renamed to `rustworkx`).
This method just calls ``pymatching.Matching.to_rustworkx`` and returns a ``rustworkx.PyGraph``, which is now just the preferred name for
``retworkx.PyGraph``. Note that in the future, only the `rustworkx` package name will be supported,
Expand All @@ -1597,7 +1605,7 @@ def to_retworkx(self) -> rx.PyGraph:
"renamed to `rustworkx`. Please use `pymatching.Matching.to_rustworkx` instead.", DeprecationWarning, stacklevel=2)
return self.to_rustworkx()

def to_rustworkx(self) -> rx.PyGraph:
def to_rustworkx(self) -> "rx.PyGraph":
"""Convert to rustworkx graph
Returns a rustworkx graph object corresponding to the matching graph. Each edge
payload is a ``dict`` with keys `fault_ids`, `weight` and `error_probability` and
Expand All @@ -1609,6 +1617,11 @@ def to_rustworkx(self) -> rx.PyGraph:
rustworkx.PyGraph
rustworkx graph corresponding to the matching graph
"""
try:
import rustworkx as rx
except ImportError:
raise ImportError("rustworkx must be installed to use Matching.to_rustworkx.")

graph = rx.PyGraph(multigraph=False)
num_nodes = self.num_nodes
has_virtual_boundary = False
Expand Down
17 changes: 16 additions & 1 deletion tests/matching/load_from_rustworkx_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
# limitations under the License.

import numpy as np
import rustworkx as rx
import pytest

from pymatching import Matching
from pymatching._cpp_pymatching import MatchingGraph


def test_boundary_from_rustworkx():
rx = pytest.importorskip("rustworkx")
g = rx.PyGraph()
g.add_nodes_from([{} for _ in range(5)])
g.add_edge(4, 0, dict(fault_ids=0))
Expand All @@ -38,6 +38,7 @@ def test_boundary_from_rustworkx():


def test_boundaries_from_rustworkx():
rx = pytest.importorskip("rustworkx")
g = rx.PyGraph()
g.add_nodes_from([{} for _ in range(6)])
g.add_edge(0, 1, dict(fault_ids=0))
Expand All @@ -57,6 +58,7 @@ def test_boundaries_from_rustworkx():


def test_unweighted_stabiliser_graph_from_rustworkx():
rx = pytest.importorskip("rustworkx")
w = rx.PyGraph()
w.add_nodes_from([{} for _ in range(6)])
w.add_edge(0, 1, dict(fault_ids=0, weight=7.0))
Expand Down Expand Up @@ -90,6 +92,7 @@ def test_unweighted_stabiliser_graph_from_rustworkx():


def test_mwpm_from_rustworkx():
rx = pytest.importorskip("rustworkx")
g = rx.PyGraph()
g.add_nodes_from([{} for _ in range(3)])
g.add_edge(0, 1, dict(fault_ids=0))
Expand Down Expand Up @@ -122,6 +125,7 @@ def test_mwpm_from_rustworkx():


def test_matching_edges_from_rustworkx():
rx = pytest.importorskip("rustworkx")
g = rx.PyGraph()
g.add_nodes_from([{} for _ in range(4)])
g.add_edge(0, 1, dict(fault_ids=0, weight=1.1, error_probability=0.1))
Expand All @@ -143,6 +147,7 @@ def test_matching_edges_from_rustworkx():


def test_qubit_id_accepted_via_rustworkx():
rx = pytest.importorskip("rustworkx")
g = rx.PyGraph()
g.add_nodes_from([{} for _ in range(4)])
g.add_edge(0, 1, dict(qubit_id=0, weight=1.1, error_probability=0.1))
Expand All @@ -163,6 +168,7 @@ def test_qubit_id_accepted_via_rustworkx():


def test_load_from_rustworkx_raises_value_error_if_qubit_id_and_fault_ids_both_supplied():
rx = pytest.importorskip("rustworkx")
with pytest.raises(ValueError):
g = rx.PyGraph()
g.add_nodes_from([{} for _ in range(3)])
Expand All @@ -173,6 +179,7 @@ def test_load_from_rustworkx_raises_value_error_if_qubit_id_and_fault_ids_both_s


def test_load_from_rustworkx_type_errors_raised():
rx = pytest.importorskip("rustworkx")
with pytest.raises(TypeError):
m = Matching()
m.load_from_rustworkx("A")
Expand All @@ -188,3 +195,11 @@ def test_load_from_rustworkx_type_errors_raised():
g.add_edge(0, 1, dict(fault_ids=[[0], [2]]))
m = Matching()
m.load_from_rustworkx(g)


def test_load_from_rustworkx_without_rustworkx_raises_type_error():
try:
import rustworkx # noqa
except ImportError:
with pytest.raises(TypeError):
Matching.load_from_rustworkx("test")
3 changes: 2 additions & 1 deletion tests/matching/output_graph_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import networkx as nx
import rustworkx as rx
import pytest

from pymatching import Matching

Expand Down Expand Up @@ -60,6 +60,7 @@ def test_matching_to_networkx():


def test_matching_to_rustworkx():
rx = pytest.importorskip("rustworkx")
g = rx.PyGraph()
g.add_nodes_from([{} for _ in range(4)])
g.add_edge(0, 1, dict(fault_ids={0}, weight=1.1, error_probability=0.1))
Expand Down
6 changes: 5 additions & 1 deletion tests/matching/properties_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import networkx as nx
import rustworkx as rx
import pytest

from pymatching.matching import Matching

Expand Down Expand Up @@ -48,9 +48,13 @@ def test_set_min_num_fault_ids():
assert m.num_fault_ids == 4
assert m.decode([1, 1]).shape[0] == 4


def test_set_min_num_fault_ids_rustworkx():
rx = pytest.importorskip("rustworkx")
g = rx.PyGraph()
g.add_nodes_from([{} for _ in range(2)])
g.add_edge(0, 1, dict(fault_ids=3))
m = Matching()
m.load_from_rustworkx(g)
assert m.num_fault_ids == 4
assert m.decode([1, 1]).shape[0] == 4
Expand Down
3 changes: 2 additions & 1 deletion tests/matching/retworkx_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import pytest

from pymatching import Matching
import rustworkx as rx


def test_load_from_retworkx_deprecated():
rx = pytest.importorskip("rustworkx")
with pytest.deprecated_call():
g = rx.PyGraph()
g.add_nodes_from([{} for _ in range(3)])
Expand All @@ -16,6 +16,7 @@ def test_load_from_retworkx_deprecated():


def test_to_retworkx_deprecated():
_ = pytest.importorskip("rustworkx")
with pytest.deprecated_call():
m = Matching()
m.add_edge(0, 1, {0})
Expand Down

0 comments on commit 5369b80

Please sign in to comment.