Skip to content

Commit

Permalink
[MetaSchedule] Fix tensorcore winograd task extraction (#13625)
Browse files Browse the repository at this point in the history
* [MetaSchedule] Fix tensorcore winograd task extraction

* add test

* fixed target
  • Loading branch information
masahi authored Dec 16, 2022
1 parent 7674ea8 commit 37f6aa0
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,8 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
)
if (
target.kind.name == "cuda"
and not is_auto_scheduler_enabled()
and not is_meta_schedule_enabled()
and nvcc.have_tensorcore(target=target)
and (
(N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0)
Expand Down
18 changes: 18 additions & 0 deletions tests/python/unittest/test_meta_schedule_relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,24 @@ def test_meta_schedule_integration_extract_from_resnet():
assert t.task_name in expected_task_names, t.task_name


@requires_torch
def test_task_extraction_winograd_tensorcore():
mod, params, _ = get_network(name="resnet_50", input_shape=[16, 3, 224, 224])
seq = tvm.transform.Sequential(
[
relay.transform.ToMixedPrecision("float16"),
relay.transform.ConvertLayout({"nn.conv2d": ["NHWC", "HWIO"]}),
]
)
with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)

target = tvm.target.Target("nvidia/geforce-rtx-3070")
extracted_tasks = ms.relay_integration.extract_tasks(mod, target=target, params=params)

assert len([t for t in extracted_tasks if "winograd" in t.task_name]) == 4


@requires_torch
def test_task_extraction_anchor_block():
mod, params, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])
Expand Down

0 comments on commit 37f6aa0

Please sign in to comment.