Skip to content

Commit

Permalink
Port type inference for 6 ops from StableHLO to MHLO
Browse files Browse the repository at this point in the history
Ops:
  1) AfterAllOp: openxla/stablehlo#708.
  2) CreateTokenOp: openxla/stablehlo#711.
  3) DynamicUpdateSliceOp: openxla/stablehlo#686 and openxla/stablehlo#757.
  4) OptimizationBarrierOp: openxla/stablehlo#575.
  5) OutfeedOp: openxla/stablehlo#713.
  6) SendOp: openxla/stablehlo#580.

This PR prepares for migration from producing MHLO to producing StableHLO by
aligning type inference between dialects, so that switching from one to another
doesn't need changes to calls to Python builders.

PiperOrigin-RevId: 495404149
  • Loading branch information
Eugene Burmako authored and copybara-github committed Dec 14, 2022
1 parent f1647b0 commit 4567fff
Showing 1 changed file with 117 additions and 0 deletions.
117 changes: 117 additions & 0 deletions third_party/stablehlo/temporary.patch
Original file line number Diff line number Diff line change
@@ -1 +1,118 @@
diff --ruN a/stablehlo/stablehlo/dialect/Base.h b/stablehlo/stablehlo/dialect/Base.h
--- stablehlo/stablehlo/dialect/Base.h
+++ stablehlo/stablehlo/dialect/Base.h
@@ -45,9 +45,12 @@

// TODO(zhouxin) change to a better name as it's used by both of size and bound
// Check if the dimension size is dynamic.
-// TODO(zhouxin) add isStaticDimSize() as well.
inline static bool isDynamicDimSize(int64_t val) {
return ShapedType::isDynamic(val);
+}
+
+inline static bool isStaticDimSize(int64_t val) {
+ return !isDynamicDimSize(val);
}

// Returns true if the given types are the same for the purposes of HLO type
diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo/dialect/TypeInference.cpp
--- stablehlo/stablehlo/dialect/TypeInference.cpp
+++ stablehlo/stablehlo/dialect/TypeInference.cpp
@@ -960,18 +960,19 @@
auto updateType = update.getType().cast<ShapedType>();

// (C3)
- int64_t operandRank = operandType.getRank();
- int64_t updateRank = updateType.getRank();
- if (updateRank != operandRank)
+ if (updateType.hasRank() && operandType.hasRank() &&
+ updateType.getRank() != operandType.getRank())
return emitOptionalError(
- location, "update rank does not match operand rank: ", updateRank,
- " vs ", operandRank, ".");
+ location,
+ "update rank does not match operand rank: ", updateType.getRank(),
+ " vs ", operandType.getRank(), ".");

// (C4)
- if ((int64_t)startIndices.size() != operandRank)
+ if (operandType.hasRank() &&
+ (int64_t)startIndices.size() != operandType.getRank())
return emitOptionalError(
location, "expects number of start_indices to match operand rank: ",
- startIndices.size(), " vs ", operandRank, ".");
+ startIndices.size(), " vs ", operandType.getRank(), ".");

// (C5)
if (!startIndices.empty()) {
@@ -989,17 +990,31 @@
}

// (C6)
- for (auto [index, dims] : llvm::enumerate(
- llvm::zip(operandType.getShape(), updateType.getShape()))) {
- auto [operandDim, updateDim] = dims;
- if (updateDim < 0 || updateDim > operandDim)
- return emitOptionalError(location, "expects size at dimension ", index,
- " of update to be in range [0, ", operandDim,
- "]. Got: ", updateDim, ".");
- }
-
- inferredReturnShapes.emplace_back(operandType.getShape(),
- operandType.getElementType());
+ if (operandType.hasRank() && updateType.hasRank())
+ for (auto [index, dims] : llvm::enumerate(
+ llvm::zip(operandType.getShape(), updateType.getShape()))) {
+ auto [operandDim, updateDim] = dims;
+ if (hlo::isDynamicDimSize(updateDim)) continue;
+ if (hlo::isStaticDimSize(operandDim)) {
+ if (updateDim < 0 || updateDim > operandDim)
+ return emitOptionalError(location, "expects size at dimension ",
+ index, " of update to be in range [0, ",
+ operandDim, "]. Got: ", updateDim, ".");
+ } else {
+ if (updateDim < 0)
+ return emitOptionalError(
+ location, "expects size at dimension ", index,
+ " of update to be non-negative. Got: ", updateDim, ".");
+ }
+ }
+
+ // (C1)
+ if (operandType.hasRank()) {
+ inferredReturnShapes.emplace_back(operandType.getShape(),
+ operandType.getElementType());
+ } else {
+ inferredReturnShapes.emplace_back(operandType.getElementType());
+ }
return success();
}

@@ -1302,6 +1317,13 @@
LogicalResult inferSelectAndScatterOp(
Value operand, SmallVectorImpl<Type>& inferredReturnTypes) {
inferredReturnTypes.push_back(operand.getType());
+ return success();
+}
+
+LogicalResult inferSendOp(Dialect* dialect, Optional<Location> location,
+ SmallVectorImpl<Type>& inferredReturnTypes) {
+ auto hloDialect = cast<HloDialectInterface>(dialect);
+ inferredReturnTypes.push_back(hloDialect->createTokenType());
return success();
}

diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.h b/stablehlo/stablehlo/dialect/TypeInference.h
--- stablehlo/stablehlo/dialect/TypeInference.h
+++ stablehlo/stablehlo/dialect/TypeInference.h
@@ -197,6 +197,9 @@
LogicalResult inferSelectAndScatterOp(
Value operand, SmallVectorImpl<Type>& inferredReturnTypes);

+LogicalResult inferSendOp(Dialect* dialect, Optional<Location> location,
+ SmallVectorImpl<Type>& inferredReturnTypes);
+
LogicalResult inferSliceOp(Optional<Location> location, Value operand,
DenseIntElementsAttr startIndices,
DenseIntElementsAttr limitIndices,

0 comments on commit 4567fff

Please sign in to comment.