diff --git a/pytket/binders/passes.cpp b/pytket/binders/passes.cpp index d47181072f..6768b6d3a9 100644 --- a/pytket/binders/passes.cpp +++ b/pytket/binders/passes.cpp @@ -265,23 +265,35 @@ PYBIND11_MODULE(passes, m) { .def( "to_dict", [](const BasePass &base_pass) { - return py::object(base_pass.get_config()).cast(); + return py::cast(serialise(base_pass)); }, ":return: A JSON serializable dictionary representation of the Pass.") .def_static( "from_dict", - [](const py::dict &base_pass_dict) { - return json(base_pass_dict).get(); + [](const py::dict &base_pass_dict, + + std::map> + &custom_deserialisation) { + return deserialise(base_pass_dict, custom_deserialisation); }, "Construct a new Pass instance from a JSON serializable dictionary " - "representation.") + "representation. `custom_deserialisation` is a map between " + "`CustomPass` " + "label attributes and a Circuit to Circuit function matching the " + "`CustomPass` `transform` argument. This allows the construction of " + "some `CustomPass` from JSON. `CustomPass` without a matching entry " + "in " + "`custom_deserialisation` will be rejected.", + py::arg("base_pass_dict"), + py::arg("custom_deserialisation") = + std::map>{}) .def(py::pickle( [](py::object self) { // __getstate__ return py::make_tuple(self.attr("to_dict")()); }, [](const py::tuple &t) { // __setstate__ const json j = t[0].cast(); - return j.get(); + return deserialise(j); })); py::class_, BasePass>( m, "SequencePass", "A sequence of compilation passes.") @@ -296,9 +308,18 @@ PYBIND11_MODULE(passes, m) { "\n:return: a pass that applies the sequence", py::arg("pass_list"), py::arg("strict") = true) .def("__str__", [](const BasePass &) { return ""; }) + .def( + "to_dict", + [](const SequencePass &seq_pass) { + return py::cast( + serialise(std::make_shared(seq_pass))); + }, + ":return: A JSON serializable dictionary representation of the " + "SequencePass.") .def( "get_sequence", &SequencePass::get_sequence, ":return: The underlying sequence of passes."); + py::class_, BasePass>( m, "RepeatPass", "Repeat a pass until its `apply()` method returns False, or if " diff --git a/pytket/conanfile.py b/pytket/conanfile.py index 1d2f0718d5..094c0e3d89 100644 --- a/pytket/conanfile.py +++ b/pytket/conanfile.py @@ -38,7 +38,7 @@ def requirements(self): self.requires("pybind11_json/0.2.14") self.requires("symengine/0.12.0") self.requires("tkassert/0.3.4@tket/stable") - self.requires("tket/1.3.38@tket/stable") + self.requires("tket/1.3.39@tket/stable") self.requires("tklog/0.3.3@tket/stable") self.requires("tkrng/0.3.3@tket/stable") self.requires("tktokenswap/0.3.9@tket/stable") diff --git a/pytket/docs/changelog.rst b/pytket/docs/changelog.rst index 1fe3063f00..df54b92f4e 100644 --- a/pytket/docs/changelog.rst +++ b/pytket/docs/changelog.rst @@ -10,6 +10,8 @@ Features: in `ClExprOp`. * Use `ClExprOp` instead of `ClassicalExpBox` when deconstructing complex conditions. +* Add `custom_deserialisation` argument to `BasePass` and `SequencePass` + `from_dict` method to support construction of `CustomPass` from json. Fixes: diff --git a/pytket/pytket/_tket/passes.pyi b/pytket/pytket/_tket/passes.pyi index efc45514eb..650273f020 100644 --- a/pytket/pytket/_tket/passes.pyi +++ b/pytket/pytket/_tket/passes.pyi @@ -18,9 +18,9 @@ class BasePass: def _pybind11_conduit_v1_(*args, **kwargs): # type: ignore ... @staticmethod - def from_dict(arg0: dict) -> BasePass: + def from_dict(base_pass_dict: dict, custom_deserialisation: dict[str, typing.Callable[[pytket._tket.circuit.Circuit], pytket._tket.circuit.Circuit]] = {}) -> BasePass: """ - Construct a new Pass instance from a JSON serializable dictionary representation. + Construct a new Pass instance from a JSON serializable dictionary representation. `custom_deserialisation` is a map between `CustomPass` label attributes and a Circuit to Circuit function matching the `CustomPass` `transform` argument. This allows the construction of some `CustomPass` from JSON. `CustomPass` without a matching entry in `custom_deserialisation` will be rejected. """ def __getstate__(self) -> tuple: ... @@ -54,7 +54,7 @@ class BasePass: :param after_apply: Invoked after a pass is applied. The CompilationUnit and a summary of the pass configuration are passed into the callback. :return: True if pass modified the circuit, else False """ - def to_dict(self) -> dict: + def to_dict(self) -> typing.Any: """ :return: A JSON serializable dictionary representation of the Pass. """ @@ -227,6 +227,10 @@ class SequencePass(BasePass): """ :return: The underlying sequence of passes. """ + def to_dict(self) -> typing.Any: + """ + :return: A JSON serializable dictionary representation of the SequencePass. + """ def AASRouting(arc: pytket._tket.architecture.Architecture, **kwargs: Any) -> BasePass: """ Construct a pass to relabel :py:class:`Circuit` Qubits to :py:class:`Device` Nodes, and then use architecture-aware synthesis to route the circuit. In the steps of the pass the circuit will be converted to CX, Rz, H gateset. The limited connectivity of the :py:class:`Architecture` is used for the routing. The direction of the edges is ignored. The placement used is GraphPlacement. This pass can take a few parameters for the routing, described below: diff --git a/pytket/tests/passes_serialisation_test.py b/pytket/tests/passes_serialisation_test.py index 193ece5786..3c76194963 100644 --- a/pytket/tests/passes_serialisation_test.py +++ b/pytket/tests/passes_serialisation_test.py @@ -41,6 +41,7 @@ DefaultMappingPass, AASRouting, SquashCustom, + CustomPass, ) from pytket.mapping import ( LexiLabellingMethod, @@ -771,3 +772,15 @@ def no_CX(circ: Circuit) -> bool: rps.to_dict()["RepeatUntilSatisfiedPass"]["predicate"]["type"] == "UserDefinedPredicate" ) + + +def test_custom_deserialisation() -> None: + def t(c: Circuit) -> Circuit: + return Circuit(2).CX(0, 1) + + custom_pass_post = BasePass.from_dict( + CustomPass(t, label="test").to_dict(), {"test": t} + ) + c: Circuit = Circuit(3).H(0).H(1).H(2) + custom_pass_post.apply(c) + assert c == Circuit(2).CX(0, 1) diff --git a/tket/conanfile.py b/tket/conanfile.py index a258c5c5ee..147e674cb2 100644 --- a/tket/conanfile.py +++ b/tket/conanfile.py @@ -23,7 +23,7 @@ class TketConan(ConanFile): name = "tket" - version = "1.3.38" + version = "1.3.39" package_type = "library" license = "Apache 2" homepage = "https://github.com/CQCL/tket" diff --git a/tket/include/tket/Predicates/CompilerPass.hpp b/tket/include/tket/Predicates/CompilerPass.hpp index b40e6206e8..f04750f97f 100644 --- a/tket/include/tket/Predicates/CompilerPass.hpp +++ b/tket/include/tket/Predicates/CompilerPass.hpp @@ -33,8 +33,6 @@ typedef std::pair PassConditions; typedef std::function PassCallback; -JSON_DECL(PassPtr) - class IncompatibleCompilerPasses : public std::logic_error { public: explicit IncompatibleCompilerPasses(const std::type_index& typeid1) @@ -301,6 +299,14 @@ class RepeatUntilSatisfiedPass : public BasePass { PredicatePtr pred_; }; +nlohmann::json serialise(const BasePass& bp); +nlohmann::json serialise(const PassPtr& pp); +nlohmann::json serialise(const std::vector& pp); + +PassPtr deserialise( + const nlohmann::json& j, + const std::map>& + custom_deserialise = {}); // TODO: Repeat with a metric, repeat until a Predicate is satisfied... } // namespace tket diff --git a/tket/src/Predicates/CompilerPass.cpp b/tket/src/Predicates/CompilerPass.cpp index 43691b0faa..f0a5461e4b 100644 --- a/tket/src/Predicates/CompilerPass.cpp +++ b/tket/src/Predicates/CompilerPass.cpp @@ -251,7 +251,7 @@ std::string SequencePass::to_string() const { nlohmann::json SequencePass::get_config() const { nlohmann::json j; j["pass_class"] = "SequencePass"; - j["SequencePass"]["sequence"] = seq_; + j["SequencePass"]["sequence"] = serialise(seq_); return j; } @@ -270,7 +270,7 @@ std::string RepeatPass::to_string() const { nlohmann::json RepeatPass::get_config() const { nlohmann::json j; j["pass_class"] = "RepeatPass"; - j["RepeatPass"]["body"] = pass_; + j["RepeatPass"]["body"] = serialise(pass_); return j; } @@ -313,7 +313,7 @@ std::string RepeatWithMetricPass::to_string() const { nlohmann::json RepeatWithMetricPass::get_config() const { nlohmann::json j; j["pass_class"] = "RepeatWithMetricPass"; - j["RepeatWithMetricPass"]["body"] = pass_; + j["RepeatWithMetricPass"]["body"] = serialise(pass_); j["RepeatWithMetricPass"]["metric"] = "SERIALIZATION OF METRICS NOT YET IMPLEMENTED"; return j; @@ -347,15 +347,27 @@ std::string RepeatUntilSatisfiedPass::to_string() const { nlohmann::json RepeatUntilSatisfiedPass::get_config() const { nlohmann::json j; j["pass_class"] = "RepeatUntilSatisfiedPass"; - j["RepeatUntilSatisfiedPass"]["body"] = pass_; + j["RepeatUntilSatisfiedPass"]["body"] = serialise(pass_); j["RepeatUntilSatisfiedPass"]["predicate"] = pred_; return j; } -void to_json(nlohmann::json& j, const PassPtr& pp) { j = pp->get_config(); } +nlohmann::json serialise(const BasePass& bp) { return bp.get_config(); } +nlohmann::json serialise(const PassPtr& pp) { return pp->get_config(); } +nlohmann::json serialise(const std::vector& pp) { + nlohmann::json j = nlohmann::json::array(); + for (const auto& p : pp) { + j.push_back(serialise(p)); + } + return j; +} -void from_json(const nlohmann::json& j, PassPtr& pp) { +PassPtr deserialise( + const nlohmann::json& j, + const std::map>& + custom_deserialise) { std::string classname = j.at("pass_class").get(); + PassPtr pp; if (classname == "StandardPass") { const nlohmann::json& content = j.at("StandardPass"); std::string passname = content.at("name").get(); @@ -576,6 +588,17 @@ void from_json(const nlohmann::json& j, PassPtr& pp) { unsigned n = content.at("n").get(); bool only_zeros = content.at("only_zeros").get(); pp = RoundAngles(n, only_zeros); + } else if (passname == "CustomPass") { + std::string label = content.at("label").get(); + auto it = custom_deserialise.find(label); + if (it != custom_deserialise.end()) { + pp = CustomPass(it->second, label); + } else { + throw JsonError( + "Cannot deserialise CustomPass without passing a " + "custom_deserialisation map " + "with a key corresponding to the pass's label."); + } } else { throw JsonError("Cannot load StandardPass of unknown type"); } @@ -583,22 +606,24 @@ void from_json(const nlohmann::json& j, PassPtr& pp) { const nlohmann::json& content = j.at("SequencePass"); std::vector seq; for (const auto& j_entry : content.at("sequence")) { - seq.push_back(j_entry.get()); + seq.push_back(deserialise(j_entry, custom_deserialise)); } pp = std::make_shared(seq); } else if (classname == "RepeatPass") { const nlohmann::json& content = j.at("RepeatPass"); - pp = std::make_shared(content.at("body").get()); + pp = std::make_shared( + deserialise(content.at("body"), custom_deserialise)); } else if (classname == "RepeatWithMetricPass") { throw PassNotSerializable(classname); } else if (classname == "RepeatUntilSatisfiedPass") { const nlohmann::json& content = j.at("RepeatUntilSatisfiedPass"); - PassPtr body = content.at("body").get(); + PassPtr body = deserialise(content.at("body"), custom_deserialise); PredicatePtr pred = content.at("predicate").get(); pp = std::make_shared(body, pred); } else { throw JsonError("Cannot load PassPtr of unknown type."); } + return pp; } } // namespace tket diff --git a/tket/test/src/test_json.cpp b/tket/test/src/test_json.cpp index d2843b6a57..e0e2704709 100644 --- a/tket/test/src/test_json.cpp +++ b/tket/test/src/test_json.cpp @@ -901,12 +901,12 @@ SCENARIO("Test compiler pass serializations") { CompilationUnit cu{circ}; \ CompilationUnit copy = cu; \ PassPtr pp = pass; \ - nlohmann::json j_pp = pp; \ - PassPtr loaded = j_pp.get(); \ + nlohmann::json j_pp = serialise(pp); \ + PassPtr loaded = deserialise(j_pp); \ pp->apply(cu); \ loaded->apply(copy); \ REQUIRE(cu.get_circ_ref() == copy.get_circ_ref()); \ - nlohmann::json j_loaded = loaded; \ + nlohmann::json j_loaded = serialise(loaded); \ REQUIRE(j_pp == j_loaded); \ } COMPPASSJSONTEST(CommuteThroughMultis, CommuteThroughMultis()) @@ -986,14 +986,14 @@ SCENARIO("Test compiler pass serializations") { CompilationUnit copy = cu; PassPtr pp = gen_pauli_exponentials( Transforms::PauliSynthStrat::Sets, CXConfigType::Tree); - nlohmann::json j_pp = pp; - PassPtr loaded = j_pp.get(); + nlohmann::json j_pp = serialise(pp); + PassPtr loaded = deserialise(j_pp); pp->apply(cu); loaded->apply(copy); DecomposeBoxes()->apply(cu); DecomposeBoxes()->apply(copy); REQUIRE(cu.get_circ_ref() == copy.get_circ_ref()); - nlohmann::json j_loaded = loaded; + nlohmann::json j_loaded = serialise(loaded); REQUIRE(j_pp == j_loaded); } GIVEN("RoutingPass") { @@ -1004,12 +1004,12 @@ SCENARIO("Test compiler pass serializations") { placement->apply(cu); CompilationUnit copy = cu; PassPtr pp = gen_routing_pass(arc, rcon); - nlohmann::json j_pp = pp; - PassPtr loaded = j_pp.get(); + nlohmann::json j_pp = serialise(pp); + PassPtr loaded = deserialise(j_pp); pp->apply(cu); loaded->apply(copy); REQUIRE(cu.get_circ_ref() == copy.get_circ_ref()); - nlohmann::json j_loaded = loaded; + nlohmann::json j_loaded = serialise(loaded); REQUIRE(j_pp == j_loaded); } GIVEN("Routing with multiple routing methods") { @@ -1023,12 +1023,12 @@ SCENARIO("Test compiler pass serializations") { placement->apply(cu); CompilationUnit copy = cu; PassPtr pp = gen_routing_pass(arc, mrcon); - nlohmann::json j_pp = pp; - PassPtr loaded = j_pp.get(); + nlohmann::json j_pp = serialise(pp); + PassPtr loaded = deserialise(j_pp); pp->apply(cu); loaded->apply(copy); REQUIRE(cu.get_circ_ref() == copy.get_circ_ref()); - nlohmann::json j_loaded = loaded; + nlohmann::json j_loaded = serialise(loaded); REQUIRE(j_pp == j_loaded); } GIVEN("FullMappingPass") { @@ -1049,7 +1049,7 @@ SCENARIO("Test compiler pass serializations") { } j_pp["StandardPass"]["routing_config"] = config_array; - PassPtr loaded = j_pp.get(); + PassPtr loaded = deserialise(j_pp); pp->apply(cu); loaded->apply(copy); REQUIRE(cu.get_circ_ref() == copy.get_circ_ref()); @@ -1065,7 +1065,7 @@ SCENARIO("Test compiler pass serializations") { j_pp["StandardPass"]["name"] = "DefaultMappingPass"; j_pp["StandardPass"]["architecture"] = arc; j_pp["StandardPass"]["delay_measures"] = true; - PassPtr loaded = j_pp.get(); + PassPtr loaded = deserialise(j_pp); pp->apply(cu); loaded->apply(copy); REQUIRE(cu.get_circ_ref() == copy.get_circ_ref()); @@ -1088,7 +1088,7 @@ SCENARIO("Test compiler pass serializations") { j_pp["StandardPass"]["routing_config"] = config_array; j_pp["StandardPass"]["directed"] = true; j_pp["StandardPass"]["delay_measures"] = false; - PassPtr loaded = j_pp.get(); + PassPtr loaded = deserialise(j_pp); pp->apply(cu); loaded->apply(copy); REQUIRE(cu.get_circ_ref() == copy.get_circ_ref()); @@ -1106,7 +1106,7 @@ SCENARIO("Test compiler pass serializations") { j_pp["StandardPass"]["pauli_synth_strat"] = Transforms::PauliSynthStrat::Sets; j_pp["StandardPass"]["cx_config"] = CXConfigType::Star; - PassPtr loaded = j_pp.get(); + PassPtr loaded = deserialise(j_pp); pp->apply(cu); loaded->apply(copy); REQUIRE(cu.get_circ_ref() == copy.get_circ_ref()); @@ -1124,7 +1124,7 @@ SCENARIO("Test compiler pass serializations") { j_pp["StandardPass"]["pauli_synth_strat"] = Transforms::PauliSynthStrat::Sets; j_pp["StandardPass"]["cx_config"] = CXConfigType::Star; - PassPtr loaded = j_pp.get(); + PassPtr loaded = deserialise(j_pp); pp->apply(cu); loaded->apply(copy); REQUIRE(cu.get_circ_ref() == copy.get_circ_ref()); @@ -1143,7 +1143,7 @@ SCENARIO("Test compiler pass serializations") { j_pp["StandardPass"]["name"] = "ContextSimp"; j_pp["StandardPass"]["allow_classical"] = true; j_pp["StandardPass"]["x_circuit"] = CircPool::X(); - PassPtr loaded = j_pp.get(); + PassPtr loaded = deserialise(j_pp); pp->apply(cu); loaded->apply(copy); REQUIRE(cu.get_circ_ref() == copy.get_circ_ref()); @@ -1158,12 +1158,12 @@ SCENARIO("Test compiler pass combinator serializations") { std::vector seq_vec = { gen_pauli_exponentials(), DecomposeBoxes(), gen_clifford_simp_pass()}; PassPtr seq = std::make_shared(seq_vec); - nlohmann::json j_seq = seq; - PassPtr loaded_seq = j_seq.get(); + nlohmann::json j_seq = serialise(seq); + PassPtr loaded_seq = deserialise(j_seq); seq->apply(cu); loaded_seq->apply(copy); REQUIRE(cu.get_circ_ref() == copy.get_circ_ref()); - nlohmann::json j_loaded_seq = loaded_seq; + nlohmann::json j_loaded_seq = serialise(loaded_seq); REQUIRE(j_seq == j_loaded_seq); } GIVEN("A complex pass with multiple combinators") { @@ -1186,12 +1186,12 @@ SCENARIO("Test compiler pass combinator serializations") { PassPtr rep = std::make_shared(seq, gate_set); PassPtr comb = std::make_shared(std::vector{rep, RebaseTket()}); - nlohmann::json j_comb = comb; - PassPtr loaded_comb = j_comb.get(); + nlohmann::json j_comb = serialise(comb); + PassPtr loaded_comb = deserialise(j_comb); comb->apply(cu); loaded_comb->apply(copy); REQUIRE(cu.get_circ_ref() == copy.get_circ_ref()); - nlohmann::json j_loaded_comb = loaded_comb; + nlohmann::json j_loaded_comb = serialise(loaded_comb); REQUIRE(j_comb == j_loaded_comb); } }