diff --git a/vta/apps/gemm/CMakeLists.txt b/vta/apps/gemm/CMakeLists.txt new file mode 100644 index 000000000000..0e8128c9f22a --- /dev/null +++ b/vta/apps/gemm/CMakeLists.txt @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +cmake_minimum_required(VERSION 3.2) +project(tsim C CXX) + +set(TVM_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../) +set(VTA_DIR ${TVM_DIR}/vta) + +include_directories("${TVM_DIR}/include") +include_directories("${TVM_DIR}/3rdparty/dlpack/include") +include_directories("${TVM_DIR}/3rdparty/dmlc-core/include") +include_directories("${TVM_DIR}/vta/src/dpi") + +set(CMAKE_C_FLAGS "-O2 -Wall -fPIC -fvisibility=hidden") +set(CMAKE_CXX_FLAGS "-O2 -Wall -fPIC -fvisibility=hidden -std=c++11") + +if (CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND + CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) + set(CMAKE_CXX_FLAGS "-faligned-new ${CMAKE_CXX_FLAGS}") +endif() + +file(GLOB TSIM_SW_SRC src/driver.cc) +list(APPEND TSIM_SW_SRC ${VTA_DIR}/src/vmem/virtual_memory.cc) +list(APPEND TSIM_SW_SRC ${VTA_DIR}/src/dpi/module.cc) + +add_library(sw SHARED ${TSIM_SW_SRC}) +target_include_directories(sw PRIVATE ${VTA_DIR}/include ${VTA_DIR}/src) + +if(APPLE) + set_target_properties(sw PROPERTIES LINK_FLAGS "-undefined dynamic_lookup") +endif(APPLE) diff --git a/vta/apps/gemm/Makefile b/vta/apps/gemm/Makefile new file mode 100644 index 000000000000..8ad1481cf7fc --- /dev/null +++ b/vta/apps/gemm/Makefile @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +export PYTHONPATH:=$(PWD)/python:$(PYTHONPATH) + +BUILD_NAME = build +build_dir = $(abspath .)/$(BUILD_NAME) + +default: chisel driver + python3 tests/python/chisel_accel.py serial + +serial: + python3 tests/python/chisel_accel.py serial + +parallel: + python3 tests/python/chisel_accel.py parallel + +driver: | $(build_dir) + cd $(build_dir) && cmake .. && make + +$(build_dir): + mkdir -p $@ + +chisel: + make -C hardware/chisel + +clean: + -rm -rf $(build_dir) + make -C hardware/chisel clean diff --git a/vta/apps/gemm/README.md b/vta/apps/gemm/README.md new file mode 100644 index 000000000000..fba5924e7a4f --- /dev/null +++ b/vta/apps/gemm/README.md @@ -0,0 +1,50 @@ + + + + + + + + + + + + + + + + + +VTA TSIM Application +====================== +Prior to this application, please take a look at `/vta/apps/tsim_example` for installation +This is an application that performs Bit Serial Multiplication for GEMM utilizing TSIM. + +**Bit Serial Multiplication for GEMM:** + +General Matrix Multiplications (GEMM), are mostly calculated by repeatly calculating the dot product for each pair of vectors. +The dot product is calculated by summing every product of the vector pair. +We approach this operation with slicing and shifting, like how basic multiplication works, each vector elements before we accumulate them. +We can sufficiently reduce the cycles required to perform a gemm given that the data bit width is small. This GEMM application uses TSIM for future accerlerator prototypes. + +* Test Chisel3 backend with bit serial GEMM + * Go to `/vta/apps/gemm` + * Run `make` + +* If you have already compiled chisel backend (i.e. ran `make`) + * Bit Serial test with another input set, run `make serial` + * Bit parallel test with another input set, run `make parallel` + +* Some steps for creating your own custom TSIM application + * Go to `/vta/apps/gemm` + * Create custom circuit within `./hardware/chisel/src/scala.main/accel/Compute.scala` + * Map the according Registers in `./hardware/chisel/src/scala.main/accel/RegFile.scala` + * Create your test script + * Map the registers in `./src/driver.cc` and link it with both `RegFile.scala` and the test script + * Understanding of `/vta/apps/tsim_example`, which performs add by one to a vector, is highly encouraged to create a more complex application + +* Some pointers + * Chisel3 tests in `/vta/apps/gemm/tests/python` + * Chisel3 accelerator backend `/vta/apps/gemm/hardware/chisel` + * Software C++ driver (backend) that handles the accelerator `/vta/apps/gemm/src/driver.cc` + * Software Python driver (frontend) that handles the accelerator `/vta/apps/gemm/python/accel` diff --git a/vta/apps/gemm/hardware/chisel/Makefile b/vta/apps/gemm/hardware/chisel/Makefile new file mode 100644 index 000000000000..4462b7a88477 --- /dev/null +++ b/vta/apps/gemm/hardware/chisel/Makefile @@ -0,0 +1,112 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +ifeq (, $(shell which verilator)) + $(error "No Verilator in $(PATH), consider doing apt-get install verilator") +endif + +# Change VERILATOR_INC_DIR if Verilator is installed on a different location +ifeq (, $(VERILATOR_INC_DIR)) + ifeq (, $(wildcard /usr/local/share/verilator/include/*)) + ifeq (, $(wildcard /usr/share/verilator/include/*)) + $(error "Verilator include directory is not set properly") + else + VERILATOR_INC_DIR := /usr/share/verilator/include + endif + else + VERILATOR_INC_DIR := /usr/local/share/verilator/include + endif +endif + +TOP = TestAccel +BUILD_NAME = build +USE_TRACE = 1 +LIBNAME = libhw + +vta_dir = $(abspath ../../../../) +tvm_dir = $(abspath ../../../../../) +build_dir = $(abspath .)/$(BUILD_NAME) +verilator_build_dir = $(build_dir)/verilator +chisel_build_dir = $(build_dir)/chisel + +verilator_opt = --cc +verilator_opt += +define+RANDOMIZE_GARBAGE_ASSIGN +verilator_opt += +define+RANDOMIZE_REG_INIT +verilator_opt += +define+RANDOMIZE_MEM_INIT +verilator_opt += --x-assign unique +verilator_opt += --output-split 20000 +verilator_opt += --output-split-cfuncs 20000 +verilator_opt += --top-module ${TOP} +verilator_opt += -Mdir ${verilator_build_dir} +verilator_opt += -I$(chisel_build_dir) + +cxx_flags = -O2 -Wall -fPIC -shared +cxx_flags += -fvisibility=hidden -std=c++11 +cxx_flags += -DVL_TSIM_NAME=V$(TOP) +cxx_flags += -DVL_PRINTF=printf +cxx_flags += -DVL_USER_FINISH +cxx_flags += -DVM_COVERAGE=0 +cxx_flags += -DVM_SC=0 +cxx_flags += -Wno-sign-compare +cxx_flags += -include V$(TOP).h +cxx_flags += -I$(verilator_build_dir) +cxx_flags += -I$(VERILATOR_INC_DIR) +cxx_flags += -I$(VERILATOR_INC_DIR)/vltstd +cxx_flags += -I$(vta_dir)/include +cxx_flags += -I$(tvm_dir)/include +cxx_flags += -I$(tvm_dir)/3rdparty/dlpack/include + +cxx_files = $(VERILATOR_INC_DIR)/verilated.cpp +cxx_files += $(VERILATOR_INC_DIR)/verilated_dpi.cpp +cxx_files += $(wildcard $(verilator_build_dir)/*.cpp) +cxx_files += $(vta_dir)/hardware/dpi/tsim_device.cc + +ifneq ($(USE_TRACE), 0) + verilator_opt += --trace + cxx_flags += -DVM_TRACE=1 + cxx_flags += -DTSIM_TRACE_FILE=$(verilator_build_dir)/$(TOP).vcd + cxx_files += $(VERILATOR_INC_DIR)/verilated_vcd_c.cpp +else + cxx_flags += -DVM_TRACE=0 +endif + +# The following is to be consistent with cmake +ifeq ($(shell uname), Darwin) + lib_path = $(build_dir)/$(LIBNAME).dylib +else + lib_path = $(build_dir)/$(LIBNAME).so +endif + +default: lib + +lib: $(lib_path) +$(lib_path): $(verilator_build_dir)/V$(TOP).cpp + g++ $(cxx_flags) $(cxx_files) -o $@ + +verilator: $(verilator_build_dir)/V$(TOP).cpp +$(verilator_build_dir)/V$(TOP).cpp: $(chisel_build_dir)/$(TOP).v + verilator $(verilator_opt) $< + +verilog: $(chisel_build_dir)/$(TOP).v +$(chisel_build_dir)/$(TOP).v: install_vta_package + sbt 'test:runMain test.Elaborate --target-dir $(chisel_build_dir) --top-name $(TOP)' + +install_vta_package: + cd $(vta_dir)/hardware/chisel && sbt publishLocal + +clean: + -rm -rf $(build_dir) target project/target project/project diff --git a/vta/apps/gemm/hardware/chisel/build.sbt b/vta/apps/gemm/hardware/chisel/build.sbt new file mode 100644 index 000000000000..a2afc0d9d362 --- /dev/null +++ b/vta/apps/gemm/hardware/chisel/build.sbt @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +name := "accel" +version := "0.1.0-SNAPSHOT" +organization := "edu.washington.cs" + +def scalacOptionsVersion(scalaVersion: String): Seq[String] = { + Seq() ++ { + // If we're building with Scala > 2.11, enable the compile option + // switch to support our anonymous Bundle definitions: + // https://github.com/scala/bug/issues/10047 + CrossVersion.partialVersion(scalaVersion) match { + case Some((2, scalaMajor: Long)) if scalaMajor < 12 => Seq() + case _ => Seq( + "-Xsource:2.11", + "-language:reflectiveCalls", + "-language:implicitConversions", + "-deprecation", + "-Xlint", + "-Ywarn-unused", + ) + } + } +} + +def javacOptionsVersion(scalaVersion: String): Seq[String] = { + Seq() ++ { + // Scala 2.12 requires Java 8. We continue to generate + // Java 7 compatible code for Scala 2.11 + // for compatibility with old clients. + CrossVersion.partialVersion(scalaVersion) match { + case Some((2, scalaMajor: Long)) if scalaMajor < 12 => + Seq("-source", "1.7", "-target", "1.7") + case _ => + Seq("-source", "1.8", "-target", "1.8") + } + } +} + +scalaVersion := "2.11.12" + +resolvers ++= Seq( + Resolver.sonatypeRepo("snapshots"), + Resolver.sonatypeRepo("releases")) + +libraryDependencies ++= Seq( + "edu.berkeley.cs" %% "chisel3" % "3.1.7", + "edu.washington.cs" %% "vta" % "0.1.0-SNAPSHOT", +) + +scalacOptions ++= scalacOptionsVersion(scalaVersion.value) +javacOptions ++= javacOptionsVersion(scalaVersion.value) diff --git a/vta/apps/gemm/hardware/chisel/project/build.properties b/vta/apps/gemm/hardware/chisel/project/build.properties new file mode 100644 index 000000000000..7e2b74b51a4f --- /dev/null +++ b/vta/apps/gemm/hardware/chisel/project/build.properties @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +sbt.version = 1.1.1 diff --git a/vta/apps/gemm/hardware/chisel/project/plugins.sbt b/vta/apps/gemm/hardware/chisel/project/plugins.sbt new file mode 100644 index 000000000000..79ffb2245d52 --- /dev/null +++ b/vta/apps/gemm/hardware/chisel/project/plugins.sbt @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +logLevel := Level.Warn diff --git a/vta/apps/gemm/hardware/chisel/src/main/scala/accel/Accel.scala b/vta/apps/gemm/hardware/chisel/src/main/scala/accel/Accel.scala new file mode 100644 index 000000000000..add07c320c1e --- /dev/null +++ b/vta/apps/gemm/hardware/chisel/src/main/scala/accel/Accel.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package accel + +import chisel3._ +import vta.dpi._ + +/** Add-by-one accelerator. + * + * ___________ ___________ + * | | | | + * | HostDPI | <--> | RegFile | <->| + * |_________| |_________| | + * | + * ___________ ___________ | + * | | | | | + * | MemDPI | <--> | Compute | <->| + * |_________| |_________| + * + */ +case class AccelConfig() { + val nCtrl = 1 + val nECnt = 1 + val nVals = 4 + val nPtrs = 3 + val regBits = 32 + val ptrBits = 2*regBits +} + +class Accel extends Module { + val io = IO(new Bundle { + val host = new VTAHostDPIClient + val mem = new VTAMemDPIMaster + }) + implicit val config = AccelConfig() + val rf = Module(new RegFile) + val ce = Module(new Compute) + rf.io.host <> io.host + io.mem <> ce.io.mem + ce.io.launch := rf.io.launch + rf.io.finish := ce.io.finish + rf.io.ecnt <> ce.io.ecnt + ce.io.vals <> rf.io.vals + ce.io.ptrs <> rf.io.ptrs +} diff --git a/vta/apps/gemm/hardware/chisel/src/main/scala/accel/Compute.scala b/vta/apps/gemm/hardware/chisel/src/main/scala/accel/Compute.scala new file mode 100644 index 000000000000..325fce1bf38a --- /dev/null +++ b/vta/apps/gemm/hardware/chisel/src/main/scala/accel/Compute.scala @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package accel + +import chisel3._ +import chisel3.util._ +import vta.dpi._ + +/** Compute + * + * Bit Slice GEMM: + * + * 1. Wait for launch to be asserted + * 2. Issue 2 read request for 8-byte value at inp1_baddr address and inp2_baddr address + * 3. Wait for the value + * 4. Increment read-address for next value + * 5. Wait for sliced accumulator + * 6. Check if counter (cnt) is equal to length process, + otherwise goto step 2 + * 7. Check if reset slice accumulator + * 8. Wait for overall accumulator + * 8. Issue a write request for 8-byte value at out_baddr address + */ +class Compute(implicit config: AccelConfig) extends Module { + val io = IO(new Bundle { + val launch = Input(Bool()) + val finish = Output(Bool()) + val ecnt = Vec(config.nECnt, ValidIO(UInt(config.regBits.W))) + val vals = Input(Vec(config.nVals, UInt(config.regBits.W))) + val ptrs = Input(Vec(config.nPtrs, UInt(config.ptrBits.W))) + val mem = new VTAMemDPIMaster + }) + val sIdle :: sReadAReq :: sReadAData :: sReadBReq :: sReadBData :: sWriteReq :: sWriteData :: Nil = Enum(7) + val state = RegInit(sIdle) + val shift = io.vals(0) + val length = io.vals(1) + val rstAccum = io.vals(2) + val startDot = io.vals(3) + val cycles = RegInit(0.U(config.regBits.W)) + val reg1 = Reg(chiselTypeOf(io.mem.rd.bits)) + val reg2 = Reg(chiselTypeOf(io.mem.rd.bits)) + val cnt = Reg(UInt(config.regBits.W)) + val raddr1 = Reg(UInt(config.ptrBits.W)) + val raddr2 = Reg(UInt(config.ptrBits.W)) + val waddr = Reg(UInt(config.ptrBits.W)) + + switch (state) { + is (sIdle) { + when (io.launch) { + state := sReadAReq + } + } + // Read + is (sReadAReq) { + state := sReadAData + } + is (sReadAData) { + when (io.mem.rd.valid) { + state := sReadBReq + } + } + is (sReadBReq) { + state := sReadBData + } + is (sReadBData) { + when (io.mem.rd.valid) { + state := sWriteReq + } + } + // Write + is (sWriteReq) { + state := sWriteData + } + is (sWriteData) { + when (cnt === (length - 1.U)) { + state := sIdle + } .otherwise { + state := sReadAReq + } + } + } + + val last = state === sWriteData && cnt === (length - 1.U) + + // cycle counter + when (state === sIdle) { + cycles := 0.U + } .otherwise { + cycles := cycles + 1.U + } + + io.ecnt(0).valid := last + io.ecnt(0).bits := cycles + + // calculate next address + when (state === sIdle) { + raddr1 := io.ptrs(0) + raddr2 := io.ptrs(1) + waddr := io.ptrs(2) + } .elsewhen (state === sWriteData) { // increment input array by 1-byte + raddr1 := raddr1 + 1.U + raddr2 := raddr2 + 1.U + waddr := waddr + } + + // create request + io.mem.req.valid := state === sReadAReq | state === sReadBReq | state === sWriteReq + io.mem.req.opcode := state === sWriteReq + io.mem.req.len := 0.U // one-word-per-request + io.mem.req.addr := Mux(state === sReadAReq | state === sReadBReq, Mux(state === sReadAReq, raddr1, raddr2), waddr) + + // read + when (state === sReadAData && io.mem.rd.valid) { + reg1 := io.mem.rd.bits(7, 0) + } + + when (state === sReadBData && io.mem.rd.valid) { + reg2 := io.mem.rd.bits(7, 0) + } + + io.mem.rd.ready := state === sReadAData | state === sReadBData + + + val sliceAccum = Module(new Accumulator(63)) + val overallAccum = Module(new Accumulator(64)) + + sliceAccum.io.valid := state === sWriteReq // 2 inputs have been processed + sliceAccum.io.in := reg1 * reg2 + sliceAccum.io.clear := startDot + overallAccum.io.clear := rstAccum + overallAccum.io.valid := last // last element has been processed + overallAccum.io.in := sliceAccum.io.sum << shift(7,0) // limit to 8 bits + + // write + io.mem.wr.valid := overallAccum.io.ready + io.mem.wr.bits := overallAccum.io.sum + + + // count read/write + when (state === sIdle) { + cnt := 0.U + } .elsewhen (state === sWriteData) { + cnt := cnt + 1.U + } + + io.finish := overallAccum.io.ready // data has been added +} + + +class Accumulator(dataBits: Int = 8) extends Module { + val io = IO(new Bundle { + val clear = Input(Bool()) + val valid = Input(Bool()) + val ready = Output(Bool()) + val in = Input(UInt(dataBits.W)) + val sum = Output(UInt((dataBits).W)) + }) + + val reg = RegInit(0.U((dataBits).W)) + val ready = RegNext(io.valid) + when (io.clear) { + reg := 0.U + } .elsewhen (io.valid) { + reg := reg + io.in + } + io.ready := ready + io.sum := reg +} + diff --git a/vta/apps/gemm/hardware/chisel/src/main/scala/accel/RegFile.scala b/vta/apps/gemm/hardware/chisel/src/main/scala/accel/RegFile.scala new file mode 100644 index 000000000000..6f0bdbb6b34c --- /dev/null +++ b/vta/apps/gemm/hardware/chisel/src/main/scala/accel/RegFile.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package accel + +import chisel3._ +import chisel3.util._ +import vta.dpi._ + +/** Register File. + * + * Six 32-bit register file. + * + * ------------------------------- + * Register description | addr + * -------------------------|----- + * Control status register | 0x00 + * Cycle counter | 0x04 + * Shift value | 0x08 + * Vector length | 0x0c + * Reset Accumulator | 0x10 + * Reset Dot Module | 0x14 + * Input1 pointer lsb | 0x18 + * Input1 pointer msb | 0x1c + * Input2 pointer lsb | 0x20 + * Input2 pointer msb | 0x24 + * Output pointer lsb | 0x28 + * Output pointer msb | 0x2c + * ------------------------------- + + * ------------------------------ + * Control status register | bit + * ------------------------------ + * Launch | 0 + * Finish | 1 + * ------------------------------ + */ +class RegFile(implicit config: AccelConfig) extends Module { + val io = IO(new Bundle { + val launch = Output(Bool()) + val finish = Input(Bool()) + val ecnt = Vec(config.nECnt, Flipped(ValidIO(UInt(config.regBits.W)))) + val vals = Output(Vec(config.nVals, UInt(config.regBits.W))) + val ptrs = Output(Vec(config.nPtrs, UInt(config.ptrBits.W))) + val host = new VTAHostDPIClient + }) + val sIdle :: sRead :: Nil = Enum(2) + val state = RegInit(sIdle) + + switch (state) { + is (sIdle) { + when (io.host.req.valid && !io.host.req.opcode) { + state := sRead + } + } + is (sRead) { + state := sIdle + } + } + + io.host.req.deq := state === sIdle & io.host.req.valid + + val nTotal = config.nCtrl + config.nECnt + config.nVals + (2*config.nPtrs) + val reg = Seq.fill(nTotal)(RegInit(0.U.asTypeOf(chiselTypeOf(io.host.req.value)))) + val addr = Seq.tabulate(nTotal)(_ * 4) + val reg_map = (addr zip reg) map { case (a, r) => a.U -> r } + val eo = config.nCtrl + val vo = eo + config.nECnt + val po = vo + config.nVals + + when (io.finish) { + reg(0) := "b_10".U + } .elsewhen (state === sIdle && io.host.req.valid && + io.host.req.opcode && addr(0).U === io.host.req.addr) { + reg(0) := io.host.req.value + } + + for (i <- 0 until config.nECnt) { + when (io.ecnt(i).valid) { + reg(eo + i) := io.ecnt(i).bits + } .elsewhen (state === sIdle && io.host.req.valid && + io.host.req.opcode && addr(eo + i).U === io.host.req.addr) { + reg(eo + i) := io.host.req.value + } + } + + for (i <- 0 until (config.nVals + (2*config.nPtrs))) { + when (state === sIdle && io.host.req.valid && + io.host.req.opcode && addr(vo + i).U === io.host.req.addr) { + reg(vo + i) := io.host.req.value + } + } + + val rdata = RegInit(0.U.asTypeOf(chiselTypeOf(io.host.req.value))) + when (state === sIdle && io.host.req.valid && !io.host.req.opcode) { + rdata := MuxLookup(io.host.req.addr, 0.U, reg_map) + } + + io.host.resp.valid := state === sRead + io.host.resp.bits := rdata + + io.launch := reg(0)(0) + + for (i <- 0 until config.nVals) { + io.vals(i) := reg(vo + i) + } + + for (i <- 0 until config.nPtrs) { + io.ptrs(i) := Cat(reg(po + 2*i + 1), reg(po + 2*i)) + } +} diff --git a/vta/apps/gemm/hardware/chisel/src/test/scala/dut/TestAccel.scala b/vta/apps/gemm/hardware/chisel/src/test/scala/dut/TestAccel.scala new file mode 100644 index 000000000000..d931620ec67d --- /dev/null +++ b/vta/apps/gemm/hardware/chisel/src/test/scala/dut/TestAccel.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package test + +import chisel3._ +import chisel3.experimental.MultiIOModule +import vta.dpi._ +import accel._ + +/** VTA simulation shell. + * + * Instantiate Host and Memory DPI modules. + * + */ +class VTASimShell extends MultiIOModule { + val host = IO(new VTAHostDPIMaster) + val mem = IO(new VTAMemDPIClient) + val sim_clock = IO(Input(Clock())) + val sim_wait = IO(Output(Bool())) + val mod_sim = Module(new VTASimDPI) + val mod_host = Module(new VTAHostDPI) + val mod_mem = Module(new VTAMemDPI) + mod_mem.io.clock := clock + mod_mem.io.reset := reset + mod_mem.io.dpi <> mem + mod_host.io.clock := clock + mod_host.io.reset := reset + host <> mod_host.io.dpi + mod_sim.io.clock := sim_clock + mod_sim.io.reset := reset + sim_wait := mod_sim.io.dpi_wait +} + +/** Test accelerator. + * + * Instantiate and connect the simulation-shell and the accelerator. + * + */ +class TestAccel extends MultiIOModule { + val sim_clock = IO(Input(Clock())) + val sim_wait = IO(Output(Bool())) + val sim_shell = Module(new VTASimShell) + val vta_accel = Module(new Accel) + sim_shell.sim_clock := sim_clock + sim_wait := sim_shell.sim_wait + sim_shell.mem <> vta_accel.io.mem + vta_accel.io.host <> sim_shell.host +} + +/** Generate TestAccel as top module */ +object Elaborate extends App { + chisel3.Driver.execute(args, () => new TestAccel) +} diff --git a/vta/apps/gemm/python/__init__.py b/vta/apps/gemm/python/__init__.py new file mode 100644 index 000000000000..784036f7d0ae --- /dev/null +++ b/vta/apps/gemm/python/__init__.py @@ -0,0 +1 @@ +from . import tsim diff --git a/vta/apps/gemm/python/tsim.py b/vta/apps/gemm/python/tsim.py new file mode 100644 index 000000000000..f5e56489aafd --- /dev/null +++ b/vta/apps/gemm/python/tsim.py @@ -0,0 +1,72 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import ctypes +import os.path as osp +from sys import platform + +def get_ext(): + """Return shared library extension""" + return ".dylib" if platform == "darwin" else ".so" + +def load_dll(dll): + """Load shared library + + Parameters + ------------ + dll : str + Path for shared library + + Returns + ------------ + The shared library + """ + try: + return [ctypes.CDLL(dll, ctypes.RTLD_GLOBAL)] + except OSError: + return [] + +def load_sw(): + """Load all software shared libraries""" + cur_path = osp.dirname(osp.abspath(osp.expanduser(__file__))) + sw_libname = "libsw" + get_ext() + sw_lib = osp.join(cur_path, "..", "build", sw_libname) + load_dll(sw_lib) + +def init(hw_backend): + """Init hardware and software shared library for accelerator + + Parameters + ------------ + hw_backend : str + Hardware backend can be verilog or chisel + + """ + cur_path = osp.dirname(osp.abspath(osp.expanduser(__file__))) + hw_libname = "libhw" + get_ext() + if hw_backend in ("verilog", "chisel"): + hw_lib = osp.join(cur_path, "..", "hardware", hw_backend, "build", hw_libname) + load_sw() + m = tvm.module.load(hw_lib, "vta-tsim") + f = tvm.get_global_func("tvm.vta.tsim.init") + f(m) + +def load_module(): + """Return driver function""" + load_sw() + return tvm.get_global_func("tvm.vta.driver") diff --git a/vta/apps/gemm/src/driver.cc b/vta/apps/gemm/src/driver.cc new file mode 100644 index 000000000000..8d380c323c9a --- /dev/null +++ b/vta/apps/gemm/src/driver.cc @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +#include "vmem/virtual_memory.h" + +namespace vta { +namespace driver { + +using vta::dpi::DPIModuleNode; +using tvm::runtime::Module; + +class DPILoader { + public: + ~DPILoader() { + dpi_->SimResume(); + dpi_->SimFinish(); + } + + void Init(Module module) { + mod_ = module; + dpi_ = this->Get(); + dpi_->SimLaunch(); + dpi_->SimWait(); + } + + DPIModuleNode* Get() { + return static_cast(mod_.operator->()); + } + + static DPILoader* Global() { + static DPILoader inst; + return &inst; + } + + // TVM module + Module mod_; + // DPI Module + DPIModuleNode* dpi_{nullptr}; +}; + +class Device { + public: + Device() { + loader_ = DPILoader::Global(); + } + + uint32_t Run(DLTensor* inp1, DLTensor* inp2, uint32_t shiftVal, DLTensor* out, uint32_t reset) { + uint32_t cycles; + uint32_t length = inp1->shape[0]; + size_t size1 = (inp1->dtype.bits >> 3) * length; + size_t size2 = (inp2->dtype.bits >> 3) * length; + size_t size3 = (64 >> 3); + inp1_ = this->MemAlloc(size1); + inp2_ = this->MemAlloc(size2); + out_ = this->MemAlloc(size3); + this->MemCopyFromHost(inp1_, inp1->data, size1); + this->MemCopyFromHost(inp2_, inp2->data, size2); + this->Init(); + this->Launch(length, shiftVal, reset); + cycles = this->WaitForCompletion(); + this->MemCopyToHost(out->data, out_, size3); + this->MemFree(inp1_); + this->MemFree(inp2_); + this->MemFree(out_); + return cycles; + } + + private: + void Init() { + dpi_ = loader_->Get(); + dpi_->SimResume(); + } + + void* MemAlloc(size_t size) { + void * addr = vta::vmem::VirtualMemoryManager::Global()->Alloc(size); + return reinterpret_cast(vta::vmem::VirtualMemoryManager::Global()->GetPhyAddr(addr)); + } + + void MemFree(void* buf) { + void * addr = vta::vmem::VirtualMemoryManager::Global()->GetAddr(reinterpret_cast(buf)); + vta::vmem::VirtualMemoryManager::Global()->Free(addr); + } + + vta_phy_addr_t MemGetPhyAddr(void* buf) { + return reinterpret_cast(reinterpret_cast(buf)); + } + + void MemCopyFromHost(void* dst, const void* src, size_t size) { + vta::vmem::VirtualMemoryManager::Global()->MemCopyFromHost(dst, src, size); + } + + void MemCopyToHost(void* dst, const void* src, size_t size) { + vta::vmem::VirtualMemoryManager::Global()->MemCopyToHost(dst, src, size); + } + + void Launch(uint32_t length, uint32_t shiftVal, uint32_t reset) { + dpi_->WriteReg(0x08, shiftVal); + dpi_->WriteReg(0x0c, length); // vector length + dpi_->WriteReg(0x18, this->MemGetPhyAddr(inp1_)); + dpi_->WriteReg(0x20, this->MemGetPhyAddr(inp2_)); + dpi_->WriteReg(0x28, this->MemGetPhyAddr(out_)); + dpi_->WriteReg(0x00, 0x1); // launch + dpi_->WriteReg(0x00, 0x0); // launch + + if (reset == 1) { + dpi_->WriteReg(0x10, 0x1); // reset accum + dpi_->WriteReg(0x10, 0x0); // stop reset accum + } + dpi_->WriteReg(0x14, 0x1); // reset dot + dpi_->WriteReg(0x14, 0x0); // stop reset dot + } + + uint32_t WaitForCompletion() { + uint32_t i, val; + for (i = 0; i < wait_cycles_; i++) { + val = dpi_->ReadReg(0x00); + if (val == 2) break; // finish + } + val = dpi_->ReadReg(0x04); + dpi_->SimWait(); + return val; + } + + // wait cycles + uint32_t wait_cycles_{100000000}; + // DPI loader + DPILoader* loader_{nullptr}; + // DPI Module + DPIModuleNode* dpi_{nullptr}; + // input vm ptr + void* inp1_{nullptr}; + void* inp2_{nullptr}; + // output vm ptr + void* out_{nullptr}; +}; + +using tvm::runtime::TVMRetValue; +using tvm::runtime::TVMArgs; + +TVM_REGISTER_GLOBAL("tvm.vta.tsim.init") +.set_body([](TVMArgs args, TVMRetValue* rv) { + Module m = args[0]; + DPILoader::Global()->Init(m); + }); + +TVM_REGISTER_GLOBAL("tvm.vta.driver") +.set_body([](TVMArgs args, TVMRetValue* rv) { + DLTensor* A = args[0]; + DLTensor* B = args[1]; + DLTensor* C = args[3]; + Device dev_; + uint32_t cycles = dev_.Run(A, B, static_cast(args[2]), C, static_cast(args[4])); + *rv = static_cast(cycles); + }); + +} // namespace driver +} // namespace vta diff --git a/vta/apps/gemm/tests/python/chisel_accel.py b/vta/apps/gemm/tests/python/chisel_accel.py new file mode 100644 index 000000000000..4aed5636b50e --- /dev/null +++ b/vta/apps/gemm/tests/python/chisel_accel.py @@ -0,0 +1,182 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import numpy as np +import tsim +import sys + +""" Vector Bit Slice and Pack Function +Parameters +---------- +A : Vector to be sliced and packed +slice_width : slice width + +Returnsi +--------- +C: 2d matrix where each cloumn (because of bit packing) represents each bit slice of A +""" +def slice(A, slice_width): + assert np.log2(slice_width) % 1 == 0, "only power of 2 is supported" + dtype = type(A[0]) + row = 0 + # currently only supports uint + if dtype is np.uint8: row = 8 // slice_width + elif dtype is np.uint16: row = 16 // slice_width + elif dtype is np.uint32: row = 32 // slice_width + elif dtype is np.uint64: row = 64 // slice_width + else: raise ValueError("datatype " + str(dtype) + "currently not supported") + if (row >= 8): + dtype = 'uint' + str(row) + else: + dtype = 'uint8' + + C = np.zeros((row, len(A))).astype(dtype) # sliced and transform + + # create mask + slice_mask = 2**(slice_width)-1 + # slice and pack + for x in range(len(A)): + for y in range(row): + C[y][x] = (np.uint64(A[x]) >> np.uint64(slice_width * y)) & np.uint64(slice_mask) + return C + +""" Matrix Multiplication Function +Parameters +---------- +A : Matrix A +B: Matrix B +w_width : weight slice width +a_width : activation slice width + +Returns +--------- +C: result of A * B +""" +# A is a n*m matrix, B is a m*p matrix(not transposed yet) +def matrix_multiply(A, B, w_width, a_width): + assert A.shape[1] == B.shape[0], "can't perform multiplication" + BT = B.transpose() + cycles = 0 + C = np.zeros((A.shape[0], B.shape[1])).astype('uint64') + for i in range(A.shape[0]): + for j in range(B.shape[1]): + # C[i, j] = A[i].dot(BT[j]) + A_sliced = slice(A[i], w_width) + B_sliced = slice(BT[j], a_width) + + C[i, j] = compute(A_sliced, B_sliced, w_width, a_width) + test = test_accel(A_sliced, B_sliced, w_width, a_width) + cycles += test[1] + np.testing.assert_equal(C[i,j], A[i].astype('uint64').dot(BT[j])) + print("PASS SW serial & parallel") + + np.testing.assert_equal(test[0], C[i, j]) + print("PASS SW & HW bit serial") + + np.testing.assert_equal(test[0], A[i].astype('uint64').dot(BT[j])) + print("PASS SW bit parallel & HW bit parallel") + + print("result: ") + print(C) + print("ALL TESTS PASSED, cycles: " + str(cycles)) + return C + +""" Software Verification Function""" +# takes 2 matrix input (sliced and packed) +def compute(A, B, w_width, a_width): + assert A.shape[1] == B.shape[1], "sliced shape not match" + # reset hardware accumulator + accum = 0 + for x in range(A.shape[0]): + for y in range(B.shape[0]): + # hardware implementation + accum += np.uint64(A[x]).dot(np.uint64(B[y])) << np.uint64(x*w_width + y*a_width) + # get value from accumulator + return accum + +"""Testing Function for Dot Product""" +def test_accel(A, B, w_width, a_width): + assert A.shape[1] == B.shape[1], "sliced shape not match" + + dtype = A.dtype + ctx = tvm.cpu(0) + f = tsim.load_module() + + a_arr = [] + b_arr = [] + for i in range(A.shape[0]): + list_a = np.zeros(A.shape[1]).astype(dtype) + for j in range(A.shape[1]): + list_a[j] = A[i][j] + a_arr.append(tvm.nd.array(list_a.astype(dtype), ctx)) + + for i in range(B.shape[0]): + list_b = np.zeros(B.shape[1]).astype(dtype) + for j in range(B.shape[1]): + list_b[j] = B[i][j] + b_arr.append(tvm.nd.array(list_b.astype(dtype), ctx)) + + cycles = 0 + + accum = tvm.nd.array(np.array([0]).astype("uint64"), ctx) + for i in range(len(a_arr)): + for j in range(len(b_arr)): + shift = np.uint8(i*w_width + j*a_width) + if i == 0 and j == 0: + cycles += f(a_arr[i], b_arr[j], shift, accum, np.uint32(1)) # reset accumulator + else: + cycles += f(a_arr[i], b_arr[j], shift, accum, np.uint32(0)) # no reset + + return (accum.asnumpy()[0], cycles) + +""" Matrix Generator +Parameters +---------- +dtype : String, datatype generated (supports only uint) +w_width : weight bit slices(needs to be less than actual bit width) +a_width : activation bit slices(needs to be less than actual bit width) +""" +def top_test(dtype, w_width, a_width): + + rmax = np.random.randint(256) + # random matrix generation (dimension up to 8) + rrow = np.random.randint(7) + 1 + rclmn = np.random.randint(7) + 1 + rrow2 = np.random.randint(7) + 1 + A = np.random.randint(rmax, size=(rrow,rclmn)).astype(dtype) + B = np.random.randint(rmax, size=(rclmn,rrow2)).astype(dtype) + + print("A: ") + print(A) + print("\n") + print("B: ") + print(B) + print("\n") + matrix_multiply(A, B, w_width, a_width) + + +if __name__ == "__main__": + tsim.init("chisel") + for i in range(1): + # reg1 and reg2 bits in Compute.scala must be modified for slices greater than 8 bits + if sys.argv[1] == 'serial': + # generates a random uint8 GEMM with 2-bit(8/4) weight and 4-bit(8/2) activation + top_test("uint8",4, 2) + elif sys.argv[1] == 'parallel': + # generates a random uint8 GEMM with 8-bit weight and 8-bit activation (bit parallel) + top_test('uint8', 1, 1)