diff --git a/tests/link/pytorch/test_elemwise.py b/tests/link/pytorch/test_elemwise.py index 86089cc921..20c98094c1 100644 --- a/tests/link/pytorch/test_elemwise.py +++ b/tests/link/pytorch/test_elemwise.py @@ -1,10 +1,13 @@ import numpy as np import pytest +import pytensor import pytensor.tensor as pt import pytensor.tensor.math as ptm from pytensor.configdefaults import config from pytensor.graph.fg import FunctionGraph +from pytensor.scalar.basic import ScalarOp, get_scalar_type +from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax from pytensor.tensor.type import matrix, tensor, tensor3, vector from tests.link.pytorch.test_basic import compare_pytorch_and_py @@ -150,3 +153,33 @@ def test_cast(): fgraph, [np.arange(6, dtype="float32").reshape(2, 3)] ) assert res.dtype == torch.int32 + + +def test_vmap_elemwise(): + from pytensor.link.pytorch.dispatch.basic import pytorch_funcify + + class TestOp(ScalarOp): + def __init__(self): + super().__init__( + output_types_preference=lambda *_: [get_scalar_type("float32")] + ) + self.call_shapes = [] + self.nin = 1 + + def perform(self, *_): + raise RuntimeError("In perform") + + @pytorch_funcify.register(TestOp) + def relu(op, node, **kwargs): + def relu(row): + op.call_shapes.append(row.size()) + return torch.max(torch.zeros_like(row), row) + + return relu + + x = matrix("x", shape=(2, 3)) + op = TestOp() + f = pytensor.function([x], Elemwise(op)(x), mode="PYTORCH") + vals = torch.zeros(2, 3).normal_() + np.testing.assert_allclose(f(vals), torch.relu(vals)) + assert op.call_shapes == [torch.Size([])], op.call_shapes