Skip to content

Commit

Permalink
🔀 Merge pull request #52 from DVLab-NTU/qcir_adjoint
Browse files Browse the repository at this point in the history
Qcir adjoint
  • Loading branch information
JoshuaLau0220 authored Jan 28, 2024
2 parents ecb29e1 + 5ada35d commit d3865dc
Show file tree
Hide file tree
Showing 19 changed files with 76 additions and 50 deletions.
4 changes: 2 additions & 2 deletions src/argparse/arg_def.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ std::string type_string(DummyArgType const& /*unused*/) { return "dummy"; }
template <>
bool parse_from_string(bool& val, std::string_view token) {
using namespace std::string_view_literals;
if (dvlab::str::is_prefix_of(dvlab::str::tolower_string(token), "true")) {
if (dvlab::str::is_prefix_of(dvlab::str::tolower_string(token), "true") || token == "1"sv) {
val = true;
return true;
} else if (dvlab::str::is_prefix_of(dvlab::str::tolower_string(token), "false")) {
} else if (dvlab::str::is_prefix_of(dvlab::str::tolower_string(token), "false") || token == "0"sv) {
val = false;
return true;
}
Expand Down
9 changes: 4 additions & 5 deletions src/argparse/arg_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

#include "./arg_type.hpp"
#include "argparse/arg_def.hpp"
#include "util/ordered_hashset.hpp"

namespace dvlab::argparse {

Expand All @@ -29,7 +28,7 @@ class MutuallyExclusiveGroup {
MutExGroupImpl(ArgumentParser& parser)
: _parser{&parser} {}
ArgumentParser* _parser;
dvlab::utils::ordered_hashset<std::string, detail::heterogeneous_string_hash, std::equal_to<>> _arguments;
std::vector<std::string> _arg_names;
bool _required = false;
bool _parsed = false;
};
Expand All @@ -42,7 +41,7 @@ class MutuallyExclusiveGroup {
requires valid_argument_type<T>
ArgType<T>& add_argument(std::string_view name, std::convertible_to<std::string> auto... alias);

bool contains(std::string_view name) const { return _pimpl->_arguments.contains(name); }
bool contains(std::string_view name) const { return std::ranges::find(_pimpl->_arg_names, name) != _pimpl->_arg_names.end(); }
MutuallyExclusiveGroup required(bool is_req) {
_pimpl->_required = is_req;
return *this;
Expand All @@ -52,9 +51,9 @@ class MutuallyExclusiveGroup {
bool is_required() const { return _pimpl->_required; }
bool is_parsed() const { return _pimpl->_parsed; }

size_t size() const noexcept { return _pimpl->_arguments.size(); }
size_t size() const noexcept { return _pimpl->_arg_names.size(); }

auto const& get_arg_names() const { return _pimpl->_arguments; }
auto const& get_arg_names() const { return _pimpl->_arg_names; }

private:
std::shared_ptr<MutExGroupImpl> _pimpl;
Expand Down
2 changes: 1 addition & 1 deletion src/argparse/arg_parser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ template <typename T>
requires valid_argument_type<T>
ArgType<T>& MutuallyExclusiveGroup::add_argument(std::string_view name, std::convertible_to<std::string> auto... alias) {
ArgType<T>& return_ref = _pimpl->_parser->add_argument<T>(name, alias...);
_pimpl->_arguments.insert(return_ref._name);
_pimpl->_arg_names.emplace_back(name);
return return_ref;
}

Expand Down
2 changes: 0 additions & 2 deletions src/argparse/arg_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,12 @@

#include <filesystem>

#include "util/ordered_hashset.hpp"
#include "util/trie.hpp"

namespace dvlab::argparse {

static_assert(is_container_type<std::vector<int>> == true);
static_assert(is_container_type<std::vector<std::string>> == true);
static_assert(is_container_type<dvlab::utils::ordered_hashset<float>> == true);
static_assert(is_container_type<std::string> == false);
static_assert(is_container_type<std::array<int, 3>> == false);

Expand Down
6 changes: 1 addition & 5 deletions src/device/device_cmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@ using dvlab::CmdExecResult;

namespace qsyn::device {

bool device_mgr_not_empty(DeviceMgr const& device_mgr) {
return dvlab::utils::expect(!device_mgr.empty(), "Device list is empty now. Please DTRead first.");
}

std::function<bool(size_t const&)> valid_device_id(qsyn::device::DeviceMgr const& device_mgr) {
return [&device_mgr](size_t const& id) {
if (device_mgr.is_id(id)) return true;
Expand Down Expand Up @@ -149,7 +145,7 @@ dvlab::Command device_print_cmd(qsyn::device::DeviceMgr& device_mgr) {
"print routing paths between q1 and q2");
},
[&device_mgr](ArgumentParser const& parser) {
if (!qsyn::device::device_mgr_not_empty(device_mgr)) return CmdExecResult::error;
if (!dvlab::utils::mgr_has_data(device_mgr)) return CmdExecResult::error;

if (parser.parsed("--edges")) {
device_mgr.get()->print_edges(parser.get<std::vector<size_t>>("--edges"));
Expand Down
2 changes: 0 additions & 2 deletions src/device/device_mgr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ namespace qsyn::device {

using DeviceMgr = dvlab::utils::DataStructureManager<Device>;

bool device_mgr_not_empty(DeviceMgr const& device_mgr);

} // namespace qsyn::device

template <>
Expand Down
3 changes: 2 additions & 1 deletion src/duostra/duostra_cmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "duostra/duostra_def.hpp"
#include "qcir/qcir.hpp"
#include "qcir/qcir_cmd.hpp"
#include "util/data_structure_manager_common_cmd.hpp"
#include "util/text_format.hpp"

using namespace dvlab::argparse;
Expand Down Expand Up @@ -221,7 +222,7 @@ Command duostra_cmd(qcir::QCirMgr& qcir_mgr, device::DeviceMgr& device_mgr) {
},

[&](ArgumentParser const& parser) {
if (!qcir_mgr_not_empty(qcir_mgr) || !device_mgr_not_empty(device_mgr)) return CmdExecResult::error;
if (!dvlab::utils::mgr_has_data(qcir_mgr) || !dvlab::utils::mgr_has_data(device_mgr)) return CmdExecResult::error;
#ifdef __GNUC__
char const* const omp_wait_policy = std::getenv("OMP_WAIT_POLICY");

Expand Down
3 changes: 1 addition & 2 deletions src/qcir/optimizer/basic_optimization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,7 @@ bool Optimizer::parse_gate(QCirGate* gate, bool do_swap, bool minimize_czs) {
* @param circuit
*/
QCir Optimizer::_build_from_storage(size_t n_qubits, bool reversed) {
QCir circuit;
circuit.add_qubits(n_qubits);
QCir circuit{n_qubits};

while (any_of(_gates.begin(), _gates.end(), [](auto& p_g) { return p_g.second.size(); })) {
dvlab::utils::ordered_hashset<size_t> available_id;
Expand Down
8 changes: 4 additions & 4 deletions src/qcir/optimizer/optimizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ class Optimizer {

// basic optimization
struct BasicOptimizationConfig {
bool doSwap;
bool separateCorrection;
size_t maxIter;
bool printStatistics;
bool doSwap = true;
bool separateCorrection = false;
size_t maxIter = 1000;
bool printStatistics = false;
};
std::optional<QCir> basic_optimization(QCir const& qcir, BasicOptimizationConfig const& config);
QCir parse_forward(QCir const& qcir, bool do_minimize_czs, BasicOptimizationConfig const& config);
Expand Down
3 changes: 2 additions & 1 deletion src/qcir/optimizer/optimizer_cmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "../qcir_mgr.hpp"
#include "./optimizer.hpp"
#include "cli/cli.hpp"
#include "util/data_structure_manager_common_cmd.hpp"
#include "util/util.hpp"

using namespace dvlab::argparse;
Expand Down Expand Up @@ -51,7 +52,7 @@ Command qcir_optimize_cmd(QCirMgr& qcir_mgr) {
.help("Only perform trivial optimizations.");
},
[&](ArgumentParser const& parser) {
if (!qcir_mgr_not_empty(qcir_mgr)) return CmdExecResult::error;
if (!dvlab::utils::mgr_has_data(qcir_mgr)) return CmdExecResult::error;
Optimizer optimizer;
std::optional<QCir> result;
std::string procedure_str{};
Expand Down
3 changes: 1 addition & 2 deletions src/qcir/optimizer/trivial_optimization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@ std::optional<QCir> Optimizer::trivial_optimization(QCir const& qcir) {
spdlog::info("Start trivial optimization");

reset(qcir);
QCir result;
QCir result{qcir.get_num_qubits()};
result.set_filename(qcir.get_filename());
result.add_procedures(qcir.get_procedures());
result.add_qubits(qcir.get_num_qubits());

auto const gate_list = qcir.get_topologically_ordered_gates();
for (auto gate : gate_list) {
Expand Down
5 changes: 5 additions & 0 deletions src/qcir/qcir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ class QCir { // NOLINT(hicpp-special-member-functions, cppcoreguidelines-specia
public:
using QubitIdType = qsyn::QubitIdType;
QCir() {}
QCir(size_t n_qubits) {
add_qubits(n_qubits);
}
~QCir() = default;
QCir(QCir const& other);
QCir(QCir&& other) noexcept = default;
Expand Down Expand Up @@ -158,6 +161,8 @@ class QCir { // NOLINT(hicpp-special-member-functions, cppcoreguidelines-specia
void update_gate_time() const;
void print_zx_form_topological_order();

void adjoint();

// DFS functions
template <typename F>
void topological_traverse(F lambda) const {
Expand Down
20 changes: 20 additions & 0 deletions src/qcir/qcir_action.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,24 @@ void QCir::reset() {
_global_dfs_counter = 1;
}

void QCir::adjoint() {
for (auto& g : _qgates) {
g->adjoint();
auto qubits = g->get_qubits();
for (auto& q : qubits) {
std::swap(q._prev, q._next);
}
g->set_qubits(qubits);
}

for (auto& q : _qubits) {
auto first = q->get_first();
auto last = q->get_last();
q->set_first(last);
q->set_last(first);
}

_dirty = true;
}

} // namespace qsyn::qcir
42 changes: 23 additions & 19 deletions src/qcir/qcir_cmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,6 @@ using dvlab::Command;

namespace qsyn::qcir {

bool qcir_mgr_not_empty(QCirMgr const& qcir_mgr) {
if (qcir_mgr.empty()) {
spdlog::error("QCir list is empty. Please create a QCir first!!");
spdlog::info("Use QCNew/QCBAdd to add a new QCir, or QCCRead to read a QCir from a file.");
return false;
}
return true;
}

std::function<bool(size_t const&)> valid_qcir_id(QCirMgr const& qcir_mgr) {
return [&](size_t const& id) {
if (qcir_mgr.is_id(id)) return true;
Expand All @@ -53,7 +44,7 @@ std::function<bool(size_t const&)> valid_qcir_id(QCirMgr const& qcir_mgr) {

std::function<bool(size_t const&)> valid_qcir_gate_id(QCirMgr const& qcir_mgr) {
return [&](size_t const& id) {
if (!qcir_mgr_not_empty(qcir_mgr)) return false;
if (!dvlab::utils::mgr_has_data(qcir_mgr)) return false;
if (qcir_mgr.get()->get_gate(id) != nullptr) return true;
spdlog::error("Gate ID {} does not exist!!", id);
return false;
Expand All @@ -62,7 +53,7 @@ std::function<bool(size_t const&)> valid_qcir_gate_id(QCirMgr const& qcir_mgr) {

std::function<bool(QubitIdType const&)> valid_qcir_qubit_id(QCirMgr const& qcir_mgr) {
return [&](QubitIdType const& id) {
if (!qcir_mgr_not_empty(qcir_mgr)) return false;
if (!dvlab::utils::mgr_has_data(qcir_mgr)) return false;
if (qcir_mgr.get()->get_qubit(id) != nullptr) return true;
spdlog::error("Qubit ID {} does not exist!!", id);
return false;
Expand All @@ -79,7 +70,7 @@ dvlab::Command qcir_compose_cmd(QCirMgr& qcir_mgr) {
.help("the ID of the circuit to compose with");
},
[&](ArgumentParser const& parser) {
if (!qcir_mgr_not_empty(qcir_mgr)) return CmdExecResult::error;
if (!dvlab::utils::mgr_has_data(qcir_mgr)) return CmdExecResult::error;
qcir_mgr.get()->compose(*qcir_mgr.find_by_id(parser.get<size_t>("id")));
return CmdExecResult::done;
}};
Expand All @@ -95,7 +86,7 @@ dvlab::Command qcir_tensor_product_cmd(QCirMgr& qcir_mgr) {
.help("the ID of the circuit to tensor with");
},
[&](ArgumentParser const& parser) {
if (!qcir_mgr_not_empty(qcir_mgr)) return CmdExecResult::error;
if (!dvlab::utils::mgr_has_data(qcir_mgr)) return CmdExecResult::error;
qcir_mgr.get()->tensor_product(*qcir_mgr.find_by_id(parser.get<size_t>("id")));
return CmdExecResult::done;
}};
Expand Down Expand Up @@ -214,7 +205,7 @@ dvlab::Command qcir_write_cmd(QCirMgr const& qcir_mgr) {
.help("the output format of the QCir. If not specified, the default format is automatically chosen based on the output file extension");
},
[&](ArgumentParser const& parser) {
if (!qcir_mgr_not_empty(qcir_mgr)) return CmdExecResult::error;
if (!dvlab::utils::mgr_has_data(qcir_mgr)) return CmdExecResult::error;

enum class OutputFormat { qasm,
latex_qcircuit };
Expand Down Expand Up @@ -273,7 +264,7 @@ Command qcir_draw_cmd(QCirMgr const& qcir_mgr) {
},
[&](ArgumentParser const& parser) {
namespace fs = std::filesystem;
if (!qcir_mgr_not_empty(qcir_mgr)) return CmdExecResult::error;
if (!dvlab::utils::mgr_has_data(qcir_mgr)) return CmdExecResult::error;

auto output_path = fs::path{parser.get<std::string>("output-path")};
auto scale = parser.get<float>("--scale");
Expand Down Expand Up @@ -321,7 +312,7 @@ dvlab::Command qcir_print_cmd(QCirMgr const& qcir_mgr) {
.help("print the circuit diagram. If `--verbose` is also specified, print the circuit diagram in the qiskit style");
},
[&](ArgumentParser const& parser) {
if (!qcir_mgr_not_empty(qcir_mgr)) {
if (!dvlab::utils::mgr_has_data(qcir_mgr)) {
return CmdExecResult::error;
}

Expand Down Expand Up @@ -433,7 +424,7 @@ dvlab::Command qcir_gate_add_cmd(QCirMgr& qcir_mgr) {
.help("the qubits on which the gate applies");
},
[=, &qcir_mgr](ArgumentParser const& parser) {
if (!qcir_mgr_not_empty(qcir_mgr)) return CmdExecResult::error;
if (!dvlab::utils::mgr_has_data(qcir_mgr)) return CmdExecResult::error;
bool const do_prepend = parser.parsed("--prepend");

auto type = parser.get<std::string>("type");
Expand Down Expand Up @@ -508,7 +499,7 @@ dvlab::Command qcir_gate_delete_cmd(QCirMgr& qcir_mgr) {
.help("the id to be removed");
},
[&](ArgumentParser const& parser) {
if (!qcir_mgr_not_empty(qcir_mgr)) return CmdExecResult::error;
if (!dvlab::utils::mgr_has_data(qcir_mgr)) return CmdExecResult::error;
qcir_mgr.get()->remove_gate(parser.get<size_t>("id"));
return CmdExecResult::done;
}};
Expand Down Expand Up @@ -562,7 +553,7 @@ dvlab::Command qcir_qubit_delete_cmd(QCirMgr& qcir_mgr) {
.help("the ID of the qubit to be removed");
},
[&](ArgumentParser const& parser) {
if (!qcir_mgr_not_empty(qcir_mgr)) return CmdExecResult::error;
if (!dvlab::utils::mgr_has_data(qcir_mgr)) return CmdExecResult::error;
if (!qcir_mgr.get()->remove_qubit(parser.get<QubitIdType>("id")))
return CmdExecResult::error;
else
Expand All @@ -588,6 +579,18 @@ dvlab::Command qcir_qubit_cmd(QCirMgr& qcir_mgr) {
return cmd;
}

dvlab::Command qcir_adjoint_cmd(QCirMgr& qcir_mgr) {
return {"adjoint",
[](ArgumentParser& parser) {
parser.description("transform the QCir to its adjoint, i.e., reverse the order of gates and replace each gate with its adjoint version");
},
[&](ArgumentParser const& /*parser*/) {
if (!dvlab::utils::mgr_has_data(qcir_mgr)) return CmdExecResult::error;
qcir_mgr.get()->adjoint();
return CmdExecResult::done;
}};
}

Command qcir_cmd(QCirMgr& qcir_mgr) {
auto cmd = dvlab::utils::mgr_root_cmd(qcir_mgr);

Expand All @@ -603,6 +606,7 @@ Command qcir_cmd(QCirMgr& qcir_mgr) {
cmd.add_subcommand(qcir_write_cmd(qcir_mgr));
cmd.add_subcommand(qcir_print_cmd(qcir_mgr));
cmd.add_subcommand(qcir_draw_cmd(qcir_mgr));
cmd.add_subcommand(qcir_adjoint_cmd(qcir_mgr));
cmd.add_subcommand(qcir_gate_cmd(qcir_mgr));
cmd.add_subcommand(qcir_qubit_cmd(qcir_mgr));
cmd.add_subcommand(qcir_optimize_cmd(qcir_mgr));
Expand Down
6 changes: 6 additions & 0 deletions src/qcir/qcir_gate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,4 +214,10 @@ void QCirGate::_print_single_qubit_or_controlled_gate(std::string gtype, bool sh
fmt::println("Execute at t= {}", get_time());
}

void QCirGate::adjoint() {
if (!is_fixed_phase_gate(_rotation_category)) {
_phase = -_phase;
}
}

} // namespace qsyn::qcir
2 changes: 2 additions & 0 deletions src/qcir/qcir_gate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ class QCirGate {
bool is_cz() const { return _rotation_category == GateRotationCategory::pz && _phase == dvlab::Phase(1) && _qubits.size() == 2; }
bool is_swap() const { return _rotation_category == GateRotationCategory::swap; }

void adjoint();

private:
protected:
size_t _id;
Expand Down
2 changes: 0 additions & 2 deletions src/qcir/qcir_mgr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ namespace qsyn::qcir {

using QCirMgr = dvlab::utils::DataStructureManager<QCir>;

bool qcir_mgr_not_empty(QCirMgr const& qcir_mgr);

} // namespace qsyn::qcir

template <>
Expand Down
2 changes: 1 addition & 1 deletion src/tensor/tensor_cmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Command tensor_print_cmd(TensorMgr& tensor_mgr) {
Command tensor_adjoint_cmd(TensorMgr& tensor_mgr) {
return {"adjoint",
[&](ArgumentParser& parser) {
parser.description("adjoint the specified tensor");
parser.description("transform the tensor to its adjoint");

parser.add_argument<size_t>("id")
.constraint(valid_tensor_id(tensor_mgr))
Expand Down
2 changes: 1 addition & 1 deletion src/zx/zx_cmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ Command zxgraph_assign_boundary_cmd(ZXGraphMgr& zxgraph_mgr) {
Command zxgraph_adjoint_cmd(ZXGraphMgr& zxgraph_mgr) {
return {"adjoint",
[](ArgumentParser& parser) {
parser.description("adjoint ZXGraph");
parser.description("transform the ZXGraph to its adjoint; i.e., swap the input/output vertices and replace each vertex with its adjoint");
},
[&](ArgumentParser const& /*parser*/) {
if (!dvlab::utils::mgr_has_data(zxgraph_mgr)) return CmdExecResult::error;
Expand Down

0 comments on commit d3865dc

Please sign in to comment.