-
Notifications
You must be signed in to change notification settings - Fork 58
[DISCUSS] Layout transformation in TIR graph #162
Comments
cc @junrushao @vinx13 |
It doesn't exist yet. It will use arith analysis or use some annotation as hint. For simple layout conversion like reordering / packing axes, they should be well supported because they are basically affine transformations. |
Thanks @vinx13 for the response! I have a few follow up questions too :) Given the discussion, it seems most people are in agreement that we need to flow/merge layout constraints through other ops. This means that given the following graph: (pre_layout_convert) -> matmul -> (post_layout_convert) -> add -> (pre_layout_convert) -> matmul -> (post_layout_convert) we should be able to flow the (pre_layout_convert) -> matmul -> add -> (post_layout_convert) -> (pre_layout_convert) -> matmul -> (post_layout_convert) I think such transformations would eventually allow us to fold
|
Thank you for great proposal, @vinx13! I also have a question.
If so, do we have some sort of guarantee that the performance of P1 and P2 match to each other? Otherwise, we may need to be smart at picking up the right path in different situations. |
Thanks for the discussion. |
Ah interesting. If I am interpreting your comment correctly, the computation part for P1 & P2 is equivalent (can be made equivalent?) modulo post/pre layout conversions. |
@psrivas2 exactly, it can be made equivalent |
Many many thanks for the great RFC and discussions everyone! I wanted to initiate a discussion around the hoisting and splitting step involved when padding is present. When splitting / hoisting out padding and cropping transformations, in order to preserve the ability to simplify between split producer and consumer primfuncs, in some cases we should expect to leave behind assumptions to preserve the local information that becomes non-local after splitting. Let us take the following TIR as an example. Prior to hoisting, @ir_module
class BeforeHoist:
@R.func
def main():
R.call_tir(func)
@T.prim_func
def func(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]):
AC = T.alloc_buffer([4, 4], "int32")
for io, ii in T.grid(4, 4):
with T.block():
T.block_attr("preproc", "pad")
AC[io, ii] = if_then_else(4 * io + ii < 14, A[4 * io + ii], 0)
BC = T.alloc_buffer([4, 4], "int32")
for i, j in T.grid(4, 4):
BC[i, j] = 2 * AC[i, j]
for i in T.serial(14):
with T.block():
T.block_attr("postproc", ["crop", 0])
B[i] = BC[i // 4, i % 4] Above we have a padding step that allows the inner compute statement to operate on a padded space. When splitting, we can leave behind the now non-local assumptions that could help simplifications of the inner compute. In this case the inner compute is already simplified, but the example can still help for discussion. Consider hoisting the padding loopnest out from the rest of @ir_module
class AfterHoistOfPadStage:
@R.func
def main():
R.call_tir(pad)
R.call_tir(func)
@T.prim_func
def pad(A: T.Buffer[14, "int32"], AC: T.Buffer([4, 4], "int32")):
for io, ii in T.grid(4, 4):
with T.block():
T.block_attr("preproc", "pad")
AC[io, ii] = if_then_else(4 * io + ii < 14, A[4 * io + ii], 0)
@T.prim_func
def func(AC: T.Buffer[(4, 4), "int32"], B: T.Buffer[14, "int32"]):
for io, ii in T.grid(4, 4):
T.assume(4 * io + ii < 14 or AC[io, ii] == 0)
BC = T.alloc_buffer([4, 4], "int32")
for i, j in T.grid(4, 4):
BC[i, j] = 2 * AC[i, j]
for i in T.serial(14):
with T.block():
T.block_attr("postproc", ["crop", 0])
B[i] = BC[i // 4, i % 4] Here the choice was made to leave behind the assumption O1.
O2.
O3.
Given that these options constitute a degree of freedom, for the purpose of staging the efforts, we could consider focusing the layout planner's initial efforts around leaving the assumptions proposed in O1 as they will be the most immediately applicable for the case of convolution and other contraction based operations where it is desired to operate in a padded and block-transformed data space. Please let us know your thoughts on this aspect of hoisting for the layout planner. cc @Lunderberg who is fully paged in on this topic. Thanks !! |
See #277 |
As we start to work on specific hardware, many operators would expect a specific kind of layout for both data and weight. Logically the layout start with simple ones. This thread discusses an example of how to handle layout transformation in a Relax-TIR setting.
The general idea is to lift layout transformation into the graph and cancels out the pair of pre-compute and post-compute transformations.
The same principle can also be applied to other cases like adding padding before layout transformation.
Example problem of layout transformation
Considering the example program. the data layout of A and C are in normal layout.
Assume that due to hardware or other restrictions, we need to convert the layout to a different setting. Say the layout is being represented as the following mapping
For simplicity let us assume that we do not change the layout of B(and W), but same principle applies. When we have the layout requirement. The first step is for TIR function to express the layout requirement through transformation.
Step 0: PrimFunc transformation
In the first step, the PrimFunc is transformed to a program with three steps:
Step 1: Lift Layout Convert into Graph
If we stop at step 0, the additional layout convert brings extra cost and sometimes infeasible, if the memory does not support the layout natively. In this second step, we lift layout conversion into graph
The result program is shown as above. Note that the layout conversion get lifted into the graph part. Now
matmul_physical
runs completely under the desirable (physical
) layout.Step 2: Fold Layout conversion
The above step still leaves many layout conversions in the graph code. In this step, we will run folding to fold the layout conversion. Note that in the above code segment.
pre_layout_conversion
andpost_layout_conversion
cancels out with each other and forms an identity(this can be done by TIR analysis)So we can run folding, the final code becomes as follows
Importantly, imagine we have a long sequence of matmul chains, then the final code will become
Discussion and Remarks
There are several advantages of this method.
The layout handling is useful for several use cases:
The same principle applies to padding to axes. Here are some additional examples.
Example Problem of layout padding
Step 0: PrimFunc transformation
Support the hardware requires the input to be padded to multiple of 128, it can be expressed in PrimFunc
Step 1: Lift padding and cropping to the graph
Step 2: Fold layout conversion
lv3 and lv4 is a pair of
crop
andpad
, with the same shape before padding. If the padding value in lv4 is the same as the value before cropping in lv3, or the padding value isT.undef()
, they can be cancelled out.Conversion of
w0, w1
can also be folded at compile time, the result will beThis enforce that each transformation step maintain the original semantic in the TIR and Relax graph. If we are allowed to output the final result in physical shape, with some undefined value in the padding region, the final
post_layout_convert
can also be eliminated.Padding on reduction dimensions
When the padding is introduced in the reduction dimensions, there are requirements on padding value so that
crop
andpad
and be cancelled. Although the padding value should ensure semantic correctness (e.g.pad_value = 0.0
forconv
,-inf
formax_pool
), the value in the padding region of the output can still be arbitrary as they will be cropped out anyways. In this case, we will need to insert hints about the padding value (which is usually operator specific property, for example, applying conv filter to padding region of value 0.0 has output 0.0).Here are an example of conv1d with
padding=2
on both sides of the input.Suppose the conv1d physically requires the input to be multiple of 8 with padding already explicily inserted, after inserting padding and crop:
The assertion of
Y[18:] == 0
is needed to hint the next padding can be simplified. In general case of multiple convolutions, this may also need non-local transformations, such as propagating the padding into the beginning.cc @YuchenJin @Hzfengsy @jinhongyii @sunggg @junrushao1994 @tqchen
The text was updated successfully, but these errors were encountered: