diff --git a/tests/passes/control/test_foreachblock.py b/tests/passes/control/test_foreachblock.py index a0578bcfb..408d80b68 100644 --- a/tests/passes/control/test_foreachblock.py +++ b/tests/passes/control/test_foreachblock.py @@ -17,7 +17,6 @@ from bqskit.passes import UnfoldPass from bqskit.passes.control.foreach import ForEachBlockPass from bqskit.passes.partitioning.quick import QuickPartitioner -from bqskit.passes.synthesis.qsearch import QSearchSynthesisPass from bqskit.passes.util.update import UpdateDataPass @@ -52,6 +51,18 @@ async def run(self, circuit: Circuit, data: PassData) -> None: circuit.append_gate(HGate(), 1) +class TestForEachPassDownSpecificSeeds(BasePass): + async def run(self, circuit: Circuit, data: PassData) -> None: + key = ForEachBlockPass.pass_down_block_specific_key_prefix + key += 'test_info' + assert key in data + assert data[key] == 'a' or data[key] == 'b' + if data[key] == 'a': + data['test_response'] = 'a' + else: + data['test_response'] = 'b' + + def never_replace(c: Circuit, o: Operation) -> bool: return False @@ -144,38 +155,28 @@ def test_pass_down_seeds(compiler: Compiler) -> None: circuit.append_gate(CNOTGate(), (1, 3)) circuit.append_gate(CNOTGate(), (2, 3)) - seed = Circuit(3) - seed.append_gate(CNOTGate(), (0, 2)) - seed.append_gate(CNOTGate(), (1, 2)) - # Manually set seed for blocks 0 and 1 - seeds = {0: [seed], 1: [seed]} + input_info = {0: 'a', 1: 'b'} partitioner = QuickPartitioner() - qsearch = QSearchSynthesisPass() - foreach = ForEachBlockPass(qsearch) - unfolder = UnfoldPass() + check_specific = TestForEachPassDownSpecificSeeds() + foreach = ForEachBlockPass(check_specific) - key = foreach.pass_down_block_specific_key_prefix + 'seed_circuits' - updater = UpdateDataPass(key, seeds) + key = foreach.pass_down_block_specific_key_prefix + 'test_info' + updater = UpdateDataPass(key, input_info) # For checking specific pass down data exists - workflow_1 = Workflow([partitioner, updater]) - # For checking specific pass down data can be used - workflow_2 = Workflow([partitioner, updater, foreach, unfolder]) + workflow = Workflow([partitioner, updater, foreach]) - # Check existence of block specific keys - partitioned, data = compiler.compile( - circuit, workflow_1, request_data=True, - ) - for i, block in enumerate(partitioned): - data[key][i] == seeds[i] - - # Check that block specific data is usable + # Check usability of block specific keys compiled, data = compiler.compile( - circuit, workflow_2, request_data=True, + circuit, workflow, request_data=True, ) - dist = compiled.get_unitary().get_distance_from(circuit.get_unitary()) - assert dist <= 1e-5 - assert key in data - assert data[key] == seeds + + block0_data = data['ForEachBlockPass_data'][0][0] + block1_data = data['ForEachBlockPass_data'][0][1] + response_key = 'test_response' + assert response_key in block0_data + assert response_key in block1_data + assert block0_data[response_key] == 'a' + assert block1_data[response_key] == 'b'