-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[frontend] allow var_mean to be implemented in one pass #1285
Labels
Comments
peterbell10
changed the title
Add
Add Mar 6, 2023
tl.welford
to tl.welford
to allow var_mean to be implemented in one pass
I feel like such a function would not belong in the |
ptillet
changed the title
Add
[frontend] allow var_mean to be implemented in one pass
Mar 6, 2023
tl.welford
to allow var_mean to be implemented in one pass
ptillet
pushed a commit
that referenced
this issue
Apr 13, 2023
…1305) Fixes #1285 This changes `tt.reduce` to replace `redOp` by a region containing arbitrary code. For example, `tl.sum` is now lowered as: ```mlir %res = "tt.reduce"(%arg0) ({ ^bb0(%arg1: f32, %arg2: f32): %add = arith.addf %arg1, %arg2 : f32 tt.reduce.return %add : f32 }) {axis = 1 : i32} : (tensor<128x128xf32>) -> tensor<128xf32> ``` Support for index reductions at the MLIR level are also dropped in favor of simultaneous reductions over multiple tensors. Which generalizes the code without loss of performance. So for example `argmin` gets lowered as: ```mlir %7 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> %8 = tt.view %7 : (tensor<256xi32>) -> tensor<1x256xi32> %9:2 = "tt.reduce"(%6, %8) ({ ^bb0(%arg4: f32, %arg5: i32, %arg6: f32, %arg7: i32): %14 = arith.cmpf olt, %arg4, %arg6 : f32 %15 = arith.cmpf ogt, %arg4, %arg6 : f32 %16 = arith.cmpi slt, %arg5, %arg7 : i32 %17 = arith.select %16, %arg5, %arg7 : i32 %18 = arith.select %15, %arg7, %17 : i32 %19 = arith.select %14, %arg5, %18 : i32 %20 = arith.cmpf olt, %arg4, %arg6 : f32 %21 = arith.select %20, %arg4, %arg6 : f32 tt.reduce.return %21, %19 : f32, i32 }) {axis = 1 : i32} : (tensor<1x256xf32>, tensor<1x256xi32>) -> (tensor<1xf32>, tensor<1xi32>) ```
pingzhuu
pushed a commit
to siliconflow/triton
that referenced
this issue
Apr 2, 2024
…riton-lang#1305) Fixes triton-lang#1285 This changes `tt.reduce` to replace `redOp` by a region containing arbitrary code. For example, `tl.sum` is now lowered as: ```mlir %res = "tt.reduce"(%arg0) ({ ^bb0(%arg1: f32, %arg2: f32): %add = arith.addf %arg1, %arg2 : f32 tt.reduce.return %add : f32 }) {axis = 1 : i32} : (tensor<128x128xf32>) -> tensor<128xf32> ``` Support for index reductions at the MLIR level are also dropped in favor of simultaneous reductions over multiple tensors. Which generalizes the code without loss of performance. So for example `argmin` gets lowered as: ```mlir %7 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> %8 = tt.view %7 : (tensor<256xi32>) -> tensor<1x256xi32> %9:2 = "tt.reduce"(%6, %8) ({ ^bb0(%arg4: f32, %arg5: i32, %arg6: f32, %arg7: i32): %14 = arith.cmpf olt, %arg4, %arg6 : f32 %15 = arith.cmpf ogt, %arg4, %arg6 : f32 %16 = arith.cmpi slt, %arg5, %arg7 : i32 %17 = arith.select %16, %arg5, %arg7 : i32 %18 = arith.select %15, %arg7, %17 : i32 %19 = arith.select %14, %arg5, %18 : i32 %20 = arith.cmpf olt, %arg4, %arg6 : f32 %21 = arith.select %20, %arg4, %arg6 : f32 tt.reduce.return %21, %19 : f32, i32 }) {axis = 1 : i32} : (tensor<1x256xf32>, tensor<1x256xi32>) -> (tensor<1xf32>, tensor<1xi32>) ```
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Currently PyTorch inductor is forced to implement
torch.var_mean
as two passes over the input data, which causes a slowdown for batch norm. To allow single-pass computation we need a new reduction operatortl.welford(mean, m2, count)
which implements the combination step of parallel Welford's algortihm.A more general solution might be to instead add a
tl.reduce
which takes a function acting on scalars, so users can write their own reductions without needing to change the triton language.The text was updated successfully, but these errors were encountered: