Skip to content

Commit

Permalink
Add page_size and num_buffers attributes to CBDesc (#370)
Browse files Browse the repository at this point in the history
These attributes are needed by ttnn generic and direct to metal
  • Loading branch information
nsmithtt authored Aug 13, 2024
1 parent 6c37767 commit 6b58bde
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 7 deletions.
22 changes: 20 additions & 2 deletions include/ttmlir/Dialect/TTKernel/IR/TTKernelOpsTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,28 @@ class TTKernel_Type<string name, string typeMnemonic, list<Trait> traits = []>
def TTKernel_CB : TTKernel_Type<"CB", "cb"> {
let summary = "TTKernel cb";
let description = "Circular buffer type in TTKernel dialect";
let parameters = (ins "uint64_t":$address, "uint64_t":$port, "MemRefType":$memref);
let assemblyFormat = "`<` $address`,` $port`,` $memref `>`";
let parameters = (ins "uint64_t":$address,
"uint64_t":$port,
"MemRefType":$memref,
"uint64_t":$page_size,
"uint64_t":$num_buffers);
let assemblyFormat = "`<` $address`,` $port`,` $memref`,` $page_size`,` $num_buffers `>`";

let extraClassDeclaration = [{
static CBType get(::mlir::MLIRContext *context,
uint64_t address,
uint64_t port,
MemRefType memref,
uint64_t numBuffers = 1) {
uint64_t pageSize = 0;
if (::mlir::isa<::mlir::tt::TileType>(memref.getElementType())) {
pageSize = ::mlir::cast<::mlir::tt::TileType>(memref.getElementType()).getSizeBytes();
} else {
pageSize = memref.getShape().back() * (memref.getElementType().getIntOrFloatBitWidth() / 8);
}
return CBType::get(context, address, port, memref, pageSize, numBuffers);
}

::llvm::ArrayRef<int64_t> getShape() const {
return getMemref().getShape();
}
Expand Down
3 changes: 2 additions & 1 deletion include/ttmlir/Target/Common/types.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ table TensorDesc {
table CBDesc {
port: uint32;
memory_desc: MemoryDesc;
page_size: uint64;
num_buffers: uint64;
}

Expand All @@ -87,7 +88,7 @@ table TensorRef {

table CBRef {
global_id: uint32;
associated_tensor_global_id: uint32;
tensor_ref: TensorRef;
address: uint64;
desc: CBDesc;
}
Expand Down
17 changes: 17 additions & 0 deletions lib/Dialect/TTMetal/Transforms/SerializeToBinary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ ::tt::target::Dim2dRange toFlatbuffer(CoreRangeAttr coreRange) {
::tt::target::Dim2d(size[0], size[1]));
}

::flatbuffers::Offset<::tt::target::CBDesc>
cbTypeToFlatbuffer(FlatbufferObjectCache &cache, ttkernel::CBType cbType) {
auto memref = cache.getOrCreate(cbType.getMemref(), memrefAttrToFlatbuffer);
return ::tt::target::CreateCBDesc(*cache.fbb, cbType.getPort(), memref,
cbType.getPageSize(),
cbType.getNumBuffers());
}

class TTMetalSerializeToBinary
: public impl::TTMetalSerializeToBinaryBase<TTMetalSerializeToBinary> {
public:
Expand Down Expand Up @@ -163,6 +171,15 @@ class TTMetalSerializeToBinary
toFlatbuffer(mlir::cast<CoreRangeAttr>(
dispatchOp.getCoreRanges()[region.getRegionNumber()]))};
std::vector<::flatbuffers::Offset<::tt::target::CBRef>> cbs;
for (auto arg : region.getArguments()) {
assert(arg.getArgNumber() < operands.size());
auto cbType = mlir::cast<ttkernel::CBType>(arg.getType());
auto cbDesc = cache.getOrCreate(cbType, cbTypeToFlatbuffer);
auto tensorRef = operands[arg.getArgNumber()];
cbs.push_back(
::tt::target::CreateCBRef(fbb, cache.global_id++, tensorRef,
cbType.getAddress(), cbDesc));
}
kernels.push_back(::tt::target::metal::CreateKernelDescDirect(
fbb, ::tt::target::metal::Kernel::KernelSource,
::tt::target::metal::CreateKernelSourceDirect(
Expand Down
8 changes: 4 additions & 4 deletions runtime/lib/ttmetal/command_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,11 @@ static ::tt::tt_metal::CircularBufferConfig createCircularBufferConfig(
cbRef->desc()->memory_desc()->size() * cbRef->desc()->num_buffers();
::tt::DataFormat dataFormat =
toDataFormat(cbRef->desc()->memory_desc()->data_type());
assert(cbRef->associated_tensor_global_id());
assert(cbRef->tensor_ref());
assert(cbRef->tensor_ref()->address() == cbRef->address());
return CircularBufferConfig(totalSize, {{cbRef->desc()->port(), dataFormat}},
*buffers.at(cbRef->associated_tensor_global_id()))
.set_page_size(cbRef->desc()->port(),
cbRef->desc()->memory_desc()->size());
*buffers.at(cbRef->tensor_ref()->global_id()))
.set_page_size(cbRef->desc()->port(), cbRef->desc()->page_size());
}

void CQExecutor::execute(
Expand Down

0 comments on commit 6b58bde

Please sign in to comment.