-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Dy2St][PIR] Hold backward program in GradNode #63694
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
std::shared_ptr<::pir::Program> program = | ||
reinterpret_cast<std::shared_ptr<::pir::Program>&>(vh[0]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
能拿出来吗?
reinterpret_cast<std::shared_ptr<::pir::Program>&>(vh[0]); | ||
// TODO(gouzil): 试一下pybind11能不能使用智能指针作为参数 | ||
// pir::IrMapping mapper; | ||
attrs[key] = program.get(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么要转回裸指针?
void** vh = inst->simple_layout ? inst->simple_value_holder | ||
: &inst->nonsimple.values_and_holders[0]; | ||
|
||
::pybind11::handle(obj).inc_ref(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在这 inc,在哪 dec?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
其余 LGTM,完美
|
||
// Clear out and middles to avoid hold memory until backward finish. | ||
out.clear(); | ||
middles.clear(); | ||
VLOG(1) << "out and middles clear end"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个忘清了?
paddle/fluid/framework/type_defs.h
Outdated
@@ -40,6 +41,7 @@ class InferShapeContext; | |||
class InferVarTypeContext; | |||
class VarDesc; | |||
class BlockDesc; | |||
class ProgramDesc; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个有必要嘛?下面只添加了 pir::Program
,为啥这里要前向声明老 IR ProgramDesc
?
@@ -998,6 +1010,7 @@ void ConstructAttrMapForRunProgram( | |||
attr_end)); | |||
|
|||
PyObject* obj = nullptr; | |||
attrs["testkey"] = std::string("testvalue"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
忘清了?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5378583
to
4bcbcc6
Compare
…IRProgram # Conflicts: # python/paddle/jit/pir_translated_layer.py
@@ -23,7 +23,7 @@ if(NOT (NOT WITH_PYTHON AND ON_INFER)) | |||
cc_library(init_env_utils SRCS init_env_utils.cc) | |||
target_compile_definitions(init_env_utils PUBLIC PADDLE_DLL_EXPORT) | |||
|
|||
paddle_test(test_comp_eager SRCS test_eager_prim.cc DEPS init_env_utils) | |||
paddle_test(test_comp_eager SRCS test_eager_prim.cc init_env_utils.cc) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#56691 看样子是为了减小单测体积才这么搞的,这样改是不是又变回去了?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,但是它们之间目前存在重复依赖的问题,会导致windows LNK2005错误,所以就先直接不拆分了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
重复依赖是指?如果这样的话,24 行是不是不需要了?还是说后续有计划优化这里呢?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
重复依赖可以看13c89cf
,以及它的PR-CI-Windows-OPENBLAS,大概就是它俩都依赖了phi
导致重复依赖了。也参考了一些:paddle_test 的文档还是没能拆出来。最好的情况当然是拆分它,缩小体积,后续优化吧。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
最好的情况当然是拆分它,缩小体积,后续优化吧。
ok
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -23,7 +23,7 @@ if(NOT (NOT WITH_PYTHON AND ON_INFER)) | |||
cc_library(init_env_utils SRCS init_env_utils.cc) | |||
target_compile_definitions(init_env_utils PUBLIC PADDLE_DLL_EXPORT) | |||
|
|||
paddle_test(test_comp_eager SRCS test_eager_prim.cc DEPS init_env_utils) | |||
paddle_test(test_comp_eager SRCS test_eager_prim.cc init_env_utils.cc) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
最好的情况当然是拆分它,缩小体积,后续优化吧。
ok
Co-authored-by: xiongkun <[email protected]> Co-authored-by: Nyakku Shigure <[email protected]>
Co-authored-by: xiongkun <[email protected]> Co-authored-by: Nyakku Shigure <[email protected]>
Co-authored-by: xiongkun <[email protected]> Co-authored-by: Nyakku Shigure <[email protected]>
add int4_1 int4_2 FLAGS_logging_pir_py_code (PaddlePaddle#63981) * FLAGS_logging_pir_py_code * FLAGS_logging_pir_py_code_dir --------- Co-authored-by: jiahy0825 <[email protected]> [Cleanup] Remove Flake8 config in `.editorconfig` (PaddlePaddle#64027) 【PIR Dist Op Reg No.19】 reg pull_box_sparse (PaddlePaddle#62982) * fix * fix * fix * fix * fix * fix * add test * add * fix * fix * add out * fix * codestyle * fix * fix backward * merge [Dy2St][PIR] Hold backward program in GradNode (PaddlePaddle#63694) Co-authored-by: xiongkun <[email protected]> Co-authored-by: Nyakku Shigure <[email protected]> split test.cmake: add new test_cases.cmake (PaddlePaddle#64007) [PIR] Support sparse_slice and sparse_sum in pt (PaddlePaddle#64009) * support sparse_slice and sparse_sum in pt * support sparse_slice and sparse_sum in pt * support sparse_slice and sparse_sum in pt option for WITH_CPP_TEST (PaddlePaddle#63896) * option for WITH_CPP_TEST * fix * Fix * Fix [PIR] Fix `attributes_num` of `SliceArrayOp` (PaddlePaddle#64013) [Dy2St] Use `full_graph=True` outside dy2st uts (part1) (PaddlePaddle#64058) [Dy2St] Use `full_graph=True` outside dy2st uts (part2) (PaddlePaddle#64059) fix typo (PaddlePaddle#64060) Co-authored-by: jiahy0825 <[email protected]> update (PaddlePaddle#64042) Replace paddle/fluid/platform/device/gpu/gpu_dnn.h (PaddlePaddle#63819) * Fix * Fix * Fix Clean lookup_table_v2_op.h lookup_table_v2_op.cu (PaddlePaddle#64020) * Fix * ci refine GetTensorListFromArgs (PaddlePaddle#64045) Revert "【Hackathon 6th Fundable Projects 3 No.60】Remove fluid operator chunk_…" (PaddlePaddle#64050) This reverts commit 88b1a6e. [Prim][PIR] support floor_divide op forward in prim pir (PaddlePaddle#64023) * floor-div-dev * update test [CINN] Reconstruct shape_analysis (PaddlePaddle#63790) * reconstruct shape_analysis * fix input value shape infer * fix merge bugs * fix concat and gather op InferSymbolicShape * fix merge bug * fix value_to_shape_or_data hash error and add some checks * fix set shape for null value * fix group op lazy infer * add IsStaticShape check * fix merge bug * support static dim check and set for VectorType * change auto to detail type [XPU] fix bugs in processing of attention_mask and fix_seed_offset on XPU (PaddlePaddle#64003) * [XPU] fix segmentfault caused by setting fix_seed_offset on XPU * cast attention_mask to float32 when necessary fix merge bug (PaddlePaddle#64069) 【Fix PIR Unittest No.125、147、481】Fix some 0D uts in PIR mode (part1) (PaddlePaddle#64064) [Prim][VJP]support autogen to remove unused composite in .yaml (PaddlePaddle#64054) * support autogen to remove unused composite in .yaml * fix bug [PIR] Fix typo `set_pit_tests_properties` -> `set_pir_tests_properties` (PaddlePaddle#64063) [Dy2St] Use `full_graph=True` outside dy2st uts (part3) (PaddlePaddle#64066) [PIR save/load] Open more tests for paddle.save and paddle.load (PaddlePaddle#64044) * open more tests for paddle.save and paddle.load * fix API Improvement for paddle.nn.functional.group_norm and paddle.nn.GroupNorm (PaddlePaddle#63881) * update group_norm * update trt plugin * update trt plugin * fix trt plugin * fix trt plugin * fix test * fix test * fix ci windows inference * update kernel function names and add v2 test * fix * fix fp16 test Revert "【Hackathon 6th Fundable Projects 3 No.81】Remove fluid operators ctc_a…" (PaddlePaddle#64049) This reverts commit 2134ead. Clean paddle/fluid/operators/fused/attention_layer_norm.h (PaddlePaddle#64051) * Fix * Fix Replace operators::math to phi::math in fluid/operators (PaddlePaddle#63854) [CINN]Clean usless loop_reorder_aligment tactic (PaddlePaddle#63998) * [CINN]Clean usless loop_reorder_aligment tactic * fix source 【Hackathon 6th Fundable Projects 3 No.396】fluid operator yolo_box_head (PaddlePaddle#63783) * Fix * Fix * Fix * Fix * Fix 【Hackathon 6th Fundable Projects 3 No.240】fluid operator moe (PaddlePaddle#63929) 【Hackathon 6th Fundable Projects 3 No.82】fluid operator cudnn_lstm (PaddlePaddle#63936) * Fix * Fix * Fix * Fix [CINN] Remove useless log (PaddlePaddle#64052) [pir_save_load] add pir for test_jit_save_load.py (PaddlePaddle#63958) * add jit load.train * modify backward program lost * modify * combine eval and train * modify 8 case of jit.save.load * modify jit_save_load case * rename jit_save_load * change name all * modify timeout * modify new case * modify TestJitSaveLoadMultiMethods * modify cpu tensor no holder bug Flashattention support qkvpacked and varlen (PaddlePaddle#63289) * Flashattention support qkvpacked and varlen * fix codestyle * fix codestyle * FlashAttention kvReduceGQA Performance Optimization * Fix problem with windows * code clean * update third_party/flashattn * update errormsg and docs * update api * update doc * update doctest * update doc, test=document_fix * update doc, test=document_fix * Update python/paddle/nn/functional/flash_attention.py Co-authored-by: zachary sun <[email protected]> * Update python/paddle/nn/functional/flash_attention.py Co-authored-by: zachary sun <[email protected]> * update doc --------- Co-authored-by: zachary sun <[email protected]> 【PIR Dist Op Reg No.20】 reg global_gather (PaddlePaddle#63867) * reg global_gather * reg global_gather * reg_global_gather * fix * fix * fix * fix conflict * fix conflict * Update ops_api_gen.py * Update ops_api_gen.py Fix backward program kwargs error when process inplace value (PaddlePaddle#63939) 【Hackathon 6th No.35】support kwargs for recompute when use_reentrant == True fix (PaddlePaddle#63880) * support kwargs for recompute when use_reentrant == True * recover third party merge main lint delete printf change flash attn version
Co-authored-by: xiongkun <[email protected]> Co-authored-by: Nyakku Shigure <[email protected]>
PR Category
Execute Infrastructure
PR Types
Bug fixes
Description
将
forward_program
和backward_program
存入 attrs 保证反向 program 的 shared_ptr 能够 +1 不会在跑完正向后被释放NOTE: 需要使用 pybind11 的 2.12 版本
相关链接: