From 83ee96eff3216072c08703ee49b17635b83a6142 Mon Sep 17 00:00:00 2001 From: Donghak PARK Date: Thu, 23 May 2024 15:58:45 +0900 Subject: [PATCH] [Mixed] Support MSELoss - Mixed Precision enable mixed precision in MSE Loss **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: Donghak PARK --- Applications/Custom/mae_loss.cpp | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/Applications/Custom/mae_loss.cpp b/Applications/Custom/mae_loss.cpp index 092f762cf6..6bc77796a8 100644 --- a/Applications/Custom/mae_loss.cpp +++ b/Applications/Custom/mae_loss.cpp @@ -24,7 +24,18 @@ static constexpr size_t SINGLE_INOUT_IDX = 0; void MaeLossLayer::forwarding(nntrainer::RunLayerContext &context, bool training) { - nntrainer::Tensor &predicted = context.getInput(SINGLE_INOUT_IDX); + nntrainer::Tensor empty_tensor; + + nntrainer::Tensor &predicted = + context.getInput(SINGLE_INOUT_IDX).getDataType() == + ml::train::TensorDim::DataType::FP32 + ? context.getInput(SINGLE_INOUT_IDX) + : empty_tensor; + + if (predicted.empty()) + predicted = context.getInput(SINGLE_INOUT_IDX) + .clone(ml::train::TensorDim::DataType::FP32); + nntrainer::Tensor &output = context.getOutput(SINGLE_INOUT_IDX); if (!context.executeInPlace()) @@ -32,7 +43,18 @@ void MaeLossLayer::forwarding(nntrainer::RunLayerContext &context, } void MaeLossLayer::calcDerivative(nntrainer::RunLayerContext &context) { - nntrainer::Tensor &predicted = context.getInput(SINGLE_INOUT_IDX); + nntrainer::Tensor empty_tensor; + + nntrainer::Tensor &predicted = + context.getInput(SINGLE_INOUT_IDX).getDataType() == + ml::train::TensorDim::DataType::FP32 + ? context.getInput(SINGLE_INOUT_IDX) + : empty_tensor; + + if (predicted.empty()) + predicted = context.getInput(SINGLE_INOUT_IDX) + .clone(ml::train::TensorDim::DataType::FP32); + nntrainer::Tensor &label = context.getLabel(SINGLE_INOUT_IDX); nntrainer::Tensor &deriv = context.getOutgoingDerivative(SINGLE_INOUT_IDX);