diff --git a/pytket/phir/sharding/shard.py b/pytket/phir/sharding/shard.py index f73531e..99799ef 100644 --- a/pytket/phir/sharding/shard.py +++ b/pytket/phir/sharding/shard.py @@ -13,6 +13,9 @@ class Shard: we actually do placement of qubits """ + # The unique identifier of the shard + ID: int = field(default_factory=count().__next__, init=False) + # The "schedulable" command of the shard primary_command: Command @@ -20,35 +23,32 @@ class Shard: # as a map of bit-handle (unitID) -> list[Command] sub_commands: dict[UnitID, list[Command]] - # A set of the identifiers of other shards this particular shard depends upon - depends_upon: set[int] - # All qubits used by the primary and sub commands - qubits_used: set[Qubit] = field(init=False) + qubits_used: set[Qubit] # = field(init=False) # Set of all classical bits written to by the primary and sub commands - bits_written: set[Bit] = field(init=False) + bits_written: set[Bit] # = field(init=False) # Set of all classical bits read by the primary and sub commands - bits_read: set[Bit] = field(init=False) + bits_read: set[Bit] # = field(init=False) - # The unique identifier of the shard - ID: int = field(default_factory=count().__next__, init=False) + # A set of the identifiers of other shards this particular shard depends upon + depends_upon: set[int] - def __post_init__(self) -> None: - self.qubits_used = set(self.primary_command.qubits) - self.bits_written = set(self.primary_command.bits) - self.bits_read = set() + # def __post_init__(self) -> None: + # self.qubits_used = set(self.primary_command.qubits) + # self.bits_written = set(self.primary_command.bits) + # self.bits_read = set() - all_sub_commands: list[Command] = [] - for sub_commands in self.sub_commands.values(): - all_sub_commands.extend(sub_commands) + # all_sub_commands: list[Command] = [] + # for sub_commands in self.sub_commands.values(): + # all_sub_commands.extend(sub_commands) - for sub_command in all_sub_commands: - self.bits_written.update(sub_command.bits) - self.bits_read.update( - set(filter(lambda x: isinstance(x, Bit), sub_command.args)), # type: ignore [misc,arg-type] # noqa: E501 - ) + # for sub_command in all_sub_commands: + # self.bits_written.update(sub_command.bits) + # self.bits_read.update( + # set(filter(lambda x: isinstance(x, Bit), sub_command.args)), # type: ignore [misc,arg-type] # noqa: E501 + # ) def pretty_print(self) -> str: output = io.StringIO() diff --git a/pytket/phir/sharding/sharder.py b/pytket/phir/sharding/sharder.py index 1455495..a7b7317 100644 --- a/pytket/phir/sharding/sharder.py +++ b/pytket/phir/sharding/sharder.py @@ -1,10 +1,19 @@ from pytket.circuit import Circuit, Command, Op, OpType -from pytket.unit_id import UnitID +from pytket.unit_id import Bit, UnitID from .shard import Shard NOT_IMPLEMENTED_OP_TYPES = [OpType.CircBox, OpType.WASM] +SHARD_TRIGGER_OP_TYPES = [ + OpType.Measure, + OpType.Reset, + OpType.Barrier, + OpType.SetBits, + OpType.ClassicalExpBox, # some classical operations are rolled up into a box + OpType.RangePredicate, +] + class Sharder: """ @@ -46,7 +55,9 @@ def _process_command(self, command: Command) -> None: raise NotImplementedError(msg) if self.should_op_create_shard(command.op): - print(f"Building shard for command: {command}") + print( + f"Building shard for command: {command} args:{command.args} bits:{command.bits}", + ) self._build_shard(command) else: self._add_pending_sub_command(command) @@ -63,6 +74,22 @@ def _build_shard(self, command: Command) -> None: ): sub_commands[key] = self._pending_commands.pop(key) + all_commands = [command] + for sub_command in sub_commands.values(): + all_commands.extend(sub_command) + + qubits_used = set(command.qubits) + + bits_written = set(command.bits) + + bits_read = set() + + for sub_command in all_commands: + bits_written.update(sub_command.bits) + bits_read.update( + set(filter(lambda x: isinstance(x, Bit), sub_command.args)), # type: ignore [misc,arg-type] # noqa: E501 + ) + # Handle dependency calculations depends_upon: set[int] = set() for shard in self._shards: @@ -72,9 +99,14 @@ def _build_shard(self, command: Command) -> None: depends_upon.add(shard.ID) # Check classical dependencies, which depend on writing and reading # hazards: RAW, WAW, WAR - # TODO: Do it! + elif not shard.bits_written.isdisjoint(bits_written): + depends_upon.add(shard.ID) + elif not shard.bits_read.isdisjoint(bits_written): + depends_upon.add(shard.ID) - shard = Shard(command, sub_commands, depends_upon) + shard = Shard( + command, sub_commands, qubits_used, bits_written, bits_read, depends_upon, + ) self._shards.append(shard) print("Appended shard:", shard) @@ -101,15 +133,20 @@ def _add_pending_sub_command(self, command: Command) -> None: if key not in self._pending_commands: self._pending_commands[key] = [] self._pending_commands[key].append(command) - print(f"Adding pending command {command}") + print( + f"Adding pending command {command} args: {command.args} bits: {command.bits}", + ) @staticmethod def should_op_create_shard(op: Op) -> bool: """ Returns `True` if the operation is one that should result in shard creation. This includes non-gate operations like measure/reset as well as 2-qubit gates. - TODO: This is almost certainly inadequate right now """ - return op.type in (OpType.Measure, OpType.Reset, OpType.Barrier) or ( - op.is_gate() and op.n_qubits > 1 + return ( + op.type in (SHARD_TRIGGER_OP_TYPES) + or ( + op.type == OpType.Conditional and op.op.type in (SHARD_TRIGGER_OP_TYPES) + ) + or (op.is_gate() and op.n_qubits > 1) ) diff --git a/tests/data/qasm/cond_classical.qasm b/tests/data/qasm/cond_classical.qasm new file mode 100644 index 0000000..2b3cb60 --- /dev/null +++ b/tests/data/qasm/cond_classical.qasm @@ -0,0 +1,24 @@ +OPENQASM 2.0; +include "hqslib1_dev.inc"; +qreg q[1]; +creg a[10]; +creg b[10]; +creg c[4]; +// classical assignment of registers +a[0] = 1; +a = 3; +// classical bitwise functions +a = 1; +b = 3; +c = a ^ b; // XOR +// evaluating a beyond creg == int +a = 1; +b = 2; +if(a[0]==1) x q[0]; +if(a!=1) x q[0]; +if(a>1) x q[0]; +if(a<1) x q[0]; +if(a>=1) x q[0]; +if(a<=1) x q[0]; +if (a==10) b=1; +measure q[0] -> c[0]; diff --git a/tests/data/qasm/simple_cond.qasm b/tests/data/qasm/simple_cond.qasm index d590bb6..30a1b95 100644 --- a/tests/data/qasm/simple_cond.qasm +++ b/tests/data/qasm/simple_cond.qasm @@ -3,9 +3,11 @@ include "hqslib1.inc"; qreg q[1]; creg c[1]; +creg z[1]; h q; measure q->c; reset q; if (c==1) h q; +if (c==1) z=3; measure q->c; diff --git a/tests/sample_data.py b/tests/sample_data.py index 4e7d3c5..a6b15db 100644 --- a/tests/sample_data.py +++ b/tests/sample_data.py @@ -11,6 +11,7 @@ class QasmFiles(Enum): baby = 4 baby_with_rollup = 5 simple_cond = 6 + cond_classical = 7 def get_qasm_as_circuit(qasm_file: QasmFiles) -> Circuit: diff --git a/tests/test_sharder.py b/tests/test_sharder.py index 82011d0..57ee256 100644 --- a/tests/test_sharder.py +++ b/tests/test_sharder.py @@ -13,6 +13,7 @@ def test_should_op_create_shard(self) -> None: Op.create(OpType.Reset), # type: ignore # noqa: PGH003 Op.create(OpType.CX), # type: ignore # noqa: PGH003 Op.create(OpType.Barrier), # type: ignore # noqa: PGH003 + # Op.create(OpType.SetBits, [3, 1]), ] expected_false: list[Op] = [ Op.create(OpType.U1, 0.32), # type: ignore # noqa: PGH003 @@ -86,21 +87,82 @@ def test_simple_conditional(self) -> None: sharder = Sharder(circuit) shards = sharder.shard() - assert len(shards) == 3 + assert len(shards) == 4 + # shard 0: h q; measure q->c; assert shards[0].primary_command.op.type == OpType.Measure assert len(shards[0].sub_commands.items()) == 1 s0_qubit, s0_sub_cmds = next(iter(shards[0].sub_commands.items())) assert s0_qubit == circuit.qubits[0] assert s0_sub_cmds[0].op.type == OpType.H + assert shards[0].depends_upon == set() + # shard 1: reset q; assert shards[1].primary_command.op.type == OpType.Reset assert len(shards[1].sub_commands.items()) == 0 + assert shards[1].depends_upon == {shards[0].ID} - assert shards[2].primary_command.op.type == OpType.Measure - assert len(shards[2].sub_commands.items()) == 1 - s2_qubit, s2_sub_cmds = next(iter(shards[2].sub_commands.items())) + # shard 2: if (c==1) z=3; + assert shards[2].primary_command.op.type == OpType.Conditional + assert cast(Conditional, shards[2].primary_command).op.op.type == OpType.SetBits + assert len(shards[2].sub_commands.keys()) == 0 + assert shards[2].depends_upon == {shards[0].ID} + + # shard 3: if (c==1) h q; measure q->c; + assert shards[3].primary_command.op.type == OpType.Measure + assert len(shards[3].sub_commands.items()) == 1 + s2_qubit, s2_sub_cmds = next(iter(shards[3].sub_commands.items())) assert s2_qubit == circuit.qubits[0] assert s2_sub_cmds[0].op.type == OpType.Conditional assert cast(Conditional, s2_sub_cmds[0].op).op.type == OpType.H assert s2_sub_cmds[0].qubits == [circuit.qubits[0]] + + def test_classical_with_conditionals(self) -> None: + circuit = get_qasm_as_circuit(QasmFiles.cond_classical) + sharder = Sharder(circuit) + shards = sharder.shard() + + circuit.get_commands() + + # assert len(shards) == 10 # TODO: fix with correct value + + # shard 0: a[0] = 1; + assert shards[0].primary_command.op.type == OpType.SetBits + assert len(shards[0].sub_commands.keys()) == 0 + assert shards[0].depends_upon == set() + assert shards[0].qubits_used == set() + assert shards[0].bits_read == set() + assert shards[0].bits_written == {circuit.bits[0]} + + # shard 1: a = 3; + assert shards[1].primary_command.op.type == OpType.SetBits + assert len(shards[1].sub_commands.keys()) == 0 + assert shards[1].depends_upon == {shards[0].ID} # WAW for shard 0 + assert shards[1].qubits_used == set() + assert shards[1].bits_read == set() + assert len(shards[1].bits_written) == 10 # TODO: Check for a[0-9] + + # shard 2: a = 1; + assert shards[2].primary_command.op.type == OpType.SetBits + assert len(shards[2].sub_commands.keys()) == 0 + assert shards[2].depends_upon == { + shards[0].ID, + shards[1].ID, + } # WAW for shard 0, 1 + assert shards[2].qubits_used == set() + assert shards[2].bits_read == set() + assert len(shards[2].bits_written) == 10 # TODO: Check for a[0-9] + + # shard 3: b = 3; + assert shards[3].primary_command.op.type == OpType.SetBits + assert len(shards[3].sub_commands.keys()) == 0 + assert shards[3].depends_upon == set() + assert shards[3].qubits_used == set() + assert shards[3].bits_read == set() + assert len(shards[3].bits_written) == 10 # TODO: Check for b[0-9] + + # shard 4: c = a ^ b; // XOR + assert shards[4].primary_command.op.type == OpType.ClassicalExpBox + assert len(shards[3].sub_commands.keys()) == 0 + assert shards[3].depends_upon == set() + assert shards[4].qubits_used == set()