-
Notifications
You must be signed in to change notification settings - Fork 81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Buffer Layout Padding #77
Conversation
This RFC introduces a method to specify padding to be applied as part of a buffer layout transformation, to be used when the desired layout does not evenly tile the buffer being transformed, and simplifications that can be performed based on these padded buffers. The motivating examples are primarily in the "Implementation options" section, which goes through several desired usages of the buffer padding, and how they can be automatically derived using the TIR primitives/transformations described in earlier sections.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Lunderberg . As we start to handle more general workloads, it is indeed very helpful to have primitives that handles imperfect buffers via padding(and sometimes predication). Bringing primitives along the direction of layout transform/padding would be a very positive step toward that direction.
My main comments is on how we can achieve that goal. Specifcially, it would be really nice if we can enhance improvements as transformations on the IR and avoid changes to the IR when possible. Let us discuss if we can explore alternative options that does not involve changes to the IR(this case buffer semantics) and build on sequence of primitives/annotations instead.
|
||
## TIR Changes | ||
|
||
### Buffer Annotation of Padding Predicate/Constraint Pairs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with the need of introducing primitives that handles padding, along with transformation.
The main question is whether or not we should introduce an IR semantics change to enable such usecase. On one hand, introducing a new IR semantics certainly makes the TIR istelf to be more expressive, but it also brings additional complexities for every primitives/passes that handles the IR. Special primitives may also be needed to handle the new introduced semantics.
It would be great for us to explore such capabilities(of layout padding) without introducing additional complexities on the IR.
Back to our goal of introducing padding. It should be possible to have an explicit buffer transformation stage that coppies the data into the target padded buffer(with predication) , then run computation on the padded value.
There is certainly some tradeoffs here, but decoupling the padding behavior as a separate stage of IR computation should be able to allow us to reuse more primitives without having to specializing for BufferConstraint
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1. I agree that we are solving a challenging and important problem, which needs additional data-structure and transformations. But as @tqchen mentioned, IR data structure needs to be stable. Once it changes, we may need lots of effort on reviewing all exsiting primitives and transformations for the little changes.
I agree we can enhance the IR sematics when "necessary", if we have no other way to go. Before that, let's think about it carefully to find alternate paths.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i agree it's helpful to think through alternatives here. could we consider some example transformations we may want to make (e.g. eliding or moving the operations which write to the padding) or pattern-matching on such operations and reducing them to hardware intrinsics (e.g. perhaps there is a way to tell the hardware how much padding to include when the value is always constant and a particular operation is in use).
on the one hand, modeling the padding computation explicitly in TIR is a more logical reuse of existing TIR. on the other hand, it may be more expensive to match this and the compiler may be slower.
i'm not necessarily in favor fo any one solution, but i think this is the sort of thing we should discuss to try and inform that decision.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tqchen @Hzfengsy Thank you, and I definitely agree on minimizing the number of IR changes being made. (Also, phew, this ended up being a longer reply than I had expected, which probably means that whatever the result of this thread, the "Rationale and Alternatives" section should be updated.)
@areusch The example transformations are largely present in the "Implementation Options" section. The goal of that section was to describe different example transformations that we'd like to be able to make, and to ensure that they could be made using the functionality introduced earlier in the RFC. It wasn't until this morning that I realized that there should also be links in the other direction, pointing from the proposed IR changes to the motivating use case.
Below is the general rationale, with high-level implementations.
Starting by listing out the desired properties of an implementation.
- No changes to existing TIR data structures
- No additional meaning to existing TIR data structures
- Simplifications can use constraints from multiple buffers
- No ordering requiring between
transform_layout
and fuse/split/reorder. (Conditional statements that are removed using the buffer constraints are typically introduced by loop rewrites.) - Can be used to describe out-of-bounds access (e.g. texture memory clamping on a GPU) that returns a default value.
- Only allocate memory when required or requested
Implementations considered
-
A. All buffer transformations introduce new stage
- Pro: No coordination required between different operators.
- Con: Any producer/consumer interactions must be recognized by operator fusion/device planning.
- Con: Cannot apply to primfunc input/outputs. (e.g. To de-duplicate operators that differ only by underlying layout, such as
topi.nn.conv2d_hwcn
,topi.nn.conv2d_nchw
,topi.nn.conv2d_NCHWc
, etc.) - Con: May introduce unnecessary data copies, if the constraint required by the consumer is already met.
-
B. Perform one
transform_layout
at a time. For each one, simplify using provided constraints, do not store constraints afterward.-
Pro: Main downside is that it could only use the constraints of a single buffer at a time. This wouldn't be able to express simplifications that rely on the padding in multiple buffers. (e.g. elementwise operator
-
Con: Requires loop rewriting to be done either inside
layout_transform
or prior to callinglayout_transform
. -
Con: Can't be applied to use cases outside of layout transformations (e.g. texture memory clamping on a GPU), where simplifications could benefit from assumed constraints.
-
-
C. Perform all
transform_layout
in a single function call, passing all layout transforms and padding constraints.- Pro: Simplifications may use constraints of all buffers being transformed.
- Con: Requires changing the calling convention for layout transformations.
- Con: Requires loop rewriting to be done either inside
layout_transform
or prior to callinglayout_transform
. - Con: Can't be applied to use cases outside of layout transformations (e.g. texture memory clamping on a GPU), where simplifications could benefit from assumed constraints.
-
D. Express buffer constraints using existing
AssertStmt
In pseudocode, each consumer would have roughly the loopnest below. However, this would still need to have some way of indicating that the constraint should be removed when lowering, and should not produce any runtime assertions.
for indices in T.grid(*transform_shape): if padding_predicate(indices): T.Assert(buf[indices] == pad_value(indices))
- Pro: No change to TIR data structures
- Pro: No change required for calling convention for layout transformations.
- Pro: Simplifications may use constraints of all buffers being transformed.
- Pro: Can be applied to use cases outside of layout transformations (e.g. texture memory clamping on a GPU), where simplifications could benefit from assumed constraints.
- Pro: No ordering between loop/layout transforms, because the constraints can be determined from the TIR.
- Con: Additional meaning attached to existing TIR data structures.
- Con: Can only describe a fixed number of assertions, wouldn't be able to express a default value for all out-of-bounds reads.
-
E. Express buffer constraints as a field in
PrimFuncNode::attrs
- Con: Passes that replace
Buffer
objects must be aware of this attribute, in order to update theBuffer
object stored in it.
- Con: Passes that replace
-
F. Express buffer constraints as a new member variable in
Buffer
- Con: Changes change to TIR data structures
- Pro: No change required for calling convention for layout transformations.
- Pro: Simplifications may use constraints of all buffers being transformed.
- Pro: Can be applied to use cases outside of layout transformations (e.g. texture memory clamping on a GPU), where simplifications could benefit from assumed constraints.
- Pro: Can rewrite loop structure later, use existing constraints.
- No changes to existing TIR data structures
- No additional meaning to existing TIR data structures
- Simplifications can use constraints from multiple buffers
- No ordering requiring between
transform_layout
and fuse/split/reorder. - Can be used to describe out-of-bounds access (e.g. texture memory clamping on a GPU) that returns a default value.
- Only allocate memory when required or requested
- A. All buffer transformations introduce new stage
- B. Perform one
transform_layout
at a time. For each one, simplify using provided constraints, do not store constraints afterward. - C. Perform all
transform_layout
in a single function call, passing all layout transforms and padding constraints. - D. Express buffer constraints using existing
AssertStmt
- E. Express buffer constraints as a field in
PrimFuncNode::attrs
- F. Express buffer constraints as a new member variable in
Buffer
Goal 1 | Goal 2 | Goal 3 | Goal 4 | Goal 5 | Goal 6 | |
---|---|---|---|---|---|---|
Impl A | ✔️ | ✔️ | ❌ | ❌ | ❌ | ❌ |
Impl B | ✔️ | ✔️ | ❌ | ❌ | ❌ | ✔️ |
Impl C | ✔️ | ✔️ | ✔️ | ❌ | ❌ | ✔️ |
Impl D | ✔️ | ❌ | ✔️ | ✔️ | ❌ | ✔️ |
Impl E | ✔️ | ❌ | ✔️ | ✔️ | ✔️ | ✔️ |
Impl F | ❌ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ |
The implementations that would satisfy the largest number of the desired goals would be adding the member variable BufferNode::constraints
, or adding a field to PrimFuncNode::attrs
that holds the constraints. Between the two, I lean toward having it as an explicit member variable, so that incorrect usage appears as a compilation error when compiling TVM, but would find either implementation acceptable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Lunderberg for dissected discussions, this is helpful.
Besides Goal 1, there is one additional implied goal:
- Goal 7: The compositionality of primitives of the buffer with layout-constraints. Specifically, how composable are the existing and new primitives(such as split/reorder/tensorization, reduction factorization) with the buffer layout constraints.
When building abstractions, we are actually trying to make a balance among two things: the simplicity/composationality and the things we can support.
- It is quite natural that a more complicated impl would hit more marks initially.
- On the other hand, there is always a consideration of added complexity and how composable our additions are with existing constructs.
In our case, we are facing a N * M problem. Where N is number of primitives and M is number of possible IR variantions(like layout constraints) we introduce to the IR. Additional field in the IR effectively means we either have to (a) introduce specific codepath to handle layout constraints, (b) generalize all relevant primitives to take that into account. The N * M problem will grow as N and M increases.
To manage our complexity, our current rationale is to keep M as stable as possible and grow N that can compose with each other.
It is also useful to come back to the high-level goal besides these goals for a single function. Our high-level goal is to enable effective end to end models under a good native layout(which involves padding and layout transformation). And it would actually be really nice to have an example at the e2e level to show how the set of transformations affect our optimizations.
Among the existing goals listed. Goal 6 is certainly a very important one. Goal 3 is primarily an implementation difference as in terms of different ways of building pattern matching. Goal 4 is not necessarily a need as many optimizations actually benefit from reduced complexity (e.g. tensorization in physical memory)
Goal 6 is an important one that would indeed touches the high-level (e2e) goal itself. Along that direction, a smart variant of Impl A (by interacting with graph) would actually enable simpler realization of goal 6 (which is important) by lifting the transformations of input/output out, and then cancel out between operators, while preserving the information.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the general idea of reusability A0 and A1 are not that different. Notably, A0 brings the reusablity in a sense of the utility simplifications wrt to folding etc. While A1 brings the simplifcations in the form of TIR level.
Note that compute defs in A0 can remain the same, the main difference is how the preferable schedules are being derived per informed by hw target as they are scheduled.
From a pure impl pov, considering the e2e goal. Having padding being sorted out in the graph level, actually still simplifies the scheduling layer.
Note that A0 still does not preclude us from doing constraint matching, one can view that as insert a pre-proc stage that "must simplifies". At the high-level it is only a repr difference(of putting in the buffer decl vs stages).
Under certain scenarios we could indeed consider put some constraints at the interface level. That would need some more thoughts on semantics, how would they interact with graph, and structural complexity (whether padding is something that worth the IR complexity).
For most e2e goals perhaps having a padding throughout graph level is not a bad way to reduce that part of the complexity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was trying to work through what the transformation hoisting would look like at the graph level, and I think it runs into the same requirement of tracking the buffer constraints. The derivation uses a toy model, consisting of the two sequential 1-d convolutions from this section. Layout transformations are introduced in read/write caches, and those layout transformations are hoisted into separate functions, so that they can be manipulated at the graph level. The adjacent transformations are then fused and simplified.
(The full derivation is in this gist, as including it here made the comment unreadably long.)
The key result from the derivation is that, for a padded transformation f
that maps from the logical layout to the physical layout, f_inv(f(X))
can be simplified to X
at the graph level, but f(f_inv(X))
cannot. Instead of simplifying into a memcpy, the transformations result in the following:
@T.prim_func
def fused_inv_transform_Y_transform_Y(
Y_write_cache: T.Buffer[(3, 8), "float32"],
Y_read_cache: T.Buffer[(3, 8), "float32"],
):
for io, ii in T.grid(3, 8):
i = 8 * io + ii - 2
if 0 <= i < 18:
Y_read_cache[io, ii] = Y_write_cache[io, ii]
else:
Y_read_cache[io, ii] = 0.0
This expression could be simplified by using the original constraint on Y_write_cache
provided in the layout transformation, but reconstructing that constraint couldn't be done by local analysis of any single function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Lunderberg That's true. If the reduction dimension is padded, we will need to insert hint in the graph to assert it was previously padded by 0. From the graph rewriting pov, we can also see this a transformation done in graph level (doesn't rely on arithmetic simplifications)
Example
X: R.Tensor[16]
F: R.Const[3]
Y: R.Tensor[18] = conv1d(X, F, pad=2)
Z: R.Tensor[20] = conv1d(Y, F, pad=2)
Inserting padding and crop:
X: R.Tensor[16]
F: R.Const[3]
X_pad = pad(X, before=2, after=6)
Y = conv1d(X_pad, F, pad=0)
assert(Y[18:] == 0)
Y_crop = crop(Y[0:18])
Y_crop_pad = pad(Y_crop, before=2, after=4)
Z = conv1d(Y_crop_pad, F, pad=0)
Z_crop = crop(Z[0:20])
Then we can propagate the padding information and combine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Of course for the sake of discussion the example is limited to two convolutions. In some cases multiple (N) back-to-back contractions with padded transformations, handling at the graph-level can require similar non-local information/hints across the sequence of operators to Nth order.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the discussions so far :) I think we all agree that having additional information variant in the buffer interface would results powerful expressiveness. Just like in high-level language while loop and for resulted in being more powerful.
Just like stated in the last statement. We could indeed consider put some constraints at the interface level. That would need some more thoughts on semantics, how would they interact with graph, and structural complexity (whether padding is something that worth the IR complexity).
On the other hand, I would still encourage us to think whether such complexity is worthwhile for only for padding. As a non-local back to back transformation padding is relatively easy and still achieved the goal.
corresponding to the transformation padding. The `value` field is | ||
defined by the user input in `pad_value` | ||
|
||
### New TIR Op, `tir::builtin::arbitrary` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we name it as "undef"? Do we need to specify the behavior when the arbitraty
value involve in computations like llvm undef and poission values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd been avoiding "undefined", as that could cause confusion with C++ notion of "undefined behavior". Where any existence of undefined behavior in C++ causes the entire program to have undefined behavior, use of T.arbitrary
would only propagate to any expression that uses T.arbitrary
as an input.
That said, the LLVM undef
maps quite well to the concept I was thinking of, so I agree with the name change. As far as I can tell, the main difference between the proposed tir::builtin::undef
and LLVM's undef
is that LLVM is tracked to individual bits, whereas the tir::builting::undef
would propagate to entire values.
I've changed the name, added a link to the Prior Art section for LLVM's undef
, and updated this section to define the computations that use undef
.
### New Primitive - Reorder Loops According to Buffer | ||
|
||
By default in S-TIR, `transform_layout` modifies the underlying layout | ||
of a buffer, but does not re-order loops that iterate over the buffer. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For S-TIR,actually we have three sort of "layouts":
- the loop iter layout, controled by
split
/fuse
/reorder
- the block iter binding
- the buffer layout
I can image one may want to simultaneously transform all of them or combinations of the three. For example, apache/tvm#11485 transform both the loop and block binding.
What if we can create uniformly designed primitive or uniform guides for all useful combinations. And then create dedicated api like sequential_buffer_access
on the ground of uniform interfaces.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, and there was a similar comment from @Hzfengsy . I've updated this section to instead propose a utility function to generate the appropriate split/fuse/reorder, rather than being itself a primitive. Taking another look at apache/tvm#11485 and the block iter bindings, I think the utility might be as simple as applying transform_block_layout
with a mapping defined based on the block iter bindings of spatial dimensions.
The main uniform usage coming to mind would be applying the transformation to all three layouts uniformly. Though, that would only be well-defined if all three are already uniform, so that wouldn't catch cases where it changes from scatter to gather or vice versa. I'll think it over a bit more there, thank you!
|
||
## TIR Changes | ||
|
||
### Buffer Annotation of Padding Predicate/Constraint Pairs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1. I agree that we are solving a challenging and important problem, which needs additional data-structure and transformations. But as @tqchen mentioned, IR data structure needs to be stable. Once it changes, we may need lots of effort on reviewing all exsiting primitives and transformations for the little changes.
I agree we can enhance the IR sematics when "necessary", if we have no other way to go. Before that, let's think about it carefully to find alternate paths.
# Original function | ||
@T.prim_func | ||
def func(A: T.Buffer[(16,), "int32"]): | ||
with T.block('compute'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not a typical S-TIR func, which is expected to be:
@T.prim_func
def func(A: T.Buffer[(16,), "int32"]):
for i in T.serial(16):
with T.block('compute'):
vi = T.axis.S(16, i)
A[vi] = vi
Here are mistakes in most cases in this RFC:
- Usually loops are outside the block, and one block only contains a single stmt
- The
if-else
branch is represented byT.where
if it's inside a block - Please explicitly write block vars, since it's important during schedule.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, and I can update the examples accordingly, though it will be a bit before I have the availability to do so. (I had been primarily focusing on the algebraic manipulations during the initial drafting.)
|
||
This transformation is similar to what can be done using | ||
split/fuse/reorder, but has two key differences. First, it presents a | ||
simpler user experience, as a transformed buffer can be accessed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that sequential_buffer_access
is beneficial. But I'd like to make it a sugar rather than a new primitive. i.e., when users call sch.sequential_buffer_access
, it will implicitly call a set of split,
reorder
, and the schedule trace will be:
sch.split(...)
sch.reorder(...)
The reasons are:
Primitive
is expected irreplaceable. If one transformation can be represented by a set of existing transformations, we won't create a new primitive (but sugar is fine)- Reuse primitives split/fuse/reorder may reduce the maintenance cost. We don't need to fix the same bug (if it exists) twice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, and I had been going back and forth on that while drafting. I had been considering the primitives as defining not just the extent of a search space, but also how many steps it would take for an optimizer to identify a transformation in the search space.
I've updated this section from an independent schedule primitive to a utility function.
Following @wrongtest's suggestion at apache#77 (comment).
Following suggestion from @Hzfengsy at apache#77 (comment)
this is on the agenda for tomorrow's community meeting. Perhaps we could discuss in higher bandwidth there? |
A[i] = 0.0 | ||
``` | ||
|
||
### New Transform - Hoist Expression |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wrongtest I remember you also proposed imperative loop partitioning in https://discuss.tvm.apache.org/t/introducing-ty-nnp-backend-with-end2end-tensorir-integration/11807. Could you comment how does this (and other related utilities / primitives in this RFC) relate to the one you proposed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are different alternative or can be combined on certain workloads. There is a discussion on performance issue of matmul when we just change dimension 128 -> 127. https://discuss.tvm.apache.org/t/te-vectorize-do-we-have-plan-to-support-vectorize-for-non-divisible-split/12469
I think it might be a good working example. Below is what user get with loop split j
: 127 -> (4, 32)
for i in range(127):
for k in range(127):
for j.outer in range(4):
for j.inner in T.vectorized(32):
if T.likely(j.outer * 32 + j.inner < 127, dtype="bool"):
C[i*127 + j.outer*32 + j.inner] += A[i*127 + k] * B[k*127 + j.outer*32 + j.inner]
The issue is that complex condition has to be introduced to maintain the program semantic, and it hurts the performance and generally we can not vectorize program with control flow.
Now I understand we have different alternatives to handle this:
-
Loop partition
We can already annotate the loop var with hint using non-imperative loop partition.for j.outer in range(4, annotations={"pragma_loop_partition_hint": 1}
After
LoopPartition
pass (and simplify) it becomes:for i in range(127): for k in range(127): # j.outer in [0, 3) for j.outer in range(3): for j.inner in T.vectorized(32): # condition is const true, optimize out C[i*127 + j.outer*32 + j.inner] += A[i*127 + k] * B[k*127 + j.outer*32 + j.inner] # j.outer in [3, 4), optimize out for j.inner in T.vectorized(31): # condition becomes j.inner < 31, hoisted with loop C[i*127 + j.outer*32 + j.inner] += A[i*127 + k] * B[k*127 + j.outer*32 + j.inner]
Then the condition branch get eliminated on different loop parts, thus becomes more friendly to performance optimizations like vectorization. For "imperative" partition, it just propose we can just partition on schedule phase when one wants to schedule different parts, such as giving different vectorization width.
-
Loop padding
With current RFC, I understand we can padding
C
andB
's innermost dimension to 128, and drop the condition directly somehow. Then it directly becomes (IIUC, we may also insert some "arbitrary" value filling code on edges and optimize them out then?)for i in range(127): for k in range(127): for j.outer in range(4): for j.inner in T.vectorized(32): C[i*127 + j.outer*32 + j.inner] += A[i*127 + k] * B[k*127 + j.outer*32 + j.inner]
On this particular case, I believe the padding is the better choice since we can get very neat codes with minimal over-computations. And we can also utilize the padding trick for different loop parts in alternative (1).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(IIUC, we may also insert some "arbitrary" value filling code on edges and optimize them out then?)
Yup, the loop that defines writes T.undef()
into the padding values would be present as an intermediate. This allows RemoveNoOp
to be much more general, since it only needs to look for two sequential writes to the same indices to conclude that the first is a no-op. As a result, a matching else_case
would be a no-op, and therefore safe to insert without impacting the final result
we discussed this at the June 6 community meeting. a significant chunk of the meeting was spent presenting the RFC, and we had about 15 minutes of discussion at the end. i think there is more to be discussed here. if we'd like to discuss in high-bandwidth, we can bring this back up at future community meetings. here are notes:
|
# Reference-level explanation | ||
[reference-level-explanation]: #reference-level-explanation | ||
|
||
## TIR Changes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tqchen replied yesterday
...
Under certain scenarios we could indeed consider put some constraints at the interface level. That would need some more thoughts on semantics, how would they interact with graph, and structural complexity (whether padding is something that worth the IR complexity).
...
@tqchen Our hope is to have these thoughts and discussion in this RFC and welcome your and others analysis on the semantics and the specific complexity it would introduce.
Thanks for the all great discussions! It is so excited that we will have a more powerful ability to handle all things like paddings and imperfect tiles. Since our team rely on the code path of s-tir, we are extremely interested in the story on s-tir. I would be very appreciated if we have some details on s-tir padding. I would like to use a [127, 127, 127] matmul to depict my questions :) @T.prim_func
def matmul(A: T.Buffer[(127, 127), "float32"], B: T.Buffer[(127, 127), "float32"], C: T.Buffer[(127, 127), "float32"]):
for i, j, k in T.grid(127, 127, 127):
with T.block("compute"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] += A[vi, vk] * B[vk, vj] In current s-tir state, we can construct padded loop and buffer using existing primitives by "split and then fuse" trick: s = tvm.tir.Schedule(matmul)
blk = s.get_block("compute")
i, j, k = s.get_loops(blk)
s.fuse(*s.split(i, factors=[4, 32]))
s.fuse(*s.split(j, factors=[4, 32]))
s.fuse(*s.split(k, factors=[4, 32]))
s.transform_layout(blk, "A", lambda i,k: ((i // 32) * 32 + i % 32, (k // 32) * 32 + k % 32))
s.transform_layout(blk, "B", lambda k,j: ((k // 32) * 32 + k % 32, (j // 32) * 32 + j % 32))
s.transform_layout(blk, "C", lambda i,j: ((i // 32) * 32 + i % 32, (j // 32) * 32 + j % 32)) We will get (if simplified) @T.prim_func
def func(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]):
for i_0_i_1_fused, j_0_j_1_fused, k_0_k_1_fused in T.grid(128, 128, 128):
with T.block("compute"):
vi = T.axis.spatial(127, i_0_i_1_fused)
vj = T.axis.spatial(127, j_0_j_1_fused)
vk = T.axis.reduce(127, k_0_k_1_fused)
T.where(i_0_i_1_fused < 127 and j_0_j_1_fused < 127 and k_0_k_1_fused < 127)
T.reads(A[vi, vk], B[vk, vj])
T.writes(C[vi, vj])
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] Then the only thing left is the condition for padding:
|
PrimExpr predicate; | ||
Optional<PrimExpr> value | ||
}; | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The proposed construction seems to have a very interesting relation with s-tir block design!
-
For
Block
- With
T.axis.*
andT.predicate
, it specifies mapping from loop space to computation instances space (block iter space). - With
T.reads
,T.writes
, it specifies mapping from block iter space to buffer access space.
- With
-
For
BufferConstraint
- Additionally specifies the buffer access space and behavior out of iter space before padding.
Thanks for the discussion. To provide more context, the A0 approach we discussed is TIR-Relax layout rewriting tlc-pack/relax#162 (the general idea is to lift such transformation in TIR scheduling into the graph, and then cancels out redundant intermediate transformations by either proving fusing the pair of post-compute and pre-compute transformations produces an identity TIR function, or use high-level operator semantic). I think this is very similar to the graph-level solution mentioned by @wrongtest |
Thanks for sharing the contextual pointers for the community @vinx13. Agreed the approaches discussed are both valid. I would actually like to argue the stronger point that they are complimentary and are only appearing to be contrary because we are considering too narrow of a scope. It can be helpful to share an overview of common handlings of layout transformations in ML compilers. Most of my arguments against A0 (local cancellation in the graph only) being sufficient stem from prior experiences in graph layout optimizations. The evolution of the graph approaches for optimizing layouts I've seen followed the below trajectory:
Local-only cancellation tends to fail in models which still have simple data flow, but more variety in the sequence of operators, each with different valid implementations. Consider,
In this case
Then apply the method discussed in A0 and do local cancellation. The above method works well for models with relatively simple data flow but for models with more branching the method has limitations. A simple consideration is sinking a transform through an operation with multiple inputs. The process of doing so requires materialization of the inverse transform on the other operands. For the sake of simplicity consider matrix multiplication:
For inference graphs I've seen this approach work well. But the approach is still greedy and suboptimal choices can occur. For training graphs this approach works less well due to the data flow complexity involved with branching from the forward to backward graph and the optimizers in place update of weights. I omit a specific example in this case for brevity, but encourage the review of of the graphs from @t-vi application of TVM to pytorch training for Bert and the long chains of transpose and reshapes that occur within the forward and backward m-h attention layers [4].
An example implementation I have seen included layout sources (e.g. operators like conv2d on an NPU with distinct layout constraints) and layout sinks (e.g. operations which involve data movement by DMA engines or in-memory compute which allow zero-cost data layout rearrangement during store). A constraint solver in this case flows layout constraints from sources toward sinks that can absorb aggregated/merged layout transform constraints. Coming back to the present discussion, I believe our design should be focused on ensuring that one or more of the non-local approaches discussed above in 2-4 are achievable. Any of these cases require the following components: C0) The ability to track constraints on a buffer. C1) The ability to roundtrip between an IR representation and the compact producer/consumer constraint representations. C2) The ability to merge/fold constraints - flowing is just merging a constraint with an unconstraint. Even for the pure local (back-to-back) case discussed in A0, components C1 and C2 are helpful with the caveat that the inferred constraints from the IR only exists within the local context of a single producer consumer pair in a pass. Thus both A0 and A1 can benefit from these components, and the delta that exists between A0 and A1 is clearer:
|
Thanks @csullivan for providing the overview. I agree that non-local approaches 2-4 are necessary. From the examples in this RFC I can also see how the components C0-C2 can be used to support these non-local approaches. C0 + C1 allows to specify the constraints during scheduling, and propagate back to the graph. Besides them, I would also like to mention another component
It seems to me that C0, C1, C3 are actually choices of implementation as there are multiple ways that require a combination of them to achieve the goal of constraint flowing.
Back to the discussion of this RFC, I think the main comments about the proposed methods is IR changes required (which may have greater impacts on the existing TIR and scheduling), and the complexity involved using the new schedule primitive to reach the final desired state. From my understanding, the intention of these new primitives is to allow arithmetic simplification to perform graph rewriting like over-computation. If this can be achieved as graph-level rewriting rule (perhaps simpler as it doesn't need arithmetic manipulations), personally I think that would still be preferred for better maintainability. Also I'd like to mention that modeling such rewriting in the graph doesn't necessary tie the TIR operator with a specific graph IR implementation. As we are moving to S-TIR scheduling, it is easy to apply some preprocessing steps to derive the PrimFunc in specific layout from a standard Finally, I would like to encourage us to focus on the e2e goals. It seems the current approaches, either implemented as A0 or A1 in graph-level, should suffice the use cases in the inference graph. Though the training graph is probably not an immediate need, if we would like to consider their use cases, probably having some concrete examples with desired result can guide us to make better decision. |
Adding some additional discussion with @csullivan . We agree that:
Right now we have some healthy discussions about ways to encode layout and padding decisions. Some of my thoughts: Introducing changes to TIR would needs some additional thoughts that deserves some extra consideration. Due to the N*M complexity (where N is the TIR possibilities and M is the number of primitives to be supported) that needs to be handled in implementation (by backend implementers and primitive implementers) Right now it is possible to do non-local constraint rewriting flowings as part of the graph pass. Note that while E1 is indeed less "compact" on one hand, we can use it to reconstruct the desirable compact data structure(something like BufferConstraint that represents the layout mapping) that we can use to flow the decisions across the graph node during the pass. Note that intiially such data structure do not need to live beyond the life of a pass, because they can be reconstructed at anytime from the other representation. This makes the decision of such data structure less critical, but still can be used to solve the problems we face. Starting from the graph level allows us to capture learnings, then use some e2e goals to make an informed decision on TIR level change later if needed. |
This was part of the design consideration, to minimize the impact of the proposed changes to primitives, lowering transformations, and backends.
I definitely agree that graph-level transforms are where the layouts and constraints should be decided. The
I'm still a bit confused with this approach, specifically how one would avoid having a separate compute definition for each workload on a new target (Initially brought up by @csullivan here.) In my mind, if I'm going to compose a layout transformation stage, it would need to be followed by a compute stage that takes a transformed layout as input. So rather than having a single conv2d that can be generalized over layouts, each transformed layout would still need to have a compute stage for it.
How would this be represented while optimizing the performance of a subgraph? My concern would be how to express the non-local constraints while keeping a small search space for optimization.
Is there an existing annotation to indicate that a stage should be removed entirely during lowering? That might be an effective way to allow more general usage by annotating a stage that can be assumed to have been performed prior to the subgraph. This would be a way to express the second option of an extra transformation stage, while still providing enough information to remove the transformation stage during lowering. |
Indeed it is important to avoid having a separate compute definition for each workload on a new target. In this particular case, all computation definition would start with the original layout. Then there is a "schedule transformation" like transform layout which will generate the new stage as part of the scheduling process. The particular stage can be marked, which contains effectively the same information as BufferConstraint, except that it does not introduce new data structures. During global layout reflowing, such information can be used to guide the reflowing to reconstruct a data structure like
Ideally we should not introduce annotation to indicate a stage should be removed, as that breaks the interface of the code itself (ideally the computation should remain the same). However, we can hint to the compiler that this particular stage is a layout transformation that should be lifted and resolved through the global constraint reflowing. Additionally, such annotation can be used to guide benchmarking, such that the overall tuning should only look at non-rewriting part(and we can leverage the transform block to generate input examples correctly). As a high level summary, the main message is to allow enough info in the TIR(as part of transform block) such that we can reconstruct a This also helps in cases where there are other graph-level layout rewriting(e.g. transpose) that can be fused with those additional transformation stages. |
Thank you, and that is roughly how I'm seeing it as well. That everything starts with the base compute definition and is modified from there. If I understand correctly, the main differences are below.
So long as the constraints can be statically searched for, this approach makes sense to me. I would be more concerned about adding additional semantics to existing nodes, such as a AttrStmt node, since it then requires passes to be aware not only of the existence of the constraint, but also that it must be reconstructed from the existing data structure. This approach would make it much more difficult for a static analysis tool to identify locations where the constraints must be updated. As a way to potentially find a way forward, what if we start by implementing pad values only for buffers that are allocated internally to a function? This would be allowed behavior under both Option A and Option B, and would help determine how difficult reconstruction of the constraints would be from the transformation block without any additional annotation. This could help motivate whether additional annotations are necessary, regardless of whether they are stored alongside the Buffer itself or in a separate attribute/annotation. |
It doesn't add additional semantic, the computation semantic stays the same, it is a hint to the graph compiler. Here are an example using |
My apologies, I had meant the semantics of a node from the perspective of a TIR transformation, not the semantics from the perspective of the computation being described. For a TIR transformation, if an object is replaced, whatever attributes describe that object must be updated to refer to the new object. So if constraints are added to the block annotation, I had been thinking of that as a change to the semantics of the |
Indeed if buffer is used in annotation value that will change the semantic of a node, however, that are different ways to represent this, as long as it can be reconstructed later. For example, we may introduce explicit cache stage to add the padding, and mark this block for later processing. |
Wouldn't that require a "remove entirely" annotation that was suggested against here? I could see how we could mark a transformation to be hoisted out later, but when some simplifications require the constraint to be expressed in the producer, and others in the consumer, exposing it to both |
Writing out some of my thoughts, to see if there's a way to express the constraints while only using existing TIR features. The main goals would be as follows.
Next, working through various options for how the constraints could be stored. In the examples below, sketching out how these would apply to the element-wise operation which starts as below. @T.prim_func
def func(A: T.Buffer[(14), "int32"], B: T.Buffer[14, "int32"]):
for i in T.serial(14):
B[i] = 2 * A[i]
|
@tqchen may clarify. I think it's suggesting marking and lifting the stage to the graph and do global flowing instead of removing it (though from the perspective of the subgraph (PrimFunc) it is removed from the PrimFunc |
Added some examples to build on top of @Lunderberg 's example TransformationThe main difference between annotation and special handling are:
Step 0: Produce temp stages with annotationThe transformation produces temporary buffers (AC and BC), where the relation between those data and the A, B are recorded in two blocks(preproc and post proc). Note that these additional annotations are hint for compilers to perform future optimizations(e.g. to lift them out our cancel. Our eventual goal could be directly reason those properties from the code, but annontations provides a first short cut. @T.prim_func
def grow(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]):
AC = T.alloc_buffer([4, 4], "int32")
BC = T.alloc_buffer([4, 4], "int32")
for io, ii in T.grid(4, 4):
with T.block():
T.block_attr("preproc", "pad")
AC[io, ii] = if_then_else(4 * io + ii < 14, A[4 * io + ii], 0)
for i, j in T.grid(4, 4):
BC[i, j] = 2 * AC[i, j]
for io, ii in T.grid(14):
with T.block():
# hint that this is a cropping operation,
# where we know that the remaining part in B is 0
# Additionally, the remaining uncovered values
# are assumed to be 0, if not provided then no assumptions are made
T.block_attr("postproc", ["crop", 0])
B[io, ii] = BC[4 * io + ii]
@T.prim_func
def addone(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]):
for i in T.grid(14):
B[i] = A[i] + 1
@R.func
def main(A: T.Tensor[14, "int32"]):
lv0 = call_tir(grow, [A], (14))
# an intermdiate stage to show non-local reflowing
lv1 = call_tir(addone, [lv0], (14))
lv2 = call_tir(grow, [lv1], (14))
... Not the special crop annotation comes with an Step 1: Reconstruct constraint at TIR-Graph levelBy looking at the primfunc, we know that there is a desire to split out the preproc stage and postpost stage to the graph. Although it is totally fine for the compiler to choose not to do so and it is still a valid program. But let us say we choose to lift them out @T.prim_func
def grow_packed(AC: T.Buffer[[4,4], "int32"], BC: T.Buffer[[4,4], "int32"]):
for i, j in T.grid(4, 4):
BC[i, j] = 2 * AC[i, j]
@T.prim_func
def pad(A: T.Buffer[14, "int32"], AC: T.Buffer[[14, 14], "int32"]):
for io, ii in T.grid(4, 4):
with T.block():
T.block_attr("preproc", "pad")
AC[io, ii] = if_then_else(4 * io + ii < 14, A[4 * io + ii], 0)
@T.prim_func
def crop_with_pad_assume(BC: T.Buffer[[4,4], "int32"], B: T.Buffer[14, "int32"]):
# Note that this crop carries a pad assertion(of other values of BC)
for io, ii in T.grid(14):
with T.block():
T.block_attr("postproc", ["crop", 0])
B[io, ii] = BC[4 * io + ii]
@R.func
def main(A: T.Tensor[14, "int32"]):
lv0 = call_tir(pad, (4, 4), A)
lv1 = call_tir(grow, [lv0], (4, 4))
# These are two things that we want to use for global format reflowing
lv2 = call_tir(crop_with_pad_assume, [lv1], (14))
lv3 = call_tir(addone, [lv2], (14)
lv4 = call_tir(pad, [lv2], (4, 4))
lv4 = call_tir(grow, [lv3], (4, 4))
lv5 = call_tir(crop_with_pad_assume, [(14)) Step 2: Global Reflowing of layoutsNow as a last step, let us say we will do global reflowing.
DiscussionThere are a few key properties that is really desirable here:
Talking about “constraints”, it is also useful to talk about categories of them, roughly we can divide them into three categories.
All three types of constraints can be helpful. In our particular case, |
I like this breakdown, and agree. In this categorization, what I've been calling "constraints" would be "assumptions". Double-checking in For usage of assumptions, I think the key would be to insert an assumption whenever the information that could otherwise prove it is hoisted out of the PrimFunc. That would provide non-local information that could be used by the PrimFunc to allow local simplifications.
I don't think we can make this strong of a statement, as it would also forbid fusing operators together or hoisting a stage out of a PrimFunc. In both cases, the signature of the resulting PrimFunc may be different than it was before. This shows up in the example, as the interface of As a slightly less general statement, I would say that transformations of a PrimFunc in isolation may not change the PrimFunc's interface. So an optimization search to improve the performance of a single subgraph may not change the layout of its own arguments, nor may it change assumptions of what is present in the padding, as those would change its interface. However, a graph-level transform would be allowed to fuse subgraphs, to hoist stages out of a PrimFunc, to alter the layout of a PrimFunc's input, or to alter the assumptions provided about the inputs. In general, a PrimFunc's interface could only be changed when calls into the PrimFunc are also modified to remain compatible. Is there a better term than "scheduling primitive" to describe layout transformations that impact input/output buffers? I think the difference is between context-independent transformations that may be performed on a PrimFunc without changing, as opposed to context-dependent transformations that may only be performed as part of a graph-level transformation.
Would this handle cases where there are multiple different options for how an operator could be implemented? Otherwise, I'm not sure how this would handle cases where multiple different sets of layouts/constraints could be inferred from different TIR-level schedules of the same operator. As examples, the drop-down has 6 different implementations of Click to expand# Implementation 1, no preproc/postproc are present.
#
# No hoistable layout transformations. Could be fused with a layout
# transformation, but doesn't otherwise provide any constraints.
@T.prim_func
def addone(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]):
for i in T.serial(14):
with T.block("compute"):
B[i] = A[i] + 1
# Implementation 2, pad input/output, but never access the padding of
# either input or output.
#
# In back-propagation of constraints, the T.undef() that is cropped
# from BC could be narrowed to a known value provided from the
# successor. However, AC's padding is never written to, so could
# propagate T.undef() back to preceding function.
@T.prim_func
def addone(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]):
for io, ii in T.grid(4, 4):
with T.block():
T.block_attr("preproc", "pad")
if 4 * io + ii < 14:
AC[io, ii] = A[4 * io + ii]
for i in T.serial(14):
with T.block("compute"):
BC[i // 4, i % 4] = AC[i // 4, i % 4] + 1
for i in T.serial(14):
with T.block():
T.block_attr("postproc", ["crop", T.undef()])
B[i] = BC[i // 4, i % 4]
# Implementation 3, pad input with known value, but never access
# padding of output.
#
# In back-propagation of constraints, the T.undef() that is cropped
# from BC could be narrowed to a known value provided from the
# successor. AC's padding is written to, so this would propagate
# `PadMapping(predicate, pad_value=0)` to the previous operator.
@T.prim_func
def addone(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]):
for io, ii in T.grid(4, 4):
with T.block():
T.block_attr("preproc", "pad")
AC[io, ii] = T.if_then_else(4 * io + ii < 14, A[4 * io + ii], 0)
for i in T.serial(14):
with T.block("compute"):
BC[i // 4, i % 4] = AC[i // 4, i % 4] + 1
for i in T.serial(14):
with T.block():
T.block_attr("postproc", ["crop", T.undef()])
B[i] = BC[i // 4, i % 4]
# Implementation 4, pad input with arbitrary value, provide no
# guarantees in output.
#
# In back-propagation of constraints, the T.undef() that is cropped
# from BC could be narrowed to a known value provided from the
# successor. AC's padding is written to, so this would propagate
# `PadMapping(predicate, pad_value=BC_pad_value - 1)` to the
# previous operator.
@T.prim_func
def addone(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]):
for io, ii in T.grid(4, 4):
with T.block():
T.block_attr("preproc", "pad")
AC[io, ii] = T.if_then_else(4 * io + ii < 14, A[4 * io + ii], T.undef())
for io, ii in T.grid(4, 4):
with T.block("compute"):
BC[io, ii] = AC[io, ii] + 1
for i in T.serial(14):
with T.block():
T.block_attr("postproc", ["crop", T.undef()])
B[i] = BC[i // 4, i % 4]
# Implementation 5, pad input with known value, analysis of TIR
# successfully propagates pad value through to provide assumption when
# cropping.
#
# In back-propagation of constraints, the output assumption is fixed.
# Unless the operator following addone has included the constraint 1
# as the required value in its padding, the crop/pad pair wouldn't be
# able to be removed. AC's padding is written to, and would propagate
# `PadMapping(predicate, pad_value=0)` to the previous operator.
@T.prim_func
def addone(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]):
for io, ii in T.grid(4, 4):
with T.block():
T.block_attr("preproc", "pad")
AC[io, ii] = T.if_then_else(4 * io + ii < 14, A[4 * io + ii], 0)
for io, ii in T.grid(4, 4):
with T.block("compute"):
BC[io, ii] = AC[io, ii] + 1
for i in T.serial(14):
with T.block():
T.block_attr("postproc", ["crop", 1])
B[i] = BC[i // 4, i % 4]
# Implementation 6, pad input with known value, analysis of TIR can't
# successfully propagate pad value through to the output.
#
# In back-propagation of constraints, the output assumption is fixed.
# Since we don't provide an assumption of what will be returned, the
# graph-level pair of `crop(T.undef())` followed by `pad(x)` could
# only be canceled out if `x` is `T.undef()`. AC's padding is written
# to, and would propagate `PadMapping(predicate, pad_value=0)` to
# the previous operator.
@T.prim_func
def addone(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]):
for io, ii in T.grid(4, 4):
with T.block():
T.block_attr("preproc", "pad")
AC[io, ii] = T.if_then_else(4 * io + ii < 14, A[4 * io + ii], 0)
for io, ii in T.grid(4, 4):
with T.block("compute"):
BC[io, ii] = AC[io, ii] + 1
for i in T.serial(14):
with T.block():
T.block_attr("postproc", ["crop", T.undef()])
B[i] = BC[i // 4, i % 4] I think the main change is that the temporary stages with annotation will need to allow multiple possibilities, rather than a single definitive layout. These options could then be searched at the graph-level to decide on the appropriate layout. After that is decided, the tempoerary stage could be selected and the transformations hoisted.
Completely agreed. I think this is true at both the TIR and graph levels, that allowing assumptions means ensuring that the assumption isn't changed after it is used for simplifications. The advantage of writing the assumptions at the graph level is that specific pairs of functions (such as I think the main rules that would need to be followed when handling assumptions would be the following three.
The restriction against changing a PrimFunc's interface fall out directly from rule #1. Since an assumption that restrict values of an input cannot be proven, these assumptions may not be modified. |
These make sense, and agreed that the TIR->global feedback is important for enabling the layout reflow. Going back through the discussion, I think we're converging on agreement on what features are required, and the main question remaining are how to best provide annotation for non-local information, and how best to express layout transformations while scheduling. I've made some updates to the text of the RFC, based on the discussions here, primarily to remove the proposed changes to TIR data structures. This follows your comment from a few days ago, which brought up
|
@T.prim_func | ||
def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]): | ||
for i,j in T.grid(4,4): | ||
if 4*i + j < 14: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which condition should be kept? How do we decide?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the merging of complementary conditionals is valid, then which condition is kept doesn't matter for correctness. For two conditions A
and B
, if A
implies !B
and B
implies !A
, then A
and !B
are different functional forms of the same expression.
That said, I'd probably keep the first conditional, as it allows for the simplification to be viewed as a specific case of a more general transformation. Given a conditional that is followed by another statement outside the conditional, it is valid to move the statement inside the conditional, placed at the end of both the then_case
and else_case
. If the statement being moved is itself a conditional, then it may be simplified. In this case, the intermediate step would look as follows.
@T.prim_func
def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
for i,j in T.grid(4,4):
if 4*i + j < 14:
A[i] = 0.0
if i==3 and j>=2:
B[i] = 2.0
else:
B[i] = 3.0
else:
A[i] = 1.0
if i==3 and j>=2:
B[i] = 2.0
else:
B[i] = 3.0
I wouldn't want to generate the intermediate state in all cases, because it may not always lead to useful simplifications, which is why it would only be applied in the special cases of identical conditions and complementary conditions.
A[i] = 0.0 | ||
``` | ||
|
||
### New Transform - Reduce Loop Extents |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is necessary. We could simply reuse loop partitioning, and break off pieces of the nest that will never execute.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point, and would avoid having the special case pass.
Looking at the implementation of LoopPartition
for edge cases, and making some notes to myself on the steps required.
- Apply a
T.likely
annotation when checking if it's a pad value (e.g.if !T.likely(4*io + ii < 14)
), sinceLoopPartition
uses this to identify partitionable conditions. - Maintain the
T.likely
annotation when hoisting part of a conditional (e.g. When hoistingio==3
out of!T.likely(io==3 and ii >=2
). - Look into relaxing the restriction against partitioning a constant loop, currently only allowed by a pass config. If we're generating loops that we know should be partitioned, it would be strange to also require the user to opt-in. I don't know the history of this restriction, so this would require some investigation. (Perhaps could allow the additional partition only if the loop is a serial loop and all but one of the partitions are no-ops.)
for i, j in T.grid(4, 4): | ||
if i == 0 and j < 2: | ||
A[i, j] = 0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would look something like
for i, j in T.grid(0..1, 0..2): # I'm making the ranges more verbose for clarity
A[i, j] = 0.0
for i, j in T.grid(0..1, 2..4)
pass # j < 2 is false
for i, j in T.grid(1..4, 0..2)
pass # i == 0 is false
for i, j in T.grid(1..4, 2..4)
pass # i == 0 is false, j < 2 is false
a no-op. | ||
|
||
```python | ||
# sched.remove_branching_through_overcompute(block='compute') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this only apply to outputs? I think we should per-buffer directive that indicates that out-of-bounds access is allowed. The only thing in question is how to determine/specify that out-of-bounds reads from inputs is ok. The user can add padding -INF to inputs to maxpool, but how does the maxpool compute know that it can use the out-of-bounds values?
As to whether to actually utilize this should probably be left to the compiler. Auto-scheduling should not be a replacement for compiler optimizations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this only apply to outputs?
This specific reasoning, of identifying the overcompute as a no-op by virtue of being overwritten later, is specific to outputs. In general, overcompute may only be introduced if the new statements can be shown to be no-ops.
- A write that is overwritten without being read is a no-op. (Used to introduce overcompute of outputs.)
- A write to a location that is deallocated without being read is a no-op. (Used to introduce overcompute of local caches.)
- A write of the same value that is already at the write location is a no-op. (Used to introduce overcompute based on known facts about the input buffer, such as
Output[0] = Output[0] + Input[i]
being a no-op ifInput[i]
is known to be zero.
I think we should per-buffer directive that indicates that out-of-bounds access is allowed.
Agreed. This is specified using the pad_value
argument in a transformation, and is exposed to local analysis either using the proposed T.assume
intrinsic for input buffers, or through BufferStore
for local caches. That way, the out-of-bounds access within an existing statement can be used to know that other out-of-bounds access are safe to use. There's an example here, where T.assume(4*io+ii < 14 or A[io,ii] == T.undef())
is used to know that it is safe to insert reads to indices for which 4*io + ii >= 14
. If no pad_value
is specified, then there is no previous read/write from those locations, and so a later transformation may not introduce a new read/write.
The user can add padding -INF to inputs to maxpool, but how does the maxpool compute know that it can use the out-of-bounds values?
This would be the role of the T.assume
intrinsic. It wouldn't have any effect on its own, and would be removed as part of lowering, but would expose information to the function that could be used as part of optimizations. In this case, the statement T.assume(!indices_are_padding or buf[indices] == -INF)
could let maxpool know that those values can be used.
As to whether to actually utilize this should probably be left to the compiler. Auto-scheduling should not be a replacement for compiler optimizations.
For transforms that could either be used in auto-scheduling or when compiling, I had been seeing compiler optimizations as the preferred place to implement transforms that are universally beneficial, and auto-scheduling as the preferred place to implement functionality that are conditionally beneficial. In this case, because the overcompute may be very large for some pathological cases, I think it is better exposed as a scheduling decision, as the cost of overcompute may not always be worth the benefit of avoiding a branch.
Thanks @Lunderberg for the update, I think we are moving towards positive direction of overall IR design. Some additional feedbacks: Keep Schedule Decisions Local to PrimFunc then ComposeOn schedule primitives, to be pragmatic, it would be helpful to have some of the cross PrimFunc re-flowing done in two steps. Specifically, some of your
In general it is helpful to first keep schedule decision local, e.g. introducing a caching stage (AC, BC in the example), the compose with another reflowing pass to bring the decision to consumer/producers. This is mainly to reduce the overall complexity in implementing such transformations, and also makes things more modular.
Use IfThenElse expression for Padding.While it is possible to express padding with a loop and another loop that writes the padded value, it is harder to schedule the resulting blocks as there are more than one producers. Having a single loop and use for io, ii in T.grid(4, 4):
with T.block():
T.block_attr("preproc", "pad")
AC[io, ii] = if_then_else(4 * io + ii < 14, A[4 * io + ii], 0) Propagate Padding Decisions Holistically in a PrimFuncSome of the complications of duplicated condition(and their simplification) roots from the fact that we do layout transform of output and input separately(each introducing their own conditions which then needs to be simplified). It might be helpful to do a global transformation, usually driven from the output, then "backprop" the implication of that decisions to the input. Doing such transformation at a single shot will likely alleviate the need of generating extra conditions then simplifying them. |
My goal with the latest update wasn't to require global decisions, but to make local changes only, which could be used in different contexts. For the auto-scheduler, since the context requires maintaining the same PrimFunc interface, local optimization would be restricted to transformations of the caching stage. For stand-alone usage, such as preparing a single PrimFunc for a unit test, the context allows the interface to change. That way, the restrictions to the transformations are imposed by the level of abstraction that requires them.
I definitely agree that this makes the later analysis/rewrites easier. I had maintained them as two separate loops both to minimize the extent of changes being made in any one scheduling change, and to maintain the current behavior of I see four main options on how the loopnests could be handled:
The current proposed version would be option 4, but I think I'd prefer option 2 in order to reduce the number of follow-up simplifications required.
At the TIR level, I suppose I'm unclear on what "'backprop' the implication of that decisions to the input" would mean, since changing the layout of one buffer doesn't strictly require changing the layout of other buffers. Intuitively, I can picture how it would apply to some operators (e.g. perform analogous transformations on the inputs to element-wise functions) and how those could be identified (e.g. track which indices are used for access of each buffer, and identify corresponding shapes from the indices), but I'm unclear as to how a similar intuition would be applied for more complicated functions. (I'm also not sure if this would require a similarly difficult sequence of proofs as the proposed transforms, just with the goal of proving a preferred layout rather than proving a possible simplification.) We could allow the user to specify transformations of all buffers simultaneously, but this wouldn't really solve the problem, as the simplifications made would still need to be based on that information provided. At the graph level, I don't think a single direction of constraint propagation is sufficient. Backward propagation, starting with the output values returned to the user, could track which indices contribute to that final output, which could be exposed to producers. Forward propagation, starting with the input values provided by the user, could track which indices of intermediate buffers contain known values, which could be exposed to consumers. With these uncertainties, I'm starting to think of |
following up on this, I think we are in broad stroke agreement that we can achieve our goals with blocl/fn attributes in IR as well as builtin assume. As a result, my original blocker for the RFC has been resolved, would still be great to work together to flesh out the details of schedule primitives and how do they interact with the rest of TIR scheduling, but I somewhat think they can be done separately and we don;t need to nail down the details of primitives. The schedule primitives can be done relatively independently as long as we agree on the principle that:
We can explore possible options as long as the IR spec remains stable, if there is a need to update IR itself or meaning of attribute, we can come back and discuss again |
cc @Hzfengsy @wrongtest-intellif it would be great if you can also take a followup look |
Thanks everyone for the very fruitful discussions! We indeed have a good path forward and are aligned on the principles that for the end to end optimization we will maintain function interface invariance and achieve graph level layout optimization via a combination of local decisions, reconstruction with assumptions, and rewriting based on the result of graph level analysis and planning. I would ask that we move this discussion into a final comment period as we would like to soon open a tracking issue for the items described in the RFC. |
Thank you very much on the comments, suggestions, and discussion, and I'm quite happy with how the design evolved over the course of the discussions! |
Thanks everyone for the discussions. We have agreed on the design principles and will continue to explore scheduling options. Let's keep the RFC open for final comments until the end of this week. |
Implementation of API in `tvm.tir.schedule` for layout transformations with padding, as part of #12261, item "Insert pad value into generated TIR, using `tir::if_then_else`, `builtin::assume`, and `builtin::undef`". Following the RFC discussion in apache/tvm-rfcs#77 (comment) and apache/tvm-rfcs#77 (comment), this commit preferentially rewrites the loops that surround a padded transformation where possible, in order to express padding in terms of `tir::if_then_else`.
Implementation of API in `tvm.tir.schedule` for layout transformations with padding, as part of apache#12261, item "Insert pad value into generated TIR, using `tir::if_then_else`, `builtin::assume`, and `builtin::undef`". Following the RFC discussion in apache/tvm-rfcs#77 (comment) and apache/tvm-rfcs#77 (comment), this commit preferentially rewrites the loops that surround a padded transformation where possible, in order to express padding in terms of `tir::if_then_else`.
This RFC introduces a method to specify padding to be applied as part of a buffer layout transformation, to be used when the desired layout does not evenly tile the buffer being transformed, and simplifications that can be performed based on these padded buffers.
The motivating examples are primarily in the "Implementation options" section, which goes through several desired usages of the buffer padding, and how they can be automatically derived using the TIR primitives/transformations described in earlier sections.
Rendered Markdown link