Skip to content

Commit

Permalink
♻️ adjust DD addition logic
Browse files Browse the repository at this point in the history
Signed-off-by: burgholzer <[email protected]>
  • Loading branch information
burgholzer committed Sep 14, 2023
1 parent c3bbfc9 commit a74e4b1
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 20 deletions.
16 changes: 13 additions & 3 deletions include/dd/NoiseFunctionality.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,12 @@ template <class Config> class DeterministicNoiseFunctionality {
complexProb.r->value = probability;
if (!e[0].w.exactlyZero()) {
const auto w = package->cn.mulCached(complexProb, e[3].w);
const auto tmp = package->add2(e[0], {e[3].p, w});
const auto var =
static_cast<Qubit>(std::max({e[0].p != nullptr ? e[0].p->v : 0,
e[1].p != nullptr ? e[1].p->v : 0,
e[2].p != nullptr ? e[2].p->v : 0,
e[3].p != nullptr ? e[3].p->v : 0}));
const auto tmp = package->add2(e[0], {e[3].p, w}, var);
package->cn.returnToCache(w);
package->cn.returnToCache(e[0].w);
e[0] = tmp;
Expand Down Expand Up @@ -464,6 +469,11 @@ template <class Config> class DeterministicNoiseFunctionality {
Complex complexProb = package->cn.getCached();
complexProb.i->value = 0;

const auto var = static_cast<Qubit>(std::max(
{e[0].p != nullptr ? e[0].p->v : 0, e[1].p != nullptr ? e[1].p->v : 0,
e[2].p != nullptr ? e[2].p->v : 0,
e[3].p != nullptr ? e[3].p->v : 0}));

qc::DensityMatrixDD oldE0Edge{e[0].p, package->cn.getCached(e[0].w)};

// e[0] = 0.5*((2-p)*e[0] + p*e[3])
Expand All @@ -488,7 +498,7 @@ template <class Config> class DeterministicNoiseFunctionality {

// e[0] = helperEdge[0] + helperEdge[1]
package->cn.returnToCache(e[0].w);
e[0] = package->add2(helperEdge[0], helperEdge[1]);
e[0] = package->add2(helperEdge[0], helperEdge[1], var);
package->cn.returnToCache(helperEdge[0].w);
package->cn.returnToCache(helperEdge[1].w);
}
Expand Down Expand Up @@ -535,7 +545,7 @@ template <class Config> class DeterministicNoiseFunctionality {
}

package->cn.returnToCache(e[3].w);
e[3] = package->add2(helperEdge[0], helperEdge[1]);
e[3] = package->add2(helperEdge[0], helperEdge[1], var);
package->cn.returnToCache(helperEdge[0].w);
package->cn.returnToCache(helperEdge[1].w);
}
Expand Down
45 changes: 28 additions & 17 deletions include/dd/Package.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1576,7 +1576,17 @@ template <class Config> class Package {
template <class Edge> Edge add(const Edge& x, const Edge& y) {
[[maybe_unused]] const auto before = cn.cacheCount();

auto result = add2(x, y);
Qubit var{};
if (!x.isTerminal()) {
assert(x.p != nullptr);
var = x.p->v;
}
if (!y.isTerminal() && (y.p->v) > var) {
assert(y.p != nullptr);
var = y.p->v;
}

auto result = add2(x, y, var);
result.w = cn.lookup(result.w, true);

[[maybe_unused]] const auto after = cn.cacheCount();
Expand All @@ -1586,7 +1596,7 @@ template <class Config> class Package {
}

template <class Node>
Edge<Node> add2(const Edge<Node>& x, const Edge<Node>& y) {
Edge<Node> add2(const Edge<Node>& x, const Edge<Node>& y, const Qubit var) {
if (x.w.exactlyZero()) {
if (y.w.exactlyZero()) {
return Edge<Node>::zero;
Expand Down Expand Up @@ -1615,15 +1625,16 @@ template <class Config> class Package {
return {r->p, cn.getCached(r->w)};
}

const Qubit w = (x.isTerminal() || (!y.isTerminal() && y.p->v > x.p->v))
? y.p->v
: x.p->v;

constexpr std::size_t n = std::tuple_size_v<decltype(x.p->e)>;
std::array<Edge<Node>, n> edge{};
for (std::size_t i = 0U; i < n; i++) {
if constexpr (std::is_same_v<Node, vNode>) {
assert(!x.isTerminal() && "x must not be terminal");
assert(!y.isTerminal() && "y must not be terminal");
assert(x.p->v == y.p->v && "x and y must be at the same level");
}
Edge<Node> e1{};
if (!x.isTerminal() && x.p->v == w) {
if (!x.isTerminal()) {
e1 = x.p->e[i];

if (!e1.w.exactlyZero()) {
Expand All @@ -1636,7 +1647,7 @@ template <class Config> class Package {
}
}
Edge<Node> e2{};
if (!y.isTerminal() && y.p->v == w) {
if (!y.isTerminal()) {
e2 = y.p->e[i];

if (!e2.w.exactlyZero()) {
Expand All @@ -1651,22 +1662,22 @@ template <class Config> class Package {

if constexpr (std::is_same_v<Node, dNode>) {
dEdge::applyDmChangesToEdges(e1, e2);
edge[i] = add2(e1, e2);
edge[i] = add2(e1, e2, var - 1);
dEdge::revertDmChangesToEdges(e1, e2);
} else {
edge[i] = add2(e1, e2);
edge[i] = add2(e1, e2, var - 1);
}

if (!x.isTerminal() && x.p->v == w) {
if (!x.isTerminal() && x.p->v == var) {
cn.returnToCache(e1.w);
}

if (!y.isTerminal() && y.p->v == w) {
if (!y.isTerminal() && y.p->v == var) {
cn.returnToCache(e2.w);
}
}

auto e = makeDDNode(w, edge, true);
auto e = makeDDNode(var, edge, true);

computeTable.insert({x.p, x.w}, {y.p, y.w}, {e.p, e.w});
return e;
Expand Down Expand Up @@ -1917,7 +1928,7 @@ template <class Config> class Package {
} else if (!m.w.exactlyZero()) {
dEdge::applyDmChangesToEdges(edge[idx], m);
const auto w = edge[idx].w;
edge[idx] = add2(edge[idx], m);
edge[idx] = add2(edge[idx], m, v);
dEdge::revertDmChangesToEdges(edge[idx], e2);
cn.returnToCache(w);
cn.returnToCache(m.w);
Expand All @@ -1931,7 +1942,7 @@ template <class Config> class Package {
edge[idx] = m;
} else if (!m.w.exactlyZero()) {
const auto w = edge[idx].w;
edge[idx] = add2(edge[idx], m);
edge[idx] = add2(edge[idx], m, v);
cn.returnToCache(w);
cn.returnToCache(m.w);
}
Expand Down Expand Up @@ -2262,11 +2273,11 @@ template <class Config> class Package {
auto r = mEdge::zero;

const auto t0 = trace(a.p->e[0], eliminate, elims);
r = add2(r, t0);
r = add2(r, t0, v - 1);
auto r1 = r;

const auto t1 = trace(a.p->e[3], eliminate, elims);
r = add2(r, t1);
r = add2(r, t1, v - 1);
auto r2 = r;

if (r.w.exactlyOne()) {
Expand Down

0 comments on commit a74e4b1

Please sign in to comment.