diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index e9f65eed166a4..00e3335633ecd 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -283,7 +283,7 @@ inline PrimExpr BufferOffset(const BufferNode* n, Array index, DataTyp } PrimExpr Buffer::vload(Array begin, DataType dtype) const { - // specially handle bool, stored asDataType::Int(8) + // specially handle bool, stored as DataType::Int(8) const BufferNode* n = operator->(); CHECK(dtype.element_of() == n->dtype.element_of() && dtype.lanes() % n->dtype.lanes() == 0) << "Cannot load " << dtype << " from buffer of " << n->dtype; @@ -297,11 +297,11 @@ PrimExpr Buffer::vload(Array begin, DataType dtype) const { } Stmt Buffer::vstore(Array begin, PrimExpr value) const { - // specially handle bool, stored asDataType::Int(8) + // specially handle bool, stored as DataType::Int(8) const BufferNode* n = operator->(); DataType dtype = value.dtype(); CHECK(dtype.element_of() == n->dtype.element_of() && dtype.lanes() % n->dtype.lanes() == 0) - << "Cannot load " << dtype << " from buffer of " << n->dtype; + << "Cannot store " << dtype << " to buffer of " << n->dtype; if (value.dtype() == DataType::Bool()) { return tir::Store(n->data, tir::Cast(DataType::Int(8), value), BufferOffset(n, begin, DataType::Int(8)), const_true());