Skip to content

Commit

Permalink
Rewrite takes any HugrMut (TransactionalRewrite copies twice, ick, ne…
Browse files Browse the repository at this point in the history
…eds test)
  • Loading branch information
acl-cqc committed Jul 26, 2023
1 parent b6ef735 commit b6b1cbd
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 17 deletions.
30 changes: 20 additions & 10 deletions src/hugr/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
pub mod outline_cfg;
pub mod simple_replace;
use std::mem;

use crate::Hugr;
use crate::{Hugr, HugrView};
pub use simple_replace::{SimpleReplacement, SimpleReplacementError};

use super::HugrMut;

/// An operation that can be applied to mutate a Hugr
pub trait Rewrite {
/// The type of Error with which this Rewrite may fail
Expand All @@ -19,7 +20,7 @@ pub trait Rewrite {
/// Checks whether the rewrite would succeed on the specified Hugr.
/// If this call succeeds, [self.apply] should also succeed on the same `h`
/// If this calls fails, [self.apply] would fail with the same error.
fn verify(&self, h: &Hugr) -> Result<(), Self::Error>;
fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error>;

/// Mutate the specified Hugr, or fail with an error.
/// If [self.unchanged_on_failure] is true, then `h` must be unchanged if Err is returned.
Expand All @@ -28,7 +29,7 @@ pub trait Rewrite {
/// May panic if-and-only-if `h` would have failed [Hugr::validate]; that is,
/// implementations may begin with `assert!(h.validate())`, with `debug_assert!(h.validate())`
/// being preferred.
fn apply(self, h: &mut Hugr) -> Result<(), Self::Error>;
fn apply(self, h: &mut impl HugrMut) -> Result<(), Self::Error>;
}

/// Wraps any rewrite into a transaction (i.e. that has no effect upon failure)
Expand All @@ -42,20 +43,29 @@ impl<R: Rewrite> Rewrite for Transactional<R> {
type Error = R::Error;
const UNCHANGED_ON_FAILURE: bool = true;

fn verify(&self, h: &Hugr) -> Result<(), Self::Error> {
fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> {
self.underlying.verify(h)
}

fn apply(self, h: &mut Hugr) -> Result<(), Self::Error> {
fn apply(self, h: &mut impl HugrMut) -> Result<(), Self::Error> {
if R::UNCHANGED_ON_FAILURE {
return self.underlying.apply(h);
}
let backup = h.clone();
// Try to backup just the contents of this HugrMut.
let mut backup = Hugr::new(h.root_type().clone());
backup.insert_from_view(backup.root(), h).unwrap();
let r = self.underlying.apply(h);
fn first_child(h: &impl HugrView) -> Option<crate::Node> {
h.children(h.root()).next()
}
if r.is_err() {
// drop the old h, it was undefined
let _ = mem::replace(h, backup);
// Try to restore backup.
h.replace_op(h.root(), backup.root_type().clone());
while let Some(child) = first_child(h) {
h.remove_node(child).unwrap();
}
h.insert_from_view(h.root(), &backup).unwrap();
}
r
Ok(())
}
}
15 changes: 10 additions & 5 deletions src/hugr/rewrite/outline_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::builder::{BlockBuilder, Container, Dataflow, SubContainer};
use crate::hugr::rewrite::Rewrite;
use crate::hugr::{HugrMut, HugrView};
use crate::ops::{BasicBlock, ConstValue, OpTag, OpTrait, OpType};
use crate::{type_row, Hugr, Node};
use crate::{type_row, Node};

/// Moves part of a Control-flow Sibling Graph into a new CFG-node
/// that is the only child of a new Basic Block in the original CSG.
Expand All @@ -24,7 +24,10 @@ impl OutlineCfg {
}
}

fn compute_entry_exit_outside(&self, h: &Hugr) -> Result<(Node, Node, Node), OutlineCfgError> {
fn compute_entry_exit_outside(
&self,
h: &impl HugrView,
) -> Result<(Node, Node, Node), OutlineCfgError> {
let cfg_n = match self
.blocks
.iter()
Expand Down Expand Up @@ -81,11 +84,11 @@ impl OutlineCfg {
impl Rewrite for OutlineCfg {
type Error = OutlineCfgError;
const UNCHANGED_ON_FAILURE: bool = true;
fn verify(&self, h: &Hugr) -> Result<(), OutlineCfgError> {
fn verify(&self, h: &impl HugrView) -> Result<(), OutlineCfgError> {
self.compute_entry_exit_outside(h)?;
Ok(())
}
fn apply(self, h: &mut Hugr) -> Result<(), OutlineCfgError> {
fn apply(self, h: &mut impl HugrMut) -> Result<(), OutlineCfgError> {
let (entry, exit, outside) = self.compute_entry_exit_outside(h)?;
// 1. Compute signature
// These panic()s only happen if the Hugr would not have passed validate()
Expand Down Expand Up @@ -124,12 +127,13 @@ impl Rewrite for OutlineCfg {
.children(new_block)
.filter(|n| h.get_optype(*n).tag() == OpTag::Cfg)
.exactly_one()
.ok() // HugrMut::Children is not Debug
.unwrap();
let inner_exit = h.children(cfg_node).next().unwrap();

// 4. Entry edges. Change any edges into entry_block from outside, to target new_block
let preds: Vec<_> = h
.linked_ports(entry, h.node_inputs(entry).exactly_one().unwrap())
.linked_ports(entry, h.node_inputs(entry).exactly_one().ok().unwrap())
.collect();
for (pred, br) in preds {
if !self.blocks.contains(&pred) {
Expand Down Expand Up @@ -164,6 +168,7 @@ impl Rewrite for OutlineCfg {
t == outside
})
.exactly_one()
.ok() // NodePorts does not implement Debug
.unwrap();
h.disconnect(exit, exit_port).unwrap();
h.connect(exit, exit_port.index(), inner_exit, 0).unwrap();
Expand Down
6 changes: 4 additions & 2 deletions src/hugr/rewrite/simple_replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ impl Rewrite for SimpleReplacement {
type Error = SimpleReplacementError;
const UNCHANGED_ON_FAILURE: bool = true;

fn verify(&self, _h: &Hugr) -> Result<(), SimpleReplacementError> {
fn verify(&self, _h: &impl HugrView) -> Result<(), SimpleReplacementError> {
unimplemented!()
}

fn apply(self, h: &mut Hugr) -> Result<(), SimpleReplacementError> {
fn apply(self, h: &mut impl HugrMut) -> Result<(), SimpleReplacementError> {
// 1. Check the parent node exists and is a DFG node.
if h.get_optype(self.parent).tag() != OpTag::Dfg {
return Err(SimpleReplacementError::InvalidParentNode());
Expand Down Expand Up @@ -123,6 +123,7 @@ impl Rewrite for SimpleReplacement {
let (rem_inp_pred_node, rem_inp_pred_port) = h
.linked_ports(*rem_inp_node, *rem_inp_port)
.exactly_one()
.ok() // PortLinks does not implement Debug
.unwrap();
h.disconnect(*rem_inp_node, *rem_inp_port).unwrap();
let new_inp_node = index_map.get(rep_inp_node).unwrap();
Expand Down Expand Up @@ -164,6 +165,7 @@ impl Rewrite for SimpleReplacement {
let (rem_inp_pred_node, rem_inp_pred_port) = h
.linked_ports(*rem_inp_node, *rem_inp_port)
.exactly_one()
.ok() // PortLinks does not implement Debug
.unwrap();
h.disconnect(*rem_inp_node, *rem_inp_port).unwrap();
h.disconnect(*rem_out_node, *rem_out_port).unwrap();
Expand Down

0 comments on commit b6b1cbd

Please sign in to comment.