diff --git a/include/dd/Edge.hpp b/include/dd/Edge.hpp index 6cb791746..95b4912fc 100644 --- a/include/dd/Edge.hpp +++ b/include/dd/Edge.hpp @@ -32,6 +32,7 @@ template struct Edge { [[nodiscard]] bool isTerminal() const; [[nodiscard]] bool isZeroTerminal() const; [[nodiscard]] bool isOneTerminal() const; + [[nodiscard]] bool isIdentity() const; // Functions only related to density matrices [[maybe_unused]] static void setDensityConjugateTrue(Edge& e); diff --git a/include/dd/Node.hpp b/include/dd/Node.hpp index ed23acd1b..793ae312a 100644 --- a/include/dd/Node.hpp +++ b/include/dd/Node.hpp @@ -105,6 +105,10 @@ struct dNode { // NOLINT(readability-identifier-naming) } static constexpr dNode* getTerminal() noexcept { return nullptr; } + [[nodiscard]] static constexpr bool isIdentity(const dNode* p) noexcept { + return p == nullptr; + } + [[nodiscard]] [[maybe_unused]] static inline bool tempDensityMatrixFlagsEqual(const std::uint8_t a, const std::uint8_t b) noexcept { diff --git a/include/dd/Package.hpp b/include/dd/Package.hpp index 1ee38fca3..9b1f97beb 100644 --- a/include/dd/Package.hpp +++ b/include/dd/Package.hpp @@ -1839,8 +1839,15 @@ template class Package { return ResultEdge::zero; } - if (x.isTerminal() && y.isTerminal()) { - return ResultEdge::terminal(cn.mulCached(x.w, y.w)); + if (x.isIdentity()) { + return {y.p, cn.mulCached(x.w, y.w)}; + } + + if constexpr (std::is_same_v || + std::is_same_v) { + if (y.isIdentity()) { + return {x.p, cn.mulCached(x.w, y.w)}; + } } auto xCopy = LEdge{x.p, Complex::one}; @@ -1865,48 +1872,6 @@ template class Package { } constexpr std::size_t n = std::tuple_size_ve)>; - ResultEdge e{}; - if constexpr (std::is_same_v) { - // This branch is only taken for matrices - if (x.p->v == var && x.p->v == y.p->v) { - if (x.p->isIdentity()) { - if constexpr (n == NEDGE) { - // additionally check if y is the identity in case of matrix - // multiplication - if (y.p->isIdentity()) { - e = makeIdent(start, var); - } else { - e = yCopy; - } - } else { - e = yCopy; - } - computeTable.insert(xCopy, yCopy, {e.p, e.w}); - e.w = cn.mulCached(x.w, y.w); - if (e.w.approximatelyZero()) { - cn.returnToCache(e.w); - return ResultEdge::zero; - } - return e; - } - - if constexpr (n == NEDGE) { - // additionally check if y is the identity in case of matrix - // multiplication - if (y.p->isIdentity()) { - e = xCopy; - computeTable.insert(xCopy, yCopy, {e.p, e.w}); - e.w = cn.mulCached(x.w, y.w); - - if (e.w.approximatelyZero()) { - cn.returnToCache(e.w); - return ResultEdge::zero; - } - return e; - } - } - } - } constexpr std::size_t rows = RADIX; constexpr std::size_t cols = n == NEDGE ? RADIX : 1U; @@ -1982,7 +1947,8 @@ template class Package { } } } - e = makeDDNode(var, edge, true, generateDensityMatrix); + + auto e = makeDDNode(var, edge, true, generateDensityMatrix); computeTable.insert(xCopy, yCopy, {e.p, e.w}); if (!e.w.exactlyZero()) { diff --git a/src/dd/Edge.cpp b/src/dd/Edge.cpp index 9ca45bdd6..5c21fdd0f 100644 --- a/src/dd/Edge.cpp +++ b/src/dd/Edge.cpp @@ -22,6 +22,13 @@ template bool Edge::isOneTerminal() const { return isTerminal() && w == Complex::one; } +template bool Edge::isIdentity() const { + if constexpr (std::is_same_v || std::is_same_v) { + return Node::isIdentity(p); + } + return false; +} + template CachedEdge::CachedEdge(Node* n, const Complex& c) : p(n) { w.r = RealNumber::val(c.r);