Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix stride default value None in torch.nn.functional.avg_pool #4984

Merged
merged 5 commits into from
Mar 7, 2020
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,10 @@ def _impl(inputs, input_types):
data = inputs[0]

pool_size = _infer_shape(inputs[1])
strides = _infer_shape(inputs[2])
if inputs[2]:
strides = _infer_shape(inputs[2])
else:
strides = pool_size
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this correct? we should use the default strides

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/pytorch/pytorch/blob/17a5c677963dc3ecb7ff505585ed15eadaaf74ef/torch/nn/functional.py#L269-L270
According to the description of strides above, for avg_pool2d. Stride's default value kernel_size, which should have the same size as pool_size here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

padding = _infer_shape(inputs[3])

ceil_mode = int(inputs[4])
Expand Down Expand Up @@ -918,7 +921,7 @@ def _get_constant(node):

def _get_operator_nodes(nodes):
""" Returns torch IR nodes that need conversion to Relay """
ops = {}
ops = []
Copy link
Member

@masahi masahi Mar 4, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should not have this change in this pr. Fetch and rebase against master.

# Traverse nodes and add to graph
for node in nodes:
if node.outputsSize() > 1:
Expand All @@ -927,7 +930,7 @@ def _get_operator_nodes(nodes):
node_name = _get_output_name(node)

if node.kind() != "prim::GetAttr":
ops[node_name] = node
ops.append((node_name, node))

return ops

Expand Down Expand Up @@ -1015,7 +1018,7 @@ def parse_params(graph, state_dict):

def parse_operators(operators, outputs, output_index_map, ret_name):
""" Convert each Torch IR operators to Relay equivalent """
for node_name, op_node in operators.items():
for node_name, op_node in operators:
operator = op_node.kind()
inputs = _get_op_inputs(op_node, outputs, output_index_map)

Expand Down