Skip to content

Commit

Permalink
Handle OWNED type host storage in ttnn runtime. (#495)
Browse files Browse the repository at this point in the history
  • Loading branch information
jnie-TT authored Aug 27, 2024
1 parent 9277218 commit 1bcf83d
Showing 1 changed file with 59 additions and 35 deletions.
94 changes: 59 additions & 35 deletions runtime/lib/ttnn/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,22 @@ ttnn::Tensor untilize(ttnn::Tensor const &input) {

namespace tt::runtime::ttnn {

static bool isOnHost(const ::ttnn::Tensor &tensor) {
// Currently only supports borrowed or owned host storage
return tensor.storage_type() == ::tt::tt_metal::StorageType::BORROWED or
tensor.storage_type() == ::tt::tt_metal::StorageType::OWNED;
}

static bool isOnDevice(const ::ttnn::Tensor &tensor) {
// Currently only supports single device storage
return tensor.storage_type() == ::tt::tt_metal::StorageType::DEVICE;
}

static ::ttnn::Tensor convertDataType(const ::ttnn::Tensor &input,
const ::ttnn::DataType &targetDataType) {
const ::ttnn::StorageType storageType = input.storage_type();
if (storageType == ::tt::tt_metal::StorageType::BORROWED) {
if (isOnHost(input)) {
return ::ttnn::to_dtype(input, targetDataType);
} else if (storageType == ::tt::tt_metal::StorageType::DEVICE) {
} else if (isOnDevice(input)) {
if (input.get_layout() != ::ttnn::TILE_LAYOUT) {
// typecast op requires tilized tensor
::ttnn::Tensor converted =
Expand Down Expand Up @@ -88,11 +98,11 @@ run(::tt::target::ttnn::ToMemoryConfigOp const *op, ::ttnn::Device &device,
std::unordered_map<std::uint32_t, ::ttnn::Tensor *> &liveTensors,
std::list<::ttnn::Tensor> &tensorPool) {
const ::ttnn::Tensor &inputTensor = *liveTensors.at(op->in0()->global_id());
assert(inputTensor.storage_type() == ::tt::tt_metal::StorageType::BORROWED or
inputTensor.storage_type() == ::tt::tt_metal::StorageType::DEVICE);

TT_FATAL(isOnHost(inputTensor) or isOnDevice(inputTensor),
"Unsupported storage type {}", inputTensor.storage_type());
const ::tt::target::Dim2d *targetTileShape =
op->out()->desc()->layout()->memory_desc()->tile_shape();

TT_FATAL(utils::isValidTileShape(targetTileShape),
"Invalid tile shape ({}, {})", targetTileShape->x(),
targetTileShape->y());
Expand All @@ -110,24 +120,29 @@ run(::tt::target::ttnn::ToMemoryConfigOp const *op, ::ttnn::Device &device,
case ::tt::target::MemorySpace::System:
case ::tt::target::MemorySpace::SystemMMIO: {
::ttnn::Tensor result;
if (inputTensor.storage_type() == ::tt::tt_metal::StorageType::BORROWED) {
if (isOnHost(inputTensor)) {
result =
updateLayoutAndDataType(inputTensor, targetDataTypeTTNN, false, true);
} else if (inputTensor.storage_type() ==
::tt::tt_metal::StorageType::DEVICE) {
} else if (isOnDevice(inputTensor)) {
result = updateLayoutAndDataType(inputTensor.cpu(), targetDataTypeTTNN,
false, true);
}
::ttnn::Tensor &outputTensor = *liveTensors.at(op->out()->global_id());
void *src = ::tt::tt_metal::get_raw_host_data_ptr(result);
void *dst = ::tt::tt_metal::get_raw_host_data_ptr(outputTensor);
std::uint32_t size = result.volume() * result.element_size();
std::memcpy(dst, src, size);
// copy the output to the output tensor if it exists
if (liveTensors.contains(op->out()->global_id())) {
::ttnn::Tensor &outputTensor = *liveTensors.at(op->out()->global_id());
void *src = ::tt::tt_metal::get_raw_host_data_ptr(result);
void *dst = ::tt::tt_metal::get_raw_host_data_ptr(outputTensor);
std::uint32_t size = result.volume() * result.element_size();
std::memcpy(dst, src, size);
} else {
tensorPool.push_back(result);
liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back());
}
break;
}
case ::tt::target::MemorySpace::DeviceDRAM: {
::tt::tt_metal::MemoryConfig memConfig = ::ttnn::DRAM_MEMORY_CONFIG;
if (inputTensor.storage_type() == ::tt::tt_metal::StorageType::BORROWED) {
if (isOnHost(inputTensor)) {
::ttnn::Tensor result = inputTensor;
bool shouldTilize = true;
// device tilize requires BFLOAT16, if not then tilize on host
Expand All @@ -140,8 +155,7 @@ run(::tt::target::ttnn::ToMemoryConfigOp const *op, ::ttnn::Device &device,
false);
tensorPool.push_back(result);
liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back());
} else if (inputTensor.storage_type() ==
::tt::tt_metal::StorageType::DEVICE) {
} else if (isOnDevice(inputTensor)) {
::ttnn::Tensor result = updateLayoutAndDataType(
inputTensor, targetDataTypeTTNN, false, false);
result = ::ttnn::to_memory_config(result, memConfig, std::nullopt);
Expand All @@ -154,7 +168,7 @@ run(::tt::target::ttnn::ToMemoryConfigOp const *op, ::ttnn::Device &device,
// But will need it's own code path when we add support for sharding
case ::tt::target::MemorySpace::DeviceL1: {
::tt::tt_metal::MemoryConfig memConfig = ::ttnn::L1_MEMORY_CONFIG;
if (inputTensor.storage_type() == ::tt::tt_metal::StorageType::BORROWED) {
if (isOnHost(inputTensor)) {
::ttnn::Tensor result = inputTensor;
bool shouldTilize = true;
// device tilize requires BFLOAT16, if not then tilize on host
Expand All @@ -167,8 +181,7 @@ run(::tt::target::ttnn::ToMemoryConfigOp const *op, ::ttnn::Device &device,
false);
tensorPool.push_back(result);
liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back());
} else if (inputTensor.storage_type() ==
::tt::tt_metal::StorageType::DEVICE) {
} else if (isOnDevice(inputTensor)) {
::ttnn::Tensor result = updateLayoutAndDataType(
inputTensor, targetDataTypeTTNN, false, false);
result = ::ttnn::to_memory_config(result, memConfig, std::nullopt);
Expand Down Expand Up @@ -206,39 +219,44 @@ run(::tt::target::ttnn::EltwiseOp const *op, ::ttnn::Device &device,
switch (op->type()) {
/* Eltwise Binary */
case ::tt::target::ttnn::EltwiseOpType::Add: {
assert(op->ins()->size() == 2 && "Unsupported number of inputs");
TT_FATAL(op->ins()->size() == 2, "Expected 2 inputs, got {}",
op->ins()->size());
auto &lhs = *liveTensors.at(op->ins()->Get(0)->global_id());
auto &rhs = *liveTensors.at(op->ins()->Get(1)->global_id());
tensorPool.push_back(::ttnn::add(lhs, rhs));
liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back());
break;
}
case ::tt::target::ttnn::EltwiseOpType::Multiply: {
assert(op->ins()->size() == 2 && "Unsupported number of inputs");
TT_FATAL(op->ins()->size() == 2, "Expected 2 inputs, got {}",
op->ins()->size());
auto &lhs = *liveTensors.at(op->ins()->Get(0)->global_id());
auto &rhs = *liveTensors.at(op->ins()->Get(1)->global_id());
tensorPool.push_back(::ttnn::multiply(lhs, rhs));
liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back());
break;
}
case ::tt::target::ttnn::EltwiseOpType::Subtract: {
assert(op->ins()->size() == 2 && "Unsupported number of inputs");
TT_FATAL(op->ins()->size() == 2, "Expected 2 inputs, got {}",
op->ins()->size());
auto &lhs = *liveTensors.at(op->ins()->Get(0)->global_id());
auto &rhs = *liveTensors.at(op->ins()->Get(1)->global_id());
tensorPool.push_back(::ttnn::subtract(lhs, rhs));
liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back());
break;
}
case ::tt::target::ttnn::EltwiseOpType::GreaterEqual: {
assert(op->ins()->size() == 2 && "Unsupported number of inputs");
TT_FATAL(op->ins()->size() == 2, "Expected 2 inputs, got {}",
op->ins()->size());
::ttnn::Tensor &lhs = *liveTensors.at(op->ins()->Get(0)->global_id());
::ttnn::Tensor &rhs = *liveTensors.at(op->ins()->Get(1)->global_id());
tensorPool.push_back(::ttnn::ge(lhs, rhs));
liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back());
break;
}
case ::tt::target::ttnn::EltwiseOpType::Div: {
assert(op->ins()->size() == 2 && "Unsupported number of inputs");
TT_FATAL(op->ins()->size() == 2, "Expected 2 inputs, got {}",
op->ins()->size());
::ttnn::Tensor &lhs = *liveTensors.at(op->ins()->Get(0)->global_id());
::ttnn::Tensor &rhs = *liveTensors.at(op->ins()->Get(1)->global_id());
tensorPool.push_back(::ttnn::divide(lhs, rhs));
Expand All @@ -247,28 +265,32 @@ run(::tt::target::ttnn::EltwiseOp const *op, ::ttnn::Device &device,
}
/* Eltwise Unary */
case ::tt::target::ttnn::EltwiseOpType::Relu: {
assert(op->ins()->size() == 1 && "Unsupported number of inputs");
TT_FATAL(op->ins()->size() == 1, "Expected 1 input, got {}",
op->ins()->size());
::ttnn::Tensor &in = *liveTensors.at(op->ins()->Get(0)->global_id());
tensorPool.push_back(::ttnn::relu(in));
liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back());
break;
}
case ::tt::target::ttnn::EltwiseOpType::Sqrt: {
assert(op->ins()->size() == 1 && "Unsupported number of inputs");
TT_FATAL(op->ins()->size() == 1, "Expected 1 input, got {}",
op->ins()->size());
::ttnn::Tensor &in = *liveTensors.at(op->ins()->Get(0)->global_id());
tensorPool.push_back(::ttnn::sqrt(in));
liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back());
break;
}
case ::tt::target::ttnn::EltwiseOpType::Sigmoid: {
assert(op->ins()->size() == 1 && "Unsupported number of inputs");
TT_FATAL(op->ins()->size() == 1, "Expected 1 input, got {}",
op->ins()->size());
::ttnn::Tensor &in = *liveTensors.at(op->ins()->Get(0)->global_id());
tensorPool.push_back(::ttnn::sigmoid(in));
liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back());
break;
}
case ::tt::target::ttnn::EltwiseOpType::Reciprocal: {
assert(op->ins()->size() == 1 && "Unsupported number of inputs");
TT_FATAL(op->ins()->size() == 1, "Expected 1 input, got {}",
op->ins()->size());
::ttnn::Tensor &in = *liveTensors.at(op->ins()->Get(0)->global_id());
tensorPool.push_back(::ttnn::reciprocal(in));
liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back());
Expand Down Expand Up @@ -451,22 +473,24 @@ void runProgram(::ttnn::Device &device,
std::list<::ttnn::Tensor> tensorPool;

int inputIndex = 0;
assert(program->inputs()->size() == inputs.size() &&
"Mismatch between program inputs and input tensors");
TT_FATAL(program->inputs()->size() == inputs.size(),
"Program expects {} inputs, found {} in input tensors vector",
program->inputs()->size(), inputs.size());
bool is_nop = handleNopProgram(program, inputs, outputs);
for (::tt::target::TensorRef const *input : *program->inputs()) {
auto [iter, inserted] =
liveTensors.try_emplace(input->global_id(), inputs[inputIndex++]);
assert(inserted && "Duplicate input tensor");
TT_FATAL(inserted, "Duplicate input tensor");
}

int outputIndex = 0;
assert(program->outputs()->size() == outputs.size() &&
"Mismatch between program outputs and output tensors");
TT_FATAL(program->outputs()->size() == outputs.size(),
"Program expects {} outputs, found {} in output tensors vector",
program->outputs()->size(), outputs.size());
for (::tt::target::TensorRef const *output : *program->outputs()) {
auto [iter, inserted] =
liveTensors.try_emplace(output->global_id(), outputs[outputIndex++]);
assert(is_nop || inserted && "Duplicate output tensor");
TT_FATAL(is_nop || inserted, "Duplicate output tensor");
}

for (::tt::target::ttnn::Operation const *op : *program->operations()) {
Expand Down

0 comments on commit 1bcf83d

Please sign in to comment.