Skip to content

Commit

Permalink
Migration the stablehlo reduction family of operations, with promotab…
Browse files Browse the repository at this point in the history
…le semantics [RFC](openxla/stablehlo#1664), to MHLO.

The CL does two things:

1. The upstream change openxla/stablehlo#1869 in StableHLO updates various API related to shape inference. MHLO shape inference functions uses those APIs. The CL fixes the invocation of those APIs in MHLO codebase.

2. There exists canonicalization passes like `group-reduction-dimensions` and `hlo-canonicalize-reduction` which create reduce operation using builder methods that calls type inference of reduce op with empty reduction region [example](https://github.com/openxla/xla/blob/a91877b9c9aa1edf307c5927782111b1a81cd81d/xla/mlir_hlo/mhlo/transforms/group_reduction_dimensions/group_reduction_dimensions.cc#L228). This is problematic as, with the [change](openxla/stablehlo#1869), the type inference of reduce op is now dependent on the reduction body. The CL updates all the calls sites of the problematic builder (the one which calls type inference with empty reduction block) with the invocation of a new custom builder method introduced for mhlo::Reduce operation.

Note that at the moment we do not need similar custom builder for other reduction based operations (like scatter, reduce_scatter, all_reduce, select_and_scatter, reduce_window) as they are presently created using a builder version take result type as an input and hence does not call inference from within.

Also, the CL adds verification tests for the operations with promotable semantics.

PiperOrigin-RevId: 597407271
  • Loading branch information
sdasgup3 authored and copybara-github committed Jan 12, 2024
1 parent 79aba08 commit 955f760
Show file tree
Hide file tree
Showing 9 changed files with 332 additions and 290 deletions.
353 changes: 71 additions & 282 deletions xla/mlir_hlo/mhlo/IR/hlo_ops.cc

Large diffs are not rendered by default.

10 changes: 7 additions & 3 deletions xla/mlir_hlo/mhlo/IR/hlo_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -1495,7 +1495,6 @@ def MHLO_AllGatherOp : MHLO_Op<"all_gather", [SameOperandsAndResultElementType]>
}

def MHLO_AllReduceOp : MHLO_Op<"all_reduce", [
SameOperandsAndResultElementType,
SingleBlockImplicitTerminator<"ReturnOp">,
InferTensorType
]> {
Expand Down Expand Up @@ -1549,8 +1548,7 @@ def MHLO_AllReduceOp : MHLO_Op<"all_reduce", [
let hasCustomHLOConverter = 1;
}

def MHLO_ReduceScatterOp : MHLO_Op<"reduce_scatter",
[SameOperandsAndResultElementType]> {
def MHLO_ReduceScatterOp : MHLO_Op<"reduce_scatter", []> {
let summary = "ReduceScatter operation";
let description = [{
Within each process group in the process grid, performs reduction, using
Expand Down Expand Up @@ -1696,6 +1694,12 @@ def MHLO_ReduceOp: MHLO_ShapedInterfaceOp<"reduce", [
// compatible with reduce op's operands.
let regions = (region SizedRegion<1>:$body);

// Builder
let builders = [
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$init_values,
"DenseIntElementsAttr":$dimensions, "TypeRange":$element_types)>,
];

// TODO(b/129422361): ReduceOp has special conversion logic to HLO.
let hasCustomHLOConverter = 1;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ limitations under the License.
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Types.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

Expand Down Expand Up @@ -224,8 +225,11 @@ LogicalResult tryLowerTo1DOr2DReduction(
int64_t reductionDim = leadingReduction ? 0 : 1;
auto reductionDimAttr = rewriter.getI64VectorAttr({reductionDim});
Value initVal = op.getInitValues().front();
auto reductionOp =
rewriter.create<ReduceOp>(loc, intermResult, initVal, reductionDimAttr);
SmallVector<Type> elementTypes{llvm::map_range(
op.getBody().front().getTerminator()->getOperands(),
[](Value v) { return v.getType().cast<ShapedType>().getElementType(); })};
auto reductionOp = rewriter.create<ReduceOp>(loc, intermResult, initVal,
reductionDimAttr, elementTypes);
rewriter.inlineRegionBefore(op.getBody(), reductionOp.getBody(),
reductionOp.getBody().begin());
intermResult = reductionOp->getResults().front();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.

#include <memory>

#include "llvm/ADT/STLExtras.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand Down Expand Up @@ -219,8 +220,12 @@ struct HloCanonicalizeReductionPass
elemTy),
operand, newOperandShape));
}
auto newOp =
b.create<ReduceOp>(loc, newOperands, op.getInitValues(), attr);
SmallVector<Type> elementTypes{llvm::map_range(
op.getBody().front().getTerminator()->getOperands(), [](Value v) {
return v.getType().cast<ShapedType>().getElementType();
})};
auto newOp = b.create<ReduceOp>(loc, newOperands, op.getInitValues(),
attr, elementTypes);
newOp.getBody().takeBody(op.getBody());

SmallVector<Value, 4> newResults;
Expand Down
70 changes: 69 additions & 1 deletion xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,41 @@ func.func @all_reduce_tuple(%arg0: tensor<10xf32>, %arg1: tensor<f32>) -> tensor

// -----

// CHECK-LABEL: func @all_reduce_with_promotable_types
func.func @all_reduce_with_promotable_types(%operand: tensor<f32>) -> tensor<f64> {

%result = "mhlo.all_reduce"(%operand) ({
^bb0(%arg0: tensor<f64>, %arg1: tensor<f64>):
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<f64>, tensor<f64>) -> tensor<f64>
"mhlo.return"(%0) : (tensor<f64>) -> ()
}) {
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #mhlo.channel_handle<handle = 0, type = 0>
} : (tensor<f32>) -> tensor<f64>

func.return %result : tensor<f64>
}

// -----

// CHECK-LABEL: func @all_reduce_with_promotable_quantized_types
func.func @all_reduce_with_promotable_quantized_types(%operand: tensor<!quant.uniform<i8:f32, 2.000000e+00:15>>)
-> tensor<!quant.uniform<i32:f32, 2.000000e+00:15>> {

%result = "mhlo.all_reduce"(%operand) ({
^bb0(%arg0: tensor<!quant.uniform<i32:f32, 2.000000e+00:15>>, %arg1: tensor<!quant.uniform<i32:f32, 2.000000e+00:15>>):
%0 = mhlo.add %arg0, %arg1 : tensor<!quant.uniform<i32:f32, 2.000000e+00:15>>
"mhlo.return"(%0) : (tensor<!quant.uniform<i32:f32, 2.000000e+00:15>>) -> ()
}) {
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #mhlo.channel_handle<handle = 0, type = 0>
} : (tensor<!quant.uniform<i8:f32, 2.000000e+00:15>>) -> tensor<!quant.uniform<i32:f32, 2.000000e+00:15>>

func.return %result : tensor<!quant.uniform<i32:f32, 2.000000e+00:15>>
}

// -----

func.func @all_reduce_invalid_reducer(%operand: tensor<10xf32>) -> tensor<10xf32> {
// expected-error@+2 {{'mhlo.all_reduce' op failed to infer returned types}}
// expected-error@+1 {{Reduction-region must take 2 parameters, but takes 3 parameter(s)}}
Expand Down Expand Up @@ -200,7 +235,8 @@ func.func @all_reduce_invalid_return_type(%operand: tensor<10xf32>) -> tensor<10
// -----

func.func @all_reduce_invalid_return_type(%operand: tensor<10xf32>) -> tensor<10xi32> {
// expected-error@+1 {{'mhlo.all_reduce' op requires the same element type for all operands and results}}
// expected-error@+2 {{'mhlo.all_reduce' op inferred type(s) 'tensor<10xf32>' are incompatible with return type(s) of operation 'tensor<10xi32>'}}
// expected-error@+1 {{'mhlo.all_reduce' op failed to infer returned types}}
%0 = "mhlo.all_reduce"(%operand) ({
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
%max = mhlo.maximum %arg0, %arg1 : tensor<f32>
Expand Down Expand Up @@ -309,6 +345,38 @@ func.func @reduce_scatter_dynamic(%data: tensor<?x?xf32>) -> tensor<?x?xf32> {

// -----

// CHECK-LABEL: func @reduce_scatter_with_promotable_types
func.func @reduce_scatter_with_promotable_types(%data: tensor<4x16xf32>) -> tensor<4x4xf64> {
%0 = "mhlo.reduce_scatter"(%data) ({
^bb0(%arg2: tensor<f64>, %arg3: tensor<f64>):
%1 = mhlo.add %arg2, %arg3 : tensor<f64>
"mhlo.return"(%1) : (tensor<f64>) -> ()
}) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>,
scatter_dimension = 1 : i64,
channel_handle = #mhlo.channel_handle<handle = 1, type = 0>,
use_global_device_ids} : (tensor<4x16xf32>) -> tensor<4x4xf64>
func.return %0 : tensor<4x4xf64>
}

// -----

// CHECK-LABEL: func @reduce_scatter_with_promotable_quantized_types
func.func @reduce_scatter_with_promotable_quantized_types(
%data: tensor<4x16x!quant.uniform<i8:f32, 2.000000e+00:15>>) ->
tensor<4x4x!quant.uniform<i32:f32, 2.000000e+00:15>> {
%0 = "mhlo.reduce_scatter"(%data) ({
^bb0(%arg2: tensor<!quant.uniform<i32:f32, 2.000000e+00:15>>, %arg3: tensor<!quant.uniform<i32:f32, 2.000000e+00:15>>):
%1 = mhlo.add %arg2, %arg3 : tensor<!quant.uniform<i32:f32, 2.000000e+00:15>>
"mhlo.return"(%1) : (tensor<!quant.uniform<i32:f32, 2.000000e+00:15>>) -> ()
}) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>,
scatter_dimension = 1 : i64,
channel_handle = #mhlo.channel_handle<handle = 1, type = 0>,
use_global_device_ids} : (tensor<4x16x!quant.uniform<i8:f32, 2.000000e+00:15>>) -> tensor<4x4x!quant.uniform<i32:f32, 2.000000e+00:15>>
func.return %0 : tensor<4x4x!quant.uniform<i32:f32, 2.000000e+00:15>>
}

// -----

func.func @reduce_scatter_c2(%data: tensor<4x16xf32>) -> tensor<4x4xf32> {
// expected-error@+1 {{expects scatter_dimension >= 0}}
%0 = "mhlo.reduce_scatter"(%data) ({
Expand Down
29 changes: 29 additions & 0 deletions xla/mlir_hlo/tests/Dialect/mhlo/verifier_reduce_op.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,35 @@ func.func @reduce_mix_rank_and_unranked(%arg0: tensor<4x4xf32>, %arg1: tensor<*x
func.return %0#0, %0#1 : tensor<4xf32>, tensor<*xf32>
}

// -----

// CHECK-LABEL: func @reduce_with_promotable_types
func.func @reduce_with_promotable_types(%arg0: tensor<4x4xf32>, %arg1 : tensor<f32>)
-> (tensor<4xf64>) {
%0 = "mhlo.reduce"(%arg0, %arg1) ({

^bb0(%arg2: tensor<f64>, %arg3: tensor<f64> ):
%1 = "mhlo.add"(%arg2, %arg3) : (tensor<f64>, tensor<f64>) -> tensor<f64>
"mhlo.return"(%1) : (tensor<f64>) -> ()

}) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<f32>) -> tensor<4xf64>

func.return %0: tensor<4xf64>
}

// -----

// CHECK-LABEL: func @reduce_with_promotable_quantized_types
func.func @reduce_with_promotable_quantized_types(%arg0: tensor<4x4x!quant.uniform<i8:f32, 2.000000e+00:15>>,
%arg1: tensor<!quant.uniform<i8:f32, 2.000000e+00:15>>) -> tensor<4x!quant.uniform<i32:f32, 2.000000e+00:15>> {
%0 = mhlo.reduce(%arg0 init: %arg1) across dimensions = [0] : (tensor<4x4x!quant.uniform<i8:f32, 2.000000e+00:15>>, tensor<!quant.uniform<i8:f32, 2.000000e+00:15>>) -> tensor<4x!quant.uniform<i32:f32, 2.000000e+00:15>>
reducer(%arg2: tensor<!quant.uniform<i32:f32, 2.000000e+00:15>>, %arg3: tensor<!quant.uniform<i32:f32, 2.000000e+00:15>>) {
%1 = mhlo.add %arg2, %arg3 : tensor<!quant.uniform<i32:f32, 2.000000e+00:15>>
mhlo.return %1 : tensor<!quant.uniform<i32:f32, 2.000000e+00:15>>
}
return %0 : tensor<4x!quant.uniform<i32:f32, 2.000000e+00:15>>
}

// Next, we have the invalid testcases.

// -----
Expand Down
40 changes: 40 additions & 0 deletions xla/mlir_hlo/tests/Dialect/mhlo/verifier_reduce_window_op.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,46 @@ func.func @reduce_window_with_non_scalar_block_arg2(%arg0: tensor<4x2xf32>,

// -----

// CHECK-LABEL: func @reduce_window_with_promotable_types
func.func @reduce_window_with_promotable_types(%arg0: tensor<4x2xf32>,
%arg1: tensor<4x2xf32>, %init0: tensor<f32>, %init1: tensor<f32>) ->
(tensor<2x2xf64>, tensor<2x2xf32>) {
%0:2 = "mhlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({
^bb0(%a0: tensor<f64>, %a1: tensor<f32>, %b0: tensor<f64>,
%b1: tensor<f32>):
%2 = mhlo.add %a0, %b0 : tensor<f64>
%3 = mhlo.add %a1, %b1 : tensor<f32>
"mhlo.return"(%2,%3) : (tensor<f64>, tensor<f32>) -> ()
})
{ padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>,
window_dimensions = dense<[5, 1]> : tensor<2xi64>,
window_strides = dense<[3, 1]> : tensor<2xi64> }
: (tensor<4x2xf32>, tensor<4x2xf32>, tensor<f32>, tensor<f32>) ->
(tensor<2x2xf64>, tensor<2x2xf32>)
func.return %0#0, %0#1 : tensor<2x2xf64>, tensor<2x2xf32>
}

// -----

// CHECK-LABEL: func @reduce_window_with_promotable_quantized_types
func.func @reduce_window_with_promotable_quantized_types(%arg0: tensor<4x2x!quant.uniform<i8:f32, 2.000000e+00:15>>,
%init0: tensor<!quant.uniform<i8:f32, 2.000000e+00:15>>) -> (tensor<2x2x!quant.uniform<i32:f32, 2.000000e+00:15>>) {

%0 = "mhlo.reduce_window"(%arg0, %init0) ({
^bb0(%a0: tensor<!quant.uniform<i32:f32, 2.000000e+00:15>>, %b0: tensor<!quant.uniform<i32:f32, 2.000000e+00:15>>):
%1 = mhlo.add %a0, %b0 : tensor<!quant.uniform<i32:f32, 2.000000e+00:15>>
"mhlo.return"(%1) : (tensor<!quant.uniform<i32:f32, 2.000000e+00:15>>) -> ()
})
{ padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>,
window_dimensions = dense<[5, 1]> : tensor<2xi64>,
window_strides = dense<[3, 1]> : tensor<2xi64>
}
: (tensor<4x2x!quant.uniform<i8:f32, 2.000000e+00:15>>, tensor<!quant.uniform<i8:f32, 2.000000e+00:15>>) -> (tensor<2x2x!quant.uniform<i32:f32, 2.000000e+00:15>>)
func.return %0 : tensor<2x2x!quant.uniform<i32:f32, 2.000000e+00:15>>
}

// -----

func.func @reduce_window_invalid_inputs(%arg0: tensor<4x2xf32>,
%arg1: tensor<4x3xi32>, %init0: tensor<f32>, %init1: tensor<i32>) ->
(tensor<2x2xf32>, tensor<2x2xi32>) {
Expand Down
48 changes: 48 additions & 0 deletions xla/mlir_hlo/tests/Dialect/mhlo/verifier_scatter_op.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,54 @@ func.func @scatter_with_unranked_inputs(%input_tensor: tensor<*xf32>,

// -----

// CHECK: func @scatter_with_promotable_types
func.func @scatter_with_promotable_types(%input_tensor: tensor<200x100x300xf32>,
%scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) ->
tensor<200x100x300xf64> {
%0 = "mhlo.scatter" (%input_tensor, %scatter_indices, %updates) ({
^bb0(%lhs: tensor<f64>, %rhs: tensor<f64>):
%add = mhlo.add %lhs, %rhs : tensor<f64>
"mhlo.return"(%add) : (tensor<f64>) -> ()
}) {
scatter_dimension_numbers = #mhlo.scatter<
update_window_dims = [1],
inserted_window_dims = [0, 1],
scatter_dims_to_operand_dims = [0, 1],
index_vector_dim = 1
>,
indices_are_sorted = true,
unique_indices = true
} : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) ->
tensor<200x100x300xf64>
func.return %0 : tensor<200x100x300xf64>
}

// -----

// CHECK: func @scatter_with_promotable_quantized_types
func.func @scatter_with_promotable_quantized_types(%input_tensor: tensor<200x100x300x!quant.uniform<i8:f32, 2.000000e+00:15>>,
%scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300x!quant.uniform<i8:f32, 2.000000e+00:15>>) ->
tensor<200x100x300x!quant.uniform<i32:f32, 2.000000e+00:15>> {
%0 = "mhlo.scatter" (%input_tensor, %scatter_indices, %updates) ({
^bb0(%lhs: tensor<!quant.uniform<i32:f32, 2.000000e+00:15>>, %rhs: tensor<!quant.uniform<i32:f32, 2.000000e+00:15>>):
%add = mhlo.add %lhs, %rhs : tensor<!quant.uniform<i32:f32, 2.000000e+00:15>>
"mhlo.return"(%add) : (tensor<!quant.uniform<i32:f32, 2.000000e+00:15>>) -> ()
}) {
scatter_dimension_numbers = #mhlo.scatter<
update_window_dims = [1],
inserted_window_dims = [0, 1],
scatter_dims_to_operand_dims = [0, 1],
index_vector_dim = 1
>,
indices_are_sorted = true,
unique_indices = true
} : (tensor<200x100x300x!quant.uniform<i8:f32, 2.000000e+00:15>>, tensor<10x2xi32>,
tensor<10x300x!quant.uniform<i8:f32, 2.000000e+00:15>>) ->
tensor<200x100x300x!quant.uniform<i32:f32, 2.000000e+00:15>>
func.return %0 : tensor<200x100x300x!quant.uniform<i32:f32, 2.000000e+00:15>>
}
// -----

func.func @invalid_scatter(%input_tensor: tensor<200x100x300xf32>,
%scatter_indices: tensor<10x2xf32>, %updates: tensor<10x300xf32>) ->
tensor<200x100x300xf32> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,61 @@ func.func @select_and_scatter(
func.return %1 : tensor<10x24x24x64xf32>
}


// CHECK: func @select_and_scatter_with_promotable_types
func.func @select_and_scatter_with_promotable_types(
%arg0: tensor<10x24x24x64xf32>,
%arg1: tensor<10x12x12x64xf32>) -> () {
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
%1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = "mhlo.compare"(%arg3, %arg4) {
comparison_direction = #mhlo<comparison_direction GE>
} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%2) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<f64>, %arg4: tensor<f64>):
%2 = mhlo.add %arg3, %arg4 : tensor<f64>
"mhlo.return"(%2) : (tensor<f64>) -> ()
}) {
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>,
padding = dense<0> : tensor<4x2xi64>
} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) ->
tensor<10x24x24x64xf64>
func.return
}


// CHECK: func @select_and_scatter_with_promotable_quantized_types
func.func @select_and_scatter_with_promotable_quantized_types(
%arg0: tensor<10x24x24x64x!quant.uniform<i8:f32, 2.000000e+00:15>>,
%arg1: tensor<10x12x12x64x!quant.uniform<i8:f32, 2.000000e+00:15>>,
%arg2 : tensor<!quant.uniform<i8:f32, 2.000000e+00:15>>) ->
tensor<10x24x24x64x!quant.uniform<i32:f32, 2.000000e+00:15>> {

%1 = "mhlo.select_and_scatter"(%arg0, %arg1, %arg2) ({
^bb0(%arg3: tensor<!quant.uniform<i8:f32, 2.000000e+00:15>>, %arg4: tensor<!quant.uniform<i8:f32, 2.000000e+00:15>>):
%2 = "mhlo.compare"(%arg3, %arg4) {
compare_type = #mhlo<comparison_type TOTALORDER>,
comparison_direction = #mhlo<comparison_direction GE>
} : (tensor<!quant.uniform<i8:f32, 2.000000e+00:15>>, tensor<!quant.uniform<i8:f32, 2.000000e+00:15>>) -> tensor<i1>
"mhlo.return"(%2) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<!quant.uniform<i32:f32, 2.000000e+00:15>>, %arg4: tensor<!quant.uniform<i32:f32, 2.000000e+00:15>>):
%2 = mhlo.add %arg3, %arg4 : tensor<!quant.uniform<i32:f32, 2.000000e+00:15>>
"mhlo.return"(%2) : (tensor<!quant.uniform<i32:f32, 2.000000e+00:15>>) -> ()
}) {
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>
} : (tensor<10x24x24x64x!quant.uniform<i8:f32, 2.000000e+00:15>>,
tensor<10x12x12x64x!quant.uniform<i8:f32, 2.000000e+00:15>>,
tensor<!quant.uniform<i8:f32, 2.000000e+00:15>>) ->
tensor<10x24x24x64x!quant.uniform<i32:f32, 2.000000e+00:15>>
func.return %1 : tensor<10x24x24x64x!quant.uniform<i32:f32, 2.000000e+00:15>>
}

// CHECK: func @select_and_scatter_with_unranked_dims
func.func @select_and_scatter_with_unranked_dims(
%arg0: tensor<4x5x1x1xbf16>,
%arg1: tensor<2x2x1x1xbf16>,
Expand Down

0 comments on commit 955f760

Please sign in to comment.