From 0becc4a7d11245e5c0056468b2ce27dc9a5ace38 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sun, 5 May 2024 06:58:58 +0000 Subject: [PATCH] Add int4 quantize kernel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit add int4_1 int4_2 FLAGS_logging_pir_py_code (#63981) * FLAGS_logging_pir_py_code * FLAGS_logging_pir_py_code_dir --------- Co-authored-by: jiahy0825 [Cleanup] Remove Flake8 config in `.editorconfig` (#64027) 【PIR Dist Op Reg No.19】 reg pull_box_sparse (#62982) * fix * fix * fix * fix * fix * fix * add test * add * fix * fix * add out * fix * codestyle * fix * fix backward * merge [Dy2St][PIR] Hold backward program in GradNode (#63694) Co-authored-by: xiongkun Co-authored-by: Nyakku Shigure split test.cmake: add new test_cases.cmake (#64007) [PIR] Support sparse_slice and sparse_sum in pt (#64009) * support sparse_slice and sparse_sum in pt * support sparse_slice and sparse_sum in pt * support sparse_slice and sparse_sum in pt option for WITH_CPP_TEST (#63896) * option for WITH_CPP_TEST * fix * Fix * Fix [PIR] Fix `attributes_num` of `SliceArrayOp` (#64013) [Dy2St] Use `full_graph=True` outside dy2st uts (part1) (#64058) [Dy2St] Use `full_graph=True` outside dy2st uts (part2) (#64059) fix typo (#64060) Co-authored-by: jiahy0825 update (#64042) Replace paddle/fluid/platform/device/gpu/gpu_dnn.h (#63819) * Fix * Fix * Fix Clean lookup_table_v2_op.h lookup_table_v2_op.cu (#64020) * Fix * ci refine GetTensorListFromArgs (#64045) Revert "【Hackathon 6th Fundable Projects 3 No.60】Remove fluid operator chunk_…" (#64050) This reverts commit 88b1a6ed30a6d66a0226a4152429981511659908. [Prim][PIR] support floor_divide op forward in prim pir (#64023) * floor-div-dev * update test [CINN] Reconstruct shape_analysis (#63790) * reconstruct shape_analysis * fix input value shape infer * fix merge bugs * fix concat and gather op InferSymbolicShape * fix merge bug * fix value_to_shape_or_data hash error and add some checks * fix set shape for null value * fix group op lazy infer * add IsStaticShape check * fix merge bug * support static dim check and set for VectorType * change auto to detail type [XPU] fix bugs in processing of attention_mask and fix_seed_offset on XPU (#64003) * [XPU] fix segmentfault caused by setting fix_seed_offset on XPU * cast attention_mask to float32 when necessary fix merge bug (#64069) 【Fix PIR Unittest No.125、147、481】Fix some 0D uts in PIR mode (part1) (#64064) [Prim][VJP]support autogen to remove unused composite in .yaml (#64054) * support autogen to remove unused composite in .yaml * fix bug [PIR] Fix typo `set_pit_tests_properties` -> `set_pir_tests_properties` (#64063) [Dy2St] Use `full_graph=True` outside dy2st uts (part3) (#64066) [PIR save/load] Open more tests for paddle.save and paddle.load (#64044) * open more tests for paddle.save and paddle.load * fix API Improvement for paddle.nn.functional.group_norm and paddle.nn.GroupNorm (#63881) * update group_norm * update trt plugin * update trt plugin * fix trt plugin * fix trt plugin * fix test * fix test * fix ci windows inference * update kernel function names and add v2 test * fix * fix fp16 test Revert "【Hackathon 6th Fundable Projects 3 No.81】Remove fluid operators ctc_a…" (#64049) This reverts commit 2134ead2500320b0101169a4809fb9f50c76cc77. Clean paddle/fluid/operators/fused/attention_layer_norm.h (#64051) * Fix * Fix Replace operators::math to phi::math in fluid/operators (#63854) [CINN]Clean usless loop_reorder_aligment tactic (#63998) * [CINN]Clean usless loop_reorder_aligment tactic * fix source 【Hackathon 6th Fundable Projects 3 No.396】fluid operator yolo_box_head (#63783) * Fix * Fix * Fix * Fix * Fix 【Hackathon 6th Fundable Projects 3 No.240】fluid operator moe (#63929) 【Hackathon 6th Fundable Projects 3 No.82】fluid operator cudnn_lstm (#63936) * Fix * Fix * Fix * Fix [CINN] Remove useless log (#64052) [pir_save_load] add pir for test_jit_save_load.py (#63958) * add jit load.train * modify backward program lost * modify * combine eval and train * modify 8 case of jit.save.load * modify jit_save_load case * rename jit_save_load * change name all * modify timeout * modify new case * modify TestJitSaveLoadMultiMethods * modify cpu tensor no holder bug Flashattention support qkvpacked and varlen (#63289) * Flashattention support qkvpacked and varlen * fix codestyle * fix codestyle * FlashAttention kvReduceGQA Performance Optimization * Fix problem with windows * code clean * update third_party/flashattn * update errormsg and docs * update api * update doc * update doctest * update doc, test=document_fix * update doc, test=document_fix * Update python/paddle/nn/functional/flash_attention.py Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com> * Update python/paddle/nn/functional/flash_attention.py Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com> * update doc --------- Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com> 【PIR Dist Op Reg No.20】 reg global_gather (#63867) * reg global_gather * reg global_gather * reg_global_gather * fix * fix * fix * fix conflict * fix conflict * Update ops_api_gen.py * Update ops_api_gen.py Fix backward program kwargs error when process inplace value (#63939) 【Hackathon 6th No.35】support kwargs for recompute when use_reentrant == True fix (#63880) * support kwargs for recompute when use_reentrant == True * recover third party merge main lint delete printf change flash attn version --- .editorconfig | 5 +- CMakeLists.txt | 1 + .../hlir/dialect/operator/ir/manual_op.cc | 24 +- .../cinn/hlir/dialect/operator/ir/manual_op.h | 8 +- .../operator/transforms/add_cinn_pass.cc | 8 +- .../transforms/add_store_in_fusion_op_pass.cc | 8 +- .../transforms/check_infer_symbolic_pass.cc | 5 +- .../transforms/cinn_group_cluster_pass.cc | 16 +- .../transforms/dynamic_reshape_pass.cc | 22 +- .../transforms/fold_manipulation_ops_pass.cc | 14 +- ...e_shape_ops_into_generate_shape_op_pass.cc | 6 +- .../convert_dynamic_to_static_dim_pass.cc | 14 - .../convert_static_dim_to_dynamic_pass.cc | 23 +- .../group_merge/simplify_dim_expr_pass.cc | 22 +- .../group_merge/single_op_fallback_to_phi.cc | 14 +- .../transforms/insert_broadcast_pass.cc | 5 +- .../lowering_pass/collect_sym_expr.cc | 21 +- .../lower_cinn_fusion_op_pass.cc | 12 +- .../operator/transforms/pd_to_cinn_pass.cc | 3 - .../transforms/pir_to_py_code_converter.cc | 56 +- .../transforms/pir_to_py_code_converter.h | 3 +- .../transforms/replace_dynamic_expand_pass.cc | 17 +- paddle/cinn/hlir/framework/pir/fusion_info.cc | 5 - paddle/cinn/hlir/framework/pir/utils.cc | 45 -- .../dy_shape_group_scheduler.cc | 1 - .../ir/group_schedule/tactic/CMakeLists.txt | 1 - .../tactic/loop_reorder_alignment_tactic.cc | 132 --- .../tactic/loop_reorder_alignment_tactic.h | 26 - paddle/common/flags.cc | 4 + .../fluid/distributed/collective/reducer.cc | 4 +- paddle/fluid/distributed/collective/reducer.h | 2 +- .../distributed/index_dataset/index_sampler.h | 9 +- .../eager/to_static/run_program_op_func.h | 11 +- .../eager/to_static/run_program_op_node.h | 70 +- paddle/fluid/eager/utils.cc | 12 + paddle/fluid/eager/utils.h | 2 + paddle/fluid/framework/CMakeLists.txt | 1 - paddle/fluid/framework/op_desc.cc | 4 + paddle/fluid/framework/type_defs.cc | 3 +- paddle/fluid/framework/type_defs.h | 4 +- paddle/fluid/imperative/amp_utils.h | 6 + paddle/fluid/imperative/reducer.cc | 9 +- .../tensorrt/plugin/group_norm_op_plugin.cu | 128 ++- .../tensorrt/plugin/group_norm_op_plugin.h | 4 +- .../plugin/preln_groupnorm_act_op_plugin.cu | 125 +-- .../plugin/preln_groupnorm_act_op_plugin.h | 4 +- .../plugin/skip_groupnorm_act_op_plugin.cu | 151 ++-- .../plugin/skip_groupnorm_act_op_plugin.h | 4 +- .../ir_adaptor/translator/op_compat_gen.py | 3 + .../ir_adaptor/translator/op_translator.cc | 5 + paddle/fluid/operators/CMakeLists.txt | 2 +- .../fluid/operators/array_to_lod_tensor_op.cc | 4 +- paddle/fluid/operators/chunk_eval_op.cc | 202 +++++ paddle/fluid/operators/chunk_eval_op.h | 358 +++++++++ .../operators/collective/c_concat_op.cu.cc | 4 +- paddle/fluid/operators/ctc_align_op.cc | 133 ++++ paddle/fluid/operators/ctc_align_op.cu | 171 ++++ paddle/fluid/operators/ctc_align_op.h | 119 +++ paddle/fluid/operators/cudnn_lstm_op.cc | 285 ------- paddle/fluid/operators/cudnn_rnn_cache.h | 2 +- .../fluid/operators/detection/bbox_util.cu.h | 2 +- .../detection/collect_fpn_proposals_op.cu | 2 +- paddle/fluid/operators/fused/CMakeLists.txt | 2 - .../operators/fused/attention_layer_norm.h | 113 --- .../fused/cudnn_bn_stats_finalize.cu.h | 9 +- .../operators/fused/cudnn_norm_conv.cu.h | 9 +- .../fused/cudnn_scale_bias_add_relu.cu.h | 11 +- paddle/fluid/operators/fused/fmha_ref.h | 750 ------------------ .../fused/fused_multi_transformer_int8_op.cu | 10 +- .../fused/fused_multi_transformer_op.cu | 6 +- .../fused/fused_multi_transformer_op.cu.h | 33 +- .../fluid/operators/fused/resnet_unit_op.cu | 4 +- .../fused/xpu_fused_common_function.h | 225 ------ .../fluid/operators/fused/yolo_box_head_op.cc | 50 -- .../operators/grid_sampler_cudnn_op.cu.cc | 2 +- .../fluid/operators/lod_tensor_to_array_op.cc | 4 +- paddle/fluid/operators/lookup_table_v2_op.cu | 254 ------ paddle/fluid/operators/lookup_table_v2_op.h | 285 ------- paddle/fluid/operators/math/CMakeLists.txt | 2 - paddle/fluid/operators/math/prelu.h | 2 +- paddle/fluid/operators/math/sample_prob.cc | 21 - paddle/fluid/operators/math/sample_prob.cu | 206 ----- paddle/fluid/operators/math/sample_prob.h | 125 --- paddle/fluid/operators/math/sampler.cc | 99 --- paddle/fluid/operators/math/sampler.h | 135 ---- paddle/fluid/operators/miopen_rnn_cache.h | 2 +- paddle/fluid/operators/moe_op.cc | 64 -- paddle/fluid/operators/nce_op.h | 32 +- .../operators/ops_signature/cudnn_lstm_sig.cc | 59 -- .../sequence_ops/sequence_softmax_op.cc | 2 +- paddle/fluid/operators/tdm_child_op.cc | 1 - paddle/fluid/operators/unbind_op.h | 2 +- paddle/fluid/operators/unique_op.h | 4 +- .../decomp_interface_gen_op_list.py | 2 + .../op_generator/infer_symbolic_shape_gen.py | 4 +- .../fluid/pir/dialect/op_generator/op_gen.py | 2 +- .../op_generator/op_infermeta_func_gen.py | 5 +- .../pir/dialect/op_generator/ops_api_gen.py | 5 + .../infer_symbolic_shape/binary_infer_sym.cc | 116 +-- .../infer_symbolic_shape/cinn_op_infer_sym.cc | 117 ++- .../infer_symbolic_shape/cinn_op_infer_sym.h | 1 + .../element_wise_binary.cc | 38 +- .../infer_symbolic_shape/infer_sym_utils.cc | 20 +- .../infer_symbolic_shape/infer_sym_utils.h | 6 +- .../multiary_infer_sym.cc | 221 +++--- .../infer_symbolic_shape/nullary_infer_sym.cc | 91 ++- .../same_operands_result.cc | 27 +- .../infer_symbolic_shape/unary_infer_sym.cc | 305 ++++--- .../dialect/operator/ir/control_flow_op.cc | 72 +- .../pir/dialect/operator/ir/control_flow_op.h | 6 +- .../dialect/operator/ir/manual_onednn_op.cc | 2 +- .../dialect/operator/ir/manual_onednn_op.h | 2 +- .../pir/dialect/operator/ir/manual_op.cc | 35 +- .../fluid/pir/dialect/operator/ir/manual_op.h | 10 +- .../pir/dialect/operator/ir/op_dialect.cc | 36 +- paddle/fluid/pir/dialect/operator/ir/ops.yaml | 20 + .../pir/dialect/operator/ir/ops_backward.yaml | 12 + .../fluid/pir/dialect/operator/utils/utils.cc | 2 + .../pir/transforms/sub_graph_detector.cc | 31 +- paddle/fluid/primitive/codegen/gen.py | 3 - .../primitive/codegen/templates/common.j2 | 8 +- .../rule/vjp/generated/generated_vjp.cc.j2 | 10 +- paddle/fluid/primitive/composite/composite.h | 8 + paddle/fluid/pybind/eager_functions.cc | 9 +- paddle/fluid/pybind/eager_utils.cc | 52 +- paddle/fluid/pybind/eager_utils.h | 3 +- paddle/fluid/pybind/op_function_common.cc | 19 +- paddle/fluid/pybind/pir.cc | 97 ++- paddle/fluid/pybind/tensor_py.h | 4 +- paddle/phi/api/yaml/backward.yaml | 48 ++ paddle/phi/api/yaml/op_compat.yaml | 67 ++ paddle/phi/api/yaml/ops.yaml | 55 ++ paddle/phi/infermeta/backward.cc | 29 + paddle/phi/infermeta/backward.h | 12 + paddle/phi/infermeta/binary.cc | 36 + paddle/phi/infermeta/binary.h | 13 + paddle/phi/infermeta/fusion.cc | 1 + paddle/phi/infermeta/ternary.cc | 54 ++ paddle/phi/infermeta/ternary.h | 13 + .../kernels/funcs/weight_dequant_functor.h | 11 +- .../phi/kernels/gpu/flash_attn_grad_kernel.cu | 615 ++++++++++++-- paddle/phi/kernels/gpu/flash_attn_kernel.cu | 243 +++++- paddle/phi/kernels/gpu/group_norm_kernel.cu | 248 +++--- .../kernels/gpu/weight_dequantize_kernel.cu | 9 +- .../phi/kernels/gpu/weight_quantize_kernel.cu | 38 +- .../kernels/gpu/yolo_box_head_kernel.cu} | 86 +- paddle/phi/kernels/group_norm_kernel.h | 34 +- .../impl/weight_quantize_kernel_gpu_impl.h | 221 +++++- .../phi/kernels/xpu/flash_attn_grad_kernel.cc | 66 +- paddle/phi/kernels/xpu/flash_attn_kernel.cc | 76 +- .../infer_symbolic_shape.h | 10 +- .../transforms/shape_optimization_pass.h | 2 +- .../dialect/shape/utils/shape_analysis.h | 79 +- .../infer_symbolic_shape.cc | 4 +- .../transforms/shape_optimization_pass.cc | 64 +- .../src/dialect/shape/utils/shape_analysis.cc | 380 ++++++--- paddle/scripts/paddle_build.sh | 18 +- .../distributed/fleet/recompute/recompute.py | 41 +- python/paddle/framework/io.py | 8 +- python/paddle/framework/io_utils.py | 34 + .../incubate/nn/layer/fused_transformer.py | 6 +- python/paddle/jit/api.py | 124 ++- .../jit/dy2static/pir_partial_program.py | 24 +- python/paddle/jit/pir_translated_layer.py | 101 ++- python/paddle/nn/functional/__init__.py | 4 + .../paddle/nn/functional/flash_attention.py | 280 +++++++ python/paddle/nn/functional/norm.py | 6 +- python/paddle/nn/layer/norm.py | 8 +- python/paddle/optimizer/adam.py | 9 +- python/paddle/pir_utils.py | 67 ++ python/paddle/static/input.py | 2 +- python/paddle/static/io_utils.py | 3 +- python/paddle/static/pir_io.py | 16 +- python/paddle/tensor/attribute.py | 6 +- test/CMakeLists.txt | 4 +- test/auto_parallel/CMakeLists.txt | 2 +- .../fleet/test_dygraph_recompute_for_eager.py | 133 ++-- test/contrib/test_d2s_amp_controlflow.py | 6 +- test/cpp/fluid/math/concat_test.cc | 10 +- test/cpp/inference/CMakeLists.txt | 1 + test/cpp/inference/api/CMakeLists.txt | 34 +- test/cpp/inference/test.cmake | 29 - test/cpp/inference/test_cases.cmake | 34 + test/cpp/prim/CMakeLists.txt | 2 +- test/custom_op/test_custom_relu_model.py | 4 +- test/custom_op/test_inference_inplace.py | 1 + test/deprecated/CMakeLists.txt | 4 +- .../custom_runtime/test_custom_cpu_plugin.py | 4 +- .../test_custom_cpu_to_static.py | 4 +- test/deprecated/distribution/CMakeLists.txt | 2 +- test/deprecated/fft/CMakeLists.txt | 2 +- test/deprecated/ir/pir/test_standalone_pir.py | 4 +- .../ir/pir/translator/CMakeLists.txt | 1 + .../test_global_gather_translator.py | 61 ++ .../test_global_scatter_translator.py | 33 +- test/deprecated/legacy_test/CMakeLists.txt | 2 +- test/deprecated/legacy_test/test_apply.py | 4 +- .../test_elementwise_floordiv_op.py | 4 +- .../legacy_test/test_instance_norm_op.py | 2 +- .../legacy_test/test_instance_norm_op_v2.py | 38 +- test/deprecated/legacy_test/test_jit_layer.py | 11 +- .../test_paddle_save_load_binary.py | 31 +- .../legacy_test/test_run_program_op.py | 2 +- ...est_save_inference_model_conditional_op.py | 9 +- .../legacy_test/test_sparse_slice_op.py | 7 + .../legacy_test/test_sparse_sum_op.py | 2 + ...t_zero_dim_sundry_static_api_deprecated.py | 158 ++++ .../test_composite_batch_norm.py | 2 +- .../test_composite_layer_norm.py | 2 +- .../composite_ops/test_composite_softmax.py | 2 +- .../prim/vjp/static/test_comp_add_grad.py | 4 +- .../vjp/static/test_comp_add_tanh_grad.py | 4 +- .../prim/vjp/static/test_comp_cast_grad.py | 4 +- .../prim/vjp/static/test_comp_div_grad.py | 4 +- .../prim/vjp/static/test_comp_gather_grad.py | 4 +- .../prim/vjp/static/test_comp_reshape_grad.py | 4 +- .../prim/vjp/static/test_comp_sqrt_grad.py | 4 +- .../prim/vjp/static/test_comp_sub_grad.py | 4 +- .../prim/vjp/static/test_comp_tanh_grad.py | 4 +- .../vjp/static/test_comp_transpose_grad.py | 4 +- .../prim/process/test_check_inputs.py | 2 +- test/deprecated/rnn/test_rnn_nets.py | 4 +- .../tokenizer/test_faster_tokenizer_op.py | 1 + test/distribution/CMakeLists.txt | 2 +- test/dygraph_to_static/test_no_gradient.py | 3 +- test/dygraph_to_static/test_pylayer.py | 2 +- test/fft/CMakeLists.txt | 2 +- test/ipu/test_dy2static_fp16_ipu.py | 2 +- test/ipu/test_dy2static_ipu.py | 2 +- test/ipu/test_print_op_ipu.py | 2 +- .../inference/test_inference_predictor_run.py | 1 + .../test_save_optimized_model_pass.py | 4 +- .../inference/test_trt_inference_fp16_io.py | 4 +- .../inference/test_trt_inference_predictor.py | 10 +- .../test_xpu_convert_mixed_precision.py | 4 +- .../test_decomp_inference_predictor_run.py | 1 + test/ir/pir/test_subgraph_exporter.py | 2 +- .../test_pull_box_sparse_translator.py | 51 ++ test/ir/test_convert_to_mixed_precision.py | 4 +- test/ir/test_inference_datatype.py | 1 + test/legacy_test/CMakeLists.txt | 5 +- test/legacy_test/test_chunk_eval_op.py | 282 +++++++ test/legacy_test/test_ctc_align.py | 232 ++++++ test/legacy_test/test_dropout_op.py | 20 +- test/legacy_test/test_flash_attention.py | 419 ++++++++++ test/legacy_test/test_group_norm_op.py | 365 ++++++++- test/legacy_test/test_group_norm_op_v2.py | 411 +++++++++- ...e_load.py => test_jit_save_load_rename.py} | 534 +++++++++---- test/legacy_test/test_onnx_export.py | 2 +- test/legacy_test/test_paddle_save_load.py | 94 ++- test/legacy_test/test_stack_op.py | 2 +- test/legacy_test/test_strided_slice_op.py | 4 +- test/legacy_test/test_tensor_register_hook.py | 3 +- test/legacy_test/test_translated_layer.py | 3 +- .../test_zero_dim_sundry_dygraph_api.py | 2 +- .../test_zero_dim_sundry_static_api_part1.py | 92 +-- .../test_zero_dim_sundry_static_api_part3.py | 29 +- test/mkldnn/CMakeLists.txt | 2 +- test/prim/process/test_prim_amp.py | 12 +- test/quantization/CMakeLists.txt | 4 +- test/sot/test_model_switch_training.py | 2 +- test/sot/test_segment_linear.py | 2 +- test/sot/test_sot_cost_model.py | 2 +- third_party/flashattn | 2 +- 264 files changed, 7793 insertions(+), 5088 deletions(-) delete mode 100644 paddle/cinn/ir/group_schedule/tactic/loop_reorder_alignment_tactic.cc delete mode 100644 paddle/cinn/ir/group_schedule/tactic/loop_reorder_alignment_tactic.h create mode 100644 paddle/fluid/operators/chunk_eval_op.cc create mode 100644 paddle/fluid/operators/chunk_eval_op.h create mode 100644 paddle/fluid/operators/ctc_align_op.cc create mode 100644 paddle/fluid/operators/ctc_align_op.cu create mode 100644 paddle/fluid/operators/ctc_align_op.h delete mode 100644 paddle/fluid/operators/cudnn_lstm_op.cc delete mode 100644 paddle/fluid/operators/fused/attention_layer_norm.h delete mode 100644 paddle/fluid/operators/fused/fmha_ref.h delete mode 100644 paddle/fluid/operators/fused/xpu_fused_common_function.h delete mode 100644 paddle/fluid/operators/fused/yolo_box_head_op.cc delete mode 100644 paddle/fluid/operators/lookup_table_v2_op.cu delete mode 100644 paddle/fluid/operators/lookup_table_v2_op.h delete mode 100644 paddle/fluid/operators/math/sample_prob.cc delete mode 100644 paddle/fluid/operators/math/sample_prob.cu delete mode 100644 paddle/fluid/operators/math/sample_prob.h delete mode 100644 paddle/fluid/operators/math/sampler.cc delete mode 100644 paddle/fluid/operators/math/sampler.h delete mode 100644 paddle/fluid/operators/moe_op.cc delete mode 100644 paddle/fluid/operators/ops_signature/cudnn_lstm_sig.cc rename paddle/{fluid/operators/fused/yolo_box_head_op.cu => phi/kernels/gpu/yolo_box_head_kernel.cu} (57%) create mode 100644 test/cpp/inference/test_cases.cmake create mode 100644 test/deprecated/ir/pir/translator/test_global_gather_translator.py create mode 100644 test/deprecated/legacy_test/test_zero_dim_sundry_static_api_deprecated.py create mode 100644 test/ir/pir/translator/test_pull_box_sparse_translator.py create mode 100644 test/legacy_test/test_chunk_eval_op.py create mode 100644 test/legacy_test/test_ctc_align.py rename test/legacy_test/{test_jit_save_load.py => test_jit_save_load_rename.py} (85%) rename test/{deprecated => }/legacy_test/test_zero_dim_sundry_dygraph_api.py (99%) rename test/{deprecated => }/legacy_test/test_zero_dim_sundry_static_api_part1.py (88%) rename test/{deprecated => }/legacy_test/test_zero_dim_sundry_static_api_part3.py (94%) diff --git a/.editorconfig b/.editorconfig index 7c31ee5239836d..b7b333f7b76864 100644 --- a/.editorconfig +++ b/.editorconfig @@ -15,15 +15,12 @@ insert_final_newline = true [*.{c,cc,cxx,cpp,cu,cuh,h,hpp,hxx,kps}] indent_size = 2 -[*.{py,java,r}] +[*.{py,pyi,java,r,toml}] indent_size = 4 [Dockerfile.*] indent_size = 4 -[.flake8] -indent_size = 4 - [*.go] indent_style = tab indent_size = 4 diff --git a/CMakeLists.txt b/CMakeLists.txt index d3e1e3fa3ea5d6..0aa41a26d700e2 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -68,6 +68,7 @@ option(WITH_PIP_CUDA_LIBRARIES "Paddle uses the CUDA library provided by NVIDIA" OFF) option(WITH_NIGHTLY_BUILD "Compile nightly paddle whl package of the develop branch" OFF) +option(WITH_CPP_TEST "Compile PaddlePaddle skip cpp test" ON) find_package(Git REQUIRED) # config GIT_URL with github mirrors to speed up dependent repos clone diff --git a/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc b/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc index 2cecc5bd052bc5..ec9ac943c7cce1 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc +++ b/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc @@ -116,14 +116,14 @@ void GroupOp::Print(pir::IrPrinter& printer) { } bool GroupOp::InferSymbolicShape( - ::pir::ShapeConstraintIRAnalysis* shape_analysis) { - ::pir::InferSymExprForBlock(*block(), shape_analysis); + ::pir::InferSymbolicShapeContext* infer_context) { + ::pir::InferSymExprForBlock(*block(), infer_context); for (uint32_t rst_idx = 0; rst_idx < num_results(); rst_idx++) { auto inner_yield_value = block()->back().operand_source(rst_idx); const auto& shape = - shape_analysis->GetShapeOrDataForValue(inner_yield_value); - shape_analysis->SetShapeOrDataForValue(result(rst_idx), shape); + infer_context->GetShapeOrDataForValue(inner_yield_value); + infer_context->SetShapeOrDataForValue(result(rst_idx), shape); } if (VLOG_IS_ON(4)) { @@ -204,16 +204,16 @@ void YieldStoreOp::Build(pir::Builder& builder, void YieldStoreOp::VerifySig() {} bool YieldStoreOp::InferSymbolicShape( - pir::ShapeConstraintIRAnalysis* shape_analysis) { - shape_analysis->SetShapeOrDataForValue( - result(0), shape_analysis->GetShapeOrDataForValue(operand_source(0))); + pir::InferSymbolicShapeContext* infer_context) { + infer_context->SetShapeOrDataForValue( + result(0), infer_context->GetShapeOrDataForValue(operand_source(0))); return true; } bool ConcatOp::InferSymbolicShape( - pir::ShapeConstraintIRAnalysis* shape_analysis) { + pir::InferSymbolicShapeContext* infer_context) { VLOG(4) << "Infer symbolic shape for cinn_op.concat"; - return ConcatOpInferSymbolicShape(this->operation(), shape_analysis); + return ConcatOpInferSymbolicShape(this->operation(), infer_context); } void ConcatOp::Build(pir::Builder& builder, // NOLINT @@ -476,7 +476,7 @@ GenerateShapeOp::ConvertAttributeToSymbolBindings( } bool GenerateShapeOp::InferSymbolicShape( - pir::ShapeConstraintIRAnalysis* shape_analysis) { + pir::InferSymbolicShapeContext* infer_context) { const auto attr_dim_exprs = [&] { std::vector dim_exprs{}; pir::Attribute dim_expr_attr = this->attributes().at("output_dim_exprs"); @@ -505,7 +505,7 @@ bool GenerateShapeOp::InferSymbolicShape( }(); auto DimExprs4InputDim = [&](int input_idx) -> const symbol::ShapeOrDataDimExprs& { - return shape_analysis->GetShapeOrDataForValue( + return infer_context->GetShapeOrDataForValue( this->operand_source(input_idx)); }; auto DimExprs4SymbolName = @@ -527,7 +527,7 @@ bool GenerateShapeOp::InferSymbolicShape( symbol::ShapeOrDataDimExprs shape_or_data_dim_exprs{ symbol::TensorShapeOrDataDimExprs(shape, substituted_dim_exprs)}; - shape_analysis->SetShapeOrDataForValue(this->out(), shape_or_data_dim_exprs); + infer_context->SetShapeOrDataForValue(this->out(), shape_or_data_dim_exprs); return true; } diff --git a/paddle/cinn/hlir/dialect/operator/ir/manual_op.h b/paddle/cinn/hlir/dialect/operator/ir/manual_op.h index 34c53ed2ebe6ba..396f9929ecb35d 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/manual_op.h +++ b/paddle/cinn/hlir/dialect/operator/ir/manual_op.h @@ -53,7 +53,7 @@ class IR_API GroupOp pir::Block *block() const; std::vector GetOperators() const; - bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis); + bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context); void VerifySig(); void Print(pir::IrPrinter &printer); // NOLINT @@ -102,7 +102,7 @@ class IR_API YieldStoreOp void VerifySig(); - bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis); + bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context); }; class IR_API ConcatOp @@ -123,7 +123,7 @@ class IR_API ConcatOp void VerifySig() const {} - bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis); + bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context); }; class IR_API SplitOp : public pir::Op { @@ -177,7 +177,7 @@ class IR_API GenerateShapeOp pir::Value out() { return result(0); } - bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis); + bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context); static pir::Attribute ConvertSymbolBindingsToAttribute( pir::Builder &builder, const SymbolBindings &symbol_bindings); // NOLINT diff --git a/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc index 8f76fd92d7084d..d653c5a9affb4e 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc @@ -220,14 +220,10 @@ void ApplyCinnPass(::pir::Program* program, ApplyPdToCinnPass(program, CreatePassManager); ApplyCinnPreprocessPass(program, CreatePassManager); ApplyBuildGroupOpPass(program, CreatePassManager); - LOG(INFO) << "====[pir-to-py-code group-ops begin]===" << std::endl - << PirToPyCodeConverter().Convert(*program); - LOG(INFO) << "====[pir-to-py-code group-ops end]==="; + PirToPyCodeConverter().SaveIfFlagEnabled("group_op_programs", *program); ApplyGroupOpPass(program, CreatePassManager); ApplyDivideGroupOpToFusionOpPass(program, CreatePassManager); - LOG(INFO) << "====[pir-to-py-code fusion-ops begin]===" << std::endl - << PirToPyCodeConverter().Convert(*program); - LOG(INFO) << "====[pir-to-py-code fusion-ops end]==="; + PirToPyCodeConverter().SaveIfFlagEnabled("fusion_op_programs", *program); LOG(INFO) << "FusionOp count before lowering : *****[ " << GetOpCount(program->module_op()) << " ]*****"; diff --git a/paddle/cinn/hlir/dialect/operator/transforms/add_store_in_fusion_op_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/add_store_in_fusion_op_pass.cc index d66943dfc8bf93..7e4bf74065fbb8 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/add_store_in_fusion_op_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/add_store_in_fusion_op_pass.cc @@ -45,11 +45,9 @@ class AddYieldStoreInFusionOpPattern auto orignal_base = op->operand_source(i); op->operand(i).set_source(store_op.result(0)); - if (shape_analysis.HasShapeOrDataForValue(orignal_base)) { - shape_analysis.SetShapeOrDataForValue( - store_op.result(0), - shape_analysis.GetShapeOrDataForValue(orignal_base)); - } + shape_analysis.SetShapeOrDataForValue( + store_op.result(0), + shape_analysis.GetShapeOrDataForValue(orignal_base)); } return true; diff --git a/paddle/cinn/hlir/dialect/operator/transforms/check_infer_symbolic_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/check_infer_symbolic_pass.cc index b4ac8265646ef0..89775c658e2fa1 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/check_infer_symbolic_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/check_infer_symbolic_pass.cc @@ -144,7 +144,10 @@ class BlockDimExprsAsserter { PADDLE_THROW(phi::errors::Unimplemented( op->name() + " DOES NOT have InferSymbolicShapeInterface!")); } else { - bool infer_result = interface.InferSymbolicShape(shape_analysis.get()); + // TODO(Hongqing-work): delete this after the shape analysis reconstruct + // is done. + bool infer_result = interface.InferSymbolicShape( + shape_analysis->GetInferSymbolicShapeContext()); PADDLE_ENFORCE_EQ(infer_result, true, ::common::errors::PreconditionNotMet( diff --git a/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.cc index 0ab46ce44f4f1c..c3bf60c601b7d5 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.cc @@ -182,11 +182,9 @@ ::pir::GroupOpsVec CloneOps( pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram()); for (size_t i = 0; i < op->num_results(); ++i) { - if (shape_analysis.HasShapeOrDataForValue(op->result(i))) { - shape_analysis.SetShapeOrDataForValue( - new_op->result(i), - shape_analysis.GetShapeOrDataForValue(op->result(i))); - } + shape_analysis.SetShapeOrDataForValue( + new_op->result(i), + shape_analysis.GetShapeOrDataForValue(op->result(i))); } vec_new_op_list.push_back(new_op); @@ -357,11 +355,9 @@ class CinnGroupClusterPattern // update ir mapping for (size_t i = 0; i < output_values.size(); ++i) { ir_mapping.Add(output_values[i], new_group_op->result(i)); - if (shape_analysis.HasShapeOrDataForValue(output_values[i])) { - shape_analysis.SetShapeOrDataForValue( - new_group_op->result(i), - shape_analysis.GetShapeOrDataForValue(output_values[i])); - } + shape_analysis.SetShapeOrDataForValue( + new_group_op->result(i), + shape_analysis.GetShapeOrDataForValue(output_values[i])); } for (size_t i = 0; i < output_values.size(); ++i) { auto find_it = all_output_values.find(output_values[i]); diff --git a/paddle/cinn/hlir/dialect/operator/transforms/dynamic_reshape_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/dynamic_reshape_pass.cc index 2bebdf4c2149fb..b45323d45ddcb3 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/dynamic_reshape_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/dynamic_reshape_pass.cc @@ -33,18 +33,16 @@ bool ReplaceOpWithReshapeOp(pir::Operation* op, std::vector shape = phi::vectorize( output.type().dyn_cast().dims()); - if (shape_analysis->HasShapeOrDataForValue(op->result(0))) { - const auto& shape_info = - shape_analysis->GetShapeOrDataForValue(op->result(0)).shape(); - int temp_dim = -1; - - for (size_t i = 0; i < shape_info.size(); ++i) { - if (shape_info[i].isa()) { - shape[i] = shape_info[i].Get(); - } else { - shape[i] = temp_dim; - temp_dim = 1; - } + const auto& shape_info = + shape_analysis->GetShapeOrDataForValue(op->result(0)).shape(); + int temp_dim = -1; + + for (size_t i = 0; i < shape_info.size(); ++i) { + if (shape_info[i].isa()) { + shape[i] = shape_info[i].Get(); + } else { + shape[i] = temp_dim; + temp_dim = 1; } } return shape; diff --git a/paddle/cinn/hlir/dialect/operator/transforms/fold_manipulation_ops_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/fold_manipulation_ops_pass.cc index 7d0a3d64246c3d..9b314724167e86 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/fold_manipulation_ops_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/fold_manipulation_ops_pass.cc @@ -53,15 +53,11 @@ bool RemoveOp(pir::Operation* op, pir::PatternRewriter* rewriter) { if (has_dynamic_shape) { auto& shape_analysis = pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram()); - if (shape_analysis.HasShapeOrDataForValue(input) && - shape_analysis.HasShapeOrDataForValue(output)) { - auto input_sym_shape = - GetExprVecFromShape(shape_analysis.GetShapeOrDataForValue(input)); - auto output_sym_shape = - GetExprVecFromShape(shape_analysis.GetShapeOrDataForValue(output)); - return input_sym_shape == output_sym_shape; - } - return false; + auto input_sym_shape = + GetExprVecFromShape(shape_analysis.GetShapeOrDataForValue(input)); + auto output_sym_shape = + GetExprVecFromShape(shape_analysis.GetShapeOrDataForValue(output)); + return input_sym_shape == output_sym_shape; } return GetDims(input) == GetDims(output); }; diff --git a/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc index 11361d34300ef6..240604ae68934a 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc @@ -214,7 +214,10 @@ void InferSymbolicShapeForSubgraph( auto infer_symbolic_shape_interface = op->dyn_cast(); if (infer_symbolic_shape_interface) { - infer_symbolic_shape_interface.InferSymbolicShape(shape_analysis); + // TODO(Hongqing-work): delete this after the shape analysis reconstruct + // is done. + infer_symbolic_shape_interface.InferSymbolicShape( + shape_analysis->GetInferSymbolicShapeContext()); } else { PADDLE_THROW(phi::errors::Unimplemented( op->name() + " DOES NOT have InferSymbolicShapeInterface!")); @@ -348,7 +351,6 @@ bool ReplaceShapeOpsToGenerateShape( auto ShapeOrDataDimExprs4Value = [&shape_analysis]( pir::Value value) -> const symbol::ShapeOrDataDimExprs& { - CHECK(shape_analysis->HasShapeOrDataForValue(value)); return shape_analysis->GetShapeOrDataForValue(value); }; std::optional opt_generated_shape = diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_dynamic_to_static_dim_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_dynamic_to_static_dim_pass.cc index 72219287fe3e33..f69a4f91153862 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_dynamic_to_static_dim_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_dynamic_to_static_dim_pass.cc @@ -104,9 +104,6 @@ class DynamicToStaticConverter { } bool Convert() { - if (!IsSymbolFullyInfered()) { - return false; - } bool updated = false; VisitEachValue(fusion_op_, [&](pir::Value value) { updated |= UpdateValueShape(value); @@ -116,16 +113,6 @@ class DynamicToStaticConverter { } private: - bool IsSymbolFullyInfered() { - bool is_infered = true; - VisitEachValue(fusion_op_, [&](pir::Value value) { - if (!shape_analysis_->HasShapeOrDataForValue(value)) { - is_infered = false; - } - }); - return is_infered; - } - DimExpr4SymbolName InitDimExpr4SymbolName() { const auto* map = GetGlobalDynamicToStaticDimMap(); CHECK(map->has_value()); @@ -178,7 +165,6 @@ class DynamicToStaticConverter { bool UpdateValueShape(pir::Value value) { bool update = false; - CHECK(shape_analysis_->HasShapeOrDataForValue(value)); const auto& origin_shape = GetOriginValueShape(value); const auto& target_shape = GetTargetValueShape(value); PADDLE_ENFORCE_EQ( diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_static_dim_to_dynamic_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_static_dim_to_dynamic_pass.cc index e20cab270cdd34..3cf1741e47d173 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_static_dim_to_dynamic_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_static_dim_to_dynamic_pass.cc @@ -150,7 +150,6 @@ struct StaticDimToDynamicConverter { &pir::ShapeAnalysisManager::Instance().Get( this->fusion_op->GetParentProgram()); ForEachValue([&](pir::Value value) { - CHECK(shape_analysis->HasShapeOrDataForValue(value)); const auto& origin_shape = GetOriginValueShape(value); const auto& target_shape = GetTargetValueShape( shape_analysis->GetShapeOrDataForValue(value).shape()); @@ -369,26 +368,8 @@ struct StaticDimToDynamicConverter { pir::Value value, int64_t constant, const std::string& symbol) { - if (shape_analysis->HasShapeOrDataForValue(value)) { - const auto& old = shape_analysis->GetShapeOrDataForValue(value).shape(); - return ConvertShapeOrDataDimExprs(Converter, old, constant, symbol); - } else { - auto& dims = value.type().dyn_cast<::pir::DenseTensorType>().dims(); - const auto& int_dims = ::common::vectorize(dims); - std::vector old{}; - for (int dim : int_dims) { - old.emplace_back(static_cast(dim)); - } - const auto& opt_exprs = - ConvertShapeOrDataDimExprs(Converter, old, constant, symbol); - if (opt_exprs.has_value()) { - return opt_exprs.value(); - } else { - return symbol::ShapeOrDataDimExprs{ - symbol::TensorShapeOrDataDimExprs(old)}; - } - } - PADDLE_THROW(phi::errors::Fatal("Dead code")); + const auto& old = shape_analysis->GetShapeOrDataForValue(value).shape(); + return ConvertShapeOrDataDimExprs(Converter, old, constant, symbol); } template diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/simplify_dim_expr_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/simplify_dim_expr_pass.cc index 5d3baeb21f92ae..2f43c7239ec67f 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/simplify_dim_expr_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/simplify_dim_expr_pass.cc @@ -101,20 +101,14 @@ void SimplifyDimExpr(pir::Operation* module_op) { VisitEachOp(module_op, [&](pir::Operation& op) { VisitEachValue(op, [&](pir::Value value) { - if (!shape_analysis->HasShapeOrDataForValue(value)) { - VLOG(4) << "SimplifyDimExpr: shape_analysis can't find ShapeOrData for " - "value of the op:" - << op.name(); - } else { - const symbol::ShapeOrDataDimExprs& shape_or_data = - shape_analysis->GetShapeOrDataForValue(value); - VLOG(8) << op.name() << " origin_shape_or_data: " << shape_or_data; - symbol::ShapeOrDataDimExprs simplified_shape_or_data = - SimplifyShapeOrData(shape_or_data); - VLOG(8) << op.name() - << " simplified_shape_or_data: " << simplified_shape_or_data; - shape_analysis->SetShapeOrDataForValue(value, simplified_shape_or_data); - } + const symbol::ShapeOrDataDimExprs& shape_or_data = + shape_analysis->GetShapeOrDataForValue(value); + VLOG(8) << op.name() << " origin_shape_or_data: " << shape_or_data; + symbol::ShapeOrDataDimExprs simplified_shape_or_data = + SimplifyShapeOrData(shape_or_data); + VLOG(8) << op.name() + << " simplified_shape_or_data: " << simplified_shape_or_data; + shape_analysis->SetShapeOrDataForValue(value, simplified_shape_or_data); }); if (op.num_results() > 0) { pir::shape::SetShapeAttrForOp( diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/single_op_fallback_to_phi.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/single_op_fallback_to_phi.cc index f859c09400c165..85a6a9c0677a03 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/single_op_fallback_to_phi.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/single_op_fallback_to_phi.cc @@ -64,15 +64,9 @@ class FusionOpPattern : public pir::OpRewritePattern { for (size_t i = 0; i < fusion_op.num_results(); ++i) { rewriter.ReplaceAllUsesWith(fusion_op.result(i), paddle_op.value()->result(i)); - if (shape_analysis.HasShapeOrDataForValue(fusion_op.result(i))) { - shape_analysis.SetShapeOrDataForValue( - paddle_op.value()->result(i), - shape_analysis.GetShapeOrDataForValue(fusion_op.result(i))); - } else { - LOG(WARNING) << "No shape_data for " - << fusion_op.result(i).defining_op()->name() << "_result_" - << i << ", this may cause error in dynamic shape"; - } + shape_analysis.SetShapeOrDataForValue( + paddle_op.value()->result(i), + shape_analysis.GetShapeOrDataForValue(fusion_op.result(i))); } rewriter.EraseOp(fusion_op); @@ -129,7 +123,7 @@ class FusionOpPattern : public pir::OpRewritePattern { pir::PatternRewriter& rewriter) const { // NOLINT auto it = op_handler_map().find(op->name()); if (it == op_handler_map().end()) { - LOG(WARNING) << "No fallback handler for op: " << op->name(); + VLOG(4) << "No fallback handler for op: " << op->name(); return std::nullopt; } return (this->*(it->second))(op, rewriter); diff --git a/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc index 83d3cdce2173aa..56a2aa07d70969 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc @@ -46,7 +46,10 @@ pir::Value GetOutputDimTensor(pir::PatternRewriter* rewriter, for (auto* op : std::vector{x_shape_op, y_shape_op, shape_broadcast_op}) { auto infer_symbolic_shape_interface = op->dyn_cast(); - infer_symbolic_shape_interface.InferSymbolicShape(shape_analysis); + // TODO(Hongqing-work): delete this after the shape analysis reconstruct is + // done. + infer_symbolic_shape_interface.InferSymbolicShape( + shape_analysis->GetInferSymbolicShapeContext()); } return shape_broadcast_op->result(0); } diff --git a/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/collect_sym_expr.cc b/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/collect_sym_expr.cc index 7526ad1ab63090..80a7d819f86127 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/collect_sym_expr.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/collect_sym_expr.cc @@ -84,9 +84,6 @@ CollectSubstituteDimExprMap( std::unordered_set base_dim_expr_set; VisitEachInputValue(group, [&](::pir::Value value) { - if (!shape_analysis.HasShapeOrDataForValue(value)) { - return; - } auto& shape_or_data = shape_analysis.GetShapeOrDataForValue(value); VisitEachDimExpr(shape_or_data, [&](const symbol::DimExpr& dim_expr) { if (IsComplicatedDimExpr(dim_expr) && @@ -146,11 +143,11 @@ symbol::ShapeOrDataDimExprs TrySubstitute( } void InferSymbolicShapeForOperation( - pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis) { + pir::Operation* op, pir::InferSymbolicShapeContext* infer_context) { auto infer_symbolic_shape_interface = op->dyn_cast(); if (infer_symbolic_shape_interface) { - infer_symbolic_shape_interface.InferSymbolicShape(shape_analysis); + infer_symbolic_shape_interface.InferSymbolicShape(infer_context); } else { PADDLE_THROW(phi::errors::Unimplemented( op->name() + " DOES NOT have InferSymbolicShapeInterface!")); @@ -164,8 +161,7 @@ GetGroupValue2Shape(const OpLoweringGroupPtr& group, for (auto op : group->ops()) { for (size_t i = 0; i < op->num_operands(); ++i) { auto operand = op->operand_source(i); - if (operand && value2shape.find(operand) == value2shape.end() && - shape_analysis.HasShapeOrDataForValue(operand)) { + if (operand && value2shape.find(operand) == value2shape.end()) { VLOG(6) << "Add value_to_shape_or_data_exprs for " << operand.impl(); value2shape.insert( {operand, shape_analysis.GetShapeOrDataForValue(operand)}); @@ -173,8 +169,7 @@ GetGroupValue2Shape(const OpLoweringGroupPtr& group, } for (size_t i = 0; i < op->num_results(); ++i) { auto result = op->result(i); - if (result && value2shape.find(result) == value2shape.end() && - shape_analysis.HasShapeOrDataForValue(result)) { + if (result && value2shape.find(result) == value2shape.end()) { VLOG(6) << "Add value_to_shape_or_data_exprs for " << result.impl(); value2shape.insert( {result, shape_analysis.GetShapeOrDataForValue(result)}); @@ -212,11 +207,13 @@ CreateGroupShapeOrDataExprs( // process the result values of each op. for (auto* op : group->ops()) { - InferSymbolicShapeForOperation(op, &local_shape_analysis); + // TODO(Hongqing-work): delete this after the shape analysis reconstruct is + // done. + InferSymbolicShapeForOperation( + op, local_shape_analysis.GetInferSymbolicShapeContext()); for (size_t i = 0; i < op->num_results(); ++i) { auto result = op->result(i); - if (result && !value2shape.count(result) && - local_shape_analysis.HasShapeOrDataForValue(result)) { + if (result && !value2shape.count(result)) { VLOG(6) << "Add value_to_shape_or_data_exprs for " << result.impl(); value2shape.insert( {result, local_shape_analysis.GetShapeOrDataForValue(result)}); diff --git a/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/lower_cinn_fusion_op_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/lower_cinn_fusion_op_pass.cc index 326b2126758ed0..a32fd54571977f 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/lower_cinn_fusion_op_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/lower_cinn_fusion_op_pass.cc @@ -89,15 +89,9 @@ class FusionOpPattern : public pir::OpRewritePattern { for (size_t i = 0; i < fusion_op.num_results(); ++i) { rewriter.ReplaceAllUsesWith(fusion_op.result(i), compiled_op->result(i)); - if (shape_analysis.HasShapeOrDataForValue(fusion_op.result(i))) { - shape_analysis.SetShapeOrDataForValue( - compiled_op->result(i), - shape_analysis.GetShapeOrDataForValue(fusion_op.result(i))); - } else { - LOG(WARNING) << "No shape_data for " - << fusion_op.result(i).defining_op()->name() << "_result_" - << i; - } + shape_analysis.SetShapeOrDataForValue( + compiled_op->result(i), + shape_analysis.GetShapeOrDataForValue(fusion_op.result(i))); } rewriter.EraseOp(fusion_op); return true; diff --git a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc index f11d66e1911f8c..2bddee42493594 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc @@ -569,9 +569,6 @@ class SplitOpPattern : public pir::OpRewritePattern { pir::PatternRewriter &rewriter) const { // NOLINT const int axis = GetAxis(split); const std::vector §ions = GetSections(split); - for (auto section : sections) { - VLOG(0) << " " << section; - } const int index = slice->attribute<::pir::Int32Attribute>("index").data(); int64_t start = std::accumulate(sections.begin(), sections.begin() + index, 0); diff --git a/paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.cc b/paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.cc index 09493545652690..32f6d67d75d468 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.cc @@ -15,6 +15,7 @@ #include "paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.h" #include #include +#include #include #include #include @@ -23,6 +24,7 @@ #include "paddle/cinn/hlir/dialect/operator/transforms/type_adt_type_id.h" #include "paddle/common/adt_type_id.h" #include "paddle/common/ddim.h" +#include "paddle/common/flags.h" #include "paddle/common/overloaded.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.h" #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" @@ -31,6 +33,7 @@ #include "paddle/pir/include/core/program.h" #include "paddle/pir/include/dialect/control_flow/ir/cf_op.h" #include "paddle/pir/include/dialect/shape/ir/shape_attribute.h" +COMMON_DECLARE_string(logging_pir_py_code_dir); namespace cinn::dialect::ir { @@ -79,8 +82,27 @@ struct OpPyCode { constexpr int kDefaultIndentSize = 2; +namespace { + +int64_t GetAutoIncrementalId() { + static std::atomic seq_no(0); + return seq_no++; +} + +} // namespace + struct PirToPyCodeConverterHelper { - PirToPyCodeConverterHelper() : indent_size_(kDefaultIndentSize) {} + explicit PirToPyCodeConverterHelper(const pir::Program* program) + : program_(program), + indent_size_(kDefaultIndentSize), + seq_no_(GetAutoIncrementalId()) {} + + std::string Convert() { return Convert(*program_); } + + private: + const pir::Program* program_; + const int indent_size_; + int64_t seq_no_; std::string Convert(const pir::Program& program) { auto istrings = ConvertMethodsToPyClass(program.module_op(), [&]() { @@ -92,7 +114,6 @@ struct PirToPyCodeConverterHelper { return ConvertIStringsToString(istrings); } - private: IStrings DefineInit(const pir::ModuleOp& module) { IStrings def_init; def_init.push_back(IString("def __init__(self):")); @@ -562,7 +583,7 @@ struct PirToPyCodeConverterHelper { std::string operator()(AdtTypeId<::pir::VectorType>) { std::stringstream ss; - const auto& name = ::pir::DenseTensorType::name(); + const auto& name = ::pir::VectorType::name(); const auto& vec_type = type.dyn_cast<::pir::VectorType>(); ss << "self." << name << "("; for (int i = 0; i < vec_type.size(); ++i) { @@ -783,15 +804,15 @@ struct PirToPyCodeConverterHelper { IStrings ret; { std::stringstream ss; - ss << "class " << GetPyClassName(module) << ":"; + ss << "class " << GetPyClassName() << ":"; ret.push_back(IString(ss.str())); } PushBackIndented(&ret, GetBody()); return ret; } - std::string GetPyClassName(const pir::ModuleOp& module) { - return std::string("Program"); + std::string GetPyClassName() { + return std::string("PirProgram_") + std::to_string(seq_no_); } std::string ConvertIStringsToString(const IStrings& istrings) { @@ -819,14 +840,29 @@ struct PirToPyCodeConverterHelper { ret->push_back(Indent(istring)); } } - - const int indent_size_; }; } // namespace -std::string PirToPyCodeConverter::Convert(const pir::Program& program) const { - return PirToPyCodeConverterHelper().Convert(program); +void PirToPyCodeConverter::SaveIfFlagEnabled( + const std::string& tag, const pir::Program& program) const { + if (FLAGS_logging_pir_py_code_dir == "") return; + const std::string file_path = + FLAGS_logging_pir_py_code_dir + "/" + tag + ".py"; + const std::string content = PirToPyCodeConverterHelper(&program).Convert(); + static std::mutex mutex; + std::unique_lock lock(mutex); + static std::unordered_map once_flags; + std::call_once(once_flags[file_path], [&] { + std::ofstream ofs; + ofs.open(file_path.c_str(), std::ios::out | std::ios::trunc); + ofs.close(); + }); + std::ofstream ofs; + ofs.open(file_path.c_str(), std::ios::out | std::ios::app); + if (!ofs.is_open()) return; + ofs << content << std::endl; + ofs.close(); } } // namespace cinn::dialect::ir diff --git a/paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.h b/paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.h index e6c22badd4c85a..bbb36acd526a6d 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.h +++ b/paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.h @@ -30,7 +30,8 @@ class PirToPyCodeConverter { PirToPyCodeConverter(const PirToPyCodeConverter&) = delete; PirToPyCodeConverter(PirToPyCodeConverter&&) = delete; - std::string Convert(const pir::Program& program) const; + void SaveIfFlagEnabled(const std::string& tag, + const pir::Program& program) const; }; } // namespace cinn::dialect::ir diff --git a/paddle/cinn/hlir/dialect/operator/transforms/replace_dynamic_expand_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/replace_dynamic_expand_pass.cc index 3690a91eb4d370..3adf8cc6110ec8 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/replace_dynamic_expand_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/replace_dynamic_expand_pass.cc @@ -52,15 +52,12 @@ class DynamicExpandOpPattern const auto& GetOutputShapeByDimExpr = [&]() -> std::vector { std::vector out_shape(out_rank, -1); - if (shape_analysis.HasShapeOrDataForValue(op->result(0))) { - VLOG(3) << "found shape dialect"; - auto shape_info = - shape_analysis.GetShapeOrDataForValue(op->result(0)).shape(); - - for (size_t i = 0; i < shape_info.size(); ++i) { - if (shape_info[i].isa()) { - out_shape[i] = shape_info[i].Get(); - } + auto shape_info = + shape_analysis.GetShapeOrDataForValue(op->result(0)).shape(); + + for (size_t i = 0; i < shape_info.size(); ++i) { + if (shape_info[i].isa()) { + out_shape[i] = shape_info[i].Get(); } } return out_shape; @@ -74,8 +71,6 @@ class DynamicExpandOpPattern auto& shape_analysis = pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram()); - CHECK(shape_analysis.HasShapeOrDataForValue(op.result(0))) - << "Can't find DimExpr for output of reshape in shape_analysis."; shape_analysis.SetShapeOrDataForValue( broadcast->result(0), shape_analysis.GetShapeOrDataForValue(op.result(0))); diff --git a/paddle/cinn/hlir/framework/pir/fusion_info.cc b/paddle/cinn/hlir/framework/pir/fusion_info.cc index 58ce8febfa3302..db9f111b0ef07c 100644 --- a/paddle/cinn/hlir/framework/pir/fusion_info.cc +++ b/paddle/cinn/hlir/framework/pir/fusion_info.cc @@ -164,11 +164,6 @@ void FusionInfo::ParseInputDimExprs(const OpLoweringGroup& group) { [&](const ::pir::Value& value) -> bool { auto& shape_analysis = ::pir::ShapeAnalysisManager::Instance().Get(group.GetParentProgram()); - if (!shape_analysis.HasShapeOrDataForValue(value)) { - VLOG(4) << "FusionInfo: input value doesn't have shape or data, skip it." - << value.impl(); - return false; - } input_dim_exprs_.push_back(shape_analysis.GetShapeOrDataForValue(value)); return true; }; diff --git a/paddle/cinn/hlir/framework/pir/utils.cc b/paddle/cinn/hlir/framework/pir/utils.cc index b08ab2e2c94f57..d041748f7e960c 100644 --- a/paddle/cinn/hlir/framework/pir/utils.cc +++ b/paddle/cinn/hlir/framework/pir/utils.cc @@ -287,51 +287,6 @@ bool IsSmallNumelOp(const ::pir::Operation& op) { return (0 <= max_value_numel && max_value_numel < 32); } -bool IsShapeComputeOp(const ::pir::Operation& op) { - auto& shape_analysis = ::pir::ShapeAnalysisManager::Instance().Get( - op.GetParent()->parent_program()); - if (op.num_operands() == 0) { - return false; - } - bool all_input_has_shape_data = true; - for (uint32_t i = 0; i < op.num_operands(); ++i) { - if (shape_analysis.HasShapeOrDataForValue(op.operand_source(i))) { - const auto& shape_expr = - shape_analysis.GetShapeOrDataForValue(op.operand_source(i)); - if (shape_expr.isa() && - shape_expr.data()) { // has shape data - continue; - } - } - all_input_has_shape_data = false; - break; - } - - for (uint32_t i = 0; i < op.num_results(); ++i) { - if (shape_analysis.HasShapeOrDataForValue(op.result(i))) { - const auto& shape_expr = - shape_analysis.GetShapeOrDataForValue(op.result(i)); - if (shape_expr.isa() && - shape_expr.data()) { // has shape data - continue; - } - } - all_input_has_shape_data = false; - break; - } - - return all_input_has_shape_data; -} - -// TODO(zyfncg): This function is a temporary solution, we need to remove it in -// the future. -bool IsTempDenySpecialOp(const ::pir::Operation& op) { - if (op.name() == "cinn_op.generate_shape") { - return false; - } - return IsShapeComputeOp(op); -} - // Mainly used for pd_to_cinn_pass and reused in IsSupportInCinn function. bool IsDeniedInCinn(const ::pir::Operation& op) { if (FLAGS_disable_dyshape_in_train && HaveUnkDim(op)) { diff --git a/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc b/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc index 262922a7ef7b97..6bd7513da39d71 100644 --- a/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc +++ b/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc @@ -18,7 +18,6 @@ #include "paddle/cinn/ir/group_schedule/tactic/arrange_storage_tactic.h" #include "paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.h" #include "paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.h" -#include "paddle/cinn/ir/group_schedule/tactic/loop_reorder_alignment_tactic.h" #include "paddle/cinn/ir/group_schedule/tactic/optimize_reduction_tactic.h" #include "paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.h" #include "paddle/cinn/ir/group_schedule/tactic/tile_tactic.h" diff --git a/paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt b/paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt index b6a2f067606468..f92d2caa966c2d 100644 --- a/paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt +++ b/paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt @@ -6,5 +6,4 @@ gather_srcs(cinnapi_src SRCS compute_inline_tactic.cc) gather_srcs(cinnapi_src SRCS optimize_reduction_tactic.cc) gather_srcs(cinnapi_src SRCS bind_cuda_tactic.cc) gather_srcs(cinnapi_src SRCS arrange_storage_tactic.cc) -gather_srcs(cinnapi_src SRCS loop_reorder_alignment_tactic.cc) gather_srcs(cinnapi_src SRCS tile_first_general_tactic.cc) diff --git a/paddle/cinn/ir/group_schedule/tactic/loop_reorder_alignment_tactic.cc b/paddle/cinn/ir/group_schedule/tactic/loop_reorder_alignment_tactic.cc deleted file mode 100644 index 8bf8a98cce2514..00000000000000 --- a/paddle/cinn/ir/group_schedule/tactic/loop_reorder_alignment_tactic.cc +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright (c) 2024 CINN Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/cinn/ir/group_schedule/tactic/loop_reorder_alignment_tactic.h" -#include -#include -#include "paddle/cinn/ir/ir.h" - -namespace cinn { -namespace ir { - -class LoopReorderAlignmentTactic final : public ScheduleTactic { - public: - void Init(ScheduleContext* context) override; - - void Apply(ir::IRSchedule* sch, const std::string& block_id) override; - - std::string TacticName() const override { - return "LoopReorderAlignmentTactic"; - } - - private: - bool NeedReorderLoops(); - - std::vector GetNewOrder(); - - void UpdateBaseRank(ir::IRSchedule* sch, const std::string& block_id); - - void DoReorder(ir::IRSchedule* sch, const std::string& block_id); - - private: - ScheduleContext* context_; - size_t base_rank_; - bool need_reorder_loops_; - std::vector new_order_; -}; - -void LoopReorderAlignmentTactic::Init(ScheduleContext* context) { - context_ = context; - base_rank_ = 0; - need_reorder_loops_ = NeedReorderLoops(); - new_order_ = GetNewOrder(); -} - -void LoopReorderAlignmentTactic::Apply(ir::IRSchedule* sch, - const std::string& block_id) { - if (!ir::IsReduceInitTensorName(block_id)) { - UpdateBaseRank(sch, block_id); - } - - if (need_reorder_loops_ && !ir::IsReduceInitTensorName(block_id)) { - DoReorder(sch, block_id); - } -} - -void LoopReorderAlignmentTactic::UpdateBaseRank(ir::IRSchedule* sch, - const std::string& block_id) { - auto loops = sch->GetLoops(block_id); - if (base_rank_ == 0) { - base_rank_ = loops.size(); - } else { - if (base_rank_ != loops.size()) { - throw std::runtime_error("loops rank not same "); - } - } -} - -bool LoopReorderAlignmentTactic::NeedReorderLoops() { - const auto HasReduceAxis = [&]() { - return context_->config.base_info->reduce_axis.size() > 0; - }; - if (!HasReduceAxis()) { - return false; - } - - const auto HasNonLastDimReduce = [&]() { - std::vector vec_reduce_axis = - context_->config.base_info->reduce_axis; - std::sort(vec_reduce_axis.begin(), vec_reduce_axis.end()); - return vec_reduce_axis.front() != - context_->config.base_info->data_rank - vec_reduce_axis.size(); - }; - - return HasNonLastDimReduce(); -} - -std::vector LoopReorderAlignmentTactic::GetNewOrder() { - std::set reduce_set(context_->config.base_info->reduce_axis.begin(), - context_->config.base_info->reduce_axis.end()); - - std::vector new_order; - for (int32_t i = 0; i < context_->config.base_info->data_rank; ++i) { - if (!reduce_set.count(i)) { - new_order.push_back(i); - } - } - for (auto axis : context_->config.base_info->reduce_axis) { - new_order.push_back(axis); - } - - return new_order; -} - -void LoopReorderAlignmentTactic::DoReorder(ir::IRSchedule* sch, - const std::string& block_id) { - const auto IsReduceBlock = [&](const std::string& block_id) { - return context_->config.base_info->reduce_tensor_names.count(block_id) > 0; - }; - if (IsReduceBlock(block_id)) { - return; - } - - sch->Reorder(block_id, new_order_); -} - -std::unique_ptr CreateLoopReorderAlignmentTactic() { - return std::make_unique(); -} - -} // namespace ir -} // namespace cinn diff --git a/paddle/cinn/ir/group_schedule/tactic/loop_reorder_alignment_tactic.h b/paddle/cinn/ir/group_schedule/tactic/loop_reorder_alignment_tactic.h deleted file mode 100644 index ee4864a5ecf926..00000000000000 --- a/paddle/cinn/ir/group_schedule/tactic/loop_reorder_alignment_tactic.h +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) 2024 CINN Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include "paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h" - -namespace cinn { -namespace ir { - -std::unique_ptr CreateLoopReorderAlignmentTactic(); - -} // namespace ir -} // namespace cinn diff --git a/paddle/common/flags.cc b/paddle/common/flags.cc index ac4e8e7f0696de..5de24f5dbfb20f 100644 --- a/paddle/common/flags.cc +++ b/paddle/common/flags.cc @@ -1435,6 +1435,10 @@ PHI_DEFINE_EXPORTED_bool(enable_pir_with_pt_in_dy2st, true, "Enable new IR in executor"); +PHI_DEFINE_EXPORTED_string(logging_pir_py_code_dir, + "", + "the logging directory to save pir py code"); + /** * Using PIR API in Python * Name: enable_pir_api diff --git a/paddle/fluid/distributed/collective/reducer.cc b/paddle/fluid/distributed/collective/reducer.cc index a49dc15199d8b5..729cf467ea7675 100644 --- a/paddle/fluid/distributed/collective/reducer.cc +++ b/paddle/fluid/distributed/collective/reducer.cc @@ -162,7 +162,7 @@ struct ConcatTensorsForAllReduce { void operator()(const DeviceContext &context, const std::vector &dense_tensors_, Tensor *p_dense_contents) { - operators::math::ConcatFunctor concat_functor_; + phi::funcs::ConcatFunctor concat_functor_; concat_functor_( context, dense_tensors_, @@ -191,7 +191,7 @@ struct SplitTensorsForAllReduce { shape_refer.emplace_back(&tensor); } - operators::math::SplitFunctor split_functor_; + phi::funcs::SplitFunctor split_functor_; split_functor_(context, *in, shape_refer, 0, &outs); } }; diff --git a/paddle/fluid/distributed/collective/reducer.h b/paddle/fluid/distributed/collective/reducer.h index c16b194ac9c073..661675c449117e 100644 --- a/paddle/fluid/distributed/collective/reducer.h +++ b/paddle/fluid/distributed/collective/reducer.h @@ -22,12 +22,12 @@ #include "paddle/fluid/eager/api/utils/hook_utils.h" #include "paddle/fluid/eager/autograd_meta.h" #include "paddle/fluid/eager/utils.h" -#include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/phi/api/include/api.h" #include "paddle/phi/api/include/fused_api.h" #include "paddle/phi/api/include/tensor.h" #include "paddle/phi/common/data_type.h" +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/utils/string/string_helper.h" diff --git a/paddle/fluid/distributed/index_dataset/index_sampler.h b/paddle/fluid/distributed/index_dataset/index_sampler.h index e8fbf39ce9341b..f32cd62445d40c 100644 --- a/paddle/fluid/distributed/index_dataset/index_sampler.h +++ b/paddle/fluid/distributed/index_dataset/index_sampler.h @@ -18,8 +18,8 @@ #include "paddle/fluid/distributed/index_dataset/index_wrapper.h" #include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/math/sampler.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/kernels/funcs/math/sampler.h" namespace paddle { namespace distributed { @@ -107,9 +107,8 @@ class LayerWiseSampler : public IndexSampler { while (layer_index >= start_sample_layer_) { auto layer_codes = tree_->GetLayerCodes(layer_index); layer_ids_.push_back(tree_->GetNodes(layer_codes)); - auto sampler_temp = - std::make_shared( - layer_ids_[idx].size() - 1, seed_); + auto sampler_temp = std::make_shared( + layer_ids_[idx].size() - 1, seed_); sampler_vec_.push_back(sampler_temp); layer_index--; idx++; @@ -131,7 +130,7 @@ class LayerWiseSampler : public IndexSampler { std::shared_ptr tree_{nullptr}; int seed_{0}; int start_sample_layer_{1}; - std::vector> sampler_vec_; + std::vector> sampler_vec_; std::vector> layer_ids_; }; diff --git a/paddle/fluid/eager/to_static/run_program_op_func.h b/paddle/fluid/eager/to_static/run_program_op_func.h index b6bdb28380736e..c6c24ae47a7d24 100644 --- a/paddle/fluid/eager/to_static/run_program_op_func.h +++ b/paddle/fluid/eager/to_static/run_program_op_func.h @@ -296,16 +296,7 @@ inline void pir_run_program_ad_func( grad_node->SetStepScope(step_scope); // just for set useable. - // Set Grad out rank as same as fwd input and set stop gradient to bwd - // NOTE(@xiongkun): Not every tensor in x(list of tensor) is required - // gradient. for example: x[1] is not used for output, the x[1] is ignored. - - std::vector x_require_grad; - for (size_t i = 0; i < x.size(); ++i) { - x_require_grad.push_back(&x[i]); - } - - grad_node->SetGradOutMeta(x_require_grad, /*slot id*/ 0); + grad_node->SetGradOutMeta(x, /*slot id*/ 0); grad_node->SetGradOutMeta(params, /*slot id*/ 1); // TODO(@xiongkun): rewrite by new ir representation. diff --git a/paddle/fluid/eager/to_static/run_program_op_node.h b/paddle/fluid/eager/to_static/run_program_op_node.h index 75d812bf66e5e2..853a0c445797c9 100644 --- a/paddle/fluid/eager/to_static/run_program_op_node.h +++ b/paddle/fluid/eager/to_static/run_program_op_node.h @@ -467,21 +467,16 @@ inline void PirRunProgramAPI( auto param_values = PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fp")); - auto *forward_global_block = - PADDLE_GET_CONST(::pir::Block *, attrs.at("forward_global_block")); - auto *backward_global_block = - PADDLE_GET_CONST(::pir::Block *, attrs.at("backward_global_block")); - - auto *forward_program = - forward_global_block->GetParentOp()->GetParentProgram(); + std::shared_ptr<::pir::Program> forward_program = PADDLE_GET_CONST( + std::shared_ptr<::pir::Program>, attrs.at("forward_program")); + std::shared_ptr<::pir::Program> backward_program = PADDLE_GET_CONST( + std::shared_ptr<::pir::Program>, attrs.at("backward_program")); if (FLAGS_print_ir) { std::ostringstream print_stream; print_stream << "ForwardProgram is :\n"; forward_program->Print(print_stream); if (!is_test) { - auto *backward_program = - backward_global_block->GetParentOp()->GetParentProgram(); print_stream << "BackwardProgram is:\n"; backward_program->Print(print_stream); } else { @@ -509,12 +504,12 @@ inline void PirRunProgramAPI( << program_id; // Step 1. share input_vars & parameters into scope details::ShareTensorsIntoScopeByValue( - forward_global_block, x, input_values, global_inner_scope); + forward_program->block(), x, input_values, global_inner_scope); details::ShareTensorsIntoScopeByValue( - forward_global_block, params, param_values, global_inner_scope); + forward_program->block(), params, param_values, global_inner_scope); // Step 2. create new interpretercore auto passed_kernel_program = - paddle::framework::ApplyIrPass(forward_program, place); + paddle::framework::ApplyIrPass(forward_program.get(), place); if (FLAGS_print_ir) { std::ostringstream print_stream; print_stream << "LoweredProgram( AfterPass ) is :\n"; @@ -535,22 +530,22 @@ inline void PirRunProgramAPI( // update interpretercore skip_gc_var auto skip_names = details::GetNameFromValue( - forward_global_block, middle_values, false, true); + forward_program->block(), middle_values, false, true); auto skip_names_set = std::set(skip_names.begin(), skip_names.end()); auto no_need_buffer_values = PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("no_need_buffers")); auto no_need_buffer_names = details::GetNameFromValue( - forward_global_block, no_need_buffer_values, false, true); + forward_program->block(), no_need_buffer_values, false, true); for (auto &name : no_need_buffer_names) { VLOG(4) << "Find no need buffer vars with name:" << name; skip_names_set.erase(name); } skip_names = details::GetNameFromValue( - forward_global_block, output_values, false, true); + forward_program->block(), output_values, false, true); skip_names_set.insert(skip_names.begin(), skip_names.end()); skip_names = details::GetNameFromValue( - forward_global_block, input_values, true, false); + forward_program->block(), input_values, true, false); skip_names_set.insert(skip_names.begin(), skip_names.end()); details::print_collection(skip_names_set); interpreter_core->SetSkipGcVars(skip_names_set); @@ -576,9 +571,9 @@ inline void PirRunProgramAPI( interpreter_core = cached_value.core_; // Step 2. update scope for cache interpretercore details::ShareTensorsIntoScopeByValue( - forward_global_block, x, input_values, global_inner_scope); + forward_program->block(), x, input_values, global_inner_scope); details::ShareTensorsIntoScopeByValue( - forward_global_block, params, param_values, global_inner_scope); + forward_program->block(), params, param_values, global_inner_scope); // TODO(xiongkun): new ir how to build scope. // if (interpreter_core->GetVariableScope()->GetMutableScope() != // global_inner_scope) { @@ -589,7 +584,7 @@ inline void PirRunProgramAPI( } // interpretercore run - if (!forward_global_block->empty()) { + if (!forward_program->block()->empty()) { paddle::platform::RecordEvent record_event( "interpreter_core_run", paddle::platform::TracerEventType::UserDefined, @@ -602,7 +597,7 @@ inline void PirRunProgramAPI( "fetch_and_gc", paddle::platform::TracerEventType::UserDefined, 1); // Get Output, and Middle Outputs details::ShareTensorsFromScopeByValue( - forward_global_block, out, output_values, global_inner_scope); + forward_program->block(), out, output_values, global_inner_scope); VLOG(3) << paddle::framework::GenScopeTreeDebugInfo(out_scope_vec->front()); @@ -1041,10 +1036,8 @@ inline void PirRunProgramGradAPI( VLOG(4) << "global_inner_scope:" << global_inner_scope; - auto *backward_global_block = - PADDLE_GET_CONST(::pir::Block *, attrs.at("backward_global_block")); - auto *backward_program = - backward_global_block->GetParentOp()->GetParentProgram(); + std::shared_ptr<::pir::Program> backward_program = PADDLE_GET_CONST( + std::shared_ptr<::pir::Program>, attrs.at("backward_program")); auto output_grad_values = PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("bo_g")); @@ -1064,8 +1057,10 @@ inline void PirRunProgramGradAPI( details::Trans2ContiguousTensorsInplace(out_grad); // share x, param, middles, output_grads, out into scope. - details::ShareTensorsIntoScopeByValue( - backward_global_block, out_grad, output_grad_values, global_inner_scope); + details::ShareTensorsIntoScopeByValue(backward_program->block(), + out_grad, + output_grad_values, + global_inner_scope); auto &cache = paddle::framework::InterpreterCoreInfoCache::Instance(); std::shared_ptr interpreter_core = @@ -1082,7 +1077,7 @@ inline void PirRunProgramGradAPI( VLOG(2) << "No interpretercore cache, so create a new interpretercore"; // Step 1. share input_vars & parameters into scope auto passed_kernel_program = - paddle::framework::ApplyIrPass(backward_program, place); + paddle::framework::ApplyIrPass(backward_program.get(), place); const auto &new_block = passed_kernel_program->block(); passed_kernel_program = paddle::framework::ApplyRemoveShadowFeedPass( @@ -1124,10 +1119,10 @@ inline void PirRunProgramGradAPI( // get all eager gc vars std::set skip_eager_delete_vars; auto skip_names = details::GetNameFromValue( - backward_global_block, x_grad_values, false, true); + backward_program->block(), x_grad_values, false, true); skip_eager_delete_vars.insert(skip_names.begin(), skip_names.end()); skip_names = details::GetNameFromValue( - backward_global_block, p_grad_values, false, true); + backward_program->block(), p_grad_values, false, true); skip_eager_delete_vars.insert(skip_names.begin(), skip_names.end()); interpreter_core->SetSkipGcVars(skip_eager_delete_vars); cache.UpdateSkipEagerDeleteVars(program_id, @@ -1160,7 +1155,7 @@ inline void PirRunProgramGradAPI( } } - if (!backward_global_block->empty()) { + if (!backward_program->block()->empty()) { paddle::platform::RecordEvent record_event( "interpreter_core_run", paddle::platform::TracerEventType::UserDefined, @@ -1175,9 +1170,11 @@ inline void PirRunProgramGradAPI( "fetch_and_gc", paddle::platform::TracerEventType::UserDefined, 1); // Step 4. get outputs details::ShareTensorsFromScopeByValue( - backward_global_block, x_grad, x_grad_values, global_inner_scope); - details::ShareTensorsFromScopeByValue( - backward_global_block, params_grad, p_grad_values, global_inner_scope); + backward_program->block(), x_grad, x_grad_values, global_inner_scope); + details::ShareTensorsFromScopeByValue(backward_program->block(), + params_grad, + p_grad_values, + global_inner_scope); VLOG(4) << "after backward gc all vars"; global_inner_scope->SetCanReused(true); details::GcScope(global_inner_scope); @@ -1316,8 +1313,7 @@ class GradNodeRunProgram : public egr::GradNodeBase { if (x[i].is_dense_tensor()) { x_grad->emplace_back(std::make_shared()); } else if (x[i].is_selected_rows()) { - auto selected_row = std::make_shared(); - x_grad->emplace_back(selected_row); + x_grad->emplace_back(std::make_shared()); } x_grad->back().set_name(x_grad_names[i]); } @@ -1446,6 +1442,10 @@ class PirGradNodeRunProgram : public egr::GradNodeBase { VLOG(3) << "End Eager Backward Node: PirGradNodeRunProgram"; *executed_ = true; + egr::EagerUtils::FillZeroForEmptyOptionalGradOutput(&x_grad, + this->OutputMeta()[0]); + egr::EagerUtils::FillZeroForEmptyOptionalGradOutput(¶ms_grad, + this->OutputMeta()[1]); return {x_grad, params_grad}; } diff --git a/paddle/fluid/eager/utils.cc b/paddle/fluid/eager/utils.cc index 1659430d6216fc..4fa6480372739f 100644 --- a/paddle/fluid/eager/utils.cc +++ b/paddle/fluid/eager/utils.cc @@ -118,6 +118,18 @@ std::vector EagerUtils::nullable_autograd_meta( return metas; } +std::vector EagerUtils::nullable_autograd_meta( + const paddle::optional>& targets) { + std::vector metas; + if (targets.get_ptr() != nullptr) { + metas.reserve(targets.get_ptr()->size()); + for (const paddle::Tensor& t : (*(targets.get_ptr()))) { + metas.emplace_back(nullable_autograd_meta(t)); + } + } + return metas; +} + std::vector EagerUtils::nullable_autograd_meta( const std::vector& targets) { std::vector metas; diff --git a/paddle/fluid/eager/utils.h b/paddle/fluid/eager/utils.h index 147f7377508a7b..aa9c972d7fa200 100644 --- a/paddle/fluid/eager/utils.h +++ b/paddle/fluid/eager/utils.h @@ -148,6 +148,8 @@ class TEST_API EagerUtils { const paddle::optional& target); static std::vector nullable_autograd_meta( const std::vector& targets); + static std::vector nullable_autograd_meta( + const paddle::optional>& targets); static std::vector nullable_autograd_meta( const std::vector& targets); static AutogradMeta* unsafe_autograd_meta(const paddle::Tensor& target); diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 62459827d3c390..83da397e8a7cc4 100755 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -671,7 +671,6 @@ if(WITH_DISTRIBUTE) glog index_sampler index_wrapper - sampler index_dataset_proto lod_rank_table framework_io diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 32c520711d978f..c131c939097052 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -26,6 +26,7 @@ limitations under the License. */ #include "paddle/fluid/operators/ops_extra_info.h" #include "paddle/phi/common/complex.h" #include "paddle/pir/include/core/block.h" +#include "paddle/pir/include/core/program.h" #include "paddle/pir/include/core/value.h" #include "paddle/utils/blank.h" @@ -977,6 +978,9 @@ struct SetAttrDescVisitor { void operator()(const std::vector &v) const { // just do nothing. } + void operator()(const std::shared_ptr &v) const { + // just do nothing. + } void operator()(const std::vector &v) const { std::vector var_names; for (auto var : v) { diff --git a/paddle/fluid/framework/type_defs.cc b/paddle/fluid/framework/type_defs.cc index d8a6546ea718d6..6d350f1fe1e6c6 100644 --- a/paddle/fluid/framework/type_defs.cc +++ b/paddle/fluid/framework/type_defs.cc @@ -39,7 +39,8 @@ template class variant, ::pir::Block*, - std::vector<::pir::Value>>; + std::vector<::pir::Value>, + std::shared_ptr<::pir::Program>>; } // namespace paddle REGISTER_LOG_SIMPLY_STR(paddle::framework::AttributeMap); REGISTER_LOG_SIMPLY_STR(paddle::framework::Attribute); diff --git a/paddle/fluid/framework/type_defs.h b/paddle/fluid/framework/type_defs.h index 61f133ceb082a8..919da606015551 100644 --- a/paddle/fluid/framework/type_defs.h +++ b/paddle/fluid/framework/type_defs.h @@ -27,6 +27,7 @@ limitations under the License. */ #include "paddle/phi/common/scalar.h" #include "paddle/phi/core/infermeta_utils.h" #include "paddle/pir/include/core/block.h" +#include "paddle/pir/include/core/program.h" #include "paddle/pir/include/core/value.h" #include "paddle/utils/blank.h" #include "paddle/utils/small_vector.h" @@ -67,7 +68,8 @@ using Attribute = paddle::variant, ::pir::Block*, - std::vector<::pir::Value>>; + std::vector<::pir::Value>, + std::shared_ptr<::pir::Program>>; using AttributeMap = std::unordered_map; using OpCreator = diff --git a/paddle/fluid/imperative/amp_utils.h b/paddle/fluid/imperative/amp_utils.h index 3b961e5960c816..14d3e64409f1c2 100644 --- a/paddle/fluid/imperative/amp_utils.h +++ b/paddle/fluid/imperative/amp_utils.h @@ -301,6 +301,12 @@ inline T AmpAutoCast(const std::string& input_name, input_name == "Ln1Scale" || input_name == "Ln1Bias") { return input; } + if (input_name == "ln_scale" || input_name == "ln_bias" || + input_name == "ln_scale_2" || input_name == "ln_bias_2" || + input_name == "ln1_scale" || input_name == "ln1_bias" || + input_name == "ln2_scale" || input_name == "ln2_bias") { + return input; + } } } if (NeedCast(input, dst_dtype)) { diff --git a/paddle/fluid/imperative/reducer.cc b/paddle/fluid/imperative/reducer.cc index 526935a5182be6..3d6f38aac1ecea 100644 --- a/paddle/fluid/imperative/reducer.cc +++ b/paddle/fluid/imperative/reducer.cc @@ -19,7 +19,7 @@ #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/parallel_context.h" -#include "paddle/fluid/operators/math/concat_and_split.h" +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/phi/kernels/funcs/strided_memcpy.h" #ifdef PADDLE_WITH_XPU #include "paddle/fluid/platform/device/xpu/enforce_xpu.h" @@ -74,7 +74,7 @@ static void ConcatTensorsForAllReduce( const DeviceContext &context, const std::vector &dense_tensors_, framework::Variable *p_dense_contents) { - operators::math::ConcatFunctor concat_functor_; + phi::funcs::ConcatFunctor concat_functor_; concat_functor_(context, dense_tensors_, 0, @@ -102,7 +102,7 @@ static void SplitTensorsForAllReduce( phi::funcs::StridedMemcpyWithAxis0( context, *in, shape_refer, &outs); } else { - operators::math::SplitFunctor split_functor_; + phi::funcs::SplitFunctor split_functor_; split_functor_(context, *in, shape_refer, 0, &outs); } } @@ -179,8 +179,7 @@ void SplitTensorsForAllReduce( outs.emplace_back(&tensor); shape_refer.emplace_back(&tensor); } - operators::math::SplitFunctor - split_functor_; + phi::funcs::SplitFunctor split_functor_; split_functor_(context, *in, shape_refer, 0, &outs); } diff --git a/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu index c9e56f1d63823d..3d07632cb61aea 100644 --- a/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu @@ -74,7 +74,7 @@ static int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) { template __global__ void groupNormNCHW32SumKernelQDQ( - const GroupNormNHWCParams<__half> params) { + const GroupNormNDHWCParams<__half> params) { // The object in charge of doing the sums for the different blocks. typedef cub::BlockScan BlockScan; @@ -90,9 +90,9 @@ __global__ void groupNormNCHW32SumKernelQDQ( int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2; // The first activation loaded by that block. - int32_t hwBegin = blockIdx.y * params.hwPerBlock; + int32_t dhwBegin = blockIdx.y * params.dhwPerBlock; // The last activation loaded by that block. - int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + int32_t dhwEnd = min(dhwBegin + params.dhwPerBlock, params.dhw); // The sums. float sum = 0.F; @@ -102,13 +102,13 @@ __global__ void groupNormNCHW32SumKernelQDQ( // nchw32 layout // batch offset + channel offset - int nc_offset = static_cast(ni) * params.hwc + - ci / 32 * params.hw * 32 + ci % 32; + int nc_offset = static_cast(ni) * params.dhwc + + ci / 32 * params.dhw * 32 + ci % 32; // Iterate over the activations to compute the sums. - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { + for (int32_t dhwi = dhwBegin; dhwi < dhwEnd; ++dhwi) { // The offset. - int64_t offset = nc_offset + static_cast(hwi) * 32; + int64_t offset = nc_offset + static_cast(dhwi) * 32; // Fetch two channels per thread. __half2 h2(0, 0); @@ -166,14 +166,14 @@ __global__ void groupNormNCHW32SumKernelQDQ( atomicAdd(¶ms.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y); } -void groupNormNCHW32SumQDQ(const GroupNormNHWCParams<__half> ¶ms, +void groupNormNCHW32SumQDQ(const GroupNormNDHWCParams<__half> ¶ms, cudaStream_t stream) { dim3 grid; // The number of blocks to compute all the channels. grid.x = divUp(params.c, params.cPerBlock); // The number of blocks to compute all the activations in a given instance. - grid.y = divUp(params.hw, params.hwPerBlock); + grid.y = divUp(params.dhw, params.dhwPerBlock); // The number of instances. grid.z = params.n; @@ -198,7 +198,7 @@ void groupNormNCHW32SumQDQ(const GroupNormNHWCParams<__half> ¶ms, template __global__ void groupNormNCHW32ScaleKernelQDQ( - const GroupNormNHWCParams<__half> params) { + const GroupNormNDHWCParams<__half> params) { // The instance in the batch. int32_t ni = blockIdx.z; // The channel loaded by that thread (2 channels per thread for F16x2). @@ -226,25 +226,25 @@ __global__ void groupNormNCHW32ScaleKernelQDQ( } // Compute the mean. - float mean = sum * params.invHWC; + float mean = sum * params.invDHWC; // Compute the variance. - float var = sumSq * params.invHWC - (mean * mean); + float var = sumSq * params.invDHWC - (mean * mean); // Compute the inverse of the stddev. float invStdDev = rsqrtf(var + params.eps); // The first activation loaded by that block. - int32_t hwBegin = blockIdx.y * params.hwPerBlock; + int32_t dhwBegin = blockIdx.y * params.dhwPerBlock; // The last activation loaded by that block. - int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + int32_t dhwEnd = min(dhwBegin + params.dhwPerBlock, params.dhw); // nchw32 layout - int c_offset = ci / 32 * params.hw * 32 + ci % 32; + int c_offset = ci / 32 * params.dhw * 32 + ci % 32; // Iterate over the activations to compute the sums. - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { + for (int32_t dhwi = dhwBegin; dhwi < dhwEnd; ++dhwi) { // The src/dst offset. - int64_t offset = static_cast(ni) * params.hwc + c_offset + - static_cast(hwi) * 32; + int64_t offset = static_cast(ni) * params.dhwc + c_offset + + static_cast(dhwi) * 32; // Fetch two channels per thread. __half2 h2(0, 0); @@ -290,14 +290,14 @@ __global__ void groupNormNCHW32ScaleKernelQDQ( } } -void groupNormNCHW32ScaleQDQ(const GroupNormNHWCParams<__half> ¶ms, +void groupNormNCHW32ScaleQDQ(const GroupNormNDHWCParams<__half> ¶ms, cudaStream_t stream) { dim3 grid; // The number of blocks to compute all the channels. grid.x = divUp(params.c, params.cPerBlock); // The number of blocks to compute all the activations in a given instance. - grid.y = divUp(params.hw, params.hwPerBlock); + grid.y = divUp(params.dhw, params.dhwPerBlock); // The number of instances. grid.z = params.n; @@ -642,7 +642,7 @@ int GroupNormPluginDynamic::enqueue( DataLayout::kNCHW); } else if (input_desc[0].format == nvinfer1::PluginFormat::kHWC8) { int32_t cPerBlock = 320; - int32_t maxBlocksPerHW = 1024; + int32_t maxBlocksPerDHW = 1024; switch (input_desc[0].dims.d[1]) { case 960: case 1920: @@ -661,6 +661,25 @@ int GroupNormPluginDynamic::enqueue( if (cPerBlock > input_desc[0].dims.d[1]) { cPerBlock = 8; } + auto d_dim = input_desc[0].dims.nbDims; + params_.n = input_desc[0].dims.d[0]; + if (d_dim == 3) { + params_.c = input_desc[0].dims.d[1]; + params_.d = 1; + params_.h = 1; + params_.w = input_desc[0].dims.d[2]; + } else if (d_dim == 4) { + params_.c = input_desc[0].dims.d[1]; + params_.d = 1; + params_.h = input_desc[0].dims.d[2]; + params_.w = input_desc[0].dims.d[3]; + } else { + // d_dim == 5 + params_.c = input_desc[0].dims.d[1]; + params_.d = input_desc[0].dims.d[2]; + params_.h = input_desc[0].dims.d[3]; + params_.w = input_desc[0].dims.d[4]; + } params_.withSilu = with_silu_; params_.dst = static_cast(outputs[0]); @@ -669,18 +688,19 @@ int GroupNormPluginDynamic::enqueue( params_.beta = reinterpret_cast(bias_gpu_); params_.redBuffer = static_cast(workspace); params_.var_data = nullptr; - params_.n = input_desc[0].dims.d[0]; - params_.h = input_desc[0].dims.d[2]; - params_.w = input_desc[0].dims.d[3]; - params_.c = input_desc[0].dims.d[1]; + // params_.n = input_desc[0].dims.d[0]; + // params_.h = input_desc[0].dims.d[2]; + // params_.w = input_desc[0].dims.d[3]; + // params_.c = input_desc[0].dims.d[1]; params_.groups = groups_; - params_.hw = params_.h * params_.w; - const int32_t blocksPerHW = findMaxDivisor(params_.hw, maxBlocksPerHW); - params_.hwPerBlock = divUp(params_.hw, blocksPerHW); + params_.dhw = params_.d * params_.h * params_.w; + const int32_t blocksPerDHW = findMaxDivisor(params_.dhw, maxBlocksPerDHW); + params_.dhwPerBlock = divUp(params_.dhw, blocksPerDHW); params_.cPerBlock = cPerBlock; params_.cPerGroup = params_.c / params_.groups; - params_.hwc = params_.hw * params_.c; - params_.invHWC = 1.F / static_cast(params_.hw * params_.cPerGroup); + params_.dhwc = params_.dhw * params_.c; + params_.invDHWC = + 1.F / static_cast(params_.dhw * params_.cPerGroup); params_.groupsPerBlock = cPerBlock / params_.cPerGroup; params_.eps = eps_; params_.var_data = nullptr; @@ -690,10 +710,10 @@ int GroupNormPluginDynamic::enqueue( 2 * sizeof(float) * params_.n * groups_, stream); - phi::groupNormNHWCSum nhwc_sum; - nhwc_sum(¶ms_, stream); - phi::groupNormNHWCScale nhwc_scale; - nhwc_scale(params_, stream); + phi::groupNormNDHWCSum ndhwc_sum; + ndhwc_sum(¶ms_, stream); + phi::groupNormNDHWCScale ndhwc_scale; + ndhwc_scale(params_, stream); } else { PADDLE_THROW(platform::errors::Fatal( "The Groupnorm TRT Plugin's only support nchw or nhwc8 input")); @@ -704,7 +724,7 @@ int GroupNormPluginDynamic::enqueue( if (input_desc[0].format == nvinfer1::PluginFormat::kCHW32) { int32_t cPerBlock = 320; - int32_t maxBlocksPerHW = 1024; + int32_t maxBlocksPerDHW = 1024; switch (input_desc[0].dims.d[1]) { case 960: case 1920: @@ -723,6 +743,25 @@ int GroupNormPluginDynamic::enqueue( if (cPerBlock > input_desc[0].dims.d[1]) { cPerBlock = 8; } + auto d_dim = input_desc[0].dims.nbDims; + params_.n = input_desc[0].dims.d[0]; + if (d_dim == 3) { + params_.c = input_desc[0].dims.d[1]; + params_.d = 1; + params_.h = 1; + params_.w = input_desc[0].dims.d[2]; + } else if (d_dim == 4) { + params_.c = input_desc[0].dims.d[1]; + params_.d = 1; + params_.h = input_desc[0].dims.d[2]; + params_.w = input_desc[0].dims.d[3]; + } else { + // d_dim == 5 + params_.c = input_desc[0].dims.d[1]; + params_.d = input_desc[0].dims.d[2]; + params_.h = input_desc[0].dims.d[3]; + params_.w = input_desc[0].dims.d[4]; + } params_.withSilu = with_silu_; params_.dst = static_cast(outputs[0]); params_.srcX = static_cast(inputs[0]); @@ -730,18 +769,19 @@ int GroupNormPluginDynamic::enqueue( params_.gamma = scale_gpu_; params_.beta = bias_gpu_; params_.redBuffer = static_cast(workspace); - params_.n = input_desc[0].dims.d[0]; - params_.h = input_desc[0].dims.d[2]; - params_.w = input_desc[0].dims.d[3]; - params_.c = input_desc[0].dims.d[1]; + // params_.n = input_desc[0].dims.d[0]; + // params_.h = input_desc[0].dims.d[2]; + // params_.w = input_desc[0].dims.d[3]; + // params_.c = input_desc[0].dims.d[1]; params_.groups = groups_; - params_.hw = params_.h * params_.w; - const int32_t blocksPerHW = findMaxDivisor(params_.hw, maxBlocksPerHW); - params_.hwPerBlock = divUp(params_.hw, blocksPerHW); + params_.dhw = params_.d * params_.h * params_.w; + const int32_t blocksPerDHW = findMaxDivisor(params_.dhw, maxBlocksPerDHW); + params_.dhwPerBlock = divUp(params_.dhw, blocksPerDHW); params_.cPerBlock = cPerBlock; params_.cPerGroup = params_.c / params_.groups; - params_.hwc = params_.hw * params_.c; - params_.invHWC = 1.F / static_cast(params_.hw * params_.cPerGroup); + params_.dhwc = params_.dhw * params_.c; + params_.invDHWC = + 1.F / static_cast(params_.dhw * params_.cPerGroup); params_.groupsPerBlock = cPerBlock / params_.cPerGroup; CHECK_EQ(cPerBlock % params_.cPerGroup, 0); CHECK_EQ(params_.cPerGroup % 2, 0); diff --git a/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.h index e76d802f853653..879fd42de50155 100644 --- a/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.h @@ -29,7 +29,7 @@ namespace inference { namespace tensorrt { namespace plugin { -using phi::GroupNormNHWCParams; +using phi::GroupNormNDHWCParams; class GroupNormPlugin : public PluginTensorRT { public: size_t getSerializationSize() const TRT_NOEXCEPT override { @@ -289,7 +289,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT { float eps_; std::vector mean_shape_; std::vector variance_shape_; - GroupNormNHWCParams params_; + GroupNormNDHWCParams params_; bool with_silu_; bool with_fp16_; bool with_int8_; diff --git a/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.cu index 7ccf5d8a8a1bc7..24ff83a8909fd2 100644 --- a/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.cu @@ -120,8 +120,8 @@ struct GroupSumsOp { }; template -__global__ void prelnGroupNormNHWCSumKernel( - GroupNormNHWCParams<__half> params) { +__global__ void prelnGroupNormNDHWCSumKernel( + GroupNormNDHWCParams<__half> params) { // The object in charge of doing the sums for the different blocks. typedef cub::BlockScan BlockScan; @@ -137,19 +137,19 @@ __global__ void prelnGroupNormNHWCSumKernel( int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2; // The first activation loaded by that block. - int32_t hwBegin = blockIdx.y * params.hwPerBlock; + int32_t dhwBegin = blockIdx.y * params.dhwPerBlock; // The last activation loaded by that block. - int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + int32_t dhwEnd = min(dhwBegin + params.dhwPerBlock, params.dhw); // The sums. float sum = 0.F; float sumSq = 0.F; // Iterate over the activations to compute the sums. - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { + for (int32_t dhwi = dhwBegin; dhwi < dhwEnd; ++dhwi) { // The offset. - int64_t offset = static_cast(ni) * params.hwc + - static_cast(hwi) * params.c + ci; + int64_t offset = static_cast(ni) * params.dhwc + + static_cast(dhwi) * params.c + ci; // Fetch two channels per thread. __half2 h2(0, 0); if (ci < params.c) { @@ -213,30 +213,30 @@ __global__ void prelnGroupNormNHWCSumKernel( atomicAdd(¶ms.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y); } -void prelnGroupNormNHWCSum(GroupNormNHWCParams<__half> const ¶ms, - cudaStream_t stream) { +void prelnGroupNormNDHWCSum(GroupNormNDHWCParams<__half> const ¶ms, + cudaStream_t stream) { // Make sure the values are as we expect. PADDLE_ENFORCE_EQ(params.c % params.cPerBlock, 0, platform::errors::InvalidArgument( - "The groupNormNHWCSum of prelnGroupnormAct Plugin got " + "The groupNormNDHWCSum of prelnGroupnormAct Plugin got " "wrong parameters" "params.c %% params.cPerBlock should be 0, but get %d.", params.c % params.cPerBlock)); PADDLE_ENFORCE_EQ( - params.hw % params.hwPerBlock, + params.dhw % params.dhwPerBlock, 0, platform::errors::InvalidArgument( - "The groupNormNHWCSum of prelnGroupnormAct Plugin got wrong " + "The groupNormNDHWCSum of prelnGroupnormAct Plugin got wrong " "parameters" - "params.hw %% params.hwPerBlock should be 0, but get %d.", - params.hw % params.hwPerBlock)); + "params.dhw %% params.dhwPerBlock should be 0, but get %d.", + params.dhw % params.dhwPerBlock)); // Make sure a group does not span multiple blocks. PADDLE_ENFORCE_EQ( params.cPerBlock % params.cPerGroup, 0, platform::errors::InvalidArgument( - "The groupNormNHWCSum of prelnGroupnormAct Plugin got wrong " + "The groupNormNDHWCSum of prelnGroupnormAct Plugin got wrong " "parameters" "params.cPerBlock %% params.cPerGroup should be 0, but get %d.", params.cPerBlock % params.cPerGroup)); @@ -245,36 +245,36 @@ void prelnGroupNormNHWCSum(GroupNormNHWCParams<__half> const ¶ms, // The number of blocks to compute all the channels. grid.x = params.c / params.cPerBlock; // The number of blocks to compute all the activations in a given instance. - grid.y = divUp(params.hw, params.hwPerBlock); + grid.y = divUp(params.dhw, params.dhwPerBlock); // The number of instances. grid.z = params.n; switch (params.cPerBlock) { case 320: - prelnGroupNormNHWCSumKernel<160><<>>(params); + prelnGroupNormNDHWCSumKernel<160><<>>(params); break; case 480: - prelnGroupNormNHWCSumKernel<256><<>>(params); + prelnGroupNormNDHWCSumKernel<256><<>>(params); break; case 256: - prelnGroupNormNHWCSumKernel<128><<>>(params); + prelnGroupNormNDHWCSumKernel<128><<>>(params); break; case 128: - prelnGroupNormNHWCSumKernel<64><<>>(params); + prelnGroupNormNDHWCSumKernel<64><<>>(params); break; case 8: - prelnGroupNormNHWCSumKernel<4><<>>(params); + prelnGroupNormNDHWCSumKernel<4><<>>(params); break; default: PADDLE_THROW(platform::errors::Fatal( - "The function groupNormNHWCSum of prelnGroupnormAct TRT Plugin " + "The function groupNormNDHWCSum of prelnGroupnormAct TRT Plugin " "encounter error")); } } template -__global__ void prelnGroupNormNHWCScaleKernel( - GroupNormNHWCParams<__half> params) { +__global__ void prelnGroupNormNDHWCScaleKernel( + GroupNormNDHWCParams<__half> params) { // The instance in the batch. int32_t ni = blockIdx.z; // The channel loaded by that thread (2 channels per thread for F16x2). @@ -299,21 +299,21 @@ __global__ void prelnGroupNormNHWCScaleKernel( } // Compute the mean. - float mean = sum * params.invHWC; + float mean = sum * params.invDHWC; // Compute the variance. - float var = sumSq * params.invHWC - (mean * mean); + float var = sumSq * params.invDHWC - (mean * mean); // Compute the inverse of the stddev. float invStdDev = rsqrtf(var + params.eps); // The first activation loaded by that block. - int32_t hwBegin = blockIdx.y * params.hwPerBlock; + int32_t dhwBegin = blockIdx.y * params.dhwPerBlock; // The last activation loaded by that block. - int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + int32_t dhwEnd = min(dhwBegin + params.dhwPerBlock, params.dhw); // Iterate over the activations to compute the sums. - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { + for (int32_t dhwi = dhwBegin; dhwi < dhwEnd; ++dhwi) { // The src/dst offset. - int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci; + int64_t offset = (int64_t)ni * params.dhwc + dhwi * params.c + ci; // Fetch two channels per thread. __half2 h2(0, 0); @@ -345,14 +345,14 @@ __global__ void prelnGroupNormNHWCScaleKernel( } } -void prelnGroupNormNHWCScale(GroupNormNHWCParams<__half> const ¶ms, - cudaStream_t stream) { +void prelnGroupNormNDHWCScale(GroupNormNDHWCParams<__half> const ¶ms, + cudaStream_t stream) { // Make sure the dimensions are aligned with what we expect. PADDLE_ENFORCE_EQ( params.c % params.cPerBlock, 0, platform::errors::InvalidArgument( - "The groupNormNHWCScale of prelnGroupnormAct Plugin got " + "The groupNormNDHWCScale of prelnGroupnormAct Plugin got " "wrong parameters" "params.c %% params.cPerBlock should be 0, but get %d.", params.c % params.cPerBlock)); @@ -361,7 +361,7 @@ void prelnGroupNormNHWCScale(GroupNormNHWCParams<__half> const ¶ms, params.cPerBlock % params.cPerGroup, 0, platform::errors::InvalidArgument( - "The groupNormNHWCScale of prelnGroupnormAct Plugin got wrong " + "The groupNormNDHWCScale of prelnGroupnormAct Plugin got wrong " "parameters" "params.cPerBlock %% params.cPerGroup should be 0, but get %d.", params.cPerBlock % params.cPerGroup)); @@ -370,29 +370,29 @@ void prelnGroupNormNHWCScale(GroupNormNHWCParams<__half> const ¶ms, // The number of blocks to compute all the channels. grid.x = params.c / params.cPerBlock; // The number of blocks to compute all the activations in a given instance. - grid.y = divUp(params.hw, params.hwPerBlock); + grid.y = divUp(params.dhw, params.dhwPerBlock); // The number of instances. grid.z = params.n; switch (params.cPerBlock) { case 320: - prelnGroupNormNHWCScaleKernel<160><<>>(params); + prelnGroupNormNDHWCScaleKernel<160><<>>(params); break; case 480: - prelnGroupNormNHWCScaleKernel<256><<>>(params); + prelnGroupNormNDHWCScaleKernel<256><<>>(params); break; case 256: - prelnGroupNormNHWCScaleKernel<128><<>>(params); + prelnGroupNormNDHWCScaleKernel<128><<>>(params); break; case 128: - prelnGroupNormNHWCScaleKernel<64><<>>(params); + prelnGroupNormNDHWCScaleKernel<64><<>>(params); break; case 8: - prelnGroupNormNHWCScaleKernel<4><<>>(params); + prelnGroupNormNDHWCScaleKernel<4><<>>(params); break; default: PADDLE_THROW(platform::errors::Fatal( - "The function groupNormNHWCSum of prelnGroupnormAct TRT Plugin " + "The function groupNormNDHWCSum of prelnGroupnormAct TRT Plugin " "encounter error")); } } @@ -413,7 +413,7 @@ int PrelnGroupnormActPluginDynamic::enqueue( VLOG(1) << "TRT Plugin DataType selected. prelnGroupnormAct-->fp16"; int32_t cPerBlock = 320; - int32_t maxBlocksPerHW = 1024; + int32_t maxBlocksPerDHW = 1024; switch (input_desc[0].dims.d[1]) { case 960: @@ -433,6 +433,25 @@ int PrelnGroupnormActPluginDynamic::enqueue( if (cPerBlock > input_desc[0].dims.d[1]) { cPerBlock = 8; } + auto d_dim = input_desc[0].dims.nbDims; + params_.n = input_desc[0].dims.d[0]; + if (d_dim == 3) { + params_.c = input_desc[0].dims.d[1]; + params_.d = 1; + params_.h = 1; + params_.w = input_desc[0].dims.d[2]; + } else if (d_dim == 4) { + params_.c = input_desc[0].dims.d[1]; + params_.d = 1; + params_.h = input_desc[0].dims.d[2]; + params_.w = input_desc[0].dims.d[3]; + } else { + // d_dim == 5 + params_.c = input_desc[0].dims.d[1]; + params_.d = input_desc[0].dims.d[2]; + params_.h = input_desc[0].dims.d[3]; + params_.w = input_desc[0].dims.d[4]; + } params_.withSilu = with_silu_; params_.dst = static_cast(outputs[1]); params_.eleOut = static_cast(outputs[0]); @@ -441,24 +460,24 @@ int PrelnGroupnormActPluginDynamic::enqueue( params_.gamma = scale_gpu_.get(); params_.beta = bias_gpu_.get(); params_.redBuffer = static_cast(workspace); - params_.n = input_desc[0].dims.d[0]; - params_.h = input_desc[0].dims.d[2]; - params_.w = input_desc[0].dims.d[3]; - params_.c = input_desc[0].dims.d[1]; + // params_.n = input_desc[0].dims.d[0]; + // params_.h = input_desc[0].dims.d[2]; + // params_.w = input_desc[0].dims.d[3]; + // params_.c = input_desc[0].dims.d[1]; params_.groups = groups_; - params_.hw = params_.h * params_.w; - const int32_t blocksPerHW = findMaxDivisor(params_.hw, maxBlocksPerHW); - params_.hwPerBlock = divUp(params_.hw, blocksPerHW); + params_.dhw = params_.d * params_.h * params_.w; + const int32_t blocksPerDHW = findMaxDivisor(params_.dhw, maxBlocksPerDHW); + params_.dhwPerBlock = divUp(params_.dhw, blocksPerDHW); params_.cPerBlock = cPerBlock; params_.cPerGroup = params_.c / params_.groups; - params_.hwc = params_.hw * params_.c; - params_.invHWC = 1.F / static_cast(params_.hw * params_.cPerGroup); + params_.dhwc = params_.dhw * params_.c; + params_.invDHWC = 1.F / static_cast(params_.dhw * params_.cPerGroup); params_.groupsPerBlock = cPerBlock / params_.cPerGroup; params_.eps = eps_; cudaMemsetAsync(params_.redBuffer, 0, ws_, stream); - prelnGroupNormNHWCSum(params_, stream); - prelnGroupNormNHWCScale(params_, stream); + prelnGroupNormNDHWCSum(params_, stream); + prelnGroupNormNDHWCScale(params_, stream); } else { // input not fp16 diff --git a/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.h index 2d5dde91901035..7119d5c8a710e7 100644 --- a/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.h @@ -28,7 +28,7 @@ namespace paddle { namespace inference { namespace tensorrt { namespace plugin { -using phi::GroupNormNHWCParams; +using phi::GroupNormNDHWCParams; class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT { public: PrelnGroupnormActPluginDynamic(const float* scale, @@ -174,7 +174,7 @@ class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT { std::vector bias_; std::shared_ptr scale_gpu_; std::shared_ptr bias_gpu_; - GroupNormNHWCParams<__half> params_; + GroupNormNDHWCParams<__half> params_; int groups_; float eps_; bool with_silu_; diff --git a/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.cu index 95c408fa859251..1722d720d5daf3 100644 --- a/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.cu @@ -131,7 +131,8 @@ struct GroupSumsOp { }; template -__global__ void skipGroupNormNHWCSumKernel(GroupNormNHWCParams<__half> params) { +__global__ void skipGroupNormNDHWCSumKernel( + GroupNormNDHWCParams<__half> params) { // The object in charge of doing the sums for the different blocks. typedef cub::BlockScan BlockScan; @@ -147,19 +148,19 @@ __global__ void skipGroupNormNHWCSumKernel(GroupNormNHWCParams<__half> params) { int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2; // The first activation loaded by that block. - int32_t hwBegin = blockIdx.y * params.hwPerBlock; + int32_t dhwBegin = blockIdx.y * params.dhwPerBlock; // The last activation loaded by that block. - int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + int32_t dhwEnd = min(dhwBegin + params.dhwPerBlock, params.dhw); // The sums. float sum = 0.F; float sumSq = 0.F; // Iterate over the activations to compute the sums. - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { + for (int32_t dhwi = dhwBegin; dhwi < dhwEnd; ++dhwi) { // The offset. - int64_t offset = static_cast(ni) * params.hwc + - static_cast(hwi) * params.c + ci; + int64_t offset = static_cast(ni) * params.dhwc + + static_cast(dhwi) * params.c + ci; // Fetch two channels per thread. __half2 h2(0, 0); if (ci < params.c) { @@ -224,29 +225,31 @@ __global__ void skipGroupNormNHWCSumKernel(GroupNormNHWCParams<__half> params) { atomicAdd(¶ms.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y); } -void skipGroupNormNHWCSum(GroupNormNHWCParams<__half> const ¶ms, - cudaStream_t stream) { +void skipGroupNormNDHWCSum(GroupNormNDHWCParams<__half> const ¶ms, + cudaStream_t stream) { // Make sure the values are as we expect. + PADDLE_ENFORCE_EQ(params.c % params.cPerBlock, + 0, + platform::errors::InvalidArgument( + "The groupNormNDHWCSum of SkipGroupnormAct Plugin got " + "wrong parameters" + "params.c %% params.cPerBlock should be 0, but get %d.", + params.c % params.cPerBlock)); PADDLE_ENFORCE_EQ( - params.c % params.cPerBlock, - 0, - platform::errors::InvalidArgument( - "The groupNormNHWCSum of SkipGroupnormAct Plugin got wrong parameters" - "params.c %% params.cPerBlock should be 0, but get %d.", - params.c % params.cPerBlock)); - PADDLE_ENFORCE_EQ( - params.hw % params.hwPerBlock, + params.dhw % params.dhwPerBlock, 0, platform::errors::InvalidArgument( - "The groupNormNHWCSum of SkipGroupnormAct Plugin got wrong parameters" - "params.hw %% params.hwPerBlock should be 0, but get %d.", - params.hw % params.hwPerBlock)); + "The groupNormNDHWCSum of SkipGroupnormAct Plugin got wrong " + "parameters" + "params.dhw %% params.dhwPerBlock should be 0, but get %d.", + params.dhw % params.dhwPerBlock)); // Make sure a group does not span multiple blocks. PADDLE_ENFORCE_EQ( params.cPerBlock % params.cPerGroup, 0, platform::errors::InvalidArgument( - "The groupNormNHWCSum of SkipGroupnormAct Plugin got wrong parameters" + "The groupNormNDHWCSum of SkipGroupnormAct Plugin got wrong " + "parameters" "params.cPerBlock %% params.cPerGroup should be 0, but get %d.", params.cPerBlock % params.cPerGroup)); dim3 grid; @@ -254,36 +257,36 @@ void skipGroupNormNHWCSum(GroupNormNHWCParams<__half> const ¶ms, // The number of blocks to compute all the channels. grid.x = params.c / params.cPerBlock; // The number of blocks to compute all the activations in a given instance. - grid.y = divUp(params.hw, params.hwPerBlock); + grid.y = divUp(params.dhw, params.dhwPerBlock); // The number of instances. grid.z = params.n; switch (params.cPerBlock) { case 320: - skipGroupNormNHWCSumKernel<160><<>>(params); + skipGroupNormNDHWCSumKernel<160><<>>(params); break; case 480: - skipGroupNormNHWCSumKernel<256><<>>(params); + skipGroupNormNDHWCSumKernel<256><<>>(params); break; case 256: - skipGroupNormNHWCSumKernel<128><<>>(params); + skipGroupNormNDHWCSumKernel<128><<>>(params); break; case 128: - skipGroupNormNHWCSumKernel<64><<>>(params); + skipGroupNormNDHWCSumKernel<64><<>>(params); break; case 8: - skipGroupNormNHWCSumKernel<4><<>>(params); + skipGroupNormNDHWCSumKernel<4><<>>(params); break; default: PADDLE_THROW(platform::errors::Fatal( - "The function groupNormNHWCSum of SkipGroupnormAct TRT Plugin " + "The function groupNormNDHWCSum of SkipGroupnormAct TRT Plugin " "encounter error")); } } template -__global__ void skipGroupNormNHWCScaleKernel( - GroupNormNHWCParams<__half> params) { +__global__ void skipGroupNormNDHWCScaleKernel( + GroupNormNDHWCParams<__half> params) { // The instance in the batch. int32_t ni = blockIdx.z; // The channel loaded by that thread (2 channels per thread for F16x2). @@ -308,21 +311,21 @@ __global__ void skipGroupNormNHWCScaleKernel( } // Compute the mean. - float mean = sum * params.invHWC; + float mean = sum * params.invDHWC; // Compute the variance. - float var = sumSq * params.invHWC - (mean * mean); + float var = sumSq * params.invDHWC - (mean * mean); // Compute the inverse of the stddev. float invStdDev = rsqrtf(var + params.eps); // The first activation loaded by that block. - int32_t hwBegin = blockIdx.y * params.hwPerBlock; + int32_t dhwBegin = blockIdx.y * params.dhwPerBlock; // The last activation loaded by that block. - int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + int32_t dhwEnd = min(dhwBegin + params.dhwPerBlock, params.dhw); // Iterate over the activations to compute the sums. - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { + for (int32_t dhwi = dhwBegin; dhwi < dhwEnd; ++dhwi) { // The src/dst offset. - int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci; + int64_t offset = (int64_t)ni * params.dhwc + dhwi * params.c + ci; // Fetch two channels per thread. __half2 h2(0, 0); @@ -354,22 +357,23 @@ __global__ void skipGroupNormNHWCScaleKernel( } } -void skipGroupNormNHWCScale(GroupNormNHWCParams<__half> const ¶ms, - cudaStream_t stream) { +void skipGroupNormNDHWCScale(GroupNormNDHWCParams<__half> const ¶ms, + cudaStream_t stream) { // Make sure the dimensions are aligned with what we expect. - PADDLE_ENFORCE_EQ(params.c % params.cPerBlock, - 0, - platform::errors::InvalidArgument( - "The groupNormNHWCScale of SkipGroupnormAct Plugin got " - "wrong parameters" - "params.c %% params.cPerBlock should be 0, but get %d.", - params.c % params.cPerBlock)); + PADDLE_ENFORCE_EQ( + params.c % params.cPerBlock, + 0, + platform::errors::InvalidArgument( + "The groupNormNDHWCScale of SkipGroupnormAct Plugin got " + "wrong parameters" + "params.c %% params.cPerBlock should be 0, but get %d.", + params.c % params.cPerBlock)); // Make sure a group does not span multiple blocks. PADDLE_ENFORCE_EQ( params.cPerBlock % params.cPerGroup, 0, platform::errors::InvalidArgument( - "The groupNormNHWCScale of SkipGroupnormAct Plugin got wrong " + "The groupNormNDHWCScale of SkipGroupnormAct Plugin got wrong " "parameters" "params.cPerBlock %% params.cPerGroup should be 0, but get %d.", params.cPerBlock % params.cPerGroup)); @@ -378,29 +382,29 @@ void skipGroupNormNHWCScale(GroupNormNHWCParams<__half> const ¶ms, // The number of blocks to compute all the channels. grid.x = params.c / params.cPerBlock; // The number of blocks to compute all the activations in a given instance. - grid.y = divUp(params.hw, params.hwPerBlock); + grid.y = divUp(params.dhw, params.dhwPerBlock); // The number of instances. grid.z = params.n; switch (params.cPerBlock) { case 320: - skipGroupNormNHWCScaleKernel<160><<>>(params); + skipGroupNormNDHWCScaleKernel<160><<>>(params); break; case 480: - skipGroupNormNHWCScaleKernel<256><<>>(params); + skipGroupNormNDHWCScaleKernel<256><<>>(params); break; case 256: - skipGroupNormNHWCScaleKernel<128><<>>(params); + skipGroupNormNDHWCScaleKernel<128><<>>(params); break; case 128: - skipGroupNormNHWCScaleKernel<64><<>>(params); + skipGroupNormNDHWCScaleKernel<64><<>>(params); break; case 8: - skipGroupNormNHWCScaleKernel<4><<>>(params); + skipGroupNormNDHWCScaleKernel<4><<>>(params); break; default: PADDLE_THROW(platform::errors::Fatal( - "The function groupNormNHWCSum of SkipGroupnormAct TRT Plugin " + "The function groupNormNDHWCSum of SkipGroupnormAct TRT Plugin " "encounter error")); } } @@ -420,7 +424,7 @@ int SkipGroupnormActPluginDynamic::enqueue( } else if (input_type == nvinfer1::DataType::kHALF) { VLOG(1) << "TRT Plugin DataType selected. SkipGroupnormAct-->fp16"; int32_t cPerBlock = 320; - int32_t maxBlocksPerHW = 1024; + int32_t maxBlocksPerDHW = 1024; switch (input_desc[0].dims.d[1]) { case 960: @@ -440,6 +444,25 @@ int SkipGroupnormActPluginDynamic::enqueue( if (cPerBlock > input_desc[0].dims.d[1]) { cPerBlock = 8; } + auto d_dim = input_desc[0].dims.nbDims; + params_.n = input_desc[0].dims.d[0]; + if (d_dim == 3) { + params_.c = input_desc[0].dims.d[1]; + params_.d = 1; + params_.h = 1; + params_.w = input_desc[0].dims.d[2]; + } else if (d_dim == 4) { + params_.c = input_desc[0].dims.d[1]; + params_.d = 1; + params_.h = input_desc[0].dims.d[2]; + params_.w = input_desc[0].dims.d[3]; + } else { + // d_dim == 5 + params_.c = input_desc[0].dims.d[1]; + params_.d = input_desc[0].dims.d[2]; + params_.h = input_desc[0].dims.d[3]; + params_.w = input_desc[0].dims.d[4]; + } params_.withSilu = true; params_.dst = static_cast(outputs[0]); params_.srcX = static_cast(inputs[0]); @@ -447,24 +470,24 @@ int SkipGroupnormActPluginDynamic::enqueue( params_.gamma = scale_gpu_.get(); params_.beta = bias_gpu_.get(); params_.redBuffer = static_cast(workspace); - params_.n = input_desc[0].dims.d[0]; - params_.h = input_desc[0].dims.d[2]; - params_.w = input_desc[0].dims.d[3]; - params_.c = input_desc[0].dims.d[1]; + // params_.n = input_desc[0].dims.d[0]; + // params_.h = input_desc[0].dims.d[2]; + // params_.w = input_desc[0].dims.d[3]; + // params_.c = input_desc[0].dims.d[1]; params_.groups = groups_; - params_.hw = params_.h * params_.w; - const int32_t blocksPerHW = findMaxDivisor(params_.hw, maxBlocksPerHW); - params_.hwPerBlock = divUp(params_.hw, blocksPerHW); + params_.dhw = params_.d * params_.h * params_.w; + const int32_t blocksPerDHW = findMaxDivisor(params_.dhw, maxBlocksPerDHW); + params_.dhwPerBlock = divUp(params_.dhw, blocksPerDHW); params_.cPerBlock = cPerBlock; params_.cPerGroup = params_.c / params_.groups; - params_.hwc = params_.hw * params_.c; - params_.invHWC = 1.F / static_cast(params_.hw * params_.cPerGroup); + params_.dhwc = params_.dhw * params_.c; + params_.invDHWC = 1.F / static_cast(params_.dhw * params_.cPerGroup); params_.groupsPerBlock = cPerBlock / params_.cPerGroup; params_.eps = eps_; cudaMemsetAsync(params_.redBuffer, 0, ws_, stream); - skipGroupNormNHWCSum(params_, stream); - skipGroupNormNHWCScale(params_, stream); + skipGroupNormNDHWCSum(params_, stream); + skipGroupNormNDHWCScale(params_, stream); } else { // input not fp16 diff --git a/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.h index 1260bbb8e2917f..b8fdb2c5ffc507 100644 --- a/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.h @@ -28,7 +28,7 @@ namespace paddle { namespace inference { namespace tensorrt { namespace plugin { -using phi::GroupNormNHWCParams; +using phi::GroupNormNDHWCParams; class SkipGroupnormActPluginDynamic : public DynamicPluginTensorRT { public: SkipGroupnormActPluginDynamic(const float* scale, @@ -169,7 +169,7 @@ class SkipGroupnormActPluginDynamic : public DynamicPluginTensorRT { std::vector bias_; std::shared_ptr scale_gpu_; std::shared_ptr bias_gpu_; - GroupNormNHWCParams<__half> params_; + GroupNormNDHWCParams<__half> params_; int groups_; float eps_; bool with_fp16_; diff --git a/paddle/fluid/ir_adaptor/translator/op_compat_gen.py b/paddle/fluid/ir_adaptor/translator/op_compat_gen.py index 89e3f7b7583f64..2e6f870d5449d5 100644 --- a/paddle/fluid/ir_adaptor/translator/op_compat_gen.py +++ b/paddle/fluid/ir_adaptor/translator/op_compat_gen.py @@ -170,6 +170,9 @@ def insert_new_mutable_attributes( op_arg_name_mappings['push_sparse_v2'].update( {"out_grad_in": "Out@GRAD", "out_grad_out": "Out@GRAD"} ) + op_arg_name_mappings['push_box_sparse'].update( + {"out_grad_in": "Out@GRAD", "out_grad_out": "Out@GRAD"} + ) op_arg_name_mappings['push_gpups_sparse'].update( {"out_grad": "Out@GRAD", "out_grad_grad": "Out@GRAD"} ) diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index a6d29fe6288bc8..adc1f38d5117d4 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -305,6 +305,11 @@ pir::OpInfo OpTranscriber::LookUpOpInfo(pir::IrContext* ctx, std::map> inputs = op_desc.Inputs(); std::vector input_types; for (const auto& pair : inputs) { + if (op_desc.Type() == "sparse_sum" || op_desc.Type() == "sparse_slice") { + if (pair.first != "x") { + continue; + } + } VarDesc* var_desc = op_desc.Block()->FindVarRecursive(pair.second[0]); PADDLE_ENFORCE_NE( var_desc, diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index a498b2aca31963..43d33643713420 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -134,7 +134,7 @@ set(COMMON_OP_DEPS ${COMMON_OP_DEPS} phi common) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_utils lod_tensor unpooling lod_rank_table context_project executor static_prim_api) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc static_prim_api static_utils static_global_utils prim_utils) -set(COMMON_OP_DEPS ${COMMON_OP_DEPS} cos_sim_functor concat_and_split sampler sample_prob tree2col) +set(COMMON_OP_DEPS ${COMMON_OP_DEPS} cos_sim_functor concat_and_split tree2col) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} beam_search) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} box_wrapper ps_gpu_wrapper) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} common_infer_shape_functions) diff --git a/paddle/fluid/operators/array_to_lod_tensor_op.cc b/paddle/fluid/operators/array_to_lod_tensor_op.cc index fae4ecbf9eb2b3..9c21dfbb1d327f 100644 --- a/paddle/fluid/operators/array_to_lod_tensor_op.cc +++ b/paddle/fluid/operators/array_to_lod_tensor_op.cc @@ -11,7 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device_context.h" @@ -77,7 +77,7 @@ struct ArrayToLoDFunctor { template template void ArrayToLoDFunctorImpl::apply() { - math::ConcatFunctor func; + phi::funcs::ConcatFunctor func; func(*dev_ctx_, prev_functor_->in, 0, prev_functor_->out); } diff --git a/paddle/fluid/operators/chunk_eval_op.cc b/paddle/fluid/operators/chunk_eval_op.cc new file mode 100644 index 00000000000000..1d2ebec27334cf --- /dev/null +++ b/paddle/fluid/operators/chunk_eval_op.cc @@ -0,0 +1,202 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/chunk_eval_op.h" + +#include +#include + +namespace paddle { +namespace operators { + +class ChunkEvalOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK( + ctx->HasInput("Inference"), "Input", "Inference", "chunk_eval"); + OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "chunk_eval"); + + OP_INOUT_CHECK( + ctx->HasOutput("Precision"), "Output", "Precision", "chunk_eval"); + OP_INOUT_CHECK(ctx->HasOutput("Recall"), "Output", "Recall", "chunk_eval"); + OP_INOUT_CHECK( + ctx->HasOutput("F1-Score"), "Output", "F1-Score", "chunk_eval"); + OP_INOUT_CHECK(ctx->HasOutput("NumInferChunks"), + "Output", + "NumInferChunks", + "chunk_eval"); + OP_INOUT_CHECK(ctx->HasOutput("NumLabelChunks"), + "Output", + "NumLabelChunks", + "chunk_eval"); + OP_INOUT_CHECK(ctx->HasOutput("NumCorrectChunks"), + "Output", + "NumCorrectChunks", + "chunk_eval"); + + auto inference_dim = ctx->GetInputDim("Inference"); + auto label_dim = ctx->GetInputDim("Label"); + + PADDLE_ENFORCE_EQ( + inference_dim, + label_dim, + phi::errors::InvalidArgument( + "Input(Inference)'s shape must be the same as Input(Label)'s " + "shape, but received [%s] (Inference) vs [%s] (Label).", + inference_dim, + label_dim)); + + bool use_padding = ctx->HasInput("SeqLength"); + if (use_padding) { + PADDLE_ENFORCE_EQ( + (inference_dim.size() == 3 && inference_dim[2] == 1) || + inference_dim.size() == 2, + true, + phi::errors::InvalidArgument( + "when Input(SeqLength) is provided, Input(Inference) " + "should be of dim 3 (batch_size, bucket, 1) or dim 2 " + "(batch_size, bucket), but received [%s].", + inference_dim)); + auto seq_length_dim = ctx->GetInputDim("SeqLength"); + PADDLE_ENFORCE_LE(seq_length_dim.size(), + 2, + phi::errors::InvalidArgument( + "Input(SeqLength)'s rank should not be greater " + "than 2, but received %d.", + seq_length_dim.size())); + } + + ctx->SetOutputDim("Precision", {1}); + ctx->SetOutputDim("Recall", {1}); + ctx->SetOutputDim("F1-Score", {1}); + ctx->SetOutputDim("NumInferChunks", {1}); + ctx->SetOutputDim("NumLabelChunks", {1}); + ctx->SetOutputDim("NumCorrectChunks", {1}); + } + + protected: + phi::KernelKey GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return phi::KernelKey(framework::proto::VarType::FP32, + platform::CPUPlace()); + } +}; + +class ChunkEvalOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Inference", + "(Tensor, default: Tensor). " + "Predictions from the network."); + AddInput("Label", + "(Tensor, default: Tensor). The true tag sequences."); + AddInput("SeqLength", + "(Tensor, default: Tensor). The length of each sequence, " + "used when Inference and Label are Tensor type .") + .AsDispensable(); + AddOutput("Precision", + "(float). The evaluated precision (called positive predictive " + "value) of chunks on the given mini-batch."); + AddOutput("Recall", + "(float). The evaluated recall (true positive rate or " + "sensitivity) of chunks on the given mini-batch."); + AddOutput("F1-Score", + "(float). The evaluated F1-Score on the given mini-batch."); + AddOutput("NumInferChunks", + "(int64_t). The number of chunks in Inference on the given " + "mini-batch."); + AddOutput( + "NumLabelChunks", + "(int64_t). The number of chunks in Label on the given mini-batch."); + AddOutput( + "NumCorrectChunks", + "(int64_t). The number of chunks both in Inference and Label on the " + "given mini-batch."); + AddAttr("num_chunk_types", + "The number of chunk type. See the description for details."); + AddAttr("chunk_scheme", + "The labeling scheme indicating " + "how to encode the chunks. Must be IOB, IOE, IOBES or " + "plain. See the description" + "for details.") + .SetDefault("IOB"); + AddAttr>("excluded_chunk_types", + "A list including chunk type ids " + "indicating chunk types that are not counted. " + "See the description for details.") + .SetDefault(std::vector{}); + AddComment(R"DOC( +For some basics of chunking, please refer to +'Chunking with Support Vector Machines '. + +ChunkEvalOp computes the precision, recall, and F1-score of chunk detection, +and supports IOB, IOE, IOBES and IO (also known as plain) tagging schemes. +Here is a NER example of labeling for these tagging schemes: + + Li Ming works at Agricultural Bank of China in Beijing. + IO I-PER I-PER O O I-ORG I-ORG I-ORG I-ORG O I-LOC + IOB B-PER I-PER O O B-ORG I-ORG I-ORG I-ORG O B-LOC + IOE I-PER E-PER O O I-ORG I-ORG I-ORG E-ORG O E-LOC + IOBES B-PER E-PER O O I-ORG I-ORG I-ORG E-ORG O S-LOC + +There are three chunk types(named entity types) including PER(person), ORG(organization) +and LOC(LOCATION), and we can see that the labels have the form -. + +Since the calculations actually use label ids rather than labels, extra attention +should be paid when mapping labels to ids to make CheckEvalOp work. The key point +is that the listed equations are satisfied by ids. + + tag_type = label % num_tag_type + chunk_type = label / num_tag_type + +where `num_tag_type` is the num of tag types in the tagging scheme, `num_chunk_type` +is the num of chunk types, and `tag_type` get its value from the following table. + + Scheme Begin Inside End Single + plain 0 - - - + IOB 0 1 - - + IOE - 0 1 - + IOBES 0 1 2 3 + +Still use NER as example, assuming the tagging scheme is IOB while chunk types are ORG, +PER and LOC. To satisfy the above equations, the label map can be like this: + + B-ORG 0 + I-ORG 1 + B-PER 2 + I-PER 3 + B-LOC 4 + I-LOC 5 + O 6 + +It's not hard to verify the equations noting that the num of chunk types +is 3 and the num of tag types in IOB scheme is 2. For example, the label +id of I-LOC is 5, the tag type id of I-LOC is 1, and the chunk type id of +I-LOC is 2, which consistent with the results from the equations. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(chunk_eval, + ops::ChunkEvalOp, + ops::ChunkEvalOpMaker); + +PD_REGISTER_STRUCT_KERNEL( + chunk_eval, CPU, ALL_LAYOUT, ops::ChunkEvalKernel, float) {} diff --git a/paddle/fluid/operators/chunk_eval_op.h b/paddle/fluid/operators/chunk_eval_op.h new file mode 100644 index 00000000000000..4b146176a43bc8 --- /dev/null +++ b/paddle/fluid/operators/chunk_eval_op.h @@ -0,0 +1,358 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include + +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class ChunkEvalKernel : public framework::OpKernel { + public: + struct Segment { + int begin; + int end; + int type; + bool operator==(const Segment& y) const { + return begin == y.begin && end == y.end && type == y.type; + } + }; + + void GetSegments(const int64_t* label, + int length, + std::vector* segments, + int num_chunk_types, + int num_tag_types, + int other_chunk_type, + int tag_begin, + int tag_inside, + int tag_end, + int tag_single) const { + segments->clear(); + segments->reserve(length); + int chunk_start = 0; + bool in_chunk = false; + int tag = -1; + int type = other_chunk_type; + for (int i = 0; i < length; ++i) { + int prev_tag = tag; + int prev_type = type; + PADDLE_ENFORCE_LE( + label[i], + num_chunk_types * num_tag_types, + phi::errors::InvalidArgument( + "The value of Input(Label) should be less than the number of " + "chunk types times the number of tag types, but received %d " + "(Label) vs %d (chunk types) * %d (tag types).", + label[i], + num_chunk_types, + num_tag_types)); + tag = label[i] % num_tag_types; + type = label[i] / num_tag_types; + if (in_chunk && ChunkEnd(prev_tag, + prev_type, + tag, + type, + other_chunk_type, + tag_begin, + tag_inside, + tag_end, + tag_single)) { + Segment segment{ + chunk_start, // begin + i - 1, // end + prev_type, + }; + segments->push_back(segment); + in_chunk = false; + } + if (ChunkBegin(prev_tag, + prev_type, + tag, + type, + other_chunk_type, + tag_begin, + tag_inside, + tag_end, + tag_single)) { + chunk_start = i; + in_chunk = true; + } + } + if (in_chunk) { + Segment segment{ + chunk_start, // begin + length - 1, // end + type, + }; + segments->push_back(segment); + } + } + + bool ChunkEnd(int prev_tag, + int prev_type, + int tag, + int type, + int other_chunk_type, + int tag_begin, + int tag_inside, + int tag_end, + int tag_single) const { + if (prev_type == other_chunk_type) return false; + if (type == other_chunk_type) return true; + if (type != prev_type) return true; + if (prev_tag == tag_begin) return tag == tag_begin || tag == tag_single; + if (prev_tag == tag_inside) return tag == tag_begin || tag == tag_single; + if (prev_tag == tag_end) return true; + if (prev_tag == tag_single) return true; + return false; + } + + bool ChunkBegin(int prev_tag, + int prev_type, + int tag, + int type, + int other_chunk_type, + int tag_begin, + int tag_inside, + int tag_end, + int tag_single) const { + if (prev_type == other_chunk_type) return type != other_chunk_type; + if (type == other_chunk_type) return false; + if (type != prev_type) return true; + if (tag == tag_begin) return true; + if (tag == tag_inside) return prev_tag == tag_end || prev_tag == tag_single; + if (tag == tag_end) return prev_tag == tag_end || prev_tag == tag_single; + if (tag == tag_single) return true; + return false; + } + + void Compute(const framework::ExecutionContext& context) const override { + // initialize to parse configurations + int num_chunk_types, num_tag_types; + int other_chunk_type; + int tag_begin, tag_inside, tag_end, tag_single; + std::vector label_segments; + std::vector output_segments; + std::set excluded_chunk_types; + + if (context.Attr("chunk_scheme") == "IOB") { + num_tag_types = 2; + tag_begin = 0; + tag_inside = 1; + tag_end = -1; + tag_single = -1; + } else if (context.Attr("chunk_scheme") == "IOE") { + num_tag_types = 2; + tag_begin = -1; + tag_inside = 0; + tag_end = 1; + tag_single = -1; + } else if (context.Attr("chunk_scheme") == "IOBES") { + num_tag_types = 4; + tag_begin = 0; + tag_inside = 1; + tag_end = 2; + tag_single = 3; + } else if (context.Attr("chunk_scheme") == "plain") { + num_tag_types = 1; + tag_begin = -1; + tag_inside = -1; + tag_end = -1; + tag_single = -1; + } else { + PADDLE_THROW(phi::errors::InvalidArgument("Unknown chunk scheme.")); + } + other_chunk_type = num_chunk_types = context.Attr("num_chunk_types"); + excluded_chunk_types.insert( + context.Attr>("excluded_chunk_types").begin(), + context.Attr>("excluded_chunk_types").end()); + + auto* inference = context.Input("Inference"); + auto place = inference->place(); + auto* label = context.Input("Label"); + auto* precision = context.Output("Precision"); + auto* recall = context.Output("Recall"); + auto* f1 = context.Output("F1-Score"); + auto* num_infer_chunks = context.Output("NumInferChunks"); + auto* num_label_chunks = context.Output("NumLabelChunks"); + auto* num_correct_chunks = + context.Output("NumCorrectChunks"); + + const int64_t* inference_data = inference->data(); + const int64_t* label_data = label->data(); + T* precision_data = precision->mutable_data(place); + T* recall_data = recall->mutable_data(place); + T* f1_data = f1->mutable_data(place); + int64_t* num_infer_chunks_data = + num_infer_chunks->mutable_data(place); + int64_t* num_label_chunks_data = + num_label_chunks->mutable_data(place); + int64_t* num_correct_chunks_data = + num_correct_chunks->mutable_data(place); + *num_infer_chunks_data = 0; + *num_label_chunks_data = 0; + *num_correct_chunks_data = 0; + + auto lod = label->lod(); + bool use_padding = lod.empty(); + int num_sequences = 0; + + if (use_padding) { + auto dim1 = inference->dims()[1]; + auto* seq_length_t = context.Input("SeqLength"); + auto* seq_length_data = seq_length_t->data(); + num_sequences = seq_length_t->dims()[0]; + + for (int i = 0; i < num_sequences; ++i) { + int seq_length = seq_length_data[i]; + EvalOneSeq(inference_data + i * dim1, + label_data + i * dim1, + seq_length, + &output_segments, + &label_segments, + num_infer_chunks_data, + num_label_chunks_data, + num_correct_chunks_data, + num_chunk_types, + num_tag_types, + other_chunk_type, + tag_begin, + tag_inside, + tag_end, + tag_single, + excluded_chunk_types); + } + } else { + PADDLE_ENFORCE_EQ( + lod.size(), + 1UL, + phi::errors::InvalidArgument( + "Only support one level LoD sequence now, but received %d.", + lod.size())); + PADDLE_ENFORCE_EQ( + lod, + inference->lod(), + phi::errors::InvalidArgument( + "Input(Inference) and Input(Label) of Op(chunk_eval) should have " + "same LoD information.")); + num_sequences = lod[0].size() - 1; + + for (int i = 0; i < num_sequences; ++i) { + int seq_length = lod[0][i + 1] - lod[0][i]; + EvalOneSeq(inference_data + lod[0][i], + label_data + lod[0][i], + seq_length, + &output_segments, + &label_segments, + num_infer_chunks_data, + num_label_chunks_data, + num_correct_chunks_data, + num_chunk_types, + num_tag_types, + other_chunk_type, + tag_begin, + tag_inside, + tag_end, + tag_single, + excluded_chunk_types); + } + } + + *precision_data = !(*num_infer_chunks_data) + ? 0 + : static_cast(*num_correct_chunks_data) / + (*num_infer_chunks_data); + *recall_data = !(*num_label_chunks_data) + ? 0 + : static_cast(*num_correct_chunks_data) / + (*num_label_chunks_data); + *f1_data = !(*num_correct_chunks_data) + ? 0 + : 2 * (*precision_data) * (*recall_data) / + ((*precision_data) + (*recall_data)); + } + + void EvalOneSeq(const int64_t* output, + const int64_t* label, + int length, + std::vector* output_segments, + std::vector* label_segments, + int64_t* num_output_segments, + int64_t* num_label_segments, + int64_t* num_correct, + int num_chunk_types, + int num_tag_types, + int other_chunk_type, + int tag_begin, + int tag_inside, + int tag_end, + int tag_single, + const std::set& excluded_chunk_types) const { + GetSegments(output, + length, + output_segments, + num_chunk_types, + num_tag_types, + other_chunk_type, + tag_begin, + tag_inside, + tag_end, + tag_single); + GetSegments(label, + length, + label_segments, + num_chunk_types, + num_tag_types, + other_chunk_type, + tag_begin, + tag_inside, + tag_end, + tag_single); + size_t i = 0, j = 0; + while (i < output_segments->size() && j < label_segments->size()) { + if (output_segments->at(i) == label_segments->at(j) && + excluded_chunk_types.count(output_segments->at(i).type) != 1) { + ++(*num_correct); + } + if (output_segments->at(i).end < label_segments->at(j).end) { + ++i; + } else if (output_segments->at(i).end > label_segments->at(j).end) { + ++j; + } else { + ++i; + ++j; + } + } + for (auto& segment : (*label_segments)) { + if (excluded_chunk_types.count(segment.type) != 1) { + ++(*num_label_segments); + } + } + for (auto& segment : (*output_segments)) { + if (excluded_chunk_types.count(segment.type) != 1) { + ++(*num_output_segments); + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/collective/c_concat_op.cu.cc b/paddle/fluid/operators/collective/c_concat_op.cu.cc index 7211c0f295d01f..22610a8fb1f15d 100644 --- a/paddle/fluid/operators/collective/c_concat_op.cu.cc +++ b/paddle/fluid/operators/collective/c_concat_op.cu.cc @@ -17,8 +17,8 @@ limitations under the License. */ #include #include "paddle/phi/core/distributed/comm_context_manager.h" -#include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/common/flags.h" @@ -151,7 +151,7 @@ class CConcatOpCUDAKernel : public framework::OpKernel { offset += rows_per_tensor; } - math::ConcatFunctor functor; + phi::funcs::ConcatFunctor functor; out->mutable_data(out_dims, place); auto& dev_ctx2 = ctx.template device_context(); functor(dev_ctx2, inputs, axis, out); diff --git a/paddle/fluid/operators/ctc_align_op.cc b/paddle/fluid/operators/ctc_align_op.cc new file mode 100644 index 00000000000000..a40ba846102935 --- /dev/null +++ b/paddle/fluid/operators/ctc_align_op.cc @@ -0,0 +1,133 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/ctc_align_op.h" + +namespace paddle { +namespace operators { + +class CTCAlignOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "ctc_align"); + OP_INOUT_CHECK(ctx->HasOutput("Output"), "Output", "Output", "ctc_align"); + + auto input_dims = ctx->GetInputDim("Input"); + + // TODO(wanghaoshuang): it is tricky to set the wrong dimension here. + ctx->SetOutputDim("Output", input_dims); + if (ctx->HasInput("InputLength")) { + ctx->SetOutputDim("OutputLength", {input_dims[0], 1}); + } + } + + protected: + phi::KernelKey GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context().GetPlace()); + } +}; + +class CTCAlignOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Input", + "2-D Tensor or LodTensor with shape " + "[Lp, 1], where Lp is the sum of all input sequences' length."); + AddInput("InputLength", + "2-D Tensor with shape [batch_size, 1], " + " When Input is padding mode, InputLength is length of every " + "sequence in Input.") + .AsDispensable(); + AddOutput("Output", "(Tensor, default: Tensor), The align result."); + AddOutput("OutputLength", + "2-D Tensor with shape [batch_size, 1], " + "When Input is padding mode, OutputLength is length of every " + "sequence in Output.") + .AsDispensable(); + AddAttr("blank", + "(int, default: 0), the blank label set in Connectionist " + "Temporal Classification (CTC) op.") + .SetDefault(0); + AddAttr("merge_repeated", + "(bool, default: true), whether to " + "merge repeated elements between two blanks. ") + .SetDefault(true); + // add attr padding number for tensor input + AddAttr("padding_value", + "(int, default: 0), padding number " + "use to padding tensor. ") + .SetDefault(0); + AddComment(R"DOC( +CTCAlign op is used to merge repeated elements between two blanks +and then delete all blanks in sequence. + +Given: + Input.data = [0, 1, 2, 2, 0, 4, 0, 4, 5, 0, 6, + 6, 0, 0, 7, 7, 7, 0] + Input.dims = {18, 1} + Input.LoD = [[0, 11, 18]] + +And: + blank = 0 + merge_repeated = True + +Then: + Output.data = [1, 2, 4, 4, 5, 6, + 6, 7] + Output.dims = {8, 1} + Output.LoD = [[0, 6, 8]] +or Given: + Input.data = [[0, 1, 2, 2, 0, 4], + [0, 4, 5, 0, 6, 0], + [0, 7, 7, 7, 0, 0]] + InputLength.data = [[6], + [5], + [4]], + Input.dims = {3, 6}, + Input.Lod = [] +And: + blank = 0 + merge_repeated = True + padding_value = 0 + +Then: + Output.data = [[1, 2, 4, 0, 0, 0], + [4, 5, 6, 0, 0, 0], + [7, 0, 0, 0, 0, 0]], + OutputLength.data = [[3], + [3], + [1]], + Output.dims = {3, 6}, + Output.Lod = [] +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR( + ctc_align, + ops::CTCAlignOp, + ops::CTCAlignOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); + +PD_REGISTER_STRUCT_KERNEL( + ctc_align, CPU, ALL_LAYOUT, ops::CTCAlignKernel, int, int64_t) {} diff --git a/paddle/fluid/operators/ctc_align_op.cu b/paddle/fluid/operators/ctc_align_op.cu new file mode 100644 index 00000000000000..76466ed12ab88f --- /dev/null +++ b/paddle/fluid/operators/ctc_align_op.cu @@ -0,0 +1,171 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include + +#include "paddle/fluid/operators/ctc_align_op.h" + +namespace paddle { +namespace operators { + +template +__global__ void MergeAndDelCudaKernel(const int64_t num_token, + const T* tokens, + const size_t num_seq, + size_t* lod0, + const int blank, + const int merge_repeated, + size_t* out_lod0, + T* output) { + int output_idx = 0; + out_lod0[0] = 0; + + for (int i = 0; i < num_seq; ++i) { + T pre_token = -1; + for (int j = lod0[i]; j < lod0[i + 1]; ++j) { + if (tokens[j] != blank && !(merge_repeated && tokens[j] == pre_token)) { + output[output_idx] = tokens[j]; + ++output_idx; + } + pre_token = tokens[j]; + } + out_lod0[i + 1] = output_idx; + } +} + +template +__global__ void PaddingMergeAndDelCudaKernel(const int64_t num_token, + const T* tokens, + const T* tokens_length, + const int blank, + const int merge_repeated, + const int padding_value, + const int64_t batch_size, + T* output, + T* output_length) { + int ind = blockIdx.x * blockDim.x + threadIdx.x; + if (ind >= batch_size) return; + int output_idx = ind * num_token; + T prev_token = -1; + for (int i = ind * num_token; i < ind * num_token + tokens_length[ind]; i++) { + if ((unsigned)tokens[i] != blank && + !(merge_repeated && tokens[i] == prev_token)) { + output[output_idx] = tokens[i]; + ++output_idx; + } + prev_token = tokens[i]; + } + output_length[ind] = output_idx - ind * num_token; + for (int i = output_idx; i < ind * num_token + num_token; i++) { + output[i] = padding_value; + } +} + +template +class CTCAlignOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), + true, + phi::errors::InvalidArgument( + "CTCAlign operator CUDA kernel must use CUDAPlace " + "rather than CPUPlace.")); + auto* input = ctx.Input("Input"); + auto* output = ctx.Output("Output"); + const int blank = ctx.Attr("blank"); + const int merge_repeated = + static_cast(ctx.Attr("merge_repeated")); + const T* tokens = input->data(); + auto stream = ctx.cuda_device_context().stream(); + + // tensor input which has no lod + if (input->lod().empty()) { + const int padding_value = ctx.Attr("padding_value"); + auto input_dims = input->dims(); + T* output_data = output->mutable_data({input_dims[0], input_dims[1]}, + ctx.GetPlace()); + auto* input_length = ctx.Input("InputLength"); + const T* input_length_data = input_length->data(); + auto* output_length = ctx.Output("OutputLength"); + T* output_length_data = + output_length->mutable_data({input_dims[0], 1}, ctx.GetPlace()); + PaddingMergeAndDelCudaKernel + <<<32, (input_dims[0] + 32 - 1) / 32, 0, stream>>>( + input_dims[1], + tokens, + input_length_data, + blank, + merge_repeated, + padding_value, + input_dims[0], + output_data, + output_length_data); + } else { + const size_t level = 0; + auto input_lod = framework::ToAbsOffset(input->lod()); + + const int64_t num_tokens = input->dims()[0]; + const size_t num_seq = input_lod[level].size() - 1; + + // prepare a lod to record lod information while merging elements + thrust::device_vector dev_out_lod0(input_lod[level].size()); + size_t* dev_out_lod0_ptr = thrust::raw_pointer_cast(dev_out_lod0.data()); + + // merge elements and delete blank + T* output_data = output->mutable_data({num_tokens, 1}, ctx.GetPlace()); + + phi::MixVector mixv_input_lod(&input_lod[level]); + MergeAndDelCudaKernel + <<<1, 1, 0, stream>>>(num_tokens, + tokens, + num_seq, + mixv_input_lod.CUDAMutableData(ctx.GetPlace()), + blank, + merge_repeated, + dev_out_lod0_ptr, + output_data); + mixv_input_lod.CopyToCPU(); + + // set output lod + std::vector host_out_lod0(dev_out_lod0.begin(), + dev_out_lod0.end()); + framework::LoD out_lod; + out_lod.push_back(host_out_lod0); + output->set_lod(out_lod); + + // resize output dims + output->Resize({static_cast(host_out_lod0.back()), 1}); + + if (host_out_lod0.back() == 0) { + output->Resize({1, 1}); + output->mutable_data(ctx.GetPlace()); + phi::funcs::SetConstant set_constant; + set_constant( + ctx.template device_context(), output, -1); + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +PD_REGISTER_STRUCT_KERNEL( + ctc_align, GPU, ALL_LAYOUT, ops::CTCAlignOpCUDAKernel, int, int64_t) {} diff --git a/paddle/fluid/operators/ctc_align_op.h b/paddle/fluid/operators/ctc_align_op.h new file mode 100644 index 00000000000000..9ebfa7196ecc56 --- /dev/null +++ b/paddle/fluid/operators/ctc_align_op.h @@ -0,0 +1,119 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace paddle { +namespace operators { + +template +class CTCAlignKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* output = ctx.Output("Output"); + size_t blank = static_cast(ctx.Attr("blank")); + bool merge_repeated = ctx.Attr("merge_repeated"); + T* output_data = output->mutable_data(ctx.GetPlace()); + auto input_dims = common::vectorize(input->dims()); + const T* input_data = input->data(); + + // support tensor input, no lod information + if (input->lod().empty()) { + size_t padding_value = + static_cast(ctx.Attr("padding_value")); + auto* input_length = ctx.Input("InputLength"); + const T* input_length_data = input_length->data(); + + auto* output_length = ctx.Output("OutputLength"); + T* output_length_data = output_length->mutable_data(ctx.GetPlace()); + + for (size_t batch_id = 0; batch_id < (unsigned)input_dims[0]; + batch_id++) { + T prev_token = -1; + size_t output_idx = 0; + for (size_t i = 0; i < (unsigned)input_length_data[batch_id]; i++) { + size_t input_ind = batch_id * input_dims[1] + i; + if ((unsigned)input_data[input_ind] != blank && + !(merge_repeated && input_data[input_ind] == prev_token)) { + output_data[batch_id * input_dims[1] + output_idx] = + input_data[input_ind]; + ++output_idx; + } + prev_token = input_data[input_ind]; + } + output_length_data[batch_id] = output_idx; + for (size_t j = output_idx; j < (unsigned)input_dims[1]; j++) + output_data[batch_id * input_dims[1] + j] = padding_value; + } + } else { + const size_t level = 0; + auto input_lod = framework::ToAbsOffset(input->lod()); + + // check input dims and lod + PADDLE_ENFORCE_EQ( + input_dims[0], + static_cast(input_lod[level].back()), + phi::errors::InvalidArgument( + "The first dimension %d of CTCAlign operator Input(Input) should " + "be equal to " + "the sum of all sequences' lengths %d.", + input_dims[0], + static_cast(input_lod[level].back()))); + + const size_t num_sequences = input_lod[level].size() - 1; + + // merge repeated tokens and delete blank + size_t output_idx = 0; + std::vector output_lod0(1, 0); + for (size_t seq_idx = 0; seq_idx < num_sequences; ++seq_idx) { + T prev_token = -1; + for (size_t i = input_lod[level][seq_idx]; + i < input_lod[level][seq_idx + 1]; + ++i) { + if ((unsigned)input_data[i] != blank && + !(merge_repeated && input_data[i] == prev_token)) { + output_data[output_idx] = input_data[i]; + ++output_idx; + } + prev_token = input_data[i]; + } + output_lod0.push_back(output_idx); + } + + // set output lod + framework::LoD output_lod; + output_lod.push_back(output_lod0); + output->set_lod(output_lod); + // resize output dims + output->Resize({static_cast(output_lod0.back()), 1}); + // for empty sequence + if (output_lod0.back() == 0) { + output->Resize({1, 1}); + output_data = output->mutable_data(ctx.GetPlace()); + output_data[0] = -1; + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/cudnn_lstm_op.cc b/paddle/fluid/operators/cudnn_lstm_op.cc deleted file mode 100644 index a082dbbcb8bcb5..00000000000000 --- a/paddle/fluid/operators/cudnn_lstm_op.cc +++ /dev/null @@ -1,285 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/phi/core/infermeta_utils.h" - -#include "paddle/phi/infermeta/multiary.h" - -namespace paddle { -namespace operators { - -class CudnnLSTMOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context().GetPlace()); - } -}; - -class CudnnLSTMOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput( - "Input", - "(Tensor) RNN input tensor, which support variable-time length input " - "sequence." - "The shape of the Tensor MUST be ( seq_len * batch_size * input_size)" - "seq_len is the total time step in this mini-batch (CAN be change in " - "different batch)" - "batch_size is the instance number of this batch" - "input_size is the hidden size of the input." - "input_size and the hidden_size in the next may not be same"); - AddInput("InitH", - "(Tensor) the initial hidden state of the LSTM" - "input. This is a tensor with shape (num_layers x batch_size x " - "hidden_size)" - "and When is_bidirec is True, the shape will be (num_layers*2 x " - "batch_size x hidden_size)"); - AddInput("InitC", - "(Tensor) the initial cell state of the LSTm " - "input. This is a tensor with shape (num_layers x batch_size x " - "hidden_size)" - "and When is_bidirec is True, the shape will be (num_layers*2 x " - "batch_size x hidden_size)"); - AddInput("W", - "(Tensor) the learnable hidden-hidden weights." - " The shape is (N), where N is total weight size of the LSTM. " - " cudnn concatenate all the weight to one Tensor") - .AsDispensable(); - AddInput("WeightList", - "(vector), stores weight and bias data when the weight " - "use the list format. ") - .AsDispensable() - .AsDuplicable(); - AddInput("SequenceLength", - "(Tensor) When the input data is padding, " - "set this parameter. This parameter represents " - "the variable sequence lengths in a batch. " - "The size of the vector has to equal the batch_size.") - .AsDispensable(); - AddOutput("Reserve", - "(Tensor, a temporary output Tensor to store the reserve_data " - "of cudnn kernel.") - .AsIntermediate(); - AddOutput("StateOut", - "Share memory with State. " - "Store the global drop state when training"); - AddOutput("Out", - "(Tensor) the hidden state of LSTM operator. " - "The shape is ( seq_len x batch_size x hidden_size) if " - "is_bidirec is False" - "and When is_bidirec is True, the shape will be ( seq_len x " - "batch_size x hidden_size * 2) "); - AddOutput("LastH", - "(Tensor) the hidden state of the last step. " - "The shape is ( num_layers x batch_size x hidden_size) if " - "is_bidirec is False" - "and When is_bidirec is True, the shape will be (num_layers*2 x " - "batch_size x hidden_size)"); - AddOutput("LastC", - "(Tensor) the cell state of the last step" - "The shape is ( num_layers x batch_size x hidden_size) if " - "is_bidirec is False" - "and When is_bidirect is True, the shape will be (num_layers*2 x " - "batch_size x hidden_size*2)"); - AddAttr( - "dropout_prob", - "dropout prob of the dropout op" - "the dropout ONLY work between lstm layers, not between time steps" - "There is no dropout work on the Out tensor") - .SetDefault(0.0); - AddAttr("is_bidirec", - "is_bidirec" - "if it is bidirectional rnn" - "The will affect the shape of the Out, LastH, and LastC") - .SetDefault(false); - AddAttr("input_size", "input size ot the Input Tensor").SetDefault(10); - AddAttr("hidden_size", "hidden size of the LSTM").SetDefault(100); - AddAttr("num_layers", "the total layer number of the LSTM") - .SetDefault(1); - AddAttr("is_test", "True if in test phase.").SetDefault(false); - AddAttr("seed", "seed to used if fix_seed is True").SetDefault(0); - AddComment(R"DOC( -CUDNN LSTM implementation - -A four-gate Long Short-Term Memory network with no peephole connections. -In the forward pass the output ht and cell output ct for a given iteration can be computed from the recurrent input ht-1, -the cell input ct-1 and the previous layer input xt given matrices W, R and biases bW, bR from the following equations: - -$$ i_t = sigmoid(W_{ix}x_{t} + W_{ih}h_{t-1} + bx_i + bh_i) $$ - -$$ f_t = sigmoid(W_{fx}x_{t} + W_{fh}h_{t-1} + bx_f + bh_f) $$ - -$$ o_t = sigmoid(W_{ox}x_{t} + W_{oh}h_{t-1} + bx_o + bh_o) $$ - -$$ \\tilde{c_t} = tanh(W_{cx}x_t + W_{ch}h_{t-1} + bx_c + bh_c) $$ - -$$ c_t = f_t \\odot c_{t-1} + i_t \\odot \\tilde{c_t} $$ - -$$ h_t = o_t \\odot tanh(c_t) $$ - -- W terms denote weight matrices (e.g. $W_{ix}$ is the matrix - of weights from the input gate to the input) -- The b terms denote bias vectors ($bx_i$ and $bh_i$ are the input gate bias vector). -- sigmoid is the logistic sigmoid function. -- $i, f, o$ and $c$ are the input gate, forget gate, output gate, - and cell activation vectors, respectively, all of which have the same size as - the cell output activation vector $h$. -- The $\odot$ is the element-wise product of the vectors. -- `tanh` is the activation functions. -- $\tilde{c_t}$ is also called candidate hidden state, - which is computed based on the current input and the previous hidden state. - -Where sigmoid is the sigmoid operator: sigmoid(x) = 1 / (1 + e^-x), * represents a point-wise multiplication, -X represents a matrix multiplication - - -)DOC"); - } -}; - -class CudnnLSTMGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "CudnnLSTMGrad"); - OP_INOUT_CHECK(ctx->HasInput("InitH"), "Input", "InitH", "CudnnLSTMGrad"); - OP_INOUT_CHECK(ctx->HasInput("InitC"), "Input", "InitC", "CudnnLSTMGrad"); - - auto SetOutGradDim = [&ctx](const std::string& name) { - auto g_name = framework::GradVarName(name); - if (ctx->HasOutput(g_name)) { - ctx->SetOutputDim(g_name, ctx->GetInputDim(name)); - } - }; - - SetOutGradDim("Input"); - if (ctx->HasInputs("WeightList")) { - ctx->SetOutputsDim(framework::GradVarName("WeightList"), - ctx->GetInputsDim("WeightList")); - } - SetOutGradDim("InitH"); - SetOutGradDim("InitC"); - } - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context().GetPlace()); - } -}; - -template -class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr op) const override { - op->SetType("cudnn_lstm_grad"); - op->SetInput("Input", this->Input("Input")); - op->SetInput("InitH", this->Input("InitH")); - op->SetInput("InitC", this->Input("InitC")); - if (this->HasInput("WeightList")) { - op->SetInput("WeightList", this->Input("WeightList")); - } - if (this->HasInput("SequenceLength")) { - op->SetInput("SequenceLength", this->Input("SequenceLength")); - } - op->SetInput("Reserve", this->Output("Reserve")); - op->SetInput("StateOut", this->Output("StateOut")); - op->SetInput("Out", this->Output("Out")); - op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - op->SetInput(framework::GradVarName("LastC"), this->OutputGrad("LastC")); - op->SetInput(framework::GradVarName("LastH"), this->OutputGrad("LastH")); - - if (this->HasInput("WeightList")) { - op->SetOutput(framework::GradVarName("WeightList"), - this->InputGrad("WeightList", false)); - } - - op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input")); - op->SetOutput(framework::GradVarName("InitH"), this->InputGrad("InitH")); - op->SetOutput(framework::GradVarName("InitC"), this->InputGrad("InitC")); - op->SetAttrMap(this->Attrs()); - } -}; - -} // namespace operators -} // namespace paddle - -DECLARE_INFER_SHAPE_FUNCTOR(cudnn_lstm, - CudnnLSTMInferShapeFunctor, - PD_INFER_META(phi::CudnnLSTMInferMeta)); - -namespace ops = paddle::operators; -REGISTER_OPERATOR(cudnn_lstm, - ops::CudnnLSTMOp, - ops::CudnnLSTMOpMaker, - ops::CudnnLSTMGradOpMaker, - ops::CudnnLSTMGradOpMaker, - CudnnLSTMInferShapeFunctor); - -REGISTER_OPERATOR(cudnn_lstm_grad, ops::CudnnLSTMGradOp); - -// TODO(Shixiaowei02) Add ModifyInput support -REGISTER_OP_VERSION(cudnn_lstm) - .AddCheckpoint( - R"ROC( - Upgrade cudnn_lstm add new inputs [WeightList, SequenceLength], modify the input [W] to dispensable, delete the input [Cache]. - Upgrade cudnn_lstm add new outputs [StateOut, Reserve, LastC, LastH], delete output [last_c, last_h]. - Upgrade cudnn_lstm modify the attr [seed] default value to 0, delete the attr [max_len].)ROC", - paddle::framework::compatible::OpVersionDesc() - .NewInput( - "WeightList", - "The WeightList stores weight and bias data. WeightList is " - "dispensable.") - .NewInput("SequenceLength", - "When the input data is padding, set this parameter. " - "SequenceLength is dispensable.") - .ModifyInput("W", - "The new LSTM use WeightList instead of W. The W " - "concatenate all the weight to one Tensor.") - .DeleteInput("Cache", - "The new LSTM use the Reserve Output to store the " - "data of dropout.") - .NewOutput("StateOut", "Store the global drop state when training") - .NewOutput("Reserve", - "A temporary output Tensor to store the reserve_data") - .DeleteOutput( - "last_c", - "Modify the name of the output from 'last_c' to 'LastC'.") - .NewOutput("LastC", "The cell state of the last step.") - .DeleteOutput( - "last_h", - "Modify the name of the output from 'last_h' to 'LastH'.") - .NewOutput("LastH", "The hidden state of the last step.") - .ModifyAttr("seed", - "Set the default value of seed from '-1' to '0'.", - 0) - .DeleteAttr("max_len", - "The length of Inputs is achieved form the input data " - "which is difficult to know the information in " - "advance.")); diff --git a/paddle/fluid/operators/cudnn_rnn_cache.h b/paddle/fluid/operators/cudnn_rnn_cache.h index 7dd81d230bd1d2..e51d558a36c618 100644 --- a/paddle/fluid/operators/cudnn_rnn_cache.h +++ b/paddle/fluid/operators/cudnn_rnn_cache.h @@ -17,7 +17,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" +#include "paddle/phi/backends/gpu/gpu_dnn.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/detection/bbox_util.cu.h b/paddle/fluid/operators/detection/bbox_util.cu.h index adb60a8a8d0642..abd34c3c2025a2 100644 --- a/paddle/fluid/operators/detection/bbox_util.cu.h +++ b/paddle/fluid/operators/detection/bbox_util.cu.h @@ -23,8 +23,8 @@ limitations under the License. */ #include namespace cub = hipcub; #endif -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/fluid/platform/for_range.h" +#include "paddle/phi/backends/gpu/gpu_dnn.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { diff --git a/paddle/fluid/operators/detection/collect_fpn_proposals_op.cu b/paddle/fluid/operators/detection/collect_fpn_proposals_op.cu index b2bbd9c82095c8..65cb7d3043d18d 100644 --- a/paddle/fluid/operators/detection/collect_fpn_proposals_op.cu +++ b/paddle/fluid/operators/detection/collect_fpn_proposals_op.cu @@ -23,10 +23,10 @@ namespace cub = hipcub; #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/operators/detection/bbox_util.h" #include "paddle/fluid/operators/detection/collect_fpn_proposals_op.h" -#include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/platform/for_range.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/core/mixed_vector.h" +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/phi/kernels/funcs/gather.cu.h" #include "paddle/phi/kernels/funcs/strided_memcpy.h" diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index 713ad1931ce236..e556949fa0f0e0 100755 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -6,7 +6,6 @@ endif() register_operators( EXCLUDES fused_bn_activation_op - yolo_box_head_op yolo_box_post_op fusion_group_op fusion_lstm_op @@ -39,7 +38,6 @@ if(WITH_GPU OR WITH_ROCM) endif() # HIP not support cudnnTransformTensor # HIP not support cudnnConvolutionBiasActivationForward - op_library(yolo_box_head_op) op_library(yolo_box_post_op) op_library(fused_gate_attention_op) # fusion_group diff --git a/paddle/fluid/operators/fused/attention_layer_norm.h b/paddle/fluid/operators/fused/attention_layer_norm.h deleted file mode 100644 index 92cbc37059eb14..00000000000000 --- a/paddle/fluid/operators/fused/attention_layer_norm.h +++ /dev/null @@ -1,113 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h" - -namespace paddle { -namespace operators { - -// NOTE: T must be the same as OutType in ComputeBackward -template -class AttnLayerNorm { - public: - AttnLayerNorm(const phi::GPUContext& dev_ctx, - float epsilon, - int64_t batch_size, - int64_t feature_size) - : dev_ctx_(dev_ctx), - epsilon_(epsilon), - batch_size_(batch_size), - feature_size_(feature_size) {} - - ~AttnLayerNorm() {} - - void ComputeForward(const InType* x_data, - const phi::funcs::LayerNormParamType* scale_data, - const phi::funcs::LayerNormParamType* bias_data, - OutType* y_data, - phi::funcs::LayerNormParamType* mean_data, - phi::funcs::LayerNormParamType* var_data, - const float* dequant_out_scale_data = nullptr, - const int quant_out_scale_offset = 0, - const float quant_in_scale = 1.0, - const int quant_round_type = 1, - const float quant_max_bound = 127.0, - const float quant_min_bound = -127.0) { - auto stream = dev_ctx_.stream(); - - switch (phi::funcs::GetDesiredBlockDim(feature_size_)) { - FIXED_BLOCK_DIM_CASE( - phi::funcs::LayerNormForward, - kBlockDim, - false, - InType, - OutType> - <<>>(x_data, - scale_data, - bias_data, - y_data, - mean_data, - var_data, - epsilon_, - feature_size_, - dequant_out_scale_data, - quant_out_scale_offset, - quant_in_scale, - quant_round_type, - quant_max_bound, - quant_min_bound)); - default: - PADDLE_THROW( - phi::errors::InvalidArgument("Feature_size must be larger than 1")); - break; - } - } - - void ComputeBackward(const T* x_data, - const T* d_y_data, - const phi::funcs::LayerNormParamType* scale_data, - const phi::funcs::LayerNormParamType* mean_data, - const phi::funcs::LayerNormParamType* var_data, - T* d_x_data, - phi::funcs::LayerNormParamType* d_scale_data, - phi::funcs::LayerNormParamType* d_bias_data) { - phi::funcs::LayerNormBackward>( - x_data, - d_y_data, - scale_data, - mean_data, - var_data, - d_x_data, - d_scale_data, - d_bias_data, - epsilon_, - batch_size_, - feature_size_, - dev_ctx_); - } - - private: - const phi::GPUContext& dev_ctx_; - - int64_t batch_size_; - int64_t feature_size_; - - float epsilon_; -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h b/paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h index ba13879b5a8dea..7ebad5c07bf22e 100644 --- a/paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h +++ b/paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h @@ -15,7 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/fluid/operators/fused/cudnn_fusion_helper.h" -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" +#include "paddle/phi/backends/gpu/gpu_dnn.h" namespace paddle { namespace operators { @@ -23,15 +23,16 @@ namespace operators { namespace dynload = phi::dynload; template using BatchNormParamType = - typename platform::CudnnDataType::BatchNormParamType; + typename phi::backends::gpu::CudnnDataType::BatchNormParamType; #if CUDNN_VERSION >= 8000 template struct BNStatsFinalizeArgs { BNStatsFinalizeArgs() { - dtype = platform::CudnnDataType::type; - param_dtype = platform::CudnnDataType>::type; + dtype = phi::backends::gpu::CudnnDataType::type; + param_dtype = + phi::backends::gpu::CudnnDataType>::type; format = CUDNN_TENSOR_NHWC; } diff --git a/paddle/fluid/operators/fused/cudnn_norm_conv.cu.h b/paddle/fluid/operators/fused/cudnn_norm_conv.cu.h index b8f88e602b8517..ecfe4dad538432 100644 --- a/paddle/fluid/operators/fused/cudnn_norm_conv.cu.h +++ b/paddle/fluid/operators/fused/cudnn_norm_conv.cu.h @@ -15,14 +15,15 @@ limitations under the License. */ #pragma once #include "paddle/fluid/operators/fused/cudnn_fusion_helper.h" -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" +#include "paddle/phi/backends/gpu/gpu_dnn.h" namespace paddle { namespace operators { namespace dynload = phi::dynload; template -using ScalingParamType = typename platform::CudnnDataType::ScalingParamType; +using ScalingParamType = + typename phi::backends::gpu::CudnnDataType::ScalingParamType; #if CUDNN_VERSION >= 8000 @@ -31,9 +32,9 @@ static size_t RoundUp(int64_t a, int64_t b) { return (a + b - 1) / b * b; } template struct NormConvolutionArgs { NormConvolutionArgs() { - dtype = platform::CudnnDataType::type; + dtype = phi::backends::gpu::CudnnDataType::type; format = CUDNN_TENSOR_NHWC; - compute_type = platform::CudnnDataType::type; + compute_type = phi::backends::gpu::CudnnDataType::type; } void Set(const phi::GPUContext &ctx, diff --git a/paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h b/paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h index 25a1c963a7f28d..768845476a428d 100644 --- a/paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h +++ b/paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h @@ -15,24 +15,25 @@ limitations under the License. */ #pragma once #include "paddle/fluid/operators/fused/cudnn_fusion_helper.h" -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" +#include "paddle/phi/backends/gpu/gpu_dnn.h" namespace paddle { namespace operators { template -using CudnnDataType = platform::CudnnDataType; +using CudnnDataType = phi::backends::gpu::CudnnDataType; namespace dynload = phi::dynload; template using BatchNormParamType = - typename platform::CudnnDataType::BatchNormParamType; + typename phi::backends::gpu::CudnnDataType::BatchNormParamType; #if CUDNN_VERSION >= 8000 template struct ScaleBiasAddReluArgs { ScaleBiasAddReluArgs() { - dtype = platform::CudnnDataType::type; - param_dtype = platform::CudnnDataType>::type; + dtype = phi::backends::gpu::CudnnDataType::type; + param_dtype = + phi::backends::gpu::CudnnDataType>::type; format = CUDNN_TENSOR_NHWC; } diff --git a/paddle/fluid/operators/fused/fmha_ref.h b/paddle/fluid/operators/fused/fmha_ref.h deleted file mode 100644 index 2a43eea07535ab..00000000000000 --- a/paddle/fluid/operators/fused/fmha_ref.h +++ /dev/null @@ -1,750 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/operators/fused/fused_softmax_mask.cu.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" -#include "paddle/phi/kernels/funcs/broadcast_function.h" -#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" -#include "paddle/phi/kernels/funcs/dropout_impl.cu.h" -#include "paddle/phi/kernels/funcs/elementwise_base.h" -#include "paddle/phi/kernels/funcs/elementwise_functor.h" -#include "paddle/phi/kernels/funcs/functors.h" -#include "paddle/phi/kernels/funcs/transpose_function.cu.h" -#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h" - -namespace paddle { -namespace operators { - -class AttnDropoutParam { - public: - AttnDropoutParam() { - is_test_ = false; - dropout_implementation_ = "downgrade_in_infer"; - dropout_prob_ = 0.5; - is_upscale_in_train_ = false; - is_fix_seed_ = false; - seed_val_ = 0; - seed_ = nullptr; - } - AttnDropoutParam(bool is_test, - const std::string dropout_implementation, - float dropout_prob, - bool is_upscale_in_train, - bool is_fix_seed, - int seed_val, - const phi::DenseTensor* seed) { - is_test_ = is_test; - dropout_implementation_ = dropout_implementation; - dropout_prob_ = dropout_prob; - is_upscale_in_train_ = is_upscale_in_train; - is_fix_seed_ = is_fix_seed; - seed_val_ = seed_val; - seed_ = seed; - } - bool is_test_; - std::string dropout_implementation_; - float dropout_prob_; - bool is_upscale_in_train_; - bool is_fix_seed_; - int seed_val_; - const phi::DenseTensor* seed_; -}; - -template -__global__ void TransposeRemovingPadding(const T* input_data, - T* output_data, - const int batch_size, - const int num_head, - const int seq_len, - const int head_dim, - const int token_num, - const int elem_cnt, - const int* padding_offset) { - // transpose and remove padding - // [batch_size, num_head, seq_len, head_dim] -> [token_num, num_head, - // head_dim] - int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; - const int dim_embed = num_head * head_dim; - using LoadT = phi::AlignedVector; - LoadT src_vec; - - for (int32_t linear_index = idx * VecSize, - step = gridDim.x * blockDim.x * VecSize; - linear_index < elem_cnt; - linear_index += step) { - const int token_idx = linear_index / dim_embed; - const int ori_token_idx = - token_idx + (padding_offset == nullptr ? 0 : padding_offset[token_idx]); - const int ori_batch_id = ori_token_idx / seq_len; - const int ori_seq_id = ori_token_idx % seq_len; - const int ori_head_id = (linear_index % dim_embed) / head_dim; - const int ori_head_lane = (linear_index % dim_embed) % head_dim; - const int ori_idx = ori_batch_id * num_head * seq_len * head_dim + - ori_head_id * seq_len * head_dim + - ori_seq_id * head_dim + ori_head_lane; - phi::Load(&input_data[ori_idx], &src_vec); - phi::Store(src_vec, &output_data[linear_index]); - } -} - -template -void InvokeTransposeRemovePadding(const phi::GPUContext& dev_ctx, - const T* input_data, - T* output_data, - const int batch_size, - const int num_head, - const int seq_len, - const int head_dim, - const int token_num, - const int* padding_offset) { - // [batch_size, num_head, seq_len, head_dim] -> [token_num, num_head, - // head_dim] - constexpr int VEC_16B = 16; - const int elem_cnt = token_num * num_head * head_dim; - constexpr int PackSize = VEC_16B / sizeof(T); - PADDLE_ENFORCE_EQ( - head_dim % PackSize, - 0, - phi::errors::PreconditionNotMet( - "dim_head=%d must be divisible by vec_size=%d", head_dim, PackSize)); - const int32_t pack_num = elem_cnt / PackSize; - const int32_t block_size = 128; - int32_t grid_size = (pack_num + block_size - 1) / block_size; - TransposeRemovingPadding - <<>>(input_data, - output_data, - batch_size, - num_head, - seq_len, - head_dim, - token_num, - elem_cnt, - padding_offset); -} - -template -class FMHARef { - public: - FMHARef(const phi::GPUContext& dev_ctx, - int64_t batch_size, - int64_t seq_len, - int64_t num_head, - int64_t head_dim, - AttnDropoutParam param) - : dev_ctx_(dev_ctx), - batch_size_(batch_size), - seq_len_(seq_len), - num_head_(num_head), - head_dim_(head_dim), - dropout_param_(param) {} - - ~FMHARef() {} - - void ComputeForward(const phi::DenseTensor& qkv_input_tensor, - const phi::DenseTensor* cache_kv_tensor, - const phi::DenseTensor* src_mask_tensor, - phi::DenseTensor* transpose_2_out_tensor, - phi::DenseTensor* cache_kv_out_tensor, - phi::DenseTensor* qk_out_tensor, - phi::DenseTensor* src_mask_out_tensor, - phi::DenseTensor* softmax_out_tensor, - phi::DenseTensor* dropout_mask_out_tensor, - phi::DenseTensor* dropout_out_tensor, - phi::DenseTensor* qktv_out_tensor, - phi::DenseTensor* fmha_out_tensor) { - // input shape: [bs, seq_len, 3, num_head, head_dim] - // transpose with perm [2, 0, 3, 1, 4], - // output_shape: [3, bs, num_head, seq_len, head_dim] - std::vector perm_1 = {2, 0, 3, 1, 4}; - phi::funcs::TransposeGPUKernelDriver( - dev_ctx_, qkv_input_tensor, perm_1, transpose_2_out_tensor); - T* qkv_data = transpose_2_out_tensor->data(); - T* qk_out_data = qk_out_tensor->data(); - T* qktv_out_data = qktv_out_tensor->data(); - T* softmax_out_data = softmax_out_tensor->data(); - T* fmha_out_data = fmha_out_tensor->data(); - - auto out_seq_len = seq_len_; - if (cache_kv_tensor) { - // kv [2, bs, num_head, seq_len, head_dim] - auto kv_tensor = transpose_2_out_tensor->Slice(1, 3); - phi::funcs::ConcatFunctor concat; - // out [2, bs, num_head, cache_seq_len + seq_len, head_dim] - concat(dev_ctx_, {*cache_kv_tensor, kv_tensor}, 3, cache_kv_out_tensor); - out_seq_len = cache_kv_out_tensor->dims()[3]; - } - - int64_t q_size = batch_size_ * seq_len_ * num_head_ * head_dim_; - T* q_ptr = qkv_data; - T* k_ptr = nullptr; - T* v_ptr = nullptr; - - if (cache_kv_tensor) { - int64_t k_size = cache_kv_out_tensor->numel() / 2; - k_ptr = cache_kv_out_tensor->data(); - v_ptr = k_ptr + k_size; - } else { - int64_t k_size = q_size; - k_ptr = q_ptr + q_size; - v_ptr = k_ptr + k_size; - } - - { - // NOTE(wangxi): We scale Q with 1/sqrt(Dh) before QK^T, because for - // float16 calculation, INF may appear in QK^T if we do not scale before. - float alpha = 1.0 / sqrt(head_dim_); - auto q_tensor = transpose_2_out_tensor->Slice(0, 1); - auto functor = phi::funcs::ScaleFunctor(alpha); - std::vector ins = {&q_tensor}; - std::vector outs = {&q_tensor}; - phi::funcs::ElementwiseKernel(dev_ctx_, ins, &outs, functor); - } - - // q*k^t, batched_gemm - CBLAS_TRANSPOSE transA = CblasNoTrans; - CBLAS_TRANSPOSE transB = CblasTrans; - auto blas = phi::funcs::GetBlas(dev_ctx_); - int gemm_batch_size = batch_size_ * num_head_; - int gemm_m = seq_len_; - int gemm_n = out_seq_len; - int gemm_k = head_dim_; - T alpha = static_cast(1.0); - T beta = static_cast(0.0); - int64_t stride_a = gemm_m * gemm_k; - int64_t stride_b = gemm_k * gemm_n; - blas.BatchedGEMM(transA, - transB, - gemm_m, - gemm_n, - gemm_k, - alpha, - q_ptr, - k_ptr, - beta, - qk_out_data, - gemm_batch_size, - stride_a, - stride_b); - int softmax_axis = -1; - if (src_mask_tensor != nullptr) { - if (src_mask_out_tensor == nullptr && seq_len_ == out_seq_len) { - LaunchFusedSoftmaxMaskKernel(qk_out_data, - src_mask_tensor->data(), - softmax_out_data, - batch_size_, - num_head_, - seq_len_, - dev_ctx_.stream()); - } else { - std::vector ins; - std::vector outs; - ins.emplace_back(qk_out_tensor); - ins.emplace_back(src_mask_tensor); - outs.emplace_back(src_mask_out_tensor); - int elewise_add_axis = -1; - phi::funcs::BroadcastKernel(dev_ctx_, - ins, - &outs, - phi::funcs::AddFunctor(), - elewise_add_axis); - - phi::SoftmaxForwardCUDAKernelDriver( - dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor); - } - } else { - phi::SoftmaxForwardCUDAKernelDriver( - dev_ctx_, *qk_out_tensor, softmax_axis, softmax_out_tensor); - } - - transB = CblasNoTrans; - gemm_m = seq_len_; - gemm_n = head_dim_; - gemm_k = out_seq_len; - alpha = static_cast(1.0); - stride_a = gemm_m * gemm_k; - stride_b = gemm_k * gemm_n; - - if (dropout_param_.dropout_prob_) { - phi::funcs::DropoutFwGPUKernelDriver( - static_cast(dev_ctx_), - dropout_param_.is_test_, - dropout_param_.dropout_prob_, - dropout_param_.is_upscale_in_train_, - dropout_param_.is_fix_seed_, - dropout_param_.seed_val_, - static_cast(*softmax_out_tensor), - dropout_param_.seed_, - dropout_mask_out_tensor, - dropout_out_tensor, - false); - T* dropout_out_data = dropout_out_tensor->data(); - blas.BatchedGEMM(transA, - transB, - gemm_m, - gemm_n, - gemm_k, - alpha, - dropout_out_data, - v_ptr, - beta, - qktv_out_data, - gemm_batch_size, - stride_a, - stride_b); - } else { - // softmax_out * v, batched_gemm - // output shape: [batch_size, num_heads, seq_len, head_dim] - blas.BatchedGEMM(transA, - transB, - gemm_m, - gemm_n, - gemm_k, - alpha, - softmax_out_data, - v_ptr, - beta, - qktv_out_data, - gemm_batch_size, - stride_a, - stride_b); - } - // transpose: [0, 2, 1, 3] - // output shape: [batch_size, seq_len, num_heads, head_dim] - std::vector perm_3 = {0, 2, 1, 3}; - phi::funcs::TransposeGPUKernelDriver( - dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor); - } - - void ComputeForwardWithoutTranspose( - const phi::DenseTensor* cache_kv_tensor, - const phi::DenseTensor* src_mask_tensor, - const phi::DenseTensor* padding_offset_tensor, - phi::DenseTensor* q_transpose_out_tensor, - phi::DenseTensor* kv_transpose_out_tensor, - phi::DenseTensor* cache_kv_out_tensor, - phi::DenseTensor* qk_out_tensor, - phi::DenseTensor* src_mask_out_tensor, - phi::DenseTensor* softmax_out_tensor, - phi::DenseTensor* dropout_mask_out_tensor, - phi::DenseTensor* dropout_out_tensor, - phi::DenseTensor* qktv_out_tensor, - phi::DenseTensor* fmha_out_tensor, - const int token_num) { - // input shape: [bs, seq_len, 3, num_head, head_dim] - // transpose with perm [2, 0, 3, 1, 4], - // output_shape: [3, bs, num_head, seq_len, head_dim] - T* qk_out_data = qk_out_tensor->data(); - T* qktv_out_data = qktv_out_tensor->data(); - T* softmax_out_data = softmax_out_tensor->data(); - T* dropout_out_data = dropout_out_tensor->data(); - T* fmha_out_data = fmha_out_tensor->data(); - - auto out_seq_len = seq_len_; - if (cache_kv_tensor) { - // kv [2, bs, num_head, seq_len, head_dim] - phi::funcs::ConcatFunctor concat; - // out [2, bs, num_head, cache_seq_len + seq_len, head_dim] - concat(dev_ctx_, - {*cache_kv_tensor, *kv_transpose_out_tensor}, - 3, - cache_kv_out_tensor); - out_seq_len = cache_kv_out_tensor->dims()[3]; - } - - int64_t q_size = batch_size_ * seq_len_ * num_head_ * head_dim_; - T* q_ptr = q_transpose_out_tensor->data(); - T* k_ptr = nullptr; - T* v_ptr = nullptr; - - if (cache_kv_tensor) { - int64_t k_size = cache_kv_out_tensor->numel() / 2; - k_ptr = cache_kv_out_tensor->data(); - v_ptr = k_ptr + k_size; - } else { - int64_t k_size = q_size; - k_ptr = kv_transpose_out_tensor->data(); - v_ptr = k_ptr + k_size; - } - - { - // NOTE(wangxi): We scale Q with 1/sqrt(Dh) before QK^T, because for - // float16 calculation, INF may appear in QK^T if we do not scale before. - float alpha = 1.0 / sqrt(head_dim_); - auto functor = phi::funcs::ScaleFunctor(alpha); - std::vector ins = {q_transpose_out_tensor}; - std::vector outs = {q_transpose_out_tensor}; - phi::funcs::ElementwiseKernel(dev_ctx_, ins, &outs, functor); - } - - // q*k^t, batched_gemm - CBLAS_TRANSPOSE transA = CblasNoTrans; - CBLAS_TRANSPOSE transB = CblasTrans; - auto blas = phi::funcs::GetBlas(dev_ctx_); - int gemm_batch_size = batch_size_ * num_head_; - int gemm_m = seq_len_; - int gemm_n = out_seq_len; - int gemm_k = head_dim_; - T alpha = static_cast(1.0); - T beta = static_cast(0.0); - int64_t stride_a = gemm_m * gemm_k; - int64_t stride_b = gemm_k * gemm_n; - blas.BatchedGEMM(transA, - transB, - gemm_m, - gemm_n, - gemm_k, - alpha, - q_ptr, - k_ptr, - beta, - qk_out_data, - gemm_batch_size, - stride_a, - stride_b); - int softmax_axis = -1; - if (src_mask_tensor != nullptr) { - if (src_mask_out_tensor == nullptr && seq_len_ == out_seq_len) { - LaunchFusedSoftmaxMaskKernel(qk_out_data, - src_mask_tensor->data(), - softmax_out_data, - batch_size_, - num_head_, - seq_len_, - dev_ctx_.stream()); - } else { - std::vector ins; - std::vector outs; - ins.emplace_back(qk_out_tensor); - ins.emplace_back(src_mask_tensor); - outs.emplace_back(src_mask_out_tensor); - int elewise_add_axis = -1; - phi::funcs::BroadcastKernel(dev_ctx_, - ins, - &outs, - phi::funcs::AddFunctor(), - elewise_add_axis); - - phi::SoftmaxForwardCUDAKernelDriver( - dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor); - } - } else { - phi::SoftmaxForwardCUDAKernelDriver( - dev_ctx_, *qk_out_tensor, softmax_axis, softmax_out_tensor); - } - - transB = CblasNoTrans; - gemm_m = seq_len_; - gemm_n = head_dim_; - gemm_k = out_seq_len; - alpha = static_cast(1.0); - stride_a = gemm_m * gemm_k; - stride_b = gemm_k * gemm_n; - - if (dropout_param_.dropout_prob_) { - phi::funcs::DropoutFwGPUKernelDriver( - static_cast(dev_ctx_), - dropout_param_.is_test_, - dropout_param_.dropout_prob_, - dropout_param_.is_upscale_in_train_, - dropout_param_.is_fix_seed_, - dropout_param_.seed_val_, - static_cast(*softmax_out_tensor), - dropout_param_.seed_, - dropout_mask_out_tensor, - dropout_out_tensor, - false); - blas.BatchedGEMM(transA, - transB, - gemm_m, - gemm_n, - gemm_k, - alpha, - dropout_out_data, - v_ptr, - beta, - qktv_out_data, - gemm_batch_size, - stride_a, - stride_b); - } else { - // softmax_out * v, batched_gemm - // output shape: [batch_size, num_heads, seq_len, head_dim] - blas.BatchedGEMM(transA, - transB, - gemm_m, - gemm_n, - gemm_k, - alpha, - softmax_out_data, - v_ptr, - beta, - qktv_out_data, - gemm_batch_size, - stride_a, - stride_b); - } - // transpose: [0, 2, 1, 3] - // output shape: [batch_size, seq_len, num_heads, head_dim] - if (!padding_offset_tensor) { - std::vector perm_3 = {0, 2, 1, 3}; - phi::funcs::TransposeGPUKernelDriver( - dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor); - } else { - InvokeTransposeRemovePadding(dev_ctx_, - qktv_out_data, - fmha_out_data, - batch_size_, - num_head_, - seq_len_, - head_dim_, - token_num, - padding_offset_tensor->data()); - } - } - - void ComputeBackward(const phi::DenseTensor& transpose_2_out_tensor, - const phi::DenseTensor* src_mask_tensor, - const phi::DenseTensor& softmax_out_tensor, - const phi::DenseTensor& dropout_mask_out_tensor, - const phi::DenseTensor& dropout_out_tensor, - const phi::DenseTensor& qk_out_tensor, - const phi::DenseTensor& src_mask_out_tensor, - const phi::DenseTensor& fmha_out_grad_tensor, - phi::DenseTensor* qktv_out_grad_tensor, - phi::DenseTensor* dropout_out_grad_tensor, - phi::DenseTensor* softmax_out_grad_tensor, - phi::DenseTensor* src_mask_out_grad_tensor, - phi::DenseTensor* qk_out_grad_tensor, - phi::DenseTensor* transpose_2_out_grad_tensor, - phi::DenseTensor* src_mask_grad_tensor, - phi::DenseTensor* qkv_input_grad_tensor) { - auto blas = phi::funcs::GetBlas(dev_ctx_); - int q_size = batch_size_ * seq_len_ * num_head_ * head_dim_; - int k_size = q_size; - int softmax_axis = -1; - - T* qkv_grad_data = transpose_2_out_grad_tensor->data(); - T* q_grad_ptr = qkv_grad_data; - T* k_grad_ptr = q_grad_ptr + q_size; - T* v_grad_ptr = k_grad_ptr + k_size; - const T* qkv_data = transpose_2_out_tensor.data(); - const T* q_ptr = qkv_data; - const T* k_ptr = q_ptr + q_size; - const T* v_ptr = k_ptr + k_size; - - const T* softmax_out_data = softmax_out_tensor.data(); - T* softmax_out_grad_data = softmax_out_grad_tensor->data(); - T* qktv_out_grad_data = qktv_out_grad_tensor->data(); - - // transpose bw - std::vector perm_3 = {0, 2, 1, 3}; - phi::funcs::TransposeGPUKernelDriver( - dev_ctx_, fmha_out_grad_tensor, perm_3, qktv_out_grad_tensor); - - // recall batchedgemm(nn) fw: softmax_out_data(x) * v_ptr(y) = - // qktv_out_data(out) - CBLAS_TRANSPOSE transA = CblasTrans; - CBLAS_TRANSPOSE transB = CblasNoTrans; - int gemm_batch_size = batch_size_ * num_head_; - int gemm_m = seq_len_; - int gemm_n = head_dim_; - int gemm_k = seq_len_; - T alpha = static_cast(1.0); - T beta = static_cast(0.0); - int64_t stride_a = gemm_m * gemm_k; - int64_t stride_b = gemm_k * gemm_n; - // bw: dy = x^t * dout - if (dropout_param_.dropout_prob_) { - const T* dropout_out_data = dropout_out_tensor.data(); - blas.BatchedGEMM(transA, - transB, - gemm_m, - gemm_n, - gemm_k, - alpha, - dropout_out_data, - qktv_out_grad_data, - beta, - v_grad_ptr, - gemm_batch_size, - stride_a, - stride_b); - } else { - blas.BatchedGEMM(transA, - transB, - gemm_m, - gemm_n, - gemm_k, - alpha, - softmax_out_data, - qktv_out_grad_data, - beta, - v_grad_ptr, - gemm_batch_size, - stride_a, - stride_b); - } - // bw: dx = dout * y^t - transA = CblasNoTrans; - transB = CblasTrans; - gemm_m = seq_len_; - gemm_n = seq_len_; - gemm_k = head_dim_; - stride_a = gemm_m * gemm_k; - stride_b = gemm_k * gemm_n; - if (dropout_param_.dropout_prob_) { - T* dropout_out_grad_data = dropout_out_grad_tensor->data(); - blas.BatchedGEMM(transA, - transB, - gemm_m, - gemm_n, - gemm_k, - alpha, - qktv_out_grad_data, - v_ptr, - beta, - dropout_out_grad_data, - gemm_batch_size, - stride_a, - stride_b); - } else { - blas.BatchedGEMM(transA, - transB, - gemm_m, - gemm_n, - gemm_k, - alpha, - qktv_out_grad_data, - v_ptr, - beta, - softmax_out_grad_data, - gemm_batch_size, - stride_a, - stride_b); - } - // dropout bw - if (dropout_param_.dropout_prob_) { - phi::funcs::DropoutGradGPUKernelDriver( - static_cast(dev_ctx_), - false, - dropout_param_.dropout_prob_, - dropout_param_.is_upscale_in_train_, - static_cast(*dropout_out_grad_tensor), - dropout_mask_out_tensor, - softmax_out_grad_tensor, - false); - } - - if (src_mask_tensor != nullptr) { - phi::SoftmaxBackwardCUDAKernelDriver(dev_ctx_, - softmax_out_tensor, - *softmax_out_grad_tensor, - softmax_axis, - src_mask_out_grad_tensor); - // recall LaunchElementwiseCudaKernel fw: src_mask_out = qk_out + - // src_mask - // Special case when dy is not needed and dx doesn't reduce - if (qk_out_grad_tensor != nullptr && src_mask_grad_tensor == nullptr && - qk_out_tensor.dims() == src_mask_out_tensor.dims()) { - VLOG(4) << "Special case when dy is not needed and dx doesn't " - "reduce"; - framework::TensorCopy(*src_mask_out_grad_tensor, - dev_ctx_.GetPlace(), - dev_ctx_, - qk_out_grad_tensor); - } else { - PADDLE_THROW(phi::errors::InvalidArgument( - "Only used for the backward elementwise_add op when" - "dy is not needed and dx is not reduce")); - return; - } - - } else { - phi::SoftmaxBackwardCUDAKernelDriver(dev_ctx_, - softmax_out_tensor, - *softmax_out_grad_tensor, - softmax_axis, - qk_out_grad_tensor); - } - - T* qk_out_grad_data = qk_out_grad_tensor->data(); - // NOTE(wangxi): For we scale Q with 1/sqrt(Dh) in forward, so we set - // alpha = 1.0 in backward. - alpha = static_cast(1.0); - // recall batchedgemm(nt) fw: q_ptr * (k_ptr)^t = qk_out - // bw: dy (seq_len * head_dim) = (dout)^t * x - transA = CblasTrans; - transB = CblasNoTrans; - gemm_m = seq_len_; - gemm_n = head_dim_; - gemm_k = seq_len_; - stride_a = gemm_m * gemm_k; - stride_b = gemm_k * gemm_n; - blas.BatchedGEMM(transA, - transB, - gemm_m, - gemm_n, - gemm_k, - alpha, - qk_out_grad_data, - q_ptr, - beta, - k_grad_ptr, - gemm_batch_size, - stride_a, - stride_b); - // dx (seq_len * head_dim) = dout * y - alpha = static_cast(1.0 / sqrt(head_dim_)); - transA = CblasNoTrans; - transB = CblasNoTrans; - gemm_m = seq_len_; - gemm_n = head_dim_; - gemm_k = seq_len_; - stride_a = gemm_m * gemm_k; - stride_b = gemm_k * gemm_n; - blas.BatchedGEMM(transA, - transB, - gemm_m, - gemm_n, - gemm_k, - alpha, - qk_out_grad_data, - k_ptr, - beta, - q_grad_ptr, - gemm_batch_size, - stride_a, - stride_b); - - // transpose bw - std::vector perm_1 = {1, 3, 0, 2, 4}; - phi::funcs::TransposeGPUKernelDriver( - dev_ctx_, *transpose_2_out_grad_tensor, perm_1, qkv_input_grad_tensor); - } - - private: - const phi::GPUContext& dev_ctx_; - - int64_t batch_size_; - int64_t seq_len_; - int64_t num_head_; - int64_t head_dim_; - - AttnDropoutParam dropout_param_; -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu index b696a183170c33..11614d70165d3a 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu @@ -61,8 +61,8 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { auto ln_scales = ctx.MultiInput("LnScale"); auto ln_biases = ctx.MultiInput("LnBias"); - auto ln_compute = - AttnLayerNorm(dev_ctx, epsilon, bsz_seq, dim_embed); + auto ln_compute = phi::fusion::AttnLayerNorm( + dev_ctx, epsilon, bsz_seq, dim_embed); phi::DenseTensor ln_mean, ln_var; ln_mean.Resize({{bsz_seq}}); auto *ln_mean_data = @@ -93,10 +93,10 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { dev_ctx.Alloc(&qkv_out, qkv_out.numel() * sizeof(T)); // 3. fmha - AttnDropoutParam attn_param( + phi::fusion::AttnDropoutParam attn_param( true, "upscale_in_train", 0.0, true, true, 0, nullptr); - auto fmha_compute = - FMHARef(dev_ctx, bsz, seq_len, num_head, dim_head, attn_param); + auto fmha_compute = phi::fusion::FMHARef( + dev_ctx, bsz, seq_len, num_head, dim_head, attn_param); auto *src_mask = ctx.Input("SrcMask"); auto cache_kvs = ctx.MultiInput("CacheKV"); auto cache_kv_outs = ctx.MultiOutput("CacheKVOut"); diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu index 75a4c7b275a8a5..b3718dfe1f7d51 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu @@ -125,7 +125,8 @@ void FusedMultiTransformerKernel( auto *padding_offset_data = encoder_remove_padding ? padding_offset_tensor.data() : nullptr; - auto ln_compute = AttnLayerNorm(dev_ctx, epsilon, token_num, dim_embed); + auto ln_compute = + phi::fusion::AttnLayerNorm(dev_ctx, epsilon, token_num, dim_embed); phi::DenseTensor ln_mean, ln_var; ln_mean.Resize({token_num}); auto *ln_mean_data = @@ -800,7 +801,8 @@ void FusedMultiTransformerKernel( // 1. layer norm - auto ln_compute = AttnLayerNorm(dev_ctx, epsilon, token_num, dim_embed); + auto ln_compute = + phi::fusion::AttnLayerNorm(dev_ctx, epsilon, token_num, dim_embed); phi::DenseTensor ln_mean, ln_var; ln_mean.Resize({token_num}); auto *ln_mean_data = diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h index 4bf467e9caf8fa..0a57fb9e873414 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h @@ -27,17 +27,16 @@ limitations under the License. */ #include "paddle/common/flags.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/fused/attention_layer_norm.h" -#include "paddle/fluid/operators/fused/fmha_ref.h" #include "paddle/fluid/operators/fused/fused_dropout_helper.h" -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/phi/api/include/tensor.h" #include "paddle/phi/backends/dynload/cublasLt.h" #include "paddle/phi/backends/gpu/gpu_device_function.h" +#include "paddle/phi/backends/gpu/gpu_dnn.h" #include "paddle/phi/core/distributed/comm_context_manager.h" #include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/fusion/gpu/attn_gemm.h" +#include "paddle/phi/kernels/fusion/gpu/fmha_ref.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/distributed/collective/process_group.h" @@ -711,13 +710,13 @@ struct Qk_dot { } }; -template +template inline __device__ float block_sum(float *red_smem, float sum) { - int warp = threadIdx.x / WARP_SIZE; - int lane = threadIdx.x % WARP_SIZE; + int warp = threadIdx.x / WARP_SIZE_T; + int lane = threadIdx.x % WARP_SIZE_T; #pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + for (int mask = WARP_SIZE_T / 2; mask >= 1; mask /= 2) { sum += __shfl_xor_sync(uint32_t(-1), sum, mask); } @@ -789,8 +788,8 @@ __global__ void masked_multihead_attention_kernel( static_assert(Dh_MAX % THREADS_PER_KEY == 0, ""); static_assert(Dh_MAX % THREADS_PER_VALUE == 0, ""); - constexpr int WARP_SIZE = 32; - constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE; + constexpr int WARP_SIZE_TMP = 32; + constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE_TMP; extern __shared__ char smem_[]; @@ -824,7 +823,7 @@ __global__ void masked_multihead_attention_kernel( constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T); static_assert(Dh_MAX % QK_VEC_SIZE == 0, ""); // Use block reduction if needed - // static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, ""); + // static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE_TMP, ""); constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE; // cache_k, [B, num_head, head_dim / x, max_seq_len, x] @@ -944,16 +943,16 @@ __global__ void masked_multihead_attention_kernel( qk = dot(q, k); - if (QK_VECS_PER_WARP <= WARP_SIZE) { + if (QK_VECS_PER_WARP <= WARP_SIZE_TMP) { #pragma unroll for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) { qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask); } } } - if (QK_VECS_PER_WARP > WARP_SIZE) { + if (QK_VECS_PER_WARP > WARP_SIZE_TMP) { constexpr int WARPS_PER_RED = - (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE; + (QK_VECS_PER_WARP + WARP_SIZE_TMP - 1) / WARP_SIZE_TMP; qk = block_sum(&red_smem[WARPS_PER_RED], qk); } if (tid == 0) { @@ -994,7 +993,7 @@ __global__ void masked_multihead_attention_kernel( } constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; - constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; + constexpr int K_PER_WARP = WARP_SIZE_TMP / THREADS_PER_KEY; T *k_cache = ¶ms.cache_kv[bhi * params.max_seq_length * Dh + ki]; int ti_end = div_up(act_time_step, K_PER_WARP) * K_PER_WARP; @@ -1031,12 +1030,12 @@ __global__ void masked_multihead_attention_kernel( } #pragma unroll - for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) { + for (int mask = WARP_SIZE_TMP / 2; mask >= THREADS_PER_KEY; mask /= 2) { qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); } - const int warp = tid / WARP_SIZE; - const int lane = tid % WARP_SIZE; + const int warp = tid / WARP_SIZE_TMP; + const int lane = tid % WARP_SIZE_TMP; if (lane == 0) { red_smem[warp] = qk_max; diff --git a/paddle/fluid/operators/fused/resnet_unit_op.cu b/paddle/fluid/operators/fused/resnet_unit_op.cu index 2955fd3b453b4d..f715bda6906951 100644 --- a/paddle/fluid/operators/fused/resnet_unit_op.cu +++ b/paddle/fluid/operators/fused/resnet_unit_op.cu @@ -31,7 +31,7 @@ class ResNetUnitKernel : public framework::OpKernel { platform::is_gpu_place(ctx.GetPlace()), true, phi::errors::PreconditionNotMet("It must use CUDAPlace.")); - PADDLE_ENFORCE_EQ(platform::CudnnDataType::type, + PADDLE_ENFORCE_EQ(phi::backends::gpu::CudnnDataType::type, CUDNN_DATA_HALF, phi::errors::Unavailable( "ResNetUnitOp only supports float16 for now.")); @@ -231,7 +231,7 @@ class ResNetUnitGradKernel : public framework::OpKernel { platform::is_gpu_place(ctx.GetPlace()), true, phi::errors::PreconditionNotMet("It must use CUDAPlace.")); - PADDLE_ENFORCE_EQ(platform::CudnnDataType::type, + PADDLE_ENFORCE_EQ(phi::backends::gpu::CudnnDataType::type, CUDNN_DATA_HALF, phi::errors::Unavailable( "ResNetUnitOp only supports float16 for now.")); diff --git a/paddle/fluid/operators/fused/xpu_fused_common_function.h b/paddle/fluid/operators/fused/xpu_fused_common_function.h deleted file mode 100644 index 63a22838e8c35e..00000000000000 --- a/paddle/fluid/operators/fused/xpu_fused_common_function.h +++ /dev/null @@ -1,225 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#ifdef PADDLE_WITH_XPU -#include "paddle/fluid/platform/device/device_wrapper.h" - -namespace paddle { -namespace operators { - -struct XPUDropoutParam { - float dropout_prob; - bool is_upscale_in_train; - bool is_test; - bool fix_seed; - const phi::DenseTensor *tensor_seed; - int seed_val; - - XPUDropoutParam() { - fix_seed = false; - is_test = false; - is_upscale_in_train = false; - dropout_prob = 0.5; - tensor_seed = nullptr; - seed_val = 0; - } - - XPUDropoutParam(const framework::ExecutionContext &context, - const int dropout_index) { - std::string pre_fix = "dropout"; - std::string str_index = std::to_string(dropout_index); - if (dropout_index > 0) { - pre_fix = pre_fix + str_index + "_"; - } else { - pre_fix = pre_fix + "_"; - } - dropout_prob = context.Attr(pre_fix + "rate"); - auto &dropout_implementation = - context.Attr(pre_fix + "implementation"); - is_upscale_in_train = (dropout_implementation == "upscale_in_train"); - is_test = context.Attr("is_test"); - fix_seed = context.Attr(pre_fix + "fix_seed"); - - std::string str_seed = "Dropout"; - if (dropout_index > 0) { - str_seed = str_seed + str_index + "Seed"; - } else { - str_seed = str_seed + "Seed"; - } - - tensor_seed = context.HasInput(str_seed) - ? context.Input(str_seed) - : nullptr; - if (tensor_seed) { - seed_val = *(tensor_seed->data()); - } else { - seed_val = fix_seed ? context.Attr(pre_fix + "seed") : 0; - } - } - - void initXPUDropoutParam(float dropout_prob_, - bool is_upscale_in_train_, - bool is_test_, - bool fix_seed_, - const phi::DenseTensor *tensor_seed, - int seed_val_) { - dropout_prob = dropout_prob_; - is_upscale_in_train = is_upscale_in_train_; - is_test = is_test_; - fix_seed = fix_seed_; - if (tensor_seed) { - seed_val = *(tensor_seed->data()); - } else { - seed_val = fix_seed ? seed_val_ : 0; - } - } - - void initXPUDropoutParam(const framework::ExecutionContext &context, - int dropout_index) { - std::string pre_fix = "dropout"; - std::string str_index = std::to_string(dropout_index); - if (dropout_index > 0) { - pre_fix = pre_fix + str_index + "_"; - } else { - pre_fix = pre_fix + "_"; - } - dropout_prob = context.Attr(pre_fix + "rate"); - auto &dropout_implementation = - context.Attr(pre_fix + "implementation"); - is_upscale_in_train = (dropout_implementation == "upscale_in_train"); - is_test = context.Attr("is_test"); - fix_seed = context.Attr(pre_fix + "fix_seed"); - std::string str_seed = "Dropout"; - if (dropout_index > 0) { - str_seed = str_seed + str_index + "Seed"; - } else { - str_seed = str_seed + "Seed"; - } - tensor_seed = context.HasInput(str_seed) - ? context.Input(str_seed) - : nullptr; - - if (tensor_seed) { - seed_val = *(tensor_seed->data()); - } else { - seed_val = fix_seed ? context.Attr(pre_fix + "seed") : 0; - } - } -}; - -/****************** - * check is l3 - *******************/ - -static bool is_in_l3(const void *addr) { - int64_t addr_int = (int64_t)addr; - int addr_int_high = addr_int >> 32; - return (addr_int_high == 0); -} - -/************************* - * dropout - *************************/ - -template -void Dropout(xpu::Context *xpu_ctx, - const T *x, - T *mask, - T *y, - const XPUDropoutParam ¶m, - int len) { - using XPUType = typename XPUTypeTrait::Type; - int r = XPU_SUCCESS; - if (param.dropout_prob == 0.0f) { - r = xpu::copy(xpu_ctx, - reinterpret_cast(x), - reinterpret_cast(y), - len); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy"); - return; - } - if (!param.is_test) { - if (param.dropout_prob == 1.0f) { - r = xpu::constant( - xpu_ctx, reinterpret_cast(y), len, XPUType(0)); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); - r = xpu::constant( - xpu_ctx, reinterpret_cast(mask), len, XPUType(0)); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); - } else { - r = xpu::dropout(xpu_ctx, - reinterpret_cast(x), - reinterpret_cast(y), - reinterpret_cast(mask), - param.seed_val, - len, - param.is_upscale_in_train, - param.dropout_prob); - - PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout"); - } - } else { - float scale = (param.is_upscale_in_train) - ? (1.0) - : (static_cast(1.0f - param.dropout_prob)); - r = xpu::scale(xpu_ctx, - reinterpret_cast(x), - reinterpret_cast(y), - len, - false, - scale, - 0.0f); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale"); - } -} - -template -void DropoutGrad(xpu::Context *xpu_ctx, - const T *dy, - const T *mask, - T *dx, - const XPUDropoutParam ¶m, - int len) { - using XPUType = typename XPUTypeTrait::Type; - if (param.dropout_prob == 0.0f) { - int r = xpu::copy(xpu_ctx, - reinterpret_cast(dy), - reinterpret_cast(dx), - len); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy"); - return; - } - if (!param.is_upscale_in_train) { - int r = xpu::mul(xpu_ctx, - reinterpret_cast(dy), - reinterpret_cast(mask), - reinterpret_cast(dx), - len); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul"); - } else { - int r = xpu::dropout_grad(xpu_ctx, - reinterpret_cast(mask), - reinterpret_cast(dy), - reinterpret_cast(dx), - param.dropout_prob, - len); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout_grad"); - } -} - -} // namespace operators -} // namespace paddle -#endif diff --git a/paddle/fluid/operators/fused/yolo_box_head_op.cc b/paddle/fluid/operators/fused/yolo_box_head_op.cc deleted file mode 100644 index 9a4e7b56434efc..00000000000000 --- a/paddle/fluid/operators/fused/yolo_box_head_op.cc +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -class YoloBoxHeadOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "yolo_box_head"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "yolo_box_head"); - ctx->SetOutputDim("Out", ctx->GetInputDim("X")); - } -}; - -class YoloBoxHeadOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "The input tensor"); - AddAttr>("anchors", - "The anchor width and height, " - "it will be parsed pair by pair."); - AddAttr("class_num", "The number of classes to predict."); - AddOutput("Out", "The output tensor"); - AddComment(R"DOC( - yolo_box_head Operator. - )DOC"); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OPERATOR(yolo_box_head, ops::YoloBoxHeadOp, ops::YoloBoxHeadOpMaker); diff --git a/paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc b/paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc index 8831a40440d676..728e6007c1c2c4 100644 --- a/paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc +++ b/paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc @@ -16,7 +16,7 @@ limitations under the License. */ // HIP not support cudnnSpatialTfGridGeneratorForward #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" +#include "paddle/phi/backends/gpu/gpu_dnn.h" namespace phi { class DenseTensor; diff --git a/paddle/fluid/operators/lod_tensor_to_array_op.cc b/paddle/fluid/operators/lod_tensor_to_array_op.cc index 42f6a4786fb25b..ff9197f40f8d76 100644 --- a/paddle/fluid/operators/lod_tensor_to_array_op.cc +++ b/paddle/fluid/operators/lod_tensor_to_array_op.cc @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/phi/core/lod_utils.h" +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" namespace paddle { namespace framework { @@ -88,7 +88,7 @@ struct LoDTensorToArrayFunctor { template template void LoDTensorToArrayFunctorImpl::apply() { - math::SplitFunctor func; + phi::funcs::SplitFunctor func; func(*dev_ctx_, prev_functor_->input_, prev_functor_->ref_inputs_, diff --git a/paddle/fluid/operators/lookup_table_v2_op.cu b/paddle/fluid/operators/lookup_table_v2_op.cu deleted file mode 100644 index 8628965251ee75..00000000000000 --- a/paddle/fluid/operators/lookup_table_v2_op.cu +++ /dev/null @@ -1,254 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/lookup_table_v2_op.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/backends/gpu/gpu_primitives.h" -#include "paddle/phi/common/float16.h" -#include "paddle/phi/kernels/funcs/eigen/common.h" - -namespace paddle { -namespace operators { - -template -__global__ void LookupTableV2(T *output, - const T *table, - const IdT *ids, - const int64_t N, - const int64_t K, - const int64_t D, - const int64_t padding_idx) { - int idx = threadIdx.x; - int idy = blockIdx.x + threadIdx.y * gridDim.x; - - while (idy < K) { - auto id = static_cast(ids[idy]); - T *out = output + idy * D; - const T *tab = table + id * D; - for (int i = idx; i < D; i += blockDim.x) { - if (PaddingFlag) { - if (id == padding_idx) - out[i] = static_cast(0); - else - out[i] = tab[i]; - } else { - out[i] = tab[i]; - } - } - idy += blockDim.y * gridDim.x; - } -} - -template -__global__ void LookupTableV2Grad(T *table, - const T *output, - const IdT *ids, - const int64_t N, - const int64_t K, - const int64_t D) { - int idx = threadIdx.x; - int idy = blockIdx.x + threadIdx.y * gridDim.x; - - while (idy < K) { - auto id = static_cast(ids[idy]); - const T *out = output + idy * D; - T *tab = table + id * D; -#ifdef PADDLE_WITH_CUDA - phi::VectorizedAtomicAddPerBlock(D, idx, blockDim.x, out, tab); -#else - for (int i = idx; i < D; i += blockDim.x) { - phi::CudaAtomicAdd(&tab[i], out[i]); - } -#endif - idy += blockDim.y * gridDim.x; - } -} - -template -struct LookupTableV2CUDAFunctor { - LookupTableV2CUDAFunctor(const framework::ExecutionContext &context, - const phi::DenseTensor *ids_t) - : context_(context), ids_t_(ids_t) {} - - template - void apply() { - auto *table_t = context_.Input("W"); - auto *output_t = context_.Output("Out"); - int64_t padding_idx = context_.Attr("padding_idx"); - - size_t N = table_t->dims()[0]; - size_t D = table_t->dims()[1]; - size_t K = ids_t_->numel(); - - const int gridx = 2 * context_.cuda_device_context().GetSMCount(); - dim3 threads(256, 4); - dim3 grids(gridx, 1); - - const auto *table = table_t->template data(); - const auto *ids = ids_t_->template data(); - auto *output = output_t->template mutable_data(context_.GetPlace()); - auto stream = context_.cuda_device_context().stream(); - - if (padding_idx == -1) { - LookupTableV2<<>>( - output, table, ids, N, K, D, padding_idx); - } else { - LookupTableV2<<>>( - output, table, ids, N, K, D, padding_idx); - } - } - - private: - const framework::ExecutionContext &context_; - const phi::DenseTensor *ids_t_; -}; - -template -class LookupTableV2CUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &context) const override { - const auto *ids_t = context.Input("Ids"); - LookupTableV2CUDAFunctor functor(context, ids_t); - framework::VisitIntDataType(framework::TransToProtoVarType(ids_t->dtype()), - functor); - } -}; - -template -__global__ void InputTypeConvert(const InT *in_ids, - const int64_t K, - OutT *out_ids) { - for (int i = 0; i < K; i++) { - out_ids[i] = static_cast(in_ids[i]); - } -} - -template -struct LookupTableV2GradCUDAFunctor { - LookupTableV2GradCUDAFunctor(const framework::ExecutionContext &context, - const phi::DenseTensor *ids_t) - : context_(context), ids_t_(ids_t) {} - - template - void apply() { - auto &dev_ctx = context_.template device_context(); - bool is_sparse = context_.Attr("is_sparse"); - - // Since paddings are not trainable and fixed in forward, the gradient of - // paddings makes no sense and we don't deal with it in backward. - if (is_sparse) { - auto *table = context_.Input("W"); - auto *d_output = - context_.Input(framework::GradVarName("Out")); - auto *d_table = - context_.Output(framework::GradVarName("W")); - - const auto *ids_data = ids_t_->template data(); - int64_t ids_num = ids_t_->numel(); - dim3 threads(128, 8); - dim3 grids(8, 1); - auto stream = dev_ctx.stream(); - phi::Vector new_rows; - new_rows.resize(ids_num); - auto gpu_place = context_.GetPlace(); - - phi::MixVector mixv_new_rows(&new_rows); - if (!std::is_same::value) { - InputTypeConvert<<>>( - ids_data, ids_num, mixv_new_rows.MutableData(gpu_place)); - } else { - memory::Copy(gpu_place, - mixv_new_rows.CUDAMutableData(gpu_place), - gpu_place, - ids_data, - ids_num * sizeof(int64_t), - stream); - } - - mixv_new_rows.CopyToCPU(); - d_table->set_rows(new_rows); - - auto *d_table_value = d_table->mutable_value(); - d_table_value->Resize({ids_num, table->dims()[1]}); - d_table_value->template mutable_data(gpu_place); - - auto *d_table_data = d_table_value->template data(); - auto *d_output_data = d_output->template data(); - auto d_output_dims = d_output->dims(); - auto d_output_dims_2d = - common::flatten_to_2d(d_output_dims, d_output_dims.size() - 1); - PADDLE_ENFORCE_EQ(d_table_value->dims(), - d_output_dims_2d, - phi::errors::InvalidArgument( - "ShapeError: The shape of lookup_table@Grad and " - "output@Grad should be same. " - "But received lookup_table@Grad's shape = [%s], " - "output@Grad's shape = [%s].", - d_table_value->dims(), - d_output_dims_2d)); - memory::Copy(gpu_place, - d_table_data, - gpu_place, - d_output_data, - d_output->numel() * sizeof(T), - stream); - - } else { - auto d_output_t = - context_.Input(framework::GradVarName("Out")); - auto d_table_t = - context_.Output(framework::GradVarName("W")); - - int N = d_table_t->dims()[0]; - int D = d_table_t->dims()[1]; - int K = ids_t_->numel(); - - const T *d_output = d_output_t->template data(); - const auto *ids = ids_t_->template data(); - T *d_table = d_table_t->mutable_data(context_.GetPlace()); - -#ifdef PADDLE_WITH_HIP - PADDLE_ENFORCE_GPU_SUCCESS( - hipMemsetAsync(d_table, 0, N * D * sizeof(T), dev_ctx.stream())); -#else - PADDLE_ENFORCE_GPU_SUCCESS( - cudaMemsetAsync(d_table, 0, N * D * sizeof(T), dev_ctx.stream())); -#endif - - const int gridx = 2 * dev_ctx.GetSMCount(); - dim3 threads(128, 8); - dim3 grids(gridx, 1); - LookupTableV2Grad<<>>( - d_table, d_output, ids, N, K, D); - } - } - - private: - const framework::ExecutionContext &context_; - const phi::DenseTensor *ids_t_; -}; - -template -class LookupTableV2GradCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &context) const override { - const auto *ids_t = context.Input("Ids"); - LookupTableV2GradCUDAFunctor functor(context, ids_t); - framework::VisitIntDataType(framework::TransToProtoVarType(ids_t->dtype()), - functor); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/lookup_table_v2_op.h b/paddle/fluid/operators/lookup_table_v2_op.h deleted file mode 100644 index 8e3ce198e060bf..00000000000000 --- a/paddle/fluid/operators/lookup_table_v2_op.h +++ /dev/null @@ -1,285 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include - -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/selected_rows_utils.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" -#include "paddle/phi/kernels/funcs/eigen/common.h" - -namespace paddle { -namespace operators { - -using SelectedRows = phi::SelectedRows; -using DDim = framework::DDim; - -constexpr int64_t kNoPadding = -1; - -template -static std::vector CopyIdsToVector(const phi::DenseTensor &ids) { - auto numel = ids.numel(); - const auto *src = ids.data(); - std::vector ret(numel); - if (std::is_same::value) { - std::memcpy(ret.data(), src, numel * sizeof(InT)); - } else { - for (decltype(numel) i = 0; i < numel; ++i) { - ret[i] = src[i]; - } - } - return ret; -} - -template -struct LookupTableV2CPUFunctor { - LookupTableV2CPUFunctor(const framework::ExecutionContext &context, - const phi::DenseTensor *ids_t) - : context_(context), ids_t_(ids_t) {} - - template - void apply() { - auto *output_t = context_.Output("Out"); // float tensor - auto *table_var = context_.InputVar("W"); - - int64_t padding_idx = context_.Attr("padding_idx"); - - auto ids = CopyIdsToVector(*ids_t_); - auto ids_numel = static_cast(ids.size()); - - if (table_var->template IsType()) { - const auto &table_t = table_var->template Get(); - int64_t row_number = table_t.dims()[0]; - int64_t row_width = table_t.dims()[1]; - - auto *table = table_t.template data(); - auto *output = output_t->template mutable_data(context_.GetPlace()); - - for (int64_t i = 0; i < ids_numel; ++i) { - if (padding_idx != kNoPadding && ids[i] == padding_idx) { - memset(output + i * row_width, 0, row_width * sizeof(T)); - } else { - PADDLE_ENFORCE_LT( - ids[i], - row_number, - phi::errors::InvalidArgument( - "Variable value (input) of OP(fluid.layers.embedding) " - "expected >= 0 and < %ld, but got %ld. Please check input " - "value.", - row_number, - ids[i])); - PADDLE_ENFORCE_GE( - ids[i], - 0, - phi::errors::InvalidArgument( - "Variable value (input) of OP(fluid.layers.embedding) " - "expected >= 0 and < %ld, but got %ld. Please check input " - "value.", - row_number, - ids[i])); - memcpy(output + i * row_width, - table + ids[i] * row_width, - row_width * sizeof(T)); - } - } - } else if (table_var->template IsType()) { - const auto &table_t = table_var->template Get(); - int64_t row_width = table_t.value().dims()[1]; - const auto *table = table_t.value().template data(); - auto *output = output_t->template mutable_data(context_.GetPlace()); - auto input_data_type = - framework::TransToProtoVarType(table_t.value().dtype()); - - for (int64_t i = 0; i < ids_numel; ++i) { - if (padding_idx != kNoPadding && ids[i] == padding_idx) { - memset(output + i * row_width, 0, row_width * sizeof(T)); - } else { - PADDLE_ENFORCE_GE( - ids[i], - 0, - phi::errors::InvalidArgument( - "Variable value (input) of OP(fluid.layers.embedding) " - "expected >= 0. But received %ld", - ids[i])); - auto id_index = table_t.Index(ids[i]); - PADDLE_ENFORCE_GE( - id_index, - 0, - phi::errors::InvalidArgument( - "the input key should be exists. But received %d.", - id_index)); - - if (input_data_type == framework::proto::VarType::BF16) { - memcpy(output + i * row_width, - table + id_index * row_width, - row_width * sizeof(T)); - } else { - auto &dev_ctx = context_.template device_context(); - auto blas = phi::funcs::GetBlas(dev_ctx); - blas.VCOPY(row_width, - table + id_index * row_width, - output + i * row_width); - } - } - } - } - } - - private: - const framework::ExecutionContext &context_; - const phi::DenseTensor *ids_t_; -}; - -template -class LookupTableV2Kernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &context) const override { - const auto *ids = context.Input("Ids"); - LookupTableV2CPUFunctor functor(context, ids); - framework::VisitIntDataType(framework::TransToProtoVarType(ids->dtype()), - functor); - } -}; - -template -struct LookupTableV2GradCPUFunctor { - LookupTableV2GradCPUFunctor(const framework::ExecutionContext &context, - const phi::DenseTensor *ids_t) - : context_(context), ids_t_(ids_t) {} - - template - void apply() { - auto *table_var = context_.InputVar("W"); - DDim table_dim; - if (table_var->template IsType()) { - table_dim = context_.Input("W")->dims(); - } else if (table_var->template IsType()) { - auto *table_t = context_.Input("W"); - table_dim = table_t->value().dims(); - } else { - PADDLE_THROW(phi::errors::InvalidArgument( - "The parameter W of a LookupTableV2 " - "must be either phi::DenseTensor or SelectedRows")); - } - - int64_t padding_idx = context_.Attr("padding_idx"); - bool is_sparse = context_.Attr("is_sparse"); - - auto ids = CopyIdsToVector(*ids_t_); - auto ids_num = static_cast(ids.size()); - - // Since paddings are not trainable and fixed in forward, the gradient of - // paddings makes no sense and we don't deal with it in backward. - if (is_sparse) { - auto *d_output = - context_.Input(framework::GradVarName("Out")); - auto *d_table = - context_.Output(framework::GradVarName("W")); - - d_table->set_rows(ids); - - auto *d_table_value = d_table->mutable_value(); - d_table_value->Resize({ids_num, table_dim[1]}); - - d_table_value->template mutable_data(context_.GetPlace()); - - d_table->set_height(table_dim[0]); - - auto *d_output_data = d_output->template data(); - auto *d_table_data = d_table_value->template data(); - - auto d_output_dims = d_output->dims(); - auto d_output_dims_2d = - common::flatten_to_2d(d_output_dims, d_output_dims.size() - 1); - PADDLE_ENFORCE_EQ(d_table_value->dims(), - d_output_dims_2d, - phi::errors::InvalidArgument( - "ShapeError: The shape of lookup_table@Grad and " - "output@Grad should be same. " - "But received lookup_table@Grad's shape = [%s], " - "output@Grad's shape = [%s].", - d_table_value->dims(), - d_output_dims_2d)); - memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel()); - - } else { - auto *d_output = - context_.Input(framework::GradVarName("Out")); - auto *d_table = - context_.Output(framework::GradVarName("W")); - auto *ids_data = ids.data(); - - int64_t N = table_dim[0]; - int64_t D = table_dim[1]; - - auto *d_output_data = d_output->template data(); - auto *d_table_data = - d_table->template mutable_data(context_.GetPlace()); - - memset(d_table_data, 0, d_table->numel() * sizeof(T)); - - for (int64_t i = 0; i < ids_num; ++i) { - if (padding_idx != kNoPadding && ids_data[i] == padding_idx) { - // the gradient of padding_idx should be 0, already done by memset, so - // do nothing. - } else { - PADDLE_ENFORCE_LT( - ids_data[i], - N, - phi::errors::InvalidArgument( - "Variable value (input) of OP(fluid.layers.embedding) " - "expected >= 0 and < %ld, but got %ld. Please check input " - "value.", - N, - ids_data[i])); - PADDLE_ENFORCE_GE( - ids_data[i], - 0, - phi::errors::InvalidArgument( - "Variable value (input) of OP(fluid.layers.embedding) " - "expected >= 0 and < %ld, but got %ld. Please check input " - "value.", - N, - ids_data[i])); - for (int j = 0; j < D; ++j) { - d_table_data[ids_data[i] * D + j] += d_output_data[i * D + j]; - } - } - } - } - } - - private: - const framework::ExecutionContext &context_; - const phi::DenseTensor *ids_t_; -}; - -template -class LookupTableV2GradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &context) const override { - const auto *ids = context.Input("Ids"); - LookupTableV2GradCPUFunctor functor(context, ids); - framework::VisitIntDataType(framework::TransToProtoVarType(ids->dtype()), - functor); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 0e0423bd64ff45..e7545b8fd4f2df 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -10,8 +10,6 @@ math_library(concat_and_split DEPS phi common) math_library(context_project DEPS phi common) math_library(cos_sim_functor) math_library(depthwise_conv) -math_library(sample_prob) -math_library(sampler DEPS phi common) if(WITH_XPU) math_library(beam_search DEPS phi common beam_search_xpu) diff --git a/paddle/fluid/operators/math/prelu.h b/paddle/fluid/operators/math/prelu.h index 00ff1fbcbc38db..d809c71f437426 100644 --- a/paddle/fluid/operators/math/prelu.h +++ b/paddle/fluid/operators/math/prelu.h @@ -15,8 +15,8 @@ limitations under the License. */ #pragma once #include -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/backends/gpu/gpu_dnn.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { diff --git a/paddle/fluid/operators/math/sample_prob.cc b/paddle/fluid/operators/math/sample_prob.cc deleted file mode 100644 index 18321cf9b9ece6..00000000000000 --- a/paddle/fluid/operators/math/sample_prob.cc +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/math/sample_prob.h" - -namespace paddle { -namespace operators { -namespace math {} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/sample_prob.cu b/paddle/fluid/operators/math/sample_prob.cu deleted file mode 100644 index 1d70b402104f58..00000000000000 --- a/paddle/fluid/operators/math/sample_prob.cu +++ /dev/null @@ -1,206 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include - -#include -#include - -#include "paddle/common/ddim.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/operators/math/sample_prob.h" -#include "paddle/fluid/operators/math/sampler.h" -#include "paddle/phi/kernels/funcs/eigen/common.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { -namespace math { - -template -__device__ T gpu_adjust_prob(const T prob, - const int num_samples, - const int num_tries) { - if (num_samples == num_tries) { - return prob * num_samples; - } else { - return -expm1(num_tries * log1p(-prob)); - } -} - -class GPULogUniformSampler { - public: - __device__ int64_t Sample(float random, - const int range, - const float log_range) const; - __device__ float Probability(int64_t value, const float log_range) const; -}; - -__device__ int64_t GPULogUniformSampler::Sample(float random, - const int range, - const float log_range) const { - // Got Log Uniform distribution from uniform distribution by - // inverse_transform_sampling method - const int64_t value = static_cast(exp(random * log_range)) - 1; - // Mathematically, value should be <= range_, but might not be due to some - // floating point roundoff, so we mod by range_. - return value % range; -} - -__device__ float GPULogUniformSampler::Probability( - int64_t value, const float log_range) const { - // Given f(x) = 1/[(x+1) * log_range_] - // The value's probability is integral of f(x) from value to (value + 1) - return (log((value + 2.0) / (value + 1.0))) / log_range; -} - -template -__global__ void SamplingCondidate(const size_t n, - const int num_tries, - const int range, - const float log_range, - const int num_true, - const std::size_t num_samples, - const int64_t* label_data, - int64_t* samples_data, - T* probabilities_data) { - const int num_sampled_classes = num_true + num_samples; - - int idx = blockDim.x * blockIdx.x + threadIdx.x; - int step_size = 0; - GPULogUniformSampler sampler; - - for (; idx < n; idx += blockDim.x * gridDim.x) { - int col_idx = idx % num_sampled_classes; - int row_idx = idx / num_sampled_classes; - if (col_idx < num_true) { - samples_data[idx] = label_data[row_idx * num_true + col_idx]; - } else { - samples_data[idx] = samples_data[col_idx]; - } - probabilities_data[idx] = sampler.Probability(samples_data[idx], log_range); - probabilities_data[idx] = - gpu_adjust_prob(probabilities_data[idx], num_samples, num_tries); - } -} - -template -int UniqSampler(const Sampler& sampler, - const std::size_t num_samples, - int64_t* samples_data) { - // sample num_samles unique samples for an example, note that they are not - // all negative samples - std::unordered_set tmp_samples; - tmp_samples.clear(); - int num_tries = 0; - int j = 0; - while (j < num_samples) { - ++num_tries; - auto v = sampler.Sample(); - auto insert_ok = tmp_samples.insert(v).second; - if (!insert_ok) { - continue; - } - samples_data[j] = v; - ++j; - } - return num_tries; -} - -template -void GPUSampleWithProb::operator()(const phi::GPUContext& context, - const int seed, - const int dict_size, - const bool uniq, - const std::size_t num_samples, - const phi::DenseTensor* L, - phi::DenseTensor* S, - phi::DenseTensor* P) { - // UNDERSTAND: dimension issues - const auto lbl_dim = L->dims(); - const int batch_size = lbl_dim[0]; - const int num_true = lbl_dim[1]; - const int num_sampled_classes = num_true + num_samples; - framework::DDim ret_dim{batch_size, num_sampled_classes}; - - // UNDERSTAND: raw data view - const int64_t* label_data = L->data(); - int64_t* samples_data = S->data(); - T* probabilities_data = P->data(); - - int s_size = num_samples; - framework::DDim s_dim{s_size}; - phi::DenseTensor s; - int64_t* s_data = s.mutable_data(s_dim, platform::CPUPlace()); - - math::LogUniformSampler sampler(dict_size, seed); - - int range = dict_size; - float log_range = log(range + 1); - - int num_tries = UniqSampler(sampler, num_samples, s_data); - VLOG(1) << "num_tries: " << num_tries; - -#ifdef PADDLE_WITH_HIP - PADDLE_ENFORCE_GPU_SUCCESS(hipMemcpy(samples_data + num_true, - s_data, - sizeof(int64_t) * num_samples, - hipMemcpyHostToDevice)); -#else - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpy(samples_data + num_true, - s_data, - sizeof(int64_t) * num_samples, - cudaMemcpyHostToDevice)); -#endif - - int threads = 512; - const size_t size = batch_size * num_sampled_classes; - int grid = (batch_size * num_sampled_classes + threads - 1) / threads; -#ifdef PADDLE_WITH_HIP - hipLaunchKernelGGL(HIP_KERNEL_NAME(SamplingCondidate), - dim3(grid), - dim3(threads), - 0, - context.stream(), - size, - num_tries, - range, - log_range, - num_true, - num_samples, - label_data, - samples_data, - probabilities_data); -#else - SamplingCondidate - <<>>(size, - num_tries, - range, - log_range, - num_true, - num_samples, - label_data, - samples_data, - probabilities_data); -#endif -} - -template class GPUSampleWithProb; -template class GPUSampleWithProb; -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/sample_prob.h b/paddle/fluid/operators/math/sample_prob.h deleted file mode 100644 index f30ada2f1f3c52..00000000000000 --- a/paddle/fluid/operators/math/sample_prob.h +++ /dev/null @@ -1,125 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include -#include - -#include "paddle/common/ddim.h" -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/operators/math/sampler.h" -#include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/kernels/funcs/eigen/common.h" - -namespace paddle { -namespace operators { -namespace math { - -/* UNDERSTAND: utility function to adjust probability for unique sampling, -return whatever as it is if not using unique samping */ -template -static T adjust_prob(const T prob, const int num_samples, const int num_tries) { - if (num_samples == num_tries) { - return prob * num_samples; - } else { - return -expm1(num_tries * log1p(-prob)); - } -} - -template -class SampleWithProb { - public: - void operator()(const DeviceContext& context, - const Sampler& sampler, - const std::size_t num_samples, - const phi::DenseTensor* L, - phi::DenseTensor* S, - phi::DenseTensor* P) { - // UNDERSTAND: dimension issues - const auto& lbl_dim = L->dims(); - const int batch_size = lbl_dim[0]; - const int num_true = lbl_dim[1]; - const int num_sampled_classes = num_true + num_samples; - framework::DDim ret_dim{batch_size, num_sampled_classes}; - - // UNDERSTAND: raw data view - const int64_t* label_data = L->data(); - int64_t* samples_data = - S->mutable_data(ret_dim, context.GetPlace()); - T* probabilities_data = P->mutable_data(ret_dim, context.GetPlace()); - - // temp sets for unique sampling - std::unordered_set tmp_samples; - int j = 0; // column index - // add true labels, not that efficient - while (j < num_true) { - for (int i = 0; i < batch_size; ++i) { - auto samples_index = i * num_sampled_classes + j; - auto v = label_data[i * num_true + j]; - samples_data[samples_index] = v; - probabilities_data[samples_index] = sampler.Probability(v); - } - ++j; - } - - // sample num_samles unique samples for an example, note that they are not - // all negative samples - tmp_samples.clear(); - int num_tries = 0; - while (j < num_sampled_classes) { - ++num_tries; - auto v = sampler.Sample(); - auto insert_ok = tmp_samples.insert(v).second; - if (!insert_ok) { - continue; - } - auto p = sampler.Probability(v); - for (int i = 0; i < batch_size; ++i) { - auto samples_index = i * num_sampled_classes + j; - samples_data[samples_index] = v; - probabilities_data[samples_index] = p; - } - ++j; - } - - // compute Q(y|x), because of unique sampling, probabilities need to be - // adjusted - for (int k = 0; k < num_sampled_classes; ++k) { - for (int i = 0; i < batch_size; ++i) { - auto samples_index = i * num_sampled_classes + k; - probabilities_data[samples_index] = adjust_prob( - probabilities_data[samples_index], num_samples, num_tries); - } - } - } -}; - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -template -class GPUSampleWithProb { - public: - void operator()(const phi::GPUContext& context, - const int seed, - const int dict_size, - const bool uniq, - const std::size_t num_samples, - const phi::DenseTensor* L, - phi::DenseTensor* S, - phi::DenseTensor* P); -}; -#endif -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/sampler.cc b/paddle/fluid/operators/math/sampler.cc deleted file mode 100644 index 0ea4336e92ec0b..00000000000000 --- a/paddle/fluid/operators/math/sampler.cc +++ /dev/null @@ -1,99 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/math/sampler.h" - -#include - -#include "paddle/phi/core/generator.h" - -namespace paddle { -namespace operators { -namespace math { - -Sampler::~Sampler() = default; - -UniformSampler::UniformSampler(int64_t range, unsigned int seed) - : Sampler(range, seed), inv_range_(1.0f / (range + 1)) { // NOLINT - random_engine_ = phi::GetCPURandomEngine(seed_); - dist_ = std::make_shared>(0, range); -} - -int64_t UniformSampler::Sample() const { return (*dist_)(*random_engine_); } - -float UniformSampler::Probability(int64_t value) const { return inv_range_; } - -LogUniformSampler::LogUniformSampler(int64_t range, unsigned int seed) - : Sampler(range, seed), log_range_(log(range + 1)) { // NOLINT - random_engine_ = phi::GetCPURandomEngine(seed_); - dist_ = std::make_shared>(0, 1); -} - -int64_t LogUniformSampler::Sample() const { - // Got Log Uniform distribution from uniform distribution by - // inverse_transform_sampling method - // More details: - // https://wanghaoshuang.github.io/2017/11/Log-uniform-distribution-sampler/ - auto cur_random = (*dist_)(*random_engine_); - const int64_t value = static_cast(exp(cur_random * log_range_)) - 1; - // Mathematically, value should be <= range_, but might not be due to some - // floating point roundoff, so we mod by range_. - return value % range_; -} - -float LogUniformSampler::Probability(int64_t value) const { - // Given f(x) = 1/[(x+1) * log_range_] - // The value's probability is integral of f(x) from value to (value + 1) - // More details: - // https://wanghaoshuang.github.io/2017/11/Log-uniform-distribution-sampler - return (log((value + 2.0) / (value + 1.0))) / log_range_; // NOLINT -} - -CustomSampler::CustomSampler(int64_t range, - const float *probabilities, - const int *alias, - const float *alias_probabilities, - unsigned int seed) - : Sampler(range, seed) { - random_engine_ = phi::GetCPURandomEngine(seed_); - real_dist_ = std::make_shared>(0, 1); - int_dist_ = std::make_shared>(0, range); - - alias_probs_ = alias_probabilities; - probs_ = probabilities; - alias_ = alias; -} - -int64_t CustomSampler::Sample() const { - auto index = (*int_dist_)(*random_engine_); - auto p = (*real_dist_)(*random_engine_); - if (p > alias_probs_[index]) { - int alias = alias_[index]; - - if (alias == exceptional_val) { - LOG(WARNING) << "WARNING: CustomSampler get alias " << exceptional_val; - return index; - } - - return alias; - } else { - return index; - } -} - -float CustomSampler::Probability(int64_t value) const { return probs_[value]; } - -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/sampler.h b/paddle/fluid/operators/math/sampler.h deleted file mode 100644 index e14e1ca572cab7..00000000000000 --- a/paddle/fluid/operators/math/sampler.h +++ /dev/null @@ -1,135 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include -#include - -#include "paddle/fluid/platform/enforce.h" - -namespace paddle { -namespace operators { -namespace math { - -// TODO(wanghaoshuang): Support for GPU - -/** - * Sample integers from [0, range). - */ -class Sampler { - public: - explicit Sampler(int64_t range, unsigned int seed = 0UL) : range_(range) { - PADDLE_ENFORCE_GT( - range, - 0, - phi::errors::InvalidArgument( - "Range should be greater than 0, but received %d.", range)); - if (seed == 0) { - std::random_device r; - seed_ = r(); - } else { - seed_ = seed; - } - } - - virtual ~Sampler(); - - // Sample a single value - virtual int64_t Sample() const = 0; - - // The probability that a single call to Sample() returns the given value. - virtual float Probability(int64_t value) const = 0; - - int64_t range() { return range_; } - - protected: - const int64_t range_; - unsigned int seed_; -}; - -/** - * Sample integers from [0, range). - * And the distribution function is: - * P(x) = 1 / range - */ -class UniformSampler : public Sampler { - public: - explicit UniformSampler(int64_t range, unsigned int seed = 0UL); - - ~UniformSampler() override {} - - int64_t Sample() const override; - - float Probability(int64_t value) const override; - - private: - const float inv_range_; - std::shared_ptr random_engine_; - std::shared_ptr> dist_; -}; - -/** - * Sample integers from [0, range). - * And the distribution function is: - * P(x) = (1/ln(range+1)) * ln(1 + 1/(x + 1)) - */ -class LogUniformSampler : public Sampler { - public: - explicit LogUniformSampler(int64_t range, unsigned int seed = 0UL); - - ~LogUniformSampler() override {} - - int64_t Sample() const override; - - float Probability(int64_t value) const override; - - private: - const float log_range_; - std::shared_ptr random_engine_; - std::shared_ptr> dist_; -}; - -/** - * Sample integers from [0, range) from custom distribution. - */ -class CustomSampler : public Sampler { - public: - explicit CustomSampler(int64_t range, - const float* probabilities, - const int* alias, - const float* alias_probabilities, - unsigned int seed = 0UL); - - ~CustomSampler() override {} - - int64_t Sample() const override; - - float Probability(int64_t value) const override; - - private: - const float* alias_probs_; - const int* alias_; - const float* probs_; - const int exceptional_val = -1; - std::shared_ptr random_engine_; - std::shared_ptr> real_dist_; - std::shared_ptr> int_dist_; -}; - -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/miopen_rnn_cache.h b/paddle/fluid/operators/miopen_rnn_cache.h index dd79f22e7cac77..31f185025e277a 100644 --- a/paddle/fluid/operators/miopen_rnn_cache.h +++ b/paddle/fluid/operators/miopen_rnn_cache.h @@ -17,7 +17,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" +#include "paddle/phi/backends/gpu/gpu_dnn.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/moe_op.cc b/paddle/fluid/operators/moe_op.cc deleted file mode 100644 index 186ac1fc434a7f..00000000000000 --- a/paddle/fluid/operators/moe_op.cc +++ /dev/null @@ -1,64 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/backward.h" -#include "paddle/phi/infermeta/binary.h" - -namespace paddle { -namespace operators { - -class MoeOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return phi::KernelKey(data_type, ctx.GetPlace()); - } -}; - -class MoeOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The source input tensor of Moe op."); - AddInput("Gate", "(Tensor), The gating input tensor of Moe op."); - AddInput("Bmm0", "(Tensor), The bmm0 input tensor of Moe op."); - AddInput("Bias0", "(Tensor), The eltwise0 input tensor of Moe op."); - AddInput("Bmm1", "(Tensor), The bmm1 input tensor of Moe op."); - AddInput("Bias1", "(Tensor), The eltwise1 input tensor of Moe op."); - AddOutput("Out", "(Tensor), The output tensor of Moe op."); - AddAttr( - "act_type", - R"DOC(activation type, currently only support `gelu`, `relu`. Default value is: `gelu`. )DOC") - .SetDefault("gelu"); - AddComment( - R"DOC(FusedEcMoe kernel. For more details you can refer to `FusedEcMoE` python documents. )DOC"); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -DECLARE_INFER_SHAPE_FUNCTOR(moe, - MoeInferShapeFunctor, - PD_INFER_META(phi::MoeInferMeta)); -REGISTER_OPERATOR(moe, ops::MoeOp, ops::MoeOpMaker, MoeInferShapeFunctor); diff --git a/paddle/fluid/operators/nce_op.h b/paddle/fluid/operators/nce_op.h index 5ad76785276dad..19eb81cb3d2b73 100644 --- a/paddle/fluid/operators/nce_op.h +++ b/paddle/fluid/operators/nce_op.h @@ -24,15 +24,15 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/selected_rows_utils.h" -#include "paddle/fluid/operators/math/sampler.h" #include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/math/sampler.h" #include "unsupported/Eigen/CXX11/Tensor" namespace paddle { namespace operators { using SelectedRows = phi::SelectedRows; -using Sampler = math::Sampler; +using Sampler = phi::math::Sampler; using DDim = framework::DDim; template { Sampler *sampler; switch (sampler_type) { case 0: { - sampler = new math::UniformSampler(num_total_classes - 1, seed); + sampler = new phi::math::UniformSampler(num_total_classes - 1, seed); break; } case 1: { - sampler = new math::LogUniformSampler(num_total_classes - 1, seed); + sampler = new phi::math::LogUniformSampler(num_total_classes - 1, seed); break; } case 2: { @@ -136,11 +136,11 @@ class NCEKernel : public framework::OpKernel { const float *probs_data = dist_probs->data(); const int *alias_data = dist_alias->data(); const float *alias_probs_data = dist_alias_probs->data(); - sampler = new math::CustomSampler(num_total_classes - 1, - probs_data, - alias_data, - alias_probs_data, - seed); + sampler = new phi::math::CustomSampler(num_total_classes - 1, + probs_data, + alias_data, + alias_probs_data, + seed); break; } default: { @@ -274,11 +274,11 @@ class NCEGradKernel : public framework::OpKernel { Sampler *sampler; switch (sampler_type) { case 0: { - sampler = new math::UniformSampler(num_total_classes - 1, seed); + sampler = new phi::math::UniformSampler(num_total_classes - 1, seed); break; } case 1: { - sampler = new math::LogUniformSampler(num_total_classes - 1, seed); + sampler = new phi::math::LogUniformSampler(num_total_classes - 1, seed); break; } case 2: { @@ -322,11 +322,11 @@ class NCEGradKernel : public framework::OpKernel { const float *probs_data = dist_probs->data(); const int *alias_data = dist_alias->data(); const float *alias_probs_data = dist_alias_probs->data(); - sampler = new math::CustomSampler(num_total_classes - 1, - probs_data, - alias_data, - alias_probs_data, - seed); + sampler = new phi::math::CustomSampler(num_total_classes - 1, + probs_data, + alias_data, + alias_probs_data, + seed); break; } default: { diff --git a/paddle/fluid/operators/ops_signature/cudnn_lstm_sig.cc b/paddle/fluid/operators/ops_signature/cudnn_lstm_sig.cc deleted file mode 100644 index 83e61b396ee537..00000000000000 --- a/paddle/fluid/operators/ops_signature/cudnn_lstm_sig.cc +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature CudnnLSTMOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature( - "cudnn_lstm", - {"Input", "InitH", "InitC", "W", "WeightList", "SequenceLength"}, - {"dropout_prob", - "is_bidirec", - "hidden_size", - "num_layers", - "is_test", - "seed"}, - {"Out", "LastH", "LastC", "Reserve", "StateOut"}); -} - -KernelSignature CudnnLSTMGradOpArgumentMapping( - const ArgumentMappingContext& ctx) { - return KernelSignature( - "cudnn_lstm_grad", - {"Input", - "InitH", - "InitC", - "WeightList", - "SequenceLength", - "Out", - "Reserve", - "StateOut", - "Out@GRAD", - "LastH@GRAD", - "LastC@GRAD"}, - {"dropout_prob", - "is_bidirec", - "hidden_size", - "num_layers", - "is_test", - "seed"}, - {"Input@GRAD", "InitH@GRAD", "InitC@GRAD", "WeightList@GRAD"}); -} - -} // namespace phi - -PD_REGISTER_ARG_MAPPING_FN(cudnn_lstm, phi::CudnnLSTMOpArgumentMapping); -PD_REGISTER_ARG_MAPPING_FN(cudnn_lstm_grad, - phi::CudnnLSTMGradOpArgumentMapping); diff --git a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc index 5fbbd49a885210..1d58154d36064b 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc @@ -17,7 +17,7 @@ limitations under the License. */ #include #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" +#include "paddle/phi/backends/gpu/gpu_dnn.h" #endif namespace paddle { diff --git a/paddle/fluid/operators/tdm_child_op.cc b/paddle/fluid/operators/tdm_child_op.cc index 41bcae86c551bd..6e3804fcb0a923 100644 --- a/paddle/fluid/operators/tdm_child_op.cc +++ b/paddle/fluid/operators/tdm_child_op.cc @@ -17,7 +17,6 @@ #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/sampler.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { diff --git a/paddle/fluid/operators/unbind_op.h b/paddle/fluid/operators/unbind_op.h index ea2c6d4ee2bb8c..dad3e2ed9001bd 100644 --- a/paddle/fluid/operators/unbind_op.h +++ b/paddle/fluid/operators/unbind_op.h @@ -20,8 +20,8 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/phi/kernels/funcs/strided_memcpy.h" namespace paddle { diff --git a/paddle/fluid/operators/unique_op.h b/paddle/fluid/operators/unique_op.h index 47bd4674c9a299..0bced76407b7e8 100644 --- a/paddle/fluid/operators/unique_op.h +++ b/paddle/fluid/operators/unique_op.h @@ -22,7 +22,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/concat_and_split.h" +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { @@ -304,7 +304,7 @@ static void UniqueDim(const framework::ExecutionContext& context, indices_vec.erase(indices_vec.begin() + input_unbind.size(), indices_vec.end()); - math::ConcatFunctor concat_functor; + phi::funcs::ConcatFunctor concat_functor; phi::DenseTensor out_trans; std::vector out_trans_dims_vec = in_trans_dims_vec; out_trans_dims_vec[0] = input_unbind.size(); diff --git a/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py b/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py index 5ae7c3152e0fb2..2348eec77b7126 100644 --- a/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py +++ b/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py @@ -29,6 +29,7 @@ "elu", "embedding", "flatten", + "floor_divide", "full_like", "gelu", "hardswish", @@ -67,6 +68,7 @@ "elu", "embedding", "flatten", + "floor_divide", "full_like", "gelu", "hardswish", diff --git a/paddle/fluid/pir/dialect/op_generator/infer_symbolic_shape_gen.py b/paddle/fluid/pir/dialect/op_generator/infer_symbolic_shape_gen.py index ff2094a3df0093..c7fe5090b06900 100644 --- a/paddle/fluid/pir/dialect/op_generator/infer_symbolic_shape_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/infer_symbolic_shape_gen.py @@ -13,9 +13,9 @@ # limitations under the License. OP_GET_KERNEL_TYPE_FOR_VAR_TEMPLATE = """ -bool {op_name}::InferSymbolicShape(pir::ShapeConstraintIRAnalysis* shape_analysis) {{ +bool {op_name}::InferSymbolicShape(pir::InferSymbolicShapeContext* infer_context) {{ VLOG(4) << "Infer symbolic shape for op: {op_name}"; - return {op_name}InferSymbolicShape(this->operation(), shape_analysis); + return {op_name}InferSymbolicShape(this->operation(), infer_context); }} """ diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index 148f3e2ac5855a..a792a920328f18 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -174,7 +174,7 @@ class {TEST_API} {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{ """ infer_symbolic_shape_template = """ - bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis* shape_analysis); + bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context); """ # ===================================== diff --git a/paddle/fluid/pir/dialect/op_generator/op_infermeta_func_gen.py b/paddle/fluid/pir/dialect/op_generator/op_infermeta_func_gen.py index 31078476b23e23..caa5a4387f63ea 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_infermeta_func_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_infermeta_func_gen.py @@ -281,9 +281,12 @@ def GenBuildOutputsPart2( }} """ + # In cudnn_lstm operator, the output weight_list_grad requires the use of optional input weight_list, + # so "pir::VectorType {name}" outside the "if" block. CREATE_OPTIONAL_INPUT_VEC_METATENSOR_TEMPLATE = """ std::vector vec_ir_tensor_{name}; + pir::VectorType {name}; if ({name}_.impl() != nullptr) {{ - pir::VectorType {name} = {name}_.type().dyn_cast(); + {name} = {name}_.type().dyn_cast(); for (size_t i=0; i < static_cast({name}.size()); i++) {{ if({name}[i].isa()) {{ auto {name}_type = {name}[i].dyn_cast(); diff --git a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index 2647b579f2bc73..4036443ea206fc 100644 --- a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -206,6 +206,11 @@ 'push_dense', 'limit_by_capacity', 'global_scatter', + 'global_gather', + 'pull_box_sparse', + 'pull_box_sparse_', + 'push_box_sparse', + 'push_box_sparse_', ] diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc index ad92776f70f996..31a99593f375db 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -64,8 +64,8 @@ inline void UpdatePaddingAndDilation( } // namespace namespace paddle::dialect { -bool Conv2dOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { +bool Conv2dOpInferSymbolicShape(pir::Operation *op, + pir::InferSymbolicShapeContext *infer_context) { const std::vector strides = paddle::dialect::details::GetVectorAttr(op, "strides"); @@ -84,9 +84,9 @@ bool Conv2dOpInferSymbolicShape( .AsString(); const auto in_s_or_d = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + infer_context->GetShapeOrDataForValue(op->operand_source(0)); const auto filter_s_or_d = - shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + infer_context->GetShapeOrDataForValue(op->operand_source(1)); const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); @@ -144,22 +144,22 @@ bool Conv2dOpInferSymbolicShape( symbol::TensorShapeOrDataDimExprs(out_s_or_d)}; }(); - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + infer_context->SetShapeOrDataForValue(op->result(0), shape_data); return true; } -bool Conv3dOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return Conv2dOpInferSymbolicShape(op, shape_analysis); +bool Conv3dOpInferSymbolicShape(pir::Operation *op, + pir::InferSymbolicShapeContext *infer_context) { + return Conv2dOpInferSymbolicShape(op, infer_context); } bool EmbeddingOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const auto x_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + infer_context->GetShapeOrDataForValue(op->operand_source(0)); const auto weight_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + infer_context->GetShapeOrDataForValue(op->operand_source(1)); const std::vector &x_dims = [&] { std::vector dims; if (x_shape_or_data.data().has_value()) { @@ -189,20 +189,20 @@ bool EmbeddingOpInferSymbolicShape( symbol::TensorShapeOrDataDimExprs(out_dims)}; }(); - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + infer_context->SetShapeOrDataForValue(op->result(0), shape_data); return true; } bool SparseWeightEmbeddingOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { PADDLE_THROW(phi::errors::Unimplemented( op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); return true; } bool ExpandAsOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { std::vector target_shape = paddle::dialect::details::GetVectorAttr(op, "target_shape"); const std::vector &output_dims = [&] { @@ -214,18 +214,18 @@ bool ExpandAsOpInferSymbolicShape( return output_dims; }(); - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( op->result(0), symbol::TensorShapeOrDataDimExprs(output_dims)); return true; } -bool GatherOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { +bool GatherOpInferSymbolicShape(pir::Operation *op, + pir::InferSymbolicShapeContext *infer_context) { const auto &input_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + infer_context->GetShapeOrDataForValue(op->operand_source(0)); const auto &index_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + infer_context->GetShapeOrDataForValue(op->operand_source(1)); const auto &numel = [&] { symbol::DimExpr numel{1}; @@ -247,7 +247,7 @@ bool GatherOpInferSymbolicShape( "in GatherOpInferSymbolicShape: The number of operands should be " "3 when the axis is not set.")); const auto &axis_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(2)); + infer_context->GetShapeOrDataForValue(op->operand_source(2)); axis = static_cast(axis_shape_or_data.data().value()[0].Get()); } @@ -294,17 +294,17 @@ bool GatherOpInferSymbolicShape( symbol::TensorShapeOrDataDimExprs(out_sym_shape)}; pir::Value res = op->result(0); - shape_analysis->SetShapeOrDataForValue(res, shape_data); + infer_context->SetShapeOrDataForValue(res, shape_data); return true; } bool GatherNdOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const auto &x_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + infer_context->GetShapeOrDataForValue(op->operand_source(0)); const auto &index_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + infer_context->GetShapeOrDataForValue(op->operand_source(1)); const std::vector &x_sym_shape = x_shape_or_data.shape(); const std::vector &index_sym_shape = @@ -337,17 +337,17 @@ bool GatherNdOpInferSymbolicShape( symbol::TensorShapeOrDataDimExprs(result_sym_dims)}; pir::Value res = op->result(0); - shape_analysis->SetShapeOrDataForValue(res, shape_data); + infer_context->SetShapeOrDataForValue(res, shape_data); return true; } bool KronOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::InferSymbolicShapeContext *infer_context) { const auto &x_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)).shape(); + infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape(); const auto &y_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(1)).shape(); + infer_context->GetShapeOrDataForValue(op->operand_source(1)).shape(); const int rank_x = x_shape_or_data.size(); const int rank_y = y_shape_or_data.size(); const int rank = (rank_x > rank_y) ? rank_x : rank_y; @@ -366,32 +366,32 @@ bool KronOpInferSymbolicShape(pir::Operation *op, symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(dim_out)}; pir::Value res = op->result(0); - shape_analysis->SetShapeOrDataForValue(res, shape_data); + infer_context->SetShapeOrDataForValue(res, shape_data); return true; } bool MaskedSelectOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const std::vector &out_dims = [&] { std::vector out_dims; symbol::DimExpr out_shape = - shape_analysis->GetNextSymName(); // unknown until runtime + infer_context->GetNextSymName(); // unknown until runtime out_dims.push_back(out_shape); return out_dims; }(); // TODO(fty1777): Add constrains between the shapes of x and mask - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( op->result(0), symbol::TensorShapeOrDataDimExprs{out_dims}); return true; } -bool MatmulOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { +bool MatmulOpInferSymbolicShape(pir::Operation *op, + pir::InferSymbolicShapeContext *infer_context) { // x_dims can't be const or ref here, in case to be broadcasted std::vector x_dims = [&] { std::vector dims; const auto &x_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + infer_context->GetShapeOrDataForValue(op->operand_source(0)); if (x_shape_or_data.data().has_value()) { dims = x_shape_or_data.data().value(); } else { @@ -404,7 +404,7 @@ bool MatmulOpInferSymbolicShape( std::vector y_dims = [&] { std::vector dims; const auto y_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + infer_context->GetShapeOrDataForValue(op->operand_source(1)); if (y_shape_or_data.data().has_value()) { dims = y_shape_or_data.data().value(); } else { @@ -445,7 +445,7 @@ bool MatmulOpInferSymbolicShape( symbol::DimExprBuilder builder; for (size_t i = 0; i < ndims_x - 2; ++i) { out_dims.emplace_back(builder.Broadcast(x_dims[i], y_dims[i])); - shape_analysis->AddBroadcastableCstr(x_dims[i], y_dims[i]); + infer_context->AddBroadcastableCstr(x_dims[i], y_dims[i]); } } @@ -462,54 +462,54 @@ bool MatmulOpInferSymbolicShape( out_dims.emplace_back(out_N); } - shape_analysis->SetShapeOrDataForValue(op->result(0), - ShapeOrData{TensorExprs(out_dims)}); + infer_context->SetShapeOrDataForValue(op->result(0), + ShapeOrData{TensorExprs(out_dims)}); if ((ndims_x == ndims_y) && ndims_x >= 2) { if (transpose_x_attr == false && transpose_y_attr == false) { - shape_analysis->AddEqualCstr(x_dims[ndims_x - 1], y_dims[ndims_x - 2]); + infer_context->AddEqualCstr(x_dims[ndims_x - 1], y_dims[ndims_x - 2]); } else if (transpose_x_attr == false && transpose_y_attr == true) { - shape_analysis->AddEqualCstr(x_dims[ndims_x - 1], y_dims[ndims_x - 1]); + infer_context->AddEqualCstr(x_dims[ndims_x - 1], y_dims[ndims_x - 1]); } else if (transpose_x_attr == true && transpose_y_attr == false) { - shape_analysis->AddEqualCstr(x_dims[ndims_x - 2], y_dims[ndims_x - 2]); + infer_context->AddEqualCstr(x_dims[ndims_x - 2], y_dims[ndims_x - 2]); } else { - shape_analysis->AddEqualCstr(x_dims[ndims_x - 2], y_dims[ndims_x - 1]); + infer_context->AddEqualCstr(x_dims[ndims_x - 2], y_dims[ndims_x - 1]); } for (size_t i = 0; i < ndims_x - 2; ++i) { - shape_analysis->AddEqualCstr(x_dims[i], y_dims[i]); + infer_context->AddEqualCstr(x_dims[i], y_dims[i]); } } return true; } bool SearchsortedOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { // The shape of output is the same as input `values` (op->operand_source(1)) const symbol::ShapeOrDataDimExprs &operand_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + infer_context->GetShapeOrDataForValue(op->operand_source(1)); // TODO(fty1777): Add constrains between the shapes of `sorted_sequence` and // `values` - shape_analysis->SetShapeOrDataForValue(op->result(0), operand_shape_or_data); + infer_context->SetShapeOrDataForValue(op->result(0), operand_shape_or_data); return true; } bool IscloseOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { // The shape of output is the same as input `values` (op->operand_source(1)) const symbol::ShapeOrDataDimExprs &operand_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); - shape_analysis->SetShapeOrDataForValue(op->result(0), operand_shape_or_data); + infer_context->GetShapeOrDataForValue(op->operand_source(1)); + infer_context->SetShapeOrDataForValue(op->result(0), operand_shape_or_data); return true; } bool TakeAlongAxisOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { // input const auto &arr_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + infer_context->GetShapeOrDataForValue(op->operand_source(0)); const auto &indices_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + infer_context->GetShapeOrDataForValue(op->operand_source(1)); const auto &attributes = op->attributes(); int axis = attributes.at("axis").dyn_cast().data(); @@ -539,16 +539,16 @@ bool TakeAlongAxisOpInferSymbolicShape( symbol::TensorShapeOrDataDimExprs(out_sym_shape)}; pir::Value res = op->result(0); - shape_analysis->SetShapeOrDataForValue(res, shape_data); + infer_context->SetShapeOrDataForValue(res, shape_data); return true; } bool TopPSamplingOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - const auto &x_dims = [op, shape_analysis] { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + const auto &x_dims = [op, infer_context] { const auto &shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + infer_context->GetShapeOrDataForValue(op->operand_source(0)); if (shape_or_data.data().has_value()) { return shape_or_data.data().value(); } else { @@ -559,7 +559,7 @@ bool TopPSamplingOpInferSymbolicShape( // all the result have the same shape for (uint32_t rst_idx = 0; rst_idx < op->num_results(); rst_idx++) { const std::vector out_dims{x_dims[0], 1}; - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( op->result(rst_idx), symbol::ShapeOrDataDimExprs{ symbol::TensorShapeOrDataDimExprs(out_dims)}); diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/cinn_op_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/cinn_op_infer_sym.cc index 5c7f21a5984f39..0aec58d385311b 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/cinn_op_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/cinn_op_infer_sym.cc @@ -19,7 +19,7 @@ namespace cinn::dialect { bool BroadcastOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const std::vector &shape = paddle::dialect::details::GetVectorAttr(op, "out_shape"); @@ -33,21 +33,21 @@ bool BroadcastOpInferSymbolicShape( symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(out_dims)}; - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + infer_context->SetShapeOrDataForValue(op->result(0), shape_data); return true; } -bool ConcatOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { +bool ConcatOpInferSymbolicShape(pir::Operation *op, + pir::InferSymbolicShapeContext *infer_context) { const auto input_values = op->operands_source(); const auto input_size = input_values.size(); - if (shape_analysis->GetShapeOrDataForValue(input_values[0]) + if (infer_context->GetShapeOrDataForValue(input_values[0]) .data() .has_value()) { std::vector out_data; for (const auto &value : input_values) { - const auto &shape_or_data = shape_analysis->GetShapeOrDataForValue(value); + const auto &shape_or_data = infer_context->GetShapeOrDataForValue(value); for (size_t i = 0; i < shape_or_data.data().value().size(); ++i) { out_data.emplace_back(shape_or_data.data().value()[i]); } @@ -57,7 +57,7 @@ bool ConcatOpInferSymbolicShape( symbol::TensorShapeOrDataDimExprs(shape, out_data)}; pir::Value res = op->result(0); - shape_analysis->SetShapeOrDataForValue(res, shape_data); + infer_context->SetShapeOrDataForValue(res, shape_data); return true; } @@ -65,21 +65,21 @@ bool ConcatOpInferSymbolicShape( const auto &GetOutDimExprs = [&]() -> std::vector { std::vector out_dims = - shape_analysis->GetShapeOrDataForValue(input_values[0]).shape(); + infer_context->GetShapeOrDataForValue(input_values[0]).shape(); size_t rank = out_dims.size(); axis = axis >= 0 ? axis : std::max(int64_t(0), int64_t(axis + rank)); for (size_t i = 1; i < input_size; ++i) { const auto &operand_shape_or_data = - shape_analysis->GetShapeOrDataForValue(input_values[i]); + infer_context->GetShapeOrDataForValue(input_values[i]); out_dims[axis] = out_dims[axis] + operand_shape_or_data.shape()[axis]; } for (size_t i = 0; i < rank; ++i) { if (i == static_cast(axis)) continue; paddle::dialect::details::BuildCstrEqForTensorListAlongAxis( - shape_analysis, input_values, i); + infer_context, input_values, i); } return out_dims; @@ -88,41 +88,41 @@ bool ConcatOpInferSymbolicShape( symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(GetOutDimExprs())}; - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + infer_context->SetShapeOrDataForValue(op->result(0), shape_data); return true; } bool ReduceInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::InferSymbolicShapeContext *infer_context) { bool keep_dim = GetBoolAttr(op, "keep_dim"); auto axis = paddle::dialect::details::GetVectorAttr(op, "dim"); bool reduce_all = axis.size() == 0 ? true : false; return paddle::dialect::details::ReduceInferDim( - op, shape_analysis, axis, keep_dim, reduce_all); + op, infer_context, axis, keep_dim, reduce_all); } bool ReduceMaxOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return ReduceInferSymbolicShape(op, shape_analysis); + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + return ReduceInferSymbolicShape(op, infer_context); } bool ReduceMinOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return ReduceInferSymbolicShape(op, shape_analysis); + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + return ReduceInferSymbolicShape(op, infer_context); } bool ReduceProdOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return ReduceInferSymbolicShape(op, shape_analysis); + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + return ReduceInferSymbolicShape(op, infer_context); } bool ReduceSumOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return ReduceInferSymbolicShape(op, shape_analysis); + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + return ReduceInferSymbolicShape(op, infer_context); } bool ReshapeOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { std::vector shape = paddle::dialect::details::GetVectorAttr(op, "shape"); @@ -159,7 +159,7 @@ bool ReshapeOpInferSymbolicShape( }(); const symbol::ShapeOrDataDimExprs &x_dim_expr = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + infer_context->GetShapeOrDataForValue(op->operand_source(0)); const auto &original_shape = x_dim_expr.shape(); @@ -191,13 +191,13 @@ bool ReshapeOpInferSymbolicShape( return symbol::TensorShapeOrDataDimExprs(out_dims); }(); - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + infer_context->SetShapeOrDataForValue(op->result(0), shape_data); return true; } bool SliceOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::InferSymbolicShapeContext *infer_context) { const std::vector starts_raw = paddle::dialect::details::GetVectorAttr(op, "starts"); const std::vector ends_raw = @@ -212,10 +212,10 @@ bool SliceOpInferSymbolicShape(pir::Operation *op, const ExprVec starts = paddle::dialect::details::VecInt642Expr(starts_raw); const ExprVec ends = paddle::dialect::details::VecInt642Expr(ends_raw); - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( op->result(0), paddle::dialect::slice_utils::SliceRawInferSymbolicShape( - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)), + infer_context->GetShapeOrDataForValue(op->operand_source(0)), starts, ends, axes_raw, @@ -225,4 +225,67 @@ bool SliceOpInferSymbolicShape(pir::Operation *op, return true; } +bool GatherOpInferSymbolicShape(pir::Operation *op, + pir::InferSymbolicShapeContext *infer_context) { + const auto &input_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + const auto &index_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(1)); + + const auto &numel = [&] { + symbol::DimExpr numel{1}; + for (const auto &dim_expr : index_shape_or_data.shape()) { + numel = numel * dim_expr; + } + return numel; + }(); + + const std::vector &input_sym_shape = + input_shape_or_data.data().has_value() + ? input_shape_or_data.data().value() + : input_shape_or_data.shape(); + + const std::vector &index_sym_shape = + index_shape_or_data.data().has_value() + ? index_shape_or_data.data().value() + : index_shape_or_data.shape(); + + int axis = op->attributes().at("axis").dyn_cast().data(); + if (axis < 0) axis += input_sym_shape.size(); + + const auto &out_sym_shape = [&] { + std::vector out_sym_shape; + + if (index_sym_shape.size() == 0) { + if (input_sym_shape.size() == 1) { + out_sym_shape.push_back(symbol::DimExpr{0}); + } else { + for (int i = 0; i < axis; ++i) { + out_sym_shape.push_back(input_sym_shape[i]); + } + for (size_t i = axis + 1; i < input_sym_shape.size(); ++i) { + out_sym_shape.push_back(input_sym_shape[i]); + } + } + } else { + for (int i = 0; i < axis; ++i) { + out_sym_shape.push_back(input_sym_shape[i]); + } + out_sym_shape.push_back(numel); + for (size_t i = axis + 1; i < input_sym_shape.size(); ++i) { + out_sym_shape.push_back(input_sym_shape[i]); + } + } + return out_sym_shape; + }(); + + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(out_sym_shape)}; + + pir::Value res = op->result(0); + infer_context->SetShapeOrDataForValue(res, shape_data); + + return true; +} + } // namespace cinn::dialect diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/cinn_op_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/cinn_op_infer_sym.h index b3cc2232a1f91c..30c4b4ffcf9691 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/cinn_op_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/cinn_op_infer_sym.h @@ -24,4 +24,5 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(ReduceProd) OP_DECLARE_INFER_SYMBOLIC_SHAPE(ReduceSum) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Reshape) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Slice) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Gather) } // namespace cinn::dialect diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/element_wise_binary.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/element_wise_binary.cc index e220d06f990204..d0cf6c4db4cdd1 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/element_wise_binary.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/element_wise_binary.cc @@ -17,16 +17,16 @@ bool InferSymbolicShapeElementWiseBinary( pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis, + pir::InferSymbolicShapeContext *infer_context, const std::function &DataComputeFunc = nullptr) { const auto &x_shape = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + infer_context->GetShapeOrDataForValue(op->operand_source(0)); std::vector shape_0 = x_shape.shape(); const auto &y_shape = - shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + infer_context->GetShapeOrDataForValue(op->operand_source(1)); std::vector shape_1 = y_shape.shape(); int diff = shape_0.size() - shape_1.size(); @@ -52,7 +52,7 @@ bool InferSymbolicShapeElementWiseBinary( shapes.emplace_back(shape_0[i]); } else { shapes.emplace_back(builder.Broadcast(shape_0[i], shape_1[i])); - shape_analysis->AddBroadcastableCstr(shape_0[i], shape_1[i]); + infer_context->AddBroadcastableCstr(shape_0[i], shape_1[i]); } } return shapes; @@ -85,52 +85,52 @@ bool InferSymbolicShapeElementWiseBinary( } symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(shapes, out_data)}; - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + infer_context->SetShapeOrDataForValue(op->result(0), shape_data); } else { symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(shapes)}; - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + infer_context->SetShapeOrDataForValue(op->result(0), shape_data); } return true; } -#define OP_ELEMENT_WISE_BINARY(name) \ - bool name##OpInferSymbolicShape( \ - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { \ - return InferSymbolicShapeElementWiseBinary(op, shape_analysis); \ +#define OP_ELEMENT_WISE_BINARY(name) \ + bool name##OpInferSymbolicShape( \ + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { \ + return InferSymbolicShapeElementWiseBinary(op, infer_context); \ } namespace paddle::dialect { bool AddOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::InferSymbolicShapeContext *infer_context) { return InferSymbolicShapeElementWiseBinary( op, - shape_analysis, + infer_context, [](const symbol::DimExpr &x, const symbol::DimExpr &y) { return x + y; }); } -bool DivideOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { +bool DivideOpInferSymbolicShape(pir::Operation *op, + pir::InferSymbolicShapeContext *infer_context) { return InferSymbolicShapeElementWiseBinary( op, - shape_analysis, + infer_context, [](const symbol::DimExpr &x, const symbol::DimExpr &y) { return x / y; }); } bool MultiplyOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { return InferSymbolicShapeElementWiseBinary( op, - shape_analysis, + infer_context, [](const symbol::DimExpr &x, const symbol::DimExpr &y) { return x * y; }); } bool SubtractOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { return InferSymbolicShapeElementWiseBinary( op, - shape_analysis, + infer_context, [](const symbol::DimExpr &x, const symbol::DimExpr &y) { return x - y; }); } diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_utils.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_utils.cc index 1026005ab7fc85..ef92d86111c674 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_utils.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_utils.cc @@ -38,7 +38,7 @@ ExprVec VecInt642Expr(const std::vector &int_vec) { } bool ReduceInferDim(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis, + pir::InferSymbolicShapeContext *infer_context, const std::vector &axis, bool keep_dim, bool reduce_all) { @@ -67,7 +67,7 @@ bool ReduceInferDim(pir::Operation *op, reduce_all = reduce_all || full_dim || empty_dim; const symbol::ShapeOrDataDimExprs &x_shape_or_data = - shape_analysis->GetShapeOrDataForValue(x); + infer_context->GetShapeOrDataForValue(x); std::vector input_shapes; if (x_shape_or_data.data() == std::nullopt || x_shape_or_data.data()->size() == 0) { @@ -95,28 +95,28 @@ bool ReduceInferDim(pir::Operation *op, symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(shapes)}; - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + infer_context->SetShapeOrDataForValue(op->result(0), shape_data); return true; } void BuildCstrEqForTensorListAlongAxis( - pir::ShapeConstraintIRAnalysis *shape_analysis, + pir::InferSymbolicShapeContext *infer_context, const symbol::TensorListShapeOrDataDimExprs &shape_data_list, int axis) { for (size_t i = 1; i < shape_data_list.size(); ++i) { - shape_analysis->AddEqualCstr(shape_data_list[0].shape()[axis], - shape_data_list[i].shape()[axis]); + infer_context->AddEqualCstr(shape_data_list[0].shape()[axis], + shape_data_list[i].shape()[axis]); } } void BuildCstrEqForTensorListAlongAxis( - pir::ShapeConstraintIRAnalysis *shape_analysis, + pir::InferSymbolicShapeContext *infer_context, const std::vector &values, int axis) { for (size_t i = 1; i < values.size(); ++i) { - shape_analysis->AddEqualCstr( - shape_analysis->GetShapeOrDataForValue(values[0]).shape()[axis], - shape_analysis->GetShapeOrDataForValue(values[i]).shape()[axis]); + infer_context->AddEqualCstr( + infer_context->GetShapeOrDataForValue(values[0]).shape()[axis], + infer_context->GetShapeOrDataForValue(values[i]).shape()[axis]); } } diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_utils.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_utils.h index 42164c3c212549..c6e348140981f5 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_utils.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_utils.h @@ -123,18 +123,18 @@ std::optional> VecExpr2Int64(const ExprVec &expr_vec); ExprVec VecInt642Expr(const std::vector &int_vec); bool ReduceInferDim(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis, + pir::InferSymbolicShapeContext *infer_context, const std::vector &axis, bool keep_dim, bool reduce_all); void BuildCstrEqForTensorListAlongAxis( - pir::ShapeConstraintIRAnalysis *shape_analysis, + pir::InferSymbolicShapeContext *infer_context, const symbol::TensorListShapeOrDataDimExprs &shape_data_list, int axis); void BuildCstrEqForTensorListAlongAxis( - pir::ShapeConstraintIRAnalysis *shape_analysis, + pir::InferSymbolicShapeContext *infer_context, const std::vector &values, int axis); diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index 43c758e942527c..4b5f146b66684a 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -22,9 +22,9 @@ namespace paddle::dialect { bool BicubicInterpOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const symbol::ShapeOrDataDimExprs &x = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + infer_context->GetShapeOrDataForValue(op->operand_source(0)); const auto &attributes = op->attributes(); @@ -56,12 +56,12 @@ bool BicubicInterpOpInferSymbolicShape( symbol::TensorShapeOrDataDimExprs(dim_out)}; pir::Value res = op->result(0); - shape_analysis->SetShapeOrDataForValue(res, shape_data); + infer_context->SetShapeOrDataForValue(res, shape_data); return true; } symbol::DimExpr out_w_tmp{0}; - const auto &next_sym = shape_analysis->GetNextSymName(); + const auto &next_sym = infer_context->GetNextSymName(); out_w_tmp = symbol::DimExpr(next_sym); std::vector dim_out; @@ -75,7 +75,7 @@ bool BicubicInterpOpInferSymbolicShape( symbol::TensorShapeOrDataDimExprs(dim_out)}; pir::Value res = op->result(0); - shape_analysis->SetShapeOrDataForValue(res, shape_data); + infer_context->SetShapeOrDataForValue(res, shape_data); return true; } else if (x.shape().size() == 4) { // shape check for 2D interpolate for input tensor shape NCHW @@ -97,13 +97,13 @@ bool BicubicInterpOpInferSymbolicShape( symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(dim_out)}; pir::Value res = op->result(0); - shape_analysis->SetShapeOrDataForValue(res, shape_data); + infer_context->SetShapeOrDataForValue(res, shape_data); return true; } symbol::DimExpr out_h_tmp{0}; symbol::DimExpr out_w_tmp{0}; - const auto &next_sym = shape_analysis->GetNextSymName(); + const auto &next_sym = infer_context->GetNextSymName(); out_h_tmp = symbol::DimExpr(next_sym); out_w_tmp = symbol::DimExpr(next_sym); @@ -118,7 +118,7 @@ bool BicubicInterpOpInferSymbolicShape( symbol::TensorShapeOrDataDimExprs(dim_out)}; pir::Value res = op->result(0); - shape_analysis->SetShapeOrDataForValue(res, shape_data); + infer_context->SetShapeOrDataForValue(res, shape_data); return true; } else if (x.shape().size() == 5) { // shape check for 3D interpolate for input tensor shape NCDHW @@ -143,14 +143,14 @@ bool BicubicInterpOpInferSymbolicShape( symbol::TensorShapeOrDataDimExprs(dim_out)}; pir::Value res = op->result(0); - shape_analysis->SetShapeOrDataForValue(res, shape_data); + infer_context->SetShapeOrDataForValue(res, shape_data); return true; } symbol::DimExpr out_d_tmp{0}; symbol::DimExpr out_h_tmp{0}; symbol::DimExpr out_w_tmp{0}; - const auto &next_sym = shape_analysis->GetNextSymName(); + const auto &next_sym = infer_context->GetNextSymName(); out_d_tmp = symbol::DimExpr(next_sym); out_h_tmp = symbol::DimExpr(next_sym); out_w_tmp = symbol::DimExpr(next_sym); @@ -167,7 +167,7 @@ bool BicubicInterpOpInferSymbolicShape( symbol::TensorShapeOrDataDimExprs(dim_out)}; pir::Value res = op->result(0); - shape_analysis->SetShapeOrDataForValue(res, shape_data); + infer_context->SetShapeOrDataForValue(res, shape_data); return true; } else { @@ -178,40 +178,77 @@ bool BicubicInterpOpInferSymbolicShape( } bool BilinearInterpOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return BicubicInterpOpInferSymbolicShape(op, shape_analysis); + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + return BicubicInterpOpInferSymbolicShape(op, infer_context); } -bool ConcatOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { +bool ConcatOpInferSymbolicShape(pir::Operation *op, + pir::InferSymbolicShapeContext *infer_context) { pir::Value operand_source = op->operand_source(0); const auto &shape_data_list = - shape_analysis->GetShapeOrDataForValue(operand_source) + infer_context->GetShapeOrDataForValue(operand_source) .dyn_cast(); - CHECK(op->operand_source(1).defining_op()->isa()); - - int64_t axis = op->operand_source(1) - .defining_op() - .attributes() - .at("value") - .dyn_cast() - .data() - .to(); size_t rank = shape_data_list[0].shape().size(); + + int64_t axis = 0; + + auto SetShapeOrDataForAxis = [&](int axis_value) { + std::vector data{axis_value}; + symbol::TensorShapeOrDataDimExprs shape_or_data( + std::vector{}, data); + infer_context->SetShapeOrDataForValue(op->operand_source(1), shape_or_data); + }; + + if (infer_context->HasShapeOrDataForValue(op->operand_source(1)) && + (infer_context->GetShapeOrDataForValue(op->operand_source(1))) + .data() + .has_value()) { + const auto &axis_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(1)); + axis = + static_cast(axis_shape_or_data.data().value()[0].Get()); + } else { + if (op->operand_source(1).defining_op() && + op->operand_source(1).defining_op()->isa()) { + axis = op->operand_source(1) + .defining_op() + .attributes() + .at("value") + .dyn_cast() + .data() + .to(); + SetShapeOrDataForAxis(axis); + } else { + pir::Value res = op->result(0); + infer_context->SetStaticShapeForValue(res); + // update axis value + auto res_shape = infer_context->GetShapeOrDataForValue(res); + for (size_t i = 0; i < rank; ++i) { + auto res_shape_dim = res_shape.shape()[i]; + auto shape_data_dim = shape_data_list[0].shape()[i]; + if (!res_shape_dim.isa()) break; + if (!shape_data_dim.isa()) break; + if (res_shape_dim.Get() > shape_data_dim.Get()) { + SetShapeOrDataForAxis(i); + } + } + return true; + } + } axis = axis >= 0 ? axis : std::max(int64_t(0), int64_t(axis + rank)); if (shape_data_list[0].data().has_value()) { if (rank == 1) { const auto &s_or_d = - shape_analysis->GetShapeOrDataForValue(operand_source); + infer_context->GetShapeOrDataForValue(operand_source); ExprVec data = details::GetExprVecFromData(s_or_d); const std::vector shape{std::int64_t(data.size())}; symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(shape, data)}; pir::Value res = op->result(0); - shape_analysis->SetShapeOrDataForValue(res, shape_data); + infer_context->SetShapeOrDataForValue(res, shape_data); return true; } else { @@ -228,7 +265,7 @@ bool ConcatOpInferSymbolicShape( symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(shape, data)}; pir::Value res = op->result(0); - shape_analysis->SetShapeOrDataForValue(res, shape_data); + infer_context->SetShapeOrDataForValue(res, shape_data); return true; } @@ -238,7 +275,7 @@ bool ConcatOpInferSymbolicShape( for (size_t i = 0; i < rank; ++i) { if (i != static_cast(axis)) { details::BuildCstrEqForTensorListAlongAxis( - shape_analysis, shape_data_list, i); + infer_context, shape_data_list, i); continue; } for (size_t j = 1; j < shape_data_list.size(); ++j) { @@ -252,37 +289,37 @@ bool ConcatOpInferSymbolicShape( symbol::TensorShapeOrDataDimExprs(out_dims)}; pir::Value res = op->result(0); - shape_analysis->SetShapeOrDataForValue(res, shape_data); + infer_context->SetShapeOrDataForValue(res, shape_data); return true; } bool FullWithTensorOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { pir::Value operand_source = op->operand_source(1); const symbol::ShapeOrDataDimExprs &operand_shape_or_data = - shape_analysis->GetShapeOrDataForValue(operand_source); + infer_context->GetShapeOrDataForValue(operand_source); const auto &out_shape = operand_shape_or_data.data().has_value() ? operand_shape_or_data.data().value() : operand_shape_or_data.shape(); - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( op->result(0), symbol::TensorShapeOrDataDimExprs(out_shape)); return true; } bool FlashAttnOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { pir::Value operand_source = op->operand_source(0); const symbol::ShapeOrDataDimExprs &q = - shape_analysis->GetShapeOrDataForValue(operand_source); + infer_context->GetShapeOrDataForValue(operand_source); const symbol::ShapeOrDataDimExprs &k = - shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + infer_context->GetShapeOrDataForValue(op->operand_source(1)); const symbol::ShapeOrDataDimExprs &v = - shape_analysis->GetShapeOrDataForValue(op->operand_source(2)); + infer_context->GetShapeOrDataForValue(op->operand_source(2)); PADDLE_ENFORCE_EQ(q.shape().size(), 4, @@ -290,23 +327,23 @@ bool FlashAttnOpInferSymbolicShape( "flash_attn receive input with dim " "[batch_size, seq_len, num_heads, head_dim]")); - shape_analysis->AddEqualCstr(q.shape()[0], k.shape()[0]); - shape_analysis->AddEqualCstr(q.shape()[0], v.shape()[0]); - shape_analysis->AddEqualCstr(k.shape()[1], v.shape()[1]); + infer_context->AddEqualCstr(q.shape()[0], k.shape()[0]); + infer_context->AddEqualCstr(q.shape()[0], v.shape()[0]); + infer_context->AddEqualCstr(k.shape()[1], v.shape()[1]); if (op->operand_source(4)) { const symbol::ShapeOrDataDimExprs &attn_mask = - shape_analysis->GetShapeOrDataForValue(op->operand_source(4)); - shape_analysis->AddEqualCstr(attn_mask.shape()[0], q.shape()[0]); - shape_analysis->AddEqualCstr(attn_mask.shape()[2], q.shape()[1]); - shape_analysis->AddEqualCstr(attn_mask.shape()[3], k.shape()[1]); + infer_context->GetShapeOrDataForValue(op->operand_source(4)); + infer_context->AddEqualCstr(attn_mask.shape()[0], q.shape()[0]); + infer_context->AddEqualCstr(attn_mask.shape()[2], q.shape()[1]); + infer_context->AddEqualCstr(attn_mask.shape()[3], k.shape()[1]); } std::vector out_shape = q.shape(); out_shape.back() = v.shape().back(); - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( op->result(0), symbol::TensorShapeOrDataDimExprs(out_shape)); // GPU has round for seqlen, but XPU has not. Here we align with the GPU @@ -325,47 +362,47 @@ bool FlashAttnOpInferSymbolicShape( num_heads_expr, seqlen_q_rounded_expr, seqlen_k_rounded_expr}; - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( op->result(1), symbol::TensorShapeOrDataDimExprs(softmax_shape)); } if (op->result(2)) { std::vector softmax_lse_shape{ batch_size_expr, num_heads_expr, seqlen_q_rounded_expr}; - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( op->result(2), symbol::TensorShapeOrDataDimExprs(softmax_lse_shape)); } if (op->result(3)) { std::vector seed_offset_shape{symbol::DimExpr{2}}; - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( op->result(3), symbol::TensorShapeOrDataDimExprs(out_shape)); } return true; } bool GroupNormOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const symbol::ShapeOrDataDimExprs &x_shape = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + infer_context->GetShapeOrDataForValue(op->operand_source(0)); - shape_analysis->SetShapeOrDataForValue(op->result(0), x_shape); + infer_context->SetShapeOrDataForValue(op->result(0), x_shape); const symbol::DimExpr &batch_size = x_shape.shape()[0]; int groups = op->attribute("groups").data(); symbol::TensorShapeOrDataDimExprs mean_shape( std::vector{batch_size, groups}); if (op->result(1)) { - shape_analysis->SetShapeOrDataForValue(op->result(1), mean_shape); + infer_context->SetShapeOrDataForValue(op->result(1), mean_shape); } if (op->result(2)) { - shape_analysis->SetShapeOrDataForValue(op->result(2), mean_shape); + infer_context->SetShapeOrDataForValue(op->result(2), mean_shape); } return true; } bool LinspaceOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const auto &num_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(2)); + infer_context->GetShapeOrDataForValue(op->operand_source(2)); const auto step = [&] { symbol::DimExpr expr; if (num_shape_or_data.data().has_value()) { @@ -380,33 +417,33 @@ bool LinspaceOpInferSymbolicShape( return symbol::ShapeOrDataDimExprs{ symbol::TensorShapeOrDataDimExprs(out_dims)}; }(); - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + infer_context->SetShapeOrDataForValue(op->result(0), shape_data); return true; } bool LinearInterpOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return BicubicInterpOpInferSymbolicShape(op, shape_analysis); + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + return BicubicInterpOpInferSymbolicShape(op, infer_context); } bool LogspaceOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return LinspaceOpInferSymbolicShape(op, shape_analysis); + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + return LinspaceOpInferSymbolicShape(op, infer_context); } bool NearestInterpOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return BicubicInterpOpInferSymbolicShape(op, shape_analysis); + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + return BicubicInterpOpInferSymbolicShape(op, infer_context); } bool MemoryEfficientAttentionOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const auto &q_shape = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)).shape(); + infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape(); const auto &k_shape = - shape_analysis->GetShapeOrDataForValue(op->operand_source(1)).shape(); + infer_context->GetShapeOrDataForValue(op->operand_source(1)).shape(); const auto &v_shape = - shape_analysis->GetShapeOrDataForValue(op->operand_source(2)).shape(); + infer_context->GetShapeOrDataForValue(op->operand_source(2)).shape(); PADDLE_ENFORCE_EQ( q_shape.size(), 4, @@ -441,15 +478,15 @@ bool MemoryEfficientAttentionOpInferSymbolicShape( const auto &value_num_head = v_shape[2]; const auto &value_head_size = v_shape[3]; - shape_analysis->AddEqualCstr(query_batch_size, key_batch_size); - shape_analysis->AddEqualCstr(key_batch_size, value_batch_size); + infer_context->AddEqualCstr(query_batch_size, key_batch_size); + infer_context->AddEqualCstr(key_batch_size, value_batch_size); - shape_analysis->AddEqualCstr(query_num_head, key_num_head); - shape_analysis->AddEqualCstr(key_num_head, value_num_head); + infer_context->AddEqualCstr(query_num_head, key_num_head); + infer_context->AddEqualCstr(key_num_head, value_num_head); - shape_analysis->AddEqualCstr(query_head_size, key_head_size); + infer_context->AddEqualCstr(query_head_size, key_head_size); - shape_analysis->AddEqualCstr(key_seq_length, value_seq_length); + infer_context->AddEqualCstr(key_seq_length, value_seq_length); const std::vector out_dims{ query_batch_size, query_seq_length, query_num_head, value_head_size}; @@ -457,20 +494,20 @@ bool MemoryEfficientAttentionOpInferSymbolicShape( query_batch_size}; const std::vector seed_and_offset_dims{2}; - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( op->result(0), symbol::TensorShapeOrDataDimExprs(out_dims)); - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( op->result(1), symbol::TensorShapeOrDataDimExprs(logsumexp_dims)); - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( op->result(2), symbol::TensorShapeOrDataDimExprs(seed_and_offset_dims)); return true; } bool MeshgridOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const symbol::TensorListShapeOrDataDimExprs &shape_data_list = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)) + infer_context->GetShapeOrDataForValue(op->operand_source(0)) .dyn_cast(); const symbol::ShapeOrDataDimExprs sym_shape_dim_exprs = [&] { @@ -495,19 +532,19 @@ bool MeshgridOpInferSymbolicShape( }(); pir::Value res = op->result(0); - shape_analysis->SetShapeOrDataForValue(res, sym_shape_dim_exprs); + infer_context->SetShapeOrDataForValue(res, sym_shape_dim_exprs); return true; } bool StackOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::InferSymbolicShapeContext *infer_context) { pir::Value operand_source = op->operand_source(0); const auto &attributes = op->attributes(); int axis = attributes.at("axis").dyn_cast().data(); const symbol::TensorListShapeOrDataDimExprs &shape_data_list = - shape_analysis->GetShapeOrDataForValue(operand_source) + infer_context->GetShapeOrDataForValue(operand_source) .dyn_cast(); int rank = shape_data_list[0].shape().size(); @@ -528,7 +565,7 @@ bool StackOpInferSymbolicShape(pir::Operation *op, } else { for (int i = 0; i < rank; ++i) { details::BuildCstrEqForTensorListAlongAxis( - shape_analysis, shape_data_list, i); + infer_context, shape_data_list, i); } shape_dim_exprs.insert(shape_dim_exprs.begin() + axis, static_cast(shape_data_list.size())); @@ -539,39 +576,39 @@ bool StackOpInferSymbolicShape(pir::Operation *op, }(); pir::Value res = op->result(0); - shape_analysis->SetShapeOrDataForValue(res, shape_data); + infer_context->SetShapeOrDataForValue(res, shape_data); return true; } bool TrilinearInterpOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return BicubicInterpOpInferSymbolicShape(op, shape_analysis); + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + return BicubicInterpOpInferSymbolicShape(op, infer_context); } bool WhereOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - shape_analysis->SetShapeOrDataForValue( + pir::InferSymbolicShapeContext *infer_context) { + infer_context->SetShapeOrDataForValue( op->result(0), - shape_analysis->GetShapeOrDataForValue(op->operand_source(0))); + infer_context->GetShapeOrDataForValue(op->operand_source(0))); const std::vector &operands = {op->operand_source(0), op->operand_source(1)}; - size_t rank = shape_analysis->GetShapeOrDataForValue(op->operand_source(0)) + size_t rank = infer_context->GetShapeOrDataForValue(op->operand_source(0)) .shape() .size(); for (size_t i = 0; i < rank; ++i) { paddle::dialect::details::BuildCstrEqForTensorListAlongAxis( - shape_analysis, operands, i); + infer_context, operands, i); } return true; } -bool Where_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return WhereOpInferSymbolicShape(op, shape_analysis); +bool Where_OpInferSymbolicShape(pir::Operation *op, + pir::InferSymbolicShapeContext *infer_context) { + return WhereOpInferSymbolicShape(op, infer_context); } } // namespace paddle::dialect diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.cc index 069c646fc60edb..b15c72ea4db013 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.cc @@ -17,14 +17,14 @@ namespace paddle::dialect { -bool ArangeOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { +bool ArangeOpInferSymbolicShape(pir::Operation *op, + pir::InferSymbolicShapeContext *infer_context) { const auto &start_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + infer_context->GetShapeOrDataForValue(op->operand_source(0)); const auto &end_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + infer_context->GetShapeOrDataForValue(op->operand_source(1)); const auto &step_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(2)); + infer_context->GetShapeOrDataForValue(op->operand_source(2)); const symbol::ShapeOrDataDimExprs &shape_data = [&] { if (!start_shape_or_data.data().has_value() || @@ -32,7 +32,7 @@ bool ArangeOpInferSymbolicShape( !step_shape_or_data.data().has_value()) { return symbol::ShapeOrDataDimExprs{ symbol::TensorShapeOrDataDimExprs(std::vector{ - symbol::DimExpr(shape_analysis->GetNextSymName())})}; + symbol::DimExpr(infer_context->GetNextSymName())})}; } const auto &start = start_shape_or_data.data()->at(0); const auto &end = end_shape_or_data.data()->at(0); @@ -45,13 +45,13 @@ bool ArangeOpInferSymbolicShape( symbol::TensorShapeOrDataDimExprs(out_dims)}; }(); - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + infer_context->SetShapeOrDataForValue(op->result(0), shape_data); return true; } bool AssignValueOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const std::vector shape = paddle::dialect::details::GetVectorAttr(op, "shape"); std::vector sym_dims; @@ -77,23 +77,23 @@ bool AssignValueOpInferSymbolicShape( symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(sym_dims, data)}; - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + infer_context->SetShapeOrDataForValue(op->result(0), shape_data); return true; } symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(sym_dims)}; - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + infer_context->SetShapeOrDataForValue(op->result(0), shape_data); return true; } bool AssignValue_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return AssignValueOpInferSymbolicShape(op, shape_analysis); + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + return AssignValueOpInferSymbolicShape(op, infer_context); } bool DataOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::InferSymbolicShapeContext *infer_context) { const auto &attributes = op->attributes(); pir::Attribute attr = attributes.at("shape"); @@ -104,7 +104,7 @@ bool DataOpInferSymbolicShape(pir::Operation *op, for (auto dim : dims) { symbol::DimExpr dim_expr; if (dim == pir::ShapedTypeInterface::kDynamic) { - symbol::DimExpr symbolic_dim_expr(shape_analysis->GetNextSymName()); + symbol::DimExpr symbolic_dim_expr(infer_context->GetNextSymName()); dim_expr = symbolic_dim_expr; } else { symbol::DimExpr numeric_dim_expr(dim); @@ -131,7 +131,7 @@ bool DataOpInferSymbolicShape(pir::Operation *op, const auto &shape_or_data = [&]() { if (IsOneNumel(op->result(0)) && IsIntType(op->result(0))) { std::vector data{ - symbol::DimExpr(shape_analysis->GetNextSymName())}; + symbol::DimExpr(infer_context->GetNextSymName())}; return symbol::ShapeOrDataDimExprs{ symbol::TensorShapeOrDataDimExprs(sym_dims, data)}; } else { @@ -140,13 +140,13 @@ bool DataOpInferSymbolicShape(pir::Operation *op, } }(); - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_or_data); + infer_context->SetShapeOrDataForValue(op->result(0), shape_or_data); return true; } bool EmptyOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::InferSymbolicShapeContext *infer_context) { const auto &shape_gen_op = op->operand_source(0).defining_op(); if (shape_gen_op->isa()) { std::vector shape = details::GetVectorAttr( @@ -159,13 +159,13 @@ bool EmptyOpInferSymbolicShape(pir::Operation *op, symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(sym_dims)}; - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + infer_context->SetShapeOrDataForValue(op->result(0), shape_data); return true; } else { pir::Value operand_source = op->operand_source(0); const symbol::ShapeOrDataDimExprs &operand_shape_or_data = - shape_analysis->GetShapeOrDataForValue(operand_source); + infer_context->GetShapeOrDataForValue(operand_source); PADDLE_ENFORCE_EQ( operand_shape_or_data.data().has_value(), true, @@ -173,7 +173,7 @@ bool EmptyOpInferSymbolicShape(pir::Operation *op, "The data of input dim_expr shape is null. When input of empty op " "is a tensor, the data of input dim_expr shape must have value.")); - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( op->result(0), symbol::TensorShapeOrDataDimExprs{ operand_shape_or_data.data().value()}); @@ -182,19 +182,19 @@ bool EmptyOpInferSymbolicShape(pir::Operation *op, } bool FeedOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::InferSymbolicShapeContext *infer_context) { const common::DDim &result_dims = op->result(0).type().dyn_cast().dims(); std::vector out_dims; for (int i = 0; i < result_dims.size(); i++) { if (result_dims[i] == -1) { - out_dims.emplace_back(shape_analysis->GetNextSymName()); + out_dims.emplace_back(infer_context->GetNextSymName()); } else { out_dims.emplace_back(result_dims[i]); } } - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( op->result(0), symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(out_dims)}); @@ -202,7 +202,7 @@ bool FeedOpInferSymbolicShape(pir::Operation *op, } bool FullOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::InferSymbolicShapeContext *infer_context) { const auto &attributes = op->attributes(); const std::vector shape = [&] { @@ -243,13 +243,13 @@ bool FullOpInferSymbolicShape(pir::Operation *op, } }(); - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( op->result(0), symbol::ShapeOrDataDimExprs(shape_data)); return true; } bool FullIntArrayOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const auto &attributes = op->attributes(); pir::Attribute attr_value = attributes.at("value"); const auto &vec = attr_value.dyn_cast().AsVector(); @@ -269,12 +269,12 @@ bool FullIntArrayOpInferSymbolicShape( symbol::TensorShapeOrDataDimExprs(shape, data)}; pir::Value res = op->result(0); - shape_analysis->SetShapeOrDataForValue(res, shape_data); + infer_context->SetShapeOrDataForValue(res, shape_data); return true; } bool GaussianOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const auto &shape_gen_op = op->operand_source(0).defining_op(); if (shape_gen_op->isa()) { @@ -288,19 +288,30 @@ bool GaussianOpInferSymbolicShape( symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(sym_dims)}; - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + infer_context->SetShapeOrDataForValue(op->result(0), shape_data); return true; } else { - PADDLE_THROW(phi::errors::Unimplemented( - "Currently shape must comes from FullIntArrayOp in GaussianOp's " - "InferSymbolicShape.")); + pir::Value operand_source = op->operand_source(0); + const symbol::ShapeOrDataDimExprs &operand_shape_or_data = + infer_context->GetShapeOrDataForValue(operand_source); + PADDLE_ENFORCE_EQ( + operand_shape_or_data.data().has_value(), + true, + common::errors::InvalidArgument( + "The data of input dim_expr shape is null. When input of empty op " + "is a tensor, the data of input dim_expr shape must have value.")); + + infer_context->SetShapeOrDataForValue( + op->result(0), + symbol::TensorShapeOrDataDimExprs{ + operand_shape_or_data.data().value()}); return true; } } bool RandintOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const auto &shape_gen_op = op->operand_source(0).defining_op(); if (shape_gen_op->isa()) { @@ -314,7 +325,7 @@ bool RandintOpInferSymbolicShape( symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(sym_dims)}; - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + infer_context->SetShapeOrDataForValue(op->result(0), shape_data); return true; } else { @@ -326,7 +337,7 @@ bool RandintOpInferSymbolicShape( } bool TrilIndicesOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const auto &attributes = op->attributes(); int rows = attributes.at("rows").dyn_cast().data(); int cols = attributes.at("cols").dyn_cast().data(); @@ -353,11 +364,11 @@ bool TrilIndicesOpInferSymbolicShape( symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(out_sym_shape)}; - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + infer_context->SetShapeOrDataForValue(op->result(0), shape_data); return true; } bool TriuIndicesOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const auto &attributes = op->attributes(); int row = attributes.at("row").dyn_cast().data(); int col = attributes.at("col").dyn_cast().data(); @@ -384,12 +395,12 @@ bool TriuIndicesOpInferSymbolicShape( symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(out_sym_shape)}; - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + infer_context->SetShapeOrDataForValue(op->result(0), shape_data); return true; } bool UniformOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return GaussianOpInferSymbolicShape(op, shape_analysis); + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + return GaussianOpInferSymbolicShape(op, infer_context); } } // namespace paddle::dialect diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.cc index 090fd7f1825ff4..6ef6d01d543578 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.cc @@ -14,14 +14,14 @@ #include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.h" -#define OP_SAME_OPERANDS_AND_RESULT(name) \ - bool name##OpInferSymbolicShape( \ - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { \ - const symbol::ShapeOrDataDimExprs &operand_shape_or_data = \ - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); \ - shape_analysis->SetShapeOrDataForValue(op->result(0), \ - operand_shape_or_data); \ - return true; \ +#define OP_SAME_OPERANDS_AND_RESULT(name) \ + bool name##OpInferSymbolicShape( \ + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { \ + const symbol::ShapeOrDataDimExprs &operand_shape_or_data = \ + infer_context->GetShapeOrDataForValue(op->operand_source(0)); \ + infer_context->SetShapeOrDataForValue(op->result(0), \ + operand_shape_or_data); \ + return true; \ } namespace paddle::dialect { @@ -138,17 +138,17 @@ OP_SAME_OPERANDS_AND_RESULT(Sigmoid) OP_SAME_OPERANDS_AND_RESULT(Sigmoid_) bool ScaleOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::InferSymbolicShapeContext *infer_context) { pir::Value operand_source = op->operand_source(0); const symbol::ShapeOrDataDimExprs &operand_shape_or_data = - shape_analysis->GetShapeOrDataForValue(operand_source); + infer_context->GetShapeOrDataForValue(operand_source); std::vector shape(operand_shape_or_data.shape()); if (operand_shape_or_data.data()) { const std::vector data = [&] { const symbol::DimExpr scale = [&]() -> symbol::DimExpr { if (op->num_operands() == 2) { - return shape_analysis->GetShapeOrDataForValue(op->operand_source(1)) + return infer_context->GetShapeOrDataForValue(op->operand_source(1)) .data() ->at(0); } @@ -164,11 +164,10 @@ bool ScaleOpInferSymbolicShape(pir::Operation *op, return data; }(); - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( op->result(0), symbol::TensorShapeOrDataDimExprs(shape, data)); } else { - shape_analysis->SetShapeOrDataForValue(op->result(0), - operand_shape_or_data); + infer_context->SetShapeOrDataForValue(op->result(0), operand_shape_or_data); } return true; diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index 1adc5767862474..5efddf897a6d03 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -18,16 +18,16 @@ namespace paddle::dialect { -bool ArgmaxOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { +bool ArgmaxOpInferSymbolicShape(pir::Operation *op, + pir::InferSymbolicShapeContext *infer_context) { bool flatten = GetBoolAttr(op, "flatten"); bool keepdims = GetBoolAttr(op, "keepdims"); const auto &input_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + infer_context->GetShapeOrDataForValue(op->operand_source(0)); const auto &axis_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + infer_context->GetShapeOrDataForValue(op->operand_source(1)); int axis = static_cast(axis_shape_or_data.data().value()[0].Get()); @@ -65,20 +65,20 @@ bool ArgmaxOpInferSymbolicShape( symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(out_sym_shape)}; - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + infer_context->SetShapeOrDataForValue(op->result(0), shape_data); return true; } -bool ArgminOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return ArgmaxOpInferSymbolicShape(op, shape_analysis); +bool ArgminOpInferSymbolicShape(pir::Operation *op, + pir::InferSymbolicShapeContext *infer_context) { + return ArgmaxOpInferSymbolicShape(op, infer_context); } bool AsComplexOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { pir::Value operand_source = op->operand_source(0); const symbol::ShapeOrDataDimExprs &operand_shape_or_data = - shape_analysis->GetShapeOrDataForValue(operand_source); + infer_context->GetShapeOrDataForValue(operand_source); const std::vector out_dims = [&] { std::vector out_dims = operand_shape_or_data.shape(); @@ -89,14 +89,14 @@ bool AsComplexOpInferSymbolicShape( symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(out_dims)}; - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + infer_context->SetShapeOrDataForValue(op->result(0), shape_data); return true; } -bool AsRealOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { +bool AsRealOpInferSymbolicShape(pir::Operation *op, + pir::InferSymbolicShapeContext *infer_context) { pir::Value operand_source = op->operand_source(0); const symbol::ShapeOrDataDimExprs &operand_shape_or_data = - shape_analysis->GetShapeOrDataForValue(operand_source); + infer_context->GetShapeOrDataForValue(operand_source); const std::vector out_dims = [&] { std::vector out_dims = operand_shape_or_data.shape(); @@ -107,42 +107,42 @@ bool AsRealOpInferSymbolicShape( symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(out_dims)}; - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + infer_context->SetShapeOrDataForValue(op->result(0), shape_data); return true; } -bool CummaxOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { +bool CummaxOpInferSymbolicShape(pir::Operation *op, + pir::InferSymbolicShapeContext *infer_context) { pir::Value operand_source = op->operand_source(0); const symbol::ShapeOrDataDimExprs &operand_shape_or_data = - shape_analysis->GetShapeOrDataForValue(operand_source); + infer_context->GetShapeOrDataForValue(operand_source); - shape_analysis->SetShapeOrDataForValue(op->result(0), operand_shape_or_data); - shape_analysis->SetShapeOrDataForValue(op->result(1), operand_shape_or_data); + infer_context->SetShapeOrDataForValue(op->result(0), operand_shape_or_data); + infer_context->SetShapeOrDataForValue(op->result(1), operand_shape_or_data); return true; } -bool CumminOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return CummaxOpInferSymbolicShape(op, shape_analysis); +bool CumminOpInferSymbolicShape(pir::Operation *op, + pir::InferSymbolicShapeContext *infer_context) { + return CummaxOpInferSymbolicShape(op, infer_context); } bool CumprodOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { pir::Value operand_source = op->operand_source(0); const symbol::ShapeOrDataDimExprs &operand_shape_or_data = - shape_analysis->GetShapeOrDataForValue(operand_source); - shape_analysis->SetShapeOrDataForValue(op->result(0), operand_shape_or_data); + infer_context->GetShapeOrDataForValue(operand_source); + infer_context->SetShapeOrDataForValue(op->result(0), operand_shape_or_data); return true; } bool Cumprod_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return CumprodOpInferSymbolicShape(op, shape_analysis); + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + return CumprodOpInferSymbolicShape(op, infer_context); } -bool CumsumOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { +bool CumsumOpInferSymbolicShape(pir::Operation *op, + pir::InferSymbolicShapeContext *infer_context) { pir::Value operand_source = op->operand_source(0); const symbol::ShapeOrDataDimExprs &operand_shape_or_data = - shape_analysis->GetShapeOrDataForValue(operand_source); + infer_context->GetShapeOrDataForValue(operand_source); bool flatten = GetBoolAttr(op, "flatten"); if (flatten) { @@ -154,24 +154,23 @@ bool CumsumOpInferSymbolicShape( const std::vector out_dims = {product}; symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(out_dims)}; - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + infer_context->SetShapeOrDataForValue(op->result(0), shape_data); } else { - shape_analysis->SetShapeOrDataForValue(op->result(0), - operand_shape_or_data); + infer_context->SetShapeOrDataForValue(op->result(0), operand_shape_or_data); } return true; } bool Cumsum_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return CumsumOpInferSymbolicShape(op, shape_analysis); + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + return CumsumOpInferSymbolicShape(op, infer_context); } bool DiagEmbedOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { pir::Value operand_source = op->operand_source(0); const symbol::ShapeOrDataDimExprs &operand_shape_or_data = - shape_analysis->GetShapeOrDataForValue(operand_source); + infer_context->GetShapeOrDataForValue(operand_source); const auto &attributes = op->attributes(); int dim1 = attributes.at("dim1").dyn_cast().data(); int dim2 = attributes.at("dim2").dyn_cast().data(); @@ -193,14 +192,14 @@ bool DiagEmbedOpInferSymbolicShape( }(); symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(out_dims)}; - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + infer_context->SetShapeOrDataForValue(op->result(0), shape_data); return true; } bool DiagonalOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { pir::Value operand_source = op->operand_source(0); const symbol::ShapeOrDataDimExprs &operand_shape_or_data = - shape_analysis->GetShapeOrDataForValue(operand_source); + infer_context->GetShapeOrDataForValue(operand_source); const auto &attributes = op->attributes(); int axis1 = attributes.at("axis1").dyn_cast().data(); int axis2 = attributes.at("axis2").dyn_cast().data(); @@ -228,7 +227,7 @@ bool DiagonalOpInferSymbolicShape( ? builder.Min(axis1_size, axis2_size - offset_sym) : zero; } else { - res_shape = shape_analysis->GetNextSymName(); + res_shape = infer_context->GetNextSymName(); } } else { if (axis1_size.isa()) { @@ -236,29 +235,29 @@ bool DiagonalOpInferSymbolicShape( ? builder.Min(axis1_size + offset_sym, axis2_size) : zero; } else { - res_shape = shape_analysis->GetNextSymName(); + res_shape = infer_context->GetNextSymName(); } } out_dims.push_back(res_shape); symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(out_dims)}; - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + infer_context->SetShapeOrDataForValue(op->result(0), shape_data); return true; } -bool EinsumOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { +bool EinsumOpInferSymbolicShape(pir::Operation *op, + pir::InferSymbolicShapeContext *infer_context) { PADDLE_THROW(phi::errors::Unimplemented( op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); return true; } bool KthvalueOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { pir::Value operand_source = op->operand_source(0); const symbol::ShapeOrDataDimExprs &operand_shape_or_data = - shape_analysis->GetShapeOrDataForValue(operand_source); + infer_context->GetShapeOrDataForValue(operand_source); const auto &attributes = op->attributes(); int axis = attributes.at("axis").dyn_cast().data(); bool keepdim = GetBoolAttr(op, "keepdim"); @@ -278,27 +277,27 @@ bool KthvalueOpInferSymbolicShape( } symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(out_dims)}; - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); - shape_analysis->SetShapeOrDataForValue(op->result(1), shape_data); + infer_context->SetShapeOrDataForValue(op->result(0), shape_data); + infer_context->SetShapeOrDataForValue(op->result(1), shape_data); return true; } bool LogcumsumexpOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { // same as CumsumOpInferSymbolicShape - return CumsumOpInferSymbolicShape(op, shape_analysis); + return CumsumOpInferSymbolicShape(op, infer_context); } bool LogsumexpOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { bool keepdim = GetBoolAttr(op, "keepdim"); std::vector axis = details::GetVectorAttr(op, "axis"); bool reduce_all = axis.size() == 0 ? true : false; - return details::ReduceInferDim(op, shape_analysis, axis, keepdim, reduce_all); + return details::ReduceInferDim(op, infer_context, axis, keepdim, reduce_all); } bool MaxOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::InferSymbolicShapeContext *infer_context) { bool keepdim = GetBoolAttr(op, "keepdim"); const std::vector axis = [&] { @@ -320,18 +319,18 @@ bool MaxOpInferSymbolicShape(pir::Operation *op, bool reduce_all = axis.size() == 0 ? true : false; - return details::ReduceInferDim(op, shape_analysis, axis, keepdim, reduce_all); + return details::ReduceInferDim(op, infer_context, axis, keepdim, reduce_all); } bool MinOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - return MaxOpInferSymbolicShape(op, shape_analysis); + pir::InferSymbolicShapeContext *infer_context) { + return MaxOpInferSymbolicShape(op, infer_context); } bool NonzeroOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const auto &x_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + infer_context->GetShapeOrDataForValue(op->operand_source(0)); const auto &x_shape = x_shape_or_data.shape(); int rank = x_shape.size(); @@ -341,21 +340,21 @@ bool NonzeroOpInferSymbolicShape( phi::errors::InvalidArgument( "Input(x) should have number of dimension at least 1.")); - std::string sym_name = shape_analysis->GetNextSymName(); + std::string sym_name = infer_context->GetNextSymName(); std::vector out_shape{symbol::DimExpr{sym_name}, symbol::DimExpr{rank}}; symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(out_shape)}; - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + infer_context->SetShapeOrDataForValue(op->result(0), shape_data); return true; } bool PadOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::InferSymbolicShapeContext *infer_context) { // input(0): Tensor x const auto &x_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + infer_context->GetShapeOrDataForValue(op->operand_source(0)); PADDLE_ENFORCE_EQ(x_shape_or_data.data().has_value(), false, phi::errors::InvalidArgument( @@ -385,16 +384,16 @@ bool PadOpInferSymbolicShape(pir::Operation *op, return out_dims; }(); - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( op->result(0), symbol::TensorShapeOrDataDimExprs(out_dims)); return true; } bool Pad3dOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::InferSymbolicShapeContext *infer_context) { const auto &x_shape = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)).shape(); + infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape(); PADDLE_ENFORCE_EQ(x_shape.size(), 5, common::errors::InvalidArgument( @@ -402,7 +401,7 @@ bool Pad3dOpInferSymbolicShape(pir::Operation *op, "5, but received %d. ", x_shape.size())); const auto &paddings_shape = - shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + infer_context->GetShapeOrDataForValue(op->operand_source(1)); if (!paddings_shape.data().has_value()) { std::stringstream ss; ss << paddings_shape; @@ -437,14 +436,14 @@ bool Pad3dOpInferSymbolicShape(pir::Operation *op, return out_dims; }(); - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( op->result(0), symbol::TensorShapeOrDataDimExprs(out_dims)); return true; } bool ProdOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::InferSymbolicShapeContext *infer_context) { bool keepdim = GetBoolAttr(op, "keep_dim"); bool reduce_all = GetBoolAttr(op, "reduce_all"); @@ -453,7 +452,7 @@ bool ProdOpInferSymbolicShape(pir::Operation *op, std::vector axis = details::GetVectorAttr( axis_gen_op->dyn_cast(), "value"); return details::ReduceInferDim( - op, shape_analysis, axis, keepdim, reduce_all); + op, infer_context, axis, keepdim, reduce_all); } else { // TODO(lanxianghit): deal with other source: pir::VectorType, // paddle::dialect::DenseTensorType @@ -466,10 +465,10 @@ bool ProdOpInferSymbolicShape(pir::Operation *op, } bool RepeatInterleaveOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { pir::Value operand_source = op->operand_source(0); const symbol::ShapeOrDataDimExprs &operand_shape_or_data = - shape_analysis->GetShapeOrDataForValue(operand_source); + infer_context->GetShapeOrDataForValue(operand_source); const auto &attributes = op->attributes(); int repeats = attributes.at("repeats").dyn_cast().data(); @@ -501,7 +500,7 @@ bool RepeatInterleaveOpInferSymbolicShape( return out_sym_shape; }(); - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( op->result(0), symbol::ShapeOrDataDimExprs{ symbol::TensorShapeOrDataDimExprs(out_sym_shape)}); @@ -523,11 +522,11 @@ symbol::ShapeOrDataDimExprs CreateShapeOrDataForXShape( } bool ReshapeOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const symbol::ShapeOrDataDimExprs &x_dim_expr = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + infer_context->GetShapeOrDataForValue(op->operand_source(0)); const symbol::ShapeOrDataDimExprs &shape_dim_expr = - shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + infer_context->GetShapeOrDataForValue(op->operand_source(1)); const auto &GetProduct = [&](const auto &dim_exprs, const auto &Filter) { symbol::DimExpr product{1}; @@ -555,7 +554,7 @@ bool ReshapeOpInferSymbolicShape( const std::vector out_dims = [&] { const auto &original_shape = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)).shape(); + infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape(); const auto &numel = GetProduct(original_shape, [](const auto &) { return true; }); @@ -587,50 +586,50 @@ bool ReshapeOpInferSymbolicShape( return symbol::TensorShapeOrDataDimExprs(out_dims); }(); - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue(op->result(0), shape_data); + infer_context->SetShapeOrDataForValue( op->result(1), CreateShapeOrDataForXShape( - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)))); + infer_context->GetShapeOrDataForValue(op->operand_source(0)))); return true; } bool Reshape_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return ReshapeOpInferSymbolicShape(op, shape_analysis); + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + return ReshapeOpInferSymbolicShape(op, infer_context); } bool ShapeOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::InferSymbolicShapeContext *infer_context) { const symbol::ShapeOrDataDimExprs &operand_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + infer_context->GetShapeOrDataForValue(op->operand_source(0)); const auto &out_data = operand_shape_or_data.shape(); const std::vector shape{std::int64_t(out_data.size())}; symbol::ShapeOrDataDimExprs shape_or_data{ symbol::TensorShapeOrDataDimExprs(shape, out_data)}; - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_or_data); + infer_context->SetShapeOrDataForValue(op->result(0), shape_or_data); return true; } bool ShapeSrOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return ShapeOpInferSymbolicShape(op, shape_analysis); + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + return ShapeOpInferSymbolicShape(op, infer_context); } bool SliceOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::InferSymbolicShapeContext *infer_context) { pir::Value operand_source = op->operand_source(0); pir::Value operand_starts = op->operand_source(1); pir::Value operand_ends = op->operand_source(2); pir::Value res = op->result(0); const symbol::ShapeOrDataDimExprs &operand_shape_or_data = - shape_analysis->GetShapeOrDataForValue(operand_source); + infer_context->GetShapeOrDataForValue(operand_source); const symbol::ShapeOrDataDimExprs &starts_shape_data = - shape_analysis->GetShapeOrDataForValue(operand_starts); + infer_context->GetShapeOrDataForValue(operand_starts); const symbol::ShapeOrDataDimExprs &ends_shape_data = - shape_analysis->GetShapeOrDataForValue(operand_ends); + infer_context->GetShapeOrDataForValue(operand_ends); std::vector axes_vec = details::GetVectorAttr(op, "axes"); @@ -643,7 +642,7 @@ bool SliceOpInferSymbolicShape(pir::Operation *op, const std::vector decrease_axis = details::GetVectorAttr(op, "decrease_axis"); - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( res, slice_utils::SliceRawInferSymbolicShape(operand_shape_or_data, starts, @@ -656,10 +655,10 @@ bool SliceOpInferSymbolicShape(pir::Operation *op, } bool SplitOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::InferSymbolicShapeContext *infer_context) { // input const auto &x_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + infer_context->GetShapeOrDataForValue(op->operand_source(0)); PADDLE_ENFORCE_EQ(x_shape_or_data.data().has_value(), false, phi::errors::InvalidArgument( @@ -683,7 +682,7 @@ bool SplitOpInferSymbolicShape(pir::Operation *op, // sections const std::vector §ions_sym = [&] { const auto §ions_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + infer_context->GetShapeOrDataForValue(op->operand_source(1)); std::vector sections_sym; if (sections_shape_or_data.data().has_value()) { sections_sym = sections_shape_or_data.data().value(); @@ -723,7 +722,7 @@ bool SplitOpInferSymbolicShape(pir::Operation *op, const bool &all_sections_sym_not_minus_one = All(sections_sym, IsNotMinusOne); if (all_sections_sym_not_minus_one) { - shape_analysis->AddEqualCstr(x_dims_sym[axis], sum_exclude_minus_one); + infer_context->AddEqualCstr(x_dims_sym[axis], sum_exclude_minus_one); } symbol::TensorListShapeOrDataDimExprs shape_data_list; @@ -747,14 +746,14 @@ bool SplitOpInferSymbolicShape(pir::Operation *op, return shape_data_list; }(); - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( op->result(0), symbol::ShapeOrDataDimExprs{output_shape_data_list}); return true; } bool SplitWithNumOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { int64_t axis = op->operand_source(1) .defining_op() .attributes() @@ -765,7 +764,7 @@ bool SplitWithNumOpInferSymbolicShape( const auto &attributes = op->attributes(); int num = attributes.at("num").dyn_cast().data(); const auto &x_s_or_d = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + infer_context->GetShapeOrDataForValue(op->operand_source(0)); int rank = x_s_or_d.shape().size(); axis = axis < 0 ? axis + rank : axis; @@ -783,13 +782,13 @@ bool SplitWithNumOpInferSymbolicShape( }(); symbol::TensorListShapeOrDataDimExprs outs_s_d(num, out_s_d); - shape_analysis->SetShapeOrDataForValue(op->result(0), - symbol::ShapeOrDataDimExprs{outs_s_d}); + infer_context->SetShapeOrDataForValue(op->result(0), + symbol::ShapeOrDataDimExprs{outs_s_d}); return true; } bool SumOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::InferSymbolicShapeContext *infer_context) { bool keepdim = GetBoolAttr(op, "keepdim"); bool reduce_all = false; @@ -801,7 +800,7 @@ bool SumOpInferSymbolicShape(pir::Operation *op, reduce_all = true; } return details::ReduceInferDim( - op, shape_analysis, axis, keepdim, reduce_all); + op, infer_context, axis, keepdim, reduce_all); } else { // TODO(lanxianghit): deal with other source: pir::VectorType, // paddle::dialect::DenseTensorType @@ -814,13 +813,13 @@ bool SumOpInferSymbolicShape(pir::Operation *op, } bool TileOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::InferSymbolicShapeContext *infer_context) { pir::Value operand_x = op->operand_source(0); symbol::ShapeOrDataDimExprs x_shape_or_data = - shape_analysis->GetShapeOrDataForValue(operand_x); + infer_context->GetShapeOrDataForValue(operand_x); pir::Value operand_repeat_times = op->operand_source(1); symbol::ShapeOrDataDimExprs repeat_times_shape_or_data = - shape_analysis->GetShapeOrDataForValue(operand_repeat_times); + infer_context->GetShapeOrDataForValue(operand_repeat_times); std::vector x_dimexpr; if (x_shape_or_data.data().has_value()) { @@ -858,17 +857,17 @@ bool TileOpInferSymbolicShape(pir::Operation *op, symbol::TensorShapeOrDataDimExprs(out_shape)}; pir::Value res = op->result(0); - shape_analysis->SetShapeOrDataForValue(res, shape_data); + infer_context->SetShapeOrDataForValue(res, shape_data); return true; } bool TopkOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::InferSymbolicShapeContext *infer_context) { symbol::ShapeOrDataDimExprs x_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + infer_context->GetShapeOrDataForValue(op->operand_source(0)); symbol::ShapeOrDataDimExprs k_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + infer_context->GetShapeOrDataForValue(op->operand_source(1)); const auto &attributes = op->attributes(); int axis = attributes.at("axis").dyn_cast().data(); const std::vector &in_dims_sym = [&] { @@ -901,28 +900,28 @@ bool TopkOpInferSymbolicShape(pir::Operation *op, symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(out_sym_shape)}; - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); - shape_analysis->SetShapeOrDataForValue(op->result(1), shape_data); + infer_context->SetShapeOrDataForValue(op->result(0), shape_data); + infer_context->SetShapeOrDataForValue(op->result(1), shape_data); return true; } bool TransposeOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { std::vector perm = op->attributes().at("perm").dyn_cast().AsVector(); if (perm.size() == 1) { // perm must be [0], which means nothing to do with input, just copy the // info from input - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( op->result(0), - shape_analysis->GetShapeOrDataForValue(op->operand_source(0))); + infer_context->GetShapeOrDataForValue(op->operand_source(0))); return true; } const std::vector &x_dims = [&] { std::vector dims; const auto &x_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + infer_context->GetShapeOrDataForValue(op->operand_source(0)); if (x_shape_or_data.data().has_value()) { dims = x_shape_or_data.data().value(); } else { @@ -958,19 +957,19 @@ bool TransposeOpInferSymbolicShape( out_dims[i] = x_dims[formatted_axis[i]]; } - shape_analysis->SetShapeOrDataForValue(op->result(0), - ShapeOrData{TensorExprs(out_dims)}); + infer_context->SetShapeOrDataForValue(op->result(0), + ShapeOrData{TensorExprs(out_dims)}); return true; } bool Transpose_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return TransposeOpInferSymbolicShape(op, shape_analysis); + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + return TransposeOpInferSymbolicShape(op, infer_context); } bool SqueezeOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { PADDLE_ENFORCE_EQ( op->num_operands(), 2, @@ -980,9 +979,9 @@ bool SqueezeOpInferSymbolicShape( op->num_operands())); auto x_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + infer_context->GetShapeOrDataForValue(op->operand_source(0)); auto axes_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + infer_context->GetShapeOrDataForValue(op->operand_source(1)); std::vector in_dims_sym; if (x_shape_or_data.data().has_value()) { @@ -1056,22 +1055,22 @@ bool SqueezeOpInferSymbolicShape( symbol::TensorShapeOrDataDimExprs(output_shape_sym)}; pir::Value res = op->result(0); - shape_analysis->SetShapeOrDataForValue(res, shape_data); - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue(res, shape_data); + infer_context->SetShapeOrDataForValue( op->result(1), CreateShapeOrDataForXShape(x_shape_or_data)); return true; } bool Squeeze_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return SqueezeOpInferSymbolicShape(op, shape_analysis); + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + return SqueezeOpInferSymbolicShape(op, infer_context); } -bool UnbindOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { +bool UnbindOpInferSymbolicShape(pir::Operation *op, + pir::InferSymbolicShapeContext *infer_context) { // input const auto &x_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + infer_context->GetShapeOrDataForValue(op->operand_source(0)); PADDLE_ENFORCE_EQ( x_shape_or_data.data().has_value(), false, @@ -1106,16 +1105,16 @@ bool UnbindOpInferSymbolicShape( return shape_data_list; }(); - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( op->result(0), symbol::ShapeOrDataDimExprs{output_shape_data_list}); return true; } -bool UniqueOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { +bool UniqueOpInferSymbolicShape(pir::Operation *op, + pir::InferSymbolicShapeContext *infer_context) { const auto &x_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + infer_context->GetShapeOrDataForValue(op->operand_source(0)); PADDLE_ENFORCE_EQ( x_shape_or_data.data().has_value(), false, @@ -1128,7 +1127,7 @@ bool UniqueOpInferSymbolicShape( paddle::dialect::details::GetVectorAttr(op, "axis"); symbol::DimExpr unique_dim_sym = - shape_analysis->GetNextSymName(); // unknown until runtime + infer_context->GetNextSymName(); // unknown until runtime const std::vector &counts_dims = [&] { std::vector out_dims; @@ -1166,22 +1165,22 @@ bool UniqueOpInferSymbolicShape( return inverse_dims; }(); - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( op->result(0), symbol::TensorShapeOrDataDimExprs{out_dims}); - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( op->result(1), symbol::TensorShapeOrDataDimExprs{index_dims}); - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( op->result(2), symbol::TensorShapeOrDataDimExprs{inverse_dims}); - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( op->result(3), symbol::TensorShapeOrDataDimExprs{counts_dims}); return true; } bool UniqueConsecutiveOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const auto &x_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + infer_context->GetShapeOrDataForValue(op->operand_source(0)); PADDLE_ENFORCE_EQ( x_shape_or_data.data().has_value(), false, @@ -1194,7 +1193,7 @@ bool UniqueConsecutiveOpInferSymbolicShape( paddle::dialect::details::GetVectorAttr(op, "axis"); symbol::DimExpr unique_dim_sym = - shape_analysis->GetNextSymName(); // unknown until runtime + infer_context->GetNextSymName(); // unknown until runtime const std::vector &counts_dims = [&] { std::vector out_dims; @@ -1230,18 +1229,18 @@ bool UniqueConsecutiveOpInferSymbolicShape( return inverse_dims; }(); - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( op->result(0), symbol::TensorShapeOrDataDimExprs{out_dims}); - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( op->result(1), symbol::TensorShapeOrDataDimExprs{inverse_dims}); - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( op->result(2), symbol::TensorShapeOrDataDimExprs{counts_dims}); return true; } bool UnsqueezeOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { PADDLE_ENFORCE_EQ( op->num_operands(), 2, @@ -1251,9 +1250,9 @@ bool UnsqueezeOpInferSymbolicShape( op->num_operands())); auto x_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + infer_context->GetShapeOrDataForValue(op->operand_source(0)); auto axes_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + infer_context->GetShapeOrDataForValue(op->operand_source(1)); std::vector x_sym_shape; if (x_shape_or_data.data().has_value()) { @@ -1311,15 +1310,15 @@ bool UnsqueezeOpInferSymbolicShape( symbol::TensorShapeOrDataDimExprs(result_sym_dims)}; pir::Value res = op->result(0); - shape_analysis->SetShapeOrDataForValue(res, shape_data); - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue(res, shape_data); + infer_context->SetShapeOrDataForValue( op->result(1), CreateShapeOrDataForXShape(x_shape_or_data)); return true; } bool Unsqueeze_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return UnsqueezeOpInferSymbolicShape(op, shape_analysis); + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + return UnsqueezeOpInferSymbolicShape(op, infer_context); } } // namespace paddle::dialect diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc index 437aac78c2367f..13e0d29103190d 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc @@ -309,18 +309,18 @@ std::vector> IfOp::Vjp( return res; } -bool IfOp::InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis) { +bool IfOp::InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context) { // infer true block - pir::InferSymExprForBlock(true_block(), shape_analysis); + pir::InferSymExprForBlock(true_block(), infer_context); // infer false block - pir::InferSymExprForBlock(false_block(), shape_analysis); + pir::InferSymExprForBlock(false_block(), infer_context); auto GetSymExprForBlockResult = - [shape_analysis](const pir::Operation &op, - uint32_t idx) -> const std::vector & { + [infer_context](const pir::Operation &op, + uint32_t idx) -> const std::vector & { const auto &shape_or_data = - shape_analysis->GetShapeOrDataForValue(op.operand_source(idx)); + infer_context->GetShapeOrDataForValue(op.operand_source(idx)); if (shape_or_data.data().has_value()) { return shape_or_data.data().value(); } else { @@ -359,12 +359,12 @@ bool IfOp::InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis) { false_dims.size())); for (size_t i = 0; i < true_dims.size(); i++) { if (true_dims[i] != false_dims[i]) { - out_dims[i] = symbol::DimExpr{shape_analysis->GetNextSymName()}; + out_dims[i] = symbol::DimExpr{infer_context->GetNextSymName()}; } } } - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( result(rst_idx), symbol::ShapeOrDataDimExprs{ symbol::TensorShapeOrDataDimExprs(out_dims)}); @@ -715,7 +715,7 @@ std::vector> WhileOp::Vjp( } bool WhileOp::InferSymbolicShape( - pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::InferSymbolicShapeContext *infer_context) { for (auto &value : block_args()) { std::vector sym_dims; const std::vector &dims = @@ -724,7 +724,7 @@ bool WhileOp::InferSymbolicShape( for (auto dim : dims) { symbol::DimExpr dim_expr; if (dim == pir::ShapedTypeInterface::kDynamic) { - symbol::DimExpr symbolic_dim_expr(shape_analysis->GetNextSymName()); + symbol::DimExpr symbolic_dim_expr(infer_context->GetNextSymName()); dim_expr = symbolic_dim_expr; } else { symbol::DimExpr numeric_dim_expr(dim); @@ -734,7 +734,7 @@ bool WhileOp::InferSymbolicShape( } symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(sym_dims)}; - shape_analysis->SetShapeOrDataForValue(value, shape_data); + infer_context->SetShapeOrDataForValue(value, shape_data); } // add GreaterThanOne constraint @@ -745,28 +745,28 @@ bool WhileOp::InferSymbolicShape( "The num_operands-1 and body_args.size is not equal")); for (size_t i = 0; i < body_args.size(); ++i) { const auto &input_i = - shape_analysis->GetShapeOrDataForValue(operand_source(i + 1)).shape(); + infer_context->GetShapeOrDataForValue(operand_source(i + 1)).shape(); const auto &args_i = - shape_analysis->GetShapeOrDataForValue(body_args[i]).shape(); + infer_context->GetShapeOrDataForValue(body_args[i]).shape(); if (input_i.size() != args_i.size()) { // there is a trick, so the size may vary. continue; } for (size_t j = 0; j < input_i.size(); ++j) { - if (shape_analysis->IsGreatThanOne(input_i[j])) { - shape_analysis->AddGreatThanOneCstr(args_i[j]); + if (infer_context->IsGreatThanOne(input_i[j])) { + infer_context->AddGreatThanOneCstr(args_i[j]); } } } - pir::InferSymExprForBlock(body(), shape_analysis); + pir::InferSymExprForBlock(body(), infer_context); // add constraints for args for (size_t i = 0; i < body_args.size(); ++i) { const auto &input_arg_shape = - shape_analysis->GetShapeOrDataForValue(body_args[i]).shape(); + infer_context->GetShapeOrDataForValue(body_args[i]).shape(); const auto &yield_value_shape = - shape_analysis + infer_context ->GetShapeOrDataForValue(body().back().operand_source(i + 1)) .shape(); PADDLE_ENFORCE_EQ(input_arg_shape.size(), @@ -780,21 +780,21 @@ bool WhileOp::InferSymbolicShape( input_arg_shape.size(), yield_value_shape.size())); const auto &original_input_shape = - shape_analysis->GetShapeOrDataForValue(operand_source(i + 1)).shape(); + infer_context->GetShapeOrDataForValue(operand_source(i + 1)).shape(); for (size_t j = 0; j < input_arg_shape.size(); ++j) { if (input_arg_shape[j].isa()) { continue; } if (input_arg_shape[j] == yield_value_shape[j]) { // Dim isn't changed in while - shape_analysis->AddEqualCstr(original_input_shape[j], - input_arg_shape[j]); + infer_context->AddEqualCstr(original_input_shape[j], + input_arg_shape[j]); continue; } if (original_input_shape.size() == yield_value_shape.size() && original_input_shape[j] == yield_value_shape[j]) { - shape_analysis->AddEqualCstr(original_input_shape[j], - input_arg_shape[j]); + infer_context->AddEqualCstr(original_input_shape[j], + input_arg_shape[j]); continue; } } @@ -802,9 +802,9 @@ bool WhileOp::InferSymbolicShape( const auto &last_op = body().back(); for (size_t i = 1; i < last_op.operands_source().size(); ++i) { - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( result(i - 1), - shape_analysis->GetShapeOrDataForValue(last_op.operand_source(i))); + infer_context->GetShapeOrDataForValue(last_op.operand_source(i))); } PADDLE_ENFORCE_EQ(body_args.size(), @@ -813,18 +813,18 @@ bool WhileOp::InferSymbolicShape( "The body_args.size and num_results is not equal")); for (size_t i = 0; i < num_results(); ++i) { const auto &input_i = - shape_analysis->GetShapeOrDataForValue(operand_source(i + 1)).shape(); + infer_context->GetShapeOrDataForValue(operand_source(i + 1)).shape(); const auto &output_i = - shape_analysis->GetShapeOrDataForValue(result(i)).shape(); + infer_context->GetShapeOrDataForValue(result(i)).shape(); const auto &args_i = - shape_analysis->GetShapeOrDataForValue(body_args[i]).shape(); + infer_context->GetShapeOrDataForValue(body_args[i]).shape(); if (input_i.size() != args_i.size()) { // there is a trick, so the size may vary. continue; } for (size_t j = 0; j < output_i.size(); j++) { - if (shape_analysis->IsEqual(output_i[j], args_i[j])) { - shape_analysis->AddEqualCstr(output_i[j], input_i[j]); + if (infer_context->IsEqual(output_i[j], args_i[j])) { + infer_context->AddEqualCstr(output_i[j], input_i[j]); } } } @@ -1118,10 +1118,10 @@ void SelectInputOp::VerifySig() { } bool SelectInputOp::InferSymbolicShape( - pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::InferSymbolicShapeContext *infer_context) { auto GetSymExprForValue = - [shape_analysis](pir::Value val) -> const std::vector & { - const auto &shape_or_data = shape_analysis->GetShapeOrDataForValue(val); + [infer_context](pir::Value val) -> const std::vector & { + const auto &shape_or_data = infer_context->GetShapeOrDataForValue(val); if (shape_or_data.data().has_value()) { return shape_or_data.data().value(); } else { @@ -1134,7 +1134,7 @@ bool SelectInputOp::InferSymbolicShape( // for compatibility, we just return second_shape. if (input1_dims.size() != input2_dims.size()) { - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( result(0), symbol::ShapeOrDataDimExprs{ symbol::TensorShapeOrDataDimExprs(input2_dims)}); @@ -1148,12 +1148,12 @@ bool SelectInputOp::InferSymbolicShape( if (input2_dims.size() != 0) { for (size_t i = 0; i < input1_dims.size(); i++) { if (input1_dims[i] != input2_dims[i]) { - out_dims[i] = symbol::DimExpr{shape_analysis->GetNextSymName()}; + out_dims[i] = symbol::DimExpr{infer_context->GetNextSymName()}; } } } - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( result(0), symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(out_dims)}); diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h index 9f32413743ce96..eabed1f546bfba 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h @@ -57,7 +57,7 @@ class IfOp : public pir::Op { const std::vector> &out_grads, const std::vector> &stop_gradients); - bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis); + bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context); }; class PyLayerOp : public pir::Op { @@ -122,7 +122,7 @@ class WhileOp const std::vector> &outputs, const std::vector> &out_grads, const std::vector> &stop_gradients); - bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis); + bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context); }; struct TuplePushOpVjpInterfaceModel : public VjpInterface::Concept { @@ -205,7 +205,7 @@ class SelectInputOp void VerifySig(); pir::Value mask() { return operand_source(0); } pir::Value out() { return result(0); } - bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis); + bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context); }; class SelectOutputOp : public pir::Op { diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_onednn_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_onednn_op.cc index 17d9a1dadc903e..469954f27afb56 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_onednn_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_onednn_op.cc @@ -351,7 +351,7 @@ phi::DataType ExpandOp::GetKernelTypeForVar( } bool ExpandOp::InferSymbolicShape( - pir::ShapeConstraintIRAnalysis* shape_analysis) { + pir::InferSymbolicShapeContext* infer_context) { VLOG(4) << "Infer symbolic shape for op: ExpandOp"; PADDLE_THROW(phi::errors::Unimplemented( " ExpandOp's InferSymbolicShape interface is NOT implemented now.")); diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_onednn_op.h b/paddle/fluid/pir/dialect/operator/ir/manual_onednn_op.h index 58f15f5582e65e..1aa51788b97b7a 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_onednn_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_onednn_op.h @@ -75,7 +75,7 @@ class ExpandOp : public pir::OpGetShapeOrDataForValue(x()); + pir::InferSymbolicShapeContext *infer_context) { + const auto &x_shape_or_data = infer_context->GetShapeOrDataForValue(x()); const auto &expand_shape_shape_or_data = - shape_analysis->GetShapeOrDataForValue(shape()); + infer_context->GetShapeOrDataForValue(shape()); const std::vector &x_dims = x_shape_or_data.shape(); @@ -3202,7 +3202,7 @@ bool ExpandOp::InferSymbolicShape( } } - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( out(), symbol::ShapeOrDataDimExprs{ symbol::TensorShapeOrDataDimExprs(out_shape)}); @@ -3578,10 +3578,10 @@ phi::DataType IncrementOp::GetKernelTypeForVar( } bool IncrementOp::InferSymbolicShape( - pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::InferSymbolicShapeContext *infer_context) { const symbol::ShapeOrDataDimExprs &operand_shape_or_data = - shape_analysis->GetShapeOrDataForValue(x()); - shape_analysis->SetShapeOrDataForValue(out(), operand_shape_or_data); + infer_context->GetShapeOrDataForValue(x()); + infer_context->SetShapeOrDataForValue(out(), operand_shape_or_data); return true; } @@ -3783,10 +3783,10 @@ phi::DataType Increment_Op::GetKernelTypeForVar( } bool Increment_Op::InferSymbolicShape( - pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::InferSymbolicShapeContext *infer_context) { const symbol::ShapeOrDataDimExprs &operand_shape_or_data = - shape_analysis->GetShapeOrDataForValue(x()); - shape_analysis->SetShapeOrDataForValue(out(), operand_shape_or_data); + infer_context->GetShapeOrDataForValue(x()); + infer_context->SetShapeOrDataForValue(out(), operand_shape_or_data); return true; } @@ -4109,18 +4109,11 @@ std::vector ComputeBroadcastShape( } bool ShapeBroadcastOp::InferSymbolicShape( - pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::InferSymbolicShapeContext *infer_context) { pir::Value x = operand_source(0); pir::Value y = operand_source(1); - - PADDLE_ENFORCE_GT(shape_analysis->HasShapeOrDataForValue(x), - 0, - phi::errors::InvalidArgument("Value x does not exist.")); - PADDLE_ENFORCE_GT(shape_analysis->HasShapeOrDataForValue(y), - 0, - phi::errors::InvalidArgument("Value y does not exist.")); - const auto &x_data_shape = shape_analysis->GetShapeOrDataForValue(x); - const auto &y_data_shape = shape_analysis->GetShapeOrDataForValue(y); + const auto &x_data_shape = infer_context->GetShapeOrDataForValue(x); + const auto &y_data_shape = infer_context->GetShapeOrDataForValue(y); PADDLE_ENFORCE_EQ(x_data_shape.data().has_value(), true, phi::errors::InvalidArgument( @@ -4141,7 +4134,7 @@ bool ShapeBroadcastOp::InferSymbolicShape( symbol::ShapeOrDataDimExprs output_data_shape{ symbol::TensorShapeOrDataDimExprs(shape, output_data)}; - shape_analysis->SetShapeOrDataForValue(res, output_data_shape); + infer_context->SetShapeOrDataForValue(res, output_data_shape); return true; } diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.h b/paddle/fluid/pir/dialect/operator/ir/manual_op.h index 05e149a1efd2ea..deb7c09c06ff88 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.h @@ -394,7 +394,7 @@ class SliceArrayOp using Op::Op; static const char *name() { return "pd_op.slice_array"; } static constexpr const char **attributes_name = nullptr; - static constexpr uint32_t attributes_num = 2; + static constexpr uint32_t attributes_num = 0; static OpInfoTuple GetOpInfo(); void VerifySig(); @@ -558,7 +558,7 @@ class ExpandOp : public pir::Op> &outputs, const std::vector> &out_grads, const std::vector> &stop_gradients); - bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis); + bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context); }; class IncrementOp @@ -604,7 +604,7 @@ class IncrementOp const std::vector> &outputs, const std::vector> &out_grads, const std::vector> &stop_gradients); - bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis); + bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context); }; class Increment_Op @@ -651,7 +651,7 @@ class Increment_Op const std::vector> &outputs, const std::vector> &out_grads, const std::vector> &stop_gradients); - bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis); + bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context); }; class AssignOut_Op @@ -760,7 +760,7 @@ class IR_API ShapeBroadcastOp const std::vector &input_values, pir::AttributeMap *p_attributes); - bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis); + bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context); }; class ArrayPopOp : public pir::Opnum_operands(); ++i) { @@ -53,14 +53,14 @@ struct CombineOpInferSymbolicShapeInterfaceModel "DenseTensorType.")); shape_data_list.emplace_back( - shape_analysis->GetShapeOrDataForValue(op->operand_source(i)) + infer_context->GetShapeOrDataForValue(op->operand_source(i)) .dyn_cast()); } return shape_data_list; }(); symbol::ShapeOrDataDimExprs shape_data{shape_data_list}; - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + infer_context->SetShapeOrDataForValue(op->result(0), shape_data); return true; } @@ -71,7 +71,7 @@ struct CombineOpInferSymbolicShapeInterfaceModel struct ConstantOpInferSymbolicShapeInterfaceModel : public InferSymbolicShapeInterface::Concept { static inline bool InferSymbolicShape( - pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis) { + pir::Operation* op, pir::InferSymbolicShapeContext* infer_context) { PADDLE_ENFORCE_NOT_NULL( op->result(0).type().dyn_cast(), phi::errors::InvalidArgument( @@ -88,7 +88,7 @@ struct ConstantOpInferSymbolicShapeInterfaceModel return dims; }(); - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( op->result(0), symbol::ShapeOrDataDimExprs{ symbol::TensorShapeOrDataDimExprs(out_dims)}); @@ -103,7 +103,7 @@ struct ConstantOpInferSymbolicShapeInterfaceModel struct ParameterOpInferSymbolicShapeInterfaceModel : public InferSymbolicShapeInterface::Concept { static inline bool InferSymbolicShape( - pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis) { + pir::Operation* op, pir::InferSymbolicShapeContext* infer_context) { pir::Value res0 = op->result(0); std::vector dims = @@ -114,7 +114,7 @@ struct ParameterOpInferSymbolicShapeInterfaceModel for (int64_t dim : dims) { symbol::DimExpr dim_expr; if (dim == -1) { - symbol::DimExpr res_dim_expr(shape_analysis->GetNextSymName()); + symbol::DimExpr res_dim_expr(infer_context->GetNextSymName()); dim_expr = res_dim_expr; } else { symbol::DimExpr res_dim_expr(dim); @@ -126,7 +126,7 @@ struct ParameterOpInferSymbolicShapeInterfaceModel symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(sym_shape)}; - shape_analysis->SetShapeOrDataForValue(res0, shape_data); + infer_context->SetShapeOrDataForValue(res0, shape_data); return true; } @@ -138,7 +138,7 @@ struct ParameterOpInferSymbolicShapeInterfaceModel struct SetParameterOpInferSymbolicShapeInterfaceModel : public InferSymbolicShapeInterface::Concept { static inline bool InferSymbolicShape( - pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis) { + pir::Operation* op, pir::InferSymbolicShapeContext* infer_context) { return true; } @@ -149,10 +149,10 @@ struct SetParameterOpInferSymbolicShapeInterfaceModel struct ShadowOutputOpInferSymbolicShapeInterfaceModel : public InferSymbolicShapeInterface::Concept { static inline bool InferSymbolicShape( - pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis) { + pir::Operation* op, pir::InferSymbolicShapeContext* infer_context) { pir::Value operand_source = op->operand_source(0); auto input_shapeordata = - shape_analysis->GetShapeOrDataForValue(operand_source); + infer_context->GetShapeOrDataForValue(operand_source); symbol::ShapeOrDataDimExprs shape_data = input_shapeordata; pir::shape::SetShapeAttrForOp(op, shape_data); @@ -167,15 +167,15 @@ struct ShadowOutputOpInferSymbolicShapeInterfaceModel struct SliceOpInferSymbolicShapeInterfaceModel : public InferSymbolicShapeInterface::Concept { static inline bool InferSymbolicShape( - pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis) { + pir::Operation* op, pir::InferSymbolicShapeContext* infer_context) { const auto index = op->attributes().at("index").dyn_cast().data(); const auto output_value = (op->operand(0).type().dyn_cast())[index] .dyn_cast(); - shape_analysis->SetShapeOrDataForValue( - op->result(0), shape_analysis->GetShapeOrDataForValue(output_value)); + infer_context->SetShapeOrDataForValue( + op->result(0), infer_context->GetShapeOrDataForValue(output_value)); return true; } @@ -187,9 +187,9 @@ struct SliceOpInferSymbolicShapeInterfaceModel struct SplitOpInferSymbolicShapeInterfaceModel : public InferSymbolicShapeInterface::Concept { static inline bool InferSymbolicShape( - pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis) { + pir::Operation* op, pir::InferSymbolicShapeContext* infer_context) { const auto& shape_data_list = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)) + infer_context->GetShapeOrDataForValue(op->operand_source(0)) .dyn_cast(); for (uint32_t rst_idx = 0; rst_idx < op->num_results(); rst_idx++) { @@ -199,7 +199,7 @@ struct SplitOpInferSymbolicShapeInterfaceModel paddle::platform::errors::InvalidArgument( "Currently InferSymbolicShape of SplitOp only support " "input without value.")); - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( op->result(rst_idx), symbol::ShapeOrDataDimExprs{shape_data_list[rst_idx]}); } @@ -213,7 +213,7 @@ struct SplitOpInferSymbolicShapeInterfaceModel struct YieldOpInferSymbolicShapeInterfaceModel : public InferSymbolicShapeInterface::Concept { static inline bool InferSymbolicShape( - pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis) { + pir::Operation* op, pir::InferSymbolicShapeContext* infer_context) { // Since YieldOp has no output, just return true return true; } diff --git a/paddle/fluid/pir/dialect/operator/ir/ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml index 7b24d6c4fec9a0..926363d002af12 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -897,6 +897,16 @@ kernel: func: get_tensor_from_selected_rows {selected_rows -> dense} +- op : global_gather + args : (Tensor x, Tensor local_count, Tensor global_count, int ring_id = 0, bool use_calc_stream=false) + output : Tensor(out) + infer_meta: + func : GlobalGatherInferMeta + kernel : + func : global_gather + data_type: x + backward : global_gather_grad + - op : global_scatter args : (Tensor x, Tensor local_count, Tensor global_count, int ring_id=0, bool use_calc_stream=false) output : Tensor(out) @@ -905,6 +915,7 @@ kernel : func : global_scatter data_type : x + backward : global_scatter_grad - op : greater_equal args : (Tensor x, Tensor y) @@ -1351,6 +1362,15 @@ func : prune_gate_by_capacity data_type : gate_idx +- op : pull_box_sparse + args : (Tensor w, Tensor[] ids, bool is_sparse = false, bool is_distributed = false, int size = 1) + output : Tensor[](out){ids.size()} + infer_meta : + func : PullBoxSparseInferMeta + kernel : + func : pull_box_sparse + data_type : ids + - op : pull_gpups_sparse args : (Tensor w, Tensor[] ids, int[] size={}, bool is_sparse=false, bool is_distributed=false) output : Tensor[](out){ids.size()} diff --git a/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml b/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml index f407cf00c504f7..1b3b24153b34a6 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml @@ -659,6 +659,18 @@ func : prod_grad composite: prod_grad(x, out, out_grad, dims, keep_dim, reduce_all, x_grad) +- backward_op : push_box_sparse + forward : pull_box_sparse (Tensor w, Tensor[] ids, bool is_sparse = false, bool is_distributed = false, int size = 1) -> Tensor[](out){ids.size()} + args : (Tensor[] ids, Tensor[] out_grad_in, bool is_sparse = false, bool is_distributed = false, int size = 1) + output : Tensor[](out_grad_out){out_grad_in.size()} + infer_meta : + func : UnchangedMultiInferMeta + param : [out_grad_in] + kernel : + func : push_box_sparse + data_type : out_grad_in + inplace : (out_grad_in -> out_grad_out) + - backward_op : rank_attention_grad forward : rank_attention (Tensor x, Tensor rank_offset, Tensor rank_param, int max_rank = 3, int max_size = 0) -> Tensor(input_help), Tensor(out), Tensor(ins_rank) args : (Tensor x, Tensor rank_offset, Tensor rank_param, Tensor input_help, Tensor ins_rank, Tensor out_grad, int max_rank = 3, int max_size = 0) diff --git a/paddle/fluid/pir/dialect/operator/utils/utils.cc b/paddle/fluid/pir/dialect/operator/utils/utils.cc index 551acf25fa16b5..711bd2f4e5f18a 100644 --- a/paddle/fluid/pir/dialect/operator/utils/utils.cc +++ b/paddle/fluid/pir/dialect/operator/utils/utils.cc @@ -105,6 +105,8 @@ const std::unordered_set LegacyOpList = { CReduceMinOp::name(), CReduceProdOp::name(), CScatterOp::name(), + PullBoxSparseOp::name(), + PushBoxSparseOp::name(), PushSparseV2Op::name(), PartialSendOp::name(), PartialRecvOp::name()}; diff --git a/paddle/fluid/pir/transforms/sub_graph_detector.cc b/paddle/fluid/pir/transforms/sub_graph_detector.cc index ebc0f1e9f9d115..1491c235bf03e8 100644 --- a/paddle/fluid/pir/transforms/sub_graph_detector.cc +++ b/paddle/fluid/pir/transforms/sub_graph_detector.cc @@ -26,6 +26,7 @@ #include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/utils/general_functions.h" #include "paddle/pir/include/core/builder.h" #include "paddle/pir/include/core/builtin_op.h" #include "paddle/pir/include/dialect/control_flow/ir/cf_dialect.h" @@ -484,34 +485,6 @@ std::vector AnalysisOutputs( return outputs; } -std::vector AnalysisExternalInputs(Operation* op) { // NOLINT - if (!op->isa()) { - return op->operands_source(); - } - // Get all ops in group - const auto all_ops = [&]() -> decltype(auto) { - const auto all_ops = op->dyn_cast().GetOperators(); - return std::unordered_set(all_ops.begin(), all_ops.end()); - }(); - std::unordered_set value_set; - const auto& IsOutsideInput = [&](const pir::Value& value) -> bool { - const bool is_outside = - value && value.defining_op() && !all_ops.count(value.defining_op()); - const bool has_visited = value_set.count(value); - if (!has_visited) value_set.insert(value); - return is_outside && !has_visited; - }; - - std::vector<::pir::Value> inputs; - // count all op's input Value - for (auto inner_op : all_ops) { - for (auto& value : inner_op->operands_source()) { - if (IsOutsideInput(value)) inputs.push_back(value); - } - } - return inputs; -} - namespace { pir::Operation* FindInsertPoint(const GroupOpsVec& group_ops, @@ -576,7 +549,7 @@ std::unordered_set GetUpstreamOpsAfterPosition( } return false; }; - std::vector op_inputs = AnalysisExternalInputs(op); + std::vector op_inputs = pir::GetUsedExternalValue(*op); for (auto value : op_inputs) { if (!value || !value.defining_op()) continue; pir::Operation* defining_op = value.defining_op(); diff --git a/paddle/fluid/primitive/codegen/gen.py b/paddle/fluid/primitive/codegen/gen.py index 09c7b0c8729f4b..636b18a75aeab0 100644 --- a/paddle/fluid/primitive/codegen/gen.py +++ b/paddle/fluid/primitive/codegen/gen.py @@ -119,9 +119,6 @@ 'relu_grad', 'sigmoid_grad', 'silu_grad', - 'exp_grad', - 'log_grad', - 'abs_double_grad', 'softmax_grad', 'sqrt_grad', ] # custom vjp list of composite op diff --git a/paddle/fluid/primitive/codegen/templates/common.j2 b/paddle/fluid/primitive/codegen/templates/common.j2 index 5f7148017ab23b..b29401133db03d 100644 --- a/paddle/fluid/primitive/codegen/templates/common.j2 +++ b/paddle/fluid/primitive/codegen/templates/common.j2 @@ -33,10 +33,10 @@ template {%- endmacro -%} -{%- macro args(inputs, attrs) -%} {#- Arguments are variable pass into method -#} - {{sequence('', '', ', ', inputs)}} - {%- if inputs|length>0 and attrs|length > 0 -%} {{", "}} {%- endif -%} {#- append comma between inputs and attrs -#} - {{sequence('', '', ', ', attrs)}} +{%- macro args(arg1, arg2) -%} {#- Arguments are variable pass into method -#} + {{sequence('', '', ', ', arg1)}} + {%- if arg1|length>0 and arg2|length > 0 -%} {{", "}} {%- endif -%} {#- append comma between arg1 and arg2 -#} + {{sequence('', '', ', ', arg2)}} {%- endmacro -%} diff --git a/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 b/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 index a2ac7b1ed64cd2..0f6f5f83d33aa7 100644 --- a/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 +++ b/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 @@ -27,7 +27,7 @@ std::vector> vjp_res; for (auto arg: stop_gradients) { vjp_res.push_back(std::vector(arg.size())); } - {% if 'composite' in api and api.name in vjp_comp_white_list %} + {% if api.name in vjp_comp_white_list %} std::string op_name = "{{api.name}}"; auto need_skip = paddle::prim::StaticCompositeContext::Instance().CheckSkipCompOps(op_name); if (paddle::prim::StaticCompositeContext::Instance().IsBwdPrimEnabled() && !need_skip) { @@ -115,7 +115,13 @@ for (size_t i=0; i< stop_gradients[{{i}}].size(); i++ ) { {% endif %} {% endfor %} {{get_mutable_attribute(api.attrs, api.name)}} -details::{{api.composite.func_name}}({{api.composite.func_args}}); + +{%- set args_names=[] -%} +{%- for i in api.inputs -%} {%- do args_names.append(i.name) -%} {%- endfor -%} +{%- for i in api.attrs -%} {%- do args_names.append(i.name) -%} {%- endfor %} +{%- set outputs_names=[] -%} +{%- for i in api.outputs -%} {%- do outputs_names.append(i.name) -%} {%- endfor -%} +details::{{api.name}}({{common.args(args_names, outputs_names)}}); {% endmacro %} {%- set api_map = {} -%} diff --git a/paddle/fluid/primitive/composite/composite.h b/paddle/fluid/primitive/composite/composite.h index 101354241e03bc..7151127804712c 100644 --- a/paddle/fluid/primitive/composite/composite.h +++ b/paddle/fluid/primitive/composite/composite.h @@ -729,6 +729,14 @@ Tensor full_like_decomp(const Tensor& x, } } +template +Tensor floor_divide_decomp(const Tensor& x, const Tensor& y) { + auto x_cast = cast(x, DataType::INT64); + auto y_cast = cast(y, DataType::INT64); + auto res = x_cast / y_cast; + return cast(res, x.dtype()); +} + template std::tuple dropout_decomp( const Tensor& x, diff --git a/paddle/fluid/pybind/eager_functions.cc b/paddle/fluid/pybind/eager_functions.cc index 66ffa2ba23d124..009abb8d749ce9 100644 --- a/paddle/fluid/pybind/eager_functions.cc +++ b/paddle/fluid/pybind/eager_functions.cc @@ -288,13 +288,13 @@ PyObject* eager_api_get_grads_types(PyObject* self, EAGER_TRY auto tensor_list = CastPyArg2VectorOfTensor(PyTuple_GET_ITEM(args, 0), 0); - std::vector ret; + std::vector ret; for (auto& tensor : tensor_list) { VLOG(6) << "Get grad for tensor: " << tensor.name(); auto meta = egr::EagerUtils::nullable_autograd_meta(tensor); if (!meta || meta->StopGradient()) { - ret.emplace_back(-1); + ret.emplace_back(phi::DataType::UNDEFINED); continue; } @@ -304,11 +304,10 @@ PyObject* eager_api_get_grads_types(PyObject* self, (tensor.dtype() == phi::DataType::FLOAT32 || tensor.dtype() == phi::DataType::FLOAT16 || tensor.dtype() == phi::DataType::BFLOAT16)) { - ret.emplace_back( - paddle::framework::TransToProtoVarType(tensor.dtype())); + ret.emplace_back(tensor.dtype()); } } else { - ret.emplace_back(-1); + ret.emplace_back(phi::DataType::UNDEFINED); } } diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index 48f01681969495..afd22b0f387cb9 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -1083,6 +1083,14 @@ PyObject* ToPyObject(const phi::DataType& dtype) { return obj.ptr(); } +PyObject* ToPyObject(const std::vector& dtypes) { + PyObject* result = PyList_New((Py_ssize_t)dtypes.size()); + for (size_t i = 0; i < dtypes.size(); i++) { + PyList_SET_ITEM(result, static_cast(i), ToPyObject(dtypes[i])); + } + return result; +} + PyObject* ToPyObject(const pir::Value& value) { auto obj = ::pybind11::cast(value); obj.inc_ref(); @@ -1389,8 +1397,17 @@ std::vector GetTensorListFromArgs( arg_idx)); } for (Py_ssize_t i = 0; i < len; i++) { + PyObject* tensor_obj = PyList_GetItem(list, i); + PADDLE_ENFORCE_EQ( + PyObject_TypeCheck(tensor_obj, p_tensor_type), + true, + platform::errors::InvalidArgument( + "%s(): argument '%s' (position %d) must be list of Tensors", + op_type, + arg_name, + arg_idx)); paddle::Tensor& tensor = - reinterpret_cast(PyList_GetItem(list, i))->tensor; + reinterpret_cast(tensor_obj)->tensor; if (local_mesh) { ConvertToDistTensor(&tensor, local_mesh); } else { @@ -1422,8 +1439,17 @@ std::vector GetTensorListFromArgs( arg_idx)); } for (Py_ssize_t i = 0; i < len; i++) { + PyObject* tensor_obj = PyTuple_GetItem(list, i); + PADDLE_ENFORCE_EQ( + PyObject_TypeCheck(tensor_obj, p_tensor_type), + true, + platform::errors::InvalidArgument( + "%s(): argument '%s' (position %d) must be list of Tensors", + op_type, + arg_name, + arg_idx)); paddle::Tensor& tensor = - reinterpret_cast(PyTuple_GetItem(list, i))->tensor; + reinterpret_cast(tensor_obj)->tensor; if (local_mesh) { ConvertToDistTensor(&tensor, local_mesh); } else { @@ -1495,8 +1521,17 @@ paddle::optional> GetOptionalTensorListFromArgs( arg_idx)); } for (Py_ssize_t i = 0; i < len; i++) { + PyObject* tensor_obj = PyList_GetItem(list, i); + PADDLE_ENFORCE_EQ( + PyObject_TypeCheck(tensor_obj, p_tensor_type), + true, + platform::errors::InvalidArgument( + "%s(): argument '%s' (position %d) must be list of Tensors", + op_type, + arg_name, + arg_idx)); paddle::Tensor& tensor = - reinterpret_cast(PyList_GetItem(list, i))->tensor; + reinterpret_cast(tensor_obj)->tensor; if (local_mesh) { ConvertToDistTensor(&tensor, local_mesh); } else { @@ -1528,8 +1563,17 @@ paddle::optional> GetOptionalTensorListFromArgs( arg_idx)); } for (Py_ssize_t i = 0; i < len; i++) { + PyObject* tensor_obj = PyTuple_GetItem(list, i); + PADDLE_ENFORCE_EQ( + PyObject_TypeCheck(tensor_obj, p_tensor_type), + true, + platform::errors::InvalidArgument( + "%s(): argument '%s' (position %d) must be list of Tensors", + op_type, + arg_name, + arg_idx)); paddle::Tensor& tensor = - reinterpret_cast(PyTuple_GetItem(list, i))->tensor; + reinterpret_cast(tensor_obj)->tensor; if (local_mesh) { ConvertToDistTensor(&tensor, local_mesh); } else { diff --git a/paddle/fluid/pybind/eager_utils.h b/paddle/fluid/pybind/eager_utils.h index e56741aa90776a..42126ee163a6dc 100644 --- a/paddle/fluid/pybind/eager_utils.h +++ b/paddle/fluid/pybind/eager_utils.h @@ -148,7 +148,8 @@ PyObject* ToPyObject(const phi::distributed::Placements& value); PyObject* ToPyObject(const phi::SelectedRows* value); PyObject* ToPyObject(const paddle::framework::proto::VarType::Type& dtype); PyObject* ToPyObject(const paddle::framework::proto::VarType& type); -PyObject* ToPyObject(const phi::DataType& type); +PyObject* ToPyObject(const phi::DataType& dtype); +PyObject* ToPyObject(const std::vector& dtypes); PyObject* ToPyObject(const void* value); PyObject* ToPyObject(const std::unordered_map& value); PyObject* ToPyObject( diff --git a/paddle/fluid/pybind/op_function_common.cc b/paddle/fluid/pybind/op_function_common.cc index f8f1424ded2432..cc484f74ab22f4 100644 --- a/paddle/fluid/pybind/op_function_common.cc +++ b/paddle/fluid/pybind/op_function_common.cc @@ -858,6 +858,17 @@ void CastPyArg2AttrIRBlock(PyObject* obj, attrs[key] = reinterpret_cast<::pir::Block*&>(vh[0]); } +void CastPyArg2AttrIRProgram(PyObject* obj, + paddle::framework::AttributeMap& attrs, // NOLINT + const std::string& key, + const std::string& op_type, + ssize_t arg_pos) { + VLOG(1) << "After Process pir::Program*"; + const std::shared_ptr<::pir::Program> program = + ::py::handle(obj).cast>(); + attrs[key] = program; +} + void CastPyArg2AttrValues(PyObject* obj, paddle::framework::AttributeMap& attrs, // NOLINT const std::string& key, @@ -1020,11 +1031,11 @@ void ConstructAttrMapForRunProgram( if (std::set({"cuda_graph_capture_mode"}).count(key)) { CastPyArg2AttrString(obj, attrs, key, op_type, arg_pos); - } else if (std::set({"global_block", - "forward_global_block", - "backward_global_block"}) - .count(key)) { + } else if (std::set({"global_block"}).count(key)) { CastPyArg2AttrIRBlock(obj, attrs, key, op_type, arg_pos); + } else if (std::set({"forward_program", "backward_program"}) + .count(key)) { + CastPyArg2AttrIRProgram(obj, attrs, key, op_type, arg_pos); } else if (std::set({"is_test", "use_interpretorcore"}) .count(key)) { CastPyArg2AttrBoolean(obj, attrs, key, op_type, arg_pos); diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index 85ce4abcda94d0..18febc9ff6754f 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -121,6 +121,16 @@ PyTypeObject *g_ir_value_pytype = nullptr; void BindOpsAPI(pybind11::module *module); +pir::Value FakeValue() { + // create a fake value to simplify `ForwardBackwardSplit`. + return pir::Value(nullptr); +} + +bool IsFakeValue(const pir::Value &value) { + // create a fake value to simplify `ForwardBackwardSplit`. + return value.impl() == nullptr || !value.type(); +} + inline int64_t GetProgramInt64Attr(const std::shared_ptr &program, const std::string &attr_name, int64_t default_value = 0) { @@ -195,6 +205,51 @@ Value GetOutputValueByName(const Program &program, const std::string &name) { return value; } +void SetValueName(Value value, const std::string name) { + pir::Operation *define_op = value.defining_op(); + if (define_op->isa()) { + define_op->set_attribute( + "parameter_name", + pir::StrAttribute::get(pir::IrContext::Instance(), name)); + } else if (define_op->isa()) { + define_op->set_attribute( + "name", pir::StrAttribute::get(pir::IrContext::Instance(), name)); + } else if (auto block_arg = value.dyn_cast()) { + PADDLE_THROW( + phi::errors::InvalidArgument("Can Not set name for BlockArgument! ")); + } else if (value.first_use()) { + auto nextOp = value.first_use().owner(); + if (nextOp->isa<::pir::ShadowOutputOp>()) { + nextOp->set_attribute( + "output_name", + pir::StrAttribute::get(pir::IrContext::Instance(), name)); + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "Currently, we can only set name of Value which is " + "shadowoutput ")); + } + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "Currently, we can only set name of Value that " + "is persistable")); + } +} + +bool HasValueName(const Value &value) { + if (IsFakeValue(value)) { + return false; + } + if (value.defining_op()->isa<::pir::ParameterOp>() || + value.defining_op()->isa() || + value.isa() || + (value.first_use() && + (value.first_use().owner()->isa<::pir::ShadowOutputOp>()))) { + return true; + } else { + return false; + } +} + std::string GetValueName(Value value) { if (auto param_op = value.defining_op<::pir::ParameterOp>()) { return param_op.param_name(); @@ -287,7 +342,7 @@ void BindProgram(py::module *m) { )DOC"); program .def(py::init([]() { - return std::make_unique(pir::IrContext::Instance()); + return std::make_shared(pir::IrContext::Instance()); })) .def("__str__", [](const std::shared_ptr &self) { @@ -376,7 +431,7 @@ void BindProgram(py::module *m) { for (auto op : self->block()->ops()) { for (auto var : op->results()) { auto is_persistable = - var.attribute("persistable"); + var.attribute(kAttrIsPersistable); if (is_persistable && is_persistable.data()) { if (var.defining_op()->isa<::pir::ParameterOp>()) { std::string var_name = GetValueName(var); @@ -968,21 +1023,12 @@ void BindValue(py::module *m) { return ss.str(); } }) - .def_property_readonly("name", - [](Value self) { return GetValueName(self); }) - .def_property_readonly( - "has_name", - [](Value self) { - if (self.defining_op()->isa<::pir::ParameterOp>() || - self.defining_op()->isa() || - self.isa() || - (self.first_use() && - self.first_use().owner()->isa<::pir::ShadowOutputOp>())) { - return true; - } else { - return false; - } - }) + .def_property( + "name", + [](Value self) { return GetValueName(self); }, + [](Value self, const std::string &name) { SetValueName(self, name); }) + .def_property_readonly("has_name", + [](Value self) { return HasValueName(self); }) .def_property( "shape", [](Value self) { return phi::vectorize(GetValueDims(self)); }, @@ -1476,16 +1522,6 @@ using SplitedProgram = std::vector>; using SplitedAttribute = std::map>; using SplitedResult = std::pair; -pir::Value FakeValue() { - // create a fake value to simplify `ForwardBackwardSplit`. - return pir::Value(nullptr); -} - -bool IsFakeValue(const pir::Value &value) { - // create a fake value to simplify `ForwardBackwardSplit`. - return value.impl() == nullptr || !value.type(); -} - static auto GetNoNeedBufferValue( const ::pir::Block *whole_block, std::vector range, @@ -1594,10 +1630,12 @@ int AppendShadowOutputs(Program *forward_program, std::string name_prefix) { int counter = 0; std::unordered_set added_value; - for (const auto &value : outputs) { if (!added_value.count(value) || IsFakeValue(value)) { std::string shadow_output_name = name_prefix + std::to_string(counter); + if (HasValueName(value)) { + shadow_output_name = GetValueName(value); + } AppendShadowOutput( forward_program, value, shadow_output_name, start_point + counter); counter += 1; @@ -1727,6 +1765,9 @@ SplitedResult SplitForwardBackward( } std::string shadow_output_name = std::string("output_") + std::to_string(counter); + if (HasValueName(v)) { + shadow_output_name = GetValueName(v); + } auto op_info = ctx->GetRegisteredOpInfo(pir::ShadowOutputOp::name()); pir::AttributeMap attribute_map = { {"output_name", pir::StrAttribute::get(ctx, shadow_output_name)}, diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index c93588f73d6f3b..ba096252689e05 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -31,10 +31,10 @@ limitations under the License. */ #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/memory/memcpy.h" -#include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/device/device_wrapper.h" #include "paddle/fluid/pybind/complex.h" +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/funcs/strided_memcpy.h" #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) @@ -721,7 +721,7 @@ void _concatCompute(const std::vector &ins, output_offset += in_stride[axis]; } } else { - paddle::operators::math::ConcatFunctor concat_functor; + phi::funcs::ConcatFunctor concat_functor; concat_functor(ctx, ins, static_cast(axis), out); } } diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index e94aff346e0a86..5e195cd0f4f7ab 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -519,6 +519,18 @@ func : cross_grad data_type : out_grad +- backward_op : cudnn_lstm_grad + forward: cudnn_lstm (Tensor x, Tensor init_h, Tensor init_c, Tensor w, Tensor[] weight_list, Tensor sequence_length, float dropout_prob = 0.0, bool is_bidirec = false, int hidden_size = 100, int num_layers = 1, bool is_test = false, int seed = 0) -> Tensor (out), Tensor (last_h), Tensor (last_c), Tensor (reserve), Tensor (state_out) + args: (Tensor x, Tensor init_h, Tensor init_c, Tensor[] weight_list, Tensor sequence_length, Tensor out, Tensor reserve, Tensor state_out, Tensor out_grad, Tensor last_h_grad, Tensor last_c_grad, float dropout_prob = 0.0, bool is_bidirec = false, int hidden_size = 100, int num_layers = 1, bool is_test = false, int seed = 0) + output: Tensor (x_grad), Tensor (init_h_grad), Tensor (init_c_grad), Tensor[](weight_list_grad){weight_list.size()} + infer_meta: + func: CudnnLSTMGradInferMeta + param : [x, init_h, init_c, weight_list] + kernel: + func: cudnn_lstm_grad + data_type : out_grad + optional: weight_list, sequence_length, weight_list_grad + - backward_op : cummax_grad forward : cummax(Tensor x, int axis=-1, DataType dtype = DataType::INT64) -> Tensor(out), Tensor(indices) args : (Tensor x, Tensor indices, Tensor out_grad, int axis, DataType dtype) @@ -870,6 +882,18 @@ func : flash_attn_grad data_type: q +- backward_op : flash_attn_qkvpacked_grad + forward : flash_attn_qkvpacked (Tensor qkv, Tensor fixed_seed_offset, Tensor attn_mask, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) + args : (Tensor qkv, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, float dropout = 0.0, bool causal = false) + optional : attn_mask + output : Tensor(qkv_grad) + infer_meta : + func : FlashAttnQKVPackedGradInferMeta + param : [qkv] + kernel : + func : flash_attn_qkvpacked_grad + data_type: qkv + - backward_op : flash_attn_unpadded_grad forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false) @@ -882,6 +906,18 @@ func : flash_attn_unpadded_grad data_type: q +- backward_op : flash_attn_varlen_qkvpacked_grad + forward : flash_attn_varlen_qkvpacked (Tensor qkv, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "", bool varlen_padded = true) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) + args : (Tensor qkv, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool varlen_padded = true) + optional : attn_mask + output : Tensor(qkv_grad) + infer_meta : + func : FlashAttnQKVPackedGradInferMeta + param : [qkv] + kernel : + func : flash_attn_varlen_qkvpacked_grad + data_type: qkv + - backward_op : flash_attn_with_sparse_mask_grad forward : flash_attn_with_sparse_mask (Tensor q, Tensor k, Tensor v, Tensor attn_mask_start_row_indices, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, int attn_mask_start_row = 0, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) args : (Tensor q, Tensor k, Tensor v, Tensor attn_mask_start_row_indices, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false, int attn_mask_start_row = 0) @@ -1053,6 +1089,18 @@ func : gelu_grad composite: gelu_grad(x, out_grad, approximate, x_grad) +- backward_op : global_gather_grad + forward : global_gather(Tensor x, Tensor local_count, Tensor global_count, int ring_id = 0, bool use_calc_stream=false) -> Tensor(out) + args : (Tensor out_grad, Tensor local_count, Tensor global_count, int ring_id = 0, bool use_calc_stream=false) + output : Tensor(x_grad) + invoke : global_scatter(out_grad, local_count, global_count, ring_id, use_calc_stream) + +- backward_op : global_scatter_grad + forward : global_scatter(Tensor x, Tensor local_count, Tensor global_count, int ring_id = 0, bool use_calc_stream=false) -> Tensor(out) + args : (Tensor out_grad, Tensor local_count, Tensor global_count, int ring_id = 0, bool use_calc_stream=false) + output : Tensor(x_grad) + invoke : global_gather(out_grad, local_count, global_count, ring_id, use_calc_stream) + - backward_op : grid_sample_grad forward : grid_sample (Tensor x, Tensor grid, str mode, str padding_mode, bool align_corners) -> Tensor(out) args : (Tensor x, Tensor grid, Tensor out_grad, str mode, str padding_mode, bool align_corners) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index bab135c5f6b444..301835ed5bc7a6 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1671,6 +1671,13 @@ attrs : {pre_nms_top_n : pre_nms_topN, post_nms_top_n : post_nms_topN} +- op : global_gather + backward : global_gather_grad + inputs : + x : X + outputs: + out : Out + - op : global_scatter inputs : {x : X} @@ -2691,6 +2698,16 @@ outputs : out : Out +- op : pull_box_sparse + inputs : + { w : W, ids: Ids} + outputs : + out : Out + attrs : + sparse : is_sparse + extra : + attrs : [bool is_sparse = false, bool is_distributed = false, int size = 1] + - op : pull_gpups_sparse backward : push_gpups_sparse inputs : @@ -2706,6 +2723,14 @@ extra : attrs : [int embedding_dim = 11, int table_id = 0, str accessor_class = "", str ctr_label_name = "", int padding_id = 0, bool scale_sparse_grad = true, 'str[] input_names = {}', bool is_distributed = true] +- op : push_box_sparse + inputs : + ids: Ids + outputs : + out : Out + attrs : + sparse : is_sparse + - op : push_dense inputs : ids : Ids @@ -3232,6 +3257,23 @@ attrs: data_format: data_layout +- op : sparse_slice + int_array : + starts : + data_type : int + tensor_name : StartsTensor + tensors_name : StartsTensorList + ends : + data_type : int + tensor_name : EndsTensor + tensors_name : EndsTensorList + +- op : sparse_sum + scalar : + axis : + data_type : int + tensor_name : AxisTensor + - op : sparse_sync_batch_norm attrs: data_format: data_layout @@ -3666,6 +3708,12 @@ outputs : {boxes : Boxes, scores : Scores} +- op : yolo_box_head + inputs : + {x : X} + outputs : + {out : Out} + - op : yolo_loss (yolov3_loss) backward: yolo_loss_grad (yolov3_loss_grad) inputs : @@ -3784,6 +3832,14 @@ outputs: {out: Out} +- op: cudnn_lstm + backward: cudnn_lstm_grad + inputs: + {x: Input, init_h: InitH, init_c: InitC, w: W, weight_list: WeightList, sequence_length: SequenceLength} + outputs: + {reserve: Reserve, state_out: StateOut, out: Out, last_h: LastH, last_c: LastC} + drop_empty_grad : [weight_list_grad] + - op: decayed_adagrad inputs: {param : Param, grad : Grad, moment : Moment, learning_rate : LearningRate} @@ -3950,6 +4006,17 @@ outputs : out : Out +- op: moe + inputs: + x: X + gate: Gate + bmm0: Bmm0 + bias0: Bias0 + bmm1: Bmm1 + bias1: Bias1 + outputs: + out: Out + - op: nce backward: nce_grad inputs: diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 98da34dd2d442e..a59f50b7a8ac64 100755 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -708,6 +708,18 @@ data_type : input backward : cross_entropy_with_softmax_grad +- op : cudnn_lstm + args: (Tensor x, Tensor init_h, Tensor init_c, Tensor w, Tensor[] weight_list, Tensor sequence_length, float dropout_prob = 0.0, bool is_bidirec = false, int hidden_size = 100, int num_layers = 1, bool is_test = false, int seed = 0) + output: Tensor (out), Tensor (last_h), Tensor (last_c), Tensor (reserve), Tensor (state_out) + infer_meta: + func: CudnnLSTMInferMeta + kernel: + func: cudnn_lstm + data_type: x + optional: w, weight_list, sequence_length + intermediate: reserve + backward: cudnn_lstm_grad + - op : cummax args : (Tensor x, int axis=-1, DataType dtype = DataType::INT64) output : Tensor(out), Tensor(indices) @@ -1097,6 +1109,18 @@ backward : flash_attn_grad interfaces : paddle::dialect::InferSymbolicShapeInterface +- op : flash_attn_qkvpacked + args : (Tensor qkv, Tensor fixed_seed_offset, Tensor attn_mask, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") + output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) + optional : fixed_seed_offset, attn_mask + infer_meta : + func : FlashAttnQKVPackedInferMeta + param : [qkv] + kernel : + func : flash_attn_qkvpacked + data_type : qkv + backward : flash_attn_qkvpacked_grad + - op : flash_attn_unpadded args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) @@ -1110,6 +1134,19 @@ intermediate : softmax_lse, seed_offset backward : flash_attn_unpadded_grad +- op : flash_attn_varlen_qkvpacked + args : (Tensor qkv, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "", bool varlen_padded = true) + output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) + optional : fixed_seed_offset , attn_mask + infer_meta : + func : FlashAttnQKVPackedInferMeta + param : [qkv] + kernel : + func : flash_attn_varlen_qkvpacked + data_type : qkv + intermediate : softmax_lse, seed_offset + backward : flash_attn_varlen_qkvpacked_grad + - op : flash_attn_with_sparse_mask args : (Tensor q, Tensor k, Tensor v, Tensor attn_mask_start_row_indices, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, int attn_mask_start_row = 0, bool return_softmax = false, bool is_test = false, str rng_name = "") output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) @@ -3225,6 +3262,15 @@ func : yolo_box data_type : x +- op : yolo_box_head + args : (Tensor x, int[] anchors, int class_num) + output : Tensor(out) + infer_meta : + func : YoloBoxHeadInferMeta + kernel : + func : yolo_box_head + data_type : x + - op : yolo_loss args : (Tensor x, Tensor gt_box, Tensor gt_label, Tensor gt_score, int[] anchors={}, int[] anchor_mask={}, int class_num =1 , float ignore_thresh=0.7, int downsample_ratio=32, bool use_label_smooth=true, float scale_x_y=1.0) output : Tensor(loss), Tensor(objectness_mask), Tensor(gt_match_mask) @@ -3236,3 +3282,12 @@ optional : gt_score intermediate : objectness_mask, gt_match_mask backward : yolo_loss_grad + +- op: moe + args: (Tensor x, Tensor gate, Tensor bmm0, Tensor bias0, Tensor bmm1, Tensor bias1, + str act_type = "gelu") + output: Tensor (out) + infer_meta: + func: MoeInferMeta + kernel: + func: moe diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 261b99512a0ffe..c7574910504cd7 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -244,6 +244,12 @@ void FlashAttnGradInferMeta(const MetaTensor& q, } } +void FlashAttnQKVPackedGradInferMeta(const MetaTensor& qkv, MetaTensor* dqkv) { + if (dqkv) { + dqkv->share_meta(qkv); + } +} + void FusedDropoutAddGradInferMeta(const MetaTensor& seed_offset, const MetaTensor& out_grad, MetaTensor* x_grad, @@ -320,6 +326,29 @@ void CrossEntropyWithSoftmaxGradInferMeta(const MetaTensor& label, logits_grad->set_dtype(softmax.dtype()); } +void CudnnLSTMGradInferMeta( + const MetaTensor& x, + const MetaTensor& init_h, + const MetaTensor& init_c, + const paddle::optional>& weight_list, + MetaTensor* x_grad, + MetaTensor* init_h_grad, + MetaTensor* init_c_grad, + std::vector weight_list_grad) { + if (x_grad) { + x_grad->share_meta(x); + } + if (init_h_grad) { + init_h_grad->share_meta(init_h); + } + if (init_c_grad) { + init_c_grad->share_meta(init_c); + } + if (!weight_list_grad.empty()) { + UnchangedMultiInferMeta(weight_list.get(), weight_list_grad); + } +} + void DeformableConvGradInferMeta(const MetaTensor& x, const MetaTensor& offset, const MetaTensor& filter, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 88aea8f18181b6..63912e98d50f3b 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -141,6 +141,16 @@ void CSoftmaxWithCrossEntropyGradInferMeta(const MetaTensor& softmax, MetaTensor* logits_grad, MetaConfig config = MetaConfig()); +void CudnnLSTMGradInferMeta( + const MetaTensor& x, + const MetaTensor& init_h, + const MetaTensor& init_c, + const paddle::optional>& weight_list, + MetaTensor* x_grad, + MetaTensor* init_h_grad, + MetaTensor* init_c_grad, + std::vector weight_list_grad); + void DeformableConvGradInferMeta(const MetaTensor& x, const MetaTensor& offset, const MetaTensor& filter, @@ -197,6 +207,8 @@ void FlashAttnGradInferMeta(const MetaTensor& q, MetaTensor* dk, MetaTensor* dv); +void FlashAttnQKVPackedGradInferMeta(const MetaTensor& qkv, MetaTensor* dq); + void FusedDropoutAddGradInferMeta(const MetaTensor& seed_offset, const MetaTensor& out_grad, MetaTensor* x_grad, diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 5212a6fe872224..0bcd80f79c989e 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -2998,6 +2998,33 @@ void PruneGateByCapacityInferMeta(const MetaTensor& gate_idx, new_gate_idx->set_dtype(gate_idx.dtype()); } +void PullBoxSparseInferMeta(const MetaTensor& w, + const std::vector& ids, + bool is_sparse, + bool is_distributed, + int size, + std::vector out) { + auto hidden_size = static_cast(size); + const size_t n_ids = ids.size(); + for (size_t i = 0; i < n_ids; ++i) { + MetaTensor* output = out[i]; + auto ids_dims = ids[i]->dims(); + int ids_rank = ids_dims.size(); + PADDLE_ENFORCE_EQ(ids_dims[ids_rank - 1], + 1UL, + phi::errors::InvalidArgument( + "Shape error in %lu id, the last dimension of the " + "'Ids' tensor must be 1.", + i)); + auto out_dim = + common::vectorize(common::slice_ddim(ids_dims, 0, ids_rank - 1)); + out_dim.push_back(hidden_size); + output->set_dims(common::make_ddim(out_dim)); + output->share_lod(*ids[i]); + output->set_dtype(w.dtype()); + } +} + void RepeatInterleaveWithTensorIndexInferMeta(const MetaTensor& x, const MetaTensor& repeats, int dim, @@ -3532,6 +3559,15 @@ void YoloBoxInferMeta(const MetaTensor& x, scores->set_dims(common::make_ddim(dim_scores)); } +void YoloBoxHeadInferMeta(const MetaTensor& x, + const std::vector& anchors UNUSED, + int class_num UNUSED, + MetaTensor* out, + MetaConfig config) { + out->set_dims(x.dims()); + out->set_dtype(x.dtype()); +} + void ValueCompareInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index bd8517f73898e0..e12b450407d162 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -479,6 +479,13 @@ void PReluInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void PullBoxSparseInferMeta(const MetaTensor& w, + const std::vector& ids, + bool is_sparse, + bool is_distributed, + int size, + std::vector out); + void PullGpupsSparseInferMeta(const MetaTensor& w, const std::vector& ids, const std::vector& size, @@ -614,6 +621,12 @@ void YoloBoxInferMeta(const MetaTensor& x, MetaTensor* scores, MetaConfig config = MetaConfig()); +void YoloBoxHeadInferMeta(const MetaTensor& x, + const std::vector& anchors, + int class_num, + MetaTensor* out, + MetaConfig config = MetaConfig()); + void ValueCompareInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out, diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index 590473bd2094ed..1853c6c395f1ed 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -1307,6 +1307,7 @@ void FusedFeedForwardInferMeta(const MetaTensor& x, ln2_variance->set_dims(mean_dim); } out->share_lod(x); + out->set_dtype(x.dtype()); } static bool IsUnaryCompound(const std::vector& functor_list) { diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index fc43c09105bd60..df7cf7f754c6ea 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -17,7 +17,10 @@ limitations under the License. */ #include "glog/logging.h" #include "paddle/common/ddim.h" +#include "paddle/common/errors.h" #include "paddle/common/layout.h" +#include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/enforce.h" #include "paddle/phi/kernels/funcs/common_shape.h" #include "paddle/phi/kernels/impl/box_coder.h" @@ -433,6 +436,31 @@ void FlashAttnInferMeta(const MetaTensor& q, seed_offset->set_dims({2}); } } +void FlashAttnQKVPackedInferMeta(const MetaTensor& qkv, + MetaTensor* out, + MetaTensor* softmax, + MetaTensor* softmax_lse, + MetaTensor* seed_offset) { + const auto& qkvdims = qkv.dims(); + PADDLE_ENFORCE(qkvdims.size() == 4 || qkvdims.size() == 5, + phi::errors::InvalidArgument( + "qkv dims must be 4(unpadded) or 5(padded batch)")); + // qkv [total_*,nheads/nheads_k+2,nheads_k,headdim] + auto out_dims = DDim({qkvdims[0], (qkvdims[1] - 2) * qkvdims[2], qkvdims[3]}); + if (qkvdims.size() == 5) { + // qkv [batchsize,seqlen,nheads/nheads_k+2,nheads_k,headdim] + out_dims = + DDim{qkvdims[0], qkvdims[1], (qkvdims[2] - 2) * qkvdims[3], qkvdims[4]}; + } + out->set_dims(out_dims); + out->set_dtype(qkv.dtype()); + out->set_layout(qkv.layout()); + softmax->set_dtype(qkv.dtype()); + softmax_lse->set_dtype(qkv.dtype()); + if (seed_offset) { + seed_offset->set_dtype(phi::DataType::INT64); + } +} void ArangeTensorInferMeta(const MetaTensor& start, const MetaTensor& end, @@ -564,6 +592,32 @@ void InstanceNormInferMeta(const MetaTensor& x, } } +void GlobalGatherInferMeta(const MetaTensor& x, + const MetaTensor& local_count, + const MetaTensor& global_count, + int ring_id, + bool use_calc_stream, + MetaTensor* out) { + PADDLE_ENFORCE_GE( + ring_id, + 0, + phi::errors::InvalidArgument( + "The ring_id (%d) for global gather op must be non-negative.", + ring_id)); + auto input_dims = x.dims(); + auto ndim_input = input_dims.size(); + // dim check + PADDLE_ENFORCE_EQ( + ndim_input, + 2, + phi::errors::InvalidArgument("The input tensor's dimension must be 2. " + "But received input's dimension = %d.", + ndim_input)); + phi::DDim out_dims = common::make_ddim({-1, -1}); + out->set_dims(out_dims); + out->set_dtype(x.dtype()); +} + void GlobalScatterInferMeta(const MetaTensor& x, const MetaTensor& local_count, const MetaTensor& global_count, diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index a4e429fdd277d6..cd21060ad0a582 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -115,6 +115,12 @@ void FlashAttnInferMeta(const MetaTensor& q, MetaTensor* softmax_lse, MetaTensor* seed_offset); +void FlashAttnQKVPackedInferMeta(const MetaTensor& qkv, + MetaTensor* out, + MetaTensor* softmax, + MetaTensor* softmax_lse, + MetaTensor* seed_offset); + void InstanceNormInferMeta(const MetaTensor& x, const MetaTensor& scale, const MetaTensor& bias, @@ -124,6 +130,13 @@ void InstanceNormInferMeta(const MetaTensor& x, MetaTensor* saved_variance, MetaConfig config = MetaConfig()); +void GlobalGatherInferMeta(const MetaTensor& x, + const MetaTensor& local_count, + const MetaTensor& global_count, + int ring_id, + bool use_calc_stream, + MetaTensor* out); + void GlobalScatterInferMeta(const MetaTensor& x, const MetaTensor& local_count, const MetaTensor& global_count, diff --git a/paddle/phi/kernels/funcs/weight_dequant_functor.h b/paddle/phi/kernels/funcs/weight_dequant_functor.h index 4eed94de7bf4dc..8d7cc93d7b1181 100644 --- a/paddle/phi/kernels/funcs/weight_dequant_functor.h +++ b/paddle/phi/kernels/funcs/weight_dequant_functor.h @@ -187,10 +187,10 @@ __global__ void int4_weight_only_dequant(const uint8_t* weight, int warp_id = threadIdx.x / 32, lane_id = threadIdx.x % 32; int tile_id = blockIdx.x * blockDim.x / 32 + warp_id; - // Every two rows of the original weights are interleaved into a row with - // stride of 64, so if each thread processes 16 elements(for int8, we can use - // ldg.128 to load weights), then every group of four adjacent threads will - // alternately process two different row weights for example every 128 + // Every 4 rows of the original weights are interleaved into a row with + // stride of 32, so if each thread processes 16 elements(for int8, we can use + // ldg.128 to load weights), then every group of two adjacent threads will + // alternately process four different row weights for example every 128 // consecutive int8 elements [128*i, 128*(i+1)-1] of row N under interleave // layout, the first 64 are from [64*i, 64*(i+1)-1] of row 2N before // interleaving, and the last 64 are from [64*i, 64*(i+1)-1] of row 2N+1 @@ -366,7 +366,6 @@ void WeightDequantize(const Context& dev_ctx, dim3 block(512); dim3 grid(n / 32); auto stream = dev_ctx.stream(); - if (algo == "weight_only_int8" && group_size == -1) { int8_weight_only_dequant<<>>( reinterpret_cast(x.data()), @@ -383,6 +382,7 @@ void WeightDequantize(const Context& dev_ctx, k, group_size); } else if (algo == "weight_only_int4" && group_size == -1) { + k *= 2; grid.x /= 2; int4_weight_only_dequant<<>>( reinterpret_cast(x.data()), @@ -391,6 +391,7 @@ void WeightDequantize(const Context& dev_ctx, n, k); } else if (algo == "weight_only_int4" && group_size > 0) { + k *= 2; grid.x /= 2; int4_weight_only_dequant<<>>( reinterpret_cast(x.data()), diff --git a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu index 4f93288edaf14c..1e919c122bf033 100644 --- a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu @@ -13,12 +13,16 @@ // limitations under the License. #include "paddle/phi/kernels/flash_attn_grad_kernel.h" +#include #include "glog/logging.h" // For VLOG() +#include "paddle/common/enforce.h" #include "paddle/common/flags.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/gpu/flash_attn_utils.h" #include "paddle/phi/kernels/reduce_sum_kernel.h" @@ -31,26 +35,205 @@ int get_num_split() { return FLAGS_cudnn_deterministic ? 1 : 0; } +template +static __global__ void SumStridedKV(const T* src, + T* dst, + const uint64_t sRowDim1, + const uint64_t sRowDim2, + const uint64_t sRowDim3, + const uint64_t sColDim, + const uint64_t sRowStride1, + const uint64_t sRowStride2, + const uint64_t sColStride, + const uint64_t dRowStride1, + const uint64_t dRowStride2) { + // SrcShape [seqlen, num_heads_k, num_heads/num_heads_k, headdim] + // AxisName [row1 , row2 , col , row3 ] + // LoopMap [blockx, thready , serialreduce , threadx] + // Ensure blockDim.x == 32 && blockDim.z == 1 + // Ensure sRowStride3 == dRowStride3 == 1 (headdim dim is contiguous) + using IndexType = uint64_t; + constexpr IndexType BlockDimX = 32; + const IndexType SRow1Begin = blockIdx.x * sRowStride1; + const IndexType SRow1End = sRowDim1 * sRowStride1; + const IndexType SRow1Stride = gridDim.x * sRowStride1; + + const IndexType SRow2Begin = threadIdx.y * sRowStride2; + const IndexType SRow2End = sRowDim2 * sRowStride2; + const IndexType SRow2Stride = blockDim.y * sRowStride2; + + // const IndexType SRow3Begin = threadIdx.x * sRowStride3; + // const IndexType SRow3End = sRowDim3 * sRowStride3; + // const IndexType SRow3Stride = BlockDimX * sRowStride3; + + constexpr IndexType SColBegin = 0; + const IndexType SColEnd = sColDim * sColStride; + const IndexType SColStride = sColStride; + + const IndexType DRow1Begin = blockIdx.x * dRowStride1; + const IndexType DRow1Stride = gridDim.x * dRowStride1; + + const IndexType DRow2Begin = threadIdx.y * dRowStride2; + const IndexType DRow2Stride = dRowStride2; + + // const IndexType DRow3Begin = threadIdx.x * dRowStride3; + // const IndexType DRow3Stride = blockDim.x * dRowStride3; + + for (auto row1 = SRow1Begin, drow1 = DRow1Begin; row1 < SRow1End; + row1 += SRow1Stride, drow1 += DRow1Stride) { + for (auto row2 = SRow2Begin, drow2 = DRow2Begin; row2 < SRow2End; + row2 += SRow2Stride, drow2 += DRow2Stride) { + const auto i1 = row1 + row2 + threadIdx.x; + const auto di1 = drow1 + drow2 + threadIdx.x; + T v[HeaddimDiv32]; +#pragma unroll + for (auto i = IndexType(0); i < HeaddimDiv32; i++) { + v[i] = T{0}; + } + for (auto col = SColBegin; col < SColEnd; col += SColStride) { + const auto i2 = i1 + col; +#pragma unroll + for (auto i = IndexType(0); i < HeaddimDiv32; i++) { + v[i] += src[i2 + i * BlockDimX]; + } + } +#pragma unroll + for (auto i = IndexType(0); i < HeaddimDiv32; i++) { + dst[di1 + i * BlockDimX] = v[i]; + } + } + } +} + +template +static auto selectSumkernel(int64_t headdim) { + PADDLE_ENFORCE_LE(headdim, + 256, + phi::errors::InvalidArgument( + "FlashAttention only support headdim <= 256")); + PADDLE_ENFORCE_EQ(headdim % 32, + 0, + phi::errors::InvalidArgument( + "FlashAttention only support headdim %% 32 == 0")); + PADDLE_ENFORCE_NE( + headdim, 0, phi::errors::InvalidArgument("Headdim can't be zero")); +#define CASEN(n) \ + case n: \ + return SumStridedKV; + switch (headdim / 32) { + CASEN(1); + CASEN(2); + CASEN(3); + CASEN(4); + CASEN(5); + CASEN(6); + CASEN(7); + CASEN(8); + } + PADDLE_FATAL("Unreachable in selectSumKernel"); +#undef CASEN +} + template -void FlashAttnUnpaddedGradKernel(const Context& ctx, - const DenseTensor& q, - const DenseTensor& k, - const DenseTensor& v, - const DenseTensor& cu_seqlens_q, - const DenseTensor& cu_seqlens_k, - const DenseTensor& out, - const DenseTensor& softmax_lse, - const DenseTensor& seed_offset, - const paddle::optional& attn_mask, - const DenseTensor& dout, - int64_t max_seqlen_q, - int64_t max_seqlen_k, - float scale, - float dropout, - bool causal, - DenseTensor* dq, - DenseTensor* dk, - DenseTensor* dv) { +static void kvReduceForGQA(const Context& ctx, + const DenseTensor& dk_tmp, + DenseTensor* dk) { + PADDLE_ENFORCE_EQ( + dk->strides()[2], + 1, + phi::errors::InvalidArgument("headdim dimention must be contiguous")); + PADDLE_ENFORCE_EQ( + dk_tmp.strides()[3], + 1, + phi::errors::InvalidArgument("headdim dimention must be contiguous")); + const int64_t reduceDimSize = dk_tmp.dims()[2]; + const size_t blockNum = + std::min((static_cast(dk_tmp.dims()[0] + 31) / 32), + static_cast(1024l)); + const dim3 threadNum{32, 4, 1}; + auto sumkernel = selectSumkernel(dk_tmp.dims()[3]); + sumkernel<<>>( + reinterpret_cast(dk_tmp.data()), + reinterpret_cast(dk->data()), + dk_tmp.dims()[0], + dk_tmp.dims()[1], + dk_tmp.dims()[3], + dk_tmp.dims()[2], + dk_tmp.strides()[0], + dk_tmp.strides()[1], + // dk_tmp.strides()[3], + dk_tmp.strides()[2], + dk->strides()[0], + dk->strides()[1] + // dk->strides()[2] + ); +} +template +static void kvReduceBatchedForGQA(const Context& ctx, + const DenseTensor& dk_tmp, + DenseTensor* dk) { + PADDLE_ENFORCE_EQ( + dk->strides()[3], + 1, + phi::errors::InvalidArgument("headdim dimention must be contiguous")); + PADDLE_ENFORCE_EQ( + dk_tmp.strides()[4], + 1, + phi::errors::InvalidArgument("headdim dimention must be contiguous")); + PADDLE_ENFORCE_EQ( + dk->strides()[0], + dk->strides()[1] * dk->dims()[1], + phi::errors::InvalidArgument("batchsize dimention must be contiguous")); + PADDLE_ENFORCE_EQ( + dk_tmp.strides()[0], + dk_tmp.strides()[1] * dk_tmp.dims()[1], + phi::errors::InvalidArgument("batchsize dimention must be contiguous")); + const int64_t reduceDimSize = dk_tmp.dims()[3]; + const size_t blockNum = std::min( + (static_cast(dk_tmp.dims()[0] * dk_tmp.dims()[1] + 31) / 32), + static_cast(1024l)); + const dim3 threadNum{32, 4, 1}; + auto sumkernel = selectSumkernel(dk_tmp.dims()[4]); + // here implicitly flat [batch,seqlen], and require batch dim to be contiguous + sumkernel<<>>( + reinterpret_cast(dk_tmp.data()), + reinterpret_cast(dk->data()), + dk_tmp.dims()[0] * dk_tmp.dims()[1], + dk_tmp.dims()[2], + dk_tmp.dims()[4], + dk_tmp.dims()[3], + dk_tmp.strides()[1], + dk_tmp.strides()[2], + // dk_tmp.strides()[4], + dk_tmp.strides()[3], + dk->strides()[1], + dk->strides()[2] + // dk->strides()[3] + ); +} + +template +void FlashAttnUnpaddedGradBaseKernel( + const Context& ctx, + const DenseTensor& q, + const DenseTensor& k, + const DenseTensor& v, + const DenseTensor& cu_seqlens_q, + const DenseTensor& cu_seqlens_k, + const DenseTensor& out, + const DenseTensor& softmax_lse, + const DenseTensor& seed_offset, + const paddle::optional& attn_mask, + const DenseTensor& dout, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + float scale, + float dropout, + bool causal, + DenseTensor* dq, + DenseTensor* dk, + DenseTensor* dv, + bool varlen_padded) { #ifdef PADDLE_WITH_FLASHATTN // q,k,v [total_*, num_heads, head_dim] auto dims = q.dims(); @@ -64,37 +247,30 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, bool is_mha = (num_heads == num_heads_k); - void* dq_ptr = nullptr; - void* dk_ptr = nullptr; - void* dv_ptr = nullptr; - + DenseTensor* kdq = dq; DenseTensor dq_tmp; - if (dq) { - dq_ptr = ctx.template Alloc(dq); - } else { + if (!dq) { dq_tmp.Resize(dims); - dq_ptr = ctx.template Alloc(&dq_tmp); + ctx.template Alloc(&dq_tmp); + kdq = &dq_tmp; } std::initializer_list dk_dv_shape = { total_k, num_heads_k, num_heads / num_heads_k, head_size}; + DenseTensor *kdk = dk, *kdv = dv; DenseTensor dk_tmp; - if (dk && is_mha) { - ctx.template Alloc(dk); - dk_ptr = dk->data(); - } else { + if (!dk || !is_mha) { dk_tmp.Resize(dk_dv_shape); - dk_ptr = ctx.template Alloc(&dk_tmp); + ctx.template Alloc(&dk_tmp); + kdk = &dk_tmp; } DenseTensor dv_tmp; - if (dv && is_mha) { - ctx.template Alloc(dv); - dv_ptr = dv->data(); - } else { + if (!dv || !is_mha) { dv_tmp.Resize(dk_dv_shape); - dv_ptr = ctx.template Alloc(&dv_tmp); + ctx.template Alloc(&dv_tmp); + kdv = &dv_tmp; } const cudaStream_t stream = ctx.stream(); @@ -139,9 +315,9 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, cu_seqlens_q.data(), cu_seqlens_k.data(), params.rng_state.data(), - dq_ptr, - dk_ptr, - dv_ptr, + kdq->data(), + kdk->data(), + kdv->data(), params.dq_accum.data(), params.batch_size, params.max_seqlen_q, @@ -162,20 +338,209 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, params.seed, params.offset, params.attn_mask_tensor ? params.attn_mask_tensor->data() : nullptr, - params.attn_mask_tensor ? params.mask_dims.data() : nullptr); + params.attn_mask_tensor ? params.mask_dims.data() : nullptr, + q.strides()[0], + k.strides()[0], + v.strides()[0], + q.strides()[1], + k.strides()[1], + v.strides()[1], + out.strides()[0], + out.strides()[1], + max_seqlen_q * q.strides()[0], + max_seqlen_k * k.strides()[0], + max_seqlen_k * v.strides()[0], + max_seqlen_q * out.strides()[0], + kdq->strides()[0], + kdk->strides()[0], + kdv->strides()[0], + kdq->strides()[1], + kdk->strides()[kdk->strides().size() - 2], + kdv->strides()[kdv->strides().size() - 2], + dout.strides()[0], + dout.strides()[1], + max_seqlen_q * kdq->strides()[0], + max_seqlen_k * kdk->strides()[0], + max_seqlen_k * kdv->strides()[0], + max_seqlen_q * dout.strides()[0], + varlen_padded); CheckFlashAttnStatus(succ); if (!is_mha) { if (dk) { - phi::SumKernel(ctx, dk_tmp, {2}, dk->type(), false, dk); + if (dk->meta().is_contiguous()) + phi::SumKernel(ctx, dk_tmp, {2}, dk->type(), false, dk); + else + kvReduceForGQA(ctx, dk_tmp, dk); } if (dv) { - phi::SumKernel(ctx, dv_tmp, {2}, dv->type(), false, dv); + if (dv->meta().is_contiguous()) + phi::SumKernel(ctx, dv_tmp, {2}, dv->type(), false, dv); + else + kvReduceForGQA(ctx, dv_tmp, dv); } } #else RaiseNotSupportedError(); #endif } + +template +void FlashAttnUnpaddedGradKernel(const Context& ctx, + const DenseTensor& q, + const DenseTensor& k, + const DenseTensor& v, + const DenseTensor& cu_seqlens_q, + const DenseTensor& cu_seqlens_k, + const DenseTensor& out, + const DenseTensor& softmax_lse, + const DenseTensor& seed_offset, + const paddle::optional& attn_mask, + const DenseTensor& dout, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + float scale, + float dropout, + bool causal, + DenseTensor* dq, + DenseTensor* dk, + DenseTensor* dv) { +#ifdef PADDLE_WITH_FLASHATTN + if (dq) { + ctx.template Alloc(dq); + } + if (dk) { + ctx.template Alloc(dk); + } + if (dv) { + ctx.template Alloc(dv); + } + FlashAttnUnpaddedGradBaseKernel(ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + out, + softmax_lse, + seed_offset, + attn_mask, + dout, + max_seqlen_q, + max_seqlen_k, + scale, + dropout, + causal, + dq, + dk, + dv, + false /*varlen_padded*/); +#else + RaiseNotSupportedError(); +#endif +} + +static void sliceFlattenView(const DenseTensor& in, + DenseTensor* out, + int axis, + int64_t offset, + int64_t sliceLength) { + PADDLE_ENFORCE_LT( + axis, + in.dims().size(), + phi::errors::InvalidArgument("sliceView receive axis out of bound")); + std::array dimArr; + std::array strideArr; + auto id = dimArr.begin(), is = strideArr.begin(); + for (int i = 0; i < in.dims().size(); i++) { + if (i == axis) continue; + if (i == axis + 1) + *id = in.dims()[i] * sliceLength; + else + *id = in.dims()[i]; + *is = in.strides()[i]; + id++; + is++; + } + *out = DenseTensor{ + in.Holder(), + DenseTensorMeta{in.dtype(), + DDim{dimArr.data(), in.dims().size() - 1}, + DDim(strideArr.data(), in.dims().size() - 1)}}; + out->set_offset(in.offset() + + offset * in.strides()[axis] * SizeOf(out->dtype())); +} +template +struct ZeroFunctor { + __device__ __forceinline__ OutT operator()() const { + return static_cast(0); + } +}; +template +void FlashAttnVarlenQKVPackedGradKernel( + const Context& ctx, + const DenseTensor& qkv, + const DenseTensor& cu_seqlens_q, + const DenseTensor& cu_seqlens_k, + const DenseTensor& out, + const DenseTensor& softmax_lse, + const DenseTensor& seed_offset, + const paddle::optional& attn_mask, + const DenseTensor& dout, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + float scale, + float dropout, + bool causal, + bool varlen_padded, + DenseTensor* dqkv) { +#ifdef PADDLE_WITH_FLASHATTN + // q,k,v [total_*, num_heads, head_dim] + const auto head_groupnum = qkv.dims()[1]; // nheads/nheads_k + 1 + 1 + DenseTensor q, k, v; + sliceFlattenView(qkv, &q, 1, 0, head_groupnum - 2); + sliceFlattenView(qkv, &k, 1, head_groupnum - 2, 1); + sliceFlattenView(qkv, &v, 1, head_groupnum - 1, 1); + // DenseTensor dqkv_tmp; + if (!dqkv) { + return; + // dqkv is the only output. No need to compute if no dqkv + // dqkv_tmp.Resize(qkv.dims()); + // dqkv = &dqkv_tmp; + } + ctx.template Alloc(dqkv); + { + std::vector inputs{}; + std::vector outputs{dqkv}; + phi::funcs::ElementwiseKernel(ctx, inputs, &outputs, ZeroFunctor()); + } + DenseTensor dq, dk, dv; + sliceFlattenView(*dqkv, &dq, 1, 0, head_groupnum - 2); + sliceFlattenView(*dqkv, &dk, 1, head_groupnum - 2, 1); + sliceFlattenView(*dqkv, &dv, 1, head_groupnum - 1, 1); + FlashAttnUnpaddedGradBaseKernel(ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + out, + softmax_lse, + seed_offset, + attn_mask, + dout, + max_seqlen_q, + max_seqlen_k, + scale, + dropout, + causal, + &dq, + &dk, + &dv, + varlen_padded); +#else + RaiseNotSupportedError(); +#endif +} template void FlashAttnGradBaseKernel( const Context& ctx, @@ -208,36 +573,29 @@ void FlashAttnGradBaseKernel( bool is_mha = (num_heads == num_heads_k); - void* dq_ptr = nullptr; - void* dk_ptr = nullptr; - void* dv_ptr = nullptr; - + std::initializer_list dk_dv_shape = { + batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}; + DenseTensor* kdq = dq; DenseTensor dq_tmp; - if (dq) { - dq_ptr = ctx.template Alloc(dq); - } else { + if (!dq) { dq_tmp.Resize(dims); - dq_ptr = ctx.template Alloc(&dq_tmp); + ctx.template Alloc(&dq_tmp); + kdq = &dq_tmp; } + DenseTensor *kdk = dk, *kdv = dv; DenseTensor dk_tmp; - std::initializer_list dk_dv_shape = { - batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}; - if (dk && is_mha) { - ctx.template Alloc(dk); - dk_ptr = dk->data(); - } else { + if (!dk || !is_mha) { dk_tmp.Resize(dk_dv_shape); - dk_ptr = ctx.template Alloc(&dk_tmp); + ctx.template Alloc(&dk_tmp); + kdk = &dk_tmp; } DenseTensor dv_tmp; - if (dv && is_mha) { - ctx.template Alloc(dv); - dv_ptr = dv->data(); - } else { + if (!dv || !is_mha) { dv_tmp.Resize(dk_dv_shape); - dv_ptr = ctx.template Alloc(&dv_tmp); + ctx.template Alloc(&dv_tmp); + kdv = &dv_tmp; } const cudaStream_t stream = ctx.stream(); @@ -291,9 +649,9 @@ void FlashAttnGradBaseKernel( params.softmax_d.data(), softmax_lse.data(), params.rng_state.data(), - dq_ptr, - dk_ptr, - dv_ptr, + kdq->data(), + kdk->data(), + kdv->data(), params.dq_accum.data(), params.batch_size, params.max_seqlen_q, @@ -321,14 +679,45 @@ void FlashAttnGradBaseKernel( params.attn_mask_start_row_indices_tensor ? params.attn_mask_start_row_indices_dims.data() : nullptr, - params.attn_mask_start_row); + params.attn_mask_start_row, + q.strides()[1], + k.strides()[1], + v.strides()[1], + q.strides()[2], + k.strides()[2], + v.strides()[2], + out.strides()[1], + out.strides()[2], + q.strides()[0], + k.strides()[0], + v.strides()[0], + out.strides()[0], + kdq->strides()[1], + kdk->strides()[1], + kdv->strides()[1], + kdq->strides()[2], + kdk->strides()[kdk->strides().size() - 2], + kdv->strides()[kdv->strides().size() - 2], + dout.strides()[1], + dout.strides()[2], + kdq->strides()[0], + kdk->strides()[0], + kdv->strides()[0], + dout.strides()[0]); CheckFlashAttnStatus(succ); if (!is_mha) { if (dk) { - phi::SumKernel(ctx, dk_tmp, {3}, dk->type(), false, dk); + if (dk->meta().is_contiguous()) + phi::SumKernel(ctx, dk_tmp, {3}, dk->type(), false, dk); + else + kvReduceBatchedForGQA(ctx, dk_tmp, dk); } + if (dv) { - phi::SumKernel(ctx, dv_tmp, {3}, dv->type(), false, dv); + if (dv->meta().is_contiguous()) + phi::SumKernel(ctx, dv_tmp, {3}, dv->type(), false, dv); + else + kvReduceBatchedForGQA(ctx, dv_tmp, dv); } } #else @@ -351,6 +740,15 @@ void FlashAttnGradKernel(const Context& ctx, DenseTensor* dq, DenseTensor* dk, DenseTensor* dv) { + if (dq) { + ctx.template Alloc(dq); + } + if (dk) { + ctx.template Alloc(dk); + } + if (dv) { + ctx.template Alloc(dv); + } FlashAttnGradBaseKernel(ctx, q, k, @@ -369,6 +767,58 @@ void FlashAttnGradKernel(const Context& ctx, dv); } +template +void FlashAttnQKVPackedGradKernel( + const Context& ctx, + const DenseTensor& qkv, + const DenseTensor& out, + const DenseTensor& softmax_lse, + const DenseTensor& seed_offset, + const paddle::optional& attn_mask, + const DenseTensor& dout, + float dropout, + bool causal, + DenseTensor* dqkv) { +#ifdef PADDLE_WITH_FLASHATTN + // qkv [batchsize, seqlen, nheads/nheads_k+2, nheads_k, head_dim] + const auto head_groupnum = qkv.dims()[2]; // nheads/nheads_k + 1 + 1 + DenseTensor q, k, v; + sliceFlattenView(qkv, &q, 2, 0, head_groupnum - 2); + sliceFlattenView(qkv, &k, 2, head_groupnum - 2, 1); + sliceFlattenView(qkv, &v, 2, head_groupnum - 1, 1); + // DenseTensor dqkv_tmp; + if (!dqkv) { + return; + // dqkv is the only output. No need to compute if no dqkv + // dqkv_tmp.Resize(qkv.dims()); + // dqkv = &dqkv_tmp; + } + ctx.template Alloc(dqkv); + DenseTensor dq, dk, dv; + sliceFlattenView(*dqkv, &dq, 2, 0, head_groupnum - 2); + sliceFlattenView(*dqkv, &dk, 2, head_groupnum - 2, 1); + sliceFlattenView(*dqkv, &dv, 2, head_groupnum - 1, 1); + FlashAttnGradBaseKernel(ctx, + q, + k, + v, + out, + softmax_lse, + seed_offset, + attn_mask, + paddle::none, + dout, + dropout, + causal, + 0, + &dq, + &dk, + &dv); +#else + RaiseNotSupportedError(); +#endif +} + template void FlashAttnWithSparseGradKernel( const Context& ctx, @@ -386,6 +836,15 @@ void FlashAttnWithSparseGradKernel( DenseTensor* dq, DenseTensor* dk, DenseTensor* dv) { + if (dq) { + ctx.template Alloc(dq); + } + if (dk) { + ctx.template Alloc(dk); + } + if (dv) { + ctx.template Alloc(dv); + } FlashAttnGradBaseKernel(ctx, q, k, @@ -414,6 +873,15 @@ PD_REGISTER_KERNEL(flash_attn_unpadded_grad, kernel->InputAt(7).SetBackend(phi::Backend::ALL_BACKEND); // seed_offset } +PD_REGISTER_KERNEL(flash_attn_varlen_qkvpacked_grad, + GPU, + ALL_LAYOUT, + phi::FlashAttnVarlenQKVPackedGradKernel, + phi::dtype::float16, + phi::dtype::bfloat16) { + kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); // seed_offset +} + PD_REGISTER_KERNEL(flash_attn_grad, GPU, ALL_LAYOUT, @@ -423,6 +891,15 @@ PD_REGISTER_KERNEL(flash_attn_grad, kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); // seed_offset } +PD_REGISTER_KERNEL(flash_attn_qkvpacked_grad, + GPU, + ALL_LAYOUT, + phi::FlashAttnQKVPackedGradKernel, + phi::dtype::float16, + phi::dtype::bfloat16) { + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); // seed_offset +} + PD_REGISTER_KERNEL(flash_attn_with_sparse_mask_grad, GPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/gpu/flash_attn_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_kernel.cu index 7eb2d342feb792..64eb8450bcac62 100644 --- a/paddle/phi/kernels/gpu/flash_attn_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_kernel.cu @@ -14,17 +14,26 @@ #include "paddle/phi/kernels/flash_attn_kernel.h" +#include #include "glog/logging.h" // For VLOG() #include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/gpu/flash_attn_utils.h" namespace phi { +template +struct ZeroFunctor { + __device__ __forceinline__ OutT operator()() const { + return static_cast(0); + } +}; template -void FlashAttnUnpaddedKernel( +void FlashAttnUnpaddedBaseKernel( const Context& ctx, const DenseTensor& q, const DenseTensor& k, @@ -44,10 +53,16 @@ void FlashAttnUnpaddedKernel( DenseTensor* out, DenseTensor* softmax, DenseTensor* softmax_lse, - DenseTensor* seed_offset) { + DenseTensor* seed_offset, + bool varlen_padded) { #ifdef PADDLE_WITH_FLASHATTN ctx.template Alloc(out); + if (varlen_padded) { + std::vector inputs{}; + std::vector outputs{out}; + phi::funcs::ElementwiseKernel(ctx, inputs, &outputs, ZeroFunctor()); + } cudaStream_t stream = ctx.stream(); // q, k, v [total_q/k/v, num_heads, head_dim] @@ -120,13 +135,158 @@ void FlashAttnUnpaddedKernel( params.seed, params.offset, params.attn_mask_tensor ? params.attn_mask_tensor->data() : nullptr, - params.attn_mask_tensor ? params.mask_dims.data() : nullptr); + params.attn_mask_tensor ? params.mask_dims.data() : nullptr, + q.strides()[0], + k.strides()[0], + v.strides()[0], + q.strides()[1], + k.strides()[1], + v.strides()[1], + out->strides()[0], + out->strides()[1], + max_seqlen_q * q.strides()[0], + max_seqlen_k * k.strides()[0], + max_seqlen_k * v.strides()[0], + max_seqlen_q * out->strides()[0], + varlen_padded); CheckFlashAttnStatus(succ); #else RaiseNotSupportedError(); #endif } +template +void FlashAttnUnpaddedKernel( + const Context& ctx, + const DenseTensor& q, + const DenseTensor& k, + const DenseTensor& v, + const DenseTensor& cu_seqlens_q, + const DenseTensor& cu_seqlens_k, + const paddle::optional& fixed_seed_offset, + const paddle::optional& attn_mask, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + float scale, + float dropout, + bool causal, + bool return_softmax, + bool is_test, + const std::string& rng_name, + DenseTensor* out, + DenseTensor* softmax, + DenseTensor* softmax_lse, + DenseTensor* seed_offset) { +#ifdef PADDLE_WITH_FLASHATTN + FlashAttnUnpaddedBaseKernel(ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + fixed_seed_offset, + attn_mask, + max_seqlen_q, + max_seqlen_k, + scale, + dropout, + causal, + return_softmax, + is_test, + rng_name, + out, + softmax, + softmax_lse, + seed_offset, + false /*varlen_padded*/); +#else + RaiseNotSupportedError(); +#endif +} + +static void sliceFlattenView(const DenseTensor& in, + DenseTensor* out, + int axis, + int64_t offset, + int64_t sliceLength) { + PADDLE_ENFORCE_LT( + axis, + in.dims().size(), + phi::errors::InvalidArgument("sliceView receive axis out of bound")); + std::array dimArr; + std::array strideArr; + auto id = dimArr.begin(), is = strideArr.begin(); + for (int i = 0; i < in.dims().size(); i++) { + if (i == axis) continue; + if (i == axis + 1) + *id = in.dims()[i] * sliceLength; + else + *id = in.dims()[i]; + *is = in.strides()[i]; + id++; + is++; + } + *out = DenseTensor{ + in.Holder(), + DenseTensorMeta{in.dtype(), + DDim{dimArr.data(), in.dims().size() - 1}, + DDim(strideArr.data(), in.dims().size() - 1)}}; + out->set_offset(in.offset() + + offset * in.strides()[axis] * SizeOf(out->dtype())); +} +template +void FlashAttnVarlenQKVPackedKernel( + const Context& ctx, + const DenseTensor& qkv, + const DenseTensor& cu_seqlens_q, + const DenseTensor& cu_seqlens_k, + const paddle::optional& fixed_seed_offset, + const paddle::optional& attn_mask, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + float scale, + float dropout, + bool causal, + bool return_softmax, + bool is_test, + const std::string& rng_name, + bool varlen_padded, + DenseTensor* out, + DenseTensor* softmax, + DenseTensor* softmax_lse, + DenseTensor* seed_offset) { +#ifdef PADDLE_WITH_FLASHATTN + const auto head_groupnum = qkv.dims()[1]; // nheads/nheads_k + 1 + 1 + DenseTensor q, k, v; + sliceFlattenView(qkv, &q, 1, 0, head_groupnum - 2); + sliceFlattenView(qkv, &k, 1, head_groupnum - 2, 1); + sliceFlattenView(qkv, &v, 1, head_groupnum - 1, 1); + FlashAttnUnpaddedBaseKernel(ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + fixed_seed_offset, + attn_mask, + max_seqlen_q, + max_seqlen_k, + scale, + dropout, + causal, + return_softmax, + is_test, + rng_name, + out, + softmax, + softmax_lse, + seed_offset, + varlen_padded); +#else + RaiseNotSupportedError(); +#endif +} + template void FlashAttnBaseKernel( const Context& ctx, @@ -239,7 +399,19 @@ void FlashAttnBaseKernel( params.attn_mask_start_row_indices_tensor ? params.attn_mask_start_row_indices_dims.data() : nullptr, - params.attn_mask_start_row); + params.attn_mask_start_row, + q.strides()[1], + k.strides()[1], + v.strides()[1], + q.strides()[2], + k.strides()[2], + v.strides()[2], + out->strides()[1], + out->strides()[2], + q.strides()[0], + k.strides()[0], + v.strides()[0], + out->strides()[0]); CheckFlashAttnStatus(succ); #else RaiseNotSupportedError(); @@ -281,6 +453,49 @@ void FlashAttnKernel(const Context& ctx, seed_offset); } +template +void FlashAttnQKVPackedKernel( + const Context& ctx, + const DenseTensor& qkv, + const paddle::optional& fixed_seed_offset, + const paddle::optional& attn_mask, + float dropout, + bool causal, + bool return_softmax, + bool is_test, + const std::string& rng_name, + DenseTensor* out, + DenseTensor* softmax, + DenseTensor* softmax_lse, + DenseTensor* seed_offset) { +#ifdef PADDLE_WITH_FLASHATTN + const auto head_groupnum = qkv.dims()[2]; // nheads/nheads_k + 1 + 1 + DenseTensor q, k, v; + sliceFlattenView(qkv, &q, 2, 0, head_groupnum - 2); + sliceFlattenView(qkv, &k, 2, head_groupnum - 2, 1); + sliceFlattenView(qkv, &v, 2, head_groupnum - 1, 1); + FlashAttnBaseKernel(ctx, + q, + k, + v, + fixed_seed_offset, + attn_mask, + paddle::none, + dropout, + causal, + return_softmax, + is_test, + rng_name, + 0, + out, + softmax, + softmax_lse, + seed_offset); +#else + RaiseNotSupportedError(); +#endif +} + template void FlashAttnWithSparseMaskKernel( const Context& ctx, @@ -330,6 +545,16 @@ PD_REGISTER_KERNEL(flash_attn_unpadded, phi::Backend::ALL_BACKEND); // fixed_seed_offset } +PD_REGISTER_KERNEL(flash_attn_varlen_qkvpacked, + GPU, + ALL_LAYOUT, + phi::FlashAttnVarlenQKVPackedKernel, + phi::dtype::float16, + phi::dtype::bfloat16) { + kernel->InputAt(3).SetBackend( + phi::Backend::ALL_BACKEND); // fixed_seed_offset +} + PD_REGISTER_KERNEL(flash_attn, GPU, ALL_LAYOUT, @@ -340,6 +565,16 @@ PD_REGISTER_KERNEL(flash_attn, phi::Backend::ALL_BACKEND); // fixed_seed_offset } +PD_REGISTER_KERNEL(flash_attn_qkvpacked, + GPU, + ALL_LAYOUT, + phi::FlashAttnQKVPackedKernel, + phi::dtype::float16, + phi::dtype::bfloat16) { + kernel->InputAt(1).SetBackend( + phi::Backend::ALL_BACKEND); // fixed_seed_offset +} + PD_REGISTER_KERNEL(flash_attn_with_sparse_mask, GPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/gpu/group_norm_kernel.cu b/paddle/phi/kernels/gpu/group_norm_kernel.cu index 301701c61d34ea..4835b643efcc76 100644 --- a/paddle/phi/kernels/gpu/group_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/group_norm_kernel.cu @@ -154,8 +154,8 @@ inline __device__ void UpdateSum( #endif template -__global__ void groupNormNHWCSumSingerChannelKernel( - const GroupNormNHWCParams params) { +__global__ void groupNormNDHWCSumSingerChannelKernel( + const GroupNormNDHWCParams params) { // The instance in the batch. __shared__ float2 smem[THREADS_PER_BLOCK]; int32_t ni = blockIdx.z; @@ -164,18 +164,18 @@ __global__ void groupNormNHWCSumSingerChannelKernel( return; } // The first activation loaded by that block. - int32_t hwBegin = blockIdx.y * params.hwPerBlock; + int32_t dhwBegin = blockIdx.y * params.dhwPerBlock; // The last activation loaded by that block. - int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + int32_t dhwEnd = min(dhwBegin + params.dhwPerBlock, params.dhw); // The sums. float sum = 0.F; float sumSq = 0.F; - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { + for (int32_t dhwi = dhwBegin; dhwi < dhwEnd; ++dhwi) { // The offset. - int64_t offset = static_cast(ni) * params.hwc + - static_cast(hwi) * params.c + ci; + int64_t offset = static_cast(ni) * params.dhwc + + static_cast(dhwi) * params.c + ci; float src_data = *reinterpret_cast(¶ms.srcX[offset]); UpdateSum(¶ms.srcX[offset], &sum, &sumSq); } @@ -187,12 +187,12 @@ __global__ void groupNormNHWCSumSingerChannelKernel( float2 sums = smem[threadIdx.x]; atomicAdd(¶ms.redBuffer[(2 * ni + 0) * params.groups + ci], - sums.x * params.invHWC); + sums.x * params.invDHWC); atomicAdd(¶ms.redBuffer[(2 * ni + 1) * params.groups + ci], sums.y); } template -__global__ void groupNormNHWCSumKernel(const GroupNormNHWCParams params) { +__global__ void groupNormNDHWCSumKernel(const GroupNormNDHWCParams params) { // The object in charge of doing the sums for the different blocks. typedef cub::BlockScan BlockScan; __shared__ typename BlockScan::TempStorage tempStorage; @@ -210,18 +210,18 @@ __global__ void groupNormNHWCSumKernel(const GroupNormNHWCParams params) { return; } // The first activation loaded by that block. - int32_t hwBegin = blockIdx.y * params.hwPerBlock; + int32_t dhwBegin = blockIdx.y * params.dhwPerBlock; // The last activation loaded by that block. - int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + int32_t dhwEnd = min(dhwBegin + params.dhwPerBlock, params.dhw); // The sums. float sum = 0.F; float sumSq = 0.F; - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { + for (int32_t dhwi = dhwBegin; dhwi < dhwEnd; ++dhwi) { // The offset. - int64_t offset = static_cast(ni) * params.hwc + - static_cast(hwi) * params.c + ci; + int64_t offset = static_cast(ni) * params.dhwc + + static_cast(dhwi) * params.c + ci; float src_data = *reinterpret_cast(¶ms.srcX[offset]); UpdateSum(¶ms.srcX[offset], &sum, &sumSq); } @@ -249,108 +249,108 @@ __global__ void groupNormNHWCSumKernel(const GroupNormNHWCParams params) { params.cPerBlock - THREADS_PER_CHANNEL) { float2 sums = smem[gi]; atomicAdd(¶ms.redBuffer[(2 * ni + 0) * params.groups + gj], - sums.x * params.invHWC); + sums.x * params.invDHWC); atomicAdd(¶ms.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y); } } template -void groupNormNHWCSum::operator()(GroupNormNHWCParams* params, - gpuStream_t stream) { +void groupNormNDHWCSum::operator()(GroupNormNDHWCParams* params, + gpuStream_t stream) { dim3 grid; grid.x = divUp(params->c, params->cPerBlock); - grid.y = divUp(params->hw, params->hwPerBlock); + grid.y = divUp(params->dhw, params->dhwPerBlock); grid.z = params->n; if (params->cPerGroup % 2 == 0) { switch (params->cPerBlock) { case 512: case 480: - groupNormNHWCSumKernel<<>>(*params); + groupNormNDHWCSumKernel<<>>(*params); break; case 320: - groupNormNHWCSumKernel<<>>(*params); + groupNormNDHWCSumKernel<<>>(*params); break; case 256: - groupNormNHWCSumKernel<<>>(*params); + groupNormNDHWCSumKernel<<>>(*params); break; case 128: - groupNormNHWCSumKernel<<>>(*params); + groupNormNDHWCSumKernel<<>>(*params); break; default: grid.x = divUp(params->c, 128); params->cPerBlock = 128; - groupNormNHWCSumKernel<<>>(*params); + groupNormNDHWCSumKernel<<>>(*params); } } else { if (params->cPerGroup != 1) { switch (params->cPerBlock) { case 512: - groupNormNHWCSumKernel<<>>(*params); + groupNormNDHWCSumKernel<<>>(*params); break; case 480: - groupNormNHWCSumKernel<<>>(*params); + groupNormNDHWCSumKernel<<>>(*params); break; case 320: - groupNormNHWCSumKernel<<>>(*params); + groupNormNDHWCSumKernel<<>>(*params); break; case 256: - groupNormNHWCSumKernel<<>>(*params); + groupNormNDHWCSumKernel<<>>(*params); break; case 128: - groupNormNHWCSumKernel<<>>(*params); + groupNormNDHWCSumKernel<<>>(*params); break; default: grid.x = divUp(params->c, 128); params->cPerBlock = 128; - groupNormNHWCSumKernel<<>>(*params); + groupNormNDHWCSumKernel<<>>(*params); } } else { switch (params->cPerBlock) { case 512: - groupNormNHWCSumSingerChannelKernel + groupNormNDHWCSumSingerChannelKernel <<>>(*params); break; case 480: - groupNormNHWCSumSingerChannelKernel + groupNormNDHWCSumSingerChannelKernel <<>>(*params); break; case 320: - groupNormNHWCSumSingerChannelKernel + groupNormNDHWCSumSingerChannelKernel <<>>(*params); break; case 256: - groupNormNHWCSumSingerChannelKernel + groupNormNDHWCSumSingerChannelKernel <<>>(*params); break; case 128: - groupNormNHWCSumSingerChannelKernel + groupNormNDHWCSumSingerChannelKernel <<>>(*params); break; default: grid.x = divUp(params->c, 128); params->cPerBlock = 128; - groupNormNHWCSumSingerChannelKernel + groupNormNDHWCSumSingerChannelKernel <<>>(*params); } } } } -template class groupNormNHWCSum; +template class groupNormNDHWCSum; template -inline __device__ void GroupNormCompute(int32_t hwBegin, - int32_t hwEnd, +inline __device__ void GroupNormCompute(int32_t dhwBegin, + int32_t dhwEnd, int32_t ci, - const GroupNormNHWCParams& params, + const GroupNormNDHWCParams& params, float mean, float invStdDev) { float gamma = phi::__2float(*(reinterpret_cast(params.gamma) + ci)); float beta = phi::__2float(*(reinterpret_cast(params.beta) + ci)); - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { + for (int32_t dhwi = dhwBegin; dhwi < dhwEnd; ++dhwi) { // The src/dst offset. - int64_t offset = (int64_t)blockIdx.z * params.hwc + hwi * params.c + ci; + int64_t offset = (int64_t)blockIdx.z * params.dhwc + dhwi * params.c + ci; const float src_data = phi::__2float(params.srcX[offset]); // Normalize the channels. float dst_data = (src_data - mean) * invStdDev; @@ -369,10 +369,10 @@ inline __device__ void GroupNormCompute(int32_t hwBegin, template <> inline __device__ void GroupNormCompute( - int32_t hwBegin, - int32_t hwEnd, + int32_t dhwBegin, + int32_t dhwEnd, int32_t ci, - const GroupNormNHWCParams& params, + const GroupNormNDHWCParams& params, float mean, float invStdDev) { float2 gammaF2, betaF2; @@ -382,9 +382,9 @@ inline __device__ void GroupNormCompute( reinterpret_cast(params.beta) + ci)); // Iterate over the activations to compute the sums. - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { + for (int32_t dhwi = dhwBegin; dhwi < dhwEnd; ++dhwi) { // The src/dst offset. - int64_t offset = (int64_t)blockIdx.z * params.hwc + hwi * params.c + ci; + int64_t offset = (int64_t)blockIdx.z * params.dhwc + dhwi * params.c + ci; // Fetch two channels per thread. __half2 h2 = *reinterpret_cast<__half2 const*>(¶ms.srcX[offset]); @@ -412,10 +412,10 @@ inline __device__ void GroupNormCompute( template <> inline __device__ void GroupNormCompute<__half, 2>( - int32_t hwBegin, - int32_t hwEnd, + int32_t dhwBegin, + int32_t dhwEnd, int32_t ci, - const GroupNormNHWCParams<__half>& params, + const GroupNormNDHWCParams<__half>& params, float mean, float invStdDev) { float2 gammaF2, betaF2; @@ -425,9 +425,9 @@ inline __device__ void GroupNormCompute<__half, 2>( reinterpret_cast(params.beta) + ci)); // Iterate over the activations to compute the sums. - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { + for (int32_t dhwi = dhwBegin; dhwi < dhwEnd; ++dhwi) { // The src/dst offset. - int64_t offset = (int64_t)blockIdx.z * params.hwc + hwi * params.c + ci; + int64_t offset = (int64_t)blockIdx.z * params.dhwc + dhwi * params.c + ci; // Fetch two channels per thread. __half2 h2 = *reinterpret_cast<__half2 const*>(¶ms.srcX[offset]); @@ -456,10 +456,10 @@ inline __device__ void GroupNormCompute<__half, 2>( #ifdef PADDLE_CUDA_BF16 template <> inline __device__ void GroupNormCompute( - int32_t hwBegin, - int32_t hwEnd, + int32_t dhwBegin, + int32_t dhwEnd, int32_t ci, - const GroupNormNHWCParams& params, + const GroupNormNDHWCParams& params, float mean, float invStdDev) { float2 gammaF2, betaF2; @@ -469,9 +469,9 @@ inline __device__ void GroupNormCompute( reinterpret_cast<__nv_bfloat16 const*>(params.beta) + ci)); // Iterate over the activations to compute the sums. - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { + for (int32_t dhwi = dhwBegin; dhwi < dhwEnd; ++dhwi) { // The src/dst offset. - int64_t offset = (int64_t)blockIdx.z * params.hwc + hwi * params.c + ci; + int64_t offset = (int64_t)blockIdx.z * params.dhwc + dhwi * params.c + ci; // Fetch two channels per thread. __nv_bfloat162 h2 = @@ -501,7 +501,8 @@ inline __device__ void GroupNormCompute( #endif template -__global__ void groupNormNHWCScaleKernel(const GroupNormNHWCParams params) { +__global__ void groupNormNDHWCScaleKernel( + const GroupNormNDHWCParams params) { // The instance in the batch. int32_t ni = blockIdx.z; // The channel loaded by that thread (2 channels per thread for F16x2). @@ -521,7 +522,7 @@ __global__ void groupNormNHWCScaleKernel(const GroupNormNHWCParams params) { float sumSq = params.redBuffer[(2 * ni + 1) * params.groups + gi]; // Compute the variance. - float var = sumSq * params.invHWC - (mean * mean); + float var = sumSq * params.invDHWC - (mean * mean); if (params.var_data != nullptr) { params.var_data[ni * params.groups + gi] = var; @@ -530,22 +531,22 @@ __global__ void groupNormNHWCScaleKernel(const GroupNormNHWCParams params) { float invStdDev = rsqrtf(var + params.eps); // The first activation loaded by that block. - int32_t hwBegin = blockIdx.y * params.hwPerBlock; + int32_t dhwBegin = blockIdx.y * params.dhwPerBlock; // The last activation loaded by that block. - int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + int32_t dhwEnd = min(dhwBegin + params.dhwPerBlock, params.dhw); GroupNormCompute( - hwBegin, hwEnd, ci, params, mean, invStdDev); + dhwBegin, dhwEnd, ci, params, mean, invStdDev); } template -void groupNormNHWCScale::operator()(const GroupNormNHWCParams& params, - gpuStream_t stream) { +void groupNormNDHWCScale::operator()(const GroupNormNDHWCParams& params, + gpuStream_t stream) { dim3 grid; // The number of blocks to compute all the channels. grid.x = divUp(params.c, params.cPerBlock); // The number of blocks to compute all the activations in a given instance. - grid.y = divUp(params.hw, params.hwPerBlock); + grid.y = divUp(params.dhw, params.dhwPerBlock); // The number of instances. grid.z = params.n; @@ -553,59 +554,59 @@ void groupNormNHWCScale::operator()(const GroupNormNHWCParams& params, switch (params.cPerBlock) { case 512: case 480: - groupNormNHWCScaleKernel<<>>(params); + groupNormNDHWCScaleKernel<<>>(params); break; case 320: - groupNormNHWCScaleKernel<<>>(params); + groupNormNDHWCScaleKernel<<>>(params); break; case 256: - groupNormNHWCScaleKernel<<>>(params); + groupNormNDHWCScaleKernel<<>>(params); break; case 128: - groupNormNHWCScaleKernel<<>>(params); + groupNormNDHWCScaleKernel<<>>(params); break; default: grid.x = divUp(params.c, 128); - groupNormNHWCScaleKernel<<>>(params); + groupNormNDHWCScaleKernel<<>>(params); } } else { switch (params.cPerBlock) { case 512: - groupNormNHWCScaleKernel<<>>(params); + groupNormNDHWCScaleKernel<<>>(params); break; case 480: - groupNormNHWCScaleKernel<<>>(params); + groupNormNDHWCScaleKernel<<>>(params); break; case 320: - groupNormNHWCScaleKernel<<>>(params); + groupNormNDHWCScaleKernel<<>>(params); break; case 256: - groupNormNHWCScaleKernel<<>>(params); + groupNormNDHWCScaleKernel<<>>(params); break; case 128: - groupNormNHWCScaleKernel<<>>(params); + groupNormNDHWCScaleKernel<<>>(params); break; default: grid.x = divUp(params.c, 128); - groupNormNHWCScaleKernel<<>>(params); + groupNormNDHWCScaleKernel<<>>(params); } } } -template class groupNormNHWCScale; +template class groupNormNDHWCScale; template -void GroupNormNHWCKernel(const Context& dev_ctx, - const DenseTensor& x, - const paddle::optional& scale, - const paddle::optional& bias, - float epsilon, - int groups, - const std::string& data_layout_str, - DenseTensor* y, - DenseTensor* mean, - DenseTensor* var) { +void GroupNormNDHWCKernel(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional& scale, + const paddle::optional& bias, + float epsilon, + int groups, + const std::string& data_layout_str, + DenseTensor* y, + DenseTensor* mean, + DenseTensor* var) { using AccT = typename phi::dtype::MPTypeTrait::Type; - GroupNormNHWCParams params_; + GroupNormNDHWCParams params_; params_.withSilu = false; const auto x_dims = x.dims(); @@ -618,10 +619,25 @@ void GroupNormNHWCKernel(const Context& dev_ctx, if (scale_ptr) scale_data = scale_ptr->data(); const T* bias_data = nullptr; if (bias_ptr) bias_data = bias_ptr->data(); + const auto d_dim = x_dims.size(); params_.n = x_dims[0]; - params_.c = x_dims[3]; - params_.h = x_dims[1]; - params_.w = x_dims[2]; + if (d_dim == 3) { + params_.c = x_dims[2]; + params_.d = 1; + params_.h = 1; + params_.w = x_dims[1]; + } else if (d_dim == 4) { + params_.c = x_dims[3]; + params_.d = 1; + params_.h = x_dims[1]; + params_.w = x_dims[2]; + } else { + // d_dim == 5 + params_.c = x_dims[4]; + params_.d = x_dims[1]; + params_.h = x_dims[2]; + params_.w = x_dims[3]; + } dev_ctx.template Alloc(mean); dev_ctx.template Alloc(var); @@ -630,7 +646,7 @@ void GroupNormNHWCKernel(const Context& dev_ctx, params_.var_data = var_data; int32_t cPerBlock = 320; - int32_t maxBlocksPerHW = 1024; + int32_t maxBlocksPerDHW = 1024; switch (params_.c) { case 2048: case 1024: @@ -660,12 +676,12 @@ void GroupNormNHWCKernel(const Context& dev_ctx, params_.gamma = scale_data; params_.beta = bias_data; - params_.hw = params_.h * params_.w; - const int32_t blocksPerHW = findMaxDivisor(params_.hw, maxBlocksPerHW); - params_.hwPerBlock = divUp(params_.hw, blocksPerHW); + params_.dhw = params_.d * params_.h * params_.w; + const int32_t blocksPerDHW = findMaxDivisor(params_.dhw, maxBlocksPerDHW); + params_.dhwPerBlock = divUp(params_.dhw, blocksPerDHW); params_.cPerBlock = cPerBlock; - params_.hwc = params_.hw * params_.c; - params_.invHWC = 1.F / static_cast(params_.hw * params_.cPerGroup); + params_.dhwc = params_.dhw * params_.c; + params_.invDHWC = 1.F / static_cast(params_.dhw * params_.cPerGroup); params_.eps = epsilon; auto stream = dev_ctx.stream(); DenseTensor redBuffer; @@ -677,10 +693,10 @@ void GroupNormNHWCKernel(const Context& dev_ctx, #else cudaMemset(params_.redBuffer, 0, buffer_sizes * sizeof(float)); #endif - groupNormNHWCSum nhwc_sum; - nhwc_sum(¶ms_, stream); - groupNormNHWCScale nhwc_scale; - nhwc_scale(params_, stream); + groupNormNDHWCSum ndhwc_sum; + ndhwc_sum(¶ms_, stream); + groupNormNDHWCScale ndhwc_scale; + ndhwc_scale(params_, stream); #ifdef PADDLE_WITH_HIP phi::backends::gpu::GpuMemcpyAsync(mean_data, params_.redBuffer, @@ -1011,22 +1027,7 @@ void GroupNormKernel(const Context& dev_ctx, DenseTensor* var) { using std::is_same; if (is_same::value && data_layout_str == "NHWC") { - GroupNormNHWCKernel(dev_ctx, - x, - scale, - bias, - epsilon, - groups, - data_layout_str, - y, - mean, - var); - return; - } - -#ifdef PADDLE_CUDA_BF16 - if (is_same::value && data_layout_str == "NHWC") { - GroupNormNHWCKernel(dev_ctx, + GroupNormNDHWCKernel(dev_ctx, x, scale, bias, @@ -1038,6 +1039,21 @@ void GroupNormKernel(const Context& dev_ctx, var); return; } + +#ifdef PADDLE_CUDA_BF16 + if (is_same::value && data_layout_str == "NHWC") { + GroupNormNDHWCKernel(dev_ctx, + x, + scale, + bias, + epsilon, + groups, + data_layout_str, + y, + mean, + var); + return; + } #endif GroupNormGeneralCaseKernel( diff --git a/paddle/phi/kernels/gpu/weight_dequantize_kernel.cu b/paddle/phi/kernels/gpu/weight_dequantize_kernel.cu index 77e71b950ddfae..94a85ee8cf183d 100644 --- a/paddle/phi/kernels/gpu/weight_dequantize_kernel.cu +++ b/paddle/phi/kernels/gpu/weight_dequantize_kernel.cu @@ -33,9 +33,16 @@ void WeightDequantizeKernel(const Context& dev_ctx, DenseTensor* out) { #if defined(PADDLE_WITH_CUTLASS) auto out_dims = out->dims(); + if (algo == "weight_only_int4") { + out->Resize({out_dims[1], out_dims[0] * 2}); + } dev_ctx.template Alloc(out); WeightDequantize(dev_ctx, x, scale, algo, true, group_size, out); - out->Resize({{out_dims[1], out_dims[0]}}); + if (algo == "weight_only_int4") { + out->Resize({out_dims[1], out_dims[0] * 2}); + } else { + out->Resize({{out_dims[1], out_dims[0]}}); + } auto out_tmp = Transpose(dev_ctx, *out, {1, 0}); out->ShareDataWith(out_tmp); #else diff --git a/paddle/phi/kernels/gpu/weight_quantize_kernel.cu b/paddle/phi/kernels/gpu/weight_quantize_kernel.cu index 51b4786155a923..49f0e49725e40d 100644 --- a/paddle/phi/kernels/gpu/weight_quantize_kernel.cu +++ b/paddle/phi/kernels/gpu/weight_quantize_kernel.cu @@ -37,11 +37,8 @@ void WeightQuantizeKernel(const Context& dev_ctx, "Currently, group_size only support -1(per-channel), 64 or 128.")); DenseTensor quanted_x; - dev_ctx.template Alloc(out); size_t m = x.dims()[0]; size_t n = x.dims()[1]; - quanted_x.Resize({static_cast(m), static_cast(n)}); - dev_ctx.template Alloc(&quanted_x); std::vector weight_shape{static_cast(x.dims()[0]), static_cast(x.dims()[1])}; PADDLE_ENFORCE_EQ( @@ -51,31 +48,54 @@ void WeightQuantizeKernel(const Context& dev_ctx, "Currently, arch only support 70, 75, 80, 86.")); if (algo == "llm.int8") { + quanted_x.Resize({static_cast(m), static_cast(n)}); + dev_ctx.template Alloc(&quanted_x); dev_ctx.template Alloc(scale); + dev_ctx.template Alloc(out); std::vector axis = {1, 0}; funcs::Transpose trans; weight_quant_gpu(dev_ctx, x.data(), quanted_x.data(), scale->data(), - weight_shape); + weight_shape, + algo); trans(dev_ctx, quanted_x, out, axis); } else if (algo == "weight_only_int8") { + quanted_x.Resize({static_cast(m), static_cast(n)}); + dev_ctx.template Alloc(&quanted_x); dev_ctx.template Alloc(scale); + dev_ctx.template Alloc(out); weight_quant_gpu(dev_ctx, x.data(), quanted_x.data(), scale->data(), - weight_shape); + weight_shape, + algo); weight_permute_gpu(dev_ctx, quanted_x.data(), out->data(), weight_shape, - arch); + arch, + algo); } else if (algo == "weight_only_int4") { - PADDLE_FATAL( - "Weight quant gpu kernel currently don't support weight_only_int4 " - "algo, please use cpu version."); + quanted_x.Resize({static_cast(m / 2), static_cast(n)}); + dev_ctx.template Alloc(&quanted_x); + dev_ctx.template Alloc(scale); + out->Resize({static_cast(n), static_cast(m / 2)}); + dev_ctx.template Alloc(out); + weight_quant_gpu(dev_ctx, + x.data(), + quanted_x.data(), + scale->data(), + weight_shape, + algo); + weight_permute_gpu(dev_ctx, + quanted_x.data(), + out->data(), + weight_shape, + arch, + algo); } else { PADDLE_FATAL( "The algo must be in ['weight_only_int8', 'weight_only_int4', " diff --git a/paddle/fluid/operators/fused/yolo_box_head_op.cu b/paddle/phi/kernels/gpu/yolo_box_head_kernel.cu similarity index 57% rename from paddle/fluid/operators/fused/yolo_box_head_op.cu rename to paddle/phi/kernels/gpu/yolo_box_head_kernel.cu index abb7b5aeaae00f..a4821e6534463d 100644 --- a/paddle/fluid/operators/fused/yolo_box_head_op.cu +++ b/paddle/phi/kernels/gpu/yolo_box_head_kernel.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,12 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/platform/device_context.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/common/memory_utils.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/yolo_box_util.h" -namespace paddle { -namespace operators { +namespace phi { template inline __device__ T SigmoidGPU(const T& x) { @@ -63,45 +65,37 @@ __global__ void YoloBoxHeadCudaKernel(const T* input, } } -template -class YoloBoxHeadKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* x = context.Input("X"); - auto* out = context.Output("Out"); - auto anchors = context.Attr>("anchors"); - auto class_num = context.Attr("class_num"); - auto& device_ctx = context.template device_context(); - auto x_dims = x->dims(); - const int batch_size = x_dims[0]; - const int h = x_dims[2]; - const int w = x_dims[3]; - const int grid_size_x = w; - const int grid_size_y = h; - const int anchors_num = anchors.size() / 2; - const T* input_data = x->data(); - T* output_data = device_ctx.Alloc(out, out->numel() * sizeof(T)); - auto stream = device_ctx.stream(); - const int volume = x_dims[1] * h * w; - dim3 block(16, 16, 4); - dim3 grid((grid_size_x / block.x) + 1, - (grid_size_y / block.y) + 1, - (anchors_num / block.z) + 1); - for (int n = 0; n < batch_size; n++) { - YoloBoxHeadCudaKernel<<>>( - input_data + n * volume, - output_data + n * volume, - grid_size_x, - grid_size_y, - class_num, - anchors_num); - } +template +void YoloBoxHeadKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& anchors, + int class_num, + DenseTensor* out) { + auto x_dims = x.dims(); + const int batch_size = x_dims[0]; + const int h = x_dims[2]; + const int w = x_dims[3]; + const int grid_size_x = w; + const int grid_size_y = h; + const int anchors_num = anchors.size() / 2; + const T* input_data = x.data(); + T* output_data = dev_ctx.template Alloc(out, out->numel() * sizeof(T)); + auto stream = dev_ctx.stream(); + const int volume = x_dims[1] * h * w; + dim3 block(16, 16, 4); + dim3 grid((grid_size_x / block.x) + 1, + (grid_size_y / block.y) + 1, + (anchors_num / block.z) + 1); + for (int n = 0; n < batch_size; n++) { + YoloBoxHeadCudaKernel<<>>(input_data + n * volume, + output_data + n * volume, + grid_size_x, + grid_size_y, + class_num, + anchors_num); } -}; - -} // namespace operators -} // namespace paddle +} +} // namespace phi -namespace ops = paddle::operators; -PD_REGISTER_STRUCT_KERNEL( - yolo_box_head, GPU, ALL_LAYOUT, ops::YoloBoxHeadKernel, float) {} +PD_REGISTER_KERNEL( + yolo_box_head, GPU, ALL_LAYOUT, phi::YoloBoxHeadKernel, float) {} diff --git a/paddle/phi/kernels/group_norm_kernel.h b/paddle/phi/kernels/group_norm_kernel.h index 9acdeca0e67478..3dc10df6a11094 100644 --- a/paddle/phi/kernels/group_norm_kernel.h +++ b/paddle/phi/kernels/group_norm_kernel.h @@ -58,14 +58,14 @@ class GroupNormDirectCUDAFunctor { #endif template -struct GroupNormNHWCParams { - // The output buffer. Layout NHWC. +struct GroupNormNDHWCParams { + // The output buffer. Layout NDHWC. T* dst; - // The output buffer. Layout NHWC. + // The output buffer. Layout NDHWC. T* eleOut; - // The input buffer. Layout NHWC. + // The input buffer. Layout NDHWC. T const* srcX; - // The input buffer. Layout NHWC. + // The input buffer. Layout NDHWC. T const* srcY; // The gamma scaling factor. void const* gamma; @@ -79,8 +79,8 @@ struct GroupNormNHWCParams { // The number of instances in the batch. int32_t n; - // The height and width of each activation map. - int32_t h, w; + // The depth, height and width of each activation map. + int32_t d, h, w; // The number of channels. int32_t c; // The number of groups. @@ -90,36 +90,36 @@ struct GroupNormNHWCParams { // Precomputed values and parameters to control the execution of the kernels. - // The number of activations per instance (h * w) and the number of + // The number of activations per instance (d * h * w) and the number of // activations per block. - int32_t hw, hwPerBlock; + int32_t dhw, dhwPerBlock; // The number of channels per group and blocks per activation in the C // dimension. int32_t cPerBlock, cPerGroup; // The precomputed stride between instances. - int32_t hwc; - // The inverse of hwc in floats (to compute mean/var). - float invHWC; + int32_t dhwc; + // The inverse of dhwc in floats (to compute mean/var). + float invDHWC; // The precomputed number of groups per block. int32_t groupsPerBlock; // epsilon, Constant for numerical stability float eps; - // for NCHW32 int8 use + // for NCDHW32 int8 use float dqScaleIn; float inv_qScale; }; template -class groupNormNHWCSum { +class groupNormNDHWCSum { public: - void operator()(GroupNormNHWCParams* params, const gpuStream_t stream); + void operator()(GroupNormNDHWCParams* params, const gpuStream_t stream); }; template -class groupNormNHWCScale { +class groupNormNDHWCScale { public: - void operator()(const GroupNormNHWCParams& params, + void operator()(const GroupNormNDHWCParams& params, const gpuStream_t stream); }; diff --git a/paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h b/paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h index 05d0e47b314555..6a98dfb526fbb7 100644 --- a/paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h +++ b/paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h @@ -48,6 +48,89 @@ __global__ void weight_permute_kernel_wint8(const int8_t* input_data_dev, } } +// from +// 0 1 2 3 4 5 6 7... +// to +// 0 8 16 24 1 9 17 25... +__global__ void weight_permute_kernel_wint4(const int8_t* input_data_dev, + int8_t* output_data_dev, + int numel, + int total_k, + int total_n) { + for (int linear_idx = blockIdx.x * blockDim.x + threadIdx.x; + linear_idx < numel; + linear_idx += blockDim.x * gridDim.x) { + int k_id = linear_idx / total_n; + int n_id = linear_idx % total_n; + constexpr int k_permute_const = 8; + int k_mod_8 = k_id % 8; + int temp_k_expr_1 = k_mod_8 - k_mod_8 / 4 * 4; + int temp_k_expr_2 = k_mod_8 / 4; + // we need int4 index like + // 0 8 16 24 1 9 17 25 2 10 18 26 3 11 19 27 + // 4 12 20 28 5 13 21 29 6 14 22 30 7 15 23 31 + // we can change it to + // 0 1 16 17 8 9 24 25 2 3 18 19 10 11 26 27 + // 4 5 20 21 12 13 28 29 6 7 22 23 14 15 30 31 + // 2 int4 pack to a int8 + // 0 8 4 12 1 9 5 13 2 10 6 14 3 11 7 15 + // find index of above list + // 0 4 8 12 2 6 10 14 1 5 9 13 3 7 11 15 + // we know int8 index is + // 0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15 + // change it to + // 0 2 4 6 1 3 5 7 8 10 12 14 9 11 13 15 + // % 8 * 2 + // 0 4 8 12 2 6 10 14 0 4 8 12 2 6 10 14 + // add 1 for 0 4 8 12 2 6 10 14 [0 4 8 12 2 6 10 14] + // we get 0 4 8 12 2 6 10 14 1 5 9 13 3 7 11 15 + // it change ori to 0 8 4 12... + // finally we do some bitwise operation to change int4index + int permute_kk = (temp_k_expr_1 + temp_k_expr_2 + + (temp_k_expr_2 + 1) % 2 * k_mod_8 * 2 / 2 + + temp_k_expr_1 * temp_k_expr_2) % + 8 * 2 + + (k_id % 16) / 8 + k_id / 16 * 16; + int permute_index = permute_kk % 32 + permute_kk / 32 * 128 + + 32 * (n_id % 4) + total_k * 4 * (n_id / 4); + int8_t shift_quant_weight = input_data_dev[linear_idx]; + output_data_dev[permute_index] = + *reinterpret_cast(&shift_quant_weight); + } + constexpr int value_per_interval_thread = 4; + constexpr int pack_size = 2; + for (int linear_idx = + blockIdx.x * blockDim.x + threadIdx.x * value_per_interval_thread; + linear_idx < numel; + linear_idx += blockDim.x * gridDim.x * 4) { + for (int pack = 0; pack < pack_size; ++pack) { + uint8_t interval_weight_0 = static_cast( + static_cast(output_data_dev[linear_idx + pack])); + uint8_t interval_weight_1 = static_cast( + static_cast(output_data_dev[linear_idx + pack + 2])); + + uint8_t interval_weight_0_l = + static_cast(((interval_weight_0 & 0x0F) + 8) & 0x0F); + uint8_t interval_weight_0_r = + static_cast(((interval_weight_0 >> 4) + 8) & 0x0F); + uint8_t interval_weight_1_l = + static_cast(((interval_weight_1 & 0x0F) + 8) & 0x0F); + uint8_t interval_weight_1_r = + static_cast(((interval_weight_1 >> 4) + 8) & 0x0F); + + uint8_t new_interval_weight_0 = + interval_weight_0_l | (interval_weight_1_l << 4); + uint8_t new_interval_weight_1 = + interval_weight_0_r | (interval_weight_1_r << 4); + + output_data_dev[linear_idx + pack] = + *reinterpret_cast(&new_interval_weight_0); + output_data_dev[linear_idx + pack + 2] = + *reinterpret_cast(&new_interval_weight_1); + } + } +} + /* For SM70 volta arch, weightonly int8 dequantize invoked in load global memory. So it only need interleave in K-dimension @@ -85,12 +168,51 @@ __global__ void weight_interleave_add_bias_kernel_wint8( } } +/* +For SM70 volta arch, weightonly int8 dequantize invoked in load global memory. +So it only need interleave in K-dimension +K_index: 0 1 2 3 4 5 6 7 -> 0 2 4 6 1 3 5 7 +*/ +__global__ void weight_interleave_add_bias_kernel_wint4( + const int8_t* input_data_dev, + int8_t* output_data_dev, + int numel, + int total_k, + int total_n) { + for (int linear_idx = blockIdx.x * blockDim.x + threadIdx.x; + linear_idx < numel; + linear_idx += blockDim.x * gridDim.x) { + int k_id = linear_idx / total_n; + int n_id = linear_idx % total_n; + constexpr int n_interleaved_factor = 8; + int n_interleave_group_id = n_id / n_interleaved_factor; + int n_interleave_id = n_id % n_interleaved_factor; + + int n_interleave_offset = n_interleave_id / 4; + n_interleave_id = (n_interleave_id % 4) * 2 + n_interleave_offset; + const int new_n_id = + n_interleave_group_id * n_interleaved_factor + n_interleave_id; + const int interleave_idx = k_id * total_n + new_n_id; + + uint8_t offseted_weight = 0; + uint8_t shift_quant_weight = + static_cast(static_cast(input_data_dev[linear_idx])); + uint8_t shift_quant_weight_low = ((shift_quant_weight & 0x0F) + 8) & 0x0F; + uint8_t shift_quant_weight_high = + ((shift_quant_weight >> 4 & 0x0F) + 8) & 0x0F; + offseted_weight = shift_quant_weight_low | (shift_quant_weight_high << 4); + output_data_dev[interleave_idx] = + *reinterpret_cast(&offseted_weight); + } +} + template void weight_permute_gpu(const GPUContext& dev_ctx, int8_t* input_data, int8_t* output_data, const std::vector& shape, - const int32_t arch) { + const int32_t arch, + const std::string& algo) { auto total_k = shape[0]; auto total_n = shape[1]; auto numel = total_k * total_n; @@ -98,11 +220,25 @@ void weight_permute_gpu(const GPUContext& dev_ctx, int grid_size = gpu_config.GetGridSize(); int block_size = gpu_config.GetBlockSize(); if ((arch == 80) || (arch == 86) || (arch == 75)) { - weight_permute_kernel_wint8<<>>( - input_data, output_data, numel, total_k, total_n); + if (algo == "weight_only_int4") { + total_k /= 2; + numel /= 2; + weight_permute_kernel_wint4<<>>( + input_data, output_data, numel, total_k, total_n); + } else { + weight_permute_kernel_wint8<<>>( + input_data, output_data, numel, total_k, total_n); + } } else if (arch == 70) { - weight_interleave_add_bias_kernel_wint8<<>>( - input_data, output_data, numel, total_k, total_n); + if (algo == "weight_only_int4") { + total_k /= 2; + numel /= 2; + weight_permute_kernel_wint4<<>>( + input_data, output_data, numel, total_k, total_n); + } else { + weight_permute_kernel_wint8<<>>( + input_data, output_data, numel, total_k, total_n); + } } } @@ -161,12 +297,76 @@ __global__ void per_channel_quant_gpu(const T* weight_data, } } } + +template +__global__ void per_channel_quant_gpu_int4(const T* weight_data, + int8_t* quanted_weight_data, + ScaleT* scale_data, + int total_k, + int total_vec_n) { + int n = blockIdx.x * blockDim.x + threadIdx.x; + if (n < total_vec_n) { + const int4* vec_weight_data_ptr = + reinterpret_cast(weight_data); + int2* vec_quanted_weight_data = + reinterpret_cast(quanted_weight_data); + phi::AlignedVector abs_max; +#pragma unroll + for (int i = 0; i < VectorSize; ++i) { + abs_max[i] = static_cast(0.0f); + } +#pragma unroll + for (int k = 0; k < total_k; ++k) { + int linear_index = k * total_vec_n + n; + phi::AlignedVector weight; + *reinterpret_cast(&weight) = vec_weight_data_ptr[linear_index]; +#pragma unroll + for (int i = 0; i < VectorSize; ++i) { + abs_max[i] = fmaxf((abs_max[i]), fabsf((weight[i]))); + } + } + phi::AlignedVector scale; +#pragma unroll + for (int i = 0; i < VectorSize; ++i) { + scale[i] = static_cast(abs_max[i] / static_cast(7.0f)); + } + *reinterpret_cast(scale_data + VectorSize * n) = + *reinterpret_cast(&scale); + + for (int k = 0; k < total_k / 2; ++k) { + phi::AlignedVector quanted_weight; + for (int i = 0; i < VectorSize; ++i) { + quanted_weight[i] = 0; + } + for (int packed_idx = 0; packed_idx < 2; ++packed_idx) { + int linear_index = (k * 2 + packed_idx) * total_vec_n + n; + phi::AlignedVector weight; + *reinterpret_cast(&weight) = + *reinterpret_cast(vec_weight_data_ptr + linear_index); +#pragma unroll + for (int i = 0; i < VectorSize; ++i) { + float scaled_weight = + (static_cast(weight[i]) / static_cast(abs_max[i])) * + static_cast(7.0); + int8_t clipped_weight = static_cast( + lroundf(fmaxf(-7.0f, fminf(7.0f, scaled_weight)))); + quanted_weight[i] |= ((clipped_weight & 0x0F) << (4 * packed_idx)); + } + } + int linear_index_new = k * total_vec_n + n; + *reinterpret_cast(vec_quanted_weight_data + linear_index_new) = + *reinterpret_cast(&quanted_weight); + } + } +} + template void weight_quant_gpu(const GPUContext& dev_ctx, const T* weight_data, int8_t* quanted_weight_data, ScaleT* scale_data, - const std::vector& shape) { + const std::vector& shape, + const std::string& algo) { int total_k = shape[0]; int total_n = shape[1]; int numel = total_k * total_n; @@ -183,8 +383,13 @@ void weight_quant_gpu(const GPUContext& dev_ctx, int vec_total_n = total_n / kVectorSize; int kGridSize = max((vec_total_n + kBlockSize - 1) / kBlockSize, static_cast(1)); - per_channel_quant_gpu<<>>( - weight_data, quanted_weight_data, scale_data, total_k, vec_total_n); + if (algo == "weight_only_int4") { + per_channel_quant_gpu_int4<<>>( + weight_data, quanted_weight_data, scale_data, total_k, vec_total_n); + } else { + per_channel_quant_gpu<<>>( + weight_data, quanted_weight_data, scale_data, total_k, vec_total_n); + } } } // namespace phi diff --git a/paddle/phi/kernels/xpu/flash_attn_grad_kernel.cc b/paddle/phi/kernels/xpu/flash_attn_grad_kernel.cc index 0dd3c137898680..82d44ac0aad429 100644 --- a/paddle/phi/kernels/xpu/flash_attn_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/flash_attn_grad_kernel.cc @@ -69,9 +69,28 @@ void FlashAttnGradKernel(const Context& ctx, const XPUType* out_data = reinterpret_cast(out.data()); const float* softmax_lse_data = softmax_lse.data(); const XPUType* dout_data = reinterpret_cast(dout.data()); + + xpu::ctx_guard RAII_GUARD(ctx.x_context()); const float* bias_data = nullptr; if (attn_mask.get_ptr() != nullptr) { - bias_data = attn_mask->data(); + if (attn_mask->dtype() == phi::DataType::FLOAT32) { + bias_data = attn_mask->data(); + } else if (attn_mask->dtype() == phi::DataType::FLOAT16 || + attn_mask->dtype() == phi::DataType::BFLOAT16) { + float* bias_tmp = RAII_GUARD.alloc_l3_or_gm(attn_mask->numel()); + int r = xpu::cast( + ctx.x_context(), + reinterpret_cast(attn_mask->data()), + bias_tmp, + attn_mask->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); + bias_data = bias_tmp; + } else { + errors::Unimplemented( + "Unsupported dtype for attention_mask in xpu flash attention, only " + "float32, float16 and " + "bfloat16 are supported."); + } } // output XPUType* dq_data = reinterpret_cast(dq->data()); @@ -92,6 +111,7 @@ void FlashAttnGradKernel(const Context& ctx, // get seed offset const int64_t* seed_offset_data = seed_offset.data(); + // template // int mha_varlen_bwd(xdnn::Context* ctx, const T* dout, const T* q, const T* // k, const T* v, const T* out, const TACCUM* softmax_lse, T* dq, T* dk, T* @@ -106,28 +126,28 @@ void FlashAttnGradKernel(const Context& ctx, // dv_maxptr = nullptr, const float* do_maxptr = nullptr); int r = baidu::xpu::xfa::mha_varlen_bwd( ctx.x_context(), - dout_data, // dout - q_data, // q - k_data, // k - v_data, // v - out_data, // out - softmax_lse_data, // softmax_lse - dq_data, // dq - dk_data, // dk - dv_data, // dv - qlod, // lod_seqlens_q - kvlod, // lod_seqlens_k - seqlen_q, // max_seqlen_q - seqlen_k, // max_seqlen_k - num_heads, // head_num - num_heads_k, // head_num_k - head_size, // head_dim - 1.0f / std::sqrt(head_size), // softmax_scale - dropout, // p_dropout - static_cast(seed_offset_data[0]), // seed - causal, // is_causal - nullptr, // attn_mask - bias_data // bias + dout_data, // dout + q_data, // q + k_data, // k + v_data, // v + out_data, // out + softmax_lse_data, // softmax_lse + dq_data, // dq + dk_data, // dk + dv_data, // dv + qlod, // lod_seqlens_q + kvlod, // lod_seqlens_k + seqlen_q, // max_seqlen_q + seqlen_k, // max_seqlen_k + num_heads, // head_num + num_heads_k, // head_num_k + head_size, // head_dim + 1.0f / std::sqrt(head_size), // softmax_scale + dropout, // p_dropout + static_cast(seed_offset_data[0]), // seed + causal, // is_causal + nullptr, // attn_mask + bias_data // bias ); PADDLE_ENFORCE_XDNN_SUCCESS(r, "mha_varlen_bwd"); #else diff --git a/paddle/phi/kernels/xpu/flash_attn_kernel.cc b/paddle/phi/kernels/xpu/flash_attn_kernel.cc index bdfab918db027c..0e4da3483290dc 100644 --- a/paddle/phi/kernels/xpu/flash_attn_kernel.cc +++ b/paddle/phi/kernels/xpu/flash_attn_kernel.cc @@ -14,7 +14,7 @@ #include "paddle/phi/kernels/flash_attn_kernel.h" #include "paddle/phi/backends/xpu/enforce_xpu.h" -#include "paddle/phi/core/enforce.h" +#include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/kernel_registry.h" #ifdef PADDLE_WITH_XPU_XHPC @@ -239,10 +239,18 @@ void FlashAttnKernel(const Context& ctx, seed_offset->Resize({2}); int64_t* seed_offset_data = ctx.template HostAlloc(seed_offset); if (fixed_seed_offset.get_ptr()) { - const int64_t* fixed_seed_offset_data = - fixed_seed_offset.get_ptr()->data(); - seed_offset_data[0] = fixed_seed_offset_data[0]; - seed_offset_data[1] = fixed_seed_offset_data[1]; + if ((fixed_seed_offset->place()).GetType() == phi::AllocationType::XPU) { + memory_utils::Copy(phi::CPUPlace(), + seed_offset_data, + fixed_seed_offset->place(), + fixed_seed_offset->data(), + sizeof(int64_t) * 2); + } else { + const int64_t* fixed_seed_offset_data = + fixed_seed_offset->data(); + seed_offset_data[0] = fixed_seed_offset_data[0]; + seed_offset_data[1] = fixed_seed_offset_data[1]; + } } else { std::pair seed_offset_pair; uint64_t inc = batch_size * num_heads * 32; @@ -263,11 +271,29 @@ void FlashAttnKernel(const Context& ctx, const XPUType* k_data = reinterpret_cast(k.data()); const XPUType* v_data = reinterpret_cast(v.data()); XPUType* out_data = reinterpret_cast(out->data()); - float* softmax_lse_data = softmax_lse->data(); + xpu::ctx_guard RAII_GUARD(ctx.x_context()); + float* softmax_lse_data = softmax_lse->data(); const float* bias_data = nullptr; if (attn_mask.get_ptr() != nullptr) { - bias_data = attn_mask->data(); + if (attn_mask->dtype() == phi::DataType::FLOAT32) { + bias_data = attn_mask->data(); + } else if (attn_mask->dtype() == phi::DataType::FLOAT16 || + attn_mask->dtype() == phi::DataType::BFLOAT16) { + float* bias_tmp = RAII_GUARD.alloc_l3_or_gm(attn_mask->numel()); + int r = xpu::cast( + ctx.x_context(), + reinterpret_cast(attn_mask->data()), + bias_tmp, + attn_mask->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); + bias_data = bias_tmp; + } else { + errors::Unimplemented( + "Unsupported dtype for attention_mask in xpu flash attention, only " + "float32, float16 and " + "bfloat16 are supported."); + } } // template int // mha_varlen_fwd(xdnn::Context* ctx, const T* q, const T* k, const T* v, T* @@ -281,24 +307,24 @@ void FlashAttnKernel(const Context& ctx, // nullptr); int r = baidu::xpu::xfa::mha_varlen_fwd( ctx.x_context(), - q_data, // q - k_data, // k - v_data, // v - out_data, // out - softmax_lse_data, // softmax_lse - qlod, // lod_seqlens_q - kvlod, // lod_seqlens_k - seqlen_q, // max_seqlen_q - seqlen_k, // max_seqlen_k - num_heads, // head_num - num_heads_k, // head_num_k - head_size, // head_dim - 1.0f / std::sqrt(head_size), // softmax_scale - dropout, // p_dropout - static_cast(seed_offset_data[0]), // seed - causal, // is_causal - nullptr, // attn_mask - bias_data // bias + q_data, // q + k_data, // k + v_data, // v + out_data, // out + softmax_lse_data, // softmax_lse + qlod, // lod_seqlens_q + kvlod, // lod_seqlens_k + seqlen_q, // max_seqlen_q + seqlen_k, // max_seqlen_k + num_heads, // head_num + num_heads_k, // head_num_k + head_size, // head_dim + 1.0f / std::sqrt(head_size), // softmax_scale + dropout, // p_dropout + static_cast(seed_offset_data[0]), // seed + causal, // is_causal + nullptr, // attn_mask + bias_data // bias ); PADDLE_ENFORCE_XDNN_SUCCESS(r, "mha_varlen_fwd"); #else diff --git a/paddle/pir/include/dialect/shape/interface/infer_symbolic_shape/infer_symbolic_shape.h b/paddle/pir/include/dialect/shape/interface/infer_symbolic_shape/infer_symbolic_shape.h index 6b42909ab6fa6e..bd4b16bbc75fdd 100644 --- a/paddle/pir/include/dialect/shape/interface/infer_symbolic_shape/infer_symbolic_shape.h +++ b/paddle/pir/include/dialect/shape/interface/infer_symbolic_shape/infer_symbolic_shape.h @@ -31,17 +31,17 @@ class InferSymbolicShapeInterface /// Defined these methods with the interface. struct Concept { explicit Concept(bool (*infer_symbolic_shapes)( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis)) + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context)) : infer_symbolic_shapes(infer_symbolic_shapes) {} bool (*infer_symbolic_shapes)( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context); }; template struct Model : public Concept { static inline bool InferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return op->dyn_cast().InferSymbolicShape(shape_analysis); + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + return op->dyn_cast().InferSymbolicShape(infer_context); } Model() : Concept(InferSymbolicShape) {} @@ -51,7 +51,7 @@ class InferSymbolicShapeInterface InferSymbolicShapeInterface(pir::Operation *op, Concept *impl) : pir::OpInterfaceBase(op), impl_(impl) {} - bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis); + bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context); private: Concept *impl_; diff --git a/paddle/pir/include/dialect/shape/transforms/shape_optimization_pass.h b/paddle/pir/include/dialect/shape/transforms/shape_optimization_pass.h index 5050ea727e678f..5e7f2c1142b982 100644 --- a/paddle/pir/include/dialect/shape/transforms/shape_optimization_pass.h +++ b/paddle/pir/include/dialect/shape/transforms/shape_optimization_pass.h @@ -26,7 +26,7 @@ class Pass; IR_API std::unique_ptr CreateShapeOptimizationPass(); void InferSymExprForBlock(const Block &block, - ShapeConstraintIRAnalysis *shape_analysis); + InferSymbolicShapeContext *infer_context); } // namespace pir diff --git a/paddle/pir/include/dialect/shape/utils/shape_analysis.h b/paddle/pir/include/dialect/shape/utils/shape_analysis.h index 30fb0021b177ac..f51c2d51422dbd 100644 --- a/paddle/pir/include/dialect/shape/utils/shape_analysis.h +++ b/paddle/pir/include/dialect/shape/utils/shape_analysis.h @@ -27,23 +27,17 @@ namespace pir { -// The implementation is based on shape constraint ir. -class IR_API ShapeConstraintIRAnalysis final - : public std::enable_shared_from_this { +class IR_API InferSymbolicShapeContext { public: - ShapeConstraintIRAnalysis() = default; - ShapeConstraintIRAnalysis(const ShapeConstraintIRAnalysis&) = delete; - ShapeConstraintIRAnalysis(ShapeConstraintIRAnalysis&&) = delete; - void Init(); const std::string GetNextSymName(); bool HasShapeOrDataForValue(Value val) const; - void InferShapeOrDataForValue(Value val); + const symbol::ShapeOrDataDimExprs& GetShapeOrDataForValue(Value val) const; - const symbol::ShapeOrDataDimExprs& GetShapeOrDataForValue(Value val); + void SetStaticShapeForValue(Value val); void SetShapeOrDataForValue(Value val, const symbol::ShapeOrDataDimExprs& shape_or_data); @@ -59,6 +53,47 @@ class IR_API ShapeConstraintIRAnalysis final void AddBroadcastableCstr(const symbol::DimExpr& lhs, const symbol::DimExpr& rhs); + bool IsBroadcastable(const symbol::DimExpr& lhs, + const symbol::DimExpr& rhs) const; + + void PrintShapeOrDatas() const; + + private: + void SubstituteDimExpr(const symbol::DimExpr& origin, + const symbol::DimExpr& substituted); + + private: + int64_t next_sym_idx_ = 0; + + std::unordered_map + value_id_to_shape_or_data_; + + symbol::ConstraintsManager constraints_manager_; + + using DimExprSubstitutionPattern = + std::unordered_map; + DimExprSubstitutionPattern substitution_pattern_; +}; + +class IR_API ShapeConstraintIRAnalysis final + : public std::enable_shared_from_this { + public: + ShapeConstraintIRAnalysis() = default; + ShapeConstraintIRAnalysis(const ShapeConstraintIRAnalysis&) = delete; + ShapeConstraintIRAnalysis(ShapeConstraintIRAnalysis&&) = delete; + void Init(); + + const std::string GetNextSymName(); + + const symbol::ShapeOrDataDimExprs& GetShapeOrDataForValue(Value val); + + void SetShapeOrDataForValue(Value val, + const symbol::ShapeOrDataDimExprs& shape_or_data); + + bool IsEqual(const symbol::DimExpr& lhs, const symbol::DimExpr& rhs) const; + + bool IsGreatThanOne(const symbol::DimExpr& dim_expr) const; + bool IsBroadcastable(const symbol::DimExpr& lhs, const symbol::DimExpr& rhs) const; @@ -93,23 +128,19 @@ class IR_API ShapeConstraintIRAnalysis final symbol::DimExpr GetProductDimExpr(Value lhs, const std::vector& lhs_dim_idxs); - private: - void SubstituteDimExpr(const symbol::DimExpr& origin, - const symbol::DimExpr& substituted); + // TODO(hongqing-work): make it a private component only for infer friend + // class + InferSymbolicShapeContext* GetInferSymbolicShapeContext() { + return &context_; + } private: - ModuleOp m_; - - int64_t next_sym_idx_ = 0; - - std::unordered_map - value_to_shape_or_data_; + void SetStaticShapeForValue(Value val); - symbol::ConstraintsManager constraints_manager_; + void InferShapeOrDataForValue(Value val); - using DimExprSubstitutionPattern = - std::unordered_map; - DimExprSubstitutionPattern substitution_pattern_; + private: + InferSymbolicShapeContext context_; }; class IR_API ShapeAnalysisManager { @@ -129,6 +160,8 @@ class IR_API ShapeAnalysisManager { #define OP_DECLARE_INFER_SYMBOLIC_SHAPE(name) \ bool name##OpInferSymbolicShape( \ - pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis); + pir::Operation* op, pir::InferSymbolicShapeContext* infer_context); + +bool IsStaticShape(const Value& value); } // namespace pir diff --git a/paddle/pir/src/dialect/shape/interface/infer_symbolic_shape/infer_symbolic_shape.cc b/paddle/pir/src/dialect/shape/interface/infer_symbolic_shape/infer_symbolic_shape.cc index dbe25e171e725f..f1c44e945f60c2 100644 --- a/paddle/pir/src/dialect/shape/interface/infer_symbolic_shape/infer_symbolic_shape.cc +++ b/paddle/pir/src/dialect/shape/interface/infer_symbolic_shape/infer_symbolic_shape.cc @@ -26,8 +26,8 @@ namespace pir { bool InferSymbolicShapeInterface::InferSymbolicShape( - pir::ShapeConstraintIRAnalysis *shape_analysis) { - return impl_->infer_symbolic_shapes(operation(), shape_analysis); + pir::InferSymbolicShapeContext *infer_context) { + return impl_->infer_symbolic_shapes(operation(), infer_context); } } // namespace pir diff --git a/paddle/pir/src/dialect/shape/transforms/shape_optimization_pass.cc b/paddle/pir/src/dialect/shape/transforms/shape_optimization_pass.cc index a1ca2ba4a54165..60cb05ec41d789 100644 --- a/paddle/pir/src/dialect/shape/transforms/shape_optimization_pass.cc +++ b/paddle/pir/src/dialect/shape/transforms/shape_optimization_pass.cc @@ -120,17 +120,16 @@ void PrintOpInfo(pir::Operation* op) { } } -void DebugPrintOpInfo( - pir::Operation* op, - pir::ShapeConstraintIRAnalysis* shape_analysis = nullptr) { +void DebugPrintOpInfo(pir::Operation* op, + pir::InferSymbolicShapeContext* infer_context = nullptr) { std::ostringstream print_stream; for (uint32_t i = 0; i < op->num_results(); ++i) { const auto& res = op->result(i); print_stream << "\tresult(" << res.dyn_cast().index() << ") " << "ShapeOrData: {"; - if (shape_analysis != nullptr) { - auto shape_data = shape_analysis->GetShapeOrDataForValue(res); + if (infer_context != nullptr) { + auto shape_data = infer_context->GetShapeOrDataForValue(res); if (shape_data.isa()) continue; print_stream << "shape: ["; @@ -167,7 +166,7 @@ void DebugPrintOpInfo( void CheckInferSymWithInferMeta( pir::Operation* op, - pir::ShapeConstraintIRAnalysis* shape_analysis = nullptr) { + pir::InferSymbolicShapeContext* infer_context = nullptr) { for (uint32_t i = 0; i < op->num_results(); ++i) { const auto& res = op->result(i); std::ostringstream print_stream; @@ -179,7 +178,7 @@ void CheckInferSymWithInferMeta( const std::vector& infer_meta_shape = common::vectorize(res.type().dyn_cast().dims()); const std::vector& infer_sym_shape = - shape_analysis->GetShapeOrDataForValue(res).shape(); + infer_context->GetShapeOrDataForValue(res).shape(); // Check rank. if (infer_meta_shape.size() != infer_sym_shape.size()) { @@ -231,9 +230,10 @@ void InferSymExprForAllValues(ModuleOp module_op) { ShapeConstraintIRAnalysis& shape_analysis = ShapeAnalysisManager::Instance().Get(module_op.program()); shape_analysis.Init(); + auto infer_context = shape_analysis.GetInferSymbolicShapeContext(); for (uint32_t i = 0; i < module_op->num_regions(); i++) { for (auto& block : module_op->region(i)) { - InferSymExprForBlock(block, &shape_analysis); + InferSymExprForBlock(block, infer_context); } } } @@ -272,16 +272,8 @@ class ShapeOptimizationPass : public pir::Pass { } // namespace -static inline bool IsStaticShape(const Value& value) { - const auto& value_type = value.type(); - if (!value || !value_type || !value_type.isa()) { - return false; - } - return !::common::contain_unknown_dim( - value_type.dyn_cast().dims()); -} - -symbol::ShapeOrDataDimExprs CreateShapeOrDataByDDim(const pir::DDim& dims) { +symbol::TensorShapeOrDataDimExprs CreateShapeOrDataByDDim( + const pir::DDim& dims) { std::vector dim_exprs; for (int i = 0; i < dims.size(); ++i) { dim_exprs.emplace_back(dims.at(i)); @@ -290,14 +282,14 @@ symbol::ShapeOrDataDimExprs CreateShapeOrDataByDDim(const pir::DDim& dims) { } void InferSymExprForBlock(const Block& block, - ShapeConstraintIRAnalysis* shape_analysis) { + InferSymbolicShapeContext* infer_context) { for (auto& op : block) { auto infer_symbolic_shape_interface = op.dyn_cast(); if (infer_symbolic_shape_interface) { PrintOpInfo(&op); PADDLE_ENFORCE_EQ( - infer_symbolic_shape_interface.InferSymbolicShape(shape_analysis), + infer_symbolic_shape_interface.InferSymbolicShape(infer_context), true, "InferSymbolicShape for %s failed.", op.name()); @@ -306,7 +298,7 @@ void InferSymExprForBlock(const Block& block, // TODO(lanxianghit): deal with the ops which have more than 1 // ACTUAL results pir::shape::SetShapeAttrForOp( - &op, shape_analysis->GetShapeOrDataForValue(op.result(0))); + &op, infer_context->GetShapeOrDataForValue(op.result(0))); } } else { const bool all_outs_static_dims = [&] { @@ -324,18 +316,36 @@ void InferSymExprForBlock(const Block& block, if (all_outs_static_dims) { for (uint32_t i = 0; i < op.num_results(); ++i) { - shape_analysis->SetShapeOrDataForValue( - op.result(i), - CreateShapeOrDataByDDim( - op.result(i).type().dyn_cast().dims())); + const Type& value_type = op.result(i).type(); + if (value_type.isa()) { + infer_context->SetShapeOrDataForValue( + op.result(i), + CreateShapeOrDataByDDim( + value_type.dyn_cast().dims())); + continue; + } + if (value_type.isa()) { + const std::vector& vec_data = + value_type.dyn_cast().data(); + symbol::TensorListShapeOrDataDimExprs shape_data_list; + for (unsigned i = 0; i < vec_data.size(); ++i) { + CHECK(vec_data[i].isa()); + const DenseTensorType& type_info = + vec_data[i].dyn_cast(); + shape_data_list.emplace_back( + CreateShapeOrDataByDDim(type_info.dims())); + } + infer_context->SetShapeOrDataForValue(op.result(i), + shape_data_list); + } } } else { PADDLE_THROW(phi::errors::Unimplemented( op.name() + " DOES NOT have InferSymbolicShapeInterface!")); } } - DebugPrintOpInfo(&op, shape_analysis); - CheckInferSymWithInferMeta(&op, shape_analysis); + DebugPrintOpInfo(&op, infer_context); + CheckInferSymWithInferMeta(&op, infer_context); } } diff --git a/paddle/pir/src/dialect/shape/utils/shape_analysis.cc b/paddle/pir/src/dialect/shape/utils/shape_analysis.cc index a573cfbf1c87da..2c71f258fbd7b4 100644 --- a/paddle/pir/src/dialect/shape/utils/shape_analysis.cc +++ b/paddle/pir/src/dialect/shape/utils/shape_analysis.cc @@ -16,8 +16,10 @@ #include #include "paddle/common/bfs_walker.h" #include "paddle/common/topo_walker.h" +#include "paddle/pir/include/core/builtin_type.h" #include "paddle/pir/include/dialect/shape/interface/infer_symbolic_shape/infer_symbolic_shape.h" #include "paddle/pir/include/dialect/shape/utils/dim_expr_util.h" +#include "paddle/pir/src/core/value_impl.h" namespace pir { @@ -29,8 +31,8 @@ static std::string GetValueId(Value val) { std::to_string(val_idx); } -void ShapeConstraintIRAnalysis::Init() { - value_to_shape_or_data_.clear(); +void InferSymbolicShapeContext::Init() { + value_id_to_shape_or_data_.clear(); next_sym_idx_ = 0; constraints_manager_.SetEqualCallbackFunc( [&](const symbol::DimExpr& lhs, const symbol::DimExpr& rhs) { @@ -38,22 +40,216 @@ void ShapeConstraintIRAnalysis::Init() { }); } -const std::string ShapeConstraintIRAnalysis::GetNextSymName() { +const std::string InferSymbolicShapeContext::GetNextSymName() { return "S" + std::to_string(next_sym_idx_++); } -bool ShapeConstraintIRAnalysis::HasShapeOrDataForValue(Value val) const { - return value_to_shape_or_data_.count(val) > 0; +bool InferSymbolicShapeContext::HasShapeOrDataForValue(Value val) const { + if (!val) { + return false; + } + return value_id_to_shape_or_data_.count(val.impl()->id()) > 0; +} + +const symbol::ShapeOrDataDimExprs& +InferSymbolicShapeContext::GetShapeOrDataForValue(Value val) const { + // TODO(Hongqing-work): define a default empty ShapeOrDataDimExprs + if (!val) { + static symbol::ShapeOrDataDimExprs empty{ + symbol::TensorShapeOrDataDimExprs{}}; + return empty; + } + if (!HasShapeOrDataForValue(val)) { + PADDLE_THROW(phi::errors::Fatal( + "Fail to GetShapeOrDataForValue on InferSymbolicShape!")); + } + + return value_id_to_shape_or_data_.at(val.impl()->id()); +} + +void InferSymbolicShapeContext::SetStaticShapeForValue(Value val) { + const auto& value_type = val.type(); + if (!val || !value_type) { + PADDLE_THROW( + phi::errors::Fatal("Set static shape for null value is FOBBIDEN!")); + } + if (!IsStaticShape(val)) { + LOG(WARNING) << "Risk on SetStaticShapeForValue for contain_unknown_dim"; + } + const auto& GetStaticShapeForDenseTensorType = + [&](DenseTensorType type_info) -> symbol::TensorShapeOrDataDimExprs { + std::vector static_shape; + for (int i = 0; i < type_info.dims().size(); ++i) { + int dim = type_info.dims()[i]; + if (dim > 0) { + static_shape.emplace_back(symbol::DimExpr{dim}); + } else { + static_shape.emplace_back(GetNextSymName()); + } + } + return symbol::TensorShapeOrDataDimExprs(static_shape); + }; + + if (value_type.isa()) { + const DenseTensorType& type_info = value_type.dyn_cast(); + SetShapeOrDataForValue(val, GetStaticShapeForDenseTensorType(type_info)); + return; + } + if (value_type.isa()) { + const std::vector& vec_data = + value_type.dyn_cast().data(); + symbol::TensorListShapeOrDataDimExprs shape_data_list; + for (unsigned i = 0; i < vec_data.size(); ++i) { + if (!vec_data[i].isa()) { + PADDLE_THROW(phi::errors::Fatal( + "Set static shape ONLY SUPPORT inner type DenseTensorType!")); + } else { + const DenseTensorType& type_info = + vec_data[i].dyn_cast(); + shape_data_list.emplace_back( + GetStaticShapeForDenseTensorType(type_info)); + } + } + SetShapeOrDataForValue(val, shape_data_list); + return; + } + PADDLE_THROW(phi::errors::Fatal( + "Set static shape ONLY SUPPORT DenseTensorType and VectorType!")); +} + +void InferSymbolicShapeContext::SetShapeOrDataForValue( + Value val, const symbol::ShapeOrDataDimExprs& shape_or_data) { + const symbol::ShapeOrDataDimExprs& substituted_shape_or_data = + symbol::SubstituteShapeOrData(shape_or_data, substitution_pattern_); + if (!val) { + LOG(WARNING) << "Set shape or data for null value"; + return; + } + auto iter = value_id_to_shape_or_data_.find(val.impl()->id()); + if (iter == value_id_to_shape_or_data_.end()) { + value_id_to_shape_or_data_.emplace(val.impl()->id(), + substituted_shape_or_data); + } else { + iter->second = substituted_shape_or_data; + } +} + +void InferSymbolicShapeContext::AddEqualCstr(const symbol::DimExpr& lhs, + const symbol::DimExpr& rhs) { + constraints_manager_.AddEqCstr(lhs, rhs); +} + +bool InferSymbolicShapeContext::IsEqual(const symbol::DimExpr& lhs, + const symbol::DimExpr& rhs) const { + return constraints_manager_.IsEqual(lhs, rhs); +} + +void InferSymbolicShapeContext::AddGreatThanOneCstr( + const symbol::DimExpr& dim_expr) { + constraints_manager_.AddGTOneCstr(dim_expr); +} + +bool InferSymbolicShapeContext::IsGreatThanOne( + const symbol::DimExpr& dim_expr) const { + return constraints_manager_.IsGTOne(dim_expr); +} + +void InferSymbolicShapeContext::AddBroadcastableCstr( + const symbol::DimExpr& lhs, const symbol::DimExpr& rhs) { + constraints_manager_.AddBroadcastableCstr(lhs, rhs); +} + +bool InferSymbolicShapeContext::IsBroadcastable( + const symbol::DimExpr& lhs, const symbol::DimExpr& rhs) const { + return constraints_manager_.IsBroadcastable(lhs, rhs); +} + +namespace { + +bool CanSubstituteInShapeAnalysis(const symbol::DimExpr& lhs, + const symbol::DimExpr& rhs) { + auto CanSubstitutePredictor = symbol::Overloaded{ + [](std::int64_t lhs, const auto& rhs) { return true; }, + [](const std::string& lhs, const std::string& rhs) { return true; }, + [](const std::string& lhs, + const symbol::Broadcast& rhs) { return true; }, + [](const auto& lhs, const auto& rhs) { return false; }}; + return std::visit(CanSubstitutePredictor, lhs.variant(), rhs.variant()) || + std::visit(CanSubstitutePredictor, rhs.variant(), lhs.variant()); +} + +} // namespace + +void InferSymbolicShapeContext::SubstituteDimExpr( + const symbol::DimExpr& origin, const symbol::DimExpr& substituted) { + if (!CanSubstituteInShapeAnalysis(origin, substituted)) return; + + substitution_pattern_[origin] = substituted; + for (auto it = substitution_pattern_.begin(); + it != substitution_pattern_.end(); + it++) { + if (it->second == origin) it->second = substituted; + } + + for (auto it = value_id_to_shape_or_data_.begin(); + it != value_id_to_shape_or_data_.end(); + it++) { + const symbol::ShapeOrDataDimExprs& substituted_shape_or_data = + symbol::SubstituteShapeOrData(it->second, substitution_pattern_); + it->second = substituted_shape_or_data; + } +} + +void InferSymbolicShapeContext::PrintShapeOrDatas() const { + LOG(INFO) << "shape analysis : @" << this + << " value_id_to_shape_or_data_ size : " + << value_id_to_shape_or_data_.size(); + LOG(INFO) << "----------- ShapeOrData for Values ------------"; + for (const auto& [value_id, shape_or_data] : value_id_to_shape_or_data_) { + LOG(INFO) << value_id << " : " << shape_or_data; + } +} + +void ShapeConstraintIRAnalysis::Init() { context_.Init(); } + +const std::string ShapeConstraintIRAnalysis::GetNextSymName() { + return context_.GetNextSymName(); +} + +void ShapeConstraintIRAnalysis::SetStaticShapeForValue(Value val) { + context_.SetStaticShapeForValue(val); } void ShapeConstraintIRAnalysis::InferShapeOrDataForValue(Value val) { std::unordered_set subgraph_ops; std::vector start_ops; + const auto& GetRealOperandSource = [&](Operation* op) -> std::vector { + if (op->num_regions() == 0) { + return op->operands_source(); + } else { + std::vector ret; + for (uint32_t i = 0; i < op->num_regions(); i++) { + for (auto& block : op->region(i)) { + for (auto& sub_op : block) { + for (auto& operand : sub_op.operands_source()) { + ret.emplace_back(operand); + } + } + } + } + return ret; + } + }; + const auto& VisitNotInferedInputOp = [&](Operation* op, const std::function& Visit) { - for (auto& operand : op->operands_source()) { - if (operand.impl() && !HasShapeOrDataForValue(operand)) { - Visit(operand.defining_op()); + for (auto& operand : GetRealOperandSource(op)) { + if (operand.impl() && !context_.HasShapeOrDataForValue(operand)) { + if (!operand.defining_op()) { + SetStaticShapeForValue(operand); + } else { + Visit(operand.defining_op()); + } } } }; @@ -62,9 +258,13 @@ void ShapeConstraintIRAnalysis::InferShapeOrDataForValue(Value val) { build_subgraph_walker(val.defining_op(), [&](Operation* op) { subgraph_ops.insert(op); bool has_prev_op = false; - for (auto& operand : op->operands_source()) { - if (operand.impl() && !HasShapeOrDataForValue(operand)) { - has_prev_op = true; + for (auto& operand : GetRealOperandSource(op)) { + if (operand.impl() && !context_.HasShapeOrDataForValue(operand)) { + if (!operand.defining_op()) { + SetStaticShapeForValue(operand); + } else { + has_prev_op = true; + } } } if (!has_prev_op) { @@ -74,7 +274,7 @@ void ShapeConstraintIRAnalysis::InferShapeOrDataForValue(Value val) { const auto& VisitSubgraphInputOp = [&](Operation* op, const std::function& Visit) { - for (auto& operand : op->operands_source()) { + for (auto& operand : GetRealOperandSource(op)) { if (operand.impl() && subgraph_ops.count(operand.defining_op())) { Visit(operand.defining_op()); } @@ -86,8 +286,13 @@ void ShapeConstraintIRAnalysis::InferShapeOrDataForValue(Value val) { for (auto iter = op->result(i).use_begin(); iter != op->result(i).use_end(); ++iter) { - if (subgraph_ops.count(iter->owner())) { - Visit(iter->owner()); + auto parent_op = iter->owner(); + while (parent_op) { + if (subgraph_ops.count(parent_op)) { + Visit(parent_op); + break; + } + parent_op = parent_op->GetParentOp(); } } } @@ -99,17 +304,26 @@ void ShapeConstraintIRAnalysis::InferShapeOrDataForValue(Value val) { auto infer_symbolic_shape_interface = op->dyn_cast(); if (infer_symbolic_shape_interface) { - infer_symbolic_shape_interface.InferSymbolicShape(this); + infer_symbolic_shape_interface.InferSymbolicShape(&context_); for (auto& result_value : op->results()) { - if (result_value && (!HasShapeOrDataForValue(result_value))) { + if (result_value && (!context_.HasShapeOrDataForValue(result_value))) { PADDLE_THROW(phi::errors::Fatal(op->name() + " HAS ERROR on InferSymbolicShape!")); } } } else { - PADDLE_THROW(phi::errors::Unimplemented( - val.defining_op()->name() + - " DOES NOT have InferSymbolicShapeInterface!")); + // TODO(Hongqing-work): throw it after the shape analysis reconstruct + // is done. + // PADDLE_THROW(phi::errors::Unimplemented( + // val.defining_op()->name() + + // " DOES NOT have InferSymbolicShapeInterface!")); + LOG(WARNING) << op->name() + << " DOES NOT have InferSymbolicShapeInterface!"; + for (auto& result_value : op->results()) { + if (result_value && (!context_.HasShapeOrDataForValue(result_value))) { + SetStaticShapeForValue(result_value); + } + } } }); } @@ -122,66 +336,42 @@ ShapeConstraintIRAnalysis::GetShapeOrDataForValue(Value val) { symbol::TensorShapeOrDataDimExprs{}}; return empty; } - if (!HasShapeOrDataForValue(val)) { + if (!context_.HasShapeOrDataForValue(val)) { // backtrack to infer shape from defining op - InferShapeOrDataForValue(val); + if (!val.defining_op()) { + SetStaticShapeForValue(val); + } else { + VLOG(3) << "InferShapeOrDataForValue, defining_op: " + << val.defining_op()->name(); + InferShapeOrDataForValue(val); + } } - return value_to_shape_or_data_.at(val); + return context_.GetShapeOrDataForValue(val); } void ShapeConstraintIRAnalysis::SetShapeOrDataForValue( Value val, const symbol::ShapeOrDataDimExprs& shape_or_data) { - const symbol::ShapeOrDataDimExprs& substituted_shape_or_data = - symbol::SubstituteShapeOrData(shape_or_data, substitution_pattern_); - auto iter = value_to_shape_or_data_.find(val); - if (iter == value_to_shape_or_data_.end()) { - value_to_shape_or_data_.emplace(val, substituted_shape_or_data); - } else { - iter->second = substituted_shape_or_data; - } -} - -void ShapeConstraintIRAnalysis::AddEqualCstr(const symbol::DimExpr& lhs, - const symbol::DimExpr& rhs) { - constraints_manager_.AddEqCstr(lhs, rhs); + context_.SetShapeOrDataForValue(val, shape_or_data); } bool ShapeConstraintIRAnalysis::IsEqual(const symbol::DimExpr& lhs, const symbol::DimExpr& rhs) const { - return constraints_manager_.IsEqual(lhs, rhs); -} - -void ShapeConstraintIRAnalysis::AddGreatThanOneCstr( - const symbol::DimExpr& dim_expr) { - constraints_manager_.AddGTOneCstr(dim_expr); + return context_.IsEqual(lhs, rhs); } bool ShapeConstraintIRAnalysis::IsGreatThanOne( const symbol::DimExpr& dim_expr) const { - return constraints_manager_.IsGTOne(dim_expr); -} - -void ShapeConstraintIRAnalysis::AddBroadcastableCstr( - const symbol::DimExpr& lhs, const symbol::DimExpr& rhs) { - constraints_manager_.AddBroadcastableCstr(lhs, rhs); + return context_.IsGreatThanOne(dim_expr); } bool ShapeConstraintIRAnalysis::IsBroadcastable( const symbol::DimExpr& lhs, const symbol::DimExpr& rhs) const { - return constraints_manager_.IsBroadcastable(lhs, rhs); + return context_.IsBroadcastable(lhs, rhs); } void ShapeConstraintIRAnalysis::PrintShapeOrDatas() const { - LOG(INFO) << "shape analysis : @" << this - << " value_to_shape_or_data_ size : " - << value_to_shape_or_data_.size(); - LOG(INFO) << "----------- ShapeOrData for Values ------------"; - for (const auto& [value, shape_or_data] : value_to_shape_or_data_) { - if (value) { - LOG(INFO) << GetValueId(value) << " : " << shape_or_data; - } - } + context_.PrintShapeOrDatas(); } // Currently, we only support TensorShapeOrDataDimExprs but not @@ -189,10 +379,6 @@ void ShapeConstraintIRAnalysis::PrintShapeOrDatas() const { bool ShapeConstraintIRAnalysis::IsShapeEqual(Value lhs, Value rhs) { if (lhs == rhs) return true; - if (!HasShapeOrDataForValue(lhs) || !HasShapeOrDataForValue(rhs)) { - return false; - } - auto lhs_type = lhs.type().dyn_cast(); auto rhs_type = rhs.type().dyn_cast(); @@ -245,11 +431,6 @@ bool ShapeConstraintIRAnalysis::IsProductEqual( return lhs_product == rhs_product; } - // For dynamic shape - if (!HasShapeOrDataForValue(lhs) || !HasShapeOrDataForValue(rhs)) { - return false; - } - auto lhs_shape_data = GetShapeOrDataForValue(lhs); auto rhs_shape_data = GetShapeOrDataForValue(rhs); @@ -338,49 +519,13 @@ symbol::DimExpr ShapeConstraintIRAnalysis::GetProductDimExpr( return symbol::SimplifyDimExpr(product); } -namespace { - -bool CanSubstituteInShapeAnalysis(const symbol::DimExpr& lhs, - const symbol::DimExpr& rhs) { - auto CanSubstitutePredictor = symbol::Overloaded{ - [](std::int64_t lhs, const auto& rhs) { return true; }, - [](const std::string& lhs, const std::string& rhs) { return true; }, - [](const std::string& lhs, - const symbol::Broadcast& rhs) { return true; }, - [](const auto& lhs, const auto& rhs) { return false; }}; - return std::visit(CanSubstitutePredictor, lhs.variant(), rhs.variant()) || - std::visit(CanSubstitutePredictor, rhs.variant(), lhs.variant()); -} - -} // namespace - -void ShapeConstraintIRAnalysis::SubstituteDimExpr( - const symbol::DimExpr& origin, const symbol::DimExpr& substituted) { - if (!CanSubstituteInShapeAnalysis(origin, substituted)) return; - - substitution_pattern_[origin] = substituted; - for (auto it = substitution_pattern_.begin(); - it != substitution_pattern_.end(); - it++) { - if (it->second == origin) it->second = substituted; - } - - for (auto it = value_to_shape_or_data_.begin(); - it != value_to_shape_or_data_.end(); - it++) { - const symbol::ShapeOrDataDimExprs& substituted_shape_or_data = - symbol::SubstituteShapeOrData(it->second, substitution_pattern_); - SetShapeOrDataForValue(it->first, substituted_shape_or_data); - } -} - pir::PrintHooks ShapeConstraintIRAnalysis::PrintHook() { pir::PrintHooks print_hook; print_hook.op_print_hook = [&](Operation* op, IrPrinter& printer) { printer.IrPrinter::PrintOperation(op); printer.os << " { "; for (uint32_t i = 0; i < op->num_results(); ++i) { - if (this->HasShapeOrDataForValue(op->result(i))) { + if (context_.HasShapeOrDataForValue(op->result(i))) { printer.os << "(" << this->GetShapeOrDataForValue(op->result(i)) << ")"; } else { printer.os << "()"; @@ -413,4 +558,33 @@ ShapeConstraintIRAnalysis& ShapeAnalysisManager::Get(pir::Program* program) { return *it->second; } +bool IsStaticShape(const Value& value) { + const auto& value_type = value.type(); + if (!value || !value_type) { + return false; + } + if (value_type.isa()) { + return !::common::contain_unknown_dim( + value_type.dyn_cast().dims()); + } + if (value_type.isa()) { + bool is_static = true; + auto vec_data = value_type.dyn_cast().data(); + for (unsigned i = 0; i < vec_data.size(); ++i) { + if (!vec_data[i].isa()) { + is_static = false; + break; + } else { + is_static = !::common::contain_unknown_dim( + vec_data[i].dyn_cast().dims()); + if (!is_static) { + break; + } + } + } + return is_static; + } + return false; +} + } // namespace pir diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index 5fd5ae50206d08..1993281f158ab8 100644 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -3847,12 +3847,14 @@ function run_setup(){ INFERENCE_DEMO_INSTALL_DIR=${INFERENCE_DEMO_INSTALL_DIR:-/root/.cache/inference_demo} fi - pip install -U PyGithub - python ${PADDLE_ROOT}/tools/check_only_change_python_files.py - if [ -f "${PADDLE_ROOT}/build/only_change_python_file.txt" ];then - export WITH_CPP_TEST=OFF - else - export WITH_CPP_TEST=ON + if [ -z "${WITH_CPP_TEST}" ] && [ "${WITH_TESTING}" == "ON" ];then + pip install PyGithub + python ${PADDLE_ROOT}/tools/check_only_change_python_files.py + if [ -f "${PADDLE_ROOT}/build/only_change_python_file.txt" ];then + export WITH_CPP_TEST=OFF + else + export WITH_CPP_TEST=ON + fi fi distibuted_flag=${WITH_DISTRIBUTE:-OFF} gloo_flag=${distibuted_flag} @@ -3872,13 +3874,13 @@ function run_setup(){ echo "if you use setup.py to compile,please export envs as following in /paddle ..." cat << EOF ======================================== - export CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE:-Release} WITH_GPU=${WITH_GPU:-OFF} WITH_SHARED_PHI=${WITH_SHARED_PHI:-OFF} WITH_TENSORRT=${WITH_TENSORRT:-ON} WITH_ROCM=${WITH_ROCM:-OFF} WITH_CINN=${WITH_CINN:-OFF} WITH_DISTRIBUTE=${distibuted_flag} WITH_MKL=${WITH_MKL:-ON} WITH_AVX=${WITH_AVX:-OFF} CUDA_ARCH_NAME=${CUDA_ARCH_NAME:-All} NEW_RELEASE_PYPI=${NEW_RELEASE_PYPI:-OFF} NEW_RELEASE_ALL=${NEW_RELEASE_ALL:-OFF} NEW_RELEASE_JIT=${NEW_RELEASE_JIT:-OFF} WITH_PYTHON=${WITH_PYTHON:-ON} CUDNN_ROOT=/usr/ WITH_TESTING=${WITH_TESTING:-ON} WITH_COVERAGE=${WITH_COVERAGE:-OFF} WITH_INCREMENTAL_COVERAGE=${WITH_INCREMENTAL_COVERAGE:-OFF} CMAKE_MODULE_PATH=/opt/rocm/hip/cmake CMAKE_EXPORT_COMPILE_COMMANDS=ON WITH_INFERENCE_API_TEST=${WITH_INFERENCE_API_TEST:-ON} INFERENCE_DEMO_INSTALL_DIR=${INFERENCE_DEMO_INSTALL_DIR} PY_VERSION=${PY_VERSION:-3.8} CMAKE_INSTALL_PREFIX=${INSTALL_PREFIX:-/paddle/build} WITH_PSCORE=${pscore_flag} WITH_PSLIB=${pslib_flag} WITH_GLOO=${gloo_flag} WITH_XPU=${WITH_XPU:-OFF} WITH_IPU=${WITH_IPU:-OFF} XPU_SDK_ROOT=${XPU_SDK_ROOT:-""} WITH_XPU_BKCL=${WITH_XPU_BKCL:-OFF} WITH_ARM=${WITH_ARM:-OFF} WITH_STRIP=${WITH_STRIP:-ON} ON_INFER=${ON_INFER:-OFF} WITH_HETERPS=${WITH_HETERPS:-OFF} WITH_GPU_GRAPH=${WITH_GPU_GRAPH:-OFF} CUDA_ARCH_BIN=${CUDA_ARCH_BIN} WITH_RECORD_BUILDTIME=${WITH_RECORD_BUILDTIME:-OFF} WITH_UNITY_BUILD=${WITH_UNITY_BUILD:-OFF} WITH_ONNXRUNTIME=${WITH_ONNXRUNTIME:-OFF} WITH_CUDNN_FRONTEND=${WITH_CUDNN_FRONTEND:-OFF} + export CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE:-Release} WITH_GPU=${WITH_GPU:-OFF} WITH_SHARED_PHI=${WITH_SHARED_PHI:-OFF} WITH_TENSORRT=${WITH_TENSORRT:-ON} WITH_ROCM=${WITH_ROCM:-OFF} WITH_CINN=${WITH_CINN:-OFF} WITH_DISTRIBUTE=${distibuted_flag} WITH_MKL=${WITH_MKL:-ON} WITH_AVX=${WITH_AVX:-OFF} CUDA_ARCH_NAME=${CUDA_ARCH_NAME:-All} NEW_RELEASE_PYPI=${NEW_RELEASE_PYPI:-OFF} NEW_RELEASE_ALL=${NEW_RELEASE_ALL:-OFF} NEW_RELEASE_JIT=${NEW_RELEASE_JIT:-OFF} WITH_PYTHON=${WITH_PYTHON:-ON} CUDNN_ROOT=/usr/ WITH_TESTING=${WITH_TESTING:-ON} WITH_COVERAGE=${WITH_COVERAGE:-OFF} WITH_INCREMENTAL_COVERAGE=${WITH_INCREMENTAL_COVERAGE:-OFF} CMAKE_MODULE_PATH=/opt/rocm/hip/cmake CMAKE_EXPORT_COMPILE_COMMANDS=ON WITH_INFERENCE_API_TEST=${WITH_INFERENCE_API_TEST:-ON} INFERENCE_DEMO_INSTALL_DIR=${INFERENCE_DEMO_INSTALL_DIR} PY_VERSION=${PY_VERSION:-3.8} CMAKE_INSTALL_PREFIX=${INSTALL_PREFIX:-/paddle/build} WITH_PSCORE=${pscore_flag} WITH_PSLIB=${pslib_flag} WITH_GLOO=${gloo_flag} WITH_XPU=${WITH_XPU:-OFF} WITH_IPU=${WITH_IPU:-OFF} XPU_SDK_ROOT=${XPU_SDK_ROOT:-""} WITH_XPU_BKCL=${WITH_XPU_BKCL:-OFF} WITH_ARM=${WITH_ARM:-OFF} WITH_STRIP=${WITH_STRIP:-ON} ON_INFER=${ON_INFER:-OFF} WITH_HETERPS=${WITH_HETERPS:-OFF} WITH_GPU_GRAPH=${WITH_GPU_GRAPH:-OFF} CUDA_ARCH_BIN=${CUDA_ARCH_BIN} WITH_RECORD_BUILDTIME=${WITH_RECORD_BUILDTIME:-OFF} WITH_UNITY_BUILD=${WITH_UNITY_BUILD:-OFF} WITH_ONNXRUNTIME=${WITH_ONNXRUNTIME:-OFF} WITH_CUDNN_FRONTEND=${WITH_CUDNN_FRONTEND:-OFF} -DWITH_CPP_TEST=${WITH_CPP_TEST:-OFF} ======================================== EOF echo "if you use cmake to compile,please Configuring cmake in /paddle/build ..." cat < 0: + func_name = parsing_names[1] + else: + continue + else: + continue + + var_dict.update( + _load_pir_persistable_vars( + model_path, programs[func_name], file_name + ) + ) return var_dict @@ -290,6 +328,7 @@ def _run_dygraph(instance, input, program_holder): input_tensors.append(tensor) persistable_tensors = [] + for var_name in program_holder.persistable_names: dy_var_name = instance._persistable_var_name_dict[var_name] if dy_var_name in instance._parameters: @@ -558,3 +597,47 @@ def train(self): def eval(self): self._is_test = True self.training = False + + def _get_program_holder(self, method_name='forward'): + program_holder = self._program_holder_dict.get(method_name, None) + if program_holder is None: + raise ValueError( + "The method `%s` does not exist in loaded PirTranslatedLayer." + % method_name + ) + return program_holder + + def _input_spec(self, method_name='forward'): + # 1. get program holder + program_holder = self._get_program_holder(method_name) + + # 2. build input spec by input desc + input_spec = [] + for var in program_holder.input_vars: + spec = paddle.static.InputSpec( + shape=var.shape, + dtype=var.dtype, + name=var.name, + ) + input_spec.append(spec) + + return input_spec + + def _output_spec(self, method_name='forward'): + # 1. get program holder + program_holder = self._get_program_holder(method_name) + + # 2. build output spec by output desc + output_spec = [] + for var in program_holder.output_vars: + # NOTE(chenweihang): InputSpec describes a tensor, not just input. + # Maybe the name is not good enough. Here we use InputSpec to + # construct the description of Output tensor + spec = paddle.static.InputSpec( + shape=var.shape, + dtype=var.dtype, + name=var.name, + ) + output_spec.append(spec) + + return output_spec diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index a9290887533764..7722ffb437389b 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -87,6 +87,8 @@ ) from .flash_attention import ( flash_attention_with_sparse_mask, + flash_attn_qkvpacked, + flash_attn_varlen_qkvpacked, scaled_dot_product_attention, sdp_kernel, # noqa: F401 ) @@ -279,5 +281,7 @@ 'gaussian_nll_loss', 'scaled_dot_product_attention', 'flash_attention_with_sparse_mask', + 'flash_attn_qkvpacked', + 'flash_attn_varlen_qkvpacked', 'group_norm', ] diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index e82684c32981de..84c7882a7151da 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -300,6 +300,158 @@ def flash_attention( ) +def flash_attn_qkvpacked( + qkv, + dropout=0.0, + causal=False, + return_softmax=False, + *, + fixed_seed_offset=None, + rng_name="", + training=True, + name=None, +): + r""" + The equation is: + + .. math:: + + result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V + + where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module. + The dimensions of the three parameters are the same. + ``d`` represents the size of the last dimension of the three parameters. + + Warning: + This API only supports inputs with dtype float16 and bfloat16. + Don't call this API if flash_attn is not supported. + + Args: + qkv(Tensor): The query/key/value packed tensor in the Attention module. + 5-D tensor with shape: + [batchsize, seqlen , num_heads/num_heads_k + 2, num_heads_k, head_dim]. + The dtype can be float16 or bfloat16. + dropout(float): The dropout ratio. + causal(bool): Whether enable causal mode. + return_softmax(bool): Whether to return softmax. + fixed_seed_offset(Tensor, optional): With fixed seed, offset for dropout mask. + training(bool): Whether it is in the training phase. + rng_name(str): The name to select Generator. + name(str, optional): The default value is None. Normally there is no need for user + to set this property. For more information, please refer to + :ref:`api_guide_Name`. + + Returns: + - out(Tensor). The attention tensor. 4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim]. The dtype can be float16 or bfloat16. + - softmax(Tensor). The softmax tensor. None if return_softmax is False. + + Examples: + .. code-block:: python + + >>> # doctest: +SKIP('flash_attn need A100 compile') + >>> import paddle + + >>> paddle.seed(2023) + >>> q = paddle.rand((1, 128, 2, 16)) + >>> qkv = paddle.stack([q, q, q], axis=2) + >>> output = paddle.nn.functional.flash_attn_qkvpacked(qkv, 0.9, False, False) + >>> print(output) + (Tensor(shape=[1, 128, 2, 16], dtype=float32, place=Place(cpu), stop_gradient=True, + [[[[0.34992966, 0.34456208, 0.45826620, ..., 0.39883569, + 0.42132431, 0.39157745], + [0.76687670, 0.65837246, 0.69117945, ..., 0.82817286, + 0.76690865, 0.71485823]], + ..., + [[0.71662450, 0.57275224, 0.57053083, ..., 0.48108247, + 0.53336465, 0.54540104], + [0.59137970, 0.51350880, 0.50449550, ..., 0.38860250, + 0.40526697, 0.60541755]]]]), None) + >>> # doctest: -SKIP + + """ + head_dim = qkv.shape[-1] + sdp_func_name = _select_sdp(head_dim) + + if sdp_func_name == "flash_attn": + if in_dynamic_or_pir_mode(): + ( + result_attention, + result_softmax, + _, + _, + ) = _C_ops.flash_attn_qkvpacked( + qkv, + fixed_seed_offset, + None, + dropout, + causal, + return_softmax, + not training, + rng_name, + ) + return result_attention, result_softmax if return_softmax else None + + helper = LayerHelper('flash_attn_qkvpacked', **locals()) + dtype = helper.input_dtype(input_param_name='qkv') + out = helper.create_variable_for_type_inference(dtype) + softmax = helper.create_variable_for_type_inference(dtype) + softmax_lse = helper.create_variable_for_type_inference(paddle.float32) + seed_offset = helper.create_variable_for_type_inference(paddle.int64) + inputs = { + 'qkv': qkv, + 'fixed_seed_offset': fixed_seed_offset, + } + outputs = { + 'out': out, + 'softmax': softmax, + 'softmax_lse': softmax_lse, + 'seed_offset': seed_offset, + } + helper.append_op( + type='flash_attn_qkvpacked', + inputs=inputs, + outputs=outputs, + attrs={ + 'dropout': dropout, + 'causal': causal, + 'return_softmax': return_softmax, + 'is_test': not training, + 'rng_name': rng_name, + }, + ) + return out, softmax if return_softmax else None + else: + # don't call qkvpacked if not using flash_attn + query = qkv[:, :, :-2].reshape([0, 0, -1, qkv.shape[-1]]) + key = qkv[:, :, -2] + value = qkv[:, :, -1] + if sdp_func_name == "mem_efficient": + from paddle.incubate.nn.memory_efficient_attention import ( + memory_efficient_attention, + ) + + output = memory_efficient_attention( + query, + key, + value, + attn_bias=None, + p=dropout, + scale=None, + training=training, + ) + return output, None + else: + return _math_attention( + query, + key, + value, + dropout_rate=dropout, + causal=causal, + return_softmax=return_softmax, + training=training, + ) + + def flash_attn_unpadded( query, key, @@ -439,6 +591,134 @@ def flash_attn_unpadded( return out, softmax if return_softmax else None +def flash_attn_varlen_qkvpacked( + qkv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + scale, + dropout=0.0, + causal=False, + return_softmax=False, + fixed_seed_offset=None, + rng_name="", + varlen_padded=True, + training=True, + name=None, +): + r""" + The equation is: + + .. math:: + + result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V + + where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module. + The dimensions of the three parameters are the same. + ``d`` represents the size of the last dimension of the three parameters. + + Warning: + This API only supports inputs with dtype float16 and bfloat16. + + Args: + qkv(Tensor): The padded query/key/value packed tensor in the Attention module. The padding part won't be computed + 4-D tensor with shape: + [total_seq_len, num_heads/num_heads_k + 2, num_heads_k, head_dim]. + The dtype can be float16 or bfloat16. + cu_seqlens_q(Tensor): The cumulative sequence lengths of the sequences in the batch, + used to index query. + cu_seqlens_k(Tensor): The cumulative sequence lengths of the sequences in the batch, + used to index key and value. + max_seqlen_q(int): Maximum sequence length of query in the batch. Note it's the padding length, not the max actual seqlen + max_seqlen_k(int): Maximum sequence length of key/value in the batch. + scale(float): The scaling of QK^T before applying softmax. + dropout(float): The dropout ratio. + causal(bool): Whether enable causal mode. + return_softmax(bool): Whether to return softmax. + fixed_seed_offset(Tensor, optional): With fixed seed, offset for dropout mask. + rng_name(str): The name to select Generator. + training(bool): Whether it is in the training phase. + name(str, optional): The default value is None. Normally there is no need for user + to set this property. For more information, please refer to + :ref:`api_guide_Name`. + + Returns: + - out(Tensor). The attention tensor. The tensor is padded by zeros. 3-D tensor with shape: [total_seq_len, num_heads, head_dim]. The dtype can be float16 or bfloat16. + - softmax(Tensor). The softmax tensor. None if return_softmax is False. + + Examples: + .. code-block:: python + + >>> # doctest: +SKIP('flash_attn need A100 compile') + >>> import paddle + >>> paddle.seed(2023) + >>> q = paddle.rand((2, 128, 8, 16), dtype='float16') + >>> cu = paddle.arange(0, 384, 128, dtype='int32') + >>> qq = paddle.reshape(q, [256, 8, 16]) + >>> qkv = paddle.stack([qq, qq, qq], axis=2) + >>> output = paddle.nn.functional.flash_attn_varlen_qkvpacked(qkv, cu, cu, 128, 128, 0.25, 0.0, False, False) + >>> # doctest: -SKIP + + """ + if in_dynamic_mode(): + ( + result_attention, + result_softmax, + ) = _C_ops.flash_attn_varlen_qkvpacked( + qkv, + cu_seqlens_q, + cu_seqlens_k, + fixed_seed_offset, + None, + max_seqlen_q, + max_seqlen_k, + scale, + dropout, + causal, + return_softmax, + not training, + rng_name, + varlen_padded, + ) + return result_attention, result_softmax if return_softmax else None + + helper = LayerHelper('flash_attn_varlen_qkvpacked', **locals()) + dtype = helper.input_dtype(input_param_name='qkv') + out = helper.create_variable_for_type_inference(dtype) + softmax = helper.create_variable_for_type_inference(dtype) + softmax_lse = helper.create_variable_for_type_inference(paddle.float32) + seed_offset = helper.create_variable_for_type_inference(paddle.int64) + inputs = { + 'qkv': qkv, + 'cu_seqlens_q': cu_seqlens_q, + 'cu_seqlens_k': cu_seqlens_k, + 'fixed_seed_offset': fixed_seed_offset, + } + outputs = { + 'out': out, + 'softmax': softmax, + 'softmax_lse': softmax_lse, + 'seed_offset': seed_offset, + } + helper.append_op( + type='flash_attn_varlen_qkvpacked', + inputs=inputs, + outputs=outputs, + attrs={ + 'max_seqlen_q': max_seqlen_q, + 'max_seqlen_k': max_seqlen_k, + 'scale': scale, + 'dropout': dropout, + 'causal': causal, + 'return_softmax': return_softmax, + 'is_test': not training, + 'rng_name': rng_name, + }, + ) + return out, softmax if return_softmax else None + + def scaled_dot_product_attention( query, key, diff --git a/python/paddle/nn/functional/norm.py b/python/paddle/nn/functional/norm.py index 82a071064e3be5..8112371a351d02 100644 --- a/python/paddle/nn/functional/norm.py +++ b/python/paddle/nn/functional/norm.py @@ -661,7 +661,7 @@ def group_norm( Default: None. bias(Tensor, optional): The bias Tensor of group_norm, with shape: attr:`[num_channels]`. Default: None. - data_format(str, optional): Specify the input data format. Only NCHW is supported. Default: NCHW. + data_format(str, optional): Specify the input data format. Support "NCL", "NCHW", "NCDHW", "NLC", "NHWC" or "NDHWC". Default: "NCHW". name(str, optional): Name for the GroupNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. Returns: @@ -702,9 +702,11 @@ def group_norm( [[-1.34163547, -0.44721183], [ 0.44721183, 1.34163547]]]]) """ - if data_format not in ['NCHW', 'NHWC']: + if data_format not in ['NCL', 'NCHW', 'NCDHW', 'NLC', 'NHWC', 'NDHWC']: raise ValueError("unsupported data layout:" + data_format) + data_format = 'NCHW' if data_format[1] == 'C' else 'NHWC' + if in_dynamic_or_pir_mode(): return _C_ops.group_norm( x, diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index 1b71fb426f5e01..bb436fe798fdaa 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -426,11 +426,11 @@ class GroupNorm(Layer): division by zero. Default: 1e-05. weight_attr(ParamAttr|bool, optional): The parameter attribute for the learnable scale :math:`g`. If it is set to False, no scale will be added to the output units. - If it is set to None, the bias is initialized one. Default: None. + If it is set to None, the scale is initialized one. Default: None. bias_attr(ParamAttr|bool, optional): The parameter attribute for the learnable bias :math:`b`. If it is set to False, no bias will be added to the output units. If it is set to None, the bias is initialized zero. Default: None. - data_format(str, optional): Specify the input data format. Only NCHW is supported. Default: NCHW. + data_format(str, optional): Specify the input data format. Support "NCL", "NCHW", "NCDHW", "NLC", "NHWC" or "NDHWC". Default: "NCHW". name(str, optional): Name for the GroupNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. Shape: @@ -493,8 +493,10 @@ def __init__( self._epsilon = epsilon self._num_channels = num_channels self._num_groups = num_groups - if data_format not in ['NCHW', 'NHWC']: + if data_format not in ['NCL', 'NCHW', 'NCDHW', 'NLC', 'NHWC', 'NDHWC']: raise ValueError("unsupported data layout:" + data_format) + + data_format = 'NCHW' if data_format[1] == 'C' else 'NHWC' self._data_format = data_format param_shape = [self._num_channels] diff --git a/python/paddle/optimizer/adam.py b/python/paddle/optimizer/adam.py index 0d51987835cab5..75ed3094a29588 100644 --- a/python/paddle/optimizer/adam.py +++ b/python/paddle/optimizer/adam.py @@ -32,8 +32,6 @@ __all__ = [] -GRAD_TYPES = [int(paddle.float32), int(paddle.float16), int(paddle.bfloat16)] - class Adam(Optimizer): r""" @@ -570,13 +568,16 @@ def _append_optimize_multi_tensor_op( params = [pair[0] for pair in parameters_and_grads] grads_types = core.eager.get_grads_types(params) for index, tp in enumerate(grads_types): - if tp == GRAD_TYPES[0]: + if tp == core.DataType.FLOAT32: grad_dict['FP32_LODTensor'].append( parameters_and_grads[index][1] ) lr = self._create_param_lr(parameters_and_grads[index]) lr_dict['FP32_LODTensor'].append(lr) - elif tp == GRAD_TYPES[1] or tp == GRAD_TYPES[2]: + elif ( + tp == core.DataType.FLOAT16 + or tp == core.DataType.BFLOAT16 + ): grad_dict['FP16_LODTensor'].append( parameters_and_grads[index][1] ) diff --git a/python/paddle/pir_utils.py b/python/paddle/pir_utils.py index d2a93f2b8e5576..2c4b8bf28d67d7 100644 --- a/python/paddle/pir_utils.py +++ b/python/paddle/pir_utils.py @@ -127,6 +127,61 @@ def __exit__(self, exc_type, exc_val, exc_tb): _switch_to_pir_() +class DygraphPirGuard: + def __enter__(self): + self.old_flag = paddle.base.framework.get_flags("FLAGS_enable_pir_api")[ + "FLAGS_enable_pir_api" + ] + if not self.old_flag: + paddle.framework.set_flags({"FLAGS_enable_pir_api": True}) + paddle.base.framework.global_var._use_pir_api_ = True + bind_datatype() + self._switch_to_pir() + + def __exit__(self, exc_type, exc_val, exc_tb): + if not self.old_flag: + paddle.framework.set_flags({"FLAGS_enable_pir_api": False}) + paddle.base.framework.global_var._use_pir_api_ = False + bind_vartype() + self._switch_to_old_ir() + + def _switch_to_pir(self): + if paddle.base.framework.get_flags("FLAGS_enable_pir_api")[ + "FLAGS_enable_pir_api" + ]: + _switch_to_pir_() + + def _switch_to_old_ir(self): + if not paddle.base.framework.get_flags("FLAGS_enable_pir_api")[ + "FLAGS_enable_pir_api" + ]: + _switch_to_old_ir_() + else: + raise RuntimeError( + "IrGuard._switch_to_old_ir only work when paddle.framework.in_pir_mode() is false, \ + please set FLAGS_enable_pir_api = false" + ) + + +class DygraphOldIrGuard: + def __enter__(self): + self.old_flag = paddle.base.framework.get_flags("FLAGS_enable_pir_api")[ + "FLAGS_enable_pir_api" + ] + if self.old_flag: + paddle.framework.set_flags({"FLAGS_enable_pir_api": False}) + paddle.base.framework.global_var._use_pir_api_ = False + bind_vartype() + _switch_to_old_ir_() + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.old_flag: + paddle.framework.set_flags({"FLAGS_enable_pir_api": True}) + paddle.base.framework.global_var._use_pir_api_ = True + bind_datatype() + _switch_to_pir_() + + def test_with_pir_api(func): @wraps(func) def impl(*args, **kwargs): @@ -145,3 +200,15 @@ def impl(*args, **kwargs): func(*args, **kwargs) return impl + + +def test_with_dygraph_pir(func): + @wraps(func) + def impl(*args, **kwargs): + with DygraphOldIrGuard(): + func(*args, **kwargs) + + with DygraphPirGuard(): + func(*args, **kwargs) + + return impl diff --git a/python/paddle/static/input.py b/python/paddle/static/input.py index 4cc2d1b9187459..d2c98d859a4e2f 100644 --- a/python/paddle/static/input.py +++ b/python/paddle/static/input.py @@ -251,7 +251,7 @@ def from_tensor(cls, tensor, name=None): InputSpec(shape=(2, 2), dtype=paddle.float32, name=x, stop_gradient=False) """ - if isinstance(tensor, (Variable, core.eager.Tensor)): + if isinstance(tensor, (Variable, core.eager.Tensor, paddle.pir.Value)): return cls(tensor.shape, tensor.dtype, name or tensor.name) else: raise ValueError( diff --git a/python/paddle/static/io_utils.py b/python/paddle/static/io_utils.py index 946d978c8a867e..5c77b032e7c194 100644 --- a/python/paddle/static/io_utils.py +++ b/python/paddle/static/io_utils.py @@ -21,7 +21,6 @@ from paddle.base import ( CompiledProgram, Variable, - default_main_program, ) @@ -66,7 +65,7 @@ def _get_valid_program(program=None): return default main program if program is None. """ if program is None: - program = default_main_program() + program = paddle.static.default_main_program() elif isinstance(program, CompiledProgram): program = program._program if program is None: diff --git a/python/paddle/static/pir_io.py b/python/paddle/static/pir_io.py index 68bc0d83be1ff3..45ba1491a7f9c0 100644 --- a/python/paddle/static/pir_io.py +++ b/python/paddle/static/pir_io.py @@ -36,7 +36,6 @@ from paddle.base.executor import Executor, global_scope from paddle.base.framework import ( dygraph_not_support, - process_type_promotion, static_only, ) from paddle.base.log_helper import get_logger @@ -215,7 +214,7 @@ def normalize_pir_program(program, feed_vars, fetch_vars, **kwargs): uniq_fetch_vars = [] for var in fetch_vars: if var.dtype != paddle.bool: - var_ = paddle.scale(fetch_vars[0], 1.0) + var_ = paddle.scale(var, 1.0) uniq_fetch_vars.append(var_) fetch_vars = uniq_fetch_vars @@ -652,12 +651,6 @@ def save_pir_inference_model( _check_vars('fetch_vars', fetch_vars) program = _get_valid_program(kwargs.get('program', None)) - - # do type promotion - program = process_type_promotion(program) - - clip_extra = kwargs.get('clip_extra', True) - # serialize and save program program = normalize_pir_program( program, @@ -665,7 +658,12 @@ def save_pir_inference_model( fetch_vars, skip_prune_program=kwargs.get('skip_prune_program', False), ) - paddle.core.serialize_pir_program(program, model_path, 1, True, False, True) + + readable = kwargs.get('readable', False) + trainable = kwargs.get('trainable', True) + paddle.core.serialize_pir_program( + program, model_path, 1, True, readable, trainable + ) # serialize and save params save_dirname = os.path.dirname(params_path) diff --git a/python/paddle/tensor/attribute.py b/python/paddle/tensor/attribute.py index 9bd7f3c16c95dc..25a7819b5b2432 100644 --- a/python/paddle/tensor/attribute.py +++ b/python/paddle/tensor/attribute.py @@ -20,7 +20,7 @@ from paddle import _C_ops from ..base.data_feeder import check_type, check_variable_and_dtype -from ..base.framework import in_dynamic_or_pir_mode, in_pir_mode +from ..base.framework import in_dynamic_or_pir_mode, use_pir_api from ..common_ops_import import Variable from ..framework import LayerHelper, core from .creation import _complex_to_real_dtype, assign @@ -250,7 +250,7 @@ def is_integer(x): dtype = x.dtype is_int_dtype = False - if not in_pir_mode(): + if not use_pir_api(): is_int_dtype = ( dtype == core.VarDesc.VarType.UINT8 or dtype == core.VarDesc.VarType.INT8 @@ -260,7 +260,7 @@ def is_integer(x): ) else: is_int_dtype = ( - dtype == core.DataType.INT8 + dtype == core.DataType.UINT8 or dtype == core.DataType.INT8 or dtype == core.DataType.INT16 or dtype == core.DataType.INT32 diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 0250acb89dccc3..c7add4988b09d2 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -91,7 +91,7 @@ function(bash_test_modules TARGET_NAME) endif() endfunction() -function(set_pit_tests_properties) +function(set_pir_tests_properties) file(STRINGS "${CMAKE_SOURCE_DIR}/test/white_list/pir_op_test_white_list" PIR_OP_TESTS) foreach(IR_OP_TEST ${PIR_OP_TESTS}) @@ -299,6 +299,6 @@ if(${len} GREATER_EQUAL 1) add_dependencies(build_tests ${test_names}) endif() -set_pit_tests_properties() +set_pir_tests_properties() add_subdirectory(deprecated) diff --git a/test/auto_parallel/CMakeLists.txt b/test/auto_parallel/CMakeLists.txt index 34e5b5f651ce75..1cc785ba549b41 100644 --- a/test/auto_parallel/CMakeLists.txt +++ b/test/auto_parallel/CMakeLists.txt @@ -308,4 +308,4 @@ endif() py_test_modules(test_job_schedule_profiler_range MODULES test_job_schedule_profiler_range) -set_pit_tests_properties() +set_pir_tests_properties() diff --git a/test/collective/fleet/test_dygraph_recompute_for_eager.py b/test/collective/fleet/test_dygraph_recompute_for_eager.py index 288f69c03d9332..790d47b6b59487 100644 --- a/test/collective/fleet/test_dygraph_recompute_for_eager.py +++ b/test/collective/fleet/test_dygraph_recompute_for_eager.py @@ -75,6 +75,7 @@ def __init__( use_raw_recompute=False, recompute_kwargs={}, raise_value_error=False, + recompute_use_kwargs_as_inputs=False, ): super().__init__() self.recompute_blocks = recompute_blocks @@ -115,6 +116,7 @@ def __init__( self.runfunc2, self.runfunc3, self.runfunc4 ), ] + self.recompute_use_kwargs_as_inputs = recompute_use_kwargs_as_inputs def forward(self, inputs): if self.use_fleet_sq and not self.use_raw_recompute: @@ -135,9 +137,14 @@ def forward(self, inputs): ) for i in range(len(self.layers)): if i in self.recompute_blocks: - inputs = recompute( - self.layers[i], inputs, pos, **recompute_kwargs - ) + if self.recompute_use_kwargs_as_inputs: + inputs = recompute( + self.layers[i], pos=pos, x=inputs, **recompute_kwargs + ) + else: + inputs = recompute( + self.layers[i], inputs, pos, **recompute_kwargs + ) else: inputs = self.layers[i](inputs, pos) @@ -153,6 +160,7 @@ def run_model( segments=1, enable_autocast=False, pure_fp16=False, + recompute_use_kwargs_as_inputs=False, ): gen = paddle.seed(10) gen.manual_seed(10) @@ -168,6 +176,7 @@ def run_model( segments=segments, recompute_kwargs=recompute_kwargs, raise_value_error=raise_value_error, + recompute_use_kwargs_as_inputs=recompute_use_kwargs_as_inputs, ) if pure_fp16: @@ -208,7 +217,12 @@ def run_model( class TestRecompute(unittest.TestCase): - def test_base_case(self, enable_autocast=False, pure_fp16=False): + def test_base_case( + self, + enable_autocast=False, + pure_fp16=False, + recompute_use_kwargs_as_inputs=False, + ): def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad): self.assertEqual(loss_ref, loss) self.assertEqual(param_ref, param) @@ -231,6 +245,7 @@ def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad): enable_autocast=enable_autocast, pure_fp16=pure_fp16, recompute_kwargs={"use_reentrant": flag}, + recompute_use_kwargs_as_inputs=recompute_use_kwargs_as_inputs, ) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) @@ -240,6 +255,7 @@ def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad): enable_autocast=enable_autocast, pure_fp16=pure_fp16, recompute_kwargs={"use_reentrant": flag}, + recompute_use_kwargs_as_inputs=recompute_use_kwargs_as_inputs, ) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) @@ -249,6 +265,7 @@ def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad): enable_autocast=enable_autocast, pure_fp16=pure_fp16, recompute_kwargs={"use_reentrant": flag}, + recompute_use_kwargs_as_inputs=recompute_use_kwargs_as_inputs, ) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) @@ -258,6 +275,7 @@ def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad): enable_autocast=enable_autocast, pure_fp16=pure_fp16, recompute_kwargs={"use_reentrant": flag}, + recompute_use_kwargs_as_inputs=recompute_use_kwargs_as_inputs, ) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) @@ -268,6 +286,7 @@ def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad): enable_autocast=enable_autocast, pure_fp16=pure_fp16, recompute_kwargs={"use_reentrant": flag}, + recompute_use_kwargs_as_inputs=recompute_use_kwargs_as_inputs, ) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) @@ -291,23 +310,34 @@ def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad): def test_fc_net_with_dropout(self): self.test_base_case() + self.test_base_case(recompute_use_kwargs_as_inputs=True) def test_fc_net_without_restore_rng(self): for flag in [True, False]: - loss_ref, param_ref, grad_ref = run_model( - recompute_block=[2], - recompute_kwargs={ - "preserve_rng_state": False, - "use_reentrant": flag, - }, - enable_autocast=True, - ) + for recompute_use_kwargs_as_inputs in [True, False]: + loss_ref, param_ref, grad_ref = run_model( + recompute_block=[2], + recompute_kwargs={ + "preserve_rng_state": False, + "use_reentrant": flag, + }, + enable_autocast=True, + recompute_use_kwargs_as_inputs=recompute_use_kwargs_as_inputs, + ) def test_fc_net_with_amp(self): self.test_base_case(enable_autocast=True) + self.test_base_case( + enable_autocast=True, recompute_use_kwargs_as_inputs=True + ) def test_fc_net_with_fp16(self): self.test_base_case(enable_autocast=True, pure_fp16=True) + self.test_base_case( + enable_autocast=True, + pure_fp16=True, + recompute_use_kwargs_as_inputs=True, + ) def test_recompute_kwargs(self): paddle.set_device("gpu") @@ -315,7 +345,7 @@ def test_recompute_kwargs(self): pos.stop_gradient = False kwargs = {"pos": pos, "use_reentrant": True} - with self.assertRaises(ValueError): + with self.assertRaises(TypeError): loss_ref, param_ref, grad_ref = run_model( recompute_block=[2], recompute_kwargs=kwargs, @@ -328,46 +358,57 @@ def test_recompute_kwargs(self): ) def test_recompute_inputs_with_param(self): - pos = paddle.randn(shape=[10, 10], dtype="float32") - new_pos = EagerParamBase( - shape=pos.shape, dtype=pos.dtype, name=pos.name - ) - pos._share_buffer_to(new_pos) - new_pos.stop_gradient = False + for flag in [True, False]: + for recompute_use_kwargs_as_inputs in [True, False]: + pos = paddle.randn(shape=[10, 10], dtype="float32") + new_pos = EagerParamBase( + shape=pos.shape, dtype=pos.dtype, name=pos.name + ) + pos._share_buffer_to(new_pos) + new_pos.stop_gradient = False - loss, param, grad = run_model( - recompute_block=[], recompute_kwargs={"pos": new_pos} - ) + loss, param, grad = run_model( + recompute_block=[2, 4], + recompute_kwargs={"pos": new_pos, "use_reentrant": flag}, + recompute_use_kwargs_as_inputs=recompute_use_kwargs_as_inputs, + ) - loss_ref, param_ref, grad_ref = run_model( - recompute_block=[1, 2, 3], recompute_kwargs={"pos": new_pos} - ) + loss_ref, param_ref, grad_ref = run_model( + recompute_block=[1, 2, 3], + recompute_kwargs={"pos": new_pos, "use_reentrant": flag}, + recompute_use_kwargs_as_inputs=recompute_use_kwargs_as_inputs, + ) - self.assertEqual(loss_ref, loss) - self.assertEqual(param_ref, param) - self.assertEqual(grad_ref, grad) + self.assertEqual(loss_ref, loss) + self.assertEqual(param_ref, param) + self.assertEqual(grad_ref, grad) def test_recompute_inputs_with_tuple(self): - pos = paddle.randn(shape=[10, 10], dtype="float32") - new_pos = EagerParamBase( - shape=pos.shape, dtype=pos.dtype, name=pos.name - ) - pos._share_buffer_to(new_pos) - pos.stop_gradient = False - new_pos.stop_gradient = False - - loss, param, grad = run_model( - recompute_block=[2, 4], recompute_kwargs={"pos": (pos,)} - ) + for flag in [True, False]: + for recompute_use_kwargs_as_inputs in [True, False]: + pos = paddle.randn(shape=[10, 10], dtype="float32") + new_pos = EagerParamBase( + shape=pos.shape, dtype=pos.dtype, name=pos.name + ) + pos._share_buffer_to(new_pos) + pos.stop_gradient = False + new_pos.stop_gradient = False + + loss, param, grad = run_model( + recompute_block=[2, 4], + recompute_kwargs={"pos": (pos,), "use_reentrant": flag}, + recompute_use_kwargs_as_inputs=recompute_use_kwargs_as_inputs, + ) - loss_ref, param_ref, grad_ref = run_model( - recompute_block=[1, 2, 3], - recompute_kwargs={"pos": (new_pos,)}, - ) + loss_ref, param_ref, grad_ref = run_model( + recompute_block=[1, 2, 3], + recompute_kwargs={"pos": (new_pos,), "use_reentrant": flag}, + recompute_use_kwargs_as_inputs=recompute_use_kwargs_as_inputs, + ) - self.assertEqual(loss_ref, loss) - self.assertEqual(param_ref, param) - self.assertEqual(grad_ref, grad) + self.assertEqual(loss_ref, loss) + self.assertEqual(param_ref, param) + self.assertEqual(grad_ref, grad) if __name__ == '__main__': diff --git a/test/contrib/test_d2s_amp_controlflow.py b/test/contrib/test_d2s_amp_controlflow.py index 6e2b965491d066..ebe877d7d8b84e 100644 --- a/test/contrib/test_d2s_amp_controlflow.py +++ b/test/contrib/test_d2s_amp_controlflow.py @@ -63,7 +63,7 @@ def forward(self): class TestD2SAmpWithControlFlowOp(unittest.TestCase): def test_cond_op(self): model = Net_Cond() - model = paddle.jit.to_static(model) + model = paddle.jit.to_static(model, full_graph=True) model = paddle.amp.decorate( models=model, level='O2', save_dtype="float32" ) @@ -72,7 +72,7 @@ def test_cond_op(self): def test_while_op(self): model = Net_While() - model = paddle.jit.to_static(model) + model = paddle.jit.to_static(model, full_graph=True) model = paddle.amp.decorate( models=model, level='O2', save_dtype="float32" ) @@ -81,7 +81,7 @@ def test_while_op(self): def test_sub_block_fp32_op(self): model = Net_Sub_Block_FP32() - model = paddle.jit.to_static(model) + model = paddle.jit.to_static(model, full_graph=True) model = paddle.amp.decorate( models=model, level='O2', save_dtype="float32" ) diff --git a/test/cpp/fluid/math/concat_test.cc b/test/cpp/fluid/math/concat_test.cc index 080a659ecdbbc6..b93c7c9a4870bd 100644 --- a/test/cpp/fluid/math/concat_test.cc +++ b/test/cpp/fluid/math/concat_test.cc @@ -15,9 +15,9 @@ limitations under the License. */ #include #include "paddle/fluid/framework/tensor_util.h" -#include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/place.h" +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" /** * case 1: @@ -77,7 +77,7 @@ void ConcatCase1(DeviceContext* context) { input.push_back(input_a); input.push_back(input_b); - paddle::operators::math::ConcatFunctor concat_functor; + phi::funcs::ConcatFunctor concat_functor; concat_functor(*context, input, 0, &out); // check the dim of input_a, input_b @@ -182,7 +182,7 @@ void ConcatCase2(DeviceContext* context) { input.push_back(input_a); input.push_back(input_b); - paddle::operators::math::ConcatFunctor concat_functor; + phi::funcs::ConcatFunctor concat_functor; concat_functor(*context, input, 1, &out); // check the dim of input_a, input_b @@ -291,7 +291,7 @@ void ConcatCase3(DeviceContext* context) { input.push_back(input_a); input.push_back(input_b); - paddle::operators::math::ConcatFunctor concat_functor; + phi::funcs::ConcatFunctor concat_functor; concat_functor(*context, input, 2, &out); // check the dim of input_a, input_b @@ -402,7 +402,7 @@ void ConcatCase4(DeviceContext* context) { input.push_back(input_a); input.push_back(input_b); - paddle::operators::math::ConcatFunctor concat_functor; + phi::funcs::ConcatFunctor concat_functor; concat_functor(*context, input, 1, &out); context->Wait(); diff --git a/test/cpp/inference/CMakeLists.txt b/test/cpp/inference/CMakeLists.txt index 4b7dcf2c0d342a..6d8bf40b0110c9 100644 --- a/test/cpp/inference/CMakeLists.txt +++ b/test/cpp/inference/CMakeLists.txt @@ -1,6 +1,7 @@ add_definitions(-DPADDLE_DLL_EXPORT) if(WITH_TESTING) include(test.cmake) # some generic cmake function for inference + include(test_cases.cmake) endif() add_subdirectory(analysis) diff --git a/test/cpp/inference/api/CMakeLists.txt b/test/cpp/inference/api/CMakeLists.txt index 722e5720744723..008bd9d7da354f 100644 --- a/test/cpp/inference/api/CMakeLists.txt +++ b/test/cpp/inference/api/CMakeLists.txt @@ -612,16 +612,32 @@ if(WITH_TESTING AND WITH_INFERENCE_API_TEST) # build test binary to be used in subsequent tests inference_analysis_api_test_build(${LEXICAL_TEST_APP} ${LEXICAL_TEST_APP_SRC}) - + # run lexcial analysis test + inference_analysis_api_lexical_test_run( + test_analyzer_lexical_gru ${LEXICAL_TEST_APP} ${GRU_MODEL_PATH} + ${GRU_DATA_PATH}) + # run bfloat16 lexical analysis test + inference_analysis_api_lexical_bfloat16_test_run( + test_analyzer_lexical_gru_bfloat16 ${LEXICAL_TEST_APP} ${GRU_MODEL_PATH} + ${GRU_DATA_PATH}) + # run post-training quantization lexical analysis test + inference_analysis_api_lexical_int8_test_run( + test_analyzer_lexical_gru_int8 + ${LEXICAL_TEST_APP} + ${GRU_MODEL_PATH} + ${GRU_DATA_PATH} + true # enable_int8_ptq + false # enable_int8_qat + false) # fuse_multi_gru # run post-training quantization lexical analysis test with multi_gru fuse - # inference_analysis_api_lexical_int8_test_run( - # test_analyzer_lexical_gru_int8_multi_gru - # ${LEXICAL_TEST_APP} - # ${GRU_MODEL_PATH} - # ${GRU_DATA_PATH} - # true # enable_int8_ptq - # false # enable_int8_qat - # true) # fuse_multi_gru + inference_analysis_api_lexical_int8_test_run( + test_analyzer_lexical_gru_int8_multi_gru + ${LEXICAL_TEST_APP} + ${GRU_MODEL_PATH} + ${GRU_DATA_PATH} + true # enable_int8_ptq + false # enable_int8_qat + true) # fuse_multi_gru # run qat gru test set(QAT_GRU_MODEL_ARCHIVE "GRU_quant_acc.tar.gz") diff --git a/test/cpp/inference/test.cmake b/test/cpp/inference/test.cmake index d394c47f68a055..03839871580d70 100644 --- a/test/cpp/inference/test.cmake +++ b/test/cpp/inference/test.cmake @@ -73,35 +73,6 @@ function(inference_download_and_uncompress_without_verify INSTALL_DIR URL INSTALL_COMMAND "") endfunction() -set(WORD2VEC_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/word2vec") -if(NOT EXISTS ${WORD2VEC_INSTALL_DIR}/word2vec.inference.model.tar.gz) - inference_download_and_uncompress_without_verify( - ${WORD2VEC_INSTALL_DIR} ${INFERENCE_URL} "word2vec.inference.model.tar.gz") -endif() -set(WORD2VEC_MODEL_DIR "${WORD2VEC_INSTALL_DIR}/word2vec.inference.model") - -set(IMG_CLS_RESNET_INSTALL_DIR - "${INFERENCE_DEMO_INSTALL_DIR}/image_classification_resnet") -if(NOT EXISTS - ${IMG_CLS_RESNET_INSTALL_DIR}/image_classification_resnet.inference.model.tgz -) - inference_download_and_uncompress_without_verify( - ${IMG_CLS_RESNET_INSTALL_DIR} ${INFERENCE_URL} - "image_classification_resnet.inference.model.tgz") -endif() -set(IMG_CLS_RESNET_MODEL_DIR - "${IMG_CLS_RESNET_INSTALL_DIR}/image_classification_resnet.inference.model") - -if(WITH_ONNXRUNTIME) - set(MOBILENETV2_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/MobileNetV2") - if(NOT EXISTS ${MOBILENETV2_INSTALL_DIR}/MobileNetV2.inference.model.tar.gz) - inference_download_and_uncompress_without_verify( - ${MOBILENETV2_INSTALL_DIR} ${INFERENCE_URL} - "MobileNetV2.inference.model.tar.gz") - endif() - set(MOBILENETV2_MODEL_DIR "${MOBILENETV2_INSTALL_DIR}/MobileNetV2") -endif() - function(inference_base_test_build TARGET) set(options "") set(oneValueArgs "") diff --git a/test/cpp/inference/test_cases.cmake b/test/cpp/inference/test_cases.cmake new file mode 100644 index 00000000000000..9efc7a14e190a0 --- /dev/null +++ b/test/cpp/inference/test_cases.cmake @@ -0,0 +1,34 @@ +include(test.cmake) # some generic cmake function for inference + +set(WORD2VEC_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/word2vec") + +if(NOT EXISTS ${WORD2VEC_INSTALL_DIR}/word2vec.inference.model.tar.gz) + inference_download_and_uncompress_without_verify( + ${WORD2VEC_INSTALL_DIR} ${INFERENCE_URL} "word2vec.inference.model.tar.gz") +endif() + +set(WORD2VEC_MODEL_DIR "${WORD2VEC_INSTALL_DIR}/word2vec.inference.model") + +set(IMG_CLS_RESNET_INSTALL_DIR + "${INFERENCE_DEMO_INSTALL_DIR}/image_classification_resnet") + +if(NOT EXISTS + ${IMG_CLS_RESNET_INSTALL_DIR}/image_classification_resnet.inference.model.tgz +) + inference_download_and_uncompress_without_verify( + ${IMG_CLS_RESNET_INSTALL_DIR} ${INFERENCE_URL} + "image_classification_resnet.inference.model.tgz") +endif() + +set(IMG_CLS_RESNET_MODEL_DIR + "${IMG_CLS_RESNET_INSTALL_DIR}/image_classification_resnet.inference.model") + +if(WITH_ONNXRUNTIME) + set(MOBILENETV2_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/MobileNetV2") + if(NOT EXISTS ${MOBILENETV2_INSTALL_DIR}/MobileNetV2.inference.model.tar.gz) + inference_download_and_uncompress_without_verify( + ${MOBILENETV2_INSTALL_DIR} ${INFERENCE_URL} + "MobileNetV2.inference.model.tar.gz") + endif() + set(MOBILENETV2_MODEL_DIR "${MOBILENETV2_INSTALL_DIR}/MobileNetV2") +endif() diff --git a/test/cpp/prim/CMakeLists.txt b/test/cpp/prim/CMakeLists.txt index 7f5b3af0525889..51b5bb70a6e225 100644 --- a/test/cpp/prim/CMakeLists.txt +++ b/test/cpp/prim/CMakeLists.txt @@ -23,7 +23,7 @@ if(NOT (NOT WITH_PYTHON AND ON_INFER)) cc_library(init_env_utils SRCS init_env_utils.cc) target_compile_definitions(init_env_utils PUBLIC PADDLE_DLL_EXPORT) - paddle_test(test_comp_eager SRCS test_eager_prim.cc DEPS init_env_utils) + paddle_test(test_comp_eager SRCS test_eager_prim.cc init_env_utils.cc) endif() # skip win32 since wget is not installed by default on windows machine. diff --git a/test/custom_op/test_custom_relu_model.py b/test/custom_op/test_custom_relu_model.py index a972831a2738d6..bc2693f5e23e25 100644 --- a/test/custom_op/test_custom_relu_model.py +++ b/test/custom_op/test_custom_relu_model.py @@ -140,7 +140,9 @@ def train_model(self, use_custom_op=False, dy2stat=False): net = Net(self.in_dim, self.out_dim, use_custom_op) if dy2stat: - net = paddle.jit.to_static(net, input_spec=[self.x_spec]) + net = paddle.jit.to_static( + net, input_spec=[self.x_spec], full_graph=True + ) mse_loss = paddle.nn.MSELoss() sgd = paddle.optimizer.SGD( learning_rate=0.1, parameters=net.parameters() diff --git a/test/custom_op/test_inference_inplace.py b/test/custom_op/test_inference_inplace.py index 64219d8e148d00..d23a2eeb970850 100644 --- a/test/custom_op/test_inference_inplace.py +++ b/test/custom_op/test_inference_inplace.py @@ -72,6 +72,7 @@ def setUp(self): shape=[None, 4], dtype='float32', name='x' ), ], + full_graph=True, ) paddle.jit.save( model, diff --git a/test/deprecated/CMakeLists.txt b/test/deprecated/CMakeLists.txt index 02f1a575411e4e..ffaf747a547d08 100644 --- a/test/deprecated/CMakeLists.txt +++ b/test/deprecated/CMakeLists.txt @@ -91,7 +91,7 @@ function(bash_test_modules TARGET_NAME) endif() endfunction() -function(set_pit_tests_properties) +function(set_pir_tests_properties) file(STRINGS "${CMAKE_SOURCE_DIR}/test/white_list/pir_op_test_white_list" PIR_OP_TESTS) foreach(IR_OP_TEST ${PIR_OP_TESTS}) @@ -164,4 +164,4 @@ if(WITH_TESTING) endif() -set_pit_tests_properties() +set_pir_tests_properties() diff --git a/test/deprecated/custom_runtime/test_custom_cpu_plugin.py b/test/deprecated/custom_runtime/test_custom_cpu_plugin.py index b92df8def9dd30..0159c59d164297 100755 --- a/test/deprecated/custom_runtime/test_custom_cpu_plugin.py +++ b/test/deprecated/custom_runtime/test_custom_cpu_plugin.py @@ -276,7 +276,9 @@ def _test_custom_device_mix_precision(self): self.temp_dir = tempfile.TemporaryDirectory() model = resnet50(True) net = to_static( - model, input_spec=[InputSpec(shape=[None, 3, 224, 224], name='x')] + model, + input_spec=[InputSpec(shape=[None, 3, 224, 224], name='x')], + full_graph=True, ) paddle.jit.save( net, os.path.join(self.temp_dir.name, 'resnet50/inference') diff --git a/test/deprecated/custom_runtime/test_custom_cpu_to_static.py b/test/deprecated/custom_runtime/test_custom_cpu_to_static.py index a9e863cf5d61f9..630e5b79783b37 100644 --- a/test/deprecated/custom_runtime/test_custom_cpu_to_static.py +++ b/test/deprecated/custom_runtime/test_custom_cpu_to_static.py @@ -160,7 +160,9 @@ def forward(self, x): # convert to static model build_strategy = paddle.static.BuildStrategy() - mnist = paddle.jit.to_static(model, build_strategy=build_strategy) + mnist = paddle.jit.to_static( + model, build_strategy=build_strategy, full_graph=True + ) # data loader transform = paddle.vision.transforms.Compose( diff --git a/test/deprecated/distribution/CMakeLists.txt b/test/deprecated/distribution/CMakeLists.txt index 27449f890fb3f3..21e9932d463b68 100644 --- a/test/deprecated/distribution/CMakeLists.txt +++ b/test/deprecated/distribution/CMakeLists.txt @@ -11,4 +11,4 @@ foreach(TEST_OP ${TEST_OPS}) py_test_modules(${TEST_OP} MODULES ${TEST_OP}) endforeach() -set_pit_tests_properties() +set_pir_tests_properties() diff --git a/test/deprecated/fft/CMakeLists.txt b/test/deprecated/fft/CMakeLists.txt index 2839c2ea7231fc..a31ec8e1f21370 100644 --- a/test/deprecated/fft/CMakeLists.txt +++ b/test/deprecated/fft/CMakeLists.txt @@ -8,4 +8,4 @@ foreach(TEST_OP ${TEST_OPS}) py_test_modules(${TEST_OP} MODULES ${TEST_OP}) endforeach() -set_pit_tests_properties() +set_pir_tests_properties() diff --git a/test/deprecated/ir/pir/test_standalone_pir.py b/test/deprecated/ir/pir/test_standalone_pir.py index 866a6fe105cc56..01a8ea95881c2f 100644 --- a/test/deprecated/ir/pir/test_standalone_pir.py +++ b/test/deprecated/ir/pir/test_standalone_pir.py @@ -185,7 +185,7 @@ def test_with_pir(self): build_strategy = paddle.static.BuildStrategy() build_strategy.enable_inplace = False - @paddle.jit.to_static(build_strategy=build_strategy) + @paddle.jit.to_static(build_strategy=build_strategy, full_graph=True) def func(x, y): return x * y @@ -210,7 +210,7 @@ def test_with_pir(self): build_strategy = paddle.static.BuildStrategy() build_strategy.enable_inplace = False - @paddle.jit.to_static(build_strategy=build_strategy) + @paddle.jit.to_static(build_strategy=build_strategy, full_graph=True) def func(x, y): x = x.reshape([-1, 2, 2]) y = y.reshape([-1, 2, 2]) diff --git a/test/deprecated/ir/pir/translator/CMakeLists.txt b/test/deprecated/ir/pir/translator/CMakeLists.txt index 5616092f35d5fb..a175ee61e45276 100644 --- a/test/deprecated/ir/pir/translator/CMakeLists.txt +++ b/test/deprecated/ir/pir/translator/CMakeLists.txt @@ -33,6 +33,7 @@ list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_pull_sparse_v2_translator) list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_random_routing_translator) list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_limit_by_capacity_translator) list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_global_scatter_translator) +list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_global_gather_translator) if(NOT WITH_DISTRIBUTE) list(REMOVE_ITEM TEST_INTERP_CASES ${DISTRIBUTED_OP_TRANSLATOR_TEST}) diff --git a/test/deprecated/ir/pir/translator/test_global_gather_translator.py b/test/deprecated/ir/pir/translator/test_global_gather_translator.py new file mode 100644 index 00000000000000..cbd883aaf6500c --- /dev/null +++ b/test/deprecated/ir/pir/translator/test_global_gather_translator.py @@ -0,0 +1,61 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import test_op_translator + +import paddle +from paddle.base.layer_helper import LayerHelper + + +class TestGlobalGatherOpTranslator( + test_op_translator.TestOpWithBackwardTranslator +): + def append_op(self): + self.forward_op_type = "global_gather" + self.backward_op_type = "global_scatter" + x = paddle.ones( + shape=( + 1, + 1, + ), + dtype='int64', + ) + local_count = paddle.ones(shape=(1,), dtype='int64') + global_count = paddle.ones(shape=(1,), dtype='int64') + x.stop_gradient = False + local_count.stop_gradient = False + global_count.stop_gradient = False + out = paddle.ones(shape=(1,), dtype='int64') + attrs = {'ring_id': 0, 'use_calc_stream': False} + helper = LayerHelper(self.forward_op_type) + helper.append_op( + type=self.forward_op_type, + inputs={ + "X": x, + 'local_count': local_count, + 'global_count': global_count, + }, + outputs={"Out": out}, + attrs=attrs, + ) + return out + + def test_translator(self): + self.check() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/deprecated/ir/pir/translator/test_global_scatter_translator.py b/test/deprecated/ir/pir/translator/test_global_scatter_translator.py index c9dcfed3e5acc4..fb349a30b95e20 100644 --- a/test/deprecated/ir/pir/translator/test_global_scatter_translator.py +++ b/test/deprecated/ir/pir/translator/test_global_scatter_translator.py @@ -20,27 +20,38 @@ from paddle.base.layer_helper import LayerHelper -class TestDistributedLookupTableOpTranslator( - test_op_translator.TestOpTranslator +class TestGlobalScatterOpTranslator( + test_op_translator.TestOpWithBackwardTranslator ): def append_op(self): - self.op_type = "global_scatter" - x = paddle.ones(shape=(4, 8), dtype='float32') - local_count = paddle.to_tensor([0, 1], dtype='int64') - global_count = paddle.to_tensor([0, 1], dtype='int64') - out = paddle.ones(shape=(2, 8), dtype='float32') + self.forward_op_type = "global_scatter" + self.backward_op_type = "global_gather" + x = paddle.ones( + shape=( + 1, + 1, + ), + dtype='int64', + ) + local_count = paddle.ones(shape=(1,), dtype='int64') + global_count = paddle.ones(shape=(1,), dtype='int64') + x.stop_gradient = False + local_count.stop_gradient = False + global_count.stop_gradient = False + out = paddle.ones(shape=(1,), dtype='int64') attrs = {'ring_id': 0, 'use_calc_stream': False} - helper = LayerHelper(self.op_type) + helper = LayerHelper(self.forward_op_type) helper.append_op( - type=self.op_type, + type=self.forward_op_type, inputs={ "X": x, - "local_count": local_count, - "global_count": global_count, + 'local_count': local_count, + 'global_count': global_count, }, outputs={"Out": out}, attrs=attrs, ) + return out def test_translator(self): self.check() diff --git a/test/deprecated/legacy_test/CMakeLists.txt b/test/deprecated/legacy_test/CMakeLists.txt index a42dcc42f8939a..4968e979a137ad 100644 --- a/test/deprecated/legacy_test/CMakeLists.txt +++ b/test/deprecated/legacy_test/CMakeLists.txt @@ -888,7 +888,7 @@ py_test_modules(test_stride MODULES test_stride ENVS FLAGS_use_stride_kernel=true) set_tests_properties(test_linalg_matrix_exp PROPERTIES TIMEOUT 120) -set_pit_tests_properties() +set_pir_tests_properties() set_tests_properties(test_fractional_max_pool2d_op PROPERTIES TIMEOUT 120) set_tests_properties(test_fractional_max_pool3d_op PROPERTIES TIMEOUT 120) diff --git a/test/deprecated/legacy_test/test_apply.py b/test/deprecated/legacy_test/test_apply.py index 2c11bd26e932cd..6c16ceb5b96f09 100644 --- a/test/deprecated/legacy_test/test_apply.py +++ b/test/deprecated/legacy_test/test_apply.py @@ -87,11 +87,11 @@ def fn(x, func): return y with paddle.jit.api.sot_mode_guard(False): - jit_g = paddle.jit.to_static(fn) + jit_g = paddle.jit.to_static(fn, full_graph=True) out_legacy_ir = jit_g(self.x, self.function) with paddle.pir_utils.IrGuard(): paddle.disable_static() - jit_g = paddle.jit.to_static(fn) + jit_g = paddle.jit.to_static(fn, full_graph=True) out_pir = jit_g(self.x, self.function) np.testing.assert_allclose( self.function(self.x).numpy(), out_legacy_ir.numpy(), rtol=1e-05 diff --git a/test/deprecated/legacy_test/test_elementwise_floordiv_op.py b/test/deprecated/legacy_test/test_elementwise_floordiv_op.py index ccfab0b9adf56a..e49f5687b1c9e5 100644 --- a/test/deprecated/legacy_test/test_elementwise_floordiv_op.py +++ b/test/deprecated/legacy_test/test_elementwise_floordiv_op.py @@ -29,7 +29,9 @@ def init_kernel_type(self): def setUp(self): self.op_type = "elementwise_floordiv" + self.prim_op_type = "comp" self.python_api = paddle.floor_divide + self.public_python_api = paddle.floor_divide self.dtype = np.int32 self.axis = -1 self.init_dtype() @@ -45,7 +47,7 @@ def setUp(self): self.outputs = {'Out': self.out} def test_check_output(self): - self.check_output(check_pir=True) + self.check_output(check_pir=True, check_prim_pir=True) def init_input_output(self): self.x = np.random.uniform(0, 10000, [10, 10]).astype(self.dtype) diff --git a/test/deprecated/legacy_test/test_instance_norm_op.py b/test/deprecated/legacy_test/test_instance_norm_op.py index ad9b12ed14bb97..2e9f9855d10332 100644 --- a/test/deprecated/legacy_test/test_instance_norm_op.py +++ b/test/deprecated/legacy_test/test_instance_norm_op.py @@ -196,7 +196,7 @@ def forward(self, x): def apply_to_static(net, use_cinn): build_strategy = paddle.static.BuildStrategy() build_strategy.build_cinn_pass = use_cinn - return paddle.jit.to_static(net, build_strategy=False) + return paddle.jit.to_static(net, build_strategy=False, full_graph=True) places = [paddle.CPUPlace()] diff --git a/test/deprecated/legacy_test/test_instance_norm_op_v2.py b/test/deprecated/legacy_test/test_instance_norm_op_v2.py index 95ca0c653a2925..8b600bbb0589f6 100644 --- a/test/deprecated/legacy_test/test_instance_norm_op_v2.py +++ b/test/deprecated/legacy_test/test_instance_norm_op_v2.py @@ -223,9 +223,9 @@ def test_check_output(self): atol=self.atol, check_prim=self.check_prim, check_pir=True, - check_prim_pir=False - if os.getenv("FLAGS_enable_pir_in_executor") - else True, + check_prim_pir=( + False if os.getenv("FLAGS_enable_pir_in_executor") else True + ), ) def test_check_grad(self): @@ -234,9 +234,9 @@ def test_check_grad(self): 'Y', check_prim=self.check_prim, check_pir=True, - check_prim_pir=False - if os.getenv("FLAGS_enable_pir_in_executor") - else True, + check_prim_pir=( + False if os.getenv("FLAGS_enable_pir_in_executor") else True + ), ) def init_dtype(self): @@ -284,9 +284,9 @@ def test_check_output(self): atol=self.atol, check_prim=self.check_prim, check_pir=True, - check_prim_pir=False - if os.getenv("FLAGS_enable_pir_in_executor") - else True, + check_prim_pir=( + False if os.getenv("FLAGS_enable_pir_in_executor") else True + ), ) def test_check_grad(self): @@ -298,9 +298,9 @@ def test_check_grad(self): max_relative_error=self.max_relative_error, check_prim=self.check_prim, check_pir=True, - check_prim_pir=False - if os.getenv("FLAGS_enable_pir_in_executor") - else True, + check_prim_pir=( + False if os.getenv("FLAGS_enable_pir_in_executor") else True + ), ) @@ -364,9 +364,9 @@ def test_check_output(self): place, check_prim=self.check_prim, check_pir=True, - check_prim_pir=False - if os.getenv("FLAGS_enable_pir_in_executor") - else True, + check_prim_pir=( + False if os.getenv("FLAGS_enable_pir_in_executor") else True + ), ) def test_check_grad(self): @@ -378,9 +378,9 @@ def test_check_grad(self): user_defined_grads=self.user_defined_grads, check_prim=self.check_prim, check_pir=True, - check_prim_pir=False - if os.getenv("FLAGS_enable_pir_in_executor") - else True, + check_prim_pir=( + False if os.getenv("FLAGS_enable_pir_in_executor") else True + ), ) @@ -402,7 +402,7 @@ def forward(self, x): def apply_to_static(net, use_cinn): build_strategy = paddle.static.BuildStrategy() build_strategy.build_cinn_pass = use_cinn - return paddle.jit.to_static(net, build_strategy=False) + return paddle.jit.to_static(net, build_strategy=False, full_graph=True) class TestPrimForwardAndBackward(unittest.TestCase): diff --git a/test/deprecated/legacy_test/test_jit_layer.py b/test/deprecated/legacy_test/test_jit_layer.py index 2289840da8cc0f..88932577e410b1 100644 --- a/test/deprecated/legacy_test/test_jit_layer.py +++ b/test/deprecated/legacy_test/test_jit_layer.py @@ -37,7 +37,9 @@ def __init__(self): self.fc2 = paddle.nn.Linear(4, 4) self._bias = 0.4 - @paddle.jit.to_static(input_spec=[InputSpec([None, 4], dtype='float32')]) + @paddle.jit.to_static( + input_spec=[InputSpec([None, 4], dtype='float32')], full_graph=True + ) def forward(self, x): out = self.fc1(x) out = self.fc2(out) @@ -45,7 +47,9 @@ def forward(self, x): out = paddle.mean(out) return out - @paddle.jit.to_static(input_spec=[InputSpec([None, 4], dtype='float32')]) + @paddle.jit.to_static( + input_spec=[InputSpec([None, 4], dtype='float32')], full_graph=True + ) def infer(self, input): out = self.fc2(input) out = out + self._bias @@ -85,7 +89,8 @@ def __init__(self): self.linear = paddle.nn.Linear(80, 80) @paddle.jit.to_static( - input_spec=[InputSpec(shape=[None, 80], dtype='float32')] + input_spec=[InputSpec(shape=[None, 80], dtype='float32')], + full_graph=True, ) def forward(self, x): out = self.linear(x) diff --git a/test/deprecated/legacy_test/test_paddle_save_load_binary.py b/test/deprecated/legacy_test/test_paddle_save_load_binary.py index 22b62e082cc94a..2cd10d5ab73b30 100644 --- a/test/deprecated/legacy_test/test_paddle_save_load_binary.py +++ b/test/deprecated/legacy_test/test_paddle_save_load_binary.py @@ -24,6 +24,8 @@ import paddle from paddle import base from paddle.base import framework +from paddle.framework.io_utils import get_value, is_pir_fetch_var, set_value +from paddle.pir_utils import test_with_pir_api IMAGE_SIZE = 784 @@ -42,6 +44,8 @@ def set_zero(self, prog, place, scope=None): scope = base.global_scope() for var in prog.list_vars(): if isinstance(var, framework.Parameter) or var.persistable: + if is_pir_fetch_var(var): + continue ten = scope.find_var(var.name).get_tensor() if ten is not None: ten.set(np.zeros_like(np.array(ten)), place) @@ -55,7 +59,7 @@ def predicate(var): vars = filter(predicate, program.list_vars()) for var in vars: paddle.save( - var.get_value(), + get_value(var), os.path.join(dirname, var.name), use_binary_format=True, ) @@ -68,8 +72,9 @@ def predicate(var): for var in var_list: var_load = paddle.load(os.path.join(dirname, var.name)) # set var_load to scope - var.set_value(var_load) + set_value(var, var_load) + @test_with_pir_api def test_replace_save_load_vars(self): paddle.enable_static() with new_program_scope(): @@ -91,6 +96,8 @@ def test_replace_save_load_vars(self): base_map = {} for var in prog.list_vars(): if isinstance(var, framework.Parameter) or var.persistable: + if is_pir_fetch_var(var): + continue t = np.array( base.global_scope().find_var(var.name).get_tensor() ) @@ -112,7 +119,7 @@ def test_replace_save_load_vars(self): ) for var in prog.list_vars(): - if var.persistable: + if var.persistable and not is_pir_fetch_var(var): new_t = np.array( base.global_scope().find_var(var.name).get_tensor() ) @@ -129,7 +136,7 @@ def test_replace_save_load_vars(self): self.set_zero(prog, place) self.replace_load_vars(prog, path_vars2) for var in prog.list_vars(): - if var.persistable: + if var.persistable and not is_pir_fetch_var(var): new_t = np.array( base.global_scope().find_var(var.name).get_tensor() ) @@ -137,6 +144,7 @@ def test_replace_save_load_vars(self): np.testing.assert_array_equal(new_t, base_t) + @test_with_pir_api def test_save_load_lod_tensor(self): paddle.enable_static() OUTPUT_NUM = 32 @@ -149,7 +157,7 @@ def test_save_load_lod_tensor(self): OUTPUT_NUM, name='fc_vars', ) - prog = base.default_main_program() + prog = paddle.static.default_main_program() place = ( base.CPUPlace() if not paddle.base.core.is_compiled_with_cuda() @@ -167,15 +175,15 @@ def test_save_load_lod_tensor(self): IMAGE_SIZE, OUTPUT_NUM, ]: - tensor = var.get_value() + tensor = get_value(var) paddle.save( tensor, dirname + 'fc_vars.w_0', use_binary_format=True ) break - origin = np.array(var.get_value()) - var.set_value(np.zeros_like(origin)) - is_zeros = np.array(var.get_value()) + origin = np.array(get_value(var)) + set_value(var, np.zeros_like(origin)) + is_zeros = np.array(get_value(var)) loaded_tensor = paddle.load(dirname + 'fc_vars.w_0') self.assertTrue(isinstance(loaded_tensor, base.core.LoDTensor)) @@ -234,6 +242,7 @@ def test_save_load_lod_tensor(self): with self.assertRaises(NotImplementedError): paddle.framework.io._load_lod_tensor(1) + @test_with_pir_api def test_save_load_selected_rows(self): paddle.enable_static() place = ( @@ -299,3 +308,7 @@ def test_save_load_selected_rows(self): paddle.framework.io._save_selected_rows(selected_rows, 1) with self.assertRaises(NotImplementedError): paddle.framework.io._load_selected_rows(1) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/deprecated/legacy_test/test_run_program_op.py b/test/deprecated/legacy_test/test_run_program_op.py index 74af64b7adb25a..9defedc052f735 100644 --- a/test/deprecated/legacy_test/test_run_program_op.py +++ b/test/deprecated/legacy_test/test_run_program_op.py @@ -486,7 +486,7 @@ def train(self, to_static): net = Net() if to_static: - net = paddle.jit.to_static(net) + net = paddle.jit.to_static(net, full_graph=True) sgd = paddle.optimizer.SGD(0.01, parameters=net.parameters()) for i in range(self.iter): diff --git a/test/deprecated/legacy_test/test_save_inference_model_conditional_op.py b/test/deprecated/legacy_test/test_save_inference_model_conditional_op.py index 19466e3cdc9f43..bec0bc539c9a51 100644 --- a/test/deprecated/legacy_test/test_save_inference_model_conditional_op.py +++ b/test/deprecated/legacy_test/test_save_inference_model_conditional_op.py @@ -86,6 +86,7 @@ def test_while_op(self): input_spec=[ paddle.static.InputSpec(shape=[1, 3, 8, 8], dtype='float32') ], + full_graph=True, ) root_path = tempfile.TemporaryDirectory() model_file = os.path.join(root_path.name, "while_net") @@ -111,7 +112,9 @@ def test_for_op(self): paddle.disable_static() net = ForNet() net = paddle.jit.to_static( - net, input_spec=[paddle.static.InputSpec(shape=[1], dtype='int32')] + net, + input_spec=[paddle.static.InputSpec(shape=[1], dtype='int32')], + full_graph=True, ) root_path = tempfile.TemporaryDirectory() model_file = os.path.join(root_path.name, "for_net") @@ -137,7 +140,9 @@ def test_if_op(self): paddle.disable_static() net = IfElseNet() net = paddle.jit.to_static( - net, input_spec=[paddle.static.InputSpec(shape=[1], dtype='int32')] + net, + input_spec=[paddle.static.InputSpec(shape=[1], dtype='int32')], + full_graph=True, ) root_path = tempfile.TemporaryDirectory() model_file = os.path.join(root_path.name, "if_net") diff --git a/test/deprecated/legacy_test/test_sparse_slice_op.py b/test/deprecated/legacy_test/test_sparse_slice_op.py index 483720b0663a26..714a55d24f21c9 100644 --- a/test/deprecated/legacy_test/test_sparse_slice_op.py +++ b/test/deprecated/legacy_test/test_sparse_slice_op.py @@ -15,6 +15,7 @@ import unittest import numpy as np +from utils import compare_legacy_with_pt import paddle @@ -206,26 +207,32 @@ def check_result_with_list(self, x, axes, starts, ends, format='coo'): if format == 'coo': self._check_result_coo(np_x, axes, starts, ends) + @compare_legacy_with_pt def test_coo_5d(self): for item in data_5d: self.check_result_with_shape(*item, format='coo') + @compare_legacy_with_pt def test_coo_4d(self): for item in data_4d: self.check_result_with_shape(*item, format='coo') + @compare_legacy_with_pt def test_coo_3d(self): for item in data_3d: self.check_result_with_shape(*item, format='coo') + @compare_legacy_with_pt def test_coo_2d(self): for item in data_2d: self.check_result_with_shape(*item, format='coo') + @compare_legacy_with_pt def test_coo_1d(self): x = [-49, 55, -5, 0, 3, 0, 0, -60, -21, 0, 0, 0] self.check_result_with_list(x, [0], [3], [5], format='coo') + @compare_legacy_with_pt def test_coo_1d_zero(self): x = [-49, 55, -5, 0, 3, 0, 0, -60, -21, 0, 0, 0] self.check_result_with_list(x, [0], [-3], [-1], format='coo') diff --git a/test/deprecated/legacy_test/test_sparse_sum_op.py b/test/deprecated/legacy_test/test_sparse_sum_op.py index 3690341c51dc0d..8d245508b3d3ef 100644 --- a/test/deprecated/legacy_test/test_sparse_sum_op.py +++ b/test/deprecated/legacy_test/test_sparse_sum_op.py @@ -15,6 +15,7 @@ import unittest import numpy as np +from utils import compare_legacy_with_pt import paddle @@ -172,6 +173,7 @@ def check_result_coo(self, x_shape, dims, keepdim, dtype=None): ) paddle.disable_static() + @compare_legacy_with_pt def test_sum(self): # 1d self.check_result_coo([5], None, False) diff --git a/test/deprecated/legacy_test/test_zero_dim_sundry_static_api_deprecated.py b/test/deprecated/legacy_test/test_zero_dim_sundry_static_api_deprecated.py new file mode 100644 index 00000000000000..cac15ad77b7b40 --- /dev/null +++ b/test/deprecated/legacy_test/test_zero_dim_sundry_static_api_deprecated.py @@ -0,0 +1,158 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Note: +# 0D Tensor indicates that the tensor's dimension is 0 +# 0D Tensor's shape is always [], numel is 1 +# which can be created by paddle.rand([]) + +import unittest + +import numpy as np +from decorator_helper import prog_scope + +import paddle + +# Use to test zero-dim of Sundry API, which is unique and can not be classified +# with others. It can be implemented here flexibly. + + +class TestSundryAPIStatic(unittest.TestCase): + def setUp(self): + paddle.enable_static() + self.exe = paddle.static.Executor() + + def assertShapeEqual(self, out, target_tuple): + if not paddle.framework.in_pir_mode(): + out_shape = list(out.shape) + else: + out_shape = out.shape + self.assertEqual(out_shape, target_tuple) + + @prog_scope() + def test_create_global_var(self): + zero_dim_var = paddle.static.create_global_var( + shape=[], value=0.5, dtype='float32' + ) + self.assertEqual(zero_dim_var.shape, ()) + prog = paddle.static.default_startup_program() + res = self.exe.run(prog, fetch_list=[zero_dim_var]) + self.assertEqual(res[0].shape, ()) + self.assertEqual(res[0], 0.5) + + @prog_scope() + def test_setitem(self): + # NOTE(zoooo0820): __setitem__ has gradient problem in static graph. + # To solve this, we may not support __setitem__ in static graph. + # These unit tests will delete soon. + + # case1: all axis have a scalar indice + x = paddle.arange(2 * 3 * 4 * 5).reshape((2, 3, 4, 5)) + x.stop_gradient = False + out = x * 2 + out = paddle.static.setitem(out, (1, 2, 3, 4), 10) + paddle.static.append_backward(out.sum()) + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out, x.grad_name]) + + self.assertEqual(out.shape, x.shape) + np.testing.assert_allclose(res[0][1, 2, 3, 4], np.array(10)) + self.assertEqual(res[1].shape, (2, 3, 4, 5)) + x_grad_expected = np.ones((2, 3, 4, 5)) * 2 + x_grad_expected[1, 2, 3, 4] = 0 + np.testing.assert_allclose(res[1], x_grad_expected) + + # case2: 0-D Tensor indice in some axis + # NOTE(zoooo0820): Now, int/slice with 0-D Tensor will still be + # treated as combined indexing, which is not support backward. + # There should have more test cases such as out[1, indice, :] = 0.5 when this + # problem is fixed. + x = paddle.randn((2, 3, 4, 5)) + x.stop_gradient = False + indice = paddle.full([], 1, dtype='int32') + out = x * 1 + out = paddle.static.setitem(out, (indice, indice), 0.5) + paddle.static.append_backward(out.sum()) + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out, x.grad_name]) + + self.assertEqual(out.shape, x.shape) + np.testing.assert_allclose(res[0][1, 1], np.ones((4, 5)) * 0.5) + x_grad_expected = np.ones((2, 3, 4, 5)) + x_grad_expected[1, 1] = 0 + np.testing.assert_allclose(res[1], x_grad_expected) + + # case3:0-D Tensor indice in some axis, value is a Tensor + # and there is broadcast + x = paddle.randn((2, 3, 4, 5)) + x.stop_gradient = False + v = paddle.ones((4, 5), dtype='float32') * 5 + v.stop_gradient = False + indice = paddle.full([], 1, dtype='int32') + out = x * 1 + out = paddle.static.setitem(out, indice, v) + paddle.static.append_backward(out.sum()) + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out, x.grad_name, v.grad_name]) + + self.assertEqual(out.shape, x.shape) + np.testing.assert_allclose(res[0][1], np.ones((3, 4, 5)) * 5) + x_grad_expected = np.ones((2, 3, 4, 5)) + x_grad_expected[1] = 0 + np.testing.assert_allclose(res[1], x_grad_expected) + + @prog_scope() + def test_static_auc(self): + x = paddle.full(shape=[3, 2], fill_value=0.25) + y = paddle.full(shape=[3], fill_value=1, dtype="int64") + out = paddle.static.auc(input=x, label=y)[0] + + prog = paddle.static.default_main_program() + res = self.exe.run( + prog, + fetch_list=[out], + ) + + self.assertEqual(res[0].shape, ()) + + @prog_scope() + def test_static_nn_prelu(self): + x1 = paddle.full([], 1.0, 'float32') + x1.stop_gradient = False + out1 = paddle.static.nn.prelu(x1, 'all') + grad_list = paddle.static.append_backward( + out1.sum(), parameter_list=[x1, out1] + ) + (_, x1_grad), (_, out1_grad) = grad_list + + prog = paddle.static.default_main_program() + self.exe.run(paddle.static.default_startup_program()) + res = self.exe.run( + prog, + fetch_list=[ + out1, + x1_grad, + out1_grad, + ], + ) + + self.assertEqual(res[0].shape, ()) + self.assertEqual(res[1].shape, ()) + self.assertEqual(res[2].shape, ()) + np.testing.assert_allclose(res[0], np.array(1)) + np.testing.assert_allclose(res[1], np.array(1)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/deprecated/prim/composite_ops/test_composite_batch_norm.py b/test/deprecated/prim/composite_ops/test_composite_batch_norm.py index c90f9c4b9c91c1..0da795957f9f4a 100644 --- a/test/deprecated/prim/composite_ops/test_composite_batch_norm.py +++ b/test/deprecated/prim/composite_ops/test_composite_batch_norm.py @@ -375,7 +375,7 @@ def test_forward(self): def apply_to_static(net, use_cinn): build_strategy = paddle.static.BuildStrategy() build_strategy.build_cinn_pass = use_cinn - return paddle.jit.to_static(net, build_strategy=False) + return paddle.jit.to_static(net, build_strategy=False, full_graph=True) class PrimeNet(paddle.nn.Layer): diff --git a/test/deprecated/prim/composite_ops/test_composite_layer_norm.py b/test/deprecated/prim/composite_ops/test_composite_layer_norm.py index 88be2af37551e8..cb2bb5a75d3a94 100644 --- a/test/deprecated/prim/composite_ops/test_composite_layer_norm.py +++ b/test/deprecated/prim/composite_ops/test_composite_layer_norm.py @@ -272,7 +272,7 @@ def test_forward(self): def apply_to_static(net, use_cinn): build_strategy = paddle.static.BuildStrategy() build_strategy.build_cinn_pass = use_cinn - return paddle.jit.to_static(net, build_strategy=False) + return paddle.jit.to_static(net, build_strategy=False, full_graph=True) class PrimeNet(paddle.nn.Layer): diff --git a/test/deprecated/prim/composite_ops/test_composite_softmax.py b/test/deprecated/prim/composite_ops/test_composite_softmax.py index 7f66453fe37e50..0f1d91af6bc515 100644 --- a/test/deprecated/prim/composite_ops/test_composite_softmax.py +++ b/test/deprecated/prim/composite_ops/test_composite_softmax.py @@ -129,7 +129,7 @@ def test_forward(self): def apply_to_static(net, use_cinn): build_strategy = paddle.static.BuildStrategy() build_strategy.build_cinn_pass = use_cinn - return paddle.jit.to_static(net, build_strategy=False) + return paddle.jit.to_static(net, build_strategy=False, full_graph=True) class PrimeNet(paddle.nn.Layer): diff --git a/test/deprecated/prim/prim/vjp/static/test_comp_add_grad.py b/test/deprecated/prim/prim/vjp/static/test_comp_add_grad.py index 211564f2cd7b00..271db4020b4b3b 100644 --- a/test/deprecated/prim/prim/vjp/static/test_comp_add_grad.py +++ b/test/deprecated/prim/prim/vjp/static/test_comp_add_grad.py @@ -24,7 +24,9 @@ def apply_to_static(net, use_cinn): build_strategy = paddle.static.BuildStrategy() build_strategy.build_cinn_pass = use_cinn - return paddle.jit.to_static(net, build_strategy=build_strategy) + return paddle.jit.to_static( + net, build_strategy=build_strategy, full_graph=True + ) class PrimeNet(paddle.nn.Layer): diff --git a/test/deprecated/prim/prim/vjp/static/test_comp_add_tanh_grad.py b/test/deprecated/prim/prim/vjp/static/test_comp_add_tanh_grad.py index 6320395e9ee4e3..cf2eb4f495e51e 100644 --- a/test/deprecated/prim/prim/vjp/static/test_comp_add_tanh_grad.py +++ b/test/deprecated/prim/prim/vjp/static/test_comp_add_tanh_grad.py @@ -24,7 +24,9 @@ def apply_to_static(net, use_cinn): build_strategy = paddle.static.BuildStrategy() build_strategy.build_cinn_pass = use_cinn - return paddle.jit.to_static(net, build_strategy=build_strategy) + return paddle.jit.to_static( + net, build_strategy=build_strategy, full_graph=True + ) class PrimeNet(paddle.nn.Layer): diff --git a/test/deprecated/prim/prim/vjp/static/test_comp_cast_grad.py b/test/deprecated/prim/prim/vjp/static/test_comp_cast_grad.py index ce0bcbd7895f2e..a997f9a87d4081 100644 --- a/test/deprecated/prim/prim/vjp/static/test_comp_cast_grad.py +++ b/test/deprecated/prim/prim/vjp/static/test_comp_cast_grad.py @@ -25,7 +25,9 @@ def apply_to_static(net, use_cinn): build_strategy = paddle.static.BuildStrategy() build_strategy.build_cinn_pass = use_cinn - return paddle.jit.to_static(net, build_strategy=build_strategy) + return paddle.jit.to_static( + net, build_strategy=build_strategy, full_graph=True + ) class PrimeNet(paddle.nn.Layer): diff --git a/test/deprecated/prim/prim/vjp/static/test_comp_div_grad.py b/test/deprecated/prim/prim/vjp/static/test_comp_div_grad.py index f0f73d20024387..e57e9446b1285f 100644 --- a/test/deprecated/prim/prim/vjp/static/test_comp_div_grad.py +++ b/test/deprecated/prim/prim/vjp/static/test_comp_div_grad.py @@ -24,7 +24,9 @@ def apply_to_static(net, use_cinn): build_strategy = paddle.static.BuildStrategy() build_strategy.build_cinn_pass = use_cinn - return paddle.jit.to_static(net, build_strategy=build_strategy) + return paddle.jit.to_static( + net, build_strategy=build_strategy, full_graph=True + ) class PrimeNet(paddle.nn.Layer): diff --git a/test/deprecated/prim/prim/vjp/static/test_comp_gather_grad.py b/test/deprecated/prim/prim/vjp/static/test_comp_gather_grad.py index c4550dd13cc027..7912f2d3a798e5 100644 --- a/test/deprecated/prim/prim/vjp/static/test_comp_gather_grad.py +++ b/test/deprecated/prim/prim/vjp/static/test_comp_gather_grad.py @@ -26,7 +26,9 @@ def apply_to_static(net, use_cinn): build_strategy = paddle.static.BuildStrategy() build_strategy.build_cinn_pass = use_cinn - return paddle.jit.to_static(net, build_strategy=build_strategy) + return paddle.jit.to_static( + net, build_strategy=build_strategy, full_graph=True + ) class PrimeNet(paddle.nn.Layer): diff --git a/test/deprecated/prim/prim/vjp/static/test_comp_reshape_grad.py b/test/deprecated/prim/prim/vjp/static/test_comp_reshape_grad.py index 7577c29b251cda..0d7b3d363d266b 100644 --- a/test/deprecated/prim/prim/vjp/static/test_comp_reshape_grad.py +++ b/test/deprecated/prim/prim/vjp/static/test_comp_reshape_grad.py @@ -24,7 +24,9 @@ def apply_to_static(net, use_cinn): build_strategy = paddle.static.BuildStrategy() build_strategy.build_cinn_pass = use_cinn - return paddle.jit.to_static(net, build_strategy=build_strategy) + return paddle.jit.to_static( + net, build_strategy=build_strategy, full_graph=True + ) class PrimeNet(paddle.nn.Layer): diff --git a/test/deprecated/prim/prim/vjp/static/test_comp_sqrt_grad.py b/test/deprecated/prim/prim/vjp/static/test_comp_sqrt_grad.py index 09789ef6602ca7..3c4511adf63b61 100644 --- a/test/deprecated/prim/prim/vjp/static/test_comp_sqrt_grad.py +++ b/test/deprecated/prim/prim/vjp/static/test_comp_sqrt_grad.py @@ -29,7 +29,9 @@ def apply_to_static(net, use_cinn): build_strategy = paddle.static.BuildStrategy() build_strategy.build_cinn_pass = use_cinn - return paddle.jit.to_static(net, build_strategy=build_strategy) + return paddle.jit.to_static( + net, build_strategy=build_strategy, full_graph=True + ) class PrimeNet(paddle.nn.Layer): diff --git a/test/deprecated/prim/prim/vjp/static/test_comp_sub_grad.py b/test/deprecated/prim/prim/vjp/static/test_comp_sub_grad.py index 49b1e33e3c0492..e7ee9379aaf33d 100644 --- a/test/deprecated/prim/prim/vjp/static/test_comp_sub_grad.py +++ b/test/deprecated/prim/prim/vjp/static/test_comp_sub_grad.py @@ -24,7 +24,9 @@ def apply_to_static(net, use_cinn): build_strategy = paddle.static.BuildStrategy() build_strategy.build_cinn_pass = use_cinn - return paddle.jit.to_static(net, build_strategy=build_strategy) + return paddle.jit.to_static( + net, build_strategy=build_strategy, full_graph=True + ) class PrimeNet(paddle.nn.Layer): diff --git a/test/deprecated/prim/prim/vjp/static/test_comp_tanh_grad.py b/test/deprecated/prim/prim/vjp/static/test_comp_tanh_grad.py index 43edf96c3aff0d..15a88c9930569c 100644 --- a/test/deprecated/prim/prim/vjp/static/test_comp_tanh_grad.py +++ b/test/deprecated/prim/prim/vjp/static/test_comp_tanh_grad.py @@ -29,7 +29,9 @@ def apply_to_static(net, use_cinn): build_strategy = paddle.static.BuildStrategy() build_strategy.build_cinn_pass = use_cinn - return paddle.jit.to_static(net, build_strategy=build_strategy) + return paddle.jit.to_static( + net, build_strategy=build_strategy, full_graph=True + ) class PrimeNet(paddle.nn.Layer): diff --git a/test/deprecated/prim/prim/vjp/static/test_comp_transpose_grad.py b/test/deprecated/prim/prim/vjp/static/test_comp_transpose_grad.py index 2f7cb85e3145d6..a1bcd5d1390485 100644 --- a/test/deprecated/prim/prim/vjp/static/test_comp_transpose_grad.py +++ b/test/deprecated/prim/prim/vjp/static/test_comp_transpose_grad.py @@ -24,7 +24,9 @@ def apply_to_static(net, use_cinn): build_strategy = paddle.static.BuildStrategy() build_strategy.build_cinn_pass = use_cinn - return paddle.jit.to_static(net, build_strategy=build_strategy) + return paddle.jit.to_static( + net, build_strategy=build_strategy, full_graph=True + ) class PrimeNet(paddle.nn.Layer): diff --git a/test/deprecated/prim/process/test_check_inputs.py b/test/deprecated/prim/process/test_check_inputs.py index b844f52ea81d80..53df7988ab1bee 100644 --- a/test/deprecated/prim/process/test_check_inputs.py +++ b/test/deprecated/prim/process/test_check_inputs.py @@ -32,7 +32,7 @@ def test_non_tensor_input(self): core._set_prim_all_enabled(True) np_data = np.random.random([3, 4]).astype("float32") tensor_data = paddle.to_tensor(np_data) - net = paddle.jit.to_static(fn) + net = paddle.jit.to_static(fn, full_graph=True) _ = net(tensor_data, shape=[2, 3, 4]).numpy() core._set_prim_all_enabled(False) diff --git a/test/deprecated/rnn/test_rnn_nets.py b/test/deprecated/rnn/test_rnn_nets.py index cdb3843cd22109..2596a24677c46f 100644 --- a/test/deprecated/rnn/test_rnn_nets.py +++ b/test/deprecated/rnn/test_rnn_nets.py @@ -374,7 +374,9 @@ def forward(self, input): rnn.train() rnn = paddle.jit.to_static( - rnn, [paddle.static.InputSpec(shape=[None, None, 16], dtype=x.dtype)] + rnn, + [paddle.static.InputSpec(shape=[None, None, 16], dtype=x.dtype)], + full_graph=True, ) temp_dir = tempfile.TemporaryDirectory() save_dirname = os.path.join(temp_dir.name, "./inference/%s_infer" % mode) diff --git a/test/deprecated/tokenizer/test_faster_tokenizer_op.py b/test/deprecated/tokenizer/test_faster_tokenizer_op.py index c5b09962380824..d35459dfa541fd 100755 --- a/test/deprecated/tokenizer/test_faster_tokenizer_op.py +++ b/test/deprecated/tokenizer/test_faster_tokenizer_op.py @@ -408,6 +408,7 @@ def test_inference(self): shape=[None], dtype=core.VarDesc.VarType.STRINGS ), # texts ], + full_graph=True, ) # Save in static graph model. paddle.jit.save(static_model, self.inference_path) diff --git a/test/distribution/CMakeLists.txt b/test/distribution/CMakeLists.txt index 2839c2ea7231fc..a31ec8e1f21370 100644 --- a/test/distribution/CMakeLists.txt +++ b/test/distribution/CMakeLists.txt @@ -8,4 +8,4 @@ foreach(TEST_OP ${TEST_OPS}) py_test_modules(${TEST_OP} MODULES ${TEST_OP}) endforeach() -set_pit_tests_properties() +set_pir_tests_properties() diff --git a/test/dygraph_to_static/test_no_gradient.py b/test/dygraph_to_static/test_no_gradient.py index 1bd3a02f54ede5..84f7b032c2f4a2 100644 --- a/test/dygraph_to_static/test_no_gradient.py +++ b/test/dygraph_to_static/test_no_gradient.py @@ -15,7 +15,7 @@ import unittest import numpy -from dygraph_to_static_utils import Dy2StTestBase +from dygraph_to_static_utils import Dy2StTestBase, test_legacy_and_pt_and_pir import paddle @@ -33,6 +33,7 @@ def main_func(x, index): class TestNoGradientCase(Dy2StTestBase): + @test_legacy_and_pt_and_pir def test_no_gradient(self): paddle.disable_static() x = paddle.randn([10, 3]) diff --git a/test/dygraph_to_static/test_pylayer.py b/test/dygraph_to_static/test_pylayer.py index bf09ba3db8a8cd..9724eb2749e40e 100644 --- a/test/dygraph_to_static/test_pylayer.py +++ b/test/dygraph_to_static/test_pylayer.py @@ -28,7 +28,7 @@ import unittest import numpy as np -from test_jit_save_load import train +from test_jit_save_load_rename import train import paddle from paddle.autograd.py_layer import PyLayer diff --git a/test/fft/CMakeLists.txt b/test/fft/CMakeLists.txt index 2839c2ea7231fc..a31ec8e1f21370 100644 --- a/test/fft/CMakeLists.txt +++ b/test/fft/CMakeLists.txt @@ -8,4 +8,4 @@ foreach(TEST_OP ${TEST_OPS}) py_test_modules(${TEST_OP} MODULES ${TEST_OP}) endforeach() -set_pit_tests_properties() +set_pir_tests_properties() diff --git a/test/ipu/test_dy2static_fp16_ipu.py b/test/ipu/test_dy2static_fp16_ipu.py index bc3e5342ef47d9..0d630735ac4dee 100644 --- a/test/ipu/test_dy2static_fp16_ipu.py +++ b/test/ipu/test_dy2static_fp16_ipu.py @@ -64,7 +64,7 @@ def _test(self, use_ipu=False): ), paddle.static.InputSpec(name="target", shape=[32], dtype="int64"), ] - model = paddle.jit.to_static(model, input_spec=specs) + model = paddle.jit.to_static(model, input_spec=specs, full_graph=True) optim = paddle.optimizer.Adam( learning_rate=0.01, parameters=model.parameters() ) diff --git a/test/ipu/test_dy2static_ipu.py b/test/ipu/test_dy2static_ipu.py index eaca14de6a3981..9e4d007f2b8084 100644 --- a/test/ipu/test_dy2static_ipu.py +++ b/test/ipu/test_dy2static_ipu.py @@ -42,7 +42,7 @@ def __init__( self.use_reduction = use_reduction self.use_identity_loss = use_identity_loss - @to_static() + @to_static(full_graph=True) def forward(self, x, target=None): x = self.conv(x) x = paddle.flatten(x, 1, -1) diff --git a/test/ipu/test_print_op_ipu.py b/test/ipu/test_print_op_ipu.py index 442077009fc486..6b0ff2ce337cc6 100644 --- a/test/ipu/test_print_op_ipu.py +++ b/test/ipu/test_print_op_ipu.py @@ -113,7 +113,7 @@ def __init__(self): in_channels=3, out_channels=1, kernel_size=2, stride=1 ) - @to_static() + @to_static(full_graph=True) def forward(self, x, target=None): x = self.conv(x) print(x) diff --git a/test/ir/inference/test_inference_predictor_run.py b/test/ir/inference/test_inference_predictor_run.py index 21b095d7974426..a901402cddb000 100644 --- a/test/ir/inference/test_inference_predictor_run.py +++ b/test/ir/inference/test_inference_predictor_run.py @@ -51,6 +51,7 @@ def setUp(self): shape=[None, 4], dtype='float32', name='input1' ), ], + full_graph=True, ) paddle.jit.save( model, diff --git a/test/ir/inference/test_save_optimized_model_pass.py b/test/ir/inference/test_save_optimized_model_pass.py index 68e2a87302b642..e07d6e6dd931e2 100644 --- a/test/ir/inference/test_save_optimized_model_pass.py +++ b/test/ir/inference/test_save_optimized_model_pass.py @@ -30,7 +30,9 @@ def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() net = alexnet(True) model = to_static( - net, input_spec=[InputSpec(shape=[None, 3, 224, 224], name='x')] + net, + input_spec=[InputSpec(shape=[None, 3, 224, 224], name='x')], + full_graph=True, ) paddle.jit.save( model, os.path.join(self.temp_dir.name, 'alexnet/inference') diff --git a/test/ir/inference/test_trt_inference_fp16_io.py b/test/ir/inference/test_trt_inference_fp16_io.py index 31cccac681b618..4f46e5f393e86c 100644 --- a/test/ir/inference/test_trt_inference_fp16_io.py +++ b/test/ir/inference/test_trt_inference_fp16_io.py @@ -30,7 +30,9 @@ def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() net = alexnet(True) model = to_static( - net, input_spec=[InputSpec(shape=[None, 3, 224, 224], name='x')] + net, + input_spec=[InputSpec(shape=[None, 3, 224, 224], name='x')], + full_graph=True, ) paddle.jit.save( model, os.path.join(self.temp_dir.name, 'alexnet/inference') diff --git a/test/ir/inference/test_trt_inference_predictor.py b/test/ir/inference/test_trt_inference_predictor.py index e334e5eabfd74e..fa29b7c86432dd 100644 --- a/test/ir/inference/test_trt_inference_predictor.py +++ b/test/ir/inference/test_trt_inference_predictor.py @@ -132,9 +132,9 @@ def load(self, config_arg, inputs=None, outpus=None): max_batch_size=max_batch_size, min_subgraph_size=self.args.subgraph_size, use_static=False, - use_calib_mode=False - if self.args.precision == 'int8' - else False, + use_calib_mode=( + False if self.args.precision == 'int8' else False + ), ) if self.args.enable_dynamic_shape: if os.path.exists(shape_range_file): @@ -387,7 +387,9 @@ def SaveInferenceModel(self): ) ] - static_model = paddle.jit.to_static(net, input_spec=input_spec) + static_model = paddle.jit.to_static( + net, input_spec=input_spec, full_graph=True + ) paddle.jit.save(static_model, self.path) def testInferencePredictor(self): diff --git a/test/ir/inference/test_xpu_convert_mixed_precision.py b/test/ir/inference/test_xpu_convert_mixed_precision.py index cce33ca3bc9dc7..27a3faf97a6645 100644 --- a/test/ir/inference/test_xpu_convert_mixed_precision.py +++ b/test/ir/inference/test_xpu_convert_mixed_precision.py @@ -32,7 +32,9 @@ def test(self): self.temp_dir = tempfile.TemporaryDirectory() model = resnet50(True) net = to_static( - model, input_spec=[InputSpec(shape=[None, 3, 224, 224], name='x')] + model, + input_spec=[InputSpec(shape=[None, 3, 224, 224], name='x')], + full_graph=True, ) paddle.jit.save( net, os.path.join(self.temp_dir.name, 'resnet50/inference') diff --git a/test/ir/pir/cinn/symbolic/test_decomp_inference_predictor_run.py b/test/ir/pir/cinn/symbolic/test_decomp_inference_predictor_run.py index 517cd7083288a9..67bf2b9795b942 100644 --- a/test/ir/pir/cinn/symbolic/test_decomp_inference_predictor_run.py +++ b/test/ir/pir/cinn/symbolic/test_decomp_inference_predictor_run.py @@ -57,6 +57,7 @@ def setUp(self): shape=self.shape, dtype='float32', name='input1' ), ], + full_graph=True, ) paddle.jit.save( model, diff --git a/test/ir/pir/test_subgraph_exporter.py b/test/ir/pir/test_subgraph_exporter.py index 9926d2ee6af5aa..32cfda850b7edd 100644 --- a/test/ir/pir/test_subgraph_exporter.py +++ b/test/ir/pir/test_subgraph_exporter.py @@ -97,7 +97,7 @@ def forward(self, x): class TestSaveFwdBwdProg(unittest.TestCase): def setUp(self): - self.net = paddle.jit.to_static(Net()) + self.net = paddle.jit.to_static(Net(), full_graph=True) self.root_dir = os.path.join(get_saving_dir(), "wrapper") self.clean() diff --git a/test/ir/pir/translator/test_pull_box_sparse_translator.py b/test/ir/pir/translator/test_pull_box_sparse_translator.py new file mode 100644 index 00000000000000..f691892adc4f41 --- /dev/null +++ b/test/ir/pir/translator/test_pull_box_sparse_translator.py @@ -0,0 +1,51 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import test_op_translator + +import paddle +from paddle.base.layer_helper import LayerHelper + + +class TestPullBoxSparseOpTranslator( + test_op_translator.TestOpWithBackwardTranslator +): + def append_op(self): + self.forward_op_type = "pull_box_sparse" + self.backward_op_type = "push_box_sparse" + ids = paddle.ones(shape=(1, 1), dtype='float32') + w = paddle.ones(shape=(1, 1), dtype='float32') + out = paddle.ones(shape=(1, 1), dtype='float32') + attrs = { + 'is_sparse': False, + 'is_distributed': False, + 'size': 1, + } + forward_helper = LayerHelper(self.forward_op_type) + forward_helper.append_op( + type=self.forward_op_type, + inputs={"W": w, "Ids": [ids]}, + outputs={"Out": [out]}, + attrs=attrs, + ) + return out + + def test_translator(self): + self.check() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/ir/test_convert_to_mixed_precision.py b/test/ir/test_convert_to_mixed_precision.py index b3e22853845858..fa5d5ea8256548 100644 --- a/test/ir/test_convert_to_mixed_precision.py +++ b/test/ir/test_convert_to_mixed_precision.py @@ -36,7 +36,9 @@ def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() model = resnet50(True) net = to_static( - model, input_spec=[InputSpec(shape=[None, 3, 224, 224], name='x')] + model, + input_spec=[InputSpec(shape=[None, 3, 224, 224], name='x')], + full_graph=True, ) paddle.jit.save( net, os.path.join(self.temp_dir.name, 'resnet50/inference') diff --git a/test/ir/test_inference_datatype.py b/test/ir/test_inference_datatype.py index df8551497490c1..2440b56c2eceb8 100644 --- a/test/ir/test_inference_datatype.py +++ b/test/ir/test_inference_datatype.py @@ -49,6 +49,7 @@ def setUp(self): input_spec=[ paddle.static.InputSpec(shape=[None, 4], dtype='float64') ], + full_graph=True, ) paddle.jit.save( model, diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index 82929dd6446341..a81bceabc45a95 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -492,7 +492,7 @@ foreach(TEST_OP ${TEST_OPS_WITH_GC}) endforeach() # Switch some dy2st UT to eager mode -set(TEST_EAGER_OPS test_jit_save_load test_translated_layer) +set(TEST_EAGER_OPS test_jit_save_load_rename test_translated_layer) foreach(TEST_OP ${TEST_EAGER_OPS}) list(REMOVE_ITEM TEST_OPS ${TEST_OP}) py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS FLAGS_enable_eager_mode=1) @@ -855,6 +855,7 @@ set_tests_properties(test_vision_models PROPERTIES TIMEOUT 120) set_tests_properties(test_dataset_uci_housing PROPERTIES TIMEOUT 120) set_tests_properties(test_dataset_imdb PROPERTIES TIMEOUT 300) set_tests_properties(test_callback_wandb PROPERTIES TIMEOUT 60) +set_tests_properties(test_jit_save_load_rename PROPERTIES TIMEOUT 100) if(WITH_COVERAGE) set_tests_properties(test_hapi_hub PROPERTIES TIMEOUT 300) endif() @@ -1054,4 +1055,4 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32) PROPERTIES ENVIRONMENT "FLAGS_new_executor_micro_batching=False") endif() -set_pit_tests_properties() +set_pir_tests_properties() diff --git a/test/legacy_test/test_chunk_eval_op.py b/test/legacy_test/test_chunk_eval_op.py new file mode 100644 index 00000000000000..b9db50079b4b3d --- /dev/null +++ b/test/legacy_test/test_chunk_eval_op.py @@ -0,0 +1,282 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from op_test import OpTest + + +class Segment: + def __init__(self, chunk_type, start_idx, end_idx): + self.chunk_type = chunk_type + self.start_idx = start_idx + self.end_idx = end_idx + + def __str__(self): + return f'(Segment: {self.chunk_type}, {self.start_idx}, {self.end_idx})' + + __repr__ = __str__ + + +class TestChunkEvalOp(OpTest): + num_sequences = 5 + batch_size = 50 + + def parse_scheme(self): + if self.scheme == 'IOB': + self.num_tag_types = 2 + elif self.scheme == 'IOE': + self.num_tag_types = 2 + + def fill_with_chunks(self, data, chunks): + for chunk in chunks: + if self.scheme == 'IOB': + data[chunk.start_idx] = chunk.chunk_type * self.num_tag_types + data[ + chunk.start_idx + 1 : chunk.end_idx + ] = chunk.chunk_type * self.num_tag_types + ( + self.num_tag_types - 1 + ) + data[chunk.end_idx] = ( + chunk.chunk_type * self.num_tag_types + + (self.num_tag_types - 1) + if chunk.start_idx < chunk.end_idx + else data[chunk.start_idx] + ) + elif self.scheme == 'IOE': + data[chunk.start_idx : chunk.end_idx] = ( + chunk.chunk_type * self.num_tag_types + ) + data[chunk.end_idx] = chunk.chunk_type * self.num_tag_types + ( + self.num_tag_types - 1 + ) + + def rand_chunks(self, starts, num_chunks): + if num_chunks < 0: + num_chunks = np.random.randint(starts[-1]) + chunks = [] + # generate chunk beginnings + chunk_begins = sorted( + np.random.choice(list(range(starts[-1])), num_chunks, replace=False) + ) + seq_chunk_begins = [] + begin_idx = 0 + # divide chunks into sequences + for i in range(len(starts) - 1): + tmp_chunk_begins = [] + while ( + begin_idx < len(chunk_begins) + and chunk_begins[begin_idx] < starts[i + 1] + ): + tmp_chunk_begins.append(chunk_begins[begin_idx]) + begin_idx += 1 + seq_chunk_begins.append(tmp_chunk_begins) + # generate chunk ends + chunk_ends = [] + for i in range(len(seq_chunk_begins)): + for j in range(len(seq_chunk_begins[i])): + low = seq_chunk_begins[i][j] + high = ( + seq_chunk_begins[i][j + 1] + if j < len(seq_chunk_begins[i]) - 1 + else starts[i + 1] + ) + chunk_ends.append(np.random.randint(low, high)) + # generate chunks + for chunk_pos in zip(chunk_begins, chunk_ends): + chunk_type = np.random.randint(self.num_chunk_types) + chunks.append(Segment(chunk_type, *chunk_pos)) + return chunks + + def gen_chunks(self, infer, label, starts): + chunks = self.rand_chunks( + starts, + self.num_infer_chunks + + self.num_label_chunks + - self.num_correct_chunks, + ) + correct_chunks = np.random.choice( + list(range(len(chunks))), self.num_correct_chunks, replace=False + ) + infer_chunks = np.random.choice( + [x for x in range(len(chunks)) if x not in correct_chunks], + self.num_infer_chunks - self.num_correct_chunks, + replace=False, + ) + infer_chunks = sorted(correct_chunks.tolist() + infer_chunks.tolist()) + label_chunks = np.random.choice( + [x for x in range(len(chunks)) if x not in infer_chunks], + self.num_label_chunks - self.num_correct_chunks, + replace=False, + ) + label_chunks = sorted(correct_chunks.tolist() + label_chunks.tolist()) + self.fill_with_chunks(infer, [chunks[idx] for idx in infer_chunks]) + self.fill_with_chunks(label, [chunks[idx] for idx in label_chunks]) + # exclude types in excluded_chunk_types + if len(self.excluded_chunk_types) > 0: + for idx in correct_chunks: + if chunks[idx].chunk_type in self.excluded_chunk_types: + self.num_correct_chunks -= 1 + for idx in infer_chunks: + if chunks[idx].chunk_type in self.excluded_chunk_types: + self.num_infer_chunks -= 1 + for idx in label_chunks: + if chunks[idx].chunk_type in self.excluded_chunk_types: + self.num_label_chunks -= 1 + return ( + self.num_correct_chunks, + self.num_infer_chunks, + self.num_label_chunks, + ) + + def set_confs(self): + # Use the IOB scheme and labels with 2 chunk types + self.scheme = 'IOB' + self.num_chunk_types = 2 + self.excluded_chunk_types = [] + self.other_chunk_type = self.num_chunk_types + self.attrs = { + 'num_chunk_types': self.num_chunk_types, + 'chunk_scheme': self.scheme, + 'excluded_chunk_types': self.excluded_chunk_types, + } + self.parse_scheme() + ( + self.num_correct_chunks, + self.num_infer_chunks, + self.num_label_chunks, + ) = (4, 5, 9) + + def set_data(self): + infer = np.zeros((self.batch_size,)).astype('int64') + infer.fill(self.num_chunk_types * self.num_tag_types) + label = np.copy(infer) + starts = np.random.choice( + list(range(1, self.batch_size)), + self.num_sequences - 1, + replace=False, + ).tolist() + starts.extend([0, self.batch_size]) + starts = sorted(starts) + ( + self.num_correct_chunks, + self.num_infer_chunks, + self.num_label_chunks, + ) = self.gen_chunks(infer, label, starts) + lod = [] + for i in range(len(starts) - 1): + lod.append(starts[i + 1] - starts[i]) + self.set_input(infer, label, lod) + precision = ( + float(self.num_correct_chunks) / self.num_infer_chunks + if self.num_infer_chunks + else 0 + ) + recall = ( + float(self.num_correct_chunks) / self.num_label_chunks + if self.num_label_chunks + else 0 + ) + f1 = ( + float(2 * precision * recall) / (precision + recall) + if self.num_correct_chunks + else 0 + ) + self.outputs = { + 'Precision': np.asarray([precision], dtype='float32'), + 'Recall': np.asarray([recall], dtype='float32'), + 'F1-Score': np.asarray([f1], dtype='float32'), + 'NumInferChunks': np.asarray( + [self.num_infer_chunks], dtype='int64' + ), + 'NumLabelChunks': np.asarray( + [self.num_label_chunks], dtype='int64' + ), + 'NumCorrectChunks': np.asarray( + [self.num_correct_chunks], dtype='int64' + ), + } + + def set_input(self, infer, label, lod): + self.inputs = {'Inference': (infer, [lod]), 'Label': (label, [lod])} + + def setUp(self): + self.op_type = 'chunk_eval' + self.set_confs() + self.set_data() + + def test_check_output(self): + # NODE(yjjiang11): This op will be deprecated. + self.check_output(check_dygraph=False) + + +class TestChunkEvalOpWithExclude(TestChunkEvalOp): + def set_confs(self): + # Use the IOE scheme and labels with 3 chunk types + self.scheme = 'IOE' + self.num_chunk_types = 3 + self.excluded_chunk_types = [1] + self.other_chunk_type = self.num_chunk_types + self.attrs = { + 'num_chunk_types': self.num_chunk_types, + 'chunk_scheme': self.scheme, + 'excluded_chunk_types': self.excluded_chunk_types, + } + self.parse_scheme() + ( + self.num_correct_chunks, + self.num_infer_chunks, + self.num_label_chunks, + ) = (15, 18, 20) + + +class TestChunkEvalOpWithTensorInput(TestChunkEvalOp): + def set_input(self, infer, label, lod): + max_len = np.max(lod) + pad_infer = [] + pad_label = [] + start = 0 + for i in range(len(lod)): + end = lod[i] + start + pad_infer.append( + np.pad( + infer[start:end], + (0, max_len - lod[i]), + 'constant', + constant_values=(-1,), + ) + ) + pad_label.append( + np.pad( + label[start:end], + (0, max_len - lod[i]), + 'constant', + constant_values=(-1,), + ) + ) + start = end + + pad_infer = np.expand_dims(np.array(pad_infer, dtype='int64'), 2) + pad_label = np.expand_dims(np.array(pad_label, dtype='int64'), 2) + lod = np.array(lod, dtype='int64') + self.inputs = { + 'Inference': pad_infer, + 'Label': pad_label, + 'SeqLength': lod, + } + + +if __name__ == '__main__': + unittest.main() diff --git a/test/legacy_test/test_ctc_align.py b/test/legacy_test/test_ctc_align.py new file mode 100644 index 00000000000000..699b176518be18 --- /dev/null +++ b/test/legacy_test/test_ctc_align.py @@ -0,0 +1,232 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from op_test import OpTest + +import paddle + + +def CTCAlign(input, lod, blank, merge_repeated, padding=0, input_length=None): + if input_length is None: + lod0 = lod[0] + result = [] + cur_offset = 0 + for i in range(len(lod0)): + prev_token = -1 + for j in range(cur_offset, cur_offset + lod0[i]): + token = input[j][0] + if (token != blank) and not ( + merge_repeated and token == prev_token + ): + result.append(token) + prev_token = token + cur_offset += lod0[i] + result = np.array(result).reshape([len(result), 1]).astype("int32") + if len(result) == 0: + result = np.array([[-1]]) + return result + else: + result = [[] for i in range(len(input))] + output_length = [] + for i in range(len(input)): + prev_token = -1 + for j in range(input_length[i][0]): + token = input[i][j] + if (token != blank) and not ( + merge_repeated and token == prev_token + ): + result[i].append(token) + prev_token = token + start = len(result[i]) + output_length.append([start]) + for j in range(start, len(input[i])): + result[i].append(padding) + result = ( + np.array(result) + .reshape([len(input), len(input[0])]) + .astype("int32") + ) + output_length = ( + np.array(output_length).reshape([len(input), 1]).astype("int32") + ) + + return result, output_length + + +class TestCTCAlignOp(OpTest): + def config(self): + self.op_type = "ctc_align" + self.input_lod = [[11, 7]] + self.blank = 0 + self.merge_repeated = False + self.input = ( + np.array([0, 1, 2, 2, 0, 4, 0, 4, 5, 0, 6, 6, 0, 0, 7, 7, 7, 0]) + .reshape([18, 1]) + .astype("int32") + ) + + def setUp(self): + self.config() + output = CTCAlign( + self.input, self.input_lod, self.blank, self.merge_repeated + ) + + self.inputs = { + "Input": (self.input, self.input_lod), + } + self.outputs = {"Output": output} + self.attrs = { + "blank": self.blank, + "merge_repeated": self.merge_repeated, + } + + def test_check_output(self): + # NODE(yjjiang11): This op will be deprecated. + self.check_output(check_dygraph=False) + + +class TestCTCAlignOpCase1(TestCTCAlignOp): + def config(self): + self.op_type = "ctc_align" + self.input_lod = [[11, 8]] + self.blank = 0 + self.merge_repeated = True + self.input = ( + np.array([0, 1, 2, 2, 0, 4, 0, 4, 5, 0, 6, 6, 0, 0, 7, 7, 7, 0, 0]) + .reshape([19, 1]) + .astype("int32") + ) + + +class TestCTCAlignOpCase2(TestCTCAlignOp): + def config(self): + self.op_type = "ctc_align" + self.input_lod = [[4]] + self.blank = 0 + self.merge_repeated = True + self.input = np.array([0, 0, 0, 0]).reshape([4, 1]).astype("int32") + + +class TestCTCAlignPaddingOp(OpTest): + def config(self): + self.op_type = "ctc_align" + self.input_lod = [] + self.blank = 0 + self.padding_value = 0 + self.merge_repeated = True + self.input = ( + np.array( + [ + [0, 2, 4, 4, 0, 6, 3, 6, 6, 0, 0], + [1, 1, 3, 0, 0, 4, 5, 6, 0, 0, 0], + ] + ) + .reshape([2, 11]) + .astype("int32") + ) + self.input_length = np.array([[9], [8]]).reshape([2, 1]).astype("int32") + + def setUp(self): + self.config() + output, output_length = CTCAlign( + self.input, + self.input_lod, + self.blank, + self.merge_repeated, + self.padding_value, + self.input_length, + ) + self.inputs = { + "Input": (self.input, self.input_lod), + "InputLength": self.input_length, + } + self.outputs = {"Output": output, "OutputLength": output_length} + self.attrs = { + "blank": self.blank, + "merge_repeated": self.merge_repeated, + "padding_value": self.padding_value, + } + + def test_check_output(self): + # NODE(yjjiang11): This op will be deprecated. + self.check_output(check_dygraph=False) + + +class TestCTCAlignOpCase3(TestCTCAlignPaddingOp): + def config(self): + self.op_type = "ctc_align" + self.blank = 0 + self.input_lod = [] + self.merge_repeated = True + self.padding_value = 0 + self.input = ( + np.array( + [[0, 1, 2, 2, 0, 4], [0, 4, 5, 0, 6, 0], [0, 7, 7, 7, 0, 0]] + ) + .reshape([3, 6]) + .astype("int32") + ) + self.input_length = ( + np.array([[6], [5], [4]]).reshape([3, 1]).astype("int32") + ) + + +class TestCTCAlignOpCase4(TestCTCAlignPaddingOp): + ''' + # test tensor input which has attr input padding_value + ''' + + def config(self): + self.op_type = "ctc_align" + self.blank = 0 + self.input_lod = [] + self.merge_repeated = False + self.padding_value = 0 + self.input = ( + np.array( + [[0, 1, 2, 2, 0, 4], [0, 4, 5, 0, 6, 0], [0, 7, 7, 7, 0, 0]] + ) + .reshape([3, 6]) + .astype("int32") + ) + self.input_length = ( + np.array([[6], [5], [4]]).reshape([3, 1]).astype("int32") + ) + + +class TestCTCAlignOpCase5(TestCTCAlignPaddingOp): + def config(self): + self.op_type = "ctc_align" + self.blank = 0 + self.input_lod = [] + self.merge_repeated = False + self.padding_value = 1 + self.input = ( + np.array( + [[0, 1, 2, 2, 0, 4], [0, 4, 5, 0, 6, 0], [0, 7, 1, 7, 0, 0]] + ) + .reshape([3, 6]) + .astype("int32") + ) + self.input_length = ( + np.array([[6], [5], [4]]).reshape([3, 1]).astype("int32") + ) + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() diff --git a/test/legacy_test/test_dropout_op.py b/test/legacy_test/test_dropout_op.py index 7d7f8e596ebe75..fcd41414f03d6f 100644 --- a/test/legacy_test/test_dropout_op.py +++ b/test/legacy_test/test_dropout_op.py @@ -1537,7 +1537,9 @@ def forward( def apply_to_static(net, use_cinn): build_strategy = paddle.static.BuildStrategy() build_strategy.build_cinn_pass = use_cinn - return paddle.jit.to_static(net, build_strategy=build_strategy) + return paddle.jit.to_static( + net, build_strategy=build_strategy, full_graph=True + ) @param.parameterized_class( @@ -1753,9 +1755,11 @@ def test_static_comp(self): input_ = paddle.static.data( 'x', shape=self.x.shape, - dtype=self.x.dtype - if self.dtype != "bfloat16" - else "float32", + dtype=( + self.x.dtype + if self.dtype != "bfloat16" + else "float32" + ), ) input_.stop_gradient = False y = paddle.assign(input_) @@ -2103,9 +2107,11 @@ def test_static_comp(self): input_ = paddle.static.data( 'x', shape=self.x.shape, - dtype=self.x.dtype - if self.dtype != "bfloat16" - else "float32", + dtype=( + self.x.dtype + if self.dtype != "bfloat16" + else "float32" + ), ) input_.stop_gradient = False output = paddle.nn.functional.dropout( diff --git a/test/legacy_test/test_flash_attention.py b/test/legacy_test/test_flash_attention.py index a5eba6148db816..c410b743a78379 100644 --- a/test/legacy_test/test_flash_attention.py +++ b/test/legacy_test/test_flash_attention.py @@ -26,7 +26,9 @@ from paddle.nn.functional.flash_attention import ( flash_attention, flash_attention_with_sparse_mask, + flash_attn_qkvpacked, flash_attn_unpadded, + flash_attn_varlen_qkvpacked, scaled_dot_product_attention, ) from paddle.pir_utils import test_with_pir_api @@ -956,5 +958,422 @@ def setUp(self): self.causal = True +class TestFlashAttentionVarlenQKVPackedGQA(TestFlashAttentionGQA): + def gen_unpadded_data(self, dtype): + seq_len_q = np.random.randint( + low=1, high=self.seq_len, size=[self.batch_size] + ) + seq_len_k = seq_len_q + cu_seqlen_q = paddle.to_tensor( + [0] + np.cumsum(seq_len_q).tolist(), dtype=paddle.int32 + ) + cu_seqlen_k = cu_seqlen_q + + qs, ks, vs = [], [], [] + for i in range(self.batch_size): + tmp_q = ( + paddle.randn( + [seq_len_q[i] * self.num_head * self.head_dim], dtype=dtype + ) + / 1e2 + ) + tmp_k = ( + paddle.randn( + [ + seq_len_k[i] + * self.num_head + * self.head_dim + // self.num_group + ], + dtype=dtype, + ) + / 1e2 + ) + tmp_v = ( + paddle.randn( + [ + seq_len_k[i] + * self.num_head + * self.head_dim + // self.num_group + ], + dtype=dtype, + ) + / 1e2 + ) + qs.append(tmp_q) + ks.append(tmp_k) + vs.append(tmp_v) + + q = paddle.concat(qs, axis=0).reshape( + [-1, self.num_head, self.head_dim] + ) + k = paddle.concat(ks, axis=0).reshape( + [-1, self.num_head // self.num_group, self.head_dim] + ) + v = paddle.concat(vs, axis=0).reshape( + [-1, self.num_head // self.num_group, self.head_dim] + ) + return q, k, v, cu_seqlen_q, cu_seqlen_k + + def calc_qkvpackedfa( + self, q, k, v, cu_seqlen_q, cu_seqlen_k, out_grad, causal, varlen_padded + ): + q, k, v = self.clone_tensor([q, k, v]) + scale = self.head_dim ** (-0.5) + if varlen_padded: + tq = q.reshape( + [ + self.batch_size * self.seq_len, + self.num_group, + self.num_head // self.num_group, + self.head_dim, + ] + ) + tk = k.reshape( + [ + self.batch_size * self.seq_len, + self.num_head // self.num_group, + self.head_dim, + ] + ) + tv = v.reshape( + [ + self.batch_size * self.seq_len, + self.num_head // self.num_group, + self.head_dim, + ] + ) + kv = paddle.stack([tk, tv], axis=1) + qkv = paddle.concat([tq, kv], axis=1) + out = flash_attn_varlen_qkvpacked( + qkv, + cu_seqlens_q=cu_seqlen_q, + cu_seqlens_k=cu_seqlen_k, + max_seqlen_q=self.seq_len, + max_seqlen_k=self.seq_len, + scale=scale, + causal=causal, + varlen_padded=varlen_padded, + ) + out_grad = out_grad.reshape(out[0].shape) + else: + tq = q.reshape( + [ + 0, + self.num_group, + self.num_head // self.num_group, + self.head_dim, + ] + ) + kv = paddle.stack([k, v], axis=1) + qkv = paddle.concat([tq, kv], axis=1) + out = flash_attn_varlen_qkvpacked( + qkv, + cu_seqlens_q=cu_seqlen_q, + cu_seqlens_k=cu_seqlen_k, + max_seqlen_q=self.seq_len, + max_seqlen_k=self.seq_len, + scale=scale, + causal=causal, + varlen_padded=varlen_padded, + ) + out = out[0] + grads = paddle.grad(outputs=out, inputs=qkv, grad_outputs=out_grad) + qkvgrad = grads[0] + out = out.reshape(q.shape) + qgrad = qkvgrad[:, :-2].reshape(q.shape) + kgrad = qkvgrad[:, -2].reshape(k.shape) + vgrad = qkvgrad[:, -1].reshape(v.shape) + if varlen_padded: + out = self.unpad(out, cu_seqlen_q) + qgrad = self.unpad(qgrad, cu_seqlen_q) + kgrad = self.unpad(kgrad, cu_seqlen_k) + vgrad = self.unpad(vgrad, cu_seqlen_k) + return self.convert_dtype([out, qgrad, kgrad, vgrad]) + + def test_main(self): + for causal in [False, True]: + for varlen_padded in [False, True]: + ( + q, + k, + v, + cu_seqlen_q, + cu_seqlen_k, + out_grad, + ) = self.gen_test_data(self.dtype, True) + if varlen_padded: + q_pad, _ = self.pad(q, cu_seqlen_q, self.seq_len) + k_pad, _ = self.pad(k, cu_seqlen_k, self.seq_len) + v_pad, _ = self.pad(v, cu_seqlen_k, self.seq_len) + out_grad_pad, _ = self.pad( + out_grad, cu_seqlen_q, self.seq_len + ) + else: + q_pad = q + k_pad = k + v_pad = v + out_grad_pad = out_grad + fa_out = self.calc_qkvpackedfa( + q_pad, + k_pad, + v_pad, + cu_seqlen_q, + cu_seqlen_k, + out_grad_pad, + causal, + varlen_padded, + ) + # if varlen_padded: + # cu_seqlen_q = None + # cu_seqlen_k = None + raw_out = self.calc_raw_attn( + q, + k, + v, + cu_seqlen_q, + cu_seqlen_k, + out_grad, + causal, + True, + ) + assert len(fa_out) == len(raw_out) + for t1, t2 in zip(fa_out, raw_out): + np.testing.assert_allclose(t1, t2, atol=1e-2, rtol=1e-2) + + +class TestFlashAttentionVarlenQKVPackedGQA2( + TestFlashAttentionVarlenQKVPackedGQA +): + def setUp(self): + self.batch_size = 2 + self.num_head = 16 + self.seq_len = 2048 + self.head_dim = 128 + self.num_group = 4 + self.dtype = 'bfloat16' + + +class TestFlashAttentionVarlenQKVPacked(TestFlashAttentionVarlenQKVPackedGQA): + def setUp(self): + self.batch_size = 3 + self.num_head = 7 + self.seq_len = 563 + self.head_dim = 64 + self.num_group = 1 + self.dtype = 'bfloat16' + + +class TestFlashAttentionQKVPackedGQA(TestFlashAttentionGQA): + def calc_qkvpackedfa(self, q, k, v, out_grad, causal): + # q, k, v = self.clone_tensor([q, k, v]) + tq = q.reshape( + [ + self.batch_size, + self.seq_len, + self.num_group, + self.num_head // self.num_group, + self.head_dim, + ], + ) + kv = paddle.stack([k, v], axis=2) + qkv = paddle.concat([tq, kv], axis=2) + (qkv,) = self.clone_tensor([qkv]) + out = flash_attn_qkvpacked(qkv, causal=causal) + out = out[0] + out.backward(out_grad) + qkvgrad = qkv.grad + qgrad = qkvgrad[:, :, :-2].reshape(q.shape) + kgrad = qkvgrad[:, :, -2].reshape(k.shape) + vgrad = qkvgrad[:, :, -1].reshape(v.shape) + return self.convert_dtype([out, qgrad, kgrad, vgrad]) + + def test_main(self): + for causal in [False, True]: + ( + q, + k, + v, + cu_seqlen_q, + cu_seqlen_k, + out_grad, + ) = self.gen_test_data(self.dtype, False) + fa_out = self.calc_qkvpackedfa(q, k, v, out_grad, causal) + raw_out = self.calc_raw_attn( + q, + k, + v, + cu_seqlen_q, + cu_seqlen_k, + out_grad, + causal, + False, + ) + assert len(fa_out) == len(raw_out) + for t1, t2 in zip(fa_out, raw_out): + np.testing.assert_allclose(t1, t2, atol=1e-2, rtol=1e-2) + + +class TestFlashAttentionQKVPackedGQA2(TestFlashAttentionQKVPackedGQA): + def setUp(self): + self.batch_size = 2 + self.num_head = 16 + self.seq_len = 2048 + self.head_dim = 128 + self.num_group = 4 + self.dtype = 'bfloat16' + + +class TestFlashAttentionQKVPacked(TestFlashAttentionQKVPackedGQA): + def setUp(self): + self.batch_size = 3 + self.num_head = 7 + self.seq_len = 563 + self.head_dim = 64 + self.num_group = 1 + self.dtype = 'bfloat16' + + +class TestFlashAttentionVarlenQKVPackedGQADeter( + TestFlashAttentionVarlenQKVPackedGQA +): + def test_main(self): + paddle.set_flags({'FLAGS_cudnn_deterministic': 1}) + for causal in [False, True]: + for varlen_padded in [False, True]: + ( + q, + k, + v, + cu_seqlen_q, + cu_seqlen_k, + out_grad, + ) = self.gen_test_data(self.dtype, True) + if varlen_padded: + q_pad, _ = self.pad(q, cu_seqlen_q, self.seq_len) + k_pad, _ = self.pad(k, cu_seqlen_k, self.seq_len) + v_pad, _ = self.pad(v, cu_seqlen_k, self.seq_len) + out_grad_pad, _ = self.pad( + out_grad, cu_seqlen_q, self.seq_len + ) + else: + q_pad = q + k_pad = k + v_pad = v + out_grad_pad = out_grad + fa_out = self.calc_qkvpackedfa( + q_pad, + k_pad, + v_pad, + cu_seqlen_q, + cu_seqlen_k, + out_grad_pad, + causal, + varlen_padded, + ) + # cu_seqlen_q = None + # cu_seqlen_k = None + raw_out = self.calc_fa( + q, + k, + v, + cu_seqlen_q, + cu_seqlen_k, + out_grad, + causal, + True, + ) + assert len(fa_out) == len(raw_out) + i = 0 + for t1, t2 in zip(fa_out, raw_out): + np.testing.assert_array_equal( + t1, + t2, + err_msg=f"Tensor{i} causal={causal} varlen_padded={varlen_padded}", + ) + i += 1 + paddle.set_flags({'FLAGS_cudnn_deterministic': 0}) + + +# can't bit-match dk,dv now when num_group more than 2, since the sum kernel is different and sum sequence not defined +# class TestFlashAttentionVarlenQKVPackedGQADeter2( +# TestFlashAttentionVarlenQKVPackedGQADeter +# ): +# def setUp(self): +# self.batch_size = 2 +# self.num_head = 16 +# self.seq_len = 2048 +# self.head_dim = 128 +# self.num_group = 4 +# self.dtype = 'bfloat16' + + +class TestFlashAttentionVarlenQKVPackedDeter( + TestFlashAttentionVarlenQKVPackedGQADeter +): + def setUp(self): + self.batch_size = 3 + self.num_head = 7 + self.seq_len = 563 + self.head_dim = 64 + self.num_group = 1 + self.dtype = 'bfloat16' + + +class TestFlashAttentionQKVPackedGQADeter(TestFlashAttentionQKVPackedGQA): + def test_main(self): + paddle.set_flags({'FLAGS_cudnn_deterministic': 1}) + for causal in [False, True]: + ( + q, + k, + v, + cu_seqlen_q, + cu_seqlen_k, + out_grad, + ) = self.gen_test_data(self.dtype, False) + fa_out = self.calc_qkvpackedfa(q, k, v, out_grad, causal) + raw_out = self.calc_fa( + q, + k, + v, + cu_seqlen_q, + cu_seqlen_k, + out_grad, + causal, + False, + ) + assert len(fa_out) == len(raw_out) + i = 0 + for t1, t2 in zip(fa_out, raw_out): + np.testing.assert_array_equal( + t1, t2, err_msg=f"Tensor{i} error, causal={causal}" + ) + i += 1 + paddle.set_flags({'FLAGS_cudnn_deterministic': 0}) + + +# can't bit-match dk,dv now when num_group more than 2, since the sum kernel is different and sum sequence not defined +# class TestFlashAttentionQKVPackedDeter2(TestFlashAttentionQKVPackedGQADeter): +# def setUp(self): +# self.batch_size = 2 +# self.num_head = 16 +# self.seq_len = 2048 +# self.head_dim = 128 +# self.num_group = 4 +# self.dtype = 'bfloat16' + + +class TestFlashAttentionQKVPackedDeter(TestFlashAttentionQKVPackedGQADeter): + def setUp(self): + self.batch_size = 3 + self.num_head = 7 + self.seq_len = 563 + self.head_dim = 64 + self.num_group = 1 + self.dtype = 'bfloat16' + + if __name__ == '__main__': unittest.main() diff --git a/test/legacy_test/test_group_norm_op.py b/test/legacy_test/test_group_norm_op.py index 4e8b18f13f47ba..f097df3b0b99cf 100644 --- a/test/legacy_test/test_group_norm_op.py +++ b/test/legacy_test/test_group_norm_op.py @@ -28,25 +28,58 @@ from utils import static_guard import paddle +import paddle.nn.functional as F from paddle import base from paddle.base import core def group_norm_naive(x, scale, bias, epsilon, groups, data_layout): - if data_layout == "NHWC": - x = np.transpose(x, (0, 3, 1, 2)) # NHWC => NCHW - N, C, H, W = x.shape - G = groups - x = x.reshape((N * G, -1)) - mean = np.mean(x, axis=1, keepdims=True) - var = np.var(x, axis=1, keepdims=True) - output = (x - mean) / np.sqrt(var + epsilon) - output = output.reshape((N, C, H, W)) * scale.reshape( - (-1, 1, 1) - ) + bias.reshape((-1, 1, 1)) - if data_layout == "NHWC": - output = np.transpose(output, (0, 2, 3, 1)) # NCHW => NHWC - return output, mean.reshape((N, G)), var.reshape((N, G)) + dim = x.ndim + if dim == 3: + if data_layout == "NHWC": + x = np.transpose(x, (0, 2, 1)) # NLC => NCL + N, C, L = x.shape + G = groups + x = x.reshape((N * G, -1)) + mean = np.mean(x, axis=1, keepdims=True) + var = np.var(x, axis=1, keepdims=True) + output = (x - mean) / np.sqrt(var + epsilon) + output = output.reshape((N, C, L)) * scale.reshape( + (-1, 1) + ) + bias.reshape((-1, 1)) + if data_layout == "NHWC": + output = np.transpose(output, (0, 2, 1)) # NCL => NLC + return output, mean.reshape((N, G)), var.reshape((N, G)) + elif dim == 4: + if data_layout == "NHWC": + x = np.transpose(x, (0, 3, 1, 2)) # NHWC => NCHW + N, C, H, W = x.shape + G = groups + x = x.reshape((N * G, -1)) + mean = np.mean(x, axis=1, keepdims=True) + var = np.var(x, axis=1, keepdims=True) + output = (x - mean) / np.sqrt(var + epsilon) + output = output.reshape((N, C, H, W)) * scale.reshape( + (-1, 1, 1) + ) + bias.reshape((-1, 1, 1)) + if data_layout == "NHWC": + output = np.transpose(output, (0, 2, 3, 1)) # NCHW => NHWC + return output, mean.reshape((N, G)), var.reshape((N, G)) + else: + if data_layout == "NHWC": + x = np.transpose(x, (0, 4, 1, 2, 3)) # NDHWC => NCDHW + N, C, D, H, W = x.shape + G = groups + x = x.reshape((N * G, -1)) + mean = np.mean(x, axis=1, keepdims=True) + var = np.var(x, axis=1, keepdims=True) + output = (x - mean) / np.sqrt(var + epsilon) + output = output.reshape((N, C, D, H, W)) * scale.reshape( + (-1, 1, 1, 1) + ) + bias.reshape((-1, 1, 1, 1)) + if data_layout == "NHWC": + output = np.transpose(output, (0, 2, 3, 4, 1)) # NCDHW => NDHWC + return output, mean.reshape((N, G)), var.reshape((N, G)) class TestGroupNormOpError(unittest.TestCase): @@ -93,11 +126,15 @@ def setUp(self): self.shape = (2, 100, 3, 5) self.attrs = {'epsilon': 1e-5, 'groups': 2, 'data_layout': "NCHW"} self.compare_between_place = False + self.channel_last = False self.init_test_case() + self.data_format = 'NHWC' if self.channel_last else 'NCHW' input = np.random.random(self.shape).astype(self.dtype) - if self.data_format == "NHWC": - input = np.transpose(input, (0, 2, 3, 1)) + if self.channel_last: + shape = list(self.shape) + shape.insert(len(shape), shape.pop(1)) + input = input.reshape(shape) scale = np.random.random([self.shape[1]]).astype(self.dtype) bias = np.random.random([self.shape[1]]).astype(self.dtype) output, mean, var = group_norm_naive( @@ -267,11 +304,15 @@ def setUp(self): self.shape = (2, 100, 3, 5) self.attrs = {'epsilon': 1e-5, 'groups': 10, 'data_layout': "NCHW"} self.compare_between_place = False + self.channel_last = False self.init_test_case() + self.data_format = 'NHWC' if self.channel_last else 'NCHW' input = np.random.random(self.shape).astype(np.float32) - if self.data_format == "NHWC": - input = np.transpose(input, (0, 2, 3, 1)) + if self.channel_last: + shape = list(self.shape) + shape.insert(len(shape), shape.pop(1)) + input = input.reshape(shape) scale = np.random.random([self.shape[1]]).astype(np.float32) bias = np.random.random([self.shape[1]]).astype(np.float32) output, mean, var = group_norm_naive( @@ -318,7 +359,7 @@ def test_check_grad(self): self.rev_comp_atol = 1e-2 self.rev_comp_rtol = 1e-2 # prim bf16 has diff in windows - if sys.platform == "win32" or self.data_format == "NHWC": + if sys.platform == "win32" or self.channel_last: self.rev_comp_atol = 5e-2 self.rev_comp_rtol = 5e-2 place = core.CUDAPlace(0) @@ -339,17 +380,61 @@ def init_test_case(self): self.attrs['groups'] = 1 +class TestGroupNormOp1_with_NCL(TestGroupNormOp): + def init_test_case(self): + self.shape = (2, 100, 3) + self.data_format = "NCHW" + self.attrs['groups'] = 1 + + +class TestGroupNormOp1_with_NCDHW(TestGroupNormOp): + def init_test_case(self): + self.shape = (2, 100, 3, 2, 2) + self.data_format = "NCDHW" + self.attrs['groups'] = 1 + + class TestGroupNormFP16Op1(TestGroupNormFP16OP): def init_test_case(self): self.attrs['groups'] = 1 self.dtype = np.float16 +class TestGroupNormFP16Op1_with_NCL(TestGroupNormFP16OP): + def init_test_case(self): + self.shape = (2, 100, 3) + self.data_format = "NCL" + self.attrs['groups'] = 1 + self.dtype = np.float16 + + +class TestGroupNormFP16Op1_with_NCDHW(TestGroupNormFP16OP): + def init_test_case(self): + self.shape = (2, 100, 3, 2, 2) + self.data_format = "NCDHW" + self.attrs['groups'] = 1 + self.dtype = np.float16 + + class TestGroupNormBF16Op1(TestGroupNormBF16Op): def init_test_case(self): self.attrs['groups'] = 1 +class TestGroupNormBF16Op1_with_NCL(TestGroupNormBF16Op): + def init_test_case(self): + self.shape = (2, 100, 3) + self.data_format = "NCL" + self.attrs['groups'] = 1 + + +class TestGroupNormBF16Op1_with_NCDHW(TestGroupNormBF16Op): + def init_test_case(self): + self.shape = (2, 100, 3, 2, 2) + self.data_format = "NCDHW" + self.attrs['groups'] = 1 + + class TestGroupNormOp2(TestGroupNormOp): def init_test_case(self): self.attrs['groups'] = 4 @@ -400,12 +485,30 @@ class TestGroupNormOp1_With_NHWC(TestGroupNormOp): def init_test_case(self): self.attrs['groups'] = 2 self.data_format = "NHWC" + self.channel_last = True + + +class TestGroupNormOp1_With_NLC(TestGroupNormOp): + def init_test_case(self): + self.shape = (2, 100, 3) + self.attrs['groups'] = 2 + self.data_format = "NLC" + self.channel_last = True + + +class TestGroupNormOp1_With_NDHWC(TestGroupNormOp): + def init_test_case(self): + self.shape = (2, 100, 3, 2, 2) + self.attrs['groups'] = 2 + self.data_format = "NDHWC" + self.channel_last = True class TestGroupNormOp2_With_NHWC(TestGroupNormOp): def init_test_case(self): self.attrs['groups'] = 4 self.data_format = "NHWC" + self.channel_last = True class TestGroupNormFP16Op_With_NHWC(TestGroupNormFP16OP): @@ -416,6 +519,7 @@ def init_test_case(self): self.attrs['epsilon'] = 0.5 self.shape = (1, 100, 4, 4) self.dtype = np.float16 + self.channel_last = True def test_check_output(self): rtol = 2e-3 @@ -430,6 +534,45 @@ def test_check_output(self): check_pir=True, ) + def test_check_grad(self): + if self.compare_between_place: + return + + check_prim_grad = False + self.rev_comp_atol = 1e-2 + self.rev_comp_rtol = 1e-2 + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, + ['X', 'Scale', 'Bias'], + 'Y', + check_pir=True, + check_prim_pir=check_prim_grad, + max_relative_error=0.03, + ) + + +class TestGroupNormFP16Op_With_NLC(TestGroupNormFP16Op_With_NHWC): + def init_test_case(self): + self.no_need_check_inplace = True + self.attrs['groups'] = 2 + self.data_format = "NLC" + self.attrs['epsilon'] = 0.5 + self.shape = (1, 100, 10) + self.dtype = np.float16 + self.channel_last = True + + +class TestGroupNormFP16Op_With_NDHWC(TestGroupNormFP16Op_With_NHWC): + def init_test_case(self): + self.no_need_check_inplace = True + self.attrs['groups'] = 10 + self.data_format = "NDHWC" + self.attrs['epsilon'] = 0.5 + self.shape = (1, 100, 4, 3, 2) + self.dtype = np.float16 + self.channel_last = True + class TestGroupNormBF16Op_With_NHWC(TestGroupNormBF16Op): def setUp(self): @@ -449,20 +592,14 @@ def setUp(self): } self.compare_between_place = False self.init_test_case() + self.data_format = 'NCHW' if self.data_format[1] == 'C' else 'NHWC' input = ( - np.sin( - np.arange( - self.shape[0] - * self.shape[1] - * self.shape[2] - * self.shape[3] - ) - ) + np.sin(np.arange(np.prod(self.shape))) .reshape(self.shape) .astype(np.float32) ) - scale = np.ones(self.shape[3]).astype(np.float32) - bias = np.sin(np.arange(self.shape[3])).astype(np.float32) + scale = np.ones(self.shape[-1]).astype(np.float32) + bias = np.sin(np.arange(self.shape[-1])).astype(np.float32) output, mean, var = group_norm_naive( input, scale, @@ -490,11 +627,46 @@ def test_check_output(self): ) +class TestGroupNormBF16Op_With_NLC(TestGroupNormBF16Op_With_NHWC): + def init_test_case(self): + self.shape = (1, 3, 512) + self.data_format = "NLC" + + +class TestGroupNormBF16Op_With_NDHWC(TestGroupNormBF16Op_With_NHWC): + def init_test_case(self): + self.shape = (1, 3, 2, 2, 512) + self.data_format = "NDHWC" + + def test_check_grad(self): + if self.compare_between_place: + return + + check_prim_grad = False + + self.rev_comp_atol = 1e-2 + self.rev_comp_rtol = 1e-2 + # prim bf16 has diff in windows + if sys.platform == "win32" or self.channel_last: + self.rev_comp_atol = 5e-2 + self.rev_comp_rtol = 5e-2 + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, + ['X', 'Scale', 'Bias'], + 'Y', + check_pir=True, + check_prim_pir=check_prim_grad, + max_relative_error=0.03, + ) + + class TestGroupNormOpBigEps1_With_NHWC(TestGroupNormOp): def init_test_case(self): self.attrs['groups'] = 1 self.attrs['epsilon'] = 0.5 self.data_format = "NHWC" + self.channel_last = True class TestGroupNormOpBigEps2_With_NHWC(TestGroupNormOp): @@ -502,12 +674,14 @@ def init_test_case(self): self.attrs['groups'] = 4 self.attrs['epsilon'] = 0.5 self.data_format = "NHWC" + self.channel_last = True class TestGroupNormOpBigEps3_With_NHWC(TestGroupNormOp): def init_test_case(self): self.attrs['epsilon'] = 0.5 self.data_format = "NHWC" + self.channel_last = True @skip_check_grad_ci( @@ -520,6 +694,7 @@ def init_test_case(self): self.attrs['groups'] = 8 self.data_format = "NHWC" self.compare_between_place = True + self.channel_last = True class TestGroupNormAPI_With_NHWC(unittest.TestCase): @@ -571,6 +746,134 @@ def test_case1(self): np.testing.assert_allclose(results[1], expect_res2[0], rtol=1e-05) +class TestGroupNormFunctionalAPI_With_NLC(unittest.TestCase): + def test_case1(self): + places = [paddle.CPUPlace()] + if base.is_compiled_with_cuda(): + places.append(paddle.CUDAPlace(0)) + for place in places: + paddle.disable_static(place) + data1_np = np.random.random((2, 3, 4)).astype("float64") + data2_np = np.random.random((2, 4, 3)).astype("float64") + data1 = paddle.to_tensor(data1_np) + data2 = paddle.to_tensor(data2_np) + scale = paddle.to_tensor([1, 1, 1, 1], dtype="float64") + bias = paddle.to_tensor([0, 0, 0, 0], dtype="float64") + out1 = F.group_norm( + data1, num_groups=2, weight=scale, bias=bias, data_format="NLC" + ) + out2 = F.group_norm( + data2, num_groups=2, weight=scale, bias=bias, data_format="NCL" + ) + + expect_res1 = group_norm_naive( + data1_np, + scale, + bias, + epsilon=1e-5, + groups=2, + data_layout="NHWC", + ) + expect_res2 = group_norm_naive( + data2_np, + scale, + bias, + epsilon=1e-5, + groups=2, + data_layout="NCHW", + ) + np.testing.assert_allclose(out1.numpy(), expect_res1[0], rtol=1e-05) + np.testing.assert_allclose(out2.numpy(), expect_res2[0], rtol=1e-05) + + +class TestGroupNormFunctionalAPI_With_NHWC(unittest.TestCase): + def test_case1(self): + places = [paddle.CPUPlace()] + if base.is_compiled_with_cuda(): + places.append(paddle.CUDAPlace(0)) + for place in places: + paddle.disable_static(place) + data1_np = np.random.random((2, 3, 2, 4)).astype("float64") + data2_np = np.random.random((2, 4, 3, 2)).astype("float64") + data1 = paddle.to_tensor(data1_np) + data2 = paddle.to_tensor(data2_np) + scale = paddle.to_tensor([1, 1, 1, 1], dtype="float64") + bias = paddle.to_tensor([0, 0, 0, 0], dtype="float64") + out1 = F.group_norm( + data1, num_groups=2, weight=scale, bias=bias, data_format="NHWC" + ) + out2 = F.group_norm( + data2, num_groups=2, weight=scale, bias=bias, data_format="NCHW" + ) + + expect_res1 = group_norm_naive( + data1_np, + scale, + bias, + epsilon=1e-5, + groups=2, + data_layout="NHWC", + ) + expect_res2 = group_norm_naive( + data2_np, + scale, + bias, + epsilon=1e-5, + groups=2, + data_layout="NCHW", + ) + np.testing.assert_allclose(out1.numpy(), expect_res1[0], rtol=1e-05) + np.testing.assert_allclose(out2.numpy(), expect_res2[0], rtol=1e-05) + + +class TestGroupNormFunctionalAPI_With_NDHWC(unittest.TestCase): + def test_case1(self): + places = [paddle.CPUPlace()] + if base.is_compiled_with_cuda(): + places.append(paddle.CUDAPlace(0)) + for place in places: + paddle.disable_static(place) + data1_np = np.random.random((2, 3, 2, 2, 4)).astype("float64") + data2_np = np.random.random((2, 4, 3, 2, 2)).astype("float64") + data1 = paddle.to_tensor(data1_np) + data2 = paddle.to_tensor(data2_np) + scale = paddle.to_tensor([1, 1, 1, 1], dtype="float64") + bias = paddle.to_tensor([0, 0, 0, 0], dtype="float64") + out1 = F.group_norm( + data1, + num_groups=2, + weight=scale, + bias=bias, + data_format="NDHWC", + ) + out2 = F.group_norm( + data2, + num_groups=2, + weight=scale, + bias=bias, + data_format="NCDHW", + ) + + expect_res1 = group_norm_naive( + data1_np, + scale, + bias, + epsilon=1e-5, + groups=2, + data_layout="NHWC", + ) + expect_res2 = group_norm_naive( + data2_np, + scale, + bias, + epsilon=1e-5, + groups=2, + data_layout="NCHW", + ) + np.testing.assert_allclose(out1.numpy(), expect_res1[0], rtol=1e-05) + np.testing.assert_allclose(out2.numpy(), expect_res2[0], rtol=1e-05) + + class TestGroupNormException(unittest.TestCase): # data_layout is not NHWC or NCHW def test_exception(self): @@ -690,7 +993,9 @@ def forward(self, x): def apply_to_static(net, use_cinn): build_strategy = paddle.static.BuildStrategy() build_strategy.build_cinn_pass = use_cinn - return paddle.jit.to_static(net, build_strategy=build_strategy) + return paddle.jit.to_static( + net, build_strategy=build_strategy, full_graph=True + ) # The original GroupNorm cannot support NHWC format diff --git a/test/legacy_test/test_group_norm_op_v2.py b/test/legacy_test/test_group_norm_op_v2.py index 94232e7f70a00e..618e740072e958 100644 --- a/test/legacy_test/test_group_norm_op_v2.py +++ b/test/legacy_test/test_group_norm_op_v2.py @@ -21,9 +21,15 @@ from paddle.base import core -def group_norm_naive_for_general_dimension(x, scale, bias, epsilon, groups): +def group_norm_naive_for_general_dimension( + x, scale, bias, epsilon, groups, channel_last=False +): # original version group norm only support 4-D tensor # this function generalizes to support differnt dimensions tensor (>= 2-D) + if channel_last: + shape = list(range(x.ndim)) + shape.insert(1, shape.pop(-1)) + x = x.transpose(shape) input_shape = x.shape N, C = x.shape[0], x.shape[1] G = groups @@ -32,8 +38,12 @@ def group_norm_naive_for_general_dimension(x, scale, bias, epsilon, groups): var = np.var(x, axis=1, keepdims=True) output = (x - mean) / np.sqrt(var + epsilon) output = output.reshape(input_shape) * scale.reshape( - (-1, 1, 1) - ) + bias.reshape((-1, 1, 1)) + [-1] + [1] * (x.ndim - 2) + ) + bias.reshape([-1] + [1] * (x.ndim - 2)) + if channel_last: + shape = list(range(output.ndim)) + shape.insert(len(shape), shape.pop(1)) + output = output.transpose(shape) return output @@ -73,6 +83,176 @@ def test_numerical_accuracy(self): self.assertTrue(np.allclose(result2, expect_res2, atol=1e-5)) +class TestGroupNormAPIV2_With_NCL(unittest.TestCase): + def test_numerical_accuracy(self): + paddle.disable_static() + shape = (2, 6, 4) + np.random.seed(10) + places = [base.CPUPlace()] + if core.is_compiled_with_cuda() and core.op_support_gpu("group_norm"): + places.append(base.CUDAPlace(0)) + + for place in places: + paddle.disable_static(place) + scale = np.array([1]).astype("float32") + bias = np.array([0]).astype("float32") + data = np.random.random(shape).astype("float32") + expect_res1 = group_norm_naive_for_general_dimension( + data, scale, bias, epsilon=1e-5, groups=6 + ) + expect_res2 = group_norm_naive_for_general_dimension( + data, scale, bias, epsilon=1e-5, groups=2 + ) + + gn1 = paddle.nn.GroupNorm( + num_channels=6, num_groups=6, data_format='NCL' + ) + gn2 = paddle.nn.GroupNorm( + num_channels=6, num_groups=2, data_format='NCL' + ) + data_pd = paddle.to_tensor(data) + result1 = gn1(data_pd).numpy() + result2 = gn2(data_pd).numpy() + np.testing.assert_allclose(result1, expect_res1, atol=1e-5) + np.testing.assert_allclose(result2, expect_res2, atol=1e-5) + + +class TestGroupNormAPIV2_With_NCDHW(unittest.TestCase): + def test_numerical_accuracy(self): + paddle.disable_static() + shape = (2, 6, 4, 2, 2) + np.random.seed(10) + places = [base.CPUPlace()] + if core.is_compiled_with_cuda() and core.op_support_gpu("group_norm"): + places.append(base.CUDAPlace(0)) + + for place in places: + paddle.disable_static(place) + scale = np.array([1]).astype("float32") + bias = np.array([0]).astype("float32") + data = np.random.random(shape).astype("float32") + expect_res1 = group_norm_naive_for_general_dimension( + data, scale, bias, epsilon=1e-5, groups=6 + ) + expect_res2 = group_norm_naive_for_general_dimension( + data, scale, bias, epsilon=1e-5, groups=2 + ) + + gn1 = paddle.nn.GroupNorm( + num_channels=6, num_groups=6, data_format='NCDHW' + ) + gn2 = paddle.nn.GroupNorm( + num_channels=6, num_groups=2, data_format='NCDHW' + ) + data_pd = paddle.to_tensor(data) + result1 = gn1(data_pd).numpy() + result2 = gn2(data_pd).numpy() + np.testing.assert_allclose(result1, expect_res1, atol=1e-5) + np.testing.assert_allclose(result2, expect_res2, atol=1e-5) + + +class TestGroupNormAPIV2_With_NLC(unittest.TestCase): + def test_numerical_accuracy(self): + paddle.disable_static() + shape = (2, 4, 6) + np.random.seed(10) + places = [base.CPUPlace()] + if core.is_compiled_with_cuda() and core.op_support_gpu("group_norm"): + places.append(base.CUDAPlace(0)) + + for place in places: + paddle.disable_static(place) + scale = np.array([1]).astype("float32") + bias = np.array([0]).astype("float32") + data = np.random.random(shape).astype("float32") + expect_res1 = group_norm_naive_for_general_dimension( + data, scale, bias, epsilon=1e-5, groups=6, channel_last=True + ) + expect_res2 = group_norm_naive_for_general_dimension( + data, scale, bias, epsilon=1e-5, groups=2, channel_last=True + ) + + gn1 = paddle.nn.GroupNorm( + num_channels=6, num_groups=6, data_format='NLC' + ) + gn2 = paddle.nn.GroupNorm( + num_channels=6, num_groups=2, data_format='NLC' + ) + data_pd = paddle.to_tensor(data) + result1 = gn1(data_pd).numpy() + result2 = gn2(data_pd).numpy() + np.testing.assert_allclose(result1, expect_res1, atol=1e-5) + np.testing.assert_allclose(result2, expect_res2, atol=1e-5) + + +class TestGroupNormAPIV2_With_NHWC(unittest.TestCase): + def test_numerical_accuracy(self): + paddle.disable_static() + shape = (2, 4, 2, 6) + np.random.seed(10) + places = [base.CPUPlace()] + if core.is_compiled_with_cuda() and core.op_support_gpu("group_norm"): + places.append(base.CUDAPlace(0)) + + for place in places: + paddle.disable_static(place) + scale = np.array([1]).astype("float32") + bias = np.array([0]).astype("float32") + data = np.random.random(shape).astype("float32") + expect_res1 = group_norm_naive_for_general_dimension( + data, scale, bias, epsilon=1e-5, groups=6, channel_last=True + ) + expect_res2 = group_norm_naive_for_general_dimension( + data, scale, bias, epsilon=1e-5, groups=2, channel_last=True + ) + + gn1 = paddle.nn.GroupNorm( + num_channels=6, num_groups=6, data_format='NHWC' + ) + gn2 = paddle.nn.GroupNorm( + num_channels=6, num_groups=2, data_format='NHWC' + ) + data_pd = paddle.to_tensor(data) + result1 = gn1(data_pd).numpy() + result2 = gn2(data_pd).numpy() + np.testing.assert_allclose(result1, expect_res1, atol=1e-5) + np.testing.assert_allclose(result2, expect_res2, atol=1e-5) + + +class TestGroupNormAPIV2_With_NDHWC(unittest.TestCase): + def test_numerical_accuracy(self): + paddle.disable_static() + shape = (2, 4, 2, 2, 6) + np.random.seed(10) + places = [base.CPUPlace()] + if core.is_compiled_with_cuda() and core.op_support_gpu("group_norm"): + places.append(base.CUDAPlace(0)) + + for place in places: + paddle.disable_static(place) + scale = np.array([1]).astype("float32") + bias = np.array([0]).astype("float32") + data = np.random.random(shape).astype("float32") + expect_res1 = group_norm_naive_for_general_dimension( + data, scale, bias, epsilon=1e-5, groups=6, channel_last=True + ) + expect_res2 = group_norm_naive_for_general_dimension( + data, scale, bias, epsilon=1e-5, groups=2, channel_last=True + ) + + gn1 = paddle.nn.GroupNorm( + num_channels=6, num_groups=6, data_format='NDHWC' + ) + gn2 = paddle.nn.GroupNorm( + num_channels=6, num_groups=2, data_format='NDHWC' + ) + data_pd = paddle.to_tensor(data) + result1 = gn1(data_pd).numpy() + result2 = gn2(data_pd).numpy() + np.testing.assert_allclose(result1, expect_res1, atol=1e-5) + np.testing.assert_allclose(result2, expect_res2, atol=1e-5) + + class TestGroupNormAPIV2_With_General_Dimensions_fp16(unittest.TestCase): def test_numerical_accuracy(self): # fp16 only supported in cuda @@ -121,6 +301,231 @@ def test_numerical_accuracy(self): ) +class TestGroupNormAPIV2_With_NCL_fp16(unittest.TestCase): + def test_numerical_accuracy(self): + if not core.is_compiled_with_cuda(): + return + paddle.disable_static() + shape = (2, 6, 4) + np.random.seed(10) + places = [base.CPUPlace()] + if core.is_compiled_with_cuda() and core.op_support_gpu("group_norm"): + places.append(base.CUDAPlace(0)) + + for place in places: + paddle.disable_static(place) + scale = np.array([1]).astype("float32") + bias = np.array([0]).astype("float32") + data = np.random.random(shape).astype("float32") + expect_res1 = group_norm_naive_for_general_dimension( + data, scale, bias, epsilon=1e-5, groups=6 + ) + expect_res2 = group_norm_naive_for_general_dimension( + data, scale, bias, epsilon=1e-5, groups=2 + ) + + gn1 = paddle.nn.GroupNorm( + num_channels=6, num_groups=6, data_format='NCL' + ) + gn2 = paddle.nn.GroupNorm( + num_channels=6, num_groups=2, data_format='NCL' + ) + paddle.assign(paddle.cast(gn1.weight, 'float16'), gn1.weight) + paddle.assign(paddle.cast(gn1.bias, 'float16'), gn1.bias) + paddle.assign(paddle.cast(gn2.weight, 'float16'), gn2.weight) + paddle.assign(paddle.cast(gn2.bias, 'float16'), gn2.bias) + + data_pd = paddle.to_tensor(data.astype('float16')) + result1 = gn1(data_pd).numpy() + result2 = gn2(data_pd).numpy() + np.testing.assert_allclose( + result1, expect_res1, rtol=1e-2, atol=1e-3 + ) + np.testing.assert_allclose( + result2, expect_res2, rtol=1e-2, atol=1e-3 + ) + + +class TestGroupNormAPIV2_With_NCDHW_fp16(unittest.TestCase): + def test_numerical_accuracy(self): + if not core.is_compiled_with_cuda(): + return + paddle.disable_static() + shape = (2, 6, 4, 2, 2) + np.random.seed(10) + places = [base.CPUPlace()] + if core.is_compiled_with_cuda() and core.op_support_gpu("group_norm"): + places.append(base.CUDAPlace(0)) + + for place in places: + paddle.disable_static(place) + scale = np.array([1]).astype("float32") + bias = np.array([0]).astype("float32") + data = np.random.random(shape).astype("float32") + expect_res1 = group_norm_naive_for_general_dimension( + data, scale, bias, epsilon=1e-5, groups=6 + ) + expect_res2 = group_norm_naive_for_general_dimension( + data, scale, bias, epsilon=1e-5, groups=2 + ) + + gn1 = paddle.nn.GroupNorm( + num_channels=6, num_groups=6, data_format='NCDHW' + ) + gn2 = paddle.nn.GroupNorm( + num_channels=6, num_groups=2, data_format='NCDHW' + ) + paddle.assign(paddle.cast(gn1.weight, 'float16'), gn1.weight) + paddle.assign(paddle.cast(gn1.bias, 'float16'), gn1.bias) + paddle.assign(paddle.cast(gn2.weight, 'float16'), gn2.weight) + paddle.assign(paddle.cast(gn2.bias, 'float16'), gn2.bias) + + data_pd = paddle.to_tensor(data.astype('float16')) + result1 = gn1(data_pd).numpy() + result2 = gn2(data_pd).numpy() + np.testing.assert_allclose( + result1, expect_res1, rtol=1e-2, atol=1e-2 + ) + np.testing.assert_allclose( + result2, expect_res2, rtol=1e-2, atol=1e-2 + ) + + +class TestGroupNormAPIV2_With_NLC_fp16(unittest.TestCase): + def test_numerical_accuracy(self): + if not core.is_compiled_with_cuda(): + return + paddle.disable_static() + shape = (2, 4, 6) + np.random.seed(10) + places = [base.CPUPlace()] + if core.is_compiled_with_cuda() and core.op_support_gpu("group_norm"): + places.append(base.CUDAPlace(0)) + + for place in places: + paddle.disable_static(place) + scale = np.array([1]).astype("float32") + bias = np.array([0]).astype("float32") + data = np.random.random(shape).astype("float32") + expect_res1 = group_norm_naive_for_general_dimension( + data, scale, bias, epsilon=1e-5, groups=6, channel_last=True + ) + expect_res2 = group_norm_naive_for_general_dimension( + data, scale, bias, epsilon=1e-5, groups=2, channel_last=True + ) + + gn1 = paddle.nn.GroupNorm( + num_channels=6, num_groups=6, data_format='NLC' + ) + gn2 = paddle.nn.GroupNorm( + num_channels=6, num_groups=2, data_format='NLC' + ) + paddle.assign(paddle.cast(gn1.weight, 'float16'), gn1.weight) + paddle.assign(paddle.cast(gn1.bias, 'float16'), gn1.bias) + paddle.assign(paddle.cast(gn2.weight, 'float16'), gn2.weight) + paddle.assign(paddle.cast(gn2.bias, 'float16'), gn2.bias) + + data_pd = paddle.to_tensor(data.astype('float16')) + result1 = gn1(data_pd).numpy() + result2 = gn2(data_pd).numpy() + np.testing.assert_allclose( + result1, expect_res1, rtol=1e-2, atol=1e-3 + ) + np.testing.assert_allclose( + result2, expect_res2, rtol=1e-2, atol=1e-3 + ) + + +class TestGroupNormAPIV2_With_NHWC_fp16(unittest.TestCase): + def test_numerical_accuracy(self): + if not core.is_compiled_with_cuda(): + return + paddle.disable_static() + shape = (2, 4, 2, 6) + np.random.seed(10) + places = [base.CPUPlace()] + if core.is_compiled_with_cuda() and core.op_support_gpu("group_norm"): + places.append(base.CUDAPlace(0)) + + for place in places: + paddle.disable_static(place) + scale = np.array([1]).astype("float32") + bias = np.array([0]).astype("float32") + data = np.random.random(shape).astype("float32") + expect_res1 = group_norm_naive_for_general_dimension( + data, scale, bias, epsilon=1e-5, groups=6, channel_last=True + ) + expect_res2 = group_norm_naive_for_general_dimension( + data, scale, bias, epsilon=1e-5, groups=2, channel_last=True + ) + + gn1 = paddle.nn.GroupNorm( + num_channels=6, num_groups=6, data_format='NHWC' + ) + gn2 = paddle.nn.GroupNorm( + num_channels=6, num_groups=2, data_format='NHWC' + ) + paddle.assign(paddle.cast(gn1.weight, 'float16'), gn1.weight) + paddle.assign(paddle.cast(gn1.bias, 'float16'), gn1.bias) + paddle.assign(paddle.cast(gn2.weight, 'float16'), gn2.weight) + paddle.assign(paddle.cast(gn2.bias, 'float16'), gn2.bias) + + data_pd = paddle.to_tensor(data.astype('float16')) + result1 = gn1(data_pd).numpy() + result2 = gn2(data_pd).numpy() + np.testing.assert_allclose( + result1, expect_res1, rtol=1e-2, atol=1e-3 + ) + np.testing.assert_allclose( + result2, expect_res2, rtol=1e-2, atol=1e-3 + ) + + +class TestGroupNormAPIV2_With_NDHWC_fp16(unittest.TestCase): + def test_numerical_accuracy(self): + if not core.is_compiled_with_cuda(): + return + paddle.disable_static() + shape = (2, 4, 2, 2, 6) + np.random.seed(10) + places = [base.CPUPlace()] + if core.is_compiled_with_cuda() and core.op_support_gpu("group_norm"): + places.append(base.CUDAPlace(0)) + + for place in places: + paddle.disable_static(place) + scale = np.array([1]).astype("float32") + bias = np.array([0]).astype("float32") + data = np.random.random(shape).astype("float32") + expect_res1 = group_norm_naive_for_general_dimension( + data, scale, bias, epsilon=1e-5, groups=6, channel_last=True + ) + expect_res2 = group_norm_naive_for_general_dimension( + data, scale, bias, epsilon=1e-5, groups=2, channel_last=True + ) + + gn1 = paddle.nn.GroupNorm( + num_channels=6, num_groups=6, data_format='NDHWC' + ) + gn2 = paddle.nn.GroupNorm( + num_channels=6, num_groups=2, data_format='NDHWC' + ) + paddle.assign(paddle.cast(gn1.weight, 'float16'), gn1.weight) + paddle.assign(paddle.cast(gn1.bias, 'float16'), gn1.bias) + paddle.assign(paddle.cast(gn2.weight, 'float16'), gn2.weight) + paddle.assign(paddle.cast(gn2.bias, 'float16'), gn2.bias) + + data_pd = paddle.to_tensor(data.astype('float16')) + result1 = gn1(data_pd).numpy() + result2 = gn2(data_pd).numpy() + np.testing.assert_allclose( + result1, expect_res1, rtol=1e-2, atol=1e-2 + ) + np.testing.assert_allclose( + result2, expect_res2, rtol=1e-2, atol=1e-2 + ) + + class TestGroupNormDimException(unittest.TestCase): def test_exception(self): def test_empty_input_static_API(): diff --git a/test/legacy_test/test_jit_save_load.py b/test/legacy_test/test_jit_save_load_rename.py similarity index 85% rename from test/legacy_test/test_jit_save_load.py rename to test/legacy_test/test_jit_save_load_rename.py index c9bb373424e772..aa0d0290131fbe 100644 --- a/test/legacy_test/test_jit_save_load.py +++ b/test/legacy_test/test_jit_save_load_rename.py @@ -28,6 +28,7 @@ from paddle.jit.api import to_static from paddle.jit.translated_layer import INFER_PARAMS_INFO_SUFFIX from paddle.nn import Linear +from paddle.pir_utils import test_with_dygraph_pir from paddle.static import InputSpec BATCH_SIZE = 32 @@ -67,7 +68,10 @@ def __init__(self, in_size, out_size): super().__init__() self._linear = Linear(in_size, out_size) - @to_static(input_spec=[InputSpec(shape=[None, 784], dtype='float32')]) + @to_static( + input_spec=[InputSpec(shape=[None, 784], dtype='float32')], + full_graph=True, + ) def forward(self, x): return self._linear(x) @@ -86,12 +90,6 @@ def __init__(self, in_size, out_size): super().__init__() self._linear = Linear(in_size, out_size) - @to_static( - input_spec=[ - InputSpec(shape=[None, 784], dtype='float32', name="image"), - InputSpec(shape=[None, 1], dtype='int64', name="label"), - ] - ) def forward(self, x, label): out = self._linear(x) loss = paddle.nn.functional.cross_entropy( @@ -106,12 +104,6 @@ def __init__(self, in_size, out_size): super().__init__() self._linear = Linear(in_size, out_size) - @to_static( - input_spec=[ - InputSpec(shape=[None, 784], dtype='float32', name="image"), - InputSpec(shape=[None, 1], dtype='int64', name="label"), - ] - ) def forward(self, x, label): out = self._linear(x) loss = paddle.nn.functional.cross_entropy( @@ -126,12 +118,6 @@ def __init__(self, in_size, out_size): super().__init__() self._linear = Linear(in_size, out_size) - @to_static( - input_spec=[ - InputSpec(shape=[None, 784], dtype='float32', name="image"), - InputSpec(shape=[None, 1], dtype='int64', name="label"), - ] - ) def forward(self, x, label): out = self._linear(x) return out @@ -156,12 +142,6 @@ def __init__(self, in_size, out_size): self._linear1 = Linear(in_size, out_size) self._linear2 = Linear(in_size, out_size) - @to_static( - input_spec=[ - InputSpec([None, 8], dtype='float32'), - InputSpec([None, 8], dtype='float32'), - ] - ) def forward(self, x, y): x_out = self._linear1(x) y_out = self._linear2(y) @@ -175,12 +155,6 @@ def __init__(self, in_size, out_size): self._linear1 = Linear(in_size, out_size) self._linear2 = Linear(in_size, out_size) - @to_static( - input_spec=( - InputSpec([None, 8], dtype='float32'), - InputSpec([None, 8], dtype='float32'), - ) - ) def forward(self, x, y): x_out = self._linear1(x) y_out = self._linear2(y) @@ -238,12 +212,6 @@ def __init__(self, in_size, out_size): super().__init__() self._linear = Linear(in_size, out_size) - @paddle.jit.to_static( - input_spec=[ - {'img': InputSpec(shape=[None, 8], dtype='float32', name='img')}, - {'label': InputSpec(shape=[None, 1], dtype='int64', name='label')}, - ] - ) def forward(self, img, label): out = self._linear(img['img']) # not return loss to avoid prune output @@ -286,15 +254,12 @@ def __init__(self, in_size, out_size): self._linear_1 = Linear(in_size, out_size) self._scale = paddle.to_tensor([9.9]) - @paddle.jit.to_static def forward(self, x): return self._linear_0(x) - @paddle.jit.to_static def forward_no_param(self, x): return x - @paddle.jit.to_static def forward_general(self, x): return self._linear_0(x) + self._linear_1(x) * self._scale @@ -428,14 +393,16 @@ def train_and_save_model(self, model_path=None): self.assertEqual(orig_input_types, new_input_types) return layer + @test_with_dygraph_pir def test_save_load(self): # train and save model train_layer = self.train_and_save_model() # load model loaded_layer = paddle.jit.load(self.model_path) self.load_and_inference(train_layer, loaded_layer) - self.load_dygraph_state_dict(train_layer) self.load_and_finetune(train_layer, loaded_layer) + if not paddle.framework.use_pir_api(): + self.load_dygraph_state_dict(train_layer) def load_and_inference(self, train_layer, infer_layer): train_layer.eval() @@ -479,6 +446,7 @@ def test_load_dygraph_no_path(self): with self.assertRaises(ValueError): model_dict = paddle.load(model_path) + @test_with_dygraph_pir def test_jit_load_no_path(self): path = os.path.join( self.temp_dir.name, "test_jit_save_load.no_path/model_path" @@ -496,6 +464,7 @@ def setUp(self): def tearDown(self): self.temp_dir.cleanup() + @test_with_dygraph_pir def test_nest_output(self): x = paddle.to_tensor(np.random.random((4, 8)).astype('float32')) @@ -737,11 +706,28 @@ def dfs(obj1, obj2): class TestSaveLoadWithDictInput(unittest.TestCase): + @test_with_dygraph_pir def test_dict_input(self): # NOTE: This net cannot be executed, it is just # a special case for exporting models in model validation # We DO NOT recommend this writing way of Layer net = LinearNetWithDictInput(8, 8) + net = paddle.jit.to_static( + net, + input_spec=[ + { + 'img': InputSpec( + shape=[None, 8], dtype=paddle.float32, name='img' + ) + }, + { + 'label': InputSpec( + shape=[None, 1], dtype=paddle.int64, name='label' + ) + }, + ], + full_graph=True, + ) # net.forward.concrete_program.inputs: # (<__main__.LinearNetWithDictInput object at 0x7f2655298a98>, # {'img': var img : base.VarType.LOD_TENSOR.shape(-1, 8).astype(VarType.FP32)}, @@ -756,7 +742,11 @@ def test_dict_input(self): layer=net, path=path, input_spec=[ - {'img': InputSpec(shape=[None, 8], dtype='float32', name='img')} + { + 'img': InputSpec( + shape=[None, 8], dtype=paddle.float32, name='img' + ) + } ], ) @@ -767,10 +757,12 @@ def test_dict_input(self): # loaded_net._input_spec(): # [InputSpec(shape=(-1, 8), dtype=VarType.FP32, name=img)] self.assertEqual(len(loaded_net._input_spec()), 1) + self.assertEqual(len(loaded_net._output_spec()), 1) temp_dir.cleanup() class TestSaveLoadWithDictInputNoPrune(unittest.TestCase): + @test_with_dygraph_pir def test_dict_input(self): net = LinearNetWithDictInputNoPrune(8, 8) temp_dir = tempfile.TemporaryDirectory() @@ -811,11 +803,14 @@ def setUp(self): def tearDown(self): self.temp_dir.cleanup() + @test_with_dygraph_pir def test_with_input_spec(self): net = LinearNetReturnLoss(8, 8) # set x.shape = [None, 8] net.forward = to_static( - net.forward, input_spec=[InputSpec([None, 8], name='x')] + net.forward, + input_spec=[InputSpec([None, 8], name='x')], + full_graph=True, ) model_path = os.path.join( @@ -824,7 +819,10 @@ def test_with_input_spec(self): # check inputs and outputs self.assertTrue(len(net.forward.inputs) == 1) input_x = net.forward.inputs[0] - self.assertTrue(input_x.shape == (-1, 8)) + if paddle.framework.use_pir_api(): + self.assertTrue(input_x.shape == [-1, 8]) + else: + self.assertTrue(input_x.shape == (-1, 8)) self.assertTrue(input_x.name == 'x') # 1. prune loss @@ -836,8 +834,17 @@ def test_with_input_spec(self): x = paddle.to_tensor(np.random.random((4, 8)).astype('float32')) pred = infer_layer(x) + @test_with_dygraph_pir def test_multi_in_out(self): net = LinearNetMultiInput(8, 8) + net = paddle.jit.to_static( + net, + input_spec=[ + InputSpec([None, 8], dtype='float32'), + InputSpec([None, 8], dtype='float32'), + ], + full_graph=True, + ) model_path = os.path.join( self.temp_dir.name, "multi_inout.output_spec1/model" @@ -846,8 +853,12 @@ def test_multi_in_out(self): self.assertTrue(len(net.forward.inputs) == 2) input_x = net.forward.inputs[0] input_y = net.forward.inputs[1] - self.assertTrue(input_x.shape == (-1, 8)) - self.assertTrue(input_y.shape == (-1, 8)) + if paddle.framework.use_pir_api(): + self.assertTrue(input_x.shape == [-1, 8]) + self.assertTrue(input_y.shape == [-1, 8]) + else: + self.assertTrue(input_x.shape == (-1, 8)) + self.assertTrue(input_y.shape == (-1, 8)) # 2. prune loss output_spec = net.forward.outputs[:2] @@ -874,9 +885,17 @@ def test_multi_in_out(self): # 4. assert pred_x == pred_xx np.testing.assert_allclose(pred_x.numpy(), pred_xx.numpy(), rtol=1e-05) + @test_with_dygraph_pir def test_multi_in_out1(self): net = LinearNetMultiInput1(8, 8) - + net = paddle.jit.to_static( + net, + input_spec=( + InputSpec([None, 8], dtype='float32'), + InputSpec([None, 8], dtype='float32'), + ), + full_graph=True, + ) model_path = os.path.join( self.temp_dir.name, "multi_inout1.output_spec1/model" ) @@ -884,8 +903,12 @@ def test_multi_in_out1(self): self.assertTrue(len(net.forward.inputs) == 2) input_x = net.forward.inputs[0] input_y = net.forward.inputs[1] - self.assertTrue(input_x.shape == (-1, 8)) - self.assertTrue(input_y.shape == (-1, 8)) + if paddle.framework.use_pir_api(): + self.assertTrue(input_x.shape == [-1, 8]) + self.assertTrue(input_y.shape == [-1, 8]) + else: + self.assertTrue(input_x.shape == (-1, 8)) + self.assertTrue(input_y.shape == (-1, 8)) # 2. prune loss output_spec = net.forward.outputs[:2] @@ -931,8 +954,14 @@ def setUp(self): def tearDown(self): self.temp_dir.cleanup() + @test_with_dygraph_pir def test_output_spec(self): train_layer = LinearNetReturnLoss(8, 8) + train_layer.forward = to_static( + train_layer.forward, + input_spec=[InputSpec([None, 8], name='x')], + full_graph=True, + ) adam = paddle.optimizer.Adam( learning_rate=0.1, parameters=train_layer.parameters() ) @@ -946,7 +975,7 @@ def test_output_spec(self): model_path = os.path.join( self.temp_dir.name, "save_load_config.output_spec" ) - output_spec = [out] + output_spec = train_layer.forward.outputs[:1] paddle.jit.save( layer=train_layer, path=model_path, @@ -961,22 +990,27 @@ def test_output_spec(self): train_layer(x)[0].numpy(), infer_layer(x).numpy() ) + @test_with_dygraph_pir def test_save_no_support_config_error(self): layer = LinearNet(784, 1) path = os.path.join(self.temp_dir.name, "no_support_config_test") with self.assertRaises(ValueError): paddle.jit.save(layer=layer, path=path, model_filename="") + @test_with_dygraph_pir def test_load_empty_model_filename_error(self): path = os.path.join(self.temp_dir.name, "error_model_filename_test") + with self.assertRaises(ValueError): paddle.jit.load(path, model_filename="") + @test_with_dygraph_pir def test_load_empty_params_filename_error(self): path = os.path.join(self.temp_dir.name, "error_params_filename_test") with self.assertRaises(ValueError): paddle.jit.load(path, params_filename="") + @test_with_dygraph_pir def test_load_with_no_support_config(self): path = os.path.join(self.temp_dir.name, "no_support_config_test") with self.assertRaises(ValueError): @@ -1001,6 +1035,7 @@ def setUp(self): def tearDown(self): self.temp_dir.cleanup() + @test_with_dygraph_pir def train_and_save_orig_model(self): layer = LinearNet(self.linear_size, self.linear_size) example_inputs, layer, _ = train(layer, self.linear_size, 1) @@ -1008,6 +1043,7 @@ def train_and_save_orig_model(self): layer=layer, path=self.model_path, input_spec=example_inputs ) + @test_with_dygraph_pir def test_load_model_retransform_inference(self): multi_loaded_layer = MultiLoadingLinearNet( self.linear_size, self.model_path @@ -1037,6 +1073,11 @@ def tearDown(self): def train_and_save(self): train_layer = LinearNetReturnHidden(8, 8) + train_layer = to_static( + train_layer, + input_spec=[InputSpec([None, 8], name='x')], + full_graph=True, + ) adam = paddle.optimizer.Adam( learning_rate=0.1, parameters=train_layer.parameters() ) @@ -1047,7 +1088,7 @@ def train_and_save(self): adam.minimize(loss) train_layer.clear_gradients() - output_spec = [hidden] + output_spec = train_layer.forward.outputs[:1] paddle.jit.save( layer=train_layer, path=self.model_path, @@ -1057,6 +1098,7 @@ def train_and_save(self): return train_layer + @test_with_dygraph_pir def test_load_pruned_model(self): train_layer = self.train_and_save() train_layer.eval() @@ -1068,6 +1110,8 @@ def test_load_pruned_model(self): train_layer(x)[0].numpy(), infer_layer(x).numpy() ) + # pir has no need to save extra var info, param always saved with program, + # and trainable info saved in program's op attr def test_load_var_not_in_extra_var_info(self): self.train_and_save() @@ -1120,6 +1164,7 @@ def verify_inference_correctness( err_msg=f'Result diff when load and inference:\nlayer result:\n{pred}\nloaded layer result:\n{loaded_pred}', ) + @test_with_dygraph_pir def test_no_prune_to_static_after_train(self): layer = LinearNet(784, 1) @@ -1132,6 +1177,7 @@ def test_no_prune_to_static_after_train(self): self.verify_inference_correctness(layer, model_path) + @test_with_dygraph_pir def test_no_prune_to_static_no_train(self): layer = LinearNetWithInputSpec(784, 1) @@ -1142,6 +1188,7 @@ def test_no_prune_to_static_no_train(self): self.verify_inference_correctness(layer, model_path) + @test_with_dygraph_pir def test_no_prune_no_to_static_after_train(self): layer = LinearNetNotDeclarative(784, 1) @@ -1158,6 +1205,7 @@ def test_no_prune_no_to_static_after_train(self): self.verify_inference_correctness(layer, model_path) + @test_with_dygraph_pir def test_no_prune_no_to_static_after_train_with_examples(self): layer = LinearNetNotDeclarative(784, 1) @@ -1171,6 +1219,7 @@ def test_no_prune_no_to_static_after_train_with_examples(self): self.verify_inference_correctness(layer, model_path) + @test_with_dygraph_pir def test_no_prune_no_to_static_no_train(self): layer = LinearNetNotDeclarative(784, 1) @@ -1185,9 +1234,17 @@ def test_no_prune_no_to_static_no_train(self): self.verify_inference_correctness(layer, model_path) + @test_with_dygraph_pir def test_prune_to_static_after_train(self): layer = LinerNetWithLabel(784, 1) - + layer = paddle.jit.to_static( + layer, + input_spec=[ + InputSpec(shape=[None, 784], dtype='float32', name="image"), + InputSpec(shape=[None, 1], dtype='int64', name="label"), + ], + full_graph=True, + ) out = train_with_label(layer) model_path = os.path.join( @@ -1198,9 +1255,8 @@ def test_prune_to_static_after_train(self): model_path, input_spec=[ InputSpec(shape=[None, 784], dtype='float32', name="image"), - True, ], - output_spec=[out], + output_spec=layer.forward.outputs[:1], input_names_after_prune=["image"], ) @@ -1208,9 +1264,17 @@ def test_prune_to_static_after_train(self): layer, model_path, with_label_and_loss=True ) + @test_with_dygraph_pir def test_prune_to_static_no_train(self): layer = LinerNetWithLabel(784, 1) - + layer = paddle.jit.to_static( + layer, + input_spec=[ + InputSpec(shape=[None, 784], dtype='float32', name="image"), + InputSpec(shape=[None, 1], dtype='int64', name="label"), + ], + full_graph=True, + ) model_path = os.path.join( self.temp_dir.name, "test_prune_to_static_no_train/model" ) @@ -1222,7 +1286,6 @@ def test_prune_to_static_no_train(self): model_path, input_spec=[ InputSpec(shape=[None, 784], dtype='float32', name="image"), - True, ], output_spec=output_spec, input_names_after_prune=["image"], @@ -1232,9 +1295,17 @@ def test_prune_to_static_no_train(self): layer, model_path, with_label_and_loss=True ) + @test_with_dygraph_pir def test_prune_input_to_static_no_train(self): layer = LinerNetWithPruneInput(784, 1) - + layer = paddle.jit.to_static( + layer, + input_spec=[ + InputSpec(shape=[None, 784], dtype='float32', name="image"), + InputSpec(shape=[None, 1], dtype='int64', name="label"), + ], + full_graph=True, + ) model_path = os.path.join( self.temp_dir.name, "test_prune_input_to_static_no_train/model" ) @@ -1248,9 +1319,17 @@ def test_prune_input_to_static_no_train(self): self.verify_inference_correctness(layer, model_path, with_label=True) + @test_with_dygraph_pir def test_prune_useless_input_to_static_no_train(self): layer = LinerNetWithUselessInput(784, 1) - + layer = paddle.jit.to_static( + layer, + input_spec=[ + InputSpec(shape=[None, 784], dtype='float32', name="image"), + InputSpec(shape=[None, 1], dtype='int64', name="label"), + ], + full_graph=True, + ) model_path = os.path.join( self.temp_dir.name, "test_prune_useless_input_to_static_no_train/model", @@ -1265,6 +1344,7 @@ def test_prune_useless_input_to_static_no_train(self): self.verify_inference_correctness(layer, model_path, with_label=True) + @test_with_dygraph_pir def test_no_prune_input_spec_name_warning(self): layer = LinearNetWithInputSpec(784, 1) @@ -1288,6 +1368,7 @@ def test_no_prune_input_spec_name_warning(self): self.verify_inference_correctness(layer, model_path) + @test_with_dygraph_pir def test_not_prune_output_spec_name_warning(self): layer = LinearNet(784, 1) @@ -1301,6 +1382,7 @@ def test_not_prune_output_spec_name_warning(self): self.verify_inference_correctness(layer, model_path) + @test_with_dygraph_pir def test_prune_input_spec_name_error(self): layer = LinerNetWithLabel(784, 1) @@ -1324,9 +1406,17 @@ def test_prune_input_spec_name_error(self): ], ) + @test_with_dygraph_pir def test_prune_output_spec_name_error(self): layer = LinerNetWithLabel(784, 1) - + layer = paddle.jit.to_static( + layer, + input_spec=[ + InputSpec(shape=[None, 784], dtype='float32', name="image"), + InputSpec(shape=[None, 1], dtype='int64', name="label"), + ], + full_graph=True, + ) train_with_label(layer) model_path = os.path.join( @@ -1358,6 +1448,7 @@ def setUp(self): def tearDown(self): self.temp_dir.cleanup() + @test_with_dygraph_pir def test_save_load_empty_layer(self): layer = EmptyLayer() x = paddle.to_tensor(np.random.random(10).astype('float32')) @@ -1380,6 +1471,7 @@ def setUp(self): def tearDown(self): self.temp_dir.cleanup() + @test_with_dygraph_pir def test_save_load_no_param_layer(self): layer = NoParamLayer() x = paddle.to_tensor(np.random.random(5).astype('float32')) @@ -1400,17 +1492,31 @@ def setUp(self): def tearDown(self): self.temp_dir.cleanup() + @test_with_dygraph_pir def test_jit_save_load_inference(self): model_path_inference = os.path.join( self.temp_dir.name, "jit_save_load_multi_methods/model" ) IMAGE_SIZE = 224 layer = LinearNetWithMultiStaticFunc(IMAGE_SIZE, 10) + layer = paddle.jit.to_static( + layer, + full_graph=True, + ) + layer.forward_no_param = paddle.jit.to_static( + layer.forward_no_param, + full_graph=True, + ) + layer.forward_general = paddle.jit.to_static( + layer.forward_general, + full_graph=True, + ) inps = paddle.randn([1, IMAGE_SIZE]) result_origin = {} for func in dir(layer): if func.startswith('forward'): result_origin[func] = getattr(layer, func, None)(inps) + paddle.jit.save(layer, model_path_inference) load_net = paddle.jit.load(model_path_inference) for func, result in result_origin.items(): @@ -1421,16 +1527,30 @@ def test_jit_save_load_inference(self): < 1e-5 ) + @test_with_dygraph_pir def test_jit_save_load_multi_methods_inputspec(self): model_path = os.path.join( self.temp_dir.name, 'jit_save_load_multi_methods/model' ) layer = LinearNetWithMultiStaticFunc(784, 1) + layer = paddle.jit.to_static( + layer, + full_graph=True, + ) + layer.forward_no_param = paddle.jit.to_static( + layer.forward_no_param, + full_graph=True, + ) + layer.forward_general = paddle.jit.to_static( + layer.forward_general, + full_graph=True, + ) with self.assertRaises(ValueError): paddle.jit.save( layer, model_path, input_spec=[InputSpec(shape=[None, 784])] ) + @test_with_dygraph_pir def test_parse_name(self): model_path_inference = os.path.join( self.temp_dir.name, "jit_save_load_parse_name/model" @@ -1456,7 +1576,6 @@ def __init__(self, in_size, out_size): self._linear_2 = Linear(self.hidden, out_size) self._scale = paddle.to_tensor([9.9]) - @paddle.jit.to_static def forward(self, x): y = self._linear_0(x) # Multiple blocks @@ -1467,88 +1586,74 @@ def forward(self, x): return self._linear_2(y) -class Net(paddle.nn.Layer): - def __init__(self): - super().__init__() - self.fc1 = paddle.nn.Linear(4, 4) - self.fc2 = paddle.nn.Linear(4, 4) - self.bias = 0.4 - self.flag = paddle.ones([2], dtype="int32") - - @paddle.jit.to_static(input_spec=[InputSpec([None, 4], dtype='float32')]) - def log_softmax(self, input): - return paddle.nn.functional.log_softmax(input, axis=-1) - - @paddle.jit.to_static(input_spec=[InputSpec([None, 4], dtype='float32')]) - def forward(self, x): - out = self.fc1(x) - out = paddle.nn.functional.relu(out) - out = paddle.mean(out) - return out - - @paddle.jit.to_static(input_spec=[InputSpec([None, 4], dtype='float32')]) - def infer(self, input): - out = self.fc2(input) - out = out + self.bias - out = paddle.mean(out) - return out - - # For extra Python float - @paddle.jit.to_static(property=True) - def fbias(self): - return self.bias + 1 - - @paddle.jit.to_static(property=True) - def down_sampling(self): - return 4 +class TestJitSaveCombineProperty(unittest.TestCase): + def setUp(self): + # enable dygraph mode + paddle.disable_static() + self.temp_dir = tempfile.TemporaryDirectory() - @paddle.jit.to_static(property=True) - def fstr(self): - return "save str property" + def tearDown(self): + self.temp_dir.cleanup() - @paddle.jit.to_static(property=True) - def ints(self): - return [10, 20] + @test_with_dygraph_pir + def test_jit_save_combine_property(self): + class Net(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.fc1 = paddle.nn.Linear(4, 4) + self.fc2 = paddle.nn.Linear(4, 4) + self.bias = 0.4 + self.flag = paddle.ones([2], dtype="int32") - @paddle.jit.to_static(property=True) - def floats(self): - return [1.1, 2.2] + @paddle.jit.to_static( + input_spec=[InputSpec([None, 4], dtype='float32')] + ) + def log_softmax(self, input): + return paddle.nn.functional.log_softmax(input, axis=-1) - @paddle.jit.to_static(property=True) - def strs(self): - return ["hello", "world"] + @paddle.jit.to_static( + input_spec=[InputSpec([None, 4], dtype='float32')] + ) + def forward(self, x): + out = self.fc1(x) + out = paddle.nn.functional.relu(out) + out = paddle.mean(out) + return out + @paddle.jit.to_static( + input_spec=[InputSpec([None, 4], dtype='float32')] + ) + def infer(self, input): + out = self.fc2(input) + out = out + self.bias + out = paddle.mean(out) + return out -class NetTensor(paddle.nn.Layer): - def __init__(self): - super().__init__() - self.fc1 = paddle.nn.Linear(4, 4) - self.fc2 = paddle.nn.Linear(4, 4) - self.bias = 0.4 - self.flag = paddle.ones([2], dtype="int32") + # For extra Python float + @paddle.jit.to_static(property=True, full_graph=True) + def fbias(self): + return self.bias + 1 - @paddle.jit.to_static(input_spec=[InputSpec([None, 4], dtype='float32')]) - def forward(self, x): - out = self.fc1(x) - out = paddle.nn.functional.relu(out) - out = paddle.mean(out) - return out + @paddle.jit.to_static(property=True, full_graph=True) + def down_sampling(self): + return 4 - @paddle.jit.to_static(property=True) - def fflag(self): - return True + @paddle.jit.to_static(property=True, full_graph=True) + def fstr(self): + return "save str property" + @paddle.jit.to_static(property=True, full_graph=True) + def ints(self): + return [10, 20] -class TestJitSaveCombineProperty(unittest.TestCase): - def setUp(self): - # enable dygraph mode - paddle.disable_static() - self.temp_dir = tempfile.TemporaryDirectory() + @paddle.jit.to_static(property=True, full_graph=True) + def floats(self): + return [1.1, 2.2] - def tearDown(self): - self.temp_dir.cleanup() + @paddle.jit.to_static(property=True, full_graph=True) + def strs(self): + return ["hello", "world"] - def test_jit_save_combine_property(self): model_path = os.path.join( self.temp_dir.name, "test_jit_save_combine/model" ) @@ -1558,50 +1663,41 @@ def test_jit_save_combine_property(self): # save paddle.jit.save(net, model_path, combine_params=True) + @test_with_dygraph_pir def test_jit_save_tensor_property(self): + class NetTensor(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.fc1 = paddle.nn.Linear(4, 4) + self.fc2 = paddle.nn.Linear(4, 4) + self.bias = 0.4 + self.flag = paddle.ones([2], dtype="int32") + + def forward(self, x): + out = self.fc1(x) + out = paddle.nn.functional.relu(out) + out = paddle.mean(out) + return out + + @paddle.jit.to_static(property=True, full_graph=True) + def fflag(self): + return True + model_path = os.path.join( self.temp_dir.name, "test_jit_save_combine/model" ) # Use new namespace with unique_name.guard(): net = NetTensor() + net = paddle.jit.to_static( + net, + input_spec=[InputSpec([None, 4], dtype='float32')], + full_graph=True, + ) paddle.jit.save(net, model_path, combine_params=True) -class LayerLoadFinetune(paddle.nn.Layer): - def __init__(self, in_size, out_size, load_path): - super().__init__() - # Test duplicate name - self._linear_0 = Linear(in_size, in_size) - self._linear_1_0 = Linear(out_size, in_size) - self._linear_1_1 = Linear(out_size, in_size) - self._linear_2 = Linear(out_size, out_size) - self._scale = paddle.to_tensor([9.9]) - - # Load multiple times - self._load_l1 = paddle.jit.load(load_path) - self._load_l2 = paddle.jit.load(load_path) - - @paddle.jit.to_static - def forward(self, x): - y = self._linear_0(x) - y = self._load_l1(y) - # Multiple blocks - if paddle.shape(x)[0] == 1: - y = self._linear_1_0(y) - y = self._load_l1(y) - else: - y += self._linear_1_1(x + self._scale) - y = self._load_l2(y) - y = self._linear_1_0(y) - y = self._load_l1(y) - y = self._linear_1_0(y) - # Use the same layer multiple times. - y = self._load_l1(y) - return y - - class TestJitSaveLoadSaveWithoutRunning(unittest.TestCase): def setUp(self): # enable dygraph mode @@ -1621,6 +1717,7 @@ def test_save_load_finetune_load(self): # Use new namespace with unique_name.guard(): layer_save = LayerSaved(IMAGE_SIZE, IMAGE_SIZE) + layer_save = paddle.jit.to_static(layer_save, full_graph=True) # save paddle.jit.save( layer_save, @@ -1654,6 +1751,39 @@ def test_save_load_finetune_load(self): self.assertTrue(float((result_01 - result_11).abs().max()) < 1e-5) +class LayerLoadFinetune(paddle.nn.Layer): + def __init__(self, in_size, out_size, load_path): + super().__init__() + # Test duplicate name + self._linear_0 = Linear(in_size, in_size) + self._linear_1_0 = Linear(out_size, in_size) + self._linear_1_1 = Linear(out_size, in_size) + self._linear_2 = Linear(out_size, out_size) + self._scale = paddle.to_tensor([9.9]) + + # Load multiple times + self._load_l1 = paddle.jit.load(load_path) + self._load_l2 = paddle.jit.load(load_path) + + def forward(self, x): + y = self._linear_0(x) + y = self._load_l1(y) + # Multiple blocks + if paddle.shape(x)[0] == 1: + y = self._linear_1_0(y) + y = self._load_l1(y) + else: + y += self._linear_1_1(x + self._scale) + y = self._load_l2(y) + y = self._linear_1_0(y) + y = self._load_l1(y) + y = self._linear_1_0(y) + # Use the same layer multiple times. + y = self._load_l1(y) + return y + + +''' class TestJitSaveLoadFinetuneLoad(unittest.TestCase): def setUp(self): # enable dygraph mode @@ -1663,6 +1793,7 @@ def setUp(self): def tearDown(self): self.temp_dir.cleanup() + #@test_with_dygraph_pir def test_save_load_finetune_load(self): model_path = os.path.join( self.temp_dir.name, "test_jit_save_load_finetune_load/model" @@ -1673,12 +1804,14 @@ def test_save_load_finetune_load(self): # Use new namespace with unique_name.guard(): layer_save = LayerSaved(IMAGE_SIZE, IMAGE_SIZE) + layer_save = paddle.jit.to_static(layer_save, full_graph=True) layer_save(inps0) # save paddle.jit.save(layer_save, model_path) # load with unique_name.guard(): layer_load = LayerLoadFinetune(IMAGE_SIZE, IMAGE_SIZE, model_path) + layer_load = paddle.jit.to_static(layer_load, full_graph=True) # train train(layer_load, input_size=IMAGE_SIZE) result_00 = layer_load(inps0) @@ -1692,6 +1825,7 @@ def test_save_load_finetune_load(self): self.assertTrue(float((result_00 - result_10).abs().max()) < 1e-5) self.assertTrue(float((result_01 - result_11).abs().max()) < 1e-5) +''' # NOTE(weixin): When there are multiple test functions in an @@ -1707,6 +1841,7 @@ def setUp(self): def tearDown(self): self.temp_dir.cleanup() + @test_with_dygraph_pir def test_jit_save_load_static_function(self): @paddle.jit.to_static def fun(inputs): @@ -1733,6 +1868,7 @@ def setUp(self): def tearDown(self): self.temp_dir.cleanup() + @test_with_dygraph_pir def test_jit_save_load_function_input_spec(self): @paddle.jit.to_static( input_spec=[ @@ -1762,6 +1898,7 @@ def setUp(self): def tearDown(self): self.temp_dir.cleanup() + @test_with_dygraph_pir def test_jit_save_load_function_function(self): def fun(inputs): return paddle.tanh(inputs) @@ -1793,6 +1930,7 @@ def setUp(self): def tearDown(self): self.temp_dir.cleanup() + @test_with_dygraph_pir def test_jit_save_load_function(self): class LinearNet(paddle.nn.Layer): def __init__(self): @@ -1832,6 +1970,7 @@ def setUp(self): def tearDown(self): self.temp_dir.cleanup() + @test_with_dygraph_pir def test_jit_save_load_function(self): class LinearNet(paddle.nn.Layer): def __init__(self): @@ -1872,6 +2011,7 @@ def setUp(self): def tearDown(self): self.temp_dir.cleanup() + @test_with_dygraph_pir def test_jit_save_load_function(self): class LinearNet(paddle.nn.Layer): def __init__(self): @@ -1922,6 +2062,7 @@ def verify_inference_correctness(self, layer, path): err_msg=f'Result diff when load and inference:\nlayer result:\n{pred}\nloaded layer result:\n{loaded_pred}', ) + @test_with_dygraph_pir def test_jit_save_data_parallel_with_inputspec(self): layer = LinearNetNotDeclarative(784, 1) layer = paddle.DataParallel(layer) @@ -1934,6 +2075,7 @@ def test_jit_save_data_parallel_with_inputspec(self): self.verify_inference_correctness(layer, path) + @test_with_dygraph_pir def test_jit_save_data_parallel_with_to_static(self): layer = LinearNetWithInputSpec(784, 1) layer = paddle.DataParallel(layer) @@ -1947,16 +2089,8 @@ def test_jit_save_data_parallel_with_to_static(self): class InputSepcLayer(paddle.nn.Layer): - ''' - A layer with InputSpec to test InputSpec compatibility - ''' - - @paddle.jit.to_static( - input_spec=[ - InputSpec(shape=[None, 8], dtype='float32', name='x'), - InputSpec(shape=[None, 1], dtype='float64', name='y'), - ] - ) + # A layer with InputSpec to test InputSpec compatibility + def forward(self, x, y): return x, y @@ -1980,11 +2114,18 @@ def _assert_input_spec_layer_return(self, expect_layer, test_layer): expected_result[1].numpy(), test_result[1].numpy() ) - def test_jit_save_compatible_input_sepc(self): + @test_with_dygraph_pir + def test_jit_save_no_input_sepc(self): layer = InputSepcLayer() - save_dir = os.path.join( - self.temp_dir.name, "jit_save_compatible_input_spec" + layer = paddle.jit.to_static( + layer, + input_spec=[ + InputSpec(shape=[None, 8], dtype='float32', name='x'), + InputSpec(shape=[None, 1], dtype='float64', name='y'), + ], + full_graph=True, ) + save_dir = os.path.join(self.temp_dir.name, "jit_save_no_input_spec") path = save_dir + "/model" paddle.jit.save(layer=layer, path=path) @@ -1992,6 +2133,21 @@ def test_jit_save_compatible_input_sepc(self): self._assert_input_spec_layer_return(layer, no_input_spec_layer) shutil.rmtree(save_dir) + @test_with_dygraph_pir + def test_jit_save_same_input_sepc(self): + layer = InputSepcLayer() + layer = paddle.jit.to_static( + layer, + input_spec=[ + InputSpec(shape=[None, 8], dtype='float32', name='x'), + InputSpec(shape=[None, 1], dtype='float64', name='y'), + ], + full_graph=True, + ) + + save_dir = os.path.join(self.temp_dir.name, "jit_save_same_input_spec") + path = save_dir + "/model" + paddle.jit.save( layer=layer, path=path, @@ -2004,6 +2160,22 @@ def test_jit_save_compatible_input_sepc(self): self._assert_input_spec_layer_return(layer, same_input_spec_layer) shutil.rmtree(save_dir) + @test_with_dygraph_pir + def test_jit_save_compatible_input_sepc(self): + layer = InputSepcLayer() + layer = paddle.jit.to_static( + layer, + input_spec=[ + InputSpec(shape=[None, 8], dtype='float32', name='x'), + InputSpec(shape=[None, 1], dtype='float64', name='y'), + ], + full_graph=True, + ) + + save_dir = os.path.join( + self.temp_dir.name, "jit_save_compatible_input_spec" + ) + path = save_dir + "/model" paddle.jit.save( layer=layer, path=path, @@ -2016,8 +2188,17 @@ def test_jit_save_compatible_input_sepc(self): self._assert_input_spec_layer_return(layer, compatible_input_spec_layer) shutil.rmtree(save_dir) + @test_with_dygraph_pir def test_jit_save_incompatible_input_sepc(self): layer = InputSepcLayer() + layer = paddle.jit.to_static( + layer, + input_spec=[ + InputSpec(shape=[None, 8], dtype='float32', name='x'), + InputSpec(shape=[None, 1], dtype='float64', name='y'), + ], + full_graph=True, + ) save_dir = os.path.join( self.temp_dir.name, "jit_save_compatible_input_spec" ) @@ -2074,6 +2255,7 @@ def setUp(self): def tearDown(self): self.temp_dir.cleanup() + @test_with_dygraph_pir def test_jit_not_save_forward(self): layer = NotJitForward() diff --git a/test/legacy_test/test_onnx_export.py b/test/legacy_test/test_onnx_export.py index 26e9ca381757ff..8a91fa9705ca8b 100644 --- a/test/legacy_test/test_onnx_export.py +++ b/test/legacy_test/test_onnx_export.py @@ -60,7 +60,7 @@ def test_prune_graph(self): model = Logic() self.x = paddle.to_tensor(np.array([1])) self.y = paddle.to_tensor(np.array([-1])) - paddle.jit.to_static(model) + paddle.jit.to_static(model, full_graph=True) out = model(self.x, self.y, z=True) paddle.onnx.export( model, diff --git a/test/legacy_test/test_paddle_save_load.py b/test/legacy_test/test_paddle_save_load.py index 8afb3f3699f6e0..34614fb85b6876 100644 --- a/test/legacy_test/test_paddle_save_load.py +++ b/test/legacy_test/test_paddle_save_load.py @@ -25,7 +25,7 @@ from paddle import base, nn from paddle.base import framework from paddle.framework import in_pir_mode -from paddle.framework.io_utils import set_value +from paddle.framework.io_utils import get_value, is_pir_fetch_var, set_value from paddle.optimizer import Adam from paddle.optimizer.lr import LRScheduler from paddle.pir_utils import test_with_pir_api @@ -174,10 +174,7 @@ def set_zero(self, prog, place, scope=None): scope = base.global_scope() for var in prog.list_vars(): if isinstance(var, framework.Parameter) or var.persistable: - if ( - in_pir_mode() - and var.get_defining_op().name() == "pd_op.fetch" - ): + if is_pir_fetch_var(var): continue ten = scope.find_var(var.name).get_tensor() if ten is not None: @@ -236,10 +233,7 @@ def test_replace_static_save_load(self): base_map = {} for var in prog.list_vars(): if isinstance(var, framework.Parameter) or var.persistable: - if ( - in_pir_mode() - and var.get_defining_op().name() == "pd_op.fetch" - ): + if is_pir_fetch_var(var): continue t = np.array( base.global_scope().find_var(var.name).get_tensor() @@ -254,10 +248,7 @@ def test_replace_static_save_load(self): paddle.static.load(prog, path) for var in prog.list_vars(): if isinstance(var, framework.Parameter) or var.persistable: - if ( - in_pir_mode() - and var.get_defining_op().name() == "pd_op.fetch" - ): + if is_pir_fetch_var(var): continue new_t = np.array( base.global_scope().find_var(var.name).get_tensor() @@ -267,14 +258,10 @@ def test_replace_static_save_load(self): # legacy paddle.base.save, paddle.load paddle.static.save(prog, path) self.set_zero(prog, place) - # paddle.static.load(prog, path) self.replace_static_load(prog, path) for var in prog.list_vars(): if isinstance(var, framework.Parameter) or var.persistable: - if ( - in_pir_mode() - and var.get_defining_op().name() == "pd_op.fetch" - ): + if is_pir_fetch_var(var): continue new_t = np.array( base.global_scope().find_var(var.name).get_tensor() @@ -285,19 +272,21 @@ def test_replace_static_save_load(self): path_vars = 'test_replace_save_load_return_tensor_static/model' for var in prog.list_vars(): if var.persistable: - if ( - in_pir_mode() - and var.get_defining_op().name() == "pd_op.fetch" - ): + if is_pir_fetch_var(var): continue tensor = base.global_scope().find_var(var.name).get_tensor() paddle.save( tensor, os.path.join(self.temp_dir.name, path_vars, var.name), ) + + # Pir value currently does not have .set_value() and .get_value() + # Instead, use new functions to replace them + with self.assertRaises(TypeError): + get_value(var, 'base.global_scope()') + # Pir get_value() currently does not raise ValueError + # Maybe fix it later if not in_pir_mode(): - with self.assertRaises(TypeError): - var.get_value('base.global_scope()') with self.assertRaises(ValueError): x.get_value() with self.assertRaises(TypeError): @@ -313,10 +302,7 @@ def test_replace_static_save_load(self): self.set_zero(prog, place) for var in prog.list_vars(): if var.persistable: - if ( - in_pir_mode() - and var.get_defining_op().name() == "pd_op.fetch" - ): + if is_pir_fetch_var(var): continue tensor = paddle.load( os.path.join(self.temp_dir.name, path_vars, var.name), @@ -381,6 +367,7 @@ def get_lr(self): print(load_dict_np[k]) np.testing.assert_array_equal(v.numpy(), load_dict_np[k]) + @test_with_pir_api def test_single_pickle_var_dygraph(self): # enable dygraph mode paddle.disable_static() @@ -411,6 +398,7 @@ def test_single_pickle_var_dygraph(self): np.testing.assert_array_equal(tensor.numpy(), np_static) np.testing.assert_array_equal(tensor.numpy(), np.array(lod_static)) + @test_with_pir_api def test_single_pickle_var_static(self): # enable static graph mode paddle.enable_static() @@ -431,7 +419,7 @@ def test_single_pickle_var_static(self): prog = paddle.static.default_main_program() for var in prog.list_vars(): if list(var.shape) == [IMAGE_SIZE, 128]: - tensor = var.get_value() + tensor = get_value(var) break scope = base.global_scope() origin_tensor = np.array(tensor) @@ -444,11 +432,11 @@ def test_single_pickle_var_static(self): lod_static = paddle.load(path) np_static = paddle.load(path, return_numpy=True) # set_tensor(np.ndarray) - var.set_value(np_static, scope) + set_value(var, np_static, scope) np.testing.assert_array_equal(origin_tensor, np.array(tensor)) # set_tensor(LoDTensor) self.set_zero(prog, place, scope) - var.set_value(lod_static, scope) + set_value(var, lod_static, scope) np.testing.assert_array_equal(origin_tensor, np.array(tensor)) # enable dygraph mode paddle.disable_static() @@ -491,6 +479,7 @@ def test_dygraph_save_static_load(self): tensor.numpy(), np.array(state_dict_param[tensor.name]) ) + @test_with_pir_api def test_save_load_complex_object_dygraph_save(self): paddle.disable_static() layer = paddle.nn.Linear(3, 4) @@ -667,6 +656,7 @@ def test_save_load_complex_object_dygraph_save(self): np.testing.assert_array_equal(load_array4[0], obj4[0]) + @test_with_pir_api def test_save_load_complex_object_static_save(self): paddle.enable_static() with new_program_scope(): @@ -686,7 +676,7 @@ def test_save_load_complex_object_static_save(self): exe = paddle.static.Executor(place) exe.run(paddle.static.default_startup_program()) - state_dict = prog.state_dict() + state_dict = prog.state_dict('all', base.global_scope()) keys = list(state_dict.keys()) obj1 = [ state_dict[keys[0]], @@ -1025,7 +1015,9 @@ def check_load_state_dict(self, orig_dict, load_dict): ) np.testing.assert_array_equal(value.numpy(), load_value) + @test_with_pir_api def test_save_load(self): + paddle.disable_static() layer, opt = self.build_and_train_model() # save @@ -1201,6 +1193,44 @@ def test_save_load_program(self): self.assertTrue(origin_startup == load_startup) temp_dir.cleanup() + def test_save_load_program_pir(self): + paddle.enable_static() + with paddle.pir_utils.IrGuard(): + temp_dir = tempfile.TemporaryDirectory() + with new_program_scope(): + layer = LinearNet() + data = paddle.static.data( + name='x_static_save', + shape=(None, IMAGE_SIZE), + dtype='float32', + ) + y_static = layer(data) + main_program = paddle.static.default_main_program() + startup_program = paddle.static.default_startup_program() + path1 = os.path.join( + temp_dir.name, + "test_paddle_save_load_program/main_program.pdmodel", + ) + path2 = os.path.join( + temp_dir.name, + "test_paddle_save_load_program/startup_program.pdmodel", + ) + paddle.save(main_program, path1) + paddle.save(startup_program, path2) + + with new_program_scope(): + load_main = paddle.load(path1) + load_startup = paddle.load(path2) + self.assertTrue( + len(main_program.global_block().ops) + == len(load_main.global_block().ops) + ) + self.assertTrue( + len(startup_program.global_block().ops) + == len(load_startup.global_block().ops) + ) + temp_dir.cleanup() + class TestSaveLoadLayer(unittest.TestCase): def test_save_load_layer(self): diff --git a/test/legacy_test/test_stack_op.py b/test/legacy_test/test_stack_op.py index 0b11f9e4236e92..d6f8eb3b695a21 100644 --- a/test/legacy_test/test_stack_op.py +++ b/test/legacy_test/test_stack_op.py @@ -399,7 +399,7 @@ def setUp(self): def test_list_single_tensor(self): expect = paddle.stack(self.x) paddle.base.core._set_prim_all_enabled(True) - st_model = paddle.jit.to_static(paddle.stack) + st_model = paddle.jit.to_static(paddle.stack, full_graph=True) actual = st_model(self.x) np.testing.assert_allclose(expect, actual) paddle.enable_static() diff --git a/test/legacy_test/test_strided_slice_op.py b/test/legacy_test/test_strided_slice_op.py index 316665afc693c7..7e4594c069c7f4 100644 --- a/test/legacy_test/test_strided_slice_op.py +++ b/test/legacy_test/test_strided_slice_op.py @@ -765,7 +765,7 @@ def create_case(self, net): self.is_grads_equal_zeros(grads_zeros) - func = paddle.jit.to_static(net.forward) + func = paddle.jit.to_static(net.forward, full_graph=True) l2 = func(inps2) s2 = l2.numpy() l2.sum().backward() @@ -807,7 +807,7 @@ def forward(self, inps): return array1 + array2 * array2 net = Simple() - func = paddle.jit.to_static(net.forward) + func = paddle.jit.to_static(net.forward, full_graph=True) inps1 = paddle.to_tensor( np.random.randn(2, 10), diff --git a/test/legacy_test/test_tensor_register_hook.py b/test/legacy_test/test_tensor_register_hook.py index c7826c983adcd7..fd53f2033acf5b 100644 --- a/test/legacy_test/test_tensor_register_hook.py +++ b/test/legacy_test/test_tensor_register_hook.py @@ -520,7 +520,8 @@ def test_register_hook_in_static_mode(self): def test_register_hook_in_dy2static_mode(self): net = SimpleNetForStatic(self.in_size, self.out_size) jit_net = paddle.jit.to_static( - net, input_spec=[paddle.static.InputSpec([None, self.in_size])] + net, + input_spec=[paddle.static.InputSpec([None, self.in_size])], ) data = np.random.uniform(size=[self.batch_size, self.in_size]).astype( diff --git a/test/legacy_test/test_translated_layer.py b/test/legacy_test/test_translated_layer.py index 8d8a9d919f3669..bb89aeff4ffea1 100644 --- a/test/legacy_test/test_translated_layer.py +++ b/test/legacy_test/test_translated_layer.py @@ -57,7 +57,8 @@ def __init__(self): paddle.static.InputSpec( shape=[None, IMAGE_SIZE], dtype='float32', name='x' ) - ] + ], + full_graph=True, ) def forward(self, x): return self._linear(x) diff --git a/test/deprecated/legacy_test/test_zero_dim_sundry_dygraph_api.py b/test/legacy_test/test_zero_dim_sundry_dygraph_api.py similarity index 99% rename from test/deprecated/legacy_test/test_zero_dim_sundry_dygraph_api.py rename to test/legacy_test/test_zero_dim_sundry_dygraph_api.py index 00f32fe8744133..ac22d5ae52ff41 100644 --- a/test/deprecated/legacy_test/test_zero_dim_sundry_dygraph_api.py +++ b/test/legacy_test/test_zero_dim_sundry_dygraph_api.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/test/deprecated/legacy_test/test_zero_dim_sundry_static_api_part1.py b/test/legacy_test/test_zero_dim_sundry_static_api_part1.py similarity index 88% rename from test/deprecated/legacy_test/test_zero_dim_sundry_static_api_part1.py rename to test/legacy_test/test_zero_dim_sundry_static_api_part1.py index 22386fc5022ed5..aafc4afc32c680 100644 --- a/test/deprecated/legacy_test/test_zero_dim_sundry_static_api_part1.py +++ b/test/legacy_test/test_zero_dim_sundry_static_api_part1.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -146,17 +146,6 @@ def test_create_parameter(self): ) self.assertEqual(zero_dim_param_res.shape, ()) - @prog_scope() - def test_create_global_var(self): - zero_dim_var = paddle.static.create_global_var( - shape=[], value=0.5, dtype='float32' - ) - self.assertEqual(zero_dim_var.shape, ()) - prog = paddle.static.default_startup_program() - res = self.exe.run(prog, fetch_list=[zero_dim_var]) - self.assertEqual(res[0].shape, ()) - self.assertEqual(res[0], 0.5) - @test_with_pir_api @prog_scope() def test_getitem(self): @@ -212,67 +201,6 @@ def test_getitem(self): self.assertEqual(res[1].shape, (1, 4)) np.testing.assert_allclose(res[1], np.ones((1, 4))) - @prog_scope() - def test_setitem(self): - # NOTE(zoooo0820): __setitem__ has gradient problem in static graph. - # To solve this, we may not support __setitem__ in static graph. - # These unit tests will delete soon. - - # case1: all axis have a scalar indice - x = paddle.arange(2 * 3 * 4 * 5).reshape((2, 3, 4, 5)) - x.stop_gradient = False - out = x * 2 - out = paddle.static.setitem(out, (1, 2, 3, 4), 10) - paddle.static.append_backward(out.sum()) - prog = paddle.static.default_main_program() - res = self.exe.run(prog, fetch_list=[out, x.grad_name]) - - self.assertEqual(out.shape, x.shape) - np.testing.assert_allclose(res[0][1, 2, 3, 4], np.array(10)) - self.assertEqual(res[1].shape, (2, 3, 4, 5)) - x_grad_expected = np.ones((2, 3, 4, 5)) * 2 - x_grad_expected[1, 2, 3, 4] = 0 - np.testing.assert_allclose(res[1], x_grad_expected) - - # case2: 0-D Tensor indice in some axis - # NOTE(zoooo0820): Now, int/slice with 0-D Tensor will still be - # treated as combined indexing, which is not support backward. - # There should have more test cases such as out[1, indice, :] = 0.5 when this - # problem is fixed. - x = paddle.randn((2, 3, 4, 5)) - x.stop_gradient = False - indice = paddle.full([], 1, dtype='int32') - out = x * 1 - out = paddle.static.setitem(out, (indice, indice), 0.5) - paddle.static.append_backward(out.sum()) - prog = paddle.static.default_main_program() - res = self.exe.run(prog, fetch_list=[out, x.grad_name]) - - self.assertEqual(out.shape, x.shape) - np.testing.assert_allclose(res[0][1, 1], np.ones((4, 5)) * 0.5) - x_grad_expected = np.ones((2, 3, 4, 5)) - x_grad_expected[1, 1] = 0 - np.testing.assert_allclose(res[1], x_grad_expected) - - # case3:0-D Tensor indice in some axis, value is a Tensor - # and there is broadcast - x = paddle.randn((2, 3, 4, 5)) - x.stop_gradient = False - v = paddle.ones((4, 5), dtype='float32') * 5 - v.stop_gradient = False - indice = paddle.full([], 1, dtype='int32') - out = x * 1 - out = paddle.static.setitem(out, indice, v) - paddle.static.append_backward(out.sum()) - prog = paddle.static.default_main_program() - res = self.exe.run(prog, fetch_list=[out, x.grad_name, v.grad_name]) - - self.assertEqual(out.shape, x.shape) - np.testing.assert_allclose(res[0][1], np.ones((3, 4, 5)) * 5) - x_grad_expected = np.ones((2, 3, 4, 5)) - x_grad_expected[1] = 0 - np.testing.assert_allclose(res[1], x_grad_expected) - @test_with_pir_api @prog_scope() def test_expand(self): @@ -650,9 +578,7 @@ def test_as_complex(self): out = paddle.as_complex(x) self.assertShapeEqual( x, - [ - 2, - ], + [2], ) self.assertShapeEqual(out, []) grad_list = paddle.static.append_backward( @@ -858,20 +784,6 @@ def test_static_accuracy(self): self.assertEqual(res[0].shape, ()) - @prog_scope() - def test_static_auc(self): - x = paddle.full(shape=[3, 2], fill_value=0.25) - y = paddle.full(shape=[3], fill_value=1, dtype="int64") - out = paddle.static.auc(input=x, label=y)[0] - - prog = paddle.static.default_main_program() - res = self.exe.run( - prog, - fetch_list=[out], - ) - - self.assertEqual(res[0].shape, ()) - @test_with_pir_api @prog_scope() def test_std(self): diff --git a/test/deprecated/legacy_test/test_zero_dim_sundry_static_api_part3.py b/test/legacy_test/test_zero_dim_sundry_static_api_part3.py similarity index 94% rename from test/deprecated/legacy_test/test_zero_dim_sundry_static_api_part3.py rename to test/legacy_test/test_zero_dim_sundry_static_api_part3.py index 7ae165540b887d..a2370f228f2abc 100644 --- a/test/deprecated/legacy_test/test_zero_dim_sundry_static_api_part3.py +++ b/test/legacy_test/test_zero_dim_sundry_static_api_part3.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -428,33 +428,6 @@ def test_prelu(self): self.assertEqual(res[4].shape, ()) self.assertEqual(res[5].shape, ()) - @prog_scope() - def test_static_nn_prelu(self): - x1 = paddle.full([], 1.0, 'float32') - x1.stop_gradient = False - out1 = paddle.static.nn.prelu(x1, 'all') - grad_list = paddle.static.append_backward( - out1.sum(), parameter_list=[x1, out1] - ) - (_, x1_grad), (_, out1_grad) = grad_list - - prog = paddle.static.default_main_program() - self.exe.run(paddle.static.default_startup_program()) - res = self.exe.run( - prog, - fetch_list=[ - out1, - x1_grad, - out1_grad, - ], - ) - - self.assertEqual(res[0].shape, ()) - self.assertEqual(res[1].shape, ()) - self.assertEqual(res[2].shape, ()) - np.testing.assert_allclose(res[0], np.array(1)) - np.testing.assert_allclose(res[1], np.array(1)) - @test_with_pir_api @prog_scope() def test_while_loop(self): diff --git a/test/mkldnn/CMakeLists.txt b/test/mkldnn/CMakeLists.txt index 4dcf8d7ff2ca47..74b6b107b573ef 100644 --- a/test/mkldnn/CMakeLists.txt +++ b/test/mkldnn/CMakeLists.txt @@ -32,4 +32,4 @@ if(WITH_ONEDNN AND NOT WIN32) endif() # set_tests_properties(test_flags_mkldnn_ops_on_off PROPERTIES TIMEOUT 120) -set_pit_tests_properties() +set_pir_tests_properties() diff --git a/test/prim/process/test_prim_amp.py b/test/prim/process/test_prim_amp.py index 8a632c13a4e07b..65b53735c6b349 100644 --- a/test/prim/process/test_prim_amp.py +++ b/test/prim/process/test_prim_amp.py @@ -57,7 +57,9 @@ def train(self, use_prim): ) if use_prim: - net = paddle.jit.to_static(net, build_strategy=False) + net = paddle.jit.to_static( + net, build_strategy=False, full_graph=True + ) with paddle.amp.auto_cast(level='O1'): out = net(self.x) loss = paddle.mean(out) @@ -82,13 +84,17 @@ def test_amp_O1_infer(self): net = PrimeNet() core._set_prim_all_enabled(False) net.eval() - static_net = paddle.jit.to_static(net, build_strategy=False) + static_net = paddle.jit.to_static( + net, build_strategy=False, full_graph=True + ) res = static_net(self.x) # set prim all enabled core._set_prim_all_enabled(True) net.eval() - static_net = paddle.jit.to_static(net, build_strategy=False) + static_net = paddle.jit.to_static( + net, build_strategy=False, full_graph=True + ) with paddle.amp.auto_cast(level='O1'): res_amp = static_net(self.x) diff --git a/test/quantization/CMakeLists.txt b/test/quantization/CMakeLists.txt index f9fa50f3068061..e18f8c0a38096c 100644 --- a/test/quantization/CMakeLists.txt +++ b/test/quantization/CMakeLists.txt @@ -1,6 +1,4 @@ -if(ON_INFER) - include(../cpp/inference/test.cmake) -endif() +include(../cpp/inference/test.cmake) file( GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" diff --git a/test/sot/test_model_switch_training.py b/test/sot/test_model_switch_training.py index 50f215600da1ef..b887a3213a2a30 100644 --- a/test/sot/test_model_switch_training.py +++ b/test/sot/test_model_switch_training.py @@ -59,7 +59,7 @@ def get_dygraph_out(self, input): def get_static_out(self, input): paddle.seed(self.seed) self.compile_cache.clear() - static_net = paddle.jit.to_static(self.net) + static_net = paddle.jit.to_static(self.net, full_graph=False) static_net.eval() eval_result = static_net(input) self.check_mode(is_train=False) diff --git a/test/sot/test_segment_linear.py b/test/sot/test_segment_linear.py index ca58be5b5b3bb5..24fb9e6b5221ec 100644 --- a/test/sot/test_segment_linear.py +++ b/test/sot/test_segment_linear.py @@ -62,7 +62,7 @@ def test_simple(self): x = paddle.randn((1, 8, 8)) net = SimpleNet() net = paddle.jit.to_static( - net + net, full_graph=False ) # dont make effect. we need fetch sot PR in paddle. loss = net(x) loss = loss.sum() diff --git a/test/sot/test_sot_cost_model.py b/test/sot/test_sot_cost_model.py index a3acec5942005e..eed690a1e77815 100644 --- a/test/sot/test_sot_cost_model.py +++ b/test/sot/test_sot_cost_model.py @@ -102,7 +102,7 @@ def test_sot_fast_with_single_graph(self): def test_net(self): x = paddle.rand([10]) net = Net() - net = paddle.jit.to_static(net, enable_fallback=True) + net = paddle.jit.to_static(net, full_graph=False) for i in range(30): x = net(x) diff --git a/third_party/flashattn b/third_party/flashattn index d98d8a36cc9b88..22b604199d911d 160000 --- a/third_party/flashattn +++ b/third_party/flashattn @@ -1 +1 @@ -Subproject commit d98d8a36cc9b884a1f405d187a0c41caeb5144c6 +Subproject commit 22b604199d911d4e155fe9e54124148c7a290263