Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix LayoutRewriter #10118

Merged
merged 5 commits into from
Feb 3, 2022
Merged

Fix LayoutRewriter #10118

merged 5 commits into from
Feb 3, 2022

Conversation

lazycal
Copy link
Contributor

@lazycal lazycal commented Jan 31, 2022

Thanks for contributing to TVM! Please refer to guideline https://tvm.apache.org/docs/contribute/ for useful information and tips. After the pull request is submitted, please request code reviews from Reviewers by @ them in the pull request thread.

Fix #10109
Originally during alter layout pass each node (LayoutAlternatedExprNode) is tied with only one specific layout stored in old_layout (or new_layout after altering layout

Layout old_layout;
Layout new_layout;
). This layout can be determined by either one of its consumer(s) or producer. Previously they are implicitly assumed to agree with each other, but this imposes restrictions on the graphs, and some reasonable graph will also violate this assumption (see #10109 for a concrete example). Actually this assumption is more than necessary for this pass. All we need is to the "delta" information of how each tensor's layout is transformed, and we don't really need to know what layouts they are before or after change.

This PR removes this assumption. Now the specific layouts stored in old_layout and new_layout do not matter, instead, they only serve to convey that the tensor changed by the transformation old_layout->new_layout. With the new change, the alter layout pass can be understood as follows (at least IMO):

  1. Visit each node in Post DFS order.
  2. For each node, call the specific alter layout function. Its InferCorrectLayout function will tell us the "expected" input layouts from a consumer's perspective (e.g., conv_nchw will say 1st input is NCHW and 2nd is OIHW), on both the graphs before and after rewrite, denoted by old_in2 and new_in2. When they differ, it means that we need to apply the transformation old_in2->new_in2 on the original inputs to work properly (https://github.com/lazycal/tvm/blob/42364531dea0fe72e1ffc80aba89ca04e50b2a67/src/relay/transforms/transform_layout.h#L384).
  3. Note that the previous transformation is assuming original input, thus prior to that we need to transform the inputs back to its layout (https://github.com/lazycal/tvm/blob/42364531dea0fe72e1ffc80aba89ca04e50b2a67/src/relay/transforms/transform_layout.h#L383).
  4. Apparently the two transformations can be fused by aligning the layout letters, e.g., layout_transform(CN->NC) + layout_transform(NW->NW8w) is equivalent to layout_transform(CN->NC8c). The previous assumption assumes that they are already aligned, thus only one transformation was inserted (https://github.com/lazycal/tvm/blob/42364531dea0fe72e1ffc80aba89ca04e50b2a67/src/relay/transforms/transform_layout.h#L381). I kept this case as well.

The only thing that I didn't fully check is InferCorrectLayouts. I am not fully aware of its formal semantic, but I assumed that they should return old_in2 and new_in2 that are transformable. Some ops may not conform to this in order to workaround the previous restriction. E.g., I found that the dense_op has this specifal hack (https://github.com/apache/tvm/compare/main...lazycal:fix-layout-pass?expand=1#diff-b1f7105acbdb593d30dc3f1506f8f226d8663164bf0e46702f8b050b056604f6L213-L216). So I just removed it.

Misc

Another potential bug I found is that the AlterOpLayout pass seems to be overloaded. People often use it to drop a subgraph to replace a certain node (e.g.,

kernel_IHWO = relay.transpose(kernel_expr, axes=(1, 2, 3, 0))
kernel_IHWOo = relay.reshape(kernel_IHWO, (in_channel, kh, kw, out_channel // oc_bn, oc_bn))
kernel_OHWoI = relay.transpose(kernel_IHWOo, axes=(3, 1, 2, 4, 0))
kernel_OHWoIi = relay.reshape(
kernel_OHWoI, (out_channel // oc_bn, kh, kw, oc_bn, in_channel // ic_bn, ic_bn)
)
kernel_OHWoIie = relay.reshape(
kernel_OHWoIi,
(out_channel // oc_bn, kh, kw, oc_bn, in_channel // ic_bn, ic_bn // n_elems, n_elems),
)
kernel_OIHWioe = relay.transpose(kernel_OHWoIie, axes=(0, 4, 1, 2, 5, 3, 6))
), while the current code logic in transform_layout.h assumes only replacing a node. For example,
transformed_args.push_back(memorizer.Transform(arg, new_in[pt], new_in2[pt]));
,
this transformation is performed on the args to the call object, but I believe in the subgraph case it's should be changed to the args to the subgraph.

@masahi masahi self-assigned this Jan 31, 2022
@masahi
Copy link
Member

masahi commented Feb 1, 2022

@lazycal Thank you very much for working on this. I need to review our current layout rewrite implementation to understand this change, so please wait for a few days or longer until I merge this (but I will, definitely).

// since packing is always done on the "C" axis.
if (new_in_layouts.size() > 0 && new_in_layouts[0].defined() && new_in_layouts[0].ndim() == 2) {
return InferCorrectLayoutOutput({new_in_layouts[0], params->weight_layout}, {"NC"}, attrs);
}
return InferCorrectLayoutOutput({"NC", params->weight_layout}, {"NC"}, attrs);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need DenseInferCorrectLayout and DensePackInferCorrectLayout now? I added these functions to workaround alter layout issues, but that might not be necessary anymore. Can you try remove them and see what happens?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I skimmed through your change in 66ac470, and I think it is probably still needed. Mainly because when FTVMAlterOpLayout is defined but FInterCorrectLayout is not, the current code logic in

// If there is no FInferCorrectLayout for the type, then we just assume the layout is correct.
static auto finfer_layout = Op::GetAttrMap<FInferCorrectLayout>("FInferCorrectLayout");
if (Op::HasAttrMap("FTVMAlterOpLayout")) {
static auto falter_layout = Op::GetAttrMap<FTVMAlterOpLayout>("FTVMAlterOpLayout");
if (ref_call->op.as<OpNode>()) {
Op op = Downcast<Op>(ref_call->op);
if (falter_layout.count(op) && !finfer_layout.count(op)) {
return memorizer->CallWithNewLayouts(ref_call, normal_new_args);
}
}
}

will assume this OP accepts any layout. In the case of Dense, it only accepts 2D data tensor, and when the producer for data tensor changes its layout, we need an additional layout transform to convert it back, which is not handled in L310-L320.

@masahi
Copy link
Member

masahi commented Feb 1, 2022

I really hate new_ vs old_ and _in vs _in2 naming convention in the existing code, it's impossible to understand. More than welcome to clean them up :)

arg_item = memorizer.Transform(arg_item, new_in, new_in2);
} else {
if (old_in.defined()) arg_item = memorizer.Transform(arg_item, new_in, old_in);
arg_item = memorizer.Transform(arg_item, old_in2, new_in2);
Copy link
Member

@masahi masahi Feb 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this code path, old_in != old_in2. So after the transform at L383, how is it possible that we can apply another transform with old_in2 as the src layout?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. I thought that old_in and old_in2 should be isomorphic, i.e., having the same structure (rank and subcoordinate factors, etc.) and only differing in how each axis is named (e.g., NW vs NC), given that they are describing the same tensor's layout. In this case, the transform can be applied. A concrete example: new_in=NW8w, old_in=NW, old_in2=NC, new_in2=NC16c, we will apply NW8w->NW and NC->NC16c, which is valid since layout_transform will work as long as the layout structure match the tensor shape. The net outcome is equivalent to a single transform NC8c->NC16c.

However, I just hit a bizare case in BroadcastInferLayout that does not give isomorphic layouts:

} else {
// Support scenarios where original operands were of type [N, H, W, C] and [C]. In this case,
// while transforming the layout, we expand dims to make C go to NHWC, and then use the
// modified layout of the first operator to call the layout transform. E.g.
// a in NCHWC16c
// b in C
// b = expand_dims(b) from C -> NHWC
// b = layout_transform(b) from NHWC -> NCHW16c
// add(a, b)
layouts.Set(small_idx, ret);
}

This code path may expand the tensor's layout and assign old_in2 to something with larger rank. For example, if the op is a+b, and originally a's layout is NCHW and b is W, then its consumer (the broadcast node) will infer old_in2=NCHW for b. Now W and NCHW are not really isomorphic... I'm working on a fix, but this does sound pretty werid for me when you say a tensor with rank 1 is inferred with a layout with rank 4....

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI many of us are aware of the messy situation our layout convert pass is in. I believe an entire rewrite is desired at some point. I'd love to have your thoughts on this in the discuss forum etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for having me in your discussion. I agree on an entire rewrite. Ping me in discuss forum when you some day decide to do it :-), and I'd love to participate.

@lazycal
Copy link
Contributor Author

lazycal commented Feb 1, 2022

@lazycal Thank you very much for working on this. I need to review our current layout rewrite implementation to understand this change, so please wait for a few days or longer until I merge this (but I will, definitely).

Thanks! I am also reviewing the InferCorrectLayout functions to see if there is anything broken. The broadcast one is one example I just hit.

I really hate new_ vs old_ and _in vs _in2 naming convention in the existing code, it's impossible to understand. More than welcome to clean them up :)

Same here. I have 2 proposals:

  • replace _in with _in_input and _in2 with _in_infer.
  • Or more explicit, _in with _in_producer and _in2 with _in_consumer.

Not sure which one would you prefer? Or any other ideas?

@masahi
Copy link
Member

masahi commented Feb 1, 2022

Same here. I have 2 proposals:

* replace `_in` with `_in_input` and `_in2` with `_in_infer`.

* Or more explicit, `_in` with `_in_producer` and `_in2` with `_in_consumer`.

Not sure which one would you prefer? Or any other ideas?

I prefer the second one. I thought infer-ed or not is for old vs new distinction? I think I prefer before vs after than old vs new, but a name like before_in_producer doesn't sound good... Actually I don't understand why we need 2 x 2 combination of layouts.

@lazycal
Copy link
Contributor Author

lazycal commented Feb 1, 2022

So the 2x2 is actually encoding 2 transforms: one is the producer's change, and the other is the consumer's (expected) change. For a weird example, before the change: (A) conv_NCHW -> (B) conv_NCHW
after the change: (A') conv_NCHW2c -> (B') conv_NCHW4c.
Suppose we are rewriting Node B. Then the producer's change will be NCHW->NCHW2c, and the consumer's expected change will be NCHW->NCHW4c. There are 4 layouts, hence 2x2 combinations.

@masahi masahi merged commit e53cbe4 into apache:main Feb 3, 2022
ylc pushed a commit to ylc/tvm that referenced this pull request Feb 16, 2022
* Fix layout pass

* add unit test

* fix lint

* fix lint

* fix lint
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug] AlterOpLayout failed on Conv->Transpose->Conv
2 participants