Skip to content

Commit

Permalink
add a flag to disable buffer initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
liuanji committed Oct 16, 2024
1 parent aea003a commit 41994f7
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/pyjuice/model/tensorcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ def set_propagation_alg(self, propagation_alg: str, **kwargs):
def forward(self, inputs: torch.Tensor, input_layer_fn: Optional[Union[str,Callable]] = None,
cache: Optional[dict] = None, return_cache: bool = False, record_cudagraph: bool = False,
apply_cudagraph: bool = True, force_use_bf16: bool = False, force_use_fp32: bool = False,
propagation_alg: Optional[Union[str,Sequence[str]]] = None, _inner_layers_only: bool = False, **kwargs):
propagation_alg: Optional[Union[str,Sequence[str]]] = None, _inner_layers_only: bool = False,
_no_buffer_reset: bool = False, **kwargs):
"""
Forward evaluation of the PC.
Expand All @@ -206,8 +207,9 @@ def forward(self, inputs: torch.Tensor, input_layer_fn: Optional[Union[str,Calla

## Initialize buffers for forward pass ##

self._init_buffer(name = "node_mars", shape = (self.num_nodes, B), set_value = 0.0)
self._init_buffer(name = "element_mars", shape = (self.num_elements, B), set_value = -torch.inf)
if not _no_buffer_reset:
self._init_buffer(name = "node_mars", shape = (self.num_nodes, B), set_value = 0.0)
self._init_buffer(name = "element_mars", shape = (self.num_elements, B), set_value = -torch.inf)

# Load cached node marginals
if self._buffer_matches(name = "node_mars", cache = cache):
Expand Down

0 comments on commit 41994f7

Please sign in to comment.