Skip to content

Commit

Permalink
[DRIVER] Add simulator, unify testcase to unittest (apache#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Jul 12, 2018
1 parent 38fa9fa commit 226b929
Show file tree
Hide file tree
Showing 19 changed files with 1,217 additions and 541 deletions.
20 changes: 19 additions & 1 deletion vta/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,19 @@ ifneq ($(ADD_LDFLAGS), NONE)
LDFLAGS += $(ADD_LDFLAGS)
endif

UNAME_S := $(shell uname -s)

ifeq ($(UNAME_S), Darwin)
SHARED_LIBRARY_SUFFIX := dylib
WHOLE_ARCH= -all_load
NO_WHOLE_ARCH= -noall_load
LDFLAGS += -undefined dynamic_lookup
else
SHARED_LIBRARY_SUFFIX := so
WHOLE_ARCH= --whole-archive
NO_WHOLE_ARCH= --no-whole-archive
endif


all: lib/libvta.so lib/libvta_runtime.so

Expand All @@ -53,6 +66,10 @@ ifeq ($(TARGET), VTA_PYNQ_TARGET)
LDFLAGS += -l:libdma.so
endif

ifeq ($(TARGET), sim)
VTA_LIB_SRC += $(wildcard src/sim/*.cc)
endif

VTA_LIB_OBJ = $(patsubst src/%.cc, build/%.o, $(VTA_LIB_SRC))

build/%.o: src/%.cc
Expand All @@ -71,7 +88,7 @@ lib/libvta_runtime.so: build/runtime.o
lint: pylint cpplint

cpplint:
python nnvm/dmlc-core/scripts/lint.py vta cpp include src hardware tests
python nnvm/dmlc-core/scripts/lint.py vta cpp include src

pylint:
pylint python/vta --rcfile=$(ROOTDIR)/tests/lint/pylintrc
Expand All @@ -86,3 +103,4 @@ clean:
-include build/*.d
-include build/*/*.d
-include build/*/*/*.d
-include build/*/*/*/*.d
2 changes: 1 addition & 1 deletion vta/include/vta/driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ void VTAMemFree(void* buf);
* \param buf Pointer to memory region allocated with VTAMemAlloc.
* \return The physical address of the memory region.
*/
vta_phy_addr_t VTAGetMemPhysAddr(void* buf);
vta_phy_addr_t VTAMemGetPhyAddr(void* buf);

/*!
* \brief Flushes the region of memory out of the CPU cache to DRAM.
Expand Down
4 changes: 2 additions & 2 deletions vta/include/vta/hw_spec.h
Original file line number Diff line number Diff line change
Expand Up @@ -519,8 +519,8 @@ typedef struct {
uint64_t alu_opcode : VTA_ALU_OPCODE_BIT_WIDTH;
/*! \brief Use immediate is true */
uint64_t use_imm : 1;
/*! \brief Immediate value */
uint64_t imm : VTA_ALUOP_IMM_BIT_WIDTH;
/*! \brief Immediate value: allow negative value */
int64_t imm : VTA_ALUOP_IMM_BIT_WIDTH;
} VTAAluInsn;

/*! \brief VTA ALU instruction converter */
Expand Down
2 changes: 1 addition & 1 deletion vta/include/vta/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ void VTAUopPush(uint32_t mode,
uint32_t wgt_index,
uint32_t opcode,
uint32_t use_imm,
uint32_t imm_val);
int32_t imm_val);

/*!
* \brief Mark start of a micro op loop.
Expand Down
3 changes: 1 addition & 2 deletions vta/make/config.mk
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ ADD_LDFLAGS=
ADD_CFLAGS=

# the hardware target
TARGET = VTA_PYNQ_TARGET
TARGET = pynq

#---------------------
# VTA hardware parameters
Expand Down Expand Up @@ -89,7 +89,6 @@ VTA_OUT_BUFF_SIZE = $(shell echo "$$(( 1 << $(VTA_LOG_OUT_BUFF_SIZE) ))" )

# Update ADD_CFLAGS
ADD_CFLAGS += \
-D$(TARGET) \
-DVTA_LOG_WGT_WIDTH=$(VTA_LOG_WGT_WIDTH) -DVTA_LOG_INP_WIDTH=$(VTA_LOG_INP_WIDTH) \
-DVTA_LOG_ACC_WIDTH=$(VTA_LOG_ACC_WIDTH) -DVTA_LOG_OUT_WIDTH=$(VTA_LOG_OUT_WIDTH) \
-DVTA_LOG_BATCH=$(VTA_LOG_BATCH) \
Expand Down
1 change: 1 addition & 0 deletions vta/python/vta/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from . import arm_conv2d, vta_conv2d
from .build_module import build_config, lower, build
from .rpc_client import reconfig_runtime, program_fpga

from . import graph
except ImportError:
pass
22 changes: 18 additions & 4 deletions vta/python/vta/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class Environment(object):
"""
current = None
cfg_keys = [
"target",
"TARGET",
"LOG_INP_WIDTH",
"LOG_WGT_WIDTH",
"LOG_ACC_WIDTH",
Expand Down Expand Up @@ -204,9 +204,19 @@ def gemm(self):

@property
def gevm(self):
"""GEMM intrinsic"""
"""GEVM intrinsic"""
return self.dev.gevm

@property
def target_host(self):
"""The target host"""
if self.TARGET == "pynq":
return "llvm -target=armv7-none-linux-gnueabihf"
elif self.TARGET == "sim":
return "llvm"
else:
raise ValueError("Unknown target %s" % self.TARGET)


def get_env():
"""Get the current VTA Environment.
Expand Down Expand Up @@ -278,6 +288,7 @@ def _init_env():

for k in Environment.cfg_keys:
keys.add("VTA_" + k)
keys.add("TARGET")

if not os.path.isfile(filename):
raise RuntimeError(
Expand All @@ -290,8 +301,11 @@ def _init_env():
for k in keys:
if k +" =" in line:
val = line.split("=")[1].strip()
cfg[k[4:]] = int(val)
cfg["target"] = "pynq"
if k.startswith("VTA_"):
k = k[4:]
cfg[k] = int(val)
else:
cfg[k] = val
return Environment(cfg)

Environment.current = _init_env()
16 changes: 12 additions & 4 deletions vta/python/vta/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,7 @@ def _visit(op):
if not fail[0]:
begin = tvm.call_extern(
"int32", "VTAUopLoopBegin", stmt.extent, *gemm_offsets)
end = tvm.call_extern(
"int32", "VTAUopLoopEnd", stmt.extent, *gemm_offsets)
end = tvm.call_extern("int32", "VTAUopLoopEnd")
return [begin, ret, end]
raise ValueError("Failed to fold the GEMM instructions..")

Expand Down Expand Up @@ -683,8 +682,14 @@ def _flatten_loop(src_coeff, dst_coeff, extents):
else:
raise RuntimeError(
"Function call not recognized %s" % (loop_body.value.name))
elif isinstance(loop_body.value, tvm.expr.Load):
alu_opcode = env.dev.ALU_OPCODE_SHR
lhs = loop_body.value
rhs = tvm.const(0)
else:
raise RuntimeError("Expression not recognized %s" % (type(loop_body.value)))
raise RuntimeError(
"Expression not recognized %s, %s, %s" % (
type(loop_body.value), str(loop_body.value), str(stmt)))

# Derive array index coefficients
dst_coeff = tvm.arith.DetectLinearEquation(dst_idx, indices)
Expand Down Expand Up @@ -772,7 +777,9 @@ def _flatten_loop(src_coeff, dst_coeff, extents):
irb = tvm.ir_builder.create()
for idx, extent in enumerate(extents):
irb.emit(tvm.call_extern(
"int32", "VTAUopLoopBegin", extent, dst_coeff[idx], src_coeff[idx]))
"int32", "VTAUopLoopBegin",
extent, dst_coeff[idx], src_coeff[idx], 0))
use_imm = int(use_imm)
irb.emit(tvm.call_extern(
"int32", "VTAUopPush",
1, 0,
Expand Down Expand Up @@ -804,5 +811,6 @@ def debug_print(stmt):
stmt : Stmt
The
"""
# pylint: disable=superfluous-parens
print(stmt)
return stmt
3 changes: 1 addition & 2 deletions vta/python/vta/rpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ def reconfig_runtime(remote):
"VTA_LOG_WGT_BUFF_SIZE",
"VTA_LOG_ACC_BUFF_SIZE",
"VTA_LOG_OUT_BUFF_SIZE"]

cflags = ["-DVTA_%s_TARGET" % env.target.upper()]
cflags = []
for k in keys:
cflags += ["-D%s=%s" % (k, str(getattr(env, k[4:])))]
freconfig = remote.get_function("tvm.contrib.vta.reconfig_runtime")
Expand Down
3 changes: 3 additions & 0 deletions vta/python/vta/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""Testing utilities, this namespace is not imported by default."""

from . util import run
51 changes: 51 additions & 0 deletions vta/python/vta/testing/simulator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Utilities to start simulator."""
import os
import ctypes
import json
import tvm

def _load_lib():
"""Load local library, assuming they are simulator."""
# pylint: disable=unused-variable
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
dll_path = [
os.path.abspath(os.path.join(curr_path, "../../../lib/libvta.so")),
os.path.abspath(os.path.join(curr_path, "../../../lib/libvta_runtime.so"))
]
runtime_dll = []
if not all(os.path.exists(f) for f in dll_path):
return []
try:
for fname in dll_path:
runtime_dll.append(ctypes.CDLL(fname, ctypes.RTLD_GLOBAL))
return runtime_dll
except OSError:
return []


def enabled():
"""Check if simulator is enabled."""
f = tvm.get_global_func("vta.simulator.profiler_clear", True)
return f is not None


def clear_stats():
"""Clear profiler statistics"""
f = tvm.get_global_func("vta.simulator.profiler_clear", True)
if f:
f()


def stats():
"""Clear profiler statistics
Returns
-------
stats : dict
Current profiler statistics
"""
x = tvm.get_global_func("vta.simulator.profiler_status")()
return json.loads(x)


LIBS = _load_lib()
30 changes: 30 additions & 0 deletions vta/python/vta/testing/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Test Utilities"""
from __future__ import absolute_import as _abs

import os
from tvm.contrib import rpc
from ..environment import get_env
from . import simulator


def run(run_func):
"""Run test function on all available env.
Parameters
----------
run_func : function(env, remote)
"""
env = get_env()
# run on simulator
if simulator.enabled():
env.TARGET = "sim"
run_func(env, rpc.LocalSession())

# Run on PYNQ if env variable exists
pynq_host = os.environ.get("VTA_PYNQ_RPC_HOST", None)
if pynq_host:
env.TARGET = "pynq"
port = os.environ.get("VTA_PYNQ_RPC_PORT", "9091")
port = int(port)
remote = rpc.connect(pynq_host, port)
run_func(env, remote)
2 changes: 1 addition & 1 deletion vta/src/data_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ struct DataBuffer {
assert(data != nullptr);
DataBuffer* buffer = new DataBuffer();
buffer->data_ = data;
buffer->phy_addr_ = VTAGetMemPhysAddr(data);
buffer->phy_addr_ = VTAMemGetPhyAddr(data);
return buffer;
}
/*!
Expand Down
4 changes: 2 additions & 2 deletions vta/src/pynq/pynq_driver.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*!
* Copyright (c) 2018 by Contributors
* \file vta_pynq_driver.c
* \file pynq_driver.c
* \brief VTA driver for Pynq board.
*/

Expand All @@ -17,7 +17,7 @@ void VTAMemFree(void* buf) {
cma_free(buf);
}

vta_phy_addr_t VTAGetMemPhysAddr(void* buf) {
vta_phy_addr_t VTAMemGetPhyAddr(void* buf) {
return cma_get_phy_addr(buf);
}

Expand Down
25 changes: 11 additions & 14 deletions vta/src/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <vta/driver.h>
#include <vta/hw_spec.h>
#include <vta/runtime.h>
#include <dmlc/logging.h>

#include <cassert>
#include <vector>
Expand Down Expand Up @@ -113,7 +114,7 @@ class UopKernel {
uint32_t wgt_index,
uint32_t opcode,
uint32_t use_imm,
uint32_t imm_val) {
int32_t imm_val) {
// The loop nest structure
VerifyDep(dst_index);
VTAUop op;
Expand Down Expand Up @@ -166,7 +167,7 @@ class UopKernel {
uint32_t opcode_{0xFFFFFFFF};
uint32_t reset_out_{0xFFFFFFFF};
bool use_imm_{false};
uint16_t imm_val_{0};
int16_t imm_val_{0};

private:
// Verify that we don't write to the same acc_mem index two cycles in a row
Expand Down Expand Up @@ -195,10 +196,6 @@ class UopKernel {

/*!
* \brief Base class of all queues to send and recv serial data.
* \param kElemBytes Element unit bytes.
* \param kMaxBytes Maximum number of bytes.
* \param kCoherent Whether we have coherent access to the buffer.
* \param kAlwaysCache Wether we should use cached memory.
*/
class BaseQueue {
public:
Expand Down Expand Up @@ -227,7 +224,7 @@ class BaseQueue {
dram_buffer_ = static_cast<char*>(VTAMemAlloc(
max_bytes, coherent || always_cache_));
assert(dram_buffer_ != nullptr);
dram_phy_addr_ = VTAGetMemPhysAddr(dram_buffer_);
dram_phy_addr_ = VTAMemGetPhyAddr(dram_buffer_);
}
/*!
* \brief Reset the pointer of the buffer.
Expand Down Expand Up @@ -597,14 +594,14 @@ class InsnQueue : public BaseQueue {
}
// Print instruction field information
if (c.mem.opcode == VTA_OPCODE_LOAD) {
printf("LOAD ");
if (c.mem.memory_type == VTA_MEM_ID_UOP) printf("UOP\n");
if (c.mem.memory_type == VTA_MEM_ID_WGT) printf("WGT\n");
if (c.mem.memory_type == VTA_MEM_ID_INP) printf("INP\n");
if (c.mem.memory_type == VTA_MEM_ID_ACC) printf("ACC\n");
printf("LOAD ");
if (c.mem.memory_type == VTA_MEM_ID_UOP) printf("UOP\n");
if (c.mem.memory_type == VTA_MEM_ID_WGT) printf("WGT\n");
if (c.mem.memory_type == VTA_MEM_ID_INP) printf("INP\n");
if (c.mem.memory_type == VTA_MEM_ID_ACC) printf("ACC\n");
}
if (c.mem.opcode == VTA_OPCODE_STORE) {
printf("STORE\n");
printf("STORE:\n");
}
printf("\tdep - pop prev: %d, pop next: %d, push prev: %d, push next: %d\n",
static_cast<int>(c.mem.pop_prev_dep),
Expand Down Expand Up @@ -1210,7 +1207,7 @@ void VTAUopPush(uint32_t mode,
uint32_t wgt_index,
uint32_t opcode,
uint32_t use_imm,
uint32_t imm_val) {
int32_t imm_val) {
vta::CommandQueue::ThreadLocal()->record_kernel()
->Push(mode, reset_out, dst_index, src_index,
wgt_index, opcode, use_imm, imm_val);
Expand Down
Loading

0 comments on commit 226b929

Please sign in to comment.