Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] support general type annotation. #1480

Merged
merged 1 commit into from
Jul 24, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,16 @@ using HalideIR::Internal::Shuffle;
// ir functions
using HalideIR::Internal::is_const_power_of_two_integer;

/*!
* \brief Create a type annotation expression
* \param dtype The data type
* \return Expr a expression with dtype.
*/
inline Expr TypeAnnotation(Type dtype) {
return ir::Call::make(dtype,
"type_annotation", {},
ir::Call::PureIntrinsic);
}
} // namespace ir
} // namespace tvm

Expand Down
4 changes: 2 additions & 2 deletions src/lang/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -350,12 +350,12 @@ Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, Expr
}
Expr elem_offset = self->elem_offset + offset;
if (content_lanes > 1) {
e_dtype = make_zero(self->dtype.with_lanes(content_lanes));
e_dtype = ir::TypeAnnotation(self->dtype.with_lanes(content_lanes));
extent = extent / make_const(self->elem_offset.type(), content_lanes);
elem_offset = self->elem_offset / make_const(self->elem_offset.type(),
content_lanes);
} else {
e_dtype = make_zero(self->dtype);
e_dtype = ir::TypeAnnotation(self->dtype);
}
Array<Expr> acc_args{
e_dtype, self->data, elem_offset,
Expand Down
5 changes: 4 additions & 1 deletion vta/src/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
#include <memory>
#include <atomic>


namespace vta {

// Avoid bad configurations.
static_assert(VTA_UOP_WIDTH == sizeof(VTAUop) * 8,
"VTA_UOP_WIDTH do not match VTAUop size");

/*! \brief Enable coherent access between VTA and CPU. */
static const bool kBufferCoherent = true;

Expand Down
4 changes: 3 additions & 1 deletion vta/src/sim/sim_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,12 @@ class SRAM {
CHECK_LE(sram_end, kMaxNumElem);
memset(sram_ptr, 0, kElemBytes * xtotal * op->y_pad_0);
sram_ptr += xtotal * op->y_pad_0;

for (uint32_t y = 0; y < op->y_size; ++y) {
memset(sram_ptr, 0, kElemBytes * op->x_pad_0);
sram_ptr += op->x_pad_0;
memcpy(sram_ptr, dram_ptr, kElemBytes * op->x_size);
sram_ptr += op->x_size;
BitPacker<kBits> src(sram_ptr);
memset(sram_ptr, 0, kElemBytes * op->x_pad_1);
sram_ptr += op->x_pad_1;
dram_ptr += kElemBytes * op->x_stride;
Expand Down Expand Up @@ -415,12 +415,14 @@ class Device {
uint32_t acc_idx = uop_ptr->dst_idx;
uint32_t inp_idx = uop_ptr->src_idx;
uint32_t wgt_idx = uop_ptr->wgt_idx;

acc_idx += y * op->dst_factor_out + x * op->dst_factor_in;
inp_idx += y * op->src_factor_out + x * op->src_factor_in;
wgt_idx += y * op->wgt_factor_out + x * op->wgt_factor_in;
BitPacker<VTA_ACC_WIDTH> acc(acc_.BeginPtr(acc_idx));
BitPacker<VTA_INP_WIDTH> inp(inp_.BeginPtr(inp_idx));
BitPacker<VTA_WGT_WIDTH> wgt(wgt_.BeginPtr(wgt_idx));

// gemm loop
for (uint32_t i = 0; i < VTA_BATCH; ++i) {
for (uint32_t j = 0; j < VTA_BLOCK_OUT; ++j) {
Expand Down