Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Update tensorizer / detensorizer comments #1781

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 124 additions & 2 deletions src/beanmachine/ppl/compiler/devectorizer_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,99 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

#
# Tensorizing and detensorizing
#
# TODO: The fact that we use vectorizing and tensorizing as synonyms throughout all this code is
# unnecessarily confusing. Pick one and stick to it.
#
# There are a number of important mismatches between the BMG and PyTorch type systems
# that we need to address when converting the accumulated graph of a PyTorch model into
# a BMG graph data structure.
#
# * PyTorch values are tensors: rectangular arrays of arbitrary dimensionality. BMG
# values are either "scalars" -- single values -- or Eigen arrays == two-dimensional
# rectangles.
#
# * PyTorch uses the same functions regardless of dimensionality; tensor(10).log() and
# tensor([[[10, 20, 30], [40, 50, 60]]]).log() are both represented by .log(). BMG
# has distinct operators for LOG and MATRIX_LOG, MULTIPLY and ELEMENTWISE_MULTIPLY
# and so on.
#
# * PyTorch's type system describes *storage format*: a tensor can be a tensor of integers
# or a tensor of doubles. BMG's type system describes *data semantics*: floating point
# types are real, positive real, negative real and probability, discrete types are natural
# and bool.
#
# Moreover: we are gradually adding matrix-flavor operators to BMG, but they're not all
# there yet. It would be nice to be able to compile models that use multidimensional tensors
# even if doing so results in a larger graph size due to use of single-value operator nodes.
#
# To address these problems we perform two rewrites first, before all other graph rewrites:
# *tensorizing* (in tensorizer_transformer.py) and *detensorizing* (this module).
#
# * Tensorizing is the more straightforward operation. We identify graph nodes in the accumulated
# graph which correspond to array operators already implemented in BMG. In particular, we
# look for nodes representing elementwise multiplication, addition, division, log, exp, and so
# on, where the operands are multidimensional arrays. Those nodes are replaced with the appropriate
# matrix-aware operator node.
#
# * Detensorizing is the more complex rewrite because it attempts to implement "batch" operations
# by doing them on individual elements of a matrix, and then combining the results back into
# a matrix. The long-term goal is to render detensorizing unnecessary by having all the necessary
# matrix operators in BMG.
#
# Detensorizing uses a few basic techniques in combination. We'll go through an example here.
# Suppose we have a MatrixAdd -> HalfCauchy -> X, where X is a node that expects a matrix input
# but there is no matrix version of HalfCauchy. What do we do?
#
# * "Splitting" takes an indexible node and breaks it up into its individual scalar quantities.
#
# MatrixAdd 0 MatrixAdd 1
# | --split-> \ / \ /
# HalfCauchy [ index , index ]
# | |
# ~ HalfCauchy
# | |
# X ~
# |
# X
#
# In this example the input would be the matrix add and the replacement would be a list of index nodes.
#
# * "Scattering" takes the now-ill-formed graph produced by splitting and moves the list "down the graph".
#
# 0 MatrixAdd 1 0 MatrixAdd 1
# \ / \ / \ / \ /
# [ index , index ] --scatter--> index index
# | | |
# HalfCauchy HalfCauchy HalfCauchy
# | | |
# ~ [~ , ~]
# | |
# X X
#
# * Finally, "merging" turns a list of nodes into a tensor node:
#
# 0 MatrixAdd 1 0 MatrixAdd 1
# \ / \ / \ / \ /
# index index --merge--> index index
# | | | |
# HalfCauchy HalfCauchy HalfCauchy HalfCauchy
# | | | |
# [~ , ~] ~ ~
# | \ /
# X Tensor
# |
# X
#
# Now the graph is well-formed again, and we've solved the type system problem that there is
# not (yet) a "matrix half Cauchy", at the cost of having to run some tricky code and generating
# O(n) index and HalfCauchy nodes.
#
# The task of the devectorizer transformer is for each node to identify whether it currently needs
# to be split, scattered or merged in order to fix a problem.

import typing
from enum import Enum
from typing import Callable, Dict, List
Expand All @@ -25,7 +118,8 @@
from beanmachine.ppl.compiler.sizer import is_scalar, Size, Sizer, Unsized
from beanmachine.ppl.compiler.tensorizer_transformer import Tensorizer

# elements in this list operate over tensors (all parameters are tensors) but do not necessarily produce tensors
# These operator nodes take a single matrix input; they do not necessarily produce
# a matrix output.
_unary_tensor_ops = [
bn.LogSumExpVectorNode,
bn.MatrixComplementNode,
Expand All @@ -41,8 +135,10 @@
bn.CholeskyNode,
]

# These operator nodes take two matrix inputs.
_binary_tensor_ops = [bn.ElementwiseMultiplyNode, bn.MatrixAddNode]

# These nodes represent constant matrix values.
_tensor_constants = [
bn.ConstantProbabilityMatrixNode,
bn.ConstantBooleanMatrixNode,
Expand All @@ -54,12 +150,17 @@
bn.UntypedConstantNode,
]

# Thses distributions produce matrix-valued samples.
# TODO: Why is categorical on this list? Categorical has a matrix-valued *input* but
# produces a natural-valued *output*. This is likely an error; investigate further.
_tensor_valued_distributions = [
bn.CategoricalNode,
bn.DirichletNode,
bn.LKJCholeskyNode,
]

# These are nodes which are *possibly* allowed to be the left-hand input of an
# indexing operation.
_indexable_node_types = [
bn.ColumnIndexNode,
bn.ConstantTensorNode,
Expand All @@ -79,16 +180,22 @@
bn.UntypedConstantNode,
]


# This is used to describe the requirements on the *input* of a node; for example,
# matrix scale requires that its first input be a matrix ("TENSOR") and its second
# a scalar.
class ElementType(Enum):
TENSOR = 1
SCALAR = 2
ANY = 3


# This describes what needs to happen to a node.
class DevectorizeTransformation(Enum):
# The node needs to be rewritten.
YES = 1
# The node needs to be rewritten with a merge operation.
YES_WITH_MERGE = 2
# The node is fine as it is.
NO = 3


Expand Down Expand Up @@ -182,6 +289,19 @@ def _parameter_to_type_torch_log_sum_exp(
return ElementType.SCALAR


# The devectorizer has two public APIs.
#
# assess_node determines if the devectorizer can operate on this node at all, which comes down
# to determining if we have information about the shape of the value in the original Python model
# or not. If we cannot determine the operation's shape then we cannot know how to devectorize it.
#
# transform_node says how to replace each node; it returns:
#
# * None, indicating that the node should be deleted
# * a node, giving the drop-in replacement for the given node
# * a list of nodes, which will later be merged or scattered.


class Devectorizer(NodeTransformer):
def __init__(self, cloner: Cloner, sizer: Sizer):
self.copy_context = CopyContext()
Expand Down Expand Up @@ -376,6 +496,7 @@ def _clone(self, node: bn.BMGNode) -> bn.BMGNode:
return n

def __split(self, node: bn.BMGNode) -> List[bn.BMGNode]:
# See comments at the top of this module describing the semantics of split.
size = self.sizer[node]
dim = len(size)
index_list = []
Expand Down Expand Up @@ -417,6 +538,7 @@ def __split(self, node: bn.BMGNode) -> List[bn.BMGNode]:
return index_list

def __scatter(self, node: bn.BMGNode) -> List[bn.BMGNode]:
# See comments at the top of the module describing the semantics of scatter.
parents = self.__get_clone_parents(node)
if isinstance(node, bn.SampleNode):
new_nodes = self.__flatten_parents(
Expand Down
52 changes: 51 additions & 1 deletion src/beanmachine/ppl/compiler/tensorizer_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

#
# Tensorizing and detensorizing
#
# See the comment at the top of devectorizer_transformer.py for a high-level description of
# what this class is for and how it works with the devectorizer.

import typing
from enum import Enum
from typing import Callable, List
Expand All @@ -20,10 +26,18 @@
from beanmachine.ppl.compiler.sizer import is_scalar, Sizer, Unsized


# The tensorizing transformation does not need to know the *semantic* type of a node;
# that is, whether it is a bool, natural, probability, positive real, and so on. But
# we do need information about what the *tensor shape* was in the original PyTorch
# model.
class ElementType(Enum):
# The node represents a multidimensional tensor that cannot be expressed in BMG.
TENSOR = 1
# The node represents a single value.
SCALAR = 2
# The node represents multiple values that can be expressed in a BMG 2-d matrix.
MATRIX = 3
# We were unable to deduce the size in the original Python model.
UNKNOWN = 4


Expand All @@ -32,6 +46,14 @@ def _always(node):


class Tensorizer(NodeTransformer):
# A node transformer exposes two operations to its caller:
# * assess_node takes a node and returns an assessement of whether it can be
# transformed.
# * transform_node takes a node and either returns a copy, or a new node to
# replace the given node.
#
# This transformer determines whether a node in the graph accumulated from the
# original Python model should be transformed into a matrix-aware BMG node.
def __init__(self, cloner: Cloner, sizer: Sizer):
self.cloner = cloner
self.sizer = sizer
Expand Down Expand Up @@ -76,6 +98,8 @@ def __init__(self, cloner: Cloner, sizer: Sizer):
def _tensorize_div(
self, node: bn.DivisionNode, new_inputs: List[bn.BMGNode]
) -> bn.BMGNode:
# If we have DIV(matrix, scalar) then we transform that into
# MATRIX_SCALE(matrix, DIV(1, scalar)).
assert len(node.inputs.inputs) == 2
tensor_input = new_inputs[0]
scalar_input = new_inputs[1]
Expand All @@ -90,8 +114,9 @@ def _tensorize_div(
def _tensorize_sum(
self, node: bn.SumNode, new_inputs: List[bn.BMGNode]
) -> bn.BMGNode:
# TODO: Ensure that we correctly insert any necessary broadcasting nodes
# in the requirements-fixing pass.
assert len(new_inputs) >= 1
# note that scalars can be broadcast
if any(
self._element_type(operand) == ElementType.MATRIX
for operand in node.inputs.inputs
Expand All @@ -105,6 +130,8 @@ def _tensorize_sum(
def _tensorize_multiply(
self, node: bn.MultiplicationNode, new_inputs: List[bn.BMGNode]
) -> bn.BMGNode:
# Note that this function handles *elementwise* multiplication of tensors, not
# matrix multiplication. There are three cases to consider.
if len(new_inputs) != 2:
raise ValueError(
"Cannot transform a mult into a tensor mult because there are not two operands"
Expand All @@ -115,12 +142,15 @@ def _tensorize_multiply(
raise ValueError(
f"cannot multiply an unsized quantity. Operands: {new_inputs[0]} and {new_inputs[1]}"
)
# Case one: MULT(matrix, matrix) --> ELEMENTWISEMULT(matrix, matrix)
# TODO: Ensure that the requirements fixing pass correctly inserts broadcast operators.
lhs_is_scalar = is_scalar(lhs_sz)
rhs_is_scalar = is_scalar(rhs_sz)
if not lhs_is_scalar and not rhs_is_scalar:
return self.cloner.bmg.add_elementwise_multiplication(
new_inputs[0], new_inputs[1]
)
# Case two: MULT(scalar, scalar) stays just that.
if lhs_is_scalar and not rhs_is_scalar:
scalar_parent_image = new_inputs[0]
tensor_parent_image = new_inputs[1]
Expand All @@ -131,6 +161,7 @@ def _tensorize_multiply(
assert is_scalar(rhs_sz)
else:
return self.cloner.bmg.add_multiplication(new_inputs[0], new_inputs[1])
# Case three: MULT(matrix, scalar) or MULT(scalar, matrix) --> MATRIX_SCALE(matrix, scalar)
return self.cloner.bmg.add_matrix_scale(
scalar_parent_image, tensor_parent_image
)
Expand All @@ -141,6 +172,8 @@ def _tensorize_unary_elementwise(
new_inputs: List[bn.BMGNode],
creator: Callable,
) -> bn.BMGNode:
# Unary operators such as exp, log, and so on, are straightforward. If the operand is
# a matrix, generate the matrix-aware node. Otherwise leave it alone.
assert len(new_inputs) == 1
if self._element_type(new_inputs[0]) == ElementType.MATRIX:
return creator(new_inputs[0])
Expand All @@ -153,6 +186,23 @@ def _tensorize_binary_elementwise(
new_inputs: List[bn.BMGNode],
creator: Callable,
) -> bn.BMGNode:
# TODO: This code is only called for addition nodes, so making it generalized
# to arbitrary binary operators is slightly misleading. Moreover, once we have
# broadcasting operations implemented correctly in the requirements fixing pass,
# this code is not quite right.
#
# The meaning of the code today is: create a MatrixAdd IFF *both* operands are
# multidimensional. This means that for the case where we have matrix + scalar,
# we do NOT generate a matrix add here; instead we keep it a regular add and the
# devectorizer tears the matrix apart, does scalar additions, and then puts it
# back together.
#
# Once we have broadcasting fixers in the requirements fixing pass, the correct
# behavior here will be to generate the matrix add if EITHER or BOTH operands are
# multidimensional. In the case where we have matrix + scalar, we do not want to
# devectorize the matrix, we want to matrix-fill the scalar, and now we have
# matrix + matrix.

assert len(new_inputs) == 2
if (
self._element_type(new_inputs[0]) == ElementType.MATRIX
Expand Down