Skip to content

Commit

Permalink
AIRSpecializeChannelWrapAndStride, AIRRtToNPU: Refactor wrap-and-stri…
Browse files Browse the repository at this point in the history
…de canonicalizer (#797)

* Have air-specialize-channel-wrap-and-stride and airrt-to-npu passes use the same util methods

* Clean up commented code

* Enable specializing both symbol and dim inputs of affine.apply into for loops
  • Loading branch information
erwei-xilinx authored Nov 28, 2024
1 parent 3aed973 commit c1218d2
Show file tree
Hide file tree
Showing 7 changed files with 384 additions and 411 deletions.
21 changes: 17 additions & 4 deletions mlir/include/air/Util/Util.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,9 @@ std::optional<int> getOffsetDimFromMemrefDim(int dimOnMemref,

// Evaluate the affine expression of affine map on a sparse vector of constant
// ints.
std::optional<int64_t>
evaluateConstantsInMap(AffineMap map,
SmallVector<std::optional<int64_t>> const_inputs,
MLIRContext *ctx);
std::optional<int64_t> evaluateConstantsInMap(
AffineMap map, SmallVector<std::optional<int64_t>> symbolInputs,
SmallVector<std::optional<int64_t>> dimInputs, MLIRContext *ctx);

// Extend the lookupOrDefault method to operate on a vector of values.
Value lookupOrDefaultRange(Value v, IRMapping &remap);
Expand All @@ -239,6 +238,20 @@ SmallVector<Value> lookupOrDefaultRange(OperandRange vec, IRMapping &remap);
// Extend isPure method to operate on air.execute.
bool isPure(Operation *op);

// Return if the given block contains N ops which are impure and aren't async
// wait ops (such as air.wait_all).
bool hasNImpureOps(Block *block, unsigned N);

// Return if the given block contains N ops or not, not counting the block's
// terminator.
bool hasNElements(Block *block, unsigned N);

// Clone backward slices of a list of values.
SmallVector<Operation *> cloneDefiningOpsInRegion(OpBuilder builder,
Region *region,
SmallVectorImpl<Value> &opers,
IRMapping &remap);

} // namespace air
} // namespace xilinx

Expand Down
312 changes: 140 additions & 172 deletions mlir/lib/Conversion/AIRRtToNpuPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -770,190 +770,155 @@ void enforceAIE2WrapLimit(ModuleOp module) {
tileIllegalWrapDim(memcpy_op);
}

LogicalResult
specializeAffineForInAIRRtDmaWrapAndStride(OpBuilder builder,
affine::AffineForOp for_op) {
auto loc = for_op->getLoc();
auto ctx = for_op->getContext();

// Declaration of constants
auto i64Ty = builder.getI64Type();
auto i64_zero =
builder.create<arith::ConstantOp>(loc, i64Ty, IntegerAttr::get(i64Ty, 0));
auto i64_one =
builder.create<arith::ConstantOp>(loc, i64Ty, IntegerAttr::get(i64Ty, 1));

// Check if the loop is the outermost loop in a perfect loop nest
auto hasNElements = [](Block *block, unsigned N) {
auto op_ptr = block->begin();
for (unsigned i = 0; i < N; i++)
op_ptr = std::next(op_ptr);
return op_ptr != block->end() && &*op_ptr == &block->back();
};
if (auto parent_for = dyn_cast<affine::AffineForOp>(for_op->getParentOp()))
if (hasNElements(parent_for.getBody(), 1))
struct AIRSpecializeAIRRtDmaWrapAndStrideInAffineFor
: public OpRewritePattern<affine::AffineForOp> {
using OpRewritePattern<affine::AffineForOp>::OpRewritePattern;

LogicalResult matchAndRewrite(affine::AffineForOp for_op,
PatternRewriter &rewriter) const override {
auto loc = for_op->getLoc();
auto ctx = for_op->getContext();

// Declaration of constants
auto i64Ty = rewriter.getI64Type();
auto i64_zero = rewriter.create<arith::ConstantOp>(
loc, i64Ty, IntegerAttr::get(i64Ty, 0));
auto i64_one = rewriter.create<arith::ConstantOp>(
loc, i64Ty, IntegerAttr::get(i64Ty, 1));

if (!air::hasNImpureOps(for_op.getBody(), 1))
return failure();

// Check if the loop nest contains exactly one memcpy op
SmallVector<airrt::DmaMemcpyNdOp> memcpy_ops;
for_op.getBody()->walk(
[&](airrt::DmaMemcpyNdOp putget) { memcpy_ops.push_back(putget); });
if (memcpy_ops.size() != 1)
return failure();
// Check if the loop contains exactly one memcpy op
if (llvm::range_size(for_op.getBody()->getOps<airrt::DmaMemcpyNdOp>()) != 1)
return failure();
airrt::DmaMemcpyNdOp memcpy_op =
*(for_op.getBody()->getOps<airrt::DmaMemcpyNdOp>().begin());

// Fold for loops into memcpy op's wrap and stride fields
auto memref = memcpy_op->getOperand(3);
auto memref_shape = xilinx::air::getTensorShape(memref.getType());
auto oper_begin = memcpy_op.getOperands().begin();
SmallVector<Value> offsets(oper_begin + 4, oper_begin + 8);
SmallVector<Value> wraps(oper_begin + 8, oper_begin + 12);
SmallVector<Value> strides(oper_begin + 12, oper_begin + 15);
// Stride field implicit last element one
strides.push_back(i64_one);

(void)air::canonicalizeWrapAndStrideList(
rewriter, offsets, wraps, strides,
air::getTensorVolume(memref.getType()));

// If empty offsets/sizes/strides, then populate the lists with default
// values.
if (offsets.empty() && wraps.empty() && strides.empty())
air::populateDefaultWrapsAndStrides(rewriter, memref, offsets, wraps,
strides);

auto res = air::foldForLoopNestAsExtendedSizesAndStrides(
rewriter, for_op.getOperation(), memcpy_op.getOperation(), offsets,
wraps, strides, memref);
if (res.failed())
return failure();

// Fold for loops into channel op's wrap and stride fields
auto memref = memcpy_ops[0]->getOperand(3);
auto memref_shape = xilinx::air::getTensorShape(memref.getType());
auto oper_begin = memcpy_ops[0].getOperands().begin();
SmallVector<Value> offsets(oper_begin + 4, oper_begin + 8);
SmallVector<Value> wraps(oper_begin + 8, oper_begin + 12);
SmallVector<Value> strides(oper_begin + 12, oper_begin + 15);
// Stride field implicit last element one
strides.push_back(i64_one);

// Canonicalize wraps and strides
(void)air::canonicalizeWrapAndStrideList(
builder, offsets, wraps, strides, air::getTensorVolume(memref.getType()));

// If empty offsets/sizes/strides, then populate the lists with default
// values.
if (offsets.empty() && wraps.empty() && strides.empty()) {
auto memref_shape = air::getTensorShape(memref.getType());
int current_stride = air::getTensorVolume(memref.getType());
for (unsigned i = 0; i < memref_shape.size(); i++) {
offsets.push_back(builder.create<arith::ConstantIndexOp>(loc, 0));
wraps.push_back(
builder.create<arith::ConstantIndexOp>(loc, memref_shape[i]));
current_stride /= memref_shape[i];
strides.push_back(
builder.create<arith::ConstantIndexOp>(loc, current_stride));
}
}
auto res = xilinx::air::foldForLoopNestAsExtendedSizesAndStrides(
builder, for_op.getOperation(), memcpy_ops[0].getOperation(), offsets,
wraps, strides, memcpy_ops[0]->getOperand(3));
if (res.failed())
return failure();
if (offsets.size() > 4 || wraps.size() > 4 || strides.size() > 4)
return failure();

if (offsets.size() > 4 || wraps.size() > 4 || strides.size() > 4)
return failure();
(void)air::canonicalizeWrapAndStrideList(
rewriter, offsets, wraps, strides,
air::getTensorVolume(memref.getType()));

// Stride field implicit last element one
strides.pop_back();
while (offsets.size() < 4) {
offsets.insert(offsets.begin(), i64_zero);
}
while (wraps.size() < 4) {
wraps.insert(wraps.begin(), i64_one);
}
while (strides.size() < 3) {
strides.insert(strides.begin(), i64_one);
}
// Stride field implicit last element one
strides.pop_back();
while (offsets.size() < 4) {
offsets.insert(offsets.begin(), i64_zero);
}
while (wraps.size() < 4) {
wraps.insert(wraps.begin(), i64_one);
}
while (strides.size() < 3) {
strides.insert(strides.begin(), i64_one);
}

// Stride = 0 means repeat that dimension. If highest dimension (dim 0) is not
// used, then move the repeat dimension to dim 0, which is the only dim with
// repeat capability. Else, NYI. Fall back to unrolling BDs.
for (unsigned i = 1; i < strides.size(); i++) {
if (mlir::getConstantIntValue(wraps[i]) &&
mlir::getConstantIntValue(strides[i])) {
if (*mlir::getConstantIntValue(wraps[i]) > 1 &&
!*mlir::getConstantIntValue(strides[i])) {
// This is a repeat dimension.
if (mlir::getConstantIntValue(wraps[0]) &&
*mlir::getConstantIntValue(wraps[0]) == 1) {
// Move the repeat dimension i to dimension 0.
auto tmp = wraps[0];
wraps[0] = wraps[i];
wraps[i] = tmp;
tmp = strides[0];
strides[0] = strides[i];
strides[i] = tmp;
} else
return failure();
// Stride = 0 means repeat that dimension. If highest dimension (dim 0) is
// not used, then move the repeat dimension to dim 0, which is the only dim
// with repeat capability. Else, fall back to unrolling BDs.
unsigned activeDimsInBetween = 0;
for (unsigned i = 1; i < strides.size(); i++) {
auto constWrap = mlir::getConstantIntValue(wraps[i]);
auto constStride = mlir::getConstantIntValue(strides[i]);
if (!constWrap)
continue;
if (!constStride)
continue;
if (*constWrap <= 1)
continue; // Inactive dimension. Continue.
if (*constStride) {
// Found active dimension after dim 0. Any subsequent repeat dimension
// shall not bump to dim 0 anymore.
activeDimsInBetween++;
continue;
}
// This is a repeat dimension.
if (mlir::getConstantIntValue(wraps[0]) &&
*mlir::getConstantIntValue(wraps[0]) == 1 && !activeDimsInBetween) {
// Dimension 0 is available. Move the repeat dimension i to dimension 0.
auto tmp = wraps[0];
wraps[0] = wraps[i];
wraps[i] = tmp;
tmp = strides[0];
strides[0] = strides[i];
strides[i] = tmp;
} else {
(void)loopUnrollFull(for_op);
return success();
}
}
}

// Create new airrt.dma_memcpy_nd
SmallVector<Type, 1> tys;
if (memcpy_ops[0]->getNumResults())
tys.push_back(airrt::EventType::get(ctx));

SmallVector<Value, 16> opers;
auto old_opers = memcpy_ops[0]->getOperands();
opers.insert(opers.end(), old_opers.begin(), old_opers.begin() + 4);
opers[1] =
builder.create<arith::ConstantOp>(loc, i64Ty, IntegerAttr::get(i64Ty, 0));
opers[2] =
builder.create<arith::ConstantOp>(loc, i64Ty, IntegerAttr::get(i64Ty, 0));
opers.insert(opers.end(), offsets.begin(), offsets.end());
opers.insert(opers.end(), wraps.begin(), wraps.end());
opers.insert(opers.end(), strides.begin(), strides.end());

// index_cast
IRMapping indexOperMap;
for (unsigned i = 0; i < opers.size(); i++) {
if (opers[i].getDefiningOp() &&
isa<arith::ConstantIndexOp>(opers[i].getDefiningOp())) {
opers[i] =
builder.clone(*opers[i].getDefiningOp(), indexOperMap)->getResult(0);
opers[i] = builder.create<arith::IndexCastOp>(
loc, IntegerType::get(ctx, 64), opers[i]);
} else if (opers[i].getDefiningOp() &&
isa<arith::IndexCastOp>(opers[i].getDefiningOp())) {
auto castOp = dyn_cast<arith::IndexCastOp>(opers[i].getDefiningOp());
if (castOp.getOperand().getDefiningOp() &&
isa<arith::ConstantOp>(castOp.getOperand().getDefiningOp()))
builder.clone(*castOp.getOperand().getDefiningOp(), indexOperMap);
opers[i] = builder.clone(*castOp, indexOperMap)->getResult(0);
// Create new airrt.dma_memcpy_nd
SmallVector<Type, 1> tys;
if (memcpy_op->getNumResults())
tys.push_back(airrt::EventType::get(ctx));

SmallVector<Value, 16> opers;
auto old_opers = memcpy_op->getOperands();
opers.insert(opers.end(), old_opers.begin(), old_opers.begin() + 4);
opers[1] = rewriter.create<arith::ConstantOp>(loc, i64Ty,
IntegerAttr::get(i64Ty, 0));
opers[2] = rewriter.create<arith::ConstantOp>(loc, i64Ty,
IntegerAttr::get(i64Ty, 0));
opers.insert(opers.end(), offsets.begin(), offsets.end());
opers.insert(opers.end(), wraps.begin(), wraps.end());
opers.insert(opers.end(), strides.begin(), strides.end());

// Hoist const ops; create index_cast ops.
IRMapping remap;
for (unsigned i = 0; i < opers.size(); i++) {
auto defOp = opers[i].getDefiningOp<arith::ConstantOp>();
if (!defOp)
continue;
if (opers[i].getType() == memcpy_op->getOperandTypes()[i])
continue;
opers[i] = rewriter.clone(*defOp, remap)->getResult(0);
opers[i] = getValueOrCreateCastToIndexLike(
rewriter, loc, memcpy_op->getOperandTypes()[i], opers[i]);
}
}
auto new_dma = builder.create<airrt::DmaMemcpyNdOp>(loc, tys, opers);
new_dma->setAttrs(memcpy_ops[0]->getDiscardableAttrDictionary());

return success();
}
// Hoist any pure ops that the new channel op depends on.
(void)air::cloneDefiningOpsInRegion(rewriter, &for_op.getRegion(), opers,
remap);

void specializeAffineForInAIRRtDmaWrapAndStride(ModuleOp module) {
SmallVector<func::FuncOp> funcOps;
module.walk([&](func::FuncOp f) { funcOps.push_back(f); });
llvm::SmallSet<Operation *, 1> erased;
SmallVector<affine::AffineForOp> unroll_outer_dim;
auto specialzeAllAffineFors =
[&](SmallVector<func::FuncOp> funcOps,
llvm::SmallSet<Operation *, 1> &erased,
SmallVector<affine::AffineForOp> &unroll_outer_dim) {
for (auto f : funcOps) {
for (auto for_op : f.getOps<affine::AffineForOp>()) {
OpBuilder builder(for_op);
if (specializeAffineForInAIRRtDmaWrapAndStride(builder, for_op)
.succeeded())
erased.insert(for_op);
else {
// Wait list to be unrolled one outer dimension, and then try
// specializing the wraps and strides again.
unroll_outer_dim.push_back(for_op);
}
}
}
};
specialzeAllAffineFors(funcOps, erased, unroll_outer_dim);
for (auto o : erased)
o->erase();
erased.clear();
// In AIE2 BD, there is one single dimension capable of repeating. If
// unroll_outer_dim isn't empty, then unroll the existing dimension in the
// repeat dim and repopulate that dimension with a true repeat dimension.
for (auto o : unroll_outer_dim) {
int64_t tripCount = llvm::divideCeilSigned(o.getConstantUpperBound() -
o.getConstantLowerBound(),
o.getStepAsInt());
(void)loopUnrollByFactor(o, tripCount);
auto new_dma = rewriter.create<airrt::DmaMemcpyNdOp>(
loc, tys, air::lookupOrDefaultRange(opers, remap));
new_dma->setAttrs(memcpy_op->getDiscardableAttrDictionary());

rewriter.eraseOp(for_op.getOperation());

return success();
}
specialzeAllAffineFors(funcOps, erased, unroll_outer_dim);
for (auto o : erased)
o->erase();
}

private:
};

struct AIRRtToNpuPass : public impl::AIRRtToNpuBase<AIRRtToNpuPass> {
void runOnOperation() override {
Expand Down Expand Up @@ -985,7 +950,10 @@ struct AIRRtToNpuPass : public impl::AIRRtToNpuBase<AIRRtToNpuPass> {
(void)applyPatternsAndFoldGreedily(module, std::move(canoPatterns_0));

// Specialize affine for loop nest into wraps and strides
specializeAffineForInAIRRtDmaWrapAndStride(module);
RewritePatternSet loopFoldPattern(ctx);
loopFoldPattern.add<AIRSpecializeAIRRtDmaWrapAndStrideInAffineFor>(ctx);
air::populateAIRLoopIndexCanonicalizationPatterns(loopFoldPattern);
(void)applyPatternsAndFoldGreedily(module, std::move(loopFoldPattern));
unrollAffineFors(module);

// Simplify arith ops (from airrt)
Expand Down
Loading

0 comments on commit c1218d2

Please sign in to comment.