Skip to content

Commit

Permalink
[SOT][dynamic shape] Adapt some InferMeta for dynamic shape (#65517)
Browse files Browse the repository at this point in the history

Co-authored-by: SigureMo <[email protected]>
Co-authored-by: Winters Montagne <[email protected]>
  • Loading branch information
3 people authored Jul 2, 2024
1 parent 888f213 commit f6ae216
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 34 deletions.
56 changes: 28 additions & 28 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ void BCELossInferMeta(const MetaTensor& input,

bool check = true;
if ((!config.is_runtime) &&
(common::product(input_dims) <= 0 || common::product(label_dims) <= 0)) {
(contain_unknown_dim(input_dims) || contain_unknown_dim(label_dims))) {
check = false;
}

Expand Down Expand Up @@ -644,34 +644,34 @@ void ConvInferMeta(const MetaTensor& input,
? filter_dims[filter_dims.size() - 1]
: filter_dims[1];

PADDLE_ENFORCE_EQ(
input_channels,
filter_channels * groups,
phi::errors::InvalidArgument(
"The number of input's channels should be equal to filter's channels "
"* groups for Op(Conv). But received: the input's channels is %d, "
"the input's shape is [%s]; the filter's channels is %d, the "
"filter's shape is [%s]; the groups is %d, the data_format is %s. "
"The error may come from wrong data_format setting.",
input_channels,
in_dims,
filter_channels,
filter_dims,
groups,
data_format));
PADDLE_ENFORCE_EQ(
filter_dims[0] % groups,
0,
phi::errors::InvalidArgument(
"The number of output's channels (filter's first dimension) of "
"Op(Conv) should be divided by groups. But received: "
"the output channels is %d, the filter's shape is [%s], "
"the groups is %d.",
filter_dims[0],
filter_dims,
groups));

if (config.is_runtime) {
PADDLE_ENFORCE_EQ(
input_channels,
filter_channels * groups,
phi::errors::InvalidArgument(
"The number of input's channels should be equal to filter's "
"channels "
"* groups for Op(Conv). But received: the input's channels is %d, "
"the input's shape is [%s]; the filter's channels is %d, the "
"filter's shape is [%s]; the groups is %d, the data_format is %s. "
"The error may come from wrong data_format setting.",
input_channels,
in_dims,
filter_channels,
filter_dims,
groups,
data_format));
PADDLE_ENFORCE_EQ(
filter_dims[0] % groups,
0,
phi::errors::InvalidArgument(
"The number of output's channels (filter's first dimension) of "
"Op(Conv) should be divided by groups. But received: "
"the output channels is %d, the filter's shape is [%s], "
"the groups is %d.",
filter_dims[0],
filter_dims,
groups));
PADDLE_ENFORCE_GT(
filter_dims[0],
0,
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/infermeta/fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2492,7 +2492,7 @@ void BNActXPUInferMeta(const MetaTensor& x,

bool check = true;
if ((!config.is_runtime) &&
(common::product(scale_dim) <= 0 || common::product(bias_dim) <= 0)) {
(contain_unknown_dim(scale_dim) || contain_unknown_dim(bias_dim))) {
check = false;
}

Expand Down
7 changes: 4 additions & 3 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -912,9 +912,10 @@ void BatchNormInferMeta(const MetaTensor& x,
}

bool check = true;

if (!scale || !bias ||
((!config.is_runtime) && (common::product(scale.dims()) <= 0 ||
common::product(bias.dims()) <= 0))) {
((!config.is_runtime) && (contain_unknown_dim(scale.dims()) ||
contain_unknown_dim(bias.dims()) || C == -1))) {
check = false;
}

Expand Down Expand Up @@ -4947,7 +4948,7 @@ void SigmoidCrossEntropyWithLogitsInferMeta(const MetaTensor& x,

bool check = true;
if ((!config.is_runtime) &&
(common::product(x_dims) <= 0 || common::product(labels_dims) <= 0)) {
(contain_unknown_dim(x_dims) || contain_unknown_dim(labels_dims))) {
check = false;
}

Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ void InstanceNormInferMeta(const MetaTensor& x,
"of scale is [%d]",
scale_dim,
scale_dim.size()));
bool check = !((!config.is_runtime) && (common::product(scale_dim) <= 0));
bool check = config.is_runtime || contain_unknown_dim(scale_dim);
if (check) {
PADDLE_ENFORCE_EQ(scale_dim[0],
C,
Expand All @@ -615,7 +615,7 @@ void InstanceNormInferMeta(const MetaTensor& x,
"of bias is [%d]",
bias_dim,
bias_dim.size()));
bool check = !((!config.is_runtime) && (common::product(bias_dim) <= 0));
bool check = config.is_runtime || !contain_unknown_dim(bias_dim);
if (check) {
PADDLE_ENFORCE_EQ(bias_dim[0],
C,
Expand Down
74 changes: 74 additions & 0 deletions test/dygraph_to_static/test_dynamic_shape_infermeta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest
from typing import Any, Callable, Sequence

import numpy as np
from dygraph_to_static_utils import (
Dy2StTestBase,
test_ast_only,
test_pir_only,
)

import paddle
from paddle.static.input import InputSpec


class TestDynamicShapeInfermeta(Dy2StTestBase):
def check_dynamic_shape(
self,
fn: Callable[..., Any],
inputs: Sequence[paddle.Tensor],
input_specs: list[InputSpec],
):
static_fn = paddle.jit.to_static(
fn,
full_graph=True,
input_spec=input_specs,
)
np.testing.assert_allclose(static_fn(*inputs), fn(*inputs), rtol=1e-05)

@test_pir_only
@test_ast_only
def test_conv2d(self):
self.check_dynamic_shape(
paddle.nn.Conv2D(3, 3, 3),
[paddle.randn([1, 3, 32, 32])],
[InputSpec(shape=[None, None, None, None], dtype='float32')],
)

@test_pir_only
@test_ast_only
def test_bn(self):
self.check_dynamic_shape(
paddle.nn.BatchNorm2D(3),
[paddle.randn([1, 3, 32, 32])],
[InputSpec(shape=[None, None, None, None], dtype='float32')],
)

@test_pir_only
@test_ast_only
def test_depthwise_conv2d(self):
self.check_dynamic_shape(
paddle.nn.Conv2D(3, 3, 3, groups=3),
[paddle.randn([1, 3, 32, 32])],
[InputSpec(shape=[None, None, None, None], dtype='float32')],
)


if __name__ == '__main__':
unittest.main()

0 comments on commit f6ae216

Please sign in to comment.