Skip to content

Commit

Permalink
Update vm build. (apache#55)
Browse files Browse the repository at this point in the history
  • Loading branch information
YuchenJin authored and yongwww committed Aug 14, 2022
1 parent 93cb308 commit f16f3a5
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 40 deletions.
29 changes: 20 additions & 9 deletions python/tvm/relax/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,26 +138,22 @@ def __getitem__(self, key: str) -> PackedFunc:
return self.module[key]


def build(mod: tvm.IRModule,
target: tvm.target.Target,
target_host: tvm.target.Target) -> Tuple[Executable, Module]:
def build(mod: tvm.IRModule, target: tvm.target.Target) -> Tuple[Executable, Module]:
"""
Build an IRModule to VM executable.
Parameters
----------
mod: IRModule
The IR module.
The input IRModule to be built.
target : tvm.target.Target
A build target.
A build target which can have optional host side compilation target.
target_host : tvm.target.Target
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
to setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
Expand All @@ -167,6 +163,20 @@ def build(mod: tvm.IRModule,
An executable that can be loaded by virtual machine.
lib: tvm.runtime.Module
A runtime module that contains generated code.
Example
-------
.. code-block:: python
class InputModule:
@R.function
def foo(x: Tensor[(3, 4), "float32"], y: Tensor[(3, 4), "float32"]):
z = R.add(x, y)
return z
mod = InputModule
target = tvm.target.Target("llvm", host="llvm")
ex, lib = relax.vm.build(mod, target)
"""
passes = [relax.transform.ToNonDataflow()]
passes.append(relax.transform.CallDPSRewrite())
Expand All @@ -178,10 +188,11 @@ def build(mod: tvm.IRModule,
# split primfunc and relax function
rx_mod, tir_mod = _split_tir_relax(new_mod)

lib = tvm.build(tir_mod, target, target_host)
lib = tvm.build(tir_mod, target)
ex = _ffi_api.VMCodeGen(rx_mod)
return ex, lib


def _split_tir_relax(mod: tvm.IRModule) -> Tuple[tvm.IRModule, tvm.IRModule]:
rx_mod = IRModule({})
tir_mod = IRModule({})
Expand Down
51 changes: 20 additions & 31 deletions tests/python/relax/test_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,8 @@ def foo(x: Tensor[(3, 4), "float32"], y: Tensor[(3, 4), "float32"]):
return y

mod = TestVMCompileStage0
target = tvm.target.Target("llvm")
target_host = tvm.target.Target("llvm")
ex, lib = relax.vm.build(mod, target, target_host)
target = tvm.target.Target("llvm", host="llvm")
ex, lib = relax.vm.build(mod, target)
inp1 = tvm.nd.array(np.random.rand(3,4).astype(np.float32))
inp2 = tvm.nd.array(np.random.rand(3,4).astype(np.float32))
vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib)
Expand Down Expand Up @@ -283,10 +282,8 @@ def foo(x: Tensor[_, "float32"]) -> Shape:
return gv3

mod = TestVMCompileStage1
code = R.parser.astext(mod)
target = tvm.target.Target("llvm")
target_host = tvm.target.Target("llvm")
ex, lib = relax.vm.build(mod, target, target_host)
target = tvm.target.Target("llvm", host="llvm")
ex, lib = relax.vm.build(mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib)

shape = (32, 16)
Expand All @@ -305,9 +302,8 @@ def foo(x: Tensor[_, "float32"]) -> Shape:
return (n * 2, m * 3)

mod = TestVMCompileStage2
target = tvm.target.Target("llvm")
target_host = tvm.target.Target("llvm")
ex, lib = relax.vm.build(mod, target, target_host)
target = tvm.target.Target("llvm", host="llvm")
ex, lib = relax.vm.build(mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib)

shape = (32, 16)
Expand All @@ -328,9 +324,8 @@ def foo(x: Tensor[(32, 16), "float32"]) -> Tensor:
return y

mod = TestVMCompileStage3
target = tvm.target.Target("llvm")
target_host = tvm.target.Target("llvm")
ex, lib = relax.vm.build(mod, target, target_host)
target = tvm.target.Target("llvm", host="llvm")
ex, lib = relax.vm.build(mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib)

shape = (32, 16)
Expand All @@ -352,9 +347,8 @@ def foo(x: Tensor[_, "float32"]) -> Tensor:

mod = TestVMCompileE2E

target = tvm.target.Target("llvm")
target_host = tvm.target.Target("llvm")
ex, lib = relax.vm.build(mod, target, target_host)
target = tvm.target.Target("llvm", host="llvm")
ex, lib = relax.vm.build(mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib)

shape = (32, 16)
Expand Down Expand Up @@ -390,9 +384,8 @@ def func(x:Tensor[(m, n), "float32"], w:Tensor[(n, k), "float32"]) -> Tensor:

mod = TestVMCompileE2E2

target = tvm.target.Target("llvm")
target_host = tvm.target.Target("llvm")
ex, lib = relax.vm.build(mod, target, target_host)
target = tvm.target.Target("llvm", host="llvm")
ex, lib = relax.vm.build(mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib)

data = tvm.nd.array(np.random.rand(32, 16).astype(np.float32))
Expand All @@ -415,9 +408,8 @@ def test_vm_emit_te_extern():

mod = bb.get()

target = tvm.target.Target("llvm")
target_host = tvm.target.Target("llvm")
ex, lib = relax.vm.build(mod, target, target_host)
target = tvm.target.Target("llvm", host="llvm")
ex, lib = relax.vm.build(mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib)

data = tvm.nd.array(np.random.rand(16, 32).astype(np.float32))
Expand All @@ -444,9 +436,8 @@ def te_func(A, B):

mod = bb.get()

target = tvm.target.Target("llvm")
target_host = tvm.target.Target("llvm")
ex, lib = relax.vm.build(mod, target, target_host)
target = tvm.target.Target("llvm", host="llvm")
ex, lib = relax.vm.build(mod, target)

vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib)
inp = tvm.nd.array(np.random.rand(1, ).astype(np.float32))
Expand All @@ -471,9 +462,8 @@ def te_func(A):

mod = bb.get()

target = tvm.target.Target("llvm")
target_host = tvm.target.Target("llvm")
ex, lib = relax.vm.build(mod, target, target_host)
target = tvm.target.Target("llvm", host="llvm")
ex, lib = relax.vm.build(mod, target)

vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib)
shape = (9, )
Expand Down Expand Up @@ -503,9 +493,8 @@ def te_func(A, B):

mod = bb.get()

target = tvm.target.Target("llvm")
target_host = tvm.target.Target("llvm")
ex, lib = relax.vm.build(mod, target, target_host)
target = tvm.target.Target("llvm", host="llvm")
ex, lib = relax.vm.build(mod, target)

vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib)
shape1 = (5, )
Expand Down

0 comments on commit f16f3a5

Please sign in to comment.