Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Refactoring and minor bug fixing in requirements fixer #1760

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions src/beanmachine/ppl/compiler/bmg_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,6 @@
# what their inputs are.

_known_requirements: Dict[type, List[bt.Requirement]] = {
# TODO: This is wrong in several ways.
# First, RealMatrix does not meet the contract of a requirement;
# in particular, it cannot be printed out by the requirement diagnostic
# in gen_to_dot.
# Second, it is too strict; the requirement on matrix add is actually
# that the two operands be any double matrix (real, neg real,
# pos real or probability).
# Third, this requirement is too weak; we are missing the requirement
# that the operands have the same element type and shape.
bn.ElementwiseMultiplyNode: [bt.RealMatrix, bt.RealMatrix],
bn.Observation: [bt.any_requirement],
bn.Query: [bt.any_requirement],
# Distributions
Expand Down Expand Up @@ -75,6 +65,7 @@
# don't check them.
bn.LogisticNode: [bt.Real],
bn.Log1mexpNode: [bt.NegativeReal],
# TODO: Check the dimensions. Consider broadcasting if possible.
bn.MatrixMultiplicationNode: [bt.any_real_matrix, bt.any_real_matrix],
bn.MatrixExpNode: [bt.any_real_matrix],
bn.MatrixLogNode: [bt.any_pos_real_matrix],
Expand Down Expand Up @@ -115,6 +106,7 @@ def __init__(self, typer: LatticeTyper) -> None:
bn.ChoiceNode: self._requirements_choice,
bn.ColumnIndexNode: self._requirements_column_index,
bn.ComplementNode: self._same_as_output,
bn.ElementwiseMultiplyNode: self._requirements_elementwise_mult,
bn.ExpM1Node: self._same_as_output,
bn.ExpNode: self._requirements_exp_neg,
bn.IfThenElseNode: self._requirements_if,
Expand All @@ -124,7 +116,7 @@ def __init__(self, typer: LatticeTyper) -> None:
bn.LogSumExpVectorNode: self._requirements_logsumexp_vector,
# TODO: bn.MatrixMultiplyNode: self._requirements_matrix_multiply,
# see comment above
bn.MatrixComplementNode: self._requrirements_matrix_complement,
bn.MatrixComplementNode: self._requirements_matrix_complement,
bn.MatrixAddNode: self._requirements_matrix_add,
bn.MatrixScaleNode: self._requirements_matrix_scale,
bn.MultiplicationNode: self._requirements_multiplication,
Expand Down Expand Up @@ -428,7 +420,7 @@ def _requirements_multiplication(
assert it in {bt.Probability, bt.PositiveReal, bt.Real}
return [it] * len(node.inputs) # pyre-ignore

def _requrirements_matrix_complement(
def _requirements_matrix_complement(
self, node: bn.MatrixComplementNode
) -> List[bt.Requirement]:
it = self.typer[node]
Expand All @@ -446,6 +438,14 @@ def _requrirements_matrix_complement(
req = [bt.SimplexMatrix]
return req

def _requirements_elementwise_mult(
self, node: bn.ElementwiseMultiplyNode
) -> List[bt.Requirement]:
# Elementwise multiply requires that both operands be the same as the output type.
it = self.typer[node]
assert isinstance(it, bt.BMGMatrixType)
return [it, it]

def _requirements_matrix_add(self, node: bn.MatrixAddNode) -> List[bt.Requirement]:
# Matrix add requires that both operands be the same as the output type.
it = self.typer[node]
Expand Down
254 changes: 151 additions & 103 deletions src/beanmachine/ppl/compiler/fix_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
returned."""


from typing import Tuple
from typing import Optional

import beanmachine.ppl.compiler.bmg_nodes as bn
import beanmachine.ppl.compiler.bmg_types as bt
Expand Down Expand Up @@ -160,9 +160,7 @@ def _meet_constant_requirement(
result = self.bmg.add_constant_of_matrix_type(node.value, required_type)
else:
result = self.bmg.add_constant_of_type(node.value, required_type)
assert self._node_meets_requirement(
result, requirement
), f"{str(result)} {str(requirement)} {str(required_type)} {str(self._typer[result])} {str(self._type_meets_requirement(self._typer[result], requirement))}"
assert self._node_meets_requirement(result, requirement)
return result

# We cannot convert this node to any type that meets the requirement.
Expand Down Expand Up @@ -320,127 +318,177 @@ def _can_force_to_neg_real(
or requirement == bt.upper_bound(bt.NegativeReal)
) and node_type == bt.Real

def _meet_real_matrix_requirement_type(
self, node: bn.OperatorNode, node_dim: Tuple[int, int]
) -> bn.BMGNode:
if node_dim[0] == 1 and node_dim[1] == 1:
result = self.bmg.add_to_real(node)
else:
result = self.bmg.add_to_real_matrix(node)
def _try_to_meet_any_real_matrix_requirement(
self,
node: bn.OperatorNode,
requirement: bt.Requirement,
) -> Optional[bn.BMGNode]:

assert not self._node_meets_requirement(node, requirement)

# Is the requirement that we have a real-valued matrix, but we haven't got
# a real-valued matrix? Every value can be converted to a real-valued matrix,
# so just insert the conversion node.

if requirement is not bt.any_real_matrix:
return None

result = self.bmg.add_to_real_matrix(node)
assert self._node_meets_requirement(result, requirement)
return result

def _meet_real_matrix_requirement(
def _try_to_meet_any_pos_real_matrix_requirement(
self,
node: bn.OperatorNode,
dim_req: Tuple[int, int],
node_dim: Tuple[int, int],
requirement: bt.Requirement,
) -> Optional[bn.BMGNode]:

assert not self._node_meets_requirement(node, requirement)

# Is the requirement that we have a pos-real-valued matrix? Anything that
# is not known to be negative can be a positive real matrix.

if requirement is not bt.any_pos_real_matrix:
return None

node_type = self._typer[node]
if isinstance(node_type, bt.NegativeRealMatrix):
return None

result = self.bmg.add_to_positive_real_matrix(node)
assert self._node_meets_requirement(result, requirement)
return result

def _try_to_meet_upper_bound_requirement(
self,
node: bn.OperatorNode,
requirement: bt.Requirement,
consumer: bn.BMGNode,
edge: str,
) -> bn.BMGNode:
result = None
req_rows, req_cols = dim_req
node_rows, node_cols = node_dim
node_is_scalar = node_rows == 1 and node_cols == 1
requires_scalar = req_rows == 1 and req_cols == 1
if requires_scalar and node_is_scalar:
result = self.bmg.add_to_real(node)
elif node_rows == req_rows and node_cols == req_cols:
result = self.bmg.add_to_real_matrix(node)

if result is None:
self.errors.add_error(
Violation(
node,
self._typer[node],
bt.RealMatrix(req_rows, req_cols),
consumer,
edge,
self.bmg.execution_context.node_locations(consumer),
)
) -> Optional[bn.BMGNode]:

assert not self._node_meets_requirement(node, requirement)

node_type = self._typer[node]
if not self._type_meets_requirement(node_type, bt.upper_bound(requirement)):
return None

# If we got here then the node did NOT meet the requirement,
# but its type DID meet an upper bound requirement, which
# implies that the requirement was not an upper bound requirement.
assert not isinstance(requirement, bt.UpperBound)

# We definitely can meet the requirement by inserting some sort
# of conversion logic. We have different helper methods for
# the atomic type and matrix type cases.
if bt.must_be_matrix(requirement):
result = self._convert_operator_to_matrix_type(
node, requirement, consumer, edge
)
return node
else:
assert isinstance(requirement, bt.BMGLatticeType)
result = self._convert_operator_to_atomic_type(
node, requirement, consumer, edge
)
assert self._node_meets_requirement(result, requirement)
return result

def _try_to_force_to_prob(self, node, requirement) -> Optional[bn.BMGNode]:
# We cannot make the node meet the requirement "implicitly". We can
# "explicitly" meet a requirement of probability if we have a
# real or pos real.

node_type = self._typer[node]
if not self._can_force_to_prob(node_type, requirement):
return None
assert node_type == bt.Real or node_type == bt.PositiveReal
assert self._node_meets_requirement(node, node_type)
return self.bmg.add_to_probability(node)

def _try_to_force_to_neg_real(self, node, requirement) -> Optional[bn.BMGNode]:
# We cannot make the node meet the requirement "implicitly". We can
# "explicitly" meet a requirement of neg real if we have a value we do
# not know is positive.
node_type = self._typer[node]
if not self._can_force_to_neg_real(node_type, requirement):
return None

return self.bmg.add_to_negative_real(node)

def _meet_operator_requirement(
self,
node: bn.OperatorNode,
requirement: bt.Requirement,
consumer: bn.BMGNode,
edge: str,
) -> bn.BMGNode:
# If the operator node already meets the requirement, we're done.
# We should not have called this function if the input node already meets
# the requirement on the edge.

assert not self._node_meets_requirement(node, requirement)

# It does not meet the requirement. Can we convert this thing to a node
# whose type does meet the requirement? The lattice type is the
# smallest type that this node is convertible to, so if the lattice type
# meets an upper bound requirement, then the conversion we want exists.
# ----
#
# TODO: Is the problem that we have a scalar but we need a matrix full
# of that value? Generate a matrix fill operation.
#
# TODO: Is the problem that we have a row or column matrix but we need
# a rectangular matrix? Generate a broadcast operation.
#
# TODO: Note that in either of these cases, we might *also* need to
# generate a type conversion, so we might not meet the requirement on
# after introducing the fill / broadcast node.
#
# ----

# Is the requirement that we have a real-valued matrix, but we haven't got
# a real-valued matrix? Every value can be converted to a real-valued matrix,
# so that's the easiest case. Knock it out first.

result = self._try_to_meet_any_real_matrix_requirement(node, requirement)
if result is not None:
return result

# Is the requirement that we have any positive real-valued matrix? Every value
# except negative real scalars and matrices can be converted to a positive real
# matrix.

result = self._try_to_meet_any_pos_real_matrix_requirement(node, requirement)
if result is not None:
return result

# If we weaken the requirement to an upper bound requirement, do we meet it? If so,
# then there is a conversion node we can add.

result = self._try_to_meet_upper_bound_requirement(
node, requirement, consumer, edge
)
if result is not None:
return result

result = self._try_to_force_to_prob(node, requirement)
if result is not None:
return result

node_type = self._typer[node]
if isinstance(node_type, bt.BMGMatrixType):
rows = node_type.rows
columns = node_type.columns
else:
rows = 1
columns = 1
if isinstance(requirement, bt.RealMatrix):
return self._meet_real_matrix_requirement(

result = self._try_to_force_to_neg_real(node, requirement)
if result is not None:
return result

# Those are the only techniques we have to make an operator meet a requirement.
# We have no way to make the conversion we need, so add an error.
self.errors.add_error(
Violation(
node,
dim_req=(requirement.rows, requirement.columns),
node_dim=(rows, columns),
consumer=consumer,
edge=edge,
)
elif requirement == bt.RealMatrix:
return self._meet_real_matrix_requirement_type(node, (rows, columns))
elif requirement is bt.any_real_matrix:
result = self.bmg.add_to_real_matrix(node)
elif requirement is bt.any_pos_real_matrix:
result = self.bmg.add_to_positive_real_matrix(node)
elif self._type_meets_requirement(node_type, bt.upper_bound(requirement)):
# If we got here then the node did NOT meet the requirement,
# but its type DID meet an upper bound requirement, which
# implies that the requirement was not an upper bound requirement.
assert not isinstance(requirement, bt.UpperBound)

# We definitely can meet the requirement by inserting some sort
# of conversion logic. We have different helper methods for
# the atomic type and matrix type cases.
if bt.must_be_matrix(requirement):
result = self._convert_operator_to_matrix_type(
node, requirement, consumer, edge
)
else:
assert isinstance(requirement, bt.BMGLatticeType)
result = self._convert_operator_to_atomic_type(
node, requirement, consumer, edge
)
elif self._can_force_to_prob(node_type, requirement):
# We cannot make the node meet the requirement "implicitly". We can
# "explicitly" meet a requirement of probability if we have a
# real or pos real.
assert node_type == bt.Real or node_type == bt.PositiveReal
assert self._node_meets_requirement(node, node_type)
result = self.bmg.add_to_probability(node)
elif self._can_force_to_neg_real(node_type, requirement):
# Similarly if we have a real but need a negative real
result = self.bmg.add_to_negative_real(node)
else:
# We have no way to make the conversion we need, so add an error.
self.errors.add_error(
Violation(
node,
node_type,
requirement,
consumer,
edge,
self.bmg.execution_context.node_locations(consumer),
)
node_type,
requirement,
consumer,
edge,
self.bmg.execution_context.node_locations(consumer),
)
return node

assert self._node_meets_requirement(result, requirement)
return result
)
return node

def _check_requirement_validity(
self,
Expand Down
5 changes: 5 additions & 0 deletions src/beanmachine/ppl/compiler/lattice_typer.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,11 @@ def _lattice_type_for_element_type(
def _type_binary_elementwise_op(
self, node: bn.BinaryOperatorNode
) -> bt.BMGLatticeType:
# Elementwise multiplication and addition require that the operands be
# of the same type and size, and that's the resulting type. Rather than
# enforcing that here, find the supremum of the element types and a size
# where both operands can be broadcast to that size. We'll then add the
# appropriate broadcast nodes in the requirements fixer.
left_type = self[node.left]
right_type = self[node.right]
assert isinstance(left_type, bt.BMGMatrixType)
Expand Down
1 change: 1 addition & 0 deletions tests/ppl/compiler/broadcast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def broadcast_add():


class BroadcastTest(unittest.TestCase):
# TODO: Test broadcast multiplication as well.
def test_broadcast_add(self) -> None:
self.maxDiff = None
observations = {}
Expand Down
2 changes: 1 addition & 1 deletion tests/ppl/compiler/fix_vectorized_models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,7 @@ def test_fix_vectorized_models_8(self) -> None:
N07[label=2];
N08[label=1];
N09[label=ToMatrix];
N10[label=ToRealMatrix];
N10[label=ToPosRealMatrix];
N11[label="[5.0,6.0]"];
N12[label=ElementwiseMult];
N13[label=Query];
Expand Down