diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 1224caf88e668..cd0f7c7075d67 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -19,9 +19,9 @@ #include "paddle/fluid/inference/api/paddle_analysis_config.h" #include "paddle/fluid/inference/api/paddle_pass_builder.h" #include "paddle/fluid/inference/utils/table_printer.h" -#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/backends/cpu/cpu_info.h" #include "paddle/utils/string/split.h" #ifdef PADDLE_WITH_TENSORRT @@ -624,10 +624,11 @@ void AnalysisConfig::EnableMkldnnQuantizer() { void AnalysisConfig::EnableMkldnnBfloat16() { #ifdef PADDLE_WITH_MKLDNN - if (platform::MayIUse(platform::cpu_isa_t::avx512_core)) { + if (phi::backends::cpu::MayIUse(phi::backends::cpu::cpu_isa_t::avx512_core)) { use_mkldnn_bfloat16_ = true; LOG(INFO) << "Hardware support for BFLOAT16" - << (platform::MayIUse(platform::cpu_isa_t::avx512_bf16) + << (phi::backends::cpu::MayIUse( + phi::backends::cpu::cpu_isa_t::avx512_bf16) ? " is enabled" : " is disabled. Simulation will be used"); } else { diff --git a/paddle/fluid/inference/api/analysis_predictor_tester.cc b/paddle/fluid/inference/api/analysis_predictor_tester.cc index c75f1e4a569c0..627b6fba02313 100644 --- a/paddle/fluid/inference/api/analysis_predictor_tester.cc +++ b/paddle/fluid/inference/api/analysis_predictor_tester.cc @@ -29,7 +29,7 @@ #include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/tests/api/tester_helper.h" #include "paddle/fluid/inference/utils/io_utils.h" -#include "paddle/fluid/platform/cpu_info.h" +#include "paddle/phi/backends/cpu/cpu_info.h" DEFINE_string(dirname, "", "dirname to tests."); @@ -327,7 +327,7 @@ TEST(AnalysisPredictor, bf16_gpu_pass_strategy) { config.EnableUseGpu(100, 0); config.EnableMkldnnBfloat16(); #ifdef PADDLE_WITH_MKLDNN - if (platform::MayIUse(platform::cpu_isa_t::avx512_core)) + if (phi::backends::cpu::MayIUse(phi::backends::cpu::cpu_isa_t::avx512_core)) ASSERT_EQ(config.mkldnn_bfloat16_enabled(), true); else ASSERT_EQ(config.mkldnn_bfloat16_enabled(), false); diff --git a/paddle/fluid/inference/api/onnxruntime_predictor_tester.cc b/paddle/fluid/inference/api/onnxruntime_predictor_tester.cc index ff8528c085009..deb9f11486e89 100644 --- a/paddle/fluid/inference/api/onnxruntime_predictor_tester.cc +++ b/paddle/fluid/inference/api/onnxruntime_predictor_tester.cc @@ -27,7 +27,7 @@ #include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/tests/api/tester_helper.h" #include "paddle/fluid/inference/utils/io_utils.h" -#include "paddle/fluid/platform/cpu_info.h" +#include "paddle/phi/backends/cpu/cpu_info.h" DEFINE_string(dirname, "", "dirname to tests."); diff --git a/paddle/fluid/inference/tests/api/analyzer_bfloat16_image_classification_tester.cc b/paddle/fluid/inference/tests/api/analyzer_bfloat16_image_classification_tester.cc index 267fb17ee6baa..c92d2ebf278f8 100644 --- a/paddle/fluid/inference/tests/api/analyzer_bfloat16_image_classification_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_bfloat16_image_classification_tester.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/fluid/inference/api/paddle_analysis_config.h" #include "paddle/fluid/inference/tests/api/tester_helper.h" -#include "paddle/fluid/platform/cpu_info.h" +#include "paddle/phi/backends/cpu/cpu_info.h" DEFINE_bool(enable_mkldnn, true, "Enable MKLDNN"); @@ -47,7 +47,7 @@ TEST(Analyzer_bfloat16_image_classification, bfloat16) { std::vector> input_slots_all; SetInputs(&input_slots_all); if (FLAGS_enable_mkldnn && FLAGS_enable_bf16 && - platform::MayIUse(platform::cpu_isa_t::avx512_bf16)) { + phi::backends::cpu::MayIUse(phi::backends::cpu::cpu_isa_t::avx512_bf16)) { b_cfg.EnableMkldnnBfloat16(); } else { FLAGS_enable_bf16 = false; diff --git a/paddle/fluid/memory/allocation/buddy_allocator.h b/paddle/fluid/memory/allocation/buddy_allocator.h index 5e39e21c9664f..5a89a6d485461 100644 --- a/paddle/fluid/memory/allocation/buddy_allocator.h +++ b/paddle/fluid/memory/allocation/buddy_allocator.h @@ -27,9 +27,9 @@ limitations under the License. */ #include "paddle/fluid/memory/allocation/memory_block.h" #include "paddle/fluid/memory/allocation/system_allocator.h" -#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/device/npu/npu_info.h" +#include "paddle/phi/backends/cpu/cpu_info.h" namespace paddle { namespace memory { diff --git a/paddle/fluid/memory/allocation/naive_best_fit_allocator.cc b/paddle/fluid/memory/allocation/naive_best_fit_allocator.cc index d1a3b77e7720b..4bcfdb1aaf424 100644 --- a/paddle/fluid/memory/allocation/naive_best_fit_allocator.cc +++ b/paddle/fluid/memory/allocation/naive_best_fit_allocator.cc @@ -78,8 +78,8 @@ BuddyAllocator *GetCPUBuddyAllocator() { std::call_once(init_flag, []() { a = new detail::BuddyAllocator( std::unique_ptr(new detail::CPUAllocator), - platform::CpuMinChunkSize(), - platform::CpuMaxChunkSize()); + phi::backends::cpu::CpuMinChunkSize(), + phi::backends::cpu::CpuMaxChunkSize()); }); return a; @@ -290,8 +290,8 @@ BuddyAllocator *GetNPUPinnedBuddyAllocator() { std::call_once(init_flag, []() { ba = new BuddyAllocator(std::unique_ptr( new detail::NPUPinnedAllocator), - platform::NPUPinnedMinChunkSize(), - platform::NPUPinnedMaxChunkSize()); + phi::backends::cpu::NPUPinnedMinChunkSize(), + phi::backends::cpu::NPUPinnedMaxChunkSize()); }); return ba; @@ -562,8 +562,8 @@ BuddyAllocator *GetCUDAPinnedBuddyAllocator() { std::call_once(init_flag, []() { ba = new BuddyAllocator(std::unique_ptr( new detail::CUDAPinnedAllocator), - platform::CUDAPinnedMinChunkSize(), - platform::CUDAPinnedMaxChunkSize()); + phi::backends::cpu::CUDAPinnedMinChunkSize(), + phi::backends::cpu::CUDAPinnedMaxChunkSize()); }); return ba; diff --git a/paddle/fluid/memory/allocation/system_allocator.cc b/paddle/fluid/memory/allocation/system_allocator.cc index 15cd2f6d1f371..033b49dbc1268 100644 --- a/paddle/fluid/memory/allocation/system_allocator.cc +++ b/paddle/fluid/memory/allocation/system_allocator.cc @@ -28,10 +28,10 @@ limitations under the License. */ #endif #include "gflags/gflags.h" #include "paddle/fluid/memory/allocation/allocator.h" -#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/device/npu/npu_info.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/backends/cpu/cpu_info.h" #ifdef PADDLE_WITH_MLU #include "paddle/fluid/platform/device/mlu/mlu_info.h" #endif @@ -206,7 +206,7 @@ void* CUDAPinnedAllocator::Alloc(size_t* index, size_t size) { // of host pinned allocation. Allocates too much would reduce // the amount of memory available to the underlying system for paging. size_t usable = - paddle::platform::CUDAPinnedMaxAllocSize() - cuda_pinnd_alloc_size_; + phi::backends::cpu::CUDAPinnedMaxAllocSize() - cuda_pinnd_alloc_size_; if (size > usable) { LOG(WARNING) << "Cannot malloc " << size / 1024.0 / 1024.0 @@ -362,7 +362,7 @@ void* NPUPinnedAllocator::Alloc(size_t* index, size_t size) { if (size <= 0) return nullptr; size_t usable = - paddle::platform::NPUPinnedMaxAllocSize() - npu_pinnd_alloc_size_; + phi::backends::cpu::NPUPinnedMaxAllocSize() - npu_pinnd_alloc_size_; if (size > usable) { LOG(WARNING) << "Cannot malloc " << size / 1024.0 / 1024.0 diff --git a/paddle/fluid/memory/pinned_memory_test.cu b/paddle/fluid/memory/pinned_memory_test.cu index 259222754e8f8..1260065319029 100644 --- a/paddle/fluid/memory/pinned_memory_test.cu +++ b/paddle/fluid/memory/pinned_memory_test.cu @@ -18,9 +18,9 @@ limitations under the License. */ #include "paddle/fluid/memory/allocation/memory_block.h" #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memory.h" -#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/place.h" +#include "paddle/phi/backends/cpu/cpu_info.h" // This unit test is an example comparing the performance between using pinned // memory and not. In general, using pinned memory will be faster. diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index 9dff9a05d73ad..330b13ab8b290 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -16,7 +16,7 @@ limitations under the License. */ #include -#include "paddle/fluid/platform/cpu_info.h" +#include "paddle/phi/backends/cpu/cpu_info.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/cpu_vec.h" #include "paddle/phi/kernels/funcs/fc_functor.h" @@ -315,10 +315,10 @@ use lstm_x_t as input and compute as standard LSTM. template inline void bias_relu(const int n, const T* x, const T* bias, T* y) { if (bias) { - phi::funcs::vec_add_bias(n, *bias, x, y); - phi::funcs::vec_relu(n, y, y); + phi::funcs::vec_add_bias(n, *bias, x, y); + phi::funcs::vec_relu(n, y, y); } else { - phi::funcs::vec_relu(n, x, y); + phi::funcs::vec_relu(n, x, y); } } @@ -329,8 +329,9 @@ inline void vec_softmax(const int n, const T* x, T* y) { for (int i = 1; i < n; ++i) { scalar = scalar < x[i] ? x[i] : scalar; } - phi::funcs::vec_add_bias(n, -scalar, x, y); // sub - phi::funcs::vec_exp(n, y, y); // exp + phi::funcs::vec_add_bias( + n, -scalar, x, y); // sub + phi::funcs::vec_exp(n, y, y); // exp // sum scalar = T(0); for (int i = 0; i < n; ++i) { @@ -393,13 +394,13 @@ class AttentionLSTMKernel : public framework::OpKernel { auto& act_gate_str = ctx.Attr("gate_activation"); auto& act_cell_str = ctx.Attr("cell_activation"); auto& act_cand_str = ctx.Attr("candidate_activation"); - if (platform::MayIUse(platform::avx)) { - phi::funcs::VecActivations act_functor; + if (phi::backends::cpu::MayIUse(phi::backends::cpu::avx)) { + phi::funcs::VecActivations act_functor; act_gate = act_functor(act_gate_str); act_cell = act_functor(act_cell_str); act_cand = act_functor(act_cand_str); } else { - phi::funcs::VecActivations act_functor; + phi::funcs::VecActivations act_functor; act_gate = act_functor(act_gate_str); act_cell = act_functor(act_cell_str); act_cand = act_functor(act_cand_str); diff --git a/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu index 40a0cb196f3bb..f17a9dfe3c232 100644 --- a/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu @@ -13,13 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.h" -#include "paddle/fluid/operators/math/softmax_impl.h" + +#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/fluid/string/string_helper.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/kernels/funcs/axis_utils.h" #include "paddle/phi/kernels/funcs/cross_entropy.h" +#include "paddle/phi/kernels/funcs/softmax_impl.h" namespace paddle { namespace operators { @@ -129,15 +131,15 @@ struct CSoftmaxWithCrossEntropyFunctor { softmax_2d.ShareDataWith(*softmax).Resize({N, D}); loss_2d.ShareDataWith(*loss).Resize({N, 1}); - auto eigen_logits = math::EigenMatrix::From(logits_2d); - auto eigen_softmax = math::EigenMatrix::From(softmax_2d); + auto eigen_logits = phi::funcs::EigenMatrix::From(logits_2d); + auto eigen_softmax = phi::funcs::EigenMatrix::From(softmax_2d); // step 1, obtain logit_max phi::DenseTensor logits_max; logits_max = ctx.AllocateTmpTensor({N, 1}, dev_ctx); void* logits_max_buff = logits_max.mutable_data(place); - auto eigen_logits_max = math::EigenMatrix::From(logits_max); + auto eigen_logits_max = phi::funcs::EigenMatrix::From(logits_max); Eigen::DSizes along_axis(1); eigen_logits_max.device(*dev_ctx.eigen_device()) = eigen_logits.maximum(along_axis); @@ -158,7 +160,7 @@ struct CSoftmaxWithCrossEntropyFunctor { eigen_softmax.device(*dev_ctx.eigen_device()) = (eigen_logits - eigen_logits_max.reshape(batch_by_one).broadcast(one_by_class)) - .unaryExpr(math::ValueClip()); + .unaryExpr(phi::funcs::ValueClip()); // step 3, obtain predict target phi::DenseTensor predicted_logits; @@ -217,7 +219,8 @@ struct CSoftmaxWithCrossEntropyFunctor { sum_exp_logits = ctx.AllocateTmpTensor({N, 1}, dev_ctx); void* sum_exp_logits_buff = sum_exp_logits.mutable_data(place); - auto eigen_sum_exp_logits = math::EigenMatrix::From(sum_exp_logits); + auto eigen_sum_exp_logits = + phi::funcs::EigenMatrix::From(sum_exp_logits); eigen_sum_exp_logits.device(*dev_ctx.eigen_device()) = eigen_softmax.sum(along_axis); @@ -231,8 +234,9 @@ struct CSoftmaxWithCrossEntropyFunctor { comm->comm(), stream)); - auto eigen_loss = math::EigenMatrix::From(loss_2d); - auto eigen_predicted_logits = math::EigenMatrix::From(predicted_logits); + auto eigen_loss = phi::funcs::EigenMatrix::From(loss_2d); + auto eigen_predicted_logits = + phi::funcs::EigenMatrix::From(predicted_logits); eigen_loss.device(*dev_ctx.eigen_device()) = (eigen_sum_exp_logits.log().unaryExpr(phi::funcs::TolerableValue()) - @@ -281,14 +285,14 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor { softmax_2d.ShareDataWith(*softmax).Resize({N, D}); loss_2d.ShareDataWith(*loss).Resize({N, 1}); - auto eigen_logits = math::EigenMatrix::From(logits_2d); - auto eigen_softmax = math::EigenMatrix::From(softmax_2d); + auto eigen_logits = phi::funcs::EigenMatrix::From(logits_2d); + auto eigen_softmax = phi::funcs::EigenMatrix::From(softmax_2d); // step 1, obtain logit_max phi::DenseTensor logits_max; logits_max = ctx.AllocateTmpTensor({N, 1}, dev_ctx); - auto eigen_logits_max = math::EigenMatrix::From(logits_max); + auto eigen_logits_max = phi::funcs::EigenMatrix::From(logits_max); Eigen::DSizes along_axis(1); eigen_logits_max.device(*dev_ctx.eigen_device()) = eigen_logits.maximum(along_axis); @@ -304,7 +308,7 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor { eigen_softmax.device(*dev_ctx.eigen_device()) = (eigen_logits - eigen_logits_max.reshape(batch_by_one).broadcast(one_by_class)) - .unaryExpr(math::ValueClip()); + .unaryExpr(phi::funcs::ValueClip()); // step 3, obtain predict target phi::DenseTensor predicted_logits; @@ -357,7 +361,8 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor { sum_exp_logits = ctx.AllocateTmpTensor({N, 1}, dev_ctx); void* sum_exp_logits_buff = sum_exp_logits.mutable_data(place); - auto eigen_sum_exp_logits = math::EigenMatrix::From(sum_exp_logits); + auto eigen_sum_exp_logits = + phi::funcs::EigenMatrix::From(sum_exp_logits); eigen_sum_exp_logits.device(*dev_ctx.eigen_device()) = eigen_softmax.sum(along_axis); @@ -366,8 +371,9 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor { opts.reduce_op = distributed::ReduceOp::SUM; pg->AllReduce(in_out, in_out, opts)->Synchronize(); - auto eigen_loss = math::EigenMatrix::From(loss_2d); - auto eigen_predicted_logits = math::EigenMatrix::From(predicted_logits); + auto eigen_loss = phi::funcs::EigenMatrix::From(loss_2d); + auto eigen_predicted_logits = + phi::funcs::EigenMatrix::From(predicted_logits); eigen_loss.device(*dev_ctx.eigen_device()) = (eigen_sum_exp_logits.log().unaryExpr(phi::funcs::TolerableValue()) - diff --git a/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.h b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.h index f3a438e729bb1..d7b20224b1448 100644 --- a/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.h +++ b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.h @@ -22,9 +22,9 @@ limitations under the License. */ #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/softmax.h" #include "paddle/phi/api/include/tensor.h" #include "paddle/phi/kernels/funcs/cross_entropy.h" +#include "paddle/phi/kernels/funcs/softmax.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.h b/paddle/fluid/operators/elementwise/elementwise_mul_op.h index 5aa1b7ed4f1dd..7f233eba88333 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.h @@ -17,7 +17,7 @@ limitations under the License. */ #include #include "paddle/fluid/operators/elementwise/elementwise_op.h" -#include "paddle/fluid/platform/cpu_info.h" +#include "paddle/phi/backends/cpu/cpu_info.h" #include "paddle/phi/kernels/elementwise_kernel.h" namespace paddle { diff --git a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc index 885f3412a4e06..11b9044bc5472 100644 --- a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc +++ b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc @@ -16,7 +16,7 @@ limitations under the License. */ #include -#include "paddle/fluid/platform/cpu_info.h" +#include "paddle/phi/backends/cpu/cpu_info.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/cpu_vec.h" #include "paddle/phi/kernels/funcs/sequence2batch.h" @@ -278,13 +278,13 @@ class FusedEmbeddingFCLSTMKernel : public framework::OpKernel { auto& act_gate_str = ctx.Attr("gate_activation"); \ auto& act_cell_str = ctx.Attr("cell_activation"); \ auto& act_cand_str = ctx.Attr("candidate_activation"); \ - if (platform::MayIUse(platform::avx)) { \ - phi::funcs::VecActivations act_functor; \ + if (phi::backends::cpu::MayIUse(phi::backends::cpu::avx)) { \ + phi::funcs::VecActivations act_functor; \ act_gate = act_functor(act_gate_str); \ act_cell = act_functor(act_cell_str); \ act_cand = act_functor(act_cand_str); \ } else { \ - phi::funcs::VecActivations act_functor; \ + phi::funcs::VecActivations act_functor; \ act_gate = act_functor(act_gate_str); \ act_cell = act_functor(act_cell_str); \ act_cand = act_functor(act_cand_str); \ diff --git a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc index df4cbba1dec15..dd5b3c0073f3c 100644 --- a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc @@ -16,7 +16,7 @@ limitations under the License. */ #include -#include "paddle/fluid/platform/cpu_info.h" +#include "paddle/phi/backends/cpu/cpu_info.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/cpu_vec.h" #include "paddle/phi/kernels/funcs/fc_functor.h" @@ -225,11 +225,11 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel { std::function fc_act; auto& fc_act_str = ctx.Attr("fc_activation"); - if (platform::MayIUse(platform::avx)) { - phi::funcs::VecActivations act_functor; + if (phi::backends::cpu::MayIUse(phi::backends::cpu::avx)) { + phi::funcs::VecActivations act_functor; fc_act = act_functor(fc_act_str); } else { - phi::funcs::VecActivations act_functor; + phi::funcs::VecActivations act_functor; fc_act = act_functor(fc_act_str); } diff --git a/paddle/fluid/operators/jit/CMakeLists.txt b/paddle/fluid/operators/jit/CMakeLists.txt index a6f10e5fbdab7..8aa7eea6708a6 100644 --- a/paddle/fluid/operators/jit/CMakeLists.txt +++ b/paddle/fluid/operators/jit/CMakeLists.txt @@ -9,7 +9,7 @@ file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/helper.h\"\n") file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/registry.h\"\n\n") -set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce place xxhash) +set(JIT_KERNEL_DEPS device_context cblas gflags enforce place xxhash) file( GLOB jit_kernel_cc_srcs diff --git a/paddle/fluid/operators/jit/gen/act.cc b/paddle/fluid/operators/jit/gen/act.cc index 5a73e3c56d511..cd88ec94b3a09 100644 --- a/paddle/fluid/operators/jit/gen/act.cc +++ b/paddle/fluid/operators/jit/gen/act.cc @@ -15,7 +15,7 @@ #include "paddle/fluid/operators/jit/gen/act.h" #include "paddle/fluid/operators/jit/registry.h" -#include "paddle/fluid/platform/cpu_info.h" +#include "paddle/phi/backends/cpu/cpu_info.h" namespace paddle { namespace operators { @@ -98,27 +98,27 @@ DECLARE_ACT_CREATOR(VTanh); // TODO(TJ): tuning use me bool VReluCreator::CanBeUsed(const int& d) const { - return platform::MayIUse(platform::avx); + return phi::backends::cpu::MayIUse(phi::backends::cpu::avx); } bool VSquareCreator::CanBeUsed(const int& d) const { - return platform::MayIUse(platform::avx); + return phi::backends::cpu::MayIUse(phi::backends::cpu::avx); } bool VIdentityCreator::CanBeUsed(const int& d) const { - return platform::MayIUse(platform::avx); + return phi::backends::cpu::MayIUse(phi::backends::cpu::avx); } bool VExpCreator::CanBeUsed(const int& d) const { - return platform::MayIUse(platform::avx) && d < 32; + return phi::backends::cpu::MayIUse(phi::backends::cpu::avx) && d < 32; } bool VSigmoidCreator::CanBeUsed(const int& d) const { - return platform::MayIUse(platform::avx); + return phi::backends::cpu::MayIUse(phi::backends::cpu::avx); } bool VTanhCreator::CanBeUsed(const int& d) const { - return platform::MayIUse(platform::avx); + return phi::backends::cpu::MayIUse(phi::backends::cpu::avx); } size_t VReluCreator::CodeSize(const int& d) const { diff --git a/paddle/fluid/operators/jit/gen/act.h b/paddle/fluid/operators/jit/gen/act.h index 8eaf75fa5ee29..f1007e92c7258 100644 --- a/paddle/fluid/operators/jit/gen/act.h +++ b/paddle/fluid/operators/jit/gen/act.h @@ -84,8 +84,8 @@ class VActFunc : public JitCode { // compute EXP with ymm, xmm template - void exp_jmm(JMM& dst, - JMM& src, + void exp_jmm(JMM& dst, // NOLINT + JMM& src, // NOLINT int src_idx = 11, int fx_idx = 12, // NOLINT int fy_idx = 13, @@ -144,10 +144,11 @@ class VActFunc : public JitCode { vcvttps2dq(ymm_int, jmm_fx); mov(reg_ptr_global, reinterpret_cast(exp_int_0x7f)); vmovdqa(jmm_tmp, ptr[reg_ptr_global]); - if (MayIUse(avx2) || std::is_same::value) { + if (phi::backends::cpu::MayIUse(phi::backends::cpu::avx2) || + std::is_same::value) { vpaddd(ymm_int, ymm_int, jmm_tmp); vpslld(ymm_int, ymm_int, 23); - } else if (MayIUse(avx)) { + } else if (phi::backends::cpu::MayIUse(phi::backends::cpu::avx)) { xmm_t xtmp1 = xmm_t(ymm_int.getIdx()); xmm_t xtmp2 = xmm_t(jmm_tmp.getIdx()); reg64_t reg_ptr_tmp = reg_ptr_global; @@ -174,8 +175,8 @@ class VActFunc : public JitCode { // compute SIGMOID with ymm, xmm template - void sigmoid_jmm(JMM& dst, - JMM& src, + void sigmoid_jmm(JMM& dst, // NOLINT + JMM& src, // NOLINT int src_idx = 11, // NOLINT int fx_idx = 12, int fy_idx = 13, @@ -203,8 +204,8 @@ class VActFunc : public JitCode { // compute TANH with ymm, xmm template - void tanh_jmm(JMM& dst, - JMM& src, + void tanh_jmm(JMM& dst, // NOLINT + JMM& src, // NOLINT int src_idx = 11, // NOLINT int fx_idx = 12, int fy_idx = 13, diff --git a/paddle/fluid/operators/jit/gen/adam.cc b/paddle/fluid/operators/jit/gen/adam.cc index 38ef6772f01ad..17b3b07720927 100644 --- a/paddle/fluid/operators/jit/gen/adam.cc +++ b/paddle/fluid/operators/jit/gen/adam.cc @@ -17,7 +17,7 @@ #include // offsetof #include "paddle/fluid/operators/jit/registry.h" -#include "paddle/fluid/platform/cpu_info.h" +#include "paddle/phi/backends/cpu/cpu_info.h" namespace paddle { namespace operators { @@ -132,7 +132,7 @@ void AdamJitCode::genCode() { class AdamCreator : public JitCodeCreator { public: bool CanBeUsed(const adam_attr_t& attr) const override { - return platform::MayIUse(platform::avx512f); + return phi::backends::cpu::MayIUse(phi::backends::cpu::avx512f); } size_t CodeSize(const adam_attr_t& attr) const override { return 96 + 32 * 8; diff --git a/paddle/fluid/operators/jit/gen/adamw.cc b/paddle/fluid/operators/jit/gen/adamw.cc index b470143fb7d8d..1cd6637d8d301 100644 --- a/paddle/fluid/operators/jit/gen/adamw.cc +++ b/paddle/fluid/operators/jit/gen/adamw.cc @@ -17,7 +17,7 @@ #include // offsetof #include "paddle/fluid/operators/jit/registry.h" -#include "paddle/fluid/platform/cpu_info.h" +#include "paddle/phi/backends/cpu/cpu_info.h" namespace paddle { namespace operators { @@ -147,7 +147,7 @@ void AdamWJitCode::genCode() { class AdamWCreator : public JitCodeCreator { public: bool CanBeUsed(const int& attr) const override { - return platform::MayIUse(platform::avx512f); + return phi::backends::cpu::MayIUse(phi::backends::cpu::avx512f); } size_t CodeSize(const int& attr) const override { return 96 + 32 * 8; } std::unique_ptr CreateJitCode(const int& attr) const override { diff --git a/paddle/fluid/operators/jit/gen/blas.cc b/paddle/fluid/operators/jit/gen/blas.cc index 7c37bb9b05128..b15fe3d5c0463 100644 --- a/paddle/fluid/operators/jit/gen/blas.cc +++ b/paddle/fluid/operators/jit/gen/blas.cc @@ -16,7 +16,7 @@ #include "paddle/fluid/operators/jit/macro.h" #include "paddle/fluid/operators/jit/registry.h" -#include "paddle/fluid/platform/cpu_info.h" +#include "paddle/phi/backends/cpu/cpu_info.h" namespace paddle { namespace operators { @@ -145,7 +145,7 @@ void NCHW16CMulNCJitCode::genCode() { class NCHW16CMulNCCreator : public JitCodeCreator { public: bool CanBeUsed(const int& attr) const override { - return platform::MayIUse(platform::avx512f); + return phi::backends::cpu::MayIUse(phi::backends::cpu::avx512f); } size_t CodeSize(const int& d) const override { return 256 * 1024; } std::unique_ptr CreateJitCode(const int& attr) const override { @@ -157,7 +157,8 @@ class NCHW16CMulNCCreator : public JitCodeCreator { class name##Creator : public JitCodeCreator { \ public: \ bool CanBeUsed(const int& attr) const override { \ - return platform::MayIUse(platform::avx) && attr <= 1024; \ + return phi::backends::cpu::MayIUse(phi::backends::cpu::avx) && \ + attr <= 1024; \ } \ size_t CodeSize(const int& d) const override { \ return 96 + d / YMM_FLOAT_BLOCK * 4 * 8; \ diff --git a/paddle/fluid/operators/jit/gen/embseqpool.cc b/paddle/fluid/operators/jit/gen/embseqpool.cc index 2a5617a078196..0fcfe6addda4e 100644 --- a/paddle/fluid/operators/jit/gen/embseqpool.cc +++ b/paddle/fluid/operators/jit/gen/embseqpool.cc @@ -18,7 +18,7 @@ #include "paddle/fluid/operators/jit/macro.h" #include "paddle/fluid/operators/jit/registry.h" -#include "paddle/fluid/platform/cpu_info.h" +#include "paddle/phi/backends/cpu/cpu_info.h" namespace paddle { namespace operators { @@ -123,7 +123,7 @@ void EmbSeqPoolJitCode::genCode() { class EmbSeqPoolCreator : public JitCodeCreator { public: bool CanBeUsed(const emb_seq_pool_attr_t& attr) const override { - return platform::MayIUse(platform::avx) && + return phi::backends::cpu::MayIUse(phi::backends::cpu::avx) && attr.table_width % YMM_FLOAT_BLOCK == 0; } size_t CodeSize(const emb_seq_pool_attr_t& attr) const override { diff --git a/paddle/fluid/operators/jit/gen/gru.cc b/paddle/fluid/operators/jit/gen/gru.cc index f21ad5aa9144f..ae860ddf06ab5 100644 --- a/paddle/fluid/operators/jit/gen/gru.cc +++ b/paddle/fluid/operators/jit/gen/gru.cc @@ -18,7 +18,7 @@ #include "paddle/fluid/operators/jit/macro.h" #include "paddle/fluid/operators/jit/registry.h" -#include "paddle/fluid/platform/cpu_info.h" +#include "paddle/phi/backends/cpu/cpu_info.h" namespace paddle { namespace operators { @@ -85,20 +85,21 @@ void GRUJitCode::genCode() { ret(); } -#define DECLARE_GRU_CREATOR(name) \ - class name##Creator : public JitCodeCreator { \ - public: \ - /* TODO(TJ): enable more */ \ - bool CanBeUsed(const gru_attr_t& attr) const override { \ - return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \ - } \ - size_t CodeSize(const gru_attr_t& attr) const override { \ - return 96 + attr.d / YMM_FLOAT_BLOCK * 96 * 2 * 8; \ - } \ - std::unique_ptr CreateJitCode( \ - const gru_attr_t& attr) const override { \ - return make_unique(attr, CodeSize(attr)); \ - } \ +#define DECLARE_GRU_CREATOR(name) \ + class name##Creator : public JitCodeCreator { \ + public: \ + /* TODO(TJ): enable more */ \ + bool CanBeUsed(const gru_attr_t& attr) const override { \ + return phi::backends::cpu::MayIUse(phi::backends::cpu::avx) && \ + attr.d % 8 == 0; \ + } \ + size_t CodeSize(const gru_attr_t& attr) const override { \ + return 96 + attr.d / YMM_FLOAT_BLOCK * 96 * 2 * 8; \ + } \ + std::unique_ptr CreateJitCode( \ + const gru_attr_t& attr) const override { \ + return make_unique(attr, CodeSize(attr)); \ + } \ } DECLARE_GRU_CREATOR(GRUH1); diff --git a/paddle/fluid/operators/jit/gen/hopv.cc b/paddle/fluid/operators/jit/gen/hopv.cc index 7449a20a87707..bf7d7189b4b1c 100644 --- a/paddle/fluid/operators/jit/gen/hopv.cc +++ b/paddle/fluid/operators/jit/gen/hopv.cc @@ -15,7 +15,7 @@ #include "paddle/fluid/operators/jit/gen/hopv.h" #include "paddle/fluid/operators/jit/registry.h" -#include "paddle/fluid/platform/cpu_info.h" +#include "paddle/phi/backends/cpu/cpu_info.h" namespace paddle { namespace operators { @@ -78,7 +78,7 @@ void HOPVJitCode::genCode() { class name##Creator : public JitCodeCreator { \ public: \ bool CanBeUsed(const int& attr) const override { \ - return platform::MayIUse(platform::avx); \ + return phi::backends::cpu::MayIUse(phi::backends::cpu::avx); \ } \ size_t CodeSize(const int& d) const override { \ return 96 + d / YMM_FLOAT_BLOCK * 4 * 8; \ diff --git a/paddle/fluid/operators/jit/gen/jitcode.h b/paddle/fluid/operators/jit/gen/jitcode.h index d71497275daa4..ffbf4bac026bc 100644 --- a/paddle/fluid/operators/jit/gen/jitcode.h +++ b/paddle/fluid/operators/jit/gen/jitcode.h @@ -18,7 +18,7 @@ #include #include "paddle/fluid/operators/jit/gen_base.h" -#include "paddle/fluid/platform/cpu_info.h" +#include "paddle/phi/backends/cpu/cpu_info.h" #define XBYAK_USE_MMAP_ALLOCATOR #include "xbyak/xbyak.h" @@ -92,7 +92,7 @@ class JitCode : public GenBase, public Xbyak::CodeGenerator { for (int i = 0; i < num_g_abi_regs; ++i) { push(Xbyak::Reg64(g_abi_regs[i])); } - if (platform::MayIUse(platform::avx512f)) { + if (phi::backends::cpu::MayIUse(phi::backends::cpu::avx512f)) { mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt); } } diff --git a/paddle/fluid/operators/jit/gen/lstm.cc b/paddle/fluid/operators/jit/gen/lstm.cc index 7417a205faff5..ff1b39e551e1b 100644 --- a/paddle/fluid/operators/jit/gen/lstm.cc +++ b/paddle/fluid/operators/jit/gen/lstm.cc @@ -18,7 +18,7 @@ #include "paddle/fluid/operators/jit/macro.h" #include "paddle/fluid/operators/jit/registry.h" -#include "paddle/fluid/platform/cpu_info.h" +#include "paddle/phi/backends/cpu/cpu_info.h" namespace paddle { namespace operators { @@ -113,20 +113,21 @@ void LSTMJitCode::genCode() { } } -#define DECLARE_LSTM_CREATOR(name) \ - class name##Creator : public JitCodeCreator { \ - public: \ - /* TODO(TJ): enable more */ \ - bool CanBeUsed(const lstm_attr_t& attr) const override { \ - return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \ - } \ - size_t CodeSize(const lstm_attr_t& attr) const override { \ - return 96 + attr.d / YMM_FLOAT_BLOCK * 90 * 4 * 8; \ - } \ - std::unique_ptr CreateJitCode( \ - const lstm_attr_t& attr) const override { \ - return make_unique(attr, CodeSize(attr)); \ - } \ +#define DECLARE_LSTM_CREATOR(name) \ + class name##Creator : public JitCodeCreator { \ + public: \ + /* TODO(TJ): enable more */ \ + bool CanBeUsed(const lstm_attr_t& attr) const override { \ + return phi::backends::cpu::MayIUse(phi::backends::cpu::avx) && \ + attr.d % 8 == 0; \ + } \ + size_t CodeSize(const lstm_attr_t& attr) const override { \ + return 96 + attr.d / YMM_FLOAT_BLOCK * 90 * 4 * 8; \ + } \ + std::unique_ptr CreateJitCode( \ + const lstm_attr_t& attr) const override { \ + return make_unique(attr, CodeSize(attr)); \ + } \ } DECLARE_LSTM_CREATOR(LSTMCtHt); diff --git a/paddle/fluid/operators/jit/gen/matmul.cc b/paddle/fluid/operators/jit/gen/matmul.cc index b039fcead24e5..14601a78344f0 100644 --- a/paddle/fluid/operators/jit/gen/matmul.cc +++ b/paddle/fluid/operators/jit/gen/matmul.cc @@ -17,7 +17,7 @@ #include // offsetof #include "paddle/fluid/operators/jit/registry.h" -#include "paddle/fluid/platform/cpu_info.h" +#include "paddle/phi/backends/cpu/cpu_info.h" namespace paddle { namespace operators { @@ -110,12 +110,13 @@ void MatMulJitCode::genCode() { class MatMulCreator : public JitCodeCreator { public: bool CanBeUsed(const matmul_attr_t& attr) const override { - return attr.m == 1 && platform::MayIUse(platform::avx512f) && + return attr.m == 1 && + phi::backends::cpu::MayIUse(phi::backends::cpu::avx512f) && attr.n % ZMM_FLOAT_BLOCK == 0 && attr.k < 512; } size_t CodeSize(const matmul_attr_t& attr) const override { int block = YMM_FLOAT_BLOCK; - if (platform::MayIUse(platform::avx512f)) { + if (phi::backends::cpu::MayIUse(phi::backends::cpu::avx512f)) { block = ZMM_FLOAT_BLOCK; } return 96 + 4 * attr.k * (attr.n / block + 1) * 8; diff --git a/paddle/fluid/operators/jit/gen/seqpool.cc b/paddle/fluid/operators/jit/gen/seqpool.cc index 836472db599dc..bf1f5d465fdfd 100644 --- a/paddle/fluid/operators/jit/gen/seqpool.cc +++ b/paddle/fluid/operators/jit/gen/seqpool.cc @@ -16,7 +16,7 @@ #include "paddle/fluid/operators/jit/gen/act.h" // for exp_float_consts ones #include "paddle/fluid/operators/jit/registry.h" -#include "paddle/fluid/platform/cpu_info.h" +#include "paddle/phi/backends/cpu/cpu_info.h" namespace paddle { namespace operators { @@ -59,7 +59,7 @@ void SeqPoolJitCode::genCode() { class SeqPoolCreator : public JitCodeCreator { public: bool CanBeUsed(const seq_pool_attr_t& attr) const override { - return platform::MayIUse(platform::avx); + return phi::backends::cpu::MayIUse(phi::backends::cpu::avx); } size_t CodeSize(const seq_pool_attr_t& attr) const override { return 96 + ((attr.w / YMM_FLOAT_BLOCK + 4 /* for rest */) * diff --git a/paddle/fluid/operators/jit/gen/sgd.cc b/paddle/fluid/operators/jit/gen/sgd.cc index 82e4051246105..0a72c9d4fda92 100644 --- a/paddle/fluid/operators/jit/gen/sgd.cc +++ b/paddle/fluid/operators/jit/gen/sgd.cc @@ -17,7 +17,7 @@ #include // offsetof #include "paddle/fluid/operators/jit/registry.h" -#include "paddle/fluid/platform/cpu_info.h" +#include "paddle/phi/backends/cpu/cpu_info.h" namespace paddle { namespace operators { @@ -109,7 +109,7 @@ void SgdJitCode::genCode() { class SgdCreator : public JitCodeCreator { public: bool CanBeUsed(const sgd_attr_t& attr) const override { - return platform::MayIUse(platform::avx) && + return phi::backends::cpu::MayIUse(phi::backends::cpu::avx) && attr.grad_width % YMM_FLOAT_BLOCK == 0; } size_t CodeSize(const sgd_attr_t& attr) const override { return 96 + 32 * 8; } diff --git a/paddle/fluid/operators/jit/gen/vbroadcast.cc b/paddle/fluid/operators/jit/gen/vbroadcast.cc index 75728562f8bd5..1e1b7225f9851 100644 --- a/paddle/fluid/operators/jit/gen/vbroadcast.cc +++ b/paddle/fluid/operators/jit/gen/vbroadcast.cc @@ -15,7 +15,7 @@ #include "paddle/fluid/operators/jit/gen/vbroadcast.h" #include "paddle/fluid/operators/jit/registry.h" -#include "paddle/fluid/platform/cpu_info.h" +#include "paddle/phi/backends/cpu/cpu_info.h" namespace paddle { namespace operators { @@ -69,7 +69,8 @@ void VBroadcastJitCode::genCode() { class VBroadcastCreator : public JitCodeCreator { public: bool CanBeUsed(const int64_t& w) const override { - return platform::MayIUse(platform::avx) && w % YMM_FLOAT_BLOCK == 0; + return phi::backends::cpu::MayIUse(phi::backends::cpu::avx) && + w % YMM_FLOAT_BLOCK == 0; } size_t CodeSize(const int64_t& w) const override { return 96 + (w / YMM_FLOAT_BLOCK) * 16 * 8; diff --git a/paddle/fluid/operators/jit/gen_base.cc b/paddle/fluid/operators/jit/gen_base.cc index 031e07b13e2bb..83b42c62d6637 100644 --- a/paddle/fluid/operators/jit/gen_base.cc +++ b/paddle/fluid/operators/jit/gen_base.cc @@ -17,8 +17,8 @@ #include #include "paddle/fluid/memory/allocation/cpu_allocator.h" // for posix_memalign -#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/backends/cpu/cpu_info.h" #ifndef _WIN32 #define posix_memalign_free free @@ -66,7 +66,7 @@ void GenBase::operator delete(void* ptr) { posix_memalign_free(ptr); } std::vector packed_groups(int n, int k, int* block_out, int* rest_out) { int block; int max_num_regs; - if (platform::MayIUse(platform::avx512f)) { + if (phi::backends::cpu::MayIUse(phi::backends::cpu::avx512f)) { block = ZMM_FLOAT_BLOCK; max_num_regs = 32; } else { diff --git a/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.cc b/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.cc index 6c830be5eb6ec..b3ad2f9b06f5b 100644 --- a/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.cc +++ b/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.cc @@ -17,7 +17,7 @@ #include #include "paddle/fluid/operators/jit/registry.h" -#include "paddle/fluid/platform/cpu_info.h" +#include "paddle/phi/backends/cpu/cpu_info.h" namespace paddle { namespace operators { @@ -172,7 +172,7 @@ bool CRFDecodingKernel::CanBeUsed(const int& d) const { #else constexpr int block = YMM_FLOAT_BLOCK; #endif - return platform::MayIUse(platform::avx) && d >= block; + return phi::backends::cpu::MayIUse(phi::backends::cpu::avx) && d >= block; } } // namespace intrinsic diff --git a/paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc b/paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc index 49189c9545587..5393398ea87d4 100644 --- a/paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc +++ b/paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc @@ -17,7 +17,7 @@ #include #include "paddle/fluid/operators/jit/registry.h" -#include "paddle/fluid/platform/cpu_info.h" +#include "paddle/phi/backends/cpu/cpu_info.h" namespace paddle { namespace operators { @@ -179,7 +179,8 @@ void LayerNorm(float* x, } bool LayerNormKernel::CanBeUsed(const int& d) const { - return platform::MayIUse(platform::avx) && d >= YMM_FLOAT_BLOCK; + return phi::backends::cpu::MayIUse(phi::backends::cpu::avx) && + d >= YMM_FLOAT_BLOCK; } } // namespace intrinsic diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.cc b/paddle/fluid/operators/jit/more/mkl/mkl.cc index 9960be46db67b..8e7ca028c501f 100644 --- a/paddle/fluid/operators/jit/more/mkl/mkl.cc +++ b/paddle/fluid/operators/jit/more/mkl/mkl.cc @@ -16,8 +16,8 @@ #include "paddle/fluid/operators/jit/refer/refer.h" #include "paddle/fluid/operators/jit/registry.h" -#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/dynload/mklml.h" +#include "paddle/phi/backends/cpu/cpu_info.h" namespace paddle { namespace operators { @@ -188,17 +188,17 @@ void StrideASum(const double* x, double* res, int n, int stride) { // TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512 template <> bool VMulKernel::CanBeUsed(const int& d) const { - return platform::MayIUse(platform::avx512f) && d > 512; + return phi::backends::cpu::MayIUse(phi::backends::cpu::avx512f) && d > 512; } template <> bool VAddKernel::CanBeUsed(const int& d) const { - return platform::MayIUse(platform::avx) && d > 512; + return phi::backends::cpu::MayIUse(phi::backends::cpu::avx) && d > 512; } template <> bool VScalKernel::CanBeUsed(const int& d) const { - return platform::MayIUse(platform::avx512f) && d > 512; + return phi::backends::cpu::MayIUse(phi::backends::cpu::avx512f) && d > 512; } template <> @@ -274,7 +274,7 @@ bool SgdKernel::CanBeUsed(const sgd_attr_t& attr) const { template <> bool MatMulKernel::CanBeUsed(const matmul_attr_t& attr) const { - return platform::MayIUse(platform::avx); + return phi::backends::cpu::MayIUse(phi::backends::cpu::avx); } template <> @@ -285,7 +285,7 @@ bool MatMulKernel::CanBeUsed(const matmul_attr_t& attr) const { template <> bool SoftmaxKernel::CanBeUsed(const int& d) const { // tuned on avx2 - return platform::MayIUse(platform::avx) && d < 60; + return phi::backends::cpu::MayIUse(phi::backends::cpu::avx) && d < 60; } #define AWALYS_USE_ME_WITH_DOUBLE(func) \ diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc index 6149f6ad9f7f4..f5e381ea14c71 100644 --- a/paddle/fluid/operators/jit/test.cc +++ b/paddle/fluid/operators/jit/test.cc @@ -19,8 +19,8 @@ limitations under the License. */ #include "glog/logging.h" #include "gtest/gtest.h" #include "paddle/fluid/operators/jit/kernels.h" -#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/place.h" +#include "paddle/phi/backends/cpu/cpu_info.h" DEFINE_double(acc, 1e-5, "Test accuracy threshold."); @@ -437,7 +437,7 @@ void TestKernelNCHW16CMulNC() { EXPECT_TRUE(tgt != nullptr); if (std::is_same::value && - paddle::platform::MayIUse(paddle::platform::avx512f)) { + phi::backends::cpu::MayIUse(phi::backends::cpu::avx512f)) { EXPECT_TRUE(jitcode != nullptr); } for (int ni = 0; ni < n; ni++) { @@ -1393,7 +1393,7 @@ TEST(JITKernel_helper, pack_weights) { } int block = 0; std::vector groups; - if (paddle::platform::MayIUse(paddle::platform::avx512f)) { + if (phi::backends::cpu::MayIUse(phi::backends::cpu::avx512f)) { block = ZMM_FLOAT_BLOCK; groups.push_back(30); } else { diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 3b06722ddfbe0..3d5c7bfb4e7b7 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -32,7 +32,6 @@ math_library(maxouting) math_library(sequence_padding) math_library(sequence_pooling DEPS math_function jit_kernel_helper) math_library(sequence_scale) -math_library(softmax DEPS math_function jit_kernel_helper) if(WITH_ASCEND_CL) math_library(beam_search DEPS math_function beam_search_npu) elseif(WITH_XPU) diff --git a/paddle/fluid/operators/math/cpu_vec.h b/paddle/fluid/operators/math/cpu_vec.h deleted file mode 100644 index 40ed412f7c507..0000000000000 --- a/paddle/fluid/operators/math/cpu_vec.h +++ /dev/null @@ -1,664 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include -#include - -#include "paddle/fluid/platform/cpu_info.h" -#include "paddle/fluid/platform/enforce.h" - -#ifdef PADDLE_WITH_MKLML -#include "paddle/fluid/platform/dynload/mklml.h" -#endif - -namespace paddle { -namespace operators { -namespace math { - -#define SIGMOID_THRESHOLD_MIN -40.0 -#define SIGMOID_THRESHOLD_MAX 13.0 - -#define YMM_FLOAT_BLOCK 8 -#define AVX_DOUBLE_BLOCK 4 -#define YMM_FLOAT_BLOCK 8 -#define AVX2_DOUBLE_BLOCK 4 -#define ZMM_FLOAT_BLOCK 16 -#define AVX512_DOUBLE_BLOCK 8 - -template -inline void vec_exp(const int n, const T* x, T* y) { - for (int i = 0; i < n; ++i) { - y[i] = std::exp(x[i]); - } -} - -template -inline void vec_scal(const int n, const T a, T* x) { - for (int i = 0; i < n; ++i) { - x[i] = a * x[i]; - } -} - -#ifdef PADDLE_WITH_MKLML -template <> -inline void vec_exp(const int n, const float* x, float* y) { - constexpr int small_enough = 128; - if (n < small_enough) { - for (int i = 0; i < n; ++i) { - y[i] = std::exp(x[i]); - } - } else { - platform::dynload::vsExp(n, x, y); - } -} - -template <> -inline void vec_exp(const int n, const double* x, double* y) { - platform::dynload::vdExp(n, x, y); -} - -template <> -inline void vec_scal(const int n, const float a, float* x) { - platform::dynload::cblas_sscal(n, a, x, 1); -} - -template <> -inline void vec_scal(const int n, const double a, double* x) { - platform::dynload::cblas_dscal(n, a, x, 1); -} -#endif - -// MKL scal only support inplace, choose this if src and dst are not equal -template -inline void vec_scal(const int n, const T a, const T* x, T* y) { - for (int i = 0; i < n; ++i) { - y[i] = a * x[i]; - } -} - -template <> -inline void vec_scal(const int n, - const float a, - const float* x, - float* y) { -#ifdef __AVX__ - constexpr int block = YMM_FLOAT_BLOCK; - if (n < block) { - vec_scal(n, a, x, y); - return; - } - const int rest = n % block; - const int end = n - rest; - int i = 0; - __m256 scalar = _mm256_set1_ps(a); - __m256 tmp; -#define MOVE_ONE_STEP \ - tmp = _mm256_loadu_ps(x + i); \ - tmp = _mm256_mul_ps(tmp, scalar); \ - _mm256_storeu_ps(y + i, tmp) - for (i = 0; i < end; i += block) { - MOVE_ONE_STEP; - } -#undef MOVE_ONE_STEP - if (rest == 0) { - return; - } - // can not continue move step if src and dst are inplace - for (i = n - rest; i < n; ++i) { - y[i] = a * x[i]; - } -#else - vec_scal(n, a, x, y); -#endif -} - -template <> -inline void vec_scal(const int n, - const float a, - const float* x, - float* y) { - vec_scal(n, a, x, y); -} - -template <> -inline void vec_scal(const int n, - const float a, - const float* x, - float* y) { - // TODO(TJ): enable me - vec_scal(n, a, x, y); -} - -template -inline void vec_sum(const size_t n, const T* x, T* s) { - s[0] = x[0]; - for (size_t i = 1; i < n; ++i) { - s[0] += x[i]; - } -} - -template <> -inline void vec_sum(const size_t n, - const float* x, - float* s) { -#ifdef __AVX__ - constexpr unsigned int block = YMM_FLOAT_BLOCK; - if (n < block) { - vec_sum(n, x, s); - return; - } - - unsigned int i, end; - i = end = 0; - s[0] = 0.f; - - end = n & ~(block - 1); - __m256 tmp = _mm256_setzero_ps(); - for (i = 0; i < end; i += block) { - tmp = _mm256_add_ps(tmp, _mm256_loadu_ps(x + i)); - } - - __m256 hsum = _mm256_hadd_ps(tmp, tmp); - hsum = _mm256_add_ps(hsum, _mm256_permute2f128_ps(hsum, hsum, 0x1)); - _mm_store_ss( - s, - _mm_hadd_ps(_mm256_castps256_ps128(hsum), _mm256_castps256_ps128(hsum))); - - for (; i < n; i++) { - s[0] += x[i]; - } -#else - vec_sum(n, x, s); -#endif -} - -template -inline void vec_mul(const size_t n, const T* x, const T* y, T* z) { - for (size_t i = 0; i < n; ++i) { - z[i] = x[i] * y[i]; - } -} - -template <> -inline void vec_mul(const size_t n, - const float* x, - const float* y, - float* z) { -#ifdef __AVX__ - constexpr unsigned int block = YMM_FLOAT_BLOCK; - if (n < block) { - vec_mul(n, x, y, z); - return; - } - - unsigned int i = 0, end = 0; - end = n & ~(block - 1); - for (i = 0; i < end; i += block) { - _mm256_storeu_ps( - z + i, _mm256_mul_ps(_mm256_loadu_ps(x + i), _mm256_loadu_ps(y + i))); - } - - for (; i < n; i++) { - z[i] = x[i] * y[i]; - } -#else - vec_mul(n, x, y, z); -#endif -} - -template -inline void vec_mul_reduce(const size_t n, const T* x, const T* y, T* z) { - z[0] = x[0] * y[0]; - for (size_t i = 1; i < n; ++i) { - z[0] += x[i] * y[i]; - } -} - -template <> -inline void vec_mul_reduce(const size_t n, - const float* x, - const float* y, - float* z) { -#ifdef __AVX__ - constexpr unsigned int block = YMM_FLOAT_BLOCK; - if (n < block) { - vec_mul_reduce(n, x, y, z); - return; - } - - unsigned int i = 0, end = 0; - z[0] = 0.f; - - end = n & ~(block - 1); - __m256 tmp = _mm256_setzero_ps(); - for (i = 0; i < end; i += block) { - tmp = _mm256_add_ps( - tmp, _mm256_mul_ps(_mm256_loadu_ps(x + i), _mm256_loadu_ps(y + i))); - } - - __m256 hsum = _mm256_hadd_ps(tmp, tmp); - hsum = _mm256_add_ps(hsum, _mm256_permute2f128_ps(hsum, hsum, 0x1)); - _mm_store_ss( - z, - _mm_hadd_ps(_mm256_castps256_ps128(hsum), _mm256_castps256_ps128(hsum))); - - for (; i < n; i++) { - z[0] += x[i] * y[i]; - } -#else - vec_mul_reduce(n, x, y, z); -#endif -} - -template -inline void vec_bias_sub(const int n, const T a, const T* x, T* y) { - for (int i = 0; i < n; ++i) { - y[i] = a - x[i]; - } -} - -template <> -inline void vec_bias_sub(const int n, - const float a, - const float* x, - float* y) { -#ifdef __AVX__ - constexpr int block = YMM_FLOAT_BLOCK; - if (n < block) { - vec_bias_sub(n, a, x, y); - return; - } - const int rest = n % block; - const int end = n - rest; - int i = 0; - __m256 bias = _mm256_set1_ps(a); - __m256 tmp; -#define MOVE_ONE_STEP \ - tmp = _mm256_loadu_ps(x + i); \ - tmp = _mm256_sub_ps(bias, tmp); \ - _mm256_storeu_ps(y + i, tmp) - for (i = 0; i < end; i += block) { - MOVE_ONE_STEP; - } -#undef MOVE_ONE_STEP - if (rest == 0) { - return; - } - // can not continue move step if src and dst are inplace - for (i = n - rest; i < n; ++i) { - y[i] = a - x[i]; - } -#else - vec_bias_sub(n, a, x, y); -#endif -} - -template <> -inline void vec_bias_sub(const int n, - const float a, - const float* x, - float* y) { - vec_bias_sub(n, a, x, y); -} - -template <> -inline void vec_bias_sub(const int n, - const float a, - const float* x, - float* y) { - // TODO(TJ): enable me - vec_bias_sub(n, a, x, y); -} - -// out = x*y + (1-x)*z -template -inline void vec_cross(const int n, const T* x, const T* y, const T* z, T* out) { - for (int i = 0; i < n; ++i) { - out[i] = x[i] * y[i] + (static_cast(1) - x[i]) * z[i]; - } -} - -template <> -inline void vec_cross( - const int n, const float* x, const float* y, const float* z, float* out) { -#ifdef __AVX__ - constexpr int block = YMM_FLOAT_BLOCK; - if (n < block) { - vec_cross(n, x, y, z, out); - return; - } - const int rest = n % block; - const int end = n - rest; - int i = 0; - __m256 bias = _mm256_set1_ps(1.f); - __m256 tmpx, tmpy, tmpz; - for (i = 0; i < end; i += block) { - tmpx = _mm256_loadu_ps(x + i); - tmpy = _mm256_loadu_ps(y + i); - tmpz = _mm256_loadu_ps(z + i); - tmpy = _mm256_mul_ps(tmpx, tmpy); - tmpx = _mm256_sub_ps(bias, tmpx); - tmpz = _mm256_mul_ps(tmpx, tmpz); - tmpz = _mm256_add_ps(tmpy, tmpz); - _mm256_storeu_ps(out + i, tmpz); - } - if (rest == 0) { - return; - } - // can not continue move step if src and dst are inplace - for (i = n - rest; i < n; ++i) { - out[i] = x[i] * y[i] + (1.f - x[i]) * z[i]; - } -#else - vec_cross(n, x, y, z, out); -#endif -} - -template <> -inline void vec_cross( - const int n, const float* x, const float* y, const float* z, float* out) { - vec_cross(n, x, y, z, out); -} - -template <> -inline void vec_cross( - const int n, const float* x, const float* y, const float* z, float* out) { - // TODO(TJ): enable me - vec_cross(n, x, y, z, out); -} - -template -inline void vec_clip(const size_t n, const T a, const T* x, T* y) { - for (size_t i = 0; i < n; ++i) { - y[i] = x[i] < a ? a : x[i]; - } -} - -template <> -inline void vec_clip(const size_t n, - const float a, - const float* x, - float* y) { -#ifdef __AVX__ - constexpr unsigned int block = YMM_FLOAT_BLOCK; - if (n < block) { - vec_clip(n, a, x, y); - return; - } - - unsigned int i = 0, end = 0; - end = n & ~(block - 1); - __m256 threshold = _mm256_set1_ps(a); - - for (i = 0; i < end; i += block) { - _mm256_storeu_ps(y + i, _mm256_max_ps(_mm256_loadu_ps(x + i), threshold)); - } - - for (; i < n; i++) { - y[i] = x[i] < a ? a : x[i]; - } -#else - vec_clip(n, a, x, y); -#endif -} - -template -inline void vec_add_bias(const int n, const T a, const T* x, T* y) { - for (int i = 0; i < n; ++i) { - y[i] = x[i] + a; - } -} - -template <> -inline void vec_add_bias(const int n, - const float a, - const float* x, - float* y) { -#ifdef __AVX__ - constexpr int block = YMM_FLOAT_BLOCK; - if (n < block) { - vec_add_bias(n, a, x, y); - return; - } - const int rest = n % block; - const int end = n - rest; - int i = 0; - __m256 bias = _mm256_set1_ps(a); - __m256 tmp; -#define MOVE_ONE_STEP \ - tmp = _mm256_loadu_ps(x + i); \ - tmp = _mm256_add_ps(tmp, bias); \ - _mm256_storeu_ps(y + i, tmp) - for (i = 0; i < end; i += block) { - MOVE_ONE_STEP; - } -#undef MOVE_ONE_STEP - if (rest == 0) { - return; - } - // can not continue move step if src and dst are inplace - for (i = n - rest; i < n; ++i) { - y[i] = x[i] + a; - } -#else - vec_add_bias(n, a, x, y); -#endif -} - -template <> -inline void vec_add_bias(const int n, - const float a, - const float* x, - float* y) { - vec_add_bias(n, a, x, y); -} - -template <> -inline void vec_add_bias(const int n, - const float a, - const float* x, - float* y) { - // TODO(TJ): enable me - vec_add_bias(n, a, x, y); -} - -template -inline void vec_identity(const int n, const T* x, T* y) { - // do nothing - return; -} - -template -inline void vec_sigmoid(const int n, const T* x, T* y) { - const T min = SIGMOID_THRESHOLD_MIN; - const T max = SIGMOID_THRESHOLD_MAX; - for (int i = 0; i < n; ++i) { - y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); - y[i] = static_cast(0) - y[i]; - } - vec_exp(n, y, y); - for (int i = 0; i < n; ++i) { - y[i] = static_cast(1) / (static_cast(1) + y[i]); - } -} - -template <> -inline void vec_sigmoid(const int n, - const float* x, - float* y) { -#ifdef __AVX__ - constexpr int block = YMM_FLOAT_BLOCK; - if (n < block) { - vec_sigmoid(n, x, y); - return; - } - const int rest = n % block; - const int end = n - rest; - int i = 0; - __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); - __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); - __m256 zeros = _mm256_setzero_ps(); - __m256 tmp; -#define MOVE_ONE_STEP \ - tmp = _mm256_loadu_ps(x + i); \ - tmp = _mm256_max_ps(tmp, min); \ - tmp = _mm256_min_ps(tmp, max); \ - tmp = _mm256_sub_ps(zeros, tmp); \ - _mm256_storeu_ps(y + i, tmp) - for (i = 0; i < end; i += block) { - MOVE_ONE_STEP; - } -#undef MOVE_ONE_STEP - if (rest != 0) { - // can not continue move step since the src and dst address could be equal - const float xmin = SIGMOID_THRESHOLD_MIN; - const float xmax = SIGMOID_THRESHOLD_MAX; - for (i = n - rest; i < n; ++i) { - y[i] = 0.f - ((x[i] < xmin) ? xmin : ((x[i] > xmax) ? xmax : x[i])); - } - } - - vec_exp(n, y, y); - - __m256 ones = _mm256_set1_ps(1.0f); -#define MOVE_ONE_STEP \ - tmp = _mm256_loadu_ps(y + i); \ - tmp = _mm256_add_ps(ones, tmp); \ - tmp = _mm256_div_ps(ones, tmp); \ - _mm256_storeu_ps(y + i, tmp) - for (i = 0; i < end; i += block) { - MOVE_ONE_STEP; - } -#undef MOVE_ONE_STEP - if (rest == 0) { - return; - } - // can not continue move step - for (i = n - rest; i < n; ++i) { - y[i] = 1.f / (1.f + y[i]); - } -#else - vec_sigmoid(n, x, y); -#endif -} - -template <> -inline void vec_sigmoid(const int n, - const float* x, - float* y) { - vec_sigmoid(n, x, y); -} - -template <> -inline void vec_sigmoid(const int n, - const float* x, - float* y) { - // TODO(TJ): enable me - vec_sigmoid(n, x, y); -} - -template -inline void vec_tanh(const int n, const T* x, T* y) { - vec_scal(n, static_cast(2), x, y); - vec_sigmoid(n, y, y); - vec_scal(n, static_cast(2), y); - vec_add_bias(n, static_cast(-1), y, y); -} - -// TODO(TJ): make relu clip -template -inline void vec_relu(const int n, const T* x, T* y) { - for (int i = 0; i < n; ++i) { - y[i] = x[i] > 0 ? x[i] : 0; - } -} - -template <> -inline void vec_relu(const int n, - const float* x, - float* y) { -#ifdef __AVX__ - constexpr int block = YMM_FLOAT_BLOCK; - if (n < block * 4) { - vec_relu(n, x, y); - return; - } - - const int rest = n % block; - const int end = n - rest; - int i = 0; - __m256 zeros = _mm256_setzero_ps(); - __m256 tmp; -#define MOVE_ONE_STEP \ - tmp = _mm256_loadu_ps(x + i); \ - tmp = _mm256_max_ps(tmp, zeros); \ - _mm256_storeu_ps(y + i, tmp) - for (i = 0; i < end; i += block) { - MOVE_ONE_STEP; - } - if (rest == 0) { - return; - } - i = n - block; - MOVE_ONE_STEP; -#undef MOVE_ONE_STEP - -#else - vec_relu(n, x, y); -#endif -} - -template <> -inline void vec_relu(const int n, - const float* x, - float* y) { - vec_relu(n, x, y); -} - -template <> -inline void vec_relu(const int n, - const float* x, - float* y) { - // TODO(TJ): enable me - vec_relu(n, x, y); -} - -// TODO(TJ): optimize double of sigmoid, tanh and relu if necessary - -template -class VecActivations { - public: - std::function operator()( - const std::string& type) { - if (type == "sigmoid") { - return vec_sigmoid; - } else if (type == "relu") { - return vec_relu; - } else if (type == "tanh") { - return vec_tanh; - } else if (type == "identity" || type == "") { - return vec_identity; - } - PADDLE_THROW(platform::errors::InvalidArgument( - "Expected type should be one of sigmod, relu, tanh, identity. But got " - "not support type: %s.", - type)); - } -}; - -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/sample_logits_op.cu b/paddle/fluid/operators/sample_logits_op.cu index 7d61088dd9fd6..a24cb99b6ea33 100644 --- a/paddle/fluid/operators/sample_logits_op.cu +++ b/paddle/fluid/operators/sample_logits_op.cu @@ -21,9 +21,9 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/operators/math/sample_prob.h" -#include "paddle/fluid/operators/math/softmax.h" #include "paddle/fluid/operators/sample_logits_op.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/softmax.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/sample_logits_op.h b/paddle/fluid/operators/sample_logits_op.h index fe53a12e5ed71..a8413a4988d83 100644 --- a/paddle/fluid/operators/sample_logits_op.h +++ b/paddle/fluid/operators/sample_logits_op.h @@ -21,8 +21,8 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/operators/math/sample_prob.h" -#include "paddle/fluid/operators/math/softmax.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/softmax.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/sequence_ops/sequence_softmax_cudnn_op.cu.cc b/paddle/fluid/operators/sequence_ops/sequence_softmax_cudnn_op.cu.cc index 54d5b5c45b590..4cd521fdb81bf 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_softmax_cudnn_op.cu.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_softmax_cudnn_op.cu.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/softmax.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/softmax.h" namespace paddle { namespace operators { @@ -61,7 +61,7 @@ class SequenceSoftmaxCUDNNKernel : public framework::OpKernel { phi::make_ddim({1UL, end_pos - start_pos}); x_i.Resize(dims_i); out_i.Resize(dims_i); - math::SoftmaxCUDNNFunctor()( + phi::funcs::SoftmaxCUDNNFunctor()( ctx.template device_context(), &x_i, &out_i); } } @@ -95,7 +95,7 @@ class SequenceSoftmaxGradCUDNNKernel : public framework::OpKernel { out_i.Resize(dims_i); out_grad_i.Resize(dims_i); x_grad_i.Resize(dims_i); - math::SoftmaxGradCUDNNFunctor()( + phi::funcs::SoftmaxGradCUDNNFunctor()( ctx.template device_context(), &out_i, &out_grad_i, diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op_npu.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op_npu.cc index 6a51198e75460..fb35f80ec2ad1 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op_npu.cc +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op_npu.cc @@ -16,10 +16,10 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/softmax.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/phi/kernels/funcs/axis_utils.h" #include "paddle/phi/kernels/funcs/cross_entropy.h" +#include "paddle/phi/kernels/funcs/softmax.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 1dc762b9e1854..b25b56d5387e3 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -56,18 +56,10 @@ cc_test( SRCS enforce_test.cc DEPS enforce) -set(CPU_INFO_DEPS gflags glog enforce) -if(WITH_XBYAK) - list(APPEND CPU_INFO_DEPS xbyak) -endif() -cc_library( - cpu_info - SRCS cpu_info.cc - DEPS ${CPU_INFO_DEPS}) cc_test( cpu_info_test SRCS cpu_info_test.cc - DEPS cpu_info) + DEPS phi_backends) cc_library( os_info SRCS os_info.cc @@ -194,7 +186,6 @@ cc_library( phi_place eigen3 cpu_helper - cpu_info framework_proto ${IPU_CTX_DEPS} ${GPU_CTX_DEPS} diff --git a/paddle/fluid/platform/cpu_info.h b/paddle/fluid/platform/cpu_info.h deleted file mode 100644 index b1220e615da00..0000000000000 --- a/paddle/fluid/platform/cpu_info.h +++ /dev/null @@ -1,92 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include - -#ifdef _WIN32 -#if defined(__AVX2__) -#include // avx2 -#elif defined(__AVX__) -#include // avx -#endif // AVX -#else // WIN32 -#ifdef __AVX__ -#include -#endif -#endif // WIN32 - -#if defined(_WIN32) -#define ALIGN32_BEG __declspec(align(32)) -#define ALIGN32_END -#else -#define ALIGN32_BEG -#define ALIGN32_END __attribute__((aligned(32))) -#endif // _WIN32 - -#ifndef PADDLE_WITH_XBYAK -#ifdef _WIN32 -#define cpuid(reg, x) __cpuidex(reg, x, 0) -#else -#if !defined(WITH_NV_JETSON) && !defined(PADDLE_WITH_ARM) && \ - !defined(PADDLE_WITH_SW) && !defined(PADDLE_WITH_MIPS) -#include -inline void cpuid(int reg[4], int x) { - __cpuid_count(x, 0, reg[0], reg[1], reg[2], reg[3]); -} -#endif -#endif -#endif - -#include "paddle/phi/backends/cpu/cpu_info.h" - -namespace paddle { -namespace platform { - -size_t CpuTotalPhysicalMemory(); - -//! Get the maximum allocation size for a machine. -size_t CpuMaxAllocSize(); - -//! Get the maximum allocation size for a machine. -size_t CUDAPinnedMaxAllocSize(); - -using phi::backends::cpu::CpuMinChunkSize; - -//! Get the maximum chunk size for buddy allocator. -size_t CpuMaxChunkSize(); - -//! Get the minimum chunk size for buddy allocator. -size_t CUDAPinnedMinChunkSize(); - -//! Get the maximum chunk size for buddy allocator. -size_t CUDAPinnedMaxChunkSize(); - -//! Get the maximum allocation size for a machine. -size_t NPUPinnedMaxAllocSize(); - -//! Get the minimum chunk size for buddy allocator. -size_t NPUPinnedMinChunkSize(); - -//! Get the maximum chunk size for buddy allocator. -size_t NPUPinnedMaxChunkSize(); - -using namespace phi::backends::cpu; // NOLINT - -// May I use some instruction -bool MayIUse(const cpu_isa_t cpu_isa); - -} // namespace platform -} // namespace paddle diff --git a/paddle/fluid/platform/cpu_info_test.cc b/paddle/fluid/platform/cpu_info_test.cc index d3d0506e8bb2a..e9e45c0292baf 100644 --- a/paddle/fluid/platform/cpu_info_test.cc +++ b/paddle/fluid/platform/cpu_info_test.cc @@ -11,7 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/platform/cpu_info.h" +#include "paddle/phi/backends/cpu/cpu_info.h" #include @@ -23,7 +23,8 @@ DECLARE_double(fraction_of_cpu_memory_to_use); TEST(CpuMemoryUsage, Print) { std::stringstream ss; - size_t memory_size = paddle::platform::CpuMaxAllocSize() / 1024 / 1024 / 1024; + size_t memory_size = + phi::backends::cpu::CpuMaxAllocSize() / 1024 / 1024 / 1024; float use_percent = FLAGS_fraction_of_cpu_memory_to_use * 100; std::cout << paddle::string::Sprintf("\n%.2f %% of CPU Memory Usage: %d GB\n", diff --git a/paddle/fluid/platform/init.cc b/paddle/fluid/platform/init.cc index 9045b4b54cc51..9d6b1cf91b109 100644 --- a/paddle/fluid/platform/init.cc +++ b/paddle/fluid/platform/init.cc @@ -16,9 +16,9 @@ limitations under the License. */ #include #include "paddle/fluid/platform/cpu_helper.h" -#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/device/npu/npu_info.h" #include "paddle/fluid/string/split.h" +#include "paddle/phi/backends/cpu/cpu_info.h" #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #include "paddle/fluid/platform/cuda_device_guard.h" #endif diff --git a/paddle/fluid/platform/profiler/CMakeLists.txt b/paddle/fluid/platform/profiler/CMakeLists.txt index 3fbf09d8902a6..c72fa593ce4a2 100644 --- a/paddle/fluid/platform/profiler/CMakeLists.txt +++ b/paddle/fluid/platform/profiler/CMakeLists.txt @@ -29,7 +29,7 @@ cc_library( cc_library( cpu_utilization SRCS cpu_utilization.cc - DEPS cpu_info os_info enforce glog) + DEPS phi_backends os_info enforce glog) cc_library( new_profiler SRCS profiler.cc diff --git a/paddle/fluid/pybind/parallel_executor.cc b/paddle/fluid/pybind/parallel_executor.cc index e0f6ae9b0b7e4..d0aea4e76df52 100644 --- a/paddle/fluid/pybind/parallel_executor.cc +++ b/paddle/fluid/pybind/parallel_executor.cc @@ -72,7 +72,6 @@ limitations under the License. */ #include "paddle/fluid/operators/common_infer_shape_functions.h" #include "paddle/fluid/operators/py_func_op.h" #include "paddle/fluid/platform/cpu_helper.h" -#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/device/device_wrapper.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/dynload/dynamic_loader.h" @@ -89,6 +88,7 @@ limitations under the License. */ #include "paddle/fluid/pybind/eager.h" #include "paddle/fluid/pybind/imperative.h" #include "paddle/fluid/pybind/io.h" +#include "paddle/phi/backends/cpu/cpu_info.h" #include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/lod_utils.h" #include "paddle/utils/none.h" diff --git a/paddle/fluid/pybind/place.cc b/paddle/fluid/pybind/place.cc index 309fad4273dc2..9201908e50c00 100644 --- a/paddle/fluid/pybind/place.cc +++ b/paddle/fluid/pybind/place.cc @@ -72,7 +72,6 @@ limitations under the License. */ #include "paddle/fluid/operators/common_infer_shape_functions.h" #include "paddle/fluid/operators/py_func_op.h" #include "paddle/fluid/platform/cpu_helper.h" -#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/device/device_wrapper.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/dynload/dynamic_loader.h" @@ -89,6 +88,7 @@ limitations under the License. */ #include "paddle/fluid/pybind/eager.h" #include "paddle/fluid/pybind/imperative.h" #include "paddle/fluid/pybind/io.h" +#include "paddle/phi/backends/cpu/cpu_info.h" #include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/lod_utils.h" #include "paddle/utils/none.h" diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 4bbf5a33bd137..f699b92e5045b 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -75,7 +75,6 @@ limitations under the License. */ #include "paddle/fluid/operators/ops_extra_info.h" #include "paddle/fluid/operators/py_func_op.h" #include "paddle/fluid/platform/cpu_helper.h" -#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/device/device_wrapper.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/dynload/dynamic_loader.h" @@ -94,6 +93,7 @@ limitations under the License. */ #include "paddle/fluid/pybind/io.h" #include "paddle/fluid/pybind/jit.h" #include "paddle/fluid/pybind/xpu_streams_py.h" +#include "paddle/phi/backends/cpu/cpu_info.h" #include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/lod_utils.h" #include "paddle/utils/none.h" @@ -327,7 +327,7 @@ bool SupportsBfloat16() { #ifndef PADDLE_WITH_MKLDNN return false; #else - if (platform::MayIUse(platform::cpu_isa_t::avx512_core)) + if (phi::backends::cpu::MayIUse(phi::backends::cpu::cpu_isa_t::avx512_core)) return true; else return false; @@ -338,7 +338,7 @@ bool SupportsBfloat16FastPerformance() { #ifndef PADDLE_WITH_MKLDNN return false; #else - if (platform::MayIUse(platform::cpu_isa_t::avx512_bf16)) + if (phi::backends::cpu::MayIUse(phi::backends::cpu::cpu_isa_t::avx512_bf16)) return true; else return false; @@ -349,8 +349,8 @@ bool SupportsInt8() { #ifndef PADDLE_WITH_MKLDNN return false; #else - return (platform::MayIUse(platform::cpu_isa_t::avx2) || - platform::MayIUse(platform::cpu_isa_t::avx512f)); + return (phi::backends::cpu::MayIUse(phi::backends::cpu::cpu_isa_t::avx2) || + phi::backends::cpu::MayIUse(phi::backends::cpu::cpu_isa_t::avx512f)); #endif } @@ -358,7 +358,8 @@ bool SupportsVNNI() { #ifndef PADDLE_WITH_MKLDNN return false; #else - return platform::MayIUse(platform::cpu_isa_t::avx512_core_vnni); + return phi::backends::cpu::MayIUse( + phi::backends::cpu::cpu_isa_t::avx512_core_vnni); #endif } @@ -615,7 +616,7 @@ PYBIND11_MODULE(libpaddle, m) { BindJit(&m); // Not used, just make sure cpu_info.cc is linked. - paddle::platform::CpuTotalPhysicalMemory(); + phi::backends::cpu::CpuTotalPhysicalMemory(); paddle::memory::allocation::UseAllocatorStrategyGFlag(); diff --git a/paddle/fluid/pybind/tensor.cc b/paddle/fluid/pybind/tensor.cc index f8193eb1a4c6a..8739b32965b0d 100644 --- a/paddle/fluid/pybind/tensor.cc +++ b/paddle/fluid/pybind/tensor.cc @@ -72,7 +72,6 @@ limitations under the License. */ #include "paddle/fluid/operators/common_infer_shape_functions.h" #include "paddle/fluid/operators/py_func_op.h" #include "paddle/fluid/platform/cpu_helper.h" -#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/device/device_wrapper.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/dynload/dynamic_loader.h" @@ -89,6 +88,7 @@ limitations under the License. */ #include "paddle/fluid/pybind/eager.h" #include "paddle/fluid/pybind/imperative.h" #include "paddle/fluid/pybind/io.h" +#include "paddle/phi/backends/cpu/cpu_info.h" #include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/lod_utils.h" #include "paddle/utils/none.h" diff --git a/paddle/phi/backends/CMakeLists.txt b/paddle/phi/backends/CMakeLists.txt index ddb7adaa92b83..c35bd2bc456f3 100644 --- a/paddle/phi/backends/CMakeLists.txt +++ b/paddle/phi/backends/CMakeLists.txt @@ -1,8 +1,11 @@ add_subdirectory(dynload) add_subdirectory(gpu) -set(BACKENDS_SRCS all_context.cc cpu/cpu_context.cc) +set(BACKENDS_SRCS all_context.cc cpu/cpu_context.cc cpu/cpu_info.cc) set(BACKENDS_DEPS enforce place flags eigen3 phi_device_context) +if(WITH_XBYAK) + list(APPEND BACKENDS_DEPS xbyak) +endif() if(WITH_GPU OR WITH_ROCM) list(APPEND BACKENDS_SRCS gpu/gpu_context.cc gpu/gpu_info.cc diff --git a/paddle/fluid/platform/cpu_info.cc b/paddle/phi/backends/cpu/cpu_info.cc similarity index 94% rename from paddle/fluid/platform/cpu_info.cc rename to paddle/phi/backends/cpu/cpu_info.cc index 7edc322a90f5a..74a8df11920d6 100644 --- a/paddle/fluid/platform/cpu_info.cc +++ b/paddle/phi/backends/cpu/cpu_info.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,11 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/platform/cpu_info.h" - -#ifdef PADDLE_WITH_XBYAK -#include "xbyak/xbyak_util.h" -#endif +#include "paddle/phi/backends/cpu/cpu_info.h" #ifdef __APPLE__ #include @@ -30,6 +26,10 @@ limitations under the License. */ #include #endif // _WIN32 +#ifdef PADDLE_WITH_XBYAK +#include "xbyak/xbyak_util.h" +#endif + #include #include "paddle/phi/core/flags.h" @@ -47,8 +47,9 @@ PADDLE_DEFINE_EXPORTED_bool(use_pinned_memory, true, "If set, allocate cpu pinned memory."); -namespace paddle { -namespace platform { +namespace phi { +namespace backends { +namespace cpu { size_t CpuTotalPhysicalMemory() { #ifdef __APPLE__ @@ -87,6 +88,11 @@ size_t CpuMaxChunkSize() { static_cast(FLAGS_initial_cpu_memory_in_mb * 1 << 20)); } +size_t CpuMinChunkSize() { + // Allow to allocate the minimum chunk size is 4 KB. + return 1 << 12; +} + size_t CUDAPinnedMaxAllocSize() { // For distributed systems, it requires configuring and limiting // the fraction of memory to use. @@ -206,5 +212,6 @@ bool MayIUse(const cpu_isa_t cpu_isa) { } #endif -} // namespace platform -} // namespace paddle +} // namespace cpu +} // namespace backends +} // namespace phi diff --git a/paddle/phi/backends/cpu/cpu_info.h b/paddle/phi/backends/cpu/cpu_info.h index 12db2c7d09d39..92a159c9b0aac 100644 --- a/paddle/phi/backends/cpu/cpu_info.h +++ b/paddle/phi/backends/cpu/cpu_info.h @@ -36,15 +36,52 @@ #define ALIGN32_END __attribute__((aligned(32))) #endif // _WIN32 +#ifndef PADDLE_WITH_XBYAK +#ifdef _WIN32 +#define cpuid(reg, x) __cpuidex(reg, x, 0) +#else +#if !defined(WITH_NV_JETSON) && !defined(PADDLE_WITH_ARM) && \ + !defined(PADDLE_WITH_SW) && !defined(PADDLE_WITH_MIPS) +#include +inline void cpuid(int reg[4], int x) { + __cpuid_count(x, 0, reg[0], reg[1], reg[2], reg[3]); +} +#endif +#endif +#endif + namespace phi { namespace backends { namespace cpu { +size_t CpuTotalPhysicalMemory(); + +//! Get the maximum allocation size for a machine. +size_t CpuMaxAllocSize(); + +//! Get the maximum allocation size for a machine. +size_t CUDAPinnedMaxAllocSize(); + //! Get the minimum chunk size for buddy allocator. -inline size_t CpuMinChunkSize() { - // Allow to allocate the minimum chunk size is 4 KB. - return 1 << 12; -} +size_t CpuMinChunkSize(); + +//! Get the maximum chunk size for buddy allocator. +size_t CpuMaxChunkSize(); + +//! Get the minimum chunk size for buddy allocator. +size_t CUDAPinnedMinChunkSize(); + +//! Get the maximum chunk size for buddy allocator. +size_t CUDAPinnedMaxChunkSize(); + +//! Get the maximum allocation size for a machine. +size_t NPUPinnedMaxAllocSize(); + +//! Get the minimum chunk size for buddy allocator. +size_t NPUPinnedMinChunkSize(); + +//! Get the maximum chunk size for buddy allocator. +size_t NPUPinnedMaxChunkSize(); typedef enum { isa_any, @@ -59,6 +96,8 @@ typedef enum { avx512_bf16, } cpu_isa_t; // Instruction set architecture +// May I use some instruction +bool MayIUse(const cpu_isa_t cpu_isa); } // namespace cpu } // namespace backends } // namespace phi diff --git a/paddle/phi/kernels/funcs/CMakeLists.txt b/paddle/phi/kernels/funcs/CMakeLists.txt index d429d4a8dad2c..efef150b56acb 100644 --- a/paddle/phi/kernels/funcs/CMakeLists.txt +++ b/paddle/phi/kernels/funcs/CMakeLists.txt @@ -19,6 +19,7 @@ math_library(matrix_solve DEPS dense_tensor eigen3 blas math_function) math_library(cross_entropy) math_library(im2col) math_library(vol2col) +math_library(softmax DEPS math_function) cc_library( phi_data_layout_transform diff --git a/paddle/phi/kernels/funcs/eigen/common.h b/paddle/phi/kernels/funcs/eigen/common.h index d34427df0e499..bfdbdf41ee089 100644 --- a/paddle/phi/kernels/funcs/eigen/common.h +++ b/paddle/phi/kernels/funcs/eigen/common.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/eigen/extensions.h" #include "unsupported/Eigen/CXX11/Tensor" namespace phi { diff --git a/paddle/fluid/operators/math/softmax.cc b/paddle/phi/kernels/funcs/softmax.cc similarity index 79% rename from paddle/fluid/operators/math/softmax.cc rename to paddle/phi/kernels/funcs/softmax.cc index 216658b3d7085..2d8dffc3aec6d 100644 --- a/paddle/fluid/operators/math/softmax.cc +++ b/paddle/phi/kernels/funcs/softmax.cc @@ -12,20 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/math/softmax.h" +#include "paddle/phi/kernels/funcs/softmax.h" -#include "paddle/fluid/operators/math/softmax_impl.h" #include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/kernels/funcs/softmax_impl.h" -namespace paddle { -namespace operators { -namespace math { +namespace phi { +namespace funcs { template class SoftmaxFunctor; template class SoftmaxFunctor; template class SoftmaxGradFunctor; template class SoftmaxGradFunctor; -} // namespace math -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/fluid/operators/math/softmax.cu b/paddle/phi/kernels/funcs/softmax.cu similarity index 50% rename from paddle/fluid/operators/math/softmax.cu rename to paddle/phi/kernels/funcs/softmax.cu index b7a9b9a19c970..2ca97cd4ac205 100644 --- a/paddle/fluid/operators/math/softmax.cu +++ b/paddle/phi/kernels/funcs/softmax.cu @@ -13,20 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. */ #include -#include "paddle/fluid/operators/math/softmax.h" -#include "paddle/fluid/operators/math/softmax_impl.h" -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_dnn.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/softmax.h" +#include "paddle/phi/kernels/funcs/softmax_impl.h" -namespace paddle { -namespace operators { -namespace math { +namespace phi { +namespace funcs { -using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; -using DataLayout = platform::DataLayout; +using ScopedTensorDescriptor = phi::backends::gpu::ScopedTensorDescriptor; +using DataLayout = phi::backends::gpu::DataLayout; template -using CudnnDataType = platform::CudnnDataType; +using CudnnDataType = phi::backends::gpu::CudnnDataType; template void SoftmaxCUDNNFunctor::operator()( @@ -51,31 +50,31 @@ void SoftmaxCUDNNFunctor::operator()( xDesc.descriptor(layout, cudnn_tensor_dims); miopenTensorDescriptor_t cudnn_y_desc = xDesc.descriptor(layout, cudnn_tensor_dims); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenSoftmaxForward_V2( - context.cudnn_handle(), - CudnnDataType::kOne(), - cudnn_x_desc, - X->data(), - CudnnDataType::kZero(), - cudnn_y_desc, - Y->mutable_data(context.GetPlace()), - MIOPEN_SOFTMAX_ACCURATE, - MIOPEN_SOFTMAX_MODE_INSTANCE)); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::miopenSoftmaxForward_V2(context.cudnn_handle(), + CudnnDataType::kOne(), + cudnn_x_desc, + X->data(), + CudnnDataType::kZero(), + cudnn_y_desc, + context.template Alloc(Y), + MIOPEN_SOFTMAX_ACCURATE, + MIOPEN_SOFTMAX_MODE_INSTANCE)); #else cudnnTensorDescriptor_t cudnn_x_desc = xDesc.descriptor(layout, cudnn_tensor_dims); cudnnTensorDescriptor_t cudnn_y_desc = xDesc.descriptor(layout, cudnn_tensor_dims); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSoftmaxForward( - context.cudnn_handle(), - CUDNN_SOFTMAX_ACCURATE, - CUDNN_SOFTMAX_MODE_INSTANCE, - CudnnDataType::kOne(), - cudnn_x_desc, - X->data(), - CudnnDataType::kZero(), - cudnn_y_desc, - Y->mutable_data(context.GetPlace()))); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnSoftmaxForward(context.cudnn_handle(), + CUDNN_SOFTMAX_ACCURATE, + CUDNN_SOFTMAX_MODE_INSTANCE, + CudnnDataType::kOne(), + cudnn_x_desc, + X->data(), + CudnnDataType::kZero(), + cudnn_y_desc, + context.template Alloc(Y))); #endif } @@ -106,18 +105,18 @@ void SoftmaxGradCUDNNFunctor::operator()( dxDesc.descriptor(layout, cudnn_tensor_dims); miopenTensorDescriptor_t cudnn_ygrad_desc = dyDesc.descriptor(layout, cudnn_tensor_dims); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenSoftmaxBackward_V2( - context.cudnn_handle(), - CudnnDataType::kOne(), - cudnn_y_desc, - Y->data(), - cudnn_ygrad_desc, - YGrad->data(), - CudnnDataType::kZero(), - cudnn_xgrad_desc, - XGrad->mutable_data(context.GetPlace()), - MIOPEN_SOFTMAX_ACCURATE, - MIOPEN_SOFTMAX_MODE_INSTANCE)); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::miopenSoftmaxBackward_V2(context.cudnn_handle(), + CudnnDataType::kOne(), + cudnn_y_desc, + Y->data(), + cudnn_ygrad_desc, + YGrad->data(), + CudnnDataType::kZero(), + cudnn_xgrad_desc, + context.template Alloc(XGrad), + MIOPEN_SOFTMAX_ACCURATE, + MIOPEN_SOFTMAX_MODE_INSTANCE)); #else cudnnTensorDescriptor_t cudnn_y_desc = yDesc.descriptor(layout, cudnn_tensor_dims); @@ -125,28 +124,28 @@ void SoftmaxGradCUDNNFunctor::operator()( dxDesc.descriptor(layout, cudnn_tensor_dims); cudnnTensorDescriptor_t cudnn_ygrad_desc = dyDesc.descriptor(layout, cudnn_tensor_dims); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSoftmaxBackward( - context.cudnn_handle(), - CUDNN_SOFTMAX_ACCURATE, - CUDNN_SOFTMAX_MODE_INSTANCE, - CudnnDataType::kOne(), - cudnn_y_desc, - Y->data(), - cudnn_ygrad_desc, - YGrad->data(), - CudnnDataType::kZero(), - cudnn_xgrad_desc, - XGrad->mutable_data(context.GetPlace()))); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnSoftmaxBackward(context.cudnn_handle(), + CUDNN_SOFTMAX_ACCURATE, + CUDNN_SOFTMAX_MODE_INSTANCE, + CudnnDataType::kOne(), + cudnn_y_desc, + Y->data(), + cudnn_ygrad_desc, + YGrad->data(), + CudnnDataType::kZero(), + cudnn_xgrad_desc, + context.template Alloc(XGrad))); #endif } template class SoftmaxCUDNNFunctor; -template class SoftmaxCUDNNFunctor; +template class SoftmaxCUDNNFunctor; template class SoftmaxGradCUDNNFunctor; -template class SoftmaxGradCUDNNFunctor; +template class SoftmaxGradCUDNNFunctor; #if CUDNN_VERSION_MIN(8, 1, 0) -template class SoftmaxCUDNNFunctor; -template class SoftmaxGradCUDNNFunctor; +template class SoftmaxCUDNNFunctor; +template class SoftmaxGradCUDNNFunctor; #endif // MIOPEN do not support double @@ -155,15 +154,14 @@ template class SoftmaxCUDNNFunctor; template class SoftmaxGradCUDNNFunctor; #endif -template class SoftmaxFunctor; -template class SoftmaxFunctor; +template class SoftmaxFunctor; +template class SoftmaxFunctor; template class SoftmaxFunctor; template class SoftmaxFunctor; template class SoftmaxGradFunctor; template class SoftmaxGradFunctor; -template class SoftmaxGradFunctor; -template class SoftmaxGradFunctor; +template class SoftmaxGradFunctor; +template class SoftmaxGradFunctor; -} // namespace math -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/fluid/operators/math/softmax.h b/paddle/phi/kernels/funcs/softmax.h similarity index 91% rename from paddle/fluid/operators/math/softmax.h rename to paddle/phi/kernels/funcs/softmax.h index 9d25309d146a8..80805eb6d76f6 100644 --- a/paddle/fluid/operators/math/softmax.h +++ b/paddle/phi/kernels/funcs/softmax.h @@ -13,11 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include "paddle/fluid/framework/tensor.h" +#include "paddle/phi/core/dense_tensor.h" -namespace paddle { -namespace operators { -namespace math { +namespace phi { +namespace funcs { template class SoftmaxFunctor { @@ -58,6 +57,5 @@ class SoftmaxGradCUDNNFunctor { #endif -} // namespace math -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/fluid/operators/math/softmax_impl.h b/paddle/phi/kernels/funcs/softmax_impl.h similarity index 86% rename from paddle/fluid/operators/math/softmax_impl.h rename to paddle/phi/kernels/funcs/softmax_impl.h index 3ce7374e4d39f..330ac331b6b8e 100644 --- a/paddle/fluid/operators/math/softmax_impl.h +++ b/paddle/phi/kernels/funcs/softmax_impl.h @@ -15,24 +15,22 @@ limitations under the License. */ #pragma once #include -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/operators/jit/kernels.h" -#include "paddle/fluid/operators/math/cpu_vec.h" -#include "paddle/fluid/platform/bfloat16.h" -#include "paddle/fluid/platform/cpu_info.h" -#include "paddle/fluid/platform/float16.h" #include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/backends/cpu/cpu_info.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/cpu_vec.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" -namespace paddle { -namespace operators { -namespace math { +namespace phi { +namespace funcs { template -using EigenMatrix = framework::EigenMatrix; +using EigenMatrix = phi::EigenMatrix; template struct ValueClip { @@ -104,7 +102,7 @@ class SoftmaxEigen { }; template -class SoftmaxEigen { +class SoftmaxEigen { public: void operator()(const DeviceContext& context, const int axis_dim, @@ -114,8 +112,8 @@ class SoftmaxEigen { constexpr int kClassDim = 1; constexpr int kAxisDim = 1; - auto logits = EigenMatrix::From(*X); - auto softmax = EigenMatrix::From(*Y); + auto logits = EigenMatrix::From(*X); + auto softmax = EigenMatrix::From(*Y); const int batch_size = logits.dimension(kBatchDim); const int num_classes = logits.dimension(kClassDim); @@ -139,7 +137,7 @@ class SoftmaxEigen { (logits - logits.maximum(along_axis) .reshape(batch_by_one) .broadcast(one_by_class)) - .unaryExpr(ValueClip()); + .unaryExpr(ValueClip()); } else { // axis != -1, class dimension split into (axis, remain), max and sum // should be calculated along axis dimension @@ -149,7 +147,7 @@ class SoftmaxEigen { .reshape(batch_one_remain) .broadcast(one_axis_one) .reshape(batch_classes)) - .unaryExpr(ValueClip()); + .unaryExpr(ValueClip()); } softmax.device(*context.eigen_device()) = softmax.exp(); @@ -162,7 +160,7 @@ class SoftmaxEigen { }; template -class SoftmaxEigen { +class SoftmaxEigen { public: void operator()(const DeviceContext& context, const int axis_dim, @@ -172,8 +170,8 @@ class SoftmaxEigen { constexpr int kClassDim = 1; constexpr int kAxisDim = 1; - auto logits = EigenMatrix::From(*X); - auto softmax = EigenMatrix::From(*Y); + auto logits = EigenMatrix::From(*X); + auto softmax = EigenMatrix::From(*Y); const int batch_size = logits.dimension(kBatchDim); const int num_classes = logits.dimension(kClassDim); @@ -197,7 +195,7 @@ class SoftmaxEigen { (logits - logits.maximum(along_axis) .reshape(batch_by_one) .broadcast(one_by_class)) - .unaryExpr(ValueClip()); + .unaryExpr(ValueClip()); } else { // axis != -1, class dimension split into (axis, remain), max and sum // should be calculated along axis dimension @@ -207,7 +205,7 @@ class SoftmaxEigen { .reshape(batch_one_remain) .broadcast(one_axis_one) .reshape(batch_classes)) - .unaryExpr(ValueClip()); + .unaryExpr(ValueClip()); } softmax.device(*context.eigen_device()) = softmax.exp(); @@ -247,21 +245,24 @@ class SoftmaxFunctor> { const int batch_size = in_dims[kBatchDim]; const int num_remain = num_classes / axis_dim; - if (num_remain == 1 && platform::MayIUse(platform::avx)) { + if (num_remain == 1 && + phi::backends::cpu::MayIUse(phi::backends::cpu::avx)) { const T* in_data = X->data(); T* out_data = Y->data(); for (int bs = 0; bs < batch_size; ++bs) { T max_val = *std::max_element(in_data, in_data + num_classes); max_val *= static_cast(-1); - vec_add_bias(num_classes, max_val, in_data, out_data); - vec_clip( + vec_add_bias( + num_classes, max_val, in_data, out_data); + vec_clip( num_classes, static_cast(-64), out_data, out_data); vec_exp(num_classes, out_data, out_data); T sum = 0; - vec_sum(num_classes, out_data, &sum); + vec_sum(num_classes, out_data, &sum); sum = static_cast(1) / sum; - vec_scal(num_classes, sum, out_data, out_data); + vec_scal( + num_classes, sum, out_data, out_data); in_data += num_classes; out_data += num_classes; @@ -308,16 +309,16 @@ class SoftmaxGradEigen { }; template -class SoftmaxGradEigen { +class SoftmaxGradEigen { public: void operator()(const DeviceContext& context, const int axis_dim, const phi::DenseTensor* y, const phi::DenseTensor* y_grad, phi::DenseTensor* x_grad) { - auto softmax = EigenMatrix::From(*y); - auto softmax_grad = EigenMatrix::From(*y_grad); - auto logits_grad = EigenMatrix::From(*x_grad); + auto softmax = EigenMatrix::From(*y); + auto softmax_grad = EigenMatrix::From(*y_grad); + auto logits_grad = EigenMatrix::From(*x_grad); constexpr int kBatchDim = 0; constexpr int kClassDim = 1; @@ -342,16 +343,16 @@ class SoftmaxGradEigen { }; template -class SoftmaxGradEigen { +class SoftmaxGradEigen { public: void operator()(const DeviceContext& context, const int axis_dim, const phi::DenseTensor* y, const phi::DenseTensor* y_grad, phi::DenseTensor* x_grad) { - auto softmax = EigenMatrix::From(*y); - auto softmax_grad = EigenMatrix::From(*y_grad); - auto logits_grad = EigenMatrix::From(*x_grad); + auto softmax = EigenMatrix::From(*y); + auto softmax_grad = EigenMatrix::From(*y_grad); + auto logits_grad = EigenMatrix::From(*x_grad); constexpr int kBatchDim = 0; constexpr int kClassDim = 1; @@ -400,17 +401,20 @@ class SoftmaxGradFunctor> { const int batch_size = out_dims[kBatchDim]; const int num_remain = num_classes / axis_dim; - if (num_remain == 1 && platform::MayIUse(platform::avx)) { + if (num_remain == 1 && + phi::backends::cpu::MayIUse(phi::backends::cpu::avx)) { const T* out_data = y->data(); const T* out_grad = y_grad->data(); T* in_grad = x_grad->data(); for (int bs = 0; bs < batch_size; ++bs) { T scalar; - vec_mul_reduce( + vec_mul_reduce( num_classes, out_grad, out_data, &scalar); scalar *= static_cast(-1); - vec_add_bias(num_classes, scalar, out_grad, in_grad); - vec_mul(num_classes, out_data, in_grad, in_grad); + vec_add_bias( + num_classes, scalar, out_grad, in_grad); + vec_mul( + num_classes, out_data, in_grad, in_grad); out_data += num_classes; out_grad += num_classes; in_grad += num_classes; @@ -422,6 +426,5 @@ class SoftmaxGradFunctor> { } }; -} // namespace math -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/gpu/cross_entropy_grad_kernel.cu b/paddle/phi/kernels/gpu/cross_entropy_grad_kernel.cu index 934b0fe152bd8..2e1a485229b06 100644 --- a/paddle/phi/kernels/gpu/cross_entropy_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/cross_entropy_grad_kernel.cu @@ -22,7 +22,6 @@ limitations under the License. */ namespace cub = hipcub; #endif -#include "paddle/fluid/operators/math/softmax.h" #include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/backends/gpu/gpu_dnn.h" #include "paddle/phi/common/amp_type_traits.h" @@ -32,6 +31,7 @@ namespace cub = hipcub; #include "paddle/phi/kernels/funcs/axis_utils.h" #include "paddle/phi/kernels/funcs/for_range.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/softmax.h" #include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h" namespace phi { diff --git a/paddle/phi/kernels/gpu/cross_entropy_kernel.cu b/paddle/phi/kernels/gpu/cross_entropy_kernel.cu index 94d2d7a744c21..b274da7743b8f 100644 --- a/paddle/phi/kernels/gpu/cross_entropy_kernel.cu +++ b/paddle/phi/kernels/gpu/cross_entropy_kernel.cu @@ -22,7 +22,6 @@ limitations under the License. */ namespace cub = hipcub; #endif -#include "paddle/fluid/operators/math/softmax.h" #include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/backends/gpu/gpu_dnn.h" #include "paddle/phi/common/amp_type_traits.h" @@ -33,6 +32,7 @@ namespace cub = hipcub; #include "paddle/phi/kernels/funcs/cross_entropy.h" #include "paddle/phi/kernels/funcs/for_range.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/softmax.h" #include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h" namespace phi { @@ -1386,7 +1386,7 @@ void CrossEntropyWithSoftmaxCUDAKernel(const GPUContext& dev_ctx, labels_2d.Resize({n, label.numel() / n}); DenseTensor loss_2d(*loss); loss_2d.Resize({n, 1}); - paddle::operators::math::SoftmaxCUDNNFunctor()( + phi::funcs::SoftmaxCUDNNFunctor()( dev_ctx, &logits_2d, &softmax_2d); phi::funcs::CrossEntropyFunctor()(dev_ctx, &loss_2d, diff --git a/paddle/phi/kernels/impl/gumbel_softmax_grad_kernel_impl.h b/paddle/phi/kernels/impl/gumbel_softmax_grad_kernel_impl.h index 96ae00366e913..f74d91990ad3b 100644 --- a/paddle/phi/kernels/impl/gumbel_softmax_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/gumbel_softmax_grad_kernel_impl.h @@ -14,11 +14,11 @@ #pragma once -#include "paddle/fluid/operators/math/softmax.h" -#include "paddle/fluid/operators/math/softmax_impl.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/funcs/axis_utils.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/softmax.h" +#include "paddle/phi/kernels/funcs/softmax_impl.h" namespace phi { @@ -50,7 +50,7 @@ void GumbelSoftmaxGradKernel(const Context& ctx, dx_2d.Resize({size_to_axis, size_from_axis}); out_2d.Resize({size_to_axis, size_from_axis}); dout_2d.Resize({size_to_axis, size_from_axis}); - paddle::operators::math::SoftmaxGradFunctor()( + phi::funcs::SoftmaxGradFunctor()( ctx, axis_dim, &out_2d, &dout_2d, &dx_2d); } diff --git a/paddle/phi/kernels/impl/gumbel_softmax_kernel_impl.h b/paddle/phi/kernels/impl/gumbel_softmax_kernel_impl.h index 26dd121be2db6..c2229b50deee1 100644 --- a/paddle/phi/kernels/impl/gumbel_softmax_kernel_impl.h +++ b/paddle/phi/kernels/impl/gumbel_softmax_kernel_impl.h @@ -16,12 +16,12 @@ #include -#include "paddle/fluid/operators/math/softmax.h" -#include "paddle/fluid/operators/math/softmax_impl.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/funcs/axis_utils.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/softmax.h" +#include "paddle/phi/kernels/funcs/softmax_impl.h" namespace phi { @@ -87,8 +87,7 @@ void GumbelSoftmaxKernelHelper(const Context& ctx, size_to_axis, size_from_axis, temperature); - paddle::operators::math::SoftmaxFunctor()( - ctx, axis_dim, &x_noise_2d, &out_2d); + phi::funcs::SoftmaxFunctor()(ctx, axis_dim, &x_noise_2d, &out_2d); if (hard) { OneHotGenerator::Transform(ctx, x, out, axis); diff --git a/paddle/phi/kernels/impl/softmax_grad_kernel_impl.h b/paddle/phi/kernels/impl/softmax_grad_kernel_impl.h index ef869195caf28..859a418dd4c3e 100644 --- a/paddle/phi/kernels/impl/softmax_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/softmax_grad_kernel_impl.h @@ -14,9 +14,9 @@ limitations under the License. */ #pragma once -#include "paddle/fluid/operators/math/softmax.h" #include "paddle/phi/kernels/funcs/axis_utils.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/softmax.h" #include "paddle/phi/kernels/softmax_grad_kernel.h" namespace phi { @@ -50,7 +50,7 @@ void SoftmaxGradKernel(const Context& dev_ctx, Out_2d.ShareDataWith(out).Resize({n, d}); dOut_2d.ShareDataWith(out_grad).Resize({n, d}); - paddle::operators::math::SoftmaxGradFunctor()( + phi::funcs::SoftmaxGradFunctor()( dev_ctx, axis_dim, &Out_2d, &dOut_2d, &dX_2d); } diff --git a/paddle/phi/kernels/impl/softmax_kernel_impl.h b/paddle/phi/kernels/impl/softmax_kernel_impl.h index 4114e1105191a..d5601104e18c4 100644 --- a/paddle/phi/kernels/impl/softmax_kernel_impl.h +++ b/paddle/phi/kernels/impl/softmax_kernel_impl.h @@ -14,9 +14,9 @@ limitations under the License. */ #pragma once -#include "paddle/fluid/operators/math/softmax.h" #include "paddle/phi/kernels/funcs/axis_utils.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/softmax.h" #include "paddle/phi/kernels/softmax_kernel.h" namespace phi { @@ -47,8 +47,7 @@ void SoftmaxKernel(const Context& dev_ctx, DenseTensor X_2d, Out_2d; X_2d.ShareDataWith(x).Resize({n, d}); Out_2d.ShareDataWith(*out).Resize({n, d}); - paddle::operators::math::SoftmaxFunctor()( - dev_ctx, axis_dim, &X_2d, &Out_2d); + phi::funcs::SoftmaxFunctor()(dev_ctx, axis_dim, &X_2d, &Out_2d); } } // namespace phi diff --git a/paddle/phi/tests/kernels/CMakeLists.txt b/paddle/phi/tests/kernels/CMakeLists.txt index ed79b26d1c291..1425ea0361cfe 100644 --- a/paddle/phi/tests/kernels/CMakeLists.txt +++ b/paddle/phi/tests/kernels/CMakeLists.txt @@ -22,7 +22,7 @@ endif() cc_test( test_cpu_vec SRCS test_cpu_vec.cc - DEPS blas cpu_info) + DEPS blas phi_backends) # For String Kernels cc_test(