Skip to content

Commit

Permalink
Refactor by adding UserGraph::to_matching_or_search_graph_helper
Browse files Browse the repository at this point in the history
  • Loading branch information
oscarhiggott committed Feb 1, 2024
1 parent 5d90773 commit 9710f5a
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 41 deletions.
48 changes: 8 additions & 40 deletions src/pymatching/sparse_blossom/driver/user_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,31 +246,15 @@ double pm::UserGraph::max_abs_weight() {
pm::MatchingGraph pm::UserGraph::to_matching_graph(pm::weight_int num_distinct_weights) {
pm::MatchingGraph matching_graph(nodes.size(), _num_observables);

// Use vectors to store boundary edges initially before adding them to matching_graph, so
// that parallel boundary edges with negative edge weights can be handled correctly
std::vector<bool> has_boundary_edge(nodes.size(), false);
std::vector<pm::signed_weight_int> boundary_edge_weights(nodes.size());
std::vector<std::vector<size_t>> boundary_edge_observables(nodes.size());

double normalising_constant = iter_discretized_edges(
double normalising_constant = to_matching_or_search_graph_helper(
num_distinct_weights,
[&](size_t u, size_t v, pm::signed_weight_int weight, const std::vector<size_t>& observables) {
matching_graph.add_edge(u, v, weight, observables);
},
[&](size_t u, pm::signed_weight_int weight, const std::vector<size_t>& observables) {
// For parallel boundary edges, keep the boundary edge with the smaller weight
if (!has_boundary_edge[u] || boundary_edge_weights[u] > weight){
boundary_edge_weights[u] = weight;
boundary_edge_observables[u] = observables;
has_boundary_edge[u] = true;
}
});

// Now add boundary edges to matching_graph
for (size_t i = 0; i < has_boundary_edge.size(); i++) {
if (has_boundary_edge[i])
matching_graph.add_boundary_edge(i, boundary_edge_weights[i], boundary_edge_observables[i]);
}
matching_graph.add_boundary_edge(u, weight, observables);
}
);

matching_graph.normalising_constant = normalising_constant;
if (boundary_nodes.size() > 0) {
Expand All @@ -286,31 +270,15 @@ pm::SearchGraph pm::UserGraph::to_search_graph(pm::weight_int num_distinct_weigh
/// Identical to to_matching_graph but for constructing a pm::SearchGraph
pm::SearchGraph search_graph(nodes.size());

// Use vectors to store boundary edges initially before adding them to search_graph, so
// that parallel boundary edges with negative edge weights can be handled correctly
std::vector<bool> has_boundary_edge(nodes.size(), false);
std::vector<pm::signed_weight_int> boundary_edge_weights(nodes.size());
std::vector<std::vector<size_t>> boundary_edge_observables(nodes.size());

double normalising_constant = iter_discretized_edges(
to_matching_or_search_graph_helper(
num_distinct_weights,
[&](size_t u, size_t v, pm::signed_weight_int weight, const std::vector<size_t>& observables) {
search_graph.add_edge(u, v, weight, observables);
},
[&](size_t u, pm::signed_weight_int weight, const std::vector<size_t>& observables) {
// For parallel boundary edges, keep the boundary edge with the smaller weight
if (!has_boundary_edge[u] || boundary_edge_weights[u] > weight){
boundary_edge_weights[u] = weight;
boundary_edge_observables[u] = observables;
has_boundary_edge[u] = true;
}
});

// Now add boundary edges to search_graph
for (size_t i = 0; i < has_boundary_edge.size(); i++) {
if (has_boundary_edge[i])
search_graph.add_boundary_edge(i, boundary_edge_weights[i], boundary_edge_observables[i]);
}
search_graph.add_boundary_edge(u, weight, observables);
}
);
return search_graph;
}

Expand Down
38 changes: 37 additions & 1 deletion src/pymatching/sparse_blossom/driver/user_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ class UserGraph {
pm::weight_int num_distinct_weights,
const EdgeCallable& edge_func,
const BoundaryEdgeCallable& boundary_edge_func);
template <typename EdgeCallable, typename BoundaryEdgeCallable>
double to_matching_or_search_graph_helper(
pm::weight_int num_distinct_weights,
const EdgeCallable& edge_func,
const BoundaryEdgeCallable& boundary_edge_func);
pm::MatchingGraph to_matching_graph(pm::weight_int num_distinct_weights);
pm::SearchGraph to_search_graph(pm::weight_int num_distinct_weights);
pm::Mwpm to_mwpm(pm::weight_int num_distinct_weights, bool ensure_search_graph_included);
Expand All @@ -120,7 +125,6 @@ inline double UserGraph::iter_discretized_edges(
pm::weight_int num_distinct_weights,
const EdgeCallable& edge_func,
const BoundaryEdgeCallable& boundary_edge_func) {
pm::MatchingGraph matching_graph(nodes.size(), _num_observables);
double normalising_constant = get_edge_weight_normalising_constant(num_distinct_weights);

for (auto& e : edges) {
Expand All @@ -141,6 +145,38 @@ inline double UserGraph::iter_discretized_edges(
return normalising_constant * 2;
}

template <typename EdgeCallable, typename BoundaryEdgeCallable>
inline double UserGraph::to_matching_or_search_graph_helper(
pm::weight_int num_distinct_weights,
const EdgeCallable& edge_func,
const BoundaryEdgeCallable& boundary_edge_func) {

// Use vectors to store boundary edges initially before adding them to the graph, so
// that parallel boundary edges with negative edge weights can be handled correctly
std::vector<bool> has_boundary_edge(nodes.size(), false);
std::vector<pm::signed_weight_int> boundary_edge_weights(nodes.size());
std::vector<std::vector<size_t>> boundary_edge_observables(nodes.size());

double normalising_constant = iter_discretized_edges(
num_distinct_weights,
edge_func,
[&](size_t u, pm::signed_weight_int weight, const std::vector<size_t>& observables) {
// For parallel boundary edges, keep the boundary edge with the smaller weight
if (!has_boundary_edge[u] || boundary_edge_weights[u] > weight){
boundary_edge_weights[u] = weight;
boundary_edge_observables[u] = observables;
has_boundary_edge[u] = true;
}
});

// Now add boundary edges to the graph
for (size_t i = 0; i < has_boundary_edge.size(); i++) {
if (has_boundary_edge[i])
boundary_edge_func(i, boundary_edge_weights[i], boundary_edge_observables[i]);
}
return normalising_constant;
}

UserGraph detector_error_model_to_user_graph(const stim::DetectorErrorModel& detector_error_model);

} // namespace pm
Expand Down

0 comments on commit 9710f5a

Please sign in to comment.