Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[VTA][SIM] Allow debug mode in simulator to skip execution (#6)
Browse files Browse the repository at this point in the history
tqchen authored and tmoreau89 committed Dec 1, 2018
1 parent 9b32883 commit 35f8b96
Showing 4 changed files with 54 additions and 88 deletions.
13 changes: 13 additions & 0 deletions vta/python/vta/testing/simulator.py
Original file line number Diff line number Diff line change
@@ -27,6 +27,19 @@ def clear_stats():
if f:
f()

# debug flag to skip execution.
DEBUG_SKIP_EXEC = 1

def debug_mode(flag):
"""Set debug mode
Paramaters
----------
flag : int
The debug flag, 0 means clear all flags.
"""
tvm.get_global_func("vta.simulator.profiler_debug_mode")(flag)


def stats():
"""Clear profiler statistics
81 changes: 0 additions & 81 deletions vta/python/vta/top/arm_conv2d.py
Original file line number Diff line number Diff line change
@@ -5,87 +5,6 @@
from topi.nn import conv2d, conv2d_alter_layout
from topi import generic

_WORKLOADS = [
# resnet 18
Workload('float32', 'float32', 224, 224, 3, 64, 7, 7, 3, 3, 2, 2),
Workload('int8', 'int32', 224, 224, 3, 64, 7, 7, 3, 3, 2, 2),
Workload('int8', 'int32', 56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
Workload('int8', 'int32', 56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
Workload('int8', 'int32', 56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
Workload('int8', 'int32', 56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
Workload('int8', 'int32', 28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
Workload('int8', 'int32', 28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
Workload('int8', 'int32', 28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
Workload('int8', 'int32', 14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
Workload('int8', 'int32', 14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
Workload('int8', 'int32', 14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
Workload('int8', 'int32', 7, 7, 512, 512, 3, 3, 1, 1, 1, 1),

# mobilenet float32
Workload('float32', 'float32', 224, 224, 3, 32, 3, 3, 1, 1, 2, 2),
Workload('float32', 'float32', 112, 112, 32, 64, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 56, 56, 64, 128, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 56, 56, 128, 128, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 28, 28, 128, 256, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 28, 28, 256, 256, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 14, 14, 256, 512, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 14, 14, 512, 512, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 7, 7, 512, 1024, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 7, 7, 1024, 1024, 1, 1, 0, 0, 1, 1),

# mobilenet int8
Workload('float32', 'float32', 224, 224, 3, 32, 3, 3, 1, 1, 2, 2),
Workload('int8', 'int32', 112, 112, 32, 64, 1, 1, 0, 0, 1, 1),
Workload('int8', 'int32', 56, 56, 64, 128, 1, 1, 0, 0, 1, 1),
Workload('int8', 'int32', 56, 56, 128, 128, 1, 1, 0, 0, 1, 1),
Workload('int8', 'int32', 28, 28, 128, 256, 1, 1, 0, 0, 1, 1),
Workload('int8', 'int32', 28, 28, 256, 256, 1, 1, 0, 0, 1, 1),
Workload('int8', 'int32', 14, 14, 256, 512, 1, 1, 0, 0, 1, 1),
Workload('int8', 'int32', 14, 14, 512, 512, 1, 1, 0, 0, 1, 1),
Workload('int8', 'int32', 7, 7, 512, 1024, 1, 1, 0, 0, 1, 1),
Workload('int8', 'int32', 7, 7, 1024, 1024, 1, 1, 0, 0, 1, 1),
]

_SCHEDULES = [
# float32 imagenet
SpatialPack(1, 8, 4, 1, 4, True),
SpatialPack(1, 8, 4, 1, 4, True),
SpatialPack(1, 7, 4, 2, 4, True),
SpatialPack(1, 4, 8, 4, 1, True),
SpatialPack(1, 4, 4, 1, 16, False),
SpatialPack(1, 4, 8, 4, 8, False),
SpatialPack(1, 7, 4, 3, 8, True),
SpatialPack(1, 2, 8, 1, 8, True),
SpatialPack(2, 1, 16, 1, 4, True),
SpatialPack(1, 7, 4, 1, 1, True),
Im2ColPack(7, 4, 1, 16, True),
Im2ColPack(7, 4, 1, 8, False),
Im2ColPack(7, 4, 1, 16, False),

# float32 mobilenet
SpatialPack(2, 2, 4, 28, 1, True),
SpatialPack(1, 4, 8, 14, 1, False),
SpatialPack(1, 2, 16, 8, 1, True),
SpatialPack(1, 4, 8, 8, 8, True),
SpatialPack(2, 2, 8, 1, 1, False),
SpatialPack(1, 4, 8, 4, 8, False),
SpatialPack(2, 2, 8, 1, 4, False),
SpatialPack(2, 2, 8, 1, 8, False),
Im2ColPack(7, 4, 1, 16, False),
Im2ColPack(7, 4, 1, 4, True),

# int8 mobilenet
SpatialPack(2, 2, 4, 28, 1, True),
SpatialPack(1, 4, 8, 14, 1, False),
SpatialPack(1, 2, 16, 8, 1, True),
SpatialPack(1, 4, 8, 8, 8, True),
SpatialPack(2, 2, 8, 1, 1, False),
SpatialPack(1, 4, 8, 4, 8, False),
SpatialPack(2, 2, 8, 1, 4, False),
SpatialPack(2, 2, 8, 1, 8, False),
Im2ColPack(7, 4, 1, 16, False),
Im2ColPack(7, 4, 1, 4, True),
]

@conv2d.register(["vtacpu", "vta"])
def compute(*args, **kwargs):
38 changes: 32 additions & 6 deletions vta/src/sim/sim_driver.cc
Original file line number Diff line number Diff line change
@@ -16,6 +16,11 @@
namespace vta {
namespace sim {

/*! \brief debug flag for skipping computation */
enum DebugFlagMask {
kSkipExec = 1
};

/*!
* \brief Helper class to pack and unpack bits
* Applies truncation when pack to low level bits.
@@ -234,8 +239,12 @@ class SRAM {
return &(data_[index]);
}
// Execute the load instruction on this SRAM
void Load(const VTAMemInsn* op, DRAM* dram, uint64_t* load_counter) {
void Load(const VTAMemInsn* op,
DRAM* dram,
uint64_t* load_counter,
bool skip_exec) {
load_counter[0] += (op->x_size * op->y_size) * kElemBytes;
if (skip_exec) return;
DType* sram_ptr = data_ + op->sram_base;
uint8_t* dram_ptr = static_cast<uint8_t*>(dram->GetAddr(
op->dram_base * kElemBytes));
@@ -306,6 +315,8 @@ class Profiler {
uint64_t gemm_counter{0};
/*! \brief instr counter for ALU ops */
uint64_t alu_counter{0};
/*! \brief set debug mode */
int64_t debug_flag{0};
/*! \brief clear the profiler */
void Clear() {
inp_load_nbytes = 0;
@@ -316,6 +327,10 @@ class Profiler {
gemm_counter = 0;
alu_counter = 0;
}
/*! \return Whether we should skip execution. */
bool SkipExec() const {
return (debug_flag & DebugFlagMask::kSkipExec) != 0;
}

std::string AsJSON() {
std::ostringstream os;
@@ -379,13 +394,15 @@ class Device {
void RunLoad(const VTAMemInsn* op) {
if (op->x_size == 0) return;
if (op->memory_type == VTA_MEM_ID_INP) {
inp_.Load(op, dram_, &(prof_->inp_load_nbytes));
inp_.Load(op, dram_, &(prof_->inp_load_nbytes), prof_->SkipExec());
} else if (op->memory_type == VTA_MEM_ID_WGT) {
wgt_.Load(op, dram_, &(prof_->wgt_load_nbytes));
wgt_.Load(op, dram_, &(prof_->wgt_load_nbytes), prof_->SkipExec());
} else if (op->memory_type == VTA_MEM_ID_ACC) {
acc_.Load(op, dram_, &(prof_->acc_load_nbytes));
acc_.Load(op, dram_, &(prof_->acc_load_nbytes), prof_->SkipExec());
} else if (op->memory_type == VTA_MEM_ID_UOP) {
uop_.Load(op, dram_, &(prof_->uop_load_nbytes));
// always load in uop, since uop is stateful
// subsequent non-debug mode exec can depend on it.
uop_.Load(op, dram_, &(prof_->uop_load_nbytes), false);
} else {
LOG(FATAL) << "Unknown memory_type=" << op->memory_type;
}
@@ -397,7 +414,9 @@ class Device {
op->memory_type == VTA_MEM_ID_UOP) {
prof_->out_store_nbytes += (
op->x_size * op->y_size * VTA_BATCH * VTA_BLOCK_OUT * VTA_OUT_WIDTH / 8);
acc_.TruncStore<VTA_OUT_WIDTH>(op, dram_);
if (!prof_->SkipExec()) {
acc_.TruncStore<VTA_OUT_WIDTH>(op, dram_);
}
} else {
LOG(FATAL) << "Store do not support memory_type="
<< op->memory_type;
@@ -407,6 +426,7 @@ class Device {
void RunGEMM(const VTAGemInsn* op) {
if (!op->reset_reg) {
prof_->gemm_counter += op->iter_out * op->iter_in * (op->uop_end - op->uop_bgn);
if (prof_->SkipExec()) return;
for (uint32_t y = 0; y < op->iter_out; ++y) {
for (uint32_t x = 0; x < op->iter_in; ++x) {
for (uint32_t uindex = op->uop_bgn; uindex < op->uop_end; ++uindex) {
@@ -440,6 +460,7 @@ class Device {
}
}
} else {
if (prof_->SkipExec()) return;
// reset
for (uint32_t y = 0; y < op->iter_out; ++y) {
for (uint32_t x = 0; x < op->iter_in; ++x) {
@@ -506,6 +527,7 @@ class Device {
template<bool use_imm, typename F>
void RunALULoop(const VTAAluInsn* op, F func) {
prof_->alu_counter += op->iter_out * op->iter_in * op->uop_end - op->uop_bgn;
if (prof_->SkipExec()) return;
for (int y = 0; y < op->iter_out; ++y) {
for (int x = 0; x < op->iter_in; ++x) {
for (int k = op->uop_bgn; k < op->uop_end; ++k) {
@@ -548,6 +570,10 @@ TVM_REGISTER_GLOBAL("vta.simulator.profiler_clear")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Profiler::ThreadLocal()->Clear();
});
TVM_REGISTER_GLOBAL("vta.simulator.profiler_debug_mode")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Profiler::ThreadLocal()->debug_flag = args[0];
});
TVM_REGISTER_GLOBAL("vta.simulator.profiler_status")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = Profiler::ThreadLocal()->AsJSON();
10 changes: 9 additions & 1 deletion vta/tests/python/unittest/test_vta_insn.py
Original file line number Diff line number Diff line change
@@ -183,8 +183,16 @@ def verify(s):

if env.TARGET == "sim":
simulator.clear_stats()
simulator.debug_mode(simulator.DEBUG_SKIP_EXEC)
f(x_nd, w_nd, y_nd)
print(simulator.stats())
stat1 = simulator.stats()
simulator.clear_stats()
simulator.debug_mode(0)
f(x_nd, w_nd, y_nd)
stat2 = simulator.stats()
for k, v in stat1.items():
if k != "uop_load_nbytes":
assert stat1[k] == stat2[k]
else:
f(x_nd, w_nd, y_nd)

0 comments on commit 35f8b96

Please sign in to comment.