Skip to content

Commit

Permalink
[Relay] Add new IR pass CombineParallelDense (apache#3862)
Browse files Browse the repository at this point in the history
* Refactor to create abstract ParallelOpCombiner

* First draft of CombineParallelDense

* Begin to work on tests

* Test

* Refactor to move out more common code

* Clean up

* Fix

* Remove statics

* fix wording

* Start to add combine_parallel_op_batch

* Resolve PR comments

* Resolve PR comments

* dummy change to retrigger CI

* Change special case from bias_add to add

* Revert special case change

* Ignore units check

* dummy change to retrigger CI

* dummy change to re-trigger CI

* Improve docs

* Update docs

* Update docs
  • Loading branch information
soiferj authored and wweic committed Oct 1, 2019
1 parent 7e41927 commit 3a7b889
Show file tree
Hide file tree
Showing 12 changed files with 1,174 additions and 194 deletions.
2 changes: 2 additions & 0 deletions docs/api/python/relay/transform.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ tvm.relay.transform

.. autofunction:: tvm.relay.transform.CombineParallelConv2D

.. autofunction:: tvm.relay.transform.CombineParallelDense

.. autofunction:: tvm.relay.transform.AlterOpLayout

.. autofunction:: tvm.relay.transform.Legalize
Expand Down
11 changes: 11 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,17 @@ TVM_DLL Pass EliminateCommonSubexpr(PackedFunc fskip = nullptr);
*/
TVM_DLL Pass CombineParallelConv2D(uint64_t min_num_branches = 3);

/*!
* \brief Combine parallel dense ops into a single batch_matmul if the
* number of branches of this dense operator is not less than
* `min_num_branch`.
*
* \param min_num_branches The minimun number of branches.
*
* \return The pass.
*/
TVM_DLL Pass CombineParallelDense(uint64_t min_num_branches = 3);

/*!
* \brief Backward fold axis scaling into weights of conv/dense operators.
*
Expand Down
30 changes: 30 additions & 0 deletions python/tvm/relay/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def build_config(opt_level=2,
"CanonicalizeCast": 3,
"EliminateCommonSubexpr": 3,
"CombineParallelConv2D": 4,
"CombineParallelDense": 4
}
fallback_device : int, str, or tvm.TVMContext, optional
Expand Down Expand Up @@ -400,6 +401,35 @@ def CombineParallelConv2D(min_num_branches=3):
return _transform.CombineParallelConv2D(min_num_branches)


def CombineParallelDense(min_num_branches=3):
"""Combine multiple dense operators into one. For example:
data
/ \
dense (2,2) dense (2,2)
| |
elemwise/bcast (2,2) elemwise/bcast (2,2)
Would become:
data
|
batch_matmul+elemwise/bcast (2,2,2)
Parameters
----------
min_num_branches : int
The minimum number of required parallel branches for performing this
optimization.
Returns
-------
ret: tvm.relay.Pass
The registered pass that combines parallel dense operators.
"""
return _transform.CombineParallelDense(min_num_branches)


def AlterOpLayout():
"""Alternate the layouts of operators or replace primitive operators with
other expressions.
Expand Down
1 change: 1 addition & 0 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ class RelayBuildModule : public runtime::ModuleNode {
});
pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip));
pass_seqs.push_back(transform::CombineParallelConv2D(3));
pass_seqs.push_back(transform::CombineParallelDense(3));
pass_seqs.push_back(transform::FoldConstant());
pass_seqs.push_back(transform::FoldScaleAxis());
pass_seqs.push_back(transform::CanonicalizeCast());
Expand Down
Loading

0 comments on commit 3a7b889

Please sign in to comment.