Skip to content

Commit

Permalink
docstrings for node classes
Browse files Browse the repository at this point in the history
  • Loading branch information
liuanji committed Feb 14, 2024
1 parent 84d176d commit 91d16f8
Show file tree
Hide file tree
Showing 5 changed files with 235 additions and 8 deletions.
48 changes: 45 additions & 3 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,51 @@
# autosummary_generate = True

def skip(app, what, name, obj, would_skip, options):
if '__' in name or name == "clone":
return True
return would_skip
flag = True

if name == "need_meta_parameters":
flag = False

elif "Nodes" in str(obj) and name == "duplicate":
flag = False

elif "Nodes" in str(obj) and name == "get_params":
flag = False

elif "Nodes" in str(obj) and name == "set_params":
flag = False

elif "InputNodes" in str(obj) and name == "set_meta_params":
flag = False

elif "Nodes" in str(obj) and name == "init_parameters":
flag = False

elif "Nodes" in str(obj) and name == "num_nodes":
flag = False

elif "Nodes" in str(obj) and name == "num_edges":
flag = False

elif "ProdNodes" in str(obj) and name == "edge_type":
flag = False

elif "ProdNodes" in str(obj) and name == "is_block_sparse":
flag = False

elif "ProdNodes" in str(obj) and name == "is_sparse":
flag = False

elif "SumNodes" in str(obj) and name == "update_parameters":
flag = False

elif "SumNodes" in str(obj) and name == "update_param_flows":
flag = False

elif "SumNodes" in str(obj) and name == "gather_parameters":
flag = False

return flag or would_skip

def setup(app):
app.connect('autodoc-skip-member', skip)
59 changes: 56 additions & 3 deletions src/pyjuice/nodes/input_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,24 @@


class InputNodes(CircuitNodes):
"""
A class representing vectors of input nodes.
:param num_node_blocks: number of node blocks
:type num_node_blocks: int
:param scope: variable scope (set of variables)
:type scope: Union[Sequence,BitSet]
:param dist: input distribution
:type dist: Distribution
:param params: parameters of the vector of nodes
:type params: Optional[Tensor]
:param block_size: block size
:type block_size: int
"""
def __init__(self, num_node_blocks: int, scope: Union[Sequence,BitSet], dist: Distribution,
params: Optional[torch.Tensor] = None, block_size: int = 0,
_no_set_meta_params: bool = False, **kwargs) -> None:
Expand Down Expand Up @@ -38,7 +56,18 @@ def __init__(self, num_node_blocks: int, scope: Union[Sequence,BitSet], dist: Di
def num_edges(self):
return 0

def duplicate(self, scope: Optional[Union[int,Sequence,BitSet]] = None, tie_params: bool = False):
def duplicate(self, scope: Optional[Union[int,Sequence,BitSet]] = None, tie_params: bool = False) -> InputNodes:
"""
Create a duplication of the current node with the same specification (i.e., number of nodes, block size, distribution).
:param scope: variable scope of the duplication
:type scope: Optional[Union[int,Sequence,BitSet]]
:param tie_params: whether to tie the parameters of the current node and the duplicated node
:type tie_params: bool
:returns: a duplicated `InputNodes`
"""
if scope is None:
scope = self.scope
else:
Expand All @@ -56,13 +85,25 @@ def duplicate(self, scope: Optional[Union[int,Sequence,BitSet]] = None, tie_para

return ns

def get_params(self):
def get_params(self) -> torch.Tensor:
"""
Get the input node parameters.
"""
if not self.provided("_params"):
return None
else:
return self._params

def set_params(self, params: Union[torch.Tensor,Dict], normalize: bool = True):
"""
Set the input node parameters.
:param params: parameters to be set
:type params: Union[torch.Tensor,Dict]
:param normalize: whether to normalize the parameters
:type normalize: bool
"""
assert params.numel() == self.num_nodes * self.dist.num_parameters()

params = params.reshape(-1)
Expand All @@ -73,13 +114,25 @@ def set_params(self, params: Union[torch.Tensor,Dict], normalize: bool = True):
self._params = params

def set_meta_params(self, **kwargs):
"""
Set the meta-parameters such as the mask of input nodes with the `MaskedCategorical` distribution.
"""
params = self.dist.set_meta_parameters(self.num_nodes, **kwargs)

self._param_initialized = False
self._params = params

def init_parameters(self, perturbation: float = 2.0, recursive: bool = True,
is_root: bool = True, ret_params: bool = False, **kwargs):
is_root: bool = True, ret_params: bool = False, **kwargs) -> None:
"""
Randomly initialize node parameters.
:param perturbation: "amount of perturbation" added to the parameters (should be greater than 0)
:type perturbation: float
:param recursive: whether to recursively apply the function to child nodes
:type recursive: bool
"""
if not self.is_tied() and not self.has_params():
self._params = self.dist.init_parameters(
num_nodes = self.num_nodes,
Expand Down
3 changes: 3 additions & 0 deletions src/pyjuice/nodes/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ def num_chs(self):

@property
def num_nodes(self):
"""
Number of PC nodes within the current node.
"""
return self.num_node_blocks * self.block_size

@property
Expand Down
49 changes: 49 additions & 0 deletions src/pyjuice/nodes/prod_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,21 @@


class ProdNodes(CircuitNodes):
"""
A class representing vectors of product nodes.
:param num_node_blocks: number of node blocks
:type num_node_blocks: int
:param chs: sequence of child nodes
:type chs: Sequence[CircuitNodes]
:param edge_ids: a matrix of size [# product node blocks, # children] - the ith product node block is connected to the `edge_ids[i,j]`th node block in the jth child
:type edge_ids: Optional[Tensor]
:param block_size: block size
:type block_size: int
"""

SPARSE = 0
BLOCK_SPARSE = 1
Expand All @@ -33,10 +48,16 @@ def __init__(self, num_node_blocks: int, chs: Sequence[CircuitNodes], edge_ids:

@property
def num_edges(self):
"""
Number of edges within the current node.
"""
return self.num_nodes * self.num_chs

@property
def edge_type(self):
"""
Type of the product edge. Either `BLOCK_SPARSE` or `SPARSE`.
"""
if self.edge_ids.size(0) == self.num_node_blocks:
return self.BLOCK_SPARSE
elif self.edge_ids.size(0) == self.num_nodes:
Expand All @@ -45,12 +66,31 @@ def edge_type(self):
raise RuntimeError(f"Unexpected shape of `edge_ids`: ({self.edge_ids.size(0)}, {self.edge_ids.size(1)})")

def is_block_sparse(self):
"""
Whether the edge type is `BLOCK_SPARSE`.
"""
return self.edge_type == self.BLOCK_SPARSE

def is_sparse(self):
"""
Whether the edge type is `SPARSE`.
"""
return self.edge_type == self.SPARSE

def duplicate(self, *args, tie_params: bool = False, allow_type_mismatch: bool = False):
"""
Create a duplication of the current node with the same specification (i.e., number of nodes, block size).
:note: The child nodes should have the same specifications compared to the original child nodes.
:param args: a sequence of new child nodes
:type args: CircuitNodes
:param tie_params: whether to tie the parameters of the current node and the duplicated node
:type tie_params: bool
:returns: a duplicated `ProdNodes`
"""
chs = []
for ns in args:
assert isinstance(ns, CircuitNodes)
Expand All @@ -73,6 +113,15 @@ def duplicate(self, *args, tie_params: bool = False, allow_type_mismatch: bool =
return ProdNodes(self.num_node_blocks, chs, edge_ids, block_size = self.block_size, source_node = self if tie_params else None)

def init_parameters(self, perturbation: float = 2.0, recursive: bool = True, is_root: bool = True, **kwargs):
"""
Randomly initialize node parameters.
:param perturbation: "amount of perturbation" added to the parameters (should be greater than 0)
:type perturbation: float
:param recursive: whether to recursively apply the function to child nodes
:type recursive: bool
"""
super(ProdNodes, self).init_parameters(
perturbation = perturbation,
recursive = recursive,
Expand Down
84 changes: 82 additions & 2 deletions src/pyjuice/nodes/sum_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,22 @@


class SumNodes(CircuitNodes):
"""
A class representing vectors of sum nodes.
:param num_node_blocks: number of node blocks
:type num_node_blocks: int
:param chs: sequence of child nodes
:type chs: Sequence[CircuitNodes]
:param edge_ids: a matrix of size [2, # edges] - every size-2 column vector [i,j] defines a set of edges that fully connect the ith sum node block and the jth child node block
:type edge_ids: Optional[Tensor]
:param block_size: block size
:type block_size: int
"""

def __init__(self, num_node_blocks: int, chs: Sequence[CircuitNodes], edge_ids: Optional[Union[Tensor,Sequence[Tensor]]] = None,
params: Optional[Tensor] = None, zero_param_mask: Optional[Tensor] = None, block_size: int = 0, **kwargs) -> None:

Expand Down Expand Up @@ -50,9 +66,25 @@ def __init__(self, num_node_blocks: int, chs: Sequence[CircuitNodes], edge_ids:

@property
def num_edges(self):
"""
Number of edges within the current node.
"""
return self.edge_ids.size(1) * self.block_size * self.ch_block_size

def duplicate(self, *args, tie_params: bool = False):
def duplicate(self, *args, tie_params: bool = False) -> SumNodes:
"""
Create a duplication of the current node with the same specification (i.e., number of nodes, block size).
:note: The child nodes should have the same specifications compared to the original child nodes.
:param args: a sequence of new child nodes
:type args: CircuitNodes
:param tie_params: whether to tie the parameters of the current node and the duplicated node
:type tie_params: bool
:returns: a duplicated `SumNodes`
"""
chs = []
for ns in args:
assert isinstance(ns, CircuitNodes)
Expand All @@ -78,11 +110,26 @@ def duplicate(self, *args, tie_params: bool = False):
return SumNodes(self.num_node_blocks, chs, edge_ids, params = params, block_size = self.block_size, source_node = self if tie_params else None)

def get_params(self):
"""
Get the sum node parameters.
"""
if not hasattr(self, "_params"):
return None
return self._params

def set_params(self, params: torch.Tensor, normalize: bool = True, pseudocount: float = 0.1):
def set_params(self, params: torch.Tensor, normalize: bool = True, pseudocount: float = 0.0):
"""
Set the sum node parameters.
:param params: parameters to be set
:type params: Union[torch.Tensor,Dict]
:param normalize: whether to normalize the parameters
:type normalize: bool
:param pseudocount: pseudo count added to the parameters
:type pseudocount: float
"""
if self._source_node is not None:
ns_source = self._source_node
ns_source.set_params(params, normalize = normalize, pseudocount = pseudocount)
Expand Down Expand Up @@ -153,6 +200,15 @@ def set_edges(self, edge_ids: Union[Tensor,Sequence[Tensor]]):
self._params = None # Clear parameters

def init_parameters(self, perturbation: float = 2.0, recursive: bool = True, is_root: bool = True, **kwargs):
"""
Randomly initialize node parameters.
:param perturbation: "amount of perturbation" added to the parameters (should be greater than 0)
:type perturbation: float
:param recursive: whether to recursively apply the function to child nodes
:type recursive: bool
"""
if self._source_node is None:
self._params = torch.exp(torch.rand([self.edge_ids.size(1), self.block_size, self.ch_block_size]) * -perturbation)

Expand All @@ -170,6 +226,15 @@ def init_parameters(self, perturbation: float = 2.0, recursive: bool = True, is_
)

def update_parameters(self, params: torch.Tensor, clone: bool = True):
"""
Update parameters from `pyjuice.TensorCircuit` to the current node.
:param params: the parameter tensor in the `TensorCircuit`
:type params: torch.Tensor
:param clone: whether to clone the parameters
:type clone: bool
"""
assert self.provided("_param_range"), "The `SumNodes` has not been compiled into a `TensorCircuit`."

if self.is_tied():
Expand All @@ -188,6 +253,15 @@ def update_parameters(self, params: torch.Tensor, clone: bool = True):
self._params = ns_params[local_parids,:,:].permute(0, 2, 1)

def update_param_flows(self, param_flows: torch.Tensor, origin_ns_only: bool = True, clone: bool = True):
"""
Update parameter flows from `pyjuice.TensorCircuit` to the current node.
:param params_flows: the parameter flow tensor in the `TensorCircuit`
:type params_flows: torch.Tensor
:param clone: whether to clone the parameters
:type clone: bool
"""
assert self.provided("_param_flow_range"), "The `SumNodes` has not been compiled into a `TensorCircuit`."

if origin_ns_only and self.is_tied():
Expand All @@ -205,6 +279,12 @@ def update_param_flows(self, param_flows: torch.Tensor, origin_ns_only: bool = T
self._param_flows = ns_param_flows[local_parfids,:,:].permute(0, 2, 1)

def gather_parameters(self, params: torch.Tensor):
"""
Update parameters from the current node to the compiled `pyjuice.TensorCircuit`.
:param params: the parameter tensor in the `TensorCircuit`
:type params: torch.Tensor
"""
assert self.provided("_param_range"), "The `SumNodes` has not been compiled into a `TensorCircuit`."

if self.is_tied() or not self.has_params():
Expand Down

0 comments on commit 91d16f8

Please sign in to comment.