Skip to content

Commit

Permalink
[fix]add hip target (apache#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
Cunxiao2002 authored Oct 16, 2024
1 parent 04abf1e commit b257cd7
Show file tree
Hide file tree
Showing 14 changed files with 2,620 additions and 1 deletion.
4 changes: 3 additions & 1 deletion cmake/modules/ROCM.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ if(USE_ROCM)
message(STATUS "Build with ROCM support")
tvm_file_glob(GLOB RUNTIME_ROCM_SRCS src/runtime/rocm/*.cc)
list(APPEND RUNTIME_SRCS ${RUNTIME_ROCM_SRCS})
list(APPEND COMPILER_SRCS src/target/opt/build_rocm_on.cc)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${ROCM_HIPHCC_LIBRARY})
if (ROCM_HSA_LIBRARY)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${ROCM_HSA_LIBRARY})
Expand Down Expand Up @@ -69,5 +70,6 @@ if(USE_ROCM)
endif(USE_THRUST)

else(USE_ROCM)
list(APPEND COMPILER_SRCS src/target/opt/build_rocm_off.cc)
#list(APPEND COMPILER_SRCS src/target/opt/build_rocm_off.cc)
list(APPEND COMPILER_SRCS src/target/opt/build_rocm_nort.cc)
endif(USE_ROCM)
46 changes: 46 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,52 @@ TVM_DLL const Op& tvm_store_matrix_sync();
*/
TVM_DLL const Op& ptx_mma();

/*!
* \brief tvm intrinsic for amd matrix core mfma instructions.
*
* void tvm_mfma(StringImm shape, StringImm A_layout, StringImm B_layout,
* StringImm A_dtype, StringImm B_dtype, StringImm C_dtype,
* Var multiplicand_a, Expr a_index,
* Var multiplicand_b, Expr b_index,
* Var accumulator, Expr c_index);
*/
TVM_DLL const Op& tvm_mfma();

/*!
* \brief tvm intrinsic for storing the result of AMD MFMA into a destination pointer.
*
* There is no real instruction that does that, but we want to hide details of
* complex index manipulation behind this intrinsic to simplify TIR lowering passes (e.g.
* LowerWarpMemory) like cuda ptx backend does.
*
* void tvm_mfma_store(IntImm m, IntImm n, Var dst_ptr, Var src_ptr, Expr src_offset, Var
* dst_stride);
*/
TVM_DLL const Op& tvm_mfma_store();

/*!
* \brief tvm intrinsic for amd rdna matrix core instructions.
*
* void tvm_rdna_wmma(StringImm shape, StringImm A_layout, StringImm B_layout,
* StringImm A_dtype, StringImm B_dtype, StringImm C_dtype,
* Var multiplicand_a, Expr a_index,
* Var multiplicand_b, Expr b_index,
* Var accumulator, Expr c_index);
*/
TVM_DLL const Op& tvm_rdna_wmma();

/*!
* \brief tvm intrinsic for storing the result of AMD RDNA WMMA into a destination pointer.
*
* There is no real instruction that does that, but we want to hide details of
* complex index manipulation behind this intrinsic to simplify TIR lowering passes (e.g.
* LowerWarpMemory) like cuda ptx backend does.
*
* void tvm_rdna_wmma_store(IntImm m, IntImm n, Var dst_ptr, Var src_ptr, Expr src_offset, Var
* dst_stride);
*/
TVM_DLL const Op& tvm_rdna_wmma_store();

/*!
* \brief tvm intrinsic for ptx predicate load with 32-bit data type.
*
Expand Down
1 change: 1 addition & 0 deletions python/tvm/_ffi/runtime_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ class Device(ctypes.Structure):
"metal": kDLMetal,
"vpi": kDLVPI,
"rocm": kDLROCM,
"hip": kDLROCM,
"ext_dev": kDLExtDev,
"hexagon": kDLHexagon,
"webgpu": kDLWebGPU,
Expand Down
111 changes: 111 additions & 0 deletions python/tvm/contrib/hipcc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Utility to invoke hipcc compiler in the system"""
from __future__ import absolute_import as _abs

import subprocess
import os
import warnings

import tvm._ffi
from tvm.target import Target

from . import utils
from .._ffi.base import py_str
from .rocm import get_rocm_arch, find_rocm_path


def compile_hip(code, target_format="hsaco", arch=None, options=None, path_target=None):
"""Compile HIP code with hipcc.
Parameters
----------
code : str
The HIP code.
target_format : str
The target format of hipcc compiler.
arch : str
The AMD GPU architecture.
options : str or list of str
The additional options.
path_target : str, optional
Output file.
Return
------
hsaco : bytearray
The bytearray of the hsaco
"""
if arch is None:
rocm_path = find_rocm_path()
arch = get_rocm_arch(rocm_path)

temp = utils.tempdir()
if target_format not in ["hsaco"]:
raise ValueError("target_format must be hsaco")
temp_code = temp.relpath("my_kernel.cc")
temp_target = temp.relpath("my_kernel.%s" % target_format)

with open(temp_code, "w") as out_file:
out_file.write(code)

file_target = path_target if path_target else temp_target
cmd = ["hipcc"]
cmd += ["-O3", '-c']
if isinstance(arch, str):
cmd += [f"--offload-arch={arch}"]
if target_format == "hsaco":
cmd += ["--genco"]
if options:
if isinstance(options, str):
cmd += [options]
elif isinstance(options, list):
cmd += options
else:
raise ValueError("options must be str or list of str")

cmd += ["-o", file_target]
cmd += [temp_code]
print(f"cmd: {cmd}")
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)

(out, _) = proc.communicate()

if proc.returncode != 0:
msg = code
msg += "\nCompilation error:\n"
msg += py_str(out)
raise RuntimeError(msg)

with open(file_target, "rb") as f:
data = bytearray(f.read())
if not data:
raise RuntimeError("Compilation error: empty result is generated")
return data


@tvm._ffi.register_func("tvm_callback_hip_compile")
def tvm_callback_hip_compile(code):
"""use hipcc to generate fatbin code for better optimization"""
hsaco = compile_hip(code, target_format="hsaco")
return hsaco
8 changes: 8 additions & 0 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1914,6 +1914,10 @@ def wrapped(*args, **kwargs):
vectorlow = _dtype_forward(_tir_op.vectorlow)
vectorhigh = _dtype_forward(_tir_op.vectorhigh)
vectorcombine = _dtype_forward(_tir_op.vectorcombine)
tvm_mfma = _dtype_forward(_tir_op.tvm_mfma)
tvm_mfma_store = _dtype_forward(_tir_op.tvm_mfma_store)
tvm_rdna_wmma = _dtype_forward(_tir_op.tvm_rdna_wmma)
tvm_rdna_wmma_store = _dtype_forward(_tir_op.tvm_rdna_wmma_store)


broadcast = Broadcast
Expand Down Expand Up @@ -2170,6 +2174,10 @@ def wrapped(*args, **kwargs):
"vectorlow",
"vectorhigh",
"vectorcombine",
"tvm_mfma",
"tvm_mfma_store",
"tvm_rdna_wmma",
"tvm_rdna_wmma_store",
"assume",
"undef",
"tvm_call_packed",
Expand Down
1 change: 1 addition & 0 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
ptx_wait_barrier,
create_barriers,
)
from .op import tvm_mfma, tvm_mfma_store, tvm_rdna_wmma, tvm_rdna_wmma_store
from .op import vectorlow, vectorhigh, vectorcombine
from .op import infinity, reinterpret
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz
Expand Down
Loading

0 comments on commit b257cd7

Please sign in to comment.