Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Bump HUGR dependency #94

Merged
merged 1 commit into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ members = ["pyrs", "compile-matcher"]

[workspace.dependencies]

quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "abfaba6" }
quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "5a97a635" }
portgraph = { version = "0.8", features = ["serde"] }
pyo3 = { version = "0.19" }
itertools = { version = "0.11.0" }
3 changes: 1 addition & 2 deletions src/circuit/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use std::iter::FusedIterator;
use hugr::hugr::views::HierarchyView;
use hugr::ops::{OpTag, OpTrait};
use petgraph::visit::{GraphBase, IntoNeighborsDirected, IntoNodeIdentifiers};
use portgraph::PortOffset;

use super::Circuit;

Expand Down Expand Up @@ -107,7 +106,7 @@ where
optype
.static_input()
// TODO query optype for this port once it is available in hugr.
.map(|_| PortOffset::new_incoming(sig.input.len()).into()),
.map(|_| Port::new_incoming(sig.input.len())),
)
.filter_map(|port| {
let (from, from_port) = self.circ.linked_ports(node, port).next()?;
Expand Down
10 changes: 6 additions & 4 deletions src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ pub(crate) fn wrap_json_op(op: &JsonOp) -> ExternalOp {
// .into()
let sig = op.signature();
let op = serde_yaml::to_value(op).unwrap();
let payload = TypeArg::Opaque(CustomTypeArg::new(TKET1_OP_PAYLOAD.clone(), op).unwrap());
let payload = TypeArg::Opaque {
arg: CustomTypeArg::new(TKET1_OP_PAYLOAD.clone(), op).unwrap(),
};
OpaqueOp::new(
TKET1_EXTENSION_ID,
JSON_OP_NAME,
Expand All @@ -100,17 +102,17 @@ pub(crate) fn try_unwrap_json_op(ext: &ExternalOp) -> Option<JsonOp> {
if ext.name() != format!("{TKET1_EXTENSION_ID}.{JSON_OP_NAME}") {
return None;
}
let Some(TypeArg::Opaque(op)) = ext.args().get(0) else {
let Some(TypeArg::Opaque { arg }) = ext.args().get(0) else {
// TODO: Throw an error? We should never get here if the name matches.
return None;
};
let op = serde_yaml::from_value(op.value.clone()).ok()?;
let op = serde_yaml::from_value(arg.value.clone()).ok()?;
Some(op)
}

/// Compute the signature of a json-encoded TKET1 operation.
fn json_op_signature(args: &[TypeArg]) -> Result<FunctionType, SignatureError> {
let [TypeArg::Opaque(arg)] = args else {
let [TypeArg::Opaque { arg }] = args else {
// This should have already been checked.
panic!("Wrong number of arguments");
};
Expand Down
6 changes: 4 additions & 2 deletions src/json/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,10 @@ impl JsonEncoder {
OpType::Const(const_op) => {
// New constant, register it if it can be interpreted as a parameter.
match const_op.value() {
Value::Prim(PrimValue::Extension((v,))) => {
if let Some(f) = v.downcast_ref::<ConstF64>() {
Value::Prim {
val: PrimValue::Extension { c: (val,) },
} => {
if let Some(f) = val.downcast_ref::<ConstF64>() {
f.to_string()
} else {
return false;
Expand Down
10 changes: 5 additions & 5 deletions src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,9 @@ pub fn symbolic_constant_op(s: &str) -> OpType {
let l: LeafOp = EXTENSION
.instantiate_extension_op(
&SYM_OP_ID,
vec![TypeArg::Opaque(
CustomTypeArg::new(SYM_EXPR_T, value).unwrap(),
)],
vec![TypeArg::Opaque {
arg: CustomTypeArg::new(SYM_EXPR_T, value).unwrap(),
}],
)
.unwrap()
.into();
Expand All @@ -202,11 +202,11 @@ pub(crate) fn match_symb_const_op(op: &OpType) -> Option<&str> {
{
// TODO also check extension name

let Some(TypeArg::Opaque(s)) = e.args().get(0) else {
let Some(TypeArg::Opaque { arg }) = e.args().get(0) else {
panic!("should be an opaque type arg.")
};

let serde_yaml::Value::String(s) = &s.value else {
let serde_yaml::Value::String(s) = &arg.value else {
panic!("unexpected yaml value.")
};

Expand Down
22 changes: 8 additions & 14 deletions src/portmatching/matcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,8 @@ use std::{
};

use super::{CircuitPattern, PEdge, PNode};
use hugr::{
hugr::views::{
sibling::{
ConvexChecker, InvalidReplacement,
InvalidSubgraph::{self},
},
SiblingSubgraph,
},
ops::OpType,
Hugr, Node, Port,
};
use hugr::hugr::views::sibling_subgraph::{ConvexChecker, InvalidReplacement, InvalidSubgraph};
use hugr::{hugr::views::SiblingSubgraph, ops::OpType, Hugr, Node, Port};
use itertools::Itertools;
use portmatching::{
automaton::{LineBuilder, ScopeAutomaton},
Expand Down Expand Up @@ -82,7 +73,7 @@ pub struct PatternMatch<'a, C> {
pub(super) root: Node,
}

impl<'a, C: Circuit<'a>> PatternMatch<'a, C> {
impl<'a, C: Circuit<'a> + Clone> PatternMatch<'a, C> {
/// The matcher's pattern ID of the match.
pub fn pattern_id(&self) -> PatternID {
self.pattern
Expand Down Expand Up @@ -246,7 +237,10 @@ impl PatternMatcher {
}

/// Find all convex pattern matches in a circuit.
pub fn find_matches<'a, C: Circuit<'a>>(&self, circuit: &'a C) -> Vec<PatternMatch<'a, C>> {
pub fn find_matches<'a, C: Circuit<'a> + Clone>(
&self,
circuit: &'a C,
) -> Vec<PatternMatch<'a, C>> {
let mut checker = ConvexChecker::new(circuit);
circuit
.commands()
Expand All @@ -255,7 +249,7 @@ impl PatternMatcher {
}

/// Find all convex pattern matches in a circuit rooted at a given node.
fn find_rooted_matches<'a, C: Circuit<'a>>(
fn find_rooted_matches<'a, C: Circuit<'a> + Clone>(
&self,
circ: &'a C,
root: Node,
Expand Down
2 changes: 1 addition & 1 deletion src/portmatching/pyo3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ impl PyPatternMatch {
///
/// Requires references to the circuit and pattern to resolve indices
/// into these objects.
pub fn try_from_rust<'circ, C: Circuit<'circ>>(
pub fn try_from_rust<'circ, C: Circuit<'circ> + Clone>(
m: PatternMatch<'circ, C>,
circ: &C,
matcher: &PatternMatcher,
Expand Down
9 changes: 3 additions & 6 deletions src/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,9 @@ pub use ecc_rewriter::ECCRewriter;

use delegate::delegate;
use derive_more::{From, Into};
use hugr::hugr::views::sibling_subgraph::InvalidReplacement;
use hugr::{
hugr::{
hugrmut::HugrMut,
views::{sibling::InvalidReplacement, SiblingSubgraph},
Rewrite, SimpleReplacementError,
},
hugr::{hugrmut::HugrMut, views::SiblingSubgraph, Rewrite, SimpleReplacementError},
Hugr, HugrView, SimpleReplacement,
};

Expand Down Expand Up @@ -55,5 +52,5 @@ impl CircuitRewrite {
/// Generate rewrite rules for circuits.
pub trait Rewriter {
/// Get the rewrite rules for a circuit.
fn get_rewrites<'a, C: Circuit<'a>>(&'a self, circ: &'a C) -> Vec<CircuitRewrite>;
fn get_rewrites<'a, C: Circuit<'a> + Clone>(&'a self, circ: &'a C) -> Vec<CircuitRewrite>;
}
2 changes: 1 addition & 1 deletion src/rewrite/ecc_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ impl ECCRewriter {
}

impl Rewriter for ECCRewriter {
fn get_rewrites<'a, C: Circuit<'a>>(&'a self, circ: &'a C) -> Vec<CircuitRewrite> {
fn get_rewrites<'a, C: Circuit<'a> + Clone>(&'a self, circ: &'a C) -> Vec<CircuitRewrite> {
let matches = self.matcher.find_matches(circ);
matches
.into_iter()
Expand Down