Skip to content

Commit

Permalink
add context manager to disable SM/DEC checks
Browse files Browse the repository at this point in the history
  • Loading branch information
liuanji committed Apr 23, 2024
1 parent 51a2463 commit 7212ddf
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/pyjuice/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pyjuice.model import compile, TensorCircuit

# Construction methods
from pyjuice.nodes import multiply, summate, inputs, set_block_size
from pyjuice.nodes import multiply, summate, inputs, set_block_size, structural_properties

# Distributions
from pyjuice.nodes import distributions
Expand Down
13 changes: 11 additions & 2 deletions src/pyjuice/graph/region_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@


class RegionGraph():

# Property checks
ALLOW_NONSMOOTH = False
ALLOW_NONDECOMPOSABLE = False

def __init__(self, scope: BitSet, children: List[RegionGraph]) -> None:
self.scope = scope
self.children = children
Expand All @@ -26,7 +31,8 @@ def __init__(self, children: List[Union[InnerRegionNode,InputRegionNode]]) -> No

scope = BitSet()
for n in children:
assert len(scope & n.scope) == 0, "Children of a PartitionNode have overlapping scopes."
if not self.ALLOW_NONDECOMPOSABLE:
assert len(scope & n.scope) == 0, "Children of a PartitionNode have overlapping scopes."
scope |= n.scope

super().__init__(scope, children)
Expand All @@ -43,7 +49,10 @@ def __init__(self, children: List[Union[InputRegionNode,PartitionNode]]) -> None

scope = deepcopy(children[0].scope)
for n in children[1:]:
assert scope == n.scope, "Children of an InnerRegionNode must have the same scope."
if not self.ALLOW_NONSMOOTH:
assert scope == n.scope, "Children of an InnerRegionNode must have the same scope."
else:
scope |= n.scope

super().__init__(scope, children)

Expand Down
2 changes: 1 addition & 1 deletion src/pyjuice/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
from .input_nodes import InputNodes
from .prod_nodes import ProdNodes
from .sum_nodes import SumNodes
from .construction import multiply, summate, inputs, set_block_size
from .construction import multiply, summate, inputs, set_block_size, structural_properties
from .methods.traversal import foreach, foldup_aggregate
39 changes: 36 additions & 3 deletions src/pyjuice/nodes/construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .prod_nodes import ProdNodes
from .sum_nodes import SumNodes
from .distributions import Distribution
from pyjuice.graph import RegionGraph

Tensor = Union[np.ndarray,torch.Tensor]
ProdNodesChs = Union[SumNodes,InputNodes]
Expand Down Expand Up @@ -102,7 +103,8 @@ def multiply(nodes1: ProdNodesChs, *args, edge_ids: Optional[Tensor] = None, spa
if edge_ids is None:
assert nodes.num_node_blocks == num_node_blocks, f"Input nodes should have the same `num_node_blocks`, but got {nodes.num_node_blocks} and {num_node_blocks}."
assert nodes.block_size == block_size, "Input nodes should have the same `num_node_blocks`."
assert len(nodes.scope & scope) == 0, "Children of a `ProdNodes` should have disjoint scopes."
if not RegionGraph.ALLOW_NONDECOMPOSABLE:
assert len(nodes.scope & scope) == 0, "Children of a `ProdNodes` should have disjoint scopes."
chs.append(nodes)
scope |= nodes.scope

Expand Down Expand Up @@ -169,7 +171,8 @@ def summate(nodes1: SumNodesChs, *args, num_node_blocks: int = 0, num_nodes: int
scope = deepcopy(nodes1.scope)
for nodes in args:
assert isinstance(nodes, ProdNodes) or isinstance(nodes, InputNodes), f"Children of sum nodes must be input or product nodes, but found input of type {type(nodes)}."
assert nodes.scope == scope, "Children of a `SumNodes` should have the same scope."
if not RegionGraph.ALLOW_NONSMOOTH:
assert nodes.scope == scope, "Children of a `SumNodes` should have the same scope."
chs.append(nodes)

return SumNodes(num_node_blocks, chs, edge_ids, block_size = block_size, **kwargs)
Expand Down Expand Up @@ -202,4 +205,34 @@ def __enter__(self) -> None:
CircuitNodes.DEFAULT_BLOCK_SIZE = self.block_size

def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
CircuitNodes.DEFAULT_BLOCK_SIZE = self.original_block_size
CircuitNodes.DEFAULT_BLOCK_SIZE = self.original_block_size


class structural_properties(_DecoratorContextManager):
"""
Context-manager that controls the assertions of circuit structural properties, including smoothness and decomposability.
:param allow_nonsmooth: whether to allow non-smooth circuits
:type allow_nonsmooth: bool
:param allow_nondecomposable: whether to allow non-decomposable circuits
:type allow_nondecomposable: bool
Example::
>>> with pyjuice.structural_properties(allow_nonsmooth = True):
... nis = pyjuice.inputs(var = 0, num_node_blocks = 4, dist = Categorical(num_cats = 20))
... ....
"""

def __init__(self, allow_nonsmooth = False, allow_nondecomposable = False):

self.allow_nonsmooth = allow_nonsmooth
self.allow_nondecomposable = allow_nondecomposable

def __enter__(self) -> None:
RegionGraph.ALLOW_NONSMOOTH = self.allow_nonsmooth
RegionGraph.ALLOW_NONDECOMPOSABLE = self.allow_nondecomposable

def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
RegionGraph.ALLOW_NONSMOOTH = False
RegionGraph.ALLOW_NONDECOMPOSABLE = False
12 changes: 11 additions & 1 deletion tests/nodes/nodes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,15 @@ def test_nodes():
assert torch.all(torch.abs(n._params.sum(dim = 2).sum(dim = 0) - 1.0) < 1e-4)


def test_structural_properties():

with juice.structural_properties(allow_nonsmooth = True):
n0 = inputs(0, 4, dists.Categorical(num_cats = 5))
n1 = inputs(1, 4, dists.Categorical(num_cats = 5))

ns = summate(n0, n1, num_node_blocks = 1, block_size = 1)


if __name__ == "__main__":
test_nodes()
test_nodes()
test_structural_properties()

0 comments on commit 7212ddf

Please sign in to comment.