Skip to content

Commit

Permalink
Add cuda target check to dense tensorcore schedule. (#5376)
Browse files Browse the repository at this point in the history
  • Loading branch information
jwfromm authored Apr 19, 2020
1 parent a2d6fe6 commit 4672dc7
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,15 +373,16 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_dense_large_batch),
name="dense_large_batch.cuda",
plevel=5)
if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
if(i % 16 == 0 and b % 16 == 0 and o % 16 == 0) \
or (i % 16 == 0 and b % 8 == 0 and o % 32 == 0) \
or (i % 16 == 0 and b % 32 == 0 and o % 8 == 0):
strategy.add_implementation(
wrap_compute_dense(topi.cuda.dense_tensorcore),
wrap_topi_schedule(topi.cuda.schedule_dense_tensorcore),
name="dense_tensorcore.cuda",
plevel=20)
if target.target_name == "cuda":
if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
if(i % 16 == 0 and b % 16 == 0 and o % 16 == 0) \
or (i % 16 == 0 and b % 8 == 0 and o % 32 == 0) \
or (i % 16 == 0 and b % 32 == 0 and o % 8 == 0):
strategy.add_implementation(
wrap_compute_dense(topi.cuda.dense_tensorcore),
wrap_topi_schedule(topi.cuda.schedule_dense_tensorcore),
name="dense_tensorcore.cuda",
plevel=20)
if target.target_name == "cuda" and "cublas" in target.libs:
strategy.add_implementation(
wrap_compute_dense(topi.cuda.dense_cublas),
Expand Down

0 comments on commit 4672dc7

Please sign in to comment.