Skip to content

Commit

Permalink
fix ll weighting
Browse files Browse the repository at this point in the history
  • Loading branch information
liuanji committed Jun 13, 2024
1 parent b37a236 commit cdee542
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
8 changes: 7 additions & 1 deletion src/pyjuice/model/tensorcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def _set_root_node_flows():
self.node_flows[self._root_node_range[0]:self._root_node_range[1],:] = root_flows
else:
if ll_weights.dim() == 1:
ll_weights = ll_weights.unsqueeze(1)
ll_weights = ll_weights.unsqueeze(0)

assert ll_weights.size(0) == self.num_root_nodes

Expand Down Expand Up @@ -519,6 +519,12 @@ def print_statistics(self):
print(f"> Number of edges: {self.num_edges}")
print(f"> Number of sum parameters: {self.num_sum_params}")

def get_node_mars(self, ns: CircuitNodes):
pass

def get_node_flows(self, ns: CircuitNodes):
pass

def enable_partial_evaluation(self, scopes: Union[Sequence[BitSet],Sequence[int]],
forward: bool = False, backward: bool = False, overwrite: bool = False):
# Create scope2nid cache
Expand Down
2 changes: 1 addition & 1 deletion tests/layer/propagation_algs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def test_general_ll_prop():

my_pflows[layer.partitioned_pfids[0][j,:]+i] = pflows

assert torch.all(torch.abs(my_pflows - param_flows) < 4e-3)
assert torch.all(torch.abs(my_pflows - param_flows) < 6e-3)


def test_mpe_prop():
Expand Down

0 comments on commit cdee542

Please sign in to comment.