Skip to content

Commit

Permalink
[AMD] Support more cache modifiers for cdna (#4369)
Browse files Browse the repository at this point in the history
- Provided required arguments to store operation
- Added testcase to test_core.py::test_store_cache_modifier
- Skip gfx11 arch in cache modifiers load/store tests

Current mapping is following:
Loads:
* ca(default) - cache at all levels
* cg - nt mode

Stores:
* wb(default): cache at all levels
* cg - nt mode
* wt - enable sc1
* cs - not supported

Signed-off-by: Ilya Veselov <[email protected]>
  • Loading branch information
joviliast authored Sep 4, 2024
1 parent 51e4b9e commit 7480ef5
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 80 deletions.
118 changes: 70 additions & 48 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3797,27 +3797,30 @@ def _kernel(dst, src, CACHE: tl.constexpr):

pgm = _kernel[(1, )](dst, src, CACHE=cache)

if not is_cuda():
if is_hip():
amdgcn = pgm.asm['amdgcn']
cache_modifier_str = 'nt' if 'gfx94' in get_arch() else 'glc'
global_load_line = [line for line in amdgcn.splitlines() if "global_load" in line]
if cache == '':
assert cache_modifier_str not in global_load_line[0]
if cache == '.cg':
assert cache_modifier_str in global_load_line[0]
return
if is_hip():
target_arch = get_arch()
# TODO: support testing for remaining architectures
if 'gfx94' not in target_arch:
return
amdgcn = pgm.asm['amdgcn']
cg_cache_modifier_str = 'nt'
global_load_line = [line for line in amdgcn.splitlines() if "global_load" in line]
if cache == '' or cache == '.ca':
assert cg_cache_modifier_str not in global_load_line[0]
if cache == '.cg':
assert cg_cache_modifier_str in global_load_line[0]

ptx = pgm.asm['ptx']
if cache == '':
assert 'ld.global.ca' not in ptx
assert 'ld.global.cg' not in ptx
if cache == '.cg':
assert 'ld.global.cg' in ptx
assert 'ld.global.ca' not in ptx
if cache == '.ca':
assert 'ld.global.ca' in ptx
assert 'ld.global.cg' not in ptx
if is_cuda():
ptx = pgm.asm['ptx']
if cache == '':
assert 'ld.global.ca' not in ptx
assert 'ld.global.cg' not in ptx
if cache == '.cg':
assert 'ld.global.cg' in ptx
assert 'ld.global.ca' not in ptx
if cache == '.ca':
assert 'ld.global.ca' in ptx
assert 'ld.global.cg' not in ptx


@pytest.mark.interpreter
Expand Down Expand Up @@ -3912,35 +3915,54 @@ def _kernel(dst, src, CACHE: tl.constexpr):
x = tl.load(src + offsets)
tl.store(dst + offsets, x, cache_modifier=CACHE)

if not is_cuda():
return
pgm = _kernel[(1, )](dst, src, CACHE=cache)
ptx = pgm.asm['ptx']
if cache == '':
assert 'st.global.wb' not in ptx
assert 'st.global.cg' not in ptx
assert 'st.global.cs' not in ptx
assert 'st.global.wt' not in ptx
if cache == '.wb':
assert 'st.global.wb' in ptx
assert 'st.global.cg' not in ptx
assert 'st.global.cs' not in ptx
assert 'st.global.wt' not in ptx
if cache == '.cg':
assert 'st.global.wb' not in ptx
assert 'st.global.cg' in ptx
assert 'st.global.cs' not in ptx
assert 'st.global.wt' not in ptx
if cache == '.cs':
assert 'st.global.wb' not in ptx
assert 'st.global.cg' not in ptx
assert 'st.global.cs' in ptx
assert 'st.global.wt' not in ptx
if cache == '.wt':
assert 'st.global.wb' not in ptx
assert 'st.global.cg' not in ptx
assert 'st.global.cs' not in ptx
assert 'st.global.wt' in ptx

if is_hip():
target_arch = get_arch()
# TODO: support testing for remaining architectures
if 'gfx94' not in target_arch:
return
amdgcn = pgm.asm['amdgcn']
cs_cache_modifier_str = 'nt'
wt_cache_modifier_str = 'sc0 sc1'
global_store_line = [line for line in amdgcn.splitlines() if "global_store" in line]
if cache == '' or cache == '.cg':
assert cs_cache_modifier_str not in global_store_line[0]
assert wt_cache_modifier_str not in global_store_line[0]
if cache == '.cs':
assert cs_cache_modifier_str in global_store_line[0]
assert wt_cache_modifier_str not in global_store_line[0]
if cache == '.wt':
assert cs_cache_modifier_str not in global_store_line[0]
assert wt_cache_modifier_str in global_store_line[0]

if is_cuda():
ptx = pgm.asm['ptx']
if cache == '':
assert 'st.global.wb' not in ptx
assert 'st.global.cg' not in ptx
assert 'st.global.cs' not in ptx
assert 'st.global.wt' not in ptx
if cache == '.wb':
assert 'st.global.wb' in ptx
assert 'st.global.cg' not in ptx
assert 'st.global.cs' not in ptx
assert 'st.global.wt' not in ptx
if cache == '.cg':
assert 'st.global.wb' not in ptx
assert 'st.global.cg' in ptx
assert 'st.global.cs' not in ptx
assert 'st.global.wt' not in ptx
if cache == '.cs':
assert 'st.global.wb' not in ptx
assert 'st.global.cg' not in ptx
assert 'st.global.cs' in ptx
assert 'st.global.wt' not in ptx
if cache == '.wt':
assert 'st.global.wb' not in ptx
assert 'st.global.cg' not in ptx
assert 'st.global.cs' not in ptx
assert 'st.global.wt' in ptx


@pytest.mark.interpreter
Expand Down
66 changes: 48 additions & 18 deletions third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,8 @@ class CallOpConversion : public mlir::RewritePattern {
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto callOp = cast<LLVM::CallOp>(op);
if (isPredicatedLoadNT(callOp)) {
return convertPredicatedLoad(callOp, rewriter, /*nt=*/true);
} else if (isPredicatedLoad(callOp)) {
return convertPredicatedLoad(callOp, rewriter, /*nt=*/false);
if (isPredicatedLoad(callOp)) {
return convertPredicatedLoad(callOp, rewriter);
} else if (isPredicatedStore(callOp)) {
return convertPredicatedStore(callOp, rewriter);
} else if (isWrappedLLVMIntrinsic(callOp)) {
Expand All @@ -40,18 +38,37 @@ class CallOpConversion : public mlir::RewritePattern {

private:
bool isPredicatedLoad(LLVM::CallOp callOp) const {
return callOp.getCallee().value().find(mlir::LLVM::AMD::Predicated_Load) !=
llvm::StringRef::npos;
return callOp.getCallee().value().contains(mlir::LLVM::AMD::predicatedLoad);
}

bool isPredicatedLoadNT(LLVM::CallOp callOp) const {
return callOp.getCallee().value().find(
mlir::LLVM::AMD::Predicated_Load_NT) != llvm::StringRef::npos;
bool isPredicatedLoadCA(LLVM::CallOp callOp) const {
return callOp.getCallee().value().contains(
mlir::LLVM::AMD::predicatedLoadCA);
}

bool isPredicatedLoadCG(LLVM::CallOp callOp) const {
return callOp.getCallee().value().contains(
mlir::LLVM::AMD::predicatedLoadCG);
}

bool isPredicatedStore(LLVM::CallOp callOp) const {
return callOp.getCallee().value().find(mlir::LLVM::AMD::Predicated_Store) !=
llvm::StringRef::npos;
return callOp.getCallee().value().contains(
mlir::LLVM::AMD::predicatedStore);
}

bool isPredicatedStoreCS(LLVM::CallOp callOp) const {
return callOp.getCallee().value().contains(
mlir::LLVM::AMD::predicatedStoreCS);
}

bool isPredicatedStoreCG(LLVM::CallOp callOp) const {
return callOp.getCallee().value().contains(
mlir::LLVM::AMD::predicatedStoreCG);
}

bool isPredicatedStoreWT(LLVM::CallOp callOp) const {
return callOp.getCallee().value().contains(
mlir::LLVM::AMD::predicatedStoreWT);
}

bool isWrappedLLVMIntrinsic(LLVM::CallOp callOp) const {
Expand Down Expand Up @@ -79,16 +96,24 @@ class CallOpConversion : public mlir::RewritePattern {
rewriter.setInsertionPointToEnd(currentBlock);
rewriter.create<LLVM::CondBrOp>(loc, pred, trueBlock, afterStore);
rewriter.setInsertionPointToStart(trueBlock);
auto storeOp = rewriter.create<LLVM::StoreOp>(loc, val, ptr);
/*
| vialatile | non-tmp | gcn instr gfx94
LLVM::StoreOp | 0 | 0 | (cg) global store
| 0 | 1 | (cs) global store nt
| 1 | 0/1 | (wt) global store sc0 sc1
*/
bool vialatileFlag = isPredicatedStoreWT(callOp);
bool nonTmpFlag = isPredicatedStoreCS(callOp);
auto storeOp = rewriter.create<LLVM::StoreOp>(
loc, val, ptr, /*alignment=*/0, vialatileFlag, nonTmpFlag);
rewriter.create<LLVM::BrOp>(loc, afterStore);
rewriter.setInsertionPointToStart(afterStore);
rewriter.eraseOp(callOp);
return mlir::success();
}

LogicalResult convertPredicatedLoad(LLVM::CallOp callOp,
mlir::PatternRewriter &rewriter,
bool nt) const {
mlir::PatternRewriter &rewriter) const {
auto operands = callOp.getOperands();
auto result = callOp.getResult();

Expand All @@ -108,10 +133,15 @@ class CallOpConversion : public mlir::RewritePattern {
rewriter.setInsertionPointToEnd(currentBlock);
rewriter.create<LLVM::CondBrOp>(loc, pred, trueBlock, falseBlock);
rewriter.setInsertionPointToStart(trueBlock);
auto loadOp = nt ? rewriter.create<LLVM::LoadOp>(
loc, elemTy, ptr, /*alignment=*/0,
/*isVolatile=*/false, /*isNonTemporal=*/true)
: rewriter.create<LLVM::LoadOp>(loc, elemTy, ptr);
/*
| vialatile | non-tmp | gcn instr gfx94
LLVM::LoadOp | 0 | 0 | (ca) global load
| 0/1 | 1 | (cg) global load nt
*/
bool vialatileFlag = false;
bool nonTmpFlag = isPredicatedLoadCG(callOp);
auto loadOp = rewriter.create<LLVM::LoadOp>(
loc, elemTy, ptr, /*alignment=*/0, vialatileFlag, nonTmpFlag);
rewriter.create<LLVM::BrOp>(loc, loadOp->getResult(0), afterLoad);
rewriter.setInsertionPointToStart(falseBlock);
rewriter.create<LLVM::BrOp>(loc, falseVal, afterLoad);
Expand Down
8 changes: 5 additions & 3 deletions third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
std::max(8u, valueElemTy.getIntOrFloatBitWidth());
const int numVecs = numElems / vec;

auto cacheMod = op.getCache();
SmallVector<Value> loadedVals;
for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) {
// TODO: optimization when ptr is GEP with constant offset
Expand Down Expand Up @@ -224,8 +225,8 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
falseVal = v;
}

bool nt = op.getCache() == triton::CacheModifier::CG;
auto loadVal = llLoad(rewriter, loc, ptr, vecTy, pred, falseVal, nt);
auto loadVal =
llLoad(rewriter, loc, ptr, vecTy, pred, falseVal, cacheMod);
for (size_t ii = 0; ii < vec; ++ii) {
Value vecIdx = createIndexAttrConstant(
rewriter, loc, this->getTypeConverter()->getIndexType(), ii % vec);
Expand Down Expand Up @@ -293,6 +294,7 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
const size_t valueElemNBits = dtsize * 8;

auto cacheMod = op.getCache();
const int numVecs = elemsPerThread / vec;
for (size_t vecStart = 0; vecStart < elemsPerThread; vecStart += vec) {
// TODO: optimization when ptr is AddPtr with constant offset
Expand Down Expand Up @@ -329,7 +331,7 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
llWord = bitcast(llWord, valArgTy);
Value maskVal = llMask ? and_(mask, maskElems[vecStart]) : mask;
auto address = ptrElems[vecStart + wordIdx * wordNElems];
llStore(rewriter, loc, address, llWord, maskVal);
llStore(rewriter, loc, address, llWord, maskVal, cacheMod);
}
}
rewriter.eraseOp(op);
Expand Down
37 changes: 31 additions & 6 deletions third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,23 @@ Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp,
}

Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy,
Value pred, Value falseVal, bool nt) {
Value pred, Value falseVal, triton::CacheModifier cm) {
Type funcType = getFunctionType(elemTy, ValueRange({ptr, pred, falseVal}));
auto parent = ptr.getParentRegion()->getParentOfType<LLVM::LLVMFuncOp>();
auto funcNameRaw = nt ? mlir::LLVM::AMD::Predicated_Load_NT
: mlir::LLVM::AMD::Predicated_Load;
auto funcName = mangleFunc(funcNameRaw, funcType);
auto getLoadNameRaw = [](triton::CacheModifier cm) {
switch (cm) {
case triton::CacheModifier::CA:
return predicatedLoadCA;
case triton::CacheModifier::CG:
return predicatedLoadCG;
default:
// Do not fail in compile time in the case of unsupported modifier.
// Just apply default config.
return predicatedLoad;
}
};

auto funcName = mangleFunc(getLoadNameRaw(cm), funcType);

LLVM::LLVMFuncOp funcOp =
appendOrGetExternFuncOp(rewriter, parent, funcName, funcType);
Expand All @@ -173,11 +184,25 @@ Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy,
}

void llStore(RewriterBase &rewriter, Location loc, Value ptr, Value val,
Value pred) {
Value pred, triton::CacheModifier cm) {
auto ctx = ptr.getContext();
Type funcType = getFunctionType(void_ty(ctx), ValueRange({ptr, val, pred}));
auto parent = ptr.getParentRegion()->getParentOfType<LLVM::LLVMFuncOp>();
auto funcName = mangleFunc(mlir::LLVM::AMD::Predicated_Store, funcType);
auto getStoreNameRaw = [](triton::CacheModifier cm) {
switch (cm) {
case triton::CacheModifier::WT:
return predicatedStoreWT;
case triton::CacheModifier::CG:
return predicatedStoreCG;
case triton::CacheModifier::CS:
return predicatedStoreCS;
default:
// Do not fail in compile time in the case of unsupported modifier.
// Just apply default config.
return predicatedStore;
}
};
auto funcName = mangleFunc(getStoreNameRaw(cm), funcType);
LLVM::LLVMFuncOp funcOp =
appendOrGetExternFuncOp(rewriter, parent, funcName, funcType);
rewriter.create<LLVM::CallOp>(loc, funcOp, ValueRange({ptr, val, pred}));
Expand Down
16 changes: 11 additions & 5 deletions third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
namespace mlir::LLVM::AMD {

const char Predicated_Load[] = "__predicated_load";
const char Predicated_Load_NT[] = "__predicated_load_NT";
const char Predicated_Store[] = "__predicated_store";
const char predicatedLoad[] = "__predicated_load";
const char predicatedLoadCA[] = "__predicated_load_CA";
const char predicatedLoadCG[] = "__predicated_load_CG";
const char predicatedStore[] = "__predicated_store";
const char predicatedStoreCG[] = "__predicated_store_CG";
const char predicatedStoreCS[] = "__predicated_store_CS";
const char predicatedStoreWT[] = "__predicated_store_WT";

Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i);
Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i);
Expand All @@ -25,11 +29,13 @@ Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp,
// Loads from shared or global memory with predication.
// `otherElems` is used to mask out the elements that are not loaded
Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy,
Value pred, Value falseVal, bool nt = false);
Value pred, Value falseVal,
triton::CacheModifier cm = triton::CacheModifier::NONE);

// Stores to shared or global memory with predication.
void llStore(RewriterBase &rewriter, Location loc, Value ptr, Value val,
Value pred);
Value pred,
triton::CacheModifier cm = triton::CacheModifier::NONE);
} // namespace mlir::LLVM::AMD

#endif

0 comments on commit 7480ef5

Please sign in to comment.