Skip to content

Commit

Permalink
style: add refurb and apply its changes, upgrade wasmtime (#105)
Browse files Browse the repository at this point in the history
  • Loading branch information
qartik authored Jan 24, 2024
1 parent 427c7e0 commit b3f8f34
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 24 deletions.
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ repos:
args: [--fix, --exit-non-zero-on-fix]
- id: ruff-format

- repo: https://github.com/dosisod/refurb
rev: v1.28.0
hooks:
- id: refurb

- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.8.0'
hooks:
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,6 @@ log_date_format = "%Y-%m-%d %H:%M:%S"

[tool.setuptools_scm]
version_scheme = "python-simplified-semver"

[tool.refurb]
python_version = "3.10"
3 changes: 1 addition & 2 deletions pytket/phir/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ def pytket_to_phir(
machine = None

logger.debug("Sharding input circuit...")
sharder = Sharder(circuit)
shards = sharder.shard()
shards = Sharder(circuit).shard()

if machine:
# Only print message if a machine object is passed
Expand Down
2 changes: 1 addition & 1 deletion pytket/phir/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def main() -> None:
parser.add_argument(
"--version",
action="version",
version=f'{version("pytket-phir")}',
version=str(version("pytket-phir")),
)
args = parser.parse_args()

Expand Down
4 changes: 2 additions & 2 deletions pytket/phir/phirgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,8 @@ def make_comment_text(command: tk.Command, op: tk.Op) -> str:
case tk.WASMOp():
args, returns = extract_wasm_args_and_returns(command, op)
return f"WASM function={op.func_name} args={args} returns={returns}"
case _:
return str(command)

return str(command)


def get_decls(qbits: set["Qubit"], cbits: set[tkBit]) -> list[dict[str, str | int]]:
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ pytket==1.24.0
ruff==0.1.14
setuptools_scm==8.0.4
sphinx==7.2.6
wasmtime==15.0.0
wasmtime==16.0.0
wheel==0.42.0
24 changes: 8 additions & 16 deletions tests/test_sharder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
class TestSharder:
def test_shard_hashing(self) -> None:
circuit = get_qasm_as_circuit(QasmFile.baby)
sharder = Sharder(circuit)
shards = sharder.shard()
shards = Sharder(circuit).shard()

shard_set = set(shards)
assert len(shard_set) == 3
Expand Down Expand Up @@ -47,8 +46,7 @@ def test_should_op_create_shard(self) -> None:

def test_with_baby_circuit(self) -> None:
circuit = get_qasm_as_circuit(QasmFile.baby)
sharder = Sharder(circuit)
shards = sharder.shard()
shards = Sharder(circuit).shard()

assert len(shards) == 3

Expand All @@ -70,8 +68,7 @@ def test_with_baby_circuit(self) -> None:

def test_rollup_behavior(self) -> None:
circuit = get_qasm_as_circuit(QasmFile.baby_with_rollup)
sharder = Sharder(circuit)
shards = sharder.shard()
shards = Sharder(circuit).shard()

assert len(shards) == 5

Expand Down Expand Up @@ -106,8 +103,7 @@ def test_rollup_behavior(self) -> None:

def test_simple_conditional(self) -> None:
circuit = get_qasm_as_circuit(QasmFile.simple_cond)
sharder = Sharder(circuit)
shards = sharder.shard()
shards = Sharder(circuit).shard()

assert len(shards) == 4

Expand Down Expand Up @@ -153,8 +149,7 @@ def test_simple_conditional(self) -> None:

def test_complex_barriers(self) -> None: # noqa: PLR0915
circuit = get_qasm_as_circuit(QasmFile.barrier_complex)
sharder = Sharder(circuit)
shards = sharder.shard()
shards = Sharder(circuit).shard()

assert len(shards) == 7

Expand Down Expand Up @@ -238,8 +233,7 @@ def test_complex_barriers(self) -> None: # noqa: PLR0915

def test_classical_hazards(self) -> None:
circuit = get_qasm_as_circuit(QasmFile.classical_hazards)
sharder = Sharder(circuit)
shards = sharder.shard()
shards = Sharder(circuit).shard()

assert len(shards) == 5

Expand Down Expand Up @@ -286,8 +280,7 @@ def test_classical_hazards(self) -> None:

def test_with_big_gate(self) -> None:
circuit = get_qasm_as_circuit(QasmFile.big_gate)
sharder = Sharder(circuit)
shards = sharder.shard()
shards = Sharder(circuit).shard()

assert len(shards) == 2

Expand All @@ -310,8 +303,7 @@ def test_with_big_gate(self) -> None:

def test_classical_ordering_breaking_circuit(self) -> None:
circuit = get_qasm_as_circuit(QasmFile.classical_ordering)
sharder = Sharder(circuit)
shards = sharder.shard()
shards = Sharder(circuit).shard()

assert len(shards) == 4

Expand Down
3 changes: 1 addition & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ def get_phir_json(qasmfile: QasmFile, *, rebase: bool) -> "JsonDict":
circuit = rebase_to_qtm_machine(circuit, qtm_machine.value, 0)
machine = QTM_MACHINES_MAP.get(qtm_machine)
assert machine
sharder = Sharder(circuit)
shards = sharder.shard()
shards = Sharder(circuit).shard()
placed = place_and_route(shards, machine)
return json.loads(genphir_parallel(placed, machine)) # type: ignore[misc, no-any-return]

Expand Down

0 comments on commit b3f8f34

Please sign in to comment.