Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PASS] Add GPU IR verifier #1296

Merged
merged 10 commits into from
Jun 23, 2018
Merged

[PASS] Add GPU IR verifier #1296

merged 10 commits into from
Jun 23, 2018

Conversation

merrymercy
Copy link
Member

Add a pass to check whether a cuda ir is valid

@merrymercy merrymercy changed the title [PASS] Add cuda verifier [PASS] Add CUDA IR verifier Jun 17, 2018
@merrymercy
Copy link
Member Author

/workspace/src/arithmetic/modular.cc:168:1: fatal error: error writing to /tmp/cc30MweO.s: No space left on device

* \return valid Whether it is a valid cuda ir
*
*/
bool VerifyCuda(Stmt stmt,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is CUDA specific? or should we call it VerifyGPUCode

@tqchen tqchen added the status: need update need update based on feedbacks label Jun 18, 2018
@tqchen
Copy link
Member

tqchen commented Jun 20, 2018

@merrymercy please act on the comments and fix the ci error


class GPUCodeVerifier : public IRVisitor {
public:
bool verify(tvm::Stmt stmt, int max_shared_memory_per_block, int max_thread_per_block) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CamelCase for fucnctions

if (shared_buffers_.count(op->buffer_var.get()) != 0) {
int64_t size = op->type.bytes();
for (auto dim : op->extents) {
size *= dim.as<IntImm>()->value;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

op-> constant_allocation_size()

// record the number of threads in a block
std::string name = var.get()->name_hint;
if (name == "threadIdx.x" || name == "threadIdx.y" || name == "threadIdx.z") {
if (visited_threads_.find(name) == visited_threads_.end()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

!count(name)

size_t max_shared_memory_per_block_;
size_t max_thread_per_block_;

bool valid{true};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

valid_


bool valid{true};

void reset_() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reset()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this visitor is only used once, reset is not necessary

Copy link
Member Author

@merrymercy merrymercy Jun 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reset is needed because there might be several gpu kernels in one Stmt.

}

if (op->is_producer) {
nest_level_++;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prefer --nest_+level

@tqchen
Copy link
Member

tqchen commented Jun 20, 2018

@eqy can you also do a round of codereview?

@eqy
Copy link
Contributor

eqy commented Jun 20, 2018

@tqchen @merrymercy if we also save the number of threads per dimension (x, y, z), perhaps we can also use this to capture
CL_INVALID_WORK_ITEM_SIZE (e.g., in https://www.khronos.org/registry/OpenCL/sdk/2.0/docs/man/xhtml/clEnqueueNDRangeKernel.html)
for OpenCL kernels. But I am not sure if this can happen in CUDA land, so we could also use a separate verifier pass there.

EDIT: It does seem that CUDA devices can have a similar limit, which can read with deviceQuery
e.g., Max dimension size of a thread block (x,y,z): (1024, 1024, 64)

@tqchen
Copy link
Member

tqchen commented Jun 20, 2018

It would be a good idea to pass in as many constraints as possible and allowing defaults to non-constraints. One possible way to do so is to allow pass in of a Map<str, value> so we don't have a lot of positional arguments


std::unordered_set<const tvm::Variable *> shared_buffers_;
std::unordered_set<std::string> visited_threads_;
size_t shared_memory_per_block_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

local_memiry_per_block_ is also needed.

@merrymercy merrymercy changed the title [PASS] Add CUDA IR verifier [PASS] Add GPU IR verifier Jun 21, 2018
@merrymercy
Copy link
Member Author

ready for review

@eqy
Copy link
Contributor

eqy commented Jun 22, 2018

Is the plan to skip checking threadId/workitem dimensions in this round?

@tqchen
Copy link
Member

tqchen commented Jun 22, 2018

please fix the compiler warning http://mode-gpu.cs.washington.edu:8080/blue/organizations/jenkins/dmlc%2Ftvm/detail/PR-1296/11/pipeline

Currently we set compiler warning as error so built won't pass if there is a warning

@merrymercy
Copy link
Member Author

merrymercy commented Jun 23, 2018

@eqy your comment is addressed
@tqchen no space left on CI

"""Test gpu code verifier"""
import tvm

global valid
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

always avoid use of global variable, to carry state, you can use a closure to capture a list

@tqchen
Copy link
Member

tqchen commented Jun 23, 2018

The test error likely indicate there is some problem with the current PR when importing runtime only dll.

@tqchen tqchen merged commit e66996a into apache:master Jun 23, 2018
@tqchen tqchen added status: accepted and removed status: need update need update based on feedbacks labels Jun 23, 2018
tqchen pushed a commit to tqchen/tvm that referenced this pull request Jul 6, 2018
@merrymercy merrymercy deleted the cuda-verifier branch July 10, 2018 20:08
mnuyens pushed a commit to mnuyens/tvm that referenced this pull request Jul 10, 2018
sergei-mironov pushed a commit to sergei-mironov/tvm that referenced this pull request Aug 8, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants