-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LANG] Support for Tuple Inputs of Reducer and ComputeOp #175
Conversation
include/tvm/expr.h
Outdated
@@ -43,6 +43,16 @@ using Halide::Internal::is_no_op; | |||
using Halide::likely; | |||
using Halide::likely_if_innermost; | |||
|
|||
/*! \brief whether two array have the same content */ | |||
template<typename T> | |||
bool IsSame(const Array<T>& a, const Array<T>& b) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SameContent
src/op/compute_op.cc
Outdated
n->reduce_axis = n->body.as<ir::Reduce>()->axis; | ||
if (n->body[0]->is_type<ir::Reduce>()) { | ||
// batch reduction should have the same axis | ||
n->reduce_axis = n->body[0].as<ir::Reduce>()->axis; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add verification check here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Raise error if that is not true
src/op/compute_op.cc
Outdated
return ret; | ||
} | ||
|
||
Operation ComputeOpNode::ReplaceInputs( | ||
const Operation& self, | ||
const std::unordered_map<Tensor, Tensor>& rmap) const { | ||
CHECK_EQ(self.operator->(), this); | ||
Expr new_body = op::ReplaceTensor(this->body, rmap); | ||
if (!new_body.same_as(this->body)) { | ||
Array<Expr> new_body = ReplaceTensor(this->body, rmap); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Try to do it simply in a for loop, check the way we do it ir mutator
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can add a Array<NodeRef> UpdateArray(Array<NodeRef> arr, std::function<NodeRef(NodeRef)> fupdate)
function?
src/op/compute_op.cc
Outdated
@@ -298,15 +332,26 @@ Stmt ComputeOpNode::BuildProvide( | |||
CHECK_EQ(stage->op.operator->(), this); | |||
|
|||
if (IsCrossThreadReduction(this, stage)) { | |||
LOG(INFO) << stage; | |||
// specially handle cross thread reduction. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this
src/pass/ir_util.h
Outdated
@@ -12,6 +12,23 @@ | |||
namespace tvm { | |||
namespace ir { | |||
|
|||
template<typename T> | |||
inline Array<T> UpdateArray(Array<T> arr, std::function<T(T)> fupdate) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
simply use F fupdate, use a template argument for fupdate
src/pass/ir_mutator.cc
Outdated
} else { | ||
return Array<Expr>(new_arr); | ||
} | ||
std::function<Expr(Expr)> fupdate = [m] (Expr e) { return m->Mutate(e); }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto fupdate, std::function and lambda are different, lambda is more specialized and can trigger inline
src/op/compute_op.cc
Outdated
if (!is_one(reduce->condition)) { | ||
*provide = IfThenElse::make(reduce->condition, *provide); | ||
for (size_t i = 0; i < size; ++i) { | ||
provides->at(i) = IfThenElse::make(reduce->condition, provides->at(i)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we have one common condition for all the bodies?
size_t size = self->body.size(); | ||
CHECK_GT(size, 0); | ||
std::vector<const Reduce*> reduces(size); | ||
for (size_t i = 0; i < size; ++i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we assume common reduce, vector is not necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reduces is used for type
src/pass/ir_util.h
Outdated
* \brief update array with an unary function | ||
* \param arr array | ||
* \param fupdate an unary function | ||
* \return if update happens, return the new array, else return the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add \tparam to document the template argument
src/pass/lower_thread_allreduce.cc
Outdated
} | ||
|
||
std::unordered_set<const Variable*> reduce_set; | ||
for (size_t i = 2; i < call->args.size(); ++i) { | ||
for (size_t i = 2+2*size; i < call->args.size(); ++i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
space between operators
src/schedule/graph.cc
Outdated
@@ -321,11 +323,14 @@ Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) { | |||
} | |||
} | |||
} else if (op.as<ComputeOpNode>()) { | |||
std::unordered_map<const Node*, TensorDimKey> vmap; | |||
std::unordered_map<const Node*, std::vector<TensorDimKey>> vmap; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for compatiblity of older compiler
reduce_stage->leaf_iter_vars = reduce_stage->all_iter_vars; | ||
reduce_stage->relations = Array<IterVarRelation>(); | ||
return factor_tensor; | ||
return factor_tensors[0]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to return array of Exprs?
include/tvm/ir.h
Outdated
Array<IterVar> rdom, | ||
Expr condition = const_true()); | ||
Expr condition = const_true(), | ||
int value_index = 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove the default value, to be safe
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
never mind, forget this comment
python/tvm/api.py
Outdated
where = convert(True) | ||
if size == 1: | ||
return _make.Reduce(combiner, expr, axis, where, 0) | ||
return [_make.Reduce(combiner, expr, axis, where, i) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change to tuple
python/tvm/schedule.py
Outdated
@@ -193,10 +193,13 @@ def rfactor(self, tensor, axis): | |||
|
|||
Returns | |||
------- | |||
tfactor : Tensor | |||
tfactor : Tensor or Array<Tensor> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
list of Tensor
src/lang/ir.cc
Outdated
@@ -79,11 +88,12 @@ Expr Reduce::make(CommReducer combiner, Expr source, | |||
for (size_t i = 0; i < axis.size(); ++i) { | |||
CHECK(axis[i].defined()); | |||
} | |||
n->type = source.type(); | |||
n->type = source[value_index].type(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if argument is passing by value, do std::move to save copy constructor
src/op/compute_op.cc
Outdated
return outputs; | ||
} | ||
|
||
bool CheckReduce(const ir::Reduce* a, const ir::Reduce* b) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ReduceEqual
Array<Expr> freduce_args; | ||
freduce_args.push_back(reduce->source); | ||
freduce_args.push_back(make_const(UInt(32), size)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let us update comment in the intrinsic def to clarify the new convention
src/op/compute_op.cc
Outdated
body = AttrStmt::make( | ||
res_handle, attr::storage_scope, StringImm::make("local"), body); | ||
Stmt body = Block::make(reduce_body, assign_body); | ||
for (int idx = size - 1; idx >= 0; --idx) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keep revserse iteration style consistent
for (i = size; i !=0; --i)
This avoid certain case when reverse iteration is unsigned
src/pass/lower_thread_allreduce.cc
Outdated
Type type, | ||
Var shared_buf, | ||
const std::vector<Type>& types, | ||
Array<Var> shared_bufs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we don't copy the value, pass by const ref
python/tvm/schedule.py
Outdated
The created factored tensor. | ||
""" | ||
return _api_internal._ScheduleRFactor(self, tensor, axis) | ||
factored = _api_internal._ScheduleRFactor(self, tensor, axis) | ||
if len(factored) == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return factored[0] if len(factored) == 1 else factored
* fix for composed symbol * fix * clean up * fix exception type
* fix for composed symbol * fix * clean up * fix exception type
* fix for composed symbol * fix * clean up * fix exception type
* fix for composed symbol * fix * clean up * fix exception type
…e#14523) (apache#175) This PR enhances CanProve to handle symbolic bound. Such analysis is essential to eliminate predicates in dynamic shape workloads. We also the int set analysis singlepoint check to avoid recursion and improve the overall analysis speed. Added CanProveSinglePoint to serve previous stronger checks. The new CanProve comes with additinal strength argument that can only be used in top-level setting with stronger analysis. Added comment for future implementation efficiency. Testcases are added to cover the cases. Co-authored-by: Tianqi Chen <[email protected]>
…e#14523) (apache#175) This PR enhances CanProve to handle symbolic bound. Such analysis is essential to eliminate predicates in dynamic shape workloads. We also the int set analysis singlepoint check to avoid recursion and improve the overall analysis speed. Added CanProveSinglePoint to serve previous stronger checks. The new CanProve comes with additinal strength argument that can only be used in top-level setting with stronger analysis. Added comment for future implementation efficiency. Testcases are added to cover the cases. Co-authored-by: Tianqi Chen <[email protected]>
…ache#175) Bumps [actions/download-artifact](https://github.com/actions/download-artifact) from 3 to 4.1.7. - [Release notes](https://github.com/actions/download-artifact/releases) - [Commits](actions/download-artifact@v3...v4.1.7) --- updated-dependencies: - dependency-name: actions/download-artifact dependency-type: direct:production ... Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
#173