Skip to content

Commit

Permalink
fix: Support hugr packages, fix the notebooks (#622)
Browse files Browse the repository at this point in the history
This is a small bundle of changes aimed at fixing the example notebooks.

- Use packages when building hugrs. This includes parsing guppy packages
from rust, as well as returning packages from the circuit builder.
  - This required adding exports for the tket2 extensions from python.
- We temporarily add a py dependency on `hugr-cli`, since packages are
currently defined there. See CQCL/hugr#1530
  
The `guppy -> pytket` pipeline is still broken, as guppy now generates
many more function calls and nested control flow primitives, which we do
not currently support.

drive-by: Make `crate::serialize::guppy::find_function` public.

Closes #621.
  • Loading branch information
aborgna-q authored Sep 30, 2024
1 parent 5ec193d commit 1cf9dcb
Show file tree
Hide file tree
Showing 12 changed files with 3,120 additions and 367 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions tket2-py/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ derive_more = { workspace = true }
itertools = { workspace = true }
portmatching = { workspace = true }
strum = { workspace = true }
# Required to acces the `Package` type.
# Remove once https://github.com/CQCL/hugr/issues/1530 is fixed.
hugr-cli = { workspace = true }

[dev-dependencies]
rstest = { workspace = true }
Expand Down
3,204 changes: 2,909 additions & 295 deletions tket2-py/examples/1-Getting-Started.ipynb

Large diffs are not rendered by default.

109 changes: 67 additions & 42 deletions tket2-py/examples/2-Rewriting-Circuits.ipynb

Large diffs are not rendered by default.

32 changes: 22 additions & 10 deletions tket2-py/examples/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,38 @@
"""Some utility functions for the example notebooks."""

from typing import TYPE_CHECKING, Any
from hugr import Hugr
from tket2.passes import lower_to_pytket
from tket2.circuit import Tk2Circuit
from guppylang.definition.function import RawFunctionDef # type: ignore[import-untyped, import-not-found, unused-ignore] # noqa: F401

if TYPE_CHECKING:
try:
from guppylang.definition.function import RawFunctionDef # type: ignore[import-untyped, import-not-found, unused-ignore] # noqa: F401
except ImportError:
RawFunctionDef = Any

def setup_jupyter_rendering():
"""Set up hugr rendering for Jupyter notebooks."""

# We need to define this helper function for now. It will be included in guppy in the future.
def _repr_hugr(
h: Hugr, include=None, exclude=None, **kwargs
) -> dict[str, bytes | str]:
return h.render_dot()._repr_mimebundle_(include, exclude, **kwargs)

def _repr_tk2circ(
circ: Tk2Circuit, include=None, exclude=None, **kwargs
) -> dict[str, bytes | str]:
h = Hugr.load_json(circ.to_hugr_json())
return _repr_hugr(h, include, exclude, **kwargs)

setattr(Hugr, "_repr_mimebundle_", _repr_hugr)
setattr(Tk2Circuit, "_repr_mimebundle_", _repr_tk2circ)


# TODO: Should this be part of the guppy API? Or tket2?
def guppy_to_circuit(func_def: RawFunctionDef) -> Tk2Circuit:
"""Convert a Guppy function definition to a `Tk2Circuit`."""
module = func_def.id.module
assert module is not None, "Function definition must belong to a module"

hugr = module.compile()
assert hugr is not None, "Module must be compilable"
pkg = module.compile()

json = hugr.to_raw().to_json()
json = pkg.to_json()
circ = Tk2Circuit.from_guppy_json(json, func_def.name)

return lower_to_pytket(circ)
18 changes: 16 additions & 2 deletions tket2-py/src/circuit/tk2circuit.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
//! Rust-backed representation of circuits

use std::borrow::{Borrow, Cow};
use std::mem;

use hugr::builder::{CircuitBuilder, DFGBuilder, Dataflow, DataflowHugr};
use hugr::extension::prelude::QB_T;
use hugr::extension::{ExtensionRegistry, EMPTY_REG};
use hugr::ops::handle::NodeHandle;
use hugr::ops::{ExtensionOp, NamedOp, OpType};
use hugr::types::Type;
use hugr_cli::Package;
use itertools::Itertools;
use pyo3::exceptions::{PyAttributeError, PyValueError};
use pyo3::types::{PyAnyMethods, PyModule, PyString, PyTypeMethods};
Expand Down Expand Up @@ -94,9 +97,20 @@ impl Tk2Circuit {
/// Decode a HUGR json string to a circuit.
#[staticmethod]
pub fn from_hugr_json(json: &str) -> PyResult<Self> {
let hugr: Hugr = serde_json::from_str(json)
let pkg: Package = serde_json::from_str(json)
.map_err(|e| PyErr::new::<PyAttributeError, _>(format!("Invalid encoded HUGR: {e}")))?;
Ok(Tk2Circuit { circ: hugr.into() })
let mut reg = REGISTRY.clone();
let mut hugrs = pkg.validate(&mut reg).map_err(|e| {
PyErr::new::<PyAttributeError, _>(format!("Invalid encoded circuit: {e}"))
})?;
if hugrs.len() != 1 {
return Err(PyValueError::new_err(
"Invalid HUGR json: Package must contain exactly one hugr.",
));
}
Ok(Tk2Circuit {
circ: mem::take(&mut hugrs[0]).into(),
})
}

/// Load a function from a compiled guppy module, encoded as a json string.
Expand Down
39 changes: 33 additions & 6 deletions tket2-py/tket2/circuit/build.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations
from hugr.hugr import Hugr
from typing import Iterable

from hugr import tys, ops
from hugr.ext import Package, Extension
from hugr.ops import ComWire, Command
from hugr.std.float import FLOAT_T
from hugr.build.tracked_dfg import TrackedDfg
Expand All @@ -17,10 +19,35 @@ class CircBuild(TrackedDfg):
def with_nqb(cls, n_qb: int) -> CircBuild:
return cls(*[tys.Qubit] * n_qb, track_inputs=True)

def finish(self) -> Tk2Circuit:
def finish_package(
self, other_extensions: Iterable[Extension] | None = None
) -> Package:
"""Finish building the package by setting all the qubits as the output
and wrap it in a hugr package with the required extensions.
Args:
other_extensions: Other extensions to include in the package.
Returns:
The finished package.
"""
import tket2.extensions as ext

extensions = [
ext.rotation(),
ext.futures(),
ext.hseries(),
ext.quantum(),
ext.result(),
*(other_extensions or []),
]

return Package(modules=[self.hugr], extensions=extensions)

def finish(self, other_extensions: list[Extension] | None = None) -> Tk2Circuit:
"""Finish building the circuit by setting all the qubits as the output
and validate."""
return load_hugr(self.hugr)

return load_hugr_pkg(self.finish_package(other_extensions))


def from_coms(*args: Command) -> Tk2Circuit:
Expand All @@ -40,8 +67,8 @@ def from_coms(*args: Command) -> Tk2Circuit:
return build.finish()


def load_hugr(h: Hugr) -> Tk2Circuit:
return Tk2Circuit.from_hugr_json(h.to_json())
def load_hugr_pkg(package: Package) -> Tk2Circuit:
return Tk2Circuit.from_hugr_json(package.to_json())


def load_custom(serialized: bytes) -> ops.Custom:
Expand All @@ -61,7 +88,7 @@ def id_circ(n_qb: int) -> Tk2Circuit:

@dataclass(frozen=True)
class QuantumOps(ops.Custom):
extension: tys.ExtensionId = "quantum.tket2"
extension: tys.ExtensionId = "tket2.quantum"


_OneQbSig = tys.FunctionType.endo([tys.Qubit])
Expand Down
42 changes: 42 additions & 0 deletions tket2-py/tket2/extensions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Hugr extension definitions for tket2 circuits."""
# This will be moved to a separate python library soon.

import pkgutil
import functools

from hugr._serialization.extension import Extension as PdExtension
from hugr.ext import Extension


@functools.cache
def rotation() -> Extension:
return load_extension("tket2.rotation")


@functools.cache
def futures() -> Extension:
return load_extension("tket2.futures")


@functools.cache
def hseries() -> Extension:
return load_extension("tket2.hseries")


@functools.cache
def quantum() -> Extension:
return load_extension("tket2.quantum")


@functools.cache
def result() -> Extension:
return load_extension("tket2.result")


def load_extension(name: str) -> Extension:
replacement = name.replace(".", "/")
json_str = pkgutil.get_data(__name__, f"_json_defs/{replacement}.json")
assert json_str is not None, f"Could not load json for extension {name}"
# TODO: Replace with `Extension.from_json` once that is implemented
# https://github.com/CQCL/hugr/issues/1523
return PdExtension.model_validate_json(json_str).deserialize()
4 changes: 4 additions & 0 deletions tket2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ bytemuck = { workspace = true }
crossbeam-channel = { workspace = true }
tracing = { workspace = true }
zstd = { workspace = true, optional = true }
# Required to acces the `Package` type.
# Remove once https://github.com/CQCL/hugr/issues/1530 is fixed.
hugr-cli = { workspace = true }


[dev-dependencies]
rstest = { workspace = true }
Expand Down
13 changes: 5 additions & 8 deletions tket2/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@

use crate::serialize::pytket::OpaqueTk1Op;
use crate::Tk2Op;
use hugr::extension::prelude::PRELUDE;
use hugr::extension::simple_op::MakeOpDef;
use hugr::extension::{
CustomSignatureFunc, ExtensionId, ExtensionRegistry, SignatureError, Version,
};
use hugr::hugr::IdentList;
use hugr::std_extensions::arithmetic::float_types::EXTENSION as FLOAT_TYPES;
use hugr::std_extensions::STD_REG;
use hugr::types::type_param::{TypeArg, TypeParam};
use hugr::types::{CustomType, PolyFuncType, PolyFuncTypeRV};
use hugr::Extension;
Expand Down Expand Up @@ -57,15 +56,13 @@ pub static ref TKET1_EXTENSION: Extension = {
res
};

/// Extension registry including the prelude, TKET1 and Tk2Ops extensions.
pub static ref REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([
/// Extension registry including the prelude, std, TKET1, and Tk2Ops extensions.
pub static ref REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new(
STD_REG.iter().map(|(_, e)| e.to_owned()).chain([
TKET1_EXTENSION.to_owned(),
PRELUDE.to_owned(),
TKET2_EXTENSION.to_owned(),
FLOAT_TYPES.to_owned(),
rotation::ROTATION_EXTENSION.to_owned()
]).unwrap();

])).unwrap();

}

Expand Down
2 changes: 1 addition & 1 deletion tket2/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ pub(crate) mod test {
#[test]
fn tk2op_properties() {
for op in Tk2Op::iter() {
// The exposed name should start with "quantum.tket2."
// The exposed name should start with "tket2.quantum."
assert!(op.exposed_name().starts_with(&EXTENSION_ID.to_string()));

let ext_op = op.into_extension_op();
Expand Down
19 changes: 16 additions & 3 deletions tket2/src/serialize/guppy.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
//! Load pre-compiled guppy functions.

use std::path::Path;
use std::{fs, io};
use std::{fs, io, mem};

use hugr::ops::{NamedOp, OpTag, OpTrait, OpType};
use hugr::{Hugr, HugrView};
use hugr_cli::Package;
use itertools::Itertools;
use thiserror::Error;

use crate::extension::REGISTRY;
use crate::{Circuit, CircuitError};

/// Loads a pre-compiled guppy file.
Expand All @@ -31,7 +33,12 @@ pub fn load_guppy_json_reader(
reader: impl io::Read,
function: &str,
) -> Result<Circuit, CircuitLoadError> {
let hugr: Hugr = serde_json::from_reader(reader)?;
let pkg: Package = serde_json::from_reader(reader)?;
let mut hugrs = pkg.validate(&mut REGISTRY.clone())?;
if hugrs.len() != 1 {
return Err(CircuitLoadError::InvalidNumHugrs(hugrs.len()));
}
let hugr = mem::take(&mut hugrs[0]);
find_function(hugr, function)
}

Expand All @@ -48,7 +55,7 @@ pub fn load_guppy_json_reader(
/// - If the root of the HUGR is not a module operation.
/// - If the function is not found in the module.
/// - If the function has control flow primitives.
fn find_function(hugr: Hugr, function_name: &str) -> Result<Circuit, CircuitLoadError> {
pub fn find_function(hugr: Hugr, function_name: &str) -> Result<Circuit, CircuitLoadError> {
// Find the root module.
let module = hugr.root();
if !OpTag::ModuleRoot.is_superset(hugr.get_optype(module).tag()) {
Expand Down Expand Up @@ -139,4 +146,10 @@ pub enum CircuitLoadError {
/// Error loading the circuit.
#[error("Error loading the circuit: {0}")]
CircuitLoadError(#[from] CircuitError),
/// Error validating the loaded circuit.
#[error("{0}")]
ValError(#[from] hugr_cli::validate::ValError),
/// The encoded HUGR package must have a single HUGR.
#[error("The encoded HUGR package must have a single HUGR, but it has {0} HUGRs.")]
InvalidNumHugrs(usize),
}

0 comments on commit 1cf9dcb

Please sign in to comment.