Skip to content

Commit

Permalink
Address feedback: I
Browse files Browse the repository at this point in the history
  • Loading branch information
sdasgup3 committed Jan 18, 2024
1 parent 3aee4ac commit aafde87
Showing 1 changed file with 32 additions and 39 deletions.
71 changes: 32 additions & 39 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -568,20 +568,17 @@ LogicalResult verifyReduceOpInputsAndInferShape(

// Returns the types of the terminator arguments of the input mlir::Block
// 'block'.
LogicalResult getAccumulatorTypes(
std::optional<Location> loc, Region& region,
SmallVectorImpl<ShapedType>& accumulatorSubShapes) {
FailureOr<SmallVector<ShapedType>> getAccumulatorTypes(
std::optional<Location> loc, Region& region) {
if (region.empty()) {
return emitOptionalError(
loc, "Expects non-empty reduction block for type inference");
}

Block& block = region.front();
for (Value retOperand : block.getTerminator()->getOperands()) {
auto shapedTy = retOperand.getType().cast<ShapedType>();
accumulatorSubShapes.push_back(shapedTy);
}
return success();
return llvm::to_vector(
llvm::map_range(block.getTerminator()->getOperands(),
[&](Value v) { return v.getType().cast<ShapedType>(); }));
}

LogicalResult verifyReducerShape(std::optional<Location> loc, Block& block,
Expand Down Expand Up @@ -1510,12 +1507,12 @@ LogicalResult inferAllReduceOp(
SmallVector<ShapedType> inputArgTensorTypes{
llvm::map_range(inputTypes, [](Type t) { return t.cast<ShapedType>(); })};
// all_reduce_c6, all_reduce_c7
SmallVector<ShapedType> accumulatorTypes;
if (failed(getAccumulatorTypes(location, computation, accumulatorTypes)))
return failure();
auto accumulatorTypesOrErr = getAccumulatorTypes(location, computation);
if (failed(accumulatorTypesOrErr)) return failure();
for (size_t inputIdx = 0; inputIdx < inputTypes.size(); ++inputIdx) {
inferredReturnShapes.emplace_back(getSameShapeTensorType(
inputArgTensorTypes[inputIdx], accumulatorTypes[0].getElementType()));
inferredReturnShapes.emplace_back(
getSameShapeTensorType(inputArgTensorTypes[inputIdx],
(*accumulatorTypesOrErr)[0].getElementType()));
}

return success();
Expand Down Expand Up @@ -2596,12 +2593,11 @@ LogicalResult inferReduceOp(
location, inputArgTensorTypes, dimensions, newDimensions, encoding)))
return failure();
// reduce_c3, reduce_c7, reduce_c8
SmallVector<ShapedType> accumulatorTypes;
if (failed(getAccumulatorTypes(location, body, accumulatorTypes)))
return failure();
auto accumulatorTypesOrErr = getAccumulatorTypes(location, body);
if (failed(accumulatorTypesOrErr)) return failure();
for (uint64_t inputIdx = 0; inputIdx < inputTypes.size(); ++inputIdx) {
ShapedType inputType = inputArgTensorTypes[inputIdx];
Type elementType = accumulatorTypes[inputIdx].getElementType();
Type elementType = (*accumulatorTypesOrErr)[inputIdx].getElementType();
if (inputType.hasRank())
inferredReturnShapes.emplace_back(newDimensions, elementType, encoding);
else
Expand Down Expand Up @@ -2634,24 +2630,24 @@ LogicalResult inferReduceWindowOp(
return failure();

// reduce_window_c1, reduce_window_c14...reduce_window_c16
SmallVector<ShapedType> accumulatorTypes;
if (failed(getAccumulatorTypes(location, body, accumulatorTypes)))
return failure();
auto accumulatorTypesOrErr = getAccumulatorTypes(location, body);
if (failed(accumulatorTypesOrErr)) return failure();
for (size_t i = 0; i < inputTypes.size(); ++i) {
auto inputRankedType = inputs[i].getType().dyn_cast<RankedTensorType>();
if (!inputRankedType) {
inferredReturnShapes.emplace_back(accumulatorTypes[i].getElementType());
inferredReturnShapes.emplace_back(
(*accumulatorTypesOrErr)[i].getElementType());
} else {
auto resultShape =
inferWindowOutputShape(inputTypes[i].getShape(), inferredWindow);
auto inputBounds = encodingToBounds(inputRankedType.getEncoding());
if (inputBounds.empty()) {
inferredReturnShapes.emplace_back(resultShape,
accumulatorTypes[i].getElementType());
inferredReturnShapes.emplace_back(
resultShape, (*accumulatorTypesOrErr)[i].getElementType());
} else {
auto resultBounds = inferWindowOutputShape(inputBounds, inferredWindow);
inferredReturnShapes.emplace_back(
resultShape, accumulatorTypes[i].getElementType(),
resultShape, (*accumulatorTypesOrErr)[i].getElementType(),
boundsToEncoding(inputRankedType.getEncoding(), resultBounds));
}
}
Expand Down Expand Up @@ -2719,14 +2715,12 @@ LogicalResult inferScatterOp(std::optional<Location> location,
ValueRange inputs, Region& updateComputation,
SmallVectorImpl<Type>& inferredReturnTypes) {
// scatter_c16, scatter_c17
SmallVector<ShapedType> accumulatorTypes;
if (failed(
getAccumulatorTypes(location, updateComputation, accumulatorTypes)))
return failure();
auto accumulatorTypesOrErr = getAccumulatorTypes(location, updateComputation);
if (failed(accumulatorTypesOrErr)) return failure();
for (uint64_t inputIdx = 0; inputIdx < inputs.size(); ++inputIdx) {
auto inputShapedTy = inputs[inputIdx].getType().cast<ShapedType>();
inferredReturnTypes.push_back(getSameShapeTensorType(
inputShapedTy, accumulatorTypes[inputIdx].getElementType()));
inputShapedTy, (*accumulatorTypesOrErr)[inputIdx].getElementType()));
}
return success();
}
Expand Down Expand Up @@ -2760,12 +2754,11 @@ LogicalResult inferSelectAndScatterOp(
std::optional<Location> location, Value operand, Region& scatter,
SmallVectorImpl<Type>& inferredReturnTypes) {
// select_and_scatter_c11, select_and_scatter_c12
SmallVector<ShapedType> accumulatorTypes;
if (failed(getAccumulatorTypes(location, scatter, accumulatorTypes)))
return failure();
auto accumulatorTypesOrErr = getAccumulatorTypes(location, scatter);
if (failed(accumulatorTypesOrErr)) return failure();
auto operandShapedTy = operand.getType().cast<ShapedType>();
inferredReturnTypes.push_back(getSameShapeTensorType(
operandShapedTy, accumulatorTypes[0].getElementType()));
operandShapedTy, (*accumulatorTypesOrErr)[0].getElementType()));
return success();
}

Expand Down Expand Up @@ -3892,13 +3885,13 @@ LogicalResult verifyReduceScatterOp(std::optional<Location> location,
}

// reduce_scatter_c9
SmallVector<ShapedType> accumulatorTypes;
if (failed(getAccumulatorTypes(location, computation, accumulatorTypes)))
return failure();
if (resultType.getElementType() != accumulatorTypes[0].getElementType()) {
auto accumulatorTypesOrErr = getAccumulatorTypes(location, computation);
if (failed(accumulatorTypesOrErr)) return failure();
if (resultType.getElementType() !=
(*accumulatorTypesOrErr)[0].getElementType()) {
return emitOptionalError(location, "result element-type is expected to be ",
accumulatorTypes[0].getElementType(), ", but got ",
resultType.getElementType());
(*accumulatorTypesOrErr)[0].getElementType(),
", but got ", resultType.getElementType());
}

return success();
Expand Down

0 comments on commit aafde87

Please sign in to comment.