Skip to content

Commit

Permalink
comments for mars/flows retrieval function
Browse files Browse the repository at this point in the history
  • Loading branch information
liuanji committed Jun 20, 2024
1 parent ca011b1 commit 8114cde
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/pyjuice/model/tensorcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,12 @@ def print_statistics(self):
print(f"> Number of sum parameters: {self.num_sum_params}")

def get_node_mars(self, ns: CircuitNodes):
"""
Retrieve the node values of `ns` from the previous forward pass.
:params ns: the target nodes
:type ns: CircuitNodes
"""
assert self.root_ns.contains(ns)
assert hasattr(self, "node_mars") and self.node_mars is not None
assert hasattr(self, "element_mars") and self.element_mars is not None
Expand Down Expand Up @@ -555,6 +561,12 @@ def get_node_mars(self, ns: CircuitNodes):
return self.element_mars[nsid:neid,:].detach()

def get_node_flows(self, ns: CircuitNodes, **kwargs):
"""
Retrieve the node flows of `ns` from the previous backward pass.
:params ns: the target nodes
:type ns: CircuitNodes
"""
assert self.root_ns.contains(ns)
assert hasattr(self, "node_flows") and self.node_flows is not None
assert hasattr(self, "element_flows") and self.element_flows is not None
Expand Down

0 comments on commit 8114cde

Please sign in to comment.