Skip to content

Commit

Permalink
[IR] support general type annotation. (apache#1480)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored and sergei-mironov committed Aug 8, 2018
1 parent c730111 commit a85bac5
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 4 deletions.
10 changes: 10 additions & 0 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,16 @@ using HalideIR::Internal::Shuffle;
// ir functions
using HalideIR::Internal::is_const_power_of_two_integer;

/*!
* \brief Create a type annotation expression
* \param dtype The data type
* \return Expr a expression with dtype.
*/
inline Expr TypeAnnotation(Type dtype) {
return ir::Call::make(dtype,
"type_annotation", {},
ir::Call::PureIntrinsic);
}
} // namespace ir
} // namespace tvm

Expand Down
4 changes: 2 additions & 2 deletions src/lang/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -350,12 +350,12 @@ Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, Expr
}
Expr elem_offset = self->elem_offset + offset;
if (content_lanes > 1) {
e_dtype = make_zero(self->dtype.with_lanes(content_lanes));
e_dtype = ir::TypeAnnotation(self->dtype.with_lanes(content_lanes));
extent = extent / make_const(self->elem_offset.type(), content_lanes);
elem_offset = self->elem_offset / make_const(self->elem_offset.type(),
content_lanes);
} else {
e_dtype = make_zero(self->dtype);
e_dtype = ir::TypeAnnotation(self->dtype);
}
Array<Expr> acc_args{
e_dtype, self->data, elem_offset,
Expand Down
5 changes: 4 additions & 1 deletion vta/src/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
#include <memory>
#include <atomic>


namespace vta {

// Avoid bad configurations.
static_assert(VTA_UOP_WIDTH == sizeof(VTAUop) * 8,
"VTA_UOP_WIDTH do not match VTAUop size");

/*! \brief Enable coherent access between VTA and CPU. */
static const bool kBufferCoherent = true;

Expand Down
4 changes: 3 additions & 1 deletion vta/src/sim/sim_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,12 @@ class SRAM {
CHECK_LE(sram_end, kMaxNumElem);
memset(sram_ptr, 0, kElemBytes * xtotal * op->y_pad_0);
sram_ptr += xtotal * op->y_pad_0;

for (uint32_t y = 0; y < op->y_size; ++y) {
memset(sram_ptr, 0, kElemBytes * op->x_pad_0);
sram_ptr += op->x_pad_0;
memcpy(sram_ptr, dram_ptr, kElemBytes * op->x_size);
sram_ptr += op->x_size;
BitPacker<kBits> src(sram_ptr);
memset(sram_ptr, 0, kElemBytes * op->x_pad_1);
sram_ptr += op->x_pad_1;
dram_ptr += kElemBytes * op->x_stride;
Expand Down Expand Up @@ -415,12 +415,14 @@ class Device {
uint32_t acc_idx = uop_ptr->dst_idx;
uint32_t inp_idx = uop_ptr->src_idx;
uint32_t wgt_idx = uop_ptr->wgt_idx;

acc_idx += y * op->dst_factor_out + x * op->dst_factor_in;
inp_idx += y * op->src_factor_out + x * op->src_factor_in;
wgt_idx += y * op->wgt_factor_out + x * op->wgt_factor_in;
BitPacker<VTA_ACC_WIDTH> acc(acc_.BeginPtr(acc_idx));
BitPacker<VTA_INP_WIDTH> inp(inp_.BeginPtr(inp_idx));
BitPacker<VTA_WGT_WIDTH> wgt(wgt_.BeginPtr(wgt_idx));

// gemm loop
for (uint32_t i = 0; i < VTA_BATCH; ++i) {
for (uint32_t j = 0; j < VTA_BLOCK_OUT; ++j) {
Expand Down

0 comments on commit a85bac5

Please sign in to comment.