Skip to content

Commit

Permalink
docstring for common transformations
Browse files Browse the repository at this point in the history
  • Loading branch information
liuanji committed Feb 14, 2024
1 parent 1294634 commit bc51a5d
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 8 deletions.
36 changes: 35 additions & 1 deletion src/pyjuice/transformations/blockify.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,24 @@ def _copy_params_kernel(new_params, params, target_id0, target_id1, target_id2,
tl.store(new_params + offs_npars, pars)


def blockify(root_ns: CircuitNodes, sparsity_tolerance: float = 0.25, max_target_block_size: int = 32, use_cuda: bool = True):
def blockify(root_ns: CircuitNodes, sparsity_tolerance: float = 0.25, max_target_block_size: int = 32, use_cuda: bool = True) -> CircuitNodes:
"""
Generate an equivalent PC with potentially high block sizes.
:param root_ns: the input PC
:type root_ns: CircuitNodes
:param sparsity_tolerance: allowed fraction of zero parameters to be added (should be in the range (0, 1])
:type sparsity_tolerance: float
:param max_target_block_size: the maximum block size to search for
:type max_target_block_size: int
:param use_cuda: use GPU when possible
:type use_cuda: bool
:returns: An equivalent `CircuitNodes`
"""

if use_cuda:
device = torch.device("cuda:0")
Expand Down Expand Up @@ -322,6 +339,23 @@ def update_ns(ns: CircuitNodes, ns_chs: Sequence[CircuitNodes]):


def unblockify(root_ns: CircuitNodes, block_size: int = 1, recursive: bool = True, keys_to_copy: Optional[Sequence[str]] = None):
"""
Decrease the block size of a PC.
:param root_ns: the input PC
:type root_ns: CircuitNodes
:param block_size: the target block size
:type block_size: int
:param recursive: whether to do it recursively or just for the current node
:type recursive: bool
:param keys_to_copy: an optional dictionary of properties to copy
:type keys_to_copy: Optional[Sequence[str]]
:returns: An equivalent `CircuitNodes`
"""

def update_ns(ns: CircuitNodes, ns_chs: Sequence[CircuitNodes]):
new_block_size = min(block_size, ns.block_size)
Expand Down
23 changes: 19 additions & 4 deletions src/pyjuice/transformations/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,23 @@
from pyjuice.utils import BitSet


def deepcopy(root_nodes: CircuitNodes, tie_params: bool = False,
var_mapping: Optional[Dict[int,int]] = None):
def deepcopy(root_ns: CircuitNodes, tie_params: bool = False,
var_mapping: Optional[Dict[int,int]] = None) -> CircuitNodes:
"""
Create a deepcopy of the input PC.
:param root_ns: the input PC
:type root_ns: CircuitNodes
:param tie_params: whether to tie the parameters between the original PC and the copied PC (if tied, their parameters will always be the same)
:type tie_params: bool
:param var_mapping: a mapping dictionary between the variables of the original PC and the copied PC
:type var_mapping: Optional[Dict[int,int]]
:returns: a copied PC
"""

old2new = dict()
tied_ns_pairs = []

Expand Down Expand Up @@ -76,12 +91,12 @@ def dfs(ns: CircuitNodes):

old2new[ns] = new_ns

dfs(root_nodes)
dfs(root_ns)

for ns, source_ns in tied_ns_pairs:
new_ns = old2new[ns]
new_source_ns = old2new[source_ns]

new_ns._source_node = new_source_ns

return old2new[root_nodes]
return old2new[root_ns]
26 changes: 23 additions & 3 deletions src/pyjuice/transformations/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,12 +214,32 @@ def merge_by_region_node(root_ns: CircuitNodes) -> CircuitNodes:
def merge(ns1: CircuitNodes, *args) -> CircuitNodes:
"""
Merge nodes with identical region node together.
:param ns1: the first PC node
:type ns1: CircuitNodes
:param args: the remaining PC nodes
:type args: CircuitNodes
Example::
>>> i00 = inputs(0, num_node_blocks, dists.Categorical(num_cats = 5))
>>> i01 = inputs(0, num_node_blocks, dists.Categorical(num_cats = 5))
>>> i10 = inputs(1, num_node_blocks, dists.Categorical(num_cats = 5))
>>> i11 = inputs(1, num_node_blocks, dists.Categorical(num_cats = 5))
>>> m00 = multiply(i00, i10)
>>> m01 = multiply(i01, i11)
>>> n0 = summate(m00, num_node_blocks = num_node_blocks)
>>> n1 = summate(m01, num_node_blocks = num_node_blocks)
>>> n_new = pyjuice.merge(n0, n1)
"""
if isinstance(ns1, SumNodes) and len(args) > 0 and isinstance(args[0], SumNodes):
if ns1.is_sum() and len(args) > 0 and args[0].is_sum():
return merge_sum_nodes(ns1, args[0], *args[1:])
elif isinstance(ns1, ProdNodes) and len(args) > 0 and isinstance(args[0], ProdNodes):
elif ns1.is_prod() and len(args) > 0 and args[0].is_prod():
return merge_prod_nodes(ns1, args[0], *args[1:])
elif len(args) == 0:
return merge_by_region_node(ns1)
else:
raise NotImplementedError()
raise ValueError()

0 comments on commit bc51a5d

Please sign in to comment.