diff --git a/Cargo.lock b/Cargo.lock index 561972b8..2532d40f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1046,7 +1046,7 @@ checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" [[package]] name = "quil-py" -version = "0.5.6" +version = "0.5.8" dependencies = [ "ndarray", "numpy", @@ -1059,7 +1059,7 @@ dependencies = [ [[package]] name = "quil-rs" -version = "0.21.5" +version = "0.21.7" dependencies = [ "approx", "clap", diff --git a/quil-rs/src/program/memory.rs b/quil-rs/src/program/memory.rs index eb95828c..1caec4c2 100644 --- a/quil-rs/src/program/memory.rs +++ b/quil-rs/src/program/memory.rs @@ -17,10 +17,10 @@ use std::collections::HashSet; use crate::expression::{Expression, FunctionCallExpression, InfixExpression, PrefixExpression}; use crate::instruction::{ Arithmetic, ArithmeticOperand, BinaryLogic, BinaryOperand, Capture, CircuitDefinition, - Comparison, ComparisonOperand, Delay, Exchange, Gate, GateDefinition, GateSpecification, - Instruction, JumpUnless, JumpWhen, Load, MeasureCalibrationDefinition, Measurement, - MemoryReference, Move, Pulse, RawCapture, SetPhase, SetScale, Sharing, ShiftPhase, Store, - UnaryLogic, Vector, WaveformInvocation, + Comparison, ComparisonOperand, Convert, Delay, Exchange, Gate, GateDefinition, + GateSpecification, Instruction, JumpUnless, JumpWhen, Load, MeasureCalibrationDefinition, + Measurement, MemoryReference, Move, Pulse, RawCapture, SetFrequency, SetPhase, SetScale, + Sharing, ShiftFrequency, ShiftPhase, Store, UnaryLogic, Vector, WaveformInvocation, }; #[derive(Clone, Debug, Hash, PartialEq)] @@ -96,6 +96,14 @@ impl Instruction { /// Return all memory accesses by the instruction - in expressions, captures, and memory manipulation pub fn get_memory_accesses(&self) -> MemoryAccesses { match self { + Instruction::Convert(Convert { + source, + destination, + }) => MemoryAccesses { + reads: set_from_memory_references![[source]], + writes: set_from_memory_references![[destination]], + ..Default::default() + }, Instruction::Comparison(Comparison { operands, .. }) => { let mut reads = HashSet::from([operands.1.name.clone()]); let writes = HashSet::from([operands.0.name.clone()]); @@ -183,6 +191,7 @@ impl Instruction { ..Default::default() }, Instruction::Exchange(Exchange { left, right }) => MemoryAccesses { + reads: set_from_memory_references![[left, right]], writes: set_from_memory_references![[left, right]], ..Default::default() }, @@ -250,17 +259,27 @@ impl Instruction { reads: set_from_memory_references!(expr.get_memory_references()), ..Default::default() }, + Instruction::SetFrequency(SetFrequency { frequency, .. }) + | Instruction::ShiftFrequency(ShiftFrequency { frequency, .. }) => MemoryAccesses { + reads: set_from_memory_references!(frequency.get_memory_references()), + ..Default::default() + }, Instruction::Store(Store { destination, - offset: _, + offset, source, - }) => MemoryAccesses { - reads: set_from_optional_memory_reference!(source.get_memory_reference()), - writes: set_from_reference_vec![vec![destination]], - ..Default::default() - }, - Instruction::Convert(_) - | Instruction::Declaration(_) + }) => { + let mut reads = vec![&offset.name]; + if let Some(source) = source.get_memory_reference() { + reads.push(&source.name); + } + MemoryAccesses { + reads: set_from_reference_vec![reads], + writes: set_from_reference_vec![vec![destination]], + ..Default::default() + } + } + Instruction::Declaration(_) | Instruction::Fence(_) | Instruction::FrameDefinition(_) | Instruction::Halt @@ -271,8 +290,6 @@ impl Instruction { | Instruction::Nop | Instruction::Pragma(_) | Instruction::Reset(_) - | Instruction::SetFrequency(_) - | Instruction::ShiftFrequency(_) | Instruction::SwapPhases(_) | Instruction::WaveformDefinition(_) => Default::default(), } @@ -321,3 +338,113 @@ impl WaveformInvocation { .collect() } } + +#[cfg(test)] +mod tests { + use rstest::rstest; + + use crate::expression::Expression; + use crate::instruction::{ + ArithmeticOperand, Convert, Exchange, FrameIdentifier, Instruction, MemoryReference, Qubit, + SetFrequency, ShiftFrequency, Store, + }; + use crate::program::MemoryAccesses; + use std::collections::HashSet; + + #[rstest] + #[case( + Instruction::Store(Store { + destination: "destination".to_string(), + offset: MemoryReference { + name: "offset".to_string(), + index: Default::default() + }, + source: ArithmeticOperand::MemoryReference(MemoryReference { + name: "source".to_string(), + index: Default::default() + }), + }), + MemoryAccesses { + captures: HashSet::new(), + reads: ["source", "offset"].iter().cloned().map(String::from).collect(), + writes: ["destination"].iter().cloned().map(String::from).collect(), + } + )] + #[case( + Instruction::Convert(Convert { + destination: MemoryReference { + name: "destination".to_string(), + index: Default::default() + }, + source: MemoryReference { + name: "source".to_string(), + index: Default::default() + }, + }), + MemoryAccesses { + captures: HashSet::new(), + reads: ["source"].iter().cloned().map(String::from).collect(), + writes: ["destination"].iter().cloned().map(String::from).collect(), + } + )] + #[case( + Instruction::Exchange(Exchange { + left: MemoryReference { + name: "left".to_string(), + index: Default::default() + }, + right: MemoryReference { + name: "right".to_string(), + index: Default::default() + }, + }), + MemoryAccesses { + captures: HashSet::new(), + reads: ["left", "right"].iter().cloned().map(String::from).collect(), + writes: ["left", "right"].iter().cloned().map(String::from).collect(), + } + )] + #[case( + Instruction::SetFrequency(SetFrequency { + frequency: Expression::Address(MemoryReference { + name: "frequency".to_string(), + index: Default::default() + }), + frame: FrameIdentifier { + name: "frame".to_string(), + qubits: vec![Qubit::Fixed(0)] + } + }), + MemoryAccesses { + captures: HashSet::new(), + reads: ["frequency"].iter().cloned().map(String::from).collect(), + writes: HashSet::new(), + } + )] + #[case( + Instruction::ShiftFrequency(ShiftFrequency { + frequency: Expression::Address(MemoryReference { + name: "frequency".to_string(), + index: Default::default() + }), + frame: FrameIdentifier { + name: "frame".to_string(), + qubits: vec![Qubit::Fixed(0)] + } + }), + MemoryAccesses { + captures: HashSet::new(), + reads: ["frequency"].iter().cloned().map(String::from).collect(), + writes: HashSet::new(), + } + )] + fn test_instruction_accesses( + #[case] instruction: Instruction, + #[case] expected: MemoryAccesses, + ) { + let memory_accesses = instruction.get_memory_accesses(); + assert_eq!(memory_accesses.captures, expected.captures); + assert_eq!(memory_accesses.reads, expected.reads); + assert_eq!(memory_accesses.writes, expected.writes); + } +}