Skip to content

Commit

Permalink
[PASS] Add GPU IR verifier (apache#1296)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored and tqchen committed Jun 23, 2018
1 parent f216b25 commit 531bb7c
Show file tree
Hide file tree
Showing 12 changed files with 410 additions and 3 deletions.
24 changes: 24 additions & 0 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,30 @@ LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target);
*/
bool VerifyMemory(LoweredFunc func, int device_type);


/*!
* \brief Verify the correctness of a GPU code
* It will check the whether the amount of memory usage or the number of threads
* in a block exceeds the limit
* \param stmt The statement to be checked
* \param constraints The dict to specify constraints to check.
* Possible keys are
*
* "max_local_memory_per_block": Total amount of local memory per block (in bytes).
* "max_shared_memory_per_block": Total amount of shared memory per block (in bytes).
* "max_thread_per_block": Maximum number of threads per block.
* "max_thread_x": Maximum length of threadIdx.x.
* "max_thread_y": Maximum length of threadIdx.y.
* "max_thread_z": Maximum length of threadIdx.z.
*
* If one key is missing in this argument, the pass won't check for that item.
* \return valid Whether it is a valid GPU code
*
*/
bool VerifyGPUCode(Stmt stmt,
Map<std::string, Expr> constraints);


} // namespace ir
} // namespace tvm

Expand Down
3 changes: 2 additions & 1 deletion include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ enum DeviceAttrKind : int {
kComputeVersion = 4,
kDeviceName = 5,
kMaxClockRate = 6,
kMultiProcessorCount = 7
kMultiProcessorCount = 7,
kMaxThreadDimensions = 8
};

/*! \brief Number of bytes each allocation must align to */
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/_ffi/runtime_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import absolute_import

import ctypes
import json
import numpy as np
from .base import _LIB, check_call
from .. import _api_internal
Expand Down Expand Up @@ -178,6 +179,18 @@ def multi_processor_count(self):
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 7)

@property
def max_thread_dimensions(self):
"""Return the maximum size of each thread axis
Returns
-------
dims: List of int
The maximum length of threadIdx.x, threadIdx.y, threadIdx.z
"""
return json.loads(_api_internal._GetDeviceAttr(
self.device_type, self.device_id, 8))

def sync(self):
"""Synchronize until jobs finished at the context."""
check_call(_LIB.TVMSynchronize(self.device_type, self.device_id, None))
Expand Down
1 change: 1 addition & 0 deletions src/api/api_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,5 +131,6 @@ REGISTER_PASS2(LowerIntrin);
REGISTER_PASS1(LowerTVMBuiltin);
REGISTER_PASS1(CombineContextCall);
REGISTER_PASS2(VerifyMemory);
REGISTER_PASS2(VerifyGPUCode);
} // namespace ir
} // namespace tvm
166 changes: 166 additions & 0 deletions src/pass/verify_gpu_code.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
/*!
* Copyright (c) 2018 by Contributors
* \file verify_gpu_code.cc
* \brief Verify the correctness of a GPU IR.
* It will check the whether the amount of memory usage or the number of threads
* in a block exceeds the limit
*/

#include <tvm/api_registry.h>
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>

namespace tvm {
namespace ir {

class GPUCodeVerifier : public IRVisitor {
public:
bool Verify(tvm::Stmt stmt,
int64_t max_local_memory_per_block,
int64_t max_shared_memory_per_block,
int64_t max_thread_per_block,
int64_t max_thread_x,
int64_t max_thread_y,
int64_t max_thread_z) {
max_local_memory_per_block_ = static_cast<size_t>(max_local_memory_per_block);
max_shared_memory_per_block_ = static_cast<size_t>(max_shared_memory_per_block);
max_thread_per_block_ = static_cast<size_t>(max_thread_per_block);
max_thread_x_ = static_cast<size_t>(max_thread_x);
max_thread_y_ = static_cast<size_t>(max_thread_y);
max_thread_z_ = static_cast<size_t>(max_thread_z);

Reset_();

this->Visit(stmt);

return valid_;
}

void Visit_(const ProducerConsumer *op) {
if (nest_level_ == 0) {
// enter a new kernel, reset statistics
Reset_();
}

if (op->is_producer) {
nest_level_++;
IRVisitor::Visit_(op);
nest_level_--;
} else {
IRVisitor::Visit_(op);
}

if (nest_level_ == 0) {
// exit a kernel, check the validity
valid_ &= thread_per_block_ <= max_thread_per_block_;

valid_ &= local_memory_per_block_ <= max_local_memory_per_block_;
valid_ &= shared_memory_per_block_ <= max_shared_memory_per_block_;
}
}

void Visit_(const Allocate *op) {
IRVisitor::Visit_(op);
// visit an allocation of a buffer in shared memory, record its size
if (visited_local_buffers_.count(op->buffer_var.get()) != 0) {
size_t size = static_cast<size_t>(op->constant_allocation_size());
local_memory_per_block_ += size * op->type.bytes();
} else if (visited_shared_buffers_.count(op->buffer_var.get()) != 0) {
size_t size = static_cast<size_t>(op->constant_allocation_size());
shared_memory_per_block_ += size * op->type.bytes();
}
}

void Visit_(const AttrStmt *op) {
if (op->attr_key == attr::storage_scope) {
if (op->value.as<StringImm>()->value == "local") {
visited_local_buffers_.insert(op->node.as<tvm::Variable>());
} else if (op->value.as<StringImm>()->value == "shared") {
visited_shared_buffers_.insert(op->node.as<tvm::Variable>());
}
} else if (op->attr_key == attr::thread_extent) {
VarExpr var = op->node.as<tvm::IterVarNode>()->var;
const auto *extent = op->value.as<IntImm>();
CHECK(extent);

// record the number of threads in a block
std::string name = var.get()->name_hint;
if (name == "threadIdx.x" || name == "threadIdx.y" || name == "threadIdx.z") {
if (!visited_threads_.count(name)) {
visited_threads_.insert(name);
size_t length = static_cast<size_t>(extent->value);
thread_per_block_ *= length;

if (name == "threadIdx.x") {
valid_ &= length <= max_thread_x_;
} else if (name == "threadIdx.y") {
valid_ &= length <= max_thread_y_;
} else if (name == "threadIdx.z") {
valid_ &= length <= max_thread_z_;
}
}
}
}
IRVisitor::Visit_(op);
}

private:
int nest_level_{0};

std::unordered_set<const tvm::Variable *> visited_local_buffers_;
std::unordered_set<const tvm::Variable *> visited_shared_buffers_;
std::unordered_set<std::string> visited_threads_;

size_t local_memory_per_block_;
size_t shared_memory_per_block_;
size_t thread_per_block_;

size_t max_local_memory_per_block_;
size_t max_shared_memory_per_block_;
size_t max_thread_per_block_;
size_t max_thread_x_, max_thread_y_, max_thread_z_;

bool valid_{true};

void Reset_() {
visited_local_buffers_.clear();
visited_shared_buffers_.clear();
local_memory_per_block_ = 0;
shared_memory_per_block_ = 0;

visited_threads_.clear();
thread_per_block_ = 1;
}
};

bool VerifyGPUCode(Stmt stmt,
Map<std::string, Expr> constraints) {
GPUCodeVerifier verifier;

auto get_int = [&constraints](std::string key, int64_t def) {
auto iter = constraints.find(key);
if (iter != constraints.end()) {
return ((*iter).second).as<IntImm>()->value;
} else {
return def;
}
};

int64_t max_local_memory_per_block = get_int("max_local_memory_per_block", INT64_MAX);
int64_t max_shared_memory_per_block = get_int("max_shared_memory_per_block", INT64_MAX);
int64_t max_thread_per_block = get_int("max_thread_per_block", INT64_MAX);
int64_t max_thread_x = get_int("max_thread_x", INT64_MAX);
int64_t max_thread_y = get_int("max_thread_y", INT64_MAX);
int64_t max_thread_z = get_int("max_thread_z", INT64_MAX);

return verifier.Verify(stmt,
max_local_memory_per_block,
max_shared_memory_per_block,
max_thread_per_block,
max_thread_x,
max_thread_y,
max_thread_z);
}

} // namespace ir
} // namespace tvm
18 changes: 17 additions & 1 deletion src/runtime/cuda/cuda_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
*/
#include <tvm/runtime/device_api.h>

#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
#include <cuda_runtime.h>
#include <tvm/container.h>
#include <tvm/ir.h>
#include <tvm/packed_func_ext.h>
#include "./cuda_common.h"

namespace tvm {
Expand Down Expand Up @@ -70,6 +72,20 @@ class CUDADeviceAPI final : public DeviceAPI {
&value, cudaDevAttrMultiProcessorCount, ctx.device_id));
break;
}
case kMaxThreadDimensions: {
int dims[3];
CUDA_CALL(cudaDeviceGetAttribute(
&dims[0], cudaDevAttrMaxBlockDimX, ctx.device_id));
CUDA_CALL(cudaDeviceGetAttribute(
&dims[1], cudaDevAttrMaxBlockDimY, ctx.device_id));
CUDA_CALL(cudaDeviceGetAttribute(
&dims[2], cudaDevAttrMaxBlockDimZ, ctx.device_id));

std::stringstream ss; // use json string to return multiple int values;
ss << "[" << dims[0] <<", " << dims[1] << ", " << dims[2] << "]";
*rv = ss.str();
return;
}
}
*rv = value;
}
Expand Down
1 change: 1 addition & 0 deletions src/runtime/metal/metal_device_api.mm
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
case kDeviceName: return;
case kMaxClockRate: return;
case kMultiProcessorCount: return;
case kMaxThreadDimensions: return;
case kExist: break;
}
}
Expand Down
15 changes: 14 additions & 1 deletion src/runtime/opencl/opencl_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
*/
#include <tvm/runtime/registry.h>
#include <dmlc/thread_local.h>
#include <tvm/container.h>
#include <tvm/ir.h>
#include <tvm/packed_func_ext.h>
#include "./opencl_common.h"

namespace tvm {
Expand All @@ -30,6 +33,7 @@ void OpenCLWorkspace::GetAttr(
CHECK_LT(index, devices.size())
<< "Invalid device id " << index;
switch (kind) {
case kExist: break;
case kMaxThreadsPerBlock: {
size_t value;
OPENCL_CALL(clGetDeviceInfo(
Expand Down Expand Up @@ -80,7 +84,16 @@ void OpenCLWorkspace::GetAttr(
*rv = static_cast<int32_t>(value);
break;
}
case kExist: break;
case kMaxThreadDimensions: {
size_t dims[3];
OPENCL_CALL(clGetDeviceInfo(
devices[index], CL_DEVICE_MAX_WORK_ITEM_SIZES, sizeof(dims), dims, nullptr));

std::stringstream ss; // use json string to return multiple int values;
ss << "[" << dims[0] <<", " << dims[1] << ", " << dims[2] << "]";
*rv = ss.str();
break;
}
}
}

Expand Down
1 change: 1 addition & 0 deletions src/runtime/opengl/opengl_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ void OpenGLWorkspace::GetAttr(
case kDeviceName: return;
case kMaxClockRate: return;
case kMultiProcessorCount: return;
case kMaxThreadDimensions: return;
}
}

Expand Down
1 change: 1 addition & 0 deletions src/runtime/rocm/rocm_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
case kDeviceName: return;
case kMaxClockRate: return;
case kMultiProcessorCount: return;
case kMaxThreadDimensions: return;
}
*rv = value;
}
Expand Down
1 change: 1 addition & 0 deletions src/runtime/vulkan/vulkan_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ void VulkanWorkspace::GetAttr(
case kMaxClockRate: return;
case kMultiProcessorCount: return;
case kExist: break;
case kMaxThreadDimensions: break;
}
}

Expand Down
Loading

0 comments on commit 531bb7c

Please sign in to comment.