Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infer bounds for Reduce op #737

Merged
merged 4 commits into from
Dec 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion stablehlo/dialect/Base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Type getExpressedTypeOrSelf(Type type) {
auto quantType = type.dyn_cast<quant::QuantizedType>();
return quantType ? quantType.getExpressedType() : type;
}
} // namespace

LogicalResult verifyCompatibleShapeWithBounds(Type type1, Type type2) {
if (failed(verifyCompatibleShape(type1, type2))) return failure();
Expand Down Expand Up @@ -61,7 +62,6 @@ LogicalResult verifyCompatibleShapeWithBounds(Type type1, Type type2) {
}
return success();
}
} // namespace

bool isCompatibleForHloTypeInference(Type tp1, Type tp2) {
// Dynamism: We don't require shapes to be the same, we only require them
Expand Down
4 changes: 4 additions & 0 deletions stablehlo/dialect/Base.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ inline static bool isStaticDimSize(int64_t val) {
return !isDynamicDimSize(val);
}

// Verifies that the two types have compatible shape with bounds but allows
// different element types.
LogicalResult verifyCompatibleShapeWithBounds(Type type1, Type type2);

// Returns true if the given types are the same for the purposes of HLO type
// inference, accounting for special properties of quantization and sparsity.
bool isCompatibleForHloTypeInference(Type tp1, Type tp2);
Expand Down
32 changes: 22 additions & 10 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ LogicalResult verifyReplicaGroups(Optional<Location> location,
LogicalResult verifyReduceOpInputsAndInferShape(
Optional<Location> location, SmallVector<TensorType> inputArgTypes,
SmallVector<TensorType> initValueTypes, DenseIntElementsAttr dimensions,
SmallVector<int64_t>& newDimensions) {
SmallVector<int64_t>& newDimensions, Attribute& encoding) {
// Check for unranked tensors in input operands.
uint64_t numInputs = inputArgTypes.size();
int64_t rankedInputIdx = -1;
Expand Down Expand Up @@ -420,13 +420,21 @@ LogicalResult verifyReduceOpInputsAndInferShape(
}

if (!allInputsUnranked) {
for (int inputIdx = 0; inputIdx < inputArgTypes[rankedInputIdx].getRank();
++inputIdx) {
auto rankedInput = inputArgTypes[rankedInputIdx].cast<RankedTensorType>();

ArrayRef<int64_t> inputBounds = encodingToBounds(rankedInput.getEncoding());
SmallVector<int64_t> newBounds;
for (int inputIdx = 0; inputIdx < rankedInput.getRank(); ++inputIdx) {
if (!dimensionsToReduceSet.count(inputIdx)) {
newDimensions.push_back(
inputArgTypes[rankedInputIdx].getDimSize(inputIdx));
newDimensions.push_back(rankedInput.getDimSize(inputIdx));
if (!inputBounds.empty()) {
newBounds.push_back(inputBounds[inputIdx]);
}
}
}
if (!inputBounds.empty()) {
encoding = boundsToEncoding(rankedInput.getEncoding(), newBounds);
}
}
return success();
}
Expand Down Expand Up @@ -1251,15 +1259,17 @@ LogicalResult inferReduceOp(
[](Type t) -> TensorType { return t.cast<TensorType>(); })};

SmallVector<int64_t> newDimensions;
if (failed(verifyReduceOpInputsAndInferShape(
location, inputArgTypes, initValueTypes, dimensions, newDimensions)))
Attribute encoding;
if (failed(verifyReduceOpInputsAndInferShape(location, inputArgTypes,
initValueTypes, dimensions,
newDimensions, encoding)))
return failure();

for (uint64_t inputIdx = 0; inputIdx < inputs.size(); ++inputIdx) {
TensorType inputType = inputArgTypes[inputIdx];
Type elementType = inputType.getElementType();
if (inputType.hasRank())
inferredReturnShapes.emplace_back(newDimensions, elementType);
inferredReturnShapes.emplace_back(newDimensions, elementType, encoding);
else
inferredReturnShapes.emplace_back(elementType);
}
Expand Down Expand Up @@ -1582,8 +1592,10 @@ LogicalResult verifyReduceOp(Optional<Location> location, ValueRange inputs,

// P1. & P2.
SmallVector<int64_t> newDimensions;
if (failed(verifyReduceOpInputsAndInferShape(
location, inputArgTypes, initValueTypes, dimensions, newDimensions)))
Attribute encoding;
if (failed(verifyReduceOpInputsAndInferShape(location, inputArgTypes,
initValueTypes, dimensions,
newDimensions, encoding)))
return failure();

// P3.
Expand Down
60 changes: 51 additions & 9 deletions stablehlo/tests/infer_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -600,21 +600,63 @@ func.func @complex_sparsity(%arg0: tensor<10x10xf32, #CSR>, %arg1: tensor<10x10x
// -----

// CHECK-LABEL: func @reduce
func.func @reduce(%arg0: tensor<4x4xf32>, %arg1 : tensor<4xf32>)
-> (tensor<4xindex>) {
func.func @reduce(%arg0: tensor<7x5xf32>, %arg1 : tensor<5xf32>)
-> (tensor<5xindex>) {
%0 = "stablehlo.reduce"(%arg0, %arg1) ({

^bb0(%arg2: tensor<4xf32>, %arg3: tensor<4xf32> ):
%1 = "stablehlo.add"(%arg2, %arg3) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
"stablehlo.return"(%1) : (tensor<4xf32>) -> ()
^bb0(%arg2: tensor<5xf32>, %arg3: tensor<5xf32> ):
%1 = "stablehlo.add"(%arg2, %arg3) : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32>
"stablehlo.return"(%1) : (tensor<5xf32>) -> ()

}) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4xf32>
}) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor<7x5xf32>, tensor<5xf32>) -> tensor<5xf32>

// CHECK: %1 = "hlo_test_infer.return_type_components"(%0) {dims0 = "[4]", element_type0 = f32} : (tensor<4xf32>) -> tensor<4xindex>
// CHECK: {dims0 = "[5]", element_type0 = f32}
%2 = "hlo_test_infer.get_return_type_components"(%0)
: (tensor<4xf32>) -> tensor<4xindex>
: (tensor<5xf32>) -> tensor<5xindex>

func.return %2: tensor<4xindex>
func.return %2: tensor<5xindex>
}

// -----

// CHECK-LABEL: func @reduce_with_bounds
func.func @reduce_with_bounds(%arg0: tensor<?x?x5xf32, #stablehlo.type_extensions<bounds = [3, 7, ?]>>, %arg1 : tensor<5xf32>)
-> (tensor<*xindex>) {
%0 = "stablehlo.reduce"(%arg0, %arg1) ({

^bb0(%arg2: tensor<5xf32>, %arg3: tensor<5xf32> ):
%1 = "stablehlo.add"(%arg2, %arg3) : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32>
"stablehlo.return"(%1) : (tensor<5xf32>) -> ()

}) {dimensions = dense<[0]> : tensor<1xi64>}
: (tensor<?x?x5xf32, #stablehlo.type_extensions<bounds = [3, 7, ?]>>, tensor<5xf32>)
-> tensor<?x5xf32, #stablehlo.type_extensions<bounds = [7, ?]>>

// CHECK: types0 = tensor<?x5xf32, #stablehlo.type_extensions<bounds = [7, ?]>>
%2 = "hlo_test_infer.get_return_types"(%0)
: (tensor<?x5xf32, #stablehlo.type_extensions<bounds = [7, ?]>>) -> tensor<*xindex>

func.return %2: tensor<*xindex>
}

// -----

// CHECK-LABEL: func @unranked_reduce
func.func @unranked_reduce(%arg0: tensor<*xf32>, %arg1 : tensor<f32>)
zhouxin913 marked this conversation as resolved.
Show resolved Hide resolved
-> (tensor<*xindex>) {
%0 = "stablehlo.reduce"(%arg0, %arg1) ({

^bb0(%arg2: tensor<f32>, %arg3: tensor<f32> ):
%1 = "stablehlo.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"stablehlo.return"(%1) : (tensor<f32>) -> ()

}) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>

// CHECK: {element_type0 = f32}
%2 = "hlo_test_infer.get_return_type_components"(%0)
: (tensor<*xf32>) -> tensor<*xindex>

func.return %2: tensor<*xindex>
}

// -----
Expand Down