From 9ddf9622dccf0270683ae594e2aa1443431988d3 Mon Sep 17 00:00:00 2001 From: Eric Lippert Date: Tue, 25 Oct 2022 10:16:46 -0700 Subject: [PATCH] Update tensorizer / detensorizer comments Summary: The tensorizing / detensorizing code is some of the more complicated code in the compiler and it could benefit from some explanatory comments to demystify it to the new reader. Reviewed By: AishwaryaSivaraman Differential Revision: D40571996 fbshipit-source-id: 547a794f6478d0350e5152b97a7baf6617420d83 --- .../ppl/compiler/devectorizer_transformer.py | 126 +++++++++++++++++- .../ppl/compiler/tensorizer_transformer.py | 52 +++++++- 2 files changed, 175 insertions(+), 3 deletions(-) diff --git a/src/beanmachine/ppl/compiler/devectorizer_transformer.py b/src/beanmachine/ppl/compiler/devectorizer_transformer.py index 5015dc7e31..d8bf080824 100644 --- a/src/beanmachine/ppl/compiler/devectorizer_transformer.py +++ b/src/beanmachine/ppl/compiler/devectorizer_transformer.py @@ -3,6 +3,99 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +# +# Tensorizing and detensorizing +# +# TODO: The fact that we use vectorizing and tensorizing as synonyms throughout all this code is +# unnecessarily confusing. Pick one and stick to it. +# +# There are a number of important mismatches between the BMG and PyTorch type systems +# that we need to address when converting the accumulated graph of a PyTorch model into +# a BMG graph data structure. +# +# * PyTorch values are tensors: rectangular arrays of arbitrary dimensionality. BMG +# values are either "scalars" -- single values -- or Eigen arrays == two-dimensional +# rectangles. +# +# * PyTorch uses the same functions regardless of dimensionality; tensor(10).log() and +# tensor([[[10, 20, 30], [40, 50, 60]]]).log() are both represented by .log(). BMG +# has distinct operators for LOG and MATRIX_LOG, MULTIPLY and ELEMENTWISE_MULTIPLY +# and so on. +# +# * PyTorch's type system describes *storage format*: a tensor can be a tensor of integers +# or a tensor of doubles. BMG's type system describes *data semantics*: floating point +# types are real, positive real, negative real and probability, discrete types are natural +# and bool. +# +# Moreover: we are gradually adding matrix-flavor operators to BMG, but they're not all +# there yet. It would be nice to be able to compile models that use multidimensional tensors +# even if doing so results in a larger graph size due to use of single-value operator nodes. +# +# To address these problems we perform two rewrites first, before all other graph rewrites: +# *tensorizing* (in tensorizer_transformer.py) and *detensorizing* (this module). +# +# * Tensorizing is the more straightforward operation. We identify graph nodes in the accumulated +# graph which correspond to array operators already implemented in BMG. In particular, we +# look for nodes representing elementwise multiplication, addition, division, log, exp, and so +# on, where the operands are multidimensional arrays. Those nodes are replaced with the appropriate +# matrix-aware operator node. +# +# * Detensorizing is the more complex rewrite because it attempts to implement "batch" operations +# by doing them on individual elements of a matrix, and then combining the results back into +# a matrix. The long-term goal is to render detensorizing unnecessary by having all the necessary +# matrix operators in BMG. +# +# Detensorizing uses a few basic techniques in combination. We'll go through an example here. +# Suppose we have a MatrixAdd -> HalfCauchy -> X, where X is a node that expects a matrix input +# but there is no matrix version of HalfCauchy. What do we do? +# +# * "Splitting" takes an indexible node and breaks it up into its individual scalar quantities. +# +# MatrixAdd 0 MatrixAdd 1 +# | --split-> \ / \ / +# HalfCauchy [ index , index ] +# | | +# ~ HalfCauchy +# | | +# X ~ +# | +# X +# +# In this example the input would be the matrix add and the replacement would be a list of index nodes. +# +# * "Scattering" takes the now-ill-formed graph produced by splitting and moves the list "down the graph". +# +# 0 MatrixAdd 1 0 MatrixAdd 1 +# \ / \ / \ / \ / +# [ index , index ] --scatter--> index index +# | | | +# HalfCauchy HalfCauchy HalfCauchy +# | | | +# ~ [~ , ~] +# | | +# X X +# +# * Finally, "merging" turns a list of nodes into a tensor node: +# +# 0 MatrixAdd 1 0 MatrixAdd 1 +# \ / \ / \ / \ / +# index index --merge--> index index +# | | | | +# HalfCauchy HalfCauchy HalfCauchy HalfCauchy +# | | | | +# [~ , ~] ~ ~ +# | \ / +# X Tensor +# | +# X +# +# Now the graph is well-formed again, and we've solved the type system problem that there is +# not (yet) a "matrix half Cauchy", at the cost of having to run some tricky code and generating +# O(n) index and HalfCauchy nodes. +# +# The task of the devectorizer transformer is for each node to identify whether it currently needs +# to be split, scattered or merged in order to fix a problem. + import typing from enum import Enum from typing import Callable, Dict, List @@ -25,7 +118,8 @@ from beanmachine.ppl.compiler.sizer import is_scalar, Size, Sizer, Unsized from beanmachine.ppl.compiler.tensorizer_transformer import Tensorizer -# elements in this list operate over tensors (all parameters are tensors) but do not necessarily produce tensors +# These operator nodes take a single matrix input; they do not necessarily produce +# a matrix output. _unary_tensor_ops = [ bn.LogSumExpVectorNode, bn.MatrixComplementNode, @@ -41,8 +135,10 @@ bn.CholeskyNode, ] +# These operator nodes take two matrix inputs. _binary_tensor_ops = [bn.ElementwiseMultiplyNode, bn.MatrixAddNode] +# These nodes represent constant matrix values. _tensor_constants = [ bn.ConstantProbabilityMatrixNode, bn.ConstantBooleanMatrixNode, @@ -54,12 +150,17 @@ bn.UntypedConstantNode, ] +# Thses distributions produce matrix-valued samples. +# TODO: Why is categorical on this list? Categorical has a matrix-valued *input* but +# produces a natural-valued *output*. This is likely an error; investigate further. _tensor_valued_distributions = [ bn.CategoricalNode, bn.DirichletNode, bn.LKJCholeskyNode, ] +# These are nodes which are *possibly* allowed to be the left-hand input of an +# indexing operation. _indexable_node_types = [ bn.ColumnIndexNode, bn.ConstantTensorNode, @@ -79,16 +180,22 @@ bn.UntypedConstantNode, ] - +# This is used to describe the requirements on the *input* of a node; for example, +# matrix scale requires that its first input be a matrix ("TENSOR") and its second +# a scalar. class ElementType(Enum): TENSOR = 1 SCALAR = 2 ANY = 3 +# This describes what needs to happen to a node. class DevectorizeTransformation(Enum): + # The node needs to be rewritten. YES = 1 + # The node needs to be rewritten with a merge operation. YES_WITH_MERGE = 2 + # The node is fine as it is. NO = 3 @@ -182,6 +289,19 @@ def _parameter_to_type_torch_log_sum_exp( return ElementType.SCALAR +# The devectorizer has two public APIs. +# +# assess_node determines if the devectorizer can operate on this node at all, which comes down +# to determining if we have information about the shape of the value in the original Python model +# or not. If we cannot determine the operation's shape then we cannot know how to devectorize it. +# +# transform_node says how to replace each node; it returns: +# +# * None, indicating that the node should be deleted +# * a node, giving the drop-in replacement for the given node +# * a list of nodes, which will later be merged or scattered. + + class Devectorizer(NodeTransformer): def __init__(self, cloner: Cloner, sizer: Sizer): self.copy_context = CopyContext() @@ -376,6 +496,7 @@ def _clone(self, node: bn.BMGNode) -> bn.BMGNode: return n def __split(self, node: bn.BMGNode) -> List[bn.BMGNode]: + # See comments at the top of this module describing the semantics of split. size = self.sizer[node] dim = len(size) index_list = [] @@ -417,6 +538,7 @@ def __split(self, node: bn.BMGNode) -> List[bn.BMGNode]: return index_list def __scatter(self, node: bn.BMGNode) -> List[bn.BMGNode]: + # See comments at the top of the module describing the semantics of scatter. parents = self.__get_clone_parents(node) if isinstance(node, bn.SampleNode): new_nodes = self.__flatten_parents( diff --git a/src/beanmachine/ppl/compiler/tensorizer_transformer.py b/src/beanmachine/ppl/compiler/tensorizer_transformer.py index 8da592639f..1d646e23cc 100644 --- a/src/beanmachine/ppl/compiler/tensorizer_transformer.py +++ b/src/beanmachine/ppl/compiler/tensorizer_transformer.py @@ -3,6 +3,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +# +# Tensorizing and detensorizing +# +# See the comment at the top of devectorizer_transformer.py for a high-level description of +# what this class is for and how it works with the devectorizer. + import typing from enum import Enum from typing import Callable, List @@ -20,10 +26,18 @@ from beanmachine.ppl.compiler.sizer import is_scalar, Sizer, Unsized +# The tensorizing transformation does not need to know the *semantic* type of a node; +# that is, whether it is a bool, natural, probability, positive real, and so on. But +# we do need information about what the *tensor shape* was in the original PyTorch +# model. class ElementType(Enum): + # The node represents a multidimensional tensor that cannot be expressed in BMG. TENSOR = 1 + # The node represents a single value. SCALAR = 2 + # The node represents multiple values that can be expressed in a BMG 2-d matrix. MATRIX = 3 + # We were unable to deduce the size in the original Python model. UNKNOWN = 4 @@ -32,6 +46,14 @@ def _always(node): class Tensorizer(NodeTransformer): + # A node transformer exposes two operations to its caller: + # * assess_node takes a node and returns an assessement of whether it can be + # transformed. + # * transform_node takes a node and either returns a copy, or a new node to + # replace the given node. + # + # This transformer determines whether a node in the graph accumulated from the + # original Python model should be transformed into a matrix-aware BMG node. def __init__(self, cloner: Cloner, sizer: Sizer): self.cloner = cloner self.sizer = sizer @@ -76,6 +98,8 @@ def __init__(self, cloner: Cloner, sizer: Sizer): def _tensorize_div( self, node: bn.DivisionNode, new_inputs: List[bn.BMGNode] ) -> bn.BMGNode: + # If we have DIV(matrix, scalar) then we transform that into + # MATRIX_SCALE(matrix, DIV(1, scalar)). assert len(node.inputs.inputs) == 2 tensor_input = new_inputs[0] scalar_input = new_inputs[1] @@ -90,8 +114,9 @@ def _tensorize_div( def _tensorize_sum( self, node: bn.SumNode, new_inputs: List[bn.BMGNode] ) -> bn.BMGNode: + # TODO: Ensure that we correctly insert any necessary broadcasting nodes + # in the requirements-fixing pass. assert len(new_inputs) >= 1 - # note that scalars can be broadcast if any( self._element_type(operand) == ElementType.MATRIX for operand in node.inputs.inputs @@ -105,6 +130,8 @@ def _tensorize_sum( def _tensorize_multiply( self, node: bn.MultiplicationNode, new_inputs: List[bn.BMGNode] ) -> bn.BMGNode: + # Note that this function handles *elementwise* multiplication of tensors, not + # matrix multiplication. There are three cases to consider. if len(new_inputs) != 2: raise ValueError( "Cannot transform a mult into a tensor mult because there are not two operands" @@ -115,12 +142,15 @@ def _tensorize_multiply( raise ValueError( f"cannot multiply an unsized quantity. Operands: {new_inputs[0]} and {new_inputs[1]}" ) + # Case one: MULT(matrix, matrix) --> ELEMENTWISEMULT(matrix, matrix) + # TODO: Ensure that the requirements fixing pass correctly inserts broadcast operators. lhs_is_scalar = is_scalar(lhs_sz) rhs_is_scalar = is_scalar(rhs_sz) if not lhs_is_scalar and not rhs_is_scalar: return self.cloner.bmg.add_elementwise_multiplication( new_inputs[0], new_inputs[1] ) + # Case two: MULT(scalar, scalar) stays just that. if lhs_is_scalar and not rhs_is_scalar: scalar_parent_image = new_inputs[0] tensor_parent_image = new_inputs[1] @@ -131,6 +161,7 @@ def _tensorize_multiply( assert is_scalar(rhs_sz) else: return self.cloner.bmg.add_multiplication(new_inputs[0], new_inputs[1]) + # Case three: MULT(matrix, scalar) or MULT(scalar, matrix) --> MATRIX_SCALE(matrix, scalar) return self.cloner.bmg.add_matrix_scale( scalar_parent_image, tensor_parent_image ) @@ -141,6 +172,8 @@ def _tensorize_unary_elementwise( new_inputs: List[bn.BMGNode], creator: Callable, ) -> bn.BMGNode: + # Unary operators such as exp, log, and so on, are straightforward. If the operand is + # a matrix, generate the matrix-aware node. Otherwise leave it alone. assert len(new_inputs) == 1 if self._element_type(new_inputs[0]) == ElementType.MATRIX: return creator(new_inputs[0]) @@ -153,6 +186,23 @@ def _tensorize_binary_elementwise( new_inputs: List[bn.BMGNode], creator: Callable, ) -> bn.BMGNode: + # TODO: This code is only called for addition nodes, so making it generalized + # to arbitrary binary operators is slightly misleading. Moreover, once we have + # broadcasting operations implemented correctly in the requirements fixing pass, + # this code is not quite right. + # + # The meaning of the code today is: create a MatrixAdd IFF *both* operands are + # multidimensional. This means that for the case where we have matrix + scalar, + # we do NOT generate a matrix add here; instead we keep it a regular add and the + # devectorizer tears the matrix apart, does scalar additions, and then puts it + # back together. + # + # Once we have broadcasting fixers in the requirements fixing pass, the correct + # behavior here will be to generate the matrix add if EITHER or BOTH operands are + # multidimensional. In the case where we have matrix + scalar, we do not want to + # devectorize the matrix, we want to matrix-fill the scalar, and now we have + # matrix + matrix. + assert len(new_inputs) == 2 if ( self._element_type(new_inputs[0]) == ElementType.MATRIX