Skip to content

Commit

Permalink
Remove ReduceOp::build() to sync with MLIR-HLO (#392)
Browse files Browse the repository at this point in the history
fix #391
Please see the detailed link in the related bug.
  • Loading branch information
Xin Zhou authored Oct 26, 2022
1 parent 8d51fe2 commit 5e41ad2
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 39 deletions.
35 changes: 0 additions & 35 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2737,41 +2737,6 @@ LogicalResult ReducePrecisionOp::verify() {
// ReduceOp
//===----------------------------------------------------------------------===//

// Returns the result type after reducing operand of the given type across the
// specified dimensions.
static TensorType getReduceResultType(Type operandTy,
DenseIntElementsAttr dimensions,
Builder* builder) {
Type elementTy = getElementTypeOrSelf(operandTy);

auto rankedTy = operandTy.dyn_cast<RankedTensorType>();
if (!rankedTy) return UnrankedTensorType::get(elementTy);

int64_t rank = rankedTy.getRank();
llvm::SmallVector<bool, 4> dimsMask(rank, false);
for (int64_t dim : dimensions.getValues<int64_t>()) dimsMask[dim] = true;

SmallVector<int64_t, 4> shape;
for (int64_t i = 0; i < rank; ++i) {
if (!dimsMask[i]) shape.push_back(rankedTy.getDimSize(i));
}

return RankedTensorType::get(shape, elementTy);
}

void ReduceOp::build(OpBuilder& builder, OperationState& state,
ValueRange operands, ValueRange initValues,
DenseIntElementsAttr dimensions) {
SmallVector<Type, 1> resultTy;
resultTy.reserve(operands.size());

for (Value operand : operands) {
resultTy.push_back(
getReduceResultType(operand.getType(), dimensions, &builder));
}
build(builder, state, resultTy, operands, initValues, dimensions);
}

bool hasSameOperandAndResultTypes(Operation& op) {
Type expected;
if (op.getNumResults() != 0) expected = op.getResult(0).getType();
Expand Down
4 changes: 0 additions & 4 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1342,10 +1342,6 @@ def StableHLO_ReduceOp: StableHLO_ShapedInterfaceOp<"reduce", [

let results = (outs Variadic<HLO_Tensor>);

let builders = [
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$init_values,
"DenseIntElementsAttr":$dimensions)>];

let hasCustomAssemblyFormat = 1;

// TODO(hinsu): Verify that the attached body arguments and results are
Expand Down

0 comments on commit 5e41ad2

Please sign in to comment.