Skip to content

Commit

Permalink
[CODEGEN][METAL] Fix unaligned vector load (#14332)
Browse files Browse the repository at this point in the history
This PR fixes the implementation of unaligned vector load.
Previously vector construction was printed as (float2)(v0, v1).
This will cause problem as C have comma expression, and (v0, v1) will be evaluated as v1.
The final result will become float2(v1, v1). The bug affects all codegen that uses
the default implementation, such as metal. We added a testcase on metal to cover this case.

Also updated codegen opencl to keep the old style as that is the convention opencl follows.
  • Loading branch information
tqchen authored Mar 20, 2023
1 parent 56d0e3b commit fc2a9e5
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 20 deletions.
15 changes: 9 additions & 6 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -802,15 +802,16 @@ void CodeGenC::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT(*)
}

void CodeGenC::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*)
// constraint of current logic
ICHECK_EQ(op->base.dtype(), DataType::Int(32));
os << "((int" << op->lanes << ")(";
// NOTE: C have comma expression so cannot use (int2)(v0, v1)
// instead should use int2(v0, v1)
PrintType(op->dtype, os);
os << "(";
for (int i = 0; i < op->lanes; i++) {
os << "(" << PrintExpr(op->base) << ")"
<< "+(" << PrintExpr(op->stride) << "*" << i << ")";
if (i != op->lanes - 1) os << ", ";
}
os << "))";
os << ")";
}

void CodeGenC::VisitExpr_(const ShuffleNode* op, std::ostream& os) {
Expand Down Expand Up @@ -999,9 +1000,11 @@ void CodeGenC::PrintVecElemLoadExpr(DataType t, int i, const std::string& value,
}

if (i == 0) {
os << "((";
// NOTE: C have comma expression so cannot use (float2)(v0, v1)
// instead should use float2(v0, v1)
os << "(";
PrintType(t, os);
os << ")(";
os << "(";
}
os << value;
if (i != t.lanes() - 1) {
Expand Down
17 changes: 5 additions & 12 deletions src/target/source/codegen_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -299,17 +299,6 @@ void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // N
os << ')';
}

void CodeGenMetal::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*)
PrintType(op->dtype, os);
os << "(";
for (int i = 0; i < op->lanes; ++i) {
if (i != 0) os << ", ";
os << "(" << PrintExpr(op->base) << ")"
<< "+(" << PrintExpr(op->stride) << "*" << i << ")";
}
os << ')';
}

void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
if (op->op.same_as(builtin::reinterpret())) {
// generate as_type<TYPE>(ARG)
Expand Down Expand Up @@ -369,7 +358,11 @@ runtime::Module BuildMetal(IRModule mod, Target target) {
code << fsource;
}

return MetalModuleCreate(code.str(), fmt, ExtractFuncInfo(mod), source.str());
std::string code_str = code.str();
if (const auto* f = Registry::Get("tvm_callback_metal_postproc")) {
code_str = (*f)(code_str).operator std::string();
}
return MetalModuleCreate(code_str, fmt, ExtractFuncInfo(mod), source.str());
}

TVM_REGISTER_GLOBAL("target.build.metal").set_body_typed(BuildMetal);
Expand Down
4 changes: 2 additions & 2 deletions src/target/source/codegen_metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ class CodeGenMetal final : public CodeGenC {
void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final;
// overload visitor
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode* op, std::ostream& os) final;
void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*)

// reuse parent's function.
using CodeGenC::PrintType;

Expand Down
37 changes: 37 additions & 0 deletions src/target/source/codegen_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,31 @@ void CodeGenOpenCL::PrintVecStore(const BufferNode* buffer, DataType t, PrimExpr
stream << ");\n";
}

void CodeGenOpenCL::PrintVecElemLoadExpr(DataType t, int i, const std::string& value,
std::ostream& os) { // NOLINT(*)
ICHECK_GT(t.lanes(), 1);
if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
if (i != 0) {
os << "|";
}
os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))";
return;
}
if (i == 0) {
// NOTE: opencl print things as (float2)(v0, v1)
os << "((";
PrintType(t, os);
os << ")(";
}
os << value;
if (i != t.lanes() - 1) {
os << ",";
} else {
os << "))";
}
return;
}

void CodeGenOpenCL::PrintStorageSync(const CallNode* op) {
const std::string& sync = op->args[0].as<StringImmNode>()->value;
if (sync == "warp") {
Expand Down Expand Up @@ -490,6 +515,18 @@ void CodeGenOpenCL::VisitExpr_(const BroadcastNode* op, std::ostream& os) { //
os << "))";
}

void CodeGenOpenCL::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*)
os << "((";
PrintType(op->dtype, os);
os << ")(";
for (int i = 0; i < op->lanes; i++) {
os << "(" << PrintExpr(op->base) << ")"
<< "+(" << PrintExpr(op->stride) << "*" << i << ")";
if (i != op->lanes - 1) os << ", ";
}
os << "))";
}

void CodeGenOpenCL::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*)
if (std::isinf(op->value)) {
if (op->value < 0) {
Expand Down
3 changes: 3 additions & 0 deletions src/target/source/codegen_opencl.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class CodeGenOpenCL final : public CodeGenC {
std::string GetVecLoad(DataType t, const BufferNode* buffer, PrimExpr base) final;
void PrintVecStore(const BufferNode* buffer, DataType t, PrimExpr base,
const std::string& value) final; // NOLINT(*)
void PrintVecElemLoadExpr(DataType t, int i, const std::string& value,
std::ostream& os) final; // NOLINT(*)
// the address of load/store
void PrintVecAddr(const BufferNode* buffer, DataType t, PrimExpr base,
std::ostream& os); // NOLINT(*)
Expand All @@ -62,6 +64,7 @@ class CodeGenOpenCL final : public CodeGenC {
// overload visitor
void VisitStmt_(const AllocateNode* op) final; // NOLINT(*)
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const CastNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*)
Expand Down
25 changes: 25 additions & 0 deletions tests/python/unittest/test_target_codegen_metal.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,31 @@ def check_inf_nan(dev, n, value, dtype):
check_inf_nan(dev, 1, float("nan"), "float16")


@tvm.testing.requires_gpu
@tvm.testing.requires_metal
def test_unaligned_vectorize():
@tvm.script.ir_module
class IRModule:
@T.prim_func
def main(A: T.Buffer((2, 3), "float32"), B: T.Buffer((6,), "float32")):
T.func_attr({"global_symbol": "main"})
for i0_1 in T.thread_binding(3, thread="threadIdx.x"):
for i0_0 in T.vectorized(2):
with T.block("block"):
vi0 = T.axis.spatial(6, i0_0 * 3 + i0_1)
B[vi0] = A[vi0 // 3, vi0 % 3]

target = "metal"
dev = tvm.metal()

a = (np.arange(6).reshape(2, 3)).astype("float32")
a_nd = tvm.nd.array(a, dev)
b_nd = tvm.nd.empty((6,), "float32", dev)
f = tvm.build(IRModule, target=target)
f(a_nd, b_nd)
np.testing.assert_allclose(b_nd.numpy(), a.reshape(6), atol=1e-5, rtol=1e-5)


@tvm.testing.requires_gpu
@tvm.testing.requires_metal
def test_metal_erf():
Expand Down

0 comments on commit fc2a9e5

Please sign in to comment.