Skip to content

Commit

Permalink
⚡ properly handle identities in DD multiplication
Browse files Browse the repository at this point in the history
Signed-off-by: burgholzer <[email protected]>
  • Loading branch information
burgholzer committed Sep 14, 2023
1 parent 52813cc commit 69a98f9
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 45 deletions.
1 change: 1 addition & 0 deletions include/dd/Edge.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ template <class Node> struct Edge {
[[nodiscard]] bool isTerminal() const;
[[nodiscard]] bool isZeroTerminal() const;
[[nodiscard]] bool isOneTerminal() const;
[[nodiscard]] bool isIdentity() const;

// Functions only related to density matrices
[[maybe_unused]] static void setDensityConjugateTrue(Edge& e);
Expand Down
4 changes: 4 additions & 0 deletions include/dd/Node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ struct dNode { // NOLINT(readability-identifier-naming)
}
static constexpr dNode* getTerminal() noexcept { return nullptr; }

[[nodiscard]] static constexpr bool isIdentity(const dNode* p) noexcept {
return p == nullptr;
}

[[nodiscard]] [[maybe_unused]] static inline bool
tempDensityMatrixFlagsEqual(const std::uint8_t a,
const std::uint8_t b) noexcept {
Expand Down
56 changes: 11 additions & 45 deletions include/dd/Package.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1839,8 +1839,15 @@ template <class Config> class Package {
return ResultEdge::zero;
}

if (x.isTerminal() && y.isTerminal()) {
return ResultEdge::terminal(cn.mulCached(x.w, y.w));
if (x.isIdentity()) {
return {y.p, cn.mulCached(x.w, y.w)};
}

if constexpr (std::is_same_v<RightOperandNode, mNode> ||
std::is_same_v<RightOperandNode, dNode>) {
if (y.isIdentity()) {
return {x.p, cn.mulCached(x.w, y.w)};
}
}

auto xCopy = LEdge{x.p, Complex::one};
Expand All @@ -1865,48 +1872,6 @@ template <class Config> class Package {
}

constexpr std::size_t n = std::tuple_size_v<decltype(y.p->e)>;
ResultEdge e{};
if constexpr (std::is_same_v<RightOperandNode, mCachedEdge>) {
// This branch is only taken for matrices
if (x.p->v == var && x.p->v == y.p->v) {
if (x.p->isIdentity()) {
if constexpr (n == NEDGE) {
// additionally check if y is the identity in case of matrix
// multiplication
if (y.p->isIdentity()) {
e = makeIdent(start, var);
} else {
e = yCopy;
}
} else {
e = yCopy;
}
computeTable.insert(xCopy, yCopy, {e.p, e.w});
e.w = cn.mulCached(x.w, y.w);
if (e.w.approximatelyZero()) {
cn.returnToCache(e.w);
return ResultEdge::zero;
}
return e;
}

if constexpr (n == NEDGE) {
// additionally check if y is the identity in case of matrix
// multiplication
if (y.p->isIdentity()) {
e = xCopy;
computeTable.insert(xCopy, yCopy, {e.p, e.w});
e.w = cn.mulCached(x.w, y.w);

if (e.w.approximatelyZero()) {
cn.returnToCache(e.w);
return ResultEdge::zero;
}
return e;
}
}
}
}

constexpr std::size_t rows = RADIX;
constexpr std::size_t cols = n == NEDGE ? RADIX : 1U;
Expand Down Expand Up @@ -1982,7 +1947,8 @@ template <class Config> class Package {
}
}
}
e = makeDDNode(var, edge, true, generateDensityMatrix);

auto e = makeDDNode(var, edge, true, generateDensityMatrix);
computeTable.insert(xCopy, yCopy, {e.p, e.w});

if (!e.w.exactlyZero()) {
Expand Down
7 changes: 7 additions & 0 deletions src/dd/Edge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ template <class Node> bool Edge<Node>::isOneTerminal() const {
return isTerminal() && w == Complex::one;
}

template <class Node> bool Edge<Node>::isIdentity() const {
if constexpr (std::is_same_v<Node, mNode> || std::is_same_v<Node, dNode>) {
return Node::isIdentity(p);
}
return false;
}

template <typename Node>
CachedEdge<Node>::CachedEdge(Node* n, const Complex& c) : p(n) {
w.r = RealNumber::val(c.r);
Expand Down

0 comments on commit 69a98f9

Please sign in to comment.