Skip to content

Commit

Permalink
Implement Blockwise in PyTorch backend
Browse files Browse the repository at this point in the history
  • Loading branch information
Ch0ronomato authored and ricardoV94 committed Nov 5, 2024
1 parent 6f4219a commit 13cb8e8
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 0 deletions.
1 change: 1 addition & 0 deletions pytensor/link/pytorch/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@
import pytensor.link.pytorch.dispatch.shape
import pytensor.link.pytorch.dispatch.sort
import pytensor.link.pytorch.dispatch.subtensor
import pytensor.link.pytorch.dispatch.blockwise
# isort: on
32 changes: 32 additions & 0 deletions pytensor/link/pytorch/dispatch/blockwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch
import torch.compiler

from pytensor.graph import FunctionGraph
from pytensor.link.pytorch.dispatch import pytorch_funcify
from pytensor.tensor.blockwise import Blockwise


@pytorch_funcify.register(Blockwise)
def funcify_Blockwise(op: Blockwise, node, *args, **kwargs):
batched_dims = op.batch_ndim(node)
core_node = op._create_dummy_core_node(node.inputs)
core_fgraph = FunctionGraph(inputs=core_node.inputs, outputs=core_node.outputs)
inner_func = pytorch_funcify(core_fgraph, squeeze_output=len(node.outputs) == 1)

for _ in range(batched_dims):
inner_func = torch.vmap(inner_func)

@torch.compiler.disable(recursive=False)
def batcher(*inputs):
op._check_runtime_broadcast(node, inputs)
# broadcast on batched_dims
all_batched_dims = tuple(t.shape[:batched_dims] for t in inputs)
batched_shape = torch.broadcast_shapes(*all_batched_dims)
broadcast_inputs = [
torch.broadcast_to(i, batched_shape + i.shape[batched_dims:])
for i in inputs
]
res = inner_func(*broadcast_inputs)
return res

return batcher
53 changes: 53 additions & 0 deletions tests/link/pytorch/test_blockwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import numpy as np
import pytest

import pytensor
import pytensor.tensor as pt
from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
from pytensor.tensor.blockwise import Blockwise


torch = pytest.importorskip("torch")
basic = pytest.importorskip("pytensor.link.pytorch.dispatch.basic")


class TestOp(Op):
gufunc_signature = "(m,n),(n,p)->(m,p)"

def __init__(self, final_shape):
super().__init__()
self.final_shape = final_shape
self.call_shapes = []

def make_node(self, *args):
return Apply(self, list(args), [pt.matrix("_", shape=self.final_shape)])

def perform(self, *_):
raise RuntimeError("In perform")


@basic.pytorch_funcify.register(TestOp)
def evaluate_test_op(op, **_):
@torch.compiler.disable(recursive=False)
def func(a, b):
op.call_shapes.extend(map(torch.Tensor.size, [a, b]))
return a @ b

return func


def test_blockwise_broadcast():
_x = np.random.rand(5, 1, 2, 3)
_y = np.random.rand(3, 3, 2)

x = pt.tensor4("x", shape=(5, 1, 2, 3))
y = pt.tensor3("y", shape=(3, 3, 2))
op = TestOp((2, 2))
z = Blockwise(op)(x, y)

f = pytensor.function([x, y], z, mode="PYTORCH")
res = f(_x, _y)
assert tuple(res.shape) == (5, 3, 2, 2)
np.testing.assert_allclose(res, _x @ _y)
assert op.call_shapes == [(2, 3), (3, 2)]

0 comments on commit 13cb8e8

Please sign in to comment.