diff --git a/vta/apps/gemm/hardware/chisel/src/main/scala/accel/Compute.scala b/vta/apps/gemm/hardware/chisel/src/main/scala/accel/Compute.scala index 325fce1bf38a..6bfe3e054121 100644 --- a/vta/apps/gemm/hardware/chisel/src/main/scala/accel/Compute.scala +++ b/vta/apps/gemm/hardware/chisel/src/main/scala/accel/Compute.scala @@ -22,21 +22,31 @@ package accel import chisel3._ import chisel3.util._ import vta.dpi._ +import vta.core._ +import vta.util.config._ +import vta.shell._ +class TestConfig extends Config(new CoreConfig ++ new PynqConfig) /** Compute * * Bit Slice GEMM: * * 1. Wait for launch to be asserted - * 2. Issue 2 read request for 8-byte value at inp1_baddr address and inp2_baddr address + * 2. Issue 1 read request for 8-bit value at inp1_baddr address (read matrix) * 3. Wait for the value * 4. Increment read-address for next value - * 5. Wait for sliced accumulator - * 6. Check if counter (cnt) is equal to length process, - otherwise goto step 2 - * 7. Check if reset slice accumulator - * 8. Wait for overall accumulator - * 8. Issue a write request for 8-byte value at out_baddr address + * 5. Repeat until all inp1 data have been read + + * 6. Issue 1 read request for 8-bit value at inp2_baddr address (read vector) + * 7. Wait for the value + * 8. Increment read-address for next value + * 9. Repeat until all inp2 data have been read + + * 10. Wait for output to be calculated + * 11. Issue a write request for 8-byte value at out_baddr address + * 12. Increment write-address for next value to write + * 13. Check if counter (cntout) is equal to length to asser finish, + otherwise go to step 11 */ class Compute(implicit config: AccelConfig) extends Module { val io = IO(new Bundle { @@ -47,19 +57,24 @@ class Compute(implicit config: AccelConfig) extends Module { val ptrs = Input(Vec(config.nPtrs, UInt(config.ptrBits.W))) val mem = new VTAMemDPIMaster }) - val sIdle :: sReadAReq :: sReadAData :: sReadBReq :: sReadBData :: sWriteReq :: sWriteData :: Nil = Enum(7) + implicit val p: Parameters = new TestConfig + val sIdle :: sReadAReq :: sReadAData :: sReadADone ::sReadBReq :: sReadBData :: sReadBDone :: sInpDone ::sWait:: sWriteReq :: sWriteData :: sWriteDone :: Nil = Enum(12) val state = RegInit(sIdle) val shift = io.vals(0) val length = io.vals(1) val rstAccum = io.vals(2) val startDot = io.vals(3) val cycles = RegInit(0.U(config.regBits.W)) - val reg1 = Reg(chiselTypeOf(io.mem.rd.bits)) - val reg2 = Reg(chiselTypeOf(io.mem.rd.bits)) - val cnt = Reg(UInt(config.regBits.W)) + val mvc = Module(new MatrixVectorMultiplication) + val reg1 = Reg(chiselTypeOf(mvc.io.wgt.data.bits)) + val reg2 = Reg(chiselTypeOf(mvc.io.inp.data.bits)) + val cntwgt = Reg(UInt(config.regBits.W)) + val cntinp = Reg(UInt(config.regBits.W)) + val cntout = Reg(UInt(config.regBits.W)) val raddr1 = Reg(UInt(config.ptrBits.W)) val raddr2 = Reg(UInt(config.ptrBits.W)) val waddr = Reg(UInt(config.ptrBits.W)) + val accum = Module(new Accmulator(size = p(CoreKey).blockOut, accBits = p(CoreKey).accBits)) switch (state) { is (sIdle) { @@ -73,7 +88,14 @@ class Compute(implicit config: AccelConfig) extends Module { } is (sReadAData) { when (io.mem.rd.valid) { + state := sReadADone + } + } + is (sReadADone) { + when (cntwgt === (length * length) - 1.U) { state := sReadBReq + } .otherwise { + state := sReadAReq } } is (sReadBReq) { @@ -81,6 +103,23 @@ class Compute(implicit config: AccelConfig) extends Module { } is (sReadBData) { when (io.mem.rd.valid) { + state := sReadBDone + } + } + is (sReadBDone) { + when (cntinp === length-1.U) { + state := sInpDone + } .otherwise { + state := sReadBReq + } + } + // Both input is processed + is (sInpDone) { + state := sWait + } + // Wait for computation + is (sWait) { + when (accum.io.ready) { state := sWriteReq } } @@ -89,15 +128,18 @@ class Compute(implicit config: AccelConfig) extends Module { state := sWriteData } is (sWriteData) { - when (cnt === (length - 1.U)) { + state := sWriteDone + } + is (sWriteDone) { + when (cntout === (length - 1.U)) { state := sIdle } .otherwise { - state := sReadAReq + state := sWriteReq } } } - val last = state === sWriteData && cnt === (length - 1.U) + val last = state === sWriteDone && cntout === (length - 1.U) // cycle counter when (state === sIdle) { @@ -114,10 +156,12 @@ class Compute(implicit config: AccelConfig) extends Module { raddr1 := io.ptrs(0) raddr2 := io.ptrs(1) waddr := io.ptrs(2) - } .elsewhen (state === sWriteData) { // increment input array by 1-byte + } .elsewhen (state === sReadADone) { // increment input array by 1-byte raddr1 := raddr1 + 1.U + } .elsewhen (state === sReadBDone) { // increment input array by 1-byte raddr2 := raddr2 + 1.U - waddr := waddr + } .elsewhen (state === sWriteDone) { + waddr := waddr + 4.U // writing 4 bytes } // create request @@ -128,59 +172,70 @@ class Compute(implicit config: AccelConfig) extends Module { // read when (state === sReadAData && io.mem.rd.valid) { - reg1 := io.mem.rd.bits(7, 0) + reg1(cntwgt/length)(cntwgt%length) := io.mem.rd.bits(7, 0) } when (state === sReadBData && io.mem.rd.valid) { - reg2 := io.mem.rd.bits(7, 0) + reg2(0)(cntinp) := io.mem.rd.bits(7, 0) } io.mem.rd.ready := state === sReadAData | state === sReadBData + mvc.io.inp.data.valid := state === sInpDone // 2 inputs have been processed + mvc.io.wgt.data.valid := state === sInpDone // 2 inputs have been processed + + mvc.io.wgt.data.bits <> reg1 + mvc.io.inp.data.bits <> reg2 + // Modify when shift operation is supported + mvc.io.reset := false.B + mvc.io.acc_i.data.valid := true.B + for (i <- 0 until p(CoreKey).blockOut) { + mvc.io.acc_i.data.bits(0)(i) := 0.U + } - - val sliceAccum = Module(new Accumulator(63)) - val overallAccum = Module(new Accumulator(64)) - - sliceAccum.io.valid := state === sWriteReq // 2 inputs have been processed - sliceAccum.io.in := reg1 * reg2 - sliceAccum.io.clear := startDot - overallAccum.io.clear := rstAccum - overallAccum.io.valid := last // last element has been processed - overallAccum.io.in := sliceAccum.io.sum << shift(7,0) // limit to 8 bits + accum.io.in := mvc.io.acc_o.data.bits + accum.io.shift := shift + accum.io.clear := rstAccum + accum.io.valid := mvc.io.acc_o.data.valid // write - io.mem.wr.valid := overallAccum.io.ready - io.mem.wr.bits := overallAccum.io.sum - + io.mem.wr.valid := state === sWriteData + io.mem.wr.bits := accum.io.sum(cntout) // count read/write when (state === sIdle) { - cnt := 0.U - } .elsewhen (state === sWriteData) { - cnt := cnt + 1.U + cntwgt := 0.U + cntinp := 0.U + cntout := 0.U + } .elsewhen (state === sReadADone) { + cntwgt := cntwgt + 1.U + } .elsewhen (state === sReadBDone) { + cntinp := cntinp + 1.U + } .elsewhen (state === sWriteDone) { + cntout := cntout + 1.U } - io.finish := overallAccum.io.ready // data has been added + io.finish := last // data has been added } - - -class Accumulator(dataBits: Int = 8) extends Module { +// Shift operation until supported in MVM +class Accmulator(size: Int = 16, accBits: Int = 32) extends Module { val io = IO(new Bundle { val clear = Input(Bool()) val valid = Input(Bool()) val ready = Output(Bool()) - val in = Input(UInt(dataBits.W)) - val sum = Output(UInt((dataBits).W)) + val in = Input(Vec(1, Vec(size, (UInt(accBits.W))))) + val shift = Input(UInt(8.W)) + val sum = Output(Vec(size, (UInt(accBits.W)))) }) + val reg = RegInit(VecInit(Seq.fill(size)(0.U(accBits.W)))) - val reg = RegInit(0.U((dataBits).W)) - val ready = RegNext(io.valid) - when (io.clear) { - reg := 0.U - } .elsewhen (io.valid) { - reg := reg + io.in - } - io.ready := ready - io.sum := reg + for (i <- 0 until size) { + when (io.clear) { + reg(i) := 0.U + } .elsewhen(io.valid) { + reg(i) := reg(i) + (io.in(0)(i) << io.shift) + } + } + io.ready := RegNext(io.valid) + io.sum := reg } diff --git a/vta/apps/gemm/hardware/chisel/src/main/scala/accel/RegFile.scala b/vta/apps/gemm/hardware/chisel/src/main/scala/accel/RegFile.scala index 6f0bdbb6b34c..10c40b5c2e72 100644 --- a/vta/apps/gemm/hardware/chisel/src/main/scala/accel/RegFile.scala +++ b/vta/apps/gemm/hardware/chisel/src/main/scala/accel/RegFile.scala @@ -35,13 +35,9 @@ import vta.dpi._ * Shift value | 0x08 * Vector length | 0x0c * Reset Accumulator | 0x10 - * Reset Dot Module | 0x14 - * Input1 pointer lsb | 0x18 - * Input1 pointer msb | 0x1c - * Input2 pointer lsb | 0x20 - * Input2 pointer msb | 0x24 - * Output pointer lsb | 0x28 - * Output pointer msb | 0x2c + * Input1 pointer | 0x18 + * Input2 pointer | 0x20 + * Output pointer | 0x28 * ------------------------------- * ------------------------------ diff --git a/vta/apps/gemm/src/driver.cc b/vta/apps/gemm/src/driver.cc index 8d380c323c9a..24b998edd211 100644 --- a/vta/apps/gemm/src/driver.cc +++ b/vta/apps/gemm/src/driver.cc @@ -66,10 +66,12 @@ class Device { uint32_t Run(DLTensor* inp1, DLTensor* inp2, uint32_t shiftVal, DLTensor* out, uint32_t reset) { uint32_t cycles; - uint32_t length = inp1->shape[0]; - size_t size1 = (inp1->dtype.bits >> 3) * length; + uint32_t length = inp2->shape[0]; + // 1 matrix 1 vector input + size_t size1 = (inp1->dtype.bits >> 3) * length * length; size_t size2 = (inp2->dtype.bits >> 3) * length; - size_t size3 = (64 >> 3); + // 1 vector output + size_t size3 = (32 >> 3) * length; inp1_ = this->MemAlloc(size1); inp2_ = this->MemAlloc(size2); out_ = this->MemAlloc(size3); @@ -115,19 +117,17 @@ class Device { void Launch(uint32_t length, uint32_t shiftVal, uint32_t reset) { dpi_->WriteReg(0x08, shiftVal); - dpi_->WriteReg(0x0c, length); // vector length + dpi_->WriteReg(0x0c, length); // tensor size dpi_->WriteReg(0x18, this->MemGetPhyAddr(inp1_)); dpi_->WriteReg(0x20, this->MemGetPhyAddr(inp2_)); dpi_->WriteReg(0x28, this->MemGetPhyAddr(out_)); dpi_->WriteReg(0x00, 0x1); // launch - dpi_->WriteReg(0x00, 0x0); // launch + dpi_->WriteReg(0x00, 0x0); if (reset == 1) { - dpi_->WriteReg(0x10, 0x1); // reset accum - dpi_->WriteReg(0x10, 0x0); // stop reset accum + dpi_->WriteReg(0x10, 0x1); // reset accumulator + dpi_->WriteReg(0x10, 0x0); } - dpi_->WriteReg(0x14, 0x1); // reset dot - dpi_->WriteReg(0x14, 0x0); // stop reset dot } uint32_t WaitForCompletion() { diff --git a/vta/apps/gemm/tests/python/chisel_accel.py b/vta/apps/gemm/tests/python/chisel_accel.py index 4aed5636b50e..4666661f9bc9 100644 --- a/vta/apps/gemm/tests/python/chisel_accel.py +++ b/vta/apps/gemm/tests/python/chisel_accel.py @@ -26,7 +26,7 @@ A : Vector to be sliced and packed slice_width : slice width -Returnsi +Returns --------- C: 2d matrix where each cloumn (because of bit packing) represents each bit slice of A """ @@ -39,7 +39,7 @@ def slice(A, slice_width): elif dtype is np.uint16: row = 16 // slice_width elif dtype is np.uint32: row = 32 // slice_width elif dtype is np.uint64: row = 64 // slice_width - else: raise ValueError("datatype " + str(dtype) + "currently not supported") + else: raise ValueError("datatype currently not supported") if (row >= 8): dtype = 'uint' + str(row) else: @@ -55,64 +55,88 @@ def slice(A, slice_width): C[y][x] = (np.uint64(A[x]) >> np.uint64(slice_width * y)) & np.uint64(slice_mask) return C +def slice_mat(A, slice_width): + assert np.log2(slice_width) % 1 == 0, "only power of 2 is supported" + dtype = type(A[0][0]) + row = 0 + # currently only supports uint + if dtype is np.uint8: row = 8 // slice_width + elif dtype is np.uint16: row = 16 // slice_width + elif dtype is np.uint32: row = 32 // slice_width + elif dtype is np.uint64: row = 64 // slice_width + else: raise ValueError("datatype currently not supported") + if (row >= 8): + dtype = 'uint' + str(row) + else: + dtype = 'uint8' + + # 3d array (bits, row, clmn) + C = np.zeros((row, A.shape[0], A.shape[1])).astype(dtype) # sliced and transform + + # create mask + slice_mask = 2**(slice_width)-1 + # slice and pack + for z in range(A.shape[0]): + C[:, z, :] = slice(A[z], slice_width) + return C + """ Matrix Multiplication Function Parameters ---------- A : Matrix A B: Matrix B -w_width : weight slice width -a_width : activation slice width +i_width : weight slice width +w_width : activation slice width Returns --------- C: result of A * B """ # A is a n*m matrix, B is a m*p matrix(not transposed yet) -def matrix_multiply(A, B, w_width, a_width): +def matrix_multiply(A, B, i_width, w_width): assert A.shape[1] == B.shape[0], "can't perform multiplication" BT = B.transpose() cycles = 0 + B_sliced = slice_mat(BT, w_width) C = np.zeros((A.shape[0], B.shape[1])).astype('uint64') for i in range(A.shape[0]): - for j in range(B.shape[1]): - # C[i, j] = A[i].dot(BT[j]) - A_sliced = slice(A[i], w_width) - B_sliced = slice(BT[j], a_width) - - C[i, j] = compute(A_sliced, B_sliced, w_width, a_width) - test = test_accel(A_sliced, B_sliced, w_width, a_width) - cycles += test[1] - np.testing.assert_equal(C[i,j], A[i].astype('uint64').dot(BT[j])) - print("PASS SW serial & parallel") - - np.testing.assert_equal(test[0], C[i, j]) - print("PASS SW & HW bit serial") - - np.testing.assert_equal(test[0], A[i].astype('uint64').dot(BT[j])) - print("PASS SW bit parallel & HW bit parallel") - + A_sliced = slice(A[i], i_width) + test = test_accel(A_sliced, B_sliced, i_width, w_width) + C[i] = test[0] + cycles += test[1] + np.testing.assert_array_equal(C[i], compute(A_sliced, B_sliced, i_width, w_width)) + print("PASS row " + str(i)) + + np.testing.assert_array_equal(C, np.matmul(A.astype('uint64'),B)) print("result: ") print(C) - print("ALL TESTS PASSED, cycles: " + str(cycles)) + print("TEST PASSED, cycles: " + str(cycles)) return C -""" Software Verification Function""" -# takes 2 matrix input (sliced and packed) -def compute(A, B, w_width, a_width): +""" Software Verification Function +Parameter Dimesions +--------- +A (bits, y) and B (bits, y, x) (transposed) + +Takes 1 vector and 1 matrix input (sliced and packed) + +Returns +--------- +Resulting vector +""" +def compute(A, B, i_width, w_width): assert A.shape[1] == B.shape[1], "sliced shape not match" # reset hardware accumulator - accum = 0 + accum = np.zeros(A.shape[1]) for x in range(A.shape[0]): for y in range(B.shape[0]): - # hardware implementation - accum += np.uint64(A[x]).dot(np.uint64(B[y])) << np.uint64(x*w_width + y*a_width) + accum += np.matmul(A[x].astype('uint64'), B[y].transpose()) << np.uint64(x*i_width + y*w_width) # get value from accumulator return accum -"""Testing Function for Dot Product""" -def test_accel(A, B, w_width, a_width): - assert A.shape[1] == B.shape[1], "sliced shape not match" - +"""Testing Function for Matrix Vector Multiplication""" +def test_accel(A, B, i_width, w_width): + assert A.shape[1] == B.shape[2], "sliced shape not match" dtype = A.dtype ctx = tvm.cpu(0) f = tsim.load_module() @@ -126,57 +150,54 @@ def test_accel(A, B, w_width, a_width): a_arr.append(tvm.nd.array(list_a.astype(dtype), ctx)) for i in range(B.shape[0]): - list_b = np.zeros(B.shape[1]).astype(dtype) - for j in range(B.shape[1]): - list_b[j] = B[i][j] + # transpose + list_b = np.zeros((B.shape[2], B.shape[1])).astype(dtype) + for j in range(B.shape[2]): + for k in range(B.shape[1]): + list_b[j][k] = B[i][j][k] b_arr.append(tvm.nd.array(list_b.astype(dtype), ctx)) cycles = 0 - - accum = tvm.nd.array(np.array([0]).astype("uint64"), ctx) + accum = tvm.nd.array(np.zeros(A.shape[1]).astype("uint32"), ctx) for i in range(len(a_arr)): for j in range(len(b_arr)): - shift = np.uint8(i*w_width + j*a_width) + shift = np.uint8(i*i_width + j*w_width) if i == 0 and j == 0: - cycles += f(a_arr[i], b_arr[j], shift, accum, np.uint32(1)) # reset accumulator + cycles += f(b_arr[j], a_arr[i], shift, accum, np.uint32(1)) # reset accumulator else: - cycles += f(a_arr[i], b_arr[j], shift, accum, np.uint32(0)) # no reset + cycles += f(b_arr[j], a_arr[i], shift, accum, np.uint32(0)) # no reset - return (accum.asnumpy()[0], cycles) + return (accum.asnumpy(), cycles) """ Matrix Generator Parameters ---------- dtype : String, datatype generated (supports only uint) -w_width : weight bit slices(needs to be less than actual bit width) -a_width : activation bit slices(needs to be less than actual bit width) +i_width : weight bit slices(needs to be less than actual bit width) +w_width : activation bit slices(needs to be less than actual bit width) """ -def top_test(dtype, w_width, a_width): - - rmax = np.random.randint(256) - # random matrix generation (dimension up to 8) - rrow = np.random.randint(7) + 1 - rclmn = np.random.randint(7) + 1 - rrow2 = np.random.randint(7) + 1 - A = np.random.randint(rmax, size=(rrow,rclmn)).astype(dtype) - B = np.random.randint(rmax, size=(rclmn,rrow2)).astype(dtype) +def top_test(dtype, i_width, w_width): - print("A: ") - print(A) - print("\n") - print("B: ") - print(B) - print("\n") - matrix_multiply(A, B, w_width, a_width) + # only supports positive values (up to 2**(bits-1)) + rmax = 127 + # (m,16) * (16,16) GEMM + rrow = np.random.randint(7) + 1 + clmn = 16 + A = np.random.randint(rmax, size=(rrow,clmn)).astype(dtype) + B = np.random.randint(rmax, size=(clmn,clmn)).astype(dtype) + print("A: " + str(A)) + print("B: " + str(B)) + # perform GEMM + matrix_multiply(A, B, i_width, w_width) if __name__ == "__main__": tsim.init("chisel") for i in range(1): - # reg1 and reg2 bits in Compute.scala must be modified for slices greater than 8 bits + # reg1 and reg2 bits in hardware/chisel/src/main/Compute.scala must be modified for slices greater than 8 bits if sys.argv[1] == 'serial': - # generates a random uint8 GEMM with 2-bit(8/4) weight and 4-bit(8/2) activation - top_test("uint8",4, 2) + # generates a random uint8 GEMM with 2-bit(8/4) input and 4-bit(8/2) weight + top_test("uint8", 4, 2) elif sys.argv[1] == 'parallel': - # generates a random uint8 GEMM with 8-bit weight and 8-bit activation (bit parallel) - top_test('uint8', 1, 1) + # generates a random uint8 GEMM with 8-bit input and 8-bit weight (bit parallel) + top_test('uint8', 8, 8)