From ca011b1554e56a8d958d554aa00701f415a93c3b Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Thu, 20 Jun 2024 15:39:57 +0800 Subject: [PATCH] helper functions to retrieve node mars/flows --- src/pyjuice/layer/input_layer.py | 3 ++ src/pyjuice/layer/layer.py | 9 ++++ src/pyjuice/layer/prod_layer.py | 3 ++ src/pyjuice/layer/sum_layer.py | 3 ++ src/pyjuice/model/tensorcircuit.py | 69 ++++++++++++++++++++++++++++-- src/pyjuice/nodes/nodes.py | 7 +++ tests/model/functionality_test.py | 63 +++++++++++++++++++++++++++ 7 files changed, 154 insertions(+), 3 deletions(-) create mode 100644 tests/model/functionality_test.py diff --git a/src/pyjuice/layer/input_layer.py b/src/pyjuice/layer/input_layer.py index 9db29bf3..9b59ca02 100644 --- a/src/pyjuice/layer/input_layer.py +++ b/src/pyjuice/layer/input_layer.py @@ -620,6 +620,9 @@ def _prepare_scope2nids(self): scope: torch.cat(ids, dim = 0).to(self.params.device) for scope, ids in scope2localgids.items() } + def is_input(self): + return True + def _reorder_nodes(self, nodes): node_set = set(nodes) reordered_untied_nodes = [] diff --git a/src/pyjuice/layer/layer.py b/src/pyjuice/layer/layer.py index f218cb12..dff94022 100644 --- a/src/pyjuice/layer/layer.py +++ b/src/pyjuice/layer/layer.py @@ -68,6 +68,15 @@ def disable_partial_evaluation(self, forward: bool = True, backward: bool = True def provided(self, var_name): return hasattr(self, var_name) and getattr(self, var_name) is not None + def is_sum(self): + return False + + def is_prod(self): + return False + + def is_input(self): + return False + def _get_propagation_alg_kwargs(self, propagation_alg: str, **kwargs): if propagation_alg == "LL": return {"alpha": 0.0} diff --git a/src/pyjuice/layer/prod_layer.py b/src/pyjuice/layer/prod_layer.py index c47fa24a..7ec2b427 100644 --- a/src/pyjuice/layer/prod_layer.py +++ b/src/pyjuice/layer/prod_layer.py @@ -253,6 +253,9 @@ def enable_partial_evaluation(self, fw_scopes: Optional[Sequence[BitSet]] = None torch.cat(ids, dim = 0) if len(ids) > 0 else torch.zeros([0], dtype = torch.long) for ids in bk_fw_partition_local_ids ] + def is_prod(self): + return True + @staticmethod # @triton.jit @FastJITFunction diff --git a/src/pyjuice/layer/sum_layer.py b/src/pyjuice/layer/sum_layer.py index 2d6e564d..89ea2a6a 100644 --- a/src/pyjuice/layer/sum_layer.py +++ b/src/pyjuice/layer/sum_layer.py @@ -367,6 +367,9 @@ def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, return None + def is_sum(self): + return True + def _forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor, params: torch.Tensor, nids: torch.Tensor, cids: torch.Tensor, pids: torch.Tensor, local_ids: Optional[torch.Tensor] = None, diff --git a/src/pyjuice/model/tensorcircuit.py b/src/pyjuice/model/tensorcircuit.py index 302ccc6d..4d025026 100644 --- a/src/pyjuice/model/tensorcircuit.py +++ b/src/pyjuice/model/tensorcircuit.py @@ -106,6 +106,9 @@ def __init__(self, root_ns: CircuitNodes, layer_sparsity_tol: float = 0.5, self.default_propagation_alg = "LL" # Could be "LL", "MPE", or "GeneralLL" self.propagation_alg_kwargs = dict() + # Running parameters + self._run_params = dict() + def to(self, device): super(TensorCircuit, self).to(device) @@ -291,6 +294,11 @@ def backward(self, inputs: Optional[torch.Tensor] = None, :type input_layer_fn: Optional[Union[str,Callable]] """ + self._run_params["allow_modify_flows"] = allow_modify_flows + self._run_params["propagation_alg"] = propagation_alg + self._run_params["logspace_flows"] = logspace_flows + self._run_params["negate_pflows"] = negate_pflows + assert self.node_mars is not None and self.element_mars is not None, "Should run forward path first." if input_layer_fn is None: if inputs.size(0) != self.num_vars: @@ -520,10 +528,65 @@ def print_statistics(self): print(f"> Number of sum parameters: {self.num_sum_params}") def get_node_mars(self, ns: CircuitNodes): - pass + assert self.root_ns.contains(ns) + assert hasattr(self, "node_mars") and self.node_mars is not None + assert hasattr(self, "element_mars") and self.element_mars is not None + + nsid, neid = ns._output_ind_range + + if ns.is_sum() or ns.is_input(): + return self.node_mars[nsid:neid,:].detach() + else: + assert ns.is_prod() + + target_layer = None + for layer_group in self.inner_layer_groups: + for layer in layer_group: + if layer.is_prod() and ns in layer.nodes: + target_layer = layer + break + + if target_layer is not None: + break + + # Rerun the corresponding product layer to get the node values + layer(self.node_mars, self.element_mars) + + return self.element_mars[nsid:neid,:].detach() + + def get_node_flows(self, ns: CircuitNodes, **kwargs): + assert self.root_ns.contains(ns) + assert hasattr(self, "node_flows") and self.node_flows is not None + assert hasattr(self, "element_flows") and self.element_flows is not None + + nsid, neid = ns._output_ind_range + + if ns.is_sum() or ns.is_input(): + return self.node_flows[nsid:neid,:].detach() + else: + assert ns.is_prod() + + layer_id = None + for idx, layer_group in enumerate(self.inner_layer_groups): + for layer in layer_group: + if layer.is_prod() and ns in layer.nodes: + layer_id = idx + break + + if layer_id is not None: + break + + # Rerun the corresponding product layer to get the node values + self.inner_layer_groups[layer_id].forward(self.node_mars, self.element_mars, _for_backward = True) + self.inner_layer_groups[layer_id+1].backward( + self.node_flows, self.element_flows, self.node_mars, self.element_mars, self.params, + param_flows = None, allow_modify_flows = self._run_params["allow_modify_flows"], + propagation_alg = self._run_params["propagation_alg"], + logspace_flows = self._run_params["logspace_flows"], + negate_pflows = self._run_params["negate_pflows"], **kwargs + ) - def get_node_flows(self, ns: CircuitNodes): - pass + return self.element_flows[nsid:neid,:].detach() def enable_partial_evaluation(self, scopes: Union[Sequence[BitSet],Sequence[int]], forward: bool = False, backward: bool = False, overwrite: bool = False): diff --git a/src/pyjuice/nodes/nodes.py b/src/pyjuice/nodes/nodes.py index b8741f79..8070a7fa 100644 --- a/src/pyjuice/nodes/nodes.py +++ b/src/pyjuice/nodes/nodes.py @@ -206,6 +206,13 @@ def has_params(self): source_ns = self.get_source_ns() return hasattr(source_ns, "_params") and source_ns._params is not None + def contains(self, ns: CircuitNodes): + for n in self: + if n == ns: + return True + + return False + def _clear_tensor_circuit_hooks(self, recursive: bool = True): def clear_hooks(ns): diff --git a/tests/model/functionality_test.py b/tests/model/functionality_test.py new file mode 100644 index 00000000..b403b02c --- /dev/null +++ b/tests/model/functionality_test.py @@ -0,0 +1,63 @@ +import pyjuice as juice +import torch +import numpy as np + +import pyjuice.nodes.distributions as dists +from pyjuice.utils import BitSet +from pyjuice.nodes import multiply, summate, inputs +from pyjuice.model import TensorCircuit + +import pytest + + +def test_tensorcircuit_fns(): + + ni0 = inputs(0, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + ni1 = inputs(1, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + ni2 = inputs(2, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + ni3 = inputs(3, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + + m1 = multiply(ni0, ni1, edge_ids = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]], dtype = torch.long)) + n1 = summate(m1, edge_ids = torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1], [0, 1, 2, 3, 0, 1, 2, 3]], dtype = torch.long)) + + m2 = multiply(ni2, ni3, edge_ids = torch.tensor([[0, 0], [1, 1]], dtype = torch.long)) + n2 = summate(m2, edge_ids = torch.tensor([[0, 0, 1, 1], [0, 1, 0, 1]], dtype = torch.long)) + + m = multiply(n1, n2, edge_ids = torch.tensor([[0, 0], [1, 1]], dtype = torch.long)) + n = summate(m, edge_ids = torch.tensor([[0, 0], [0, 1]], dtype = torch.long)) + + pc = TensorCircuit(n) + + device = torch.device("cuda:0") + pc.to(device) + + data = torch.randint(0, 2, [16, 4]).to(device) + + lls = pc(data) + pc.backward(data.permute(1, 0), allow_modify_flows = False) + + nsid, neid = n2._output_ind_range + n2_mars = pc.get_node_mars(n2) + assert n2_mars.size(0) == neid - nsid + assert n2_mars.size(1) == 16 + assert (n2_mars == pc.node_mars[nsid:neid,:]).all() + + n2_flows = pc.get_node_flows(n2) + assert n2_flows.size(0) == neid - nsid + assert n2_flows.size(1) == 16 + assert (n2_flows == pc.node_flows[nsid:neid,:]).all() + + nsid, neid = m2._output_ind_range + m2_mars = pc.get_node_mars(m2) + assert m2_mars.size(0) == neid - nsid + assert m2_mars.size(1) == 16 + assert (m2_mars == pc.element_mars[nsid:neid,:]).all() + + m2_flows = pc.get_node_flows(m2) + assert m2_flows.size(0) == neid - nsid + assert m2_flows.size(1) == 16 + assert (m2_flows == pc.element_flows[nsid:neid,:]).all() + + +if __name__ == "__main__": + test_tensorcircuit_fns() \ No newline at end of file