Skip to content

Commit

Permalink
Merge pull request apache#47 from wjj19950828/paddle_frontend
Browse files Browse the repository at this point in the history
Fixed pool2d bug
  • Loading branch information
jiangjiajun authored Sep 23, 2021
2 parents bb76e0a + 4bdbb38 commit 6b45a40
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
8 changes: 7 additions & 1 deletion python/tvm/relay/frontend/paddlepaddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -1356,6 +1356,7 @@ def convert_pool2d(g, op, block):
ksize = [1, 1]

input_x = g.get_node(op.input("X")[0])
input_shape = infer_shape(input_x)

op_map = {
"avg": "avg_pool2d",
Expand All @@ -1372,7 +1373,7 @@ def convert_pool2d(g, op, block):
if padding_algorithm == "VALID":
paddings = [0, 0]
elif padding_algorithm == "SAME":
in_h, in_w = infer_shape(input_x)[2:]
in_h, in_w = input_shape[2:]
pad_h = _get_pad_size(in_h, ksize[0], strides[0])
pad_w = _get_pad_size(in_w, ksize[1], strides[1])
paddings = [pad_h[0], pad_w[0], pad_h[1], pad_w[1]]
Expand All @@ -1385,6 +1386,11 @@ def convert_pool2d(g, op, block):
msg = 'Value {} in attribute "padding" of operator Pool2d is not "valid."'
raise tvm.error.OpAttributeInvalid(msg.format(padding_algorithm))

if input_shape[2] < ksize[0]:
ksize[0] = input_shape[2]
if input_shape[3] < ksize[1]:
ksize[1] = input_shape[3]

if not adaptive:
out = getattr(_op.nn, op_map[pooling_type])(
input_x, pool_size=ksize, strides=strides, padding=paddings, ceil_mode=ceil_mode
Expand Down
11 changes: 10 additions & 1 deletion tests/python/frontend/paddlepaddle/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1458,6 +1458,13 @@ def pool2d2(inputs):

@paddle.jit.to_static
def pool2d3(inputs):
output = nn.functional.max_pool2d(
inputs, kernel_size=2, stride=2, padding=0
)
return output

@paddle.jit.to_static
def pool2d4(inputs):
output, max_indices = nn.functional.max_pool2d(
inputs, kernel_size=2, stride=2, padding=0, return_mask=True
)
Expand All @@ -1466,8 +1473,10 @@ def pool2d3(inputs):
input_data = paddle.uniform(shape=[1, 2, 32, 32], dtype="float32", min=-1, max=1)
verify_model(pool2d1, input_data, input_shape=[[-1, 2, 32, 32]])
verify_model(pool2d2, input_data=input_data)
input_data1 = paddle.uniform(shape=[1, 2, 1, 50], dtype="float32", min=-1, max=1)
verify_model(pool2d3, input_data=input_data1)
# need op max_pool2d_with_index
verify_model(pool2d3, input_data=input_data)
verify_model(pool2d4, input_data=input_data)


@tvm.testing.uses_gpu
Expand Down

0 comments on commit 6b45a40

Please sign in to comment.