Skip to content

Commit

Permalink
[MHLO] Init MHLO view like op patterns
Browse files Browse the repository at this point in the history
See RFC: llvm#999

Co-authored-by: Bairen Yi [email protected]
Co-authored-by: Jiawei Wu [email protected]
Co-authored-by: Tianyou Guo [email protected]
Co-authored-by: Xu Yan [email protected]
Co-authored-by: Ziheng Jiang [email protected]
  • Loading branch information
Tanyo Kwok committed Jul 22, 2022
1 parent a02dbb2 commit 6432021
Show file tree
Hide file tree
Showing 6 changed files with 257 additions and 3 deletions.
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ endmacro()
option(TORCH_MLIR_ENABLE_MHLO "Add mhlo dialect" ON)
if(TORCH_MLIR_ENABLE_MHLO)
add_definitions(-DTORCH_MLIR_ENABLE_MHLO)
option(TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
"Enable truncate dimension size from i64 to i32(unsafely)" OFF)
if(TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32)
add_definitions(-DTORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32)
endif()
endif()

torch_mlir_add_llvm_external_project(
Expand Down
4 changes: 3 additions & 1 deletion lib/Conversion/TorchToMhlo/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
add_mlir_conversion_library(TorchMLIRTorchToMhlo
TorchToMhlo.cpp
BasicOp.cpp
SliceLikeOps.cpp
ViewLikeOps.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToMhlo

DEPENDS
MhloDialect
ChloDialect
TorchMLIRConversionPassIncGen

LINK_COMPONENTS
Expand All @@ -17,6 +18,7 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo
MLIRIR
MLIRPass
MhloDialect
ChloDialect
TorchMLIRTorchDialect
)

Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TorchToMhlo/PopulatePatterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace torch_to_mhlo {
void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target);
void populateSliceLikeOpPatternsAndLegality(TypeConverter &typeConverter,
void populateViewLikeOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target);

Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TorchToMhlo/TorchToMhlo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class ConvertTorchToMhlo : public ConvertTorchToMhloBase<ConvertTorchToMhlo> {

torch_to_mhlo::populateBasicOpPatternsAndLegality(typeConverter, patterns,
target);
torch_to_mhlo::populateSliceLikeOpPatternsAndLegality(typeConverter, patterns,
torch_to_mhlo::populateViewLikeOpPatternsAndLegality(typeConverter, patterns,
target);

if (failed(applyPartialConversion(getOperation(), target,
Expand Down
132 changes: 132 additions & 0 deletions lib/Conversion/TorchToMhlo/ViewLikeOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"

#include "../PassDetail.h"
#include "./PopulatePatterns.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include <numeric>

using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::TorchConversion;


namespace {

// This defines a template to construct ops whose legalizations are
// specialized.
template <typename AtenOpT>
class ConvertAtenViewOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;

LogicalResult matchAndRewrite(
AtenOpT op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
auto rankType =
adaptor.self().getType().template dyn_cast<RankedTensorType>();
if (!rankType)
return op.emitError("Only ranked tensor types are currently supported");

SmallVector<Value, 4> dimSizes;
if (!getAtenViewOpSizes(op, adaptor, rewriter, dimSizes)) {
return op.emitError("Dims size must be a list of Scalar");
}

auto loc = op.getLoc();
auto newRank = dimSizes.size();
if (newRank == 0 || rankType.getRank() == 0) {
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
adaptor.self());
return success();
}

std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value& dSize) {
dSize = rewriter.create<ToI64Op>(loc, dSize).getResult();
return dSize;
});

#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
// The i64 calculation is much slower than i32 on some devices, such as Nvidia GPU.
// One can truncate from i64 to i32 since dimension sizes are unlikely to exceed
// the range of i32(4GiB)
std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value& dSize) {
// dimSize: cast i64 -> i32
dSize = rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), dSize);
return dSize;
});
#endif

Value mhloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
rewriter.replaceOpWithNewOp<chlo::DynamicReshapeOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
adaptor.self(),
mhloShape);
return success();
}

bool getAtenViewOpSizes(
AtenOpT op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter,
SmallVector<Value, 4>& dimSizes) const;
};

template <>
bool ConvertAtenViewOp<AtenViewOp>::getAtenViewOpSizes(
AtenViewOp op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter,
SmallVector<Value, 4>& dimSizes) const {
return getListConstructElements(adaptor.size(), dimSizes);
}

template <>
bool ConvertAtenViewOp<AtenReshapeOp>::getAtenViewOpSizes(
AtenReshapeOp op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter,
SmallVector<Value, 4>& dimSizes) const {
return getListConstructElements(adaptor.shape(), dimSizes);
}

} // namespace

void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
MLIRContext *context = patterns.getContext();

#define INSERT_VIEW_OP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenViewOp<AtenOp>>(typeConverter, context);
INSERT_VIEW_OP_PATTERN(AtenViewOp);
INSERT_VIEW_OP_PATTERN(AtenReshapeOp);
#undef INSERT_VIEW_OP_PATTERN
}
115 changes: 115 additions & 0 deletions test/Conversion/TorchToMhlo/view_like.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s

// CHECK-LABEL: func.func @torch.aten.view$view_like(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,224],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
// CHECK: %[[INT:.*]]-1 = torch.constant.int -1
// CHECK: %[[INT224:.*]] = torch.constant.int 224
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT]]-1, %[[INT]]224 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT]]-1
// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT224]]
// CHECK: %[[T4:.*]] = arith.trunci %[[T2]] : i64 to i32
// CHECK: %[[T5:.*]] = arith.trunci %[[T3]] : i64 to i32
// CHECK: %[[T6:.*]] = tensor.from_elements %[[T4]], %[[T5]] : tensor<2xi32>
// CHECK: %[[T7:.*]] = "chlo.dynamic_reshape"(%[[T0]], %[[T6]]) : (tensor<?x?x?x?xf32>, tensor<2xi32>) -> tensor<?x224xf32>
// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor<?x224xf32> -> !torch.vtensor<[?,224],f32>
// CHECK: return %[[T8]] : !torch.vtensor<[?,224],f32>
func.func @torch.aten.view$view_like(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,224],f32> {
%int-1 = torch.constant.int -1
%int224 = torch.constant.int 224
%0 = torch.prim.ListConstruct %int-1, %int224 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.view %arg0, %0 : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,224],f32>
return %1 : !torch.vtensor<[?,224],f32>
}

// -----
// CHECK-LABEL: func.func @torch.aten.reshape$view_like(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.vtensor<[?,120,4,64],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?,?],f32> -> tensor<?x?x?x?x?xf32>
// CHECK: %[[INT:.*]]-1 = torch.constant.int -1
// CHECK: %[[INT120:.*]] = torch.constant.int 120
// CHECK: %[[INT4:.*]] = torch.constant.int 4
// CHECK: %[[INT64:.*]] = torch.constant.int 64
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT]]-1, %[[INT]]120, %[[INT]]4, %[[INT]]64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT]]-1
// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT120]]
// CHECK: %[[T4:.*]] = torch_c.to_i64 %[[INT4]]
// CHECK: %[[T5:.*]] = torch_c.to_i64 %[[INT64]]
// CHECK: %[[T6:.*]] = arith.trunci %[[T2]] : i64 to i32
// CHECK: %[[T7:.*]] = arith.trunci %[[T3]] : i64 to i32
// CHECK: %[[T8:.*]] = arith.trunci %[[T4]] : i64 to i32
// CHECK: %[[T9:.*]] = arith.trunci %[[T5]] : i64 to i32
// CHECK: %[[T10:.*]] = tensor.from_elements %[[T6]], %[[T7]], %[[T8]], %[[T9]] : tensor<4xi32>
// CHECK: %[[T11:.*]] = "chlo.dynamic_reshape"(%[[T0]], %[[T10]]) : (tensor<?x?x?x?x?xf32>, tensor<4xi32>) -> tensor<?x120x4x64xf32>
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<?x120x4x64xf32> -> !torch.vtensor<[?,120,4,64],f32>
// CHECK: return %[[T12]] : !torch.vtensor<[?,120,4,64],f32>
func.func @torch.aten.reshape$view_like(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.vtensor<[?,120,4,64],f32> {
%int-1 = torch.constant.int -1
%int120 = torch.constant.int 120
%int4 = torch.constant.int 4
%int64 = torch.constant.int 64
%0 = torch.prim.ListConstruct %int-1, %int120, %int4, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.reshape %arg0, %0 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,120,4,64],f32>
return %1 : !torch.vtensor<[?,120,4,64],f32>
}

// -----
// CHECK-LABEL: func.func @torch.aten.view.minus1$view_like(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,3,?,?],f32>) -> !torch.vtensor<[2,3,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,3,?,?],f32> -> tensor<2x3x?x?xf32>
// CHECK: %[[INT:.*]]-1 = torch.constant.int -1
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[T1:.*]] = torch.aten.size.int %[[ARG0]], %[[INT0]] : !torch.vtensor<[2,3,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[T2:.*]] = torch.aten.size.int %[[ARG0]], %[[INT1]] : !torch.vtensor<[2,3,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]], %[[INT]]-1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T4:.*]] = torch_c.to_i64 %[[T1]]
// CHECK: %[[T5:.*]] = torch_c.to_i64 %[[T2]]
// CHECK: %[[T6:.*]] = torch_c.to_i64 %[[INT]]-1
// CHECK: %[[T7:.*]] = arith.trunci %[[T4]] : i64 to i32
// CHECK: %[[T8:.*]] = arith.trunci %[[T5]] : i64 to i32
// CHECK: %[[T9:.*]] = arith.trunci %[[T6]] : i64 to i32
// CHECK: %[[T10:.*]] = tensor.from_elements %[[T7]], %[[T8]], %[[T9]] : tensor<3xi32>
// CHECK: %[[T11:.*]] = "chlo.dynamic_reshape"(%[[T0]], %[[T10]]) : (tensor<2x3x?x?xf32>, tensor<3xi32>) -> tensor<2x3x?xf32>
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<2x3x?xf32> -> !torch.vtensor<[2,3,?],f32>
// CHECK: return %[[T12]] : !torch.vtensor<[2,3,?],f32>
func.func @torch.aten.view.minus1$view_like(%arg0: !torch.vtensor<[2,3,?,?],f32>) -> !torch.vtensor<[2,3,?],f32> {
%int-1 = torch.constant.int -1
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[2,3,?,?],f32>, !torch.int -> !torch.int
%1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[2,3,?,?],f32>, !torch.int -> !torch.int
%2 = torch.prim.ListConstruct %0, %1, %int-1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%3 = torch.aten.view %arg0, %2 : !torch.vtensor<[2,3,?,?],f32>, !torch.list<int> -> !torch.vtensor<[2,3,?],f32>
return %3 : !torch.vtensor<[2,3,?],f32>
}

// -----
// CHECK-LABEL: func.func @torch.aten.view.to_rank1$view_like(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[],f32> -> tensor<f32>
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[T2:.*]] = "mhlo.reshape"(%[[T0]]) : (tensor<f32>) -> tensor<1xf32>
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
// CHECK: return %[[T3]] : !torch.vtensor<[1],f32>
func.func @torch.aten.view.to_rank1$view_like(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
%int1 = torch.constant.int 1
%0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%1 = torch.aten.view %arg0, %0 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[1],f32>
return %1 : !torch.vtensor<[1],f32>
}

// -----
// CHECK-LABEL: func.func @torch.aten.view.to_rank0$view_like(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1],f32>) -> !torch.vtensor<[],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1],f32> -> tensor<1xf32>
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[T2:.*]] = "mhlo.reshape"(%[[T0]]) : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<f32> -> !torch.vtensor<[],f32>
// CHECK: return %[[T3]] : !torch.vtensor<[],f32>
func.func @torch.aten.view.to_rank0$view_like(%arg0: !torch.vtensor<[1],f32>) -> !torch.vtensor<[],f32> {
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
%1 = torch.aten.view %arg0, %0 : !torch.vtensor<[1],f32>, !torch.list<int> -> !torch.vtensor<[],f32>
return %1 : !torch.vtensor<[],f32>
}

0 comments on commit 6432021

Please sign in to comment.