From b01f1eb5c584824d1041efa5dc19cc332b9c2c97 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Fri, 11 Oct 2024 17:15:41 +0200 Subject: [PATCH] Vector to XSMM (#3) Implements lowering pass from vector to XSMM microkernels. libxsmm is added as an external dependency together with general MLIR infrastructure for handling XSMM code generation and runtime execution. The XSMM lowering is optional and can be enabled at JIT step by environment variable TRITON_CPU_XSMM=1 libxsmm is built as a shared library and linked with XSMM-related libraries. These are also added to the Python infrastructure. Additionally, general MLIR utilities are imported to allow analysis, code generation and microkernel execution. Initially, a simple pattern mapping vector contraction to an XSMM kernel is added. --- bin/RegisterTritonDialects.h | 2 + cmake/xsmm.cmake | 59 + third_party/cpu/CMakeLists.txt | 17 +- third_party/cpu/backend/compiler.py | 8 + third_party/cpu/include/CMakeLists.txt | 1 + third_party/cpu/include/Xsmm/CMakeLists.txt | 8 + third_party/cpu/include/Xsmm/Passes.h | 67 + third_party/cpu/include/Xsmm/Passes.td | 17 + third_party/cpu/include/Xsmm/XsmmEnum.h | 18 + third_party/cpu/include/Xsmm/XsmmEnum.td | 83 ++ third_party/cpu/lib/CMakeLists.txt | 1 + third_party/cpu/lib/Xsmm/CMakeLists.txt | 28 + .../cpu/lib/Xsmm/ConvertVectorToXsmm.cpp | 139 +++ third_party/cpu/lib/Xsmm/ValueUtils.cpp | 146 +++ third_party/cpu/lib/Xsmm/ValueUtils.h | 50 + third_party/cpu/lib/Xsmm/VnniUtils.cpp | 89 ++ third_party/cpu/lib/Xsmm/VnniUtils.h | 62 + third_party/cpu/lib/Xsmm/XsmmEnum.cpp | 15 + third_party/cpu/lib/Xsmm/XsmmUtils.cpp | 1084 +++++++++++++++++ third_party/cpu/lib/Xsmm/XsmmUtils.h | 157 +++ .../cpu/lib/Xsmm/runtime/XsmmRunnerUtils.cpp | 516 ++++++++ .../cpu/lib/Xsmm/runtime/XsmmRunnerUtils.h | 85 ++ third_party/cpu/triton_cpu.cc | 8 + 23 files changed, 2658 insertions(+), 2 deletions(-) create mode 100644 cmake/xsmm.cmake create mode 100644 third_party/cpu/include/Xsmm/CMakeLists.txt create mode 100644 third_party/cpu/include/Xsmm/Passes.h create mode 100644 third_party/cpu/include/Xsmm/Passes.td create mode 100644 third_party/cpu/include/Xsmm/XsmmEnum.h create mode 100644 third_party/cpu/include/Xsmm/XsmmEnum.td create mode 100644 third_party/cpu/lib/Xsmm/CMakeLists.txt create mode 100644 third_party/cpu/lib/Xsmm/ConvertVectorToXsmm.cpp create mode 100644 third_party/cpu/lib/Xsmm/ValueUtils.cpp create mode 100644 third_party/cpu/lib/Xsmm/ValueUtils.h create mode 100644 third_party/cpu/lib/Xsmm/VnniUtils.cpp create mode 100644 third_party/cpu/lib/Xsmm/VnniUtils.h create mode 100644 third_party/cpu/lib/Xsmm/XsmmEnum.cpp create mode 100644 third_party/cpu/lib/Xsmm/XsmmUtils.cpp create mode 100644 third_party/cpu/lib/Xsmm/XsmmUtils.h create mode 100644 third_party/cpu/lib/Xsmm/runtime/XsmmRunnerUtils.cpp create mode 100644 third_party/cpu/lib/Xsmm/runtime/XsmmRunnerUtils.h diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index a4939a739528..c4393bc17b80 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -20,6 +20,7 @@ #include "cpu/include/TritonCPUToLLVM/Passes.h" #include "cpu/include/TritonCPUTransforms/Passes.h" #include "cpu/include/TritonToTritonCPU/Passes.h" +#include "cpu/include/Xsmm/Passes.h" #include "nvidia/include/NVGPUToLLVM/Passes.h" #include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h" #include "triton/Conversion/TritonGPUToLLVM/Passes.h" @@ -74,6 +75,7 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::triton::cpu::registerTritonCPUTransformsPasses(); mlir::triton::cpu::registerTritonCPUToLLVMPasses(); mlir::triton::cpu::registerTritonOpScalarizeExternalModels(registry); + mlir::triton::cpu::registerTritonCPUXsmmPasses(); // TODO: register Triton & TritonGPU passes registry.insert + $ +) +add_definitions(-DLIBXSMM_DEFAULT_CONFIG -U_DEBUG -D__BLAS=0) + +set_property(TARGET xsmm PROPERTY POSITION_INDEPENDENT_CODE ON) # -fPIC +set_property(TARGET xsmm PROPERTY COMPILE_WARNING_AS_ERROR ON) + +set(THREADS_PREFER_PTHREAD_FLAG ON) +find_package(Threads REQUIRED) +target_link_libraries(xsmm PUBLIC Threads::Threads) +target_link_libraries(xsmm PUBLIC ${CMAKE_DL_LIBS}) + +include(CheckLibraryExists) +check_library_exists(m sqrt "" XSMM_LIBM) +if(XSMM_LIBM) + target_link_libraries(xsmm PUBLIC m) +endif() +check_library_exists(rt sched_yield "" XSMM_LIBRT) +if(XSMM_LIBRT) + target_link_libraries(xsmm PUBLIC rt) +endif() +#target_link_libraries(xsmm PUBLIC c) diff --git a/third_party/cpu/CMakeLists.txt b/third_party/cpu/CMakeLists.txt index 59d0f5c53d46..d58dfd545a02 100644 --- a/third_party/cpu/CMakeLists.txt +++ b/third_party/cpu/CMakeLists.txt @@ -1,15 +1,28 @@ +# libxsmm +include(xsmm) +message (STATUS "LIBXSMM Include dir: ${XSMM_INCLUDE_DIRS}") + include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) add_subdirectory(include) add_subdirectory(lib) if(TRITON_BUILD_PYTHON_MODULE) - add_triton_plugin(TritonCPU ${CMAKE_CURRENT_SOURCE_DIR}/triton_cpu.cc LINK_LIBS TritonCPUToLLVM TritonCPUTransforms) - target_link_libraries(TritonCPU PUBLIC MLIRVectorToSCF MLIRAffineToStandard MLIRMathToLibm MLIRAMXToLLVMIRTranslation) + add_triton_plugin(TritonCPU ${CMAKE_CURRENT_SOURCE_DIR}/triton_cpu.cc LINK_LIBS TritonCPUToLLVM TritonCPUTransforms TritonCPUXsmm) + target_link_libraries(TritonCPU PUBLIC MLIRVectorToSCF MLIRAffineToStandard MLIRMathToLibm MLIRAMXToLLVMIRTranslation MLIRMemRefTransforms) endif() add_library(TritonCPURuntime SHARED ${CMAKE_CURRENT_SOURCE_DIR}/runtime/cpu_runtime.cpp) target_link_libraries(TritonCPURuntime PRIVATE LLVMSupport) +add_library(TritonCPUXsmmRuntime SHARED ${CMAKE_CURRENT_SOURCE_DIR}/lib/Xsmm/runtime/XsmmRunnerUtils.cpp) +target_link_libraries(TritonCPUXsmmRuntime PRIVATE xsmm) +set_property(TARGET TritonCPUXsmmRuntime PROPERTY CXX_STANDARD 11) +target_compile_definitions(TritonCPUXsmmRuntime PRIVATE mlir_c_runner_utils_EXPORTS) +target_include_directories(TritonCPUXsmmRuntime + PUBLIC + $ +) + # Build and link sleef set(SLEEF_BUILD_SHARED_LIBS ON CACHE BOOL "Build sleef shared lib" FORCE) set(SLEEF_BUILD_DFT OFF CACHE BOOL "Don't build sleef DFT lib" FORCE) diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 6150defb6128..975001dfccd3 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -41,6 +41,7 @@ class CPUOptions: enable_fp_fusion: bool = True max_num_imprecise_acc_default: int = 0 enable_fast_math: bool = True + enable_xsmm: bool = False vec_lib: Optional[str] = 'libsleef' # TODO: Try to enable it. sanitize_overflow: bool = False @@ -96,6 +97,8 @@ def parse_options(self, opts) -> Any: if "supported_fp8_dtypes" not in args: supported_fp8_dtypes = set(CPUOptions.supported_fp8_dtypes) args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes)) + if "enable_xsmm" not in args: + args["enable_xsmm"] = os.getenv("TRITON_CPU_XSMM", "0") != "0" return CPUOptions(**args) def pack_metadata(self, metadata): @@ -194,7 +197,10 @@ def make_llir(self, src, metadata, options): # TritonCPU -> LLVM-IR (MLIR) pm = ir.pass_manager(mod.context) pm.enable_debug() + if options.enable_xsmm: + cpu.passes.ttcpuir.add_convert_vector_to_xsmm(pm) cpu.passes.ttcpuir.add_lower_vector_multi_dim(pm) + cpu.passes.ttcpuir.add_expand_strided_metadata(pm) cpu.passes.ttcpuir.add_vector_to_scf(pm, True, 1, False) cpu.passes.ttcpuir.add_lower_affine(pm) passes.convert.add_scf_to_cf(pm) @@ -259,6 +265,8 @@ def make_so(src, metadata, options): Path(asm_path).write_text(src) lib_dirs = cpu_driver.library_dirs libs = ["gcc", "m", "TritonCPURuntime", "sleef"] + if options.enable_xsmm: + libs.extend(["xsmm", "TritonCPUXsmmRuntime"]) so = _build("kernel", asm_path, tmpdir, lib_dirs, cpu_driver.include_dirs, libs) with open(so, "rb") as f: return f.read() diff --git a/third_party/cpu/include/CMakeLists.txt b/third_party/cpu/include/CMakeLists.txt index 30282a4736c1..fee03f896707 100644 --- a/third_party/cpu/include/CMakeLists.txt +++ b/third_party/cpu/include/CMakeLists.txt @@ -2,3 +2,4 @@ add_subdirectory(ScalarizePass) add_subdirectory(TritonCPUToLLVM) add_subdirectory(TritonCPUTransforms) add_subdirectory(TritonToTritonCPU) +add_subdirectory(Xsmm) diff --git a/third_party/cpu/include/Xsmm/CMakeLists.txt b/third_party/cpu/include/Xsmm/CMakeLists.txt new file mode 100644 index 000000000000..ede68918f9a6 --- /dev/null +++ b/third_party/cpu/include/Xsmm/CMakeLists.txt @@ -0,0 +1,8 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonCPUXsmm) +add_public_tablegen_target(TritonCPUXsmmPassIncGen) + +set(LLVM_TARGET_DEFINITIONS XsmmEnum.td) +mlir_tablegen(XsmmEnum.h.inc -gen-enum-decls) +mlir_tablegen(XsmmEnum.cpp.inc -gen-enum-defs) +add_public_tablegen_target(TritonCPUXsmmAttrDefIncGen) diff --git a/third_party/cpu/include/Xsmm/Passes.h b/third_party/cpu/include/Xsmm/Passes.h new file mode 100644 index 000000000000..0e7b2d11ce5e --- /dev/null +++ b/third_party/cpu/include/Xsmm/Passes.h @@ -0,0 +1,67 @@ +#ifndef TritonCPUXsmm_CONVERSION_PASSES_H +#define TritonCPUXsmm_CONVERSION_PASSES_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +class ModuleOp; + +namespace affine { +class AffineDialect; +} // namespace affine + +namespace arith { +class ArithDialect; +} // namespace arith + +namespace func { +class FuncOp; +class FuncDialect; +} // namespace func + +namespace linalg { +class LinalgDialect; +} // namespace linalg + +namespace LLVM { +class LLVMDialect; +} // namespace LLVM + +namespace math { +class MathDialect; +} // namespace math + +namespace memref { +class MemRefDialect; +} // namespace memref + +namespace scf { +class SCFDialect; +} // namespace scf + +namespace tensor { +class TensorDialect; +} // namespace tensor + +namespace vector { +class VectorDialect; +} // namespace vector + +} // namespace mlir + +namespace mlir { +namespace triton { +namespace cpu { + +#define GEN_PASS_DECL +#include "cpu/include/Xsmm/Passes.h.inc" + +#define GEN_PASS_REGISTRATION +#include "cpu/include/Xsmm/Passes.h.inc" + +} // namespace cpu +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/cpu/include/Xsmm/Passes.td b/third_party/cpu/include/Xsmm/Passes.td new file mode 100644 index 000000000000..b6b84d25dfde --- /dev/null +++ b/third_party/cpu/include/Xsmm/Passes.td @@ -0,0 +1,17 @@ +#ifndef TRITONCPU_XSMM_PASSES +#define TRITONCPU_XSMM_PASSES + +include "mlir/Pass/PassBase.td" + +def ConvertVectorToXsmm : Pass<"triton-cpu-convert-vector-to-xsmm", "mlir::ModuleOp"> { + let summary = "Convert vector to xsmm"; + let description = [{ + Convert vector operations to XSMM operations. + }]; + let dependentDialects = ["func::FuncDialect", + "memref::MemRefDialect", + "vector::VectorDialect", + "LLVM::LLVMDialect"]; +} + +#endif diff --git a/third_party/cpu/include/Xsmm/XsmmEnum.h b/third_party/cpu/include/Xsmm/XsmmEnum.h new file mode 100644 index 000000000000..19bfad8b16ba --- /dev/null +++ b/third_party/cpu/include/Xsmm/XsmmEnum.h @@ -0,0 +1,18 @@ +//===- XsmmEnum.h -----------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TPP_DIALECT_XSMM_XSMMENUM_H +#define TPP_DIALECT_XSMM_XSMMENUM_H + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/DialectImplementation.h" + +#define GET_ATTRDEF_CLASSES +#include "cpu/include/Xsmm/XsmmEnum.h.inc" + +#endif // TPP_DIALECT_XSMM_XSMMENUM_H diff --git a/third_party/cpu/include/Xsmm/XsmmEnum.td b/third_party/cpu/include/Xsmm/XsmmEnum.td new file mode 100644 index 000000000000..17da6cfbb8ef --- /dev/null +++ b/third_party/cpu/include/Xsmm/XsmmEnum.td @@ -0,0 +1,83 @@ +//===- XsmmEnum --------------------------------------------*- Tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/EnumAttr.td" + +def Xsmm_DataType: I64EnumAttr< + "DataType", "see: libxsmm_datatype", + [ + I64EnumAttrCase<"F32", 1, "f32">, + I64EnumAttrCase<"BF16", 2, "bf16"> + ]>{ + let cppNamespace = "mlir::xsmm"; +} + +def Xsmm_BinaryKind : I64EnumAttr< + "BinaryKind", "see: libxsmm_meltw_binary_type", + [ + I64EnumAttrCase<"NONE", 0, "none">, + I64EnumAttrCase<"ADD", 1, "add">, + I64EnumAttrCase<"MUL", 2, "mul">, + I64EnumAttrCase<"SUB", 3, "sub">, + I64EnumAttrCase<"DIV", 4, "div"> + ]> { + let cppNamespace = "mlir::xsmm"; +} + +def Xsmm_UnaryKind : I64EnumAttr< + "UnaryKind", "see: libxsmm_meltw_unary_type", + [ + I64EnumAttrCase<"NONE", 0, "none">, + I64EnumAttrCase<"IDENTITY", 1, "identity">, + I64EnumAttrCase<"ZERO", 2, "zero">, + I64EnumAttrCase<"RELU", 5, "relu">, + I64EnumAttrCase<"VNNI2", 28, "vnni_2">, + I64EnumAttrCase<"TRANSPOSE", 29, "transpose"> + ]> { + let cppNamespace = "mlir::xsmm"; +} + +def Xsmm_UnaryFlags : I64EnumAttr< + "UnaryFlags", "see: libxsmm_meltw_unary_flags", + [ + I64EnumAttrCase<"NONE", 0, "none">, + I64EnumAttrCase<"BCAST_ROW", 2, "bcast_row">, + I64EnumAttrCase<"BCAST_COL", 4, "bcast_col">, + I64EnumAttrCase<"BCAST_SCALAR", 8, "bcast_scalar"> + ]> { + let cppNamespace = "mlir::xsmm"; +} + +def Xsmm_BinaryFlags : I64EnumAttr< + "BinaryFlags", "see: libxsmm_meltw_binary_flags", + [ + I64EnumAttrCase<"NONE", 0, "none">, + I64EnumAttrCase<"BCAST_ROW_IN_0", 1, "bcast_row_in0">, + I64EnumAttrCase<"BCAST_ROW_IN_1", 2, "bcast_row_in1">, + I64EnumAttrCase<"BCAST_COL_IN_0", 4, "bcast_col_in0">, + I64EnumAttrCase<"BCAST_COL_IN_1", 8, "bcast_col_in1">, + I64EnumAttrCase<"BCAST_SCALAR_IN_0", 16, "bcast_scalar_in0">, + I64EnumAttrCase<"BCAST_SCALAR_IN_1", 32, "bcast_scalar_in1"> + ]> { + let cppNamespace = "mlir::xsmm"; +} + +def Xsmm_GemmFlags : I64EnumAttr< + "GemmFlags", "see: libxsmm_gemm_flags", + [ + I64EnumAttrCase<"NONE", 0, "none">, + I64EnumAttrCase<"BETA_0", 4, "beta_0">, + I64EnumAttrCase<"VNNI_A", 2048, "vnni_a">, + I64EnumAttrCase<"VNNI_B", 4096, "vnni_b">, + I64EnumAttrCase<"VNNI_C", 8192, "vnni_c">, + I64EnumAttrCase<"NO_RESET_TILECONFIG", 64, "no_reset_tileconfig">, + I64EnumAttrCase<"NO_SETUP_TILECONFIG", 128, "no_setup_tileconfig"> + ]> { + let cppNamespace = "mlir::xsmm"; +} diff --git a/third_party/cpu/lib/CMakeLists.txt b/third_party/cpu/lib/CMakeLists.txt index fad51ab86ea9..df45d15fdac9 100644 --- a/third_party/cpu/lib/CMakeLists.txt +++ b/third_party/cpu/lib/CMakeLists.txt @@ -2,3 +2,4 @@ add_subdirectory(Analysis) add_subdirectory(TritonCPUToLLVM) add_subdirectory(TritonCPUTransforms) add_subdirectory(TritonToTritonCPU) +add_subdirectory(Xsmm) diff --git a/third_party/cpu/lib/Xsmm/CMakeLists.txt b/third_party/cpu/lib/Xsmm/CMakeLists.txt new file mode 100644 index 000000000000..1f54d2f21d52 --- /dev/null +++ b/third_party/cpu/lib/Xsmm/CMakeLists.txt @@ -0,0 +1,28 @@ +add_triton_library(TritonCPUXsmm + ConvertVectorToXsmm.cpp + VnniUtils.cpp + ValueUtils.cpp + XsmmEnum.cpp + XsmmUtils.cpp + + DEPENDS + TritonCPUXsmmPassIncGen + TritonCPUXsmmAttrDefIncGen + xsmm + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRVectorDialect + MLIRMemRefDialect + MLIRFuncDialect + MLIRLLVMDialect + MLIRInferTypeOpInterface + MLIRLinalgUtils + xsmm +) + +target_include_directories(TritonCPUXsmm + PUBLIC + $ +) diff --git a/third_party/cpu/lib/Xsmm/ConvertVectorToXsmm.cpp b/third_party/cpu/lib/Xsmm/ConvertVectorToXsmm.cpp new file mode 100644 index 000000000000..729b66d8dc9f --- /dev/null +++ b/third_party/cpu/lib/Xsmm/ConvertVectorToXsmm.cpp @@ -0,0 +1,139 @@ +//===- ConvertVectorToXsmm.cpp ----------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "cpu/include/Xsmm/Passes.h" + +#include "ValueUtils.h" +#include "VnniUtils.h" +#include "XsmmUtils.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" + +#include +#include + +using namespace mlir; +using namespace mlir::vector; +using namespace mlir::func; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace mlir { +namespace triton { +namespace cpu { +#define GEN_PASS_DEF_CONVERTVECTORTOXSMM +#include "cpu/include/Xsmm/Passes.h.inc" +} // namespace cpu +} // namespace triton +} // namespace mlir + +namespace { + +static Value getMemrefSource(PatternRewriter &rewriter, Operation *op, + TypedValue operand) { + Location loc = op->getLoc(); + MLIRContext *ctx = op->getContext(); + OpBuilder::InsertionGuard g(rewriter); + + if (auto readOp = + dyn_cast_or_null(operand.getDefiningOp())) { + VectorType vecTy = readOp.getVectorType(); + SmallVector strides(vecTy.getRank(), 1); + return rewriter.create( + loc, readOp.getSource(), getAsOpFoldResult(readOp.getIndices()), + getAsIndexOpFoldResult(ctx, vecTy.getShape()), + getAsIndexOpFoldResult(ctx, strides)); + } + + rewriter.setInsertionPoint(op); + + auto vecTy = dyn_cast(operand.getType()); + assert(vecTy && "Expect vector type operand"); + MemRefType memTy = MemRefType::get(vecTy.getShape(), vecTy.getElementType()); + auto alloca = rewriter.create(loc, memTy); + Value zeroIdx = rewriter.create(loc, 0); + SmallVector indices(memTy.getRank(), zeroIdx); + auto write = + rewriter.create(loc, operand, alloca, indices); + + return alloca; +} + +struct ContractToXsmm : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp contractOp, + PatternRewriter &rewriter) const override { + Location loc = contractOp.getLoc(); + + TypedValue lhs = contractOp.getLhs(); + TypedValue rhs = contractOp.getRhs(); + TypedValue acc = contractOp.getAcc(); + + auto vecTy = dyn_cast(acc.getType()); + if (!vecTy) + return rewriter.notifyMatchFailure(contractOp, + "expects to accumulate on vector"); + + SmallVector flags; + Value lhsBuf = getMemrefSource(rewriter, contractOp, lhs); + Value rhsBuf = getMemrefSource(rewriter, contractOp, rhs); + Value accBuf = getMemrefSource(rewriter, contractOp, acc); + SmallVector inputs{lhsBuf, rhsBuf, accBuf}; + SmallVector outputs{nullptr}; + auto brgemmInfo = + xsmm::utils::isMappableToBrgemm(rewriter, contractOp, inputs, outputs, + contractOp.getIndexingMapsArray()); + if (failed(brgemmInfo)) + return rewriter.notifyMatchFailure(contractOp, "not mappable to XSMM"); + if (brgemmInfo->isVnni) + return rewriter.notifyMatchFailure(contractOp, "VNNI support NYI"); + + auto xsmmFuncs = xsmm::utils::buildBrgemmCalls( + rewriter, contractOp, ValueRange{lhsBuf, rhsBuf, accBuf}, *brgemmInfo, + flags); + + Value zeroIdx = rewriter.create(loc, 0); + SmallVector indices(dyn_cast(accBuf.getType()).getRank(), + zeroIdx); + auto readOp = + rewriter.create(loc, vecTy, accBuf, indices); + + rewriter.replaceOp(contractOp, readOp); + + return success(); + } +}; + +struct ConvertVectorToXsmm + : public triton::cpu::impl::ConvertVectorToXsmmBase { + using ConvertVectorToXsmmBase::ConvertVectorToXsmmBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + patterns.add(context); + if (failed(mlir::applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace diff --git a/third_party/cpu/lib/Xsmm/ValueUtils.cpp b/third_party/cpu/lib/Xsmm/ValueUtils.cpp new file mode 100644 index 000000000000..566665dbc7f0 --- /dev/null +++ b/third_party/cpu/lib/Xsmm/ValueUtils.cpp @@ -0,0 +1,146 @@ +//===- ValueUtils.cpp --------------------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "ValueUtils.h" + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Value.h" +#include "llvm/ADT/TypeSwitch.h" + +namespace mlir { +namespace utils { + +// Returns true if the value is a constant float or integer. +bool isValConstZero(Value val) { + return matchPattern(val, m_AnyZeroFloat()) || matchPattern(val, m_Zero()); +} + +// Returns true if the attribute represent "all zeros" +static bool isZeroAttr(Attribute attribute) { + return TypeSwitch(attribute) + .Case([](auto attr) { return attr.getValueAsDouble() == 0.0; }) + .Case([](auto attr) { return attr.getInt() == 0; }) + .Case([](auto attr) { + if (!attr.getElementType().isIntOrFloat()) + return false; + if (!attr.isSplat()) + return false; + auto splat = attr.template getSplatValue(); + return isZeroAttr(splat); + }) + .Default([](auto attr) { return false; }); +} + +// Prototypes +bool isZeroOp(Operation *); + +// Returns true if the value represents a zero filled tensor. +// Recurse into isZeroOp for defining ops if not immediately obvious +// Looks past linalg generic's argument (which don't have defining ops) +bool isZeroTensor(Value val) { + if (!val) + return false; + if (isValConstZero(val)) + return true; + + Operation *defOp = nullptr; + + // Block arguments don't have a defining op, but they do have an op arg + if (auto arg = dyn_cast(val)) { + // We need to find the argument to the linalg on the same order as this one + auto *linalgOp = arg.getParentRegion()->getParentOp(); + if (!isa(linalgOp)) + return false; + auto index = arg.getArgNumber(); + auto linalgArg = linalgOp->getOperand(index); + defOp = linalgArg.getDefiningOp(); + } else { + defOp = val.getDefiningOp(); + } + return isZeroOp(defOp); +} + +// Returns true if the operation represents a zero filled tensor +// Recurses into isZeroTensor for operands and isZeroAttr for attributes +bool isZeroOp(Operation *defOp) { + if (!defOp) + return false; + + return TypeSwitch(defOp) + .Case([&](auto op) { + // Dense attributes don't match APFloat.isZero() + auto attr = op.getValue(); + return isZeroAttr(attr); + }) + .Case([&](auto op) { + if (op.getInputs().size() != 1) + return false; + return isZeroTensor(op.getInputs()[0]); + }) + .Case( + [&](auto op) { return isZeroTensor(op.getSource()); }) + .Case([&](auto op) { + auto name = op.getName(); + auto module = defOp->getParentOfType(); + auto global = module.lookupSymbol(name); + auto attr = global.getInitialValueAttr(); + return isZeroAttr(attr); + }) + .Default([&](Operation *op) { return false; }); +} + +FailureOr> getStaticStrides(Value value) { + auto valueType = value.getType(); + if (!isa(valueType)) + return failure(); + auto memrefType = cast(valueType); + SmallVector strides; + int64_t offset; + if (failed(getStridesAndOffset(memrefType, strides, offset))) { + return failure(); + } + if (llvm::any_of(strides, [](int64_t stride) { + return stride == ShapedType::kDynamic; + })) { + return failure(); + } + return strides; +} + +std::pair getPtrAndOffset(OpBuilder &builder, Value operand, + Location loc) { + auto memrefType = dyn_cast(operand.getType()); + assert(memrefType && "Expect a memref value"); + MemRefType baseMemrefType = MemRefType::get({}, memrefType.getElementType()); + Type basePtrType = builder.getIndexType(); + Type offsetType = builder.getIndexType(); + SmallVector sizesTypes(memrefType.getRank(), offsetType); + SmallVector stridesTypes(memrefType.getRank(), offsetType); + auto meta = builder.create( + loc, baseMemrefType, offsetType, sizesTypes, stridesTypes, operand); + Value alignedPointerAsIndex = + builder.create(loc, basePtrType, + operand); + Value alignedPointerAsI64 = builder.create( + loc, builder.getIntegerType(64), alignedPointerAsIndex); + // TODO: non-POD will require an LLVMTypeConverter. + Value alignedPointer = builder.create( + loc, LLVM::LLVMPointerType::get(builder.getContext()), + alignedPointerAsI64); + Value offset = meta.getOffset(); + return std::make_pair(alignedPointer, offset); +} + +} // namespace utils +} // namespace mlir diff --git a/third_party/cpu/lib/Xsmm/ValueUtils.h b/third_party/cpu/lib/Xsmm/ValueUtils.h new file mode 100644 index 000000000000..8cd50146d41c --- /dev/null +++ b/third_party/cpu/lib/Xsmm/ValueUtils.h @@ -0,0 +1,50 @@ +//===- ValueUtils.h - -------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TPP_TRANSFORMS_UTILS_VALUEUTILS_H +#define TPP_TRANSFORMS_UTILS_VALUEUTILS_H + +#include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" + +#include +#include + +using namespace mlir; + +namespace mlir { +class OpBuilder; +class Operation; +class Location; +namespace utils { + +// Returns true if the value is a constant float or integer. +bool isValConstZero(Value val); + +// Returns true if the op defining `val` represents a zero filled tensor. +bool isZeroTensor(Value val); + +// Returns true if the operation represents a zero filled tensor. +bool isZeroOp(Operation *); + +// Returns the strides of `val`. The method returns something usefull +// only if the `val` type is a strided memref and the strides are statically +// known. +FailureOr> getStaticStrides(Value val); + +// Return the offset and ptr for `val`. Assert if `val` +// is not a memref. +std::pair getPtrAndOffset(OpBuilder &builder, Value val, + Location loc); + +} // namespace utils +} // namespace mlir + +#endif // TPP_TRANSFORMS_UTILS_VALUEUTILS_H diff --git a/third_party/cpu/lib/Xsmm/VnniUtils.cpp b/third_party/cpu/lib/Xsmm/VnniUtils.cpp new file mode 100644 index 000000000000..6df29f993c60 --- /dev/null +++ b/third_party/cpu/lib/Xsmm/VnniUtils.cpp @@ -0,0 +1,89 @@ +//===- VNNIUtils.cpp ---------------------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "VnniUtils.h" + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" + +#include "libxsmm.h" + +namespace mlir { +namespace vnni { +namespace utils { + +std::optional getVnniBlockingFactor(Type type) { + auto elementType = getElementTypeOrSelf(type); + if (elementType.isBF16()) + return libxsmm_cpuid_dot_pack_factor(LIBXSMM_DATATYPE_BF16); + return std::nullopt; +} + +// Until we have a better way to express the VNNI layout (see: #563), it is up +// to the callee to specify the expected rank in the VNNI layout as the rank +// depends on the operations we are dealing with. +bool isInVnniLayout(VnniOperandRank expectedRank, MemRefType memref) { + if (memref.getRank() != static_cast(expectedRank) || + !memref.getElementType().isBF16()) { + return false; + } + return memref.getShape().back() == vnni::utils::getVnniBlockingFactor(memref); +} + +bool isInVnniLayout(int64_t expectedRank, VectorType vector) { + if (vector.getRank() != expectedRank || !vector.getElementType().isBF16()) { + return false; + } + return vector.getShape().back() == vnni::utils::getVnniBlockingFactor(vector); +} + +// Until we have a better way to express the VNNI layout (see: #563), it is up +// to the callee to specify the expected rank in the VNNI layout as the rank +// depends on the operations we are dealing with. +bool isInVnniLayout(VnniOperandRank expectedRank, VectorType vector) { + return isInVnniLayout((int64_t)expectedRank, vector); +} + +FailureOr isInVnniLayout(linalg::GenericOp linalgOp, + AffineMap map, int64_t blockingFactor) { + ArrayRef results = map.getResults(); + SmallVector iteratorTypes = + linalgOp.getIteratorTypesArray(); + + AffineExpr vnniDim = results.back(); + auto dimExpr = dyn_cast(vnniDim); + if (!dimExpr || iteratorTypes[dimExpr.getPosition()] != + mlir::utils::IteratorType::reduction) { + return failure(); + } + + for (auto result : results) { + auto blockeDim = dyn_cast(result); + if (!blockeDim) + continue; + if (blockeDim.getKind() != AffineExprKind::FloorDiv) + continue; + auto lhsDim = dyn_cast(blockeDim.getLHS()); + auto rhsCst = dyn_cast(blockeDim.getRHS()); + if (!lhsDim || !rhsCst) + continue; + if (iteratorTypes[lhsDim.getPosition()] != + mlir::utils::IteratorType::reduction) + continue; + if (rhsCst.getValue() != blockingFactor) + continue; + return lhsDim; + } + return failure(); +} + +} // namespace utils +} // namespace vnni +} // namespace mlir diff --git a/third_party/cpu/lib/Xsmm/VnniUtils.h b/third_party/cpu/lib/Xsmm/VnniUtils.h new file mode 100644 index 000000000000..e8517a5d23e1 --- /dev/null +++ b/third_party/cpu/lib/Xsmm/VnniUtils.h @@ -0,0 +1,62 @@ +//===- VnniUtils.h -----------------------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TPP_TRANSFORMS_UTILS_VNNIUTILS_H +#define TPP_TRANSFORMS_UTILS_VNNIUTILS_H + +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Support/LogicalResult.h" + +#include +#include + +using namespace mlir; + +namespace mlir { +class Type; +class MemRefType; +class OpOperand; +class AffineDimExpr; +class AffineMap; + +namespace linalg { +class GenericOp; +} // namespace linalg + +namespace vnni { +namespace utils { + +enum class VnniOperandRank { + TRANSPOSE = 3, + GEMM = 3, + BRGEMM_INS = 4, + BRGEMM_OUTS = 3 +}; + +// Return the VNNI blocking factor: 2 for BF16 and 4 for BF8. +std::optional getVnniBlockingFactor(Type type); + +// Return true if the memref is in VNNI layout with rank `expectedRank`. +bool isInVnniLayout(VnniOperandRank expectedRank, MemRefType memref); + +bool isInVnniLayout(int64_t expectedRank, VectorType vector); + +// Return true if the memref is in VNNI layout with rank `expectedRank`. +bool isInVnniLayout(VnniOperandRank expectedRank, VectorType vector); + +// Return the first AffineDimExpr in the map `affineMap` +// with a VNNI layout pattern (AffineDimExpr floordiv VNNI). +FailureOr isInVnniLayout(linalg::GenericOp linalgOp, + AffineMap affineMap, + int64_t blockingFactor); + +} // namespace utils +} // namespace vnni +} // namespace mlir + +#endif diff --git a/third_party/cpu/lib/Xsmm/XsmmEnum.cpp b/third_party/cpu/lib/Xsmm/XsmmEnum.cpp new file mode 100644 index 000000000000..85766e5272f0 --- /dev/null +++ b/third_party/cpu/lib/Xsmm/XsmmEnum.cpp @@ -0,0 +1,15 @@ +//===- XsmmEnum.cpp - Xsmm dialect enum -------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "cpu/include/Xsmm/XsmmEnum.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::xsmm; + +#include "cpu/include/Xsmm/XsmmEnum.cpp.inc" diff --git a/third_party/cpu/lib/Xsmm/XsmmUtils.cpp b/third_party/cpu/lib/Xsmm/XsmmUtils.cpp new file mode 100644 index 000000000000..c5a60d2627de --- /dev/null +++ b/third_party/cpu/lib/Xsmm/XsmmUtils.cpp @@ -0,0 +1,1084 @@ +//===- XsmmUtils.cpp ---------------------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "XsmmUtils.h" +#include "ValueUtils.h" +#include "VnniUtils.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Value.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Compiler.h" + +#include +#include + +#define DEBUG_TYPE "xsmm-utils" + +using namespace mlir; +using namespace mlir::linalg; + +namespace mlir { +namespace xsmm { +namespace utils { + +// Callable object to verify if `operand` has static shape. +struct HasStaticShape { + HasStaticShape() = default; + HasStaticShape(SmallVectorImpl *shape) : shape(shape){}; + + bool operator()(Value operand, Operation *op) const { + auto operandType = operand.getType(); + if (auto shapedType = dyn_cast_or_null(operandType)) { + if (!shapedType.hasStaticShape()) + return false; + if (shape) { + for (int64_t shapeOnDim : shapedType.getShape()) + shape->push_back(shapeOnDim); + } + } + return true; + } + SmallVectorImpl *shape = nullptr; +}; + +// Callable object to verify if `operand` has static strides. +// If `operand` is a tensor type or a scalar, return true. +struct HasStaticStrides { + HasStaticStrides() = default; + HasStaticStrides(SmallVector *strides) : strides(strides){}; + + bool operator()(Value operand, Operation *op) const { + auto operandType = operand.getType(); + SmallVector strides; + if (auto memRefType = dyn_cast_or_null(operandType)) { + int64_t offset; + if (failed(getStridesAndOffset(memRefType, strides, offset))) + return false; + if (llvm::any_of(strides, [](int64_t stride) { + return stride == ShapedType::kDynamic; + })) { + return false; + } + if (this->strides) + this->strides->append(strides.begin(), strides.end()); + } + return true; + } + SmallVectorImpl *strides = nullptr; +}; + +// Structural matcher. +static FailureOr +checkStructure(vector::ContractionOp contractOp, SmallVector &inputs, + SmallVector &outputs, ArrayRef indexingMap) { + if (!HasStaticShape()(inputs[0], inputs[0].getDefiningOp()) || + !HasStaticShape()(inputs[1], inputs[1].getDefiningOp()) || + !HasStaticShape()(inputs[2], inputs[2].getDefiningOp()) || + (outputs[0] != nullptr && + !HasStaticShape()(outputs[0], outputs[0].getDefiningOp())) || + !HasStaticStrides()(inputs[0], inputs[0].getDefiningOp()) || + !HasStaticStrides()(inputs[1], inputs[1].getDefiningOp()) || + !HasStaticStrides()(inputs[2], inputs[2].getDefiningOp()) || + (outputs[0] != nullptr && + !HasStaticStrides()(outputs[0], outputs[0].getDefiningOp()))) { + return failure(); + } + return inferContractionDims(indexingMap); +} + +// Return the position of `dim` in the codomain of `operand`. +std::optional getPosInCodomain(unsigned dim, Value operand, + vector::ContractionOp contractOp, + AffineMap map) { + return map.getResultPosition(getAffineDimExpr(dim, contractOp.getContext())); +} + +static SmallVector +createFlatListOfOperandStaticDims(vector::ContractionOp contractOp) { + SmallVector res; + for (OpOperand &opOperand : contractOp.getOperation()->getOpOperands()) + llvm::append_range( + res, dyn_cast(opOperand.get().getType()).getShape()); + return res; +} + +static SmallVector +computeStaticLoopSizes(vector::ContractionOp contractOp, + ArrayRef maps) { + AffineMap map = concatAffineMaps(maps); + unsigned numDims = map.getNumDims(), numRes = map.getNumResults(); + SmallVector allShapeSizes = + createFlatListOfOperandStaticDims(contractOp); + SmallVector res(numDims, 0); + for (unsigned idx = 0; idx < numRes; ++idx) { + auto result = map.getResult(idx); + if (auto d = dyn_cast(result)) + res[d.getPosition()] = allShapeSizes[idx]; + } + return res; +} + +static FailureOr> +getVNNIStaticStrides(MemRefType valueType) { + SmallVector strides; + int64_t offset; + SmallVector shape; + for (size_t i = 0; i < valueType.getShape().size(); i++) { + shape.push_back(valueType.getShape()[i]); + } + auto temp = shape[shape.size() - 1]; + shape[shape.size() - 1] = shape[shape.size() - 2]; + shape[shape.size() - 2] = temp; + auto memrefType = MemRefType::get(shape, valueType.getElementType()); + if (failed(getStridesAndOffset(memrefType, strides, offset))) { + return failure(); + } + if (llvm::any_of(strides, [](int64_t stride) { + return stride == ShapedType::kDynamic; + })) { + return failure(); + } + return strides; +} + +// Access matcher. +FailureOr +checkAccess(PatternRewriter &rewriter, vector::ContractionOp contractOp, + unsigned m, unsigned n, SmallVector kVector, + std::optional batchPos, SmallVector inputs, + ArrayRef indexingMap) { + Value operandA = inputs[0]; + Value operandB = inputs[1]; + Value operandC = inputs[2]; + + unsigned k; + if (*xsmm::utils::getPosInCodomain( + kVector[0], contractOp->getOpOperand(1).get(), contractOp, + contractOp.getIndexingMapsArray()[1]) < + *xsmm::utils::getPosInCodomain( + n, contractOp->getOpOperand(1).get(), contractOp, + contractOp.getIndexingMapsArray()[1]) || + kVector.size() == 1) { + k = kVector[0]; + } else if (kVector.size() > 1) { + k = kVector[1]; + } + + auto checkStridesAndGetLda = [&](unsigned minorDim, unsigned majorDim, + Value operand, AffineMap map, + int operandIndex) -> FailureOr { + auto minorDimPosInCodomain = + xsmm::utils::getPosInCodomain(minorDim, operand, contractOp, map); + auto majorDimPosInCodomain = + xsmm::utils::getPosInCodomain(majorDim, operand, contractOp, map); + if (!minorDimPosInCodomain || !majorDimPosInCodomain) { + return failure(); + } + auto dataType = xsmm::utils::getDataType(rewriter, operand.getType()); + FailureOr> stridesOnOperand; + if (dataType == + DataTypeAttr::get(contractOp.getContext(), xsmm::DataType::BF16) && + operandIndex == 1) { + stridesOnOperand = + getVNNIStaticStrides(dyn_cast(operand.getType())); + } else { + stridesOnOperand = mlir::utils::getStaticStrides(operand); + } + if (failed(stridesOnOperand) || + (dataType == + DataTypeAttr::get(contractOp.getContext(), xsmm::DataType::BF16) && + operandIndex == 0 && + (*stridesOnOperand)[*minorDimPosInCodomain] != 2) || + ((dataType != DataTypeAttr::get(contractOp.getContext(), + xsmm::DataType::BF16) && + (*stridesOnOperand)[*minorDimPosInCodomain] != 1))) { + return failure(); + } + if (dataType == + DataTypeAttr::get(contractOp.getContext(), xsmm::DataType::BF16) && + operandIndex == 1) { + return (*stridesOnOperand)[*majorDimPosInCodomain + 1]; + } else { + return (*stridesOnOperand)[*majorDimPosInCodomain]; + } + }; + // A(m, k) + auto lda = checkStridesAndGetLda(k, m, operandA, indexingMap[0], 0); + if (failed(lda)) { + LLVM_DEBUG(llvm::dbgs() << "Failed to compute lda\n"); + return failure(); + } + LLVM_DEBUG(llvm::dbgs() << "[isMappableToBrge" + "mm] Strides on " + "A: OK\n"); + + // B(k, n) + auto ldb = checkStridesAndGetLda(n, k, operandB, indexingMap[1], 1); + if (failed(ldb)) { + LLVM_DEBUG(llvm::dbgs() << "Failed to compute ldb\n"); + + return failure(); + } + LLVM_DEBUG(llvm::dbgs() << "[isMappableToBrge" + "mm] Strides on " + "B: OK\n"); + + // C(m, n) + auto ldc = checkStridesAndGetLda(n, m, operandC, indexingMap[2], 2); + if (failed(ldc)) { + LLVM_DEBUG(llvm::dbgs() << "Failed to compute ldc\n"); + return failure(); + } + LLVM_DEBUG(llvm::dbgs() << "[isMappableToBrge" + "mm] Strides on " + "C: OK\n"); + + int64_t strideA = 1; + int64_t strideB = 1; + if (batchPos) { + auto batchPosCodomainA = getPosInCodomain(batchPos.value(), operandA, + contractOp, indexingMap[0]); + auto stridesOnA = ::mlir::utils::getStaticStrides(operandA); + strideA = (*stridesOnA)[*batchPosCodomainA]; + + auto batchPosCodomainB = getPosInCodomain(batchPos.value(), operandB, + contractOp, indexingMap[1]); + auto stridesOnB = ::mlir::utils::getStaticStrides(operandB); + strideB = (*stridesOnB)[*batchPosCodomainB]; + } + + auto loops = computeStaticLoopSizes(contractOp, indexingMap); + int64_t batchVal = (batchPos) ? loops[batchPos.value()] : 0; + + auto loopsK = 1; + for (auto kItr : kVector) + loopsK *= loops[kItr]; + + xsmm::BrgemmInfo info{loops[m], loops[n], loopsK, batchVal, *lda, + *ldb, *ldc, strideA, strideB}; + return info; +} + +// Check if the given +// generic is mappable to a +// brgemm xsmm op. +// - It is a contraction, +// with: +// -- 1 m and 1 n and 2 k +// dimensions. +// -- m appears on the LHS +// and OUT but not in RHS. +// -- n appears on the RHS +// and OUT but not in LHS. +// -- k and k' appear on the +// RHS and LHS but not OUT. +// -- the stride of the +// minor dimension for A, k +// is 1. +// -- the stride of the +// minor dimension for B, n +// is 1. +// -- the stride of the +// minor dimension for C, n +// is 1. +FailureOr isMappableToBrgemm(PatternRewriter &rewriter, + vector::ContractionOp contractOp, + SmallVector &inputs, + SmallVector &output, + ArrayRef indexingMap) { + auto contractionDims = + checkStructure(contractOp, inputs, output, indexingMap); + if (failed(contractionDims)) { + LLVM_DEBUG(llvm::dbgs() << "[isMappableToBr" + "gemm] Failed " + "on " + "checkStructure" + "\n"); + return failure(); + } + unsigned m = contractionDims->m.back(); + unsigned n = contractionDims->n.back(); + SmallVector kVector; + std::optional batch; + if (contractionDims->k.size() >= 2) { + for (size_t i = 1; i < contractionDims->k.size(); i++) + kVector.push_back(contractionDims->k[i]); + } else { + for (size_t i = 0; i < contractionDims->k.size(); i++) + kVector.push_back(contractionDims->k[i]); + } + + LLVM_DEBUG(llvm::dbgs() << "[isMappableToBrge" + "mm] Candidate " + "dims: " + << "\n"); + LLVM_DEBUG(llvm::dbgs() << "[isMappableToBrge" + "mm] m: " + << m << "\n"); + LLVM_DEBUG(llvm::dbgs() << "[isMappableToBrge" + "mm] n: " + << n << "\n"); + if (batch) + LLVM_DEBUG(llvm::dbgs() << "[isMappableToBr" + "gemm] batch: " + << batch << "\n"); + else + LLVM_DEBUG(llvm::dbgs() << "[isMappableToBr" + "gemm] no batch " + "dim\n"); + auto retval = checkAccess(rewriter, contractOp, m, n, kVector, batch, inputs, + indexingMap); + if (failed(retval)) { + LLVM_DEBUG(llvm::dbgs() << "Failed to check access\n"); + return failure(); + } + return retval; +} + +DataTypeAttr getDataType(RewriterBase &rewriter, Type type) { + auto elemType = getElementTypeOrSelf(type); + if (elemType.isBF16()) + return DataTypeAttr::get(rewriter.getContext(), xsmm::DataType::BF16); + return DataTypeAttr::get(rewriter.getContext(), xsmm::DataType::F32); +} + +FailureOr getUnaryInfo(Value input, Value output, + UnaryFlags inputFlag) { + Type outputType = output.getType(); + + assert(isa(outputType)); + auto outputShapedType = cast(outputType); + if (outputShapedType.getRank() != 2 || !outputShapedType.hasStaticShape() || + !isa(outputShapedType.getElementType())) { + return failure(); + } + + UnaryInfo unaryInfo; + unaryInfo.m = outputShapedType.getShape()[0]; + unaryInfo.n = outputShapedType.getShape()[1]; + + int64_t ldi = 1; + if (ShapedType inputShapedType = dyn_cast(input.getType())) { + auto stridesOnInput = mlir::utils::getStaticStrides(input); + if (failed(stridesOnInput) || stridesOnInput->back() != 1 || + !inputShapedType.hasStaticShape()) { + return failure(); + } + + // If we are broascasting a row into cols, the leading + // dimension is 1, same for scalar broadcast. + if (inputFlag == UnaryFlags::BCAST_ROW || + inputFlag == UnaryFlags::BCAST_SCALAR) { + ldi = 1; + } + // If we are broascasting a col into rows, the leading + // dimension is the size of the tensor. + else if (inputFlag == UnaryFlags::BCAST_COL) { + ldi = inputShapedType.getShape().back(); + } else { + ldi = stridesOnInput->front(); + } + } + auto stridesOnOutput = mlir::utils::getStaticStrides(output); + if (failed(stridesOnOutput) || stridesOnOutput->back() != 1) + return failure(); + + unaryInfo.ldi = ldi; + unaryInfo.ldo = stridesOnOutput->front(); + return unaryInfo; +} + +FailureOr getBinaryInfo(Value lhs, BinaryFlags lhsFlag, Value rhs, + BinaryFlags rhsFlag, Value output) { + Type outputType = output.getType(); + + assert(isa(outputType)); + auto outputShapedType = cast(outputType); + if (outputShapedType.getRank() != 2 || !outputShapedType.hasStaticShape() || + !isa(outputShapedType.getElementType())) { + return failure(); + } + + BinaryInfo binaryInfo; + binaryInfo.m = outputShapedType.getShape()[0]; + binaryInfo.n = outputShapedType.getShape()[1]; + + int64_t ldiLhs = 1; + if (ShapedType lhsShapedType = dyn_cast(lhs.getType())) { + auto stridesOnLhs = mlir::utils::getStaticStrides(lhs); + if (failed(stridesOnLhs) || stridesOnLhs->back() != 1 || + !lhsShapedType.hasStaticShape()) { + return failure(); + } + + if (lhsFlag == BinaryFlags::BCAST_SCALAR_IN_0 || + lhsFlag == BinaryFlags::BCAST_ROW_IN_0) { + ldiLhs = 1; + } else if (lhsFlag == BinaryFlags::BCAST_COL_IN_0) { + ldiLhs = lhsShapedType.getShape().back(); + } else { + ldiLhs = stridesOnLhs->front(); + } + } + + int64_t ldiRhs = 1; + if (ShapedType rhsShapedType = dyn_cast(rhs.getType())) { + auto stridesOnRhs = mlir::utils::getStaticStrides(rhs); + if (failed(stridesOnRhs) || stridesOnRhs->back() != 1 || + !rhsShapedType.hasStaticShape()) { + return failure(); + } + + if (rhsFlag == BinaryFlags::BCAST_SCALAR_IN_1 || + rhsFlag == BinaryFlags::BCAST_ROW_IN_1) { + ldiRhs = 1; + } else if (rhsFlag == BinaryFlags::BCAST_COL_IN_1) { + ldiRhs = rhsShapedType.getShape().back(); + } else { + ldiRhs = stridesOnRhs->front(); + } + } + + binaryInfo.ldiLhs = ldiLhs; + binaryInfo.ldiRhs = ldiRhs; + + auto stridesOnOutput = mlir::utils::getStaticStrides(output); + if (failed(stridesOnOutput) || stridesOnOutput->back() != 1) + return failure(); + binaryInfo.ldo = stridesOnOutput->front(); + return binaryInfo; +} + +// Examples: +// If lower=[c], higher=[a, b, c], [c] reshaped into [1, 1, c]. +// If lower=[b, c], higher=[a, b, c], [b, c] reshaped into [1, b, c]. +// If lower=[a], higher=[a, a], [a] reshaped into [1, a]. +// If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a]. +// If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1]. +static void +computeBcastShapeInput(ArrayRef higherRankShape, + ArrayRef lowerRankShape, + SmallVectorImpl &reshapeOutputShape) { + // Initialize new shapes with [1] * higherRank. + int64_t higherRank = higherRankShape.size(); + int64_t lowerRank = lowerRankShape.size(); + + reshapeOutputShape.assign(higherRank, 1); + + int64_t higherRankDim; + int64_t lowerRankDim; + + for (int64_t i = higherRank - 1, j = lowerRank - 1; i >= 0 && j >= 0; + i--, j--) { + higherRankDim = higherRankShape[i]; + lowerRankDim = lowerRankShape[j]; + + if (lowerRankDim == 1 && higherRankDim > 1) + reshapeOutputShape[i] = 1; + else if ((lowerRankDim > 1 && higherRankDim == 1) || + (lowerRankDim == higherRankDim)) + reshapeOutputShape[i] = lowerRankDim; + else if (higherRankDim != lowerRankDim) + assert(false && "bCast semantics for identity op broken"); + } +} + +FailureOr getUnaryFlags(Type inputType, Type outputType) { + assert(isa(outputType) && "expect shaped type on output"); + assert(cast(outputType).getRank() == 2 && + "expect rank 2 on output"); + + if (!isa(inputType) || + cast(inputType).getRank() == 0) { + return xsmm::UnaryFlags::BCAST_SCALAR; + } + + ArrayRef shapeOutput = cast(outputType).getShape(); + ArrayRef shapeInput = cast(inputType).getShape(); + assert(shapeOutput.size() >= shapeInput.size() && + "output rank must be >= input rank"); + SmallVector bShapeInput; + computeBcastShapeInput(shapeOutput, shapeInput, bShapeInput); + assert(shapeOutput.size() == bShapeInput.size()); + shapeInput = bShapeInput; + + // Same shape for input and output, no bcast. + if (shapeInput == shapeOutput) + return xsmm::UnaryFlags::NONE; + + // Input is a memref but it is all ones, bcast = scalar. + auto isOne = [](int64_t val) { return val == 1; }; + if (llvm::all_of(shapeInput, isOne)) + return xsmm::UnaryFlags::BCAST_SCALAR; + + if (shapeInput[1] == 1 && shapeOutput[1] > 1) + return xsmm::UnaryFlags::BCAST_ROW; + + if (shapeInput[0] == 1 && shapeOutput[0] > 1) + return xsmm::UnaryFlags::BCAST_COL; + + return failure(); +} + +FailureOr getBinFlags(ArrayRef shapeOutput, + ArrayRef shapeOperand, + OperandPos operandNumber) { + assert(shapeOutput.size() >= shapeOperand.size() && + "Output rank must be >= operand rank"); + SmallVector bOperandShape; + computeBcastShapeInput(shapeOutput, shapeOperand, bOperandShape); + assert(shapeOutput.size() == bOperandShape.size()); + assert(shapeOutput.size() == 2); + enum class BCastType { NONE = 0, SCALAR, ROW, COL }; + auto getBCastEnum = [](BCastType bCastType, + OperandPos operandPos) -> xsmm::BinaryFlags { + switch (bCastType) { + case BCastType::NONE: + return xsmm::BinaryFlags::NONE; + case BCastType::SCALAR: + if (operandPos == OperandPos::LHS) + return xsmm::BinaryFlags::BCAST_SCALAR_IN_0; + else + return xsmm::BinaryFlags::BCAST_SCALAR_IN_1; + case BCastType::ROW: + if (operandPos == OperandPos::LHS) + return xsmm::BinaryFlags::BCAST_ROW_IN_0; + else + return xsmm::BinaryFlags::BCAST_ROW_IN_1; + case BCastType::COL: + if (operandPos == OperandPos::LHS) + return xsmm::BinaryFlags::BCAST_COL_IN_0; + else + return xsmm::BinaryFlags::BCAST_COL_IN_1; + } + assert(false && "unrechable"); + abort(); + }; + + if (bOperandShape == shapeOutput) + return getBCastEnum(BCastType::NONE, operandNumber); + + auto isOne = [](int64_t val) { return val == 1; }; + if (llvm::all_of(bOperandShape, isOne)) + return getBCastEnum(BCastType::SCALAR, operandNumber); + + if (bOperandShape[1] == 1 && shapeOutput[1] > 1) + return getBCastEnum(BCastType::ROW, operandNumber); + + if (bOperandShape[0] == 1 && shapeOutput[0] > 1) + return getBCastEnum(BCastType::COL, operandNumber); + + return failure(); +} + +FailureOr getBinaryFlags(Type operandType, Type outputType, + OperandPos operandNumber) { + assert(isa(outputType) && "expect shaped type on output"); + assert(cast(outputType).getRank() == 2 && + "expect rank 2 on output"); + + if (!isa(operandType) || + cast(operandType).getRank() == 0) { + if (operandNumber == OperandPos::LHS) + return xsmm::BinaryFlags::BCAST_SCALAR_IN_0; + return xsmm::BinaryFlags::BCAST_SCALAR_IN_1; + } + + enum class BCastType { NONE = 0, SCALAR, ROW, COL }; + auto shapeOutput = cast(outputType).getShape(); + auto shapeOperand = cast(operandType).getShape(); + return getBinFlags(shapeOutput, shapeOperand, operandNumber); +} + +FailureOr getBinaryFlagsVectorType(Type operandType, + Type outputType, + OperandPos operandNumber) { + assert(isa(outputType) && "expect shaped type on output"); + assert(cast(outputType).getRank() == 2 && + "expect rank 2 on output"); + + if (!isa(operandType) || + cast(operandType).getRank() == 0) { + if (operandNumber == OperandPos::LHS) + return xsmm::BinaryFlags::BCAST_SCALAR_IN_0; + return xsmm::BinaryFlags::BCAST_SCALAR_IN_1; + } + + enum class BCastType { NONE = 0, SCALAR, ROW, COL }; + auto shapeOutput = cast(outputType).getShape(); + auto shapeOperand = cast(operandType).getShape(); + return getBinFlags(shapeOutput, shapeOperand, operandNumber); +} + +FailureOr getLeadingDim(Type type, size_t pos) { + // Not shaped type, the leading dimension is the single scalar. + auto memref = dyn_cast(type); + if (!memref) + return 1; + // For 1d memref we cannot use the stride as leading dimension, but the + // leading dimension is the dimension itself. + if (memref.getRank() == 1) + return memref.getShape()[0]; + + SmallVector strides; + int64_t offset; + if (failed(getStridesAndOffset(memref, strides, offset))) + return failure(); + // fail if the strides are non-constant + if (llvm::any_of(strides, [](int64_t stride) { + return stride == ShapedType::kDynamic; + })) + return failure(); + return strides[pos]; +} + +static bool isInnerMostDim(OpOperand *operand, unsigned minorDim, + vector::ContractionOp contractOp, DataTypeAttr dtype, + int operandNumber) { + auto shapedType = cast(operand->get().getType()); + int64_t rank = shapedType.getRank(); + if (dtype == + DataTypeAttr::get(contractOp.getContext(), xsmm::DataType::BF16) && + (operandNumber == 1 || operandNumber == 0)) { + return minorDim == rank - 2; + } + return minorDim == rank - 1; +} + +// Emit a transpose operation for `operand` by swapping `dim` with `newDim`. +// Emit a transpose operation for `operand` by swapping the dimensions at index +// `dim` with `newDim`. +static void emitTransposeOnOperand(RewriterBase &rewriter, + vector::ContractionOp contractOp, + Value operand, unsigned dim, unsigned newDim, + int operandNumber) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(contractOp); + + Location loc = contractOp.getLoc(); + auto operandType = cast(operand.getType()); + auto rank = operandType.getRank(); + SmallVector shape = llvm::to_vector(operandType.getShape()); + auto permutation = llvm::to_vector(llvm::seq(0, rank)); + std::swap(permutation[dim], permutation[newDim]); + assert(isPermutationVector(permutation)); + LLVM_DEBUG(llvm::interleaveComma( + permutation, llvm::dbgs() << "[emitTransposeOnOperand] Perm: ")); + LLVM_DEBUG(llvm::dbgs() << "\n"); + + applyPermutationToVector(shape, permutation); + auto vectorType = VectorType::get( + shape, cast(operand.getType()).getElementType()); + Value transposeResult = rewriter.create( + loc, vectorType, operand, permutation); + + SmallVector indexingMaps = contractOp.getIndexingMapsArray(); + AffineMap operandMap = indexingMaps[operandNumber]; + LLVM_DEBUG(llvm::dbgs() << "[emitTransposeOnOperand] Old map: " << operandMap + << "\n"); + SmallVector mapResults = llvm::to_vector(operandMap.getResults()); + applyPermutationToVector(mapResults, permutation); + AffineMap newMap = + AffineMap::get(operandMap.getNumDims(), operandMap.getNumSymbols(), + mapResults, contractOp.getContext()); + LLVM_DEBUG(llvm::dbgs() << "[emitTransposeOnOperand] New map: " << newMap + << "\n"); + indexingMaps[operandNumber] = newMap; + // TODO: We probably cannot update the result in place. + rewriter.modifyOpInPlace(contractOp, [&]() { + contractOp->setOperand(operandNumber, transposeResult); + contractOp.setIndexingMapsAttr( + ArrayAttr::get(contractOp.getContext(), + llvm::to_vector(llvm::map_range( + indexingMaps, [](AffineMap map) -> Attribute { + return AffineMapAttr::get(map); + })))); + }); +} + +FailureOr +makeMinorDimensionsInnerMost(RewriterBase &rewriter, + vector::ContractionOp contractOp, unsigned m, + unsigned n, unsigned k, DataTypeAttr type) { + OpOperand *operandA = &contractOp->getOpOperand(0); + OpOperand *operandB = &contractOp->getOpOperand(1); + OpOperand &operandC = contractOp->getOpOperand(2); + + // C(m,n) += A(m,k) * B(k,n) + // n is expected to be the innermost for C + // k is expected to be the innermost for A + // n is expected to be the innermost for B + auto minorKInCodomainOpA = xsmm::utils::getPosInCodomain( + k, operandA->get(), contractOp, contractOp.getIndexingMapsArray()[0]); + auto minorMInCodomainOpA = xsmm::utils::getPosInCodomain( + m, operandA->get(), contractOp, contractOp.getIndexingMapsArray()[0]); + if (!minorKInCodomainOpA || !minorMInCodomainOpA) { + LLVM_DEBUG( + llvm::dbgs() + << "[makeMinorDimensionsInnerMost] did not find minor dims for A\n"); + return failure(); + } + + auto minorNInCodomainOpB = xsmm::utils::getPosInCodomain( + n, operandB->get(), contractOp, contractOp.getIndexingMapsArray()[1]); + auto minorKInCodomainOpB = xsmm::utils::getPosInCodomain( + k, operandB->get(), contractOp, contractOp.getIndexingMapsArray()[1]); + if (!minorNInCodomainOpB || !minorKInCodomainOpB) { + LLVM_DEBUG( + llvm::dbgs() + << "[makeMinorDimensionsInnerMost] did not find minor dims for B\n"); + return failure(); + } + + auto minorNInCodomainOpC = xsmm::utils::getPosInCodomain( + n, operandC.get(), contractOp, contractOp.getIndexingMapsArray()[2]); + auto minorMInCodomainOpC = xsmm::utils::getPosInCodomain( + m, operandC.get(), contractOp, contractOp.getIndexingMapsArray()[2]); + if (!minorNInCodomainOpC || !minorMInCodomainOpC) { + LLVM_DEBUG( + llvm::dbgs() + << "[makeMinorDimensionsInnerMost] did not find minor dims for C\n"); + return failure(); + } + + if (!isInnerMostDim(&operandC, *minorNInCodomainOpC, contractOp, type, 2)) { + LLVM_DEBUG(llvm::dbgs() + << "[makeMinorDimensionsInnerMost] emit transpose for C\n"); + assert( + isInnerMostDim(&operandC, *minorMInCodomainOpC, contractOp, type, 2)); + if (isInnerMostDim(operandA, *minorKInCodomainOpA, contractOp, type, 0)) { + emitTransposeOnOperand(rewriter, contractOp, operandA->get(), + *minorKInCodomainOpA, *minorMInCodomainOpA, 0); + } + if (isInnerMostDim(operandB, *minorNInCodomainOpB, contractOp, type, 1)) { + emitTransposeOnOperand(rewriter, contractOp, operandB->get(), + *minorNInCodomainOpB, *minorKInCodomainOpB, 1); + } + // Avoid transpose on the output by swapping A and B. + OpOperand *operandA = &contractOp->getOpOperand(0); + OpOperand *operandB = &contractOp->getOpOperand(1); + SmallVector indexingMaps = contractOp.getIndexingMapsArray(); + std::swap(indexingMaps[0], indexingMaps[1]); + rewriter.modifyOpInPlace(contractOp, [&]() { + Value operandATmp = operandA->get(); + contractOp->setOperand(0, operandB->get()); + contractOp->setOperand(1, operandATmp); + contractOp.setIndexingMapsAttr( + ArrayAttr::get(contractOp.getContext(), + llvm::to_vector(llvm::map_range( + indexingMaps, [](AffineMap map) -> Attribute { + return AffineMapAttr::get(map); + })))); + }); + return contractOp; + } + + if (!isInnerMostDim(operandA, *minorKInCodomainOpA, contractOp, type, 0)) { + LLVM_DEBUG(llvm::dbgs() + << "[makeMinorDimensionsInnerMost] emit transpose for A\n"); + assert(isInnerMostDim(operandA, *minorMInCodomainOpA, contractOp, type, 0)); + emitTransposeOnOperand(rewriter, contractOp, operandA->get(), + *minorKInCodomainOpA, *minorMInCodomainOpA, 0); + } + if (!isInnerMostDim(operandB, *minorNInCodomainOpB, contractOp, type, 1)) { + LLVM_DEBUG(llvm::dbgs() + << "[makeMinorDimensionsInnerMost] emit transpose for B\n"); + assert(isInnerMostDim(operandB, *minorKInCodomainOpB, contractOp, type, 1)); + emitTransposeOnOperand(rewriter, contractOp, operandB->get(), + *minorKInCodomainOpB, *minorNInCodomainOpB, 1); + } + return contractOp; +} + +bool WithInputs(PatternRewriter &rewriter, Operation *op, + SmallVector> operations, + SmallVector &inputs, SmallVector &opChain) { + for (size_t i = 0; i < operations.size(); i++) { + auto input = op->getOperand(i); + if (!operations[i](input.getDefiningOp())) + return false; + if (input.getDefiningOp()->getOperand(0).getDefiningOp() != nullptr) { + if (!(isa( + input.getDefiningOp()->getOperand(0).getDefiningOp()) || + isa( + input.getDefiningOp()->getOperand(0).getDefiningOp()) || + isa( + input.getDefiningOp()->getOperand(0).getDefiningOp()) || + isa( + input.getDefiningOp()->getOperand(0).getDefiningOp()))) + return false; + } + inputs.push_back(input.getDefiningOp()->getOpOperand(0).get()); + opChain.push_back(input.getDefiningOp()); + } + return true; +} + +bool WithOutput(Operation *op, std::function operation, + SmallVector &output, SmallVector &opChain) { + // Check on the inner chain of operations in the right order. + // Make sure all operands are used and chained + for (auto use : op->getResult(0).getUsers()) { + if (use != op && operation(use)) { + if (!isa(use->getOperand(1).getDefiningOp())) + return false; + output.push_back(use->getOpOperand(1).get()); + opChain.push_back(use); + return true; + } + } + return false; +} + +bool WithOps(Region *region, Operation *op, Operation *currentOp, + SmallVector> operations, + SmallVector &opChain) { + auto &block = region->front(); + + llvm::SmallSetVector chainedValues; + + auto start = block.begin(); + for (auto opItr = block.begin(); opItr != block.end(); opItr++) { + if (&*opItr != currentOp || !operations[0](&*opItr)) + continue; + start = opItr; + opChain.push_back(&*opItr); + break; + } + // Check on the inner chain of operations in the right order. + // Make sure all operands are used and chained + for (auto check : operations) { + Operation *innerOp = &*start; + // Must be right op in right order + if (start == block.end() || !check(innerOp)) + return false; + start++; + // At least one operand must come from args or a previous op + bool consumesValueFromChain = false; + if (chainedValues.empty()) { + consumesValueFromChain = true; + } else { + for (auto operand : innerOp->getOperands()) { + if (chainedValues.contains(operand)) { + chainedValues.remove(operand); + consumesValueFromChain = true; + } + } + } + + // Operation isn't in the chain + if (!consumesValueFromChain) + return false; + + for (auto ret : innerOp->getResults()) { + chainedValues.insert(ret); + } + } + return true; +} + +bool isTwoDTransposeOp(vector::TransposeOp transposeOp) { + if (!(dyn_cast(transposeOp.getOperand().getType()).getRank() == + 2 && + dyn_cast(transposeOp.getResult().getType()).getRank() == + 2) || + !(isa(transposeOp->getParentOp()) && + dyn_cast(transposeOp->getParentOp()).getRank() == 2)) + return false; + return true; +} + +// Extract the operands to be used in the function call. For each memref operand +// extract the aligned pointer and the offset. +SmallVector getOperands(OpBuilder &builder, Location loc, + ValueRange operands, IntegerAttr dataTypeAttr) { + SmallVector res; + IntegerType integer64 = IntegerType::get(builder.getContext(), 64); + res.push_back( + builder.create(loc, integer64, dataTypeAttr)); + + for (Value operand : operands) { + auto memrefType = dyn_cast(operand.getType()); + if (!memrefType) { + res.push_back(operand); + continue; + } + auto [ptr, offset] = ::mlir::utils::getPtrAndOffset(builder, operand, loc); + res.push_back(ptr); + res.push_back(offset); + } + return res; +} + +SmallVector extractInvokeOperandTypes(OpBuilder &builder, + ValueRange operands) { + SmallVector results; + // One extra operand for datatype + IntegerType integer64 = IntegerType::get(builder.getContext(), 64); + results.push_back(integer64); + for (Value operand : operands) { + Type operandType = operand.getType(); + if (auto memrefType = dyn_cast(operandType)) { + // TODO: non-POD will require an LLVMTypeConverter. + Type basePtrType = LLVM::LLVMPointerType::get(builder.getContext()); + results.push_back(basePtrType); + results.push_back(builder.getIndexType()); // offset + } else { + results.push_back(operand.getType()); + } + } + return results; +} + +int64_t getOredFlags(ArrayAttr flags) { + int64_t oredFlag = 0; + for (auto flag : flags) { + int64_t intAttr = dyn_cast(flag).getInt(); + // LIBXSMM is col-major, swap A and B flags. + if (auto gemmFlag = dyn_cast_or_null(flag)) { + if (gemmFlag.getValue() == GemmFlags::VNNI_A) + intAttr = static_cast(GemmFlags::VNNI_B); + if (gemmFlag.getValue() == GemmFlags::VNNI_B) + intAttr = static_cast(GemmFlags::VNNI_A); + } + oredFlag |= intAttr; + } + return oredFlag; +} + +func::CallOp buildDispatchCall(RewriterBase &rewriter, Location loc, + ArrayRef dispatchOperands, + ArrayRef dispatchOperandTypes, + ModuleOp module, FlatSymbolRefAttr fnName) { + auto libFnType = rewriter.getFunctionType( + dispatchOperandTypes, IntegerType::get(rewriter.getContext(), 64)); + + if (!module.lookupSymbol(fnName.getAttr())) { + OpBuilder::InsertionGuard guard(rewriter); + // Insert before module terminator. + rewriter.setInsertionPoint(module.getBody(), + std::prev(module.getBody()->end())); + func::FuncOp funcOp = + rewriter.create(loc, fnName.getValue(), libFnType); + funcOp.setPrivate(); + } + + func::CallOp call = rewriter.create( + loc, fnName.getValue(), IntegerType::get(rewriter.getContext(), 64), + dispatchOperands); + return call; +} + +func::CallOp buildInvokeCall(RewriterBase &rewriter, Location loc, + ModuleOp module, SmallVector operandRange, + StringRef invokeName, DataTypeAttr dtype) { + auto libFnType = rewriter.getFunctionType( + xsmm::utils::extractInvokeOperandTypes(rewriter, operandRange), {}); + FlatSymbolRefAttr fnName = + SymbolRefAttr::get(rewriter.getContext(), invokeName); + + if (!module.lookupSymbol(fnName)) { + OpBuilder::InsertionGuard guard(rewriter); + // Insert before module terminator. + rewriter.setInsertionPoint(module.getBody(), + std::prev(module.getBody()->end())); + func::FuncOp funcOp = + rewriter.create(loc, invokeName, libFnType); + funcOp.setPrivate(); + } + + func::CallOp call = rewriter.create( + loc, fnName, TypeRange(), + xsmm::utils::getOperands(rewriter, loc, operandRange, dtype)); + + return call; +} + +std::pair +buildBrgemmCalls(PatternRewriter &rewriter, Operation *op, ValueRange inputs, + xsmm::BrgemmInfo brgemmInfo, SmallVector flags) { + assert(inputs.size() == 3 && "Expects three inputs for BRGEMM call"); + auto m = brgemmInfo.m; + auto n = brgemmInfo.n; + auto k = brgemmInfo.k; + auto batch = brgemmInfo.batch; + int64_t lda = brgemmInfo.lda; + int64_t ldb = brgemmInfo.ldb; + int64_t ldc = brgemmInfo.ldc; + int64_t strideA = brgemmInfo.strideA; + int64_t strideB = brgemmInfo.strideB; + auto loc = op->getLoc(); + auto dtype = xsmm::utils::getDataType(rewriter, inputs[0].getType()); + IntegerType integer64 = IntegerType::get(rewriter.getContext(), 64); + SmallVector dispatchOperands; + SmallVector dispatchOperandTypes; + // Dispatch the data type. + dispatchOperands.push_back(rewriter.create( + loc, integer64, cast(dtype))); + dispatchOperandTypes.push_back(integer64); + + ArrayAttr brgemmFlags = rewriter.getArrayAttr(flags); + SmallVector invokeOperands; + std::string dispatchName = "xsmm_gemm_dispatch"; + std::string invokeName = "xsmm_gemm_invoke"; + + if (batch != 0) { + dispatchName = "xsmm_brgemm_dispatch"; + invokeName = "xsmm_brgemm_invoke"; + } + + auto dims = SmallVector{m, n, k, lda, ldb, ldc}; + if (batch != 0) { + dims.append({strideA, strideB}); + } + for (size_t idx = 0; idx < dims.size(); idx++) { + dispatchOperands.push_back(rewriter.create( + loc, integer64, rewriter.getIntegerAttr(integer64, dims[idx]))); + dispatchOperandTypes.push_back(integer64); + } + // Dispatch the flags. Pass to the library the already ored-flag to + // avoid changing the interface every time we add a new flag. Flags + // are assumed to be verified before (i.e., op verifier). + int64_t oredFlag = xsmm::utils::getOredFlags(brgemmFlags); + + dispatchOperands.push_back(rewriter.create( + loc, integer64, IntegerAttr::get(rewriter.getI64Type(), oredFlag))); + dispatchOperandTypes.push_back(integer64); + ModuleOp module = op->getParentOfType(); + auto dispatched = xsmm::utils::buildDispatchCall( + rewriter, loc, dispatchOperands, dispatchOperandTypes, module, + SymbolRefAttr::get(op->getContext(), dispatchName)); + SmallVector operandRange; + operandRange.push_back(dispatched.getResult(0)); + for (auto operand : inputs) { + operandRange.push_back(operand); + } + if (batch != 0) { + Value batchDim = rewriter.create( + loc, integer64, rewriter.getIntegerAttr(integer64, batch)); + operandRange.push_back(batchDim); + } + auto invokeCall = xsmm::utils::buildInvokeCall( + rewriter, loc, module, operandRange, invokeName, dtype); + return std::make_pair(&*dispatched, &*invokeCall); +} + +} // namespace utils +} // namespace xsmm +} // namespace mlir diff --git a/third_party/cpu/lib/Xsmm/XsmmUtils.h b/third_party/cpu/lib/Xsmm/XsmmUtils.h new file mode 100644 index 000000000000..5d19110f4ae4 --- /dev/null +++ b/third_party/cpu/lib/Xsmm/XsmmUtils.h @@ -0,0 +1,157 @@ +//===- XsmmUtils.h - --------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TPP_DIALECT_XSMM_XSMMUTILS_H +#define TPP_DIALECT_XSMM_XSMMUTILS_H + +#include "cpu/include/Xsmm/XsmmEnum.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Value.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallVector.h" + +#include +#include + +using namespace mlir; + +namespace mlir { +class Type; +class RewriterBase; +class ArrayAttr; +class Operation; +class ValueRange; +class Attribute; + +namespace xsmm { + +struct BrgemmInfo { + int64_t m; + int64_t n; + int64_t k; + int64_t batch; + + int64_t lda; + int64_t ldb; + int64_t ldc; + int64_t strideA; + int64_t strideB; + + bool isVnni = false; +}; + +template +std::function FuncType = + [](Operation *op) { return isa(op); }; + +class UnaryKindAttr; + +struct UnaryInfo { + unsigned m; + unsigned n; + + int64_t ldi; + int64_t ldo; +}; + +struct BinaryInfo { + unsigned m; + unsigned n; + + int64_t ldiLhs; + int64_t ldiRhs; + int64_t ldo; +}; + +namespace utils { + +DataTypeAttr getDataType(RewriterBase &rewriter, Type type); + +FailureOr getUnaryInfo(Value input, Value output, + UnaryFlags inputFlag); + +FailureOr getBinaryInfo(Value lhs, BinaryFlags lhsFlag, Value rhs, + BinaryFlags rhsFlag, Value output); + +// Compute the broadcasting flags for 'inputType' based 'outputType'. +// Rules for broadcasting follows Numpy-style, and are only allowed in +// 'inputType'. see: https://numpy.org/doc/stable/user/basics.broadcasting.html +FailureOr getUnaryFlags(Type inputType, Type outputType); + +// Compute the broadcasting flags for 'operandType' based on 'outputType'. +enum class OperandPos { LHS = 0, RHS = 1 }; +FailureOr getBinFlags(ArrayRef shapeOutput, + ArrayRef shapeOperand, + OperandPos operandNumber); +FailureOr getBinaryFlags(Type operandType, Type outputType, + OperandPos operandNumber); + +FailureOr getBinaryFlagsVectorType(Type operandType, + Type outputType, + OperandPos operandNumber); + +FailureOr getLeadingDim(Type type, size_t pos = 0); + +int64_t getOredFlags(ArrayAttr flags); + +SmallVector extractInvokeOperandTypes(OpBuilder &builder, + ValueRange operands); +SmallVector getOperands(OpBuilder &builder, Location loc, + ValueRange operands, IntegerAttr dataTypeAttr); + +FailureOr isMappableToBrgemm(PatternRewriter &rewriter, + vector::ContractionOp contractOp, + SmallVector &inputs, + SmallVector &output, + ArrayRef indexingMap); + +FailureOr +makeMinorDimensionsInnerMost(RewriterBase &rewriter, + vector::ContractionOp contractOp, unsigned m, + unsigned n, unsigned k, IntegerAttr type); +std::optional getPosInCodomain(unsigned dim, Value operand, + vector::ContractionOp contractOp, + AffineMap map); +FailureOr +checkAccess(PatternRewriter &rewriter, vector::ContractionOp contractOp, + unsigned m, unsigned n, SmallVector kVector, + std::optional batchPos, SmallVector inputs, + ArrayRef indexingMap); + +bool WithInputs(PatternRewriter &rewriter, Operation *op, + SmallVector> operations, + SmallVector &inputs, SmallVector &opChain); +bool WithOutput(Operation *op, std::function operation, + SmallVector &output, SmallVector &opChain); +bool WithOps(Region *region, Operation *op, Operation *currentOp, + SmallVector> operations, + SmallVector &opChain); + +bool isTwoDTransposeOp(vector::TransposeOp transposeOp); + +func::CallOp buildDispatchCall(RewriterBase &rewriter, Location loc, + ArrayRef dispatchOperands, + ArrayRef dispatchOperandTypes, + ModuleOp module, FlatSymbolRefAttr fnName); +func::CallOp buildInvokeCall(RewriterBase &rewriter, Location loc, + ModuleOp module, SmallVector operands, + StringRef invokeName, DataTypeAttr dtype); + +// Create a pair of XSMM dispatch and invoke (BR)GEMM calls. +std::pair +buildBrgemmCalls(PatternRewriter &rewriter, Operation *op, ValueRange inputs, + xsmm::BrgemmInfo brgemmInfo, SmallVector flags); + +} // namespace utils +} // namespace xsmm +} // namespace mlir + +#endif // TPP_DIALECT_XSMM_XSMMUTILS_H diff --git a/third_party/cpu/lib/Xsmm/runtime/XsmmRunnerUtils.cpp b/third_party/cpu/lib/Xsmm/runtime/XsmmRunnerUtils.cpp new file mode 100644 index 000000000000..62b784f1f154 --- /dev/null +++ b/third_party/cpu/lib/Xsmm/runtime/XsmmRunnerUtils.cpp @@ -0,0 +1,516 @@ +//===- CRunnerUtils.cpp - Utils for MLIR execution ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements basic functions to manipulate structured MLIR types at +// runtime. Entities in this file are meant to be retargetable, including on +// targets without a C++ runtime, and must be kept C compatible. +// +//===----------------------------------------------------------------------===// + +#include "XsmmRunnerUtils.h" +#include "libxsmm.h" // NOLINT [build/include_subdir] +#include "libxsmm_utils.h" + +// Helper function prototypes. +static void printXsmmStruct(const libxsmm_gemm_shape &gemmShape, + FILE *outfile = stderr); +static void printXsmmStruct(const libxsmm_meltw_unary_shape &unaryShape, + FILE *outfile = stderr); +static void printXsmmStruct(const libxsmm_meltw_binary_shape &binaryShape, + FILE *outfile = stderr); +static void printXsmmStruct(const libxsmm_gemm_batch_reduce_config &brgemmShape, + FILE *outfile = stderr); + +static bool hasImplicitComputeDtypeUnary(const libxsmm_meltw_unary_type dtype) { + switch (dtype) { + // Zero + case LIBXSMM_MELTW_TYPE_UNARY_XOR: + // Copy + case LIBXSMM_MELTW_TYPE_UNARY_IDENTITY: + // Transpose + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT: + // VNNI2 + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_VNNI2: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_VNNI2_TO_VNNI2T: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_VNNI2T: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_VNNI2_PAD: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_PADM_MOD2: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_PADN_MOD2: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_PADNM_MOD2: + // VNNI4 + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_VNNI4: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_VNNI4_TO_VNNI4T: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_VNNI4T: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_VNNI4_PAD: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_PADM_MOD4: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_PADN_MOD4: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_PADNM_MOD4: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_VNNI4_TO_NORM: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_VNNI4_TO_VNNI2: + return true; + default: + return false; + } +} + +namespace { + +void *get_base_ptr(const libxsmm_datatype dType, void *alignedPtr, + int64_t offset) { + if (dType == LIBXSMM_DATATYPE_F32) { + float *base_ptr = (float *)alignedPtr + offset; + return (void *)base_ptr; + } else if (dType == LIBXSMM_DATATYPE_BF16) { + bf16 *base_ptr = (bf16 *)alignedPtr + offset; + return (void *)base_ptr; + } + fprintf(stderr, "Unhandled data type in get_data_pointer_from_memref_desc:%d", + dType); + return nullptr; +} + +} // namespace + +extern "C" void xsmm_gemm_invoke(const libxsmm_datatype dType, int64_t addr, + void *alignedPtrA, int64_t offsetA, + void *alignedPtrB, int64_t offsetB, + void *alignedPtrC, int64_t offsetC) { + libxsmm_xmmfunction sgemm; + libxsmm_gemm_param gemm_param; + + // LIBXSMM col-major change A with B. + gemm_param.a.primary = get_base_ptr(dType, alignedPtrB, offsetB); + gemm_param.b.primary = get_base_ptr(dType, alignedPtrA, offsetA); + gemm_param.c.primary = get_base_ptr(dType, alignedPtrC, offsetC); + + sgemm.gemm = reinterpret_cast(addr); + sgemm.gemm(&gemm_param); +} + +extern "C" int64_t xsmm_gemm_dispatch(const libxsmm_datatype dtype, int64_t m, + int64_t n, int64_t k, int64_t lda, + int64_t ldb, int64_t ldc, + const libxsmm_gemm_flags flags) { + // std::cout << "lda: " << lda << "\n"; + // std::cout << "ldb: " << ldb << "\n"; + // std::cout << "ldc: " << ldc << "\n"; + + // std::cout << "m: " << m << "\n"; + // std::cout << "n: " << n << "\n"; + // std::cout << "k: " << k << "\n"; + + libxsmm_blasint m_int = m; + libxsmm_blasint n_int = n; + libxsmm_blasint k_int = k; + + libxsmm_gemm_shape l_shape; + libxsmm_bitfield l_flags = flags; + libxsmm_bitfield l_prefetch_flags = 0; + + // See: + // https://stackoverflow.com/questions/56043539/cublassgemm-row-major-multiplication + // LIBXSMM col-major change m with n. + l_shape.m = n_int; + l_shape.n = m_int; + l_shape.k = k_int; + l_shape.lda = ldb; + l_shape.ldb = lda; + l_shape.ldc = ldc; + l_shape.a_in_type = dtype; + l_shape.b_in_type = dtype; + l_shape.out_type = dtype; + // Retarget computation type from bf16 to f32 due to missing hardware support. + l_shape.comp_type = + dtype == LIBXSMM_DATATYPE_BF16 ? LIBXSMM_DATATYPE_F32 : dtype; + + auto sgemm = libxsmm_dispatch_gemm(l_shape, l_flags, l_prefetch_flags); + if (!sgemm) { + fprintf(stderr, "failed to generate matmul func\n"); + fprintf(stderr, "dtype: %u\n", dtype); + printXsmmStruct(l_shape); + exit(-1); + } + + return reinterpret_cast(sgemm); +} + +extern "C" int64_t +xsmm_unary_dispatch(const libxsmm_meltw_unary_type op_type, + const libxsmm_datatype dtype, int64_t m, int64_t n, + int64_t ldi, int64_t ldo, + const libxsmm_meltw_unary_flags unary_flags) { + // std::cout << "ldi: " << ldi << "\n"; + // std::cout << "ldo: " << ldo << "\n"; + // std::cout << "m: " << m << "\n"; + // std::cout << "n: " << n << "\n"; + // std::cout << "type: " << type << "\n"; + // std::cout << "bcast_type: " << bcast_type << "\n"; + + libxsmm_meltw_unary_shape unary_shape; + // Row major to col major swap m with n. + unary_shape.m = static_cast(n); + unary_shape.n = static_cast(m); + unary_shape.in0_type = dtype; + // Retarget computation type from bf16 to f32 due to missing hardware support. + // Copy and Zero should remain in BF16 to avoid useless up/down casts + auto force_fp32 = (dtype == LIBXSMM_DATATYPE_BF16 && + !hasImplicitComputeDtypeUnary(op_type)); + unary_shape.comp_type = force_fp32 ? LIBXSMM_DATATYPE_F32 : dtype; + unary_shape.out_type = dtype; + unary_shape.ldi = static_cast(ldi); + unary_shape.ldo = static_cast(ldo); + + libxsmm_meltwfunction_unary kernel = + libxsmm_dispatch_meltw_unary(op_type, unary_shape, unary_flags); + if (!kernel) { + fprintf(stderr, "failed to generate unary func\n"); + fprintf(stderr, "op_type: %u\n", op_type); + fprintf(stderr, "flags: %u\n", unary_flags); + printXsmmStruct(unary_shape); + exit(-1); + } + + return reinterpret_cast(kernel); +} + +extern "C" int64_t +xsmm_binary_dispatch(const libxsmm_meltw_binary_type op_type, + const libxsmm_datatype dtype, int64_t m, int64_t n, + int64_t ldiLhs, int64_t ldiRhs, int64_t ldo, + const libxsmm_meltw_binary_flags flags) { + libxsmm_meltw_binary_shape binary_shape; + // Row major to col major swap m with n. + binary_shape.m = static_cast(n); + binary_shape.n = static_cast(m); + binary_shape.in0_type = dtype; + binary_shape.in1_type = dtype; + // Retarget computation type from bf16 to f32 due to missing hardware support. + binary_shape.comp_type = + dtype == LIBXSMM_DATATYPE_BF16 ? LIBXSMM_DATATYPE_F32 : dtype; + binary_shape.out_type = dtype; + binary_shape.ldi = static_cast(ldiLhs); + binary_shape.ldi2 = static_cast(ldiRhs); + binary_shape.ldo = static_cast(ldo); + + libxsmm_meltwfunction_binary kernel = + libxsmm_dispatch_meltw_binary(op_type, binary_shape, flags); + if (!kernel) { + fprintf(stderr, "failed to generate binary func\n"); + fprintf(stderr, "op_type: %u\n", op_type); + fprintf(stderr, "flags: %u\n", flags); + printXsmmStruct(binary_shape); + exit(-1); + } + + return reinterpret_cast(kernel); +} + +extern "C" int64_t xsmm_intel_amx_tile_config_dispatch( + const libxsmm_datatype dtype, int64_t m, int64_t n, int64_t k, int64_t lda, + int64_t ldb, int64_t ldc, int64_t stride_a, int64_t stride_b, + const libxsmm_gemm_flags flags) { + libxsmm_blasint m_int = m; + libxsmm_blasint n_int = n; + libxsmm_blasint k_int = k; + + libxsmm_gemm_shape l_shape; + libxsmm_bitfield l_cfg_flags = flags; + + l_shape.m = n_int; + l_shape.n = m_int; + l_shape.k = k_int; + l_shape.lda = ldb; + l_shape.ldb = lda; + l_shape.ldc = ldc; + l_shape.a_in_type = dtype; + l_shape.b_in_type = dtype; + l_shape.out_type = dtype; + l_shape.comp_type = + dtype == LIBXSMM_DATATYPE_BF16 ? LIBXSMM_DATATYPE_F32 : dtype; + + auto sgemm = libxsmm_dispatch_tilecfg_gemm(l_shape, l_cfg_flags); + if (!sgemm) { + fprintf(stderr, "failed to generate tileconfig func\n"); + fprintf(stderr, "dtype: %u\n", dtype); + fprintf(stderr, "flags: %u\n", flags); + printXsmmStruct(l_shape); + exit(-1); + } + + return reinterpret_cast(sgemm); +} + +extern "C" void xsmm_unary_invoke(const libxsmm_datatype dType, int64_t addr, + void *alignedPtrIn, int64_t offsetIn, + void *alignedPtrOut, int64_t offsetOut) { + libxsmm_meltw_unary_param param; + + param.in.primary = get_base_ptr(dType, alignedPtrIn, offsetIn); + param.out.primary = get_base_ptr(dType, alignedPtrOut, offsetOut); + + libxsmm_meltwfunction_unary kernel = + reinterpret_cast(addr); + kernel(¶m); +} + +extern "C" void xsmm_binary_invoke(const libxsmm_datatype dType, int64_t addr, + void *alignedPtrLhs, int64_t offsetLhs, + void *alignedPtrRhs, int64_t offsetRhs, + void *alignedPtrOut, int64_t offsetOut) { + libxsmm_meltw_binary_param param; + + param.in0.primary = get_base_ptr(dType, alignedPtrLhs, offsetLhs); + param.in1.primary = get_base_ptr(dType, alignedPtrRhs, offsetRhs); + param.out.primary = get_base_ptr(dType, alignedPtrOut, offsetOut); + + libxsmm_meltwfunction_binary kernel = + reinterpret_cast(addr); + kernel(¶m); +} + +extern "C" void xsmm_unary_scalar_invoke(const libxsmm_datatype dType, + int64_t addr, float input, + void *alignedOut, int64_t offsetOut) { + libxsmm_meltwfunction_unary kernel = + reinterpret_cast(addr); + libxsmm_meltw_unary_param param; + + param.in.primary = (void *)&input; + param.out.primary = get_base_ptr(dType, alignedOut, offsetOut); + kernel(¶m); +} + +extern "C" void xsmm_brgemm_invoke(const libxsmm_datatype dType, int64_t addr, + void *alignedPtrA, int64_t offsetA, + void *alignedPtrB, int64_t offsetB, + void *alignedPtrC, int64_t offsetC, + int64_t numBatches) { + libxsmm_xmmfunction sgemm; + libxsmm_gemm_param gemm_param; + + unsigned long long numBatchesVar = numBatches; + gemm_param.op.tertiary = (void *)&numBatchesVar; + + // LIBXSMM col-major change A with B. + gemm_param.a.primary = get_base_ptr(dType, alignedPtrB, offsetB); + gemm_param.b.primary = get_base_ptr(dType, alignedPtrA, offsetA); + gemm_param.c.primary = get_base_ptr(dType, alignedPtrC, offsetC); + + sgemm.gemm = reinterpret_cast(addr); + sgemm.gemm(&gemm_param); +} + +extern "C" int64_t xsmm_brgemm_dispatch(const libxsmm_datatype dtype, int64_t m, + int64_t n, int64_t k, int64_t lda, + int64_t ldb, int64_t ldc, + int64_t stride_a, int64_t stride_b, + const libxsmm_gemm_flags flags) { + // std::cout << "lda: " << lda << "\n"; + // std::cout << "lbd: " << ldb << "\n"; + // std::cout << "ldc: " << ldc << "\n"; + // std::cout << "m: " << m << "\n"; + // std::cout << "n: " << n << "\n"; + // std::cout << "k: " << k << "\n"; + + libxsmm_blasint lda_int = lda; + libxsmm_blasint ldb_int = ldb; + libxsmm_blasint ldc_int = ldc; + libxsmm_blasint m_int = m; + libxsmm_blasint n_int = n; + libxsmm_blasint k_int = k; + + libxsmm_gemm_shape l_shape; + libxsmm_bitfield l_flags = flags; + libxsmm_bitfield l_prefetch_flags = 0; + libxsmm_gemm_batch_reduce_config l_brconfig; + + l_shape.m = n_int; + l_shape.n = m_int; + l_shape.k = k_int; + l_shape.lda = ldb_int; + l_shape.ldb = lda_int; + l_shape.ldc = ldc_int; + l_shape.a_in_type = dtype; + l_shape.b_in_type = dtype; + l_shape.out_type = dtype; + // Retarget computation type from bf16 to f32 due to missing hardware support. + l_shape.comp_type = + dtype == LIBXSMM_DATATYPE_BF16 ? LIBXSMM_DATATYPE_F32 : dtype; + l_brconfig.br_type = LIBXSMM_GEMM_BATCH_REDUCE_STRIDE; + auto typeSize = dtype == LIBXSMM_DATATYPE_F32 ? sizeof(float) : sizeof(bf16); + l_brconfig.br_stride_a_hint = stride_b * typeSize; + l_brconfig.br_stride_b_hint = stride_a * typeSize; + l_brconfig.br_unroll_hint = 0; + + auto sgemm = + libxsmm_dispatch_brgemm(l_shape, l_flags, l_prefetch_flags, l_brconfig); + if (!sgemm) { + fprintf(stderr, "failed to generate brgemm func\n"); + fprintf(stderr, "dtype: %u\n", dtype); + printXsmmStruct(l_shape); + printXsmmStruct(l_brconfig); + exit(-1); + } + + return reinterpret_cast(sgemm); +} + +extern "C" void xsmm_fused_brgemm_invoke(const libxsmm_datatype dType, + int64_t addr, void *alignedPtrA, + int64_t offsetA, void *alignedPtrB, + int64_t offsetB, void *alignedPtrC, + int64_t offsetC, void *alignedPtrD, + int64_t offsetD, int64_t numBatches) { + libxsmm_xmmfunction sgemm; + libxsmm_gemm_ext_param gemm_param; + + unsigned long long numBatchesVar = numBatches; + gemm_param.op.tertiary = (void *)&numBatchesVar; + + // LIBXSMM col-major change A with B. + gemm_param.a.primary = get_base_ptr(dType, alignedPtrB, offsetB); + gemm_param.b.primary = get_base_ptr(dType, alignedPtrA, offsetA); + gemm_param.c.primary = get_base_ptr(dType, alignedPtrC, offsetC); + gemm_param.d.primary = get_base_ptr(dType, alignedPtrD, offsetD); + + sgemm.gemm_ext = reinterpret_cast(addr); + sgemm.gemm_ext(&gemm_param); +} + +extern "C" int64_t +xsmm_fused_brgemm_dispatch(const libxsmm_datatype data_type, int64_t m, + int64_t n, int64_t k, int64_t lda, int64_t ldb, + int64_t ldc, int64_t stride_a, int64_t stride_b, + const libxsmm_gemm_flags gemm_flags, + const libxsmm_meltw_unary_flags unary_flags, + const libxsmm_meltw_unary_type unary_op_type, + const libxsmm_meltw_binary_flags binary_flags, + const libxsmm_meltw_binary_type binary_op_type) { + // std::cout << "lda: " << lda << "\n"; + // std::cout << "lbd: " << ldb << "\n"; + // std::cout << "ldc: " << ldc << "\n"; + // std::cout << "m: " << m << "\n"; + // std::cout << "n: " << n << "\n"; + // std::cout << "k: " << k << "\n"; + + libxsmm_blasint lda_int = lda; + libxsmm_blasint ldb_int = ldb; + libxsmm_blasint ldc_int = ldc; + libxsmm_blasint m_int = m; + libxsmm_blasint n_int = n; + libxsmm_blasint k_int = k; + libxsmm_gemm_shape l_shape; + libxsmm_bitfield l_flags = gemm_flags; + libxsmm_bitfield l_prefetch_flags = 0; + + l_shape.m = n_int; + l_shape.n = m_int; + l_shape.k = k_int; + l_shape.lda = ldb_int; + l_shape.ldb = lda_int; + l_shape.ldc = ldc_int; + l_shape.a_in_type = data_type; + l_shape.b_in_type = data_type; + l_shape.out_type = data_type; + // Retarget computation type from bf16 to f32 due to missing hardware support. + l_shape.comp_type = + data_type == LIBXSMM_DATATYPE_BF16 ? LIBXSMM_DATATYPE_F32 : data_type; + + libxsmm_gemm_batch_reduce_config l_brconfig; + l_brconfig.br_type = LIBXSMM_GEMM_BATCH_REDUCE_STRIDE; + auto typeSize = + data_type == LIBXSMM_DATATYPE_F32 ? sizeof(float) : sizeof(bf16); + l_brconfig.br_stride_a_hint = stride_b * typeSize; + l_brconfig.br_stride_b_hint = stride_a * typeSize; + l_brconfig.br_unroll_hint = 0; + + libxsmm_gemm_ext_unary_argops l_argops; + memset(&l_argops, 0, sizeof(libxsmm_gemm_ext_unary_argops)); + l_argops.cp_unary_flags = LIBXSMM_MELTW_FLAG_UNARY_NONE; + l_argops.ldcp = ldc; + l_argops.cp_unary_type = unary_op_type; + + libxsmm_gemm_ext_binary_postops l_postops; + memset(&l_postops, 0, sizeof(libxsmm_gemm_ext_binary_postops)); + l_postops.d_in_type = data_type; + + l_postops.d_binary_flags = binary_flags; + l_postops.d_binary_type = binary_op_type; + l_postops.ldd = ldc; + + auto sgemm = libxsmm_dispatch_brgemm_ext(l_shape, l_flags, l_prefetch_flags, + l_brconfig, l_argops, l_postops); + if (!sgemm) { + fprintf(stderr, "failed to generate fused brgemm func\n"); + fprintf(stderr, "data_type: %u\n", data_type); + printXsmmStruct(l_shape); + printXsmmStruct(l_brconfig); + exit(-1); + } + + return reinterpret_cast(sgemm); +} + +extern "C" MLIR_RUNNERUTILS_EXPORT void +xsmm_intel_amx_tile_config_invoke(const libxsmm_datatype dType, int64_t addr, + void *tileState, int64_t offset) { + libxsmm_xmmfunction cfg_tr; + + libxsmm_tilecfg_state *l_tilestate = + reinterpret_cast(tileState); + + cfg_tr.tilecfg = reinterpret_cast(addr); + cfg_tr.tilecfg(l_tilestate); +} + +static void printXsmmStruct(const libxsmm_gemm_shape &gemmShape, + FILE *outfile) { + fprintf(outfile, "M: %d\n", gemmShape.m); + fprintf(outfile, "N: %d\n", gemmShape.n); + fprintf(outfile, "K: %d\n", gemmShape.k); + fprintf(outfile, "lda: %d\n", gemmShape.lda); + fprintf(outfile, "ldb: %d\n", gemmShape.ldb); + fprintf(outfile, "ldc: %d\n", gemmShape.ldc); + fprintf(outfile, "a_in_type: %d\n", gemmShape.a_in_type); + fprintf(outfile, "b_in_type: %d\n", gemmShape.b_in_type); + fprintf(outfile, "comp_type: %d\n", gemmShape.comp_type); + fprintf(outfile, "out_type: %d\n", gemmShape.out_type); +} + +static void printXsmmStruct(const libxsmm_meltw_unary_shape &unaryShape, + FILE *outfile) { + fprintf(outfile, "M: %d\n", unaryShape.m); + fprintf(outfile, "N: %d\n", unaryShape.n); + fprintf(outfile, "in0_type: %d\n", unaryShape.in0_type); + fprintf(outfile, "comp_type: %d\n", unaryShape.comp_type); + fprintf(outfile, "out_type: %d\n", unaryShape.out_type); + fprintf(outfile, "ldi: %d\n", unaryShape.ldi); + fprintf(outfile, "ldo: %d\n", unaryShape.ldo); +} + +static void printXsmmStruct(const libxsmm_meltw_binary_shape &binaryShape, + FILE *outfile) { + fprintf(outfile, "M: %d\n", binaryShape.m); + fprintf(outfile, "N: %d\n", binaryShape.n); + fprintf(outfile, "in0_type: %d\n", binaryShape.in0_type); + fprintf(outfile, "in1_type: %d\n", binaryShape.in1_type); + fprintf(outfile, "comp_type: %d\n", binaryShape.comp_type); + fprintf(outfile, "out_type: %d\n", binaryShape.out_type); + fprintf(outfile, "ldi: %d\n", binaryShape.ldi); + fprintf(outfile, "ldi2: %d\n", binaryShape.ldi2); + fprintf(outfile, "ldo: %d\n", binaryShape.ldo); +} + +static void +printXsmmStruct(const libxsmm_gemm_batch_reduce_config &brgemmConfig, + FILE *outfile) { + fprintf(outfile, "br_type: %d\n", brgemmConfig.br_type); + fprintf(outfile, "br_stride_a_hint: %d\n", brgemmConfig.br_stride_a_hint); + fprintf(outfile, "br_stride_b_hint: %d\n", brgemmConfig.br_stride_b_hint); + fprintf(outfile, "br_unroll_hint: %d\n", brgemmConfig.br_unroll_hint); +} diff --git a/third_party/cpu/lib/Xsmm/runtime/XsmmRunnerUtils.h b/third_party/cpu/lib/Xsmm/runtime/XsmmRunnerUtils.h new file mode 100644 index 000000000000..3d5048188efc --- /dev/null +++ b/third_party/cpu/lib/Xsmm/runtime/XsmmRunnerUtils.h @@ -0,0 +1,85 @@ +//===- CRunnerUtils.h - Utils for debugging MLIR execution ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares basic classes and functions to manipulate structured MLIR +// types at runtime. Entities in this file must be compliant with C++11 and be +// retargetable, including on targets without a C++ runtime. +// +//===----------------------------------------------------------------------===// + +#ifndef TPP_EXECUTIONENGINE_CRUNNERUTILS_H +#define TPP_EXECUTIONENGINE_CRUNNERUTILS_H + +#include "libxsmm.h" +#include "mlir/ExecutionEngine/Float16bits.h" +#include "mlir/ExecutionEngine/RunnerUtils.h" + +extern "C" MLIR_RUNNERUTILS_EXPORT int64_t +xsmm_gemm_dispatch(const libxsmm_datatype, int64_t, int64_t, int64_t, int64_t, + int64_t, int64_t, const libxsmm_gemm_flags); + +extern "C" MLIR_RUNNERUTILS_EXPORT int64_t xsmm_unary_dispatch( + const libxsmm_meltw_unary_type, const libxsmm_datatype, int64_t, int64_t, + int64_t, int64_t, const libxsmm_meltw_unary_flags); + +extern "C" MLIR_RUNNERUTILS_EXPORT int64_t xsmm_binary_dispatch( + const libxsmm_meltw_binary_type, const libxsmm_datatype, int64_t, int64_t, + int64_t, int64_t, int64_t, const libxsmm_meltw_binary_flags); + +extern "C" MLIR_RUNNERUTILS_EXPORT int64_t xsmm_brgemm_dispatch( + const libxsmm_datatype, int64_t, int64_t, int64_t, int64_t, int64_t, + int64_t, int64_t, int64_t, const libxsmm_gemm_flags); + +extern "C" MLIR_RUNNERUTILS_EXPORT int64_t xsmm_fused_brgemm_dispatch( + const libxsmm_datatype data_type, int64_t m, int64_t n, int64_t k, + int64_t lda, int64_t ldb, int64_t ldc, int64_t stride_a, int64_t stride_b, + const libxsmm_gemm_flags gemm_flags, + const libxsmm_meltw_unary_flags unary_flags, + const libxsmm_meltw_unary_type unary_op_type, + const libxsmm_meltw_binary_flags binary_flags, + const libxsmm_meltw_binary_type binary_op_type); + +extern "C" MLIR_RUNNERUTILS_EXPORT int64_t xsmm_intel_amx_tile_config_dispatch( + const libxsmm_datatype, int64_t, int64_t, int64_t, int64_t, int64_t, + int64_t, int64_t, int64_t, const libxsmm_gemm_flags); + +extern "C" MLIR_RUNNERUTILS_EXPORT void +xsmm_gemm_invoke(const libxsmm_datatype dType, int64_t addr, void *alignedPtrA, + int64_t offsetA, void *alignedPtrB, int64_t offsetB, + void *alignedPtrC, int64_t offsetC); + +extern "C" MLIR_RUNNERUTILS_EXPORT void +xsmm_unary_invoke(const libxsmm_datatype dType, int64_t addr, + void *alignedPtrIn, int64_t offsetIn, void *alignedPtrOut, + int64_t offsetOut); + +extern "C" MLIR_RUNNERUTILS_EXPORT void +xsmm_unary_scalar_invoke(const libxsmm_datatype, int64_t addr, float scalar, + void *alignedPtrOut, int64_t offsetOut); + +extern "C" MLIR_RUNNERUTILS_EXPORT void +xsmm_binary_invoke(const libxsmm_datatype dType, int64_t addr, + void *alignedPtrLhs, int64_t offsetLhs, void *alignedPtrRhs, + int64_t offsetRhs, void *alignedPtrOut, int64_t offsetOut); + +extern "C" MLIR_RUNNERUTILS_EXPORT void +xsmm_brgemm_invoke(const libxsmm_datatype dType, int64_t addr, + void *alignedPtrA, int64_t offsetA, void *alignedPtrB, + int64_t offsetB, void *alignedPtrC, int64_t offsetC, + int64_t numBatches); + +extern "C" MLIR_RUNNERUTILS_EXPORT void xsmm_fused_brgemm_invoke( + const libxsmm_datatype dType, int64_t addr, void *alignedPtrA, + int64_t offsetA, void *alignedPtrB, int64_t offsetB, void *alignedPtrC, + int64_t offsetC, void *alignedPtrD, int64_t offsetD, int64_t numBatches); + +extern "C" MLIR_RUNNERUTILS_EXPORT void +xsmm_intel_amx_tile_config_invoke(const libxsmm_datatype dType, int64_t addr, + void *alignedPtrA, int64_t offset); + +#endif // TPP_EXECUTIONENGINE_CRUNNERUTILS_H diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index 9db53e1d0c71..a126731408ca 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -2,12 +2,14 @@ #include "TritonCPUToLLVM/Passes.h" #include "TritonCPUTransforms/Passes.h" #include "TritonToTritonCPU/Passes.h" +#include "Xsmm/Passes.h" #include "triton/Dialect/TritonCPU/IR/Dialect.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" #include "mlir/Conversion/Passes.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/Passes.h" #include "mlir/Pass/Pass.h" @@ -154,6 +156,12 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) { m.def("add_func_to_llvmir", [](mlir::PassManager &pm) { pm.addPass(mlir::createConvertFuncToLLVMPass()); }); + m.def("add_convert_vector_to_xsmm", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createConvertVectorToXsmm()); + }); + m.def("add_expand_strided_metadata", [](mlir::PassManager &pm) { + pm.addPass(mlir::memref::createExpandStridedMetadataPass()); + }); } void init_triton_cpu(py::module &&m) {