forked from triton-lang/triton
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[RUNTIME] refactor driver (triton-lang#1515)
Improved separation between different backends
- Loading branch information
Showing
14 changed files
with
428 additions
and
367 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
graft src | ||
graft triton/third_party | ||
graft triton/runtime/backends/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
#include "cuda.h" | ||
#define PY_SSIZE_T_CLEAN | ||
#include <Python.h> | ||
|
||
static inline void gpuAssert(CUresult code, const char *file, int line) { | ||
if (code != CUDA_SUCCESS) { | ||
const char *prefix = "Triton Error [CUDA]: "; | ||
const char *str; | ||
cuGetErrorString(code, &str); | ||
char err[1024] = {0}; | ||
strcat(err, prefix); | ||
strcat(err, str); | ||
PyErr_SetString(PyExc_RuntimeError, err); | ||
} | ||
} | ||
|
||
#define CUDA_CHECK(ans) \ | ||
{ \ | ||
gpuAssert((ans), __FILE__, __LINE__); \ | ||
if (PyErr_Occurred()) \ | ||
return NULL; \ | ||
} | ||
|
||
static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { | ||
int device_id; | ||
if (!PyArg_ParseTuple(args, "i", &device_id)) | ||
return NULL; | ||
// Get device handle | ||
CUdevice device; | ||
cuDeviceGet(&device, device_id); | ||
|
||
// create a struct to hold device properties | ||
int max_shared_mem; | ||
int multiprocessor_count; | ||
int sm_clock_rate; | ||
int mem_clock_rate; | ||
int mem_bus_width; | ||
CUDA_CHECK(cuDeviceGetAttribute( | ||
&max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, | ||
device)); | ||
CUDA_CHECK(cuDeviceGetAttribute( | ||
&multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device)); | ||
CUDA_CHECK(cuDeviceGetAttribute(&sm_clock_rate, | ||
CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device)); | ||
CUDA_CHECK(cuDeviceGetAttribute( | ||
&mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device)); | ||
CUDA_CHECK(cuDeviceGetAttribute( | ||
&mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device)); | ||
|
||
return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem", | ||
max_shared_mem, "multiprocessor_count", | ||
multiprocessor_count, "sm_clock_rate", sm_clock_rate, | ||
"mem_clock_rate", mem_clock_rate, "mem_bus_width", | ||
mem_bus_width); | ||
} | ||
|
||
static PyObject *loadBinary(PyObject *self, PyObject *args) { | ||
const char *name; | ||
const char *data; | ||
Py_ssize_t data_size; | ||
int shared; | ||
int device; | ||
if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared, | ||
&device)) { | ||
return NULL; | ||
} | ||
CUfunction fun; | ||
CUmodule mod; | ||
int32_t n_regs = 0; | ||
int32_t n_spills = 0; | ||
// create driver handles | ||
CUDA_CHECK(cuModuleLoadData(&mod, data)); | ||
CUDA_CHECK(cuModuleGetFunction(&fun, mod, name)); | ||
// get allocated registers and spilled registers from the function | ||
CUDA_CHECK(cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun)); | ||
CUDA_CHECK( | ||
cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun)); | ||
n_spills /= 4; | ||
// set dynamic shared memory if necessary | ||
int shared_optin; | ||
CUDA_CHECK(cuDeviceGetAttribute( | ||
&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, | ||
device)); | ||
if (shared > 49152 && shared_optin > 49152) { | ||
CUDA_CHECK(cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED)); | ||
int shared_total, shared_static; | ||
CUDA_CHECK(cuDeviceGetAttribute( | ||
&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, | ||
device)); | ||
CUDA_CHECK(cuFuncGetAttribute(&shared_static, | ||
CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun)); | ||
CUDA_CHECK( | ||
cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, | ||
shared_optin - shared_static)); | ||
} | ||
|
||
if (PyErr_Occurred()) { | ||
return NULL; | ||
} | ||
return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs, | ||
n_spills); | ||
} | ||
|
||
static PyMethodDef ModuleMethods[] = { | ||
{"load_binary", loadBinary, METH_VARARGS, | ||
"Load provided cubin into CUDA driver"}, | ||
{"get_device_properties", getDeviceProperties, METH_VARARGS, | ||
"Get the properties for a given device"}, | ||
{NULL, NULL, 0, NULL} // sentinel | ||
}; | ||
|
||
static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "cuda_utils", | ||
NULL, // documentation | ||
-1, // size | ||
ModuleMethods}; | ||
|
||
PyMODINIT_FUNC PyInit_cuda_utils(void) { | ||
PyObject *m = PyModule_Create(&ModuleDef); | ||
if (m == NULL) { | ||
return NULL; | ||
} | ||
PyModule_AddFunctions(m, ModuleMethods); | ||
return m; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
#define __HIP_PLATFORM_AMD__ | ||
#include <hip/hip_runtime.h> | ||
#define PY_SSIZE_T_CLEAN | ||
#include <Python.h> | ||
#include <stdio.h> | ||
#include <stdlib.h> | ||
|
||
static inline void gpuAssert(hipError_t code, const char *file, int line) { | ||
{ | ||
if (code != HIP_SUCCESS) { | ||
{ | ||
const char *prefix = "Triton Error [HIP]: "; | ||
const char *str = hipGetErrorString(code); | ||
char err[1024] = {0}; | ||
snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str); | ||
PyErr_SetString(PyExc_RuntimeError, err); | ||
} | ||
} | ||
} | ||
} | ||
|
||
#define HIP_CHECK(ans) \ | ||
{ \ | ||
gpuAssert((ans), __FILE__, __LINE__); \ | ||
if (PyErr_Occurred()) \ | ||
return NULL; \ | ||
} | ||
|
||
static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { | ||
int device_id; | ||
if (!PyArg_ParseTuple(args, "i", &device_id)) | ||
return NULL; | ||
|
||
hipDeviceProp_t props; | ||
HIP_CHECK(hipGetDeviceProperties(&props, device_id)); | ||
|
||
// create a struct to hold device properties | ||
return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem", | ||
props.sharedMemPerBlock, "multiprocessor_count", | ||
props.multiProcessorCount, "sm_clock_rate", | ||
props.clockRate, "mem_clock_rate", props.memoryClockRate, | ||
"mem_bus_width", props.memoryBusWidth); | ||
} | ||
|
||
static PyObject *loadBinary(PyObject *self, PyObject *args) { | ||
const char *name; | ||
const char *data; | ||
Py_ssize_t data_size; | ||
int shared; | ||
int device; | ||
if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared, | ||
&device)) { | ||
return NULL; | ||
} | ||
|
||
// Open HSACO file | ||
FILE *hsaco_file; | ||
if ((hsaco_file = fopen(data, "rb")) == NULL) { | ||
return NULL; | ||
} | ||
|
||
// Read HSCAO file into Buffer | ||
fseek(hsaco_file, 0L, SEEK_END); | ||
size_t hsaco_file_size = ftell(hsaco_file); | ||
unsigned char *hsaco = | ||
(unsigned char *)malloc(hsaco_file_size * sizeof(unsigned char)); | ||
rewind(hsaco_file); | ||
fread(hsaco, sizeof(unsigned char), hsaco_file_size, hsaco_file); | ||
fclose(hsaco_file); | ||
|
||
// set HIP options | ||
hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes, | ||
hipJitOptionErrorLogBuffer, | ||
hipJitOptionInfoLogBufferSizeBytes, | ||
hipJitOptionInfoLogBuffer, hipJitOptionLogVerbose}; | ||
const unsigned int errbufsize = 8192; | ||
const unsigned int logbufsize = 8192; | ||
char _err[errbufsize]; | ||
char _log[logbufsize]; | ||
void *optval[] = {(void *)(uintptr_t)errbufsize, (void *)_err, | ||
(void *)(uintptr_t)logbufsize, (void *)_log, (void *)1}; | ||
|
||
// launch HIP Binary | ||
hipModule_t mod; | ||
hipFunction_t fun; | ||
hipModuleLoadDataEx(&mod, hsaco, 5, opt, optval); | ||
hipModuleGetFunction(&fun, mod, name); | ||
free(hsaco); | ||
|
||
// get allocated registers and spilled registers from the function | ||
int n_regs = 0; | ||
int n_spills = 0; | ||
if (PyErr_Occurred()) { | ||
return NULL; | ||
} | ||
return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs, | ||
n_spills); | ||
} | ||
|
||
static PyMethodDef ModuleMethods[] = { | ||
{"load_binary", loadBinary, METH_VARARGS, | ||
"Load provided hsaco into HIP driver"}, | ||
{"get_device_properties", getDeviceProperties, METH_VARARGS, | ||
"Get the properties for a given device"}, | ||
{NULL, NULL, 0, NULL} // sentinel | ||
}; | ||
|
||
static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "hip_utils", | ||
NULL, // documentation | ||
-1, // size | ||
ModuleMethods}; | ||
|
||
PyMODINIT_FUNC PyInit_hip_utils(void) { | ||
PyObject *m = PyModule_Create(&ModuleDef); | ||
if (m == NULL) { | ||
return NULL; | ||
} | ||
PyModule_AddFunctions(m, ModuleMethods); | ||
return m; | ||
} |
Oops, something went wrong.