Skip to content

Commit

Permalink
Implement vectorize utility
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Sep 5, 2023
1 parent 8df22d7 commit d6b8777
Show file tree
Hide file tree
Showing 11 changed files with 111 additions and 55 deletions.
2 changes: 1 addition & 1 deletion pytensor/graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
clone,
ancestors,
)
from pytensor.graph.replace import clone_replace, graph_replace
from pytensor.graph.replace import clone_replace, graph_replace, vectorize
from pytensor.graph.op import Op
from pytensor.graph.type import Type
from pytensor.graph.fg import FunctionGraph
Expand Down
67 changes: 65 additions & 2 deletions pytensor/graph/replace.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from functools import partial
from typing import Iterable, Optional, Sequence, Union, cast, overload
from functools import partial, singledispatch
from typing import Iterable, Mapping, Optional, Sequence, Union, cast, overload

from pytensor.graph.basic import Apply, Constant, Variable, truncated_graph_inputs
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op


ReplaceTypes = Union[Iterable[tuple[Variable, Variable]], dict[Variable, Variable]]
Expand Down Expand Up @@ -198,3 +199,65 @@ def toposort_key(
return list(fg.outputs)
else:
return fg.outputs[0]


@singledispatch
def _vectorize_node(op: Op, node: Apply, *bached_inputs) -> Apply:
# Default implementation is provided in pytensor.tensor.blockwise
raise NotImplementedError


def vectorize_node(node: Apply, *batched_inputs) -> Apply:
"""Returns vectorized version of node with new batched inputs."""
op = node.op
return _vectorize_node(op, node, *batched_inputs)


def vectorize(
outputs: Sequence[Variable], vectorize: Mapping[Variable, Variable]
) -> Sequence[Variable]:
"""Vectorize outputs graph given mapping from old variables to expanded counterparts version.
Expanded dimensions must be on the left. Behavior is similar to the functional `numpy.vectorize`.
Examples
--------
.. code-block:: python
import pytensor
import pytensor.tensor as pt
from pytensor.graph import vectorize
# Original graph
x = pt.vector("x")
y = pt.exp(x) / pt.sum(pt.exp(x))
# Vectorized graph
new_x = pt.matrix("new_x")
[new_y] = vectorize([y], {x: new_x})
fn = pytensor.function([new_x], new_y)
fn([[0, 1, 2], [2, 1, 0]])
# array([[0.09003057, 0.24472847, 0.66524096],
# [0.66524096, 0.24472847, 0.09003057]])
"""
# Avoid circular import

inputs = truncated_graph_inputs(outputs, ancestors_to_include=vectorize.keys())
new_inputs = [vectorize.get(inp, inp) for inp in inputs]

def transform(var):
if var in inputs:
return new_inputs[inputs.index(var)]

node = var.owner
batched_inputs = [transform(inp) for inp in node.inputs]
batched_node = vectorize_node(node, *batched_inputs)
batched_var = batched_node.outputs[var.owner.outputs.index(var)]

return batched_var

# TODO: MergeOptimization or node caching?
return [transform(out) for out in outputs]
3 changes: 2 additions & 1 deletion pytensor/scalar/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from typing import Optional, Sequence, Tuple

from pytensor.compile import rebuild_collect_shared
from pytensor.graph import Constant, FunctionGraph, Variable, clone
from pytensor.graph.basic import Constant, Variable, clone
from pytensor.graph.fg import FunctionGraph
from pytensor.scalar.basic import ScalarInnerGraphOp, as_scalar


Expand Down
58 changes: 14 additions & 44 deletions pytensor/tensor/blockwise.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import re
from functools import singledispatch
from typing import Any, Dict, List, Optional, Sequence, Tuple, cast

import numpy as np
Expand All @@ -9,6 +8,7 @@
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.null_type import NullType
from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node, vectorize
from pytensor.tensor import as_tensor_variable
from pytensor.tensor.shape import shape_padleft
from pytensor.tensor.type import continuous_dtypes, discrete_dtypes, tensor
Expand Down Expand Up @@ -72,8 +72,8 @@ def operand_sig(operand: Variable, prefix: str) -> str:
return f"{inputs_sig}->{outputs_sig}"


@singledispatch
def _vectorize_node(op: Op, node: Apply, *bached_inputs) -> Apply:
@_vectorize_node.register(Op)
def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply:
if hasattr(op, "gufunc_signature"):
signature = op.gufunc_signature
else:
Expand All @@ -83,12 +83,6 @@ def _vectorize_node(op: Op, node: Apply, *bached_inputs) -> Apply:
return cast(Apply, Blockwise(op, signature=signature).make_node(*bached_inputs))


def vectorize_node(node: Apply, *batched_inputs) -> Apply:
"""Returns vectorized version of node with new batched inputs."""
op = node.op
return _vectorize_node(op, node, *batched_inputs)


class Blockwise(Op):
"""Generalizes a core `Op` to work with batched dimensions.
Expand Down Expand Up @@ -279,42 +273,18 @@ def as_core(t, core_t):

core_igrads = self.core_op.L_op(core_inputs, core_outputs, core_ograds)

batch_ndims = self._batch_ndim_from_outputs(outputs)

def transform(var):
# From a graph of ScalarOps, make a graph of Broadcast ops.
if isinstance(var.type, (NullType, DisconnectedType)):
return var
if var in core_inputs:
return inputs[core_inputs.index(var)]
if var in core_outputs:
return outputs[core_outputs.index(var)]
if var in core_ograds:
return ograds[core_ograds.index(var)]

node = var.owner

# The gradient contains a constant, which may be responsible for broadcasting
if node is None:
if batch_ndims:
var = shape_padleft(var, batch_ndims)
return var

batched_inputs = [transform(inp) for inp in node.inputs]
batched_node = vectorize_node(node, *batched_inputs)
batched_var = batched_node.outputs[var.owner.outputs.index(var)]

return batched_var

ret = []
for core_igrad, ipt in zip(core_igrads, inputs):
# Undefined gradient
if core_igrad is None:
ret.append(None)
else:
ret.append(transform(core_igrad))
igrads = vectorize(
[core_igrad for core_igrad in core_igrads if core_igrad is not None],
vectorize=dict(
zip(core_inputs + core_outputs + core_ograds, inputs + outputs + ograds)
),
)

return ret
igrads_iter = iter(igrads)
return [
None if core_igrad is None else next(igrads_iter)
for core_igrad in core_igrads
]

def L_op(self, inputs, outs, ograds):
from pytensor.tensor.math import sum as pt_sum
Expand Down
3 changes: 2 additions & 1 deletion pytensor/tensor/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply
from pytensor.graph.null_type import NullType
from pytensor.graph.replace import _vectorize_node
from pytensor.graph.utils import MethodNotDefined
from pytensor.link.c.basic import failure_code
from pytensor.link.c.op import COp, ExternalCOp, OpenMPOp
Expand All @@ -22,7 +23,7 @@
from pytensor.tensor import elemwise_cgen as cgen
from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import _get_vector_length, as_tensor_variable
from pytensor.tensor.blockwise import _vectorize_node, vectorize_not_needed
from pytensor.tensor.blockwise import vectorize_not_needed
from pytensor.tensor.type import (
TensorType,
continuous_dtypes,
Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/random/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Variable, equal_computations
from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node
from pytensor.misc.safe_asarray import _asarray
from pytensor.scalar import ScalarVariable
from pytensor.tensor.basic import (
Expand All @@ -17,7 +18,6 @@
get_vector_length,
infer_static_shape,
)
from pytensor.tensor.blockwise import _vectorize_node
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType
from pytensor.tensor.random.utils import (
broadcast_params,
Expand Down
3 changes: 2 additions & 1 deletion pytensor/tensor/rewriting/blockwise.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from pytensor.compile.mode import optdb
from pytensor.graph import node_rewriter
from pytensor.graph.replace import vectorize_node
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
from pytensor.tensor.blockwise import Blockwise, vectorize_node
from pytensor.tensor.blockwise import Blockwise


@node_rewriter([Blockwise])
Expand Down
21 changes: 20 additions & 1 deletion tests/graph/test_replace.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import numpy as np
import pytest
import scipy.special

import pytensor.tensor as pt
from pytensor import config, function, shared
from pytensor.graph.basic import graph_inputs
from pytensor.graph.replace import clone_replace, graph_replace
from pytensor.graph.replace import clone_replace, graph_replace, vectorize
from pytensor.tensor import dvector, fvector, vector
from tests import unittest_tools as utt
from tests.graph.utils import MyOp, MyVariable
Expand Down Expand Up @@ -223,3 +224,21 @@ def test_graph_replace_disconnected(self):
assert oc[0] is o
with pytest.raises(ValueError, match="Some replacements were not used"):
oc = graph_replace([o], {fake: x.clone()}, strict=True)


class TestVectorize:
# TODO: Add tests with multiple outputs, constants, and other singleton types

def test_basic(self):
x = pt.vector("x")
y = pt.exp(x) / pt.sum(pt.exp(x))

new_x = pt.matrix("new_x")
[new_y] = vectorize([y], {x: new_x})

fn = function([new_x], new_y)
test_new_y = np.array([[0, 1, 2], [2, 1, 0]]).astype(config.floatX)
np.testing.assert_allclose(
fn(test_new_y),
scipy.special.softmax(test_new_y, axis=-1),
)
2 changes: 1 addition & 1 deletion tests/tensor/random/test_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import pytensor.tensor as at
from pytensor import config, function
from pytensor.gradient import NullTypeGradError, grad
from pytensor.graph.replace import vectorize_node
from pytensor.raise_op import Assert
from pytensor.tensor.blockwise import vectorize_node
from pytensor.tensor.math import eq
from pytensor.tensor.random import normal
from pytensor.tensor.random.op import RandomState, RandomVariable, default_rng
Expand Down
3 changes: 2 additions & 1 deletion tests/tensor/test_blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from pytensor import config
from pytensor.gradient import grad
from pytensor.graph import Apply, Op
from pytensor.graph.replace import vectorize_node
from pytensor.tensor import tensor
from pytensor.tensor.blockwise import Blockwise, _parse_gufunc_signature, vectorize_node
from pytensor.tensor.blockwise import Blockwise, _parse_gufunc_signature
from pytensor.tensor.nlinalg import MatrixInverse
from pytensor.tensor.slinalg import Cholesky, Solve

Expand Down
2 changes: 1 addition & 1 deletion tests/tensor/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import vectorize_node
from pytensor.link.basic import PerformLinker
from pytensor.link.c.basic import CLinker, OpWiseCLinker
from pytensor.tensor import as_tensor_variable
from pytensor.tensor.basic import second
from pytensor.tensor.blockwise import vectorize_node
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import Any, Sum
from pytensor.tensor.math import all as pt_all
Expand Down

0 comments on commit d6b8777

Please sign in to comment.