Skip to content

Commit

Permalink
[TOSA] Fix Tensor.hacked_twin to support diff size indexes
Browse files Browse the repository at this point in the history
- Broadcasts index list

Signed-off-by: Suraj Sudhir <[email protected]>
  • Loading branch information
sjarus committed Jul 18, 2024
1 parent f0ce1e9 commit 50f0589
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 6 deletions.
123 changes: 118 additions & 5 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3797,13 +3797,126 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
indicesTfConcatTensors.push_back(indicesTfOneDim.getResult());
}

// Right now only support multiple indexes with same shape
// TODO for different shape multiple indexes, add broadcast_to for small
// shape
auto getRankExtendedShape =
[](SmallVector<int64_t> inputShape,
SmallVector<int64_t> maxRank1DimShape) -> SmallVector<int64_t> {
SmallVector<int64_t> rankExtendedShape(maxRank1DimShape);
auto inputRank = inputShape.size();
auto maxRank = maxRank1DimShape.size();
auto startIdx = maxRank - inputRank;
for (size_t i = startIdx; i < maxRank; i++) {
rankExtendedShape[i] = inputShape[i - startIdx];
}
return rankExtendedShape;
};

bool hasDiffShapedIndexes = false;
for (auto indexShapeOneDim : indexesShape) {
if (!llvm::equal(indexesShape[0], indexShapeOneDim)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: Only support multi indexes with same shape");
hasDiffShapedIndexes = true;
break;
}
}

if (hasDiffShapedIndexes) {
int64_t maxRank = 1;
for (auto idxRank : indexesRank) {
if (idxRank > maxRank)
maxRank = idxRank;
}
// Tensor shape of max rank, each dim being 1
SmallVector<int64_t> maxRank1DimShape;
for (int i = 0; i < maxRank; i++)
maxRank1DimShape.push_back(1);
// Tensor shape of max rank, each dim being the max dim.
SmallVector<int64_t> maxRankMaxDimShape(maxRank1DimShape);

auto updateMaxRankMaxDimShape =
[&](SmallVector<int64_t> broadcastedShape) -> LogicalResult {
for (size_t i = 0; i < maxRankMaxDimShape.size(); i++) {
// check for malformed index tensors
if (broadcastedShape[i] != 1 && maxRankMaxDimShape[i] != 1 &&
maxRankMaxDimShape[i] != broadcastedShape[i]) {
return failure();
}
if (broadcastedShape[i] > maxRankMaxDimShape[i])
maxRankMaxDimShape[i] = broadcastedShape[i];
}
return success();
};

for (size_t i = 0; i < indexesRank.size(); i++) {
// Reshape all index tensors to same maxRank
auto idxRank = indexesRank[i];
auto unreshapedIdxTensor = indicesTfConcatTensors[i];
SmallVector<int64_t> broadcastedShape =
getRankExtendedShape(indexesShape[i], maxRank1DimShape);

if (idxRank < maxRank) {
auto idxType =
dyn_cast<RankedTensorType>(indicesTfConcatTensors[i].getType());
// indicesTfConcatTensors has a trailing [1] dim for the final concat.
auto broadcastedShapeTf(broadcastedShape);
broadcastedShapeTf.push_back(1);
auto reshapeOutputTy = RankedTensorType::get(
broadcastedShapeTf, idxType.getElementType());
// Update the tensor array with the max rank-extended form
indicesTfConcatTensors[i] = rewriter.create<tosa::ReshapeOp>(
op->getLoc(), reshapeOutputTy, unreshapedIdxTensor,
rewriter.getDenseI64ArrayAttr(broadcastedShapeTf));
}

// Construct the max rank broadcasted form of all index tensors with
// each index tensor.
if (updateMaxRankMaxDimShape(broadcastedShape).failed()) {
return rewriter.notifyMatchFailure(
op, "Malformed index tensors that have mismatched dim shapes");
}

// Every index now has the same rank but not yet same shape until
// tosa.tile below.
indexesShape[i] = broadcastedShape;
indexesRank[i] = maxRank;
}

auto getTileOpShape = [&](SmallVector<int64_t> indexShape,
SmallVector<int64_t> &tileOpShape) -> bool {
bool needsTiling = false;
for (size_t i = 0; i < indexShape.size(); i++) {
if (1 == indexShape[i]) {
tileOpShape.push_back(maxRankMaxDimShape[i]);
needsTiling = true;
} else {
tileOpShape.push_back(1);
}
}
return needsTiling;
};

// Use tosa.tile to broadcast in multiple dims so all index tensors have
// the same shape. This materializes new tensors.
for (size_t i = 0; i < indexesRank.size(); i++) {
SmallVector<int64_t> tileOpShape;
bool needsTiling = getTileOpShape(indexesShape[i], tileOpShape);

if (needsTiling) {
auto idxType =
dyn_cast<RankedTensorType>(indicesTfConcatTensors[i].getType());
// indicesTfConcatTensors has a trailing [1] dim for the final concat.
auto maxRankMaxDimShapeTf(maxRankMaxDimShape);
maxRankMaxDimShapeTf.push_back(1);
auto tileOpShapeTf(tileOpShape);
tileOpShapeTf.push_back(1);
auto tileOutputTy = RankedTensorType::get(maxRankMaxDimShapeTf,
idxType.getElementType());
auto reshapedIdxTensor = indicesTfConcatTensors[i];
indicesTfConcatTensors[i] = rewriter.create<tosa::TileOp>(
op->getLoc(), tileOutputTy, reshapedIdxTensor,
rewriter.getDenseI64ArrayAttr(tileOpShapeTf));
}

// Every index tensor now has the same rank and shape
indexesShape[i] = maxRankMaxDimShape;
}
}

Expand Down
5 changes: 4 additions & 1 deletion projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
# this is added to check the torch.onnx.export -> import_onnx -> torch path
"DeformConv2D_basic",
"ReduceAnyDimFloatModule_basic",
"UnfoldModule_basic",
}

LINALG_CRASHING_SET = {
Expand Down Expand Up @@ -1981,6 +1982,8 @@
"TorchPrimLoopForLikeTensorArgModule_basic",
"RenormModuleFloat32NegativeDim_basic",
"RenormModuleFloat32_basic",
"IndexTensorStaticContiguousWithNoneModule_basic",
"IndexTensorStaticNonContiguousWithNoneModule_basic",
}

MAKE_FX_TOSA_PASS_SET = (
Expand Down Expand Up @@ -2740,6 +2743,7 @@
"ReduceAnyFloatModule_basic",
"ReduceMaxAlongDimUnsignedInt_basic",
"ReduceMinAlongDimUnsignedInt_basic",
"UnfoldModule_basic",
}

if torch_version_for_comparison() < version.parse("2.3.0.dev"):
Expand Down Expand Up @@ -3177,7 +3181,6 @@
"IndexSelectWholeTensorModule_basic",
"IndexTensorDyanmicInputContiguousWithNoneModule_basic",
"IndexTensorDyanmicInputNonContiguousWithNoneModule_basic",
"IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic",
"IndexTensorMultiInputContiguousCenter_basic",
"IndexTensorMultiInputContiguousOneDimDynamic_basic",
"IndexTensorMultiInputNonContiguousDynamic_basic",
Expand Down
24 changes: 24 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5646,3 +5646,27 @@ def AtenKthvalueFloat64DynamicDimsModule_basic(module, tu: TestUtils):
module.forward(
torch.randperm(4 * 2 * 8 * 3, dtype=torch.float64).reshape(4, 2, 8, 3)
)


# ==============================================================================


class UnfoldModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.unfold = torch.nn.Unfold(kernel_size=(2, 3))

@export
@annotate_args(
[
None,
([-1, -1, -1, -1], torch.float32, True),
]
)
def forward(self, input):
return self.unfold(input)


@register_test_case(module_factory=lambda: UnfoldModule())
def UnfoldModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5, 3, 4))

0 comments on commit 50f0589

Please sign in to comment.