From 7212ddf6f762d5cf466a8a357d0a7c24e3f2b455 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Tue, 23 Apr 2024 18:40:16 +0800 Subject: [PATCH] add context manager to disable SM/DEC checks --- src/pyjuice/__init__.py | 2 +- src/pyjuice/graph/region_graph.py | 13 +++++++++-- src/pyjuice/nodes/__init__.py | 2 +- src/pyjuice/nodes/construction.py | 39 ++++++++++++++++++++++++++++--- tests/nodes/nodes_test.py | 12 +++++++++- 5 files changed, 60 insertions(+), 8 deletions(-) diff --git a/src/pyjuice/__init__.py b/src/pyjuice/__init__.py index 8ad855f8..bfc8060b 100644 --- a/src/pyjuice/__init__.py +++ b/src/pyjuice/__init__.py @@ -13,7 +13,7 @@ from pyjuice.model import compile, TensorCircuit # Construction methods -from pyjuice.nodes import multiply, summate, inputs, set_block_size +from pyjuice.nodes import multiply, summate, inputs, set_block_size, structural_properties # Distributions from pyjuice.nodes import distributions diff --git a/src/pyjuice/graph/region_graph.py b/src/pyjuice/graph/region_graph.py index f665435c..542b1c02 100644 --- a/src/pyjuice/graph/region_graph.py +++ b/src/pyjuice/graph/region_graph.py @@ -9,6 +9,11 @@ class RegionGraph(): + + # Property checks + ALLOW_NONSMOOTH = False + ALLOW_NONDECOMPOSABLE = False + def __init__(self, scope: BitSet, children: List[RegionGraph]) -> None: self.scope = scope self.children = children @@ -26,7 +31,8 @@ def __init__(self, children: List[Union[InnerRegionNode,InputRegionNode]]) -> No scope = BitSet() for n in children: - assert len(scope & n.scope) == 0, "Children of a PartitionNode have overlapping scopes." + if not self.ALLOW_NONDECOMPOSABLE: + assert len(scope & n.scope) == 0, "Children of a PartitionNode have overlapping scopes." scope |= n.scope super().__init__(scope, children) @@ -43,7 +49,10 @@ def __init__(self, children: List[Union[InputRegionNode,PartitionNode]]) -> None scope = deepcopy(children[0].scope) for n in children[1:]: - assert scope == n.scope, "Children of an InnerRegionNode must have the same scope." + if not self.ALLOW_NONSMOOTH: + assert scope == n.scope, "Children of an InnerRegionNode must have the same scope." + else: + scope |= n.scope super().__init__(scope, children) diff --git a/src/pyjuice/nodes/__init__.py b/src/pyjuice/nodes/__init__.py index 038d971f..d7fc79bb 100644 --- a/src/pyjuice/nodes/__init__.py +++ b/src/pyjuice/nodes/__init__.py @@ -2,5 +2,5 @@ from .input_nodes import InputNodes from .prod_nodes import ProdNodes from .sum_nodes import SumNodes -from .construction import multiply, summate, inputs, set_block_size +from .construction import multiply, summate, inputs, set_block_size, structural_properties from .methods.traversal import foreach, foldup_aggregate \ No newline at end of file diff --git a/src/pyjuice/nodes/construction.py b/src/pyjuice/nodes/construction.py index 58834550..079e93ed 100644 --- a/src/pyjuice/nodes/construction.py +++ b/src/pyjuice/nodes/construction.py @@ -12,6 +12,7 @@ from .prod_nodes import ProdNodes from .sum_nodes import SumNodes from .distributions import Distribution +from pyjuice.graph import RegionGraph Tensor = Union[np.ndarray,torch.Tensor] ProdNodesChs = Union[SumNodes,InputNodes] @@ -102,7 +103,8 @@ def multiply(nodes1: ProdNodesChs, *args, edge_ids: Optional[Tensor] = None, spa if edge_ids is None: assert nodes.num_node_blocks == num_node_blocks, f"Input nodes should have the same `num_node_blocks`, but got {nodes.num_node_blocks} and {num_node_blocks}." assert nodes.block_size == block_size, "Input nodes should have the same `num_node_blocks`." - assert len(nodes.scope & scope) == 0, "Children of a `ProdNodes` should have disjoint scopes." + if not RegionGraph.ALLOW_NONDECOMPOSABLE: + assert len(nodes.scope & scope) == 0, "Children of a `ProdNodes` should have disjoint scopes." chs.append(nodes) scope |= nodes.scope @@ -169,7 +171,8 @@ def summate(nodes1: SumNodesChs, *args, num_node_blocks: int = 0, num_nodes: int scope = deepcopy(nodes1.scope) for nodes in args: assert isinstance(nodes, ProdNodes) or isinstance(nodes, InputNodes), f"Children of sum nodes must be input or product nodes, but found input of type {type(nodes)}." - assert nodes.scope == scope, "Children of a `SumNodes` should have the same scope." + if not RegionGraph.ALLOW_NONSMOOTH: + assert nodes.scope == scope, "Children of a `SumNodes` should have the same scope." chs.append(nodes) return SumNodes(num_node_blocks, chs, edge_ids, block_size = block_size, **kwargs) @@ -202,4 +205,34 @@ def __enter__(self) -> None: CircuitNodes.DEFAULT_BLOCK_SIZE = self.block_size def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: - CircuitNodes.DEFAULT_BLOCK_SIZE = self.original_block_size \ No newline at end of file + CircuitNodes.DEFAULT_BLOCK_SIZE = self.original_block_size + + +class structural_properties(_DecoratorContextManager): + """ + Context-manager that controls the assertions of circuit structural properties, including smoothness and decomposability. + + :param allow_nonsmooth: whether to allow non-smooth circuits + :type allow_nonsmooth: bool + + :param allow_nondecomposable: whether to allow non-decomposable circuits + :type allow_nondecomposable: bool + + Example:: + >>> with pyjuice.structural_properties(allow_nonsmooth = True): + ... nis = pyjuice.inputs(var = 0, num_node_blocks = 4, dist = Categorical(num_cats = 20)) + ... .... + """ + + def __init__(self, allow_nonsmooth = False, allow_nondecomposable = False): + + self.allow_nonsmooth = allow_nonsmooth + self.allow_nondecomposable = allow_nondecomposable + + def __enter__(self) -> None: + RegionGraph.ALLOW_NONSMOOTH = self.allow_nonsmooth + RegionGraph.ALLOW_NONDECOMPOSABLE = self.allow_nondecomposable + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + RegionGraph.ALLOW_NONSMOOTH = False + RegionGraph.ALLOW_NONDECOMPOSABLE = False diff --git a/tests/nodes/nodes_test.py b/tests/nodes/nodes_test.py index ba492494..53cacaaf 100644 --- a/tests/nodes/nodes_test.py +++ b/tests/nodes/nodes_test.py @@ -44,5 +44,15 @@ def test_nodes(): assert torch.all(torch.abs(n._params.sum(dim = 2).sum(dim = 0) - 1.0) < 1e-4) +def test_structural_properties(): + + with juice.structural_properties(allow_nonsmooth = True): + n0 = inputs(0, 4, dists.Categorical(num_cats = 5)) + n1 = inputs(1, 4, dists.Categorical(num_cats = 5)) + + ns = summate(n0, n1, num_node_blocks = 1, block_size = 1) + + if __name__ == "__main__": - test_nodes() \ No newline at end of file + test_nodes() + test_structural_properties()