-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix LayoutRewriter #10118
Fix LayoutRewriter #10118
Conversation
@lazycal Thank you very much for working on this. I need to review our current layout rewrite implementation to understand this change, so please wait for a few days or longer until I merge this (but I will, definitely). |
// since packing is always done on the "C" axis. | ||
if (new_in_layouts.size() > 0 && new_in_layouts[0].defined() && new_in_layouts[0].ndim() == 2) { | ||
return InferCorrectLayoutOutput({new_in_layouts[0], params->weight_layout}, {"NC"}, attrs); | ||
} | ||
return InferCorrectLayoutOutput({"NC", params->weight_layout}, {"NC"}, attrs); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we still need DenseInferCorrectLayout
and DensePackInferCorrectLayout
now? I added these functions to workaround alter layout issues, but that might not be necessary anymore. Can you try remove them and see what happens?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I skimmed through your change in 66ac470, and I think it is probably still needed. Mainly because when FTVMAlterOpLayout is defined but FInterCorrectLayout is not, the current code logic in
tvm/src/relay/transforms/transform_layout.h
Lines 310 to 320 in 6a274af
// If there is no FInferCorrectLayout for the type, then we just assume the layout is correct. | |
static auto finfer_layout = Op::GetAttrMap<FInferCorrectLayout>("FInferCorrectLayout"); | |
if (Op::HasAttrMap("FTVMAlterOpLayout")) { | |
static auto falter_layout = Op::GetAttrMap<FTVMAlterOpLayout>("FTVMAlterOpLayout"); | |
if (ref_call->op.as<OpNode>()) { | |
Op op = Downcast<Op>(ref_call->op); | |
if (falter_layout.count(op) && !finfer_layout.count(op)) { | |
return memorizer->CallWithNewLayouts(ref_call, normal_new_args); | |
} | |
} | |
} |
will assume this OP accepts any layout. In the case of Dense, it only accepts 2D
data
tensor, and when the producer for data
tensor changes its layout, we need an additional layout transform to convert it back, which is not handled in L310-L320.
I really hate |
arg_item = memorizer.Transform(arg_item, new_in, new_in2); | ||
} else { | ||
if (old_in.defined()) arg_item = memorizer.Transform(arg_item, new_in, old_in); | ||
arg_item = memorizer.Transform(arg_item, old_in2, new_in2); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this code path, old_in != old_in2
. So after the transform at L383, how is it possible that we can apply another transform with old_in2
as the src layout?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. I thought that old_in
and old_in2
should be isomorphic, i.e., having the same structure (rank and subcoordinate factors, etc.) and only differing in how each axis is named (e.g., NW
vs NC
), given that they are describing the same tensor's layout. In this case, the transform can be applied. A concrete example: new_in=NW8w
, old_in=NW
, old_in2=NC
, new_in2=NC16c
, we will apply NW8w->NW
and NC->NC16c
, which is valid since layout_transform will work as long as the layout structure match the tensor shape. The net outcome is equivalent to a single transform NC8c->NC16c
.
However, I just hit a bizare case in BroadcastInferLayout that does not give isomorphic layouts:
tvm/src/relay/transforms/infer_layout_utils.h
Lines 224 to 234 in 6a274af
} else { | |
// Support scenarios where original operands were of type [N, H, W, C] and [C]. In this case, | |
// while transforming the layout, we expand dims to make C go to NHWC, and then use the | |
// modified layout of the first operator to call the layout transform. E.g. | |
// a in NCHWC16c | |
// b in C | |
// b = expand_dims(b) from C -> NHWC | |
// b = layout_transform(b) from NHWC -> NCHW16c | |
// add(a, b) | |
layouts.Set(small_idx, ret); | |
} |
This code path may expand the tensor's layout and assign
old_in2
to something with larger rank. For example, if the op is a+b
, and originally a
's layout is NCHW
and b
is W
, then its consumer (the broadcast node) will infer old_in2=NCHW
for b
. Now W
and NCHW
are not really isomorphic... I'm working on a fix, but this does sound pretty werid for me when you say a tensor with rank 1 is inferred with a layout with rank 4....
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI many of us are aware of the messy situation our layout convert pass is in. I believe an entire rewrite is desired at some point. I'd love to have your thoughts on this in the discuss forum etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for having me in your discussion. I agree on an entire rewrite. Ping me in discuss forum when you some day decide to do it :-), and I'd love to participate.
Thanks! I am also reviewing the InferCorrectLayout functions to see if there is anything broken. The broadcast one is one example I just hit.
Same here. I have 2 proposals:
Not sure which one would you prefer? Or any other ideas? |
I prefer the second one. I thought |
So the 2x2 is actually encoding 2 transforms: one is the producer's change, and the other is the consumer's (expected) change. For a weird example, before the change: |
* Fix layout pass * add unit test * fix lint * fix lint * fix lint
Thanks for contributing to TVM! Please refer to guideline https://tvm.apache.org/docs/contribute/ for useful information and tips. After the pull request is submitted, please request code reviews from Reviewers by @ them in the pull request thread.
Fix #10109
Originally during alter layout pass each node (
LayoutAlternatedExprNode
) is tied with only one specific layout stored inold_layout
(ornew_layout
after altering layouttvm/src/relay/transforms/transform_layout.h
Lines 175 to 176 in 6a274af
This PR removes this assumption. Now the specific layouts stored in
old_layout
andnew_layout
do not matter, instead, they only serve to convey that the tensor changed by the transformationold_layout
->new_layout
. With the new change, the alter layout pass can be understood as follows (at least IMO):InferCorrectLayout
function will tell us the "expected" input layouts from a consumer's perspective (e.g., conv_nchw will say 1st input is NCHW and 2nd is OIHW), on both the graphs before and after rewrite, denoted byold_in2
andnew_in2
. When they differ, it means that we need to apply the transformationold_in2
->new_in2
on the original inputs to work properly (https://github.com/lazycal/tvm/blob/42364531dea0fe72e1ffc80aba89ca04e50b2a67/src/relay/transforms/transform_layout.h#L384).layout_transform(CN->NC) + layout_transform(NW->NW8w)
is equivalent tolayout_transform(CN->NC8c)
. The previous assumption assumes that they are already aligned, thus only one transformation was inserted (https://github.com/lazycal/tvm/blob/42364531dea0fe72e1ffc80aba89ca04e50b2a67/src/relay/transforms/transform_layout.h#L381). I kept this case as well.The only thing that I didn't fully check is
InferCorrectLayout
s. I am not fully aware of its formal semantic, but I assumed that they should returnold_in2
andnew_in2
that are transformable. Some ops may not conform to this in order to workaround the previous restriction. E.g., I found that the dense_op has this specifal hack (https://github.com/apache/tvm/compare/main...lazycal:fix-layout-pass?expand=1#diff-b1f7105acbdb593d30dc3f1506f8f226d8663164bf0e46702f8b050b056604f6L213-L216). So I just removed it.Misc
Another potential bug I found is that the AlterOpLayout pass seems to be overloaded. People often use it to drop a subgraph to replace a certain node (e.g.,
tvm/python/tvm/topi/x86/conv2d_alter_op.py
Lines 171 to 181 in d1dafbd
transform_layout.h
assumes only replacing a node. For example,tvm/src/relay/transforms/transform_layout.h
Line 388 in 6a274af
this transformation is performed on the args to the call object, but I believe in the subgraph case it's should be changed to the args to the subgraph.