Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Hzfengsy committed Sep 26, 2021
1 parent 1f38801 commit a0c3cf1
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 53 deletions.
2 changes: 1 addition & 1 deletion include/tvm/tir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ class LinkedParam : public ObjectRef {
*
* \code
* @T.prim_func
* def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: ty.int32) -> None:
* def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: T.int32) -> None:
* A = T.match_buffer(a, (m, n), "float32")
* B = T.match_buffer(b, (m, n), "float32")
*
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def specialize(self, param_map: Mapping[Var, Union[PrimExpr, Buffer]]):
.. code-block:: python
@T.prim_func
def mem_copy(a: T.handle, b: T.handle, m: ty.int32, n: ty.int32) -> None:
def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: T.int32) -> None:
A = T.match_buffer(a, (m, n), "float32")
B = T.match_buffer(b, (m, n), "float32")
Expand Down
2 changes: 1 addition & 1 deletion src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,7 @@ Doc TVMScriptPrinter::VisitType_(const TupleTypeNode* node) {
for (Type field : node->fields) {
fields.push_back(Print(field));
}
return Doc::Text("ty.Tuple[") << Doc::Concat(fields) << "]";
return Doc::Text("T.Tuple[") << Doc::Concat(fields) << "]";
}
}

Expand Down
75 changes: 39 additions & 36 deletions tests/python/unittest/test_meta_schedule_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import pytest

from tvm import tir, script
from tvm import script
from tvm._ffi import register_func
from tvm.meta_schedule.builder import (
BuilderInput,
Expand All @@ -32,55 +32,58 @@
PyBuilder,
)
from tvm.runtime import Module
from tvm.script import ty
from tvm.script import tir as T
from tvm.target import Target


# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring


@script.tir
@script.ir_module
class MatmulModule:
@T.prim_func
def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-self-argument
tir.func_attr({"global_symbol": "matmul", "tir.noalias": True})
A = tir.match_buffer(a, (1024, 1024), "float32")
B = tir.match_buffer(b, (1024, 1024), "float32")
C = tir.match_buffer(c, (1024, 1024), "float32")
with tir.block([1024, 1024, tir.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]:
with tir.init():
T.func_attr({"global_symbol": "matmul", "tir.noalias": True})
A = T.match_buffer(a, (1024, 1024), "float32")
B = T.match_buffer(b, (1024, 1024), "float32")
C = T.match_buffer(c, (1024, 1024), "float32")
with T.block([1024, 1024, T.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]:
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]


@script.tir
@script.ir_module
class MatmulReluModule:
@T.prim_func
def matmul_relu( # pylint: disable=no-self-argument
a: T.handle, b: T.handle, d: T.handle
) -> None:
tir.func_attr({"global_symbol": "matmul_relu", "tir.noalias": True})
A = tir.match_buffer(a, (1024, 1024), "float32")
B = tir.match_buffer(b, (1024, 1024), "float32")
D = tir.match_buffer(d, (1024, 1024), "float32")
C = tir.alloc_buffer((1024, 1024), "float32")
with tir.block([1024, 1024, tir.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]:
with tir.init():
T.func_attr({"global_symbol": "matmul_relu", "tir.noalias": True})
A = T.match_buffer(a, (1024, 1024), "float32")
B = T.match_buffer(b, (1024, 1024), "float32")
D = T.match_buffer(d, (1024, 1024), "float32")
C = T.alloc_buffer((1024, 1024), "float32")
with T.block([1024, 1024, T.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]:
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
with tir.block([1024, 1024], "relu") as [vi, vj]:
D[vi, vj] = tir.max(C[vi, vj], 0.0)
with T.block([1024, 1024], "relu") as [vi, vj]:
D[vi, vj] = T.max(C[vi, vj], 0.0)


@script.tir
@script.ir_module
class BatchMatmulModule:
@T.prim_func
def batch_matmul( # pylint: disable=no-self-argument
a: T.handle, b: T.handle, c: T.handle
) -> None:
tir.func_attr({"global_symbol": "batch_matmul", "tir.noalias": True})
A = tir.match_buffer(a, [16, 128, 128])
B = tir.match_buffer(b, [16, 128, 128])
C = tir.match_buffer(c, [16, 128, 128])
with tir.block([16, 128, 128, tir.reduce_axis(0, 128)], "update") as [vn, vi, vj, vk]:
with tir.init():
T.func_attr({"global_symbol": "batch_matmul", "tir.noalias": True})
A = T.match_buffer(a, [16, 128, 128])
B = T.match_buffer(b, [16, 128, 128])
C = T.match_buffer(c, [16, 128, 128])
with T.block([16, 128, 128, T.reduce_axis(0, 128)], "update") as [vn, vi, vj, vk]:
with T.init():
C[vn, vi, vj] = 0.0
C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk]

Expand All @@ -101,7 +104,7 @@ def _check_build_results(builder_results: List[BuilderResult]):

def test_meta_schedule_single_build():
"""Test meta schedule builder for a single build"""
mod = MatmulModule()
mod = MatmulModule
builder = LocalBuilder()
builder_inputs = [BuilderInput(mod, Target("llvm"))]
builder_results = builder.build(builder_inputs)
Expand All @@ -113,9 +116,9 @@ def test_meta_schedule_multiple_build():
"""Test meta schedule builder for multiple builds"""
builder = LocalBuilder()
builder_inputs = [
BuilderInput(MatmulModule(), Target("llvm")),
BuilderInput(MatmulReluModule(), Target("llvm")),
BuilderInput(BatchMatmulModule(), Target("llvm")),
BuilderInput(MatmulModule, Target("llvm")),
BuilderInput(MatmulReluModule, Target("llvm")),
BuilderInput(BatchMatmulModule, Target("llvm")),
]
builder_results = builder.build(builder_inputs)
assert len(builder_results) == len(builder_inputs)
Expand All @@ -134,9 +137,9 @@ def build( # pylint: disable=no-self-use

builder = TestBuilder()
builder_inputs = [
BuilderInput(MatmulModule(), Target("llvm")),
BuilderInput(MatmulReluModule(), Target("llvm")),
BuilderInput(BatchMatmulModule(), Target("llvm")),
BuilderInput(MatmulModule, Target("llvm")),
BuilderInput(MatmulReluModule, Target("llvm")),
BuilderInput(BatchMatmulModule, Target("llvm")),
]
builder_results = builder.build(builder_inputs)
assert len(builder_results) == len(builder_inputs)
Expand All @@ -156,7 +159,7 @@ def test_build(mod: Module, target: Target) -> None: # pylint: disable=unused-v
raise ValueError("Builder intended Test Error (build func).")

builder = LocalBuilder(f_build="meta_schedule.builder.test_build", initializer=initializer)
builder_inputs = [BuilderInput(MatmulModule(), Target("llvm"))]
builder_inputs = [BuilderInput(MatmulModule, Target("llvm"))]
builder_results = builder.build(builder_inputs)
assert len(builder_results) == len(builder_inputs)
for result in builder_results:
Expand All @@ -175,7 +178,7 @@ def test_build(mod: Module) -> str: # pylint: disable=unused-variable
raise ValueError("Builder intended Test Error (export func).")

builder = LocalBuilder(f_export="meta_schedule.builder.test_export", initializer=initializer)
builder_inputs = [BuilderInput(MatmulModule(), Target("llvm"))]
builder_inputs = [BuilderInput(MatmulModule, Target("llvm"))]
builder_results = builder.build(builder_inputs)
assert len(builder_results) == len(builder_inputs)
for result in builder_results:
Expand All @@ -198,7 +201,7 @@ def timeout_build(mod, target): # pylint: disable=unused-argument, unused-varia
f_build="meta_schedule.builder.test_time_out",
initializer=initializer,
)
builder_inputs = [BuilderInput(MatmulModule(), Target("llvm"))]
builder_inputs = [BuilderInput(MatmulModule, Target("llvm"))]
builder_results = builder.build(builder_inputs)
assert len(builder_results) == len(builder_inputs)
for result in builder_results:
Expand Down
26 changes: 13 additions & 13 deletions tests/python/unittest/test_meta_schedule_space_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,25 @@
import pytest

import tvm
from tvm import tir
from tvm.script import ty
from tvm.script import tir as T

from tvm.tir.schedule import Schedule, Trace
from tvm.tir.schedule import Schedule
from tvm.meta_schedule.space_generator import ScheduleFn, SpaceGeneratorUnion


# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument
# fmt: off

@tvm.script.tir
@tvm.script.ir_module
class Matmul:
def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
tir.func_attr({"global_symbol": "main"})
A = tir.match_buffer(a, (1024, 1024), "float32")
B = tir.match_buffer(b, (1024, 1024), "float32")
C = tir.match_buffer(c, (1024, 1024), "float32")
with tir.block([1024, 1024, tir.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]:
with tir.init():
@T.prim_func
def main(a: T.handle, b: T.handle, c: T.handle) -> None:
T.func_attr({"global_symbol": "main"})
A = T.match_buffer(a, (1024, 1024), "float32")
B = T.match_buffer(b, (1024, 1024), "float32")
C = T.match_buffer(c, (1024, 1024), "float32")
with T.block([1024, 1024, T.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]:
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

Expand All @@ -66,7 +66,7 @@ def _check_correct(schedule: Schedule):


def test_meta_schedule_space_generator_schedule_fn():
mod = Matmul()
mod = Matmul
space_generator = ScheduleFn(sch_fn=schedule_matmul)
design_spaces = space_generator.generate_design_space(mod)
assert len(design_spaces) == 1
Expand All @@ -75,7 +75,7 @@ def test_meta_schedule_space_generator_schedule_fn():


def test_meta_schedule_design_space_generator_union():
mod = Matmul()
mod = Matmul
space_generator = ScheduleFn(sch_fn=schedule_matmul)
space_generator_union = SpaceGeneratorUnion([space_generator, space_generator])
design_spaces = space_generator_union.generate_design_space(mod)
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_meta_schedule_tune_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-s


def test_tune_context_create():
mod = Matmul()
mod = Matmul
context = TuneContext(mod=mod, target=Target("llvm"), task_name="Test Task")
assert context.num_threads > 0
assert context.rand_state != -1
Expand Down

0 comments on commit a0c3cf1

Please sign in to comment.