Skip to content

Commit

Permalink
[topi] fix strategy for sparse dense cuda (apache#5782)
Browse files Browse the repository at this point in the history
  • Loading branch information
antinucleon authored and Trevor Morris committed Jun 18, 2020
1 parent 29ed4e9 commit 4d7d1eb
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 15 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def compute_sparse_dense(attrs, inputs, out_type):
"""Compute definition of sparse_dense"""
return [topi.nn.sparse_dense(inputs[0], inputs[1], inputs[2], inputs[3])]

reg.register_schedule("nn.sparse_dense", strategy.schedule_sparse_dense)
reg.register_strategy("nn.sparse_dense", strategy.sparse_dense_strategy)
reg.register_pattern("nn.sparse_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)


Expand Down
13 changes: 13 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,19 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
plevel=15)
return strategy


@sparse_dense_strategy.register(["cuda", "gpu"])
def sparse_dense_strategy_cuda(attrs, inputs, out_type, target):
"""sparse dense cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_sparse_dense(topi.cuda.sparse_dense),
wrap_topi_schedule(topi.cuda.schedule_sparse_dense),
name="sparse_dense.cuda",
plevel=10)
return strategy


@argsort_strategy.register(["cuda", "gpu"])
def argsort_strategy_cuda(attrs, inputs, out_type, target):
"""argsort cuda strategy"""
Expand Down
22 changes: 16 additions & 6 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,12 +599,22 @@ def batch_matmul_strategy(attrs, inputs, out_type, target):
name="batch_matmul.generic")
return strategy

# sparse_dense
@generic_func
def schedule_sparse_dense(attrs, outs, target):
"""schedule sparse_dense"""
with target:
return topi.generic.schedule_sparse_dense(outs)
# sparse dense
def wrap_compute_sparse_dense(topi_compute):
"""wrap sparse dense topi compute"""
def _compute_sparse_dense(attrs, inputs, out_type):
return [topi_compute(inputs[0], inputs[1], inputs[2], inputs[3])]
return _compute_sparse_dense

@override_native_generic_func("sparse_dense_strategy")
def sparse_dense_strategy(attrs, inputs, out_type, target):
"""sparse dense generic strategy"""
logger.warning("sparse dense is not optimized for this platform.")
strategy = _op.OpStrategy()
strategy.add_implementation(wrap_compute_sparse_dense(topi.nn.sparse_dense),
wrap_topi_schedule(topi.generic.schedule_sparse_dense),
name="sparse_dense.generic")
return strategy

# sparse_transpose
@generic_func
Expand Down
15 changes: 10 additions & 5 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,11 +294,16 @@ def batch_matmul_strategy_cpu(attrs, inputs, out_type, target):
plevel=15)
return strategy

@schedule_sparse_dense.register("cpu")
def schedule_sparse_dense_cpu(attrs, outs, target):
"""schedule sparse_dense for x86"""
with target:
return topi.x86.schedule_sparse_dense(outs)
@sparse_dense_strategy.register("cpu")
def sparse_dense_strategy_cpu(attrs, inputs, out_type, target):
"""sparse dense x86 strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(wrap_compute_sparse_dense(topi.nn.sparse_dense),
wrap_topi_schedule(topi.x86.schedule_sparse_dense),
name="sparse_dense.x86",
plevel=10)
return strategy


@roi_align_strategy.register("cpu")
def roi_align_strategy_cpu(attrs, inputs, out_type, target):
Expand Down
1 change: 0 additions & 1 deletion topi/python/topi/cuda/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def schedule_sparse_dense(cfg, outs):
"""Create schedule for sparse dense"""
# pylint:disable=invalid-name
s = te.create_schedule([x.op for x in outs])

def _callback(op):
if op.tag == "sparse_dense_bsrmm":
y_bsrmm = op.input_tensors[0]
Expand Down
2 changes: 0 additions & 2 deletions topi/python/topi/x86/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,9 @@
from ..util import traverse_inline, get_const_int
from .util import get_fp32_len


def schedule_sparse_dense(outs):
"""Create schedule for sparse dense"""
s = te.create_schedule([x.op for x in outs])

def _callback(op):
simd_width = get_fp32_len()
if op.tag == "sparse_dense_csrmm" and op != outs[0].op:
Expand Down

0 comments on commit 4d7d1eb

Please sign in to comment.