-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay][AlterOp] Improving support for broadcast layout alteration. #4040
Conversation
include/tvm/data_layout.h
Outdated
@@ -210,6 +210,20 @@ class Layout : public NodeRef { | |||
return ct; | |||
} | |||
|
|||
/*! \return Concatenation of all primal axes */ | |||
inline std::string get_primal_axes() const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggest to wrap it as Layout
so later we can deal with Layout instead of std::string
shouldn't |
I thought of doing it in canocalizeOps (where bias_add is converted to add with expand_dims), that will definitely be simpler. But, will miss cases where we create network with just add. |
@anijain2305 Do you mean scenarios like adding a 1-D bias vector to NCHW tensor? This doesn't follow the broadcast semantic. |
NCHW (Currently working)Originalbias -> [C] After CanonicalizeOpsbias -> [C] After AlterOpLayoutbias -> [C] This is good for now. Now, lets look at NHWC NHWC (Does not work - Happens in TFLite)Originalbias -> [C] After CanonicalizeOpsbias -> [C] After AlterOpLayoutbias -> [C] The problem is that LT does not support going from C to NCHW16c. This PR inserts an expand dims to make bias go to NHWC, and then call layout transform. So, this PR bring following transformation Questions to ask
|
e57b1f9
to
821c402
Compare
@yzhliu @vinx13 @tqchen @zhiics @ZihengJiang Ping for review. I am done from my side. |
@anijain2305 Thank you for adding this! I have the same issue when trying to internally transfer conv2d to NCHWc layout for tf model. One question is that can we directly use BinaryBroadcastLayout for bias_add? |
After discussion, I think this change is necessary. LGTM |
include/tvm/data_layout.h
Outdated
* \param dst_layout The dst layout to which current layout has to be expanded. | ||
* \return The expanded Layout. | ||
*/ | ||
inline Layout ExpandLayout(const Layout& dst_layout) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggest to name it ExpandPrimal
include/tvm/data_layout.h
Outdated
// 2) Now, add the primal axis of the current layout. | ||
for (auto src_axis : operator->()->axes) { | ||
if (LayoutAxis::Get(src_axis).IsPrimal()) { | ||
new_src_layout_str += src_axis->var->name_hint; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we simply do new_src_layout_str += this->name()
?
include/tvm/data_layout.h
Outdated
for (auto axis : operator->()->axes) { | ||
if (!LayoutAxis::Get(axis).IsPrimal()) { | ||
// 1) Find the corresponding dual axis | ||
auto dual_axis = std::toupper(LayoutAxis::Get(axis).name()[0]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LayoutAxis::Get(axis).ToPrimal() or ToDual()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function is private. Can't do that.
include/tvm/data_layout.h
Outdated
std::string new_src_layout_str = ""; | ||
for (auto dst_axis : dst_layout->axes) { | ||
if (LayoutAxis::Get(dst_axis).IsPrimal()) { | ||
if (this->IndexOf(LayoutAxis::Get(dst_axis)) == -1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggest to use this->Contains
include/tvm/data_layout.h
Outdated
bool is_shape_one = false; | ||
if (auto* shape_int = shape_val.as<IntImm>()) { | ||
if (shape_int->value == 1) { | ||
new_small_layout += "1"; | ||
is_shape_one = true; | ||
} | ||
} | ||
|
||
// 4) b) If shape is not 1, retain the factor. | ||
if (!is_shape_one) { | ||
auto new_shape_val = FactorOf(LayoutAxis::Get(dual_axis)); | ||
new_small_layout += std::to_string(new_shape_val); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it might be better to have another helper function ReplaceFactor
implement the functionality above, and move this AdjustSubordinateFactors
to alter_op_layout, as this is not as general for tvm::Layout.
auto dual_axis = LayoutAxis::Get(axis).ToPrimal().name()[0]; | ||
|
||
// 2) Find the index of this dual axis in old_layout | ||
int old_axis = old_layout.IndexOf(LayoutAxis::Get(dual_axis)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually you don't need to do .name()
first then ::Get
, ToPrimal()
already returns LayoutAxis
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried that. But it does not work because some of the methods are private. I didn't want to change the class description.
* master: (21 commits) [Fix][VM] Fix VM invoke with set_params (apache#4079) [QNN] Refactor fixed point multiplication in requantize (apache#4073) Fix match case in Python-side expr functor (apache#4037) Hide symbols from dependent libraries if HIDE_PRIVATE_SYMBOLS is ON. (apache#4041) Add gradient for log-softmax (apache#4069) [DOC] Fix typos in tutorials (apache#4066) dicrease the complexity of CalcDep from exponential to linear (apache#4053) [Relay][AlterOp] Minor refactor. (apache#4064) [Relay][AlterOp] Improving support for broadcast layout alteration. (apache#4040) Add parses support for zeros_like tflite operator (apache#4042) [Bugfix][TF] reset graph after getting tag of savedmodel (apache#4055) [Relay][VM] Add more passes to VMCompiler (apache#4058) [Relay][VM] Add autotvm context when compile (apache#4062) [Bugfix] Fix target host for vm compiler (apache#4057) [Relay][Training] Add gradient for Crossentropy (apache#3925) [llvm] switch to use Align for llvm trunk (apache#4051) [Relay][TopHub] Add switch to disable TopHub download (apache#4015) [Relay][Op] Add instance norm op (apache#4004) [QNN][Relay] Calling Dialect passes from inside Relay Build API. (apache#3971) [RELAY/PASS] Fix the extent for the post_stmt in the loop partition (apache#3734) ...
Improve broadcast layout support while Altering layouts
@yzhliu @vinx13 @tqchen @zhiics @ZihengJiang