Skip to content

Commit

Permalink
fix: update wasm functions to accept WasmModuleHandler (#1613)
Browse files Browse the repository at this point in the history
  • Loading branch information
qartik authored Oct 10, 2024
1 parent 8dcc370 commit 4c3d1d1
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 8 deletions.
8 changes: 4 additions & 4 deletions pytket/pytket/_tket/circuit.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1965,14 +1965,14 @@ class Circuit:
:param qubit_2: index of target qubit 2
:return: the new :py:class:`Circuit`
"""
def add_wasm(self, funcname: str, filehandler: pytket.wasm.wasm.WasmFileHandler, list_i: typing.Sequence[int], list_o: typing.Sequence[int], args: typing.Union[typing.Sequence[int], typing.Sequence[pytket._tket.unit_id.Bit]], args_wasm: typing.Optional[typing.Sequence[int]] = None, **kwargs: typing.Any) -> Circuit:
def add_wasm(self, funcname: str, filehandler: pytket.wasm.wasm.WasmModuleHandler, list_i: typing.Sequence[int], list_o: typing.Sequence[int], args: typing.Union[typing.Sequence[int], typing.Sequence[pytket._tket.unit_id.Bit]], args_wasm: typing.Optional[typing.Sequence[int]] = None, **kwargs: typing.Any) -> Circuit:
"""
Add a classical function call from a wasm file to the circuit.
:param funcname: name of the function that is called
:param filehandler: wasm file handler to identify the wasm file
:param filehandler: wasm file or module handler to identify the wasm module
:param list_i: list of the number of bits in the input variables
Expand All @@ -1988,14 +1988,14 @@ class Circuit:
:return: the new :py:class:`Circuit`
"""
def add_wasm_to_reg(self, funcname: str, filehandler: pytket.wasm.wasm.WasmFileHandler, list_i: typing.Sequence[pytket._tket.unit_id.BitRegister], list_o: typing.Sequence[pytket._tket.unit_id.BitRegister], args_wasm: typing.Optional[typing.Sequence[int]] = None, **kwargs: typing.Any) -> Circuit:
def add_wasm_to_reg(self, funcname: str, filehandler: pytket.wasm.wasm.WasmModuleHandler, list_i: typing.Sequence[pytket._tket.unit_id.BitRegister], list_o: typing.Sequence[pytket._tket.unit_id.BitRegister], args_wasm: typing.Optional[typing.Sequence[int]] = None, **kwargs: typing.Any) -> Circuit:
"""
Add a classical function call from a wasm file to the circuit.
:param funcname: name of the function that is called
:param filehandler: wasm file handler to identify the wasm file
:param filehandler: wasm file or module handler to identify the wasm module
:param list_i: list of the classical registers assigned to
the input variables of the function call
Expand Down
8 changes: 4 additions & 4 deletions pytket/pytket/circuit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
def overload_add_wasm(
self: Circuit,
funcname: str,
filehandler: wasm.WasmFileHandler,
filehandler: wasm.WasmModuleHandler,
list_i: Sequence[int],
list_o: Sequence[int],
args: Union[Sequence[int], Sequence[Bit]],
Expand All @@ -72,7 +72,7 @@ def overload_add_wasm(
) -> Circuit:
"""Add a classical function call from a wasm file to the circuit.
\n\n:param funcname: name of the function that is called
\n:param filehandler: wasm file handler to identify the wasm file
\n:param filehandler: wasm file or module handler to identify the wasm module
\n:param list_i: list of the number of bits in the input variables
\n:param list_o: list of the number of bits in the output variables
\n:param args: vector of circuit bits the wasm op should be added to
Expand Down Expand Up @@ -113,15 +113,15 @@ def overload_add_wasm(
def overload_add_wasm_to_reg(
self: Circuit,
funcname: str,
filehandler: wasm.WasmFileHandler,
filehandler: wasm.WasmModuleHandler,
list_i: Sequence[BitRegister],
list_o: Sequence[BitRegister],
args_wasm: Optional[Sequence[int]] = None,
**kwargs: Any,
) -> Circuit:
"""Add a classical function call from a wasm file to the circuit.
\n\n:param funcname: name of the function that is called
\n:param filehandler: wasm file handler to identify the wasm file
\n:param filehandler: wasm file or module handler to identify the wasm module
\n:param list_i: list of the classical registers assigned to
the input variables of the function call
\n:param list_o: list of the classical registers assigned to
Expand Down
1 change: 1 addition & 0 deletions pytket/pytket/wasm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@

from .wasm import (
WasmFileHandler,
WasmModuleHandler,
)
25 changes: 25 additions & 0 deletions pytket/tests/classical_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,18 @@ def test_wasm_3() -> None:
assert c.depth() == 1


def test_wasm_3_bytes() -> None:
with open("testfile.wasm", "rb") as f:
bytecode = f.read()
c = Circuit(0, 6)

w = wasm.WasmModuleHandler(bytecode)

c.add_wasm("add_one", w, [1], [1], [Bit(0), Bit(1)])

assert c.depth() == 1


def test_wasm_4() -> None:
w = wasm.WasmFileHandler("testfile.wasm")

Expand Down Expand Up @@ -412,6 +424,19 @@ def test_wasm_function_check_6() -> None:
assert c.depth() == 1


def test_wasm_function_check_6_bytes() -> None:
with open("testfile.wasm", "rb") as f:
bytecode = f.read()

w = wasm.WasmModuleHandler(bytecode)
c = Circuit(20, 20)
c0 = c.add_c_register("c0", 32)
c1 = c.add_c_register("c1", 4)

c.add_wasm_to_reg("add_one", w, [c0], [c1])
assert c.depth() == 1


def test_wasm_function_check_7() -> None:
w = wasm.WasmFileHandler("testfile.wasm", int_size=32)
c = Circuit(20, 20)
Expand Down

0 comments on commit 4c3d1d1

Please sign in to comment.