Skip to content

Commit

Permalink
Handle unranked input types in inferMostSpecificTypeComponents (ope…
Browse files Browse the repository at this point in the history
…nxla#1046)

This is to fix the case when all the `inputTypes` in
`inferMostSpecificTypeComponents` is unranked.
  • Loading branch information
sdasgup3 authored and GleasonK committed Feb 10, 2023
1 parent 44f3eb2 commit 7177166
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions stablehlo/dialect/Base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,10 +365,16 @@ LogicalResult inferMostSpecificTypeComponents(
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
auto inferredTypeOrErr = inferMostSpecificType(location, inputTypes);
if (failed(inferredTypeOrErr)) return failure();
auto resultType = (*inferredTypeOrErr).cast<RankedTensorType>();
inferredReturnShapes.emplace_back(resultType.getShape(),
resultType.getElementType(),
resultType.getEncoding());

auto rankedResultType = (*inferredTypeOrErr).dyn_cast<RankedTensorType>();
if (!rankedResultType) {
inferredReturnShapes.emplace_back(*inferredTypeOrErr);
} else {
inferredReturnShapes.emplace_back(rankedResultType.getShape(),
rankedResultType.getElementType(),
rankedResultType.getEncoding());
}

return success();
}

Expand Down

0 comments on commit 7177166

Please sign in to comment.