From 226b9290c3e86698ba9af2d9fd77d197591d2cac Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Thu, 12 Apr 2018 22:12:43 -0700 Subject: [PATCH] [DRIVER] Add simulator, unify testcase to unittest (#25) --- vta/Makefile | 20 +- vta/include/vta/driver.h | 2 +- vta/include/vta/hw_spec.h | 4 +- vta/include/vta/runtime.h | 2 +- vta/make/config.mk | 3 +- vta/python/vta/__init__.py | 1 + vta/python/vta/environment.py | 22 +- vta/python/vta/ir_pass.py | 16 +- vta/python/vta/rpc_client.py | 3 +- vta/python/vta/testing/__init__.py | 3 + vta/python/vta/testing/simulator.py | 51 ++ vta/python/vta/testing/util.py | 30 ++ vta/src/data_buffer.h | 2 +- vta/src/pynq/pynq_driver.cc | 4 +- vta/src/runtime.cc | 25 +- vta/src/sim/sim_driver.cc | 581 +++++++++++++++++++++ vta/src/tvm/vta_device_api.cc | 3 - vta/tests/python/pynq/test_vta_insn.py | 504 ------------------ vta/tests/python/unittest/test_vta_insn.py | 482 +++++++++++++++++ 19 files changed, 1217 insertions(+), 541 deletions(-) create mode 100644 vta/python/vta/testing/__init__.py create mode 100644 vta/python/vta/testing/simulator.py create mode 100644 vta/python/vta/testing/util.py create mode 100644 vta/src/sim/sim_driver.cc delete mode 100644 vta/tests/python/pynq/test_vta_insn.py create mode 100644 vta/tests/python/unittest/test_vta_insn.py diff --git a/vta/Makefile b/vta/Makefile index 069f6e01c3ea..6bfa82dc2e10 100644 --- a/vta/Makefile +++ b/vta/Makefile @@ -40,6 +40,19 @@ ifneq ($(ADD_LDFLAGS), NONE) LDFLAGS += $(ADD_LDFLAGS) endif +UNAME_S := $(shell uname -s) + +ifeq ($(UNAME_S), Darwin) + SHARED_LIBRARY_SUFFIX := dylib + WHOLE_ARCH= -all_load + NO_WHOLE_ARCH= -noall_load + LDFLAGS += -undefined dynamic_lookup +else + SHARED_LIBRARY_SUFFIX := so + WHOLE_ARCH= --whole-archive + NO_WHOLE_ARCH= --no-whole-archive +endif + all: lib/libvta.so lib/libvta_runtime.so @@ -53,6 +66,10 @@ ifeq ($(TARGET), VTA_PYNQ_TARGET) LDFLAGS += -l:libdma.so endif +ifeq ($(TARGET), sim) + VTA_LIB_SRC += $(wildcard src/sim/*.cc) +endif + VTA_LIB_OBJ = $(patsubst src/%.cc, build/%.o, $(VTA_LIB_SRC)) build/%.o: src/%.cc @@ -71,7 +88,7 @@ lib/libvta_runtime.so: build/runtime.o lint: pylint cpplint cpplint: - python nnvm/dmlc-core/scripts/lint.py vta cpp include src hardware tests + python nnvm/dmlc-core/scripts/lint.py vta cpp include src pylint: pylint python/vta --rcfile=$(ROOTDIR)/tests/lint/pylintrc @@ -86,3 +103,4 @@ clean: -include build/*.d -include build/*/*.d -include build/*/*/*.d +-include build/*/*/*/*.d diff --git a/vta/include/vta/driver.h b/vta/include/vta/driver.h index 8a29fc47aa84..58f778806208 100644 --- a/vta/include/vta/driver.h +++ b/vta/include/vta/driver.h @@ -77,7 +77,7 @@ void VTAMemFree(void* buf); * \param buf Pointer to memory region allocated with VTAMemAlloc. * \return The physical address of the memory region. */ -vta_phy_addr_t VTAGetMemPhysAddr(void* buf); +vta_phy_addr_t VTAMemGetPhyAddr(void* buf); /*! * \brief Flushes the region of memory out of the CPU cache to DRAM. diff --git a/vta/include/vta/hw_spec.h b/vta/include/vta/hw_spec.h index 7eae322a0aca..9d62d4e7d57e 100644 --- a/vta/include/vta/hw_spec.h +++ b/vta/include/vta/hw_spec.h @@ -519,8 +519,8 @@ typedef struct { uint64_t alu_opcode : VTA_ALU_OPCODE_BIT_WIDTH; /*! \brief Use immediate is true */ uint64_t use_imm : 1; - /*! \brief Immediate value */ - uint64_t imm : VTA_ALUOP_IMM_BIT_WIDTH; + /*! \brief Immediate value: allow negative value */ + int64_t imm : VTA_ALUOP_IMM_BIT_WIDTH; } VTAAluInsn; /*! \brief VTA ALU instruction converter */ diff --git a/vta/include/vta/runtime.h b/vta/include/vta/runtime.h index c9373846d903..479540129226 100644 --- a/vta/include/vta/runtime.h +++ b/vta/include/vta/runtime.h @@ -196,7 +196,7 @@ void VTAUopPush(uint32_t mode, uint32_t wgt_index, uint32_t opcode, uint32_t use_imm, - uint32_t imm_val); + int32_t imm_val); /*! * \brief Mark start of a micro op loop. diff --git a/vta/make/config.mk b/vta/make/config.mk index 9f611896a8ba..e329dcf987b8 100644 --- a/vta/make/config.mk +++ b/vta/make/config.mk @@ -27,7 +27,7 @@ ADD_LDFLAGS= ADD_CFLAGS= # the hardware target -TARGET = VTA_PYNQ_TARGET +TARGET = pynq #--------------------- # VTA hardware parameters @@ -89,7 +89,6 @@ VTA_OUT_BUFF_SIZE = $(shell echo "$$(( 1 << $(VTA_LOG_OUT_BUFF_SIZE) ))" ) # Update ADD_CFLAGS ADD_CFLAGS += \ - -D$(TARGET) \ -DVTA_LOG_WGT_WIDTH=$(VTA_LOG_WGT_WIDTH) -DVTA_LOG_INP_WIDTH=$(VTA_LOG_INP_WIDTH) \ -DVTA_LOG_ACC_WIDTH=$(VTA_LOG_ACC_WIDTH) -DVTA_LOG_OUT_WIDTH=$(VTA_LOG_OUT_WIDTH) \ -DVTA_LOG_BATCH=$(VTA_LOG_BATCH) \ diff --git a/vta/python/vta/__init__.py b/vta/python/vta/__init__.py index 275c15b227a3..693a4124f40b 100644 --- a/vta/python/vta/__init__.py +++ b/vta/python/vta/__init__.py @@ -8,6 +8,7 @@ from . import arm_conv2d, vta_conv2d from .build_module import build_config, lower, build from .rpc_client import reconfig_runtime, program_fpga + from . import graph except ImportError: pass diff --git a/vta/python/vta/environment.py b/vta/python/vta/environment.py index b0d7d170fd7d..8ff2bbce2787 100644 --- a/vta/python/vta/environment.py +++ b/vta/python/vta/environment.py @@ -89,7 +89,7 @@ class Environment(object): """ current = None cfg_keys = [ - "target", + "TARGET", "LOG_INP_WIDTH", "LOG_WGT_WIDTH", "LOG_ACC_WIDTH", @@ -204,9 +204,19 @@ def gemm(self): @property def gevm(self): - """GEMM intrinsic""" + """GEVM intrinsic""" return self.dev.gevm + @property + def target_host(self): + """The target host""" + if self.TARGET == "pynq": + return "llvm -target=armv7-none-linux-gnueabihf" + elif self.TARGET == "sim": + return "llvm" + else: + raise ValueError("Unknown target %s" % self.TARGET) + def get_env(): """Get the current VTA Environment. @@ -278,6 +288,7 @@ def _init_env(): for k in Environment.cfg_keys: keys.add("VTA_" + k) + keys.add("TARGET") if not os.path.isfile(filename): raise RuntimeError( @@ -290,8 +301,11 @@ def _init_env(): for k in keys: if k +" =" in line: val = line.split("=")[1].strip() - cfg[k[4:]] = int(val) - cfg["target"] = "pynq" + if k.startswith("VTA_"): + k = k[4:] + cfg[k] = int(val) + else: + cfg[k] = val return Environment(cfg) Environment.current = _init_env() diff --git a/vta/python/vta/ir_pass.py b/vta/python/vta/ir_pass.py index f85f1760e72b..9310d46dce88 100644 --- a/vta/python/vta/ir_pass.py +++ b/vta/python/vta/ir_pass.py @@ -78,8 +78,7 @@ def _visit(op): if not fail[0]: begin = tvm.call_extern( "int32", "VTAUopLoopBegin", stmt.extent, *gemm_offsets) - end = tvm.call_extern( - "int32", "VTAUopLoopEnd", stmt.extent, *gemm_offsets) + end = tvm.call_extern("int32", "VTAUopLoopEnd") return [begin, ret, end] raise ValueError("Failed to fold the GEMM instructions..") @@ -683,8 +682,14 @@ def _flatten_loop(src_coeff, dst_coeff, extents): else: raise RuntimeError( "Function call not recognized %s" % (loop_body.value.name)) + elif isinstance(loop_body.value, tvm.expr.Load): + alu_opcode = env.dev.ALU_OPCODE_SHR + lhs = loop_body.value + rhs = tvm.const(0) else: - raise RuntimeError("Expression not recognized %s" % (type(loop_body.value))) + raise RuntimeError( + "Expression not recognized %s, %s, %s" % ( + type(loop_body.value), str(loop_body.value), str(stmt))) # Derive array index coefficients dst_coeff = tvm.arith.DetectLinearEquation(dst_idx, indices) @@ -772,7 +777,9 @@ def _flatten_loop(src_coeff, dst_coeff, extents): irb = tvm.ir_builder.create() for idx, extent in enumerate(extents): irb.emit(tvm.call_extern( - "int32", "VTAUopLoopBegin", extent, dst_coeff[idx], src_coeff[idx])) + "int32", "VTAUopLoopBegin", + extent, dst_coeff[idx], src_coeff[idx], 0)) + use_imm = int(use_imm) irb.emit(tvm.call_extern( "int32", "VTAUopPush", 1, 0, @@ -804,5 +811,6 @@ def debug_print(stmt): stmt : Stmt The """ + # pylint: disable=superfluous-parens print(stmt) return stmt diff --git a/vta/python/vta/rpc_client.py b/vta/python/vta/rpc_client.py index a6355a592279..f0a02d7cc045 100644 --- a/vta/python/vta/rpc_client.py +++ b/vta/python/vta/rpc_client.py @@ -24,8 +24,7 @@ def reconfig_runtime(remote): "VTA_LOG_WGT_BUFF_SIZE", "VTA_LOG_ACC_BUFF_SIZE", "VTA_LOG_OUT_BUFF_SIZE"] - - cflags = ["-DVTA_%s_TARGET" % env.target.upper()] + cflags = [] for k in keys: cflags += ["-D%s=%s" % (k, str(getattr(env, k[4:])))] freconfig = remote.get_function("tvm.contrib.vta.reconfig_runtime") diff --git a/vta/python/vta/testing/__init__.py b/vta/python/vta/testing/__init__.py new file mode 100644 index 000000000000..513fa1e99a52 --- /dev/null +++ b/vta/python/vta/testing/__init__.py @@ -0,0 +1,3 @@ +"""Testing utilities, this namespace is not imported by default.""" + +from . util import run diff --git a/vta/python/vta/testing/simulator.py b/vta/python/vta/testing/simulator.py new file mode 100644 index 000000000000..bb436a1853a8 --- /dev/null +++ b/vta/python/vta/testing/simulator.py @@ -0,0 +1,51 @@ +"""Utilities to start simulator.""" +import os +import ctypes +import json +import tvm + +def _load_lib(): + """Load local library, assuming they are simulator.""" + # pylint: disable=unused-variable + curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) + dll_path = [ + os.path.abspath(os.path.join(curr_path, "../../../lib/libvta.so")), + os.path.abspath(os.path.join(curr_path, "../../../lib/libvta_runtime.so")) + ] + runtime_dll = [] + if not all(os.path.exists(f) for f in dll_path): + return [] + try: + for fname in dll_path: + runtime_dll.append(ctypes.CDLL(fname, ctypes.RTLD_GLOBAL)) + return runtime_dll + except OSError: + return [] + + +def enabled(): + """Check if simulator is enabled.""" + f = tvm.get_global_func("vta.simulator.profiler_clear", True) + return f is not None + + +def clear_stats(): + """Clear profiler statistics""" + f = tvm.get_global_func("vta.simulator.profiler_clear", True) + if f: + f() + + +def stats(): + """Clear profiler statistics + + Returns + ------- + stats : dict + Current profiler statistics + """ + x = tvm.get_global_func("vta.simulator.profiler_status")() + return json.loads(x) + + +LIBS = _load_lib() diff --git a/vta/python/vta/testing/util.py b/vta/python/vta/testing/util.py new file mode 100644 index 000000000000..bbf6417a167e --- /dev/null +++ b/vta/python/vta/testing/util.py @@ -0,0 +1,30 @@ +"""Test Utilities""" +from __future__ import absolute_import as _abs + +import os +from tvm.contrib import rpc +from ..environment import get_env +from . import simulator + + +def run(run_func): + """Run test function on all available env. + + Parameters + ---------- + run_func : function(env, remote) + """ + env = get_env() + # run on simulator + if simulator.enabled(): + env.TARGET = "sim" + run_func(env, rpc.LocalSession()) + + # Run on PYNQ if env variable exists + pynq_host = os.environ.get("VTA_PYNQ_RPC_HOST", None) + if pynq_host: + env.TARGET = "pynq" + port = os.environ.get("VTA_PYNQ_RPC_PORT", "9091") + port = int(port) + remote = rpc.connect(pynq_host, port) + run_func(env, remote) diff --git a/vta/src/data_buffer.h b/vta/src/data_buffer.h index aed92c49e795..fba46dc07efa 100644 --- a/vta/src/data_buffer.h +++ b/vta/src/data_buffer.h @@ -57,7 +57,7 @@ struct DataBuffer { assert(data != nullptr); DataBuffer* buffer = new DataBuffer(); buffer->data_ = data; - buffer->phy_addr_ = VTAGetMemPhysAddr(data); + buffer->phy_addr_ = VTAMemGetPhyAddr(data); return buffer; } /*! diff --git a/vta/src/pynq/pynq_driver.cc b/vta/src/pynq/pynq_driver.cc index 0330450db285..e2630b14acde 100644 --- a/vta/src/pynq/pynq_driver.cc +++ b/vta/src/pynq/pynq_driver.cc @@ -1,6 +1,6 @@ /*! * Copyright (c) 2018 by Contributors - * \file vta_pynq_driver.c + * \file pynq_driver.c * \brief VTA driver for Pynq board. */ @@ -17,7 +17,7 @@ void VTAMemFree(void* buf) { cma_free(buf); } -vta_phy_addr_t VTAGetMemPhysAddr(void* buf) { +vta_phy_addr_t VTAMemGetPhyAddr(void* buf) { return cma_get_phy_addr(buf); } diff --git a/vta/src/runtime.cc b/vta/src/runtime.cc index a8819323fc11..da5109c141ee 100644 --- a/vta/src/runtime.cc +++ b/vta/src/runtime.cc @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -113,7 +114,7 @@ class UopKernel { uint32_t wgt_index, uint32_t opcode, uint32_t use_imm, - uint32_t imm_val) { + int32_t imm_val) { // The loop nest structure VerifyDep(dst_index); VTAUop op; @@ -166,7 +167,7 @@ class UopKernel { uint32_t opcode_{0xFFFFFFFF}; uint32_t reset_out_{0xFFFFFFFF}; bool use_imm_{false}; - uint16_t imm_val_{0}; + int16_t imm_val_{0}; private: // Verify that we don't write to the same acc_mem index two cycles in a row @@ -195,10 +196,6 @@ class UopKernel { /*! * \brief Base class of all queues to send and recv serial data. - * \param kElemBytes Element unit bytes. - * \param kMaxBytes Maximum number of bytes. - * \param kCoherent Whether we have coherent access to the buffer. - * \param kAlwaysCache Wether we should use cached memory. */ class BaseQueue { public: @@ -227,7 +224,7 @@ class BaseQueue { dram_buffer_ = static_cast(VTAMemAlloc( max_bytes, coherent || always_cache_)); assert(dram_buffer_ != nullptr); - dram_phy_addr_ = VTAGetMemPhysAddr(dram_buffer_); + dram_phy_addr_ = VTAMemGetPhyAddr(dram_buffer_); } /*! * \brief Reset the pointer of the buffer. @@ -597,14 +594,14 @@ class InsnQueue : public BaseQueue { } // Print instruction field information if (c.mem.opcode == VTA_OPCODE_LOAD) { - printf("LOAD "); - if (c.mem.memory_type == VTA_MEM_ID_UOP) printf("UOP\n"); - if (c.mem.memory_type == VTA_MEM_ID_WGT) printf("WGT\n"); - if (c.mem.memory_type == VTA_MEM_ID_INP) printf("INP\n"); - if (c.mem.memory_type == VTA_MEM_ID_ACC) printf("ACC\n"); + printf("LOAD "); + if (c.mem.memory_type == VTA_MEM_ID_UOP) printf("UOP\n"); + if (c.mem.memory_type == VTA_MEM_ID_WGT) printf("WGT\n"); + if (c.mem.memory_type == VTA_MEM_ID_INP) printf("INP\n"); + if (c.mem.memory_type == VTA_MEM_ID_ACC) printf("ACC\n"); } if (c.mem.opcode == VTA_OPCODE_STORE) { - printf("STORE\n"); + printf("STORE:\n"); } printf("\tdep - pop prev: %d, pop next: %d, push prev: %d, push next: %d\n", static_cast(c.mem.pop_prev_dep), @@ -1210,7 +1207,7 @@ void VTAUopPush(uint32_t mode, uint32_t wgt_index, uint32_t opcode, uint32_t use_imm, - uint32_t imm_val) { + int32_t imm_val) { vta::CommandQueue::ThreadLocal()->record_kernel() ->Push(mode, reset_out, dst_index, src_index, wgt_index, opcode, use_imm, imm_val); diff --git a/vta/src/sim/sim_driver.cc b/vta/src/sim/sim_driver.cc new file mode 100644 index 000000000000..a88ab2466a25 --- /dev/null +++ b/vta/src/sim/sim_driver.cc @@ -0,0 +1,581 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file sim_driver.cc + * \brief VTA driver for simulated backend. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace vta { +namespace sim { + +/*! + * \brief Helper class to pack and unpack bits + * Applies truncation when pack to low level bits. + * + * \tparam bits The number of bits in integer. + * \note This implementation relies on little endian. + */ +template +class BitPacker { + public: + explicit BitPacker(void* data) { + data_ = static_cast(data); + } + + uint32_t GetUnsigned(uint32_t index) const { + if (bits == 32) { + return data_[index]; + } else if (bits == 16) { + return reinterpret_cast(data_)[index]; + } else if (bits == 8) { + return reinterpret_cast(data_)[index]; + } else { + uint32_t offset = index / kNumPackElem; + uint32_t shift = index % kNumPackElem; + return (data_[offset] >> shift) & kMask; + } + } + + int32_t GetSigned(uint32_t index) const { + if (bits == 32) { + return reinterpret_cast(data_)[index]; + } else if (bits == 16) { + return reinterpret_cast(data_)[index]; + } else if (bits == 8) { + return reinterpret_cast(data_)[index]; + } else { + uint32_t offset = index / kNumPackElem; + uint32_t shift = (index % kNumPackElem) * bits; + int32_t uvalue = static_cast( + (data_[offset] >> shift) & kMask); + int kleft = 32 - bits; + return (uvalue << kleft) >> kleft; + } + } + + void SetUnsigned(uint32_t index, uint32_t value) { + if (bits == 32) { + data_[index] = value; + } else if (bits == 16) { + reinterpret_cast(data_)[index] = value; + } else if (bits == 8) { + reinterpret_cast(data_)[index] = value; + } else { + uint32_t offset = index / kNumPackElem; + uint32_t shift = (index % kNumPackElem) * bits; + data_[offset] &= (~(kMask << shift)); + data_[offset] |= (value & kMask) << shift; + } + } + + void SetSigned(uint32_t index, int32_t value) { + if (bits == 32) { + reinterpret_cast(data_)[index] = value; + } else if (bits == 16) { + reinterpret_cast(data_)[index] = value; + } else if (bits == 8) { + reinterpret_cast(data_)[index] = value; + } else { + uint32_t offset = index / kNumPackElem; + uint32_t shift = (index % kNumPackElem) * bits; + data_[offset] &= (~(kMask << shift)); + data_[offset] |= static_cast(value & kMask) << shift; + } + } + + private: + uint32_t* data_; + static constexpr uint32_t kNumPackElem = 32 / bits; + static constexpr uint32_t kMask = (1U << (bits >= 32U ? 31U : bits)) - 1U; +}; + +/*! + * \brief DRAM memory manager + * Implements simple paging to allow physical address translation. + */ +class DRAM { + public: + /*! + * \brief Get virtual address given physical address. + * \param phy_addr The simulator phyiscal address. + * \return The true virtual address; + */ + void* GetAddr(uint64_t phy_addr) { + std::lock_guard lock(mutex_); + uint64_t loc = (phy_addr >> kPageBits) - 1; + CHECK_LT(loc, ptable_.size()); + Page* p = ptable_[loc]; + CHECK(p != nullptr); + size_t offset = (loc - p->ptable_begin) << kPageBits; + offset += phy_addr & (kPageSize - 1); + return reinterpret_cast(p->data) + offset; + } + /*! + * \brief Get physical address + * \param buf The virtual address. + * \return The true physical address; + */ + vta_phy_addr_t GetPhyAddr(void* buf) { + std::lock_guard lock(mutex_); + auto it = pmap_.find(buf); + CHECK(it != pmap_.end()); + Page* p = it->second.get(); + return (p->ptable_begin + 1) << kPageBits; + } + /*! + * \brief Allocate memory from manager + * \param size The size of memory + * \return The virtual address + */ + void* Alloc(size_t size) { + std::lock_guard lock(mutex_); + size_t npage = (size + kPageSize - 1) / kPageSize; + auto it = free_map_.lower_bound(npage); + if (it != free_map_.end()) { + Page* p = it->second; + free_map_.erase(it); + return p->data; + } + size_t start = ptable_.size(); + std::unique_ptr p(new Page(start, npage)); + // insert page entry + ptable_.resize(start + npage, p.get()); + void* data = p->data; + pmap_[data] = std::move(p); + return data; + } + /*! + * \brief Free the memory. + * \param size The size of memory + * \return The virtual address + */ + void Free(void* data) { + std::lock_guard lock(mutex_); + auto it = pmap_.find(data); + CHECK(it != pmap_.end()); + Page* p = it->second.get(); + free_map_.insert(std::make_pair(p->num_pages, p)); + } + + static DRAM* Global() { + static DRAM inst; + return &inst; + } + + + private: + // The bits in page table + static constexpr vta_phy_addr_t kPageBits = 16; + // page size, also the maximum allocable size 16 K + static constexpr vta_phy_addr_t kPageSize = 1 << kPageBits; + /*! \brief A page in the DRAM */ + struct Page { + /*! \brief Data Type */ + using DType = typename std::aligned_storage::type; + /*! \brief Start location in page table */ + size_t ptable_begin; + /*! \brief The total number of pages */ + size_t num_pages; + /*! \brief Data */ + DType* data{nullptr}; + // construct a new page + explicit Page(size_t ptable_begin, size_t num_pages) + : ptable_begin(ptable_begin), num_pages(num_pages) { + data = new DType[num_pages]; + } + ~Page() { + delete [] data; + } + }; + // Internal lock + std::mutex mutex_; + // Physical address -> page + std::vector ptable_; + // virtual addres -> page + std::unordered_map > pmap_; + // Free map + std::multimap free_map_; +}; + +/*! + * \brief Register file. + * \tparam kBits Number of bits of one value. + * \tparam kLane Number of lanes in one element. + * \tparam kMaxNumElem Maximum number of element. + */ +template +class SRAM { + public: + /*! \brief Bytes of single vector element */ + static const int kElemBytes = (kBits * kLane + 7) / 8; + /*! \brief content data type */ + using DType = typename std::aligned_storage::type; + SRAM() { + data_ = new DType[kMaxNumElem]; + } + ~SRAM() { + delete [] data_; + } + // Get the i-th index + void* BeginPtr(uint32_t index) { + CHECK_LT(index, kMaxNumElem); + return &(data_[index]); + } + // Execute the load instruction on this SRAM + void Load(const VTAMemInsn* op, DRAM* dram, uint64_t* load_counter) { + load_counter[0] += (op->x_size * op->y_size) * kElemBytes; + DType* sram_ptr = data_ + op->sram_base; + uint8_t* dram_ptr = static_cast(dram->GetAddr( + op->dram_base * kElemBytes)); + uint64_t xtotal = op->x_size + op->x_pad_0 + op->x_pad_1; + uint32_t ytotal = op->y_size + op->y_pad_0 + op->y_pad_1; + uint64_t sram_end = op->sram_base + xtotal * ytotal; + CHECK_LE(sram_end, kMaxNumElem); + memset(sram_ptr, 0, kElemBytes * xtotal * op->y_pad_0); + sram_ptr += xtotal * op->y_pad_0; + for (uint32_t y = 0; y < op->y_size; ++y) { + memset(sram_ptr, 0, kElemBytes * op->x_pad_0); + sram_ptr += op->x_pad_0; + memcpy(sram_ptr, dram_ptr, kElemBytes * op->x_size); + sram_ptr += op->x_size; + BitPacker src(sram_ptr); + memset(sram_ptr, 0, kElemBytes * op->x_pad_1); + sram_ptr += op->x_pad_1; + dram_ptr += kElemBytes * op->x_stride; + } + memset(sram_ptr, 0, kElemBytes * xtotal * op->y_pad_1); + } + // Execute the store instruction on this SRAM apply trucation. + // This relies on the elements is 32 bits + template + void TruncStore(const VTAMemInsn* op, DRAM* dram) { + CHECK_EQ(op->x_pad_0, 0); + CHECK_EQ(op->x_pad_1, 0); + CHECK_EQ(op->y_pad_0, 0); + CHECK_EQ(op->y_pad_1, 0); + int target_width = (target_bits * kLane + 7) / 8; + BitPacker src(data_ + op->sram_base); + BitPacker dst(dram->GetAddr(op->dram_base * target_width)); + for (uint32_t y = 0; y < op->y_size; ++y) { + for (uint32_t x = 0; x < op->x_size; ++x) { + uint32_t sram_base = y * op->x_size + x; + uint32_t dram_base = y * op->x_stride + x; + for (int i = 0; i < kLane; ++i) { + dst.SetSigned(dram_base * kLane + i, + src.GetSigned(sram_base * kLane +i)); + } + } + } + } + + private: + /*! \brief internal data content */ + DType* data_; +}; + + +/*! + * \brief Memory information of special memory region. + * Use MemoryInfo as its container type + */ +class Profiler { + public: + /*! \brief The memory load statistics */ + uint64_t inp_load_nbytes{0}; + /*! \brief The memory load statistics */ + uint64_t wgt_load_nbytes{0}; + /*! \brief The ACC memory load statistics */ + uint64_t acc_load_nbytes{0}; + /*! \brief The ACC memory load statistics */ + uint64_t uop_load_nbytes{0}; + /*! \brief The ACC memory load statistics */ + uint64_t out_store_nbytes{0}; + /*! \brief instr counter for gemm */ + uint64_t gemm_counter{0}; + /*! \brief instr counter for ALU ops */ + uint64_t alu_counter{0}; + /*! \brief clear the profiler */ + void Clear() { + inp_load_nbytes = 0; + wgt_load_nbytes = 0; + acc_load_nbytes = 0; + uop_load_nbytes = 0; + out_store_nbytes = 0; + gemm_counter = 0; + alu_counter = 0; + } + + std::string AsJSON() { + std::ostringstream os; + os << "{\n" + << " \"inp_load_nbytes\":" << inp_load_nbytes << ",\n" + << " \"wgt_load_nbytes\":" << wgt_load_nbytes << ",\n" + << " \"acc_load_nbytes\":" << acc_load_nbytes << ",\n" + << " \"uop_load_nbytes\":" << uop_load_nbytes << ",\n" + << " \"out_store_nbytes\":" << out_store_nbytes << ",\n" + << " \"gemm_counter\":" << gemm_counter << ",\n" + << " \"alu_counter\":" << alu_counter << "\n" + <<"}\n"; + return os.str(); + } + + static Profiler* ThreadLocal() { + static thread_local Profiler inst; + return &inst; + } +}; + + +// Simulate device +// TODO(tqchen,thierry): queue based event driven simulation. +class Device { + public: + Device() { + prof_ = Profiler::ThreadLocal(); + dram_ = DRAM::Global(); + } + + int Run(vta_phy_addr_t insn_phy_addr, + uint32_t insn_count, + uint32_t wait_cycles) { + VTAGenericInsn* insn = static_cast( + dram_->GetAddr(insn_phy_addr)); + finish_counter_ = 0; + for (uint32_t i = 0; i < insn_count; ++i) { + this->Run(insn + i); + } + return 0; + } + + private: + void Run(const VTAGenericInsn* insn) { + const VTAMemInsn* mem = reinterpret_cast(insn); + const VTAGemInsn* gem = reinterpret_cast(insn); + const VTAAluInsn* alu = reinterpret_cast(insn); + switch (mem->opcode) { + case VTA_OPCODE_LOAD: RunLoad(mem); break; + case VTA_OPCODE_STORE: RunStore(mem); break; + case VTA_OPCODE_GEMM: RunGEMM(gem); break; + case VTA_OPCODE_ALU: RunALU(alu); break; + case VTA_OPCODE_FINISH: ++finish_counter_; break; + default: { + LOG(FATAL) << "Unknown op_code" << mem->opcode; + } + } + } + + void RunLoad(const VTAMemInsn* op) { + if (op->x_size == 0) return; + if (op->memory_type == VTA_MEM_ID_INP) { + inp_.Load(op, dram_, &(prof_->inp_load_nbytes)); + } else if (op->memory_type == VTA_MEM_ID_WGT) { + wgt_.Load(op, dram_, &(prof_->wgt_load_nbytes)); + } else if (op->memory_type == VTA_MEM_ID_ACC) { + acc_.Load(op, dram_, &(prof_->acc_load_nbytes)); + } else if (op->memory_type == VTA_MEM_ID_UOP) { + uop_.Load(op, dram_, &(prof_->uop_load_nbytes)); + } else { + LOG(FATAL) << "Unknown memory_type=" << op->memory_type; + } + } + + void RunStore(const VTAMemInsn* op) { + if (op->memory_type == VTA_MEM_ID_ACC || + op->memory_type == VTA_MEM_ID_UOP) { + prof_->out_store_nbytes += ( + op->x_size * op->y_size * VTA_BATCH * VTA_BLOCK_OUT * VTA_OUT_WIDTH / 8); + acc_.TruncStore(op, dram_); + } else { + LOG(FATAL) << "Store do not support memory_type=" + << op->memory_type; + } + } + + void RunGEMM(const VTAGemInsn* op) { + if (!op->reset_reg) { + prof_->gemm_counter += op->iter_out * op->iter_in; + for (uint32_t y = 0; y < op->iter_out; ++y) { + for (uint32_t x = 0; x < op->iter_in; ++x) { + for (uint32_t uindex = op->uop_bgn; uindex < op->uop_end; ++uindex) { + VTAUop* uop_ptr = static_cast(uop_.BeginPtr(uindex)); + // Read in memory indices + uint32_t acc_idx = uop_ptr->dst_idx; + uint32_t inp_idx = uop_ptr->src_idx; + uint32_t wgt_idx = uop_ptr->wgt_idx; + acc_idx += y * op->dst_factor_out + x * op->dst_factor_in; + inp_idx += y * op->src_factor_out + x * op->src_factor_in; + wgt_idx += y * op->wgt_factor_out + x * op->wgt_factor_in; + BitPacker acc(acc_.BeginPtr(acc_idx)); + BitPacker inp(inp_.BeginPtr(inp_idx)); + BitPacker wgt(wgt_.BeginPtr(wgt_idx)); + // gemm loop + for (uint32_t i = 0; i < VTA_BATCH; ++i) { + for (uint32_t j = 0; j < VTA_BLOCK_OUT; ++j) { + uint32_t acc_offset = i * VTA_BLOCK_OUT + j; + int32_t sum = acc.GetSigned(acc_offset); + for (uint32_t k = 0; k < VTA_BLOCK_IN; ++k) { + sum += + inp.GetSigned(i * VTA_BLOCK_IN + k) * + wgt.GetSigned(j * VTA_BLOCK_IN + k); + } + acc.SetSigned(acc_offset, sum); + } + } + } + } + } + } else { + // reset + for (uint32_t y = 0; y < op->iter_out; ++y) { + for (uint32_t x = 0; x < op->iter_in; ++x) { + for (uint32_t uindex = op->uop_bgn; uindex < op->uop_end; ++uindex) { + VTAUop* uop_ptr = static_cast(uop_.BeginPtr(uindex)); + uint32_t acc_idx = uop_ptr->dst_idx; + acc_idx += y * op->dst_factor_out + x * op->dst_factor_in; + BitPacker acc(acc_.BeginPtr(acc_idx)); + for (uint32_t i = 0; i < VTA_BATCH * VTA_BLOCK_OUT; ++i) { + acc.SetSigned(i, 0); + } + } + } + } + } + } + + void RunALU(const VTAAluInsn* op) { + prof_->alu_counter += op->iter_out * op->iter_in; + if (op->use_imm) { + RunALU_(op); + } else { + RunALU_(op); + } + } + + template + void RunALU_(const VTAAluInsn* op) { + switch (op->alu_opcode) { + case VTA_ALU_OPCODE_ADD: { + return RunALULoop(op, [](int32_t x, int32_t y) { + return x + y; + }); + } + case VTA_ALU_OPCODE_MAX: { + return RunALULoop(op, [](int32_t x, int32_t y) { + return std::max(x, y); + }); + } + case VTA_ALU_OPCODE_MIN: { + return RunALULoop(op, [](int32_t x, int32_t y) { + return std::min(x, y); + }); + } + case VTA_ALU_OPCODE_SHR: { + return RunALULoop(op, [](int32_t x, int32_t y) { + if (y >= 0) { + return x >> y; + } else { + return x << (-y); + } + }); + } + default: { + LOG(FATAL) << "Unknown ALU code " << op->alu_opcode; + } + } + } + + template + void RunALULoop(const VTAAluInsn* op, F func) { + for (int y = 0; y < op->iter_out; ++y) { + for (int x = 0; x < op->iter_in; ++x) { + for (int k = op->uop_bgn; k < op->uop_end; ++k) { + // Read micro op + VTAUop* uop_ptr = static_cast(uop_.BeginPtr(k)); + uint32_t dst_index = uop_ptr->dst_idx; + uint32_t src_index = uop_ptr->src_idx; + dst_index += y * op->dst_factor_out + x * op->dst_factor_in; + src_index += y * op->src_factor_out + x * op->src_factor_in; + BitPacker dst(acc_.BeginPtr(dst_index)); + BitPacker src(acc_.BeginPtr(src_index)); + for (int k = 0; k < VTA_BLOCK_OUT; ++k) { + if (use_imm) { + dst.SetSigned(k, func(dst.GetSigned(k), op->imm)); + } else { + dst.SetSigned(k, func(dst.GetSigned(k), src.GetSigned(k))); + } + } + } + } + } + } + // the finish counter + int finish_counter_{0}; + // Prof_ + Profiler* prof_; + // The DRAM interface + DRAM* dram_; + // The SRAM + SRAM inp_; + SRAM wgt_; + SRAM acc_; + SRAM uop_; +}; + +using tvm::runtime::TVMRetValue; +using tvm::runtime::TVMArgs; + +TVM_REGISTER_GLOBAL("vta.simulator.profiler_clear") +.set_body([](TVMArgs args, TVMRetValue* rv) { + Profiler::ThreadLocal()->Clear(); + }); +TVM_REGISTER_GLOBAL("vta.simulator.profiler_status") +.set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = Profiler::ThreadLocal()->AsJSON(); + }); +} // namespace sim +} // namespace vta + +void* VTAMemAlloc(size_t size, int cached) { + return vta::sim::DRAM::Global()->Alloc(size); +} + +void VTAMemFree(void* buf) { + vta::sim::DRAM::Global()->Free(buf); +} + +vta_phy_addr_t VTAMemGetPhyAddr(void* buf) { + return vta::sim::DRAM::Global()->GetPhyAddr(buf); +} + +void VTAFlushCache(vta_phy_addr_t buf, int size) { +} + +void VTAInvalidateCache(vta_phy_addr_t buf, int size) { +} + +VTADeviceHandle VTADeviceAlloc() { + return new vta::sim::Device(); +} + +void VTADeviceFree(VTADeviceHandle handle) { + delete static_cast(handle); +} + +int VTADeviceRun(VTADeviceHandle handle, + vta_phy_addr_t insn_phy_addr, + uint32_t insn_count, + uint32_t wait_cycles) { + return static_cast(handle)->Run( + insn_phy_addr, insn_count, wait_cycles); +} + +void VTAProgram(const char* bitstream) { +} diff --git a/vta/src/tvm/vta_device_api.cc b/vta/src/tvm/vta_device_api.cc index 450b23b05fee..e4671d8a0207 100644 --- a/vta/src/tvm/vta_device_api.cc +++ b/vta/src/tvm/vta_device_api.cc @@ -67,9 +67,6 @@ class VTADeviceAPI final : public DeviceAPI { std::make_shared(); return inst; } - - private: - void* runtime_dll_{nullptr}; }; struct VTAWorkspacePool : public WorkspacePool { diff --git a/vta/tests/python/pynq/test_vta_insn.py b/vta/tests/python/pynq/test_vta_insn.py deleted file mode 100644 index 14baede4e44a..000000000000 --- a/vta/tests/python/pynq/test_vta_insn.py +++ /dev/null @@ -1,504 +0,0 @@ -"""Unit test VTA's instructions """ -import tvm -import vta -import mxnet as mx -import numpy as np -import topi -from tvm.contrib import rpc, util - -host = "pynq" -port = 9091 -target = "llvm -target=armv7-none-linux-gnueabihf" -do_verify = True -print_ir = False - -def test_save_load_out(): - env = vta.get_env() - """Test save/store output command""" - n = 4 - x = tvm.placeholder( - (n, n, env.BATCH, env.BLOCK_OUT), - name="x", - dtype=env.acc_dtype) - x_buf = tvm.compute( - (n, n, env.BATCH, env.BLOCK_OUT), - lambda *i: x(*i), - "x_buf") - # insert no-op that won't be optimized away - y_buf = tvm.compute( - (n, n, env.BATCH, env.BLOCK_OUT), - lambda *i: x_buf(*i)>>0, - "y_buf") - y = tvm.compute( - (n, n, env.BATCH, env.BLOCK_OUT), - lambda *i: y_buf(*i).astype(env.inp_dtype), - "y") - # schedule - s = tvm.create_schedule(y.op) - s[x_buf].set_scope(env.acc_scope) - s[x_buf].pragma(x_buf.op.axis[0], env.dma_copy) - s[y_buf].set_scope(env.acc_scope) - s[y_buf].pragma(y_buf.op.axis[0], env.alu) - s[y].pragma(y.op.axis[0], env.dma_copy) - - def verify(): - # build - with vta.build_config(env.DEBUG_DUMP_INSN): - m = vta.build(s, [x, y], "ext_dev", target) - temp = util.tempdir() - remote = rpc.connect(host, port) - m.save(temp.relpath("load_act.o")) - remote.upload(temp.relpath("load_act.o")) - f = remote.load_module("load_act.o") - # verify - ctx = remote.ext_dev(0) - x_np = np.random.randint( - 1, 10, size=(n, n, env.BATCH, env.BLOCK_OUT)).astype(x.dtype) - y_np = x_np.astype(y.dtype) - x_nd = tvm.nd.array(x_np, ctx) - y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype) - f(x_nd, y_nd) - np.testing.assert_equal(y_np, y_nd.asnumpy()) - print("\tFinished verification...") - if do_verify: - verify() - - -def test_padded_load(): - """Test padded load.""" - env = vta.get_env() - # declare - n = 21 - m = 20 - pad_before = [0, 1, 0, 0] - pad_after = [1, 3, 0, 0] - x = tvm.placeholder( - (n, m, env.BATCH, env.BLOCK_OUT), - name="x", - dtype=env.acc_dtype) - x_buf = topi.nn.pad(x, pad_before, pad_after, name="y") - # insert no-op that won't be optimized away - y_buf = tvm.compute((n + pad_before[0] + pad_after[0], - m + pad_before[1] + pad_after[1], - env.BATCH, - env.BLOCK_OUT), lambda *i: x_buf(*i)>>0, "y_buf") - y = tvm.compute((n + pad_before[0] + pad_after[0], - m + pad_before[1] + pad_after[1], - env.BATCH, - env.BLOCK_OUT), lambda *i: y_buf(*i).astype(env.inp_dtype), "y") - # schedule - s = tvm.create_schedule(y.op) - s[x_buf].set_scope(env.acc_scope) - s[x_buf].pragma(x_buf.op.axis[0], env.dma_copy) - s[y_buf].set_scope(env.acc_scope) - s[y_buf].pragma(y_buf.op.axis[0], env.alu) - s[y].pragma(y.op.axis[0], env.dma_copy) - - def verify(): - # build - with vta.build_config(env.DEBUG_DUMP_INSN): - mod = vta.build(s, [x, y], "ext_dev", target) - temp = util.tempdir() - remote = rpc.connect(host, port) - mod.save(temp.relpath("padded_load.o")) - remote.upload(temp.relpath("padded_load.o")) - f = remote.load_module("padded_load.o") - # verify - ctx = remote.ext_dev(0) - x_np = np.random.randint(1, 2, size=( - n, m, env.BATCH, env.BLOCK_OUT)).astype(x.dtype) - y_np = np.zeros((n + pad_before[0] + pad_after[0], - m + pad_before[1] + pad_after[1], - env.BATCH, - env.BLOCK_OUT)).astype(y.dtype) - y_np[pad_before[0]:pad_before[0] + n, - pad_before[1]:pad_before[1] + m, - :] = x_np - x_nd = tvm.nd.array(x_np, ctx) - y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype) - f(x_nd, y_nd) - np.testing.assert_equal(y_np, y_nd.asnumpy()) - print("\tFinished verification...") - - if print_ir: - print(vta.lower(s, [y, x], simple_mode=True)) - - -def test_gemm(): - """Test GEMM.""" - env = vta.get_env() - # declare - o = 4 - n = 4 - m = 4 - x = tvm.placeholder((o, n, env.BATCH, env.BLOCK_IN), name="x", dtype=env.inp_dtype) - w = tvm.placeholder((m, n, env.BLOCK_OUT, env.BLOCK_IN), name="w", dtype=env.wgt_dtype) - x_buf = tvm.compute((o, n, env.BATCH, env.BLOCK_IN), lambda *i: x(*i), "x_buf") - w_buf = tvm.compute((m, n, env.BLOCK_OUT, env.BLOCK_IN), lambda *i: w(*i), "w_buf") - ko = tvm.reduce_axis((0, n), name="ko") - ki = tvm.reduce_axis((0, env.BLOCK_IN), name="ki") - y_gem = tvm.compute( - (o, m, env.BATCH, env.BLOCK_OUT), - lambda bo, co, bi, ci: - tvm.sum(x_buf[bo, ko, bi, ki].astype(env.acc_dtype) * - w_buf[co, ko, ci, ki].astype(env.acc_dtype), - axis=[ko, ki]), - name="y_gem") - y_shf = tvm.compute( - (o, m, env.BATCH, env.BLOCK_OUT), - lambda *i: y_gem(*i)>>8, - name="y_shf") - y_max = tvm.compute( - (o, m, env.BATCH, env.BLOCK_OUT), - lambda *i: tvm.max(y_shf(*i), 0), - "y_max") #relu - y_min = tvm.compute( - (o, m, env.BATCH, env.BLOCK_OUT), - lambda *i: tvm.min(y_max(*i), (1<<(env.INP_WIDTH-1))-1), - "y_min") #relu - y = tvm.compute( - (o, m, env.BATCH, env.BLOCK_OUT), - lambda *i: y_min(*i).astype(env.inp_dtype), - name="y") - - def verify(s): - mod = vta.build(s, [x, w, y], "ext_dev", target) - temp = util.tempdir() - remote = rpc.connect(host, port) - mod.save(temp.relpath("gemm.o")) - remote.upload(temp.relpath("gemm.o")) - f = remote.load_module("gemm.o") - # verify - ctx = remote.ext_dev(0) - x_np = np.random.randint( - -128, 128, size=(o, n, env.BATCH, env.BLOCK_IN)).astype(x.dtype) - w_np = np.random.randint( - -128, 128, size=(m, n, env.BLOCK_OUT, env.BLOCK_IN)).astype(w.dtype) - y_np = np.zeros((o, m, env.BATCH, env.BLOCK_OUT)).astype(y.dtype) - x_nd = tvm.nd.array(x_np, ctx) - w_nd = tvm.nd.array(w_np, ctx) - y_nd = tvm.nd.array(y_np, ctx) - y_np = y_np.astype(env.acc_dtype) - for b in range(o): - for i in range(m): - for j in range(n): - y_np[b,i,:] += np.dot(x_np[b,j,:].astype(env.acc_dtype), - w_np[i,j].T.astype(env.acc_dtype)) - y_np = np.right_shift(y_np, 8) - y_np = np.clip(y_np, 0, (1<<(env.INP_WIDTH-1))-1).astype(y.dtype) - f(x_nd, w_nd, y_nd) - np.testing.assert_equal(y_np, y_nd.asnumpy()) - print("\tFinished verification...") - - def test_schedule1(): - # default schedule with no smt - s = tvm.create_schedule(y.op) - # set the scope of the SRAM buffers - s[x_buf].set_scope(env.SCOPE_INP) - s[w_buf].set_scope(env.SCOPE_WGT) - s[y_gem].set_scope(env.acc_scope) - s[y_shf].set_scope(env.acc_scope) - s[y_max].set_scope(env.acc_scope) - s[y_min].set_scope(env.acc_scope) - # set pragmas for DMA transfer and ALU ops - s[x_buf].pragma(s[x_buf].op.axis[0], env.dma_copy) - s[w_buf].pragma(s[w_buf].op.axis[0], env.dma_copy) - s[y_shf].pragma(s[y_shf].op.axis[0], env.alu) - s[y_max].pragma(s[y_max].op.axis[0], env.alu) - s[y_min].pragma(s[y_min].op.axis[0], env.alu) - s[y].pragma(s[y].op.axis[0], env.dma_copy) - # tensorization - s[y_gem].reorder( - ko, - s[y_gem].op.axis[0], - s[y_gem].op.axis[1], - s[y_gem].op.axis[2], - s[y_gem].op.axis[3], - ki) - s[y_gem].tensorize(s[y_gem].op.axis[2], env.GEMM) - if print_ir: - print(vta.lower(s, [x, w, y], simple_mode=True)) - if do_verify: - with vta.build_config(env.DEBUG_DUMP_INSN): - verify(s) - - def test_smt(): - # test smt schedule - s = tvm.create_schedule(y.op) - s[x_buf].set_scope(env.SCOPE_INP) - s[w_buf].set_scope(env.SCOPE_WGT) - s[y_gem].set_scope(env.acc_scope) - s[y_shf].set_scope(env.acc_scope) - s[y_max].set_scope(env.acc_scope) - s[y_min].set_scope(env.acc_scope) - abo, aco, abi, aci = s[y].op.axis - abo1, abo2 = s[y].split(abo, nparts=2) - s[y].bind(abo1, tvm.thread_axis("cthread")) - s[y_gem].compute_at(s[y], abo1) - s[y_shf].compute_at(s[y], abo1) - s[y_max].compute_at(s[y], abo1) - s[y_min].compute_at(s[y], abo1) - s[y_gem].reorder( - ko, - s[y_gem].op.axis[0], - s[y_gem].op.axis[1], - s[y_gem].op.axis[2], - s[y_gem].op.axis[3], - ki) - s[y_gem].tensorize(s[y_gem].op.axis[2], env.GEMM) - s[y_shf].pragma(s[y_shf].op.axis[0], env.alu) - s[y_max].pragma(s[y_max].op.axis[0], env.alu) - s[y_min].pragma(s[y_min].op.axis[0], env.alu) - s[x_buf].compute_at(s[y_gem], ko) - s[x_buf].pragma(s[x_buf].op.axis[0], env.dma_copy) - s[w_buf].compute_at(s[y_gem], ko) - s[w_buf].pragma(s[w_buf].op.axis[0], env.dma_copy) - s[y].pragma(abo2, env.dma_copy) - if print_ir: - print(vta.lower(s, [x, y, w], simple_mode=True)) - if do_verify: - with vta.build_config(env.DEBUG_DUMP_INSN): - verify(s) - - test_schedule1() - test_smt() - -def test_alu(tvm_op, np_op=None, use_imm=False): - """Test ALU""" - env = vta.get_env() - m = 8 - n = 8 - imm = np.random.randint(1,5) - # compute - a = tvm.placeholder( - (m, n, env.BATCH, env.BLOCK_OUT), - name="a", - dtype=env.acc_dtype) - a_buf = tvm.compute( - (m, n, env.BATCH, env.BLOCK_OUT), - lambda *i: a(*i), - "a_buf") #DRAM->SRAM - if use_imm: - res_buf = tvm.compute( - (m, n, env.BATCH, env.BLOCK_OUT), - lambda *i: tvm_op(a_buf(*i), imm), - "res_buf") #compute - else: - b = tvm.placeholder( - (m, n, env.BATCH, env.BLOCK_OUT), - name="b", - dtype=env.acc_dtype) - b_buf = tvm.compute( - (m, n, env.BATCH, env.BLOCK_OUT), - lambda *i: b(*i), - "b_buf") #DRAM->SRAM - res_buf = tvm.compute( - (m, n, env.BATCH, env.BLOCK_OUT), - lambda *i: tvm_op(a_buf(*i), b_buf(*i)), - "res_buf") #compute - res = tvm.compute( - (m, n, env.BATCH, env.BLOCK_OUT), - lambda *i: res_buf(*i).astype(env.inp_dtype), - "res") #SRAM->DRAM - # schedule - s = tvm.create_schedule(res.op) - s[a_buf].set_scope(env.acc_scope) # SRAM - s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy) # DRAM->SRAM - s[res_buf].set_scope(env.acc_scope) # SRAM - s[res_buf].pragma(res_buf.op.axis[0], env.alu) # compute - s[res].pragma(res.op.axis[0], env.dma_copy) # SRAM->DRAM - if use_imm: - if print_ir: - print(vta.lower(s, [a, res], simple_mode=True)) - else: - s[b_buf].set_scope(env.acc_scope) # SRAM - s[b_buf].pragma(b_buf.op.axis[0], env.dma_copy) # DRAM->SRAM - if print_ir: - print(vta.lower(s, [a, b, res], simple_mode=True)) - - def verify(): - # build - with vta.build_config(): - if use_imm: - mod = vta.build(s, [a, res], "ext_dev", target) - else: - mod = vta.build(s, [a, b, res], "ext_dev", target) - temp = util.tempdir() - remote = rpc.connect(host, port) - mod.save(temp.relpath("load_act.o")) - remote.upload(temp.relpath("load_act.o")) - f = remote.load_module("load_act.o") - # verify - ctx = remote.ext_dev(0) - a_np = np.random.randint( - -16, 16, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(a.dtype) - if use_imm: - res_np = np_op(a_np, imm) if np_op else tvm_op(a_np, imm) - else: - b_np = np.random.randint( - -16, 16, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(b.dtype) - res_np = np_op(a_np, b_np) if np_op else tvm_op(a_np, b_np) - res_np = res_np.astype(res.dtype) - a_nd = tvm.nd.array(a_np, ctx) - res_nd = tvm.nd.array( - np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx) - if use_imm: - f(a_nd, res_nd) - else: - b_nd = tvm.nd.array(b_np, ctx) - f(a_nd, b_nd, res_nd) - np.testing.assert_equal(res_np, res_nd.asnumpy()) - print("\tFinished verification...") - - if do_verify: - verify() - -def test_relu(): - """Test RELU on ALU""" - env = vta.get_env() - m = 8 - n = 8 - # compute - a = tvm.placeholder( - (m, n, env.BATCH, env.BLOCK_OUT), - name="a", - dtype=env.acc_dtype) - a_buf = tvm.compute( - (m, n, env.BATCH, env.BLOCK_OUT), - lambda *i: a(*i), - "a_buf") # DRAM->SRAM - max_buf = tvm.compute( - (m, n, env.BATCH, env.BLOCK_OUT), - lambda *i: tvm.max(a_buf(*i), 0), - "res_buf") # relu - min_buf = tvm.compute( - (m, n, env.BATCH, env.BLOCK_OUT), - lambda *i: tvm.min(max_buf(*i), (1<<(env.INP_WIDTH-1))-1), - "max_buf") # relu - res = tvm.compute( - (m, n, env.BATCH, env.BLOCK_OUT), - lambda *i: min_buf(*i).astype(env.inp_dtype), - "min_buf") # SRAM->DRAM - # schedule - s = tvm.create_schedule(res.op) - s[a_buf].set_scope(env.acc_scope) # SRAM - s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy) # DRAM->SRAM - s[max_buf].set_scope(env.acc_scope) # SRAM - s[min_buf].set_scope(env.acc_scope) # SRAM - s[max_buf].pragma(max_buf.op.axis[0], env.alu) # compute - s[min_buf].pragma(min_buf.op.axis[0], env.alu) # compute - s[res].pragma(res.op.axis[0], env.dma_copy) # SRAM->DRAM - if print_ir: - print(vta.lower(s, [a, res], simple_mode=True)) - - def verify(): - # build - with vta.build_config(env.DEBUG_DUMP_INSN): - mod = vta.build(s, [a, res], "ext_dev", target) - temp = util.tempdir() - remote = rpc.connect(host, port) - mod.save(temp.relpath("load_act.o")) - remote.upload(temp.relpath("load_act.o")) - f = remote.load_module("load_act.o") - # verify - ctx = remote.ext_dev(0) - a_np = np.random.randint( - -256, 256, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(a.dtype) - res_np = np.clip(a_np, 0, (1<<(env.INP_WIDTH-1))-1).astype(res.dtype) - a_nd = tvm.nd.array(a_np, ctx) - res_nd = tvm.nd.array( - np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx) - f(a_nd, res_nd) - np.testing.assert_equal(res_np, res_nd.asnumpy()) - print("\tFinished verification...") - - if do_verify: - verify() - -def test_shift_and_scale(): - """Test shift and scale on ALU""" - env = vta.get_env() - m = 8 - n = 8 - imm_shift = np.random.randint(-10,10) - imm_scale = np.random.randint(1,5) - # compute - a = tvm.placeholder( - (m, n, env.BATCH, env.BLOCK_OUT), - name="a", dtype=env.acc_dtype) - a_buf = tvm.compute( - (m, n, env.BATCH, env.BLOCK_OUT), - lambda *i: a(*i), - "a_buf") # DRAM->SRAM - res_shift = tvm.compute( - (m, n, env.BATCH, env.BLOCK_OUT), - lambda *i: a_buf(*i)+imm_shift, - "res_shift") # compute - res_scale = tvm.compute( - (m, n, env.BATCH, env.BLOCK_OUT), - lambda *i: res_shift(*i)>>imm_scale, - "res_scale") # compute - res = tvm.compute( - (m, n, env.BATCH, env.BLOCK_OUT), - lambda *i: res_scale(*i).astype(env.inp_dtype), - "res") # SRAM->DRAM - # schedule - s = tvm.create_schedule(res.op) - s[a_buf].set_scope(env.acc_scope) # SRAM - s[res_shift].set_scope(env.acc_scope) # SRAM - s[res_scale].set_scope(env.acc_scope) # SRAM - s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy) # DRAM->SRAM - s[res_shift].pragma(res_shift.op.axis[0], env.alu) # compute - s[res_scale].pragma(res_scale.op.axis[0], env.alu) # compute - s[res].pragma(res.op.axis[0], env.dma_copy) # SRAM->DRAM - - if print_ir: - print(vta.lower(s, [a, res], simple_mode=True)) - - def verify(): - # build - mod = vta.build(s, [a, res], "ext_dev", target) - temp = util.tempdir() - remote = rpc.connect(host, port) - mod.save(temp.relpath("load_act.o")) - remote.upload(temp.relpath("load_act.o")) - f = remote.load_module("load_act.o") - # verify - ctx = remote.ext_dev(0) - a_np = np.random.randint( - -10, 10, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(a.dtype) - res_np = np.right_shift((a_np + imm_shift), imm_scale) - res_np = res_np.astype(res.dtype) - a_nd = tvm.nd.array(a_np, ctx) - res_nd = tvm.nd.array( - np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx) - f(a_nd, res_nd) - np.testing.assert_equal(res_np, res_nd.asnumpy()) - print("\tFinished verification...") - - if do_verify: - verify() - -if __name__ == "__main__": - print("Load/store test") - test_save_load_out() - print("Padded load test") - test_padded_load() - # print("GEMM test") - # test_gemm() - print("Max immediate") - test_alu(tvm.max, np.maximum, use_imm=True) - print("Max") - test_alu(tvm.max, np.maximum) - print("Add immediate") - test_alu(lambda x, y: x + y, use_imm=True) - print("Add") - test_alu(lambda x, y: x + y) - print("Shift right immediate") - test_alu(lambda x, y: x >> y, np.right_shift, use_imm=True) - print("Shift left immediate") - test_alu(lambda x, y: x << y, np.left_shift, use_imm=True) - print("Relu") - test_relu() - # print("Shift and scale") - # test_shift_and_scale() diff --git a/vta/tests/python/unittest/test_vta_insn.py b/vta/tests/python/unittest/test_vta_insn.py new file mode 100644 index 000000000000..339d8d31e238 --- /dev/null +++ b/vta/tests/python/unittest/test_vta_insn.py @@ -0,0 +1,482 @@ +"""Unit test VTA's instructions """ +import tvm +import numpy as np +import topi +from tvm.contrib import rpc, util + +import vta +import vta.testing +from vta.testing import simulator + + +def test_save_load_out(): + """Test save/store output command""" + def _run(env, remote): + n = 6 + x = tvm.placeholder( + (n, n, env.BATCH, env.BLOCK_OUT), + name="x", + dtype=env.acc_dtype) + x_buf = tvm.compute( + (n, n, env.BATCH, env.BLOCK_OUT), + lambda *i: x(*i), "x_buf") + # insert no-op that won't be optimized away + y_buf = tvm.compute( + (n, n, env.BATCH, env.BLOCK_OUT), + lambda *i: x_buf(*i)>>0, "y_buf") + y = tvm.compute( + (n, n, env.BATCH, env.BLOCK_OUT), + lambda *i: y_buf(*i).astype(env.inp_dtype), "y") + # schedule + s = tvm.create_schedule(y.op) + s[x_buf].set_scope(env.acc_scope) + s[x_buf].pragma(x_buf.op.axis[0], env.dma_copy) + s[y_buf].set_scope(env.acc_scope) + s[y_buf].pragma(y_buf.op.axis[0], env.alu) + s[y].pragma(y.op.axis[0], env.dma_copy) + + # verification + with vta.build_config(): + m = vta.build(s, [x, y], "ext_dev", env.target_host) + + if not remote: + return + temp = util.tempdir() + m.save(temp.relpath("load_act.o")) + remote.upload(temp.relpath("load_act.o")) + f = remote.load_module("load_act.o") + # verify + ctx = remote.ext_dev(0) + x_np = np.random.randint( + 1, 10, size=(n, n, env.BATCH, env.BLOCK_OUT)).astype(x.dtype) + y_np = x_np.astype(y.dtype) + x_nd = tvm.nd.array(x_np, ctx) + y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype) + f(x_nd, y_nd) + np.testing.assert_equal(y_np, y_nd.asnumpy()) + + vta.testing.run(_run) + + +def test_padded_load(): + """Test padded load.""" + def _run(env, remote): + # declare + n = 21 + m = 20 + pad_before = [0, 1, 0, 0] + pad_after = [1, 3, 0, 0] + x = tvm.placeholder( + (n, m, env.BATCH, env.BLOCK_OUT), + name="x", + dtype=env.acc_dtype) + x_buf = topi.nn.pad(x, pad_before, pad_after, name="y") + # insert no-op that won't be optimized away + y_buf = tvm.compute((n + pad_before[0] + pad_after[0], + m + pad_before[1] + pad_after[1], + env.BATCH, + env.BLOCK_OUT), lambda *i: x_buf(*i)>>0, "y_buf") + y = tvm.compute((n + pad_before[0] + pad_after[0], + m + pad_before[1] + pad_after[1], + env.BATCH, + env.BLOCK_OUT), lambda *i: y_buf(*i).astype(env.inp_dtype), "y") + # schedule + s = tvm.create_schedule(y.op) + s[x_buf].set_scope(env.acc_scope) + s[x_buf].pragma(x_buf.op.axis[0], env.dma_copy) + s[y_buf].set_scope(env.acc_scope) + s[y_buf].pragma(y_buf.op.axis[0], env.alu) + s[y].pragma(y.op.axis[0], env.dma_copy) + # build + with vta.build_config(): + mod = vta.build(s, [x, y], "ext_dev", env.target_host) + + if not remote: + return + temp = util.tempdir() + mod.save(temp.relpath("padded_load.o")) + remote.upload(temp.relpath("padded_load.o")) + f = remote.load_module("padded_load.o") + # verify + ctx = remote.ext_dev(0) + x_np = np.random.randint(1, 2, size=( + n, m, env.BATCH, env.BLOCK_OUT)).astype(x.dtype) + y_np = np.zeros((n + pad_before[0] + pad_after[0], + m + pad_before[1] + pad_after[1], + env.BATCH, + env.BLOCK_OUT)).astype(y.dtype) + y_np[pad_before[0]:pad_before[0] + n, + pad_before[1]:pad_before[1] + m, + :] = x_np + x_nd = tvm.nd.array(x_np, ctx) + y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype) + f(x_nd, y_nd) + np.testing.assert_equal(y_np, y_nd.asnumpy()) + + vta.testing.run(_run) + + +def test_gemm(): + """Test GEMM.""" + def _run(env, remote): + # declare + o = 4 + n = 1 + m = 4 + x = tvm.placeholder((o, n, env.BATCH, env.BLOCK_IN), name="x", dtype=env.inp_dtype) + w = tvm.placeholder((m, n, env.BLOCK_OUT, env.BLOCK_IN), name="w", dtype=env.wgt_dtype) + x_buf = tvm.compute((o, n, env.BATCH, env.BLOCK_IN), lambda *i: x(*i), "x_buf") + w_buf = tvm.compute((m, n, env.BLOCK_OUT, env.BLOCK_IN), lambda *i: w(*i), "w_buf") + ko = tvm.reduce_axis((0, n), name="ko") + ki = tvm.reduce_axis((0, env.BLOCK_IN), name="ki") + y_gem = tvm.compute( + (o, m, env.BATCH, env.BLOCK_OUT), + lambda bo, co, bi, ci: + tvm.sum(x_buf[bo, ko, bi, ki].astype(env.acc_dtype) * + w_buf[co, ko, ci, ki].astype(env.acc_dtype), + axis=[ko, ki]), + name="y_gem") + y_shf = tvm.compute( + (o, m, env.BATCH, env.BLOCK_OUT), + lambda *i: y_gem(*i)>>8, + name="y_shf") + y_max = tvm.compute( + (o, m, env.BATCH, env.BLOCK_OUT), + lambda *i: tvm.max(y_shf(*i), 0), + "y_max") #relu + y_min = tvm.compute( + (o, m, env.BATCH, env.BLOCK_OUT), + lambda *i: tvm.min(y_max(*i), (1<<(env.INP_WIDTH-1))-1), + "y_min") #relu + y = tvm.compute( + (o, m, env.BATCH, env.BLOCK_OUT), + lambda *i: y_min(*i).astype(env.inp_dtype), + name="y") + + if not remote: + return + + def verify(s): + mod = vta.build(s, [x, w, y], "ext_dev", env.target_host) + temp = util.tempdir() + mod.save(temp.relpath("gemm.o")) + remote.upload(temp.relpath("gemm.o")) + f = remote.load_module("gemm.o") + # verify + ctx = remote.ext_dev(0) + x_np = np.random.randint( + -128, 128, size=(o, n, env.BATCH, env.BLOCK_IN)).astype(x.dtype) + w_np = np.random.randint( + -128, 128, size=(m, n, env.BLOCK_OUT, env.BLOCK_IN)).astype(w.dtype) + y_np = np.zeros((o, m, env.BATCH, env.BLOCK_OUT)).astype(y.dtype) + x_nd = tvm.nd.array(x_np, ctx) + w_nd = tvm.nd.array(w_np, ctx) + y_nd = tvm.nd.array(y_np, ctx) + y_np = y_np.astype(env.acc_dtype) + for b in range(o): + for i in range(m): + for j in range(n): + y_np[b,i,:] += np.dot(x_np[b,j,:].astype(env.acc_dtype), + w_np[i,j].T.astype(env.acc_dtype)) + y_np = np.right_shift(y_np, 8) + y_np = np.clip(y_np, 0, (1<<(env.INP_WIDTH-1))-1).astype(y.dtype) + + if env.TARGET == "sim": + simulator.clear_stats() + f(x_nd, w_nd, y_nd) + print(simulator.stats()) + else: + f(x_nd, w_nd, y_nd) + + np.testing.assert_equal(y_np, y_nd.asnumpy()) + + def test_schedule1(): + # default schedule with no smt + s = tvm.create_schedule(y.op) + # set the scope of the SRAM buffers + s[x_buf].set_scope(env.inp_scope) + s[w_buf].set_scope(env.wgt_scope) + s[y_gem].set_scope(env.acc_scope) + s[y_shf].set_scope(env.acc_scope) + s[y_max].set_scope(env.acc_scope) + s[y_min].set_scope(env.acc_scope) + # set pragmas for DMA transfer and ALU ops + s[x_buf].compute_at(s[y_gem], ko) + s[x_buf].pragma(s[x_buf].op.axis[0], env.dma_copy) + s[w_buf].compute_at(s[y_gem], ko) + s[w_buf].pragma(s[w_buf].op.axis[0], env.dma_copy) + s[y_shf].pragma(s[y_shf].op.axis[0], env.alu) + s[y_max].pragma(s[y_max].op.axis[0], env.alu) + s[y_min].pragma(s[y_min].op.axis[0], env.alu) + s[y].pragma(s[y].op.axis[0], env.dma_copy) + # tensorization + s[y_gem].reorder( + ko, + s[y_gem].op.axis[0], + s[y_gem].op.axis[1], + s[y_gem].op.axis[2], + s[y_gem].op.axis[3], + ki) + s[y_gem].tensorize(s[y_gem].op.axis[2], env.gemm) + verify(s) + + def test_smt(): + # test smt schedule + s = tvm.create_schedule(y.op) + s[x_buf].set_scope(env.inp_scope) + s[w_buf].set_scope(env.wgt_scope) + s[y_gem].set_scope(env.acc_scope) + s[y_shf].set_scope(env.acc_scope) + s[y_max].set_scope(env.acc_scope) + s[y_min].set_scope(env.acc_scope) + abo, aco, abi, aci = s[y].op.axis + abo1, abo2 = s[y].split(abo, nparts=2) + s[y].bind(abo1, tvm.thread_axis("cthread")) + s[y_gem].compute_at(s[y], abo1) + s[y_shf].compute_at(s[y], abo1) + s[y_max].compute_at(s[y], abo1) + s[y_min].compute_at(s[y], abo1) + s[y_gem].reorder( + ko, + s[y_gem].op.axis[0], + s[y_gem].op.axis[1], + s[y_gem].op.axis[2], + s[y_gem].op.axis[3], + ki) + s[y_gem].tensorize(s[y_gem].op.axis[2], env.gemm) + s[y_shf].pragma(s[y_shf].op.axis[0], env.alu) + s[y_max].pragma(s[y_max].op.axis[0], env.alu) + s[y_min].pragma(s[y_min].op.axis[0], env.alu) + s[x_buf].compute_at(s[y_gem], ko) + s[x_buf].pragma(s[x_buf].op.axis[0], env.dma_copy) + s[w_buf].compute_at(s[y_gem], ko) + s[w_buf].pragma(s[w_buf].op.axis[0], env.dma_copy) + s[y].pragma(abo2, env.dma_copy) + verify(s) + + test_schedule1() + test_smt() + vta.testing.run(_run) + + +def test_alu(): + def _run(env, remote): + def check_alu(tvm_op, np_op=None, use_imm=False): + """Test ALU""" + m = 8 + n = 8 + imm = np.random.randint(1,5) + # compute + a = tvm.placeholder( + (m, n, env.BATCH, env.BLOCK_OUT), + name="a", + dtype=env.acc_dtype) + a_buf = tvm.compute( + (m, n, env.BATCH, env.BLOCK_OUT), + lambda *i: a(*i), + "a_buf") #DRAM->SRAM + if use_imm: + res_buf = tvm.compute( + (m, n, env.BATCH, env.BLOCK_OUT), + lambda *i: tvm_op(a_buf(*i), imm), + "res_buf") #compute + else: + b = tvm.placeholder( + (m, n, env.BATCH, env.BLOCK_OUT), + name="b", + dtype=env.acc_dtype) + b_buf = tvm.compute( + (m, n, env.BATCH, env.BLOCK_OUT), + lambda *i: b(*i), + "b_buf") #DRAM->SRAM + res_buf = tvm.compute( + (m, n, env.BATCH, env.BLOCK_OUT), + lambda *i: tvm_op(a_buf(*i), b_buf(*i)), + "res_buf") #compute5B + res = tvm.compute( + (m, n, env.BATCH, env.BLOCK_OUT), + lambda *i: res_buf(*i).astype(env.inp_dtype), + "res") #SRAM->DRAM + # schedule + s = tvm.create_schedule(res.op) + s[a_buf].set_scope(env.acc_scope) # SRAM + s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy) # DRAM->SRAM + s[res_buf].set_scope(env.acc_scope) # SRAM + s[res_buf].pragma(res_buf.op.axis[0], env.alu) # compute + s[res].pragma(res.op.axis[0], env.dma_copy) # SRAM->DRAM + if not use_imm: + s[b_buf].set_scope(env.acc_scope) # SRAM + s[b_buf].pragma(b_buf.op.axis[0], env.dma_copy) # DRAM->SRAM + + if not remote: + return + + # build + with vta.build_config(): + if use_imm: + mod = vta.build(s, [a, res], "ext_dev", env.target_host) + else: + mod = vta.build(s, [a, b, res], "ext_dev", env.target_host) + temp = util.tempdir() + mod.save(temp.relpath("load_act.o")) + remote.upload(temp.relpath("load_act.o")) + f = remote.load_module("load_act.o") + # verify + ctx = remote.ext_dev(0) + a_np = np.random.randint( + -16, 16, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(a.dtype) + if use_imm: + res_np = np_op(a_np, imm) if np_op else tvm_op(a_np, imm) + else: + b_np = np.random.randint( + -16, 16, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(b.dtype) + res_np = np_op(a_np, b_np) if np_op else tvm_op(a_np, b_np) + res_np = res_np.astype(res.dtype) + a_nd = tvm.nd.array(a_np, ctx) + res_nd = tvm.nd.array( + np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx) + if use_imm: + f(a_nd, res_nd) + else: + b_nd = tvm.nd.array(b_np, ctx) + f(a_nd, b_nd, res_nd) + np.testing.assert_equal(res_np, res_nd.asnumpy()) + + check_alu(lambda x, y: x << y, np.left_shift, use_imm=True) + check_alu(tvm.max, np.maximum, use_imm=True) + check_alu(tvm.max, np.maximum) + check_alu(lambda x, y: x + y, use_imm=True) + check_alu(lambda x, y: x + y) + check_alu(lambda x, y: x >> y, np.right_shift, use_imm=True) + + vta.testing.run(_run) + + +def test_relu(): + """Test RELU on ALU""" + def _run(env, remote): + m = 8 + n = 10 + # compute + a = tvm.placeholder( + (m, n, env.BATCH, env.BLOCK_OUT), + name="a", + dtype=env.acc_dtype) + a_buf = tvm.compute( + (m, n, env.BATCH, env.BLOCK_OUT), + lambda *i: a(*i), + "a_buf") # DRAM->SRAM + max_buf = tvm.compute( + (m, n, env.BATCH, env.BLOCK_OUT), + lambda *i: tvm.max(a_buf(*i), 0), + "res_buf") # relu + min_buf = tvm.compute( + (m, n, env.BATCH, env.BLOCK_OUT), + lambda *i: tvm.min(max_buf(*i), (1<<(env.INP_WIDTH-1))-1), + "max_buf") # relu + res = tvm.compute( + (m, n, env.BATCH, env.BLOCK_OUT), + lambda *i: min_buf(*i).astype(env.inp_dtype), + "min_buf") # SRAM->DRAM + # schedule + s = tvm.create_schedule(res.op) + s[a_buf].set_scope(env.acc_scope) # SRAM + s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy) # DRAM->SRAM + s[max_buf].set_scope(env.acc_scope) # SRAM + s[min_buf].set_scope(env.acc_scope) # SRAM + s[max_buf].pragma(max_buf.op.axis[0], env.alu) # compute + s[min_buf].pragma(min_buf.op.axis[0], env.alu) # compute + s[res].pragma(res.op.axis[0], env.dma_copy) # SRAM->DRAM + # build + with vta.build_config(): + mod = vta.build(s, [a, res], "ext_dev", env.target_host) + if not remote: + return + temp = util.tempdir() + mod.save(temp.relpath("load_act.o")) + remote.upload(temp.relpath("load_act.o")) + f = remote.load_module("load_act.o") + # verify + ctx = remote.ext_dev(0) + a_np = np.random.randint( + -256, 256, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(a.dtype) + res_np = np.clip(a_np, 0, (1<<(env.INP_WIDTH-1))-1).astype(res.dtype) + a_nd = tvm.nd.array(a_np, ctx) + res_nd = tvm.nd.array( + np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx) + f(a_nd, res_nd) + np.testing.assert_equal(res_np, res_nd.asnumpy()) + + vta.testing.run(_run) + + +def test_shift_and_scale(): + """Test shift and scale on ALU""" + def _run(env, remote): + m = 2 + n = 8 + imm_shift = np.random.randint(0,8) + imm_scale = np.random.randint(1,5) + # compute + a = tvm.placeholder( + (m, n, env.BATCH, env.BLOCK_OUT), + name="a", dtype=env.acc_dtype) + a_buf = tvm.compute( + (m, n, env.BATCH, env.BLOCK_OUT), + lambda *i: a(*i), + "a_buf") # DRAM->SRAM + res_shift = tvm.compute( + (m, n, env.BATCH, env.BLOCK_OUT), + lambda *i: a_buf(*i)+imm_shift, + "res_shift") # compute + res_scale = tvm.compute( + (m, n, env.BATCH, env.BLOCK_OUT), + lambda *i: res_shift(*i)>>imm_scale, + "res_scale") # compute + res = tvm.compute( + (m, n, env.BATCH, env.BLOCK_OUT), + lambda *i: res_scale(*i).astype(env.inp_dtype), + "res") # SRAM->DRAM + # schedule + s = tvm.create_schedule(res.op) + s[a_buf].set_scope(env.acc_scope) # SRAM + s[res_shift].set_scope(env.acc_scope) # SRAM + s[res_scale].set_scope(env.acc_scope) # SRAM + s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy) # DRAM->SRAM + s[res_shift].pragma(res_shift.op.axis[0], env.alu) # compute + s[res_scale].pragma(res_scale.op.axis[0], env.alu) # compute + s[res].pragma(res.op.axis[0], env.dma_copy) # SRAM->DRAM + # build + mod = vta.build(s, [a, res], "ext_dev", env.target_host) + if not remote: + return + temp = util.tempdir() + mod.save(temp.relpath("load_act.o")) + remote.upload(temp.relpath("load_act.o")) + f = remote.load_module("load_act.o") + # verify + ctx = remote.ext_dev(0) + a_np = np.random.randint( + -10, 10, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(a.dtype) + res_np = np.right_shift((a_np + imm_shift), imm_scale) + res_np = res_np.astype(res.dtype) + a_nd = tvm.nd.array(a_np, ctx) + res_nd = tvm.nd.array( + np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx) + f(a_nd, res_nd) + np.testing.assert_equal(res_np, res_nd.asnumpy()) + + vta.testing.run(_run) + +if __name__ == "__main__": + print("Load/store test") + test_save_load_out() + print("Padded load test") + #test_padded_load() + print("GEMM test") + test_gemm() + test_alu() + print("ALU test") + test_relu() + print("Shift and scale") + test_shift_and_scale()