forked from triton-lang/triton-cpu
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implements lowering pass from vector to XSMM microkernels. libxsmm is added as an external dependency together with general MLIR infrastructure for handling XSMM code generation and runtime execution. The XSMM lowering is optional and can be enabled at JIT step by environment variable TRITON_CPU_XSMM=1 libxsmm is built as a shared library and linked with XSMM-related libraries. These are also added to the Python infrastructure. Additionally, general MLIR utilities are imported to allow analysis, code generation and microkernel execution. Initially, a simple pattern mapping vector contraction to an XSMM kernel is added.
- Loading branch information
Showing
23 changed files
with
2,658 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Use LIBXSMM (make PREFIX=/path/to/libxsmm) given by LIBXSMMROOT | ||
set(LIBXSMMROOT $ENV{LIBXSMMROOT}) | ||
# Fetch LIBXSMM (even if LIBXSMMROOT is present) | ||
set(LIBXSMMFETCH $ENV{LIBXSMMFETCH}) | ||
|
||
if(LIBXSMMROOT AND NOT LIBXSMMFETCH) | ||
message(STATUS "Found LIBXSMM (${LIBXSMMROOT})") | ||
file(GLOB XSMM_SRCS LIST_DIRECTORIES false CONFIGURE_DEPENDS ${LIBXSMMROOT}/include/libxsmm/*.c) | ||
list(REMOVE_ITEM XSMM_SRCS ${LIBXSMMROOT}/include/libxsmm/libxsmm_generator_gemm_driver.c) | ||
else() | ||
message(STATUS "Fetching LIBXSMM") | ||
include(FetchContent) | ||
|
||
FetchContent_Declare( | ||
xsmm | ||
URL https://github.com/libxsmm/libxsmm/archive/a4bb2a90c161c3f64563846fecaf291eeaa1b1d9.tar.gz | ||
URL_HASH SHA256=4f9400bdb5361a829a1e7c635c904fc76328256df25ab071e1694fff593e5398 | ||
) | ||
|
||
FetchContent_GetProperties(xsmm) | ||
if(NOT xsmm_POPULATED) | ||
FetchContent_Populate(xsmm) | ||
endif() | ||
|
||
set(LIBXSMMROOT ${xsmm_SOURCE_DIR}) | ||
endif() | ||
|
||
if(NOT XSMM_SRCS) | ||
file(GLOB XSMM_SRCS LIST_DIRECTORIES false CONFIGURE_DEPENDS ${LIBXSMMROOT}/src/*.c) | ||
list(REMOVE_ITEM XSMM_SRCS ${LIBXSMMROOT}/src/libxsmm_generator_gemm_driver.c) | ||
endif() | ||
|
||
set(XSMM_INCLUDE_DIRS ${LIBXSMMROOT}/include) | ||
|
||
add_mlir_library(xsmm SHARED ${XSMM_SRCS}) | ||
target_include_directories(xsmm PUBLIC | ||
$<BUILD_INTERFACE:${XSMM_INCLUDE_DIRS}> | ||
$<INSTALL_INTERFACE:include/xsmm> | ||
) | ||
add_definitions(-DLIBXSMM_DEFAULT_CONFIG -U_DEBUG -D__BLAS=0) | ||
|
||
set_property(TARGET xsmm PROPERTY POSITION_INDEPENDENT_CODE ON) # -fPIC | ||
set_property(TARGET xsmm PROPERTY COMPILE_WARNING_AS_ERROR ON) | ||
|
||
set(THREADS_PREFER_PTHREAD_FLAG ON) | ||
find_package(Threads REQUIRED) | ||
target_link_libraries(xsmm PUBLIC Threads::Threads) | ||
target_link_libraries(xsmm PUBLIC ${CMAKE_DL_LIBS}) | ||
|
||
include(CheckLibraryExists) | ||
check_library_exists(m sqrt "" XSMM_LIBM) | ||
if(XSMM_LIBM) | ||
target_link_libraries(xsmm PUBLIC m) | ||
endif() | ||
check_library_exists(rt sched_yield "" XSMM_LIBRT) | ||
if(XSMM_LIBRT) | ||
target_link_libraries(xsmm PUBLIC rt) | ||
endif() | ||
#target_link_libraries(xsmm PUBLIC c) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
set(LLVM_TARGET_DEFINITIONS Passes.td) | ||
mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonCPUXsmm) | ||
add_public_tablegen_target(TritonCPUXsmmPassIncGen) | ||
|
||
set(LLVM_TARGET_DEFINITIONS XsmmEnum.td) | ||
mlir_tablegen(XsmmEnum.h.inc -gen-enum-decls) | ||
mlir_tablegen(XsmmEnum.cpp.inc -gen-enum-defs) | ||
add_public_tablegen_target(TritonCPUXsmmAttrDefIncGen) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
#ifndef TritonCPUXsmm_CONVERSION_PASSES_H | ||
#define TritonCPUXsmm_CONVERSION_PASSES_H | ||
|
||
#include "mlir/Pass/Pass.h" | ||
|
||
namespace mlir { | ||
|
||
class ModuleOp; | ||
|
||
namespace affine { | ||
class AffineDialect; | ||
} // namespace affine | ||
|
||
namespace arith { | ||
class ArithDialect; | ||
} // namespace arith | ||
|
||
namespace func { | ||
class FuncOp; | ||
class FuncDialect; | ||
} // namespace func | ||
|
||
namespace linalg { | ||
class LinalgDialect; | ||
} // namespace linalg | ||
|
||
namespace LLVM { | ||
class LLVMDialect; | ||
} // namespace LLVM | ||
|
||
namespace math { | ||
class MathDialect; | ||
} // namespace math | ||
|
||
namespace memref { | ||
class MemRefDialect; | ||
} // namespace memref | ||
|
||
namespace scf { | ||
class SCFDialect; | ||
} // namespace scf | ||
|
||
namespace tensor { | ||
class TensorDialect; | ||
} // namespace tensor | ||
|
||
namespace vector { | ||
class VectorDialect; | ||
} // namespace vector | ||
|
||
} // namespace mlir | ||
|
||
namespace mlir { | ||
namespace triton { | ||
namespace cpu { | ||
|
||
#define GEN_PASS_DECL | ||
#include "cpu/include/Xsmm/Passes.h.inc" | ||
|
||
#define GEN_PASS_REGISTRATION | ||
#include "cpu/include/Xsmm/Passes.h.inc" | ||
|
||
} // namespace cpu | ||
} // namespace triton | ||
} // namespace mlir | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#ifndef TRITONCPU_XSMM_PASSES | ||
#define TRITONCPU_XSMM_PASSES | ||
|
||
include "mlir/Pass/PassBase.td" | ||
|
||
def ConvertVectorToXsmm : Pass<"triton-cpu-convert-vector-to-xsmm", "mlir::ModuleOp"> { | ||
let summary = "Convert vector to xsmm"; | ||
let description = [{ | ||
Convert vector operations to XSMM operations. | ||
}]; | ||
let dependentDialects = ["func::FuncDialect", | ||
"memref::MemRefDialect", | ||
"vector::VectorDialect", | ||
"LLVM::LLVMDialect"]; | ||
} | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
//===- XsmmEnum.h -----------------------------------------------*- C++ -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef TPP_DIALECT_XSMM_XSMMENUM_H | ||
#define TPP_DIALECT_XSMM_XSMMENUM_H | ||
|
||
#include "mlir/IR/Attributes.h" | ||
#include "mlir/IR/DialectImplementation.h" | ||
|
||
#define GET_ATTRDEF_CLASSES | ||
#include "cpu/include/Xsmm/XsmmEnum.h.inc" | ||
|
||
#endif // TPP_DIALECT_XSMM_XSMMENUM_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
//===- XsmmEnum --------------------------------------------*- Tablegen -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
include "mlir/IR/AttrTypeBase.td" | ||
include "mlir/IR/EnumAttr.td" | ||
|
||
def Xsmm_DataType: I64EnumAttr< | ||
"DataType", "see: libxsmm_datatype", | ||
[ | ||
I64EnumAttrCase<"F32", 1, "f32">, | ||
I64EnumAttrCase<"BF16", 2, "bf16"> | ||
]>{ | ||
let cppNamespace = "mlir::xsmm"; | ||
} | ||
|
||
def Xsmm_BinaryKind : I64EnumAttr< | ||
"BinaryKind", "see: libxsmm_meltw_binary_type", | ||
[ | ||
I64EnumAttrCase<"NONE", 0, "none">, | ||
I64EnumAttrCase<"ADD", 1, "add">, | ||
I64EnumAttrCase<"MUL", 2, "mul">, | ||
I64EnumAttrCase<"SUB", 3, "sub">, | ||
I64EnumAttrCase<"DIV", 4, "div"> | ||
]> { | ||
let cppNamespace = "mlir::xsmm"; | ||
} | ||
|
||
def Xsmm_UnaryKind : I64EnumAttr< | ||
"UnaryKind", "see: libxsmm_meltw_unary_type", | ||
[ | ||
I64EnumAttrCase<"NONE", 0, "none">, | ||
I64EnumAttrCase<"IDENTITY", 1, "identity">, | ||
I64EnumAttrCase<"ZERO", 2, "zero">, | ||
I64EnumAttrCase<"RELU", 5, "relu">, | ||
I64EnumAttrCase<"VNNI2", 28, "vnni_2">, | ||
I64EnumAttrCase<"TRANSPOSE", 29, "transpose"> | ||
]> { | ||
let cppNamespace = "mlir::xsmm"; | ||
} | ||
|
||
def Xsmm_UnaryFlags : I64EnumAttr< | ||
"UnaryFlags", "see: libxsmm_meltw_unary_flags", | ||
[ | ||
I64EnumAttrCase<"NONE", 0, "none">, | ||
I64EnumAttrCase<"BCAST_ROW", 2, "bcast_row">, | ||
I64EnumAttrCase<"BCAST_COL", 4, "bcast_col">, | ||
I64EnumAttrCase<"BCAST_SCALAR", 8, "bcast_scalar"> | ||
]> { | ||
let cppNamespace = "mlir::xsmm"; | ||
} | ||
|
||
def Xsmm_BinaryFlags : I64EnumAttr< | ||
"BinaryFlags", "see: libxsmm_meltw_binary_flags", | ||
[ | ||
I64EnumAttrCase<"NONE", 0, "none">, | ||
I64EnumAttrCase<"BCAST_ROW_IN_0", 1, "bcast_row_in0">, | ||
I64EnumAttrCase<"BCAST_ROW_IN_1", 2, "bcast_row_in1">, | ||
I64EnumAttrCase<"BCAST_COL_IN_0", 4, "bcast_col_in0">, | ||
I64EnumAttrCase<"BCAST_COL_IN_1", 8, "bcast_col_in1">, | ||
I64EnumAttrCase<"BCAST_SCALAR_IN_0", 16, "bcast_scalar_in0">, | ||
I64EnumAttrCase<"BCAST_SCALAR_IN_1", 32, "bcast_scalar_in1"> | ||
]> { | ||
let cppNamespace = "mlir::xsmm"; | ||
} | ||
|
||
def Xsmm_GemmFlags : I64EnumAttr< | ||
"GemmFlags", "see: libxsmm_gemm_flags", | ||
[ | ||
I64EnumAttrCase<"NONE", 0, "none">, | ||
I64EnumAttrCase<"BETA_0", 4, "beta_0">, | ||
I64EnumAttrCase<"VNNI_A", 2048, "vnni_a">, | ||
I64EnumAttrCase<"VNNI_B", 4096, "vnni_b">, | ||
I64EnumAttrCase<"VNNI_C", 8192, "vnni_c">, | ||
I64EnumAttrCase<"NO_RESET_TILECONFIG", 64, "no_reset_tileconfig">, | ||
I64EnumAttrCase<"NO_SETUP_TILECONFIG", 128, "no_setup_tileconfig"> | ||
]> { | ||
let cppNamespace = "mlir::xsmm"; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
add_triton_library(TritonCPUXsmm | ||
ConvertVectorToXsmm.cpp | ||
VnniUtils.cpp | ||
ValueUtils.cpp | ||
XsmmEnum.cpp | ||
XsmmUtils.cpp | ||
|
||
DEPENDS | ||
TritonCPUXsmmPassIncGen | ||
TritonCPUXsmmAttrDefIncGen | ||
xsmm | ||
|
||
LINK_LIBS PUBLIC | ||
MLIRIR | ||
MLIRPass | ||
MLIRVectorDialect | ||
MLIRMemRefDialect | ||
MLIRFuncDialect | ||
MLIRLLVMDialect | ||
MLIRInferTypeOpInterface | ||
MLIRLinalgUtils | ||
xsmm | ||
) | ||
|
||
target_include_directories(TritonCPUXsmm | ||
PUBLIC | ||
$<BUILD_INTERFACE:${XSMM_INCLUDE_DIRS}> | ||
) |
Oops, something went wrong.