Skip to content

Commit

Permalink
Fix scatter with unsigned indices.
Browse files Browse the repository at this point in the history
See ScatterTest.OutOfBoundsUnsignedIndex in
compiler/xla/tests:scatter_test_cpu.

PiperOrigin-RevId: 500767330
  • Loading branch information
jreiffers authored and TensorFlow MLIR Team committed Jan 9, 2023
1 parent aa7b6e7 commit f493f16
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
11 changes: 11 additions & 0 deletions tests/lower_index_cast.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,14 @@ func.func @f(%arg0 : tensor<42x?xi32>) -> tensor<42x?xindex> {
%0 = arith.index_cast %arg0 : tensor<42x?xi32> to tensor<42x?xindex>
func.return %0 : tensor<42x?xindex>
}

// -----

// CHECK-LABEL: func @index_castui
func.func @index_castui(%arg0 : tensor<10xi32>) -> tensor<10xindex> {
// CHECK: tensor.generate {
// CHECK: %[[C:.*]] = arith.index_castui
// CHECK: tensor.yield %[[C]] : index
%0 = arith.index_castui %arg0 : tensor<10xi32> to tensor<10xindex>
func.return %0 : tensor<10xindex>
}
17 changes: 9 additions & 8 deletions transforms/lower_index_cast_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ namespace mlir {
namespace {

// index_cast is not defined on tensors, so lower it to a tensor.generate.
struct IndexCastConverter : public OpRewritePattern<arith::IndexCastOp> {
template <typename T>
struct IndexCastConverter : public OpRewritePattern<T> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::IndexCastOp op,
PatternRewriter &rewriter) const final {
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
using OpRewritePattern<T>::OpRewritePattern;
LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter) const final {
auto resultTy = op.getType().template dyn_cast<RankedTensorType>();
if (!resultTy) return failure();

SmallVector<Value> dynamicExtents =
Expand All @@ -48,8 +48,7 @@ struct IndexCastConverter : public OpRewritePattern<arith::IndexCastOp> {
op, resultTy, dynamicExtents,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value extent = b.create<tensor::ExtractOp>(loc, op.getIn(), args);
Value cast = b.create<arith::IndexCastOp>(
loc, resultTy.getElementType(), extent);
Value cast = b.create<T>(loc, resultTy.getElementType(), extent);
b.create<tensor::YieldOp>(loc, cast);
});
return success();
Expand All @@ -60,7 +59,9 @@ struct LowerIndexCastPass
: public impl::LowerIndexCastPassBase<LowerIndexCastPass> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
patterns.add<IndexCastConverter>(patterns.getContext());
patterns.add<IndexCastConverter<arith::IndexCastOp>,
IndexCastConverter<arith::IndexCastUIOp>>(
patterns.getContext());
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
Expand Down

0 comments on commit f493f16

Please sign in to comment.