Skip to content

Commit

Permalink
Change many add_node+open_extensions to use add_op
Browse files Browse the repository at this point in the history
  • Loading branch information
acl-cqc committed Oct 24, 2023
1 parent 5c2e287 commit 41c9c14
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 88 deletions.
12 changes: 6 additions & 6 deletions src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,18 +148,18 @@ pub(crate) mod test {
let mut hugr = Hugr::new(NodeType::pure(ops::DFG {
signature: signature.clone(),
}));
hugr.add_node_with_parent(
hugr.add_op_with_parent(
hugr.root(),
NodeType::open_extensions(ops::Input {
ops::Input {
types: signature.input,
}),
},
)
.unwrap();
hugr.add_node_with_parent(
hugr.add_op_with_parent(
hugr.root(),
NodeType::open_extensions(ops::Output {
ops::Output {
types: signature.output,
}),
},
)
.unwrap();
hugr
Expand Down
2 changes: 1 addition & 1 deletion src/builder/conditional.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> ConditionalBuilder<B> {
if let Some(&sibling_node) = self.case_nodes[case + 1..].iter().flatten().next() {
self.hugr_mut().add_op_before(sibling_node, case_op)?
} else {
self.add_child_node(NodeType::open_extensions(case_op))?
self.add_child_op(case_op)?
};

self.case_nodes[case] = Some(case_node);
Expand Down
140 changes: 64 additions & 76 deletions src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -730,11 +730,11 @@ mod test {
let root_node = NodeType::open_extensions(op);
let mut hugr = Hugr::new(root_node);

let input = NodeType::open_extensions(ops::Input::new(type_row![NAT, NAT]));
let output = NodeType::open_extensions(ops::Output::new(type_row![NAT]));
let input = ops::Input::new(type_row![NAT, NAT]);
let output = ops::Output::new(type_row![NAT]);

let input = hugr.add_node_with_parent(hugr.root(), input)?;
let output = hugr.add_node_with_parent(hugr.root(), output)?;
let input = hugr.add_op_with_parent(hugr.root(), input)?;
let output = hugr.add_op_with_parent(hugr.root(), output)?;

assert_matches!(hugr.get_io(hugr.root()), Some(_));

Expand All @@ -750,29 +750,29 @@ mod test {
let mult_c_sig = FunctionType::new(type_row![NAT, NAT], type_row![NAT])
.with_extension_delta(&ExtensionSet::singleton(&C));

let add_a = hugr.add_node_with_parent(
let add_a = hugr.add_op_with_parent(
hugr.root(),
NodeType::open_extensions(ops::DFG {
ops::DFG {
signature: add_a_sig,
}),
},
)?;
let add_b = hugr.add_node_with_parent(
let add_b = hugr.add_op_with_parent(
hugr.root(),
NodeType::open_extensions(ops::DFG {
ops::DFG {
signature: add_b_sig,
}),
},
)?;
let add_ab = hugr.add_node_with_parent(
let add_ab = hugr.add_op_with_parent(
hugr.root(),
NodeType::open_extensions(ops::DFG {
ops::DFG {
signature: add_ab_sig,
}),
},
)?;
let mult_c = hugr.add_node_with_parent(
let mult_c = hugr.add_op_with_parent(
hugr.root(),
NodeType::open_extensions(ops::DFG {
ops::DFG {
signature: mult_c_sig,
}),
},
)?;

hugr.connect(input, 0, add_a, 0)?;
Expand Down Expand Up @@ -906,29 +906,26 @@ mod test {
let [input, output] = hugr.get_io(hugr.root()).unwrap();
let add_r_sig = FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(&rs);

let add_r = hugr.add_node_with_parent(
let add_r = hugr.add_op_with_parent(
hugr.root(),
NodeType::open_extensions(ops::DFG {
ops::DFG {
signature: add_r_sig,
}),
},
)?;

// Dangling thingy
let src_sig = FunctionType::new(type_row![], type_row![NAT])
.with_extension_delta(&ExtensionSet::new());

let src = hugr.add_node_with_parent(
hugr.root(),
NodeType::open_extensions(ops::DFG { signature: src_sig }),
)?;
let src = hugr.add_op_with_parent(hugr.root(), ops::DFG { signature: src_sig })?;

let mult_sig = FunctionType::new(type_row![NAT, NAT], type_row![NAT]);
// Mult has open extension requirements, which we should solve to be "R"
let mult = hugr.add_node_with_parent(
let mult = hugr.add_op_with_parent(
hugr.root(),
NodeType::open_extensions(ops::DFG {
ops::DFG {
signature: mult_sig,
}),
},
)?;

hugr.connect(input, 0, add_r, 0)?;
Expand Down Expand Up @@ -988,18 +985,18 @@ mod test {
) -> Result<[Node; 3], Box<dyn Error>> {
let op: OpType = op.into();

let node = hugr.add_node_with_parent(parent, NodeType::open_extensions(op))?;
let input = hugr.add_node_with_parent(
let node = hugr.add_op_with_parent(parent, op)?;
let input = hugr.add_op_with_parent(
node,
NodeType::open_extensions(ops::Input {
ops::Input {
types: op_sig.input,
}),
},
)?;
let output = hugr.add_node_with_parent(
let output = hugr.add_op_with_parent(
node,
NodeType::open_extensions(ops::Output {
ops::Output {
types: op_sig.output,
}),
},
)?;
Ok([node, input, output])
}
Expand All @@ -1020,20 +1017,20 @@ mod test {
Into::<OpType>::into(op).signature(),
)?;

let lift1 = hugr.add_node_with_parent(
let lift1 = hugr.add_op_with_parent(
case,
NodeType::open_extensions(ops::LeafOp::Lift {
ops::LeafOp::Lift {
type_row: type_row![NAT],
new_extension: first_ext,
}),
},
)?;

let lift2 = hugr.add_node_with_parent(
let lift2 = hugr.add_op_with_parent(
case,
NodeType::open_extensions(ops::LeafOp::Lift {
ops::LeafOp::Lift {
type_row: type_row![NAT],
new_extension: second_ext,
}),
},
)?;

hugr.connect(case_in, 0, lift1, 0)?;
Expand Down Expand Up @@ -1098,17 +1095,17 @@ mod test {
}));

let root = hugr.root();
let input = hugr.add_node_with_parent(
let input = hugr.add_op_with_parent(
root,
NodeType::open_extensions(ops::Input {
ops::Input {
types: type_row![NAT],
}),
},
)?;
let output = hugr.add_node_with_parent(
let output = hugr.add_op_with_parent(
root,
NodeType::open_extensions(ops::Output {
ops::Output {
types: type_row![NAT],
}),
},
)?;

// Make identical dataflow nodes which add extension requirement "A" or "B"
Expand All @@ -1129,12 +1126,12 @@ mod test {
.unwrap();

let lift = hugr
.add_node_with_parent(
.add_op_with_parent(
node,
NodeType::open_extensions(ops::LeafOp::Lift {
ops::LeafOp::Lift {
type_row: type_row![NAT],
new_extension: ext,
}),
},
)
.unwrap();

Expand Down Expand Up @@ -1181,7 +1178,7 @@ mod test {

let [bb, bb_in, bb_out] = create_with_io(hugr, bb_parent, dfb, dfb_sig)?;

let dfg = hugr.add_node_with_parent(bb, NodeType::open_extensions(op))?;
let dfg = hugr.add_op_with_parent(bb, op)?;

hugr.connect(bb_in, 0, dfg, 0)?;
hugr.connect(dfg, 0, bb_out, 0)?;
Expand Down Expand Up @@ -1213,23 +1210,20 @@ mod test {
extension_delta: entry_extensions,
};

let exit = hugr.add_node_with_parent(
let exit = hugr.add_op_with_parent(
root,
NodeType::open_extensions(ops::BasicBlock::Exit {
ops::BasicBlock::Exit {
cfg_outputs: exit_types.into(),
}),
},
)?;

let entry = hugr.add_op_before(exit,dfb)?;
let entry_in = hugr.add_node_with_parent(
entry,
NodeType::open_extensions(ops::Input { types: inputs }),
)?;
let entry_out = hugr.add_node_with_parent(
let entry = hugr.add_op_before(exit, dfb)?;
let entry_in = hugr.add_op_with_parent(entry, ops::Input { types: inputs })?;
let entry_out = hugr.add_op_with_parent(
entry,
NodeType::open_extensions(ops::Output {
ops::Output {
types: vec![entry_tuple_sum].into(),
}),
},
)?;

Ok(([entry, entry_in, entry_out], exit))
Expand Down Expand Up @@ -1280,12 +1274,12 @@ mod test {
type_row![NAT],
)?;

let mkpred = hugr.add_node_with_parent(
let mkpred = hugr.add_op_with_parent(
entry,
NodeType::open_extensions(make_opaque(
make_opaque(
A,
FunctionType::new(vec![NAT], twoway(NAT)).with_extension_delta(&a),
)),
),
)?;

// Internal wiring for DFGs
Expand Down Expand Up @@ -1376,12 +1370,9 @@ mod test {
type_row![NAT],
)?;

let entry_mid = hugr.add_node_with_parent(
let entry_mid = hugr.add_op_with_parent(
entry,
NodeType::open_extensions(make_opaque(
UNKNOWN_EXTENSION,
FunctionType::new(vec![NAT], twoway(NAT)),
)),
make_opaque(UNKNOWN_EXTENSION, FunctionType::new(vec![NAT], twoway(NAT))),
)?;

hugr.connect(entry_in, 0, entry_mid, 0)?;
Expand Down Expand Up @@ -1465,12 +1456,12 @@ mod test {
type_row![NAT],
)?;

let entry_dfg = hugr.add_node_with_parent(
let entry_dfg = hugr.add_op_with_parent(
entry,
NodeType::open_extensions(make_opaque(
make_opaque(
UNKNOWN_EXTENSION,
FunctionType::new(vec![NAT], oneway(NAT)).with_extension_delta(&entry_ext),
)),
),
)?;

hugr.connect(entry_in, 0, entry_dfg, 0)?;
Expand Down Expand Up @@ -1546,12 +1537,9 @@ mod test {
type_row![NAT],
)?;

let entry_mid = hugr.add_node_with_parent(
let entry_mid = hugr.add_op_with_parent(
entry,
NodeType::open_extensions(make_opaque(
UNKNOWN_EXTENSION,
FunctionType::new(vec![NAT], oneway(NAT)),
)),
make_opaque(UNKNOWN_EXTENSION, FunctionType::new(vec![NAT], oneway(NAT))),
)?;

hugr.connect(entry_in, 0, entry_mid, 0)?;
Expand Down
8 changes: 4 additions & 4 deletions src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ impl From<HugrError> for PyErr {

#[cfg(test)]
mod test {
use super::{Hugr, HugrView, NodeType};
use super::{Hugr, HugrView};
use crate::builder::test::closed_dfg_root_hugr;
use crate::extension::ExtensionSet;
use crate::hugr::HugrMut;
Expand Down Expand Up @@ -645,12 +645,12 @@ mod test {
FunctionType::new(type_row![BIT], type_row![BIT]).with_extension_delta(&r),
);
let [input, output] = hugr.get_io(hugr.root()).unwrap();
let lift = hugr.add_node_with_parent(
let lift = hugr.add_op_with_parent(
hugr.root(),
NodeType::open_extensions(ops::LeafOp::Lift {
ops::LeafOp::Lift {
type_row: type_row![BIT],
new_extension: "R".try_into().unwrap(),
}),
},
)?;
hugr.connect(input, 0, lift, 0)?;
hugr.connect(lift, 0, output, 0)?;
Expand Down
2 changes: 1 addition & 1 deletion src/hugr/hugrmut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ impl<T: RootTagged<RootHandle = Node> + AsMut<Hugr>> HugrMut for T {
}

fn add_op_before(&mut self, sibling: Node, op: impl Into<OpType>) -> Result<Node, HugrError> {
let node = self.as_mut().add_node(NodeType::open_extensions(op));
let node = self.as_mut().add_op(op);
self.as_mut()
.hierarchy
.insert_before(node.index, sibling.index)?;
Expand Down

0 comments on commit 41c9c14

Please sign in to comment.