Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove rustworkx as a dependency #90

Merged
merged 3 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -200,13 +200,13 @@ jobs:
- name: Build and install
run: python -m pip install --verbose -e .

- name: Test without stim
- name: Test without stim or rustworkx
run: python -m pytest tests

- name: Add stim
run: python -m pip install stim
- name: Install stim and rustworkx
run: python -m pip install stim rustworkx

- name: Test with stim using coverage
- name: Test with stim and rustworkx using coverage
run: python -m pytest tests --cov=./src/pymatching --cov-report term

- name: flake8
Expand Down Expand Up @@ -244,7 +244,7 @@ jobs:
with:
python-version: '3.10'
- name: Add requirements
run: python -m pip install --upgrade cmake>=3.12 ninja pytest flake8 pytest-cov stim
run: python -m pip install --upgrade cmake>=3.12 ninja pytest flake8 pytest-cov stim rustworkx
- name: Build and install
run: pip install --verbose -e .
- name: Run tests and collect coverage
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
cov.xml
coverage.xml
.coverage
*.ipynb_checkpoints
.idea
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def build_extension(self, ext):
'console_scripts': ['pymatching=pymatching._cli_argv:cli_argv'],
},
python_requires=">=3.7",
install_requires=['scipy', 'numpy', 'networkx', 'rustworkx', 'matplotlib'],
install_requires=['scipy', 'numpy', 'networkx', 'matplotlib'],
# Needed on Windows to avoid the default `build` colliding with Bazel's `BUILD`.
options={'build': {'build_base': 'python_build_stim'}},
)
77 changes: 45 additions & 32 deletions src/pymatching/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
import warnings

import numpy as np
import networkx as nx
import rustworkx as rx
import pymatching
import networkx as nx
from scipy.sparse import csc_matrix, spmatrix
import matplotlib.cbook

if TYPE_CHECKING:
import stim # pragma: no cover
import rustworkx as rx # pragma: no cover

import pymatching._cpp_pymatching as _cpp_pm

Expand All @@ -38,7 +38,7 @@ class Matching:
"""

def __init__(self,
graph: Union[csc_matrix, np.ndarray, rx.PyGraph, nx.Graph, List[
graph: Union[csc_matrix, np.ndarray, "rx.PyGraph", nx.Graph, List[
List[int]], 'stim.DetectorErrorModel', spmatrix] = None,
weights: Union[float, np.ndarray, List[float]] = None,
error_probabilities: Union[float, np.ndarray, List[float]] = None,
Expand Down Expand Up @@ -136,22 +136,36 @@ def __init__(self,
if graph is None:
return
del kwargs["H"]
# Networkx graph
if isinstance(graph, nx.Graph):
self.load_from_networkx(graph)
elif isinstance(graph, rx.PyGraph):
self.load_from_rustworkx(graph)
elif type(graph).__name__ == "DetectorErrorModel":
self._load_from_detector_error_model(graph)
else:
try:
graph = csc_matrix(graph)
except TypeError:
raise TypeError("The type of the input graph is not recognised. `graph` must be "
"a scipy.sparse or numpy matrix, networkx or rustworkx graph, or "
"stim.DetectorErrorModel.")
self.load_from_check_matrix(graph, weights, error_probabilities,
repetitions, timelike_weights, measurement_error_probabilities,
**kwargs)
return
# Rustworkx PyGraph
try:
import rustworkx as rx
if isinstance(graph, rx.PyGraph):
self.load_from_rustworkx(graph)
return
except ImportError: # pragma no cover
pass
# stim.DetectorErrorModel
try:
import stim
if isinstance(graph, stim.DetectorErrorModel):
self._load_from_detector_error_model(graph)
return
except ImportError: # pragma no cover
pass
# scipy.csc_matrix
try:
graph = csc_matrix(graph)
except TypeError:
raise TypeError("The type of the input graph is not recognised. `graph` must be "
"a scipy.sparse or numpy matrix, networkx or rustworkx graph, or "
"stim.DetectorErrorModel.")
self.load_from_check_matrix(graph, weights, error_probabilities,
repetitions, timelike_weights, measurement_error_probabilities,
**kwargs)

def add_noise(self) -> Union[Tuple[np.ndarray, np.ndarray], None]:
"""Add noise by flipping edges in the matching graph with
Expand Down Expand Up @@ -210,16 +224,6 @@ def decode(self,
(modulo 2) between the (noisy) measurement of stabiliser `i` in time
step `j+1` and time step `j` (for the case where the matching graph is
constructed from a check matrix with `repetitions>1`).
_legacy_num_neighbours: int
The `num_neighbours` argument available in PyMatching versions 0.x.x is not
available in PyMatching v2.0.0 or later, since it introduced an approximation
that is not relevant or required in the new version 2 implementation.
Providing num_neighbours as this second positional argument will raise an exception in a
future version of PyMatching.
_legacy_return_weight: bool
``return_weight`` used to be available as this third positional argument, but should now
be set as a keyword argument. In a future version of PyMatching, it will only be possible
to provide `return_weight` as a keyword argument.
return_weight : bool, optional
If `return_weight==True`, the sum of the weights of the edges in the
minimum weight perfect matching is also returned. By default False
Expand Down Expand Up @@ -314,7 +318,7 @@ def decode(self,
if _legacy_return_weight is not None:
warnings.warn("The ``return_weights`` argument was provided as a positional argument, but in a future "
"version of PyMatching, it will be required to provide ``return_weights`` as a keyword "
"argument.")
"argument.", DeprecationWarning, stacklevel=2)
return_weight = _legacy_return_weight
detection_events = self._syndrome_array_to_detection_events(z)
correction, weight = self._matching_graph.decode(detection_events)
Expand Down Expand Up @@ -1474,7 +1478,7 @@ def load_from_networkx(self, graph: nx.Graph, *, min_num_fault_ids: int = None)
g.add_edge(u, v, fault_ids, weight, e_prob, merge_strategy="smallest-weight")
self._matching_graph = g

def load_from_retworkx(self, graph: rx.PyGraph, *, min_num_fault_ids: int = None) -> None:
def load_from_retworkx(self, graph: "rx.PyGraph", *, min_num_fault_ids: int = None) -> None:
r"""
Load a matching graph from a retworkX graph. This method is deprecated since the retworkx package has been
renamed to rustworkx. Please use ``pymatching.Matching.load_from_rustworkx`` instead.
Expand All @@ -1483,7 +1487,7 @@ def load_from_retworkx(self, graph: rx.PyGraph, *, min_num_fault_ids: int = None
"renamed to `rustworkx`. Please use `pymatching.Matching.load_from_rustworkx` instead.", DeprecationWarning, stacklevel=2)
self.load_from_rustworkx(graph=graph, min_num_fault_ids=min_num_fault_ids)

def load_from_rustworkx(self, graph: rx.PyGraph, *, min_num_fault_ids: int = None) -> None:
def load_from_rustworkx(self, graph: "rx.PyGraph", *, min_num_fault_ids: int = None) -> None:
r"""
Load a matching graph from a rustworkX graph

Expand Down Expand Up @@ -1525,6 +1529,10 @@ def load_from_rustworkx(self, graph: rx.PyGraph, *, min_num_fault_ids: int = Non
>>> m
<pymatching.Matching object with 1 detector, 2 boundary nodes, and 2 edges>
"""
try:
import rustworkx as rx
except ImportError: # pragma no cover
raise ImportError("rustworkx must be installed to use Matching.load_from_rustworkx")
if not isinstance(graph, rx.PyGraph):
raise TypeError("G must be a rustworkx graph")
boundary = {i for i in graph.node_indices() if graph[i].get("is_boundary", False)}
Expand Down Expand Up @@ -1587,7 +1595,7 @@ def to_networkx(self) -> nx.Graph:
graph.nodes[num_nodes]['is_boundary'] = True
return graph

def to_retworkx(self) -> rx.PyGraph:
def to_retworkx(self) -> "rx.PyGraph":
"""Deprecated, use ``pymatching.Matching.to_rustworkx`` instead (since the `retworkx` package has been renamed to `rustworkx`).
This method just calls ``pymatching.Matching.to_rustworkx`` and returns a ``rustworkx.PyGraph``, which is now just the preferred name for
``retworkx.PyGraph``. Note that in the future, only the `rustworkx` package name will be supported,
Expand All @@ -1597,7 +1605,7 @@ def to_retworkx(self) -> rx.PyGraph:
"renamed to `rustworkx`. Please use `pymatching.Matching.to_rustworkx` instead.", DeprecationWarning, stacklevel=2)
return self.to_rustworkx()

def to_rustworkx(self) -> rx.PyGraph:
def to_rustworkx(self) -> "rx.PyGraph":
"""Convert to rustworkx graph
Returns a rustworkx graph object corresponding to the matching graph. Each edge
payload is a ``dict`` with keys `fault_ids`, `weight` and `error_probability` and
Expand All @@ -1609,6 +1617,11 @@ def to_rustworkx(self) -> rx.PyGraph:
rustworkx.PyGraph
rustworkx graph corresponding to the matching graph
"""
try:
import rustworkx as rx
except ImportError: # pragma no cover
raise ImportError("rustworkx must be installed to use Matching.to_rustworkx.")

graph = rx.PyGraph(multigraph=False)
num_nodes = self.num_nodes
has_virtual_boundary = False
Expand Down
1 change: 1 addition & 0 deletions tests/matching/docstrings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

def test_matching_docstrings():
pytest.importorskip("stim")
pytest.importorskip("rustworkx")
doctest.testmod(pymatching.matching, raise_on_error=True)


Expand Down
17 changes: 16 additions & 1 deletion tests/matching/load_from_rustworkx_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
# limitations under the License.

import numpy as np
import rustworkx as rx
import pytest

from pymatching import Matching
from pymatching._cpp_pymatching import MatchingGraph


def test_boundary_from_rustworkx():
rx = pytest.importorskip("rustworkx")
g = rx.PyGraph()
g.add_nodes_from([{} for _ in range(5)])
g.add_edge(4, 0, dict(fault_ids=0))
Expand All @@ -38,6 +38,7 @@ def test_boundary_from_rustworkx():


def test_boundaries_from_rustworkx():
rx = pytest.importorskip("rustworkx")
g = rx.PyGraph()
g.add_nodes_from([{} for _ in range(6)])
g.add_edge(0, 1, dict(fault_ids=0))
Expand All @@ -57,6 +58,7 @@ def test_boundaries_from_rustworkx():


def test_unweighted_stabiliser_graph_from_rustworkx():
rx = pytest.importorskip("rustworkx")
w = rx.PyGraph()
w.add_nodes_from([{} for _ in range(6)])
w.add_edge(0, 1, dict(fault_ids=0, weight=7.0))
Expand Down Expand Up @@ -90,6 +92,7 @@ def test_unweighted_stabiliser_graph_from_rustworkx():


def test_mwpm_from_rustworkx():
rx = pytest.importorskip("rustworkx")
g = rx.PyGraph()
g.add_nodes_from([{} for _ in range(3)])
g.add_edge(0, 1, dict(fault_ids=0))
Expand Down Expand Up @@ -122,6 +125,7 @@ def test_mwpm_from_rustworkx():


def test_matching_edges_from_rustworkx():
rx = pytest.importorskip("rustworkx")
g = rx.PyGraph()
g.add_nodes_from([{} for _ in range(4)])
g.add_edge(0, 1, dict(fault_ids=0, weight=1.1, error_probability=0.1))
Expand All @@ -143,6 +147,7 @@ def test_matching_edges_from_rustworkx():


def test_qubit_id_accepted_via_rustworkx():
rx = pytest.importorskip("rustworkx")
g = rx.PyGraph()
g.add_nodes_from([{} for _ in range(4)])
g.add_edge(0, 1, dict(qubit_id=0, weight=1.1, error_probability=0.1))
Expand All @@ -163,6 +168,7 @@ def test_qubit_id_accepted_via_rustworkx():


def test_load_from_rustworkx_raises_value_error_if_qubit_id_and_fault_ids_both_supplied():
rx = pytest.importorskip("rustworkx")
with pytest.raises(ValueError):
g = rx.PyGraph()
g.add_nodes_from([{} for _ in range(3)])
Expand All @@ -173,6 +179,7 @@ def test_load_from_rustworkx_raises_value_error_if_qubit_id_and_fault_ids_both_s


def test_load_from_rustworkx_type_errors_raised():
rx = pytest.importorskip("rustworkx")
with pytest.raises(TypeError):
m = Matching()
m.load_from_rustworkx("A")
Expand All @@ -188,3 +195,11 @@ def test_load_from_rustworkx_type_errors_raised():
g.add_edge(0, 1, dict(fault_ids=[[0], [2]]))
m = Matching()
m.load_from_rustworkx(g)


def test_load_from_rustworkx_without_rustworkx_raises_type_error():
try:
import rustworkx # noqa
except ImportError:
with pytest.raises(TypeError):
Matching.load_from_rustworkx("test")
3 changes: 2 additions & 1 deletion tests/matching/output_graph_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import networkx as nx
import rustworkx as rx
import pytest

from pymatching import Matching

Expand Down Expand Up @@ -60,6 +60,7 @@ def test_matching_to_networkx():


def test_matching_to_rustworkx():
rx = pytest.importorskip("rustworkx")
g = rx.PyGraph()
g.add_nodes_from([{} for _ in range(4)])
g.add_edge(0, 1, dict(fault_ids={0}, weight=1.1, error_probability=0.1))
Expand Down
6 changes: 5 additions & 1 deletion tests/matching/properties_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import networkx as nx
import rustworkx as rx
import pytest

from pymatching.matching import Matching

Expand Down Expand Up @@ -48,9 +48,13 @@ def test_set_min_num_fault_ids():
assert m.num_fault_ids == 4
assert m.decode([1, 1]).shape[0] == 4


def test_set_min_num_fault_ids_rustworkx():
rx = pytest.importorskip("rustworkx")
g = rx.PyGraph()
g.add_nodes_from([{} for _ in range(2)])
g.add_edge(0, 1, dict(fault_ids=3))
m = Matching()
m.load_from_rustworkx(g)
assert m.num_fault_ids == 4
assert m.decode([1, 1]).shape[0] == 4
Expand Down
3 changes: 2 additions & 1 deletion tests/matching/retworkx_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import pytest

from pymatching import Matching
import rustworkx as rx


def test_load_from_retworkx_deprecated():
rx = pytest.importorskip("rustworkx")
with pytest.deprecated_call():
g = rx.PyGraph()
g.add_nodes_from([{} for _ in range(3)])
Expand All @@ -16,6 +16,7 @@ def test_load_from_retworkx_deprecated():


def test_to_retworkx_deprecated():
_ = pytest.importorskip("rustworkx")
with pytest.deprecated_call():
m = Matching()
m.add_edge(0, 1, {0})
Expand Down
Loading