From dfe2ad7f96fb0e92ae0942795ea1ac74d5be48f8 Mon Sep 17 00:00:00 2001 From: Minjang Kim Date: Wed, 15 May 2024 22:19:36 -0700 Subject: [PATCH] Quick patches to make it work after rebasing (#3) --- lib/Conversion/TritonCPUToLLVM/PrintOpToLLVM.cpp | 4 ++-- lib/Conversion/TritonCPUToLLVM/TypeConverter.cpp | 2 +- lib/Conversion/TritonToTritonCPU/TritonCPUConversion.cpp | 6 +++--- python/triton/runtime/jit.py | 2 +- third_party/cpu/backend/compiler.py | 6 +++--- third_party/cpu/backend/driver.py | 4 +++- 6 files changed, 13 insertions(+), 11 deletions(-) diff --git a/lib/Conversion/TritonCPUToLLVM/PrintOpToLLVM.cpp b/lib/Conversion/TritonCPUToLLVM/PrintOpToLLVM.cpp index 96a1c5d1619f..b424cf8e37b7 100644 --- a/lib/Conversion/TritonCPUToLLVM/PrintOpToLLVM.cpp +++ b/lib/Conversion/TritonCPUToLLVM/PrintOpToLLVM.cpp @@ -39,7 +39,7 @@ struct PrintOpConversion : public ConvertOpToLLVMPattern { for (size_t i = 0; i < op.getNumOperands(); i++) { auto elems = unpackLLElements(loc, adaptor.getOperands()[i], rewriter); - if (op.getOperand(i).getType().dyn_cast()) { + if (dyn_cast(op.getOperand(i).getType())) { llvm_unreachable("Not implemented for tensor types"); } @@ -61,7 +61,7 @@ struct PrintOpConversion : public ConvertOpToLLVMPattern { std::string getFormatSubstr(Value value, bool hex = false, std::optional width = std::nullopt) const { Type type = value.getType(); - if (type.isa()) { + if (isa(type)) { return "%p"; } // Hex is "0x%0nx" or "0x%0nllx", where n is the number of hex digits in the diff --git a/lib/Conversion/TritonCPUToLLVM/TypeConverter.cpp b/lib/Conversion/TritonCPUToLLVM/TypeConverter.cpp index e8ca0810c195..72ef796fdabb 100644 --- a/lib/Conversion/TritonCPUToLLVM/TypeConverter.cpp +++ b/lib/Conversion/TritonCPUToLLVM/TypeConverter.cpp @@ -24,7 +24,7 @@ Type TritonCPUToLLVMTypeConverter::convertTritonPointerType( triton::PointerType type) { auto ctx = type.getContext(); auto pointeeType = type.getPointeeType(); - if (pointeeType.isa()) { + if (isa(pointeeType)) { llvm_unreachable("Not implemented"); } return LLVM::LLVMPointerType::get(ctx, type.getAddressSpace()); diff --git a/lib/Conversion/TritonToTritonCPU/TritonCPUConversion.cpp b/lib/Conversion/TritonToTritonCPU/TritonCPUConversion.cpp index dabc2a27a87b..97948404bdbf 100644 --- a/lib/Conversion/TritonToTritonCPU/TritonCPUConversion.cpp +++ b/lib/Conversion/TritonToTritonCPU/TritonCPUConversion.cpp @@ -26,7 +26,7 @@ TritonCPUTypeConverter::TritonCPUTypeConverter(MLIRContext *context) addConversion([this](triton::PointerType ptrType) -> triton::PointerType { // Check whether tensor pointer `tt.ptr>` auto pointeeTensorType = - ptrType.getPointeeType().dyn_cast(); + dyn_cast(ptrType.getPointeeType()); if (pointeeTensorType == nullptr) return ptrType; @@ -99,9 +99,9 @@ TritonCPUConversionTarget::TritonCPUConversionTarget( // We have requirements for the data layouts addDynamicallyLegalOp([](triton::DotOp dotOp) -> bool { Attribute aEncoding = - dotOp.getA().getType().cast().getEncoding(); + cast(dotOp.getA().getType()).getEncoding(); Attribute bEncoding = - dotOp.getB().getType().cast().getEncoding(); + cast(dotOp.getB().getType()).getEncoding(); // TODO: return false; }); diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index e1c802f70925..263315402a65 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -633,7 +633,7 @@ def run(self, *args, grid, warmup, **kwargs): # The CPU launcher will provide the grid ids directly to the kernel. # Note that this design is interim and subject to change. - if target[0] == 'cpu': + if target.backend == 'cpu': signature["__grid0"] = 'i32' signature["__grid1"] = 'i32' signature["__grid2"] = 'i32' diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 84564cabef0c..3c293cdf468f 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -7,7 +7,7 @@ from typing import Any from triton._C.libtriton import cpu, ir, llvm, passes -from triton.backends.compiler import BaseBackend +from triton.backends.compiler import BaseBackend, GPUTarget @dataclass(frozen=True) @@ -35,8 +35,8 @@ def hash(self): class CPUBackend(BaseBackend): @staticmethod - def supports_target(target: tuple): - return target[0] == "cpu" + def supports_target(target: GPUTarget): + return target.backend == "cpu" def __init__(self, target: tuple) -> None: super().__init__(target) diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index a6cf99f742b2..3f3816a99b9f 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -1,3 +1,4 @@ +from triton.backends.compiler import GPUTarget from triton.backends.driver import CPUDriverBase # ------------------------ @@ -60,7 +61,8 @@ def __init__(self): def get_current_target(self): # Capability and warp size are zeros for CPU. - return ("cpu", 0, 0) + # TODO: GPUTarget naming isn't obviously good. + return GPUTarget("cpu", 0, 0) @staticmethod def is_active():