Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Commit

Permalink
copy from add_matrix_ops
Browse files Browse the repository at this point in the history
Summary: Adds the operators MatrixAdd, MatrixSum, MatrixExp, and ElementwiseMult

Reviewed By: ericlippert

Differential Revision: D38794274

fbshipit-source-id: 0d4299f8f106edd3cee3f0bf1cf26d205bd42869
  • Loading branch information
Steffi Stumpos authored and facebook-github-bot committed Aug 19, 2022
1 parent b346237 commit df1e11d
Show file tree
Hide file tree
Showing 12 changed files with 954 additions and 3 deletions.
32 changes: 32 additions & 0 deletions src/beanmachine/ppl/compiler/bm_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,38 @@ def add_query(self, operator: BMGNode) -> bn.Query:
self.add_node(node)
return node

@memoize
def add_elementwise_multiplication(self, left: BMGNode, right: BMGNode) -> BMGNode:
if isinstance(left, ConstantNode) and isinstance(right, ConstantNode):
return self.add_constant(left.value * right.value)
node = bn.ElementwiseMultiplyNode(left, right)
self.add_node(node)
return node

@memoize
def add_matrix_addition(self, left: BMGNode, right: BMGNode) -> BMGNode:
if isinstance(left, ConstantNode) and isinstance(right, ConstantNode):
return self.add_constant(left.value + right.value)
node = bn.MatrixAddNode(left, right)
self.add_node(node)
return node

@memoize
def add_matrix_sum(self, matrix: BMGNode) -> BMGNode:
if isinstance(matrix, ConstantNode):
return self.add_constant(matrix.value.sum())
node = bn.MatrixSumNode(matrix)
self.add_node(node)
return node

@memoize
def add_matrix_exp(self, matrix: BMGNode) -> BMGNode:
if isinstance(matrix, ConstantNode):
return self.add_constant(matrix.value.exp())
node = bn.MatrixExpNode(matrix)
self.add_node(node)
return node

def add_exp_product(self, *inputs: BMGNode) -> bn.ExpProductFactorNode:
# Note that factors are NOT deduplicated; this method is not
# memoized. We need to be able to add multiple factors to the same
Expand Down
4 changes: 4 additions & 0 deletions src/beanmachine/ppl/compiler/bmg_node_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def dist_type(node: bn.DistributionNode) -> Tuple[dt, Any]:
bn.CholeskyNode: OperatorType.CHOLESKY,
bn.ColumnIndexNode: OperatorType.COLUMN_INDEX,
bn.ComplementNode: OperatorType.COMPLEMENT,
bn.ElementwiseMultiplyNode: OperatorType.ELEMENTWISE_MULTIPLY,
bn.ExpM1Node: OperatorType.EXPM1,
bn.ExpNode: OperatorType.EXP,
bn.IfThenElseNode: OperatorType.IF_THEN_ELSE,
Expand All @@ -78,8 +79,11 @@ def dist_type(node: bn.DistributionNode) -> Tuple[dt, Any]:
bn.LogisticNode: OperatorType.LOGISTIC,
bn.LogSumExpNode: OperatorType.LOGSUMEXP,
bn.LogSumExpVectorNode: OperatorType.LOGSUMEXP_VECTOR,
bn.MatrixAddNode: OperatorType.MATRIX_ADD,
bn.MatrixExpNode: OperatorType.MATRIX_EXP,
bn.MatrixMultiplicationNode: OperatorType.MATRIX_MULTIPLY,
bn.MatrixScaleNode: OperatorType.MATRIX_SCALE,
bn.MatrixSumNode: OperatorType.MATRIX_SUM,
bn.MultiplicationNode: OperatorType.MULTIPLY,
bn.NegateNode: OperatorType.NEGATE,
bn.PhiNode: OperatorType.PHI,
Expand Down
38 changes: 38 additions & 0 deletions src/beanmachine/ppl/compiler/bmg_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,14 @@ def __str__(self) -> str:
return f"({str(self.left)}<={str(self.right)})"


class ElementwiseMultiplyNode(BinaryOperatorNode):
def __init__(self, left: BMGNode, right: BMGNode):
BinaryOperatorNode.__init__(self, left, right)

def __str__(self) -> str:
return f"{self.left} * {self.right}"


class EqualNode(ComparisonNode):
def __init__(self, left: BMGNode, right: BMGNode):
ComparisonNode.__init__(self, left, right)
Expand Down Expand Up @@ -1035,6 +1043,17 @@ def __str__(self) -> str:
return "(" + str(self.left) + "*" + str(self.right) + ")"


class MatrixAddNode(BinaryOperatorNode):
"""This represents an exponentiation operation; it is generated when
a model contains calls to Tensor.exp or math.exp."""

def __init__(self, left: BMGNode, right: BMGNode):
BinaryOperatorNode.__init__(self, left, right)

def __str__(self) -> str:
return f"{self.left} + {self.right}"


class MatrixScaleNode(BinaryOperatorNode):
"""This represents a matrix scaling."""

Expand Down Expand Up @@ -1183,6 +1202,25 @@ def __str__(self) -> str:
return "Log1mexp(" + str(self.operand) + ")"


class MatrixExpNode(UnaryOperatorNode):
"""This represents an exponentiation operation; it is generated when
a model contains calls to Tensor.exp or math.exp."""

def __init__(self, operand: BMGNode):
UnaryOperatorNode.__init__(self, operand)

def __str__(self) -> str:
return "MatrixExp(" + str(self.operand) + ")"


class MatrixSumNode(UnaryOperatorNode):
def __init__(self, operand: BMGNode):
UnaryOperatorNode.__init__(self, operand)

def __str__(self) -> str:
return "MatrixSum(" + str(self.operand) + ")"


class TransposeNode(UnaryOperatorNode):
"""This represents a transpose operation; it is generated when
a model contains calls to transpose or Tensor.transpose"""
Expand Down
4 changes: 4 additions & 0 deletions src/beanmachine/ppl/compiler/bmg_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
# what their inputs are.

_known_requirements: Dict[type, List[bt.Requirement]] = {
bn.ElementwiseMultiplyNode: [bt.RealMatrix, bt.RealMatrix],
bn.Observation: [bt.any_requirement],
bn.Query: [bt.any_requirement],
# Distributions
Expand All @@ -53,6 +54,9 @@
bn.LogisticNode: [bt.Real],
bn.Log1mexpNode: [bt.NegativeReal],
bn.MatrixMultiplicationNode: [bt.any_real_matrix, bt.any_real_matrix],
bn.MatrixAddNode: [bt.RealMatrix, bt.RealMatrix],
bn.MatrixExpNode: [bt.any_real_matrix],
bn.MatrixSumNode: [bt.any_real_matrix],
bn.PhiNode: [bt.Real],
bn.ToIntNode: [bt.upper_bound(bt.Real)],
bn.ToNegativeRealNode: [bt.Real],
Expand Down
67 changes: 66 additions & 1 deletion src/beanmachine/ppl/compiler/fix_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
returned."""


from typing import Tuple

import beanmachine.ppl.compiler.bmg_nodes as bn
import beanmachine.ppl.compiler.bmg_types as bt
from beanmachine.ppl.compiler.bm_graph_builder import BMGraphBuilder
Expand Down Expand Up @@ -63,6 +65,8 @@ def _type_meets_requirement(self, t: bt.BMGLatticeType, r: bt.Requirement) -> bo
return True
if r is bt.any_real_matrix:
return _is_real_matrix(t)
if r == bt.RealMatrix:
return isinstance(t, bt.RealMatrix)
if isinstance(r, bt.UpperBound):
return bt.supremum(t, r.bound) == r.bound
if isinstance(r, bt.AlwaysMatrix):
Expand Down Expand Up @@ -114,6 +118,12 @@ def _meet_constant_requirement(
# Emit the value as the equivalent real matrix:
return self.bmg.add_real_matrix(node.value)

if requirement == bt.RealMatrix:
if isinstance(it, bt.RealMatrix):
return self.bmg.add_constant_of_matrix_type(node.value, it)
else:
return self.bmg.add_real_matrix(node.value)

if self._type_meets_requirement(it, bt.upper_bound(requirement)):
if requirement is bt.any_requirement:
# The lattice type of the constant might be Zero or One; in that case,
Expand Down Expand Up @@ -283,6 +293,45 @@ def _can_force_to_neg_real(
or requirement == bt.upper_bound(bt.NegativeReal)
) and node_type == bt.Real

def _meet_real_matrix_requirement_type(
self, node: bn.OperatorNode, node_dim: Tuple[int, int]
) -> bn.BMGNode:
if node_dim[0] == 1 and node_dim[1] == 1:
result = self.bmg.add_to_real(node)
else:
result = self.bmg.add_to_real_matrix(node)
return result

def _meet_real_matrix_requirement(
self,
node: bn.OperatorNode,
dim_req: Tuple[int, int],
node_dim: Tuple[int, int],
consumer: bn.BMGNode,
edge: str,
) -> bn.BMGNode:
result = None
node_is_scalar = node_dim[0] == 1 and node_dim[1] == 1
requires_scalar = dim_req[0] == 1 and dim_req[1] == 1
if requires_scalar and node_is_scalar:
result = self.bmg.add_to_real(node)
elif node_dim[0] == dim_req[0] and node_dim[1] == dim_req[1]:
result = self.bmg.add_to_real_matrix(node)

if result is None:
self.errors.add_error(
Violation(
node,
self._typer[node],
bt.RealMatrix(1, 1),
consumer,
edge,
self.bmg.execution_context.node_locations(consumer),
)
)
return node
return result

def _meet_operator_requirement(
self,
node: bn.OperatorNode,
Expand All @@ -299,7 +348,23 @@ def _meet_operator_requirement(
# meets an upper bound requirement, then the conversion we want exists.

node_type = self._typer[node]
if requirement is bt.any_real_matrix:
if isinstance(node_type, bt.BMGMatrixType):
rows = node_type.rows
columns = node_type.columns
else:
rows = 1
columns = 1
if isinstance(requirement, bt.RealMatrix):
return self._meet_real_matrix_requirement(
node,
dim_req=(requirement.rows, requirement.columns),
node_dim=(rows, columns),
consumer=consumer,
edge=edge,
)
elif requirement == bt.RealMatrix:
return self._meet_real_matrix_requirement_type(node, (rows, columns))
elif requirement is bt.any_real_matrix:
result = self.bmg.add_to_real_matrix(node)
elif self._type_meets_requirement(node_type, bt.upper_bound(requirement)):
# If we got here then the node did NOT meet the requirement,
Expand Down
3 changes: 2 additions & 1 deletion src/beanmachine/ppl/compiler/fix_unsupported.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,8 @@ def get_error(node: bn.BMGNode) -> Optional[BMGError]:
return None
for i in node.inputs:
t = typer[i]
if t == bt.Untypable or t == bt.Tensor:
is_tensor = t == bt.Tensor or isinstance(t, bt.BMGMatrixType)
if t == bt.Untypable or is_tensor:
return None
return UntypableNode(node, bmg.execution_context.node_locations(node))

Expand Down
8 changes: 8 additions & 0 deletions src/beanmachine/ppl/compiler/graph_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def _val(node: bn.ConstantNode) -> str:
bn.ExpProductFactorNode: "ExpProduct",
bn.FlatNode: "Flat",
bn.FloorDivNode: "//",
bn.ElementwiseMultiplyNode: "ElementwiseMult",
bn.GammaNode: "Gamma",
bn.GreaterThanEqualNode: ">=",
bn.GreaterThanNode: ">",
Expand All @@ -88,8 +89,11 @@ def _val(node: bn.ConstantNode) -> str:
bn.LogSumExpVectorNode: "LogSumExp",
bn.LogAddExpNode: "LogAddExp",
bn.LShiftNode: "<<",
bn.MatrixAddNode: "MatrixAdd",
bn.MatrixExpNode: "MatrixExp",
bn.MatrixMultiplicationNode: "@",
bn.MatrixScaleNode: "MatrixScale",
bn.MatrixSumNode: "MatrixSum",
bn.ModNode: "%",
bn.MultiplicationNode: "*",
bn.NaturalNode: _val,
Expand Down Expand Up @@ -274,6 +278,7 @@ def _prefix_numbered(prefix: List[str]) -> Callable:
bn.ConstantTensorNode: _none,
bn.DirichletNode: ["concentration"],
bn.DivisionNode: _left_right,
bn.ElementwiseMultiplyNode: _left_right,
bn.EqualNode: _left_right,
bn.ExpM1Node: _operand,
bn.ExpNode: _operand,
Expand All @@ -300,8 +305,11 @@ def _prefix_numbered(prefix: List[str]) -> Callable:
bn.LogSumExpVectorNode: _operand,
bn.LogAddExpNode: _left_right,
bn.SwitchNode: _numbered_or_left_right,
bn.MatrixAddNode: _left_right,
bn.MatrixExpNode: _operand,
bn.MatrixMultiplicationNode: _left_right,
bn.MatrixScaleNode: _numbered_or_left_right,
bn.MatrixSumNode: _operand,
bn.MultiplicationNode: _numbered_or_left_right,
bn.NaturalNode: _none,
bn.NegateNode: _operand,
Expand Down
Loading

0 comments on commit df1e11d

Please sign in to comment.