Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Auto Parallel] Add conv2d and pool flops #48084

Merged
merged 10 commits into from
Jan 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions python/paddle/fluid/tests/unittests/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,32 @@ def test_flops(self):
)
== 144
)
self.assertTrue(
flops(
'pool',
{'X': [[12, 12]]},
{},
)
== 12 * 12
)
self.assertTrue(
flops(
'conv2d',
{
'Bias': [],
'Filter': [[3, 3, 2, 2]],
'Input': [[8, 3, 4, 4]],
'ResidualData': [],
},
{
'dilations': [1, 1],
'groups': 1,
'paddings': [1, 1],
'strides': [1, 1],
},
)
== 14400
)


if __name__ == '__main__':
Expand Down
91 changes: 90 additions & 1 deletion python/paddle/utils/flops.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,85 @@ def _c_embedding_flops(input_shapes, attrs):
return 0


@register_flops("conv2d")
def _conv2d_flops(input_shapes, attrs):
"""FLOPs computation for conv2d op.
For conv2d(input,filter):
active_elements = batch_size * numel(output)
conv_flops = 2 * macs_per_position_conv * active_elements
bias_flops = out_channels * active_elements
equation: flops = conv_flops + bias_flops
"""

bias = (
input_shapes.get('Bias')[0]
if len(input_shapes.get('Bias')) > 0
else None
)
input = input_shapes.get('Input')[0]
weight = input_shapes.get('Filter')[0]

padding = attrs.get('paddings')
stride = attrs.get('strides')
dilation = attrs.get('dilations')
groups = attrs.get('groups')

batch_size = input[0]
in_channels = input[1]
out_channels = weight[0]
kernel_dims = list(weight[2:])
input_dims = list(input[2:])
length = len(input_dims)

paddings = (
padding
if isinstance(padding, list)
else [
padding,
]
* length
)
strides = (
stride
if isinstance(stride, list)
else [
stride,
]
* length
)
dilations = (
dilation
if isinstance(dilation, list)
else [
dilation,
]
* length
)

output_dims = []
for idx, input_dim in enumerate(input_dims):
output_dim = (
input_dim
+ 2 * paddings[idx]
- (dilations[idx] * (kernel_dims[idx] - 1) + 1)
) // strides[idx] + 1
output_dims.append(output_dim)
filters_per_channel = out_channels // groups
macs_conv_per_position = (
prod(kernel_dims) * in_channels * filters_per_channel
)
active_elements = batch_size * prod(output_dims)
overall_conv_macs = macs_conv_per_position * active_elements
overall_conv_flops = 2 * overall_conv_macs

overall_bias_flops = 0

if bias is not None:
overall_bias_flops = out_channels * active_elements

return overall_conv_flops + overall_bias_flops


@register_flops("dropout")
def _dropout_flops(input_shapes, attrs):
"""FLOPs computation for dropout op.
Expand Down Expand Up @@ -195,7 +274,7 @@ def _matmul_v2_flops(input_shapes, attrs):
shape_of_other = [odim1, odim2 ... odim(n-m) ... odim_m_1, dim_m] length:m
suppose n > m and dim_n = odim_m_1:
shape_of_output = [dim1, dim2 ... max(dim(n-m), odim(n-m)), max(dim(n-m+1), odim(n-m+1))...dim_n_1, dim_m]
equation: flops = 2 * numel(output) * dim_n
equation: flops = 2 * numel(outputs) * dim_n
"""
x_shape = input_shapes.get('X')[0]
y_shape = input_shapes.get('Y')[0]
Expand Down Expand Up @@ -281,3 +360,13 @@ def _transpose2_flops(input_shapes, attrs):
equation: flops = 0
"""
return 0


@register_flops("pool")
def _pool_flops(input_shapes, attrs):
"""FLOPs computation for pool op.
For pool(input):
equation: flops = (numel)total number of elements in the input tensor.
"""
input = input_shapes.get('X')[0]
return prod(input)