Skip to content

Commit

Permalink
🚸 Support for adding tests to NALAC (cda-tum#629)
Browse files Browse the repository at this point in the history
## Description

This PR contains modifications that were necessary for the tests of the
Neutral Atom Logical Array Compiler, see cda-tum/mqt-qmap#470.

## Checklist:

- [x] The pull request only contains commits that are related to it.
- [x] I have added appropriate tests and documentation.
- [x] I have made sure that all CI jobs on GitHub pass.
- [x] The pull request introduces no new warnings and follows the
project's style guidelines.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: burgholzer <[email protected]>
  • Loading branch information
3 people authored Jun 7, 2024
1 parent 7eb0ab7 commit 0e4ff9e
Show file tree
Hide file tree
Showing 16 changed files with 113 additions and 72 deletions.
7 changes: 7 additions & 0 deletions include/mqt-core/Permutation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ class Permutation : public std::map<Qubit, Qubit> {
}
return t;
}

[[nodiscard]] auto apply(const Qubit qubit) const -> Qubit {
if (empty()) {
return qubit;
}
return at(qubit);
}
};
} // namespace qc

Expand Down
1 change: 0 additions & 1 deletion include/mqt-core/na/NAComputation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ class NAComputation {
}
auto clear(const bool clearInitialPositions = true) -> void {
operations.clear();
initialPositions.clear();
if (clearInitialPositions) {
initialPositions.clear();
}
Expand Down
9 changes: 9 additions & 0 deletions include/mqt-core/na/NADefinitions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,12 @@ template <> struct std::hash<na::FullOpType> {
return qc::combineHash(h1, h2);
}
};

/// Hash function for Point, e.g., for use in unordered_map
template <> struct std::hash<na::Point> {
std::size_t operator()(const na::Point& p) const noexcept {
const std::size_t h1 = std::hash<decltype(p.x)>{}(p.x);
const std::size_t h2 = std::hash<decltype(p.y)>{}(p.y);
return qc::combineHash(h1, h2);
}
};
1 change: 1 addition & 0 deletions include/mqt-core/na/operations/NAGlobalOperation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class NAGlobalOperation : public NAOperation {
[[nodiscard]] auto getParams() const -> const std::vector<qc::fp>& {
return params;
}
[[nodiscard]] auto getType() const -> FullOpType { return type; }
[[nodiscard]] auto isGlobalOperation() const -> bool override { return true; }
[[nodiscard]] auto toString() const -> std::string override;
[[nodiscard]] auto clone() const -> std::unique_ptr<NAOperation> override {
Expand Down
3 changes: 2 additions & 1 deletion include/mqt-core/operations/CompoundOperation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ class CompoundOperation final : public Operation {

std::vector<std::unique_ptr<Operation>>& getOps() { return ops; }

[[nodiscard]] std::set<Qubit> getUsedQubits() const override;
[[nodiscard]] auto getUsedQubitsPermuted(const Permutation& perm) const
-> std::set<Qubit> override;

[[nodiscard]] auto commutesAtQubit(const Operation& other,
const Qubit& qubit) const -> bool override;
Expand Down
13 changes: 0 additions & 13 deletions include/mqt-core/operations/NonUnitaryOperation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,6 @@ class NonUnitaryOperation final : public Operation {
std::vector<Bit>& getClassics() { return classics; }
[[nodiscard]] std::size_t getNclassics() const { return classics.size(); }

[[nodiscard]] std::set<Qubit> getUsedQubits() const override {
const auto& opTargets = getTargets();
return {opTargets.begin(), opTargets.end()};
}

[[nodiscard]] const Controls& getControls() const override {
throw QFRException("Cannot get controls from non-unitary operation.");
}

[[nodiscard]] Controls& getControls() override {
throw QFRException("Cannot get controls from non-unitary operation.");
}

void addDepthContribution(std::vector<std::size_t>& depths) const override;

void addControl(const Control /*c*/) override {
Expand Down
13 changes: 4 additions & 9 deletions include/mqt-core/operations/Operation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,10 @@ class Operation {
[[nodiscard]] const std::string& getName() const { return name; }
[[nodiscard]] virtual OpType getType() const { return type; }

[[nodiscard]] virtual std::set<Qubit> getUsedQubits() const {
const auto& opTargets = getTargets();
const auto& opControls = getControls();
std::set<Qubit> usedQubits = {opTargets.begin(), opTargets.end()};
for (const auto& control : opControls) {
usedQubits.insert(control.qubit);
}
return usedQubits;
}
[[nodiscard]] virtual auto
getUsedQubitsPermuted(const Permutation& perm) const -> std::set<Qubit>;

[[nodiscard]] auto getUsedQubits() const -> std::set<Qubit>;

[[nodiscard]] std::unique_ptr<Operation> getInverted() const {
auto op = clone();
Expand Down
11 changes: 7 additions & 4 deletions src/na/NAComputation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@ auto NAComputation::toString() const -> std::string {
std::stringstream ss;
ss << "init at ";
for (const auto& p : initialPositions) {
ss << "(" << p->x << ", " << p->y << ")"
<< ", ";
ss << *p << ", ";
}
if (ss.tellp() == 8) {
ss.seekp(-1, std::ios_base::end);
} else {
ss.seekp(-2, std::ios_base::end);
}
ss.seekp(-2, std::ios_base::end);
ss << ";\n";
for (const auto& op : operations) {
ss << op->toString();
ss << *op;
}
return ss.str();
}
Expand Down
10 changes: 7 additions & 3 deletions src/na/operations/NALocalOperation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@ auto NALocalOperation::toString() const -> std::string {
ss << ")";
}
ss << " at ";
for (const auto& p : positions) {
ss << *p << ", ";
if (positions.empty()) {
ss.seekp(-1, std::ios_base::end);
} else {
for (const auto& p : positions) {
ss << *p << ", ";
}
ss.seekp(-2, std::ios_base::end);
}
ss.seekp(-2, std::ios_base::end);
ss << ";\n";
return ss.str();
}
Expand Down
5 changes: 3 additions & 2 deletions src/operations/CompoundOperation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,11 @@ void CompoundOperation::dumpOpenQASM(std::ostream& of,
}
}

std::set<Qubit> CompoundOperation::getUsedQubits() const {
auto CompoundOperation::getUsedQubitsPermuted(const Permutation& perm) const
-> std::set<Qubit> {
std::set<Qubit> usedQubits{};
for (const auto& op : ops) {
usedQubits.merge(op->getUsedQubits());
usedQubits.merge(op->getUsedQubitsPermuted(perm));
}
return usedQubits;
}
Expand Down
73 changes: 36 additions & 37 deletions src/operations/Operation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,51 +121,35 @@ bool Operation::equals(const Operation& op, const Permutation& perm1,
return false;
}

// check controls
if (nc1 != 0U) {
Controls controls1{};
if (perm1.empty()) {
controls1 = getControls();
} else {
for (const auto& control : getControls()) {
controls1.emplace(perm1.at(control.qubit), control.type);
}
if (isDiagonalGate()) {
// check pos. controls and targets together
const auto& usedQubits1 = getUsedQubitsPermuted(perm1);
const auto& usedQubits2 = op.getUsedQubitsPermuted(perm2);
if (usedQubits1 != usedQubits2) {
return false;
}

Controls controls2{};
if (perm2.empty()) {
controls2 = op.getControls();
} else {
for (const auto& control : op.getControls()) {
controls2.emplace(perm2.at(control.qubit), control.type);
std::set<Qubit> negControls1{};
for (const auto& control : getControls()) {
if (control.type == Control::Type::Neg) {
negControls1.emplace(perm1.apply(control.qubit));
}
}

if (controls1 != controls2) {
return false;
}
}

// check targets
std::set<Qubit> targets1{};
if (perm1.empty()) {
targets1 = {getTargets().begin(), getTargets().end()};
} else {
for (const auto& target : getTargets()) {
targets1.emplace(perm1.at(target));
std::set<Qubit> negControls2{};
for (const auto& control : op.getControls()) {
if (control.type == Control::Type::Neg) {
negControls2.emplace(perm2.apply(control.qubit));
}
}
return negControls1 == negControls2;
}

std::set<Qubit> targets2{};
if (perm2.empty()) {
targets2 = {op.getTargets().begin(), op.getTargets().end()};
} else {
for (const auto& target : op.getTargets()) {
targets2.emplace(perm2.at(target));
}
// check controls
if (nc1 != 0U &&
perm1.apply(getControls()) != perm2.apply(op.getControls())) {
return false;
}

return targets1 == targets2;
return perm1.apply(getTargets()) == perm2.apply(op.getTargets());
}

void Operation::addDepthContribution(std::vector<std::size_t>& depths) const {
Expand Down Expand Up @@ -198,4 +182,19 @@ auto Operation::isInverseOf(const Operation& other) const -> bool {
return operator==(*other.getInverted());
}

auto Operation::getUsedQubitsPermuted(const qc::Permutation& perm) const
-> std::set<Qubit> {
std::set<Qubit> usedQubits;
for (const auto& target : getTargets()) {
usedQubits.emplace(perm.apply(target));
}
for (const auto& control : getControls()) {
usedQubits.emplace(perm.apply(control.qubit));
}
return usedQubits;
}

auto Operation::getUsedQubits() const -> std::set<Qubit> {
return getUsedQubitsPermuted({});
}
} // namespace qc
3 changes: 2 additions & 1 deletion test/datastructures/test_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <algorithm>
#include <gtest/gtest.h>
#include <memory>
#include <stdexcept>
#include <tuple>
#include <vector>

Expand Down Expand Up @@ -39,7 +40,7 @@ TEST(Layer, ExecutableSet1) {
EXPECT_EQ(layer.getExecutableSet().size(), 1); // layer (1)
std::shared_ptr<Layer::DAGVertex> v = *(layer.getExecutableSet()).begin();
v->execute();
EXPECT_ANY_THROW(v->execute());
EXPECT_THROW(v->execute(), std::logic_error);
EXPECT_EQ(layer.getExecutableSet().size(), 3); // layer (2)
v = *(layer.getExecutableSet()).begin();
v->execute();
Expand Down
7 changes: 7 additions & 0 deletions test/na/test_nacomputation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,11 @@ TEST(NAComputation, General) {
"move (0, 1), (1, 1) to (4, 1), (5, 1);\n"
"store (4, 1), (5, 1) to (4, 0), (5, 0);\n");
}

TEST(NAComputation, EmptyPrint) {
const NAComputation qc;
std::stringstream ss;
ss << qc;
EXPECT_EQ(ss.str(), "init at;\n");
}
} // namespace na
9 changes: 9 additions & 0 deletions test/na/test_naoperation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include <gtest/gtest.h>
#include <memory>
#include <sstream>
#include <vector>

namespace na {
Expand Down Expand Up @@ -49,4 +50,12 @@ TEST(NAOperation, LocalOperation) {
EXPECT_ANY_THROW(
NALocalOperation(FullOpType{qc::RY, 1}, std::make_shared<Point>(0, 0)));
}

TEST(NAOperation, EmptyPrint) {
const NALocalOperation op(FullOpType{qc::RY, 0}, std::vector{qc::PI_2},
std::vector<std::shared_ptr<Point>>{});
std::stringstream ss;
ss << op;
EXPECT_EQ(ss.str(), "ry(1.5708) at;\n");
}
} // namespace na
19 changes: 19 additions & 0 deletions test/test_operation.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "Definitions.hpp"
#include "Permutation.hpp"
#include "operations/AodOperation.hpp"
#include "operations/CompoundOperation.hpp"
#include "operations/Expression.hpp"
Expand Down Expand Up @@ -139,6 +140,24 @@ TEST(Operation, IsDiagonalGate) {
EXPECT_TRUE(op2.isDiagonalGate());
}

TEST(Operation, Equality) {
const qc::StandardOperation op1(0, qc::Z);
const qc::StandardOperation op2(1, 0, qc::Z);
const qc::StandardOperation op3(0, 1, qc::Z);
const qc::StandardOperation op4({0, qc::Control::Type::Neg}, 1, qc::Z);
EXPECT_FALSE(op1 == op2);
EXPECT_TRUE(op2 == op3);
EXPECT_TRUE(op3 == op2);
EXPECT_FALSE(op2 == op4);

EXPECT_TRUE(op2.equals(op3, qc::Permutation{{{0, 0}, {1, 2}}},
qc::Permutation{{{0, 2}, {1, 0}}}));
EXPECT_FALSE(
op2.equals(op3, qc::Permutation{{{0, 0}, {1, 2}}}, qc::Permutation{}));
EXPECT_FALSE(op2.equals(op4, qc::Permutation{{{0, 0}, {1, 2}}},
qc::Permutation{{{0, 2}, {1, 0}}}));
}

TEST(StandardOperation, Move) {
const qc::StandardOperation moveOp({0, 1}, qc::OpType::Move);
EXPECT_EQ(moveOp.getTargets().size(), 2);
Expand Down
1 change: 0 additions & 1 deletion test/unittests/test_qfr_functionality.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1706,7 +1706,6 @@ TEST_F(QFRFunctionality, addControlClassicControlledOperation) {
TEST_F(QFRFunctionality, addControlNonUnitaryOperation) {
auto op = NonUnitaryOperation(0U, Measure);

EXPECT_THROW(static_cast<void>(op.getControls()), QFRException);
EXPECT_THROW(op.addControl(1), QFRException);
EXPECT_THROW(op.removeControl(1), QFRException);
EXPECT_THROW(op.clearControls(), QFRException);
Expand Down

0 comments on commit 0e4ff9e

Please sign in to comment.