diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 7c3c8309e1150..153089204576c 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -629,12 +629,12 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) this->PrintExpr(op->args[0], os); os << " == NULL)"; } else if (op->op.same_as(builtin::reinterpret())) { - // generate (*( TYPE *)(&(ARG))) + int ssa_scope = BeginScope(); + std::string rhs = SSAGetID(PrintExpr(op->args[0]), op->args[0]->dtype); os << "(*("; this->PrintType(op->dtype, os); - os << " *)(&("; - this->PrintExpr(op->args[0], os); - os << ")))"; + os << " *)(&(" << rhs << ")))"; + EndScope(ssa_scope); } else if (op->op.same_as(builtin::isnan())) { os << "("; this->PrintExpr(op->args[0], os); @@ -720,14 +720,15 @@ void CodeGenC::VisitStmt_(const StoreNode* op) { } else { CHECK(is_one(op->predicate)) << "Predicated store is not supported"; arith::PVar base; + + // The assignment below introduces side-effect, and the resulting value cannot + // be reused across multiple expression, thus a new scope is needed + int vec_scope = BeginScope(); + if (arith::ramp(base, 1, t.lanes()).Match(op->index)) { std::string value = this->PrintExpr(op->value); this->PrintVecStore(op->buffer_var.get(), t, base.Eval(), value); } else { - // The assignment below introduces side-effect, and the resulting value cannot - // be reused across multiple expression, thus a new scope is needed - int vec_scope = BeginScope(); - // store elements seperately std::string index = SSAGetID(PrintExpr(op->index), op->index.dtype()); std::string value = SSAGetID(PrintExpr(op->value), op->value.dtype()); @@ -754,8 +755,8 @@ void CodeGenC::VisitStmt_(const StoreNode* op) { PrintVecElemLoad(value, op->value.dtype(), i, stream); stream << ";\n"; } - EndScope(vec_scope); } + EndScope(vec_scope); } } diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index 9477e044fc336..4d4cfcfb3640c 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -145,6 +145,18 @@ class GPUCodeVerifier : public StmtExprVisitor { ExprVisitor::VisitExpr_(op); } + void VisitStmt_(const StoreNode* op) { + // Currently not able to check out: If the index expression failed + // to be simplified to a RampNode + if (op->index->IsInstance()) { + if (op->index->dtype.lanes() > 1) { + valid_ &= static_cast(op->index->dtype.lanes() * op->index->dtype.bytes()) + <= max_vector_bytes_; + } + } + StmtVisitor::VisitStmt_(op); + } + private: int nest_level_{0};