Skip to content

Commit

Permalink
Allow quantized tensor as the input for mhlo.reshape. This also inclu…
Browse files Browse the repository at this point in the history
…de change the folding logic to allow constant folding on quantized tensor.

There are basically two changes proposed here. The first change in reshape method is a fix to avoid folders to crash on quantized type, so that translation like MHLO -> (X) can run smoothly for quantized types. The second change in `materializeConstant allows constant folding for quantized types.

We note that the fixes are not ideal and there is a [ticket](openxla/stablehlo#1691) to explore better solutions for the same.

PiperOrigin-RevId: 570133154
  • Loading branch information
tensorflower-gardener authored and TensorFlow MLIR Team committed Oct 2, 2023
1 parent 01a41c8 commit 29f9c05
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 2 deletions.
26 changes: 25 additions & 1 deletion mhlo/IR/hlo_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,16 @@ DenseElementsAttr reshape(DenseElementsAttr attr, ShapedType newType) {
auto splatValue = attr.getValues<bool>()[0];
return DenseElementsAttr::get(newType, {splatValue});
}
// Bypass the element type check for quantized tensor. For quantized tensors,
// we only require storage type and shape match the attribute type and shape.
if (auto quantElemTy =
newType.getElementType().dyn_cast<quant::QuantizedType>()) {
// Only shape and storage type information is needed to reshape the
// attribute.
auto quantShapedType =
RankedTensorType::get(newType.getShape(), quantElemTy.getStorageType());
return attr.reshape(quantShapedType);
}
return attr.reshape(newType);
}

Expand Down Expand Up @@ -7143,7 +7153,21 @@ Operation* MhloDialect::materializeConstant(OpBuilder& builder, Attribute value,
// HLO dialect constants only support ElementsAttr unlike standard dialect
// constant which supports all attributes.
if (!elementsAttr) return nullptr;
// HLO dialect constants require the type of value and result to match.
auto resultShapedType = type.dyn_cast<ShapedType>();
auto attrShapedType = elementsAttr.getType().dyn_cast<ShapedType>();
if (resultShapedType && attrShapedType) {
if (auto quantElemTy = resultShapedType.getElementType()
.dyn_cast<quant::QuantizedType>()) {
// Attribute type and shape should match storage type and shape for
// quantized tensors.
if ((attrShapedType.getElementType() != quantElemTy.getStorageType()) ||
(attrShapedType.getShape() != resultShapedType.getShape()))
return nullptr;
}
return builder.create<mhlo::ConstantOp>(loc, type, elementsAttr);
}
// HLO dialect constants require the type of value and result to match for
// non-quantized tensors.
if (type != elementsAttr.getType()) return nullptr;

return builder.create<mhlo::ConstantOp>(loc, type, elementsAttr);
Expand Down
2 changes: 1 addition & 1 deletion mhlo/IR/hlo_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -2718,7 +2718,7 @@ def MHLO_MapOp: MHLO_ShapedInterfaceOp<"map",
}

def MHLO_ReshapeOp: MHLO_Op<"reshape",
[Pure, SameOperandsAndResultElementType]> {
[Pure, HLO_CompatibleOperandsAndResultElementType]> {
let summary = "Reshape operation";
let description = [{
Performs reshape of `operand` tensor to a `result` tensor.
Expand Down
41 changes: 41 additions & 0 deletions tests/Dialect/mhlo/canonicalize/reshape.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,44 @@ func.func @non_const_many_chained_reshapes(%arg : tensor<2x3x4xi32>) -> tensor<1
// CHECK-NEXT: return [[RES]]
func.return %4 : tensor<1x2x4x3xi32>
}

// -----

// CHECK-LABEL: func @fold_per_tensor_quantized_tensor
func.func @fold_per_tensor_quantized_tensor() -> tensor<4x2x!quant.uniform<i8:f32, 2.000000e+0:16>> {
%cst = mhlo.constant() {value = dense<[[[1, 2],[3, 4]],[[5, 6],[7, 8]]]> : tensor<2x2x2xi8>} : () -> tensor<2x2x2x!quant.uniform<i8:f32, 2.000000e+0:16>>
%0 = "mhlo.reshape"(%cst) : (tensor<2x2x2x!quant.uniform<i8:f32, 2.000000e+0:16>>) -> tensor<4x2x!quant.uniform<i8:f32, 2.000000e+0:16>>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant()
// CHECK-SAME: [1, 2], [3, 4], [5, 6], [7, 8]
// CHECK-SAME: () -> tensor<4x2x!quant.uniform<i8:f32, 2.000000e+00:16>>
// CHECK-NEXT: return [[CST]]
func.return %0 : tensor<4x2x!quant.uniform<i8:f32, 2.000000e+0:16>>
}

// -----

// CHECK-LABEL: func @fold_per_axis_quantized_tensor
func.func @fold_per_axis_quantized_tensor() -> tensor<4x2x!quant.uniform<i8:f32:1, {2.000000e+0:16,3.000000e+0:32}>> {
%cst = mhlo.constant() {value = dense<[[[1, 2],[3, 4]],[[5, 6],[7, 8]]]> : tensor<2x2x2xi8>} : () -> tensor<2x2x2x!quant.uniform<i8:f32:2, {2.000000e+0:16,3.000000e+0:32}>>
%0 = "mhlo.reshape"(%cst) : (tensor<2x2x2x!quant.uniform<i8:f32:2, {2.000000e+0:16,3.000000e+0:32}>>) -> tensor<4x2x!quant.uniform<i8:f32:1, {2.000000e+0:16,3.000000e+0:32}>>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant()
// CHECK-SAME: [1, 2], [3, 4], [5, 6], [7, 8]
// CHECK-SAME: () -> tensor<4x2x!quant.uniform<i8:f32:1, {2.000000e+00:16,3.000000e+00:32}>>
// CHECK-NEXT: return [[CST]]
func.return %0 : tensor<4x2x!quant.uniform<i8:f32:1, {2.000000e+0:16,3.000000e+0:32}>>
}

// -----

// CHECK-LABEL: func @non_const_many_chained_reshapes_quantized
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func.func @non_const_many_chained_reshapes_quantized(%arg : tensor<2x3x4x!quant.uniform<i8:f32, 2.000000e+0:16>>) -> tensor<1x2x4x3x!quant.uniform<i8:f32, 2.000000e+0:16>> {
%0 = "mhlo.reshape"(%arg) : (tensor<2x3x4x!quant.uniform<i8:f32, 2.000000e+0:16>>) -> tensor<4x3x2x!quant.uniform<i8:f32, 2.000000e+0:16>>
%1 = "mhlo.reshape"(%0) : (tensor<4x3x2x!quant.uniform<i8:f32, 2.000000e+0:16>>) -> tensor<12x2x!quant.uniform<i8:f32, 2.000000e+0:16>>
%2 = "mhlo.reshape"(%1) : (tensor<12x2x!quant.uniform<i8:f32, 2.000000e+0:16>>) -> tensor<2x12x!quant.uniform<i8:f32, 2.000000e+0:16>>
%3 = "mhlo.reshape"(%2) : (tensor<2x12x!quant.uniform<i8:f32, 2.000000e+0:16>>) -> tensor<24x!quant.uniform<i8:f32, 2.000000e+0:16>>
%4 = "mhlo.reshape"(%3) : (tensor<24x!quant.uniform<i8:f32, 2.000000e+0:16>>) -> tensor<1x2x4x3x!quant.uniform<i8:f32, 2.000000e+0:16>>
// CHECK-NEXT: [[RES:%.+]] = mhlo.reshape [[ARG]] : (tensor<2x3x4x!quant.uniform<i8:f32, 2.000000e+00:16>>) -> tensor<1x2x4x3x!quant.uniform<i8:f32, 2.000000e+00:16>>
// CHECK-NEXT: return [[RES]]
func.return %4 : tensor<1x2x4x3x!quant.uniform<i8:f32, 2.000000e+0:16>>
}

0 comments on commit 29f9c05

Please sign in to comment.