Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Mar 5, 2022
1 parent db96065 commit 5e5cc79
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions tests/python/unittest/test_meta_schedule_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ def main(a: T.handle, b: T.handle) -> None: # type: ignore
# pylint: enable=no-member,line-too-long,too-many-nested-blocks,unbalanced-tuple-unpacking,no-self-argument


def _has_torch():
import importlib.util # pylint: disable=unused-import,import-outside-toplevel

spec = importlib.util.find_spec("torch")
return spec is not None


requires_torch = pytest.mark.skipif(not _has_torch(), reason="torch is not installed")


def _check_mock_task(tasks: List[ExtractedTask], mod: IRModule):
(task,) = tasks
assert isinstance(task, ExtractedTask)
Expand All @@ -62,6 +72,7 @@ def _check_mock_task(tasks: List[ExtractedTask], mod: IRModule):
tvm.ir.assert_structural_equal(tir_mod, MockModule)


@requires_torch
def test_meta_schedule_integration_task_extraction_query():
mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])
env = TaskExtraction()
Expand All @@ -87,6 +98,7 @@ def test_meta_schedule_integration_multiple_current():
...


@requires_torch
def test_meta_schedule_integration_query_inside_with_scope():
mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])
env = TaskExtraction()
Expand All @@ -100,6 +112,7 @@ def test_meta_schedule_integration_query_inside_with_scope():
_check_mock_task(env.tasks, mod)


@requires_torch
def test_meta_schedule_integration_extract_from_resnet():
mod, params, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])
extracted_tasks = ms.integration.extract_task_from_relay(mod, target="llvm", params=params)
Expand Down Expand Up @@ -135,6 +148,7 @@ def test_meta_schedule_integration_extract_from_resnet():
assert t.task_name in expected_task_names, t.task_name


@requires_torch
def test_meta_schedule_integration_apply_history_best():
@derived_object
class DummyDatabase(PyDatabase):
Expand Down Expand Up @@ -183,6 +197,7 @@ def print_results(self) -> None:
TuningRecord(Schedule(MockModule).trace, [1.0], workload, target, [])
)
mod = env.query(task_name="mock-task", mod=mod, target=target, dispatched=[MockModule])
mod = IRModule({"main": mod})
assert tvm.ir.structural_equal(mod, workload.mod)


Expand Down

0 comments on commit 5e5cc79

Please sign in to comment.