Skip to content

Commit

Permalink
helper functions to retrieve node mars/flows
Browse files Browse the repository at this point in the history
  • Loading branch information
liuanji committed Jun 20, 2024
1 parent 8f4469c commit ca011b1
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 3 deletions.
3 changes: 3 additions & 0 deletions src/pyjuice/layer/input_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,9 @@ def _prepare_scope2nids(self):
scope: torch.cat(ids, dim = 0).to(self.params.device) for scope, ids in scope2localgids.items()
}

def is_input(self):
return True

def _reorder_nodes(self, nodes):
node_set = set(nodes)
reordered_untied_nodes = []
Expand Down
9 changes: 9 additions & 0 deletions src/pyjuice/layer/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,15 @@ def disable_partial_evaluation(self, forward: bool = True, backward: bool = True
def provided(self, var_name):
return hasattr(self, var_name) and getattr(self, var_name) is not None

def is_sum(self):
return False

def is_prod(self):
return False

def is_input(self):
return False

def _get_propagation_alg_kwargs(self, propagation_alg: str, **kwargs):
if propagation_alg == "LL":
return {"alpha": 0.0}
Expand Down
3 changes: 3 additions & 0 deletions src/pyjuice/layer/prod_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,9 @@ def enable_partial_evaluation(self, fw_scopes: Optional[Sequence[BitSet]] = None
torch.cat(ids, dim = 0) if len(ids) > 0 else torch.zeros([0], dtype = torch.long) for ids in bk_fw_partition_local_ids
]

def is_prod(self):
return True

@staticmethod
# @triton.jit
@FastJITFunction
Expand Down
3 changes: 3 additions & 0 deletions src/pyjuice/layer/sum_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,9 @@ def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor,

return None

def is_sum(self):
return True

def _forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor,
params: torch.Tensor, nids: torch.Tensor, cids: torch.Tensor,
pids: torch.Tensor, local_ids: Optional[torch.Tensor] = None,
Expand Down
69 changes: 66 additions & 3 deletions src/pyjuice/model/tensorcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ def __init__(self, root_ns: CircuitNodes, layer_sparsity_tol: float = 0.5,
self.default_propagation_alg = "LL" # Could be "LL", "MPE", or "GeneralLL"
self.propagation_alg_kwargs = dict()

# Running parameters
self._run_params = dict()

def to(self, device):
super(TensorCircuit, self).to(device)

Expand Down Expand Up @@ -291,6 +294,11 @@ def backward(self, inputs: Optional[torch.Tensor] = None,
:type input_layer_fn: Optional[Union[str,Callable]]
"""

self._run_params["allow_modify_flows"] = allow_modify_flows
self._run_params["propagation_alg"] = propagation_alg
self._run_params["logspace_flows"] = logspace_flows
self._run_params["negate_pflows"] = negate_pflows

assert self.node_mars is not None and self.element_mars is not None, "Should run forward path first."
if input_layer_fn is None:
if inputs.size(0) != self.num_vars:
Expand Down Expand Up @@ -520,10 +528,65 @@ def print_statistics(self):
print(f"> Number of sum parameters: {self.num_sum_params}")

def get_node_mars(self, ns: CircuitNodes):
pass
assert self.root_ns.contains(ns)
assert hasattr(self, "node_mars") and self.node_mars is not None
assert hasattr(self, "element_mars") and self.element_mars is not None

nsid, neid = ns._output_ind_range

if ns.is_sum() or ns.is_input():
return self.node_mars[nsid:neid,:].detach()
else:
assert ns.is_prod()

target_layer = None
for layer_group in self.inner_layer_groups:
for layer in layer_group:
if layer.is_prod() and ns in layer.nodes:
target_layer = layer
break

if target_layer is not None:
break

# Rerun the corresponding product layer to get the node values
layer(self.node_mars, self.element_mars)

return self.element_mars[nsid:neid,:].detach()

def get_node_flows(self, ns: CircuitNodes, **kwargs):
assert self.root_ns.contains(ns)
assert hasattr(self, "node_flows") and self.node_flows is not None
assert hasattr(self, "element_flows") and self.element_flows is not None

nsid, neid = ns._output_ind_range

if ns.is_sum() or ns.is_input():
return self.node_flows[nsid:neid,:].detach()
else:
assert ns.is_prod()

layer_id = None
for idx, layer_group in enumerate(self.inner_layer_groups):
for layer in layer_group:
if layer.is_prod() and ns in layer.nodes:
layer_id = idx
break

if layer_id is not None:
break

# Rerun the corresponding product layer to get the node values
self.inner_layer_groups[layer_id].forward(self.node_mars, self.element_mars, _for_backward = True)
self.inner_layer_groups[layer_id+1].backward(
self.node_flows, self.element_flows, self.node_mars, self.element_mars, self.params,
param_flows = None, allow_modify_flows = self._run_params["allow_modify_flows"],
propagation_alg = self._run_params["propagation_alg"],
logspace_flows = self._run_params["logspace_flows"],
negate_pflows = self._run_params["negate_pflows"], **kwargs
)

def get_node_flows(self, ns: CircuitNodes):
pass
return self.element_flows[nsid:neid,:].detach()

def enable_partial_evaluation(self, scopes: Union[Sequence[BitSet],Sequence[int]],
forward: bool = False, backward: bool = False, overwrite: bool = False):
Expand Down
7 changes: 7 additions & 0 deletions src/pyjuice/nodes/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,13 @@ def has_params(self):
source_ns = self.get_source_ns()
return hasattr(source_ns, "_params") and source_ns._params is not None

def contains(self, ns: CircuitNodes):
for n in self:
if n == ns:
return True

return False

def _clear_tensor_circuit_hooks(self, recursive: bool = True):

def clear_hooks(ns):
Expand Down
63 changes: 63 additions & 0 deletions tests/model/functionality_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import pyjuice as juice
import torch
import numpy as np

import pyjuice.nodes.distributions as dists
from pyjuice.utils import BitSet
from pyjuice.nodes import multiply, summate, inputs
from pyjuice.model import TensorCircuit

import pytest


def test_tensorcircuit_fns():

ni0 = inputs(0, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2))
ni1 = inputs(1, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2))
ni2 = inputs(2, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2))
ni3 = inputs(3, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2))

m1 = multiply(ni0, ni1, edge_ids = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]], dtype = torch.long))
n1 = summate(m1, edge_ids = torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1], [0, 1, 2, 3, 0, 1, 2, 3]], dtype = torch.long))

m2 = multiply(ni2, ni3, edge_ids = torch.tensor([[0, 0], [1, 1]], dtype = torch.long))
n2 = summate(m2, edge_ids = torch.tensor([[0, 0, 1, 1], [0, 1, 0, 1]], dtype = torch.long))

m = multiply(n1, n2, edge_ids = torch.tensor([[0, 0], [1, 1]], dtype = torch.long))
n = summate(m, edge_ids = torch.tensor([[0, 0], [0, 1]], dtype = torch.long))

pc = TensorCircuit(n)

device = torch.device("cuda:0")
pc.to(device)

data = torch.randint(0, 2, [16, 4]).to(device)

lls = pc(data)
pc.backward(data.permute(1, 0), allow_modify_flows = False)

nsid, neid = n2._output_ind_range
n2_mars = pc.get_node_mars(n2)
assert n2_mars.size(0) == neid - nsid
assert n2_mars.size(1) == 16
assert (n2_mars == pc.node_mars[nsid:neid,:]).all()

n2_flows = pc.get_node_flows(n2)
assert n2_flows.size(0) == neid - nsid
assert n2_flows.size(1) == 16
assert (n2_flows == pc.node_flows[nsid:neid,:]).all()

nsid, neid = m2._output_ind_range
m2_mars = pc.get_node_mars(m2)
assert m2_mars.size(0) == neid - nsid
assert m2_mars.size(1) == 16
assert (m2_mars == pc.element_mars[nsid:neid,:]).all()

m2_flows = pc.get_node_flows(m2)
assert m2_flows.size(0) == neid - nsid
assert m2_flows.size(1) == 16
assert (m2_flows == pc.element_flows[nsid:neid,:]).all()


if __name__ == "__main__":
test_tensorcircuit_fns()

0 comments on commit ca011b1

Please sign in to comment.