From 531bb7c422c621c6dd726ca34cefd1fceb053f7b Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 24 Jun 2018 04:21:54 +0800 Subject: [PATCH] [PASS] Add GPU IR verifier (#1296) --- include/tvm/ir_pass.h | 24 +++ include/tvm/runtime/device_api.h | 3 +- python/tvm/_ffi/runtime_ctypes.py | 13 ++ src/api/api_pass.cc | 1 + src/pass/verify_gpu_code.cc | 166 +++++++++++++++++ src/runtime/cuda/cuda_device_api.cc | 18 +- src/runtime/metal/metal_device_api.mm | 1 + src/runtime/opencl/opencl_device_api.cc | 15 +- src/runtime/opengl/opengl_device_api.cc | 1 + src/runtime/rocm/rocm_device_api.cc | 1 + src/runtime/vulkan/vulkan_device_api.cc | 1 + .../unittest/test_pass_verify_gpu_code.py | 169 ++++++++++++++++++ 12 files changed, 410 insertions(+), 3 deletions(-) create mode 100644 src/pass/verify_gpu_code.cc create mode 100644 tests/python/unittest/test_pass_verify_gpu_code.py diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 4be1643429ea..f6c6334c88b6 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -477,6 +477,30 @@ LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target); */ bool VerifyMemory(LoweredFunc func, int device_type); + +/*! + * \brief Verify the correctness of a GPU code + * It will check the whether the amount of memory usage or the number of threads + * in a block exceeds the limit + * \param stmt The statement to be checked + * \param constraints The dict to specify constraints to check. + * Possible keys are + * + * "max_local_memory_per_block": Total amount of local memory per block (in bytes). + * "max_shared_memory_per_block": Total amount of shared memory per block (in bytes). + * "max_thread_per_block": Maximum number of threads per block. + * "max_thread_x": Maximum length of threadIdx.x. + * "max_thread_y": Maximum length of threadIdx.y. + * "max_thread_z": Maximum length of threadIdx.z. + * + * If one key is missing in this argument, the pass won't check for that item. + * \return valid Whether it is a valid GPU code + * + */ +bool VerifyGPUCode(Stmt stmt, + Map constraints); + + } // namespace ir } // namespace tvm diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index 82128fc58d47..3458c143e662 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -23,7 +23,8 @@ enum DeviceAttrKind : int { kComputeVersion = 4, kDeviceName = 5, kMaxClockRate = 6, - kMultiProcessorCount = 7 + kMultiProcessorCount = 7, + kMaxThreadDimensions = 8 }; /*! \brief Number of bytes each allocation must align to */ diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 3fc020c8781b..992afa4990c3 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -3,6 +3,7 @@ from __future__ import absolute_import import ctypes +import json import numpy as np from .base import _LIB, check_call from .. import _api_internal @@ -178,6 +179,18 @@ def multi_processor_count(self): return _api_internal._GetDeviceAttr( self.device_type, self.device_id, 7) + @property + def max_thread_dimensions(self): + """Return the maximum size of each thread axis + + Returns + ------- + dims: List of int + The maximum length of threadIdx.x, threadIdx.y, threadIdx.z + """ + return json.loads(_api_internal._GetDeviceAttr( + self.device_type, self.device_id, 8)) + def sync(self): """Synchronize until jobs finished at the context.""" check_call(_LIB.TVMSynchronize(self.device_type, self.device_id, None)) diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 64c922559229..adfe198ebf54 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -131,5 +131,6 @@ REGISTER_PASS2(LowerIntrin); REGISTER_PASS1(LowerTVMBuiltin); REGISTER_PASS1(CombineContextCall); REGISTER_PASS2(VerifyMemory); +REGISTER_PASS2(VerifyGPUCode); } // namespace ir } // namespace tvm diff --git a/src/pass/verify_gpu_code.cc b/src/pass/verify_gpu_code.cc new file mode 100644 index 000000000000..aa1762221d8d --- /dev/null +++ b/src/pass/verify_gpu_code.cc @@ -0,0 +1,166 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file verify_gpu_code.cc + * \brief Verify the correctness of a GPU IR. + * It will check the whether the amount of memory usage or the number of threads + * in a block exceeds the limit + */ + +#include +#include +#include + +namespace tvm { +namespace ir { + +class GPUCodeVerifier : public IRVisitor { + public: + bool Verify(tvm::Stmt stmt, + int64_t max_local_memory_per_block, + int64_t max_shared_memory_per_block, + int64_t max_thread_per_block, + int64_t max_thread_x, + int64_t max_thread_y, + int64_t max_thread_z) { + max_local_memory_per_block_ = static_cast(max_local_memory_per_block); + max_shared_memory_per_block_ = static_cast(max_shared_memory_per_block); + max_thread_per_block_ = static_cast(max_thread_per_block); + max_thread_x_ = static_cast(max_thread_x); + max_thread_y_ = static_cast(max_thread_y); + max_thread_z_ = static_cast(max_thread_z); + + Reset_(); + + this->Visit(stmt); + + return valid_; + } + + void Visit_(const ProducerConsumer *op) { + if (nest_level_ == 0) { + // enter a new kernel, reset statistics + Reset_(); + } + + if (op->is_producer) { + nest_level_++; + IRVisitor::Visit_(op); + nest_level_--; + } else { + IRVisitor::Visit_(op); + } + + if (nest_level_ == 0) { + // exit a kernel, check the validity + valid_ &= thread_per_block_ <= max_thread_per_block_; + + valid_ &= local_memory_per_block_ <= max_local_memory_per_block_; + valid_ &= shared_memory_per_block_ <= max_shared_memory_per_block_; + } + } + + void Visit_(const Allocate *op) { + IRVisitor::Visit_(op); + // visit an allocation of a buffer in shared memory, record its size + if (visited_local_buffers_.count(op->buffer_var.get()) != 0) { + size_t size = static_cast(op->constant_allocation_size()); + local_memory_per_block_ += size * op->type.bytes(); + } else if (visited_shared_buffers_.count(op->buffer_var.get()) != 0) { + size_t size = static_cast(op->constant_allocation_size()); + shared_memory_per_block_ += size * op->type.bytes(); + } + } + + void Visit_(const AttrStmt *op) { + if (op->attr_key == attr::storage_scope) { + if (op->value.as()->value == "local") { + visited_local_buffers_.insert(op->node.as()); + } else if (op->value.as()->value == "shared") { + visited_shared_buffers_.insert(op->node.as()); + } + } else if (op->attr_key == attr::thread_extent) { + VarExpr var = op->node.as()->var; + const auto *extent = op->value.as(); + CHECK(extent); + + // record the number of threads in a block + std::string name = var.get()->name_hint; + if (name == "threadIdx.x" || name == "threadIdx.y" || name == "threadIdx.z") { + if (!visited_threads_.count(name)) { + visited_threads_.insert(name); + size_t length = static_cast(extent->value); + thread_per_block_ *= length; + + if (name == "threadIdx.x") { + valid_ &= length <= max_thread_x_; + } else if (name == "threadIdx.y") { + valid_ &= length <= max_thread_y_; + } else if (name == "threadIdx.z") { + valid_ &= length <= max_thread_z_; + } + } + } + } + IRVisitor::Visit_(op); + } + + private: + int nest_level_{0}; + + std::unordered_set visited_local_buffers_; + std::unordered_set visited_shared_buffers_; + std::unordered_set visited_threads_; + + size_t local_memory_per_block_; + size_t shared_memory_per_block_; + size_t thread_per_block_; + + size_t max_local_memory_per_block_; + size_t max_shared_memory_per_block_; + size_t max_thread_per_block_; + size_t max_thread_x_, max_thread_y_, max_thread_z_; + + bool valid_{true}; + + void Reset_() { + visited_local_buffers_.clear(); + visited_shared_buffers_.clear(); + local_memory_per_block_ = 0; + shared_memory_per_block_ = 0; + + visited_threads_.clear(); + thread_per_block_ = 1; + } +}; + +bool VerifyGPUCode(Stmt stmt, + Map constraints) { + GPUCodeVerifier verifier; + + auto get_int = [&constraints](std::string key, int64_t def) { + auto iter = constraints.find(key); + if (iter != constraints.end()) { + return ((*iter).second).as()->value; + } else { + return def; + } + }; + + int64_t max_local_memory_per_block = get_int("max_local_memory_per_block", INT64_MAX); + int64_t max_shared_memory_per_block = get_int("max_shared_memory_per_block", INT64_MAX); + int64_t max_thread_per_block = get_int("max_thread_per_block", INT64_MAX); + int64_t max_thread_x = get_int("max_thread_x", INT64_MAX); + int64_t max_thread_y = get_int("max_thread_y", INT64_MAX); + int64_t max_thread_z = get_int("max_thread_z", INT64_MAX); + + return verifier.Verify(stmt, + max_local_memory_per_block, + max_shared_memory_per_block, + max_thread_per_block, + max_thread_x, + max_thread_y, + max_thread_z); +} + +} // namespace ir +} // namespace tvm diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index 9f1ea15aeeae..4573f97ddd8a 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -5,10 +5,12 @@ */ #include -#include #include #include #include +#include +#include +#include #include "./cuda_common.h" namespace tvm { @@ -70,6 +72,20 @@ class CUDADeviceAPI final : public DeviceAPI { &value, cudaDevAttrMultiProcessorCount, ctx.device_id)); break; } + case kMaxThreadDimensions: { + int dims[3]; + CUDA_CALL(cudaDeviceGetAttribute( + &dims[0], cudaDevAttrMaxBlockDimX, ctx.device_id)); + CUDA_CALL(cudaDeviceGetAttribute( + &dims[1], cudaDevAttrMaxBlockDimY, ctx.device_id)); + CUDA_CALL(cudaDeviceGetAttribute( + &dims[2], cudaDevAttrMaxBlockDimZ, ctx.device_id)); + + std::stringstream ss; // use json string to return multiple int values; + ss << "[" << dims[0] <<", " << dims[1] << ", " << dims[2] << "]"; + *rv = ss.str(); + return; + } } *rv = value; } diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm index fae1f0ded88f..47c2899cea71 100644 --- a/src/runtime/metal/metal_device_api.mm +++ b/src/runtime/metal/metal_device_api.mm @@ -42,6 +42,7 @@ case kDeviceName: return; case kMaxClockRate: return; case kMultiProcessorCount: return; + case kMaxThreadDimensions: return; case kExist: break; } } diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index b30b058d3acc..8c43ec252a79 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -4,6 +4,9 @@ */ #include #include +#include +#include +#include #include "./opencl_common.h" namespace tvm { @@ -30,6 +33,7 @@ void OpenCLWorkspace::GetAttr( CHECK_LT(index, devices.size()) << "Invalid device id " << index; switch (kind) { + case kExist: break; case kMaxThreadsPerBlock: { size_t value; OPENCL_CALL(clGetDeviceInfo( @@ -80,7 +84,16 @@ void OpenCLWorkspace::GetAttr( *rv = static_cast(value); break; } - case kExist: break; + case kMaxThreadDimensions: { + size_t dims[3]; + OPENCL_CALL(clGetDeviceInfo( + devices[index], CL_DEVICE_MAX_WORK_ITEM_SIZES, sizeof(dims), dims, nullptr)); + + std::stringstream ss; // use json string to return multiple int values; + ss << "[" << dims[0] <<", " << dims[1] << ", " << dims[2] << "]"; + *rv = ss.str(); + break; + } } } diff --git a/src/runtime/opengl/opengl_device_api.cc b/src/runtime/opengl/opengl_device_api.cc index 4357e610f478..3a21ed6e6d07 100644 --- a/src/runtime/opengl/opengl_device_api.cc +++ b/src/runtime/opengl/opengl_device_api.cc @@ -97,6 +97,7 @@ void OpenGLWorkspace::GetAttr( case kDeviceName: return; case kMaxClockRate: return; case kMultiProcessorCount: return; + case kMaxThreadDimensions: return; } } diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc index 3438e8edeaec..6aff5e56c715 100644 --- a/src/runtime/rocm/rocm_device_api.cc +++ b/src/runtime/rocm/rocm_device_api.cc @@ -52,6 +52,7 @@ class ROCMDeviceAPI final : public DeviceAPI { case kDeviceName: return; case kMaxClockRate: return; case kMultiProcessorCount: return; + case kMaxThreadDimensions: return; } *rv = value; } diff --git a/src/runtime/vulkan/vulkan_device_api.cc b/src/runtime/vulkan/vulkan_device_api.cc index 891b39456558..d0f7189d982a 100644 --- a/src/runtime/vulkan/vulkan_device_api.cc +++ b/src/runtime/vulkan/vulkan_device_api.cc @@ -73,6 +73,7 @@ void VulkanWorkspace::GetAttr( case kMaxClockRate: return; case kMultiProcessorCount: return; case kExist: break; + case kMaxThreadDimensions: break; } } diff --git a/tests/python/unittest/test_pass_verify_gpu_code.py b/tests/python/unittest/test_pass_verify_gpu_code.py new file mode 100644 index 000000000000..3059cbef0ed6 --- /dev/null +++ b/tests/python/unittest/test_pass_verify_gpu_code.py @@ -0,0 +1,169 @@ +"""Test gpu code verifier""" +import tvm + +def get_verify_pass(valid, **kwargs): + def verify_pass(stmt): + valid[0] = tvm.ir_pass.VerifyGPUCode(stmt, kwargs) + return stmt + return verify_pass + +def test_shared_memory(): + N = 1024 + M = 128 + + A = tvm.placeholder((N,), name='A', dtype='float32') + B = tvm.compute((N, ), lambda i: A[i], name='B') + + s = tvm.create_schedule([B.op]) + AA = s.cache_read(A, "shared", [B]) + o, i = s[B].split(s[B].op.axis[0], M) + s[AA].compute_at(s[B], o) + s[B].bind(o, tvm.thread_axis("blockIdx.x")) + s[B].bind(i, tvm.thread_axis("threadIdx.x")) + + # shared memory usage: M * 4B + # thread usage: M + + for target in ['opencl', 'cuda']: + if not tvm.context(target).exist: + continue + valid = [None] + with tvm.build_config(**{"add_lower_pass": [ + (2, get_verify_pass(valid, + max_shared_memory_per_block=4 * M - 1, + max_thread_per_block=M))]}): + tvm.build(s, [A, B], target) + assert not valid[0] + + with tvm.build_config(**{"add_lower_pass": [ + (2, get_verify_pass(valid, + max_shared_memory_per_block=4 * M, + max_thread_per_block=M))]}): + tvm.build(s, [A, B], target) + assert valid[0] + +def test_local_memory(): + N = 1024 + M = 128 + + A = tvm.placeholder((N,), name='A', dtype='float32') + B = tvm.compute((N, ), lambda i: A[i], name='B') + + s = tvm.create_schedule([B.op]) + AA = s.cache_read(A, "local", [B]) + o, i = s[B].split(s[B].op.axis[0], M) + s[AA].compute_at(s[B], o) + s[B].bind(o, tvm.thread_axis("blockIdx.x")) + + # local memory usage: M * 4B + # thread usage: M + + for target in ['opencl', 'cuda']: + if not tvm.context(target).exist: + continue + + valid = [None] + with tvm.build_config(**{"add_lower_pass": [ + (2, get_verify_pass(valid, + max_local_memory_per_block=4 * M - 1, + max_thread_per_block=1))]}): + tvm.build(s, [A, B], target) + assert not valid[0] + + with tvm.build_config(**{"add_lower_pass": [ + (2, get_verify_pass(valid, + max_local_memory_per_block=4 * M, + max_thread_per_block=1))]}): + tvm.build(s, [A, B], target) + assert valid[0] + +def test_num_thread(): + N = 1024 + M = 128 + + A = tvm.placeholder((N,), name='A', dtype='float32') + B = tvm.compute((N, ), lambda i: A[i], name='B') + + s = tvm.create_schedule([B.op]) + o, i = s[B].split(s[B].op.axis[0], M) + + s[B].bind(o, tvm.thread_axis('threadIdx.x')) + s[B].bind(i, tvm.thread_axis("threadIdx.y")) + + # shared memory usage: 0 + # thread usage: N + + for target in ['opencl', 'cuda']: + if not tvm.context(target).exist: + continue + + valid = [None] + with tvm.build_config(**{"add_lower_pass": [ + (2, get_verify_pass(valid, + max_shared_memory_per_block=0, + max_thread_per_block=N - 1))]}): + tvm.build(s, [A, B], target) + assert not valid[0] + + with tvm.build_config(**{"add_lower_pass": [ + (2, get_verify_pass(valid, + max_shared_memory_per_block=0, + max_thread_per_block=N))]}): + tvm.build(s, [A, B], target) + assert valid[0] + + with tvm.build_config(**{"add_lower_pass": [ + (2, get_verify_pass(valid, + max_shared_memory_per_block=0, + max_thread_per_block=N, + max_thread_y=M-1))]}): + tvm.build(s, [A, B], target) + assert not valid[0] + + with tvm.build_config(**{"add_lower_pass": [ + (2, get_verify_pass(valid, + max_shared_memory_per_block=0, + max_thread_per_block=N, + max_thread_y=M))]}): + tvm.build(s, [A, B], target) + assert valid[0] + +def test_multiple_kernels(): + N = 1024 + + A = tvm.placeholder((N, N), name='A') + B = tvm.compute((N, N), lambda i, j: A[i, j]) + C = tvm.compute((N, N), lambda i, j: B[i, j]) + + s = tvm.create_schedule([C.op]) + + s[C].bind(s[C].op.axis[1], tvm.thread_axis("threadIdx.x")) + s[B].bind(s[B].op.axis[1], tvm.thread_axis("threadIdx.x")) + + # shared memory usage: 0 + # thread usage: N + + for target in ['opencl', 'cuda']: + if not tvm.context(target).exist: + continue + + valid = [None] + with tvm.build_config(**{"add_lower_pass": [ + (2, get_verify_pass(valid, + max_shared_memory_per_block=0, + max_thread_per_block=N - 1))]}): + tvm.build(s, [A, C], target) + assert not valid[0] + + with tvm.build_config(**{"add_lower_pass": [ + (2, get_verify_pass(valid, + max_shared_memory_per_block=0, + max_thread_per_block=N))]}): + tvm.build(s, [A, C], target) + assert valid[0] + +if __name__ == "__main__": + test_local_memory() + test_shared_memory() + test_num_thread() + test_multiple_kernels()