From d5cfee898a86795270dca0b3dc13f9d6b2fbd223 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 31 Mar 2020 11:11:13 +0200 Subject: [PATCH] rocm: fix dense_rocblas in strategy, topi --- python/tvm/relay/op/strategy/rocm.py | 2 +- topi/python/topi/rocm/dense.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/op/strategy/rocm.py b/python/tvm/relay/op/strategy/rocm.py index 0486f71b526c..6cda346e5068 100644 --- a/python/tvm/relay/op/strategy/rocm.py +++ b/python/tvm/relay/op/strategy/rocm.py @@ -129,7 +129,7 @@ def dense_strategy_rocm(attrs, inputs, out_type, target): assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported." strategy.add_implementation( wrap_compute_dense(topi.rocm.dense_rocblas), - wrap_topi_schedule(topi.rocm.dense_rocblas), + wrap_topi_schedule(topi.rocm.schedule_dense_rocblas), name="dense_rocblas.rocm", plevel=15) return strategy diff --git a/topi/python/topi/rocm/dense.py b/topi/python/topi/rocm/dense.py index 097120da88d6..989cc2aed7c3 100644 --- a/topi/python/topi/rocm/dense.py +++ b/topi/python/topi/rocm/dense.py @@ -123,6 +123,8 @@ def dense_rocblas(cfg, data, weight, bias=None, out_dtype=None): output : tvm.te.Tensor 2-D with shape [batch, out_dim] """ + if out_dtype is None: + out_dtype = data.dtype assert out_dtype == data.dtype, "Mixed precision not supported." matmul = rocblas.matmul(data, weight, False, True) batch, in_dim = data.shape