-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rewrite ReduceOp to support arbitrary reduce operations #1305
Conversation
// Create a new copy of the reduce block, and inline it | ||
Block *currentBlock = rewriter.getBlock(); | ||
Region &parent = *currentBlock->getParent(); | ||
rewriter.cloneRegionBefore(reduceOp, &parent.front()); | ||
auto &newReduce = parent.front(); | ||
auto returnOp = dyn_cast<triton::GenericReduceReturnOp>(newReduce.getTerminator()); | ||
rewriter.mergeBlockBefore(&newReduce, &*rewriter.getInsertionPoint(), {acc, cur}); | ||
acc = returnOp.getResult(); | ||
// Delete the terminator, which is no longer used | ||
rewriter.eraseOp(returnOp); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the main change compared to ReduceOpToLLVM.cpp
.
python/triton/language/semantic.py
Outdated
def prod(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: | ||
|
||
def make_mul(reduce_op): | ||
ir_scalar_ty = input.type.scalar.to_ir(builder) | ||
region = reduce_op.get_region(0) | ||
with insertion_guard(builder): | ||
block = builder.create_block_with_parent(region, [ir_scalar_ty] * 2) | ||
fmul = builder.create_fmul(block.arg(0), block.arg(1)) | ||
builder.create_reduce_ret(fmul) | ||
|
||
return reduction(input, axis, make_mul, builder) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've been using this for testing but the end goal would be to have the compiler build the inner function from a lambda, or something like that. I might need some help with that though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haha yeah it's not entirely trivial. I think it means the ASTVisitor
should be modified to create MLIR functions out of lambda, and then the reduce op could merge in the basic block from this function
def TT_GenericReduceOp: TT_Op<"generic_reduce", | ||
[Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> { | ||
let summary = "Reduction using generic combination algorithm"; | ||
let arguments = (ins TT_Tensor:$operand, I32Attr:$axis); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ptillet assuming I can get index reductions working, do you think it would be reasonable to replace ReduceOp
entirely?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, if index reductions can work, then I think we could replace ReduceOp
with the new op. We'll have to do some heavier testing to make sure that the performance hasn't decreased
f6d4247
to
eaea0a3
Compare
python/triton/language/core.py
Outdated
axis = _constexpr_to_value(axis) | ||
n = input.shape[axis] | ||
index = arange(0, n, _builder=_builder) | ||
new_shape = [constexpr(1)] * len(input.shape) | ||
new_shape[axis] = constexpr(n) | ||
index = view(index, new_shape, _builder=_builder) | ||
index = broadcast_to(index, input.shape, _builder=_builder) | ||
|
||
values, indices = semantic.min_with_index(input, index, axis, _builder) | ||
return indices |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is my strategy for armin
/argmax
. Instead of special casing it I just lower it as a reduction over two tensors:
%7 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32>
%8 = tt.view %7 : (tensor<256xi32>) -> tensor<1x256xi32>
%9:2 = "tt.generic_reduce"(%6, %8) ({
^bb0(%arg4: f32, %arg5: i32, %arg6: f32, %arg7: i32):
%15 = arith.cmpf olt, %arg4, %arg6 : f32
%16 = arith.cmpf ogt, %arg4, %arg6 : f32
%17 = arith.minsi %arg5, %arg7 : i32
%18 = arith.select %16, %arg7, %17 : i32
%19 = arith.select %15, %arg5, %18 : i32
%20 = arith.minf %arg4, %arg6 : f32
tt.generic_reduce.return %20, %19 : f32, i32
}) {axis = 1 : i32} : (tensor<1x256xf32>, tensor<1x256xi32>) -> (tensor<1xf32>, tensor<1xi32>)
This has some really nice properties.
- the reduction code is the same whether you discard the min/max value or not
- It generalized perfectly to higher numbers of tensors, e.g. the 3 needed for aten.var_mean
- argmin/argmax specific logic is defined entirely at python level
- In my limited testing so far, it performs identically
@@ -80,6 +80,8 @@ TritonGPUConversionTarget::TritonGPUConversionTarget( | |||
// Some ops from SCF are illegal | |||
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp, scf::ReduceOp, | |||
scf::ReduceReturnOp>(); | |||
// We have custom versions of some arith operators | |||
addIllegalOp<arith::CmpIOp, arith::CmpFOp, arith::SelectOp>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did start running into edge cases in the some of the Dialect conversion code, where these were slipping through despite there being a conversion rule for them. It's possible that nested regions are handled differently by MLIR, not sure.
barrier(); | ||
for (unsigned i = 0; i < op.getNumOperands(); ++i) { | ||
store(acc[i], writePtrs[i]); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new changes here basically just change
foo(acc)
if (withIndex)
foo(accIndex)
into equivalent for loops.
param_types = [ty.to_ir(_builder) for ty in prototype.param_types] | ||
block = _builder.create_block_with_parent(region, param_types) | ||
args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)] | ||
results = _generator.call_JitFunction(combine_fn, args, kwargs={}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm in two minds whether this is hacky or elegant, but it works. I pass the CodeGenerator
in via the _generator
argument much like the _builder
argument, then call this function which I factored out of visit_Call
to generate the appropriate function definition and call it.
This is cherry-picked from triton-lang#1305 If you call a `JITFunction` twice in the same kernel, first with `int32` then with `uint32`, the second call will treat the unsigned value as signed. This passes through MLIR without error because MLIR uses the same types for both, but different operation calls will be generated.
This is cherry-picked from #1305 If you call a `JITFunction` twice in the same kernel, first with `int32` then with `uint32`, the second call will treat the unsigned value as signed. This passes through MLIR without error because MLIR uses the same types for both, but different operation calls will be generated so you may silently get the wrong result.
Thanks for the PR. Things are busy right now, but we will review it next week! |
e01d246
to
34973c1
Compare
cad9967
to
7e195a6
Compare
(sorry, things have been busy and haven't had time to review this yet!) |
@ptillet do you have any idea when you might have time to review this? |
} | ||
|
||
// TODO: This always takes layout from the first argument which | ||
// is fine for argmin/argmax but may not be optimal generally |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you limit all arguments of reduce to have the same encoding. So this is just fine?
if (t.getShape() != srcShape) { rop.emitError() << "shape mismatch"; }
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The concern is that the first argument might be cheap to convert but the second argument slow to convert. In that case this will remove the cheap layout conversion and add a more expensive one.
Also, I don't think there's ever a case where shape mismatch can happen.
3dd30f7
to
3db5241
Compare
Benchmark related stuff were merged in yesterday, so it's possible the tests got flaky. I'll investigate later today. |
Thanks again for the PR @peterbell10 . And thanks @Jokeren for the review. |
A small oversight in triton-lang#1305, since `view` can rearrange elements it should be avoided here. Instead I use indexing with `None` to create new dimensions.
A small oversight in #1305, since `view` can rearrange elements it should be avoided here. Instead I use indexing with `None` to create new dimensions. Co-authored-by: Philippe Tillet <[email protected]>
…on-lang#1340) This is cherry-picked from triton-lang#1305 If you call a `JITFunction` twice in the same kernel, first with `int32` then with `uint32`, the second call will treat the unsigned value as signed. This passes through MLIR without error because MLIR uses the same types for both, but different operation calls will be generated so you may silently get the wrong result.
…riton-lang#1305) Fixes triton-lang#1285 This changes `tt.reduce` to replace `redOp` by a region containing arbitrary code. For example, `tl.sum` is now lowered as: ```mlir %res = "tt.reduce"(%arg0) ({ ^bb0(%arg1: f32, %arg2: f32): %add = arith.addf %arg1, %arg2 : f32 tt.reduce.return %add : f32 }) {axis = 1 : i32} : (tensor<128x128xf32>) -> tensor<128xf32> ``` Support for index reductions at the MLIR level are also dropped in favor of simultaneous reductions over multiple tensors. Which generalizes the code without loss of performance. So for example `argmin` gets lowered as: ```mlir %7 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> %8 = tt.view %7 : (tensor<256xi32>) -> tensor<1x256xi32> %9:2 = "tt.reduce"(%6, %8) ({ ^bb0(%arg4: f32, %arg5: i32, %arg6: f32, %arg7: i32): %14 = arith.cmpf olt, %arg4, %arg6 : f32 %15 = arith.cmpf ogt, %arg4, %arg6 : f32 %16 = arith.cmpi slt, %arg5, %arg7 : i32 %17 = arith.select %16, %arg5, %arg7 : i32 %18 = arith.select %15, %arg7, %17 : i32 %19 = arith.select %14, %arg5, %18 : i32 %20 = arith.cmpf olt, %arg4, %arg6 : f32 %21 = arith.select %20, %arg4, %arg6 : f32 tt.reduce.return %21, %19 : f32, i32 }) {axis = 1 : i32} : (tensor<1x256xf32>, tensor<1x256xi32>) -> (tensor<1xf32>, tensor<1xi32>) ```
A small oversight in triton-lang#1305, since `view` can rearrange elements it should be avoided here. Instead I use indexing with `None` to create new dimensions. Co-authored-by: Philippe Tillet <[email protected]>
…ng#1305) No further changes were needed after triton-lang#1282. Fixes triton-lang#1264. Signed-off-by: Julian Oppermann <[email protected]>
Fixes #1285
This changes
tt.reduce
to replaceredOp
by a region containing arbitrary code. For example,tl.sum
is now lowered as:Support for index reductions at the MLIR level are also dropped in favor of simultaneous reductions over multiple tensors. Which generalizes the code without loss of performance. So for example
argmin
gets lowered as: