-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Port type inference for 6 ops from StableHLO to MHLO
Ops: 1) AfterAllOp: openxla/stablehlo#708. 2) CreateTokenOp: openxla/stablehlo#711. 3) DynamicUpdateSliceOp: openxla/stablehlo#686 and openxla/stablehlo#757. 4) OptimizationBarrierOp: openxla/stablehlo#575. 5) OutfeedOp: openxla/stablehlo#713. 6) SendOp: openxla/stablehlo#580. This PR prepares for migration from producing MHLO to producing StableHLO by aligning type inference between dialects, so that switching from one to another doesn't need changes to calls to Python builders. PiperOrigin-RevId: 495404149
- Loading branch information
1 parent
f1647b0
commit 4567fff
Showing
1 changed file
with
117 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,118 @@ | ||
diff --ruN a/stablehlo/stablehlo/dialect/Base.h b/stablehlo/stablehlo/dialect/Base.h | ||
--- stablehlo/stablehlo/dialect/Base.h | ||
+++ stablehlo/stablehlo/dialect/Base.h | ||
@@ -45,9 +45,12 @@ | ||
|
||
// TODO(zhouxin) change to a better name as it's used by both of size and bound | ||
// Check if the dimension size is dynamic. | ||
-// TODO(zhouxin) add isStaticDimSize() as well. | ||
inline static bool isDynamicDimSize(int64_t val) { | ||
return ShapedType::isDynamic(val); | ||
+} | ||
+ | ||
+inline static bool isStaticDimSize(int64_t val) { | ||
+ return !isDynamicDimSize(val); | ||
} | ||
|
||
// Returns true if the given types are the same for the purposes of HLO type | ||
diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo/dialect/TypeInference.cpp | ||
--- stablehlo/stablehlo/dialect/TypeInference.cpp | ||
+++ stablehlo/stablehlo/dialect/TypeInference.cpp | ||
@@ -960,18 +960,19 @@ | ||
auto updateType = update.getType().cast<ShapedType>(); | ||
|
||
// (C3) | ||
- int64_t operandRank = operandType.getRank(); | ||
- int64_t updateRank = updateType.getRank(); | ||
- if (updateRank != operandRank) | ||
+ if (updateType.hasRank() && operandType.hasRank() && | ||
+ updateType.getRank() != operandType.getRank()) | ||
return emitOptionalError( | ||
- location, "update rank does not match operand rank: ", updateRank, | ||
- " vs ", operandRank, "."); | ||
+ location, | ||
+ "update rank does not match operand rank: ", updateType.getRank(), | ||
+ " vs ", operandType.getRank(), "."); | ||
|
||
// (C4) | ||
- if ((int64_t)startIndices.size() != operandRank) | ||
+ if (operandType.hasRank() && | ||
+ (int64_t)startIndices.size() != operandType.getRank()) | ||
return emitOptionalError( | ||
location, "expects number of start_indices to match operand rank: ", | ||
- startIndices.size(), " vs ", operandRank, "."); | ||
+ startIndices.size(), " vs ", operandType.getRank(), "."); | ||
|
||
// (C5) | ||
if (!startIndices.empty()) { | ||
@@ -989,17 +990,31 @@ | ||
} | ||
|
||
// (C6) | ||
- for (auto [index, dims] : llvm::enumerate( | ||
- llvm::zip(operandType.getShape(), updateType.getShape()))) { | ||
- auto [operandDim, updateDim] = dims; | ||
- if (updateDim < 0 || updateDim > operandDim) | ||
- return emitOptionalError(location, "expects size at dimension ", index, | ||
- " of update to be in range [0, ", operandDim, | ||
- "]. Got: ", updateDim, "."); | ||
- } | ||
- | ||
- inferredReturnShapes.emplace_back(operandType.getShape(), | ||
- operandType.getElementType()); | ||
+ if (operandType.hasRank() && updateType.hasRank()) | ||
+ for (auto [index, dims] : llvm::enumerate( | ||
+ llvm::zip(operandType.getShape(), updateType.getShape()))) { | ||
+ auto [operandDim, updateDim] = dims; | ||
+ if (hlo::isDynamicDimSize(updateDim)) continue; | ||
+ if (hlo::isStaticDimSize(operandDim)) { | ||
+ if (updateDim < 0 || updateDim > operandDim) | ||
+ return emitOptionalError(location, "expects size at dimension ", | ||
+ index, " of update to be in range [0, ", | ||
+ operandDim, "]. Got: ", updateDim, "."); | ||
+ } else { | ||
+ if (updateDim < 0) | ||
+ return emitOptionalError( | ||
+ location, "expects size at dimension ", index, | ||
+ " of update to be non-negative. Got: ", updateDim, "."); | ||
+ } | ||
+ } | ||
+ | ||
+ // (C1) | ||
+ if (operandType.hasRank()) { | ||
+ inferredReturnShapes.emplace_back(operandType.getShape(), | ||
+ operandType.getElementType()); | ||
+ } else { | ||
+ inferredReturnShapes.emplace_back(operandType.getElementType()); | ||
+ } | ||
return success(); | ||
} | ||
|
||
@@ -1302,6 +1317,13 @@ | ||
LogicalResult inferSelectAndScatterOp( | ||
Value operand, SmallVectorImpl<Type>& inferredReturnTypes) { | ||
inferredReturnTypes.push_back(operand.getType()); | ||
+ return success(); | ||
+} | ||
+ | ||
+LogicalResult inferSendOp(Dialect* dialect, Optional<Location> location, | ||
+ SmallVectorImpl<Type>& inferredReturnTypes) { | ||
+ auto hloDialect = cast<HloDialectInterface>(dialect); | ||
+ inferredReturnTypes.push_back(hloDialect->createTokenType()); | ||
return success(); | ||
} | ||
|
||
diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.h b/stablehlo/stablehlo/dialect/TypeInference.h | ||
--- stablehlo/stablehlo/dialect/TypeInference.h | ||
+++ stablehlo/stablehlo/dialect/TypeInference.h | ||
@@ -197,6 +197,9 @@ | ||
LogicalResult inferSelectAndScatterOp( | ||
Value operand, SmallVectorImpl<Type>& inferredReturnTypes); | ||
|
||
+LogicalResult inferSendOp(Dialect* dialect, Optional<Location> location, | ||
+ SmallVectorImpl<Type>& inferredReturnTypes); | ||
+ | ||
LogicalResult inferSliceOp(Optional<Location> location, Value operand, | ||
DenseIntElementsAttr startIndices, | ||
DenseIntElementsAttr limitIndices, | ||
|