Skip to content

Commit

Permalink
[VTA][Chisel] scale dram base address in hardware instead of runtime (a…
Browse files Browse the repository at this point in the history
…pache#3772)

* [VTA][Chisel] scale dram base address in hardware instead of runtime

* remove trailing spaces
  • Loading branch information
vegaluisjose authored and wweic committed Aug 16, 2019
1 parent adbf584 commit 5544a21
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 22 deletions.
5 changes: 3 additions & 2 deletions vta/hardware/chisel/src/main/scala/core/LoadUop.scala
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,12 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module {
}

// read-from-dram
val maskOffset = VecInit(Seq.fill(M_DRAM_OFFSET_BITS)(true.B)).asUInt
when (state === sIdle) {
when (offsetIsEven) {
raddr := io.baddr + dec.dram_offset
raddr := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(uopBytes)))
} .otherwise {
raddr := io.baddr + dec.dram_offset - uopBytes.U
raddr := (io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(uopBytes)))) - uopBytes.U
}
} .elsewhen (state === sReadData && xcnt === xlen && xrem =/= 0.U) {
raddr := raddr + xmax_bytes
Expand Down
2 changes: 1 addition & 1 deletion vta/hardware/chisel/src/main/scala/core/TensorLoad.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)
val strideFactor = tp.tensorLength * tp.tensorWidth

val dec = io.inst.asTypeOf(new MemDecode)
val dataCtrl = Module(new TensorDataCtrl(sizeFactor, strideFactor))
val dataCtrl = Module(new TensorDataCtrl(tensorType, sizeFactor, strideFactor))
val dataCtrlDone = RegInit(false.B)
val yPadCtrl0 = Module(new TensorPadCtrl(padType = "YPad0", sizeFactor))
val yPadCtrl1 = Module(new TensorPadCtrl(padType = "YPad1", sizeFactor))
Expand Down
6 changes: 4 additions & 2 deletions vta/hardware/chisel/src/main/scala/core/TensorStore.scala
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,11 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)
val mdata = MuxLookup(set, 0.U.asTypeOf(chiselTypeOf(wdata_t)), tread)

// write-to-dram
val maskOffset = VecInit(Seq.fill(M_DRAM_OFFSET_BITS)(true.B)).asUInt
val elemBytes = (p(CoreKey).batch * p(CoreKey).blockOut * p(CoreKey).outBits) / 8
when (state === sIdle) {
waddr_cur := io.baddr + dec.dram_offset
waddr_nxt := io.baddr + dec.dram_offset
waddr_cur := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(elemBytes)))
waddr_nxt := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(elemBytes)))
} .elsewhen (state === sWriteAck && io.vme_wr.ack && xrem =/= 0.U) {
waddr_cur := waddr_cur + xmax_bytes
} .elsewhen (stride) {
Expand Down
16 changes: 13 additions & 3 deletions vta/hardware/chisel/src/main/scala/core/TensorUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ class TensorPadCtrl(padType: String = "none", sizeFactor: Int = 1) extends Modul
}

/** TensorDataCtrl. Data controller for TensorLoad. */
class TensorDataCtrl(sizeFactor: Int = 1, strideFactor: Int = 1)(implicit p: Parameters) extends Module {
class TensorDataCtrl(tensorType: String = "none", sizeFactor: Int = 1, strideFactor: Int = 1)(implicit p: Parameters) extends Module {
val mp = p(ShellKey).memParams
val io = IO(new Bundle {
val start = Input(Bool())
Expand Down Expand Up @@ -281,9 +281,19 @@ class TensorDataCtrl(sizeFactor: Int = 1, strideFactor: Int = 1)(implicit p: Par
ycnt := ycnt + 1.U
}

val maskOffset = VecInit(Seq.fill(M_DRAM_OFFSET_BITS)(true.B)).asUInt
val elemBytes =
if (tensorType == "inp") {
(p(CoreKey).batch * p(CoreKey).blockIn * p(CoreKey).inpBits) / 8
} else if (tensorType == "wgt") {
(p(CoreKey).blockOut * p(CoreKey).blockIn * p(CoreKey).wgtBits) / 8
} else {
(p(CoreKey).batch * p(CoreKey).blockOut * p(CoreKey).accBits) / 8
}

when (io.start) {
caddr := io.baddr + dec.dram_offset
baddr := io.baddr + dec.dram_offset
caddr := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(elemBytes)))
baddr := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(elemBytes)))
} .elsewhen (io.yupdate) {
when (split) {
caddr := caddr + xmax_bytes
Expand Down
4 changes: 2 additions & 2 deletions vta/src/device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down
12 changes: 0 additions & 12 deletions vta/src/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -431,11 +431,7 @@ class UopQueue : public BaseQueue<VTAUop> {
insn->memory_type = VTA_MEM_ID_UOP;
insn->sram_base = sram_begin_;
// Update cache idx to physical address map
#ifdef USE_TSIM
insn->dram_base = fpga_buff_phy_ + offset;
#else
insn->dram_base = (fpga_buff_phy_ + offset) / kElemBytes;
#endif
insn->y_size = 1;
insn->x_size = (sram_end_ - sram_begin_);
insn->x_stride = (sram_end_ - sram_begin_);
Expand Down Expand Up @@ -1011,11 +1007,7 @@ class CommandQueue {
insn->memory_type = dst_memory_type;
insn->sram_base = dst_sram_index;
DataBuffer* src = DataBuffer::FromHandle(src_dram_addr);
#ifdef USE_TSIM
insn->dram_base = (uint32_t) src->phy_addr() + src_elem_offset*GetElemBytes(dst_memory_type);
#else
insn->dram_base = src->phy_addr() / GetElemBytes(dst_memory_type) + src_elem_offset;
#endif
insn->y_size = y_size;
insn->x_size = x_size;
insn->x_stride = x_stride;
Expand All @@ -1038,11 +1030,7 @@ class CommandQueue {
insn->memory_type = src_memory_type;
insn->sram_base = src_sram_index;
DataBuffer* dst = DataBuffer::FromHandle(dst_dram_addr);
#ifdef USE_TSIM
insn->dram_base = (uint32_t) dst->phy_addr() + dst_elem_offset*GetElemBytes(src_memory_type);
#else
insn->dram_base = dst->phy_addr() / GetElemBytes(src_memory_type) + dst_elem_offset;
#endif
insn->y_size = y_size;
insn->x_size = x_size;
insn->x_stride = x_stride;
Expand Down

0 comments on commit 5544a21

Please sign in to comment.