Skip to content

Commit

Permalink
Test the ecc io encoding roundtrip
Browse files Browse the repository at this point in the history
  • Loading branch information
aborgna-q committed Jul 22, 2024
1 parent 70eacd1 commit e89aa9c
Showing 1 changed file with 21 additions and 2 deletions.
23 changes: 21 additions & 2 deletions tket2/src/rewrite/ecc_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ impl ECCRewriter {
#[cfg(feature = "binary-eccs")]
pub fn save_binary_io<W: io::Write>(
&self,
writer: &mut W,
writer: W,
) -> Result<(), RewriterSerialisationError> {
let mut encoder = zstd::Encoder::new(writer, 9)?;
rmp_serde::encode::write(&mut encoder, &self)?;
Expand All @@ -134,7 +134,7 @@ impl ECCRewriter {
///
/// Loads streams as created by [`ECCRewriter::save_binary_io`].
#[cfg(feature = "binary-eccs")]
pub fn load_binary_io<R: io::Read>(reader: &mut R) -> Result<Self, RewriterSerialisationError> {
pub fn load_binary_io<R: io::Read>(reader: R) -> Result<Self, RewriterSerialisationError> {
let data = zstd::decode_all(reader)?;
Ok(rmp_serde::decode::from_slice(&data)?)
}
Expand Down Expand Up @@ -389,4 +389,23 @@ mod tests {
let cx_cx = cx_cx();
assert_eq!(rewriter.get_rewrites(&cx_cx).len(), 1);
}

#[test]
#[cfg(feature = "binary-eccs")]
fn ecc_file_roundtrip() {
let ecc = EqCircClass::new(h_h(), vec![empty(), cx_cx()]);
let rewriter = ECCRewriter::from_eccs([ecc]);

let mut data: Vec<u8> = Vec::new();
rewriter.save_binary_io(&mut data).unwrap();
let loaded_rewriter = ECCRewriter::load_binary_io(data.as_slice()).unwrap();

assert_eq!(
rewriter.matcher.n_patterns(),
loaded_rewriter.matcher.n_patterns()
);
assert_eq!(rewriter.targets, loaded_rewriter.targets);
assert_eq!(rewriter.rewrite_rules, loaded_rewriter.rewrite_rules);
assert_eq!(rewriter.empty_wires, loaded_rewriter.empty_wires);
}
}

0 comments on commit e89aa9c

Please sign in to comment.