Skip to content

Commit

Permalink
Improve torch elemwise operator
Browse files Browse the repository at this point in the history
  • Loading branch information
Ch0ronomato authored and ricardoV94 committed Nov 19, 2024
1 parent 0ba554b commit 6de3151
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 3 deletions.
18 changes: 15 additions & 3 deletions pytensor/link/pytorch/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,21 @@ def pytorch_funcify_Elemwise(op, node, **kwargs):
scalar_op = op.scalar_op
base_fn = pytorch_funcify(scalar_op, node=node, **kwargs)

def elemwise_fn(*inputs):
Elemwise._check_runtime_broadcast(node, inputs)
return base_fn(*inputs)
if hasattr(scalar_op, "nfunc_spec") and hasattr(torch, scalar_op.nfunc_spec[0]):
# torch can handle this scalar
# broadcast, we'll let it.
def elemwise_fn(*inputs):
Elemwise._check_runtime_broadcast(node, inputs)
return base_fn(*inputs)
else:

def elemwise_fn(*inputs):
Elemwise._check_runtime_broadcast(node, inputs)
broadcast_inputs = torch.broadcast_tensors(*inputs)
ufunc = base_fn
for _ in range(broadcast_inputs[0].dim()):
ufunc = torch.vmap(ufunc)
return ufunc(*broadcast_inputs)

return elemwise_fn

Expand Down
33 changes: 33 additions & 0 deletions tests/link/pytorch/test_elemwise.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import numpy as np
import pytest

import pytensor
import pytensor.tensor as pt
import pytensor.tensor.math as ptm
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.scalar.basic import ScalarOp, get_scalar_type
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax
from pytensor.tensor.type import matrix, tensor, tensor3, vector
from tests.link.pytorch.test_basic import compare_pytorch_and_py
Expand Down Expand Up @@ -150,3 +153,33 @@ def test_cast():
fgraph, [np.arange(6, dtype="float32").reshape(2, 3)]
)
assert res.dtype == torch.int32


def test_vmap_elemwise():
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify

class TestOp(ScalarOp):
def __init__(self):
super().__init__(
output_types_preference=lambda *_: [get_scalar_type("float32")]
)
self.call_shapes = []
self.nin = 1

def perform(self, *_):
raise RuntimeError("In perform")

@pytorch_funcify.register(TestOp)
def relu(op, node, **kwargs):
def relu(row):
op.call_shapes.append(row.size())
return torch.max(torch.zeros_like(row), row)

return relu

x = matrix("x", shape=(2, 3))
op = TestOp()
f = pytensor.function([x], Elemwise(op)(x), mode="PYTORCH")
vals = torch.zeros(2, 3).normal_()
np.testing.assert_allclose(f(vals), torch.relu(vals))
assert op.call_shapes == [torch.Size([])], op.call_shapes

0 comments on commit 6de3151

Please sign in to comment.