From e73f0c1fc834dbcbb66be108a283027c029809f1 Mon Sep 17 00:00:00 2001 From: PuQing Date: Wed, 8 Mar 2023 15:14:30 +0800 Subject: [PATCH 1/4] fix momentum dtype infer --- .../new_executor/interpreter/interpreter_util.cc | 1 - paddle/phi/infermeta/multiary.cc | 7 ++++++- paddle/phi/kernels/gpu/momentum_kernel.cu | 10 ++++++++-- paddle/phi/kernels/xpu/momentum_kernel.cc | 5 ++++- 4 files changed, 18 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index c11e48ec94c36..f0bb79b43bf53 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -94,7 +94,6 @@ static std::set OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { "matrix_rank_tol", "merged_adam", "mode", - "momentum", "multiclass_nms3", "multinomial", "nanmedian", diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 9a4f233ce8f93..8607bea0136f0 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -2108,10 +2108,15 @@ void MomentumInferMeta(const MetaTensor& param, auto param_dim = param.dims(); param_out->set_dims(param_dim); + auto MPType = (param.dtype() == phi::DataType::FLOAT16 || + param.dtype() == phi::DataType::BFLOAT16) + ? phi::DataType::FLOAT32 + : param.dtype(); velocity_out->set_dims(param_dim); - + velocity_out->set_dtype(MPType); if (master_param_out) { master_param_out->set_dims(param_dim); + master_param_out->set_dtype(MPType); } } diff --git a/paddle/phi/kernels/gpu/momentum_kernel.cu b/paddle/phi/kernels/gpu/momentum_kernel.cu index 5a4f5d33e6165..d544c242ded76 100644 --- a/paddle/phi/kernels/gpu/momentum_kernel.cu +++ b/paddle/phi/kernels/gpu/momentum_kernel.cu @@ -24,7 +24,10 @@ PD_REGISTER_KERNEL(momentum, phi::MomentumDenseKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16) { + kernel->OutputAt(1).SetDataType(UNDEFINED); + kernel->OutputAt(2).SetDataType(UNDEFINED); +} PD_REGISTER_KERNEL(momentum_dense_param_sparse_grad, GPU, @@ -32,4 +35,7 @@ PD_REGISTER_KERNEL(momentum_dense_param_sparse_grad, phi::MomentumSparseKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16) { + kernel->OutputAt(1).SetDataType(UNDEFINED); + kernel->OutputAt(2).SetDataType(UNDEFINED); +} diff --git a/paddle/phi/kernels/xpu/momentum_kernel.cc b/paddle/phi/kernels/xpu/momentum_kernel.cc index ad9cb2e6ef86e..c9842cb970a46 100644 --- a/paddle/phi/kernels/xpu/momentum_kernel.cc +++ b/paddle/phi/kernels/xpu/momentum_kernel.cc @@ -69,4 +69,7 @@ PD_REGISTER_KERNEL(momentum, ALL_LAYOUT, phi::MomentumDenseKernel, float, - phi::dtype::float16) {} + phi::dtype::float16) { + kernel->OutputAt(1).SetDataType(UNDEFINED); + kernel->OutputAt(2).SetDataType(UNDEFINED); +} From dcc593ee8af7fc39af69e7741e8d2c12745f480e Mon Sep 17 00:00:00 2001 From: PuQing Date: Wed, 15 Mar 2023 10:45:09 +0800 Subject: [PATCH 2/4] fix momentum datatype --- paddle/phi/kernels/cpu/momentum_kernel.cc | 10 ++++++++-- paddle/phi/kernels/gpu/momentum_kernel.cu | 8 ++++---- paddle/phi/kernels/xpu/momentum_kernel.cc | 4 ++-- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/paddle/phi/kernels/cpu/momentum_kernel.cc b/paddle/phi/kernels/cpu/momentum_kernel.cc index 7a4ea9f19e5c2..a4b91d9d1f24b 100644 --- a/paddle/phi/kernels/cpu/momentum_kernel.cc +++ b/paddle/phi/kernels/cpu/momentum_kernel.cc @@ -19,11 +19,17 @@ #include "paddle/phi/kernels/impl/momentum_kernel_impl.h" PD_REGISTER_KERNEL( - momentum, CPU, ALL_LAYOUT, phi::MomentumDenseKernel, float, double) {} + momentum, CPU, ALL_LAYOUT, phi::MomentumDenseKernel, float, double) { + kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED); + kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED); +} PD_REGISTER_KERNEL(momentum_dense_param_sparse_grad, CPU, ALL_LAYOUT, phi::MomentumSparseKernel, float, - double) {} + double) { + kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED); + kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED); +} diff --git a/paddle/phi/kernels/gpu/momentum_kernel.cu b/paddle/phi/kernels/gpu/momentum_kernel.cu index d544c242ded76..6d2b51dff64cb 100644 --- a/paddle/phi/kernels/gpu/momentum_kernel.cu +++ b/paddle/phi/kernels/gpu/momentum_kernel.cu @@ -25,8 +25,8 @@ PD_REGISTER_KERNEL(momentum, float, double, phi::dtype::float16) { - kernel->OutputAt(1).SetDataType(UNDEFINED); - kernel->OutputAt(2).SetDataType(UNDEFINED); + kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED); + kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED); } PD_REGISTER_KERNEL(momentum_dense_param_sparse_grad, @@ -36,6 +36,6 @@ PD_REGISTER_KERNEL(momentum_dense_param_sparse_grad, float, double, phi::dtype::float16) { - kernel->OutputAt(1).SetDataType(UNDEFINED); - kernel->OutputAt(2).SetDataType(UNDEFINED); + kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED); + kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED); } diff --git a/paddle/phi/kernels/xpu/momentum_kernel.cc b/paddle/phi/kernels/xpu/momentum_kernel.cc index c9842cb970a46..207bfef37f947 100644 --- a/paddle/phi/kernels/xpu/momentum_kernel.cc +++ b/paddle/phi/kernels/xpu/momentum_kernel.cc @@ -70,6 +70,6 @@ PD_REGISTER_KERNEL(momentum, phi::MomentumDenseKernel, float, phi::dtype::float16) { - kernel->OutputAt(1).SetDataType(UNDEFINED); - kernel->OutputAt(2).SetDataType(UNDEFINED); + kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED); + kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED); } From acb7bf2916cbd8a3c4d498120346d245e6b0909d Mon Sep 17 00:00:00 2001 From: PuQing Date: Sun, 19 Mar 2023 16:38:53 +0800 Subject: [PATCH 3/4] fix on cpu --- paddle/phi/kernels/cpu/momentum_kernel.cc | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/paddle/phi/kernels/cpu/momentum_kernel.cc b/paddle/phi/kernels/cpu/momentum_kernel.cc index a4b91d9d1f24b..7a4ea9f19e5c2 100644 --- a/paddle/phi/kernels/cpu/momentum_kernel.cc +++ b/paddle/phi/kernels/cpu/momentum_kernel.cc @@ -19,17 +19,11 @@ #include "paddle/phi/kernels/impl/momentum_kernel_impl.h" PD_REGISTER_KERNEL( - momentum, CPU, ALL_LAYOUT, phi::MomentumDenseKernel, float, double) { - kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED); - kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED); -} + momentum, CPU, ALL_LAYOUT, phi::MomentumDenseKernel, float, double) {} PD_REGISTER_KERNEL(momentum_dense_param_sparse_grad, CPU, ALL_LAYOUT, phi::MomentumSparseKernel, float, - double) { - kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED); - kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED); -} + double) {} From a1f5ff8d8e1529efd3220ad9c49ba862b14fa30f Mon Sep 17 00:00:00 2001 From: PuQing Date: Mon, 20 Mar 2023 15:51:07 +0800 Subject: [PATCH 4/4] add momentum --- .../fluid/framework/new_executor/interpreter/interpreter_util.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index 5f8aea485743c..6859d878d6281 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -65,6 +65,7 @@ static std::set OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { "less_equal", "less_than", "merged_adam", + "momentum", "multiclass_nms3", "nanmedian", "sync_batch_norm_grad",