Skip to content

Commit

Permalink
Some more fixes, cleanup, refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
zahimoud committed Apr 28, 2023
1 parent a4ee545 commit 0171697
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 31 deletions.
4 changes: 3 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ struct ConvertLayoutOpConversion
unsigned dim = sliceLayout.getDim();
auto parentEncoding = sliceLayout.getParent();
auto parentSizePerThread = getSizePerThread(parentEncoding);
unsigned stride = parentSizePerThread[dim];
unsigned stride = 1;
if (getOrder(parentEncoding)[0] == dim)
stride = parentSizePerThread[dim];
auto parentShape = sliceLayout.paddedShape(shape);
auto parentTy = RankedTensorType::get(parentShape, type.getElementType(),
parentEncoding);
Expand Down
30 changes: 1 addition & 29 deletions lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ class ConvertTritonGPUOpToLLVMPatternBase {
} else if (auto mma = layout.dyn_cast<MmaEncodingAttr>()) {
result = emitIndicesForDistributedLayout(loc, b, mma, type);
} else if (auto slice = layout.dyn_cast<SliceEncodingAttr>()) {
result = emitIndicesForSliceLayout(loc, b, slice, type);
result = emitIndicesForDistributedLayout(loc, b, slice, type);
} else {
llvm_unreachable(
"emitIndices for layouts other than blocked & slice not "
Expand Down Expand Up @@ -913,34 +913,6 @@ class ConvertTritonGPUOpToLLVMPatternBase {
return resultOffsets;
}

SmallVector<SmallVector<Value>>
emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &rewriter,
const SliceEncodingAttr &sliceLayout,
RankedTensorType type) const {
auto parentEncoding = sliceLayout.getParent();
auto parentShape = sliceLayout.paddedShape(type.getShape());
RankedTensorType parentTy = RankedTensorType::get(
parentShape, type.getElementType(), parentEncoding);

unsigned dim = sliceLayout.getDim();
// step 1, delinearize threadId to get the base index
auto multiDimBase =
emitBaseIndexForLayout(loc, rewriter, sliceLayout, type);
// step 2, get offset of each element
auto offset = emitOffsetForSliceLayout(sliceLayout, type);
// step 3, add offset to base, and reorder the sequence of indices to
// guarantee that elems in the same sizePerThread are adjacent in order
auto shape = type.getShape();
unsigned rank = shape.size();
unsigned elemsPerThread = offset.size();
SmallVector<SmallVector<Value>> multiDimIdx(elemsPerThread,
SmallVector<Value>(rank));
for (unsigned n = 0; n < elemsPerThread; ++n)
for (unsigned k = 0; k < rank; ++k)
multiDimIdx[n][k] = add(multiDimBase[k], i32_val(offset[n][k]));
return multiDimIdx;
}

protected:
TritonGPUToLLVMTypeConverter *converter;
const Allocation *allocation;
Expand Down
1 change: 0 additions & 1 deletion lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(),
blockedLayout.getSizePerThread().end());
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
// TODO: maybe should not be supported
auto sizePerThread = getSizePerThread(sliceLayout.getParent());
sizePerThread.erase(sizePerThread.begin() + sliceLayout.getDim());
return sizePerThread;
Expand Down

0 comments on commit 0171697

Please sign in to comment.