Skip to content

Commit

Permalink
[CODEGEN] Fix code generation bugs for cuda & Improve verification pa…
Browse files Browse the repository at this point in the history
…ss for cuda
  • Loading branch information
merrymercy committed Jul 13, 2020
1 parent 9f7745e commit a877f3d
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 11 deletions.
19 changes: 10 additions & 9 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -629,12 +629,12 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
this->PrintExpr(op->args[0], os);
os << " == NULL)";
} else if (op->op.same_as(builtin::reinterpret())) {
// generate (*( TYPE *)(&(ARG)))
int ssa_scope = BeginScope();
std::string rhs = SSAGetID(PrintExpr(op->args[0]), op->args[0]->dtype);
os << "(*(";
this->PrintType(op->dtype, os);
os << " *)(&(";
this->PrintExpr(op->args[0], os);
os << ")))";
os << " *)(&(" << rhs << ")))";
EndScope(ssa_scope);
} else if (op->op.same_as(builtin::isnan())) {
os << "(";
this->PrintExpr(op->args[0], os);
Expand Down Expand Up @@ -720,14 +720,15 @@ void CodeGenC::VisitStmt_(const StoreNode* op) {
} else {
CHECK(is_one(op->predicate)) << "Predicated store is not supported";
arith::PVar<PrimExpr> base;

// The assignment below introduces side-effect, and the resulting value cannot
// be reused across multiple expression, thus a new scope is needed
int vec_scope = BeginScope();

if (arith::ramp(base, 1, t.lanes()).Match(op->index)) {
std::string value = this->PrintExpr(op->value);
this->PrintVecStore(op->buffer_var.get(), t, base.Eval(), value);
} else {
// The assignment below introduces side-effect, and the resulting value cannot
// be reused across multiple expression, thus a new scope is needed
int vec_scope = BeginScope();

// store elements seperately
std::string index = SSAGetID(PrintExpr(op->index), op->index.dtype());
std::string value = SSAGetID(PrintExpr(op->value), op->value.dtype());
Expand All @@ -754,8 +755,8 @@ void CodeGenC::VisitStmt_(const StoreNode* op) {
PrintVecElemLoad(value, op->value.dtype(), i, stream);
stream << ";\n";
}
EndScope(vec_scope);
}
EndScope(vec_scope);
}
}

Expand Down
12 changes: 12 additions & 0 deletions src/tir/analysis/verify_gpu_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,18 @@ class GPUCodeVerifier : public StmtExprVisitor {
ExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const StoreNode* op) {
// Currently not able to check out: If the index expression failed
// to be simplified to a RampNode
if (op->index->IsInstance<RampNode>()) {
if (op->index->dtype.lanes() > 1) {
valid_ &= static_cast<size_t>(op->index->dtype.lanes() * op->index->dtype.bytes())
<= max_vector_bytes_;
}
}
StmtVisitor::VisitStmt_(op);
}

private:
int nest_level_{0};

Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_target_codegen_c_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_reinterpret():
nn = 1024
n = tvm.runtime.convert(nn)
A = te.placeholder((n,), name='A', dtype="int32")
B = te.compute(A.shape, lambda *i: tvm.tir.call_intrin("float32", "tir.reinterpret", A(*i)), name='B')
B = te.compute(A.shape, lambda *i: tvm.tir.call_intrin("float32", "tir.reinterpret", 2 + A(*i)), name='B')
s = te.create_schedule(B.op)

def check_c():
Expand All @@ -114,7 +114,7 @@ def check_c():
b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
fadd(a, b)
tvm.testing.assert_allclose(
b.asnumpy(), a.asnumpy().view('float32'))
b.asnumpy(), (2 + a.asnumpy()).view('float32'))
check_c()


Expand Down

0 comments on commit a877f3d

Please sign in to comment.