Skip to content

Commit

Permalink
Update stablehlo.constant to TTIR conversion (#1068)
Browse files Browse the repository at this point in the history
* Convert integer scalars of different bit widths explicitly.
* Convert boolean values to bfloat16 due to unavailability of TTNN support.
* Convert 64-bit integers to 32-bit integers due to unavailability of TTNN
  support.
  • Loading branch information
mmanzoorTT authored Nov 4, 2024
1 parent 25191b6 commit 6100428
Show file tree
Hide file tree
Showing 4 changed files with 252 additions and 32 deletions.
4 changes: 4 additions & 0 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ class StablehloTypeConverter : public TypeConverter {
if (type.getElementTypeBitWidth() == 1) {
elementType = BFloat16Type::get(elementType.getContext());
changed = true;
} else if (type.getElementTypeBitWidth() == 64 &&
isa<IntegerType>(type.getElementType())) {
elementType = IntegerType::get(elementType.getContext(), 32);
changed = true;
}
// Create shape of 1-D tensor in case of scalar input.
if (shape.size() == 0) {
Expand Down
98 changes: 79 additions & 19 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,30 @@
//
// SPDX-License-Identifier: Apache-2.0

#include <algorithm>
#include <vector>

#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"

#include "ttmlir/Conversion/StableHLOToTTIR/StableHLOToTTIR.h"
#include "ttmlir/Dialect/TT/IR/TT.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTIR/IR/TTIR.h"
#include "ttmlir/Dialect/TTIR/IR/TTIROps.h"

#include "mlir/Dialect/Traits.h"
#include <llvm/ADT/APFloat.h>
#include <mlir/Dialect/Func/Transforms/FuncConversions.h>
#include <mlir/Dialect/Tensor/IR/Tensor.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/PatternMatch.h>
#include <mlir/IR/ValueRange.h>
#include <mlir/Support/LogicalResult.h>

#include <stablehlo/dialect/StablehloOps.h>

#include "ttmlir/Dialect/TT/IR/TT.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTIR/IR/TTIR.h"
#include "ttmlir/Dialect/TTIR/IR/TTIROps.h"

using namespace mlir;
using namespace mlir::tt;

Expand Down Expand Up @@ -315,12 +314,7 @@ class StableHLOToTTIRConstantOpConversionPattern
auto outputType = mlir::cast<RankedTensorType>(
getTypeConverter()->convertType(srcOp.getResult().getType()));

// Scalar tensors are not supported by TTIR so we have to convert them to
// 1-D tensors.
mlir::ElementsAttr valueAttr =
srcOp.getValue().getShapedType().getShape().empty()
? convertTo1DTensor(srcOp.getValue())
: srcOp.getValue();
mlir::ElementsAttr valueAttr = getValueAttr(srcOp.getValue());

rewriter.replaceOpWithNewOp<mlir::tt::ttir::ConstantOp>(srcOp, outputType,
valueAttr);
Expand All @@ -338,13 +332,50 @@ class StableHLOToTTIRConstantOpConversionPattern
return success();
}

mlir::ElementsAttr convertTo1DTensor(mlir::ElementsAttr valueAttr) const {
// Rebuilding value of constant op for following cases.
// 1. Scalar values: TTNN does not support scalar types. So they are converted
// 1-D tensors.
// 2. Boolean tensor: TTNN does not support boolean data. So they are
// converted to bfloat16 tensors.
// 3. Integer tensor: TTNN does not support 64 bit integer. So they are
// converted to 32 bit tensor.
mlir::ElementsAttr getValueAttr(mlir::ElementsAttr valueAttr) const {
Type elementType = valueAttr.getElementType();
size_t bitWidth = elementType.getIntOrFloatBitWidth();
bool isTensor = !valueAttr.getShapedType().getShape().empty();
bool isIntTensor = isTensor && isa<IntegerType>(elementType) &&
bitWidth != 1 && bitWidth != 64;
bool isFloatTensor = isTensor && isa<FloatType>(elementType);

if (isTensor && (isIntTensor || isFloatTensor)) {
return valueAttr;
}

mlir::ShapedType valueType = mlir::cast<mlir::ShapedType>(
getTypeConverter()->convertType(valueAttr.getShapedType()));
if (valueAttr.getElementType().isInteger()) {
return mlir::DenseElementsAttr::get<int>(valueType,
valueAttr.getSplatValue<int>());
} else {
if (isa<IntegerType>(elementType)) {
switch (bitWidth) {
case 1: {
return rebuildValueAttr<bool>(valueAttr, 1);
}
case 8: {
return rebuildValueAttr<int8_t>(valueAttr, 8);
}
case 16: {
return rebuildValueAttr<int16_t>(valueAttr, 16);
}
case 32: {
return rebuildValueAttr<int32_t>(valueAttr, 32);
}
case 64: {
return rebuildValueAttr<int64_t>(valueAttr, 32);
}
default: {
assert(false && "Unsupported integer type.");
}
}
}
if (isa<FloatType>(elementType)) {
// In case of float values llvm has a bug where not all float types are
// supported for iterating in DenseElementsAttr, so we have to use a
// different constructor.
Expand All @@ -353,6 +384,35 @@ class StableHLOToTTIRConstantOpConversionPattern
valueAttr.getValues<mlir::APFloat>().end());
return mlir::DenseElementsAttr::get(valueType, floatValues);
}
assert(false && "Unsupported data type.");
}

// Extract the values (using the given ElementType) and create new data
// structure. This is used to convert scalars (of type boolean, int8, int16,
// int32, and int64) and tensors (of type boolean and int64).
template <typename ElementType>
mlir::ElementsAttr rebuildValueAttr(mlir::ElementsAttr valueAttr,
size_t bitWidth) const {
mlir::ShapedType valueType = mlir::cast<mlir::ShapedType>(
getTypeConverter()->convertType(valueAttr.getShapedType()));

// Create data structure for boolean type with bfloat16.
if (bitWidth == 1) {
std::vector<mlir::APFloat> booleanValue = {};
for (ElementType value : valueAttr.getValues<ElementType>()) {
mlir::APFloat input(mlir::APFloat::BFloat(), value);
booleanValue.emplace_back(input);
}
return mlir::DenseElementsAttr::get(valueType, booleanValue);
}

// Create data structure for other types.
std::vector<mlir::APInt> IntegerValue = {};
for (ElementType value : valueAttr.getValues<ElementType>()) {
mlir::APInt input(bitWidth, value);
IntegerValue.emplace_back(input);
}
return mlir::DenseElementsAttr::get(valueType, IntegerValue);
}
};

Expand Down
4 changes: 2 additions & 2 deletions test/ttmlir/Conversion/StableHLOToTTIR/binary/concat_op.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ module @jit_concat attributes {} {
dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// CHECK: %[[C:.*]] = tensor.empty[[C:.*]]
// CHECK: %[[C:.*]] = "ttir.concat"(%arg0, %arg1, %0) <{dim = 0 : si32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<3x2xi64>, tensor<1x2xi64>, tensor<4x2xi64>) -> tensor<4x2xi64>
// CHECK: %[[C:.*]] = "ttir.concat"(%arg0, %arg1, %0) <{dim = 0 : si32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<3x2xi32>, tensor<1x2xi32>, tensor<4x2xi32>) -> tensor<4x2xi32>
return %0 : tensor<4x2xi64>
}

Expand All @@ -42,7 +42,7 @@ module @jit_concat attributes {} {
dimension = 1 : i64
} : (tensor<256x512xi64>, tensor<256x256xi64>) -> tensor<256x768xi64>
// CHECK: %[[C:.*]] = tensor.empty[[C:.*]]
// CHECK: %[[C:.*]] = "ttir.concat"(%arg0, %arg1, %0) <{dim = 1 : si32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<256x512xi64>, tensor<256x256xi64>, tensor<256x768xi64>) -> tensor<256x768xi64>
// CHECK: %[[C:.*]] = "ttir.concat"(%arg0, %arg1, %0) <{dim = 1 : si32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<256x512xi32>, tensor<256x256xi32>, tensor<256x768xi32>) -> tensor<256x768xi32>
return %0 : tensor<256x768xi64>
}

Expand Down
178 changes: 167 additions & 11 deletions test/ttmlir/Conversion/StableHLOToTTIR/constant_op.mlir
Original file line number Diff line number Diff line change
@@ -1,31 +1,187 @@
// REQUIRES: stablehlo
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s
module @jit_constant attributes {} {
func.func public @test_splat() -> tensor<64xf32> {
%0 = stablehlo.constant dense<0.3> : tensor<64xf32>
func.func public @test_boolean_scalar() -> tensor<i1> {
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<1.000000e+00> : tensor<1xbf16>}> : () -> tensor<1xbf16>
%0 = stablehlo.constant dense<true> : tensor<i1>
// CHECK: return %{{[0-9]+}} : tensor<1xbf16>
return %0 : tensor<i1>
}

func.func public @test_boolean_splat() -> tensor<64xi1> {
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<1.000000e+00> : tensor<64xbf16>}> : () -> tensor<64xbf16>
%0 = stablehlo.constant dense<true> : tensor<64xi1>
// CHECK: return %{{[0-9]+}} : tensor<64xbf16>
return %0 : tensor<64xi1>
}

func.func public @test_boolean_multiple() -> tensor<2x2xi1> {
// The ugly regex after `dense` is necessary because double square opening
// brackets indicate substitution block in FileCheck syntax.
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<{{([[])}}[1.000000e+00, 0.000000e+00], [0.000000e+00, 1.000000e+00]]> : tensor<2x2xbf16>}> : () -> tensor<2x2xbf16>
%0 = stablehlo.constant dense<[[true, false], [false, true]]> : tensor<2x2xi1>
// CHECK: return %{{[0-9]+}} : tensor<2x2xbf16>
return %0 : tensor<2x2xi1>
}

func.func public @test_bfloat16_scalar() -> tensor<bf16> {
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3.000000e+00> : tensor<1xbf16>}> : () -> tensor<1xbf16>
%0 = stablehlo.constant dense<3.0> : tensor<bf16>
// CHECK: return %{{[0-9]+}} : tensor<1xbf16>
return %0 : tensor<bf16>
}

func.func public @test_bfloat16_splat() -> tensor<64xbf16> {
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3.000000e+00> : tensor<64xbf16>}> : () -> tensor<64xbf16>
%0 = stablehlo.constant dense<3.0> : tensor<64xbf16>
// CHECK: return %{{[0-9]+}} : tensor<64xbf16>
return %0 : tensor<64xbf16>
}

func.func public @test_bfloat16_multiple() -> tensor<2x2xbf16> {
// The ugly regex after `dense` is necessary because double square opening
// brackets indicate substitution block in FileCheck syntax.
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<{{([[])}}[0.000000e+00, 1.000000e+00], [2.000000e+00, 3.000000e+00]]> : tensor<2x2xbf16>}> : () -> tensor<2x2xbf16>
%0 = stablehlo.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xbf16>
// CHECK: return %{{[0-9]+}} : tensor<2x2xbf16>
return %0 : tensor<2x2xbf16>
}

func.func public @test_float16_scalar() -> tensor<f16> {
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3.000000e+00> : tensor<1xf16>}> : () -> tensor<1xf16>
%0 = stablehlo.constant dense<3.0> : tensor<f16>
// CHECK: return %{{[0-9]+}} : tensor<1xf16>
return %0 : tensor<f16>
}

func.func public @test_float16_splat() -> tensor<64xf16> {
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3.000000e+00> : tensor<64xf16>}> : () -> tensor<64xf16>
%0 = stablehlo.constant dense<3.0> : tensor<64xf16>
// CHECK: return %{{[0-9]+}} : tensor<64xf16>
return %0 : tensor<64xf16>
}

func.func public @test_float16_multiple() -> tensor<2x2xf16> {
// The ugly regex after `dense` is necessary because double square opening
// brackets indicate substitution block in FileCheck syntax.
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<{{([[])}}[0.000000e+00, 1.000000e+00], [2.000000e+00, 3.000000e+00]]> : tensor<2x2xf16>}> : () -> tensor<2x2xf16>
%0 = stablehlo.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf16>
// CHECK: return %{{[0-9]+}} : tensor<2x2xf16>
return %0 : tensor<2x2xf16>
}

func.func public @test_float_scalar() -> tensor<f32> {
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3.000000e-01> : tensor<1xf32>}> : () -> tensor<1xf32>
%0 = stablehlo.constant dense<0.3> : tensor<f32>
// CHECK: return %{{[0-9]+}} : tensor<1xf32>
return %0 : tensor<f32>
}

func.func public @test_float_splat() -> tensor<64xf32> {
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3.000000e-01> : tensor<64xf32>}> : () -> tensor<64xf32>
%0 = stablehlo.constant dense<0.3> : tensor<64xf32>
// CHECK: return %{{[0-9]+}} : tensor<64xf32>
return %0 : tensor<64xf32>
}

func.func public @test_multiple() -> tensor<2x2xf32> {
%0 = stablehlo.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
func.func public @test_float_multiple() -> tensor<2x2xf32> {
// The ugly regex after `dense` is necessary because double square opening
// brackets indicate substitution block in FileCheck syntax.
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<{{([[])}}[0.000000e+00, 1.000000e+00], [2.000000e+00, 3.000000e+00]]> : tensor<2x2xf32>}> : () -> tensor<2x2xf32>
%0 = stablehlo.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
// CHECK: return %{{[0-9]+}} : tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}

func.func public @test_scalar_int() -> tensor<i32> {
%0 = stablehlo.constant dense<3> : tensor<i32>
func.func public @test_int8_scalar() -> tensor<i8> {
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
%0 = stablehlo.constant dense<3> : tensor<i8>
// CHECK: return %{{[0-9]+}} : tensor<1xi8>
return %0 : tensor<i8>
}

func.func public @test_int8_splat() -> tensor<64xi8> {
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<64xi8>}> : () -> tensor<64xi8>
%0 = stablehlo.constant dense<3> : tensor<64xi8>
// CHECK: return %{{[0-9]+}} : tensor<64xi8>
return %0 : tensor<64xi8>
}

func.func public @test_int8_multiple() -> tensor<2x2xi8> {
// The ugly regex after `dense` is necessary because double square opening
// brackets indicate substitution block in FileCheck syntax.
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<{{([[])}}[0, 1], [2, 3]]> : tensor<2x2xi8>}> : () -> tensor<2x2xi8>
%0 = stablehlo.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi8>
// CHECK: return %{{[0-9]+}} : tensor<2x2xi8>
return %0 : tensor<2x2xi8>
}

func.func public @test_int16_scalar() -> tensor<i16> {
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<1xi16>}> : () -> tensor<1xi16>
%0 = stablehlo.constant dense<3> : tensor<i16>
// CHECK: return %{{[0-9]+}} : tensor<1xi16>
return %0 : tensor<i16>
}

func.func public @test_int16_splat() -> tensor<64xi16> {
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<64xi16>}> : () -> tensor<64xi16>
%0 = stablehlo.constant dense<3> : tensor<64xi16>
// CHECK: return %{{[0-9]+}} : tensor<64xi16>
return %0 : tensor<64xi16>
}

func.func public @test_int16_multiple() -> tensor<2x2xi16> {
// The ugly regex after `dense` is necessary because double square opening
// brackets indicate substitution block in FileCheck syntax.
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<{{([[])}}[0, 1], [2, 3]]> : tensor<2x2xi16>}> : () -> tensor<2x2xi16>
%0 = stablehlo.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi16>
// CHECK: return %{{[0-9]+}} : tensor<2x2xi16>
return %0 : tensor<2x2xi16>
}

func.func public @test_int32_scalar() -> tensor<i32> {
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<1xi32>}> : () -> tensor<1xi32>
%0 = stablehlo.constant dense<3> : tensor<i32>
// CHECK: return %{{[0-9]+}} : tensor<1xi32>
return %0 : tensor<i32>
}

func.func public @test_int32_splat() -> tensor<64xi32> {
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<64xi32>}> : () -> tensor<64xi32>
%0 = stablehlo.constant dense<3> : tensor<64xi32>
// CHECK: return %{{[0-9]+}} : tensor<64xi32>
return %0 : tensor<64xi32>
}

func.func public @test_int32_multiple() -> tensor<2x2xi32> {
// The ugly regex after `dense` is necessary because double square opening
// brackets indicate substitution block in FileCheck syntax.
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<{{([[])}}[0, 1], [2, 3]]> : tensor<2x2xi32>}> : () -> tensor<2x2xi32>
%0 = stablehlo.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
// CHECK: return %{{[0-9]+}} : tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}

func.func public @test_int64_scalar() -> tensor<i64> {
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<1xi32>}> : () -> tensor<1xi32>
%0 = stablehlo.constant dense<3> : tensor<i64>
// CHECK: return %{{[0-9]+}} : tensor<1xi32>
return %0 : tensor<i64>
}

func.func public @test_scalar_float() -> tensor<f32> {
%0 = stablehlo.constant dense<0.3> : tensor<f32>
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3.000000e-01> : tensor<1xf32>}> : () -> tensor<1xf32>
return %0 : tensor<f32>
// CHECK: return %{{[0-9]+}} : tensor<1xf32>
func.func public @test_int64_splat() -> tensor<64xi64> {
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<64xi32>}> : () -> tensor<64xi32>
%0 = stablehlo.constant dense<3> : tensor<64xi64>
// CHECK: return %{{[0-9]+}} : tensor<64xi32>
return %0 : tensor<64xi64>
}

func.func public @test_int64_multiple() -> tensor<2x2xi64> {
// The ugly regex after `dense` is necessary because double square opening
// brackets indicate substitution block in FileCheck syntax.
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<{{([[])}}[0, 1], [2, 3]]> : tensor<2x2xi32>}> : () -> tensor<2x2xi32>
%0 = stablehlo.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi64>
// CHECK: return %{{[0-9]+}} : tensor<2x2xi32>
return %0 : tensor<2x2xi64>
}
}

0 comments on commit 6100428

Please sign in to comment.