Skip to content

Commit

Permalink
rocm: fix dense_rocblas in strategy, topi (apache#5191)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored and Trevor Morris committed Apr 1, 2020
1 parent c3a64c1 commit 759d901
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def dense_strategy_rocm(attrs, inputs, out_type, target):
assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported."
strategy.add_implementation(
wrap_compute_dense(topi.rocm.dense_rocblas),
wrap_topi_schedule(topi.rocm.dense_rocblas),
wrap_topi_schedule(topi.rocm.schedule_dense_rocblas),
name="dense_rocblas.rocm",
plevel=15)
return strategy
2 changes: 2 additions & 0 deletions topi/python/topi/rocm/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ def dense_rocblas(cfg, data, weight, bias=None, out_dtype=None):
output : tvm.te.Tensor
2-D with shape [batch, out_dim]
"""
if out_dtype is None:
out_dtype = data.dtype
assert out_dtype == data.dtype, "Mixed precision not supported."
matmul = rocblas.matmul(data, weight, False, True)
batch, in_dim = data.shape
Expand Down

0 comments on commit 759d901

Please sign in to comment.