Skip to content

Commit

Permalink
Quick patches to make it work after rebasing (triton-lang#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
minjang authored and Devjiu committed Aug 13, 2024
1 parent 100c775 commit dfe2ad7
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 11 deletions.
4 changes: 2 additions & 2 deletions lib/Conversion/TritonCPUToLLVM/PrintOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ struct PrintOpConversion : public ConvertOpToLLVMPattern<triton::PrintOp> {

for (size_t i = 0; i < op.getNumOperands(); i++) {
auto elems = unpackLLElements(loc, adaptor.getOperands()[i], rewriter);
if (op.getOperand(i).getType().dyn_cast<RankedTensorType>()) {
if (dyn_cast<RankedTensorType>(op.getOperand(i).getType())) {
llvm_unreachable("Not implemented for tensor types");
}

Expand All @@ -61,7 +61,7 @@ struct PrintOpConversion : public ConvertOpToLLVMPattern<triton::PrintOp> {
std::string getFormatSubstr(Value value, bool hex = false,
std::optional<int> width = std::nullopt) const {
Type type = value.getType();
if (type.isa<LLVM::LLVMPointerType>()) {
if (isa<LLVM::LLVMPointerType>(type)) {
return "%p";
}
// Hex is "0x%0nx" or "0x%0nllx", where n is the number of hex digits in the
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonCPUToLLVM/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Type TritonCPUToLLVMTypeConverter::convertTritonPointerType(
triton::PointerType type) {
auto ctx = type.getContext();
auto pointeeType = type.getPointeeType();
if (pointeeType.isa<RankedTensorType>()) {
if (isa<RankedTensorType>(pointeeType)) {
llvm_unreachable("Not implemented");
}
return LLVM::LLVMPointerType::get(ctx, type.getAddressSpace());
Expand Down
6 changes: 3 additions & 3 deletions lib/Conversion/TritonToTritonCPU/TritonCPUConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ TritonCPUTypeConverter::TritonCPUTypeConverter(MLIRContext *context)
addConversion([this](triton::PointerType ptrType) -> triton::PointerType {
// Check whether tensor pointer `tt.ptr<tensor<>>`
auto pointeeTensorType =
ptrType.getPointeeType().dyn_cast<RankedTensorType>();
dyn_cast<RankedTensorType>(ptrType.getPointeeType());
if (pointeeTensorType == nullptr)
return ptrType;

Expand Down Expand Up @@ -99,9 +99,9 @@ TritonCPUConversionTarget::TritonCPUConversionTarget(
// We have requirements for the data layouts
addDynamicallyLegalOp<triton::DotOp>([](triton::DotOp dotOp) -> bool {
Attribute aEncoding =
dotOp.getA().getType().cast<RankedTensorType>().getEncoding();
cast<RankedTensorType>(dotOp.getA().getType()).getEncoding();
Attribute bEncoding =
dotOp.getB().getType().cast<RankedTensorType>().getEncoding();
cast<RankedTensorType>(dotOp.getB().getType()).getEncoding();
// TODO:
return false;
});
Expand Down
2 changes: 1 addition & 1 deletion python/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ def run(self, *args, grid, warmup, **kwargs):

# The CPU launcher will provide the grid ids directly to the kernel.
# Note that this design is interim and subject to change.
if target[0] == 'cpu':
if target.backend == 'cpu':
signature["__grid0"] = 'i32'
signature["__grid1"] = 'i32'
signature["__grid2"] = 'i32'
Expand Down
6 changes: 3 additions & 3 deletions third_party/cpu/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any

from triton._C.libtriton import cpu, ir, llvm, passes
from triton.backends.compiler import BaseBackend
from triton.backends.compiler import BaseBackend, GPUTarget


@dataclass(frozen=True)
Expand Down Expand Up @@ -35,8 +35,8 @@ def hash(self):
class CPUBackend(BaseBackend):

@staticmethod
def supports_target(target: tuple):
return target[0] == "cpu"
def supports_target(target: GPUTarget):
return target.backend == "cpu"

def __init__(self, target: tuple) -> None:
super().__init__(target)
Expand Down
4 changes: 3 additions & 1 deletion third_party/cpu/backend/driver.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from triton.backends.compiler import GPUTarget
from triton.backends.driver import CPUDriverBase

# ------------------------
Expand Down Expand Up @@ -60,7 +61,8 @@ def __init__(self):

def get_current_target(self):
# Capability and warp size are zeros for CPU.
return ("cpu", 0, 0)
# TODO: GPUTarget naming isn't obviously good.
return GPUTarget("cpu", 0, 0)

@staticmethod
def is_active():
Expand Down

0 comments on commit dfe2ad7

Please sign in to comment.