Skip to content

Commit

Permalink
🔀 Merge pull request #101 from DVLab-NTU/refactor/qcir_flexible_gate_…
Browse files Browse the repository at this point in the history
…type

Refactor/qcir flexible gate type
  • Loading branch information
JoshuaLau0220 authored Mar 17, 2024
2 parents 65f7329 + 7cb7468 commit 3641a47
Show file tree
Hide file tree
Showing 26 changed files with 313 additions and 798 deletions.
146 changes: 83 additions & 63 deletions src/convert/qcir_to_tableau.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,29 +28,6 @@ namespace {

#include <algorithm>

Pauli to_pauli(qcir::GateRotationCategory const& category) {
switch (category) {
case qcir::GateRotationCategory::rz:
case qcir::GateRotationCategory::pz:
return Pauli::z;
case qcir::GateRotationCategory::rx:
case qcir::GateRotationCategory::px:
return Pauli::x;
case qcir::GateRotationCategory::ry:
case qcir::GateRotationCategory::py:
return Pauli::y;
default:
DVLAB_UNREACHABLE("Invalid rotation category");
}
}

bool is_r_type_rotation(qcir::GateRotationCategory const& category) {
return (
category == qcir::GateRotationCategory::rz ||
category == qcir::GateRotationCategory::rx ||
category == qcir::GateRotationCategory::ry);
}

template <std::ranges::range R, typename T>
bool contains(R const& r, T const& value) {
return std::ranges::find(r, value) != r.end();
Expand Down Expand Up @@ -104,14 +81,25 @@ std::vector<size_t> get_qubit_idx_vec(QubitIdList const& qubits) {
return ret;
}

void implement_mcrz(Tableau& tableau, QubitIdList const& qubits, dvlab::Phase const& phase) {
void implement_mcr(Tableau& tableau, QubitIdList const& qubits, dvlab::Phase const& ph, Pauli pauli) {
if (std::holds_alternative<StabilizerTableau>(tableau.back())) {
tableau.push_back(std::vector<PauliRotation>{});
}

dvlab::Phase const phase =
ph *
dvlab::Rational(1, static_cast<int>(std::pow(2, gsl::narrow<double>(qubits.size()) - 1)));

auto const targ = gsl::narrow<size_t>(qubits.back());
// convert rotation plane first
if (pauli == Pauli::x) {
tableau.h(targ);
} else if (pauli == Pauli::y) {
tableau.v(targ);
}
// guaranteed to be a vector of PauliRotation
auto& last_rotation_group = std::get<std::vector<PauliRotation>>(tableau.back());

auto const targ = gsl::narrow<size_t>(qubits.back());
for (auto const comb_size : std::views::iota(0ul, qubits.size())) {
bool const is_neg = comb_size % 2;
auto qubit_idx_vec = get_qubit_idx_vec(qubits);
Expand All @@ -126,12 +114,30 @@ void implement_mcrz(Tableau& tableau, QubitIdList const& qubits, dvlab::Phase co
last_rotation_group.push_back(PauliRotation(pauli_range.begin(), pauli_range.end(), is_neg ? -phase : phase));
} while (next_combination(qubit_idx_vec, comb_size));
}
// restore rotation plane
if (pauli == Pauli::x) {
tableau.h(targ);
} else if (pauli == Pauli::y) {
tableau.vdg(targ);
}
}

void implement_mcpz(Tableau& tableau, QubitIdList const& qubits, dvlab::Phase const& phase) {
void implement_mcp(Tableau& tableau, QubitIdList const& qubits, dvlab::Phase const& ph, Pauli pauli) {
if (std::holds_alternative<StabilizerTableau>(tableau.back())) {
tableau.push_back(std::vector<PauliRotation>{});
}

dvlab::Phase const phase =
ph *
dvlab::Rational(1, static_cast<int>(std::pow(2, gsl::narrow<double>(qubits.size()) - 1)));

auto const targ = gsl::narrow<size_t>(qubits.back());
// convert rotation plane first
if (pauli == Pauli::x) {
tableau.h(targ);
} else if (pauli == Pauli::y) {
tableau.v(targ);
}
// guaranteed to be a vector of PauliRotation
auto& last_rotation_group = std::get<std::vector<PauliRotation>>(tableau.back());

Expand All @@ -148,29 +154,6 @@ void implement_mcpz(Tableau& tableau, QubitIdList const& qubits, dvlab::Phase co
last_rotation_group.push_back(PauliRotation(pauli_range.begin(), pauli_range.end(), is_neg ? -phase : phase));
} while (next_combination(qubit_idx_vec, comb_size));
}
}

void implement_rotation_gate(Tableau& tableau, qcir::GateRotationCategory category, dvlab::Phase const& ph, QubitIdList const& qubits) {
auto const pauli = to_pauli(category);

auto const targ = gsl::narrow<size_t>(qubits.back());
// convert rotation plane first
if (pauli == Pauli::x) {
tableau.h(targ);
} else if (pauli == Pauli::y) {
tableau.v(targ);
}

dvlab::Phase const phase =
ph *
dvlab::Rational(1, static_cast<int>(std::pow(2, gsl::narrow<double>(qubits.size()) - 1)));
// implement rotation in Z plane
if (is_r_type_rotation(category)) {
implement_mcrz(tableau, qubits, phase);
} else {
implement_mcpz(tableau, qubits, phase);
}

// restore rotation plane
if (pauli == Pauli::x) {
tableau.h(targ);
Expand Down Expand Up @@ -215,7 +198,7 @@ bool append_to_tableau(qcir::PZGate const& op, experimental::Tableau& tableau, Q
} else if (op.get_phase() == dvlab::Phase(-1, 2)) {
tableau.sdg(qubits[0]);
} else {
experimental::implement_rotation_gate(tableau, qcir::GateRotationCategory::pz, op.get_phase(), qubits);
experimental::implement_mcp(tableau, qubits, op.get_phase(), experimental::Pauli::z);
}

return true;
Expand All @@ -230,7 +213,7 @@ bool append_to_tableau(qcir::PXGate const& op, experimental::Tableau& tableau, Q
} else if (op.get_phase() == dvlab::Phase(-1, 2)) {
tableau.vdg(qubits[0]);
} else {
experimental::implement_rotation_gate(tableau, qcir::GateRotationCategory::px, op.get_phase(), qubits);
experimental::implement_mcp(tableau, qubits, op.get_phase(), experimental::Pauli::x);
}

return true;
Expand All @@ -249,7 +232,7 @@ bool append_to_tableau(qcir::PYGate const& op, experimental::Tableau& tableau, Q
tableau.vdg(qubits[0]);
tableau.s(qubits[0]);
} else {
experimental::implement_rotation_gate(tableau, qcir::GateRotationCategory::py, op.get_phase(), qubits);
experimental::implement_mcp(tableau, qubits, op.get_phase(), experimental::Pauli::y);
}

return true;
Expand All @@ -264,7 +247,7 @@ bool append_to_tableau(qcir::RZGate const& op, experimental::Tableau& tableau, Q
} else if (op.get_phase() == dvlab::Phase(-1, 2)) {
tableau.sdg(qubits[0]);
} else {
experimental::implement_rotation_gate(tableau, qcir::GateRotationCategory::rz, op.get_phase(), qubits);
experimental::implement_mcr(tableau, qubits, op.get_phase(), experimental::Pauli::z);
}
return true;
}
Expand All @@ -278,7 +261,7 @@ bool append_to_tableau(qcir::RXGate const& op, experimental::Tableau& tableau, Q
} else if (op.get_phase() == dvlab::Phase(-1, 2)) {
tableau.vdg(qubits[0]);
} else {
experimental::implement_rotation_gate(tableau, qcir::GateRotationCategory::rx, op.get_phase(), qubits);
experimental::implement_mcr(tableau, qubits, op.get_phase(), experimental::Pauli::x);
}
return true;
}
Expand All @@ -296,21 +279,58 @@ bool append_to_tableau(qcir::RYGate const& op, experimental::Tableau& tableau, Q
tableau.vdg(qubits[0]);
tableau.s(qubits[0]);
} else {
experimental::implement_rotation_gate(tableau, qcir::GateRotationCategory::ry, op.get_phase(), qubits);
experimental::implement_mcr(tableau, qubits, op.get_phase(), experimental::Pauli::y);
}
return true;
}

template <>
bool append_to_tableau(qcir::LegacyGateType const& op, experimental::Tableau& tableau, QubitIdList const& qubits) {
if (op.get_type() == "cx") {
tableau.cx(qubits[0], qubits[1]);
} else if (op.get_type() == "cz") {
tableau.cz(qubits[0], qubits[1]);
} else {
experimental::implement_rotation_gate(tableau, op.get_rotation_category(), op.get_phase(), qubits);
bool append_to_tableau(qcir::ControlGate const& op, experimental::Tableau& tableau, QubitIdList const& qubits) {
if (auto target_op = op.get_target_operation().get_underlying_if<qcir::PXGate>()) {
if (op.get_num_qubits() == 2 && target_op->get_phase() == dvlab::Phase(1)) {
tableau.cx(qubits[0], qubits[1]);
} else {
experimental::implement_mcp(tableau, qubits, target_op->get_phase(), experimental::Pauli::x);
}
return true;
}
return true;

if (auto target_op = op.get_target_operation().get_underlying_if<qcir::PYGate>()) {
if (op.get_num_qubits() == 2 && target_op->get_phase() == dvlab::Phase(1)) {
tableau.sdg(qubits[1]);
tableau.cx(qubits[0], qubits[1]);
tableau.s(qubits[1]);
} else {
experimental::implement_mcp(tableau, qubits, target_op->get_phase(), experimental::Pauli::y);
}
return true;
}

if (auto target_op = op.get_target_operation().get_underlying_if<qcir::PZGate>()) {
if (op.get_num_qubits() == 2 && target_op->get_phase() == dvlab::Phase(1)) {
tableau.cz(qubits[0], qubits[1]);
} else {
experimental::implement_mcp(tableau, qubits, target_op->get_phase(), experimental::Pauli::z);
}
return true;
}

if (auto target_op = op.get_target_operation().get_underlying_if<qcir::RXGate>()) {
experimental::implement_mcr(tableau, qubits, target_op->get_phase(), experimental::Pauli::x);
return true;
}

if (auto target_op = op.get_target_operation().get_underlying_if<qcir::RYGate>()) {
experimental::implement_mcr(tableau, qubits, target_op->get_phase(), experimental::Pauli::y);
return true;
}

if (auto target_op = op.get_target_operation().get_underlying_if<qcir::RZGate>()) {
experimental::implement_mcr(tableau, qubits, target_op->get_phase(), experimental::Pauli::z);
return true;
}

return false;
}

namespace experimental {
Expand Down
22 changes: 5 additions & 17 deletions src/convert/qcir_to_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,23 +86,11 @@ std::optional<QTensor<double>> to_tensor(RYGate const& op) {
}

template <>
std::optional<QTensor<double>> to_tensor(LegacyGateType const& op) {
switch (op.get_rotation_category()) {
case GateRotationCategory::pz:
return QTensor<double>::control(QTensor<double>::pzgate(op.get_phase()), op.get_num_qubits() - 1);
case GateRotationCategory::rz:
return QTensor<double>::control(QTensor<double>::rzgate(op.get_phase()), op.get_num_qubits() - 1);
case GateRotationCategory::px:
return QTensor<double>::control(QTensor<double>::pxgate(op.get_phase()), op.get_num_qubits() - 1);
case GateRotationCategory::rx:
return QTensor<double>::control(QTensor<double>::rxgate(op.get_phase()), op.get_num_qubits() - 1);
case GateRotationCategory::py:
return QTensor<double>::control(QTensor<double>::pygate(op.get_phase()), op.get_num_qubits() - 1);
case GateRotationCategory::ry:
return QTensor<double>::control(QTensor<double>::rygate(op.get_phase()), op.get_num_qubits() - 1);

default:
return std::nullopt;
std::optional<QTensor<double>> to_tensor(ControlGate const& op) {
if (auto target_tensor = to_tensor(op.get_target_operation())) {
return QTensor<double>::control(*target_tensor, op.get_num_qubits() - op.get_target_operation().get_num_qubits());
} else {
return std::nullopt;
}
}

Expand Down
62 changes: 35 additions & 27 deletions src/convert/qcir_to_zxgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace qsyn {

using zx::ZXVertex, zx::ZXGraph, zx::VertexType, zx::EdgeType;

using qcir::GateRotationCategory, qcir::QCir;
using qcir::QCir;

namespace {

Expand Down Expand Up @@ -375,33 +375,41 @@ std::optional<ZXGraph> to_zxgraph(qcir::RYGate const& op) {
}

template <>
std::optional<ZXGraph> to_zxgraph(qcir::LegacyGateType const& op) {
assert(op.get_num_qubits() != 1);
switch (op.get_rotation_category()) {
case GateRotationCategory::rz:
return create_mcr_zx_form(op.get_num_qubits(), op.get_phase(), RotationAxis::z);
case GateRotationCategory::rx:
return create_mcr_zx_form(op.get_num_qubits(), op.get_phase(), RotationAxis::x);
case GateRotationCategory::ry:
return create_mcr_zx_form(op.get_num_qubits(), op.get_phase(), RotationAxis::y);

case GateRotationCategory::pz:
if (op.get_num_qubits() == 2 && op.get_phase() == Phase(1)) {
return create_cz_zx_form();
} else {
return create_mcp_zx_form(op.get_num_qubits(), op.get_phase(), RotationAxis::z);
}
case GateRotationCategory::px:
if (op.get_num_qubits() == 2 && op.get_phase() == Phase(1)) {
return create_cx_zx_form();
} else {
return create_mcp_zx_form(op.get_num_qubits(), op.get_phase(), RotationAxis::x);
}
case GateRotationCategory::py:
return create_mcp_zx_form(op.get_num_qubits(), op.get_phase(), RotationAxis::y);
default:
return std::nullopt;
std::optional<ZXGraph> to_zxgraph(qcir::ControlGate const& op) {
auto const& target_op = op.get_target_operation();

if (auto const px = target_op.get_underlying_if<qcir::PXGate>()) {
if (op.get_num_qubits() == 2 && px->get_phase() == Phase(1)) {
return create_cx_zx_form();
}
return create_mcp_zx_form(op.get_num_qubits(), px->get_phase(), RotationAxis::x);
}

if (auto const py = target_op.get_underlying_if<qcir::PYGate>()) {
return create_mcp_zx_form(op.get_num_qubits(), py->get_phase(), RotationAxis::y);
}

if (auto const pz = target_op.get_underlying_if<qcir::PZGate>()) {
if (op.get_num_qubits() == 2 && pz->get_phase() == Phase(1)) {
return create_cz_zx_form();
}

return create_mcp_zx_form(op.get_num_qubits(), pz->get_phase(), RotationAxis::z);
}

if (auto const rx = target_op.get_underlying_if<qcir::RXGate>()) {
return create_mcr_zx_form(op.get_num_qubits(), rx->get_phase(), RotationAxis::x);
}

if (auto const ry = target_op.get_underlying_if<qcir::RYGate>()) {
return create_mcr_zx_form(op.get_num_qubits(), ry->get_phase(), RotationAxis::y);
}

if (auto const rz = target_op.get_underlying_if<qcir::RZGate>()) {
return create_mcr_zx_form(op.get_num_qubits(), rz->get_phase(), RotationAxis::z);
}

return std::nullopt;
}

std::optional<ZXGraph> to_zxgraph(qcir::QCirGate const& gate) {
Expand Down
4 changes: 2 additions & 2 deletions src/convert/tableau_to_qcir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ void add_clifford_gate(qcir::QCir& qcir, CliffordOperator const& op) {
qcir.append(qcir::SGate(), {gsl::narrow<QubitIdType>(qubits[0])});
break;
case COT::cx:
qcir.add_gate("cx", {gsl::narrow<QubitIdType>(qubits[0]), gsl::narrow<QubitIdType>(qubits[1])}, {}, true);
qcir.append(qcir::CXGate(), {gsl::narrow<QubitIdType>(qubits[0]), gsl::narrow<QubitIdType>(qubits[1])});
break;
case COT::sdg:
qcir.append(qcir::SdgGate(), {gsl::narrow<QubitIdType>(qubits[0])});
Expand All @@ -52,7 +52,7 @@ void add_clifford_gate(qcir::QCir& qcir, CliffordOperator const& op) {
qcir.append(qcir::ZGate(), {gsl::narrow<QubitIdType>(qubits[0])});
break;
case COT::cz:
qcir.add_gate("cz", {gsl::narrow<QubitIdType>(qubits[0]), gsl::narrow<QubitIdType>(qubits[1])}, {}, true);
qcir.append(qcir::CZGate(), {gsl::narrow<QubitIdType>(qubits[0]), gsl::narrow<QubitIdType>(qubits[1])});
break;
case COT::swap:
qcir.append(qcir::SwapGate(), {gsl::narrow<QubitIdType>(qubits[0]), gsl::narrow<QubitIdType>(qubits[1])});
Expand Down
12 changes: 5 additions & 7 deletions src/device/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -515,13 +515,11 @@ bool Device::_parse_gate_set(std::string const& gate_set_str) {
_topology->add_gate_type(op->get_repr().substr(0, op->get_repr().find_first_of('(')));
return std::make_optional(op->get_type());
}
auto gate_type = str_to_gate_type(str);
if (!gate_type.has_value()) {
spdlog::error("unsupported gate type \"{}\"!!", str);
return std::nullopt;
};
_topology->add_gate_type(gate_type_to_str(*gate_type));
return std::make_optional(gate_type_to_str(*gate_type));
if (auto op = qcir::str_to_operation(str, {dvlab::Phase()}); op.has_value()) {
_topology->add_gate_type(op->get_repr().substr(0, op->get_repr().find_first_of('(')));
return std::make_optional(op->get_type());
}
return std::nullopt;
});

return std::ranges::all_of(gate_set_view, [](auto const& gate_type) { return gate_type.has_value(); });
Expand Down
8 changes: 4 additions & 4 deletions src/duostra/duostra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,14 @@ void Duostra::build_circuit_by_result() {
if (qubits[1] != max_qubit_id) {
qu.emplace_back(qubits[1]);
}
if (operation.get_operation() == SwapGate{}) {
if (operation.get_operation().is<SwapGate>()) {
// NOTE - Decompose SWAP into three CX
QubitIdList qu_reverse;
qu_reverse.emplace_back(qubits[1]);
qu_reverse.emplace_back(qubits[0]);
_physical_circuit->append(LegacyGateType(std::make_tuple(GateRotationCategory::px, 2, Phase(1))), qu);
_physical_circuit->append(LegacyGateType(std::make_tuple(GateRotationCategory::px, 2, Phase(1))), qu_reverse);
_physical_circuit->append(LegacyGateType(std::make_tuple(GateRotationCategory::px, 2, Phase(1))), qu);
_physical_circuit->append(CXGate(), qu);
_physical_circuit->append(CXGate(), qu_reverse);
_physical_circuit->append(CXGate(), qu);
} else {
_physical_circuit->append(operation.get_operation(), qu);
}
Expand Down
Loading

0 comments on commit 3641a47

Please sign in to comment.