Skip to content

Commit

Permalink
[Unity] Add dlight.gpu.Fallback in DispatchSortScan, add argsort, top…
Browse files Browse the repository at this point in the history
…k, and cumprod (#16351)
  • Loading branch information
yongwww authored Jan 10, 2024
1 parent 298ad2c commit e1d71b3
Show file tree
Hide file tree
Showing 21 changed files with 1,178 additions and 274 deletions.
52 changes: 0 additions & 52 deletions include/tvm/relax/attrs/sort.h

This file was deleted.

99 changes: 99 additions & 0 deletions include/tvm/relax/attrs/sorting.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relax/attrs/sorting.h
* \brief Attributes for sorting operators.
*/
#ifndef TVM_RELAX_ATTRS_SORTING_H_
#define TVM_RELAX_ATTRS_SORTING_H_

#include <tvm/relax/expr.h>
#include <tvm/tir/index_map.h>

namespace tvm {
namespace relax {

/*! \brief Attributes used in sort operator */
struct SortAttrs : public tvm::AttrsNode<SortAttrs> {
int axis;
bool descending;

TVM_DECLARE_ATTRS(SortAttrs, "relax.attrs.SortAttrs") {
TVM_ATTR_FIELD(axis).set_default(-1).describe(
"Axis along which the sort is computed."
"The default the last axis is used.");
TVM_ATTR_FIELD(descending)
.set_default(false)
.describe(
"Whether to sort in descending order."
"If it is not specified, it defaults to the ascending order.");
}
}; // struct SortAttrs

/*! \brief Attributes used in argsort operator */
struct ArgsortAttrs : public tvm::AttrsNode<ArgsortAttrs> {
int axis;
bool descending;
DataType dtype;

TVM_DECLARE_ATTRS(ArgsortAttrs, "relax.attrs.ArgsortAttrs") {
TVM_ATTR_FIELD(axis).set_default(-1).describe(
"Axis along which the argsort is computed."
"The default the last axis is used.");
TVM_ATTR_FIELD(descending)
.set_default(false)
.describe(
"Whether to argsort in descending order."
"If it is not specified, it defaults to the ascending order.");
TVM_ATTR_FIELD(dtype)
.set_default(NullValue<DataType>())
.describe("DType of the output indices.");
}
}; // struct ArgsortAttrs

/*! \brief Attributes used in topk operator */
struct TopKAttrs : public tvm::AttrsNode<TopKAttrs> {
int k;
int axis;
bool largest;
String ret_type;
DataType dtype;

TVM_DECLARE_ATTRS(TopKAttrs, "relax.attrs.TopKAttrs") {
TVM_ATTR_FIELD(k).describe("Number of top elements to select");
TVM_ATTR_FIELD(axis).set_default(-1).describe("Axis along which to sort the input tensor.");
TVM_ATTR_FIELD(ret_type).set_default("both").describe(
"The return type [both, values, indices]."
"both - return both top k data and indices."
"values - return top k data only."
"indices - return top k indices only.");
TVM_ATTR_FIELD(largest).set_default(true).describe(
"Whether to return largest or smallest elements."
"By default, return the largest k elements.");
TVM_ATTR_FIELD(dtype)
.set_default(NullValue<DataType>())
.describe("Data type of the output indices.");
}
}; // struct TopKAttrs

} // namespace relax
} // namespace tvm

#endif // TVM_RELAX_ATTRS_SORTING_H_
20 changes: 12 additions & 8 deletions include/tvm/relax/attrs/statistical.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,24 @@ struct StatisticalAttrs : public tvm::AttrsNode<StatisticalAttrs> {
}
}; // struct StatisticalAttrs

/*! \brief Attributes used in cumsum operators */
struct CumsumAttrs : public tvm::AttrsNode<CumsumAttrs> {
/*! \brief Attributes used in scan operators like cumsum, cumprod */
struct ScanopAttrs : public tvm::AttrsNode<ScanopAttrs> {
Optional<Integer> axis;
DataType dtype;
Bool exclusive = Bool(false);

TVM_DECLARE_ATTRS(CumsumAttrs, "relax.attrs.CumsumAttrs") {
TVM_DECLARE_ATTRS(ScanopAttrs, "relax.attrs.ScanopAttrs") {
TVM_ATTR_FIELD(axis).describe(
"Axis along which the cumulative sum is computed."
"The default (None) is to compute the cumsum over the flattened array.");
"The axis along which to perform the scan computation."
"The default (None) is to compute over the flattened array.");
TVM_ATTR_FIELD(dtype).describe(
"Type of the returned array and of the accumulator in which the elements are summed."
"If dtype is not specified, it defaults to the dtype of data.");
"The output data type."
"If dtype is not specified, it defaults to the dtype of input data.");
TVM_ATTR_FIELD(exclusive)
.describe("The first element is not included")
.set_default(Bool(false));
}
}; // struct CumsumAttrs
}; // struct ScanopAttrs

} // namespace relax
} // namespace tvm
Expand Down
124 changes: 92 additions & 32 deletions python/tvm/relax/backend/dispatch_sort_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
# pylint: disable=invalid-name, unused-argument, redefined-argument-from-local
"""Dispatch sort and scan operators to platform dependent implementation."""

from tvm import topi
from tvm import topi, dlight, relax
from tvm.ir import Op
from tvm.ir.module import IRModule
from tvm.ir.transform import PassContext, module_pass
from tvm.target import Target
from tvm.contrib.thrust import can_use_thrust
from tvm.relax import Expr, Function, Call, PyExprMutator, expr_functor, TensorStructInfo
from tvm.relax import PyExprMutator, expr_functor


@expr_functor.mutator
Expand All @@ -36,13 +36,17 @@ class SortScanDispatcher(PyExprMutator):
def __init__(self, mod):
super().__init__(mod)

def _get_target(self, expr: Expr) -> Target:
sinfo = expr.struct_info
def _get_target(self, sinfo: relax.StructInfo) -> Target:
# Get target information from TensorStructInfo
if isinstance(sinfo, TensorStructInfo):
if isinstance(sinfo, relax.TensorStructInfo):
vdevice = sinfo.vdevice
if vdevice is not None:
return vdevice.target
elif isinstance(sinfo, relax.TupleStructInfo):
for f in sinfo.fields:
tgt = self._get_target(f)
if tgt != Target.current():
return tgt
# Return the target in current context
target = Target.current()
if target is None:
Expand All @@ -52,38 +56,94 @@ def _get_target(self, expr: Expr) -> Target:
)
return target

def visit_call_(self, call: Call) -> Expr:
def _apply_dlight_gpu_fallback(self, target: Target, tir_call: relax.Call) -> None:
# Apply dlight.gpu.Fallback() on GPU
gvar = tir_call.args[0]
assert isinstance(gvar, relax.GlobalVar)
scan_prim_func = self.builder_.get()[gvar]
sch = dlight.base.transform._apply_rules(
scan_prim_func,
target,
[
dlight.gpu.Fallback(),
],
False,
)
if sch is not None:
assert len(sch) == 1
self.builder_.update_func(gvar, sch[0].mod["main"].with_attr("tir.is_scheduled", 1))

def visit_call_(self, call: relax.Call) -> relax.Expr:
if not isinstance(call.op, Op):
return super().visit_call_(call)

if call.op.name == "relax.sort":
tgt = self._get_target(call)
tgt = self._get_target(call.struct_info)
te_func = topi.sort
with tgt:
if can_use_thrust(tgt, "tvm.contrib.thrust.sort"):
return self.builder_.call_te(
topi.cuda.sort_thrust,
call.args[0],
call.attrs.axis,
not call.attrs.descending,
)
return self.builder_.call_te(
topi.cuda.sort if tgt.kind.name == "cuda" else topi.sort,
call.args[0],
call.attrs.axis,
not call.attrs.descending,
)

if call.op.name == "relax.cumsum":
tgt = self._get_target(call)
axis = int(call.attrs.axis) if call.attrs.axis is not None else call.attrs.axis
te_func = topi.cuda.sort_thrust
elif tgt.kind.name == "cuda":
te_func = topi.cuda.sort
return self.builder_.call_te(
te_func,
call.args[0],
call.attrs.axis,
not call.attrs.descending,
)
if call.op.name == "relax.argsort":
tgt = self._get_target(call.struct_info)
te_func = topi.argsort
with tgt:
return self.builder_.call_te(
topi.cuda.cumsum if tgt.kind.name == "cuda" else topi.cumsum,
call.args[0],
axis,
call.attrs.dtype,
)

if can_use_thrust(tgt, "tvm.contrib.thrust.sort"):
te_func = topi.cuda.argsort_thrust
elif tgt.kind.name == "cuda":
te_func = topi.cuda.argsort
return self.builder_.call_te(
te_func,
call.args[0],
axis=call.attrs.axis,
is_ascend=not call.attrs.descending,
dtype=call.attrs.dtype,
)
if call.op.name == "relax.topk":
tgt = self._get_target(call.struct_info)
te_func = topi.topk
if can_use_thrust(tgt, "tvm.contrib.thrust.sort"):
te_func = topi.cuda.topk_thrust
elif tgt.kind.name == "cuda":
te_func = topi.cuda.topk
tir_call = self.builder_.call_te(
te_func,
call.args[0],
axis=call.attrs.axis,
ret_type=call.attrs.ret_type,
is_ascend=not call.attrs.largest,
dtype=call.attrs.dtype,
)
if tgt.kind.name != "cuda":
return tir_call
# apply dlight gpu fallback
self._apply_dlight_gpu_fallback(tgt, tir_call)
return tir_call
if call.op.name in ("relax.cumprod", "relax.cumsum"):
tgt = self._get_target(call.struct_info)
axis = int(call.attrs.axis) if call.attrs.axis is not None else call.attrs.axis
te_func = topi.cuda.cumsum if tgt.kind.name == "cuda" else topi.cumsum
if call.op.name == "relax.cumprod":
te_func = topi.cuda.cumprod if tgt.kind.name == "cuda" else topi.cumprod
tir_call = self.builder_.call_te(
te_func,
call.args[0],
axis,
call.attrs.dtype,
call.attrs.exclusive,
)
if tgt.kind.name != "cuda":
return tir_call
# apply dlight gpu fallback
self._apply_dlight_gpu_fallback(tgt, tir_call)
return tir_call
return super().visit_call_(call)


Expand All @@ -96,7 +156,7 @@ class DispatchSortScan:
def transform_module(self, mod: IRModule, ctx: PassContext) -> IRModule:
sort_scan_dispater = SortScanDispatcher(mod)
for gv, func in mod.functions_items():
if isinstance(func, Function):
if isinstance(func, relax.Function):
func = sort_scan_dispater.visit_expr(func)
sort_scan_dispater.builder_.update_func(gv, func)
return sort_scan_dispater.builder_.get()
return sort_scan_dispater.builder_.finalize()
4 changes: 2 additions & 2 deletions python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@
from .qdq import quantize, dequantize
from .search import argmax, argmin, where
from .set import unique
from .sort import sort
from .statistical import cumsum, max, mean, min, prod, std, sum, variance
from .sorting import sort, argsort, topk
from .statistical import cumsum, cumprod, max, mean, min, prod, std, sum, variance
from .ternary import ewise_fma
from .unary import (
abs,
Expand Down
16 changes: 13 additions & 3 deletions python/tvm/relax/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ class SortAttrs(Attrs):
"""Attributes for sort operator"""


@tvm._ffi.register_object("relax.attrs.ArgsortAttrs")
class ArgsortAttrs(Attrs):
"""Attributes for argsort operator"""


@tvm._ffi.register_object("relax.attrs.SplitAttrs")
class SplitAttrs(Attrs):
"""Attributes used in split operator"""
Expand Down Expand Up @@ -154,9 +159,14 @@ class TileAttrs(Attrs):
"""Attributes for tile operator"""


@tvm._ffi.register_object("relax.attrs.CumsumAttrs")
class CumsumAttrs(Attrs):
"""Attributes for cumsum operator"""
@tvm._ffi.register_object("relax.attrs.ScanopAttrs")
class ScanopAttrs(Attrs):
"""Attributes for scan operators"""


@tvm._ffi.register_object("relax.attrs.TopKAttrs")
class TopKAttrs(Attrs):
"""Attributes for topk operators"""


@tvm._ffi.register_object("relax.attrs.EinsumAttrs")
Expand Down
Loading

0 comments on commit e1d71b3

Please sign in to comment.