Skip to content

Commit

Permalink
[CODEGEN] Fix code generation bugs for cuda & Improve verification pa…
Browse files Browse the repository at this point in the history
…ss for cuda
  • Loading branch information
merrymercy committed Jul 13, 2020
1 parent 9f7745e commit 993c6fe
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 9 deletions.
19 changes: 10 additions & 9 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -629,12 +629,12 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
this->PrintExpr(op->args[0], os);
os << " == NULL)";
} else if (op->op.same_as(builtin::reinterpret())) {
// generate (*( TYPE *)(&(ARG)))
int ssa_scope = BeginScope();
std::string rhs = SSAGetID(PrintExpr(op->args[0]), op->args[0]->dtype);
os << "(*(";
this->PrintType(op->dtype, os);
os << " *)(&(";
this->PrintExpr(op->args[0], os);
os << ")))";
os << " *)(&(" << rhs << ")))";
EndScope(ssa_scope);
} else if (op->op.same_as(builtin::isnan())) {
os << "(";
this->PrintExpr(op->args[0], os);
Expand Down Expand Up @@ -720,14 +720,15 @@ void CodeGenC::VisitStmt_(const StoreNode* op) {
} else {
CHECK(is_one(op->predicate)) << "Predicated store is not supported";
arith::PVar<PrimExpr> base;

// The assignment below introduces side-effect, and the resulting value cannot
// be reused across multiple expression, thus a new scope is needed
int vec_scope = BeginScope();

if (arith::ramp(base, 1, t.lanes()).Match(op->index)) {
std::string value = this->PrintExpr(op->value);
this->PrintVecStore(op->buffer_var.get(), t, base.Eval(), value);
} else {
// The assignment below introduces side-effect, and the resulting value cannot
// be reused across multiple expression, thus a new scope is needed
int vec_scope = BeginScope();

// store elements seperately
std::string index = SSAGetID(PrintExpr(op->index), op->index.dtype());
std::string value = SSAGetID(PrintExpr(op->value), op->value.dtype());
Expand All @@ -754,8 +755,8 @@ void CodeGenC::VisitStmt_(const StoreNode* op) {
PrintVecElemLoad(value, op->value.dtype(), i, stream);
stream << ";\n";
}
EndScope(vec_scope);
}
EndScope(vec_scope);
}
}

Expand Down
12 changes: 12 additions & 0 deletions src/tir/analysis/verify_gpu_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,18 @@ class GPUCodeVerifier : public StmtExprVisitor {
ExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const StoreNode* op) {
// Currently not able to check out: If the index expression failed
// to be simplified to a RampNode
if (op->index->IsInstance<RampNode>()) {
if (op->index->dtype.lanes() > 1) {
valid_ &= static_cast<size_t>(op->index->dtype.lanes() * op->index->dtype.bytes())
<= max_vector_bytes_;
}
}
StmtVisitor::VisitStmt_(op);
}

private:
int nest_level_{0};

Expand Down

0 comments on commit 993c6fe

Please sign in to comment.