Skip to content

Commit

Permalink
fix: torch frontend max pooling to support optional batch dim (ivy-ll…
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam-Armstrong authored Mar 6, 2024
1 parent 514fedc commit c1f98b1
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions ivy/functional/frontends/torch/nn/functional/pooling_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,13 @@ def max_pool1d(
stride = kernel_size
if not isinstance(padding, int):
padding = [(pad, pad) for pad in padding]
return ivy.max_pool1d(
if input.ndim == 2:
without_batch_dim = True
input = ivy.expand_dims(input, axis=0)
else:
without_batch_dim = False

ret = ivy.max_pool1d(
input,
kernel_size,
stride,
Expand All @@ -246,6 +252,9 @@ def max_pool1d(
dilation=dilation,
ceil_mode=ceil_mode,
)
if without_batch_dim:
ret = ret[0]
return ret


@with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch")
Expand All @@ -263,7 +272,13 @@ def max_pool2d(
stride = kernel_size
if not isinstance(padding, int):
padding = [(pad, pad) for pad in padding]
return ivy.max_pool2d(
if input.ndim == 3:
without_batch_dim = True
input = ivy.expand_dims(input, axis=0)
else:
without_batch_dim = False

ret = ivy.max_pool2d(
input,
kernel_size,
stride,
Expand All @@ -272,6 +287,9 @@ def max_pool2d(
dilation=dilation,
ceil_mode=ceil_mode,
)
if without_batch_dim:
ret = ret[0]
return ret


@with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch")
Expand Down Expand Up @@ -306,5 +324,4 @@ def max_pool3d(
)
if without_batch_dim:
ret = ret[0]

return ret

0 comments on commit c1f98b1

Please sign in to comment.