Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Paddle Inference ]use python to generate cutlass code #50603

Merged
merged 31 commits into from
Mar 13, 2023

Conversation

zhoutianzi666
Copy link
Contributor

@zhoutianzi666 zhoutianzi666 commented Feb 17, 2023

PR types

Others

PR changes

Others

Describe

  • 用python脚本生成CUTLASS conv代码

@paddle-bot
Copy link

paddle-bot bot commented Feb 17, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added contributor External developers status: proposed labels Feb 17, 2023
@MARD1NO
Copy link
Contributor

MARD1NO commented Feb 24, 2023

我觉得有个小问题,虽然kernel是生成的,但我觉得你这个PR应该包含生成的cu文件。

假设我开发另外一个功能,我重新cmake,cmake走到你生成kernel的逻辑,会生成 conv2d_bias_act.cu,那此时我git commit还需要避免把这个cu文件提交上去。

@zhoutianzi666
Copy link
Contributor Author

zhoutianzi666 commented Feb 28, 2023

我觉得有个小问题,虽然kernel是生成的,但我觉得你这个PR应该包含生成的cu文件。

假设我开发另外一个功能,我重新cmake,cmake走到你生成kernel的逻辑,会生成 conv2d_bias_act.cu,那此时我git commit还需要避免把这个cu文件提交上去。

感谢review!已经添加到了.gitignore了!

Copy link
Contributor

@zhangjun zhangjun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该PR完成功能:将之前合入代码进行了模版化生成替换

Comment on lines 70 to 96
class CbrAct(enum.Enum):
Identity = 1
Relu = 2
Silu = 3


ActCutlassTag = {
CbrAct.Identity: 'cutlass::epilogue::thread::Identity',
CbrAct.Silu: 'cutlass::epilogue::thread::SiLu',
CbrAct.Relu: 'cutlass::epilogue::thread::ReLu',
}

# some global variables used, now we only support these residual blocks
EpiResBlocks = [
(CbrAct.Silu, "cutlass::plus", CbrAct.Identity),
(CbrAct.Identity, "cutlass::plus", CbrAct.Relu),
]

UnderScoreName = {
EpiResBlocks[0]: "conv2d_bias_silu_add",
EpiResBlocks[1]: "conv2d_bias_add_relu",
}

CamelName = {
EpiResBlocks[0]: "Conv2dBiasSiluAdd",
EpiResBlocks[1]: "Conv2dBiasAddRelu",
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里枚举定义,是不是能单独拿出来共用

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里枚举定义,是不是能单独拿出来共用

这里是为了在生成函数代码时候,用来生成函数名字的。不同的后处理对应不同的函数名字。

@zhangjun
Copy link
Contributor

zhangjun commented Mar 3, 2023

.gitignore 加上

paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_act.cu
paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_residual.cu

Comment on lines 112 to 114
execute_process(
COMMAND ${sh_cmd} ${sh_arg0}
WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/fusion/cutlass")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改成add_custom_target 形式

add_custom_target(
eager_python_c_codegen
COMMAND
"${PYTHON_EXECUTABLE}"
"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/generator/python_c_gen.py"
"--api_yaml_path=${api_yaml_path},${fwd_api_yaml_path}"
"--output_path=${tmp_python_c_output_path}"
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_python_c_output_path}
${python_c_output_path}
VERBATIM)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改成add_custom_target 形式

add_custom_target(
eager_python_c_codegen
COMMAND
"${PYTHON_EXECUTABLE}"
"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/generator/python_c_gen.py"
"--api_yaml_path=${api_yaml_path},${fwd_api_yaml_path}"
"--output_path=${tmp_python_c_output_path}"
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_python_c_output_path}
${python_c_output_path}
VERBATIM)

由于kernel_declare的函数原因,必须在cmake时候产生文件,所以目前只能用execute_process(来生成文件

@zhoutianzi666
Copy link
Contributor Author

.gitignore 加上

paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_act.cu
paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_residual.cu

done!

.gitignore Outdated
@@ -96,4 +96,6 @@ paddle/fluid/prim/api/generated/prim_api/*
paddle/fluid/framework/__init__.py
paddle/phi/api/profiler/__init__.py
python/paddle/incubate/fleet/parameter_server/pslib/ps_pb2.py
paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_act.cu
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

生成的文件加入.gitignore


# this is used for leaky_relu, this activation need a fuse_alpha parameter

cba_kernel_alpha = cba_kernel_no_alpha.replace(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有些激活函数。例如leaky_relu需要参数alpha,因此定义了cba_kernel_alpha供使用!

"epi_part": "${epi_func}< ${element_c}, ${epilogue_vector_length}, ${element_accum}, ${element_epilogue}>",
}

cba_kernel_no_alpha = (
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分是传递一些参数给conv kernel使用的代码

@@ -62,12 +62,14 @@ __global__ void naive_conv2d_kernel(const half *input,
int dilation_w,
int oh,
int ow,
int groups,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个baseline函数支持了group conv,为以后的cutlass group conv和depthwise conv做支持!

Copy link
Contributor

@qingqing01 qingqing01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1, 有单元测试不?
2, 后续需要文档, 每增加一个新的模板生成代码,需要修改那些文件。 看起来C++、Python都需要修改

).replace(
"typename ImplicitG", "float alpha = params.alpha; typename ImplicitG"
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need more comments in this file, in orde to easy to maintain and update for others

@zhoutianzi666
Copy link
Contributor Author

1, 有单元测试不?
2, 后续需要文档, 每增加一个新的模板生成代码,需要修改那些文件。 看起来C++、Python都需d

1, 有单元测试不?
2, 后续需要文档, 每增加一个新的模板生成代码,需要修改那些文件。 看起来C++、Python都需要修改

单元测试在python/paddle/fluid/tests/unittests/ir/inference/test_cutlass_conv2d_fusion_op.py中。
几乎不需要改动C++文件,C++的代码仅仅是作为baseline来验证cutlass 各种kernel的正确性。

qingqing01
qingqing01 previously approved these changes Mar 9, 2023
Copy link
Contributor

@qingqing01 qingqing01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后续此类代码,需要更详细注释,写清楚使用限制等。

@zhoutianzi666
Copy link
Contributor Author

后续此类代码,需要更详细注释,写清楚使用限制等。

ok

Copy link
Contributor

@zhangjun zhangjun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

后续TODO:
group>1 情况;
padding_algorithm非EXPLICIT支持;

@zhangjun zhangjun merged commit 4e9e23c into PaddlePaddle:develop Mar 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants