Skip to content

Commit

Permalink
[Mixed] Support MSELoss - Mixed Precision
Browse files Browse the repository at this point in the history
enable mixed precision in MSE Loss

**Self evaluation:**
1. Build test:	 [X]Passed [ ]Failed [ ]Skipped
2. Run test:	 [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Donghak PARK <[email protected]>
  • Loading branch information
DonghakPark committed May 23, 2024
1 parent 3fe9a1e commit 83ee96e
Showing 1 changed file with 24 additions and 2 deletions.
26 changes: 24 additions & 2 deletions Applications/Custom/mae_loss.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,37 @@ static constexpr size_t SINGLE_INOUT_IDX = 0;

void MaeLossLayer::forwarding(nntrainer::RunLayerContext &context,
bool training) {
nntrainer::Tensor &predicted = context.getInput(SINGLE_INOUT_IDX);
nntrainer::Tensor empty_tensor;

nntrainer::Tensor &predicted =
context.getInput(SINGLE_INOUT_IDX).getDataType() ==
ml::train::TensorDim::DataType::FP32
? context.getInput(SINGLE_INOUT_IDX)
: empty_tensor;

if (predicted.empty())
predicted = context.getInput(SINGLE_INOUT_IDX)
.clone(ml::train::TensorDim::DataType::FP32);

nntrainer::Tensor &output = context.getOutput(SINGLE_INOUT_IDX);

if (!context.executeInPlace())
output.fill(predicted);
}

void MaeLossLayer::calcDerivative(nntrainer::RunLayerContext &context) {
nntrainer::Tensor &predicted = context.getInput(SINGLE_INOUT_IDX);
nntrainer::Tensor empty_tensor;

nntrainer::Tensor &predicted =
context.getInput(SINGLE_INOUT_IDX).getDataType() ==
ml::train::TensorDim::DataType::FP32
? context.getInput(SINGLE_INOUT_IDX)
: empty_tensor;

if (predicted.empty())
predicted = context.getInput(SINGLE_INOUT_IDX)
.clone(ml::train::TensorDim::DataType::FP32);

nntrainer::Tensor &label = context.getLabel(SINGLE_INOUT_IDX);

nntrainer::Tensor &deriv = context.getOutgoingDerivative(SINGLE_INOUT_IDX);
Expand Down

0 comments on commit 83ee96e

Please sign in to comment.