From a85bac5864f7c107cf4c92a2d1545923fa1b3340 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 24 Jul 2018 14:01:58 -0700 Subject: [PATCH] [IR] support general type annotation. (#1480) --- include/tvm/ir.h | 10 ++++++++++ src/lang/buffer.cc | 4 ++-- vta/src/runtime.cc | 5 ++++- vta/src/sim/sim_driver.cc | 4 +++- 4 files changed, 19 insertions(+), 4 deletions(-) diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 8c398c9a55846..97e0b44d2fec3 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -514,6 +514,16 @@ using HalideIR::Internal::Shuffle; // ir functions using HalideIR::Internal::is_const_power_of_two_integer; +/*! + * \brief Create a type annotation expression + * \param dtype The data type + * \return Expr a expression with dtype. + */ +inline Expr TypeAnnotation(Type dtype) { + return ir::Call::make(dtype, + "type_annotation", {}, + ir::Call::PureIntrinsic); +} } // namespace ir } // namespace tvm diff --git a/src/lang/buffer.cc b/src/lang/buffer.cc index 39566df45ae64..3f23c2d480bff 100644 --- a/src/lang/buffer.cc +++ b/src/lang/buffer.cc @@ -350,12 +350,12 @@ Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, Expr } Expr elem_offset = self->elem_offset + offset; if (content_lanes > 1) { - e_dtype = make_zero(self->dtype.with_lanes(content_lanes)); + e_dtype = ir::TypeAnnotation(self->dtype.with_lanes(content_lanes)); extent = extent / make_const(self->elem_offset.type(), content_lanes); elem_offset = self->elem_offset / make_const(self->elem_offset.type(), content_lanes); } else { - e_dtype = make_zero(self->dtype); + e_dtype = ir::TypeAnnotation(self->dtype); } Array acc_args{ e_dtype, self->data, elem_offset, diff --git a/vta/src/runtime.cc b/vta/src/runtime.cc index 4d45159a10e1c..ffa0096e1713c 100644 --- a/vta/src/runtime.cc +++ b/vta/src/runtime.cc @@ -18,9 +18,12 @@ #include #include - namespace vta { +// Avoid bad configurations. +static_assert(VTA_UOP_WIDTH == sizeof(VTAUop) * 8, + "VTA_UOP_WIDTH do not match VTAUop size"); + /*! \brief Enable coherent access between VTA and CPU. */ static const bool kBufferCoherent = true; diff --git a/vta/src/sim/sim_driver.cc b/vta/src/sim/sim_driver.cc index 9a953e7aeadb0..60645818757c5 100644 --- a/vta/src/sim/sim_driver.cc +++ b/vta/src/sim/sim_driver.cc @@ -245,12 +245,12 @@ class SRAM { CHECK_LE(sram_end, kMaxNumElem); memset(sram_ptr, 0, kElemBytes * xtotal * op->y_pad_0); sram_ptr += xtotal * op->y_pad_0; + for (uint32_t y = 0; y < op->y_size; ++y) { memset(sram_ptr, 0, kElemBytes * op->x_pad_0); sram_ptr += op->x_pad_0; memcpy(sram_ptr, dram_ptr, kElemBytes * op->x_size); sram_ptr += op->x_size; - BitPacker src(sram_ptr); memset(sram_ptr, 0, kElemBytes * op->x_pad_1); sram_ptr += op->x_pad_1; dram_ptr += kElemBytes * op->x_stride; @@ -415,12 +415,14 @@ class Device { uint32_t acc_idx = uop_ptr->dst_idx; uint32_t inp_idx = uop_ptr->src_idx; uint32_t wgt_idx = uop_ptr->wgt_idx; + acc_idx += y * op->dst_factor_out + x * op->dst_factor_in; inp_idx += y * op->src_factor_out + x * op->src_factor_in; wgt_idx += y * op->wgt_factor_out + x * op->wgt_factor_in; BitPacker acc(acc_.BeginPtr(acc_idx)); BitPacker inp(inp_.BeginPtr(inp_idx)); BitPacker wgt(wgt_.BeginPtr(wgt_idx)); + // gemm loop for (uint32_t i = 0; i < VTA_BATCH; ++i) { for (uint32_t j = 0; j < VTA_BLOCK_OUT; ++j) {