diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index a4939a739528..c4393bc17b80 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -20,6 +20,7 @@ #include "cpu/include/TritonCPUToLLVM/Passes.h" #include "cpu/include/TritonCPUTransforms/Passes.h" #include "cpu/include/TritonToTritonCPU/Passes.h" +#include "cpu/include/Xsmm/Passes.h" #include "nvidia/include/NVGPUToLLVM/Passes.h" #include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h" #include "triton/Conversion/TritonGPUToLLVM/Passes.h" @@ -74,6 +75,7 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::triton::cpu::registerTritonCPUTransformsPasses(); mlir::triton::cpu::registerTritonCPUToLLVMPasses(); mlir::triton::cpu::registerTritonOpScalarizeExternalModels(registry); + mlir::triton::cpu::registerTritonCPUXsmmPasses(); // TODO: register Triton & TritonGPU passes registry.insert + $ +) +add_definitions(-DLIBXSMM_DEFAULT_CONFIG -U_DEBUG -D__BLAS=0) + +set_property(TARGET xsmm PROPERTY POSITION_INDEPENDENT_CODE ON) # -fPIC +set_property(TARGET xsmm PROPERTY COMPILE_WARNING_AS_ERROR ON) + +set(THREADS_PREFER_PTHREAD_FLAG ON) +find_package(Threads REQUIRED) +target_link_libraries(xsmm PUBLIC Threads::Threads) +target_link_libraries(xsmm PUBLIC ${CMAKE_DL_LIBS}) + +include(CheckLibraryExists) +check_library_exists(m sqrt "" XSMM_LIBM) +if(XSMM_LIBM) + target_link_libraries(xsmm PUBLIC m) +endif() +check_library_exists(rt sched_yield "" XSMM_LIBRT) +if(XSMM_LIBRT) + target_link_libraries(xsmm PUBLIC rt) +endif() +#target_link_libraries(xsmm PUBLIC c) diff --git a/third_party/cpu/CMakeLists.txt b/third_party/cpu/CMakeLists.txt index 59d0f5c53d46..d58dfd545a02 100644 --- a/third_party/cpu/CMakeLists.txt +++ b/third_party/cpu/CMakeLists.txt @@ -1,15 +1,28 @@ +# libxsmm +include(xsmm) +message (STATUS "LIBXSMM Include dir: ${XSMM_INCLUDE_DIRS}") + include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) add_subdirectory(include) add_subdirectory(lib) if(TRITON_BUILD_PYTHON_MODULE) - add_triton_plugin(TritonCPU ${CMAKE_CURRENT_SOURCE_DIR}/triton_cpu.cc LINK_LIBS TritonCPUToLLVM TritonCPUTransforms) - target_link_libraries(TritonCPU PUBLIC MLIRVectorToSCF MLIRAffineToStandard MLIRMathToLibm MLIRAMXToLLVMIRTranslation) + add_triton_plugin(TritonCPU ${CMAKE_CURRENT_SOURCE_DIR}/triton_cpu.cc LINK_LIBS TritonCPUToLLVM TritonCPUTransforms TritonCPUXsmm) + target_link_libraries(TritonCPU PUBLIC MLIRVectorToSCF MLIRAffineToStandard MLIRMathToLibm MLIRAMXToLLVMIRTranslation MLIRMemRefTransforms) endif() add_library(TritonCPURuntime SHARED ${CMAKE_CURRENT_SOURCE_DIR}/runtime/cpu_runtime.cpp) target_link_libraries(TritonCPURuntime PRIVATE LLVMSupport) +add_library(TritonCPUXsmmRuntime SHARED ${CMAKE_CURRENT_SOURCE_DIR}/lib/Xsmm/runtime/XsmmRunnerUtils.cpp) +target_link_libraries(TritonCPUXsmmRuntime PRIVATE xsmm) +set_property(TARGET TritonCPUXsmmRuntime PROPERTY CXX_STANDARD 11) +target_compile_definitions(TritonCPUXsmmRuntime PRIVATE mlir_c_runner_utils_EXPORTS) +target_include_directories(TritonCPUXsmmRuntime + PUBLIC + $ +) + # Build and link sleef set(SLEEF_BUILD_SHARED_LIBS ON CACHE BOOL "Build sleef shared lib" FORCE) set(SLEEF_BUILD_DFT OFF CACHE BOOL "Don't build sleef DFT lib" FORCE) diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 6150defb6128..975001dfccd3 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -41,6 +41,7 @@ class CPUOptions: enable_fp_fusion: bool = True max_num_imprecise_acc_default: int = 0 enable_fast_math: bool = True + enable_xsmm: bool = False vec_lib: Optional[str] = 'libsleef' # TODO: Try to enable it. sanitize_overflow: bool = False @@ -96,6 +97,8 @@ def parse_options(self, opts) -> Any: if "supported_fp8_dtypes" not in args: supported_fp8_dtypes = set(CPUOptions.supported_fp8_dtypes) args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes)) + if "enable_xsmm" not in args: + args["enable_xsmm"] = os.getenv("TRITON_CPU_XSMM", "0") != "0" return CPUOptions(**args) def pack_metadata(self, metadata): @@ -194,7 +197,10 @@ def make_llir(self, src, metadata, options): # TritonCPU -> LLVM-IR (MLIR) pm = ir.pass_manager(mod.context) pm.enable_debug() + if options.enable_xsmm: + cpu.passes.ttcpuir.add_convert_vector_to_xsmm(pm) cpu.passes.ttcpuir.add_lower_vector_multi_dim(pm) + cpu.passes.ttcpuir.add_expand_strided_metadata(pm) cpu.passes.ttcpuir.add_vector_to_scf(pm, True, 1, False) cpu.passes.ttcpuir.add_lower_affine(pm) passes.convert.add_scf_to_cf(pm) @@ -259,6 +265,8 @@ def make_so(src, metadata, options): Path(asm_path).write_text(src) lib_dirs = cpu_driver.library_dirs libs = ["gcc", "m", "TritonCPURuntime", "sleef"] + if options.enable_xsmm: + libs.extend(["xsmm", "TritonCPUXsmmRuntime"]) so = _build("kernel", asm_path, tmpdir, lib_dirs, cpu_driver.include_dirs, libs) with open(so, "rb") as f: return f.read() diff --git a/third_party/cpu/include/CMakeLists.txt b/third_party/cpu/include/CMakeLists.txt index 30282a4736c1..fee03f896707 100644 --- a/third_party/cpu/include/CMakeLists.txt +++ b/third_party/cpu/include/CMakeLists.txt @@ -2,3 +2,4 @@ add_subdirectory(ScalarizePass) add_subdirectory(TritonCPUToLLVM) add_subdirectory(TritonCPUTransforms) add_subdirectory(TritonToTritonCPU) +add_subdirectory(Xsmm) diff --git a/third_party/cpu/include/Xsmm/CMakeLists.txt b/third_party/cpu/include/Xsmm/CMakeLists.txt new file mode 100644 index 000000000000..ede68918f9a6 --- /dev/null +++ b/third_party/cpu/include/Xsmm/CMakeLists.txt @@ -0,0 +1,8 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonCPUXsmm) +add_public_tablegen_target(TritonCPUXsmmPassIncGen) + +set(LLVM_TARGET_DEFINITIONS XsmmEnum.td) +mlir_tablegen(XsmmEnum.h.inc -gen-enum-decls) +mlir_tablegen(XsmmEnum.cpp.inc -gen-enum-defs) +add_public_tablegen_target(TritonCPUXsmmAttrDefIncGen) diff --git a/third_party/cpu/include/Xsmm/Passes.h b/third_party/cpu/include/Xsmm/Passes.h new file mode 100644 index 000000000000..0e7b2d11ce5e --- /dev/null +++ b/third_party/cpu/include/Xsmm/Passes.h @@ -0,0 +1,67 @@ +#ifndef TritonCPUXsmm_CONVERSION_PASSES_H +#define TritonCPUXsmm_CONVERSION_PASSES_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +class ModuleOp; + +namespace affine { +class AffineDialect; +} // namespace affine + +namespace arith { +class ArithDialect; +} // namespace arith + +namespace func { +class FuncOp; +class FuncDialect; +} // namespace func + +namespace linalg { +class LinalgDialect; +} // namespace linalg + +namespace LLVM { +class LLVMDialect; +} // namespace LLVM + +namespace math { +class MathDialect; +} // namespace math + +namespace memref { +class MemRefDialect; +} // namespace memref + +namespace scf { +class SCFDialect; +} // namespace scf + +namespace tensor { +class TensorDialect; +} // namespace tensor + +namespace vector { +class VectorDialect; +} // namespace vector + +} // namespace mlir + +namespace mlir { +namespace triton { +namespace cpu { + +#define GEN_PASS_DECL +#include "cpu/include/Xsmm/Passes.h.inc" + +#define GEN_PASS_REGISTRATION +#include "cpu/include/Xsmm/Passes.h.inc" + +} // namespace cpu +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/cpu/include/Xsmm/Passes.td b/third_party/cpu/include/Xsmm/Passes.td new file mode 100644 index 000000000000..b6b84d25dfde --- /dev/null +++ b/third_party/cpu/include/Xsmm/Passes.td @@ -0,0 +1,17 @@ +#ifndef TRITONCPU_XSMM_PASSES +#define TRITONCPU_XSMM_PASSES + +include "mlir/Pass/PassBase.td" + +def ConvertVectorToXsmm : Pass<"triton-cpu-convert-vector-to-xsmm", "mlir::ModuleOp"> { + let summary = "Convert vector to xsmm"; + let description = [{ + Convert vector operations to XSMM operations. + }]; + let dependentDialects = ["func::FuncDialect", + "memref::MemRefDialect", + "vector::VectorDialect", + "LLVM::LLVMDialect"]; +} + +#endif diff --git a/third_party/cpu/include/Xsmm/XsmmEnum.h b/third_party/cpu/include/Xsmm/XsmmEnum.h new file mode 100644 index 000000000000..19bfad8b16ba --- /dev/null +++ b/third_party/cpu/include/Xsmm/XsmmEnum.h @@ -0,0 +1,18 @@ +//===- XsmmEnum.h -----------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TPP_DIALECT_XSMM_XSMMENUM_H +#define TPP_DIALECT_XSMM_XSMMENUM_H + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/DialectImplementation.h" + +#define GET_ATTRDEF_CLASSES +#include "cpu/include/Xsmm/XsmmEnum.h.inc" + +#endif // TPP_DIALECT_XSMM_XSMMENUM_H diff --git a/third_party/cpu/include/Xsmm/XsmmEnum.td b/third_party/cpu/include/Xsmm/XsmmEnum.td new file mode 100644 index 000000000000..17da6cfbb8ef --- /dev/null +++ b/third_party/cpu/include/Xsmm/XsmmEnum.td @@ -0,0 +1,83 @@ +//===- XsmmEnum --------------------------------------------*- Tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/EnumAttr.td" + +def Xsmm_DataType: I64EnumAttr< + "DataType", "see: libxsmm_datatype", + [ + I64EnumAttrCase<"F32", 1, "f32">, + I64EnumAttrCase<"BF16", 2, "bf16"> + ]>{ + let cppNamespace = "mlir::xsmm"; +} + +def Xsmm_BinaryKind : I64EnumAttr< + "BinaryKind", "see: libxsmm_meltw_binary_type", + [ + I64EnumAttrCase<"NONE", 0, "none">, + I64EnumAttrCase<"ADD", 1, "add">, + I64EnumAttrCase<"MUL", 2, "mul">, + I64EnumAttrCase<"SUB", 3, "sub">, + I64EnumAttrCase<"DIV", 4, "div"> + ]> { + let cppNamespace = "mlir::xsmm"; +} + +def Xsmm_UnaryKind : I64EnumAttr< + "UnaryKind", "see: libxsmm_meltw_unary_type", + [ + I64EnumAttrCase<"NONE", 0, "none">, + I64EnumAttrCase<"IDENTITY", 1, "identity">, + I64EnumAttrCase<"ZERO", 2, "zero">, + I64EnumAttrCase<"RELU", 5, "relu">, + I64EnumAttrCase<"VNNI2", 28, "vnni_2">, + I64EnumAttrCase<"TRANSPOSE", 29, "transpose"> + ]> { + let cppNamespace = "mlir::xsmm"; +} + +def Xsmm_UnaryFlags : I64EnumAttr< + "UnaryFlags", "see: libxsmm_meltw_unary_flags", + [ + I64EnumAttrCase<"NONE", 0, "none">, + I64EnumAttrCase<"BCAST_ROW", 2, "bcast_row">, + I64EnumAttrCase<"BCAST_COL", 4, "bcast_col">, + I64EnumAttrCase<"BCAST_SCALAR", 8, "bcast_scalar"> + ]> { + let cppNamespace = "mlir::xsmm"; +} + +def Xsmm_BinaryFlags : I64EnumAttr< + "BinaryFlags", "see: libxsmm_meltw_binary_flags", + [ + I64EnumAttrCase<"NONE", 0, "none">, + I64EnumAttrCase<"BCAST_ROW_IN_0", 1, "bcast_row_in0">, + I64EnumAttrCase<"BCAST_ROW_IN_1", 2, "bcast_row_in1">, + I64EnumAttrCase<"BCAST_COL_IN_0", 4, "bcast_col_in0">, + I64EnumAttrCase<"BCAST_COL_IN_1", 8, "bcast_col_in1">, + I64EnumAttrCase<"BCAST_SCALAR_IN_0", 16, "bcast_scalar_in0">, + I64EnumAttrCase<"BCAST_SCALAR_IN_1", 32, "bcast_scalar_in1"> + ]> { + let cppNamespace = "mlir::xsmm"; +} + +def Xsmm_GemmFlags : I64EnumAttr< + "GemmFlags", "see: libxsmm_gemm_flags", + [ + I64EnumAttrCase<"NONE", 0, "none">, + I64EnumAttrCase<"BETA_0", 4, "beta_0">, + I64EnumAttrCase<"VNNI_A", 2048, "vnni_a">, + I64EnumAttrCase<"VNNI_B", 4096, "vnni_b">, + I64EnumAttrCase<"VNNI_C", 8192, "vnni_c">, + I64EnumAttrCase<"NO_RESET_TILECONFIG", 64, "no_reset_tileconfig">, + I64EnumAttrCase<"NO_SETUP_TILECONFIG", 128, "no_setup_tileconfig"> + ]> { + let cppNamespace = "mlir::xsmm"; +} diff --git a/third_party/cpu/lib/CMakeLists.txt b/third_party/cpu/lib/CMakeLists.txt index fad51ab86ea9..df45d15fdac9 100644 --- a/third_party/cpu/lib/CMakeLists.txt +++ b/third_party/cpu/lib/CMakeLists.txt @@ -2,3 +2,4 @@ add_subdirectory(Analysis) add_subdirectory(TritonCPUToLLVM) add_subdirectory(TritonCPUTransforms) add_subdirectory(TritonToTritonCPU) +add_subdirectory(Xsmm) diff --git a/third_party/cpu/lib/Xsmm/CMakeLists.txt b/third_party/cpu/lib/Xsmm/CMakeLists.txt new file mode 100644 index 000000000000..1f54d2f21d52 --- /dev/null +++ b/third_party/cpu/lib/Xsmm/CMakeLists.txt @@ -0,0 +1,28 @@ +add_triton_library(TritonCPUXsmm + ConvertVectorToXsmm.cpp + VnniUtils.cpp + ValueUtils.cpp + XsmmEnum.cpp + XsmmUtils.cpp + + DEPENDS + TritonCPUXsmmPassIncGen + TritonCPUXsmmAttrDefIncGen + xsmm + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRVectorDialect + MLIRMemRefDialect + MLIRFuncDialect + MLIRLLVMDialect + MLIRInferTypeOpInterface + MLIRLinalgUtils + xsmm +) + +target_include_directories(TritonCPUXsmm + PUBLIC + $ +) diff --git a/third_party/cpu/lib/Xsmm/ConvertVectorToXsmm.cpp b/third_party/cpu/lib/Xsmm/ConvertVectorToXsmm.cpp new file mode 100644 index 000000000000..729b66d8dc9f --- /dev/null +++ b/third_party/cpu/lib/Xsmm/ConvertVectorToXsmm.cpp @@ -0,0 +1,139 @@ +//===- ConvertVectorToXsmm.cpp ----------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "cpu/include/Xsmm/Passes.h" + +#include "ValueUtils.h" +#include "VnniUtils.h" +#include "XsmmUtils.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" + +#include +#include + +using namespace mlir; +using namespace mlir::vector; +using namespace mlir::func; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace mlir { +namespace triton { +namespace cpu { +#define GEN_PASS_DEF_CONVERTVECTORTOXSMM +#include "cpu/include/Xsmm/Passes.h.inc" +} // namespace cpu +} // namespace triton +} // namespace mlir + +namespace { + +static Value getMemrefSource(PatternRewriter &rewriter, Operation *op, + TypedValue operand) { + Location loc = op->getLoc(); + MLIRContext *ctx = op->getContext(); + OpBuilder::InsertionGuard g(rewriter); + + if (auto readOp = + dyn_cast_or_null(operand.getDefiningOp())) { + VectorType vecTy = readOp.getVectorType(); + SmallVector strides(vecTy.getRank(), 1); + return rewriter.create( + loc, readOp.getSource(), getAsOpFoldResult(readOp.getIndices()), + getAsIndexOpFoldResult(ctx, vecTy.getShape()), + getAsIndexOpFoldResult(ctx, strides)); + } + + rewriter.setInsertionPoint(op); + + auto vecTy = dyn_cast(operand.getType()); + assert(vecTy && "Expect vector type operand"); + MemRefType memTy = MemRefType::get(vecTy.getShape(), vecTy.getElementType()); + auto alloca = rewriter.create(loc, memTy); + Value zeroIdx = rewriter.create(loc, 0); + SmallVector indices(memTy.getRank(), zeroIdx); + auto write = + rewriter.create(loc, operand, alloca, indices); + + return alloca; +} + +struct ContractToXsmm : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp contractOp, + PatternRewriter &rewriter) const override { + Location loc = contractOp.getLoc(); + + TypedValue lhs = contractOp.getLhs(); + TypedValue rhs = contractOp.getRhs(); + TypedValue acc = contractOp.getAcc(); + + auto vecTy = dyn_cast(acc.getType()); + if (!vecTy) + return rewriter.notifyMatchFailure(contractOp, + "expects to accumulate on vector"); + + SmallVector flags; + Value lhsBuf = getMemrefSource(rewriter, contractOp, lhs); + Value rhsBuf = getMemrefSource(rewriter, contractOp, rhs); + Value accBuf = getMemrefSource(rewriter, contractOp, acc); + SmallVector inputs{lhsBuf, rhsBuf, accBuf}; + SmallVector outputs{nullptr}; + auto brgemmInfo = + xsmm::utils::isMappableToBrgemm(rewriter, contractOp, inputs, outputs, + contractOp.getIndexingMapsArray()); + if (failed(brgemmInfo)) + return rewriter.notifyMatchFailure(contractOp, "not mappable to XSMM"); + if (brgemmInfo->isVnni) + return rewriter.notifyMatchFailure(contractOp, "VNNI support NYI"); + + auto xsmmFuncs = xsmm::utils::buildBrgemmCalls( + rewriter, contractOp, ValueRange{lhsBuf, rhsBuf, accBuf}, *brgemmInfo, + flags); + + Value zeroIdx = rewriter.create(loc, 0); + SmallVector indices(dyn_cast(accBuf.getType()).getRank(), + zeroIdx); + auto readOp = + rewriter.create(loc, vecTy, accBuf, indices); + + rewriter.replaceOp(contractOp, readOp); + + return success(); + } +}; + +struct ConvertVectorToXsmm + : public triton::cpu::impl::ConvertVectorToXsmmBase { + using ConvertVectorToXsmmBase::ConvertVectorToXsmmBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + patterns.add(context); + if (failed(mlir::applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace diff --git a/third_party/cpu/lib/Xsmm/ValueUtils.cpp b/third_party/cpu/lib/Xsmm/ValueUtils.cpp new file mode 100644 index 000000000000..566665dbc7f0 --- /dev/null +++ b/third_party/cpu/lib/Xsmm/ValueUtils.cpp @@ -0,0 +1,146 @@ +//===- ValueUtils.cpp --------------------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "ValueUtils.h" + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Value.h" +#include "llvm/ADT/TypeSwitch.h" + +namespace mlir { +namespace utils { + +// Returns true if the value is a constant float or integer. +bool isValConstZero(Value val) { + return matchPattern(val, m_AnyZeroFloat()) || matchPattern(val, m_Zero()); +} + +// Returns true if the attribute represent "all zeros" +static bool isZeroAttr(Attribute attribute) { + return TypeSwitch(attribute) + .Case([](auto attr) { return attr.getValueAsDouble() == 0.0; }) + .Case([](auto attr) { return attr.getInt() == 0; }) + .Case([](auto attr) { + if (!attr.getElementType().isIntOrFloat()) + return false; + if (!attr.isSplat()) + return false; + auto splat = attr.template getSplatValue(); + return isZeroAttr(splat); + }) + .Default([](auto attr) { return false; }); +} + +// Prototypes +bool isZeroOp(Operation *); + +// Returns true if the value represents a zero filled tensor. +// Recurse into isZeroOp for defining ops if not immediately obvious +// Looks past linalg generic's argument (which don't have defining ops) +bool isZeroTensor(Value val) { + if (!val) + return false; + if (isValConstZero(val)) + return true; + + Operation *defOp = nullptr; + + // Block arguments don't have a defining op, but they do have an op arg + if (auto arg = dyn_cast(val)) { + // We need to find the argument to the linalg on the same order as this one + auto *linalgOp = arg.getParentRegion()->getParentOp(); + if (!isa(linalgOp)) + return false; + auto index = arg.getArgNumber(); + auto linalgArg = linalgOp->getOperand(index); + defOp = linalgArg.getDefiningOp(); + } else { + defOp = val.getDefiningOp(); + } + return isZeroOp(defOp); +} + +// Returns true if the operation represents a zero filled tensor +// Recurses into isZeroTensor for operands and isZeroAttr for attributes +bool isZeroOp(Operation *defOp) { + if (!defOp) + return false; + + return TypeSwitch(defOp) + .Case([&](auto op) { + // Dense attributes don't match APFloat.isZero() + auto attr = op.getValue(); + return isZeroAttr(attr); + }) + .Case([&](auto op) { + if (op.getInputs().size() != 1) + return false; + return isZeroTensor(op.getInputs()[0]); + }) + .Case( + [&](auto op) { return isZeroTensor(op.getSource()); }) + .Case([&](auto op) { + auto name = op.getName(); + auto module = defOp->getParentOfType(); + auto global = module.lookupSymbol(name); + auto attr = global.getInitialValueAttr(); + return isZeroAttr(attr); + }) + .Default([&](Operation *op) { return false; }); +} + +FailureOr> getStaticStrides(Value value) { + auto valueType = value.getType(); + if (!isa(valueType)) + return failure(); + auto memrefType = cast(valueType); + SmallVector strides; + int64_t offset; + if (failed(getStridesAndOffset(memrefType, strides, offset))) { + return failure(); + } + if (llvm::any_of(strides, [](int64_t stride) { + return stride == ShapedType::kDynamic; + })) { + return failure(); + } + return strides; +} + +std::pair getPtrAndOffset(OpBuilder &builder, Value operand, + Location loc) { + auto memrefType = dyn_cast(operand.getType()); + assert(memrefType && "Expect a memref value"); + MemRefType baseMemrefType = MemRefType::get({}, memrefType.getElementType()); + Type basePtrType = builder.getIndexType(); + Type offsetType = builder.getIndexType(); + SmallVector sizesTypes(memrefType.getRank(), offsetType); + SmallVector stridesTypes(memrefType.getRank(), offsetType); + auto meta = builder.create( + loc, baseMemrefType, offsetType, sizesTypes, stridesTypes, operand); + Value alignedPointerAsIndex = + builder.create(loc, basePtrType, + operand); + Value alignedPointerAsI64 = builder.create( + loc, builder.getIntegerType(64), alignedPointerAsIndex); + // TODO: non-POD will require an LLVMTypeConverter. + Value alignedPointer = builder.create( + loc, LLVM::LLVMPointerType::get(builder.getContext()), + alignedPointerAsI64); + Value offset = meta.getOffset(); + return std::make_pair(alignedPointer, offset); +} + +} // namespace utils +} // namespace mlir diff --git a/third_party/cpu/lib/Xsmm/ValueUtils.h b/third_party/cpu/lib/Xsmm/ValueUtils.h new file mode 100644 index 000000000000..8cd50146d41c --- /dev/null +++ b/third_party/cpu/lib/Xsmm/ValueUtils.h @@ -0,0 +1,50 @@ +//===- ValueUtils.h - -------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TPP_TRANSFORMS_UTILS_VALUEUTILS_H +#define TPP_TRANSFORMS_UTILS_VALUEUTILS_H + +#include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" + +#include +#include + +using namespace mlir; + +namespace mlir { +class OpBuilder; +class Operation; +class Location; +namespace utils { + +// Returns true if the value is a constant float or integer. +bool isValConstZero(Value val); + +// Returns true if the op defining `val` represents a zero filled tensor. +bool isZeroTensor(Value val); + +// Returns true if the operation represents a zero filled tensor. +bool isZeroOp(Operation *); + +// Returns the strides of `val`. The method returns something usefull +// only if the `val` type is a strided memref and the strides are statically +// known. +FailureOr> getStaticStrides(Value val); + +// Return the offset and ptr for `val`. Assert if `val` +// is not a memref. +std::pair getPtrAndOffset(OpBuilder &builder, Value val, + Location loc); + +} // namespace utils +} // namespace mlir + +#endif // TPP_TRANSFORMS_UTILS_VALUEUTILS_H diff --git a/third_party/cpu/lib/Xsmm/VnniUtils.cpp b/third_party/cpu/lib/Xsmm/VnniUtils.cpp new file mode 100644 index 000000000000..6df29f993c60 --- /dev/null +++ b/third_party/cpu/lib/Xsmm/VnniUtils.cpp @@ -0,0 +1,89 @@ +//===- VNNIUtils.cpp ---------------------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "VnniUtils.h" + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" + +#include "libxsmm.h" + +namespace mlir { +namespace vnni { +namespace utils { + +std::optional getVnniBlockingFactor(Type type) { + auto elementType = getElementTypeOrSelf(type); + if (elementType.isBF16()) + return libxsmm_cpuid_dot_pack_factor(LIBXSMM_DATATYPE_BF16); + return std::nullopt; +} + +// Until we have a better way to express the VNNI layout (see: #563), it is up +// to the callee to specify the expected rank in the VNNI layout as the rank +// depends on the operations we are dealing with. +bool isInVnniLayout(VnniOperandRank expectedRank, MemRefType memref) { + if (memref.getRank() != static_cast(expectedRank) || + !memref.getElementType().isBF16()) { + return false; + } + return memref.getShape().back() == vnni::utils::getVnniBlockingFactor(memref); +} + +bool isInVnniLayout(int64_t expectedRank, VectorType vector) { + if (vector.getRank() != expectedRank || !vector.getElementType().isBF16()) { + return false; + } + return vector.getShape().back() == vnni::utils::getVnniBlockingFactor(vector); +} + +// Until we have a better way to express the VNNI layout (see: #563), it is up +// to the callee to specify the expected rank in the VNNI layout as the rank +// depends on the operations we are dealing with. +bool isInVnniLayout(VnniOperandRank expectedRank, VectorType vector) { + return isInVnniLayout((int64_t)expectedRank, vector); +} + +FailureOr isInVnniLayout(linalg::GenericOp linalgOp, + AffineMap map, int64_t blockingFactor) { + ArrayRef results = map.getResults(); + SmallVector iteratorTypes = + linalgOp.getIteratorTypesArray(); + + AffineExpr vnniDim = results.back(); + auto dimExpr = dyn_cast(vnniDim); + if (!dimExpr || iteratorTypes[dimExpr.getPosition()] != + mlir::utils::IteratorType::reduction) { + return failure(); + } + + for (auto result : results) { + auto blockeDim = dyn_cast(result); + if (!blockeDim) + continue; + if (blockeDim.getKind() != AffineExprKind::FloorDiv) + continue; + auto lhsDim = dyn_cast(blockeDim.getLHS()); + auto rhsCst = dyn_cast(blockeDim.getRHS()); + if (!lhsDim || !rhsCst) + continue; + if (iteratorTypes[lhsDim.getPosition()] != + mlir::utils::IteratorType::reduction) + continue; + if (rhsCst.getValue() != blockingFactor) + continue; + return lhsDim; + } + return failure(); +} + +} // namespace utils +} // namespace vnni +} // namespace mlir diff --git a/third_party/cpu/lib/Xsmm/VnniUtils.h b/third_party/cpu/lib/Xsmm/VnniUtils.h new file mode 100644 index 000000000000..e8517a5d23e1 --- /dev/null +++ b/third_party/cpu/lib/Xsmm/VnniUtils.h @@ -0,0 +1,62 @@ +//===- VnniUtils.h -----------------------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TPP_TRANSFORMS_UTILS_VNNIUTILS_H +#define TPP_TRANSFORMS_UTILS_VNNIUTILS_H + +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Support/LogicalResult.h" + +#include +#include + +using namespace mlir; + +namespace mlir { +class Type; +class MemRefType; +class OpOperand; +class AffineDimExpr; +class AffineMap; + +namespace linalg { +class GenericOp; +} // namespace linalg + +namespace vnni { +namespace utils { + +enum class VnniOperandRank { + TRANSPOSE = 3, + GEMM = 3, + BRGEMM_INS = 4, + BRGEMM_OUTS = 3 +}; + +// Return the VNNI blocking factor: 2 for BF16 and 4 for BF8. +std::optional getVnniBlockingFactor(Type type); + +// Return true if the memref is in VNNI layout with rank `expectedRank`. +bool isInVnniLayout(VnniOperandRank expectedRank, MemRefType memref); + +bool isInVnniLayout(int64_t expectedRank, VectorType vector); + +// Return true if the memref is in VNNI layout with rank `expectedRank`. +bool isInVnniLayout(VnniOperandRank expectedRank, VectorType vector); + +// Return the first AffineDimExpr in the map `affineMap` +// with a VNNI layout pattern (AffineDimExpr floordiv VNNI). +FailureOr isInVnniLayout(linalg::GenericOp linalgOp, + AffineMap affineMap, + int64_t blockingFactor); + +} // namespace utils +} // namespace vnni +} // namespace mlir + +#endif diff --git a/third_party/cpu/lib/Xsmm/XsmmEnum.cpp b/third_party/cpu/lib/Xsmm/XsmmEnum.cpp new file mode 100644 index 000000000000..85766e5272f0 --- /dev/null +++ b/third_party/cpu/lib/Xsmm/XsmmEnum.cpp @@ -0,0 +1,15 @@ +//===- XsmmEnum.cpp - Xsmm dialect enum -------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "cpu/include/Xsmm/XsmmEnum.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::xsmm; + +#include "cpu/include/Xsmm/XsmmEnum.cpp.inc" diff --git a/third_party/cpu/lib/Xsmm/XsmmUtils.cpp b/third_party/cpu/lib/Xsmm/XsmmUtils.cpp new file mode 100644 index 000000000000..c5a60d2627de --- /dev/null +++ b/third_party/cpu/lib/Xsmm/XsmmUtils.cpp @@ -0,0 +1,1084 @@ +//===- XsmmUtils.cpp ---------------------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "XsmmUtils.h" +#include "ValueUtils.h" +#include "VnniUtils.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Value.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Compiler.h" + +#include +#include + +#define DEBUG_TYPE "xsmm-utils" + +using namespace mlir; +using namespace mlir::linalg; + +namespace mlir { +namespace xsmm { +namespace utils { + +// Callable object to verify if `operand` has static shape. +struct HasStaticShape { + HasStaticShape() = default; + HasStaticShape(SmallVectorImpl *shape) : shape(shape){}; + + bool operator()(Value operand, Operation *op) const { + auto operandType = operand.getType(); + if (auto shapedType = dyn_cast_or_null(operandType)) { + if (!shapedType.hasStaticShape()) + return false; + if (shape) { + for (int64_t shapeOnDim : shapedType.getShape()) + shape->push_back(shapeOnDim); + } + } + return true; + } + SmallVectorImpl *shape = nullptr; +}; + +// Callable object to verify if `operand` has static strides. +// If `operand` is a tensor type or a scalar, return true. +struct HasStaticStrides { + HasStaticStrides() = default; + HasStaticStrides(SmallVector *strides) : strides(strides){}; + + bool operator()(Value operand, Operation *op) const { + auto operandType = operand.getType(); + SmallVector strides; + if (auto memRefType = dyn_cast_or_null(operandType)) { + int64_t offset; + if (failed(getStridesAndOffset(memRefType, strides, offset))) + return false; + if (llvm::any_of(strides, [](int64_t stride) { + return stride == ShapedType::kDynamic; + })) { + return false; + } + if (this->strides) + this->strides->append(strides.begin(), strides.end()); + } + return true; + } + SmallVectorImpl *strides = nullptr; +}; + +// Structural matcher. +static FailureOr +checkStructure(vector::ContractionOp contractOp, SmallVector &inputs, + SmallVector &outputs, ArrayRef indexingMap) { + if (!HasStaticShape()(inputs[0], inputs[0].getDefiningOp()) || + !HasStaticShape()(inputs[1], inputs[1].getDefiningOp()) || + !HasStaticShape()(inputs[2], inputs[2].getDefiningOp()) || + (outputs[0] != nullptr && + !HasStaticShape()(outputs[0], outputs[0].getDefiningOp())) || + !HasStaticStrides()(inputs[0], inputs[0].getDefiningOp()) || + !HasStaticStrides()(inputs[1], inputs[1].getDefiningOp()) || + !HasStaticStrides()(inputs[2], inputs[2].getDefiningOp()) || + (outputs[0] != nullptr && + !HasStaticStrides()(outputs[0], outputs[0].getDefiningOp()))) { + return failure(); + } + return inferContractionDims(indexingMap); +} + +// Return the position of `dim` in the codomain of `operand`. +std::optional getPosInCodomain(unsigned dim, Value operand, + vector::ContractionOp contractOp, + AffineMap map) { + return map.getResultPosition(getAffineDimExpr(dim, contractOp.getContext())); +} + +static SmallVector +createFlatListOfOperandStaticDims(vector::ContractionOp contractOp) { + SmallVector res; + for (OpOperand &opOperand : contractOp.getOperation()->getOpOperands()) + llvm::append_range( + res, dyn_cast(opOperand.get().getType()).getShape()); + return res; +} + +static SmallVector +computeStaticLoopSizes(vector::ContractionOp contractOp, + ArrayRef maps) { + AffineMap map = concatAffineMaps(maps); + unsigned numDims = map.getNumDims(), numRes = map.getNumResults(); + SmallVector allShapeSizes = + createFlatListOfOperandStaticDims(contractOp); + SmallVector res(numDims, 0); + for (unsigned idx = 0; idx < numRes; ++idx) { + auto result = map.getResult(idx); + if (auto d = dyn_cast(result)) + res[d.getPosition()] = allShapeSizes[idx]; + } + return res; +} + +static FailureOr> +getVNNIStaticStrides(MemRefType valueType) { + SmallVector strides; + int64_t offset; + SmallVector shape; + for (size_t i = 0; i < valueType.getShape().size(); i++) { + shape.push_back(valueType.getShape()[i]); + } + auto temp = shape[shape.size() - 1]; + shape[shape.size() - 1] = shape[shape.size() - 2]; + shape[shape.size() - 2] = temp; + auto memrefType = MemRefType::get(shape, valueType.getElementType()); + if (failed(getStridesAndOffset(memrefType, strides, offset))) { + return failure(); + } + if (llvm::any_of(strides, [](int64_t stride) { + return stride == ShapedType::kDynamic; + })) { + return failure(); + } + return strides; +} + +// Access matcher. +FailureOr +checkAccess(PatternRewriter &rewriter, vector::ContractionOp contractOp, + unsigned m, unsigned n, SmallVector kVector, + std::optional batchPos, SmallVector inputs, + ArrayRef indexingMap) { + Value operandA = inputs[0]; + Value operandB = inputs[1]; + Value operandC = inputs[2]; + + unsigned k; + if (*xsmm::utils::getPosInCodomain( + kVector[0], contractOp->getOpOperand(1).get(), contractOp, + contractOp.getIndexingMapsArray()[1]) < + *xsmm::utils::getPosInCodomain( + n, contractOp->getOpOperand(1).get(), contractOp, + contractOp.getIndexingMapsArray()[1]) || + kVector.size() == 1) { + k = kVector[0]; + } else if (kVector.size() > 1) { + k = kVector[1]; + } + + auto checkStridesAndGetLda = [&](unsigned minorDim, unsigned majorDim, + Value operand, AffineMap map, + int operandIndex) -> FailureOr { + auto minorDimPosInCodomain = + xsmm::utils::getPosInCodomain(minorDim, operand, contractOp, map); + auto majorDimPosInCodomain = + xsmm::utils::getPosInCodomain(majorDim, operand, contractOp, map); + if (!minorDimPosInCodomain || !majorDimPosInCodomain) { + return failure(); + } + auto dataType = xsmm::utils::getDataType(rewriter, operand.getType()); + FailureOr> stridesOnOperand; + if (dataType == + DataTypeAttr::get(contractOp.getContext(), xsmm::DataType::BF16) && + operandIndex == 1) { + stridesOnOperand = + getVNNIStaticStrides(dyn_cast(operand.getType())); + } else { + stridesOnOperand = mlir::utils::getStaticStrides(operand); + } + if (failed(stridesOnOperand) || + (dataType == + DataTypeAttr::get(contractOp.getContext(), xsmm::DataType::BF16) && + operandIndex == 0 && + (*stridesOnOperand)[*minorDimPosInCodomain] != 2) || + ((dataType != DataTypeAttr::get(contractOp.getContext(), + xsmm::DataType::BF16) && + (*stridesOnOperand)[*minorDimPosInCodomain] != 1))) { + return failure(); + } + if (dataType == + DataTypeAttr::get(contractOp.getContext(), xsmm::DataType::BF16) && + operandIndex == 1) { + return (*stridesOnOperand)[*majorDimPosInCodomain + 1]; + } else { + return (*stridesOnOperand)[*majorDimPosInCodomain]; + } + }; + // A(m, k) + auto lda = checkStridesAndGetLda(k, m, operandA, indexingMap[0], 0); + if (failed(lda)) { + LLVM_DEBUG(llvm::dbgs() << "Failed to compute lda\n"); + return failure(); + } + LLVM_DEBUG(llvm::dbgs() << "[isMappableToBrge" + "mm] Strides on " + "A: OK\n"); + + // B(k, n) + auto ldb = checkStridesAndGetLda(n, k, operandB, indexingMap[1], 1); + if (failed(ldb)) { + LLVM_DEBUG(llvm::dbgs() << "Failed to compute ldb\n"); + + return failure(); + } + LLVM_DEBUG(llvm::dbgs() << "[isMappableToBrge" + "mm] Strides on " + "B: OK\n"); + + // C(m, n) + auto ldc = checkStridesAndGetLda(n, m, operandC, indexingMap[2], 2); + if (failed(ldc)) { + LLVM_DEBUG(llvm::dbgs() << "Failed to compute ldc\n"); + return failure(); + } + LLVM_DEBUG(llvm::dbgs() << "[isMappableToBrge" + "mm] Strides on " + "C: OK\n"); + + int64_t strideA = 1; + int64_t strideB = 1; + if (batchPos) { + auto batchPosCodomainA = getPosInCodomain(batchPos.value(), operandA, + contractOp, indexingMap[0]); + auto stridesOnA = ::mlir::utils::getStaticStrides(operandA); + strideA = (*stridesOnA)[*batchPosCodomainA]; + + auto batchPosCodomainB = getPosInCodomain(batchPos.value(), operandB, + contractOp, indexingMap[1]); + auto stridesOnB = ::mlir::utils::getStaticStrides(operandB); + strideB = (*stridesOnB)[*batchPosCodomainB]; + } + + auto loops = computeStaticLoopSizes(contractOp, indexingMap); + int64_t batchVal = (batchPos) ? loops[batchPos.value()] : 0; + + auto loopsK = 1; + for (auto kItr : kVector) + loopsK *= loops[kItr]; + + xsmm::BrgemmInfo info{loops[m], loops[n], loopsK, batchVal, *lda, + *ldb, *ldc, strideA, strideB}; + return info; +} + +// Check if the given +// generic is mappable to a +// brgemm xsmm op. +// - It is a contraction, +// with: +// -- 1 m and 1 n and 2 k +// dimensions. +// -- m appears on the LHS +// and OUT but not in RHS. +// -- n appears on the RHS +// and OUT but not in LHS. +// -- k and k' appear on the +// RHS and LHS but not OUT. +// -- the stride of the +// minor dimension for A, k +// is 1. +// -- the stride of the +// minor dimension for B, n +// is 1. +// -- the stride of the +// minor dimension for C, n +// is 1. +FailureOr isMappableToBrgemm(PatternRewriter &rewriter, + vector::ContractionOp contractOp, + SmallVector &inputs, + SmallVector &output, + ArrayRef indexingMap) { + auto contractionDims = + checkStructure(contractOp, inputs, output, indexingMap); + if (failed(contractionDims)) { + LLVM_DEBUG(llvm::dbgs() << "[isMappableToBr" + "gemm] Failed " + "on " + "checkStructure" + "\n"); + return failure(); + } + unsigned m = contractionDims->m.back(); + unsigned n = contractionDims->n.back(); + SmallVector kVector; + std::optional batch; + if (contractionDims->k.size() >= 2) { + for (size_t i = 1; i < contractionDims->k.size(); i++) + kVector.push_back(contractionDims->k[i]); + } else { + for (size_t i = 0; i < contractionDims->k.size(); i++) + kVector.push_back(contractionDims->k[i]); + } + + LLVM_DEBUG(llvm::dbgs() << "[isMappableToBrge" + "mm] Candidate " + "dims: " + << "\n"); + LLVM_DEBUG(llvm::dbgs() << "[isMappableToBrge" + "mm] m: " + << m << "\n"); + LLVM_DEBUG(llvm::dbgs() << "[isMappableToBrge" + "mm] n: " + << n << "\n"); + if (batch) + LLVM_DEBUG(llvm::dbgs() << "[isMappableToBr" + "gemm] batch: " + << batch << "\n"); + else + LLVM_DEBUG(llvm::dbgs() << "[isMappableToBr" + "gemm] no batch " + "dim\n"); + auto retval = checkAccess(rewriter, contractOp, m, n, kVector, batch, inputs, + indexingMap); + if (failed(retval)) { + LLVM_DEBUG(llvm::dbgs() << "Failed to check access\n"); + return failure(); + } + return retval; +} + +DataTypeAttr getDataType(RewriterBase &rewriter, Type type) { + auto elemType = getElementTypeOrSelf(type); + if (elemType.isBF16()) + return DataTypeAttr::get(rewriter.getContext(), xsmm::DataType::BF16); + return DataTypeAttr::get(rewriter.getContext(), xsmm::DataType::F32); +} + +FailureOr getUnaryInfo(Value input, Value output, + UnaryFlags inputFlag) { + Type outputType = output.getType(); + + assert(isa(outputType)); + auto outputShapedType = cast(outputType); + if (outputShapedType.getRank() != 2 || !outputShapedType.hasStaticShape() || + !isa(outputShapedType.getElementType())) { + return failure(); + } + + UnaryInfo unaryInfo; + unaryInfo.m = outputShapedType.getShape()[0]; + unaryInfo.n = outputShapedType.getShape()[1]; + + int64_t ldi = 1; + if (ShapedType inputShapedType = dyn_cast(input.getType())) { + auto stridesOnInput = mlir::utils::getStaticStrides(input); + if (failed(stridesOnInput) || stridesOnInput->back() != 1 || + !inputShapedType.hasStaticShape()) { + return failure(); + } + + // If we are broascasting a row into cols, the leading + // dimension is 1, same for scalar broadcast. + if (inputFlag == UnaryFlags::BCAST_ROW || + inputFlag == UnaryFlags::BCAST_SCALAR) { + ldi = 1; + } + // If we are broascasting a col into rows, the leading + // dimension is the size of the tensor. + else if (inputFlag == UnaryFlags::BCAST_COL) { + ldi = inputShapedType.getShape().back(); + } else { + ldi = stridesOnInput->front(); + } + } + auto stridesOnOutput = mlir::utils::getStaticStrides(output); + if (failed(stridesOnOutput) || stridesOnOutput->back() != 1) + return failure(); + + unaryInfo.ldi = ldi; + unaryInfo.ldo = stridesOnOutput->front(); + return unaryInfo; +} + +FailureOr getBinaryInfo(Value lhs, BinaryFlags lhsFlag, Value rhs, + BinaryFlags rhsFlag, Value output) { + Type outputType = output.getType(); + + assert(isa(outputType)); + auto outputShapedType = cast(outputType); + if (outputShapedType.getRank() != 2 || !outputShapedType.hasStaticShape() || + !isa(outputShapedType.getElementType())) { + return failure(); + } + + BinaryInfo binaryInfo; + binaryInfo.m = outputShapedType.getShape()[0]; + binaryInfo.n = outputShapedType.getShape()[1]; + + int64_t ldiLhs = 1; + if (ShapedType lhsShapedType = dyn_cast(lhs.getType())) { + auto stridesOnLhs = mlir::utils::getStaticStrides(lhs); + if (failed(stridesOnLhs) || stridesOnLhs->back() != 1 || + !lhsShapedType.hasStaticShape()) { + return failure(); + } + + if (lhsFlag == BinaryFlags::BCAST_SCALAR_IN_0 || + lhsFlag == BinaryFlags::BCAST_ROW_IN_0) { + ldiLhs = 1; + } else if (lhsFlag == BinaryFlags::BCAST_COL_IN_0) { + ldiLhs = lhsShapedType.getShape().back(); + } else { + ldiLhs = stridesOnLhs->front(); + } + } + + int64_t ldiRhs = 1; + if (ShapedType rhsShapedType = dyn_cast(rhs.getType())) { + auto stridesOnRhs = mlir::utils::getStaticStrides(rhs); + if (failed(stridesOnRhs) || stridesOnRhs->back() != 1 || + !rhsShapedType.hasStaticShape()) { + return failure(); + } + + if (rhsFlag == BinaryFlags::BCAST_SCALAR_IN_1 || + rhsFlag == BinaryFlags::BCAST_ROW_IN_1) { + ldiRhs = 1; + } else if (rhsFlag == BinaryFlags::BCAST_COL_IN_1) { + ldiRhs = rhsShapedType.getShape().back(); + } else { + ldiRhs = stridesOnRhs->front(); + } + } + + binaryInfo.ldiLhs = ldiLhs; + binaryInfo.ldiRhs = ldiRhs; + + auto stridesOnOutput = mlir::utils::getStaticStrides(output); + if (failed(stridesOnOutput) || stridesOnOutput->back() != 1) + return failure(); + binaryInfo.ldo = stridesOnOutput->front(); + return binaryInfo; +} + +// Examples: +// If lower=[c], higher=[a, b, c], [c] reshaped into [1, 1, c]. +// If lower=[b, c], higher=[a, b, c], [b, c] reshaped into [1, b, c]. +// If lower=[a], higher=[a, a], [a] reshaped into [1, a]. +// If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a]. +// If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1]. +static void +computeBcastShapeInput(ArrayRef higherRankShape, + ArrayRef lowerRankShape, + SmallVectorImpl &reshapeOutputShape) { + // Initialize new shapes with [1] * higherRank. + int64_t higherRank = higherRankShape.size(); + int64_t lowerRank = lowerRankShape.size(); + + reshapeOutputShape.assign(higherRank, 1); + + int64_t higherRankDim; + int64_t lowerRankDim; + + for (int64_t i = higherRank - 1, j = lowerRank - 1; i >= 0 && j >= 0; + i--, j--) { + higherRankDim = higherRankShape[i]; + lowerRankDim = lowerRankShape[j]; + + if (lowerRankDim == 1 && higherRankDim > 1) + reshapeOutputShape[i] = 1; + else if ((lowerRankDim > 1 && higherRankDim == 1) || + (lowerRankDim == higherRankDim)) + reshapeOutputShape[i] = lowerRankDim; + else if (higherRankDim != lowerRankDim) + assert(false && "bCast semantics for identity op broken"); + } +} + +FailureOr getUnaryFlags(Type inputType, Type outputType) { + assert(isa(outputType) && "expect shaped type on output"); + assert(cast(outputType).getRank() == 2 && + "expect rank 2 on output"); + + if (!isa(inputType) || + cast(inputType).getRank() == 0) { + return xsmm::UnaryFlags::BCAST_SCALAR; + } + + ArrayRef shapeOutput = cast(outputType).getShape(); + ArrayRef shapeInput = cast(inputType).getShape(); + assert(shapeOutput.size() >= shapeInput.size() && + "output rank must be >= input rank"); + SmallVector bShapeInput; + computeBcastShapeInput(shapeOutput, shapeInput, bShapeInput); + assert(shapeOutput.size() == bShapeInput.size()); + shapeInput = bShapeInput; + + // Same shape for input and output, no bcast. + if (shapeInput == shapeOutput) + return xsmm::UnaryFlags::NONE; + + // Input is a memref but it is all ones, bcast = scalar. + auto isOne = [](int64_t val) { return val == 1; }; + if (llvm::all_of(shapeInput, isOne)) + return xsmm::UnaryFlags::BCAST_SCALAR; + + if (shapeInput[1] == 1 && shapeOutput[1] > 1) + return xsmm::UnaryFlags::BCAST_ROW; + + if (shapeInput[0] == 1 && shapeOutput[0] > 1) + return xsmm::UnaryFlags::BCAST_COL; + + return failure(); +} + +FailureOr getBinFlags(ArrayRef shapeOutput, + ArrayRef shapeOperand, + OperandPos operandNumber) { + assert(shapeOutput.size() >= shapeOperand.size() && + "Output rank must be >= operand rank"); + SmallVector bOperandShape; + computeBcastShapeInput(shapeOutput, shapeOperand, bOperandShape); + assert(shapeOutput.size() == bOperandShape.size()); + assert(shapeOutput.size() == 2); + enum class BCastType { NONE = 0, SCALAR, ROW, COL }; + auto getBCastEnum = [](BCastType bCastType, + OperandPos operandPos) -> xsmm::BinaryFlags { + switch (bCastType) { + case BCastType::NONE: + return xsmm::BinaryFlags::NONE; + case BCastType::SCALAR: + if (operandPos == OperandPos::LHS) + return xsmm::BinaryFlags::BCAST_SCALAR_IN_0; + else + return xsmm::BinaryFlags::BCAST_SCALAR_IN_1; + case BCastType::ROW: + if (operandPos == OperandPos::LHS) + return xsmm::BinaryFlags::BCAST_ROW_IN_0; + else + return xsmm::BinaryFlags::BCAST_ROW_IN_1; + case BCastType::COL: + if (operandPos == OperandPos::LHS) + return xsmm::BinaryFlags::BCAST_COL_IN_0; + else + return xsmm::BinaryFlags::BCAST_COL_IN_1; + } + assert(false && "unrechable"); + abort(); + }; + + if (bOperandShape == shapeOutput) + return getBCastEnum(BCastType::NONE, operandNumber); + + auto isOne = [](int64_t val) { return val == 1; }; + if (llvm::all_of(bOperandShape, isOne)) + return getBCastEnum(BCastType::SCALAR, operandNumber); + + if (bOperandShape[1] == 1 && shapeOutput[1] > 1) + return getBCastEnum(BCastType::ROW, operandNumber); + + if (bOperandShape[0] == 1 && shapeOutput[0] > 1) + return getBCastEnum(BCastType::COL, operandNumber); + + return failure(); +} + +FailureOr getBinaryFlags(Type operandType, Type outputType, + OperandPos operandNumber) { + assert(isa(outputType) && "expect shaped type on output"); + assert(cast(outputType).getRank() == 2 && + "expect rank 2 on output"); + + if (!isa(operandType) || + cast(operandType).getRank() == 0) { + if (operandNumber == OperandPos::LHS) + return xsmm::BinaryFlags::BCAST_SCALAR_IN_0; + return xsmm::BinaryFlags::BCAST_SCALAR_IN_1; + } + + enum class BCastType { NONE = 0, SCALAR, ROW, COL }; + auto shapeOutput = cast(outputType).getShape(); + auto shapeOperand = cast(operandType).getShape(); + return getBinFlags(shapeOutput, shapeOperand, operandNumber); +} + +FailureOr getBinaryFlagsVectorType(Type operandType, + Type outputType, + OperandPos operandNumber) { + assert(isa(outputType) && "expect shaped type on output"); + assert(cast(outputType).getRank() == 2 && + "expect rank 2 on output"); + + if (!isa(operandType) || + cast(operandType).getRank() == 0) { + if (operandNumber == OperandPos::LHS) + return xsmm::BinaryFlags::BCAST_SCALAR_IN_0; + return xsmm::BinaryFlags::BCAST_SCALAR_IN_1; + } + + enum class BCastType { NONE = 0, SCALAR, ROW, COL }; + auto shapeOutput = cast(outputType).getShape(); + auto shapeOperand = cast(operandType).getShape(); + return getBinFlags(shapeOutput, shapeOperand, operandNumber); +} + +FailureOr getLeadingDim(Type type, size_t pos) { + // Not shaped type, the leading dimension is the single scalar. + auto memref = dyn_cast(type); + if (!memref) + return 1; + // For 1d memref we cannot use the stride as leading dimension, but the + // leading dimension is the dimension itself. + if (memref.getRank() == 1) + return memref.getShape()[0]; + + SmallVector strides; + int64_t offset; + if (failed(getStridesAndOffset(memref, strides, offset))) + return failure(); + // fail if the strides are non-constant + if (llvm::any_of(strides, [](int64_t stride) { + return stride == ShapedType::kDynamic; + })) + return failure(); + return strides[pos]; +} + +static bool isInnerMostDim(OpOperand *operand, unsigned minorDim, + vector::ContractionOp contractOp, DataTypeAttr dtype, + int operandNumber) { + auto shapedType = cast(operand->get().getType()); + int64_t rank = shapedType.getRank(); + if (dtype == + DataTypeAttr::get(contractOp.getContext(), xsmm::DataType::BF16) && + (operandNumber == 1 || operandNumber == 0)) { + return minorDim == rank - 2; + } + return minorDim == rank - 1; +} + +// Emit a transpose operation for `operand` by swapping `dim` with `newDim`. +// Emit a transpose operation for `operand` by swapping the dimensions at index +// `dim` with `newDim`. +static void emitTransposeOnOperand(RewriterBase &rewriter, + vector::ContractionOp contractOp, + Value operand, unsigned dim, unsigned newDim, + int operandNumber) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(contractOp); + + Location loc = contractOp.getLoc(); + auto operandType = cast(operand.getType()); + auto rank = operandType.getRank(); + SmallVector shape = llvm::to_vector(operandType.getShape()); + auto permutation = llvm::to_vector(llvm::seq(0, rank)); + std::swap(permutation[dim], permutation[newDim]); + assert(isPermutationVector(permutation)); + LLVM_DEBUG(llvm::interleaveComma( + permutation, llvm::dbgs() << "[emitTransposeOnOperand] Perm: ")); + LLVM_DEBUG(llvm::dbgs() << "\n"); + + applyPermutationToVector(shape, permutation); + auto vectorType = VectorType::get( + shape, cast(operand.getType()).getElementType()); + Value transposeResult = rewriter.create( + loc, vectorType, operand, permutation); + + SmallVector indexingMaps = contractOp.getIndexingMapsArray(); + AffineMap operandMap = indexingMaps[operandNumber]; + LLVM_DEBUG(llvm::dbgs() << "[emitTransposeOnOperand] Old map: " << operandMap + << "\n"); + SmallVector mapResults = llvm::to_vector(operandMap.getResults()); + applyPermutationToVector(mapResults, permutation); + AffineMap newMap = + AffineMap::get(operandMap.getNumDims(), operandMap.getNumSymbols(), + mapResults, contractOp.getContext()); + LLVM_DEBUG(llvm::dbgs() << "[emitTransposeOnOperand] New map: " << newMap + << "\n"); + indexingMaps[operandNumber] = newMap; + // TODO: We probably cannot update the result in place. + rewriter.modifyOpInPlace(contractOp, [&]() { + contractOp->setOperand(operandNumber, transposeResult); + contractOp.setIndexingMapsAttr( + ArrayAttr::get(contractOp.getContext(), + llvm::to_vector(llvm::map_range( + indexingMaps, [](AffineMap map) -> Attribute { + return AffineMapAttr::get(map); + })))); + }); +} + +FailureOr +makeMinorDimensionsInnerMost(RewriterBase &rewriter, + vector::ContractionOp contractOp, unsigned m, + unsigned n, unsigned k, DataTypeAttr type) { + OpOperand *operandA = &contractOp->getOpOperand(0); + OpOperand *operandB = &contractOp->getOpOperand(1); + OpOperand &operandC = contractOp->getOpOperand(2); + + // C(m,n) += A(m,k) * B(k,n) + // n is expected to be the innermost for C + // k is expected to be the innermost for A + // n is expected to be the innermost for B + auto minorKInCodomainOpA = xsmm::utils::getPosInCodomain( + k, operandA->get(), contractOp, contractOp.getIndexingMapsArray()[0]); + auto minorMInCodomainOpA = xsmm::utils::getPosInCodomain( + m, operandA->get(), contractOp, contractOp.getIndexingMapsArray()[0]); + if (!minorKInCodomainOpA || !minorMInCodomainOpA) { + LLVM_DEBUG( + llvm::dbgs() + << "[makeMinorDimensionsInnerMost] did not find minor dims for A\n"); + return failure(); + } + + auto minorNInCodomainOpB = xsmm::utils::getPosInCodomain( + n, operandB->get(), contractOp, contractOp.getIndexingMapsArray()[1]); + auto minorKInCodomainOpB = xsmm::utils::getPosInCodomain( + k, operandB->get(), contractOp, contractOp.getIndexingMapsArray()[1]); + if (!minorNInCodomainOpB || !minorKInCodomainOpB) { + LLVM_DEBUG( + llvm::dbgs() + << "[makeMinorDimensionsInnerMost] did not find minor dims for B\n"); + return failure(); + } + + auto minorNInCodomainOpC = xsmm::utils::getPosInCodomain( + n, operandC.get(), contractOp, contractOp.getIndexingMapsArray()[2]); + auto minorMInCodomainOpC = xsmm::utils::getPosInCodomain( + m, operandC.get(), contractOp, contractOp.getIndexingMapsArray()[2]); + if (!minorNInCodomainOpC || !minorMInCodomainOpC) { + LLVM_DEBUG( + llvm::dbgs() + << "[makeMinorDimensionsInnerMost] did not find minor dims for C\n"); + return failure(); + } + + if (!isInnerMostDim(&operandC, *minorNInCodomainOpC, contractOp, type, 2)) { + LLVM_DEBUG(llvm::dbgs() + << "[makeMinorDimensionsInnerMost] emit transpose for C\n"); + assert( + isInnerMostDim(&operandC, *minorMInCodomainOpC, contractOp, type, 2)); + if (isInnerMostDim(operandA, *minorKInCodomainOpA, contractOp, type, 0)) { + emitTransposeOnOperand(rewriter, contractOp, operandA->get(), + *minorKInCodomainOpA, *minorMInCodomainOpA, 0); + } + if (isInnerMostDim(operandB, *minorNInCodomainOpB, contractOp, type, 1)) { + emitTransposeOnOperand(rewriter, contractOp, operandB->get(), + *minorNInCodomainOpB, *minorKInCodomainOpB, 1); + } + // Avoid transpose on the output by swapping A and B. + OpOperand *operandA = &contractOp->getOpOperand(0); + OpOperand *operandB = &contractOp->getOpOperand(1); + SmallVector indexingMaps = contractOp.getIndexingMapsArray(); + std::swap(indexingMaps[0], indexingMaps[1]); + rewriter.modifyOpInPlace(contractOp, [&]() { + Value operandATmp = operandA->get(); + contractOp->setOperand(0, operandB->get()); + contractOp->setOperand(1, operandATmp); + contractOp.setIndexingMapsAttr( + ArrayAttr::get(contractOp.getContext(), + llvm::to_vector(llvm::map_range( + indexingMaps, [](AffineMap map) -> Attribute { + return AffineMapAttr::get(map); + })))); + }); + return contractOp; + } + + if (!isInnerMostDim(operandA, *minorKInCodomainOpA, contractOp, type, 0)) { + LLVM_DEBUG(llvm::dbgs() + << "[makeMinorDimensionsInnerMost] emit transpose for A\n"); + assert(isInnerMostDim(operandA, *minorMInCodomainOpA, contractOp, type, 0)); + emitTransposeOnOperand(rewriter, contractOp, operandA->get(), + *minorKInCodomainOpA, *minorMInCodomainOpA, 0); + } + if (!isInnerMostDim(operandB, *minorNInCodomainOpB, contractOp, type, 1)) { + LLVM_DEBUG(llvm::dbgs() + << "[makeMinorDimensionsInnerMost] emit transpose for B\n"); + assert(isInnerMostDim(operandB, *minorKInCodomainOpB, contractOp, type, 1)); + emitTransposeOnOperand(rewriter, contractOp, operandB->get(), + *minorKInCodomainOpB, *minorNInCodomainOpB, 1); + } + return contractOp; +} + +bool WithInputs(PatternRewriter &rewriter, Operation *op, + SmallVector> operations, + SmallVector &inputs, SmallVector &opChain) { + for (size_t i = 0; i < operations.size(); i++) { + auto input = op->getOperand(i); + if (!operations[i](input.getDefiningOp())) + return false; + if (input.getDefiningOp()->getOperand(0).getDefiningOp() != nullptr) { + if (!(isa( + input.getDefiningOp()->getOperand(0).getDefiningOp()) || + isa( + input.getDefiningOp()->getOperand(0).getDefiningOp()) || + isa( + input.getDefiningOp()->getOperand(0).getDefiningOp()) || + isa( + input.getDefiningOp()->getOperand(0).getDefiningOp()))) + return false; + } + inputs.push_back(input.getDefiningOp()->getOpOperand(0).get()); + opChain.push_back(input.getDefiningOp()); + } + return true; +} + +bool WithOutput(Operation *op, std::function operation, + SmallVector &output, SmallVector &opChain) { + // Check on the inner chain of operations in the right order. + // Make sure all operands are used and chained + for (auto use : op->getResult(0).getUsers()) { + if (use != op && operation(use)) { + if (!isa(use->getOperand(1).getDefiningOp())) + return false; + output.push_back(use->getOpOperand(1).get()); + opChain.push_back(use); + return true; + } + } + return false; +} + +bool WithOps(Region *region, Operation *op, Operation *currentOp, + SmallVector> operations, + SmallVector &opChain) { + auto &block = region->front(); + + llvm::SmallSetVector chainedValues; + + auto start = block.begin(); + for (auto opItr = block.begin(); opItr != block.end(); opItr++) { + if (&*opItr != currentOp || !operations[0](&*opItr)) + continue; + start = opItr; + opChain.push_back(&*opItr); + break; + } + // Check on the inner chain of operations in the right order. + // Make sure all operands are used and chained + for (auto check : operations) { + Operation *innerOp = &*start; + // Must be right op in right order + if (start == block.end() || !check(innerOp)) + return false; + start++; + // At least one operand must come from args or a previous op + bool consumesValueFromChain = false; + if (chainedValues.empty()) { + consumesValueFromChain = true; + } else { + for (auto operand : innerOp->getOperands()) { + if (chainedValues.contains(operand)) { + chainedValues.remove(operand); + consumesValueFromChain = true; + } + } + } + + // Operation isn't in the chain + if (!consumesValueFromChain) + return false; + + for (auto ret : innerOp->getResults()) { + chainedValues.insert(ret); + } + } + return true; +} + +bool isTwoDTransposeOp(vector::TransposeOp transposeOp) { + if (!(dyn_cast(transposeOp.getOperand().getType()).getRank() == + 2 && + dyn_cast(transposeOp.getResult().getType()).getRank() == + 2) || + !(isa(transposeOp->getParentOp()) && + dyn_cast(transposeOp->getParentOp()).getRank() == 2)) + return false; + return true; +} + +// Extract the operands to be used in the function call. For each memref operand +// extract the aligned pointer and the offset. +SmallVector getOperands(OpBuilder &builder, Location loc, + ValueRange operands, IntegerAttr dataTypeAttr) { + SmallVector res; + IntegerType integer64 = IntegerType::get(builder.getContext(), 64); + res.push_back( + builder.create(loc, integer64, dataTypeAttr)); + + for (Value operand : operands) { + auto memrefType = dyn_cast(operand.getType()); + if (!memrefType) { + res.push_back(operand); + continue; + } + auto [ptr, offset] = ::mlir::utils::getPtrAndOffset(builder, operand, loc); + res.push_back(ptr); + res.push_back(offset); + } + return res; +} + +SmallVector extractInvokeOperandTypes(OpBuilder &builder, + ValueRange operands) { + SmallVector results; + // One extra operand for datatype + IntegerType integer64 = IntegerType::get(builder.getContext(), 64); + results.push_back(integer64); + for (Value operand : operands) { + Type operandType = operand.getType(); + if (auto memrefType = dyn_cast(operandType)) { + // TODO: non-POD will require an LLVMTypeConverter. + Type basePtrType = LLVM::LLVMPointerType::get(builder.getContext()); + results.push_back(basePtrType); + results.push_back(builder.getIndexType()); // offset + } else { + results.push_back(operand.getType()); + } + } + return results; +} + +int64_t getOredFlags(ArrayAttr flags) { + int64_t oredFlag = 0; + for (auto flag : flags) { + int64_t intAttr = dyn_cast(flag).getInt(); + // LIBXSMM is col-major, swap A and B flags. + if (auto gemmFlag = dyn_cast_or_null(flag)) { + if (gemmFlag.getValue() == GemmFlags::VNNI_A) + intAttr = static_cast(GemmFlags::VNNI_B); + if (gemmFlag.getValue() == GemmFlags::VNNI_B) + intAttr = static_cast(GemmFlags::VNNI_A); + } + oredFlag |= intAttr; + } + return oredFlag; +} + +func::CallOp buildDispatchCall(RewriterBase &rewriter, Location loc, + ArrayRef dispatchOperands, + ArrayRef dispatchOperandTypes, + ModuleOp module, FlatSymbolRefAttr fnName) { + auto libFnType = rewriter.getFunctionType( + dispatchOperandTypes, IntegerType::get(rewriter.getContext(), 64)); + + if (!module.lookupSymbol(fnName.getAttr())) { + OpBuilder::InsertionGuard guard(rewriter); + // Insert before module terminator. + rewriter.setInsertionPoint(module.getBody(), + std::prev(module.getBody()->end())); + func::FuncOp funcOp = + rewriter.create(loc, fnName.getValue(), libFnType); + funcOp.setPrivate(); + } + + func::CallOp call = rewriter.create( + loc, fnName.getValue(), IntegerType::get(rewriter.getContext(), 64), + dispatchOperands); + return call; +} + +func::CallOp buildInvokeCall(RewriterBase &rewriter, Location loc, + ModuleOp module, SmallVector operandRange, + StringRef invokeName, DataTypeAttr dtype) { + auto libFnType = rewriter.getFunctionType( + xsmm::utils::extractInvokeOperandTypes(rewriter, operandRange), {}); + FlatSymbolRefAttr fnName = + SymbolRefAttr::get(rewriter.getContext(), invokeName); + + if (!module.lookupSymbol(fnName)) { + OpBuilder::InsertionGuard guard(rewriter); + // Insert before module terminator. + rewriter.setInsertionPoint(module.getBody(), + std::prev(module.getBody()->end())); + func::FuncOp funcOp = + rewriter.create(loc, invokeName, libFnType); + funcOp.setPrivate(); + } + + func::CallOp call = rewriter.create( + loc, fnName, TypeRange(), + xsmm::utils::getOperands(rewriter, loc, operandRange, dtype)); + + return call; +} + +std::pair +buildBrgemmCalls(PatternRewriter &rewriter, Operation *op, ValueRange inputs, + xsmm::BrgemmInfo brgemmInfo, SmallVector flags) { + assert(inputs.size() == 3 && "Expects three inputs for BRGEMM call"); + auto m = brgemmInfo.m; + auto n = brgemmInfo.n; + auto k = brgemmInfo.k; + auto batch = brgemmInfo.batch; + int64_t lda = brgemmInfo.lda; + int64_t ldb = brgemmInfo.ldb; + int64_t ldc = brgemmInfo.ldc; + int64_t strideA = brgemmInfo.strideA; + int64_t strideB = brgemmInfo.strideB; + auto loc = op->getLoc(); + auto dtype = xsmm::utils::getDataType(rewriter, inputs[0].getType()); + IntegerType integer64 = IntegerType::get(rewriter.getContext(), 64); + SmallVector dispatchOperands; + SmallVector dispatchOperandTypes; + // Dispatch the data type. + dispatchOperands.push_back(rewriter.create( + loc, integer64, cast(dtype))); + dispatchOperandTypes.push_back(integer64); + + ArrayAttr brgemmFlags = rewriter.getArrayAttr(flags); + SmallVector invokeOperands; + std::string dispatchName = "xsmm_gemm_dispatch"; + std::string invokeName = "xsmm_gemm_invoke"; + + if (batch != 0) { + dispatchName = "xsmm_brgemm_dispatch"; + invokeName = "xsmm_brgemm_invoke"; + } + + auto dims = SmallVector{m, n, k, lda, ldb, ldc}; + if (batch != 0) { + dims.append({strideA, strideB}); + } + for (size_t idx = 0; idx < dims.size(); idx++) { + dispatchOperands.push_back(rewriter.create( + loc, integer64, rewriter.getIntegerAttr(integer64, dims[idx]))); + dispatchOperandTypes.push_back(integer64); + } + // Dispatch the flags. Pass to the library the already ored-flag to + // avoid changing the interface every time we add a new flag. Flags + // are assumed to be verified before (i.e., op verifier). + int64_t oredFlag = xsmm::utils::getOredFlags(brgemmFlags); + + dispatchOperands.push_back(rewriter.create( + loc, integer64, IntegerAttr::get(rewriter.getI64Type(), oredFlag))); + dispatchOperandTypes.push_back(integer64); + ModuleOp module = op->getParentOfType(); + auto dispatched = xsmm::utils::buildDispatchCall( + rewriter, loc, dispatchOperands, dispatchOperandTypes, module, + SymbolRefAttr::get(op->getContext(), dispatchName)); + SmallVector operandRange; + operandRange.push_back(dispatched.getResult(0)); + for (auto operand : inputs) { + operandRange.push_back(operand); + } + if (batch != 0) { + Value batchDim = rewriter.create( + loc, integer64, rewriter.getIntegerAttr(integer64, batch)); + operandRange.push_back(batchDim); + } + auto invokeCall = xsmm::utils::buildInvokeCall( + rewriter, loc, module, operandRange, invokeName, dtype); + return std::make_pair(&*dispatched, &*invokeCall); +} + +} // namespace utils +} // namespace xsmm +} // namespace mlir diff --git a/third_party/cpu/lib/Xsmm/XsmmUtils.h b/third_party/cpu/lib/Xsmm/XsmmUtils.h new file mode 100644 index 000000000000..5d19110f4ae4 --- /dev/null +++ b/third_party/cpu/lib/Xsmm/XsmmUtils.h @@ -0,0 +1,157 @@ +//===- XsmmUtils.h - --------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TPP_DIALECT_XSMM_XSMMUTILS_H +#define TPP_DIALECT_XSMM_XSMMUTILS_H + +#include "cpu/include/Xsmm/XsmmEnum.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Value.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallVector.h" + +#include +#include + +using namespace mlir; + +namespace mlir { +class Type; +class RewriterBase; +class ArrayAttr; +class Operation; +class ValueRange; +class Attribute; + +namespace xsmm { + +struct BrgemmInfo { + int64_t m; + int64_t n; + int64_t k; + int64_t batch; + + int64_t lda; + int64_t ldb; + int64_t ldc; + int64_t strideA; + int64_t strideB; + + bool isVnni = false; +}; + +template +std::function FuncType = + [](Operation *op) { return isa(op); }; + +class UnaryKindAttr; + +struct UnaryInfo { + unsigned m; + unsigned n; + + int64_t ldi; + int64_t ldo; +}; + +struct BinaryInfo { + unsigned m; + unsigned n; + + int64_t ldiLhs; + int64_t ldiRhs; + int64_t ldo; +}; + +namespace utils { + +DataTypeAttr getDataType(RewriterBase &rewriter, Type type); + +FailureOr getUnaryInfo(Value input, Value output, + UnaryFlags inputFlag); + +FailureOr getBinaryInfo(Value lhs, BinaryFlags lhsFlag, Value rhs, + BinaryFlags rhsFlag, Value output); + +// Compute the broadcasting flags for 'inputType' based 'outputType'. +// Rules for broadcasting follows Numpy-style, and are only allowed in +// 'inputType'. see: https://numpy.org/doc/stable/user/basics.broadcasting.html +FailureOr getUnaryFlags(Type inputType, Type outputType); + +// Compute the broadcasting flags for 'operandType' based on 'outputType'. +enum class OperandPos { LHS = 0, RHS = 1 }; +FailureOr getBinFlags(ArrayRef shapeOutput, + ArrayRef shapeOperand, + OperandPos operandNumber); +FailureOr getBinaryFlags(Type operandType, Type outputType, + OperandPos operandNumber); + +FailureOr getBinaryFlagsVectorType(Type operandType, + Type outputType, + OperandPos operandNumber); + +FailureOr getLeadingDim(Type type, size_t pos = 0); + +int64_t getOredFlags(ArrayAttr flags); + +SmallVector extractInvokeOperandTypes(OpBuilder &builder, + ValueRange operands); +SmallVector getOperands(OpBuilder &builder, Location loc, + ValueRange operands, IntegerAttr dataTypeAttr); + +FailureOr isMappableToBrgemm(PatternRewriter &rewriter, + vector::ContractionOp contractOp, + SmallVector &inputs, + SmallVector &output, + ArrayRef indexingMap); + +FailureOr +makeMinorDimensionsInnerMost(RewriterBase &rewriter, + vector::ContractionOp contractOp, unsigned m, + unsigned n, unsigned k, IntegerAttr type); +std::optional getPosInCodomain(unsigned dim, Value operand, + vector::ContractionOp contractOp, + AffineMap map); +FailureOr +checkAccess(PatternRewriter &rewriter, vector::ContractionOp contractOp, + unsigned m, unsigned n, SmallVector kVector, + std::optional batchPos, SmallVector inputs, + ArrayRef indexingMap); + +bool WithInputs(PatternRewriter &rewriter, Operation *op, + SmallVector> operations, + SmallVector &inputs, SmallVector &opChain); +bool WithOutput(Operation *op, std::function operation, + SmallVector &output, SmallVector &opChain); +bool WithOps(Region *region, Operation *op, Operation *currentOp, + SmallVector> operations, + SmallVector &opChain); + +bool isTwoDTransposeOp(vector::TransposeOp transposeOp); + +func::CallOp buildDispatchCall(RewriterBase &rewriter, Location loc, + ArrayRef dispatchOperands, + ArrayRef dispatchOperandTypes, + ModuleOp module, FlatSymbolRefAttr fnName); +func::CallOp buildInvokeCall(RewriterBase &rewriter, Location loc, + ModuleOp module, SmallVector operands, + StringRef invokeName, DataTypeAttr dtype); + +// Create a pair of XSMM dispatch and invoke (BR)GEMM calls. +std::pair +buildBrgemmCalls(PatternRewriter &rewriter, Operation *op, ValueRange inputs, + xsmm::BrgemmInfo brgemmInfo, SmallVector flags); + +} // namespace utils +} // namespace xsmm +} // namespace mlir + +#endif // TPP_DIALECT_XSMM_XSMMUTILS_H diff --git a/third_party/cpu/lib/Xsmm/runtime/XsmmRunnerUtils.cpp b/third_party/cpu/lib/Xsmm/runtime/XsmmRunnerUtils.cpp new file mode 100644 index 000000000000..62b784f1f154 --- /dev/null +++ b/third_party/cpu/lib/Xsmm/runtime/XsmmRunnerUtils.cpp @@ -0,0 +1,516 @@ +//===- CRunnerUtils.cpp - Utils for MLIR execution ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements basic functions to manipulate structured MLIR types at +// runtime. Entities in this file are meant to be retargetable, including on +// targets without a C++ runtime, and must be kept C compatible. +// +//===----------------------------------------------------------------------===// + +#include "XsmmRunnerUtils.h" +#include "libxsmm.h" // NOLINT [build/include_subdir] +#include "libxsmm_utils.h" + +// Helper function prototypes. +static void printXsmmStruct(const libxsmm_gemm_shape &gemmShape, + FILE *outfile = stderr); +static void printXsmmStruct(const libxsmm_meltw_unary_shape &unaryShape, + FILE *outfile = stderr); +static void printXsmmStruct(const libxsmm_meltw_binary_shape &binaryShape, + FILE *outfile = stderr); +static void printXsmmStruct(const libxsmm_gemm_batch_reduce_config &brgemmShape, + FILE *outfile = stderr); + +static bool hasImplicitComputeDtypeUnary(const libxsmm_meltw_unary_type dtype) { + switch (dtype) { + // Zero + case LIBXSMM_MELTW_TYPE_UNARY_XOR: + // Copy + case LIBXSMM_MELTW_TYPE_UNARY_IDENTITY: + // Transpose + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT: + // VNNI2 + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_VNNI2: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_VNNI2_TO_VNNI2T: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_VNNI2T: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_VNNI2_PAD: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_PADM_MOD2: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_PADN_MOD2: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_PADNM_MOD2: + // VNNI4 + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_VNNI4: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_VNNI4_TO_VNNI4T: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_VNNI4T: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_VNNI4_PAD: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_PADM_MOD4: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_PADN_MOD4: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_PADNM_MOD4: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_VNNI4_TO_NORM: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_VNNI4_TO_VNNI2: + return true; + default: + return false; + } +} + +namespace { + +void *get_base_ptr(const libxsmm_datatype dType, void *alignedPtr, + int64_t offset) { + if (dType == LIBXSMM_DATATYPE_F32) { + float *base_ptr = (float *)alignedPtr + offset; + return (void *)base_ptr; + } else if (dType == LIBXSMM_DATATYPE_BF16) { + bf16 *base_ptr = (bf16 *)alignedPtr + offset; + return (void *)base_ptr; + } + fprintf(stderr, "Unhandled data type in get_data_pointer_from_memref_desc:%d", + dType); + return nullptr; +} + +} // namespace + +extern "C" void xsmm_gemm_invoke(const libxsmm_datatype dType, int64_t addr, + void *alignedPtrA, int64_t offsetA, + void *alignedPtrB, int64_t offsetB, + void *alignedPtrC, int64_t offsetC) { + libxsmm_xmmfunction sgemm; + libxsmm_gemm_param gemm_param; + + // LIBXSMM col-major change A with B. + gemm_param.a.primary = get_base_ptr(dType, alignedPtrB, offsetB); + gemm_param.b.primary = get_base_ptr(dType, alignedPtrA, offsetA); + gemm_param.c.primary = get_base_ptr(dType, alignedPtrC, offsetC); + + sgemm.gemm = reinterpret_cast(addr); + sgemm.gemm(&gemm_param); +} + +extern "C" int64_t xsmm_gemm_dispatch(const libxsmm_datatype dtype, int64_t m, + int64_t n, int64_t k, int64_t lda, + int64_t ldb, int64_t ldc, + const libxsmm_gemm_flags flags) { + // std::cout << "lda: " << lda << "\n"; + // std::cout << "ldb: " << ldb << "\n"; + // std::cout << "ldc: " << ldc << "\n"; + + // std::cout << "m: " << m << "\n"; + // std::cout << "n: " << n << "\n"; + // std::cout << "k: " << k << "\n"; + + libxsmm_blasint m_int = m; + libxsmm_blasint n_int = n; + libxsmm_blasint k_int = k; + + libxsmm_gemm_shape l_shape; + libxsmm_bitfield l_flags = flags; + libxsmm_bitfield l_prefetch_flags = 0; + + // See: + // https://stackoverflow.com/questions/56043539/cublassgemm-row-major-multiplication + // LIBXSMM col-major change m with n. + l_shape.m = n_int; + l_shape.n = m_int; + l_shape.k = k_int; + l_shape.lda = ldb; + l_shape.ldb = lda; + l_shape.ldc = ldc; + l_shape.a_in_type = dtype; + l_shape.b_in_type = dtype; + l_shape.out_type = dtype; + // Retarget computation type from bf16 to f32 due to missing hardware support. + l_shape.comp_type = + dtype == LIBXSMM_DATATYPE_BF16 ? LIBXSMM_DATATYPE_F32 : dtype; + + auto sgemm = libxsmm_dispatch_gemm(l_shape, l_flags, l_prefetch_flags); + if (!sgemm) { + fprintf(stderr, "failed to generate matmul func\n"); + fprintf(stderr, "dtype: %u\n", dtype); + printXsmmStruct(l_shape); + exit(-1); + } + + return reinterpret_cast(sgemm); +} + +extern "C" int64_t +xsmm_unary_dispatch(const libxsmm_meltw_unary_type op_type, + const libxsmm_datatype dtype, int64_t m, int64_t n, + int64_t ldi, int64_t ldo, + const libxsmm_meltw_unary_flags unary_flags) { + // std::cout << "ldi: " << ldi << "\n"; + // std::cout << "ldo: " << ldo << "\n"; + // std::cout << "m: " << m << "\n"; + // std::cout << "n: " << n << "\n"; + // std::cout << "type: " << type << "\n"; + // std::cout << "bcast_type: " << bcast_type << "\n"; + + libxsmm_meltw_unary_shape unary_shape; + // Row major to col major swap m with n. + unary_shape.m = static_cast(n); + unary_shape.n = static_cast(m); + unary_shape.in0_type = dtype; + // Retarget computation type from bf16 to f32 due to missing hardware support. + // Copy and Zero should remain in BF16 to avoid useless up/down casts + auto force_fp32 = (dtype == LIBXSMM_DATATYPE_BF16 && + !hasImplicitComputeDtypeUnary(op_type)); + unary_shape.comp_type = force_fp32 ? LIBXSMM_DATATYPE_F32 : dtype; + unary_shape.out_type = dtype; + unary_shape.ldi = static_cast(ldi); + unary_shape.ldo = static_cast(ldo); + + libxsmm_meltwfunction_unary kernel = + libxsmm_dispatch_meltw_unary(op_type, unary_shape, unary_flags); + if (!kernel) { + fprintf(stderr, "failed to generate unary func\n"); + fprintf(stderr, "op_type: %u\n", op_type); + fprintf(stderr, "flags: %u\n", unary_flags); + printXsmmStruct(unary_shape); + exit(-1); + } + + return reinterpret_cast(kernel); +} + +extern "C" int64_t +xsmm_binary_dispatch(const libxsmm_meltw_binary_type op_type, + const libxsmm_datatype dtype, int64_t m, int64_t n, + int64_t ldiLhs, int64_t ldiRhs, int64_t ldo, + const libxsmm_meltw_binary_flags flags) { + libxsmm_meltw_binary_shape binary_shape; + // Row major to col major swap m with n. + binary_shape.m = static_cast(n); + binary_shape.n = static_cast(m); + binary_shape.in0_type = dtype; + binary_shape.in1_type = dtype; + // Retarget computation type from bf16 to f32 due to missing hardware support. + binary_shape.comp_type = + dtype == LIBXSMM_DATATYPE_BF16 ? LIBXSMM_DATATYPE_F32 : dtype; + binary_shape.out_type = dtype; + binary_shape.ldi = static_cast(ldiLhs); + binary_shape.ldi2 = static_cast(ldiRhs); + binary_shape.ldo = static_cast(ldo); + + libxsmm_meltwfunction_binary kernel = + libxsmm_dispatch_meltw_binary(op_type, binary_shape, flags); + if (!kernel) { + fprintf(stderr, "failed to generate binary func\n"); + fprintf(stderr, "op_type: %u\n", op_type); + fprintf(stderr, "flags: %u\n", flags); + printXsmmStruct(binary_shape); + exit(-1); + } + + return reinterpret_cast(kernel); +} + +extern "C" int64_t xsmm_intel_amx_tile_config_dispatch( + const libxsmm_datatype dtype, int64_t m, int64_t n, int64_t k, int64_t lda, + int64_t ldb, int64_t ldc, int64_t stride_a, int64_t stride_b, + const libxsmm_gemm_flags flags) { + libxsmm_blasint m_int = m; + libxsmm_blasint n_int = n; + libxsmm_blasint k_int = k; + + libxsmm_gemm_shape l_shape; + libxsmm_bitfield l_cfg_flags = flags; + + l_shape.m = n_int; + l_shape.n = m_int; + l_shape.k = k_int; + l_shape.lda = ldb; + l_shape.ldb = lda; + l_shape.ldc = ldc; + l_shape.a_in_type = dtype; + l_shape.b_in_type = dtype; + l_shape.out_type = dtype; + l_shape.comp_type = + dtype == LIBXSMM_DATATYPE_BF16 ? LIBXSMM_DATATYPE_F32 : dtype; + + auto sgemm = libxsmm_dispatch_tilecfg_gemm(l_shape, l_cfg_flags); + if (!sgemm) { + fprintf(stderr, "failed to generate tileconfig func\n"); + fprintf(stderr, "dtype: %u\n", dtype); + fprintf(stderr, "flags: %u\n", flags); + printXsmmStruct(l_shape); + exit(-1); + } + + return reinterpret_cast(sgemm); +} + +extern "C" void xsmm_unary_invoke(const libxsmm_datatype dType, int64_t addr, + void *alignedPtrIn, int64_t offsetIn, + void *alignedPtrOut, int64_t offsetOut) { + libxsmm_meltw_unary_param param; + + param.in.primary = get_base_ptr(dType, alignedPtrIn, offsetIn); + param.out.primary = get_base_ptr(dType, alignedPtrOut, offsetOut); + + libxsmm_meltwfunction_unary kernel = + reinterpret_cast(addr); + kernel(¶m); +} + +extern "C" void xsmm_binary_invoke(const libxsmm_datatype dType, int64_t addr, + void *alignedPtrLhs, int64_t offsetLhs, + void *alignedPtrRhs, int64_t offsetRhs, + void *alignedPtrOut, int64_t offsetOut) { + libxsmm_meltw_binary_param param; + + param.in0.primary = get_base_ptr(dType, alignedPtrLhs, offsetLhs); + param.in1.primary = get_base_ptr(dType, alignedPtrRhs, offsetRhs); + param.out.primary = get_base_ptr(dType, alignedPtrOut, offsetOut); + + libxsmm_meltwfunction_binary kernel = + reinterpret_cast(addr); + kernel(¶m); +} + +extern "C" void xsmm_unary_scalar_invoke(const libxsmm_datatype dType, + int64_t addr, float input, + void *alignedOut, int64_t offsetOut) { + libxsmm_meltwfunction_unary kernel = + reinterpret_cast(addr); + libxsmm_meltw_unary_param param; + + param.in.primary = (void *)&input; + param.out.primary = get_base_ptr(dType, alignedOut, offsetOut); + kernel(¶m); +} + +extern "C" void xsmm_brgemm_invoke(const libxsmm_datatype dType, int64_t addr, + void *alignedPtrA, int64_t offsetA, + void *alignedPtrB, int64_t offsetB, + void *alignedPtrC, int64_t offsetC, + int64_t numBatches) { + libxsmm_xmmfunction sgemm; + libxsmm_gemm_param gemm_param; + + unsigned long long numBatchesVar = numBatches; + gemm_param.op.tertiary = (void *)&numBatchesVar; + + // LIBXSMM col-major change A with B. + gemm_param.a.primary = get_base_ptr(dType, alignedPtrB, offsetB); + gemm_param.b.primary = get_base_ptr(dType, alignedPtrA, offsetA); + gemm_param.c.primary = get_base_ptr(dType, alignedPtrC, offsetC); + + sgemm.gemm = reinterpret_cast(addr); + sgemm.gemm(&gemm_param); +} + +extern "C" int64_t xsmm_brgemm_dispatch(const libxsmm_datatype dtype, int64_t m, + int64_t n, int64_t k, int64_t lda, + int64_t ldb, int64_t ldc, + int64_t stride_a, int64_t stride_b, + const libxsmm_gemm_flags flags) { + // std::cout << "lda: " << lda << "\n"; + // std::cout << "lbd: " << ldb << "\n"; + // std::cout << "ldc: " << ldc << "\n"; + // std::cout << "m: " << m << "\n"; + // std::cout << "n: " << n << "\n"; + // std::cout << "k: " << k << "\n"; + + libxsmm_blasint lda_int = lda; + libxsmm_blasint ldb_int = ldb; + libxsmm_blasint ldc_int = ldc; + libxsmm_blasint m_int = m; + libxsmm_blasint n_int = n; + libxsmm_blasint k_int = k; + + libxsmm_gemm_shape l_shape; + libxsmm_bitfield l_flags = flags; + libxsmm_bitfield l_prefetch_flags = 0; + libxsmm_gemm_batch_reduce_config l_brconfig; + + l_shape.m = n_int; + l_shape.n = m_int; + l_shape.k = k_int; + l_shape.lda = ldb_int; + l_shape.ldb = lda_int; + l_shape.ldc = ldc_int; + l_shape.a_in_type = dtype; + l_shape.b_in_type = dtype; + l_shape.out_type = dtype; + // Retarget computation type from bf16 to f32 due to missing hardware support. + l_shape.comp_type = + dtype == LIBXSMM_DATATYPE_BF16 ? LIBXSMM_DATATYPE_F32 : dtype; + l_brconfig.br_type = LIBXSMM_GEMM_BATCH_REDUCE_STRIDE; + auto typeSize = dtype == LIBXSMM_DATATYPE_F32 ? sizeof(float) : sizeof(bf16); + l_brconfig.br_stride_a_hint = stride_b * typeSize; + l_brconfig.br_stride_b_hint = stride_a * typeSize; + l_brconfig.br_unroll_hint = 0; + + auto sgemm = + libxsmm_dispatch_brgemm(l_shape, l_flags, l_prefetch_flags, l_brconfig); + if (!sgemm) { + fprintf(stderr, "failed to generate brgemm func\n"); + fprintf(stderr, "dtype: %u\n", dtype); + printXsmmStruct(l_shape); + printXsmmStruct(l_brconfig); + exit(-1); + } + + return reinterpret_cast(sgemm); +} + +extern "C" void xsmm_fused_brgemm_invoke(const libxsmm_datatype dType, + int64_t addr, void *alignedPtrA, + int64_t offsetA, void *alignedPtrB, + int64_t offsetB, void *alignedPtrC, + int64_t offsetC, void *alignedPtrD, + int64_t offsetD, int64_t numBatches) { + libxsmm_xmmfunction sgemm; + libxsmm_gemm_ext_param gemm_param; + + unsigned long long numBatchesVar = numBatches; + gemm_param.op.tertiary = (void *)&numBatchesVar; + + // LIBXSMM col-major change A with B. + gemm_param.a.primary = get_base_ptr(dType, alignedPtrB, offsetB); + gemm_param.b.primary = get_base_ptr(dType, alignedPtrA, offsetA); + gemm_param.c.primary = get_base_ptr(dType, alignedPtrC, offsetC); + gemm_param.d.primary = get_base_ptr(dType, alignedPtrD, offsetD); + + sgemm.gemm_ext = reinterpret_cast(addr); + sgemm.gemm_ext(&gemm_param); +} + +extern "C" int64_t +xsmm_fused_brgemm_dispatch(const libxsmm_datatype data_type, int64_t m, + int64_t n, int64_t k, int64_t lda, int64_t ldb, + int64_t ldc, int64_t stride_a, int64_t stride_b, + const libxsmm_gemm_flags gemm_flags, + const libxsmm_meltw_unary_flags unary_flags, + const libxsmm_meltw_unary_type unary_op_type, + const libxsmm_meltw_binary_flags binary_flags, + const libxsmm_meltw_binary_type binary_op_type) { + // std::cout << "lda: " << lda << "\n"; + // std::cout << "lbd: " << ldb << "\n"; + // std::cout << "ldc: " << ldc << "\n"; + // std::cout << "m: " << m << "\n"; + // std::cout << "n: " << n << "\n"; + // std::cout << "k: " << k << "\n"; + + libxsmm_blasint lda_int = lda; + libxsmm_blasint ldb_int = ldb; + libxsmm_blasint ldc_int = ldc; + libxsmm_blasint m_int = m; + libxsmm_blasint n_int = n; + libxsmm_blasint k_int = k; + libxsmm_gemm_shape l_shape; + libxsmm_bitfield l_flags = gemm_flags; + libxsmm_bitfield l_prefetch_flags = 0; + + l_shape.m = n_int; + l_shape.n = m_int; + l_shape.k = k_int; + l_shape.lda = ldb_int; + l_shape.ldb = lda_int; + l_shape.ldc = ldc_int; + l_shape.a_in_type = data_type; + l_shape.b_in_type = data_type; + l_shape.out_type = data_type; + // Retarget computation type from bf16 to f32 due to missing hardware support. + l_shape.comp_type = + data_type == LIBXSMM_DATATYPE_BF16 ? LIBXSMM_DATATYPE_F32 : data_type; + + libxsmm_gemm_batch_reduce_config l_brconfig; + l_brconfig.br_type = LIBXSMM_GEMM_BATCH_REDUCE_STRIDE; + auto typeSize = + data_type == LIBXSMM_DATATYPE_F32 ? sizeof(float) : sizeof(bf16); + l_brconfig.br_stride_a_hint = stride_b * typeSize; + l_brconfig.br_stride_b_hint = stride_a * typeSize; + l_brconfig.br_unroll_hint = 0; + + libxsmm_gemm_ext_unary_argops l_argops; + memset(&l_argops, 0, sizeof(libxsmm_gemm_ext_unary_argops)); + l_argops.cp_unary_flags = LIBXSMM_MELTW_FLAG_UNARY_NONE; + l_argops.ldcp = ldc; + l_argops.cp_unary_type = unary_op_type; + + libxsmm_gemm_ext_binary_postops l_postops; + memset(&l_postops, 0, sizeof(libxsmm_gemm_ext_binary_postops)); + l_postops.d_in_type = data_type; + + l_postops.d_binary_flags = binary_flags; + l_postops.d_binary_type = binary_op_type; + l_postops.ldd = ldc; + + auto sgemm = libxsmm_dispatch_brgemm_ext(l_shape, l_flags, l_prefetch_flags, + l_brconfig, l_argops, l_postops); + if (!sgemm) { + fprintf(stderr, "failed to generate fused brgemm func\n"); + fprintf(stderr, "data_type: %u\n", data_type); + printXsmmStruct(l_shape); + printXsmmStruct(l_brconfig); + exit(-1); + } + + return reinterpret_cast(sgemm); +} + +extern "C" MLIR_RUNNERUTILS_EXPORT void +xsmm_intel_amx_tile_config_invoke(const libxsmm_datatype dType, int64_t addr, + void *tileState, int64_t offset) { + libxsmm_xmmfunction cfg_tr; + + libxsmm_tilecfg_state *l_tilestate = + reinterpret_cast(tileState); + + cfg_tr.tilecfg = reinterpret_cast(addr); + cfg_tr.tilecfg(l_tilestate); +} + +static void printXsmmStruct(const libxsmm_gemm_shape &gemmShape, + FILE *outfile) { + fprintf(outfile, "M: %d\n", gemmShape.m); + fprintf(outfile, "N: %d\n", gemmShape.n); + fprintf(outfile, "K: %d\n", gemmShape.k); + fprintf(outfile, "lda: %d\n", gemmShape.lda); + fprintf(outfile, "ldb: %d\n", gemmShape.ldb); + fprintf(outfile, "ldc: %d\n", gemmShape.ldc); + fprintf(outfile, "a_in_type: %d\n", gemmShape.a_in_type); + fprintf(outfile, "b_in_type: %d\n", gemmShape.b_in_type); + fprintf(outfile, "comp_type: %d\n", gemmShape.comp_type); + fprintf(outfile, "out_type: %d\n", gemmShape.out_type); +} + +static void printXsmmStruct(const libxsmm_meltw_unary_shape &unaryShape, + FILE *outfile) { + fprintf(outfile, "M: %d\n", unaryShape.m); + fprintf(outfile, "N: %d\n", unaryShape.n); + fprintf(outfile, "in0_type: %d\n", unaryShape.in0_type); + fprintf(outfile, "comp_type: %d\n", unaryShape.comp_type); + fprintf(outfile, "out_type: %d\n", unaryShape.out_type); + fprintf(outfile, "ldi: %d\n", unaryShape.ldi); + fprintf(outfile, "ldo: %d\n", unaryShape.ldo); +} + +static void printXsmmStruct(const libxsmm_meltw_binary_shape &binaryShape, + FILE *outfile) { + fprintf(outfile, "M: %d\n", binaryShape.m); + fprintf(outfile, "N: %d\n", binaryShape.n); + fprintf(outfile, "in0_type: %d\n", binaryShape.in0_type); + fprintf(outfile, "in1_type: %d\n", binaryShape.in1_type); + fprintf(outfile, "comp_type: %d\n", binaryShape.comp_type); + fprintf(outfile, "out_type: %d\n", binaryShape.out_type); + fprintf(outfile, "ldi: %d\n", binaryShape.ldi); + fprintf(outfile, "ldi2: %d\n", binaryShape.ldi2); + fprintf(outfile, "ldo: %d\n", binaryShape.ldo); +} + +static void +printXsmmStruct(const libxsmm_gemm_batch_reduce_config &brgemmConfig, + FILE *outfile) { + fprintf(outfile, "br_type: %d\n", brgemmConfig.br_type); + fprintf(outfile, "br_stride_a_hint: %d\n", brgemmConfig.br_stride_a_hint); + fprintf(outfile, "br_stride_b_hint: %d\n", brgemmConfig.br_stride_b_hint); + fprintf(outfile, "br_unroll_hint: %d\n", brgemmConfig.br_unroll_hint); +} diff --git a/third_party/cpu/lib/Xsmm/runtime/XsmmRunnerUtils.h b/third_party/cpu/lib/Xsmm/runtime/XsmmRunnerUtils.h new file mode 100644 index 000000000000..3d5048188efc --- /dev/null +++ b/third_party/cpu/lib/Xsmm/runtime/XsmmRunnerUtils.h @@ -0,0 +1,85 @@ +//===- CRunnerUtils.h - Utils for debugging MLIR execution ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares basic classes and functions to manipulate structured MLIR +// types at runtime. Entities in this file must be compliant with C++11 and be +// retargetable, including on targets without a C++ runtime. +// +//===----------------------------------------------------------------------===// + +#ifndef TPP_EXECUTIONENGINE_CRUNNERUTILS_H +#define TPP_EXECUTIONENGINE_CRUNNERUTILS_H + +#include "libxsmm.h" +#include "mlir/ExecutionEngine/Float16bits.h" +#include "mlir/ExecutionEngine/RunnerUtils.h" + +extern "C" MLIR_RUNNERUTILS_EXPORT int64_t +xsmm_gemm_dispatch(const libxsmm_datatype, int64_t, int64_t, int64_t, int64_t, + int64_t, int64_t, const libxsmm_gemm_flags); + +extern "C" MLIR_RUNNERUTILS_EXPORT int64_t xsmm_unary_dispatch( + const libxsmm_meltw_unary_type, const libxsmm_datatype, int64_t, int64_t, + int64_t, int64_t, const libxsmm_meltw_unary_flags); + +extern "C" MLIR_RUNNERUTILS_EXPORT int64_t xsmm_binary_dispatch( + const libxsmm_meltw_binary_type, const libxsmm_datatype, int64_t, int64_t, + int64_t, int64_t, int64_t, const libxsmm_meltw_binary_flags); + +extern "C" MLIR_RUNNERUTILS_EXPORT int64_t xsmm_brgemm_dispatch( + const libxsmm_datatype, int64_t, int64_t, int64_t, int64_t, int64_t, + int64_t, int64_t, int64_t, const libxsmm_gemm_flags); + +extern "C" MLIR_RUNNERUTILS_EXPORT int64_t xsmm_fused_brgemm_dispatch( + const libxsmm_datatype data_type, int64_t m, int64_t n, int64_t k, + int64_t lda, int64_t ldb, int64_t ldc, int64_t stride_a, int64_t stride_b, + const libxsmm_gemm_flags gemm_flags, + const libxsmm_meltw_unary_flags unary_flags, + const libxsmm_meltw_unary_type unary_op_type, + const libxsmm_meltw_binary_flags binary_flags, + const libxsmm_meltw_binary_type binary_op_type); + +extern "C" MLIR_RUNNERUTILS_EXPORT int64_t xsmm_intel_amx_tile_config_dispatch( + const libxsmm_datatype, int64_t, int64_t, int64_t, int64_t, int64_t, + int64_t, int64_t, int64_t, const libxsmm_gemm_flags); + +extern "C" MLIR_RUNNERUTILS_EXPORT void +xsmm_gemm_invoke(const libxsmm_datatype dType, int64_t addr, void *alignedPtrA, + int64_t offsetA, void *alignedPtrB, int64_t offsetB, + void *alignedPtrC, int64_t offsetC); + +extern "C" MLIR_RUNNERUTILS_EXPORT void +xsmm_unary_invoke(const libxsmm_datatype dType, int64_t addr, + void *alignedPtrIn, int64_t offsetIn, void *alignedPtrOut, + int64_t offsetOut); + +extern "C" MLIR_RUNNERUTILS_EXPORT void +xsmm_unary_scalar_invoke(const libxsmm_datatype, int64_t addr, float scalar, + void *alignedPtrOut, int64_t offsetOut); + +extern "C" MLIR_RUNNERUTILS_EXPORT void +xsmm_binary_invoke(const libxsmm_datatype dType, int64_t addr, + void *alignedPtrLhs, int64_t offsetLhs, void *alignedPtrRhs, + int64_t offsetRhs, void *alignedPtrOut, int64_t offsetOut); + +extern "C" MLIR_RUNNERUTILS_EXPORT void +xsmm_brgemm_invoke(const libxsmm_datatype dType, int64_t addr, + void *alignedPtrA, int64_t offsetA, void *alignedPtrB, + int64_t offsetB, void *alignedPtrC, int64_t offsetC, + int64_t numBatches); + +extern "C" MLIR_RUNNERUTILS_EXPORT void xsmm_fused_brgemm_invoke( + const libxsmm_datatype dType, int64_t addr, void *alignedPtrA, + int64_t offsetA, void *alignedPtrB, int64_t offsetB, void *alignedPtrC, + int64_t offsetC, void *alignedPtrD, int64_t offsetD, int64_t numBatches); + +extern "C" MLIR_RUNNERUTILS_EXPORT void +xsmm_intel_amx_tile_config_invoke(const libxsmm_datatype dType, int64_t addr, + void *alignedPtrA, int64_t offset); + +#endif // TPP_EXECUTIONENGINE_CRUNNERUTILS_H diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index 9db53e1d0c71..a126731408ca 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -2,12 +2,14 @@ #include "TritonCPUToLLVM/Passes.h" #include "TritonCPUTransforms/Passes.h" #include "TritonToTritonCPU/Passes.h" +#include "Xsmm/Passes.h" #include "triton/Dialect/TritonCPU/IR/Dialect.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" #include "mlir/Conversion/Passes.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/Passes.h" #include "mlir/Pass/Pass.h" @@ -154,6 +156,12 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) { m.def("add_func_to_llvmir", [](mlir::PassManager &pm) { pm.addPass(mlir::createConvertFuncToLLVMPass()); }); + m.def("add_convert_vector_to_xsmm", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createConvertVectorToXsmm()); + }); + m.def("add_expand_strided_metadata", [](mlir::PassManager &pm) { + pm.addPass(mlir::memref::createExpandStridedMetadataPass()); + }); } void init_triton_cpu(py::module &&m) {