Skip to content

Commit

Permalink
Code redability improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
smit-hinsu committed Dec 10, 2022
1 parent 3b9f4ea commit 707ff46
Showing 1 changed file with 9 additions and 13 deletions.
22 changes: 9 additions & 13 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,24 +397,21 @@ LogicalResult verifyReduceOpInputsAndInferShape(
}
}

SmallVector<int64_t> newBounds;
if (!allInputsUnranked) {
RankedTensorType firstRankedInput =
inputArgTypes[rankedInputIdx].cast<RankedTensorType>();
ArrayRef<int64_t> inputBounds =
encodingToBounds(firstRankedInput.getEncoding());
for (int inputIdx = 0; inputIdx < inputArgTypes[rankedInputIdx].getRank();
++inputIdx) {
auto rankedInput = inputArgTypes[rankedInputIdx].cast<RankedTensorType>();

ArrayRef<int64_t> inputBounds = encodingToBounds(rankedInput.getEncoding());
SmallVector<int64_t> newBounds;
for (int inputIdx = 0; inputIdx < rankedInput.getRank(); ++inputIdx) {
if (!dimensionsToReduceSet.count(inputIdx)) {
newDimensions.push_back(
inputArgTypes[rankedInputIdx].getDimSize(inputIdx));
newDimensions.push_back(rankedInput.getDimSize(inputIdx));
if (!inputBounds.empty()) {
newBounds.push_back(inputBounds[inputIdx]);
}
}
}
if (!inputBounds.empty()) {
encoding = boundsToEncoding(firstRankedInput.getEncoding(), newBounds);
encoding = boundsToEncoding(rankedInput.getEncoding(), newBounds);
}
}
return success();
Expand Down Expand Up @@ -1228,11 +1225,10 @@ LogicalResult inferReduceOp(
for (uint64_t inputIdx = 0; inputIdx < inputs.size(); ++inputIdx) {
TensorType inputType = inputArgTypes[inputIdx];
Type elementType = inputType.getElementType();
if (inputType.hasRank()) {
if (inputType.hasRank())
inferredReturnShapes.emplace_back(newDimensions, elementType, encoding);
} else {
else
inferredReturnShapes.emplace_back(elementType);
}
}

return success();
Expand Down

0 comments on commit 707ff46

Please sign in to comment.