Skip to content

Commit

Permalink
Vector to XSMM (triton-lang#3)
Browse files Browse the repository at this point in the history
Implements lowering pass from vector to XSMM microkernels.
libxsmm is added as an external dependency together with general MLIR infrastructure for handling XSMM code generation and runtime execution.
The XSMM lowering is optional and can be enabled at JIT step by environment variable TRITON_CPU_XSMM=1

libxsmm is built as a shared library and linked with XSMM-related libraries. These are also added to the Python infrastructure.
Additionally, general MLIR utilities are imported to allow analysis, code generation and microkernel execution.
Initially, a simple pattern mapping vector contraction to an XSMM kernel is added.
  • Loading branch information
adam-smnk authored and Devjiu committed Nov 13, 2024
1 parent 666b8ae commit b01f1eb
Show file tree
Hide file tree
Showing 23 changed files with 2,658 additions and 2 deletions.
2 changes: 2 additions & 0 deletions bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "cpu/include/TritonCPUToLLVM/Passes.h"
#include "cpu/include/TritonCPUTransforms/Passes.h"
#include "cpu/include/TritonToTritonCPU/Passes.h"
#include "cpu/include/Xsmm/Passes.h"
#include "nvidia/include/NVGPUToLLVM/Passes.h"
#include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h"
#include "triton/Conversion/TritonGPUToLLVM/Passes.h"
Expand Down Expand Up @@ -74,6 +75,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::triton::cpu::registerTritonCPUTransformsPasses();
mlir::triton::cpu::registerTritonCPUToLLVMPasses();
mlir::triton::cpu::registerTritonOpScalarizeExternalModels(registry);
mlir::triton::cpu::registerTritonCPUXsmmPasses();

// TODO: register Triton & TritonGPU passes
registry.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
Expand Down
59 changes: 59 additions & 0 deletions cmake/xsmm.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Use LIBXSMM (make PREFIX=/path/to/libxsmm) given by LIBXSMMROOT
set(LIBXSMMROOT $ENV{LIBXSMMROOT})
# Fetch LIBXSMM (even if LIBXSMMROOT is present)
set(LIBXSMMFETCH $ENV{LIBXSMMFETCH})

if(LIBXSMMROOT AND NOT LIBXSMMFETCH)
message(STATUS "Found LIBXSMM (${LIBXSMMROOT})")
file(GLOB XSMM_SRCS LIST_DIRECTORIES false CONFIGURE_DEPENDS ${LIBXSMMROOT}/include/libxsmm/*.c)
list(REMOVE_ITEM XSMM_SRCS ${LIBXSMMROOT}/include/libxsmm/libxsmm_generator_gemm_driver.c)
else()
message(STATUS "Fetching LIBXSMM")
include(FetchContent)

FetchContent_Declare(
xsmm
URL https://github.com/libxsmm/libxsmm/archive/a4bb2a90c161c3f64563846fecaf291eeaa1b1d9.tar.gz
URL_HASH SHA256=4f9400bdb5361a829a1e7c635c904fc76328256df25ab071e1694fff593e5398
)

FetchContent_GetProperties(xsmm)
if(NOT xsmm_POPULATED)
FetchContent_Populate(xsmm)
endif()

set(LIBXSMMROOT ${xsmm_SOURCE_DIR})
endif()

if(NOT XSMM_SRCS)
file(GLOB XSMM_SRCS LIST_DIRECTORIES false CONFIGURE_DEPENDS ${LIBXSMMROOT}/src/*.c)
list(REMOVE_ITEM XSMM_SRCS ${LIBXSMMROOT}/src/libxsmm_generator_gemm_driver.c)
endif()

set(XSMM_INCLUDE_DIRS ${LIBXSMMROOT}/include)

add_mlir_library(xsmm SHARED ${XSMM_SRCS})
target_include_directories(xsmm PUBLIC
$<BUILD_INTERFACE:${XSMM_INCLUDE_DIRS}>
$<INSTALL_INTERFACE:include/xsmm>
)
add_definitions(-DLIBXSMM_DEFAULT_CONFIG -U_DEBUG -D__BLAS=0)

set_property(TARGET xsmm PROPERTY POSITION_INDEPENDENT_CODE ON) # -fPIC
set_property(TARGET xsmm PROPERTY COMPILE_WARNING_AS_ERROR ON)

set(THREADS_PREFER_PTHREAD_FLAG ON)
find_package(Threads REQUIRED)
target_link_libraries(xsmm PUBLIC Threads::Threads)
target_link_libraries(xsmm PUBLIC ${CMAKE_DL_LIBS})

include(CheckLibraryExists)
check_library_exists(m sqrt "" XSMM_LIBM)
if(XSMM_LIBM)
target_link_libraries(xsmm PUBLIC m)
endif()
check_library_exists(rt sched_yield "" XSMM_LIBRT)
if(XSMM_LIBRT)
target_link_libraries(xsmm PUBLIC rt)
endif()
#target_link_libraries(xsmm PUBLIC c)
17 changes: 15 additions & 2 deletions third_party/cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,15 +1,28 @@
# libxsmm
include(xsmm)
message (STATUS "LIBXSMM Include dir: ${XSMM_INCLUDE_DIRS}")

include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include)
add_subdirectory(include)
add_subdirectory(lib)
if(TRITON_BUILD_PYTHON_MODULE)
add_triton_plugin(TritonCPU ${CMAKE_CURRENT_SOURCE_DIR}/triton_cpu.cc LINK_LIBS TritonCPUToLLVM TritonCPUTransforms)
target_link_libraries(TritonCPU PUBLIC MLIRVectorToSCF MLIRAffineToStandard MLIRMathToLibm MLIRAMXToLLVMIRTranslation)
add_triton_plugin(TritonCPU ${CMAKE_CURRENT_SOURCE_DIR}/triton_cpu.cc LINK_LIBS TritonCPUToLLVM TritonCPUTransforms TritonCPUXsmm)
target_link_libraries(TritonCPU PUBLIC MLIRVectorToSCF MLIRAffineToStandard MLIRMathToLibm MLIRAMXToLLVMIRTranslation MLIRMemRefTransforms)
endif()

add_library(TritonCPURuntime SHARED ${CMAKE_CURRENT_SOURCE_DIR}/runtime/cpu_runtime.cpp)
target_link_libraries(TritonCPURuntime PRIVATE LLVMSupport)

add_library(TritonCPUXsmmRuntime SHARED ${CMAKE_CURRENT_SOURCE_DIR}/lib/Xsmm/runtime/XsmmRunnerUtils.cpp)
target_link_libraries(TritonCPUXsmmRuntime PRIVATE xsmm)
set_property(TARGET TritonCPUXsmmRuntime PROPERTY CXX_STANDARD 11)
target_compile_definitions(TritonCPUXsmmRuntime PRIVATE mlir_c_runner_utils_EXPORTS)
target_include_directories(TritonCPUXsmmRuntime
PUBLIC
$<BUILD_INTERFACE:${XSMM_INCLUDE_DIRS}>
)

# Build and link sleef
set(SLEEF_BUILD_SHARED_LIBS ON CACHE BOOL "Build sleef shared lib" FORCE)
set(SLEEF_BUILD_DFT OFF CACHE BOOL "Don't build sleef DFT lib" FORCE)
Expand Down
8 changes: 8 additions & 0 deletions third_party/cpu/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class CPUOptions:
enable_fp_fusion: bool = True
max_num_imprecise_acc_default: int = 0
enable_fast_math: bool = True
enable_xsmm: bool = False
vec_lib: Optional[str] = 'libsleef'
# TODO: Try to enable it.
sanitize_overflow: bool = False
Expand Down Expand Up @@ -96,6 +97,8 @@ def parse_options(self, opts) -> Any:
if "supported_fp8_dtypes" not in args:
supported_fp8_dtypes = set(CPUOptions.supported_fp8_dtypes)
args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes))
if "enable_xsmm" not in args:
args["enable_xsmm"] = os.getenv("TRITON_CPU_XSMM", "0") != "0"
return CPUOptions(**args)

def pack_metadata(self, metadata):
Expand Down Expand Up @@ -194,7 +197,10 @@ def make_llir(self, src, metadata, options):
# TritonCPU -> LLVM-IR (MLIR)
pm = ir.pass_manager(mod.context)
pm.enable_debug()
if options.enable_xsmm:
cpu.passes.ttcpuir.add_convert_vector_to_xsmm(pm)
cpu.passes.ttcpuir.add_lower_vector_multi_dim(pm)
cpu.passes.ttcpuir.add_expand_strided_metadata(pm)
cpu.passes.ttcpuir.add_vector_to_scf(pm, True, 1, False)
cpu.passes.ttcpuir.add_lower_affine(pm)
passes.convert.add_scf_to_cf(pm)
Expand Down Expand Up @@ -259,6 +265,8 @@ def make_so(src, metadata, options):
Path(asm_path).write_text(src)
lib_dirs = cpu_driver.library_dirs
libs = ["gcc", "m", "TritonCPURuntime", "sleef"]
if options.enable_xsmm:
libs.extend(["xsmm", "TritonCPUXsmmRuntime"])
so = _build("kernel", asm_path, tmpdir, lib_dirs, cpu_driver.include_dirs, libs)
with open(so, "rb") as f:
return f.read()
Expand Down
1 change: 1 addition & 0 deletions third_party/cpu/include/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ add_subdirectory(ScalarizePass)
add_subdirectory(TritonCPUToLLVM)
add_subdirectory(TritonCPUTransforms)
add_subdirectory(TritonToTritonCPU)
add_subdirectory(Xsmm)
8 changes: 8 additions & 0 deletions third_party/cpu/include/Xsmm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonCPUXsmm)
add_public_tablegen_target(TritonCPUXsmmPassIncGen)

set(LLVM_TARGET_DEFINITIONS XsmmEnum.td)
mlir_tablegen(XsmmEnum.h.inc -gen-enum-decls)
mlir_tablegen(XsmmEnum.cpp.inc -gen-enum-defs)
add_public_tablegen_target(TritonCPUXsmmAttrDefIncGen)
67 changes: 67 additions & 0 deletions third_party/cpu/include/Xsmm/Passes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#ifndef TritonCPUXsmm_CONVERSION_PASSES_H
#define TritonCPUXsmm_CONVERSION_PASSES_H

#include "mlir/Pass/Pass.h"

namespace mlir {

class ModuleOp;

namespace affine {
class AffineDialect;
} // namespace affine

namespace arith {
class ArithDialect;
} // namespace arith

namespace func {
class FuncOp;
class FuncDialect;
} // namespace func

namespace linalg {
class LinalgDialect;
} // namespace linalg

namespace LLVM {
class LLVMDialect;
} // namespace LLVM

namespace math {
class MathDialect;
} // namespace math

namespace memref {
class MemRefDialect;
} // namespace memref

namespace scf {
class SCFDialect;
} // namespace scf

namespace tensor {
class TensorDialect;
} // namespace tensor

namespace vector {
class VectorDialect;
} // namespace vector

} // namespace mlir

namespace mlir {
namespace triton {
namespace cpu {

#define GEN_PASS_DECL
#include "cpu/include/Xsmm/Passes.h.inc"

#define GEN_PASS_REGISTRATION
#include "cpu/include/Xsmm/Passes.h.inc"

} // namespace cpu
} // namespace triton
} // namespace mlir

#endif
17 changes: 17 additions & 0 deletions third_party/cpu/include/Xsmm/Passes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#ifndef TRITONCPU_XSMM_PASSES
#define TRITONCPU_XSMM_PASSES

include "mlir/Pass/PassBase.td"

def ConvertVectorToXsmm : Pass<"triton-cpu-convert-vector-to-xsmm", "mlir::ModuleOp"> {
let summary = "Convert vector to xsmm";
let description = [{
Convert vector operations to XSMM operations.
}];
let dependentDialects = ["func::FuncDialect",
"memref::MemRefDialect",
"vector::VectorDialect",
"LLVM::LLVMDialect"];
}

#endif
18 changes: 18 additions & 0 deletions third_party/cpu/include/Xsmm/XsmmEnum.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
//===- XsmmEnum.h -----------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef TPP_DIALECT_XSMM_XSMMENUM_H
#define TPP_DIALECT_XSMM_XSMMENUM_H

#include "mlir/IR/Attributes.h"
#include "mlir/IR/DialectImplementation.h"

#define GET_ATTRDEF_CLASSES
#include "cpu/include/Xsmm/XsmmEnum.h.inc"

#endif // TPP_DIALECT_XSMM_XSMMENUM_H
83 changes: 83 additions & 0 deletions third_party/cpu/include/Xsmm/XsmmEnum.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
//===- XsmmEnum --------------------------------------------*- Tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/EnumAttr.td"

def Xsmm_DataType: I64EnumAttr<
"DataType", "see: libxsmm_datatype",
[
I64EnumAttrCase<"F32", 1, "f32">,
I64EnumAttrCase<"BF16", 2, "bf16">
]>{
let cppNamespace = "mlir::xsmm";
}

def Xsmm_BinaryKind : I64EnumAttr<
"BinaryKind", "see: libxsmm_meltw_binary_type",
[
I64EnumAttrCase<"NONE", 0, "none">,
I64EnumAttrCase<"ADD", 1, "add">,
I64EnumAttrCase<"MUL", 2, "mul">,
I64EnumAttrCase<"SUB", 3, "sub">,
I64EnumAttrCase<"DIV", 4, "div">
]> {
let cppNamespace = "mlir::xsmm";
}

def Xsmm_UnaryKind : I64EnumAttr<
"UnaryKind", "see: libxsmm_meltw_unary_type",
[
I64EnumAttrCase<"NONE", 0, "none">,
I64EnumAttrCase<"IDENTITY", 1, "identity">,
I64EnumAttrCase<"ZERO", 2, "zero">,
I64EnumAttrCase<"RELU", 5, "relu">,
I64EnumAttrCase<"VNNI2", 28, "vnni_2">,
I64EnumAttrCase<"TRANSPOSE", 29, "transpose">
]> {
let cppNamespace = "mlir::xsmm";
}

def Xsmm_UnaryFlags : I64EnumAttr<
"UnaryFlags", "see: libxsmm_meltw_unary_flags",
[
I64EnumAttrCase<"NONE", 0, "none">,
I64EnumAttrCase<"BCAST_ROW", 2, "bcast_row">,
I64EnumAttrCase<"BCAST_COL", 4, "bcast_col">,
I64EnumAttrCase<"BCAST_SCALAR", 8, "bcast_scalar">
]> {
let cppNamespace = "mlir::xsmm";
}

def Xsmm_BinaryFlags : I64EnumAttr<
"BinaryFlags", "see: libxsmm_meltw_binary_flags",
[
I64EnumAttrCase<"NONE", 0, "none">,
I64EnumAttrCase<"BCAST_ROW_IN_0", 1, "bcast_row_in0">,
I64EnumAttrCase<"BCAST_ROW_IN_1", 2, "bcast_row_in1">,
I64EnumAttrCase<"BCAST_COL_IN_0", 4, "bcast_col_in0">,
I64EnumAttrCase<"BCAST_COL_IN_1", 8, "bcast_col_in1">,
I64EnumAttrCase<"BCAST_SCALAR_IN_0", 16, "bcast_scalar_in0">,
I64EnumAttrCase<"BCAST_SCALAR_IN_1", 32, "bcast_scalar_in1">
]> {
let cppNamespace = "mlir::xsmm";
}

def Xsmm_GemmFlags : I64EnumAttr<
"GemmFlags", "see: libxsmm_gemm_flags",
[
I64EnumAttrCase<"NONE", 0, "none">,
I64EnumAttrCase<"BETA_0", 4, "beta_0">,
I64EnumAttrCase<"VNNI_A", 2048, "vnni_a">,
I64EnumAttrCase<"VNNI_B", 4096, "vnni_b">,
I64EnumAttrCase<"VNNI_C", 8192, "vnni_c">,
I64EnumAttrCase<"NO_RESET_TILECONFIG", 64, "no_reset_tileconfig">,
I64EnumAttrCase<"NO_SETUP_TILECONFIG", 128, "no_setup_tileconfig">
]> {
let cppNamespace = "mlir::xsmm";
}
1 change: 1 addition & 0 deletions third_party/cpu/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ add_subdirectory(Analysis)
add_subdirectory(TritonCPUToLLVM)
add_subdirectory(TritonCPUTransforms)
add_subdirectory(TritonToTritonCPU)
add_subdirectory(Xsmm)
28 changes: 28 additions & 0 deletions third_party/cpu/lib/Xsmm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
add_triton_library(TritonCPUXsmm
ConvertVectorToXsmm.cpp
VnniUtils.cpp
ValueUtils.cpp
XsmmEnum.cpp
XsmmUtils.cpp

DEPENDS
TritonCPUXsmmPassIncGen
TritonCPUXsmmAttrDefIncGen
xsmm

LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRVectorDialect
MLIRMemRefDialect
MLIRFuncDialect
MLIRLLVMDialect
MLIRInferTypeOpInterface
MLIRLinalgUtils
xsmm
)

target_include_directories(TritonCPUXsmm
PUBLIC
$<BUILD_INTERFACE:${XSMM_INCLUDE_DIRS}>
)
Loading

0 comments on commit b01f1eb

Please sign in to comment.