Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VTA] [Chisel] support for different inp/wgt bits, rewrote DotProduct for clarity #3605

Merged
merged 22 commits into from
Jul 26, 2019
Merged
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 38 additions & 35 deletions vta/hardware/chisel/src/main/scala/core/TensorGemm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ class MAC(dataBits: Int = 8, cBits: Int = 16, outBits: Int = 17) extends Module
val rB = RegNext(io.b)
val rC = RegNext(io.c)
mult := rA * rB
add := rC + mult
add := rC +& mult
io.y := add
}

/** Pipelined adder */
class Adder(dataBits: Int = 8, outBits: Int = 17) extends Module {
class PipeAdder(dataBits: Int = 8, outBits: Int = 17) extends Module {
require (outBits >= dataBits)
val io = IO(new Bundle {
val a = Input(SInt(dataBits.W))
Expand All @@ -56,58 +56,61 @@ class Adder(dataBits: Int = 8, outBits: Int = 17) extends Module {
val add = Wire(SInt(outBits.W))
val rA = RegNext(io.a)
val rB = RegNext(io.b)
add := rA + rB
add := rA +& rB
io.y := add
}

/** Pipelined DotProduct based on MAC and Adder */
class DotProduct(dataBits: Int = 8, size: Int = 16) extends Module {
/** Pipelined DotProduct based on MAC and PipeAdder */
class DotProduct(inpBits: Int = 8, wgtBits: Int = 8, size: Int = 16) extends Module {
BenjaminTu marked this conversation as resolved.
Show resolved Hide resolved
val errMsg = s"\n\n[VTA] [DotProduct] size must be greater than 4 and a power of 2\n\n"
require(size >= 4 && isPow2(size), errMsg)
val b = dataBits * 2
val b = inpBits + wgtBits
val dataBits = Math.max(inpBits, wgtBits)
val outBits = b + log2Ceil(size) + 1
val io = IO(new Bundle {
val a = Input(Vec(size, SInt(dataBits.W)))
val b = Input(Vec(size, SInt(dataBits.W)))
val a = Input(Vec(size, SInt(inpBits.W)))
val b = Input(Vec(size, SInt(wgtBits.W)))
val y = Output(SInt(outBits.W))
})
val p = log2Ceil(size/2)
val s = Seq.tabulate(log2Ceil(size))(i => pow(2, p - i).toInt)
val da = Seq.tabulate(s(0))(i => RegNext(io.a(s(0) + i)))
val db = Seq.tabulate(s(0))(i => RegNext(io.b(s(0) + i)))
val m = Seq.tabulate(2)(i =>
Seq.fill(s(0))(Module(new MAC(dataBits = dataBits, cBits = b + i, outBits = b + i + 1)))
)
val s = Seq.tabulate(log2Ceil(size+1))(i => pow(2, log2Ceil(size) - i).toInt) // # of total layers
val p = log2Ceil(size/2)+1 // # of adder layers
val m = Seq.fill(s(0))(Module(new MAC(dataBits = dataBits, cBits = b, outBits = b + 1))) // # of total vector pairs
val a = Seq.tabulate(p)(i =>
Seq.fill(s(i + 1))(Module(new Adder(dataBits = b + i + 2, outBits = b + i + 3)))
)
Seq.fill(s(i + 1))(Module(new PipeAdder(dataBits = b + i + 1, outBits = b + i + 2)))
) // # adders within each layer

for (i <- 0 until log2Ceil(size)) {
for (j <- 0 until s(i)) {
// Vector MACs
for (i <- 0 until s(0)) {
m(i).io.a := io.a(i)
m(i).io.b := io.b(i)
m(i).io.c := 0.S
}

// PipeAdder Reduction
for (i <- 0 until p) {
for (j <- 0 until s(i+1)) {
if (i == 0) {
m(i)(j).io.a := io.a(j)
m(i)(j).io.b := io.b(j)
m(i)(j).io.c := 0.S
m(i + 1)(j).io.a := da(j)
m(i + 1)(j).io.b := db(j)
m(i + 1)(j).io.c := m(i)(j).io.y
} else if (i == 1) {
a(i - 1)(j).io.a := m(i)(2*j).io.y
BenjaminTu marked this conversation as resolved.
Show resolved Hide resolved
a(i - 1)(j).io.b := m(i)(2*j + 1).io.y
// First layer of PipeAdders
a(i)(j).io.a := m(2*j).io.y
a(i)(j).io.b := m(2*j + 1).io.y
} else {
a(i - 1)(j).io.a := a(i - 2)(2*j).io.y
a(i - 1)(j).io.b := a(i - 2)(2*j + 1).io.y
a(i)(j).io.a := a(i - 1)(2*j).io.y
a(i)(j).io.b := a(i - 1)(2*j + 1).io.y
}
}
}

// last adder
io.y := a(p-1)(0).io.y
}

/** Perform matric-vector-multiplication based on DotProduct */
BenjaminTu marked this conversation as resolved.
Show resolved Hide resolved
class MatrixVectorCore(implicit p: Parameters) extends Module {
class MatrixVectorMultiplication(implicit p: Parameters) extends Module {
val accBits = p(CoreKey).accBits
val size = p(CoreKey).blockOut
val dataBits = p(CoreKey).inpBits
val inpBits = p(CoreKey).inpBits
val wgtBits = p(CoreKey).wgtBits
val outBits = p(CoreKey).outBits
val io = IO(new Bundle{
val reset = Input(Bool()) // FIXME: reset should be replaced by a load-acc instr
val inp = new TensorMasterData(tensorType = "inp")
Expand All @@ -116,7 +119,7 @@ class MatrixVectorCore(implicit p: Parameters) extends Module {
val acc_o = new TensorClientData(tensorType = "acc")
val out = new TensorClientData(tensorType = "out")
})
val dot = Seq.fill(size)(Module(new DotProduct(dataBits, size)))
val dot = Seq.fill(size)(Module(new DotProduct(inpBits, wgtBits, size)))
BenjaminTu marked this conversation as resolved.
Show resolved Hide resolved
val acc = Seq.fill(size)(Module(new Pipe(UInt(accBits.W), latency = log2Ceil(size) + 1)))
val add = Seq.fill(size)(Wire(SInt(accBits.W)))
val vld = Wire(Vec(size, Bool()))
Expand All @@ -139,7 +142,7 @@ class MatrixVectorCore(implicit p: Parameters) extends Module {

/** TensorGemm.
*
* This unit instantiate the MatrixVectorCore and go over the
* This unit instantiate the MatrixVectorMultiplication and go over the
* micro-ops (uops) which are used to read inputs, weights and biases,
* and writes results back to the acc and out scratchpads.
*
Expand All @@ -159,7 +162,7 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module
})
val sIdle :: sReadUop :: sComputeIdx :: sReadTensor :: sExe :: sWait :: Nil = Enum(6)
val state = RegInit(sIdle)
val mvc = Module(new MatrixVectorCore)
val mvc = Module(new MatrixVectorMultiplication)
val dec = io.inst.asTypeOf(new GemmDecode)
val uop_idx = Reg(chiselTypeOf(dec.uop_end))
val uop_end = dec.uop_end
Expand Down