From bb2bb455469815d08fc8a3e6d3de4ea51611ad99 Mon Sep 17 00:00:00 2001 From: "jijoong.moon" Date: Sat, 11 May 2024 14:00:04 +0900 Subject: [PATCH] [ Context ] Add loss scale in Context & using mse loss This PR add loss scale parameter in runcontext and use it to update mse loss. . Add Loss Scale Parameter in RunLayerContext Constructor . Add applyLossScale func to update return derivitive in Loss Layer . Change MSE Loss Layer to apply the loss scale to return derivitive **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: jijoong.moon --- nntrainer/graph/network_graph.cpp | 6 ++-- nntrainer/layers/layer_context.cpp | 3 +- nntrainer/layers/layer_context.h | 36 +++++++++++++++++--- nntrainer/layers/layer_node.cpp | 9 ++--- nntrainer/layers/layer_node.h | 3 +- nntrainer/layers/loss/loss_layer.cpp | 7 ++++ nntrainer/layers/loss/loss_layer.h | 7 ++++ nntrainer/layers/loss/mse_loss_layer.cpp | 13 ++++++- nntrainer/layers/time_dist.cpp | 16 ++++----- nntrainer/models/model_common_properties.h | 2 +- nntrainer/tensor/weight.cpp | 2 +- test/unittest/layers/layers_golden_tests.cpp | 2 +- test/unittest/layers/unittest_layer_node.cpp | 4 +-- 13 files changed, 84 insertions(+), 26 deletions(-) diff --git a/nntrainer/graph/network_graph.cpp b/nntrainer/graph/network_graph.cpp index b7f4d1cffd..297cd3e881 100644 --- a/nntrainer/graph/network_graph.cpp +++ b/nntrainer/graph/network_graph.cpp @@ -880,7 +880,8 @@ NetworkGraph::finalizeContext(const std::shared_ptr &lnode, lnode->getTrainable(), shared_weight_names), inputs, outputs, tensor_manager->requestTensors(gnode, init_context.getTensorsSpec(), - lnode->getTrainable(), shared_tensor_names)); + lnode->getTrainable(), shared_tensor_names), + init_context.getLossScale()); return outputs; } @@ -1028,7 +1029,8 @@ NetworkGraph::refinalizeContext(const std::shared_ptr &lnode, // TODO: update weights spec for trainable based on layer trainable prop weights, inputs, outputs, tensor_manager->requestTensors(gnode, init_context.getTensorsSpec(), - lnode->getTrainable(), shared_tensor_names)); + lnode->getTrainable(), shared_tensor_names), + init_context.getLossScale()); return outputs; } diff --git a/nntrainer/layers/layer_context.cpp b/nntrainer/layers/layer_context.cpp index f0856c1dbb..fbbc9ecaff 100644 --- a/nntrainer/layers/layer_context.cpp +++ b/nntrainer/layers/layer_context.cpp @@ -126,13 +126,14 @@ const std::vector &InitLayerContext::getOutSpecs() const { } RunLayerContext::RunLayerContext(const std::string &name, bool trainable, - float l, bool in_place_, + float l, bool in_place_, float loss_scale_, const std::vector &w, const std::vector &in, const std::vector &out, const std::vector &t) : loss(l), in_place(in_place_), + loss_scale(loss_scale_), weights(w), inputs(in), outputs(out), diff --git a/nntrainer/layers/layer_context.h b/nntrainer/layers/layer_context.h index e2f428aa2c..09bccc2c73 100644 --- a/nntrainer/layers/layer_context.h +++ b/nntrainer/layers/layer_context.h @@ -63,7 +63,7 @@ class InitLayerContext { const float max_norm = 0.0, std::array tensor_type_ = {"NCHW", "FP32", "FP32"}, - const float loss_scale = 0.0); + const float loss_scale = 1.0); /** * @brief get Tensor Format of Layer * @@ -348,6 +348,14 @@ class InitLayerContext { */ bool executeInPlace() const { return in_place; } + /** + * @brief get Initial value of Loss_Scale. This is set to RunLayerContext + * and updated + * + * @return loss_scale + */ + float getLossScale() const { return loss_scale; } + private: std::vector input_dim; /**< Input dimensions for the layer */ bool in_place; /**< if the layer is expected to run in-place */ @@ -385,7 +393,7 @@ class RunLayerContext { * @brief Construct a new Run Layer Context object * */ - RunLayerContext() : loss(0.0), in_place(false) {} + RunLayerContext() : loss(0.0), in_place(false), loss_scale(1.0) {} /** * @brief Construct a new Run Layer Context object @@ -396,6 +404,17 @@ class RunLayerContext { std::get(props).set(name); } + /** + * @brief Construct a new Run Layer Context object + * + */ + RunLayerContext(const std::string &name, bool in_place_, float loss_scale_) : + RunLayerContext() { + in_place = in_place_; + std::get(props).set(name); + loss_scale = loss_scale_; + } + /** * @brief Construct a new Run Layer Context object * @@ -403,13 +422,15 @@ class RunLayerContext { * @param trainable if the layer is trainable * @param l loss of the layer * @param in_place_ execution in-place of the layer + * @param loss_scale loss_scale of the layer * @param w weights of the layer * @param in inputs of the layer * @param out outputs of the layer * @param t extra tensors of the layer */ RunLayerContext(const std::string &name, bool trainable, float l, - bool in_place_, const std::vector &w, + bool in_place_, float loss_scale_, + const std::vector &w, const std::vector &in, const std::vector &out, const std::vector &t); @@ -883,10 +904,17 @@ class RunLayerContext { */ ml::train::LayerComputeEngine getComputeEngine() { return compute_engine; } + /** + * @brief get loss scale + * @return loss scale + */ + float getLossScale() { return loss_scale; } + private: std::tuple props; /**< props of the layer */ float loss; /**< loss of the layer */ - bool in_place; /**< if the layer is expected to run in-place */ + bool in_place; /**< if the layer is expected to run in-place */ + float loss_scale; /**< loss_scale of the layer */ std::vector weights; /**< weights of the layer */ std::vector inputs; /**< inputs of the layer */ diff --git a/nntrainer/layers/layer_node.cpp b/nntrainer/layers/layer_node.cpp index 8b18d80762..f41752a4d8 100644 --- a/nntrainer/layers/layer_node.cpp +++ b/nntrainer/layers/layer_node.cpp @@ -599,7 +599,7 @@ InitLayerContext LayerNode::finalize(const std::vector &input_dims, const auto &scope = getSharedFrom().empty() ? getName() : getSharedFrom(); float max_norm = 0.0; - float loss_scale = 0.0; + float loss_scale = 1.0; if (!std::get(*layer_node_props).empty()) max_norm = std::get(*layer_node_props).get(); @@ -864,10 +864,11 @@ float LayerNode::getLoss() const { return *loss; } void LayerNode::configureRunContext(const std::vector &weights, const std::vector &inputs, const std::vector &outputs, - const std::vector &tensors) { + const std::vector &tensors, + float loss_scale) { run_context = std::make_unique( - getName(), getTrainable(), 0.0f, executeInPlace() != InPlace::NONE, weights, - inputs, outputs, tensors); + getName(), getTrainable(), 0.0f, executeInPlace() != InPlace::NONE, + loss_scale, weights, inputs, outputs, tensors); } /** diff --git a/nntrainer/layers/layer_node.h b/nntrainer/layers/layer_node.h index 7dfb1bd1a0..3fd2d55b97 100644 --- a/nntrainer/layers/layer_node.h +++ b/nntrainer/layers/layer_node.h @@ -820,7 +820,8 @@ class LayerNode final : public ml::train::Layer, public GraphNode { void configureRunContext(const std::vector &weights, const std::vector &inputs, const std::vector &outputs, - const std::vector &tensors); + const std::vector &tensors, + float loss_scale); /** * @brief Preset modes for printing summary for the layer diff --git a/nntrainer/layers/loss/loss_layer.cpp b/nntrainer/layers/loss/loss_layer.cpp index 40f74717f8..8f422fe379 100644 --- a/nntrainer/layers/loss/loss_layer.cpp +++ b/nntrainer/layers/loss/loss_layer.cpp @@ -36,6 +36,13 @@ void LossLayer::updateLoss(RunLayerContext &context, const Tensor &l) { context.setLoss(loss_sum / (float)l.batch()); } +void LossLayer::applyLossScale(RunLayerContext &context, Tensor &ret_deriv) { + + float loss_scale = context.getLossScale(); + if (loss_scale != 1.0) + ret_deriv.multiply_i(loss_scale); +} + /** * @copydoc Layer::setProperty(const std::vector &values) */ diff --git a/nntrainer/layers/loss/loss_layer.h b/nntrainer/layers/loss/loss_layer.h index 00b520f6e6..581e9477a8 100644 --- a/nntrainer/layers/loss/loss_layer.h +++ b/nntrainer/layers/loss/loss_layer.h @@ -60,6 +60,13 @@ class LossLayer : public Layer { */ void updateLoss(RunLayerContext &context, const Tensor &l); + /** + * @brief update return derivative with loss scale + * @param context Run context to update + * @param return_dev Tensor data to calculate + */ + void applyLossScale(RunLayerContext &context, Tensor &l); + Tensor l; /**< loss tensor to store intermediate value to calculate loss value */ }; diff --git a/nntrainer/layers/loss/mse_loss_layer.cpp b/nntrainer/layers/loss/mse_loss_layer.cpp index ec9bc9b844..3aed8125e0 100644 --- a/nntrainer/layers/loss/mse_loss_layer.cpp +++ b/nntrainer/layers/loss/mse_loss_layer.cpp @@ -50,8 +50,17 @@ void MSELossLayer::forwarding(RunLayerContext &context, bool training) { } void MSELossLayer::calcDerivative(RunLayerContext &context) { + Tensor empty_tensor; Tensor &ret_derivative = context.getOutgoingDerivative(SINGLE_INOUT_IDX); - const Tensor &y2 = context.getIncomingDerivative(SINGLE_INOUT_IDX); + const Tensor &y2_ = context.getIncomingDerivative(SINGLE_INOUT_IDX); + Tensor &y2 = empty_tensor; + + if (ret_derivative.getDataType() == ml::train::TensorDim::DataType::FP32) + y2 = y2_; + + if (y2.empty()) + y2 = y2_.clone(ret_derivative.getDataType()); + Tensor &y = context.getInput(SINGLE_INOUT_IDX); y.subtract(y2, ret_derivative); @@ -60,6 +69,8 @@ void MSELossLayer::calcDerivative(RunLayerContext &context) { throw std::runtime_error( "[MSELossLayer::calcDerivative] Error when calculating loss"); } + + LossLayer::applyLossScale(context, ret_derivative); } } // namespace nntrainer diff --git a/nntrainer/layers/time_dist.cpp b/nntrainer/layers/time_dist.cpp index 80451416df..779010065a 100644 --- a/nntrainer/layers/time_dist.cpp +++ b/nntrainer/layers/time_dist.cpp @@ -256,8 +256,8 @@ void TimeDistLayer::forwarding(RunLayerContext &context, bool training) { RunLayerContext dist_context(context.getName(), context.getTrainable(), context.getLoss(), context.executeInPlace(), - getWeightsForContext(), {&in_var}, {&out_var}, - getTensorsForContext()); + context.getLossScale(), getWeightsForContext(), + {&in_var}, {&out_var}, getTensorsForContext()); dist_layer->forwarding(dist_context, training); } @@ -303,8 +303,8 @@ void TimeDistLayer::calcDerivative(RunLayerContext &context) { RunLayerContext dist_context(context.getName(), context.getTrainable(), context.getLoss(), context.executeInPlace(), - getWeightsForContext(), {&in_var}, {&out_var}, - getTensorsForContext()); + context.getLossScale(), getWeightsForContext(), + {&in_var}, {&out_var}, getTensorsForContext()); dist_layer->calcDerivative(dist_context); } @@ -354,8 +354,8 @@ void TimeDistLayer::calcGradient(RunLayerContext &context) { RunLayerContext dist_context(context.getName(), context.getTrainable(), context.getLoss(), context.executeInPlace(), - getWeightsForContext(), {&in_var}, {&out_var}, - getTensorsForContext()); + context.getLossScale(), getWeightsForContext(), + {&in_var}, {&out_var}, getTensorsForContext()); dist_layer->calcGradient(dist_context); } @@ -396,8 +396,8 @@ void TimeDistLayer::setBatch(RunLayerContext &context, unsigned int batch) { RunLayerContext dist_context(context.getName(), context.getTrainable(), context.getLoss(), context.executeInPlace(), - getWeightsForContext(), {&in_var}, {&out_var}, - getTensorsForContext()); + context.getLossScale(), getWeightsForContext(), + {&in_var}, {&out_var}, getTensorsForContext()); dist_layer->setBatch(dist_context, batch); diff --git a/nntrainer/models/model_common_properties.h b/nntrainer/models/model_common_properties.h index 3776afefca..3435d18e96 100644 --- a/nntrainer/models/model_common_properties.h +++ b/nntrainer/models/model_common_properties.h @@ -217,7 +217,7 @@ class ModelTensorDataType final : public EnumProperty { */ class LossScale : public Property { public: - LossScale(float value = 0.0f); + LossScale(float value = 1.0f); static constexpr const char *key = "loss_scale"; /**< unique key to access */ using prop_tag = float_prop_tag; /**< property type */ }; diff --git a/nntrainer/tensor/weight.cpp b/nntrainer/tensor/weight.cpp index d8db5ba094..df262f50d9 100644 --- a/nntrainer/tensor/weight.cpp +++ b/nntrainer/tensor/weight.cpp @@ -99,7 +99,7 @@ Weight::Weight(const Tensor &v, const Tensor &g, const Tensor &v32, decay(0.0f), clip_by_global_norm(0.0f), output_axis(output_axis_), - loss_scale(0.0), + loss_scale(1.0), var32(std::make_shared(n + ":fp32")) { if (!g.empty() && isMixedPrecision()) { diff --git a/test/unittest/layers/layers_golden_tests.cpp b/test/unittest/layers/layers_golden_tests.cpp index 64400e6ecd..c71d653c05 100644 --- a/test/unittest/layers/layers_golden_tests.cpp +++ b/test/unittest/layers/layers_golden_tests.cpp @@ -156,7 +156,7 @@ static RunLayerContext prepareRunContext(const TensorPacks &packs) { }; auto rc = - RunLayerContext("golden", true, 0.0f, false, create_view(weights), + RunLayerContext("golden", true, 0.0f, false, 1.0, create_view(weights), create_view(ins), create_view(outs), create_view(tensors)); auto num_outputs = rc.getNumOutputs(); diff --git a/test/unittest/layers/unittest_layer_node.cpp b/test/unittest/layers/unittest_layer_node.cpp index 3b41f02f30..37287f7ce5 100644 --- a/test/unittest/layers/unittest_layer_node.cpp +++ b/test/unittest/layers/unittest_layer_node.cpp @@ -131,7 +131,7 @@ TEST(nntrainer_LayerNode, finalize_05_n) { nntrainer::createLayerNode(nntrainer::IdentityLayer::type)); EXPECT_NO_THROW(lnode->setProperty({"input_shape=1:1:1", "name=abc"})); EXPECT_NO_THROW(lnode->finalize()); - EXPECT_NO_THROW(lnode->configureRunContext({}, {&input}, {}, {})); + EXPECT_NO_THROW(lnode->configureRunContext({}, {&input}, {}, {}, 1.0)); EXPECT_THROW(lnode->finalize(), std::runtime_error); } @@ -298,7 +298,7 @@ TEST(nntrainer_LayerNode, setWeights_02_n) { EXPECT_NO_THROW(lnode = nntrainer::createLayerNode(nntrainer::IdentityLayer::type)); EXPECT_NO_THROW(lnode->setProperty({"input_shape=1:1:1", "name=abc"})); - EXPECT_NO_THROW(lnode->configureRunContext({&weight}, {&input}, {}, {})); + EXPECT_NO_THROW(lnode->configureRunContext({&weight}, {&input}, {}, {}, 1.0)); EXPECT_THROW(lnode->setWeights(new_weights), std::runtime_error); }