Skip to content

Commit

Permalink
[RUNTIME] refactor driver (triton-lang#1515)
Browse files Browse the repository at this point in the history
Improved separation between different backends
  • Loading branch information
ptillet authored Apr 13, 2023
1 parent 7584e04 commit c5359f0
Show file tree
Hide file tree
Showing 14 changed files with 428 additions and 367 deletions.
1 change: 1 addition & 0 deletions python/MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
graft src
graft triton/third_party
graft triton/runtime/backends/
1 change: 0 additions & 1 deletion python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,6 @@ def build_extension(self, ext):
"triton/ops",
"triton/ops/blocksparse",
"triton/runtime",
"triton/runtime/driver",
"triton/tools",
],
install_requires=[
Expand Down
23 changes: 9 additions & 14 deletions python/triton/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@

import triton
import triton._C.libtriton.triton as _triton
from ..runtime import driver
# TODO: runtime.errors
from ..runtime.autotuner import OutOfResources
from ..runtime.cache import get_cache_manager
from ..runtime.driver import get_cuda_utils, get_hip_utils
from ..tools.disasm import extract
from .code_generator import ast_to_ttir
from .make_launcher import make_stub
Expand Down Expand Up @@ -519,24 +519,19 @@ def __init__(self, fn, so_path, metadata, asm):
self.metadata = metadata
self.cu_module = None
self.cu_function = None
self.is_hip = "amdgcn" in asm

def _init_handles(self):
if self.cu_module is not None:
return
device = triton.runtime.jit.get_current_device()
if self.is_hip:
hip_utils = get_hip_utils()
max_shared = hip_utils.get_device_properties(device)["max_shared_mem"]
if self.shared > max_shared:
raise OutOfResources(self.shared, max_shared, "shared memory")
mod, func, n_regs, n_spills = hip_utils.load_binary(self.metadata["name"], self.asm["hsaco_path"], self.shared, device)
else:
cuda_utils = get_cuda_utils()
max_shared = cuda_utils.get_device_properties(device)["max_shared_mem"]
if self.shared > max_shared:
raise OutOfResources(self.shared, max_shared, "shared memory")
mod, func, n_regs, n_spills = cuda_utils.load_binary(self.metadata["name"], self.asm["cubin"], self.shared, device)
bin_path = {
driver.HIP: "hsaco_path",
driver.CUDA: "cubin"
}[driver.backend]
max_shared = driver.utils.get_device_properties(device)["max_shared_mem"]
if self.shared > max_shared:
raise OutOfResources(self.shared, max_shared, "shared memory")
mod, func, n_regs, n_spills = driver.utils.load_binary(self.metadata["name"], self.asm[bin_path], self.shared, device)

self.n_spills = n_spills
self.n_regs = n_regs
Expand Down
9 changes: 2 additions & 7 deletions python/triton/language/math.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
import os

import torch

from ..runtime import driver
from . import core, extern

if torch.version.hip is not None:
LOCAL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "cuda2gcn.bc")
else:
LOCAL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "third_party", "cuda", "lib", "libdevice.10.bc")
LIBDEVICE_PATH = os.getenv("TRITON_LIBDEVICE_PATH", LOCAL_PATH)
LIBDEVICE_PATH = os.getenv("TRITON_LIBDEVICE_PATH", driver.libdevice_path)


@extern.extern
Expand Down
15 changes: 5 additions & 10 deletions python/triton/ops/matmul_perf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,22 @@

import triton
import triton._C.libtriton.triton as _triton
from triton.runtime.driver.cuda import get_cuda_utils
from triton.runtime import driver
from triton.testing import get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops


def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
''' return compute throughput in TOPS '''
total_warps = num_ctas * min(num_warps, 4)
cuda_utils = get_cuda_utils()
num_subcores = cuda_utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, backend, device)
return tflops


def get_simd_tflops(backend, device, num_ctas, num_warps, dtype):
''' return compute throughput in TOPS '''
total_warps = num_ctas * min(num_warps, 4)
cuda_utils = get_cuda_utils()
num_subcores = cuda_utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
tflops = min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, backend, device)
return tflops

Expand Down Expand Up @@ -62,8 +60,7 @@ def estimate_matmul_time(
compute_ms = total_ops / tput

# time to load data
cuda_utils = get_cuda_utils()
num_sm = cuda_utils.get_device_properties(device)["multiprocessor_count"]
num_sm = driver.utils.get_device_properties(device)["multiprocessor_count"]
active_cta_ratio = min(1, num_ctas / num_sm)
active_cta_ratio_bw1 = min(1, num_ctas / 32) # 32 active ctas are enough to saturate
active_cta_ratio_bw2 = max(min(1, (num_ctas - 32) / (108 - 32)), 0) # 32-108, remaining 5%
Expand Down Expand Up @@ -114,9 +111,7 @@ def early_config_prune(configs, named_args):
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = \
kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], config.num_stages

# TODO: move to `cuda_utils` submodule
cuda_utils = get_cuda_utils()
max_shared_memory = cuda_utils.get_device_properties(device)["max_shared_mem"]
max_shared_memory = driver.utils.get_device_properties(device)["max_shared_mem"]
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
if required_shared_memory <= max_shared_memory:
pruned_configs.append(config)
Expand Down
2 changes: 1 addition & 1 deletion python/triton/runtime/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from . import driver
from .autotuner import (Autotuner, Config, Heuristics, OutOfResources, autotune,
heuristics)
from .driver import driver
from .jit import (JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret,
version_key)

Expand Down
124 changes: 124 additions & 0 deletions python/triton/runtime/backends/cuda.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
#include "cuda.h"
#define PY_SSIZE_T_CLEAN
#include <Python.h>

static inline void gpuAssert(CUresult code, const char *file, int line) {
if (code != CUDA_SUCCESS) {
const char *prefix = "Triton Error [CUDA]: ";
const char *str;
cuGetErrorString(code, &str);
char err[1024] = {0};
strcat(err, prefix);
strcat(err, str);
PyErr_SetString(PyExc_RuntimeError, err);
}
}

#define CUDA_CHECK(ans) \
{ \
gpuAssert((ans), __FILE__, __LINE__); \
if (PyErr_Occurred()) \
return NULL; \
}

static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
int device_id;
if (!PyArg_ParseTuple(args, "i", &device_id))
return NULL;
// Get device handle
CUdevice device;
cuDeviceGet(&device, device_id);

// create a struct to hold device properties
int max_shared_mem;
int multiprocessor_count;
int sm_clock_rate;
int mem_clock_rate;
int mem_bus_width;
CUDA_CHECK(cuDeviceGetAttribute(
&max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
device));
CUDA_CHECK(cuDeviceGetAttribute(
&multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device));
CUDA_CHECK(cuDeviceGetAttribute(&sm_clock_rate,
CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device));
CUDA_CHECK(cuDeviceGetAttribute(
&mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device));
CUDA_CHECK(cuDeviceGetAttribute(
&mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device));

return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem",
max_shared_mem, "multiprocessor_count",
multiprocessor_count, "sm_clock_rate", sm_clock_rate,
"mem_clock_rate", mem_clock_rate, "mem_bus_width",
mem_bus_width);
}

static PyObject *loadBinary(PyObject *self, PyObject *args) {
const char *name;
const char *data;
Py_ssize_t data_size;
int shared;
int device;
if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared,
&device)) {
return NULL;
}
CUfunction fun;
CUmodule mod;
int32_t n_regs = 0;
int32_t n_spills = 0;
// create driver handles
CUDA_CHECK(cuModuleLoadData(&mod, data));
CUDA_CHECK(cuModuleGetFunction(&fun, mod, name));
// get allocated registers and spilled registers from the function
CUDA_CHECK(cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun));
CUDA_CHECK(
cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun));
n_spills /= 4;
// set dynamic shared memory if necessary
int shared_optin;
CUDA_CHECK(cuDeviceGetAttribute(
&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
device));
if (shared > 49152 && shared_optin > 49152) {
CUDA_CHECK(cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED));
int shared_total, shared_static;
CUDA_CHECK(cuDeviceGetAttribute(
&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR,
device));
CUDA_CHECK(cuFuncGetAttribute(&shared_static,
CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun));
CUDA_CHECK(
cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
shared_optin - shared_static));
}

if (PyErr_Occurred()) {
return NULL;
}
return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs,
n_spills);
}

static PyMethodDef ModuleMethods[] = {
{"load_binary", loadBinary, METH_VARARGS,
"Load provided cubin into CUDA driver"},
{"get_device_properties", getDeviceProperties, METH_VARARGS,
"Get the properties for a given device"},
{NULL, NULL, 0, NULL} // sentinel
};

static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "cuda_utils",
NULL, // documentation
-1, // size
ModuleMethods};

PyMODINIT_FUNC PyInit_cuda_utils(void) {
PyObject *m = PyModule_Create(&ModuleDef);
if (m == NULL) {
return NULL;
}
PyModule_AddFunctions(m, ModuleMethods);
return m;
}
120 changes: 120 additions & 0 deletions python/triton/runtime/backends/hip.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#define __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <stdio.h>
#include <stdlib.h>

static inline void gpuAssert(hipError_t code, const char *file, int line) {
{
if (code != HIP_SUCCESS) {
{
const char *prefix = "Triton Error [HIP]: ";
const char *str = hipGetErrorString(code);
char err[1024] = {0};
snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str);
PyErr_SetString(PyExc_RuntimeError, err);
}
}
}
}

#define HIP_CHECK(ans) \
{ \
gpuAssert((ans), __FILE__, __LINE__); \
if (PyErr_Occurred()) \
return NULL; \
}

static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
int device_id;
if (!PyArg_ParseTuple(args, "i", &device_id))
return NULL;

hipDeviceProp_t props;
HIP_CHECK(hipGetDeviceProperties(&props, device_id));

// create a struct to hold device properties
return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem",
props.sharedMemPerBlock, "multiprocessor_count",
props.multiProcessorCount, "sm_clock_rate",
props.clockRate, "mem_clock_rate", props.memoryClockRate,
"mem_bus_width", props.memoryBusWidth);
}

static PyObject *loadBinary(PyObject *self, PyObject *args) {
const char *name;
const char *data;
Py_ssize_t data_size;
int shared;
int device;
if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared,
&device)) {
return NULL;
}

// Open HSACO file
FILE *hsaco_file;
if ((hsaco_file = fopen(data, "rb")) == NULL) {
return NULL;
}

// Read HSCAO file into Buffer
fseek(hsaco_file, 0L, SEEK_END);
size_t hsaco_file_size = ftell(hsaco_file);
unsigned char *hsaco =
(unsigned char *)malloc(hsaco_file_size * sizeof(unsigned char));
rewind(hsaco_file);
fread(hsaco, sizeof(unsigned char), hsaco_file_size, hsaco_file);
fclose(hsaco_file);

// set HIP options
hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes,
hipJitOptionErrorLogBuffer,
hipJitOptionInfoLogBufferSizeBytes,
hipJitOptionInfoLogBuffer, hipJitOptionLogVerbose};
const unsigned int errbufsize = 8192;
const unsigned int logbufsize = 8192;
char _err[errbufsize];
char _log[logbufsize];
void *optval[] = {(void *)(uintptr_t)errbufsize, (void *)_err,
(void *)(uintptr_t)logbufsize, (void *)_log, (void *)1};

// launch HIP Binary
hipModule_t mod;
hipFunction_t fun;
hipModuleLoadDataEx(&mod, hsaco, 5, opt, optval);
hipModuleGetFunction(&fun, mod, name);
free(hsaco);

// get allocated registers and spilled registers from the function
int n_regs = 0;
int n_spills = 0;
if (PyErr_Occurred()) {
return NULL;
}
return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs,
n_spills);
}

static PyMethodDef ModuleMethods[] = {
{"load_binary", loadBinary, METH_VARARGS,
"Load provided hsaco into HIP driver"},
{"get_device_properties", getDeviceProperties, METH_VARARGS,
"Get the properties for a given device"},
{NULL, NULL, 0, NULL} // sentinel
};

static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "hip_utils",
NULL, // documentation
-1, // size
ModuleMethods};

PyMODINIT_FUNC PyInit_hip_utils(void) {
PyObject *m = PyModule_Create(&ModuleDef);
if (m == NULL) {
return NULL;
}
PyModule_AddFunctions(m, ModuleMethods);
return m;
}
Loading

0 comments on commit c5359f0

Please sign in to comment.