-
Notifications
You must be signed in to change notification settings - Fork 3.5k
/
verify_gpu_code.cc
332 lines (291 loc) · 12.1 KB
/
verify_gpu_code.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file verify_gpu_code.cc
* \brief Verify the correctness of a GPU IR.
* It will check the whether the amount of memory usage or the number of threads
* in a block exceeds the limit
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
#include "../../runtime/thread_storage_scope.h"
#include "../transforms/ir_utils.h"
namespace tvm {
namespace tir {
class GPUCodeVerifier : public StmtExprVisitor {
public:
std::vector<String> Verify(Stmt stmt, int64_t max_local_memory_per_block,
int64_t max_shared_memory_per_block, int64_t max_threads_per_block,
int64_t max_thread_x, int64_t max_thread_y, int64_t max_thread_z,
int64_t max_vthread, int64_t max_vector_bytes, int64_t max_kernels) {
max_local_memory_per_block_ = static_cast<size_t>(max_local_memory_per_block);
max_shared_memory_per_block_ = static_cast<size_t>(max_shared_memory_per_block);
max_threads_per_block_ = static_cast<size_t>(max_threads_per_block);
max_thread_x_ = static_cast<size_t>(max_thread_x);
max_thread_y_ = static_cast<size_t>(max_thread_y);
max_thread_z_ = static_cast<size_t>(max_thread_z);
max_vthread_ = static_cast<size_t>(max_vthread);
max_vector_bytes_ = static_cast<size_t>(max_vector_bytes);
max_kernels_ = static_cast<size_t>(max_kernels);
Reset_();
// TODO(jcf94): Add support of detecting CUDA Misaligned Address error
this->VisitStmt(stmt);
return errors_;
}
void VisitStmt_(const AllocateNode* op) final {
StmtVisitor::VisitStmt_(op);
auto scope = GetPtrStorageScope(op->buffer_var);
runtime::StorageScope storage_scope = runtime::StorageScope::Create(scope);
// visit an allocation of a buffer in shared memory, record its size
if (storage_scope.rank == runtime::StorageRank::kLocal) {
size_t size = static_cast<size_t>(op->ConstantAllocationSize());
local_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes();
} else if (storage_scope.rank == runtime::StorageRank::kShared) {
size_t size = static_cast<size_t>(op->ConstantAllocationSize());
shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes();
}
if (op->dtype.lanes() > 1) {
if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) {
std::stringstream s;
s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes ("
<< op->dtype.bytes() << ") for dtype " << op->dtype
<< " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")";
errors_.push_back(s.str());
}
}
}
void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) {
if (nest_level_ == 0) {
// enter a new kernel, reset statistics
Reset_();
kernels_launched_++;
}
Var var = op->node.as<IterVarNode>()->var;
const auto* extent = op->value.as<IntImmNode>();
ICHECK(extent);
std::string name = var.get()->name_hint;
// record the number of threads in a block
if (name == "threadIdx.x" || name == "threadIdx.y" || name == "threadIdx.z" ||
name == "vthread") {
size_t length = static_cast<size_t>(extent->value);
if (!visited_threads_.count(name)) {
visited_threads_.insert(name);
thread_per_block_ *= length;
auto err = [this](std::string id, size_t ext, size_t m) {
if (ext > m) {
std::stringstream s;
s << "Extent of " << id << " (" << ext << ") is greater than maximum allowed (" << m
<< ");";
errors_.push_back(s.str());
}
};
if (name == "threadIdx.x") {
err("threadIdx.x", length, max_thread_x_);
thread_x_extent_ = length;
} else if (name == "threadIdx.y") {
err("threadIdx.y", length, max_thread_y_);
thread_y_extent_ = length;
} else if (name == "threadIdx.z") {
err("threadIdx.z", length, max_thread_z_);
thread_z_extent_ = length;
} else if (name == "vthread") {
err("vthread", length, max_vthread_);
}
} else {
// the thread should be bound to axes with the same length
auto err = [this, name](std::string id, size_t ext, size_t m) {
if (name == id && ext != m) {
std::stringstream s;
s << "Extent of " << id << " (" << ext << ") does not match the bound " << m;
errors_.push_back(s.str());
}
};
err("threadIdx.x", length, thread_x_extent_);
err("threadIdx.y", length, thread_y_extent_);
err("threadIdx.z", length, thread_z_extent_);
}
}
nest_level_++;
StmtVisitor::VisitStmt_(op);
nest_level_--;
if (nest_level_ == 0) {
// exit a kernel, check the validity
auto err = [this](std::string id, size_t num, size_t m) {
if (num > m) {
std::stringstream s;
s << "Used " << id << " (" << num << ") is greater than the allowed maximum (" << m
<< ")";
errors_.push_back(s.str());
}
};
err("threads per block", thread_per_block_, max_threads_per_block_);
err("local memory per block", local_memory_per_block_, max_local_memory_per_block_);
err("shared memory per block", shared_memory_per_block_, max_shared_memory_per_block_);
if (kernels_launched_ > max_kernels_) {
std::stringstream s;
s << "Number of launched kernels (" << kernels_launched_
<< ") is greater than the allowed maximum (" << max_kernels_ << ")";
errors_.push_back(s.str());
}
}
} else {
StmtVisitor::VisitStmt_(op);
}
}
void VisitStmt_(const ForNode* op) {
if (op->loop_var->name_hint == "vthread.s") {
const auto* extent = op->extent.as<IntImmNode>();
ICHECK(extent);
size_t num_vthread = static_cast<size_t>(extent->value);
if (num_vthread > max_vthread_) {
std::stringstream s;
s << "Number of vthreads (" << num_vthread << ") is greater than the allowed maximum ("
<< max_vthread_ << ")";
errors_.push_back(s.str());
}
}
StmtVisitor::VisitStmt_(op);
}
void VisitExpr_(const LoadNode* op) final {
LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
}
void VisitStmt_(const StoreNode* op) final {
LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
}
void VisitExpr_(const BufferLoadNode* op) {
if (op->dtype.lanes() > 1) {
if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) {
std::stringstream s;
s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes ("
<< op->dtype.bytes() << ") for dtype " << op->dtype
<< " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")";
errors_.push_back(s.str());
}
}
ExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const BufferStoreNode* op) {
if (op->value->dtype.lanes() > 1) {
if (static_cast<size_t>(op->value->dtype.lanes() * op->value->dtype.bytes()) >
max_vector_bytes_) {
std::stringstream s;
s << "Number of lanes (" << op->value->dtype.lanes() << ") times number of bytes ("
<< op->value->dtype.bytes() << ") for dtype " << op->value->dtype
<< " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")";
errors_.push_back(s.str());
}
}
StmtVisitor::VisitStmt_(op);
}
private:
int nest_level_{0};
std::unordered_set<std::string> visited_threads_;
size_t thread_x_extent_, thread_y_extent_, thread_z_extent_;
size_t local_memory_per_block_;
size_t shared_memory_per_block_;
size_t thread_per_block_;
size_t kernels_launched_{0};
size_t max_local_memory_per_block_;
size_t max_shared_memory_per_block_;
size_t max_threads_per_block_;
size_t max_thread_x_, max_thread_y_, max_thread_z_, max_vthread_;
size_t max_vector_bytes_;
size_t max_kernels_;
std::vector<String> errors_;
void Reset_() {
local_memory_per_block_ = 0;
shared_memory_per_block_ = 0;
visited_threads_.clear();
thread_per_block_ = 1;
}
};
std::vector<String> VerifyGPUCode_(const PrimFunc& func, Map<String, PrimExpr> constraints) {
GPUCodeVerifier verifier;
int64_t max_local_memory_per_block = INT64_MAX;
int64_t max_shared_memory_per_block = INT64_MAX;
int64_t max_threads_per_block = INT64_MAX;
int64_t max_thread_x = INT64_MAX;
int64_t max_thread_y = INT64_MAX;
int64_t max_thread_z = INT64_MAX;
int64_t max_vthread = INT64_MAX;
int64_t max_vector_bytes = INT64_MAX;
int64_t max_kernels = INT64_MAX;
for (auto iter : constraints) {
const IntImmNode* val = iter.second.as<IntImmNode>();
if (iter.first == "max_local_memory_per_block") {
max_local_memory_per_block = val->value;
} else if (iter.first == "max_shared_memory_per_block") {
max_shared_memory_per_block = val->value;
} else if (iter.first == "max_threads_per_block") {
max_threads_per_block = val->value;
} else if (iter.first == "max_thread_x") {
max_thread_x = val->value;
} else if (iter.first == "max_thread_y") {
max_thread_y = val->value;
} else if (iter.first == "max_thread_z") {
max_thread_z = val->value;
} else if (iter.first == "max_vthread") {
max_vthread = val->value;
} else if (iter.first == "max_vector_bytes") {
max_vector_bytes = val->value;
} else if (iter.first == "max_kernels") {
max_kernels = val->value;
} else {
LOG(FATAL) << "Invalid check item: " << iter.first;
}
}
return verifier.Verify(func->body, max_local_memory_per_block, max_shared_memory_per_block,
max_threads_per_block, max_thread_x, max_thread_y, max_thread_z,
max_vthread, max_vector_bytes, max_kernels);
}
bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints) {
auto errs = VerifyGPUCode_(func, constraints);
return errs.size() == 0;
}
TVM_REGISTER_GLOBAL("tir.analysis.verify_gpu_code").set_body_typed(VerifyGPUCode);
namespace transform {
Pass VerifyGPUCode(Map<String, PrimExpr> constraints) {
auto pass_func = [=](IRModule mod, PassContext ctx) {
for (auto kv : mod->functions) {
if (auto* n = kv.second.as<PrimFuncNode>()) {
auto func = GetRef<PrimFunc>(n);
auto errs = VerifyGPUCode_(func, constraints);
if (errs.size() != 0) {
std::stringstream s;
for (auto& err : errs) {
s << " " << err << std::endl;
}
LOG(FATAL) << "RuntimeError: GPU constraint(s) violated:\n"
<< s.str() << " In function\n"
<< func;
}
}
}
return mod;
};
return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifyGPUCode", {});
}
TVM_REGISTER_GLOBAL("tir.transform.VerifyGPUCode").set_body_typed(VerifyGPUCode);
} // namespace transform
} // namespace tir
} // namespace tvm